changes for add trainingjob 73/13773/4
authorrajdeep11 <rajdeep.sin@samsung.com>
Mon, 25 Nov 2024 10:43:06 +0000 (16:13 +0530)
committerrajdeep11 <rajdeep.sin@samsung.com>
Mon, 25 Nov 2024 11:03:04 +0000 (16:33 +0530)
Change-Id: I0287b4f51348eb219da68294c0aa001f3aecb226
Signed-off-by: rajdeep11 <rajdeep.sin@samsung.com>
trainingmgr/controller/trainingjob_controller.py
trainingmgr/db/trainingjob_db.py
trainingmgr/models/__init__.py
trainingmgr/schemas/trainingjob_schema.py
trainingmgr/service/training_job_service.py
trainingmgr/trainingmgr_main.py

index 574de25..ee17c47 100644 (file)
 #   limitations under the License.
 #
 # ==================================================================================
+
+import json
 from flask import Blueprint, jsonify, request
+from flask_api import status
+from marshmallow import ValidationError
+from trainingmgr.common.exceptions_utls import TMException
 from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
-from trainingmgr.service.training_job_service import delete_training_job, create_training_job, get_training_job, get_trainining_jobs
+from trainingmgr.schemas.trainingjob_schema import TrainingJobSchema
+from trainingmgr.service.training_job_service import delete_training_job, create_training_job, get_training_job, get_trainingjob_by_modelId, get_trainining_jobs
+from trainingmgr.common.trainingmgr_util import check_key_in_dictionary
+from trainingmgr.common.trainingConfig_parser import validateTrainingConfig
 
 training_job_controller = Blueprint('training_job_controller', __name__)
 LOGGER = TrainingMgrConfig().logger
 
+trainingjob_schema = TrainingJobSchema()
+trainingjobs_schema = TrainingJobSchema(many=True)
+
 @training_job_controller.route('/training-jobs/<int:training_job_id>', methods=['DELETE'])
 def delete_trainingjob(training_job_id):
     LOGGER.debug(f'delete training job : {training_job_id}')
@@ -40,14 +51,39 @@ def delete_trainingjob(training_job_id):
             'message': str(e)
         }), 500
     
+    
 @training_job_controller.route('/training-jobs', methods=['POST'])
 def create_trainingjob():
+
     try:
-        data = request.get_json()
-        create_training_job(data)
-        LOGGER.debug(f'create training job Successfully: {data}')
-        return '', 200
 
