From fe5159c1ff03f0d2506cbad13405dbe4bce10f5f Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Fri, 5 Jul 2024 10:08:32 +0530 Subject: [PATCH] perf(vector): updated marshalling of vector --- dgraphtest/config.go | 16 +- dgraphtest/local_cluster.go | 32 ++- go.mod | 24 +- go.sum | 40 +-- posting/index.go | 123 ++++++++- posting/lists.go | 3 +- posting/mvcc.go | 2 +- query/vector/vector_test.go | 14 +- tok/hnsw/heap.go | 4 + tok/hnsw/helper.go | 201 +++++++++------ tok/hnsw/persistent_hnsw.go | 82 +++--- tok/hnsw/persistent_hnsw_test.go | 88 +------ tok/hnsw/search_layer.go | 19 +- tok/index/helper.go | 46 ++-- tok/index/helper_test.go | 423 +++++++++++++++++++++++++++++++ worker/task.go | 4 +- 16 files changed, 844 insertions(+), 277 deletions(-) create mode 100644 tok/index/helper_test.go diff --git a/dgraphtest/config.go b/dgraphtest/config.go index 95c7b7d959d..5bc4252ce53 100644 --- a/dgraphtest/config.go +++ b/dgraphtest/config.go @@ -147,14 +147,14 @@ func NewClusterConfig() ClusterConfig { } } -func newClusterConfigFrom(cc ClusterConfig) ClusterConfig { - prefix := fmt.Sprintf("dgraphtest-%d", rand.NewSource(time.Now().UnixNano()).Int63()%1000000) - defaultBackupVol := fmt.Sprintf("%v_backup", prefix) - defaultExportVol := fmt.Sprintf("%v_export", prefix) - cc.prefix = prefix - cc.volumes = map[string]string{DefaultBackupDir: defaultBackupVol, DefaultExportDir: defaultExportVol} - return cc -} +//func newClusterConfigFrom(cc ClusterConfig) ClusterConfig { +// prefix := fmt.Sprintf("dgraphtest-%d", rand.NewSource(time.Now().UnixNano()).Int63()%1000000) +// defaultBackupVol := fmt.Sprintf("%v_backup", prefix) +// defaultExportVol := fmt.Sprintf("%v_export", prefix) +// cc.prefix = prefix +// cc.volumes = map[string]string{DefaultBackupDir: defaultBackupVol, DefaultExportDir: defaultExportVol} +// return cc +//} // WithNAlphas sets the number of alphas in the cluster func (cc ClusterConfig) WithNumAlphas(n int) ClusterConfig { diff --git a/dgraphtest/local_cluster.go b/dgraphtest/local_cluster.go index 7e6dc8a2dc8..0248255fa47 100644 --- a/dgraphtest/local_cluster.go +++ b/dgraphtest/local_cluster.go @@ -23,6 +23,7 @@ import ( "fmt" "io" "log" + "math/rand" "net/http" "os" "os/exec" @@ -34,6 +35,7 @@ import ( "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/network" "github.com/docker/docker/api/types/volume" docker "github.com/docker/docker/client" @@ -402,6 +404,26 @@ func (c *LocalCluster) Cleanup(verbose bool) { } } +func (c *LocalCluster) cleanupDocker() error { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + defer cancel() + // Prune containers + contsReport, err := c.dcli.ContainersPrune(ctx, filters.Args{}) + if err != nil { + log.Fatalf("[ERROR] Error pruning containers: %v", err) + } + log.Printf("[INFO] Pruned containers: %+v\n", contsReport) + + // Prune networks + netsReport, err := c.dcli.NetworksPrune(ctx, filters.Args{}) + if err != nil { + log.Fatalf("[ERROR] Error pruning networks: %v", err) + } + log.Printf("[INFO] Pruned networks: %+v\n", netsReport) + + return nil +} + func (c *LocalCluster) Start() error { log.Printf("[INFO] starting cluster with prefix [%v]", c.conf.prefix) startAll := func() error { @@ -433,10 +455,16 @@ func (c *LocalCluster) Start() error { log.Printf("[WARNING] saw the err, trying again: %v", err) } - log.Printf("[INFO] cleaning up the cluster for retrying!") + if err1 := c.Stop(); err1 != nil { + log.Printf("[WARNING] error while stopping :%v", err1) + } c.Cleanup(true) - c.conf = newClusterConfigFrom(c.conf) + if err := c.cleanupDocker(); err != nil { + log.Printf("[ERROR] while cleaning old dockers %v", err) + } + + c.conf.prefix = fmt.Sprintf("dgraphtest-%d", rand.NewSource(time.Now().UnixNano()).Int63()%1000000) if err := c.init(); err != nil { log.Printf("[ERROR] error while init, returning: %v", err) return err diff --git a/go.mod b/go.mod index 8700071976f..6b41e41fbe6 100644 --- a/go.mod +++ b/go.mod @@ -13,8 +13,8 @@ require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 github.com/IBM/sarama v1.41.0 github.com/Masterminds/semver/v3 v3.1.0 + github.com/bits-and-blooms/bitset v1.2.0 github.com/blevesearch/bleve/v2 v2.3.10 - github.com/chewxy/math32 v1.10.1 github.com/dgraph-io/badger/v4 v4.2.0 github.com/dgraph-io/dgo/v230 v230.0.2-0.20240314155021-7b8d289e37f3 github.com/dgraph-io/gqlgen v0.13.2 @@ -54,19 +54,20 @@ require ( github.com/spf13/viper v1.7.1 github.com/stretchr/testify v1.9.0 github.com/twpayne/go-geom v1.0.5 + github.com/viterin/vek v0.4.2 github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c go.etcd.io/etcd/raft/v3 v3.5.9 go.opencensus.io v0.24.0 go.uber.org/zap v1.16.0 - golang.org/x/crypto v0.21.0 - golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 - golang.org/x/mod v0.16.0 - golang.org/x/net v0.22.0 - golang.org/x/sync v0.6.0 - golang.org/x/sys v0.18.0 - golang.org/x/term v0.18.0 - golang.org/x/text v0.14.0 - golang.org/x/tools v0.19.0 + golang.org/x/crypto v0.24.0 + golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 + golang.org/x/mod v0.18.0 + golang.org/x/net v0.26.0 + golang.org/x/sync v0.7.0 + golang.org/x/sys v0.22.0 + golang.org/x/term v0.21.0 + golang.org/x/text v0.16.0 + golang.org/x/tools v0.22.0 google.golang.org/grpc v1.62.1 gopkg.in/square/go-jose.v2 v2.3.1 gopkg.in/yaml.v2 v2.4.0 @@ -78,7 +79,6 @@ require ( github.com/agnivade/levenshtein v1.0.3 // indirect github.com/apache/thrift v0.13.0 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/bits-and-blooms/bitset v1.2.0 // indirect github.com/blevesearch/bleve_index_api v1.0.6 // indirect github.com/blevesearch/geo v0.1.18 // indirect github.com/blevesearch/go-porterstemmer v1.0.3 // indirect @@ -86,6 +86,7 @@ require ( github.com/blevesearch/snowballstem v0.9.0 // indirect github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/chewxy/math32 v1.10.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/go-units v0.4.0 // indirect @@ -140,6 +141,7 @@ require ( github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/subosito/gotenv v1.4.2 // indirect github.com/tinylib/msgp v1.1.2 // indirect + github.com/viterin/partial v1.1.0 // indirect github.com/xdg/stringprep v1.0.3 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.10.0 // indirect diff --git a/go.sum b/go.sum index 9d5cfb73420..3b68aebc34a 100644 --- a/go.sum +++ b/go.sum @@ -685,6 +685,10 @@ github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPU github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= github.com/vektah/dataloaden v0.2.1-0.20190515034641-a19b9a6e7c9e/go.mod h1:/HUdMve7rvxZma+2ZELQeNh88+003LL7Pf/CZ089j8U= github.com/vektah/gqlparser/v2 v2.1.0/go.mod h1:SyUiHgLATUR8BiYURfTirrTcGpcE+4XkV2se04Px1Ms= +github.com/viterin/partial v1.1.0 h1:iH1l1xqBlapXsYzADS1dcbizg3iQUKTU1rbwkHv/80E= +github.com/viterin/partial v1.1.0/go.mod h1:oKGAo7/wylWkJTLrWX8n+f4aDPtQMQ6VG4dd2qur5QA= +github.com/viterin/vek v0.4.2 h1:Vyv04UjQT6gcjEFX82AS9ocgNbAJqsHviheIBdPlv5U= +github.com/viterin/vek v0.4.2/go.mod h1:A4JRAe8OvbhdzBL5ofzjBS0J29FyUrf95tQogvtHHUc= github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk= github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= github.com/xdg/stringprep v1.0.3 h1:cmL5Enob4W83ti/ZHuZLuKD/xqJfus4fVPwE+/BDm+4= @@ -743,8 +747,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -758,8 +762,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw= -golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -786,8 +790,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= -golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -834,8 +838,8 @@ golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= -golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -856,8 +860,8 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -916,13 +920,13 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= +golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= +golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -934,8 +938,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1005,8 +1009,8 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/posting/index.go b/posting/index.go index 38d3e4c6094..3257583899f 100644 --- a/posting/index.go +++ b/posting/index.go @@ -649,6 +649,120 @@ type rebuilder struct { fn func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) } +func (r *rebuilder) RunWithoutTemp(ctx context.Context) error { + stream := pstore.NewStreamAt(r.startTs) + stream.LogPrefix = fmt.Sprintf("Rebuilding index for predicate %s (1/2):", r.attr) + stream.Prefix = r.prefix + stream.NumGo = 128 + txn := NewTxn(r.startTs) + stream.KeyToList = func(key []byte, it *badger.Iterator) (*bpb.KVList, error) { + // We should return quickly if the context is no longer valid. + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + pk, err := x.Parse(key) + if err != nil { + return nil, errors.Wrapf(err, "could not parse key %s", hex.Dump(key)) + } + + l := new(List) + l.key = key + l.plist = new(pb.PostingList) + + found := false + + for it.Valid() { + item := it.Item() + if !bytes.Equal(item.Key(), l.key) { + break + } + l.maxTs = x.Max(l.maxTs, item.Version()) + if item.IsDeletedOrExpired() { + // Don't consider any more versions. + break + } + + found = true + switch item.UserMeta() { + case BitEmptyPosting: + l.minTs = item.Version() + case BitCompletePosting: + if err := unmarshalOrCopy(l.plist, item); err != nil { + return nil, err + } + l.minTs = item.Version() + + // No need to do Next here. The outer loop can take care of skipping + // more versions of the same key. + case BitDeltaPosting: + err := item.Value(func(val []byte) error { + pl := &pb.PostingList{} + if err := pl.Unmarshal(val); err != nil { + return err + } + pl.CommitTs = item.Version() + for _, mpost := range pl.Postings { + // commitTs, startTs are meant to be only in memory, not + // stored on disk. + mpost.CommitTs = item.Version() + } + if l.mutationMap == nil { + l.mutationMap = make(map[uint64]*pb.PostingList) + } + l.mutationMap[pl.CommitTs] = pl + return nil + }) + if err != nil { + return nil, err + } + default: + return nil, errors.Errorf( + "Unexpected meta: %d for key: %s", item.UserMeta(), hex.Dump(key)) + } + if found { + break + } + } + + if _, err := r.fn(pk.Uid, l, txn); err != nil { + return nil, err + } + return nil, nil + } + stream.Send = func(buf *z.Buffer) error { + // TODO. Make an in memory txn with disk backing for more data than memory. + return nil + } + + start := time.Now() + if err := stream.Orchestrate(ctx); err != nil { + return err + } + + txn.Update() + writer := NewTxnWriter(pstore) + + defer func() { + glog.V(1).Infof("Rebuilding index for predicate %s: building index took: %v\n", + r.attr, time.Since(start)) + }() + + ResetCache() + + return x.ExponentialRetry(int(x.Config.MaxRetries), + 20*time.Millisecond, func() error { + err := txn.CommitToDisk(writer, r.startTs) + if err == badger.ErrBannedKey { + glog.Errorf("Error while writing to banned namespace.") + return nil + } + return err + }) +} + func (r *rebuilder) Run(ctx context.Context) error { if r.startTs == 0 { glog.Infof("maxassigned is 0, no indexing work for predicate %s", r.attr) @@ -1175,6 +1289,8 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { factorySpecs = []*tok.FactoryCreateSpec{factorySpec} } + runForVectors := (len(factorySpecs) != 0) + pk := x.ParsedKey{Attr: rb.Attr} builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { @@ -1200,7 +1316,9 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { case ErrRetry: time.Sleep(10 * time.Millisecond) default: - edges = append(edges, newEdges...) + if !runForVectors { + edges = append(edges, newEdges...) + } return err } } @@ -1210,6 +1328,9 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { } return edges, err } + if len(factorySpecs) != 0 { + return builder.RunWithoutTemp(ctx) + } return builder.Run(ctx) } diff --git a/posting/lists.go b/posting/lists.go index f00995ef8f2..b1abd569cf7 100644 --- a/posting/lists.go +++ b/posting/lists.go @@ -242,8 +242,7 @@ func (lc *LocalCache) Find(pred []byte, filter func([]byte) bool) (uint64, error } if filter(vals.Value.([]byte)) { - result.Uids = append(result.Uids, pk.Uid) - break + return pk.Uid, nil } continue diff --git a/posting/mvcc.go b/posting/mvcc.go index 513f9912820..27cdc6fe7a9 100644 --- a/posting/mvcc.go +++ b/posting/mvcc.go @@ -535,7 +535,7 @@ func (c *CachePL) Set(l *List, readTs uint64) { } func ShouldGoInCache(pk x.ParsedKey) bool { - return !pk.IsData() && strings.HasSuffix(pk.Attr, "dgraph.type") + return (!pk.IsData() && strings.HasSuffix(pk.Attr, "dgraph.type")) } func getNew(key []byte, pstore *badger.DB, readTs uint64) (*List, error) { diff --git a/query/vector/vector_test.go b/query/vector/vector_test.go index 2332a1f8811..eb5568d42c6 100644 --- a/query/vector/vector_test.go +++ b/query/vector/vector_test.go @@ -429,7 +429,7 @@ func TestVectorsMutateFixedLengthWithDiffrentIndexes(t *testing.T) { testVectorMutationSameLength(t) dropPredicate("vtest") - setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "dot_product")) + setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "dotproduct")) testVectorMutationSameLength(t) dropPredicate("vtest") } @@ -455,8 +455,8 @@ func TestVectorDeadlockwithTimeout(t *testing.T) { }) dropPredicate(pred) setSchema(fmt.Sprintf(vectorSchemaWithIndex, pred, "4", "euclidian")) - numVectors := 1000 - vectorSize := 10 + numVectors := 10000 + vectorSize := 1000 randomVectors, _ := generateRandomVectors(numVectors, vectorSize, pred) @@ -480,15 +480,15 @@ func TestVectorMutateDiffrentLengthWithDiffrentIndexes(t *testing.T) { dropPredicate("vtest") setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "euclidian")) - testVectorMutationDiffrentLength(t, "can not subtract vectors of different lengths") + testVectorMutationDiffrentLength(t, "can not compute euclidian distance on vectors of different lengths") dropPredicate("vtest") setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "cosine")) - testVectorMutationDiffrentLength(t, "can not compute dot product on vectors of different lengths") + testVectorMutationDiffrentLength(t, "can not compute cosine distance on vectors of different lengths") dropPredicate("vtest") - setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "dot_product")) - testVectorMutationDiffrentLength(t, "can not subtract vectors of different lengths") + setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "dotproduct")) + testVectorMutationDiffrentLength(t, "can not compute dot product on vectors of different lengths") dropPredicate("vtest") } diff --git a/tok/hnsw/heap.go b/tok/hnsw/heap.go index 24c6edbe8fe..da165f835ae 100644 --- a/tok/hnsw/heap.go +++ b/tok/hnsw/heap.go @@ -44,6 +44,10 @@ func (h *minPersistentTupleHeap[T]) Push(x interface{}) { *h = append(*h, x.(minPersistentHeapElement[T])) } +func (h *minPersistentTupleHeap[T]) PopLast() { + heap.Remove(h, h.Len()-1) +} + func (h *minPersistentTupleHeap[T]) Pop() interface{} { old := *h n := len(old) diff --git a/tok/hnsw/helper.go b/tok/hnsw/helper.go index 033ef40a99d..b58a911f072 100644 --- a/tok/hnsw/helper.go +++ b/tok/hnsw/helper.go @@ -3,19 +3,21 @@ package hnsw import ( "context" "encoding/binary" - "encoding/json" + "fmt" "log" "math" "math/rand" "sort" "strconv" "strings" + "unsafe" - "github.com/chewxy/math32" c "github.com/dgraph-io/dgraph/tok/constraints" "github.com/dgraph-io/dgraph/tok/index" "github.com/getsentry/sentry-go" "github.com/pkg/errors" + "github.com/viterin/vek" + "github.com/viterin/vek/vek32" ) const ( @@ -61,62 +63,47 @@ func (s *SearchResult) GetExtraMetrics() map[string]uint64 { return s.extraMetrics } -func norm[T c.Float](v []T, floatBits int) T { - vectorNorm, _ := dotProduct(v, v, floatBits) - if floatBits == 32 { - return T(math32.Sqrt(float32(vectorNorm))) +func applyDistanceFunction[T c.Float](a, b []T, floatBits int, funcName string, + applyFn32 func(a, b []float32) float32, applyFn64 func(a, b []float64) float64) (T, error) { + if len(a) != len(b) { + err := errors.New(fmt.Sprintf("can not compute %s on vectors of different lengths", funcName)) + return T(0), err } - if floatBits == 64 { - return T(math.Sqrt(float64(vectorNorm))) + + if floatBits == 32 { + var a1, b1 []float32 + a1 = *(*[]float32)(unsafe.Pointer(&a)) + b1 = *(*[]float32)(unsafe.Pointer(&b)) + return T(applyFn32(a1, b1)), nil + } else if floatBits == 64 { + var a1, b1 []float64 + a1 = *(*[]float64)(unsafe.Pointer(&a)) + b1 = *(*[]float64)(unsafe.Pointer(&b)) + return T(applyFn64(a1, b1)), nil } - panic("Invalid floatBits") + + panic("While applying function on two floats, found an invalid number of float bits") + } // This needs to implement signature of SimilarityType[T].distanceScore // function, hence it takes in a floatBits parameter, // but doesn't actually use it. func dotProduct[T c.Float](a, b []T, floatBits int) (T, error) { - var dotProduct T - if len(a) != len(b) { - err := errors.New("can not compute dot product on vectors of different lengths") - return dotProduct, err - } - for i := range a { - dotProduct += a[i] * b[i] - } - return dotProduct, nil + return applyDistanceFunction(a, b, floatBits, "dot product", vek32.Dot, vek.Dot) } // This needs to implement signature of SimilarityType[T].distanceScore // function, hence it takes in a floatBits parameter. func cosineSimilarity[T c.Float](a, b []T, floatBits int) (T, error) { - dotProd, err := dotProduct(a, b, floatBits) - if err != nil { - return 0, err - } - normA := norm[T](a, floatBits) - normB := norm[T](b, floatBits) - if normA == 0 || normB == 0 { - err := errors.New("can not compute cosine similarity on zero vector") - var empty T - return empty, err - } - return dotProd / (normA * normB), nil + return applyDistanceFunction(a, b, floatBits, "cosine distance", vek32.CosineSimilarity, vek.CosineSimilarity) } // This needs to implement signature of SimilarityType[T].distanceScore // function, hence it takes in a floatBits parameter, // but doesn't actually use it. func euclidianDistanceSq[T c.Float](a, b []T, floatBits int) (T, error) { - if len(a) != len(b) { - return 0, errors.New("can not subtract vectors of different lengths") - } - var distSq T - for i := range a { - val := a[i] - b[i] - distSq += val * val - } - return distSq, nil + return applyDistanceFunction(a, b, floatBits, "euclidian distance", vek32.Distance, vek.Distance) } // Used for distance, since shorter distance is better @@ -207,24 +194,6 @@ func cannotConvertToUintSlice(s string) error { return errors.Errorf("Cannot convert %s to uint slice", s) } -func diff(a []uint64, b []uint64) []uint64 { - // Turn b into a map - m := make(map[uint64]bool, len(b)) - for _, s := range b { - m[s] = false - } - // Append values from the longest slice that don't exist in the map - var diff []uint64 - for _, s := range a { - if _, ok := m[s]; !ok { - diff = append(diff, s) - continue - } - m[s] = true - } - return diff -} - // TODO: Move SimilarityType to index package. // // Remove "hnsw-isms". @@ -335,7 +304,7 @@ func populateEdgeDataFromKeyWithCacheType( if data == nil { return false, nil } - err = json.Unmarshal(data.([]byte), &edgeData) + err = decodeUint64MatrixUnsafe(data.([]byte), edgeData) return true, err } @@ -444,6 +413,70 @@ func (ph *persistentHNSW[T]) createEntryAndStartNodes( return entry, edges, nil } +// Converts the matrix into linear array that looks like +// [0: Number of rows 1: Length of row1 2-n: Data of row1 3: Length of row2 ..] +func encodeUint64MatrixUnsafe(matrix [][]uint64) []byte { + if len(matrix) == 0 { + return nil + } + + // Calculate the total size + var totalSize uint64 + for _, row := range matrix { + totalSize += uint64(len(row))*uint64(unsafe.Sizeof(uint64(0))) + uint64(unsafe.Sizeof(uint64(0))) + } + totalSize += uint64(unsafe.Sizeof(uint64(0))) + + // Create a byte slice with the appropriate size + data := make([]byte, totalSize) + + offset := 0 + // Write number of rows + rows := uint64(len(matrix)) + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&rows))[:]) + offset += 8 + + // Write each row's length and data + for _, row := range matrix { + rowLen := uint64(len(row)) + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&rowLen))[:]) + offset += 8 + for i := range row { + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&row[i]))[:]) + offset += 8 + } + } + + return data +} + +func decodeUint64MatrixUnsafe(data []byte, matrix *[][]uint64) error { + if len(data) == 0 { + return nil + } + + offset := 0 + // Read number of rows + rows := *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + + *matrix = make([][]uint64, rows) + + for i := 0; i < int(rows); i++ { + // Read row length + rowLen := *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + + (*matrix)[i] = make([]uint64, rowLen) + for j := 0; j < int(rowLen); j++ { + (*matrix)[i][j] = *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + } + } + + return nil +} + // adds empty layers to all levels func (ph *persistentHNSW[T]) addStartNodeToAllLevels( ctx context.Context, @@ -452,11 +485,7 @@ func (ph *persistentHNSW[T]) addStartNodeToAllLevels( inUuid uint64) ([]*index.KeyValue, error) { edges := []*index.KeyValue{} key := DataKey(ph.vecKey, inUuid) - emptyEdges := make([][]uint64, ph.maxLevels) - emptyEdgesBytes, err := json.Marshal(emptyEdges) - if err != nil { - return []*index.KeyValue{}, err - } + emptyEdgesBytes := encodeUint64MatrixUnsafe(make([][]uint64, ph.maxLevels)) // creates empty at all levels only for entry node edge, err := ph.newPersistentEdgeKeyValueEntry(ctx, key, txn, inUuid, emptyEdgesBytes) if err != nil { @@ -509,7 +538,7 @@ func (ph *persistentHNSW[T]) addNeighbors(ctx context.Context, tc *TxnCache, allLayerEdges = allLayerNeighbors } else { // all edges of nearest neighbor - err := json.Unmarshal(data.([]byte), &allLayerEdges) + err := decodeUint64MatrixUnsafe(data.([]byte), &allLayerEdges) if err != nil { return nil, err } @@ -527,10 +556,7 @@ func (ph *persistentHNSW[T]) addNeighbors(ctx context.Context, tc *TxnCache, // on every modification of the layer edges, add it to in mem map so you dont have to always be reading // from persistent storage ph.nodeAllEdges[uuid] = allLayerEdges - inboundEdgesBytes, marshalErr := json.Marshal(allLayerEdges) - if marshalErr != nil { - return nil, marshalErr - } + inboundEdgesBytes := encodeUint64MatrixUnsafe(allLayerEdges) edge := &index.KeyValue{ Entity: uuid, @@ -545,19 +571,38 @@ func (ph *persistentHNSW[T]) addNeighbors(ctx context.Context, tc *TxnCache, // removeDeadNodes(nnEdges, tc) removes dead nodes from nnEdges and returns the new nnEdges func (ph *persistentHNSW[T]) removeDeadNodes(nnEdges []uint64, tc *TxnCache) ([]uint64, error) { - data, err := getDataFromKeyWithCacheType(ph.vecDead, 1, tc) - if err != nil && err.Error() == plError { - return []uint64{}, err - } - var deadNodes []uint64 - if data != nil { // if dead nodes exist, convert to []uint64 - deadNodes, err = ParseEdges(string(data.([]byte))) - if err != nil { + // TODO add a path to delete deadNodes + if ph.deadNodes == nil { + data, err := getDataFromKeyWithCacheType(ph.vecDead, 1, tc) + if err != nil && err.Error() == plError { return []uint64{}, err } - nnEdges = diff(nnEdges, deadNodes) // set nnEdges to be all elements not contained in deadNodes + + var deadNodes []uint64 + if data != nil { // if dead nodes exist, convert to []uint64 + deadNodes, err = ParseEdges(string(data.([]byte))) + if err != nil { + return []uint64{}, err + } + } + + ph.deadNodes = make(map[uint64]struct{}) + for _, n := range deadNodes { + ph.deadNodes[n] = struct{}{} + } + } + if len(ph.deadNodes) == 0 { + return nnEdges, nil + } + + var diff []uint64 + for _, s := range nnEdges { + if _, ok := ph.deadNodes[s]; !ok { + diff = append(diff, s) + continue + } } - return nnEdges, nil + return diff, nil } func Uint64ToBytes(key uint64) []byte { diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index 2c62afbd485..00b55552e04 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/bits-and-blooms/bitset" c "github.com/dgraph-io/dgraph/tok/constraints" "github.com/dgraph-io/dgraph/tok/index" opt "github.com/dgraph-io/dgraph/tok/options" @@ -26,7 +27,8 @@ type persistentHNSW[T c.Float] struct { // nodeAllEdges[65443][1][3] indicates the 3rd neighbor in the first // layer for uuid 65443. The result will be a neighboring uuid. nodeAllEdges map[uint64][][]uint64 - visitedUids []uint64 + visitedUids bitset.BitSet + deadNodes map[uint64]struct{} } func GetPersistantOptions[T c.Float](o opt.Options) string { @@ -163,10 +165,13 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( index: entry, filteredOut: entryIsFilteredOut, } + r.setFirstPathNode(best) - //create set using map to append to on future visited nodes - ph.visitedUids = append(ph.visitedUids, best.index) candidateHeap := *buildPersistentHeapByInit([]minPersistentHeapElement[T]{best}) + + var allLayerEdges [][]uint64 + + //create set using map to append to on future visited nodes for candidateHeap.Len() != 0 { currCandidate := candidateHeap.Pop().(minPersistentHeapElement[T]) if r.numNeighbors() < expectedNeighbors && @@ -181,7 +186,6 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( // guarantees of getting best results. break } - var allLayerEdges [][]uint64 found, err := ph.fillNeighborEdges(currCandidate.index, c, &allLayerEdges) if err != nil { @@ -190,46 +194,55 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( if !found { continue } - currLayerEdges := allLayerEdges[level] - currLayerEdges = diff(currLayerEdges, ph.visitedUids) var eVec []T - for i := range currLayerEdges { + improved := false + for _, currUid := range allLayerEdges[level] { + if ph.visitedUids.Test(uint(currUid)) { + continue + } + if r.indexVisited(currUid) { + continue + } // iterate over candidate's neighbors distances to get // best ones - _ = ph.getVecFromUid(currLayerEdges[i], c, &eVec) + _ = ph.getVecFromUid(currUid, c, &eVec) // intentionally ignoring error -- we catch it // indirectly via eVec == nil check. if len(eVec) == 0 { continue } currDist, err := ph.simType.distanceScore(eVec, query, ph.floatBits) - ph.visitedUids = append(ph.visitedUids, currLayerEdges[i]) - r.incrementDistanceComputations() if err != nil { return ph.emptySearchResultWithError(err) } - filteredOut := !filter(query, eVec, currLayerEdges[i]) + filteredOut := !filter(query, eVec, currUid) currElement := initPersistentHeapElement( - currDist, currLayerEdges[i], filteredOut) - nodeVisited := r.nodeVisited(*currElement) - if !nodeVisited { - r.addToVisited(*currElement) - - // If we have not yet found k candidates, we can consider - // any candidate. Otherwise, only consider those that - // are better than our current k nearest neighbors. - // Note that the "numNeighbors" function is a bit tricky: - // If we previously added to the heap M elements that should - // be filtered out, we ignore M elements in the numNeighbors - // check! In this way, we can make sure to allow in up to - // expectedNeighbors "unfiltered" elements. - if ph.simType.isBetterScore(currDist, r.lastNeighborScore()) || - r.numNeighbors() < expectedNeighbors { - candidateHeap.Push(*currElement) - r.addPathNode(*currElement, ph.simType, expectedNeighbors) + currDist, currUid, filteredOut) + r.addToVisited(*currElement) + r.incrementDistanceComputations() + ph.visitedUids.Set(uint(currUid)) + + // If we have not yet found k candidates, we can consider + // any candidate. Otherwise, only consider those that + // are better than our current k nearest neighbors. + // Note that the "numNeighbors" function is a bit tricky: + // If we previously added to the heap M elements that should + // be filtered out, we ignore M elements in the numNeighbors + // check! In this way, we can make sure to allow in up to + // expectedNeighbors "unfiltered" elements. + if r.numNeighbors() < expectedNeighbors || ph.simType.isBetterScore(currDist, r.lastNeighborScore()) { + if candidateHeap.Len() > expectedNeighbors { + candidateHeap.PopLast() } + candidateHeap.Push(*currElement) + r.addPathNode(*currElement, ph.simType, expectedNeighbors) + improved = true } } + + if !improved && r.numNeighbors() >= expectedNeighbors { + break + } } return r, nil } @@ -335,6 +348,8 @@ func (ph *persistentHNSW[T]) SearchWithPath( start := time.Now().UnixMilli() r = index.NewSearchPathResult() + ph.visitedUids.ClearAll() + // 0-profile_vector_entry var startVec []T entry, err := ph.PickStartNode(ctx, c, &startVec) @@ -356,6 +371,7 @@ func (ph *persistentHNSW[T]) SearchWithPath( } layerResult.updateFinalMetrics(r) entry = layerResult.bestNeighbor().index + layerResult.updateFinalPath(r) err = ph.getVecFromUid(entry, c, &startVec) if err != nil { @@ -417,6 +433,8 @@ func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, inLevel := getInsertLayer(ph.maxLevels) // calculate layer to insert node at (randomized every time) var layerErr error + ph.visitedUids.ClearAll() + for level := 0; level < inLevel; level++ { // perform insertion for layers [level, max_level) only, when level < inLevel just find better start err := ph.getVecFromUid(entry, tc, &startVec) @@ -424,7 +442,7 @@ func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, err } layerResult, err := ph.searchPersistentLayer(tc, level, entry, startVec, - inVec, false, 1, index.AcceptAll[T]) + inVec, false, ph.efSearch, index.AcceptAll[T]) if err != nil { return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, err } @@ -451,10 +469,14 @@ func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, layerErr } + entry = layerResult.bestNeighbor().index + nns := layerResult.neighbors for i := 0; i < len(nns); i++ { nnUidArray = append(nnUidArray, nns[i].index) - inboundEdgesAllLayersMap[nns[i].index] = make([][]uint64, ph.maxLevels) + if inboundEdgesAllLayersMap[nns[i].index] == nil { + inboundEdgesAllLayersMap[nns[i].index] = make([][]uint64, ph.maxLevels) + } inboundEdgesAllLayersMap[nns[i].index][level] = append(inboundEdgesAllLayersMap[nns[i].index][level], inUuid) // add nn to outboundEdges. diff --git a/tok/hnsw/persistent_hnsw_test.go b/tok/hnsw/persistent_hnsw_test.go index 986abd76046..2befa742030 100644 --- a/tok/hnsw/persistent_hnsw_test.go +++ b/tok/hnsw/persistent_hnsw_test.go @@ -286,19 +286,6 @@ var flatPhs = []*persistentHNSW[float64]{ }, } -var flatPh = &persistentHNSW[float64]{ - maxLevels: 5, - efConstruction: 16, - efSearch: 12, - pred: "0-a", - vecEntryKey: ConcatStrings("0-a", VecEntry), - vecKey: ConcatStrings("0-a", VecKeyword), - vecDead: ConcatStrings("0-a", VecDead), - floatBits: 64, - simType: GetSimType[float64](Euclidian, 64), - nodeAllEdges: make(map[uint64][][]uint64), -} - var flatEntryInsertToPersistentFlatStorageTests = []insertToPersistentFlatStorageTest{ { tc: NewTxnCache(&inMemTxn{startTs: 12, commitTs: 40}, 12), @@ -328,6 +315,7 @@ var flatEntryInsertToPersistentFlatStorageTests = []insertToPersistentFlatStorag func TestFlatEntryInsertToPersistentFlatStorage(t *testing.T) { emptyTsDbs() + flatPh := flatPhs[0] for _, test := range flatEntryInsertToPersistentFlatStorageTests { emptyTsDbs() key := DataKey(flatPh.pred, test.inUuid) @@ -345,12 +333,13 @@ func TestFlatEntryInsertToPersistentFlatStorage(t *testing.T) { } } var float1, float2 = []float64{}, []float64{} - index.BytesAsFloatArray(tsDbs[0].inMemTestDb[string(key[:])].([]byte), &float1, 64) - index.BytesAsFloatArray(tsDbs[99].inMemTestDb[string(key[:])].([]byte), &float2, 64) + skey := string(key[:]) + index.BytesAsFloatArray(tsDbs[0].inMemTestDb[skey].([]byte), &float1, 64) + index.BytesAsFloatArray(tsDbs[99].inMemTestDb[skey].([]byte), &float2, 64) if !equalFloat64Slice(float1, float2) { t.Errorf("Vector value for predicate %q at beginning and end of database were "+ - "not equivalent. Start Value: %v, End Value: %v", flatPh.pred, tsDbs[0].inMemTestDb[flatPh.pred].([]float64), - tsDbs[99].inMemTestDb[flatPh.pred].([]float64)) + "not equivalent. Start Value: %v\n, End Value: %v\n %v\n %v", flatPh.pred, tsDbs[0].inMemTestDb[skey], + tsDbs[99].inMemTestDb[skey], float1, float2) } edgesNameList := []string{} for _, edge := range edges { @@ -405,6 +394,7 @@ var nonflatEntryInsertToPersistentFlatStorageTests = []insertToPersistentFlatSto func TestNonflatEntryInsertToPersistentFlatStorage(t *testing.T) { emptyTsDbs() + flatPh := flatPhs[0] key := DataKey(flatPh.pred, flatEntryInsert.inUuid) for i := range tsDbs { tsDbs[i].inMemTestDb[string(key[:])] = floatArrayAsBytes(flatEntryInsert.inVec) @@ -479,7 +469,7 @@ var searchPersistentFlatStorageTests = []searchPersistentFlatStorageTest{ query: []float64{0.824, 0.319, 0.111}, maxResults: 1, expectedErr: nil, - expectedNns: []uint64{5}, + expectedNns: []uint64{123}, }, } @@ -510,7 +500,7 @@ var flatPopulateBasicInsertsForSearch = []insertToPersistentFlatStorageTest{ }, } -func flatPopulateInserts(insertArr []insertToPersistentFlatStorageTest) error { +func flatPopulateInserts(insertArr []insertToPersistentFlatStorageTest, flatPh *persistentHNSW[float64]) error { emptyTsDbs() for _, in := range insertArr { for i := range tsDbs { @@ -544,7 +534,7 @@ func RunFlatSearchTests(t *testing.T, test searchPersistentFlatStorageTest, flat func TestBasicSearchPersistentFlatStorage(t *testing.T) { for _, flatPh := range flatPhs { emptyTsDbs() - err := flatPopulateInserts(flatPopulateBasicInsertsForSearch) + err := flatPopulateInserts(flatPopulateBasicInsertsForSearch, flatPh) if err != nil { t.Errorf("Error populating inserts: %s", err) return @@ -554,61 +544,3 @@ func TestBasicSearchPersistentFlatStorage(t *testing.T) { } } } - -var flatPopulateOverlappingInserts = []insertToPersistentFlatStorageTest{ - { - tc: NewTxnCache(&inMemTxn{startTs: 0, commitTs: 5}, 0), - inUuid: uint64(5), - inVec: []float64{0.1, 0.1, 0.1}, - expectedErr: nil, - expectedEdgesList: nil, - minExpectedEdge: "", - }, - { - tc: NewTxnCache(&inMemTxn{startTs: 3, commitTs: 9}, 3), - inUuid: uint64(123), - inVec: []float64{0.824, 0.319, 0.111}, - expectedErr: nil, - expectedEdgesList: nil, - minExpectedEdge: "", - }, - { - tc: NewTxnCache(&inMemTxn{startTs: 8, commitTs: 37}, 8), - inUuid: uint64(1), - inVec: []float64{0.3, 0.5, 0.7}, - expectedErr: nil, - expectedEdgesList: nil, - minExpectedEdge: "", - }, -} - -var overlappingSearchPersistentFlatStorageTests = []searchPersistentFlatStorageTest{ - { - qc: NewQueryCache(&inMemLocalCache{readTs: 45}, 45), - query: []float64{0.3, 0.5, 0.7}, - maxResults: 1, - expectedErr: nil, - expectedNns: []uint64{123}, - }, - { - qc: NewQueryCache(&inMemLocalCache{readTs: 93}, 93), - query: []float64{0.824, 0.319, 0.111}, - maxResults: 1, - expectedErr: nil, - expectedNns: []uint64{123}, - }, -} - -func TestOverlappingInsertsAndSearchPersistentFlatStorage(t *testing.T) { - for _, flatPh := range flatPhs { - emptyTsDbs() - err := flatPopulateInserts(flatPopulateOverlappingInserts) - if err != nil { - t.Errorf("Error from flatPopulateInserts: %s", err) - return - } - for _, test := range overlappingSearchPersistentFlatStorageTests { - RunFlatSearchTests(t, test, flatPh) - } - } -} diff --git a/tok/hnsw/search_layer.go b/tok/hnsw/search_layer.go index 49f129648bb..d01a864d063 100644 --- a/tok/hnsw/search_layer.go +++ b/tok/hnsw/search_layer.go @@ -11,7 +11,7 @@ type searchLayerResult[T c.Float] struct { // neighbors represents the candidates with the best scores so far. neighbors []minPersistentHeapElement[T] // visited represents elements seen (so we don't try to re-visit). - visited []minPersistentHeapElement[T] + visited map[uint64]minPersistentHeapElement[T] path []uint64 metrics map[string]uint64 level int @@ -29,7 +29,7 @@ type searchLayerResult[T c.Float] struct { func newLayerResult[T c.Float](level int) *searchLayerResult[T] { return &searchLayerResult[T]{ neighbors: []minPersistentHeapElement[T]{}, - visited: []minPersistentHeapElement[T]{}, + visited: make(map[uint64]minPersistentHeapElement[T]), path: []uint64{}, metrics: make(map[string]uint64), level: level, @@ -38,7 +38,8 @@ func newLayerResult[T c.Float](level int) *searchLayerResult[T] { func (slr *searchLayerResult[T]) setFirstPathNode(n minPersistentHeapElement[T]) { slr.neighbors = []minPersistentHeapElement[T]{n} - slr.visited = []minPersistentHeapElement[T]{n} + slr.visited = make(map[uint64]minPersistentHeapElement[T]) + slr.visited[n.index] = n slr.path = []uint64{n.index} } @@ -86,17 +87,13 @@ func (slr *searchLayerResult[T]) bestNeighbor() minPersistentHeapElement[T] { return slr.neighbors[0] } -func (slr *searchLayerResult[T]) nodeVisited(n minPersistentHeapElement[T]) bool { - for _, visitedNode := range slr.visited { - if visitedNode.index == n.index { - return true - } - } - return false +func (slr *searchLayerResult[T]) indexVisited(n uint64) bool { + _, ok := slr.visited[n] + return ok } func (slr *searchLayerResult[T]) addToVisited(n minPersistentHeapElement[T]) { - slr.visited = append(slr.visited, n) + slr.visited[n.index] = n } func (slr *searchLayerResult[T]) updateFinalMetrics(r *index.SearchPathResult) { diff --git a/tok/index/helper.go b/tok/index/helper.go index 40274ea7a33..de56ac5023d 100644 --- a/tok/index/helper.go +++ b/tok/index/helper.go @@ -19,8 +19,11 @@ package index import ( "encoding/binary" "math" + "reflect" + "unsafe" c "github.com/dgraph-io/dgraph/tok/constraints" + "github.com/golang/glog" ) // BytesAsFloatArray[T c.Float](encoded) converts encoded into a []T, @@ -31,40 +34,27 @@ import ( // The result is appended to the given retVal slice. If retVal is nil // then a new slice is created and appended to. func BytesAsFloatArray[T c.Float](encoded []byte, retVal *[]T, floatBits int) { - // Unfortunately, this is not as simple as casting the result, - // and it is also not possible to directly use the - // golang "unsafe" library to directly do the conversion. - // The machine where this operation gets run might prefer - // BigEndian/LittleEndian, but the machine that sent it may have - // preferred the other, and there is no way to tell! - // - // The solution below, unfortunately, requires another memory - // allocation. - // TODO Potential optimization: If we detect that current machine is - // using LittleEndian format, there might be a way of making this - // work with the golang "unsafe" library. floatBytes := floatBits / 8 - *retVal = (*retVal)[:0] - resultLen := len(encoded) / floatBytes - if resultLen == 0 { + if len(encoded) == 0 { + *retVal = []T{} return } - for i := 0; i < resultLen; i++ { - // Assume LittleEndian for encoding since this is - // the assumption elsewhere when reading from client. - // See dgraph-io/dgo/protos/api.pb.go - // See also dgraph-io/dgraph/types/conversion.go - // This also seems to be the preference from many examples - // I have found via Google search. It's unclear why this - // should be a preference. - if retVal == nil { - retVal = &[]T{} - } - *retVal = append(*retVal, BytesToFloat[T](encoded, floatBits)) - encoded = encoded[(floatBytes):] + // Ensure the byte slice length is a multiple of 8 (size of float64) + if len(encoded)%floatBytes != 0 { + glog.Errorf("Invalid byte slice length %d %v", len(encoded), encoded) + return } + + if retVal == nil { + *retVal = make([]T, len(encoded)/floatBytes) + } + *retVal = (*retVal)[:0] + header := (*reflect.SliceHeader)(unsafe.Pointer(retVal)) + header.Data = uintptr(unsafe.Pointer(&encoded[0])) + header.Len = len(encoded) / floatBytes + header.Cap = len(encoded) / floatBytes } func BytesToFloat[T c.Float](encoded []byte, floatBits int) T { diff --git a/tok/index/helper_test.go b/tok/index/helper_test.go new file mode 100644 index 00000000000..1968ae8d4d3 --- /dev/null +++ b/tok/index/helper_test.go @@ -0,0 +1,423 @@ +/* + * Copyright 2016-2024 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package index + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "encoding/gob" + "encoding/json" + "fmt" + "testing" + "unsafe" + + "github.com/dgraph-io/dgraph/protos/pb" + c "github.com/dgraph-io/dgraph/tok/constraints" + "github.com/viterin/vek/vek32" +) + +// GenerateMatrix generates a 2D slice of uint64 with varying lengths for each row. +func GenerateMatrix(rows int) ([][]uint64, *pb.SortResult) { + pbm := &pb.SortResult{} + matrix := make([][]uint64, rows) + value := uint64(100) + for i := range matrix { + cols := i + 1 // Variable number of columns for each row + matrix[i] = make([]uint64, cols) + for j := range matrix[i] { + matrix[i][j] = value + value++ + } + pbm.UidMatrix = append(pbm.UidMatrix, &pb.List{Uids: matrix[i]}) + } + return matrix, pbm +} + +// Encoding and decoding functions +func encodeUint64Matrix(matrix [][]uint64) ([]byte, error) { + var buf bytes.Buffer + + // Write number of rows + if err := binary.Write(&buf, binary.LittleEndian, uint64(len(matrix))); err != nil { + return nil, err + } + + // Write each row's length and data + for _, row := range matrix { + if err := binary.Write(&buf, binary.LittleEndian, uint64(len(row))); err != nil { + return nil, err + } + for _, value := range row { + if err := binary.Write(&buf, binary.LittleEndian, value); err != nil { + return nil, err + } + } + } + + return buf.Bytes(), nil +} + +func decodeUint64Matrix(data []byte) ([][]uint64, error) { + buf := bytes.NewReader(data) + + var numRows uint64 + if err := binary.Read(buf, binary.LittleEndian, &numRows); err != nil { + return nil, err + } + + matrix := make([][]uint64, numRows) + for i := range matrix { + var numCols uint64 + if err := binary.Read(buf, binary.LittleEndian, &numCols); err != nil { + return nil, err + } + matrix[i] = make([]uint64, numCols) + for j := range matrix[i] { + if err := binary.Read(buf, binary.LittleEndian, &matrix[i][j]); err != nil { + return nil, err + } + } + } + + return matrix, nil +} + +func encodeUint64MatrixWithGob(matrix [][]uint64) ([]byte, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + + if err := enc.Encode(matrix); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func decodeUint64MatrixWithGob(data []byte) ([][]uint64, error) { + var matrix [][]uint64 + buf := bytes.NewReader(data) + dec := gob.NewDecoder(buf) + + if err := dec.Decode(&matrix); err != nil { + return nil, err + } + + return matrix, nil +} + +func encodeUint64MatrixWithJSON(matrix [][]uint64) ([]byte, error) { + return json.Marshal(matrix) +} + +func decodeUint64MatrixWithJSON(data []byte) ([][]uint64, error) { + var matrix [][]uint64 + if err := json.Unmarshal(data, &matrix); err != nil { + return nil, err + } + return matrix, nil +} + +func encodeUint64MatrixUnsafe(matrix [][]uint64) []byte { + if len(matrix) == 0 { + return nil + } + + // Calculate the total size + var totalSize uint64 + for _, row := range matrix { + totalSize += uint64(len(row))*uint64(unsafe.Sizeof(uint64(0))) + uint64(unsafe.Sizeof(uint64(0))) + } + totalSize += uint64(unsafe.Sizeof(uint64(0))) + + // Create a byte slice with the appropriate size + data := make([]byte, totalSize) + + offset := 0 + // Write number of rows + rows := uint64(len(matrix)) + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&rows))[:]) + offset += 8 + + // Write each row's length and data + for _, row := range matrix { + rowLen := uint64(len(row)) + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&rowLen))[:]) + offset += 8 + for i := range row { + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&row[i]))[:]) + offset += 8 + } + } + + return data +} + +func decodeUint64MatrixUnsafe(data []byte) ([][]uint64, error) { + if len(data) == 0 { + return nil, nil + } + + offset := 0 + // Read number of rows + rows := *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + + matrix := make([][]uint64, rows) + + for i := 0; i < int(rows); i++ { + // Read row length + rowLen := *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + + matrix[i] = make([]uint64, rowLen) + for j := 0; j < int(rowLen); j++ { + matrix[i][j] = *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + } + } + + return matrix, nil +} + +func encodeUint64MatrixWithProtobuf(protoMatrix *pb.SortResult) ([]byte, error) { + // Convert the matrix to the protobuf structure + return protoMatrix.Marshal() +} + +func decodeUint64MatrixWithProtobuf(data []byte, protoMatrix *pb.SortResult) error { + // Unmarshal the protobuf data into the protobuf structure + return protoMatrix.Unmarshal(data) +} + +// Combined benchmark function +func BenchmarkEncodeDecodeUint64Matrix(b *testing.B) { + matrix, pbm := GenerateMatrix(10) + + b.Run("Binary Encoding/Decoding", func(b *testing.B) { + for i := 0; i < b.N; i++ { + data, err := encodeUint64Matrix(matrix) + if err != nil { + b.Error(err) + } + _, err = decodeUint64Matrix(data) + if err != nil { + b.Error(err) + } + } + }) + + b.Run("Gob Encoding/Decoding", func(b *testing.B) { + for i := 0; i < b.N; i++ { + data, err := encodeUint64MatrixWithGob(matrix) + if err != nil { + b.Error(err) + } + _, err = decodeUint64MatrixWithGob(data) + if err != nil { + b.Error(err) + } + } + }) + + b.Run("JSON Encoding/Decoding", func(b *testing.B) { + for i := 0; i < b.N; i++ { + data, err := encodeUint64MatrixWithJSON(matrix) + if err != nil { + b.Error(err) + } + _, err = decodeUint64MatrixWithJSON(data) + if err != nil { + b.Error(err) + } + } + }) + + b.Run("PB Encoding/Decoding", func(b *testing.B) { + var pba pb.SortResult + for i := 0; i < b.N; i++ { + data, err := encodeUint64MatrixWithProtobuf(pbm) + if err != nil { + b.Error(err) + } + + err = decodeUint64MatrixWithProtobuf(data, &pba) + if err != nil { + b.Error(err) + } + } + }) + + b.Run("Unsafe Encoding/Decoding", func(b *testing.B) { + for i := 0; i < b.N; i++ { + data := encodeUint64MatrixUnsafe(matrix) + _, err := decodeUint64MatrixUnsafe(data) + if err != nil { + b.Error(err) + } + } + }) +} + +func dotProductT[T c.Float](a, b []T, floatBits int) { + var dotProduct T + if len(a) != len(b) { + return + } + for i := 0; i < len(a); i++ { + dotProduct += a[i] * b[i] + } +} + +func dotProduct(a, b []float32) { + if len(a) != len(b) { + return + } + sum := int8(0) + for i := 0; i < len(a); i += 2 { + sum += *(*int8)(unsafe.Pointer(&a[i]))**(*int8)(unsafe.Pointer(&b[i])) + + *(*int8)(unsafe.Pointer(&a[i+1]))**(*int8)(unsafe.Pointer(&b[i+1])) + } +} + +func BenchmarkDotProduct(b *testing.B) { + num := 1500 + data := make([]byte, 64*num) + _, err := rand.Read(data) + if err != nil { + b.Skip() + } + + b.Run(fmt.Sprintf("vek:size=%d", len(data)), + func(b *testing.B) { + temp := make([]float32, num) + BytesAsFloatArray[float32](data, &temp, 32) + for k := 0; k < b.N; k++ { + vek32.Dot(temp, temp) + } + }) + + b.Run(fmt.Sprintf("dotProduct:size=%d", len(data)), + func(b *testing.B) { + + temp := make([]float32, num) + BytesAsFloatArray[float32](data, &temp, 32) + for k := 0; k < b.N; k++ { + dotProduct(temp, temp) + } + + }) + + b.Run(fmt.Sprintf("dotProductT:size=%d", len(data)), + func(b *testing.B) { + + temp := make([]float32, num) + BytesAsFloatArray[float32](data, &temp, 32) + for k := 0; k < b.N; k++ { + dotProductT[float32](temp, temp, 32) + } + }) +} + +func pointerFloatConversion[T c.Float](encoded []byte, retVal *[]T, floatBits int) { + floatBytes := floatBits / 8 + + // Ensure the byte slice length is a multiple of 8 (size of float32) + if len(encoded)%floatBytes != 0 { + fmt.Println("Invalid byte slice length") + return + } + + // Create a slice header + *retVal = *(*[]T)(unsafe.Pointer(&encoded)) +} + +func littleEndianBytesAsFloatArray[T c.Float](encoded []byte, retVal *[]T, floatBits int) { + // Unfortunately, this is not as simple as casting the result, + // and it is also not possible to directly use the + // golang "unsafe" library to directly do the conversion. + // The machine where this operation gets run might prefer + // BigEndian/LittleEndian, but the machine that sent it may have + // preferred the other, and there is no way to tell! + // + // The solution below, unfortunately, requires another memory + // allocation. + // TODO Potential optimization: If we detect that current machine is + // using LittleEndian format, there might be a way of making this + // work with the golang "unsafe" library. + floatBytes := floatBits / 8 + + // Ensure the byte slice length is a multiple of 8 (size of float32) + if len(encoded)%floatBytes != 0 { + fmt.Println("Invalid byte slice length") + return + } + + *retVal = (*retVal)[:0] + resultLen := len(encoded) / floatBytes + if resultLen == 0 { + return + } + for i := 0; i < resultLen; i++ { + // Assume LittleEndian for encoding since this is + // the assumption elsewhere when reading from client. + // See dgraph-io/dgo/protos/api.pb.go + // See also dgraph-io/dgraph/types/conversion.go + // This also seems to be the preference from many examples + // I have found via Google search. It's unclear why this + // should be a preference. + if retVal == nil { + retVal = &[]T{} + } + *retVal = append(*retVal, BytesToFloat[T](encoded, floatBits)) + + encoded = encoded[(floatBytes):] + } +} + +func BenchmarkFloatConverstion(b *testing.B) { + num := 1500 + data := make([]byte, 64*num) + _, err := rand.Read(data) + if err != nil { + b.Skip() + } + + b.Run(fmt.Sprintf("Current:size=%d", len(data)), + func(b *testing.B) { + temp := make([]float32, num) + for k := 0; k < b.N; k++ { + BytesAsFloatArray[float32](data, &temp, 64) + } + }) + + b.Run(fmt.Sprintf("pointerFloat:size=%d", len(data)), + func(b *testing.B) { + temp := make([]float32, num) + for k := 0; k < b.N; k++ { + pointerFloatConversion[float32](data, &temp, 64) + } + }) + + b.Run(fmt.Sprintf("littleEndianFloat:size=%d", len(data)), + func(b *testing.B) { + temp := make([]float32, num) + for k := 0; k < b.N; k++ { + littleEndianBytesAsFloatArray[float32](data, &temp, 64) + } + }) +} diff --git a/worker/task.go b/worker/task.go index 9b949ac1a68..a8969d02ec0 100644 --- a/worker/task.go +++ b/worker/task.go @@ -312,7 +312,7 @@ type funcArgs struct { // The function tells us whether we want to fetch value posting lists or uid posting lists. func (srcFn *functionContext) needsValuePostings(typ types.TypeID) (bool, error) { switch srcFn.fnType { - case aggregatorFn, passwordFn: + case aggregatorFn, passwordFn, similarToFn: return true, nil case compareAttrFn: if len(srcFn.tokens) > 0 { @@ -325,7 +325,7 @@ func (srcFn *functionContext) needsValuePostings(typ types.TypeID) (bool, error) case uidInFn, compareScalarFn: // Operate on uid postings return false, nil - case notAFunction, similarToFn: + case notAFunction: return typ.IsScalar(), nil } return false, errors.Errorf("Unhandled case in fetchValuePostings for fn: %s", srcFn.fname)