diff --git a/collection.go b/collection.go index a12958b..ec5686f 100644 --- a/collection.go +++ b/collection.go @@ -99,21 +99,21 @@ func (coll *Collection) SimpleFindWithCtx(ctx context.Context, results interface return cur.All(ctx, results) } -//-------------------------------- +// -------------------------------- // Aggregation methods -//-------------------------------- +// -------------------------------- // SimpleAggregateFirst is just same as SimpleAggregateFirstWithCtx, but doesn't get context param. -func (coll *Collection) SimpleAggregateFirst(result interface{}, stages ...interface{}) (bool, error) { - return coll.SimpleAggregateFirstWithCtx(ctx(), result, stages...) +func (coll *Collection) SimpleAggregateFirst(result interface{}, stages []interface{}, opts ...*options.AggregateOptions) (bool, error) { + return coll.SimpleAggregateFirstWithCtx(ctx(), result, stages, opts...) } // SimpleAggregateFirstWithCtx performs a simple aggregation, decodes the first aggregate result and returns it using the provided result parameter. // The value of `stages` can be Operator|bson.M // Note: you can not use this method in a transaction because it does not accept a context. // To participate in transactions, please use the regular aggregation method. -func (coll *Collection) SimpleAggregateFirstWithCtx(ctx context.Context, result interface{}, stages ...interface{}) (bool, error) { - cur, err := coll.SimpleAggregateCursorWithCtx(ctx, stages...) +func (coll *Collection) SimpleAggregateFirstWithCtx(ctx context.Context, result interface{}, stages []interface{}, opts ...*options.AggregateOptions) (bool, error) { + cur, err := coll.SimpleAggregateCursorWithCtx(ctx, stages, opts...) if err != nil { return false, err } @@ -124,16 +124,16 @@ func (coll *Collection) SimpleAggregateFirstWithCtx(ctx context.Context, result } // SimpleAggregate is just same as SimpleAggregateWithCtx, but doesn't get context param. -func (coll *Collection) SimpleAggregate(results interface{}, stages ...interface{}) error { - return coll.SimpleAggregateWithCtx(ctx(), results, stages...) +func (coll *Collection) SimpleAggregate(results interface{}, stages []interface{}, opts ...*options.AggregateOptions) error { + return coll.SimpleAggregateWithCtx(ctx(), results, stages, opts...) } // SimpleAggregateWithCtx performs a simple aggregation, decodes the aggregate result and returns the list using the provided result parameter. // The value of `stages` can be Operator|bson.M // Note: you can not use this method in a transaction because it does not accept a context. // To participate in transactions, please use the regular aggregation method. -func (coll *Collection) SimpleAggregateWithCtx(ctx context.Context, results interface{}, stages ...interface{}) error { - cur, err := coll.SimpleAggregateCursorWithCtx(ctx, stages...) +func (coll *Collection) SimpleAggregateWithCtx(ctx context.Context, results interface{}, stages []interface{}, opts ...*options.AggregateOptions) error { + cur, err := coll.SimpleAggregateCursorWithCtx(ctx, stages, opts...) if err != nil { return err } @@ -143,14 +143,14 @@ func (coll *Collection) SimpleAggregateWithCtx(ctx context.Context, results inte // SimpleAggregateCursor is just same as SimpleAggregateCursorWithCtx, but // doesn't get context. -func (coll *Collection) SimpleAggregateCursor(stages ...interface{}) (*mongo.Cursor, error) { - return coll.SimpleAggregateCursorWithCtx(ctx(), stages...) +func (coll *Collection) SimpleAggregateCursor(stages []interface{}, opts ...*options.AggregateOptions) (*mongo.Cursor, error) { + return coll.SimpleAggregateCursorWithCtx(ctx(), stages, opts...) } // SimpleAggregateCursorWithCtx performs a simple aggregation and returns a cursor over the resulting documents. // Note: you can not use this method in a transaction because it does not accept a context. // To participate in transactions, please use the regular aggregation method. -func (coll *Collection) SimpleAggregateCursorWithCtx(ctx context.Context, stages ...interface{}) (*mongo.Cursor, error) { +func (coll *Collection) SimpleAggregateCursorWithCtx(ctx context.Context, stages []interface{}, opts ...*options.AggregateOptions) (*mongo.Cursor, error) { pipeline := bson.A{} for _, stage := range stages { @@ -161,5 +161,5 @@ func (coll *Collection) SimpleAggregateCursorWithCtx(ctx context.Context, stages } } - return coll.Aggregate(ctx, pipeline, nil) + return coll.Aggregate(ctx, pipeline, opts...) } diff --git a/collection_test.go b/collection_test.go index 0253b51..6cf0fde 100644 --- a/collection_test.go +++ b/collection_test.go @@ -1,6 +1,8 @@ package mgm_test import ( + "testing" + "github.com/kamva/mgm/v3" "github.com/kamva/mgm/v3/builder" "github.com/kamva/mgm/v3/internal/util" @@ -9,7 +11,6 @@ import ( "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" - "testing" ) func TestFindByIdWithInvalidId(t *testing.T) { @@ -127,7 +128,7 @@ func TestCollection_SimpleAggregateFirst(t *testing.T) { // We dont want to change document. group := builder.Group("$_id", nil) - found, err := mgm.Coll(&Doc{}).SimpleAggregateFirst(&gotResult, group) + found, err := mgm.Coll(&Doc{}).SimpleAggregateFirst(&gotResult, []interface{}{group}) assert.True(t, found) util.AssertErrIsNil(t, err) @@ -146,7 +147,7 @@ func TestCollection_SimpleAggregateFirstFalse(t *testing.T) { var gotResult *Doc match := bson.M{operator.Match: bson.M{"user_id": "unknown"}} - found, err := mgm.Coll(&Doc{}).SimpleAggregateFirst(gotResult, match) + found, err := mgm.Coll(&Doc{}).SimpleAggregateFirst(gotResult, []interface{}{match}) assert.False(t, found) util.AssertErrIsNil(t, err) @@ -166,7 +167,7 @@ func TestCollection_SimpleAggregate(t *testing.T) { project := bson.M{operator.Project: bson.M{"age": 0}} - err := mgm.Coll(&Doc{}).SimpleAggregate(&gotResult, group, project) + err := mgm.Coll(&Doc{}).SimpleAggregate(&gotResult, []interface{}{group, project}) util.AssertErrIsNil(t, err)