From 72a5bedeb326a44065317608ebaead3303559ade Mon Sep 17 00:00:00 2001 From: Sergey Grebenshchikov Date: Fri, 11 Oct 2024 15:09:30 +0200 Subject: [PATCH] clean up interfaces, update docs --- README.md | 94 +++++++++++++++++++++++++--------------- example_test.go | 8 ++-- lsh/example_test.go | 4 +- lsh/model.go | 32 +++++++++++--- lsh/model_bench_test.go | 4 +- lsh/model_test.go | 22 +++++++--- lsh/model_wide.go | 32 +++++++++++--- lsh/model_wide_test.go | 22 +++++++--- model.go | 27 +++++++++--- model_bench_test.go | 16 +++---- model_test.go | 44 +++++++++---------- model_wide.go | 23 ++++++++-- model_wide_bench_test.go | 4 +- model_wide_test.go | 12 ++++- 14 files changed, 233 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index 3f594f2..b65493f 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,24 @@ The sub-package [`lsh`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh) ## Usage +There are just three methods you'll usually need: + +- **Fit** *(data, labels, [options])*: create a model from a dataset + + Variants: [`bitknn.Fit`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Fit), [`bitknn.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#FitWide), [`lsh.Fit`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Fit), [`lsh.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#FitWide) +- **Find** *(k, point)*: Given a point, return the *k* nearest neighbor's indices and distances. + + Variants: [`bitknn.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Find), [`bitknn.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Find), [`lsh.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Find), [`lsh.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Find) + +- **Predict** *(k, point, votes)*: Predict the label for a given point based on its nearest neighbors, write the label votes into the provided vote counter. + + Variants: [`bitknn.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Predict), [`bitknn.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Predict), [`lsh.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Predict), [`lsh.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Predict) + +Each of the above methods is available on each model type. There are four model types in total: + +- **Exact k-NN** models: [`bitknn.Model`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model) (64 bits), [`bitknn.WideModel`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel) (*N* * 64 bits) +- **Approximate (ANN)** models: [`lsh.Model`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model) (64 bits), [`lsh.WideModel`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel) (*N* * 64 bits) + ### Basic usage ```go @@ -53,14 +71,16 @@ func main() { votes := make([]float64, 2) k := 2 - model.Predict1(k, 0b101011, bitknn.VoteSlice(votes)) + model.Predict(k, 0b101011, bitknn.VoteSlice(votes)) + // or, just return the nearest neighbor's distances and indices: + // distances,indices := model.Find(k, 0b101011) fmt.Println("Votes:", bitknn.VoteSlice(votes)) // you can also use a map for the votes. // this is good if you have a very large number of different labels: votesMap := make(map[int]float64) - model.Predict1(k, 0b101011, bitknn.VoteMap(votesMap)) + model.Predict(k, 0b101011, bitknn.VoteMap(votesMap)) fmt.Println("Votes for 0:", votesMap[0]) } ``` @@ -96,13 +116,15 @@ func main() { votes := make([]float64, 2) k := 2 - model.Predict1(k, 0b101011, bitknn.VoteSlice(votes)) + model.Predict(k, 0b101011, bitknn.VoteSlice(votes)) + // or, just return the nearest neighbor's distances and indices: + // distances,indices := model.Find(k, 0b101011) fmt.Println("Votes:", bitknn.VoteSlice(votes)) // you can also use a map for the votes votesMap := make(map[int]float64) - model.Predict1(k, 0b101011, bitknn.VoteMap(votesMap)) + model.Predict(k, 0b101011, bitknn.VoteMap(votesMap)) fmt.Println("Votes for 0:", votesMap[0]) } ``` @@ -163,7 +185,7 @@ func main() { k := 2 query := pack.String("fob") - model.Predict1(k, query, bitknn.VoteSlice(votes)) + model.Predict(k, query, bitknn.VoteSlice(votes)) fmt.Println("Votes:", bitknn.VoteSlice(votes)) } @@ -188,37 +210,37 @@ pkg: github.com/keilerkonzept/bitknn cpu: Apple M1 Pro ``` -| Op | N | k | Distance weighting | Vote values | sec / op | B/op | allocs/op | -|------------|---------|-----|--------------------|-------------|--------------|------|-----------| -| `Predict1` | 100 | 3 | | | 138.7n ± 22% | 0 | 0 | -| `Predict1` | 100 | 3 | | ☑️ | 127.8n ± 11% | 0 | 0 | -| `Predict1` | 100 | 3 | linear | | 137.0n ± 11% | 0 | 0 | -| `Predict1` | 100 | 3 | linear | ☑️ | 136.7n ± 10% | 0 | 0 | -| `Predict1` | 100 | 3 | quadratic | | 137.2n ± 7% | 0 | 0 | -| `Predict1` | 100 | 3 | quadratic | ☑️ | 130.4n ± 4% | 0 | 0 | -| `Predict1` | 100 | 3 | custom | | 140.6n ± 7% | 0 | 0 | -| `Predict1` | 100 | 3 | custom | ☑️ | 134.9n ± 13% | 0 | 0 | -| `Predict1` | 100 | 10 | | | 307.4n ± 11% | 0 | 0 | -| `Predict1` | 100 | 10 | | ☑️ | 297.8n ± 15% | 0 | 0 | -| `Predict1` | 100 | 10 | linear | | 288.2n ± 18% | 0 | 0 | -| `Predict1` | 100 | 10 | linear | ☑️ | 302.9n ± 14% | 0 | 0 | -| `Predict1` | 100 | 10 | quadratic | | 283.7n ± 15% | 0 | 0 | -| `Predict1` | 100 | 10 | quadratic | ☑️ | 290.0n ± 13% | 0 | 0 | -| `Predict1` | 100 | 10 | custom | | 313.1n ± 17% | 0 | 0 | -| `Predict1` | 100 | 10 | custom | ☑️ | 316.2n ± 11% | 0 | 0 | -| `Predict1` | 100 | 100 | | ☑️ | 545.4n ± 4% | 0 | 0 | -| `Predict1` | 100 | 100 | linear | | 542.4n ± 4% | 0 | 0 | -| `Predict1` | 100 | 100 | linear | ☑️ | 577.5n ± 4% | 0 | 0 | -| `Predict1` | 100 | 100 | quadratic | | 553.1n ± 3% | 0 | 0 | -| `Predict1` | 100 | 100 | quadratic | ☑️ | 582.4n ± 6% | 0 | 0 | -| `Predict1` | 100 | 100 | custom | | 683.8n ± 4% | 0 | 0 | -| `Predict1` | 100 | 100 | custom | ☑️ | 748.5n ± 2% | 0 | 0 | -| `Predict1` | 1000 | 3 | | | 669.5n ± 6% | 0 | 0 | -| `Predict1` | 1000 | 10 | | | 930.3n ± 7% | 0 | 0 | -| `Predict1` | 1000 | 100 | | | 3.762µ ± 5% | 0 | 0 | -| `Predict1` | 1000000 | 3 | | | 532.1µ ± 1% | 0 | 0 | -| `Predict1` | 1000000 | 10 | | | 534.5µ ± 1% | 0 | 0 | -| `Predict1` | 1000000 | 100 | | | 551.7µ ± 1% | 0 | 0 | +| Op | N | k | Distance weighting | Vote values | sec / op | B/op | allocs/op | +|-----------|---------|-----|--------------------|-------------|--------------|------|-----------| +| `Predict` | 100 | 3 | | | 138.7n ± 22% | 0 | 0 | +| `Predict` | 100 | 3 | | ☑️ | 127.8n ± 11% | 0 | 0 | +| `Predict` | 100 | 3 | linear | | 137.0n ± 11% | 0 | 0 | +| `Predict` | 100 | 3 | linear | ☑️ | 136.7n ± 10% | 0 | 0 | +| `Predict` | 100 | 3 | quadratic | | 137.2n ± 7% | 0 | 0 | +| `Predict` | 100 | 3 | quadratic | ☑️ | 130.4n ± 4% | 0 | 0 | +| `Predict` | 100 | 3 | custom | | 140.6n ± 7% | 0 | 0 | +| `Predict` | 100 | 3 | custom | ☑️ | 134.9n ± 13% | 0 | 0 | +| `Predict` | 100 | 10 | | | 307.4n ± 11% | 0 | 0 | +| `Predict` | 100 | 10 | | ☑️ | 297.8n ± 15% | 0 | 0 | +| `Predict` | 100 | 10 | linear | | 288.2n ± 18% | 0 | 0 | +| `Predict` | 100 | 10 | linear | ☑️ | 302.9n ± 14% | 0 | 0 | +| `Predict` | 100 | 10 | quadratic | | 283.7n ± 15% | 0 | 0 | +| `Predict` | 100 | 10 | quadratic | ☑️ | 290.0n ± 13% | 0 | 0 | +| `Predict` | 100 | 10 | custom | | 313.1n ± 17% | 0 | 0 | +| `Predict` | 100 | 10 | custom | ☑️ | 316.2n ± 11% | 0 | 0 | +| `Predict` | 100 | 100 | | ☑️ | 545.4n ± 4% | 0 | 0 | +| `Predict` | 100 | 100 | linear | | 542.4n ± 4% | 0 | 0 | +| `Predict` | 100 | 100 | linear | ☑️ | 577.5n ± 4% | 0 | 0 | +| `Predict` | 100 | 100 | quadratic | | 553.1n ± 3% | 0 | 0 | +| `Predict` | 100 | 100 | quadratic | ☑️ | 582.4n ± 6% | 0 | 0 | +| `Predict` | 100 | 100 | custom | | 683.8n ± 4% | 0 | 0 | +| `Predict` | 100 | 100 | custom | ☑️ | 748.5n ± 2% | 0 | 0 | +| `Predict` | 1000 | 3 | | | 669.5n ± 6% | 0 | 0 | +| `Predict` | 1000 | 10 | | | 930.3n ± 7% | 0 | 0 | +| `Predict` | 1000 | 100 | | | 3.762µ ± 5% | 0 | 0 | +| `Predict` | 1000000 | 3 | | | 532.1µ ± 1% | 0 | 0 | +| `Predict` | 1000000 | 10 | | | 534.5µ ± 1% | 0 | 0 | +| `Predict` | 1000000 | 100 | | | 551.7µ ± 1% | 0 | 0 | ## License diff --git a/example_test.go b/example_test.go index 9f337b2..d728b78 100644 --- a/example_test.go +++ b/example_test.go @@ -19,14 +19,16 @@ func Example() { votes := make([]float64, 2) k := 2 - model.Predict1(k, 0b101011, bitknn.VoteSlice(votes)) + model.Predict(k, 0b101011, bitknn.VoteSlice(votes)) + // or, just return the nearest neighbor's distances and indices: + // distances,indices := model.Find(k, 0b101011) fmt.Println("Votes:", bitknn.VoteSlice(votes)) // you can also use a map for the votes. // this is good if you have a very large number of different labels: votesMap := make(map[int]float64) - model.Predict1(k, 0b101011, bitknn.VoteMap(votesMap)) + model.Predict(k, 0b101011, bitknn.VoteMap(votesMap)) fmt.Println("Votes for 0:", votesMap[0]) // Output: // Votes: [0.5 0.25] @@ -52,7 +54,7 @@ func ExampleFitWide() { k := 2 query := pack.String("fob") - model.Predict1(k, query, bitknn.VoteSlice(votes)) + model.Predict(k, query, bitknn.VoteSlice(votes)) fmt.Println("Votes:", bitknn.VoteSlice(votes)) diff --git a/lsh/example_test.go b/lsh/example_test.go index 2f934f3..69edb4a 100644 --- a/lsh/example_test.go +++ b/lsh/example_test.go @@ -23,13 +23,13 @@ func Example() { votes := make([]float64, 2) k := 2 - model.Predict1(k, 0b101011, bitknn.VoteSlice(votes)) + model.Predict(k, 0b101011, bitknn.VoteSlice(votes)) fmt.Println("Votes:", bitknn.VoteSlice(votes)) // you can also use a map for the votes votesMap := make(map[int]float64) - model.Predict1(k, 0b101011, bitknn.VoteMap(votesMap)) + model.Predict(k, 0b101011, bitknn.VoteMap(votesMap)) fmt.Println("Votes for 0:", votesMap[0]) // Output: // Votes: [0.5 0.25] diff --git a/lsh/model.go b/lsh/model.go index b10f0fc..e35f1fd 100644 --- a/lsh/model.go +++ b/lsh/model.go @@ -72,24 +72,42 @@ func Fit(data []uint64, labels []int, hash Hash, opts ...bitknn.Option) *Model { } } -// Predict1 predicts the label for a single input using the LSH model. -func (me *Model) Predict1(k int, x uint64, votes bitknn.VoteCounter) int { +// Finds the nearest neighbors of the given point. +// Writes their distances and indices in the dataset into the pre-allocated slices. +// Returns the distance and index slices, truncated to the actual number of neighbors found. +func (me *Model) Find(k int, x uint64) ([]int, []int) { me.PreallocateHeap(k) - return me.Predict1Into(k, x, votes, me.HeapBucketDistances, me.HeapBucketIDs, me.HeapDistances, me.HeapIndices) + return me.FindInto(k, x, me.HeapBucketDistances, me.HeapBucketIDs, me.HeapDistances, me.HeapIndices) +} + +// Finds the nearest neighbors of the given point. +// Writes their distances and indices in the dataset into the provided slices. +// The slices should be pre-allocated to length k+1. +// Returns the distance and index slices, truncated to the actual number of neighbors found. +func (me *Model) FindInto(k int, x uint64, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) ([]int, []int) { + xp := me.Hash.Hash1(x) + k, _ = Nearest(me.Data, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices) + return distances[:k], indices[:k] +} + +// Predict predicts the label for a single input using the LSH model. +func (me *Model) Predict(k int, x uint64, votes bitknn.VoteCounter) int { + me.PreallocateHeap(k) + return me.PredictInto(k, x, votes, me.HeapBucketDistances, me.HeapBucketIDs, me.HeapDistances, me.HeapIndices) } // Predicts the label of a single input point. Each call allocates three new slices of length [k]+1 for the neighbor heaps. -func (me *Model) Predict1Alloc(k int, x uint64, votes bitknn.VoteCounter) int { +func (me *Model) PredictAlloc(k int, x uint64, votes bitknn.VoteCounter) int { bucketDistances := make([]int, k+1) bucketIDs := make([]uint64, k+1) distances := make([]int, k+1) indices := make([]int, k+1) - return me.Predict1Into(k, x, votes, bucketDistances, bucketIDs, distances, indices) + return me.PredictInto(k, x, votes, bucketDistances, bucketIDs, distances, indices) } -// Predict1Into predicts the label for a single input using the given slices (of length [k]+1 each) for the neighbor heaps. -func (me *Model) Predict1Into(k int, x uint64, votes bitknn.VoteCounter, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int { +// PredictInto predicts the label for a single input using the given slices (of length [k]+1 each) for the neighbor heaps. +func (me *Model) PredictInto(k int, x uint64, votes bitknn.VoteCounter, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int { xp := me.Hash.Hash1(x) k, n := Nearest(me.Data, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices) me.Vote(k, distances, indices, votes) diff --git a/lsh/model_bench_test.go b/lsh/model_bench_test.go index f5201b9..1573c40 100644 --- a/lsh/model_bench_test.go +++ b/lsh/model_bench_test.go @@ -9,7 +9,7 @@ import ( "github.com/keilerkonzept/bitknn/lsh" ) -func Benchmark_Model_Predict1(b *testing.B) { +func Benchmark_Model_Predict(b *testing.B) { type bench struct { hashes []lsh.Hash dataSize []int @@ -34,7 +34,7 @@ func Benchmark_Model_Predict1(b *testing.B) { model.PreallocateHeap(k) b.ResetTimer() for n := 0; n < b.N; n++ { - model.Predict1(k, query, bitknn.DiscardVotes) + model.Predict(k, query, bitknn.DiscardVotes) } }) } diff --git a/lsh/model_test.go b/lsh/model_test.go index f7690f8..15e6f3d 100644 --- a/lsh/model_test.go +++ b/lsh/model_test.go @@ -87,23 +87,35 @@ func Test_Model_NoHash_IsExact(t *testing.T) { knn.PreallocateHeap(k) ann.PreallocateHeap(k) for _, q := range queries { - knn.Predict1(k, q, bitknn.VoteSlice(knnVotes)) - - ann.Predict1(k, q, bitknn.VoteSlice(annVotes)) + knn.Predict(k, q, bitknn.VoteSlice(knnVotes)) + ann.Predict(k, q, bitknn.VoteSlice(annVotes)) slices.Sort(knn.HeapDistances[:k]) slices.Sort(ann.HeapDistances[:k]) if !reflect.DeepEqual(knn.HeapDistances[:k], ann.HeapDistances[:k]) { t.Fatal("NoHash ANN should result in the same distances for the nearest neighbors: ", knn.HeapDistances[:k], ann.HeapDistances[:k], knn.HeapIndices[:k], ann.HeapIndices[:k]) } - ann0.Predict1Alloc(k, q, bitknn.VoteSlice(annVotes)) + kd, ki := knn.Find(k, q) + ad, ai := ann.Find(k, q) + slices.Sort(kd) + slices.Sort(ad) + if !reflect.DeepEqual(kd, ad) { + t.Fatal("NoHash ANN should result in the same distances for the nearest neighbors: ", kd, ad) + } + slices.Sort(ki) + slices.Sort(ai) + if !reflect.DeepEqual(ki, ai) { + t.Fatal("NoHash ANN should result in the same indices for the nearest neighbors: ", ki, ai) + } + + ann0.PredictAlloc(k, q, bitknn.VoteSlice(annVotes)) for i, vk := range knnVotes { va := annVotes[i] if math.Abs(vk-va) > eps { t.Fatalf("ANN: %s: %v: %v %v", pair.name, q, knnVotes, annVotes) } } - ann0.Predict1(k, q, bitknn.VoteSlice(annVotes)) + ann0.Predict(k, q, bitknn.VoteSlice(annVotes)) for i, vk := range knnVotes { va := annVotes[i] if math.Abs(vk-va) > eps { diff --git a/lsh/model_wide.go b/lsh/model_wide.go index 714d269..4092fb2 100644 --- a/lsh/model_wide.go +++ b/lsh/model_wide.go @@ -65,24 +65,42 @@ func FitWide(data [][]uint64, labels []int, hash HashWide, opts ...bitknn.Option } } -// Predict1 predicts the label for a single input using the LSH model. -func (me *WideModel) Predict1(k int, x []uint64, votes bitknn.VoteCounter) int { +// Finds the nearest neighbors of the given point. +// Writes their distances and indices in the dataset into the pre-allocated slices. +// Returns the distance and index slices, truncated to the actual number of neighbors found. +func (me *WideModel) Find(k int, x []uint64) ([]int, []int) { me.PreallocateHeap(k) - return me.Predict1Into(k, x, votes, me.HeapBucketDistances, me.HeapBucketIDs, me.Narrow.HeapDistances, me.Narrow.HeapIndices) + return me.FindInto(k, x, me.HeapBucketDistances, me.HeapBucketIDs, me.Narrow.HeapDistances, me.Narrow.HeapIndices) +} + +// Finds the nearest neighbors of the given point. +// Writes their distances and indices in the dataset into the provided slices. +// The slices should be pre-allocated to length k+1. +// Returns the distance and index slices, truncated to the actual number of neighbors found. +func (me *WideModel) FindInto(k int, x []uint64, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) ([]int, []int) { + xp := me.Hash.Hash1Wide(x) + k, _ = NearestWide(me.WideData, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices) + return distances[:k], indices[:k] +} + +// Predict predicts the label for a single input using the LSH model. +func (me *WideModel) Predict(k int, x []uint64, votes bitknn.VoteCounter) int { + me.PreallocateHeap(k) + return me.PredictInto(k, x, votes, me.HeapBucketDistances, me.HeapBucketIDs, me.Narrow.HeapDistances, me.Narrow.HeapIndices) } // Predicts the label of a single input point. Each call allocates three new slices of length [k]+1 for the neighbor heaps. -func (me *WideModel) Predict1Alloc(k int, x []uint64, votes bitknn.VoteCounter) int { +func (me *WideModel) PredictAlloc(k int, x []uint64, votes bitknn.VoteCounter) int { bucketDistances := make([]int, k+1) bucketIDs := make([]uint64, k+1) distances := make([]int, k+1) indices := make([]int, k+1) - return me.Predict1Into(k, x, votes, bucketDistances, bucketIDs, distances, indices) + return me.PredictInto(k, x, votes, bucketDistances, bucketIDs, distances, indices) } -// Predict1Into predicts the label for a single input using the given slices (of length [k]+1 each) for the neighbor heaps. -func (me *WideModel) Predict1Into(k int, x []uint64, votes bitknn.VoteCounter, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int { +// PredictInto predicts the label for a single input using the given slices (of length [k]+1 each) for the neighbor heaps. +func (me *WideModel) PredictInto(k int, x []uint64, votes bitknn.VoteCounter, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int { xp := me.Hash.Hash1Wide(x) k0, _ := NearestWide(me.WideData, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices) me.WideModel.Narrow.Vote(k0, distances, indices, votes) diff --git a/lsh/model_wide_test.go b/lsh/model_wide_test.go index 1dbbe93..c98ae4e 100644 --- a/lsh/model_wide_test.go +++ b/lsh/model_wide_test.go @@ -44,14 +44,20 @@ func Test_WideModel_64bit_Equal_To_Narrow(t *testing.T) { narrow.PreallocateHeap(k) wide.PreallocateHeap(k) for _, q := range queries { - narrow.Predict1(k, q, bitknn.VoteSlice(narrowVotes)) - wide.Predict1(k, []uint64{q}, bitknn.VoteSlice(wideVotes)) - slices.Sort(narrow.HeapDistances[:k]) - slices.Sort(wide.Narrow.HeapDistances[:k]) + nd, ni := narrow.Find(k, q) + wd, wi := wide.Find(k, []uint64{q}) + if !reflect.DeepEqual(nd, wd) { + t.Fatal("Wide KNN should result in the same distances for the nearest neighbors: ", nd, wd) + } + if !reflect.DeepEqual(ni, wi) { + t.Fatal("Wide ANN should result in the same indices for the nearest neighbors: ", ni, wi) + } + narrow.Predict(k, q, bitknn.VoteSlice(narrowVotes)) + wide.Predict(k, []uint64{q}, bitknn.VoteSlice(wideVotes)) if !reflect.DeepEqual(narrow.HeapDistances[:k], wide.Narrow.HeapDistances[:k]) { t.Fatal("Wide KNN should result in the same distances for the nearest neighbors: ", narrow.HeapDistances[:k], wide.Narrow.HeapDistances[:k]) } - if !reflect.DeepEqual(narrow.HeapDistances[:k], wide.Narrow.HeapDistances[:k]) { + if !reflect.DeepEqual(narrow.HeapIndices[:k], wide.Narrow.HeapIndices[:k]) { t.Fatal("Wide ANN should result in the same indices for the nearest neighbors: ", narrow.HeapIndices[:k], wide.Narrow.HeapIndices[:k]) } for i, vk := range narrowVotes { @@ -60,13 +66,15 @@ func Test_WideModel_64bit_Equal_To_Narrow(t *testing.T) { t.Fatalf("%s: %v: %v %v", pair.name, q, narrowVotes, wideVotes) } } - wide.Predict1Alloc(k, []uint64{q}, bitknn.VoteSlice(wideVotes)) + wide.PredictAlloc(k, []uint64{q}, bitknn.VoteSlice(wideVotes)) slices.Sort(narrow.HeapDistances[:k]) slices.Sort(wide.Narrow.HeapDistances[:k]) if !reflect.DeepEqual(narrow.HeapDistances[:k], wide.Narrow.HeapDistances[:k]) { t.Fatal("Wide KNN should result in the same distances for the nearest neighbors: ", narrow.HeapDistances[:k], wide.Narrow.HeapDistances[:k]) } - if !reflect.DeepEqual(narrow.HeapDistances[:k], wide.Narrow.HeapDistances[:k]) { + slices.Sort(narrow.HeapIndices[:k]) + slices.Sort(wide.Narrow.HeapIndices[:k]) + if !reflect.DeepEqual(narrow.HeapIndices[:k], wide.Narrow.HeapIndices[:k]) { t.Fatal("Wide ANN should result in the same indices for the nearest neighbors: ", narrow.HeapIndices[:k], wide.Narrow.HeapIndices[:k]) } for i, vk := range narrowVotes { diff --git a/model.go b/model.go index fc0096e..d6dbec2 100644 --- a/model.go +++ b/model.go @@ -42,21 +42,38 @@ func (me *Model) PreallocateHeap(k int) { me.HeapIndices = slice.OrAlloc(me.HeapIndices, k+1) } +// Finds the nearest neighbors of the given point. +// Writes their distances and indices in the dataset into the pre-allocated slices. +// Returns the distance and index slices, truncated to the actual number of neighbors found. +func (me *Model) Find(k int, x uint64) ([]int, []int) { + me.PreallocateHeap(k) + return me.FindInto(k, x, me.HeapDistances, me.HeapIndices) +} + +// Finds the nearest neighbors of the given point. +// Writes their distances and indices in the dataset into the provided slices. +// The slices should be pre-allocated to length k+1. +// Returns the distance and index slices, truncated to the actual number of neighbors found. +func (me *Model) FindInto(k int, x uint64, distances []int, indices []int) ([]int, []int) { + k = Nearest(me.Data, k, x, distances, indices) + return distances[:k], indices[:k] +} + // Predicts the label of a single input point. Each call allocates two new slices of length K+1 for the neighbor heap. -func (me *Model) Predict1Alloc(k int, x uint64, votes VoteCounter) { +func (me *Model) PredictAlloc(k int, x uint64, votes VoteCounter) { distances, indices := make([]int, k+1), make([]int, k+1) - me.Predict1Into(k, x, distances, indices, votes) + me.PredictInto(k, x, distances, indices, votes) } // Predicts the label of a single input point. Reuses two slices of length K+1 for the neighbor heap. -func (me *Model) Predict1(k int, x uint64, votes VoteCounter) { +func (me *Model) Predict(k int, x uint64, votes VoteCounter) { me.HeapDistances = slice.OrAlloc(me.HeapDistances, k+1) me.HeapIndices = slice.OrAlloc(me.HeapIndices, k+1) - me.Predict1Into(k, x, me.HeapDistances, me.HeapIndices, votes) + me.PredictInto(k, x, me.HeapDistances, me.HeapIndices, votes) } // Predicts the label of a single input point, using the given slices for the neighbor heap. -func (me *Model) Predict1Into(k int, x uint64, distances []int, indices []int, votes VoteCounter) { +func (me *Model) PredictInto(k int, x uint64, distances []int, indices []int, votes VoteCounter) { k = Nearest(me.Data, k, x, distances, indices) me.Vote(k, distances, indices, votes) } diff --git a/model_bench_test.go b/model_bench_test.go index 10f21b4..03cf59b 100644 --- a/model_bench_test.go +++ b/model_bench_test.go @@ -9,7 +9,7 @@ import ( "github.com/keilerkonzept/bitknn/internal/testrandom" ) -func Benchmark_Model_Predict1(b *testing.B) { +func Benchmark_Model_Predict(b *testing.B) { type bench struct { dataSize []int k []int @@ -30,7 +30,7 @@ func Benchmark_Model_Predict1(b *testing.B) { model.PreallocateHeap(k) b.ResetTimer() for n := 0; n < b.N; n++ { - model.Predict1(k, query, bitknn.DiscardVotes) + model.Predict(k, query, bitknn.DiscardVotes) } }) } @@ -38,7 +38,7 @@ func Benchmark_Model_Predict1(b *testing.B) { } } -func Benchmark_Model_Predict1V(b *testing.B) { +func Benchmark_Model_PredictV(b *testing.B) { votes := make([]float64, 256) type bench struct { dataSize []int @@ -61,7 +61,7 @@ func Benchmark_Model_Predict1V(b *testing.B) { voteSlice := bitknn.VoteSlice(votes) b.ResetTimer() for n := 0; n < b.N; n++ { - model.Predict1(k, query, &voteSlice) + model.Predict(k, query, &voteSlice) } }) } @@ -69,7 +69,7 @@ func Benchmark_Model_Predict1V(b *testing.B) { } } -func Benchmark_Model_Predict1D(b *testing.B) { +func Benchmark_Model_PredictD(b *testing.B) { votes := make([]float64, 256) type bench struct { dataSize []int @@ -93,7 +93,7 @@ func Benchmark_Model_Predict1D(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - model.Predict1(k, query, &voteSlice) + model.Predict(k, query, &voteSlice) } }) } @@ -102,7 +102,7 @@ func Benchmark_Model_Predict1D(b *testing.B) { } } -func Benchmark_Model_Predict1DV(b *testing.B) { +func Benchmark_Model_PredictDV(b *testing.B) { votes := make([]float64, 256) type bench struct { dataSize []int @@ -127,7 +127,7 @@ func Benchmark_Model_Predict1DV(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - model.Predict1(k, query, &voteSlice) + model.Predict(k, query, &voteSlice) } }) } diff --git a/model_test.go b/model_test.go index f373a83..8ce86d6 100644 --- a/model_test.go +++ b/model_test.go @@ -23,7 +23,7 @@ func Test_DistanceWeighting_String(t *testing.T) { } } -func Test_Model_Predict1_Predict1Realloc(t *testing.T) { +func Test_Model_Predict_PredictRealloc(t *testing.T) { data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} labels := []int{0, 1, 1, 0} k := 2 @@ -34,7 +34,7 @@ func Test_Model_Predict1_Predict1Realloc(t *testing.T) { votes := make([]float64, k) model.PreallocateHeap(k) { - model.Predict1(k, x, bitknn.VoteSlice(votes)) + model.Predict(k, x, bitknn.VoteSlice(votes)) expectedVotes := []float64{1, 1} if diff := cmp.Diff(expectedVotes, votes); diff != "" { @@ -42,7 +42,7 @@ func Test_Model_Predict1_Predict1Realloc(t *testing.T) { } } { - model.Predict1Alloc(k, x, bitknn.VoteSlice(votes)) + model.PredictAlloc(k, x, bitknn.VoteSlice(votes)) expectedVotes := []float64{1, 1} if diff := cmp.Diff(expectedVotes, votes); diff != "" { @@ -51,7 +51,7 @@ func Test_Model_Predict1_Predict1Realloc(t *testing.T) { } } -func Test_Model_Reslice_Predict1(t *testing.T) { +func Test_Model_Reslice_Predict(t *testing.T) { data := []uint64{0b0000, 0b11111, 0b0011, 0b0101} labels := []int{0, 1, 1, 0} values := []float64{1.0, 2.0, 3.0, 4.0} @@ -61,21 +61,21 @@ func Test_Model_Reslice_Predict1(t *testing.T) { x := uint64(0b0010) votes := make([]float64, 2) model.PreallocateHeap(3) - model.Predict1(2, x, bitknn.VoteSlice(votes)) + model.Predict(2, x, bitknn.VoteSlice(votes)) { expectedVotes := []float64{1, 3} if diff := cmp.Diff(expectedVotes, votes); diff != "" { t.Error(diff) } } - model.Predict1(3, x, bitknn.VoteSlice(votes)) + model.Predict(3, x, bitknn.VoteSlice(votes)) { expectedVotes := []float64{5, 3} if diff := cmp.Diff(expectedVotes, votes); diff != "" { t.Error(diff) } } - model.Predict1(2, x, bitknn.VoteSlice(votes)) + model.Predict(2, x, bitknn.VoteSlice(votes)) { expectedVotes := []float64{1, 3} @@ -85,7 +85,7 @@ func Test_Model_Reslice_Predict1(t *testing.T) { } { - model.Predict1(10, x, bitknn.VoteSlice(votes)) + model.Predict(10, x, bitknn.VoteSlice(votes)) expectedVotes := []float64{5, 5} if diff := cmp.Diff(expectedVotes, votes); diff != "" { t.Error(diff) @@ -93,7 +93,7 @@ func Test_Model_Reslice_Predict1(t *testing.T) { } } -func Test_Model_Predict1V(t *testing.T) { +func Test_Model_PredictV(t *testing.T) { data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} labels := []int{0, 1, 1, 0} values := []float64{1.0, 2.0, 3.0, 4.0} @@ -104,7 +104,7 @@ func Test_Model_Predict1V(t *testing.T) { x := uint64(0b0010) votes := make([]float64, 2) model.PreallocateHeap(k) - model.Predict1(k, x, bitknn.VoteSlice(votes)) + model.Predict(k, x, bitknn.VoteSlice(votes)) expectedVotes := []float64{1, 3} if diff := cmp.Diff(expectedVotes, votes); diff != "" { @@ -112,7 +112,7 @@ func Test_Model_Predict1V(t *testing.T) { } } -func Test_Model_Predict1D(t *testing.T) { +func Test_Model_PredictD(t *testing.T) { data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} labels := []int{0, 1, 1, 0} k := 3 @@ -122,7 +122,7 @@ func Test_Model_Predict1D(t *testing.T) { x := uint64(0b0001) votes := make([]float64, 2) model.PreallocateHeap(k) - model.Predict1(k, x, bitknn.VoteSlice(votes)) + model.Predict(k, x, bitknn.VoteSlice(votes)) expectedVotes := []float64{1, 0.5} if diff := cmp.Diff(expectedVotes, votes); diff != "" { @@ -130,7 +130,7 @@ func Test_Model_Predict1D(t *testing.T) { } } -func Test_Model_Predict1VL(t *testing.T) { +func Test_Model_PredictVL(t *testing.T) { data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} labels := []int{0, 1, 1, 0} values := []float64{1.0, 2.0, 3.0, 3.0} @@ -141,7 +141,7 @@ func Test_Model_Predict1VL(t *testing.T) { x := uint64(0b0000) votes := make([]float64, 2) model.PreallocateHeap(k) - model.Predict1(k, x, bitknn.VoteSlice(votes)) + model.Predict(k, x, bitknn.VoteSlice(votes)) expectedVotes := []float64{2, 1} if diff := cmp.Diff(expectedVotes, votes); diff != "" { @@ -149,7 +149,7 @@ func Test_Model_Predict1VL(t *testing.T) { } } -func Test_Model_Predict1VQ(t *testing.T) { +func Test_Model_PredictVQ(t *testing.T) { data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} labels := []int{0, 1, 1, 0} values := []float64{1.0, 2.0, 4.0, 5.0} @@ -159,7 +159,7 @@ func Test_Model_Predict1VQ(t *testing.T) { x := uint64(0b0000) votes := make([]float64, 2) model.PreallocateHeap(k) - model.Predict1(k, x, bitknn.VoteSlice(votes)) + model.Predict(k, x, bitknn.VoteSlice(votes)) expectedVotes := []float64{2, 0.8} if diff := cmp.Diff(expectedVotes, votes); diff != "" { @@ -167,7 +167,7 @@ func Test_Model_Predict1VQ(t *testing.T) { } } -func Test_Model_Predict1Q(t *testing.T) { +func Test_Model_PredictQ(t *testing.T) { data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} labels := []int{0, 1, 1, 0} k := 3 @@ -176,7 +176,7 @@ func Test_Model_Predict1Q(t *testing.T) { x := uint64(0b0000) votes := make([]float64, 2) model.PreallocateHeap(k) - model.Predict1(k, x, bitknn.VoteSlice(votes)) + model.Predict(k, x, bitknn.VoteSlice(votes)) expectedVotes := []float64{1.2, 0.2} if diff := cmp.Diff(expectedVotes, votes); diff != "" { @@ -184,7 +184,7 @@ func Test_Model_Predict1Q(t *testing.T) { } } -func Test_Model_Predict1VC(t *testing.T) { +func Test_Model_PredictVC(t *testing.T) { data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} labels := []int{0, 1, 1, 0} values := []float64{1.0, 2.0, 3.0, 3.0} @@ -202,7 +202,7 @@ func Test_Model_Predict1VC(t *testing.T) { x := uint64(0b0000) votes := make([]float64, 2) model.PreallocateHeap(k) - model.Predict1(k, x, bitknn.VoteSlice(votes)) + model.Predict(k, x, bitknn.VoteSlice(votes)) expectedVotes := []float64{4, 3} if diff := cmp.Diff(expectedVotes, votes); diff != "" { @@ -210,7 +210,7 @@ func Test_Model_Predict1VC(t *testing.T) { } } -func Test_Model_Predict1C(t *testing.T) { +func Test_Model_PredictC(t *testing.T) { data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} labels := []int{0, 1, 2, 0} k := 3 @@ -227,7 +227,7 @@ func Test_Model_Predict1C(t *testing.T) { x := uint64(0b0000) votes := make([]float64, 3) model.PreallocateHeap(k) - model.Predict1(k, x, bitknn.VoteSlice(votes)) + model.Predict(k, x, bitknn.VoteSlice(votes)) expectedVotes := []float64{2, 0, 1} if diff := cmp.Diff(expectedVotes, votes); diff != "" { diff --git a/model_wide.go b/model_wide.go index 18dd6ad..1e93520 100644 --- a/model_wide.go +++ b/model_wide.go @@ -21,16 +21,33 @@ func (me *WideModel) PreallocateHeap(k int) { me.Narrow.PreallocateHeap(k) } +// Finds the nearest neighbors of the given point. +// Writes their distances and indices in the dataset into the pre-allocated slices. +// Returns the distance and index slices, truncated to the actual number of neighbors found. +func (me *WideModel) Find(k int, x []uint64) ([]int, []int) { + me.Narrow.PreallocateHeap(k) + return me.FindInto(k, x, me.Narrow.HeapDistances, me.Narrow.HeapIndices) +} + +// Finds the nearest neighbors of the given point. +// Writes their distances and indices in the dataset into the provided slices. +// The slices should be pre-allocated to length k+1. +// Returns the distance and index slices, truncated to the actual number of neighbors found. +func (me *WideModel) FindInto(k int, x []uint64, distances []int, indices []int) ([]int, []int) { + k = NearestWide(me.WideData, k, x, distances, indices) + return distances[:k], indices[:k] +} + // Predicts the label of a single input point. Reuses two slices of length K+1 for the neighbor heap. // Returns the number of neighbors found. -func (me *WideModel) Predict1(k int, x []uint64, votes VoteCounter) int { +func (me *WideModel) Predict(k int, x []uint64, votes VoteCounter) int { me.Narrow.PreallocateHeap(k) - return me.Predict1Into(k, x, me.Narrow.HeapDistances, me.Narrow.HeapIndices, votes) + return me.PredictInto(k, x, me.Narrow.HeapDistances, me.Narrow.HeapIndices, votes) } // Predicts the label of a single input point, using the given slices for the neighbor heap. // Returns the number of neighbors found. -func (me *WideModel) Predict1Into(k int, x []uint64, distances []int, indices []int, votes VoteCounter) int { +func (me *WideModel) PredictInto(k int, x []uint64, distances []int, indices []int, votes VoteCounter) int { k = NearestWide(me.WideData, k, x, distances, indices) me.Narrow.Vote(k, distances, indices, votes) return k diff --git a/model_wide_bench_test.go b/model_wide_bench_test.go index 568b4e6..9f0d9d0 100644 --- a/model_wide_bench_test.go +++ b/model_wide_bench_test.go @@ -8,7 +8,7 @@ import ( "github.com/keilerkonzept/bitknn/internal/testrandom" ) -func Benchmark_WideModel_Predict1(b *testing.B) { +func Benchmark_WideModel_Predict(b *testing.B) { type bench struct { dim []int dataSize []int @@ -31,7 +31,7 @@ func Benchmark_WideModel_Predict1(b *testing.B) { model.PreallocateHeap(k) b.ResetTimer() for n := 0; n < b.N; n++ { - model.Predict1(k, query, bitknn.DiscardVotes) + model.Predict(k, query, bitknn.DiscardVotes) } }) } diff --git a/model_wide_test.go b/model_wide_test.go index f79a7ab..4568e1b 100644 --- a/model_wide_test.go +++ b/model_wide_test.go @@ -42,8 +42,16 @@ func Test_Model_64bit_Equal_To_Narrow(t *testing.T) { narrow.PreallocateHeap(k) wide.PreallocateHeap(k) for _, q := range queries { - narrow.Predict1(k, q, bitknn.VoteSlice(narrowVotes)) - wide.Predict1(k, []uint64{q}, bitknn.VoteSlice(wideVotes)) + nd, ni := narrow.Find(k, q) + wd, wi := wide.Find(k, []uint64{q}) + if !reflect.DeepEqual(nd, wd) { + t.Fatal("Wide model should result in the same distances for the nearest neighbors as the narrow model: ", nd, wd) + } + if !reflect.DeepEqual(ni, wi) { + t.Fatal("Wide model should result in the same indices for the nearest neighbors as the narrow model: ", ni, wi) + } + narrow.Predict(k, q, bitknn.VoteSlice(narrowVotes)) + wide.Predict(k, []uint64{q}, bitknn.VoteSlice(wideVotes)) slices.Sort(narrow.HeapDistances[:k]) slices.Sort(wide.Narrow.HeapDistances[:k]) if !reflect.DeepEqual(narrow.HeapDistances[:k], wide.Narrow.HeapDistances[:k]) {