editing add_training_job and edit_training_job api based on mme 82/12182/2
authorrajdeep11 <rajdeep.sin@samsung.com>
Thu, 7 Dec 2023 08:11:27 +0000 (13:41 +0530)
committerrajdeep11 <rajdeep.sin@samsung.com>
Thu, 7 Dec 2023 10:31:52 +0000 (16:01 +0530)
Issue-Id: AIMLFW-65

Change-Id: I5eec767b26657a290bf0acf2489724eabbfd55a8
Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
tests/test.env
tests/test_tm_apis.py
tests/test_trainingmgr_util.py
trainingmgr/common/trainingmgr_config.py
trainingmgr/common/trainingmgr_operations.py
trainingmgr/common/trainingmgr_util.py
trainingmgr/trainingmgr_main.py

index a5c0048..a7ba1e9 100644 (file)
@@ -29,3 +29,6 @@ PS_PASSWORD="abcd"
 PS_IP="localhost"
 PS_PORT="30001"
 ACCESS_CONTROL_ALLOW_ORIGIN="http://localhost:32005"
+PIPELINE="{'timeseries':'qoe_pipeline'}"
+MODEL_MANAGEMENT_SERVICE_IP=localhost
+MODEL_MANAGEMENT_SERVICE_PORT=12343
\ No newline at end of file
index 4d465a4..436acab 100644 (file)
@@ -336,7 +336,9 @@ class Test_training_main:
                     "pipeline_version":"3",
                     "datalake_source":"InfluxSource",
                     "_measurement":"liveCell",
-                    "bucket":"UEData"
+                    "bucket":"UEData",
+                    "is_mme":False,
+                    "model_name": ""
                     }
         expected_data = b'{"result": "Information stored in database."}'
         response = self.client.post("/trainingjobs/{}".format("usecase1"),
@@ -349,9 +351,9 @@ class Test_training_main:
     db_result = [('usecase1', 'uc1', '*', 'qoe Pipeline lat v2', 'Default', '{"arguments": {"epochs": "1", "trainingjob_name": "usecase1"}}',
      '', datetime.datetime(2022, 10, 12, 10, 0, 59, 923588), '51948a12-aee9-42e5-93a0-b8f4a15bca33',
       '{"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "FINISHED", "TRAINING": "FINISHED", "TRAINING_AND_TRAINED_MODEL": "FINISHED", "TRAINED_MODEL": "FAILED"}',
-       datetime.datetime(2022, 10, 12, 10, 2, 31, 888830), 1, False, '3', '{"datalake_source": {"InfluxSource": {}}}', 'No data available.', '', 'liveCell', 'UEData', False)]
+       datetime.datetime(2022, 10, 12, 10, 2, 31, 888830), 1, False, '3', '{"datalake_source": {"InfluxSource": {}}}', 'No data available.', '', 'liveCell', 'UEData', False, False, "","")]
     
-    training_data = ('','','','','','','','','','','')
+    training_data = ('','','','','','','','','','','', '','')
     @patch('trainingmgr.trainingmgr_main.validate_trainingjob_name', return_value = True)
     @patch('trainingmgr.trainingmgr_main.get_trainingjob_info_by_name', return_value = db_result)
     @patch('trainingmgr.trainingmgr_main.check_trainingjob_data', return_value = training_data)
@@ -373,7 +375,9 @@ class Test_training_main:
                     "pipeline_version":"3",
                     "datalake_source":"InfluxSource",
                     "_measurement":"liveCell",
-                    "bucket":"UEData"
+                    "bucket":"UEData", 
+                    "is_mme": False,
+                    "model_name":""
                     }
             
         expected_data = 'Information updated in database'
index 1c6e9fb..e655c24 100644 (file)
@@ -279,16 +279,16 @@ class Test_check_trainingjob_data:
     @patch('trainingmgr.common.trainingmgr_util.isinstance',return_value=True)  
     def test_check_trainingjob_data(self,mock1,mock2):
         usecase_name = "usecase8"
