From 285025f08bdcfd60972bc21118ae1364f14ee347 Mon Sep 17 00:00:00 2001 From: rajdeep11 Date: Wed, 30 Oct 2024 11:45:03 +0530 Subject: [PATCH] changes in the start training Change-Id: I73676734367b89016697ba21c2cdc46864dd6d2f Signed-off-by: rajdeep11 --- trainingmgr/common/trainingmgr_util.py | 9 +++++---- trainingmgr/trainingmgr_main.py | 34 +++++++++++++++++----------------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/trainingmgr/common/trainingmgr_util.py b/trainingmgr/common/trainingmgr_util.py index f6cbaa8..4b8879b 100644 --- a/trainingmgr/common/trainingmgr_util.py +++ b/trainingmgr/common/trainingmgr_util.py @@ -27,13 +27,14 @@ import requests from marshmallow import ValidationError from trainingmgr.db.common_db_fun import change_in_progress_to_failed_by_latest_version, \ get_field_by_latest_version, change_field_of_latest_version, \ - get_latest_version_trainingjob_name, get_all_versions_info_by_name + get_latest_version_trainingjob_name from trainingmgr.db.featuregroup_db import add_featuregroup, edit_featuregroup, get_feature_groups_db, \ get_feature_group_by_name_db, delete_feature_group_by_name from trainingmgr.constants.states import States from trainingmgr.common.exceptions_utls import APIException,TMException,DBException from trainingmgr.common.trainingmgr_operations import create_dme_filtered_data_job from trainingmgr.schemas import ma, TrainingJobSchema , FeatureGroupSchema +from trainingmgr.db.trainingjob_db import get_all_versions_info_by_name ERROR_TYPE_KF_ADAPTER_JSON = "Kf adapter doesn't sends json type response" MIMETYPE_JSON = "application/json" @@ -331,7 +332,7 @@ def handle_async_feature_engineering_status_exception_case(lock, dataextraction_ except KeyError as key_err: logger.error("The training job key doesn't exist in DATAEXTRACTION_JOBS_CACHE: " + str(key_err)) -def validate_trainingjob_name(trainingjob_name, ps_db_obj): +def validate_trainingjob_name(trainingjob_name): """ This function returns True if given trainingjob_name exists in db otherwise it returns False. @@ -343,13 +344,13 @@ def validate_trainingjob_name(trainingjob_name, ps_db_obj): raise TMException("The name of training job is invalid.") try: - results = get_all_versions_info_by_name(trainingjob_name, ps_db_obj) + results = get_all_versions_info_by_name(trainingjob_name) except Exception as err: errmsg = str(err) raise DBException("Could not get info from db for " + trainingjob_name + "," + errmsg) if results: isavailable = True - return isavailable + return isavailable def get_pipelines_details(training_config_obj): logger=training_config_obj.logger diff --git a/trainingmgr/trainingmgr_main.py b/trainingmgr/trainingmgr_main.py index 2dc12a2..a8859ea 100644 --- a/trainingmgr/trainingmgr_main.py +++ b/trainingmgr/trainingmgr_main.py @@ -51,9 +51,9 @@ from trainingmgr.db.trainingmgr_ps_db import PSDB from trainingmgr.common.exceptions_utls import DBException from trainingmgr.db.common_db_fun import get_data_extraction_in_progress_trainingjobs, \ change_field_of_latest_version, \ - change_in_progress_to_failed_by_latest_version, change_steps_state_of_latest_version, \ + change_in_progress_to_failed_by_latest_version, \ get_info_by_version, \ - get_trainingjob_info_by_name, get_latest_version_trainingjob_name, get_all_versions_info_by_name, \ + get_latest_version_trainingjob_name, get_all_versions_info_by_name, \ update_model_download_url, \ get_field_of_given_version,get_all_jobs_latest_status_version, get_info_of_latest_version, \ delete_trainingjob_version, change_field_value_by_version @@ -341,27 +341,26 @@ def training(trainingjob_name): return {"Exception":"The trainingjob_name is not correct"}, status.HTTP_400_BAD_REQUEST LOGGER.debug("Request for training trainingjob %s ", trainingjob_name) try: - isDataAvaible = validate_trainingjob_name(trainingjob_name, PS_DB_OBJ) + isDataAvaible = validate_trainingjob_name(trainingjob_name) if not isDataAvaible: response_code = status.HTTP_404_NOT_FOUND raise TMException("Given trainingjob name is not present in database" + \ "(trainingjob: " + trainingjob_name + ")") from None else: - db_results = get_trainingjob_info_by_name(trainingjob_name, PS_DB_OBJ) - featuregroup_name = db_results[0][2] - result= get_feature_group_by_name_db(PS_DB_OBJ, featuregroup_name) - feature_list_string = result[0][1] + trainingjob = get_trainingjob_info_by_name(trainingjob_name) + featuregroup= get_feature_group_by_name_db(trainingjob.feature_group_name) + feature_list_string = featuregroup.feature_list influxdb_info_dic={} - influxdb_info_dic["host"]=result[0][3] - influxdb_info_dic["port"]=result[0][4] - influxdb_info_dic["bucket"]=result[0][5] - influxdb_info_dic["token"]=result[0][6] - influxdb_info_dic["db_org"] = result[0][7] - influxdb_info_dic["source_name"]= result[0][11] - query_filter = db_results[0][6] - datalake_source = json.loads(db_results[0][14])['datalake_source'] - _measurement = result[0][8] + influxdb_info_dic["host"]=featuregroup.host + influxdb_info_dic["port"]=featuregroup.port + influxdb_info_dic["bucket"]=featuregroup.bucket + influxdb_info_dic["token"]=featuregroup.token + influxdb_info_dic["db_org"] = featuregroup.db_org + influxdb_info_dic["source_name"]= featuregroup.source_name + _measurement = featuregroup.measurement + query_filter = trainingjob.query_filter + datalake_source = json.loads(trainingjob.datalake_source)['datalake_source'] LOGGER.debug('Starting Data Extraction...') de_response = data_extraction_start(TRAININGMGR_CONFIG_OBJ, trainingjob_name, feature_list_string, query_filter, datalake_source, @@ -369,7 +368,7 @@ def training(trainingjob_name): if (de_response.status_code == status.HTTP_200_OK ): LOGGER.debug("Response from data extraction for " + \ trainingjob_name + " : " + json.dumps(de_response.json())) - change_steps_state_of_latest_version(trainingjob_name, PS_DB_OBJ, + change_steps_state_of_latest_version(trainingjob_name, Steps.DATA_EXTRACTION.name, States.IN_PROGRESS.name) with LOCK: @@ -393,6 +392,7 @@ def training(trainingjob_name): LOGGER.debug("Error is training, job name:" + trainingjob_name + str(err)) return APP.response_class(response=json.dumps(response_data),status=response_code, mimetype=MIMETYPE_JSON) + @APP.route('/trainingjob/dataExtractionNotification', methods=['POST']) def data_extraction_notification(): """ -- 2.16.6