Add implementation of SDL in python
[ric-plt/sdlpy.git] / ricsdl-package / ricsdl / backend / redis.py
diff --git a/ricsdl-package/ricsdl/backend/redis.py b/ricsdl-package/ricsdl/backend/redis.py
new file mode 100644 (file)
index 0000000..afa7450
--- /dev/null
@@ -0,0 +1,317 @@
+# Copyright (c) 2019 AT&T Intellectual Property.
+# Copyright (c) 2018-2019 Nokia.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#
+# This source code is part of the near-RT RIC (RAN Intelligent Controller)
+# platform project (RICP).
+#
+
+
+"""The module provides implementation of Shared Data Layer (SDL) database backend interface."""
+import contextlib
+from typing import (Dict, Set, List, Union)
+from redis import Redis
+from redis.sentinel import Sentinel
+from redis.lock import Lock
+from redis._compat import nativestr
+from redis import exceptions as redis_exceptions
+from ricsdl.configuration import _Configuration
+from ricsdl.exceptions import (
+    RejectedByBackend,
+    NotConnected,
+    BackendError
+)
+from .dbbackend_abc import DbBackendAbc
+from .dbbackend_abc import DbBackendLockAbc
+
+
+@contextlib.contextmanager
+def _map_to_sdl_exception():
+    """Translates known redis exceptions into SDL exceptions."""
+    try:
+        yield
+    except(redis_exceptions.ResponseError) as exc:
+        raise RejectedByBackend("SDL backend rejected the request: {}".
+                                format(str(exc))) from exc
+    except(redis_exceptions.ConnectionError, redis_exceptions.TimeoutError) as exc:
+        raise NotConnected("SDL not connected to backend: {}".
+                           format(str(exc))) from exc
+    except(redis_exceptions.RedisError) as exc:
+        raise BackendError("SDL backend failed to process the request: {}".
+                           format(str(exc))) from exc
+
+
+class RedisBackend(DbBackendAbc):
+    """
+    A class providing an implementation of database backend of Shared Data Layer (SDL), when
+    backend database solution is Redis.
+
+    Args:
+        configuration (_Configuration): SDL configuration, containing credentials to connect to
+                                        Redis database backend.
+    """
+    def __init__(self, configuration: _Configuration) -> None:
+        super().__init__()
+        with _map_to_sdl_exception():
+            if configuration.get_params().db_sentinel_port:
+                sentinel_node = (configuration.get_params().db_host,
+                                 configuration.get_params().db_sentinel_port)
+                master_name = configuration.get_params().db_sentinel_master_name
+                self.__sentinel = Sentinel([sentinel_node])
+                self.__redis = self.__sentinel.master_for(master_name)
+            else:
+                self.__redis = Redis(host=configuration.get_params().db_host,
+                                     port=configuration.get_params().db_port,
+                                     db=0,
+                                     max_connections=20)
+        self.__redis.set_response_callback('SETIE', lambda r: r and nativestr(r) == 'OK' or False)
+        self.__redis.set_response_callback('DELIE', lambda r: r and int(r) == 1 or False)
+
+    def __del__(self):
+        self.close()
+
+    def __str__(self):
+        return str(
+            {
+                "Redis connection": repr(self.__redis)
+            }
+        )
+
+    def close(self):
+        self.__redis.close()
+
+    def set(self, ns: str, data_map: Dict[str, bytes]) -> None:
+        db_data_map = self._add_data_map_ns_prefix(ns, data_map)
+        with _map_to_sdl_exception():
+            self.__redis.mset(db_data_map)
+
+    def set_if(self, ns: str, key: str, old_data: bytes, new_data: bytes) -> bool:
+        db_key = self._add_key_ns_prefix(ns, key)
+        with _map_to_sdl_exception():
+            return self.__redis.execute_command('SETIE', db_key, new_data, old_data)
+
+    def set_if_not_exists(self, ns: str, key: str, data: bytes) -> bool:
+        db_key = self._add_key_ns_prefix(ns, key)
+        with _map_to_sdl_exception():
+            return self.__redis.setnx(db_key, data)
+
+    def get(self, ns: str, keys: List[str]) -> Dict[str, bytes]:
+        ret = dict()
+        db_keys = self._add_keys_ns_prefix(ns, keys)
+        with _map_to_sdl_exception():
+            values = self.__redis.mget(db_keys)
+            for idx, val in enumerate(values):
+                # return only key values, which has a value
+                if val:
+                    ret[keys[idx]] = val
+            return ret
+
+    def find_keys(self, ns: str, key_prefix: str) -> List[str]:
+        escaped_key_prefix = self._escape_characters(key_prefix)
+        db_escaped_key_prefix = self._add_key_ns_prefix(ns, escaped_key_prefix + '*')
+        with _map_to_sdl_exception():
+            ret = self.__redis.keys(db_escaped_key_prefix)
+            return self._strip_ns_from_bin_keys(ns, ret)
+
+    def find_and_get(self, ns: str, key_prefix: str, atomic: bool) -> Dict[str, bytes]:
+        # todo: replace below implementation with redis 'NGET' module
+        ret = dict()  # type: Dict[str, bytes]
+        with _map_to_sdl_exception():
+            matched_keys = self.find_keys(ns, key_prefix)
+            if matched_keys:
+                ret = self.get(ns, matched_keys)
+        return ret
+
+    def remove(self, ns: str, keys: List[str]) -> None:
+        db_keys = self._add_keys_ns_prefix(ns, keys)
+        with _map_to_sdl_exception():
+            self.__redis.delete(*db_keys)
+
+    def remove_if(self, ns: str, key: str, data: bytes) -> bool:
+        db_key = self._add_key_ns_prefix(ns, key)
+        with _map_to_sdl_exception():
+            return self.__redis.execute_command('DELIE', db_key, data)
+
+    def add_member(self, ns: str, group: str, members: Set[bytes]) -> None:
+        db_key = self._add_key_ns_prefix(ns, group)
+        with _map_to_sdl_exception():
+            self.__redis.sadd(db_key, *members)
+
+    def remove_member(self, ns: str, group: str, members: Set[bytes]) -> None:
+        db_key = self._add_key_ns_prefix(ns, group)
+        with _map_to_sdl_exception():
+            self.__redis.srem(db_key, *members)
+
+    def remove_group(self, ns: str, group: str) -> None:
+        db_key = self._add_key_ns_prefix(ns, group)
+        with _map_to_sdl_exception():
+            self.__redis.delete(db_key)
+
+    def get_members(self, ns: str, group: str) -> Set[bytes]:
+        db_key = self._add_key_ns_prefix(ns, group)
+        with _map_to_sdl_exception():
+            return self.__redis.smembers(db_key)
+
+    def is_member(self, ns: str, group: str, member: bytes) -> bool:
+        db_key = self._add_key_ns_prefix(ns, group)
+        with _map_to_sdl_exception():
+            return self.__redis.sismember(db_key, member)
+
+    def group_size(self, ns: str, group: str) -> int:
+        db_key = self._add_key_ns_prefix(ns, group)
+        with _map_to_sdl_exception():
+            return self.__redis.scard(db_key)
+
+    @classmethod
+    def _add_key_ns_prefix(cls, ns: str, key: str):
+        return '{' + ns + '},' + key
+
+    @classmethod
+    def _add_keys_ns_prefix(cls, ns: str, keylist: List[str]) -> List[str]:
+        ret_nskeys = []
+        for k in keylist:
+            ret_nskeys.append('{' + ns + '},' + k)
+        return ret_nskeys
+
+    @classmethod
+    def _add_data_map_ns_prefix(cls, ns: str, data_dict: Dict[str, bytes]) -> Dict[str, bytes]:
+        ret_nsdict = {}
+        for key, val in data_dict.items():
+            ret_nsdict['{' + ns + '},' + key] = val
+        return ret_nsdict
+
+    @classmethod
+    def _strip_ns_from_bin_keys(cls, ns: str, nskeylist: List[bytes]) -> List[str]:
+        ret_keys = []
+        for k in nskeylist:
+            nskey = k.decode("utf-8").split(',', 1)
+            if len(nskey) != 2:
+                msg = u'Illegal namespace %s key:%s' % (ns, nskey)
+                raise RejectedByBackend(msg)
+            ret_keys.append(nskey[1])
+        return ret_keys
+
+    @classmethod
+    def _escape_characters(cls, pattern: str) -> str:
+        return pattern.translate(str.maketrans(
+            {"(": r"\(",
+             ")": r"\)",
+             "[": r"\[",
+             "]": r"\]",
+             "*": r"\*",
+             "?": r"\?",
+             "\\": r"\\"}))
+
+    def get_redis_connection(self):
+        """Return existing Redis database connection."""
+        return self.__redis
+
+
+class RedisBackendLock(DbBackendLockAbc):
+    """
+    A class providing an implementation of database backend lock of Shared Data Layer (SDL), when
+    backend database solution is Redis.
+
+    Args:
+        ns (str): Namespace under which this lock is targeted.
+        name (str): Lock name, identifies the lock key in a Redis database backend.
+        expiration (int, float): Lock expiration time after which the lock is removed if it hasn't
+                                 been released earlier by a 'release' method.
+        redis_backend (RedisBackend): Database backend object containing connection to Redis
+                                      database.
+    """
+    lua_get_validity_time = None
+    # KEYS[1] - lock name
+    # ARGS[1] - token
+    # return < 0 in case of failure, otherwise return lock validity time in milliseconds.
+    LUA_GET_VALIDITY_TIME_SCRIPT = """
+        local token = redis.call('get', KEYS[1])
+        if not token then
+            return -10
+        end
+        if token ~= ARGV[1] then
+            return -11
+        end
+        return redis.call('pttl', KEYS[1])
+    """
+
+    def __init__(self, ns: str, name: str, expiration: Union[int, float],
+                 redis_backend: RedisBackend) -> None:
+        super().__init__(ns, name)
+        self.__redis = redis_backend.get_redis_connection()
+        with _map_to_sdl_exception():
+            redis_lockname = '{' + ns + '},' + self._lock_name
+            self.__redis_lock = Lock(redis=self.__redis, name=redis_lockname, timeout=expiration)
+            self._register_scripts()
+
+    def __str__(self):
+        return str(
+            {
+                "lock namespace": self._ns,
+                "lock name": self._lock_name,
+                "lock status": self._lock_status_to_string()
+            }
+        )
+
+    def acquire(self, retry_interval: Union[int, float] = 0.1,
+                retry_timeout: Union[int, float] = 10) -> bool:
+        succeeded = False
+        self.__redis_lock.sleep = retry_interval
+        with _map_to_sdl_exception():
+            succeeded = self.__redis_lock.acquire(blocking_timeout=retry_timeout)
+        return succeeded
+
+    def release(self) -> None:
+        with _map_to_sdl_exception():
+            self.__redis_lock.release()
+
+    def refresh(self) -> None:
+        with _map_to_sdl_exception():
+            self.__redis_lock.reacquire()
+
+    def get_validity_time(self) -> Union[int, float]:
+        validity = 0
+        if self.__redis_lock.local.token is None:
+            msg = u'Cannot get validity time of an unlocked lock %s' % self._lock_name
+            raise RejectedByBackend(msg)
+
+        with _map_to_sdl_exception():
+            validity = self.lua_get_validity_time(keys=[self.__redis_lock.name],
+                                                  args=[self.__redis_lock.local.token],
+                                                  client=self.__redis)
+        if validity < 0:
+            msg = (u'Getting validity time of a lock %s failed with error code: %d'
+                   % (self._lock_name, validity))
+            raise RejectedByBackend(msg)
+        ftime = validity / 1000.0
+        if ftime.is_integer():
+            return int(ftime)
+        return ftime
+
+    def _register_scripts(self):
+        cls = self.__class__
+        client = self.__redis
+        if cls.lua_get_validity_time is None:
+            cls.lua_get_validity_time = client.register_script(cls.LUA_GET_VALIDITY_TIME_SCRIPT)
+
+    def _lock_status_to_string(self) -> str:
+        try:
+            if self.__redis_lock.locked():
+                if self.__redis_lock.owned():
+                    return 'locked'
+                return 'locked by someone else'
+            return 'unlocked'
+        except(redis_exceptions.RedisError) as exc:
+            return f'Error: {str(exc)}'