From: rajdeep11 Date: Mon, 25 Nov 2024 10:43:06 +0000 (+0530) Subject: changes for add trainingjob X-Git-Tag: 3.0.0~31 X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=commitdiff_plain;h=dedc5b9a9b8730858492e7f6c04db768fda9ce29;p=aiml-fw%2Fawmf%2Ftm.git changes for add trainingjob Change-Id: I0287b4f51348eb219da68294c0aa001f3aecb226 Signed-off-by: rajdeep11 --- diff --git a/trainingmgr/controller/trainingjob_controller.py b/trainingmgr/controller/trainingjob_controller.py index 574de25..ee17c47 100644 --- a/trainingmgr/controller/trainingjob_controller.py +++ b/trainingmgr/controller/trainingjob_controller.py @@ -15,13 +15,24 @@ # limitations under the License. # # ================================================================================== + +import json from flask import Blueprint, jsonify, request +from flask_api import status +from marshmallow import ValidationError +from trainingmgr.common.exceptions_utls import TMException from trainingmgr.common.trainingmgr_config import TrainingMgrConfig -from trainingmgr.service.training_job_service import delete_training_job, create_training_job, get_training_job, get_trainining_jobs +from trainingmgr.schemas.trainingjob_schema import TrainingJobSchema +from trainingmgr.service.training_job_service import delete_training_job, create_training_job, get_training_job, get_trainingjob_by_modelId, get_trainining_jobs +from trainingmgr.common.trainingmgr_util import check_key_in_dictionary +from trainingmgr.common.trainingConfig_parser import validateTrainingConfig training_job_controller = Blueprint('training_job_controller', __name__) LOGGER = TrainingMgrConfig().logger +trainingjob_schema = TrainingJobSchema() +trainingjobs_schema = TrainingJobSchema(many=True) + @training_job_controller.route('/training-jobs/', methods=['DELETE']) def delete_trainingjob(training_job_id): LOGGER.debug(f'delete training job : {training_job_id}') @@ -40,14 +51,39 @@ def delete_trainingjob(training_job_id): 'message': str(e) }), 500 + @training_job_controller.route('/training-jobs', methods=['POST']) def create_trainingjob(): + try: - data = request.get_json() - create_training_job(data) - LOGGER.debug(f'create training job Successfully: {data}') - return '', 200 + request_json = request.get_json() + + if check_key_in_dictionary(["training_config"], request_json): + request_json['training_config'] = json.dumps(request_json["training_config"]) + else: + return jsonify({'Exception': 'The training_config is missing'}), status.HTTP_400_BAD_REQUEST + + trainingjob = trainingjob_schema.load(request_json) + + model_id = trainingjob.modelId + + trainingConfig = trainingjob.training_config + if(not validateTrainingConfig(trainingConfig)): + return jsonify({'Exception': 'The TrainingConfig is not correct'}), status.HTTP_400_BAD_REQUEST + + #check if trainingjob is already present with name + trainingjob_db = get_trainingjob_by_modelId(model_id) + + if trainingjob_db != None: + return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is already present in database"}), status.HTTP_409_CONFLICT + + create_training_job(trainingjob) + + return jsonify({"Trainingjob": trainingjob_schema.dump(trainingjob)}), 201 + + except ValidationError as error: + return jsonify(error.messages), status.HTTP_400_BAD_REQUEST except Exception as e: return jsonify({ 'message': str(e) diff --git a/trainingmgr/db/trainingjob_db.py b/trainingmgr/db/trainingjob_db.py index befeeb6..f422464 100644 --- a/trainingmgr/db/trainingjob_db.py +++ b/trainingmgr/db/trainingjob_db.py @@ -20,12 +20,13 @@ import datetime import re import json from trainingmgr.common.exceptions_utls import DBException -from trainingmgr.models import db, TrainingJob, TrainingJobStatus +from trainingmgr.common.trainingConfig_parser import getField +from trainingmgr.models import db, TrainingJob, TrainingJobStatus, ModelID from trainingmgr.constants.steps import Steps from trainingmgr.constants.states import States from sqlalchemy.sql import func -from sqlalchemy.exc import NoResultFound -from trainingmgr.common.trainingConfig_parser import getField +from sqlalchemy.orm.exc import NoResultFound + @@ -38,6 +39,20 @@ def get_all_versions_info_by_name(trainingjob_name): """ return TrainingJob.query.filter_by(trainingjob_name=trainingjob_name).all() + +def get_trainingjob_info_by_name(trainingjob_name): + """ + This function returns information of training job by name and + by default latest version + """ + + try: + trainingjob_max_version = TrainingJob.query.filter(TrainingJob.trainingjob_name == trainingjob_name).order_by(TrainingJob.version.desc()).first() + except Exception as err: + raise DBException(DB_QUERY_EXEC_ERROR + \ + "get_trainingjob_info_by_name" + str(err)) + return trainingjob_max_version + def add_update_trainingjob(trainingjob, adding): """ This function add the new row or update existing row with given information @@ -84,20 +99,7 @@ def add_update_trainingjob(trainingjob, adding): except Exception as err: raise DBException(DB_QUERY_EXEC_ERROR + \ "add_update_trainingjob" + str(err)) - -def get_trainingjob_info_by_name(trainingjob_name): - """ - This function returns information of training job by name and - by default latest version - """ - - try: - trainingjob_max_version = TrainingJob.query.filter(TrainingJob.trainingjob_name == trainingjob_name).order_by(TrainingJob.version.desc()).first() - except Exception as err: - raise DBException(DB_QUERY_EXEC_ERROR + \ - "get_trainingjob_info_by_name" + str(err)) - return trainingjob_max_version - + def get_info_by_version(trainingjob_name, version): """ This function returns information for given trainingjob. @@ -288,11 +290,26 @@ def delete_trainingjob_version(trainingjob_name, version): raise DBException(DB_QUERY_EXEC_ERROR + \ "delete_trainingjob_version" + str(err)) -from trainingmgr.schemas import TrainingJobSchema -def create_trainingjob(data): - tj = TrainingJobSchema().load(data) - db.session.add(tj) - db.session.commit() +def create_trainingjob(trainingjob): + + steps_state = { + Steps.DATA_EXTRACTION.name: States.NOT_STARTED.name, + Steps.DATA_EXTRACTION_AND_TRAINING.name: States.NOT_STARTED.name, + Steps.TRAINING.name: States.NOT_STARTED.name, + Steps.TRAINING_AND_TRAINED_MODEL.name: States.NOT_STARTED.name, + Steps.TRAINED_MODEL.name: States.NOT_STARTED.name + } + + try: + training_job_status = TrainingJobStatus(states= json.dumps(steps_state)) + db.session.add(training_job_status) + db.session.commit() #to get the steps_state id + + trainingjob.steps_state_id = training_job_status.id + db.session.add(trainingjob) + db.session.commit() + except Exception as err: + raise DBException(f'{DB_QUERY_EXEC_ERROR} in the create_trainingjob : {str(err)}') def delete_trainingjob_by_id(id: int): """ @@ -325,3 +342,20 @@ def get_trainingjob(id: int=None): tjs = TrainingJob.query.all() return tjs return tj + +def get_trainingjob_by_modelId_db(model_id): + try: + trainingjob = ( + db.session.query(TrainingJob) + .join(ModelID) + .filter( + ModelID.modelname == model_id.modelname, + ModelID.modelversion == model_id.modelversion + ) + .one() + ) + return trainingjob + except NoResultFound: + return None + except Exception as e: + raise DBException(f'{DB_QUERY_EXEC_ERROR} in the get_trainingjob_by_modelId_db : {str(e)}') \ No newline at end of file diff --git a/trainingmgr/models/__init__.py b/trainingmgr/models/__init__.py index fdb06a8..6ecb42a 100644 --- a/trainingmgr/models/__init__.py +++ b/trainingmgr/models/__init__.py @@ -19,7 +19,7 @@ from flask_sqlalchemy import SQLAlchemy db = SQLAlchemy() -from trainingmgr.models.trainingjob import TrainingJob +from trainingmgr.models.trainingjob import TrainingJob, ModelID from trainingmgr.models.featuregroup import FeatureGroup from trainingmgr.models.steps_state import TrainingJobStatus diff --git a/trainingmgr/schemas/trainingjob_schema.py b/trainingmgr/schemas/trainingjob_schema.py index 23d8a42..549b7c5 100644 --- a/trainingmgr/schemas/trainingjob_schema.py +++ b/trainingmgr/schemas/trainingjob_schema.py @@ -16,22 +16,36 @@ # # ================================================================================== +import re from trainingmgr.schemas import ma from trainingmgr.models import TrainingJob from trainingmgr.models.trainingjob import ModelID -from marshmallow import pre_load +from marshmallow import pre_load, validates, ValidationError + +PATTERN = re.compile(r"\w+") class ModelSchema(ma.SQLAlchemyAutoSchema): class Meta: model = ModelID load_instance = True + class TrainingJobSchema(ma.SQLAlchemyAutoSchema): class Meta: model = TrainingJob load_instance = True + exclude = ("creation_time", "deletion_in_progress", "version", "updation_time","run_id") modelId = ma.Nested(ModelSchema) + + @validates("trainingjob_name") + def validate_trainingjob_name(self, value): + + if not (3<= len(value) <=50): + raise ValidationError("Training job name length must be between 3 and 50 characters") + + if not PATTERN.fullmatch(value): + raise ValidationError("Training job name must be alphanumeric and underscore only.") @pre_load def processModelId(self, data, **kwargs): diff --git a/trainingmgr/service/training_job_service.py b/trainingmgr/service/training_job_service.py index da569fa..8b5a28b 100644 --- a/trainingmgr/service/training_job_service.py +++ b/trainingmgr/service/training_job_service.py @@ -15,8 +15,8 @@ # limitations under the License. # # ================================================================================== -from trainingmgr.db.trainingjob_db import delete_trainingjob_by_id, create_trainingjob, get_trainingjob -from trainingmgr.common.exceptions_utls import DBException +from trainingmgr.db.trainingjob_db import delete_trainingjob_by_id, create_trainingjob, get_trainingjob, get_trainingjob_by_modelId_db +from trainingmgr.common.exceptions_utls import DBException, TMException from trainingmgr.schemas import TrainingJobSchema trainingJobSchema = TrainingJobSchema() @@ -31,8 +31,11 @@ def get_trainining_jobs(): result = trainingJobsSchema.dump(tjs) return result -def create_training_job(data): - create_trainingjob(data) +def create_training_job(trainingjob): + try: + create_trainingjob(trainingjob) + except DBException as err: + raise TMException(f"create_training_job failed with exception : {str(err)}") def delete_training_job(training_job_id : int): @@ -54,4 +57,12 @@ def delete_training_job(training_job_id : int): return delete_trainingjob_by_id(id=training_job_id) except Exception as err : raise DBException(f"delete_trainining_job failed with exception : {str(err)}") + +def get_trainingjob_by_modelId(model_id): + try: + trainingjob = get_trainingjob_by_modelId_db(model_id) + return trainingjob + + except Exception as err: + raise DBException(f"get_trainingjob_by_modelId failed with exception : {str(err)}") diff --git a/trainingmgr/trainingmgr_main.py b/trainingmgr/trainingmgr_main.py index 0cdf094..8da790f 100644 --- a/trainingmgr/trainingmgr_main.py +++ b/trainingmgr/trainingmgr_main.py @@ -67,6 +67,7 @@ TRAININGMGR_CONFIG_OBJ = TrainingMgrConfig() from middleware.loggingMiddleware import LoggingMiddleware APP.wsgi_app = LoggingMiddleware(APP.wsgi_app) APP.register_blueprint(training_job_controller) + PS_DB_OBJ = None LOGGER = None MM_SDK = None