X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=blobdiff_plain;f=ricsdl-package%2Fricsdl%2Fsyncstorage.py;h=b15365a168c93a7de203643ec25542b09905d24d;hb=77c5b120496bdcb798d0e02b719c177e8f48d4e9;hp=29adb17bf25207e65758117ac50074d5c3b416a2;hpb=c979c0db16f873c0f8ea6fe5d1b98c15f79d18de;p=ric-plt%2Fsdlpy.git diff --git a/ricsdl-package/ricsdl/syncstorage.py b/ricsdl-package/ricsdl/syncstorage.py old mode 100644 new mode 100755 index 29adb17..b15365a --- a/ricsdl-package/ricsdl/syncstorage.py +++ b/ricsdl-package/ricsdl/syncstorage.py @@ -20,7 +20,8 @@ """The module provides implementation of the syncronous Shared Data Layer (SDL) interface.""" import builtins -from typing import (Dict, Set, List, Union) +import inspect +from typing import (Any, Callable, Dict, Set, List, Optional, Tuple, Union) from ricsdl.configuration import _Configuration from ricsdl.syncstorage_abc import (SyncStorageAbc, SyncLockAbc) import ricsdl.backend @@ -119,6 +120,7 @@ class SyncStorage(SyncStorageAbc): def __init__(self, fake_db_backend=None) -> None: super().__init__() self.__configuration = _Configuration(fake_db_backend) + self.event_separator = self.__configuration.get_event_separator() self.__dbbackend = ricsdl.backend.get_backend_instance(self.__configuration) def __del__(self): @@ -143,6 +145,7 @@ class SyncStorage(SyncStorageAbc): @func_arg_checker(SdlTypeError, 1, ns=str, data_map=dict) def set(self, ns: str, data_map: Dict[str, bytes]) -> None: + self._validate_key_value_dict(data_map) self.__dbbackend.set(ns, data_map) @func_arg_checker(SdlTypeError, 1, ns=str, key=str, old_data=bytes, new_data=bytes) @@ -205,6 +208,81 @@ class SyncStorage(SyncStorageAbc): def group_size(self, ns: str, group: str) -> int: return self.__dbbackend.group_size(ns, group) + @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, data_map=dict) + def set_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]], + data_map: Dict[str, bytes]) -> None: + self._validate_key_value_dict(data_map) + self._validate_channels_events(channels_and_events) + for channel, events in channels_and_events.items(): + channels_and_events[channel] = [events] if isinstance(events, str) else events + self.__dbbackend.set_and_publish(ns, channels_and_events, data_map) + + @func_arg_checker(SdlTypeError, + 1, + ns=str, + channels_and_events=dict, + key=str, + old_data=bytes, + new_data=bytes) + def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]], + key: str, old_data: bytes, new_data: bytes) -> bool: + self._validate_channels_events(channels_and_events) + for channel, events in channels_and_events.items(): + channels_and_events[channel] = [events] if isinstance(events, str) else events + return self.__dbbackend.set_if_and_publish(ns, channels_and_events, key, old_data, new_data) + + @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, key=str, data=bytes) + def set_if_not_exists_and_publish(self, ns: str, + channels_and_events: Dict[str, Union[str, List[str]]], + key: str, data: bytes) -> bool: + self._validate_channels_events(channels_and_events) + for channel, events in channels_and_events.items(): + channels_and_events[channel] = [events] if isinstance(events, str) else events + return self.__dbbackend.set_if_not_exists_and_publish(ns, channels_and_events, key, data) + + @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, keys=(str, builtins.set)) + def remove_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]], + keys: Union[str, Set[str]]) -> None: + self._validate_channels_events(channels_and_events) + for channel, events in channels_and_events.items(): + channels_and_events[channel] = [events] if isinstance(events, str) else events + keys = [keys] if isinstance(keys, str) else list(keys) + self.__dbbackend.remove_and_publish(ns, channels_and_events, keys) + + @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, key=str, data=bytes) + def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]], + key: str, data: bytes) -> bool: + self._validate_channels_events(channels_and_events) + for channel, events in channels_and_events.items(): + channels_and_events[channel] = [events] if isinstance(events, str) else events + return self.__dbbackend.remove_if_and_publish(ns, channels_and_events, key, data) + + @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict) + def remove_all_and_publish(self, ns: str, + channels_and_events: Dict[str, Union[str, List[str]]]) -> None: + self._validate_channels_events(channels_and_events) + for channel, events in channels_and_events.items(): + channels_and_events[channel] = [events] if isinstance(events, str) else events + self.__dbbackend.remove_all_and_publish(ns, channels_and_events) + + @func_arg_checker(SdlTypeError, 1, ns=str, cb=Callable, channels=(str, builtins.set)) + def subscribe_channel(self, ns: str, cb: Callable[[str, List[str]], None], + channels: Union[str, Set[str]]) -> None: + self._validate_callback(cb) + channels = [channels] if isinstance(channels, str) else list(channels) + self.__dbbackend.subscribe_channel(ns, cb, channels) + + @func_arg_checker(SdlTypeError, 1, ns=str, channels=(str, builtins.set)) + def unsubscribe_channel(self, ns: str, channels: Union[str, Set[str]]) -> None: + channels = [channels] if isinstance(channels, str) else list(channels) + self.__dbbackend.unsubscribe_channel(ns, channels) + + def start_event_listener(self) -> None: + self.__dbbackend.start_event_listener() + + def handle_events(self) -> Optional[Tuple[str, List[str]]]: + return self.__dbbackend.handle_events() + @func_arg_checker(SdlTypeError, 1, ns=str, resource=str, expiration=(int, float)) def get_lock_resource(self, ns: str, resource: str, expiration: Union[int, float]) -> SyncLock: return SyncLock(ns, resource, expiration, self) @@ -216,3 +294,39 @@ class SyncStorage(SyncStorageAbc): def get_configuration(self) -> _Configuration: """Return configuration what was valid when the SDL instance was initiated.""" return self.__configuration + + @classmethod + def _validate_key_value_dict(cls, kv): + for k, v in kv.items(): + if not isinstance(k, str): + raise SdlTypeError(r"Wrong dict key type: {}={}. Must be: str".format(k, type(k))) + if not isinstance(v, bytes): + raise SdlTypeError(r"Wrong dict value type: {}={}. Must be: bytes".format(v, type(v))) + + def _validate_channels_events(self, channels_and_events: Dict[Any, Any]): + for channel, events in channels_and_events.items(): + if not isinstance(channel, str): + raise SdlTypeError(r"Wrong channel type: {}={}. Must be: str".format( + channel, type(channel))) + if not isinstance(events, (list, str)): + raise SdlTypeError(r"Wrong event type: {}={}. Must be: str".format( + events, type(events))) + if isinstance(events, list): + for event in events: + if not isinstance(event, str): + raise SdlTypeError(r"Wrong event type: {}={}. Must be: str".format( + events, type(events))) + if self.event_separator in event: + raise SdlTypeError(r"Events {} contains illegal substring (\"{}\")".format( + events, self.event_separator)) + else: + if self.event_separator in events: + raise SdlTypeError(r"Events {} contains illegal substring (\"{}\")".format( + events, self.event_separator)) + + @classmethod + def _validate_callback(cls, cb): + param_len = len(inspect.signature(cb).parameters) + if param_len != 2: + raise SdlTypeError( + f"Callback function should take 2 positional argument but {param_len} were given")