"gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/db"
"gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/logging"
"gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/models"
+ "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/utils"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"github.com/google/uuid"
logging.INFO("Uploading model API ...")
modelName := cont.Param("modelName")
modelVersion := cont.Param("modelVersion")
- artifactVersion := cont.Param("artifactVersion")
// Confirm if Model with Given ModelId: (ModelName and ModelVersion) is Registered or not:
modelInfo, err := m.iDB.GetModelInfoByNameAndVer(modelName, modelVersion)
return
}
+ // Read the uploaded File
+ fileHeader, err := cont.FormFile("file")
+ if err != nil {
+ statusCode := http.StatusInternalServerError
+ logging.ERROR("failed to read form file: %v", err)
+ cont.JSON(statusCode, models.ProblemDetail{
+ Status: statusCode,
+ Title: "Internal Server Error",
+ Detail: fmt.Sprintf("Can't read form file| Error: %s", err.Error()),
+ })
+ return
+ }
+
+ // Validate that file has .zip extension
+ if !strings.HasSuffix(strings.ToLower(fileHeader.Filename), ".zip") {
+ statusCode := http.StatusUnsupportedMediaType
+ logging.ERROR("invalid file type: %s", fileHeader.Filename)
+ cont.JSON(statusCode, models.ProblemDetail{
+ Status: statusCode,
+ Title: "Unsupported Media Type",
+ Detail: fmt.Sprintf("invalid file type: %s, Only .zip files are allowed", fileHeader.Filename),
+ })
+ return
+ }
+
+ file, err := fileHeader.Open()
+ if err != nil {
+ statusCode := http.StatusInternalServerError
+ logging.ERROR("failed to open uploaded file: %s", err.Error())
+ cont.JSON(statusCode, models.ProblemDetail{
+ Status: statusCode,
+ Title: "Internal Server Error",
+ Detail: fmt.Sprintf("failed to open uploaded file: %s", err.Error()),
+ })
+ return
+ }
+ defer file.Close()
+
+ byteFile, err := io.ReadAll(file)
+ if err != nil {
+ statusCode := http.StatusInternalServerError
+ logging.ERROR("Error reading file content: %s", err.Error())
+ cont.JSON(statusCode, models.ProblemDetail{
+ Status: statusCode,
+ Title: "Internal Server Error",
+ Detail: fmt.Sprintf("Error reading file content: %s", err.Error()),
+ })
+ return
+ }
+
+ artifactVersion := modelInfo.ModelId.ArtifactVersion
modelKey := fmt.Sprintf("%s_%s_%s", modelName, modelVersion, artifactVersion)
exportBucket := strings.ToLower(modelName)
- //TODO convert multipart.FileHeader to []byte
- fileHeader, _ := cont.FormFile("file")
- //TODO : Accept only .zip file for trained model
- file, _ := fileHeader.Open()
- defer file.Close()
- byteFile, _ := io.ReadAll((file))
+ // Update the Artifact-Version
+ newArtifactVersion, err := utils.IncrementArtifactVersion(artifactVersion)
+ if err != nil {
+ statusCode := http.StatusInternalServerError
+ logging.ERROR("Unable to get newArtifactVersion: %s", err.Error())
+ cont.JSON(statusCode, models.ProblemDetail{
+ Status: statusCode,
+ Title: "Internal Server Error",
+ Detail: fmt.Sprintf("Unable to get newArtifactVersion: %s", err.Error()),
+ })
+ return
+ }
+ modelInfo.ModelId.ArtifactVersion = newArtifactVersion
+ if err := m.iDB.Update(*modelInfo); err != nil {
+ statusCode := http.StatusInternalServerError
+ logging.ERROR("Unable to update newArtifactVersion: %s", err.Error())
+ cont.JSON(statusCode, models.ProblemDetail{
+ Status: statusCode,
+ Title: "Internal Server Error",
+ Detail: fmt.Sprintf("Unable to update newArtifactVersion: %s", err.Error()),
+ })
+ return
+ }
+ // Upload the file to s3-bucket
logging.INFO("Uploading model : " + modelKey)
if err := m.dbmgr.UploadFile(byteFile, modelKey+os.Getenv("MODEL_FILE_POSTFIX"), exportBucket); err != nil {
- logging.ERROR("Failed to Upload Model : ", err)
+ // Model failed to update: Rollback artifact version to old-one
+ logging.ERROR(fmt.Sprintf("Failed to Upload Model : %s, Rolling back to previous artifact-version : %s", err.Error(), artifactVersion))
+ modelInfo.ModelId.ArtifactVersion = artifactVersion
+ if err := m.iDB.Update(*modelInfo); err != nil {
+ /*
+ Ideally, the following situation should never occur.
+ This scenario can happen when:
+ The model artifact version is incremented to a new version, and
+ The file upload to the bucket fails, and
+ The rollback to the previous artifact version also fails.
+ */
+ logging.ERROR("Unable to rollback to old-artifactVersion: %s", err.Error())
+ }
+
cont.JSON(http.StatusInternalServerError, gin.H{
"code": http.StatusInternalServerError,
"message": err.Error(),
})
return
}
+
+ logging.INFO("model updated")
cont.JSON(http.StatusOK, gin.H{
- "code": http.StatusOK,
- "message": string("Model uploaded successfully.."),
+ "code": http.StatusOK,
+ "message": string("Model uploaded successfully.."),
+ "modelinfo": modelInfo,
})
}
cont.JSON(http.StatusNoContent, nil)
}
+// Deprecated: use the new API reference: UploadModel.
func (m *MmeApiHandler) UpdateArtifact(cont *gin.Context) {
logging.INFO("Update artifact version of model")
modelname := cont.Param("modelname")
"encoding/json"
"fmt"
"io"
+ "log"
"mime/multipart"
"net/http"
"net/http/httptest"
modelArtifactVersion := "1.0.0"
modelInfo := models.ModelRelatedInformation{
ModelId: models.ModelID{
- ModelName: modelName,
- ModelVersion: modelVersion,
+ ModelName: modelName,
+ ModelVersion: modelVersion,
+ ArtifactVersion: modelArtifactVersion,
},
}
iDBMockInst.On("GetModelInfoByNameAndVer").Return(&modelInfo, nil)
writer.Close()
// Upload model
- url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s/%s", modelName, modelVersion, modelArtifactVersion)
+ url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s", modelName, modelVersion)
req := httptest.NewRequest(http.MethodPost, url, body)
req.Header.Set("Content-Type", writer.FormDataContentType())
router.ServeHTTP(responseRecorder, req)
response := responseRecorder.Result()
responseBody, _ := io.ReadAll(response.Body)
assert.Equal(t, http.StatusOK, responseRecorder.Code)
- assert.Equal(t, `{"code":200,"message":"Model uploaded successfully.."}`, string(responseBody))
+ var responseJson map[string]any
+ err = json.Unmarshal(responseBody, &responseJson)
+ if err != nil {
+ t.Errorf("Error to Unmarshal response-body : Error %s", err.Error())
+ }
+
+ newModelInfoStr, err := json.Marshal(responseJson["modelinfo"])
+ if err != nil {
+ t.Errorf("Error to Marshal model-Info : Error %s", err.Error())
+ }
+
+ var newModelInfo models.ModelRelatedInformation
+ if err := json.Unmarshal(newModelInfoStr, &newModelInfo); err != nil {
+ log.Fatal("unmarshal error:", err)
+ }
+ assert.Equal(t, newModelInfo.ModelId.ArtifactVersion, "1.1.0")
}
func TestUploadModelFailureModelNotRegistered(t *testing.T) {
iDBMockInst := new(mme_mocks.IDBMock)
modelName := "test-model"
modelVersion := "1"
- modelArtifactVersion := "1.0.0"
// Returns Empty model, signifying Model is Not registered
iDBMockInst.On("GetModelInfoByNameAndVer").Return(&models.ModelRelatedInformation{}, nil)
handler := apis.NewMmeApiHandler(nil, iDBMockInst)
responseRecorder := httptest.NewRecorder()
// Upload model
- url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s/%s", modelName, modelVersion, modelArtifactVersion)
+ url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s", modelName, modelVersion)
req := httptest.NewRequest(http.MethodPost, url, nil)
router.ServeHTTP(responseRecorder, req)
modelArtifactVersion := "1.0.0"
modelInfo := models.ModelRelatedInformation{
ModelId: models.ModelID{
- ModelName: modelName,
- ModelVersion: modelVersion,
+ ModelName: modelName,
+ ModelVersion: modelVersion,
+ ArtifactVersion: modelArtifactVersion,
},
}
iDBMockInst.On("GetModelInfoByNameAndVer").Return(&modelInfo, nil)
writer.Close()
// Upload model
- url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s/%s", modelName, modelVersion, modelArtifactVersion)
+ url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s", modelName, modelVersion)
req := httptest.NewRequest(http.MethodPost, url, body)
req.Header.Set("Content-Type", writer.FormDataContentType())
router.ServeHTTP(responseRecorder, req)
@host = x.x.x.x:32006
### registraton
-POST http://{{host}}/model-registrations
+POST http://{{host}}/ai-ml-model-registration/v1/model-registrations
Content-Type: application/json
{
### Upload model using multipart/form-data
### Before Uploading, Make sure you have a "Model.zip" file in the current directory
-POST http://{{host}}/ai-ml-model-registration/v1/uploadModel/testmodel/1/1.0.0
+POST http://{{host}}/ai-ml-model-registration/v1/uploadModel/TestModel1/v1.0
Content-Type: multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW
------WebKitFormBoundary7MA4YWxkTrZu0gW
------WebKitFormBoundary7MA4YWxkTrZu0gW--
### Download model
-GET http://{{host}}/ai-ml-model-registration/v1/downloadModel/testmodel/1/1.0.0/model.zip
\ No newline at end of file
+GET http://{{host}}/ai-ml-model-registration/v1/downloadModel/TestModel1/v1.0/1.0.0/model.zip
\ No newline at end of file
api := r.Group("/ai-ml-model-registration/v1")
{
api.POST("/model-registrations", handler.RegisterModel)
- api.POST("/model-registrations/updateArtifact/:modelname/:modelversion/:artifactversion", handler.UpdateArtifact)
+ api.POST("/model-registrations/updateArtifact/:modelname/:modelversion/:artifactversion", handler.UpdateArtifact) // Deprecated: use the new API reference: /uploadModel/:modelName/:modelVersion.
api.GET("/model-registrations/:modelRegistrationId", handler.GetModelInfoById)
api.PUT("/model-registrations/:modelRegistrationId", handler.UpdateModel)
api.DELETE("/model-registrations/:modelRegistrationId", handler.DeleteModel)
api.GET("/getModelInfo/:modelName", handler.GetModelInfoByName)
- api.POST("/uploadModel/:modelName/:modelVersion/:artifactVersion", handler.UploadModel)
+ api.POST("/uploadModel/:modelName/:modelVersion", handler.UploadModel)
api.GET("/downloadModel/:modelName/:modelVersion/:artifactVersion/model.zip", handler.DownloadModel)
}
// As per R1-AP v6
--- /dev/null
+package utils
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+
+ "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/logging"
+)
+
+func IncrementArtifactVersion(artifactVersion string) (string, error) {
+ parts := strings.Split(artifactVersion, ".")
+ if len(parts) != 3 {
+ logging.ERROR("invalid artifactVersion format: " + artifactVersion)
+ return "", fmt.Errorf("invalid artifactVersion format: %s", artifactVersion)
+ }
+
+ major, err1 := strconv.Atoi(parts[0])
+ minor, err2 := strconv.Atoi(parts[1])
+ patch, err3 := strconv.Atoi(parts[2])
+ if err1 != nil || err2 != nil || err3 != nil {
+ logging.ERROR(fmt.Sprintf("failed to parse artifactVersion numbers: %v, %v, %v", err1, err2, err3))
+ return "", fmt.Errorf("failed to parse artifactVersion numbers: %v, %v, %v", err1, err2, err3)
+ }
+
+ // Increment logic
+ if artifactVersion == "0.0.0" {
+ // Modify to 1.0.0
+ major = 1
+ minor = 0
+ patch = 0
+ } else {
+ // Change from 1.x.0 to 1.(x + 1).0
+ minor += 1
+ }
+
+ // Construct new version string
+ newArtifactVersion := fmt.Sprintf("%d.%d.%d", major, minor, patch)
+ return newArtifactVersion, nil
+}
--- /dev/null
+package utils
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestIncrementArtifactVersion(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected string
+ expected_error string
+ }{
+ {"InitialVersion", "0.0.0", "1.0.0", ""},
+ {"MinorIncrement", "1.0.0", "1.1.0", ""},
+ {"AnotherIncrement", "1.5.0", "1.6.0", ""},
+ {"InvalidFormat", "a.b.c", "", "failed to parse artifactVersion numbers: strconv.Atoi: parsing \"a\": invalid syntax, strconv.Atoi: parsing \"b\": invalid syntax, strconv.Atoi: parsing \"c\": invalid syntax"},
+ {"InvalidFormat", "1.0", "", "invalid artifactVersion format: 1.0"},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got, err := IncrementArtifactVersion(tc.input)
+ var err_str string
+ if err == nil {
+ err_str = ""
+ } else {
+ err_str = err.Error()
+ }
+
+ assert.Equal(t, tc.expected_error, err_str)
+ assert.Equal(t, tc.expected, got)
+ })
+ }
+}