models/db: introduce TargetEnvironment side-table; batch attach on reads; add CRUD... 08/14808/6
authoralswp006 <alswp006@gmail.com>
Sat, 9 Aug 2025 10:14:43 +0000 (19:14 +0900)
committerMinje Kim <alswp006@gmail.com>
Mon, 13 Oct 2025 09:31:15 +0000 (09:31 +0000)
- Add internal-only fields to models.TargetEnvironment:
  * ID (uuid, PK)
  * ModelRelatedInformationID (FK -> model_related_informations.id)
  Both are tagged with json:"-" so the public API remains unchanged.

- Migration:
  * AutoMigrate(&ModelRelatedInformation{}, &TargetEnvironment{}) so
    target_environments has 5 columns:
    (id, model_related_information_id, platform_name, environment_type, dependency_list).
  * For existing DBs, legacy rows without model_related_information_id need
    backfill or archival; otherwise they will not attach to parents.

- Fixes the registration error:
  "column model_related_information_id does not exist".

- Tests:
  * Add CRUD unit test (TestCRUD_Subtests) covering create/read/replace/partial(nil)/clear([])/delete.
    All pass with -race.

Change-Id: Id57bf7721ce2295fceb394a0dd82feab72e60a95
Signed-off-by: alswp006 <alswp006@gmail.com>
db/modelInfoRepository.go
db/modelInfoRepository_test.go [new file with mode: 0644]
main.go
models/modelInfo.go

