adding Access-Control-Allow-Origin. 50/11250/4
authorrajdeep11 <rajdeep.sin@samsung.com>
Wed, 31 May 2023 05:57:19 +0000 (11:27 +0530)
committerrajdeep11 <rajdeep.sin@samsung.com>
Wed, 31 May 2023 10:32:03 +0000 (16:02 +0530)
Change-Id: I27ebe642297ca43b4a13aa6f319357db31b2ce72
Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
tests/test.env
tests/test_tm_apis.py
tests/test_trainingmgr_config.py
trainingmgr/common/trainingmgr_config.py
trainingmgr/trainingmgr_main.py

index 325004a..a5c0048 100644 (file)
@@ -28,3 +28,4 @@ PS_USER=postgres
 PS_PASSWORD="abcd"
 PS_IP="localhost"
 PS_PORT="30001"
+ACCESS_CONTROL_ALLOW_ORIGIN="http://localhost:32005"
index 6bfb8d0..82eb7bb 100644 (file)
@@ -28,6 +28,7 @@ import sys
 import datetime
 from flask_api import status
 from dotenv import load_dotenv
+load_dotenv('tests/test.env')
 from trainingmgr.constants.states import States
 from threading import Lock
 from trainingmgr import trainingmgr_main 
@@ -443,7 +444,6 @@ class Test_get_versions_for_pipeline:
     def setup_method(self,mock1,mock2):
         self.client = trainingmgr_main.APP.test_client(self)
         self.logger = trainingmgr_main.LOGGER
-        load_dotenv('tests/test.env')
         self.TRAININGMGR_CONFIG_OBJ = TrainingMgrConfig()   
 
     the_response = Response()
index ce944cd..08385cf 100644 (file)
@@ -99,6 +99,11 @@ class Test_trainingmgr_config:
         result = self.TRAININGMGR_CONFIG_OBJ.ps_port
         assert result == expected_data
 
+    def test_allow_access_allowed_origin(self):
+        expected_data = "http://localhost:32005"
+        result = self.TRAININGMGR_CONFIG_OBJ.allow_control_access_origin
+        assert result == expected_data
+
     def test_is_config_loaded_properly(self):
         expected_data = True
         result = TrainingMgrConfig.is_config_loaded_properly(self.TRAININGMGR_CONFIG_OBJ)
index 4b98958..11c6b92 100644 (file)
@@ -46,6 +46,7 @@ class TrainingMgrConfig:
         self.__ps_password = getenv('PS_PASSWORD').rstrip()
         self.__ps_ip = getenv('PS_IP').rstrip()
         self.__ps_port = getenv('PS_PORT').rstrip()
+        self.__allow_control_access_origin = getenv('ACCESS_CONTROL_ALLOW_ORIGIN').rstrip()
 
         self.tmgr_logger = TMLogger("common/conf_log.yaml")
         self.__logger = self.tmgr_logger.logger
@@ -182,6 +183,19 @@ class TrainingMgrConfig:
         """
         return self.__ps_port
 
+    @property
+    def allow_control_access_origin(self):
+        """
+        Function for getting allow_control_access_origin
+
+        Args: None
+
+        Returns:
+            string allow_control_access_origin
+        
+        """
+        return self.__allow_control_access_origin
+
     def is_config_loaded_properly(self):
         """
         This function checks where all environment variable got value or not.
@@ -193,7 +207,7 @@ class TrainingMgrConfig:
         for var in [self.__kf_adapter_ip, self.__kf_adapter_port,
                     self.__data_extraction_ip, self.__data_extraction_port,
                     self.__my_port, self.__ps_ip, self.__ps_port, self.__ps_user,
-                    self.__ps_password, self.__my_ip, self.__logger]:
+                    self.__ps_password, self.__my_ip, self.__allow_control_access_origin, self.__logger]:
             if var is None:
                 all_present = False
         return all_present
index 5e4779b..ffb338f 100644 (file)
@@ -30,7 +30,7 @@ import time
 from flask import Flask, request, send_file
 from flask_api import status
 import requests
-from flask_cors import cross_origin
+from flask_cors import CORS
 from werkzeug.utils import secure_filename
 from modelmetricsdk.model_metrics_sdk import ModelMetricsSdk
 from trainingmgr.common.trainingmgr_operations import data_extraction_start, training_start, data_extraction_status, create_dme_filtered_data_job, delete_dme_filtered_data_job
@@ -78,7 +78,6 @@ def error(err):
 
 
 @APP.route('/trainingjobs/<trainingjob_name>/<version>', methods=['GET'])
-@cross_origin()
 def get_trainingjob_by_name_version(trainingjob_name, version):
     """
     Rest endpoint to fetch training job details by name and version
@@ -189,7 +188,6 @@ def get_trainingjob_by_name_version(trainingjob_name, version):
                                         mimetype=MIMETYPE_JSON)
 
 @APP.route('/trainingjobs/<trainingjob_name>/<version>/steps_state', methods=['GET']) # Handled in GUI
-@cross_origin()
 def get_steps_state(trainingjob_name, version):
     """
     Function handling rest end points to get steps_state information for
