Implement sentinel based DB capacity scaling
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
1 # Copyright (c) 2019 AT&T Intellectual Property.
2 # Copyright (c) 2018-2019 Nokia.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 #
17 # This source code is part of the near-RT RIC (RAN Intelligent Controller)
18 # platform project (RICP).
19 #
20
21
22 """The module provides implementation of Shared Data Layer (SDL) database backend interface."""
23 import contextlib
24 import threading
25 from typing import (Callable, Dict, Set, List, Optional, Tuple, Union)
26 import zlib
27 import redis
28 from redis import Redis
29 from redis.sentinel import Sentinel
30 from redis.lock import Lock
31 from redis._compat import nativestr
32 from redis import exceptions as redis_exceptions
33 from ricsdl.configuration import _Configuration
34 from ricsdl.exceptions import (
35     RejectedByBackend,
36     NotConnected,
37     BackendError
38 )
39 from .dbbackend_abc import DbBackendAbc
40 from .dbbackend_abc import DbBackendLockAbc
41
42
43 @contextlib.contextmanager
44 def _map_to_sdl_exception():
45     """Translates known redis exceptions into SDL exceptions."""
46     try:
47         yield
48     except(redis_exceptions.ResponseError) as exc:
49         raise RejectedByBackend("SDL backend rejected the request: {}".
50                                 format(str(exc))) from exc
51     except(redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
52         raise NotConnected("SDL not connected to backend: {}".
53                            format(str(exc))) from exc
54     except(redis_exceptions.RedisError) as exc:
55         raise BackendError("SDL backend failed to process the request: {}".
56                            format(str(exc))) from exc
57
58
59 class PubSub(redis.client.PubSub):
60     def handle_message(self, response, ignore_subscribe_messages=False):
61         """
62         Parses a pub/sub message. If the channel or pattern was subscribed to
63         with a message handler, the handler is invoked instead of a parsed
64         message being returned.
65
66         Adapted from: https://github.com/andymccurdy/redis-py/blob/master/redis/client.py
67         """
68         message_type = nativestr(response[0])
69         if message_type == 'pmessage':
70             message = {
71                 'type': message_type,
72                 'pattern': response[1],
73                 'channel': response[2],
74                 'data': response[3]
75             }
76         elif message_type == 'pong':
77             message = {
78                 'type': message_type,
79                 'pattern': None,
80                 'channel': None,
81                 'data': response[1]
82             }
83         else:
84             message = {
85                 'type': message_type,
86                 'pattern': None,
87                 'channel': response[1],
88                 'data': response[2]
89             }
90
91         # if this is an unsubscribe message, remove it from memory
92         if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
93             if message_type == 'punsubscribe':
94                 pattern = response[1]
95                 if pattern in self.pending_unsubscribe_patterns:
96                     self.pending_unsubscribe_patterns.remove(pattern)
97                     self.patterns.pop(pattern, None)
98             else:
99                 channel = response[1]
100                 if channel in self.pending_unsubscribe_channels:
101                     self.pending_unsubscribe_channels.remove(channel)
102                     self.channels.pop(channel, None)
103
104         if message_type in self.PUBLISH_MESSAGE_TYPES:
105             # if there's a message handler, invoke it
106             if message_type == 'pmessage':
107                 handler = self.patterns.get(message['pattern'], None)
108             else:
109                 handler = self.channels.get(message['channel'], None)
110             if handler:
111                 # Need to send only channel and notification instead of raw
112                 # message
113                 message_channel = self._strip_ns_from_bin_key('', message['channel'])
114                 message_data = message['data'].decode('utf-8')
115                 handler(message_channel, message_data)
116                 return message_channel, message_data
117         elif message_type != 'pong':
118             # this is a subscribe/unsubscribe message. ignore if we don't
119             # want them
120             if ignore_subscribe_messages or self.ignore_subscribe_messages:
121                 return None
122
123         return message
124
125     @classmethod
126     def _strip_ns_from_bin_key(cls, ns: str, nskey: bytes) -> str:
127         try:
128             redis_key = nskey.decode('utf-8')
129         except UnicodeDecodeError as exc:
130             msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
131             raise RejectedByBackend(msg)
132         nskey = redis_key.split(',', 1)
133         if len(nskey) != 2:
134             msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
135             raise RejectedByBackend(msg)
136         return nskey[1]
137
138
139 class RedisBackend(DbBackendAbc):
140     """
141     A class providing an implementation of database backend of Shared Data Layer (SDL), when
142     backend database solution is Redis.
143
144     Args:
145         configuration (_Configuration): SDL configuration, containing credentials to connect to
146                                         Redis database backend.
147     """
148     def __init__(self, configuration: _Configuration) -> None:
149         super().__init__()
150         self.next_client_event = 0
151         self.clients = list()
152         with _map_to_sdl_exception():
153             self.clients = self.__create_redis_clients(configuration)
154
155     def __del__(self):
156         self.close()
157
158     def __str__(self):
159         out = {"DB type": "Redis"}
160         for i, r in enumerate(self.clients):
161             out["Redis client[" + str(i) + "]"] = str(r)
162         return str(out)
163
164     def is_connected(self):
165         is_connected = True
166         with _map_to_sdl_exception():
167             for c in self.clients:
168                 if not c.redis_client.ping():
169                     is_connected = False
170                     break
171         return is_connected
172
173     def close(self):
174         for c in self.clients:
175             c.redis_client.close()
176
177     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
178         db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
179         with _map_to_sdl_exception():
180             self.__getClient(ns).mset(db_data_map)
181
182     def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
183         db_key = self.__add_key_ns_prefix(ns, key)
184         with _map_to_sdl_exception():
185             return self.__getClient(ns).execute_command('SETIE', db_key, new_data, old_data)
186
187     def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
188         db_key = self.__add_key_ns_prefix(ns, key)
189         with _map_to_sdl_exception():
190             return self.__getClient(ns).setnx(db_key, data)
191
192     def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
193         ret = dict()
194         db_keys = self.__add_keys_ns_prefix(ns, keys)
195         with _map_to_sdl_exception():
196             values = self.__getClient(ns).mget(db_keys)
197             for idx, val in enumerate(values):
198                 # return only key values, which has a value
199                 if val is not None:
200                     ret[keys[idx]] = val
201             return ret
202
203     def find_keys(self, ns: str, key_pattern: str) -> List[str]:
204         db_key_pattern = self.__add_key_ns_prefix(ns, key_pattern)
205         with _map_to_sdl_exception():
206             ret = self.__getClient(ns).keys(db_key_pattern)
207             return self.__strip_ns_from_bin_keys(ns, ret)
208
209     def find_and_get(self, ns: str, key_pattern: str) -> Dict[str, bytes]:
210         # todo: replace below implementation with redis 'NGET' module
211         ret = dict()  # type: Dict[str, bytes]
212         with _map_to_sdl_exception():
213             matched_keys = self.find_keys(ns, key_pattern)
214             if matched_keys:
215                 ret = self.get(ns, matched_keys)
216         return ret
217
218     def remove(self, ns: str, keys: List[str]) -> None:
219         db_keys = self.__add_keys_ns_prefix(ns, keys)
220         with _map_to_sdl_exception():
221             self.__getClient(ns).delete(*db_keys)
222
223     def remove_if(self, ns: str, key: str, data: bytes) -> bool:
224         db_key = self.__add_key_ns_prefix(ns, key)
225         with _map_to_sdl_exception():
226             return self.__getClient(ns).execute_command('DELIE', db_key, data)
227
228     def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
229         db_key = self.__add_key_ns_prefix(ns, group)
230         with _map_to_sdl_exception():
231             self.__getClient(ns).sadd(db_key, *members)
232
233     def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
234         db_key = self.__add_key_ns_prefix(ns, group)
235         with _map_to_sdl_exception():
236             self.__getClient(ns).srem(db_key, *members)
237
238     def remove_group(self, ns: str, group: str) -> None:
239         db_key = self.__add_key_ns_prefix(ns, group)
240         with _map_to_sdl_exception():
241             self.__getClient(ns).delete(db_key)
242
243     def get_members(self, ns: str, group: str) -> Set[bytes]:
244         db_key = self.__add_key_ns_prefix(ns, group)
245         with _map_to_sdl_exception():
246             return self.__getClient(ns).smembers(db_key)
247
248     def is_member(self, ns: str, group: str, member: bytes) -> bool:
249         db_key = self.__add_key_ns_prefix(ns, group)
250         with _map_to_sdl_exception():
251             return self.__getClient(ns).sismember(db_key, member)
252
253     def group_size(self, ns: str, group: str) -> int:
254         db_key = self.__add_key_ns_prefix(ns, group)
255         with _map_to_sdl_exception():
256             return self.__getClient(ns).scard(db_key)
257
258     def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
259                         data_map: Dict[str, bytes]) -> None:
260         db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
261         channels_and_events_prepared = []
262         total_events = 0
263         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
264         with _map_to_sdl_exception():
265             return self.__getClient(ns).execute_command(
266                 "MSETMPUB",
267                 len(db_data_map),
268                 total_events,
269                 *[val for data in db_data_map.items() for val in data],
270                 *channels_and_events_prepared,
271             )
272
273     def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
274                            old_data: bytes, new_data: bytes) -> bool:
275         db_key = self.__add_key_ns_prefix(ns, key)
276         channels_and_events_prepared = []
277         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
278         with _map_to_sdl_exception():
279             ret = self.__getClient(ns).execute_command("SETIEMPUB", db_key, new_data, old_data,
280                                                        *channels_and_events_prepared)
281             return ret == b"OK"
282
283     def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
284                                       key: str, data: bytes) -> bool:
285         db_key = self.__add_key_ns_prefix(ns, key)
286         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
287         with _map_to_sdl_exception():
288             ret = self.__getClient(ns).execute_command("SETNXMPUB", db_key, data,
289                                                        *channels_and_events_prepared)
290             return ret == b"OK"
291
292     def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
293                            keys: List[str]) -> None:
294         db_keys = self.__add_keys_ns_prefix(ns, keys)
295         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
296         with _map_to_sdl_exception():
297             return self.__getClient(ns).execute_command(
298                 "DELMPUB",
299                 len(db_keys),
300                 total_events,
301                 *db_keys,
302                 *channels_and_events_prepared,
303             )
304
305     def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
306                               data: bytes) -> bool:
307         db_key = self.__add_key_ns_prefix(ns, key)
308         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
309         with _map_to_sdl_exception():
310             ret = self.__getClient(ns).execute_command("DELIEMPUB", db_key, data,
311                                                        *channels_and_events_prepared)
312             return bool(ret)
313
314     def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None:
315         keys = self.__getClient(ns).keys(self.__add_key_ns_prefix(ns, "*"))
316         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
317         with _map_to_sdl_exception():
318             return self.__getClient(ns).execute_command(
319                 "DELMPUB",
320                 len(keys),
321                 total_events,
322                 *keys,
323                 *channels_and_events_prepared,
324             )
325
326     def subscribe_channel(self, ns: str, cb: Callable[[str, str], None],
327                           channels: List[str]) -> None:
328         channels = self.__add_keys_ns_prefix(ns, channels)
329         for channel in channels:
330             with _map_to_sdl_exception():
331                 redis_ctx = self.__getClientConn(ns)
332                 redis_ctx.redis_pubsub.subscribe(**{channel: cb})
333                 if not redis_ctx.pubsub_thread.is_alive() and redis_ctx.run_in_thread:
334                     redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001,
335                                                                                    daemon=True)
336
337     def unsubscribe_channel(self, ns: str, channels: List[str]) -> None:
338         channels = self.__add_keys_ns_prefix(ns, channels)
339         for channel in channels:
340             with _map_to_sdl_exception():
341                 self.__getClientConn(ns).redis_pubsub.unsubscribe(channel)
342
343     def start_event_listener(self) -> None:
344         redis_ctxs = self.__getClientConns()
345         for redis_ctx in redis_ctxs:
346             if redis_ctx.pubsub_thread.is_alive():
347                 raise RejectedByBackend("Event loop already started")
348             if redis_ctx.redis_pubsub.subscribed and len(redis_ctx.redis_client.pubsub_channels()) > 0:
349                 redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True)
350             redis_ctx.run_in_thread = True
351
352     def handle_events(self) -> Optional[Tuple[str, str]]:
353         if self.next_client_event >= len(self.clients):
354             self.next_client_event = 0
355         redis_ctx = self.clients[self.next_client_event]
356         self.next_client_event += 1
357         if redis_ctx.pubsub_thread.is_alive() or redis_ctx.run_in_thread:
358             raise RejectedByBackend("Event loop already started")
359         try:
360             return redis_ctx.redis_pubsub.get_message(ignore_subscribe_messages=True)
361         except RuntimeError:
362             return None
363
364     def __create_redis_clients(self, config):
365         clients = list()
366         cfg_params = config.get_params()
367         if cfg_params.db_cluster_addr_list is None:
368             clients.append(self.__create_legacy_redis_client(cfg_params))
369         else:
370             for addr in cfg_params.db_cluster_addr_list.split(","):
371                 client = self.__create_redis_client(cfg_params, addr)
372                 clients.append(client)
373         return clients
374
375     def __create_legacy_redis_client(self, cfg_params):
376         return self.__create_redis_client(cfg_params, cfg_params.db_host)
377
378     def __create_redis_client(self, cfg_params, addr):
379         new_sentinel = None
380         new_redis = None
381         if cfg_params.db_sentinel_port is None:
382             new_redis = Redis(host=addr, port=cfg_params.db_port, db=0, max_connections=20)
383         else:
384             sentinel_node = (addr, cfg_params.db_sentinel_port)
385             master_name = cfg_params.db_sentinel_master_name
386             new_sentinel = Sentinel([sentinel_node])
387             new_redis = new_sentinel.master_for(master_name)
388
389         new_redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
390         new_redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
391
392         redis_pubsub = PubSub(new_redis.connection_pool, ignore_subscribe_messages=True)
393         pubsub_thread = threading.Thread(target=None)
394         run_in_thread = False
395
396         return _RedisConn(new_redis, redis_pubsub, pubsub_thread, run_in_thread)
397
398     def __getClientConns(self):
399         return self.clients
400
401     def __getClientConn(self, ns):
402         clients_cnt = len(self.clients)
403         client_id = self.__get_hash(ns) % clients_cnt
404         return self.clients[client_id]
405
406     def __getClient(self, ns):
407         clients_cnt = len(self.clients)
408         client_id = 0
409         if clients_cnt > 1:
410             client_id = self.__get_hash(ns) % clients_cnt
411         return self.clients[client_id].redis_client
412
413     @classmethod
414     def __get_hash(cls, str):
415         return zlib.crc32(str.encode())
416
417     @classmethod
418     def __add_key_ns_prefix(cls, ns: str, key: str):
419         return '{' + ns + '},' + key
420
421     @classmethod
422     def __add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
423         ret_nskeys = []
424         for k in keylist:
425             ret_nskeys.append('{' + ns + '},' + k)
426         return ret_nskeys
427
428     @classmethod
429     def __add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
430         ret_nsdict = {}
431         for key, val in data_dict.items():
432             ret_nsdict['{' + ns + '},' + key] = val
433         return ret_nsdict
434
435     @classmethod
436     def __strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
437         ret_keys = []
438         for k in nskeylist:
439             try:
440                 redis_key = k.decode("utf-8")
441             except UnicodeDecodeError as exc:
442                 msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
443                 raise RejectedByBackend(msg)
444             nskey = redis_key.split(',', 1)
445             if len(nskey) != 2:
446                 msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
447                 raise RejectedByBackend(msg)
448             ret_keys.append(nskey[1])
449         return ret_keys
450
451     @classmethod
452     def _prepare_channels(cls, ns: str, channels_and_events: Dict[str,
453                                                                   List[str]]) -> Tuple[List, int]:
454         channels_and_events_prepared = []
455         total_events = 0
456         for channel, events in channels_and_events.items():
457             for event in events:
458                 channels_and_events_prepared.append(cls.__add_key_ns_prefix(ns, channel))
459                 channels_and_events_prepared.append(event)
460                 total_events += 1
461         return channels_and_events_prepared, total_events
462
463     def get_redis_connection(self, ns: str):
464         """Return existing Redis database connection valid for the namespace."""
465         return self.__getClient(ns)
466
467
468 class _RedisConn:
469     """
470     Internal class container to hold redis client connection
471     """
472
473     def __init__(self, redis_client, pubsub, pubsub_thread, run_in_thread):
474         self.redis_client = redis_client
475         self.redis_pubsub = pubsub
476         self.pubsub_thread = pubsub_thread
477         self.run_in_thread = run_in_thread
478
479     def __str__(self):
480         return str(
481             {
482                 "Client": repr(self.redis_client),
483                 "Subscrions": self.redis_pubsub.subscribed,
484                 "PubSub thread": repr(self.pubsub_thread),
485                 "Run in thread": self.run_in_thread,
486             }
487         )
488
489
490 class RedisBackendLock(DbBackendLockAbc):
491     """
492     A class providing an implementation of database backend lock of Shared Data Layer (SDL), when
493     backend database solution is Redis.
494
495     Args:
496         ns (str): Namespace under which this lock is targeted.
497         name (str): Lock name, identifies the lock key in a Redis database backend.
498         expiration (int, float): Lock expiration time after which the lock is removed if it hasn't
499                                  been released earlier by a 'release' method.
500         redis_backend (RedisBackend): Database backend object containing connection to Redis
501                                       database.
502     """
503     lua_get_validity_time = None
504     # KEYS[1] - lock name
505     # ARGS[1] - token
506     # return < 0 in case of failure, otherwise return lock validity time in milliseconds.
507     LUA_GET_VALIDITY_TIME_SCRIPT = """
508         local token = redis.call('get', KEYS[1])
509         if not token then
510             return -10
511         end
512         if token ~= ARGV[1] then
513             return -11
514         end
515         return redis.call('pttl', KEYS[1])
516     """
517
518     def __init__(self, ns: str, name: str, expiration: Union[int, float],
519                  redis_backend: RedisBackend) -> None:
520         super().__init__(ns, name)
521         self.__redis = redis_backend.get_redis_connection(ns)
522         with _map_to_sdl_exception():
523             redis_lockname = '{' + ns + '},' + self._lock_name
524             self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)
525             self._register_scripts()
526
527     def __str__(self):
528         return str(
529             {
530                 "lock DB type": "Redis",
531                 "lock namespace": self._ns,
532                 "lock name": self._lock_name,
533                 "lock status": self._lock_status_to_string()
534             }
535         )
536
537     def acquire(self, retry_interval: Union[int, float] = 0.1,
538                 retry_timeout: Union[int, float] = 10) -> bool:
539         succeeded = False
540         self.__redis_lock.sleep = retry_interval
541         with _map_to_sdl_exception():
542             succeeded = self.__redis_lock.acquire(blocking_timeout=retry_timeout)
543         return succeeded
544
545     def release(self) -> None:
546         with _map_to_sdl_exception():
547             self.__redis_lock.release()
548
549     def refresh(self) -> None:
550         with _map_to_sdl_exception():
551             self.__redis_lock.reacquire()
552
553     def get_validity_time(self) -> Union[int, float]:
554         validity = 0
555         if self.__redis_lock.local.token is None:
556             msg = u'Cannot get validity time of an unlocked lock %s' % self._lock_name
557             raise RejectedByBackend(msg)
558
559         with _map_to_sdl_exception():
560             validity = self.lua_get_validity_time(keys=[self.__redis_lock.name],
561                                                   args=[self.__redis_lock.local.token],
562                                                   client=self.__redis)
563         if validity < 0:
564             msg = (u'Getting validity time of a lock %s failed with error code: %d'
565                    % (self._lock_name, validity))
566             raise RejectedByBackend(msg)
567         ftime = validity / 1000.0
568         if ftime.is_integer():
569             return int(ftime)
570         return ftime
571
572     def _register_scripts(self):
573         cls = self.__class__
574         client = self.__redis
575         if cls.lua_get_validity_time is None:
576             cls.lua_get_validity_time = client.register_script(cls.LUA_GET_VALIDITY_TIME_SCRIPT)
577
578     def _lock_status_to_string(self) -> str:
579         try:
580             if self.__redis_lock.locked():
581                 if self.__redis_lock.owned():
582                     return 'locked'
583                 return 'locked by someone else'
584             return 'unlocked'
585         except(redis_exceptions.RedisError) as exc:
586             return f'Error: {str(exc)}'