From a8e36a74fa74a8d9b12d5111dccb714fb10dc9b2 Mon Sep 17 00:00:00 2001 From: ashishj1729 Date: Tue, 12 Nov 2024 13:06:24 +0530 Subject: [PATCH] Addition of TrainingConfig Parameter in trainingJob Model Change-Id: I48e3c9d870682a4c31e4430106bbfbef93279e7f Signed-off-by: ashishj1729 --- tests/test_tm_apis.py | 38 ++++--- trainingmgr/common/trainingConfig_parser.py | 92 ++++++++++++++++ trainingmgr/common/trainingmgr_util.py | 3 +- trainingmgr/db/trainingjob_db.py | 21 ++-- trainingmgr/db/trainingmgr_ps_db.py | 20 ++-- trainingmgr/models/trainingjob.py | 25 ++--- trainingmgr/schemas/trainingjob_schema.py | 4 +- trainingmgr/trainingmgr_main.py | 159 ++++++++++++++++------------ 8 files changed, 239 insertions(+), 123 deletions(-) create mode 100644 trainingmgr/common/trainingConfig_parser.py diff --git a/tests/test_tm_apis.py b/tests/test_tm_apis.py index 989bbce..707001d 100644 --- a/tests/test_tm_apis.py +++ b/tests/test_tm_apis.py @@ -36,7 +36,7 @@ from trainingmgr.common.tmgr_logger import TMLogger from trainingmgr.common.trainingmgr_config import TrainingMgrConfig from trainingmgr.common.exceptions_utls import DBException, TMException from trainingmgr.models import TrainingJob - +from trainingmgr.common.trainingConfig_parser import getField trainingmgr_main.LOGGER = pytest.logger trainingmgr_main.LOCK = Lock() trainingmgr_main.DATAEXTRACTION_JOBS_CACHE = {} @@ -281,26 +281,33 @@ class Test_get_trainingjob_by_name_version: """Create a mock TrainingJob object.""" creation_time = datetime.datetime.now() updation_time = datetime.datetime.now() + training_config = { + "is_mme" : True, + "description": "Test description", + "dataPipeline": { + "feature_group_name": "test_feature_group", + "query_filter": "", + "arguments": {"epochs" : 1, "trainingjob_name": "test_job"} + }, + "trainingPipeline": { + "pipeline_name": "test_pipeline", + "pipeline_version": "2", + "enable_versioning": True + } + } + + return TrainingJob( trainingjob_name="test_job", - description="Test description", - feature_group_name="test_feature_group", - pipeline_name="test_pipeline", - experiment_name="test_experiment", - arguments=json.dumps({"param1": "value1"}), - query_filter="test_filter", + training_config = json.dumps(training_config), creation_time=creation_time, run_id="test_run_id", steps_state=json.dumps({"step1": "completed"}), updation_time=updation_time, version=1, - enable_versioning=True, - pipeline_version="v1", - datalake_source=json.dumps({"datalake_source": {"source1": "path1"}}), model_url="http://test.model.url", notification_url="http://test.notification.url", deletion_in_progress=False, - is_mme=True, model_name="test_model", model_info="test_model_info" ) @@ -329,11 +336,10 @@ class Test_get_trainingjob_by_name_version: assert 'trainingjob' in data job_data = data['trainingjob'] assert job_data['trainingjob_name'] == "test_job" - assert job_data['description'] == "Test description" - assert job_data['feature_list'] == "test_feature_group" - assert job_data['pipeline_name'] == "test_pipeline" - assert job_data['experiment_name'] == "test_experiment" - assert job_data['is_mme'] is True + assert job_data['training_config']['description'] == "Test description" + assert job_data['training_config']['dataPipeline']['feature_group_name'] == "test_feature_group" + assert job_data['training_config']['trainingPipeline']['pipeline_name'] == "test_pipeline" + assert job_data['training_config']['is_mme'] is True assert job_data['model_name'] == "test_model" assert job_data['model_info'] == "test_model_info" assert job_data['accuracy'] == mock_metrics diff --git a/trainingmgr/common/trainingConfig_parser.py b/trainingmgr/common/trainingConfig_parser.py new file mode 100644 index 0000000..9467c67 --- /dev/null +++ b/trainingmgr/common/trainingConfig_parser.py @@ -0,0 +1,92 @@ +# ================================================================================== +# +# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== + +import json + +def parse_dict_by_fields(data, fields): + ''' + It parses the provided data (dicts) by the fields provided + Example: + data = {"a": 1, "b": {"c" : 4, "d" : {-1}}} + fields = ["a", "b", "c"] = 4 + fields = ["a", "b", "d"] = -1 + ''' + try: + cur = data + for field in fields: + cur = cur[field] + return cur + except Exception as e: + raise Exception("Can't parse Fields: {} in Data : {}| recieved-error : {}".format(fields, data, e)) + + +def __getLeafPaths(): + ''' + It returns paths possible to retrieve data + Based on TrainingConfig Schema: + { + "is_mme" : false, + "description": "This is something3", + "dataPipeline": { + "feature_group_name": "base2", + "query_filter": "", + "arguments": "{'epochs': '1'}" + }, + "trainingPipeline": { + "pipeline_name": "qoe_Pipeline", + "pipeline_version": "2", + "enable_versioning": true + } + ''' + paths = { + "is_mme": ["is_mme"], + "description": ["description"], + "feature_group_name": ["dataPipeline", "feature_group_name"], + "query_filter" : ["dataPipeline", "query_filter"], + "arguments" : ["dataPipeline", "arguments"], + "pipeline_name": ["trainingPipeline", "pipeline_name"], + "pipeline_version": ["trainingPipeline", "pipeline_version"], + "enable_versioning": ["trainingPipeline", "enable_versioning"] + } + return paths + +def prepocessTrainingConfig(trainingConfig): + if (isinstance(trainingConfig, str)): + return json.loads(trainingConfig) + return trainingConfig + +def validateTrainingConfig(trainingConfig): + ''' + One way to Validate TrainingConfig is to see if each Leafpath exists or not + Any other key:value pair than TrainingConfig Schema is not treated as invalid. + ''' + trainingConfig = prepocessTrainingConfig(trainingConfig) + allPaths = __getLeafPaths() + try: + for fieldPath in allPaths.values(): + parse_dict_by_fields(trainingConfig, fieldPath) + return True + except Exception as e: + print("Unable to Validate Error: ", e) + return False + +def getField(trainingConfig, fieldname): + trainingConfig = prepocessTrainingConfig(trainingConfig) + fieldPath = __getLeafPaths()[fieldname] + return parse_dict_by_fields(trainingConfig, fieldPath) + \ No newline at end of file diff --git a/trainingmgr/common/trainingmgr_util.py b/trainingmgr/common/trainingmgr_util.py index de3f2d3..f63854f 100644 --- a/trainingmgr/common/trainingmgr_util.py +++ b/trainingmgr/common/trainingmgr_util.py @@ -444,4 +444,5 @@ class PipelineInfo: "display_name": self.display_name, "description": self.description, "created_at": self.created_at - } \ No newline at end of file + } + diff --git a/trainingmgr/db/trainingjob_db.py b/trainingmgr/db/trainingjob_db.py index 34b7915..4b70adb 100644 --- a/trainingmgr/db/trainingjob_db.py +++ b/trainingmgr/db/trainingjob_db.py @@ -24,13 +24,13 @@ from trainingmgr.models import db, TrainingJob, FeatureGroup from trainingmgr.constants.steps import Steps from trainingmgr.constants.states import States from sqlalchemy.sql import func +from trainingmgr.common.trainingConfig_parser import getField DB_QUERY_EXEC_ERROR = "Failed to execute query in " PATTERN = re.compile(r"\w+") - def get_all_versions_info_by_name(trainingjob_name): """ This function returns information of given trainingjob_name for all version. @@ -45,11 +45,13 @@ def add_update_trainingjob(trainingjob, adding): try: # arguments_string = json.dumps({"arguments": trainingjob.arguments}) datalake_source_dic = {} - datalake_source_dic[trainingjob.datalake_source] = {} - trainingjob.datalake_source = json.dumps({"datalake_source": datalake_source_dic}) + # Needs to be populated from feature_group + # datalake_source_dic[trainingjob.datalake_source] = {} + # trainingjob.datalake_source = json.dumps({"datalake_source": datalake_source_dic}) trainingjob.creation_time = datetime.datetime.utcnow() trainingjob.updation_time = trainingjob.creation_time run_id = "No data available" + trainingjob.run_id = run_id steps_state = { Steps.DATA_EXTRACTION.name: States.NOT_STARTED.name, Steps.DATA_EXTRACTION_AND_TRAINING.name: States.NOT_STARTED.name, @@ -59,21 +61,20 @@ def add_update_trainingjob(trainingjob, adding): } trainingjob.steps_state=json.dumps(steps_state) trainingjob.model_url = "No data available." + trainingjob.notification_url = "No data available." trainingjob.deletion_in_progress = False trainingjob.version = 1 + if not adding: - trainingjob_max_version = db.session.query(TrainingJob).filter(TrainingJob.trainingjob_name == trainingjob.trainingjob_name).order_by(TrainingJob.version.desc()).first() - - if trainingjob_max_version.enable_versioning: + if getField(trainingjob_max_version.training_config, "enable_versioning"): trainingjob.version = trainingjob_max_version.version + 1 db.session.add(trainingjob) else: - - for key, value in trainingjob.items(): - if(key == 'id'): + for attr in vars(trainingjob): + if(attr == 'id' or attr == '_sa_instance_state'): continue - setattr(trainingjob_max_version, key, value) + setattr(trainingjob_max_version, attr, getattr(trainingjob, attr)) else: db.session.add(trainingjob) diff --git a/trainingmgr/db/trainingmgr_ps_db.py b/trainingmgr/db/trainingmgr_ps_db.py index a78a096..bdd79f2 100644 --- a/trainingmgr/db/trainingmgr_ps_db.py +++ b/trainingmgr/db/trainingmgr_ps_db.py @@ -81,24 +81,24 @@ class PSDB(): try: cur2.execute("create table if not exists trainingjob_info(" + \ "trainingjob_name varchar(128) NOT NULL," + \ - "description varchar(2000) NOT NULL," + \ - "feature_list varchar(2000) NOT NULL," + \ - "pipeline_name varchar(128) NOT NULL," + \ - "experiment_name varchar(128) NOT NULL," + \ - "arguments varchar(2000) NOT NULL," + \ - "query_filter varchar(2000) NOT NULL," + \ + "training_config varchar(5000) NOT NULL," + \ + # "feature_list varchar(2000) NOT NULL," + \ + # "pipeline_name varchar(128) NOT NULL," + \ + # "experiment_name varchar(128) NOT NULL," + \ + # "arguments varchar(2000) NOT NULL," + \ + # "query_filter varchar(2000) NOT NULL," + \ "creation_time TIMESTAMP NOT NULL," + \ "run_id varchar(1000) NOT NULL," + \ "steps_state varchar(1000) NOT NULL," + \ "updation_time TIMESTAMP NOT NULL," + \ "version INTEGER NOT NULL," + \ - "enable_versioning BOOLEAN NOT NULL," + \ - "pipeline_version varchar(128) NOT NULL," + \ - "datalake_source varchar(2000) NOT NULL," + \ + # "enable_versioning BOOLEAN NOT NULL," + \ + # "pipeline_version varchar(128) NOT NULL," + \ + # "datalake_source varchar(2000) NOT NULL," + \ "model_url varchar(100) NOT NULL," + \ "notification_url varchar(1000) NOT NULL," + \ "deletion_in_progress BOOLEAN NOT NULL," + \ - "is_mme BOOLEAN NOT NULL," + \ + # "is_mme BOOLEAN NOT NULL," + \ "model_name varchar(128) NOT NULL," + \ "model_info varchar(1000) NOT NULL," \ "PRIMARY KEY (trainingjob_name,version)" + \ diff --git a/trainingmgr/models/trainingjob.py b/trainingmgr/models/trainingjob.py index f64f8a1..64a0535 100644 --- a/trainingmgr/models/trainingjob.py +++ b/trainingmgr/models/trainingjob.py @@ -23,24 +23,15 @@ class TrainingJob(db.Model): __tablename__ = "trainingjob_info_table" id = db.Column(db.Integer, primary_key=True) trainingjob_name= db.Column(db.String(128), nullable=False) - description = db.Column(db.String(2000), nullable=False) - feature_group_name = db.Column(db.String(128), nullable=False) - pipeline_name= db.Column(db.String(128), nullable=False) - experiment_name = db.Column(db.String(128), nullable=False) - arguments = db.Column(db.String(2000), nullable=False) - query_filter = db.Column(db.String(2000), nullable=False) + run_id = db.Column(db.String(1000), nullable=True) + steps_state = db.Column(db.String(1000), nullable=True) creation_time = db.Column(db.DateTime(timezone=False), server_default=func.now(),nullable=False) - run_id = db.Column(db.String(1000), nullable=False) - steps_state = db.Column(db.String(1000), nullable=False) - updation_time = db.Column(db.DateTime(timezone=False),onupdate=func.now() ,nullable=False) - version = db.Column(db.Integer, nullable=False) - enable_versioning = db.Column(db.Boolean, nullable=False) - pipeline_version = db.Column(db.String(128), nullable=False) - datalake_source = db.Column(db.String(2000), nullable=False) - model_url = db.Column(db.String(1000), nullable=False) - notification_url = db.Column(db.String(1000), nullable=False) - deletion_in_progress = db.Column(db.Boolean, nullable=False) - is_mme = db.Column(db.Boolean, nullable=True) + updation_time = db.Column(db.DateTime(timezone=False),onupdate=func.now() ,nullable=True) + version = db.Column(db.Integer, nullable=True) + deletion_in_progress = db.Column(db.Boolean, nullable=True) + training_config = db.Column(db.String(5000), nullable=False) + model_url = db.Column(db.String(1000), nullable=True) + notification_url = db.Column(db.String(1000), nullable=True) model_name = db.Column(db.String(128), nullable=True) model_info = db.Column(db.String(1000), nullable=True) diff --git a/trainingmgr/schemas/trainingjob_schema.py b/trainingmgr/schemas/trainingjob_schema.py index 38b6a6d..b93fd15 100644 --- a/trainingmgr/schemas/trainingjob_schema.py +++ b/trainingmgr/schemas/trainingjob_schema.py @@ -23,4 +23,6 @@ class TrainingJobSchema(ma.SQLAlchemyAutoSchema): class Meta: model = TrainingJob include_relationships = True - load_instance = True \ No newline at end of file + load_instance = True + + \ No newline at end of file diff --git a/trainingmgr/trainingmgr_main.py b/trainingmgr/trainingmgr_main.py index cc4d458..bd8450e 100644 --- a/trainingmgr/trainingmgr_main.py +++ b/trainingmgr/trainingmgr_main.py @@ -49,11 +49,7 @@ from trainingmgr.constants.steps import Steps from trainingmgr.constants.states import States from trainingmgr.db.trainingmgr_ps_db import PSDB from trainingmgr.common.exceptions_utls import DBException -from trainingmgr.db.common_db_fun import get_data_extraction_in_progress_trainingjobs, \ - change_in_progress_to_failed_by_latest_version, \ - get_all_versions_info_by_name, \ - update_model_download_url, \ - get_field_of_given_version +from trainingmgr.db.common_db_fun import get_data_extraction_in_progress_trainingjobs from trainingmgr.models import db, TrainingJob, FeatureGroup from trainingmgr.schemas import ma, TrainingJobSchema , FeatureGroupSchema from trainingmgr.db.featuregroup_db import add_featuregroup, edit_featuregroup, get_feature_groups_db, \ @@ -61,7 +57,9 @@ from trainingmgr.db.featuregroup_db import add_featuregroup, edit_featuregroup, from trainingmgr.db.trainingjob_db import add_update_trainingjob, get_trainingjob_info_by_name, \ get_all_jobs_latest_status_version, change_steps_state_of_latest_version, get_info_by_version, \ get_steps_state_db, change_field_of_latest_version, get_latest_version_trainingjob_name, get_info_of_latest_version, \ - change_field_value_by_version, delete_trainingjob_version + change_field_value_by_version, delete_trainingjob_version, change_in_progress_to_failed_by_latest_version, \ + update_model_download_url, get_all_versions_info_by_name +from trainingmgr.common.trainingConfig_parser import validateTrainingConfig, getField APP = Flask(__name__) @@ -82,6 +80,7 @@ NOT_LIST="not given as list" trainingjob_schema = TrainingJobSchema() trainingjobs_schema = TrainingJobSchema(many=True) + @APP.errorhandler(APIException) def error(err): """ @@ -92,7 +91,7 @@ def error(err): status=err.code, mimetype=MIMETYPE_JSON) - +# Training-Config Handled @APP.route('/trainingjobs//', methods=['GET']) def get_trainingjob_by_name_version(trainingjob_name, version): """ @@ -174,23 +173,24 @@ def get_trainingjob_by_name_version(trainingjob_name, version): if trainingjob: dict_data = { "trainingjob_name": trainingjob.trainingjob_name, - "description": trainingjob.description, - "feature_list": trainingjob.feature_group_name, - "pipeline_name": trainingjob.pipeline_name, - "experiment_name": trainingjob.experiment_name, - "arguments": trainingjob.arguments, - "query_filter": trainingjob.query_filter, + "training_config": json.loads(trainingjob.training_config), + # "description": trainingjob.description, + # "feature_list": trainingjob.feature_group_name, + # "pipeline_name": trainingjob.pipeline_name, + # "experiment_name": trainingjob.experiment_name, + # "arguments": trainingjob.arguments, + # "query_filter": trainingjob.query_filter, "creation_time": str(trainingjob.creation_time), "run_id": trainingjob.run_id, "steps_state": json.loads(trainingjob.steps_state), "updation_time": str(trainingjob.updation_time), "version": trainingjob.version, - "enable_versioning": trainingjob.enable_versioning, - "pipeline_version": trainingjob.pipeline_version, - "datalake_source": get_one_key(json.loads(trainingjob.datalake_source)['datalake_source']), + # "enable_versioning": trainingjob.enable_versioning, + # "pipeline_version": trainingjob.pipeline_version, + # "datalake_source": get_one_key(json.loads(trainingjob.datalake_source)['datalake_source']), "model_url": trainingjob.model_url, "notification_url": trainingjob.notification_url, - "is_mme": trainingjob.is_mme, + # "is_mme": trainingjob.is_mme, "model_name": trainingjob.model_name, "model_info": trainingjob.model_info, "accuracy": data @@ -210,7 +210,8 @@ def get_trainingjob_by_name_version(trainingjob_name, version): status=response_code, mimetype=MIMETYPE_JSON) -@APP.route('/trainingjobs///steps_state', methods=['GET']) # Handled in GUI +# Training-Config Handled (No Change) +@APP.route('/trainingjobs///steps_state', methods=['GET']) def get_steps_state(trainingjob_name, version): """ Function handling rest end points to get steps_state information for @@ -278,6 +279,7 @@ def get_steps_state(trainingjob_name, version): status=response_code, mimetype=MIMETYPE_JSON) +# Training-Config Handled (No Change) @APP.route('/model///Model.zip', methods=['GET']) def get_model(trainingjob_name, version): """ @@ -306,8 +308,8 @@ def get_model(trainingjob_name, version): except Exception: return {"Exception": "error while downloading model"}, status.HTTP_500_INTERNAL_SERVER_ERROR - -@APP.route('/trainingjobs//training', methods=['POST']) # Handled in GUI +# Training-Config Handled +@APP.route('/trainingjobs//training', methods=['POST']) def training(trainingjob_name): """ Rest end point to start training job. @@ -346,8 +348,9 @@ def training(trainingjob_name): "(trainingjob: " + trainingjob_name + ")") from None else: - trainingjob = get_trainingjob_info_by_name(trainingjob_name) - featuregroup= get_feature_group_by_name_db(trainingjob.feature_group_name) + trainingjob = get_trainingjob_info_by_name(trainingjob_name) + + featuregroup= get_feature_group_by_name_db(getField(trainingjob.training_config, "feature_group_name")) feature_list_string = featuregroup.feature_list influxdb_info_dic={} influxdb_info_dic["host"]=featuregroup.host @@ -357,8 +360,8 @@ def training(trainingjob_name): influxdb_info_dic["db_org"] = featuregroup.db_org influxdb_info_dic["source_name"]= featuregroup.source_name _measurement = featuregroup.measurement - query_filter = trainingjob.query_filter - datalake_source = json.loads(trainingjob.datalake_source)['datalake_source'] + query_filter = getField(trainingjob.training_config, "query_filter") + datalake_source = {featuregroup.datalake_source: {}} # Datalake source should be taken from FeatureGroup (not TrainingJob) LOGGER.debug('Starting Data Extraction...') de_response = data_extraction_start(TRAININGMGR_CONFIG_OBJ, trainingjob_name, feature_list_string, query_filter, datalake_source, @@ -386,11 +389,13 @@ def training(trainingjob_name): raise TMException("Data extraction doesn't send json type response" + \ "(trainingjob name is " + trainingjob_name + ")") from None except Exception as err: + # print(traceback.format_exc()) response_data = {"Exception": str(err)} LOGGER.debug("Error is training, job name:" + trainingjob_name + str(err)) return APP.response_class(response=json.dumps(response_data),status=response_code, mimetype=MIMETYPE_JSON) +# Training-Config Handled @APP.route('/trainingjob/dataExtractionNotification', methods=['POST']) def data_extraction_notification(): """ @@ -426,13 +431,17 @@ def data_extraction_notification(): trainingjob_name = request.json["trainingjob_name"] trainingjob = get_trainingjob_info_by_name(trainingjob_name) - arguments = trainingjob.arguments + arguments = getField(trainingjob.training_config, "arguments") arguments["version"] = trainingjob.version + # Arguments values must be of type string + for key, val in arguments.items(): + if not isinstance(val, str): + arguments[key] = str(val) LOGGER.debug(arguments) - + # Experiment name is harded to be Default dict_data = { - "pipeline_name": trainingjob.pipeline_name, "experiment_name": trainingjob.experiment_name, - "arguments": arguments, "pipeline_version": trainingjob.pipeline_version + "pipeline_name": getField(trainingjob.training_config, "pipeline_name"), "experiment_name": 'Default', + "arguments": arguments, "pipeline_version": getField(trainingjob.training_config, "pipeline_version") } response = training_start(TRAININGMGR_CONFIG_OBJ, dict_data, trainingjob_name) @@ -483,6 +492,7 @@ def data_extraction_notification(): status=status.HTTP_200_OK, mimetype=MIMETYPE_JSON) +# Training-Config Handled (No Change) @APP.route('/pipelines/', methods=['GET']) def get_pipeline_info_by_name(pipe_name): """ @@ -525,7 +535,8 @@ def get_pipeline_info_by_name(pipe_name): return APP.response_class(response=json.dumps(api_response), status=response_code, mimetype=MIMETYPE_JSON) - + +# Training-Config Handled .. @APP.route('/trainingjob/pipelineNotification', methods=['POST']) def pipeline_notification(): """ @@ -578,7 +589,7 @@ def pipeline_notification(): model_url = "http://" + str(TRAININGMGR_CONFIG_OBJ.my_ip) + ":" + \ str(TRAININGMGR_CONFIG_OBJ.my_port) + "/model/" + \ trainingjob_name + "/" + str(version) + "/Model.zip" - + update_model_download_url(trainingjob_name, version, model_url, PS_DB_OBJ) @@ -588,7 +599,7 @@ def pipeline_notification(): # upload to the mme trainingjob_info=get_trainingjob_info_by_name(trainingjob_name) - is_mme= trainingjob_info.is_mme + is_mme = getField(trainingjob_info.training_config, "is_mme") if is_mme: model_name=trainingjob_info.model_name #model_name file=MM_SDK.get_model_zip(trainingjob_name, str(version)) @@ -624,7 +635,7 @@ def pipeline_notification(): "Pipeline notification success.", LOGGER, True, trainingjob_name, MM_SDK) - +# Training-Config Handled (No Change) @APP.route('/trainingjobs/latest', methods=['GET']) def trainingjobs_operations(): """ @@ -673,6 +684,7 @@ def trainingjobs_operations(): status=response_code, mimetype=MIMETYPE_JSON) +# Training-Config Handled (No Change) .. @APP.route("/pipelines//upload", methods=['POST']) def upload_pipeline(pipe_name): """ @@ -771,8 +783,7 @@ def upload_pipeline(pipe_name): mimetype=MIMETYPE_JSON) - - +# Training-Config Handled (No Change) @APP.route("/pipelines//versions", methods=['GET']) def get_versions_for_pipeline(pipeline_name): """ @@ -826,6 +837,7 @@ def get_versions_for_pipeline(pipeline_name): status=response_code, mimetype=MIMETYPE_JSON) +# Training-Config Handled (No Change) @APP.route('/pipelines', methods=['GET']) def get_pipelines(): """ @@ -859,6 +871,7 @@ def get_pipelines(): api_response = {"Exception": str(err)} return APP.response_class(response=json.dumps(api_response),status=response_code,mimetype=MIMETYPE_JSON) +# Training-Config Handled (No Change) @APP.route('/experiments', methods=['GET']) def get_all_experiment_names(): """ @@ -907,11 +920,11 @@ def get_all_experiment_names(): status=reponse_code, mimetype=MIMETYPE_JSON) - -@APP.route('/trainingjobs/', methods=['POST', 'PUT']) # Handled in GUI +# Training-Config handled +@APP.route('/trainingjobs/', methods=['POST', 'PUT']) def trainingjob_operations(trainingjob_name): """ - Rest endpoind to create or update trainingjob + Rest endpoint to create or update trainingjob Precondtion for update : trainingjob's overall_status should be failed or finished and deletion processs should not be in progress @@ -922,34 +935,32 @@ def trainingjob_operations(trainingjob_name): Args in json: if post/put request is called json with below fields are given: - description: str - description - featuregroup_name: str - feature group name - pipeline_name: str - name of pipeline - experiment_name: str - name of experiment - arguments: dict - key-value pairs related to hyper parameters and - "trainingjob": key-value pair - query_filter: str - string indication sql where clause for filtering out features - enable_versioning: bool - flag for trainingjob versioning - pipeline_version: str - pipeline version - datalake_source: str - string indicating datalake source - _measurement: str - _measurement for influx db datalake - bucket: str - bucket name for influx db datalake - is_mme: boolean - whether mme is enabled - model_name: str - name of the model - + modelName: str + Name of model + trainingConfig: dict + Training-Configurations, parameter as follows + is_mme: boolean + whether mme is enabled + description: str + description + dataPipeline: dict + Configurations related to dataPipeline, parameter as follows + feature_group_name: str + feature group name + query_filter: str + string indication sql where clause for filtering out features + arguments: dict + key-value pairs related to hyper parameters and + "trainingjob": key-value pair + trainingPipeline: dict + Configurations related to trainingPipeline, parameter as follows + pipeline_name: str + name of pipeline + pipeline_version: str + pipeline version + enable_versioning: bool + flag for trainingjob versioning + Returns: 1. For post request json: @@ -972,6 +983,10 @@ def trainingjob_operations(trainingjob_name): if not check_trainingjob_name_or_featuregroup_name(trainingjob_name): return {"Exception":"The trainingjob_name is not correct"}, status.HTTP_400_BAD_REQUEST + trainingConfig = request.json["training_config"] + if(not validateTrainingConfig(trainingConfig)): + return {"Exception":"The TrainingConfig is not correct"}, status.HTTP_400_BAD_REQUEST + LOGGER.debug("Training job create/update request(trainingjob name %s) ", trainingjob_name ) try: json_data = request.json @@ -982,9 +997,11 @@ def trainingjob_operations(trainingjob_name): response_code = status.HTTP_409_CONFLICT raise TMException("trainingjob name(" + trainingjob_name + ") is already present in database") else: - trainingjob = trainingjob_schema.load(request.get_json()) + processed_json_data = request.get_json() + processed_json_data['training_config'] = json.dumps(request.get_json()["training_config"]) + trainingjob = trainingjob_schema.load(processed_json_data) model_info="" - if trainingjob.is_mme: + if getField(trainingjob.training_config, "is_mme"): pipeline_dict =json.loads(TRAININGMGR_CONFIG_OBJ.pipeline) model_info=get_model_info(TRAININGMGR_CONFIG_OBJ, trainingjob.model_name) s=model_info["meta-info"]["feature-list"] @@ -1013,7 +1030,9 @@ def trainingjob_operations(trainingjob_name): response_code = status.HTTP_404_NOT_FOUND raise TMException("Trainingjob name(" + trainingjob_name + ") is not present in database") else: - trainingjob = trainingjob_schema.load(request.get_json()) + processed_json_data = request.get_json() + processed_json_data['training_config'] = json.dumps(request.get_json()["training_config"]) + trainingjob = trainingjob_schema.load(processed_json_data) trainingjob_info = get_trainingjob_info_by_name(trainingjob_name) if trainingjob_info: if trainingjob_info.deletion_in_progress: @@ -1034,6 +1053,7 @@ def trainingjob_operations(trainingjob_name): status= response_code, mimetype=MIMETYPE_JSON) +# Training-Config Handled (No Change) .. @APP.route('/trainingjobs/retraining', methods=['POST']) def retraining(): """ @@ -1137,6 +1157,7 @@ def retraining(): status=status.HTTP_200_OK, mimetype='application/json') +# Training-Config Handled (No Change) .. @APP.route('/trainingjobs', methods=['DELETE']) def delete_list_of_trainingjob_version(): """ @@ -1264,6 +1285,7 @@ def delete_list_of_trainingjob_version(): status=status.HTTP_200_OK, mimetype='application/json') +# Training-Config Handled (No Change) @APP.route('/trainingjobs/metadata/') def get_metadata(trainingjob_name): """ @@ -1646,6 +1668,7 @@ def delete_list_of_feature_group(): status=status.HTTP_200_OK, mimetype='application/json') +# Training-Config Handled (No Change) def async_feature_engineering_status(): """ This function takes trainingjobs from DATAEXTRACTION_JOBS_CACHE and checks data extraction status -- 2.16.6