From 808360a345c7e9f7e2eca33d20748d724f0a42fe Mon Sep 17 00:00:00 2001 From: rajdeep11 Date: Thu, 7 Dec 2023 13:41:27 +0530 Subject: [PATCH] editing add_training_job and edit_training_job api based on mme Issue-Id: AIMLFW-65 Change-Id: I5eec767b26657a290bf0acf2489724eabbfd55a8 Signed-off-by: rajdeep11 --- tests/test.env | 3 ++ tests/test_tm_apis.py | 12 +++++--- tests/test_trainingmgr_util.py | 8 ++--- trainingmgr/common/trainingmgr_config.py | 44 +++++++++++++++++++++++++++- trainingmgr/common/trainingmgr_operations.py | 15 ++++++++++ trainingmgr/common/trainingmgr_util.py | 6 ++-- trainingmgr/trainingmgr_main.py | 43 ++++++++++++++++++++++----- 7 files changed, 112 insertions(+), 19 deletions(-) diff --git a/tests/test.env b/tests/test.env index a5c0048..a7ba1e9 100644 --- a/tests/test.env +++ b/tests/test.env @@ -29,3 +29,6 @@ PS_PASSWORD="abcd" PS_IP="localhost" PS_PORT="30001" ACCESS_CONTROL_ALLOW_ORIGIN="http://localhost:32005" +PIPELINE="{'timeseries':'qoe_pipeline'}" +MODEL_MANAGEMENT_SERVICE_IP=localhost +MODEL_MANAGEMENT_SERVICE_PORT=12343 \ No newline at end of file diff --git a/tests/test_tm_apis.py b/tests/test_tm_apis.py index 4d465a4..436acab 100644 --- a/tests/test_tm_apis.py +++ b/tests/test_tm_apis.py @@ -336,7 +336,9 @@ class Test_training_main: "pipeline_version":"3", "datalake_source":"InfluxSource", "_measurement":"liveCell", - "bucket":"UEData" + "bucket":"UEData", + "is_mme":False, + "model_name": "" } expected_data = b'{"result": "Information stored in database."}' response = self.client.post("/trainingjobs/{}".format("usecase1"), @@ -349,9 +351,9 @@ class Test_training_main: db_result = [('usecase1', 'uc1', '*', 'qoe Pipeline lat v2', 'Default', '{"arguments": {"epochs": "1", "trainingjob_name": "usecase1"}}', '', datetime.datetime(2022, 10, 12, 10, 0, 59, 923588), '51948a12-aee9-42e5-93a0-b8f4a15bca33', '{"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "FINISHED", "TRAINING": "FINISHED", "TRAINING_AND_TRAINED_MODEL": "FINISHED", "TRAINED_MODEL": "FAILED"}', - datetime.datetime(2022, 10, 12, 10, 2, 31, 888830), 1, False, '3', '{"datalake_source": {"InfluxSource": {}}}', 'No data available.', '', 'liveCell', 'UEData', False)] + datetime.datetime(2022, 10, 12, 10, 2, 31, 888830), 1, False, '3', '{"datalake_source": {"InfluxSource": {}}}', 'No data available.', '', 'liveCell', 'UEData', False, False, "","")] - training_data = ('','','','','','','','','','','') + training_data = ('','','','','','','','','','','', '','') @patch('trainingmgr.trainingmgr_main.validate_trainingjob_name', return_value = True) @patch('trainingmgr.trainingmgr_main.get_trainingjob_info_by_name', return_value = db_result) @patch('trainingmgr.trainingmgr_main.check_trainingjob_data', return_value = training_data) @@ -373,7 +375,9 @@ class Test_training_main: "pipeline_version":"3", "datalake_source":"InfluxSource", "_measurement":"liveCell", - "bucket":"UEData" + "bucket":"UEData", + "is_mme": False, + "model_name":"" } expected_data = 'Information updated in database' diff --git a/tests/test_trainingmgr_util.py b/tests/test_trainingmgr_util.py index 1c6e9fb..e655c24 100644 --- a/tests/test_trainingmgr_util.py +++ b/tests/test_trainingmgr_util.py @@ -279,16 +279,16 @@ class Test_check_trainingjob_data: @patch('trainingmgr.common.trainingmgr_util.isinstance',return_value=True) def test_check_trainingjob_data(self,mock1,mock2): usecase_name = "usecase8" - json_data = { "description":"unittest", "featureGroup_name": "group1" , "pipeline_name":"qoe" , "experiment_name":"experiment1" , "arguments":"arguments1" , "query_filter":"query1" , "enable_versioning":True , "target_deployment":"Near RT RIC" , "pipeline_version":1 , "datalake_source":"cassandra db" , "incremental_training":True , "model":"usecase7" , "model_version":1 , "_measurement":2 , "bucket":"bucket1"} + json_data = { "description":"unittest", "featureGroup_name": "group1" , "pipeline_name":"qoe" , "experiment_name":"experiment1" , "arguments":"arguments1" , "query_filter":"query1" , "enable_versioning":True , "target_deployment":"Near RT RIC" , "pipeline_version":1 , "datalake_source":"cassandra db" , "incremental_training":True , "model":"usecase7" , "model_version":1 , "_measurement":2 , "bucket":"bucket1", "is_mme":False, "model_name":""} - expected_data = ("group1", 'unittest', 'qoe', 'experiment1', 'arguments1', 'query1', True, 1, 'cassandra db', 2, 'bucket1') + expected_data = ("group1", 'unittest', 'qoe', 'experiment1', 'arguments1', 'query1', True, 1, 'cassandra db', 2, 'bucket1',False, "") assert check_trainingjob_data(usecase_name, json_data) == expected_data,"data not equal" def test_negative_check_trainingjob_data_1(self): usecase_name = "usecase8" - json_data = { "description":"unittest", "featureGroup_name": "group1" , "pipeline_name":"qoe" , "experiment_name":"experiment1" , "arguments":"arguments1" , "query_filter":"query1" , "enable_versioning":True , "target_deployment":"Near RT RIC" , "pipeline_version":1 , "datalake_source":"cassandra db" , "incremental_training":True , "model":"usecase7" , "model_version":1 , "_measurement":2 , "bucket":"bucket1"} + json_data = { "description":"unittest", "featureGroup_name": "group1" , "pipeline_name":"qoe" , "experiment_name":"experiment1" , "arguments":"arguments1" , "query_filter":"query1" , "enable_versioning":True , "target_deployment":"Near RT RIC" , "pipeline_version":1 , "datalake_source":"cassandra db" , "incremental_training":True , "model":"usecase7" , "model_version":1 , "_measurement":2 , "bucket":"bucket1", "is_mme":False, "model_name":""} - expected_data = ("group1", 'unittest', 'qoe', 'experiment1', 'arguments1', 'query1', True, 1, 'cassandra db', 2, 'bucket1') + expected_data = ("group1", 'unittest', 'qoe', 'experiment1', 'arguments1', 'query1', True, 1, 'cassandra db', 2, 'bucket1',False, "") try: assert check_trainingjob_data(usecase_name, json_data) == expected_data,"data not equal" assert False diff --git a/trainingmgr/common/trainingmgr_config.py b/trainingmgr/common/trainingmgr_config.py index 2bcb0d0..dadbba2 100644 --- a/trainingmgr/common/trainingmgr_config.py +++ b/trainingmgr/common/trainingmgr_config.py @@ -46,7 +46,12 @@ class TrainingMgrConfig: self.__ps_password = getenv('PS_PASSWORD').rstrip() if getenv('PS_PASSWORD') is not None else None self.__ps_ip = getenv('PS_IP').rstrip() if getenv('PS_IP') is not None else None self.__ps_port = getenv('PS_PORT').rstrip() if getenv('PS_PORT') is not None else None + + self.__model_management_service_ip = getenv('MODEL_MANAGEMENT_SERVICE_IP').rstrip() if getenv('MODEL_MANAGEMENT_SERVICE_IP') is not None else None + self.__model_management_service_port = getenv('MODEL_MANAGEMENT_SERVICE_PORT').rstrip() if getenv('MODEL_MANAGEMENT_SERVICE_PORT') is not None else None + self.__allow_control_access_origin = getenv('ACCESS_CONTROL_ALLOW_ORIGIN').rstrip() if getenv('ACCESS_CONTROL_ALLOW_ORIGIN') is not None else None + self.__pipeline = getenv('PIPELINE').rstrip() if getenv('PIPELINE') is not None else None self.tmgr_logger = TMLogger("common/conf_log.yaml") self.__logger = self.tmgr_logger.logger @@ -182,6 +187,29 @@ class TrainingMgrConfig: port number where postgres db is accessible """ return self.__ps_port + + @property + def model_management_service_port(self): + """ + Function for getting model management service port + Args:None + + Returns: + string model_management_service_port + """ + return self.__model_management_service_port + + + @property + def model_management_service_ip(self): + """ + Function for getting model management service ip + Args:None + + Returns: + string model_management_service_ip + """ + return self.__model_management_service_ip @property def allow_control_access_origin(self): @@ -196,6 +224,19 @@ class TrainingMgrConfig: """ return self.__allow_control_access_origin + @property + def pipeline(self): + """ + Function for getting pipelines + + Args: None + + Returns: + string pipelines + + """ + return self.__pipeline + def is_config_loaded_properly(self): """ This function checks where all environment variable got value or not. @@ -207,7 +248,8 @@ class TrainingMgrConfig: for var in [self.__kf_adapter_ip, self.__kf_adapter_port, self.__data_extraction_ip, self.__data_extraction_port, self.__my_port, self.__ps_ip, self.__ps_port, self.__ps_user, - self.__ps_password, self.__my_ip, self.__allow_control_access_origin, self.__logger]: + self.__ps_password, self.__my_ip,self.__model_management_service_ip, self.__model_management_service_port, + self.__allow_control_access_origin,self.__pipeline, self.__logger]: if var is None: all_present = False return all_present diff --git a/trainingmgr/common/trainingmgr_operations.py b/trainingmgr/common/trainingmgr_operations.py index e7b2bd0..a0a976c 100644 --- a/trainingmgr/common/trainingmgr_operations.py +++ b/trainingmgr/common/trainingmgr_operations.py @@ -25,6 +25,7 @@ import json import requests import validators from trainingmgr.common.exceptions_utls import TMException +from flask_api import status MIMETYPE_JSON = "application/json" @@ -167,3 +168,17 @@ def delete_dme_filtered_data_job(training_config_obj, feature_group_name, host, logger.debug(url) response = requests.delete(url) return response + +def get_model_info(training_config_obj, model_name): + logger = training_config_obj.logger + model_management_service_ip = training_config_obj.model_management_service_ip + model_management_service_port = training_config_obj.model_management_service_port + url ="http://"+str(model_management_service_ip)+":"+str(model_management_service_port)+"/getModelInfo/{}".format(model_name) + response = requests.get(url) + if(response.status_code==status.HTTP_200_OK): + model_info=json.loads(response.json()['message']) + return model_info + else: + errMsg="model info can't be fetched, model_name: {} , err: {}".format(model_name, response.text) + logger.error(errMsg) + raise TMException(errMsg) diff --git a/trainingmgr/common/trainingmgr_util.py b/trainingmgr/common/trainingmgr_util.py index ed1a16c..e3e7054 100644 --- a/trainingmgr/common/trainingmgr_util.py +++ b/trainingmgr/common/trainingmgr_util.py @@ -133,7 +133,7 @@ def check_trainingjob_data(trainingjob_name, json_data): "arguments", "enable_versioning", "datalake_source", "description", "query_filter", "_measurement", - "bucket"], json_data): + "bucket", "is_mme", "model_name"], json_data): description = json_data["description"] feature_list = json_data["featureGroup_name"] @@ -149,6 +149,8 @@ def check_trainingjob_data(trainingjob_name, json_data): datalake_source = json_data["datalake_source"] _measurement = json_data["_measurement"] bucket = json_data["bucket"] + is_mme=json_data["is_mme"] + model_name=json_data["model_name"] else : raise TMException("check_trainingjob_data- supplied data doesn't have" + \ "all the required fields ") @@ -157,7 +159,7 @@ def check_trainingjob_data(trainingjob_name, json_data): str(err)) from None return (feature_list, description, pipeline_name, experiment_name, arguments, query_filter, enable_versioning, pipeline_version, - datalake_source, _measurement, bucket) + datalake_source, _measurement, bucket, is_mme, model_name) def check_feature_group_data(json_data): """ diff --git a/trainingmgr/trainingmgr_main.py b/trainingmgr/trainingmgr_main.py index 2278d94..224292e 100644 --- a/trainingmgr/trainingmgr_main.py +++ b/trainingmgr/trainingmgr_main.py @@ -33,7 +33,8 @@ import requests from flask_cors import CORS from werkzeug.utils import secure_filename from modelmetricsdk.model_metrics_sdk import ModelMetricsSdk -from trainingmgr.common.trainingmgr_operations import data_extraction_start, training_start, data_extraction_status, create_dme_filtered_data_job, delete_dme_filtered_data_job +from trainingmgr.common.trainingmgr_operations import data_extraction_start, training_start, data_extraction_status, create_dme_filtered_data_job, delete_dme_filtered_data_job, \ +get_model_info from trainingmgr.common.trainingmgr_config import TrainingMgrConfig from trainingmgr.common.trainingmgr_util import get_one_word_status, check_trainingjob_data, \ check_key_in_dictionary, get_one_key, \ @@ -869,6 +870,10 @@ def trainingjob_operations(trainingjob_name): _measurement for influx db datalake bucket: str bucket name for influx db datalake + is_mme: boolean + whether mme is enabled + model_name: str + name of the model Returns: 1. For post request @@ -904,13 +909,31 @@ def trainingjob_operations(trainingjob_name): else: (featuregroup_name, description, pipeline_name, experiment_name, arguments, query_filter, enable_versioning, pipeline_version, - datalake_source, _measurement, bucket) = \ + datalake_source, _measurement, bucket, is_mme, model_name) = \ check_trainingjob_data(trainingjob_name, json_data) + model_info="" + if is_mme: + pipeline_dict =json.loads(TRAININGMGR_CONFIG_OBJ.pipeline) + model_info=get_model_info(TRAININGMGR_CONFIG_OBJ, model_name) + s=model_info["meta-info"]["feature-list"] + model_type=model_info["meta-info"]["model-type"] + try: + pipeline_name=pipeline_dict[str(model_type)] + except Exception as err: + err="Doesn't support the model type" + raise TMException(err) + pipeline_version=pipeline_name + feature_list=','.join(s) + result= get_feature_groups_db(PS_DB_OBJ) + for res in result: + if feature_list==res[1]: + featuregroup_name=res[0] + break add_update_trainingjob(description, pipeline_name, experiment_name, featuregroup_name, arguments, query_filter, True, enable_versioning, pipeline_version, datalake_source, trainingjob_name, PS_DB_OBJ, _measurement=_measurement, - bucket=bucket) + bucket=bucket, is_mme=is_mme, model_name=model_name, model_info=model_info) api_response = {"result": "Information stored in database."} response_code = status.HTTP_201_CREATED elif(request.method == 'PUT'): @@ -930,14 +953,18 @@ def trainingjob_operations(trainingjob_name): not in [States.FAILED.name, States.FINISHED.name]): raise TMException("Trainingjob(" + trainingjob_name + ") is not in finished or failed status") - (featuregroup_name, description, pipeline_name, experiment_name, - arguments, query_filter, enable_versioning, pipeline_version, - datalake_source, _measurement, bucket) = check_trainingjob_data(trainingjob_name, json_data) - + (featuregroup_name, description, pipeline_name, experiment_name, + arguments, query_filter, enable_versioning, pipeline_version, + datalake_source, _measurement, bucket, is_mme, model_name) = check_trainingjob_data(trainingjob_name, json_data) + if is_mme: + featuregroup_name=results[0][2] + pipeline_name, pipeline_version=results[0][3], results[0][13] + # model name is not changing hence model info is unchanged. + model_info = results[0][22] add_update_trainingjob(description, pipeline_name, experiment_name, featuregroup_name, arguments, query_filter, False, enable_versioning, pipeline_version, datalake_source, trainingjob_name, PS_DB_OBJ, _measurement=_measurement, - bucket=bucket) + bucket=bucket, is_mme=is_mme, model_name=model_name, model_info=model_info) api_response = {"result": "Information updated in database."} response_code = status.HTTP_200_OK except Exception as err: -- 2.16.6