Skip to content

Commit

Permalink
Reservoir sampler (#884)
Browse files Browse the repository at this point in the history
Reservoir sampler query operator
  • Loading branch information
thorfour authored Jun 4, 2024
1 parent 71fef5f commit f4615c3
Show file tree
Hide file tree
Showing 8 changed files with 455 additions and 1 deletion.
44 changes: 44 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3279,3 +3279,47 @@ func (a *AssertBucket) Upload(ctx context.Context, path string, r io.Reader) err
}
return a.Bucket.Upload(ctx, path, r)
}

func Test_DB_Sample(t *testing.T) {
t.Parallel()
config := NewTableConfig(
dynparquet.SampleDefinition(),
)
logger := newTestLogger(t)

c, err := New(WithLogger(logger))
t.Cleanup(func() {
require.NoError(t, c.Close())
})
require.NoError(t, err)
db, err := c.DB(context.Background(), "test")
require.NoError(t, err)
table, err := db.Table("test", config)
require.NoError(t, err)

ctx := context.Background()
for i := 0; i < 500; i++ {
samples := dynparquet.GenerateTestSamples(10)
r, err := samples.ToRecord()
require.NoError(t, err)
_, err = table.InsertRecord(ctx, r)
require.NoError(t, err)
}

pool := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer pool.AssertSize(t, 0)
lock := &sync.Mutex{}
rows := int64(0)
sampleSize := int64(13)
engine := query.NewEngine(pool, db.TableProvider())
err = engine.ScanTable("test").
Sample(sampleSize). // Sample 13 rows
Execute(context.Background(), func(ctx context.Context, r arrow.Record) error {
lock.Lock()
defer lock.Unlock()
rows += r.NumRows()
return nil
})
require.NoError(t, err)
require.Equal(t, sampleSize, rows)
}
12 changes: 12 additions & 0 deletions query/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Builder interface {
Limit(expr logicalplan.Expr) Builder
Execute(ctx context.Context, callback func(ctx context.Context, r arrow.Record) error) error
Explain(ctx context.Context) (string, error)
Sample(size int64) Builder
}

type LocalEngine struct {
Expand Down Expand Up @@ -143,6 +144,17 @@ func (b LocalQueryBuilder) Limit(
}
}

func (b LocalQueryBuilder) Sample(
size int64,
) Builder {
return LocalQueryBuilder{
pool: b.pool,
tracer: b.tracer,
planBuilder: b.planBuilder.Sample(logicalplan.Literal(size)),
execOpts: b.execOpts,
}
}

func (b LocalQueryBuilder) Execute(ctx context.Context, callback func(ctx context.Context, r arrow.Record) error) error {
ctx, span := b.tracer.Start(ctx, "LocalQueryBuilder/Execute")
defer span.End()
Expand Down
16 changes: 16 additions & 0 deletions query/logicalplan/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,22 @@ func resolveAggregation(plan *LogicalPlan, agg *AggregationFunction) ([]*Aggrega
}
}

func (b Builder) Sample(expr Expr) Builder {
if expr == nil {
return b
}

return Builder{
err: b.err,
plan: &LogicalPlan{
Input: b.plan,
Sample: &Sample{
Expr: expr,
},
},
}
}

