}
func NewModelInfoRepository(db *gorm.DB) *ModelInfoRepository {
- return &ModelInfoRepository{
- db: db,
+ return &ModelInfoRepository{db: db}
+}
+
+func replaceTargetEnvs(tx *gorm.DB, m *models.ModelRelatedInformation) error {
+ tes := m.ModelInformation.TargetEnvironment
+ if tes == nil {
+ return nil
+ }
+ if err := tx.Where("model_related_information_id = ?", m.Id).
+ Delete(&models.TargetEnvironment{}).Error; err != nil {
+ return err
+ }
+ if len(tes) == 0 {
+ return nil
}
+ rows := make([]models.TargetEnvironment, 0, len(tes))
+ for _, te := range tes {
+ rows = append(rows, models.TargetEnvironment{
+ ModelRelatedInformationID: m.Id,
+ PlatformName: te.PlatformName,
+ EnvironmentType: te.EnvironmentType,
+ DependencyList: te.DependencyList,
+ })
+ }
+ return tx.Create(&rows).Error
}
-func (repo *ModelInfoRepository) Create(modelInfo models.ModelRelatedInformation) error {
- result := repo.db.Create(modelInfo)
- return result.Error
+func (repo *ModelInfoRepository) Create(m models.ModelRelatedInformation) error {
+ return repo.db.Transaction(func(tx *gorm.DB) error {
+ if err := tx.Create(&m).Error; err != nil {
+ return err
+ }
+ return replaceTargetEnvs(tx, &m)
+ })
}
func (repo *ModelInfoRepository) GetByID(id string) (*models.ModelRelatedInformation, error) {
func (repo *ModelInfoRepository) GetAll() ([]models.ModelRelatedInformation, error) {
var modelInfos []models.ModelRelatedInformation
- result := repo.db.Find(&modelInfos)
+ result := repo.db.Session(&gorm.Session{SkipHooks: true}).Find(&modelInfos)
if result.Error != nil {
return nil, result.Error
}
+ if err := attachEnvsBatch(repo.db, modelInfos); err != nil {
+ return nil, err
+ }
return modelInfos, nil
}
-func (repo *ModelInfoRepository) Update(modelInfo models.ModelRelatedInformation) error {
- if err := repo.db.Save(modelInfo).Error; err != nil {
- return err
- }
- return nil
+func (repo *ModelInfoRepository) Update(m models.ModelRelatedInformation) error {
+ return repo.db.Transaction(func(tx *gorm.DB) error {
+ if err := tx.Save(&m).Error; err != nil {
+ return err
+ }
+ return replaceTargetEnvs(tx, &m)
+ })
}
func (repo *ModelInfoRepository) Delete(id string) (int64, error) {
- result := repo.db.Delete(&models.ModelRelatedInformation{}, "id = ?", id)
- return result.RowsAffected, result.Error
+ var rows int64
+ err := repo.db.Transaction(func(tx *gorm.DB) error {
+ if err := tx.Where("model_related_information_id = ?", id).
+ Delete(&models.TargetEnvironment{}).
+ Error; err != nil {
+ return err
+ }
+ res := tx.Delete(&models.ModelRelatedInformation{}, "id = ?", id)
+ rows = res.RowsAffected
+ return res.Error
+ })
+ return rows, err
}
+
func (repo *ModelInfoRepository) GetModelInfoByName(modelName string) ([]models.ModelRelatedInformation, error) {
var modelInfos []models.ModelRelatedInformation
- result := repo.db.Where("model_name = ?", modelName).Find(&modelInfos)
- if result.Error != nil {
- return nil, result.Error
+ if err := repo.db.Session(&gorm.Session{SkipHooks: true}).
+ Where("model_name = ?", modelName).
+ Find(&modelInfos).Error; err != nil {
+ return nil, err
+ }
+ if err := attachEnvsBatch(repo.db, modelInfos); err != nil {
+ return nil, err
}
return modelInfos, nil
}
func (repo *ModelInfoRepository) GetModelInfoByNameAndVer(modelName string, modelVersion string) (*models.ModelRelatedInformation, error) {
- var modelInfo models.ModelRelatedInformation
- result := repo.db.Where("model_name = ? AND model_version = ?", modelName, modelVersion).Find(&modelInfo)
- if result.Error != nil {
- return nil, result.Error
+ var m models.ModelRelatedInformation
+ if err := repo.db.Session(&gorm.Session{SkipHooks: true}).
+ Where("model_name = ? AND model_version = ?", modelName, modelVersion).
+ First(&m).Error; err != nil {
+ return nil, err
+ }
+ if err := attachEnvsOne(repo.db, &m); err != nil {
+ return nil, err
}
- return &modelInfo, nil
+ return &m, nil
}
func (repo *ModelInfoRepository) GetModelInfoById(id string) (*models.ModelRelatedInformation, error) {
- logging.INFO("id is:", id)
- var modelInfo models.ModelRelatedInformation
- result := repo.db.Where("id = ?", id).Find(&modelInfo)
- if result.Error != nil {
- return nil, result.Error
+ logging.INFO("id is: ", id)
+ var m models.ModelRelatedInformation
+ if err := repo.db.Session(&gorm.Session{SkipHooks: true}).
+ Where("id = ?", id).
+ First(&m).Error; err != nil {
+ return nil, err
+ }
+ if err := attachEnvsOne(repo.db, &m); err != nil {
+ return nil, err
}
- return &modelInfo, nil
+ return &m, nil
+}
+
+func attachEnvsOne(tx *gorm.DB, m *models.ModelRelatedInformation) error {
+ if m == nil || m.Id == "" {
+ return nil
+ }
+ var rows []models.TargetEnvironment
+ if err := tx.Table("target_environments").
+ Select("platform_name, environment_type, dependency_list").
+ Where("model_related_information_id = ?", m.Id).
+ Find(&rows).Error; err != nil {
+ return err
+ }
+ m.ModelInformation.TargetEnvironment = rows
+ return nil
+}
+
+func attachEnvsBatch(tx *gorm.DB, parents []models.ModelRelatedInformation) error {
+ if len(parents) == 0 {
+ return nil
+ }
+
+ ids := make([]string, 0, len(parents))
+ pos := make(map[string]int, len(parents))
+ for i := range parents {
+ if id := parents[i].Id; id != "" {
+ ids = append(ids, id)
+ pos[id] = i
+ }
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+
+ var rows []struct {
+ ModelRelatedInformationID string
+ PlatformName string
+ EnvironmentType string
+ DependencyList string
+ }
+ if err := tx.Table("target_environments").
+ Select("model_related_information_id, platform_name, environment_type, dependency_list").
+ Where("model_related_information_id IN ?", ids).
+ Find(&rows).Error; err != nil {
+ return err
+ }
+
+ for _, r := range rows {
+ if i, ok := pos[r.ModelRelatedInformationID]; ok {
+ parents[i].ModelInformation.TargetEnvironment = append(
+ parents[i].ModelInformation.TargetEnvironment,
+ models.TargetEnvironment{
+ PlatformName: r.PlatformName,
+ EnvironmentType: r.EnvironmentType,
+ DependencyList: r.DependencyList,
+ },
+ )
+ }
+ }
+ return nil
}
--- /dev/null
+/*
+==================================================================================
+Copyright (c) 2025 Minje Kim <alswp006@gmail.com> All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==================================================================================
+*/
+package db
+
+import (
+ "testing"
+
+ "gerrit.o-ran-sc.org/r/aiml-fw/awmf/modelmgmtservice/models"
+ "github.com/glebarez/sqlite"
+ "gorm.io/gorm"
+)
+
+func openTestDB(t *testing.T) *gorm.DB {
+ t.Helper()
+
+ d, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("open sqlite: %v", err)
+ }
+
+ if err := d.AutoMigrate(
+ &models.ModelRelatedInformation{},
+ &models.TargetEnvironment{},
+ ); err != nil {
+ t.Fatalf("automigrate: %v", err)
+ }
+ return d
+}
+
+func newRepo(t *testing.T) *ModelInfoRepository {
+ return NewModelInfoRepository(openTestDB(t))
+}
+
+func mkMRI(name, ver string, envs []models.TargetEnvironment) models.ModelRelatedInformation {
+ return models.ModelRelatedInformation{
+ ModelId: models.ModelID{ModelName: name, ModelVersion: ver},
+ Description: "test",
+ ModelInformation: models.ModelInformation{
+ Metadata: models.Metadata{Author: "tester"},
+ InputDataType: "csv",
+ OutputDataType: "json",
+ TargetEnvironment: envs,
+ },
+ }
+}
+func TestCRUD_Subtests(t *testing.T) {
+ repo := newRepo(t)
+
+ var id string
+
+ t.Run("Create_with_envs", func(t *testing.T) {
+ m := mkMRI("resnet", "1.0.0", []models.TargetEnvironment{
+ {PlatformName: "k8s", EnvironmentType: "prod", DependencyList: "cuda=12.1,torch=2.3"},
+ {PlatformName: "ec2", EnvironmentType: "stg", DependencyList: "cpu-only"},
+ })
+ if err := repo.Create(m); err != nil {
+ t.Fatalf("create: %v", err)
+ }
+ got, err := repo.GetModelInfoByNameAndVer("resnet", "1.0.0")
+ if err != nil {
+ t.Fatalf("get after create: %v", err)
+ }
+ if got.Id == "" {
+ t.Fatalf("expected non-empty id")
+ }
+ if len(got.ModelInformation.TargetEnvironment) != 2 {
+ t.Fatalf("want 2 envs, got %d", len(got.ModelInformation.TargetEnvironment))
+ }
+ id = got.Id
+ })
+
+ t.Run("Read", func(t *testing.T) {
+ byID, err := repo.GetModelInfoById(id)
+ if err != nil {
+ t.Fatalf("get by id: %v", err)
+ }
+ if byID.Id != id {
+ t.Fatalf("id mismatch")
+ }
+
+ all, err := repo.GetAll()
+ if err != nil {
+ t.Fatalf("get all: %v", err)
+ }
+ if len(all) == 0 {
+ t.Fatalf("expected >=1")
+ }
+
+ list, err := repo.GetModelInfoByName("resnet")
+ if err != nil {
+ t.Fatalf("get by name: %v", err)
+ }
+ if len(list) == 0 {
+ t.Fatalf("expected >=1 resnet")
+ }
+ })
+
+ t.Run("Update_Replace_with_new_values", func(t *testing.T) {
+ cur, err := repo.GetModelInfoById(id)
+ if err != nil {
+ t.Fatalf("get: %v", err)
+ }
+ cur.ModelInformation.TargetEnvironment = []models.TargetEnvironment{
+ {PlatformName: "edge", EnvironmentType: "prod", DependencyList: "cuda=12.2,torch=2.4"},
+ }
+ if err := repo.Update(*cur); err != nil {
+ t.Fatalf("update replace: %v", err)
+ }
+ after, err := repo.GetModelInfoById(id)
+ if err != nil {
+ t.Fatalf("get after replace: %v", err)
+ }
+ if len(after.ModelInformation.TargetEnvironment) != 1 {
+ t.Fatalf("want 1 env after replace, got %d", len(after.ModelInformation.TargetEnvironment))
+ }
+ if after.ModelInformation.TargetEnvironment[0].PlatformName != "edge" {
+ t.Fatalf("unexpected platform after replace: %+v", after.ModelInformation.TargetEnvironment[0])
+ }
+ })
+
+ t.Run("Update_Partial_nil_keeps_existing", func(t *testing.T) {
+ cur, err := repo.GetModelInfoById(id)
+ if err != nil {
+ t.Fatalf("get: %v", err)
+ }
+ cur.ModelInformation.TargetEnvironment = nil
+ if err := repo.Update(*cur); err != nil {
+ t.Fatalf("update partial nil: %v", err)
+ }
+ after, err := repo.GetModelInfoById(id)
+ if err != nil {
+ t.Fatalf("get after partial: %v", err)
+ }
+ if len(after.ModelInformation.TargetEnvironment) != 1 {
+ t.Fatalf("want keep 1 env, got %d", len(after.ModelInformation.TargetEnvironment))
+ }
+ })
+
+ t.Run("Update_Clear_with_empty_slice", func(t *testing.T) {
+ cur, err := repo.GetModelInfoById(id)
+ if err != nil {
+ t.Fatalf("get: %v", err)
+ }
+ cur.ModelInformation.TargetEnvironment = []models.TargetEnvironment{}
+ if err := repo.Update(*cur); err != nil {
+ t.Fatalf("update clear []: %v", err)
+ }
+ after, err := repo.GetModelInfoById(id)
+ if err != nil {
+ t.Fatalf("get after clear: %v", err)
+ }
+ if len(after.ModelInformation.TargetEnvironment) != 0 {
+ t.Fatalf("want 0 env after clear, got %d", len(after.ModelInformation.TargetEnvironment))
+ }
+ })
+
+ t.Run("Delete", func(t *testing.T) {
+ rows, err := repo.Delete(id)
+ if err != nil {
+ t.Fatalf("delete: %v", err)
+ }
+ if rows != 1 {
+ t.Fatalf("want 1 row deleted, got %d", rows)
+ }
+ if _, err := repo.GetModelInfoById(id); err == nil {
+ t.Fatalf("expected error after delete, got nil")
+ }
+ })
+}
}
// Auto migrate the scheme
- db.AutoMigrate(&models.ModelRelatedInformation{})
+ err = db.AutoMigrate(
+ &models.ModelRelatedInformation{},
+ &models.TargetEnvironment{},
+ )
+ if err != nil {
+ logging.ERROR("Failed to migrate database", "error", err)
+ os.Exit(-1)
+ }
+
repo := modelDB.NewModelInfoRepository(db)
router := routers.InitRouter(
package models
+import (
+ "github.com/google/uuid"
+ "gorm.io/gorm"
+)
+
type Metadata struct {
Author string `json:"author" validate:"required"`
Owner string `json:"owner"`
}
-type TargetEnironment struct {
- PlatformName string `json:"platformName" validate:"required"`
- EnvironmentType string `json:"environmentType" validate:"required"`
- DependencyList string `json:"dependencyList" validate:"required"`
+type TargetEnvironment struct {
+ ID string `gorm:"primaryKey" json:"-"`
+ ModelRelatedInformationID string `gorm:"index;not null" json:"-"`
+ PlatformName string `json:"platformName" validate:"required"`
+ EnvironmentType string `json:"environmentType" validate:"required"`
+ DependencyList string `json:"dependencyList" validate:"required"`
+}
+
+func (TargetEnvironment) TableName() string { return "target_environments" }
+
+func (te *TargetEnvironment) BeforeCreate(tx *gorm.DB) error {
+ if te.ID == "" {
+ te.ID = uuid.NewString()
+ }
+ return nil
}
type ModelInformation struct {
- Metadata Metadata `json:"metadata" gorm:"embedded" validate:"required"`
- InputDataType string `json:"inputDataType" validate:"required"` // this field will be a Comma Separated List
- OutputDataType string `json:"outputDataType" validate:"required"` // this field will be a Comma Separated List
- // TODO: gorm doesn't support list, need to find the right way
- // TargetEnvironment []TargetEnironment `json:"targetEnvironment" gorm:"embedded"`
+ Metadata Metadata `json:"metadata" gorm:"embedded" validate:"required"`
+ InputDataType string `json:"inputDataType" validate:"required"` // this field will be a Comma Separated List
+ OutputDataType string `json:"outputDataType" validate:"required"` // this field will be a Comma Separated List
+ TargetEnvironment []TargetEnvironment `json:"targetEnvironment,omitempty" gorm:"-"`
}
+
type ModelID struct {
ModelName string `json:"modelName" validate:"required" gorm:"primaryKey"`
ModelVersion string `json:"modelVersion" validate:"required" gorm:"primaryKey"`
Name string `json:"name"`
Data string `json:"data"`
}
+
+func (modelInfo *ModelRelatedInformation) BeforeCreate(tx *gorm.DB) error {
+ if modelInfo.Id == "" {
+ modelInfo.Id = uuid.NewString()
+ }
+ return nil
+}