FeatureGroup Resolution from InputDataType 19/13819/1
authorashishj1729 <jain.ashish@samsung.com>
Thu, 5 Dec 2024 21:36:52 +0000 (03:06 +0530)
committerashishj1729 <jain.ashish@samsung.com>
Thu, 5 Dec 2024 21:39:06 +0000 (03:09 +0530)
Change-Id: I75255e289cbf4da5cf1f59103450ea0e56c28f16
Signed-off-by: ashishj1729 <jain.ashish@samsung.com>
trainingmgr/common/trainingConfig_parser.py
trainingmgr/controller/trainingjob_controller.py
trainingmgr/db/featuregroup_db.py
trainingmgr/pipeline/mme_mgr.py [new file with mode: 0644]
trainingmgr/service/featuregroup_service.py
trainingmgr/service/mme_service.py [new file with mode: 0644]
trainingmgr/service/training_job_service.py

index b3155fc..844e836 100644 (file)
@@ -87,4 +87,25 @@ def getField(trainingConfig, fieldname):
     trainingConfig = prepocessTrainingConfig(trainingConfig)
     fieldPath = __getLeafPaths()[fieldname]
     return parse_dict_by_fields(trainingConfig, fieldPath)
-    
\ No newline at end of file
+
+def setField(trainingConfig, fieldname, value):
+    """
+    Set the value of a specific field in the training configuration.
+    """
+    trainingConfig = prepocessTrainingConfig(trainingConfig)
+    fieldPath = __getLeafPaths()[fieldname]
+    if not fieldPath:
+        raise KeyError(f"Field '{fieldname}' is not defined in the schema.")
+    
+    # Traverse the path and set the value
+    current_node = trainingConfig
+    for i, key in enumerate(fieldPath):
+        if i == len(fieldPath) - 1:
+            # Last key in the path, set the value
+            current_node[key] = value
+        else:
+            # Traverse or create intermediate dictionaries
+            if key not in current_node or not isinstance(current_node[key], dict):
+                current_node[key] = {}
+            current_node = current_node[key]
+    return trainingConfig
\ No newline at end of file
index 8f2cf1e..9ea0972 100644 (file)
@@ -27,7 +27,7 @@ from trainingmgr.service.training_job_service import delete_training_job, create
 get_steps_state
 from trainingmgr.common.trainingmgr_util import check_key_in_dictionary
 from trainingmgr.common.trainingConfig_parser import validateTrainingConfig
-
+from trainingmgr.service.mme_service import get_modelinfo_by_modelId_service
 training_job_controller = Blueprint('training_job_controller', __name__)
 LOGGER = TrainingMgrConfig().logger
 
@@ -73,13 +73,18 @@ def create_trainingjob():
         if(not validateTrainingConfig(trainingConfig)):
             return jsonify({'Exception': 'The TrainingConfig is not correct'}), status.HTTP_400_BAD_REQUEST
         
-        #check if trainingjob is already present with name
+        # check if trainingjob is already present with name
         trainingjob_db = get_trainingjob_by_modelId(model_id)
 
         if trainingjob_db != None:
             return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is already present in database"}), status.HTTP_409_CONFLICT
 
-        create_training_job(trainingjob)
+        # Verify if the modelId is registered over mme or not
+        
+        registered_model_dict = get_modelinfo_by_modelId_service(model_id.modelname, model_id.modelversion)
+        if registered_model_dict is None:
+            return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is not registered at MME, Please first register at MME and then continue"}), status.HTTP_400_BAD_REQUEST
+        create_training_job(trainingjob, registered_model_dict)
 
         return jsonify({"Trainingjob": trainingjob_schema.dump(trainingjob)}), 201
         
index 095ac6b..079c888 100644 (file)
@@ -67,6 +67,16 @@ def get_feature_group_by_name_db(featuregroup_name):
     """
     return FeatureGroup.query.filter_by(featuregroup_name=featuregroup_name).first()
 
+def get_feature_groups_from_inputDataType_db(inputDataType):
+    """
+        This Function return all feature group with feature_list as "inputDataType"
+        Return type is a list of tuples
+    """
+    try:
+        return FeatureGroup.query.with_entities(FeatureGroup.featuregroup_name).filter_by(feature_list=inputDataType).all()
+    except Exception as err:
+        raise DBException("Unable to query in get_feature_groups_from_inputDataType_db with error : ", err)
+
 def delete_feature_group_by_name(featuregroup_name):
     """
     This function is used to delete the feature group from db
