adding tests 69/10569/2
authorrajdeep11 <rajdeep.sin@samsung.com>
Tue, 28 Feb 2023 11:24:26 +0000 (16:54 +0530)
committerrajdeep11 <rajdeep.sin@samsung.com>
Thu, 2 Mar 2023 07:24:04 +0000 (12:54 +0530)
Issue-Id: AIMLFW-15

Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
Change-Id: Ie0850e3699ab1835acec08a0a18be733d0f0c0c9
Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
tests/test_common_db_fun.py
tests/test_tm_apis.py
trainingmgr/trainingmgr_main.py

index 77a054f..eae43c1 100644 (file)
@@ -28,7 +28,7 @@ from trainingmgr.db.common_db_fun import get_data_extraction_in_progress_trainin
      get_trainingjob_info_by_name, get_latest_version_trainingjob_name, \
      get_all_versions_info_by_name, get_all_distinct_trainingjobs, \
      get_all_version_num_by_trainingjob_name, update_model_download_url, \
-     add_update_trainingjob, get_all_jobs_latest_status_version
+     add_update_trainingjob, get_all_jobs_latest_status_version, get_info_of_latest_version
 
 mimic_db = {
             "usecase_name": "Tester",
@@ -439,3 +439,20 @@ class Test_Common_Db_Fun:
             fxn_name = "get_all_versions_info_by_name"
             assert str(err) == "Failed to execute query in get_all_versions_info_by_nameDB Error", 'Negative test {} FAILED, Doesnt returned required error'.format(fxn_name)
             assert checker.finished, 'Cursor Not Closed Properly for fxn {} | Negative Test'.format(fxn_name)
+
+    def test_get_info_of_latest_version(self):
+        db_obj = db_helper([["version"], ["*"]])
+        out = get_info_of_latest_version("Tester", db_obj)
+        assert out != None, 'get_info_of_latest_version FAILED'
+
+
+    def test_negative_get_info_of_latest_version(self):
+        checker = Check()
+        try:
+            db_obj = db_helper([["version"], ["*"]], raise_exception=True,check_success_obj=checker)
+            out = get_info_of_latest_version("Tester", db_obj)
+            assert out != None, 'get_info_of_latest_version FAILED'
+        except Exception as err:
+            fxn_name = "get_info_by_version"
+            assert str(err) == "DB Error", 'Negative test {} FAILED, Doesnt returned required error'.format(fxn_name)
+            assert checker.finished, 'Cursor Not Closed Properly for fxn {} | Negative Test'.format(fxn_name)
\ No newline at end of file
index 938be4d..39127af 100644 (file)
@@ -28,6 +28,7 @@ import sys
 import datetime
 from flask_api import status
 from dotenv import load_dotenv
+from trainingmgr.constants.states import States
 from threading import Lock
 from trainingmgr import trainingmgr_main 
 from trainingmgr.common.tmgr_logger import TMLogger
@@ -676,3 +677,114 @@ class Test_get_metadata_1:
                                     content_type="application/json")
         trainingmgr_main.LOGGER.debug(response.data)
         assert response.status_code == status.HTTP_404_NOT_FOUND, "Return status code NOT equal"
