Implement a fake SDL database backend
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
1 # Copyright (c) 2019 AT&T Intellectual Property.
2 # Copyright (c) 2018-2019 Nokia.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 #
17 # This source code is part of the near-RT RIC (RAN Intelligent Controller)
18 # platform project (RICP).
19 #
20
21
22 """The module provides implementation of Shared Data Layer (SDL) database backend interface."""
23 import contextlib
24 from typing import (Dict, Set, List, Union)
25 from redis import Redis
26 from redis.sentinel import Sentinel
27 from redis.lock import Lock
28 from redis._compat import nativestr
29 from redis import exceptions as redis_exceptions
30 from ricsdl.configuration import _Configuration
31 from ricsdl.exceptions import (
32     RejectedByBackend,
33     NotConnected,
34     BackendError
35 )
36 from .dbbackend_abc import DbBackendAbc
37 from .dbbackend_abc import DbBackendLockAbc
38
39
40 @contextlib.contextmanager
41 def _map_to_sdl_exception():
42     """Translates known redis exceptions into SDL exceptions."""
43     try:
44         yield
45     except(redis_exceptions.ResponseError) as exc:
46         raise RejectedByBackend("SDL backend rejected the request: {}".
47                                 format(str(exc))) from exc
48     except(redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
49         raise NotConnected("SDL not connected to backend: {}".
50                            format(str(exc))) from exc
51     except(redis_exceptions.RedisError) as exc:
52         raise BackendError("SDL backend failed to process the request: {}".
53                            format(str(exc))) from exc
54
55
56 class RedisBackend(DbBackendAbc):
57     """
58     A class providing an implementation of database backend of Shared Data Layer (SDL), when
59     backend database solution is Redis.
60
61     Args:
62         configuration (_Configuration): SDL configuration, containing credentials to connect to
63                                         Redis database backend.
64     """
65     def __init__(self, configuration: _Configuration) -> None:
66         super().__init__()
67         with _map_to_sdl_exception():
68             if configuration.get_params().db_sentinel_port:
69                 sentinel_node = (configuration.get_params().db_host,
70                                  configuration.get_params().db_sentinel_port)
71                 master_name = configuration.get_params().db_sentinel_master_name
72                 self.__sentinel = Sentinel([sentinel_node])
73                 self.__redis = self.__sentinel.master_for(master_name)
74             else:
75                 self.__redis = Redis(host=configuration.get_params().db_host,
76                                      port=configuration.get_params().db_port,
77                                      db=0,
78                                      max_connections=20)
79         self.__redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
80         self.__redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
81
82     def __del__(self):
83         self.close()
84
85     def __str__(self):
86         return str(
87             {
88                 "DB type": "Redis",
89                 "Redis connection": repr(self.__redis)
90             }
91         )
92
93     def close(self):
94         self.__redis.close()
95
96     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
97         db_data_map = self._add_data_map_ns_prefix(ns, data_map)
98         with _map_to_sdl_exception():
99             self.__redis.mset(db_data_map)
100
101     def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
102         db_key = self._add_key_ns_prefix(ns, key)
103         with _map_to_sdl_exception():
104             return self.__redis.execute_command('SETIE', db_key, new_data, old_data)
105
106     def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
107         db_key = self._add_key_ns_prefix(ns, key)
108         with _map_to_sdl_exception():
109             return self.__redis.setnx(db_key, data)
110
111     def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
112         ret = dict()
113         db_keys = self._add_keys_ns_prefix(ns, keys)
114         with _map_to_sdl_exception():
115             values = self.__redis.mget(db_keys)
116             for idx, val in enumerate(values):
117                 # return only key values, which has a value
118                 if val:
119                     ret[keys[idx]] = val
120             return ret
121
122     def find_keys(self, ns: str, key_pattern: str) -> List[str]:
123         db_key_pattern = self._add_key_ns_prefix(ns, key_pattern)
124         with _map_to_sdl_exception():
125             ret = self.__redis.keys(db_key_pattern)
126             return self._strip_ns_from_bin_keys(ns, ret)
127
128     def find_and_get(self, ns: str, key_pattern: str) -> Dict[str, bytes]:
129         # todo: replace below implementation with redis 'NGET' module
130         ret = dict()  # type: Dict[str, bytes]
131         with _map_to_sdl_exception():
132             matched_keys = self.find_keys(ns, key_pattern)
133             if matched_keys:
134                 ret = self.get(ns, matched_keys)
135         return ret
136
137     def remove(self, ns: str, keys: List[str]) -> None:
138         db_keys = self._add_keys_ns_prefix(ns, keys)
139         with _map_to_sdl_exception():
140             self.__redis.delete(*db_keys)
141
142     def remove_if(self, ns: str, key: str, data: bytes) -> bool:
143         db_key = self._add_key_ns_prefix(ns, key)
144         with _map_to_sdl_exception():
145             return self.__redis.execute_command('DELIE', db_key, data)
146
147     def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
148         db_key = self._add_key_ns_prefix(ns, group)
149         with _map_to_sdl_exception():
150             self.__redis.sadd(db_key, *members)
151
152     def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
153         db_key = self._add_key_ns_prefix(ns, group)
154         with _map_to_sdl_exception():
155             self.__redis.srem(db_key, *members)
156
157     def remove_group(self, ns: str, group: str) -> None:
158         db_key = self._add_key_ns_prefix(ns, group)
159         with _map_to_sdl_exception():
160             self.__redis.delete(db_key)
161
162     def get_members(self, ns: str, group: str) -> Set[bytes]:
163         db_key = self._add_key_ns_prefix(ns, group)
164         with _map_to_sdl_exception():
165             return self.__redis.smembers(db_key)
166
167     def is_member(self, ns: str, group: str, member: bytes) -> bool:
168         db_key = self._add_key_ns_prefix(ns, group)
169         with _map_to_sdl_exception():
170             return self.__redis.sismember(db_key, member)
171
172     def group_size(self, ns: str, group: str) -> int:
173         db_key = self._add_key_ns_prefix(ns, group)
174         with _map_to_sdl_exception():
175             return self.__redis.scard(db_key)
176
177     @classmethod
178     def _add_key_ns_prefix(cls, ns: str, key: str):
179         return '{' + ns + '},' + key
180
181     @classmethod
182     def _add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
183         ret_nskeys = []
184         for k in keylist:
185             ret_nskeys.append('{' + ns + '},' + k)
186         return ret_nskeys
187
188     @classmethod
189     def _add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
190         ret_nsdict = {}
191         for key, val in data_dict.items():
192             ret_nsdict['{' + ns + '},' + key] = val
193         return ret_nsdict
194
195     @classmethod
196     def _strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
197         ret_keys = []
198         for k in nskeylist:
199             try:
200                 redis_key = k.decode("utf-8")
201             except UnicodeDecodeError as exc:
202                 msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
203                 raise RejectedByBackend(msg)
204             nskey = redis_key.split(',', 1)
205             if len(nskey) != 2:
206                 msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
207                 raise RejectedByBackend(msg)
208             ret_keys.append(nskey[1])
209         return ret_keys
210
211     def get_redis_connection(self):
212         """Return existing Redis database connection."""
213         return self.__redis
214
215
216 class RedisBackendLock(DbBackendLockAbc):
217     """
218     A class providing an implementation of database backend lock of Shared Data Layer (SDL), when
219     backend database solution is Redis.
220
221     Args:
222         ns (str): Namespace under which this lock is targeted.
223         name (str): Lock name, identifies the lock key in a Redis database backend.
224         expiration (int, float): Lock expiration time after which the lock is removed if it hasn't
225                                  been released earlier by a 'release' method.
226         redis_backend (RedisBackend): Database backend object containing connection to Redis
227                                       database.
228     """
229     lua_get_validity_time = None
230     # KEYS[1] - lock name
231     # ARGS[1] - token
232     # return < 0 in case of failure, otherwise return lock validity time in milliseconds.
233     LUA_GET_VALIDITY_TIME_SCRIPT = """
234         local token = redis.call('get', KEYS[1])
235         if not token then
236             return -10
237         end
238         if token ~= ARGV[1] then
239             return -11
240         end
241         return redis.call('pttl', KEYS[1])
242     """
243
244     def __init__(self, ns: str, name: str, expiration: Union[int, float],
245                  redis_backend: RedisBackend) -> None:
246         super().__init__(ns, name)
247         self.__redis = redis_backend.get_redis_connection()
248         with _map_to_sdl_exception():
249             redis_lockname = '{' + ns + '},' + self._lock_name
250             self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)
251             self._register_scripts()
252
253     def __str__(self):
254         return str(
255             {
256                 "lock DB type": "Redis",
257                 "lock namespace": self._ns,
258                 "lock name": self._lock_name,
259                 "lock status": self._lock_status_to_string()
260             }
261         )
262
263     def acquire(self, retry_interval: Union[int, float] = 0.1,
264                 retry_timeout: Union[int, float] = 10) -> bool:
265         succeeded = False
266         self.__redis_lock.sleep = retry_interval
267         with _map_to_sdl_exception():
268             succeeded = self.__redis_lock.acquire(blocking_timeout=retry_timeout)
269         return succeeded
270
271     def release(self) -> None:
272         with _map_to_sdl_exception():
273             self.__redis_lock.release()
274
275     def refresh(self) -> None:
276         with _map_to_sdl_exception():
277             self.__redis_lock.reacquire()
278
279     def get_validity_time(self) -> Union[int, float]:
280         validity = 0
281         if self.__redis_lock.local.token is None:
282             msg = u'Cannot get validity time of an unlocked lock %s' % self._lock_name
283             raise RejectedByBackend(msg)
284
285         with _map_to_sdl_exception():
286             validity = self.lua_get_validity_time(keys=[self.__redis_lock.name],
287                                                   args=[self.__redis_lock.local.token],
288                                                   client=self.__redis)
289         if validity < 0:
290             msg = (u'Getting validity time of a lock %s failed with error code: %d'
291                    % (self._lock_name, validity))
292             raise RejectedByBackend(msg)
293         ftime = validity / 1000.0
294         if ftime.is_integer():
295             return int(ftime)
296         return ftime
297
298     def _register_scripts(self):
299         cls = self.__class__
300         client = self.__redis
301         if cls.lua_get_validity_time is None:
302             cls.lua_get_validity_time = client.register_script(cls.LUA_GET_VALIDITY_TIME_SCRIPT)
303
304     def _lock_status_to_string(self) -> str:
305         try:
306             if self.__redis_lock.locked():
307                 if self.__redis_lock.owned():
308                     return 'locked'
309                 return 'locked by someone else'
310             return 'unlocked'
311         except(redis_exceptions.RedisError) as exc:
312             return f'Error: {str(exc)}'