# limitations under the License.
#
# ==================================================================================
+
+import json
from flask import Blueprint, jsonify, request
+from flask_api import status
+from marshmallow import ValidationError
+from trainingmgr.common.exceptions_utls import TMException
from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
-from trainingmgr.service.training_job_service import delete_training_job, create_training_job, get_training_job, get_trainining_jobs
+from trainingmgr.schemas.trainingjob_schema import TrainingJobSchema
+from trainingmgr.service.training_job_service import delete_training_job, create_training_job, get_training_job, get_trainingjob_by_modelId, get_trainining_jobs
+from trainingmgr.common.trainingmgr_util import check_key_in_dictionary
+from trainingmgr.common.trainingConfig_parser import validateTrainingConfig
training_job_controller = Blueprint('training_job_controller', __name__)
LOGGER = TrainingMgrConfig().logger
+trainingjob_schema = TrainingJobSchema()
+trainingjobs_schema = TrainingJobSchema(many=True)
+
@training_job_controller.route('/training-jobs/<int:training_job_id>', methods=['DELETE'])
def delete_trainingjob(training_job_id):
LOGGER.debug(f'delete training job : {training_job_id}')
'message': str(e)
}), 500
+
@training_job_controller.route('/training-jobs', methods=['POST'])
def create_trainingjob():
+
try:
- data = request.get_json()
- create_training_job(data)
- LOGGER.debug(f'create training job Successfully: {data}')
- return '', 200
+ request_json = request.get_json()
+
+ if check_key_in_dictionary(["training_config"], request_json):
+ request_json['training_config'] = json.dumps(request_json["training_config"])
+ else:
+ return jsonify({'Exception': 'The training_config is missing'}), status.HTTP_400_BAD_REQUEST
+
+ trainingjob = trainingjob_schema.load(request_json)
+
+ model_id = trainingjob.modelId
+
+ trainingConfig = trainingjob.training_config
+ if(not validateTrainingConfig(trainingConfig)):
+ return jsonify({'Exception': 'The TrainingConfig is not correct'}), status.HTTP_400_BAD_REQUEST
+
+ #check if trainingjob is already present with name
+ trainingjob_db = get_trainingjob_by_modelId(model_id)
+
+ if trainingjob_db != None:
+ return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is already present in database"}), status.HTTP_409_CONFLICT
+
+ create_training_job(trainingjob)
+
+ return jsonify({"Trainingjob": trainingjob_schema.dump(trainingjob)}), 201
+
+ except ValidationError as error:
+ return jsonify(error.messages), status.HTTP_400_BAD_REQUEST
except Exception as e:
return jsonify({
'message': str(e)
import re
import json
from trainingmgr.common.exceptions_utls import DBException
-from trainingmgr.models import db, TrainingJob, TrainingJobStatus
+from trainingmgr.common.trainingConfig_parser import getField
+from trainingmgr.models import db, TrainingJob, TrainingJobStatus, ModelID
from trainingmgr.constants.steps import Steps
from trainingmgr.constants.states import States
from sqlalchemy.sql import func
-from sqlalchemy.exc import NoResultFound
-from trainingmgr.common.trainingConfig_parser import getField
+from sqlalchemy.orm.exc import NoResultFound
+
"""
return TrainingJob.query.filter_by(trainingjob_name=trainingjob_name).all()
+
+def get_trainingjob_info_by_name(trainingjob_name):
+ """
+ This function returns information of training job by name and
+ by default latest version
+ """
+
+ try:
+ trainingjob_max_version = TrainingJob.query.filter(TrainingJob.trainingjob_name == trainingjob_name).order_by(TrainingJob.version.desc()).first()
+ except Exception as err:
+ raise DBException(DB_QUERY_EXEC_ERROR + \
+ "get_trainingjob_info_by_name" + str(err))
+ return trainingjob_max_version
+
def add_update_trainingjob(trainingjob, adding):
"""
This function add the new row or update existing row with given information
except Exception as err:
raise DBException(DB_QUERY_EXEC_ERROR + \
"add_update_trainingjob" + str(err))
-
-def get_trainingjob_info_by_name(trainingjob_name):
- """
- This function returns information of training job by name and
- by default latest version
- """
-
- try:
- trainingjob_max_version = TrainingJob.query.filter(TrainingJob.trainingjob_name == trainingjob_name).order_by(TrainingJob.version.desc()).first()
- except Exception as err:
- raise DBException(DB_QUERY_EXEC_ERROR + \
- "get_trainingjob_info_by_name" + str(err))
- return trainingjob_max_version
-
+
def get_info_by_version(trainingjob_name, version):
"""
This function returns information for given <trainingjob_name, version> trainingjob.
raise DBException(DB_QUERY_EXEC_ERROR + \
"delete_trainingjob_version" + str(err))
-from trainingmgr.schemas import TrainingJobSchema
-def create_trainingjob(data):
- tj = TrainingJobSchema().load(data)
- db.session.add(tj)
- db.session.commit()
+def create_trainingjob(trainingjob):
+
+ steps_state = {
+ Steps.DATA_EXTRACTION.name: States.NOT_STARTED.name,
+ Steps.DATA_EXTRACTION_AND_TRAINING.name: States.NOT_STARTED.name,
+ Steps.TRAINING.name: States.NOT_STARTED.name,
+ Steps.TRAINING_AND_TRAINED_MODEL.name: States.NOT_STARTED.name,
+ Steps.TRAINED_MODEL.name: States.NOT_STARTED.name
+ }
+
+ try:
+ training_job_status = TrainingJobStatus(states= json.dumps(steps_state))
+ db.session.add(training_job_status)
+ db.session.commit() #to get the steps_state id
+
+ trainingjob.steps_state_id = training_job_status.id
+ db.session.add(trainingjob)
+ db.session.commit()
+ except Exception as err:
+ raise DBException(f'{DB_QUERY_EXEC_ERROR} in the create_trainingjob : {str(err)}')
def delete_trainingjob_by_id(id: int):
"""
tjs = TrainingJob.query.all()
return tjs
return tj
+
+def get_trainingjob_by_modelId_db(model_id):
+ try:
+ trainingjob = (
+ db.session.query(TrainingJob)
+ .join(ModelID)
+ .filter(
+ ModelID.modelname == model_id.modelname,
+ ModelID.modelversion == model_id.modelversion
+ )
+ .one()
+ )
+ return trainingjob
+ except NoResultFound:
+ return None
+ except Exception as e:
+ raise DBException(f'{DB_QUERY_EXEC_ERROR} in the get_trainingjob_by_modelId_db : {str(e)}')
\ No newline at end of file
db = SQLAlchemy()
-from trainingmgr.models.trainingjob import TrainingJob
+from trainingmgr.models.trainingjob import TrainingJob, ModelID
from trainingmgr.models.featuregroup import FeatureGroup
from trainingmgr.models.steps_state import TrainingJobStatus
#
# ==================================================================================
+import re
from trainingmgr.schemas import ma
from trainingmgr.models import TrainingJob
from trainingmgr.models.trainingjob import ModelID
-from marshmallow import pre_load
+from marshmallow import pre_load, validates, ValidationError
+
+PATTERN = re.compile(r"\w+")
class ModelSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = ModelID
load_instance = True
+
class TrainingJobSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = TrainingJob
load_instance = True
+ exclude = ("creation_time", "deletion_in_progress", "version", "updation_time","run_id")
modelId = ma.Nested(ModelSchema)
+
+ @validates("trainingjob_name")
+ def validate_trainingjob_name(self, value):
+
+ if not (3<= len(value) <=50):
+ raise ValidationError("Training job name length must be between 3 and 50 characters")
+
+ if not PATTERN.fullmatch(value):
+ raise ValidationError("Training job name must be alphanumeric and underscore only.")
@pre_load
def processModelId(self, data, **kwargs):
# limitations under the License.
#
# ==================================================================================
-from trainingmgr.db.trainingjob_db import delete_trainingjob_by_id, create_trainingjob, get_trainingjob
-from trainingmgr.common.exceptions_utls import DBException
+from trainingmgr.db.trainingjob_db import delete_trainingjob_by_id, create_trainingjob, get_trainingjob, get_trainingjob_by_modelId_db
+from trainingmgr.common.exceptions_utls import DBException, TMException
from trainingmgr.schemas import TrainingJobSchema
trainingJobSchema = TrainingJobSchema()
result = trainingJobsSchema.dump(tjs)
return result
-def create_training_job(data):
- create_trainingjob(data)
+def create_training_job(trainingjob):
+ try:
+ create_trainingjob(trainingjob)
+ except DBException as err:
+ raise TMException(f"create_training_job failed with exception : {str(err)}")
def delete_training_job(training_job_id : int):
return delete_trainingjob_by_id(id=training_job_id)
except Exception as err :
raise DBException(f"delete_trainining_job failed with exception : {str(err)}")
+
+def get_trainingjob_by_modelId(model_id):
+ try:
+ trainingjob = get_trainingjob_by_modelId_db(model_id)
+ return trainingjob
+
+ except Exception as err:
+ raise DBException(f"get_trainingjob_by_modelId failed with exception : {str(err)}")
from middleware.loggingMiddleware import LoggingMiddleware
APP.wsgi_app = LoggingMiddleware(APP.wsgi_app)
APP.register_blueprint(training_job_controller)
+
PS_DB_OBJ = None
LOGGER = None
MM_SDK = None