X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=blobdiff_plain;f=ricsdl-package%2Fricsdl%2Fbackend%2Fredis.py;h=3ebc8cb324d5a56255e875f04ad9c9149e89ac78;hb=77c5b120496bdcb798d0e02b719c177e8f48d4e9;hp=3364497e0f8383e6f635a5b24a3f69adda59311e;hpb=bef156a640df036aa97fe1f2656c54a5717fc12b;p=ric-plt%2Fsdlpy.git diff --git a/ricsdl-package/ricsdl/backend/redis.py b/ricsdl-package/ricsdl/backend/redis.py old mode 100644 new mode 100755 index 3364497..3ebc8cb --- a/ricsdl-package/ricsdl/backend/redis.py +++ b/ricsdl-package/ricsdl/backend/redis.py @@ -21,11 +21,14 @@ """The module provides implementation of Shared Data Layer (SDL) database backend interface.""" import contextlib -from typing import (Dict, Set, List, Union) +import threading +from typing import (Callable, Dict, Set, List, Optional, Tuple, Union) +import zlib +import redis from redis import Redis from redis.sentinel import Sentinel from redis.lock import Lock -from redis._compat import nativestr +from redis.utils import str_if_bytes from redis import exceptions as redis_exceptions from ricsdl.configuration import _Configuration from ricsdl.exceptions import ( @@ -53,6 +56,91 @@ def _map_to_sdl_exception(): format(str(exc))) from exc +class PubSub(redis.client.PubSub): + def __init__(self, event_separator, connection_pool, ignore_subscribe_messages=False): + super().__init__(connection_pool, shard_hint=None, ignore_subscribe_messages=ignore_subscribe_messages) + self.event_separator = event_separator + + def handle_message(self, response, ignore_subscribe_messages=False): + """ + Parses a pub/sub message. If the channel or pattern was subscribed to + with a message handler, the handler is invoked instead of a parsed + message being returned. + + Adapted from: https://github.com/andymccurdy/redis-py/blob/master/redis/client.py + """ + message_type = str_if_bytes(response[0]) + if message_type == 'pmessage': + message = { + 'type': message_type, + 'pattern': response[1], + 'channel': response[2], + 'data': response[3] + } + elif message_type == 'pong': + message = { + 'type': message_type, + 'pattern': None, + 'channel': None, + 'data': response[1] + } + else: + message = { + 'type': message_type, + 'pattern': None, + 'channel': response[1], + 'data': response[2] + } + + # if this is an unsubscribe message, remove it from memory + if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: + if message_type == 'punsubscribe': + pattern = response[1] + if pattern in self.pending_unsubscribe_patterns: + self.pending_unsubscribe_patterns.remove(pattern) + self.patterns.pop(pattern, None) + else: + channel = response[1] + if channel in self.pending_unsubscribe_channels: + self.pending_unsubscribe_channels.remove(channel) + self.channels.pop(channel, None) + + if message_type in self.PUBLISH_MESSAGE_TYPES: + # if there's a message handler, invoke it + if message_type == 'pmessage': + handler = self.patterns.get(message['pattern'], None) + else: + handler = self.channels.get(message['channel'], None) + if handler: + # Need to send only channel and notification instead of raw + # message + message_channel = self._strip_ns_from_bin_key('', message['channel']) + message_data = message['data'].decode('utf-8') + messages = message_data.split(self.event_separator) + handler(message_channel, messages) + return message_channel, messages + elif message_type != 'pong': + # this is a subscribe/unsubscribe message. ignore if we don't + # want them + if ignore_subscribe_messages or self.ignore_subscribe_messages: + return None + + return message + + @classmethod + def _strip_ns_from_bin_key(cls, ns: str, nskey: bytes) -> str: + try: + redis_key = nskey.decode('utf-8') + except UnicodeDecodeError as exc: + msg = u'Namespace %s key conversion to string failed: %s' % (ns, str(exc)) + raise RejectedByBackend(msg) + nskey = redis_key.split(',', 1) + if len(nskey) != 2: + msg = u'Namespace %s key:%s has no namespace prefix' % (ns, redis_key) + raise RejectedByBackend(msg) + return nskey[1] + + class RedisBackend(DbBackendAbc): """ A class providing an implementation of database backend of Shared Data Layer (SDL), when @@ -64,70 +152,65 @@ class RedisBackend(DbBackendAbc): """ def __init__(self, configuration: _Configuration) -> None: super().__init__() + self.next_client_event = 0 + self.event_separator = configuration.get_event_separator() + self.clients = list() with _map_to_sdl_exception(): - if configuration.get_params().db_sentinel_port: - sentinel_node = (configuration.get_params().db_host, - configuration.get_params().db_sentinel_port) - master_name = configuration.get_params().db_sentinel_master_name - self.__sentinel = Sentinel([sentinel_node]) - self.__redis = self.__sentinel.master_for(master_name) - else: - self.__redis = Redis(host=configuration.get_params().db_host, - port=configuration.get_params().db_port, - db=0, - max_connections=20) - self.__redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False) - self.__redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False) + self.clients = self.__create_redis_clients(configuration) def __del__(self): self.close() def __str__(self): - return str( - { - "DB type": "Redis", - "Redis connection": repr(self.__redis) - } - ) + out = {"DB type": "Redis"} + for i, r in enumerate(self.clients): + out["Redis client[" + str(i) + "]"] = str(r) + return str(out) def is_connected(self): + is_connected = True with _map_to_sdl_exception(): - return self.__redis.ping() + for c in self.clients: + if not c.redis_client.ping(): + is_connected = False + break + return is_connected def close(self): - self.__redis.close() + for c in self.clients: + c.redis_client.close() def set(self, ns: str, data_map: Dict[str, bytes]) -> None: - db_data_map = self._add_data_map_ns_prefix(ns, data_map) + db_data_map = self.__add_data_map_ns_prefix(ns, data_map) with _map_to_sdl_exception(): - self.__redis.mset(db_data_map) + self.__getClient(ns).mset(db_data_map) def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool: - db_key = self._add_key_ns_prefix(ns, key) + db_key = self.__add_key_ns_prefix(ns, key) with _map_to_sdl_exception(): - return self.__redis.execute_command('SETIE', db_key, new_data, old_data) + return self.__getClient(ns).execute_command('SETIE', db_key, new_data, old_data) def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool: - db_key = self._add_key_ns_prefix(ns, key) + db_key = self.__add_key_ns_prefix(ns, key) with _map_to_sdl_exception(): - return self.__redis.setnx(db_key, data) + return self.__getClient(ns).setnx(db_key, data) def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]: ret = dict() - db_keys = self._add_keys_ns_prefix(ns, keys) + db_keys = self.__add_keys_ns_prefix(ns, keys) with _map_to_sdl_exception(): - values = self.__redis.mget(db_keys) + values = self.__getClient(ns).mget(db_keys) for idx, val in enumerate(values): # return only key values, which has a value - if val: + if val is not None: ret[keys[idx]] = val return ret def find_keys(self, ns: str, key_pattern: str) -> List[str]: - db_key_pattern = self._add_key_ns_prefix(ns, key_pattern) + db_key_pattern = self.__add_key_ns_prefix(ns, key_pattern) with _map_to_sdl_exception(): - ret = self.__redis.keys(db_key_pattern) - return self._strip_ns_from_bin_keys(ns, ret) + ret = self.__getClient(ns).keys(db_key_pattern) + return self.__strip_ns_from_bin_keys(ns, ret) def find_and_get(self, ns: str, key_pattern: str) -> Dict[str, bytes]: # todo: replace below implementation with redis 'NGET' module @@ -139,65 +222,224 @@ class RedisBackend(DbBackendAbc): return ret def remove(self, ns: str, keys: List[str]) -> None: - db_keys = self._add_keys_ns_prefix(ns, keys) + db_keys = self.__add_keys_ns_prefix(ns, keys) with _map_to_sdl_exception(): - self.__redis.delete(*db_keys) + self.__getClient(ns).delete(*db_keys) def remove_if(self, ns: str, key: str, data: bytes) -> bool: - db_key = self._add_key_ns_prefix(ns, key) + db_key = self.__add_key_ns_prefix(ns, key) with _map_to_sdl_exception(): - return self.__redis.execute_command('DELIE', db_key, data) + return self.__getClient(ns).execute_command('DELIE', db_key, data) def add_member(self, ns: str, group: str, members: Set[bytes]) -> None: - db_key = self._add_key_ns_prefix(ns, group) + db_key = self.__add_key_ns_prefix(ns, group) with _map_to_sdl_exception(): - self.__redis.sadd(db_key, *members) + self.__getClient(ns).sadd(db_key, *members) def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None: - db_key = self._add_key_ns_prefix(ns, group) + db_key = self.__add_key_ns_prefix(ns, group) with _map_to_sdl_exception(): - self.__redis.srem(db_key, *members) + self.__getClient(ns).srem(db_key, *members) def remove_group(self, ns: str, group: str) -> None: - db_key = self._add_key_ns_prefix(ns, group) + db_key = self.__add_key_ns_prefix(ns, group) with _map_to_sdl_exception(): - self.__redis.delete(db_key) + self.__getClient(ns).delete(db_key) def get_members(self, ns: str, group: str) -> Set[bytes]: - db_key = self._add_key_ns_prefix(ns, group) + db_key = self.__add_key_ns_prefix(ns, group) with _map_to_sdl_exception(): - return self.__redis.smembers(db_key) + return self.__getClient(ns).smembers(db_key) def is_member(self, ns: str, group: str, member: bytes) -> bool: - db_key = self._add_key_ns_prefix(ns, group) + db_key = self.__add_key_ns_prefix(ns, group) with _map_to_sdl_exception(): - return self.__redis.sismember(db_key, member) + return self.__getClient(ns).sismember(db_key, member) def group_size(self, ns: str, group: str) -> int: - db_key = self._add_key_ns_prefix(ns, group) + db_key = self.__add_key_ns_prefix(ns, group) + with _map_to_sdl_exception(): + return self.__getClient(ns).scard(db_key) + + def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], + data_map: Dict[str, bytes]) -> None: + db_data_map = self.__add_data_map_ns_prefix(ns, data_map) + channels_and_events_prepared = [] + total_events = 0 + channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events) with _map_to_sdl_exception(): - return self.__redis.scard(db_key) + return self.__getClient(ns).execute_command( + "MSETMPUB", + len(db_data_map), + total_events, + *[val for data in db_data_map.items() for val in data], + *channels_and_events_prepared, + ) + + def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str, + old_data: bytes, new_data: bytes) -> bool: + db_key = self.__add_key_ns_prefix(ns, key) + channels_and_events_prepared = [] + channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events) + with _map_to_sdl_exception(): + ret = self.__getClient(ns).execute_command("SETIEMPUB", db_key, new_data, old_data, + *channels_and_events_prepared) + return ret == b"OK" + + def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], + key: str, data: bytes) -> bool: + db_key = self.__add_key_ns_prefix(ns, key) + channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events) + with _map_to_sdl_exception(): + ret = self.__getClient(ns).execute_command("SETNXMPUB", db_key, data, + *channels_and_events_prepared) + return ret == b"OK" + + def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], + keys: List[str]) -> None: + db_keys = self.__add_keys_ns_prefix(ns, keys) + channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events) + with _map_to_sdl_exception(): + return self.__getClient(ns).execute_command( + "DELMPUB", + len(db_keys), + total_events, + *db_keys, + *channels_and_events_prepared, + ) + + def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str, + data: bytes) -> bool: + db_key = self.__add_key_ns_prefix(ns, key) + channels_and_events_prepared, _ = self._prepare_channels(ns, channels_and_events) + with _map_to_sdl_exception(): + ret = self.__getClient(ns).execute_command("DELIEMPUB", db_key, data, + *channels_and_events_prepared) + return bool(ret) + + def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None: + keys = self.__getClient(ns).keys(self.__add_key_ns_prefix(ns, "*")) + channels_and_events_prepared, total_events = self._prepare_channels(ns, channels_and_events) + with _map_to_sdl_exception(): + return self.__getClient(ns).execute_command( + "DELMPUB", + len(keys), + total_events, + *keys, + *channels_and_events_prepared, + ) + + def subscribe_channel(self, ns: str, cb: Callable[[str, List[str]], None], + channels: List[str]) -> None: + channels = self.__add_keys_ns_prefix(ns, channels) + for channel in channels: + with _map_to_sdl_exception(): + redis_ctx = self.__getClientConn(ns) + redis_ctx.redis_pubsub.subscribe(**{channel: cb}) + if not redis_ctx.pubsub_thread.is_alive() and redis_ctx.run_in_thread: + redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001, + daemon=True) + + def unsubscribe_channel(self, ns: str, channels: List[str]) -> None: + channels = self.__add_keys_ns_prefix(ns, channels) + for channel in channels: + with _map_to_sdl_exception(): + self.__getClientConn(ns).redis_pubsub.unsubscribe(channel) + + def start_event_listener(self) -> None: + redis_ctxs = self.__getClientConns() + for redis_ctx in redis_ctxs: + if redis_ctx.pubsub_thread.is_alive(): + raise RejectedByBackend("Event loop already started") + if redis_ctx.redis_pubsub.subscribed and len(redis_ctx.redis_client.pubsub_channels()) > 0: + redis_ctx.pubsub_thread = redis_ctx.redis_pubsub.run_in_thread(sleep_time=0.001, daemon=True) + redis_ctx.run_in_thread = True + + def handle_events(self) -> Optional[Tuple[str, List[str]]]: + if self.next_client_event >= len(self.clients): + self.next_client_event = 0 + redis_ctx = self.clients[self.next_client_event] + self.next_client_event += 1 + if redis_ctx.pubsub_thread.is_alive() or redis_ctx.run_in_thread: + raise RejectedByBackend("Event loop already started") + try: + return redis_ctx.redis_pubsub.get_message(ignore_subscribe_messages=True) + except RuntimeError: + return None + + def __create_redis_clients(self, config): + clients = list() + cfg_params = config.get_params() + if cfg_params.db_cluster_addr_list is None: + clients.append(self.__create_legacy_redis_client(cfg_params)) + else: + for addr in cfg_params.db_cluster_addr_list.split(","): + client = self.__create_redis_client(cfg_params, addr) + clients.append(client) + return clients + + def __create_legacy_redis_client(self, cfg_params): + return self.__create_redis_client(cfg_params, cfg_params.db_host) + + def __create_redis_client(self, cfg_params, addr): + new_sentinel = None + new_redis = None + if cfg_params.db_sentinel_port is None: + new_redis = Redis(host=addr, port=cfg_params.db_port, db=0, max_connections=20) + else: + sentinel_node = (addr, cfg_params.db_sentinel_port) + master_name = cfg_params.db_sentinel_master_name + new_sentinel = Sentinel([sentinel_node]) + new_redis = new_sentinel.master_for(master_name) + + new_redis.set_response_callback('SETIE', lambda r: r and str_if_bytes(r) == 'OK' or False) + new_redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False) + + redis_pubsub = PubSub(self.event_separator, new_redis.connection_pool, ignore_subscribe_messages=True) + pubsub_thread = threading.Thread(target=None) + run_in_thread = False + + return _RedisConn(new_redis, redis_pubsub, pubsub_thread, run_in_thread) + + def __getClientConns(self): + return self.clients + + def __getClientConn(self, ns): + clients_cnt = len(self.clients) + client_id = self.__get_hash(ns) % clients_cnt + return self.clients[client_id] + + def __getClient(self, ns): + clients_cnt = len(self.clients) + client_id = 0 + if clients_cnt > 1: + client_id = self.__get_hash(ns) % clients_cnt + return self.clients[client_id].redis_client + + @classmethod + def __get_hash(cls, str): + return zlib.crc32(str.encode()) @classmethod - def _add_key_ns_prefix(cls, ns: str, key: str): + def __add_key_ns_prefix(cls, ns: str, key: str): return '{' + ns + '},' + key @classmethod - def _add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]: + def __add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]: ret_nskeys = [] for k in keylist: ret_nskeys.append('{' + ns + '},' + k) return ret_nskeys @classmethod - def _add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]: + def __add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]: ret_nsdict = {} for key, val in data_dict.items(): ret_nsdict['{' + ns + '},' + key] = val return ret_nsdict @classmethod - def _strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]: + def __strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]: ret_keys = [] for k in nskeylist: try: @@ -212,9 +454,46 @@ class RedisBackend(DbBackendAbc): ret_keys.append(nskey[1]) return ret_keys - def get_redis_connection(self): - """Return existing Redis database connection.""" - return self.__redis + def _prepare_channels(self, ns: str, + channels_and_events: Dict[str, List[str]]) -> Tuple[List, int]: + channels_and_events_prepared = [] + for channel, events in channels_and_events.items(): + one_channel_join_events = None + for event in events: + if one_channel_join_events is None: + channels_and_events_prepared.append(self.__add_key_ns_prefix(ns, channel)) + one_channel_join_events = event + else: + one_channel_join_events = one_channel_join_events + self.event_separator + event + channels_and_events_prepared.append(one_channel_join_events) + pairs_cnt = int(len(channels_and_events_prepared) / 2) + return channels_and_events_prepared, pairs_cnt + + def get_redis_connection(self, ns: str): + """Return existing Redis database connection valid for the namespace.""" + return self.__getClient(ns) + + +class _RedisConn: + """ + Internal class container to hold redis client connection + """ + + def __init__(self, redis_client, pubsub, pubsub_thread, run_in_thread): + self.redis_client = redis_client + self.redis_pubsub = pubsub + self.pubsub_thread = pubsub_thread + self.run_in_thread = run_in_thread + + def __str__(self): + return str( + { + "Client": repr(self.redis_client), + "Subscrions": self.redis_pubsub.subscribed, + "PubSub thread": repr(self.pubsub_thread), + "Run in thread": self.run_in_thread, + } + ) class RedisBackendLock(DbBackendLockAbc): @@ -248,7 +527,7 @@ class RedisBackendLock(DbBackendLockAbc): def __init__(self, ns: str, name: str, expiration: Union[int, float], redis_backend: RedisBackend) -> None: super().__init__(ns, name) - self.__redis = redis_backend.get_redis_connection() + self.__redis = redis_backend.get_redis_connection(ns) with _map_to_sdl_exception(): redis_lockname = '{' + ns + '},' + self._lock_name self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)