From 8ade6d3e398f0496880bfb06b1a12c7ca4417ef8 Mon Sep 17 00:00:00 2001 From: Ross Light Date: Sun, 17 Jul 2022 15:29:34 -0700 Subject: [PATCH] Fix AggregateFinal behavior on window functions --- CHANGELOG.md | 13 ++++++++++++- func.go | 34 +++++++++++++++++++++++++++------- func_test.go | 10 ++++++++++ 3 files changed, 49 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df443d4..8efce6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -[Unreleased]: https://github.com/zombiezen/go-sqlite/compare/v0.10.0...main +[Unreleased]: https://github.com/zombiezen/go-sqlite/compare/v0.10.1...main + +## [0.10.1][] - 2022-07-17 + +Version 0.10.1 fixes a bug in user-defined window functions. +Special thanks to Jan Mercl for assistance in debugging this issue. + +[0.10.1]: https://github.com/zombiezen/go-sqlite/releases/tag/v0.10.1 + +### Fixed + +- `AggregateFinal` is now called correctly at the end of window functions' usages. ## [0.10.0][] - 2022-07-10 diff --git a/func.go b/func.go index 73847d2..a5d03ad 100644 --- a/func.go +++ b/func.go @@ -409,16 +409,28 @@ type FunctionImpl struct { // AggregateStep is called for each row // of an aggregate function's SQL invocation. + // + // Use closure variables to accumulate state between calls to AggregateStep. AggregateStep func(ctx Context, rowArgs []Value) error - // AggregateFinal is called after all of the aggregate function's input rows - // have been stepped through to construct the result. + // AggregateFinal is called + // after all of the aggregate function's input rows have been stepped through. + // The AggregateFinal function should + // reset the state used in AggregateStep to its initial value. + // When using the function as a non-window aggregate, + // the returned value is used as the function's result. + // When using a function as window aggregate, + // the function will only receive a call to AggregateFinal + // when it processes one or more rows during its evaluation + // to reset state (the returned Value is ignored). // - // Use closure variables to pass information between AggregateStep and - // AggregateFinal. The AggregateFinal function should also reset any shared - // variables to their initial states before returning. + // Use closure variables to pass information between AggregateStep and AggregateFinal. AggregateFinal func(ctx Context) (Value, error) - // WindowValue is called to get the current value of a aggregate window function. + // WindowValue is called to get the current value of an aggregate window function. + // This function will not be called when using an aggregate window function + // as an ordinary aggregate function. + // + // Use closure variables to pass information between AggregateStep and WindowValue. WindowValue func(ctx Context) (Value, error) // WindowInverse is called to remove // the oldest presently aggregated result of AggregateStep @@ -608,6 +620,14 @@ func funcTrampoline(tls *libc.TLS, ctx uintptr, n int32, valarray uintptr) { } func stepTrampoline(tls *libc.TLS, ctx uintptr, n int32, valarray uintptr) { + f := getxfuncs(tls, ctx) + if f.xValue != nil { + // SQLite only calls xFinal on window functions + // when they have an aggregate context allocated. + // The actual data is unused, since the closures pass data out-of-band. + lib.Xsqlite3_aggregate_context(tls, ctx, 1) + } + vals := make([]Value, 0, int(n)) for ; len(vals) < cap(vals); valarray += uintptr(ptrSize) { vals = append(vals, Value{ @@ -616,7 +636,7 @@ func stepTrampoline(tls *libc.TLS, ctx uintptr, n int32, valarray uintptr) { }) } goCtx := Context{tls: tls, ptr: ctx} - if err := getxfuncs(tls, ctx).xStep(goCtx, vals); err != nil { + if err := f.xStep(goCtx, vals); err != nil { goCtx.resultError(err) } } diff --git a/func_test.go b/func_test.go index 5994f8a..1648b64 100644 --- a/func_test.go +++ b/func_test.go @@ -102,6 +102,7 @@ func TestAggFunc(t *testing.T) { } } + finalCalled := false sumintsImpl := &FunctionImpl{ NArgs: 1, Deterministic: true, @@ -114,6 +115,7 @@ func TestAggFunc(t *testing.T) { return nil } sumintsImpl.AggregateFinal = func(ctx Context) (Value, error) { + finalCalled = true result := IntegerValue(sum) sum = 0 return result, nil @@ -134,6 +136,9 @@ func TestAggFunc(t *testing.T) { if got := stmt.ColumnInt(0); got != want { t.Errorf("sum(c)=%d, want %d", got, want) } + if !finalCalled { + t.Error("xFinal not called") + } } // Equivalent of https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions @@ -172,6 +177,7 @@ func TestWindowFunc(t *testing.T) { t.Errorf("INSERT: %v", err) } + finalCalled := false sumintImpl := &FunctionImpl{ NArgs: 1, Deterministic: true, @@ -194,6 +200,7 @@ func TestWindowFunc(t *testing.T) { return IntegerValue(sum), nil } sumintImpl.AggregateFinal = func(ctx Context) (Value, error) { + finalCalled = true result := IntegerValue(sum) sum = 0 return result, nil @@ -240,6 +247,9 @@ func TestWindowFunc(t *testing.T) { if diff := cmp.Diff(want, got, cmp.AllowUnexported(row{})); diff != "" { t.Errorf("-want +got:\n%s", diff) } + if !finalCalled { + t.Error("xFinal not called") + } } func TestCastTextToInteger(t *testing.T) {