afa7450f640e9d425423b8273b24cd67d98ce766
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
1 # Copyright (c) 2019 AT&T Intellectual Property.
2 # Copyright (c) 2018-2019 Nokia.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 #
17 # This source code is part of the near-RT RIC (RAN Intelligent Controller)
18 # platform project (RICP).
19 #
20
21
22 """The module provides implementation of Shared Data Layer (SDL) database backend interface."""
23 import contextlib
24 from typing import (Dict, Set, List, Union)
25 from redis import Redis
26 from redis.sentinel import Sentinel
27 from redis.lock import Lock
28 from redis._compat import nativestr
29 from redis import exceptions as redis_exceptions
30 from ricsdl.configuration import _Configuration
31 from ricsdl.exceptions import (
32     RejectedByBackend,
33     NotConnected,
34     BackendError
35 )
36 from .dbbackend_abc import DbBackendAbc
37 from .dbbackend_abc import DbBackendLockAbc
38
39
40 @contextlib.contextmanager
41 def _map_to_sdl_exception():
42     """Translates known redis exceptions into SDL exceptions."""
43     try:
44         yield
45     except(redis_exceptions.ResponseError) as exc:
46         raise RejectedByBackend("SDL backend rejected the request: {}".
47                                 format(str(exc))) from exc
48     except(redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
49         raise NotConnected("SDL not connected to backend: {}".
50                            format(str(exc))) from exc
51     except(redis_exceptions.RedisError) as exc:
52         raise BackendError("SDL backend failed to process the request: {}".
53                            format(str(exc))) from exc
54
55
56 class RedisBackend(DbBackendAbc):
57     """
58     A class providing an implementation of database backend of Shared Data Layer (SDL), when
59     backend database solution is Redis.
60
61     Args:
62         configuration (_Configuration): SDL configuration, containing credentials to connect to
63                                         Redis database backend.
64     """
65     def __init__(self, configuration: _Configuration) -> None:
66         super().__init__()
67         with _map_to_sdl_exception():
68             if configuration.get_params().db_sentinel_port:
69                 sentinel_node = (configuration.get_params().db_host,
70                                  configuration.get_params().db_sentinel_port)
71                 master_name = configuration.get_params().db_sentinel_master_name
72                 self.__sentinel = Sentinel([sentinel_node])
73                 self.__redis = self.__sentinel.master_for(master_name)
74             else:
75                 self.__redis = Redis(host=configuration.get_params().db_host,
76                                      port=configuration.get_params().db_port,
77                                      db=0,
78                                      max_connections=20)
79         self.__redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
80         self.__redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
81
82     def __del__(self):
83         self.close()
84
85     def __str__(self):
86         return str(
87             {
88                 "Redis connection": repr(self.__redis)
89             }
90         )
91
92     def close(self):
93         self.__redis.close()
94
95     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
96         db_data_map = self._add_data_map_ns_prefix(ns, data_map)
97         with _map_to_sdl_exception():
98             self.__redis.mset(db_data_map)
99
100     def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
101         db_key = self._add_key_ns_prefix(ns, key)
102         with _map_to_sdl_exception():
103             return self.__redis.execute_command('SETIE', db_key, new_data, old_data)
104
105     def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
106         db_key = self._add_key_ns_prefix(ns, key)
107         with _map_to_sdl_exception():
108             return self.__redis.setnx(db_key, data)
109
110     def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
111         ret = dict()
112         db_keys = self._add_keys_ns_prefix(ns, keys)
113         with _map_to_sdl_exception():
114             values = self.__redis.mget(db_keys)
115             for idx, val in enumerate(values):
116                 # return only key values, which has a value
117                 if val:
118                     ret[keys[idx]] = val
119             return ret
120
121     def find_keys(self, ns: str, key_prefix: str) -> List[str]:
122         escaped_key_prefix = self._escape_characters(key_prefix)
123         db_escaped_key_prefix = self._add_key_ns_prefix(ns, escaped_key_prefix + '*')
124         with _map_to_sdl_exception():
125             ret = self.__redis.keys(db_escaped_key_prefix)
126             return self._strip_ns_from_bin_keys(ns, ret)
127
128     def find_and_get(self, ns: str, key_prefix: str, atomic: bool) -> Dict[str, bytes]:
129         # todo: replace below implementation with redis 'NGET' module
130         ret = dict()  # type: Dict[str, bytes]
131         with _map_to_sdl_exception():
132             matched_keys = self.find_keys(ns, key_prefix)
133             if matched_keys:
134                 ret = self.get(ns, matched_keys)
135         return ret
136
137     def remove(self, ns: str, keys: List[str]) -> None:
138         db_keys = self._add_keys_ns_prefix(ns, keys)
139         with _map_to_sdl_exception():
140             self.__redis.delete(*db_keys)
141
142     def remove_if(self, ns: str, key: str, data: bytes) -> bool:
143         db_key = self._add_key_ns_prefix(ns, key)
144         with _map_to_sdl_exception():
145             return self.__redis.execute_command('DELIE', db_key, data)
146
147     def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
148         db_key = self._add_key_ns_prefix(ns, group)
149         with _map_to_sdl_exception():
150             self.__redis.sadd(db_key, *members)
151
152     def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
153         db_key = self._add_key_ns_prefix(ns, group)
154         with _map_to_sdl_exception():
155             self.__redis.srem(db_key, *members)
156
157     def remove_group(self, ns: str, group: str) -> None:
158         db_key = self._add_key_ns_prefix(ns, group)
159         with _map_to_sdl_exception():
160             self.__redis.delete(db_key)
161
162     def get_members(self, ns: str, group: str) -> Set[bytes]:
163         db_key = self._add_key_ns_prefix(ns, group)
164         with _map_to_sdl_exception():
165             return self.__redis.smembers(db_key)
166
167     def is_member(self, ns: str, group: str, member: bytes) -> bool:
168         db_key = self._add_key_ns_prefix(ns, group)
169         with _map_to_sdl_exception():
170             return self.__redis.sismember(db_key, member)
171
172     def group_size(self, ns: str, group: str) -> int:
173         db_key = self._add_key_ns_prefix(ns, group)
174         with _map_to_sdl_exception():
175             return self.__redis.scard(db_key)
176
177     @classmethod
178     def _add_key_ns_prefix(cls, ns: str, key: str):
179         return '{' + ns + '},' + key
180
181     @classmethod
182     def _add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
183         ret_nskeys = []
184         for k in keylist:
185             ret_nskeys.append('{' + ns + '},' + k)
186         return ret_nskeys
187
188     @classmethod
189     def _add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
190         ret_nsdict = {}
191         for key, val in data_dict.items():
192             ret_nsdict['{' + ns + '},' + key] = val
193         return ret_nsdict
194
195     @classmethod
196     def _strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
197         ret_keys = []
198         for k in nskeylist:
199             nskey = k.decode("utf-8").split(',', 1)
200             if len(nskey) != 2:
201                 msg = u'Illegal namespace %s key:%s' % (ns, nskey)
202                 raise RejectedByBackend(msg)
203             ret_keys.append(nskey[1])
204         return ret_keys
205
206     @classmethod
207     def _escape_characters(cls, pattern: str) -> str:
208         return pattern.translate(str.maketrans(
209             {"(": r"\(",
210              ")": r"\)",
211              "[": r"\[",
212              "]": r"\]",
213              "*": r"\*",
214              "?": r"\?",
215              "\\": r"\\"}))
216
217     def get_redis_connection(self):
218         """Return existing Redis database connection."""
219         return self.__redis
220
221
222 class RedisBackendLock(DbBackendLockAbc):
223     """
224     A class providing an implementation of database backend lock of Shared Data Layer (SDL), when
225     backend database solution is Redis.
226
227     Args:
228         ns (str): Namespace under which this lock is targeted.
229         name (str): Lock name, identifies the lock key in a Redis database backend.
230         expiration (int, float): Lock expiration time after which the lock is removed if it hasn't
231                                  been released earlier by a 'release' method.
232         redis_backend (RedisBackend): Database backend object containing connection to Redis
233                                       database.
234     """
235     lua_get_validity_time = None
236     # KEYS[1] - lock name
237     # ARGS[1] - token
238     # return < 0 in case of failure, otherwise return lock validity time in milliseconds.
239     LUA_GET_VALIDITY_TIME_SCRIPT = """
240         local token = redis.call('get', KEYS[1])
241         if not token then
242             return -10
243         end
244         if token ~= ARGV[1] then
245             return -11
246         end
247         return redis.call('pttl', KEYS[1])
248     """
249
250     def __init__(self, ns: str, name: str, expiration: Union[int, float],
251                  redis_backend: RedisBackend) -> None:
252         super().__init__(ns, name)
253         self.__redis = redis_backend.get_redis_connection()
254         with _map_to_sdl_exception():
255             redis_lockname = '{' + ns + '},' + self._lock_name
256             self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)
257             self._register_scripts()
258
259     def __str__(self):
260         return str(
261             {
262                 "lock namespace": self._ns,
263                 "lock name": self._lock_name,
264                 "lock status": self._lock_status_to_string()
265             }
266         )
267
268     def acquire(self, retry_interval: Union[int, float] = 0.1,
269                 retry_timeout: Union[int, float] = 10) -> bool:
270         succeeded = False
271         self.__redis_lock.sleep = retry_interval
272         with _map_to_sdl_exception():
273             succeeded = self.__redis_lock.acquire(blocking_timeout=retry_timeout)
274         return succeeded
275
276     def release(self) -> None:
277         with _map_to_sdl_exception():
278             self.__redis_lock.release()
279
280     def refresh(self) -> None:
281         with _map_to_sdl_exception():
282             self.__redis_lock.reacquire()
283
284     def get_validity_time(self) -> Union[int, float]:
285         validity = 0
286         if self.__redis_lock.local.token is None:
287             msg = u'Cannot get validity time of an unlocked lock %s' % self._lock_name
288             raise RejectedByBackend(msg)
289
290         with _map_to_sdl_exception():
291             validity = self.lua_get_validity_time(keys=[self.__redis_lock.name],
292                                                   args=[self.__redis_lock.local.token],
293                                                   client=self.__redis)
294         if validity < 0:
295             msg = (u'Getting validity time of a lock %s failed with error code: %d'
296                    % (self._lock_name, validity))
297             raise RejectedByBackend(msg)
298         ftime = validity / 1000.0
299         if ftime.is_integer():
300             return int(ftime)
301         return ftime
302
303     def _register_scripts(self):
304         cls = self.__class__
305         client = self.__redis
306         if cls.lua_get_validity_time is None:
307             cls.lua_get_validity_time = client.register_script(cls.LUA_GET_VALIDITY_TIME_SCRIPT)
308
309     def _lock_status_to_string(self) -> str:
310         try:
311             if self.__redis_lock.locked():
312                 if self.__redis_lock.owned():
313                     return 'locked'
314                 return 'locked by someone else'
315             return 'unlocked'
316         except(redis_exceptions.RedisError) as exc:
317             return f'Error: {str(exc)}'