Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sgreben committed Oct 9, 2024
1 parent 8055df1 commit a1f8dec
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 36 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
*.txt
*.test
testdata
*.gz
13 changes: 3 additions & 10 deletions lsh/model_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,12 @@ func Benchmark_Model_Predict1(b *testing.B) {
dataSize []int
k []int
}
mh := lsh.RandomMinHashR(testrandom.Source)
rbh := lsh.RandomBlurR(3, 20, testrandom.Source)
hashes := []lsh.Hash{
lsh.NoHash{},
lsh.HashCompose{rbh, mh},
lsh.RandomBitSampleR(30, testrandom.Source),
rbh,
mh,
lsh.RandomMinHashesR(10, testrandom.Source),
lsh.ConstantHash{}, // should be only a bit slower than exact KNN
}
benches := []bench{
// {hashes: hashes, dataSize: []int{100}, k: []int{1, 3, 10}},
{hashes: hashes, dataSize: []int{1_000_000}, k: []int{1, 1000}},
{hashes: hashes, dataSize: []int{100}, k: []int{1, 3, 10}},
{hashes: hashes, dataSize: []int{1_000_000}, k: []int{3, 10, 100}},
}
for _, bench := range benches {
for _, dataSize := range bench.dataSize {
Expand Down
44 changes: 23 additions & 21 deletions lsh/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package lsh_test

import (
"math"
"reflect"
"slices"
"testing"

"github.com/keilerkonzept/bitknn"
Expand All @@ -11,14 +13,14 @@ import (

func Test_Model_NoHash_IsExact(t *testing.T) {
var h lsh.NoHash
_ = h
var h0 lsh.ConstantHash
id := func(a uint64) uint64 { return a }
rapid.Check(t, func(t *rapid.T) {
k := rapid.IntRange(1, 1001).Draw(t, "k")
data := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 1000, func(a uint64) uint64 { return a }).Draw(t, "data")
k := rapid.IntRange(3, 1001).Draw(t, "k")
data := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 1000, id).Draw(t, "data")
labels := rapid.SliceOfN(rapid.IntRange(0, 3), len(data), len(data)).Draw(t, "labels")
values := rapid.SliceOfN(rapid.Float64(), len(data), len(data)).Draw(t, "values")
queries := rapid.SliceOfN(rapid.Uint64(), 3, 64).Draw(t, "queries")
queries := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 64, id).Draw(t, "queries")
knnVotes := make([]float64, 4)
annVotes := make([]float64, 4)
type pair struct {
Expand All @@ -31,52 +33,53 @@ func Test_Model_NoHash_IsExact(t *testing.T) {
{
"V",
bitknn.Fit(data, labels, bitknn.WithValues(values)),
lsh.Fit(data, labels, h0, bitknn.WithValues(values)),
lsh.Fit(data, labels, h, bitknn.WithValues(values)),
lsh.Fit(data, labels, h0, bitknn.WithValues(values)),
},
{
"LV",
bitknn.Fit(data, labels, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)),
lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)),
lsh.Fit(data, labels, h, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)),
lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)),
},
{
"QV",
bitknn.Fit(data, labels, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)),
lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)),
lsh.Fit(data, labels, h, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)),
lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)),
},
{
"CV",
bitknn.Fit(data, labels, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)),
lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)),
lsh.Fit(data, labels, h, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)),
lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)),
},
{
"0",
bitknn.Fit(data, labels),
lsh.Fit(data, labels, h0),
lsh.Fit(data, labels, h),
lsh.Fit(data, labels, h0),
},
{
"L",
bitknn.Fit(data, labels, bitknn.WithLinearDistanceWeighting()),
lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting()),
lsh.Fit(data, labels, h, bitknn.WithLinearDistanceWeighting()),
lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting()),
},
{
"Q",
bitknn.Fit(data, labels, bitknn.WithQuadraticDistanceWeighting()),
lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting()),
lsh.Fit(data, labels, h, bitknn.WithQuadraticDistanceWeighting()),
lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting()),
},
{
"C",
bitknn.Fit(data, labels, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)),
lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)),
lsh.Fit(data, labels, h, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)),
lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)),
},
}
const eps = 1e-8
for _, pair := range pairs {
knn := pair.KNN
ann := pair.ANN
Expand All @@ -87,25 +90,24 @@ func Test_Model_NoHash_IsExact(t *testing.T) {
knn.Predict1(k, q, knnVotes)

ann.Predict1(k, q, annVotes)
const eps = 1e-8
for i, vk := range knnVotes {
va := annVotes[i]
if math.Abs(vk-va) > eps {
t.Fatalf("%s: %v: %v %v", pair.name, q, knnVotes, annVotes)
}
slices.Sort(knn.HeapDistances[:k])
slices.Sort(ann.HeapDistances[:k])
if !reflect.DeepEqual(knn.HeapDistances[:k], ann.HeapDistances[:k]) {
t.Fatal("NoHash ANN should result in the same distances for the nearest neighbors: ", knn.HeapDistances[:k], ann.HeapDistances[:k], knn.HeapIndices[:k], ann.HeapIndices[:k])
}
ann.Predict1Alloc(k, q, annVotes)

ann0.Predict1Alloc(k, q, annVotes)
for i, vk := range knnVotes {
va := annVotes[i]
if math.Abs(vk-va) > eps {
t.Fatalf("%s: %v: %v %v", pair.name, q, knnVotes, annVotes)
t.Fatalf("ANN: %s: %v: %v %v", pair.name, q, knnVotes, annVotes)
}
}
ann0.Predict1(k, q, annVotes)
for i, vk := range knnVotes {
va := annVotes[i]
if math.Abs(vk-va) > eps {
t.Fatalf("%s: %v: %v %v", pair.name, q, knnVotes, annVotes)
t.Fatalf("ANN0: %s: %v: %v %v", pair.name, q, knnVotes, annVotes)
}
}
}
Expand Down
1 change: 0 additions & 1 deletion lsh/nearest.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,4 @@ func nearestBuckets(bucketIDs []uint64, k int, x uint64, distance0 *int, heap *h
heap.PushPop(dist, b)
maxDist = *distance0
}
return
}
10 changes: 6 additions & 4 deletions nearest.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,27 @@ import (
// cap(distances) = cap(indices) = k+1 >= 1
func Nearest(data []uint64, k int, x uint64, distances, indices []int) int {
heap := heap.MakeMax(distances, indices)
distance0 := &distances[0]

k0 := min(k, len(data))

for i := 0; i < k0; i++ {
dist := bits.OnesCount64(x ^ data[i])
for i, d := range data[:k0] {
dist := bits.OnesCount64(x ^ d)
heap.Push(dist, i)
}

if k0 < k {
return k0
}

maxDist := distances[0]
maxDist := *distance0
for i := k; i < len(data); i++ {
dist := bits.OnesCount64(x ^ data[i])
if dist >= maxDist {
continue
}
heap.PushPop(dist, i)
maxDist = distances[0]
maxDist = *distance0
}
return k
}

0 comments on commit a1f8dec

Please sign in to comment.