from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
from trainingmgr.common.exceptions_utls import DBException, TMException
from trainingmgr.models import TrainingJob
-
+from trainingmgr.common.trainingConfig_parser import getField
trainingmgr_main.LOGGER = pytest.logger
trainingmgr_main.LOCK = Lock()
trainingmgr_main.DATAEXTRACTION_JOBS_CACHE = {}
"""Create a mock TrainingJob object."""
creation_time = datetime.datetime.now()
updation_time = datetime.datetime.now()
+ training_config = {
+ "is_mme" : True,
+ "description": "Test description",
+ "dataPipeline": {
+ "feature_group_name": "test_feature_group",
+ "query_filter": "",
+ "arguments": {"epochs" : 1, "trainingjob_name": "test_job"}
+ },
+ "trainingPipeline": {
+ "pipeline_name": "test_pipeline",
+ "pipeline_version": "2",
+ "enable_versioning": True
+ }
+ }
+
+
return TrainingJob(
trainingjob_name="test_job",
- description="Test description",
- feature_group_name="test_feature_group",
- pipeline_name="test_pipeline",
- experiment_name="test_experiment",
- arguments=json.dumps({"param1": "value1"}),
- query_filter="test_filter",
+ training_config = json.dumps(training_config),
creation_time=creation_time,
run_id="test_run_id",
steps_state=json.dumps({"step1": "completed"}),
updation_time=updation_time,
version=1,
- enable_versioning=True,
- pipeline_version="v1",
- datalake_source=json.dumps({"datalake_source": {"source1": "path1"}}),
model_url="http://test.model.url",
notification_url="http://test.notification.url",
deletion_in_progress=False,
- is_mme=True,
model_name="test_model",
model_info="test_model_info"
)
assert 'trainingjob' in data
job_data = data['trainingjob']
assert job_data['trainingjob_name'] == "test_job"
- assert job_data['description'] == "Test description"
- assert job_data['feature_list'] == "test_feature_group"
- assert job_data['pipeline_name'] == "test_pipeline"
- assert job_data['experiment_name'] == "test_experiment"
- assert job_data['is_mme'] is True
+ assert job_data['training_config']['description'] == "Test description"
+ assert job_data['training_config']['dataPipeline']['feature_group_name'] == "test_feature_group"
+ assert job_data['training_config']['trainingPipeline']['pipeline_name'] == "test_pipeline"
+ assert job_data['training_config']['is_mme'] is True
assert job_data['model_name'] == "test_model"
assert job_data['model_info'] == "test_model_info"
assert job_data['accuracy'] == mock_metrics
--- /dev/null
+# ==================================================================================
+#
+# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==================================================================================
+
+import json
+
+def parse_dict_by_fields(data, fields):
+ '''
+ It parses the provided data (dicts) by the fields provided
+ Example:
+ data = {"a": 1, "b": {"c" : 4, "d" : {-1}}}
+ fields = ["a", "b", "c"] = 4
+ fields = ["a", "b", "d"] = -1
+ '''
+ try:
+ cur = data
+ for field in fields:
+ cur = cur[field]
+ return cur
+ except Exception as e:
+ raise Exception("Can't parse Fields: {} in Data : {}| recieved-error : {}".format(fields, data, e))
+
+
+def __getLeafPaths():
+ '''
+ It returns paths possible to retrieve data
+ Based on TrainingConfig Schema:
+ {
+ "is_mme" : false,
+ "description": "This is something3",
+ "dataPipeline": {
+ "feature_group_name": "base2",
+ "query_filter": "",
+ "arguments": "{'epochs': '1'}"
+ },
+ "trainingPipeline": {
+ "pipeline_name": "qoe_Pipeline",
+ "pipeline_version": "2",
+ "enable_versioning": true
+ }
+ '''
+ paths = {
+ "is_mme": ["is_mme"],
+ "description": ["description"],
+ "feature_group_name": ["dataPipeline", "feature_group_name"],
+ "query_filter" : ["dataPipeline", "query_filter"],
+ "arguments" : ["dataPipeline", "arguments"],
+ "pipeline_name": ["trainingPipeline", "pipeline_name"],
+ "pipeline_version": ["trainingPipeline", "pipeline_version"],
+ "enable_versioning": ["trainingPipeline", "enable_versioning"]
+ }
+ return paths
+
+def prepocessTrainingConfig(trainingConfig):
+ if (isinstance(trainingConfig, str)):
+ return json.loads(trainingConfig)
+ return trainingConfig
+
+def validateTrainingConfig(trainingConfig):
+ '''
+ One way to Validate TrainingConfig is to see if each Leafpath exists or not
+ Any other key:value pair than TrainingConfig Schema is not treated as invalid.
+ '''
+ trainingConfig = prepocessTrainingConfig(trainingConfig)
+ allPaths = __getLeafPaths()
+ try:
+ for fieldPath in allPaths.values():
+ parse_dict_by_fields(trainingConfig, fieldPath)
+ return True
+ except Exception as e:
+ print("Unable to Validate Error: ", e)
+ return False
+
+def getField(trainingConfig, fieldname):
+ trainingConfig = prepocessTrainingConfig(trainingConfig)
+ fieldPath = __getLeafPaths()[fieldname]
+ return parse_dict_by_fields(trainingConfig, fieldPath)
+
\ No newline at end of file
from trainingmgr.constants.states import States
from trainingmgr.db.trainingmgr_ps_db import PSDB
from trainingmgr.common.exceptions_utls import DBException
-from trainingmgr.db.common_db_fun import get_data_extraction_in_progress_trainingjobs, \
- change_in_progress_to_failed_by_latest_version, \
- get_all_versions_info_by_name, \
- update_model_download_url, \
- get_field_of_given_version
+from trainingmgr.db.common_db_fun import get_data_extraction_in_progress_trainingjobs
from trainingmgr.models import db, TrainingJob, FeatureGroup
from trainingmgr.schemas import ma, TrainingJobSchema , FeatureGroupSchema
from trainingmgr.db.featuregroup_db import add_featuregroup, edit_featuregroup, get_feature_groups_db, \
from trainingmgr.db.trainingjob_db import add_update_trainingjob, get_trainingjob_info_by_name, \
get_all_jobs_latest_status_version, change_steps_state_of_latest_version, get_info_by_version, \
get_steps_state_db, change_field_of_latest_version, get_latest_version_trainingjob_name, get_info_of_latest_version, \
- change_field_value_by_version, delete_trainingjob_version
+ change_field_value_by_version, delete_trainingjob_version, change_in_progress_to_failed_by_latest_version, \
+ update_model_download_url, get_all_versions_info_by_name
+from trainingmgr.common.trainingConfig_parser import validateTrainingConfig, getField
APP = Flask(__name__)
trainingjob_schema = TrainingJobSchema()
trainingjobs_schema = TrainingJobSchema(many=True)
+
@APP.errorhandler(APIException)
def error(err):
"""
status=err.code,
mimetype=MIMETYPE_JSON)
-
+# Training-Config Handled
@APP.route('/trainingjobs/<trainingjob_name>/<version>', methods=['GET'])
def get_trainingjob_by_name_version(trainingjob_name, version):
"""
if trainingjob:
dict_data = {
"trainingjob_name": trainingjob.trainingjob_name,
- "description": trainingjob.description,
- "feature_list": trainingjob.feature_group_name,
- "pipeline_name": trainingjob.pipeline_name,
- "experiment_name": trainingjob.experiment_name,
- "arguments": trainingjob.arguments,
- "query_filter": trainingjob.query_filter,
+ "training_config": json.loads(trainingjob.training_config),
+ # "description": trainingjob.description,
+ # "feature_list": trainingjob.feature_group_name,
+ # "pipeline_name": trainingjob.pipeline_name,
+ # "experiment_name": trainingjob.experiment_name,
+ # "arguments": trainingjob.arguments,
+ # "query_filter": trainingjob.query_filter,
"creation_time": str(trainingjob.creation_time),
"run_id": trainingjob.run_id,
"steps_state": json.loads(trainingjob.steps_state),
"updation_time": str(trainingjob.updation_time),
"version": trainingjob.version,
- "enable_versioning": trainingjob.enable_versioning,
- "pipeline_version": trainingjob.pipeline_version,
- "datalake_source": get_one_key(json.loads(trainingjob.datalake_source)['datalake_source']),
+ # "enable_versioning": trainingjob.enable_versioning,
+ # "pipeline_version": trainingjob.pipeline_version,
+ # "datalake_source": get_one_key(json.loads(trainingjob.datalake_source)['datalake_source']),
"model_url": trainingjob.model_url,
"notification_url": trainingjob.notification_url,
- "is_mme": trainingjob.is_mme,
+ # "is_mme": trainingjob.is_mme,
"model_name": trainingjob.model_name,
"model_info": trainingjob.model_info,
"accuracy": data
status=response_code,
mimetype=MIMETYPE_JSON)
-@APP.route('/trainingjobs/<trainingjob_name>/<version>/steps_state', methods=['GET']) # Handled in GUI
+# Training-Config Handled (No Change)
+@APP.route('/trainingjobs/<trainingjob_name>/<version>/steps_state', methods=['GET'])
def get_steps_state(trainingjob_name, version):
"""
Function handling rest end points to get steps_state information for
status=response_code,
mimetype=MIMETYPE_JSON)
+# Training-Config Handled (No Change)
@APP.route('/model/<trainingjob_name>/<version>/Model.zip', methods=['GET'])
def get_model(trainingjob_name, version):
"""
except Exception:
return {"Exception": "error while downloading model"}, status.HTTP_500_INTERNAL_SERVER_ERROR
-
-@APP.route('/trainingjobs/<trainingjob_name>/training', methods=['POST']) # Handled in GUI
+# Training-Config Handled
+@APP.route('/trainingjobs/<trainingjob_name>/training', methods=['POST'])
def training(trainingjob_name):
"""
Rest end point to start training job.
"(trainingjob: " + trainingjob_name + ")") from None
else:
- trainingjob = get_trainingjob_info_by_name(trainingjob_name)
- featuregroup= get_feature_group_by_name_db(trainingjob.feature_group_name)
+ trainingjob = get_trainingjob_info_by_name(trainingjob_name)
+
+ featuregroup= get_feature_group_by_name_db(getField(trainingjob.training_config, "feature_group_name"))
feature_list_string = featuregroup.feature_list
influxdb_info_dic={}
influxdb_info_dic["host"]=featuregroup.host
influxdb_info_dic["db_org"] = featuregroup.db_org
influxdb_info_dic["source_name"]= featuregroup.source_name
_measurement = featuregroup.measurement
- query_filter = trainingjob.query_filter
- datalake_source = json.loads(trainingjob.datalake_source)['datalake_source']
+ query_filter = getField(trainingjob.training_config, "query_filter")
+ datalake_source = {featuregroup.datalake_source: {}} # Datalake source should be taken from FeatureGroup (not TrainingJob)
LOGGER.debug('Starting Data Extraction...')
de_response = data_extraction_start(TRAININGMGR_CONFIG_OBJ, trainingjob_name,
feature_list_string, query_filter, datalake_source,
raise TMException("Data extraction doesn't send json type response" + \
"(trainingjob name is " + trainingjob_name + ")") from None
except Exception as err:
+ # print(traceback.format_exc())
response_data = {"Exception": str(err)}
LOGGER.debug("Error is training, job name:" + trainingjob_name + str(err))
return APP.response_class(response=json.dumps(response_data),status=response_code,
mimetype=MIMETYPE_JSON)
+# Training-Config Handled
@APP.route('/trainingjob/dataExtractionNotification', methods=['POST'])
def data_extraction_notification():
"""
trainingjob_name = request.json["trainingjob_name"]
trainingjob = get_trainingjob_info_by_name(trainingjob_name)
- arguments = trainingjob.arguments
+ arguments = getField(trainingjob.training_config, "arguments")
arguments["version"] = trainingjob.version
+ # Arguments values must be of type string
+ for key, val in arguments.items():
+ if not isinstance(val, str):
+ arguments[key] = str(val)
LOGGER.debug(arguments)
-
+ # Experiment name is harded to be Default
dict_data = {
- "pipeline_name": trainingjob.pipeline_name, "experiment_name": trainingjob.experiment_name,
- "arguments": arguments, "pipeline_version": trainingjob.pipeline_version
+ "pipeline_name": getField(trainingjob.training_config, "pipeline_name"), "experiment_name": 'Default',
+ "arguments": arguments, "pipeline_version": getField(trainingjob.training_config, "pipeline_version")
}
response = training_start(TRAININGMGR_CONFIG_OBJ, dict_data, trainingjob_name)
status=status.HTTP_200_OK,
mimetype=MIMETYPE_JSON)
+# Training-Config Handled (No Change)
@APP.route('/pipelines/<pipe_name>', methods=['GET'])
def get_pipeline_info_by_name(pipe_name):
"""
return APP.response_class(response=json.dumps(api_response),
status=response_code,
mimetype=MIMETYPE_JSON)
-
+
+# Training-Config Handled ..
@APP.route('/trainingjob/pipelineNotification', methods=['POST'])
def pipeline_notification():
"""
model_url = "http://" + str(TRAININGMGR_CONFIG_OBJ.my_ip) + ":" + \
str(TRAININGMGR_CONFIG_OBJ.my_port) + "/model/" + \
trainingjob_name + "/" + str(version) + "/Model.zip"
-
+
update_model_download_url(trainingjob_name, version, model_url, PS_DB_OBJ)
# upload to the mme
trainingjob_info=get_trainingjob_info_by_name(trainingjob_name)
- is_mme= trainingjob_info.is_mme
+ is_mme = getField(trainingjob_info.training_config, "is_mme")
if is_mme:
model_name=trainingjob_info.model_name #model_name
file=MM_SDK.get_model_zip(trainingjob_name, str(version))
"Pipeline notification success.",
LOGGER, True, trainingjob_name, MM_SDK)
-
+# Training-Config Handled (No Change)
@APP.route('/trainingjobs/latest', methods=['GET'])
def trainingjobs_operations():
"""
status=response_code,
mimetype=MIMETYPE_JSON)
+# Training-Config Handled (No Change) ..
@APP.route("/pipelines/<pipe_name>/upload", methods=['POST'])
def upload_pipeline(pipe_name):
"""
mimetype=MIMETYPE_JSON)
-
-
+# Training-Config Handled (No Change)
@APP.route("/pipelines/<pipeline_name>/versions", methods=['GET'])
def get_versions_for_pipeline(pipeline_name):
"""
status=response_code,
mimetype=MIMETYPE_JSON)
+# Training-Config Handled (No Change)
@APP.route('/pipelines', methods=['GET'])
def get_pipelines():
"""
api_response = {"Exception": str(err)}
return APP.response_class(response=json.dumps(api_response),status=response_code,mimetype=MIMETYPE_JSON)
+# Training-Config Handled (No Change)
@APP.route('/experiments', methods=['GET'])
def get_all_experiment_names():
"""
status=reponse_code,
mimetype=MIMETYPE_JSON)
-
-@APP.route('/trainingjobs/<trainingjob_name>', methods=['POST', 'PUT']) # Handled in GUI
+# Training-Config handled
+@APP.route('/trainingjobs/<trainingjob_name>', methods=['POST', 'PUT'])
def trainingjob_operations(trainingjob_name):
"""
- Rest endpoind to create or update trainingjob
+ Rest endpoint to create or update trainingjob
Precondtion for update : trainingjob's overall_status should be failed
or finished and deletion processs should not be in progress
Args in json:
if post/put request is called
json with below fields are given:
- description: str
- description
- featuregroup_name: str
- feature group name
- pipeline_name: str
- name of pipeline
- experiment_name: str
- name of experiment
- arguments: dict
- key-value pairs related to hyper parameters and
- "trainingjob":<trainingjob_name> key-value pair
- query_filter: str
- string indication sql where clause for filtering out features
- enable_versioning: bool
- flag for trainingjob versioning
- pipeline_version: str
- pipeline version
- datalake_source: str
- string indicating datalake source
- _measurement: str
- _measurement for influx db datalake
- bucket: str
- bucket name for influx db datalake
- is_mme: boolean
- whether mme is enabled
- model_name: str
- name of the model
-
+ modelName: str
+ Name of model
+ trainingConfig: dict
+ Training-Configurations, parameter as follows
+ is_mme: boolean
+ whether mme is enabled
+ description: str
+ description
+ dataPipeline: dict
+ Configurations related to dataPipeline, parameter as follows
+ feature_group_name: str
+ feature group name
+ query_filter: str
+ string indication sql where clause for filtering out features
+ arguments: dict
+ key-value pairs related to hyper parameters and
+ "trainingjob":<trainingjob_name> key-value pair
+ trainingPipeline: dict
+ Configurations related to trainingPipeline, parameter as follows
+ pipeline_name: str
+ name of pipeline
+ pipeline_version: str
+ pipeline version
+ enable_versioning: bool
+ flag for trainingjob versioning
+
Returns:
1. For post request
json:
if not check_trainingjob_name_or_featuregroup_name(trainingjob_name):
return {"Exception":"The trainingjob_name is not correct"}, status.HTTP_400_BAD_REQUEST
+ trainingConfig = request.json["training_config"]
+ if(not validateTrainingConfig(trainingConfig)):
+ return {"Exception":"The TrainingConfig is not correct"}, status.HTTP_400_BAD_REQUEST
+
LOGGER.debug("Training job create/update request(trainingjob name %s) ", trainingjob_name )
try:
json_data = request.json
response_code = status.HTTP_409_CONFLICT
raise TMException("trainingjob name(" + trainingjob_name + ") is already present in database")
else:
- trainingjob = trainingjob_schema.load(request.get_json())
+ processed_json_data = request.get_json()
+ processed_json_data['training_config'] = json.dumps(request.get_json()["training_config"])
+ trainingjob = trainingjob_schema.load(processed_json_data)
model_info=""
- if trainingjob.is_mme:
+ if getField(trainingjob.training_config, "is_mme"):
pipeline_dict =json.loads(TRAININGMGR_CONFIG_OBJ.pipeline)
model_info=get_model_info(TRAININGMGR_CONFIG_OBJ, trainingjob.model_name)
s=model_info["meta-info"]["feature-list"]
response_code = status.HTTP_404_NOT_FOUND
raise TMException("Trainingjob name(" + trainingjob_name + ") is not present in database")
else:
- trainingjob = trainingjob_schema.load(request.get_json())
+ processed_json_data = request.get_json()
+ processed_json_data['training_config'] = json.dumps(request.get_json()["training_config"])
+ trainingjob = trainingjob_schema.load(processed_json_data)
trainingjob_info = get_trainingjob_info_by_name(trainingjob_name)
if trainingjob_info:
if trainingjob_info.deletion_in_progress:
status= response_code,
mimetype=MIMETYPE_JSON)
+# Training-Config Handled (No Change) ..
@APP.route('/trainingjobs/retraining', methods=['POST'])
def retraining():
"""
status=status.HTTP_200_OK,
mimetype='application/json')
+# Training-Config Handled (No Change) ..
@APP.route('/trainingjobs', methods=['DELETE'])
def delete_list_of_trainingjob_version():
"""
status=status.HTTP_200_OK,
mimetype='application/json')
+# Training-Config Handled (No Change)
@APP.route('/trainingjobs/metadata/<trainingjob_name>')
def get_metadata(trainingjob_name):
"""
status=status.HTTP_200_OK,
mimetype='application/json')
+# Training-Config Handled (No Change)
def async_feature_engineering_status():
"""
This function takes trainingjobs from DATAEXTRACTION_JOBS_CACHE and checks data extraction status