trainingConfig = prepocessTrainingConfig(trainingConfig)
fieldPath = __getLeafPaths()[fieldname]
return parse_dict_by_fields(trainingConfig, fieldPath)
-
\ No newline at end of file
+
+def setField(trainingConfig, fieldname, value):
+ """
+ Set the value of a specific field in the training configuration.
+ """
+ trainingConfig = prepocessTrainingConfig(trainingConfig)
+ fieldPath = __getLeafPaths()[fieldname]
+ if not fieldPath:
+ raise KeyError(f"Field '{fieldname}' is not defined in the schema.")
+
+ # Traverse the path and set the value
+ current_node = trainingConfig
+ for i, key in enumerate(fieldPath):
+ if i == len(fieldPath) - 1:
+ # Last key in the path, set the value
+ current_node[key] = value
+ else:
+ # Traverse or create intermediate dictionaries
+ if key not in current_node or not isinstance(current_node[key], dict):
+ current_node[key] = {}
+ current_node = current_node[key]
+ return trainingConfig
\ No newline at end of file
get_steps_state
from trainingmgr.common.trainingmgr_util import check_key_in_dictionary
from trainingmgr.common.trainingConfig_parser import validateTrainingConfig
-
+from trainingmgr.service.mme_service import get_modelinfo_by_modelId_service
training_job_controller = Blueprint('training_job_controller', __name__)
LOGGER = TrainingMgrConfig().logger
if(not validateTrainingConfig(trainingConfig)):
return jsonify({'Exception': 'The TrainingConfig is not correct'}), status.HTTP_400_BAD_REQUEST
- #check if trainingjob is already present with name
+ # check if trainingjob is already present with name
trainingjob_db = get_trainingjob_by_modelId(model_id)
if trainingjob_db != None:
return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is already present in database"}), status.HTTP_409_CONFLICT
- create_training_job(trainingjob)
+ # Verify if the modelId is registered over mme or not
+
+ registered_model_dict = get_modelinfo_by_modelId_service(model_id.modelname, model_id.modelversion)
+ if registered_model_dict is None:
+ return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is not registered at MME, Please first register at MME and then continue"}), status.HTTP_400_BAD_REQUEST
+ create_training_job(trainingjob, registered_model_dict)
return jsonify({"Trainingjob": trainingjob_schema.dump(trainingjob)}), 201
"""
return FeatureGroup.query.filter_by(featuregroup_name=featuregroup_name).first()
+def get_feature_groups_from_inputDataType_db(inputDataType):
+ """
+ This Function return all feature group with feature_list as "inputDataType"
+ Return type is a list of tuples
+ """
+ try:
+ return FeatureGroup.query.with_entities(FeatureGroup.featuregroup_name).filter_by(feature_list=inputDataType).all()
+ except Exception as err:
+ raise DBException("Unable to query in get_feature_groups_from_inputDataType_db with error : ", err)
+
def delete_feature_group_by_name(featuregroup_name):
"""
This function is used to delete the feature group from db
db.session.commit()
return
+
+
--- /dev/null
+# ==================================================================================
+#
+# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==================================================================================
+
+from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
+import requests
+from trainingmgr.common.exceptions_utls import TMException
+from flask_api import status
+import requests
+import json
+
+LOGGER = TrainingMgrConfig().logger
+
+# Constants
+MIMETYPE_JSON = "application/json"
+ERROR_TYPE_KF_ADAPTER_JSON = "Kf adapter doesn't sends json type response"
+
+
+class MmeMgr:
+ __instance = None
+
+ def __new__(cls):
+ if cls.__instance is None:
+ cls.__instance = super(MmeMgr, cls).__new__(cls)
+ cls.__instance.__initialized = False
+ return cls.__instance
+
+ def __init__(self):
+ if self.__initialized:
+ return
+
+ self.mme_ip = TrainingMgrConfig().model_management_service_ip
+ self.mme_port = TrainingMgrConfig().model_management_service_port
+
+ self.__initialized = True
+
+
+ def get_modelInfo_by_modelId(self, modelName, modelVersion):
+ """
+ This function returns the model information for given modelName and ModelVersion from MME
+ """
+ try:
+ url = f'http://{self.mme_ip}:{self.mme_port}/getModelInfo/?modelName={modelName}&modelVersion={modelVersion}'
+ LOGGER.debug(f"Requesting modelInfo from: {url}")
+ response = requests.get(url)
+ if response.status_code == 200:
+ return response.json()
+ elif response.status_code == 404:
+ # The modelinfo is NOT FOUND i.e. model is not registered
+ LOGGER.debug(f"ModelName = {modelName}, ModelVersion = {modelVersion} is not registered on MME")
+ return None
+ else:
+ err_msg = f"Unexpected response from KFAdapter: {response.status_code}"
+ LOGGER.error(err_msg)
+ return TMException(err_msg)
+
+ except requests.RequestException as err:
+ err_msg = f"Error communicating with MME : {str(err)}"
+ LOGGER.error(err_msg)
+ raise TMException(err_msg)
+ except Exception as err:
+ err_msg = f"Unexpected error in get_modelInfo_by_modelId: {str(err)}"
+ LOGGER.error(err_msg)
+ raise TMException(err_msg)
+
+
\ No newline at end of file
#
# ==================================================================================
-from trainingmgr.db.featuregroup_db import get_feature_group_by_name_db
+from trainingmgr.db.featuregroup_db import get_feature_group_by_name_db, get_feature_groups_from_inputDataType_db
from trainingmgr.common.exceptions_utls import TMException, DBException
from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
featuregroup = get_feature_group_by_name_db(featuregroup_name)
return featuregroup
except DBException as err:
- raise TMException(f"get featuregroup by name service failed with exception : {str(err)}")
\ No newline at end of file
+ raise TMException(f"get featuregroup by name service failed with exception : {str(err)}")
+
+def get_featuregroup_from_inputDataType(inputDataType):
+ LOGGER.debug(f'Deducing featuregroupName from InputDataType : {inputDataType}')
+ try:
+ candidate_list = get_feature_groups_from_inputDataType_db(inputDataType)
+ LOGGER.debug(f'Candidates for inputDataType {inputDataType} are f{candidate_list}')
+ if(len(candidate_list) == 0):
+ raise TMException(f'No featureGroup is available for inputDataType {inputDataType}')
+ elif(len(candidate_list) == 1):
+ selected_featuregroup = candidate_list[0][0]
+ LOGGER.debug(f'FeatureGroup Selected for InputDataType {inputDataType} is {selected_featuregroup}')
+ return selected_featuregroup
+ else:
+ raise TMException(f'2 or more featureGroup are available for inputDataType : {inputDataType}\n Available featuregroups are {candidate_list}\n Please specify one featuregroup_name in trainingConfig to resolve conflict')
+ except DBException as err:
+ raise TMException(f"get get_featuregroup_from_inputDataType service failed with exception : {str(err)}")
+ except Exception as err:
+ raise err
+
+
+
\ No newline at end of file
--- /dev/null
+# ==================================================================================
+#
+# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==================================================================================
+
+from trainingmgr.pipeline.mme_mgr import MmeMgr
+from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
+
+LOGGER = TrainingMgrConfig().logger
+mmeMgrObj = MmeMgr()
+
+def get_modelinfo_by_modelId_service(model_name, model_version):
+ LOGGER.debug(f'get_modelinfo_by_modelId_service from MME service where model_name = {model_name}, model_version = {model_version}')
+ return mmeMgrObj.get_modelInfo_by_modelId(model_name, model_version)
+
\ No newline at end of file
from trainingmgr.db.trainingjob_db import delete_trainingjob_by_id, create_trainingjob, get_trainingjob, get_trainingjob_by_modelId_db, \
change_steps_state, change_field_value
from trainingmgr.common.exceptions_utls import DBException, TMException
+from trainingmgr.common.trainingConfig_parser import getField, setField
from trainingmgr.schemas import TrainingJobSchema
from trainingmgr.common.trainingmgr_util import get_one_word_status, get_step_in_progress_state
from trainingmgr.constants.steps import Steps
from trainingmgr.constants.states import States
from trainingmgr.service.pipeline_service import terminate_training_service
+from trainingmgr.service.featuregroup_service import get_featuregroup_from_inputDataType
+from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
+
trainingJobSchema = TrainingJobSchema()
trainingJobsSchema = TrainingJobSchema(many=True)
+LOGGER = TrainingMgrConfig().logger
+
def get_training_job(training_job_id: int):
try:
tj =get_trainingjob(training_job_id)
except DBException as err:
raise TMException(f"get_training_jobs failed with exception : {str(err)}")
-def create_training_job(trainingjob):
+def create_training_job(trainingjob, registered_model_dict):
try:
+ # First-of all we need to resolve featureGroupname from inputDatatype
+ training_config = trainingjob.training_config
+ feature_group_name = getField(training_config, "feature_group_name")
+ if feature_group_name == "":
+ # User has not provided feature_group_name, then it MUST be deduced from Registered InputDataType
+ feature_group_name = get_featuregroup_from_inputDataType(registered_model_dict['modelinfo']['modelInformation']['inputDataType'])
+ trainingjob.training_config = json.dumps(setField(training_config, "feature_group_name", feature_group_name))
+ LOGGER.debug("Training Config after FeatureGroup deduction --> " + trainingjob.training_config)
+
create_trainingjob(trainingjob)
except DBException as err:
raise TMException(f"create_training_job failed with exception : {str(err)}")