changes to add retraining pipeline in json and fixing the null dump to 58/13958/2
authorrajdeep11 <rajdeep.sin@samsung.com>
Mon, 23 Dec 2024 10:37:30 +0000 (16:07 +0530)
committerrajdeep11 <rajdeep.sin@samsung.com>
Mon, 23 Dec 2024 10:51:07 +0000 (16:21 +0530)
string

Change-Id: I9f348f1363286a87b03d38bd4fa399e5198937ef
Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
trainingmgr/common/trainingConfig_parser.py
trainingmgr/controller/trainingjob_controller.py
trainingmgr/schemas/trainingjob_schema.py
trainingmgr/service/training_job_service.py
trainingmgr/trainingmgr_main.py

index 4b38e96..903419e 100644 (file)
@@ -48,7 +48,9 @@ def __getLeafPaths():
             },
             "trainingPipeline": {
                     "pipeline_name": "qoe_Pipeline",
-                    "pipeline_version": "2"
+                    "pipeline_version": "2", 
+                    "retraining_pipeline_name":"retrain-qoe-pipeline",
+                    "retraining_pipeline_version:"2"
         }
     '''
     paths = {
@@ -56,8 +58,10 @@ def __getLeafPaths():
         "feature_group_name": ["dataPipeline", "feature_group_name"],
         "query_filter" : ["dataPipeline", "query_filter"],
         "arguments" : ["dataPipeline", "arguments"],
-        "pipeline_name": ["trainingPipeline", "pipeline_name"],
-        "pipeline_version": ["trainingPipeline", "pipeline_version"],
+        "training_pipeline_name": ["trainingPipeline", "training_pipeline_name"],
+        "training_pipeline_version": ["trainingPipeline", "training_pipeline_version"],
+        "retraining_pipeline_name": ["trainingPipeline", "retraining_pipeline_name"],
+        "retraining_pipeline_version":["trainingPipeline", "retraining_pipeline_version"]
     }
     return paths
 
index 99cece0..0761768 100644 (file)
@@ -85,15 +85,7 @@ def create_trainingjob():
         if registered_model_dict["modelLocation"] != trainingjob.model_location:
             return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} and trainingjob created does not have same modelLocation, Please first register at MME properly and then continue"}), status.HTTP_400_BAD_REQUEST
         
-        if registered_model_dict["modelId"]["artifactVersion"] == "0.0.0":
-            if registered_model_dict["modelLocation"] == "":
-                return create_training_job(trainingjob=trainingjob, registered_model_dict=registered_model_dict)
-            else:
-                trainingjob = update_trainingPipeline(trainingjob)
-                return create_training_job(trainingjob=trainingjob, registered_model_dict=registered_model_dict)
-        else:
-            trainingjob = update_trainingPipeline(trainingjob)
-            return create_training_job(trainingjob=trainingjob, registered_model_dict=registered_model_dict)
+        return create_training_job(trainingjob=trainingjob, registered_model_dict= registered_model_dict)
         
     except ValidationError as error:
         return jsonify(error.messages), status.HTTP_400_BAD_REQUEST
index 2e6034b..f24efb0 100644 (file)
@@ -21,7 +21,7 @@ from trainingmgr.schemas import ma
 from trainingmgr.models import TrainingJob
 from trainingmgr.models.trainingjob import ModelID
 import json
-from marshmallow import pre_load, post_dump, validates, ValidationError
+from marshmallow import pre_load, post_dump, fields, validates, ValidationError
 
 PATTERN = re.compile(r"\w+")
 
@@ -29,6 +29,13 @@ class ModelSchema(ma.SQLAlchemyAutoSchema):
     class Meta:
         model = ModelID
         load_instance = True
+        
+    @post_dump
+    def replace_null_with_empty_string(self, data, **kwargs):
+        for key, value in data.items():
+            if value is None:
+                data[key]=""
+        return data
 
 class TrainingJobSchema(ma.SQLAlchemyAutoSchema):
     class Meta:
@@ -38,6 +45,8 @@ class TrainingJobSchema(ma.SQLAlchemyAutoSchema):
     
     modelId = ma.Nested(ModelSchema)
 
+    # consumer_rapp_id = fields.String(allow_none = True, dump_default="")
+
     @pre_load
     def processModelId(self, data, **kwargs):
         modelname = data['modelId']['modelname']
@@ -56,4 +65,11 @@ class TrainingJobSchema(ma.SQLAlchemyAutoSchema):
             data["training_config"] = json.loads(data["training_config"])   
         
         return data
+    
+    @post_dump
+    def replace_null_with_empty_string(self, data, **kwargs):
+        for key, value in data.items():
+            if value is None:
+                data[key]=""
+        return data
         
\ No newline at end of file
index 1932277..91caf2f 100644 (file)
@@ -273,13 +273,12 @@ def training(trainingjob):
     response.headers['Location'] = "training-jobs/" + str(training_job_id)
     return response, 201
 
-def update_trainingPipeline(trainingjob):
+    
+def fetch_pipelinename_and_version(type, training_config):
     try:
-        training_config = trainingjob.training_config
-        training_config = json.dumps(setField(training_config, "pipeline_name", "qoe_retraining_pipeline"))
-        training_config = json.dumps(setField(training_config, "pipeline_version", "qoe_retraining_pipeline"))
-        trainingjob.training_config = training_config
-        return trainingjob
+        if type =="training":
+         return getField(training_config, "training_pipeline_name"), getField(training_config, "training_pipeline_version")
+        else :
+            return getField(training_config, "retraining_pipeline_name"), getField(training_config, "retraining_pipeline_version")
     except Exception as err:
-        LOGGER.error(f"error in updating the trainingPipeline due to {str(err)}")
-        raise TMException("failed to update the trainingPipeline")
+        raise TMException(f"cant fetch training or retraining pipeline name or version from trainingconfig {training_config}")
\ No newline at end of file
index 627d107..d9741fe 100644 (file)
@@ -46,7 +46,7 @@ from trainingmgr.controller.pipeline_controller import pipeline_controller
 from trainingmgr.common.trainingConfig_parser import validateTrainingConfig, getField
 from trainingmgr.handler.async_handler import start_async_handler
 from trainingmgr.service.mme_service import get_modelinfo_by_modelId_service
-from trainingmgr.service.training_job_service import change_status_tj, change_update_field_value, get_training_job, update_artifact_version
+from trainingmgr.service.training_job_service import change_status_tj, change_update_field_value, fetch_pipelinename_and_version, get_training_job, update_artifact_version
 
 APP = Flask(__name__)
 TRAININGMGR_CONFIG_OBJ = TrainingMgrConfig()
@@ -170,9 +170,32 @@ def data_extraction_notification():
                 argument_dict[key] = str(val)
         LOGGER.debug(argument_dict)
         # Experiment name is harded to be Default
+
+        model_id = trainingjob.modelId
+        registered_model_list = get_modelinfo_by_modelId_service(model_id.modelname, model_id.modelversion)
+
+        if registered_model_list is None:
+            return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is not registered at MME, Please first register at MME and then continue"}), status.HTTP_400_BAD_REQUEST
+
+        registered_model_dict = registered_model_list[0]
+
+        pipeline_name =""
+        pipeline_version =""
+        if registered_model_dict["modelId"]["artifactVersion"] == "0.0.0":
+            if registered_model_dict["modelLocation"] == "":
+                pipeline_name, pipeline_version = fetch_pipelinename_and_version("training", trainingjob.training_config)
+            else:
+                pipeline_name, pipeline_version = fetch_pipelinename_and_version("re-training", trainingjob.training_config)
+                if pipeline_name == "" or  pipeline_version =="":
+                    return jsonify({"Error": "Provide retraining pipeline name and version"}), 500
+        else:
+            pipeline_name, pipeline_version = fetch_pipelinename_and_version("re-training", trainingjob.training_config)
+            if pipeline_name == "" or  pipeline_version =="":
+                return jsonify({"Error": "Provide retraining pipeline name and version"}), 500
+
         training_details = {
-            "pipeline_name": getField(trainingjob.training_config, "pipeline_name"), "experiment_name": 'Default',
-            "arguments": argument_dict, "pipeline_version": getField(trainingjob.training_config, "pipeline_version")
+            "pipeline_name": pipeline_name, "experiment_name": 'Default',
+            "arguments": argument_dict, "pipeline_version": pipeline_version
         }
         LOGGER.debug("training detail for kf adapter is: "+ str(training_details))
         response = training_start(TRAININGMGR_CONFIG_OBJ, training_details, trainingjob_id)
@@ -230,7 +253,6 @@ def data_extraction_notification():
                                     status=status.HTTP_200_OK,
                                     mimetype=MIMETYPE_JSON)
 
-
 # Will be migrated to pipline Mgr in next iteration
 @APP.route('/trainingjob/pipelineNotification', methods=['POST'])
 def pipeline_notification():