Implement sentinel based DB capacity scaling
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
index c319bf0..12726a3 100755 (executable)
@@ -23,6 +23,7 @@
 import contextlib
 import threading
 from typing import (Callable, Dict, Set, List, Optional, Tuple, Union)
+import zlib
 import redis
 from redis import Redis
 from redis.sentinel import Sentinel
@@ -146,63 +147,53 @@ class RedisBackend(DbBackendAbc):
     """
     def __init__(self, configuration: _Configuration) -> None:
         super().__init__()
+        self.next_client_event = 0
+        self.clients = list()
         with _map_to_sdl_exception():
-            if configuration.get_params().db_sentinel_port:
-                sentinel_node = (configuration.get_params().db_host,
-                                 configuration.get_params().db_sentinel_port)
-                master_name = configuration.get_params().db_sentinel_master_name
-                self.__sentinel = Sentinel([sentinel_node])
-                self.__redis = self.__sentinel.master_for(master_name)
-            else:
-                self.__redis = Redis(host=configuration.get_params().db_host,
-                                     port=configuration.get_params().db_port,
-                                     db=0,
-                                     max_connections=20)
-        self.__redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
-        self.__redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
-
-        self.__redis_pubsub = PubSub(self.__redis.connection_pool, ignore_subscribe_messages=True)
-        self.pubsub_thread = threading.Thread(target=None)
-        self._run_in_thread = False
+            self.clients = self.__create_redis_clients(configuration)
 
     def __del__(self):
         self.close()
 
     def __str__(self):
-        return str(
-            {
-                "DB type": "Redis",
-                "Redis connection": repr(self.__redis)
-            }
-        )
+        out = {"DB type": "Redis"}
+        for i, r in enumerate(self.clients):
+            out["Redis client[" + str(i) + "]"] = str(r)
+        return str(out)
 
     def is_connected(self):
+        is_connected = True
         with _map_to_sdl_exception():
-            return self.__redis.ping()
+            for c in self.clients:
+                if not c.redis_client.ping():
+                    is_connected = False
+                    break
+        return is_connected
 
     def close(self):
-        self.__redis.close()
+        for c in self.clients:
+            c.redis_client.close()
 
     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
-        db_data_map = self._add_data_map_ns_prefix(ns, data_map)
+        db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
         with _map_to_sdl_exception():
-            self.__redis.mset(db_data_map)
+            self.__getClient(ns).mset(db_data_map)
 
     def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
-        db_key = self._add_key_ns_prefix(ns, key)
+        db_key = self.__add_key_ns_prefix(ns, key)
         with _map_to_sdl_exception():
-            return self.__redis.execute_command('SETIE', db_key, new_data, old_data)
+            return self.__getClient(ns).execute_command('SETIE', db_key, new_data, old_data)
 
     def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
-        db_key = self._add_key_ns_prefix(ns, key)
+        db_key = self.__add_key_ns_prefix(ns, key)
         with _map_to_sdl_exception():
-            return self.__redis.setnx(db_key, data)
+            return self.__getClient(ns).setnx(db_key, data)
 
     def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
         ret = dict()
-        db_keys = self._add_keys_ns_prefix(ns, keys)
+        db_keys = self.__add_keys_ns_prefix(ns, keys)
         with _map_to_sdl_exception():
-            values = self.__redis.mget(db_keys)
+            values = self.__getClient(ns).mget(db_keys)
             for idx, val in enumerate(values):
                 # return only key values, which has a value
                 if val is not None:
@@ -210,10 +201,10 @@ class RedisBackend(DbBackendAbc):
             return ret
 
     def find_keys(self, ns: str, key_pattern: str) -> List[str]:
-        db_key_pattern = self._add_key_ns_prefix(ns, key_pattern)
+        db_key_pattern = self.__add_key_ns_prefix(ns, key_pattern)
         with _map_to_sdl_exception():
-            ret = self.__redis.keys(db_key_pattern)
-            return self._strip_ns_from_bin_keys(ns, ret)
+            ret = self.__getClient(ns).keys(db_key_pattern)
+            return self.__strip_ns_from_bin_keys(ns, ret)
 
     def find_and_get(self, ns: str, key_pattern: str) -> Dict[str, bytes]:
         # todo: replace below implementation with redis 'NGET' module
@@ -225,53 +216,53 @@ class RedisBackend(DbBackendAbc):
         return ret
 
     def remove(self, ns: str, keys: List[str]) -> None:
-        db_keys = self._add_keys_ns_prefix(ns, keys)
+        db_keys = self.__add_keys_ns_prefix(ns, keys)
         with _map_to_sdl_exception():
-            self.__redis.delete(*db_keys)
+            self.__getClient(ns).delete(*db_keys)
 
     def remove_if(self, ns: str, key: str, data: bytes) -> bool:
-        db_key = self._add_key_ns_prefix(ns, key)
+        db_key = self.__add_key_ns_prefix(ns, key)
         with _map_to_sdl_exception():
-            return self.__redis.execute_command('DELIE', db_key, data)
+            return self.__getClient(ns).execute_command('DELIE', db_key, data)
 
     def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
         with _map_to_sdl_exception():
-            self.__redis.sadd(db_key, *members)
+            self.__getClient(ns).sadd(db_key, *members)
 
     def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
         with _map_to_sdl_exception():
-            self.__redis.srem(db_key, *members)
+            self.__getClient(ns).srem(db_key, *members)
 
     def remove_group(self, ns: str, group: str) -> None:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
         with _map_to_sdl_exception():
-            self.__redis.delete(db_key)
+            self.__getClient(ns).delete(db_key)
 
     def get_members(self, ns: str, group: str) -> Set[bytes]:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
         with _map_to_sdl_exception():
-            return self.__redis.smembers(db_key)
+            return self.__getClient(ns).smembers(db_key)
 
     def is_member(self, ns: str, group: str, member: bytes) -> bool:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
         with _map_to_sdl_exception():
-            return self.__redis.sismember(db_key, member)
+            return self.__getClient(ns).sismember(db_key, member)
 
     def group_size(self, ns: str, group: str) -> int:
-        db_key = self._add_key_ns_prefix(ns, group)
+        db_key = self.__add_key_ns_prefix(ns, group)
         with _map_to_sdl_exception():
-            return self.__redis.scard(db_key)
+            return self.__getClient(ns).scard(db_key)
 
     def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
                         data_map: Dict[str, bytes]) -> None:
-        db_data_map = self._add_data_map_ns_prefix(ns, data_map)
+        db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
         channels_and_events_prepared = []
         total_events = 0
         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
         with _map_to_sdl_exception():
-            return self.__redis.execute_command(
+            return self.__getClient(ns).execute_command(
                 "MSETMPUB",
                 len(db_data_map),
                 total_events,
@@ -281,29 +272,29 @@ class RedisBackend(DbBackendAbc):
 
     def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
                            old_data: bytes, new_data: bytes) -> bool:
-        db_key = self._add_key_ns_prefix(ns, key)
+        db_key = self.__add_key_ns_prefix(ns, key)
         channels_and_events_prepared = []
         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
         with _map_to_sdl_exception():
-            ret = self.__redis.execute_command("SETIEMPUB", db_key, new_data, old_data,
-                                               *channels_and_events_prepared)
+            ret = self.__getClient(ns).execute_command("SETIEMPUB", db_key, new_data, old_data,
+                                                       *channels_and_events_prepared)
             return ret == b"OK"
 
     def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
                                       key: str, data: bytes) -> bool:
-        db_key = self._add_key_ns_prefix(ns, key)
+        db_key = self.__add_key_ns_prefix(ns, key)
         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
         with _map_to_sdl_exception():
-            ret = self.__redis.execute_command("SETNXMPUB", db_key, data,
-                                               *channels_and_events_prepared)
+            ret = self.__getClient(ns).execute_command("SETNXMPUB", db_key, data,
+                                                       *channels_and_events_prepared)
             return ret == b"OK"
 
     def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
                            keys: List[str]) -> None:
-        db_keys = self._add_keys_ns_prefix(ns, keys)
+        db_keys = self.__add_keys_ns_prefix(ns, keys)
         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
         with _map_to_sdl_exception():
-            return self.__redis.execute_command(
+            return self.__getClient(ns).execute_command(
                 "DELMPUB",
                 len(db_keys),
                 total_events,
@@ -313,18 +304,18 @@ class RedisBackend(DbBackendAbc):
 
     def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
                               data: bytes) -> bool:
-        db_key = self._add_key_ns_prefix(ns, key)
+        db_key = self.__add_key_ns_prefix(ns, key)
         channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
         with _map_to_sdl_exception():
-            ret = self.__redis.execute_command("DELIEMPUB", db_key, data,
-                                               *channels_and_events_prepared)
+            ret = self.__getClient(ns).execute_command("DELIEMPUB", db_key, data,
+                                                       *channels_and_events_prepared)
             return bool(ret)
 
     def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None:
-        keys = self.__redis.keys(self._add_key_ns_prefix(ns, "*"))
+        keys = self.__getClient(ns).keys(self.__add_key_ns_prefix(ns, "*"))
         channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
         with _map_to_sdl_exception():
-            return self.__redis.execute_command(
+            return self.__getClient(ns).execute_command(
                 "DELMPUB",
                 len(keys),
                 total_events,
@@ -334,55 +325,115 @@ class RedisBackend(DbBackendAbc):
 
     def subscribe_channel(self, ns: str, cb: Callable[[str, str], None],
                           channels: List[str]) -> None:
-        channels = self._add_keys_ns_prefix(ns, channels)
+        channels = self.__add_keys_ns_prefix(ns, channels)
         for channel in channels:
             with _map_to_sdl_exception():
-                self.__redis_pubsub.subscribe(**{channel: cb})
-                if not self.pubsub_thread.is_alive() and self._run_in_thread:
-                    self.pubsub_thread = self.__redis_pubsub.run_in_thread(sleep_time=0.001,
-                                                                           daemon=True)
+                redis_ctx = self.__getClientConn(ns)
+                redis_ctx.redis_pubsub.subscribe(**{channel: cb})
+                if not redis_ctx.pubsub_thread.is_alive() and redis_ctx.run_in_thread:
+                    redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001,
+                                                                                   daemon=True)
 
     def unsubscribe_channel(self, ns: str, channels: List[str]) -> None:
-        channels = self._add_keys_ns_prefix(ns, channels)
+        channels = self.__add_keys_ns_prefix(ns, channels)
         for channel in channels:
             with _map_to_sdl_exception():
-                self.__redis_pubsub.unsubscribe(channel)
+                self.__getClientConn(ns).redis_pubsub.unsubscribe(channel)
 
     def start_event_listener(self) -> None:
-        if self.pubsub_thread.is_alive():
-            raise RejectedByBackend("Event loop already started")
-        if len(self.__redis.pubsub_channels()) > 0:
-            self.pubsub_thread = self.__redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True)
-        self._run_in_thread = True
+        redis_ctxs = self.__getClientConns()
+        for redis_ctx in redis_ctxs:
+            if redis_ctx.pubsub_thread.is_alive():
+                raise RejectedByBackend("Event loop already started")
+            if redis_ctx.redis_pubsub.subscribed and len(redis_ctx.redis_client.pubsub_channels()) > 0:
+                redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True)
+            redis_ctx.run_in_thread = True
 
     def handle_events(self) -> Optional[Tuple[str, str]]:
-        if self.pubsub_thread.is_alive() or self._run_in_thread:
+        if self.next_client_event >= len(self.clients):
+            self.next_client_event = 0
+        redis_ctx = self.clients[self.next_client_event]
+        self.next_client_event += 1
+        if redis_ctx.pubsub_thread.is_alive() or redis_ctx.run_in_thread:
             raise RejectedByBackend("Event loop already started")
         try:
-            return self.__redis_pubsub.get_message(ignore_subscribe_messages=True)
+            return redis_ctx.redis_pubsub.get_message(ignore_subscribe_messages=True)
         except RuntimeError:
             return None
 
+    def __create_redis_clients(self, config):
+        clients = list()
+        cfg_params = config.get_params()
+        if cfg_params.db_cluster_addr_list is None:
+            clients.append(self.__create_legacy_redis_client(cfg_params))
+        else:
+            for addr in cfg_params.db_cluster_addr_list.split(","):
+                client = self.__create_redis_client(cfg_params, addr)
+                clients.append(client)
+        return clients
+
+    def __create_legacy_redis_client(self, cfg_params):
+        return self.__create_redis_client(cfg_params, cfg_params.db_host)
+
+    def __create_redis_client(self, cfg_params, addr):
+        new_sentinel = None
+        new_redis = None
+        if cfg_params.db_sentinel_port is None:
+            new_redis = Redis(host=addr, port=cfg_params.db_port, db=0, max_connections=20)
+        else:
+            sentinel_node = (addr, cfg_params.db_sentinel_port)
+            master_name = cfg_params.db_sentinel_master_name
+            new_sentinel = Sentinel([sentinel_node])
+            new_redis = new_sentinel.master_for(master_name)
+
+        new_redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
+        new_redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
+
+        redis_pubsub = PubSub(new_redis.connection_pool, ignore_subscribe_messages=True)
+        pubsub_thread = threading.Thread(target=None)
+        run_in_thread = False
+
+        return _RedisConn(new_redis, redis_pubsub, pubsub_thread, run_in_thread)
+
+    def __getClientConns(self):
+        return self.clients
+
+    def __getClientConn(self, ns):
+        clients_cnt = len(self.clients)
+        client_id = self.__get_hash(ns) % clients_cnt
+        return self.clients[client_id]
+
+    def __getClient(self, ns):
+        clients_cnt = len(self.clients)
+        client_id = 0
+        if clients_cnt > 1:
+            client_id = self.__get_hash(ns) % clients_cnt
+        return self.clients[client_id].redis_client
+
+    @classmethod
+    def __get_hash(cls, str):
+        return zlib.crc32(str.encode())
+
     @classmethod
-    def _add_key_ns_prefix(cls, ns: str, key: str):
+    def __add_key_ns_prefix(cls, ns: str, key: str):
         return '{' + ns + '},' + key
 
     @classmethod
-    def _add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
+    def __add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
         ret_nskeys = []
         for k in keylist:
             ret_nskeys.append('{' + ns + '},' + k)
         return ret_nskeys
 
     @classmethod
-    def _add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
+    def __add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
         ret_nsdict = {}
         for key, val in data_dict.items():
             ret_nsdict['{' + ns + '},' + key] = val
         return ret_nsdict
 
     @classmethod
-    def _strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
+    def __strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
         ret_keys = []
         for k in nskeylist:
             try:
@@ -404,14 +455,36 @@ class RedisBackend(DbBackendAbc):
         total_events = 0
         for channel, events in channels_and_events.items():
             for event in events:
-                channels_and_events_prepared.append(cls._add_key_ns_prefix(ns, channel))
+                channels_and_events_prepared.append(cls.__add_key_ns_prefix(ns, channel))
                 channels_and_events_prepared.append(event)
                 total_events += 1
         return channels_and_events_prepared, total_events
 
-    def get_redis_connection(self):
-        """Return existing Redis database connection."""
-        return self.__redis
+    def get_redis_connection(self, ns: str):
+        """Return existing Redis database connection valid for the namespace."""
+        return self.__getClient(ns)
+
+
+class _RedisConn:
+    """
+    Internal class container to hold redis client connection
+    """
+
+    def __init__(self, redis_client, pubsub, pubsub_thread, run_in_thread):
+        self.redis_client = redis_client
+        self.redis_pubsub = pubsub
+        self.pubsub_thread = pubsub_thread
+        self.run_in_thread = run_in_thread
+
+    def __str__(self):
+        return str(
+            {
+                "Client": repr(self.redis_client),
+                "Subscrions": self.redis_pubsub.subscribed,
+                "PubSub thread": repr(self.pubsub_thread),
+                "Run in thread": self.run_in_thread,
+            }
+        )
 
 
 class RedisBackendLock(DbBackendLockAbc):
@@ -445,7 +518,7 @@ class RedisBackendLock(DbBackendLockAbc):
     def __init__(self, ns: str, name: str, expiration: Union[int, float],
                  redis_backend: RedisBackend) -> None:
         super().__init__(ns, name)
-        self.__redis = redis_backend.get_redis_connection()
+        self.__redis = redis_backend.get_redis_connection(ns)
         with _map_to_sdl_exception():
             redis_lockname = '{' + ns + '},' + self._lock_name
             self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)