Upgrading ModelInfo Model as per R1AP v6 Specs 18/13818/2
authorashishj1729 <jain.ashish@samsung.com>
Thu, 5 Dec 2024 21:34:52 +0000 (03:04 +0530)
committerashishj1729 <jain.ashish@samsung.com>
Fri, 6 Dec 2024 10:31:02 +0000 (16:01 +0530)
Change-Id: If070badab55a6c1c3e4c21afbaec3a6defab44c4
Signed-off-by: ashishj1729 <jain.ashish@samsung.com>
apis/mmes_apis.go
apis_test/mmes_apis_test.go
db/iDB.go
db/modelInfoRepository.go
main.go
models/modelInfo.go
routers/router.go

index f9ba3ae..2c15970 100644 (file)
@@ -18,18 +18,18 @@ limitations under the License.
 package apis
 
 import (
+       "fmt"
        "io"
        "net/http"
        "os"
-       "fmt"
 
        "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/core"
        "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/db"
        "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/logging"
        "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/models"
        "github.com/gin-gonic/gin"
+       "github.com/go-playground/validator/v10"
        "github.com/google/uuid"
-
 )
 
 type MmeApiHandler struct {
@@ -47,7 +47,7 @@ func NewMmeApiHandler(dbMgr core.DBMgr, iDB db.IDB) *MmeApiHandler {
 
 func (m *MmeApiHandler) RegisterModel(cont *gin.Context) {
 
-       var modelInfo models.ModelInfo
+       var modelInfo models.ModelRelatedInformation
 
        if err := cont.ShouldBindJSON(&modelInfo); err != nil {
                cont.JSON(http.StatusBadRequest, gin.H{
@@ -59,15 +59,24 @@ func (m *MmeApiHandler) RegisterModel(cont *gin.Context) {
        id := uuid.New()
        modelInfo.Id = id.String()
 
-       // TODO: validate the object
+       validate := validator.New()
+       if err := validate.Struct(modelInfo); err != nil {
+               cont.JSON(http.StatusBadRequest, gin.H{
+                       "error": err.Error(),
+               })
+               return
+       }
 
        if err := m.iDB.Create(modelInfo); err != nil {
                logging.ERROR("error", err)
+               cont.JSON(http.StatusBadRequest, gin.H{
+                       "Error": err.Error(),
+               })
                return
        }
 
        logging.INFO("model is saved.")
-
+       cont.Header("Location", "/model-registrations/"+id.String())
        cont.JSON(http.StatusCreated, gin.H{
                "modelInfo": modelInfo,
        })
@@ -113,7 +122,7 @@ func (m *MmeApiHandler) GetModelInfo(cont *gin.Context) {
        queryParams := cont.Request.URL.Query()
        //to check only modelName and modelVersion can be passed.
        allowedParams := map[string]bool{
-               "modelName": true,
+               "modelName":    true,
                "modelVersion": true,
        }
 
@@ -126,11 +135,11 @@ func (m *MmeApiHandler) GetModelInfo(cont *gin.Context) {
                }
        }
 
-       modelName:= cont.Query("modelName")
-       modelVersion:= cont.Query("modelVersion")
+       modelName := cont.Query("modelName")
+       modelVersion := cont.Query("modelVersion")
 
        if modelName == "" {
-               //return all modelinfo stored 
+               //return all modelinfo stored
 
                models, err := m.iDB.GetAll()
                if err != nil {
@@ -146,7 +155,7 @@ func (m *MmeApiHandler) GetModelInfo(cont *gin.Context) {
        } else {
                if modelVersion == "" {
                        // get all modelInfo by model name
-                       modelInfos, err:= m.iDB.GetModelInfoByName(modelName)
+                       modelInfos, err := m.iDB.GetModelInfoByName(modelName)
                        if err != nil {
                                statusCode := http.StatusInternalServerError
                                logging.ERROR("Error occurred, send status code: ", statusCode)
@@ -158,14 +167,13 @@ func (m *MmeApiHandler) GetModelInfo(cont *gin.Context) {
                        }
 
                        cont.JSON(http.StatusOK, gin.H{
-                               "modelinfoList":modelInfos,
+                               "modelinfoList": modelInfos,
                        })
                        return
 
-               } else
-               {
+               } else {
                        // get all modelInfo by model name and version
-                       modelInfo, err:= m.iDB.GetModelInfoByNameAndVer(modelName, modelVersion)
+                       modelInfo, err := m.iDB.GetModelInfoByNameAndVer(modelName, modelVersion)
                        if err != nil {
                                statusCode := http.StatusInternalServerError
                                logging.ERROR("Error occurred, send status code: ", statusCode)
@@ -175,7 +183,7 @@ func (m *MmeApiHandler) GetModelInfo(cont *gin.Context) {
                                })
                                return
                        }
-                       if modelInfo.Id == ""{
+                       if modelInfo.ModelId.ModelName != modelName && modelInfo.ModelId.ModelVersion != modelVersion {
                                statusCode := http.StatusNotFound
                                errMessage := fmt.Sprintf("Record not found with modelName: %s and modelVersion: %s", modelName, modelVersion)
                                logging.ERROR("Record not found, send status code: ", statusCode)
@@ -187,7 +195,7 @@ func (m *MmeApiHandler) GetModelInfo(cont *gin.Context) {
                        }
 
                        cont.JSON(http.StatusOK, gin.H{
-                               "modelinfo":modelInfo,
+                               "modelinfo": modelInfo,
                        })
                        return
                }
@@ -196,7 +204,7 @@ func (m *MmeApiHandler) GetModelInfo(cont *gin.Context) {
 
 func (m *MmeApiHandler) GetModelInfoById(cont *gin.Context) {
        logging.INFO("Get model info by id ...")
-       id := cont.Param("id")
+       id := cont.Param("modelRegistrationId")
        modelInfo, err := m.iDB.GetModelInfoById(id)
        if err != nil {
                logging.ERROR("error:", err)
@@ -206,7 +214,7 @@ func (m *MmeApiHandler) GetModelInfoById(cont *gin.Context) {
                })
                return
        }
-       if modelInfo.Id == ""{
+       if modelInfo.Id == "" {
                statusCode := http.StatusNotFound
                errMessage := fmt.Sprintf("Record not found with id: %s", id)
                logging.ERROR("Record not found, send status code: ", statusCode)
@@ -284,8 +292,8 @@ func (m *MmeApiHandler) GetModel(cont *gin.Context) {
 
 func (m *MmeApiHandler) UpdateModel(c *gin.Context) {
        logging.INFO("Updating model...")
-       id := c.Param("id")
-       var modelInfo models.ModelInfo
+       id := c.Param("modelRegistrationId")
+       var modelInfo models.ModelRelatedInformation
 
        if err := c.ShouldBindJSON(&modelInfo); err != nil {
                c.JSON(http.StatusBadRequest, gin.H{
@@ -308,18 +316,23 @@ func (m *MmeApiHandler) UpdateModel(c *gin.Context) {
 
        logging.INFO("model updated")
        c.JSON(http.StatusOK, gin.H{
-               "modelinfo":modelInfo,
+               "modelinfo": modelInfo,
        })
 }
 
-func (m *MmeApiHandler) DeleteModel(c *gin.Context) {
-       logging.INFO("Deleting model...")
-       id := c.Param("id")
-       if err := m.iDB.Delete(id); err != nil {
-               c.JSON(http.StatusInternalServerError, gin.H{
+func (m *MmeApiHandler) DeleteModel(cont *gin.Context) {
+       id := cont.Param("modelRegistrationId")
+       logging.INFO("Deleting model... id = ", id)
+       rowsAffected, err := m.iDB.Delete(id)
+       if err != nil {
+               cont.JSON(http.StatusInternalServerError, gin.H{
                        "error": err.Error(),
                })
                return
        }
-       c.JSON(http.StatusOK, gin.H{"message": "modelInfo deleted"})
+       message := "modelInfo deleted"
+       if rowsAffected == 0 {
+               message = fmt.Sprintf("id = %s already doesn't exist in database", id)
+       }
+       cont.JSON(http.StatusOK, gin.H{"message": message})
 }
index deb5a9e..4a28a82 100644 (file)
@@ -38,17 +38,19 @@ import (
 )
 
 var registerModelBody = `{
-       "id" : "id", 
-       "modelId": {
-               "modelName": "test-model",
-               "modelVersion":"1"
-       },
-       "description": "testing",
-       "metaInfo": {
-               "metadata": {
-                       "author":"tester"
-               }
-       }
+       "id" : "id",
+    "modelId": {
+        "modelName": "model3",
+        "modelVersion" : "2"
+    },
+    "description": "hello world2",
+    "modelInformation": {
+        "metadata": {
+            "author": "someone"
+        },
+        "inputDataType": "pdcpBytesDl,pdcpBytesUl,kpi",
+        "outputDataType": "c, d"
+    }
 }`
 
 type dbMgrMock struct {
@@ -74,27 +76,27 @@ type iDBMock struct {
        db.IDB
 }
 
-func (i *iDBMock) Create(modelInfo models.ModelInfo) error {
+func (i *iDBMock) Create(modelInfo models.ModelRelatedInformation) error {
        args := i.Called(modelInfo)
        return args.Error(0)
 }
-func (i *iDBMock) GetByID(id string) (*models.ModelInfo, error) {
+func (i *iDBMock) GetByID(id string) (*models.ModelRelatedInformation, error) {
        return nil, nil
 }
-func (i *iDBMock) GetAll() ([]models.ModelInfo, error) {
+func (i *iDBMock) GetAll() ([]models.ModelRelatedInformation, error) {
        args := i.Called()
        if _, ok := args.Get(1).(error); !ok {
-               return args.Get(0).([]models.ModelInfo), nil
+               return args.Get(0).([]models.ModelRelatedInformation), nil
        } else {
-               var emptyModelInfo []models.ModelInfo
+               var emptyModelInfo []models.ModelRelatedInformation
                return emptyModelInfo, args.Error(1)
        }
 }
-func (i *iDBMock) Update(modelInfo models.ModelInfo) error {
+func (i *iDBMock) Update(modelInfo models.ModelRelatedInformation) error {
        return nil
 }
-func (i *iDBMock) Delete(id string) error {
-       return nil
+func (i *iDBMock) Delete(id string) (int64, error) {
+       return 1, nil
 }
 
 func TestRegisterModel(t *testing.T) {
@@ -104,7 +106,7 @@ func TestRegisterModel(t *testing.T) {
        handler := apis.NewMmeApiHandler(nil, iDBMockInst)
        router := routers.InitRouter(handler)
        w := httptest.NewRecorder()
-       req, _ := http.NewRequest("POST", "/registerModel", strings.NewReader(registerModelBody))
+       req, _ := http.NewRequest("POST", "/model-registrations", strings.NewReader(registerModelBody))
        router.ServeHTTP(w, req)
        assert.Equal(t, 201, w.Code)
 }
@@ -115,7 +117,7 @@ func TestWhenSuccessGetModelInfoList(t *testing.T) {
 
        // Setting Mock
        iDBmockInst := new(iDBMock)
-       iDBmockInst.On("GetAll").Return([]models.ModelInfo{
+       iDBmockInst.On("GetAll").Return([]models.ModelRelatedInformation{
                {
                        Id: "1234",
                        ModelId: models.ModelID{
@@ -123,10 +125,12 @@ func TestWhenSuccessGetModelInfoList(t *testing.T) {
                                ModelVersion: "v1.0",
                        },
                        Description: "this is test modelINfo",
-                       ModelSpec: models.ModelSpec{
+                       ModelInformation: models.ModelInformation{
                                Metadata: models.Metadata{
-                                       Author: "testing",
+                                       Author: "someone",
                                },
+                               InputDataType:  "pdcpBytesDl,pdcpBytesUl,kpi",
+                               OutputDataType: "c,d",
                        },
                },
        }, nil)
@@ -141,7 +145,7 @@ func TestWhenSuccessGetModelInfoList(t *testing.T) {
        response := responseRecorder.Result()
        body, _ := io.ReadAll(response.Body)
 
-       var modelInfos []models.ModelInfo
+       var modelInfos []models.ModelRelatedInformation
        logging.INFO(modelInfos)
        json.Unmarshal(body, &modelInfos)
 
@@ -155,7 +159,7 @@ func TestWhenFailGetModelInfoList(t *testing.T) {
 
        // Setting Mock
        iDBmockInst2 := new(iDBMock)
-       iDBmockInst2.On("GetAll").Return([]models.ModelInfo{}, fmt.Errorf("db not available"))
+       iDBmockInst2.On("GetAll").Return([]models.ModelRelatedInformation{}, fmt.Errorf("db not available"))
 
        handler := apis.NewMmeApiHandler(nil, iDBmockInst2)
        router := routers.InitRouter(handler)
@@ -167,7 +171,7 @@ func TestWhenFailGetModelInfoList(t *testing.T) {
        response := responseRecorder.Result()
        body, _ := io.ReadAll(response.Body)
 
-       var modelInfoListResp []models.ModelInfo
+       var modelInfoListResp []models.ModelRelatedInformation
        json.Unmarshal(body, &modelInfoListResp)
 
        assert.Equal(t, 500, responseRecorder.Code)
index 0d1be63..c451ea2 100644 (file)
--- a/db/iDB.go
+++ b/db/iDB.go
@@ -22,12 +22,12 @@ import (
 )
 
 type IDB interface {
-       Create(modelInfo models.ModelInfo) error
-       GetByID(id string) (*models.ModelInfo, error)
-       GetAll() ([]models.ModelInfo, error)
-       GetModelInfoByName(modelName string)([]models.ModelInfo, error)
-       GetModelInfoByNameAndVer(modelName string, modelVersion string)(*models.ModelInfo, error)
-       GetModelInfoById(id string)(*models.ModelInfo, error)
-       Update(modelInfo models.ModelInfo) error
-       Delete(id string) error
+       Create(modelInfo models.ModelRelatedInformation) error
+       GetByID(id string) (*models.ModelRelatedInformation, error)
+       GetAll() ([]models.ModelRelatedInformation, error)
+       GetModelInfoByName(modelName string) ([]models.ModelRelatedInformation, error)
+       GetModelInfoByNameAndVer(modelName string, modelVersion string) (*models.ModelRelatedInformation, error)
+       GetModelInfoById(id string) (*models.ModelRelatedInformation, error)
+       Update(modelInfo models.ModelRelatedInformation) error
+       Delete(id string) (int64, error)
 }
index c72b5b1..95e6acd 100644 (file)
@@ -33,17 +33,17 @@ func NewModelInfoRepository(db *gorm.DB) *ModelInfoRepository {
        }
 }
 
-func (repo *ModelInfoRepository) Create(modelInfo models.ModelInfo) error {
-       repo.db.Create(modelInfo)
-       return nil
+func (repo *ModelInfoRepository) Create(modelInfo models.ModelRelatedInformation) error {
+       result := repo.db.Create(modelInfo)
+       return result.Error
 }
 
-func (repo *ModelInfoRepository) GetByID(id string) (*models.ModelInfo, error) {
+func (repo *ModelInfoRepository) GetByID(id string) (*models.ModelRelatedInformation, error) {
        return nil, nil
 }
 
-func (repo *ModelInfoRepository) GetAll() ([]models.ModelInfo, error) {
-       var modelInfos []models.ModelInfo
+func (repo *ModelInfoRepository) GetAll() ([]models.ModelRelatedInformation, error) {
+       var modelInfos []models.ModelRelatedInformation
        result := repo.db.Find(&modelInfos)
        if result.Error != nil {
                return nil, result.Error
@@ -51,22 +51,19 @@ func (repo *ModelInfoRepository) GetAll() ([]models.ModelInfo, error) {
        return modelInfos, nil
 }
 
-func (repo *ModelInfoRepository) Update(modelInfo models.ModelInfo) error {
+func (repo *ModelInfoRepository) Update(modelInfo models.ModelRelatedInformation) error {
        if err := repo.db.Save(modelInfo).Error; err != nil {
                return err
        }
        return nil
 }
 
-func (repo *ModelInfoRepository) Delete(id string) error {
-       logging.INFO("id is:", id)
-       if err := repo.db.Delete(&models.ModelInfo{}, "id=?", id).Error; err != nil {
-               return err
-       }
-       return nil
+func (repo *ModelInfoRepository) Delete(id string) (int64, error) {
+       result := repo.db.Delete(&models.ModelRelatedInformation{}, "id = ?", id)
+       return result.RowsAffected, result.Error
 }
-func (repo *ModelInfoRepository) GetModelInfoByName(modelName string)([]models.ModelInfo, error){
-       var modelInfos []models.ModelInfo
+func (repo *ModelInfoRepository) GetModelInfoByName(modelName string) ([]models.ModelRelatedInformation, error) {
+       var modelInfos []models.ModelRelatedInformation
        result := repo.db.Where("model_name = ?", modelName).Find(&modelInfos)
        if result.Error != nil {
                return nil, result.Error
@@ -74,20 +71,21 @@ func (repo *ModelInfoRepository) GetModelInfoByName(modelName string)([]models.M
        return modelInfos, nil
 }
 
-func (repo *ModelInfoRepository) GetModelInfoByNameAndVer(modelName string, modelVersion string)(*models.ModelInfo, error){
-       var modelInfo models.ModelInfo
+func (repo *ModelInfoRepository) GetModelInfoByNameAndVer(modelName string, modelVersion string) (*models.ModelRelatedInformation, error) {
+       var modelInfo models.ModelRelatedInformation
        result := repo.db.Where("model_name = ? AND model_version = ?", modelName, modelVersion).Find(&modelInfo)
        if result.Error != nil {
                return nil, result.Error
        }
        return &modelInfo, nil
 }
-func (repo *ModelInfoRepository) GetModelInfoById(id string)(*models.ModelInfo, error){
+
+func (repo *ModelInfoRepository) GetModelInfoById(id string) (*models.ModelRelatedInformation, error) {
        logging.INFO("id is:", id)
-       var modelInfo models.ModelInfo
+       var modelInfo models.ModelRelatedInformation
        result := repo.db.Where("id = ?", id).Find(&modelInfo)
        if result.Error != nil {
                return nil, result.Error
        }
        return &modelInfo, nil
-}
\ No newline at end of file
+}
diff --git a/main.go b/main.go
index 963a26f..560c748 100644 (file)
--- a/main.go
+++ b/main.go
@@ -67,7 +67,7 @@ func main() {
        }
 
        // Auto migrate the scheme
-       db.AutoMigrate(&models.ModelInfo{})
+       db.AutoMigrate(&models.ModelRelatedInformation{})
        repo := modelDB.NewModelInfoRepository(db)
 
        router := routers.InitRouter(
index 2d1c702..f154838 100644 (file)
@@ -19,22 +19,35 @@ limitations under the License.
 package models
 
 type Metadata struct {
-       Author string `json:"author"`
+       Author string `json:"author" validate:"required"`
+       Owner  string `json:"owner"`
 }
 
-type ModelSpec struct {
-       Metadata Metadata `json:"metadata" gorm:"embedded"`
+type TargetEnironment struct {
+       PlatformName    string `json:"platformName" validate:"required"`
+       EnvironmentType string `json:"environmentType" validate:"required"`
+       DependencyList  string `json:"dependencyList" validate:"required"`
+}
+
+type ModelInformation struct {
+       Metadata       Metadata `json:"metadata" gorm:"embedded" validate:"required"`
+       InputDataType  string   `json:"inputDataType" validate:"required"`  // this field will be a Comma Separated List
+       OutputDataType string   `json:"outputDataType" validate:"required"` // this field will be a Comma Separated List
+       // TODO: gorm doesn't support list, need to find the right way
+       // TargetEnvironment []TargetEnironment `json:"targetEnvironment" gorm:"embedded"`
 }
 type ModelID struct {
-       ModelName    string `json:"modelName"`
-       ModelVersion string `json:"modelVersion"`
+       ModelName       string `json:"modelName" validate:"required" gorm:"primaryKey"`
+       ModelVersion    string `json:"modelVersion" validate:"required" gorm:"primaryKey"`
+       ArtifactVersion string `json:"artifactVersion"`
 }
 
-type ModelInfo struct {
-       Id          string    `json:"id" gorm:"primaryKey"`
-       ModelId     ModelID   `json:"model-id,omitempty" gorm:"embedded"`
-       Description string    `json:"description"`
-       ModelSpec   ModelSpec `json:"meta-info" gorm:"embedded"`
+type ModelRelatedInformation struct {
+       Id               string           `json:"id" gorm:"unique"`
+       ModelId          ModelID          `json:"modelId,omitempty" validate:"required" gorm:"embedded;primaryKey"`
+       Description      string           `json:"description" validate:"required"`
+       ModelInformation ModelInformation `json:"modelInformation" validate:"required" gorm:"embedded"`
+       ModelLocation    string           `json:"modelLocation"`
 }
 
 type ModelInfoResponse struct {
index 3cd3235..8d4fa5f 100644 (file)
@@ -26,12 +26,13 @@ func InitRouter(handler *apis.MmeApiHandler) *gin.Engine {
        r := gin.New()
        r.Use(gin.Logger())
        r.Use(gin.Recovery())
+       // As per R1-AP v6
+       r.POST("/model-registrations", handler.RegisterModel)
+       r.GET("/model-registrations/:modelRegistrationId", handler.GetModelInfoById)
+       r.PUT("/model-registrations/:modelRegistrationId", handler.UpdateModel)
+       r.DELETE("/model-registrations/:modelRegistrationId", handler.DeleteModel)
 
-       r.POST("/registerModel", handler.RegisterModel)
        r.GET("/getModelInfo", handler.GetModelInfo)
-       r.PUT("/modelInfo/:id", handler.UpdateModel)
-       r.GET("/modelInfo/:id", handler.GetModelInfoById)
-       r.DELETE("/modelInfo/:id", handler.DeleteModel)
        r.GET("/getModelInfo/:modelName", handler.GetModelInfoByName)
        r.POST("/uploadModel/:modelName", handler.UploadModel)
        r.GET("/downloadModel/:modelName/model.zip", handler.DownloadModel)