From 4b7bc6c684b2f0f7c5f86c0c54315449733753e6 Mon Sep 17 00:00:00 2001 From: ashishj1729 Date: Fri, 10 Oct 2025 00:35:48 +0530 Subject: [PATCH] Add Update ArtifactVersion Functionality after Model-Upload via MME 1. Handling Error Scenarios in MME-Upload Model 2. Updating ArtifactVersion when model is uploaded Issue-id: AIMLFW-282 Change-Id: Icc586cde7827dd1d901f58eb09bfbec582109277 Signed-off-by: ashishj1729 --- apis/mmes_apis.go | 106 +++++++++++++++++++++++++++++++++++++++----- apis_test/mmes_apis_test.go | 35 +++++++++++---- request.http | 6 +-- routers/router.go | 4 +- utils/utils.go | 40 +++++++++++++++++ utils/utils_test.go | 37 ++++++++++++++++ 6 files changed, 204 insertions(+), 24 deletions(-) create mode 100644 utils/utils.go create mode 100644 utils/utils_test.go diff --git a/apis/mmes_apis.go b/apis/mmes_apis.go index 38bce11..34e20e7 100644 --- a/apis/mmes_apis.go +++ b/apis/mmes_apis.go @@ -30,6 +30,7 @@ import ( "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/db" "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/logging" "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/models" + "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/utils" "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" "github.com/google/uuid" @@ -259,7 +260,6 @@ func (m *MmeApiHandler) UploadModel(cont *gin.Context) { logging.INFO("Uploading model API ...") modelName := cont.Param("modelName") modelVersion := cont.Param("modelVersion") - artifactVersion := cont.Param("artifactVersion") // Confirm if Model with Given ModelId: (ModelName and ModelVersion) is Registered or not: modelInfo, err := m.iDB.GetModelInfoByNameAndVer(modelName, modelVersion) @@ -285,27 +285,112 @@ func (m *MmeApiHandler) UploadModel(cont *gin.Context) { return } + // Read the uploaded File + fileHeader, err := cont.FormFile("file") + if err != nil { + statusCode := http.StatusInternalServerError + logging.ERROR("failed to read form file: %v", err) + cont.JSON(statusCode, models.ProblemDetail{ + Status: statusCode, + Title: "Internal Server Error", + Detail: fmt.Sprintf("Can't read form file| Error: %s", err.Error()), + }) + return + } + + // Validate that file has .zip extension + if !strings.HasSuffix(strings.ToLower(fileHeader.Filename), ".zip") { + statusCode := http.StatusUnsupportedMediaType + logging.ERROR("invalid file type: %s", fileHeader.Filename) + cont.JSON(statusCode, models.ProblemDetail{ + Status: statusCode, + Title: "Unsupported Media Type", + Detail: fmt.Sprintf("invalid file type: %s, Only .zip files are allowed", fileHeader.Filename), + }) + return + } + + file, err := fileHeader.Open() + if err != nil { + statusCode := http.StatusInternalServerError + logging.ERROR("failed to open uploaded file: %s", err.Error()) + cont.JSON(statusCode, models.ProblemDetail{ + Status: statusCode, + Title: "Internal Server Error", + Detail: fmt.Sprintf("failed to open uploaded file: %s", err.Error()), + }) + return + } + defer file.Close() + + byteFile, err := io.ReadAll(file) + if err != nil { + statusCode := http.StatusInternalServerError + logging.ERROR("Error reading file content: %s", err.Error()) + cont.JSON(statusCode, models.ProblemDetail{ + Status: statusCode, + Title: "Internal Server Error", + Detail: fmt.Sprintf("Error reading file content: %s", err.Error()), + }) + return + } + + artifactVersion := modelInfo.ModelId.ArtifactVersion modelKey := fmt.Sprintf("%s_%s_%s", modelName, modelVersion, artifactVersion) exportBucket := strings.ToLower(modelName) - //TODO convert multipart.FileHeader to []byte - fileHeader, _ := cont.FormFile("file") - //TODO : Accept only .zip file for trained model - file, _ := fileHeader.Open() - defer file.Close() - byteFile, _ := io.ReadAll((file)) + // Update the Artifact-Version + newArtifactVersion, err := utils.IncrementArtifactVersion(artifactVersion) + if err != nil { + statusCode := http.StatusInternalServerError + logging.ERROR("Unable to get newArtifactVersion: %s", err.Error()) + cont.JSON(statusCode, models.ProblemDetail{ + Status: statusCode, + Title: "Internal Server Error", + Detail: fmt.Sprintf("Unable to get newArtifactVersion: %s", err.Error()), + }) + return + } + modelInfo.ModelId.ArtifactVersion = newArtifactVersion + if err := m.iDB.Update(*modelInfo); err != nil { + statusCode := http.StatusInternalServerError + logging.ERROR("Unable to update newArtifactVersion: %s", err.Error()) + cont.JSON(statusCode, models.ProblemDetail{ + Status: statusCode, + Title: "Internal Server Error", + Detail: fmt.Sprintf("Unable to update newArtifactVersion: %s", err.Error()), + }) + return + } + // Upload the file to s3-bucket logging.INFO("Uploading model : " + modelKey) if err := m.dbmgr.UploadFile(byteFile, modelKey+os.Getenv("MODEL_FILE_POSTFIX"), exportBucket); err != nil { - logging.ERROR("Failed to Upload Model : ", err) + // Model failed to update: Rollback artifact version to old-one + logging.ERROR(fmt.Sprintf("Failed to Upload Model : %s, Rolling back to previous artifact-version : %s", err.Error(), artifactVersion)) + modelInfo.ModelId.ArtifactVersion = artifactVersion + if err := m.iDB.Update(*modelInfo); err != nil { + /* + Ideally, the following situation should never occur. + This scenario can happen when: + The model artifact version is incremented to a new version, and + The file upload to the bucket fails, and + The rollback to the previous artifact version also fails. + */ + logging.ERROR("Unable to rollback to old-artifactVersion: %s", err.Error()) + } + cont.JSON(http.StatusInternalServerError, gin.H{ "code": http.StatusInternalServerError, "message": err.Error(), }) return } + + logging.INFO("model updated") cont.JSON(http.StatusOK, gin.H{ - "code": http.StatusOK, - "message": string("Model uploaded successfully.."), + "code": http.StatusOK, + "message": string("Model uploaded successfully.."), + "modelinfo": modelInfo, }) } @@ -401,6 +486,7 @@ func (m *MmeApiHandler) DeleteModel(cont *gin.Context) { cont.JSON(http.StatusNoContent, nil) } +// Deprecated: use the new API reference: UploadModel. func (m *MmeApiHandler) UpdateArtifact(cont *gin.Context) { logging.INFO("Update artifact version of model") modelname := cont.Param("modelname") diff --git a/apis_test/mmes_apis_test.go b/apis_test/mmes_apis_test.go index 2d09461..edfc7db 100644 --- a/apis_test/mmes_apis_test.go +++ b/apis_test/mmes_apis_test.go @@ -22,6 +22,7 @@ import ( "encoding/json" "fmt" "io" + "log" "mime/multipart" "net/http" "net/http/httptest" @@ -376,8 +377,9 @@ func TestUploadModelSuccess(t *testing.T) { modelArtifactVersion := "1.0.0" modelInfo := models.ModelRelatedInformation{ ModelId: models.ModelID{ - ModelName: modelName, - ModelVersion: modelVersion, + ModelName: modelName, + ModelVersion: modelVersion, + ArtifactVersion: modelArtifactVersion, }, } iDBMockInst.On("GetModelInfoByNameAndVer").Return(&modelInfo, nil) @@ -398,7 +400,7 @@ func TestUploadModelSuccess(t *testing.T) { writer.Close() // Upload model - url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s/%s", modelName, modelVersion, modelArtifactVersion) + url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s", modelName, modelVersion) req := httptest.NewRequest(http.MethodPost, url, body) req.Header.Set("Content-Type", writer.FormDataContentType()) router.ServeHTTP(responseRecorder, req) @@ -406,7 +408,22 @@ func TestUploadModelSuccess(t *testing.T) { response := responseRecorder.Result() responseBody, _ := io.ReadAll(response.Body) assert.Equal(t, http.StatusOK, responseRecorder.Code) - assert.Equal(t, `{"code":200,"message":"Model uploaded successfully.."}`, string(responseBody)) + var responseJson map[string]any + err = json.Unmarshal(responseBody, &responseJson) + if err != nil { + t.Errorf("Error to Unmarshal response-body : Error %s", err.Error()) + } + + newModelInfoStr, err := json.Marshal(responseJson["modelinfo"]) + if err != nil { + t.Errorf("Error to Marshal model-Info : Error %s", err.Error()) + } + + var newModelInfo models.ModelRelatedInformation + if err := json.Unmarshal(newModelInfoStr, &newModelInfo); err != nil { + log.Fatal("unmarshal error:", err) + } + assert.Equal(t, newModelInfo.ModelId.ArtifactVersion, "1.1.0") } func TestUploadModelFailureModelNotRegistered(t *testing.T) { @@ -416,7 +433,6 @@ func TestUploadModelFailureModelNotRegistered(t *testing.T) { iDBMockInst := new(mme_mocks.IDBMock) modelName := "test-model" modelVersion := "1" - modelArtifactVersion := "1.0.0" // Returns Empty model, signifying Model is Not registered iDBMockInst.On("GetModelInfoByNameAndVer").Return(&models.ModelRelatedInformation{}, nil) handler := apis.NewMmeApiHandler(nil, iDBMockInst) @@ -424,7 +440,7 @@ func TestUploadModelFailureModelNotRegistered(t *testing.T) { responseRecorder := httptest.NewRecorder() // Upload model - url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s/%s", modelName, modelVersion, modelArtifactVersion) + url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s", modelName, modelVersion) req := httptest.NewRequest(http.MethodPost, url, nil) router.ServeHTTP(responseRecorder, req) @@ -448,8 +464,9 @@ func TestUploadModelFailureModelUploadFailure(t *testing.T) { modelArtifactVersion := "1.0.0" modelInfo := models.ModelRelatedInformation{ ModelId: models.ModelID{ - ModelName: modelName, - ModelVersion: modelVersion, + ModelName: modelName, + ModelVersion: modelVersion, + ArtifactVersion: modelArtifactVersion, }, } iDBMockInst.On("GetModelInfoByNameAndVer").Return(&modelInfo, nil) @@ -471,7 +488,7 @@ func TestUploadModelFailureModelUploadFailure(t *testing.T) { writer.Close() // Upload model - url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s/%s", modelName, modelVersion, modelArtifactVersion) + url := fmt.Sprintf("/ai-ml-model-registration/v1/uploadModel/%s/%s", modelName, modelVersion) req := httptest.NewRequest(http.MethodPost, url, body) req.Header.Set("Content-Type", writer.FormDataContentType()) router.ServeHTTP(responseRecorder, req) diff --git a/request.http b/request.http index 01878d2..5d1f535 100644 --- a/request.http +++ b/request.http @@ -2,7 +2,7 @@ @host = x.x.x.x:32006 ### registraton -POST http://{{host}}/model-registrations +POST http://{{host}}/ai-ml-model-registration/v1/model-registrations Content-Type: application/json { @@ -65,7 +65,7 @@ DELETE http://{{host}}/model-registrations/a43d1a80-e1c5-4d87-b90f-729736bdd89f ### Upload model using multipart/form-data ### Before Uploading, Make sure you have a "Model.zip" file in the current directory -POST http://{{host}}/ai-ml-model-registration/v1/uploadModel/testmodel/1/1.0.0 +POST http://{{host}}/ai-ml-model-registration/v1/uploadModel/TestModel1/v1.0 Content-Type: multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW ------WebKitFormBoundary7MA4YWxkTrZu0gW @@ -76,4 +76,4 @@ Content-Type: application/zip ------WebKitFormBoundary7MA4YWxkTrZu0gW-- ### Download model -GET http://{{host}}/ai-ml-model-registration/v1/downloadModel/testmodel/1/1.0.0/model.zip \ No newline at end of file +GET http://{{host}}/ai-ml-model-registration/v1/downloadModel/TestModel1/v1.0/1.0.0/model.zip \ No newline at end of file diff --git a/routers/router.go b/routers/router.go index 07115ba..21d6766 100644 --- a/routers/router.go +++ b/routers/router.go @@ -29,12 +29,12 @@ func InitRouter(handler *apis.MmeApiHandler) *gin.Engine { api := r.Group("/ai-ml-model-registration/v1") { api.POST("/model-registrations", handler.RegisterModel) - api.POST("/model-registrations/updateArtifact/:modelname/:modelversion/:artifactversion", handler.UpdateArtifact) + api.POST("/model-registrations/updateArtifact/:modelname/:modelversion/:artifactversion", handler.UpdateArtifact) // Deprecated: use the new API reference: /uploadModel/:modelName/:modelVersion. api.GET("/model-registrations/:modelRegistrationId", handler.GetModelInfoById) api.PUT("/model-registrations/:modelRegistrationId", handler.UpdateModel) api.DELETE("/model-registrations/:modelRegistrationId", handler.DeleteModel) api.GET("/getModelInfo/:modelName", handler.GetModelInfoByName) - api.POST("/uploadModel/:modelName/:modelVersion/:artifactVersion", handler.UploadModel) + api.POST("/uploadModel/:modelName/:modelVersion", handler.UploadModel) api.GET("/downloadModel/:modelName/:modelVersion/:artifactVersion/model.zip", handler.DownloadModel) } // As per R1-AP v6 diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..dbd1f87 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,40 @@ +package utils + +import ( + "fmt" + "strconv" + "strings" + + "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/logging" +) + +func IncrementArtifactVersion(artifactVersion string) (string, error) { + parts := strings.Split(artifactVersion, ".") + if len(parts) != 3 { + logging.ERROR("invalid artifactVersion format: " + artifactVersion) + return "", fmt.Errorf("invalid artifactVersion format: %s", artifactVersion) + } + + major, err1 := strconv.Atoi(parts[0]) + minor, err2 := strconv.Atoi(parts[1]) + patch, err3 := strconv.Atoi(parts[2]) + if err1 != nil || err2 != nil || err3 != nil { + logging.ERROR(fmt.Sprintf("failed to parse artifactVersion numbers: %v, %v, %v", err1, err2, err3)) + return "", fmt.Errorf("failed to parse artifactVersion numbers: %v, %v, %v", err1, err2, err3) + } + + // Increment logic + if artifactVersion == "0.0.0" { + // Modify to 1.0.0 + major = 1 + minor = 0 + patch = 0 + } else { + // Change from 1.x.0 to 1.(x + 1).0 + minor += 1 + } + + // Construct new version string + newArtifactVersion := fmt.Sprintf("%d.%d.%d", major, minor, patch) + return newArtifactVersion, nil +} diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 0000000..7567e63 --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,37 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIncrementArtifactVersion(t *testing.T) { + tests := []struct { + name string + input string + expected string + expected_error string + }{ + {"InitialVersion", "0.0.0", "1.0.0", ""}, + {"MinorIncrement", "1.0.0", "1.1.0", ""}, + {"AnotherIncrement", "1.5.0", "1.6.0", ""}, + {"InvalidFormat", "a.b.c", "", "failed to parse artifactVersion numbers: strconv.Atoi: parsing \"a\": invalid syntax, strconv.Atoi: parsing \"b\": invalid syntax, strconv.Atoi: parsing \"c\": invalid syntax"}, + {"InvalidFormat", "1.0", "", "invalid artifactVersion format: 1.0"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := IncrementArtifactVersion(tc.input) + var err_str string + if err == nil { + err_str = "" + } else { + err_str = err.Error() + } + + assert.Equal(t, tc.expected_error, err_str) + assert.Equal(t, tc.expected, got) + }) + } +} -- 2.16.6