Fix Flake8 reported errors (E275)
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
old mode 100644 (file)
new mode 100755 (executable)
index 3364497..d7139fb
@@ -1,5 +1,5 @@
 # Copyright (c) 2019 AT&T Intellectual Property.
-# Copyright (c) 2018-2019 Nokia.
+# Copyright (c) 2018-2022 Nokia.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 
 """The module provides implementation of Shared Data Layer (SDL) database backend interface."""
 import contextlib
-from typing import (Dict, Set, List, Union)
+import threading
+from typing import (Callable, Dict, Set, List, Optional, Tuple, Union)
+import zlib
+import redis
 from redis import Redis
 from redis.sentinel import Sentinel
 from redis.lock import Lock
-from redis._compat import nativestr
+from redis.utils import str_if_bytes
 from redis import exceptions as redis_exceptions
 from ricsdl.configuration import _Configuration
 from ricsdl.exceptions import (
@@ -42,17 +45,102 @@ def _map_to_sdl_exception():
     """Translates known redis exceptions into SDL exceptions."""
     try:
         yield
-    except(redis_exceptions.ResponseError) as exc:
+    except redis_exceptions.ResponseError as exc:
         raise RejectedByBackend("SDL backend rejected the request: {}".
                                 format(str(exc))) from exc
-    except(redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
+    except (redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
         raise NotConnected("SDL not connected to backend: {}".
                            format(str(exc))) from exc
-    except(redis_exceptions.RedisError) as exc:
+    except redis_exceptions.RedisError as exc:
         raise BackendError("SDL backend failed to process the request: {}".
                            format(str(exc))) from exc
 
 
+class PubSub(redis.client.PubSub):
+    def __init__(self, event_separator, connection_pool, ignore_subscribe_messages=False):
+        super().__init__(connection_pool, shard_hint=None, ignore_subscribe_messages=ignore_subscribe_messages)
+        self.event_separator = event_separator
+
+    def handle_message(self, response, ignore_subscribe_messages=False):
+        """
+        Parses a pub/sub message. If the channel or pattern was subscribed to
+        with a message handler, the handler is invoked instead of a parsed
+        message being returned.
+
+        Adapted from: https://github.com/andymccurdy/redis-py/blob/master/redis/client.py
+        """
+        message_type = str_if_bytes(response[0])
+        if message_type == 'pmessage':
+            message = {
+                'type': message_type,
+                'pattern': response[1],
+                'channel': response[2],
+                'data': response[3]
+            }
+        elif message_type == 'pong':
+            message = {
+                'type': message_type,
+                'pattern': None,
+                'channel': None,
+                'data': response[1]
+            }
+        else:
+            message = {
+                'type': message_type,
+                'pattern': None,
+                'channel': response[1],
+                'data': response[2]
+            }
+
+        # if this is an unsubscribe message, remove it from memory
+        if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
+            if message_type == 'punsubscribe':
+                pattern = response[1]
+                if pattern in self.pending_unsubscribe_patterns:
+                    self.pending_unsubscribe_patterns.remove(pattern)
+                    self.patterns.pop(pattern, None)
+            else:
+                channel = response[1]
+                if channel in self.pending_unsubscribe_channels:
+                    self.pending_unsubscribe_channels.remove(channel)
+                    self.channels.pop(channel, None)
+
+        if message_type in self.PUBLISH_MESSAGE_TYPES:
+            # if there's a message handler, invoke it
+            if message_type == 'pmessage':
+                handler = self.patterns.get(message['pattern'], None)
+            else:
+                handler = self.channels.get(message['channel'], None)
+            if handler:
+                # Need to send only channel and notification instead of raw
+                # message
+                message_channel = self._strip_ns_from_bin_key('', message['channel'])
+                message_data = message['data'].decode('utf-8')
+                messages = message_data.split(self.event_separator)
+                handler(message_channel, messages)
+                return message_channel, messages
+        elif message_type != 'pong':
+            # this is a subscribe/unsubscribe message. ignore if we don't
+            # want them
+            if ignore_subscribe_messages or self.ignore_subscribe_messages:
+                return None
+
+        return message
+
+    @classmethod
+    def _strip_ns_from_bin_key(cls, ns: str, nskey: bytes) -> str:
+        try:
+            redis_key = nskey.decode('utf-8')
+        except UnicodeDecodeError as exc:
+            msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
+            raise RejectedByBackend(msg)
+        nskey = redis_key.split(',', 1)
+        if len(nskey) != 2:
+            msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
+            raise RejectedByBackend(msg)
+        return nskey[1]
+
+
 class RedisBackend(DbBackendAbc):
     """
     A class providing an implementation of database backend of Shared Data Layer (SDL), when
@@ -64,70 +152,65 @@ class RedisBackend(DbBackendAbc):
     """
     def __init__(self, configuration: _Configuration) -> None:
         super().__init__()
+        self.next_client_event = 0
+        self.event_separator = configuration.get_event_separator()
+        self.clients = list()
         with _map_to_sdl_exception():
-            if configuration.get_params().db_sentinel_port:
-                sentinel_node = (configuration.get_params().db_host,
-                                 configuration.get_params().db_sentinel_port)
-                master_name = configuration.get_params().db_sentinel_master_name
-                self.__sentinel = Sentinel([sentinel_node])
-                self.__redis = self.__sentinel.master_for(master_name)
-            else:
-                self.__redis = Redis(host=configuration.get_params().db_host,
-                                     port=configuration.get_params().db_port,
-                                     db=0,
-                                     max_connections=20)
-        self.__redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
-        self.__redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
+            self.clients = self.__create_redis_clients(configuration)
 
     def __del__(self):
         self.close()
 
     def __str__(self):
-        return str(
-            {
-                "DB type": "Redis",
-                "Redis connection": repr(self.__redis)
-            }
-        )
+        out = {"DB type": "Redis"}
+        for i, r in enumerate(self.clients):
+            out["Redis client[" + str(i) + "]"] = str(r)
+        return str(out)
 
     def is_connected(self):
+        is_connected = True
         with _map_to_sdl_exception():
-            return self.__redis.ping()
+            for c in self.clients:
+                if not c.redis_client.ping():
+                    is_connected = False
+                    break
+        return is_connected
 
     def close(self):
-        self.__redis.close()
+        for c in self.clients:
+            c.redis_client.close()
 
     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
-        db_data_map = self._add_data_map_ns_prefix(ns, data_map)
+        db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
         with _map_to_sdl_exception():
-            self.__redis.mset(db_data_map)
+            self.__getClient(ns).mset(db_data_map)
 
     def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
-        db_key = self._add_key_ns_prefix(ns, key)
+        db_key = self.__add_key_ns_prefix(ns, key)
         with _map_to_sdl_exception():
-            return self.__redis.execute_command('SETIE', db_key, new_data, old_data)
+            return self.__getClient(ns).execute_command('SETIE', db_key, new_data, old_data)
 
     def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
-        db_key = self._add_key_ns_prefix(ns, key)
+        db_key = self.__add_key_ns_prefix(ns, key)
         with _map_to_sdl_exception():
-            return self.__redis.setnx(db_key, data)
+            return self.__getClient(ns).setnx(db_key, data)
 
     def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
         ret = dict()
-        db_keys = self._add_keys_ns_prefix(ns, keys)
+        db_keys = self.__add_keys_ns_prefix(ns, keys)
         with _map_to_sdl_exception():
-            values = self.__redis.mget(db_keys)
+            values = self.__getClient(ns).mget(db_keys)
             for idx, val in enumerate(values):
                 # return only key values, which has a value
-                if val:
+                if val is not None:
                     ret[keys[idx]] = val
             return ret
 
     def find_keys(self, ns: str, key_pattern: str) -> List[str]:
-        db_key_pattern = self._add_key_ns_prefix(ns, key_pattern)
+        db_key_pattern = self.__add_key_ns_prefix(ns, key_pattern)
         with _map_to_sdl_exception():
-            ret = self.__redis.keys(db_key_pattern)
-            return self._strip_ns_from_bin_keys(ns, ret)
+            ret = self.__getClient(ns).keys(db_key_pattern)
+            return self.__strip_ns_from_bin_keys(ns, ret)
 
     def find_and_get(self, ns: str, key_pattern: str) -> Dict[str, bytes]:
         # todo: replace below implementation with redis 'NGET' module
@@ -139,65 +222,221 @@ class RedisBackend(DbBackendAbc):
         return ret
 
     def remove(self, ns: str, keys: List[str]) -> None:
-        db_keys = self._add_keys_ns_prefix(ns, keys)
+        db_keys = self.__add_keys_ns_prefix(ns, keys)
         with _map_to_sdl_exception():
-            self.__redis.delete(*db_keys)
+            self.__getClient(ns).delete(*db_keys)
 
     def remove_if(self, ns: str, key: str, data: bytes) -> bool:
-        db_key = self._add_key_ns_prefix(ns, key)
+        db_key = self.__add_key_ns_prefix(ns, key)
         with _map_to_sdl_exception():
-            return self.__redis.execute_command('DELIE', db_key, data)
+            return self.__getClient(ns).execute_command('DELIE', db_key, data)
 
     def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
         with _map_to_sdl_exception():
-            self.__redis.sadd(db_key, *members)
+            self.__getClient(ns).sadd(db_key, *members)
 
     def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
         with _map_to_sdl_exception():
-            self.__redis.srem(db_key, *members)
+            self.__getClient(ns).srem(db_key, *members)
 
     def remove_group(self, ns: str, group: str) -> None:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
         with _map_to_sdl_exception():
-            self.__redis.delete(db_key)
+            self.__getClient(ns).delete(db_key)
 
     def get_members(self, ns: str, group: str) -> Set[bytes]:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
         with _map_to_sdl_exception():
-            return self.__redis.smembers(db_key)
+            return self.__getClient(ns).smembers(db_key)
 
     def is_member(self, ns: str, group: str, member: bytes) -> bool:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
         with _map_to_sdl_exception():
-            return self.__redis.sismember(db_key, member)
+            return self.__getClient(ns).sismember(db_key, member)
 
     def group_size(self, ns: str, group: str) -> int:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
+        with _map_to_sdl_exception():
+            return self.__getClient(ns).scard(db_key)
+
+    def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                        data_map: Dict[str, bytes]) -> None:
+        db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
+        channels_and_events_prepared = []
+        total_events = 0
+        channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
         with _map_to_sdl_exception():
-            return self.__redis.scard(db_key)
+            return self.__getClient(ns).execute_command(
+                "MSETMPUB",
+                len(db_data_map),
+                total_events,
+                *[val for data in db_data_map.items() for val in data],
+                *channels_and_events_prepared,
+            )
+
+    def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
+                           old_data: bytes, new_data: bytes) -> bool:
+        db_key = self.__add_key_ns_prefix(ns, key)
+        channels_and_events_prepared = []
+        channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
+        with _map_to_sdl_exception():
+            ret = self.__getClient(ns).execute_command("SETIEMPUB", db_key, new_data, old_data,
+                                                       *channels_and_events_prepared)
+            return ret == b"OK"
+
+    def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                                      key: str, data: bytes) -> bool:
+        db_key = self.__add_key_ns_prefix(ns, key)
+        channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
+        with _map_to_sdl_exception():
+            ret = self.__getClient(ns).execute_command("SETNXMPUB", db_key, data,
+                                                       *channels_and_events_prepared)
+            return ret == b"OK"
+
+    def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                           keys: List[str]) -> None:
+        db_keys = self.__add_keys_ns_prefix(ns, keys)
+        channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
+        with _map_to_sdl_exception():
+            return self.__getClient(ns).execute_command(
+                "DELMPUB",
+                len(db_keys),
+                total_events,
+                *db_keys,
+                *channels_and_events_prepared,
+            )
+
+    def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
+                              data: bytes) -> bool:
+        db_key = self.__add_key_ns_prefix(ns, key)
+        channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
+        with _map_to_sdl_exception():
+            ret = self.__getClient(ns).execute_command("DELIEMPUB", db_key, data,
+                                                       *channels_and_events_prepared)
+            return bool(ret)
+
+    def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None:
+        keys = self.__getClient(ns).keys(self.__add_key_ns_prefix(ns, "*"))
+        channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
+        with _map_to_sdl_exception():
+            return self.__getClient(ns).execute_command(
+                "DELMPUB",
+                len(keys),
+                total_events,
+                *keys,
+                *channels_and_events_prepared,
+            )
+
+    def subscribe_channel(self, ns: str, cb: Callable[[str, List[str]], None],
+                          channels: List[str]) -> None:
+        channels = self.__add_keys_ns_prefix(ns, channels)
+        for channel in channels:
+            with _map_to_sdl_exception():
+                redis_ctx = self.__getClientConn(ns)
+                redis_ctx.redis_pubsub.subscribe(**{channel: cb})
+                if not redis_ctx.pubsub_thread.is_alive() and redis_ctx.run_in_thread:
+                    redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001,
+                                                                                   daemon=True)
+
+    def unsubscribe_channel(self, ns: str, channels: List[str]) -> None:
+        channels = self.__add_keys_ns_prefix(ns, channels)
+        for channel in channels:
+            with _map_to_sdl_exception():
+                self.__getClientConn(ns).redis_pubsub.unsubscribe(channel)
+
+    def start_event_listener(self) -> None:
+        redis_ctxs = self.__getClientConns()
+        for redis_ctx in redis_ctxs:
+            if redis_ctx.pubsub_thread.is_alive():
+                raise RejectedByBackend("Event loop already started")
+            if redis_ctx.redis_pubsub.subscribed and len(redis_ctx.redis_client.pubsub_channels()) > 0:
+                redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True)
+            redis_ctx.run_in_thread = True
+
+    def handle_events(self) -> Optional[Tuple[str, List[str]]]:
+        if self.next_client_event >= len(self.clients):
+            self.next_client_event = 0
+        redis_ctx = self.clients[self.next_client_event]
+        self.next_client_event += 1
+        if redis_ctx.pubsub_thread.is_alive() or redis_ctx.run_in_thread:
+            raise RejectedByBackend("Event loop already started")
+        try:
+            return redis_ctx.redis_pubsub.get_message(ignore_subscribe_messages=True)
+        except RuntimeError:
+            return None
+
+    def __create_redis_clients(self, config):
+        clients = list()
+        cfg_params = config.get_params()
+        for i, addr in enumerate(cfg_params.db_cluster_addrs):
+            port = cfg_params.db_ports[i] if i < len(cfg_params.db_ports) else ""
+            sport = cfg_params.db_sentinel_ports[i] if i < len(cfg_params.db_sentinel_ports) else ""
+            name = cfg_params.db_sentinel_master_names[i] if i < len(cfg_params.db_sentinel_master_names) else ""
+
+            client = self.__create_redis_client(addr, port, sport, name)
+            clients.append(client)
+        return clients
+
+    def __create_redis_client(self, addr, port, sentinel_port, master_name):
+        new_sentinel = None
+        new_redis = None
+        if len(sentinel_port) == 0:
+            new_redis = Redis(host=addr, port=port, db=0, max_connections=20)
+        else:
+            sentinel_node = (addr, sentinel_port)
+            new_sentinel = Sentinel([sentinel_node])
+            new_redis = new_sentinel.master_for(master_name)
+
+        new_redis.set_response_callback('SETIE', lambda r: r and str_if_bytes(r) == 'OK' or False)
+        new_redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
+
+        redis_pubsub = PubSub(self.event_separator, new_redis.connection_pool, ignore_subscribe_messages=True)
+        pubsub_thread = threading.Thread(target=None)
+        run_in_thread = False
+
+        return _RedisConn(new_redis, redis_pubsub, pubsub_thread, run_in_thread)
+
+    def __getClientConns(self):
+        return self.clients
+
+    def __getClientConn(self, ns):
+        clients_cnt = len(self.clients)
+        client_id = self.__get_hash(ns) % clients_cnt
+        return self.clients[client_id]
+
+    def __getClient(self, ns):
+        clients_cnt = len(self.clients)
+        client_id = 0
+        if clients_cnt > 1:
+            client_id = self.__get_hash(ns) % clients_cnt
+        return self.clients[client_id].redis_client
+
+    @classmethod
+    def __get_hash(cls, str):
+        return zlib.crc32(str.encode())
 
     @classmethod
