Return details of pipelines instead of their names. 37/13437/2
authorTaewan Kim <t25.kim@samsung.com>
Wed, 25 Sep 2024 12:59:14 +0000 (21:59 +0900)
committerTaewan Kim <t25.kim@samsung.com>
Wed, 25 Sep 2024 13:04:53 +0000 (22:04 +0900)
It relays the return value from kubeflow-adapter.

Issue-ID: AIMLFW-146

Change-Id: Ib223e363a810ddb5861e83ab7ac75ac5c50e8c97
Signed-off-by: Taewan Kim <t25.kim@samsung.com>
tests/test_tm_apis.py
tests/test_trainingmgr_util.py
trainingmgr/common/trainingmgr_util.py
trainingmgr/trainingmgr_main.py

index 4e0dea2..d03eb8b 100644 (file)
@@ -676,11 +676,11 @@ class Test_get_versions_for_pipeline:
     
     @patch('trainingmgr.trainingmgr_main.TRAININGMGR_CONFIG_OBJ', return_value = mocked_TRAININGMGR_CONFIG_OBJ)
     @patch('trainingmgr.trainingmgr_main.requests.get', return_value = the_response)
-    @patch('trainingmgr.trainingmgr_main.get_all_pipeline_names_svc', return_value=[
-               "qoe_pipeline"
-       ])
+    @patch('trainingmgr.trainingmgr_main.get_pipelines_details', return_value=
+            {"next_page_token":"next-page-token","pipelines":[{"created_at":"created-at","description":"pipeline-description","display_name":"pipeline-name","pipeline_id":"pipeline-id"}],"total_size":"total-size"}
+       )
     def test_get_versions_for_pipeline_positive(self,mock1,mock2, mock3):
-        response = self.client.get("/pipelines/{}/versions".format("qoe_pipeline"))     
+        response = self.client.get("/pipelines/{}/versions".format("pipeline-name"))
         trainingmgr_main.LOGGER.debug(response.data)
         assert response.content_type == "application/json", "not equal content type"
         assert response.status_code == 200, "Return status code NOT equal"   
@@ -719,7 +719,7 @@ class Test_get_versions_for_pipeline:
         print(response.data)
         assert response.content_type != "application/text", "not equal content type"
     
-class Test_get_all_pipeline_names:
+class Test_get_pipelines_details:
     def setup_method(self):
         self.client = trainingmgr_main.APP.test_client(self)
         self.logger = trainingmgr_main.LOGGER
@@ -731,20 +731,20 @@ class Test_get_all_pipeline_names:
     the_response.headers={"content-type": "application/json"}
     the_response._content = b'{ "exp1":"id1","exp2":"id2"}'
     @patch('trainingmgr.trainingmgr_main.requests.get', return_value = the_response)
-    def test_get_all_pipeline_names(self,mock1):
+    def test_get_pipelines_details(self,mock1):
         response = self.client.get("/pipelines")      
         assert response.content_type == "application/json", "not equal content type"
         assert response.status_code == 500, "Return status code NOT equal"   
         
     @patch('trainingmgr.trainingmgr_main.requests.get', side_effect = requests.exceptions.ConnectionError('Mocked error'))
-    def test_negative_get_all_pipeline_names_1(self,mock1):
+    def test_negative_get_pipelines_details_1(self,mock1):
         response = self.client.get("/pipelines")       
         print(response.data)
         assert response.content_type == "application/json", "not equal content type"
         assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR, "Should have thrown the exception "
         
     @patch('trainingmgr.trainingmgr_main.requests.get', side_effect = TypeError('Mocked error'))
-    def test_negative_get_all_pipeline_names_2(self,mock1):
+    def test_negative_get_pipelines_details_2(self,mock1):
         response = self.client.get("/pipelines")       
         print(response.data)
         assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR, "Should have thrown the exception "
@@ -756,7 +756,7 @@ class Test_get_all_pipeline_names:
     the_response1.headers={"content-type": "application/text"}
     the_response1._content = b'{ "exp1":"id1","exp2":"id2"}'
     @patch('trainingmgr.trainingmgr_main.requests.get', return_value = the_response1)
-    def test_negative_get_all_pipeline_names_3(self,mock1):
+    def test_negative_get_pipelines_details_3(self,mock1):
         response = self.client.get("/pipelines")       
         print(response.data)
         assert response.content_type != "application/text", "not equal content type"
index bfe6cd8..9884a87 100644 (file)
@@ -37,7 +37,7 @@ from trainingmgr.common.tmgr_logger import TMLogger
 from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
 from trainingmgr.common.trainingmgr_util import response_for_training, check_key_in_dictionary,check_trainingjob_data, \
     get_one_key, get_metrics, handle_async_feature_engineering_status_exception_case, get_one_word_status, check_trainingjob_data, \
-    validate_trainingjob_name, get_all_pipeline_names_svc, check_feature_group_data, get_feature_group_by_name, edit_feature_group_by_name
+    validate_trainingjob_name, get_pipelines_details, check_feature_group_data, get_feature_group_by_name, edit_feature_group_by_name
 from requests.models import Response   
 from trainingmgr import trainingmgr_main
 from trainingmgr.common.tmgr_logger import TMLogger
@@ -528,7 +528,7 @@ class Test_validate_trainingjob_name:
         except TMException as err:
             assert str(err) == "The name of training job is invalid."
 
-class Test_get_all_pipeline_names_svc:
+class Test_get_pipelines_details:
     # testing the get_all_pipeline service
     def setup_method(self):
         self.client = trainingmgr_main.APP.test_client(self)
@@ -539,7 +539,7 @@ class Test_get_all_pipeline_names_svc:
     the_response.error_type = "expired"
     the_response.status_code = 200
     the_response.headers={"content-type": "application/json"}
