Pack notifications to be compatible with SDL golang
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / fake_dict_db.py
old mode 100644 (file)
new mode 100755 (executable)
index b6c84a2..1a63ebd
 
 """The module provides fake implementation of Shared Data Layer (SDL) database backend interface."""
 import fnmatch
 
 """The module provides fake implementation of Shared Data Layer (SDL) database backend interface."""
 import fnmatch
-from typing import (Dict, Set, List, Union)
+from typing import (Callable, Dict, Set, List, Optional, Tuple, Union)
+import queue
+import threading
+import time
 from ricsdl.configuration import _Configuration
 from .dbbackend_abc import DbBackendAbc
 from .dbbackend_abc import DbBackendLockAbc
 from ricsdl.configuration import _Configuration
 from .dbbackend_abc import DbBackendAbc
 from .dbbackend_abc import DbBackendLockAbc
@@ -43,6 +46,10 @@ class FakeDictBackend(DbBackendAbc):
         super().__init__()
         self._db = {}
         self._configuration = configuration
         super().__init__()
         self._db = {}
         self._configuration = configuration
+        self._queue = queue.Queue(1)
+        self._channel_cbs = {}
+        self._listen_thread = threading.Thread(target=self._listen, daemon=True)
+        self._run_in_thread = False
 
     def __str__(self):
         return str(
 
     def __str__(self):
         return str(
@@ -138,6 +145,89 @@ class FakeDictBackend(DbBackendAbc):
             return 0
         return len(self._db[group])
 
             return 0
         return len(self._db[group])
 
+    def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                        data_map: Dict[str, bytes]) -> None:
+        self._db.update(data_map.copy())
+        for channel, events in channels_and_events.items():
+            self._queue.put((channel, events))
+
+    def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
+                           old_data: bytes, new_data: bytes) -> bool:
+        if self.set_if(ns, key, old_data, new_data):
+            for channel, events in channels_and_events.items():
+                self._queue.put((channel, events))
+            return True
+        return False
+
+    def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                                      key: str, data: bytes) -> bool:
+        if self.set_if_not_exists(ns, key, data):
+            for channel, events in channels_and_events.items():
+                self._queue.put((channel, events))
+            return True
+        return False
+
+    def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                           keys: List[str]) -> None:
+        for key in keys:
+            self._db.pop(key, None)
+        for channel, events in channels_and_events.items():
+            self._queue.put((channel, events))
+
+    def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
+                              data: bytes) -> bool:
+        if self.remove_if(ns, key, data):
+            for channel, events in channels_and_events.items():
+                self._queue.put((channel, events))
+            return True
+        return False
+
+    def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None:
+        # Note: Since fake db has only one namespace, this deletes all keys
+        self._db.clear()
+        for channel, events in channels_and_events.items():
+            self._queue.put((channel, events))
+
+    def subscribe_channel(self, ns: str,
+                          cb: Union[Callable[[str, str], None], Callable[[str, List[str]], None]],
+                          channels: List[str]) -> None:
+        for channel in channels:
+            self._channel_cbs[channel] = cb
+            if not self._listen_thread.is_alive() and self._run_in_thread:
+                self._listen_thread.start()
+
+    def _listen(self):
+        while True:
+            message = self._queue.get()
+            cb = self._channel_cbs.get(message[0], None)
+            if cb:
+                cb(message[0], message[1][0] if (isinstance(message[1], list) and len(message[1]) == 1) else message[1])
+            time.sleep(0.001)
+
+    def unsubscribe_channel(self, ns: str, channels: List[str]) -> None:
+        for channel in channels:
+            self._channel_cbs.pop(channel, None)
+
+    def start_event_listener(self) -> None:
+        if self._listen_thread.is_alive():
+            raise Exception("Event loop already started")
+        if len(self._channel_cbs) > 0:
+            self._listen_thread.start()
+        self._run_in_thread = True
+
+    def handle_events(self) -> Optional[Union[Tuple[str, str], Tuple[str, List[str]]]]:
+        if self._listen_thread.is_alive() or self._run_in_thread:
+            raise Exception("Event loop already started")
+        try:
+            message = self._queue.get(block=False)
+        except queue.Empty:
+            return None
+        cb = self._channel_cbs.get(message[0], None)
+        notifications = message[1][0] if (isinstance(message[1], list) and len(message[1]) == 1) else message[1]
+        if cb:
+            cb(message[0], notifications)
+        return (message[0], notifications)
+
 
 class FakeDictBackendLock(DbBackendLockAbc):
     """
 
 class FakeDictBackendLock(DbBackendLockAbc):
     """