@@ -77,3 +87,5 @@ def delete_feature_group_by_name(featuregroup_name):
         db.session.commit()
     return
 
+
+
diff --git a/trainingmgr/pipeline/mme_mgr.py b/trainingmgr/pipeline/mme_mgr.py
new file mode 100644 (file)
index 0000000..377f6f3
--- /dev/null
@@ -0,0 +1,80 @@
+# ==================================================================================
+#
+#       Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#          http://www.apache.org/licenses/LICENSE-2.0
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+#
+# ==================================================================================
+
+from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
+import requests
+from trainingmgr.common.exceptions_utls import TMException
+from flask_api import status
+import requests
+import json
+
+LOGGER = TrainingMgrConfig().logger
+
+# Constants
+MIMETYPE_JSON = "application/json"
+ERROR_TYPE_KF_ADAPTER_JSON = "Kf adapter doesn't sends json type response"
+
+
+class MmeMgr:    
+    __instance = None
+
+    def __new__(cls):
+        if cls.__instance is None:
+            cls.__instance = super(MmeMgr, cls).__new__(cls)
+            cls.__instance.__initialized = False
+        return cls.__instance
+    
+    def __init__(self):
+        if self.__initialized:
+            return
+
+        self.mme_ip = TrainingMgrConfig().model_management_service_ip
+        self.mme_port = TrainingMgrConfig().model_management_service_port
+        
+        self.__initialized = True
+    
+        
+    def get_modelInfo_by_modelId(self, modelName, modelVersion):
+        """
+            This function returns the model information for given modelName and ModelVersion from MME
+        """
+        try:
+            url = f'http://{self.mme_ip}:{self.mme_port}/getModelInfo/?modelName={modelName}&modelVersion={modelVersion}'
+            LOGGER.debug(f"Requesting modelInfo from: {url}")
+            response = requests.get(url)
+            if response.status_code == 200:
+                return response.json()
+            elif response.status_code == 404:
+                # The modelinfo is NOT FOUND i.e. model is not registered
+                LOGGER.debug(f"ModelName = {modelName}, ModelVersion = {modelVersion} is not registered on MME")
+                return None                
+            else:
+                err_msg = f"Unexpected response from KFAdapter: {response.status_code}"
+                LOGGER.error(err_msg)
+                return TMException(err_msg)
+
+        except requests.RequestException as err:
+            err_msg = f"Error communicating with MME : {str(err)}"
+            LOGGER.error(err_msg)
+            raise TMException(err_msg)
+        except Exception as err:
+            err_msg = f"Unexpected error in get_modelInfo_by_modelId: {str(err)}"
+            LOGGER.error(err_msg)
+            raise TMException(err_msg)
+        
+    
\ No newline at end of file
index 3c0a354..af0a9d9 100644 (file)
@@ -16,7 +16,7 @@
 #
 # ==================================================================================
 
-from trainingmgr.db.featuregroup_db import get_feature_group_by_name_db
+from trainingmgr.db.featuregroup_db import get_feature_group_by_name_db, get_feature_groups_from_inputDataType_db
 from trainingmgr.common.exceptions_utls import TMException, DBException
 from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
 
@@ -28,4 +28,25 @@ def get_featuregroup_by_name(featuregroup_name:str):
         featuregroup = get_feature_group_by_name_db(featuregroup_name)
         return featuregroup
     except DBException as err:
