Fix the model registration 68/13168/3 2.0.1 3.0.0
authorsubhash kumar singh <subh.singh@samsung.com>
Tue, 9 Jul 2024 11:00:54 +0000 (11:00 +0000)
committersubhash kumar singh <subh.singh@samsung.com>
Tue, 9 Jul 2024 13:32:03 +0000 (13:32 +0000)
Fixed model registiration by creating instance of dbMgr.
Also code is modified to be more testable.

Issue-ID: AIMLFW-108
Change-Id: If6c03cb1aaaeeb0e8c3794bbb787abcf25946859
Signed-off-by: subhash kumar singh <subh.singh@samsung.com>
Dockerfile
apis/mmes_apis.go
apis_test/mmes_apis_test.go [new file with mode: 0644]
go.mod
go.sum
main.go
routers/router.go

index ec2e804..4d4d951 100644 (file)
@@ -25,7 +25,9 @@ WORKDIR ${MME_DIR}
 # Copy sources into the container
 COPY . .
 # Install dependencies from go.mod
-RUN go get
+RUN go mod tidy
+
+RUN LOG_FILE_NAME=testing.log go test ./...
 
 #Build all packages from current dir into bin 
 RUN go build -o mme_bin .
index 340ec32..dbff098 100644 (file)
@@ -34,9 +34,18 @@ type ModelInfo struct {
        Metainfo  map[string]interface{} `json:"meta-info"`
 }
 
-var dbmgr core.DBMgr
+type MmeApiHandler struct {
+       dbmgr core.DBMgr
+}
+
+func NewMmeApiHandler(dbMgr core.DBMgr) *MmeApiHandler {
+       handler := &MmeApiHandler{
+               dbmgr: dbMgr,
+       }
+       return handler
+}
 
