From e6f7dce34c78cc0bdbf8d2c53b4862314ad4305e Mon Sep 17 00:00:00 2001 From: ashishj1729 Date: Fri, 6 Dec 2024 03:06:52 +0530 Subject: [PATCH] FeatureGroup Resolution from InputDataType Change-Id: I75255e289cbf4da5cf1f59103450ea0e56c28f16 Signed-off-by: ashishj1729 --- trainingmgr/common/trainingConfig_parser.py | 23 ++++++- trainingmgr/controller/trainingjob_controller.py | 11 +++- trainingmgr/db/featuregroup_db.py | 12 ++++ trainingmgr/pipeline/mme_mgr.py | 80 ++++++++++++++++++++++++ trainingmgr/service/featuregroup_service.py | 25 +++++++- trainingmgr/service/mme_service.py | 28 +++++++++ trainingmgr/service/training_job_service.py | 17 ++++- 7 files changed, 189 insertions(+), 7 deletions(-) create mode 100644 trainingmgr/pipeline/mme_mgr.py create mode 100644 trainingmgr/service/mme_service.py diff --git a/trainingmgr/common/trainingConfig_parser.py b/trainingmgr/common/trainingConfig_parser.py index b3155fc..844e836 100644 --- a/trainingmgr/common/trainingConfig_parser.py +++ b/trainingmgr/common/trainingConfig_parser.py @@ -87,4 +87,25 @@ def getField(trainingConfig, fieldname): trainingConfig = prepocessTrainingConfig(trainingConfig) fieldPath = __getLeafPaths()[fieldname] return parse_dict_by_fields(trainingConfig, fieldPath) - \ No newline at end of file + +def setField(trainingConfig, fieldname, value): + """ + Set the value of a specific field in the training configuration. + """ + trainingConfig = prepocessTrainingConfig(trainingConfig) + fieldPath = __getLeafPaths()[fieldname] + if not fieldPath: + raise KeyError(f"Field '{fieldname}' is not defined in the schema.") + + # Traverse the path and set the value + current_node = trainingConfig + for i, key in enumerate(fieldPath): + if i == len(fieldPath) - 1: + # Last key in the path, set the value + current_node[key] = value + else: + # Traverse or create intermediate dictionaries + if key not in current_node or not isinstance(current_node[key], dict): + current_node[key] = {} + current_node = current_node[key] + return trainingConfig \ No newline at end of file diff --git a/trainingmgr/controller/trainingjob_controller.py b/trainingmgr/controller/trainingjob_controller.py index 8f2cf1e..9ea0972 100644 --- a/trainingmgr/controller/trainingjob_controller.py +++ b/trainingmgr/controller/trainingjob_controller.py @@ -27,7 +27,7 @@ from trainingmgr.service.training_job_service import delete_training_job, create get_steps_state from trainingmgr.common.trainingmgr_util import check_key_in_dictionary from trainingmgr.common.trainingConfig_parser import validateTrainingConfig - +from trainingmgr.service.mme_service import get_modelinfo_by_modelId_service training_job_controller = Blueprint('training_job_controller', __name__) LOGGER = TrainingMgrConfig().logger @@ -73,13 +73,18 @@ def create_trainingjob(): if(not validateTrainingConfig(trainingConfig)): return jsonify({'Exception': 'The TrainingConfig is not correct'}), status.HTTP_400_BAD_REQUEST - #check if trainingjob is already present with name + # check if trainingjob is already present with name trainingjob_db = get_trainingjob_by_modelId(model_id) if trainingjob_db != None: return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is already present in database"}), status.HTTP_409_CONFLICT - create_training_job(trainingjob) + # Verify if the modelId is registered over mme or not + + registered_model_dict = get_modelinfo_by_modelId_service(model_id.modelname, model_id.modelversion) + if registered_model_dict is None: + return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is not registered at MME, Please first register at MME and then continue"}), status.HTTP_400_BAD_REQUEST + create_training_job(trainingjob, registered_model_dict) return jsonify({"Trainingjob": trainingjob_schema.dump(trainingjob)}), 201 diff --git a/trainingmgr/db/featuregroup_db.py b/trainingmgr/db/featuregroup_db.py index 095ac6b..079c888 100644 --- a/trainingmgr/db/featuregroup_db.py +++ b/trainingmgr/db/featuregroup_db.py @@ -67,6 +67,16 @@ def get_feature_group_by_name_db(featuregroup_name): """ return FeatureGroup.query.filter_by(featuregroup_name=featuregroup_name).first() +def get_feature_groups_from_inputDataType_db(inputDataType): + """ + This Function return all feature group with feature_list as "inputDataType" + Return type is a list of tuples + """ + try: + return FeatureGroup.query.with_entities(FeatureGroup.featuregroup_name).filter_by(feature_list=inputDataType).all() + except Exception as err: + raise DBException("Unable to query in get_feature_groups_from_inputDataType_db with error : ", err) + def delete_feature_group_by_name(featuregroup_name): """ This function is used to delete the feature group from db @@ -77,3 +87,5 @@ def delete_feature_group_by_name(featuregroup_name): db.session.commit() return + + diff --git a/trainingmgr/pipeline/mme_mgr.py b/trainingmgr/pipeline/mme_mgr.py new file mode 100644 index 0000000..377f6f3 --- /dev/null +++ b/trainingmgr/pipeline/mme_mgr.py @@ -0,0 +1,80 @@ +# ================================================================================== +# +# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== + +from trainingmgr.common.trainingmgr_config import TrainingMgrConfig +import requests +from trainingmgr.common.exceptions_utls import TMException +from flask_api import status +import requests +import json + +LOGGER = TrainingMgrConfig().logger + +# Constants +MIMETYPE_JSON = "application/json" +ERROR_TYPE_KF_ADAPTER_JSON = "Kf adapter doesn't sends json type response" + + +class MmeMgr: + __instance = None + + def __new__(cls): + if cls.__instance is None: + cls.__instance = super(MmeMgr, cls).__new__(cls) + cls.__instance.__initialized = False + return cls.__instance + + def __init__(self): + if self.__initialized: + return + + self.mme_ip = TrainingMgrConfig().model_management_service_ip + self.mme_port = TrainingMgrConfig().model_management_service_port + + self.__initialized = True + + + def get_modelInfo_by_modelId(self, modelName, modelVersion): + """ + This function returns the model information for given modelName and ModelVersion from MME + """ + try: + url = f'http://{self.mme_ip}:{self.mme_port}/getModelInfo/?modelName={modelName}&modelVersion={modelVersion}' + LOGGER.debug(f"Requesting modelInfo from: {url}") + response = requests.get(url) + if response.status_code == 200: + return response.json() + elif response.status_code == 404: + # The modelinfo is NOT FOUND i.e. model is not registered + LOGGER.debug(f"ModelName = {modelName}, ModelVersion = {modelVersion} is not registered on MME") + return None + else: + err_msg = f"Unexpected response from KFAdapter: {response.status_code}" + LOGGER.error(err_msg) + return TMException(err_msg) + + except requests.RequestException as err: + err_msg = f"Error communicating with MME : {str(err)}" + LOGGER.error(err_msg) + raise TMException(err_msg) + except Exception as err: + err_msg = f"Unexpected error in get_modelInfo_by_modelId: {str(err)}" + LOGGER.error(err_msg) + raise TMException(err_msg) + + \ No newline at end of file diff --git a/trainingmgr/service/featuregroup_service.py b/trainingmgr/service/featuregroup_service.py index 3c0a354..af0a9d9 100644 --- a/trainingmgr/service/featuregroup_service.py +++ b/trainingmgr/service/featuregroup_service.py @@ -16,7 +16,7 @@ # # ================================================================================== -from trainingmgr.db.featuregroup_db import get_feature_group_by_name_db +from trainingmgr.db.featuregroup_db import get_feature_group_by_name_db, get_feature_groups_from_inputDataType_db from trainingmgr.common.exceptions_utls import TMException, DBException from trainingmgr.common.trainingmgr_config import TrainingMgrConfig @@ -28,4 +28,25 @@ def get_featuregroup_by_name(featuregroup_name:str): featuregroup = get_feature_group_by_name_db(featuregroup_name) return featuregroup except DBException as err: - raise TMException(f"get featuregroup by name service failed with exception : {str(err)}") \ No newline at end of file + raise TMException(f"get featuregroup by name service failed with exception : {str(err)}") + +def get_featuregroup_from_inputDataType(inputDataType): + LOGGER.debug(f'Deducing featuregroupName from InputDataType : {inputDataType}') + try: + candidate_list = get_feature_groups_from_inputDataType_db(inputDataType) + LOGGER.debug(f'Candidates for inputDataType {inputDataType} are f{candidate_list}') + if(len(candidate_list) == 0): + raise TMException(f'No featureGroup is available for inputDataType {inputDataType}') + elif(len(candidate_list) == 1): + selected_featuregroup = candidate_list[0][0] + LOGGER.debug(f'FeatureGroup Selected for InputDataType {inputDataType} is {selected_featuregroup}') + return selected_featuregroup + else: + raise TMException(f'2 or more featureGroup are available for inputDataType : {inputDataType}\n Available featuregroups are {candidate_list}\n Please specify one featuregroup_name in trainingConfig to resolve conflict') + except DBException as err: + raise TMException(f"get get_featuregroup_from_inputDataType service failed with exception : {str(err)}") + except Exception as err: + raise err + + + \ No newline at end of file diff --git a/trainingmgr/service/mme_service.py b/trainingmgr/service/mme_service.py new file mode 100644 index 0000000..b552dd4 --- /dev/null +++ b/trainingmgr/service/mme_service.py @@ -0,0 +1,28 @@ +# ================================================================================== +# +# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== + +from trainingmgr.pipeline.mme_mgr import MmeMgr +from trainingmgr.common.trainingmgr_config import TrainingMgrConfig + +LOGGER = TrainingMgrConfig().logger +mmeMgrObj = MmeMgr() + +def get_modelinfo_by_modelId_service(model_name, model_version): + LOGGER.debug(f'get_modelinfo_by_modelId_service from MME service where model_name = {model_name}, model_version = {model_version}') + return mmeMgrObj.get_modelInfo_by_modelId(model_name, model_version) + \ No newline at end of file diff --git a/trainingmgr/service/training_job_service.py b/trainingmgr/service/training_job_service.py index d84ca8e..cf92a51 100644 --- a/trainingmgr/service/training_job_service.py +++ b/trainingmgr/service/training_job_service.py @@ -19,14 +19,20 @@ import json from trainingmgr.db.trainingjob_db import delete_trainingjob_by_id, create_trainingjob, get_trainingjob, get_trainingjob_by_modelId_db, \ change_steps_state, change_field_value from trainingmgr.common.exceptions_utls import DBException, TMException +from trainingmgr.common.trainingConfig_parser import getField, setField from trainingmgr.schemas import TrainingJobSchema from trainingmgr.common.trainingmgr_util import get_one_word_status, get_step_in_progress_state from trainingmgr.constants.steps import Steps from trainingmgr.constants.states import States from trainingmgr.service.pipeline_service import terminate_training_service +from trainingmgr.service.featuregroup_service import get_featuregroup_from_inputDataType +from trainingmgr.common.trainingmgr_config import TrainingMgrConfig + trainingJobSchema = TrainingJobSchema() trainingJobsSchema = TrainingJobSchema(many=True) +LOGGER = TrainingMgrConfig().logger + def get_training_job(training_job_id: int): try: tj =get_trainingjob(training_job_id) @@ -41,8 +47,17 @@ def get_trainining_jobs(): except DBException as err: raise TMException(f"get_training_jobs failed with exception : {str(err)}") -def create_training_job(trainingjob): +def create_training_job(trainingjob, registered_model_dict): try: + # First-of all we need to resolve featureGroupname from inputDatatype + training_config = trainingjob.training_config + feature_group_name = getField(training_config, "feature_group_name") + if feature_group_name == "": + # User has not provided feature_group_name, then it MUST be deduced from Registered InputDataType + feature_group_name = get_featuregroup_from_inputDataType(registered_model_dict['modelinfo']['modelInformation']['inputDataType']) + trainingjob.training_config = json.dumps(setField(training_config, "feature_group_name", feature_group_name)) + LOGGER.debug("Training Config after FeatureGroup deduction --> " + trainingjob.training_config) + create_trainingjob(trainingjob) except DBException as err: raise TMException(f"create_training_job failed with exception : {str(err)}") -- 2.16.6