+        request_json = request.get_json()
+
+        if check_key_in_dictionary(["training_config"], request_json):
+            request_json['training_config'] = json.dumps(request_json["training_config"])
+        else:
+            return jsonify({'Exception': 'The training_config is missing'}), status.HTTP_400_BAD_REQUEST
+        
+        trainingjob = trainingjob_schema.load(request_json)
+
+        model_id = trainingjob.modelId
+        
+        trainingConfig = trainingjob.training_config
+        if(not validateTrainingConfig(trainingConfig)):
+            return jsonify({'Exception': 'The TrainingConfig is not correct'}), status.HTTP_400_BAD_REQUEST
+        
+        #check if trainingjob is already present with name
+        trainingjob_db = get_trainingjob_by_modelId(model_id)
+
+        if trainingjob_db != None:
+            return jsonify({"Exception":f"modelId {model_id.modelname} and {model_id.modelversion} is already present in database"}), status.HTTP_409_CONFLICT
+
+        create_training_job(trainingjob)
+
+        return jsonify({"Trainingjob": trainingjob_schema.dump(trainingjob)}), 201
+        
+    except ValidationError as error:
+        return jsonify(error.messages), status.HTTP_400_BAD_REQUEST
     except Exception as e:
         return jsonify({
             'message': str(e)
index befeeb6..f422464 100644 (file)
@@ -20,12 +20,13 @@ import datetime
 import re
 import json
 from trainingmgr.common.exceptions_utls import DBException
-from trainingmgr.models import db, TrainingJob, TrainingJobStatus
+from trainingmgr.common.trainingConfig_parser import getField
+from trainingmgr.models import db, TrainingJob, TrainingJobStatus, ModelID
 from trainingmgr.constants.steps import Steps
 from trainingmgr.constants.states import States
 from sqlalchemy.sql import func
-from sqlalchemy.exc import NoResultFound
-from trainingmgr.common.trainingConfig_parser import getField
+from sqlalchemy.orm.exc import NoResultFound
+
 
 
 
@@ -38,6 +39,20 @@ def get_all_versions_info_by_name(trainingjob_name):
     """   
     return TrainingJob.query.filter_by(trainingjob_name=trainingjob_name).all()
 
+
+def get_trainingjob_info_by_name(trainingjob_name):
+    """
+    This function returns information of training job by name and 
+    by default latest version
+    """
+
+    try:
+        trainingjob_max_version = TrainingJob.query.filter(TrainingJob.trainingjob_name == trainingjob_name).order_by(TrainingJob.version.desc()).first()
+    except Exception as err:
+        raise DBException(DB_QUERY_EXEC_ERROR + \
+            "get_trainingjob_info_by_name"  + str(err))
+    return trainingjob_max_version
+
 def add_update_trainingjob(trainingjob, adding):
     """
     This function add the new row or update existing row with given information
@@ -84,20 +99,7 @@ def add_update_trainingjob(trainingjob, adding):
     except Exception as err:
         raise DBException(DB_QUERY_EXEC_ERROR + \
             "add_update_trainingjob"  + str(err))
-
-def get_trainingjob_info_by_name(trainingjob_name):
-    """
-    This function returns information of training job by name and 
-    by default latest version
-    """
-
-    try:
-        trainingjob_max_version = TrainingJob.query.filter(TrainingJob.trainingjob_name == trainingjob_name).order_by(TrainingJob.version.desc()).first()
-    except Exception as err:
-        raise DBException(DB_QUERY_EXEC_ERROR + \
-            "get_trainingjob_info_by_name"  + str(err))
-    return trainingjob_max_version
-
+    
 def get_info_by_version(trainingjob_name, version):
     """
     This function returns information for given <trainingjob_name, version> trainingjob.
@@ -288,11 +290,26 @@ def delete_trainingjob_version(trainingjob_name, version):
         raise DBException(DB_QUERY_EXEC_ERROR + \
             "delete_trainingjob_version" + str(err))
 
-from trainingmgr.schemas import TrainingJobSchema
-def create_trainingjob(data):
-        tj = TrainingJobSchema().load(data)
-        db.session.add(tj)
-        db.session.commit()
+def create_trainingjob(trainingjob):
+        
+        steps_state = {
+            Steps.DATA_EXTRACTION.name: States.NOT_STARTED.name,
+            Steps.DATA_EXTRACTION_AND_TRAINING.name: States.NOT_STARTED.name,
+            Steps.TRAINING.name: States.NOT_STARTED.name,
+            Steps.TRAINING_AND_TRAINED_MODEL.name: States.NOT_STARTED.name,
+            Steps.TRAINED_MODEL.name: States.NOT_STARTED.name
+        }
+
+        try:
+            training_job_status = TrainingJobStatus(states= json.dumps(steps_state))
+            db.session.add(training_job_status)
+            db.session.commit()     #to get the steps_state id
+
+            trainingjob.steps_state_id = training_job_status.id
+            db.session.add(trainingjob)
+            db.session.commit()
+        except Exception as err:
+            raise DBException(f'{DB_QUERY_EXEC_ERROR} in the create_trainingjob : {str(err)}')
 
 def delete_trainingjob_by_id(id: int):
     """
@@ -325,3 +342,20 @@ def get_trainingjob(id: int=None):
         tjs = TrainingJob.query.all()
         return tjs
     return tj
+
+def get_trainingjob_by_modelId_db(model_id):
+    try:
+        trainingjob = (
+            db.session.query(TrainingJob)
+            .join(ModelID)
+            .filter(
+                ModelID.modelname == model_id.modelname,
+                ModelID.modelversion == model_id.modelversion
+            )
+            .one()
+        )
+        return trainingjob
+    except NoResultFound:
+        return None
+    except Exception as e:
+        raise DBException(f'{DB_QUERY_EXEC_ERROR} in the get_trainingjob_by_modelId_db : {str(e)}')
\ No newline at end of file
index fdb06a8..6ecb42a 100644 (file)
@@ -19,7 +19,7 @@ from flask_sqlalchemy import SQLAlchemy
 
 db = SQLAlchemy()
 
-from trainingmgr.models.trainingjob import TrainingJob
+from trainingmgr.models.trainingjob import TrainingJob, ModelID
 from trainingmgr.models.featuregroup import FeatureGroup
 from trainingmgr.models.steps_state import TrainingJobStatus
 
index 23d8a42..549b7c5 100644 (file)
 #
 # ==================================================================================
 
+import re
 from trainingmgr.schemas import ma
 from trainingmgr.models import TrainingJob
 from trainingmgr.models.trainingjob import ModelID
 
-from marshmallow import pre_load
+from marshmallow import pre_load, validates, ValidationError
+
+PATTERN = re.compile(r"\w+")
 
 class ModelSchema(ma.SQLAlchemyAutoSchema):
     class Meta:
         model = ModelID
         load_instance = True
+
 class TrainingJobSchema(ma.SQLAlchemyAutoSchema):
     class Meta:
         model = TrainingJob
         load_instance = True
+        exclude = ("creation_time", "deletion_in_progress", "version", "updation_time","run_id")
     
     modelId = ma.Nested(ModelSchema)
+
+    @validates("trainingjob_name")
+    def validate_trainingjob_name(self, value):
+
+        if not (3<= len(value) <=50):
+            raise ValidationError("Training job name length must be between 3 and 50 characters")
+        
+        if not PATTERN.fullmatch(value):
+            raise ValidationError("Training job name must be alphanumeric and underscore only.")
     
     @pre_load
     def processModelId(self, data, **kwargs):
index da569fa..8b5a28b 100644 (file)
@@ -15,8 +15,8 @@
 #   limitations under the License.
 #
 # ==================================================================================
-from trainingmgr.db.trainingjob_db import delete_trainingjob_by_id, create_trainingjob, get_trainingjob
-from trainingmgr.common.exceptions_utls import DBException
+from trainingmgr.db.trainingjob_db import delete_trainingjob_by_id, create_trainingjob, get_trainingjob, get_trainingjob_by_modelId_db
+from trainingmgr.common.exceptions_utls import DBException, TMException
 from trainingmgr.schemas import TrainingJobSchema
 
 trainingJobSchema = TrainingJobSchema()
@@ -31,8 +31,11 @@ def get_trainining_jobs():
     result = trainingJobsSchema.dump(tjs)
     return result
 
-def create_training_job(data):
-    create_trainingjob(data)
+def create_training_job(trainingjob):
+    try:
+        create_trainingjob(trainingjob)
+    except DBException as err:
+        raise TMException(f"create_training_job failed with exception : {str(err)}")
     
 
 def delete_training_job(training_job_id : int):
@@ -54,4 +57,12 @@ def delete_training_job(training_job_id : int):
         return delete_trainingjob_by_id(id=training_job_id)
     except Exception as err :
         raise DBException(f"delete_trainining_job failed with exception : {str(err)}")
+    
+def get_trainingjob_by_modelId(model_id):
+    try:
+        trainingjob = get_trainingjob_by_modelId_db(model_id)
+        return trainingjob
+
+    except Exception as err:
+        raise DBException(f"get_trainingjob_by_modelId failed with exception : {str(err)}")
 
index 0cdf094..8da790f 100644 (file)
@@ -67,6 +67,7 @@ TRAININGMGR_CONFIG_OBJ = TrainingMgrConfig()
 from middleware.loggingMiddleware import LoggingMiddleware
 APP.wsgi_app = LoggingMiddleware(APP.wsgi_app)
 APP.register_blueprint(training_job_controller)
+
 PS_DB_OBJ = None
 LOGGER = None
 MM_SDK = None