Add support for notifications
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
old mode 100644 (file)
new mode 100755 (executable)
index 3364497..c67be2d
@@ -21,7 +21,9 @@
 
 """The module provides implementation of Shared Data Layer (SDL) database backend interface."""
 import contextlib
-from typing import (Dict, Set, List, Union)
+import threading
+from typing import (Callable, Dict, Set, List, Optional, Tuple, Union)
+import redis
 from redis import Redis
 from redis.sentinel import Sentinel
 from redis.lock import Lock
@@ -53,6 +55,86 @@ def _map_to_sdl_exception():
                            format(str(exc))) from exc
 
 
+class PubSub(redis.client.PubSub):
+    def handle_message(self, response, ignore_subscribe_messages=False):
+        """
+        Parses a pub/sub message. If the channel or pattern was subscribed to
+        with a message handler, the handler is invoked instead of a parsed
+        message being returned.
+
+        Adapted from: https://github.com/andymccurdy/redis-py/blob/master/redis/client.py
+        """
+        message_type = nativestr(response[0])
+        if message_type == 'pmessage':
+            message = {
+                'type': message_type,
+                'pattern': response[1],
+                'channel': response[2],
+                'data': response[3]
+            }
+        elif message_type == 'pong':
+            message = {
+                'type': message_type,
+                'pattern': None,
+                'channel': None,
+                'data': response[1]
+            }
+        else:
+            message = {
+                'type': message_type,
+                'pattern': None,
+                'channel': response[1],
+                'data': response[2]
+            }
+
+        # if this is an unsubscribe message, remove it from memory
+        if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
+            if message_type == 'punsubscribe':
+                pattern = response[1]
+                if pattern in self.pending_unsubscribe_patterns:
+                    self.pending_unsubscribe_patterns.remove(pattern)
+                    self.patterns.pop(pattern, None)
+            else:
+                channel = response[1]
+                if channel in self.pending_unsubscribe_channels:
+                    self.pending_unsubscribe_channels.remove(channel)
+                    self.channels.pop(channel, None)
+
+        if message_type in self.PUBLISH_MESSAGE_TYPES:
+            # if there's a message handler, invoke it
+            if message_type == 'pmessage':
+                handler = self.patterns.get(message['pattern'], None)
+            else:
+                handler = self.channels.get(message['channel'], None)
+            if handler:
+                # Need to send only channel and notification instead of raw
+                # message
+                message_channel = self._strip_ns_from_bin_key('', message['channel'])
+                message_data = message['data'].decode('utf-8')
+                handler(message_channel, message_data)
+                return message_channel, message_data
+        elif message_type != 'pong':
+            # this is a subscribe/unsubscribe message. ignore if we don't
+            # want them
+            if ignore_subscribe_messages or self.ignore_subscribe_messages:
+                return None
+
+        return message
+
+    @classmethod
+    def _strip_ns_from_bin_key(cls, ns: str, nskey: bytes) -> str:
+        try:
+            redis_key = nskey.decode('utf-8')
+        except UnicodeDecodeError as exc:
+            msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
+            raise RejectedByBackend(msg)
+        nskey = redis_key.split(',', 1)
+        if len(nskey) != 2:
+            msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
+            raise RejectedByBackend(msg)
+        return nskey[1]
+
+
 class RedisBackend(DbBackendAbc):
     """
     A class providing an implementation of database backend of Shared Data Layer (SDL), when
@@ -79,6 +161,10 @@ class RedisBackend(DbBackendAbc):
         self.__redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
         self.__redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
 
+        self.__redis_pubsub = PubSub(self.__redis.connection_pool, ignore_subscribe_messages=True)
+        self.pubsub_thread = threading.Thread(target=None)
+        self._run_in_thread = False
+
     def __del__(self):
         self.close()
 
@@ -178,6 +264,105 @@ class RedisBackend(DbBackendAbc):
         with _map_to_sdl_exception():
             return self.__redis.scard(db_key)
 
+    def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                        data_map: Dict[str, bytes]) -> None:
+        db_data_map = self._add_data_map_ns_prefix(ns, data_map)
+        channels_and_events_prepared = []
+        total_events = 0
+        channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
+        with _map_to_sdl_exception():
+            return self.__redis.execute_command(
+                "MSETMPUB",
+                len(db_data_map),
+                total_events,
+                *[val for data in db_data_map.items() for val in data],
+                *channels_and_events_prepared,
+            )
+
+    def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
+                           old_data: bytes, new_data: bytes) -> bool:
+        db_key = self._add_key_ns_prefix(ns, key)
+        channels_and_events_prepared = []
+        channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
+        with _map_to_sdl_exception():
+            ret = self.__redis.execute_command("SETIEPUB", db_key, new_data, old_data,
+                                               *channels_and_events_prepared)
+            return ret == b"OK"
+
+    def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                                      key: str, data: bytes) -> bool:
+        db_key = self._add_key_ns_prefix(ns, key)
+        channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
+        with _map_to_sdl_exception():
+            ret = self.__redis.execute_command("SETNXPUB", db_key, data,
+                                               *channels_and_events_prepared)
+            return ret == b"OK"
+
+    def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                           keys: List[str]) -> None:
+        db_keys = self._add_keys_ns_prefix(ns, keys)
+        channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
+        with _map_to_sdl_exception():
+            return self.__redis.execute_command(
+                "DELMPUB",
+                len(db_keys),
+                total_events,
+                *db_keys,
+                *channels_and_events_prepared,
+            )
+
+    def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
+                              data: bytes) -> bool:
+        db_key = self._add_key_ns_prefix(ns, key)
+        channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
+        with _map_to_sdl_exception():
+            ret = self.__redis.execute_command("DELIEPUB", db_key, data,
+                                               *channels_and_events_prepared)
+            return bool(ret)
+
+    def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None:
+        keys = self.__redis.keys(self._add_key_ns_prefix(ns, "*"))
+        channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
+        with _map_to_sdl_exception():
+            return self.__redis.execute_command(
+                "DELMPUB",
+                len(keys),
+                total_events,
+                *keys,
+                *channels_and_events_prepared,
+            )
+
+    def subscribe_channel(self, ns: str, cb: Callable[[str, str], None],
+                          channels: List[str]) -> None:
+        channels = self._add_keys_ns_prefix(ns, channels)
+        for channel in channels:
+            with _map_to_sdl_exception():
+                self.__redis_pubsub.subscribe(**{channel: cb})
+                if not self.pubsub_thread.is_alive() and self._run_in_thread:
+                    self.pubsub_thread = self.__redis_pubsub.run_in_thread(sleep_time=0.001,
+                                                                           daemon=True)
+
+    def unsubscribe_channel(self, ns: str, channels: List[str]) -> None:
+        channels = self._add_keys_ns_prefix(ns, channels)
+        for channel in channels:
+            with _map_to_sdl_exception():
+                self.__redis_pubsub.unsubscribe(channel)
+
+    def start_event_listener(self) -> None:
+        if self.pubsub_thread.is_alive():
+            raise RejectedByBackend("Event loop already started")
+        if len(self.__redis.pubsub_channels()) > 0:
+            self.pubsub_thread = self.__redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True)
+        self._run_in_thread = True
+
+    def handle_events(self) -> Optional[Tuple[str, str]]:
+        if self.pubsub_thread.is_alive() or self._run_in_thread:
+            raise RejectedByBackend("Event loop already started")
+        try:
+            return self.__redis_pubsub.get_message(ignore_subscribe_messages=True)
+        except RuntimeError:
+            return None
+
     @classmethod
     def _add_key_ns_prefix(cls, ns: str, key: str):
         return '{' + ns + '},' + key
@@ -212,6 +397,18 @@ class RedisBackend(DbBackendAbc):
             ret_keys.append(nskey[1])
         return ret_keys
 
+    @classmethod
+    def _prepare_channels(cls, ns: str, channels_and_events: Dict[str,
+                                                                  List[str]]) -> Tuple[List, int]:
+        channels_and_events_prepared = []
+        total_events = 0
+        for channel, events in channels_and_events.items():
+            for event in events:
+                channels_and_events_prepared.append(cls._add_key_ns_prefix(ns, channel))
+                channels_and_events_prepared.append(event)
+                total_events += 1
+        return channels_and_events_prepared, total_events
+
     def get_redis_connection(self):
         """Return existing Redis database connection."""
         return self.__redis