minor fixes for issues reported in sonar 67/10367/1
authorrajdeep11 <rajdeep.sin@samsung.com>
Thu, 2 Feb 2023 12:07:03 +0000 (17:37 +0530)
committerrajdeep11 <rajdeep.sin@samsung.com>
Thu, 2 Feb 2023 12:08:36 +0000 (17:38 +0530)
Issue-Id: AIMLFW-22

Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
Change-Id: I7f2c03ec31f91a3e56191b4df75a5267842b09a1

tests/test_tm_apis.py
tests/test_trainingmgr_util.py
trainingmgr/common/trainingmgr_util.py
trainingmgr/trainingmgr_main.py

index df261bc..938be4d 100644 (file)
@@ -454,9 +454,13 @@ class Test_get_versions_for_pipeline:
     mocked_TRAININGMGR_CONFIG_OBJ=mock.Mock(name="TRAININGMGR_CONFIG_OBJ")
     attrs_TRAININGMGR_CONFIG_OBJ = {'kf_adapter_ip.return_value': '123', 'kf_adapter_port.return_value' : '100'}
     mocked_TRAININGMGR_CONFIG_OBJ.configure_mock(**attrs_TRAININGMGR_CONFIG_OBJ)
+    
     @patch('trainingmgr.trainingmgr_main.TRAININGMGR_CONFIG_OBJ', return_value = mocked_TRAININGMGR_CONFIG_OBJ)
     @patch('trainingmgr.trainingmgr_main.requests.get', return_value = the_response)
-    def test_get_versions_for_pipeline_positive(self,mock1,mock2):
+    @patch('trainingmgr.trainingmgr_main.get_all_pipeline_names_svc', return_value=[
+               "qoe_pipeline"
+       ])
+    def test_get_versions_for_pipeline_positive(self,mock1,mock2, mock3):
         response = self.client.get("/pipelines/{}/versions".format("qoe_pipeline"))     
         trainingmgr_main.LOGGER.debug(response.data)
         assert response.content_type == "application/json", "not equal content type"
index 3721fed..2700466 100644 (file)
@@ -21,6 +21,7 @@ This file contains the unittesting for Training management utility functions
 """
 from pickle import FALSE
 import sys
+from unittest import mock
 from mock import patch
 from threading import Lock
 import pytest
@@ -36,7 +37,7 @@ from trainingmgr.common.tmgr_logger import TMLogger
 from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
 from trainingmgr.common.trainingmgr_util import response_for_training, check_key_in_dictionary,check_trainingjob_data, \
     get_one_key, get_metrics, handle_async_feature_engineering_status_exception_case, get_one_word_status, check_trainingjob_data, \
-    validate_trainingjob_name
+    validate_trainingjob_name, get_all_pipeline_names_svc
 from requests.models import Response   
 from trainingmgr import trainingmgr_main
 from trainingmgr.common.tmgr_logger import TMLogger
@@ -638,3 +639,26 @@ class Test_validate_trainingjob_name:
             assert False
         except Exception:
             assert True
+
+class Test_get_all_pipeline_names_svc:
+    # testing the get_all_pipeline service
+    def setup_method(self):
+        self.client = trainingmgr_main.APP.test_client(self)
+        self.logger = trainingmgr_main.LOGGER
+    
+    the_response = Response()
+    the_response.code = "expired"
+    the_response.error_type = "expired"
+    the_response.status_code = 200
+    the_response.headers={"content-type": "application/json"}
+    the_response._content = b'{ "qoe_Pipeline":"id1"}'
+
+    mocked_TRAININGMGR_CONFIG_OBJ=mock.Mock(name="TRAININGMGR_CONFIG_OBJ")
+    attrs_TRAININGMGR_CONFIG_OBJ = {'kf_adapter_ip.return_value': '123', 'kf_adapter_port.return_value' : '100'}
+    mocked_TRAININGMGR_CONFIG_OBJ.configure_mock(**attrs_TRAININGMGR_CONFIG_OBJ)
+    
+    @patch('trainingmgr.trainingmgr_main.TRAININGMGR_CONFIG_OBJ', return_value = mocked_TRAININGMGR_CONFIG_OBJ)
+    @patch('trainingmgr.trainingmgr_main.requests.get', return_value = the_response)
+    def test_get_all_pipeline_names(self,mock1, mock2):
+        expected_data=['qoe_Pipeline']
+        assert get_all_pipeline_names_svc(self.mocked_TRAININGMGR_CONFIG_OBJ) ==expected_data, "Not equal"
\ No newline at end of file
index 38474b0..9ff1ad9 100644 (file)
@@ -25,10 +25,12 @@ import requests
 from trainingmgr.db.common_db_fun import change_in_progress_to_failed_by_latest_version, \
     get_field_by_latest_version, change_field_of_latest_version, \
     get_latest_version_trainingjob_name,get_all_versions_info_by_name
-
 from trainingmgr.constants.states import States
 from trainingmgr.common.exceptions_utls import APIException,TMException,DBException
 
+ERROR_TYPE_KF_ADAPTER_JSON = "Kf adapter doesn't sends json type response"
+MIMETYPE_JSON = "application/json"
+
 def response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk):
     """
     Post training job completion,this function provides notifications to the subscribers, 
