from trainingmgr.common.tmgr_logger import TMLogger
from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
from trainingmgr.common.exceptions_utls import DBException, TMException
+from trainingmgr.models import TrainingJob
+
trainingmgr_main.LOGGER = pytest.logger
trainingmgr_main.LOCK = Lock()
trainingmgr_main.DATAEXTRACTION_JOBS_CACHE = {}
assert expected_data in str(response.data)
-@pytest.mark.skip("")
class Test_get_trainingjob_by_name_version:
def setup_method(self):
self.client = trainingmgr_main.APP.test_client(self)
self.logger = trainingmgr_main.LOGGER
- @patch('trainingmgr.trainingmgr_main.get_info_by_version',return_value=[('usecase7', 'auto test', '*', 'prediction with model name', 'Default', '{"arguments": {"epochs": "1", "usecase": "usecase7"}}', 'Enb=20 and Cellnum=6', datetime.datetime(2022, 9, 20,11, 40, 30), '7d09c0bf-7575-4475-86ff-5573fb3c4716', '{"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "FINISHED", "TRAINING": "FINISHED", "TRAINING_AND_TRAINED_MODEL": "FINISHED", "TRAINED_MODEL": "FINISHED"}', datetime.datetime(2022, 9, 20, 11, 42, 20), 1, True, 'Near RT RIC', '{"datalake_source": {"CassandraSource": {}}}', '{"datalake_source": {"CassandraSource": {}}}','http://10.0.0.47:32002/model/usecase7/1/Model.zip','','','','','',False,'','')])
- @patch('trainingmgr.trainingmgr_main.get_metrics',return_value={"metrics": [{"Accuracy": "0.0"}]})
- @patch('trainingmgr.trainingmgr_main.get_one_key',return_value='cassandra')
- def test_get_trainingjob_by_name_version(self,mock1,mock2,mock3):
- usecase_name = "usecase7"
- version = "1"
- response = self.client.get("/trainingjobs/{}/{}".format(usecase_name, version))
- expected_data = b'{"trainingjob": {"trainingjob_name": "usecase7", "description": "auto test", "feature_list": "*", "pipeline_name": "prediction with model name", "experiment_name": "Default", "arguments": {"epochs": "1", "usecase": "usecase7"}, "query_filter": "Enb=20 and Cellnum=6", "creation_time": "2022-09-20 11:40:30", "run_id": "7d09c0bf-7575-4475-86ff-5573fb3c4716", "steps_state": {"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "FINISHED", "TRAINING": "FINISHED", "TRAINING_AND_TRAINED_MODEL": "FINISHED", "TRAINED_MODEL": "FINISHED"}, "updation_time": "2022-09-20 11:42:20", "version": 1, "enable_versioning": true, "pipeline_version": "Near RT RIC", "datalake_source": "cassandra", "model_url": "{\\"datalake_source\\": {\\"CassandraSource\\": {}}}", "notification_url": "http://10.0.0.47:32002/model/usecase7/1/Model.zip", "is_mme": "", "model_name": "", "model_info": "", "accuracy": {"metrics": [{"Accuracy": "0.0"}]}}}'
- assert response.status_code == status.HTTP_200_OK, "not equal code"
- assert response.data == expected_data, "not equal data"
-
- @patch('trainingmgr.trainingmgr_main.get_info_by_version',return_value=False)
- @patch('trainingmgr.trainingmgr_main.get_metrics',return_value={"metrics": [{"Accuracy": "0.0"}]})
- @patch('trainingmgr.trainingmgr_main.get_one_key',return_value='cassandra')
- def test_negative_get_trainingjob_by_name_version(self,mock1,mock2,mock3):
- usecase_name = "usecase7"
- version = "1"
- response = self.client.get("/trainingjobs/{}/{}".format(usecase_name, version))
- expected_data = b'{"trainingjob": {"trainingjob_name": "usecase7", "description": "auto test", "feature_list": "*", "pipeline_name": "prediction with model name", "experiment_name": "Default", "arguments": {"epochs": "1", "usecase": "usecase7"}, "query_filter": "Enb=20 and Cellnum=6", "creation_time": "2022-09-20 11:40:30", "run_id": "7d09c0bf-7575-4475-86ff-5573fb3c4716", "steps_state": {"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "FINISHED", "TRAINING": "FINISHED", "TRAINING_AND_TRAINED_MODEL": "FINISHED", "TRAINED_MODEL": "FINISHED"}, "updation_time": "2022-09-20 11:42:20", "version": 1, "enable_versioning": true, "pipeline_version": "Near RT RIC", "datalake_source": "cassandra", "model_url": "{\\"datalake_source\\": {\\"CassandraSource\\": {}}}", "notification_url": "http://10.0.0.47:32002/model/usecase7/1/Model.zip", "_measurement": "", "bucket": "", "accuracy": {"metrics": [{"Accuracy": "0.0"}]}}}'
- trainingmgr_main.LOGGER.debug(expected_data)
- trainingmgr_main.LOGGER.debug(response.data)
- assert response.content_type == "application/json", "not equal content type"
- assert response.status_code == 404, "not equal code"
-
- def test_negative_get_trainingjob_by_name_version2(self):
- usecase_name = "usecase7*"
- version = "1"
- response = self.client.get("/trainingjobs/{}/{}".format(usecase_name, version))
- print(response.data)
- assert response.status_code == status.HTTP_400_BAD_REQUEST, "not equal status code"
- assert response.data == b'{"Exception":"The trainingjob_name or version is not correct"}\n'
- usecase_name="usecase7"
- version="a"
- response = self.client.get("/trainingjobs/{}/{}".format(usecase_name, version))
- assert response.status_code == status.HTTP_400_BAD_REQUEST, "not equal status code"
- assert response.data == b'{"Exception":"The trainingjob_name or version is not correct"}\n'
+ @pytest.fixture
+ def mock_training_job(self):
+ """Create a mock TrainingJob object."""
+ creation_time = datetime.datetime.now()
+ updation_time = datetime.datetime.now()
+ return TrainingJob(
+ trainingjob_name="test_job",
+ description="Test description",
+ feature_group_name="test_feature_group",
+ pipeline_name="test_pipeline",
+ experiment_name="test_experiment",
+ arguments=json.dumps({"param1": "value1"}),
+ query_filter="test_filter",
+ creation_time=creation_time,
+ run_id="test_run_id",
+ steps_state=json.dumps({"step1": "completed"}),
+ updation_time=updation_time,
+ version=1,
+ enable_versioning=True,
+ pipeline_version="v1",
+ datalake_source=json.dumps({"datalake_source": {"source1": "path1"}}),
+ model_url="http://test.model.url",
+ notification_url="http://test.notification.url",
+ deletion_in_progress=False,
+ is_mme=True,
+ model_name="test_model",
+ model_info="test_model_info"
+ )
+
+ @pytest.fixture
+ def mock_metrics(self):
+ """Create mock metrics data."""
+ return {"accuracy": "0.95", "precision": "0.92"}
+
+ @patch('trainingmgr.trainingmgr_main.get_info_by_version')
+ @patch('trainingmgr.trainingmgr_main.get_metrics')
+ @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True)
+ def test_successful_get_trainingjob(self, mock_check_name_and_version, mock_get_metrics, mock_get_info, mock_training_job, mock_metrics):
+ """Test successful retrieval of training job."""
+ # Mock return values
+ mock_get_info.return_value = mock_training_job
+ mock_get_metrics.return_value = mock_metrics
+
+ # Make the GET request
+ response = self.client.get('/trainingjobs/test_job/1')
+
+ # Verify response
+ assert response.status_code == status.HTTP_200_OK
+ data = json.loads(response.data)
+
+ assert 'trainingjob' in data
+ job_data = data['trainingjob']
+ assert job_data['trainingjob_name'] == "test_job"
+ assert job_data['description'] == "Test description"
+ assert job_data['feature_list'] == "test_feature_group"
+ assert job_data['pipeline_name'] == "test_pipeline"
+ assert job_data['experiment_name'] == "test_experiment"
+ assert job_data['is_mme'] is True
+ assert job_data['model_name'] == "test_model"
+ assert job_data['model_info'] == "test_model_info"
+ assert job_data['accuracy'] == mock_metrics
+
+ @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=False)
+ def test_invalid_name_version(self, mock1):
+ """Test with invalid training job name or version."""
+ response = self.client.get('/trainingjobs/invalid_*job/999')
+
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ data = json.loads(response.data)
+ assert "Exception" in data
+ assert "trainingjob_name or version is not correct" in data["Exception"]
+
+ @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True)
+ @patch('trainingmgr.trainingmgr_main.get_info_by_version', return_value=None)
+ @patch('trainingmgr.trainingmgr_main.get_metrics', return_value = "No data available")
+ def test_nonexistent_trainingjob(self, mock1, mock2, mock3):
+ """Test when training job doesn't exist in database."""
+
+ response = self.client.get('/trainingjobs/nonexistent_job/1')
+
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+ data = json.loads(response.data)
+ assert "Exception" in data
+ assert "Not found given trainingjob with version" in data["Exception"]
+
+ @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True)
+ @patch('trainingmgr.trainingmgr_main.get_info_by_version', side_effect=Exception("Database error"))
+ def test_database_error(self, mock1, mock2):
+ """Test handling of database errors."""
+
+ response = self.client.get('/trainingjobs/test_job/1')
+
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
+ data = json.loads(response.data)
+ assert "Exception" in data
+ assert "Database error" in data["Exception"]
+
+ @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True)
+ @patch('trainingmgr.trainingmgr_main.get_info_by_version', return_value=mock_training_job)
+ @patch('trainingmgr.trainingmgr_main.get_metrics', side_effect=Exception("Metrics error"))
+ def test_metrics_error(self, mock1, mock2, mock3):
+ """Test handling of metrics retrieval error."""
+
+ response = self.client.get('/trainingjobs/test_job/1')
+
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
+ data = json.loads(response.data)
+ assert "Exception" in data
+ assert "Metrics error" in data["Exception"]
+
@pytest.mark.skip("")
class Test_unpload_pipeline:
def setup_method(self):