Fixed model registiration by creating instance of dbMgr.
Also code is modified to be more testable.
Issue-ID: AIMLFW-108
Change-Id: If6c03cb1aaaeeb0e8c3794bbb787abcf25946859
Signed-off-by: subhash kumar singh <subh.singh@samsung.com>
# Copy sources into the container
COPY . .
# Install dependencies from go.mod
-RUN go get
+RUN go mod tidy
+
+RUN LOG_FILE_NAME=testing.log go test ./...
#Build all packages from current dir into bin
RUN go build -o mme_bin .
Metainfo map[string]interface{} `json:"meta-info"`
}
-var dbmgr core.DBMgr
+type MmeApiHandler struct {
+ dbmgr core.DBMgr
+}
+
+func NewMmeApiHandler(dbMgr core.DBMgr) *MmeApiHandler {
+ handler := &MmeApiHandler{
+ dbmgr: dbMgr,
+ }
+ return handler
+}
-func RegisterModel(cont *gin.Context) {
+func (m *MmeApiHandler) RegisterModel(cont *gin.Context) {
var returnCode int = http.StatusCreated
var responseMsg string = "Model registered successfully"
logging.INFO(modelInfo.ModelName, modelInfo.RAppId, modelInfo.Metainfo)
modelInfoBytes, _ := json.Marshal(modelInfo)
- err := dbmgr.CreateBucket(modelInfo.ModelName)
+ err := m.dbmgr.CreateBucket(modelInfo.ModelName)
if err == nil {
- dbmgr.UploadFile(modelInfoBytes, modelInfo.ModelName+os.Getenv("INFO_FILE_POSTFIX"), modelInfo.ModelName)
+ m.dbmgr.UploadFile(modelInfoBytes, modelInfo.ModelName+os.Getenv("INFO_FILE_POSTFIX"), modelInfo.ModelName)
} else {
returnCode = http.StatusInternalServerError
responseMsg = err.Error()
Model name : string
*/
-func GetModelInfo(cont *gin.Context) {
+func (m *MmeApiHandler) GetModelInfo(cont *gin.Context) {
logging.INFO("Fetching model")
bodyBytes, _ := io.ReadAll(cont.Request.Body)
//TODO Error checking of request is not in json, i.e. etra ',' at EOF
model_name := jsonMap["model-name"].(string)
logging.INFO("The request model name: ", model_name)
- model_info := dbmgr.GetBucketObject(model_name+os.Getenv("INFO_FILE_POSTFIX"), model_name)
+ model_info := m.dbmgr.GetBucketObject(model_name+os.Getenv("INFO_FILE_POSTFIX"), model_name)
cont.JSON(http.StatusOK, gin.H{
"code": http.StatusOK,
/*
Provides the model details by param model name
*/
-func GetModelInfoByName(cont *gin.Context) {
+func (m *MmeApiHandler) GetModelInfoByName(cont *gin.Context) {
logging.INFO("Get model info by name API ...")
modelName := cont.Param("modelName")
- model_info := dbmgr.GetBucketObject(modelName+os.Getenv("INFO_FILE_POSTFIX"), modelName)
+ model_info := m.dbmgr.GetBucketObject(modelName+os.Getenv("INFO_FILE_POSTFIX"), modelName)
cont.JSON(http.StatusOK, gin.H{
"code": http.StatusOK,
// API to upload the trained model in zip format
// TODO : Model version as input
-func UploadModel(cont *gin.Context) {
+func (m *MmeApiHandler) UploadModel(cont *gin.Context) {
logging.INFO("Uploading model API ...")
modelName := cont.Param("modelName")
//TODO convert multipart.FileHeader to []byted
byteFile, _ := io.ReadAll((file))
logging.INFO("Uploading model : ", modelName)
- dbmgr.UploadFile(byteFile, modelName+os.Getenv("MODEL_FILE_POSTFIX"), modelName)
+ m.dbmgr.UploadFile(byteFile, modelName+os.Getenv("MODEL_FILE_POSTFIX"), modelName)
cont.JSON(http.StatusOK, gin.H{
"code": http.StatusOK,
"message": string("Model uploaded successfully.."),
API to download the trained model from bucket
Input: model name in path params as "modelName"
*/
-func DownloadModel(cont *gin.Context) {
+func (m *MmeApiHandler) DownloadModel(cont *gin.Context) {
logging.INFO("Download model API ...")
modelName := cont.Param("modelName")
fileName := modelName + os.Getenv("MODEL_FILE_POSTFIX")
- fileByes := dbmgr.GetBucketObject(fileName, modelName)
+ fileByes := m.dbmgr.GetBucketObject(fileName, modelName)
//Return file in api reponse using byte slice
cont.Header("Content-Disposition", "attachment;"+fileName)
cont.Data(http.StatusOK, "application/octet", fileByes)
}
-func GetModel(cont *gin.Context) {
+func (m *MmeApiHandler) GetModel(cont *gin.Context) {
logging.INFO("Fetching model")
cont.IndentedJSON(http.StatusOK, " ")
}
-func UpdateModel() {
+func (m *MmeApiHandler) UpdateModel() {
logging.INFO("Updating model...")
}
-func DeleteModel() {
+func (m *MmeApiHandler) DeleteModel() {
logging.INFO("Deleting model...")
}
--- /dev/null
+/*
+==================================================================================
+Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==================================================================================
+*/
+package apis_test
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+
+ "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/apis"
+ "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/core"
+ "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/routers"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+var registerModelBody = `{
+ "model-name": "test-model",
+ "rapp-id": "1234",
+ "meta-info": {
+ "a": "b"
+ }
+}`
+
+type dbMgrMock struct {
+ mock.Mock
+ core.DBMgr
+}
+
+func (d *dbMgrMock) CreateBucket(bucketName string) (err error) {
+ args := d.Called(bucketName)
+ return args.Error(0)
+}
+
+func (d *dbMgrMock) UploadFile(dataBytes []byte, file_name string, bucketName string) {
+}
+func TestRegisterModel(t *testing.T) {
+ os.Setenv("LOG_FILE_NAME", "testing")
+ dbMgrMockInst := new(dbMgrMock)
+ dbMgrMockInst.On("CreateBucket", "test-model").Return(nil)
+ handler := apis.NewMmeApiHandler(dbMgrMockInst)
+ router := routers.InitRouter(handler)
+ w := httptest.NewRecorder()
+ req, _ := http.NewRequest("POST", "/registerModel", strings.NewReader(registerModelBody))
+ router.ServeHTTP(w, req)
+ assert.Equal(t, 201, w.Code)
+}
require (
github.com/aws/aws-sdk-go v1.47.3
github.com/gin-gonic/gin v1.9.1
+ github.com/stretchr/testify v1.8.3
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
+ github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ github.com/stretchr/objx v0.5.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
"os"
"time"
+ "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/apis"
+ "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/core"
"gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/logging"
"gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/routers"
)
func main() {
- router := routers.InitRouter()
+ router := routers.InitRouter(
+ apis.NewMmeApiHandler(
+ core.GetDBManagerInstance(),
+ ))
server := http.Server{
Addr: os.Getenv("MMES_URL"),
Handler: router,
"github.com/gin-gonic/gin"
)
-func InitRouter() *gin.Engine {
+func InitRouter(handler *apis.MmeApiHandler) *gin.Engine {
r := gin.New()
r.Use(gin.Logger())
r.Use(gin.Recovery())
- r.POST("/registerModel", apis.RegisterModel)
- r.GET("/getModelInfo", apis.GetModelInfo)
- r.GET("/getModelInfo/:modelName", apis.GetModelInfoByName)
- r.POST("/uploadModel/:modelName", apis.UploadModel)
- r.GET("/downloadModel/:modelName/model.zip", apis.DownloadModel)
+ r.POST("/registerModel", handler.RegisterModel)
+ r.GET("/getModelInfo", handler.GetModelInfo)
+ r.GET("/getModelInfo/:modelName", handler.GetModelInfoByName)
+ r.POST("/uploadModel/:modelName", handler.UploadModel)
+ r.GET("/downloadModel/:modelName/model.zip", handler.DownloadModel)
return r
}