From fb99cf3a1bb4dc2db1688a800464c97ee2a77ddb Mon Sep 17 00:00:00 2001 From: rajdeep11 Date: Tue, 8 Oct 2024 20:06:25 +0530 Subject: [PATCH] adding the model discovery Change-Id: I7763f1d2283766f89cdddc8e72fbc1c605692dda Signed-off-by: rajdeep11 --- apis/mmes_apis.go | 126 ++++++++++++++++++++++++++++++++++------------ db/iDB.go | 2 + db/modelInfoRepository.go | 17 +++++++ 3 files changed, 112 insertions(+), 33 deletions(-) diff --git a/apis/mmes_apis.go b/apis/mmes_apis.go index 995bef2..afbe154 100644 --- a/apis/mmes_apis.go +++ b/apis/mmes_apis.go @@ -21,6 +21,7 @@ import ( "io" "net/http" "os" + "fmt" "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/core" "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/db" @@ -107,42 +108,101 @@ func (m *MmeApiHandler) RegisterModel(cont *gin.Context) { This API retrieves model info list managed in modelmgmtservice */ func (m *MmeApiHandler) GetModelInfo(cont *gin.Context) { - logging.INFO("List all model API") - - models, err := m.iDB.GetAll() - if err != nil { - logging.ERROR("error:", err) - cont.JSON(http.StatusInternalServerError, gin.H{ - "code": http.StatusInternalServerError, - "message": err.Error(), - }) - return - } - cont.JSON(http.StatusOK, models) - // bucketList, err := m.dbmgr.ListBucket(os.Getenv("INFO_FILE_POSTFIX")) - // if err != nil { - // statusCode := http.StatusInternalServerError - // logging.ERROR("Error occurred, send status code: ", statusCode) - // cont.JSON(statusCode, gin.H{ - // "code": statusCode, - // "message": "Unexpected Error in server, you can't get model information list", - // }) - // return - // } + logging.INFO("Get model info ") + queryParams := cont.Request.URL.Query() + //to check only modelName and modelVersion can be passed. + allowedParams := map[string]bool{ + "modelName": true, + "modelVersion": true, + } - // modelInfoListResp := []models.ModelInfoResponse{} - // for _, bucket := range bucketList { - // modelInfoListResp = append(modelInfoListResp, models.ModelInfoResponse{ - // Name: bucket.Name, - // Data: string(bucket.Object), - // }) - // } + for key := range queryParams { + if !allowedParams[key] { + cont.JSON(http.StatusBadRequest, gin.H{ + "error": "Only modelName and modelVersion are allowed", + }) + return + } + } - // cont.JSON(http.StatusOK, gin.H{ - // "code": http.StatusOK, - // "message": modelInfoListResp, - // }) + modelName:= cont.Query("modelName") + modelVersion:= cont.Query("modelVersion") + + if modelName == "" { + //return all modelinfo stored + + models, err := m.iDB.GetAll() + if err != nil { + logging.ERROR("error:", err) + cont.JSON(http.StatusInternalServerError, gin.H{ + "code": http.StatusInternalServerError, + "message": err.Error(), + }) + return + } + cont.JSON(http.StatusOK, models) + return + } else { + if modelVersion == "" { + // get all modelInfo by model name + modelInfos, err:= m.iDB.GetModelInfoByName(modelName) + if err != nil { + statusCode := http.StatusInternalServerError + logging.ERROR("Error occurred, send status code: ", statusCode) + cont.JSON(statusCode, gin.H{ + "code": statusCode, + "message": "Unexpected Error in server, you can't get model information list", + }) + return + } + //to check record not found + if len(modelInfos)==0{ + statusCode := http.StatusNotFound + errMessage := fmt.Sprintf("Record not found with modelName: %s", modelName) + logging.ERROR("Record not found, send status code: ", statusCode) + cont.JSON(statusCode, gin.H{ + "code": statusCode, + "message": errMessage, + }) + return + } + + cont.JSON(http.StatusOK, gin.H{ + "modelinfoList":modelInfos, + }) + return + + } else + { + // get all modelInfo by model name and version + modelInfo, err:= m.iDB.GetModelInfoByNameAndVer(modelName, modelVersion) + if err != nil { + statusCode := http.StatusInternalServerError + logging.ERROR("Error occurred, send status code: ", statusCode) + cont.JSON(statusCode, gin.H{ + "code": statusCode, + "message": "Unexpected Error in server, you can't get model information list", + }) + return + } + if modelInfo.Id == ""{ + statusCode := http.StatusNotFound + errMessage := fmt.Sprintf("Record not found with modelName: %s and modelVersion: %s", modelName, modelVersion) + logging.ERROR("Record not found, send status code: ", statusCode) + cont.JSON(statusCode, gin.H{ + "code": statusCode, + "message": errMessage, + }) + return + } + + cont.JSON(http.StatusOK, gin.H{ + "modelinfo":modelInfo, + }) + return + } + } } /* diff --git a/db/iDB.go b/db/iDB.go index a237a69..969c634 100644 --- a/db/iDB.go +++ b/db/iDB.go @@ -25,6 +25,8 @@ type IDB interface { Create(modelInfo models.ModelInfo) error GetByID(id string) (*models.ModelInfo, error) GetAll() ([]models.ModelInfo, error) + GetModelInfoByName(modelName string)([]models.ModelInfo, error) + GetModelInfoByNameAndVer(modelName string, modelVersion string)(*models.ModelInfo, error) Update(modelInfo models.ModelInfo) error Delete(id string) error } diff --git a/db/modelInfoRepository.go b/db/modelInfoRepository.go index 5ac9fdd..e277253 100644 --- a/db/modelInfoRepository.go +++ b/db/modelInfoRepository.go @@ -65,3 +65,20 @@ func (repo *ModelInfoRepository) Delete(id string) error { } return nil } +func (repo *ModelInfoRepository) GetModelInfoByName(modelName string)([]models.ModelInfo, error){ + var modelInfos []models.ModelInfo + result := repo.db.Where("model_name = ?", modelName).Find(&modelInfos) + if result.Error != nil { + return nil, result.Error + } + return modelInfos, nil +} + +func (repo *ModelInfoRepository) GetModelInfoByNameAndVer(modelName string, modelVersion string)(*models.ModelInfo, error){ + var modelInfo models.ModelInfo + result := repo.db.Where("model_name = ? AND model_version = ?", modelName, modelVersion).Find(&modelInfo) + if result.Error != nil { + return nil, result.Error + } + return &modelInfo, nil +} \ No newline at end of file -- 2.16.6