Update opMulti operations of the filter
[pti/o2.git] / o2common / views / route.py
1 # Copyright (C) 2021-2022 Wind River Systems, Inc.
2 #
3 #  Licensed under the Apache License, Version 2.0 (the "License");
4 #  you may not use this file except in compliance with the License.
5 #  You may obtain a copy of the License at
6 #
7 #      http://www.apache.org/licenses/LICENSE-2.0
8 #
9 #  Unless required by applicable law or agreed to in writing, software
10 #  distributed under the License is distributed on an "AS IS" BASIS,
11 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 #  See the License for the specific language governing permissions and
13 #  limitations under the License.
14
15 # -*- coding: utf-8 -*-
16 from __future__ import unicode_literals
17
18 # from collections import OrderedDict
19 from functools import wraps
20 # from six import iteritems
21
22 from flask import request
23
24 from flask_restx import Namespace
25 from flask_restx._http import HTTPStatus
26 from flask_restx.marshalling import marshal_with, marshal
27 from flask_restx.utils import merge
28 from flask_restx.mask import Mask  # , apply as apply_mask
29 from flask_restx.model import Model
30 from flask_restx.fields import List, Nested, String
31 from flask_restx.utils import unpack
32
33 from o2common.helper import o2logging
34 logger = o2logging.get_logger(__name__)
35
36
37 class O2Namespace(Namespace):
38
39     def __init__(self, name, description=None, path=None, decorators=None,
40                  validate=None, authorizations=None, ordered=False, **kwargs):
41         super().__init__(name, description, path, decorators,
42                          validate, authorizations, ordered, **kwargs)
43
44     def marshal_with(
45         self, fields, as_list=False, code=HTTPStatus.OK, description=None,
46         **kwargs
47     ):
48         """
49         A decorator specifying the fields to use for serialization.
50
51         :param bool as_list: Indicate that the return type is a list \
52             (for the documentation)
53         :param int code: Optionally give the expected HTTP response \
54             code if its different from 200
55
56         """
57
58         def wrapper(func):
59             doc = {
60                 "responses": {
61                     str(code): (description, [fields], kwargs)
62                     if as_list
63                     else (description, fields, kwargs)
64                 },
65                 "__mask__": kwargs.get(
66                     "mask", True
67                 ),  # Mask values can't be determined outside app context
68             }
69             func.__apidoc__ = merge(getattr(func, "__apidoc__", {}), doc)
70             return o2_marshal_with(fields, ordered=self.ordered,
71                                    **kwargs)(func)
72
73         return wrapper
74
75
76 class o2_marshal_with(marshal_with):
77     def __init__(
78         self, fields, envelope=None, skip_none=False, mask=None, ordered=False
79     ):
80         """
81         :param fields: a dict of whose keys will make up the final
82                        serialized response output
83         :param envelope: optional key that will be used to envelop the
84                        serialized response
85         """
86         self.fields = fields
87         self.envelope = envelope
88         self.skip_none = skip_none
89         self.ordered = ordered
90         self.mask = Mask(mask, skip=True)
91
92     def __call__(self, f):
93         @wraps(f)
94         def wrapper(*args, **kwargs):
95             resp = f(*args, **kwargs)
96
97             req_args = request.args
98             mask = self._gen_mask_from_selector(**req_args)
99
100             # mask = self.mask
101
102             # if has_request_context():
103             # mask_header = current_app.config["RESTX_MASK_HEADER"]
104             # mask = request.headers.get(mask_header) or mask
105             if isinstance(resp, tuple):
106                 data, code, headers = unpack(resp)
107                 return (
108                     marshal(
109                         data,
110                         self.fields,
111                         self.envelope,
112                         self.skip_none,
113                         mask,
114                         self.ordered,
115                     ),
116                     code,
117                     headers,
118                 )
119             else:
120                 return marshal(
121                     resp, self.fields, self.envelope, self.skip_none, mask,
122                     self.ordered
123                 )
124
125         return wrapper
126
127     def _gen_mask_from_selector(self, **kwargs) -> str:
128         mask_val = ''
129         if 'all_fields' in kwargs:
130             all_fields_without_space = kwargs['all_fields'].replace(" ", "")
131             all_fields = all_fields_without_space.lower()
132             if 'true' == all_fields:
133                 mask_val = ''
134
135         elif 'fields' in kwargs and kwargs['fields'] != '':
136             fields_without_space = kwargs['fields'].replace(" ", "")
137
138             # filters = fields_without_space.split(',')
139
140             # mask_val_list = []
141             # for f in filters:
142             #     if '/' in f:
143             #         a = self.__gen_mask_tree(f)
144             #         mask_val_list.append(a)
145             #         continue
146             #     mask_val_list.append(f)
147             # mask_val = '{%s}' % ','.join(mask_val_list)
148             selector = {}
149
150             self.__update_selector_value(selector, fields_without_space, True)
151
152             mask_val = self.__gen_mask_from_selector(selector)
153
154         elif 'exclude_fields' in kwargs and kwargs['exclude_fields'] != '':
155             exclude_fields_without_space = kwargs['exclude_fields'].replace(
156                 " ", "")
157
158             selector = self.__gen_selector_from_model_with_value(
159                 self.fields)
160
161             self.__update_selector_value(
162                 selector, exclude_fields_without_space, False)
163
164             mask_val = self.__gen_mask_from_selector(selector)
165         elif 'exclude_default' in kwargs and kwargs['exclude_default'] != '':
166             exclude_default_without_space = kwargs['exclude_default'].replace(
167                 " ", "")
168             exclude_default = exclude_default_without_space.lower()
169             if 'true' == exclude_default:
170                 mask_val = '{}'
171
172         else:
173             mask_val = ''
174
175         return mask_val
176
177     def __gen_mask_tree(self, field: str) -> str:
178
179         f = field.split('/', 1)
180         if len(f) > 1:
181             s = self.__gen_mask_tree(f[1])
182             return '%s%s' % (f[0], s)
183         else:
184             return '{%s}' % f[0]
185
186     def __gen_selector_from_model_with_value(
187             self, model: Model, default_val: bool = True) -> dict:
188         selector = dict()
189         for i in model:
190             if type(model[i]) is List:
191                 if type(model[i].container) is String:
192                     selector[i] = default_val
193                     continue
194                 selector[i] = self.__gen_selector_from_model_with_value(
195                     model[i].container.model, default_val)
196                 continue
197             elif type(model[i]) is Nested:
198                 selector[i] = self.__gen_selector_from_model_with_value(
199                     model[i].model, default_val)
200             selector[i] = default_val
201         return selector
202
203     def __update_selector_value(self, default_selector: dict, filter: str,
204                                 val: bool):
205         fields = filter.split(',')
206         for f in fields:
207             if '/' in f:
208                 self.__update_selector_tree_value(default_selector, f, val)
209                 continue
210             default_selector[f] = val
211
212     def __update_selector_tree_value(self, m: dict, filter: str, val: bool):
213         filter_list = filter.split('/', 1)
214         if filter_list[0] not in m:
215             m[filter_list[0]] = dict()
216         if len(filter_list) > 1:
217             self.__update_selector_tree_value(
218                 m[filter_list[0]], filter_list[1], val)
219             return
220         m[filter_list[0]] = val
221
222     def __gen_mask_from_selector(self, fields: dict) -> str:
223         mask_li = list()
224         for k, v in fields.items():
225             if type(v) is dict:
226                 s = self.__gen_mask_from_selector(v)
227                 mask_li.append('%s%s' % (k, s))
228                 continue
229             if v:
230                 mask_li.append(k)
231
232         return '{%s}' % ','.join(mask_li)