Fix Flake8 reported errors (E275)
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
index 12726a3..d7139fb 100755 (executable)
@@ -1,5 +1,5 @@
 # Copyright (c) 2019 AT&T Intellectual Property.
-# Copyright (c) 2018-2019 Nokia.
+# Copyright (c) 2018-2022 Nokia.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -28,7 +28,7 @@ import redis
 from redis import Redis
 from redis.sentinel import Sentinel
 from redis.lock import Lock
-from redis._compat import nativestr
+from redis.utils import str_if_bytes
 from redis import exceptions as redis_exceptions
 from ricsdl.configuration import _Configuration
 from ricsdl.exceptions import (
@@ -45,18 +45,22 @@ def _map_to_sdl_exception():
     """Translates known redis exceptions into SDL exceptions."""
     try:
         yield
-    except(redis_exceptions.ResponseError) as exc:
+    except redis_exceptions.ResponseError as exc:
         raise RejectedByBackend("SDL backend rejected the request: {}".
                                 format(str(exc))) from exc
-    except(redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
+    except (redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
         raise NotConnected("SDL not connected to backend: {}".
                            format(str(exc))) from exc
-    except(redis_exceptions.RedisError) as exc:
+    except redis_exceptions.RedisError as exc:
         raise BackendError("SDL backend failed to process the request: {}".
                            format(str(exc))) from exc
 
 
 class PubSub(redis.client.PubSub):
+    def __init__(self, event_separator, connection_pool, ignore_subscribe_messages=False):
+        super().__init__(connection_pool, shard_hint=None, ignore_subscribe_messages=ignore_subscribe_messages)
+        self.event_separator = event_separator
+
     def handle_message(self, response, ignore_subscribe_messages=False):
         """
         Parses a pub/sub message. If the channel or pattern was subscribed to
@@ -65,7 +69,7 @@ class PubSub(redis.client.PubSub):
 
         Adapted from: https://github.com/andymccurdy/redis-py/blob/master/redis/client.py
         """
-        message_type = nativestr(response[0])
+        message_type = str_if_bytes(response[0])
         if message_type == 'pmessage':
             message = {
                 'type': message_type,
@@ -112,8 +116,9 @@ class PubSub(redis.client.PubSub):
                 # message
                 message_channel = self._strip_ns_from_bin_key('', message['channel'])
                 message_data = message['data'].decode('utf-8')
-                handler(message_channel, message_data)
-                return message_channel, message_data
+                messages = message_data.split(self.event_separator)
+                handler(message_channel, messages)
+                return message_channel, messages
         elif message_type != 'pong':
             # this is a subscribe/unsubscribe message. ignore if we don't
             # want them
@@ -148,6 +153,7 @@ class RedisBackend(DbBackendAbc):
     def __init__(self, configuration: _Configuration) -> None:
         super().__init__()
         self.next_client_event = 0
+        self.event_separator = configuration.get_event_separator()
         self.clients = list()
         with _map_to_sdl_exception():
             self.clients = self.__create_redis_clients(configuration)
@@ -323,7 +329,7 @@ class RedisBackend(DbBackendAbc):
                 *channels_and_events_prepared,
             )
 
-    def subscribe_channel(self, ns: str, cb: Callable[[str, str], None],
+    def subscribe_channel(self, ns: str, cb: Callable[[str, List[str]], None],
                           channels: List[str]) -> None:
         channels = self.__add_keys_ns_prefix(ns, channels)
         for channel in channels:
@@ -349,7 +355,7 @@ class RedisBackend(DbBackendAbc):
                 redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True)
             redis_ctx.run_in_thread = True
 
-    def handle_events(self) -> Optional[Tuple[str, str]]:
+    def handle_events(self) -> Optional[Tuple[str, List[str]]]:
         if self.next_client_event >= len(self.clients):
             self.next_client_event = 0
         redis_ctx = self.clients[self.next_client_event]
@@ -364,32 +370,29 @@ class RedisBackend(DbBackendAbc):
     def __create_redis_clients(self, config):
         clients = list()
         cfg_params = config.get_params()
-        if cfg_params.db_cluster_addr_list is None:
-            clients.append(self.__create_legacy_redis_client(cfg_params))
-        else:
-            for addr in cfg_params.db_cluster_addr_list.split(","):
-                client = self.__create_redis_client(cfg_params, addr)
-                clients.append(client)
-        return clients
+        for i, addr in enumerate(cfg_params.db_cluster_addrs):
+            port = cfg_params.db_ports[i] if i < len(cfg_params.db_ports) else ""
+            sport = cfg_params.db_sentinel_ports[i] if i < len(cfg_params.db_sentinel_ports) else ""
+            name = cfg_params.db_sentinel_master_names[i] if i < len(cfg_params.db_sentinel_master_names) else ""
 
-    def __create_legacy_redis_client(self, cfg_params):
-        return self.__create_redis_client(cfg_params, cfg_params.db_host)
+            client = self.__create_redis_client(addr, port, sport, name)
+            clients.append(client)
+        return clients
 
-    def __create_redis_client(self, cfg_params, addr):
+    def __create_redis_client(self, addr, port, sentinel_port, master_name):
         new_sentinel = None
         new_redis = None
-        if cfg_params.db_sentinel_port is None:
-            new_redis = Redis(host=addr, port=cfg_params.db_port, db=0, max_connections=20)
+        if len(sentinel_port) == 0:
+            new_redis = Redis(host=addr, port=port, db=0, max_connections=20)
         else:
-            sentinel_node = (addr, cfg_params.db_sentinel_port)
-            master_name = cfg_params.db_sentinel_master_name
+            sentinel_node = (addr, sentinel_port)
             new_sentinel = Sentinel([sentinel_node])
             new_redis = new_sentinel.master_for(master_name)
 
-        new_redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
+        new_redis.set_response_callback('SETIE', lambda r: r and str_if_bytes(r) == 'OK' or False)
         new_redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
 
-        redis_pubsub = PubSub(new_redis.connection_pool, ignore_subscribe_messages=True)
+        redis_pubsub = PubSub(self.event_separator, new_redis.connection_pool, ignore_subscribe_messages=True)
         pubsub_thread = threading.Thread(target=None)
         run_in_thread = False
 
@@ -448,17 +451,20 @@ class RedisBackend(DbBackendAbc):
             ret_keys.append(nskey[1])
         return ret_keys
 
-    @classmethod
-    def _prepare_channels(cls, ns: str, channels_and_events: Dict[str,
-                                                                  List[str]]) -> Tuple[List, int]:
+    def _prepare_channels(self, ns: str,
+                          channels_and_events: Dict[str, List[str]]) -> Tuple[List, int]:
         channels_and_events_prepared = []
-        total_events = 0
         for channel, events in channels_and_events.items():
+            one_channel_join_events = None
             for event in events:
-                channels_and_events_prepared.append(cls.__add_key_ns_prefix(ns, channel))
-                channels_and_events_prepared.append(event)
-                total_events += 1
-        return channels_and_events_prepared, total_events
+                if one_channel_join_events is None:
+                    channels_and_events_prepared.append(self.__add_key_ns_prefix(ns, channel))
+                    one_channel_join_events = event
+                else:
+                    one_channel_join_events = one_channel_join_events + self.event_separator + event
+            channels_and_events_prepared.append(one_channel_join_events)
+        pairs_cnt = int(len(channels_and_events_prepared) / 2)
+        return channels_and_events_prepared, pairs_cnt
 
     def get_redis_connection(self, ns: str):
         """Return existing Redis database connection valid for the namespace."""
@@ -582,5 +588,5 @@ class RedisBackendLock(DbBackendLockAbc):
                     return 'locked'
                 return 'locked by someone else'
             return 'unlocked'
-        except(redis_exceptions.RedisError) as exc:
+        except redis_exceptions.RedisError as exc:
             return f'Error: {str(exc)}'