added test cases for coverage 93/14093/1
authorMonosij Ghosh <mono.ghosh@samsung.com>
Mon, 13 Jan 2025 12:59:40 +0000 (18:29 +0530)
committersubhash kumar singh <subh.singh@samsung.com>
Wed, 15 Jan 2025 17:16:48 +0000 (17:16 +0000)
added TCs for trainingmgr_main file

Change-Id: Ie4deb11d2623bfcf4f570c186be69e327d0f6a48
Signed-off-by: Monosij Ghosh <mono.ghosh@samsung.com>
(cherry picked from commit 1b5cce485c96dc3a3a94cbc6869a43d99e0c0da5)

tests/test_trainingmgr_main.py [new file with mode: 0644]

diff --git a/tests/test_trainingmgr_main.py b/tests/test_trainingmgr_main.py
new file mode 100644 (file)
index 0000000..6aad439
--- /dev/null
@@ -0,0 +1,433 @@
+# ==================================================================================
+#
+#      Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+# 
+# ==================================================================================
+
+import pytest
+from flask import Flask
+import json
+import requests
+from unittest import mock
+from unittest.mock import patch, MagicMock
+from requests.models import Response
+from threading import Lock
+import os
+import sys
+import datetime
+from flask_api import status
+from flask import send_file
+from dotenv import load_dotenv
+load_dotenv('tests/test.env') 
+from trainingmgr.constants.states import States
+from threading import Lock
+from trainingmgr.common.tmgr_logger import TMLogger
+from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
+from trainingmgr.common.exceptions_utls import DBException, TMException
+from trainingmgr.models import TrainingJob
+from trainingmgr.models import FeatureGroup
+from trainingmgr.common.trainingConfig_parser import getField
+from trainingmgr.controller import featuregroup_controller, training_job_controller
+from trainingmgr.controller.pipeline_controller import pipeline_controller
+from trainingmgr.handler.async_handler import Model_Metrics_Sdk
+from middleware.loggingMiddleware import LoggingMiddleware
+from io import BytesIO
+
+#mock ModelMetricsSdk before importing
+mock_modelmetrics_sdk = MagicMock()
+sys.modules["trainingmgr.handler.async_handler"] = MagicMock(ModelMetricsSdk=mock_modelmetrics_sdk)
+TRAININGMGR_CONFIG_OBJ = TrainingMgrConfig()
+from trainingmgr import trainingmgr_main
+
+class TestGetModel:
+    def setup_method(self):
+        self.client = trainingmgr_main.APP.test_client(self)
+    
+    mocked_mm_sdk=mock.Mock(name="MM_SDK")
+    attrs_mm_sdk = {'check_object.return_value': True, 'get_model_zip.return_value':"model.zip"}
+    mocked_mm_sdk.configure_mock(**attrs_mm_sdk)
+
+    @patch('trainingmgr.trainingmgr_main.MM_SDK', return_value = mocked_mm_sdk)
+    @patch('trainingmgr.trainingmgr_main.send_file')
+    def test_getmodel_valid_data(self, mock_sendfile, mock_model_metrics_sdk):
+        mock_sendfile.return_value = trainingmgr_main.APP.response_class(
+            response=b'model.zip',
+            status=200,
+            mimetype='application/zip'
+        )
+        response = self.client.get("/model/{modelname}/{modelversion}/{artifactversion}/Model.zip".format(modelname="valid_model",
+                                modelversion="1.0", artifactversion = "1.0"), content_type="application/json")
+        print(response.data)
+        assert response.status_code == 200
+        assert b'model.zip' == response.data
+
+    @patch('trainingmgr.trainingmgr_main.MM_SDK', return_value=mocked_mm_sdk)
+    @patch('trainingmgr.trainingmgr_main.send_file')
+    def test_getmodel_model_not_found(self, mock_sendfile, mock_model_metrics_sdk):
+        
+        mock_sendfile.side_effect = Exception("An error occurred (404) when calling the HeadObject operation: Not Found")
+        response = self.client.get("/model/{modelname}/{modelversion}/{artifactversion}/Model.zip".format(modelname="nonexistent_model",
+                                modelversion="1.0", artifactversion = "1.0"), content_type="application/json")
+        
+        assert response.status_code == 404
+        assert b"error while downloading model as no model with modelId nonexistent_model 1.0 1.0 was found" in response.data
+    
+    @patch('trainingmgr.trainingmgr_main.MM_SDK', return_value=mocked_mm_sdk)
+    @patch('trainingmgr.trainingmgr_main.send_file')
+    def test_getmodel_internal_server_error(self, mock_sendfile, mock_model_metrics_sdk):
+        mock_sendfile.side_effect = Exception("Some general error occurred")
+        response = self.client.get("/model/{modelname}/{modelversion}/{artifactversion}/Model.zip".format(modelname="invalid_model",
+                                modelversion="1.0", artifactversion = "1.0"), content_type="application/json")
+        
+        assert response.status_code == 500
+        assert b"error while downloading model" in response.data
+    
+
+class TestDataExtractionNotification:
+    def setup_method(self):
+        self.client = trainingmgr_main.APP.test_client(self)
+    
+    @pytest.fixture        
+    def mock_training_job(self):
+        TrainingJob = MagicMock()
+        TrainingJob.id = 1
+        
+        mock_modelId = MagicMock()
+        
+        mock_modelId.modelname = 'test_model'  # Set the modelname attribute
+        mock_modelId.modelversion = '1.0'     # Set the modelversion attribute
+        
+        TrainingJob.modelId.return_value = mock_modelId
+        TrainingJob.model_location = "s3://model-location"
+        TrainingJob.training_config = {
+            "description": "Test description",
+            "dataPipeline": {
+                "feature_group_name": "test_feature_group",
+                "query_filter": "",
+                "arguments": {"epochs" : 1, "trainingjob_name": "test_job"}
+            },
+            "trainingPipeline": {
+                    "pipeline_name": "test_pipeline",
+                    "pipeline_version": "2",
+                    "enable_versioning": True
+            }
+        }
+        mock_steps_state = MagicMock()
+        mock_steps_state.states = json.dumps({"DATA_EXTRACTION":"FINISHED", 
+                            "DATA_EXTRACTION_AND_TRAINING":"FINISHED", 
+                            "TRAINING":"IN_PROGRESS"})  
+        TrainingJob.steps_state = mock_steps_state
+        TrainingJob.deletion_in_progress = False
+        return TrainingJob
+
+
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = False)
+    def test_dataextraction_keynotindict(self, mock_check_dict):        
+        trainingjob_req = {
+                     "trainingjob_id":"123"
+        }
+        response = self.client.post('/trainingjob/dataExtractionNotification', data = json.dumps(trainingjob_req),
+                                    content_type="application/json")
+        assert response.status_code == 400
+        assert "featuregroup_name or trainingjob_id key not available in request" in str(response.data)
+    
+
+    registered_model_list1 = [{ 
+        "modelId" : { 
+            "artifactVersion" : "0.0.0"
+        },
+        "modelLocation": "s3://different-location"
+    }]
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True)
+    @patch('trainingmgr.trainingmgr_main.get_training_job')
+    @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list1)
+    @patch('trainingmgr.trainingmgr_main.fetch_pipelinename_and_version', return_value = ("", ""))
+    def test_dataextraction_trainingModel_error(self, mock_fetchname, mock_getmodelInfo, mock_get_trainingjob, mock_check_dict, mock_training_job):        
+        mock_get_trainingjob.return_value = mock_training_job
+        trainingjob_req = {
+                     "trainingjob_id":"123"
+        }
+        response = self.client.post('/trainingjob/dataExtractionNotification', data = json.dumps(trainingjob_req),
+                                    content_type="application/json")
+        assert response.status_code == 500
+        assert "Provide retraining pipeline name and version" in str(response.data)
+    
+
+    registered_model_list2 = [{ 
+        "modelId" : { 
+            "artifactVersion" : "0.1.0"
+        },
+    }]
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True)
+    @patch('trainingmgr.trainingmgr_main.get_training_job')
+    @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list2)
+    @patch('trainingmgr.trainingmgr_main.fetch_pipelinename_and_version', return_value = ("", ""))
+    def test_dataextraction_retrainingModel_error(self, mock_fetch_pipeline, mock_getmodelInfo, mock_get_trainingjob, mock_check_dict, mock_training_job):        
+        mock_get_trainingjob.return_value = mock_training_job
+        trainingjob_req = {
+                     "trainingjob_id":"123"
+        }
+        response = self.client.post('/trainingjob/dataExtractionNotification', data = json.dumps(trainingjob_req),
+                                    content_type="application/json")
+        assert response.status_code == 500
+        assert "Provide retraining pipeline name and version" in str(response.data)
+
+
+    registered_model_list3 = [{ 
+        "modelId" : { 
+            "artifactVersion" : "0.0.0"
+        },
+        "modelLocation": ""
+    }]
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True)
+    @patch('trainingmgr.trainingmgr_main.get_training_job')
+    @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list3)
+    @patch('trainingmgr.trainingmgr_main.fetch_pipelinename_and_version', return_value = ("qoe_pipeline", "v1"))
+    @patch('trainingmgr.trainingmgr_main.training_start')
+    def test_dataextraction_trainingModel_invalidresponse_kf(self, mock_training_start, mock_fetch_pipeline, 
+    mock_getmodelInfo, mock_get_trainingjob, mock_check_dict, mock_training_job):        
+         
+        mock_response_invalid = MagicMock()
+        mock_response_invalid.headers = {'content-type': 'text/plain'}
+        mock_response_invalid.status_code = 500
+        mock_training_start.return_value = mock_response_invalid
+
+        mock_get_trainingjob.return_value = mock_training_job
+        trainingjob_req = {
+                     "trainingjob_id":"123"
+        }
+
+        response = self.client.post('/trainingjob/dataExtractionNotification', data = json.dumps(trainingjob_req), content_type="application/json")
+        assert response.status_code == 500
+
+
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True)
+    @patch('trainingmgr.trainingmgr_main.get_training_job')
+    @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list3)
+    @patch('trainingmgr.trainingmgr_main.fetch_pipelinename_and_version', return_value = ("qoe_pipeline", "v1"))
+    @patch('trainingmgr.trainingmgr_main.training_start')
+    @patch('trainingmgr.trainingmgr_main.change_status_tj')
+    @patch('trainingmgr.trainingmgr_main.change_update_field_value')
+    def test_dataextraction_trainingModel_validresponse_kf(self, mock_update_field_val, mock_change_status,
+     mock_training_start, mock_fetch_pipeline, mock_getmodelInfo, mock_get_trainingjob, mock_check_dict, mock_training_job):        
+         
+        mock_response_valid = MagicMock()
+        mock_response_valid.headers = {'content-type': 'application/json'}
+        mock_response_valid.status_code = 200
+        mock_response_valid.json.return_value = {'run_status': 'scheduled', 'run_id': '123'}
+        mock_training_start.return_value = mock_response_valid
+
+        mock_get_trainingjob.return_value = mock_training_job
+        trainingjob_req = {
+                     "trainingjob_id":"123"
+        }
+        response = self.client.post('/trainingjob/dataExtractionNotification', data = json.dumps(trainingjob_req),
+                                    content_type="application/json")
+        assert response.status_code == 200
+        assert "pipeline is scheduled" in str(response.data)
+
+
+class TestPipelineNotification:
+
+    def setup_method(self):
+        self.client = trainingmgr_main.APP.test_client(self)
+
+    mocked_mm_sdk = mock.Mock(name="MM_SDK")
+    @pytest.fixture        
+    def mock_training_job(self):
+        TrainingJob = MagicMock()
+        TrainingJob.id = 1
+
+        TrainingJob.model_location = "s3://model-location"
+        TrainingJob.training_config = {
+            "description": "Test description",
+            "dataPipeline": {
+                "feature_group_name": "test_feature_group",
+                "query_filter": "",
+                "arguments": {"epochs" : 1, "trainingjob_name": "test_job"}
+            },
+            "trainingPipeline": {
+                    "pipeline_name": "test_pipeline",
+                    "pipeline_version": "2",
+                    "enable_versioning": True
+            }
+        }
+        mock_steps_state = MagicMock()
+        mock_steps_state.states = json.dumps({"DATA_EXTRACTION":"FINISHED", 
+                            "DATA_EXTRACTION_AND_TRAINING":"FINISHED", 
+                            "TRAINING":"IN_PROGRESS"})  
+        TrainingJob.steps_state = mock_steps_state
+        TrainingJob.deletion_in_progress = False
+        return TrainingJob
+    
+    registered_model_list = [{ 
+        "modelId" : { 
+            "artifactVersion" : "1.0"
+        },
+        "modelLocation": "s3://different-location"
+    }]
+        
+    @patch('trainingmgr.trainingmgr_main.MM_SDK', return_value = mocked_mm_sdk)
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True)
+    @patch('trainingmgr.trainingmgr_main.get_training_job')
+    @patch('trainingmgr.trainingmgr_main.change_status_tj')
+    @patch('trainingmgr.trainingmgr_main.notification_rapp')
+    @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list)
+    @patch('trainingmgr.trainingmgr_main.change_update_field_value')
+    
+    def test_success(self, mock_update_field_value, mock_get_modelinfo, mock_notification_rapp, mock_change_status, 
+    mock_get_trainingjob, mock_check_in_dict, mock_mmsdk, mock_training_job):
+        mock_mmsdk.check_object.return_value = True
+        mock_get_trainingjob.return_value = mock_training_job
+        trainingjob_req = {
+                     "trainingjob_id" : "123",
+                     "run_status" : "SUCCEEDED"
+        }
+        response = self.client.post('/trainingjob/pipelineNotification', data = json.dumps(trainingjob_req),
+                                    content_type="application/json")
+        assert response.status_code == 200
+        assert response.json == {'Message': 'Training successful'}
+
+
+    mocked_mm_sdk_false = mock.Mock(name="MM_SDK")
+    @patch('trainingmgr.trainingmgr_main.MM_SDK', return_value = mocked_mm_sdk)
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True)
+    @patch('trainingmgr.trainingmgr_main.get_training_job')
+    @patch('trainingmgr.trainingmgr_main.change_status_tj')
+    @patch('trainingmgr.trainingmgr_main.notification_rapp')
+    @patch('trainingmgr.trainingmgr_main.get_modelinfo_by_modelId_service', return_value = registered_model_list)
+    
+    def test_unsuccess_object_not_present_in_mmsdk(self, mock_get_modelinfo, mock_notification_rapp, mock_change_status, 
+    mock_get_trainingjob, mock_check_in_dict, mock_mmsdk, mock_training_job):  
+        mock_mmsdk.check_object.return_value = False
+        mock_get_trainingjob.return_value = mock_training_job
+        trainingjob_req = {
+                     "trainingjob_id" : "123",
+                     "run_status" : "SUCCEEDED"
+        }
+        response = self.client.post('/trainingjob/pipelineNotification', data = json.dumps(trainingjob_req),
+                                    content_type="application/json")
+        assert response.status_code == 500    
+
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True)
+    @patch('trainingmgr.trainingmgr_main.get_training_job')
+    @patch('trainingmgr.trainingmgr_main.change_status_tj')
+    @patch('trainingmgr.trainingmgr_main.notification_rapp')
+    def test_unsuccess(self, mock_notification_rapp, mock_change_status, mock_get_trainingjob, mock_check_in_dict, 
+     mock_training_job):
+        mock_get_trainingjob.return_value = mock_training_job
+        trainingjob_req = {
+                     "trainingjob_id" : "123",
+                     "run_status" : "UNSUCCEEDED"
+        }
+        response = self.client.post('/trainingjob/pipelineNotification', data = json.dumps(trainingjob_req),
+                                    content_type="application/json")
+        assert response.status_code == 500
+
+class TestFeatureGroupByName:
+
+    def setup_method(self):
+        self.client = trainingmgr_main.APP.test_client(self)
+    
+    @pytest.fixture
+    def mock_feature_group(self):
+        mock_fg = MagicMock()
+        mock_fg.featuregroup_name = "test_featuregroup"
+        return mock_fg
+
+    @patch('trainingmgr.trainingmgr_main.get_feature_group_by_name')
+    def test_get_feature_group_success(self, mock_get_feature_group):
+        mock_get_feature_group.return_value = {"featuregroup_name": "test_featuregroup", "status": "success"}, status.HTTP_200_OK
+        
+        response = self.client.get("/featureGroup/{featuregroup_name}".format(featuregroup_name = "test_featuregroup"), content_type="application/json")
+        assert response.status_code == 200
+        assert response.json == {"featuregroup_name": "test_featuregroup", "status": "success"}
+
+    @patch('trainingmgr.trainingmgr_main.get_feature_group_by_name')
+    def test_get_feature_group_not_found(self, mock_get_feature_group):
+        mock_get_feature_group.return_value = {"Exception": "Feature group not found"}, status.HTTP_404_NOT_FOUND
+        
+        response = self.client.get("/featureGroup/{featuregroup_name}".format(featuregroup_name = "invalid"), content_type="application/json")
+        assert response.status_code == 404
+        assert response.json == {"Exception": "Feature group not found"}
+    
+    @patch('trainingmgr.trainingmgr_main.edit_feature_group_by_name')
+    @patch('trainingmgr.trainingmgr_main.FeatureGroupSchema.load')
+    def test_put_feature_group_success(self, mock_load, mock_edit_feature_group, mock_feature_group):   
+        mock_load.return_value = mock_feature_group
+        mock_edit_feature_group.return_value = {"featuregroup_name": "test_featuregroup", "status": "updated"}, status.HTTP_200_OK
+        put_data = {"featuregroup_name": "test_featuregroup", "new_field": "value"}
+        
+        response = self.client.put("/featureGroup/{featuregroup_name}".format(featuregroup_name = "test_featuregroup"), json = put_data)
+        assert response.status_code == 200
+        assert response.json == {"featuregroup_name": "test_featuregroup", "status": "updated"}
+
+
+class TestDeleteListOfFeatureGroup:
+    
+    def setup_method(self):
+        self.client = trainingmgr_main.APP.test_client(self)
+
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True)
+    @patch('trainingmgr.trainingmgr_main.get_feature_group_by_name_db')
+    @patch('trainingmgr.trainingmgr_main.delete_feature_group_by_name')
+    @patch('trainingmgr.trainingmgr_main.delete_dme_filtered_data_job')
+    def test_delete_feature_group_success(self, mock_delete_dme, mock_delete_fg, mock_get_fg, mock_check_key):
+        mock_get_fg.return_value = MagicMock(enable_dme=True, host="localhost", dme_port=8080)
+        mock_delete_fg.return_value = None
+        mock_delete_dme.return_value.status_code = status.HTTP_204_NO_CONTENT
+        
+        request_data = {
+            "featuregroups_list": [
+                {"featureGroup_name": "test_featuregroup"}
+            ]
+        }
+        response = self.client.delete('/featureGroup', json=request_data)
+        assert response.status_code == 200
+        assert response.json == {"success count": 1, "failure count": 0}
+
+    
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = True)
+    @patch('trainingmgr.trainingmgr_main.get_feature_group_by_name_db')
+    @patch('trainingmgr.trainingmgr_main.delete_feature_group_by_name')
+    @patch('trainingmgr.trainingmgr_main.delete_dme_filtered_data_job')
+    def test_delete_feature_group_failure(self, mock_delete_dme, mock_delete_fg, mock_get_fg, mock_check_key):
+        mock_get_fg.return_value = MagicMock(enable_dme=True, host="localhost", dme_port=8080)
+        mock_delete_fg.side_effect = Exception("Failed to delete feature group")
+        
+        request_data = {
+            "featuregroups_list": [
+                {"featureGroup_name": "test_featuregroup"}
+            ]
+        }
+        response = self.client.delete('/featureGroup', json=request_data)
+        assert response.status_code == 200
+        assert response.json == {"success count": 0, "failure count": 1}
+    
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary', return_value = False)
+    @patch('trainingmgr.trainingmgr_main.get_feature_group_by_name_db')
+    @patch('trainingmgr.trainingmgr_main.delete_feature_group_by_name')
+    @patch('trainingmgr.trainingmgr_main.delete_dme_filtered_data_job')
+    def test_delete_feature_group_invalid_key(self, mock_delete_dme, mock_delete_fg, mock_get_fg, mock_check_key):
+        
+        request_data = {
+            "featuregroups_list": [
+                {"featuregroup_name": "test_featuregroup"}
+            ]
+        }
+        response = self.client.delete('/featureGroup', json=request_data)
+        assert response.status_code == 400
+        assert "Wrong Request syntax" in response.json["Exception"]
+