Fix the IPv6 does not work
[pti/o2.git] / o2common / views / route.py
1 # Copyright (C) 2021-2022 Wind River Systems, Inc.
2 #
3 #  Licensed under the Apache License, Version 2.0 (the "License");
4 #  you may not use this file except in compliance with the License.
5 #  You may obtain a copy of the License at
6 #
7 #      http://www.apache.org/licenses/LICENSE-2.0
8 #
9 #  Unless required by applicable law or agreed to in writing, software
10 #  distributed under the License is distributed on an "AS IS" BASIS,
11 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 #  See the License for the specific language governing permissions and
13 #  limitations under the License.
14
15 # -*- coding: utf-8 -*-
16 from __future__ import unicode_literals
17
18 # from collections import OrderedDict
19 from functools import wraps
20 # from six import iteritems
21
22 from flask import request
23
24 from flask_restx import Namespace
25 from flask_restx._http import HTTPStatus
26 from flask_restx.marshalling import marshal_with, marshal
27 from flask_restx.utils import merge
28 from flask_restx.mask import Mask  # , apply as apply_mask
29 from flask_restx.model import Model
30 from flask_restx.fields import List, Nested, String
31 from flask_restx.utils import unpack
32
33 from o2common.views.route_exception import BadRequestException
34
35 from o2common.helper import o2logging
36 logger = o2logging.get_logger(__name__)
37
38
39 class O2Namespace(Namespace):
40
41     def __init__(self, name, description=None, path=None, decorators=None,
42                  validate=None, authorizations=None, ordered=False, **kwargs):
43         super().__init__(name, description, path, decorators,
44                          validate, authorizations, ordered, **kwargs)
45
46     def marshal_with(
47         self, fields, as_list=False, code=HTTPStatus.OK, description=None,
48         **kwargs
49     ):
50         """
51         A decorator specifying the fields to use for serialization.
52
53         :param bool as_list: Indicate that the return type is a list \
54             (for the documentation)
55         :param int code: Optionally give the expected HTTP response \
56             code if its different from 200
57
58         """
59
60         def wrapper(func):
61             doc = {
62                 "responses": {
63                     str(code): (description, [fields], kwargs)
64                     if as_list
65                     else (description, fields, kwargs)
66                 },
67                 "__mask__": kwargs.get(
68                     "mask", True
69                 ),  # Mask values can't be determined outside app context
70             }
71             func.__apidoc__ = merge(getattr(func, "__apidoc__", {}), doc)
72             return o2_marshal_with(fields, ordered=self.ordered,
73                                    **kwargs)(func)
74
75         return wrapper
76
77
78 class o2_marshal_with(marshal_with):
79     def __init__(
80         self, fields, envelope=None, skip_none=False, mask=None, ordered=False
81     ):
82         """
83         :param fields: a dict of whose keys will make up the final
84                        serialized response output
85         :param envelope: optional key that will be used to envelop the
86                        serialized response
87         """
88         self.fields = fields
89         self.envelope = envelope
90         self.skip_none = skip_none
91         self.ordered = ordered
92         self.mask = Mask(mask, skip=True)
93
94     def __call__(self, f):
95         @wraps(f)
96         def wrapper(*args, **kwargs):
97             resp = f(*args, **kwargs)
98
99             req_args = request.args
100             mask = self._gen_mask_from_selector(**req_args)
101             if mask == '':
102                 mask = self.mask
103
104             # if has_request_context():
105             # mask_header = current_app.config["RESTX_MASK_HEADER"]
106             # mask = request.headers.get(mask_header) or mask
107             if isinstance(resp, tuple):
108                 data, code, headers = unpack(resp)
109                 return (
110                     marshal(
111                         data,
112                         self.fields,
113                         self.envelope,
114                         self.skip_none,
115                         mask,
116                         self.ordered,
117                     ),
118                     code,
119                     headers,
120                 )
121             else:
122                 return marshal(
123                     resp, self.fields, self.envelope, self.skip_none, mask,
124                     self.ordered
125                 )
126
127         return wrapper
128
129     def _gen_mask_from_selector(self, **kwargs) -> str:
130         mask_val = ''
131         if 'all_fields' in kwargs:
132             all_fields_without_space = kwargs['all_fields'].strip()
133             logger.debug('all_fields selector value is {}'.format(
134                 all_fields_without_space))
135             selector = self.__gen_selector_from_model_with_value(
136                 self.fields)
137             mask_val = self.__gen_mask_from_selector(selector)
138
139         elif 'fields' in kwargs and kwargs['fields'] != '':
140             fields_without_space = kwargs['fields'].strip()
141             selector = {}
142             self.__update_selector_value(selector, fields_without_space, True)
143             self.__set_default_mask(selector)
144             mask_val = self.__gen_mask_from_selector(selector)
145
146         elif 'exclude_fields' in kwargs and kwargs['exclude_fields'] != '':
147             exclude_fields_without_space = kwargs['exclude_fields'].strip()
148             selector = self.__gen_selector_from_model_with_value(
149                 self.fields)
150             self.__update_selector_value(
151                 selector, exclude_fields_without_space, False)
152             self.__set_default_mask(selector)
153             mask_val = self.__gen_mask_from_selector(selector)
154
155         elif 'exclude_default' in kwargs and kwargs['exclude_default'] != '':
156             exclude_default_without_space = kwargs['exclude_default'].strip()
157             exclude_default = exclude_default_without_space.lower()
158             if 'true' == exclude_default:
159                 mask_val = '{}'
160
161         else:
162             mask_val = ''
163
164         return mask_val
165
166     def __gen_mask_tree(self, field: str) -> str:
167
168         f = field.split('/', 1)
169         if len(f) > 1:
170             s = self.__gen_mask_tree(f[1])
171             return '%s%s' % (f[0], s)
172         else:
173             return '{%s}' % f[0]
174
175     def __gen_selector_from_model_with_value(
176             self, model: Model, default_val: bool = True) -> dict:
177         selector = dict()
178         for i in model:
179             if type(model[i]) is List:
180                 if type(model[i].container) is String:
181                     selector[i] = default_val
182                     continue
183                 selector[i] = self.__gen_selector_from_model_with_value(
184                     model[i].container.model, default_val)
185                 continue
186             elif type(model[i]) is Nested:
187                 selector[i] = self.__gen_selector_from_model_with_value(
188                     model[i].model, default_val)
189             selector[i] = default_val
190         return selector
191
192     def __update_selector_value(self, selector: dict, filter: str,
193                                 val: bool):
194         fields = filter.split(',')
195         for f in fields:
196             f = f.strip()
197             if '/' in f:
198                 parent = f.split('/')[0]
199                 if parent in selector and type(selector[parent]) is bool:
200                     selector[parent] = dict()
201                 self.__update_selector_tree_value(selector, f, val)
202                 continue
203             if f not in self.fields:
204                 raise BadRequestException(
205                     'Selector attribute {} not found'.format(f))
206             selector[f] = val
207
208     def __update_selector_tree_value(self, m: dict, filter: str, val: bool):
209         filter_list = filter.split('/', 1)
210         if filter_list[0] not in m:
211             m[filter_list[0]] = dict()
212         if len(filter_list) > 1:
213             self.__update_selector_tree_value(
214                 m[filter_list[0]], filter_list[1], val)
215             return
216         m[filter_list[0]] = val
217
218     def __gen_mask_from_selector(self, fields: dict) -> str:
219         mask_li = list()
220         for k, v in fields.items():
221             if type(v) is dict:
222                 s = self.__gen_mask_from_selector(v)
223                 if s == '{}':
224                     continue
225                 mask_li.append('%s%s' % (k, s))
226                 continue
227             if v:
228                 mask_li.append(k)
229
230         return '{%s}' % ','.join(mask_li)
231
232     def __set_default_mask(self, selector: dict, val: bool = True):
233         def convert_mask(mask):
234             # convert mask from {aa,bb,xxx{yyy}} structure to aa,bbxxx/yyy
235             stack = []
236             result = []
237             word = ''
238             for ch in mask:
239                 if ch == '{':
240                     if word:
241                         stack.append(word)
242                         word = ''
243                 elif ch == '}':
244                     if word:
245                         result.append('/'.join(stack + [word]))
246                         word = ''
247                     if stack:
248                         stack.pop()
249                 elif ch == ',':
250                     if word:
251                         result.append('/'.join(stack + [word]))
252                         word = ''
253                 else:
254                     word += ch
255             if word:
256                 result.append(word)
257             return ','.join(result)
258
259         mask = getattr(self.fields, "__mask__")
260         mask = convert_mask(str(mask))
261         if not mask:
262             selector_all = self.__gen_selector_from_model_with_value(
263                 self.fields)
264             for s in selector_all:
265                 selector[s] = val
266             return
267         default_selector = mask
268         self.__update_selector_value(selector, default_selector, val)