Skip to content

Commit

Permalink
Fix AggregateFinal behavior on window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
zombiezen committed Jul 17, 2022
1 parent ab7a223 commit 8ade6d3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 8 deletions.
13 changes: 12 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

[Unreleased]: https://github.com/zombiezen/go-sqlite/compare/v0.10.0...main
[Unreleased]: https://github.com/zombiezen/go-sqlite/compare/v0.10.1...main

## [0.10.1][] - 2022-07-17

Version 0.10.1 fixes a bug in user-defined window functions.
Special thanks to Jan Mercl for assistance in debugging this issue.

[0.10.1]: https://github.com/zombiezen/go-sqlite/releases/tag/v0.10.1

### Fixed

- `AggregateFinal` is now called correctly at the end of window functions' usages.

## [0.10.0][] - 2022-07-10

Expand Down
34 changes: 27 additions & 7 deletions func.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,16 +409,28 @@ type FunctionImpl struct {

// AggregateStep is called for each row
// of an aggregate function's SQL invocation.
//
// Use closure variables to accumulate state between calls to AggregateStep.
AggregateStep func(ctx Context, rowArgs []Value) error
// AggregateFinal is called after all of the aggregate function's input rows
// have been stepped through to construct the result.
// AggregateFinal is called
// after all of the aggregate function's input rows have been stepped through.
// The AggregateFinal function should
// reset the state used in AggregateStep to its initial value.
// When using the function as a non-window aggregate,
// the returned value is used as the function's result.
// When using a function as window aggregate,
// the function will only receive a call to AggregateFinal
// when it processes one or more rows during its evaluation
// to reset state (the returned Value is ignored).
//
// Use closure variables to pass information between AggregateStep and
// AggregateFinal. The AggregateFinal function should also reset any shared
// variables to their initial states before returning.
// Use closure variables to pass information between AggregateStep and AggregateFinal.
AggregateFinal func(ctx Context) (Value, error)

// WindowValue is called to get the current value of a aggregate window function.
// WindowValue is called to get the current value of an aggregate window function.
// This function will not be called when using an aggregate window function
// as an ordinary aggregate function.
//
// Use closure variables to pass information between AggregateStep and WindowValue.
WindowValue func(ctx Context) (Value, error)
// WindowInverse is called to remove
// the oldest presently aggregated result of AggregateStep
Expand Down Expand Up @@ -608,6 +620,14 @@ func funcTrampoline(tls *libc.TLS, ctx uintptr, n int32, valarray uintptr) {
}

func stepTrampoline(tls *libc.TLS, ctx uintptr, n int32, valarray uintptr) {
f := getxfuncs(tls, ctx)
if f.xValue != nil {
// SQLite only calls xFinal on window functions
// when they have an aggregate context allocated.
// The actual data is unused, since the closures pass data out-of-band.
lib.Xsqlite3_aggregate_context(tls, ctx, 1)
}

vals := make([]Value, 0, int(n))
for ; len(vals) < cap(vals); valarray += uintptr(ptrSize) {
vals = append(vals, Value{
Expand All @@ -616,7 +636,7 @@ func stepTrampoline(tls *libc.TLS, ctx uintptr, n int32, valarray uintptr) {
})
}
goCtx := Context{tls: tls, ptr: ctx}
if err := getxfuncs(tls, ctx).xStep(goCtx, vals); err != nil {
if err := f.xStep(goCtx, vals); err != nil {
goCtx.resultError(err)
}
}
Expand Down
10 changes: 10 additions & 0 deletions func_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ func TestAggFunc(t *testing.T) {
}
}

finalCalled := false
sumintsImpl := &FunctionImpl{
NArgs: 1,
Deterministic: true,
Expand All @@ -114,6 +115,7 @@ func TestAggFunc(t *testing.T) {
return nil
}
sumintsImpl.AggregateFinal = func(ctx Context) (Value, error) {
finalCalled = true
result := IntegerValue(sum)
sum = 0
return result, nil
Expand All @@ -134,6 +136,9 @@ func TestAggFunc(t *testing.T) {
if got := stmt.ColumnInt(0); got != want {
t.Errorf("sum(c)=%d, want %d", got, want)
}
if !finalCalled {
t.Error("xFinal not called")
}
}

// Equivalent of https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions
Expand Down Expand Up @@ -172,6 +177,7 @@ func TestWindowFunc(t *testing.T) {
t.Errorf("INSERT: %v", err)
}

finalCalled := false
sumintImpl := &FunctionImpl{
NArgs: 1,
Deterministic: true,
Expand All @@ -194,6 +200,7 @@ func TestWindowFunc(t *testing.T) {
return IntegerValue(sum), nil
}
sumintImpl.AggregateFinal = func(ctx Context) (Value, error) {
finalCalled = true
result := IntegerValue(sum)
sum = 0
return result, nil
Expand Down Expand Up @@ -240,6 +247,9 @@ func TestWindowFunc(t *testing.T) {
if diff := cmp.Diff(want, got, cmp.AllowUnexported(row{})); diff != "" {
t.Errorf("-want +got:\n%s", diff)
}
if !finalCalled {
t.Error("xFinal not called")
}
}

func TestCastTextToInteger(t *testing.T) {
Expand Down

0 comments on commit 8ade6d3

Please sign in to comment.