@@ -282,7 +280,6 @@ def get_model(trainingjob_name, version):
 
 
 @APP.route('/trainingjobs/<trainingjob_name>/training', methods=['POST']) # Handled in GUI
-@cross_origin()
 def training(trainingjob_name):
     """
     Rest end point to start training job.
@@ -536,7 +533,6 @@ def pipeline_notification():
 
 
 @APP.route('/trainingjobs/latest', methods=['GET'])
-@cross_origin()
 def trainingjobs_operations():
     """
     Rest endpoint to fetch overall status, latest version of all existing training jobs
@@ -585,7 +581,6 @@ def trainingjobs_operations():
                         mimetype=MIMETYPE_JSON)
 
 @APP.route("/pipelines/<pipe_name>/upload", methods=['POST'])
-@cross_origin()
 def upload_pipeline(pipe_name):
     """
     Function handling rest endpoint to upload pipeline.
@@ -684,7 +679,6 @@ def upload_pipeline(pipe_name):
 
 
 @APP.route("/pipelines/<pipeline_name>/versions", methods=['GET'])
-@cross_origin()
 def get_versions_for_pipeline(pipeline_name):
     """
     Function handling rest endpoint to get versions of given pipeline name.
@@ -739,7 +733,6 @@ def get_versions_for_pipeline(pipeline_name):
             mimetype=MIMETYPE_JSON)
  
 @APP.route('/pipelines', methods=['GET'])
-@cross_origin()
 def get_all_pipeline_names():
     """
     Function handling rest endpoint to get all pipeline names.
@@ -773,7 +766,6 @@ def get_all_pipeline_names():
     return APP.response_class(response=json.dumps(api_response),status=response_code,mimetype=MIMETYPE_JSON)
 
 @APP.route('/experiments', methods=['GET'])
-@cross_origin()
 def get_all_experiment_names():
     """
     Function handling rest endpoint to get all experiment names.
@@ -823,7 +815,6 @@ def get_all_experiment_names():
 
 
 @APP.route('/trainingjobs/<trainingjob_name>', methods=['POST', 'PUT']) # Handled in GUI
-@cross_origin()
 def trainingjob_operations(trainingjob_name):
     """
     Rest endpoind to create or update trainingjob
@@ -937,7 +928,6 @@ def trainingjob_operations(trainingjob_name):
                     mimetype=MIMETYPE_JSON)
 
 @APP.route('/trainingjobs/retraining', methods=['POST'])
-@cross_origin()
 def retraining():
     """
     Function handling rest endpoint to retrain trainingjobs in request json. trainingjob's
@@ -1065,7 +1055,6 @@ def retraining():
         mimetype='application/json')
 
 @APP.route('/trainingjobs', methods=['DELETE'])
-@cross_origin()
 def delete_list_of_trainingjob_version():
     """
     Function handling rest endpoint to delete latest version of trainingjob_name trainingjobs which is
@@ -1263,7 +1252,6 @@ def get_metadata(trainingjob_name):
                                         mimetype=MIMETYPE_JSON)
 
 @APP.route('/featureGroup', methods=['POST'])
-@cross_origin()
 def create_feature_group():
     """
     Rest endpoint to create feature group
@@ -1354,7 +1342,6 @@ def create_feature_group():
                                         mimetype=MIMETYPE_JSON)
 
 @APP.route('/featureGroup', methods=['GET'])
-@cross_origin()
 def get_feature_group():
     """
     Rest endpoint to fetch all the feature groups
@@ -1403,7 +1390,6 @@ def get_feature_group():
                         mimetype=MIMETYPE_JSON)
 
 @APP.route('/featureGroup/<featuregroup_name>', methods=['GET'])
-@cross_origin()
 def get_feature_group_by_name(featuregroup_name):
     """
     Rest endpoint to fetch a feature group
@@ -1483,7 +1469,6 @@ def get_feature_group_by_name(featuregroup_name):
                         mimetype=MIMETYPE_JSON) 
 
 @APP.route('/featureGroup', methods=['DELETE'])
-@cross_origin()
 def delete_list_of_feature_group():
     """
     Function handling rest endpoint to delete featureGroup which is
@@ -1645,6 +1630,8 @@ if __name__ == "__main__":
         DATAEXTRACTION_JOBS_CACHE = get_data_extraction_in_progress_trainingjobs(PS_DB_OBJ)
         threading.Thread(target=async_feature_engineering_status, daemon=True).start()
         MM_SDK = ModelMetricsSdk()
+        list_allow_control_access_origin = TRAININGMGR_CONFIG_OBJ.allow_control_access_origin.split(',')
+        CORS(APP, resources={r"/*": {"origins": list_allow_control_access_origin}})
         LOGGER.debug("Starting AIML-WF training manager .....")
         APP.run(debug=True, port=int(TRAININGMGR_CONFIG_OBJ.my_port), host='0.0.0.0')
     except TMException as err: