From b9b49db849791deae6e67464dd6b41ddfd4ab8bd Mon Sep 17 00:00:00 2001 From: subhash kumar singh Date: Tue, 9 Jul 2024 11:00:54 +0000 Subject: [PATCH] Fix the model registration Fixed model registiration by creating instance of dbMgr. Also code is modified to be more testable. Issue-ID: AIMLFW-108 Change-Id: If6c03cb1aaaeeb0e8c3794bbb787abcf25946859 Signed-off-by: subhash kumar singh --- Dockerfile | 4 ++- apis/mmes_apis.go | 39 ++++++++++++++++----------- apis_test/mmes_apis_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++++ go.mod | 4 +++ go.sum | 1 + main.go | 7 ++++- routers/router.go | 12 ++++----- 7 files changed, 108 insertions(+), 23 deletions(-) create mode 100644 apis_test/mmes_apis_test.go diff --git a/Dockerfile b/Dockerfile index ec2e804..4d4d951 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,7 +25,9 @@ WORKDIR ${MME_DIR} # Copy sources into the container COPY . . # Install dependencies from go.mod -RUN go get +RUN go mod tidy + +RUN LOG_FILE_NAME=testing.log go test ./... #Build all packages from current dir into bin RUN go build -o mme_bin . diff --git a/apis/mmes_apis.go b/apis/mmes_apis.go index 340ec32..dbff098 100644 --- a/apis/mmes_apis.go +++ b/apis/mmes_apis.go @@ -34,9 +34,18 @@ type ModelInfo struct { Metainfo map[string]interface{} `json:"meta-info"` } -var dbmgr core.DBMgr +type MmeApiHandler struct { + dbmgr core.DBMgr +} + +func NewMmeApiHandler(dbMgr core.DBMgr) *MmeApiHandler { + handler := &MmeApiHandler{ + dbmgr: dbMgr, + } + return handler +} -func RegisterModel(cont *gin.Context) { +func (m *MmeApiHandler) RegisterModel(cont *gin.Context) { var returnCode int = http.StatusCreated var responseMsg string = "Model registered successfully" @@ -57,9 +66,9 @@ func RegisterModel(cont *gin.Context) { logging.INFO(modelInfo.ModelName, modelInfo.RAppId, modelInfo.Metainfo) modelInfoBytes, _ := json.Marshal(modelInfo) - err := dbmgr.CreateBucket(modelInfo.ModelName) + err := m.dbmgr.CreateBucket(modelInfo.ModelName) if err == nil { - dbmgr.UploadFile(modelInfoBytes, modelInfo.ModelName+os.Getenv("INFO_FILE_POSTFIX"), modelInfo.ModelName) + m.dbmgr.UploadFile(modelInfoBytes, modelInfo.ModelName+os.Getenv("INFO_FILE_POSTFIX"), modelInfo.ModelName) } else { returnCode = http.StatusInternalServerError responseMsg = err.Error() @@ -77,7 +86,7 @@ input : Model name : string */ -func GetModelInfo(cont *gin.Context) { +func (m *MmeApiHandler) GetModelInfo(cont *gin.Context) { logging.INFO("Fetching model") bodyBytes, _ := io.ReadAll(cont.Request.Body) //TODO Error checking of request is not in json, i.e. etra ',' at EOF @@ -86,7 +95,7 @@ func GetModelInfo(cont *gin.Context) { model_name := jsonMap["model-name"].(string) logging.INFO("The request model name: ", model_name) - model_info := dbmgr.GetBucketObject(model_name+os.Getenv("INFO_FILE_POSTFIX"), model_name) + model_info := m.dbmgr.GetBucketObject(model_name+os.Getenv("INFO_FILE_POSTFIX"), model_name) cont.JSON(http.StatusOK, gin.H{ "code": http.StatusOK, @@ -97,11 +106,11 @@ func GetModelInfo(cont *gin.Context) { /* Provides the model details by param model name */ -func GetModelInfoByName(cont *gin.Context) { +func (m *MmeApiHandler) GetModelInfoByName(cont *gin.Context) { logging.INFO("Get model info by name API ...") modelName := cont.Param("modelName") - model_info := dbmgr.GetBucketObject(modelName+os.Getenv("INFO_FILE_POSTFIX"), modelName) + model_info := m.dbmgr.GetBucketObject(modelName+os.Getenv("INFO_FILE_POSTFIX"), modelName) cont.JSON(http.StatusOK, gin.H{ "code": http.StatusOK, @@ -112,7 +121,7 @@ func GetModelInfoByName(cont *gin.Context) { // API to upload the trained model in zip format // TODO : Model version as input -func UploadModel(cont *gin.Context) { +func (m *MmeApiHandler) UploadModel(cont *gin.Context) { logging.INFO("Uploading model API ...") modelName := cont.Param("modelName") //TODO convert multipart.FileHeader to []byted @@ -124,7 +133,7 @@ func UploadModel(cont *gin.Context) { byteFile, _ := io.ReadAll((file)) logging.INFO("Uploading model : ", modelName) - dbmgr.UploadFile(byteFile, modelName+os.Getenv("MODEL_FILE_POSTFIX"), modelName) + m.dbmgr.UploadFile(byteFile, modelName+os.Getenv("MODEL_FILE_POSTFIX"), modelName) cont.JSON(http.StatusOK, gin.H{ "code": http.StatusOK, "message": string("Model uploaded successfully.."), @@ -135,11 +144,11 @@ func UploadModel(cont *gin.Context) { API to download the trained model from bucket Input: model name in path params as "modelName" */ -func DownloadModel(cont *gin.Context) { +func (m *MmeApiHandler) DownloadModel(cont *gin.Context) { logging.INFO("Download model API ...") modelName := cont.Param("modelName") fileName := modelName + os.Getenv("MODEL_FILE_POSTFIX") - fileByes := dbmgr.GetBucketObject(fileName, modelName) + fileByes := m.dbmgr.GetBucketObject(fileName, modelName) //Return file in api reponse using byte slice cont.Header("Content-Disposition", "attachment;"+fileName) @@ -147,15 +156,15 @@ func DownloadModel(cont *gin.Context) { cont.Data(http.StatusOK, "application/octet", fileByes) } -func GetModel(cont *gin.Context) { +func (m *MmeApiHandler) GetModel(cont *gin.Context) { logging.INFO("Fetching model") cont.IndentedJSON(http.StatusOK, " ") } -func UpdateModel() { +func (m *MmeApiHandler) UpdateModel() { logging.INFO("Updating model...") } -func DeleteModel() { +func (m *MmeApiHandler) DeleteModel() { logging.INFO("Deleting model...") } diff --git a/apis_test/mmes_apis_test.go b/apis_test/mmes_apis_test.go new file mode 100644 index 0000000..acc69b5 --- /dev/null +++ b/apis_test/mmes_apis_test.go @@ -0,0 +1,64 @@ +/* +================================================================================== +Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +================================================================================== +*/ +package apis_test + +import ( + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/apis" + "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/core" + "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/routers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var registerModelBody = `{ + "model-name": "test-model", + "rapp-id": "1234", + "meta-info": { + "a": "b" + } +}` + +type dbMgrMock struct { + mock.Mock + core.DBMgr +} + +func (d *dbMgrMock) CreateBucket(bucketName string) (err error) { + args := d.Called(bucketName) + return args.Error(0) +} + +func (d *dbMgrMock) UploadFile(dataBytes []byte, file_name string, bucketName string) { +} +func TestRegisterModel(t *testing.T) { + os.Setenv("LOG_FILE_NAME", "testing") + dbMgrMockInst := new(dbMgrMock) + dbMgrMockInst.On("CreateBucket", "test-model").Return(nil) + handler := apis.NewMmeApiHandler(dbMgrMockInst) + router := routers.InitRouter(handler) + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/registerModel", strings.NewReader(registerModelBody)) + router.ServeHTTP(w, req) + assert.Equal(t, 201, w.Code) +} diff --git a/go.mod b/go.mod index c480608..7886fda 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,13 @@ go 1.21.3 require ( github.com/aws/aws-sdk-go v1.47.3 github.com/gin-gonic/gin v1.9.1 + github.com/stretchr/testify v1.8.3 ) require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -24,6 +26,8 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect diff --git a/go.sum b/go.sum index e7877ea..87f381e 100644 --- a/go.sum +++ b/go.sum @@ -53,6 +53,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/main.go b/main.go index 4eaf649..4135c8c 100644 --- a/main.go +++ b/main.go @@ -22,12 +22,17 @@ import ( "os" "time" + "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/apis" + "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/core" "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/logging" "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/routers" ) func main() { - router := routers.InitRouter() + router := routers.InitRouter( + apis.NewMmeApiHandler( + core.GetDBManagerInstance(), + )) server := http.Server{ Addr: os.Getenv("MMES_URL"), Handler: router, diff --git a/routers/router.go b/routers/router.go index 6d7fd7a..12a4918 100644 --- a/routers/router.go +++ b/routers/router.go @@ -22,15 +22,15 @@ import ( "github.com/gin-gonic/gin" ) -func InitRouter() *gin.Engine { +func InitRouter(handler *apis.MmeApiHandler) *gin.Engine { r := gin.New() r.Use(gin.Logger()) r.Use(gin.Recovery()) - r.POST("/registerModel", apis.RegisterModel) - r.GET("/getModelInfo", apis.GetModelInfo) - r.GET("/getModelInfo/:modelName", apis.GetModelInfoByName) - r.POST("/uploadModel/:modelName", apis.UploadModel) - r.GET("/downloadModel/:modelName/model.zip", apis.DownloadModel) + r.POST("/registerModel", handler.RegisterModel) + r.GET("/getModelInfo", handler.GetModelInfo) + r.GET("/getModelInfo/:modelName", handler.GetModelInfoByName) + r.POST("/uploadModel/:modelName", handler.UploadModel) + r.GET("/downloadModel/:modelName/model.zip", handler.DownloadModel) return r } -- 2.16.6