822afcf739c5440ebd50b2b679b0dca32ce50a99
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
1 # Copyright (c) 2019 AT&T Intellectual Property.
2 # Copyright (c) 2018-2022 Nokia.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 #
17 # This source code is part of the near-RT RIC (RAN Intelligent Controller)
18 # platform project (RICP).
19 #
20
21
22 """The module provides implementation of Shared Data Layer (SDL) database backend interface."""
23 import contextlib
24 import threading
25 from typing import (Callable, Dict, Set, List, Optional, Tuple, Union)
26 import zlib
27 import redis
28 from redis import Redis
29 from redis.sentinel import Sentinel
30 from redis.lock import Lock
31 from redis.utils import str_if_bytes
32 from redis import exceptions as redis_exceptions
33 from ricsdl.configuration import _Configuration
34 from ricsdl.exceptions import (
35     RejectedByBackend,
36     NotConnected,
37     BackendError
38 )
39 from .dbbackend_abc import DbBackendAbc
40 from .dbbackend_abc import DbBackendLockAbc
41
42
43 @contextlib.contextmanager
44 def _map_to_sdl_exception():
45     """Translates known redis exceptions into SDL exceptions."""
46     try:
47         yield
48     except(redis_exceptions.ResponseError) as exc:
49         raise RejectedByBackend("SDL backend rejected the request: {}".
50                                 format(str(exc))) from exc
51     except(redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
52         raise NotConnected("SDL not connected to backend: {}".
53                            format(str(exc))) from exc
54     except(redis_exceptions.RedisError) as exc:
55         raise BackendError("SDL backend failed to process the request: {}".
56                            format(str(exc))) from exc
57
58
59 class PubSub(redis.client.PubSub):
60     def __init__(self, event_separator, connection_pool, ignore_subscribe_messages=False):
61         super().__init__(connection_pool, shard_hint=None, ignore_subscribe_messages=ignore_subscribe_messages)
62         self.event_separator = event_separator
63
64     def handle_message(self, response, ignore_subscribe_messages=False):
65         """
66         Parses a pub/sub message. If the channel or pattern was subscribed to
67         with a message handler, the handler is invoked instead of a parsed
68         message being returned.
69
70         Adapted from: https://github.com/andymccurdy/redis-py/blob/master/redis/client.py
71         """
72         message_type = str_if_bytes(response[0])
73         if message_type == 'pmessage':
74             message = {
75                 'type': message_type,
76                 'pattern': response[1],
77                 'channel': response[2],
78                 'data': response[3]
79             }
80         elif message_type == 'pong':
81             message = {
82                 'type': message_type,
83                 'pattern': None,
84                 'channel': None,
85                 'data': response[1]
86             }
87         else:
88             message = {
89                 'type': message_type,
90                 'pattern': None,
91                 'channel': response[1],
92                 'data': response[2]
93             }
94
95         # if this is an unsubscribe message, remove it from memory
96         if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
97             if message_type == 'punsubscribe':
98                 pattern = response[1]
99                 if pattern in self.pending_unsubscribe_patterns:
100                     self.pending_unsubscribe_patterns.remove(pattern)
101                     self.patterns.pop(pattern, None)
102             else:
103                 channel = response[1]
104                 if channel in self.pending_unsubscribe_channels:
105                     self.pending_unsubscribe_channels.remove(channel)
106                     self.channels.pop(channel, None)
107
108         if message_type in self.PUBLISH_MESSAGE_TYPES:
109             # if there's a message handler, invoke it
110             if message_type == 'pmessage':
111                 handler = self.patterns.get(message['pattern'], None)
112             else:
113                 handler = self.channels.get(message['channel'], None)
114             if handler:
115                 # Need to send only channel and notification instead of raw
116                 # message
117                 message_channel = self._strip_ns_from_bin_key('', message['channel'])
118                 message_data = message['data'].decode('utf-8')
119                 messages = message_data.split(self.event_separator)
120                 handler(message_channel, messages)
121                 return message_channel, messages
122         elif message_type != 'pong':
123             # this is a subscribe/unsubscribe message. ignore if we don't
124             # want them
125             if ignore_subscribe_messages or self.ignore_subscribe_messages:
126                 return None
127
128         return message
129
130     @classmethod
131     def _strip_ns_from_bin_key(cls, ns: str, nskey: bytes) -> str:
132         try:
133             redis_key = nskey.decode('utf-8')
134         except UnicodeDecodeError as exc:
135             msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
136             raise RejectedByBackend(msg)
137         nskey = redis_key.split(',', 1)
138         if len(nskey) != 2:
139             msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
140             raise RejectedByBackend(msg)
141         return nskey[1]
142
143
144 class RedisBackend(DbBackendAbc):
145     """
146     A class providing an implementation of database backend of Shared Data Layer (SDL), when
147     backend database solution is Redis.
148
149     Args:
150         configuration (_Configuration): SDL configuration, containing credentials to connect to
151                                         Redis database backend.
152     """
153     def __init__(self, configuration: _Configuration) -> None:
154         super().__init__()
155         self.next_client_event = 0
156         self.event_separator = configuration.get_event_separator()
157         self.clients = list()
158         with _map_to_sdl_exception():
159             self.clients = self.__create_redis_clients(configuration)
160
161     def __del__(self):
162         self.close()
163
164     def __str__(self):
165         out = {"DB type": "Redis"}
166         for i, r in enumerate(self.clients):
167             out["Redis client[" + str(i) + "]"] = str(r)
168         return str(out)
169
170     def is_connected(self):
171         is_connected = True
172         with _map_to_sdl_exception():
173             for c in self.clients:
174                 if not c.redis_client.ping():
175                     is_connected = False
176                     break
177         return is_connected
178
179     def close(self):
180         for c in self.clients:
181             c.redis_client.close()
182
183     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
184         db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
185         with _map_to_sdl_exception():
186             self.__getClient(ns).mset(db_data_map)
187
188     def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
189         db_key = self.__add_key_ns_prefix(ns, key)
190         with _map_to_sdl_exception():
191             return self.__getClient(ns).execute_command('SETIE', db_key, new_data, old_data)
192
193     def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
194         db_key = self.__add_key_ns_prefix(ns, key)
195         with _map_to_sdl_exception():
196             return self.__getClient(ns).setnx(db_key, data)
197
198     def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
199         ret = dict()
200         db_keys = self.__add_keys_ns_prefix(ns, keys)
201         with _map_to_sdl_exception():
202             values = self.__getClient(ns).mget(db_keys)
203             for idx, val in enumerate(values):
204                 # return only key values, which has a value
205                 if val is not None:
206                     ret[keys[idx]] = val
207             return ret
208
209     def find_keys(self, ns: str, key_pattern: str) -> List[str]:
210         db_key_pattern = self.__add_key_ns_prefix(ns, key_pattern)
211         with _map_to_sdl_exception():
212             ret = self.__getClient(ns).keys(db_key_pattern)
213             return self.__strip_ns_from_bin_keys(ns, ret)
214
215     def find_and_get(self, ns: str, key_pattern: str) -> Dict[str, bytes]:
216         # todo: replace below implementation with redis 'NGET' module
217         ret = dict()  # type: Dict[str, bytes]
218         with _map_to_sdl_exception():
219             matched_keys = self.find_keys(ns, key_pattern)
220             if matched_keys:
221                 ret = self.get(ns, matched_keys)
222         return ret
223
224     def remove(self, ns: str, keys: List[str]) -> None:
225         db_keys = self.__add_keys_ns_prefix(ns, keys)
226         with _map_to_sdl_exception():
227             self.__getClient(ns).delete(*db_keys)
228
229     def remove_if(self, ns: str, key: str, data: bytes) -> bool:
230         db_key = self.__add_key_ns_prefix(ns, key)
231         with _map_to_sdl_exception():
232             return self.__getClient(ns).execute_command('DELIE', db_key, data)
233
234     def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
235         db_key = self.__add_key_ns_prefix(ns, group)
236         with _map_to_sdl_exception():
237             self.__getClient(ns).sadd(db_key, *members)
238
239     def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
240         db_key = self.__add_key_ns_prefix(ns, group)
241         with _map_to_sdl_exception():
242             self.__getClient(ns).srem(db_key, *members)
243
244     def remove_group(self, ns: str, group: str) -> None:
245         db_key = self.__add_key_ns_prefix(ns, group)
246         with _map_to_sdl_exception():
247             self.__getClient(ns).delete(db_key)
248
249     def get_members(self, ns: str, group: str) -> Set[bytes]:
250         db_key = self.__add_key_ns_prefix(ns, group)
251         with _map_to_sdl_exception():
252             return self.__getClient(ns).smembers(db_key)
253
254     def is_member(self, ns: str, group: str, member: bytes) -> bool:
255         db_key = self.__add_key_ns_prefix(ns, group)
256         with _map_to_sdl_exception():
257             return self.__getClient(ns).sismember(db_key, member)
258
259     def group_size(self, ns: str, group: str) -> int:
260         db_key = self.__add_key_ns_prefix(ns, group)
261         with _map_to_sdl_exception():
262             return self.__getClient(ns).scard(db_key)
263
264     def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
265                         data_map: Dict[str, bytes]) -> None:
266         db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
267         channels_and_events_prepared = []
268         total_events = 0
269         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
270         with _map_to_sdl_exception():
271             return self.__getClient(ns).execute_command(
272                 "MSETMPUB",
273                 len(db_data_map),
274                 total_events,
275                 *[val for data in db_data_map.items() for val in data],
276                 *channels_and_events_prepared,
277             )
278
279     def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
280                            old_data: bytes, new_data: bytes) -> bool:
281         db_key = self.__add_key_ns_prefix(ns, key)
282         channels_and_events_prepared = []
283         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
284         with _map_to_sdl_exception():
285             ret = self.__getClient(ns).execute_command("SETIEMPUB", db_key, new_data, old_data,
286                                                        *channels_and_events_prepared)
287             return ret == b"OK"
288
289     def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
290                                       key: str, data: bytes) -> bool:
291         db_key = self.__add_key_ns_prefix(ns, key)
292         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
293         with _map_to_sdl_exception():
294             ret = self.__getClient(ns).execute_command("SETNXMPUB", db_key, data,
295                                                        *channels_and_events_prepared)
296             return ret == b"OK"
297
298     def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
299                            keys: List[str]) -> None:
300         db_keys = self.__add_keys_ns_prefix(ns, keys)
301         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
302         with _map_to_sdl_exception():
303             return self.__getClient(ns).execute_command(
304                 "DELMPUB",
305                 len(db_keys),
306                 total_events,
307                 *db_keys,
308                 *channels_and_events_prepared,
309             )
310
311     def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
312                               data: bytes) -> bool:
313         db_key = self.__add_key_ns_prefix(ns, key)
314         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
315         with _map_to_sdl_exception():
316             ret = self.__getClient(ns).execute_command("DELIEMPUB", db_key, data,
317                                                        *channels_and_events_prepared)
318             return bool(ret)
319
320     def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None:
321         keys = self.__getClient(ns).keys(self.__add_key_ns_prefix(ns, "*"))
322         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
323         with _map_to_sdl_exception():
324             return self.__getClient(ns).execute_command(
325                 "DELMPUB",
326                 len(keys),
327                 total_events,
328                 *keys,
329                 *channels_and_events_prepared,
330             )
331
332     def subscribe_channel(self, ns: str, cb: Callable[[str, List[str]], None],
333                           channels: List[str]) -> None:
334         channels = self.__add_keys_ns_prefix(ns, channels)
335         for channel in channels:
336             with _map_to_sdl_exception():
337                 redis_ctx = self.__getClientConn(ns)
338                 redis_ctx.redis_pubsub.subscribe(**{channel: cb})
339                 if not redis_ctx.pubsub_thread.is_alive() and redis_ctx.run_in_thread:
340                     redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001,
341                                                                                    daemon=True)
342
343     def unsubscribe_channel(self, ns: str, channels: List[str]) -> None:
344         channels = self.__add_keys_ns_prefix(ns, channels)
345         for channel in channels:
346             with _map_to_sdl_exception():
347                 self.__getClientConn(ns).redis_pubsub.unsubscribe(channel)
348
349     def start_event_listener(self) -> None:
350         redis_ctxs = self.__getClientConns()
351         for redis_ctx in redis_ctxs:
352             if redis_ctx.pubsub_thread.is_alive():
353                 raise RejectedByBackend("Event loop already started")
354             if redis_ctx.redis_pubsub.subscribed and len(redis_ctx.redis_client.pubsub_channels()) > 0:
355                 redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True)
356             redis_ctx.run_in_thread = True
357
358     def handle_events(self) -> Optional[Tuple[str, List[str]]]:
359         if self.next_client_event >= len(self.clients):
360             self.next_client_event = 0
361         redis_ctx = self.clients[self.next_client_event]
362         self.next_client_event += 1
363         if redis_ctx.pubsub_thread.is_alive() or redis_ctx.run_in_thread:
364             raise RejectedByBackend("Event loop already started")
365         try:
366             return redis_ctx.redis_pubsub.get_message(ignore_subscribe_messages=True)
367         except RuntimeError:
368             return None
369
370     def __create_redis_clients(self, config):
371         clients = list()
372         cfg_params = config.get_params()
373         for i, addr in enumerate(cfg_params.db_cluster_addrs):
374             port = cfg_params.db_ports[i] if i < len(cfg_params.db_ports) else ""
375             sport = cfg_params.db_sentinel_ports[i] if i < len(cfg_params.db_sentinel_ports) else ""
376             name = cfg_params.db_sentinel_master_names[i] if i < len(cfg_params.db_sentinel_master_names) else ""
377
378             client = self.__create_redis_client(addr, port, sport, name)
379             clients.append(client)
380         return clients
381
382     def __create_redis_client(self, addr, port, sentinel_port, master_name):
383         new_sentinel = None
384         new_redis = None
385         if len(sentinel_port) == 0:
386             new_redis = Redis(host=addr, port=port, db=0, max_connections=20)
387         else:
388             sentinel_node = (addr, sentinel_port)
389             new_sentinel = Sentinel([sentinel_node])
390             new_redis = new_sentinel.master_for(master_name)
391
392         new_redis.set_response_callback('SETIE', lambda r: r and str_if_bytes(r) == 'OK' or False)
393         new_redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
394
395         redis_pubsub = PubSub(self.event_separator, new_redis.connection_pool, ignore_subscribe_messages=True)
396         pubsub_thread = threading.Thread(target=None)
397         run_in_thread = False
398
399         return _RedisConn(new_redis, redis_pubsub, pubsub_thread, run_in_thread)
400
401     def __getClientConns(self):
402         return self.clients
403
404     def __getClientConn(self, ns):
405         clients_cnt = len(self.clients)
406         client_id = self.__get_hash(ns) % clients_cnt
407         return self.clients[client_id]
408
409     def __getClient(self, ns):
410         clients_cnt = len(self.clients)
411         client_id = 0
412         if clients_cnt > 1:
413             client_id = self.__get_hash(ns) % clients_cnt
414         return self.clients[client_id].redis_client
415
416     @classmethod
417     def __get_hash(cls, str):
418         return zlib.crc32(str.encode())
419
420     @classmethod
421     def __add_key_ns_prefix(cls, ns: str, key: str):
422         return '{' + ns + '},' + key
423
424     @classmethod
425     def __add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
426         ret_nskeys = []
427         for k in keylist:
428             ret_nskeys.append('{' + ns + '},' + k)
429         return ret_nskeys
430
431     @classmethod
432     def __add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
433         ret_nsdict = {}
434         for key, val in data_dict.items():
435             ret_nsdict['{' + ns + '},' + key] = val
436         return ret_nsdict
437
438     @classmethod
439     def __strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
440         ret_keys = []
441         for k in nskeylist:
442             try:
443                 redis_key = k.decode("utf-8")
444             except UnicodeDecodeError as exc:
445                 msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
446                 raise RejectedByBackend(msg)
447             nskey = redis_key.split(',', 1)
448             if len(nskey) != 2:
449                 msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
450                 raise RejectedByBackend(msg)
451             ret_keys.append(nskey[1])
452         return ret_keys
453
454     def _prepare_channels(self, ns: str,
455                           channels_and_events: Dict[str, List[str]]) -> Tuple[List, int]:
456         channels_and_events_prepared = []
457         for channel, events in channels_and_events.items():
458             one_channel_join_events = None
459             for event in events:
460                 if one_channel_join_events is None:
461                     channels_and_events_prepared.append(self.__add_key_ns_prefix(ns, channel))
462                     one_channel_join_events = event
463                 else:
464                     one_channel_join_events = one_channel_join_events + self.event_separator + event
465             channels_and_events_prepared.append(one_channel_join_events)
466         pairs_cnt = int(len(channels_and_events_prepared) / 2)
467         return channels_and_events_prepared, pairs_cnt
468
469     def get_redis_connection(self, ns: str):
470         """Return existing Redis database connection valid for the namespace."""
471         return self.__getClient(ns)
472
473
474 class _RedisConn:
475     """
476     Internal class container to hold redis client connection
477     """
478
479     def __init__(self, redis_client, pubsub, pubsub_thread, run_in_thread):
480         self.redis_client = redis_client
481         self.redis_pubsub = pubsub
482         self.pubsub_thread = pubsub_thread
483         self.run_in_thread = run_in_thread
484
485     def __str__(self):
486         return str(
487             {
488                 "Client": repr(self.redis_client),
489                 "Subscrions": self.redis_pubsub.subscribed,
490                 "PubSub thread": repr(self.pubsub_thread),
491                 "Run in thread": self.run_in_thread,
492             }
493         )
494
495
496 class RedisBackendLock(DbBackendLockAbc):
497     """
498     A class providing an implementation of database backend lock of Shared Data Layer (SDL), when
499     backend database solution is Redis.
500
501     Args:
502         ns (str): Namespace under which this lock is targeted.
503         name (str): Lock name, identifies the lock key in a Redis database backend.
504         expiration (int, float): Lock expiration time after which the lock is removed if it hasn't
505                                  been released earlier by a 'release' method.
506         redis_backend (RedisBackend): Database backend object containing connection to Redis
507                                       database.
508     """
509     lua_get_validity_time = None
510     # KEYS[1] - lock name
511     # ARGS[1] - token
512     # return < 0 in case of failure, otherwise return lock validity time in milliseconds.
513     LUA_GET_VALIDITY_TIME_SCRIPT = """
514         local token = redis.call('get', KEYS[1])
515         if not token then
516             return -10
517         end
518         if token ~= ARGV[1] then
519             return -11
520         end
521         return redis.call('pttl', KEYS[1])
522     """
523
524     def __init__(self, ns: str, name: str, expiration: Union[int, float],
525                  redis_backend: RedisBackend) -> None:
526         super().__init__(ns, name)
527         self.__redis = redis_backend.get_redis_connection(ns)
528         with _map_to_sdl_exception():
529             redis_lockname = '{' + ns + '},' + self._lock_name
530             self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)
531             self._register_scripts()
532
533     def __str__(self):
534         return str(
535             {
536                 "lock DB type": "Redis",
537                 "lock namespace": self._ns,
538                 "lock name": self._lock_name,
539                 "lock status": self._lock_status_to_string()
540             }
541         )
542
543     def acquire(self, retry_interval: Union[int, float] = 0.1,
544                 retry_timeout: Union[int, float] = 10) -> bool:
545         succeeded = False
546         self.__redis_lock.sleep = retry_interval
547         with _map_to_sdl_exception():
548             succeeded = self.__redis_lock.acquire(blocking_timeout=retry_timeout)
549         return succeeded
550
551     def release(self) -> None:
552         with _map_to_sdl_exception():
553             self.__redis_lock.release()
554
555     def refresh(self) -> None:
556         with _map_to_sdl_exception():
557             self.__redis_lock.reacquire()
558
559     def get_validity_time(self) -> Union[int, float]:
560         validity = 0
561         if self.__redis_lock.local.token is None:
562             msg = u'Cannot get validity time of an unlocked lock %s' % self._lock_name
563             raise RejectedByBackend(msg)
564
565         with _map_to_sdl_exception():
566             validity = self.lua_get_validity_time(keys=[self.__redis_lock.name],
567                                                   args=[self.__redis_lock.local.token],
568                                                   client=self.__redis)
569         if validity < 0:
570             msg = (u'Getting validity time of a lock %s failed with error code: %d'
571                    % (self._lock_name, validity))
572             raise RejectedByBackend(msg)
573         ftime = validity / 1000.0
574         if ftime.is_integer():
575             return int(ftime)
576         return ftime
577
578     def _register_scripts(self):
579         cls = self.__class__
580         client = self.__redis
581         if cls.lua_get_validity_time is None:
582             cls.lua_get_validity_time = client.register_script(cls.LUA_GET_VALIDITY_TIME_SCRIPT)
583
584     def _lock_status_to_string(self) -> str:
585         try:
586             if self.__redis_lock.locked():
587                 if self.__redis_lock.owned():
588                     return 'locked'
589                 return 'locked by someone else'
590             return 'unlocked'
591         except(redis_exceptions.RedisError) as exc:
592             return f'Error: {str(exc)}'