From c430009581b73495e847bb824a258994363804be Mon Sep 17 00:00:00 2001 From: shivaji-dgraph Date: Wed, 15 May 2024 14:52:17 +0530 Subject: [PATCH] add test for dotproduct and cosine index and fix failing tests --- graphql/resolve/query_rewriter.go | 3 +- graphql/resolve/query_test.yaml | 12 +- query/vector/vector_graphql_test.go | 168 +++++++++++++++++++--------- 3 files changed, 124 insertions(+), 59 deletions(-) diff --git a/graphql/resolve/query_rewriter.go b/graphql/resolve/query_rewriter.go index 4755a76496b..509e593a772 100644 --- a/graphql/resolve/query_rewriter.go +++ b/graphql/resolve/query_rewriter.go @@ -824,7 +824,8 @@ func rewriteAsSimilarByEmbeddingQuery( if metric == schema.SimilarSearchMetricDotProduct { distanceFormula = "math(( 1.0 - (($search_vector) dot v2)) /2.0)" } else if metric == schema.SimilarSearchMetricCosine { - distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0)" + distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" + + " * (v2 dot v2) ) )) / 2.0)" } // Save vectorString as a query variable, $search_vector diff --git a/graphql/resolve/query_test.yaml b/graphql/resolve/query_test.yaml index c74b9d77471..ab15599d020 100644 --- a/graphql/resolve/query_test.yaml +++ b/graphql/resolve/query_test.yaml @@ -3367,7 +3367,7 @@ query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { var(func: similar_to(Product.productVector, 1, $search_vector)) @filter(type(Product)) { v2 as Product.productVector - distance as math((v2 - $search_vector) dot (v2 - $search_vector)) + distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector))) } querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) { Product.id : Product.id @@ -3397,7 +3397,7 @@ } var(func: similar_to(Product.productVector, 3, val(v1))) { v2 as Product.productVector - distance as math((v2 - v1) dot (v2 - v1)) + distance as math(sqrt((v2 - v1) dot (v2 - v1))) } querySimilarProductById(func: uid(distance), orderasc: val(distance)) { Product.id : Product.id @@ -3428,7 +3428,7 @@ } var(func: similar_to(ProjectCosine.description_v, 3, val(v1))) { v2 as ProjectCosine.description_v - distance as math((v1 dot v2) / ((v1 dot v1) * (v2 dot v2))) + distance as math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0) } querySimilarProjectCosineById(func: uid(distance), orderasc: val(distance)) { ProjectCosine.id : ProjectCosine.id @@ -3453,7 +3453,7 @@ query querySimilarProjectCosineByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { var(func: similar_to(ProjectCosine.description_v, 1, $search_vector)) @filter(type(ProjectCosine)) { v2 as ProjectCosine.description_v - distance as math(($search_vector dot v2) / (($search_vector dot $search_vector) * (v2 dot v2))) + distance as math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0) } querySimilarProjectCosineByEmbedding(func: uid(distance), orderasc: val(distance)) { ProjectCosine.id : ProjectCosine.id @@ -3483,7 +3483,7 @@ } var(func: similar_to(ProjectDotProduct.description_v, 3, val(v1))) { v2 as ProjectDotProduct.description_v - distance as math(v1 dot v2) + distance as math((1.0 - (v1 dot v2)) /2.0) } querySimilarProjectDotProductById(func: uid(distance), orderasc: val(distance)) { ProjectDotProduct.id : ProjectDotProduct.id @@ -3508,7 +3508,7 @@ query querySimilarProjectDotProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { var(func: similar_to(ProjectDotProduct.description_v, 1, $search_vector)) @filter(type(ProjectDotProduct)) { v2 as ProjectDotProduct.description_v - distance as math($search_vector dot v2) + distance as math(( 1.0 - (($search_vector) dot v2)) /2.0) } querySimilarProjectDotProductByEmbedding(func: uid(distance), orderasc: val(distance)) { ProjectDotProduct.id : ProjectDotProduct.id diff --git a/query/vector/vector_graphql_test.go b/query/vector/vector_graphql_test.go index 058e89dca04..70371265e99 100644 --- a/query/vector/vector_graphql_test.go +++ b/query/vector/vector_graphql_test.go @@ -20,6 +20,8 @@ package query import ( "encoding/json" + "fmt" + "math/rand" "testing" "github.com/dgraph-io/dgraph/dgraphtest" @@ -36,29 +38,56 @@ const ( type Project { id: ID! title: String! @search(by: [exact]) - title_v: [Float!] @embedding @search(by: ["hnsw(metric: euclidian, exponent: 4)"]) - } - ` + title_v: [Float!] @embedding @search(by: ["hnsw(metric: %v, exponent: 4)"]) + } ` ) -var ( - projects = []ProjectInput{ProjectInput{ - Title: "iCreate with a Mini iPad", - TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2}, - }, ProjectInput{ - Title: "Resistive Touchscreen", - TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2}, - }, ProjectInput{ - Title: "Fitness Band", - TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2}, - }, ProjectInput{ - Title: "Smart Watch", - TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2}, - }, ProjectInput{ - Title: "Smart Ring", - TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2}, - }} -) +func generateProjects(count int) []ProjectInput { + var projects []ProjectInput + for i := 0; i < count; i++ { + title := generateUniqueRandomTitle(projects) + titleV := generateRandomTitleV(5) // Assuming size is fixed at 5 + project := ProjectInput{ + Title: title, + TitleV: titleV, + } + projects = append(projects, project) + } + return projects +} + +func isTitleExists(title string, existingTitles []ProjectInput) bool { + for _, project := range existingTitles { + if project.Title == title { + return true + } + } + return false +} + +func generateUniqueRandomTitle(existingTitles []ProjectInput) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + const titleLength = 10 + title := make([]byte, titleLength) + for { + for i := range title { + title[i] = charset[rand.Intn(len(charset))] + } + titleStr := string(title) + if !isTitleExists(titleStr, existingTitles) { + return titleStr + } + } +} + +func generateRandomTitleV(size int) []float32 { + var titleV []float32 + for i := 0; i < size; i++ { + value := rand.Float32() + titleV = append(titleV, value) + } + return titleV +} func addProject(t *testing.T, hc *dgraphtest.HTTPClient, project ProjectInput) { query := ` @@ -79,6 +108,7 @@ func addProject(t *testing.T, hc *dgraphtest.HTTPClient, project ProjectInput) { _, err := hc.RunGraphqlQuery(params, false) require.NoError(t, err) } + func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title string) ProjectInput { query := ` query QueryProject($title: String!) { queryProject(filter: { title: { eq: $title } }) { @@ -96,7 +126,6 @@ func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title strin type QueryResult struct { QueryProject []ProjectInput `json:"queryProject"` } - var resp QueryResult err = json.Unmarshal([]byte(string(response)), &resp) require.NoError(t, err) @@ -104,11 +133,10 @@ func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title strin return resp.QueryProject[0] } -func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, vector []float32) []ProjectInput { +func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, vector []float32, topk int) []ProjectInput { // query similar project by embedding queryProduct := `query QuerySimilarProjectByEmbedding($by: ProjectEmbedding!, $topK: Int!, $vector: [Float!]!) { querySimilarProjectByEmbedding(by: $by, topK: $topK, vector: $vector) { - id title title_v } @@ -120,13 +148,13 @@ func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, ve Query: queryProduct, Variables: map[string]interface{}{ "by": "title_v", - "topK": 3, + "topK": topk, "vector": vector, }} response, err := hc.RunGraphqlQuery(params, false) require.NoError(t, err) type QueryResult struct { - QueryProject []ProjectInput `json:"queryProject"` + QueryProject []ProjectInput `json:"querySimilarProjectByEmbedding"` } var resp QueryResult err = json.Unmarshal([]byte(string(response)), &resp) @@ -143,21 +171,75 @@ func TestVectorGraphQLAddVectorPredicate(t *testing.T) { require.NoError(t, err) hc.LoginIntoNamespace("groot", "password", 0) // add schema - require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema)) + require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean"))) } -func TestVectorGraphQlMutationAndQuery(t *testing.T) { +func TestVectorSchema(t *testing.T) { require.NoError(t, client.DropAll()) hc, err := dc.HTTPClient() require.NoError(t, err) hc.LoginIntoNamespace("groot", "password", 0) + schema := `type Project { + id: ID! + title: String! @search(by: [exact]) + title_v: [Float!] + }` + + // add schema + require.NoError(t, hc.UpdateGQLSchema(schema)) + require.Error(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean"))) + require.NoError(t, client.DropAll()) + require.Error(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "dotproduct"))) + require.NoError(t, client.DropAll()) + require.Error(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "cosine"))) +} + +func TestVectorGraphQlEuclidianIndexMutationAndQuery(t *testing.T) { + require.NoError(t, client.DropAll()) + hc, err := dc.HTTPClient() + require.NoError(t, err) + hc.LoginIntoNamespace("groot", "password", 0) + + schema := fmt.Sprintf(graphQLVectorSchema, "euclidean") // add schema - require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema)) + require.NoError(t, hc.UpdateGQLSchema(schema)) + testVectorGraphQlMutationAndQuery(t, hc) + +} - // add project +func TestVectorGraphQlCosineIndexMutationAndQuery(t *testing.T) { + require.NoError(t, client.DropAll()) + hc, err := dc.HTTPClient() + require.NoError(t, err) + hc.LoginIntoNamespace("groot", "password", 0) + + schema := fmt.Sprintf(graphQLVectorSchema, "cosine") + // add schema + require.NoError(t, hc.UpdateGQLSchema(schema)) + testVectorGraphQlMutationAndQuery(t, hc) + +} + +func TestVectorGraphQlDotProductIndexMutationAndQuery(t *testing.T) { + require.NoError(t, client.DropAll()) + hc, err := dc.HTTPClient() + require.NoError(t, err) + hc.LoginIntoNamespace("groot", "password", 0) + + schema := fmt.Sprintf(graphQLVectorSchema, "dotproduct") + // add schema + require.NoError(t, hc.UpdateGQLSchema(schema)) + testVectorGraphQlMutationAndQuery(t, hc) + +} + +func testVectorGraphQlMutationAndQuery(t *testing.T, hc *dgraphtest.HTTPClient) { var vectors [][]float32 + numProjects := 100 + projects := generateProjects(numProjects) + fmt.Println("projects", len(projects)) for _, project := range projects { vectors = append(vectors, project.TitleV) addProject(t, hc, project) @@ -165,42 +247,24 @@ func TestVectorGraphQlMutationAndQuery(t *testing.T) { for _, project := range projects { p := queryProjectUsingTitle(t, hc, project.Title) + fmt.Println("p", p) require.Equal(t, project.Title, p.Title) require.Equal(t, project.TitleV, p.TitleV) } for _, project := range projects { p := queryProjectUsingTitle(t, hc, project.Title) + fmt.Println("p1", p) + require.Equal(t, project.Title, p.Title) require.Equal(t, project.TitleV, p.TitleV) } // query similar project by embedding for _, project := range projects { - similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV) - + similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV, numProjects) for _, similarVec := range similarProjects { require.Contains(t, vectors, similarVec.TitleV) } } } - -func TestVectorSchema(t *testing.T) { - require.NoError(t, client.DropAll()) - - hc, err := dc.HTTPClient() - require.NoError(t, err) - hc.LoginIntoNamespace("groot", "password", 0) - - schema := `type Project { - id: ID! - title: String! @search(by: [exact]) - title_v: [Float!] - }` - - // add schema - require.NoError(t, hc.UpdateGQLSchema(schema)) - require.Error(t, hc.UpdateGQLSchema(graphQLVectorSchema)) - require.NoError(t, client.DropAll()) - require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema)) -}