diff --git a/tests/go.mod b/tests/go.mod index 350d17946..8eeab51de 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -29,8 +29,8 @@ require ( github.com/microsoft/go-mssqldb v1.7.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - golang.org/x/crypto v0.24.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/crypto v0.26.0 // indirect + golang.org/x/text v0.17.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/tests/transaction_test.go b/tests/transaction_test.go index d2cbc9a95..9f0f067c8 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -297,6 +297,74 @@ func TestNestedTransactionWithBlock(t *testing.T) { } } +func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) { + transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return callback(ctx, tx) + }) + } + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("inner rollback") + }); err == nil { + t.Fatalf("nested transaction has no error") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked parent record") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should not find rollbacked nested record") + } +} + func TestDisabledNestedTransaction(t *testing.T) { var ( user = *GetUser("transaction-nested", Config{})