diff --git a/association.go b/association.go index 7ebe065e..5f7896bf 100644 --- a/association.go +++ b/association.go @@ -24,6 +24,36 @@ type associationKey struct { index int } +type AssocData struct { + path []string + through string +} + +func (a AssocData) Through() string { + return a.through +} + +func (a AssocData) FullPath() []string { + return append(a.IntermediaryPath(), a.path[len(a.path)-1]) +} + +func (a AssocData) IntermediaryPath() []string { + tmpPath := append([]string{}, a.path[:len(a.path)-1]...) + return append(tmpPath, a.Through()) +} + +func (a AssocData) Field() string { + return a.path[len(a.path)-1] +} + +func NewAssociationData(sl slice, path []string) AssocData { + a := sl.Get(0).Association(path[0]) + return AssocData{ + path: path, + through: a.Through(), + } +} + type associationData struct { typ AssociationType targetIndex []int diff --git a/rel_test.go b/rel_test.go index fe3385b5..6f951ca4 100644 --- a/rel_test.go +++ b/rel_test.go @@ -100,3 +100,24 @@ type UserRole struct { UserID int `db:",primary"` RoleID int `db:",primary"` } + +type Supplier struct { + ID int + Name string + + Account Account + AccountHistory AccountHistory `through:"account"` +} + +type Account struct { + ID int + SupplierID int + AccountNumber string + AccountHistory AccountHistory +} + +type AccountHistory struct { + ID int + AccountID int + CreditRating int +} diff --git a/repository.go b/repository.go index 1df283e8..de326b79 100644 --- a/repository.go +++ b/repository.go @@ -881,7 +881,6 @@ func (r repository) Preload(ctx context.Context, records interface{}, field stri var ( sl slice - cw = fetchContext(ctx, r.rootAdapter) path = strings.Split(field, ".") rt = reflect.TypeOf(records) ) @@ -897,13 +896,47 @@ func (r repository) Preload(ctx context.Context, records interface{}, field stri sl = NewDocument(records) } + if sl.Len() == 0 { + return nil + } + + assoc := NewAssociationData(sl, path) + if assoc.Through() == "" { + return r.preload(ctx, sl, assoc, queriers...) + } + + return r.preloadThrough(ctx, sl, assoc, queriers...) +} + +func (r repository) preloadThrough(ctx context.Context, sl slice, assoc AssocData, queriers ...Querier) error { + if err := r.preload(ctx, sl, NewAssociationData(sl, assoc.IntermediaryPath())); err != nil { + return err + } + + if err := r.preload(ctx, sl, NewAssociationData(sl, assoc.FullPath())); err != nil { + return err + } + + for i := 0; i < sl.Len(); i++ { + recordDoc := sl.Get(i) + + record, _ := recordDoc.Value(assoc.Through()) + assocDoc, _ := NewDocument(record, true).Value(assoc.Field()) + recordDoc.SetValue(assoc.Field(), assocDoc) + } + + return nil +} + +func (r repository) preload(ctx context.Context, sl slice, assoc AssocData, queriers ...Querier) error { var ( - targets, table, keyField, keyType, ddata, loaded = r.mapPreloadTargets(sl, path) + cw = fetchContext(ctx, r.rootAdapter) + targets, table, keyField, keyType, ddata, loaded = r.mapPreloadTargets(sl, assoc.path) ids = r.targetIDs(targets) query = Build(table, append(queriers, In(keyField, ids...))...) ) - if len(targets) == 0 || loaded && !bool(query.ReloadQuery) { + if loaded && !bool(query.ReloadQuery) { return nil } diff --git a/repository_test.go b/repository_test.go index 2c8d5afd..48f00d9c 100644 --- a/repository_test.go +++ b/repository_test.go @@ -3177,6 +3177,114 @@ func TestRepository_Preload_queryError(t *testing.T) { cur.AssertExpectations(t) } +func TestRepository_PreloadThrough_hasOne(t *testing.T) { + var ( + adapter = &testAdapter{} + repo = New(adapter) + supplier = Supplier{ID: 1, Name: "Supplier 1"} + account = Account{ID: 2, SupplierID: supplier.ID, AccountNumber: "222222"} + accountHistory = AccountHistory{ID: 3, AccountID: account.ID, CreditRating: 3333} + cur = &testCursor{} + ) + + adapter.On("Query", From("accounts").Where(In("supplier_id", supplier.ID))).Return(cur, nil).Once() + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "supplier_id", "account_number"}, nil).Once() + cur.On("Next").Return(true).Once() + cur.MockScan(account.ID, account.SupplierID, account.AccountNumber).Twice() + cur.On("Next").Return(false).Once() + + adapter.On("Query", From("account_histories").Where(In("account_id", account.ID))).Return(cur, nil).Once() + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "account_id", "credit_rating"}, nil).Once() + cur.On("Next").Return(true).Once() + cur.MockScan(accountHistory.ID, accountHistory.AccountID, accountHistory.CreditRating).Twice() + cur.On("Next").Return(false).Once() + + assert.Nil(t, repo.Preload(context.TODO(), &supplier, "account_history")) + assert.Equal(t, accountHistory, supplier.AccountHistory) + assert.Equal(t, accountHistory, supplier.Account.AccountHistory) +} + +func TestRepository_PreloadThrough_sliceHasOne(t *testing.T) { + var ( + adapter = &testAdapter{} + repo = New(adapter) + suppliers = []Supplier{ + Supplier{ID: 1, Name: "Supplier 1"}, + Supplier{ID: 4, Name: "Supplier 4"}, + } + accounts = []Account{ + Account{ID: 2, SupplierID: suppliers[0].ID, AccountNumber: "222222"}, + Account{ID: 5, SupplierID: suppliers[1].ID, AccountNumber: "555555"}, + } + accountHistories = []AccountHistory{ + AccountHistory{ID: 3, AccountID: accounts[0].ID, CreditRating: 3333}, + AccountHistory{ID: 6, AccountID: accounts[1].ID, CreditRating: 6666}, + } + cur = &testCursor{} + ) + + adapter.On("Query", From("accounts").Where(In("supplier_id", suppliers[0].ID, suppliers[1].ID))).Return(cur, nil).Once() + adapter.On("Query", From("accounts").Where(In("supplier_id", suppliers[1].ID, suppliers[0].ID))).Return(cur, nil).Once() + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "supplier_id", "account_number"}, nil).Once() + cur.On("Next").Return(true).Twice() + cur.MockScan(accounts[0].ID, accounts[0].SupplierID, accounts[0].AccountNumber).Twice() + cur.MockScan(accounts[1].ID, accounts[1].SupplierID, accounts[1].AccountNumber).Twice() + cur.On("Next").Return(false).Once() + + adapter.On("Query", From("account_histories").Where(In("account_id", accounts[0].ID, accounts[1].ID))).Return(cur, nil).Once() + adapter.On("Query", From("account_histories").Where(In("account_id", accounts[1].ID, accounts[0].ID))).Return(cur, nil).Once() + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "account_id", "credit_rating"}, nil).Once() + cur.On("Next").Return(true).Twice() + cur.MockScan(accountHistories[0].ID, accountHistories[0].AccountID, accountHistories[0].CreditRating).Twice() + cur.MockScan(accountHistories[1].ID, accountHistories[1].AccountID, accountHistories[1].CreditRating).Twice() + cur.On("Next").Return(false).Once() + + assert.Nil(t, repo.Preload(context.TODO(), &suppliers, "account_history")) + for i := 0; i < 2; i++ { + assert.Equal(t, accountHistories[i], suppliers[i].AccountHistory) + assert.Equal(t, accountHistories[i], suppliers[i].Account.AccountHistory) + } +} + +func TestRepository_PreloadThrough_intermediaryQueryError(t *testing.T) { + var ( + adapter = &testAdapter{} + repo = New(adapter) + supplier = Supplier{ID: 1, Name: "Supplier 1"} + err = errors.New("intermediaryError") + ) + + adapter.On("Query", From("accounts").Where(In("supplier_id", supplier.ID))).Return(&testCursor{}, err).Once() + + assert.Error(t, repo.Preload(context.TODO(), &supplier, "account_history")) +} + +func TestRepository_PreloadThrough_targerQueryError(t *testing.T) { + var ( + adapter = &testAdapter{} + repo = New(adapter) + supplier = Supplier{ID: 1, Name: "Supplier 1"} + account = Account{ID: 2, SupplierID: supplier.ID, AccountNumber: "222222"} + cur = &testCursor{} + err = errors.New("targetError") + ) + + adapter.On("Query", From("accounts").Where(In("supplier_id", supplier.ID))).Return(cur, nil).Once() + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "supplier_id", "account_number"}, nil).Once() + cur.On("Next").Return(true).Once() + cur.MockScan(account.ID, account.SupplierID, account.AccountNumber).Twice() + cur.On("Next").Return(false).Once() + + adapter.On("Query", From("account_histories").Where(In("account_id", account.ID))).Return(cur, err).Once() + + assert.Error(t, repo.Preload(context.TODO(), &supplier, "account_history")) +} + func TestRepository_MustPreload(t *testing.T) { var ( adapter = &testAdapter{}