3ebc8cb324d5a56255e875f04ad9c9149e89ac78
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
1 # Copyright (c) 2019 AT&T Intellectual Property.
2 # Copyright (c) 2018-2019 Nokia.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 #
17 # This source code is part of the near-RT RIC (RAN Intelligent Controller)
18 # platform project (RICP).
19 #
20
21
22 """The module provides implementation of Shared Data Layer (SDL) database backend interface."""
23 import contextlib
24 import threading
25 from typing import (Callable, Dict, Set, List, Optional, Tuple, Union)
26 import zlib
27 import redis
28 from redis import Redis
29 from redis.sentinel import Sentinel
30 from redis.lock import Lock
31 from redis.utils import str_if_bytes
32 from redis import exceptions as redis_exceptions
33 from ricsdl.configuration import _Configuration
34 from ricsdl.exceptions import (
35     RejectedByBackend,
36     NotConnected,
37     BackendError
38 )
39 from .dbbackend_abc import DbBackendAbc
40 from .dbbackend_abc import DbBackendLockAbc
41
42
43 @contextlib.contextmanager
44 def _map_to_sdl_exception():
45     """Translates known redis exceptions into SDL exceptions."""
46     try:
47         yield
48     except(redis_exceptions.ResponseError) as exc:
49         raise RejectedByBackend("SDL backend rejected the request: {}".
50                                 format(str(exc))) from exc
51     except(redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
52         raise NotConnected("SDL not connected to backend: {}".
53                            format(str(exc))) from exc
54     except(redis_exceptions.RedisError) as exc:
55         raise BackendError("SDL backend failed to process the request: {}".
56                            format(str(exc))) from exc
57
58
59 class PubSub(redis.client.PubSub):
60     def __init__(self, event_separator, connection_pool, ignore_subscribe_messages=False):
61         super().__init__(connection_pool, shard_hint=None, ignore_subscribe_messages=ignore_subscribe_messages)
62         self.event_separator = event_separator
63
64     def handle_message(self, response, ignore_subscribe_messages=False):
65         """
66         Parses a pub/sub message. If the channel or pattern was subscribed to
67         with a message handler, the handler is invoked instead of a parsed
68         message being returned.
69
70         Adapted from: https://github.com/andymccurdy/redis-py/blob/master/redis/client.py
71         """
72         message_type = str_if_bytes(response[0])
73         if message_type == 'pmessage':
74             message = {
75                 'type': message_type,
76                 'pattern': response[1],
77                 'channel': response[2],
78                 'data': response[3]
79             }
80         elif message_type == 'pong':
81             message = {
82                 'type': message_type,
83                 'pattern': None,
84                 'channel': None,
85                 'data': response[1]
86             }
87         else:
88             message = {
89                 'type': message_type,
90                 'pattern': None,
91                 'channel': response[1],
92                 'data': response[2]
93             }
94
95         # if this is an unsubscribe message, remove it from memory
96         if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
97             if message_type == 'punsubscribe':
98                 pattern = response[1]
99                 if pattern in self.pending_unsubscribe_patterns:
100                     self.pending_unsubscribe_patterns.remove(pattern)
101                     self.patterns.pop(pattern, None)
102             else:
103                 channel = response[1]
104                 if channel in self.pending_unsubscribe_channels:
105                     self.pending_unsubscribe_channels.remove(channel)
106                     self.channels.pop(channel, None)
107
108         if message_type in self.PUBLISH_MESSAGE_TYPES:
109             # if there's a message handler, invoke it
110             if message_type == 'pmessage':
111                 handler = self.patterns.get(message['pattern'], None)
112             else:
113                 handler = self.channels.get(message['channel'], None)
114             if handler:
115                 # Need to send only channel and notification instead of raw
116                 # message
117                 message_channel = self._strip_ns_from_bin_key('', message['channel'])
118                 message_data = message['data'].decode('utf-8')
119                 messages = message_data.split(self.event_separator)
120                 handler(message_channel, messages)
121                 return message_channel, messages
122         elif message_type != 'pong':
123             # this is a subscribe/unsubscribe message. ignore if we don't
124             # want them
125             if ignore_subscribe_messages or self.ignore_subscribe_messages:
126                 return None
127
128         return message
129
130     @classmethod
131     def _strip_ns_from_bin_key(cls, ns: str, nskey: bytes) -> str:
132         try:
133             redis_key = nskey.decode('utf-8')
134         except UnicodeDecodeError as exc:
135             msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
136             raise RejectedByBackend(msg)
137         nskey = redis_key.split(',', 1)
138         if len(nskey) != 2:
139             msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
140             raise RejectedByBackend(msg)
141         return nskey[1]
142
143
144 class RedisBackend(DbBackendAbc):
145     """
146     A class providing an implementation of database backend of Shared Data Layer (SDL), when
147     backend database solution is Redis.
148
149     Args:
150         configuration (_Configuration): SDL configuration, containing credentials to connect to
151                                         Redis database backend.
152     """
153     def __init__(self, configuration: _Configuration) -> None:
154         super().__init__()
155         self.next_client_event = 0
156         self.event_separator = configuration.get_event_separator()
157         self.clients = list()
158         with _map_to_sdl_exception():
159             self.clients = self.__create_redis_clients(configuration)
160
161     def __del__(self):
162         self.close()
163
164     def __str__(self):
165         out = {"DB type": "Redis"}
166         for i, r in enumerate(self.clients):
167             out["Redis client[" + str(i) + "]"] = str(r)
168         return str(out)
169
170     def is_connected(self):
171         is_connected = True
172         with _map_to_sdl_exception():
173             for c in self.clients:
174                 if not c.redis_client.ping():
175                     is_connected = False
176                     break
177         return is_connected
178
179     def close(self):
180         for c in self.clients:
181             c.redis_client.close()
182
183     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
184         db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
185         with _map_to_sdl_exception():
186             self.__getClient(ns).mset(db_data_map)
187
188     def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
189         db_key = self.__add_key_ns_prefix(ns, key)
190         with _map_to_sdl_exception():
191             return self.__getClient(ns).execute_command('SETIE', db_key, new_data, old_data)
192
193     def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
194         db_key = self.__add_key_ns_prefix(ns, key)
195         with _map_to_sdl_exception():
196             return self.__getClient(ns).setnx(db_key, data)
197
198     def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
199         ret = dict()
200         db_keys = self.__add_keys_ns_prefix(ns, keys)
201         with _map_to_sdl_exception():
202             values = self.__getClient(ns).mget(db_keys)
203             for idx, val in enumerate(values):
204                 # return only key values, which has a value
205                 if val is not None:
206                     ret[keys[idx]] = val
207             return ret
208
209     def find_keys(self, ns: str, key_pattern: str) -> List[str]:
210         db_key_pattern = self.__add_key_ns_prefix(ns, key_pattern)
211         with _map_to_sdl_exception():
212             ret = self.__getClient(ns).keys(db_key_pattern)
213             return self.__strip_ns_from_bin_keys(ns, ret)
214
215     def find_and_get(self, ns: str, key_pattern: str) -> Dict[str, bytes]:
216         # todo: replace below implementation with redis 'NGET' module
217         ret = dict()  # type: Dict[str, bytes]
218         with _map_to_sdl_exception():
219             matched_keys = self.find_keys(ns, key_pattern)
220             if matched_keys:
221                 ret = self.get(ns, matched_keys)
222         return ret
223
224     def remove(self, ns: str, keys: List[str]) -> None:
225         db_keys = self.__add_keys_ns_prefix(ns, keys)
226         with _map_to_sdl_exception():
227             self.__getClient(ns).delete(*db_keys)
228
229     def remove_if(self, ns: str, key: str, data: bytes) -> bool:
230         db_key = self.__add_key_ns_prefix(ns, key)
231         with _map_to_sdl_exception():
232             return self.__getClient(ns).execute_command('DELIE', db_key, data)
233
234     def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
235         db_key = self.__add_key_ns_prefix(ns, group)
236         with _map_to_sdl_exception():
237             self.__getClient(ns).sadd(db_key, *members)
238
239     def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
240         db_key = self.__add_key_ns_prefix(ns, group)
241         with _map_to_sdl_exception():
242             self.__getClient(ns).srem(db_key, *members)
243
244     def remove_group(self, ns: str, group: str) -> None:
245         db_key = self.__add_key_ns_prefix(ns, group)
246         with _map_to_sdl_exception():
247             self.__getClient(ns).delete(db_key)
248
249     def get_members(self, ns: str, group: str) -> Set[bytes]:
250         db_key = self.__add_key_ns_prefix(ns, group)
251         with _map_to_sdl_exception():
252             return self.__getClient(ns).smembers(db_key)
253
254     def is_member(self, ns: str, group: str, member: bytes) -> bool:
255         db_key = self.__add_key_ns_prefix(ns, group)
256         with _map_to_sdl_exception():
257             return self.__getClient(ns).sismember(db_key, member)
258
259     def group_size(self, ns: str, group: str) -> int:
260         db_key = self.__add_key_ns_prefix(ns, group)
261         with _map_to_sdl_exception():
262             return self.__getClient(ns).scard(db_key)
263
264     def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
265                         data_map: Dict[str, bytes]) -> None:
266         db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
267         channels_and_events_prepared = []
268         total_events = 0
269         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
270         with _map_to_sdl_exception():
271             return self.__getClient(ns).execute_command(
272                 "MSETMPUB",
273                 len(db_data_map),
274                 total_events,
275                 *[val for data in db_data_map.items() for val in data],
276                 *channels_and_events_prepared,
277             )
278
279     def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
280                            old_data: bytes, new_data: bytes) -> bool:
281         db_key = self.__add_key_ns_prefix(ns, key)
282         channels_and_events_prepared = []
283         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
284         with _map_to_sdl_exception():
285             ret = self.__getClient(ns).execute_command("SETIEMPUB", db_key, new_data, old_data,
286                                                        *channels_and_events_prepared)
287             return ret == b"OK"
288
289     def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
290                                       key: str, data: bytes) -> bool:
291         db_key = self.__add_key_ns_prefix(ns, key)
292         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
293         with _map_to_sdl_exception():
294             ret = self.__getClient(ns).execute_command("SETNXMPUB", db_key, data,
295                                                        *channels_and_events_prepared)
296             return ret == b"OK"
297
298     def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
299                            keys: List[str]) -> None:
300         db_keys = self.__add_keys_ns_prefix(ns, keys)
301         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
302         with _map_to_sdl_exception():
303             return self.__getClient(ns).execute_command(
304                 "DELMPUB",
305                 len(db_keys),
306                 total_events,
307                 *db_keys,
308                 *channels_and_events_prepared,
309             )
310
311     def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
312                               data: bytes) -> bool:
313         db_key = self.__add_key_ns_prefix(ns, key)
314         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
315         with _map_to_sdl_exception():
316             ret = self.__getClient(ns).execute_command("DELIEMPUB", db_key, data,
317                                                        *channels_and_events_prepared)
318             return bool(ret)
319
320     def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None:
321         keys = self.__getClient(ns).keys(self.__add_key_ns_prefix(ns, "*"))
322         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
323         with _map_to_sdl_exception():
324             return self.__getClient(ns).execute_command(
325                 "DELMPUB",
326                 len(keys),
327                 total_events,
328                 *keys,
329                 *channels_and_events_prepared,
330             )
331
332     def subscribe_channel(self, ns: str, cb: Callable[[str, List[str]], None],
333                           channels: List[str]) -> None:
334         channels = self.__add_keys_ns_prefix(ns, channels)
335         for channel in channels:
336             with _map_to_sdl_exception():
337                 redis_ctx = self.__getClientConn(ns)
338                 redis_ctx.redis_pubsub.subscribe(**{channel: cb})
339                 if not redis_ctx.pubsub_thread.is_alive() and redis_ctx.run_in_thread:
340                     redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001,
341                                                                                    daemon=True)
342
343     def unsubscribe_channel(self, ns: str, channels: List[str]) -> None:
344         channels = self.__add_keys_ns_prefix(ns, channels)
345         for channel in channels:
346             with _map_to_sdl_exception():
347                 self.__getClientConn(ns).redis_pubsub.unsubscribe(channel)
348
349     def start_event_listener(self) -> None:
350         redis_ctxs = self.__getClientConns()
351         for redis_ctx in redis_ctxs:
352             if redis_ctx.pubsub_thread.is_alive():
353                 raise RejectedByBackend("Event loop already started")
354             if redis_ctx.redis_pubsub.subscribed and len(redis_ctx.redis_client.pubsub_channels()) > 0:
355                 redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True)
356             redis_ctx.run_in_thread = True
357
358     def handle_events(self) -> Optional[Tuple[str, List[str]]]:
359         if self.next_client_event >= len(self.clients):
360             self.next_client_event = 0
361         redis_ctx = self.clients[self.next_client_event]
362         self.next_client_event += 1
363         if redis_ctx.pubsub_thread.is_alive() or redis_ctx.run_in_thread:
364             raise RejectedByBackend("Event loop already started")
365         try:
366             return redis_ctx.redis_pubsub.get_message(ignore_subscribe_messages=True)
367         except RuntimeError:
368             return None
369
370     def __create_redis_clients(self, config):
371         clients = list()
372         cfg_params = config.get_params()
373         if cfg_params.db_cluster_addr_list is None:
374             clients.append(self.__create_legacy_redis_client(cfg_params))
375         else:
376             for addr in cfg_params.db_cluster_addr_list.split(","):
377                 client = self.__create_redis_client(cfg_params, addr)
378                 clients.append(client)
379         return clients
380
381     def __create_legacy_redis_client(self, cfg_params):
382         return self.__create_redis_client(cfg_params, cfg_params.db_host)
383
384     def __create_redis_client(self, cfg_params, addr):
385         new_sentinel = None
386         new_redis = None
387         if cfg_params.db_sentinel_port is None:
388             new_redis = Redis(host=addr, port=cfg_params.db_port, db=0, max_connections=20)
389         else:
390             sentinel_node = (addr, cfg_params.db_sentinel_port)
391             master_name = cfg_params.db_sentinel_master_name
392             new_sentinel = Sentinel([sentinel_node])
393             new_redis = new_sentinel.master_for(master_name)
394
395         new_redis.set_response_callback('SETIE', lambda r: r and str_if_bytes(r) == 'OK' or False)
396         new_redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
397
398         redis_pubsub = PubSub(self.event_separator, new_redis.connection_pool, ignore_subscribe_messages=True)
399         pubsub_thread = threading.Thread(target=None)
400         run_in_thread = False
401
402         return _RedisConn(new_redis, redis_pubsub, pubsub_thread, run_in_thread)
403
404     def __getClientConns(self):
405         return self.clients
406
407     def __getClientConn(self, ns):
408         clients_cnt = len(self.clients)
409         client_id = self.__get_hash(ns) % clients_cnt
410         return self.clients[client_id]
411
412     def __getClient(self, ns):
413         clients_cnt = len(self.clients)
414         client_id = 0
415         if clients_cnt > 1:
416             client_id = self.__get_hash(ns) % clients_cnt
417         return self.clients[client_id].redis_client
418
419     @classmethod
420     def __get_hash(cls, str):
421         return zlib.crc32(str.encode())
422
423     @classmethod
424     def __add_key_ns_prefix(cls, ns: str, key: str):
425         return '{' + ns + '},' + key
426
427     @classmethod
428     def __add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
429         ret_nskeys = []
430         for k in keylist:
431             ret_nskeys.append('{' + ns + '},' + k)
432         return ret_nskeys
433
434     @classmethod
435     def __add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
436         ret_nsdict = {}
437         for key, val in data_dict.items():
438             ret_nsdict['{' + ns + '},' + key] = val
439         return ret_nsdict
440
441     @classmethod
442     def __strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
443         ret_keys = []
444         for k in nskeylist:
445             try:
446                 redis_key = k.decode("utf-8")
447             except UnicodeDecodeError as exc:
448                 msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
449                 raise RejectedByBackend(msg)
450             nskey = redis_key.split(',', 1)
451             if len(nskey) != 2:
452                 msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
453                 raise RejectedByBackend(msg)
454             ret_keys.append(nskey[1])
455         return ret_keys
456
457     def _prepare_channels(self, ns: str,
458                           channels_and_events: Dict[str, List[str]]) -> Tuple[List, int]:
459         channels_and_events_prepared = []
460         for channel, events in channels_and_events.items():
461             one_channel_join_events = None
462             for event in events:
463                 if one_channel_join_events is None:
464                     channels_and_events_prepared.append(self.__add_key_ns_prefix(ns, channel))
465                     one_channel_join_events = event
466                 else:
467                     one_channel_join_events = one_channel_join_events + self.event_separator + event
468             channels_and_events_prepared.append(one_channel_join_events)
469         pairs_cnt = int(len(channels_and_events_prepared) / 2)
470         return channels_and_events_prepared, pairs_cnt
471
472     def get_redis_connection(self, ns: str):
473         """Return existing Redis database connection valid for the namespace."""
474         return self.__getClient(ns)
475
476
477 class _RedisConn:
478     """
479     Internal class container to hold redis client connection
480     """
481
482     def __init__(self, redis_client, pubsub, pubsub_thread, run_in_thread):
483         self.redis_client = redis_client
484         self.redis_pubsub = pubsub
485         self.pubsub_thread = pubsub_thread
486         self.run_in_thread = run_in_thread
487
488     def __str__(self):
489         return str(
490             {
491                 "Client": repr(self.redis_client),
492                 "Subscrions": self.redis_pubsub.subscribed,
493                 "PubSub thread": repr(self.pubsub_thread),
494                 "Run in thread": self.run_in_thread,
495             }
496         )
497
498
499 class RedisBackendLock(DbBackendLockAbc):
500     """
501     A class providing an implementation of database backend lock of Shared Data Layer (SDL), when
502     backend database solution is Redis.
503
504     Args:
505         ns (str): Namespace under which this lock is targeted.
506         name (str): Lock name, identifies the lock key in a Redis database backend.
507         expiration (int, float): Lock expiration time after which the lock is removed if it hasn't
508                                  been released earlier by a 'release' method.
509         redis_backend (RedisBackend): Database backend object containing connection to Redis
510                                       database.
511     """
512     lua_get_validity_time = None
513     # KEYS[1] - lock name
514     # ARGS[1] - token
515     # return < 0 in case of failure, otherwise return lock validity time in milliseconds.
516     LUA_GET_VALIDITY_TIME_SCRIPT = """
517         local token = redis.call('get', KEYS[1])
518         if not token then
519             return -10
520         end
521         if token ~= ARGV[1] then
522             return -11
523         end
524         return redis.call('pttl', KEYS[1])
525     """
526
527     def __init__(self, ns: str, name: str, expiration: Union[int, float],
528                  redis_backend: RedisBackend) -> None:
529         super().__init__(ns, name)
530         self.__redis = redis_backend.get_redis_connection(ns)
531         with _map_to_sdl_exception():
532             redis_lockname = '{' + ns + '},' + self._lock_name
533             self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)
534             self._register_scripts()
535
536     def __str__(self):
537         return str(
538             {
539                 "lock DB type": "Redis",
540                 "lock namespace": self._ns,
541                 "lock name": self._lock_name,
542                 "lock status": self._lock_status_to_string()
543             }
544         )
545
546     def acquire(self, retry_interval: Union[int, float] = 0.1,
547                 retry_timeout: Union[int, float] = 10) -> bool:
548         succeeded = False
549         self.__redis_lock.sleep = retry_interval
550         with _map_to_sdl_exception():
551             succeeded = self.__redis_lock.acquire(blocking_timeout=retry_timeout)
552         return succeeded
553
554     def release(self) -> None:
555         with _map_to_sdl_exception():
556             self.__redis_lock.release()
557
558     def refresh(self) -> None:
559         with _map_to_sdl_exception():
560             self.__redis_lock.reacquire()
561
562     def get_validity_time(self) -> Union[int, float]:
563         validity = 0
564         if self.__redis_lock.local.token is None:
565             msg = u'Cannot get validity time of an unlocked lock %s' % self._lock_name
566             raise RejectedByBackend(msg)
567
568         with _map_to_sdl_exception():
569             validity = self.lua_get_validity_time(keys=[self.__redis_lock.name],
570                                                   args=[self.__redis_lock.local.token],
571                                                   client=self.__redis)
572         if validity < 0:
573             msg = (u'Getting validity time of a lock %s failed with error code: %d'
574                    % (self._lock_name, validity))
575             raise RejectedByBackend(msg)
576         ftime = validity / 1000.0
577         if ftime.is_integer():
578             return int(ftime)
579         return ftime
580
581     def _register_scripts(self):
582         cls = self.__class__
583         client = self.__redis
584         if cls.lua_get_validity_time is None:
585             cls.lua_get_validity_time = client.register_script(cls.LUA_GET_VALIDITY_TIME_SCRIPT)
586
587     def _lock_status_to_string(self) -> str:
588         try:
589             if self.__redis_lock.locked():
590                 if self.__redis_lock.owned():
591                     return 'locked'
592                 return 'locked by someone else'
593             return 'unlocked'
594         except(redis_exceptions.RedisError) as exc:
595             return f'Error: {str(exc)}'