},
"trainingPipeline": {
"pipeline_name": "qoe_Pipeline",
- "pipeline_version": "2"
+ "pipeline_version": "2",
+ "retraining_pipeline_name":"retrain-qoe-pipeline",
+ "retraining_pipeline_version:"2"
}
'''
paths = {
"feature_group_name": ["dataPipeline", "feature_group_name"],
"query_filter" : ["dataPipeline", "query_filter"],
"arguments" : ["dataPipeline", "arguments"],
- "pipeline_name": ["trainingPipeline", "pipeline_name"],
- "pipeline_version": ["trainingPipeline", "pipeline_version"],
+ "training_pipeline_name": ["trainingPipeline", "training_pipeline_name"],
+ "training_pipeline_version": ["trainingPipeline", "training_pipeline_version"],
+ "retraining_pipeline_name": ["trainingPipeline", "retraining_pipeline_name"],
+ "retraining_pipeline_version":["trainingPipeline", "retraining_pipeline_version"]
}
return paths
if registered_model_dict["modelLocation"] != trainingjob.model_location:
return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} and trainingjob created does not have same modelLocation, Please first register at MME properly and then continue"}), status.HTTP_400_BAD_REQUEST
- if registered_model_dict["modelId"]["artifactVersion"] == "0.0.0":
- if registered_model_dict["modelLocation"] == "":
- return create_training_job(trainingjob=trainingjob, registered_model_dict=registered_model_dict)
- else:
- trainingjob = update_trainingPipeline(trainingjob)
- return create_training_job(trainingjob=trainingjob, registered_model_dict=registered_model_dict)
- else:
- trainingjob = update_trainingPipeline(trainingjob)
- return create_training_job(trainingjob=trainingjob, registered_model_dict=registered_model_dict)
+ return create_training_job(trainingjob=trainingjob, registered_model_dict= registered_model_dict)
except ValidationError as error:
return jsonify(error.messages), status.HTTP_400_BAD_REQUEST
from trainingmgr.models import TrainingJob
from trainingmgr.models.trainingjob import ModelID
import json
-from marshmallow import pre_load, post_dump, validates, ValidationError
+from marshmallow import pre_load, post_dump, fields, validates, ValidationError
PATTERN = re.compile(r"\w+")
class Meta:
model = ModelID
load_instance = True
+
+ @post_dump
+ def replace_null_with_empty_string(self, data, **kwargs):
+ for key, value in data.items():
+ if value is None:
+ data[key]=""
+ return data
class TrainingJobSchema(ma.SQLAlchemyAutoSchema):
class Meta:
modelId = ma.Nested(ModelSchema)
+ # consumer_rapp_id = fields.String(allow_none = True, dump_default="")
+
@pre_load
def processModelId(self, data, **kwargs):
modelname = data['modelId']['modelname']
data["training_config"] = json.loads(data["training_config"])
return data
+
+ @post_dump
+ def replace_null_with_empty_string(self, data, **kwargs):
+ for key, value in data.items():
+ if value is None:
+ data[key]=""
+ return data
\ No newline at end of file
response.headers['Location'] = "training-jobs/" + str(training_job_id)
return response, 201
-def update_trainingPipeline(trainingjob):
+
+def fetch_pipelinename_and_version(type, training_config):
try:
- training_config = trainingjob.training_config
- training_config = json.dumps(setField(training_config, "pipeline_name", "qoe_retraining_pipeline"))
- training_config = json.dumps(setField(training_config, "pipeline_version", "qoe_retraining_pipeline"))
- trainingjob.training_config = training_config
- return trainingjob
+ if type =="training":
+ return getField(training_config, "training_pipeline_name"), getField(training_config, "training_pipeline_version")
+ else :
+ return getField(training_config, "retraining_pipeline_name"), getField(training_config, "retraining_pipeline_version")
except Exception as err:
- LOGGER.error(f"error in updating the trainingPipeline due to {str(err)}")
- raise TMException("failed to update the trainingPipeline")
+ raise TMException(f"cant fetch training or retraining pipeline name or version from trainingconfig {training_config}")
\ No newline at end of file
from trainingmgr.common.trainingConfig_parser import validateTrainingConfig, getField
from trainingmgr.handler.async_handler import start_async_handler
from trainingmgr.service.mme_service import get_modelinfo_by_modelId_service
-from trainingmgr.service.training_job_service import change_status_tj, change_update_field_value, get_training_job, update_artifact_version
+from trainingmgr.service.training_job_service import change_status_tj, change_update_field_value, fetch_pipelinename_and_version, get_training_job, update_artifact_version
APP = Flask(__name__)
TRAININGMGR_CONFIG_OBJ = TrainingMgrConfig()
argument_dict[key] = str(val)
LOGGER.debug(argument_dict)
# Experiment name is harded to be Default
+
+ model_id = trainingjob.modelId
+ registered_model_list = get_modelinfo_by_modelId_service(model_id.modelname, model_id.modelversion)
+
+ if registered_model_list is None:
+ return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is not registered at MME, Please first register at MME and then continue"}), status.HTTP_400_BAD_REQUEST
+
+ registered_model_dict = registered_model_list[0]
+
+ pipeline_name =""
+ pipeline_version =""
+ if registered_model_dict["modelId"]["artifactVersion"] == "0.0.0":
+ if registered_model_dict["modelLocation"] == "":
+ pipeline_name, pipeline_version = fetch_pipelinename_and_version("training", trainingjob.training_config)
+ else:
+ pipeline_name, pipeline_version = fetch_pipelinename_and_version("re-training", trainingjob.training_config)
+ if pipeline_name == "" or pipeline_version =="":
+ return jsonify({"Error": "Provide retraining pipeline name and version"}), 500
+ else:
+ pipeline_name, pipeline_version = fetch_pipelinename_and_version("re-training", trainingjob.training_config)
+ if pipeline_name == "" or pipeline_version =="":
+ return jsonify({"Error": "Provide retraining pipeline name and version"}), 500
+
training_details = {
- "pipeline_name": getField(trainingjob.training_config, "pipeline_name"), "experiment_name": 'Default',
- "arguments": argument_dict, "pipeline_version": getField(trainingjob.training_config, "pipeline_version")
+ "pipeline_name": pipeline_name, "experiment_name": 'Default',
+ "arguments": argument_dict, "pipeline_version": pipeline_version
}
LOGGER.debug("training detail for kf adapter is: "+ str(training_details))
response = training_start(TRAININGMGR_CONFIG_OBJ, training_details, trainingjob_id)
status=status.HTTP_200_OK,
mimetype=MIMETYPE_JSON)
-
# Will be migrated to pipline Mgr in next iteration
@APP.route('/trainingjob/pipelineNotification', methods=['POST'])
def pipeline_notification():