bc4b43b107c9f7cd2b3fb345b8e19cfee18f0334
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
1 # Copyright (c) 2019 AT&T Intellectual Property.
2 # Copyright (c) 2018-2019 Nokia.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 #
17 # This source code is part of the near-RT RIC (RAN Intelligent Controller)
18 # platform project (RICP).
19 #
20
21
22 """The module provides implementation of Shared Data Layer (SDL) database backend interface."""
23 import contextlib
24 import threading
25 from typing import (Callable, Dict, Set, List, Optional, Tuple, Union)
26 import zlib
27 import redis
28 from redis import Redis
29 from redis.sentinel import Sentinel
30 from redis.lock import Lock
31 from redis._compat import nativestr
32 from redis import exceptions as redis_exceptions
33 from ricsdl.configuration import _Configuration
34 from ricsdl.exceptions import (
35     RejectedByBackend,
36     NotConnected,
37     BackendError
38 )
39 from .dbbackend_abc import DbBackendAbc
40 from .dbbackend_abc import DbBackendLockAbc
41
42
43 @contextlib.contextmanager
44 def _map_to_sdl_exception():
45     """Translates known redis exceptions into SDL exceptions."""
46     try:
47         yield
48     except(redis_exceptions.ResponseError) as exc:
49         raise RejectedByBackend("SDL backend rejected the request: {}".
50                                 format(str(exc))) from exc
51     except(redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
52         raise NotConnected("SDL not connected to backend: {}".
53                            format(str(exc))) from exc
54     except(redis_exceptions.RedisError) as exc:
55         raise BackendError("SDL backend failed to process the request: {}".
56                            format(str(exc))) from exc
57
58
59 class PubSub(redis.client.PubSub):
60     def __init__(self, event_separator, connection_pool, ignore_subscribe_messages=False):
61         super().__init__(connection_pool, shard_hint=None, ignore_subscribe_messages=ignore_subscribe_messages)
62         self.event_separator = event_separator
63
64     def handle_message(self, response, ignore_subscribe_messages=False):
65         """
66         Parses a pub/sub message. If the channel or pattern was subscribed to
67         with a message handler, the handler is invoked instead of a parsed
68         message being returned.
69
70         Adapted from: https://github.com/andymccurdy/redis-py/blob/master/redis/client.py
71         """
72         message_type = nativestr(response[0])
73         if message_type == 'pmessage':
74             message = {
75                 'type': message_type,
76                 'pattern': response[1],
77                 'channel': response[2],
78                 'data': response[3]
79             }
80         elif message_type == 'pong':
81             message = {
82                 'type': message_type,
83                 'pattern': None,
84                 'channel': None,
85                 'data': response[1]
86             }
87         else:
88             message = {
89                 'type': message_type,
90                 'pattern': None,
91                 'channel': response[1],
92                 'data': response[2]
93             }
94
95         # if this is an unsubscribe message, remove it from memory
96         if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
97             if message_type == 'punsubscribe':
98                 pattern = response[1]
99                 if pattern in self.pending_unsubscribe_patterns:
100                     self.pending_unsubscribe_patterns.remove(pattern)
101                     self.patterns.pop(pattern, None)
102             else:
103                 channel = response[1]
104                 if channel in self.pending_unsubscribe_channels:
105                     self.pending_unsubscribe_channels.remove(channel)
106                     self.channels.pop(channel, None)
107
108         if message_type in self.PUBLISH_MESSAGE_TYPES:
109             # if there's a message handler, invoke it
110             if message_type == 'pmessage':
111                 handler = self.patterns.get(message['pattern'], None)
112             else:
113                 handler = self.channels.get(message['channel'], None)
114             if handler:
115                 # Need to send only channel and notification instead of raw
116                 # message
117                 message_channel = self._strip_ns_from_bin_key('', message['channel'])
118                 message_data = message['data'].decode('utf-8')
119                 messages = message_data.split(self.event_separator)
120                 notification = messages[0] if len(messages) == 1 else messages
121                 handler(message_channel, notification)
122                 return message_channel, notification
123         elif message_type != 'pong':
124             # this is a subscribe/unsubscribe message. ignore if we don't
125             # want them
126             if ignore_subscribe_messages or self.ignore_subscribe_messages:
127                 return None
128
129         return message
130
131     @classmethod
132     def _strip_ns_from_bin_key(cls, ns: str, nskey: bytes) -> str:
133         try:
134             redis_key = nskey.decode('utf-8')
135         except UnicodeDecodeError as exc:
136             msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
137             raise RejectedByBackend(msg)
138         nskey = redis_key.split(',', 1)
139         if len(nskey) != 2:
140             msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
141             raise RejectedByBackend(msg)
142         return nskey[1]
143
144
145 class RedisBackend(DbBackendAbc):
146     """
147     A class providing an implementation of database backend of Shared Data Layer (SDL), when
148     backend database solution is Redis.
149
150     Args:
151         configuration (_Configuration): SDL configuration, containing credentials to connect to
152                                         Redis database backend.
153     """
154     def __init__(self, configuration: _Configuration) -> None:
155         super().__init__()
156         self.next_client_event = 0
157         self.event_separator = configuration.get_event_separator()
158         self.clients = list()
159         with _map_to_sdl_exception():
160             self.clients = self.__create_redis_clients(configuration)
161
162     def __del__(self):
163         self.close()
164
165     def __str__(self):
166         out = {"DB type": "Redis"}
167         for i, r in enumerate(self.clients):
168             out["Redis client[" + str(i) + "]"] = str(r)
169         return str(out)
170
171     def is_connected(self):
172         is_connected = True
173         with _map_to_sdl_exception():
174             for c in self.clients:
175                 if not c.redis_client.ping():
176                     is_connected = False
177                     break
178         return is_connected
179
180     def close(self):
181         for c in self.clients:
182             c.redis_client.close()
183
184     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
185         db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
186         with _map_to_sdl_exception():
187             self.__getClient(ns).mset(db_data_map)
188
189     def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
190         db_key = self.__add_key_ns_prefix(ns, key)
191         with _map_to_sdl_exception():
192             return self.__getClient(ns).execute_command('SETIE', db_key, new_data, old_data)
193
194     def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
195         db_key = self.__add_key_ns_prefix(ns, key)
196         with _map_to_sdl_exception():
197             return self.__getClient(ns).setnx(db_key, data)
198
199     def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
200         ret = dict()
201         db_keys = self.__add_keys_ns_prefix(ns, keys)
202         with _map_to_sdl_exception():
203             values = self.__getClient(ns).mget(db_keys)
204             for idx, val in enumerate(values):
205                 # return only key values, which has a value
206                 if val is not None:
207                     ret[keys[idx]] = val
208             return ret
209
210     def find_keys(self, ns: str, key_pattern: str) -> List[str]:
211         db_key_pattern = self.__add_key_ns_prefix(ns, key_pattern)
212         with _map_to_sdl_exception():
213             ret = self.__getClient(ns).keys(db_key_pattern)
214             return self.__strip_ns_from_bin_keys(ns, ret)
215
216     def find_and_get(self, ns: str, key_pattern: str) -> Dict[str, bytes]:
217         # todo: replace below implementation with redis 'NGET' module
218         ret = dict()  # type: Dict[str, bytes]
219         with _map_to_sdl_exception():
220             matched_keys = self.find_keys(ns, key_pattern)
221             if matched_keys:
222                 ret = self.get(ns, matched_keys)
223         return ret
224
225     def remove(self, ns: str, keys: List[str]) -> None:
226         db_keys = self.__add_keys_ns_prefix(ns, keys)
227         with _map_to_sdl_exception():
228             self.__getClient(ns).delete(*db_keys)
229
230     def remove_if(self, ns: str, key: str, data: bytes) -> bool:
231         db_key = self.__add_key_ns_prefix(ns, key)
232         with _map_to_sdl_exception():
233             return self.__getClient(ns).execute_command('DELIE', db_key, data)
234
235     def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
236         db_key = self.__add_key_ns_prefix(ns, group)
237         with _map_to_sdl_exception():
238             self.__getClient(ns).sadd(db_key, *members)
239
240     def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
241         db_key = self.__add_key_ns_prefix(ns, group)
242         with _map_to_sdl_exception():
243             self.__getClient(ns).srem(db_key, *members)
244
245     def remove_group(self, ns: str, group: str) -> None:
246         db_key = self.__add_key_ns_prefix(ns, group)
247         with _map_to_sdl_exception():
248             self.__getClient(ns).delete(db_key)
249
250     def get_members(self, ns: str, group: str) -> Set[bytes]:
251         db_key = self.__add_key_ns_prefix(ns, group)
252         with _map_to_sdl_exception():
253             return self.__getClient(ns).smembers(db_key)
254
255     def is_member(self, ns: str, group: str, member: bytes) -> bool:
256         db_key = self.__add_key_ns_prefix(ns, group)
257         with _map_to_sdl_exception():
258             return self.__getClient(ns).sismember(db_key, member)
259
260     def group_size(self, ns: str, group: str) -> int:
261         db_key = self.__add_key_ns_prefix(ns, group)
262         with _map_to_sdl_exception():
263             return self.__getClient(ns).scard(db_key)
264
265     def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
266                         data_map: Dict[str, bytes]) -> None:
267         db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
268         channels_and_events_prepared = []
269         total_events = 0
270         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
271         with _map_to_sdl_exception():
272             return self.__getClient(ns).execute_command(
273                 "MSETMPUB",
274                 len(db_data_map),
275                 total_events,
276                 *[val for data in db_data_map.items() for val in data],
277                 *channels_and_events_prepared,
278             )
279
280     def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
281                            old_data: bytes, new_data: bytes) -> bool:
282         db_key = self.__add_key_ns_prefix(ns, key)
283         channels_and_events_prepared = []
284         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
285         with _map_to_sdl_exception():
286             ret = self.__getClient(ns).execute_command("SETIEMPUB", db_key, new_data, old_data,
287                                                        *channels_and_events_prepared)
288             return ret == b"OK"
289
290     def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
291                                       key: str, data: bytes) -> bool:
292         db_key = self.__add_key_ns_prefix(ns, key)
293         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
294         with _map_to_sdl_exception():
295             ret = self.__getClient(ns).execute_command("SETNXMPUB", db_key, data,
296                                                        *channels_and_events_prepared)
297             return ret == b"OK"
298
299     def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
300                            keys: List[str]) -> None:
301         db_keys = self.__add_keys_ns_prefix(ns, keys)
302         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
303         with _map_to_sdl_exception():
304             return self.__getClient(ns).execute_command(
305                 "DELMPUB",
306                 len(db_keys),
307                 total_events,
308                 *db_keys,
309                 *channels_and_events_prepared,
310             )
311
312     def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
313                               data: bytes) -> bool:
314         db_key = self.__add_key_ns_prefix(ns, key)
315         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
316         with _map_to_sdl_exception():
317             ret = self.__getClient(ns).execute_command("DELIEMPUB", db_key, data,
318                                                        *channels_and_events_prepared)
319             return bool(ret)
320
321     def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None:
322         keys = self.__getClient(ns).keys(self.__add_key_ns_prefix(ns, "*"))
323         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
324         with _map_to_sdl_exception():
325             return self.__getClient(ns).execute_command(
326                 "DELMPUB",
327                 len(keys),
328                 total_events,
329                 *keys,
330                 *channels_and_events_prepared,
331             )
332
333     def subscribe_channel(self, ns: str,
334                           cb: Union[Callable[[str, str], None], Callable[[str, List[str]], None]],
335                           channels: List[str]) -> None:
336         channels = self.__add_keys_ns_prefix(ns, channels)
337         for channel in channels:
338             with _map_to_sdl_exception():
339                 redis_ctx = self.__getClientConn(ns)
340                 redis_ctx.redis_pubsub.subscribe(**{channel: cb})
341                 if not redis_ctx.pubsub_thread.is_alive() and redis_ctx.run_in_thread:
342                     redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001,
343                                                                                    daemon=True)
344
345     def unsubscribe_channel(self, ns: str, channels: List[str]) -> None:
346         channels = self.__add_keys_ns_prefix(ns, channels)
347         for channel in channels:
348             with _map_to_sdl_exception():
349                 self.__getClientConn(ns).redis_pubsub.unsubscribe(channel)
350
351     def start_event_listener(self) -> None:
352         redis_ctxs = self.__getClientConns()
353         for redis_ctx in redis_ctxs:
354             if redis_ctx.pubsub_thread.is_alive():
355                 raise RejectedByBackend("Event loop already started")
356             if redis_ctx.redis_pubsub.subscribed and len(redis_ctx.redis_client.pubsub_channels()) > 0:
357                 redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True)
358             redis_ctx.run_in_thread = True
359
360     def handle_events(self) -> Optional[Union[Tuple[str, str], Tuple[str, List[str]]]]:
361         if self.next_client_event >= len(self.clients):
362             self.next_client_event = 0
363         redis_ctx = self.clients[self.next_client_event]
364         self.next_client_event += 1
365         if redis_ctx.pubsub_thread.is_alive() or redis_ctx.run_in_thread:
366             raise RejectedByBackend("Event loop already started")
367         try:
368             return redis_ctx.redis_pubsub.get_message(ignore_subscribe_messages=True)
369         except RuntimeError:
370             return None
371
372     def __create_redis_clients(self, config):
373         clients = list()
374         cfg_params = config.get_params()
375         if cfg_params.db_cluster_addr_list is None:
376             clients.append(self.__create_legacy_redis_client(cfg_params))
377         else:
378             for addr in cfg_params.db_cluster_addr_list.split(","):
379                 client = self.__create_redis_client(cfg_params, addr)
380                 clients.append(client)
381         return clients
382
383     def __create_legacy_redis_client(self, cfg_params):
384         return self.__create_redis_client(cfg_params, cfg_params.db_host)
385
386     def __create_redis_client(self, cfg_params, addr):
387         new_sentinel = None
388         new_redis = None
389         if cfg_params.db_sentinel_port is None:
390             new_redis = Redis(host=addr, port=cfg_params.db_port, db=0, max_connections=20)
391         else:
392             sentinel_node = (addr, cfg_params.db_sentinel_port)
393             master_name = cfg_params.db_sentinel_master_name
394             new_sentinel = Sentinel([sentinel_node])
395             new_redis = new_sentinel.master_for(master_name)
396
397         new_redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
398         new_redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
399
400         redis_pubsub = PubSub(self.event_separator, new_redis.connection_pool, ignore_subscribe_messages=True)
401         pubsub_thread = threading.Thread(target=None)
402         run_in_thread = False
403
404         return _RedisConn(new_redis, redis_pubsub, pubsub_thread, run_in_thread)
405
406     def __getClientConns(self):
407         return self.clients
408
409     def __getClientConn(self, ns):
410         clients_cnt = len(self.clients)
411         client_id = self.__get_hash(ns) % clients_cnt
412         return self.clients[client_id]
413
414     def __getClient(self, ns):
415         clients_cnt = len(self.clients)
416         client_id = 0
417         if clients_cnt > 1:
418             client_id = self.__get_hash(ns) % clients_cnt
419         return self.clients[client_id].redis_client
420
421     @classmethod
422     def __get_hash(cls, str):
423         return zlib.crc32(str.encode())
424
425     @classmethod
426     def __add_key_ns_prefix(cls, ns: str, key: str):
427         return '{' + ns + '},' + key
428
429     @classmethod
430     def __add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
431         ret_nskeys = []
432         for k in keylist:
433             ret_nskeys.append('{' + ns + '},' + k)
434         return ret_nskeys
435
436     @classmethod
437     def __add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
438         ret_nsdict = {}
439         for key, val in data_dict.items():
440             ret_nsdict['{' + ns + '},' + key] = val
441         return ret_nsdict
442
443     @classmethod
444     def __strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
445         ret_keys = []
446         for k in nskeylist:
447             try:
448                 redis_key = k.decode("utf-8")
449             except UnicodeDecodeError as exc:
450                 msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
451                 raise RejectedByBackend(msg)
452             nskey = redis_key.split(',', 1)
453             if len(nskey) != 2:
454                 msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
455                 raise RejectedByBackend(msg)
456             ret_keys.append(nskey[1])
457         return ret_keys
458
459     def _prepare_channels(self, ns: str,
460                           channels_and_events: Dict[str, List[str]]) -> Tuple[List, int]:
461         channels_and_events_prepared = []
462         for channel, events in channels_and_events.items():
463             one_channel_join_events = None
464             for event in events:
465                 if one_channel_join_events is None:
466                     channels_and_events_prepared.append(self.__add_key_ns_prefix(ns, channel))
467                     one_channel_join_events = event
468                 else:
469                     one_channel_join_events = one_channel_join_events + self.event_separator + event
470             channels_and_events_prepared.append(one_channel_join_events)
471         pairs_cnt = int(len(channels_and_events_prepared) / 2)
472         return channels_and_events_prepared, pairs_cnt
473
474     def get_redis_connection(self, ns: str):
475         """Return existing Redis database connection valid for the namespace."""
476         return self.__getClient(ns)
477
478
479 class _RedisConn:
480     """
481     Internal class container to hold redis client connection
482     """
483
484     def __init__(self, redis_client, pubsub, pubsub_thread, run_in_thread):
485         self.redis_client = redis_client
486         self.redis_pubsub = pubsub
487         self.pubsub_thread = pubsub_thread
488         self.run_in_thread = run_in_thread
489
490     def __str__(self):
491         return str(
492             {
493                 "Client": repr(self.redis_client),
494                 "Subscrions": self.redis_pubsub.subscribed,
495                 "PubSub thread": repr(self.pubsub_thread),
496                 "Run in thread": self.run_in_thread,
497             }
498         )
499
500
501 class RedisBackendLock(DbBackendLockAbc):
502     """
503     A class providing an implementation of database backend lock of Shared Data Layer (SDL), when
504     backend database solution is Redis.
505
506     Args:
507         ns (str): Namespace under which this lock is targeted.
508         name (str): Lock name, identifies the lock key in a Redis database backend.
509         expiration (int, float): Lock expiration time after which the lock is removed if it hasn't
510                                  been released earlier by a 'release' method.
511         redis_backend (RedisBackend): Database backend object containing connection to Redis
512                                       database.
513     """
514     lua_get_validity_time = None
515     # KEYS[1] - lock name
516     # ARGS[1] - token
517     # return < 0 in case of failure, otherwise return lock validity time in milliseconds.
518     LUA_GET_VALIDITY_TIME_SCRIPT = """
519         local token = redis.call('get', KEYS[1])
520         if not token then
521             return -10
522         end
523         if token ~= ARGV[1] then
524             return -11
525         end
526         return redis.call('pttl', KEYS[1])
527     """
528
529     def __init__(self, ns: str, name: str, expiration: Union[int, float],
530                  redis_backend: RedisBackend) -> None:
531         super().__init__(ns, name)
532         self.__redis = redis_backend.get_redis_connection(ns)
533         with _map_to_sdl_exception():
534             redis_lockname = '{' + ns + '},' + self._lock_name
535             self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)
536             self._register_scripts()
537
538     def __str__(self):
539         return str(
540             {
541                 "lock DB type": "Redis",
542                 "lock namespace": self._ns,
543                 "lock name": self._lock_name,
544                 "lock status": self._lock_status_to_string()
545             }
546         )
547
548     def acquire(self, retry_interval: Union[int, float] = 0.1,
549                 retry_timeout: Union[int, float] = 10) -> bool:
550         succeeded = False
551         self.__redis_lock.sleep = retry_interval
552         with _map_to_sdl_exception():
553             succeeded = self.__redis_lock.acquire(blocking_timeout=retry_timeout)
554         return succeeded
555
556     def release(self) -> None:
557         with _map_to_sdl_exception():
558             self.__redis_lock.release()
559
560     def refresh(self) -> None:
561         with _map_to_sdl_exception():
562             self.__redis_lock.reacquire()
563
564     def get_validity_time(self) -> Union[int, float]:
565         validity = 0
566         if self.__redis_lock.local.token is None:
567             msg = u'Cannot get validity time of an unlocked lock %s' % self._lock_name
568             raise RejectedByBackend(msg)
569
570         with _map_to_sdl_exception():
571             validity = self.lua_get_validity_time(keys=[self.__redis_lock.name],
572                                                   args=[self.__redis_lock.local.token],
573                                                   client=self.__redis)
574         if validity < 0:
575             msg = (u'Getting validity time of a lock %s failed with error code: %d'
576                    % (self._lock_name, validity))
577             raise RejectedByBackend(msg)
578         ftime = validity / 1000.0
579         if ftime.is_integer():
580             return int(ftime)
581         return ftime
582
583     def _register_scripts(self):
584         cls = self.__class__
585         client = self.__redis
586         if cls.lua_get_validity_time is None:
587             cls.lua_get_validity_time = client.register_script(cls.LUA_GET_VALIDITY_TIME_SCRIPT)
588
589     def _lock_status_to_string(self) -> str:
590         try:
591             if self.__redis_lock.locked():
592                 if self.__redis_lock.owned():
593                     return 'locked'
594                 return 'locked by someone else'
595             return 'unlocked'
596         except(redis_exceptions.RedisError) as exc:
597             return f'Error: {str(exc)}'