Issue-Id: AIMLFW-183: Add API to get trainingJobs from modelId 83/14283/7
authorashishj1729 <jain.ashish@samsung.com>
Wed, 26 Mar 2025 05:55:37 +0000 (11:25 +0530)
committerashishj1729 <jain.ashish@samsung.com>
Mon, 31 Mar 2025 05:03:47 +0000 (10:33 +0530)
Following Items are implemented:
1. Added an API-endpoint ('/training-jobs/<model_name>/<model_version>') which will take ModelId (modelName + modelVersion) as input and returns all the TrainingJobs associated with the ModelId.

Change-Id: I1363fd9bc17631147d94b1c55f2003e72da1bcc1
Signed-off-by: ashishj1729 <jain.ashish@samsung.com>
tests/test_featuregroup_service.py
tests/test_trainingjob_controller.py
tests/test_trainingjob_db.py [new file with mode: 0644]
tests/test_trainingjob_service.py
trainingmgr/controller/trainingjob_controller.py
trainingmgr/db/trainingjob_db.py
trainingmgr/service/training_job_service.py

index 82ddb8c..746619f 100644 (file)
@@ -47,7 +47,7 @@ sys.modules["trainingmgr.handler.async_handler"] = MagicMock(ModelMetricsSdk=moc
 
 from trainingmgr.db.trainingjob_db import (
     change_state_to_failed, delete_trainingjob_by_id, create_trainingjob,
-    get_trainingjob, get_trainingjob_by_modelId_db, change_steps_state,
+    get_trainingjob, change_steps_state,
     change_field_value, change_steps_state_df, changeartifact
 )
 from trainingmgr.common.exceptions_utls import DBException, TMException
index b081096..3e31d24 100644 (file)
@@ -241,4 +241,28 @@ class TestGetTrainingJobStatus:
     def test_get_trainingjob_status(self, mock1):
         response = self.client.get("/training-jobs/123/status")
         assert response.status_code == 200
-        assert response.json == {"status": "running"}
\ No newline at end of file
+        assert response.json == {"status": "running"}
+        
+class TestGetTrainingJobInfosFromModelId:
+    def setup_method(self):
+        app = Flask(__name__)
+        app.register_blueprint(training_job_controller)
+        self.client = app.test_client()
+    
+    @pytest.fixture
+    def mock_feature_group_fixture(self):
+        FeatureGroupObj = MagicMock()
+        FeatureGroupObj.featuregroup_name = "test_feature_group"
+        return FeatureGroupObj
+    
+    @patch('trainingmgr.controller.trainingjob_controller.fetch_trainingjob_infos_from_model_id')
+    def test_success(self, mock_fetch_feature_group, mock_feature_group_fixture):
+        mock_fetch_feature_group.return_value = mock_feature_group_fixture
+        response = self.client.get("/training-jobs/abc/1")
+        assert response.status_code == 200
+    
+
+    @patch('trainingmgr.controller.trainingjob_controller.fetch_trainingjob_infos_from_model_id', side_effect = Exception("Generic exception"))
+    def test_internal_error(self, mock_fetch_feature_group, mock_feature_group_fixture):
+        response = self.client.get("/training-jobs/abc/1")
+        assert response.status_code == 500
\ No newline at end of file
diff --git a/tests/test_trainingjob_db.py b/tests/test_trainingjob_db.py
new file mode 100644 (file)
index 0000000..5e15574
--- /dev/null
@@ -0,0 +1,51 @@
+# ==================================================================================
+#
+#      Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+#
+# ==================================================================================
+
+import pytest
+from unittest.mock import patch, MagicMock
+from trainingmgr.models.trainingjob import TrainingJob
+
+from trainingmgr.db.trainingjob_db import (
+    get_trainingjobs_by_model_id_db
+)
+
+class TestGetTrainingJobsByModelIdDb:
+    def test_success(self):
+        mock_session = MagicMock()
+        mock_query = mock_session.query.return_value
+        mock_join = mock_query.join.return_value
+        mock_filter = mock_join.filter.return_value
+        # mock_trainingjobs = [MagicMock(spec=TrainingJob), MagicMock(spec=TrainingJob)]
+        mock_trainingjobs = None
+        mock_filter.all.return_value = mock_trainingjobs
+        
+        with patch("trainingmgr.models.db.session.query", mock_session):
+            result = get_trainingjobs_by_model_id_db("test_model", "v1")
+
+    
+    @patch('trainingmgr.models.db.session.query', side_effect = Exception("Database error"))
+    def test_internal_error(self, mock1):
+        try:
+            model_name = "abc"
+            model_version = "1"
+            get_trainingjobs_by_model_id_db(model_name, model_version)
+            assert False, "The test is supposed to fail, but it passed, It will be considered as failure"
+        except:
+            # Test was supposed to fail, and It failed, So, It will consider Passed
+            pass
+        
\ No newline at end of file
index 7bcab2f..74b4779 100644 (file)
@@ -47,7 +47,7 @@ sys.modules["trainingmgr.handler.async_handler"] = MagicMock(ModelMetricsSdk=moc
 
 from trainingmgr.db.trainingjob_db import (
     change_state_to_failed, delete_trainingjob_by_id, create_trainingjob,
-    get_trainingjob, get_trainingjob_by_modelId_db, change_steps_state,
+    get_trainingjob, change_steps_state,
     change_field_value, change_steps_state_df, changeartifact
 )
 from trainingmgr.common.exceptions_utls import DBException, TMException
@@ -60,9 +60,9 @@ from trainingmgr.service.training_job_service import (
     get_trainining_jobs,
     create_training_job,
     delete_training_job,
-    get_trainingjob_by_modelId,
+    fetch_trainingjob_infos_from_model_id,
     update_artifact_version,
-    training
+    training,
 )
 
 class TestGetTrainingJob:
@@ -327,3 +327,26 @@ class TestTraining:
         response, status_code = training(mock_Trainingjob)
         assert status_code == 500
 
+class TestFetchTrainingJobInfosFromModelId:
+    @patch('trainingmgr.service.training_job_service.get_trainingjobs_by_model_id_db')
+    def test_success(self, mock_gettrainingJob):
+        model_name = "abc"
+        model_version = "1"
+        trainingjobs_info = fetch_trainingjob_infos_from_model_id(model_name, model_version)
+        assert True # Reached here without error, test-passed
+    
+    
+    @patch('trainingmgr.service.training_job_service.get_trainingjobs_by_model_id_db', side_effect = Exception("Generic exception"))
+    def test_internalError(self, mock_gettrainingJob):
+        model_name = "abc"
+        model_version = "1"
+        try:
+            trainingjobs_info = fetch_trainingjob_infos_from_model_id(model_name, model_version)
+            assert False, "The test should have raised an Exception, but It didn't"
+        except Exception as e:
+            # Signifies test-passed
+            pass
+            
+            
+
+    
\ No newline at end of file
index d618f8e..070bcb8 100644 (file)
@@ -24,9 +24,10 @@ from marshmallow import ValidationError
 from trainingmgr.common.exceptions_utls import TMException
 from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
 from trainingmgr.schemas.trainingjob_schema import TrainingJobSchema
+from trainingmgr.schemas.featuregroup_schema import FeatureGroupSchema
 from trainingmgr.schemas.problemdetail_schema import ProblemDetails
 from trainingmgr.service.training_job_service import delete_training_job, create_training_job, get_training_job, get_trainining_jobs, \
-get_steps_state
+get_steps_state, fetch_trainingjob_infos_from_model_id
 from trainingmgr.common.trainingmgr_util import check_key_in_dictionary
 from trainingmgr.common.trainingConfig_parser import validateTrainingConfig
 from trainingmgr.service.mme_service import get_modelinfo_by_modelId_service
@@ -114,3 +115,16 @@ def get_trainingjob_status(training_job_id):
     except Exception as err:
         LOGGER.error(f"Error fetching status for training job {training_job_id}: {str(err)}")
         return ProblemDetails(500, "Internal Server Error", str(err)).to_json()
+
+@training_job_controller.route('/training-jobs/<model_name>/<model_version>', methods=['GET'])
+def get_trainingjob_infos_from_model_id(model_name, model_version):
+    '''
+     This API-endpoint takes model_name and model_version into account and returns all the trainingJobInfo related to that ModelId
+    '''
+    LOGGER.debug(f'Requesting trainingJob-info for model-Id for model_Id: {model_name} and {model_version}')
+    try:
+        trainingjob_infos = fetch_trainingjob_infos_from_model_id(model_name, model_version)
+        return jsonify(trainingjobs_schema.dump(trainingjob_infos)), 200
+    except Exception as err:
+        LOGGER.error(f"Error fetching training-job-infos corresponding to model_name = {model_name} and model_version = {model_version} : {str(err)}")
+        return ProblemDetails(500, "Internal Server Error", str(err)).to_json()
\ No newline at end of file
index 0b24939..4378d70 100644 (file)
@@ -24,7 +24,7 @@ from trainingmgr.models import db, TrainingJob, TrainingJobStatus, ModelID
 from trainingmgr.constants.steps import Steps
 from trainingmgr.constants.states import States
 from sqlalchemy.exc import NoResultFound
-
+from sqlalchemy import desc
 
 
 DB_QUERY_EXEC_ERROR = "Failed to execute query in "
@@ -120,22 +120,24 @@ def change_field_value_by_version(trainingjob_name, version, field, field_value)
     except Exception as err:
         raise DBException("Failed to execute query in change_field_value_by_version," + str(err))
 
-def get_trainingjob_by_modelId_db(model_id):
+
+def get_trainingjobs_by_model_id_db(model_name, model_version):
     try:
-        trainingjob = (
+        trainingjobs = (
             db.session.query(TrainingJob)
             .join(ModelID)
             .filter(
-                ModelID.modelname == model_id.modelname,
-                ModelID.modelversion == model_id.modelversion
+                ModelID.modelname == model_name,
+                ModelID.modelversion == model_version
             )
-            .one()
+            .all()
         )
-        return trainingjob
+        return trainingjobs
     except NoResultFound:
         return None
     except Exception as e:
-        raise DBException(f'{DB_QUERY_EXEC_ERROR} in the get_trainingjob_by_name_db : {str(e)}')
+        raise DBException(f'{DB_QUERY_EXEC_ERROR} in the get_trainingjobs_by_model_id_db : {str(e)}')
+
 
 def change_steps_state(trainingjob_id, step: Steps, state:States):
 
index 0c6086a..144d141 100644 (file)
@@ -21,8 +21,8 @@ from flask_api import status
 from flask import jsonify
 from trainingmgr.common.trainingmgr_operations import data_extraction_start, notification_rapp
 from trainingmgr.db.model_db import get_model_by_modelId
-from trainingmgr.db.trainingjob_db import change_state_to_failed, delete_trainingjob_by_id, create_trainingjob, get_trainingjob, get_trainingjob_by_modelId_db, \
-change_steps_state, change_field_value, change_field_value, change_steps_state_df, changeartifact
+from trainingmgr.db.trainingjob_db import change_state_to_failed, delete_trainingjob_by_id, create_trainingjob, get_trainingjob,\
+change_steps_state, change_field_value, change_field_value, change_steps_state_df, changeartifact, get_trainingjobs_by_model_id_db
 from trainingmgr.common.exceptions_utls import DBException, TMException
 from trainingmgr.common.trainingConfig_parser import getField, setField
 from trainingmgr.handler.async_handler import DATAEXTRACTION_JOBS_CACHE
@@ -129,15 +129,6 @@ def delete_training_job(training_job_id : int):
         raise DBException(f"delete_trainining_job failed with exception : {str(err)}")
 
 
-def get_trainingjob_by_modelId(model_id):
-    try:
-        trainingjob = get_trainingjob_by_modelId_db(model_id)
-        return trainingjob
-    except Exception as err:
-        if "No row was found when one was required" in str(err):
-            return None
-        raise DBException(f"get_trainingjob_by_name failed with exception : {str(err)}")
-
 def get_steps_state(trainingjob_id):
     try:    
         trainingjob = get_trainingjob(trainingjob_id)
@@ -283,4 +274,12 @@ def fetch_pipelinename_and_version(type, training_config):
         else :
             return getField(training_config, "retraining_pipeline_name"), getField(training_config, "retraining_pipeline_version")
     except Exception as err:
-        raise TMException(f"cant fetch training or retraining pipeline name or version from trainingconfig {training_config}| Error: " + str(err))
\ No newline at end of file
+        raise TMException(f"cant fetch training or retraining pipeline name or version from trainingconfig {training_config}| Error: " + str(err))
+    
+
+def fetch_trainingjob_infos_from_model_id(model_name, model_version):
+    try:
+        trainingjob_infos = get_trainingjobs_by_model_id_db(model_name, model_version)
+        return trainingjob_infos
+    except Exception as err:
+        raise TMException(f"Can't fetch trainingjob_infos from model_name {model_name} and model_version {model_version}| Error: " + str(err))
\ No newline at end of file