Fix UT for trainingMgrUtils 27/13727/1
authorsubhash kumar singh <subh.singh@samsung.com>
Mon, 4 Nov 2024 13:08:32 +0000 (13:08 +0000)
committersubhash kumar singh <subh.singh@samsung.com>
Mon, 4 Nov 2024 13:08:32 +0000 (13:08 +0000)
- Fixed UT for trainingMgrUtils
- Fixing missing parameter for `create_dme_filtered_data_job()`

Change-Id: Iccf85f85fde683540a74f4bf3a100f1f47f0ff1c
Signed-off-by: subhash kumar singh <subh.singh@samsung.com>
tests/test_trainingmgr_util.py
trainingmgr/common/trainingmgr_util.py

index 903449d..85002c6 100644 (file)
@@ -46,7 +46,6 @@ trainingmgr_main.LOGGER = pytest.logger
 from trainingmgr.models import FeatureGroup
 from trainingmgr.trainingmgr_main import APP
 
-@pytest.mark.skip("")
 class Test_response_for_training:
     def setup_method(self):
         self.client = trainingmgr_main.APP.test_client(self)
@@ -78,11 +77,10 @@ class Test_response_for_training:
         is_success = True
         is_fail = False
         trainingjob_name = "usecase7"
-        ps_db_obj = ()
         mm_sdk = ()
-        result = response_for_training(code_success, message_success, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk)
+        result = response_for_training(code_success, message_success, logger, is_success, trainingjob_name, mm_sdk)
         assert message_success == result[0]['result']
-        result = response_for_training(code_fail, message_fail, logger, is_fail, trainingjob_name, ps_db_obj, mm_sdk)
+        result = response_for_training(code_fail, message_fail, logger, is_fail, trainingjob_name, mm_sdk)
         assert message_fail == result[0]['Exception']
 
     @patch('trainingmgr.common.trainingmgr_util.get_field_by_latest_version', return_value=[['www.google.com','h1','h2'], ['www.google.com','h1','h2'], ['www.google.com','h1','h2']])
@@ -97,10 +95,9 @@ class Test_response_for_training:
         logger = trainingmgr_main.LOGGER
         is_success = True
         trainingjob_name = "usecase7"
-        ps_db_obj = ()
         mm_sdk = ()
         try:
-            response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk)
+            response_for_training(code, message, logger, is_success, trainingjob_name, mm_sdk)
             assert False
         except APIException as err:
             assert err.code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -119,10 +116,9 @@ class Test_response_for_training:
         logger = trainingmgr_main.LOGGER
         is_success = True
         trainingjob_name = "usecase7"
-        ps_db_obj = ()
         mm_sdk = ()
         try:
-            response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk)
+            response_for_training(code, message, logger, is_success, trainingjob_name, mm_sdk)
             assert False
         except APIException as err:
             assert "Failed to notify the subscribed url " + trainingjob_name in err.message
@@ -141,10 +137,9 @@ class Test_response_for_training:
         logger = trainingmgr_main.LOGGER
         is_success = True
         trainingjob_name = "usecase7"
-        ps_db_obj = ()
         mm_sdk = ()
         try:
-            response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk)
+            response_for_training(code, message, logger, is_success, trainingjob_name, mm_sdk)
             assert False
         except APIException as err:
             assert "Failed to notify the subscribed url " + trainingjob_name in err.message
@@ -161,11 +156,10 @@ class Test_response_for_training:
         is_success = True
         is_fail = False
         trainingjob_name = "usecase7"
-        ps_db_obj = ()
         mm_sdk = ()
-        result = response_for_training(code_success, message_success, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk)
+        result = response_for_training(code_success, message_success, logger, is_success, trainingjob_name, mm_sdk)
         assert message_success == result[0]['result']
