1 # Copyright (C) 2021-2022 Wind River Systems, Inc.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
15 # -*- coding: utf-8 -*-
16 from __future__ import unicode_literals
18 # from collections import OrderedDict
19 from functools import wraps
20 # from six import iteritems
22 from flask import request
24 from flask_restx import Namespace
25 from flask_restx._http import HTTPStatus
26 from flask_restx.marshalling import marshal_with, marshal
27 from flask_restx.utils import merge
28 from flask_restx.mask import Mask # , apply as apply_mask
29 from flask_restx.model import Model
30 from flask_restx.fields import List, Nested, String
31 from flask_restx.utils import unpack
33 from o2common.views.route_exception import BadRequestException
35 from o2common.helper import o2logging
36 logger = o2logging.get_logger(__name__)
39 class O2Namespace(Namespace):
41 def __init__(self, name, description=None, path=None, decorators=None,
42 validate=None, authorizations=None, ordered=False, **kwargs):
43 super().__init__(name, description, path, decorators,
44 validate, authorizations, ordered, **kwargs)
47 self, fields, as_list=False, code=HTTPStatus.OK, description=None,
51 A decorator specifying the fields to use for serialization.
53 :param bool as_list: Indicate that the return type is a list \
54 (for the documentation)
55 :param int code: Optionally give the expected HTTP response \
56 code if its different from 200
63 str(code): (description, [fields], kwargs)
65 else (description, fields, kwargs)
67 "__mask__": kwargs.get(
69 ), # Mask values can't be determined outside app context
71 func.__apidoc__ = merge(getattr(func, "__apidoc__", {}), doc)
72 return o2_marshal_with(fields, ordered=self.ordered,
78 class o2_marshal_with(marshal_with):
80 self, fields, envelope=None, skip_none=False, mask=None, ordered=False
83 :param fields: a dict of whose keys will make up the final
84 serialized response output
85 :param envelope: optional key that will be used to envelop the
89 self.envelope = envelope
90 self.skip_none = skip_none
91 self.ordered = ordered
92 self.mask = Mask(mask, skip=True)
94 def __call__(self, f):
96 def wrapper(*args, **kwargs):
97 resp = f(*args, **kwargs)
99 req_args = request.args
100 mask = self._gen_mask_from_selector(**req_args)
104 # if has_request_context():
105 # mask_header = current_app.config["RESTX_MASK_HEADER"]
106 # mask = request.headers.get(mask_header) or mask
107 if isinstance(resp, tuple):
108 data, code, headers = unpack(resp)
123 resp, self.fields, self.envelope, self.skip_none, mask,
129 def _gen_mask_from_selector(self, **kwargs) -> str:
131 if 'all_fields' in kwargs:
132 all_fields_without_space = kwargs['all_fields'].strip()
133 logger.debug('all_fields selector value is {}'.format(
134 all_fields_without_space))
135 selector = self.__gen_selector_from_model_with_value(
137 mask_val = self.__gen_mask_from_selector(selector)
139 elif 'fields' in kwargs and kwargs['fields'] != '':
140 fields_without_space = kwargs['fields'].strip()
142 self.__update_selector_value(selector, fields_without_space, True)
143 self.__set_default_mask(selector)
144 mask_val = self.__gen_mask_from_selector(selector)
146 elif 'exclude_fields' in kwargs and kwargs['exclude_fields'] != '':
147 exclude_fields_without_space = kwargs['exclude_fields'].strip()
148 selector = self.__gen_selector_from_model_with_value(
150 self.__update_selector_value(
151 selector, exclude_fields_without_space, False)
152 self.__set_default_mask(selector)
153 mask_val = self.__gen_mask_from_selector(selector)
155 elif 'exclude_default' in kwargs and kwargs['exclude_default'] != '':
156 exclude_default_without_space = kwargs['exclude_default'].strip()
157 exclude_default = exclude_default_without_space.lower()
158 if 'true' == exclude_default:
166 def __gen_mask_tree(self, field: str) -> str:
168 f = field.split('/', 1)
170 s = self.__gen_mask_tree(f[1])
171 return '%s%s' % (f[0], s)
175 def __gen_selector_from_model_with_value(
176 self, model: Model, default_val: bool = True) -> dict:
179 if type(model[i]) is List:
180 if type(model[i].container) is String:
181 selector[i] = default_val
183 selector[i] = self.__gen_selector_from_model_with_value(
184 model[i].container.model, default_val)
186 elif type(model[i]) is Nested:
187 selector[i] = self.__gen_selector_from_model_with_value(
188 model[i].model, default_val)
189 selector[i] = default_val
192 def __update_selector_value(self, selector: dict, filter: str,
194 fields = filter.split(',')
198 parent = f.split('/')[0]
199 if parent in selector and type(selector[parent]) is bool:
200 selector[parent] = dict()
201 self.__update_selector_tree_value(selector, f, val)
203 if f not in self.fields:
204 raise BadRequestException(
205 'Selector attribute {} not found'.format(f))
208 def __update_selector_tree_value(self, m: dict, filter: str, val: bool):
209 filter_list = filter.split('/', 1)
210 if filter_list[0] not in m:
211 m[filter_list[0]] = dict()
212 if len(filter_list) > 1:
213 self.__update_selector_tree_value(
214 m[filter_list[0]], filter_list[1], val)
216 m[filter_list[0]] = val
218 def __gen_mask_from_selector(self, fields: dict) -> str:
220 for k, v in fields.items():
222 s = self.__gen_mask_from_selector(v)
225 mask_li.append('%s%s' % (k, s))
230 return '{%s}' % ','.join(mask_li)
232 def __set_default_mask(self, selector: dict, val: bool = True):
233 def convert_mask(mask):
234 # convert mask from {aa,bb,xxx{yyy}} structure to aa,bbxxx/yyy
245 result.append('/'.join(stack + [word]))
251 result.append('/'.join(stack + [word]))
257 return ','.join(result)
259 mask = getattr(self.fields, "__mask__")
260 mask = convert_mask(str(mask))
262 selector_all = self.__gen_selector_from_model_with_value(
264 for s in selector_all:
267 default_selector = mask
268 self.__update_selector_value(selector, default_selector, val)