From 8c33b20f20db2226923abfd0af509dc93cb2692e Mon Sep 17 00:00:00 2001 From: ashishj1729 Date: Wed, 26 Mar 2025 11:25:37 +0530 Subject: [PATCH] Issue-Id: AIMLFW-183: Add API to get trainingJobs from modelId Following Items are implemented: 1. Added an API-endpoint ('/training-jobs//') which will take ModelId (modelName + modelVersion) as input and returns all the TrainingJobs associated with the ModelId. Change-Id: I1363fd9bc17631147d94b1c55f2003e72da1bcc1 Signed-off-by: ashishj1729 --- tests/test_featuregroup_service.py | 2 +- tests/test_trainingjob_controller.py | 26 +++++++++++- tests/test_trainingjob_db.py | 51 ++++++++++++++++++++++++ tests/test_trainingjob_service.py | 29 ++++++++++++-- trainingmgr/controller/trainingjob_controller.py | 16 +++++++- trainingmgr/db/trainingjob_db.py | 18 +++++---- trainingmgr/service/training_job_service.py | 23 +++++------ 7 files changed, 139 insertions(+), 26 deletions(-) create mode 100644 tests/test_trainingjob_db.py diff --git a/tests/test_featuregroup_service.py b/tests/test_featuregroup_service.py index 82ddb8c..746619f 100644 --- a/tests/test_featuregroup_service.py +++ b/tests/test_featuregroup_service.py @@ -47,7 +47,7 @@ sys.modules["trainingmgr.handler.async_handler"] = MagicMock(ModelMetricsSdk=moc from trainingmgr.db.trainingjob_db import ( change_state_to_failed, delete_trainingjob_by_id, create_trainingjob, - get_trainingjob, get_trainingjob_by_modelId_db, change_steps_state, + get_trainingjob, change_steps_state, change_field_value, change_steps_state_df, changeartifact ) from trainingmgr.common.exceptions_utls import DBException, TMException diff --git a/tests/test_trainingjob_controller.py b/tests/test_trainingjob_controller.py index b081096..3e31d24 100644 --- a/tests/test_trainingjob_controller.py +++ b/tests/test_trainingjob_controller.py @@ -241,4 +241,28 @@ class TestGetTrainingJobStatus: def test_get_trainingjob_status(self, mock1): response = self.client.get("/training-jobs/123/status") assert response.status_code == 200 - assert response.json == {"status": "running"} \ No newline at end of file + assert response.json == {"status": "running"} + +class TestGetTrainingJobInfosFromModelId: + def setup_method(self): + app = Flask(__name__) + app.register_blueprint(training_job_controller) + self.client = app.test_client() + + @pytest.fixture + def mock_feature_group_fixture(self): + FeatureGroupObj = MagicMock() + FeatureGroupObj.featuregroup_name = "test_feature_group" + return FeatureGroupObj + + @patch('trainingmgr.controller.trainingjob_controller.fetch_trainingjob_infos_from_model_id') + def test_success(self, mock_fetch_feature_group, mock_feature_group_fixture): + mock_fetch_feature_group.return_value = mock_feature_group_fixture + response = self.client.get("/training-jobs/abc/1") + assert response.status_code == 200 + + + @patch('trainingmgr.controller.trainingjob_controller.fetch_trainingjob_infos_from_model_id', side_effect = Exception("Generic exception")) + def test_internal_error(self, mock_fetch_feature_group, mock_feature_group_fixture): + response = self.client.get("/training-jobs/abc/1") + assert response.status_code == 500 \ No newline at end of file diff --git a/tests/test_trainingjob_db.py b/tests/test_trainingjob_db.py new file mode 100644 index 0000000..5e15574 --- /dev/null +++ b/tests/test_trainingjob_db.py @@ -0,0 +1,51 @@ +# ================================================================================== +# +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== + +import pytest +from unittest.mock import patch, MagicMock +from trainingmgr.models.trainingjob import TrainingJob + +from trainingmgr.db.trainingjob_db import ( + get_trainingjobs_by_model_id_db +) + +class TestGetTrainingJobsByModelIdDb: + def test_success(self): + mock_session = MagicMock() + mock_query = mock_session.query.return_value + mock_join = mock_query.join.return_value + mock_filter = mock_join.filter.return_value + # mock_trainingjobs = [MagicMock(spec=TrainingJob), MagicMock(spec=TrainingJob)] + mock_trainingjobs = None + mock_filter.all.return_value = mock_trainingjobs + + with patch("trainingmgr.models.db.session.query", mock_session): + result = get_trainingjobs_by_model_id_db("test_model", "v1") + + + @patch('trainingmgr.models.db.session.query', side_effect = Exception("Database error")) + def test_internal_error(self, mock1): + try: + model_name = "abc" + model_version = "1" + get_trainingjobs_by_model_id_db(model_name, model_version) + assert False, "The test is supposed to fail, but it passed, It will be considered as failure" + except: + # Test was supposed to fail, and It failed, So, It will consider Passed + pass + \ No newline at end of file diff --git a/tests/test_trainingjob_service.py b/tests/test_trainingjob_service.py index 7bcab2f..74b4779 100644 --- a/tests/test_trainingjob_service.py +++ b/tests/test_trainingjob_service.py @@ -47,7 +47,7 @@ sys.modules["trainingmgr.handler.async_handler"] = MagicMock(ModelMetricsSdk=moc from trainingmgr.db.trainingjob_db import ( change_state_to_failed, delete_trainingjob_by_id, create_trainingjob, - get_trainingjob, get_trainingjob_by_modelId_db, change_steps_state, + get_trainingjob, change_steps_state, change_field_value, change_steps_state_df, changeartifact ) from trainingmgr.common.exceptions_utls import DBException, TMException @@ -60,9 +60,9 @@ from trainingmgr.service.training_job_service import ( get_trainining_jobs, create_training_job, delete_training_job, - get_trainingjob_by_modelId, + fetch_trainingjob_infos_from_model_id, update_artifact_version, - training + training, ) class TestGetTrainingJob: @@ -327,3 +327,26 @@ class TestTraining: response, status_code = training(mock_Trainingjob) assert status_code == 500 +class TestFetchTrainingJobInfosFromModelId: + @patch('trainingmgr.service.training_job_service.get_trainingjobs_by_model_id_db') + def test_success(self, mock_gettrainingJob): + model_name = "abc" + model_version = "1" + trainingjobs_info = fetch_trainingjob_infos_from_model_id(model_name, model_version) + assert True # Reached here without error, test-passed + + + @patch('trainingmgr.service.training_job_service.get_trainingjobs_by_model_id_db', side_effect = Exception("Generic exception")) + def test_internalError(self, mock_gettrainingJob): + model_name = "abc" + model_version = "1" + try: + trainingjobs_info = fetch_trainingjob_infos_from_model_id(model_name, model_version) + assert False, "The test should have raised an Exception, but It didn't" + except Exception as e: + # Signifies test-passed + pass + + + + \ No newline at end of file diff --git a/trainingmgr/controller/trainingjob_controller.py b/trainingmgr/controller/trainingjob_controller.py index d618f8e..070bcb8 100644 --- a/trainingmgr/controller/trainingjob_controller.py +++ b/trainingmgr/controller/trainingjob_controller.py @@ -24,9 +24,10 @@ from marshmallow import ValidationError from trainingmgr.common.exceptions_utls import TMException from trainingmgr.common.trainingmgr_config import TrainingMgrConfig from trainingmgr.schemas.trainingjob_schema import TrainingJobSchema +from trainingmgr.schemas.featuregroup_schema import FeatureGroupSchema from trainingmgr.schemas.problemdetail_schema import ProblemDetails from trainingmgr.service.training_job_service import delete_training_job, create_training_job, get_training_job, get_trainining_jobs, \ -get_steps_state +get_steps_state, fetch_trainingjob_infos_from_model_id from trainingmgr.common.trainingmgr_util import check_key_in_dictionary from trainingmgr.common.trainingConfig_parser import validateTrainingConfig from trainingmgr.service.mme_service import get_modelinfo_by_modelId_service @@ -114,3 +115,16 @@ def get_trainingjob_status(training_job_id): except Exception as err: LOGGER.error(f"Error fetching status for training job {training_job_id}: {str(err)}") return ProblemDetails(500, "Internal Server Error", str(err)).to_json() + +@training_job_controller.route('/training-jobs//', methods=['GET']) +def get_trainingjob_infos_from_model_id(model_name, model_version): + ''' + This API-endpoint takes model_name and model_version into account and returns all the trainingJobInfo related to that ModelId + ''' + LOGGER.debug(f'Requesting trainingJob-info for model-Id for model_Id: {model_name} and {model_version}') + try: + trainingjob_infos = fetch_trainingjob_infos_from_model_id(model_name, model_version) + return jsonify(trainingjobs_schema.dump(trainingjob_infos)), 200 + except Exception as err: + LOGGER.error(f"Error fetching training-job-infos corresponding to model_name = {model_name} and model_version = {model_version} : {str(err)}") + return ProblemDetails(500, "Internal Server Error", str(err)).to_json() \ No newline at end of file diff --git a/trainingmgr/db/trainingjob_db.py b/trainingmgr/db/trainingjob_db.py index 0b24939..4378d70 100644 --- a/trainingmgr/db/trainingjob_db.py +++ b/trainingmgr/db/trainingjob_db.py @@ -24,7 +24,7 @@ from trainingmgr.models import db, TrainingJob, TrainingJobStatus, ModelID from trainingmgr.constants.steps import Steps from trainingmgr.constants.states import States from sqlalchemy.exc import NoResultFound - +from sqlalchemy import desc DB_QUERY_EXEC_ERROR = "Failed to execute query in " @@ -120,22 +120,24 @@ def change_field_value_by_version(trainingjob_name, version, field, field_value) except Exception as err: raise DBException("Failed to execute query in change_field_value_by_version," + str(err)) -def get_trainingjob_by_modelId_db(model_id): + +def get_trainingjobs_by_model_id_db(model_name, model_version): try: - trainingjob = ( + trainingjobs = ( db.session.query(TrainingJob) .join(ModelID) .filter( - ModelID.modelname == model_id.modelname, - ModelID.modelversion == model_id.modelversion + ModelID.modelname == model_name, + ModelID.modelversion == model_version ) - .one() + .all() ) - return trainingjob + return trainingjobs except NoResultFound: return None except Exception as e: - raise DBException(f'{DB_QUERY_EXEC_ERROR} in the get_trainingjob_by_name_db : {str(e)}') + raise DBException(f'{DB_QUERY_EXEC_ERROR} in the get_trainingjobs_by_model_id_db : {str(e)}') + def change_steps_state(trainingjob_id, step: Steps, state:States): diff --git a/trainingmgr/service/training_job_service.py b/trainingmgr/service/training_job_service.py index 0c6086a..144d141 100644 --- a/trainingmgr/service/training_job_service.py +++ b/trainingmgr/service/training_job_service.py @@ -21,8 +21,8 @@ from flask_api import status from flask import jsonify from trainingmgr.common.trainingmgr_operations import data_extraction_start, notification_rapp from trainingmgr.db.model_db import get_model_by_modelId -from trainingmgr.db.trainingjob_db import change_state_to_failed, delete_trainingjob_by_id, create_trainingjob, get_trainingjob, get_trainingjob_by_modelId_db, \ -change_steps_state, change_field_value, change_field_value, change_steps_state_df, changeartifact +from trainingmgr.db.trainingjob_db import change_state_to_failed, delete_trainingjob_by_id, create_trainingjob, get_trainingjob,\ +change_steps_state, change_field_value, change_field_value, change_steps_state_df, changeartifact, get_trainingjobs_by_model_id_db from trainingmgr.common.exceptions_utls import DBException, TMException from trainingmgr.common.trainingConfig_parser import getField, setField from trainingmgr.handler.async_handler import DATAEXTRACTION_JOBS_CACHE @@ -129,15 +129,6 @@ def delete_training_job(training_job_id : int): raise DBException(f"delete_trainining_job failed with exception : {str(err)}") -def get_trainingjob_by_modelId(model_id): - try: - trainingjob = get_trainingjob_by_modelId_db(model_id) - return trainingjob - except Exception as err: - if "No row was found when one was required" in str(err): - return None - raise DBException(f"get_trainingjob_by_name failed with exception : {str(err)}") - def get_steps_state(trainingjob_id): try: trainingjob = get_trainingjob(trainingjob_id) @@ -283,4 +274,12 @@ def fetch_pipelinename_and_version(type, training_config): else : return getField(training_config, "retraining_pipeline_name"), getField(training_config, "retraining_pipeline_version") except Exception as err: - raise TMException(f"cant fetch training or retraining pipeline name or version from trainingconfig {training_config}| Error: " + str(err)) \ No newline at end of file + raise TMException(f"cant fetch training or retraining pipeline name or version from trainingconfig {training_config}| Error: " + str(err)) + + +def fetch_trainingjob_infos_from_model_id(model_name, model_version): + try: + trainingjob_infos = get_trainingjobs_by_model_id_db(model_name, model_version) + return trainingjob_infos + except Exception as err: + raise TMException(f"Can't fetch trainingjob_infos from model_name {model_name} and model_version {model_version}| Error: " + str(err)) \ No newline at end of file -- 2.16.6