Read list type of environment variables
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / syncstorage.py
old mode 100644 (file)
new mode 100755 (executable)
index 29adb17..62e3a7f
@@ -1,5 +1,5 @@
 # Copyright (c) 2019 AT&T Intellectual Property.
-# Copyright (c) 2018-2019 Nokia.
+# Copyright (c) 2018-2022 Nokia.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,7 +20,8 @@
 
 """The module provides implementation of the syncronous Shared Data Layer (SDL) interface."""
 import builtins
-from typing import (Dict, Set, List, Union)
+import inspect
+from typing import (Any, Callable, Dict, Set, List, Optional, Tuple, Union)
 from ricsdl.configuration import _Configuration
 from ricsdl.syncstorage_abc import (SyncStorageAbc, SyncLockAbc)
 import ricsdl.backend
@@ -118,7 +119,9 @@ class SyncStorage(SyncStorageAbc):
     """
     def __init__(self, fake_db_backend=None) -> None:
         super().__init__()
+        self.__dbbackend = None
         self.__configuration = _Configuration(fake_db_backend)
+        self.event_separator = self.__configuration.get_event_separator()
         self.__dbbackend = ricsdl.backend.get_backend_instance(self.__configuration)
 
     def __del__(self):
@@ -139,10 +142,12 @@ class SyncStorage(SyncStorageAbc):
             return False
 
     def close(self):
-        self.__dbbackend.close()
+        if self.__dbbackend:
+            self.__dbbackend.close()
 
     @func_arg_checker(SdlTypeError, 1, ns=str, data_map=dict)
     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
+        self._validate_key_value_dict(data_map)
         self.__dbbackend.set(ns, data_map)
 
     @func_arg_checker(SdlTypeError, 1, ns=str, key=str, old_data=bytes, new_data=bytes)
@@ -205,6 +210,81 @@ class SyncStorage(SyncStorageAbc):
     def group_size(self, ns: str, group: str) -> int:
         return self.__dbbackend.group_size(ns, group)
 
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, data_map=dict)
+    def set_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                        data_map: Dict[str, bytes]) -> None:
+        self._validate_key_value_dict(data_map)
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        self.__dbbackend.set_and_publish(ns, channels_and_events, data_map)
+
+    @func_arg_checker(SdlTypeError,
+                      1,
+                      ns=str,
+                      channels_and_events=dict,
+                      key=str,
+                      old_data=bytes,
+                      new_data=bytes)
+    def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                           key: str, old_data: bytes, new_data: bytes) -> bool:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        return self.__dbbackend.set_if_and_publish(ns, channels_and_events, key, old_data, new_data)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, key=str, data=bytes)
+    def set_if_not_exists_and_publish(self, ns: str,
+                                      channels_and_events: Dict[str, Union[str, List[str]]],
+                                      key: str, data: bytes) -> bool:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        return self.__dbbackend.set_if_not_exists_and_publish(ns, channels_and_events, key, data)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, keys=(str, builtins.set))
+    def remove_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                           keys: Union[str, Set[str]]) -> None:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        keys = [keys] if isinstance(keys, str) else list(keys)
+        self.__dbbackend.remove_and_publish(ns, channels_and_events, keys)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, key=str, data=bytes)
+    def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                              key: str, data: bytes) -> bool:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        return self.__dbbackend.remove_if_and_publish(ns, channels_and_events, key, data)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict)
+    def remove_all_and_publish(self, ns: str,
+                               channels_and_events: Dict[str, Union[str, List[str]]]) -> None:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        self.__dbbackend.remove_all_and_publish(ns, channels_and_events)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, cb=Callable, channels=(str, builtins.set))
+    def subscribe_channel(self, ns: str, cb: Callable[[str, List[str]], None],
+                          channels: Union[str, Set[str]]) -> None:
+        self._validate_callback(cb)
+        channels = [channels] if isinstance(channels, str) else list(channels)
+        self.__dbbackend.subscribe_channel(ns, cb, channels)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels=(str, builtins.set))
+    def unsubscribe_channel(self, ns: str, channels: Union[str, Set[str]]) -> None:
+        channels = [channels] if isinstance(channels, str) else list(channels)
+        self.__dbbackend.unsubscribe_channel(ns, channels)
+
+    def start_event_listener(self) -> None:
+        self.__dbbackend.start_event_listener()
+
+    def handle_events(self) -> Optional[Tuple[str, List[str]]]:
+        return self.__dbbackend.handle_events()
+
     @func_arg_checker(SdlTypeError, 1, ns=str, resource=str, expiration=(int, float))
     def get_lock_resource(self, ns: str, resource: str, expiration: Union[int, float]) -> SyncLock:
         return SyncLock(ns, resource, expiration, self)
@@ -216,3 +296,39 @@ class SyncStorage(SyncStorageAbc):
     def get_configuration(self) -> _Configuration:
         """Return configuration what was valid when the SDL instance was initiated."""
         return self.__configuration
+
+    @classmethod
+    def _validate_key_value_dict(cls, kv):
+        for k, v in kv.items():
+            if not isinstance(k, str):
+                raise SdlTypeError(r"Wrong dict key type: {}={}. Must be: str".format(k, type(k)))
+            if not isinstance(v, bytes):
+                raise SdlTypeError(r"Wrong dict value type: {}={}. Must be: bytes".format(v, type(v)))
+
+    def _validate_channels_events(self, channels_and_events: Dict[Any, Any]):
+        for channel, events in channels_and_events.items():
+            if not isinstance(channel, str):
+                raise SdlTypeError(r"Wrong channel type: {}={}. Must be: str".format(
+                    channel, type(channel)))
+            if not isinstance(events, (list, str)):
+                raise SdlTypeError(r"Wrong event type: {}={}. Must be: str".format(
+                    events, type(events)))
+            if isinstance(events, list):
+                for event in events:
+                    if not isinstance(event, str):
+                        raise SdlTypeError(r"Wrong event type: {}={}. Must be: str".format(
+                            events, type(events)))
+                    if self.event_separator in event:
+                        raise SdlTypeError(r"Events {} contains illegal substring (\"{}\")".format(
+                            events, self.event_separator))
+            else:
+                if self.event_separator in events:
+                    raise SdlTypeError(r"Events {} contains illegal substring (\"{}\")".format(
+                        events, self.event_separator))
+
+    @classmethod
+    def _validate_callback(cls, cb):
+        param_len = len(inspect.signature(cb).parameters)
+        if param_len != 2:
+            raise SdlTypeError(
+                f"Callback function should take 2 positional argument but {param_len} were given")