-        result = response_for_training(code_fail, message_fail, logger, is_fail, trainingjob_name, ps_db_obj, mm_sdk)
+        result = response_for_training(code_fail, message_fail, logger, is_fail, trainingjob_name, mm_sdk)
         assert message_fail == result[0]['Exception']
 
     @patch('trainingmgr.common.trainingmgr_util.get_field_by_latest_version', side_effect = Exception)
@@ -176,10 +170,9 @@ class Test_response_for_training:
         logger = trainingmgr_main.LOGGER
         is_success = True
         trainingjob_name = "usecase7"
-        ps_db_obj = ()
         mm_sdk = ()
         try:
-            response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk)
+            response_for_training(code, message, logger, is_success, trainingjob_name, mm_sdk)
             assert False
         except APIException as err:
             assert err.code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -195,10 +188,9 @@ class Test_response_for_training:
         logger = trainingmgr_main.LOGGER
         is_success = True
         trainingjob_name = "usecase7"
-        ps_db_obj = ()
         mm_sdk = ()
         try:
-            response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk)
+            response_for_training(code, message, logger, is_success, trainingjob_name, mm_sdk)
             assert False
         except APIException as err:
             assert err.code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -215,10 +207,9 @@ class Test_response_for_training:
         logger = trainingmgr_main.LOGGER
         is_success = True
         trainingjob_name = "usecase7"
-        ps_db_obj = ()
         mm_sdk = ()
         try:
-            response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk)
+            response_for_training(code, message, logger, is_success, trainingjob_name, mm_sdk)
             assert False
         except APIException as err:
             assert err.code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -237,15 +228,13 @@ class Test_response_for_training:
         logger = trainingmgr_main.LOGGER
         is_success = True
         trainingjob_name = "usecase7"
-        ps_db_obj = ()
         mm_sdk = ()
         try:
-            response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk)
+            response_for_training(code, message, logger, is_success, trainingjob_name, mm_sdk)
             assert False
         except Exception:
             assert True
 
-@pytest.mark.skip("")
 class Test_check_key_in_dictionary:
     def test_check_key_in_dictionary(self):
         fields = ["model","brand","year"]
@@ -278,7 +267,6 @@ class Test_check_key_in_dictionary:
         except Exception:
             assert True
 
-@pytest.mark.skip("")
 class Test_check_trainingjob_data:    
     @patch('trainingmgr.common.trainingmgr_util.check_key_in_dictionary',return_value=True)
     @patch('trainingmgr.common.trainingmgr_util.isinstance',return_value=True)  
@@ -323,7 +311,6 @@ class Test_check_trainingjob_data:
         except Exception:
             assert True
 
-@pytest.mark.skip("")
 class Test_get_one_key:
     def test_get_one_key(self):
         dictionary = {
@@ -378,7 +365,6 @@ class dummy_mmsdk:
                     }
         return thisdict
     
-@pytest.mark.skip("")
 class Test_get_metrics:   
     @patch('trainingmgr.common.trainingmgr_util.json.dumps',return_value='usecase_data')
     def test_get_metrics_with_version(self,mock1):
@@ -425,7 +411,6 @@ class Test_get_metrics:
         except Exception:
             assert True
 
-@pytest.mark.skip("")
 class dummy_mmsdk_1:
     def check_object(self, param1, param2, param3):
         return False
@@ -438,7 +423,6 @@ class dummy_mmsdk_1:
                     }
         return thisdict
 
-@pytest.mark.skip("")
 class Test_get_metrics_2:   
     @patch('trainingmgr.common.trainingmgr_util.json.dumps',return_value='usecase_data')
     def test_negative_get_metrics_2_1(self,mock1):
@@ -454,7 +438,6 @@ class Test_get_metrics_2:
         except Exception:
             assert True
 
-@pytest.mark.skip("")
 class Test_handle_async_feature_engineering_status_exception_case:
     @patch('trainingmgr.common.trainingmgr_util.change_in_progress_to_failed_by_latest_version',return_value=True)
     @patch('trainingmgr.common.trainingmgr_util.response_for_training',return_value=True)
