Update watcher worker
[pti/o2.git] / o2ims / adapter / unit_of_work.py
diff --git a/o2ims/adapter/unit_of_work.py b/o2ims/adapter/unit_of_work.py
new file mode 100644 (file)
index 0000000..c958ce2
--- /dev/null
@@ -0,0 +1,53 @@
+# Copyright (C) 2021 Wind River Systems, Inc.\r
+#\r
+#  Licensed under the Apache License, Version 2.0 (the "License");\r
+#  you may not use this file except in compliance with the License.\r
+#  You may obtain a copy of the License at\r
+#\r
+#      http://www.apache.org/licenses/LICENSE-2.0\r
+#\r
+#  Unless required by applicable law or agreed to in writing, software\r
+#  distributed under the License is distributed on an "AS IS" BASIS,\r
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r
+#  See the License for the specific language governing permissions and\r
+#  limitations under the License.\r
+\r
+# pylint: disable=attribute-defined-outside-init\r
+from __future__ import annotations\r
+from sqlalchemy import create_engine\r
+from sqlalchemy.orm import sessionmaker\r
+from sqlalchemy.orm.session import Session\r
+\r
+from o2ims import config\r
+from o2ims.adapter.ocloud_repository import OcloudSqlAlchemyRepository\r
+from o2ims.adapter.stx_repository import StxObjectSqlAlchemyRepository\r
+from o2ims.service.unit_of_work import AbstractUnitOfWork\r
+\r
+\r
+DEFAULT_SESSION_FACTORY = sessionmaker(\r
+    bind=create_engine(\r
+        config.get_postgres_uri(),\r
+        isolation_level="REPEATABLE READ",\r
+    )\r
+)\r
+\r
+\r
+class SqlAlchemyUnitOfWork(AbstractUnitOfWork):\r
+    def __init__(self, session_factory=DEFAULT_SESSION_FACTORY):\r
+        self.session_factory = session_factory\r
+\r
+    def __enter__(self):\r
+        self.session = self.session_factory()  # type: Session\r
+        self.oclouds = OcloudSqlAlchemyRepository(self.session)\r
+        self.stxobjects = StxObjectSqlAlchemyRepository(self.session)\r
+        return super().__enter__()\r
+\r
+    def __exit__(self, *args):\r
+        super().__exit__(*args)\r
+        self.session.close()\r
+\r
+    def _commit(self):\r
+        self.session.commit()\r
+\r
+    def rollback(self):\r
+        self.session.rollback()\r