Release ricsdl 3.1.0
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / syncstorage.py
old mode 100644 (file)
new mode 100755 (executable)
index 41deeba..62e3a7f
@@ -1,5 +1,5 @@
 # Copyright (c) 2019 AT&T Intellectual Property.
 # Copyright (c) 2019 AT&T Intellectual Property.
-# Copyright (c) 2018-2019 Nokia.
+# Copyright (c) 2018-2022 Nokia.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 
 """The module provides implementation of the syncronous Shared Data Layer (SDL) interface."""
 import builtins
 
 """The module provides implementation of the syncronous Shared Data Layer (SDL) interface."""
 import builtins
-from typing import (Dict, Set, List, Union)
+import inspect
+from typing import (Any, Callable, Dict, Set, List, Optional, Tuple, Union)
 from ricsdl.configuration import _Configuration
 from ricsdl.syncstorage_abc import (SyncStorageAbc, SyncLockAbc)
 import ricsdl.backend
 from ricsdl.backend.dbbackend_abc import DbBackendAbc
 from ricsdl.configuration import _Configuration
 from ricsdl.syncstorage_abc import (SyncStorageAbc, SyncLockAbc)
 import ricsdl.backend
 from ricsdl.backend.dbbackend_abc import DbBackendAbc
-from ricsdl.exceptions import SdlTypeError
+from ricsdl.exceptions import (SdlException, SdlTypeError)
 
 
 def func_arg_checker(exception, start_arg_idx, **types):
 
 
 def func_arg_checker(exception, start_arg_idx, **types):
@@ -118,7 +119,9 @@ class SyncStorage(SyncStorageAbc):
     """
     def __init__(self, fake_db_backend=None) -> None:
         super().__init__()
     """
     def __init__(self, fake_db_backend=None) -> None:
         super().__init__()
+        self.__dbbackend = None
         self.__configuration = _Configuration(fake_db_backend)
         self.__configuration = _Configuration(fake_db_backend)
+        self.event_separator = self.__configuration.get_event_separator()
         self.__dbbackend = ricsdl.backend.get_backend_instance(self.__configuration)
 
     def __del__(self):
         self.__dbbackend = ricsdl.backend.get_backend_instance(self.__configuration)
 
     def __del__(self):
@@ -132,11 +135,19 @@ class SyncStorage(SyncStorageAbc):
             }
         )
 
             }
         )
 
+    def is_active(self):
+        try:
+            return self.__dbbackend.is_connected()
+        except SdlException:
+            return False
+
     def close(self):
     def close(self):
-        self.__dbbackend.close()
+        if self.__dbbackend:
+            self.__dbbackend.close()
 
     @func_arg_checker(SdlTypeError, 1, ns=str, data_map=dict)
     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
 
     @func_arg_checker(SdlTypeError, 1, ns=str, data_map=dict)
     def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
+        self._validate_key_value_dict(data_map)
         self.__dbbackend.set(ns, data_map)
 
     @func_arg_checker(SdlTypeError, 1, ns=str, key=str, old_data=bytes, new_data=bytes)
         self.__dbbackend.set(ns, data_map)
 
     @func_arg_checker(SdlTypeError, 1, ns=str, key=str, old_data=bytes, new_data=bytes)