@@ -466,11 +449,10 @@ class Test_handle_async_feature_engineering_status_exception_case:
            logger = "123"
            is_success = True
            usecase_name = "usecase7"
-           ps_db_obj = () 
            mm_sdk = ()       
            assert handle_async_feature_engineering_status_exception_case(lock, featurestore_job_cache, code,
                                                            message, logger, is_success,
-                                                           usecase_name, ps_db_obj, mm_sdk) == None,"data not equal"
+                                                           usecase_name, mm_sdk) == None,"data not equal"
     
     @patch('trainingmgr.common.trainingmgr_util.change_in_progress_to_failed_by_latest_version',return_value=True)
     @patch('trainingmgr.common.trainingmgr_util.response_for_training',return_value=True)
@@ -496,28 +478,24 @@ class Test_handle_async_feature_engineering_status_exception_case:
            except Exception:
                assert True
 
-@pytest.mark.skip("")
 class Test_get_one_word_status:
     def test_get_one_word_status(self):
            steps_state = [0,1,2,3]
            expected_data = "IN_PROGRESS"
            assert get_one_word_status(steps_state) == expected_data,"data not equal"
 
-@pytest.mark.skip("")
 class Test_validate_trainingjob_name:
     @patch('trainingmgr.common.trainingmgr_util.get_all_versions_info_by_name',return_value=True)
     def test_validate_trainingjob_name_1(self,mock1):
         trainingjob_name = "usecase8"
-        ps_db_obj = ()
         expected_data = True
-        assert validate_trainingjob_name(trainingjob_name,ps_db_obj) == expected_data,"data not equal"
+        assert validate_trainingjob_name(trainingjob_name) == expected_data,"data not equal"
 
     @patch('trainingmgr.common.trainingmgr_util.get_all_versions_info_by_name', side_effect = DBException)
     def test_validate_trainingjob_name_2(self,mock1):
         trainingjob_name = "usecase8"
-        ps_db_obj = ()
         try:
-            validate_trainingjob_name(trainingjob_name,ps_db_obj)
+            validate_trainingjob_name(trainingjob_name)
             assert False
         except DBException as err:
             assert 'Could not get info from db for ' + trainingjob_name in str(err)
@@ -526,22 +504,20 @@ class Test_validate_trainingjob_name:
         short_name = "__"
         long_name = "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
         not_allowed_symbol_name = "case@#"
-        ps_db_obj = ()
         try:
-            validate_trainingjob_name(short_name,ps_db_obj)
+            validate_trainingjob_name(short_name)
             assert False
         except TMException as err:
             assert str(err) == "The name of training job is invalid."
         try:
-            validate_trainingjob_name(long_name,ps_db_obj)
+            validate_trainingjob_name(long_name)
         except TMException as err:
             assert str(err) == "The name of training job is invalid."
         try:
-            validate_trainingjob_name(not_allowed_symbol_name,ps_db_obj)
+            validate_trainingjob_name(not_allowed_symbol_name)
         except TMException as err:
             assert str(err) == "The name of training job is invalid."
 
-@pytest.mark.skip("")
 class Test_get_pipelines_details:
     # testing the get_all_pipeline service
     def setup_method(self):
@@ -565,7 +541,6 @@ class Test_get_pipelines_details:
         expected_data="next-page-token"
         assert get_pipelines_details(self.mocked_TRAININGMGR_CONFIG_OBJ)["next_page_token"] == expected_data, "Not equal"
 
-@pytest.mark.skip("")
 class Test_check_feature_group_data:
     @patch('trainingmgr.common.trainingmgr_util.check_key_in_dictionary',return_value=True)
     def test_check_feature_group_data(self, mock1):
@@ -611,7 +586,6 @@ class Test_check_feature_group_data:
         except:
             assert True
 
