Minor Buffix includes: 08/13908/1
authorashishj1729 <jain.ashish@samsung.com>
Thu, 19 Dec 2024 09:56:36 +0000 (15:26 +0530)
committerashishj1729 <jain.ashish@samsung.com>
Thu, 19 Dec 2024 10:00:40 +0000 (15:30 +0530)
1. Featuregroup Lists contains Id now
2. TrainingConfig is returned as json rather than string
3. List Trainingjobs is working now
4. FeatureGroupDeduction changes as per MME response change

Change-Id: I063cc2474c5b981c86e32f5bc8bb81faa3b2005c
Signed-off-by: ashishj1729 <jain.ashish@samsung.com>
trainingmgr/controller/trainingjob_controller.py
trainingmgr/models/trainingjob.py
trainingmgr/schemas/trainingjob_schema.py
trainingmgr/service/training_job_service.py
trainingmgr/trainingmgr_main.py

index 17ae8e8..aab12db 100644 (file)
@@ -93,9 +93,10 @@ def create_trainingjob():
 
         # Verify if the modelId is registered over mme or not
         
-        registered_model_dict = get_modelinfo_by_modelId_service(model_id.modelname, model_id.modelversion)
-        if registered_model_dict is None:
+        registered_model_list = get_modelinfo_by_modelId_service(model_id.modelname, model_id.modelversion)
+        if registered_model_list is None:
             return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is not registered at MME, Please first register at MME and then continue"}), status.HTTP_400_BAD_REQUEST
+        registered_model_dict = registered_model_list[0]
         create_training_job(trainingjob, registered_model_dict)
 
         return jsonify({"Trainingjob": trainingjob_schema.dump(trainingjob)}), 201
@@ -111,7 +112,7 @@ def create_trainingjob():
 def get_trainingjobs():
     LOGGER.debug(f'get the trainingjobs')
     try:
-        resp = trainingjob_schema.dump(get_trainining_jobs())
+        resp = trainingjobs_schema.dump(get_trainining_jobs())
         return jsonify(resp), 200
     except TMException as err:
         return jsonify({
index 7783180..84600d4 100644 (file)
@@ -82,5 +82,5 @@ class TrainingJob(db.Model):
     #     self.training_config = json.dumps(value)
 
     def __repr__(self):
-        return f'<Trainingjob {self.trainingjob_name}>'
+        return f'<Trainingjob {self.id}>'
 
index c297b1e..2e6034b 100644 (file)
@@ -20,8 +20,8 @@ import re
 from trainingmgr.schemas import ma
 from trainingmgr.models import TrainingJob
 from trainingmgr.models.trainingjob import ModelID
-
-from marshmallow import pre_load, validates, ValidationError
+import json
+from marshmallow import pre_load, post_dump, validates, ValidationError
 
 PATTERN = re.compile(r"\w+")
 
@@ -46,5 +46,14 @@ class TrainingJobSchema(ma.SQLAlchemyAutoSchema):
         modeldict = dict(modelname=modelname, modelversion=modelversion)
         data['modelId'] = modeldict
         return data
-         
+    
+    @post_dump(pass_many=True)
+    def trainingConfigtoDict(self, data, many, **kwargs):
+        if many:
+            for index in range(len(data)):
+                data[index]["training_config"] = json.loads(data[index]["training_config"])
+        else:
+            data["training_config"] = json.loads(data["training_config"])   
+        
+        return data
         
\ No newline at end of file
index 72764d0..5e76934 100644 (file)
@@ -56,7 +56,7 @@ def create_training_job(trainingjob, registered_model_dict):
         feature_group_name = getField(training_config, "feature_group_name")
         if feature_group_name == "":
             # User has not provided feature_group_name, then it MUST be deduced from Registered InputDataType
-            feature_group_name = get_featuregroup_from_inputDataType(registered_model_dict['modelinfo']['modelInformation']['inputDataType'])
+            feature_group_name = get_featuregroup_from_inputDataType(registered_model_dict['modelInformation']['inputDataType'])
             trainingjob.training_config = json.dumps(setField(training_config, "feature_group_name", feature_group_name))
             LOGGER.debug("Training Config after FeatureGroup deduction --> " + trainingjob.training_config)
             
index 2c92cfb..693378e 100644 (file)
@@ -81,7 +81,7 @@ NOT_LIST="not given as list"
 
 trainingjob_schema = TrainingJobSchema()
 trainingjobs_schema = TrainingJobSchema(many=True)
-
+featuregroups_schema = FeatureGroupSchema(many=True)
 
 @APP.errorhandler(APIException)
 def error(err):
@@ -766,19 +766,8 @@ def get_feature_group():
     api_response={}
     response_code=status.HTTP_500_INTERNAL_SERVER_ERROR
     try:
-        result= get_feature_groups_db()
-        feature_groups=[]
-        for res in result:
-            dict_data={
-                "featuregroup_name": res.featuregroup_name,
-                "features": res.feature_list,
-                "datalake": res.datalake_source,
-                "dme": res.enable_dme
-                }
-            feature_groups.append(dict_data)
-        api_response={"featuregroups":feature_groups}
-        response_code=status.HTTP_200_OK
-
+        api_response={"featuregroups": featuregroups_schema.dump(get_feature_groups_db())}
+        response_code=status.HTTP_200_OK    
     except Exception as err:
         api_response =   {"Exception": str(err)}
         LOGGER.error(str(err))