Add support for notifications
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
1 # Copyright (c) 2019 AT&T Intellectual Property.
2 # Copyright (c) 2018-2019 Nokia.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 #
17 # This source code is part of the near-RT RIC (RAN Intelligent Controller)
18 # platform project (RICP).
19 #
20
21
22 """The module provides implementation of Shared Data Layer (SDL) database backend interface."""
23 import contextlib
24 import threading
25 from typing import (Callable, Dict, Set, List, Optional, Tuple, Union)
26 import redis
27 from redis import Redis
28 from redis.sentinel import Sentinel
29 from redis.lock import Lock
30 from redis._compat import nativestr
31 from redis import exceptions as redis_exceptions
32 from ricsdl.configuration import _Configuration
33 from ricsdl.exceptions import (
34     RejectedByBackend,
35     NotConnected,
36     BackendError
37 )
38 from .dbbackend_abc import DbBackendAbc
39 from .dbbackend_abc import DbBackendLockAbc
40
41
42 @contextlib.contextmanager
43 def _map_to_sdl_exception():
44     """Translates known redis exceptions into SDL exceptions."""
45     try:
46         yield
47     except(redis_exceptions.ResponseError) as exc:
48         raise RejectedByBackend("SDL backend rejected the request: {}".
49                                 format(str(exc))) from exc
50     except(redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
51         raise NotConnected("SDL not connected to backend: {}".
52                            format(str(exc))) from exc
53     except(redis_exceptions.RedisError) as exc:
54         raise BackendError("SDL backend failed to process the request: {}".
55                            format(str(exc))) from exc
56
57
58 class PubSub(redis.client.PubSub):
59     def handle_message(self, response, ignore_subscribe_messages=False):
60         """
61         Parses a pub/sub message. If the channel or pattern was subscribed to
62         with a message handler, the handler is invoked instead of a parsed
63         message being returned.
64
65         Adapted from: https://github.com/andymccurdy/redis-py/blob/master/redis/client.py
66         """
67         message_type = nativestr(response[0])
68         if message_type == 'pmessage':
69             message = {
70                 'type': message_type,
71                 'pattern': response[1],
72                 'channel': response[2],
73                 'data': response[3]
74             }
75         elif message_type == 'pong':
76             message = {
77                 'type': message_type,
78                 'pattern': None,
79                 'channel': None,
80                 'data': response[1]
81             }
82         else:
83             message = {
84                 'type': message_type,
85                 'pattern': None,
86                 'channel': response[1],
87                 'data': response[2]
88             }
89
90         # if this is an unsubscribe message, remove it from memory
91         if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
92             if message_type == 'punsubscribe':
93                 pattern = response[1]
94                 if pattern in self.pending_unsubscribe_patterns:
95                     self.pending_unsubscribe_patterns.remove(pattern)
96                     self.patterns.pop(pattern, None)
97             else:
98                 channel = response[1]
99                 if channel in self.pending_unsubscribe_channels:
100                     self.pending_unsubscribe_channels.remove(channel)
101                     self.channels.pop(channel, None)
102
103         if message_type in self.PUBLISH_MESSAGE_TYPES:
104             # if there's a message handler, invoke it
105             if message_type == 'pmessage':
106                 handler = self.patterns.get(message['pattern'], None)
107             else:
108                 handler = self.channels.get(message['channel'], None)
109             if handler:
110                 # Need to send only channel and notification instead of raw
111                 # message
112                 message_channel = self._strip_ns_from_bin_key('', message['channel'])
113                 message_data = message['data'].decode('utf-8')
114                 handler(message_channel, message_data)
115                 return message_channel, message_data
116         elif message_type != 'pong':
117             # this is a subscribe/unsubscribe message. ignore if we don't
118             # want them
119             if ignore_subscribe_messages or self.ignore_subscribe_messages:
120                 return None
121
122         return message
123
124     @classmethod
125     def _strip_ns_from_bin_key(cls, ns: str, nskey: bytes) -> str:
126         try:
127             redis_key = nskey.decode('utf-8')
128         except UnicodeDecodeError as exc:
129             msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
130             raise RejectedByBackend(msg)
131         nskey = redis_key.split(',', 1)
132         if len(nskey) != 2:
133             msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
134             raise RejectedByBackend(msg)
135         return nskey[1]
136
137
138 class RedisBackend(DbBackendAbc):
139     """
140     A class providing an implementation of database backend of Shared Data Layer (SDL), when
141     backend database solution is Redis.
142
143     Args:
144         configuration (_Configuration): SDL configuration, containing credentials to connect to
145                                         Redis database backend.
146     """
147     def __init__(self, configuration: _Configuration) -> None:
148         super().__init__()
149         with _map_to_sdl_exception():
150             if configuration.get_params().db_sentinel_port:
151                 sentinel_node = (configuration.get_params().db_host,
152                                  configuration.get_params().db_sentinel_port)
153                 master_name = configuration.get_params().db_sentinel_master_name
154                 self.__sentinel = Sentinel([sentinel_node])
155                 self.__redis = self.__sentinel.master_for(master_name)
156             else:
157                 self.__redis = Redis(host=configuration.get_params().db_host,
158                                      port=configuration.get_params().db_port,
159                                      db=0,
160                                      max_connections=20)
161         self.__redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
162         self.__redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
163
164         self.__redis_pubsub = PubSub(self.__redis.connection_pool, ignore_subscribe_messages=True)
165         self.pubsub_thread = threading.Thread(target=None)
166         self._run_in_thread = False
167
168     def __del__(self):
169         self.close()
170
171     def __str__(self):
172         return str(
173             {
174                 "DB type": "Redis",
175                 "Redis connection": repr(self.__redis)
176             }
177         )
178
179     def is_connected(self):
180         with _map_to_sdl_exception():
181             return self.__redis.ping()
182
183     def close(self):
184         self.__redis.close()
185
186     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
187         db_data_map = self._add_data_map_ns_prefix(ns, data_map)
188         with _map_to_sdl_exception():
189             self.__redis.mset(db_data_map)
190
191     def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
192         db_key = self._add_key_ns_prefix(ns, key)
193         with _map_to_sdl_exception():
194             return self.__redis.execute_command('SETIE', db_key, new_data, old_data)
195
196     def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
197         db_key = self._add_key_ns_prefix(ns, key)
198         with _map_to_sdl_exception():
199             return self.__redis.setnx(db_key, data)
200
201     def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
202         ret = dict()
203         db_keys = self._add_keys_ns_prefix(ns, keys)
204         with _map_to_sdl_exception():
205             values = self.__redis.mget(db_keys)
206             for idx, val in enumerate(values):
207                 # return only key values, which has a value
208                 if val:
209                     ret[keys[idx]] = val
210             return ret
211
212     def find_keys(self, ns: str, key_pattern: str) -> List[str]:
213         db_key_pattern = self._add_key_ns_prefix(ns, key_pattern)
214         with _map_to_sdl_exception():
215             ret = self.__redis.keys(db_key_pattern)
216             return self._strip_ns_from_bin_keys(ns, ret)
217
218     def find_and_get(self, ns: str, key_pattern: str) -> Dict[str, bytes]:
219         # todo: replace below implementation with redis 'NGET' module
220         ret = dict()  # type: Dict[str, bytes]
221         with _map_to_sdl_exception():
222             matched_keys = self.find_keys(ns, key_pattern)
223             if matched_keys:
224                 ret = self.get(ns, matched_keys)
225         return ret
226
227     def remove(self, ns: str, keys: List[str]) -> None:
228         db_keys = self._add_keys_ns_prefix(ns, keys)
229         with _map_to_sdl_exception():
230             self.__redis.delete(*db_keys)
231
232     def remove_if(self, ns: str, key: str, data: bytes) -> bool:
233         db_key = self._add_key_ns_prefix(ns, key)
234         with _map_to_sdl_exception():
235             return self.__redis.execute_command('DELIE', db_key, data)
236
237     def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
238         db_key = self._add_key_ns_prefix(ns, group)
239         with _map_to_sdl_exception():
240             self.__redis.sadd(db_key, *members)
241
242     def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
243         db_key = self._add_key_ns_prefix(ns, group)
244         with _map_to_sdl_exception():
245             self.__redis.srem(db_key, *members)
246
247     def remove_group(self, ns: str, group: str) -> None:
248         db_key = self._add_key_ns_prefix(ns, group)
249         with _map_to_sdl_exception():
250             self.__redis.delete(db_key)
251
252     def get_members(self, ns: str, group: str) -> Set[bytes]:
253         db_key = self._add_key_ns_prefix(ns, group)
254         with _map_to_sdl_exception():
255             return self.__redis.smembers(db_key)
256
257     def is_member(self, ns: str, group: str, member: bytes) -> bool:
258         db_key = self._add_key_ns_prefix(ns, group)
259         with _map_to_sdl_exception():
260             return self.__redis.sismember(db_key, member)
261
262     def group_size(self, ns: str, group: str) -> int:
263         db_key = self._add_key_ns_prefix(ns, group)
264         with _map_to_sdl_exception():
265             return self.__redis.scard(db_key)
266
267     def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
268                         data_map: Dict[str, bytes]) -> None:
269         db_data_map = self._add_data_map_ns_prefix(ns, data_map)
270         channels_and_events_prepared = []
271         total_events = 0
272         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
273         with _map_to_sdl_exception():
274             return self.__redis.execute_command(
275                 "MSETMPUB",
276                 len(db_data_map),
277                 total_events,
278                 *[val for data in db_data_map.items() for val in data],
279                 *channels_and_events_prepared,
280             )
281
282     def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
283                            old_data: bytes, new_data: bytes) -> bool:
284         db_key = self._add_key_ns_prefix(ns, key)
285         channels_and_events_prepared = []
286         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
287         with _map_to_sdl_exception():
288             ret = self.__redis.execute_command("SETIEPUB", db_key, new_data, old_data,
289                                                *channels_and_events_prepared)
290             return ret == b"OK"
291
292     def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
293                                       key: str, data: bytes) -> bool:
294         db_key = self._add_key_ns_prefix(ns, key)
295         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
296         with _map_to_sdl_exception():
297             ret = self.__redis.execute_command("SETNXPUB", db_key, data,
298                                                *channels_and_events_prepared)
299             return ret == b"OK"
300
301     def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
302                            keys: List[str]) -> None:
303         db_keys = self._add_keys_ns_prefix(ns, keys)
304         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
305         with _map_to_sdl_exception():
306             return self.__redis.execute_command(
307                 "DELMPUB",
308                 len(db_keys),
309                 total_events,
310                 *db_keys,
311                 *channels_and_events_prepared,
312             )
313
314     def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
315                               data: bytes) -> bool:
316         db_key = self._add_key_ns_prefix(ns, key)
317         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
318         with _map_to_sdl_exception():
319             ret = self.__redis.execute_command("DELIEPUB", db_key, data,
320                                                *channels_and_events_prepared)
321             return bool(ret)
322
323     def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None:
324         keys = self.__redis.keys(self._add_key_ns_prefix(ns, "*"))
325         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
326         with _map_to_sdl_exception():
327             return self.__redis.execute_command(
328                 "DELMPUB",
329                 len(keys),
330                 total_events,
331                 *keys,
332                 *channels_and_events_prepared,
333             )
334
335     def subscribe_channel(self, ns: str, cb: Callable[[str, str], None],
336                           channels: List[str]) -> None:
337         channels = self._add_keys_ns_prefix(ns, channels)
338         for channel in channels:
339             with _map_to_sdl_exception():
340                 self.__redis_pubsub.subscribe(**{channel: cb})
341                 if not self.pubsub_thread.is_alive() and self._run_in_thread:
342                     self.pubsub_thread = self.__redis_pubsub.run_in_thread(sleep_time=0.001,
343                                                                            daemon=True)
344
345     def unsubscribe_channel(self, ns: str, channels: List[str]) -> None:
346         channels = self._add_keys_ns_prefix(ns, channels)
347         for channel in channels:
348             with _map_to_sdl_exception():
349                 self.__redis_pubsub.unsubscribe(channel)
350
351     def start_event_listener(self) -> None:
352         if self.pubsub_thread.is_alive():
353             raise RejectedByBackend("Event loop already started")
354         if len(self.__redis.pubsub_channels()) > 0:
355             self.pubsub_thread = self.__redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True)
356         self._run_in_thread = True
357
358     def handle_events(self) -> Optional[Tuple[str, str]]:
359         if self.pubsub_thread.is_alive() or self._run_in_thread:
360             raise RejectedByBackend("Event loop already started")
361         try:
362             return self.__redis_pubsub.get_message(ignore_subscribe_messages=True)
363         except RuntimeError:
364             return None
365
366     @classmethod
367     def _add_key_ns_prefix(cls, ns: str, key: str):
368         return '{' + ns + '},' + key
369
370     @classmethod
371     def _add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
372         ret_nskeys = []
373         for k in keylist:
374             ret_nskeys.append('{' + ns + '},' + k)
375         return ret_nskeys
376
377     @classmethod
378     def _add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
379         ret_nsdict = {}
380         for key, val in data_dict.items():
381             ret_nsdict['{' + ns + '},' + key] = val
382         return ret_nsdict
383
384     @classmethod
385     def _strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
386         ret_keys = []
387         for k in nskeylist:
388             try:
389                 redis_key = k.decode("utf-8")
390             except UnicodeDecodeError as exc:
391                 msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
392                 raise RejectedByBackend(msg)
393             nskey = redis_key.split(',', 1)
394             if len(nskey) != 2:
395                 msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
396                 raise RejectedByBackend(msg)
397             ret_keys.append(nskey[1])
398         return ret_keys
399
400     @classmethod
401     def _prepare_channels(cls, ns: str, channels_and_events: Dict[str,
402                                                                   List[str]]) -> Tuple[List, int]:
403         channels_and_events_prepared = []
404         total_events = 0
405         for channel, events in channels_and_events.items():
406             for event in events:
407                 channels_and_events_prepared.append(cls._add_key_ns_prefix(ns, channel))
408                 channels_and_events_prepared.append(event)
409                 total_events += 1
410         return channels_and_events_prepared, total_events
411
412     def get_redis_connection(self):
413         """Return existing Redis database connection."""
414         return self.__redis
415
416
417 class RedisBackendLock(DbBackendLockAbc):
418     """
419     A class providing an implementation of database backend lock of Shared Data Layer (SDL), when
420     backend database solution is Redis.
421
422     Args:
423         ns (str): Namespace under which this lock is targeted.
424         name (str): Lock name, identifies the lock key in a Redis database backend.
425         expiration (int, float): Lock expiration time after which the lock is removed if it hasn't
426                                  been released earlier by a 'release' method.
427         redis_backend (RedisBackend): Database backend object containing connection to Redis
428                                       database.
429     """
430     lua_get_validity_time = None
431     # KEYS[1] - lock name
432     # ARGS[1] - token
433     # return < 0 in case of failure, otherwise return lock validity time in milliseconds.
434     LUA_GET_VALIDITY_TIME_SCRIPT = """
435         local token = redis.call('get', KEYS[1])
436         if not token then
437             return -10
438         end
439         if token ~= ARGV[1] then
440             return -11
441         end
442         return redis.call('pttl', KEYS[1])
443     """
444
445     def __init__(self, ns: str, name: str, expiration: Union[int, float],
446                  redis_backend: RedisBackend) -> None:
447         super().__init__(ns, name)
448         self.__redis = redis_backend.get_redis_connection()
449         with _map_to_sdl_exception():
450             redis_lockname = '{' + ns + '},' + self._lock_name
451             self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)
452             self._register_scripts()
453
454     def __str__(self):
455         return str(
456             {
457                 "lock DB type": "Redis",
458                 "lock namespace": self._ns,
459                 "lock name": self._lock_name,
460                 "lock status": self._lock_status_to_string()
461             }
462         )
463
464     def acquire(self, retry_interval: Union[int, float] = 0.1,
465                 retry_timeout: Union[int, float] = 10) -> bool:
466         succeeded = False
467         self.__redis_lock.sleep = retry_interval
468         with _map_to_sdl_exception():
469             succeeded = self.__redis_lock.acquire(blocking_timeout=retry_timeout)
470         return succeeded
471
472     def release(self) -> None:
473         with _map_to_sdl_exception():
474             self.__redis_lock.release()
475
476     def refresh(self) -> None:
477         with _map_to_sdl_exception():
478             self.__redis_lock.reacquire()
479
480     def get_validity_time(self) -> Union[int, float]:
481         validity = 0
482         if self.__redis_lock.local.token is None:
483             msg = u'Cannot get validity time of an unlocked lock %s' % self._lock_name
484             raise RejectedByBackend(msg)
485
486         with _map_to_sdl_exception():
487             validity = self.lua_get_validity_time(keys=[self.__redis_lock.name],
488                                                   args=[self.__redis_lock.local.token],
489                                                   client=self.__redis)
490         if validity < 0:
491             msg = (u'Getting validity time of a lock %s failed with error code: %d'
492                    % (self._lock_name, validity))
493             raise RejectedByBackend(msg)
494         ftime = validity / 1000.0
495         if ftime.is_integer():
496             return int(ftime)
497         return ftime
498
499     def _register_scripts(self):
500         cls = self.__class__
501         client = self.__redis
502         if cls.lua_get_validity_time is None:
503             cls.lua_get_validity_time = client.register_script(cls.LUA_GET_VALIDITY_TIME_SCRIPT)
504
505     def _lock_status_to_string(self) -> str:
506         try:
507             if self.__redis_lock.locked():
508                 if self.__redis_lock.owned():
509                     return 'locked'
510                 return 'locked by someone else'
511             return 'unlocked'
512         except(redis_exceptions.RedisError) as exc:
513             return f'Error: {str(exc)}'