From: rajdeep11 Date: Wed, 30 Oct 2024 19:47:06 +0000 (+0530) Subject: adding test_cases for get_trainingjob_by_name_and_version X-Git-Tag: 3.0.0~45^2~1 X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=commitdiff_plain;h=38c21dcba80b1718761f880bb82a51b3e910b835;p=aiml-fw%2Fawmf%2Ftm.git adding test_cases for get_trainingjob_by_name_and_version Change-Id: I881ba8356cadedd7bd6fcc58a6cfe684876e3f03 Signed-off-by: rajdeep11 --- diff --git a/tests/test_tm_apis.py b/tests/test_tm_apis.py index 9fb69f7..faa8580 100644 --- a/tests/test_tm_apis.py +++ b/tests/test_tm_apis.py @@ -35,6 +35,8 @@ from trainingmgr import trainingmgr_main from trainingmgr.common.tmgr_logger import TMLogger from trainingmgr.common.trainingmgr_config import TrainingMgrConfig from trainingmgr.common.exceptions_utls import DBException, TMException +from trainingmgr.models import TrainingJob + trainingmgr_main.LOGGER = pytest.logger trainingmgr_main.LOCK = Lock() trainingmgr_main.DATAEXTRACTION_JOBS_CACHE = {} @@ -265,49 +267,122 @@ class Test_pipeline_notification: assert expected_data in str(response.data) -@pytest.mark.skip("") class Test_get_trainingjob_by_name_version: def setup_method(self): self.client = trainingmgr_main.APP.test_client(self) self.logger = trainingmgr_main.LOGGER - @patch('trainingmgr.trainingmgr_main.get_info_by_version',return_value=[('usecase7', 'auto test', '*', 'prediction with model name', 'Default', '{"arguments": {"epochs": "1", "usecase": "usecase7"}}', 'Enb=20 and Cellnum=6', datetime.datetime(2022, 9, 20,11, 40, 30), '7d09c0bf-7575-4475-86ff-5573fb3c4716', '{"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "FINISHED", "TRAINING": "FINISHED", "TRAINING_AND_TRAINED_MODEL": "FINISHED", "TRAINED_MODEL": "FINISHED"}', datetime.datetime(2022, 9, 20, 11, 42, 20), 1, True, 'Near RT RIC', '{"datalake_source": {"CassandraSource": {}}}', '{"datalake_source": {"CassandraSource": {}}}','http://10.0.0.47:32002/model/usecase7/1/Model.zip','','','','','',False,'','')]) - @patch('trainingmgr.trainingmgr_main.get_metrics',return_value={"metrics": [{"Accuracy": "0.0"}]}) - @patch('trainingmgr.trainingmgr_main.get_one_key',return_value='cassandra') - def test_get_trainingjob_by_name_version(self,mock1,mock2,mock3): - usecase_name = "usecase7" - version = "1" - response = self.client.get("/trainingjobs/{}/{}".format(usecase_name, version)) - expected_data = b'{"trainingjob": {"trainingjob_name": "usecase7", "description": "auto test", "feature_list": "*", "pipeline_name": "prediction with model name", "experiment_name": "Default", "arguments": {"epochs": "1", "usecase": "usecase7"}, "query_filter": "Enb=20 and Cellnum=6", "creation_time": "2022-09-20 11:40:30", "run_id": "7d09c0bf-7575-4475-86ff-5573fb3c4716", "steps_state": {"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "FINISHED", "TRAINING": "FINISHED", "TRAINING_AND_TRAINED_MODEL": "FINISHED", "TRAINED_MODEL": "FINISHED"}, "updation_time": "2022-09-20 11:42:20", "version": 1, "enable_versioning": true, "pipeline_version": "Near RT RIC", "datalake_source": "cassandra", "model_url": "{\\"datalake_source\\": {\\"CassandraSource\\": {}}}", "notification_url": "http://10.0.0.47:32002/model/usecase7/1/Model.zip", "is_mme": "", "model_name": "", "model_info": "", "accuracy": {"metrics": [{"Accuracy": "0.0"}]}}}' - assert response.status_code == status.HTTP_200_OK, "not equal code" - assert response.data == expected_data, "not equal data" - - @patch('trainingmgr.trainingmgr_main.get_info_by_version',return_value=False) - @patch('trainingmgr.trainingmgr_main.get_metrics',return_value={"metrics": [{"Accuracy": "0.0"}]}) - @patch('trainingmgr.trainingmgr_main.get_one_key',return_value='cassandra') - def test_negative_get_trainingjob_by_name_version(self,mock1,mock2,mock3): - usecase_name = "usecase7" - version = "1" - response = self.client.get("/trainingjobs/{}/{}".format(usecase_name, version)) - expected_data = b'{"trainingjob": {"trainingjob_name": "usecase7", "description": "auto test", "feature_list": "*", "pipeline_name": "prediction with model name", "experiment_name": "Default", "arguments": {"epochs": "1", "usecase": "usecase7"}, "query_filter": "Enb=20 and Cellnum=6", "creation_time": "2022-09-20 11:40:30", "run_id": "7d09c0bf-7575-4475-86ff-5573fb3c4716", "steps_state": {"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "FINISHED", "TRAINING": "FINISHED", "TRAINING_AND_TRAINED_MODEL": "FINISHED", "TRAINED_MODEL": "FINISHED"}, "updation_time": "2022-09-20 11:42:20", "version": 1, "enable_versioning": true, "pipeline_version": "Near RT RIC", "datalake_source": "cassandra", "model_url": "{\\"datalake_source\\": {\\"CassandraSource\\": {}}}", "notification_url": "http://10.0.0.47:32002/model/usecase7/1/Model.zip", "_measurement": "", "bucket": "", "accuracy": {"metrics": [{"Accuracy": "0.0"}]}}}' - trainingmgr_main.LOGGER.debug(expected_data) - trainingmgr_main.LOGGER.debug(response.data) - assert response.content_type == "application/json", "not equal content type" - assert response.status_code == 404, "not equal code" - - def test_negative_get_trainingjob_by_name_version2(self): - usecase_name = "usecase7*" - version = "1" - response = self.client.get("/trainingjobs/{}/{}".format(usecase_name, version)) - print(response.data) - assert response.status_code == status.HTTP_400_BAD_REQUEST, "not equal status code" - assert response.data == b'{"Exception":"The trainingjob_name or version is not correct"}\n' - usecase_name="usecase7" - version="a" - response = self.client.get("/trainingjobs/{}/{}".format(usecase_name, version)) - assert response.status_code == status.HTTP_400_BAD_REQUEST, "not equal status code" - assert response.data == b'{"Exception":"The trainingjob_name or version is not correct"}\n' + @pytest.fixture + def mock_training_job(self): + """Create a mock TrainingJob object.""" + creation_time = datetime.datetime.now() + updation_time = datetime.datetime.now() + return TrainingJob( + trainingjob_name="test_job", + description="Test description", + feature_group_name="test_feature_group", + pipeline_name="test_pipeline", + experiment_name="test_experiment", + arguments=json.dumps({"param1": "value1"}), + query_filter="test_filter", + creation_time=creation_time, + run_id="test_run_id", + steps_state=json.dumps({"step1": "completed"}), + updation_time=updation_time, + version=1, + enable_versioning=True, + pipeline_version="v1", + datalake_source=json.dumps({"datalake_source": {"source1": "path1"}}), + model_url="http://test.model.url", + notification_url="http://test.notification.url", + deletion_in_progress=False, + is_mme=True, + model_name="test_model", + model_info="test_model_info" + ) + + @pytest.fixture + def mock_metrics(self): + """Create mock metrics data.""" + return {"accuracy": "0.95", "precision": "0.92"} + + @patch('trainingmgr.trainingmgr_main.get_info_by_version') + @patch('trainingmgr.trainingmgr_main.get_metrics') + @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True) + def test_successful_get_trainingjob(self, mock_check_name_and_version, mock_get_metrics, mock_get_info, mock_training_job, mock_metrics): + """Test successful retrieval of training job.""" + # Mock return values + mock_get_info.return_value = mock_training_job + mock_get_metrics.return_value = mock_metrics + + # Make the GET request + response = self.client.get('/trainingjobs/test_job/1') + + # Verify response + assert response.status_code == status.HTTP_200_OK + data = json.loads(response.data) + + assert 'trainingjob' in data + job_data = data['trainingjob'] + assert job_data['trainingjob_name'] == "test_job" + assert job_data['description'] == "Test description" + assert job_data['feature_list'] == "test_feature_group" + assert job_data['pipeline_name'] == "test_pipeline" + assert job_data['experiment_name'] == "test_experiment" + assert job_data['is_mme'] is True + assert job_data['model_name'] == "test_model" + assert job_data['model_info'] == "test_model_info" + assert job_data['accuracy'] == mock_metrics + + @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=False) + def test_invalid_name_version(self, mock1): + """Test with invalid training job name or version.""" + response = self.client.get('/trainingjobs/invalid_*job/999') + + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = json.loads(response.data) + assert "Exception" in data + assert "trainingjob_name or version is not correct" in data["Exception"] + + @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True) + @patch('trainingmgr.trainingmgr_main.get_info_by_version', return_value=None) + @patch('trainingmgr.trainingmgr_main.get_metrics', return_value = "No data available") + def test_nonexistent_trainingjob(self, mock1, mock2, mock3): + """Test when training job doesn't exist in database.""" + + response = self.client.get('/trainingjobs/nonexistent_job/1') + + assert response.status_code == status.HTTP_404_NOT_FOUND + data = json.loads(response.data) + assert "Exception" in data + assert "Not found given trainingjob with version" in data["Exception"] + + @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True) + @patch('trainingmgr.trainingmgr_main.get_info_by_version', side_effect=Exception("Database error")) + def test_database_error(self, mock1, mock2): + """Test handling of database errors.""" + + response = self.client.get('/trainingjobs/test_job/1') + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + data = json.loads(response.data) + assert "Exception" in data + assert "Database error" in data["Exception"] + + @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True) + @patch('trainingmgr.trainingmgr_main.get_info_by_version', return_value=mock_training_job) + @patch('trainingmgr.trainingmgr_main.get_metrics', side_effect=Exception("Metrics error")) + def test_metrics_error(self, mock1, mock2, mock3): + """Test handling of metrics retrieval error.""" + + response = self.client.get('/trainingjobs/test_job/1') + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + data = json.loads(response.data) + assert "Exception" in data + assert "Metrics error" in data["Exception"] + @pytest.mark.skip("") class Test_unpload_pipeline: def setup_method(self):