319469e4d78bbce2ecd11e2d2b0f97c588be4645
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
1 # Copyright (c) 2019 AT&T Intellectual Property.
2 # Copyright (c) 2018-2019 Nokia.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 #
17 # This source code is part of the near-RT RIC (RAN Intelligent Controller)
18 # platform project (RICP).
19 #
20
21
22 """The module provides implementation of Shared Data Layer (SDL) database backend interface."""
23 import contextlib
24 from typing import (Dict, Set, List, Union)
25 from redis import Redis
26 from redis.sentinel import Sentinel
27 from redis.lock import Lock
28 from redis._compat import nativestr
29 from redis import exceptions as redis_exceptions
30 from ricsdl.configuration import _Configuration
31 from ricsdl.exceptions import (
32     RejectedByBackend,
33     NotConnected,
34     BackendError
35 )
36 from .dbbackend_abc import DbBackendAbc
37 from .dbbackend_abc import DbBackendLockAbc
38
39
40 @contextlib.contextmanager
41 def _map_to_sdl_exception():
42     """Translates known redis exceptions into SDL exceptions."""
43     try:
44         yield
45     except(redis_exceptions.ResponseError) as exc:
46         raise RejectedByBackend("SDL backend rejected the request: {}".
47                                 format(str(exc))) from exc
48     except(redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
49         raise NotConnected("SDL not connected to backend: {}".
50                            format(str(exc))) from exc
51     except(redis_exceptions.RedisError) as exc:
52         raise BackendError("SDL backend failed to process the request: {}".
53                            format(str(exc))) from exc
54
55
56 class RedisBackend(DbBackendAbc):
57     """
58     A class providing an implementation of database backend of Shared Data Layer (SDL), when
59     backend database solution is Redis.
60
61     Args:
62         configuration (_Configuration): SDL configuration, containing credentials to connect to
63                                         Redis database backend.
64     """
65     def __init__(self, configuration: _Configuration) -> None:
66         super().__init__()
67         with _map_to_sdl_exception():
68             if configuration.get_params().db_sentinel_port:
69                 sentinel_node = (configuration.get_params().db_host,
70                                  configuration.get_params().db_sentinel_port)
71                 master_name = configuration.get_params().db_sentinel_master_name
72                 self.__sentinel = Sentinel([sentinel_node])
73                 self.__redis = self.__sentinel.master_for(master_name)
74             else:
75                 self.__redis = Redis(host=configuration.get_params().db_host,
76                                      port=configuration.get_params().db_port,
77                                      db=0,
78                                      max_connections=20)
79         self.__redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
80         self.__redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
81
82     def __del__(self):
83         self.close()
84
85     def __str__(self):
86         return str(
87             {
88                 "Redis connection": repr(self.__redis)
89             }
90         )
91
92     def close(self):
93         self.__redis.close()
94
95     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
96         db_data_map = self._add_data_map_ns_prefix(ns, data_map)
97         with _map_to_sdl_exception():
98             self.__redis.mset(db_data_map)
99
100     def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
101         db_key = self._add_key_ns_prefix(ns, key)
102         with _map_to_sdl_exception():
103             return self.__redis.execute_command('SETIE', db_key, new_data, old_data)
104
105     def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
106         db_key = self._add_key_ns_prefix(ns, key)
107         with _map_to_sdl_exception():
108             return self.__redis.setnx(db_key, data)
109
110     def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
111         ret = dict()
112         db_keys = self._add_keys_ns_prefix(ns, keys)
113         with _map_to_sdl_exception():
114             values = self.__redis.mget(db_keys)
115             for idx, val in enumerate(values):
116                 # return only key values, which has a value
117                 if val:
118                     ret[keys[idx]] = val
119             return ret
120
121     def find_keys(self, ns: str, key_pattern: str) -> List[str]:
122         db_key_pattern = self._add_key_ns_prefix(ns, key_pattern)
123         with _map_to_sdl_exception():
124             ret = self.__redis.keys(db_key_pattern)
125             return self._strip_ns_from_bin_keys(ns, ret)
126
127     def find_and_get(self, ns: str, key_pattern: str) -> Dict[str, bytes]:
128         # todo: replace below implementation with redis 'NGET' module
129         ret = dict()  # type: Dict[str, bytes]
130         with _map_to_sdl_exception():
131             matched_keys = self.find_keys(ns, key_pattern)
132             if matched_keys:
133                 ret = self.get(ns, matched_keys)
134         return ret
135
136     def remove(self, ns: str, keys: List[str]) -> None:
137         db_keys = self._add_keys_ns_prefix(ns, keys)
138         with _map_to_sdl_exception():
139             self.__redis.delete(*db_keys)
140
141     def remove_if(self, ns: str, key: str, data: bytes) -> bool:
142         db_key = self._add_key_ns_prefix(ns, key)
143         with _map_to_sdl_exception():
144             return self.__redis.execute_command('DELIE', db_key, data)
145
146     def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
147         db_key = self._add_key_ns_prefix(ns, group)
148         with _map_to_sdl_exception():
149             self.__redis.sadd(db_key, *members)
150
151     def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
152         db_key = self._add_key_ns_prefix(ns, group)
153         with _map_to_sdl_exception():
154             self.__redis.srem(db_key, *members)
155
156     def remove_group(self, ns: str, group: str) -> None:
157         db_key = self._add_key_ns_prefix(ns, group)
158         with _map_to_sdl_exception():
159             self.__redis.delete(db_key)
160
161     def get_members(self, ns: str, group: str) -> Set[bytes]:
162         db_key = self._add_key_ns_prefix(ns, group)
163         with _map_to_sdl_exception():
164             return self.__redis.smembers(db_key)
165
166     def is_member(self, ns: str, group: str, member: bytes) -> bool:
167         db_key = self._add_key_ns_prefix(ns, group)
168         with _map_to_sdl_exception():
169             return self.__redis.sismember(db_key, member)
170
171     def group_size(self, ns: str, group: str) -> int:
172         db_key = self._add_key_ns_prefix(ns, group)
173         with _map_to_sdl_exception():
174             return self.__redis.scard(db_key)
175
176     @classmethod
177     def _add_key_ns_prefix(cls, ns: str, key: str):
178         return '{' + ns + '},' + key
179
180     @classmethod
181     def _add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
182         ret_nskeys = []
183         for k in keylist:
184             ret_nskeys.append('{' + ns + '},' + k)
185         return ret_nskeys
186
187     @classmethod
188     def _add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
189         ret_nsdict = {}
190         for key, val in data_dict.items():
191             ret_nsdict['{' + ns + '},' + key] = val
192         return ret_nsdict
193
194     @classmethod
195     def _strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
196         ret_keys = []
197         for k in nskeylist:
198             try:
199                 redis_key = k.decode("utf-8")
200             except UnicodeDecodeError as exc:
201                 msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
202                 raise RejectedByBackend(msg)
203             nskey = redis_key.split(',', 1)
204             if len(nskey) != 2:
205                 msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
206                 raise RejectedByBackend(msg)
207             ret_keys.append(nskey[1])
208         return ret_keys
209
210     def get_redis_connection(self):
211         """Return existing Redis database connection."""
212         return self.__redis
213
214
215 class RedisBackendLock(DbBackendLockAbc):
216     """
217     A class providing an implementation of database backend lock of Shared Data Layer (SDL), when
218     backend database solution is Redis.
219
220     Args:
221         ns (str): Namespace under which this lock is targeted.
222         name (str): Lock name, identifies the lock key in a Redis database backend.
223         expiration (int, float): Lock expiration time after which the lock is removed if it hasn't
224                                  been released earlier by a 'release' method.
225         redis_backend (RedisBackend): Database backend object containing connection to Redis
226                                       database.
227     """
228     lua_get_validity_time = None
229     # KEYS[1] - lock name
230     # ARGS[1] - token
231     # return < 0 in case of failure, otherwise return lock validity time in milliseconds.
232     LUA_GET_VALIDITY_TIME_SCRIPT = """
233         local token = redis.call('get', KEYS[1])
234         if not token then
235             return -10
236         end
237         if token ~= ARGV[1] then
238             return -11
239         end
240         return redis.call('pttl', KEYS[1])
241     """
242
243     def __init__(self, ns: str, name: str, expiration: Union[int, float],
244                  redis_backend: RedisBackend) -> None:
245         super().__init__(ns, name)
246         self.__redis = redis_backend.get_redis_connection()
247         with _map_to_sdl_exception():
248             redis_lockname = '{' + ns + '},' + self._lock_name
249             self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)
250             self._register_scripts()
251
252     def __str__(self):
253         return str(
254             {
255                 "lock namespace": self._ns,
256                 "lock name": self._lock_name,
257                 "lock status": self._lock_status_to_string()
258             }
259         )
260
261     def acquire(self, retry_interval: Union[int, float] = 0.1,
262                 retry_timeout: Union[int, float] = 10) -> bool:
263         succeeded = False
264         self.__redis_lock.sleep = retry_interval
265         with _map_to_sdl_exception():
266             succeeded = self.__redis_lock.acquire(blocking_timeout=retry_timeout)
267         return succeeded
268
269     def release(self) -> None:
270         with _map_to_sdl_exception():
271             self.__redis_lock.release()
272
273     def refresh(self) -> None:
274         with _map_to_sdl_exception():
275             self.__redis_lock.reacquire()
276
277     def get_validity_time(self) -> Union[int, float]:
278         validity = 0
279         if self.__redis_lock.local.token is None:
280             msg = u'Cannot get validity time of an unlocked lock %s' % self._lock_name
281             raise RejectedByBackend(msg)
282
283         with _map_to_sdl_exception():
284             validity = self.lua_get_validity_time(keys=[self.__redis_lock.name],
285                                                   args=[self.__redis_lock.local.token],
286                                                   client=self.__redis)
287         if validity < 0:
288             msg = (u'Getting validity time of a lock %s failed with error code: %d'
289                    % (self._lock_name, validity))
290             raise RejectedByBackend(msg)
291         ftime = validity / 1000.0
292         if ftime.is_integer():
293             return int(ftime)
294         return ftime
295
296     def _register_scripts(self):
297         cls = self.__class__
298         client = self.__redis
299         if cls.lua_get_validity_time is None:
300             cls.lua_get_validity_time = client.register_script(cls.LUA_GET_VALIDITY_TIME_SCRIPT)
301
302     def _lock_status_to_string(self) -> str:
303         try:
304             if self.__redis_lock.locked():
305                 if self.__redis_lock.owned():
306                     return 'locked'
307                 return 'locked by someone else'
308             return 'unlocked'
309         except(redis_exceptions.RedisError) as exc:
310             return f'Error: {str(exc)}'