From: rajdeep11 Date: Tue, 21 Feb 2023 08:56:44 +0000 (+0530) Subject: adding edit and retrain feature X-Git-Tag: 1.1.0~26 X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=commitdiff_plain;h=201df052479deee5296763cabb683ce761974b82;p=aiml-fw%2Fawmf%2Ftm.git adding edit and retrain feature Issue-Id: AIMLFW-15 Signed-off-by: rajdeep11 Change-Id: Ia121004986bc778f6407aefa77b2aa2331b0aa89 Signed-off-by: rajdeep11 --- diff --git a/trainingmgr/db/common_db_fun.py b/trainingmgr/db/common_db_fun.py index 1389d81..329138c 100644 --- a/trainingmgr/db/common_db_fun.py +++ b/trainingmgr/db/common_db_fun.py @@ -588,3 +588,26 @@ def get_all_jobs_latest_status_version(ps_db_obj): conn.close() return results +def get_info_of_latest_version(trainingjob_name, ps_db_obj): + """ + This function returns information of + usecase. + """ + + conn = ps_db_obj.get_new_conn() + cursor = conn.cursor() + try: + cursor.execute('''select nt.mv from (select max(version) mv,trainingjob_name ''' + \ + '''from trainingjob_info group by trainingjob_name) nt where nt.trainingjob_name=%s''', + (trainingjob_name,)) + version = int(cursor.fetchall()[0][0]) + cursor.execute(''' select * from trainingjob_info where trainingjob_name=%s and version = %s''', + (trainingjob_name, version)) + except Exception as err: + conn.rollback() + conn.close() + raise err + results = cursor.fetchall() + conn.commit() + conn.close() + return results diff --git a/trainingmgr/trainingmgr_main.py b/trainingmgr/trainingmgr_main.py index 2c65cc4..dc918f7 100644 --- a/trainingmgr/trainingmgr_main.py +++ b/trainingmgr/trainingmgr_main.py @@ -50,7 +50,7 @@ from trainingmgr.db.common_db_fun import get_data_extraction_in_progress_trainin get_info_by_version, \ get_trainingjob_info_by_name, get_latest_version_trainingjob_name, get_all_versions_info_by_name, \ update_model_download_url, add_update_trainingjob, \ - get_field_of_given_version,get_all_jobs_latest_status_version + get_field_of_given_version,get_all_jobs_latest_status_version, get_info_of_latest_version APP = Flask(__name__) TRAININGMGR_CONFIG_OBJ = None @@ -937,6 +937,135 @@ def trainingjob_operations(trainingjob_name): status= response_code, mimetype=MIMETYPE_JSON) +@APP.route('/trainingjobs/retraining', methods=['POST']) +@cross_origin() +def retraining(): + """ + Function handling rest endpoint to retrain trainingjobs in request json. trainingjob's + overall_status should be failed or finished and its deletion_in_progress should be False + otherwise retraining of that trainingjob is counted in failure. + Args in function: none + Required Args in json: + trainingjobs_list: list + list containing dictionaries + dictionary contains + usecase_name: str + name of trainingjob + notification_url(optional): str + url for notification + feature_filter(optional): str + feature filter + Returns: + json: + success count: int + successful retraining count + failure count: int + failure retraining count + status: HTTP status code 200 + Exceptions: + all exception are provided with exception message and HTTP status code. + """ + LOGGER.debug('request comes for retraining, ' + json.dumps(request.json)) + try: + check_key_in_dictionary(["trainingjobs_list"], request.json) + except Exception as err: + raise APIException(status.HTTP_400_BAD_REQUEST, str(err)) from None + + trainingjobs_list = request.json['trainingjobs_list'] + print("trainingjobs list is :", trainingjobs_list) + if not isinstance(trainingjobs_list, list): + raise APIException(status.HTTP_400_BAD_REQUEST, "not given as list") + + for obj in trainingjobs_list: + try: + check_key_in_dictionary(["trainingjob_name"], obj) + except Exception as err: + raise APIException(status.HTTP_400_BAD_REQUEST, str(err)) from None + + not_possible_to_retrain = [] + possible_to_retrain = [] + + for obj in trainingjobs_list: + trainingjob_name = obj['trainingjob_name'] + results = None + try: + results = get_info_of_latest_version(trainingjob_name, PS_DB_OBJ) + except Exception as err: + not_possible_to_retrain.append(trainingjob_name) + LOGGER.debug(str(err) + "(trainingjob_name is " + trainingjob_name + ")") + continue + + if results: + + if results[0][19]: + not_possible_to_retrain.append(trainingjob_name) + LOGGER.debug("Failed to retrain because deletion in progress" + \ + "(trainingjob_name is " + trainingjob_name + ")") + continue + + if (get_one_word_status(json.loads(results[0][9])) + not in [States.FINISHED.name, States.FAILED.name]): + not_possible_to_retrain.append(trainingjob_name) + LOGGER.debug("Not finished or not failed status" + \ + "(trainingjob_name is " + trainingjob_name + ")") + continue + + enable_versioning = results[0][12] + pipeline_version = results[0][13] + description = results[0][1] + pipeline_name = results[0][3] + experiment_name = results[0][4] + feature_list = results[0][2] + arguments = json.loads(results[0][5])['arguments'] + query_filter = results[0][6] + datalake_source = get_one_key(json.loads(results[0][14])["datalake_source"]) + _measurement = results[0][17] + bucket = results[0][18] + + notification_url = "" + if "notification_url" in obj: + notification_url = obj['notification_url'] + + if "feature_filter" in obj: + query_filter = obj['feature_filter'] + + try: + add_update_trainingjob(description, pipeline_name, experiment_name, + feature_list, arguments, query_filter, False, + enable_versioning, pipeline_version, + datalake_source, trainingjob_name, PS_DB_OBJ, + notification_url, _measurement, bucket) + except Exception as err: + not_possible_to_retrain.append(trainingjob_name) + LOGGER.debug(str(err) + "(usecase_name is " + trainingjob_name + ")") + continue + + url = 'http://' + str(TRAININGMGR_CONFIG_OBJ.my_ip) + \ + ':' + str(TRAININGMGR_CONFIG_OBJ.my_port) + \ + '/trainingjobs/' +trainingjob_name + '/training' + response = requests.post(url) + + if response.status_code == status.HTTP_200_OK: + possible_to_retrain.append(trainingjob_name) + else: + LOGGER.debug("not 200 response" + "(trainingjob_name is " + trainingjob_name + ")") + not_possible_to_retrain.append(trainingjob_name) + + else: + LOGGER.debug("not present in postgres db" + "(trainingjob_name is " + trainingjob_name + ")") + not_possible_to_retrain.append(trainingjob_name) + + LOGGER.debug('success list: ' + str(possible_to_retrain)) + LOGGER.debug('failure list: ' + str(not_possible_to_retrain)) + + return APP.response_class(response=json.dumps( \ + { + "success count": len(possible_to_retrain), + "failure count": len(not_possible_to_retrain) + }), + status=status.HTTP_200_OK, + mimetype='application/json') + @APP.route('/trainingjobs/metadata/') def get_metadata(trainingjob_name): """