Add support for notifications
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / syncstorage.py
old mode 100644 (file)
new mode 100755 (executable)
index 0a2d9b3..48b5e3d
@@ -20,7 +20,8 @@
 
 """The module provides implementation of the syncronous Shared Data Layer (SDL) interface."""
 import builtins
-from typing import (Dict, Set, List, Union)
+import inspect
+from typing import (Any, Callable, Dict, Set, List, Optional, Tuple, Union)
 from ricsdl.configuration import _Configuration
 from ricsdl.syncstorage_abc import (SyncStorageAbc, SyncLockAbc)
 import ricsdl.backend
@@ -206,6 +207,81 @@ class SyncStorage(SyncStorageAbc):
     def group_size(self, ns: str, group: str) -> int:
         return self.__dbbackend.group_size(ns, group)
 
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, data_map=dict)
+    def set_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                        data_map: Dict[str, bytes]) -> None:
+        self._validate_key_value_dict(data_map)
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        self.__dbbackend.set_and_publish(ns, channels_and_events, data_map)
+
+    @func_arg_checker(SdlTypeError,
+                      1,
+                      ns=str,
+                      channels_and_events=dict,
+                      key=str,
+                      old_data=bytes,
+                      new_data=bytes)
+    def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                           key: str, old_data: bytes, new_data: bytes) -> bool:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        return self.__dbbackend.set_if_and_publish(ns, channels_and_events, key, old_data, new_data)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, key=str, data=bytes)
+    def set_if_not_exists_and_publish(self, ns: str,
+                                      channels_and_events: Dict[str, Union[str, List[str]]],
+                                      key: str, data: bytes) -> bool:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        return self.__dbbackend.set_if_not_exists_and_publish(ns, channels_and_events, key, data)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, keys=(str, builtins.set))
+    def remove_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                           keys: Union[str, Set[str]]) -> None:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        keys = [keys] if isinstance(keys, str) else list(keys)
+        self.__dbbackend.remove_and_publish(ns, channels_and_events, keys)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, key=str, data=bytes)
+    def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                              key: str, data: bytes) -> bool:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        return self.__dbbackend.remove_if_and_publish(ns, channels_and_events, key, data)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict)
+    def remove_all_and_publish(self, ns: str,
+                               channels_and_events: Dict[str, Union[str, List[str]]]) -> None:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        self.__dbbackend.remove_all_and_publish(ns, channels_and_events)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, cb=Callable, channels=(str, builtins.set))
+    def subscribe_channel(self, ns: str, cb: Callable[[str, str], None],
+                          channels: Union[str, Set[str]]) -> None:
+        self._validate_callback(cb)
+        channels = [channels] if isinstance(channels, str) else list(channels)
+        self.__dbbackend.subscribe_channel(ns, cb, channels)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels=(str, builtins.set))
+    def unsubscribe_channel(self, ns: str, channels: Union[str, Set[str]]) -> None:
+        channels = [channels] if isinstance(channels, str) else list(channels)
+        self.__dbbackend.unsubscribe_channel(ns, channels)
+
+    def start_event_listener(self) -> None:
+        self.__dbbackend.start_event_listener()
+
+    def handle_events(self) -> Optional[Tuple[str, str]]:
+        return self.__dbbackend.handle_events()
+
     @func_arg_checker(SdlTypeError, 1, ns=str, resource=str, expiration=(int, float))
     def get_lock_resource(self, ns: str, resource: str, expiration: Union[int, float]) -> SyncLock:
         return SyncLock(ns, resource, expiration, self)
@@ -225,3 +301,25 @@ class SyncStorage(SyncStorageAbc):
                 raise SdlTypeError(r"Wrong dict key type: {}={}. Must be: str".format(k, type(k)))
             if not isinstance(v, bytes):
                 raise SdlTypeError(r"Wrong dict value type: {}={}. Must be: bytes".format(v, type(v)))
+
+    @classmethod
+    def _validate_channels_events(cls, channels_and_events: Dict[Any, Any]):
+        for channel, events in channels_and_events.items():
+            if not isinstance(channel, str):
+                raise SdlTypeError(r"Wrong channel type: {}={}. Must be: str".format(
+                    channel, type(channel)))
+            if not isinstance(events, (list, str)):
+                raise SdlTypeError(r"Wrong event type: {}={}. Must be: str".format(
+                    events, type(events)))
+            if isinstance(events, list):
+                for event in events:
+                    if not isinstance(event, str):
+                        raise SdlTypeError(r"Wrong event type: {}={}. Must be: str".format(
+                            events, type(events)))
+
+    @classmethod
+    def _validate_callback(cls, cb):
+        param_len = len(inspect.signature(cb).parameters)
+        if param_len != 2:
+            raise SdlTypeError(
+                f"Callback function should take 2 positional argument but {param_len} were given")