From 4248cbcab434d0b55c8c59e9eec03a4c1df1d42d Mon Sep 17 00:00:00 2001 From: alswp006 Date: Sat, 9 Aug 2025 19:14:43 +0900 Subject: [PATCH] models/db: introduce TargetEnvironment side-table; batch attach on reads; add CRUD tests - Add internal-only fields to models.TargetEnvironment: * ID (uuid, PK) * ModelRelatedInformationID (FK -> model_related_informations.id) Both are tagged with json:"-" so the public API remains unchanged. - Migration: * AutoMigrate(&ModelRelatedInformation{}, &TargetEnvironment{}) so target_environments has 5 columns: (id, model_related_information_id, platform_name, environment_type, dependency_list). * For existing DBs, legacy rows without model_related_information_id need backfill or archival; otherwise they will not attach to parents. - Fixes the registration error: "column model_related_information_id does not exist". - Tests: * Add CRUD unit test (TestCRUD_Subtests) covering create/read/replace/partial(nil)/clear([])/delete. All pass with -race. Change-Id: Id57bf7721ce2295fceb394a0dd82feab72e60a95 Signed-off-by: alswp006 --- db/modelInfoRepository.go | 168 +++++++++++++++++++++++++++++++------ db/modelInfoRepository_test.go | 184 +++++++++++++++++++++++++++++++++++++++++ main.go | 10 ++- models/modelInfo.go | 41 +++++++-- 4 files changed, 366 insertions(+), 37 deletions(-) create mode 100644 db/modelInfoRepository_test.go diff --git a/db/modelInfoRepository.go b/db/modelInfoRepository.go index 95e6acd..fbb60ff 100644 --- a/db/modelInfoRepository.go +++ b/db/modelInfoRepository.go @@ -28,14 +28,40 @@ type ModelInfoRepository struct { } func NewModelInfoRepository(db *gorm.DB) *ModelInfoRepository { - return &ModelInfoRepository{ - db: db, + return &ModelInfoRepository{db: db} +} + +func replaceTargetEnvs(tx *gorm.DB, m *models.ModelRelatedInformation) error { + tes := m.ModelInformation.TargetEnvironment + if tes == nil { + return nil + } + if err := tx.Where("model_related_information_id = ?", m.Id). + Delete(&models.TargetEnvironment{}).Error; err != nil { + return err + } + if len(tes) == 0 { + return nil } + rows := make([]models.TargetEnvironment, 0, len(tes)) + for _, te := range tes { + rows = append(rows, models.TargetEnvironment{ + ModelRelatedInformationID: m.Id, + PlatformName: te.PlatformName, + EnvironmentType: te.EnvironmentType, + DependencyList: te.DependencyList, + }) + } + return tx.Create(&rows).Error } -func (repo *ModelInfoRepository) Create(modelInfo models.ModelRelatedInformation) error { - result := repo.db.Create(modelInfo) - return result.Error +func (repo *ModelInfoRepository) Create(m models.ModelRelatedInformation) error { + return repo.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Create(&m).Error; err != nil { + return err + } + return replaceTargetEnvs(tx, &m) + }) } func (repo *ModelInfoRepository) GetByID(id string) (*models.ModelRelatedInformation, error) { @@ -44,48 +70,136 @@ func (repo *ModelInfoRepository) GetByID(id string) (*models.ModelRelatedInforma func (repo *ModelInfoRepository) GetAll() ([]models.ModelRelatedInformation, error) { var modelInfos []models.ModelRelatedInformation - result := repo.db.Find(&modelInfos) + result := repo.db.Session(&gorm.Session{SkipHooks: true}).Find(&modelInfos) if result.Error != nil { return nil, result.Error } + if err := attachEnvsBatch(repo.db, modelInfos); err != nil { + return nil, err + } return modelInfos, nil } -func (repo *ModelInfoRepository) Update(modelInfo models.ModelRelatedInformation) error { - if err := repo.db.Save(modelInfo).Error; err != nil { - return err - } - return nil +func (repo *ModelInfoRepository) Update(m models.ModelRelatedInformation) error { + return repo.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Save(&m).Error; err != nil { + return err + } + return replaceTargetEnvs(tx, &m) + }) } func (repo *ModelInfoRepository) Delete(id string) (int64, error) { - result := repo.db.Delete(&models.ModelRelatedInformation{}, "id = ?", id) - return result.RowsAffected, result.Error + var rows int64 + err := repo.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("model_related_information_id = ?", id). + Delete(&models.TargetEnvironment{}). + Error; err != nil { + return err + } + res := tx.Delete(&models.ModelRelatedInformation{}, "id = ?", id) + rows = res.RowsAffected + return res.Error + }) + return rows, err } + func (repo *ModelInfoRepository) GetModelInfoByName(modelName string) ([]models.ModelRelatedInformation, error) { var modelInfos []models.ModelRelatedInformation - result := repo.db.Where("model_name = ?", modelName).Find(&modelInfos) - if result.Error != nil { - return nil, result.Error + if err := repo.db.Session(&gorm.Session{SkipHooks: true}). + Where("model_name = ?", modelName). + Find(&modelInfos).Error; err != nil { + return nil, err + } + if err := attachEnvsBatch(repo.db, modelInfos); err != nil { + return nil, err } return modelInfos, nil } func (repo *ModelInfoRepository) GetModelInfoByNameAndVer(modelName string, modelVersion string) (*models.ModelRelatedInformation, error) { - var modelInfo models.ModelRelatedInformation - result := repo.db.Where("model_name = ? AND model_version = ?", modelName, modelVersion).Find(&modelInfo) - if result.Error != nil { - return nil, result.Error + var m models.ModelRelatedInformation + if err := repo.db.Session(&gorm.Session{SkipHooks: true}). + Where("model_name = ? AND model_version = ?", modelName, modelVersion). + First(&m).Error; err != nil { + return nil, err + } + if err := attachEnvsOne(repo.db, &m); err != nil { + return nil, err } - return &modelInfo, nil + return &m, nil } func (repo *ModelInfoRepository) GetModelInfoById(id string) (*models.ModelRelatedInformation, error) { - logging.INFO("id is:", id) - var modelInfo models.ModelRelatedInformation - result := repo.db.Where("id = ?", id).Find(&modelInfo) - if result.Error != nil { - return nil, result.Error + logging.INFO("id is: ", id) + var m models.ModelRelatedInformation + if err := repo.db.Session(&gorm.Session{SkipHooks: true}). + Where("id = ?", id). + First(&m).Error; err != nil { + return nil, err + } + if err := attachEnvsOne(repo.db, &m); err != nil { + return nil, err } - return &modelInfo, nil + return &m, nil +} + +func attachEnvsOne(tx *gorm.DB, m *models.ModelRelatedInformation) error { + if m == nil || m.Id == "" { + return nil + } + var rows []models.TargetEnvironment + if err := tx.Table("target_environments"). + Select("platform_name, environment_type, dependency_list"). + Where("model_related_information_id = ?", m.Id). + Find(&rows).Error; err != nil { + return err + } + m.ModelInformation.TargetEnvironment = rows + return nil +} + +func attachEnvsBatch(tx *gorm.DB, parents []models.ModelRelatedInformation) error { + if len(parents) == 0 { + return nil + } + + ids := make([]string, 0, len(parents)) + pos := make(map[string]int, len(parents)) + for i := range parents { + if id := parents[i].Id; id != "" { + ids = append(ids, id) + pos[id] = i + } + } + if len(ids) == 0 { + return nil + } + + var rows []struct { + ModelRelatedInformationID string + PlatformName string + EnvironmentType string + DependencyList string + } + if err := tx.Table("target_environments"). + Select("model_related_information_id, platform_name, environment_type, dependency_list"). + Where("model_related_information_id IN ?", ids). + Find(&rows).Error; err != nil { + return err + } + + for _, r := range rows { + if i, ok := pos[r.ModelRelatedInformationID]; ok { + parents[i].ModelInformation.TargetEnvironment = append( + parents[i].ModelInformation.TargetEnvironment, + models.TargetEnvironment{ + PlatformName: r.PlatformName, + EnvironmentType: r.EnvironmentType, + DependencyList: r.DependencyList, + }, + ) + } + } + return nil } diff --git a/db/modelInfoRepository_test.go b/db/modelInfoRepository_test.go new file mode 100644 index 0000000..5938e92 --- /dev/null +++ b/db/modelInfoRepository_test.go @@ -0,0 +1,184 @@ +/* +================================================================================== +Copyright (c) 2025 Minje Kim All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +================================================================================== +*/ +package db + +import ( + "testing" + + "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/models" + "github.com/glebarez/sqlite" + "gorm.io/gorm" +) + +func openTestDB(t *testing.T) *gorm.DB { + t.Helper() + + d, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + + if err := d.AutoMigrate( + &models.ModelRelatedInformation{}, + &models.TargetEnvironment{}, + ); err != nil { + t.Fatalf("automigrate: %v", err) + } + return d +} + +func newRepo(t *testing.T) *ModelInfoRepository { + return NewModelInfoRepository(openTestDB(t)) +} + +func mkMRI(name, ver string, envs []models.TargetEnvironment) models.ModelRelatedInformation { + return models.ModelRelatedInformation{ + ModelId: models.ModelID{ModelName: name, ModelVersion: ver}, + Description: "test", + ModelInformation: models.ModelInformation{ + Metadata: models.Metadata{Author: "tester"}, + InputDataType: "csv", + OutputDataType: "json", + TargetEnvironment: envs, + }, + } +} +func TestCRUD_Subtests(t *testing.T) { + repo := newRepo(t) + + var id string + + t.Run("Create_with_envs", func(t *testing.T) { + m := mkMRI("resnet", "1.0.0", []models.TargetEnvironment{ + {PlatformName: "k8s", EnvironmentType: "prod", DependencyList: "cuda=12.1,torch=2.3"}, + {PlatformName: "ec2", EnvironmentType: "stg", DependencyList: "cpu-only"}, + }) + if err := repo.Create(m); err != nil { + t.Fatalf("create: %v", err) + } + got, err := repo.GetModelInfoByNameAndVer("resnet", "1.0.0") + if err != nil { + t.Fatalf("get after create: %v", err) + } + if got.Id == "" { + t.Fatalf("expected non-empty id") + } + if len(got.ModelInformation.TargetEnvironment) != 2 { + t.Fatalf("want 2 envs, got %d", len(got.ModelInformation.TargetEnvironment)) + } + id = got.Id + }) + + t.Run("Read", func(t *testing.T) { + byID, err := repo.GetModelInfoById(id) + if err != nil { + t.Fatalf("get by id: %v", err) + } + if byID.Id != id { + t.Fatalf("id mismatch") + } + + all, err := repo.GetAll() + if err != nil { + t.Fatalf("get all: %v", err) + } + if len(all) == 0 { + t.Fatalf("expected >=1") + } + + list, err := repo.GetModelInfoByName("resnet") + if err != nil { + t.Fatalf("get by name: %v", err) + } + if len(list) == 0 { + t.Fatalf("expected >=1 resnet") + } + }) + + t.Run("Update_Replace_with_new_values", func(t *testing.T) { + cur, err := repo.GetModelInfoById(id) + if err != nil { + t.Fatalf("get: %v", err) + } + cur.ModelInformation.TargetEnvironment = []models.TargetEnvironment{ + {PlatformName: "edge", EnvironmentType: "prod", DependencyList: "cuda=12.2,torch=2.4"}, + } + if err := repo.Update(*cur); err != nil { + t.Fatalf("update replace: %v", err) + } + after, err := repo.GetModelInfoById(id) + if err != nil { + t.Fatalf("get after replace: %v", err) + } + if len(after.ModelInformation.TargetEnvironment) != 1 { + t.Fatalf("want 1 env after replace, got %d", len(after.ModelInformation.TargetEnvironment)) + } + if after.ModelInformation.TargetEnvironment[0].PlatformName != "edge" { + t.Fatalf("unexpected platform after replace: %+v", after.ModelInformation.TargetEnvironment[0]) + } + }) + + t.Run("Update_Partial_nil_keeps_existing", func(t *testing.T) { + cur, err := repo.GetModelInfoById(id) + if err != nil { + t.Fatalf("get: %v", err) + } + cur.ModelInformation.TargetEnvironment = nil + if err := repo.Update(*cur); err != nil { + t.Fatalf("update partial nil: %v", err) + } + after, err := repo.GetModelInfoById(id) + if err != nil { + t.Fatalf("get after partial: %v", err) + } + if len(after.ModelInformation.TargetEnvironment) != 1 { + t.Fatalf("want keep 1 env, got %d", len(after.ModelInformation.TargetEnvironment)) + } + }) + + t.Run("Update_Clear_with_empty_slice", func(t *testing.T) { + cur, err := repo.GetModelInfoById(id) + if err != nil { + t.Fatalf("get: %v", err) + } + cur.ModelInformation.TargetEnvironment = []models.TargetEnvironment{} + if err := repo.Update(*cur); err != nil { + t.Fatalf("update clear []: %v", err) + } + after, err := repo.GetModelInfoById(id) + if err != nil { + t.Fatalf("get after clear: %v", err) + } + if len(after.ModelInformation.TargetEnvironment) != 0 { + t.Fatalf("want 0 env after clear, got %d", len(after.ModelInformation.TargetEnvironment)) + } + }) + + t.Run("Delete", func(t *testing.T) { + rows, err := repo.Delete(id) + if err != nil { + t.Fatalf("delete: %v", err) + } + if rows != 1 { + t.Fatalf("want 1 row deleted, got %d", rows) + } + if _, err := repo.GetModelInfoById(id); err == nil { + t.Fatalf("expected error after delete, got nil") + } + }) +} diff --git a/main.go b/main.go index 25eb27d..ed29860 100644 --- a/main.go +++ b/main.go @@ -67,7 +67,15 @@ func main() { } // Auto migrate the scheme - db.AutoMigrate(&models.ModelRelatedInformation{}) + err = db.AutoMigrate( + &models.ModelRelatedInformation{}, + &models.TargetEnvironment{}, + ) + if err != nil { + logging.ERROR("Failed to migrate database", "error", err) + os.Exit(-1) + } + repo := modelDB.NewModelInfoRepository(db) router := routers.InitRouter( diff --git a/models/modelInfo.go b/models/modelInfo.go index f154838..dfdca6c 100644 --- a/models/modelInfo.go +++ b/models/modelInfo.go @@ -18,24 +18,40 @@ limitations under the License. package models +import ( + "github.com/google/uuid" + "gorm.io/gorm" +) + type Metadata struct { Author string `json:"author" validate:"required"` Owner string `json:"owner"` } -type TargetEnironment struct { - PlatformName string `json:"platformName" validate:"required"` - EnvironmentType string `json:"environmentType" validate:"required"` - DependencyList string `json:"dependencyList" validate:"required"` +type TargetEnvironment struct { + ID string `gorm:"primaryKey" json:"-"` + ModelRelatedInformationID string `gorm:"index;not null" json:"-"` + PlatformName string `json:"platformName" validate:"required"` + EnvironmentType string `json:"environmentType" validate:"required"` + DependencyList string `json:"dependencyList" validate:"required"` +} + +func (TargetEnvironment) TableName() string { return "target_environments" } + +func (te *TargetEnvironment) BeforeCreate(tx *gorm.DB) error { + if te.ID == "" { + te.ID = uuid.NewString() + } + return nil } type ModelInformation struct { - Metadata Metadata `json:"metadata" gorm:"embedded" validate:"required"` - InputDataType string `json:"inputDataType" validate:"required"` // this field will be a Comma Separated List - OutputDataType string `json:"outputDataType" validate:"required"` // this field will be a Comma Separated List - // TODO: gorm doesn't support list, need to find the right way - // TargetEnvironment []TargetEnironment `json:"targetEnvironment" gorm:"embedded"` + Metadata Metadata `json:"metadata" gorm:"embedded" validate:"required"` + InputDataType string `json:"inputDataType" validate:"required"` // this field will be a Comma Separated List + OutputDataType string `json:"outputDataType" validate:"required"` // this field will be a Comma Separated List + TargetEnvironment []TargetEnvironment `json:"targetEnvironment,omitempty" gorm:"-"` } + type ModelID struct { ModelName string `json:"modelName" validate:"required" gorm:"primaryKey"` ModelVersion string `json:"modelVersion" validate:"required" gorm:"primaryKey"` @@ -54,3 +70,10 @@ type ModelInfoResponse struct { Name string `json:"name"` Data string `json:"data"` } + +func (modelInfo *ModelRelatedInformation) BeforeCreate(tx *gorm.DB) error { + if modelInfo.Id == "" { + modelInfo.Id = uuid.NewString() + } + return nil +} -- 2.16.6