Fix INF-375 get default mask strip failed
[pti/o2.git] / o2common / views / route.py
index b2d537d..9776458 100644 (file)
@@ -30,6 +30,8 @@ from flask_restx.model import Model
 from flask_restx.fields import List, Nested, String
 from flask_restx.utils import unpack
 
+from o2common.views.route_exception import BadRequestException
+
 from o2common.helper import o2logging
 logger = o2logging.get_logger(__name__)
 
@@ -95,9 +97,9 @@ class o2_marshal_with(marshal_with):
             resp = f(*args, **kwargs)
 
             req_args = request.args
-            mask = self._gen_mask_from_filter(**req_args)
-
-            # mask = self.mask
+            mask = self._gen_mask_from_selector(**req_args)
+            if mask == '':
+                mask = self.mask
 
             # if has_request_context():
             # mask_header = current_app.config["RESTX_MASK_HEADER"]
@@ -124,48 +126,34 @@ class o2_marshal_with(marshal_with):
 
         return wrapper
 
-    def _gen_mask_from_filter(self, **kwargs) -> str:
+    def _gen_mask_from_selector(self, **kwargs) -> str:
         mask_val = ''
         if 'all_fields' in kwargs:
-            all_fields_without_space = kwargs['all_fields'].replace(" ", "")
-            all_fields = all_fields_without_space.lower()
-            if 'true' == all_fields:
-                mask_val = ''
+            all_fields_without_space = kwargs['all_fields'].strip()
+            logger.debug('all_fields selector value is {}'.format(
+                all_fields_without_space))
+            selector = self.__gen_selector_from_model_with_value(
+                self.fields)
+            mask_val = self.__gen_mask_from_selector(selector)
 
         elif 'fields' in kwargs and kwargs['fields'] != '':
-            fields_without_space = kwargs['fields'].replace(" ", "")
-
-            # filters = fields_without_space.split(',')
-
-            # mask_val_list = []
-            # for f in filters:
-            #     if '/' in f:
-            #         a = self.__gen_mask_tree(f)
-            #         mask_val_list.append(a)
-            #         continue
-            #     mask_val_list.append(f)
-            # mask_val = '{%s}' % ','.join(mask_val_list)
-            default_fields = {}
-
-            self.__update_filter_value(
-                default_fields, fields_without_space, True)
-
-            mask_val = self.__gen_mask_from_filter_tree(default_fields)
+            fields_without_space = kwargs['fields'].strip()
+            selector = {}
+            self.__update_selector_value(selector, fields_without_space, True)
+            self.__set_default_mask(selector)
+            mask_val = self.__gen_mask_from_selector(selector)
 
         elif 'exclude_fields' in kwargs and kwargs['exclude_fields'] != '':
-            exclude_fields_without_space = kwargs['exclude_fields'].replace(
-                " ", "")
-
-            default_fields = self.__gen_filter_tree_from_model_with_value(
+            exclude_fields_without_space = kwargs['exclude_fields'].strip()
+            selector = self.__gen_selector_from_model_with_value(
                 self.fields)
+            self.__update_selector_value(
+                selector, exclude_fields_without_space, False)
+            self.__set_default_mask(selector)
+            mask_val = self.__gen_mask_from_selector(selector)
 
-            self.__update_filter_value(
-                default_fields, exclude_fields_without_space, False)
-
-            mask_val = self.__gen_mask_from_filter_tree(default_fields)
         elif 'exclude_default' in kwargs and kwargs['exclude_default'] != '':
-            exclude_default_without_space = kwargs['exclude_default'].replace(
-                " ", "")
+            exclude_default_without_space = kwargs['exclude_default'].strip()
             exclude_default = exclude_default_without_space.lower()
             if 'true' == exclude_default:
                 mask_val = '{}'
@@ -184,50 +172,65 @@ class o2_marshal_with(marshal_with):
         else:
             return '{%s}' % f[0]
 
-    def __gen_filter_tree_from_model_with_value(
+    def __gen_selector_from_model_with_value(
             self, model: Model, default_val: bool = True) -> dict:
-        filter = dict()
+        selector = dict()
         for i in model:
             if type(model[i]) is List:
                 if type(model[i].container) is String:
-                    filter[i] = default_val
+                    selector[i] = default_val
                     continue
-                filter[i] = self.__gen_filter_tree_from_model_with_value(
+                selector[i] = self.__gen_selector_from_model_with_value(
                     model[i].container.model, default_val)
                 continue
             elif type(model[i]) is Nested:
-                filter[i] = self.__gen_filter_tree_from_model_with_value(
+                selector[i] = self.__gen_selector_from_model_with_value(
                     model[i].model, default_val)
-            filter[i] = default_val
-        return filter
+            selector[i] = default_val
+        return selector
 
-    def __update_filter_value(self, default_fields: dict, filter: str,
-                              val: bool):
+    def __update_selector_value(self, selector: dict, filter: str,
+                                val: bool):
         fields = filter.split(',')
         for f in fields:
+            f = f.strip()
             if '/' in f:
-                self.__update_filter_tree_value(default_fields, f, val)
+                self.__update_selector_tree_value(selector, f, val)
                 continue
-            default_fields[f] = val
+            if f not in self.fields:
+                raise BadRequestException(
+                    'Selector attribute {} not found'.format(f))
+            selector[f] = val
 
-    def __update_filter_tree_value(self, m: dict, filter: str, val: bool):
+    def __update_selector_tree_value(self, m: dict, filter: str, val: bool):
         filter_list = filter.split('/', 1)
         if filter_list[0] not in m:
             m[filter_list[0]] = dict()
         if len(filter_list) > 1:
-            self.__update_filter_tree_value(
+            self.__update_selector_tree_value(
                 m[filter_list[0]], filter_list[1], val)
             return
         m[filter_list[0]] = val
 
-    def __gen_mask_from_filter_tree(self, fields: dict) -> str:
+    def __gen_mask_from_selector(self, fields: dict) -> str:
         mask_li = list()
         for k, v in fields.items():
             if type(v) is dict:
-                s = self.__gen_mask_from_filter_tree(v)
+                s = self.__gen_mask_from_selector(v)
                 mask_li.append('%s%s' % (k, s))
                 continue
             if v:
                 mask_li.append(k)
 
         return '{%s}' % ','.join(mask_li)
+
+    def __set_default_mask(self, selector: dict, val: bool = True):
+        mask = getattr(self.fields, "__mask__")
+        if not mask:
+            selector_all = self.__gen_selector_from_model_with_value(
+                self.fields)
+            for s in selector_all:
+                selector[s] = val
+            return
+        default_selector = str(mask).strip(' {}')
+        self.__update_selector_value(selector, default_selector, val)