test case addtion for TM 96/13996/1
authorMonosij Ghosh <mono.ghosh@samsung.com>
Fri, 3 Jan 2025 11:57:00 +0000 (17:27 +0530)
committersubhash kumar singh <subh.singh@samsung.com>
Mon, 6 Jan 2025 05:36:10 +0000 (05:36 +0000)
added test cases for trainingjob_controller file

Change-Id: If43c6e85631f7068a7df64322acf18aaa8cf9337
Signed-off-by: Monosij Ghosh <mono.ghosh@samsung.com>
(cherry picked from commit 7655e09d74b18b26da1da1b51bfc53c82668974b)

tests/test_trainingjob_controller.py [new file with mode: 0644]

diff --git a/tests/test_trainingjob_controller.py b/tests/test_trainingjob_controller.py
new file mode 100644 (file)
index 0000000..a7aa34e
--- /dev/null
@@ -0,0 +1,257 @@
+# ==================================================================================
+#
+#      Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+#
+# ==================================================================================
+import pytest
+from flask import Flask
+import json
+import requests
+from unittest import mock
+from unittest.mock import patch, MagicMock
+import pytest
+from requests.models import Response
+from threading import Lock
+import os
+import sys
+import datetime
+from flask_api import status
+from dotenv import load_dotenv
+load_dotenv('tests/test.env')
+from trainingmgr.constants.states import States
+from threading import Lock
+from trainingmgr.common.tmgr_logger import TMLogger
+from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
+from trainingmgr.common.exceptions_utls import DBException, TMException
+from trainingmgr.models import TrainingJob
+from trainingmgr.models import FeatureGroup
+from trainingmgr.common.trainingConfig_parser import getField
+
+#mock ModelMetricsSdk before importing
+mock_modelmetrics_sdk = MagicMock()
+sys.modules["trainingmgr.handler.async_handler"] = MagicMock(ModelMetricsSdk=mock_modelmetrics_sdk)
+from trainingmgr.controller.trainingjob_controller import training_job_controller
+from trainingmgr import trainingmgr_main
+
+trainingmgr_main.LOGGER = pytest.logger
+trainingmgr_main.LOCK = Lock()
+trainingmgr_main.DATAEXTRACTION_JOBS_CACHE = {}
+
+
+class Test_create_trainingjob:
+    def setup_method(self):
+        app = Flask(__name__)
+        app.register_blueprint(training_job_controller)
+        self.client = app.test_client()
+
+    mocked_TRAININGMGR_CONFIG_OBJ = mock.Mock(name="TRAININGMGR_CONFIG_OBJ")
+    attrs_TRAININGMGR_CONFIG_OBJ = {'kf_adapter_ip.return_value': '123', 'kf_adapter_port.return_value': '100'}
+    mocked_TRAININGMGR_CONFIG_OBJ.configure_mock(**attrs_TRAININGMGR_CONFIG_OBJ)
+
+    def test_create_trainingjob_missing_training_config(self):
+        trainingmgr_main.LOGGER.debug("******* test_create_trainingjob_missing_training_config *******")
+        expected_data = "The training_config is missing"
+        trainingjob_req = {
+                    "modelId":{
+                        "modelname": "modeltest",
+                        "modelversion": "1"
+                    }
+            }
+        response = self.client.post("/training-jobs", data = json.dumps(trainingjob_req),
+                                    content_type="application/json")
+        trainingmgr_main.LOGGER.debug(response.data)
+        print(response)
+        assert response.status_code == status.HTTP_400_BAD_REQUEST
+        assert expected_data in str(response.data)
+
+    def test_create_trainingjob_invalid_training_config(self):
+        trainingmgr_main.LOGGER.debug("******* test_create_trainingjob_invalid_training_config *******")
+        expected_data = "The TrainingConfig is not correct"
+        trainingjob_req = {
+                    "modelId":{
+                        "modelname": "modeltest",
+                        "modelversion": "1"
+                    },
+                    "training_config": {
+                        "description": "trainingjob for testing"
+                    }
+            }
+        response = self.client.post("/training-jobs", data=json.dumps(trainingjob_req),
+                                    content_type="application/json")
+        trainingmgr_main.LOGGER.debug(response.data)
+        assert response.status_code == status.HTTP_400_BAD_REQUEST
+        assert expected_data in str(response.data)
+
+    @patch('trainingmgr.controller.trainingjob_controller.get_modelinfo_by_modelId_service', return_value = None)
+    def test_create_trainingjob_model_not_registered(self, mock1):
+        trainingmgr_main.LOGGER.debug("******* test_create_trainingjob_model_not_registered *******")
+        expected_data = "modelId test_model and 1 is not registered at MME, Please first register at MME and then continue"
+        trainingjob_req = {
+            "modelId":{
+                "modelname": "test_model", 
+                "modelversion": "1"
+            },
+            "model_location": "",
+            "training_config": {
+                    "description": "trainingjob for testing",
+                    "dataPipeline": {
+                        "feature_group_name": "testing_influxdb_01",
+                        "query_filter": "",
+                        "arguments": "{'epochs': 1'}"
+                    },
+                    "trainingPipeline": {
+                        "training_pipeline_name": "qoe_Pipeline",
+                        "training_pipeline_version": "qoe_Pipeline",
+                        "retraining_pipeline_name": "qoe_PipelineRetrain",
+                        "retraining_pipeline_version": "qoe_PipelineRetrain",
+                    }
+            },
+        }
+        response = self.client.post("/training-jobs", data=json.dumps(trainingjob_req), content_type="application/json")
+        trainingmgr_main.LOGGER.debug(response.data)
+        print(response.data)
+        assert response.status_code == status.HTTP_400_BAD_REQUEST
+        assert expected_data in str(response.data)
+
+    registered_model_list = [{"modelLocation": "s3://different-location"}]
+    @patch('trainingmgr.controller.trainingjob_controller.get_modelinfo_by_modelId_service', return_value=registered_model_list)
+    def test_create_trainingjob_model_location_mismatch(self, mock1):
+        trainingmgr_main.LOGGER.debug("******* test_create_trainingjob_model_location_mismatch *******")
+        expected_data = "modelId test_model and 1 and trainingjob created does not have same modelLocation, Please first register at MME properly and then continue"
+        trainingjob_req = {
+            "modelId":{
+                "modelname": "test_model", 
+                "modelversion": "1"
+            },
+            "model_location": "s3://model-location",
+            "training_config": {
+                    "description": "trainingjob for testing",
+                    "dataPipeline": {
+                        "feature_group_name": "testing_influxdb_01",
+                        "query_filter": "",
+                        "arguments": "{'epochs': 1'}"
+                    },
+                    "trainingPipeline": {
+                        "training_pipeline_name": "qoe_Pipeline",
+                        "training_pipeline_version": "qoe_Pipeline",
+                        "retraining_pipeline_name": "qoe_PipelineRetrain",
+                        "retraining_pipeline_version": "qoe_PipelineRetrain",
+                    }
+            },
+        }
+        response = self.client.post("/training-jobs", data=json.dumps(trainingjob_req), content_type="application/json")
+        print(response.data)
+        trainingmgr_main.LOGGER.debug(response.data)
+        assert response.status_code == status.HTTP_400_BAD_REQUEST
+        assert expected_data in str(response.data)
+
+class Test_DeleteTrainingJob:
+    def setup_method(self):
+        app = Flask(__name__)
+        app.register_blueprint(training_job_controller)
+        self.client = app.test_client()
+
+    @patch('trainingmgr.controller.trainingjob_controller.delete_training_job', return_value=True)
+    def test_delete_trainingjob_success(self, mock1):
+        response = self.client.delete("/training-jobs/{}".format("123"))
+        trainingmgr_main.LOGGER.debug(response.data)
+        assert response.status_code == status.HTTP_204_NO_CONTENT
+    
+    @patch('trainingmgr.controller.trainingjob_controller.delete_training_job', return_value=False)
+    def test_delete_trainingjob_not_found(self, mock1):
+        expected_data = {'message': 'training job with given id is not found'}
+        
+        response = self.client.delete("/training-jobs/{}".format("123"))
+        trainingmgr_main.LOGGER.debug(response.data)
+        assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
+        assert expected_data == response.json
+
+
+class Test_GetTrainingJobs:
+    def setup_method(self):
+        app = Flask(__name__)
+        app.register_blueprint(training_job_controller)
+        self.client = app.test_client()
+    
+    tjs = [{'id': 1, 'name': 'Test Job'}]
+    @patch('trainingmgr.controller.trainingjob_controller.get_trainining_jobs', return_value = tjs)
+    @patch('trainingmgr.controller.trainingjob_controller.trainingjobs_schema.dump', return_value = tjs)
+    def test_get_trainingjobs_success(self, mock1, mock2):
+        expected_data = [{"id": 1, "name": "Test Job"}]
+        response = self.client.get('/training-jobs/')
+        assert response.status_code == 200
+        assert expected_data == response.json
+
+    @patch('trainingmgr.controller.trainingjob_controller.get_trainining_jobs')
+    def test_get_trainingjobs_tmexception(self, mock_get_trainingjobs):
+        mock_get_trainingjobs.side_effect = TMException('Training jobs not found')
+
+        response = self.client.get('/training-jobs/')
+        assert response.status_code == 400
+        assert response.json['message'] == 'Training jobs not found'
+
+    @patch('trainingmgr.controller.trainingjob_controller.get_trainining_jobs')
+    def test_get_trainingjobs_generic_exception(self, mock_get_trainingjobs):
+        mock_get_trainingjobs.side_effect = Exception('Unexpected error')
+        response = self.client.get('/training-jobs/')
+        assert response.status_code == 500
+        assert response.json['message'] == 'Unexpected error'
+
+class Test_GetTrainingJob:
+    def setup_method(self):
+        app = Flask(__name__)
+        app.register_blueprint(training_job_controller)
+        self.client = app.test_client()
+
+    tj = {'id': 1, 'name': 'Test Job'}
+    @patch('trainingmgr.controller.trainingjob_controller.get_training_job', return_value = tj)
+    @patch('trainingmgr.controller.trainingjob_controller.trainingjob_schema.dump', return_value = tj)
+    def test_get_trainingjob_success(self, mock_schema_dump, mock_get_training_job):
+        response = self.client.get('/training-jobs/1')
+        assert response.status_code == 200
+        assert response.json == {'id': 1, 'name': 'Test Job'}
+
+    @patch('trainingmgr.controller.trainingjob_controller.get_training_job')
+    def test_get_trainingjob_tmexception(self, mock_get_training_job):
+        # Simulate TMException
+        mock_get_training_job.side_effect = TMException('Training job not found')
+
+        response = self.client.get('/training-jobs/1')
+        assert response.status_code == 400
+        assert response.json['message'] == 'Training job not found'
+
+    @patch('trainingmgr.controller.trainingjob_controller.get_training_job')
+    def test_get_trainingjob_generic_exception(self, mock_get_training_job):
+        mock_get_training_job.side_effect = Exception('Unexpected error')
+        response = self.client.get('/training-jobs/1')
+        assert response.status_code == 500
+        assert response.json['message'] == 'Unexpected error'
+
+
+class Test_GetTrainingJobStatus:
+    def setup_method(self):
+        app = Flask(__name__)
+        app.register_blueprint(training_job_controller)
+        self.client = app.test_client()
+
+    expected_data = {"status": "running"}
+    @patch('trainingmgr.controller.trainingjob_controller.get_steps_state', return_value=json.dumps(expected_data))
+    def test_get_trainingjob_status(self, mock1):
+        expected_data = {"status": "running"}
+        response = self.client.get("/training-jobs/{}/status".format("123"))
+        trainingmgr_main.LOGGER.debug(response.data)
+        assert response.status_code == status.HTTP_200_OK
+        assert expected_data == response.json
+