From 40d1f5752740f89ea7a0144366ea9c62760bc282 Mon Sep 17 00:00:00 2001 From: rajdeep11 Date: Wed, 30 Oct 2024 12:08:58 +0530 Subject: [PATCH] changes in the data extraction notification code Change-Id: I094fdb6289b9037b6fa2c8179d9f2fa30910e313 Signed-off-by: rajdeep11 --- trainingmgr/common/trainingmgr_util.py | 59 ++++++++++++++++------------------ trainingmgr/db/trainingjob_db.py | 20 +++++++++++- trainingmgr/trainingmgr_main.py | 29 ++++++++--------- 3 files changed, 61 insertions(+), 47 deletions(-) diff --git a/trainingmgr/common/trainingmgr_util.py b/trainingmgr/common/trainingmgr_util.py index 4b8879b..2c36759 100644 --- a/trainingmgr/common/trainingmgr_util.py +++ b/trainingmgr/common/trainingmgr_util.py @@ -43,7 +43,7 @@ PATTERN = re.compile(r"\w+") featuregroup_schema = FeatureGroupSchema() featuregroups_schema = FeatureGroupSchema(many = True) -def response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk): +def response_for_training(code, message, logger, is_success, trainingjob_name, mm_sdk): """ Post training job completion,this function provides notifications to the subscribers, who subscribed for the result of training job and provided a notification url during @@ -54,37 +54,34 @@ def response_for_training(code, message, logger, is_success, trainingjob_name, p try : #TODO DB query optimization, all data to fetch in one call - notif_url_result = get_field_by_latest_version(trainingjob_name, ps_db_obj, "notification_url") - if notif_url_result : - notification_url = notif_url_result[0][0] - model_url_result = None - if notification_url != '': - model_url_result = get_field_by_latest_version(trainingjob_name, ps_db_obj, "model_url") - model_url = model_url_result[0][0] - version = get_latest_version_trainingjob_name(trainingjob_name, ps_db_obj) - metrics = get_metrics(trainingjob_name, version, mm_sdk) - - req_json = None - if is_success: - req_json = { - "result": "success", "model_url": model_url, - "trainingjob_name": trainingjob_name, "metrics": metrics - } - else: - req_json = {"result": "failed", "trainingjob_name": trainingjob_name} - - response = requests.post(notification_url, - data=json.dumps(req_json), - headers={ - 'content-type': MIMETYPE_JSON, - 'Accept-Charset': 'UTF-8' - }) - if ( response.headers['content-type'] != MIMETYPE_JSON - or response.status_code != status.HTTP_200_OK ): - err_msg = "Failed to notify the subscribed url " + trainingjob_name - raise TMException(err_msg) + notif_url = get_field_by_latest_version(trainingjob_name, "notification_url") + if notif_url : + + model_url = get_field_by_latest_version(trainingjob_name, "model_url") + version = get_latest_version_trainingjob_name(trainingjob_name) + metrics = get_metrics(trainingjob_name, version, mm_sdk) + + req_json = None + if is_success: + req_json = { + "result": "success", "model_url": model_url, + "trainingjob_name": trainingjob_name, "metrics": metrics + } + else: + req_json = {"result": "failed", "trainingjob_name": trainingjob_name} + + response = requests.post(notif_url, + data=json.dumps(req_json), + headers={ + 'content-type': MIMETYPE_JSON, + 'Accept-Charset': 'UTF-8' + }) + if ( response.headers['content-type'] != MIMETYPE_JSON + or response.status_code != status.HTTP_200_OK ): + err_msg = "Failed to notify the subscribed url " + trainingjob_name + raise TMException(err_msg) except Exception as err: - change_in_progress_to_failed_by_latest_version(trainingjob_name, ps_db_obj) + change_in_progress_to_failed_by_latest_version(trainingjob_name) raise APIException(status.HTTP_500_INTERNAL_SERVER_ERROR, str(err) + "(trainingjob name is " + trainingjob_name + ")") from None if is_success: diff --git a/trainingmgr/db/trainingjob_db.py b/trainingmgr/db/trainingjob_db.py index 3e525d1..e4f4353 100644 --- a/trainingmgr/db/trainingjob_db.py +++ b/trainingmgr/db/trainingjob_db.py @@ -211,4 +211,22 @@ def get_field_by_latest_version(trainingjob_name, field): except Exception as err: raise DBException("Failed to execute query in get_field_by_latest_version," + str(err)) - return result \ No newline at end of file + return result + +def change_field_of_latest_version(trainingjob_name, field, field_value): + """ + This function updates the field's value for given trainingjob. + """ + + try: + trainingjob_max_version = TrainingJob.query.filter(TrainingJob.trainingjob_name == trainingjob_name).order_by(TrainingJob.version.desc()).first() + if field == "notification_url": + trainingjob_max_version.notification_url = field_value + trainingjob_max_version.updation_time = datetime.datetime.utcnow() + if field == "run_id": + trainingjob_max_version.run_id = field_value + trainingjob_max_version.updation_time = datetime.datetime.utcnow() + db.session.commit() + + except Exception as err: + raise DBException("Failed to execute query in change_field_of_latest_version," + str(err)) \ No newline at end of file diff --git a/trainingmgr/trainingmgr_main.py b/trainingmgr/trainingmgr_main.py index a8859ea..f832542 100644 --- a/trainingmgr/trainingmgr_main.py +++ b/trainingmgr/trainingmgr_main.py @@ -43,14 +43,13 @@ from trainingmgr.common.trainingmgr_util import get_one_word_status, check_train response_for_training, get_metrics, \ handle_async_feature_engineering_status_exception_case, \ validate_trainingjob_name, get_pipelines_details, check_feature_group_data, check_trainingjob_name_and_version, check_trainingjob_name_or_featuregroup_name, \ - get_feature_group_by_name, edit_feature_group_by_name + get_feature_group_by_name, edit_feature_group_by_name, fetch_pipeline_info_by_name from trainingmgr.common.exceptions_utls import APIException,TMException from trainingmgr.constants.steps import Steps from trainingmgr.constants.states import States from trainingmgr.db.trainingmgr_ps_db import PSDB from trainingmgr.common.exceptions_utls import DBException from trainingmgr.db.common_db_fun import get_data_extraction_in_progress_trainingjobs, \ - change_field_of_latest_version, \ change_in_progress_to_failed_by_latest_version, \ get_info_by_version, \ get_latest_version_trainingjob_name, get_all_versions_info_by_name, \ @@ -63,7 +62,7 @@ from trainingmgr.db.featuregroup_db import add_featuregroup, edit_featuregroup, get_feature_group_by_name_db, delete_feature_group_by_name from trainingmgr.db.trainingjob_db import add_update_trainingjob, get_trainingjob_info_by_name, \ get_all_jobs_latest_status_version, change_steps_state_of_latest_version, get_info_by_version, \ - get_steps_state_db + get_steps_state_db, change_field_of_latest_version APP = Flask(__name__) @@ -427,14 +426,14 @@ def data_extraction_notification(): return {"Exception":err_msg}, status.HTTP_400_BAD_REQUEST trainingjob_name = request.json["trainingjob_name"] - results = get_trainingjob_info_by_name(trainingjob_name, PS_DB_OBJ) - arguments = json.loads(results[0][5])['arguments'] - arguments["version"] = results[0][11] + trainingjob = get_trainingjob_info_by_name(trainingjob_name) + arguments = trainingjob.arguments + arguments["version"] = trainingjob.version LOGGER.debug(arguments) dict_data = { - "pipeline_name": results[0][3], "experiment_name": results[0][4], - "arguments": arguments, "pipeline_version": results[0][13] + "pipeline_name": trainingjob.pipeline_name, "experiment_name": trainingjob.experiment_name, + "arguments": arguments, "pipeline_version": trainingjob.pipeline_version } response = training_start(TRAININGMGR_CONFIG_OBJ, dict_data, trainingjob_name) @@ -454,32 +453,32 @@ def data_extraction_notification(): raise TMException(err_msg) if json_data["run_status"] == 'scheduled': - change_steps_state_of_latest_version(trainingjob_name, PS_DB_OBJ, + change_steps_state_of_latest_version(trainingjob_name, Steps.DATA_EXTRACTION_AND_TRAINING.name, States.FINISHED.name) - change_steps_state_of_latest_version(trainingjob_name, PS_DB_OBJ, + change_steps_state_of_latest_version(trainingjob_name, Steps.TRAINING.name, States.IN_PROGRESS.name) - change_field_of_latest_version(trainingjob_name, PS_DB_OBJ, + change_field_of_latest_version(trainingjob_name, "run_id", json_data["run_id"]) else: raise TMException("KF Adapter- run_status in not scheduled") except requests.exceptions.ConnectionError as err: err_msg = "Failed to connect KF adapter." LOGGER.error(err_msg) - if not change_in_progress_to_failed_by_latest_version(trainingjob_name, PS_DB_OBJ) : + if not change_in_progress_to_failed_by_latest_version(trainingjob_name) : LOGGER.error(ERROR_TYPE_DB_STATUS) return response_for_training(err_response_code, err_msg + str(err) + "(trainingjob name is " + trainingjob_name + ")", - LOGGER, False, trainingjob_name, PS_DB_OBJ, MM_SDK) + LOGGER, False, trainingjob_name, MM_SDK) except Exception as err: LOGGER.error("Failed to handle dataExtractionNotification. " + str(err)) - if not change_in_progress_to_failed_by_latest_version(trainingjob_name, PS_DB_OBJ) : + if not change_in_progress_to_failed_by_latest_version(trainingjob_name) : LOGGER.error(ERROR_TYPE_DB_STATUS) return response_for_training(err_response_code, str(err) + "(trainingjob name is " + trainingjob_name + ")", - LOGGER, False, trainingjob_name, PS_DB_OBJ, MM_SDK) + LOGGER, False, trainingjob_name, MM_SDK) return APP.response_class(response=json.dumps({"result": "pipeline is scheduled"}), status=status.HTTP_200_OK, -- 2.16.6