From: ashishj1729 Date: Fri, 10 Jan 2025 09:29:05 +0000 (+0530) Subject: Adding Unit-Test for Trainingmgr_operations X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=commitdiff_plain;h=refs%2Fchanges%2F11%2F14011%2F1;p=aiml-fw%2Fawmf%2Ftm.git Adding Unit-Test for Trainingmgr_operations Change-Id: I421408bb94d1d72a893b244464a9d57dcf2b399f Signed-off-by: ashishj1729 (cherry picked from commit 60c3df0155391cf33a08de16c95f66b871feab18) --- diff --git a/tests/test_trainingmgr_operations.py b/tests/test_trainingmgr_operations.py index fd9d7ea..7d22161 100644 --- a/tests/test_trainingmgr_operations.py +++ b/tests/test_trainingmgr_operations.py @@ -15,234 +15,196 @@ # # limitations under the License. # # # # ================================================================================== -# import json -# import requests -# from unittest import mock -# from mock import patch, MagicMock -# import pytest -# import flask -# from requests.models import Response -# from threading import Lock -# import os -# import sys -# import datetime -# from flask_api import status -# from dotenv import load_dotenv -# from threading import Lock -# from trainingmgr import trainingmgr_main -# from trainingmgr.common import trainingmgr_operations -# from trainingmgr.common.exceptions_utls import TMException -# from trainingmgr.common.trainingmgr_util import MIMETYPE_JSON -# from trainingmgr.common.trainingmgr_config import TrainingMgrConfig - -# trainingmgr_main.LOGGER = pytest.logger -# trainingmgr_main.LOCK = Lock() -# trainingmgr_main.DATAEXTRACTION_JOBS_CACHE = {} - -# class DummyVariable: -# kf_adapter_ip = "localhost" -# kf_adapter_port = 5001 -# data_extraction_ip = "localhost" -# data_extraction_port = 32000 -# model_management_service_ip="localhost" -# model_management_service_port=123123 -# logger = trainingmgr_main.LOGGER - -# @pytest.mark.skip("") -# class Test_data_extraction_start: -# def setup_method(self): -# self.client = trainingmgr_main.APP.test_client(self) -# self.logger = trainingmgr_main.LOGGER - -# de_result = Response() -# de_result.status_code = status.HTTP_200_OK -# de_result.headers={'content-type': MIMETYPE_JSON} -# @patch('trainingmgr.common.trainingmgr_operations.requests.post', return_value = de_result) -# def test_success(self, mock1): -# trainingjob_name = "usecase12" -# training_config_obj = DummyVariable() -# feature_list = "*" -# query_filter = "" -# datalake_source = {"InfluxSource": {}} -# _measurement = "liveCell" -# influxdb_info_dict={'host': '', 'port': '', 'token': '', 'source_name': '', 'db_org': '', 'bucket': ''} -# try: -# response = trainingmgr_operations.data_extraction_start(training_config_obj, trainingjob_name, feature_list, -# query_filter, datalake_source, _measurement, influxdb_info_dict) -# assert response.status_code == status.HTTP_200_OK -# assert response.headers['content-type'] == MIMETYPE_JSON -# except: -# assert False - -# @pytest.mark.skip("") -# class Test_data_extraction_status: -# def setup_method(self): -# self.client = trainingmgr_main.APP.test_client(self) -# self.logger = trainingmgr_main.LOGGER - -# de_result = Response() -# de_result.status_code = status.HTTP_200_OK -# de_result.headers={'content-type': MIMETYPE_JSON} -# @patch('trainingmgr.common.trainingmgr_operations.requests.get', return_value = de_result) -# def test_success(self, mock1): -# trainingjob_name = "usecase12" -# training_config_obj = DummyVariable() -# try: -# response = trainingmgr_operations.data_extraction_status(trainingjob_name, training_config_obj) -# assert response.status_code == status.HTTP_200_OK -# assert response.headers['content-type'] == MIMETYPE_JSON -# except: -# assert False - -# @pytest.mark.skip("") -# class Test_training_start: -# def setup_method(self): -# self.client = trainingmgr_main.APP.test_client(self) -# self.logger = trainingmgr_main.LOGGER - -# ts_result = Response() -# ts_result.status_code = status.HTTP_200_OK -# ts_result.headers={'content-type': MIMETYPE_JSON} -# @patch('trainingmgr.common.trainingmgr_operations.requests.post', return_value = ts_result) -# def test_success(self, mock1): -# trainingjob_name = "usecase12" -# dict_data = { -# "pipeline_name": "qoe", -# "experiment_name": "default", -# "arguments": "{epoches : 1}", -# "pipeline_version": 1 -# } -# training_config_obj = DummyVariable() -# try: -# response = trainingmgr_operations.training_start(training_config_obj,dict_data,trainingjob_name) -# assert response.headers['content-type'] == MIMETYPE_JSON -# assert response.status_code == status.HTTP_200_OK -# except Exception: -# assert False - -# def test_fail(self): -# trainingjob_name = "usecase12" -# dict_data = { -# "pipeline_name": "qoe", -# "experiment_name": "default", -# "arguments": "{epoches : 1}", -# "pipeline_version": 1 -# } -# training_config_obj = DummyVariable() -# try: -# trainingmgr_operations.training_start(training_config_obj,dict_data,trainingjob_name) -# assert False -# except requests.exceptions.ConnectionError: -# assert True -# except Exception: -# assert False - -# @pytest.mark.skip("") -# class Test_create_dme_filtered_data_job: -# the_response=Response() -# the_response.status_code=status.HTTP_201_CREATED -# @patch('trainingmgr.common.trainingmgr_operations.requests.put', return_value=the_response) -# def test_success(self, mock1): -# training_config_obj = DummyVariable() -# source_name="" -# features=[] -# feature_group_name="test" -# host="10.0.0.50" -# port="31840" -# measured_obj_class="NRCellDU" -# response=trainingmgr_operations.create_dme_filtered_data_job(training_config_obj, source_name, features, feature_group_name, host, port, measured_obj_class) -# assert response.status_code==status.HTTP_201_CREATED, "create_dme_filtered_data_job failed" - -# def test_create_url_host_port_fail(self): -# training_config_obj = DummyVariable() -# source_name="" -# features=[] -# feature_group_name="test" -# measured_obj_class="NRCellDU" -# host="url error" -# port="31840" -# try: -# response=trainingmgr_operations.create_dme_filtered_data_job(training_config_obj, source_name, features, feature_group_name, host, port, measured_obj_class) -# assert False -# except TMException as err: -# assert "URL validation error: " in err.message -# except Exception: -# assert False - -# @pytest.mark.skip("") -# class Test_delete_dme_filtered_data_job: -# the_response=Response() -# the_response.status_code=status.HTTP_204_NO_CONTENT -# @patch('trainingmgr.common.trainingmgr_operations.requests.delete', return_value=the_response) -# def test_success(self, mock1): -# training_config_obj = DummyVariable() -# feature_group_name="test" -# host="10.0.0.50" -# port="31840" -# response=trainingmgr_operations.delete_dme_filtered_data_job(training_config_obj, feature_group_name, host, port) -# assert response.status_code==status.HTTP_204_NO_CONTENT, "delete_dme_filtered_data_job failed" - -# def test_create_url_host_port_fail(self): -# training_config_obj = DummyVariable() -# feature_group_name="test" -# host="url error" -# port="31840" -# try: -# response=trainingmgr_operations.delete_dme_filtered_data_job(training_config_obj, feature_group_name, host, port) -# assert False -# except TMException as err: -# assert "URL validation error: " in err.message -# except Exception: -# assert False - -# @pytest.mark.skip("") -# class Test_get_model_info: - -# @patch('trainingmgr.common.trainingmgr_operations.requests.get') -# def test_get_model_info(self,mock_requests_get): -# training_config_obj = DummyVariable() -# model_name="qoe" -# rapp_id = "rapp_1" -# meta_info = { -# "test": "test" -# } +from mock import patch +import pytest +from requests.models import Response +from threading import Lock +from flask_api import status +from trainingmgr import trainingmgr_main +from trainingmgr.common import trainingmgr_operations +from trainingmgr.common.exceptions_utls import TMException +from trainingmgr.common.trainingmgr_util import MIMETYPE_JSON +from trainingmgr.constants.steps import Steps +from trainingmgr.constants.states import States +trainingmgr_main.LOGGER = pytest.logger +trainingmgr_main.LOCK = Lock() +trainingmgr_main.DATAEXTRACTION_JOBS_CACHE = {} + +class DummyVariable: + kf_adapter_ip = "localhost" + kf_adapter_port = 5001 + data_extraction_ip = "localhost" + data_extraction_port = 32000 + model_management_service_ip="localhost" + model_management_service_port=123123 + logger = trainingmgr_main.LOGGER + +class DummyStepsState: + def __init__(self, states): + self.states = states + +class DummyTrainingJob: + def __init__(self, cur_state, notification_url): + self.steps_state = DummyStepsState(cur_state) + self.notification_url = notification_url + +class Test_data_extraction_start: + def setup_method(self): + self.client = trainingmgr_main.APP.test_client(self) + self.logger = trainingmgr_main.LOGGER + + de_result = Response() + de_result.status_code = status.HTTP_200_OK + de_result.headers={'content-type': MIMETYPE_JSON} + @patch('trainingmgr.common.trainingmgr_operations.requests.post', return_value = de_result) + def test_success(self, mock1): + trainingjob_id = 1 + featuregroup_name = "base1" + training_config_obj = DummyVariable() + feature_list = "*" + query_filter = "" + datalake_source = {"InfluxSource": {}} + _measurement = "liveCell" + influxdb_info_dict={'host': '', 'port': '', 'token': '', 'source_name': '', 'db_org': '', 'bucket': ''} + try: + response = trainingmgr_operations.data_extraction_start(training_config_obj, trainingjob_id, feature_list, + query_filter, datalake_source, _measurement, influxdb_info_dict, featuregroup_name) + assert response.status_code == status.HTTP_200_OK + assert response.headers['content-type'] == MIMETYPE_JSON + except: + assert False + + +class Test_data_extraction_status: + def setup_method(self): + self.client = trainingmgr_main.APP.test_client(self) + self.logger = trainingmgr_main.LOGGER + + de_result = Response() + de_result.status_code = status.HTTP_200_OK + de_result.headers={'content-type': MIMETYPE_JSON} + @patch('trainingmgr.common.trainingmgr_operations.requests.get', return_value = de_result) + def test_success(self, mock1): + featuregroup_name = "base1" + trainingjob_id = 1 + training_config_obj = DummyVariable() + try: + response = trainingmgr_operations.data_extraction_status(featuregroup_name, trainingjob_id, training_config_obj) + assert response.status_code == status.HTTP_200_OK + assert response.headers['content-type'] == MIMETYPE_JSON + except: + assert False + + +class Test_training_start: + def setup_method(self): + self.client = trainingmgr_main.APP.test_client(self) + self.logger = trainingmgr_main.LOGGER + + ts_result = Response() + ts_result.status_code = status.HTTP_200_OK + ts_result.headers={'content-type': MIMETYPE_JSON} + @patch('trainingmgr.common.trainingmgr_operations.requests.post', return_value = ts_result) + def test_success(self, mock1): + trainingjob_id = 1 + dict_data = { + "pipeline_name": "qoe", + "experiment_name": "default", + "arguments": "{epoches : 1}", + "pipeline_version": 1 + } + training_config_obj = DummyVariable() + try: + response = trainingmgr_operations.training_start(training_config_obj, dict_data, trainingjob_id) + assert response.headers['content-type'] == MIMETYPE_JSON + assert response.status_code == status.HTTP_200_OK + except Exception: + assert False + + def test_fail(self): + trainingjob_id = 1 + dict_data = { + "pipeline_name": "qoe", + "experiment_name": "default", + "arguments": "{epoches : 1}", + "pipeline_version": 1 + } + training_config_obj = DummyVariable() + try: + trainingmgr_operations.training_start(training_config_obj, dict_data, trainingjob_id) + assert False + except TMException: + assert True + except Exception: + # Any other Exception signifies test-failure + assert False + + +class Test_create_url_host_port: + def test_success(self): + expected_url = "http://10.0.0.7:38012/training" + url = trainingmgr_operations.create_url_host_port("http", "10.0.0.7", "38012", "training") + assert url == expected_url, "create_url_host_port Failed" -# model_data = { -# "model-name": model_name, -# "rapp-id": rapp_id, -# "meta-info": meta_info -# } -# mock_response=MagicMock(spec=Response) -# mock_response.status_code=200 -# mock_response.json.return_value={'message': {"name": model_name, "data": json.dumps(model_data)}} -# mock_requests_get.return_value= mock_response -# model_info=trainingmgr_operations.get_model_info(training_config_obj, model_name) -# expected_model_info={ -# "model-name": model_name, -# "rapp-id": rapp_id, -# "meta-info": meta_info -# } -# assert model_info==expected_model_info, "get model info failed" - -# @patch('trainingmgr.common.trainingmgr_operations.requests.get') -# def test_negative_get_model_info(self,mock_requests_get): -# training_config_obj = DummyVariable() -# model_name="qoe" -# rapp_id = "rapp_1" -# meta_info = { -# "test": "test" -# } + def test_failure(self): + try: + trainingmgr_operations.create_url_host_port("http", "HOST ERROR", "38012", "training") + assert False + except TMException as err: + assert "URL validation error: " in err.message + except Exception: + assert False + + +class Test_create_dme_filtered_data_job: + the_response=Response() + the_response.status_code=status.HTTP_201_CREATED + @patch('trainingmgr.common.trainingmgr_operations.requests.put', return_value=the_response) + def test_success(self, mock1): + training_config_obj = DummyVariable() + source_name="GNBDU324" + features= "pdcpBytesUl, pdcpBytesDl" + feature_group_name="test" + host="10.0.0.50" + port="31840" + measured_obj_class="NRCellDU" + response=trainingmgr_operations.create_dme_filtered_data_job(training_config_obj, source_name, features, feature_group_name, host, port, measured_obj_class) + assert response.status_code==status.HTTP_201_CREATED, "create_dme_filtered_data_job failed" + + +class Test_delete_dme_filtered_data_job: + the_response=Response() + the_response.status_code=status.HTTP_204_NO_CONTENT + @patch('trainingmgr.common.trainingmgr_operations.requests.delete', return_value=the_response) + def test_success(self, mock1): + training_config_obj = DummyVariable() + feature_group_name="test" + host="10.0.0.50" + port="31840" + response=trainingmgr_operations.delete_dme_filtered_data_job(training_config_obj, feature_group_name, host, port) + assert response.status_code == status.HTTP_204_NO_CONTENT, "delete_dme_filtered_data_job failed" + +class Test_notification_rapp: + steps_state = { + Steps.DATA_EXTRACTION.name: States.NOT_STARTED.name, + Steps.DATA_EXTRACTION_AND_TRAINING.name: States.NOT_STARTED.name, + Steps.TRAINING.name: States.NOT_STARTED.name, + Steps.TRAINING_AND_TRAINED_MODEL.name: States.NOT_STARTED.name, + Steps.TRAINED_MODEL.name: States.NOT_STARTED.name + } + the_response=Response() + the_response.status_code=status.HTTP_200_OK + @patch('trainingmgr.common.trainingmgr_operations.get_trainingjob', return_value = DummyTrainingJob(steps_state, "dummy_url")) + @patch('trainingmgr.common.trainingmgr_operations.requests.post', return_value=the_response) + def test_success(self, mock1, mock2): + trainingmgr_operations.notification_rapp(1) -# model_data = { -# "model-name": model_name, -# "rapp-id": rapp_id, -# "meta-info": meta_info -# } -# mock_response=MagicMock(spec=Response) -# mock_response.status_code=500 -# mock_response.json.return_value={'message': {"name": model_name, "data": json.dumps(model_data)}} -# mock_requests_get.return_value= mock_response -# try: -# model_info=trainingmgr_operations.get_model_info(training_config_obj, model_name) -# except TMException as err: -# assert "model info can't be fetched, model_name:" in err.message + the_response=Response() + the_response.status_code=status.HTTP_404_NOT_FOUND + @patch('trainingmgr.common.trainingmgr_operations.get_trainingjob', return_value = DummyTrainingJob(steps_state, "dummy_url")) + @patch('trainingmgr.common.trainingmgr_operations.requests.post', return_value=the_response) + def test_failure(self, mock1, mock2): + resp = trainingmgr_operations.notification_rapp(1) + assert resp is None, f"notification_rapp is supposed to fail and return None, but except it returned {resp}" + + \ No newline at end of file diff --git a/trainingmgr/common/trainingmgr_operations.py b/trainingmgr/common/trainingmgr_operations.py index 4ab1123..4c76994 100644 --- a/trainingmgr/common/trainingmgr_operations.py +++ b/trainingmgr/common/trainingmgr_operations.py @@ -182,20 +182,6 @@ def delete_dme_filtered_data_job(training_config_obj, feature_group_name, host, response = requests.delete(url) return response -def get_model_info(training_config_obj, model_name): - logger = training_config_obj.logger - model_management_service_ip = training_config_obj.model_management_service_ip - model_management_service_port = training_config_obj.model_management_service_port - url ="http://"+str(model_management_service_ip)+":"+str(model_management_service_port)+"/getModelInfo/{}".format(model_name) - response = requests.get(url) - if(response.status_code==status.HTTP_200_OK): - model_info=json.loads(response.json()['message']["data"]) - return model_info - else: - errMsg="model info can't be fetched, model_name: {} , err: {}".format(model_name, response.text) - logger.error(errMsg) - raise TMException(errMsg) - def notification_rapp(trainingjob_id): try: trainingjob = get_trainingjob(trainingjob_id)