From 86c6c6423ba42bcac880554e619fb11cc982d1ed Mon Sep 17 00:00:00 2001 From: rajdeep11 Date: Mon, 11 Nov 2024 11:30:30 +0530 Subject: [PATCH] logic for notification to rApp Change-Id: I3e6131d30e33289ec8b0da2862827f88af42ef37 Signed-off-by: rajdeep11 --- trainingmgr/common/trainingmgr_operations.py | 12 ++++++++++++ trainingmgr/db/trainingjob_db.py | 2 +- trainingmgr/trainingmgr_main.py | 17 +++++++++++------ 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/trainingmgr/common/trainingmgr_operations.py b/trainingmgr/common/trainingmgr_operations.py index 23321f3..5cff359 100644 --- a/trainingmgr/common/trainingmgr_operations.py +++ b/trainingmgr/common/trainingmgr_operations.py @@ -26,6 +26,7 @@ import requests import validators from trainingmgr.common.exceptions_utls import TMException from flask_api import status +from trainingmgr.db.trainingjob_db import get_steps_state_db MIMETYPE_JSON = "application/json" @@ -182,3 +183,14 @@ def get_model_info(training_config_obj, model_name): errMsg="model info can't be fetched, model_name: {} , err: {}".format(model_name, response.text) logger.error(errMsg) raise TMException(errMsg) + +def notification_rapp(trainingjob, training_config_obj): + steps_state = get_steps_state_db(trainingjob.trainingjob_name) + response = requests.post(trainingjob.notification_url, + data=json.dumps(steps_state), + headers={ + 'content-type': MIMETYPE_JSON, + 'Accept-Charset': 'UTF-8' + }) + if response.status_code != 200: + raise TMException("Notification failed: "+response.text) \ No newline at end of file diff --git a/trainingmgr/db/trainingjob_db.py b/trainingmgr/db/trainingjob_db.py index 4b70adb..b52da7b 100644 --- a/trainingmgr/db/trainingjob_db.py +++ b/trainingmgr/db/trainingjob_db.py @@ -115,7 +115,7 @@ def get_steps_state_db(trainingjob_name, version): """ try: - steps_state = TrainingJob.query.filter(TrainingJob.trainingjob_name == trainingjob_name).filter(TrainingJob.version == version).first().steps_state + steps_state = TrainingJob.query.filter(TrainingJob.trainingjob_name == trainingjob_name).filter(TrainingJob.version == version).first().steps_state.states except Exception as err: raise DBException("Failed to execute query in get_field_of_given_version" + str(err)) diff --git a/trainingmgr/trainingmgr_main.py b/trainingmgr/trainingmgr_main.py index 6e18acc..165eeff 100644 --- a/trainingmgr/trainingmgr_main.py +++ b/trainingmgr/trainingmgr_main.py @@ -36,7 +36,7 @@ from flask_cors import CORS from werkzeug.utils import secure_filename from modelmetricsdk.model_metrics_sdk import ModelMetricsSdk from trainingmgr.common.trainingmgr_operations import data_extraction_start, training_start, data_extraction_status, create_dme_filtered_data_job, delete_dme_filtered_data_job, \ - get_model_info + get_model_info, notification_rapp from trainingmgr.common.trainingmgr_config import TrainingMgrConfig from trainingmgr.common.trainingmgr_util import get_one_word_status, check_trainingjob_data, \ check_key_in_dictionary, get_one_key, \ @@ -469,6 +469,7 @@ def data_extraction_notification(): States.IN_PROGRESS.name) change_field_of_latest_version(trainingjob_name, "run_id", json_data["run_id"]) + notification_rapp(trainingjob, TRAININGMGR_CONFIG_OBJ) else: raise TMException("KF Adapter- run_status in not scheduled") except requests.exceptions.ConnectionError as err: @@ -570,21 +571,26 @@ def pipeline_notification(): run_status = request.json["run_status"] if run_status == 'SUCCEEDED': + + trainingjob_info=get_trainingjob_info_by_name(trainingjob_name) change_steps_state_of_latest_version(trainingjob_name, Steps.TRAINING.name, States.FINISHED.name) change_steps_state_of_latest_version(trainingjob_name, Steps.TRAINING_AND_TRAINED_MODEL.name, States.IN_PROGRESS.name) - + notification_rapp(trainingjob_info, TRAININGMGR_CONFIG_OBJ) + version = get_latest_version_trainingjob_name(trainingjob_name) + change_steps_state_of_latest_version(trainingjob_name, Steps.TRAINING_AND_TRAINED_MODEL.name, States.FINISHED.name) change_steps_state_of_latest_version(trainingjob_name, Steps.TRAINED_MODEL.name, States.IN_PROGRESS.name) - + notification_rapp(trainingjob_info, TRAININGMGR_CONFIG_OBJ) + if MM_SDK.check_object(trainingjob_name, version, "Model.zip"): model_url = "http://" + str(TRAININGMGR_CONFIG_OBJ.my_ip) + ":" + \ str(TRAININGMGR_CONFIG_OBJ.my_port) + "/model/" + \ @@ -596,10 +602,9 @@ def pipeline_notification(): change_steps_state_of_latest_version(trainingjob_name, Steps.TRAINED_MODEL.name, States.FINISHED.name) + notification_rapp(trainingjob_info, TRAININGMGR_CONFIG_OBJ) # upload to the mme - trainingjob_info=get_trainingjob_info_by_name(trainingjob_name) - - is_mme = getField(trainingjob_info.training_config, "is_mme") + is_mme= trainingjob_info.is_mme if is_mme: model_name=trainingjob_info.model_name #model_name file=MM_SDK.get_model_zip(trainingjob_name, str(version)) -- 2.16.6