-func RegisterModel(cont *gin.Context) {
+func (m *MmeApiHandler) RegisterModel(cont *gin.Context) {
        var returnCode int = http.StatusCreated
        var responseMsg string = "Model registered successfully"
 
@@ -57,9 +66,9 @@ func RegisterModel(cont *gin.Context) {
                logging.INFO(modelInfo.ModelName, modelInfo.RAppId, modelInfo.Metainfo)
                modelInfoBytes, _ := json.Marshal(modelInfo)
 
-               err := dbmgr.CreateBucket(modelInfo.ModelName)
+               err := m.dbmgr.CreateBucket(modelInfo.ModelName)
                if err == nil {
-                       dbmgr.UploadFile(modelInfoBytes, modelInfo.ModelName+os.Getenv("INFO_FILE_POSTFIX"), modelInfo.ModelName)
+                       m.dbmgr.UploadFile(modelInfoBytes, modelInfo.ModelName+os.Getenv("INFO_FILE_POSTFIX"), modelInfo.ModelName)
                } else {
                        returnCode = http.StatusInternalServerError
                        responseMsg = err.Error()
@@ -77,7 +86,7 @@ input :
 
        Model name : string
 */
-func GetModelInfo(cont *gin.Context) {
+func (m *MmeApiHandler) GetModelInfo(cont *gin.Context) {
        logging.INFO("Fetching model")
        bodyBytes, _ := io.ReadAll(cont.Request.Body)
        //TODO Error checking of request is not in json, i.e. etra ',' at EOF
@@ -86,7 +95,7 @@ func GetModelInfo(cont *gin.Context) {
        model_name := jsonMap["model-name"].(string)
        logging.INFO("The request model name: ", model_name)
 
-       model_info := dbmgr.GetBucketObject(model_name+os.Getenv("INFO_FILE_POSTFIX"), model_name)
+       model_info := m.dbmgr.GetBucketObject(model_name+os.Getenv("INFO_FILE_POSTFIX"), model_name)
 
        cont.JSON(http.StatusOK, gin.H{
                "code":    http.StatusOK,
@@ -97,11 +106,11 @@ func GetModelInfo(cont *gin.Context) {
 /*
 Provides the model details by param model name
 */
-func GetModelInfoByName(cont *gin.Context) {
+func (m *MmeApiHandler) GetModelInfoByName(cont *gin.Context) {
        logging.INFO("Get model info by name API ...")
        modelName := cont.Param("modelName")
 
-       model_info := dbmgr.GetBucketObject(modelName+os.Getenv("INFO_FILE_POSTFIX"), modelName)
+       model_info := m.dbmgr.GetBucketObject(modelName+os.Getenv("INFO_FILE_POSTFIX"), modelName)
 
        cont.JSON(http.StatusOK, gin.H{
                "code":    http.StatusOK,
@@ -112,7 +121,7 @@ func GetModelInfoByName(cont *gin.Context) {
 // API to upload the trained model in zip format
 // TODO : Model version as input
 
-func UploadModel(cont *gin.Context) {
+func (m *MmeApiHandler) UploadModel(cont *gin.Context) {
        logging.INFO("Uploading model API ...")
        modelName := cont.Param("modelName")
        //TODO convert multipart.FileHeader to []byted
@@ -124,7 +133,7 @@ func UploadModel(cont *gin.Context) {
        byteFile, _ := io.ReadAll((file))
 
        logging.INFO("Uploading model : ", modelName)
-       dbmgr.UploadFile(byteFile, modelName+os.Getenv("MODEL_FILE_POSTFIX"), modelName)
+       m.dbmgr.UploadFile(byteFile, modelName+os.Getenv("MODEL_FILE_POSTFIX"), modelName)
        cont.JSON(http.StatusOK, gin.H{
                "code":    http.StatusOK,
                "message": string("Model uploaded successfully.."),
@@ -135,11 +144,11 @@ func UploadModel(cont *gin.Context) {
 API to download the trained model from  bucket
 Input: model name in path params as "modelName"
 */
-func DownloadModel(cont *gin.Context) {
+func (m *MmeApiHandler) DownloadModel(cont *gin.Context) {
        logging.INFO("Download model API ...")
        modelName := cont.Param("modelName")
        fileName := modelName + os.Getenv("MODEL_FILE_POSTFIX")
-       fileByes := dbmgr.GetBucketObject(fileName, modelName)
+       fileByes := m.dbmgr.GetBucketObject(fileName, modelName)
 
        //Return file in api reponse using byte slice
        cont.Header("Content-Disposition", "attachment;"+fileName)
@@ -147,15 +156,15 @@ func DownloadModel(cont *gin.Context) {
        cont.Data(http.StatusOK, "application/octet", fileByes)
 }
 
-func GetModel(cont *gin.Context) {
+func (m *MmeApiHandler) GetModel(cont *gin.Context) {
        logging.INFO("Fetching model")
        cont.IndentedJSON(http.StatusOK, " ")
 }
 
-func UpdateModel() {
+func (m *MmeApiHandler) UpdateModel() {
        logging.INFO("Updating model...")
 }
 
-func DeleteModel() {
+func (m *MmeApiHandler) DeleteModel() {
        logging.INFO("Deleting model...")
 }
diff --git a/apis_test/mmes_apis_test.go b/apis_test/mmes_apis_test.go
new file mode 100644 (file)
index 0000000..acc69b5
--- /dev/null
@@ -0,0 +1,64 @@
+/*
+==================================================================================
+Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==================================================================================
+*/
+package apis_test
+
+import (
+       "net/http"
+       "net/http/httptest"
+       "os"
+       "strings"
+       "testing"
+
+       "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/apis"
+       "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/core"
+       "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/routers"
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/mock"
+)
+
+var registerModelBody = `{
+       "model-name": "test-model",
+       "rapp-id": "1234",
+       "meta-info": {
+               "a": "b"
+       }
+}`
+
+type dbMgrMock struct {
+       mock.Mock
+       core.DBMgr
+}
+
+func (d *dbMgrMock) CreateBucket(bucketName string) (err error) {
+       args := d.Called(bucketName)
+       return args.Error(0)
+}
+
+func (d *dbMgrMock) UploadFile(dataBytes []byte, file_name string, bucketName string) {
+}
+func TestRegisterModel(t *testing.T) {
+       os.Setenv("LOG_FILE_NAME", "testing")
+       dbMgrMockInst := new(dbMgrMock)
+       dbMgrMockInst.On("CreateBucket", "test-model").Return(nil)
+       handler := apis.NewMmeApiHandler(dbMgrMockInst)
+       router := routers.InitRouter(handler)
+       w := httptest.NewRecorder()
+       req, _ := http.NewRequest("POST", "/registerModel", strings.NewReader(registerModelBody))
+       router.ServeHTTP(w, req)
+       assert.Equal(t, 201, w.Code)
+}
diff --git a/go.mod b/go.mod
index c480608..7886fda 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -5,11 +5,13 @@ go 1.21.3
 require (
        github.com/aws/aws-sdk-go v1.47.3
        github.com/gin-gonic/gin v1.9.1
+       github.com/stretchr/testify v1.8.3
 )
 
 require (
        github.com/bytedance/sonic v1.9.1 // indirect
        github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
+       github.com/davecgh/go-spew v1.1.1 // indirect
        github.com/gabriel-vasile/mimetype v1.4.2 // indirect
        github.com/gin-contrib/sse v0.1.0 // indirect
        github.com/go-playground/locales v0.14.1 // indirect
@@ -24,6 +26,8 @@ require (
        github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
        github.com/modern-go/reflect2 v1.0.2 // indirect
        github.com/pelletier/go-toml/v2 v2.0.8 // indirect
+       github.com/pmezard/go-difflib v1.0.0 // indirect
+       github.com/stretchr/objx v0.5.0 // indirect
        github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
        github.com/ugorji/go/codec v1.2.11 // indirect
        golang.org/x/arch v0.3.0 // indirect
diff --git a/go.sum b/go.sum
index e7877ea..87f381e 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -53,6 +53,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
 github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
diff --git a/main.go b/main.go
index 4eaf649..4135c8c 100644 (file)
--- a/main.go
+++ b/main.go
@@ -22,12 +22,17 @@ import (
        "os"
        "time"
 
+       "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/apis"
+       "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/core"
        "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/logging"
        "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/routers"
 )
 
 func main() {
-       router := routers.InitRouter()
+       router := routers.InitRouter(
+               apis.NewMmeApiHandler(
+                       core.GetDBManagerInstance(),
+               ))
        server := http.Server{
                Addr:         os.Getenv("MMES_URL"),
                Handler:      router,
index 6d7fd7a..12a4918 100644 (file)
@@ -22,15 +22,15 @@ import (
        "github.com/gin-gonic/gin"
 )
 
-func InitRouter() *gin.Engine {
+func InitRouter(handler *apis.MmeApiHandler) *gin.Engine {
        r := gin.New()
        r.Use(gin.Logger())
        r.Use(gin.Recovery())
 
-       r.POST("/registerModel", apis.RegisterModel)
-       r.GET("/getModelInfo", apis.GetModelInfo)
-       r.GET("/getModelInfo/:modelName", apis.GetModelInfoByName)
-       r.POST("/uploadModel/:modelName", apis.UploadModel)
-       r.GET("/downloadModel/:modelName/model.zip", apis.DownloadModel)
+       r.POST("/registerModel", handler.RegisterModel)
+       r.GET("/getModelInfo", handler.GetModelInfo)
+       r.GET("/getModelInfo/:modelName", handler.GetModelInfoByName)
+       r.POST("/uploadModel/:modelName", handler.UploadModel)
+       r.GET("/downloadModel/:modelName/model.zip", handler.DownloadModel)
        return r
 }