-        json_data = { "description":"unittest", "featureGroup_name": "group1" , "pipeline_name":"qoe" , "experiment_name":"experiment1" , "arguments":"arguments1" , "query_filter":"query1" , "enable_versioning":True , "target_deployment":"Near RT RIC" , "pipeline_version":1 , "datalake_source":"cassandra db" , "incremental_training":True , "model":"usecase7" , "model_version":1 , "_measurement":2 , "bucket":"bucket1"}
+        json_data = { "description":"unittest", "featureGroup_name": "group1" , "pipeline_name":"qoe" , "experiment_name":"experiment1" , "arguments":"arguments1" , "query_filter":"query1" , "enable_versioning":True , "target_deployment":"Near RT RIC" , "pipeline_version":1 , "datalake_source":"cassandra db" , "incremental_training":True , "model":"usecase7" , "model_version":1 , "_measurement":2 , "bucket":"bucket1", "is_mme":False, "model_name":""}
     
-        expected_data = ("group1", 'unittest', 'qoe', 'experiment1', 'arguments1', 'query1', True, 1, 'cassandra db', 2, 'bucket1')
+        expected_data = ("group1", 'unittest', 'qoe', 'experiment1', 'arguments1', 'query1', True, 1, 'cassandra db', 2, 'bucket1',False, "")
         assert check_trainingjob_data(usecase_name, json_data) == expected_data,"data not equal"
     
     def test_negative_check_trainingjob_data_1(self):
         usecase_name = "usecase8"
-        json_data = { "description":"unittest", "featureGroup_name": "group1" , "pipeline_name":"qoe" , "experiment_name":"experiment1" , "arguments":"arguments1" , "query_filter":"query1" , "enable_versioning":True , "target_deployment":"Near RT RIC" , "pipeline_version":1 , "datalake_source":"cassandra db" , "incremental_training":True , "model":"usecase7" , "model_version":1 , "_measurement":2 , "bucket":"bucket1"}
+        json_data = { "description":"unittest", "featureGroup_name": "group1" , "pipeline_name":"qoe" , "experiment_name":"experiment1" , "arguments":"arguments1" , "query_filter":"query1" , "enable_versioning":True , "target_deployment":"Near RT RIC" , "pipeline_version":1 , "datalake_source":"cassandra db" , "incremental_training":True , "model":"usecase7" , "model_version":1 , "_measurement":2 , "bucket":"bucket1", "is_mme":False, "model_name":""}
     
-        expected_data = ("group1", 'unittest', 'qoe', 'experiment1', 'arguments1', 'query1', True, 1, 'cassandra db', 2, 'bucket1')
+        expected_data = ("group1", 'unittest', 'qoe', 'experiment1', 'arguments1', 'query1', True, 1, 'cassandra db', 2, 'bucket1',False, "")
         try:
             assert check_trainingjob_data(usecase_name, json_data) == expected_data,"data not equal"
             assert False
index 2bcb0d0..dadbba2 100644 (file)
@@ -46,7 +46,12 @@ class TrainingMgrConfig:
         self.__ps_password = getenv('PS_PASSWORD').rstrip() if getenv('PS_PASSWORD') is not None else None
         self.__ps_ip = getenv('PS_IP').rstrip() if getenv('PS_IP') is not None else None
         self.__ps_port = getenv('PS_PORT').rstrip() if getenv('PS_PORT') is not None else None
+
+        self.__model_management_service_ip = getenv('MODEL_MANAGEMENT_SERVICE_IP').rstrip() if getenv('MODEL_MANAGEMENT_SERVICE_IP') is not None else None
+        self.__model_management_service_port = getenv('MODEL_MANAGEMENT_SERVICE_PORT').rstrip() if getenv('MODEL_MANAGEMENT_SERVICE_PORT') is not None else None
+
         self.__allow_control_access_origin = getenv('ACCESS_CONTROL_ALLOW_ORIGIN').rstrip() if getenv('ACCESS_CONTROL_ALLOW_ORIGIN') is not None else None
+        self.__pipeline = getenv('PIPELINE').rstrip() if getenv('PIPELINE') is not None else None
 
         self.tmgr_logger = TMLogger("common/conf_log.yaml")
         self.__logger = self.tmgr_logger.logger
