Add error handling for ocloud route
[pti/o2.git] / o2common / views / route.py
1 # Copyright (C) 2021-2022 Wind River Systems, Inc.
2 #
3 #  Licensed under the Apache License, Version 2.0 (the "License");
4 #  you may not use this file except in compliance with the License.
5 #  You may obtain a copy of the License at
6 #
7 #      http://www.apache.org/licenses/LICENSE-2.0
8 #
9 #  Unless required by applicable law or agreed to in writing, software
10 #  distributed under the License is distributed on an "AS IS" BASIS,
11 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 #  See the License for the specific language governing permissions and
13 #  limitations under the License.
14
15 # -*- coding: utf-8 -*-
16 from __future__ import unicode_literals
17
18 # from collections import OrderedDict
19 from functools import wraps
20 # from six import iteritems
21
22 from flask import request
23
24 from flask_restx import Namespace
25 from flask_restx._http import HTTPStatus
26 from flask_restx.marshalling import marshal_with, marshal
27 from flask_restx.utils import merge
28 from flask_restx.mask import Mask  # , apply as apply_mask
29 from flask_restx.model import Model
30 from flask_restx.fields import List, Nested, String
31 from flask_restx.utils import unpack
32 from o2common.domain.base import Serializer
33
34 from o2common.helper import o2logging
35 logger = o2logging.get_logger(__name__)
36
37
38 class O2Namespace(Namespace):
39
40     def __init__(self, name, description=None, path=None, decorators=None,
41                  validate=None, authorizations=None, ordered=False, **kwargs):
42         super().__init__(name, description, path, decorators,
43                          validate, authorizations, ordered, **kwargs)
44
45     def marshal_with(
46         self, fields, as_list=False, code=HTTPStatus.OK, description=None,
47         **kwargs
48     ):
49         """
50         A decorator specifying the fields to use for serialization.
51
52         :param bool as_list: Indicate that the return type is a list \
53             (for the documentation)
54         :param int code: Optionally give the expected HTTP response \
55             code if its different from 200
56
57         """
58
59         def wrapper(func):
60             doc = {
61                 "responses": {
62                     str(code): (description, [fields], kwargs)
63                     if as_list
64                     else (description, fields, kwargs)
65                 },
66                 "__mask__": kwargs.get(
67                     "mask", True
68                 ),  # Mask values can't be determined outside app context
69             }
70             func.__apidoc__ = merge(getattr(func, "__apidoc__", {}), doc)
71             return o2_marshal_with(fields, ordered=self.ordered,
72                                    **kwargs)(func)
73
74         return wrapper
75
76
77 class o2_marshal_with(marshal_with):
78     def __init__(
79         self, fields, envelope=None, skip_none=False, mask=None, ordered=False
80     ):
81         """
82         :param fields: a dict of whose keys will make up the final
83                        serialized response output
84         :param envelope: optional key that will be used to envelop the
85                        serialized response
86         """
87         self.fields = fields
88         self.envelope = envelope
89         self.skip_none = skip_none
90         self.ordered = ordered
91         self.mask = Mask(mask, skip=True)
92
93     def __call__(self, f):
94         @wraps(f)
95         def wrapper(*args, **kwargs):
96             resp = f(*args, **kwargs)
97
98             req_args = request.args
99             mask = self._gen_mask_from_selector(**req_args)
100
101             # mask = self.mask
102
103             # if has_request_context():
104             # mask_header = current_app.config["RESTX_MASK_HEADER"]
105             # mask = request.headers.get(mask_header) or mask
106             if isinstance(resp, tuple):
107                 data, code, headers = unpack(resp)
108                 return (
109                     marshal(
110                         data,
111                         self.fields,
112                         self.envelope,
113                         self.skip_none,
114                         mask,
115                         self.ordered,
116                     ),
117                     code,
118                     headers,
119                 )
120             else:
121                 return marshal(
122                     resp, self.fields, self.envelope, self.skip_none, mask,
123                     self.ordered
124                 )
125
126         return wrapper
127
128     def _gen_mask_from_selector(self, **kwargs) -> str:
129         mask_val = ''
130         if 'all_fields' in kwargs:
131             all_fields_without_space = kwargs['all_fields'].replace(" ", "")
132             all_fields = all_fields_without_space.lower()
133             if 'true' == all_fields:
134                 mask_val = ''
135
136         elif 'fields' in kwargs and kwargs['fields'] != '':
137             fields_without_space = kwargs['fields'].replace(" ", "")
138
139             # filters = fields_without_space.split(',')
140
141             # mask_val_list = []
142             # for f in filters:
143             #     if '/' in f:
144             #         a = self.__gen_mask_tree(f)
145             #         mask_val_list.append(a)
146             #         continue
147             #     mask_val_list.append(f)
148             # mask_val = '{%s}' % ','.join(mask_val_list)
149             selector = {}
150
151             self.__update_selector_value(selector, fields_without_space, True)
152
153             mask_val = self.__gen_mask_from_selector(selector)
154
155         elif 'exclude_fields' in kwargs and kwargs['exclude_fields'] != '':
156             exclude_fields_without_space = kwargs['exclude_fields'].replace(
157                 " ", "")
158
159             selector = self.__gen_selector_from_model_with_value(
160                 self.fields)
161
162             self.__update_selector_value(
163                 selector, exclude_fields_without_space, False)
164
165             mask_val = self.__gen_mask_from_selector(selector)
166         elif 'exclude_default' in kwargs and kwargs['exclude_default'] != '':
167             exclude_default_without_space = kwargs['exclude_default'].replace(
168                 " ", "")
169             exclude_default = exclude_default_without_space.lower()
170             if 'true' == exclude_default:
171                 mask_val = '{}'
172
173         else:
174             mask_val = ''
175
176         return mask_val
177
178     def __gen_mask_tree(self, field: str) -> str:
179
180         f = field.split('/', 1)
181         if len(f) > 1:
182             s = self.__gen_mask_tree(f[1])
183             return '%s%s' % (f[0], s)
184         else:
185             return '{%s}' % f[0]
186
187     def __gen_selector_from_model_with_value(
188             self, model: Model, default_val: bool = True) -> dict:
189         selector = dict()
190         for i in model:
191             if type(model[i]) is List:
192                 if type(model[i].container) is String:
193                     selector[i] = default_val
194                     continue
195                 selector[i] = self.__gen_selector_from_model_with_value(
196                     model[i].container.model, default_val)
197                 continue
198             elif type(model[i]) is Nested:
199                 selector[i] = self.__gen_selector_from_model_with_value(
200                     model[i].model, default_val)
201             selector[i] = default_val
202         return selector
203
204     def __update_selector_value(self, default_selector: dict, filter: str,
205                                 val: bool):
206         fields = filter.split(',')
207         for f in fields:
208             if '/' in f:
209                 self.__update_selector_tree_value(default_selector, f, val)
210                 continue
211             default_selector[f] = val
212
213     def __update_selector_tree_value(self, m: dict, filter: str, val: bool):
214         filter_list = filter.split('/', 1)
215         if filter_list[0] not in m:
216             m[filter_list[0]] = dict()
217         if len(filter_list) > 1:
218             self.__update_selector_tree_value(
219                 m[filter_list[0]], filter_list[1], val)
220             return
221         m[filter_list[0]] = val
222
223     def __gen_mask_from_selector(self, fields: dict) -> str:
224         mask_li = list()
225         for k, v in fields.items():
226             if type(v) is dict:
227                 s = self.__gen_mask_from_selector(v)
228                 mask_li.append('%s%s' % (k, s))
229                 continue
230             if v:
231                 mask_li.append(k)
232
233         return '{%s}' % ','.join(mask_li)
234
235
236 class ProblemDetails(Serializer):
237     def __init__(self, namespace: O2Namespace, code: int, detail: str,
238                  title=None, instance=None
239                  ) -> None:
240         self.ns = namespace
241         self.status = code
242         self.detail = detail
243         self.type = request.path
244         self.title = title if title is not None else self.getTitle(code)
245         self.instance = instance if instance is not None else []
246
247     def getTitle(self, code):
248         return HTTPStatus(code).phrase
249
250     def abort(self):
251         self.ns.abort(self.status, self.detail, **self.serialize())
252
253     def serialize(self):
254         details = {}
255         for key in dir(self):
256             if key == 'ns' or key.startswith('__') or\
257                     callable(getattr(self, key)):
258                 continue
259             else:
260                 details[key] = getattr(self, key)
261         return details