From 8055df1468cfcdee07637f672d6a91cf1022463f Mon Sep 17 00:00:00 2001 From: Sergey Grebenshchikov Date: Wed, 9 Oct 2024 15:59:35 +0200 Subject: [PATCH] lsh test,fixes --- .gitignore | 1 + internal/heap/heap.go | 4 ++ internal/heap/heap_test.go | 4 ++ lsh/hashes.go | 82 +++++++++++---------- lsh/hashes_test.go | 144 +++++++++++++++++++------------------ lsh/model.go | 28 +++++--- lsh/model_bench_test.go | 18 ++--- lsh/model_test.go | 75 +++++++++++++------ lsh/nearest.go | 86 ++++++++++++++++------ lsh/nearest_test.go | 11 +-- model.go | 4 +- nearest.go | 5 +- 12 files changed, 287 insertions(+), 175 deletions(-) diff --git a/.gitignore b/.gitignore index e985e57..76b137f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.out *.txt +*.test testdata diff --git a/internal/heap/heap.go b/internal/heap/heap.go index 9268bfc..d85bc94 100644 --- a/internal/heap/heap.go +++ b/internal/heap/heap.go @@ -22,6 +22,10 @@ func MakeMax[T int | uint64](distances []int, value []T) Max[T] { } } +func (me *Max[T]) Len() int { + return me.len +} + func (me *Max[T]) swap(i, j int) { me.distances[i], me.distances[j] = me.distances[j], me.distances[i] me.values[i], me.values[j] = me.values[j], me.values[i] diff --git a/internal/heap/heap_test.go b/internal/heap/heap_test.go index e2ba076..fd2e33f 100644 --- a/internal/heap/heap_test.go +++ b/internal/heap/heap_test.go @@ -57,6 +57,10 @@ func TestNeighborHeapPushPop(t *testing.T) { heap.PushPop(25, 4) + if heap.Len() != 3 { + t.Error("Expected length not to change") + } + // Check if heap is reordered correctly expectedDistances := []int{25, 20, 10, 30, diff --git a/lsh/hashes.go b/lsh/hashes.go index a53a459..d6c9b0a 100644 --- a/lsh/hashes.go +++ b/lsh/hashes.go @@ -24,6 +24,24 @@ func (me HashFunc) Hash(data []uint64, out []uint64) { } } +type HashCompose []Hash + +// Hash1 applies the function to a single uint64 value. +func (me HashCompose) Hash1(x uint64) uint64 { + for _, h := range me { + x = h.Hash1(x) + } + return x +} + +// Hash applies the function to a slice of uint64 values. +func (me HashCompose) Hash(data []uint64, out []uint64) { + for _, h := range me { + h.Hash(data, out) + data = out + } +} + // NoHash is the identity function. Used as a dummy [Hash] for testing. type NoHash struct{} @@ -35,6 +53,17 @@ func (me NoHash) Hash(data []uint64, out []uint64) { copy(out, data) } +// ConstantHash is a constant 0 function. Used as a dummy [Hash] for testing. +type ConstantHash struct{} + +// Hash1 returns the given value. +func (me ConstantHash) Hash1(x uint64) uint64 { return 0 } + +// Hash copies the input slice to the output slice. +func (me ConstantHash) Hash(data []uint64, out []uint64) { + clear(out) +} + // MinHashes is a concatenation of [MinHash]es type MinHashes []MinHash @@ -117,47 +146,12 @@ func (me MinHash) Hash(data []uint64, out []uint64) { for j, m := range me { if (d & m) != 0 { out[i] = uint64(j) + break } } } } -var boxBlur3LUT = [8]uint64{ - 0, // 0b000, - 0, // 0b001, - 0, // 0b010, - 1, // 0b011, - 0, // 0b100, - 1, // 0b101, - 1, // 0b110, - 1, // 0b111, -} - -func boxBlur3(x uint64) uint64 { - var b uint64 - b = boxBlur3LUT[x&0b11] - for i := range 61 { - b |= boxBlur3LUT[x&0b111] << (i + 1) - x >>= 1 - } - return b -} - -// BoxBlur3 hashes values by applying a box blur with radius 3 (each bit in the output is the average of the 3 neighboring bits in the input) -type BoxBlur3 struct{} - -// Hash1 hashes a single uint64 value. -func (me BoxBlur3) Hash1(x uint64) uint64 { - return boxBlur3(x) -} - -// Hash hashes a slice of uint64 values. -func (me BoxBlur3) Hash(data []uint64, out []uint64) { - for i, d := range data { - out[i] = boxBlur3(d) - } -} - // Blur hashes values based on thresholding the number of bits in common with the given bitmasks. // For bitmasks of consecutive set bits, this is in effect a "blur" of the bit vector. type Blur struct { @@ -254,3 +248,19 @@ func RandomBitSampleR(numBitsSet int, rand *rand.Rand) BitSample { } return BitSample(out) } + +// BoxBlur generates a Blur that averages groups of neighboring bits for each bit in the output. +func BoxBlur(radius int, step int) Blur { + mask := uint64(1< dz { + zCloser++ + } + } + + if zCloser > yCloser { + t.Errorf("Expected Hash1(x) to be closer to Hash1(y) more often than Hash1(x) to be closer to Hash1(z), got %d and %d", yCloser, zCloser) + } + }) + t.Run("Blur_Hamming_LS_Property", func(t *testing.T) { x := uint64(0b1110) y := uint64(0b1100) @@ -171,7 +209,7 @@ func TestBlur(t *testing.T) { xyEqual := 0 xzEqual := 0 - trials := 1000 + trials := 10_000 for range trials { h := lsh.RandomBlurR(3, 10, testrandom.Source) @@ -350,7 +388,7 @@ func TestMinHashes(t *testing.T) { xyEqual := 0 xzEqual := 0 - trials := 1000 + trials := 10_000 for range trials { h := lsh.RandomMinHashesR(3, testrandom.Source) @@ -390,67 +428,33 @@ func TestHashFunc(t *testing.T) { } } -func TestBoxBlur3(t *testing.T) { - t.Run("BoxBlur3_Hash1", func(t *testing.T) { - var h lsh.BoxBlur3 - - testCases := []struct { - input uint64 - want uint64 - }{ - {0xF0F0F0F0, 0xF0F0F0F0}, - {0x0F0F0F0F, 0x0F0F0F0F}, - { - 0b11110010111100101111001011110010, - 0b11110001111100011111000111110000, - }, +func TestDummyHashes(t *testing.T) { + t.Run("NoHash", func(t *testing.T) { + var h lsh.NoHash + query := uint64(0x12345) + data := []uint64{0x12345, 0x54321} + out := make([]uint64, len(data)) + if h.Hash1(query) != query { + t.Fatal() } - - for _, tc := range testCases { - got := h.Hash1(tc.input) - if got != tc.want { - t.Errorf("BoxBlur3.Hash1(%x) = %x; want %x", tc.input, got, tc.want) - } + h.Hash(data, out) + if !reflect.DeepEqual(data, out) { + t.Fatal() } }) - - t.Run("BoxBlur3_Hash", func(t *testing.T) { - var h lsh.BoxBlur3 - - input := []uint64{0xF0F0F0F0, 0x0F0F0F0F, 0x72F2F2F2} - output := make([]uint64, len(input)) - want := []uint64{0xF0F0F0F0, 0x0F0F0F0F, 0x71F1F1F0} - - h.Hash(input, output) - - for i, v := range output { - if v != want[i] { - t.Errorf("BoxBlur3.Hash() for input %x = %x; want %x", input[i], v, want[i]) - } + t.Run("ConstantHash", func(t *testing.T) { + var h lsh.ConstantHash + q := uint64(0x12345) + data := []uint64{0x12345, 0x54321} + out := make([]uint64, len(data)) + if h.Hash1(q) != 0 { + t.Fatal() } - }) - - t.Run("BoxBlur3_Hamming_LS_Property", func(t *testing.T) { - xyEqual := 0 - xzEqual := 0 - trials := 1000 - var h lsh.BoxBlur3 - for range trials { - flip3Bits := uint64(lsh.RandomBitSampleR(3, testrandom.Source)) - flip10Bits := uint64(lsh.RandomBitSampleR(10, testrandom.Source)) - x := testrandom.Query() - y := x ^ flip3Bits - z := x ^ flip10Bits - if h.Hash1(x) == h.Hash1(y) { - xyEqual++ + h.Hash(data, out) + for i := range out { + if out[i] != 0 { + t.Fatal() } - if h.Hash1(x) == h.Hash1(z) { - xzEqual++ - } - } - - if xyEqual <= xzEqual { - t.Errorf("Expected Hash1(x) to equal Hash1(y) more often than Hash1(x) to equal Hash1(z), got %d and %d", xyEqual, xzEqual) } }) } diff --git a/lsh/model.go b/lsh/model.go index d4a1bdc..d4c0b7b 100644 --- a/lsh/model.go +++ b/lsh/model.go @@ -13,16 +13,19 @@ type Model struct { *bitknn.Model Hash Hash // LSH function mapping points to bucket IDs. - BucketIDs []uint64 // Bucket IDs. - Buckets map[uint64]slice.IndexRange // Bucket contents for each hash (offset+length in Data). - HeapBucketIDs []uint64 + BucketIDs []uint64 // Bucket IDs. + Buckets map[uint64]slice.IndexRange // Bucket contents for each hash (offset+length in Data). + + HeapBucketDistances []int + HeapBucketIDs []uint64 } // PreallocateHeap allocates memory for the nearest neighbor heap. func (me *Model) PreallocateHeap(k int) { + me.HeapBucketDistances = slice.OrAlloc(me.HeapBucketDistances, k+1) + me.HeapBucketIDs = slice.OrAlloc(me.HeapBucketIDs, k+1) me.HeapDistances = slice.OrAlloc(me.HeapDistances, k+1) me.HeapIndices = slice.OrAlloc(me.HeapIndices, k+1) - me.HeapBucketIDs = slice.OrAlloc(me.HeapBucketIDs, k+1) } // Fit creates and fits an LSH k-NN model using the provided data, labels, and hash function. @@ -65,22 +68,27 @@ func Fit(data []uint64, labels []int, hash Hash, opts ...bitknn.Option) *Model { // Predict1 predicts the label for a single input using the LSH model. func (me *Model) Predict1(k int, x uint64, votes []float64) int { + me.HeapBucketDistances = slice.OrAlloc(me.HeapBucketDistances, k+1) + me.HeapBucketIDs = slice.OrAlloc(me.HeapBucketIDs, k+1) me.HeapDistances = slice.OrAlloc(me.HeapDistances, k+1) me.HeapIndices = slice.OrAlloc(me.HeapIndices, k+1) - me.HeapBucketIDs = slice.OrAlloc(me.HeapBucketIDs, k+1) - return me.Predict1Into(k, x, votes, me.HeapDistances, me.HeapBucketIDs, me.HeapIndices) + return me.Predict1Into(k, x, votes, me.HeapBucketDistances, me.HeapBucketIDs, me.HeapDistances, me.HeapIndices) } // Predicts the label of a single input point. Each call allocates three new slices of length [k]+1 for the neighbor heaps. func (me *Model) Predict1Alloc(k int, x uint64, votes []float64) int { - distances, indices, bucketIDs := make([]int, k+1), make([]int, k+1), make([]uint64, k+1) - return me.Predict1Into(k, x, votes, distances, bucketIDs, indices) + bucketDistances := make([]int, k+1) + bucketIDs := make([]uint64, k+1) + distances := make([]int, k+1) + indices := make([]int, k+1) + + return me.Predict1Into(k, x, votes, bucketDistances, bucketIDs, distances, indices) } // Predict1Into predicts the label for a single input using the given slices (of length [k]+1 each) for the neighbor heaps. -func (me *Model) Predict1Into(k int, x uint64, votes []float64, distances []int, bucketIDs []uint64, indices []int) int { +func (me *Model) Predict1Into(k int, x uint64, votes []float64, bucketDistances []int, bucketIDs []uint64, distances []int, indices []int) int { xp := me.Hash.Hash1(x) - k, n := Nearest(me.Data, me.BucketIDs, me.Buckets, k, xp, x, distances, bucketIDs, indices) + k, n := Nearest(me.Data, me.BucketIDs, me.Buckets, k, xp, x, bucketDistances, bucketIDs, distances, indices) clear(votes) switch me.DistanceWeighting { diff --git a/lsh/model_bench_test.go b/lsh/model_bench_test.go index 3a6356c..b5d9449 100644 --- a/lsh/model_bench_test.go +++ b/lsh/model_bench_test.go @@ -15,17 +15,19 @@ func Benchmark_Model_Predict1(b *testing.B) { dataSize []int k []int } + mh := lsh.RandomMinHashR(testrandom.Source) + rbh := lsh.RandomBlurR(3, 20, testrandom.Source) hashes := []lsh.Hash{ lsh.NoHash{}, - lsh.BoxBlur3{}, - lsh.RandomBitSampleR(48, testrandom.Source), - lsh.RandomBlurR(3, 10, testrandom.Source), - lsh.RandomMinHashR(testrandom.Source), - lsh.RandomMinHashesR(3, testrandom.Source), + lsh.HashCompose{rbh, mh}, + lsh.RandomBitSampleR(30, testrandom.Source), + rbh, + mh, + lsh.RandomMinHashesR(10, testrandom.Source), } benches := []bench{ - {hashes: hashes, dataSize: []int{100}, k: []int{3, 10}}, - {hashes: hashes, dataSize: []int{1_000_000}, k: []int{1, 2, 3, 10, 100}}, + // {hashes: hashes, dataSize: []int{100}, k: []int{1, 3, 10}}, + {hashes: hashes, dataSize: []int{1_000_000}, k: []int{1, 1000}}, } for _, bench := range benches { for _, dataSize := range bench.dataSize { @@ -35,9 +37,7 @@ func Benchmark_Model_Predict1(b *testing.B) { for _, k := range bench.k { for _, hash := range bench.hashes { b.Run(fmt.Sprintf("hash=%T_N=%d_k=%d", hash, dataSize, k), func(b *testing.B) { - model := lsh.Fit(data, labels, hash) - model.PreallocateHeap(k) b.ResetTimer() for n := 0; n < b.N; n++ { diff --git a/lsh/model_test.go b/lsh/model_test.go index b9b9127..6f72ae1 100644 --- a/lsh/model_test.go +++ b/lsh/model_test.go @@ -1,7 +1,7 @@ package lsh_test import ( - "reflect" + "math" "testing" "github.com/keilerkonzept/bitknn" @@ -11,67 +11,102 @@ import ( func Test_Model_NoHash_IsExact(t *testing.T) { var h lsh.NoHash + _ = h + var h0 lsh.ConstantHash rapid.Check(t, func(t *rapid.T) { - k := rapid.IntRange(1, 100).Draw(t, "k") - data := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 64, func(a uint64) uint64 { return a }).Draw(t, "data") + k := rapid.IntRange(1, 1001).Draw(t, "k") + data := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 1000, func(a uint64) uint64 { return a }).Draw(t, "data") labels := rapid.SliceOfN(rapid.IntRange(0, 3), len(data), len(data)).Draw(t, "labels") values := rapid.SliceOfN(rapid.Float64(), len(data), len(data)).Draw(t, "values") queries := rapid.SliceOfN(rapid.Uint64(), 3, 64).Draw(t, "queries") knnVotes := make([]float64, 4) annVotes := make([]float64, 4) type pair struct { - KNN *bitknn.Model - ANN *lsh.Model + name string + KNN *bitknn.Model + ANN *lsh.Model + ANN0 *lsh.Model } pairs := []pair{ { + "V", bitknn.Fit(data, labels, bitknn.WithValues(values)), - lsh.Fit(data, labels, h, bitknn.WithValues(values)), + lsh.Fit(data, labels, h0, bitknn.WithValues(values)), + lsh.Fit(data, labels, h0, bitknn.WithValues(values)), }, { + "LV", bitknn.Fit(data, labels, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)), - lsh.Fit(data, labels, h, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)), + lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)), + lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)), }, { + "QV", bitknn.Fit(data, labels, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)), - lsh.Fit(data, labels, h, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)), + lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)), + lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)), }, { + "CV", bitknn.Fit(data, labels, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)), - lsh.Fit(data, labels, h, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)), + lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)), + lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)), }, { + "0", bitknn.Fit(data, labels), - lsh.Fit(data, labels, h), + lsh.Fit(data, labels, h0), + lsh.Fit(data, labels, h0), }, { + "L", bitknn.Fit(data, labels, bitknn.WithLinearDistanceWeighting()), - lsh.Fit(data, labels, h, bitknn.WithLinearDistanceWeighting()), + lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting()), + lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting()), }, { + "Q", bitknn.Fit(data, labels, bitknn.WithQuadraticDistanceWeighting()), - lsh.Fit(data, labels, h, bitknn.WithQuadraticDistanceWeighting()), + lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting()), + lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting()), }, { + "C", bitknn.Fit(data, labels, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)), - lsh.Fit(data, labels, h, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)), + lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)), + lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)), }, } for _, pair := range pairs { knn := pair.KNN ann := pair.ANN + ann0 := pair.ANN0 + knn.PreallocateHeap(k) + ann.PreallocateHeap(k) for _, q := range queries { - knn.PreallocateHeap(k) knn.Predict1(k, q, knnVotes) - ann.PreallocateHeap(k) + ann.Predict1(k, q, annVotes) - if !reflect.DeepEqual(knnVotes, annVotes) { - t.Fatalf("%v %v", knnVotes, annVotes) + const eps = 1e-8 + for i, vk := range knnVotes { + va := annVotes[i] + if math.Abs(vk-va) > eps { + t.Fatalf("%s: %v: %v %v", pair.name, q, knnVotes, annVotes) + } } - knn.Predict1Alloc(k, q, knnVotes) ann.Predict1Alloc(k, q, annVotes) - if !reflect.DeepEqual(knnVotes, annVotes) { - t.Fatalf("%v %v", knnVotes, annVotes) + for i, vk := range knnVotes { + va := annVotes[i] + if math.Abs(vk-va) > eps { + t.Fatalf("%s: %v: %v %v", pair.name, q, knnVotes, annVotes) + } + } + ann0.Predict1(k, q, annVotes) + for i, vk := range knnVotes { + va := annVotes[i] + if math.Abs(vk-va) > eps { + t.Fatalf("%s: %v: %v %v", pair.name, q, knnVotes, annVotes) + } } } } diff --git a/lsh/nearest.go b/lsh/nearest.go index 5c44af3..ee9c835 100644 --- a/lsh/nearest.go +++ b/lsh/nearest.go @@ -21,20 +21,66 @@ import ( // Returns: // - The number of nearest neighbors found. // - The total number of data points examined. -func Nearest(data []uint64, bucketIDs []uint64, buckets map[uint64]slice.IndexRange, k int, xh uint64, x uint64, distances []int, heapBucketIDs []uint64, indices []int) (int, int) { - k0 := nearestBuckets(bucketIDs, k, xh, distances, heapBucketIDs) - k1, n := nearestInBuckets(data, heapBucketIDs[:k0], buckets, k, x, distances, indices) - return k1, n +func Nearest(data []uint64, bucketIDs []uint64, buckets map[uint64]slice.IndexRange, k int, xh uint64, x uint64, bucketDistances []int, heapBucketIDs []uint64, distances []int, indices []int) (int, int) { + dataHeap := heap.MakeMax[int](distances, indices) + exactBucket := buckets[xh] + numExamined := exactBucket.Length + nearestInBucket(data, exactBucket, k, x, &distances[0], &dataHeap) + + // if the exact bucket already contains k neighbors, stop and return them + if dataHeap.Len() == k { + return k, exactBucket.Length + } + + // otherwise, determine the k nearest buckets and find the k nearest neighbors in these buckets. + bucketHeap := heap.MakeMax[uint64](bucketDistances, heapBucketIDs) + nearestBuckets(bucketIDs, k, xh, &bucketDistances[0], &bucketHeap) + n := nearestInBuckets(data, heapBucketIDs[:bucketHeap.Len()], buckets, k, x, xh, &distances[0], &dataHeap) + + return dataHeap.Len(), numExamined + n +} + +func nearestInBucket(data []uint64, b slice.IndexRange, k int, x uint64, distance0 *int, heap *heap.Max[int]) { + if b.Length == 0 { + return + } + + end := b.Offset + b.Length + end0 := b.Offset + min(b.Length, k) + + for i := b.Offset; i < end0; i++ { + dist := bits.OnesCount64(x ^ data[i]) + heap.Push(dist, i) + } + + if b.Length < k { + return + } + + maxDist := *distance0 + for i := b.Offset + k; i < end; i++ { + dist := bits.OnesCount64(x ^ data[i]) + if dist >= maxDist { + continue + } + heap.PushPop(dist, i) + maxDist = *distance0 + } } // nearestInBuckets finds the nearest neighbors within specific buckets. -// It returns the number of neighbors found and the total number of points examined. -func nearestInBuckets(data []uint64, inBuckets []uint64, buckets map[uint64]slice.IndexRange, k int, x uint64, distances []int, indices []int) (int, int) { - heap := heap.MakeMax[int](distances, indices) +// Returns the number of points examined. +func nearestInBuckets(data []uint64, inBuckets []uint64, buckets map[uint64]slice.IndexRange, k int, x, xh uint64, distance0 *int, heap *heap.Max[int]) int { var maxDist int - j := 0 + j := heap.Len() + if j > 0 { + maxDist = *distance0 + } t := 0 for _, bid := range inBuckets { + if bid == xh { // skip exact bucket + continue + } b := buckets[bid] end := b.Offset + b.Length t += b.Length @@ -45,7 +91,7 @@ func nearestInBuckets(data []uint64, inBuckets []uint64, buckets map[uint64]slic continue } heap.PushPop(dist, i) - maxDist = distances[0] + maxDist = *distance0 } continue } @@ -53,7 +99,7 @@ func nearestInBuckets(data []uint64, inBuckets []uint64, buckets map[uint64]slic dist := bits.OnesCount64(x ^ data[i]) if j < k { heap.Push(dist, i) - maxDist = distances[0] + maxDist = *distance0 j++ continue } @@ -61,20 +107,14 @@ func nearestInBuckets(data []uint64, inBuckets []uint64, buckets map[uint64]slic continue } heap.PushPop(dist, i) - maxDist = distances[0] + maxDist = *distance0 } } - if j < k { - return j, t - } - return k, t + return t } // nearestBuckets finds the buckets with IDs that are (Hamming-)nearest to a query point hash. -// It returns the number of nearest buckets found. -func nearestBuckets(bucketIDs []uint64, k int, x uint64, distances []int, heapBucketIDs []uint64) int { - heap := heap.MakeMax[uint64](distances, heapBucketIDs) - +func nearestBuckets(bucketIDs []uint64, k int, x uint64, distance0 *int, heap *heap.Max[uint64]) { k0 := min(k, len(bucketIDs)) var maxDist int for _, b := range bucketIDs[:k0] { @@ -82,16 +122,16 @@ func nearestBuckets(bucketIDs []uint64, k int, x uint64, distances []int, heapBu heap.Push(dist, b) } if k0 < k { - return k0 + return } - maxDist = distances[0] + maxDist = *distance0 for _, b := range bucketIDs[k0:] { dist := bits.OnesCount64(x ^ b) if dist >= maxDist { continue } heap.PushPop(dist, b) - maxDist = distances[0] + maxDist = *distance0 } - return k + return } diff --git a/lsh/nearest_test.go b/lsh/nearest_test.go index 3219557..95fa33a 100644 --- a/lsh/nearest_test.go +++ b/lsh/nearest_test.go @@ -17,12 +17,13 @@ func TestNearest(t *testing.T) { } k := 10 distances := make([]int, k+1) + bucketDistances := make([]int, k+1) heapBucketIDs := make([]uint64, k+1) indices := make([]int, k+1) x := uint64(5) xh := uint64(1) - k, n := lsh.Nearest(data, bucketIDs, buckets, k, x, xh, distances, heapBucketIDs, indices) + k, n := lsh.Nearest(data, bucketIDs, buckets, k, x, xh, bucketDistances, heapBucketIDs, distances, indices) if 8 != k { t.Fatal(k) @@ -40,12 +41,13 @@ func TestNearest(t *testing.T) { } k := 3 distances := make([]int, k+1) + bucketDistances := make([]int, k+1) heapBucketIDs := make([]uint64, k+1) indices := make([]int, k+1) x := uint64(5) xh := uint64(1) - k, n := lsh.Nearest(data, bucketIDs, buckets, k, x, xh, distances, heapBucketIDs, indices) + k, n := lsh.Nearest(data, bucketIDs, buckets, k, x, xh, bucketDistances, heapBucketIDs, distances, indices) if 3 != k { t.Fatal(k) @@ -65,13 +67,14 @@ func TestNearest(t *testing.T) { } k := 3 distances := make([]int, k+1) + bucketDistances := make([]int, k+1) heapBucketIDs := make([]uint64, k+1) indices := make([]int, k+1) { x := uint64(5) xh := uint64(1) - k, n := lsh.Nearest(data, bucketIDs, buckets, k, x, xh, distances, heapBucketIDs, indices) + k, n := lsh.Nearest(data, bucketIDs, buckets, k, x, xh, bucketDistances, heapBucketIDs, distances, indices) if 3 != k { t.Fatal(k) @@ -83,7 +86,7 @@ func TestNearest(t *testing.T) { { x := uint64(4) xh := uint64(2) - k, n := lsh.Nearest(data, bucketIDs, buckets, k, x, xh, distances, heapBucketIDs, indices) + k, n := lsh.Nearest(data, bucketIDs, buckets, k, x, xh, bucketDistances, heapBucketIDs, distances, indices) if 3 != k { t.Fatal(k) diff --git a/model.go b/model.go index 6d5fb04..5d6a031 100644 --- a/model.go +++ b/model.go @@ -1,6 +1,8 @@ package bitknn -import "github.com/keilerkonzept/bitknn/internal/slice" +import ( + "github.com/keilerkonzept/bitknn/internal/slice" +) // Create a k-NN model for the given data points and labels. func Fit(data []uint64, labels []int, opts ...Option) *Model { diff --git a/nearest.go b/nearest.go index fd20523..b61f842 100644 --- a/nearest.go +++ b/nearest.go @@ -16,7 +16,7 @@ func Nearest(data []uint64, k int, x uint64, distances, indices []int) int { heap := heap.MakeMax(distances, indices) k0 := min(k, len(data)) - var maxDist int + for i := 0; i < k0; i++ { dist := bits.OnesCount64(x ^ data[i]) heap.Push(dist, i) @@ -24,7 +24,8 @@ func Nearest(data []uint64, k int, x uint64, distances, indices []int) int { if k0 < k { return k0 } - maxDist = distances[0] + + maxDist := distances[0] for i := k; i < len(data); i++ { dist := bits.OnesCount64(x ^ data[i]) if dist >= maxDist {