-    the_response._content = b'{ "qoe_Pipeline":"id1"}'
+    the_response._content = b'{"next_page_token":"next-page-token","pipelines":[{"created_at":"created-at","description":"pipeline-description","display_name":"pipeline-name","pipeline_id":"pipeline-id"}],"total_size":"total-size"}'
 
     mocked_TRAININGMGR_CONFIG_OBJ=mock.Mock(name="TRAININGMGR_CONFIG_OBJ")
     attrs_TRAININGMGR_CONFIG_OBJ = {'kf_adapter_ip.return_value': '123', 'kf_adapter_port.return_value' : '100'}
@@ -547,9 +547,9 @@ class Test_get_all_pipeline_names_svc:
     
     @patch('trainingmgr.trainingmgr_main.TRAININGMGR_CONFIG_OBJ', return_value = mocked_TRAININGMGR_CONFIG_OBJ)
     @patch('trainingmgr.trainingmgr_main.requests.get', return_value = the_response)
-    def test_get_all_pipeline_names(self,mock1, mock2):
-        expected_data=['qoe_Pipeline']
-        assert get_all_pipeline_names_svc(self.mocked_TRAININGMGR_CONFIG_OBJ) ==expected_data, "Not equal"
+    def test_get_pipelines_details(self,mock1, mock2):
+        expected_data="next-page-token"
+        assert get_pipelines_details(self.mocked_TRAININGMGR_CONFIG_OBJ)["next_page_token"] == expected_data, "Not equal"
 
 class Test_check_feature_group_data:
     @patch('trainingmgr.common.trainingmgr_util.check_key_in_dictionary',return_value=True)
@@ -772,4 +772,4 @@ class Test_edit_feature_group_by_name:
         assert status_code == 400, "status code is not equal"
         assert json_data == expected_data, json_data
 
-    # TODO: Test Code in the case where DME is edited from enabled to disabled)
\ No newline at end of file
+    # TODO: Test Code in the case where DME is edited from enabled to disabled)
index e350682..21247bf 100644 (file)
@@ -370,10 +370,7 @@ def validate_trainingjob_name(trainingjob_name, ps_db_obj):
         isavailable = True
     return isavailable    
 
-def get_all_pipeline_names_svc(training_config_obj):
-    # This function returns all the pipeline names 
-
-    pipeline_names = []
+def get_pipelines_details(training_config_obj):
     logger=training_config_obj.logger
     try:
         kf_adapter_ip = training_config_obj.kf_adapter_ip
@@ -386,12 +383,9 @@ def get_all_pipeline_names_svc(training_config_obj):
             err_smg = ERROR_TYPE_KF_ADAPTER_JSON
             logger.error(err_smg)
             raise TMException(err_smg)
-        for pipeline in response.json().keys():
-            pipeline_names.append(pipeline)
     except Exception as err:
         logger.error(str(err))
-    logger.debug(pipeline_names)
-    return pipeline_names
+    return response.json()
 
 def check_trainingjob_name_and_version(trainingjob_name, version):
     if (re.fullmatch(PATTERN, trainingjob_name) and version.isnumeric()):
index a600e70..b62fd29 100644 (file)
@@ -40,7 +40,7 @@ from trainingmgr.common.trainingmgr_util import get_one_word_status, check_train
     check_key_in_dictionary, get_one_key, \
     response_for_training, get_metrics, \
     handle_async_feature_engineering_status_exception_case, \
-    validate_trainingjob_name, get_all_pipeline_names_svc, check_feature_group_data, check_trainingjob_name_and_version, check_trainingjob_name_or_featuregroup_name, \
+    validate_trainingjob_name, get_pipelines_details, check_feature_group_data, check_trainingjob_name_and_version, check_trainingjob_name_or_featuregroup_name, \
     get_feature_group_by_name, edit_feature_group_by_name
 from trainingmgr.common.exceptions_utls import APIException,TMException
 from trainingmgr.constants.steps import Steps
@@ -743,11 +743,10 @@ def get_versions_for_pipeline(pipeline_name):
     LOGGER.debug("Request to get all version for given pipeline(" + pipeline_name + ").")
     response_code = status.HTTP_500_INTERNAL_SERVER_ERROR
     try:
-        pipeline_names=get_all_pipeline_names_svc(TRAININGMGR_CONFIG_OBJ)
-        print(pipeline_names, pipeline_name)
-        for pipeline in pipeline_names:
-            if pipeline == pipeline_name:
-                valid_pipeline=pipeline
+        pipelines = get_pipelines_details(TRAININGMGR_CONFIG_OBJ)
+        for pipeline in pipelines['pipelines']:
+            if pipeline['display_name'] == pipeline_name:
+                valid_pipeline = pipeline['display_name']
                 break
         if valid_pipeline == "":
             raise TMException("Pipeline name not present")
@@ -771,7 +770,7 @@ def get_versions_for_pipeline(pipeline_name):
             mimetype=MIMETYPE_JSON)
  
 @APP.route('/pipelines', methods=['GET'])
-def get_all_pipeline_names():
+def get_pipelines():
     """
     Function handling rest endpoint to get all pipeline names.
 
@@ -795,8 +794,8 @@ def get_all_pipeline_names():
     api_response = {}
     response_code = status.HTTP_500_INTERNAL_SERVER_ERROR
     try:
-        pipeline_names=get_all_pipeline_names_svc(TRAININGMGR_CONFIG_OBJ)
-        api_response = {"pipeline_names": pipeline_names}
+        pipelines = get_pipelines_details(TRAININGMGR_CONFIG_OBJ)
+        api_response = pipelines
         response_code = status.HTTP_200_OK
     except Exception as err:
         LOGGER.error(str(err))