index 95e6acd..fbb60ff 100644 (file)
@@ -28,14 +28,40 @@ type ModelInfoRepository struct {
 }
 
 func NewModelInfoRepository(db *gorm.DB) *ModelInfoRepository {
-       return &ModelInfoRepository{
-               db: db,
+       return &ModelInfoRepository{db: db}
+}
+
+func replaceTargetEnvs(tx *gorm.DB, m *models.ModelRelatedInformation) error {
+       tes := m.ModelInformation.TargetEnvironment
+       if tes == nil {
+               return nil
+       }
+       if err := tx.Where("model_related_information_id = ?", m.Id).
+               Delete(&models.TargetEnvironment{}).Error; err != nil {
+               return err
+       }
+       if len(tes) == 0 {
+               return nil
        }
+       rows := make([]models.TargetEnvironment, 0, len(tes))
+       for _, te := range tes {
+               rows = append(rows, models.TargetEnvironment{
+                       ModelRelatedInformationID: m.Id,
+                       PlatformName:              te.PlatformName,
+                       EnvironmentType:           te.EnvironmentType,
+                       DependencyList:            te.DependencyList,
+               })
+       }
+       return tx.Create(&rows).Error
 }
 
-func (repo *ModelInfoRepository) Create(modelInfo models.ModelRelatedInformation) error {
-       result := repo.db.Create(modelInfo)
-       return result.Error
+func (repo *ModelInfoRepository) Create(m models.ModelRelatedInformation) error {
+       return repo.db.Transaction(func(tx *gorm.DB) error {
+               if err := tx.Create(&m).Error; err != nil {
+                       return err
+               }
+               return replaceTargetEnvs(tx, &m)
+       })
 }
 
 func (repo *ModelInfoRepository) GetByID(id string) (*models.ModelRelatedInformation, error) {
@@ -44,48 +70,136 @@ func (repo *ModelInfoRepository) GetByID(id string) (*models.ModelRelatedInforma
 
 func (repo *ModelInfoRepository) GetAll() ([]models.ModelRelatedInformation, error) {
        var modelInfos []models.ModelRelatedInformation
-       result := repo.db.Find(&modelInfos)
+       result := repo.db.Session(&gorm.Session{SkipHooks: true}).Find(&modelInfos)
        if result.Error != nil {
                return nil, result.Error
        }
+       if err := attachEnvsBatch(repo.db, modelInfos); err != nil {
+               return nil, err
+       }
        return modelInfos, nil
 }
 
-func (repo *ModelInfoRepository) Update(modelInfo models.ModelRelatedInformation) error {
-       if err := repo.db.Save(modelInfo).Error; err != nil {
-               return err
-       }
-       return nil
+func (repo *ModelInfoRepository) Update(m models.ModelRelatedInformation) error {
+       return repo.db.Transaction(func(tx *gorm.DB) error {
+               if err := tx.Save(&m).Error; err != nil {
+                       return err
+               }
+               return replaceTargetEnvs(tx, &m)
+       })
 }
 
 func (repo *ModelInfoRepository) Delete(id string) (int64, error) {
-       result := repo.db.Delete(&models.ModelRelatedInformation{}, "id = ?", id)
-       return result.RowsAffected, result.Error
+       var rows int64
+       err := repo.db.Transaction(func(tx *gorm.DB) error {
+               if err := tx.Where("model_related_information_id = ?", id).
+                       Delete(&models.TargetEnvironment{}).
+                       Error; err != nil {
+                       return err
+               }
+               res := tx.Delete(&models.ModelRelatedInformation{}, "id = ?", id)
+               rows = res.RowsAffected
+               return res.Error
+       })
+       return rows, err
 }
+
 func (repo *ModelInfoRepository) GetModelInfoByName(modelName string) ([]models.ModelRelatedInformation, error) {
        var modelInfos []models.ModelRelatedInformation
-       result := repo.db.Where("model_name = ?", modelName).Find(&modelInfos)
-       if result.Error != nil {
-               return nil, result.Error
+       if err := repo.db.Session(&gorm.Session{SkipHooks: true}).
+               Where("model_name = ?", modelName).
+               Find(&modelInfos).Error; err != nil {
+               return nil, err
+       }
+       if err := attachEnvsBatch(repo.db, modelInfos); err != nil {
+               return nil, err
        }
        return modelInfos, nil
 }
 
 func (repo *ModelInfoRepository) GetModelInfoByNameAndVer(modelName string, modelVersion string) (*models.ModelRelatedInformation, error) {
-       var modelInfo models.ModelRelatedInformation
-       result := repo.db.Where("model_name = ? AND model_version = ?", modelName, modelVersion).Find(&modelInfo)
-       if result.Error != nil {
-               return nil, result.Error
+       var m models.ModelRelatedInformation
+       if err := repo.db.Session(&gorm.Session{SkipHooks: true}).
+               Where("model_name = ? AND model_version = ?", modelName, modelVersion).
+               First(&m).Error; err != nil {
+               return nil, err
+       }
+       if err := attachEnvsOne(repo.db, &m); err != nil {
+               return nil, err
        }
-       return &modelInfo, nil
+       return &m, nil
 }
 
 func (repo *ModelInfoRepository) GetModelInfoById(id string) (*models.ModelRelatedInformation, error) {
-       logging.INFO("id is:", id)
-       var modelInfo models.ModelRelatedInformation
-       result := repo.db.Where("id = ?", id).Find(&modelInfo)
-       if result.Error != nil {
-               return nil, result.Error
+       logging.INFO("id is: ", id)
+       var m models.ModelRelatedInformation
+       if err := repo.db.Session(&gorm.Session{SkipHooks: true}).
+               Where("id = ?", id).
+               First(&m).Error; err != nil {
+               return nil, err
+       }
+       if err := attachEnvsOne(repo.db, &m); err != nil {
+               return nil, err
        }
-       return &modelInfo, nil
+       return &m, nil
+}
+
+func attachEnvsOne(tx *gorm.DB, m *models.ModelRelatedInformation) error {
+       if m == nil || m.Id == "" {
+               return nil
+       }
+       var rows []models.TargetEnvironment
+       if err := tx.Table("target_environments").
+               Select("platform_name, environment_type, dependency_list").
+               Where("model_related_information_id = ?", m.Id).
+               Find(&rows).Error; err != nil {
+               return err
+       }
+       m.ModelInformation.TargetEnvironment = rows
+       return nil
+}
+
+func attachEnvsBatch(tx *gorm.DB, parents []models.ModelRelatedInformation) error {
+       if len(parents) == 0 {
+               return nil
+       }
+
+       ids := make([]string, 0, len(parents))
+       pos := make(map[string]int, len(parents))
+       for i := range parents {
+               if id := parents[i].Id; id != "" {
+                       ids = append(ids, id)
+                       pos[id] = i
+               }
+       }
+       if len(ids) == 0 {
+               return nil
+       }
+
+       var rows []struct {
+               ModelRelatedInformationID string
+               PlatformName              string
+               EnvironmentType           string
+               DependencyList            string
+       }
+       if err := tx.Table("target_environments").
+               Select("model_related_information_id, platform_name, environment_type, dependency_list").
+               Where("model_related_information_id IN ?", ids).
+               Find(&rows).Error; err != nil {
+               return err
+       }
+
+       for _, r := range rows {
+               if i, ok := pos[r.ModelRelatedInformationID]; ok {
+                       parents[i].ModelInformation.TargetEnvironment = append(
+                               parents[i].ModelInformation.TargetEnvironment,
+                               models.TargetEnvironment{
+                                       PlatformName:    r.PlatformName,
+                                       EnvironmentType: r.EnvironmentType,
+                                       DependencyList:  r.DependencyList,
+                               },
+                       )
+               }
+       }
+       return nil
 }
diff --git a/db/modelInfoRepository_test.go b/db/modelInfoRepository_test.go
new file mode 100644 (file)
index 0000000..5938e92
--- /dev/null
@@ -0,0 +1,184 @@
+/*
+==================================================================================
+Copyright (c) 2025 Minje Kim <alswp006@gmail.com> All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==================================================================================
+*/
+package db
+
+import (
+       "testing"
+
+       "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/models"
+       "github.com/glebarez/sqlite"
+       "gorm.io/gorm"
+)
+
+func openTestDB(t *testing.T) *gorm.DB {
+       t.Helper()
+
+       d, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
+       if err != nil {
+               t.Fatalf("open sqlite: %v", err)
+       }
+
+       if err := d.AutoMigrate(
+               &models.ModelRelatedInformation{},
+               &models.TargetEnvironment{},
+       ); err != nil {
+               t.Fatalf("automigrate: %v", err)
+       }
+       return d
+}
+
+func newRepo(t *testing.T) *ModelInfoRepository {
+       return NewModelInfoRepository(openTestDB(t))
+}
+
+func mkMRI(name, ver string, envs []models.TargetEnvironment) models.ModelRelatedInformation {
+       return models.ModelRelatedInformation{
+               ModelId:     models.ModelID{ModelName: name, ModelVersion: ver},
+               Description: "test",
+               ModelInformation: models.ModelInformation{
+                       Metadata:          models.Metadata{Author: "tester"},
+                       InputDataType:     "csv",
+                       OutputDataType:    "json",
+                       TargetEnvironment: envs,
+               },
+       }
+}
+func TestCRUD_Subtests(t *testing.T) {
+       repo := newRepo(t)
+
+       var id string
+
+       t.Run("Create_with_envs", func(t *testing.T) {
+               m := mkMRI("resnet", "1.0.0", []models.TargetEnvironment{
+                       {PlatformName: "k8s", EnvironmentType: "prod", DependencyList: "cuda=12.1,torch=2.3"},
+                       {PlatformName: "ec2", EnvironmentType: "stg", DependencyList: "cpu-only"},
+               })
+               if err := repo.Create(m); err != nil {
+                       t.Fatalf("create: %v", err)
+               }
+               got, err := repo.GetModelInfoByNameAndVer("resnet", "1.0.0")
+               if err != nil {
+                       t.Fatalf("get after create: %v", err)
+               }
+               if got.Id == "" {
+                       t.Fatalf("expected non-empty id")
+               }
+               if len(got.ModelInformation.TargetEnvironment) != 2 {
+                       t.Fatalf("want 2 envs, got %d", len(got.ModelInformation.TargetEnvironment))
+               }
+               id = got.Id
+       })
+
+       t.Run("Read", func(t *testing.T) {
+               byID, err := repo.GetModelInfoById(id)
+               if err != nil {
+                       t.Fatalf("get by id: %v", err)
+               }
+               if byID.Id != id {
+                       t.Fatalf("id mismatch")
+               }
+
+               all, err := repo.GetAll()
+               if err != nil {
+                       t.Fatalf("get all: %v", err)
+               }
+               if len(all) == 0 {
+                       t.Fatalf("expected >=1")
+               }
+
+               list, err := repo.GetModelInfoByName("resnet")
+               if err != nil {
+                       t.Fatalf("get by name: %v", err)
+               }
+               if len(list) == 0 {
+                       t.Fatalf("expected >=1 resnet")
+               }
+       })
+
+       t.Run("Update_Replace_with_new_values", func(t *testing.T) {
+               cur, err := repo.GetModelInfoById(id)
+               if err != nil {
+                       t.Fatalf("get: %v", err)
+               }
+               cur.ModelInformation.TargetEnvironment = []models.TargetEnvironment{
+                       {PlatformName: "edge", EnvironmentType: "prod", DependencyList: "cuda=12.2,torch=2.4"},
+               }
+               if err := repo.Update(*cur); err != nil {
+                       t.Fatalf("update replace: %v", err)
+               }
+               after, err := repo.GetModelInfoById(id)
+               if err != nil {
+                       t.Fatalf("get after replace: %v", err)
+               }
+               if len(after.ModelInformation.TargetEnvironment) != 1 {
+                       t.Fatalf("want 1 env after replace, got %d", len(after.ModelInformation.TargetEnvironment))
+               }
+               if after.ModelInformation.TargetEnvironment[0].PlatformName != "edge" {
+                       t.Fatalf("unexpected platform after replace: %+v", after.ModelInformation.TargetEnvironment[0])
+               }
+       })
+
+       t.Run("Update_Partial_nil_keeps_existing", func(t *testing.T) {
+               cur, err := repo.GetModelInfoById(id)
+               if err != nil {
+                       t.Fatalf("get: %v", err)
+               }
+               cur.ModelInformation.TargetEnvironment = nil
+               if err := repo.Update(*cur); err != nil {
+                       t.Fatalf("update partial nil: %v", err)
+               }
+               after, err := repo.GetModelInfoById(id)
+               if err != nil {
+                       t.Fatalf("get after partial: %v", err)
+               }
+               if len(after.ModelInformation.TargetEnvironment) != 1 {
+                       t.Fatalf("want keep 1 env, got %d", len(after.ModelInformation.TargetEnvironment))
+               }
+       })
+
+       t.Run("Update_Clear_with_empty_slice", func(t *testing.T) {
+               cur, err := repo.GetModelInfoById(id)
+               if err != nil {
+                       t.Fatalf("get: %v", err)
+               }
+               cur.ModelInformation.TargetEnvironment = []models.TargetEnvironment{}
+               if err := repo.Update(*cur); err != nil {
+                       t.Fatalf("update clear []: %v", err)
+               }
+               after, err := repo.GetModelInfoById(id)
+               if err != nil {
+                       t.Fatalf("get after clear: %v", err)
+               }
+               if len(after.ModelInformation.TargetEnvironment) != 0 {
+                       t.Fatalf("want 0 env after clear, got %d", len(after.ModelInformation.TargetEnvironment))
+               }
+       })
+
+       t.Run("Delete", func(t *testing.T) {
+               rows, err := repo.Delete(id)
+               if err != nil {
+                       t.Fatalf("delete: %v", err)
+               }
+               if rows != 1 {
+                       t.Fatalf("want 1 row deleted, got %d", rows)
+               }
+               if _, err := repo.GetModelInfoById(id); err == nil {
+                       t.Fatalf("expected error after delete, got nil")
+               }
+       })
+}
diff --git a/main.go b/main.go
index 25eb27d..ed29860 100644 (file)
--- a/main.go
+++ b/main.go
@@ -67,7 +67,15 @@ func main() {
        }
 
        // Auto migrate the scheme
-       db.AutoMigrate(&models.ModelRelatedInformation{})
+       err = db.AutoMigrate(
+               &models.ModelRelatedInformation{},
+               &models.TargetEnvironment{},
+       )
+       if err != nil {
+               logging.ERROR("Failed to migrate database", "error", err)
+               os.Exit(-1)
+       }
+
        repo := modelDB.NewModelInfoRepository(db)
 
        router := routers.InitRouter(
index f154838..dfdca6c 100644 (file)
@@ -18,24 +18,40 @@ limitations under the License.
 
 package models
 
+import (
+       "github.com/google/uuid"
+       "gorm.io/gorm"
+)
+
 type Metadata struct {
        Author string `json:"author" validate:"required"`
        Owner  string `json:"owner"`
 }
 
-type TargetEnironment struct {
-       PlatformName    string `json:"platformName" validate:"required"`
-       EnvironmentType string `json:"environmentType" validate:"required"`
-       DependencyList  string `json:"dependencyList" validate:"required"`
+type TargetEnvironment struct {
+       ID                        string `gorm:"primaryKey" json:"-"`
+       ModelRelatedInformationID string `gorm:"index;not null" json:"-"`
+       PlatformName              string `json:"platformName" validate:"required"`
+       EnvironmentType           string `json:"environmentType" validate:"required"`
+       DependencyList            string `json:"dependencyList" validate:"required"`
+}
+
+func (TargetEnvironment) TableName() string { return "target_environments" }
+
+func (te *TargetEnvironment) BeforeCreate(tx *gorm.DB) error {
+       if te.ID == "" {
+               te.ID = uuid.NewString()
+       }
+       return nil
 }
 
 type ModelInformation struct {
-       Metadata       Metadata `json:"metadata" gorm:"embedded" validate:"required"`
-       InputDataType  string   `json:"inputDataType" validate:"required"`  // this field will be a Comma Separated List
-       OutputDataType string   `json:"outputDataType" validate:"required"` // this field will be a Comma Separated List
-       // TODO: gorm doesn't support list, need to find the right way
-       // TargetEnvironment []TargetEnironment `json:"targetEnvironment" gorm:"embedded"`
+       Metadata          Metadata            `json:"metadata" gorm:"embedded" validate:"required"`
+       InputDataType     string              `json:"inputDataType" validate:"required"`  // this field will be a Comma Separated List
+       OutputDataType    string              `json:"outputDataType" validate:"required"` // this field will be a Comma Separated List
+       TargetEnvironment []TargetEnvironment `json:"targetEnvironment,omitempty" gorm:"-"`
 }
+
 type ModelID struct {
        ModelName       string `json:"modelName" validate:"required" gorm:"primaryKey"`
        ModelVersion    string `json:"modelVersion" validate:"required" gorm:"primaryKey"`
@@ -54,3 +70,10 @@ type ModelInfoResponse struct {
        Name string `json:"name"`
        Data string `json:"data"`
 }
+
+func (modelInfo *ModelRelatedInformation) BeforeCreate(tx *gorm.DB) error {
+       if modelInfo.Id == "" {
+               modelInfo.Id = uuid.NewString()
+       }
+       return nil
+}