@@ -221,3 +223,26 @@ def validate_trainingjob_name(trainingjob_name, ps_db_obj):
     if results:
         isavailable = True
     return isavailable    
+
+def get_all_pipeline_names_svc(training_config_obj):
+    # This function returns all the pipeline names 
+
+    pipeline_names = []
+    logger=training_config_obj.logger
+    try:
+        kf_adapter_ip = training_config_obj.kf_adapter_ip
+        kf_adapter_port = training_config_obj.kf_adapter_port
+        if kf_adapter_ip!=None and kf_adapter_port!=None:
+            url = 'http://' + str(kf_adapter_ip) + ':' + str(kf_adapter_port) + '/pipelines'
+        logger.debug(url)
+        response = requests.get(url)
+        if response.headers['content-type'] != MIMETYPE_JSON:
+            err_smg = ERROR_TYPE_KF_ADAPTER_JSON
+            logger.error(err_smg)
+            raise TMException(err_smg)
+        for pipeline in response.json().keys():
+            pipeline_names.append(pipeline)
+    except Exception as err:
+        logger.error(str(err))
+    logger.debug(pipeline_names)
+    return pipeline_names
\ No newline at end of file
index 4c2e082..2c65cc4 100644 (file)
@@ -20,6 +20,7 @@
 This file contains all rest endpoints exposed by Training manager.
 """
 import json
+import re
 from logging import Logger
 import os
 import traceback
@@ -38,7 +39,7 @@ from trainingmgr.common.trainingmgr_util import get_one_word_status, check_train
     check_key_in_dictionary, get_one_key, \
     response_for_training, get_metrics, \
     handle_async_feature_engineering_status_exception_case, \
-    validate_trainingjob_name
+    validate_trainingjob_name, get_all_pipeline_names_svc
 from trainingmgr.common.exceptions_utls import APIException,TMException
 from trainingmgr.constants.steps import Steps
 from trainingmgr.constants.states import States
@@ -621,7 +622,10 @@ def upload_pipeline(pipe_name):
         else:
             result_string = "Didn't get file"
             raise ValueError("file not found in request.files")
-
+        pattern = re.compile(r"[a-zA-Z0-9_]+")
+        if not re.fullmatch(pattern, pipe_name):
+            err_msg="the pipeline name is not valid"
+            raise TMException(err_msg)
         LOGGER.debug("Uploading received for %s", uploaded_file.filename)
         if uploaded_file.filename != '':
             uploaded_file_path = "/tmp/" + secure_filename(uploaded_file.filename)
@@ -658,11 +662,16 @@ def upload_pipeline(pipe_name):
         LOGGER.error(tbk)
         result_code = status.HTTP_500_INTERNAL_SERVER_ERROR
         result_string = "Error while uploading pipeline"
+    except TMException:
+        tbk = traceback.format_exc()
+        LOGGER.error(tbk)
+        result_code = status.HTTP_500_INTERNAL_SERVER_ERROR
+        result_string = "Pipeline name is not of valid format"
     except Exception:
         tbk = traceback.format_exc()
         LOGGER.error(tbk)
         result_code = status.HTTP_500_INTERNAL_SERVER_ERROR
-        result_string = "Error while uploading pipeline"
+        result_string = "Error while uploading pipeline cause"
 
     if uploaded_file_path and os.path.isfile(uploaded_file_path):
         LOGGER.debug("Deleting %s", uploaded_file_path)
@@ -698,15 +707,24 @@ def get_versions_for_pipeline(pipeline_name):
     Exceptions:
         all exception are provided with exception message and HTTP status code.
     """
