adding edit and retrain feature 89/10489/3
authorrajdeep11 <rajdeep.sin@samsung.com>
Tue, 21 Feb 2023 08:56:44 +0000 (14:26 +0530)
committerrajdeep11 <rajdeep.sin@samsung.com>
Mon, 27 Feb 2023 10:28:16 +0000 (15:58 +0530)
Issue-Id: AIMLFW-15

Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
Change-Id: Ia121004986bc778f6407aefa77b2aa2331b0aa89
Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
trainingmgr/db/common_db_fun.py
trainingmgr/trainingmgr_main.py

index 1389d81..329138c 100644 (file)
@@ -588,3 +588,26 @@ def get_all_jobs_latest_status_version(ps_db_obj):
                 conn.close()
     return results
 
+def get_info_of_latest_version(trainingjob_name, ps_db_obj):
+    """
+    This function returns information of <trainingjob_name, trainingjob_name trainingjob's latest version>
+    usecase.
+    """
+
+    conn = ps_db_obj.get_new_conn()
+    cursor = conn.cursor()
+    try:
+        cursor.execute('''select nt.mv from (select max(version) mv,trainingjob_name ''' + \
+                       '''from trainingjob_info group by trainingjob_name) nt where nt.trainingjob_name=%s''',
+                       (trainingjob_name,))
+        version = int(cursor.fetchall()[0][0])
+        cursor.execute(''' select * from trainingjob_info where trainingjob_name=%s and version = %s''',
+                       (trainingjob_name, version))
+    except Exception as err:
+        conn.rollback()
+        conn.close()
+        raise err
+    results = cursor.fetchall()
+    conn.commit()
+    conn.close()
+    return results
index 2c65cc4..dc918f7 100644 (file)
@@ -50,7 +50,7 @@ from trainingmgr.db.common_db_fun import get_data_extraction_in_progress_trainin
     get_info_by_version, \
     get_trainingjob_info_by_name, get_latest_version_trainingjob_name, get_all_versions_info_by_name, \
     update_model_download_url, add_update_trainingjob, \
-    get_field_of_given_version,get_all_jobs_latest_status_version
+    get_field_of_given_version,get_all_jobs_latest_status_version, get_info_of_latest_version
 
 APP = Flask(__name__)
 TRAININGMGR_CONFIG_OBJ = None
@@ -937,6 +937,135 @@ def trainingjob_operations(trainingjob_name):
                     status= response_code,
                     mimetype=MIMETYPE_JSON)
 
+@APP.route('/trainingjobs/retraining', methods=['POST'])
+@cross_origin()
+def retraining():
+    """
+    Function handling rest endpoint to retrain trainingjobs in request json. trainingjob's
+    overall_status should be failed or finished and its deletion_in_progress should be False
+    otherwise retraining of that trainingjob is counted in failure.
+    Args in function: none
+    Required Args in json:
+        trainingjobs_list: list
+                       list containing dictionaries
+                           dictionary contains
+                               usecase_name: str
+                                   name of trainingjob
+                               notification_url(optional): str
+                                   url for notification
+                               feature_filter(optional): str
+                                   feature filter
+    Returns:
+        json:
+            success count: int
+                successful retraining count
+            failure count: int
+                failure retraining count
+        status: HTTP status code 200
+    Exceptions:
+        all exception are provided with exception message and HTTP status code.
+    """
+    LOGGER.debug('request comes for retraining, ' + json.dumps(request.json))
+    try:
+        check_key_in_dictionary(["trainingjobs_list"], request.json)
+    except Exception as err:
+        raise APIException(status.HTTP_400_BAD_REQUEST, str(err)) from None
+
+    trainingjobs_list = request.json['trainingjobs_list']
+    print("trainingjobs list is :", trainingjobs_list)
+    if not isinstance(trainingjobs_list, list):
+        raise APIException(status.HTTP_400_BAD_REQUEST, "not given as list")
+
+    for obj in trainingjobs_list:
+        try:
+            check_key_in_dictionary(["trainingjob_name"], obj)
+        except Exception as err:
+            raise APIException(status.HTTP_400_BAD_REQUEST, str(err)) from None
+
+    not_possible_to_retrain = []
+    possible_to_retrain = []
+
+    for obj in trainingjobs_list:
+        trainingjob_name = obj['trainingjob_name']
+        results = None
+        try:
+            results = get_info_of_latest_version(trainingjob_name, PS_DB_OBJ)
+        except Exception as err:
+            not_possible_to_retrain.append(trainingjob_name)
+            LOGGER.debug(str(err) + "(trainingjob_name is " + trainingjob_name + ")")
+            continue
+        
+        if results:
+
+            if results[0][19]:
+                not_possible_to_retrain.append(trainingjob_name)
+                LOGGER.debug("Failed to retrain because deletion in progress" + \
+                             "(trainingjob_name is " + trainingjob_name + ")")
+                continue
+
+            if (get_one_word_status(json.loads(results[0][9]))
+                    not in [States.FINISHED.name, States.FAILED.name]):
+                not_possible_to_retrain.append(trainingjob_name)
+                LOGGER.debug("Not finished or not failed status" + \
+                             "(trainingjob_name is " + trainingjob_name + ")")
+                continue
+
+            enable_versioning = results[0][12]
+            pipeline_version = results[0][13]
+            description = results[0][1]
+            pipeline_name = results[0][3]
+            experiment_name = results[0][4]
+            feature_list = results[0][2]
+            arguments = json.loads(results[0][5])['arguments']
+            query_filter = results[0][6]
+            datalake_source = get_one_key(json.loads(results[0][14])["datalake_source"])
+            _measurement = results[0][17]
+            bucket = results[0][18]
+
+            notification_url = ""
+            if "notification_url" in obj:
+                notification_url = obj['notification_url']
+
+            if "feature_filter" in obj:
+                query_filter = obj['feature_filter']
+
+            try:
+                add_update_trainingjob(description, pipeline_name, experiment_name,
+                                      feature_list, arguments, query_filter, False,
+                                      enable_versioning, pipeline_version,
+                                      datalake_source, trainingjob_name, PS_DB_OBJ,
+                                      notification_url, _measurement, bucket)
+            except Exception as err:
+                not_possible_to_retrain.append(trainingjob_name)
+                LOGGER.debug(str(err) + "(usecase_name is " + trainingjob_name + ")")
+                continue
+
+            url = 'http://' + str(TRAININGMGR_CONFIG_OBJ.my_ip) + \
+                  ':' + str(TRAININGMGR_CONFIG_OBJ.my_port) + \
+                  '/trainingjobs/' +trainingjob_name + '/training'
+            response = requests.post(url)
+
+            if response.status_code == status.HTTP_200_OK:
+                possible_to_retrain.append(trainingjob_name)
+            else:
+                LOGGER.debug("not 200 response" + "(trainingjob_name is " + trainingjob_name + ")")
+                not_possible_to_retrain.append(trainingjob_name)
+
+        else:
+            LOGGER.debug("not present in postgres db" + "(trainingjob_name is " + trainingjob_name + ")")
+            not_possible_to_retrain.append(trainingjob_name)
+
+        LOGGER.debug('success list: ' + str(possible_to_retrain))
+        LOGGER.debug('failure list: ' + str(not_possible_to_retrain))
+
+    return APP.response_class(response=json.dumps( \
+        {
+            "success count": len(possible_to_retrain),
+            "failure count": len(not_possible_to_retrain)
+        }),
+        status=status.HTTP_200_OK,
+        mimetype='application/json')
+
 @APP.route('/trainingjobs/metadata/<trainingjob_name>')
 def get_metadata(trainingjob_name):
     """