Add support for notifications
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / fake_dict_db.py
old mode 100644 (file)
new mode 100755 (executable)
index b6c84a2..5a49f10
 
 """The module provides fake implementation of Shared Data Layer (SDL) database backend interface."""
 import fnmatch
-from typing import (Dict, Set, List, Union)
+from typing import (Callable, Dict, Set, List, Optional, Tuple, Union)
+import queue
+import threading
+import time
 from ricsdl.configuration import _Configuration
 from .dbbackend_abc import DbBackendAbc
 from .dbbackend_abc import DbBackendLockAbc
@@ -43,6 +46,10 @@ class FakeDictBackend(DbBackendAbc):
         super().__init__()
         self._db = {}
         self._configuration = configuration
+        self._queue = queue.Queue(1)
+        self._channel_cbs = {}
+        self._listen_thread = threading.Thread(target=self._listen, daemon=True)
+        self._run_in_thread = False
 
     def __str__(self):
         return str(
@@ -138,6 +145,93 @@ class FakeDictBackend(DbBackendAbc):
             return 0
         return len(self._db[group])
 
+    def set_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                        data_map: Dict[str, bytes]) -> None:
+        self._db.update(data_map.copy())
+        for channel, events in channels_and_events.items():
+            for event in events:
+                self._queue.put((channel, event))
+
+    def set_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
+                           old_data: bytes, new_data: bytes) -> bool:
+        if self.set_if(ns, key, old_data, new_data):
+            for channel, events in channels_and_events.items():
+                for event in events:
+                    self._queue.put((channel, event))
+            return True
+        return False
+
+    def set_if_not_exists_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                                      key: str, data: bytes) -> bool:
+        if self.set_if_not_exists(ns, key, data):
+            for channel, events in channels_and_events.items():
+                for event in events:
+                    self._queue.put((channel, event))
+            return True
+        return False
+
+    def remove_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]],
+                           keys: List[str]) -> None:
+        for key in keys:
+            self._db.pop(key, None)
+        for channel, events in channels_and_events.items():
+            for event in events:
+                self._queue.put((channel, event))
+
+    def remove_if_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]], key: str,
+                              data: bytes) -> bool:
+        if self.remove_if(ns, key, data):
+            for channel, events in channels_and_events.items():
+                for event in events:
+                    self._queue.put((channel, event))
+            return True
+        return False
+
+    def remove_all_and_publish(self, ns: str, channels_and_events: Dict[str, List[str]]) -> None:
+        # Note: Since fake db has only one namespace, this deletes all keys
+        self._db.clear()
+        for channel, events in channels_and_events.items():
+            for event in events:
+                self._queue.put((channel, event))
+
+    def subscribe_channel(self, ns: str, cb: Callable[[str, str], None],
+                          channels: List[str]) -> None:
+        for channel in channels:
+            self._channel_cbs[channel] = cb
+            if not self._listen_thread.is_alive() and self._run_in_thread:
+                self._listen_thread.start()
+
+    def _listen(self):
+        while True:
+            message = self._queue.get()
+            cb = self._channel_cbs.get(message[0], None)
+            if cb:
+                cb(message[0], message[1])
+            time.sleep(0.001)
+
+    def unsubscribe_channel(self, ns: str, channels: List[str]) -> None:
+        for channel in channels:
+            self._channel_cbs.pop(channel, None)
+
+    def start_event_listener(self) -> None:
+        if self._listen_thread.is_alive():
+            raise Exception("Event loop already started")
+        if len(self._channel_cbs) > 0:
+            self._listen_thread.start()
+        self._run_in_thread = True
+
+    def handle_events(self) -> Optional[Tuple[str, str]]:
+        if self._listen_thread.is_alive() or self._run_in_thread:
+            raise Exception("Event loop already started")
+        try:
+            message = self._queue.get(block=False)
+        except queue.Empty:
+            return None
+        cb = self._channel_cbs.get(message[0], None)
+        if cb:
+            cb(message[0], message[1])
+        return (message[0], message[1])
+
 
 class FakeDictBackendLock(DbBackendLockAbc):
     """