@@ -182,6 +187,29 @@ class TrainingMgrConfig:
             port number where postgres db is accessible
         """
         return self.__ps_port
+    
+    @property
+    def model_management_service_port(self):
+        """
+        Function for getting model management service port
+        Args:None
+
+        Returns:
+            string model_management_service_port
+        """
+        return self.__model_management_service_port
+    
+
+    @property
+    def model_management_service_ip(self):
+        """
+        Function for getting model management service ip
+        Args:None
+
+        Returns:
+            string model_management_service_ip
+        """
+        return self.__model_management_service_ip
 
     @property
     def allow_control_access_origin(self):
@@ -196,6 +224,19 @@ class TrainingMgrConfig:
         """
         return self.__allow_control_access_origin
 
+    @property
+    def pipeline(self):
+        """
+        Function for getting pipelines
+
+        Args: None
+
+        Returns:
+            string pipelines
+        
+        """
+        return self.__pipeline
+
     def is_config_loaded_properly(self):
         """
         This function checks where all environment variable got value or not.
@@ -207,7 +248,8 @@ class TrainingMgrConfig:
         for var in [self.__kf_adapter_ip, self.__kf_adapter_port,
                     self.__data_extraction_ip, self.__data_extraction_port,
                     self.__my_port, self.__ps_ip, self.__ps_port, self.__ps_user,
-                    self.__ps_password, self.__my_ip, self.__allow_control_access_origin, self.__logger]:
+                    self.__ps_password, self.__my_ip,self.__model_management_service_ip, self.__model_management_service_port, 
+                    self.__allow_control_access_origin,self.__pipeline, self.__logger]:
             if var is None:
                 all_present = False
         return all_present
index e7b2bd0..a0a976c 100644 (file)
@@ -25,6 +25,7 @@ import json
 import requests
 import validators
 from trainingmgr.common.exceptions_utls import TMException
+from flask_api import status
 
 MIMETYPE_JSON = "application/json"
 
@@ -167,3 +168,17 @@ def delete_dme_filtered_data_job(training_config_obj, feature_group_name, host,
     logger.debug(url)
     response = requests.delete(url)
     return response
+
+def get_model_info(training_config_obj, model_name):
+    logger = training_config_obj.logger
+    model_management_service_ip = training_config_obj.model_management_service_ip
+    model_management_service_port = training_config_obj.model_management_service_port
+    url ="http://"+str(model_management_service_ip)+":"+str(model_management_service_port)+"/getModelInfo/{}".format(model_name)
+    response = requests.get(url)
+    if(response.status_code==status.HTTP_200_OK):
+        model_info=json.loads(response.json()['message'])
+        return model_info
+    else:
+        errMsg="model info can't be fetched, model_name: {} , err: {}".format(model_name, response.text)
+        logger.error(errMsg)
+        raise TMException(errMsg)
index ed1a16c..e3e7054 100644 (file)
@@ -133,7 +133,7 @@ def check_trainingjob_data(trainingjob_name, json_data):
                                  "arguments", "enable_versioning",
                                  "datalake_source", "description",
                                  "query_filter", "_measurement",
-                                 "bucket"], json_data):
+                                 "bucket", "is_mme", "model_name"], json_data):
 
             description = json_data["description"]
             feature_list = json_data["featureGroup_name"]
@@ -149,6 +149,8 @@ def check_trainingjob_data(trainingjob_name, json_data):
             datalake_source = json_data["datalake_source"]
             _measurement = json_data["_measurement"]
             bucket = json_data["bucket"]
+            is_mme=json_data["is_mme"]
+            model_name=json_data["model_name"]
         else :
             raise TMException("check_trainingjob_data- supplied data doesn't have" + \
                                 "all the required fields ")
@@ -157,7 +159,7 @@ def check_trainingjob_data(trainingjob_name, json_data):
                            str(err)) from None
     return (feature_list, description, pipeline_name, experiment_name,
             arguments, query_filter, enable_versioning, pipeline_version,
-            datalake_source, _measurement, bucket)
+            datalake_source, _measurement, bucket, is_mme, model_name)
 
 def check_feature_group_data(json_data):
     """
index 2278d94..224292e 100644 (file)
@@ -33,7 +33,8 @@ import requests
 from flask_cors import CORS
 from werkzeug.utils import secure_filename
 from modelmetricsdk.model_metrics_sdk import ModelMetricsSdk
