From af8b118f74f43c481022079dd5b04d84cc75f0db Mon Sep 17 00:00:00 2001 From: rajdeep11 Date: Wed, 31 May 2023 11:27:19 +0530 Subject: [PATCH] adding Access-Control-Allow-Origin. Change-Id: I27ebe642297ca43b4a13aa6f319357db31b2ce72 Signed-off-by: rajdeep11 --- tests/test.env | 1 + tests/test_tm_apis.py | 2 +- tests/test_trainingmgr_config.py | 5 +++++ trainingmgr/common/trainingmgr_config.py | 16 +++++++++++++++- trainingmgr/trainingmgr_main.py | 19 +++---------------- 5 files changed, 25 insertions(+), 18 deletions(-) diff --git a/tests/test.env b/tests/test.env index 325004a..a5c0048 100644 --- a/tests/test.env +++ b/tests/test.env @@ -28,3 +28,4 @@ PS_USER=postgres PS_PASSWORD="abcd" PS_IP="localhost" PS_PORT="30001" +ACCESS_CONTROL_ALLOW_ORIGIN="http://localhost:32005" diff --git a/tests/test_tm_apis.py b/tests/test_tm_apis.py index 6bfb8d0..82eb7bb 100644 --- a/tests/test_tm_apis.py +++ b/tests/test_tm_apis.py @@ -28,6 +28,7 @@ import sys import datetime from flask_api import status from dotenv import load_dotenv +load_dotenv('tests/test.env') from trainingmgr.constants.states import States from threading import Lock from trainingmgr import trainingmgr_main @@ -443,7 +444,6 @@ class Test_get_versions_for_pipeline: def setup_method(self,mock1,mock2): self.client = trainingmgr_main.APP.test_client(self) self.logger = trainingmgr_main.LOGGER - load_dotenv('tests/test.env') self.TRAININGMGR_CONFIG_OBJ = TrainingMgrConfig() the_response = Response() diff --git a/tests/test_trainingmgr_config.py b/tests/test_trainingmgr_config.py index ce944cd..08385cf 100644 --- a/tests/test_trainingmgr_config.py +++ b/tests/test_trainingmgr_config.py @@ -99,6 +99,11 @@ class Test_trainingmgr_config: result = self.TRAININGMGR_CONFIG_OBJ.ps_port assert result == expected_data + def test_allow_access_allowed_origin(self): + expected_data = "http://localhost:32005" + result = self.TRAININGMGR_CONFIG_OBJ.allow_control_access_origin + assert result == expected_data + def test_is_config_loaded_properly(self): expected_data = True result = TrainingMgrConfig.is_config_loaded_properly(self.TRAININGMGR_CONFIG_OBJ) diff --git a/trainingmgr/common/trainingmgr_config.py b/trainingmgr/common/trainingmgr_config.py index 4b98958..11c6b92 100644 --- a/trainingmgr/common/trainingmgr_config.py +++ b/trainingmgr/common/trainingmgr_config.py @@ -46,6 +46,7 @@ class TrainingMgrConfig: self.__ps_password = getenv('PS_PASSWORD').rstrip() self.__ps_ip = getenv('PS_IP').rstrip() self.__ps_port = getenv('PS_PORT').rstrip() + self.__allow_control_access_origin = getenv('ACCESS_CONTROL_ALLOW_ORIGIN').rstrip() self.tmgr_logger = TMLogger("common/conf_log.yaml") self.__logger = self.tmgr_logger.logger @@ -182,6 +183,19 @@ class TrainingMgrConfig: """ return self.__ps_port + @property + def allow_control_access_origin(self): + """ + Function for getting allow_control_access_origin + + Args: None + + Returns: + string allow_control_access_origin + + """ + return self.__allow_control_access_origin + def is_config_loaded_properly(self): """ This function checks where all environment variable got value or not. @@ -193,7 +207,7 @@ class TrainingMgrConfig: for var in [self.__kf_adapter_ip, self.__kf_adapter_port, self.__data_extraction_ip, self.__data_extraction_port, self.__my_port, self.__ps_ip, self.__ps_port, self.__ps_user, - self.__ps_password, self.__my_ip, self.__logger]: + self.__ps_password, self.__my_ip, self.__allow_control_access_origin, self.__logger]: if var is None: all_present = False return all_present diff --git a/trainingmgr/trainingmgr_main.py b/trainingmgr/trainingmgr_main.py index 5e4779b..ffb338f 100644 --- a/trainingmgr/trainingmgr_main.py +++ b/trainingmgr/trainingmgr_main.py @@ -30,7 +30,7 @@ import time from flask import Flask, request, send_file from flask_api import status import requests -from flask_cors import cross_origin +from flask_cors import CORS from werkzeug.utils import secure_filename from modelmetricsdk.model_metrics_sdk import ModelMetricsSdk from trainingmgr.common.trainingmgr_operations import data_extraction_start, training_start, data_extraction_status, create_dme_filtered_data_job, delete_dme_filtered_data_job @@ -78,7 +78,6 @@ def error(err): @APP.route('/trainingjobs//', methods=['GET']) -@cross_origin() def get_trainingjob_by_name_version(trainingjob_name, version): """ Rest endpoint to fetch training job details by name and version @@ -189,7 +188,6 @@ def get_trainingjob_by_name_version(trainingjob_name, version): mimetype=MIMETYPE_JSON) @APP.route('/trainingjobs///steps_state', methods=['GET']) # Handled in GUI -@cross_origin() def get_steps_state(trainingjob_name, version): """ Function handling rest end points to get steps_state information for @@ -282,7 +280,6 @@ def get_model(trainingjob_name, version): @APP.route('/trainingjobs//training', methods=['POST']) # Handled in GUI -@cross_origin() def training(trainingjob_name): """ Rest end point to start training job. @@ -536,7 +533,6 @@ def pipeline_notification(): @APP.route('/trainingjobs/latest', methods=['GET']) -@cross_origin() def trainingjobs_operations(): """ Rest endpoint to fetch overall status, latest version of all existing training jobs @@ -585,7 +581,6 @@ def trainingjobs_operations(): mimetype=MIMETYPE_JSON) @APP.route("/pipelines//upload", methods=['POST']) -@cross_origin() def upload_pipeline(pipe_name): """ Function handling rest endpoint to upload pipeline. @@ -684,7 +679,6 @@ def upload_pipeline(pipe_name): @APP.route("/pipelines//versions", methods=['GET']) -@cross_origin() def get_versions_for_pipeline(pipeline_name): """ Function handling rest endpoint to get versions of given pipeline name. @@ -739,7 +733,6 @@ def get_versions_for_pipeline(pipeline_name): mimetype=MIMETYPE_JSON) @APP.route('/pipelines', methods=['GET']) -@cross_origin() def get_all_pipeline_names(): """ Function handling rest endpoint to get all pipeline names. @@ -773,7 +766,6 @@ def get_all_pipeline_names(): return APP.response_class(response=json.dumps(api_response),status=response_code,mimetype=MIMETYPE_JSON) @APP.route('/experiments', methods=['GET']) -@cross_origin() def get_all_experiment_names(): """ Function handling rest endpoint to get all experiment names. @@ -823,7 +815,6 @@ def get_all_experiment_names(): @APP.route('/trainingjobs/', methods=['POST', 'PUT']) # Handled in GUI -@cross_origin() def trainingjob_operations(trainingjob_name): """ Rest endpoind to create or update trainingjob @@ -937,7 +928,6 @@ def trainingjob_operations(trainingjob_name): mimetype=MIMETYPE_JSON) @APP.route('/trainingjobs/retraining', methods=['POST']) -@cross_origin() def retraining(): """ Function handling rest endpoint to retrain trainingjobs in request json. trainingjob's @@ -1065,7 +1055,6 @@ def retraining(): mimetype='application/json') @APP.route('/trainingjobs', methods=['DELETE']) -@cross_origin() def delete_list_of_trainingjob_version(): """ Function handling rest endpoint to delete latest version of trainingjob_name trainingjobs which is @@ -1263,7 +1252,6 @@ def get_metadata(trainingjob_name): mimetype=MIMETYPE_JSON) @APP.route('/featureGroup', methods=['POST']) -@cross_origin() def create_feature_group(): """ Rest endpoint to create feature group @@ -1354,7 +1342,6 @@ def create_feature_group(): mimetype=MIMETYPE_JSON) @APP.route('/featureGroup', methods=['GET']) -@cross_origin() def get_feature_group(): """ Rest endpoint to fetch all the feature groups @@ -1403,7 +1390,6 @@ def get_feature_group(): mimetype=MIMETYPE_JSON) @APP.route('/featureGroup/', methods=['GET']) -@cross_origin() def get_feature_group_by_name(featuregroup_name): """ Rest endpoint to fetch a feature group @@ -1483,7 +1469,6 @@ def get_feature_group_by_name(featuregroup_name): mimetype=MIMETYPE_JSON) @APP.route('/featureGroup', methods=['DELETE']) -@cross_origin() def delete_list_of_feature_group(): """ Function handling rest endpoint to delete featureGroup which is @@ -1645,6 +1630,8 @@ if __name__ == "__main__": DATAEXTRACTION_JOBS_CACHE = get_data_extraction_in_progress_trainingjobs(PS_DB_OBJ) threading.Thread(target=async_feature_engineering_status, daemon=True).start() MM_SDK = ModelMetricsSdk() + list_allow_control_access_origin = TRAININGMGR_CONFIG_OBJ.allow_control_access_origin.split(',') + CORS(APP, resources={r"/*": {"origins": list_allow_control_access_origin}}) LOGGER.debug("Starting AIML-WF training manager .....") APP.run(debug=True, port=int(TRAININGMGR_CONFIG_OBJ.my_port), host='0.0.0.0') except TMException as err: -- 2.16.6