From: rajdeep11 Date: Mon, 23 Dec 2024 10:37:30 +0000 (+0530) Subject: changes to add retraining pipeline in json and fixing the null dump to X-Git-Tag: 3.0.0~5 X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=commitdiff_plain;h=32eba8eb4a75563acf791fad7d65b4258a379bf5;p=aiml-fw%2Fawmf%2Ftm.git changes to add retraining pipeline in json and fixing the null dump to string Change-Id: I9f348f1363286a87b03d38bd4fa399e5198937ef Signed-off-by: rajdeep11 --- diff --git a/trainingmgr/common/trainingConfig_parser.py b/trainingmgr/common/trainingConfig_parser.py index 4b38e96..903419e 100644 --- a/trainingmgr/common/trainingConfig_parser.py +++ b/trainingmgr/common/trainingConfig_parser.py @@ -48,7 +48,9 @@ def __getLeafPaths(): }, "trainingPipeline": { "pipeline_name": "qoe_Pipeline", - "pipeline_version": "2" + "pipeline_version": "2", + "retraining_pipeline_name":"retrain-qoe-pipeline", + "retraining_pipeline_version:"2" } ''' paths = { @@ -56,8 +58,10 @@ def __getLeafPaths(): "feature_group_name": ["dataPipeline", "feature_group_name"], "query_filter" : ["dataPipeline", "query_filter"], "arguments" : ["dataPipeline", "arguments"], - "pipeline_name": ["trainingPipeline", "pipeline_name"], - "pipeline_version": ["trainingPipeline", "pipeline_version"], + "training_pipeline_name": ["trainingPipeline", "training_pipeline_name"], + "training_pipeline_version": ["trainingPipeline", "training_pipeline_version"], + "retraining_pipeline_name": ["trainingPipeline", "retraining_pipeline_name"], + "retraining_pipeline_version":["trainingPipeline", "retraining_pipeline_version"] } return paths diff --git a/trainingmgr/controller/trainingjob_controller.py b/trainingmgr/controller/trainingjob_controller.py index 99cece0..0761768 100644 --- a/trainingmgr/controller/trainingjob_controller.py +++ b/trainingmgr/controller/trainingjob_controller.py @@ -85,15 +85,7 @@ def create_trainingjob(): if registered_model_dict["modelLocation"] != trainingjob.model_location: return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} and trainingjob created does not have same modelLocation, Please first register at MME properly and then continue"}), status.HTTP_400_BAD_REQUEST - if registered_model_dict["modelId"]["artifactVersion"] == "0.0.0": - if registered_model_dict["modelLocation"] == "": - return create_training_job(trainingjob=trainingjob, registered_model_dict=registered_model_dict) - else: - trainingjob = update_trainingPipeline(trainingjob) - return create_training_job(trainingjob=trainingjob, registered_model_dict=registered_model_dict) - else: - trainingjob = update_trainingPipeline(trainingjob) - return create_training_job(trainingjob=trainingjob, registered_model_dict=registered_model_dict) + return create_training_job(trainingjob=trainingjob, registered_model_dict= registered_model_dict) except ValidationError as error: return jsonify(error.messages), status.HTTP_400_BAD_REQUEST diff --git a/trainingmgr/schemas/trainingjob_schema.py b/trainingmgr/schemas/trainingjob_schema.py index 2e6034b..f24efb0 100644 --- a/trainingmgr/schemas/trainingjob_schema.py +++ b/trainingmgr/schemas/trainingjob_schema.py @@ -21,7 +21,7 @@ from trainingmgr.schemas import ma from trainingmgr.models import TrainingJob from trainingmgr.models.trainingjob import ModelID import json -from marshmallow import pre_load, post_dump, validates, ValidationError +from marshmallow import pre_load, post_dump, fields, validates, ValidationError PATTERN = re.compile(r"\w+") @@ -29,6 +29,13 @@ class ModelSchema(ma.SQLAlchemyAutoSchema): class Meta: model = ModelID load_instance = True + + @post_dump + def replace_null_with_empty_string(self, data, **kwargs): + for key, value in data.items(): + if value is None: + data[key]="" + return data class TrainingJobSchema(ma.SQLAlchemyAutoSchema): class Meta: @@ -38,6 +45,8 @@ class TrainingJobSchema(ma.SQLAlchemyAutoSchema): modelId = ma.Nested(ModelSchema) + # consumer_rapp_id = fields.String(allow_none = True, dump_default="") + @pre_load def processModelId(self, data, **kwargs): modelname = data['modelId']['modelname'] @@ -56,4 +65,11 @@ class TrainingJobSchema(ma.SQLAlchemyAutoSchema): data["training_config"] = json.loads(data["training_config"]) return data + + @post_dump + def replace_null_with_empty_string(self, data, **kwargs): + for key, value in data.items(): + if value is None: + data[key]="" + return data \ No newline at end of file diff --git a/trainingmgr/service/training_job_service.py b/trainingmgr/service/training_job_service.py index 1932277..91caf2f 100644 --- a/trainingmgr/service/training_job_service.py +++ b/trainingmgr/service/training_job_service.py @@ -273,13 +273,12 @@ def training(trainingjob): response.headers['Location'] = "training-jobs/" + str(training_job_id) return response, 201 -def update_trainingPipeline(trainingjob): + +def fetch_pipelinename_and_version(type, training_config): try: - training_config = trainingjob.training_config - training_config = json.dumps(setField(training_config, "pipeline_name", "qoe_retraining_pipeline")) - training_config = json.dumps(setField(training_config, "pipeline_version", "qoe_retraining_pipeline")) - trainingjob.training_config = training_config - return trainingjob + if type =="training": + return getField(training_config, "training_pipeline_name"), getField(training_config, "training_pipeline_version") + else : + return getField(training_config, "retraining_pipeline_name"), getField(training_config, "retraining_pipeline_version") except Exception as err: - LOGGER.error(f"error in updating the trainingPipeline due to {str(err)}") - raise TMException("failed to update the trainingPipeline") + raise TMException(f"cant fetch training or retraining pipeline name or version from trainingconfig {training_config}") \ No newline at end of file diff --git a/trainingmgr/trainingmgr_main.py b/trainingmgr/trainingmgr_main.py index 627d107..d9741fe 100644 --- a/trainingmgr/trainingmgr_main.py +++ b/trainingmgr/trainingmgr_main.py @@ -46,7 +46,7 @@ from trainingmgr.controller.pipeline_controller import pipeline_controller from trainingmgr.common.trainingConfig_parser import validateTrainingConfig, getField from trainingmgr.handler.async_handler import start_async_handler from trainingmgr.service.mme_service import get_modelinfo_by_modelId_service -from trainingmgr.service.training_job_service import change_status_tj, change_update_field_value, get_training_job, update_artifact_version +from trainingmgr.service.training_job_service import change_status_tj, change_update_field_value, fetch_pipelinename_and_version, get_training_job, update_artifact_version APP = Flask(__name__) TRAININGMGR_CONFIG_OBJ = TrainingMgrConfig() @@ -170,9 +170,32 @@ def data_extraction_notification(): argument_dict[key] = str(val) LOGGER.debug(argument_dict) # Experiment name is harded to be Default + + model_id = trainingjob.modelId + registered_model_list = get_modelinfo_by_modelId_service(model_id.modelname, model_id.modelversion) + + if registered_model_list is None: + return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is not registered at MME, Please first register at MME and then continue"}), status.HTTP_400_BAD_REQUEST + + registered_model_dict = registered_model_list[0] + + pipeline_name ="" + pipeline_version ="" + if registered_model_dict["modelId"]["artifactVersion"] == "0.0.0": + if registered_model_dict["modelLocation"] == "": + pipeline_name, pipeline_version = fetch_pipelinename_and_version("training", trainingjob.training_config) + else: + pipeline_name, pipeline_version = fetch_pipelinename_and_version("re-training", trainingjob.training_config) + if pipeline_name == "" or pipeline_version =="": + return jsonify({"Error": "Provide retraining pipeline name and version"}), 500 + else: + pipeline_name, pipeline_version = fetch_pipelinename_and_version("re-training", trainingjob.training_config) + if pipeline_name == "" or pipeline_version =="": + return jsonify({"Error": "Provide retraining pipeline name and version"}), 500 + training_details = { - "pipeline_name": getField(trainingjob.training_config, "pipeline_name"), "experiment_name": 'Default', - "arguments": argument_dict, "pipeline_version": getField(trainingjob.training_config, "pipeline_version") + "pipeline_name": pipeline_name, "experiment_name": 'Default', + "arguments": argument_dict, "pipeline_version": pipeline_version } LOGGER.debug("training detail for kf adapter is: "+ str(training_details)) response = training_start(TRAININGMGR_CONFIG_OBJ, training_details, trainingjob_id) @@ -230,7 +253,6 @@ def data_extraction_notification(): status=status.HTTP_200_OK, mimetype=MIMETYPE_JSON) - # Will be migrated to pipline Mgr in next iteration @APP.route('/trainingjob/pipelineNotification', methods=['POST']) def pipeline_notification():