changes in the data extraction notification code 15/13715/6
authorrajdeep11 <rajdeep.sin@samsung.com>
Wed, 30 Oct 2024 06:38:58 +0000 (12:08 +0530)
committersubhash kumar singh <subh.singh@samsung.com>
Wed, 30 Oct 2024 10:09:39 +0000 (10:09 +0000)
Change-Id: I094fdb6289b9037b6fa2c8179d9f2fa30910e313
Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
trainingmgr/common/trainingmgr_util.py
trainingmgr/db/trainingjob_db.py
trainingmgr/trainingmgr_main.py

index 4b8879b..2c36759 100644 (file)
@@ -43,7 +43,7 @@ PATTERN = re.compile(r"\w+")
 featuregroup_schema = FeatureGroupSchema()
 featuregroups_schema = FeatureGroupSchema(many = True)
 
-def response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk):
+def response_for_training(code, message, logger, is_success, trainingjob_name, mm_sdk):
     """
     Post training job completion,this function provides notifications to the subscribers, 
     who subscribed for the result of training job and provided a notification url during 
@@ -54,37 +54,34 @@ def response_for_training(code, message, logger, is_success, trainingjob_name, p
     
     try :
         #TODO DB query optimization, all data to fetch in one call
-        notif_url_result = get_field_by_latest_version(trainingjob_name, ps_db_obj, "notification_url")
-        if notif_url_result :
-            notification_url = notif_url_result[0][0]
-            model_url_result = None
-            if notification_url != '':
-                model_url_result = get_field_by_latest_version(trainingjob_name, ps_db_obj, "model_url")
-                model_url = model_url_result[0][0]
-                version = get_latest_version_trainingjob_name(trainingjob_name, ps_db_obj)
-                metrics = get_metrics(trainingjob_name, version, mm_sdk)
-
-                req_json = None
-                if is_success:
-                    req_json = {
-                        "result": "success", "model_url": model_url,
-                        "trainingjob_name": trainingjob_name, "metrics": metrics
-                    }
-                else:
-                    req_json = {"result": "failed", "trainingjob_name": trainingjob_name}
-            
-                response = requests.post(notification_url,
-                        data=json.dumps(req_json),
-                        headers={
-                            'content-type': MIMETYPE_JSON,
-                            'Accept-Charset': 'UTF-8'
-                        })
-                if ( response.headers['content-type'] != MIMETYPE_JSON
-                        or response.status_code != status.HTTP_200_OK ):
-                    err_msg = "Failed to notify the subscribed url " + trainingjob_name
-                    raise TMException(err_msg)
+        notif_url = get_field_by_latest_version(trainingjob_name, "notification_url")
+        if notif_url :
+
+            model_url = get_field_by_latest_version(trainingjob_name, "model_url")
+            version = get_latest_version_trainingjob_name(trainingjob_name)
+            metrics = get_metrics(trainingjob_name, version, mm_sdk)
+
+            req_json = None
+            if is_success:
+                req_json = {
+                    "result": "success", "model_url": model_url,
+                    "trainingjob_name": trainingjob_name, "metrics": metrics
+                }
+            else:
+                req_json = {"result": "failed", "trainingjob_name": trainingjob_name}
+        
+            response = requests.post(notif_url,
+                    data=json.dumps(req_json),
+                    headers={
+                        'content-type': MIMETYPE_JSON,
+                        'Accept-Charset': 'UTF-8'
+                    })
+            if ( response.headers['content-type'] != MIMETYPE_JSON
+                    or response.status_code != status.HTTP_200_OK ):
+                err_msg = "Failed to notify the subscribed url " + trainingjob_name
+                raise TMException(err_msg)
     except Exception as err:
-        change_in_progress_to_failed_by_latest_version(trainingjob_name, ps_db_obj)
+        change_in_progress_to_failed_by_latest_version(trainingjob_name)
         raise APIException(status.HTTP_500_INTERNAL_SERVER_ERROR,
                             str(err) + "(trainingjob name is " + trainingjob_name + ")") from None
     if is_success:
index 3e525d1..e4f4353 100644 (file)
@@ -211,4 +211,22 @@ def get_field_by_latest_version(trainingjob_name, field):
     except Exception as err:
         raise DBException("Failed to execute query in get_field_by_latest_version,"  + str(err))
 
-    return result
\ No newline at end of file
+    return result
+
+def change_field_of_latest_version(trainingjob_name, field, field_value):
+    """
+    This function updates the field's value for given trainingjob.
+    """
+
+    try:
+        trainingjob_max_version = TrainingJob.query.filter(TrainingJob.trainingjob_name == trainingjob_name).order_by(TrainingJob.version.desc()).first()
+        if field == "notification_url":
+            trainingjob_max_version.notification_url = field_value
+            trainingjob_max_version.updation_time = datetime.datetime.utcnow()
+        if field == "run_id":
+            trainingjob_max_version.run_id = field_value
+            trainingjob_max_version.updation_time = datetime.datetime.utcnow()
+        db.session.commit()
+
+    except Exception as err:
+        raise DBException("Failed to execute query in change_field_of_latest_version,"  + str(err))
\ No newline at end of file
index a8859ea..f832542 100644 (file)
@@ -43,14 +43,13 @@ from trainingmgr.common.trainingmgr_util import get_one_word_status, check_train
     response_for_training, get_metrics, \
     handle_async_feature_engineering_status_exception_case, \
     validate_trainingjob_name, get_pipelines_details, check_feature_group_data, check_trainingjob_name_and_version, check_trainingjob_name_or_featuregroup_name, \
-    get_feature_group_by_name, edit_feature_group_by_name
+    get_feature_group_by_name, edit_feature_group_by_name, fetch_pipeline_info_by_name
 from trainingmgr.common.exceptions_utls import APIException,TMException
 from trainingmgr.constants.steps import Steps
 from trainingmgr.constants.states import States
 from trainingmgr.db.trainingmgr_ps_db import PSDB
 from trainingmgr.common.exceptions_utls import DBException
 from trainingmgr.db.common_db_fun import get_data_extraction_in_progress_trainingjobs, \
-    change_field_of_latest_version, \
     change_in_progress_to_failed_by_latest_version, \
     get_info_by_version, \
     get_latest_version_trainingjob_name, get_all_versions_info_by_name, \
@@ -63,7 +62,7 @@ from trainingmgr.db.featuregroup_db import add_featuregroup, edit_featuregroup,
     get_feature_group_by_name_db, delete_feature_group_by_name
 from trainingmgr.db.trainingjob_db import add_update_trainingjob, get_trainingjob_info_by_name, \
     get_all_jobs_latest_status_version, change_steps_state_of_latest_version, get_info_by_version, \
-    get_steps_state_db
+    get_steps_state_db, change_field_of_latest_version
 
 APP = Flask(__name__)
 
@@ -427,14 +426,14 @@ def data_extraction_notification():
             return {"Exception":err_msg}, status.HTTP_400_BAD_REQUEST
             
         trainingjob_name = request.json["trainingjob_name"]
-        results = get_trainingjob_info_by_name(trainingjob_name, PS_DB_OBJ)
-        arguments = json.loads(results[0][5])['arguments']
-        arguments["version"] = results[0][11]
+        trainingjob = get_trainingjob_info_by_name(trainingjob_name)
+        arguments = trainingjob.arguments
+        arguments["version"] = trainingjob.version
         LOGGER.debug(arguments)
 
         dict_data = {
-            "pipeline_name": results[0][3], "experiment_name": results[0][4],
-            "arguments": arguments, "pipeline_version": results[0][13]
+            "pipeline_name": trainingjob.pipeline_name, "experiment_name": trainingjob.experiment_name,
+            "arguments": arguments, "pipeline_version": trainingjob.pipeline_version
         }
 
         response = training_start(TRAININGMGR_CONFIG_OBJ, dict_data, trainingjob_name)
@@ -454,32 +453,32 @@ def data_extraction_notification():
             raise TMException(err_msg)
 
         if json_data["run_status"] == 'scheduled':
-            change_steps_state_of_latest_version(trainingjob_name, PS_DB_OBJ,
+            change_steps_state_of_latest_version(trainingjob_name,
                                                 Steps.DATA_EXTRACTION_AND_TRAINING.name,
                                                 States.FINISHED.name)
-            change_steps_state_of_latest_version(trainingjob_name, PS_DB_OBJ,
+            change_steps_state_of_latest_version(trainingjob_name,
                                                 Steps.TRAINING.name,
                                                 States.IN_PROGRESS.name)
-            change_field_of_latest_version(trainingjob_name, PS_DB_OBJ,
+            change_field_of_latest_version(trainingjob_name,
                                         "run_id", json_data["run_id"])
         else:
             raise TMException("KF Adapter- run_status in not scheduled")
     except requests.exceptions.ConnectionError as err:
         err_msg = "Failed to connect KF adapter."
         LOGGER.error(err_msg)
-        if not change_in_progress_to_failed_by_latest_version(trainingjob_name, PS_DB_OBJ) :
+        if not change_in_progress_to_failed_by_latest_version(trainingjob_name) :
             LOGGER.error(ERROR_TYPE_DB_STATUS)
         return response_for_training(err_response_code,
                                         err_msg + str(err) + "(trainingjob name is " + trainingjob_name + ")",
-                                        LOGGER, False, trainingjob_name, PS_DB_OBJ, MM_SDK)
+                                        LOGGER, False, trainingjob_name, MM_SDK)
 
     except Exception as err:
         LOGGER.error("Failed to handle dataExtractionNotification. " + str(err))
-        if not change_in_progress_to_failed_by_latest_version(trainingjob_name, PS_DB_OBJ) :
+        if not change_in_progress_to_failed_by_latest_version(trainingjob_name) :
             LOGGER.error(ERROR_TYPE_DB_STATUS)
         return response_for_training(err_response_code,
                                         str(err) + "(trainingjob name is " + trainingjob_name + ")",
-                                        LOGGER, False, trainingjob_name, PS_DB_OBJ, MM_SDK)
+                                        LOGGER, False, trainingjob_name, MM_SDK)
 
     return APP.response_class(response=json.dumps({"result": "pipeline is scheduled"}),
                                     status=status.HTTP_200_OK,