From a109a8b55881dc86f0a15659a4d4433a32de302d Mon Sep 17 00:00:00 2001 From: Monosij Ghosh Date: Mon, 13 Jan 2025 18:29:40 +0530 Subject: [PATCH] added test cases for coverage added TCs for trainingmgr_main file Change-Id: Ie4deb11d2623bfcf4f570c186be69e327d0f6a48 Signed-off-by: Monosij Ghosh (cherry picked from commit 1b5cce485c96dc3a3a94cbc6869a43d99e0c0da5) --- tests/test_trainingmgr_main.py | 433 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 433 insertions(+) create mode 100644 tests/test_trainingmgr_main.py diff --git a/tests/test_trainingmgr_main.py b/tests/test_trainingmgr_main.py new file mode 100644 index 0000000..6aad439 --- /dev/null +++ b/tests/test_trainingmgr_main.py @@ -0,0 +1,433 @@ +# ================================================================================== +# +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== + +import pytest +from flask import Flask +import json +import requests +from unittest import mock +from unittest.mock import patch, MagicMock +from requests.models import Response +from threading import Lock +import os +import sys +import datetime +from flask_api import status +from flask import send_file +from dotenv import load_dotenv +load_dotenv('tests/test.env') +from trainingmgr.constants.states import States +from threading import Lock +from trainingmgr.common.tmgr_logger import TMLogger +from trainingmgr.common.trainingmgr_config import TrainingMgrConfig +from trainingmgr.common.exceptions_utls import DBException, TMException +from trainingmgr.models import TrainingJob +from trainingmgr.models import FeatureGroup +from trainingmgr.common.trainingConfig_parser import getField +from trainingmgr.controller import featuregroup_controller, training_job_controller +from trainingmgr.controller.pipeline_controller import pipeline_controller +from trainingmgr.handler.async_handler import Model_Metrics_Sdk +from middleware.loggingMiddleware import LoggingMiddleware +from io import BytesIO + +#mock ModelMetricsSdk before importing +mock_modelmetrics_sdk = MagicMock() +sys.modules["trainingmgr.handler.async_handler"] = MagicMock(ModelMetricsSdk=mock_modelmetrics_sdk) +TRAININGMGR_CONFIG_OBJ = TrainingMgrConfig() +from trainingmgr import trainingmgr_main + +class TestGetModel: + def setup_method(self): + self.client = trainingmgr_main.APP.test_client(self) + + mocked_mm_sdk=mock.Mock(name="MM_SDK") + attrs_mm_sdk = {'check_object.return_value': True, 'get_model_zip.return_value':"model.zip"} + mocked_mm_sdk.configure_mock(**attrs_mm_sdk) + + @patch('trainingmgr.trainingmgr_main.MM_SDK', return_value = mocked_mm_sdk) + @patch('trainingmgr.trainingmgr_main.send_file') + def test_getmodel_valid_data(self, mock_sendfile, mock_model_metrics_sdk): + mock_sendfile.return_value = trainingmgr_main.APP.response_class( + response=b'model.zip', + status=200, + mimetype='application/zip' + ) + response = self.client.get("/model/{modelname}/{modelversion}/{artifactversion}/Model.zip".format(modelname="valid_model", + modelversion="1.0", artifactversion = "1.0"), content_type="application/json") + print(response.data) + assert response.status_code == 200 + assert b'model.zip' == response.data + + @patch('trainingmgr.trainingmgr_main.MM_SDK', return_value=mocked_mm_sdk) + @patch('trainingmgr.trainingmgr_main.send_file') + def test_getmodel_model_not_found(self, mock_sendfile, mock_model_metrics_sdk): + + mock_sendfile.side_effect = Exception("An error occurred (404) when calling the HeadObject operation: Not Found") + response = self.client.get("/model/{modelname}/{modelversion}/{artifactversion}/Model.zip".format(modelname="nonexistent_model", + modelversion="1.0", artifactversion = "1.0"), content_type="application/json") + + assert response.status_code == 404 + assert b"error while downloading model as no model with modelId nonexistent_model 1.0 1.0 was found" in response.data + + @patch('trainingmgr.trainingmgr_main.MM_SDK', return_value=mocked_mm_sdk) + @patch('trainingmgr.trainingmgr_main.send_file') + def test_getmodel_internal_server_error(self, mock_sendfile, mock_model_metrics_sdk): + mock_sendfile.side_effect = Exception("Some general error occurred") + response = self.client.get("/model/{modelname}/{modelversion}/{artifactversion}/Model.zip".format(modelname="invalid_model", + modelversion="1.0", artifactversion = "1.0"), content_type="application/json") + + assert response.status_code == 500 + assert b"error while downloading model" in response.data + + +class TestDataExtractionNotification: + def setup_method(self): + self.client = trainingmgr_main.APP.test_client(self) + + @pytest.fixture + def mock_training_job(self): + TrainingJob = MagicMock() + TrainingJob.id = 1 + + mock_modelId = MagicMock() + + mock_modelId.modelname = 'test_model' # Set the modelname attribute + mock_modelId.modelversion = '1.0' # Set the modelversion attribute + + TrainingJob.modelId.return_value = mock_modelId + TrainingJob.model_location = "s3://model-location" + TrainingJob.training_config = { + "description": "Test description", + "dataPipeline": { + "feature_group_name": "test_feature_group", + "query_filter": "", + "arguments": {"epochs" : 1, "trainingjob_name": "test_job"} + }, + "trainingPipeline": { + "pipeline_name": "test_pipeline", + "pipeline_version": "2", + "enable_versioning": True + } + } + mock_steps_state = MagicMock() + mock_steps_state.states = json.dumps({"DATA_EXTRACTION":"FINISHED", + "DATA_EXTRACTION_AND_TRAINING":"FINISHED", + "TRAINING":"IN_PROGRESS"}) + TrainingJob.steps_state = mock_steps_state + TrainingJob.deletion_in_progress = False + return TrainingJob + + + @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = False) + def test_dataextraction_keynotindict(self, mock_check_dict): + trainingjob_req = { + "trainingjob_id":"123" + } + response = self.client.post('/trainingjob/dataExtractionNotification', data = json.dumps(trainingjob_req), + content_type="application/json") + assert response.status_code == 400 + assert "featuregroup_name or trainingjob_id key not available in request" in str(response.data) + + + registered_model_list1 = [{ + "modelId" : { + "artifactVersion" : "0.0.0" + }, + "modelLocation": "s3://different-location" + }] + @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True) + @patch('trainingmgr.trainingmgr_main.get_training_job') + @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list1) + @patch('trainingmgr.trainingmgr_main.fetch_pipelinename_and_version', return_value = ("", "")) + def test_dataextraction_trainingModel_error(self, mock_fetchname, mock_getmodelInfo, mock_get_trainingjob, mock_check_dict, mock_training_job): + mock_get_trainingjob.return_value = mock_training_job + trainingjob_req = { + "trainingjob_id":"123" + } + response = self.client.post('/trainingjob/dataExtractionNotification', data = json.dumps(trainingjob_req), + content_type="application/json") + assert response.status_code == 500 + assert "Provide retraining pipeline name and version" in str(response.data) + + + registered_model_list2 = [{ + "modelId" : { + "artifactVersion" : "0.1.0" + }, + }] + @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True) + @patch('trainingmgr.trainingmgr_main.get_training_job') + @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list2) + @patch('trainingmgr.trainingmgr_main.fetch_pipelinename_and_version', return_value = ("", "")) + def test_dataextraction_retrainingModel_error(self, mock_fetch_pipeline, mock_getmodelInfo, mock_get_trainingjob, mock_check_dict, mock_training_job): + mock_get_trainingjob.return_value = mock_training_job + trainingjob_req = { + "trainingjob_id":"123" + } + response = self.client.post('/trainingjob/dataExtractionNotification', data = json.dumps(trainingjob_req), + content_type="application/json") + assert response.status_code == 500 + assert "Provide retraining pipeline name and version" in str(response.data) + + + registered_model_list3 = [{ + "modelId" : { + "artifactVersion" : "0.0.0" + }, + "modelLocation": "" + }] + @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True) + @patch('trainingmgr.trainingmgr_main.get_training_job') + @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list3) + @patch('trainingmgr.trainingmgr_main.fetch_pipelinename_and_version', return_value = ("qoe_pipeline", "v1")) + @patch('trainingmgr.trainingmgr_main.training_start') + def test_dataextraction_trainingModel_invalidresponse_kf(self, mock_training_start, mock_fetch_pipeline, + mock_getmodelInfo, mock_get_trainingjob, mock_check_dict, mock_training_job): + + mock_response_invalid = MagicMock() + mock_response_invalid.headers = {'content-type': 'text/plain'} + mock_response_invalid.status_code = 500 + mock_training_start.return_value = mock_response_invalid + + mock_get_trainingjob.return_value = mock_training_job + trainingjob_req = { + "trainingjob_id":"123" + } + + response = self.client.post('/trainingjob/dataExtractionNotification', data = json.dumps(trainingjob_req), content_type="application/json") + assert response.status_code == 500 + + + @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True) + @patch('trainingmgr.trainingmgr_main.get_training_job') + @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list3) + @patch('trainingmgr.trainingmgr_main.fetch_pipelinename_and_version', return_value = ("qoe_pipeline", "v1")) + @patch('trainingmgr.trainingmgr_main.training_start') + @patch('trainingmgr.trainingmgr_main.change_status_tj') + @patch('trainingmgr.trainingmgr_main.change_update_field_value') + def test_dataextraction_trainingModel_validresponse_kf(self, mock_update_field_val, mock_change_status, + mock_training_start, mock_fetch_pipeline, mock_getmodelInfo, mock_get_trainingjob, mock_check_dict, mock_training_job): + + mock_response_valid = MagicMock() + mock_response_valid.headers = {'content-type': 'application/json'} + mock_response_valid.status_code = 200 + mock_response_valid.json.return_value = {'run_status': 'scheduled', 'run_id': '123'} + mock_training_start.return_value = mock_response_valid + + mock_get_trainingjob.return_value = mock_training_job + trainingjob_req = { + "trainingjob_id":"123" + } + response = self.client.post('/trainingjob/dataExtractionNotification', data = json.dumps(trainingjob_req), + content_type="application/json") + assert response.status_code == 200 + assert "pipeline is scheduled" in str(response.data) + + +class TestPipelineNotification: + + def setup_method(self): + self.client = trainingmgr_main.APP.test_client(self) + + mocked_mm_sdk = mock.Mock(name="MM_SDK") + @pytest.fixture + def mock_training_job(self): + TrainingJob = MagicMock() + TrainingJob.id = 1 + + TrainingJob.model_location = "s3://model-location" + TrainingJob.training_config = { + "description": "Test description", + "dataPipeline": { + "feature_group_name": "test_feature_group", + "query_filter": "", + "arguments": {"epochs" : 1, "trainingjob_name": "test_job"} + }, + "trainingPipeline": { + "pipeline_name": "test_pipeline", + "pipeline_version": "2", + "enable_versioning": True + } + } + mock_steps_state = MagicMock() + mock_steps_state.states = json.dumps({"DATA_EXTRACTION":"FINISHED", + "DATA_EXTRACTION_AND_TRAINING":"FINISHED", + "TRAINING":"IN_PROGRESS"}) + TrainingJob.steps_state = mock_steps_state + TrainingJob.deletion_in_progress = False + return TrainingJob + + registered_model_list = [{ + "modelId" : { + "artifactVersion" : "1.0" + }, + "modelLocation": "s3://different-location" + }] + + @patch('trainingmgr.trainingmgr_main.MM_SDK', return_value = mocked_mm_sdk) + @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True) + @patch('trainingmgr.trainingmgr_main.get_training_job') + @patch('trainingmgr.trainingmgr_main.change_status_tj') + @patch('trainingmgr.trainingmgr_main.notification_rapp') + @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list) + @patch('trainingmgr.trainingmgr_main.change_update_field_value') + + def test_success(self, mock_update_field_value, mock_get_modelinfo, mock_notification_rapp, mock_change_status, + mock_get_trainingjob, mock_check_in_dict, mock_mmsdk, mock_training_job): + mock_mmsdk.check_object.return_value = True + mock_get_trainingjob.return_value = mock_training_job + trainingjob_req = { + "trainingjob_id" : "123", + "run_status" : "SUCCEEDED" + } + response = self.client.post('/trainingjob/pipelineNotification', data = json.dumps(trainingjob_req), + content_type="application/json") + assert response.status_code == 200 + assert response.json == {'Message': 'Training successful'} + + + mocked_mm_sdk_false = mock.Mock(name="MM_SDK") + @patch('trainingmgr.trainingmgr_main.MM_SDK', return_value = mocked_mm_sdk) + @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True) + @patch('trainingmgr.trainingmgr_main.get_training_job') + @patch('trainingmgr.trainingmgr_main.change_status_tj') + @patch('trainingmgr.trainingmgr_main.notification_rapp') + @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list) + + def test_unsuccess_object_not_present_in_mmsdk(self, mock_get_modelinfo, mock_notification_rapp, mock_change_status, + mock_get_trainingjob, mock_check_in_dict, mock_mmsdk, mock_training_job): + mock_mmsdk.check_object.return_value = False + mock_get_trainingjob.return_value = mock_training_job + trainingjob_req = { + "trainingjob_id" : "123", + "run_status" : "SUCCEEDED" + } + response = self.client.post('/trainingjob/pipelineNotification', data = json.dumps(trainingjob_req), + content_type="application/json") + assert response.status_code == 500 + + @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True) + @patch('trainingmgr.trainingmgr_main.get_training_job') + @patch('trainingmgr.trainingmgr_main.change_status_tj') + @patch('trainingmgr.trainingmgr_main.notification_rapp') + def test_unsuccess(self, mock_notification_rapp, mock_change_status, mock_get_trainingjob, mock_check_in_dict, + mock_training_job): + mock_get_trainingjob.return_value = mock_training_job + trainingjob_req = { + "trainingjob_id" : "123", + "run_status" : "UNSUCCEEDED" + } + response = self.client.post('/trainingjob/pipelineNotification', data = json.dumps(trainingjob_req), + content_type="application/json") + assert response.status_code == 500 + +class TestFeatureGroupByName: + + def setup_method(self): + self.client = trainingmgr_main.APP.test_client(self) + + @pytest.fixture + def mock_feature_group(self): + mock_fg = MagicMock() + mock_fg.featuregroup_name = "test_featuregroup" + return mock_fg + + @patch('trainingmgr.trainingmgr_main.get_feature_group_by_name') + def test_get_feature_group_success(self, mock_get_feature_group): + mock_get_feature_group.return_value = {"featuregroup_name": "test_featuregroup", "status": "success"}, status.HTTP_200_OK + + response = self.client.get("/featureGroup/{featuregroup_name}".format(featuregroup_name = "test_featuregroup"), content_type="application/json") + assert response.status_code == 200 + assert response.json == {"featuregroup_name": "test_featuregroup", "status": "success"} + + @patch('trainingmgr.trainingmgr_main.get_feature_group_by_name') + def test_get_feature_group_not_found(self, mock_get_feature_group): + mock_get_feature_group.return_value = {"Exception": "Feature group not found"}, status.HTTP_404_NOT_FOUND + + response = self.client.get("/featureGroup/{featuregroup_name}".format(featuregroup_name = "invalid"), content_type="application/json") + assert response.status_code == 404 + assert response.json == {"Exception": "Feature group not found"} + + @patch('trainingmgr.trainingmgr_main.edit_feature_group_by_name') + @patch('trainingmgr.trainingmgr_main.FeatureGroupSchema.load') + def test_put_feature_group_success(self, mock_load, mock_edit_feature_group, mock_feature_group): + mock_load.return_value = mock_feature_group + mock_edit_feature_group.return_value = {"featuregroup_name": "test_featuregroup", "status": "updated"}, status.HTTP_200_OK + put_data = {"featuregroup_name": "test_featuregroup", "new_field": "value"} + + response = self.client.put("/featureGroup/{featuregroup_name}".format(featuregroup_name = "test_featuregroup"), json = put_data) + assert response.status_code == 200 + assert response.json == {"featuregroup_name": "test_featuregroup", "status": "updated"} + + +class TestDeleteListOfFeatureGroup: + + def setup_method(self): + self.client = trainingmgr_main.APP.test_client(self) + + @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True) + @patch('trainingmgr.trainingmgr_main.get_feature_group_by_name_db') + @patch('trainingmgr.trainingmgr_main.delete_feature_group_by_name') + @patch('trainingmgr.trainingmgr_main.delete_dme_filtered_data_job') + def test_delete_feature_group_success(self, mock_delete_dme, mock_delete_fg, mock_get_fg, mock_check_key): + mock_get_fg.return_value = MagicMock(enable_dme=True, host="localhost", dme_port=8080) + mock_delete_fg.return_value = None + mock_delete_dme.return_value.status_code = status.HTTP_204_NO_CONTENT + + request_data = { + "featuregroups_list": [ + {"featureGroup_name": "test_featuregroup"} + ] + } + response = self.client.delete('/featureGroup', json=request_data) + assert response.status_code == 200 + assert response.json == {"success count": 1, "failure count": 0} + + + @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True) + @patch('trainingmgr.trainingmgr_main.get_feature_group_by_name_db') + @patch('trainingmgr.trainingmgr_main.delete_feature_group_by_name') + @patch('trainingmgr.trainingmgr_main.delete_dme_filtered_data_job') + def test_delete_feature_group_failure(self, mock_delete_dme, mock_delete_fg, mock_get_fg, mock_check_key): + mock_get_fg.return_value = MagicMock(enable_dme=True, host="localhost", dme_port=8080) + mock_delete_fg.side_effect = Exception("Failed to delete feature group") + + request_data = { + "featuregroups_list": [ + {"featureGroup_name": "test_featuregroup"} + ] + } + response = self.client.delete('/featureGroup', json=request_data) + assert response.status_code == 200 + assert response.json == {"success count": 0, "failure count": 1} + + @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = False) + @patch('trainingmgr.trainingmgr_main.get_feature_group_by_name_db') + @patch('trainingmgr.trainingmgr_main.delete_feature_group_by_name') + @patch('trainingmgr.trainingmgr_main.delete_dme_filtered_data_job') + def test_delete_feature_group_invalid_key(self, mock_delete_dme, mock_delete_fg, mock_get_fg, mock_check_key): + + request_data = { + "featuregroups_list": [ + {"featuregroup_name": "test_featuregroup"} + ] + } + response = self.client.delete('/featureGroup', json=request_data) + assert response.status_code == 400 + assert "Wrong Request syntax" in response.json["Exception"] + -- 2.16.6