From 3c79da393c11672ba7e611bc70083298e8c48ba2 Mon Sep 17 00:00:00 2001 From: ashishj1729 Date: Fri, 15 Nov 2024 18:40:10 +0530 Subject: [PATCH] Addition of Parameters in TrainingJob Model as per R1AP v6 Change-Id: I5ab822c9464baa798b7caeeb4770a126e5870f52 Signed-off-by: ashishj1729 --- trainingmgr/models/trainingjob.py | 11 +++++++-- trainingmgr/trainingmgr_main.py | 47 +++++++++++++++++---------------------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/trainingmgr/models/trainingjob.py b/trainingmgr/models/trainingjob.py index 4c6eb33..941e76e 100644 --- a/trainingmgr/models/trainingjob.py +++ b/trainingmgr/models/trainingjob.py @@ -49,9 +49,16 @@ class TrainingJob(db.Model): updation_time = Column(DateTime(timezone=False),onupdate=func.now() ,nullable=True) version = Column(Integer, nullable=True) deletion_in_progress = Column(Boolean, nullable=True) - training_config = Column(String(5000), nullable=False) + # As per R1AP v6 (Optional) + model_location = db.Column(db.String(1000), nullable=True) + training_dataset = db.Column(db.String(1000), nullable=True) + validation_dataset = db.Column(db.String(1000), nullable=True) + training_config = db.Column(db.String(5000), nullable=False) + notification_url = db.Column(db.String(1000), nullable=True) + consumer_rapp_id = db.Column(db.String(1000), nullable=True) + producer_rapp_id = db.Column(db.String(1000), nullable=True) + model_url = Column(String(1000), nullable=True) - notification_url = Column(String(1000), nullable=True) model_id = Column(Integer, nullable=False) #defineing relationships diff --git a/trainingmgr/trainingmgr_main.py b/trainingmgr/trainingmgr_main.py index 8da790f..740c0c3 100644 --- a/trainingmgr/trainingmgr_main.py +++ b/trainingmgr/trainingmgr_main.py @@ -93,7 +93,7 @@ def error(err): status=err.code, mimetype=MIMETYPE_JSON) -# Training-Config Handled + @APP.route('/trainingjobs//', methods=['GET']) def get_trainingjob_by_name_version(trainingjob_name, version): """ @@ -173,21 +173,17 @@ def get_trainingjob_by_name_version(trainingjob_name, version): if trainingjob: dict_data = { "trainingjob_name": trainingjob.trainingjob_name, + "model_location": trainingjob.model_location, + "training_dataset": trainingjob.training_dataset, + "validation_dataset": trainingjob.validation_dataset, "training_config": json.loads(trainingjob.training_config), - # "description": trainingjob.description, - # "feature_list": trainingjob.feature_group_name, - # "pipeline_name": trainingjob.pipeline_name, - # "experiment_name": trainingjob.experiment_name, - # "arguments": trainingjob.arguments, - # "query_filter": trainingjob.query_filter, + "consumer_rapp_id": trainingjob.consumer_rapp_id, + "producer_rapp_id": trainingjob.producer_rapp_id, "creation_time": str(trainingjob.creation_time), "run_id": trainingjob.run_id, "steps_state": trainingjob.steps_state.states , "updation_time": str(trainingjob.updation_time), "version": trainingjob.version, - # "enable_versioning": trainingjob.enable_versioning, - # "pipeline_version": trainingjob.pipeline_version, - # "datalake_source": get_one_key(json.loads(trainingjob.datalake_source)['datalake_source']), "model_url": trainingjob.model_url, "notification_url": trainingjob.notification_url, "accuracy": data @@ -207,7 +203,7 @@ def get_trainingjob_by_name_version(trainingjob_name, version): status=response_code, mimetype=MIMETYPE_JSON) -# Training-Config Handled (No Change) + @APP.route('/trainingjobs///steps_state', methods=['GET']) def get_steps_state(trainingjob_name, version): """ @@ -276,7 +272,7 @@ def get_steps_state(trainingjob_name, version): status=response_code, mimetype=MIMETYPE_JSON) -# Training-Config Handled (No Change) + @APP.route('/model///Model.zip', methods=['GET']) def get_model(trainingjob_name, version): """ @@ -305,7 +301,7 @@ def get_model(trainingjob_name, version): except Exception: return {"Exception": "error while downloading model"}, status.HTTP_500_INTERNAL_SERVER_ERROR -# Training-Config Handled + @APP.route('/trainingjobs//training', methods=['POST']) def training(trainingjob_name): """ @@ -392,7 +388,7 @@ def training(trainingjob_name): return APP.response_class(response=json.dumps(response_data),status=response_code, mimetype=MIMETYPE_JSON) -# Training-Config Handled + @APP.route('/trainingjob/dataExtractionNotification', methods=['POST']) def data_extraction_notification(): """ @@ -490,7 +486,7 @@ def data_extraction_notification(): status=status.HTTP_200_OK, mimetype=MIMETYPE_JSON) -# Training-Config Handled (No Change) + @APP.route('/pipelines/', methods=['GET']) def get_pipeline_info_by_name(pipe_name): """ @@ -534,7 +530,7 @@ def get_pipeline_info_by_name(pipe_name): status=response_code, mimetype=MIMETYPE_JSON) -# Training-Config Handled .. + @APP.route('/trainingjob/pipelineNotification', methods=['POST']) def pipeline_notification(): """ @@ -623,7 +619,7 @@ def pipeline_notification(): "Pipeline notification success.", LOGGER, True, trainingjob_name, MM_SDK) -# Training-Config Handled (No Change) + @APP.route('/trainingjobs/latest', methods=['GET']) def trainingjobs_operations(): """ @@ -672,7 +668,7 @@ def trainingjobs_operations(): status=response_code, mimetype=MIMETYPE_JSON) -# Training-Config Handled (No Change) .. + @APP.route("/pipelines//upload", methods=['POST']) def upload_pipeline(pipe_name): """ @@ -771,7 +767,6 @@ def upload_pipeline(pipe_name): mimetype=MIMETYPE_JSON) -# Training-Config Handled (No Change) @APP.route("/pipelines//versions", methods=['GET']) def get_versions_for_pipeline(pipeline_name): """ @@ -825,7 +820,7 @@ def get_versions_for_pipeline(pipeline_name): status=response_code, mimetype=MIMETYPE_JSON) -# Training-Config Handled (No Change) + @APP.route('/pipelines', methods=['GET']) def get_pipelines(): """ @@ -859,7 +854,7 @@ def get_pipelines(): api_response = {"Exception": str(err)} return APP.response_class(response=json.dumps(api_response),status=response_code,mimetype=MIMETYPE_JSON) -# Training-Config Handled (No Change) + @APP.route('/experiments', methods=['GET']) def get_all_experiment_names(): """ @@ -908,7 +903,7 @@ def get_all_experiment_names(): status=reponse_code, mimetype=MIMETYPE_JSON) -# Training-Config handled + @APP.route('/trainingjobs/', methods=['POST', 'PUT']) def trainingjob_operations(trainingjob_name): """ @@ -1019,7 +1014,7 @@ def trainingjob_operations(trainingjob_name): status= response_code, mimetype=MIMETYPE_JSON) -# Training-Config Handled (No Change) .. + @APP.route('/trainingjobs/retraining', methods=['POST']) def retraining(): """ @@ -1123,7 +1118,7 @@ def retraining(): status=status.HTTP_200_OK, mimetype='application/json') -# Training-Config Handled (No Change) .. + @APP.route('/trainingjobs', methods=['DELETE']) def delete_list_of_trainingjob_version(): """ @@ -1251,7 +1246,7 @@ def delete_list_of_trainingjob_version(): status=status.HTTP_200_OK, mimetype='application/json') -# Training-Config Handled (No Change) + @APP.route('/trainingjobs/metadata/') def get_metadata(trainingjob_name): """ @@ -1634,7 +1629,7 @@ def delete_list_of_feature_group(): status=status.HTTP_200_OK, mimetype='application/json') -# Training-Config Handled (No Change) + def async_feature_engineering_status(): """ This function takes trainingjobs from DATAEXTRACTION_JOBS_CACHE and checks data extraction status -- 2.16.6