+# Copyright (C) 2021-2022 Wind River Systems, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+# from collections import OrderedDict
+from functools import wraps
+# from six import iteritems
+
+from flask import request
+
+from flask_restx import Namespace
+from flask_restx._http import HTTPStatus
+from flask_restx.marshalling import marshal_with, marshal
+from flask_restx.utils import merge
+from flask_restx.mask import Mask # , apply as apply_mask
+from flask_restx.model import Model
+from flask_restx.fields import List, Nested, String
+from flask_restx.utils import unpack
+
+from o2common.helper import o2logging
+logger = o2logging.get_logger(__name__)
+
+
+class O2Namespace(Namespace):
+
+ def __init__(self, name, description=None, path=None, decorators=None,
+ validate=None, authorizations=None, ordered=False, **kwargs):
+ super().__init__(name, description, path, decorators,
+ validate, authorizations, ordered, **kwargs)
+
+ def marshal_with(
+ self, fields, as_list=False, code=HTTPStatus.OK, description=None,
+ **kwargs
+ ):
+ """
+ A decorator specifying the fields to use for serialization.
+
+ :param bool as_list: Indicate that the return type is a list \
+ (for the documentation)
+ :param int code: Optionally give the expected HTTP response \
+ code if its different from 200
+
+ """
+
+ def wrapper(func):
+ doc = {
+ "responses": {
+ str(code): (description, [fields], kwargs)
+ if as_list
+ else (description, fields, kwargs)
+ },
+ "__mask__": kwargs.get(
+ "mask", True
+ ), # Mask values can't be determined outside app context
+ }
+ func.__apidoc__ = merge(getattr(func, "__apidoc__", {}), doc)
+ return o2_marshal_with(fields, ordered=self.ordered,
+ **kwargs)(func)
+
+ return wrapper
+
+
+class o2_marshal_with(marshal_with):
+ def __init__(
+ self, fields, envelope=None, skip_none=False, mask=None, ordered=False
+ ):
+ """
+ :param fields: a dict of whose keys will make up the final
+ serialized response output
+ :param envelope: optional key that will be used to envelop the
+ serialized response
+ """
+ self.fields = fields
+ self.envelope = envelope
+ self.skip_none = skip_none
+ self.ordered = ordered
+ self.mask = Mask(mask, skip=True)
+
+ def __call__(self, f):
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ resp = f(*args, **kwargs)
+
+ req_args = request.args
+ mask = self._gen_mask_from_filter(**req_args)
+
+ # mask = self.mask
+
+ # if has_request_context():
+ # mask_header = current_app.config["RESTX_MASK_HEADER"]
+ # mask = request.headers.get(mask_header) or mask
+ if isinstance(resp, tuple):
+ data, code, headers = unpack(resp)
+ return (
+ marshal(
+ data,
+ self.fields,
+ self.envelope,
+ self.skip_none,
+ mask,
+ self.ordered,
+ ),
+ code,
+ headers,
+ )
+ else:
+ return marshal(
+ resp, self.fields, self.envelope, self.skip_none, mask,
+ self.ordered
+ )
+
+ return wrapper
+
+ def _gen_mask_from_filter(self, **kwargs) -> str:
+ mask_val = ''
+ if 'all_fields' in kwargs:
+ all_fields_without_space = kwargs['all_fields'].replace(" ", "")
+ all_fields = all_fields_without_space.lower()
+ if 'true' == all_fields:
+ mask_val = ''
+
+ elif 'fields' in kwargs and kwargs['fields'] != '':
+ fields_without_space = kwargs['fields'].replace(" ", "")
+
+ # filters = fields_without_space.split(',')
+
+ # mask_val_list = []
+ # for f in filters:
+ # if '/' in f:
+ # a = self.__gen_mask_tree(f)
+ # mask_val_list.append(a)
+ # continue
+ # mask_val_list.append(f)
+ # mask_val = '{%s}' % ','.join(mask_val_list)
+ default_fields = {}
+
+ self.__update_filter_value(
+ default_fields, fields_without_space, True)
+
+ mask_val = self.__gen_mask_from_filter_tree(default_fields)
+
+ elif 'exclude_fields' in kwargs and kwargs['exclude_fields'] != '':
+ exclude_fields_without_space = kwargs['exclude_fields'].replace(
+ " ", "")
+
+ default_fields = self.__gen_filter_tree_from_model_with_value(
+ self.fields)
+
+ self.__update_filter_value(
+ default_fields, exclude_fields_without_space, False)
+
+ mask_val = self.__gen_mask_from_filter_tree(default_fields)
+ elif 'exclude_default' in kwargs and kwargs['exclude_default'] != '':
+ exclude_default_without_space = kwargs['exclude_default'].replace(
+ " ", "")
+ exclude_default = exclude_default_without_space.lower()
+ if 'true' == exclude_default:
+ mask_val = '{}'
+
+ else:
+ mask_val = ''
+
+ return mask_val
+
+ def __gen_mask_tree(self, field: str) -> str:
+
+ f = field.split('/', 1)
+ if len(f) > 1:
+ s = self.__gen_mask_tree(f[1])
+ return '%s%s' % (f[0], s)
+ else:
+ return '{%s}' % f[0]
+
+ def __gen_filter_tree_from_model_with_value(
+ self, model: Model, default_val: bool = True) -> dict:
+ filter = dict()
+ for i in model:
+ if type(model[i]) is List:
+ if type(model[i].container) is String:
+ filter[i] = default_val
+ continue
+ filter[i] = self.__gen_filter_tree_from_model_with_value(
+ model[i].container.model, default_val)
+ continue
+ elif type(model[i]) is Nested:
+ filter[i] = self.__gen_filter_tree_from_model_with_value(
+ model[i].model, default_val)
+ filter[i] = default_val
+ return filter
+
+ def __update_filter_value(self, default_fields: dict, filter: str,
+ val: bool):
+ fields = filter.split(',')
+ for f in fields:
+ if '/' in f:
+ self.__update_filter_tree_value(default_fields, f, val)
+ continue
+ default_fields[f] = val
+
+ def __update_filter_tree_value(self, m: dict, filter: str, val: bool):
+ filter_list = filter.split('/', 1)
+ if filter_list[0] not in m:
+ m[filter_list[0]] = dict()
+ if len(filter_list) > 1:
+ self.__update_filter_tree_value(
+ m[filter_list[0]], filter_list[1], val)
+ return
+ m[filter_list[0]] = val
+
+ def __gen_mask_from_filter_tree(self, fields: dict) -> str:
+ mask_li = list()
+ for k, v in fields.items():
+ if type(v) is dict:
+ s = self.__gen_mask_from_filter_tree(v)
+ mask_li.append('%s%s' % (k, s))
+ continue
+ if v:
+ mask_li.append(k)
+
+ return '{%s}' % ','.join(mask_li)