-from trainingmgr.common.trainingmgr_operations import data_extraction_start, training_start, data_extraction_status, create_dme_filtered_data_job, delete_dme_filtered_data_job
+from trainingmgr.common.trainingmgr_operations import data_extraction_start, training_start, data_extraction_status, create_dme_filtered_data_job, delete_dme_filtered_data_job, \
+get_model_info
 from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
 from trainingmgr.common.trainingmgr_util import get_one_word_status, check_trainingjob_data, \
     check_key_in_dictionary, get_one_key, \
@@ -869,6 +870,10 @@ def trainingjob_operations(trainingjob_name):
                     _measurement for influx db datalake
                 bucket: str
                     bucket name for influx db datalake
+                is_mme: boolean
+                    whether mme is enabled 
+                model_name: str
+                    name of the model
 
     Returns:
         1. For post request
@@ -904,13 +909,31 @@ def trainingjob_operations(trainingjob_name):
             else:
                 (featuregroup_name, description, pipeline_name, experiment_name,
                 arguments, query_filter, enable_versioning, pipeline_version,
-                datalake_source, _measurement, bucket) = \
+                datalake_source, _measurement, bucket, is_mme, model_name) = \
                 check_trainingjob_data(trainingjob_name, json_data)
+                model_info=""
+                if is_mme: 
+                    pipeline_dict =json.loads(TRAININGMGR_CONFIG_OBJ.pipeline)
+                    model_info=get_model_info(TRAININGMGR_CONFIG_OBJ, model_name)
+                    s=model_info["meta-info"]["feature-list"]
+                    model_type=model_info["meta-info"]["model-type"]
+                    try:
+                        pipeline_name=pipeline_dict[str(model_type)]
+                    except Exception as err:
+                        err="Doesn't support the model type"
+                        raise TMException(err)
+                    pipeline_version=pipeline_name
+                    feature_list=','.join(s)
+                    result= get_feature_groups_db(PS_DB_OBJ)
+                    for res in result:
+                        if feature_list==res[1]:
+                            featuregroup_name=res[0]
+                            break 
                 add_update_trainingjob(description, pipeline_name, experiment_name, featuregroup_name,
                                     arguments, query_filter, True, enable_versioning,
                                     pipeline_version, datalake_source, trainingjob_name, 
                                     PS_DB_OBJ, _measurement=_measurement,
-                                    bucket=bucket)
+                                    bucket=bucket, is_mme=is_mme, model_name=model_name, model_info=model_info)
                 api_response =  {"result": "Information stored in database."}                 
                 response_code = status.HTTP_201_CREATED
         elif(request.method == 'PUT'):
@@ -930,14 +953,18 @@ def trainingjob_operations(trainingjob_name):
                             not in [States.FAILED.name, States.FINISHED.name]):
                         raise TMException("Trainingjob(" + trainingjob_name + ") is not in finished or failed status")
 
-                (featuregroup_name, description, pipeline_name, experiment_name,
-                arguments, query_filter, enable_versioning, pipeline_version,
-                datalake_source, _measurement, bucket) = check_trainingjob_data(trainingjob_name, json_data)
-
+                    (featuregroup_name, description, pipeline_name, experiment_name,
+                    arguments, query_filter, enable_versioning, pipeline_version,
+                    datalake_source, _measurement, bucket, is_mme, model_name) = check_trainingjob_data(trainingjob_name, json_data)
+                if is_mme:
+                    featuregroup_name=results[0][2]
+                    pipeline_name, pipeline_version=results[0][3], results[0][13]
+                # model name is not changing hence model info is unchanged.
+                model_info = results[0][22]
                 add_update_trainingjob(description, pipeline_name, experiment_name, featuregroup_name,
                                         arguments, query_filter, False, enable_versioning,
                                         pipeline_version, datalake_source, trainingjob_name, PS_DB_OBJ, _measurement=_measurement,
-                                        bucket=bucket)
+                                        bucket=bucket, is_mme=is_mme, model_name=model_name, model_info=model_info)
                 api_response = {"result": "Information updated in database."}
                 response_code = status.HTTP_200_OK
     except Exception as err: