test cases for 26/13726/4
authorrajdeep11 <rajdeep.sin@samsung.com>
Wed, 30 Oct 2024 20:42:39 +0000 (02:12 +0530)
committerrajdeep11 <rajdeep.sin@samsung.com>
Fri, 8 Nov 2024 10:30:47 +0000 (16:00 +0530)
get_steps_state and get_model

Change-Id: I6f51eb514b5ed8a1db7c874b925e98c4b5350123
Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
tests/test_tm_apis.py
trainingmgr/trainingmgr_main.py

index faa8580..8ceca5a 100644 (file)
@@ -403,50 +403,79 @@ class Test_unpload_pipeline:
         expected = ValueError("file not found in request.files")
         assert response.content_type == "application/json", "not equal content type"
         assert response.status_code == 500, "not equal code"
-@pytest.mark.skip("")
+
 class Test_get_steps_state:
-      def setup_method(self):
+    def setup_method(self):
         self.client = trainingmgr_main.APP.test_client(self)
         self.logger = trainingmgr_main.LOGGER
+    
+    @pytest.fixture
+    def mock_steps_state(self):
+        """Create mock steps state data."""
+        return {
+            "DATA_EXTRACTION": "FINISHED",
+            "DATA_EXTRACTION_AND_TRAINING": "FINISHED",
+            "TRAINING": "FINISHED",
+            "TRAINING_AND_TRAINED_MODEL": "FINISHED",
+            "TRAINED_MODEL": "FINISHED"
+        }
       
-      @patch('trainingmgr.trainingmgr_main.get_field_of_given_version',return_value=[['data_extracted','data_pending'], ['data1','data2']])
-      def test_get_steps_state(self,mock1):
-          usecase_name = "usecase7"
-          version = "1" 
-          response = self.client.get("/trainingjobs/{}/{}/steps_state".format(usecase_name, version))
-          expected_data = b'data_extracted'
-          assert response.content_type == "application/json", "not equal content type"
-          assert response.status_code == status.HTTP_200_OK, "not equal code"
-          assert response.data == expected_data, "not equal data"
-
-      @patch('trainingmgr.trainingmgr_main.get_field_of_given_version',return_value=False)
-      def test_negative_get_steps_state(self,mock1):
-          usecase_name = "usecase7"
-          version = "1" 
-          response = self.client.get("/trainingjobs/{}/{}/steps_state".format(usecase_name, version))
-          expected_data = b'data_extracted'
-          assert response.content_type == "application/json", "not equal content type"
-          assert response.status_code == 404, "not equal code"
+    @patch('trainingmgr.trainingmgr_main.get_steps_state_db')
+    @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version')
+    def test_successful_get_steps_state(self, mock_name_and_version, mock_get_steps_state, mock_steps_state):
+        """Test successful retrieval of steps state."""
+
+        mock_get_steps_state.return_value = mock_steps_state
+        response = self.client.get('/trainingjobs/test_job/1/steps_state')
+        
+        assert response.status_code == status.HTTP_200_OK
+        data = response.get_json()
+        
+        # Verify all expected states are present
+        assert "DATA_EXTRACTION" in data
+        assert "DATA_EXTRACTION_AND_TRAINING" in data
+        assert "TRAINING" in data
+        assert "TRAINING_AND_TRAINED_MODEL" in data
+        assert "TRAINED_MODEL" in data
+        
+        # Verify state values
+        assert data["DATA_EXTRACTION"] == "FINISHED"
+        assert data["TRAINING"] == "FINISHED"
+        assert data["TRAINED_MODEL"] == "FINISHED"
     
-      @patch('trainingmgr.trainingmgr_main.get_field_of_given_version',return_value=Exception("Not found given trainingjob with version"))
-      def test_negative_get_steps_state_2(self,mock1):
-          usecase_name = "usecase7"
-          version = "1" 
-          response = self.client.get("/trainingjobs/{}/{}/steps_state".format(usecase_name, version))
-          expected_data = b'data_extracted'
-          assert response.status_code == 500, "not equal code"
+    @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=False)
+    def test_invalid_name_version(self, mock1):
+        """Test with invalid training job name or version."""
+        response = self.client.get('/trainingjobs/invalid_job/999/steps_state')
+        
+        assert response.status_code == status.HTTP_400_BAD_REQUEST
+        data = response.get_json()
+        assert "Exception" in data
+        assert "trainingjob_name or version is not correct" in data["Exception"]
+
+    @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True)
+    @patch('trainingmgr.trainingmgr_main.get_steps_state_db', return_value=None)
+    def test_nonexistent_trainingjob(self, mock1, mock2):
+        """Test when training job doesn't exist in database."""
+
+        response = self.client.get('/trainingjobs/nonexistent_job/1/steps_state')
         
-      def test_negative_get_steps_state_by_name_and_version(self):
-        usecase_name = "usecase7*"
-        version = "1"
-        response = self.client.get("/trainingjobs/{}/{}/steps_state".format(usecase_name, version))
-        assert response.status_code == status.HTTP_400_BAD_REQUEST, "not equal status code"
-        assert response.data == b'{"Exception":"The trainingjob_name or version is not correct"}\n'
-        usecase_name="usecase7"
-        version="a"
-        response = self.client.get("/trainingjobs/{}/{}/steps_state".format(usecase_name, version))
-        assert response.status_code == status.HTTP_400_BAD_REQUEST, "not equal status code"
-        assert response.data == b'{"Exception":"The trainingjob_name or version is not correct"}\n'
+        assert response.status_code == status.HTTP_404_NOT_FOUND
+        data = response.get_json()
+        assert "Exception" in data
+        assert "Not found given trainingjob in database" in data["Exception"]
+
+    @patch('trainingmgr.trainingmgr_main.check_trainingjob_name_and_version', return_value=True)
+    @patch('trainingmgr.trainingmgr_main.get_steps_state_db', side_effect=Exception("Database error"))
+    def test_database_error(self, mock1, mock2):
+        """Test handling of database errors."""
+
+        response = self.client.get('/trainingjobs/test_job/1/steps_state')
+        
+        assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
+        data = json.loads(response.data)
+        assert "Exception" in data
+        assert "Database error" in data["Exception"]
           
 @pytest.mark.skip("")
 class Test_training_main:
@@ -928,7 +957,6 @@ class Test_get_metadata:
         assert response.status_code==status.HTTP_400_BAD_REQUEST
         assert response.data == b'{"Exception":"The trainingjob_name is not correct"}\n'
 
-@pytest.mark.skip("")
 class Test_get_model:
         def setup_method(self):
             self.client = trainingmgr_main.APP.test_client(self)
index 2dcbd03..cc4d458 100644 (file)
@@ -274,7 +274,7 @@ def get_steps_state(trainingjob_name, version):
         LOGGER.error(str(err))
         response_data = {"Exception": str(err)}
 
-    return APP.response_class(response=response_data,
+    return APP.response_class(response=json.dumps(response_data),
                                       status=response_code,
                                       mimetype=MIMETYPE_JSON)