func (b Builder) Build() (*LogicalPlan, error) {
if b.err != nil {
return nil, b.err
Expand Down
16 changes: 16 additions & 0 deletions query/logicalplan/logicalplan.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type LogicalPlan struct {
Projection *Projection
Aggregation *Aggregation
Limit *Limit
Sample *Sample
}

// Callback is a function that is called throughout a chain of operators
Expand Down Expand Up @@ -159,6 +160,13 @@ func (plan *LogicalPlan) DataTypeForExpr(expr Expr) (arrow.DataType, error) {
return nil, fmt.Errorf("data type for expr %v within Distinct: %w", expr, err)
}

return t, nil
case plan.Sample != nil:
t, err := expr.DataType(plan.Input)
if err != nil {
return nil, fmt.Errorf("data type for expr %v within Sample: %w", expr, err)
}

return t, nil
default:
return nil, fmt.Errorf("unknown logical plan")
Expand Down Expand Up @@ -414,3 +422,11 @@ type Limit struct {
func (l *Limit) String() string {
return "Limit" + " Expr: " + fmt.Sprint(l.Expr)
}

type Sample struct {
Expr Expr
}

func (s *Sample) String() string {
return "Sample" + " Expr: " + fmt.Sprint(s.Expr)
}
5 changes: 4 additions & 1 deletion query/logicalplan/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,13 @@ func ValidateSingleFieldSet(plan *LogicalPlan) *PlanValidationError {
if plan.Limit != nil {
fieldsSet = append(fieldsSet, 6)
}
if plan.Sample != nil {
fieldsSet = append(fieldsSet, 7)
}

if len(fieldsSet) != 1 {
fieldsFound := make([]string, 0)
fields := []string{"SchemaScan", "TableScan", "Filter", "Distinct", "Projection", "Aggregation"}
fields := []string{"SchemaScan", "TableScan", "Filter", "Distinct", "Projection", "Aggregation", "Limit", "Sample"}
for _, i := range fieldsSet {
fieldsFound = append(fieldsFound, fields[i])
}
Expand Down
14 changes: 14 additions & 0 deletions query/physicalplan/physicalplan.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/memory"
"github.com/apache/arrow/go/v16/arrow/scalar"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -471,6 +472,19 @@ func Build(
if ordered {
oInfo.nodeMaintainsOrdering()
}
case plan.Sample != nil:
v := plan.Sample.Expr.(*logicalplan.LiteralExpr).Value.(*scalar.Int64).Value
perSampler := v / int64(len(prev))
r := v % int64(len(prev))
for i := range prev {
adjust := int64(0)
if i < int(r) {
adjust = 1
}
s := NewReservoirSampler(perSampler + adjust)
prev[i].SetNext(s)
prev[i] = s
}
default:
panic("Unsupported plan")
}
Expand Down
153 changes: 153 additions & 0 deletions query/physicalplan/sampler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package physicalplan

import (
"context"
"fmt"
"math"
"math/rand"

"github.com/apache/arrow/go/v16/arrow"
)

type ReservoirSampler struct {
next PhysicalPlan

// size is the max number of rows in the reservoir
size int64

// reservoir is the set of records that have been sampled. They may vary in schema due to dynamic columns.
reservoir []sample

w float64 // w is the probability of keeping a record
n int64 // n is the number of rows that have been sampled thus far
i float64 // i is the current row number being sampled
}

// NewReservoirSampler will create a new ReservoirSampler operator that will sample up to size rows of all records seen by Callback.
func NewReservoirSampler(size int64) *ReservoirSampler {
return &ReservoirSampler{
size: size,
w: math.Exp(math.Log(rand.Float64()) / float64(size)),
}
}

func (s *ReservoirSampler) SetNext(p PhysicalPlan) {
s.next = p
}

func (s *ReservoirSampler) Draw() *Diagram {
var child *Diagram
if s.next != nil {
child = s.next.Draw()
}
details := fmt.Sprintf("Reservoir Sampler (%v)", s.size)
return &Diagram{Details: details, Child: child}
}

func (s *ReservoirSampler) Close() {
for _, r := range s.reservoir {
r.r.Release()
}
s.next.Close()
}

// Callback collects all the records to sample.
func (s *ReservoirSampler) Callback(_ context.Context, r arrow.Record) error {
r = s.fill(r)
if r == nil { // The record fit in the reservoir
return nil
}
if s.n == s.size { // The reservoir just filled up. Slice the reservoir to the correct size so we can easily perform row replacement
s.sliceReservoir()
}

// Sample the record
s.sample(r)
return nil
}

// fill will fill the reservoir with the first size records.
func (s *ReservoirSampler) fill(r arrow.Record) arrow.Record {
if s.n >= s.size {
return r
}

if s.n+r.NumRows() <= s.size { // The record fits in the reservoir
s.reservoir = append(s.reservoir, sample{r: r, i: -1}) // -1 means the record is not sampled; use the entire record
r.Retain()
s.n += r.NumRows()
return nil
}

// The record partially fits in the reservoir
s.reservoir = append(s.reservoir, sample{r: r.NewSlice(0, s.size-s.n), i: -1})
r = r.NewSlice(s.size-s.n, r.NumRows())
s.n = s.size
return r
}

func (s *ReservoirSampler) sliceReservoir() {
newReservoir := make([]sample, 0, s.size)
for _, r := range s.reservoir {
for j := int64(0); j < r.r.NumRows(); j++ {
newReservoir = append(newReservoir, sample{r: r.r, i: j})
r.r.Retain()
}
r.r.Release()
}
s.reservoir = newReservoir
}

// sample implements the reservoir sampling algorithm found https://en.wikipedia.org/wiki/Reservoir_sampling.
func (s *ReservoirSampler) sample(r arrow.Record) {
n := s.n + r.NumRows()
if s.i == 0 {
s.i = float64(s.n) - 1
} else if s.i < float64(n) {
s.replace(rand.Intn(int(s.size)), sample{r: r, i: int64(s.i) - s.n})
s.w = s.w * math.Exp(math.Log(rand.Float64())/float64(s.size))
}

for s.i < float64(n) {
s.i += math.Floor(math.Log(rand.Float64())/math.Log(1-s.w)) + 1
if s.i < float64(n) {
// replace a random item of the reservoir with row i
s.replace(rand.Intn(int(s.size)), sample{r: r, i: int64(s.i) - s.n})
s.w = s.w * math.Exp(math.Log(rand.Float64())/float64(s.size))
}
}
s.n = n
}

// Finish sends all the records in the reservoir to the next operator.
func (s *ReservoirSampler) Finish(ctx context.Context) error {
// Send all the records in the reservoir to the next operator
for _, r := range s.reservoir {
if r.i == -1 {
if err := s.next.Callback(ctx, r.r); err != nil {
return err
}
continue
}

record := r.r.NewSlice(r.i, r.i+1)
defer record.Release()
if err := s.next.Callback(ctx, record); err != nil {
return err
}
}

return s.next.Finish(ctx)
}

// replace will replace the row at index i with the row in the record r at index j.
func (s *ReservoirSampler) replace(i int, newRow sample) {
s.reservoir[i].r.Release()
s.reservoir[i] = newRow
newRow.r.Retain()
}

type sample struct {
r arrow.Record
i int64
}
Loading

0 comments on commit f4615c3

Please sign in to comment.