From 6badad3c6da09ba5d4097acd704b6ed313209340 Mon Sep 17 00:00:00 2001 From: John Brooks Date: Sat, 8 Aug 2020 13:19:34 -0700 Subject: [PATCH] Add custom collation functions CreateCollation registers a Go function as a sqlite collation, similar to CreateFunction. These can be used in queries for custom sorting. --- collation.go | 110 ++++++++++++++++++++++++++++++++++++++++++++++ collation_test.go | 71 ++++++++++++++++++++++++++++++ wrappers.c | 9 ++++ wrappers.h | 3 ++ 4 files changed, 193 insertions(+) create mode 100644 collation.go create mode 100644 collation_test.go diff --git a/collation.go b/collation.go new file mode 100644 index 0000000..9f8b2e2 --- /dev/null +++ b/collation.go @@ -0,0 +1,110 @@ +// Copyright (c) 2020 John Brooks +// +// Permission to use, copy, modify, and distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +package sqlite + +// #include +// #include +// #include "wrappers.h" +// +// static int go_sqlite3_create_collation_v2( +// sqlite3 *db, +// const char *zName, +// int eTextRep, +// uintptr_t pApp, +// int (*xCompare)(void*,int,const void*,int,const void*), +// void (*xDestroy)(void*) +// ) { +// return sqlite3_create_collation_v2( +// db, +// zName, +// eTextRep, +// (void*)pApp, +// xCompare, +// xDestroy); +// } +import "C" +import ( + "sync" +) + +type xcollation struct { + id int + name string + conn *Conn + xCompare func(string, string) int +} + +var xcollations = struct { + mu sync.Mutex + m map[int]*xcollation + next int +}{ + m: make(map[int]*xcollation), +} + +// CreateCollation registers a Go function as a SQLite collation function. +// +// These function are used with the COLLATE operator to implement custom sorting in queries. +// +// The xCompare function must return an integer that is negative, zero, or positive if the first +// string is less than, equal to, or greater than the second, respectively. The function must +// always return the same result for the same inputs and must be commutative. +// +// These are the same properties as strings.Compare(). +// +// https://sqlite.org/datatype3.html#collation +// https://sqlite.org/c3ref/create_collation.html +func (conn *Conn) CreateCollation(name string, xCompare func(string, string) int) error { + cname := C.CString(name) + eTextRep := C.int(C.SQLITE_UTF8) + + x := &xcollation{ + name: name, + conn: conn, + xCompare: xCompare, + } + + xcollations.mu.Lock() + xcollations.next++ + x.id = xcollations.next + xcollations.m[x.id] = x + xcollations.mu.Unlock() + + res := C.go_sqlite3_create_collation_v2( + conn.conn, + cname, + eTextRep, + C.uintptr_t(x.id), + (*[0]byte)(C.c_collation_tramp), + (*[0]byte)(C.c_destroy_collation_tramp), + ) + return conn.reserr("Conn.CreateCollation", name, res) +} + +//export go_collation_tramp +func go_collation_tramp(ptr uintptr, aLen C.int, a *C.char, bLen C.int, b *C.char) C.int { + xcollations.mu.Lock() + x := xcollations.m[int(ptr)] + xcollations.mu.Unlock() + return C.int(x.xCompare(C.GoStringN((*C.char)(a), aLen), C.GoStringN((*C.char)(b), bLen))) +} + +//export go_destroy_collation_tramp +func go_destroy_collation_tramp(ptr uintptr) { + id := int(ptr) + xcollations.mu.Lock() + delete(xcollations.m, id) + xcollations.mu.Unlock() +} diff --git a/collation_test.go b/collation_test.go new file mode 100644 index 0000000..e5c60ef --- /dev/null +++ b/collation_test.go @@ -0,0 +1,71 @@ +// Copyright (c) 2020 John Brooks +// +// Permission to use, copy, modify, and distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +package sqlite_test + +import ( + "testing" + + "crawshaw.io/sqlite" + "crawshaw.io/sqlite/sqlitex" +) + +func TestCollation(t *testing.T) { + c, err := sqlite.OpenConn(":memory:", 0) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := c.Close(); err != nil { + t.Error(err) + } + }() + + xCompare := func(a, b string) int { + if len(a) > len(b) { + return 1 + } else if len(a) < len(b) { + return -1 + } else { + return 0 + } + } + + if err := c.CreateCollation("sort_strlen", xCompare); err != nil { + t.Fatal(err) + } + + err = sqlitex.ExecScript(c, ` + CREATE TABLE strs (str); + INSERT INTO strs (str) VALUES ('ccc'),('a'),('bb'),('a'); + `) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := c.PrepareTransient("SELECT str FROM strs ORDER BY str COLLATE sort_strlen") + if err != nil { + t.Fatal(err) + } + wants := []string{"a", "a", "bb", "ccc"} + for i, want := range wants { + if _, err := stmt.Step(); err != nil { + t.Fatal(err) + } + if got := stmt.ColumnText(0); got != want { + t.Errorf("sort_strlen %d got %s, wanted %s", i, got, want) + } + } + stmt.Finalize() +} diff --git a/wrappers.c b/wrappers.c index cbe5d3b..ec674b2 100644 --- a/wrappers.c +++ b/wrappers.c @@ -47,3 +47,12 @@ void c_destroy_tramp(void* ptr) { return go_destroy_tramp((uintptr_t)ptr); } +extern int go_collation_tramp(uintptr_t, int, char *, int, char *); +int c_collation_tramp(void *ptr, int aLen, const void *a, int bLen, const void *b) { + return go_collation_tramp((uintptr_t)ptr, aLen, (char *)a, bLen, (char *)b); +} + +extern void go_destroy_collation_tramp(uintptr_t); +void c_destroy_collation_tramp(void *ptr) { + return go_destroy_collation_tramp((uintptr_t)ptr); +} diff --git a/wrappers.h b/wrappers.h index 4dc38fc..56d004b 100644 --- a/wrappers.h +++ b/wrappers.h @@ -28,4 +28,7 @@ int c_xapply_filter_tramp(void*, const char*); void c_destroy_tramp(void*); +int c_collation_tramp(void *ptr, int aLen, const void *a, int bLen, const void *b); +void c_destroy_collation_tramp(void *ptr); + #endif // WRAPPERS_H