-        raise TMException(f"get featuregroup by name service failed with exception : {str(err)}") 
\ No newline at end of file
+        raise TMException(f"get featuregroup by name service failed with exception : {str(err)}")
+    
+def get_featuregroup_from_inputDataType(inputDataType):
+    LOGGER.debug(f'Deducing featuregroupName from InputDataType : {inputDataType}')
+    try:
+        candidate_list = get_feature_groups_from_inputDataType_db(inputDataType)
+        LOGGER.debug(f'Candidates for inputDataType {inputDataType} are f{candidate_list}')
+        if(len(candidate_list) == 0):
+            raise TMException(f'No featureGroup is available for inputDataType {inputDataType}')
+        elif(len(candidate_list) == 1):
+            selected_featuregroup = candidate_list[0][0]
+            LOGGER.debug(f'FeatureGroup Selected for InputDataType {inputDataType} is {selected_featuregroup}')
+            return selected_featuregroup
+        else:
+            raise TMException(f'2 or more featureGroup are available for inputDataType : {inputDataType}\n Available featuregroups are {candidate_list}\n Please specify one featuregroup_name in trainingConfig to resolve conflict')
+    except DBException as err:
+        raise TMException(f"get get_featuregroup_from_inputDataType service failed with exception : {str(err)}")
+    except Exception as err:
+        raise err
+            
+        
+    
\ No newline at end of file
diff --git a/trainingmgr/service/mme_service.py b/trainingmgr/service/mme_service.py
new file mode 100644 (file)
index 0000000..b552dd4
--- /dev/null
@@ -0,0 +1,28 @@
+# ==================================================================================
+#
+#       Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#          http://www.apache.org/licenses/LICENSE-2.0
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+#
+# ==================================================================================
+
+from trainingmgr.pipeline.mme_mgr import MmeMgr
+from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
+
+LOGGER = TrainingMgrConfig().logger
+mmeMgrObj = MmeMgr()
+
+def get_modelinfo_by_modelId_service(model_name, model_version):
+    LOGGER.debug(f'get_modelinfo_by_modelId_service from MME service where model_name = {model_name}, model_version = {model_version}')
+    return mmeMgrObj.get_modelInfo_by_modelId(model_name, model_version)
+    
\ No newline at end of file
index d84ca8e..cf92a51 100644 (file)
@@ -19,14 +19,20 @@ import json
 from trainingmgr.db.trainingjob_db import delete_trainingjob_by_id, create_trainingjob, get_trainingjob, get_trainingjob_by_modelId_db, \
 change_steps_state, change_field_value
 from trainingmgr.common.exceptions_utls import DBException, TMException
+from trainingmgr.common.trainingConfig_parser import getField, setField
 from trainingmgr.schemas import TrainingJobSchema
 from trainingmgr.common.trainingmgr_util import get_one_word_status, get_step_in_progress_state
 from trainingmgr.constants.steps import Steps
 from trainingmgr.constants.states import States
 from trainingmgr.service.pipeline_service import terminate_training_service
+from trainingmgr.service.featuregroup_service import get_featuregroup_from_inputDataType
+from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
+
 trainingJobSchema = TrainingJobSchema()
 trainingJobsSchema = TrainingJobSchema(many=True)
 
+LOGGER = TrainingMgrConfig().logger
+
 def get_training_job(training_job_id: int):
     try:
         tj =get_trainingjob(training_job_id)
@@ -41,8 +47,17 @@ def get_trainining_jobs():
     except DBException as err:
         raise TMException(f"get_training_jobs failed with exception : {str(err)}")
 
-def create_training_job(trainingjob):
+def create_training_job(trainingjob, registered_model_dict):
     try:
+        # First-of all we need to resolve featureGroupname from inputDatatype
+        training_config = trainingjob.training_config
+        feature_group_name = getField(training_config, "feature_group_name")
+        if feature_group_name == "":
+            # User has not provided feature_group_name, then it MUST be deduced from Registered InputDataType
+            feature_group_name = get_featuregroup_from_inputDataType(registered_model_dict['modelinfo']['modelInformation']['inputDataType'])
+            trainingjob.training_config = json.dumps(setField(training_config, "feature_group_name", feature_group_name))
+            LOGGER.debug("Training Config after FeatureGroup deduction --> " + trainingjob.training_config)
+            
         create_trainingjob(trainingjob)
     except DBException as err:
         raise TMException(f"create_training_job failed with exception : {str(err)}")