@@ -199,6 +210,81 @@ class SyncStorage(SyncStorageAbc):
     def group_size(self, ns: str, group: str) -> int:
         return self.__dbbackend.group_size(ns, group)
 
     def group_size(self, ns: str, group: str) -> int:
         return self.__dbbackend.group_size(ns, group)
 
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, data_map=dict)
+    def set_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                        data_map: Dict[str, bytes]) -> None:
+        self._validate_key_value_dict(data_map)
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        self.__dbbackend.set_and_publish(ns, channels_and_events, data_map)
+
+    @func_arg_checker(SdlTypeError,
+                      1,
+                      ns=str,
+                      channels_and_events=dict,
+                      key=str,
+                      old_data=bytes,
+                      new_data=bytes)
+    def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                           key: str, old_data: bytes, new_data: bytes) -> bool:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        return self.__dbbackend.set_if_and_publish(ns, channels_and_events, key, old_data, new_data)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, key=str, data=bytes)
+    def set_if_not_exists_and_publish(self, ns: str,
+                                      channels_and_events: Dict[str, Union[str, List[str]]],
+                                      key: str, data: bytes) -> bool:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        return self.__dbbackend.set_if_not_exists_and_publish(ns, channels_and_events, key, data)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, keys=(str, builtins.set))
+    def remove_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                           keys: Union[str, Set[str]]) -> None:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        keys = [keys] if isinstance(keys, str) else list(keys)
+        self.__dbbackend.remove_and_publish(ns, channels_and_events, keys)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict, key=str, data=bytes)
+    def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, Union[str, List[str]]],
+                              key: str, data: bytes) -> bool:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        return self.__dbbackend.remove_if_and_publish(ns, channels_and_events, key, data)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels_and_events=dict)
+    def remove_all_and_publish(self, ns: str,
+                               channels_and_events: Dict[str, Union[str, List[str]]]) -> None:
+        self._validate_channels_events(channels_and_events)
+        for channel, events in channels_and_events.items():
+            channels_and_events[channel] = [events] if isinstance(events, str) else events
+        self.__dbbackend.remove_all_and_publish(ns, channels_and_events)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, cb=Callable, channels=(str, builtins.set))
+    def subscribe_channel(self, ns: str, cb: Callable[[str, List[str]], None],
+                          channels: Union[str, Set[str]]) -> None:
+        self._validate_callback(cb)
+        channels = [channels] if isinstance(channels, str) else list(channels)
+        self.__dbbackend.subscribe_channel(ns, cb, channels)
+
+    @func_arg_checker(SdlTypeError, 1, ns=str, channels=(str, builtins.set))
+    def unsubscribe_channel(self, ns: str, channels: Union[str, Set[str]]) -> None:
+        channels = [channels] if isinstance(channels, str) else list(channels)
+        self.__dbbackend.unsubscribe_channel(ns, channels)
+
+    def start_event_listener(self) -> None:
+        self.__dbbackend.start_event_listener()
+
+    def handle_events(self) -> Optional[Tuple[str, List[str]]]:
+        return self.__dbbackend.handle_events()
+
     @func_arg_checker(SdlTypeError, 1, ns=str, resource=str, expiration=(int, float))
     def get_lock_resource(self, ns: str, resource: str, expiration: Union[int, float]) -> SyncLock:
         return SyncLock(ns, resource, expiration, self)
     @func_arg_checker(SdlTypeError, 1, ns=str, resource=str, expiration=(int, float))
     def get_lock_resource(self, ns: str, resource: str, expiration: Union[int, float]) -> SyncLock:
         return SyncLock(ns, resource, expiration, self)
@@ -210,3 +296,39 @@ class SyncStorage(SyncStorageAbc):
     def get_configuration(self) -> _Configuration:
         """Return configuration what was valid when the SDL instance was initiated."""
         return self.__configuration
     def get_configuration(self) -> _Configuration:
         """Return configuration what was valid when the SDL instance was initiated."""
         return self.__configuration
+
+    @classmethod
+    def _validate_key_value_dict(cls, kv):
+        for k, v in kv.items():
+            if not isinstance(k, str):
+                raise SdlTypeError(r"Wrong dict key type: {}={}. Must be: str".format(k, type(k)))
+            if not isinstance(v, bytes):
+                raise SdlTypeError(r"Wrong dict value type: {}={}. Must be: bytes".format(v, type(v)))
+
+    def _validate_channels_events(self, channels_and_events: Dict[Any, Any]):
+        for channel, events in channels_and_events.items():
+            if not isinstance(channel, str):
+                raise SdlTypeError(r"Wrong channel type: {}={}. Must be: str".format(
+                    channel, type(channel)))
+            if not isinstance(events, (list, str)):
+                raise SdlTypeError(r"Wrong event type: {}={}. Must be: str".format(
+                    events, type(events)))
+            if isinstance(events, list):
+                for event in events:
+                    if not isinstance(event, str):
+                        raise SdlTypeError(r"Wrong event type: {}={}. Must be: str".format(
+                            events, type(events)))
+                    if self.event_separator in event:
+                        raise SdlTypeError(r"Events {} contains illegal substring (\"{}\")".format(
+                            events, self.event_separator))
+            else:
+                if self.event_separator in events:
+                    raise SdlTypeError(r"Events {} contains illegal substring (\"{}\")".format(
+                        events, self.event_separator))
+
+    @classmethod
+    def _validate_callback(cls, cb):
+        param_len = len(inspect.signature(cb).parameters)
+        if param_len != 2:
+            raise SdlTypeError(
+                f"Callback function should take 2 positional argument but {param_len} were given")