From: rajdeep11 Date: Wed, 30 Oct 2024 20:42:39 +0000 (+0530) Subject: test cases for X-Git-Tag: 3.0.0~45^2 X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=commitdiff_plain;h=f534aea8ca5b9c594f8383f858afc72471383927;p=aiml-fw%2Fawmf%2Ftm.git test cases for get_steps_state and get_model Change-Id: I6f51eb514b5ed8a1db7c874b925e98c4b5350123 Signed-off-by: rajdeep11 --- diff --git a/tests/test_tm_apis.py b/tests/test_tm_apis.py index faa8580..8ceca5a 100644 --- a/tests/test_tm_apis.py +++ b/tests/test_tm_apis.py @@ -403,50 +403,79 @@ class Test_unpload_pipeline: expected = ValueError("file not found in request.files") assert response.content_type == "application/json", "not equal content type" assert response.status_code == 500, "not equal code" -@pytest.mark.skip("") + class Test_get_steps_state: - def setup_method(self): + def setup_method(self): self.client = trainingmgr_main.APP.test_client(self) self.logger = trainingmgr_main.LOGGER + + @pytest.fixture + def mock_steps_state(self): + """Create mock steps state data.""" + return { + "DATA_EXTRACTION": "FINISHED", + "DATA_EXTRACTION_AND_TRAINING": "FINISHED", + "TRAINING": "FINISHED", + "TRAINING_AND_TRAINED_MODEL": "FINISHED", + "TRAINED_MODEL": "FINISHED" + } - @patch('trainingmgr.trainingmgr_main.get_field_of_given_version',return_value=[['data_extracted','data_pending'], ['data1','data2']]) - def test_get_steps_state(self,mock1): - usecase_name = "usecase7" - version = "1" - response = self.client.get("/trainingjobs/{}/{}/steps_state".format(usecase_name, version)) - expected_data = b'data_extracted' - assert response.content_type == "application/json", "not equal content type" - assert response.status_code == status.HTTP_200_OK, "not equal code" - assert response.data == expected_data, "not equal data" - - @patch('trainingmgr.trainingmgr_main.get_field_of_given_version',return_value=False) - def test_negative_get_steps_state(self,mock1): - usecase_name = "usecase7" - version = "1" - response = self.client.get("/trainingjobs/{}/{}/steps_state".format(usecase_name, version)) - expected_data = b'data_extracted' - assert response.content_type == "application/json", "not equal content type" - assert response.status_code == 404, "not equal code" + @patch('trainingmgr.trainingmgr_main.get_steps_state_db') + @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version') + def test_successful_get_steps_state(self, mock_name_and_version, mock_get_steps_state, mock_steps_state): + """Test successful retrieval of steps state.""" + + mock_get_steps_state.return_value = mock_steps_state + response = self.client.get('/trainingjobs/test_job/1/steps_state') + + assert response.status_code == status.HTTP_200_OK + data = response.get_json() + + # Verify all expected states are present + assert "DATA_EXTRACTION" in data + assert "DATA_EXTRACTION_AND_TRAINING" in data + assert "TRAINING" in data + assert "TRAINING_AND_TRAINED_MODEL" in data + assert "TRAINED_MODEL" in data + + # Verify state values + assert data["DATA_EXTRACTION"] == "FINISHED" + assert data["TRAINING"] == "FINISHED" + assert data["TRAINED_MODEL"] == "FINISHED" - @patch('trainingmgr.trainingmgr_main.get_field_of_given_version',return_value=Exception("Not found given trainingjob with version")) - def test_negative_get_steps_state_2(self,mock1): - usecase_name = "usecase7" - version = "1" - response = self.client.get("/trainingjobs/{}/{}/steps_state".format(usecase_name, version)) - expected_data = b'data_extracted' - assert response.status_code == 500, "not equal code" + @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=False) + def test_invalid_name_version(self, mock1): + """Test with invalid training job name or version.""" + response = self.client.get('/trainingjobs/invalid_job/999/steps_state') + + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.get_json() + assert "Exception" in data + assert "trainingjob_name or version is not correct" in data["Exception"] + + @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True) + @patch('trainingmgr.trainingmgr_main.get_steps_state_db', return_value=None) + def test_nonexistent_trainingjob(self, mock1, mock2): + """Test when training job doesn't exist in database.""" + + response = self.client.get('/trainingjobs/nonexistent_job/1/steps_state') - def test_negative_get_steps_state_by_name_and_version(self): - usecase_name = "usecase7*" - version = "1" - response = self.client.get("/trainingjobs/{}/{}/steps_state".format(usecase_name, version)) - assert response.status_code == status.HTTP_400_BAD_REQUEST, "not equal status code" - assert response.data == b'{"Exception":"The trainingjob_name or version is not correct"}\n' - usecase_name="usecase7" - version="a" - response = self.client.get("/trainingjobs/{}/{}/steps_state".format(usecase_name, version)) - assert response.status_code == status.HTTP_400_BAD_REQUEST, "not equal status code" - assert response.data == b'{"Exception":"The trainingjob_name or version is not correct"}\n' + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.get_json() + assert "Exception" in data + assert "Not found given trainingjob in database" in data["Exception"] + + @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True) + @patch('trainingmgr.trainingmgr_main.get_steps_state_db', side_effect=Exception("Database error")) + def test_database_error(self, mock1, mock2): + """Test handling of database errors.""" + + response = self.client.get('/trainingjobs/test_job/1/steps_state') + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + data = json.loads(response.data) + assert "Exception" in data + assert "Database error" in data["Exception"] @pytest.mark.skip("") class Test_training_main: @@ -928,7 +957,6 @@ class Test_get_metadata: assert response.status_code==status.HTTP_400_BAD_REQUEST assert response.data == b'{"Exception":"The trainingjob_name is not correct"}\n' -@pytest.mark.skip("") class Test_get_model: def setup_method(self): self.client = trainingmgr_main.APP.test_client(self) diff --git a/trainingmgr/trainingmgr_main.py b/trainingmgr/trainingmgr_main.py index 2dcbd03..cc4d458 100644 --- a/trainingmgr/trainingmgr_main.py +++ b/trainingmgr/trainingmgr_main.py @@ -274,7 +274,7 @@ def get_steps_state(trainingjob_name, version): LOGGER.error(str(err)) response_data = {"Exception": str(err)} - return APP.response_class(response=response_data, + return APP.response_class(response=json.dumps(response_data), status=response_code, mimetype=MIMETYPE_JSON)