From: Monosij Ghosh Date: Fri, 3 Jan 2025 11:57:00 +0000 (+0530) Subject: test case addtion for TM X-Git-Tag: 4.0.0~35 X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=commitdiff_plain;h=7655e09d74b18b26da1da1b51bfc53c82668974b;p=aiml-fw%2Fawmf%2Ftm.git test case addtion for TM added test cases for trainingjob_controller file Change-Id: If43c6e85631f7068a7df64322acf18aaa8cf9337 Signed-off-by: Monosij Ghosh --- diff --git a/tests/test_trainingjob_controller.py b/tests/test_trainingjob_controller.py new file mode 100644 index 0000000..a7aa34e --- /dev/null +++ b/tests/test_trainingjob_controller.py @@ -0,0 +1,257 @@ +# ================================================================================== +# +# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== +import pytest +from flask import Flask +import json +import requests +from unittest import mock +from unittest.mock import patch, MagicMock +import pytest +from requests.models import Response +from threading import Lock +import os +import sys +import datetime +from flask_api import status +from dotenv import load_dotenv +load_dotenv('tests/test.env') +from trainingmgr.constants.states import States +from threading import Lock +from trainingmgr.common.tmgr_logger import TMLogger +from trainingmgr.common.trainingmgr_config import TrainingMgrConfig +from trainingmgr.common.exceptions_utls import DBException, TMException +from trainingmgr.models import TrainingJob +from trainingmgr.models import FeatureGroup +from trainingmgr.common.trainingConfig_parser import getField + +#mock ModelMetricsSdk before importing +mock_modelmetrics_sdk = MagicMock() +sys.modules["trainingmgr.handler.async_handler"] = MagicMock(ModelMetricsSdk=mock_modelmetrics_sdk) +from trainingmgr.controller.trainingjob_controller import training_job_controller +from trainingmgr import trainingmgr_main + +trainingmgr_main.LOGGER = pytest.logger +trainingmgr_main.LOCK = Lock() +trainingmgr_main.DATAEXTRACTION_JOBS_CACHE = {} + + +class Test_create_trainingjob: + def setup_method(self): + app = Flask(__name__) + app.register_blueprint(training_job_controller) + self.client = app.test_client() + + mocked_TRAININGMGR_CONFIG_OBJ = mock.Mock(name="TRAININGMGR_CONFIG_OBJ") + attrs_TRAININGMGR_CONFIG_OBJ = {'kf_adapter_ip.return_value': '123', 'kf_adapter_port.return_value': '100'} + mocked_TRAININGMGR_CONFIG_OBJ.configure_mock(**attrs_TRAININGMGR_CONFIG_OBJ) + + def test_create_trainingjob_missing_training_config(self): + trainingmgr_main.LOGGER.debug("******* test_create_trainingjob_missing_training_config *******") + expected_data = "The training_config is missing" + trainingjob_req = { + "modelId":{ + "modelname": "modeltest", + "modelversion": "1" + } + } + response = self.client.post("/training-jobs", data = json.dumps(trainingjob_req), + content_type="application/json") + trainingmgr_main.LOGGER.debug(response.data) + print(response) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert expected_data in str(response.data) + + def test_create_trainingjob_invalid_training_config(self): + trainingmgr_main.LOGGER.debug("******* test_create_trainingjob_invalid_training_config *******") + expected_data = "The TrainingConfig is not correct" + trainingjob_req = { + "modelId":{ + "modelname": "modeltest", + "modelversion": "1" + }, + "training_config": { + "description": "trainingjob for testing" + } + } + response = self.client.post("/training-jobs", data=json.dumps(trainingjob_req), + content_type="application/json") + trainingmgr_main.LOGGER.debug(response.data) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert expected_data in str(response.data) + + @patch('trainingmgr.controller.trainingjob_controller.get_modelinfo_by_modelId_service', return_value = None) + def test_create_trainingjob_model_not_registered(self, mock1): + trainingmgr_main.LOGGER.debug("******* test_create_trainingjob_model_not_registered *******") + expected_data = "modelId test_model and 1 is not registered at MME, Please first register at MME and then continue" + trainingjob_req = { + "modelId":{ + "modelname": "test_model", + "modelversion": "1" + }, + "model_location": "", + "training_config": { + "description": "trainingjob for testing", + "dataPipeline": { + "feature_group_name": "testing_influxdb_01", + "query_filter": "", + "arguments": "{'epochs': 1'}" + }, + "trainingPipeline": { + "training_pipeline_name": "qoe_Pipeline", + "training_pipeline_version": "qoe_Pipeline", + "retraining_pipeline_name": "qoe_PipelineRetrain", + "retraining_pipeline_version": "qoe_PipelineRetrain", + } + }, + } + response = self.client.post("/training-jobs", data=json.dumps(trainingjob_req), content_type="application/json") + trainingmgr_main.LOGGER.debug(response.data) + print(response.data) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert expected_data in str(response.data) + + registered_model_list = [{"modelLocation": "s3://different-location"}] + @patch('trainingmgr.controller.trainingjob_controller.get_modelinfo_by_modelId_service', return_value=registered_model_list) + def test_create_trainingjob_model_location_mismatch(self, mock1): + trainingmgr_main.LOGGER.debug("******* test_create_trainingjob_model_location_mismatch *******") + expected_data = "modelId test_model and 1 and trainingjob created does not have same modelLocation, Please first register at MME properly and then continue" + trainingjob_req = { + "modelId":{ + "modelname": "test_model", + "modelversion": "1" + }, + "model_location": "s3://model-location", + "training_config": { + "description": "trainingjob for testing", + "dataPipeline": { + "feature_group_name": "testing_influxdb_01", + "query_filter": "", + "arguments": "{'epochs': 1'}" + }, + "trainingPipeline": { + "training_pipeline_name": "qoe_Pipeline", + "training_pipeline_version": "qoe_Pipeline", + "retraining_pipeline_name": "qoe_PipelineRetrain", + "retraining_pipeline_version": "qoe_PipelineRetrain", + } + }, + } + response = self.client.post("/training-jobs", data=json.dumps(trainingjob_req), content_type="application/json") + print(response.data) + trainingmgr_main.LOGGER.debug(response.data) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert expected_data in str(response.data) + +class Test_DeleteTrainingJob: + def setup_method(self): + app = Flask(__name__) + app.register_blueprint(training_job_controller) + self.client = app.test_client() + + @patch('trainingmgr.controller.trainingjob_controller.delete_training_job', return_value=True) + def test_delete_trainingjob_success(self, mock1): + response = self.client.delete("/training-jobs/{}".format("123")) + trainingmgr_main.LOGGER.debug(response.data) + assert response.status_code == status.HTTP_204_NO_CONTENT + + @patch('trainingmgr.controller.trainingjob_controller.delete_training_job', return_value=False) + def test_delete_trainingjob_not_found(self, mock1): + expected_data = {'message': 'training job with given id is not found'} + + response = self.client.delete("/training-jobs/{}".format("123")) + trainingmgr_main.LOGGER.debug(response.data) + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert expected_data == response.json + + +class Test_GetTrainingJobs: + def setup_method(self): + app = Flask(__name__) + app.register_blueprint(training_job_controller) + self.client = app.test_client() + + tjs = [{'id': 1, 'name': 'Test Job'}] + @patch('trainingmgr.controller.trainingjob_controller.get_trainining_jobs', return_value = tjs) + @patch('trainingmgr.controller.trainingjob_controller.trainingjobs_schema.dump', return_value = tjs) + def test_get_trainingjobs_success(self, mock1, mock2): + expected_data = [{"id": 1, "name": "Test Job"}] + response = self.client.get('/training-jobs/') + assert response.status_code == 200 + assert expected_data == response.json + + @patch('trainingmgr.controller.trainingjob_controller.get_trainining_jobs') + def test_get_trainingjobs_tmexception(self, mock_get_trainingjobs): + mock_get_trainingjobs.side_effect = TMException('Training jobs not found') + + response = self.client.get('/training-jobs/') + assert response.status_code == 400 + assert response.json['message'] == 'Training jobs not found' + + @patch('trainingmgr.controller.trainingjob_controller.get_trainining_jobs') + def test_get_trainingjobs_generic_exception(self, mock_get_trainingjobs): + mock_get_trainingjobs.side_effect = Exception('Unexpected error') + response = self.client.get('/training-jobs/') + assert response.status_code == 500 + assert response.json['message'] == 'Unexpected error' + +class Test_GetTrainingJob: + def setup_method(self): + app = Flask(__name__) + app.register_blueprint(training_job_controller) + self.client = app.test_client() + + tj = {'id': 1, 'name': 'Test Job'} + @patch('trainingmgr.controller.trainingjob_controller.get_training_job', return_value = tj) + @patch('trainingmgr.controller.trainingjob_controller.trainingjob_schema.dump', return_value = tj) + def test_get_trainingjob_success(self, mock_schema_dump, mock_get_training_job): + response = self.client.get('/training-jobs/1') + assert response.status_code == 200 + assert response.json == {'id': 1, 'name': 'Test Job'} + + @patch('trainingmgr.controller.trainingjob_controller.get_training_job') + def test_get_trainingjob_tmexception(self, mock_get_training_job): + # Simulate TMException + mock_get_training_job.side_effect = TMException('Training job not found') + + response = self.client.get('/training-jobs/1') + assert response.status_code == 400 + assert response.json['message'] == 'Training job not found' + + @patch('trainingmgr.controller.trainingjob_controller.get_training_job') + def test_get_trainingjob_generic_exception(self, mock_get_training_job): + mock_get_training_job.side_effect = Exception('Unexpected error') + response = self.client.get('/training-jobs/1') + assert response.status_code == 500 + assert response.json['message'] == 'Unexpected error' + + +class Test_GetTrainingJobStatus: + def setup_method(self): + app = Flask(__name__) + app.register_blueprint(training_job_controller) + self.client = app.test_client() + + expected_data = {"status": "running"} + @patch('trainingmgr.controller.trainingjob_controller.get_steps_state', return_value=json.dumps(expected_data)) + def test_get_trainingjob_status(self, mock1): + expected_data = {"status": "running"} + response = self.client.get("/training-jobs/{}/status".format("123")) + trainingmgr_main.LOGGER.debug(response.data) + assert response.status_code == status.HTTP_200_OK + assert expected_data == response.json +