from flask_restx.fields import List, Nested, String
from flask_restx.utils import unpack
+from o2common.views.route_exception import BadRequestException
+
from o2common.helper import o2logging
logger = o2logging.get_logger(__name__)
resp = f(*args, **kwargs)
req_args = request.args
- mask = self._gen_mask_from_filter(**req_args)
-
- # mask = self.mask
+ mask = self._gen_mask_from_selector(**req_args)
+ if mask == '':
+ mask = self.mask
# if has_request_context():
# mask_header = current_app.config["RESTX_MASK_HEADER"]
return wrapper
- def _gen_mask_from_filter(self, **kwargs) -> str:
+ def _gen_mask_from_selector(self, **kwargs) -> str:
mask_val = ''
if 'all_fields' in kwargs:
- all_fields_without_space = kwargs['all_fields'].replace(" ", "")
- all_fields = all_fields_without_space.lower()
- if 'true' == all_fields:
- mask_val = ''
+ all_fields_without_space = kwargs['all_fields'].strip()
+ logger.debug('all_fields selector value is {}'.format(
+ all_fields_without_space))
+ selector = self.__gen_selector_from_model_with_value(
+ self.fields)
+ mask_val = self.__gen_mask_from_selector(selector)
elif 'fields' in kwargs and kwargs['fields'] != '':
- fields_without_space = kwargs['fields'].replace(" ", "")
-
- # filters = fields_without_space.split(',')
-
- # mask_val_list = []
- # for f in filters:
- # if '/' in f:
- # a = self.__gen_mask_tree(f)
- # mask_val_list.append(a)
- # continue
- # mask_val_list.append(f)
- # mask_val = '{%s}' % ','.join(mask_val_list)
- default_fields = {}
-
- self.__update_filter_value(
- default_fields, fields_without_space, True)
-
- mask_val = self.__gen_mask_from_filter_tree(default_fields)
+ fields_without_space = kwargs['fields'].strip()
+ selector = {}
+ self.__update_selector_value(selector, fields_without_space, True)
+ self.__set_default_mask(selector)
+ mask_val = self.__gen_mask_from_selector(selector)
elif 'exclude_fields' in kwargs and kwargs['exclude_fields'] != '':
- exclude_fields_without_space = kwargs['exclude_fields'].replace(
- " ", "")
-
- default_fields = self.__gen_filter_tree_from_model_with_value(
+ exclude_fields_without_space = kwargs['exclude_fields'].strip()
+ selector = self.__gen_selector_from_model_with_value(
self.fields)
+ self.__update_selector_value(
+ selector, exclude_fields_without_space, False)
+ self.__set_default_mask(selector)
+ mask_val = self.__gen_mask_from_selector(selector)
- self.__update_filter_value(
- default_fields, exclude_fields_without_space, False)
-
- mask_val = self.__gen_mask_from_filter_tree(default_fields)
elif 'exclude_default' in kwargs and kwargs['exclude_default'] != '':
- exclude_default_without_space = kwargs['exclude_default'].replace(
- " ", "")
+ exclude_default_without_space = kwargs['exclude_default'].strip()
exclude_default = exclude_default_without_space.lower()
if 'true' == exclude_default:
mask_val = '{}'
else:
return '{%s}' % f[0]
- def __gen_filter_tree_from_model_with_value(
+ def __gen_selector_from_model_with_value(
self, model: Model, default_val: bool = True) -> dict:
- filter = dict()
+ selector = dict()
for i in model:
if type(model[i]) is List:
if type(model[i].container) is String:
- filter[i] = default_val
+ selector[i] = default_val
continue
- filter[i] = self.__gen_filter_tree_from_model_with_value(
+ selector[i] = self.__gen_selector_from_model_with_value(
model[i].container.model, default_val)
continue
elif type(model[i]) is Nested:
- filter[i] = self.__gen_filter_tree_from_model_with_value(
+ selector[i] = self.__gen_selector_from_model_with_value(
model[i].model, default_val)
- filter[i] = default_val
- return filter
+ selector[i] = default_val
+ return selector
- def __update_filter_value(self, default_fields: dict, filter: str,
- val: bool):
+ def __update_selector_value(self, selector: dict, filter: str,
+ val: bool):
fields = filter.split(',')
for f in fields:
+ f = f.strip()
if '/' in f:
- self.__update_filter_tree_value(default_fields, f, val)
+ parent = f.split('/')[0]
+ if parent in selector and type(selector[parent]) is bool:
+ selector[parent] = dict()
+ self.__update_selector_tree_value(selector, f, val)
continue
- default_fields[f] = val
+ if f not in self.fields:
+ raise BadRequestException(
+ 'Selector attribute {} not found'.format(f))
+ selector[f] = val
- def __update_filter_tree_value(self, m: dict, filter: str, val: bool):
+ def __update_selector_tree_value(self, m: dict, filter: str, val: bool):
filter_list = filter.split('/', 1)
if filter_list[0] not in m:
m[filter_list[0]] = dict()
if len(filter_list) > 1:
- self.__update_filter_tree_value(
+ self.__update_selector_tree_value(
m[filter_list[0]], filter_list[1], val)
return
m[filter_list[0]] = val
- def __gen_mask_from_filter_tree(self, fields: dict) -> str:
+ def __gen_mask_from_selector(self, fields: dict) -> str:
mask_li = list()
for k, v in fields.items():
if type(v) is dict:
- s = self.__gen_mask_from_filter_tree(v)
+ s = self.__gen_mask_from_selector(v)
+ if s == '{}':
+ continue
mask_li.append('%s%s' % (k, s))
continue
if v:
mask_li.append(k)
return '{%s}' % ','.join(mask_li)
+
+ def __set_default_mask(self, selector: dict, val: bool = True):
+ def convert_mask(mask):
+ # convert mask from {aa,bb,xxx{yyy}} structure to aa,bbxxx/yyy
+ stack = []
+ result = []
+ word = ''
+ for ch in mask:
+ if ch == '{':
+ if word:
+ stack.append(word)
+ word = ''
+ elif ch == '}':
+ if word:
+ result.append('/'.join(stack + [word]))
+ word = ''
+ if stack:
+ stack.pop()
+ elif ch == ',':
+ if word:
+ result.append('/'.join(stack + [word]))
+ word = ''
+ else:
+ word += ch
+ if word:
+ result.append(word)
+ return ','.join(result)
+
+ mask = getattr(self.fields, "__mask__")
+ mask = convert_mask(str(mask))
+ if not mask:
+ selector_all = self.__gen_selector_from_model_with_value(
+ self.fields)
+ for s in selector_all:
+ selector[s] = val
+ return
+ default_selector = mask
+ self.__update_selector_value(selector, default_selector, val)