+    valid_pipeline=""
+    api_response = {}            
     LOGGER.debug("Request to get all version for given pipeline(" + pipeline_name + ").")
-    api_response = {}
     response_code = status.HTTP_500_INTERNAL_SERVER_ERROR
     try:
+        pipeline_names=get_all_pipeline_names_svc(TRAININGMGR_CONFIG_OBJ)
+        print(pipeline_names, pipeline_name)
+        for pipeline in pipeline_names:
+            if pipeline == pipeline_name:
+                valid_pipeline=pipeline
+                break
+        if valid_pipeline == "":
+            raise TMException("Pipeline name not present")
         kf_adapter_ip = TRAININGMGR_CONFIG_OBJ.kf_adapter_ip
         kf_adapter_port = TRAININGMGR_CONFIG_OBJ.kf_adapter_port
         if kf_adapter_ip!=None and kf_adapter_port!=None :
           url = 'http://' + str(kf_adapter_ip) + ':' + str(
-            kf_adapter_port) + '/pipelines/' + pipeline_name + \
+            kf_adapter_port) + '/pipelines/' + valid_pipeline + \
             '/versions'
         LOGGER.debug("URL:" + url)
         response = requests.get(url)
@@ -720,8 +738,7 @@ def get_versions_for_pipeline(pipeline_name):
     return APP.response_class(response=json.dumps(api_response),
             status=response_code,
             mimetype=MIMETYPE_JSON)
-     
-
 @APP.route('/pipelines', methods=['GET'])
 @cross_origin()
 def get_all_pipeline_names():
@@ -745,33 +762,16 @@ def get_all_pipeline_names():
         all exception are provided with exception message and HTTP status code.
     """
     LOGGER.debug("Request to get all getting all pipeline names.")
-    response = None
     api_response = {}
     response_code = status.HTTP_500_INTERNAL_SERVER_ERROR
     try:
-        kf_adapter_ip = TRAININGMGR_CONFIG_OBJ.kf_adapter_ip
-        kf_adapter_port = TRAININGMGR_CONFIG_OBJ.kf_adapter_port
-        if kf_adapter_ip!=None and kf_adapter_port!=None:
-            url = 'http://' + str(kf_adapter_ip) + ':' + str(kf_adapter_port) + '/pipelines'
-        LOGGER.debug(url)
-        response = requests.get(url)
-        if response.headers['content-type'] != MIMETYPE_JSON:
-            err_smg = ERROR_TYPE_KF_ADAPTER_JSON
-            LOGGER.error(err_smg)
-            raise TMException(err_smg)
-        pipeline_names = []
-        for pipeline in response.json().keys():
-            pipeline_names.append(pipeline)
-
+        pipeline_names=get_all_pipeline_names_svc(TRAININGMGR_CONFIG_OBJ)
         api_response = {"pipeline_names": pipeline_names}
-        response_code = status.HTTP_200_OK 
+        response_code = status.HTTP_200_OK
     except Exception as err:
         LOGGER.error(str(err))
         api_response =  {"Exception": str(err)}
-    return APP.response_class(response=json.dumps(api_response),
-                                    status=response_code,
-                                    mimetype=MIMETYPE_JSON)
-
+    return APP.response_class(response=json.dumps(api_response),status=response_code,mimetype=MIMETYPE_JSON)
 
 @APP.route('/experiments', methods=['GET'])
 @cross_origin()