changes in the start training 14/13714/4
authorrajdeep11 <rajdeep.sin@samsung.com>
Wed, 30 Oct 2024 06:15:03 +0000 (11:45 +0530)
committersubhash kumar singh <subh.singh@samsung.com>
Wed, 30 Oct 2024 10:08:30 +0000 (10:08 +0000)
Change-Id: I73676734367b89016697ba21c2cdc46864dd6d2f
Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
trainingmgr/common/trainingmgr_util.py
trainingmgr/trainingmgr_main.py

index f6cbaa8..4b8879b 100644 (file)
@@ -27,13 +27,14 @@ import requests
 from marshmallow import ValidationError
 from trainingmgr.db.common_db_fun import change_in_progress_to_failed_by_latest_version, \
     get_field_by_latest_version, change_field_of_latest_version, \
-    get_latest_version_trainingjob_name, get_all_versions_info_by_name
+    get_latest_version_trainingjob_name
 from trainingmgr.db.featuregroup_db import add_featuregroup, edit_featuregroup, get_feature_groups_db, \
 get_feature_group_by_name_db, delete_feature_group_by_name
 from trainingmgr.constants.states import States
 from trainingmgr.common.exceptions_utls import APIException,TMException,DBException
 from trainingmgr.common.trainingmgr_operations import create_dme_filtered_data_job
 from trainingmgr.schemas import ma, TrainingJobSchema , FeatureGroupSchema
+from trainingmgr.db.trainingjob_db import get_all_versions_info_by_name
 
 ERROR_TYPE_KF_ADAPTER_JSON = "Kf adapter doesn't sends json type response"
 MIMETYPE_JSON = "application/json"
@@ -331,7 +332,7 @@ def handle_async_feature_engineering_status_exception_case(lock, dataextraction_
             except KeyError as key_err:
                 logger.error("The training job key doesn't exist in DATAEXTRACTION_JOBS_CACHE: " + str(key_err))
 
-def validate_trainingjob_name(trainingjob_name, ps_db_obj):
+def validate_trainingjob_name(trainingjob_name):
     """
     This function returns True if given trainingjob_name exists in db otherwise
     it returns False.
@@ -343,13 +344,13 @@ def validate_trainingjob_name(trainingjob_name, ps_db_obj):
         raise TMException("The name of training job is invalid.")
 
     try:
-        results = get_all_versions_info_by_name(trainingjob_name, ps_db_obj)
+        results = get_all_versions_info_by_name(trainingjob_name)
     except Exception as err:
         errmsg = str(err)
         raise DBException("Could not get info from db for " + trainingjob_name + "," + errmsg)
     if results:
         isavailable = True
-    return isavailable    
+    return isavailable     
 
 def get_pipelines_details(training_config_obj):
     logger=training_config_obj.logger
index 2dc12a2..a8859ea 100644 (file)
@@ -51,9 +51,9 @@ from trainingmgr.db.trainingmgr_ps_db import PSDB
 from trainingmgr.common.exceptions_utls import DBException
 from trainingmgr.db.common_db_fun import get_data_extraction_in_progress_trainingjobs, \
     change_field_of_latest_version, \
-    change_in_progress_to_failed_by_latest_version, change_steps_state_of_latest_version, \
+    change_in_progress_to_failed_by_latest_version, \
     get_info_by_version, \
-    get_trainingjob_info_by_name, get_latest_version_trainingjob_name, get_all_versions_info_by_name, \
+    get_latest_version_trainingjob_name, get_all_versions_info_by_name, \
     update_model_download_url, \
     get_field_of_given_version,get_all_jobs_latest_status_version, get_info_of_latest_version, \
     delete_trainingjob_version, change_field_value_by_version
@@ -341,27 +341,26 @@ def training(trainingjob_name):
         return {"Exception":"The trainingjob_name is not correct"}, status.HTTP_400_BAD_REQUEST
     LOGGER.debug("Request for training trainingjob  %s ", trainingjob_name)
     try:
-        isDataAvaible = validate_trainingjob_name(trainingjob_name, PS_DB_OBJ)
+        isDataAvaible = validate_trainingjob_name(trainingjob_name)
         if not isDataAvaible:
             response_code = status.HTTP_404_NOT_FOUND
             raise TMException("Given trainingjob name is not present in database" + \
                         "(trainingjob: " + trainingjob_name + ")") from None
         else:
 
-            db_results = get_trainingjob_info_by_name(trainingjob_name, PS_DB_OBJ)
-            featuregroup_name = db_results[0][2]
-            result= get_feature_group_by_name_db(PS_DB_OBJ, featuregroup_name)
-            feature_list_string = result[0][1]
+            trainingjob = get_trainingjob_info_by_name(trainingjob_name)
+            featuregroup= get_feature_group_by_name_db(trainingjob.feature_group_name)
+            feature_list_string = featuregroup.feature_list
             influxdb_info_dic={}
-            influxdb_info_dic["host"]=result[0][3]
-            influxdb_info_dic["port"]=result[0][4]
-            influxdb_info_dic["bucket"]=result[0][5]
-            influxdb_info_dic["token"]=result[0][6]
-            influxdb_info_dic["db_org"] = result[0][7]
-            influxdb_info_dic["source_name"]= result[0][11]
-            query_filter = db_results[0][6]
-            datalake_source = json.loads(db_results[0][14])['datalake_source']
-            _measurement = result[0][8]
+            influxdb_info_dic["host"]=featuregroup.host
+            influxdb_info_dic["port"]=featuregroup.port
+            influxdb_info_dic["bucket"]=featuregroup.bucket
+            influxdb_info_dic["token"]=featuregroup.token
+            influxdb_info_dic["db_org"] = featuregroup.db_org
+            influxdb_info_dic["source_name"]= featuregroup.source_name
+            _measurement = featuregroup.measurement
+            query_filter = trainingjob.query_filter
+            datalake_source = json.loads(trainingjob.datalake_source)['datalake_source']
             LOGGER.debug('Starting Data Extraction...')
             de_response = data_extraction_start(TRAININGMGR_CONFIG_OBJ, trainingjob_name,
                                          feature_list_string, query_filter, datalake_source,
@@ -369,7 +368,7 @@ def training(trainingjob_name):
             if (de_response.status_code == status.HTTP_200_OK ):
                 LOGGER.debug("Response from data extraction for " + \
                         trainingjob_name + " : " + json.dumps(de_response.json()))
-                change_steps_state_of_latest_version(trainingjob_name, PS_DB_OBJ,
+                change_steps_state_of_latest_version(trainingjob_name,
                                                     Steps.DATA_EXTRACTION.name,
                                                     States.IN_PROGRESS.name)
                 with LOCK:
@@ -393,6 +392,7 @@ def training(trainingjob_name):
         LOGGER.debug("Error is training, job name:" + trainingjob_name + str(err))         
     return APP.response_class(response=json.dumps(response_data),status=response_code,
                             mimetype=MIMETYPE_JSON)
+
 @APP.route('/trainingjob/dataExtractionNotification', methods=['POST'])
 def data_extraction_notification():
     """