-@pytest.mark.skip("")
 class Test_get_feature_group_by_name:
     fg_dict ={'id': 21, 'featuregroup_name': 'testing', 'feature_list': '', 'datalake_source': 'InfluxSource', 'host': '127.0.0.21', 'port': '8086', 'bucket': '', 'token': '', 'db_org': '', 'measurement': '', 'enable_dme': False, 'measured_obj_class': '', 'dme_port': '', 'source_name': ''} 
     featuregroup = FeatureGroup()
@@ -662,7 +636,7 @@ class Test_get_feature_group_by_name:
         assert status_code == 400, "status code is not equal"
         assert json_data == expected_data, json_data
         
-@pytest.mark.skip("")
+
 class Test_edit_feature_group_by_name:
 
     fg_init = [('testing', '', 'InfluxSource', '127.0.0.21', '8080', '', '', '', '', False, '', '', '')]
@@ -672,31 +646,34 @@ class Test_edit_feature_group_by_name:
     
     # In the case where the feature group is edited while DME is disabled
     feature_group_data1=('testing','testing','InfluxSource',False,'127.0.0.1', '8080', '','testing','','','','','')
+    
+    @pytest.fixture
+    def get_sample_feature_group(self):
+        return FeatureGroup(
+        featuregroup_name="SampleFeatureGroup",
+        feature_list="feature1,feature2,feature3",
+        datalake_source="datalake_source_url",
+        host="localhost",
+        port="12345",
+        bucket="my_bucket",
+        token="auth_token",
+        db_org="organization_name",
+        measurement="measurement_name",
+        enable_dme=False,
+        measured_obj_class="object_class",
+        dme_port="6789",
+        source_name="source_name"
+        )
+    
     @patch('trainingmgr.common.trainingmgr_util.edit_featuregroup')
     @patch('trainingmgr.common.trainingmgr_util.check_feature_group_data', return_value=feature_group_data1)
     @patch('trainingmgr.common.trainingmgr_util.get_feature_group_by_name_db', return_value=fg_init)
-    def test_edit_feature_group_by_name_1(self, mock1, mock2, mock3):
+    def test_edit_feature_group_by_name_1(self, mock1, mock2, mock3, get_sample_feature_group):
         tm_conf_obj=()
-        ps_db_obj=()
         logger = trainingmgr_main.LOGGER
-        fg_name='testing'
         expected_data = {"result": "Feature Group Edited"}
-        json_request = {
-                "featureGroupName": fg_name,
-                "feature_list": self.fg_edit[0][1],
-                "datalake_source": self.fg_edit[0][2],
-                "Host": self.fg_edit[0][3],
-                "Port": self.fg_edit[0][4],
-                "bucket": self.fg_edit[0][5],
-                "token": self.fg_edit[0][6],
-                "dbOrg": self.fg_edit[0][7],
-                "_measurement": self.fg_edit[0][8],
-                "enable_Dme": self.fg_edit[0][9],
-                "measured_obj_class": self.fg_edit[0][10],
-                "dmePort": self.fg_edit[0][11],
-                "source_name": self.fg_edit[0][12]
-            }
-        json_data, status_code = edit_feature_group_by_name(tm_conf_obj, ps_db_obj, logger, fg_name, json_request)
+        
+        json_data, status_code = edit_feature_group_by_name(get_sample_feature_group.featuregroup_name, get_sample_feature_group, logger, tm_conf_obj)
         assert status_code == 200, "status code is not equal"
         assert json_data == expected_data, json_data
 
@@ -713,27 +690,13 @@ class Test_edit_feature_group_by_name:
     @patch('trainingmgr.common.trainingmgr_util.check_feature_group_data', return_value=feature_group_data2)
     @patch('trainingmgr.common.trainingmgr_util.get_feature_group_by_name_db', return_value=fg_init)
     @patch('trainingmgr.common.trainingmgr_util.delete_feature_group_by_name')
