X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=blobdiff_plain;f=o2common%2Fviews%2Froute.py;h=f152fc06c4485a8dcd076d66cc7d4a05ddeb13b6;hb=58994b7d851b47456eed1820d36cc06803777e3b;hp=832e38c8ac10ef6dcf09ebd9b49af94a914daf80;hpb=5e0dacb10819977ef6a452257346f72592cff374;p=pti%2Fo2.git diff --git a/o2common/views/route.py b/o2common/views/route.py index 832e38c..f152fc0 100644 --- a/o2common/views/route.py +++ b/o2common/views/route.py @@ -29,7 +29,8 @@ from flask_restx.mask import Mask # , apply as apply_mask from flask_restx.model import Model from flask_restx.fields import List, Nested, String from flask_restx.utils import unpack -from o2common.domain.base import Serializer + +from o2common.views.route_exception import BadRequestException from o2common.helper import o2logging logger = o2logging.get_logger(__name__) @@ -97,8 +98,8 @@ class o2_marshal_with(marshal_with): req_args = request.args mask = self._gen_mask_from_selector(**req_args) - - # mask = self.mask + if mask == '': + mask = self.mask # if has_request_context(): # mask_header = current_app.config["RESTX_MASK_HEADER"] @@ -128,44 +129,31 @@ class o2_marshal_with(marshal_with): def _gen_mask_from_selector(self, **kwargs) -> str: mask_val = '' if 'all_fields' in kwargs: - all_fields_without_space = kwargs['all_fields'].replace(" ", "") - all_fields = all_fields_without_space.lower() - if 'true' == all_fields: - mask_val = '' + all_fields_without_space = kwargs['all_fields'].strip() + logger.debug('all_fields selector value is {}'.format( + all_fields_without_space)) + selector = self.__gen_selector_from_model_with_value( + self.fields) + mask_val = self.__gen_mask_from_selector(selector) elif 'fields' in kwargs and kwargs['fields'] != '': - fields_without_space = kwargs['fields'].replace(" ", "") - - # filters = fields_without_space.split(',') - - # mask_val_list = [] - # for f in filters: - # if '/' in f: - # a = self.__gen_mask_tree(f) - # mask_val_list.append(a) - # continue - # mask_val_list.append(f) - # mask_val = '{%s}' % ','.join(mask_val_list) + fields_without_space = kwargs['fields'].strip() selector = {} - self.__update_selector_value(selector, fields_without_space, True) - + self.__set_default_mask(selector) mask_val = self.__gen_mask_from_selector(selector) elif 'exclude_fields' in kwargs and kwargs['exclude_fields'] != '': - exclude_fields_without_space = kwargs['exclude_fields'].replace( - " ", "") - + exclude_fields_without_space = kwargs['exclude_fields'].strip() selector = self.__gen_selector_from_model_with_value( self.fields) - self.__update_selector_value( selector, exclude_fields_without_space, False) - + self.__set_default_mask(selector) mask_val = self.__gen_mask_from_selector(selector) + elif 'exclude_default' in kwargs and kwargs['exclude_default'] != '': - exclude_default_without_space = kwargs['exclude_default'].replace( - " ", "") + exclude_default_without_space = kwargs['exclude_default'].strip() exclude_default = exclude_default_without_space.lower() if 'true' == exclude_default: mask_val = '{}' @@ -201,14 +189,21 @@ class o2_marshal_with(marshal_with): selector[i] = default_val return selector - def __update_selector_value(self, default_selector: dict, filter: str, + def __update_selector_value(self, selector: dict, filter: str, val: bool): fields = filter.split(',') for f in fields: + f = f.strip() if '/' in f: - self.__update_selector_tree_value(default_selector, f, val) + parent = f.split('/')[0] + if parent in selector and type(selector[parent]) is bool: + selector[parent] = dict() + self.__update_selector_tree_value(selector, f, val) continue - default_selector[f] = val + if f not in self.fields: + raise BadRequestException( + 'Selector attribute {} not found'.format(f)) + selector[f] = val def __update_selector_tree_value(self, m: dict, filter: str, val: bool): filter_list = filter.split('/', 1) @@ -225,6 +220,8 @@ class o2_marshal_with(marshal_with): for k, v in fields.items(): if type(v) is dict: s = self.__gen_mask_from_selector(v) + if s == '{}': + continue mask_li.append('%s%s' % (k, s)) continue if v: @@ -232,30 +229,40 @@ class o2_marshal_with(marshal_with): return '{%s}' % ','.join(mask_li) - -class ProblemDetails(Serializer): - def __init__(self, namespace: O2Namespace, code: int, detail: str, - title=None, instance=None - ) -> None: - self.ns = namespace - self.status = code - self.detail = detail - self.type = request.path - self.title = title if title is not None else self.getTitle(code) - self.instance = instance if instance is not None else [] - - def getTitle(self, code): - return HTTPStatus(code).phrase - - def abort(self): - self.ns.abort(self.status, self.detail, **self.serialize()) - - def serialize(self): - details = {} - for key in dir(self): - if key == 'ns' or key.startswith('__') or\ - callable(getattr(self, key)): - continue - else: - details[key] = getattr(self, key) - return details + def __set_default_mask(self, selector: dict, val: bool = True): + def convert_mask(mask): + # convert mask from {aa,bb,xxx{yyy}} structure to aa,bbxxx/yyy + stack = [] + result = [] + word = '' + for ch in mask: + if ch == '{': + if word: + stack.append(word) + word = '' + elif ch == '}': + if word: + result.append('/'.join(stack + [word])) + word = '' + if stack: + stack.pop() + elif ch == ',': + if word: + result.append('/'.join(stack + [word])) + word = '' + else: + word += ch + if word: + result.append(word) + return ','.join(result) + + mask = getattr(self.fields, "__mask__") + mask = convert_mask(str(mask)) + if not mask: + selector_all = self.__gen_selector_from_model_with_value( + self.fields) + for s in selector_all: + selector[s] = val + return + default_selector = mask + self.__update_selector_value(selector, default_selector, val)