Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Preload Functionality with Custom Joins and Add Comprehensive Tests #7293

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 17 additions & 34 deletions callbacks/preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,6 @@
"gorm.io/gorm/utils"
)

// parsePreloadMap extracts nested preloads. e.g.
//
// // schema has a "k0" relation and a "k7.k8" embedded relation
// parsePreloadMap(schema, map[string][]interface{}{
// clause.Associations: {"arg1"},
// "k1": {"arg2"},
// "k2.k3": {"arg3"},
// "k4.k5.k6": {"arg4"},
// })
// // preloadMap is
// map[string]map[string][]interface{}{
// "k0": {},
// "k7": {
// "k8": {},
// },
// "k1": {},
// "k2": {
// "k3": {"arg3"},
// },
// "k4": {
// "k5.k6": {"arg4"},
// },
// }
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
preloadMap := map[string]map[string][]interface{}{}
setPreloadMap := func(name, value string, args []interface{}) {
Expand Down Expand Up @@ -74,7 +51,6 @@
}
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
Expand All @@ -84,10 +60,7 @@
}

// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}, customJoin func(*gorm.DB) *gorm.DB) error {

Check failure on line 63 in callbacks/preload.go

View workflow job for this annotation

GitHub Actions / runner / golangci-lint

[golangci] reported by reviewdog 🐶 calculated cyclomatic complexity for function preloadEntryPoint is 23, max is 10 (cyclop) Raw Output: callbacks/preload.go:63:1: calculated cyclomatic complexity for function preloadEntryPoint is 23, max is 10 (cyclop) func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}, customJoin func(*gorm.DB) *gorm.DB) error { ^
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)

// avoid random traversal of the map
Expand Down Expand Up @@ -116,7 +89,7 @@

for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds, customJoin); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
Expand All @@ -138,14 +111,14 @@
}

tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, customJoin); err != nil {
return err
}
}
case reflect.Struct, reflect.Pointer:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, customJoin); err != nil {
return err
}
default:
Expand All @@ -155,7 +128,7 @@
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name], customJoin); err != nil {
return err
}
}
Expand All @@ -182,7 +155,7 @@
return tx
}

func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}, customJoin func(*gorm.DB) *gorm.DB) error {

Check failure on line 158 in callbacks/preload.go

View workflow job for this annotation

GitHub Actions / runner / golangci-lint

[golangci] reported by reviewdog 🐶 calculated cyclomatic complexity for function preload is 38, max is 10 (cyclop) Raw Output: callbacks/preload.go:158:1: calculated cyclomatic complexity for function preload is 38, max is 10 (cyclop) func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}, customJoin func(*gorm.DB) *gorm.DB) error { ^
var (
reflectValue = tx.Statement.ReflectValue
relForeignKeys []string
Expand All @@ -193,6 +166,10 @@
inlineConds []interface{}
)

if customJoin != nil {
tx = customJoin(tx)
}

if rel.JoinTable != nil {
var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
Expand Down Expand Up @@ -268,7 +245,13 @@

// nested preload
for p, pvs := range preloads {
tx = tx.Preload(p, pvs...)
if customJoin != nil {
tx = tx.Preload(p, pvs, func(tx *gorm.DB) *gorm.DB {
return customJoin(tx)
})
} else {
tx = tx.Preload(p, pvs...)
}
}

reflectResults := rel.FieldSchema.MakeSlice().Elem()
Expand Down
2 changes: 1 addition & 1 deletion callbacks/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func Preload(db *gorm.DB) {
return
}

db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations], nil))
}
}

Expand Down
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ require (
github.com/jinzhu/now v1.1.5
golang.org/x/text v0.20.0
)

require (
github.com/mattn/go-sqlite3 v1.14.22 // indirect
gorm.io/driver/sqlite v1.5.6 // indirect
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE=
gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
205 changes: 205 additions & 0 deletions tests/preload_custom_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package tests_test

import (
"testing"
"time"

"gorm.io/driver/sqlite"
"gorm.io/gorm"
)

// Structs for preload tests
type PreloadItem struct {
ID uint
Name string
Tags []PreloadTag `gorm:"many2many:preload_items_preload_tags"`
CreatedAt time.Time
}

type PreloadTag struct {
ID uint
Name string
Status string
SubTags []PreloadSubTag `gorm:"many2many:tag_sub_tags"`
}

type PreloadSubTag struct {
ID uint
Name string
Status string
}

// Setup database for preload tests
func setupPreloadTestDB(t *testing.T) *gorm.DB {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to connect database: %v", err)
}
err = db.AutoMigrate(&PreloadItem{}, &PreloadTag{}, &PreloadSubTag{})
if err != nil {
t.Fatalf("failed to migrate database: %v", err)
}
return db
}

// Test default preload functionality
func TestDefaultPreload(t *testing.T) {
db := setupPreloadTestDB(t)

tag1 := PreloadTag{Name: "Tag1", Status: "active"}
item := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag1}}
db.Create(&item)

var items []PreloadItem
err := db.Preload("Tags").Find(&items).Error
if err != nil {
t.Fatalf("default preload failed: %v", err)
}

if len(items) != 1 || len(items[0].Tags) != 1 || items[0].Tags[0].Name != "Tag1" {
t.Errorf("unexpected default preload results: %v", items)
}
}

