"""The module provides implementation of Shared Data Layer (SDL) database backend interface."""
import contextlib
-from typing import (Dict, Set, List, Union)
+import threading
+from typing import (Callable, Dict, Set, List, Optional, Tuple, Union)
+import zlib
+import redis
from redis import Redis
from redis.sentinel import Sentinel
from redis.lock import Lock
-from redis._compat import nativestr
+from redis.utils import str_if_bytes
from redis import exceptions as redis_exceptions
from ricsdl.configuration import _Configuration
from ricsdl.exceptions import (
format(str(exc))) from exc
+class PubSub(redis.client.PubSub):
+ def __init__(self, event_separator, connection_pool, ignore_subscribe_messages=False):
+ super().__init__(connection_pool, shard_hint=None, ignore_subscribe_messages=ignore_subscribe_messages)
+ self.event_separator = event_separator
+
+ def handle_message(self, response, ignore_subscribe_messages=False):
+ """
+ Parses a pub/sub message. If the channel or pattern was subscribed to
+ with a message handler, the handler is invoked instead of a parsed
+ message being returned.
+
+ Adapted from: https://github.com/andymccurdy/redis-py/blob/master/redis/client.py
+ """
+ message_type = str_if_bytes(response[0])
+ if message_type == 'pmessage':
+ message = {
+ 'type': message_type,
+ 'pattern': response[1],
+ 'channel': response[2],
+ 'data': response[3]
+ }
+ elif message_type == 'pong':
+ message = {
+ 'type': message_type,
+ 'pattern': None,
+ 'channel': None,
+ 'data': response[1]
+ }
+ else:
+ message = {
+ 'type': message_type,
+ 'pattern': None,
+ 'channel': response[1],
+ 'data': response[2]
+ }
+
+ # if this is an unsubscribe message, remove it from memory
+ if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
+ if message_type == 'punsubscribe':
+ pattern = response[1]
+ if pattern in self.pending_unsubscribe_patterns:
+ self.pending_unsubscribe_patterns.remove(pattern)
+ self.patterns.pop(pattern, None)
+ else:
+ channel = response[1]
+ if channel in self.pending_unsubscribe_channels:
+ self.pending_unsubscribe_channels.remove(channel)
+ self.channels.pop(channel, None)
+
+ if message_type in self.PUBLISH_MESSAGE_TYPES:
+ # if there's a message handler, invoke it
+ if message_type == 'pmessage':
+ handler = self.patterns.get(message['pattern'], None)
+ else:
+ handler = self.channels.get(message['channel'], None)
+ if handler:
+ # Need to send only channel and notification instead of raw
+ # message
+ message_channel = self._strip_ns_from_bin_key('', message['channel'])
+ message_data = message['data'].decode('utf-8')
+ messages = message_data.split(self.event_separator)
+ handler(message_channel, messages)
+ return message_channel, messages
+ elif message_type != 'pong':
+ # this is a subscribe/unsubscribe message. ignore if we don't
+ # want them
+ if ignore_subscribe_messages or self.ignore_subscribe_messages:
+ return None
+
+ return message
+
+ @classmethod
+ def _strip_ns_from_bin_key(cls, ns: str, nskey: bytes) -> str:
+ try:
+ redis_key = nskey.decode('utf-8')
+ except UnicodeDecodeError as exc:
+ msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc))
+ raise RejectedByBackend(msg)
+ nskey = redis_key.split(',', 1)
+ if len(nskey) != 2:
+ msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key)
+ raise RejectedByBackend(msg)
+ return nskey[1]
+
+
class RedisBackend(DbBackendAbc):
"""
A class providing an implementation of database backend of Shared Data Layer (SDL), when
"""
def __init__(self, configuration: _Configuration) -> None:
super().__init__()
+ self.next_client_event = 0
+ self.event_separator = configuration.get_event_separator()
+ self.clients = list()
with _map_to_sdl_exception():
- if configuration.get_params().db_sentinel_port:
- sentinel_node = (configuration.get_params().db_host,
- configuration.get_params().db_sentinel_port)
- master_name = configuration.get_params().db_sentinel_master_name
- self.__sentinel = Sentinel([sentinel_node])
- self.__redis = self.__sentinel.master_for(master_name)
- else:
- self.__redis = Redis(host=configuration.get_params().db_host,
- port=configuration.get_params().db_port,
- db=0,
- max_connections=20)
- self.__redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
- self.__redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
+ self.clients = self.__create_redis_clients(configuration)
def __del__(self):
self.close()
def __str__(self):
- return str(
- {
- "DB type": "Redis",
- "Redis connection": repr(self.__redis)
- }
- )
+ out = {"DB type": "Redis"}
+ for i, r in enumerate(self.clients):
+ out["Redis client[" + str(i) + "]"] = str(r)
+ return str(out)
+
+ def is_connected(self):
+ is_connected = True
+ with _map_to_sdl_exception():
+ for c in self.clients:
+ if not c.redis_client.ping():
+ is_connected = False
+ break
+ return is_connected
def close(self):
- self.__redis.close()
+ for c in self.clients:
+ c.redis_client.close()
def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
- db_data_map = self._add_data_map_ns_prefix(ns, data_map)
+ db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
with _map_to_sdl_exception():
- self.__redis.mset(db_data_map)
+ self.__getClient(ns).mset(db_data_map)
def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
- db_key = self._add_key_ns_prefix(ns, key)
+ db_key = self.__add_key_ns_prefix(ns, key)
with _map_to_sdl_exception():
- return self.__redis.execute_command('SETIE', db_key, new_data, old_data)
+ return self.__getClient(ns).execute_command('SETIE', db_key, new_data, old_data)
def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
- db_key = self._add_key_ns_prefix(ns, key)
+ db_key = self.__add_key_ns_prefix(ns, key)
with _map_to_sdl_exception():
- return self.__redis.setnx(db_key, data)
+ return self.__getClient(ns).setnx(db_key, data)
def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
ret = dict()
- db_keys = self._add_keys_ns_prefix(ns, keys)
+ db_keys = self.__add_keys_ns_prefix(ns, keys)
with _map_to_sdl_exception():
- values = self.__redis.mget(db_keys)
+ values = self.__getClient(ns).mget(db_keys)
for idx, val in enumerate(values):
# return only key values, which has a value
- if val:
+ if val is not None:
ret[keys[idx]] = val
return ret
def find_keys(self, ns: str, key_pattern: str) -> List[str]:
- db_key_pattern = self._add_key_ns_prefix(ns, key_pattern)
+ db_key_pattern = self.__add_key_ns_prefix(ns, key_pattern)
with _map_to_sdl_exception():
- ret = self.__redis.keys(db_key_pattern)
- return self._strip_ns_from_bin_keys(ns, ret)
+ ret = self.__getClient(ns).keys(db_key_pattern)
+ return self.__strip_ns_from_bin_keys(ns, ret)
def find_and_get(self, ns: str, key_pattern: str) -> Dict[str, bytes]:
# todo: replace below implementation with redis 'NGET' module
return ret
def remove(self, ns: str, keys: List[str]) -> None:
- db_keys = self._add_keys_ns_prefix(ns, keys)
+ db_keys = self.__add_keys_ns_prefix(ns, keys)
with _map_to_sdl_exception():
- self.__redis.delete(*db_keys)
+ self.__getClient(ns).delete(*db_keys)
def remove_if(self, ns: str, key: str, data: bytes) -> bool:
- db_key = self._add_key_ns_prefix(ns, key)
+ db_key = self.__add_key_ns_prefix(ns, key)
with _map_to_sdl_exception():
- return self.__redis.execute_command('DELIE', db_key, data)
+ return self.__getClient(ns).execute_command('DELIE', db_key, data)
def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
- db_key = self._add_key_ns_prefix(ns, group)
+ db_key = self.__add_key_ns_prefix(ns, group)
with _map_to_sdl_exception():
- self.__redis.sadd(db_key, *members)
+ self.__getClient(ns).sadd(db_key, *members)
def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
- db_key = self._add_key_ns_prefix(ns, group)
+ db_key = self.__add_key_ns_prefix(ns, group)
with _map_to_sdl_exception():
- self.__redis.srem(db_key, *members)
+ self.__getClient(ns).srem(db_key, *members)
def remove_group(self, ns: str, group: str) -> None:
- db_key = self._add_key_ns_prefix(ns, group)
+ db_key = self.__add_key_ns_prefix(ns, group)
with _map_to_sdl_exception():
- self.__redis.delete(db_key)
+ self.__getClient(ns).delete(db_key)
def get_members(self, ns: str, group: str) -> Set[bytes]:
- db_key = self._add_key_ns_prefix(ns, group)
+ db_key = self.__add_key_ns_prefix(ns, group)
with _map_to_sdl_exception():
- return self.__redis.smembers(db_key)
+ return self.__getClient(ns).smembers(db_key)
def is_member(self, ns: str, group: str, member: bytes) -> bool:
- db_key = self._add_key_ns_prefix(ns, group)
+ db_key = self.__add_key_ns_prefix(ns, group)
with _map_to_sdl_exception():
- return self.__redis.sismember(db_key, member)
+ return self.__getClient(ns).sismember(db_key, member)
def group_size(self, ns: str, group: str) -> int:
- db_key = self._add_key_ns_prefix(ns, group)
+ db_key = self.__add_key_ns_prefix(ns, group)
with _map_to_sdl_exception():
- return self.__redis.scard(db_key)
+ return self.__getClient(ns).scard(db_key)
+
+ def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+ data_map: Dict[str, bytes]) -> None:
+ db_data_map = self.__add_data_map_ns_prefix(ns, data_map)
+ channels_and_events_prepared = []
+ total_events = 0
+ channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
+ with _map_to_sdl_exception():
+ return self.__getClient(ns).execute_command(
+ "MSETMPUB",
+ len(db_data_map),
+ total_events,
+ *[val for data in db_data_map.items() for val in data],
+ *channels_and_events_prepared,
+ )
+
+ def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
+ old_data: bytes, new_data: bytes) -> bool:
+ db_key = self.__add_key_ns_prefix(ns, key)
+ channels_and_events_prepared = []
+ channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
+ with _map_to_sdl_exception():
+ ret = self.__getClient(ns).execute_command("SETIEMPUB", db_key, new_data, old_data,
+ *channels_and_events_prepared)
+ return ret == b"OK"
+
+ def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+ key: str, data: bytes) -> bool:
+ db_key = self.__add_key_ns_prefix(ns, key)
+ channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
+ with _map_to_sdl_exception():
+ ret = self.__getClient(ns).execute_command("SETNXMPUB", db_key, data,
+ *channels_and_events_prepared)
+ return ret == b"OK"
+
+ def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+ keys: List[str]) -> None:
+ db_keys = self.__add_keys_ns_prefix(ns, keys)
+ channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
+ with _map_to_sdl_exception():
+ return self.__getClient(ns).execute_command(
+ "DELMPUB",
+ len(db_keys),
+ total_events,
+ *db_keys,
+ *channels_and_events_prepared,
+ )
+
+ def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
+ data: bytes) -> bool:
+ db_key = self.__add_key_ns_prefix(ns, key)
+ channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events)
+ with _map_to_sdl_exception():
+ ret = self.__getClient(ns).execute_command("DELIEMPUB", db_key, data,
+ *channels_and_events_prepared)
+ return bool(ret)
+
+ def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None:
+ keys = self.__getClient(ns).keys(self.__add_key_ns_prefix(ns, "*"))
+ channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events)
+ with _map_to_sdl_exception():
+ return self.__getClient(ns).execute_command(
+ "DELMPUB",
+ len(keys),
+ total_events,
+ *keys,
+ *channels_and_events_prepared,
+ )
+
+ def subscribe_channel(self, ns: str, cb: Callable[[str, List[str]], None],
+ channels: List[str]) -> None:
+ channels = self.__add_keys_ns_prefix(ns, channels)
+ for channel in channels:
+ with _map_to_sdl_exception():
+ redis_ctx = self.__getClientConn(ns)
+ redis_ctx.redis_pubsub.subscribe(**{channel: cb})
+ if not redis_ctx.pubsub_thread.is_alive() and redis_ctx.run_in_thread:
+ redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001,
+ daemon=True)
+
+ def unsubscribe_channel(self, ns: str, channels: List[str]) -> None:
+ channels = self.__add_keys_ns_prefix(ns, channels)
+ for channel in channels:
+ with _map_to_sdl_exception():
+ self.__getClientConn(ns).redis_pubsub.unsubscribe(channel)
+
+ def start_event_listener(self) -> None:
+ redis_ctxs = self.__getClientConns()
+ for redis_ctx in redis_ctxs:
+ if redis_ctx.pubsub_thread.is_alive():
+ raise RejectedByBackend("Event loop already started")
+ if redis_ctx.redis_pubsub.subscribed and len(redis_ctx.redis_client.pubsub_channels()) > 0:
+ redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True)
+ redis_ctx.run_in_thread = True
+
+ def handle_events(self) -> Optional[Tuple[str, List[str]]]:
+ if self.next_client_event >= len(self.clients):
+ self.next_client_event = 0
+ redis_ctx = self.clients[self.next_client_event]
+ self.next_client_event += 1
+ if redis_ctx.pubsub_thread.is_alive() or redis_ctx.run_in_thread:
+ raise RejectedByBackend("Event loop already started")
+ try:
+ return redis_ctx.redis_pubsub.get_message(ignore_subscribe_messages=True)
+ except RuntimeError:
+ return None
+
+ def __create_redis_clients(self, config):
+ clients = list()
+ cfg_params = config.get_params()
+ if cfg_params.db_cluster_addr_list is None:
+ clients.append(self.__create_legacy_redis_client(cfg_params))
+ else:
+ for addr in cfg_params.db_cluster_addr_list.split(","):
+ client = self.__create_redis_client(cfg_params, addr)
+ clients.append(client)
+ return clients
+
+ def __create_legacy_redis_client(self, cfg_params):
+ return self.__create_redis_client(cfg_params, cfg_params.db_host)
+
+ def __create_redis_client(self, cfg_params, addr):
+ new_sentinel = None
+ new_redis = None
+ if cfg_params.db_sentinel_port is None:
+ new_redis = Redis(host=addr, port=cfg_params.db_port, db=0, max_connections=20)
+ else:
+ sentinel_node = (addr, cfg_params.db_sentinel_port)
+ master_name = cfg_params.db_sentinel_master_name
+ new_sentinel = Sentinel([sentinel_node])
+ new_redis = new_sentinel.master_for(master_name)
+
+ new_redis.set_response_callback('SETIE', lambda r: r and str_if_bytes(r) == 'OK' or False)
+ new_redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
+
+ redis_pubsub = PubSub(self.event_separator, new_redis.connection_pool, ignore_subscribe_messages=True)
+ pubsub_thread = threading.Thread(target=None)
+ run_in_thread = False
+
+ return _RedisConn(new_redis, redis_pubsub, pubsub_thread, run_in_thread)
+
+ def __getClientConns(self):
+ return self.clients
+
+ def __getClientConn(self, ns):
+ clients_cnt = len(self.clients)
+ client_id = self.__get_hash(ns) % clients_cnt
+ return self.clients[client_id]
+
+ def __getClient(self, ns):
+ clients_cnt = len(self.clients)
+ client_id = 0
+ if clients_cnt > 1:
+ client_id = self.__get_hash(ns) % clients_cnt
+ return self.clients[client_id].redis_client
+
+ @classmethod
+ def __get_hash(cls, str):
+ return zlib.crc32(str.encode())
@classmethod
- def _add_key_ns_prefix(cls, ns: str, key: str):
+ def __add_key_ns_prefix(cls, ns: str, key: str):
return '{' + ns + '},' + key
@classmethod
- def _add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
+ def __add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
ret_nskeys = []
for k in keylist:
ret_nskeys.append('{' + ns + '},' + k)
return ret_nskeys
@classmethod
- def _add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
+ def __add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
ret_nsdict = {}
for key, val in data_dict.items():
ret_nsdict['{' + ns + '},' + key] = val
return ret_nsdict
@classmethod
- def _strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
+ def __strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
ret_keys = []
for k in nskeylist:
try:
ret_keys.append(nskey[1])
return ret_keys
- def get_redis_connection(self):
- """Return existing Redis database connection."""
- return self.__redis
+ def _prepare_channels(self, ns: str,
+ channels_and_events: Dict[str, List[str]]) -> Tuple[List, int]:
+ channels_and_events_prepared = []
+ for channel, events in channels_and_events.items():
+ one_channel_join_events = None
+ for event in events:
+ if one_channel_join_events is None:
+ channels_and_events_prepared.append(self.__add_key_ns_prefix(ns, channel))
+ one_channel_join_events = event
+ else:
+ one_channel_join_events = one_channel_join_events + self.event_separator + event
+ channels_and_events_prepared.append(one_channel_join_events)
+ pairs_cnt = int(len(channels_and_events_prepared) / 2)
+ return channels_and_events_prepared, pairs_cnt
+
+ def get_redis_connection(self, ns: str):
+ """Return existing Redis database connection valid for the namespace."""
+ return self.__getClient(ns)
+
+
+class _RedisConn:
+ """
+ Internal class container to hold redis client connection
+ """
+
+ def __init__(self, redis_client, pubsub, pubsub_thread, run_in_thread):
+ self.redis_client = redis_client
+ self.redis_pubsub = pubsub
+ self.pubsub_thread = pubsub_thread
+ self.run_in_thread = run_in_thread
+
+ def __str__(self):
+ return str(
+ {
+ "Client": repr(self.redis_client),
+ "Subscrions": self.redis_pubsub.subscribed,
+ "PubSub thread": repr(self.pubsub_thread),
+ "Run in thread": self.run_in_thread,
+ }
+ )
class RedisBackendLock(DbBackendLockAbc):
def __init__(self, ns: str, name: str, expiration: Union[int, float],
redis_backend: RedisBackend) -> None:
super().__init__(ns, name)
- self.__redis = redis_backend.get_redis_connection()
+ self.__redis = redis_backend.get_redis_connection(ns)
with _map_to_sdl_exception():
redis_lockname = '{' + ns + '},' + self._lock_name
self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)