-    def _add_key_ns_prefix(cls, ns: str, key: str):
+    def __add_key_ns_prefix(cls, ns: str, key: str):
         return '{' + ns + '},' + key
 
     @classmethod
-    def _add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
+    def __add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
         ret_nskeys = []
         for k in keylist:
             ret_nskeys.append('{' + ns + '},' + k)
         return ret_nskeys
 
     @classmethod
-    def _add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
+    def __add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
         ret_nsdict = {}
         for key, val in data_dict.items():
             ret_nsdict['{' + ns + '},' + key] = val
         return ret_nsdict
 
     @classmethod
-    def _strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
+    def __strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
         ret_keys = []
         for k in nskeylist:
             try:
@@ -212,9 +451,46 @@ class RedisBackend(DbBackendAbc):
             ret_keys.append(nskey[1])
         return ret_keys
 
-    def get_redis_connection(self):
-        """Return existing Redis database connection."""
-        return self.__redis
+    def _prepare_channels(self, ns: str,
+                          channels_and_events: Dict[str, List[str]]) -> Tuple[List, int]:
+        channels_and_events_prepared = []
+        for channel, events in channels_and_events.items():
+            one_channel_join_events = None
+            for event in events:
+                if one_channel_join_events is None:
+                    channels_and_events_prepared.append(self.__add_key_ns_prefix(ns, channel))
+                    one_channel_join_events = event
+                else:
+                    one_channel_join_events = one_channel_join_events + self.event_separator + event
+            channels_and_events_prepared.append(one_channel_join_events)
+        pairs_cnt = int(len(channels_and_events_prepared) / 2)
+        return channels_and_events_prepared, pairs_cnt
+
+    def get_redis_connection(self, ns: str):
+        """Return existing Redis database connection valid for the namespace."""
+        return self.__getClient(ns)
+
+
+class _RedisConn:
+    """
+    Internal class container to hold redis client connection
+    """
+
+    def __init__(self, redis_client, pubsub, pubsub_thread, run_in_thread):
+        self.redis_client = redis_client
+        self.redis_pubsub = pubsub
+        self.pubsub_thread = pubsub_thread
+        self.run_in_thread = run_in_thread
+
+    def __str__(self):
+        return str(
+            {
+                "Client": repr(self.redis_client),
+                "Subscrions": self.redis_pubsub.subscribed,
+                "PubSub thread": repr(self.pubsub_thread),
+                "Run in thread": self.run_in_thread,
+            }
+        )
 
 
 class RedisBackendLock(DbBackendLockAbc):
@@ -248,7 +524,7 @@ class RedisBackendLock(DbBackendLockAbc):
     def __init__(self, ns: str, name: str, expiration: Union[int, float],
                  redis_backend: RedisBackend) -> None:
         super().__init__(ns, name)
-        self.__redis = redis_backend.get_redis_connection()
+        self.__redis = redis_backend.get_redis_connection(ns)
         with _map_to_sdl_exception():
             redis_lockname = '{' + ns + '},' + self._lock_name
             self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)
@@ -312,5 +588,5 @@ class RedisBackendLock(DbBackendLockAbc):
                     return 'locked'
                 return 'locked by someone else'
             return 'unlocked'
-        except(redis_exceptions.RedisError) as exc:
+        except redis_exceptions.RedisError as exc:
             return f'Error: {str(exc)}'