// Test preloading with custom joins and conditions
func TestCustomJoinsWithConditions(t *testing.T) {
db := setupPreloadTestDB(t)

tag1 := PreloadTag{Name: "Tag1", Status: "active"}
tag2 := PreloadTag{Name: "Tag2", Status: "inactive"}
item := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag1, tag2}}
db.Create(&item)

var items []PreloadItem
err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
return tx.Joins("JOIN preload_items_preload_tags ON preload_items_preload_tags.preload_tag_id = preload_tags.id").
Where("preload_tags.status = ?", "active")
}).Find(&items).Error
if err != nil {
t.Fatalf("custom join with conditions failed: %v", err)
}

if len(items) != 1 || len(items[0].Tags) != 1 || items[0].Tags[0].Status != "active" {
t.Errorf("unexpected results in TestCustomJoinsWithConditions: %v", items)
}
}

// Test nested preload functionality with custom joins
func TestNestedPreloadWithCustomJoins(t *testing.T) {
db := setupPreloadTestDB(t)

subTag := PreloadSubTag{Name: "SubTag1", Status: "active"}
tag := PreloadTag{Name: "Tag1", Status: "active", SubTags: []PreloadSubTag{subTag}}
item := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag}}
db.Create(&item)

var items []PreloadItem
err := db.Preload("Tags.SubTags", func(tx *gorm.DB) *gorm.DB {
return tx.Joins("JOIN tag_sub_tags ON tag_sub_tags.preload_sub_tag_id = preload_sub_tags.id").
Where("preload_sub_tags.status = ?", "active")
}).Find(&items).Error
if err != nil {
t.Fatalf("nested preload with custom joins failed: %v", err)
}

if len(items) != 1 || len(items[0].Tags) != 1 || len(items[0].Tags[0].SubTags) != 1 || items[0].Tags[0].SubTags[0].Name != "SubTag1" {
t.Errorf("unexpected nested preload results: %v", items)
}
}

// Test behavior when no matching records exist
func TestNoMatchingRecords(t *testing.T) {
db := setupPreloadTestDB(t)

tag := PreloadTag{Name: "Tag1", Status: "inactive"}
item := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag}}
db.Create(&item)

var items []PreloadItem
err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
return tx.Joins("JOIN preload_items_preload_tags ON preload_items_preload_tags.preload_tag_id = preload_tags.id").
Where("preload_tags.status = ?", "active")
}).Find(&items).Error
if err != nil {
t.Fatalf("preload with no matching records failed: %v", err)
}

if len(items) != 1 || len(items[0].Tags) != 0 {
t.Errorf("unexpected results in TestNoMatchingRecords: %v", items)
}
}

// Test behavior with an empty database
func TestEmptyDatabase(t *testing.T) {
db := setupPreloadTestDB(t)

var items []PreloadItem
err := db.Preload("Tags").Find(&items).Error
if err != nil {
t.Fatalf("preload with empty database failed: %v", err)
}

if len(items) != 0 {
t.Errorf("unexpected results in TestEmptyDatabase: %v", items)
}
}

// Test multiple items with different tag statuses
func TestMultipleItemsWithDifferentTagStatuses(t *testing.T) {
db := setupPreloadTestDB(t)

tag1 := PreloadTag{Name: "Tag1", Status: "active"}
tag2 := PreloadTag{Name: "Tag2", Status: "inactive"}
item1 := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag1}}
item2 := PreloadItem{Name: "Item2", Tags: []PreloadTag{tag2}}
db.Create(&item1)
db.Create(&item2)

var items []PreloadItem
err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
return tx.Joins("JOIN preload_items_preload_tags ON preload_items_preload_tags.preload_tag_id = preload_tags.id").
Where("preload_tags.status = ?", "active")
}).Find(&items).Error
if err != nil {
t.Fatalf("preload with multiple items failed: %v", err)
}

if len(items) != 2 || len(items[0].Tags) != 1 || len(items[1].Tags) != 0 {
t.Errorf("unexpected results in TestMultipleItemsWithDifferentTagStatuses: %v", items)
}
}

// Test duplicate preload conditions
func TestDuplicatePreloadConditions(t *testing.T) {
db := setupPreloadTestDB(t)

tag1 := PreloadTag{Name: "Tag1", Status: "active"}
tag2 := PreloadTag{Name: "Tag2", Status: "inactive"}
item := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag1, tag2}}
db.Create(&item)

var activeTagsItems []PreloadItem
var inactiveTagsItems []PreloadItem

err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
return tx.Where("status = ?", "active")
}).Find(&activeTagsItems).Error
if err != nil {
t.Fatalf("preload for active tags failed: %v", err)
}

err = db.Preload("Tags", func(tx *gorm.DB) *gorm.DB {
return tx.Where("status = ?", "inactive")
}).Find(&inactiveTagsItems).Error
if err != nil {
t.Fatalf("preload for inactive tags failed: %v", err)
}

if len(activeTagsItems) != 1 || len(activeTagsItems[0].Tags) != 1 || activeTagsItems[0].Tags[0].Status != "active" {
t.Errorf("unexpected active tag results in TestDuplicatePreloadConditions: %v", activeTagsItems)
}

if len(inactiveTagsItems) != 1 || len(inactiveTagsItems[0].Tags) != 1 || inactiveTagsItems[0].Tags[0].Status != "inactive" {
t.Errorf("unexpected inactive tag results in TestDuplicatePreloadConditions: %v", inactiveTagsItems)
}
}
Loading