-    def test_edit_feature_group_by_name_2(self, mock1, mock2, mock3, mock4, mock5, mock6):
-        ps_db_obj=()
+    def test_edit_feature_group_by_name_2(self, mock1, mock2, mock3, mock4, mock5, mock6, get_sample_feature_group):
+        tm_conf_obj=()
         logger = trainingmgr_main.LOGGER
         fg_name='testing'
         expected_data = {"result": "Feature Group Edited"}
-        json_request = {
-                "featureGroupName": fg_name,
-                "feature_list": self.fg_edit[0][1],
-                "datalake_source": self.fg_edit[0][2],
-                "Host": self.fg_edit[0][3],
-                "Port": self.fg_edit[0][4],
-                "bucket": self.fg_edit[0][5],
-                "token": self.fg_edit[0][6],
-                "dbOrg": self.fg_edit[0][7],
-                "_measurement": self.fg_edit[0][8],
-                "enable_Dme": self.fg_edit[0][9],
-                "measured_obj_class": self.fg_edit[0][10],
-                "dmePort": self.fg_edit[0][11],
-                "source_name": self.fg_edit[0][12]
-            }
-        json_data, status_code = edit_feature_group_by_name(self.mocked_TRAININGMGR_CONFIG_OBJ, ps_db_obj, logger, fg_name, json_request)
+
+        json_data, status_code = edit_feature_group_by_name(get_sample_feature_group.featuregroup_name, get_sample_feature_group, logger, tm_conf_obj)
         assert status_code == 200, "status code is not equal"
         assert json_data == expected_data, json_data
     
@@ -747,7 +710,8 @@ class Test_edit_feature_group_by_name:
     @patch('trainingmgr.common.trainingmgr_util.check_feature_group_data', return_value=feature_group_data3)
     @patch('trainingmgr.common.trainingmgr_util.get_feature_group_by_name_db', return_value=fg_init)
     @patch('trainingmgr.common.trainingmgr_util.delete_feature_group_by_name')
-    def test_negative_edit_feature_group_by_name(self, mock1, mock2, mock3, mock4, mock5):
+    @pytest.mark.skip("")
+    def test_negative_edit_feature_group_by_name(self, mock1, mock2, mock3, mock4, mock5, get_sample_feature_group):
         tm_conf_obj=()
         ps_db_obj=()
         logger = trainingmgr_main.LOGGER
@@ -782,7 +746,7 @@ class Test_edit_feature_group_by_name:
         json_data, status_code = edit_feature_group_by_name(tm_conf_obj, ps_db_obj, logger, fg_name, json_request)
         assert status_code == 400, "status code is not equal"
         assert json_data == expected_data, json_data
-
+    @pytest.mark.skip("")
     def test_negative_edit_feature_group_by_name_with_incorrect_name(self):
         tm_conf_obj=()
         ps_db_obj=()
index c203d48..de3f2d3 100644 (file)
@@ -230,7 +230,9 @@ def get_feature_group_by_name(featuregroup_name, logger):
         logger.error(str(err))
     return api_response, response_code
 
-def edit_feature_group_by_name(featuregroup_name, featuregroup, logger, tm_conf_obj):
+from trainingmgr.models.featuregroup import FeatureGroup 
+def edit_feature_group_by_name(featuregroup_name: str, 
+                               featuregroup: FeatureGroup, logger, tm_conf_obj):
     """
     Function fetching a feature group
 
@@ -263,7 +265,9 @@ def edit_feature_group_by_name(featuregroup_name, featuregroup, logger, tm_conf_
         response_code =status.HTTP_200_OK
         # TODO: Implement the process where DME edits from the dashboard are applied to the endpoint
         if featuregroup.enable_dme == True :
-            response= create_dme_filtered_data_job(tm_conf_obj, featuregroup)
+            response= create_dme_filtered_data_job(tm_conf_obj, featuregroup.source_name, featuregroup.feature_list, 
+                                                   featuregroup.host, featuregroup.port, 
+                                                   featuregroup.measured_obj_class)
             if response.status_code != 201:
                 api_response={"Exception": "Cannot create dme job"}
                 delete_feature_group_by_name(featuregroup)