+
+## Retraining API test
+class Test_retraining:
+    @patch('trainingmgr.common.trainingmgr_config.TMLogger', return_value = TMLogger("tests/common/conf_log.yaml"))
+    def setup_method(self,mock1,mock2):
+        self.client = trainingmgr_main.APP.test_client(self)
+        self.logger = trainingmgr_main.LOGGER
+        
+    #test_positive_1
+    db_result = [('mynetwork', 'testing', '*', 'testing_pipeline', 'Default', '{"arguments": {"epochs": "1", "trainingjob_name": "mynetwork"}}', '', datetime.datetime(2023, 2, 9, 9, 2, 11, 13916), 'No data available', '{"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "IN_PROGRESS", "TRAINING": "NOT_STARTED", "TRAINING_AND_TRAINED_MODEL": "NOT_STARTED", "TRAINED_MODEL": "NOT_STARTED"}', datetime.datetime(2023, 2, 9, 9, 2, 11, 13916), 1, False, '2', '{"datalake_source": {"InfluxSource": {}}}', 'No data available.', '', 'liveCell', 'UEData', False)]
+    mocked_TRAININGMGR_CONFIG_OBJ=mock.Mock(name="TRAININGMGR_CONFIG_OBJ")
+    attrs_TRAININGMGR_CONFIG_OBJ = {'my_ip.return_value': '123'}
+    mocked_TRAININGMGR_CONFIG_OBJ.configure_mock(**attrs_TRAININGMGR_CONFIG_OBJ)
+    #postive_1
+    tmres = Response()
+    tmres.code = "expired"
+    tmres.error_type = "expired"
+    tmres.status_code = status.HTTP_200_OK
+    tmres.headers={"content-type": "application/json"}
+    tmres._content = b'{"task_status": "Completed", "result": "Data Pipeline Execution Completed"}'  
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary',return_value=True) 
+    @patch('trainingmgr.trainingmgr_main.get_info_of_latest_version', return_value= db_result)
+    @patch('trainingmgr.trainingmgr_main.TRAININGMGR_CONFIG_OBJ', return_value = mocked_TRAININGMGR_CONFIG_OBJ)
+    @patch('trainingmgr.trainingmgr_main.add_update_trainingjob',return_value="")
+    @patch('trainingmgr.trainingmgr_main.get_one_word_status',return_value = States.FINISHED.name)
+    @patch('trainingmgr.trainingmgr_main.requests.post',return_value = tmres)
+    def test_retraining(self,mock1, mock2, mock3,mock4, mock5, mock6):
+        retrain_req = {"trainingjobs_list": [{"trainingjob_name": "mynetwork"}]}
+        response = self.client.post("/trainingjobs/retraining", data=json.dumps(retrain_req),content_type="application/json")   
+        data=json.loads(response.data)
+        assert response.status_code == status.HTTP_200_OK, "Return status code NOT equal"
+        assert data["success count"]==1 , "Return success count NOT equal"
+
+    #Negative_1
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary',side_effect = Exception('Mocked error'))
+    def test_negative_retraining_1(self,mock1):
+        retrain_req = {"trainingjobs_list": [{"trainingjob_name": "mynetwork"}]}
+        response = self.client.post("/trainingjobs/retraining", data=json.dumps(retrain_req),content_type="application/json")   
+        assert response.status_code == status.HTTP_400_BAD_REQUEST, "Return status code NOT equal"  
+
+
+    #Negative_2
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary')
+    @patch('trainingmgr.trainingmgr_main.get_info_of_latest_version', side_effect = Exception('Mocked error'))
+    def test_negative_retraining_2(self,mock1,mock2):
+        retrain_req = {"trainingjobs_list": [{"trainingjob_name": "mynetwork"}]}
+        response = self.client.post("/trainingjobs/retraining", data=json.dumps(retrain_req),content_type="application/json")   
+        data = json.loads(response.data)
+        assert response.status_code == status.HTTP_200_OK, "Return status code NOT equal"
+        assert data["failure count"] == 1, "Return failure count NOT equal"
+        
+
+    #Negative_3_when_deletion_in_progress
+    db_result2 = [('mynetwork', 'testing', '*', 'testing_pipeline', 'Default', '{"arguments": {"epochs": "1", "trainingjob_name": "mynetwork"}}', '', datetime.datetime(2023, 2, 9, 9, 2, 11, 13916), 'No data available', '{"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "IN_PROGRESS", "TRAINING": "NOT_STARTED", "TRAINING_AND_TRAINED_MODEL": "NOT_STARTED", "TRAINED_MODEL": "NOT_STARTED"}', datetime.datetime(2023, 2, 9, 9, 2, 11, 13916), 1, False, '2', '{"datalake_source": {"InfluxSource": {}}}', 'No data available.', '', 'liveCell', 'UEData', True)]
+  
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary') 
+    @patch('trainingmgr.trainingmgr_main.get_info_of_latest_version', return_value= db_result2)
+    def test_negative_retraining_3(self,mock1, mock2):
+        retrain_req = {"trainingjobs_list": [{"trainingjob_name": "mynetwork"}]}
+        response = self.client.post("/trainingjobs/retraining", data=json.dumps(retrain_req),content_type="application/json")   
+        data=json.loads(response.data)
+        assert response.status_code == status.HTTP_200_OK, "Return status code NOT equal"
+        assert data["failure count"]==1, "Return failure count NOT equal"
+
+
+    #Negative_4
+    db_result = [('mynetwork', 'testing', '*', 'testing_pipeline', 'Default', '{"arguments": {"epochs": "1", "trainingjob_name": "mynetwork"}}', '', datetime.datetime(2023, 2, 9, 9, 2, 11, 13916), 'No data available', '{"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "IN_PROGRESS", "TRAINING": "NOT_STARTED", "TRAINING_AND_TRAINED_MODEL": "NOT_STARTED", "TRAINED_MODEL": "NOT_STARTED"}', datetime.datetime(2023, 2, 9, 9, 2, 11, 13916), 1, False, '2', '{"datalake_source": {"InfluxSource": {}}}', 'No data available.', '', 'liveCell', 'UEData', False)]
+      
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary',return_value="") 
+    @patch('trainingmgr.trainingmgr_main.get_info_of_latest_version', return_value= db_result)
+    @patch('trainingmgr.trainingmgr_main.add_update_trainingjob',side_effect = Exception('Mocked error'))
+    def test_negative_retraining_4(self,mock1, mock2, mock3):
+        retrain_req = {"trainingjobs_list": [{"trainingjob_name": "mynetwork"}]}
+        response = self.client.post("/trainingjobs/retraining", data=json.dumps(retrain_req),content_type="application/json")   
+        data=json.loads(response.data)
+        assert response.status_code == status.HTTP_200_OK, "Return status code NOT equal"
+        assert data["failure count"]==1, "Return failure count NOT equal"
+
+
+    #Negative_5
+    db_result = [('mynetwork', 'testing', '*', 'testing_pipeline', 'Default', '{"arguments": {"epochs": "1", "trainingjob_name": "mynetwork"}}', '', datetime.datetime(2023, 2, 9, 9, 2, 11, 13916), 'No data available', '{"DATA_EXTRACTION": "FINISHED", "DATA_EXTRACTION_AND_TRAINING": "IN_PROGRESS", "TRAINING": "NOT_STARTED", "TRAINING_AND_TRAINED_MODEL": "NOT_STARTED", "TRAINED_MODEL": "NOT_STARTED"}', datetime.datetime(2023, 2, 9, 9, 2, 11, 13916), 1, False, '2', '{"datalake_source": {"InfluxSource": {}}}', 'No data available.', '', 'liveCell', 'UEData', False)]
+    
+
+    tmres = Response()
+    tmres.code = "expired"
+    tmres.error_type = "expired"
+    tmres.status_code = status.HTTP_204_NO_CONTENT
+    tmres.headers={"content-type": "application/json"}
+    tmres._content = b'{"task_status": "Completed", "result": "Data Pipeline Execution Completed"}'  
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary',return_value="") 
+    @patch('trainingmgr.trainingmgr_main.get_info_of_latest_version', return_value= db_result)
+    @patch('trainingmgr.trainingmgr_main.add_update_trainingjob',return_value="")
+    @patch('trainingmgr.trainingmgr_main.requests.post',return_value = tmres)
+    def test_negative_retraining_5(self,mock1, mock2, mock3,mock4):
+        retrain_req = {"trainingjobs_list": [{"trainingjob_name": "mynetwork"}]}
+        response = self.client.post("/trainingjobs/retraining", data=json.dumps(retrain_req),content_type="application/json")   
+        data=json.loads(response.data)
+        assert response.status_code == status.HTTP_200_OK, "Return status code NOT equal" 
+        assert data["failure count"]==1, "Return failure count NOT equal"
+
+      
+    #Negative_6
+    db_result3 = [] 
+    @patch('trainingmgr.trainingmgr_main.check_key_in_dictionary') 
+    @patch('trainingmgr.trainingmgr_main.get_info_of_latest_version', return_value= db_result3)
+    def test_negative_retraining_6(self,mock1, mock2):
+        retrain_req = {"trainingjobs_list": [{"trainingjob_name": "mynetwork"}]}
+        response = self.client.post("/trainingjobs/retraining", data=json.dumps(retrain_req),content_type="application/json")   
+        data=json.loads(response.data)
+        assert response.status_code == status.HTTP_200_OK, "Return status code NOT equal"
+        assert data["failure count"]==1, "Return failure count NOT equal"
index dc918f7..d0b9fa1 100644 (file)
@@ -972,7 +972,6 @@ def retraining():
         raise APIException(status.HTTP_400_BAD_REQUEST, str(err)) from None
 
     trainingjobs_list = request.json['trainingjobs_list']
-    print("trainingjobs list is :", trainingjobs_list)
     if not isinstance(trainingjobs_list, list):
         raise APIException(status.HTTP_400_BAD_REQUEST, "not given as list")
 
@@ -1037,7 +1036,7 @@ def retraining():
                                       notification_url, _measurement, bucket)
             except Exception as err:
                 not_possible_to_retrain.append(trainingjob_name)
-                LOGGER.debug(str(err) + "(usecase_name is " + trainingjob_name + ")")
+                LOGGER.debug(str(err) + "(training job is " + trainingjob_name + ")")
                 continue
 
             url = 'http://' + str(TRAININGMGR_CONFIG_OBJ.my_ip) + \