diff --git a/.github/workflows/cd-dgraph.yml b/.github/workflows/cd-dgraph.yml index 6f687373181..898dbcafcbd 100644 --- a/.github/workflows/cd-dgraph.yml +++ b/.github/workflows/cd-dgraph.yml @@ -10,6 +10,10 @@ on: description: releasetag required: true type: string + custom-build: + type: boolean + default: false + description: if checked, images will be pushed to dgraph-custom repo in Dockerhub jobs: dgraph-build-amd64: runs-on: warp-ubuntu-latest-x64-16x @@ -85,7 +89,10 @@ jobs: run: | make docker-image DGRAPH_VERSION=${{ env.DGRAPH_RELEASE_VERSION }}-amd64 [[ "${{ inputs.latest }}" = true ]] && docker tag dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 dgraph/dgraph:latest-amd64 || true + [[ "${{ inputs.custom-build }}" = true ]] && docker tag dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 dgraph/dgraph-custom:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 || true - name: Make Dgraph Standalone Docker Image with Version + #No need to build and push Standalone Image when its a custom build + if: inputs.custom-build == false run: | make docker-image-standalone DGRAPH_VERSION=${{ env.DGRAPH_RELEASE_VERSION }}-amd64 [[ "${{ inputs.latest }}" = true ]] && docker tag dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 dgraph/standalone:latest-amd64 || true @@ -96,8 +103,12 @@ jobs: password: ${{ secrets.DOCKERHUB_PASSWORD_TOKEN }} - name: Push Images to DockerHub run: | - docker push dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 - docker push dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 + if [ "${{ inputs.custom-build }}" == "true" ]; then + docker push dgraph/dgraph-custom:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 + else + docker push dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 + docker push dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 + fi dgraph-build-arm64: runs-on: warp-ubuntu-latest-arm64-16x @@ -173,7 +184,10 @@ jobs: run: | make docker-image DGRAPH_VERSION=${{ env.DGRAPH_RELEASE_VERSION }}-arm64 [[ "${{ inputs.latest }}" = true ]] && docker tag dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 dgraph/dgraph:latest-arm64 || true + [[ "${{ inputs.custom-build }}" = true ]] && docker tag dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 dgraph/dgraph-custom:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 || true - name: Make Dgraph Standalone Docker Image with Version + #No need to build and push Standalone Image when its a custom build + if: inputs.custom-build == false run: | make docker-image-standalone DGRAPH_VERSION=${{ env.DGRAPH_RELEASE_VERSION }}-arm64 [[ "${{ inputs.latest }}" = true ]] && docker tag dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 dgraph/standalone:latest-arm64 || true @@ -184,8 +198,12 @@ jobs: password: ${{ secrets.DOCKERHUB_PASSWORD_TOKEN }} - name: Push Images to DockerHub run: | - docker push dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 - docker push dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 + if [ "${{ inputs.custom-build }}" == "true" ]; then + docker push dgraph/dgraph-custom:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 + else + docker push dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 + docker push dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 + fi dgraph-docker-manifest: needs: [dgraph-build-amd64, dgraph-build-arm64] @@ -215,13 +233,23 @@ jobs: password: ${{ secrets.DOCKERHUB_PASSWORD_TOKEN }} - name: Docker Manifest run: | - # standalone - docker manifest create dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }} --amend dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 --amend dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 - docker manifest push dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }} - [[ "${{ inputs.latest }}" = true ]] && docker manifest create dgraph/standalone:latest --amend dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 --amend dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 - [[ "${{ inputs.latest }}" = true ]] && docker manifest push dgraph/standalone:latest || true - # dgraph - docker manifest create dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }} --amend dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 --amend dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 - docker manifest push dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }} - [[ "${{ inputs.latest }}" = true ]] && docker manifest create dgraph/dgraph:latest --amend dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 --amend dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 || true - [[ "${{ inputs.latest }}" = true ]] && docker manifest push dgraph/dgraph:latest || true + if [ "${{ github.event.inputs.custom-build }}" == "true" ]; then + docker manifest create dgraph/dgraph-custom:${{ env.DGRAPH_RELEASE_VERSION }} --amend dgraph/dgraph-custom:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 --amend dgraph/dgraph-custom:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 + docker manifest push dgraph/dgraph-custom:${{ env.DGRAPH_RELEASE_VERSION }} + else + # standalone + docker manifest create dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }} --amend dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 --amend dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 + docker manifest push dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }} + if [ "${{ github.event.inputs.latest }}" == "true" ]; then + docker manifest create dgraph/standalone:latest --amend dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 --amend dgraph/standalone:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 + docker manifest push dgraph/standalone:latest + fi + # dgraph + docker manifest create dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }} --amend dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 --amend dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 + docker manifest push dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }} + if [ "${{ github.event.inputs.latest }}" == "true" ]; then + docker manifest create dgraph/dgraph:latest --amend dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-amd64 --amend dgraph/dgraph:${{ env.DGRAPH_RELEASE_VERSION }}-arm64 + docker manifest push dgraph/dgraph:latest + fi + fi + diff --git a/.github/workflows/ci-dgraph-upgrade-fixed-versions-tests.yml b/.github/workflows/ci-dgraph-upgrade-fixed-versions-tests.yml index a10746c77e2..df41840b36d 100644 --- a/.github/workflows/ci-dgraph-upgrade-fixed-versions-tests.yml +++ b/.github/workflows/ci-dgraph-upgrade-fixed-versions-tests.yml @@ -1,7 +1,7 @@ name: ci-dgraph-upgrade-fixed-versions-tests on: schedule: - - cron: "0 3 * * *" # 1 run per day + - cron: "00 20 * * *" # 1 run per day jobs: dgraph-upgrade-fixed-versions-tests: runs-on: warp-ubuntu-latest-x64-16x diff --git a/.github/workflows/ci-dgraph-upgrade-tests.yml b/.github/workflows/ci-dgraph-upgrade-tests.yml index 7b6612dfc9a..c8056853904 100644 --- a/.github/workflows/ci-dgraph-upgrade-tests.yml +++ b/.github/workflows/ci-dgraph-upgrade-tests.yml @@ -9,6 +9,8 @@ on: branches: - main - 'release/**' + schedule: + - cron: "00 20 * * *" # 1 run per day jobs: dgraph-upgrade-tests: if: github.event.pull_request.draft == false diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 00000000000..0f4298f94f2 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,53 @@ +name: "CodeQL" + +on: + push: + branches: + - main + - release/** + pull_request: + branches: + - main + - release/** + schedule: + - cron: '0 0 * * *' + +jobs: + analyze: + name: Analyze (${{ matrix.language }}) + runs-on: warp-ubuntu-latest-x64-16x + timeout-minutes: 360 + permissions: + security-events: write + packages: read + + strategy: + fail-fast: false + matrix: + include: + - language: go + build-mode: autobuild + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: ${{ matrix.language }} + build-mode: ${{ matrix.build-mode }} + + - if: matrix.build-mode == 'manual' + run: | + echo 'If you are using a "manual" build mode for one or more of the' \ + 'languages you are analyzing, replace this with the commands to build' \ + 'your code, for example:' + echo ' make bootstrap' + echo ' make release' + exit 1 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:${{matrix.language}}" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000000..352f6023cce --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,18 @@ +name: 'Close stale issues and PRs' +on: + schedule: + - cron: '00 02,14 * * *' + +permissions: + issues: write + pull-requests: write + +jobs: + stale: + runs-on: ubuntu-latest + steps: + - uses: actions/stale@v9 + with: + stale-issue-message: 'This issue has been stale for 60 days and will be closed automatically in 7 days. Comment to keep it open.' + stale-pr-message: 'This PR has been stale for 60 days and will be closed automatically in 7 days. Comment to keep it open.' + operations-per-run: 250 diff --git a/CHANGELOG.md b/CHANGELOG.md index bd7c35b4bb0..386561977e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,110 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project will adhere to [Semantic Versioning](https://semver.org) starting `v22.0.0`. +## [v24.0.1] - 2024-07-30 +[v24.0.1]: https://github.com/dgraph-io/dgraph/compare/v24.0.0...v24.0.1 + +> **Warning** +> After upgrading to v24.0.1, vector index needs to be rebuilt as underlying data has changed. + +- **Fixed** + + - fix(core): Fix regression in parsing json empty string #9108 + - fix(upgrade): fix failing upgrade tests #9042 + - fix(ci): fixing health endpoint issue #9116 + - Fix(graphql): issue with local variable squashing intended JWK index by @matthewmcneely in #9114 + +- **Chore** + - chore(deps): bump urllib3 from 1.26.18 to 1.26.19 /contrib/config/marketplace/aws/tests #9103 + - chore(deps): bump requests from 2.31.0 to 2.32.0 /contrib/config/marketplace/aws/tests #9090 + + +- **Perf** + - perf(vector): updated marshalling of vector #9109 + + +## [v24.0.0] - 2024-06-06 +[v24.0.0]: https://github.com/dgraph-io/dgraph/compare/v24.0.0...v23.1.0 + +> **Warning** +> This will be a breaking change for anyone moving from to `v.24.0.0`. +> If you have any duplicate users or groups in Dgraph ACL, they would not be accessible from now on. Please delete any +> duplicate users and groups before you upgrade. File format is the same, so binary can be directly replaced after. +> deleting duplicate users and groups. + +### Added + +- **Vector** + - feat(graphql): Add vector support to graphql (#9074) + - feat(vector): add vector to schema in #9060 + - feat(vector): Added similar_to in vector in #9062 + - feat(vector): Add vector type to Dgraph in #9050 + - feat(vector): fix live loader and add tests for dropall, drop namespace, live load in #9063 + - fix(vector): show error is invalid input is provided to vector predicate in #9064 + - fix(vector):fix similar_to() error return when data is not present in #9084 + - fix(vector): Update query_rewriter to fix dotproduct and cosine query conversion in #9083 + +- **Core** + - feat(core): Add cache to dgraph.type predicate in #9068 + - [BREAKING]feat(core): add unique constraint support in schema for new predicates in #8827 + - feat(debug): add parse_key to debug tool in #7640 + - feat(acl): support more JWT algorithms for ACL by in #8912 + - feat(restore): add support for namespace aware restore by in #8968 + +- **GraphQL** + - feat(vector): Added lang support by in #8924 + - feat(graphql): allow updatable and nullable id fields. (#7736) in #9020 +### Fixed + +- **Core** + - Fix(debug): Close file correctly before exiting on error in #9076 + - fix(restore): fix incr restore and normal restore for vector predicates in #9078 + - Fix(core): Fix deadlock in runMutation and error handling in #9085 + - fix(core): Fixed deadlock that happens due to timeout in #9007 + - fix(core): limit #edges using LimitMutationsNquad config and add logs in #9010 + - fix(core): Update math parsing function by in #9053 + - fix(restore): use different map directory for each group (#8047) in #8972 + - fix(export): Support for any s3 endpoint by in #8978 + - fix(restore): use custom type for sensitive fields by in #8969 + - fix(export): Escape MySQL column names in #8961 + - fix(debug): fix debug tool for schema keys in #7939 + - fix(restore): allow incrementalFrom to be 1 in restore API by in #8988 + - fix(raft):alpha leader fails to stream snapshot to new alpha nodes in #9022 + - fix(query): fix has function in filter in #9043 + - fix(core):Reduce x.ParsedKey memory allocation from 72 to 56 bytes by optimizing struct memory alignment in #9047 + - fix(restore): do not retry restore proposal (#8058) in #9017 + +- **Perf** + - perf(core): Fix performance issue in type filter (#9065) in #9089 + - perf(core): Update postinglistCountAndLength function to improve performance in #9088 + - perf(query): use quickselect instead of sorting while pagination by in #8995 + - perf(query): Update CompressedBin IntersectionAlgo by in #9000 + +- **Chore** + - chore(upgrade): run tests from v23.1.0 -> main in #9097 + - chore(deps): upgrade etcd/raft to v3 by in #7688 + - chore(restore): add more logs for restore request (#8050) in #8975 + - upgrade(go): update go version to 1.22 in #9058 + - chore(deps): bump github.com/apache/thrift from 0.12.0 to 0.13.0 by in #8982 + - chore(deps): bump golang.org/x/net from 0.14.0 to 0.17.0 in #9015 + - chore(deps): use bleve 2.3.10 for more languages in full text search in #9030 + - chore(deps): bump golang.org/x/crypto from 0.12.0 to 0.17.0 in #9032 + - chore(deps): bump urllib3 from 1.26.5 to 1.26.18 in /contrib/config/marketplace/aws/tests in #9018 + - chore(deps): bump google.golang.org/grpc from 1.56.2 to 1.56.3 in #9024 + - chore(deps): bump google.golang.org/protobuf from 1.31.0 to 1.33.0in #9051[ + +## [23.1.1] - 2024-04-26 +[v23.1.1]: https://github.com/dgraph-io/dgraph/compare/v23.1.0...v23.1.1 + +### Fixed + +- **Core Dgraph** + - perf(core): Fix performance issue in type filter (#9065) + + +- **CI & Testing** + - ci/cd optimizations (#9069) + ## [v23.1.0] - 2023-08-17 [v23.1.0]: https://github.com/dgraph-io/dgraph/compare/v23.0.1...v23.1.0 diff --git a/check_upgrade/check_upgrade.go b/check_upgrade/check_upgrade.go new file mode 100644 index 00000000000..4c4ec02f52a --- /dev/null +++ b/check_upgrade/check_upgrade.go @@ -0,0 +1,591 @@ +/* + * Copyright 2024 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package checkupgrade + +import ( + "crypto" + "crypto/ed25519" + "encoding/json" + "fmt" + "log" + "os" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/pkg/errors" + "github.com/spf13/cobra" + + "github.com/dgraph-io/dgraph/dgraphapi" + "github.com/dgraph-io/dgraph/x" +) + +var ( + CheckUpgrade x.SubCommand +) + +const ( + alphaHttp = "http_port" + dgUser = "dgUser" + password = "password" + namespace = "namespace" + aclSecretKeyFilePath = "aclSecretKeyFilePath" + jwtAlg = "jwt-alg" + deleteDup = "delete-duplicates" + guardianGroup = "guardians" +) + +type commandInput struct { + alphaHttp string + dgUser string + password string + namespace uint64 + aclSecretKeyFilePath string + jwtAlg string + dupDelete bool +} + +type aclNode struct { + UID string `json:"uid"` + DgraphXID string `json:"dgraph.xid"` + DgraphType []string `json:"dgraph.type"` + DgraphUserGroup []UserGroup `json:"dgraph.user.group"` +} + +type UserGroup struct { + UID string `json:"uid"` + DgraphXid string `json:"dgraph.xid"` +} + +type Nodes struct { + Nodes []aclNode `json:"nodes"` +} +type Response struct { + Data Nodes `json:"data"` +} + +func parseJWTKey(alg jwt.SigningMethod, key x.Sensitive) (interface{}, interface{}, error) { + switch { + case strings.HasPrefix(alg.Alg(), "HS"): + return key, key, nil + + case strings.HasPrefix(alg.Alg(), "ES"): + pk, err := jwt.ParseECPrivateKeyFromPEM(key) + if err != nil { + return nil, nil, errors.Wrapf(err, "error parsing ACL key as ECDSA private key") + } + return pk, &pk.PublicKey, nil + + case strings.HasPrefix(alg.Alg(), "RS") || strings.HasPrefix(alg.Alg(), "PS"): + pk, err := jwt.ParseRSAPrivateKeyFromPEM(key) + if err != nil { + return nil, nil, errors.Wrapf(err, "error parsing ACL key as RSA private key") + } + return pk, &pk.PublicKey, nil + + case alg.Alg() == "EdDSA": + pk, err := jwt.ParseEdPrivateKeyFromPEM(key) + if err != nil { + return nil, nil, errors.Wrapf(err, "error parsing ACL key as EdDSA private key") + } + return pk.(crypto.Signer), pk.(ed25519.PrivateKey).Public(), nil + + default: + return nil, nil, errors.Errorf("unsupported signing algorithm: %v", alg.Alg()) + } +} + +func getAccessJwt(userId string, group string, namespace uint64, aclSecretFile string, + algStr string) (string, error) { + aclKey, err := os.ReadFile(aclSecretFile) + if err != nil { + return "", fmt.Errorf("error reading ACL secret key from file: %s: %s", aclSecretFile, err) + } + + var aclAlg jwt.SigningMethod + var privKey interface{} + if aclKey != nil { + aclAlg = jwt.GetSigningMethod(algStr) + if aclAlg == nil { + return "", fmt.Errorf("unsupported jwt signing algorithm for acl: %v", algStr) + } + privKey, _, err = parseJWTKey(aclAlg, aclKey) + if err != nil { + return "", err + } + } + + token := jwt.NewWithClaims(aclAlg, jwt.MapClaims{ + "userid": userId, + "groups": []string{group}, + "namespace": namespace, + "exp": time.Now().Add(time.Hour).Unix(), + }) + + jwtString, err := token.SignedString(x.MaybeKeyToBytes(privKey)) + if err != nil { + return "", errors.Errorf("unable to encode jwt to string: %v", err) + } + return jwtString, nil +} + +func setupClient(alphaHttp string) (*dgraphapi.HTTPClient, error) { + httpClient, err := dgraphapi.GetHttpClient(alphaHttp, "") + if err != nil { + return nil, errors.Wrapf(err, "while getting HTTP client") + } + return httpClient, nil +} + +func contains(slice []string, value string) bool { + for _, v := range slice { + if v == value { + return true + } + } + return false +} + +func findDuplicateNodes(aclNodes []aclNode) [3]map[string][]string { + du := make(map[string][]string) + dg := make(map[string][]string) + dug := make(map[string][]string) + + for i, node1 := range aclNodes { + for j := i + 1; j < len(aclNodes); j++ { + node2 := aclNodes[j] + if node1.DgraphXID == node2.DgraphXID { + if node1.DgraphType[0] == "dgraph.type.User" && node1.DgraphType[0] == node2.DgraphType[0] { + filterAndRecordDuplicates(du, node1, node2) + } else if node1.DgraphType[0] == "dgraph.type.Group" && node1.DgraphType[0] == node2.DgraphType[0] { + filterAndRecordDuplicates(dg, node1, node2) + } else { + filterAndRecordDuplicates(dug, node1, node2) + } + } + } + } + + return [3]map[string][]string{ + du, dg, dug, + } +} + +func filterAndRecordDuplicates(du map[string][]string, node1 aclNode, node2 aclNode) { + if _, exists := du[node1.DgraphXID]; !exists { + du[node1.DgraphXID] = []string{} + } + if !contains(du[node1.DgraphXID], node1.UID) { + du[node1.DgraphXID] = append(du[node1.DgraphXID], node1.UID) + } + if !contains(du[node1.DgraphXID], node2.UID) { + du[node1.DgraphXID] = append(du[node1.DgraphXID], node2.UID) + } +} + +func queryDuplicateNodes(hc *dgraphapi.HTTPClient) ([3]map[string][]string, error) { + query := `{ + nodes(func: has(dgraph.xid)) { + uid + dgraph.xid + dgraph.type + } + }` + + resp, err := hc.PostDqlQuery(query) + if err != nil { + return [3]map[string][]string{}, errors.Wrapf(err, "while querying dgraph for duplicate nodes") + } + + var result Response + if err := json.Unmarshal(resp, &result); err != nil { + return [3]map[string][]string{}, errors.Wrapf(err, "while unmarshalling response: %v", string(resp)) + + } + return findDuplicateNodes(result.Data.Nodes), nil +} + +func printAndDeleteDuplicates(hc *dgraphapi.HTTPClient, entityType string, ns uint64, nodesmap map[string][]string, + dupDelete bool) error { + if len(nodesmap) == 0 { + return nil + } + + fmt.Printf("Found duplicate %ss in namespace: #%v\n", entityType, ns) + for key, node := range nodesmap { + fmt.Printf("dgraph.xid %v , Uids: %v\n", key, node) + } + + if dupDelete { + switch entityType { + case "user": + return deleteDuplicatesUser(hc, nodesmap) + + case "group": + return deleteDuplicatesGroup(hc, nodesmap) + default: + return deleteDuplicatesUserGroup(hc, nodesmap) + + } + } + return nil +} +func deleteUids(hc *dgraphapi.HTTPClient, uids []string, skipUid int, node string) error { + query := `{ + delete { + <%v> * * . + } + }` + + for i, uid := range uids { + if i == skipUid { + continue + } + + fmt.Printf("deleting following uid [%v] of duplicate node:%v\n", uid, node) + _, err := hc.Mutate(fmt.Sprintf(query, uid), true) + if err != nil { + return err + } + } + + return nil +} + +// uniqueStringsExcluding extracts unique strings from nodeCollection excluding those in the exclude slice +func uniqueStringsExcluding(nodeCollection [][]string, exclude []string) []string { + excludeMap := make(map[string]struct{}) + for _, e := range exclude { + excludeMap[e] = struct{}{} + } + + uniqueMap := make(map[string]struct{}) + for _, nodes := range nodeCollection { + for _, node := range nodes { + if _, inExclude := excludeMap[node]; !inExclude { + uniqueMap[node] = struct{}{} + } + } + } + + uniqueSlice := make([]string, 0, len(uniqueMap)) + for node := range uniqueMap { + uniqueSlice = append(uniqueSlice, node) + } + + return uniqueSlice +} + +func queryUserGroup(hc *dgraphapi.HTTPClient, uid string) (aclNode, error) { + query := fmt.Sprintf(`{ + nodes(func: eq("dgraph.xid", "%v")) { + uid + dgraph.xid + } + }`, uid) + resp, err := hc.PostDqlQuery(query) + if err != nil { + return aclNode{}, err + } + var result Response + if err := json.Unmarshal(resp, &result); err != nil { + return aclNode{}, errors.Wrapf(err, "while unmarshalling response: %v", string(resp)) + } + + if len(result.Data.Nodes) > 1 { + return aclNode{}, nil + } + + return result.Data.Nodes[0], nil +} + +func addUsersToGroup(hc *dgraphapi.HTTPClient, users []string, groupUid string) error { + rdf := `` + for _, user := range users { + fmt.Printf("adding user %v to group %v\n", user, groupUid) + node, err := queryUserGroup(hc, user) + if err != nil { + return err + } + if node.UID != "" { + rdf += fmt.Sprintf("<%v> <%v> .\n", groupUid, node.UID) + + } + } + + _, err := hc.Mutate(rdf, true) + if err != nil { + return err + } + return nil +} + +func deleteDuplicatesGroup(hc *dgraphapi.HTTPClient, duplicates map[string][]string) error { + query := `{ + nodes(func: uid(%v)) { + uid + dgraph.xid + dgraph.type + ~dgraph.user.group{ + dgraph.xid + } + } + }` + + for group, uids := range duplicates { + var nodeCollection [][]string + + for _, uid := range uids { + resp, err := hc.PostDqlQuery(fmt.Sprintf(query, uid)) + if err != nil { + return err + } + var result Response + if err := json.Unmarshal(resp, &result); err != nil { + log.Fatalf("while unmarshalling response: %v", err) + } + var strs []string + for i := range result.Data.Nodes[0].DgraphUserGroup { + strs = append(strs, result.Data.Nodes[0].DgraphUserGroup[i].DgraphXid) + } + nodeCollection = append(nodeCollection, strs) + } + var saveIndex int + prevLen := 0 + + fmt.Printf("keeping group%v with uid: %v", group, uids[saveIndex]) + if group == guardianGroup { + for k, nodes := range nodeCollection { + if contains(nodes, "groot") && len(nodes) > prevLen { + saveIndex = k + prevLen = len(nodes) + } + } + uniqueUsers := uniqueStringsExcluding(nodeCollection, uids) + if err := addUsersToGroup(hc, uniqueUsers, uids[saveIndex]); err != nil { + return err + } + if err := deleteUids(hc, uids, saveIndex, group); err != nil { + return err + } + + } else { + if err := deleteUids(hc, uids, 0, group); err != nil { + return err + } + } + + } + return nil +} + +func deleteDuplicatesUser(hc *dgraphapi.HTTPClient, duplicates map[string][]string) error { + query := `{ + nodes(func: uid(%v)) { + uid + dgraph.xid + dgraph.type + dgraph.user.group{ + dgraph.xid + } + } + }` + for user, uids := range duplicates { + var groupsCollection [][]string + for _, uid := range uids { + resp, err := hc.PostDqlQuery(fmt.Sprintf(query, uid)) + if err != nil { + return err + } + var result Response + if err := json.Unmarshal(resp, &result); err != nil { + log.Fatalf("while unmarshalling response: %v", err) + } + var strs []string + for i := range result.Data.Nodes[0].DgraphUserGroup { + strs = append(strs, result.Data.Nodes[0].DgraphUserGroup[i].DgraphXid) + } + groupsCollection = append(groupsCollection, strs) + } + var saveIndex int + prevLen := 0 + for k, groups := range groupsCollection { + if contains(groups, "guardians") && len(groups) > prevLen { + saveIndex = k + prevLen = len(groups) + } + } + + fmt.Printf("keeping user%v with uid: %v", user, uids[saveIndex]) + + if err := deleteUids(hc, uids, saveIndex, user); err != nil { + return err + } + + } + return nil +} + +func deleteDuplicatesUserGroup(hc *dgraphapi.HTTPClient, duplicates map[string][]string) error { + // we will delete only user in this case + query := `{ + nodes(func: uid(%v)) { + uid + dgraph.xid + dgraph.type + } + }` + + for userGroup, uids := range duplicates { + var saveIndex int + + for i, uid := range uids { + resp, err := hc.PostDqlQuery(fmt.Sprintf(query, uid)) + if err != nil { + return err + } + var result Response + if err := json.Unmarshal(resp, &result); err != nil { + log.Fatalf("while unmarshalling response: %v", err) + } + + if result.Data.Nodes[0].DgraphType[0] == "dgraph.type.group" { + saveIndex = i + break + } + } + fmt.Printf("keeping group%v with uid: %v", userGroup, uids[saveIndex]) + fmt.Print("\n") + + if err := deleteUids(hc, uids, saveIndex, userGroup); err != nil { + return err + } + + } + return nil +} + +func init() { + CheckUpgrade.Cmd = &cobra.Command{ + Use: "checkupgrade", + Short: "Run the checkupgrade tool", + Long: "The checkupgrade tool is used to check for duplicate dgraph.xid's in the Dgraph database before upgrade.", + Run: func(cmd *cobra.Command, args []string) { + run() + }, + Annotations: map[string]string{"group": "tool"}, + } + CheckUpgrade.Cmd.SetHelpTemplate(x.NonRootTemplate) + flag := CheckUpgrade.Cmd.Flags() + flag.String(alphaHttp, "127.0.0.1:8080", "Dgraph Alpha Http server address") + flag.String(namespace, "", "Namespace to check for duplicate nodes") + flag.String(dgUser, "groot", "Username of the namespace's user") + flag.String(password, "password", "Password of the namespace's user") + flag.String(aclSecretKeyFilePath, "", "path of file that stores secret key or private key,"+ + " which is used to sign the ACL JWT") + flag.String(jwtAlg, "HS256", "JWT signing algorithm") + flag.String(deleteDup, "false", "set this flag to true to delete duplicates nodes") +} + +func run() { + if err := checkUpgrade(); err != nil { + fmt.Fprintln(os.Stderr, err) + } +} + +func checkUpgrade() error { + fmt.Println("Running check-upgrade tool") + + cmdInput := parseInput() + var accessJwt string + var err error + if cmdInput.aclSecretKeyFilePath != "" { + accessJwt, err = getAccessJwt(dgraphapi.DefaultUser, guardianGroup, 0, cmdInput.aclSecretKeyFilePath, + cmdInput.jwtAlg) + if err != nil { + return errors.Wrapf(err, "while getting access jwt token") + } + } + hc, err := setupClient(cmdInput.alphaHttp) + if err != nil { + return errors.Wrapf(err, "while setting up clients") + } + + hc.AccessJwt = accessJwt + + var namespaces []uint64 + if cmdInput.namespace == 0 { + namespaces, err = hc.ListNamespaces() + if err != nil { + return errors.Wrapf(err, "while lisiting namespaces") + } + } else { + namespaces = append(namespaces, cmdInput.namespace) + } + + for _, ns := range namespaces { + if cmdInput.aclSecretKeyFilePath != "" { + hc.AccessJwt, err = getAccessJwt(dgraphapi.DefaultUser, guardianGroup, ns, cmdInput.aclSecretKeyFilePath, + cmdInput.jwtAlg) + if err != nil { + return errors.Wrapf(err, "while getting access jwt token for namespace %v", ns) + } + } else { + if err := hc.LoginIntoNamespace(cmdInput.dgUser, cmdInput.password, ns); err != nil { + return errors.Wrapf(err, "while logging into namespace %v", ns) + } + } + + duplicates, err := queryDuplicateNodes(hc) + if err != nil { + return err + } + if err := printAndDeleteDuplicates(hc, "user", ns, duplicates[0], cmdInput.dupDelete); err != nil { + return err + } + // example output: + // Found duplicate users in namespace: #0 + // dgraph.xid user1 , Uids: [0x4 0x3] + if err := printAndDeleteDuplicates(hc, "group", ns, duplicates[1], cmdInput.dupDelete); err != nil { + return err + } + // Found duplicate groups in namespace: #1 + // dgraph.xid group1 , Uids: [0x2714 0x2711] + if err := printAndDeleteDuplicates(hc, "groups and user", ns, duplicates[2], cmdInput.dupDelete); err != nil { + return err + } + // Found duplicate groups and users in namespace: #0 + // dgraph.xid userGroup1 , Uids: [0x7532 0x7531] + } + + fmt.Println("To delete duplicate nodes use following mutation: ") + deleteMut := ` + { + delete { + * * . + } + }` + fmt.Fprint(os.Stderr, deleteMut) + return nil +} + +func parseInput() *commandInput { + return &commandInput{alphaHttp: CheckUpgrade.Conf.GetString(alphaHttp), dgUser: CheckUpgrade.Conf.GetString(dgUser), + password: CheckUpgrade.Conf.GetString(password), namespace: CheckUpgrade.Conf.GetUint64(namespace), + aclSecretKeyFilePath: CheckUpgrade.Conf.GetString(aclSecretKeyFilePath), + jwtAlg: CheckUpgrade.Conf.GetString(jwtAlg), dupDelete: CheckUpgrade.Conf.GetBool(deleteDup)} +} diff --git a/check_upgrade/check_upgrade_test.go b/check_upgrade/check_upgrade_test.go new file mode 100644 index 00000000000..663fe529366 --- /dev/null +++ b/check_upgrade/check_upgrade_test.go @@ -0,0 +1,199 @@ +//go:build integration2 + +/* + * Copyright 2024 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package checkupgrade + +import ( + "context" + "fmt" + "os/exec" + "path/filepath" + "regexp" + "testing" + "time" + + "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" + "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/x" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" +) + +func TestCheckUpgrade(t *testing.T) { + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1). + WithACL(time.Hour).WithVersion("57aa5c4ac") + c, err := dgraphtest.NewLocalCluster(conf) + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + + gc, cleanup, err := c.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + hc, err := c.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + rdfs := ` + _:a "user1" . + _:a "dgraph.type.User" . + _:b "user1" . + _:b "dgraph.type.User" .` + + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + var nss []uint64 + for i := 0; i < 5; i++ { + ns, err := hc.AddNamespace() + require.NoError(t, err) + require.NoError(t, gc.LoginIntoNamespace(context.Background(), "groot", "password", ns)) + mu = &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + nss = append(nss, ns) + } + + conf1 := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour).WithVersion("local") + c1, err := dgraphtest.NewLocalCluster(conf1) + require.NoError(t, err) + defer func() { c1.Cleanup(t.Failed()) }() + require.NoError(t, c1.Start()) + alphaHttp, err := c.GetAlphaHttpPublicPort() + require.NoError(t, err) + + args := []string{ + "checkupgrade", + "--http_port", "localhost:" + alphaHttp, + "--dgUser", "groot", + "--password", "password", + "--namespace", "1", + } + + cmd := exec.Command(filepath.Join(c1.GetTempDir(), "dgraph"), args...) + out, err := cmd.CombinedOutput() + require.NoError(t, err) + actualOutput := string(out) + fmt.Println("logs of checkupgrade tool\n", actualOutput) + expectedOutputPattern := `Found duplicate users in namespace: #\d+\ndgraph\.xid user1 , Uids: \[\d+x\d+ \d+x\d+\]\n` + match, err := regexp.MatchString(expectedOutputPattern, actualOutput) + require.NoError(t, err) + + if !match { + t.Errorf("Output does not match expected pattern.\nExpected pattern:\n%s\n\nGot:\n%s", + expectedOutputPattern, actualOutput) + } +} + +func TestQueryDuplicateNodes(t *testing.T) { + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1). + WithACL(time.Hour).WithVersion("57aa5c4ac").WithAclAlg(jwt.GetSigningMethod("HS256")) + c, err := dgraphtest.NewLocalCluster(conf) + require.NoError(t, err) + // defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + gc, cleanup, err := c.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + hc, err := c.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + rdfs := ` + <0x40> "user1" . + <0x40> "dgraph.type.User" . + <0x50> "user1" . + <0x50> "dgraph.type.User" . + <0x60> "user1" . + <0x60> "dgraph.type.User" . + <0x60> <0x1> . + <0x50> <0x1> . + <0x70> "user1" . + <0x70> "dgraph.type.User" . + <0x80> "user3" . + <0x80> "dgraph.type.User" . + <0x90> "user3" . + <0x90> "dgraph.type.User" . + <0x100> "Group4" . + <0x100> "dgraph.type.Group" . + <0x110> "Group4" . + <0x110> "dgraph.type.Group" . + <0x120> "Group4" . + <0x120> "dgraph.type.Group" . + <0x130> "Group4" . + <0x130> "dgraph.type.Group" . + <0x140> "Group4" . + <0x140> "dgraph.type.Group" . + <0x150> "usrgrp1" . + <0x150> "dgraph.type.User" . + <0x160> "usrgrp1" . + <0x160> "dgraph.type.User" . + <0x170> "usrgrp1" . + <0x170> "dgraph.type.User" . + <0x180> "usrgrp1" . + <0x180> "dgraph.type.Group" . + <0x200> "usrgrp2" . + <0x200> "dgraph.type.Group" . + <0x210> "usrgrp2" . + <0x210> "dgraph.type.User" . + ` + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + duplicateNodes, err := queryDuplicateNodes(hc) + require.NoError(t, err) + + du := map[string][]string{ + "user1": {"0x40", "0x50", "0x60", "0x70"}, + "user3": {"0x80", "0x90"}, + "usrgrp1": {"0x150", "0x160", "0x170"}, + } + + dg := map[string][]string{ + "Group4": {"0x100", "0x110", "0x120", "0x130", "0x140"}, + "usrgrp1": {"0x180", "0x190"}, + } + + dug := map[string][]string{ + "usrgrp1": {"0x150", "0x160", "0x170", "0x180"}, + "usrgrp2": {"0x200", "0x210"}, + } + + expectedDup := [3]map[string][]string{du, dg, dug} + + for i, dn := range duplicateNodes { + for j, d := range dn { + require.Equal(t, len(expectedDup[i][j]), len(d)) + for _, uid := range d { + require.Contains(t, expectedDup[i][j], uid) + } + } + } + require.NoError(t, deleteDuplicatesGroup(hc, duplicateNodes[0])) + require.NoError(t, deleteDuplicatesGroup(hc, duplicateNodes[1])) + require.NoError(t, deleteDuplicatesGroup(hc, duplicateNodes[2])) +} diff --git a/chunker/json_parser.go b/chunker/json_parser.go index 7535f24310f..1d82f9a44bd 100644 --- a/chunker/json_parser.go +++ b/chunker/json_parser.go @@ -230,7 +230,7 @@ func handleBasicType(k string, v interface{}, op int, nq *api.NQuad) error { return nil } - if vf, err := types.ParseVFloat(v); err == nil { + if vf, err := types.ParseVFloat(v); err == nil && len(vf) != 0 { nq.ObjectValue = &api.Value{Val: &api.Value_Vfloat32Val{Vfloat32Val: types.FloatArrayAsBytes(vf)}} return nil } diff --git a/chunker/json_parser_test.go b/chunker/json_parser_test.go index e98ff90ed97..fe1a581b1ca 100644 --- a/chunker/json_parser_test.go +++ b/chunker/json_parser_test.go @@ -74,6 +74,13 @@ type Person struct { School *School `json:"school,omitempty"` } +type Product struct { + Uid string `json:"uid,omitempty"` + Name string `json:"name"` + Discription string `json:"discription"` + Discription_v string `json:"discription_v"` +} + func Parse(b []byte, op int) ([]*api.NQuad, error) { nqs := NewNQuadBuffer(1000) err := nqs.ParseJSON(b, op) @@ -1380,3 +1387,128 @@ func BenchmarkNoFacetsFast(b *testing.B) { _, _ = FastParse(json, SetNquads) } } + +func TestNquadsEmptyStringFromJson(t *testing.T) { + json := `[{"name":""}]` + + nq, err := Parse([]byte(json), SetNquads) + require.NoError(t, err) + + fastNQ, err := FastParse([]byte(json), SetNquads) + require.NoError(t, err) + + // The string value should be empty. + require.Equal(t, nq[0].ObjectValue.GetStrVal(), "") + require.Equal(t, fastNQ[0].ObjectValue.GetStrVal(), "") +} + +func TestNquadsJsonEmptyStringVectorPred(t *testing.T) { + p := Product{ + Uid: "1", + Name: "", + Discription_v: "", + } + + b, err := json.Marshal([]Product{p}) + require.NoError(t, err) + + nq, err := Parse(b, SetNquads) + require.NoError(t, err) + require.Equal(t, 3, len(nq)) + + fastNQ, err := FastParse(b, SetNquads) + require.NoError(t, err) + require.Equal(t, 3, len(fastNQ)) + + // predicate Name should be empty and edge for Discription_v should not be there + // we do not create edge for "" in float32vector. + exp := &Experiment{ + t: t, + nqs: nq, + schema: `name: string @index(exact) . + discription_v: float32vector .`, + query: `{product(func: uid(1)) { + name + discription_v + }}`, + expected: `{"product":[{ + "name":""}]}`, + } + exp.verify() + + exp.nqs = fastNQ + exp.verify() +} + +func TestNquadsJsonEmptySquareBracketVectorPred(t *testing.T) { + p := Product{ + Name: "ipad", + Discription_v: "[]", + } + + b, err := json.Marshal(p) + require.NoError(t, err) + + nq, err := Parse(b, SetNquads) + require.NoError(t, err) + require.Equal(t, 3, len(nq)) + + fastNQ, err := FastParse(b, SetNquads) + require.NoError(t, err) + require.Equal(t, 3, len(fastNQ)) + + // predicate Name should have value "ipad" and edge for Discription_v should not be there + // we do not create edge for [] in float32vector. + exp := &Experiment{ + t: t, + nqs: nq, + schema: `name: string @index(exact) . + discription_v: float32vector .`, + query: `{product(func: eq(name, "ipad")) { + name + discription_v + }}`, + expected: `{"product":[{ + "name":"ipad"}]}`, + } + exp.verify() + + exp.nqs = fastNQ + exp.verify() +} + +func TestNquadsJsonValidVector(t *testing.T) { + p := Product{ + Name: "ipad", + Discription_v: "[1.1, 2.2, 3.3]", + } + + b, err := json.Marshal(p) + require.NoError(t, err) + + nq, err := Parse(b, SetNquads) + require.NoError(t, err) + require.Equal(t, 3, len(nq)) + + fastNQ, err := FastParse(b, SetNquads) + require.NoError(t, err) + require.Equal(t, 3, len(fastNQ)) + + exp := &Experiment{ + t: t, + nqs: nq, + schema: `name: string @index(exact) . + discription_v: float32vector .`, + query: `{product(func: eq(name, "ipad")) { + name + discription_v + }}`, + expected: `{"product":[{ + "name":"ipad", + "discription_v":[1.1, 2.2, 3.3]}]}`, + } + exp.verify() + + exp.nqs = fastNQ + exp.verify() +} diff --git a/contrib/config/marketplace/aws/tests/requirements.txt b/contrib/config/marketplace/aws/tests/requirements.txt index 5f7261d236b..aec22f7f08e 100644 --- a/contrib/config/marketplace/aws/tests/requirements.txt +++ b/contrib/config/marketplace/aws/tests/requirements.txt @@ -4,7 +4,7 @@ awscli==1.18.40 backports.shutil-get-terminal-size==1.0.0 boto3==1.12.40 botocore==1.15.40 -certifi==2023.7.22 +certifi==2024.7.4 cfn-lint==0.29.5 chardet==3.0.4 colorama==0.4.3 @@ -14,7 +14,7 @@ docker==4.2.0 docutils==0.15.2 dulwich==0.19.15 idna==3.7 -Jinja2==2.11.3 +Jinja2==3.1.4 jmespath==0.9.5 jsonpatch==1.25 jsonpointer==2.0 @@ -29,13 +29,13 @@ pyrsistent==0.16.0 python-dateutil==2.8.1 PyYAML==5.4 reprint==0.5.2 -requests==2.31.0 +requests==2.32.0 rsa==4.7 s3transfer==0.3.3 six==1.14.0 tabulate==0.8.7 taskcat==0.9.17 typing-extensions==3.7.4.2 -urllib3==1.26.18 +urllib3==1.26.19 websocket-client==0.57.0 yattag==1.13.2 diff --git a/dgraph/cmd/alpha/run.go b/dgraph/cmd/alpha/run.go index b6132a10781..857843ff45a 100644 --- a/dgraph/cmd/alpha/run.go +++ b/dgraph/cmd/alpha/run.go @@ -208,6 +208,10 @@ they form a Raft group and provide synchronous replication. Flag("shared-instance", "When set to true, it disables ACLs for non-galaxy users. "+ "It expects the access JWT to be constructed outside dgraph for non-galaxy users as "+ "login is denied to them. Additionally, this disables access to environment variables for minio, aws, etc."). + Flag("type-filter-uid-limit", "TypeFilterUidLimit decides how many elements would be searched directly"+ + " vs searched via type index. If the number of elements are too low, then querying the"+ + " index might be slower. This would allow people to set their limit according to"+ + " their use case."). String()) flag.String("graphql", worker.GraphQLDefaults, z.NewSuperFlagHelp(worker.GraphQLDefaults). @@ -641,16 +645,21 @@ func run() { security := z.NewSuperFlag(Alpha.Conf.GetString("security")).MergeAndCheckDefault( worker.SecurityDefaults) conf := audit.GetAuditConf(Alpha.Conf.GetString("audit")) + + x.Config.Limit = z.NewSuperFlag(Alpha.Conf.GetString("limit")).MergeAndCheckDefault( + worker.LimitDefaults) + opts := worker.Options{ PostingDir: Alpha.Conf.GetString("postings"), WALDir: Alpha.Conf.GetString("wal"), CacheMb: totalCache, CachePercentage: cachePercentage, - MutationsMode: worker.AllowMutations, - AuthToken: security.GetString("token"), - Audit: conf, - ChangeDataConf: Alpha.Conf.GetString("cdc"), + MutationsMode: worker.AllowMutations, + AuthToken: security.GetString("token"), + Audit: conf, + ChangeDataConf: Alpha.Conf.GetString("cdc"), + TypeFilterUidLimit: x.Config.Limit.GetInt64("type-filter-uid-limit"), } keys, err := ee.GetKeys(Alpha.Conf) @@ -665,8 +674,6 @@ func run() { glog.Info("ACL secret key loaded successfully.") } - x.Config.Limit = z.NewSuperFlag(Alpha.Conf.GetString("limit")).MergeAndCheckDefault( - worker.LimitDefaults) abortDur := x.Config.Limit.GetDuration("txn-abort-after") switch strings.ToLower(x.Config.Limit.GetString("mutations")) { case "allow": diff --git a/dgraph/cmd/debuginfo/debugging.go b/dgraph/cmd/debuginfo/debugging.go index 18ad5366694..35190f8c4ce 100644 --- a/dgraph/cmd/debuginfo/debugging.go +++ b/dgraph/cmd/debuginfo/debugging.go @@ -79,11 +79,13 @@ func saveDebug(sourceURL, filePath string, duration time.Duration) error { glog.Warningf("error closing resp reader: %v", err) } }() - out, err := os.Create(filePath) if err != nil { return fmt.Errorf("error while creating debug file: %s", err) } + defer func() { + out.Close() + }() _, err = io.Copy(out, resp) return err } diff --git a/dgraph/cmd/root.go b/dgraph/cmd/root.go index c1938b7f236..cd752d5fa5a 100644 --- a/dgraph/cmd/root.go +++ b/dgraph/cmd/root.go @@ -32,6 +32,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" + checkupgrade "github.com/dgraph-io/dgraph/check_upgrade" "github.com/dgraph-io/dgraph/dgraph/cmd/alpha" "github.com/dgraph-io/dgraph/dgraph/cmd/bulk" "github.com/dgraph-io/dgraph/dgraph/cmd/cert" @@ -84,6 +85,7 @@ var rootConf = viper.New() var subcommands = []*x.SubCommand{ &bulk.Bulk, &cert.Cert, &conv.Conv, &live.Live, &alpha.Alpha, &zero.Zero, &version.Version, &debug.Debug, &migrate.Migrate, &debuginfo.DebugInfo, &upgrade.Upgrade, &decrypt.Decrypt, &increment.Increment, + &checkupgrade.CheckUpgrade, } func initCmds() { diff --git a/dgraphtest/cluster.go b/dgraphapi/cluster.go similarity index 88% rename from dgraphtest/cluster.go rename to dgraphapi/cluster.go index bbe190ea162..8f833d06d5f 100644 --- a/dgraphtest/cluster.go +++ b/dgraphapi/cluster.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package dgraphtest +package dgraphapi import ( "bytes" @@ -26,7 +26,6 @@ import ( "net/http" "os/exec" "strings" - "testing" "time" "github.com/pkg/errors" @@ -37,6 +36,12 @@ import ( "github.com/dgraph-io/dgraph/x" ) +const ( + localVersion = "local" + DefaultUser = "groot" + DefaultPassword = "password" +) + type Cluster interface { Client() (*GrpcClient, func(), error) HTTPClient() (*HTTPClient, error) @@ -45,6 +50,7 @@ type Cluster interface { AssignUids(gc *dgo.Dgraph, num uint64) error GetVersion() string GetEncKeyPath() (string, error) + GetRepoDir() (string, error) } type GrpcClient struct { @@ -62,11 +68,12 @@ type HttpToken struct { // HTTPClient allows doing operations on Dgraph over http type HTTPClient struct { *HttpToken - adminURL string - graphqlURL string - licenseURL string - stateURL string - dqlURL string + adminURL string + graphqlURL string + licenseURL string + stateURL string + dqlURL string + dqlMutateUrl string } // GraphQLParams are used for making graphql requests to dgraph @@ -90,6 +97,31 @@ type LicenseResponse struct { Extensions map[string]interface{} `json:"license,omitempty"` } +var client *http.Client = &http.Client{ + Timeout: requestTimeout, +} + +func DoReq(req *http.Request) ([]byte, error) { + resp, err := client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "error performing HTTP request") + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Printf("[WARNING] error closing response body: %v", err) + } + }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrapf(err, "error reading response body: url: [%v], err: [%v]", req.URL, err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("got non 200 resp: %v", string(respBody)) + } + return respBody, nil +} + func (hc *HTTPClient) LoginUsingToken(ns uint64) error { q := `mutation login( $namespace: Int, $refreshToken:String) { login(namespace: $namespace, refreshToken: $refreshToken) { @@ -239,7 +271,7 @@ func (hc *HTTPClient) doPost(body []byte, url string, contentType string) ([]byt req.Header.Add("X-Dgraph-AccessToken", hc.AccessJwt) } - return doReq(req) + return DoReq(req) } // RunGraphqlQuery makes a query to graphql (or admin) endpoint @@ -264,7 +296,8 @@ func (hc *HTTPClient) RunGraphqlQuery(params GraphQLParams, admin bool) ([]byte, return nil, errors.Wrap(err, "error unmarshalling GQL response") } if len(gqlResp.Errors) > 0 { - return nil, errors.Wrapf(gqlResp.Errors, "error while running graphql query, resp: %v", string(gqlResp.Data)) + return nil, errors.Wrapf(gqlResp.Errors, "error while running graphql query, resp: %v", + string(gqlResp.Data)) } return gqlResp.Data, nil } @@ -287,8 +320,13 @@ func (hc *HTTPClient) HealthForInstance() ([]byte, error) { // Backup creates a backup of dgraph at a given path func (hc *HTTPClient) Backup(c Cluster, forceFull bool, backupPath string) error { + repoDir, err := c.GetRepoDir() + if err != nil { + return errors.Wrapf(err, "error getting repo directory") + } + // backup API was made async in the commit d3bf7b7b2786bcb99f02e1641f3b656d0a98f7f4 - asyncAPI, err := IsHigherVersion(c.GetVersion(), "d3bf7b7b2786bcb99f02e1641f3b656d0a98f7f4") + asyncAPI, err := IsHigherVersion(c.GetVersion(), "d3bf7b7b2786bcb99f02e1641f3b656d0a98f7f4", repoDir) if err != nil { return errors.Wrapf(err, "error checking incremental restore support") } @@ -378,9 +416,13 @@ func (hc *HTTPClient) WaitForTask(taskId string) error { // Restore performs restore on Dgraph cluster from the given path to backup func (hc *HTTPClient) Restore(c Cluster, backupPath string, backupId string, incrFrom, backupNum int) error { + repoDir, err := c.GetRepoDir() + if err != nil { + return errors.Wrapf(err, "error getting repo directory") + } // incremental restore was introduced in commit 8b3712e93ed2435bea52d957f7b69976c6cfc55b - incrRestoreSupported, err := IsHigherVersion(c.GetVersion(), "8b3712e93ed2435bea52d957f7b69976c6cfc55b") + incrRestoreSupported, err := IsHigherVersion(c.GetVersion(), "8b3712e93ed2435bea52d957f7b69976c6cfc55b", repoDir) if err != nil { return errors.Wrapf(err, "error checking incremental restore support") } @@ -657,7 +699,24 @@ func (hc *HTTPClient) PostDqlQuery(query string) ([]byte, error) { if hc.HttpToken != nil { req.Header.Add("X-Dgraph-AccessToken", hc.AccessJwt) } - return doReq(req) + return DoReq(req) +} + +func (hc *HTTPClient) Mutate(mutation string, commitNow bool) ([]byte, error) { + url := hc.dqlMutateUrl + if commitNow { + url = hc.dqlMutateUrl + "?commitNow=true" + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBufferString(mutation)) + if err != nil { + return nil, errors.Wrapf(err, "error building req for endpoint [%v]", url) + } + req.Header.Add("Content-Type", "application/rdf") + if hc.HttpToken != nil { + req.Header.Add("X-Dgraph-AccessToken", hc.AccessJwt) + } + return DoReq(req) } // SetupSchema sets up DQL schema @@ -715,20 +774,8 @@ func (gc *GrpcClient) Query(query string) (*api.Response, error) { return txn.Query(ctx, query) } -// ShouldSkipTest skips a given test if clusterVersion < minVersion -func ShouldSkipTest(t *testing.T, minVersion, clusterVersion string) error { - supported, err := IsHigherVersion(clusterVersion, minVersion) - if err != nil { - t.Fatal(err) - } - if !supported { - t.Skipf("test is valid for commits greater than [%v]", minVersion) - } - return nil -} - // IsHigherVersion checks whether "higher" is the higher version compared to "lower" -func IsHigherVersion(higher, lower string) (bool, error) { +func IsHigherVersion(higher, lower, repoDir string) (bool, error) { // the order of if conditions matters here if lower == localVersion { return false, nil @@ -751,3 +798,21 @@ func IsHigherVersion(higher, lower string) (bool, error) { return true, nil } + +func GetHttpClient(alphaUrl, zeroUrl string) (*HTTPClient, error) { + adminUrl := "http://" + alphaUrl + "/admin" + graphQLUrl := "http://" + alphaUrl + "/graphql" + licenseUrl := "http://" + zeroUrl + "/enterpriseLicense" + stateUrl := "http://" + zeroUrl + "/state" + dqlUrl := "http://" + alphaUrl + "/query" + dqlMutateUrl := "http://" + alphaUrl + "/mutate" + return &HTTPClient{ + adminURL: adminUrl, + graphqlURL: graphQLUrl, + licenseURL: licenseUrl, + stateURL: stateUrl, + dqlURL: dqlUrl, + dqlMutateUrl: dqlMutateUrl, + HttpToken: &HttpToken{}, + }, nil +} diff --git a/dgraphtest/ee.go b/dgraphapi/ee.go similarity index 95% rename from dgraphtest/ee.go rename to dgraphapi/ee.go index 6f3b4316769..57fe41d9a71 100644 --- a/dgraphtest/ee.go +++ b/dgraphapi/ee.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package dgraphtest +package dgraphapi import ( "encoding/json" @@ -123,7 +123,7 @@ func (hc *HTTPClient) CreateGroup(name string) (string, error) { } resp, err := hc.RunGraphqlQuery(params, true) if err != nil { - return "", nil + return "", err } type Response struct { AddGroup struct { @@ -453,3 +453,27 @@ func (hc *HTTPClient) DeleteNamespace(nsID uint64) (uint64, error) { } return 0, errors.New(result.DeleteNamespace.Message) } + +func (hc *HTTPClient) ListNamespaces() ([]uint64, error) { + const listNss = `{ state { + namespaces + } + }` + + params := GraphQLParams{Query: listNss} + resp, err := hc.RunGraphqlQuery(params, true) + if err != nil { + return nil, err + } + + var result struct { + State struct { + Namespaces []uint64 `json:"namespaces"` + } `json:"state"` + } + if err := json.Unmarshal(resp, &result); err != nil { + return nil, errors.Wrap(err, "error unmarshalling response") + } + + return result.State.Namespaces, nil +} diff --git a/dgraphtest/json.go b/dgraphapi/json.go similarity index 97% rename from dgraphtest/json.go rename to dgraphapi/json.go index 47411bd6cb4..73ae5e84005 100644 --- a/dgraphtest/json.go +++ b/dgraphapi/json.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package dgraphtest +package dgraphapi import ( "encoding/json" @@ -30,6 +30,11 @@ import ( "github.com/pkg/errors" ) +const ( + waitDurBeforeRetry = time.Second + requestTimeout = 120 * time.Second +) + func PollTillPassOrTimeout(gcli *GrpcClient, query, want string, timeout time.Duration) error { ticker := time.NewTimer(requestTimeout) defer ticker.Stop() diff --git a/dgraphtest/snapshot.go b/dgraphapi/snapshot.go similarity index 98% rename from dgraphtest/snapshot.go rename to dgraphapi/snapshot.go index bd0207cffdc..430d605ab39 100644 --- a/dgraphtest/snapshot.go +++ b/dgraphapi/snapshot.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package dgraphtest +package dgraphapi import ( "encoding/json" diff --git a/dgraphtest/vector.go b/dgraphapi/vector.go similarity index 99% rename from dgraphtest/vector.go rename to dgraphapi/vector.go index 36b76e680dc..de3dad04878 100644 --- a/dgraphtest/vector.go +++ b/dgraphapi/vector.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package dgraphtest +package dgraphapi import ( "encoding/json" diff --git a/dgraphtest/compose_cluster.go b/dgraphtest/compose_cluster.go index e11e40719f7..640b66ba9f4 100644 --- a/dgraphtest/compose_cluster.go +++ b/dgraphtest/compose_cluster.go @@ -18,6 +18,7 @@ package dgraphtest import ( "github.com/dgraph-io/dgo/v230" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/testutil" ) @@ -27,30 +28,23 @@ func NewComposeCluster() *ComposeCluster { return &ComposeCluster{} } -func (c *ComposeCluster) Client() (*GrpcClient, func(), error) { +func (c *ComposeCluster) Client() (*dgraphapi.GrpcClient, func(), error) { client, err := testutil.DgraphClient(testutil.SockAddr) if err != nil { return nil, nil, err } - return &GrpcClient{Dgraph: client}, func() {}, nil + return &dgraphapi.GrpcClient{Dgraph: client}, func() {}, nil } // HTTPClient creates an HTTP client -func (c *ComposeCluster) HTTPClient() (*HTTPClient, error) { - adminUrl := "http://" + testutil.SockAddrHttp + "/admin" - graphQLUrl := "http://" + testutil.SockAddrHttp + "/graphql" - licenseUrl := "http://" + testutil.SockAddrZeroHttp + "/enterpriseLicense" - stateUrl := "http://" + testutil.SockAddrZeroHttp + "/state" - dqlUrl := "http://" + testutil.SockAddrHttp + "/query" - return &HTTPClient{ - adminURL: adminUrl, - graphqlURL: graphQLUrl, - licenseURL: licenseUrl, - stateURL: stateUrl, - dqlURL: dqlUrl, - HttpToken: &HttpToken{}, - }, nil +func (c *ComposeCluster) HTTPClient() (*dgraphapi.HTTPClient, error) { + httpClient, err := dgraphapi.GetHttpClient(testutil.SockAddrHttp, testutil.SockAddrZeroHttp) + if err != nil { + return nil, err + } + httpClient.HttpToken = &dgraphapi.HttpToken{} + return httpClient, nil } func (c *ComposeCluster) AlphasHealth() ([]string, error) { @@ -72,3 +66,8 @@ func (c *ComposeCluster) GetVersion() string { func (c *ComposeCluster) GetEncKeyPath() (string, error) { return "", errNotImplemented } + +// GetRepoDir returns the repositroty directory of the cluster +func (c *ComposeCluster) GetRepoDir() (string, error) { + return "", errNotImplemented +} diff --git a/dgraphtest/config.go b/dgraphtest/config.go index ded018d201b..5bc4252ce53 100644 --- a/dgraphtest/config.go +++ b/dgraphtest/config.go @@ -74,9 +74,8 @@ func AllUpgradeCombos(v20 bool) []UpgradeCombo { // In mainCombos list, we keep latest version to current HEAD as well as // older versions of dgraph to ensure that a change does not cause failures. mainCombos := []UpgradeCombo{ - {"v23.0.1", localVersion, BackupRestore}, - {"v23.0.1", localVersion, InPlace}, - {"v21.03.0", "4fc9cfd", BackupRestore}, + {"v23.1.0", localVersion, BackupRestore}, + {"v23.1.0", localVersion, InPlace}, } if v20 { @@ -120,6 +119,7 @@ type ClusterConfig struct { customPlugins bool snapShotAfterEntries uint64 snapshotAfterDuration time.Duration + repoDir string } func (cc ClusterConfig) WithGraphqlLambdaURL(url string) ClusterConfig { @@ -147,6 +147,15 @@ func NewClusterConfig() ClusterConfig { } } +//func newClusterConfigFrom(cc ClusterConfig) ClusterConfig { +// prefix := fmt.Sprintf("dgraphtest-%d", rand.NewSource(time.Now().UnixNano()).Int63()%1000000) +// defaultBackupVol := fmt.Sprintf("%v_backup", prefix) +// defaultExportVol := fmt.Sprintf("%v_export", prefix) +// cc.prefix = prefix +// cc.volumes = map[string]string{DefaultBackupDir: defaultBackupVol, DefaultExportDir: defaultExportVol} +// return cc +//} + // WithNAlphas sets the number of alphas in the cluster func (cc ClusterConfig) WithNumAlphas(n int) ClusterConfig { cc.numAlphas = n diff --git a/dgraphtest/dcloud_cluster.go b/dgraphtest/dcloud_cluster.go index 98bafc1825b..eb3923a4944 100644 --- a/dgraphtest/dcloud_cluster.go +++ b/dgraphtest/dcloud_cluster.go @@ -29,6 +29,7 @@ import ( "github.com/dgraph-io/dgo/v230" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" ) type DCloudCluster struct { @@ -71,7 +72,7 @@ func (c *DCloudCluster) init() error { return nil } -func (c *DCloudCluster) Client() (*GrpcClient, func(), error) { +func (c *DCloudCluster) Client() (*dgraphapi.GrpcClient, func(), error) { var conns []*grpc.ClientConn conn, err := dgo.DialCloud(c.url, c.token) if err != nil { @@ -87,10 +88,10 @@ func (c *DCloudCluster) Client() (*GrpcClient, func(), error) { } } client := dgo.NewDgraphClient(api.NewDgraphClient(conn)) - return &GrpcClient{Dgraph: client}, cleanup, nil + return &dgraphapi.GrpcClient{Dgraph: client}, cleanup, nil } -func (c *DCloudCluster) HTTPClient() (*HTTPClient, error) { +func (c *DCloudCluster) HTTPClient() (*dgraphapi.HTTPClient, error) { return nil, errNotImplemented } @@ -161,3 +162,16 @@ func (c1 *DCloudCluster) AssignUids(client *dgo.Dgraph, num uint64) error { func (c *DCloudCluster) GetVersion() string { return localVersion } + +// GetRepoDir returns the repositroty directory of the cluster +func (c *DCloudCluster) GetRepoDir() (string, error) { + return "", errNotImplemented +} + +func (c *DCloudCluster) AlphasLogs() ([]string, error) { + return nil, errNotImplemented +} + +func (c *DCloudCluster) GetEncKeyPath() (string, error) { + return "", errNotImplemented +} diff --git a/dgraphtest/dgraph.go b/dgraphtest/dgraph.go index cfacf165589..96473d6b0b8 100644 --- a/dgraphtest/dgraph.go +++ b/dgraphtest/dgraph.go @@ -23,6 +23,7 @@ import ( "path/filepath" "strconv" "strings" + "testing" "time" "github.com/docker/docker/api/types/mount" @@ -57,9 +58,6 @@ const ( encKeyFile = "enc-key" encKeyMountPath = "/secrets/enc-key" - DefaultUser = "groot" - DefaultPassword = "password" - goBinMountPath = "/gobin" localVersion = "local" waitDurBeforeRetry = time.Second @@ -174,7 +172,7 @@ func (z *zero) healthURL(c *LocalCluster) (string, error) { if err != nil { return "", err } - return "http://localhost:" + publicPort + "/health", nil + return "http://0.0.0.0:" + publicPort + "/health", nil } func (z *zero) changeStatus(isRunning bool) { @@ -186,7 +184,7 @@ func (z *zero) assignURL(c *LocalCluster) (string, error) { if err != nil { return "", err } - return "http://localhost:" + publicPort + "/assign", nil + return "http://0.0.0.0:" + publicPort + "/assign", nil } func (z *zero) alphaURL(c *LocalCluster) (string, error) { @@ -198,7 +196,7 @@ func (z *zero) zeroURL(c *LocalCluster) (string, error) { if err != nil { return "", err } - return "localhost:" + publicPort + "", nil + return "0.0.0.0:" + publicPort + "", nil } type alpha struct { @@ -353,7 +351,7 @@ func (a *alpha) healthURL(c *LocalCluster) (string, error) { if err != nil { return "", err } - return "http://localhost:" + publicPort + "/health", nil + return "http://0.0.0.0:" + publicPort + "/health", nil } func (a *alpha) assignURL(c *LocalCluster) (string, error) { @@ -365,7 +363,7 @@ func (a *alpha) alphaURL(c *LocalCluster) (string, error) { if err != nil { return "", err } - return "localhost:" + publicPort + "", nil + return "0.0.0.0:" + publicPort + "", nil } func (a *alpha) changeStatus(isRunning bool) { @@ -390,8 +388,14 @@ func publicPort(dcli *docker.Client, dc dnode, privatePort string) (string, erro if len(bindings) == 0 { continue } - if port.Port() == privatePort { - return bindings[0].HostPort, nil + if port.Port() != privatePort { + continue + } + + for _, binding := range bindings { + if binding.HostIP == "0.0.0.0" { + return binding.HostPort, nil + } } } @@ -415,3 +419,15 @@ func mountBinary(c *LocalCluster) (mount.Mount, error) { ReadOnly: true, }, nil } + +// ShouldSkipTest skips a given test if clusterVersion < minVersion +func ShouldSkipTest(t *testing.T, minVersion, clusterVersion string) error { + supported, err := IsHigherVersion(clusterVersion, minVersion) + if err != nil { + t.Fatal(err) + } + if !supported { + t.Skipf("test is valid for commits greater than [%v]", minVersion) + } + return nil +} diff --git a/dgraphtest/image.go b/dgraphtest/image.go index 06261fe8404..5b516afae60 100644 --- a/dgraphtest/image.go +++ b/dgraphtest/image.go @@ -26,6 +26,7 @@ import ( "strings" "github.com/pkg/errors" + "golang.org/x/mod/modfile" ) func (c *LocalCluster) dgraphImage() string { @@ -151,6 +152,10 @@ func getHash(ref string) (string, error) { func buildDgraphBinary(dir, binaryDir, version string) error { log.Printf("[INFO] building dgraph binary for version [%v]", version) + if err := fixGoModIfNeeded(); err != nil { + return err + } + cmd := exec.Command("make", "dgraph") cmd.Dir = filepath.Join(dir, "dgraph") if out, err := cmd.CombinedOutput(); err != nil { @@ -204,3 +209,53 @@ func copy(src, dst string) error { _, err = io.Copy(destination, source) return err } + +// IsHigherVersion checks whether "higher" is the higher version compared to "lower" +func IsHigherVersion(higher, lower string) (bool, error) { + // the order of if conditions matters here + if lower == localVersion { + return false, nil + } + if higher == localVersion { + return true, nil + } + + // An older commit is usually the ancestor of a newer commit which is a descendant commit + cmd := exec.Command("git", "merge-base", "--is-ancestor", lower, higher) + cmd.Dir = repoDir + if out, err := cmd.CombinedOutput(); err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + return exitError.ExitCode() == 0, nil + } + + return false, errors.Wrapf(err, "error checking if [%v] is ancestor of [%v]\noutput:%v", + higher, lower, string(out)) + } + + return true, nil +} + +func fixGoModIfNeeded() error { + repoModFilePath := filepath.Join(repoDir, "go.mod") + repoModFile, err := modfile.Parse(repoModFilePath, nil, nil) + if err != nil { + return errors.Wrapf(err, "error parsing mod file in repoDir [%v]", repoDir) + } + + modFile, err := modfile.Parse("go.mod", nil, nil) + if err != nil { + return errors.Wrapf(err, "error while parsing go.mod file") + } + + if len(modFile.Replace) == len(repoModFile.Replace) { + return nil + } + + repoModFile.Replace = modFile.Replace + if data, err := repoModFile.Format(); err != nil { + return errors.Wrapf(err, "error while formatting mod file") + } else if err := os.WriteFile(repoModFilePath, data, 0644); err != nil { + return errors.Wrapf(err, "error while writing to go.mod file") + } + return nil +} diff --git a/dgraphtest/load.go b/dgraphtest/load.go index bbbb0901d9d..5093fecb726 100644 --- a/dgraphtest/load.go +++ b/dgraphtest/load.go @@ -34,6 +34,7 @@ import ( "github.com/pkg/errors" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/ee/enc" "github.com/dgraph-io/dgraph/x" ) @@ -138,7 +139,7 @@ func setDQLSchema(c *LocalCluster, files []string) error { if c.conf.acl { ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) defer cancel() - err := gc.LoginIntoNamespace(ctx, DefaultUser, DefaultPassword, x.GalaxyNamespace) + err := gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) if err != nil { return errors.Wrap(err, "error login to default namespace") } @@ -188,7 +189,7 @@ func setGraphQLSchema(c *LocalCluster, files []string) error { } if c.conf.acl { - err := hc.LoginIntoNamespace(DefaultUser, DefaultPassword, nss.Namespace) + err := hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, nss.Namespace) if err != nil { return errors.Wrap(err, "error login into default namespace") } @@ -233,7 +234,7 @@ func (c *LocalCluster) LiveLoad(opts LiveOpts) error { } if c.conf.acl { args = append(args, fmt.Sprintf("--creds=user=%s;password=%s;namespace=%d", - DefaultUser, DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) } if c.conf.encryption { args = append(args, fmt.Sprintf("--encryption=key-file=%v", c.encKeyPath)) @@ -261,7 +262,7 @@ func findGrootAndGuardians(c *LocalCluster) (string, string, error) { if c.conf.acl { ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) defer cancel() - err = gc.LoginIntoNamespace(ctx, DefaultUser, DefaultPassword, x.GalaxyNamespace) + err = gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) if err != nil { return "", "", errors.Wrapf(err, "error logging in as groot") } @@ -502,7 +503,7 @@ func (c *LocalCluster) BulkLoad(opts BulkOpts) error { } // AddData will insert a total of end-start triples into the database. -func AddData(gc *GrpcClient, pred string, start, end int) error { +func AddData(gc *dgraphapi.GrpcClient, pred string, start, end int) error { if err := gc.SetupSchema(fmt.Sprintf(`%v: string @index(exact) .`, pred)); err != nil { return err } diff --git a/dgraphtest/local_cluster.go b/dgraphtest/local_cluster.go index 0e2c47b0720..0248255fa47 100644 --- a/dgraphtest/local_cluster.go +++ b/dgraphtest/local_cluster.go @@ -17,11 +17,13 @@ package dgraphtest import ( + "bytes" "context" "encoding/json" "fmt" "io" "log" + "math/rand" "net/http" "os" "os/exec" @@ -33,6 +35,7 @@ import ( "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/network" "github.com/docker/docker/api/types/volume" docker "github.com/docker/docker/client" @@ -43,6 +46,7 @@ import ( "github.com/dgraph-io/dgo/v230" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/x" ) @@ -138,6 +142,7 @@ func (c *LocalCluster) init() error { } } + c.zeros = c.zeros[:0] for i := 0; i < c.conf.numZeros; i++ { zo := &zero{id: i} zo.containerName = fmt.Sprintf(zeroNameFmt, c.conf.prefix, zo.id) @@ -145,6 +150,7 @@ func (c *LocalCluster) init() error { c.zeros = append(c.zeros, zo) } + c.alphas = c.alphas[:0] for i := 0; i < c.conf.numAlphas; i++ { aa := &alpha{id: i} aa.containerName = fmt.Sprintf(alphaNameFmt, c.conf.prefix, aa.id) @@ -278,6 +284,87 @@ func (c *LocalCluster) destroyContainers() error { return nil } +// CheckRunningServices checks open ports using lsof and returns the output as a string +func CheckRunningServices() (string, error) { + lsofCmd := exec.Command("lsof", "-i", "-n") + output, err := runCommand(lsofCmd) + if err != nil { + return "", fmt.Errorf("error running lsof command: %v", err) + } + return output, nil +} + +// ListRunningContainers lists running Docker containers using the Docker Go client +func (c *LocalCluster) listRunningContainers() (string, error) { + containers, err := c.dcli.ContainerList(context.Background(), types.ContainerListOptions{}) + if err != nil { + return "", fmt.Errorf("error listing Docker containers: %v", err) + } + + var result bytes.Buffer + for _, container := range containers { + result.WriteString(fmt.Sprintf("ID: %s, Image: %s, Command: %s, Status: %s\n", + container.ID[:10], container.Image, container.Command, container.Status)) + + result.WriteString("Port Mappings:\n") + for _, port := range container.Ports { + result.WriteString(fmt.Sprintf(" %s:%d -> %d\n", port.IP, port.PublicPort, port.PrivatePort)) + } + result.WriteString("\n") + + result.WriteString("Port Mappings:\n") + info, err := c.dcli.ContainerInspect(context.Background(), container.ID) + if err != nil { + return "", errors.Wrap(err, "error inspecting container") + } + + for port, bindings := range info.NetworkSettings.Ports { + if len(bindings) == 0 { + continue + } + result.WriteString(fmt.Sprintf(" %s:%s\n", port.Port(), bindings)) + } + result.WriteString("\n") + } + + return result.String(), nil +} + +// runCommand executes a command and returns its output or an error +func runCommand(cmd *exec.Cmd) (string, error) { + var out bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &stderr + err := cmd.Run() + if err != nil { + return "", fmt.Errorf("%v: %v", err, stderr.String()) + } + return out.String(), nil +} + +func (c *LocalCluster) printNetworkStuff() { + log.Printf("Checking running services and ports using lsof, netstat, and Docker...\n") + + // Check running services using lsof + lsofOutput, err := CheckRunningServices() + if err != nil { + fmt.Printf("Error checking running services: %v\n", err) + } else { + log.Printf("Output of lsof -i:") + log.Println(lsofOutput) + } + + // List running Docker containers + dockerOutput, err := c.listRunningContainers() + if err != nil { + fmt.Printf("Error listing Docker containers: %v\n", err) + } else { + log.Printf("Running Docker containers:") + log.Println(dockerOutput) + } +} + func (c *LocalCluster) Cleanup(verbose bool) { if c == nil { return @@ -317,6 +404,26 @@ func (c *LocalCluster) Cleanup(verbose bool) { } } +func (c *LocalCluster) cleanupDocker() error { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + defer cancel() + // Prune containers + contsReport, err := c.dcli.ContainersPrune(ctx, filters.Args{}) + if err != nil { + log.Fatalf("[ERROR] Error pruning containers: %v", err) + } + log.Printf("[INFO] Pruned containers: %+v\n", contsReport) + + // Prune networks + netsReport, err := c.dcli.NetworksPrune(ctx, filters.Args{}) + if err != nil { + log.Fatalf("[ERROR] Error pruning networks: %v", err) + } + log.Printf("[INFO] Pruned networks: %+v\n", netsReport) + + return nil +} + func (c *LocalCluster) Start() error { log.Printf("[INFO] starting cluster with prefix [%v]", c.conf.prefix) startAll := func() error { @@ -334,26 +441,35 @@ func (c *LocalCluster) Start() error { return c.HealthCheck(false) } - var err error - // sometimes health check doesn't work due to unmapped ports. We dont know why this happens, - // but checking it 4 times before failing the test. - for i := 0; i < 4; i++ { + // sometimes health check doesn't work due to unmapped ports. We dont + // know why this happens, but checking it 3 times before failing the test. + retry := 0 + for { + retry++ - if err = startAll(); err == nil { + if err := startAll(); err == nil { return nil + } else if retry == 3 { + return err + } else { + log.Printf("[WARNING] saw the err, trying again: %v", err) } - log.Printf("[WARNING] Saw the error :%v, trying again", err) + if err1 := c.Stop(); err1 != nil { - log.Printf("[WARNING] error while stopping :%v", err) + log.Printf("[WARNING] error while stopping :%v", err1) + } + c.Cleanup(true) + + if err := c.cleanupDocker(); err != nil { + log.Printf("[ERROR] while cleaning old dockers %v", err) } - c.Cleanup(false) + + c.conf.prefix = fmt.Sprintf("dgraphtest-%d", rand.NewSource(time.Now().UnixNano()).Int63()%1000000) if err := c.init(); err != nil { - c.Cleanup(true) + log.Printf("[ERROR] error while init, returning: %v", err) return err } } - - return err } func (c *LocalCluster) StartZero(id int) error { @@ -449,11 +565,7 @@ func (c *LocalCluster) HealthCheck(zeroOnly bool) error { if !zo.isRunning { break } - url, err := zo.healthURL(c) - if err != nil { - return errors.Wrap(err, "error getting health URL") - } - if err := c.containerHealthCheck(url); err != nil { + if err := c.containerHealthCheck(zo.healthURL); err != nil { return err } log.Printf("[INFO] container [%v] passed health check", zo.containerName) @@ -470,11 +582,7 @@ func (c *LocalCluster) HealthCheck(zeroOnly bool) error { if !aa.isRunning { break } - url, err := aa.healthURL(c) - if err != nil { - return errors.Wrap(err, "error getting health URL") - } - if err := c.containerHealthCheck(url); err != nil { + if err := c.containerHealthCheck(aa.healthURL); err != nil { return err } log.Printf("[INFO] container [%v] passed health check", aa.containerName) @@ -486,18 +594,28 @@ func (c *LocalCluster) HealthCheck(zeroOnly bool) error { return nil } -func (c *LocalCluster) containerHealthCheck(url string) error { +func (c *LocalCluster) containerHealthCheck(url func(c *LocalCluster) (string, error)) error { + endpoint, err := url(c) + if err != nil { + return errors.Wrap(err, "error getting health URL") + } + for i := 0; i < 60; i++ { time.Sleep(waitDurBeforeRetry) - req, err := http.NewRequest(http.MethodGet, url, nil) + endpoint, err = url(c) + if err != nil { + return errors.Wrap(err, "error getting health URL") + } + + req, err := http.NewRequest(http.MethodGet, endpoint, nil) if err != nil { - log.Printf("[WARNING] error building req for endpoint [%v], err: [%v]", url, err) + log.Printf("[WARNING] error building req for endpoint [%v], err: [%v]", endpoint, err) continue } - body, err := doReq(req) + body, err := dgraphapi.DoReq(req) if err != nil { - log.Printf("[WARNING] error hitting health endpoint [%v], err: [%v]", url, err) + log.Printf("[WARNING] error hitting health endpoint [%v], err: [%v]", endpoint, err) continue } resp := string(body) @@ -523,7 +641,8 @@ func (c *LocalCluster) containerHealthCheck(url string) error { return nil } - return fmt.Errorf("health failed, cluster took too long to come up [%v]", url) + c.printNetworkStuff() + return fmt.Errorf("health failed, cluster took too long to come up [%v]", endpoint) } func (c *LocalCluster) waitUntilLogin() error { @@ -540,7 +659,7 @@ func (c *LocalCluster) waitUntilLogin() error { ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) defer cancel() for i := 0; i < 10; i++ { - err := client.Login(ctx, DefaultUser, DefaultPassword) + err := client.Login(ctx, dgraphapi.DefaultUser, dgraphapi.DefaultPassword) if err == nil { log.Printf("[INFO] login succeeded") return nil @@ -557,7 +676,7 @@ func (c *LocalCluster) waitUntilGraphqlHealthCheck() error { return errors.Wrap(err, "error creating http client while graphql health check") } if c.conf.acl { - if err := hc.LoginIntoNamespace(DefaultUser, DefaultPassword, x.GalaxyNamespace); err != nil { + if err := hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace); err != nil { return errors.Wrap(err, "error during login while graphql health check") } } @@ -581,31 +700,6 @@ func (c *LocalCluster) waitUntilGraphqlHealthCheck() error { return errors.New("error during graphql health check") } -var client *http.Client = &http.Client{ - Timeout: requestTimeout, -} - -func doReq(req *http.Request) ([]byte, error) { - resp, err := client.Do(req) - if err != nil { - return nil, errors.Wrap(err, "error performing HTTP request") - } - defer func() { - if err := resp.Body.Close(); err != nil { - log.Printf("[WARNING] error closing response body: %v", err) - } - }() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, errors.Wrapf(err, "error reading response body: url: [%v], err: [%v]", req.URL, err) - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("got non 200 resp: %v", string(respBody)) - } - return respBody, nil -} - // Upgrades the cluster to the provided dgraph version func (c *LocalCluster) Upgrade(version string, strategy UpgradeStrategy) error { if version == c.conf.version { @@ -620,7 +714,7 @@ func (c *LocalCluster) Upgrade(version string, strategy UpgradeStrategy) error { return err } if c.conf.acl { - if err := hc.LoginIntoNamespace(DefaultUser, DefaultPassword, x.GalaxyNamespace); err != nil { + if err := hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace); err != nil { return errors.Wrapf(err, "error during login before upgrade") } } @@ -643,14 +737,14 @@ func (c *LocalCluster) Upgrade(version string, strategy UpgradeStrategy) error { return errors.Wrapf(err, "error creating HTTP client after upgrade") } if c.conf.acl { - if err := hc.LoginIntoNamespace(DefaultUser, DefaultPassword, x.GalaxyNamespace); err != nil { + if err := hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace); err != nil { return errors.Wrapf(err, "error during login after upgrade") } } if err := hc.Restore(c, DefaultBackupDir, "", 0, 1); err != nil { return errors.Wrap(err, "error doing restore during upgrade") } - if err := WaitForRestore(c); err != nil { + if err := dgraphapi.WaitForRestore(c); err != nil { return errors.Wrap(err, "error waiting for restore to complete") } return nil @@ -661,7 +755,7 @@ func (c *LocalCluster) Upgrade(version string, strategy UpgradeStrategy) error { return err } if c.conf.acl { - if err := hc.LoginIntoNamespace(DefaultUser, DefaultPassword, x.GalaxyNamespace); err != nil { + if err := hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace); err != nil { return errors.Wrapf(err, "error during login before upgrade") } } @@ -716,7 +810,7 @@ func (c *LocalCluster) recreateContainers() error { } // Client returns a grpc client that can talk to any Alpha in the cluster -func (c *LocalCluster) Client() (*GrpcClient, func(), error) { +func (c *LocalCluster) Client() (*dgraphapi.GrpcClient, func(), error) { // TODO(aman): can we cache the connections? var apiClients []api.DgraphClient var conns []*grpc.ClientConn @@ -744,10 +838,10 @@ func (c *LocalCluster) Client() (*GrpcClient, func(), error) { } } } - return &GrpcClient{Dgraph: client}, cleanup, nil + return &dgraphapi.GrpcClient{Dgraph: client}, cleanup, nil } -func (c *LocalCluster) AlphaClient(id int) (*GrpcClient, func(), error) { +func (c *LocalCluster) AlphaClient(id int) (*dgraphapi.GrpcClient, func(), error) { alpha := c.alphas[id] url, err := alpha.alphaURL(c) if err != nil { @@ -764,39 +858,22 @@ func (c *LocalCluster) AlphaClient(id int) (*GrpcClient, func(), error) { log.Printf("[WARNING] error closing connection: %v", err) } } - return &GrpcClient{Dgraph: client}, cleanup, nil + return &dgraphapi.GrpcClient{Dgraph: client}, cleanup, nil } // HTTPClient creates an HTTP client -func (c *LocalCluster) HTTPClient() (*HTTPClient, error) { - adminURL, err := c.serverURL("alpha", "/admin") - if err != nil { - return nil, err - } - graphqlURL, err := c.serverURL("alpha", "/graphql") - if err != nil { - return nil, err - } - licenseURL, err := c.serverURL("zero", "/enterpriseLicense") +func (c *LocalCluster) HTTPClient() (*dgraphapi.HTTPClient, error) { + alphaUrl, err := c.serverURL("alpha", "") if err != nil { return nil, err } - stateURL, err := c.serverURL("zero", "/state") - if err != nil { - return nil, err - } - dqlURL, err := c.dqlURL() + + zeroUrl, err := c.serverURL("zero", "") if err != nil { return nil, err } - return &HTTPClient{ - adminURL: adminURL, - graphqlURL: graphqlURL, - licenseURL: licenseURL, - stateURL: stateURL, - dqlURL: dqlURL, - }, nil + return dgraphapi.GetHttpClient(alphaUrl, zeroUrl) } // serverURL returns url to the 'server' 'endpoint' @@ -808,17 +885,7 @@ func (c *LocalCluster) serverURL(server, endpoint string) (string, error) { if err != nil { return "", err } - url := "http://localhost:" + pubPort + endpoint - return url, nil -} - -// dqlURL returns url to the dql query endpoint -func (c *LocalCluster) dqlURL() (string, error) { - publicPort, err := publicPort(c.dcli, c.alphas[0], alphaHttpPort) - if err != nil { - return "", err - } - url := "http://localhost:" + publicPort + "/query" + url := "0.0.0.0:" + pubPort + endpoint return url, nil } @@ -838,7 +905,7 @@ func (c *LocalCluster) AlphasHealth() ([]string, error) { if err != nil { return nil, errors.Wrapf(err, "error building req for endpoint [%v]", url) } - h, err := doReq(req) + h, err := dgraphapi.DoReq(req) if err != nil { return nil, errors.Wrap(err, "error getting health") } @@ -877,7 +944,7 @@ func (c *LocalCluster) AssignUids(_ *dgo.Dgraph, num uint64) error { if err != nil { return errors.Wrapf(err, "error building req for endpoint [%v]", url) } - body, err := doReq(req) + body, err := dgraphapi.DoReq(req) if err != nil { return err } @@ -901,6 +968,11 @@ func (c *LocalCluster) GetVersion() string { return c.conf.version } +// GetRepoDir returns the repositroty directory of the cluster +func (c *LocalCluster) GetRepoDir() (string, error) { + return c.conf.repoDir, nil +} + // GetEncKeyPath returns the path to the encryption key file when encryption is enabled. // It returns an empty string otherwise. The path to the encryption file is valid only // inside the alpha container. @@ -1144,3 +1216,15 @@ func (c *LocalCluster) GeneratePlugins(raceEnabled bool) error { return nil } + +func (c *LocalCluster) GetAlphaGrpcPublicPort() (string, error) { + return publicPort(c.dcli, c.alphas[0], alphaGrpcPort) +} + +func (c *LocalCluster) GetAlphaHttpPublicPort() (string, error) { + return publicPort(c.dcli, c.alphas[0], alphaHttpPort) +} + +func (c *LocalCluster) GetTempDir() string { + return c.tempBinDir +} diff --git a/edgraph/server.go b/edgraph/server.go index 48d556f9319..f59d65c7ff3 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -303,7 +303,7 @@ func parseSchemaFromAlterOperation(ctx context.Context, op *api.Operation) ( // Pre-defined predicates cannot be altered but let the update go through // if the update is equal to the existing one. - if schema.IsPreDefPredChanged(update) { + if schema.CheckAndModifyPreDefPredicate(update) { return nil, errors.Errorf("predicate %s is pre-defined and is not allowed to be"+ " modified", x.ParseAttr(update.Predicate)) } @@ -1344,7 +1344,7 @@ func (s *Server) doQuery(ctx context.Context, req *Request) (resp *api.Response, graphql: isGraphQL, gqlField: req.gqlField, } - if rerr = parseRequest(qc); rerr != nil { + if rerr = parseRequest(ctx, qc); rerr != nil { return } @@ -1565,7 +1565,7 @@ func processQuery(ctx context.Context, qc *queryContext) (*api.Response, error) } // parseRequest parses the incoming request -func parseRequest(qc *queryContext) error { +func parseRequest(ctx context.Context, qc *queryContext) error { start := time.Now() defer func() { qc.latency.Parsing = time.Since(start) @@ -1585,7 +1585,7 @@ func parseRequest(qc *queryContext) error { qc.gmuList = append(qc.gmuList, gmu) } - if err := addQueryIfUnique(qc); err != nil { + if err := addQueryIfUnique(ctx, qc); err != nil { return err } @@ -1698,19 +1698,38 @@ func verifyUnique(qc *queryContext, qr query.Request) error { } // addQueryIfUnique adds dummy queries in the request for checking whether predicate is unique in the db -func addQueryIfUnique(qc *queryContext) error { +func addQueryIfUnique(qctx context.Context, qc *queryContext) error { if len(qc.gmuList) == 0 { return nil } - ctx := context.WithValue(context.Background(), schema.IsWrite, false) + ctx := context.WithValue(qctx, schema.IsWrite, false) + namespace, err := x.ExtractNamespace(ctx) + if err != nil { + // It's okay to ignore this here. If namespace is not set, it could mean either there is no + // authorization or it's trying to be bypassed. So the namespace is either 0 or the mutation would fail. + glog.Errorf("Error while extracting namespace, assuming default %s", err) + namespace = 0 + } + isGalaxyQuery := x.IsGalaxyOperation(ctx) + qc.uniqueVars = map[uint64]uniquePredMeta{} var buildQuery strings.Builder for gmuIndex, gmu := range qc.gmuList { for rdfIndex, pred := range gmu.Set { - predSchema, _ := schema.State().Get(ctx, x.NamespaceAttr(pred.Namespace, pred.Predicate)) - if !predSchema.Unique { - continue + if isGalaxyQuery { + // The caller should make sure that the directed edges contain the namespace we want + // to insert into. + namespace = pred.Namespace + } + if pred.Predicate != "dgraph.xid" { + // [TODO] Don't check if it's dgraph.xid. It's a bug as this node might not be aware + // of the schema for the given predicate. This is a bug issue for dgraph.xid hence + // we are bypassing it manually until the bug is fixed. + predSchema, ok := schema.State().Get(ctx, x.NamespaceAttr(namespace, pred.Predicate)) + if !ok || !predSchema.Unique { + continue + } } var predicateName string if pred.Lang != "" { diff --git a/ee/acl/acl_curl_test.go b/ee/acl/acl_curl_test.go index dab256cc878..75075e05fe6 100644 --- a/ee/acl/acl_curl_test.go +++ b/ee/acl/acl_curl_test.go @@ -22,7 +22,7 @@ import ( "github.com/golang/glog" "github.com/stretchr/testify/require" - "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/testutil" "github.com/dgraph-io/dgraph/x" ) @@ -37,13 +37,13 @@ func (asuite *AclTestSuite) TestCurlAuthorization() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) createAccountAndData(t, gc, hc) // test query through curl @@ -104,8 +104,8 @@ func (asuite *AclTestSuite) TestCurlAuthorization() { require.NoError(t, err, fmt.Sprintf("login through refresh httpToken failed: %v", err)) hcWithGroot, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hcWithGroot.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hcWithGroot.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) createGroupAndAcls(t, unusedGroup, false, hcWithGroot) time.Sleep(expireJwtSleep) testutil.VerifyCurlCmd(t, queryArgs(hc.AccessJwt), &testutil.CurlFailureConfig{ @@ -129,8 +129,8 @@ func (asuite *AclTestSuite) TestCurlAuthorization() { ShouldFail: true, DgraphErrMsg: "PermissionDenied", }) - require.NoError(t, hcWithGroot.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hcWithGroot.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) createGroupAndAcls(t, devGroup, true, hcWithGroot) time.Sleep(defaultTimeToSleep) // refresh the jwts again diff --git a/ee/acl/acl_integration_test.go b/ee/acl/acl_integration_test.go index 0679b03fced..eaf360cf218 100644 --- a/ee/acl/acl_integration_test.go +++ b/ee/acl/acl_integration_test.go @@ -27,7 +27,7 @@ import ( "github.com/stretchr/testify/require" "github.com/dgraph-io/dgo/v230/protos/api" - "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/x" ) @@ -46,9 +46,9 @@ func (asuite *AclTestSuite) TestPasswordReturn() { t := asuite.T() hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) - query := dgraphtest.GraphQLParams{ + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + query := dgraphapi.GraphQLParams{ Query: ` query { getCurrentUser { @@ -66,8 +66,8 @@ func (asuite *AclTestSuite) TestHealthForAcl() { t := asuite.T() hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) resetUser(t, hc) require.NoError(t, hc.LoginIntoNamespace(userid, userpassword, x.GalaxyNamespace)) @@ -77,8 +77,8 @@ func (asuite *AclTestSuite) TestHealthForAcl() { assertNonGuardianFailure(t, "health", false, gqlResp, err) // assert data for guardians - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) resp, err := hc.HealthForInstance() require.NoError(t, err, "health request failed") @@ -336,11 +336,11 @@ func (asuite *AclTestSuite) TestGuardianOnlyAccessForAdminEndpoints() { for _, tcase := range tcases { t.Run(tcase.name, func(t *testing.T) { - params := dgraphtest.GraphQLParams{Query: tcase.query} + params := dgraphapi.GraphQLParams{Query: tcase.query} hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) resetUser(t, hc) require.NoError(t, hc.LoginIntoNamespace(userid, userpassword, x.GalaxyNamespace)) @@ -351,8 +351,8 @@ func (asuite *AclTestSuite) TestGuardianOnlyAccessForAdminEndpoints() { // for guardians, assert non-ACL error or success if tcase.testGuardianAccess { - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) resp, err := hc.RunGraphqlQuery(params, true) if tcase.guardianErr == "" { @@ -378,13 +378,13 @@ func (asuite *AclTestSuite) TestFailedLogin() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) @@ -410,8 +410,9 @@ func (asuite *AclTestSuite) TestWrongPermission() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.DropAll()) mu := &api.Mutation{SetNquads: []byte(` _:dev "dgraph.type.Group" . @@ -437,3 +438,51 @@ func (asuite *AclTestSuite) TestWrongPermission() { require.Error(t, err, "Setting permission to -1 should have returned error") require.Contains(t, err.Error(), "Value for this predicate should be between 0 and 7") } + +func (asuite *AclTestSuite) TestACLNamespaceEdge() { + t := asuite.T() + gc, cleanup, err := asuite.dc.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + json := ` + { + "set": [ + { + "dgraph.xid": "groot", + "dgraph.password": "password", + "dgraph.type": "dgraph.type.User", + "dgraph.user.group": { + "dgraph.xid": "guardians", + "dgraph.type": "dgraph.type.Group", + "namespace": 1 + }, + "namespace": 1 + } + ] +}` + + mu := &api.Mutation{SetJson: []byte(json), CommitNow: true} + _, err = gc.Mutate(mu) + require.Error(t, err) + require.ErrorContains(t, err, "could not insert duplicate value") // Could be gaurdian or groot +} + +func (asuite *AclTestSuite) TestACLDuplicateGrootUser() { + t := asuite.T() + gc, cleanup, err := asuite.dc.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + rdfs := `_:a "groot" . + _:a "dgraph.type.User" .` + + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.Error(t, err) + require.ErrorContains(t, err, "could not insert duplicate value [groot] for predicate [dgraph.xid]") +} diff --git a/ee/acl/acl_test.go b/ee/acl/acl_test.go index 88b78f52d55..6f7281c5e80 100644 --- a/ee/acl/acl_test.go +++ b/ee/acl/acl_test.go @@ -26,6 +26,7 @@ import ( "github.com/dgraph-io/dgo/v230" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) @@ -35,6 +36,274 @@ var ( userpassword = "simplepassword" ) +const ( + + // This is the groot schema before adding @unique directive to the dgraph.xid predicate + oldGrootSchema = `{ + "schema": [ + { + "predicate": "dgraph.acl.rule", + "type": "uid", + "list": true + }, + { + "predicate":"dgraph.drop.op", + "type":"string" + }, + { + "predicate":"dgraph.graphql.p_query", + "type":"string", + "index":true, + "tokenizer":["sha256"] + }, + { + "predicate": "dgraph.graphql.schema", + "type": "string" + }, + { + "predicate": "dgraph.graphql.xid", + "type": "string", + "index": true, + "tokenizer": [ + "exact" + ], + "upsert": true + }, + { + "predicate": "dgraph.password", + "type": "password" + }, + { + "predicate": "dgraph.rule.permission", + "type": "int" + }, + { + "predicate": "dgraph.rule.predicate", + "type": "string", + "index": true, + "tokenizer": [ + "exact" + ], + "upsert": true + }, + { + "predicate": "dgraph.type", + "type": "string", + "index": true, + "tokenizer": [ + "exact" + ], + "list": true + }, + { + "predicate": "dgraph.user.group", + "type": "uid", + "reverse": true, + "list": true + }, + { + "predicate": "dgraph.xid", + "type": "string", + "index": true, + "tokenizer": [ + "exact" + ], + "upsert": true + } + ], + "types": [ + { + "fields": [ + { + "name": "dgraph.graphql.schema" + }, + { + "name": "dgraph.graphql.xid" + } + ], + "name": "dgraph.graphql" + }, + { + "fields": [ + { + "name": "dgraph.graphql.p_query" + } + ], + "name": "dgraph.graphql.persisted_query" + }, + { + "fields": [ + { + "name": "dgraph.xid" + }, + { + "name": "dgraph.acl.rule" + } + ], + "name": "dgraph.type.Group" + }, + { + "fields": [ + { + "name": "dgraph.rule.predicate" + }, + { + "name": "dgraph.rule.permission" + } + ], + "name": "dgraph.type.Rule" + }, + { + "fields": [ + { + "name": "dgraph.xid" + }, + { + "name": "dgraph.password" + }, + { + "name": "dgraph.user.group" + } + ], + "name": "dgraph.type.User" + } + ] + }` + + // This is the groot schema after adding @unique directive to the dgraph.xid predicate + newGrootSchema = `{ + "schema": [ + { + "predicate": "dgraph.acl.rule", + "type": "uid", + "list": true + }, + { + "predicate":"dgraph.drop.op", + "type":"string" + }, + { + "predicate":"dgraph.graphql.p_query", + "type":"string", + "index":true, + "tokenizer":["sha256"] + }, + { + "predicate": "dgraph.graphql.schema", + "type": "string" + }, + { + "predicate": "dgraph.graphql.xid", + "type": "string", + "index": true, + "tokenizer": [ + "exact" + ], + "upsert": true + }, + { + "predicate": "dgraph.password", + "type": "password" + }, + { + "predicate": "dgraph.rule.permission", + "type": "int" + }, + { + "predicate": "dgraph.rule.predicate", + "type": "string", + "index": true, + "tokenizer": [ + "exact" + ], + "upsert": true + }, + { + "predicate": "dgraph.type", + "type": "string", + "index": true, + "tokenizer": [ + "exact" + ], + "list": true + }, + { + "predicate": "dgraph.user.group", + "type": "uid", + "reverse": true, + "list": true + }, + { + "predicate": "dgraph.xid", + "type": "string", + "index": true, + "tokenizer": [ + "exact" + ], + "upsert": true, + "unique": true + } + ], + "types": [ + { + "fields": [ + { + "name": "dgraph.graphql.schema" + }, + { + "name": "dgraph.graphql.xid" + } + ], + "name": "dgraph.graphql" + }, + { + "fields": [ + { + "name": "dgraph.graphql.p_query" + } + ], + "name": "dgraph.graphql.persisted_query" + }, + { + "fields": [ + { + "name": "dgraph.xid" + }, + { + "name": "dgraph.acl.rule" + } + ], + "name": "dgraph.type.Group" + }, + { + "fields": [ + { + "name": "dgraph.rule.predicate" + }, + { + "name": "dgraph.rule.permission" + } + ], + "name": "dgraph.type.Rule" + }, + { + "fields": [ + { + "name": "dgraph.xid" + }, + { + "name": "dgraph.password" + }, + { + "name": "dgraph.user.group" + } + ], + "name": "dgraph.type.User" + } + ] + }` +) + func checkUserCount(t *testing.T, resp []byte, expected int) { type Response struct { AddUser struct { @@ -68,8 +337,8 @@ func (asuite *AclTestSuite) TestGetCurrentUser() { hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace), "login failed") + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace), "login failed") currentUser, err := hc.GetCurrentUser() require.NoError(t, err) require.Equal(t, currentUser, "groot") @@ -100,8 +369,8 @@ func (asuite *AclTestSuite) TestCreateAndDeleteUsers() { resetUser(t, hc) // adding the user again should fail - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) user, err := hc.CreateUser(userid, userpassword) require.Error(t, err) require.Contains(t, err.Error(), "because id alice already exists") @@ -112,16 +381,16 @@ func (asuite *AclTestSuite) TestCreateAndDeleteUsers() { // delete the user hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, hc.DeleteUser(userid), "error while deleteing user") user, err = hc.CreateUser(userid, userpassword) require.NoError(t, err) require.Equal(t, userid, user) } -func resetUser(t *testing.T, hc *dgraphtest.HTTPClient) { - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) +func resetUser(t *testing.T, hc *dgraphapi.HTTPClient) { + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) // clean up the user to allow repeated running of this test require.NoError(t, hc.DeleteUser(userid), "error while deleteing user") @@ -140,9 +409,9 @@ func (asuite *AclTestSuite) TestPreDefinedPredicates() { require.NoError(t, err) defer cleanup() ctx := context.Background() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) - alterPreDefinedPredicates(t, gc.Dgraph) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + alterPreDefinedPredicates(t, gc.Dgraph, asuite.dc.GetVersion()) } func (asuite *AclTestSuite) TestPreDefinedTypes() { @@ -154,8 +423,8 @@ func (asuite *AclTestSuite) TestPreDefinedTypes() { require.NoError(t, err) defer cleanup() ctx := context.Background() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) alterPreDefinedTypes(t, gc.Dgraph) } @@ -172,7 +441,7 @@ func (asuite *AclTestSuite) TestAuthorization() { testAuthorization(t, gc, hc, asuite) } -func getGrootAndGuardiansUid(t *testing.T, gc *dgraphtest.GrpcClient) (string, string) { +func getGrootAndGuardiansUid(t *testing.T, gc *dgraphapi.GrpcClient) (string, string) { grootUserQuery := ` { grootUser(func:eq(dgraph.xid, "groot")){ @@ -228,7 +497,7 @@ const ( expireJwtSleep = 21 * time.Second ) -func testAuthorization(t *testing.T, gc *dgraphtest.GrpcClient, hc *dgraphtest.HTTPClient, asuite *AclTestSuite) { +func testAuthorization(t *testing.T, gc *dgraphapi.GrpcClient, hc *dgraphapi.HTTPClient, asuite *AclTestSuite) { createAccountAndData(t, gc, hc) asuite.Upgrade() @@ -239,7 +508,7 @@ func testAuthorization(t *testing.T, gc *dgraphtest.GrpcClient, hc *dgraphtest.H hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) // initially the query should return empty result, mutate and alter // operations should all fail when there are no rules defined on the predicates @@ -295,16 +564,27 @@ var ( }`, predicateToRead, queryAttr) ) -func alterPreDefinedPredicates(t *testing.T, dg *dgo.Dgraph) { +func alterPreDefinedPredicates(t *testing.T, dg *dgo.Dgraph, clusterVersion string) { ctx := context.Background() - // Test that alter requests are allowed if the new update is the same as - // the initial update for a pre-defined predicate. - require.NoError(t, dg.Alter(ctx, &api.Operation{ - Schema: "dgraph.xid: string @index(exact) @upsert .", - })) + // Commit daa5805739ed258e913a157c6e0f126b2291b1b0 represents the latest update to the main branch. + // In this commit, the @unique directive is not applied to ACL predicates. + // Therefore, we are now deciding which schema to test. + // 'newGrootSchema' refers to the default schema with the @unique directive defined on ACL predicates. + // 'oldGrootSchema' refers to the default schema without the @unique directive on ACL predicates. + supported, err := dgraphtest.IsHigherVersion(clusterVersion, "daa5805739ed258e913a157c6e0f126b2291b1b0") + require.NoError(t, err) + if supported { + require.NoError(t, dg.Alter(ctx, &api.Operation{ + Schema: "dgraph.xid: string @index(exact) @upsert @unique .", + })) + } else { + require.NoError(t, dg.Alter(ctx, &api.Operation{ + Schema: "dgraph.xid: string @index(exact) @upsert .", + })) + } - err := dg.Alter(ctx, &api.Operation{Schema: "dgraph.xid: int ."}) + err = dg.Alter(ctx, &api.Operation{Schema: "dgraph.xid: int ."}) require.Error(t, err) require.Contains(t, err.Error(), "predicate dgraph.xid is pre-defined and is not allowed to be modified") @@ -351,7 +631,7 @@ func alterPreDefinedTypes(t *testing.T, dg *dgo.Dgraph) { require.Contains(t, err.Error(), "type dgraph.type.Group is pre-defined and is not allowed to be dropped") } -func queryWithShouldFail(t *testing.T, gc *dgraphtest.GrpcClient, shouldFail bool, query string) { +func queryWithShouldFail(t *testing.T, gc *dgraphapi.GrpcClient, shouldFail bool, query string) { _, err := gc.Query(query) if shouldFail { require.Error(t, err, "the query should have failed") @@ -360,7 +640,7 @@ func queryWithShouldFail(t *testing.T, gc *dgraphtest.GrpcClient, shouldFail boo } } -func mutatePredicateWithUserAccount(t *testing.T, gc *dgraphtest.GrpcClient, shouldFail bool) { +func mutatePredicateWithUserAccount(t *testing.T, gc *dgraphapi.GrpcClient, shouldFail bool) { mu := &api.Mutation{SetNquads: []byte(fmt.Sprintf(`_:a <%s> "string" .`, predicateToWrite)), CommitNow: true} _, err := gc.Mutate(mu) if shouldFail { @@ -370,7 +650,7 @@ func mutatePredicateWithUserAccount(t *testing.T, gc *dgraphtest.GrpcClient, sho } } -func alterPredicateWithUserAccount(t *testing.T, gc *dgraphtest.GrpcClient, shouldFail bool) { +func alterPredicateWithUserAccount(t *testing.T, gc *dgraphapi.GrpcClient, shouldFail bool) { err := gc.Alter(context.Background(), &api.Operation{Schema: fmt.Sprintf(`%s: int .`, predicateToAlter)}) if shouldFail { require.Error(t, err, "the alter should have failed") @@ -379,10 +659,10 @@ func alterPredicateWithUserAccount(t *testing.T, gc *dgraphtest.GrpcClient, shou } } -func createAccountAndData(t *testing.T, gc *dgraphtest.GrpcClient, hc *dgraphtest.HTTPClient) { +func createAccountAndData(t *testing.T, gc *dgraphapi.GrpcClient, hc *dgraphapi.HTTPClient) { // use the groot account to clean the database - require.NoError(t, gc.LoginIntoNamespace(context.Background(), dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(context.Background(), dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll(), "Unable to cleanup db") require.NoError(t, gc.Alter(context.Background(), &api.Operation{ @@ -409,7 +689,7 @@ type group struct { Rules []rule `json:"rules"` } -func createGroupAndAcls(t *testing.T, group string, addUserToGroup bool, hc *dgraphtest.HTTPClient) { +func createGroupAndAcls(t *testing.T, group string, addUserToGroup bool, hc *dgraphapi.HTTPClient) { // create a new group createdGroup, err := hc.CreateGroup(group) require.NoError(t, err) @@ -420,7 +700,7 @@ func createGroupAndAcls(t *testing.T, group string, addUserToGroup bool, hc *dgr require.NoError(t, hc.AddUserToGroup(userid, group)) } - rules := []dgraphtest.AclRule{ + rules := []dgraphapi.AclRule{ { Predicate: predicateToRead, Permission: Read.Code, }, @@ -454,9 +734,9 @@ func (asuite *AclTestSuite) TestPredicatePermission() { hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) createAccountAndData(t, gc, hc) @@ -466,12 +746,12 @@ func (asuite *AclTestSuite) TestPredicatePermission() { require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.LoginIntoNamespace(ctx, userid, userpassword, x.GalaxyNamespace), "Logging in with the current password should have succeeded") @@ -509,9 +789,9 @@ func (asuite *AclTestSuite) TestAccessWithoutLoggingIn() { hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) createAccountAndData(t, gc, hc) @@ -540,8 +820,8 @@ func (asuite *AclTestSuite) TestUnauthorizedDeletion() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) @@ -551,8 +831,8 @@ func (asuite *AclTestSuite) TestUnauthorizedDeletion() { hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) resetUser(t, hc) @@ -567,7 +847,7 @@ func (asuite *AclTestSuite) TestUnauthorizedDeletion() { nodeUID, ok := resp.Uids["a"] require.True(t, ok) - require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphtest.AclRule{{Predicate: unAuthPred, Permission: 0}}, true)) + require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphapi.AclRule{{Predicate: unAuthPred, Permission: 0}}, true)) asuite.Upgrade() @@ -598,9 +878,9 @@ func (asuite *AclTestSuite) TestGuardianAccess() { hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) op := api.Operation{Schema: "unauthpred: string @index(exact) ."} @@ -651,7 +931,7 @@ func (asuite *AclTestSuite) TestGuardianAccess() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, hc.RemoveUserFromGroup("guardian", "guardians")) @@ -667,13 +947,13 @@ func (asuite *AclTestSuite) TestQueryRemoveUnauthorizedPred() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) @@ -711,7 +991,7 @@ func (asuite *AclTestSuite) TestQueryRemoveUnauthorizedPred() { // give read access of to alice require.NoError(t, hc.AddRulesToGroup(devGroup, - []dgraphtest.AclRule{{Predicate: "name", Permission: Read.Code}}, true)) + []dgraphapi.AclRule{{Predicate: "name", Permission: Read.Code}}, true)) asuite.Upgrade() userClient, cleanup, err := asuite.dc.Client() require.NoError(t, err) @@ -804,7 +1084,7 @@ func (asuite *AclTestSuite) TestQueryRemoveUnauthorizedPred() { t.Run(tc.description, func(t *testing.T) { // testify does not support subtests running in parallel with suite package // t.Parallel() - require.NoError(t, dgraphtest.PollTillPassOrTimeout(userClient, tc.input, tc.output, timeout)) + require.NoError(t, dgraphapi.PollTillPassOrTimeout(userClient, tc.input, tc.output, timeout)) }) } } @@ -817,12 +1097,12 @@ func (asuite *AclTestSuite) TestExpandQueryWithACLPermissions() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) @@ -846,7 +1126,7 @@ func (asuite *AclTestSuite) TestExpandQueryWithACLPermissions() { createdGroup, err = hc.CreateGroup(sreGroup) require.NoError(t, err) require.Equal(t, sreGroup, createdGroup) - require.NoError(t, hc.AddRulesToGroup(sreGroup, []dgraphtest.AclRule{{Predicate: "age", Permission: Read.Code}, + require.NoError(t, hc.AddRulesToGroup(sreGroup, []dgraphapi.AclRule{{Predicate: "age", Permission: Read.Code}, {Predicate: "name", Permission: Write.Code}}, true)) require.NoError(t, hc.AddUserToGroup(userid, devGroup)) @@ -870,7 +1150,7 @@ func (asuite *AclTestSuite) TestExpandQueryWithACLPermissions() { query := "{me(func: has(name)){expand(_all_)}}" resp, err := gc.Query(query) require.NoError(t, err, "Error while querying data") - require.NoError(t, dgraphtest.CompareJSON( + require.NoError(t, dgraphapi.CompareJSON( `{"me":[{"name":"RandomGuy","age":23, "nickname":"RG"},{"name":"RandomGuy2","age":25, "nickname":"RG2"}]}`, string(resp.GetJson()))) @@ -881,41 +1161,41 @@ func (asuite *AclTestSuite) TestExpandQueryWithACLPermissions() { defer cleanup() hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) time.Sleep(defaultTimeToSleep) require.NoError(t, userClient.LoginIntoNamespace(ctx, userid, userpassword, x.GalaxyNamespace)) // Query via user when user has no permissions - require.NoError(t, dgraphtest.PollTillPassOrTimeout(userClient, query, `{}`, timeout)) + require.NoError(t, dgraphapi.PollTillPassOrTimeout(userClient, query, `{}`, timeout)) // Give read access of , write access of to dev - require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphtest.AclRule{{Predicate: "age", Permission: Write.Code}, + require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphapi.AclRule{{Predicate: "age", Permission: Write.Code}, {Predicate: "name", Permission: Read.Code}}, true)) - require.NoError(t, dgraphtest.PollTillPassOrTimeout(userClient, query, + require.NoError(t, dgraphapi.PollTillPassOrTimeout(userClient, query, `{"me":[{"name":"RandomGuy"},{"name":"RandomGuy2"}]}`, timeout)) // Login to groot to modify accesses (2) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) // Add alice to sre group which has read access to and write access to require.NoError(t, hc.AddUserToGroup(userid, sreGroup)) - require.NoError(t, dgraphtest.PollTillPassOrTimeout(userClient, query, + require.NoError(t, dgraphapi.PollTillPassOrTimeout(userClient, query, `{"me":[{"name":"RandomGuy","age":23},{"name":"RandomGuy2","age":25}]}`, timeout)) // Login to groot to modify accesses (3) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) // Give read access of and , write access of to dev - require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphtest.AclRule{{Predicate: "age", Permission: Write.Code}, + require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphapi.AclRule{{Predicate: "age", Permission: Write.Code}, {Predicate: "name", Permission: Read.Code}, {Predicate: "nickname", Permission: Read.Code}}, true)) - require.NoError(t, dgraphtest.PollTillPassOrTimeout(userClient, query, + require.NoError(t, dgraphapi.PollTillPassOrTimeout(userClient, query, `{"me":[{"name":"RandomGuy","age":23, "nickname":"RG"},{"name":"RandomGuy2","age":25, "nickname":"RG2"}]}`, timeout)) } @@ -928,12 +1208,12 @@ func (asuite *AclTestSuite) TestDeleteQueryWithACLPermissions() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) @@ -978,20 +1258,20 @@ func (asuite *AclTestSuite) TestDeleteQueryWithACLPermissions() { // Test that groot has access to all the predicates resp, err = gc.Query(query) require.NoError(t, err, "Error while querying data") - require.NoError(t, dgraphtest.CompareJSON( + require.NoError(t, dgraphapi.CompareJSON( `{"q1":[{"name":"RandomGuy","age":23, "nickname": "RG"},{"name":"RandomGuy2","age":25, "nickname": "RG2"}]}`, string(resp.GetJson()))) // Give Write Access to alice for name and age predicate - require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphtest.AclRule{{Predicate: "name", Permission: Write.Code}, + require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphapi.AclRule{{Predicate: "name", Permission: Write.Code}, {Predicate: "age", Permission: Write.Code}}, true)) asuite.Upgrade() gc, _, err = asuite.dc.Client() require.NoError(t, err) - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) userClient, cleanup, err := asuite.dc.Client() require.NoError(t, err) @@ -1006,18 +1286,18 @@ func (asuite *AclTestSuite) TestDeleteQueryWithACLPermissions() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) resp, err = gc.Query(query) require.NoError(t, err, "Error while querying data") // Only name and age predicates got deleted via user - alice - require.NoError(t, dgraphtest.CompareJSON( + require.NoError(t, dgraphapi.CompareJSON( `{"q1":[{"nickname": "RG"},{"name":"RandomGuy2","age":25, "nickname": "RG2"}]}`, string(resp.GetJson()))) // Give write access of to dev - require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphtest.AclRule{{Predicate: "name", Permission: Write.Code}, + require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphapi.AclRule{{Predicate: "name", Permission: Write.Code}, {Predicate: "age", Permission: Write.Code}, {Predicate: "dgraph.type", Permission: Write.Code}}, true)) time.Sleep(defaultTimeToSleep) @@ -1029,7 +1309,7 @@ func (asuite *AclTestSuite) TestDeleteQueryWithACLPermissions() { resp, err = gc.Query(query) require.NoError(t, err, "Error while querying data") // Because alise had permission to dgraph.type the node reference has been deleted - require.NoError(t, dgraphtest.CompareJSON(`{"q1":[{"name":"RandomGuy2","age":25, "nickname": "RG2"}]}`, + require.NoError(t, dgraphapi.CompareJSON(`{"q1":[{"name":"RandomGuy2","age":25, "nickname": "RG2"}]}`, string(resp.GetJson()))) } @@ -1041,12 +1321,12 @@ func (asuite *AclTestSuite) TestValQueryWithACLPermissions() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) @@ -1093,7 +1373,7 @@ func (asuite *AclTestSuite) TestValQueryWithACLPermissions() { // Test that groot has access to all the predicates resp, err := gc.Query(query) require.NoError(t, err, "Error while querying data") - require.NoError(t, dgraphtest.CompareJSON( + require.NoError(t, dgraphapi.CompareJSON( `{"q1":[{"name":"RandomGuy","age":23},{"name":"RandomGuy2","age":25}],`+ `"q2":[{"val(v)":"RandomGuy","val(a)":23}]}`, string(resp.GetJson()))) @@ -1203,8 +1483,8 @@ func (asuite *AclTestSuite) TestValQueryWithACLPermissions() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) time.Sleep(defaultTimeToSleep) @@ -1214,13 +1494,13 @@ func (asuite *AclTestSuite) TestValQueryWithACLPermissions() { t.Run(desc, func(t *testing.T) { resp, err := userClient.Query(tc.input) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(tc.outputNoPerm, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(tc.outputNoPerm, string(resp.Json))) }) } // Give read access of to dev require.NoError(t, hc.AddRulesToGroup(devGroup, - []dgraphtest.AclRule{{Predicate: "name", Permission: Read.Code}}, true)) + []dgraphapi.AclRule{{Predicate: "name", Permission: Read.Code}}, true)) time.Sleep(defaultTimeToSleep) for _, tc := range tests { @@ -1228,16 +1508,16 @@ func (asuite *AclTestSuite) TestValQueryWithACLPermissions() { t.Run(desc, func(t *testing.T) { resp, err := userClient.Query(tc.input) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(tc.outputNamePerm, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(tc.outputNamePerm, string(resp.Json))) }) } // Login to groot to modify accesses (1) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) // Give read access of and to dev - require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphtest.AclRule{{Predicate: "name", Permission: Read.Code}, + require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphapi.AclRule{{Predicate: "name", Permission: Read.Code}, {Predicate: "age", Permission: Read.Code}}, true)) time.Sleep(defaultTimeToSleep) @@ -1247,7 +1527,7 @@ func (asuite *AclTestSuite) TestValQueryWithACLPermissions() { t.Run(desc, func(t *testing.T) { resp, err := userClient.Query(tc.input) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(tc.outputNameAgePerm, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(tc.outputNameAgePerm, string(resp.Json))) }) } @@ -1261,13 +1541,13 @@ func (asuite *AclTestSuite) TestAllPredsPermission() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) @@ -1317,7 +1597,7 @@ func (asuite *AclTestSuite) TestAllPredsPermission() { // Test that groot has access to all the predicates resp, err := gc.Query(query) require.NoError(t, err, "Error while querying data") - require.NoError(t, dgraphtest.CompareJSON( + require.NoError(t, dgraphapi.CompareJSON( `{"q1":[{"name":"RandomGuy","age":23},{"name":"RandomGuy2","age":25}],`+ `"q2":[{"val(v)":"RandomGuy","val(a)":23}]}`, string(resp.GetJson()))) @@ -1370,19 +1650,19 @@ func (asuite *AclTestSuite) TestAllPredsPermission() { t.Run(desc, func(t *testing.T) { resp, err := userClient.Query(tc.input) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(tc.outputNoPerm, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(tc.outputNoPerm, string(resp.Json))) }) } // Login to groot to modify accesses (1) hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) // Give read access of all predicates to dev require.NoError(t, hc.AddRulesToGroup(devGroup, - []dgraphtest.AclRule{{Predicate: "dgraph.all", Permission: Read.Code}}, true)) + []dgraphapi.AclRule{{Predicate: "dgraph.all", Permission: Read.Code}}, true)) time.Sleep(defaultTimeToSleep) for _, tc := range tests { @@ -1390,7 +1670,7 @@ func (asuite *AclTestSuite) TestAllPredsPermission() { t.Run(desc, func(t *testing.T) { resp, err := userClient.Query(tc.input) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(tc.outputNameAgePerm, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(tc.outputNameAgePerm, string(resp.Json))) }) } @@ -1408,7 +1688,7 @@ func (asuite *AclTestSuite) TestAllPredsPermission() { // Give write access of all predicates to dev. Now mutation should succeed. require.NoError(t, hc.AddRulesToGroup(devGroup, - []dgraphtest.AclRule{{Predicate: "dgraph.all", Permission: Write.Code | Read.Code}}, true)) + []dgraphapi.AclRule{{Predicate: "dgraph.all", Permission: Write.Code | Read.Code}}, true)) time.Sleep(defaultTimeToSleep) _, err = userClient.Mutate(mu) @@ -1423,12 +1703,12 @@ func (asuite *AclTestSuite) TestNewACLPredicates() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) addDataAndRules(ctx, t, gc, hc) @@ -1482,7 +1762,7 @@ func (asuite *AclTestSuite) TestNewACLPredicates() { resp, err := userClient.NewTxn().Query(ctx, tc.input) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(tc.output, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(tc.output, string(resp.Json))) }) } @@ -1524,12 +1804,12 @@ func (asuite *AclTestSuite) TestDeleteRule() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) addDataAndRules(ctx, t, gc, hc) @@ -1537,8 +1817,8 @@ func (asuite *AclTestSuite) TestDeleteRule() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) userClient, cleanup, err := asuite.dc.Client() require.NoError(t, err) @@ -1551,7 +1831,7 @@ func (asuite *AclTestSuite) TestDeleteRule() { resp, err := userClient.Query(queryName) require.NoError(t, err, "Error while querying data") - require.NoError(t, dgraphtest.CompareJSON(`{"me":[{"name":"RandomGuy"},{"name":"RandomGuy2"}]}`, + require.NoError(t, dgraphapi.CompareJSON(`{"me":[{"name":"RandomGuy"},{"name":"RandomGuy2"}]}`, string(resp.GetJson()))) require.NoError(t, hc.RemovePredicateFromGroup(devGroup, "name")) @@ -1559,10 +1839,10 @@ func (asuite *AclTestSuite) TestDeleteRule() { resp, err = userClient.Query(queryName) require.NoError(t, err, "Error while querying data") - require.NoError(t, dgraphtest.CompareJSON(string(resp.GetJson()), `{}`)) + require.NoError(t, dgraphapi.CompareJSON(string(resp.GetJson()), `{}`)) } -func addDataAndRules(ctx context.Context, t *testing.T, gc *dgraphtest.GrpcClient, hc *dgraphtest.HTTPClient) { +func addDataAndRules(ctx context.Context, t *testing.T, gc *dgraphapi.GrpcClient, hc *dgraphapi.HTTPClient) { require.NoError(t, gc.DropAll()) op := api.Operation{Schema: ` name : string @index(exact) . @@ -1636,12 +1916,12 @@ func (asuite *AclTestSuite) TestQueryUserInfo() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) addDataAndRules(ctx, t, gc, hc) require.NoError(t, hc.LoginIntoNamespace(userid, userpassword, x.GalaxyNamespace)) @@ -1663,12 +1943,12 @@ func (asuite *AclTestSuite) TestQueryUserInfo() { } } ` - params := dgraphtest.GraphQLParams{ + params := dgraphapi.GraphQLParams{ Query: gqlQuery, } gqlResp, err := hc.RunGraphqlQuery(params, true) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(` + require.NoError(t, dgraphapi.CompareJSON(` { "queryUser": [ { @@ -1734,7 +2014,7 @@ func (asuite *AclTestSuite) TestQueryUserInfo() { resp, err := userClient.Query(query) require.NoError(t, err, "Error while querying ACL") - require.NoError(t, dgraphtest.CompareJSON(`{"me":[]}`, string(resp.GetJson()))) + require.NoError(t, dgraphapi.CompareJSON(`{"me":[]}`, string(resp.GetJson()))) gqlQuery = ` query { @@ -1749,12 +2029,12 @@ func (asuite *AclTestSuite) TestQueryUserInfo() { } } }` - params = dgraphtest.GraphQLParams{Query: gqlQuery} + params = dgraphapi.GraphQLParams{Query: gqlQuery} gqlResp, err = hc.RunGraphqlQuery(params, true) require.NoError(t, err) // The user should only be able to see their group dev and themselves as the user. - require.NoError(t, dgraphtest.CompareJSON(`{ + require.NoError(t, dgraphapi.CompareJSON(`{ "queryGroup": [ { "name": "dev", @@ -1799,27 +2079,28 @@ func (asuite *AclTestSuite) TestQueryUserInfo() { } } }` - params = dgraphtest.GraphQLParams{Query: gqlQuery} + params = dgraphapi.GraphQLParams{Query: gqlQuery} gqlResp, err = hc.RunGraphqlQuery(params, true) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{"getGroup": null}`, string(gqlResp))) + require.NoError(t, dgraphapi.CompareJSON(`{"getGroup": null}`, string(gqlResp))) } func (asuite *AclTestSuite) TestQueriesWithUserAndGroupOfSameName() { t := asuite.T() + dgraphtest.ShouldSkipTest(t, asuite.dc.GetVersion(), "7b1f473ddf01547e24b44f580a68e6b049502d69") ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second) defer cancel() gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) // Creates a user -- alice @@ -1847,7 +2128,7 @@ func (asuite *AclTestSuite) TestQueriesWithUserAndGroupOfSameName() { // add rules to groups require.NoError(t, hc.AddRulesToGroup("alice", - []dgraphtest.AclRule{{Predicate: "name", Permission: Read.Code}}, true)) + []dgraphapi.AclRule{{Predicate: "name", Permission: Read.Code}}, true)) asuite.Upgrade() @@ -1863,7 +2144,7 @@ func (asuite *AclTestSuite) TestQueriesWithUserAndGroupOfSameName() { age } }` - require.NoError(t, dgraphtest.PollTillPassOrTimeout(dc, query, + require.NoError(t, dgraphapi.PollTillPassOrTimeout(dc, query, `{"q":[{"name":"RandomGuy"},{"name":"RandomGuy2"}]}`, timeout)) } @@ -1872,8 +2153,8 @@ func (asuite *AclTestSuite) TestQueriesForNonGuardianUserWithoutGroup() { hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) // Create a new user without any groups, queryGroup should return an empty result. resetUser(t, hc) @@ -1895,10 +2176,10 @@ func (asuite *AclTestSuite) TestQueriesForNonGuardianUserWithoutGroup() { } ` - params := dgraphtest.GraphQLParams{Query: gqlQuery} + params := dgraphapi.GraphQLParams{Query: gqlQuery} gqlResp, err := hc.RunGraphqlQuery(params, true) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{"queryGroup": []}`, string(gqlResp))) + require.NoError(t, dgraphapi.CompareJSON(`{"queryGroup": []}`, string(gqlResp))) gqlQuery = ` query { @@ -1909,10 +2190,10 @@ func (asuite *AclTestSuite) TestQueriesForNonGuardianUserWithoutGroup() { } } }` - params = dgraphtest.GraphQLParams{Query: gqlQuery} + params = dgraphapi.GraphQLParams{Query: gqlQuery} gqlResp, err = hc.RunGraphqlQuery(params, true) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{"queryUser": [{ "groups": [], "name": "alice"}]}`, string(gqlResp))) + require.NoError(t, dgraphapi.CompareJSON(`{"queryUser": [{ "groups": [], "name": "alice"}]}`, string(gqlResp))) } func (asuite *AclTestSuite) TestSchemaQueryWithACL() { @@ -1923,136 +2204,7 @@ func (asuite *AclTestSuite) TestSchemaQueryWithACL() { defer cancel() schemaQuery := "schema{}" - grootSchema := `{ - "schema": [ - { - "predicate": "dgraph.acl.rule", - "type": "uid", - "list": true - }, - { - "predicate":"dgraph.drop.op", - "type":"string" - }, - { - "predicate":"dgraph.graphql.p_query", - "type":"string", - "index":true, - "tokenizer":["sha256"] - }, - { - "predicate": "dgraph.graphql.schema", - "type": "string" - }, - { - "predicate": "dgraph.graphql.xid", - "type": "string", - "index": true, - "tokenizer": [ - "exact" - ], - "upsert": true - }, - { - "predicate": "dgraph.password", - "type": "password" - }, - { - "predicate": "dgraph.rule.permission", - "type": "int" - }, - { - "predicate": "dgraph.rule.predicate", - "type": "string", - "index": true, - "tokenizer": [ - "exact" - ], - "upsert": true - }, - { - "predicate": "dgraph.type", - "type": "string", - "index": true, - "tokenizer": [ - "exact" - ], - "list": true - }, - { - "predicate": "dgraph.user.group", - "type": "uid", - "reverse": true, - "list": true - }, - { - "predicate": "dgraph.xid", - "type": "string", - "index": true, - "tokenizer": [ - "exact" - ], - "upsert": true - } - ], - "types": [ - { - "fields": [ - { - "name": "dgraph.graphql.schema" - }, - { - "name": "dgraph.graphql.xid" - } - ], - "name": "dgraph.graphql" - }, - { - "fields": [ - { - "name": "dgraph.graphql.p_query" - } - ], - "name": "dgraph.graphql.persisted_query" - }, - { - "fields": [ - { - "name": "dgraph.xid" - }, - { - "name": "dgraph.acl.rule" - } - ], - "name": "dgraph.type.Group" - }, - { - "fields": [ - { - "name": "dgraph.rule.predicate" - }, - { - "name": "dgraph.rule.permission" - } - ], - "name": "dgraph.type.Rule" - }, - { - "fields": [ - { - "name": "dgraph.xid" - }, - { - "name": "dgraph.password" - }, - { - "name": "dgraph.user.group" - } - ], - "name": "dgraph.type.User" - } - ] -}` + aliceSchema := `{ "schema": [ { @@ -2092,17 +2244,23 @@ func (asuite *AclTestSuite) TestSchemaQueryWithACL() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) resp, err := gc.Query(schemaQuery) require.NoError(t, err) - require.JSONEq(t, grootSchema, string(resp.GetJson())) + supported, err := dgraphtest.IsHigherVersion(asuite.dc.GetVersion(), "daa5805739ed258e913a157c6e0f126b2291b1b0") + require.NoError(t, err) + if supported { + require.JSONEq(t, newGrootSchema, string(resp.GetJson())) + } else { + require.JSONEq(t, oldGrootSchema, string(resp.GetJson())) + } // add another user and some data for that user with permissions on predicates resetUser(t, hc) @@ -2129,12 +2287,12 @@ func (asuite *AclTestSuite) TestDeleteUserShouldDeleteUserFromGroup() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) resetUser(t, hc) addDataAndRules(ctx, t, gc, hc) @@ -2143,8 +2301,8 @@ func (asuite *AclTestSuite) TestDeleteUserShouldDeleteUserFromGroup() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, hc.DeleteUser(userid)) gqlQuery := ` @@ -2153,7 +2311,7 @@ func (asuite *AclTestSuite) TestDeleteUserShouldDeleteUserFromGroup() { name } }` - params := dgraphtest.GraphQLParams{Query: gqlQuery} + params := dgraphapi.GraphQLParams{Query: gqlQuery} gqlResp, err := hc.RunGraphqlQuery(params, true) require.NoError(t, err) require.JSONEq(t, `{"queryUser":[{"name":"groot"}]}`, string(gqlResp)) @@ -2168,10 +2326,10 @@ func (asuite *AclTestSuite) TestDeleteUserShouldDeleteUserFromGroup() { } } }` - params = dgraphtest.GraphQLParams{Query: gqlQuery} + params = dgraphapi.GraphQLParams{Query: gqlQuery} gqlResp, err = hc.RunGraphqlQuery(params, true) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ + require.NoError(t, dgraphapi.CompareJSON(`{ "queryGroup": [ { "name": "guardians", @@ -2205,12 +2363,12 @@ func (asuite *AclTestSuite) TestGroupDeleteShouldDeleteGroupFromUser() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) resetUser(t, hc) addDataAndRules(ctx, t, gc, hc) @@ -2219,8 +2377,8 @@ func (asuite *AclTestSuite) TestGroupDeleteShouldDeleteGroupFromUser() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, hc.DeleteGroup("dev-a")) gqlQuery := ` @@ -2229,10 +2387,10 @@ func (asuite *AclTestSuite) TestGroupDeleteShouldDeleteGroupFromUser() { name } }` - params := dgraphtest.GraphQLParams{Query: gqlQuery} + params := dgraphapi.GraphQLParams{Query: gqlQuery} gqlResp, err := hc.RunGraphqlQuery(params, true) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ + require.NoError(t, dgraphapi.CompareJSON(`{ "queryGroup": [ { "name": "guardians" @@ -2255,10 +2413,10 @@ func (asuite *AclTestSuite) TestGroupDeleteShouldDeleteGroupFromUser() { } } }` - params = dgraphtest.GraphQLParams{Query: gqlQuery} + params = dgraphapi.GraphQLParams{Query: gqlQuery} gqlResp, err = hc.RunGraphqlQuery(params, true) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ + require.NoError(t, dgraphapi.CompareJSON(`{ "getUser": { "name": "alice", "groups": [ @@ -2300,11 +2458,11 @@ func (asuite *AclTestSuite) TestAddUpdateGroupWithDuplicateRules() { hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) groupName := "testGroup" - addedRules := []dgraphtest.AclRule{ + addedRules := []dgraphapi.AclRule{ { Predicate: "test", Permission: 1, @@ -2330,9 +2488,9 @@ func (asuite *AclTestSuite) TestAddUpdateGroupWithDuplicateRules() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) - updatedRules := []dgraphtest.AclRule{ + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + updatedRules := []dgraphapi.AclRule{ { Predicate: "test", Permission: 3, @@ -2350,7 +2508,7 @@ func (asuite *AclTestSuite) TestAddUpdateGroupWithDuplicateRules() { require.NoError(t, err) require.Equal(t, groupName, updatedGroup.Name) require.Len(t, updatedGroup.Rules, 3) - require.ElementsMatch(t, []dgraphtest.AclRule{updatedRules[0], addedRules[2], updatedRules[2]}, + require.ElementsMatch(t, []dgraphapi.AclRule{updatedRules[0], addedRules[2], updatedRules[2]}, updatedGroup.Rules) updatedGroup1, err := hc.UpdateGroup(groupName, nil, @@ -2359,7 +2517,7 @@ func (asuite *AclTestSuite) TestAddUpdateGroupWithDuplicateRules() { require.Equal(t, groupName, updatedGroup1.Name) require.Len(t, updatedGroup1.Rules, 2) - require.ElementsMatch(t, []dgraphtest.AclRule{updatedRules[0], updatedRules[2]}, updatedGroup1.Rules) + require.ElementsMatch(t, []dgraphapi.AclRule{updatedRules[0], updatedRules[2]}, updatedGroup1.Rules) // cleanup require.NoError(t, hc.DeleteGroup(groupName)) @@ -2373,12 +2531,12 @@ func (asuite *AclTestSuite) TestAllowUIDAccess() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) op := api.Operation{Schema: ` @@ -2400,7 +2558,7 @@ func (asuite *AclTestSuite) TestAllowUIDAccess() { // give read access of to alice require.NoError(t, hc.AddRulesToGroup(devGroup, - []dgraphtest.AclRule{{Predicate: "name", Permission: Read.Code}}, true)) + []dgraphapi.AclRule{{Predicate: "name", Permission: Read.Code}}, true)) asuite.Upgrade() @@ -2420,7 +2578,7 @@ func (asuite *AclTestSuite) TestAllowUIDAccess() { }` resp, err := userClient.Query(uidQuery) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{"me":[{"name":"100th User", "uid": "0x64"}]}`, string(resp.GetJson()))) + require.NoError(t, dgraphapi.CompareJSON(`{"me":[{"name":"100th User", "uid": "0x64"}]}`, string(resp.GetJson()))) } func (asuite *AclTestSuite) TestAddNewPredicate() { @@ -2431,12 +2589,12 @@ func (asuite *AclTestSuite) TestAddNewPredicate() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) resetUser(t, hc) @@ -2445,8 +2603,8 @@ func (asuite *AclTestSuite) TestAddNewPredicate() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) userClient, cancel, err := asuite.dc.Client() defer cleanup() require.NoError(t, err) @@ -2474,13 +2632,13 @@ func (asuite *AclTestSuite) TestCrossGroupPermission() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) @@ -2498,14 +2656,14 @@ func (asuite *AclTestSuite) TestCrossGroupPermission() { require.NoError(t, err) require.Equal(t, "alterer", createdGroup) // add rules to groups - require.NoError(t, hc.AddRulesToGroup("reader", []dgraphtest.AclRule{{Predicate: "newpred", Permission: 4}}, true)) - require.NoError(t, hc.AddRulesToGroup("writer", []dgraphtest.AclRule{{Predicate: "newpred", Permission: 2}}, true)) - require.NoError(t, hc.AddRulesToGroup("alterer", []dgraphtest.AclRule{{Predicate: "newpred", Permission: 1}}, true)) + require.NoError(t, hc.AddRulesToGroup("reader", []dgraphapi.AclRule{{Predicate: "newpred", Permission: 4}}, true)) + require.NoError(t, hc.AddRulesToGroup("writer", []dgraphapi.AclRule{{Predicate: "newpred", Permission: 2}}, true)) + require.NoError(t, hc.AddRulesToGroup("alterer", []dgraphapi.AclRule{{Predicate: "newpred", Permission: 1}}, true)) // Wait for acl cache to be refreshed time.Sleep(defaultTimeToSleep) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) // create 8 users. for i := 0; i < 8; i++ { userIdx := strconv.Itoa(i) @@ -2530,7 +2688,7 @@ func (asuite *AclTestSuite) TestCrossGroupPermission() { time.Sleep(defaultTimeToSleep) // operations - dgQuery := func(client *dgraphtest.GrpcClient, shouldFail bool, user string) { + dgQuery := func(client *dgraphapi.GrpcClient, shouldFail bool, user string) { _, err := client.Query(` { me(func: has(newpred)) { @@ -2541,7 +2699,7 @@ func (asuite *AclTestSuite) TestCrossGroupPermission() { require.True(t, (err != nil) == shouldFail, "Query test Failed for: "+user+", shouldFail: "+strconv.FormatBool(shouldFail)) } - dgMutation := func(client *dgraphtest.GrpcClient, shouldFail bool, user string) { + dgMutation := func(client *dgraphapi.GrpcClient, shouldFail bool, user string) { _, err := client.Mutate(&api.Mutation{ Set: []*api.NQuad{ { @@ -2555,7 +2713,7 @@ func (asuite *AclTestSuite) TestCrossGroupPermission() { require.True(t, (err != nil) == shouldFail, "Mutation test failed for: "+user+", shouldFail: "+strconv.FormatBool(shouldFail)) } - dgAlter := func(client *dgraphtest.GrpcClient, shouldFail bool, user string) { + dgAlter := func(client *dgraphapi.GrpcClient, shouldFail bool, user string) { err := client.Alter(ctx, &api.Operation{Schema: `newpred: string @index(exact) .`}) require.True(t, (err != nil) == shouldFail, "Alter test failed for: "+user+", shouldFail: "+strconv.FormatBool(shouldFail)) @@ -2592,12 +2750,12 @@ func (asuite *AclTestSuite) TestMutationWithValueVar() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.DropAll()) err = gc.Alter(ctx, &api.Operation{ @@ -2623,7 +2781,7 @@ func (asuite *AclTestSuite) TestMutationWithValueVar() { require.NoError(t, err) require.Equal(t, devGroup, createdGroup) require.NoError(t, hc.AddUserToGroup(userid, devGroup)) - require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphtest.AclRule{ + require.NoError(t, hc.AddRulesToGroup(devGroup, []dgraphapi.AclRule{ { Predicate: "name", Permission: Read.Code | Write.Code, @@ -2683,7 +2841,7 @@ func (asuite *AclTestSuite) TestMutationWithValueVar() { resp, err := userClient.Query(query) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{"me": [{"name":"r1","nickname":"r1"}]}`, string(resp.GetJson()))) + require.NoError(t, dgraphapi.CompareJSON(`{"me": [{"name":"r1","nickname":"r1"}]}`, string(resp.GetJson()))) } func (asuite *AclTestSuite) TestDeleteGuardiansGroupShouldFail() { @@ -2694,12 +2852,12 @@ func (asuite *AclTestSuite) TestDeleteGuardiansGroupShouldFail() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) addDataAndRules(ctx, t, gc, hc) @@ -2707,7 +2865,7 @@ func (asuite *AclTestSuite) TestDeleteGuardiansGroupShouldFail() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) err = hc.DeleteGroup("guardians") require.Error(t, err) @@ -2722,13 +2880,13 @@ func (asuite *AclTestSuite) TestDeleteGrootUserShouldFail() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) addDataAndRules(ctx, t, gc, hc) @@ -2736,7 +2894,7 @@ func (asuite *AclTestSuite) TestDeleteGrootUserShouldFail() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) err = hc.DeleteUser("groot") require.Error(t, err) @@ -2752,12 +2910,12 @@ func (asuite *AclTestSuite) TestDeleteGrootUserFromGuardiansGroupShouldFail() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) addDataAndRules(ctx, t, gc, hc) @@ -2765,7 +2923,7 @@ func (asuite *AclTestSuite) TestDeleteGrootUserFromGuardiansGroupShouldFail() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) err = hc.RemoveUserFromGroup("groot", "guardians") require.Error(t, err) @@ -2780,12 +2938,12 @@ func (asuite *AclTestSuite) TestDeleteGrootAndGuardiansUsingDelNQuadShouldFail() gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) addDataAndRules(ctx, t, gc, hc) @@ -2794,8 +2952,8 @@ func (asuite *AclTestSuite) TestDeleteGrootAndGuardiansUsingDelNQuadShouldFail() gc, cleanup, err = asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) grootUid, guardiansUid := getGrootAndGuardiansUid(t, gc) @@ -2812,7 +2970,7 @@ func (asuite *AclTestSuite) TestDeleteGrootAndGuardiansUsingDelNQuadShouldFail() require.Contains(t, err.Error(), "Properties of guardians group and groot user cannot be deleted") } -func deleteGuardiansGroupAndGrootUserShouldFail(t *testing.T, hc *dgraphtest.HTTPClient) { +func deleteGuardiansGroupAndGrootUserShouldFail(t *testing.T, hc *dgraphapi.HTTPClient) { // Try deleting guardians group should fail err := hc.DeleteGroup("guardians") require.Error(t, err) @@ -2831,12 +2989,12 @@ func (asuite *AclTestSuite) TestDropAllShouldResetGuardiansAndGroot() { gc, cleanup, err := asuite.dc.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) addDataAndRules(ctx, t, gc, hc) @@ -2845,8 +3003,8 @@ func (asuite *AclTestSuite) TestDropAllShouldResetGuardiansAndGroot() { gc, cleanup, err = asuite.dc.Client() defer cleanup() require.NoError(t, err) - require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) // Try Drop All op := api.Operation{ @@ -2861,7 +3019,7 @@ func (asuite *AclTestSuite) TestDropAllShouldResetGuardiansAndGroot() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) deleteGuardiansGroupAndGrootUserShouldFail(t, hc) // Try Drop Data @@ -2875,6 +3033,6 @@ func (asuite *AclTestSuite) TestDropAllShouldResetGuardiansAndGroot() { hc, err = asuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) deleteGuardiansGroupAndGrootUserShouldFail(t, hc) } diff --git a/ee/acl/integration_test.go b/ee/acl/integration_test.go index 76b518b7fae..51eabf5466f 100644 --- a/ee/acl/integration_test.go +++ b/ee/acl/integration_test.go @@ -17,12 +17,13 @@ import ( "github.com/stretchr/testify/suite" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" ) type AclTestSuite struct { suite.Suite - dc dgraphtest.Cluster + dc dgraphapi.Cluster } func (suite *AclTestSuite) SetupTest() { diff --git a/ee/acl/jwt_algo_test.go b/ee/acl/jwt_algo_test.go index 95933d5ec3a..f67ab398f87 100644 --- a/ee/acl/jwt_algo_test.go +++ b/ee/acl/jwt_algo_test.go @@ -26,6 +26,7 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) @@ -48,7 +49,7 @@ func TestACLJwtAlgo(t *testing.T) { require.NoError(t, err) defer cleanup() require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) // op with Grpc client _, err = gc.Query(`{q(func: uid(0x1)) {uid}}`) @@ -61,8 +62,8 @@ func TestACLJwtAlgo(t *testing.T) { hc, err := c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) // op with HTTP client require.NoError(t, hc.Backup(c, true, dgraphtest.DefaultBackupDir)) diff --git a/ee/acl/upgrade_test.go b/ee/acl/upgrade_test.go index cdc2774ee3e..9230445c460 100644 --- a/ee/acl/upgrade_test.go +++ b/ee/acl/upgrade_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) @@ -27,7 +28,7 @@ import ( type AclTestSuite struct { suite.Suite lc *dgraphtest.LocalCluster - dc dgraphtest.Cluster + dc dgraphapi.Cluster uc dgraphtest.UpgradeCombo } @@ -54,7 +55,7 @@ func (asuite *AclTestSuite) Upgrade() { func TestACLSuite(t *testing.T) { for _, uc := range dgraphtest.AllUpgradeCombos(true) { - log.Printf("running upgrade tests for confg: %+v", uc) + log.Printf("running upgrade tests for config: %+v", uc) aclSuite := AclTestSuite{uc: uc} suite.Run(t, &aclSuite) if t.Failed() { diff --git a/go.mod b/go.mod index 3da8a99fca6..6b41e41fbe6 100644 --- a/go.mod +++ b/go.mod @@ -13,8 +13,8 @@ require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 github.com/IBM/sarama v1.41.0 github.com/Masterminds/semver/v3 v3.1.0 + github.com/bits-and-blooms/bitset v1.2.0 github.com/blevesearch/bleve/v2 v2.3.10 - github.com/chewxy/math32 v1.10.1 github.com/dgraph-io/badger/v4 v4.2.0 github.com/dgraph-io/dgo/v230 v230.0.2-0.20240314155021-7b8d289e37f3 github.com/dgraph-io/gqlgen v0.13.2 @@ -54,18 +54,20 @@ require ( github.com/spf13/viper v1.7.1 github.com/stretchr/testify v1.9.0 github.com/twpayne/go-geom v1.0.5 + github.com/viterin/vek v0.4.2 github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c go.etcd.io/etcd/raft/v3 v3.5.9 go.opencensus.io v0.24.0 go.uber.org/zap v1.16.0 - golang.org/x/crypto v0.21.0 - golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 - golang.org/x/net v0.22.0 - golang.org/x/sync v0.6.0 - golang.org/x/sys v0.18.0 - golang.org/x/term v0.18.0 - golang.org/x/text v0.14.0 - golang.org/x/tools v0.19.0 + golang.org/x/crypto v0.24.0 + golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 + golang.org/x/mod v0.18.0 + golang.org/x/net v0.26.0 + golang.org/x/sync v0.7.0 + golang.org/x/sys v0.22.0 + golang.org/x/term v0.21.0 + golang.org/x/text v0.16.0 + golang.org/x/tools v0.22.0 google.golang.org/grpc v1.62.1 gopkg.in/square/go-jose.v2 v2.3.1 gopkg.in/yaml.v2 v2.4.0 @@ -77,7 +79,6 @@ require ( github.com/agnivade/levenshtein v1.0.3 // indirect github.com/apache/thrift v0.13.0 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/bits-and-blooms/bitset v1.2.0 // indirect github.com/blevesearch/bleve_index_api v1.0.6 // indirect github.com/blevesearch/geo v0.1.18 // indirect github.com/blevesearch/go-porterstemmer v1.0.3 // indirect @@ -85,6 +86,7 @@ require ( github.com/blevesearch/snowballstem v0.9.0 // indirect github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/chewxy/math32 v1.10.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/go-units v0.4.0 // indirect @@ -139,10 +141,10 @@ require ( github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/subosito/gotenv v1.4.2 // indirect github.com/tinylib/msgp v1.1.2 // indirect + github.com/viterin/partial v1.1.0 // indirect github.com/xdg/stringprep v1.0.3 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.10.0 // indirect - golang.org/x/mod v0.16.0 // indirect golang.org/x/time v0.3.0 // indirect google.golang.org/api v0.122.0 // indirect google.golang.org/appengine v1.6.8 // indirect diff --git a/go.sum b/go.sum index 9d5cfb73420..3b68aebc34a 100644 --- a/go.sum +++ b/go.sum @@ -685,6 +685,10 @@ github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPU github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= github.com/vektah/dataloaden v0.2.1-0.20190515034641-a19b9a6e7c9e/go.mod h1:/HUdMve7rvxZma+2ZELQeNh88+003LL7Pf/CZ089j8U= github.com/vektah/gqlparser/v2 v2.1.0/go.mod h1:SyUiHgLATUR8BiYURfTirrTcGpcE+4XkV2se04Px1Ms= +github.com/viterin/partial v1.1.0 h1:iH1l1xqBlapXsYzADS1dcbizg3iQUKTU1rbwkHv/80E= +github.com/viterin/partial v1.1.0/go.mod h1:oKGAo7/wylWkJTLrWX8n+f4aDPtQMQ6VG4dd2qur5QA= +github.com/viterin/vek v0.4.2 h1:Vyv04UjQT6gcjEFX82AS9ocgNbAJqsHviheIBdPlv5U= +github.com/viterin/vek v0.4.2/go.mod h1:A4JRAe8OvbhdzBL5ofzjBS0J29FyUrf95tQogvtHHUc= github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk= github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= github.com/xdg/stringprep v1.0.3 h1:cmL5Enob4W83ti/ZHuZLuKD/xqJfus4fVPwE+/BDm+4= @@ -743,8 +747,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -758,8 +762,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw= -golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -786,8 +790,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= -golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -834,8 +838,8 @@ golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= -golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -856,8 +860,8 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -916,13 +920,13 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= +golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= +golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -934,8 +938,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1005,8 +1009,8 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/graphql/authorization/auth.go b/graphql/authorization/auth.go index 70318d22152..d56344ede05 100644 --- a/graphql/authorization/auth.go +++ b/graphql/authorization/auth.go @@ -513,7 +513,7 @@ func (a *AuthMeta) FetchJWK(i int) error { func (a *AuthMeta) refreshJWK(i int) error { var err error - for i := 0; i < 3; i++ { + for n := 0; n < 3; n++ { err = a.FetchJWK(i) if err == nil { return nil diff --git a/graphql/bench/schema.graphql b/graphql/bench/schema.graphql index 91ddba4f481..7299905b12b 100644 --- a/graphql/bench/schema.graphql +++ b/graphql/bench/schema.graphql @@ -69,4 +69,4 @@ type Owner { hasRestaurants: [Restaurant] @hasInverse(field: owner) } -# Dgraph.Authorization {"VerificationKey":"secretkey","Header":"X-Test-Auth","Namespace":"https://xyz.io/jwt/claims","Algo":"HS256","Audience":["aud1","63do0q16n6ebjgkumu05kkeian","aud5"]} \ No newline at end of file +# Dgraph.Authorization {"VerificationKey":"secretkey","Header":"X-Test-Auth","Namespace":"https://xyz.io/jwt/claims","Algo":"HS256","Audience":["aud1","63do0q16n6ebjgkumu05kkeian","aud5"]} diff --git a/graphql/bench/schema_auth.graphql b/graphql/bench/schema_auth.graphql index 245bbc5f7dd..50b62cf009b 100644 --- a/graphql/bench/schema_auth.graphql +++ b/graphql/bench/schema_auth.graphql @@ -248,4 +248,4 @@ type Owner @auth( hasRestaurants: [Restaurant] @hasInverse(field: owner) } -# Dgraph.Authorization {"VerificationKey":"secretkey","Header":"X-Test-Auth","Namespace":"https://xyz.io/jwt/claims","Algo":"HS256","Audience":["aud1","63do0q16n6ebjgkumu05kkeian","aud5"]} \ No newline at end of file +# Dgraph.Authorization {"VerificationKey":"secretkey","Header":"X-Test-Auth","Namespace":"https://xyz.io/jwt/claims","Algo":"HS256","Audience":["aud1","63do0q16n6ebjgkumu05kkeian","aud5"]} diff --git a/graphql/dgraph/graphquery.go b/graphql/dgraph/graphquery.go index 99bb6352147..23519f45a78 100644 --- a/graphql/dgraph/graphquery.go +++ b/graphql/dgraph/graphquery.go @@ -28,12 +28,16 @@ import ( // validate query, and so doesn't return an error if query is 'malformed' - it might // just write something that wouldn't parse as a Dgraph query. func AsString(queries []*dql.GraphQuery) string { - if queries == nil { + if len(queries) == 0 { return "" } var b strings.Builder - x.Check2(b.WriteString("query {\n")) + queryName := queries[len(queries)-1].Attr + x.Check2(b.WriteString("query ")) + addQueryVars(&b, queryName, queries[0].Args) + x.Check2(b.WriteString("{\n")) + numRewrittenQueries := 0 for _, q := range queries { if q == nil { @@ -54,6 +58,24 @@ func AsString(queries []*dql.GraphQuery) string { return b.String() } +func addQueryVars(b *strings.Builder, queryName string, args map[string]string) { + dollarFound := false + for name, val := range args { + if strings.HasPrefix(name, "$") { + if !dollarFound { + x.Check2(b.WriteString(queryName + "(")) + x.Check2(b.WriteString(name + ": " + val)) + dollarFound = true + } else { + x.Check2(b.WriteString(", " + name + ": " + val)) + } + } + } + if dollarFound { + x.Check2(b.WriteString(") ")) + } +} + func writeQuery(b *strings.Builder, query *dql.GraphQuery, prefix string) { if query.Var != "" || query.Alias != "" || query.Attr != "" { x.Check2(b.WriteString(prefix)) @@ -145,6 +167,9 @@ func writeRoot(b *strings.Builder, q *dql.GraphQuery) { } switch { + // TODO: Instead of the hard-coded strings "uid", "type", etc., use the + // pre-defined constants in dql/parser.go such as dql.uidFunc, dql.typFunc, + // etc. This of course will require that we make these constants public. case q.Func.Name == "uid": x.Check2(b.WriteString("(func: ")) writeUIDFunc(b, q.Func.UID, q.Func.Args) @@ -154,6 +179,10 @@ func writeRoot(b *strings.Builder, q *dql.GraphQuery) { x.Check2(b.WriteString("(func: eq(")) writeFilterArguments(b, q.Func.Args) x.Check2(b.WriteRune(')')) + case q.Func.Name == "similar_to": + x.Check2(b.WriteString("(func: similar_to(")) + writeFilterArguments(b, q.Func.Args) + x.Check2(b.WriteRune(')')) } writeOrderAndPage(b, q, true) x.Check2(b.WriteRune(')')) diff --git a/graphql/e2e/auth/schema.graphql b/graphql/e2e/auth/schema.graphql index a9542def0a7..4148b373968 100644 --- a/graphql/e2e/auth/schema.graphql +++ b/graphql/e2e/auth/schema.graphql @@ -568,7 +568,7 @@ type Contact @auth( query: { rule: "{$ContactRole: { eq: \"ADMINISTRATOR\"}}" } ) { id: ID! - nickName: String @search(by: [exact, term, fulltext, regexp]) + nickName: String @search(by: ["exact", "term", "fulltext", "regexp"]) adminTasks: [AdminTask] @hasInverse(field: forContact) tasks: [Task] @hasInverse(field: forContact) } @@ -577,14 +577,14 @@ type AdminTask @auth( query: { rule: "{$TaskRole: { eq: \"ADMINISTRATOR\"}}" } ) { id: ID! - name: String @search(by: [exact, term, fulltext, regexp]) + name: String @search(by: ["exact", "term", "fulltext", "regexp"]) occurrences: [TaskOccurrence] @hasInverse(field: adminTask) forContact: Contact @hasInverse(field: adminTasks) } type Task { id: ID! - name: String @search(by: [exact, term, fulltext, regexp]) + name: String @search(by: ["exact", "term", "fulltext", "regexp"]) occurrences: [TaskOccurrence] @hasInverse(field: task) forContact: Contact @hasInverse(field: tasks) } @@ -608,7 +608,7 @@ type TaskOccurrence @auth( task: Task @hasInverse(field: occurrences) adminTask: AdminTask @hasInverse(field: occurrences) isPublic: Boolean @search - role: String @search(by: [exact, term, fulltext, regexp]) + role: String @search(by: ["exact", "term", "fulltext", "regexp"]) } type Author { diff --git a/graphql/e2e/common/query.go b/graphql/e2e/common/query.go index 5bb4c1136bc..4c26375132a 100644 --- a/graphql/e2e/common/query.go +++ b/graphql/e2e/common/query.go @@ -1285,7 +1285,7 @@ func stringExactFilters(t *testing.T) { func scalarListFilters(t *testing.T) { - // tags is a list of strings with @search(by: exact). So all the filters + // tags is a list of strings with @search(by: "exact"). So all the filters // lt, le, ... mean "is there something in the list that's lt 'Dgraph'", etc. cases := map[string]struct { diff --git a/graphql/e2e/custom_logic/custom_logic_test.go b/graphql/e2e/custom_logic/custom_logic_test.go index e0e4861c2b8..8601228c01a 100644 --- a/graphql/e2e/custom_logic/custom_logic_test.go +++ b/graphql/e2e/custom_logic/custom_logic_test.go @@ -1067,7 +1067,7 @@ func TestCustomFieldsShouldPassBody(t *testing.T) { schema := ` type User { - id: String! @id @search(by: [hash, regexp]) + id: String! @id @search(by: ["hash", "regexp"]) address:String name: String @custom( @@ -2573,7 +2573,7 @@ func TestCustomDQL(t *testing.T) { } type Tweets implements Node { id: ID! - text: String! @search(by: [fulltext, exact]) + text: String! @search(by: ["fulltext", "exact"]) user: User timestamp: DateTime! @search } @@ -2864,7 +2864,7 @@ func TestCustomFieldsWithRestError(t *testing.T) { } type User { - id: String! @id @search(by: [hash, regexp]) + id: String! @id @search(by: ["hash", "regexp"]) name: String @custom( http: { diff --git a/graphql/e2e/directives/schema.graphql b/graphql/e2e/directives/schema.graphql index f7e6718608b..84b9e5eee32 100644 --- a/graphql/e2e/directives/schema.graphql +++ b/graphql/e2e/directives/schema.graphql @@ -423,4 +423,4 @@ type CricketTeam implements Team { type LibraryManager { name: String! @id manages: [LibraryMember] -} \ No newline at end of file +} diff --git a/graphql/e2e/normal/schema.graphql b/graphql/e2e/normal/schema.graphql index 396dea4318b..1f8a2d2956e 100644 --- a/graphql/e2e/normal/schema.graphql +++ b/graphql/e2e/normal/schema.graphql @@ -422,4 +422,4 @@ type CricketTeam implements Team { type LibraryManager { name: String! @id manages: [LibraryMember] -} \ No newline at end of file +} diff --git a/graphql/e2e/schema/apollo_service_response.graphql b/graphql/e2e/schema/apollo_service_response.graphql index c0383cdeec5..14f6efa1e0b 100644 --- a/graphql/e2e/schema/apollo_service_response.graphql +++ b/graphql/e2e/schema/apollo_service_response.graphql @@ -83,6 +83,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -196,7 +197,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/e2e/schema/generatedSchema.graphql b/graphql/e2e/schema/generatedSchema.graphql index 6b9c37d5f31..70cc4f3f19c 100644 --- a/graphql/e2e/schema/generatedSchema.graphql +++ b/graphql/e2e/schema/generatedSchema.graphql @@ -64,6 +64,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -177,7 +178,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/e2e/schema/schema_test.go b/graphql/e2e/schema/schema_test.go index 0b37f440264..671f93def11 100644 --- a/graphql/e2e/schema/schema_test.go +++ b/graphql/e2e/schema/schema_test.go @@ -564,7 +564,7 @@ func TestLargeSchemaUpdate(t *testing.T) { schema := "type LargeSchema {" for i := 1; i <= numFields; i++ { - schema = schema + "\n" + fmt.Sprintf("field%d: String! @search(by: [regexp])", i) + schema = schema + "\n" + fmt.Sprintf("field%d: String! @search(by: [\"regexp\"])", i) } schema = schema + "\n}" diff --git a/graphql/e2e/subscription/subscription_test.go b/graphql/e2e/subscription/subscription_test.go index 3dda3c231c9..f546ed81ec5 100644 --- a/graphql/e2e/subscription/subscription_test.go +++ b/graphql/e2e/subscription/subscription_test.go @@ -46,7 +46,7 @@ const ( } type Customer { - username: String! @id @search(by: [hash, regexp]) + username: String! @id @search(by: ["hash", "regexp"]) reviews: [Review] @hasInverse(field: by) } diff --git a/graphql/resolve/mutation_rewriter.go b/graphql/resolve/mutation_rewriter.go index 669cfafa6ae..a5914347b76 100644 --- a/graphql/resolve/mutation_rewriter.go +++ b/graphql/resolve/mutation_rewriter.go @@ -1675,6 +1675,12 @@ func rewriteObject( fieldName = fieldName[1 : len(fieldName)-1] } + if fieldDef.HasEmbeddingDirective() { + // embedding is a JSON array of numbers. Rewrite it as a string, for now + var valBytes []byte + valBytes, _ = json.Marshal(val) + val = string(valBytes) + } // TODO: Write a function for aggregating data of fragment from child nodes. switch val := val.(type) { case map[string]interface{}: diff --git a/graphql/resolve/query_rewriter.go b/graphql/resolve/query_rewriter.go index e58f03ab73b..509e593a772 100644 --- a/graphql/resolve/query_rewriter.go +++ b/graphql/resolve/query_rewriter.go @@ -19,6 +19,7 @@ package resolve import ( "bytes" "context" + "encoding/json" "fmt" "sort" "strconv" @@ -147,7 +148,14 @@ func (qr *queryRewriter) Rewrite( dgQuery := rewriteAsGet(gqlQuery, uid, xid, authRw) return dgQuery, nil - + case schema.SimilarByIdQuery: + xid, uid, err := gqlQuery.IDArgValue() + if err != nil { + return nil, err + } + return rewriteAsSimilarByIdQuery(gqlQuery, uid, xid, authRw), nil + case schema.SimilarByEmbeddingQuery: + return rewriteAsSimilarByEmbeddingQuery(gqlQuery, authRw), nil case schema.FilterQuery: return rewriteAsQuery(gqlQuery, authRw), nil case schema.PasswordQuery: @@ -612,6 +620,292 @@ func rewriteAsGet( return dgQuery } +// rewriteAsSimilarByIdQuery +// +// rewrites SimilarById graphQL query to nested DQL query blocks +// Example rewrittern query: +// +// query { +// var(func: eq(Product.id, "0528012398")) @filter(type(Product)) { +// vec as Product.embedding +// } +// var() { +// v1 as max(val(vec)) +// } +// var(func: similar_to(Product.embedding, 8, val(v1))) { +// v2 as Product.embedding +// distance as math((v2 - v1) dot (v2 - v1)) +// } +// querySimilarProductById(func: uid(distance) +// @filter(Product.id != "0528012398"), orderasc: val(distance)) { +// Product.id : Product.id +// Product.description : Product.description +// Product.title : Product.title +// Product.imageUrl : Product.imageUrl +// Product.vector_distance : val(distance) +// dgraph.uid : uid +// } +// } +func rewriteAsSimilarByIdQuery( + query schema.Query, + uid uint64, + xidArgToVal map[string]string, + auth *authRewriter) []*dql.GraphQuery { + + // Get graphQL arguments + typ := query.Type() + similarBy := query.ArgValue(schema.SimilarByArgName).(string) + pred := typ.DgraphPredicate(similarBy) + topK := query.ArgValue(schema.SimilarTopKArgName) + similarByField := typ.Field(similarBy) + metric := similarByField.EmbeddingSearchMetric() + distanceFormula := "math(sqrt((v2 - v1) dot (v2 - v1)))" // default - euclidian + + if metric == schema.SimilarSearchMetricDotProduct { + distanceFormula = "math((1.0 - (v1 dot v2)) /2.0)" + } else if metric == schema.SimilarSearchMetricCosine { + distanceFormula = "math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)" + } + + // First generate the query to fetch the uid + // for the given id. For Example, + // var(func: eq(Product.id, "0528012398")) @filter(type(Product)) { + // vec as Product.embedding + // } + dgQuery := rewriteAsGet(query, uid, xidArgToVal, auth) + lastQuery := dgQuery[len(dgQuery)-1] + // Turn the root query into "var" + lastQuery.Attr = "var" + // Save the result to be later used for the last query block, sortQuery + result := lastQuery.Children + + // define the variable "vec" for the search vector + lastQuery.Children = []*dql.GraphQuery{{ + Attr: pred, + Var: "vec", + }} + + // Turn the variable into a "const" by + // remembering the max of it. + // The lookup is going to return exactly one uid + // anyway. For example, + // var() { + // v1 as max(val(vec)) + // } + aggQuery := &dql.GraphQuery{ + Attr: "var" + "()", + Children: []*dql.GraphQuery{ + { + Var: "v1", + Attr: "max(val(vec))", + }, + }, + } + + // Similar_to query, computes the distance for + // ordering the result later. + // Example: + // var(func: similar_to(Product.embedding, 8, val(v1))) { + // v2 as Product.embedding + // distance as math((v2 - v1) dot (v2 - v1)) + // } + similarQuery := &dql.GraphQuery{ + Attr: "var", + Children: []*dql.GraphQuery{ + { + Var: "v2", + Attr: pred, + }, + { + Var: "distance", + Attr: distanceFormula, + }, + }, + Func: &dql.Function{ + Name: "similar_to", + Args: []dql.Arg{ + { + Value: pred, + }, + { + Value: fmt.Sprintf("%v", topK), + }, + { + Value: "val(v1)", + }, + }, + }, + } + + // Rename the distance as .vector_distance + distance := &dql.GraphQuery{ + Alias: typ.Name() + "." + schema.SimilarQueryDistanceFieldName, + Attr: "val(distance)", + } + + var found bool = false + for _, child := range result { + if child.Alias == typ.Name()+"."+schema.SimilarQueryDistanceFieldName { + child.Attr = "val(distance)" + found = true + break + } + } + if !found { + result = append(result, distance) + } + + // order the result by euclidian distance, For example, + // querySimilarProductById(func: uid(distance), orderasc: val(distance)) { + // Product.id : Product.id + // Product.description : Product.description + // Product.title : Product.title + // Product.imageUrl : Product.imageUrl + // Product.vector_distance : val(distance) + // dgraph.uid : uid + // } + // } + sortQuery := &dql.GraphQuery{ + Attr: query.DgraphAlias(), + Children: result, + Func: &dql.Function{ + Name: "uid", + Args: []dql.Arg{{Value: "distance"}}, + }, + Order: []*pb.Order{{Attr: "val(distance)", Desc: false}}, + } + addArgumentsToField(sortQuery, query) + + dgQuery = append(dgQuery, aggQuery, similarQuery, sortQuery) + return dgQuery +} + +// rewriteAsSimilarByEmbeddingQuery +// +// rewrites SimilarByEmbedding graphQL query to nested DQL query blocks +// Example rewrittern query: +// +// query gQLTodQL($search_vector: float32vector = "") { +// var(func: similar_to(Product.embedding, 8, $search_vector)) { +// v2 as Product.embedding +// distance as math((v2 - $search_vector) dot (v2 - $search_vector)) +// } +// querySimilarProductById(func: uid(distance), +// @filter(Product.id != "0528012398"), orderasc: val(distance)) { +// Product.id : Product.id +// Product.description : Product.description +// Product.title : Product.title +// Product.imageUrl : Product.imageUrl +// Product.vector_distance : val(distance) +// dgraph.uid : uid +// } +// } +func rewriteAsSimilarByEmbeddingQuery( + query schema.Query, auth *authRewriter) []*dql.GraphQuery { + + dgQuery := rewriteAsQuery(query, auth) + + // Remember dgQuery[0].Children as result type for the last block + // in the rewritten query + result := dgQuery[0].Children + typ := query.Type() + + // Get all the arguments from graphQL query + similarBy := query.ArgValue(schema.SimilarByArgName).(string) + pred := typ.DgraphPredicate(similarBy) + topK := query.ArgValue(schema.SimilarTopKArgName) + vec := query.ArgValue(schema.SimilarVectorArgName).([]interface{}) + vecStr, _ := json.Marshal(vec) + + similarByField := typ.Field(similarBy) + metric := similarByField.EmbeddingSearchMetric() + distanceFormula := "math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))" // default = euclidian + + if metric == schema.SimilarSearchMetricDotProduct { + distanceFormula = "math(( 1.0 - (($search_vector) dot v2)) /2.0)" + } else if metric == schema.SimilarSearchMetricCosine { + distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" + + " * (v2 dot v2) ) )) / 2.0)" + } + + // Save vectorString as a query variable, $search_vector + queryArgs := dgQuery[0].Args + if queryArgs == nil { + queryArgs = make(map[string]string) + } + queryArgs["$search_vector"] = " float32vector = \"" + string(vecStr) + "\"" + thisFilter := &dql.FilterTree{ + Func: dgQuery[0].Func, + } + + // create the similar_to function and move existing root function + // to the filter tree + addToFilterTree(dgQuery[0], thisFilter) + + // Create similar_to as the root function, passing $search_vector as + // the search vector + dgQuery[0].Attr = "var" + dgQuery[0].Func = &dql.Function{ + Name: "similar_to", + Args: []dql.Arg{ + { + Value: pred, + }, + { + Value: fmt.Sprintf("%v", topK), + }, + { + Value: "$search_vector", + }, + }, + } + + // Compute the euclidian distance between the neighbor + // and the search vector + dgQuery[0].Children = []*dql.GraphQuery{ + { + Var: "v2", + Attr: pred, + }, + { + Var: "distance", + Attr: distanceFormula, + }, + } + + // Rename distance as .vector_distance + distance := &dql.GraphQuery{ + Alias: typ.Name() + "." + schema.SimilarQueryDistanceFieldName, + Attr: "val(distance)", + } + + var found bool = false + for _, child := range result { + if child.Alias == typ.Name()+"."+schema.SimilarQueryDistanceFieldName { + child.Attr = "val(distance)" + found = true + break + } + } + if !found { + result = append(result, distance) + } + + // order by distance + sortQuery := &dql.GraphQuery{ + Attr: query.DgraphAlias(), + Children: result, + Func: &dql.Function{ + Name: "uid", + Args: []dql.Arg{{Value: "distance"}}, + }, + Order: []*pb.Order{{Attr: "val(distance)", Desc: false}}, + } + + dgQuery = append(dgQuery, sortQuery) + return dgQuery +} + // Adds common RBAC and UID, Type rules to DQL query. // This function is used by rewriteAsQuery and aggregateQuery functions func addCommonRules( diff --git a/graphql/resolve/query_test.yaml b/graphql/resolve/query_test.yaml index c281bc5932a..ab15599d020 100644 --- a/graphql/resolve/query_test.yaml +++ b/graphql/resolve/query_test.yaml @@ -3353,3 +3353,168 @@ dgraph.uid : uid } } +- name: "query similar_to" + gqlquery: | + query { + querySimilarProductByEmbedding(by: productVector, topK: 1, vector: [0.1, 0.2, 0.3, 0.4, 0.5]) { + id + title + productVector + } + } + + dgquery: |- + query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { + var(func: similar_to(Product.productVector, 1, $search_vector)) @filter(type(Product)) { + v2 as Product.productVector + distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector))) + } + querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) { + Product.id : Product.id + Product.title : Product.title + Product.productVector : Product.productVector + dgraph.uid : uid + Product.vector_distance : val(distance) + } + } +- name: "query vector using uid" + gqlquery: | + query { + querySimilarProductById(by: productVector, topK: 3, id: "0x1") { + id + title + productVector + } + } + + dgquery: |- + query { + var(func: eq(Product.id, "0x1")) @filter(type(Product)) { + vec as Product.productVector + } + var() { + v1 as max(val(vec)) + } + var(func: similar_to(Product.productVector, 3, val(v1))) { + v2 as Product.productVector + distance as math(sqrt((v2 - v1) dot (v2 - v1))) + } + querySimilarProductById(func: uid(distance), orderasc: val(distance)) { + Product.id : Product.id + Product.title : Product.title + Product.productVector : Product.productVector + dgraph.uid : uid + Product.vector_distance : val(distance) + } + } + +- name: "query vector by id with cosine distance" + gqlquery: | + query { + querySimilarProjectCosineById(by: description_v, topK: 3, id: "0x1") { + id + title + description_v + } + } + + dgquery: |- + query { + var(func: eq(ProjectCosine.id, "0x1")) @filter(type(ProjectCosine)) { + vec as ProjectCosine.description_v + } + var() { + v1 as max(val(vec)) + } + var(func: similar_to(ProjectCosine.description_v, 3, val(v1))) { + v2 as ProjectCosine.description_v + distance as math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0) + } + querySimilarProjectCosineById(func: uid(distance), orderasc: val(distance)) { + ProjectCosine.id : ProjectCosine.id + ProjectCosine.title : ProjectCosine.title + ProjectCosine.description_v : ProjectCosine.description_v + dgraph.uid : uid + ProjectCosine.vector_distance : val(distance) + } + } + +- name: "query similar_to with cosine distance" + gqlquery: | + query { + querySimilarProjectCosineByEmbedding(by: description_v, topK: 1, vector: [0.1, 0.2, 0.3, 0.4, 0.5]) { + id + title + description_v + } + } + + dgquery: |- + query querySimilarProjectCosineByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { + var(func: similar_to(ProjectCosine.description_v, 1, $search_vector)) @filter(type(ProjectCosine)) { + v2 as ProjectCosine.description_v + distance as math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0) + } + querySimilarProjectCosineByEmbedding(func: uid(distance), orderasc: val(distance)) { + ProjectCosine.id : ProjectCosine.id + ProjectCosine.title : ProjectCosine.title + ProjectCosine.description_v : ProjectCosine.description_v + dgraph.uid : uid + ProjectCosine.vector_distance : val(distance) + } + } +- name: "query vector by id with dot product distance" + gqlquery: | + query { + querySimilarProjectDotProductById(by: description_v, topK: 3, id: "0x1") { + id + title + description_v + } + } + + dgquery: |- + query { + var(func: eq(ProjectDotProduct.id, "0x1")) @filter(type(ProjectDotProduct)) { + vec as ProjectDotProduct.description_v + } + var() { + v1 as max(val(vec)) + } + var(func: similar_to(ProjectDotProduct.description_v, 3, val(v1))) { + v2 as ProjectDotProduct.description_v + distance as math((1.0 - (v1 dot v2)) /2.0) + } + querySimilarProjectDotProductById(func: uid(distance), orderasc: val(distance)) { + ProjectDotProduct.id : ProjectDotProduct.id + ProjectDotProduct.title : ProjectDotProduct.title + ProjectDotProduct.description_v : ProjectDotProduct.description_v + dgraph.uid : uid + ProjectDotProduct.vector_distance : val(distance) + } + } + +- name: "query similar_to with dot product distance" + gqlquery: | + query { + querySimilarProjectDotProductByEmbedding(by: description_v, topK: 1, vector: [0.1, 0.2, 0.3, 0.4, 0.5]) { + id + title + description_v + } + } + + dgquery: |- + query querySimilarProjectDotProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { + var(func: similar_to(ProjectDotProduct.description_v, 1, $search_vector)) @filter(type(ProjectDotProduct)) { + v2 as ProjectDotProduct.description_v + distance as math(( 1.0 - (($search_vector) dot v2)) /2.0) + } + querySimilarProjectDotProductByEmbedding(func: uid(distance), orderasc: val(distance)) { + ProjectDotProduct.id : ProjectDotProduct.id + ProjectDotProduct.title : ProjectDotProduct.title + ProjectDotProduct.description_v : ProjectDotProduct.description_v + dgraph.uid : uid + ProjectDotProduct.vector_distance : val(distance) + } + } \ No newline at end of file diff --git a/graphql/resolve/resolver.go b/graphql/resolve/resolver.go index 7b0936ec755..79dd12b16ab 100644 --- a/graphql/resolve/resolver.go +++ b/graphql/resolve/resolver.go @@ -235,6 +235,8 @@ func (rf *resolverFactory) WithConventionResolvers( s schema.Schema, fns *ResolverFns) ResolverFactory { queries := append(s.Queries(schema.GetQuery), s.Queries(schema.FilterQuery)...) + queries = append(queries, s.Queries(schema.SimilarByIdQuery)...) + queries = append(queries, s.Queries(schema.SimilarByEmbeddingQuery)...) queries = append(queries, s.Queries(schema.PasswordQuery)...) queries = append(queries, s.Queries(schema.AggregateQuery)...) for _, q := range queries { diff --git a/graphql/resolve/resolver_error_test.go b/graphql/resolve/resolver_error_test.go index cdd4d2a30ff..be298e67765 100644 --- a/graphql/resolve/resolver_error_test.go +++ b/graphql/resolve/resolver_error_test.go @@ -83,6 +83,14 @@ type Author { postsNullableListRequired: [Post]! } +type Product { + id: String! @id + description: String + title: String + imageUrl: String + productVector: [Float!] @embedding +} + type Post { id: ID! title: String! diff --git a/graphql/resolve/schema.graphql b/graphql/resolve/schema.graphql index 6d8b2a1588d..6b469284860 100644 --- a/graphql/resolve/schema.graphql +++ b/graphql/resolve/schema.graphql @@ -11,7 +11,7 @@ type Hotel { type Country { id: ID! - name: String! @search(by: [trigram, exact]) + name: String! @search(by: ["trigram", "exact"]) states: [State] @hasInverse(field: country) } @@ -499,4 +499,26 @@ type CricketTeam implements Team { type LibraryManager { name: String! @id manages: [LibraryMember] -} \ No newline at end of file +} + +type Product { + id: String! @id + description: String + title: String + imageUrl: String + productVector: [Float!] @embedding +} + +type ProjectCosine { + id: String! @id + description: String + title: String + description_v: [Float!] @embedding @search(by: ["hnsw(metric: cosine, exponent: 4)"]) +} + +type ProjectDotProduct { + id: String! @id + description: String + title: String + description_v: [Float!] @embedding @search(by: ["hnsw(metric: dotproduct, exponent: 4)"]) +} diff --git a/graphql/schema/dgraph_schemagen_test.yml b/graphql/schema/dgraph_schemagen_test.yml index 32bb64b3b58..ba066df38da 100644 --- a/graphql/schema/dgraph_schemagen_test.yml +++ b/graphql/schema/dgraph_schemagen_test.yml @@ -141,20 +141,24 @@ schemas: s5: String @search(by: [fulltext]) s6: String @search(by: [trigram]) s7: String @search(by: [regexp]) - s8: String @search(by: [exact, fulltext, term, trigram]) + s8: String @search(by: ["exact", "fulltext", "term", "trigram"]) dt1: DateTime @search dt2: DateTime @search(by: [year]) dt3: DateTime @search(by: [month]) dt4: DateTime @search(by: [day]) dt5: DateTime @search(by: [hour]) + vf1: [Float!] @embedding @search(by: ["hnsw"]) + vf2: [Float!] @embedding @search(by: ["hnsw(exponent: 4, metric: euclidian)"]) + vf3: [Float!] @embedding @search(by: ["hnsw(metric: cosine)"]) + vf4: [Float!] @embedding @search(by: ["hnsw(metric: dotproduct, exponent: 4)"]) e: E @search e1: E @search(by: [hash]) e2: E @search(by: [exact]) e3: E @search(by: [trigram]) e4: E @search(by: [regexp]) - e5: E @search(by: [hash, regexp]) - e6: E @search(by: [hash, trigram]) - e7: E @search(by: [exact, regexp]) + e5: E @search(by: ["hash", "regexp"]) + e6: E @search(by: ["hash", "trigram"]) + e7: E @search(by: ["exact", "regexp"]) } enum E { A } output: | @@ -180,6 +184,10 @@ schemas: X.dt3 X.dt4 X.dt5 + X.vf1 + X.vf2 + X.vf3 + X.vf4 X.e X.e1 X.e2 @@ -210,6 +218,10 @@ schemas: X.dt3: dateTime @index(month) . X.dt4: dateTime @index(day) . X.dt5: dateTime @index(hour) . + X.vf1: float32vector @index(hnsw) . + X.vf2: float32vector @index(hnsw(exponent: "4", metric: "euclidian")) . + X.vf3: float32vector @index(hnsw(metric: "cosine")) . + X.vf4: float32vector @index(hnsw(exponent: "4", metric: "dotproduct")) . X.e: string @index(hash) . X.e1: string @index(hash) . X.e2: string @index(exact) . @@ -434,7 +446,7 @@ schemas: f2: String @dgraph(pred: "T.f@no") f3: String @dgraph(pred: "f3@en") name: String! @id - nameHi: String @dgraph(pred: "Person.name@hi") @search(by: [term, exact]) + nameHi: String @dgraph(pred: "Person.name@hi") @search(by: ["term", "exact"]) nameEn: String @dgraph(pred: "Person.name@en") @search(by: [regexp]) nameHiEn: String @dgraph(pred: "Person.name@hi:en") nameHi_En_Untag: String @dgraph(pred: "Person.name@hi:en:.") @@ -503,7 +515,7 @@ schemas: - name: "Field with @id directive and a hash arg in search directive generates correct schema." input: | interface A { - id: String! @id @search(by: [hash, term]) + id: String! @id @search(by: ["hash", "term"]) } type B implements A { correct: Boolean @search @@ -641,13 +653,13 @@ schemas: combined" input: | type A { - p: String @dgraph(pred: "name") @search(by: [exact, term]) + p: String @dgraph(pred: "name") @search(by: ["exact", "term"]) } type B { q: String @dgraph(pred: "name") @search(by: [trigram]) } type C { - q: String @dgraph(pred: "name") @search(by: [exact, term]) + q: String @dgraph(pred: "name") @search(by: ["exact", "term"]) } output: | type A { @@ -664,8 +676,8 @@ schemas: - name: "fields with @dgraph(pred: ...) contain different language." input: | type A { - content: String! @dgraph(pred: "post") @search(by: [exact, term]) - author: String @dgraph(pred: "<公司>") @search(by: [exact, term]) + content: String! @dgraph(pred: "post") @search(by: ["exact", "term"]) + author: String @dgraph(pred: "<公司>") @search(by: ["exact", "term"]) } output: | type A { diff --git a/graphql/schema/gqlschema.go b/graphql/schema/gqlschema.go index 7927f547e64..3fff2a7bae2 100644 --- a/graphql/schema/gqlschema.go +++ b/graphql/schema/gqlschema.go @@ -34,9 +34,10 @@ const ( searchDirective = "search" searchArgs = "by" - dgraphDirective = "dgraph" - dgraphTypeArg = "type" - dgraphPredArg = "pred" + dgraphDirective = "dgraph" + dgraphTypeArg = "type" + dgraphPredArg = "pred" + embeddingDirective = "embedding" idDirective = "id" idDirectiveInterfaceArg = "interface" @@ -161,6 +162,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -275,7 +277,8 @@ input GenerateMutationParams { ` directiveDefs = ` directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION @@ -306,7 +309,8 @@ directive @generate( // So, such directives have to be missed too. apolloSupportedDirectiveDefs = ` directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION @@ -458,6 +462,7 @@ var supportedSearches = map[string]searchTypeIndex{ "point": {"Point", "geo"}, "polygon": {"Polygon", "geo"}, "multiPolygon": {"MultiPolygon", "geo"}, + "hnsw": {"Float", "hnsw"}, } // GraphQL scalar/object type -> default search arg @@ -532,6 +537,7 @@ var builtInFilters = map[string]string{ "point": "PointGeoFilter", "polygon": "PolygonGeoFilter", "multiPolygon": "PolygonGeoFilter", + "hnsw": "HNSWSearchFilter", } // GraphQL in-built type -> Dgraph scalar @@ -561,6 +567,7 @@ func ValidatorNoOp( var directiveValidators = map[string]directiveValidator{ inverseDirective: hasInverseValidation, searchDirective: searchValidation, + embeddingDirective: embeddingValidation, dgraphDirective: dgraphDirectiveValidation, idDirective: idValidation, subscriptionDirective: ValidatorNoOp, @@ -1469,6 +1476,7 @@ func getFilterTypes(schema *ast.Schema, fld *ast.FieldDefinition, filterName str filterNames := make([]string, len(searchArgs)) for i, search := range searchArgs { + search = parseSearchType(search) filterNames[i] = builtInFilters[search] // For enum type, if the index is "hash" or "exact", we construct filter named @@ -1671,6 +1679,10 @@ func hasXID(defn *ast.Definition) bool { return fieldAny(nonExternalAndKeyFields(defn), hasIDDirective) } +func hasEmbedding(defn *ast.Definition) bool { + return fieldAny(nonExternalAndKeyFields(defn), hasEmbeddingDirective) +} + // fieldAny returns true if any field in fields satisfies pred func fieldAny(fields ast.FieldList, pred func(*ast.FieldDefinition) bool) bool { for _, fld := range fields { @@ -1945,7 +1957,8 @@ func addAggregationResultType(schema *ast.Schema, defn *ast.Definition, provides } } -func addGetQuery(schema *ast.Schema, defn *ast.Definition, providesTypeMap map[string]bool, generateSubscription bool) { +func addGetQuery(schema *ast.Schema, defn *ast.Definition, + providesTypeMap map[string]bool, generateSubscription bool) { hasIDField := hasID(defn) hasXIDField := hasXID(defn) xidCount := xidsCount(defn.Fields) @@ -2006,6 +2019,184 @@ func addGetQuery(schema *ast.Schema, defn *ast.Definition, providesTypeMap map[s } } +// addSimilarByEmbeddingQuery adds a query to perform similarity based search on +// a specified embedding as an array of floats +// schema - the graphQL schema. The schema will be modified by this operation +// defn - The type definition for the object for which this query will be added +func addSimilarByEmbeddingQuery(schema *ast.Schema, defn *ast.Definition) { + + qry := &ast.FieldDefinition{ + Name: SimilarQueryPrefix + defn.Name + SimilarByEmbeddingQuerySuffix, + Type: &ast.Type{ + Elem: &ast.Type{ + NamedType: defn.Name, + }, + }, + } + + // The new field is "vector_distance". Add it to input Type + if defn.Fields.ForName(SimilarQueryDistanceFieldName) == nil { + defn.Fields = append(defn.Fields, + &ast.FieldDefinition{ + Name: SimilarQueryDistanceFieldName, + Type: &ast.Type{NamedType: "Float"}}) + } + // Define the enum to + //select from among all predicates with "@embedding" directives + enumName := defn.Name + EmbeddingEnumSuffix + enum := &ast.Definition{ + Kind: ast.Enum, + Name: enumName, + } + + for _, fld := range defn.Fields { + if hasEmbeddingDirective(fld) { + enum.EnumValues = append(enum.EnumValues, + &ast.EnumValueDefinition{Name: fld.Name}) + } + } + schema.Types[enumName] = enum + + //Accept the name of embedding field + qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{ + Name: SimilarByArgName, + Type: &ast.Type{NamedType: enumName, NonNull: true}, + }) + + // Accept the topK, number of nearest neighbors to + // return + qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{ + Name: SimilarTopKArgName, + Type: &ast.Type{ + NamedType: "Int", + NonNull: true, + }, + }) + + // Accept an array of floats as the search vector + qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{ + Name: SimilarVectorArgName, + Type: &ast.Type{ + Elem: &ast.Type{ + NamedType: "Float", + NonNull: true, + }, + NonNull: true, + }, + }) + addFilterArgument(schema, qry) + + schema.Query.Fields = append(schema.Query.Fields, qry) +} + +// addSimilarByIdQuery adds a query that looks up a node based on an id/xid. +// The query then performs a similarity search based on the value of the +// selected embedding field to find similar objects +// schema - The graphQL schema. New enums and result type are added to the schema +// defn - The object type for which the query is added +func addSimilarByIdQuery(schema *ast.Schema, defn *ast.Definition, + providesTypeMap map[string]bool) { + hasIDField := hasID(defn) + hasXIDField := hasXID(defn) + xidCount := xidsCount(defn.Fields) + if !hasIDField && !hasXIDField { + return + } + + // create the new query, querySimilarById + qry := &ast.FieldDefinition{ + Name: SimilarQueryPrefix + defn.Name + SimilarByIdQuerySuffix, + Type: &ast.Type{ + Elem: &ast.Type{ + NamedType: defn.Name, + }, + }, + } + + // The new field is "vector_distance". Add it to input Type + if defn.Fields.ForName(SimilarQueryDistanceFieldName) == nil { + defn.Fields = append(defn.Fields, + &ast.FieldDefinition{ + Name: SimilarQueryDistanceFieldName, + Type: &ast.Type{NamedType: "Float"}}) + } + // If the defn, only specified one of ID/XID field, then they are mandatory. + // If it specified both, then they are optional. + if hasIDField { + fields := getIDField(defn, providesTypeMap) + qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{ + Name: fields[0].Name, + Type: &ast.Type{ + NamedType: idTypeFor(defn), + NonNull: !hasXIDField, + }, + }) + } + + if hasXIDField { + var idWithoutUniqueArgExists bool + for _, fld := range defn.Fields { + if hasIDDirective(fld) { + if !hasInterfaceArg(fld) { + idWithoutUniqueArgExists = true + } + qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{ + Name: fld.Name, + Type: &ast.Type{ + NamedType: fld.Type.Name(), + NonNull: !hasIDField && xidCount <= 1, + }, + }) + } + } + if defn.Kind == "INTERFACE" && idWithoutUniqueArgExists { + qry.Directives = append( + qry.Directives, &ast.Directive{Name: deprecatedDirective, + Arguments: ast.ArgumentList{&ast.Argument{Name: "reason", + Value: &ast.Value{Raw: "@id argument for get query on interface is being" + + " deprecated. Only those @id fields which have interface argument" + + " set to true will be available in getQuery argument on interface" + + " post v21.11.0, please update your schema accordingly.", + Kind: ast.StringValue}}}}) + } + } + + // Define the enum to + //select from among all predicates with "@embedding" directives + enumName := defn.Name + EmbeddingEnumSuffix + enum := &ast.Definition{ + Kind: ast.Enum, + Name: enumName, + } + + for _, fld := range defn.Fields { + if hasEmbeddingDirective(fld) { + enum.EnumValues = append(enum.EnumValues, + &ast.EnumValueDefinition{Name: fld.Name}) + } + } + schema.Types[enumName] = enum + + // Accept the name of the embedding field. + qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{ + Name: SimilarByArgName, + Type: &ast.Type{NamedType: enumName, NonNull: true}, + }) + + // Accept the topK, number of nearest neighbors to + // return + qry.Arguments = append(qry.Arguments, &ast.ArgumentDefinition{ + Name: SimilarTopKArgName, + Type: &ast.Type{ + NamedType: "Int", + NonNull: true, + }, + }) + + addFilterArgument(schema, qry) + schema.Query.Fields = append(schema.Query.Fields, qry) +} + func addFilterQuery( schema *ast.Schema, defn *ast.Definition, @@ -2032,7 +2223,8 @@ func addFilterQuery( } -func addAggregationQuery(schema *ast.Schema, defn *ast.Definition, generateSubscription bool) { +func addAggregationQuery(schema *ast.Schema, + defn *ast.Definition, generateSubscription bool) { qry := &ast.FieldDefinition{ Name: "aggregate" + defn.Name, Type: &ast.Type{ @@ -2049,7 +2241,8 @@ func addAggregationQuery(schema *ast.Schema, defn *ast.Definition, generateSubsc } -func addPasswordQuery(schema *ast.Schema, defn *ast.Definition, providesTypeMap map[string]bool) { +func addPasswordQuery(schema *ast.Schema, + defn *ast.Definition, providesTypeMap map[string]bool) { hasIDField := hasID(defn) hasXIDField := hasXID(defn) if !hasIDField && !hasXIDField { @@ -2095,6 +2288,10 @@ func addQueries( ) { if params.generateGetQuery { addGetQuery(schema, defn, providesTypeMap, params.generateSubscription) + if hasEmbedding(defn) { + addSimilarByIdQuery(schema, defn, providesTypeMap) + addSimilarByEmbeddingQuery(schema, defn) + } } if params.generatePasswordQuery { diff --git a/graphql/schema/gqlschema_test.yml b/graphql/schema/gqlschema_test.yml index 64f918f05e7..eb206de22fa 100644 --- a/graphql/schema/gqlschema_test.yml +++ b/graphql/schema/gqlschema_test.yml @@ -77,7 +77,7 @@ invalid_schemas: name: "Enum indexes clash trigram and regexp" input: | type T { - f: E @search(by: [trigram, regexp]) + f: E @search(by: ["trigram", "regexp"]) } enum E { A @@ -91,7 +91,7 @@ invalid_schemas: name: "Enum indexes clash hash and exact" input: | type T { - f: E @search(by: [hash, exact]) + f: E @search(by: ["hash", "exact"]) } enum E { A @@ -100,6 +100,17 @@ invalid_schemas: {"message": "Type T; Field f: the arguments 'hash' and 'exact' can't be used together as arguments to @search.", "locations": [{"line": 2, "column": 9}]} ] + - + name: "HNSW index options malformed" + input: | + type T { + f: [Float!] @embedding @search(by: ["hnsw(metric:dotproduct)"]) + } + errlist: [ + {"message": "Type T; Field f: has the @search directive but the argument 'hnsw(metric:dotproduct)' with search options is malformed. Search options are comma-separated key-value pairs in YAML format => ", + "locations": [{"line": 2, "column": 27}]} + ] + - name: "Reference type that is not in input schema" input: | @@ -466,7 +477,7 @@ invalid_schemas: name: "Search with wrong arg for the index" input: | type X { - y: String @search(by: [hash, hour]) + y: String @search(by: ["hash", "hour"]) } errlist: [ {"message": "Type X; Field y: has the @search directive but the argument hour doesn't @@ -479,11 +490,11 @@ invalid_schemas: name: "Search without []" input: | type X { - y: String @search(by: hash) + y: String @search(by: "hash") } errlist: [ {"message": "Type X; Field y: the @search directive requires a list argument, - like @search(by: [hash])", + like @search(by: [\"hash\"])", "locations":[{"line":2, "column":14}]} ] @@ -491,7 +502,7 @@ invalid_schemas: name: "Search doesn't allow hash and exact together" input: | type X { - y: String @search(by: [hash, exact]) + y: String @search(by: ["hash", "exact"]) } errlist: [ {"message": "Type X; Field y: the arguments 'hash' and 'exact' can't be @@ -503,7 +514,7 @@ invalid_schemas: name: "Search with multiple datetime index" input: | type X { - y: DateTime @search(by: [hour, month]) + y: DateTime @search(by: ["hour", "month"]) } errlist: [ {"message": "Type X; Field y: has the search directive on DateTime. DateTime @@ -515,7 +526,7 @@ invalid_schemas: name: "Search doesn't allow trigram and regexp together" input: | type X { - y: String @search(by: [trigram, regexp]) + y: String @search(by: ["trigram", "regexp"]) } errlist: [ {"message": "Type X; Field y: the argument to @search 'trigram' is the same as @@ -527,7 +538,7 @@ invalid_schemas: name: "Search doesn't accept bogus args" input: | type X { - y: String @search(by: [bogus]) + y: String @search(by: ["bogus"]) } errlist: [ {"message": "Type X; Field y: the argument to @search bogus isn't valid.Fields of type @@ -2732,7 +2743,7 @@ invalid_schemas: review: String! } errlist: [ - {"message": "Type Product; @remote directive cannot be defined with @key directive", "locations": [ { "line": 174, "column": 12} ] }, + {"message": "Type Product; @remote directive cannot be defined with @key directive", "locations": [ { "line": 176, "column": 12} ] }, ] - name: "directives defined on @external fields that are not @key." @@ -2751,6 +2762,29 @@ invalid_schemas: {"message": "Type Product: Field name: @search directive can not be defined on @external fields that are not @key.", "locations": [ { "line": 3, "column": 18} ] }, ] + - name: "@embedding directive on a field with String type" + input: | + type Product { + id: String! @id + description: String + title: String + productVector: String @embedding + } + errlist: [ + {"message": "Type Product; Field productVector: The field with @embedding directive is of type String, but @embedding directive only applies to fields of type [Float!].", "locations": [ { "line": 5, "column": 3} ] }, + ] + + - name: "@embedding directive on a field with [Int] type" + input: | + type User { + email: String! @id + name: String + userVector: [Int] @embedding + } + errlist: [ + {"message": "Type User; Field userVector: The field with @embedding directive is of type [Int], but @embedding directive only applies to fields of type [Float!].", "locations": [ { "line": 4, "column": 3} ] }, + ] + - name: "@requires directive defined on type definitions" input: | type Product @key(fields: "id"){ @@ -3099,8 +3133,8 @@ valid_schemas: strFulltext: String @search(by: [fulltext]) strTrigram: String @search(by: [trigram]) strRegexp: String @search(by: [regexp]) - strRegexpFulltext: String @search(by: [regexp, fulltext]) - strMultipleIndex: String @search(by: [trigram, hash, term, fulltext]) + strRegexpFulltext: String @search(by: ["regexp", "fulltext"]) + strMultipleIndex: String @search(by: ["trigram", "hash", "term", "fulltext"]) dt: DateTime @search dt2: DateTime @search(by: []) dtYear: DateTime @search(by: [year]) @@ -3446,7 +3480,7 @@ valid_schemas: f2: String @dgraph(pred: "T.f@no") name: String! @id f3: String @dgraph(pred: "f3@en") - nameHi: String @dgraph(pred: "Person.name@hi") @search(by: [term, exact]) + nameHi: String @dgraph(pred: "Person.name@hi") @search(by: ["term", "exact"]) nameEn: String @dgraph(pred: "Person.name@en") @search(by: [regexp]) nameHiEn: String @dgraph(pred: "Person.name@hi:en") nameHi_En_Untag: String @dgraph(pred: "Person.name@hi:en:.") diff --git a/graphql/schema/rules.go b/graphql/schema/rules.go index 1f010f91b36..c5799bddd0b 100644 --- a/graphql/schema/rules.go +++ b/graphql/schema/rules.go @@ -29,6 +29,7 @@ import ( "github.com/dgraph-io/gqlparser/v2/gqlerror" "github.com/dgraph-io/gqlparser/v2/parser" "github.com/dgraph-io/gqlparser/v2/validator" + "gopkg.in/yaml.v2" ) const ( @@ -829,6 +830,33 @@ func listValidityCheck(typ *ast.Definition, field *ast.FieldDefinition) gqlerror return nil } +func embeddingValidation(sch *ast.Schema, typ *ast.Definition, + field *ast.FieldDefinition, dir *ast.Directive, + secrets map[string]x.Sensitive) gqlerror.List { + var errs []*gqlerror.Error + if field.Type.Elem == nil { + errs = append(errs, + gqlerror.ErrorPosf( + field.Position, + "Type %s; Field %s: The field with @embedding directive is of type %s,"+ + " but @embedding directive only applies"+ + " to fields of type [Float!].", typ.Name, field.Name, field.Type.Name())) + return errs + } + + if !strings.EqualFold(field.Type.Elem.NamedType, "Float") || + !field.Type.Elem.NonNull { + errs = append(errs, + gqlerror.ErrorPosf( + field.Position, + "Type %s; Field %s: The field with @embedding directive is of type [%s], "+ + "but @embedding directive only applies"+ + " to fields of type [Float!].", typ.Name, field.Name, field.Type.Name())) + } + + return errs +} + func hasInverseValidation(sch *ast.Schema, typ *ast.Definition, field *ast.FieldDefinition, dir *ast.Directive, secrets map[string]x.Sensitive) gqlerror.List { @@ -971,7 +999,8 @@ func validateSearchArg(searchArg string, dir *ast.Directive) *gqlerror.Error { isEnum := sch.Types[field.Type.Name()].Kind == ast.Enum - search, ok := supportedSearches[searchArg] + searchType := parseSearchType(searchArg) + search, ok := supportedSearches[searchType] switch { case !ok: // This check can be removed once gqlparser bug @@ -998,6 +1027,24 @@ func validateSearchArg(searchArg string, "doesn't apply to field type %s which is an Enum. Enum only supports "+ "hash, exact, regexp and trigram", typ.Name, field.Name, searchArg, field.Type.Name()) + + case search.dgIndex == "hnsw": + if !hasEmbeddingDirective(field) { + return gqlerror.ErrorPosf( + dir.Position, + "Type %s; Field %s: has the @search directive but the argument %s "+ + "requires the field also has @%s directive.", + typ.Name, field.Name, searchArg, embeddingDirective) + } + _, valid := getSearchOptions(searchArg) + if !valid { + return gqlerror.ErrorPosf( + dir.Position, + "Type %s; Field %s: has the @search directive but the argument '%s' "+ + "with search options is malformed. Search options are comma-separated "+ + "key-value pairs in YAML format => ", + typ.Name, field.Name, searchArg) + } } return nil @@ -1041,7 +1088,7 @@ func searchValidation( if arg.Value.Kind != ast.ListValue { errs = append(errs, gqlerror.ErrorPosf( dir.Position, - "Type %s; Field %s: the @search directive requires a list argument, like @search(by: [hash])", + "Type %s; Field %s: the @search directive requires a list argument, like @search(by: [\"hash\"])", typ.Name, field.Name)) return errs } @@ -1056,7 +1103,16 @@ func searchValidation( // Checks that the filter indexes aren't repeated and they // don't clash with each other. - searchIndex := builtInFilters[searchArg] + searchType := parseSearchType(searchArg) + searchIndex := builtInFilters[searchType] + if len(searchIndex) == 0 { + errs = append(errs, gqlerror.ErrorPosf( + dir.Position, + "Type %s; Field %s: the argument to @search '%s' is not among "+ + "supported search types.", + typ.Name, field.Name, searchArg)) + return errs + } if val, ok := searchIndexes[searchIndex]; ok { if field.Type.Name() == "String" || sch.Types[field.Type.Name()].Kind == ast.Enum { errs = append(errs, gqlerror.ErrorPosf( @@ -1087,12 +1143,108 @@ func searchValidation( } } - searchIndexes[searchIndex] = searchArg + searchIndexes[searchIndex] = searchType } return errs } +// parseSearchType(searchArg) parses the searchType from searchArg +// searchArg is specified with the following syntax +// +// := [ ] +// +// hnsw(metric: euclidian, exponent: 6) +// hnsw +// hnsw(exponent: 3) +func parseSearchType(searchArg string) string { + searchType := searchArg + if strings.IndexByte(searchArg, '(') >= 0 { + searchType = searchArg[:strings.IndexByte(searchArg, '(')] + searchType = strings.TrimSpace(searchType) + } + return searchType +} + +// parseSearchOptions(searchArg) parses searchOptions from searchArg +// searchArg is specified with the following syntax +// +// := [ ] +// +// searchOptions := * +// := +// Examples: +// +// hnsw(metric: euclidian, exponent: 6) +// hnsw +// hnsw(exponent: 3) +func parseSearchOptions(searchArg string) (map[string]string, bool) { + searchArg = strings.TrimSpace(searchArg) + openParen := strings.Index(searchArg, "(") + + if openParen < 0 && searchArg[len(searchArg)-1] != ')' { + // no search options and supported searchType found + return map[string]string{}, true // valid = true, no search options + } + + if openParen+1 == len(searchArg)-1 { + // found () with no index options between + // '(' & ')' + // TODO: If DQL schema parser allows the pair of parentheses + // without any options then we need to allow this in GraphQL + // schema too + return map[string]string{}, false + } + + if openParen < 0 || searchArg[len(searchArg)-1] != ')' { + // does not have open/close parenthesis + return map[string]string{}, false // valid = false + } + + indexOptions := "{" + searchArg[openParen+1:len(searchArg)-1] + "}" + var kvMap map[string]string + err := yaml.Unmarshal([]byte(indexOptions), &kvMap) + if err != nil { + return map[string]string{}, false + } + + return kvMap, true // parsed valid options +} + +// getSearchOptions(searchArg) Stringifies search options using DQL syntax +func getSearchOptions(searchArg string) (string, bool) { + res := "" + kvMap, ok := parseSearchOptions(searchArg) + if len(kvMap) == 0 { + return res, ok + } + + keys := make([]string, 0, len(kvMap)) + for k := range kvMap { + keys = append(keys, k) + } + + sort.Strings(keys) + + res += "(" + i := 0 + for _, key := range keys { + if len(kvMap[key]) == 0 { + // If the value is null, then return invalid + return "", false + } + res += strings.TrimSpace(key) + ": \"" + + strings.TrimSpace(kvMap[key]) + "\"" + if i < len(keys)-1 { + res += ", " + } + i++ + } + res += ")" + + return res, true // parsed valid options +} + func dgraphDirectiveValidation(sch *ast.Schema, typ *ast.Definition, field *ast.FieldDefinition, dir *ast.Directive, secrets map[string]x.Sensitive) gqlerror.List { var errs []*gqlerror.Error diff --git a/graphql/schema/schemagen.go b/graphql/schema/schemagen.go index 76b9a30fd96..068e808db27 100644 --- a/graphql/schema/schemagen.go +++ b/graphql/schema/schemagen.go @@ -472,10 +472,14 @@ func getAllowedHeaders(sch *ast.Schema, definitions []string, authHeader string) } func getAllSearchIndexes(val *ast.Value) []string { + // all searchArgs were validated before getting here. res := make([]string, len(val.Children)) for i, child := range val.Children { - res[i] = supportedSearches[child.Value.Raw].dgIndex + searchType := parseSearchType(child.Value.Raw) + res[i] = supportedSearches[searchType].dgIndex + searchOptions, _ := getSearchOptions(child.Value.Raw) + res[i] += searchOptions } return res @@ -669,6 +673,13 @@ func genDgSchema(gqlSch *ast.Schema, definitions []string, } } + embedding := f.Directives.ForName(embeddingDirective) + if embedding != nil { + // embeddingValidation ensured GQL type is [Float] + // set typStr to float32vector + typStr = "float32vector" + } + if parentInt == nil { // if field name contains @ then it is a language tagged field. isLang := false diff --git a/graphql/schema/testdata/apolloservice/input/generate-directive.graphql b/graphql/schema/testdata/apolloservice/input/generate-directive.graphql index 0621754ede9..d45f95ace3e 100644 --- a/graphql/schema/testdata/apolloservice/input/generate-directive.graphql +++ b/graphql/schema/testdata/apolloservice/input/generate-directive.graphql @@ -34,4 +34,4 @@ type Person @withSubscription @generate( ) { id: ID! name: String! -} \ No newline at end of file +} diff --git a/graphql/schema/testdata/apolloservice/output/auth-directive.graphql b/graphql/schema/testdata/apolloservice/output/auth-directive.graphql index d11d8dd6f8b..4c7c8697b62 100644 --- a/graphql/schema/testdata/apolloservice/output/auth-directive.graphql +++ b/graphql/schema/testdata/apolloservice/output/auth-directive.graphql @@ -77,6 +77,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -190,7 +191,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/apolloservice/output/custom-directive.graphql b/graphql/schema/testdata/apolloservice/output/custom-directive.graphql index 0ffe521b151..2895f84b1a9 100644 --- a/graphql/schema/testdata/apolloservice/output/custom-directive.graphql +++ b/graphql/schema/testdata/apolloservice/output/custom-directive.graphql @@ -69,6 +69,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -182,7 +183,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/apolloservice/output/extended-types.graphql b/graphql/schema/testdata/apolloservice/output/extended-types.graphql index 4c2a6ab110b..5cfe4020f15 100644 --- a/graphql/schema/testdata/apolloservice/output/extended-types.graphql +++ b/graphql/schema/testdata/apolloservice/output/extended-types.graphql @@ -83,6 +83,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -196,7 +197,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/apolloservice/output/generate-directive.graphql b/graphql/schema/testdata/apolloservice/output/generate-directive.graphql index fd38dab07b9..e3acf38139c 100644 --- a/graphql/schema/testdata/apolloservice/output/generate-directive.graphql +++ b/graphql/schema/testdata/apolloservice/output/generate-directive.graphql @@ -79,6 +79,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -192,7 +193,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/apolloservice/output/single-extended-type.graphql b/graphql/schema/testdata/apolloservice/output/single-extended-type.graphql index 94acbbcc7a3..3a375fac52e 100644 --- a/graphql/schema/testdata/apolloservice/output/single-extended-type.graphql +++ b/graphql/schema/testdata/apolloservice/output/single-extended-type.graphql @@ -64,6 +64,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -177,7 +178,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/input/auth-on-interfaces.graphql b/graphql/schema/testdata/schemagen/input/auth-on-interfaces.graphql index 2a50a0f32bf..6a8812cc6fe 100644 --- a/graphql/schema/testdata/schemagen/input/auth-on-interfaces.graphql +++ b/graphql/schema/testdata/schemagen/input/auth-on-interfaces.graphql @@ -28,4 +28,4 @@ type Question implements Post @auth( }""" } ){ answered: Boolean @search -} \ No newline at end of file +} diff --git a/graphql/schema/testdata/schemagen/input/embedding-directive-with-similar-queries.graphql b/graphql/schema/testdata/schemagen/input/embedding-directive-with-similar-queries.graphql new file mode 100644 index 00000000000..c3c7f33cd9c --- /dev/null +++ b/graphql/schema/testdata/schemagen/input/embedding-directive-with-similar-queries.graphql @@ -0,0 +1,21 @@ +# simple user product GraphQL API - v0.4 + +type Product { + id: String! @id + description: String + title: String + imageUrl: String + product_vector: [Float!] @embedding @search(by: ["hnsw(metric: euclidian, exponent: 4)"]) +} + +type Purchase @lambdaOnMutate(add: true){ + user: User @hasInverse(field: "purchase_history") + product: Product + date: DateTime @search(by: [day]) +} + +type User { + email: String! @id + purchase_history: [Purchase] + user_vector: [Float!] @embedding @search(by: ["hnsw"]) +} diff --git a/graphql/schema/testdata/schemagen/input/language-tags.graphql b/graphql/schema/testdata/schemagen/input/language-tags.graphql index 97257697b46..0d6633669a9 100644 --- a/graphql/schema/testdata/schemagen/input/language-tags.graphql +++ b/graphql/schema/testdata/schemagen/input/language-tags.graphql @@ -12,7 +12,7 @@ type Person implements Node { f3: String @dgraph(pred: "f3@en") name: String! @id # We can have exact index on language tagged field while having hash index on language untagged field - nameHi: String @dgraph(pred: "Person.name@hi") @search(by: [term, exact]) + nameHi: String @dgraph(pred: "Person.name@hi") @search(by: ["term", "exact"]) nameEn: String @dgraph(pred: "Person.name@en") @search(by: [regexp]) # Below Fields nameHiEn,nameHi_En_Untag won't be added to update/add mutation/ref type # and also to filters, order as they corresponds to multiple language tags diff --git a/graphql/schema/testdata/schemagen/input/searchables-references.graphql b/graphql/schema/testdata/schemagen/input/searchables-references.graphql index fbac2b2cfee..d69dad886ce 100644 --- a/graphql/schema/testdata/schemagen/input/searchables-references.graphql +++ b/graphql/schema/testdata/schemagen/input/searchables-references.graphql @@ -7,7 +7,7 @@ type Author { type Post { postID: ID! - title: String! @search(by: [term, fulltext]) - text: String @search(by: [fulltext, term]) + title: String! @search(by: ["term", "fulltext"]) + text: String @search(by: ["fulltext", "term"]) datePublished: DateTime # Have something not search } diff --git a/graphql/schema/testdata/schemagen/input/searchables.graphql b/graphql/schema/testdata/schemagen/input/searchables.graphql index bc704585127..3690268b843 100644 --- a/graphql/schema/testdata/schemagen/input/searchables.graphql +++ b/graphql/schema/testdata/schemagen/input/searchables.graphql @@ -1,7 +1,7 @@ type Post { postID: ID! title: String! @search(by: [term]) - titleByEverything: String! @search(by: [term, fulltext, trigram, hash]) + titleByEverything: String! @search(by: ["term", "fulltext", "trigram", "hash"]) text: String @search(by: [fulltext]) tags: [String] @search(by: [trigram]) @@ -26,8 +26,8 @@ type Post { postTypeRegexp: PostType @search(by: [regexp]) postTypeExact: [PostType] @search(by: [exact]) postTypeHash: PostType @search(by: [hash]) - postTypeRegexpExact: PostType @search(by: [exact, regexp]) - postTypeHashRegexp: PostType @search(by: [hash, regexp]) + postTypeRegexpExact: PostType @search(by: ["exact", "regexp"]) + postTypeHashRegexp: PostType @search(by: ["hash", "regexp"]) postTypeNone: PostType @search(by: []) } diff --git a/graphql/schema/testdata/schemagen/input/union.graphql b/graphql/schema/testdata/schemagen/input/union.graphql index ee36d9f3cc7..d18f80cd9e1 100644 --- a/graphql/schema/testdata/schemagen/input/union.graphql +++ b/graphql/schema/testdata/schemagen/input/union.graphql @@ -38,4 +38,4 @@ type Planet { url: "http://mock:8888/tool/$id" method: "GET" }) -} \ No newline at end of file +} diff --git a/graphql/schema/testdata/schemagen/output/apollo-federation.graphql b/graphql/schema/testdata/schemagen/output/apollo-federation.graphql index 6e2ce105791..58e3caa0114 100644 --- a/graphql/schema/testdata/schemagen/output/apollo-federation.graphql +++ b/graphql/schema/testdata/schemagen/output/apollo-federation.graphql @@ -99,6 +99,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -212,7 +213,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/auth-on-interfaces.graphql b/graphql/schema/testdata/schemagen/output/auth-on-interfaces.graphql index ea87104b74f..25fe5454258 100644 --- a/graphql/schema/testdata/schemagen/output/auth-on-interfaces.graphql +++ b/graphql/schema/testdata/schemagen/output/auth-on-interfaces.graphql @@ -81,6 +81,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -194,7 +195,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/authorization.graphql b/graphql/schema/testdata/schemagen/output/authorization.graphql index b04bf479f80..4154417f6bf 100644 --- a/graphql/schema/testdata/schemagen/output/authorization.graphql +++ b/graphql/schema/testdata/schemagen/output/authorization.graphql @@ -77,6 +77,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -190,7 +191,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/comments-and-descriptions.graphql b/graphql/schema/testdata/schemagen/output/comments-and-descriptions.graphql old mode 100755 new mode 100644 index cf11ecc3856..5cfbc133bff --- a/graphql/schema/testdata/schemagen/output/comments-and-descriptions.graphql +++ b/graphql/schema/testdata/schemagen/output/comments-and-descriptions.graphql @@ -90,6 +90,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -203,7 +204,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/custom-dql-query-with-subscription.graphql b/graphql/schema/testdata/schemagen/output/custom-dql-query-with-subscription.graphql old mode 100755 new mode 100644 index 99d4d850e10..42d1b5e8735 --- a/graphql/schema/testdata/schemagen/output/custom-dql-query-with-subscription.graphql +++ b/graphql/schema/testdata/schemagen/output/custom-dql-query-with-subscription.graphql @@ -78,6 +78,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -191,7 +192,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/custom-mutation.graphql b/graphql/schema/testdata/schemagen/output/custom-mutation.graphql index 1b48735de16..b06c8052f6d 100644 --- a/graphql/schema/testdata/schemagen/output/custom-mutation.graphql +++ b/graphql/schema/testdata/schemagen/output/custom-mutation.graphql @@ -68,6 +68,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -181,7 +182,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/custom-nested-types.graphql b/graphql/schema/testdata/schemagen/output/custom-nested-types.graphql old mode 100755 new mode 100644 index 7184d90e5a8..ba46ead1487 --- a/graphql/schema/testdata/schemagen/output/custom-nested-types.graphql +++ b/graphql/schema/testdata/schemagen/output/custom-nested-types.graphql @@ -85,6 +85,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -198,7 +199,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/custom-query-mixed-types.graphql b/graphql/schema/testdata/schemagen/output/custom-query-mixed-types.graphql index d441721cf27..689f77789a8 100644 --- a/graphql/schema/testdata/schemagen/output/custom-query-mixed-types.graphql +++ b/graphql/schema/testdata/schemagen/output/custom-query-mixed-types.graphql @@ -69,6 +69,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -182,7 +183,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/custom-query-not-dgraph-type.graphql b/graphql/schema/testdata/schemagen/output/custom-query-not-dgraph-type.graphql old mode 100755 new mode 100644 index bbc666a87e1..69eff1c9bc2 --- a/graphql/schema/testdata/schemagen/output/custom-query-not-dgraph-type.graphql +++ b/graphql/schema/testdata/schemagen/output/custom-query-not-dgraph-type.graphql @@ -68,6 +68,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -181,7 +182,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/custom-query-with-dgraph-type.graphql b/graphql/schema/testdata/schemagen/output/custom-query-with-dgraph-type.graphql old mode 100755 new mode 100644 index a38c50561d8..34959d52171 --- a/graphql/schema/testdata/schemagen/output/custom-query-with-dgraph-type.graphql +++ b/graphql/schema/testdata/schemagen/output/custom-query-with-dgraph-type.graphql @@ -64,6 +64,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -177,7 +178,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/deprecated.graphql b/graphql/schema/testdata/schemagen/output/deprecated.graphql old mode 100755 new mode 100644 index 52cb4ea15d4..a44e32bd30d --- a/graphql/schema/testdata/schemagen/output/deprecated.graphql +++ b/graphql/schema/testdata/schemagen/output/deprecated.graphql @@ -64,6 +64,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -177,7 +178,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/dgraph-reverse-directive-on-concrete-type-with-interfaces.graphql b/graphql/schema/testdata/schemagen/output/dgraph-reverse-directive-on-concrete-type-with-interfaces.graphql old mode 100755 new mode 100644 index b1af0955121..43cb363af0e --- a/graphql/schema/testdata/schemagen/output/dgraph-reverse-directive-on-concrete-type-with-interfaces.graphql +++ b/graphql/schema/testdata/schemagen/output/dgraph-reverse-directive-on-concrete-type-with-interfaces.graphql @@ -81,6 +81,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -194,7 +195,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/dgraph-reverse-directive-with-interfaces.graphql b/graphql/schema/testdata/schemagen/output/dgraph-reverse-directive-with-interfaces.graphql old mode 100755 new mode 100644 index 44da4eec782..40219cc45be --- a/graphql/schema/testdata/schemagen/output/dgraph-reverse-directive-with-interfaces.graphql +++ b/graphql/schema/testdata/schemagen/output/dgraph-reverse-directive-with-interfaces.graphql @@ -81,6 +81,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -194,7 +195,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/embedding-directive-with-similar-queries.graphql b/graphql/schema/testdata/schemagen/output/embedding-directive-with-similar-queries.graphql new file mode 100644 index 00000000000..77d52a9e96a --- /dev/null +++ b/graphql/schema/testdata/schemagen/output/embedding-directive-with-similar-queries.graphql @@ -0,0 +1,575 @@ +####################### +# Input Schema +####################### + +type Product { + id: String! @id + description: String + title: String + imageUrl: String + product_vector: [Float!] @embedding @search(by: ["hnsw(metric: euclidian, exponent: 4)"]) + vector_distance: Float +} + +type Purchase @lambdaOnMutate(add: true) { + user(filter: UserFilter): User @hasInverse(field: "purchase_history") + product(filter: ProductFilter): Product + date: DateTime @search(by: [day]) +} + +type User { + email: String! @id + purchase_history(filter: PurchaseFilter, order: PurchaseOrder, first: Int, offset: Int): [Purchase] @hasInverse(field: user) + user_vector: [Float!] @embedding @search(by: ["hnsw"]) + vector_distance: Float + purchase_historyAggregate(filter: PurchaseFilter): PurchaseAggregateResult +} + +####################### +# Extended Definitions +####################### + +""" +The Int64 scalar type represents a signed 64‐bit numeric non‐fractional value. +Int64 can represent values in range [-(2^63),(2^63 - 1)]. +""" +scalar Int64 + +""" +The DateTime scalar type represents date and time as a string in RFC3339 format. +For example: "1985-04-12T23:20:50.52Z" represents 20 mins 50.52 secs after the 23rd hour of Apr 12th 1985 in UTC. +""" +scalar DateTime + +input IntRange{ + min: Int! + max: Int! +} + +input FloatRange{ + min: Float! + max: Float! +} + +input Int64Range{ + min: Int64! + max: Int64! +} + +input DateTimeRange{ + min: DateTime! + max: DateTime! +} + +input StringRange{ + min: String! + max: String! +} + +enum DgraphIndex { + int + int64 + float + bool + hash + exact + term + fulltext + trigram + regexp + year + month + day + hour + geo + hnsw +} + +input AuthRule { + and: [AuthRule] + or: [AuthRule] + not: AuthRule + rule: String +} + +enum HTTPMethod { + GET + POST + PUT + PATCH + DELETE +} + +enum Mode { + BATCH + SINGLE +} + +input CustomHTTP { + url: String! + method: HTTPMethod! + body: String + graphql: String + mode: Mode + forwardHeaders: [String!] + secretHeaders: [String!] + introspectionHeaders: [String!] + skipIntrospection: Boolean +} + +type Point { + longitude: Float! + latitude: Float! +} + +input PointRef { + longitude: Float! + latitude: Float! +} + +input NearFilter { + distance: Float! + coordinate: PointRef! +} + +input PointGeoFilter { + near: NearFilter + within: WithinFilter +} + +type PointList { + points: [Point!]! +} + +input PointListRef { + points: [PointRef!]! +} + +type Polygon { + coordinates: [PointList!]! +} + +input PolygonRef { + coordinates: [PointListRef!]! +} + +type MultiPolygon { + polygons: [Polygon!]! +} + +input MultiPolygonRef { + polygons: [PolygonRef!]! +} + +input WithinFilter { + polygon: PolygonRef! +} + +input ContainsFilter { + point: PointRef + polygon: PolygonRef +} + +input IntersectsFilter { + polygon: PolygonRef + multiPolygon: MultiPolygonRef +} + +input PolygonGeoFilter { + near: NearFilter + within: WithinFilter + contains: ContainsFilter + intersects: IntersectsFilter +} + +input GenerateQueryParams { + get: Boolean + query: Boolean + password: Boolean + aggregate: Boolean +} + +input GenerateMutationParams { + add: Boolean + update: Boolean + delete: Boolean +} + +directive @hasInverse(field: String!) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION +directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION +directive @id(interface: Boolean) on FIELD_DEFINITION +directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION +directive @secret(field: String!, pred: String) on OBJECT | INTERFACE +directive @auth( + password: AuthRule + query: AuthRule, + add: AuthRule, + update: AuthRule, + delete: AuthRule) on OBJECT | INTERFACE +directive @custom(http: CustomHTTP, dql: String) on FIELD_DEFINITION +directive @remote on OBJECT | INTERFACE | UNION | INPUT_OBJECT | ENUM +directive @remoteResponse(name: String) on FIELD_DEFINITION +directive @cascade(fields: [String]) on FIELD +directive @lambda on FIELD_DEFINITION +directive @lambdaOnMutate(add: Boolean, update: Boolean, delete: Boolean) on OBJECT | INTERFACE +directive @cacheControl(maxAge: Int!) on QUERY +directive @generate( + query: GenerateQueryParams, + mutation: GenerateMutationParams, + subscription: Boolean) on OBJECT | INTERFACE + +input IntFilter { + eq: Int + in: [Int] + le: Int + lt: Int + ge: Int + gt: Int + between: IntRange +} + +input Int64Filter { + eq: Int64 + in: [Int64] + le: Int64 + lt: Int64 + ge: Int64 + gt: Int64 + between: Int64Range +} + +input FloatFilter { + eq: Float + in: [Float] + le: Float + lt: Float + ge: Float + gt: Float + between: FloatRange +} + +input DateTimeFilter { + eq: DateTime + in: [DateTime] + le: DateTime + lt: DateTime + ge: DateTime + gt: DateTime + between: DateTimeRange +} + +input StringTermFilter { + allofterms: String + anyofterms: String +} + +input StringRegExpFilter { + regexp: String +} + +input StringFullTextFilter { + alloftext: String + anyoftext: String +} + +input StringExactFilter { + eq: String + in: [String] + le: String + lt: String + ge: String + gt: String + between: StringRange +} + +input StringHashFilter { + eq: String + in: [String] +} + +####################### +# Generated Types +####################### + +type AddProductPayload { + product(filter: ProductFilter, order: ProductOrder, first: Int, offset: Int): [Product] + numUids: Int +} + +type AddPurchasePayload { + purchase(filter: PurchaseFilter, order: PurchaseOrder, first: Int, offset: Int): [Purchase] + numUids: Int +} + +type AddUserPayload { + user(filter: UserFilter, order: UserOrder, first: Int, offset: Int): [User] + numUids: Int +} + +type DeleteProductPayload { + product(filter: ProductFilter, order: ProductOrder, first: Int, offset: Int): [Product] + msg: String + numUids: Int +} + +type DeletePurchasePayload { + purchase(filter: PurchaseFilter, order: PurchaseOrder, first: Int, offset: Int): [Purchase] + msg: String + numUids: Int +} + +type DeleteUserPayload { + user(filter: UserFilter, order: UserOrder, first: Int, offset: Int): [User] + msg: String + numUids: Int +} + +type ProductAggregateResult { + count: Int + idMin: String + idMax: String + descriptionMin: String + descriptionMax: String + titleMin: String + titleMax: String + imageUrlMin: String + imageUrlMax: String +} + +type PurchaseAggregateResult { + count: Int + dateMin: DateTime + dateMax: DateTime +} + +type UpdateProductPayload { + product(filter: ProductFilter, order: ProductOrder, first: Int, offset: Int): [Product] + numUids: Int +} + +type UpdatePurchasePayload { + purchase(filter: PurchaseFilter, order: PurchaseOrder, first: Int, offset: Int): [Purchase] + numUids: Int +} + +type UpdateUserPayload { + user(filter: UserFilter, order: UserOrder, first: Int, offset: Int): [User] + numUids: Int +} + +type UserAggregateResult { + count: Int + emailMin: String + emailMax: String +} + +####################### +# Generated Enums +####################### + +enum ProductEmbedding { + product_vector +} + +enum ProductHasFilter { + id + description + title + imageUrl + product_vector + vector_distance +} + +enum ProductOrderable { + id + description + title + imageUrl +} + +enum PurchaseHasFilter { + user + product + date +} + +enum PurchaseOrderable { + date +} + +enum UserEmbedding { + user_vector +} + +enum UserHasFilter { + email + purchase_history + user_vector + vector_distance +} + +enum UserOrderable { + email +} + +####################### +# Generated Inputs +####################### + +input AddProductInput { + id: String! + description: String + title: String + imageUrl: String + product_vector: [Float!] +} + +input AddPurchaseInput { + user: UserRef + product: ProductRef + date: DateTime +} + +input AddUserInput { + email: String! + purchase_history: [PurchaseRef] + user_vector: [Float!] +} + +input ProductFilter { + id: StringHashFilter + has: [ProductHasFilter] + and: [ProductFilter] + or: [ProductFilter] + not: ProductFilter +} + +input ProductOrder { + asc: ProductOrderable + desc: ProductOrderable + then: ProductOrder +} + +input ProductPatch { + id: String + description: String + title: String + imageUrl: String + product_vector: [Float!] +} + +input ProductRef { + id: String + description: String + title: String + imageUrl: String + product_vector: [Float!] +} + +input PurchaseFilter { + date: DateTimeFilter + has: [PurchaseHasFilter] + and: [PurchaseFilter] + or: [PurchaseFilter] + not: PurchaseFilter +} + +input PurchaseOrder { + asc: PurchaseOrderable + desc: PurchaseOrderable + then: PurchaseOrder +} + +input PurchasePatch { + user: UserRef + product: ProductRef + date: DateTime +} + +input PurchaseRef { + user: UserRef + product: ProductRef + date: DateTime +} + +input UpdateProductInput { + filter: ProductFilter! + set: ProductPatch + remove: ProductPatch +} + +input UpdatePurchaseInput { + filter: PurchaseFilter! + set: PurchasePatch + remove: PurchasePatch +} + +input UpdateUserInput { + filter: UserFilter! + set: UserPatch + remove: UserPatch +} + +input UserFilter { + email: StringHashFilter + has: [UserHasFilter] + and: [UserFilter] + or: [UserFilter] + not: UserFilter +} + +input UserOrder { + asc: UserOrderable + desc: UserOrderable + then: UserOrder +} + +input UserPatch { + email: String + purchase_history: [PurchaseRef] + user_vector: [Float!] +} + +input UserRef { + email: String + purchase_history: [PurchaseRef] + user_vector: [Float!] +} + +####################### +# Generated Query +####################### + +type Query { + getProduct(id: String!): Product + querySimilarProductById(id: String!, by: ProductEmbedding!, topK: Int!, filter: ProductFilter): [Product] + querySimilarProductByEmbedding(by: ProductEmbedding!, topK: Int!, vector: [Float!]!, filter: ProductFilter): [Product] + queryProduct(filter: ProductFilter, order: ProductOrder, first: Int, offset: Int): [Product] + aggregateProduct(filter: ProductFilter): ProductAggregateResult + queryPurchase(filter: PurchaseFilter, order: PurchaseOrder, first: Int, offset: Int): [Purchase] + aggregatePurchase(filter: PurchaseFilter): PurchaseAggregateResult + getUser(email: String!): User + querySimilarUserById(email: String!, by: UserEmbedding!, topK: Int!, filter: UserFilter): [User] + querySimilarUserByEmbedding(by: UserEmbedding!, topK: Int!, vector: [Float!]!, filter: UserFilter): [User] + queryUser(filter: UserFilter, order: UserOrder, first: Int, offset: Int): [User] + aggregateUser(filter: UserFilter): UserAggregateResult +} + +####################### +# Generated Mutations +####################### + +type Mutation { + addProduct(input: [AddProductInput!]!, upsert: Boolean): AddProductPayload + updateProduct(input: UpdateProductInput!): UpdateProductPayload + deleteProduct(filter: ProductFilter!): DeleteProductPayload + addPurchase(input: [AddPurchaseInput!]!): AddPurchasePayload + updatePurchase(input: UpdatePurchaseInput!): UpdatePurchasePayload + deletePurchase(filter: PurchaseFilter!): DeletePurchasePayload + addUser(input: [AddUserInput!]!, upsert: Boolean): AddUserPayload + updateUser(input: UpdateUserInput!): UpdateUserPayload + deleteUser(filter: UserFilter!): DeleteUserPayload +} + diff --git a/graphql/schema/testdata/schemagen/output/field-with-id-directive.graphql b/graphql/schema/testdata/schemagen/output/field-with-id-directive.graphql old mode 100755 new mode 100644 index c352bb70d87..c570fe404b0 --- a/graphql/schema/testdata/schemagen/output/field-with-id-directive.graphql +++ b/graphql/schema/testdata/schemagen/output/field-with-id-directive.graphql @@ -78,6 +78,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -191,7 +192,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/field-with-multiple-@id-fields.graphql b/graphql/schema/testdata/schemagen/output/field-with-multiple-@id-fields.graphql old mode 100755 new mode 100644 index 13aa2507611..4742f69f583 --- a/graphql/schema/testdata/schemagen/output/field-with-multiple-@id-fields.graphql +++ b/graphql/schema/testdata/schemagen/output/field-with-multiple-@id-fields.graphql @@ -78,6 +78,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -191,7 +192,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/field-with-reverse-predicate-in-dgraph-directive.graphql b/graphql/schema/testdata/schemagen/output/field-with-reverse-predicate-in-dgraph-directive.graphql old mode 100755 new mode 100644 index f2145184ee6..b323d5c0af3 --- a/graphql/schema/testdata/schemagen/output/field-with-reverse-predicate-in-dgraph-directive.graphql +++ b/graphql/schema/testdata/schemagen/output/field-with-reverse-predicate-in-dgraph-directive.graphql @@ -73,6 +73,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -186,7 +187,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/filter-cleanSchema-all-empty.graphql b/graphql/schema/testdata/schemagen/output/filter-cleanSchema-all-empty.graphql index 64d88e15dd9..98a287265a8 100644 --- a/graphql/schema/testdata/schemagen/output/filter-cleanSchema-all-empty.graphql +++ b/graphql/schema/testdata/schemagen/output/filter-cleanSchema-all-empty.graphql @@ -76,6 +76,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -189,7 +190,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/filter-cleanSchema-circular.graphql b/graphql/schema/testdata/schemagen/output/filter-cleanSchema-circular.graphql index fd2ae029d78..e8a0a4e3262 100644 --- a/graphql/schema/testdata/schemagen/output/filter-cleanSchema-circular.graphql +++ b/graphql/schema/testdata/schemagen/output/filter-cleanSchema-circular.graphql @@ -80,6 +80,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -193,7 +194,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/filter-cleanSchema-custom-mutation.graphql b/graphql/schema/testdata/schemagen/output/filter-cleanSchema-custom-mutation.graphql index 480d42cc906..b49eccb0dda 100644 --- a/graphql/schema/testdata/schemagen/output/filter-cleanSchema-custom-mutation.graphql +++ b/graphql/schema/testdata/schemagen/output/filter-cleanSchema-custom-mutation.graphql @@ -68,6 +68,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -181,7 +182,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/filter-cleanSchema-directLink.graphql b/graphql/schema/testdata/schemagen/output/filter-cleanSchema-directLink.graphql index ce22f674d5a..12d3fc8f1a8 100644 --- a/graphql/schema/testdata/schemagen/output/filter-cleanSchema-directLink.graphql +++ b/graphql/schema/testdata/schemagen/output/filter-cleanSchema-directLink.graphql @@ -78,6 +78,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -191,7 +192,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/generate-directive.graphql b/graphql/schema/testdata/schemagen/output/generate-directive.graphql index 046c376fd08..2218601efb5 100644 --- a/graphql/schema/testdata/schemagen/output/generate-directive.graphql +++ b/graphql/schema/testdata/schemagen/output/generate-directive.graphql @@ -79,6 +79,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -192,7 +193,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/geo-type.graphql b/graphql/schema/testdata/schemagen/output/geo-type.graphql index 7b007dbea08..229ca9e0bf6 100644 --- a/graphql/schema/testdata/schemagen/output/geo-type.graphql +++ b/graphql/schema/testdata/schemagen/output/geo-type.graphql @@ -70,6 +70,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -183,7 +184,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/hasInverse-with-interface-having-directive.graphql b/graphql/schema/testdata/schemagen/output/hasInverse-with-interface-having-directive.graphql old mode 100755 new mode 100644 index 2a6e74d672c..18d386066e0 --- a/graphql/schema/testdata/schemagen/output/hasInverse-with-interface-having-directive.graphql +++ b/graphql/schema/testdata/schemagen/output/hasInverse-with-interface-having-directive.graphql @@ -89,6 +89,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -202,7 +203,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/hasInverse-with-interface.graphql b/graphql/schema/testdata/schemagen/output/hasInverse-with-interface.graphql old mode 100755 new mode 100644 index 7262833b462..6f2de5a2e5d --- a/graphql/schema/testdata/schemagen/output/hasInverse-with-interface.graphql +++ b/graphql/schema/testdata/schemagen/output/hasInverse-with-interface.graphql @@ -91,6 +91,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -204,7 +205,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/hasInverse-with-type-having-directive.graphql b/graphql/schema/testdata/schemagen/output/hasInverse-with-type-having-directive.graphql old mode 100755 new mode 100644 index 2a6e74d672c..18d386066e0 --- a/graphql/schema/testdata/schemagen/output/hasInverse-with-type-having-directive.graphql +++ b/graphql/schema/testdata/schemagen/output/hasInverse-with-type-having-directive.graphql @@ -89,6 +89,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -202,7 +203,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/hasInverse.graphql b/graphql/schema/testdata/schemagen/output/hasInverse.graphql old mode 100755 new mode 100644 index e550998e55a..2c5c948720f --- a/graphql/schema/testdata/schemagen/output/hasInverse.graphql +++ b/graphql/schema/testdata/schemagen/output/hasInverse.graphql @@ -70,6 +70,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -183,7 +184,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/hasInverse_withSubscription.graphql b/graphql/schema/testdata/schemagen/output/hasInverse_withSubscription.graphql old mode 100755 new mode 100644 index 08cebd56299..67d0ebb7360 --- a/graphql/schema/testdata/schemagen/output/hasInverse_withSubscription.graphql +++ b/graphql/schema/testdata/schemagen/output/hasInverse_withSubscription.graphql @@ -70,6 +70,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -183,7 +184,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/hasfilter.graphql b/graphql/schema/testdata/schemagen/output/hasfilter.graphql index e5adcc1acc6..3fd950ad953 100644 --- a/graphql/schema/testdata/schemagen/output/hasfilter.graphql +++ b/graphql/schema/testdata/schemagen/output/hasfilter.graphql @@ -72,6 +72,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -185,7 +186,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/ignore-unsupported-directive.graphql b/graphql/schema/testdata/schemagen/output/ignore-unsupported-directive.graphql old mode 100755 new mode 100644 index 0d50f4fcb38..11a0c470f28 --- a/graphql/schema/testdata/schemagen/output/ignore-unsupported-directive.graphql +++ b/graphql/schema/testdata/schemagen/output/ignore-unsupported-directive.graphql @@ -71,6 +71,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -184,7 +185,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/interface-with-dgraph-pred.graphql b/graphql/schema/testdata/schemagen/output/interface-with-dgraph-pred.graphql index a604dd070c5..315a044ef03 100644 --- a/graphql/schema/testdata/schemagen/output/interface-with-dgraph-pred.graphql +++ b/graphql/schema/testdata/schemagen/output/interface-with-dgraph-pred.graphql @@ -80,6 +80,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -193,7 +194,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/interface-with-id-directive.graphql b/graphql/schema/testdata/schemagen/output/interface-with-id-directive.graphql old mode 100755 new mode 100644 index b69bab2269d..45286f85154 --- a/graphql/schema/testdata/schemagen/output/interface-with-id-directive.graphql +++ b/graphql/schema/testdata/schemagen/output/interface-with-id-directive.graphql @@ -76,6 +76,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -189,7 +190,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/interface-with-no-ids.graphql b/graphql/schema/testdata/schemagen/output/interface-with-no-ids.graphql old mode 100755 new mode 100644 index 90161b27860..15d9aa8534d --- a/graphql/schema/testdata/schemagen/output/interface-with-no-ids.graphql +++ b/graphql/schema/testdata/schemagen/output/interface-with-no-ids.graphql @@ -74,6 +74,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -187,7 +188,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/interfaces-with-types-and-password.graphql b/graphql/schema/testdata/schemagen/output/interfaces-with-types-and-password.graphql old mode 100755 new mode 100644 index 2314846daa9..d1f7645b71b --- a/graphql/schema/testdata/schemagen/output/interfaces-with-types-and-password.graphql +++ b/graphql/schema/testdata/schemagen/output/interfaces-with-types-and-password.graphql @@ -99,6 +99,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -212,7 +213,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/interfaces-with-types.graphql b/graphql/schema/testdata/schemagen/output/interfaces-with-types.graphql old mode 100755 new mode 100644 index 53da87263b2..667dc065979 --- a/graphql/schema/testdata/schemagen/output/interfaces-with-types.graphql +++ b/graphql/schema/testdata/schemagen/output/interfaces-with-types.graphql @@ -99,6 +99,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -212,7 +213,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/lambda-directive.graphql b/graphql/schema/testdata/schemagen/output/lambda-directive.graphql index 8df251f64a7..c083d55a1a4 100644 --- a/graphql/schema/testdata/schemagen/output/lambda-directive.graphql +++ b/graphql/schema/testdata/schemagen/output/lambda-directive.graphql @@ -66,6 +66,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -179,7 +180,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/language-tags.graphql b/graphql/schema/testdata/schemagen/output/language-tags.graphql old mode 100755 new mode 100644 index 75d0f6daa9a..03c3a4b89cf --- a/graphql/schema/testdata/schemagen/output/language-tags.graphql +++ b/graphql/schema/testdata/schemagen/output/language-tags.graphql @@ -12,7 +12,7 @@ type Person implements Node { f2: String @dgraph(pred: "T.f@no") f3: String @dgraph(pred: "f3@en") name: String! @id - nameHi: String @dgraph(pred: "Person.name@hi") @search(by: [term,exact]) + nameHi: String @dgraph(pred: "Person.name@hi") @search(by: ["term","exact"]) nameEn: String @dgraph(pred: "Person.name@en") @search(by: [regexp]) nameHiEn: String @dgraph(pred: "Person.name@hi:en") nameHi_En_Untag: String @dgraph(pred: "Person.name@hi:en:.") @@ -79,6 +79,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -192,7 +193,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/no-id-field-with-searchables.graphql b/graphql/schema/testdata/schemagen/output/no-id-field-with-searchables.graphql old mode 100755 new mode 100644 index 5430dbf9e17..fafc7d3ac19 --- a/graphql/schema/testdata/schemagen/output/no-id-field-with-searchables.graphql +++ b/graphql/schema/testdata/schemagen/output/no-id-field-with-searchables.graphql @@ -63,6 +63,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -176,7 +177,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/no-id-field.graphql b/graphql/schema/testdata/schemagen/output/no-id-field.graphql old mode 100755 new mode 100644 index 1976d50b845..929f96e399b --- a/graphql/schema/testdata/schemagen/output/no-id-field.graphql +++ b/graphql/schema/testdata/schemagen/output/no-id-field.graphql @@ -76,6 +76,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -189,7 +190,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/password-type.graphql b/graphql/schema/testdata/schemagen/output/password-type.graphql old mode 100755 new mode 100644 index 7bfe0bf4007..cb13e7d52fc --- a/graphql/schema/testdata/schemagen/output/password-type.graphql +++ b/graphql/schema/testdata/schemagen/output/password-type.graphql @@ -64,6 +64,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -177,7 +178,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/random.graphql b/graphql/schema/testdata/schemagen/output/random.graphql index 117b334faef..196c54e6e2f 100644 --- a/graphql/schema/testdata/schemagen/output/random.graphql +++ b/graphql/schema/testdata/schemagen/output/random.graphql @@ -82,6 +82,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -195,7 +196,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/searchables-references.graphql b/graphql/schema/testdata/schemagen/output/searchables-references.graphql old mode 100755 new mode 100644 index a1d33610f5a..e2ef7e884ad --- a/graphql/schema/testdata/schemagen/output/searchables-references.graphql +++ b/graphql/schema/testdata/schemagen/output/searchables-references.graphql @@ -12,8 +12,8 @@ type Author { type Post { postID: ID! - title: String! @search(by: [term,fulltext]) - text: String @search(by: [fulltext,term]) + title: String! @search(by: ["term","fulltext"]) + text: String @search(by: ["fulltext","term"]) datePublished: DateTime } @@ -74,6 +74,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -187,7 +188,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/searchables.graphql b/graphql/schema/testdata/schemagen/output/searchables.graphql old mode 100755 new mode 100644 index 8e21af2ddcd..822434df817 --- a/graphql/schema/testdata/schemagen/output/searchables.graphql +++ b/graphql/schema/testdata/schemagen/output/searchables.graphql @@ -5,7 +5,7 @@ type Post { postID: ID! title: String! @search(by: [term]) - titleByEverything: String! @search(by: [term,fulltext,trigram,hash]) + titleByEverything: String! @search(by: ["term","fulltext","trigram","hash"]) text: String @search(by: [fulltext]) tags: [String] @search(by: [trigram]) tagsHash: [String] @search(by: [hash]) @@ -26,8 +26,8 @@ type Post { postTypeRegexp: PostType @search(by: [regexp]) postTypeExact: [PostType] @search(by: [exact]) postTypeHash: PostType @search(by: [hash]) - postTypeRegexpExact: PostType @search(by: [exact,regexp]) - postTypeHashRegexp: PostType @search(by: [hash,regexp]) + postTypeRegexpExact: PostType @search(by: ["exact","regexp"]) + postTypeHashRegexp: PostType @search(by: ["hash","regexp"]) postTypeNone: PostType @search(by: []) } @@ -94,6 +94,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -207,7 +208,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/single-type-with-enum.graphql b/graphql/schema/testdata/schemagen/output/single-type-with-enum.graphql old mode 100755 new mode 100644 index dc67bd55967..9e618990057 --- a/graphql/schema/testdata/schemagen/output/single-type-with-enum.graphql +++ b/graphql/schema/testdata/schemagen/output/single-type-with-enum.graphql @@ -72,6 +72,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -185,7 +186,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/single-type.graphql b/graphql/schema/testdata/schemagen/output/single-type.graphql old mode 100755 new mode 100644 index ffd93575abf..dfd68e48be5 --- a/graphql/schema/testdata/schemagen/output/single-type.graphql +++ b/graphql/schema/testdata/schemagen/output/single-type.graphql @@ -67,6 +67,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -180,7 +181,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/type-implements-multiple-interfaces.graphql b/graphql/schema/testdata/schemagen/output/type-implements-multiple-interfaces.graphql old mode 100755 new mode 100644 index b789d8551f7..1307a08b8cf --- a/graphql/schema/testdata/schemagen/output/type-implements-multiple-interfaces.graphql +++ b/graphql/schema/testdata/schemagen/output/type-implements-multiple-interfaces.graphql @@ -81,6 +81,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -194,7 +195,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/type-reference.graphql b/graphql/schema/testdata/schemagen/output/type-reference.graphql old mode 100755 new mode 100644 index 721497a18fd..30e7ab07d77 --- a/graphql/schema/testdata/schemagen/output/type-reference.graphql +++ b/graphql/schema/testdata/schemagen/output/type-reference.graphql @@ -71,6 +71,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -184,7 +185,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/type-with-arguments-on-field.graphql b/graphql/schema/testdata/schemagen/output/type-with-arguments-on-field.graphql index 192482c12fa..dc5a5938e70 100644 --- a/graphql/schema/testdata/schemagen/output/type-with-arguments-on-field.graphql +++ b/graphql/schema/testdata/schemagen/output/type-with-arguments-on-field.graphql @@ -72,6 +72,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -185,7 +186,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/type-with-custom-field-on-dgraph-type.graphql b/graphql/schema/testdata/schemagen/output/type-with-custom-field-on-dgraph-type.graphql index 7e4c71914f7..2ab2f58c4fa 100644 --- a/graphql/schema/testdata/schemagen/output/type-with-custom-field-on-dgraph-type.graphql +++ b/graphql/schema/testdata/schemagen/output/type-with-custom-field-on-dgraph-type.graphql @@ -71,6 +71,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -184,7 +185,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/type-with-custom-fields-on-remote-type.graphql b/graphql/schema/testdata/schemagen/output/type-with-custom-fields-on-remote-type.graphql index e85ea68b041..e425637fb7c 100644 --- a/graphql/schema/testdata/schemagen/output/type-with-custom-fields-on-remote-type.graphql +++ b/graphql/schema/testdata/schemagen/output/type-with-custom-fields-on-remote-type.graphql @@ -71,6 +71,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -184,7 +185,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/type-without-orderables.graphql b/graphql/schema/testdata/schemagen/output/type-without-orderables.graphql index ff78a20fbb1..7196b507ad5 100644 --- a/graphql/schema/testdata/schemagen/output/type-without-orderables.graphql +++ b/graphql/schema/testdata/schemagen/output/type-without-orderables.graphql @@ -66,6 +66,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -179,7 +180,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/testdata/schemagen/output/union.graphql b/graphql/schema/testdata/schemagen/output/union.graphql index 33eaa4c6c54..c6ea004e33f 100644 --- a/graphql/schema/testdata/schemagen/output/union.graphql +++ b/graphql/schema/testdata/schemagen/output/union.graphql @@ -113,6 +113,7 @@ enum DgraphIndex { day hour geo + hnsw } input AuthRule { @@ -226,7 +227,8 @@ input GenerateMutationParams { } directive @hasInverse(field: String!) on FIELD_DEFINITION -directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION +directive @search(by: [String!]) on FIELD_DEFINITION +directive @embedding on FIELD_DEFINITION directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION directive @id(interface: Boolean) on FIELD_DEFINITION directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION diff --git a/graphql/schema/wrappers.go b/graphql/schema/wrappers.go index 40239982a55..e756d681a36 100644 --- a/graphql/schema/wrappers.go +++ b/graphql/schema/wrappers.go @@ -85,24 +85,38 @@ type EntityRepresentations struct { // Query/Mutation types and arg names const ( - GetQuery QueryType = "get" - FilterQuery QueryType = "query" - AggregateQuery QueryType = "aggregate" - SchemaQuery QueryType = "schema" - EntitiesQuery QueryType = "entities" - PasswordQuery QueryType = "checkPassword" - HTTPQuery QueryType = "http" - DQLQuery QueryType = "dql" - NotSupportedQuery QueryType = "notsupported" - AddMutation MutationType = "add" - UpdateMutation MutationType = "update" - DeleteMutation MutationType = "delete" - HTTPMutation MutationType = "http" - NotSupportedMutation MutationType = "notsupported" - IDType = "ID" - InputArgName = "input" - UpsertArgName = "upsert" - FilterArgName = "filter" + GetQuery QueryType = "get" + SimilarByIdQuery QueryType = "querySimilarById" + SimilarByEmbeddingQuery QueryType = "querySimilarByEmbedding" + FilterQuery QueryType = "query" + AggregateQuery QueryType = "aggregate" + SchemaQuery QueryType = "schema" + EntitiesQuery QueryType = "entities" + PasswordQuery QueryType = "checkPassword" + HTTPQuery QueryType = "http" + DQLQuery QueryType = "dql" + NotSupportedQuery QueryType = "notsupported" + AddMutation MutationType = "add" + UpdateMutation MutationType = "update" + DeleteMutation MutationType = "delete" + HTTPMutation MutationType = "http" + NotSupportedMutation MutationType = "notsupported" + IDType = "ID" + InputArgName = "input" + UpsertArgName = "upsert" + FilterArgName = "filter" + SimilarByArgName = "by" + SimilarTopKArgName = "topK" + SimilarVectorArgName = "vector" + EmbeddingEnumSuffix = "Embedding" + SimilarQueryPrefix = "querySimilar" + SimilarByIdQuerySuffix = "ById" + SimilarByEmbeddingQuerySuffix = "ByEmbedding" + SimilarQueryResultTypeSuffix = "WithDistance" + SimilarQueryDistanceFieldName = "vector_distance" + SimilarSearchMetricEuclidian = "euclidian" + SimilarSearchMetricDotProduct = "dotproduct" + SimilarSearchMetricCosine = "cosine" ) // Schema represents a valid GraphQL schema @@ -269,6 +283,8 @@ type FieldDefinition interface { IsID() bool IsExternal() bool HasIDDirective() bool + HasEmbeddingDirective() bool + EmbeddingSearchMetric() string HasInterfaceArg() bool Inverse() FieldDefinition WithMemberType(string) FieldDefinition @@ -1376,8 +1392,11 @@ func (f *field) IDArgValue() (xids map[string]string, uid uint64, err error) { // or Password. Therefore the non ID and Password field is an XID. // TODO maybe there is a better way to do this. for _, arg := range f.field.Arguments { + xidArgName = "" if (idField == nil || arg.Name != idField.Name()) && - (passwordField == nil || arg.Name != passwordField.Name()) { + (passwordField == nil || arg.Name != passwordField.Name()) && + (queryType(f.field.Name, nil) != SimilarByIdQuery || + (arg.Name != SimilarTopKArgName && arg.Name != SimilarByArgName && arg.Name != "filter")) { xidArgName = arg.Name } @@ -2007,6 +2026,10 @@ func queryType(name string, custom *ast.Directive) QueryType { return GetQuery case name == "__schema" || name == "__type" || name == "__typename": return SchemaQuery + case strings.HasPrefix(name, SimilarQueryPrefix) && strings.HasSuffix(name, SimilarByIdQuerySuffix): + return SimilarByIdQuery + case strings.HasPrefix(name, SimilarQueryPrefix) && strings.HasSuffix(name, SimilarByEmbeddingQuerySuffix): + return SimilarByEmbeddingQuery case strings.HasPrefix(name, "query"): return FilterQuery case strings.HasPrefix(name, "check"): @@ -2325,6 +2348,30 @@ func hasIDDirective(fd *ast.FieldDefinition) bool { return id != nil } +func (fd *fieldDefinition) HasEmbeddingDirective() bool { + if fd.fieldDef == nil { + return false + } + return hasEmbeddingDirective(fd.fieldDef) +} + +func (fd *fieldDefinition) EmbeddingSearchMetric() string { + if fd.fieldDef == nil || !hasEmbeddingDirective(fd.fieldDef) || + fd.fieldDef.Directives.ForName(searchDirective) == nil { + return "" + } + + searchArg := getSearchArgs(fd.fieldDef)[0] + kvMap, _ := parseSearchOptions(searchArg) + + return kvMap["metric"] +} + +func hasEmbeddingDirective(fd *ast.FieldDefinition) bool { + id := fd.Directives.ForName(embeddingDirective) + return id != nil +} + func (fd *fieldDefinition) HasInterfaceArg() bool { if fd.fieldDef == nil { return false diff --git a/graphql/schema/wrappers_test.go b/graphql/schema/wrappers_test.go index b119842d41a..9867020bd13 100644 --- a/graphql/schema/wrappers_test.go +++ b/graphql/schema/wrappers_test.go @@ -37,7 +37,7 @@ func TestDgraphMapping_WithoutDirectives(t *testing.T) { type Author { id: ID! - name: String! @search(by: [hash, trigram]) + name: String! @search(by: ["hash", "trigram"]) dob: DateTime @search reputation: Float @search posts: [Post!] @hasInverse(field: author) @@ -222,7 +222,7 @@ func TestDgraphMapping_WithDirectives(t *testing.T) { type Author @dgraph(type: "dgraph.author") { id: ID! - name: String! @search(by: [hash, trigram]) + name: String! @search(by: ["hash", "trigram"]) dob: DateTime @search reputation: Float @search posts: [Post!] @hasInverse(field: author) diff --git a/posting/index.go b/posting/index.go index c1463b57696..3257583899f 100644 --- a/posting/index.go +++ b/posting/index.go @@ -239,7 +239,7 @@ func (txn *Txn) addReverseMutationHelper(ctx context.Context, plist *List, plist.Lock() defer plist.Unlock() if hasCountIndex { - countBefore, found, _ = plist.getPostingAndLength(txn.StartTs, 0, edge.ValueId) + countBefore, found, _ = plist.getPostingAndLengthNoSort(txn.StartTs, 0, edge.ValueId) if countBefore == -1 { return emptyCountParams, ErrTsTooOld } @@ -649,6 +649,120 @@ type rebuilder struct { fn func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) } +func (r *rebuilder) RunWithoutTemp(ctx context.Context) error { + stream := pstore.NewStreamAt(r.startTs) + stream.LogPrefix = fmt.Sprintf("Rebuilding index for predicate %s (1/2):", r.attr) + stream.Prefix = r.prefix + stream.NumGo = 128 + txn := NewTxn(r.startTs) + stream.KeyToList = func(key []byte, it *badger.Iterator) (*bpb.KVList, error) { + // We should return quickly if the context is no longer valid. + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + pk, err := x.Parse(key) + if err != nil { + return nil, errors.Wrapf(err, "could not parse key %s", hex.Dump(key)) + } + + l := new(List) + l.key = key + l.plist = new(pb.PostingList) + + found := false + + for it.Valid() { + item := it.Item() + if !bytes.Equal(item.Key(), l.key) { + break + } + l.maxTs = x.Max(l.maxTs, item.Version()) + if item.IsDeletedOrExpired() { + // Don't consider any more versions. + break + } + + found = true + switch item.UserMeta() { + case BitEmptyPosting: + l.minTs = item.Version() + case BitCompletePosting: + if err := unmarshalOrCopy(l.plist, item); err != nil { + return nil, err + } + l.minTs = item.Version() + + // No need to do Next here. The outer loop can take care of skipping + // more versions of the same key. + case BitDeltaPosting: + err := item.Value(func(val []byte) error { + pl := &pb.PostingList{} + if err := pl.Unmarshal(val); err != nil { + return err + } + pl.CommitTs = item.Version() + for _, mpost := range pl.Postings { + // commitTs, startTs are meant to be only in memory, not + // stored on disk. + mpost.CommitTs = item.Version() + } + if l.mutationMap == nil { + l.mutationMap = make(map[uint64]*pb.PostingList) + } + l.mutationMap[pl.CommitTs] = pl + return nil + }) + if err != nil { + return nil, err + } + default: + return nil, errors.Errorf( + "Unexpected meta: %d for key: %s", item.UserMeta(), hex.Dump(key)) + } + if found { + break + } + } + + if _, err := r.fn(pk.Uid, l, txn); err != nil { + return nil, err + } + return nil, nil + } + stream.Send = func(buf *z.Buffer) error { + // TODO. Make an in memory txn with disk backing for more data than memory. + return nil + } + + start := time.Now() + if err := stream.Orchestrate(ctx); err != nil { + return err + } + + txn.Update() + writer := NewTxnWriter(pstore) + + defer func() { + glog.V(1).Infof("Rebuilding index for predicate %s: building index took: %v\n", + r.attr, time.Since(start)) + }() + + ResetCache() + + return x.ExponentialRetry(int(x.Config.MaxRetries), + 20*time.Millisecond, func() error { + err := txn.CommitToDisk(writer, r.startTs) + if err == badger.ErrBannedKey { + glog.Errorf("Error while writing to banned namespace.") + return nil + } + return err + }) +} + func (r *rebuilder) Run(ctx context.Context) error { if r.startTs == 0 { glog.Infof("maxassigned is 0, no indexing work for predicate %s", r.attr) @@ -1175,6 +1289,8 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { factorySpecs = []*tok.FactoryCreateSpec{factorySpec} } + runForVectors := (len(factorySpecs) != 0) + pk := x.ParsedKey{Attr: rb.Attr} builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { @@ -1200,7 +1316,9 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { case ErrRetry: time.Sleep(10 * time.Millisecond) default: - edges = append(edges, newEdges...) + if !runForVectors { + edges = append(edges, newEdges...) + } return err } } @@ -1210,6 +1328,9 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { } return edges, err } + if len(factorySpecs) != 0 { + return builder.RunWithoutTemp(ctx) + } return builder.Run(ctx) } @@ -1511,11 +1632,13 @@ func rebuildListType(ctx context.Context, rb *IndexRebuild) error { // DeleteAll deletes all entries in the posting list. func DeleteAll() error { + ResetCache() return pstore.DropAll() } // DeleteData deletes all data for the namespace but leaves types and schema intact. func DeleteData(ns uint64) error { + ResetCache() prefix := make([]byte, 9) prefix[0] = x.DefaultPrefix binary.BigEndian.PutUint64(prefix[1:], ns) @@ -1525,6 +1648,7 @@ func DeleteData(ns uint64) error { // DeletePredicate deletes all entries and indices for a given predicate. func DeletePredicate(ctx context.Context, attr string, ts uint64) error { glog.Infof("Dropping predicate: [%s]", attr) + ResetCache() preds := schema.State().PredicatesToDelete(attr) for _, pred := range preds { prefix := x.PredicatePrefix(pred) @@ -1541,6 +1665,8 @@ func DeletePredicate(ctx context.Context, attr string, ts uint64) error { // DeleteNamespace bans the namespace and deletes its predicates/types from the schema. func DeleteNamespace(ns uint64) error { + // TODO: We should only delete cache for certain keys, not all the keys. + ResetCache() schema.State().DeletePredsForNs(ns) return pstore.BanNamespace(ns) } diff --git a/posting/list.go b/posting/list.go index 24b5f02f3f6..8c86562ad96 100644 --- a/posting/list.go +++ b/posting/list.go @@ -25,6 +25,7 @@ import ( "sort" "github.com/dgryski/go-farm" + "github.com/golang/glog" "github.com/golang/protobuf/proto" "github.com/pkg/errors" @@ -556,6 +557,27 @@ func (l *List) getMutation(startTs uint64) []byte { return nil } +func (l *List) setMutationAfterCommit(startTs, commitTs uint64, data []byte) { + pl := new(pb.PostingList) + x.Check(pl.Unmarshal(data)) + pl.CommitTs = commitTs + for _, p := range pl.Postings { + p.CommitTs = commitTs + } + + x.AssertTrue(pl.Pack == nil) + + l.Lock() + if l.mutationMap == nil { + l.mutationMap = make(map[uint64]*pb.PostingList) + } + l.mutationMap[startTs] = pl + if pl.CommitTs != 0 { + l.maxTs = x.Max(l.maxTs, pl.CommitTs) + } + l.Unlock() +} + func (l *List) setMutation(startTs uint64, data []byte) { pl := new(pb.PostingList) x.Check(pl.Unmarshal(data)) @@ -565,6 +587,9 @@ func (l *List) setMutation(startTs uint64, data []byte) { l.mutationMap = make(map[uint64]*pb.PostingList) } l.mutationMap[startTs] = pl + if pl.CommitTs != 0 { + l.maxTs = x.Max(l.maxTs, pl.CommitTs) + } l.Unlock() } @@ -666,6 +691,19 @@ func (l *List) iterate(readTs uint64, afterUid uint64, f func(obj *pb.Posting) e }) } + numDeletePostingsRead := 0 + numNormalPostingsRead := 0 + defer func() { + // If we see a lot of these logs, it means that a lot of elements are getting deleted. + // This could be normal, but if we see this too much, that means that rollups are too slow. + if numNormalPostingsRead < numDeletePostingsRead && + (numNormalPostingsRead > 0 || numDeletePostingsRead > 0) { + glog.V(3).Infof("High proportion of deleted data observed for posting list %b: total = %d, "+ + "percent deleted = %d", l.key, numNormalPostingsRead+numDeletePostingsRead, + (numDeletePostingsRead*100)/(numDeletePostingsRead+numNormalPostingsRead)) + } + }() + var ( mp, pp *pb.Posting pitr pIterator @@ -708,6 +746,7 @@ loop: case mp.Uid == 0 || (pp.Uid > 0 && pp.Uid < mp.Uid): // Either mp is empty, or pp is lower than mp. err = f(pp) + numNormalPostingsRead += 1 if err != nil { break loop } @@ -719,18 +758,24 @@ loop: // Either pp is empty, or mp is lower than pp. if mp.Op != Del { err = f(mp) + numNormalPostingsRead += 1 if err != nil { break loop } + } else { + numDeletePostingsRead += 1 } prevUid = mp.Uid midx++ case pp.Uid == mp.Uid: if mp.Op != Del { err = f(mp) + numNormalPostingsRead += 1 if err != nil { break loop } + } else { + numDeletePostingsRead += 1 } prevUid = mp.Uid if err = pitr.next(); err != nil { @@ -762,6 +807,38 @@ func (l *List) IsEmpty(readTs, afterUid uint64) (bool, error) { return count == 0, nil } +func (l *List) getPostingAndLengthNoSort(readTs, afterUid, uid uint64) (int, bool, *pb.Posting) { + l.AssertRLock() + + dec := codec.Decoder{Pack: l.plist.Pack} + uids := dec.Seek(uid, codec.SeekStart) + length := codec.ExactLen(l.plist.Pack) + found := len(uids) > 0 && uids[0] == uid + + for _, plist := range l.mutationMap { + for _, mpost := range plist.Postings { + if (mpost.CommitTs > 0 && mpost.CommitTs <= readTs) || (mpost.StartTs == readTs) { + if hasDeleteAll(mpost) { + found = false + length = 0 + continue + } + if mpost.Uid == uid { + found = (mpost.Op == Set) + } + if mpost.Op == Set { + length += 1 + } else { + length -= 1 + } + + } + } + } + + return length, found, nil +} + func (l *List) getPostingAndLength(readTs, afterUid, uid uint64) (int, bool, *pb.Posting) { l.AssertRLock() var count int @@ -1130,6 +1207,8 @@ func (l *List) rollup(readTs uint64, split bool) (*rollupOutput, error) { } if len(out.plist.Splits) > 0 || len(l.mutationMap) > 0 { + // In case there were splits, this would read all the splits from + // Badger. if err := l.encode(out, readTs, split); err != nil { return nil, errors.Wrapf(err, "while encoding") } @@ -1219,9 +1298,16 @@ func (l *List) Uids(opt ListOptions) (*pb.List, error) { // Do The intersection here as it's optimized. out.Uids = res + lenBefore := len(res) if opt.Intersect != nil { algo.IntersectWith(out, opt.Intersect, out) } + lenAfter := len(out.Uids) + if lenBefore-lenAfter > 0 { + // If we see this log, that means that iterate is going over too many elements that it doesn't need to + glog.V(3).Infof("Retrieved a list. length before intersection: %d, length after: %d, extra"+ + " elements: %d", lenBefore, lenAfter, lenBefore-lenAfter) + } return out, nil } diff --git a/posting/list_test.go b/posting/list_test.go index 5256938e5d1..45c0d963262 100644 --- a/posting/list_test.go +++ b/posting/list_test.go @@ -42,6 +42,21 @@ func setMaxListSize(newMaxListSize int) { maxListSize = newMaxListSize } +func readPostingListFromDisk(key []byte, pstore *badger.DB, readTs uint64) (*List, error) { + txn := pstore.NewTransactionAt(readTs, false) + defer txn.Discard() + + // When we do rollups, an older version would go to the top of the LSM tree, which can cause + // issues during txn.Get. Therefore, always iterate. + iterOpts := badger.DefaultIteratorOptions + iterOpts.AllVersions = true + iterOpts.PrefetchValues = false + itr := txn.NewKeyIterator(key, iterOpts) + defer itr.Close() + itr.Seek(key) + return ReadPostingList(key, itr) +} + func (l *List) PostingList() *pb.PostingList { l.RLock() defer l.RUnlock() @@ -177,7 +192,7 @@ func checkValue(t *testing.T, ol *List, val string, readTs uint64) { // TODO(txn): Add tests after lru eviction func TestAddMutation_Value(t *testing.T) { key := x.DataKey(x.GalaxyAttr(x.GalaxyAttr("value")), 10) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) edge := &pb.DirectedEdge{ Value: []byte("oh hey there"), @@ -440,7 +455,7 @@ func TestRollupMaxTsIsSet(t *testing.T) { maxListSize = math.MaxInt32 key := x.DataKey(x.GalaxyAttr("bal"), 1333) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) var commits int N := int(1e6) @@ -461,7 +476,7 @@ func TestRollupMaxTsIsSet(t *testing.T) { } require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) } commits++ @@ -474,7 +489,7 @@ func TestMillion(t *testing.T) { maxListSize = math.MaxInt32 key := x.DataKey(x.GalaxyAttr("bal"), 1331) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) var commits int N := int(1e6) @@ -492,7 +507,7 @@ func TestMillion(t *testing.T) { kvs, err := ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) } commits++ @@ -512,7 +527,7 @@ func TestMillion(t *testing.T) { func TestAddMutation_mrjn2(t *testing.T) { ctx := context.Background() key := x.DataKey(x.GalaxyAttr("bal"), 1001) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) var readTs uint64 for readTs = 1; readTs < 10; readTs++ { @@ -587,7 +602,7 @@ func TestAddMutation_mrjn2(t *testing.T) { func TestAddMutation_gru(t *testing.T) { key := x.DataKey(x.GalaxyAttr("question.tag"), 0x01) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) { @@ -620,7 +635,7 @@ func TestAddMutation_gru(t *testing.T) { func TestAddMutation_gru2(t *testing.T) { key := x.DataKey(x.GalaxyAttr("question.tag"), 0x100) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) { @@ -667,7 +682,7 @@ func TestAddAndDelMutation(t *testing.T) { // Ensure each test uses unique key since we don't clear the postings // after each test key := x.DataKey(x.GalaxyAttr("dummy_key"), 0x927) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) { @@ -695,7 +710,7 @@ func TestAddAndDelMutation(t *testing.T) { func TestAfterUIDCount(t *testing.T) { key := x.DataKey(x.GalaxyAttr("value"), 22) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) // Set value to cars and merge to BadgerDB. edge := &pb.DirectedEdge{} @@ -766,7 +781,7 @@ func TestAfterUIDCount(t *testing.T) { func TestAfterUIDCount2(t *testing.T) { key := x.DataKey(x.GalaxyAttr("value"), 23) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) // Set value to cars and merge to BadgerDB. @@ -793,7 +808,7 @@ func TestAfterUIDCount2(t *testing.T) { func TestDelete(t *testing.T) { key := x.DataKey(x.GalaxyAttr("value"), 25) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) // Set value to cars and merge to BadgerDB. @@ -815,7 +830,7 @@ func TestDelete(t *testing.T) { func TestAfterUIDCountWithCommit(t *testing.T) { key := x.DataKey(x.GalaxyAttr("value"), 26) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) // Set value to cars and merge to BadgerDB. @@ -906,7 +921,7 @@ func createMultiPartList(t *testing.T, size int, addFacet bool) (*List, int) { maxListSize = 5000 key := x.DataKey(x.GalaxyAttr(uuid.New().String()), 1331) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) commits := 0 for i := 1; i <= size; i++ { @@ -926,7 +941,7 @@ func createMultiPartList(t *testing.T, size int, addFacet bool) (*List, int) { kvs, err := ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) } commits++ @@ -938,7 +953,7 @@ func createMultiPartList(t *testing.T, size int, addFacet bool) (*List, int) { require.Equal(t, uint64(size+1), kv.Version) } require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) require.Nil(t, ol.plist.Pack) require.Equal(t, 0, len(ol.plist.Postings)) @@ -954,7 +969,7 @@ func createAndDeleteMultiPartList(t *testing.T, size int) (*List, int) { maxListSize = 10000 key := x.DataKey(x.GalaxyAttr(uuid.New().String()), 1331) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) commits := 0 for i := 1; i <= size; i++ { @@ -969,7 +984,7 @@ func createAndDeleteMultiPartList(t *testing.T, size int) (*List, int) { kvs, err := ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) } commits++ @@ -990,7 +1005,7 @@ func createAndDeleteMultiPartList(t *testing.T, size int) (*List, int) { kvs, err := ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) } commits++ @@ -1002,7 +1017,7 @@ func createAndDeleteMultiPartList(t *testing.T, size int) (*List, int) { func TestLargePlistSplit(t *testing.T) { key := x.DataKey(uuid.New().String(), 1331) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) b := make([]byte, 30<<20) _, _ = rand.Read(b) @@ -1019,7 +1034,7 @@ func TestLargePlistSplit(t *testing.T) { _, err = ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) b = make([]byte, 10<<20) _, _ = rand.Read(b) @@ -1038,7 +1053,7 @@ func TestLargePlistSplit(t *testing.T) { kvs, err := ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) // require.Nil(t, ol.plist.Bitmap) require.Equal(t, 0, len(ol.plist.Postings)) @@ -1113,7 +1128,7 @@ func TestBinSplit(t *testing.T) { maxListSize = originalListSize }() key := x.DataKey(x.GalaxyAttr(uuid.New().String()), 1331) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) for i := 1; i <= size; i++ { edge := &pb.DirectedEdge{ @@ -1131,7 +1146,7 @@ func TestBinSplit(t *testing.T) { require.Equal(t, uint64(size+1), kv.Version) } require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) require.Equal(t, 0, len(ol.plist.Splits)) require.Equal(t, size, len(ol.plist.Postings)) @@ -1245,7 +1260,7 @@ func TestMultiPartListWriteToDisk(t *testing.T) { require.Equal(t, len(kvs), len(originalList.plist.Splits)+1) require.NoError(t, writePostingListToDisk(kvs)) - newList, err := getNew(kvs[0].Key, ps, math.MaxUint64) + newList, err := readPostingListFromDisk(kvs[0].Key, ps, math.MaxUint64) require.NoError(t, err) opt := ListOptions{ReadTs: uint64(size) + 1} @@ -1294,7 +1309,7 @@ func TestMultiPartListDeleteAndAdd(t *testing.T) { // Add entries to the maps. key := x.DataKey(x.GalaxyAttr(uuid.New().String()), 1331) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) for i := 1; i <= size; i++ { edge := &pb.DirectedEdge{ @@ -1308,7 +1323,7 @@ func TestMultiPartListDeleteAndAdd(t *testing.T) { kvs, err := ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) } } @@ -1335,7 +1350,7 @@ func TestMultiPartListDeleteAndAdd(t *testing.T) { kvs, err := ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) } } @@ -1344,7 +1359,7 @@ func TestMultiPartListDeleteAndAdd(t *testing.T) { kvs, err := ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) for _, kv := range kvs { require.Equal(t, baseStartTs+uint64(1+size/2), kv.Version) @@ -1372,7 +1387,7 @@ func TestMultiPartListDeleteAndAdd(t *testing.T) { kvs, err := ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) } } @@ -1381,7 +1396,7 @@ func TestMultiPartListDeleteAndAdd(t *testing.T) { kvs, err = ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) // Verify all entries are once again in the list. @@ -1435,7 +1450,7 @@ func TestRecursiveSplits(t *testing.T) { // Create a list that should be split recursively. size := int(1e5) key := x.DataKey(x.GalaxyAttr(uuid.New().String()), 1331) - ol, err := getNew(key, ps, math.MaxUint64) + ol, err := readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) commits := 0 for i := 1; i <= size; i++ { @@ -1457,7 +1472,7 @@ func TestRecursiveSplits(t *testing.T) { kvs, err := ol.Rollup(nil, math.MaxUint64) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) - ol, err = getNew(key, ps, math.MaxUint64) + ol, err = readPostingListFromDisk(key, ps, math.MaxUint64) require.NoError(t, err) require.True(t, len(ol.plist.Splits) > 2) @@ -1496,7 +1511,7 @@ func TestMain(m *testing.M) { func BenchmarkAddMutations(b *testing.B) { key := x.DataKey(x.GalaxyAttr("name"), 1) - l, err := getNew(key, ps, math.MaxUint64) + l, err := readPostingListFromDisk(key, ps, math.MaxUint64) if err != nil { b.Error(err) } diff --git a/posting/lists.go b/posting/lists.go index e813ebfd303..b1abd569cf7 100644 --- a/posting/lists.go +++ b/posting/lists.go @@ -49,10 +49,12 @@ func Init(ps *badger.DB, cacheSize int64) { pstore = ps closer = z.NewCloser(1) go x.MonitorMemoryMetrics(closer) + // Initialize cache. if cacheSize == 0 { return } + var err error lCache, err = ristretto.NewCache(&ristretto.Config{ // Use 5% of cache memory for storing counters. @@ -61,11 +63,7 @@ func Init(ps *badger.DB, cacheSize int64) { BufferItems: 64, Metrics: true, Cost: func(val interface{}) int64 { - l, ok := val.(*List) - if !ok { - return int64(0) - } - return int64(l.DeepSize()) + return 0 }, }) x.Check(err) @@ -101,7 +99,8 @@ func GetNoStore(key []byte, readTs uint64) (rlist *List, err error) { type LocalCache struct { sync.RWMutex - startTs uint64 + startTs uint64 + commitTs uint64 // The keys for these maps is a string representation of the Badger key for the posting list. // deltas keep track of the updates made by txn. These must be kept around until written to disk @@ -169,6 +168,12 @@ func NoCache(startTs uint64) *LocalCache { return &LocalCache{startTs: startTs} } +func (lc *LocalCache) UpdateCommitTs(commitTs uint64) { + lc.Lock() + defer lc.Unlock() + lc.commitTs = commitTs +} + func (lc *LocalCache) Find(pred []byte, filter func([]byte) bool) (uint64, error) { txn := pstore.NewTransactionAt(lc.startTs, false) defer txn.Discard() @@ -237,8 +242,7 @@ func (lc *LocalCache) Find(pred []byte, filter func([]byte) bool) (uint64, error } if filter(vals.Value.([]byte)) { - result.Uids = append(result.Uids, pk.Uid) - break + return pk.Uid, nil } continue @@ -339,6 +343,8 @@ func (lc *LocalCache) UpdateDeltasAndDiscardLists() { } for key, pl := range lc.plists { + //pk, _ := x.Parse([]byte(key)) + //fmt.Printf("{TXN} Closing %v\n", pk) data := pl.getMutation(lc.startTs) if len(data) > 0 { lc.deltas[key] = data diff --git a/posting/mvcc.go b/posting/mvcc.go index d228ad10f0d..27cdc6fe7a9 100644 --- a/posting/mvcc.go +++ b/posting/mvcc.go @@ -19,8 +19,8 @@ package posting import ( "bytes" "encoding/hex" - "math" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -59,6 +59,18 @@ type incrRollupi struct { closer *z.Closer } +type CachePL struct { + count int + list *List + lastUpdate uint64 +} + +type GlobalCache struct { + sync.RWMutex + + items map[string]*CachePL +} + var ( // ErrTsTooOld is returned when a transaction is too old to be applied. ErrTsTooOld = errors.Errorf("Transaction is too old") @@ -72,6 +84,8 @@ var ( IncrRollup = &incrRollupi{ priorityKeys: make([]*pooledKeys, 2), } + + globalCache = &GlobalCache{items: make(map[string]*CachePL, 100)} ) func init() { @@ -109,7 +123,7 @@ func (ir *incrRollupi) rollUpKey(writer *TxnWriter, key []byte) error { } } - l, err := GetNoStore(key, math.MaxUint64) + l, err := GetNoStore(key, ts) if err != nil { return err } @@ -118,9 +132,18 @@ func (ir *incrRollupi) rollUpKey(writer *TxnWriter, key []byte) error { if err != nil { return err } - // Clear the list from the cache after a rollup. - RemoveCacheFor(key) + globalCache.Lock() + val, ok := globalCache.items[string(key)] + if ok { + val.list = nil + } + globalCache.Unlock() + // TODO Update cache with rolled up results + // If we do a rollup, we typically won't need to update the key in cache. + // The only caveat is that the key written by rollup would be written at +1 + // timestamp, hence bumping the latest TS for the key by 1. The cache should + // understand that. const N = uint64(1000) if glog.V(2) { if count := atomic.AddUint64(&ir.count, 1); count%N == 0 { @@ -172,8 +195,8 @@ func (ir *incrRollupi) Process(closer *z.Closer, getNewTs func(bool) uint64) { currTs := time.Now().Unix() for _, key := range *batch { hash := z.MemHash(key) - if elem := m[hash]; currTs-elem >= 10 { - // Key not present or Key present but last roll up was more than 10 sec ago. + if elem := m[hash]; currTs-elem >= 2 { + // Key not present or Key present but last roll up was more than 2 sec ago. // Add/Update map and rollup. m[hash] = currTs if err := ir.rollUpKey(writer, key); err != nil { @@ -235,7 +258,7 @@ func (txn *Txn) addConflictKey(conflictKey uint64) { } // FillContext updates the given transaction context with data from this transaction. -func (txn *Txn) FillContext(ctx *api.TxnContext, gid uint32) { +func (txn *Txn) FillContext(ctx *api.TxnContext, gid uint32, isErrored bool) { txn.Lock() ctx.StartTs = txn.StartTs @@ -249,7 +272,12 @@ func (txn *Txn) FillContext(ctx *api.TxnContext, gid uint32) { ctx.Keys = x.Unique(ctx.Keys) txn.Unlock() - txn.Update() + // If the trasnaction has errored out, we don't need to update it, as these values will never be read. + // Sometimes, the transaction might have failed due to timeout. If we let this trasnactino update, there + // could be deadlock with the running transaction. + if !isErrored { + txn.Update() + } txn.cache.fillPreds(ctx, gid) } @@ -315,30 +343,57 @@ func (txn *Txn) CommitToDisk(writer *TxnWriter, commitTs uint64) error { return nil } -// ResetCache will clear all the cached list. func ResetCache() { + globalCache.Lock() + globalCache.items = make(map[string]*CachePL) + globalCache.Unlock() lCache.Clear() } -// RemoveCacheFor will delete the list corresponding to the given key. -func RemoveCacheFor(key []byte) { - // TODO: investigate if this can be done by calling Set with a nil value. - lCache.Del(key) +func NewCachePL() *CachePL { + return &CachePL{ + count: 0, + list: nil, + lastUpdate: 0, + } } // RemoveCachedKeys will delete the cached list by this txn. -func (txn *Txn) RemoveCachedKeys() { +func (txn *Txn) UpdateCachedKeys(commitTs uint64) { if txn == nil || txn.cache == nil { return } - for key := range txn.cache.deltas { - lCache.Del(key) - } -} -func WaitForCache() { - // TODO Investigate if this is needed and why Jepsen tests fail with the cache enabled. - // lCache.Wait() + for key, delta := range txn.cache.deltas { + pk, _ := x.Parse([]byte(key)) + if !ShouldGoInCache(pk) { + continue + } + globalCache.Lock() + val, ok := globalCache.items[key] + if !ok { + val = NewCachePL() + val.lastUpdate = commitTs + globalCache.items[key] = val + } + if commitTs != 0 { + // TODO Delete this if the values are too old in an async thread + val.lastUpdate = commitTs + } + if !ok { + globalCache.Unlock() + continue + } + + val.count -= 1 + + if commitTs != 0 && val.list != nil { + p := new(pb.PostingList) + x.Check(p.Unmarshal(delta)) + val.list.setMutationAfterCommit(txn.StartTs, commitTs, delta) + } + globalCache.Unlock() + } } func unmarshalOrCopy(plist *pb.PostingList, item *badger.Item) error { @@ -457,33 +512,63 @@ func ReadPostingList(key []byte, it *badger.Iterator) (*List, error) { return l, nil } -func getNew(key []byte, pstore *badger.DB, readTs uint64) (*List, error) { - cachedVal, ok := lCache.Get(key) - if ok { - l, ok := cachedVal.(*List) - if ok && l != nil { - // No need to clone the immutable layer or the key since mutations will not modify it. - lCopy := &List{ - minTs: l.minTs, - maxTs: l.maxTs, - key: key, - plist: l.plist, - } - l.RLock() - if l.mutationMap != nil { - lCopy.mutationMap = make(map[uint64]*pb.PostingList, len(l.mutationMap)) - for ts, pl := range l.mutationMap { - lCopy.mutationMap[ts] = proto.Clone(pl).(*pb.PostingList) - } - } - l.RUnlock() - return lCopy, nil - } +func copyList(l *List) *List { + l.AssertRLock() + // No need to clone the immutable layer or the key since mutations will not modify it. + lCopy := &List{ + minTs: l.minTs, + maxTs: l.maxTs, + key: l.key, + plist: l.plist, } + lCopy.mutationMap = make(map[uint64]*pb.PostingList, len(l.mutationMap)) + for k, v := range l.mutationMap { + lCopy.mutationMap[k] = proto.Clone(v).(*pb.PostingList) + } + return lCopy +} +func (c *CachePL) Set(l *List, readTs uint64) { + if c.lastUpdate < readTs && (c.list == nil || c.list.maxTs < l.maxTs) { + c.list = l + } +} + +func ShouldGoInCache(pk x.ParsedKey) bool { + return (!pk.IsData() && strings.HasSuffix(pk.Attr, "dgraph.type")) +} + +func getNew(key []byte, pstore *badger.DB, readTs uint64) (*List, error) { if pstore.IsClosed() { return nil, badger.ErrDBClosed } + + pk, _ := x.Parse(key) + + if ShouldGoInCache(pk) { + globalCache.Lock() + cacheItem, ok := globalCache.items[string(key)] + if !ok { + cacheItem = NewCachePL() + globalCache.items[string(key)] = cacheItem + } + cacheItem.count += 1 + + // We use badger subscription to invalidate the cache. For every write we make the value + // corresponding to the key in the cache to nil. So, if we get some non-nil value from the cache + // then it means that no writes have happened after the last set of this key in the cache. + if ok { + if cacheItem.list != nil && cacheItem.list.minTs <= readTs { + cacheItem.list.RLock() + lCopy := copyList(cacheItem.list) + cacheItem.list.RUnlock() + globalCache.Unlock() + return lCopy, nil + } + } + globalCache.Unlock() + } + txn := pstore.NewTransactionAt(readTs, false) defer txn.Discard() @@ -499,6 +584,26 @@ func getNew(key []byte, pstore *badger.DB, readTs uint64) (*List, error) { if err != nil { return l, err } - lCache.Set(key, l, 0) + + // Only set l to the cache if readTs >= latestTs, which implies that l is + // the latest version of the PL. We also check that we're reading a version + // from Badger, which is higher than the write registered by the cache. + if ShouldGoInCache(pk) { + globalCache.Lock() + l.RLock() + cacheItem, ok := globalCache.items[string(key)] + if !ok { + cacheItemNew := NewCachePL() + cacheItemNew.count = 1 + cacheItemNew.list = copyList(l) + cacheItemNew.lastUpdate = l.maxTs + globalCache.items[string(key)] = cacheItemNew + } else { + cacheItem.Set(copyList(l), readTs) + } + l.RUnlock() + globalCache.Unlock() + } + return l, nil } diff --git a/posting/mvcc_test.go b/posting/mvcc_test.go index ad0ffa9564b..9061d61d359 100644 --- a/posting/mvcc_test.go +++ b/posting/mvcc_test.go @@ -61,6 +61,42 @@ func TestIncrRollupGetsCancelledQuickly(t *testing.T) { } } +func TestCacheAfterDeltaUpdateRecieved(t *testing.T) { + attr := x.GalaxyAttr("cache") + key := x.IndexKey(attr, "temp") + + // Create a delta from 5->15. Mimick how a follower recieves a delta. + p := new(pb.PostingList) + p.Postings = []*pb.Posting{{ + Uid: 1, + StartTs: 5, + CommitTs: 15, + Op: 1, + }} + delta, err := p.Marshal() + require.NoError(t, err) + + // Write delta to disk and call update + txn := Oracle().RegisterStartTs(5) + txn.cache.deltas[string(key)] = delta + + writer := NewTxnWriter(pstore) + require.NoError(t, txn.CommitToDisk(writer, 15)) + require.NoError(t, writer.Flush()) + + txn.UpdateCachedKeys(15) + + // Read key at timestamp 10. Make sure cache is not updated by this, as there is a later read. + l, err := GetNoStore(key, 10) + require.NoError(t, err) + require.Equal(t, len(l.mutationMap), 0) + + // Read at 20 should show the value + l1, err := GetNoStore(key, 20) + require.NoError(t, err) + require.Equal(t, len(l1.mutationMap), 1) +} + func TestRollupTimestamp(t *testing.T) { attr := x.GalaxyAttr("rollup") key := x.DataKey(attr, 1) @@ -79,8 +115,9 @@ func TestRollupTimestamp(t *testing.T) { edge := &pb.DirectedEdge{ Entity: 1, Attr: attr, - Value: []byte(x.Star), - Op: pb.DirectedEdge_DEL, + + Value: []byte(x.Star), + Op: pb.DirectedEdge_DEL, } addMutation(t, l, edge, Del, 9, 10, false) diff --git a/posting/oracle.go b/posting/oracle.go index dcc310ab8f7..57cf6f317b9 100644 --- a/posting/oracle.go +++ b/posting/oracle.go @@ -290,8 +290,14 @@ func (o *oracle) ProcessDelta(delta *pb.OracleDelta) { o.Lock() defer o.Unlock() - for _, txn := range delta.Txns { - delete(o.pendingTxns, txn.StartTs) + for _, status := range delta.Txns { + txn := o.pendingTxns[status.StartTs] + if txn != nil && status.CommitTs > 0 { + for k := range txn.cache.deltas { + IncrRollup.addKeyToBatch([]byte(k), 0) + } + } + delete(o.pendingTxns, status.StartTs) } curMax := o.MaxAssigned() if delta.MaxAssigned < curMax { diff --git a/query/cloud_test.go b/query/cloud_test.go index fae5ff42e17..b542f83a448 100644 --- a/query/cloud_test.go +++ b/query/cloud_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) @@ -37,10 +38,10 @@ func TestMain(m *testing.M) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - x.Panic(dg.LoginIntoNamespace(ctx, dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + x.Panic(dg.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) dc = c - client = dg.Dgraph + client.Dgraph = dg populateCluster(dc) m.Run() } diff --git a/query/common_test.go b/query/common_test.go index ea914bbb24f..aafcbd42b55 100644 --- a/query/common_test.go +++ b/query/common_test.go @@ -29,7 +29,7 @@ import ( "github.com/stretchr/testify/require" "github.com/dgraph-io/dgo/v230/protos/api" - "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/x" ) @@ -75,7 +75,7 @@ func processQuery(ctx context.Context, t *testing.T, query string) (string, erro return string(jsonResponse), err } -func processQueryRDF(ctx context.Context, t *testing.T, query string) (string, error) { +func processQueryRDF(ctx context.Context, query string) (string, error) { txn := client.NewTxn() defer func() { _ = txn.Discard(ctx) }() @@ -264,6 +264,14 @@ type Speaker { language } +type User { + name + password + gender + friend + alive +} + name : string @index(term, exact, trigram) @count @lang . name_lang : string @lang . lang_type : string @index(exact) . @@ -337,10 +345,9 @@ tweet-c : string @index(fulltext) . tweet-d : string @index(trigram) . name2 : string @index(term) . age2 : int @index(int) . -vectorNonIndex : float32vector . ` -func populateCluster(dc dgraphtest.Cluster) { +func populateCluster(dc dgraphapi.Cluster) { x.Panic(client.Alter(context.Background(), &api.Operation{DropAll: true})) // In the query package, we test using hard coded UIDs so that we know what results @@ -348,32 +355,30 @@ func populateCluster(dc dgraphtest.Cluster) { // all the UIDs we are using during the tests. x.Panic(dc.AssignUids(client.Dgraph, 65536)) - higher, err := dgraphtest.IsHigherVersion(dc.GetVersion(), "160a0faa5fc6233fdc5a4caa4a7a3d1591f460d0") - x.Panic(err) - var ts string - if higher { - ts = testSchema + `type User { - name - password - gender - friend - alive - user_profile - } - user_profile : float32vector @index(hnsw(metric:"euclidian")) .` - } else { - ts = testSchema + `type User { - name - password - gender - friend - alive - }` - } - setSchema(ts) - err = addTriplesToCluster(` - <1> "[1.0, 1.0, 2.0, 2.0]" . - <2> "[2.0, 1.0, 2.0, 2.0]" . + // higher, err := dgraphtest.IsHigherVersion(dc.GetVersion(), "160a0faa5fc6233fdc5a4caa4a7a3d1591f460d0") + // x.Panic(err) + // var ts string + // if higher { + // ts = testSchema + `type User { + // name + // password + // gender + // friend + // alive + // user_profile + // } + // user_profile : float32vector @index(hnsw(metric:"euclidian")) .` + // } else { + // ts = testSchema + `type User { + // name + // password + // gender + // friend + // alive + // }` + // } + setSchema(testSchema) + err := addTriplesToCluster(` <1> "Michonne" . <2> "King Lear" . <3> "Margaret" . diff --git a/query/integration_test.go b/query/integration_test.go index b05b2ae8bed..26a550dd842 100644 --- a/query/integration_test.go +++ b/query/integration_test.go @@ -22,6 +22,7 @@ import ( "context" "testing" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) @@ -34,8 +35,8 @@ func TestMain(m *testing.M) { client, cleanup, err = dc.Client() x.Panic(err) defer cleanup() - x.Panic(client.LoginIntoNamespace(context.Background(), dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + x.Panic(client.LoginIntoNamespace(context.Background(), dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) populateCluster(dc) m.Run() diff --git a/query/query0_test.go b/query/query0_test.go index 80739311d7f..43321336b2a 100644 --- a/query/query0_test.go +++ b/query/query0_test.go @@ -27,46 +27,11 @@ import ( "github.com/stretchr/testify/require" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/dql" ) -func TestGetVector(t *testing.T) { - query := ` - { - me(func: has(vectorNonIndex)) { - a as vectorNonIndex - } - aggregation() { - avg(val(a)) - sum(val(a)) - } - } - ` - js := processQueryNoErr(t, query) - k := `{ - "data": { - "me": [ - { - "vectorNonIndex": [1,1,2,2] - }, - { - "vectorNonIndex": [2,1,2,2] - } - ], - "aggregation": [ - { - "avg(val(a))": [1.5,1,2,2] - }, - { - "sum(val(a))": [3,2,4,4] - } - ] - } -}` - require.JSONEq(t, k, js) -} - func TestGetUID(t *testing.T) { query := ` { @@ -3657,5 +3622,5 @@ func TestInvalidRegex(t *testing.T) { } } -var client *dgraphtest.GrpcClient -var dc dgraphtest.Cluster +var client *dgraphapi.GrpcClient +var dc dgraphapi.Cluster diff --git a/query/query3_test.go b/query/query3_test.go index 50dc330b527..0f60f1cfcda 100644 --- a/query/query3_test.go +++ b/query/query3_test.go @@ -29,7 +29,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" - "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/testutil" ) @@ -2346,9 +2346,9 @@ func TestPasswordExpandAll1(t *testing.T) { ` js := processQueryNoErr(t, query) // During upgrade tests, UIDs of groot and guardians nodes might change. - a := dgraphtest.CompareJSON(`{"data":{"me":[{"alive":true, + a := dgraphapi.CompareJSON(`{"data":{"me":[{"alive":true, "gender":"female", "name":"Michonne"}]}}`, js) - b := dgraphtest.CompareJSON(`{"data":{"me":[{"alive":true, "dgraph.xid":"guardians", + b := dgraphapi.CompareJSON(`{"data":{"me":[{"alive":true, "dgraph.xid":"guardians", "gender":"female","name":"Michonne"}]}}`, js) if a != nil && b != nil { t.Error(a) @@ -2366,9 +2366,9 @@ func TestPasswordExpandAll2(t *testing.T) { ` js := processQueryNoErr(t, query) // During upgrade tests, UIDs of groot and guardians nodes might change. - a := dgraphtest.CompareJSON(`{"data":{"me":[{"alive":true, "checkpwd(password)":false, + a := dgraphapi.CompareJSON(`{"data":{"me":[{"alive":true, "checkpwd(password)":false, "gender":"female", "name":"Michonne"}]}}`, js) - b := dgraphtest.CompareJSON(`{"data":{"me":[{"alive":true, "dgraph.xid":"guardians", + b := dgraphapi.CompareJSON(`{"data":{"me":[{"alive":true, "dgraph.xid":"guardians", "checkpwd(password)":false, "gender":"female", "name":"Michonne"}]}}`, js) if a != nil && b != nil { t.Error(a) @@ -3213,8 +3213,8 @@ func TestMultiRegexInFilter(t *testing.T) { ` res := processQueryNoErr(t, query) // During upgrade tests, UIDs of groot and guardians nodes might change. - a := dgraphtest.CompareJSON(`{"data": {"q": [{"alive":true, "gender":"female","name":"Michonne"}]}}`, res) - b := dgraphtest.CompareJSON(`{"data": {"q": [{"alive":true, "dgraph.xid":"guardians", + a := dgraphapi.CompareJSON(`{"data": {"q": [{"alive":true, "gender":"female","name":"Michonne"}]}}`, res) + b := dgraphapi.CompareJSON(`{"data": {"q": [{"alive":true, "dgraph.xid":"guardians", "gender":"female","name":"Michonne"}]}}`, res) if a != nil && b != nil { t.Error(a) diff --git a/query/query4_test.go b/query/query4_test.go index 8cae7c298f2..5fa63608453 100644 --- a/query/query4_test.go +++ b/query/query4_test.go @@ -28,7 +28,7 @@ import ( "github.com/stretchr/testify/require" "github.com/dgraph-io/dgo/v230/protos/api" - "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/x" ) @@ -638,9 +638,9 @@ func TestFilterAtSameLevelOnUIDWithExpand(t *testing.T) { }` js := processQueryNoErr(t, query) // Because the UID for guardians and groot can change while upgrade tests are running - a := dgraphtest.CompareJSON(`{"data":{"q":[{"name":"Michonne","gender":"female","alive":true, + a := dgraphapi.CompareJSON(`{"data":{"q":[{"name":"Michonne","gender":"female","alive":true, "friend":[{"gender":"male","alive":true,"name":"Rick Grimes"}]}]}}`, js) - b := dgraphtest.CompareJSON(`{"data":{"q":[{"name":"Michonne","gender":"female","alive":true, + b := dgraphapi.CompareJSON(`{"data":{"q":[{"name":"Michonne","gender":"female","alive":true, "dgraph.xid":"guardians","friend":[{"gender":"male","alive":true,"name":"Rick Grimes"}]}]}}`, js) if a != nil && b != nil { t.Error(a) diff --git a/query/rdf_result_test.go b/query/rdf_result_test.go index c9e6b9fc0c8..d7ac107c158 100644 --- a/query/rdf_result_test.go +++ b/query/rdf_result_test.go @@ -36,7 +36,7 @@ func TestRDFResult(t *testing.T) { } }` - rdf, err := processQueryRDF(context.Background(), t, query) + rdf, err := processQueryRDF(context.Background(), query) require.NoError(t, err) require.Equal(t, rdf, `<0x1> "Michonne" . <0x1> <0x17> . @@ -69,7 +69,7 @@ func TestRDFNormalize(t *testing.T) { } } }` - _, err := processQueryRDF(context.Background(), t, query) + _, err := processQueryRDF(context.Background(), query) require.Error(t, err, "normalize directive is not supported in the rdf output format") } @@ -80,7 +80,7 @@ func TestRDFGroupBy(t *testing.T) { count(uid) } }` - _, err := processQueryRDF(context.Background(), t, query) + _, err := processQueryRDF(context.Background(), query) require.Contains(t, err.Error(), "groupby is not supported in rdf output format") } @@ -91,7 +91,7 @@ func TestRDFUidCount(t *testing.T) { count(uid) } }` - _, err := processQueryRDF(context.Background(), t, query) + _, err := processQueryRDF(context.Background(), query) require.Contains(t, err.Error(), "uid count is not supported in the rdf output format") } @@ -108,7 +108,7 @@ func TestRDFIngoreReflex(t *testing.T) { } } }` - _, err := processQueryRDF(context.Background(), t, query) + _, err := processQueryRDF(context.Background(), query) require.Contains(t, err.Error(), "ignorereflex directive is not supported in the rdf output format") } @@ -121,7 +121,7 @@ func TestRDFRecurse(t *testing.T) { friend } }` - rdf, err := processQueryRDF(context.Background(), t, query) + rdf, err := processQueryRDF(context.Background(), query) require.NoError(t, err) require.Equal(t, rdf, `<0x1> "Michonne" . <0x17> "Rick Grimes" . @@ -137,7 +137,7 @@ func TestRDFIgnoreUid(t *testing.T) { name } }` - rdf, err := processQueryRDF(context.Background(), t, query) + rdf, err := processQueryRDF(context.Background(), query) require.NoError(t, err) require.Equal(t, rdf, `<0x1> "Michonne" . <0x17> "Rick Grimes" . @@ -154,7 +154,7 @@ func TestRDFCheckPwd(t *testing.T) { } } ` - _, err := processQueryRDF(context.Background(), t, query) + _, err := processQueryRDF(context.Background(), query) require.Contains(t, err.Error(), "chkpwd function is not supported in the rdf output format") } @@ -172,7 +172,7 @@ func TestRDFPredicateCount(t *testing.T) { } ` - rdf, err := processQueryRDF(context.Background(), t, query) + rdf, err := processQueryRDF(context.Background(), query) require.NoError(t, err) require.Equal(t, `<0x1> "Michonne" . <0x17> "Rick Grimes" . @@ -201,7 +201,7 @@ func TestRDFFacets(t *testing.T) { path @facets(weight) } }` - _, err := processQueryRDF(context.Background(), t, query) + _, err := processQueryRDF(context.Background(), query) require.Contains(t, err.Error(), "facets are not supported in the rdf output format") } @@ -219,7 +219,7 @@ func TestDateRDF(t *testing.T) { } } ` - rdf, err := processQueryRDF(context.Background(), t, query) + rdf, err := processQueryRDF(context.Background(), query) require.NoError(t, err) expected := `<0x1> "Michonne" . <0x1> "female" . diff --git a/query/upgrade_test.go b/query/upgrade_test.go index 69dd62e4c44..996a9f6332b 100644 --- a/query/upgrade_test.go +++ b/query/upgrade_test.go @@ -25,29 +25,30 @@ import ( "testing" "time" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) func TestMain(m *testing.M) { - mutate := func(c dgraphtest.Cluster) { + mutate := func(c dgraphapi.Cluster) { dg, cleanup, err := c.Client() x.Panic(err) defer cleanup() - x.Panic(dg.LoginIntoNamespace(context.Background(), dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + x.Panic(dg.LoginIntoNamespace(context.Background(), dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) client = dg dc = c populateCluster(dc) } - query := func(c dgraphtest.Cluster) int { + query := func(c dgraphapi.Cluster) int { dg, cleanup, err := c.Client() x.Panic(err) defer cleanup() - x.Panic(dg.LoginIntoNamespace(context.Background(), dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + x.Panic(dg.LoginIntoNamespace(context.Background(), dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) client = dg dc = c @@ -65,7 +66,7 @@ func TestMain(m *testing.M) { hc, err := c.HTTPClient() x.Panic(err) - x.Panic(hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + x.Panic(hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) mutate(c) x.Panic(c.Upgrade(uc.After, uc.Strategy)) diff --git a/query/vector/integration_test.go b/query/vector/integration_test.go new file mode 100644 index 00000000000..7cb11b1b231 --- /dev/null +++ b/query/vector/integration_test.go @@ -0,0 +1,42 @@ +//go:build integration + +/* + * Copyright 2023 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package query + +import ( + "context" + "testing" + + "github.com/dgraph-io/dgraph/dgraphapi" + "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/x" +) + +func TestMain(m *testing.M) { + dc = dgraphtest.NewComposeCluster() + + var err error + var cleanup func() + client, cleanup, err = dc.Client() + x.Panic(err) + defer cleanup() + x.Panic(client.LoginIntoNamespace(context.Background(), dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + m.Run() +} diff --git a/query/vector/vector_graphql_test.go b/query/vector/vector_graphql_test.go new file mode 100644 index 00000000000..ca061a9c781 --- /dev/null +++ b/query/vector/vector_graphql_test.go @@ -0,0 +1,258 @@ +//go:build integration + +/* + * Copyright 2016-2023 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package query + +import ( + "encoding/json" + "fmt" + "math/rand" + "testing" + + "github.com/dgraph-io/dgraph/dgraphapi" + "github.com/stretchr/testify/require" +) + +type ProjectInput struct { + Title string `json:"title"` + TitleV []float32 `json:"title_v"` +} + +const ( + graphQLVectorSchema = ` + type Project { + id: ID! + title: String! @search(by: [exact]) + title_v: [Float!] @embedding @search(by: ["hnsw(metric: %v, exponent: 4)"]) + } ` +) + +func generateProjects(count int) []ProjectInput { + var projects []ProjectInput + for i := 0; i < count; i++ { + title := generateUniqueRandomTitle(projects) + titleV := generateRandomTitleV(5) // Assuming size is fixed at 5 + project := ProjectInput{ + Title: title, + TitleV: titleV, + } + projects = append(projects, project) + } + return projects +} + +func isTitleExists(title string, existingTitles []ProjectInput) bool { + for _, project := range existingTitles { + if project.Title == title { + return true + } + } + return false +} + +func generateUniqueRandomTitle(existingTitles []ProjectInput) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + const titleLength = 10 + title := make([]byte, titleLength) + for { + for i := range title { + title[i] = charset[rand.Intn(len(charset))] + } + titleStr := string(title) + if !isTitleExists(titleStr, existingTitles) { + return titleStr + } + } +} + +func generateRandomTitleV(size int) []float32 { + var titleV []float32 + for i := 0; i < size; i++ { + value := rand.Float32() + titleV = append(titleV, value) + } + return titleV +} + +func addProject(t *testing.T, hc *dgraphapi.HTTPClient, project ProjectInput) { + query := ` + mutation addProject($project: AddProjectInput!) { + addProject(input: [$project]) { + project { + title + title_v + } + } + }` + + params := dgraphapi.GraphQLParams{ + Query: query, + Variables: map[string]interface{}{"project": project}, + } + + _, err := hc.RunGraphqlQuery(params, false) + require.NoError(t, err) +} + +func queryProjectUsingTitle(t *testing.T, hc *dgraphapi.HTTPClient, title string) ProjectInput { + query := ` query QueryProject($title: String!) { + queryProject(filter: { title: { eq: $title } }) { + title + title_v + } + }` + + params := dgraphapi.GraphQLParams{ + Query: query, + Variables: map[string]interface{}{"title": title}, + } + response, err := hc.RunGraphqlQuery(params, false) + require.NoError(t, err) + type QueryResult struct { + QueryProject []ProjectInput `json:"queryProject"` + } + var resp QueryResult + err = json.Unmarshal([]byte(string(response)), &resp) + require.NoError(t, err) + + return resp.QueryProject[0] +} + +func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphapi.HTTPClient, vector []float32, topk int) []ProjectInput { + // query similar project by embedding + queryProduct := `query QuerySimilarProjectByEmbedding($by: ProjectEmbedding!, $topK: Int!, $vector: [Float!]!) { + querySimilarProjectByEmbedding(by: $by, topK: $topK, vector: $vector) { + title + title_v + } + } + + ` + + params := dgraphapi.GraphQLParams{ + Query: queryProduct, + Variables: map[string]interface{}{ + "by": "title_v", + "topK": topk, + "vector": vector, + }} + response, err := hc.RunGraphqlQuery(params, false) + require.NoError(t, err) + type QueryResult struct { + QueryProject []ProjectInput `json:"querySimilarProjectByEmbedding"` + } + var resp QueryResult + err = json.Unmarshal([]byte(string(response)), &resp) + require.NoError(t, err) + + return resp.QueryProject +} + +func TestVectorGraphQLAddVectorPredicate(t *testing.T) { + require.NoError(t, client.DropAll()) + + hc, err := dc.HTTPClient() + require.NoError(t, err) + hc.LoginIntoNamespace("groot", "password", 0) + // add schema + require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean"))) +} + +func TestVectorSchema(t *testing.T) { + require.NoError(t, client.DropAll()) + + hc, err := dc.HTTPClient() + require.NoError(t, err) + hc.LoginIntoNamespace("groot", "password", 0) + + schema := `type Project { + id: ID! + title: String! @search(by: [exact]) + title_v: [Float!] + }` + + // add schema + require.NoError(t, hc.UpdateGQLSchema(schema)) + require.Error(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean"))) +} + +func TestVectorGraphQlEuclidianIndexMutationAndQuery(t *testing.T) { + require.NoError(t, client.DropAll()) + hc, err := dc.HTTPClient() + require.NoError(t, err) + hc.LoginIntoNamespace("groot", "password", 0) + + schema := fmt.Sprintf(graphQLVectorSchema, "euclidean") + // add schema + require.NoError(t, hc.UpdateGQLSchema(schema)) + testVectorGraphQlMutationAndQuery(t, hc) +} + +func TestVectorGraphQlCosineIndexMutationAndQuery(t *testing.T) { + require.NoError(t, client.DropAll()) + hc, err := dc.HTTPClient() + require.NoError(t, err) + hc.LoginIntoNamespace("groot", "password", 0) + + schema := fmt.Sprintf(graphQLVectorSchema, "cosine") + // add schema + require.NoError(t, hc.UpdateGQLSchema(schema)) + testVectorGraphQlMutationAndQuery(t, hc) +} + +func TestVectorGraphQlDotProductIndexMutationAndQuery(t *testing.T) { + require.NoError(t, client.DropAll()) + hc, err := dc.HTTPClient() + require.NoError(t, err) + hc.LoginIntoNamespace("groot", "password", 0) + + schema := fmt.Sprintf(graphQLVectorSchema, "dotproduct") + // add schema + require.NoError(t, hc.UpdateGQLSchema(schema)) + testVectorGraphQlMutationAndQuery(t, hc) +} + +func testVectorGraphQlMutationAndQuery(t *testing.T, hc *dgraphapi.HTTPClient) { + var vectors [][]float32 + numProjects := 100 + projects := generateProjects(numProjects) + for _, project := range projects { + vectors = append(vectors, project.TitleV) + addProject(t, hc, project) + } + + for _, project := range projects { + p := queryProjectUsingTitle(t, hc, project.Title) + require.Equal(t, project.Title, p.Title) + require.Equal(t, project.TitleV, p.TitleV) + } + + for _, project := range projects { + p := queryProjectUsingTitle(t, hc, project.Title) + require.Equal(t, project.Title, p.Title) + require.Equal(t, project.TitleV, p.TitleV) + } + + // query similar project by embedding + for _, project := range projects { + similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV, numProjects) + for _, similarVec := range similarProjects { + require.Contains(t, vectors, similarVec.TitleV) + } + } +} diff --git a/query/vector_test.go b/query/vector/vector_test.go similarity index 72% rename from query/vector_test.go rename to query/vector/vector_test.go index a387b4d8479..eb5568d42c6 100644 --- a/query/vector_test.go +++ b/query/vector/vector_test.go @@ -25,8 +25,12 @@ import ( "math/rand" "strings" "testing" + "time" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" + "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/x" "github.com/stretchr/testify/require" ) @@ -38,6 +42,124 @@ const ( vectorSchemaWithoutIndex = `%v: float32vector .` ) +var client *dgraphapi.GrpcClient +var dc dgraphapi.Cluster + +func setSchema(schema string) { + var err error + for retry := 0; retry < 60; retry++ { + err = client.Alter(context.Background(), &api.Operation{Schema: schema}) + if err == nil { + return + } + time.Sleep(time.Second) + } + panic(fmt.Sprintf("Could not alter schema. Got error %v", err.Error())) +} + +func dropPredicate(pred string) { + err := client.Alter(context.Background(), &api.Operation{ + DropAttr: pred, + }) + if err != nil { + panic(fmt.Sprintf("Could not drop predicate. Got error %v", err.Error())) + } +} + +func processQuery(ctx context.Context, t *testing.T, query string) (string, error) { + txn := client.NewTxn() + defer func() { + if err := txn.Discard(ctx); err != nil { + t.Logf("error discarding txn: %v", err) + } + }() + + res, err := txn.Query(ctx, query) + if err != nil { + return "", err + } + + response := map[string]interface{}{} + response["data"] = json.RawMessage(string(res.Json)) + + jsonResponse, err := json.Marshal(response) + require.NoError(t, err) + return string(jsonResponse), err +} + +func processQueryRDF(ctx context.Context, t *testing.T, query string) (string, error) { + txn := client.NewTxn() + defer func() { _ = txn.Discard(ctx) }() + + res, err := txn.Do(ctx, &api.Request{ + Query: query, + RespFormat: api.Request_RDF, + }) + if err != nil { + return "", err + } + return string(res.Rdf), err +} + +func processQueryNoErr(t *testing.T, query string) string { + res, err := processQuery(context.Background(), t, query) + require.NoError(t, err) + return res +} + +// processQueryForMetrics works like processQuery but returns metrics instead of response. +func processQueryForMetrics(t *testing.T, query string) *api.Metrics { + txn := client.NewTxn() + defer func() { _ = txn.Discard(context.Background()) }() + + res, err := txn.Query(context.Background(), query) + require.NoError(t, err) + return res.Metrics +} + +func processQueryWithVars(t *testing.T, query string, + vars map[string]string) (string, error) { + txn := client.NewTxn() + defer func() { _ = txn.Discard(context.Background()) }() + + res, err := txn.QueryWithVars(context.Background(), query, vars) + if err != nil { + return "", err + } + + response := map[string]interface{}{} + response["data"] = json.RawMessage(string(res.Json)) + + jsonResponse, err := json.Marshal(response) + require.NoError(t, err) + return string(jsonResponse), err +} + +func addTriplesToCluster(triples string) error { + txn := client.NewTxn() + ctx := context.Background() + defer func() { _ = txn.Discard(ctx) }() + + _, err := txn.Mutate(ctx, &api.Mutation{ + SetNquads: []byte(triples), + CommitNow: true, + }) + return err +} + +func deleteTriplesInCluster(triples string) { + txn := client.NewTxn() + ctx := context.Background() + defer func() { _ = txn.Discard(ctx) }() + + _, err := txn.Mutate(ctx, &api.Mutation{ + DelNquads: []byte(triples), + CommitNow: true, + }) + if err != nil { + panic(fmt.Sprintf("Could not delete triples. Got error %v", err.Error())) + } +} func updateVector(t *testing.T, triple string, pred string) []float32 { uid := strings.Split(triple, " ")[0] randomVec := generateRandomVector(10) @@ -148,6 +270,10 @@ func querySingleVectorError(t *testing.T, vector, pred string, validateError boo return []float32{}, err } + if len(data.Vector) == 0 { + return []float32{}, nil + } + return data.Vector[0].VTest, nil } @@ -303,24 +429,66 @@ func TestVectorsMutateFixedLengthWithDiffrentIndexes(t *testing.T) { testVectorMutationSameLength(t) dropPredicate("vtest") - setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "dot_product")) + setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "dotproduct")) testVectorMutationSameLength(t) dropPredicate("vtest") } +func TestVectorDeadlockwithTimeout(t *testing.T) { + pred := "vtest1" + dc = dgraphtest.NewComposeCluster() + var cleanup func() + client, cleanup, err := dc.Client() + x.Panic(err) + defer cleanup() + + for i := 0; i < 5; i++ { + fmt.Println("Testing iteration: ", i) + ctx, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + err = client.LoginIntoNamespace(ctx, dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace) + require.NoError(t, err) + + err = client.Alter(context.Background(), &api.Operation{ + DropAttr: pred, + }) + dropPredicate(pred) + setSchema(fmt.Sprintf(vectorSchemaWithIndex, pred, "4", "euclidian")) + numVectors := 10000 + vectorSize := 1000 + + randomVectors, _ := generateRandomVectors(numVectors, vectorSize, pred) + + txn := client.NewTxn() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer func() { _ = txn.Discard(ctx) }() + defer cancel() + + _, err = txn.Mutate(ctx, &api.Mutation{ + SetNquads: []byte(randomVectors), + CommitNow: true, + }) + require.Error(t, err) + + err = txn.Commit(ctx) + require.Contains(t, err.Error(), "Transaction has already been committed or discarded") + } +} + func TestVectorMutateDiffrentLengthWithDiffrentIndexes(t *testing.T) { dropPredicate("vtest") setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "euclidian")) - testVectorMutationDiffrentLength(t, "can not subtract vectors of different lengths") + testVectorMutationDiffrentLength(t, "can not compute euclidian distance on vectors of different lengths") dropPredicate("vtest") setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "cosine")) - testVectorMutationDiffrentLength(t, "can not compute dot product on vectors of different lengths") + testVectorMutationDiffrentLength(t, "can not compute cosine distance on vectors of different lengths") dropPredicate("vtest") - setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "dot_product")) - testVectorMutationDiffrentLength(t, "can not subtract vectors of different lengths") + setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "dotproduct")) + testVectorMutationDiffrentLength(t, "can not compute dot product on vectors of different lengths") dropPredicate("vtest") } @@ -449,8 +617,9 @@ func TestVectorDelete(t *testing.T) { } triple := deleteTriple(len(triples) - 2) + // after deleteing all vectors, we should get an empty array of vectors in response when we do silimar_to query _, err = querySingleVectorError(t, strings.Split(triple, `"`)[1], "vtest", false) - require.NotNil(t, err) + require.NoError(t, err) } func TestVectorUpdate(t *testing.T) { @@ -559,3 +728,66 @@ func TestVectorTwoTxnWithoutCommit(t *testing.T) { require.Contains(t, resp, vectors[i]) } } + +func TestGetVector(t *testing.T) { + setSchema("vectorNonIndex : float32vector .") + + rdfs := ` + <1> "[1.0, 1.0, 2.0, 2.0]" . + <2> "[2.0, 1.0, 2.0, 2.0]" .` + require.NoError(t, addTriplesToCluster(rdfs)) + + query := ` + { + me(func: has(vectorNonIndex)) { + a as vectorNonIndex + } + aggregation() { + avg(val(a)) + sum(val(a)) + } + } + ` + js := processQueryNoErr(t, query) + k := `{ + "data": { + "me": [ + { + "vectorNonIndex": [ + 1, + 1, + 2, + 2 + ] + }, + { + "vectorNonIndex": [ + 2, + 1, + 2, + 2 + ] + } + ], + "aggregation": [ + { + "avg(val(a))": [ + 1.5, + 1, + 2, + 2 + ] + }, + { + "sum(val(a))": [ + 3, + 2, + 4, + 4 + ] + } + ] + } + }` + require.JSONEq(t, k, js) +} diff --git a/schema/schema.go b/schema/schema.go index 534264b21a0..f3275286082 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -21,6 +21,7 @@ import ( "encoding/hex" "fmt" "math" + "reflect" "sync" "github.com/golang/glog" @@ -50,7 +51,8 @@ var ( type contextKey int const ( - IsWrite contextKey = iota + IsWrite contextKey = iota + IsUniqueDgraphXid = true ) // GetWriteContext returns a context that sets the schema context for writing. @@ -791,6 +793,7 @@ func initialSchemaInternal(namespace uint64, all bool) []*pb.SchemaUpdate { ValueType: pb.Posting_STRING, Directive: pb.SchemaUpdate_INDEX, Upsert: true, + Unique: IsUniqueDgraphXid, Tokenizer: []string{"exact"}, }, { @@ -827,25 +830,60 @@ func initialSchemaInternal(namespace uint64, all bool) []*pb.SchemaUpdate { return initialSchema } -// IsPreDefPredChanged returns true if the initial update for the pre-defined -// predicate is different than the passed update. -// If the passed update is not a pre-defined predicate then it just returns false. -func IsPreDefPredChanged(update *pb.SchemaUpdate) bool { +// CheckAndModifyPreDefPredicate returns true if the initial update for the pre-defined +// predicate is different from the passed update. It may also modify certain predicates +// under specific conditions. +// If the passed update is not a pre-defined predicate, it returns false. +func CheckAndModifyPreDefPredicate(update *pb.SchemaUpdate) bool { // Return false for non-pre-defined predicates. if !x.IsPreDefinedPredicate(update.Predicate) { return false } - - initialSchema := CompleteInitialSchema(x.ParseNamespace(update.Predicate)) + ns := x.ParseNamespace(update.Predicate) + initialSchema := CompleteInitialSchema(ns) for _, original := range initialSchema { if original.Predicate != update.Predicate { continue } + + // For the dgraph.xid predicate, only the Unique field is allowed to be changed. + // Previously, the Unique attribute was not applied to the dgraph.xid predicate. + // For users upgrading from a lower version, we will set Unique to true. + if update.Predicate == x.NamespaceAttr(ns, "dgraph.xid") && !update.Unique { + if isDgraphXidChangeValid(original, update) { + update.Unique = true + return false + } + } return !proto.Equal(original, update) } return true } +// isDgraphXidChangeValid returns true if the change in the dgraph.xid predicate is valid. +func isDgraphXidChangeValid(original, update *pb.SchemaUpdate) bool { + changed := compareSchemaUpdates(original, update) + return len(changed) == 1 && changed[0] == "Unique" +} + +func compareSchemaUpdates(original, update *pb.SchemaUpdate) []string { + var changes []string + vOriginal := reflect.ValueOf(*original) + vUpdate := reflect.ValueOf(*update) + + for i := 0; i < vOriginal.NumField(); i++ { + fieldName := vOriginal.Type().Field(i).Name + valueOriginal := vOriginal.Field(i) + valueUpdate := vUpdate.Field(i) + + if !reflect.DeepEqual(valueOriginal.Interface(), valueUpdate.Interface()) { + changes = append(changes, fieldName) + } + } + + return changes +} + // IsPreDefTypeChanged returns true if the initial update for the pre-defined // type is different than the passed update. // If the passed update is not a pre-defined type than it just returns false. diff --git a/systest/backup/common/utils.go b/systest/backup/common/utils.go index cab1b54004c..dc60fd2b81f 100644 --- a/systest/backup/common/utils.go +++ b/systest/backup/common/utils.go @@ -108,7 +108,7 @@ func AddItemSchema(t *testing.T, header http.Header, whichAlpha string) { updateSchemaParams := &common.GraphQLParams{ Query: `mutation { updateGQLSchema( - input: { set: { schema: "type Item {id: ID!, name: String! @search(by: [hash]), price: String!}"}}) + input: { set: { schema: "type Item {id: ID!, name: String! @search(by: [\"hash\"]), price: String!}"}}) { gqlSchema { schema diff --git a/systest/backup/encryption/backup_test.go b/systest/backup/encryption/backup_test.go index 17bd5ddfbc5..85136bd62fa 100644 --- a/systest/backup/encryption/backup_test.go +++ b/systest/backup/encryption/backup_test.go @@ -52,6 +52,7 @@ var ( ) func TestBackupMinioE(t *testing.T) { + t.Skip() backupDst = "minio://minio:9001/dgraph-backup?secure=false" addr := testutil.ContainerAddr("minio", 9001) localBackupDst = "minio://" + addr + "/dgraph-backup?secure=false" diff --git a/systest/integration2/acl_test.go b/systest/integration2/acl_test.go new file mode 100644 index 00000000000..3a149fe93f8 --- /dev/null +++ b/systest/integration2/acl_test.go @@ -0,0 +1,115 @@ +//go:build integration2 + +/* + * Copyright 2024 Dgraph Labs, Inc. and Contributors * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" + "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/x" + + "github.com/stretchr/testify/require" +) + +type S struct { + Predicate string `json:"predicate"` + Type string `json:"type"` + Index bool `json:"index"` + Tokenizer []string `json:"tokenizer"` + Unique bool `json:"unique"` +} + +type Received struct { + Schema []S `json:"schema"` +} + +func testDuplicateUserUpgradeStrat(t *testing.T, strat dgraphtest.UpgradeStrategy) { + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour).WithVersion("v23.0.1") + c, err := dgraphtest.NewLocalCluster(conf) + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + + gc, cleanup, err := c.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + hc, err := c.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, gc.SetupSchema(`name: string .`)) + + rdfs := ` + _:a "alice" . + _:b "bob" . + _:c "sagar" . + _:d "ajay" .` + _, err = gc.Mutate(&api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}) + require.NoError(t, c.Upgrade("local", strat)) + gc, cleanup, err = c.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + hc, err = c.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + query := "schema {}" + resp, err := gc.Query(query) + require.NoError(t, err) + + var received Received + require.NoError(t, json.Unmarshal([]byte(resp.Json), &received)) + for _, s := range received.Schema { + if s.Predicate == "dgraph.xid" { + require.True(t, s.Unique) + } + } + + query = `{ + q(func: has(name)) { + count(uid) + } + }` + resp, err = gc.Query(query) + require.NoError(t, err) + require.Contains(t, string(resp.Json), `"count":4`) +} + +func TestDuplicateUserWithLiveLoader(t *testing.T) { + testDuplicateUserUpgradeStrat(t, dgraphtest.ExportImport) +} + +func TestDuplicateUserWithBackupRestore(t *testing.T) { + testDuplicateUserUpgradeStrat(t, dgraphtest.BackupRestore) +} + +func TestDuplicateUserWithInPlace(t *testing.T) { + testDuplicateUserUpgradeStrat(t, dgraphtest.InPlace) +} diff --git a/systest/integration2/bulk_loader_test.go b/systest/integration2/bulk_loader_test.go index 035d86ab60d..8257a3d38fc 100644 --- a/systest/integration2/bulk_loader_test.go +++ b/systest/integration2/bulk_loader_test.go @@ -23,7 +23,9 @@ import ( "testing" "time" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/x" "github.com/stretchr/testify/require" ) @@ -99,10 +101,10 @@ func TestBulkLoaderNoDqlSchema(t *testing.T) { // run some queries and ensure everything looks good hc, err := c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) - params := dgraphtest.GraphQLParams{ + params := dgraphapi.GraphQLParams{ Query: `query { getMessage(uniqueId: 3) { content @@ -112,15 +114,15 @@ func TestBulkLoaderNoDqlSchema(t *testing.T) { } data, err := hc.RunGraphqlQuery(params, false) require.NoError(t, err) - dgraphtest.CompareJSON(`{ + dgraphapi.CompareJSON(`{ "getMessage": { "content": "DVTCTXCVYI", "author": "USYMVFJYXA" } }`, string(data)) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, 1)) - params = dgraphtest.GraphQLParams{ + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, 1)) + params = dgraphapi.GraphQLParams{ Query: `query { getTemplate(uniqueId: 2) { content @@ -129,7 +131,7 @@ func TestBulkLoaderNoDqlSchema(t *testing.T) { } data, err = hc.RunGraphqlQuery(params, false) require.NoError(t, err) - dgraphtest.CompareJSON(`{ + dgraphapi.CompareJSON(`{ "getTemplate": { "content": "t2" } diff --git a/systest/integration2/graphql_schema_auth_test.go b/systest/integration2/graphql_schema_auth_test.go index 7880e2c7d74..497acde6691 100644 --- a/systest/integration2/graphql_schema_auth_test.go +++ b/systest/integration2/graphql_schema_auth_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" "github.com/stretchr/testify/require" @@ -37,8 +38,8 @@ func TestGraphqlSchema(t *testing.T) { hc, err := c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) // DGRAPHCORE-329 //nolint:lll @@ -74,7 +75,7 @@ func TestGraphqlSchema(t *testing.T) { # Dgraph.Authorization {"VerificationKey":"secretkey","Header":"X-Test-Auth","Namespace":"https://xyz.io/jwt/claims","Algo":"HS256","Audience":["aud"]}` require.NoError(t, hc.UpdateGQLSchema(sch1)) - params := dgraphtest.GraphQLParams{ + params := dgraphapi.GraphQLParams{ Query: `query { queryCity(filter: { id: { eq: 0 } }) { name @@ -84,7 +85,7 @@ func TestGraphqlSchema(t *testing.T) { _, err = hc.RunGraphqlQuery(params, false) require.NoError(t, err) - params = dgraphtest.GraphQLParams{ + params = dgraphapi.GraphQLParams{ Query: `query { queryCity(filter: { id: "0" }) { name @@ -714,9 +715,9 @@ func TestGraphqlSchema(t *testing.T) { """ newPlaces: [Place] """ - Uncommon visited places. A place is uncommon if not visited for more than three times prior to the summary date. + Uncommon visited places. A place is Uncommon if not visited for more than three times prior to the summary date. """ - uncommonPlaces: [Place] + UncommonPlaces: [Place] """ The total number of places visited in the day. """ @@ -726,9 +727,9 @@ func TestGraphqlSchema(t *testing.T) { """ newPlaceCount: Int """ - The number of uncommon places visited in the day. + The number of Uncommon places visited in the day. """ - uncommonPlaceCount: Int + UncommonPlaceCount: Int """ Cummulative time (in seconds) stationary. """ @@ -738,7 +739,7 @@ func TestGraphqlSchema(t *testing.T) { """ durationInNewPlaces: Int """ - Cummulative time (in seconds) spent in uncommon places within the day. + Cummulative time (in seconds) spent in Uncommon places within the day. """ durationInUncommonPlaces: Int """ diff --git a/systest/integration2/incremental_restore_test.go b/systest/integration2/incremental_restore_test.go index 9099bd6a1db..4078809f6e0 100644 --- a/systest/integration2/incremental_restore_test.go +++ b/systest/integration2/incremental_restore_test.go @@ -28,6 +28,7 @@ import ( "github.com/stretchr/testify/require" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) @@ -43,12 +44,12 @@ func TestIncrementalRestore(t *testing.T) { require.NoError(t, err) defer cleanup() require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) uids := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20} c.AssignUids(gc.Dgraph, uint64(len(uids))) @@ -69,7 +70,7 @@ func TestIncrementalRestore(t *testing.T) { incrFrom := i - 1 require.NoError(t, hc.Restore(c, dgraphtest.DefaultBackupDir, "", incrFrom, i)) - require.NoError(t, dgraphtest.WaitForRestore(c)) + require.NoError(t, dgraphapi.WaitForRestore(c)) for j := 1; j <= i; j++ { resp, err := gc.Query(fmt.Sprintf(`{q(func: uid(%v)) {money}}`, j)) diff --git a/systest/integration2/snapshot_test.go b/systest/integration2/snapshot_test.go index 67de8184731..2831adad9bc 100644 --- a/systest/integration2/snapshot_test.go +++ b/systest/integration2/snapshot_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" "github.com/stretchr/testify/require" @@ -42,13 +43,13 @@ func TestSnapshotTranferAfterNewNodeJoins(t *testing.T) { hc, err := c.HTTPClient() require.NoError(t, err) - hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) gc, cleanup, err := c.Client() require.NoError(t, err) defer cleanup() require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) prevSnapshotTs, err := hc.GetCurrentSnapshotTs(1) require.NoError(t, err) diff --git a/systest/license/integration_test.go b/systest/license/integration_test.go index b8c1258e8f7..78c89093fc0 100644 --- a/systest/license/integration_test.go +++ b/systest/license/integration_test.go @@ -23,12 +23,13 @@ import ( "github.com/stretchr/testify/suite" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" ) type LicenseTestSuite struct { suite.Suite - dc dgraphtest.Cluster + dc dgraphapi.Cluster } func (lsuite *LicenseTestSuite) SetupTest() { diff --git a/systest/license/license_test.go b/systest/license/license_test.go index 739a08f551f..950282590c7 100644 --- a/systest/license/license_test.go +++ b/systest/license/license_test.go @@ -23,7 +23,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/dgraphapi" ) var expiredKey = []byte(`-----BEGIN PGP MESSAGE----- @@ -138,20 +138,20 @@ func (lsuite *LicenseTestSuite) TestEnterpriseLicenseWithGraphqlEndPoint() { if tt.code == `Success` { require.NoError(t, err) // Check if the license is applied - dgraphtest.CompareJSON(`{"enterpriseLicense":{"response":{"code":"Success"}}}`, string(resp)) + dgraphapi.CompareJSON(`{"enterpriseLicense":{"response":{"code":"Success"}}}`, string(resp)) // check the user information in case the license is applied // Expired license should not be enabled even after it is applied assertLicenseNotEnabled(t, hcli, tt.user) } else { - dgraphtest.CompareJSON(`{"enterpriseLicense":null}`, string(resp)) + dgraphapi.CompareJSON(`{"enterpriseLicense":null}`, string(resp)) // check the error message in case the license is not applied require.Contains(t, err.Error(), tt.message) } } } -func assertLicenseNotEnabled(t *testing.T, hcli *dgraphtest.HTTPClient, user string) { +func assertLicenseNotEnabled(t *testing.T, hcli *dgraphapi.HTTPClient, user string) { response, err := hcli.GetZeroState() require.NoError(t, err) diff --git a/systest/license/upgrade_test.go b/systest/license/upgrade_test.go index 14822e176e5..d44b455697e 100644 --- a/systest/license/upgrade_test.go +++ b/systest/license/upgrade_test.go @@ -24,13 +24,14 @@ import ( "github.com/stretchr/testify/suite" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) type LicenseTestSuite struct { suite.Suite - dc dgraphtest.Cluster + dc dgraphapi.Cluster lc *dgraphtest.LocalCluster uc dgraphtest.UpgradeCombo } diff --git a/systest/multi-tenancy/basic_test.go b/systest/multi-tenancy/basic_test.go index d3d15189bd7..d51a025adb0 100644 --- a/systest/multi-tenancy/basic_test.go +++ b/systest/multi-tenancy/basic_test.go @@ -27,6 +27,7 @@ import ( "github.com/stretchr/testify/require" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/ee/acl" "github.com/dgraph-io/dgraph/x" @@ -42,7 +43,7 @@ func (msuite *MultitenancyTestSuite) TestAclBasic() { // Galaxy Login hcli, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NotNil(t, hcli.AccessJwt, "galaxy token is nil") require.NoError(t, err, "login with namespace failed") @@ -56,7 +57,7 @@ func (msuite *MultitenancyTestSuite) TestAclBasic() { defer cleanup() require.NoError(t, err) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns)) msuite.AddData(gcli) // Upgrade @@ -66,7 +67,7 @@ func (msuite *MultitenancyTestSuite) TestAclBasic() { defer cleanup() require.NoError(t, err) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns)) query := `{ me(func: has(name)) { @@ -76,7 +77,7 @@ func (msuite *MultitenancyTestSuite) TestAclBasic() { }` resp, err := gcli.Query(query) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON( + require.NoError(t, dgraphapi.CompareJSON( `{"me": [{"name":"guy1","nickname":"RG"},{"name": "guy2", "nickname":"RG2"}]}`, string(resp.Json))) // groot of namespace 0 should not see the data of namespace-1 @@ -84,15 +85,15 @@ func (msuite *MultitenancyTestSuite) TestAclBasic() { defer cleanup() require.NoError(t, err) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) resp, err = gcli.Query(query) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{"me": []}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{"me": []}`, string(resp.Json))) // Login to namespace 1 via groot and create new user alice. hcli, err = msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns) require.NotNil(t, hcli.AccessJwt, "token for the namespace is nil") require.NoErrorf(t, err, "login with namespace %d failed", ns) _, err = hcli.CreateUser("alice", "newpassword") @@ -105,21 +106,21 @@ func (msuite *MultitenancyTestSuite) TestAclBasic() { require.NoError(t, gcli.LoginIntoNamespace(context.Background(), "alice", "newpassword", ns)) resp, err = gcli.Query(query) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{}`, string(resp.Json))) // Create a new group, add alice to that group and give read access of to dev group. _, err = hcli.CreateGroup("dev") require.NoError(t, err) require.NoError(t, hcli.AddUserToGroup("alice", "dev")) require.NoError(t, hcli.AddRulesToGroup("dev", - []dgraphtest.AclRule{{Predicate: "name", Permission: acl.Read.Code}}, true)) + []dgraphapi.AclRule{{Predicate: "name", Permission: acl.Read.Code}}, true)) // Now alice should see the name predicate but not nickname. gcli, cleanup, err = msuite.dc.Client() defer cleanup() require.NoError(t, err) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), "alice", "newpassword", ns)) - dgraphtest.PollTillPassOrTimeout(gcli, query, `{"me": [{"name":"guy1"},{"name": "guy2"}]}`, aclQueryTimeout) + dgraphapi.PollTillPassOrTimeout(gcli, query, `{"me": [{"name":"guy1"},{"name": "guy2"}]}`, aclQueryTimeout) } func (msuite *MultitenancyTestSuite) TestNameSpaceLimitFlag() { @@ -128,7 +129,7 @@ func (msuite *MultitenancyTestSuite) TestNameSpaceLimitFlag() { // Galaxy login hcli, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NotNil(t, hcli.AccessJwt, "galaxy token is nil") require.NoErrorf(t, err, "login as groot into namespace %d failed", x.GalaxyNamespace) @@ -144,7 +145,7 @@ func (msuite *MultitenancyTestSuite) TestNameSpaceLimitFlag() { defer cleanup() require.NoError(t, e) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns)) require.NoError(t, gcli.SetupSchema(`name: string .`)) // trying to load more triplets than allowed,It should return error. @@ -167,7 +168,7 @@ func (msuite *MultitenancyTestSuite) TestPersistentQuery() { // Galaxy Login hcli1, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli1.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = hcli1.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NotNil(t, hcli1.AccessJwt, "galaxy token is nil") require.NoErrorf(t, err, "login as groot into namespace %d failed", x.GalaxyNamespace) @@ -181,20 +182,20 @@ func (msuite *MultitenancyTestSuite) TestPersistentQuery() { // Galaxy Login hcli1, err = msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli1.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = hcli1.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NotNil(t, hcli1.AccessJwt, "galaxy token is nil") require.NoErrorf(t, err, "login as groot into namespace %d failed", x.GalaxyNamespace) // Log into ns hcli2, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli2.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns) + err = hcli2.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns) require.NotNil(t, hcli2.AccessJwt, "token is nil") require.NoErrorf(t, err, "login as groot into namespace %d failed", ns) sch := `type Product { productID: ID! - name: String @search(by: [term]) + name: String @search(by: ["term"]) }` require.NoError(t, hcli1.UpdateGQLSchema(sch)) require.NoError(t, hcli2.UpdateGQLSchema(sch)) @@ -218,7 +219,7 @@ func (msuite *MultitenancyTestSuite) TestPersistentQuery() { require.Error(t, err) require.Contains(t, err.Error(), "PersistedQueryNotFound") - hcli3 := &dgraphtest.HTTPClient{HttpToken: &dgraphtest.HttpToken{AccessJwt: ""}} + hcli3 := &dgraphapi.HTTPClient{HttpToken: &dgraphapi.HttpToken{AccessJwt: ""}} _, err = hcli3.PostPersistentQuery("", sha1) require.Error(t, err) require.Contains(t, err.Error(), "unsupported protocol scheme") @@ -230,7 +231,7 @@ func (msuite *MultitenancyTestSuite) TestTokenExpired() { // Galaxy Login hcli, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NotNil(t, hcli.HttpToken, "galaxy token is nil") require.NoErrorf(t, err, "login as groot into namespace %d failed", x.GalaxyNamespace) @@ -244,7 +245,7 @@ func (msuite *MultitenancyTestSuite) TestTokenExpired() { // ns Login hcli, err = msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns) require.NotNil(t, hcli.HttpToken, "token is nil") require.NoErrorf(t, err, "login as groot into namespace %d failed", ns) @@ -266,7 +267,7 @@ func (msuite *MultitenancyTestSuite) TestTwoPermissionSetsInNameSpacesWithAcl() // Galaxy Login ghcli, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = ghcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = ghcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NotNil(t, ghcli, "galaxy token is nil") require.NoErrorf(t, err, "login as groot into namespace %d failed", x.GalaxyNamespace) @@ -284,7 +285,7 @@ func (msuite *MultitenancyTestSuite) TestTwoPermissionSetsInNameSpacesWithAcl() defer cleanup() require.NoError(t, e) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns1)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns1)) msuite.AddData(gcli) user1, user2 := "alice", "bob" @@ -293,7 +294,7 @@ func (msuite *MultitenancyTestSuite) TestTwoPermissionSetsInNameSpacesWithAcl() // Create user alice in ns1 hcli, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns1) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns1) require.NoErrorf(t, err, "login as groot into namespace %d failed", ns1) _, err = hcli.CreateUser(user1, user1passwd) require.NoError(t, err) @@ -309,13 +310,13 @@ func (msuite *MultitenancyTestSuite) TestTwoPermissionSetsInNameSpacesWithAcl() gcli, cleanup, err = msuite.dc.Client() defer cleanup() require.NoError(t, err) - require.NoError(t, gcli.LoginIntoNamespace(context.Background(), dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns2)) + require.NoError(t, gcli.LoginIntoNamespace(context.Background(), dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns2)) msuite.AddData(gcli) // Create user bob hcli, err = msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns2) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns2) require.NoErrorf(t, err, "login with namespace %d failed", ns2) _, err = hcli.CreateUser(user2, user2passwd) require.NoError(t, err) @@ -331,14 +332,14 @@ func (msuite *MultitenancyTestSuite) TestTwoPermissionSetsInNameSpacesWithAcl() defer cleanup() require.NoError(t, err) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), user1, user1passwd, ns1)) - dgraphtest.PollTillPassOrTimeout(gcli, query, `{"me": [{"name":"guy2"}, {"name":"guy1"}]}`, aclQueryTimeout) + dgraphapi.PollTillPassOrTimeout(gcli, query, `{"me": [{"name":"guy2"}, {"name":"guy1"}]}`, aclQueryTimeout) // Query via bob and check result gcli, cleanup, err = msuite.dc.Client() defer cleanup() require.NoError(t, err) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), user2, user2passwd, ns2)) - require.NoError(t, dgraphtest.PollTillPassOrTimeout(gcli, query, `{}`, aclQueryTimeout)) + require.NoError(t, dgraphapi.PollTillPassOrTimeout(gcli, query, `{}`, aclQueryTimeout)) // Query namespace-1 via alice and check result to ensure it still works gcli, cleanup, err = msuite.dc.Client() @@ -347,22 +348,22 @@ func (msuite *MultitenancyTestSuite) TestTwoPermissionSetsInNameSpacesWithAcl() require.NoError(t, gcli.LoginIntoNamespace(context.Background(), user1, user1passwd, ns1)) resp, err := gcli.Query(query) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{"me": [{"name":"guy2"}, {"name":"guy1"}]}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{"me": [{"name":"guy2"}, {"name":"guy1"}]}`, string(resp.Json))) // Change permissions in namespace-2 hcli, err = msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns2) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns2) require.NoErrorf(t, err, "login as groot into namespace %d failed", ns2) require.NoError(t, hcli.AddRulesToGroup("dev", - []dgraphtest.AclRule{{Predicate: "name", Permission: acl.Read.Code}}, false)) + []dgraphapi.AclRule{{Predicate: "name", Permission: acl.Read.Code}}, false)) // Query namespace-2 gcli, cleanup, err = msuite.dc.Client() defer cleanup() require.NoError(t, err) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), user2, user2passwd, ns2)) - require.NoError(t, dgraphtest.PollTillPassOrTimeout(gcli, query, + require.NoError(t, dgraphapi.PollTillPassOrTimeout(gcli, query, `{"me": [{"name":"guy2", "nickname": "RG2"}, {"name":"guy1", "nickname": "RG"}]}`, aclQueryTimeout)) // Query namespace-1 @@ -372,7 +373,7 @@ func (msuite *MultitenancyTestSuite) TestTwoPermissionSetsInNameSpacesWithAcl() require.NoError(t, gcli.LoginIntoNamespace(context.Background(), user1, user1passwd, ns1)) resp, err = gcli.Query(query) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{"me": [{"name":"guy2"}, {"name":"guy1"}]}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{"me": [{"name":"guy2"}, {"name":"guy1"}]}`, string(resp.Json))) } func (msuite *MultitenancyTestSuite) TestCreateNamespace() { @@ -381,7 +382,7 @@ func (msuite *MultitenancyTestSuite) TestCreateNamespace() { // Galaxy Login hcli, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NotNil(t, hcli.AccessJwt, "Galaxy token is nil") require.NoErrorf(t, err, "login failed") @@ -395,7 +396,7 @@ func (msuite *MultitenancyTestSuite) TestCreateNamespace() { // Log into the namespace as groot hcli, err = msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns) require.NotNil(t, hcli.AccessJwt, "namespace token is nil") require.NoErrorf(t, err, "login with namespace %d failed", ns) @@ -411,7 +412,7 @@ func (msuite *MultitenancyTestSuite) TestResetPassword() { // Galaxy Login hcli1, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli1.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = hcli1.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NotNil(t, hcli1.HttpToken, "Galaxy token is nil") require.NoErrorf(t, err, "login failed") @@ -420,7 +421,7 @@ func (msuite *MultitenancyTestSuite) TestResetPassword() { require.NoError(t, err) // Reset Password - _, err = hcli1.ResetPassword(dgraphtest.DefaultUser, "newpassword", ns) + _, err = hcli1.ResetPassword(dgraphapi.DefaultUser, "newpassword", ns) require.NoError(t, err) // Upgrade @@ -429,14 +430,14 @@ func (msuite *MultitenancyTestSuite) TestResetPassword() { // Try and Fail with old password for groot hcli2, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli2.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns) + err = hcli2.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns) require.Error(t, err, "expected error because incorrect login") require.Empty(t, hcli2.AccessJwt, "nil token because incorrect login") // Try and succeed with new password for groot hcli3, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli3.LoginIntoNamespace(dgraphtest.DefaultUser, "newpassword", ns) + err = hcli3.LoginIntoNamespace(dgraphapi.DefaultUser, "newpassword", ns) require.NoError(t, err, "login failed") require.Equal(t, hcli3.Password, "newpassword", "new password matches the reset password") } @@ -447,15 +448,15 @@ func (msuite *MultitenancyTestSuite) TestDeleteNamespace() { // Galaxy Login hcli, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NoErrorf(t, err, "login failed") - dg := make(map[uint64]*dgraphtest.GrpcClient) + dg := make(map[uint64]*dgraphapi.GrpcClient) gcli, cleanup, e := msuite.dc.Client() defer cleanup() require.NoError(t, e) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) dg[x.GalaxyNamespace] = gcli // Create a new namespace @@ -467,7 +468,7 @@ func (msuite *MultitenancyTestSuite) TestDeleteNamespace() { defer cleanup() require.NoError(t, err) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns)) dg[ns] = gcli addData := func(ns uint64) error { @@ -490,7 +491,7 @@ func (msuite *MultitenancyTestSuite) TestDeleteNamespace() { ` resp, err := dg[ns].Query(query) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(expected, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(expected, string(resp.Json))) } require.NoError(t, addData(x.GalaxyNamespace)) @@ -501,13 +502,13 @@ func (msuite *MultitenancyTestSuite) TestDeleteNamespace() { // Upgrade msuite.Upgrade() - dg = make(map[uint64]*dgraphtest.GrpcClient) + dg = make(map[uint64]*dgraphapi.GrpcClient) gcli, cleanup, err = msuite.dc.Client() defer cleanup() require.NoError(t, err) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) dg[x.GalaxyNamespace] = gcli // Log into namespace as groot @@ -515,13 +516,13 @@ func (msuite *MultitenancyTestSuite) TestDeleteNamespace() { defer cleanup() require.NoError(t, err) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns)) dg[ns] = gcli // Galaxy Login hcli, err = msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NoError(t, err, "login failed") // Delete namespace @@ -554,7 +555,7 @@ func (msuite *MultitenancyTestSuite) TestDeleteNamespace() { } } -func (msuite *MultitenancyTestSuite) AddData(gcli *dgraphtest.GrpcClient) { +func (msuite *MultitenancyTestSuite) AddData(gcli *dgraphapi.GrpcClient) { rdfs := ` _:a "guy1" . _:a "RG" . @@ -566,7 +567,7 @@ func (msuite *MultitenancyTestSuite) AddData(gcli *dgraphtest.GrpcClient) { require.NoError(msuite.T(), err) } -func AddNumberOfTriples(gcli *dgraphtest.GrpcClient, start, end int) (*api.Response, error) { +func AddNumberOfTriples(gcli *dgraphapi.GrpcClient, start, end int) (*api.Response, error) { triples := strings.Builder{} for i := start; i <= end; i++ { triples.WriteString(fmt.Sprintf("_:person%[1]v \"person%[1]v\" .\n", i)) @@ -579,12 +580,12 @@ func (msuite *MultitenancyTestSuite) createGroupAndSetPermissions(namespace uint t := msuite.T() hcli, err := msuite.dc.HTTPClient() require.NoError(t, err) - require.NoError(t, hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, namespace)) + require.NoError(t, hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, namespace)) require.NotNil(t, hcli.AccessJwt, "namespace token is nil") require.NoErrorf(t, err, "login as groot into namespace %d failed", namespace) _, err = hcli.CreateGroup(group) require.NoError(t, err) require.NoError(t, hcli.AddUserToGroup(user, group)) require.NoError(t, hcli.AddRulesToGroup(group, - []dgraphtest.AclRule{{Predicate: predicate, Permission: acl.Read.Code}}, true)) + []dgraphapi.AclRule{{Predicate: predicate, Permission: acl.Read.Code}}, true)) } diff --git a/systest/multi-tenancy/integration_basic_helper_test.go b/systest/multi-tenancy/integration_basic_helper_test.go index f847cf1638f..d93353c1cc0 100644 --- a/systest/multi-tenancy/integration_basic_helper_test.go +++ b/systest/multi-tenancy/integration_basic_helper_test.go @@ -27,7 +27,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/testutil" "github.com/dgraph-io/dgraph/x" ) @@ -67,11 +67,11 @@ func (msuite *MultitenancyTestSuite) TestLiveLoadMulti() { defer cleanup() require.NoError(t, err) require.NoError(t, gcli0.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) hcli, err := msuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NoError(t, err, "login failed") // Create a new namespace @@ -81,11 +81,11 @@ func (msuite *MultitenancyTestSuite) TestLiveLoadMulti() { defer cleanup() require.NoError(t, err) require.NoError(t, gcli1.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns)) // Load data. - galaxyCreds := &testutil.LoginParams{UserID: dgraphtest.DefaultUser, - Passwd: dgraphtest.DefaultPassword, Namespace: x.GalaxyNamespace} + galaxyCreds := &testutil.LoginParams{UserID: dgraphapi.DefaultUser, + Passwd: dgraphapi.DefaultPassword, Namespace: x.GalaxyNamespace} require.NoError(t, msuite.liveLoadData(&liveOpts{ rdfs: fmt.Sprintf(` _:a "galaxy alice" . @@ -184,7 +184,7 @@ func (msuite *MultitenancyTestSuite) TestLiveLoadMulti() { err = msuite.liveLoadData(&liveOpts{ rdfs: `_:c "ns hola" .`, schema: `name: string @index(term) .`, - creds: &testutil.LoginParams{UserID: dgraphtest.DefaultUser, Passwd: dgraphtest.DefaultPassword, Namespace: ns}, + creds: &testutil.LoginParams{UserID: dgraphapi.DefaultUser, Passwd: dgraphapi.DefaultPassword, Namespace: ns}, forceNs: -1, }) require.Error(t, err) @@ -193,7 +193,7 @@ func (msuite *MultitenancyTestSuite) TestLiveLoadMulti() { err = msuite.liveLoadData(&liveOpts{ rdfs: `_:c "ns hola" .`, schema: `name: string @index(term) .`, - creds: &testutil.LoginParams{UserID: dgraphtest.DefaultUser, Passwd: dgraphtest.DefaultPassword, Namespace: ns}, + creds: &testutil.LoginParams{UserID: dgraphapi.DefaultUser, Passwd: dgraphapi.DefaultPassword, Namespace: ns}, forceNs: 10, }) require.Error(t, err) @@ -205,7 +205,7 @@ func (msuite *MultitenancyTestSuite) TestLiveLoadMulti() { _:b "ns gary" <%#x> . _:c "ns hola" <%#x> .`, ns, 0x100), schema: `name: string @index(term) .`, - creds: &testutil.LoginParams{UserID: dgraphtest.DefaultUser, Passwd: dgraphtest.DefaultPassword, Namespace: ns}, + creds: &testutil.LoginParams{UserID: dgraphapi.DefaultUser, Passwd: dgraphapi.DefaultPassword, Namespace: ns}, })) resp, err = gcli1.Query(query3) diff --git a/systest/multi-tenancy/integration_test.go b/systest/multi-tenancy/integration_test.go index 1038b11d47a..327bca87c14 100644 --- a/systest/multi-tenancy/integration_test.go +++ b/systest/multi-tenancy/integration_test.go @@ -26,13 +26,14 @@ import ( "github.com/stretchr/testify/suite" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) type MultitenancyTestSuite struct { suite.Suite - dc dgraphtest.Cluster + dc dgraphapi.Cluster } func (msuite *MultitenancyTestSuite) SetupTest() { @@ -45,7 +46,7 @@ func (msuite *MultitenancyTestSuite) TearDownTest() { defer cleanup() require.NoError(t, err) require.NoError(t, gcli.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gcli.Alter(context.Background(), &api.Operation{DropAll: true})) } diff --git a/systest/multi-tenancy/upgrade_test.go b/systest/multi-tenancy/upgrade_test.go index e3f0e4fff46..1b73247bc4b 100644 --- a/systest/multi-tenancy/upgrade_test.go +++ b/systest/multi-tenancy/upgrade_test.go @@ -26,13 +26,14 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) type MultitenancyTestSuite struct { suite.Suite - dc dgraphtest.Cluster + dc dgraphapi.Cluster lc *dgraphtest.LocalCluster uc dgraphtest.UpgradeCombo } @@ -61,7 +62,7 @@ func (msuite *MultitenancyTestSuite) Upgrade() { func TestMultitenancySuite(t *testing.T) { for _, uc := range dgraphtest.AllUpgradeCombos(false) { - log.Printf("running upgrade tests for confg: %+v", uc) + log.Printf("running upgrade tests for config: %+v", uc) var msuite MultitenancyTestSuite msuite.uc = uc suite.Run(t, &msuite) diff --git a/systest/mutations-and-queries/integration_test.go b/systest/mutations-and-queries/integration_test.go index 1fce9446d84..89e78e6f7f6 100644 --- a/systest/mutations-and-queries/integration_test.go +++ b/systest/mutations-and-queries/integration_test.go @@ -24,12 +24,13 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" ) type SystestTestSuite struct { suite.Suite - dc dgraphtest.Cluster + dc dgraphapi.Cluster } func (ssuite *SystestTestSuite) SetupTest() { diff --git a/systest/mutations-and-queries/mutations_test.go b/systest/mutations-and-queries/mutations_test.go index b14862b5098..59f92c5edf0 100644 --- a/systest/mutations-and-queries/mutations_test.go +++ b/systest/mutations-and-queries/mutations_test.go @@ -36,6 +36,7 @@ import ( "github.com/dgraph-io/dgo/v230" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/testutil" ) @@ -162,7 +163,7 @@ func (ssuite *SystestTestSuite) FacetJsonInputSupportsAnyOfTerms() { require.NoError(t, err, "the query should have succeeded") //var respUser User - dgraphtest.CompareJSON(fmt.Sprintf(` + dgraphapi.CompareJSON(fmt.Sprintf(` { "direct":[ { @@ -246,7 +247,7 @@ func (ssuite *SystestTestSuite) NQuadMutationTest() { txn = gcli.NewTxn() resp, err := txn.Query(ctx, breakfastQuery) require.NoError(t, err) - dgraphtest.CompareJSON(`{ "q": [ { + dgraphapi.CompareJSON(`{ "q": [ { "fruit": [ { "xid": "apple" }, { "xid": "banana" } @@ -275,7 +276,7 @@ func (ssuite *SystestTestSuite) NQuadMutationTest() { txn = gcli.NewTxn() resp, err = txn.Query(ctx, breakfastQuery) require.NoError(t, err) - dgraphtest.CompareJSON(`{ "q": [ { + dgraphapi.CompareJSON(`{ "q": [ { "fruit": [ { "xid": "apple" } ] @@ -328,7 +329,7 @@ func (ssuite *SystestTestSuite) DeleteAllReverseIndex() { ctx = context.Background() resp, err := gcli.NewTxn().Query(ctx, fmt.Sprintf("{ q(func: uid(%s)) { ~link { uid } }}", bId)) require.NoError(t, err) - dgraphtest.CompareJSON(`{"q":[]}`, string(resp.Json)) + dgraphapi.CompareJSON(`{"q":[]}`, string(resp.Json)) assignedIds, err = gcli.NewTxn().Mutate(ctx, &api.Mutation{ CommitNow: true, @@ -339,7 +340,7 @@ func (ssuite *SystestTestSuite) DeleteAllReverseIndex() { resp, err = gcli.NewTxn().Query(ctx, fmt.Sprintf("{ q(func: uid(%s)) { ~link { uid } }}", cId)) require.NoError(t, err) - dgraphtest.CompareJSON(fmt.Sprintf(`{"q":[{"~link": [{"uid": "%s"}]}]}`, aId), string(resp.Json)) + dgraphapi.CompareJSON(fmt.Sprintf(`{"q":[{"~link": [{"uid": "%s"}]}]}`, aId), string(resp.Json)) } func (ssuite *SystestTestSuite) NormalizeEdgeCasesTest() { @@ -478,7 +479,7 @@ func (ssuite *SystestTestSuite) FacetOrderTest() { ctx = context.Background() resp, err := txn.Query(ctx, friendQuery) require.NoError(t, err) - dgraphtest.CompareJSON(` + dgraphapi.CompareJSON(` { "q":[ { @@ -1188,7 +1189,7 @@ func (ssuite *SystestTestSuite) DeleteWithExpandAll() { require.Equal(t, 0, len(r.Me)) } -func testTimeValue(t *testing.T, c *dgraphtest.GrpcClient, timeBytes []byte) { +func testTimeValue(t *testing.T, c *dgraphapi.GrpcClient, timeBytes []byte) { nquads := []*api.NQuad{ { Subject: "0x01", @@ -1264,7 +1265,7 @@ func (ssuite *SystestTestSuite) SkipEmptyPLForHas() { }` resp, err := gcli.NewTxn().Query(ctx, q) require.NoError(t, err) - dgraphtest.CompareJSON(`{"users":[{"name":"u"},{"name":"u1"}]}`, string(resp.Json)) + dgraphapi.CompareJSON(`{"users":[{"name":"u"},{"name":"u1"}]}`, string(resp.Json)) op := &api.Operation{DropAll: true} require.NoError(t, gcli.Alter(ctx, op)) @@ -1331,7 +1332,7 @@ func (ssuite *SystestTestSuite) HasWithDash() { txn = gcli.NewTxn() resp, err := txn.Query(ctx, friendQuery) require.NoError(t, err) - dgraphtest.CompareJSON(`{"q":[{"new-friend":[{"name":"Bob"},{"name":"Charlie"}]}]}`, string(resp.Json)) + dgraphapi.CompareJSON(`{"q":[{"new-friend":[{"name":"Bob"},{"name":"Charlie"}]}]}`, string(resp.Json)) } func (ssuite *SystestTestSuite) ListGeoFilterTest() { @@ -1378,7 +1379,7 @@ func (ssuite *SystestTestSuite) ListGeoFilterTest() { } }`) require.NoError(t, err) - dgraphtest.CompareJSON(` + dgraphapi.CompareJSON(` { "q": [ { @@ -1436,7 +1437,7 @@ func (ssuite *SystestTestSuite) ListRegexFilterTest() { } }`) require.NoError(t, err) - dgraphtest.CompareJSON(` + dgraphapi.CompareJSON(` { "q": [ { @@ -1485,7 +1486,7 @@ func (ssuite *SystestTestSuite) RegexQueryWithVarsWithSlash() { } }`) require.NoError(t, err) - dgraphtest.CompareJSON(` + dgraphapi.CompareJSON(` { "q": [ { @@ -1503,7 +1504,7 @@ func (ssuite *SystestTestSuite) RegexQueryWithVarsWithSlash() { } }`, map[string]string{"$rex": "/\\/def/"}) require.NoError(t, err) - dgraphtest.CompareJSON(` + dgraphapi.CompareJSON(` { "q": [ { @@ -1559,7 +1560,7 @@ func (ssuite *SystestTestSuite) RegexQueryWithVars() { } }`, map[string]string{"$term": "/^rea.*$/"}) require.NoError(t, err) - dgraphtest.CompareJSON(` + dgraphapi.CompareJSON(` { "q": [ { @@ -1616,7 +1617,7 @@ func (ssuite *SystestTestSuite) GraphQLVarChild() { } }`, map[string]string{"$alice": a}) require.NoError(t, err) - dgraphtest.CompareJSON(` + dgraphapi.CompareJSON(` { "q": [ { @@ -1637,7 +1638,7 @@ func (ssuite *SystestTestSuite) GraphQLVarChild() { } }`, map[string]string{"$bob": b}) require.NoError(t, err) - dgraphtest.CompareJSON(` + dgraphapi.CompareJSON(` { "q": [ { @@ -1664,7 +1665,7 @@ func (ssuite *SystestTestSuite) GraphQLVarChild() { } }`, map[string]string{"$friends": friends}) require.NoError(t, err) - dgraphtest.CompareJSON(` + dgraphapi.CompareJSON(` { "q": [ { @@ -1721,7 +1722,7 @@ func (ssuite *SystestTestSuite) MathGe() { } }`) require.NoError(t, err) - dgraphtest.CompareJSON(` + dgraphapi.CompareJSON(` { "q": [ { @@ -1958,7 +1959,7 @@ func (ssuite *SystestTestSuite) RestoreReservedPreds() { query := `schema(preds: dgraph.type) {predicate}` resp, err := gcli.NewReadOnlyTxn().Query(ctx, query) require.NoError(t, err) - dgraphtest.CompareJSON(`{"schema": [{"predicate":"dgraph.type"}]}`, string(resp.Json)) + dgraphapi.CompareJSON(`{"schema": [{"predicate":"dgraph.type"}]}`, string(resp.Json)) } func (ssuite *SystestTestSuite) DropData() { @@ -2005,7 +2006,7 @@ func (ssuite *SystestTestSuite) DropData() { query := `schema(preds: [name, follow]) {predicate}` resp, err := gcli.NewReadOnlyTxn().Query(ctx, query) require.NoError(t, err) - dgraphtest.CompareJSON(`{"schema": [{"predicate":"name"}, {"predicate":"follow"}]}`, string(resp.Json)) + dgraphapi.CompareJSON(`{"schema": [{"predicate":"name"}, {"predicate":"follow"}]}`, string(resp.Json)) // Check data is gone. resp, err = gcli.NewTxn().Query(ctx, `{ @@ -2015,7 +2016,7 @@ func (ssuite *SystestTestSuite) DropData() { } }`) require.NoError(t, err) - dgraphtest.CompareJSON(`{"q": []}`, string(resp.GetJson())) + dgraphapi.CompareJSON(`{"q": []}`, string(resp.GetJson())) } func (ssuite *SystestTestSuite) DropDataAndDropAll() { @@ -2061,7 +2062,7 @@ func (ssuite *SystestTestSuite) DropType() { query := `schema(type: Person) {}` resp, err := gcli.NewReadOnlyTxn().Query(ctx, query) require.NoError(t, err) - dgraphtest.CompareJSON(`{"types":[{"name":"Person", "fields":[{"name":"name"}]}]}`, + dgraphapi.CompareJSON(`{"types":[{"name":"Person", "fields":[{"name":"name"}]}]}`, string(resp.Json)) require.NoError(t, gcli.Alter(ctx, &api.Operation{ @@ -2072,7 +2073,7 @@ func (ssuite *SystestTestSuite) DropType() { // Check type is gone. resp, err = gcli.NewReadOnlyTxn().Query(ctx, query) require.NoError(t, err) - dgraphtest.CompareJSON("{}", string(resp.Json)) + dgraphapi.CompareJSON("{}", string(resp.Json)) } func (ssuite *SystestTestSuite) DropTypeNoValue() { @@ -2111,7 +2112,7 @@ func (ssuite *SystestTestSuite) CountIndexConcurrentSetDelUIDList() { var wg sync.WaitGroup wg.Add(numRoutines) for i := 0; i < numRoutines; i++ { - go func(dg *dgraphtest.GrpcClient, wg *sync.WaitGroup) { + go func(dg *dgraphapi.GrpcClient, wg *sync.WaitGroup) { defer wg.Done() for { if atomic.AddUint64(&txnCur, 1) > txnTotal { @@ -2146,7 +2147,7 @@ func (ssuite *SystestTestSuite) CountIndexConcurrentSetDelUIDList() { }`, len(insertedMap)) resp, err := gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err, "the query should have succeeded") - dgraphtest.CompareJSON(`{"me":[{"uid": "0x1"}]}`, string(resp.GetJson())) + dgraphapi.CompareJSON(`{"me":[{"uid": "0x1"}]}`, string(resp.GetJson())) // Now start deleting UIDs. var insertedUids []int @@ -2162,7 +2163,7 @@ func (ssuite *SystestTestSuite) CountIndexConcurrentSetDelUIDList() { wg.Add(numRoutines) for i := 0; i < numRoutines; i++ { - go func(dg *dgraphtest.GrpcClient, wg *sync.WaitGroup) { + go func(dg *dgraphapi.GrpcClient, wg *sync.WaitGroup) { defer wg.Done() for { if atomic.AddUint64(&txnCur, 1) > txnTotal { @@ -2197,7 +2198,7 @@ func (ssuite *SystestTestSuite) CountIndexConcurrentSetDelUIDList() { }`, insertedCount-len(deletedMap)) resp, err = gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err, "the query should have succeeded") - dgraphtest.CompareJSON(`{"me":[{"uid": "0x1"}]}`, string(resp.GetJson())) + dgraphapi.CompareJSON(`{"me":[{"uid": "0x1"}]}`, string(resp.GetJson())) // Delete all friends now. mu := &api.Mutation{ @@ -2216,7 +2217,7 @@ func (ssuite *SystestTestSuite) CountIndexConcurrentSetDelUIDList() { ctx = context.Background() resp, err = gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err, "the query should have succeeded") - dgraphtest.CompareJSON(`{"me":[]}`, string(resp.GetJson())) + dgraphapi.CompareJSON(`{"me":[]}`, string(resp.GetJson())) } func (ssuite *SystestTestSuite) CountIndexConcurrentSetDelScalarPredicate() { @@ -2238,7 +2239,7 @@ func (ssuite *SystestTestSuite) CountIndexConcurrentSetDelScalarPredicate() { var wg sync.WaitGroup wg.Add(numRoutines) for i := 0; i < numRoutines; i++ { - go func(dg *dgraphtest.GrpcClient, wg *sync.WaitGroup) { + go func(dg *dgraphapi.GrpcClient, wg *sync.WaitGroup) { defer wg.Done() for { if atomic.AddUint64(&txnCur, 1) > txnTotal { @@ -2293,7 +2294,7 @@ func (ssuite *SystestTestSuite) CountIndexConcurrentSetDelScalarPredicate() { ctx = context.Background() resp, err = gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err, "the query should have succeeded") - dgraphtest.CompareJSON(`{"q":[]}`, string(resp.GetJson())) + dgraphapi.CompareJSON(`{"q":[]}`, string(resp.GetJson())) } func (ssuite *SystestTestSuite) CountIndexNonlistPredicateDelete() { @@ -2324,7 +2325,7 @@ func (ssuite *SystestTestSuite) CountIndexNonlistPredicateDelete() { }` resp, err := gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err, "the query should have succeeded") - dgraphtest.CompareJSON(`{"q": [{"uid": "0x1"}]}`, string(resp.GetJson())) + dgraphapi.CompareJSON(`{"q": [{"uid": "0x1"}]}`, string(resp.GetJson())) // Upgrade ssuite.Upgrade() @@ -2345,7 +2346,7 @@ func (ssuite *SystestTestSuite) CountIndexNonlistPredicateDelete() { // Query it using count index. resp, err = gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err, "the query should have succeeded") - dgraphtest.CompareJSON(`{"q": [{"uid": "0x1"}]}`, string(resp.GetJson())) + dgraphapi.CompareJSON(`{"q": [{"uid": "0x1"}]}`, string(resp.GetJson())) } func (ssuite *SystestTestSuite) ReverseCountIndexDelete() { @@ -2375,7 +2376,7 @@ func (ssuite *SystestTestSuite) ReverseCountIndexDelete() { }` resp, err := gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err, "the query should have succeeded") - dgraphtest.CompareJSON(`{"me":[{"uid": "0x2"}, {"uid": "0x3"}]}`, string(resp.GetJson())) + dgraphapi.CompareJSON(`{"me":[{"uid": "0x2"}, {"uid": "0x3"}]}`, string(resp.GetJson())) // Upgrade ssuite.Upgrade() @@ -2394,7 +2395,7 @@ func (ssuite *SystestTestSuite) ReverseCountIndexDelete() { resp, err = gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err, "the query should have succeeded") - dgraphtest.CompareJSON(`{"me":[{"uid": "0x3"}]}`, string(resp.GetJson())) + dgraphapi.CompareJSON(`{"me":[{"uid": "0x3"}]}`, string(resp.GetJson())) } @@ -2428,7 +2429,7 @@ func (ssuite *SystestTestSuite) ReverseCountIndex() { var wg sync.WaitGroup wg.Add(numRoutines) for i := 0; i < numRoutines; i++ { - go func(dg *dgraphtest.GrpcClient, id string, wg *sync.WaitGroup) { + go func(dg *dgraphapi.GrpcClient, id string, wg *sync.WaitGroup) { defer wg.Done() mu := &api.Mutation{ CommitNow: true, @@ -2461,7 +2462,7 @@ func (ssuite *SystestTestSuite) ReverseCountIndex() { ctx = context.Background() resp, err := gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err, "the query should have succeeded") - dgraphtest.CompareJSON(`{"me":[{"name":"Alice","count(~friend)":10}]}`, string(resp.GetJson())) + dgraphapi.CompareJSON(`{"me":[{"name":"Alice","count(~friend)":10}]}`, string(resp.GetJson())) } func (ssuite *SystestTestSuite) TypePredicateCheck() { @@ -2566,7 +2567,7 @@ func (ssuite *SystestTestSuite) InferSchemaAsList() { require.NoError(t, err) resp, err := gcli.NewReadOnlyTxn().Query(context.Background(), query) require.NoError(t, err) - dgraphtest.CompareJSON(`{"schema": [{"predicate":"name", "list":true}, + dgraphapi.CompareJSON(`{"schema": [{"predicate":"name", "list":true}, {"predicate":"nickname"}]}`, string(resp.Json)) } @@ -2594,7 +2595,7 @@ func (ssuite *SystestTestSuite) InferSchemaAsListJSON() { require.NoError(t, err) resp, err := gcli.NewReadOnlyTxn().Query(context.Background(), query) require.NoError(t, err) - dgraphtest.CompareJSON(`{"schema": [{"predicate":"name", "list":true}, + dgraphapi.CompareJSON(`{"schema": [{"predicate":"name", "list":true}, {"predicate":"nickname"}]}`, string(resp.Json)) } @@ -2623,7 +2624,7 @@ func (ssuite *SystestTestSuite) ForceSchemaAsListJSON() { require.NoError(t, err) resp, err := gcli.NewReadOnlyTxn().Query(context.Background(), query) require.NoError(t, err) - dgraphtest.CompareJSON(`{"schema": [{"predicate":"name", "list":true}, + dgraphapi.CompareJSON(`{"schema": [{"predicate":"name", "list":true}, {"predicate":"nickname"}]}`, string(resp.Json)) } @@ -2652,7 +2653,7 @@ func (ssuite *SystestTestSuite) ForceSchemaAsSingleJSON() { require.NoError(t, err) resp, err := gcli.NewReadOnlyTxn().Query(context.Background(), query) require.NoError(t, err) - dgraphtest.CompareJSON(`{"schema": [{"predicate":"person"}, {"predicate":"nickname"}]}`, + dgraphapi.CompareJSON(`{"schema": [{"predicate":"person"}, {"predicate":"nickname"}]}`, string(resp.Json)) } @@ -2692,7 +2693,7 @@ func (ssuite *SystestTestSuite) OverwriteUidPredicates() { }` resp, err := gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err) - dgraphtest.CompareJSON(`{"me":[{"name":"Alice","best_friend": {"name": "Bob"}}]}`, + dgraphapi.CompareJSON(`{"me":[{"name":"Alice","best_friend": {"name": "Bob"}}]}`, string(resp.GetJson())) upsertQuery := `query { alice as var(func: eq(name, Alice)) }` @@ -2717,7 +2718,7 @@ func (ssuite *SystestTestSuite) OverwriteUidPredicates() { require.NoError(t, err) resp, err = gcli.NewReadOnlyTxn().Query(context.Background(), q) require.NoError(t, err) - dgraphtest.CompareJSON(`{"me":[{"name":"Alice","best_friend": {"name": "Carol"}}]}`, + dgraphapi.CompareJSON(`{"me":[{"name":"Alice","best_friend": {"name": "Carol"}}]}`, string(resp.GetJson())) } @@ -2757,7 +2758,7 @@ func (ssuite *SystestTestSuite) OverwriteUidPredicatesReverse() { }` resp, err := gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err) - dgraphtest.CompareJSON(`{"me":[{"name":"Alice","best_friend": {"name": "Bob"}}]}`, + dgraphapi.CompareJSON(`{"me":[{"name":"Alice","best_friend": {"name": "Bob"}}]}`, string(resp.GetJson())) reverseQuery := `{ @@ -2769,7 +2770,7 @@ func (ssuite *SystestTestSuite) OverwriteUidPredicatesReverse() { }}` resp, err = gcli.NewReadOnlyTxn().Query(ctx, reverseQuery) require.NoError(t, err) - dgraphtest.CompareJSON(`{"reverse":[{"name":"Bob","~best_friend": [{"name": "Alice"}]}]}`, + dgraphapi.CompareJSON(`{"reverse":[{"name":"Bob","~best_friend": [{"name": "Alice"}]}]}`, string(resp.GetJson())) upsertQuery := `query { alice as var(func: eq(name, Alice)) }` @@ -2788,7 +2789,7 @@ func (ssuite *SystestTestSuite) OverwriteUidPredicatesReverse() { resp, err = gcli.NewReadOnlyTxn().Query(ctx, reverseQuery) require.NoError(t, err) - dgraphtest.CompareJSON(`{"reverse":[{"name":"Carol","~best_friend": [{"name": "Alice"}]}]}`, + dgraphapi.CompareJSON(`{"reverse":[{"name":"Carol","~best_friend": [{"name": "Alice"}]}]}`, string(resp.GetJson())) // Delete the triples and verify the reverse edge is gone. @@ -2816,12 +2817,12 @@ func (ssuite *SystestTestSuite) OverwriteUidPredicatesReverse() { ctx = context.Background() resp, err = gcli.NewReadOnlyTxn().Query(ctx, reverseQuery) require.NoError(t, err) - dgraphtest.CompareJSON(`{"reverse":[]}`, + dgraphapi.CompareJSON(`{"reverse":[]}`, string(resp.GetJson())) resp, err = gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err) - dgraphtest.CompareJSON(`{"me":[{"name":"Alice"}]}`, + dgraphapi.CompareJSON(`{"me":[{"name":"Alice"}]}`, string(resp.GetJson())) } @@ -2883,7 +2884,7 @@ func (ssuite *SystestTestSuite) OverwriteUidPredicatesMultipleTxn() { require.NoError(t, err) resp, err = gcli.NewReadOnlyTxn().Query(context.Background(), query) require.NoError(t, err) - dgraphtest.CompareJSON(`{"me":[{"name":"Alice","best_friend": {"name": "Carl"}}]}`, + dgraphapi.CompareJSON(`{"me":[{"name":"Alice","best_friend": {"name": "Carl"}}]}`, string(resp.GetJson())) } @@ -2927,7 +2928,7 @@ func (ssuite *SystestTestSuite) DeleteAndQuerySameTxn() { q := `{ me(func: has(name)) { name } }` resp, err := txn2.Query(ctx, q) require.NoError(t, err) - dgraphtest.CompareJSON(`{"me":[{"name":"Alice"}]}`, + dgraphapi.CompareJSON(`{"me":[{"name":"Alice"}]}`, string(resp.GetJson())) require.NoError(t, txn2.Commit(ctx)) @@ -2940,7 +2941,7 @@ func (ssuite *SystestTestSuite) DeleteAndQuerySameTxn() { // Verify that changes are reflected after the transaction is committed. resp, err = gcli.NewReadOnlyTxn().Query(context.Background(), q) require.NoError(t, err) - dgraphtest.CompareJSON(`{"me":[{"name":"Alice"}]}`, + dgraphapi.CompareJSON(`{"me":[{"name":"Alice"}]}`, string(resp.GetJson())) } @@ -2980,7 +2981,7 @@ func (ssuite *SystestTestSuite) AddAndQueryZeroTimeValue() { txn = gcli.NewTxn() resp, err := txn.Query(ctx, datetimeQuery) require.NoError(t, err) - dgraphtest.CompareJSON(`{ + dgraphapi.CompareJSON(`{ "q": [ { "val": "0000-01-01T00:00:00Z" diff --git a/systest/mutations-and-queries/queries_test.go b/systest/mutations-and-queries/queries_test.go index 8d30188baa3..0f1f4cb7554 100644 --- a/systest/mutations-and-queries/queries_test.go +++ b/systest/mutations-and-queries/queries_test.go @@ -27,7 +27,7 @@ import ( "github.com/stretchr/testify/require" "github.com/dgraph-io/dgo/v230/protos/api" - "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/testutil" "github.com/dgraph-io/dgraph/x" ) @@ -226,7 +226,7 @@ func (ssuite *SystestTestSuite) MultipleBlockEval() { for _, tc := range tests { resp, err := txn.Query(ctx, fmt.Sprintf(queryFmt, tc.in)) require.NoError(t, err) - dgraphtest.CompareJSON(tc.out, string(resp.Json)) + dgraphapi.CompareJSON(tc.out, string(resp.Json)) } } @@ -328,7 +328,7 @@ func (ssuite *SystestTestSuite) UnmatchedVarEval() { for _, tc := range tests { resp, err := txn.Query(ctx, tc.in) require.NoError(t, err) - dgraphtest.CompareJSON(tc.out, string(resp.Json)) + dgraphapi.CompareJSON(tc.out, string(resp.Json)) } } @@ -453,7 +453,7 @@ func (ssuite *SystestTestSuite) SchemaQueryTestPredicate1() { "types": [` + testutil.GetInternalTypes(false) + ` ] }` - dgraphtest.CompareJSON(js, string(resp.Json)) + dgraphapi.CompareJSON(js, string(resp.Json)) } func (ssuite *SystestTestSuite) SchemaQueryTestPredicate2() { @@ -496,7 +496,7 @@ func (ssuite *SystestTestSuite) SchemaQueryTestPredicate2() { } ] }` - dgraphtest.CompareJSON(js, string(resp.Json)) + dgraphapi.CompareJSON(js, string(resp.Json)) } func (ssuite *SystestTestSuite) SchemaQueryTestPredicate3() { @@ -548,7 +548,7 @@ func (ssuite *SystestTestSuite) SchemaQueryTestPredicate3() { } ] }` - dgraphtest.CompareJSON(js, string(resp.Json)) + dgraphapi.CompareJSON(js, string(resp.Json)) } func (ssuite *SystestTestSuite) SchemaQueryTestHTTP() { @@ -573,7 +573,7 @@ func (ssuite *SystestTestSuite) SchemaQueryTestHTTP() { hcli, err := ssuite.dc.HTTPClient() require.NoError(t, err) - err = hcli.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + err = hcli.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) require.NotNil(t, hcli.AccessJwt, "token is nil") require.NoError(t, err) @@ -584,7 +584,7 @@ func (ssuite *SystestTestSuite) SchemaQueryTestHTTP() { var m map[string]json.RawMessage require.NoError(t, json.Unmarshal(res, &m)) require.NotNil(t, m["extensions"]) - dgraphtest.CompareJSON(testutil.GetFullSchemaJSON(testutil.SchemaOptions{UserPreds: ` + dgraphapi.CompareJSON(testutil.GetFullSchemaJSON(testutil.SchemaOptions{UserPreds: ` { "predicate": "name", "type": "string", @@ -747,7 +747,7 @@ func (ssuite *SystestTestSuite) FuzzyMatch() { continue } require.NoError(t, err) - dgraphtest.CompareJSON(tc.out, string(resp.Json)) + dgraphapi.CompareJSON(tc.out, string(resp.Json)) } } @@ -1218,7 +1218,7 @@ func (ssuite *SystestTestSuite) CascadeParams() { for _, tc := range tests { resp, err := gcli.NewTxn().Query(ctx, tc.in) require.NoError(t, err) - dgraphtest.CompareJSON(tc.out, string(resp.Json)) + dgraphapi.CompareJSON(tc.out, string(resp.Json)) } } @@ -1339,7 +1339,7 @@ func (ssuite *SystestTestSuite) QueryHashIndex() { for _, tc := range tests { resp, err := gcli.NewTxn().Query(ctx, tc.in) require.NoError(t, err) - dgraphtest.CompareJSON(tc.out, string(resp.Json)) + dgraphapi.CompareJSON(tc.out, string(resp.Json)) } } @@ -1396,7 +1396,7 @@ func (ssuite *SystestTestSuite) RegexpToggleTrigramIndex() { for _, tc := range tests { resp, err := gcli.NewTxn().Query(ctx, tc.in) require.NoError(t, err) - dgraphtest.CompareJSON(tc.out, string(resp.Json)) + dgraphapi.CompareJSON(tc.out, string(resp.Json)) } op = &api.Operation{Schema: `name: string @index(trigram) @lang .`} @@ -1406,7 +1406,7 @@ func (ssuite *SystestTestSuite) RegexpToggleTrigramIndex() { for _, tc := range tests { resp, err := gcli.NewTxn().Query(ctx, tc.in) require.NoError(t, err) - dgraphtest.CompareJSON(tc.out, string(resp.Json)) + dgraphapi.CompareJSON(tc.out, string(resp.Json)) } require.NoError(t, gcli.Alter(ctx, &api.Operation{ @@ -1455,7 +1455,7 @@ func (ssuite *SystestTestSuite) EqWithAlteredIndexOrder() { expectedResult := `{"q":[{"name":"Alice"}]}` resp, err := gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err) - dgraphtest.CompareJSON(expectedResult, string(resp.Json)) + dgraphapi.CompareJSON(expectedResult, string(resp.Json)) // now, let's set the schema with trigram before term op = &api.Operation{Schema: `name: string @index(trigram, term) .`} @@ -1464,7 +1464,7 @@ func (ssuite *SystestTestSuite) EqWithAlteredIndexOrder() { // querying with eq should still work resp, err = gcli.NewReadOnlyTxn().Query(ctx, q) require.NoError(t, err) - dgraphtest.CompareJSON(expectedResult, string(resp.Json)) + dgraphapi.CompareJSON(expectedResult, string(resp.Json)) } func (ssuite *SystestTestSuite) GroupByUidWorks() { @@ -1519,17 +1519,17 @@ func (ssuite *SystestTestSuite) GroupByUidWorks() { for _, tc := range tests { resp, err := gcli.NewTxn().Query(ctx, tc.in) require.NoError(t, err) - dgraphtest.CompareJSON(tc.out, string(resp.Json)) + dgraphapi.CompareJSON(tc.out, string(resp.Json)) } } -func doGrpcLogin(ssuite *SystestTestSuite) (*dgraphtest.GrpcClient, func(), error) { +func doGrpcLogin(ssuite *SystestTestSuite) (*dgraphapi.GrpcClient, func(), error) { gcli, cleanup, err := ssuite.dc.Client() if err != nil { return nil, nil, errors.Wrap(err, "error creating grpc client") } err = gcli.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace) if err != nil { return nil, nil, errors.Wrap(err, "groot login into galaxy namespace failed") } diff --git a/systest/mutations-and-queries/upgrade_test.go b/systest/mutations-and-queries/upgrade_test.go index 9d1ace45bc7..66847fac3dd 100644 --- a/systest/mutations-and-queries/upgrade_test.go +++ b/systest/mutations-and-queries/upgrade_test.go @@ -27,13 +27,14 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) type SystestTestSuite struct { suite.Suite - dc dgraphtest.Cluster + dc dgraphapi.Cluster lc *dgraphtest.LocalCluster uc dgraphtest.UpgradeCombo } diff --git a/systest/online-restore/namespace-aware/restore_test.go b/systest/online-restore/namespace-aware/restore_test.go index 32fda468f4a..76ef1738b1a 100644 --- a/systest/online-restore/namespace-aware/restore_test.go +++ b/systest/online-restore/namespace-aware/restore_test.go @@ -25,11 +25,12 @@ import ( "github.com/stretchr/testify/require" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) -// func addData(gc *dgraphtest.GrpcClient, pred string, start, end int) error { +// func addData(gc *dgraphapi.GrpcClient, pred string, start, end int) error { // if err := gc.SetupSchema(fmt.Sprintf(`%v: string @index(exact) .`, pred)); err != nil { // return err // } @@ -45,12 +46,12 @@ import ( func commonTest(t *testing.T, existingCluster, freshCluster *dgraphtest.LocalCluster) { hc, err := existingCluster.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) gc, cleanup, err := existingCluster.Client() defer cleanup() require.NoError(t, err) - require.NoError(t, gc.Login(context.Background(), dgraphtest.DefaultUser, dgraphtest.DefaultPassword)) + require.NoError(t, gc.Login(context.Background(), dgraphapi.DefaultUser, dgraphapi.DefaultPassword)) namespaces := []uint64{0} require.NoError(t, dgraphtest.AddData(gc, "pred", 1, 100)) @@ -59,21 +60,21 @@ func commonTest(t *testing.T, existingCluster, freshCluster *dgraphtest.LocalClu require.NoError(t, err) namespaces = append(namespaces, ns) require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns)) require.NoError(t, dgraphtest.AddData(gc, "pred", 1, 100+int(ns))) } - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, hc.Backup(existingCluster, false, dgraphtest.DefaultBackupDir)) restoreNamespaces := func(c *dgraphtest.LocalCluster) { hc, err := c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) for _, ns := range namespaces { require.NoError(t, hc.RestoreTenant(c, dgraphtest.DefaultBackupDir, "", 0, 0, ns)) - require.NoError(t, dgraphtest.WaitForRestore(c)) + require.NoError(t, dgraphapi.WaitForRestore(c)) gc, cleanup, err = c.Client() require.NoError(t, err) @@ -81,7 +82,7 @@ func commonTest(t *testing.T, existingCluster, freshCluster *dgraphtest.LocalClu // Only the namespace '0' should have data require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) const query = `{ all(func: has(pred)) { count(uid) @@ -89,12 +90,12 @@ func commonTest(t *testing.T, existingCluster, freshCluster *dgraphtest.LocalClu }` resp, err := gc.Query(query) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(fmt.Sprintf(`{"all":[{"count":%v}]}`, 100+ns), string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(fmt.Sprintf(`{"all":[{"count":%v}]}`, 100+ns), string(resp.Json))) // other namespaces should have no data for _, ns2 := range namespaces[1:] { require.Error(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns2)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns2)) } } } @@ -109,12 +110,12 @@ func commonTest(t *testing.T, existingCluster, freshCluster *dgraphtest.LocalClu func commonIncRestoreTest(t *testing.T, existingCluster, freshCluster *dgraphtest.LocalCluster) { hc, err := existingCluster.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) gc, cleanup, err := existingCluster.Client() defer cleanup() require.NoError(t, err) - require.NoError(t, gc.Login(context.Background(), dgraphtest.DefaultUser, dgraphtest.DefaultPassword)) + require.NoError(t, gc.Login(context.Background(), dgraphapi.DefaultUser, dgraphapi.DefaultPassword)) require.NoError(t, gc.DropAll()) require.NoError(t, dgraphtest.AddData(gc, "pred", 1, 100)) @@ -129,20 +130,20 @@ func commonIncRestoreTest(t *testing.T, existingCluster, freshCluster *dgraphtes for j := 0; j < 5; j++ { for i, ns := range namespaces { require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, ns)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, ns)) start := i*20 + 1 end := (i + 1) * 20 require.NoError(t, dgraphtest.AddData(gc, "pred", start, end)) } - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, hc.Backup(existingCluster, j == 0, dgraphtest.DefaultBackupDir)) } restoreNamespaces := func(c *dgraphtest.LocalCluster) { hc, err := c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) for _, ns := range namespaces { for j := 0; j < 5; j++ { incrFrom := j + 1 @@ -151,13 +152,13 @@ func commonIncRestoreTest(t *testing.T, existingCluster, freshCluster *dgraphtes } require.NoError(t, hc.RestoreTenant(c, dgraphtest.DefaultBackupDir, "", incrFrom, j+1, ns)) - require.NoError(t, dgraphtest.WaitForRestore(c)) + require.NoError(t, dgraphapi.WaitForRestore(c)) gc, cleanup, err = c.Client() require.NoError(t, err) defer cleanup() - require.NoError(t, gc.Login(context.Background(), dgraphtest.DefaultUser, dgraphtest.DefaultPassword)) + require.NoError(t, gc.Login(context.Background(), dgraphapi.DefaultUser, dgraphapi.DefaultPassword)) const query = `{ all(func: has(pred)) { count(uid) @@ -165,7 +166,7 @@ func commonIncRestoreTest(t *testing.T, existingCluster, freshCluster *dgraphtes }` resp, err := gc.Query(query) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(fmt.Sprintf(`{"all":[{"count":%v}]}`, 20*(j+1)), string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(fmt.Sprintf(`{"all":[{"count":%v}]}`, 20*(j+1)), string(resp.Json))) } } } diff --git a/systest/plugin/integration_test.go b/systest/plugin/integration_test.go index 267ac4c00bc..c8bda8a98e3 100644 --- a/systest/plugin/integration_test.go +++ b/systest/plugin/integration_test.go @@ -24,12 +24,13 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" ) type PluginTestSuite struct { suite.Suite - dc dgraphtest.Cluster + dc dgraphapi.Cluster } func (psuite *PluginTestSuite) SetupTest() { diff --git a/systest/plugin/plugin_test.go b/systest/plugin/plugin_test.go index cfc1891e625..82e5a61d3f5 100644 --- a/systest/plugin/plugin_test.go +++ b/systest/plugin/plugin_test.go @@ -24,7 +24,7 @@ import ( "github.com/stretchr/testify/require" "github.com/dgraph-io/dgo/v230/protos/api" - "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/dgraphapi" ) type testCase struct { @@ -329,7 +329,7 @@ func (psuite *PluginTestSuite) TestPlugins() { for _, test := range testInp[i].cases { reply, err := gcli.Query(test.query) require.NoError(t, err) - dgraphtest.CompareJSON(test.wantResult, string(reply.GetJson())) + dgraphapi.CompareJSON(test.wantResult, string(reply.GetJson())) } }) } diff --git a/systest/plugin/upgrade_test.go b/systest/plugin/upgrade_test.go index 4ff2d59ec0a..effeeb2ea2c 100644 --- a/systest/plugin/upgrade_test.go +++ b/systest/plugin/upgrade_test.go @@ -26,15 +26,16 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" ) type PluginTestSuite struct { suite.Suite - dc dgraphtest.Cluster - lc *dgraphtest.LocalCluster - uc dgraphtest.UpgradeCombo + dc dgraphapi.Cluster + lc *dgraphtest.LocalCluster + uc dgraphtest.UpgradeCombo } func (psuite *PluginTestSuite) SetupSubTest() { diff --git a/systest/unique_test.go b/systest/unique_test.go index 54c6b3ce1b8..4f9614ba8ce 100644 --- a/systest/unique_test.go +++ b/systest/unique_test.go @@ -33,6 +33,7 @@ import ( "github.com/dgraph-io/dgo/v230/protos/api" "github.com/dgraph-io/dgraph/dgraph/cmd/live" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" ) @@ -49,7 +50,7 @@ const emailQueryWithUid = `{ } }` -func setUpDgraph(t *testing.T) *dgraphtest.GrpcClient { +func setUpDgraph(t *testing.T) *dgraphapi.GrpcClient { c := dgraphtest.ComposeCluster{} dg, close, err := c.Client() require.NoError(t, err) @@ -126,7 +127,7 @@ func TestUniqueTwoMutationSingleBlankNode(t *testing.T) { require.NoError(t, err) resp, err := dg.Query(emailQuery) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [{"email": "example@email.com"}]}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [{"email": "example@email.com"}]}`, string(resp.Json))) _, err = dg.Mutate(&api.Mutation{ SetNquads: []byte(rdf), CommitNow: true, @@ -150,7 +151,7 @@ func TestUniqueOneMutationSameValueTwoBlankNode(t *testing.T) { resp, err := dg.Query(emailQuery) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [ ]}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [ ]}`, string(resp.Json))) } func TestUniqueOneMutationSameValueSingleBlankNode(t *testing.T) { @@ -166,7 +167,7 @@ func TestUniqueOneMutationSameValueSingleBlankNode(t *testing.T) { require.NoError(t, err) resp, err := dg.Query(emailQuery) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [ { "email": "example@email.com"}]}`, + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [ { "email": "example@email.com"}]}`, string(resp.Json))) } @@ -183,7 +184,7 @@ func TestUniqueTwoMutattionsTwoHardCodedUIDs(t *testing.T) { resp, err := dg.Query(emailQueryWithUid) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [ {"email": "example@email.com","uid" :"0x5"}]}`, + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [ {"email": "example@email.com","uid" :"0x5"}]}`, string(resp.Json))) rdf = `<0x6> "example@email.com" . ` @@ -208,7 +209,7 @@ func TestUniqueHardCodedUidsWithDiffrentNotation(t *testing.T) { resp, err := dg.Query(emailQueryWithUid) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [ {"email": "example@email.com","uid" :"0xad"}]}`, + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [ {"email": "example@email.com","uid" :"0xad"}]}`, string(resp.Json))) rdf = `<0o255> "example@email.com" . ` @@ -219,7 +220,7 @@ func TestUniqueHardCodedUidsWithDiffrentNotation(t *testing.T) { require.NoError(t, err) resp, err = dg.Query(emailQueryWithUid) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [ {"email": "example@email.com","uid" :"0xad"}]}`, + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [ {"email": "example@email.com","uid" :"0xad"}]}`, string(resp.Json))) rdf = `<0b10101101> "example@email.com" . ` @@ -230,7 +231,7 @@ func TestUniqueHardCodedUidsWithDiffrentNotation(t *testing.T) { require.NoError(t, err) resp, err = dg.Query(emailQueryWithUid) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [ {"email": "example@email.com","uid" :"0xad"}]}`, + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [ {"email": "example@email.com","uid" :"0xad"}]}`, string(resp.Json))) rdf = `<173> "example@email.com" . ` @@ -241,7 +242,7 @@ func TestUniqueHardCodedUidsWithDiffrentNotation(t *testing.T) { require.NoError(t, err) resp, err = dg.Query(emailQueryWithUid) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [ {"email": "example@email.com","uid" :"0xad"}]}`, + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [ {"email": "example@email.com","uid" :"0xad"}]}`, string(resp.Json))) } @@ -259,7 +260,7 @@ func TestUniqueSingleMutattionsOneHardCodedUIDSameValue(t *testing.T) { require.NoError(t, err) resp, err := dg.Query(emailQueryWithUid) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [ { + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [ { "email": "example@email.com", "uid":"0x5" }]}`, string(resp.Json))) @@ -280,7 +281,7 @@ func TestUniqueOneMutattionsTwoHardCodedUIDsDiffValue(t *testing.T) { resp, err := dg.Query(emailQuery) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [ ]}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [ ]}`, string(resp.Json))) } func TestUniqueUpsertMutation(t *testing.T) { @@ -306,7 +307,7 @@ func TestUniqueUpsertMutation(t *testing.T) { require.NoError(t, err) resp, err := dg.Query(emailQuery) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [{"email": "example@email.com"} ]}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [{"email": "example@email.com"} ]}`, string(resp.Json))) mu = &api.Mutation{ SetNquads: []byte(` uid(v) "example1@email.com" .`), @@ -328,7 +329,7 @@ func TestUniqueWithConditionalUpsertMutation(t *testing.T) { require.NoError(t, err) resp, err := dg.Query(emailQuery) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [{"email": "example@email.com"} ]}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [{"email": "example@email.com"} ]}`, string(resp.Json))) query := `query{ v as person(func: eq(email, "example@email.com")) { @@ -505,7 +506,7 @@ func TestUniqueForInt(t *testing.T) { }` resp, err := dg.Query(query) require.NoError(t, err) - require.NoError(t, dgraphtest.CompareJSON(`{ "q": [ { "mobile": 1234567890}]}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{ "q": [ { "mobile": 1234567890}]}`, string(resp.Json))) } func TestUniqueForLangDirective(t *testing.T) { @@ -657,7 +658,7 @@ func TestConcurrencyMutationsDiffrentValuesForDiffrentBlankNode(t *testing.T) { resp, err := dg.Query(query) require.NoError(t, err) // there should be 1000 emails in DB. - require.NoError(t, dgraphtest.CompareJSON(`{"allMails":[{"count":1000}]}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{"allMails":[{"count":1000}]}`, string(resp.Json))) } func TestUniqueTwoTxnWithoutCommit(t *testing.T) { @@ -682,7 +683,7 @@ func TestUniqueTwoTxnWithoutCommit(t *testing.T) { resp, err := dg.Query(emailQuery) require.NoError(t, err) // there should be only one email data as expected. - require.NoError(t, dgraphtest.CompareJSON(`{"q":[{"email":"example@email.com"}]}`, string(resp.Json))) + require.NoError(t, dgraphapi.CompareJSON(`{"q":[{"email":"example@email.com"}]}`, string(resp.Json))) } func TestUniqueSingelTxnDuplicteValuesWithoutCommit(t *testing.T) { @@ -751,8 +752,8 @@ func TestConcurrency2(t *testing.T) { }`, i) resp, err := dg.Query(emailQuery) require.NoError(t, err) - err1 := dgraphtest.CompareJSON(fmt.Sprintf(`{"q":[{"email":"example%v@email.com"}]}`, i), string(resp.Json)) - err2 := dgraphtest.CompareJSON(fmt.Sprintf(`{ "q": [ ]}`), string(resp.Json)) + err1 := dgraphapi.CompareJSON(fmt.Sprintf(`{"q":[{"email":"example%v@email.com"}]}`, i), string(resp.Json)) + err2 := dgraphapi.CompareJSON(fmt.Sprintf(`{ "q": [ ]}`), string(resp.Json)) if err1 != nil && err2 != nil { t.Fatal() } diff --git a/systest/vector/backup_test.go b/systest/vector/backup_test.go new file mode 100644 index 00000000000..db0900a7636 --- /dev/null +++ b/systest/vector/backup_test.go @@ -0,0 +1,289 @@ +//go:build !oss && integration + +/* + * Copyright 2023 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + "slices" + "strings" + "testing" + "time" + + "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" + "github.com/dgraph-io/dgraph/dgraphtest" + "github.com/dgraph-io/dgraph/x" + "github.com/stretchr/testify/require" +) + +func TestVectorIncrBackupRestore(t *testing.T) { + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) + c, err := dgraphtest.NewLocalCluster(conf) + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + + gc, cleanup, err := c.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + hc, err := c.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + require.NoError(t, gc.SetupSchema(testSchema)) + + numVectors := 500 + pred := "project_discription_v" + allVectors := make([][][]float32, 0, 5) + allRdfs := make([]string, 0, 5) + for i := 1; i <= 5; i++ { + var rdfs string + var vectors [][]float32 + rdfs, vectors = dgraphapi.GenerateRandomVectors(numVectors*(i-1), numVectors*i, 1, pred) + allVectors = append(allVectors, vectors) + allRdfs = append(allRdfs, rdfs) + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err := gc.Mutate(mu) + require.NoError(t, err) + + t.Logf("taking backup #%v\n", i) + require.NoError(t, hc.Backup(c, i == 1, dgraphtest.DefaultBackupDir)) + } + + for i := 1; i <= 5; i++ { + t.Logf("restoring backup #%v\n", i) + + incrFrom := i - 1 + require.NoError(t, hc.Restore(c, dgraphtest.DefaultBackupDir, "", incrFrom, i)) + require.NoError(t, dgraphapi.WaitForRestore(c)) + query := `{ + vector(func: has(project_discription_v)) { + count(uid) + } + }` + result, err := gc.Query(query) + require.NoError(t, err) + + require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors*i), string(result.GetJson())) + var allSpredVec [][]float32 + for i, vecArr := range allVectors { + if i <= i { + allSpredVec = append(allSpredVec, vecArr...) + } + } + for p, vector := range allVectors[i-1] { + triple := strings.Split(allRdfs[i-1], "\n")[p] + uid := strings.Split(triple, " ")[0] + queriedVector, err := gc.QuerySingleVectorsUsingUid(uid, pred) + require.NoError(t, err) + + require.Equal(t, allVectors[i-1][p], queriedVector[0]) + + similarVectors, err := gc.QueryMultipleVectorsUsingSimilarTo(vector, pred, numVectors) + require.NoError(t, err) + require.GreaterOrEqual(t, len(similarVectors), 10) + for _, similarVector := range similarVectors { + require.Contains(t, allSpredVec, similarVector) + } + } + } +} + +func TestVectorBackupRestore(t *testing.T) { + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) + c, err := dgraphtest.NewLocalCluster(conf) + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + + gc, cleanup, err := c.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + hc, err := c.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + require.NoError(t, gc.SetupSchema(testSchema)) + + numVectors := 1000 + pred := "project_discription_v" + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 10, pred) + + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + t.Log("taking backup \n") + require.NoError(t, hc.Backup(c, false, dgraphtest.DefaultBackupDir)) + + t.Log("restoring backup \n") + require.NoError(t, hc.Restore(c, dgraphtest.DefaultBackupDir, "", 0, 0)) + require.NoError(t, dgraphapi.WaitForRestore(c)) + + testVectorQuery(t, gc, vectors, rdfs, pred, numVectors) +} + +func TestVectorBackupRestoreDropIndex(t *testing.T) { + // setup cluster + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) + c, err := dgraphtest.NewLocalCluster(conf) + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + + gc, cleanup, err := c.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + hc, err := c.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + // add vector predicate + index + require.NoError(t, gc.SetupSchema(testSchema)) + // add data to the vector predicate + numVectors := 3 + pred := "project_discription_v" + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 1, pred) + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + t.Log("taking full backup \n") + require.NoError(t, hc.Backup(c, false, dgraphtest.DefaultBackupDir)) + + // drop index + require.NoError(t, gc.SetupSchema(testSchemaWithoutIndex)) + + // add more data to the vector predicate + rdfs, vectors2 := dgraphapi.GenerateRandomVectors(3, numVectors+3, 1, pred) + mu = &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + // delete some entries + mu = &api.Mutation{DelNquads: []byte(strings.Split(rdfs, "\n")[1]), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + vectors2 = slices.Delete(vectors2, 1, 2) + + mu = &api.Mutation{DelNquads: []byte(strings.Split(rdfs, "\n")[0]), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + vectors2 = slices.Delete(vectors2, 0, 1) + + t.Log("taking first incr backup \n") + require.NoError(t, hc.Backup(c, false, dgraphtest.DefaultBackupDir)) + + // add index + require.NoError(t, gc.SetupSchema(testSchema)) + + t.Log("taking second incr backup \n") + require.NoError(t, hc.Backup(c, false, dgraphtest.DefaultBackupDir)) + + // restore backup + t.Log("restoring backup \n") + require.NoError(t, hc.Restore(c, dgraphtest.DefaultBackupDir, "", 0, 0)) + require.NoError(t, dgraphapi.WaitForRestore(c)) + + query := ` { + vectors(func: has(project_discription_v)) { + count(uid) + } + }` + resp, err := gc.Query(query) + require.NoError(t, err) + require.JSONEq(t, `{"vectors":[{"count":4}]}`, string(resp.GetJson())) + + require.NoError(t, err) + allVec := append(vectors, vectors2...) + + for _, vector := range allVec { + + similarVectors, err := gc.QueryMultipleVectorsUsingSimilarTo(vector, pred, 4) + require.NoError(t, err) + for _, similarVector := range similarVectors { + require.Contains(t, allVec, similarVector) + } + } +} + +func TestVectorBackupRestoreReIndexing(t *testing.T) { + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) + c, err := dgraphtest.NewLocalCluster(conf) + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + + gc, cleanup, err := c.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + hc, err := c.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + require.NoError(t, gc.SetupSchema(testSchema)) + + numVectors := 1000 + pred := "project_discription_v" + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 10, pred) + + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + t.Log("taking backup \n") + require.NoError(t, hc.Backup(c, false, dgraphtest.DefaultBackupDir)) + + rdfs2, vectors2 := dgraphapi.GenerateRandomVectors(numVectors, numVectors+300, 10, pred) + + mu = &api.Mutation{SetNquads: []byte(rdfs2), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + t.Log("restoring backup \n") + require.NoError(t, hc.Restore(c, dgraphtest.DefaultBackupDir, "", 2, 1)) + require.NoError(t, dgraphapi.WaitForRestore(c)) + + for i := 0; i < 5; i++ { + // drop index + require.NoError(t, gc.SetupSchema(testSchemaWithoutIndex)) + // add index + require.NoError(t, gc.SetupSchema(testSchema)) + } + vectors = append(vectors, vectors2...) + rdfs = rdfs + rdfs2 + testVectorQuery(t, gc, vectors, rdfs, pred, numVectors) +} diff --git a/systest/vector/load_test.go b/systest/vector/load_test.go index 2835d31dfdb..7154fba81da 100644 --- a/systest/vector/load_test.go +++ b/systest/vector/load_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" "github.com/stretchr/testify/require" @@ -51,18 +52,18 @@ func testExportAndLiveLoad(t *testing.T, c *dgraphtest.LocalCluster, exportForma require.NoError(t, err) defer cleanup() require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.SetupSchema(testSchema)) numVectors := 100 pred := "project_discription_v" - rdfs, vectors := dgraphtest.GenerateRandomVectors(0, numVectors, 10, pred) + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 10, pred) mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err = gc.Mutate(mu) @@ -87,7 +88,7 @@ func testExportAndLiveLoad(t *testing.T, c *dgraphtest.LocalCluster, exportForma require.NoError(t, c.LiveLoadFromExport(dgraphtest.DefaultExportDir)) require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) result, err = gc.Query(query) require.NoError(t, err) diff --git a/systest/vector/vector_test.go b/systest/vector/vector_test.go index d15c5a4747c..af786f8cc27 100644 --- a/systest/vector/vector_test.go +++ b/systest/vector/vector_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/dgraph-io/dgo/v230/protos/api" + "github.com/dgraph-io/dgraph/dgraphapi" "github.com/dgraph-io/dgraph/dgraphtest" "github.com/dgraph-io/dgraph/x" "github.com/stretchr/testify/require" @@ -38,7 +39,7 @@ const ( testSchemaWithoutIndex = `project_discription_v: float32vector .` ) -func testVectorQuery(t *testing.T, gc *dgraphtest.GrpcClient, vectors [][]float32, rdfs, pred string, topk int) { +func testVectorQuery(t *testing.T, gc *dgraphapi.GrpcClient, vectors [][]float32, rdfs, pred string, topk int) { for i, vector := range vectors { triple := strings.Split(rdfs, "\n")[i] uid := strings.Split(triple, " ")[0] @@ -65,12 +66,12 @@ func TestVectorDropAll(t *testing.T) { require.NoError(t, err) defer cleanup() require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) numVectors := 100 pred := "project_discription_v" @@ -85,7 +86,7 @@ func TestVectorDropAll(t *testing.T) { for i := 0; i < 10; i++ { require.NoError(t, gc.SetupSchema(testSchema)) - rdfs, vectors := dgraphtest.GenerateRandomVectors(0, numVectors, 100, pred) + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err = gc.Mutate(mu) require.NoError(t, err) @@ -122,25 +123,25 @@ func TestVectorSnapshot(t *testing.T) { require.NoError(t, err) defer cleanup() require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, c.KillAlpha(1)) hc, err = c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) gc, cleanup, err = c.AlphaClient(0) require.NoError(t, err) defer cleanup() require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.SetupSchema(testSchema)) @@ -149,7 +150,7 @@ func TestVectorSnapshot(t *testing.T) { numVectors := 500 pred := "project_discription_v" - rdfs, vectors := dgraphtest.GenerateRandomVectors(0, numVectors, 100, pred) + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err = gc.Mutate(mu) require.NoError(t, err) @@ -175,8 +176,8 @@ func TestVectorSnapshot(t *testing.T) { gc, cleanup, err = c.AlphaClient(1) require.NoError(t, err) defer cleanup() - require.NoError(t, gc.Login(context.Background(), dgraphtest.DefaultUser, - dgraphtest.DefaultPassword)) + require.NoError(t, gc.Login(context.Background(), dgraphapi.DefaultUser, + dgraphapi.DefaultPassword)) result, err = gc.Query(query) require.NoError(t, err) @@ -196,12 +197,12 @@ func TestVectorDropNamespace(t *testing.T) { require.NoError(t, err) defer cleanup() require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) numVectors := 500 pred := "project_discription_v" @@ -209,7 +210,7 @@ func TestVectorDropNamespace(t *testing.T) { ns, err := hc.AddNamespace() require.NoError(t, err) require.NoError(t, gc.SetupSchema(testSchema)) - rdfs, vectors := dgraphtest.GenerateRandomVectors(0, numVectors, 100, pred) + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err = gc.Mutate(mu) require.NoError(t, err) @@ -247,18 +248,18 @@ func TestVectorIndexRebuilding(t *testing.T) { require.NoError(t, err) defer cleanup() require.NoError(t, gc.LoginIntoNamespace(context.Background(), - dgraphtest.DefaultUser, dgraphtest.DefaultPassword, x.GalaxyNamespace)) + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) hc, err := c.HTTPClient() require.NoError(t, err) - require.NoError(t, hc.LoginIntoNamespace(dgraphtest.DefaultUser, - dgraphtest.DefaultPassword, x.GalaxyNamespace)) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) require.NoError(t, gc.SetupSchema(testSchema)) pred := "project_discription_v" numVectors := 1000 - rdfs, vectors := dgraphtest.GenerateRandomVectors(0, numVectors, 100, pred) + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err = gc.Mutate(mu) require.NoError(t, err) @@ -288,3 +289,29 @@ func TestVectorIndexRebuilding(t *testing.T) { testVectorQuery(t, gc, vectors, rdfs, pred, numVectors) } + +func TestVectorIndexOnVectorPredWithoutData(t *testing.T) { + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) + c, err := dgraphtest.NewLocalCluster(conf) + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + + gc, cleanup, err := c.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + hc, err := c.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + require.NoError(t, gc.SetupSchema(testSchema)) + pred := "project_discription_v" + + vector := []float32{1.0, 2.0, 3.0} + _, err = gc.QueryMultipleVectorsUsingSimilarTo(vector, pred, 10) + require.NoError(t, err) +} diff --git a/testutil/schema.go b/testutil/schema.go index 02ef3c5ae8b..e5e0e96ee0f 100644 --- a/testutil/schema.go +++ b/testutil/schema.go @@ -32,7 +32,7 @@ import ( const ( aclPreds = ` -{"predicate":"dgraph.xid","type":"string", "index":true, "tokenizer":["exact"], "upsert":true}, +{"predicate":"dgraph.xid","type":"string", "index":true, "tokenizer":["exact"], "unique": true, "upsert":true}, {"predicate":"dgraph.password","type":"password"}, {"predicate":"dgraph.user.group","list":true, "reverse":true, "type":"uid"}, {"predicate":"dgraph.acl.rule","type":"uid","list":true}, diff --git a/tok/hnsw/heap.go b/tok/hnsw/heap.go index 24c6edbe8fe..da165f835ae 100644 --- a/tok/hnsw/heap.go +++ b/tok/hnsw/heap.go @@ -44,6 +44,10 @@ func (h *minPersistentTupleHeap[T]) Push(x interface{}) { *h = append(*h, x.(minPersistentHeapElement[T])) } +func (h *minPersistentTupleHeap[T]) PopLast() { + heap.Remove(h, h.Len()-1) +} + func (h *minPersistentTupleHeap[T]) Pop() interface{} { old := *h n := len(old) diff --git a/tok/hnsw/helper.go b/tok/hnsw/helper.go index cbb02bb6fbe..b58a911f072 100644 --- a/tok/hnsw/helper.go +++ b/tok/hnsw/helper.go @@ -3,19 +3,21 @@ package hnsw import ( "context" "encoding/binary" - "encoding/json" + "fmt" "log" "math" "math/rand" "sort" "strconv" "strings" + "unsafe" - "github.com/chewxy/math32" c "github.com/dgraph-io/dgraph/tok/constraints" "github.com/dgraph-io/dgraph/tok/index" "github.com/getsentry/sentry-go" "github.com/pkg/errors" + "github.com/viterin/vek" + "github.com/viterin/vek/vek32" ) const ( @@ -24,6 +26,7 @@ const ( DotProd = "dotproduct" plError = "\nerror fetching posting list for data key: " dataError = "\nerror fetching data for data key: " + EmptyHNSWTreeError = "HNSW tree has no elements" VecKeyword = "__vector_" visitedVectorsLevel = "visited_vectors_level_" distanceComputations = "vector_distance_computations" @@ -60,62 +63,47 @@ func (s *SearchResult) GetExtraMetrics() map[string]uint64 { return s.extraMetrics } -func norm[T c.Float](v []T, floatBits int) T { - vectorNorm, _ := dotProduct(v, v, floatBits) - if floatBits == 32 { - return T(math32.Sqrt(float32(vectorNorm))) +func applyDistanceFunction[T c.Float](a, b []T, floatBits int, funcName string, + applyFn32 func(a, b []float32) float32, applyFn64 func(a, b []float64) float64) (T, error) { + if len(a) != len(b) { + err := errors.New(fmt.Sprintf("can not compute %s on vectors of different lengths", funcName)) + return T(0), err } - if floatBits == 64 { - return T(math.Sqrt(float64(vectorNorm))) + + if floatBits == 32 { + var a1, b1 []float32 + a1 = *(*[]float32)(unsafe.Pointer(&a)) + b1 = *(*[]float32)(unsafe.Pointer(&b)) + return T(applyFn32(a1, b1)), nil + } else if floatBits == 64 { + var a1, b1 []float64 + a1 = *(*[]float64)(unsafe.Pointer(&a)) + b1 = *(*[]float64)(unsafe.Pointer(&b)) + return T(applyFn64(a1, b1)), nil } - panic("Invalid floatBits") + + panic("While applying function on two floats, found an invalid number of float bits") + } // This needs to implement signature of SimilarityType[T].distanceScore // function, hence it takes in a floatBits parameter, // but doesn't actually use it. func dotProduct[T c.Float](a, b []T, floatBits int) (T, error) { - var dotProduct T - if len(a) != len(b) { - err := errors.New("can not compute dot product on vectors of different lengths") - return dotProduct, err - } - for i := range a { - dotProduct += a[i] * b[i] - } - return dotProduct, nil + return applyDistanceFunction(a, b, floatBits, "dot product", vek32.Dot, vek.Dot) } // This needs to implement signature of SimilarityType[T].distanceScore // function, hence it takes in a floatBits parameter. func cosineSimilarity[T c.Float](a, b []T, floatBits int) (T, error) { - dotProd, err := dotProduct(a, b, floatBits) - if err != nil { - return 0, err - } - normA := norm[T](a, floatBits) - normB := norm[T](b, floatBits) - if normA == 0 || normB == 0 { - err := errors.New("can not compute cosine similarity on zero vector") - var empty T - return empty, err - } - return dotProd / (normA * normB), nil + return applyDistanceFunction(a, b, floatBits, "cosine distance", vek32.CosineSimilarity, vek.CosineSimilarity) } // This needs to implement signature of SimilarityType[T].distanceScore // function, hence it takes in a floatBits parameter, // but doesn't actually use it. func euclidianDistanceSq[T c.Float](a, b []T, floatBits int) (T, error) { - if len(a) != len(b) { - return 0, errors.New("can not subtract vectors of different lengths") - } - var distSq T - for i := range a { - val := a[i] - b[i] - distSq += val * val - } - return distSq, nil + return applyDistanceFunction(a, b, floatBits, "euclidian distance", vek32.Distance, vek.Distance) } // Used for distance, since shorter distance is better @@ -206,24 +194,6 @@ func cannotConvertToUintSlice(s string) error { return errors.Errorf("Cannot convert %s to uint slice", s) } -func diff(a []uint64, b []uint64) []uint64 { - // Turn b into a map - m := make(map[uint64]bool, len(b)) - for _, s := range b { - m[s] = false - } - // Append values from the longest slice that don't exist in the map - var diff []uint64 - for _, s := range a { - if _, ok := m[s]; !ok { - diff = append(diff, s) - continue - } - m[s] = true - } - return diff -} - // TODO: Move SimilarityType to index package. // // Remove "hnsw-isms". @@ -334,7 +304,7 @@ func populateEdgeDataFromKeyWithCacheType( if data == nil { return false, nil } - err = json.Unmarshal(data.([]byte), &edgeData) + err = decodeUint64MatrixUnsafe(data.([]byte), edgeData) return true, err } @@ -432,9 +402,10 @@ func (ph *persistentHNSW[T]) createEntryAndStartNodes( err := ph.getVecFromUid(entry, c, vec) if err != nil || len(*vec) == 0 { // The entry vector has been deleted. We have to create a new entry vector. - entry, err := ph.PickStartNode(ctx, c, vec) + entry, err := ph.calculateNewEntryVec(ctx, c, vec) if err != nil { - return 0, []*index.KeyValue{}, err + // No other node exists, go with the new node that has come + return create_edges(inUuid) } return create_edges(entry) } @@ -442,6 +413,70 @@ func (ph *persistentHNSW[T]) createEntryAndStartNodes( return entry, edges, nil } +// Converts the matrix into linear array that looks like +// [0: Number of rows 1: Length of row1 2-n: Data of row1 3: Length of row2 ..] +func encodeUint64MatrixUnsafe(matrix [][]uint64) []byte { + if len(matrix) == 0 { + return nil + } + + // Calculate the total size + var totalSize uint64 + for _, row := range matrix { + totalSize += uint64(len(row))*uint64(unsafe.Sizeof(uint64(0))) + uint64(unsafe.Sizeof(uint64(0))) + } + totalSize += uint64(unsafe.Sizeof(uint64(0))) + + // Create a byte slice with the appropriate size + data := make([]byte, totalSize) + + offset := 0 + // Write number of rows + rows := uint64(len(matrix)) + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&rows))[:]) + offset += 8 + + // Write each row's length and data + for _, row := range matrix { + rowLen := uint64(len(row)) + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&rowLen))[:]) + offset += 8 + for i := range row { + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&row[i]))[:]) + offset += 8 + } + } + + return data +} + +func decodeUint64MatrixUnsafe(data []byte, matrix *[][]uint64) error { + if len(data) == 0 { + return nil + } + + offset := 0 + // Read number of rows + rows := *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + + *matrix = make([][]uint64, rows) + + for i := 0; i < int(rows); i++ { + // Read row length + rowLen := *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + + (*matrix)[i] = make([]uint64, rowLen) + for j := 0; j < int(rowLen); j++ { + (*matrix)[i][j] = *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + } + } + + return nil +} + // adds empty layers to all levels func (ph *persistentHNSW[T]) addStartNodeToAllLevels( ctx context.Context, @@ -450,11 +485,7 @@ func (ph *persistentHNSW[T]) addStartNodeToAllLevels( inUuid uint64) ([]*index.KeyValue, error) { edges := []*index.KeyValue{} key := DataKey(ph.vecKey, inUuid) - emptyEdges := make([][]uint64, ph.maxLevels) - emptyEdgesBytes, err := json.Marshal(emptyEdges) - if err != nil { - return []*index.KeyValue{}, err - } + emptyEdgesBytes := encodeUint64MatrixUnsafe(make([][]uint64, ph.maxLevels)) // creates empty at all levels only for entry node edge, err := ph.newPersistentEdgeKeyValueEntry(ctx, key, txn, inUuid, emptyEdgesBytes) if err != nil { @@ -507,7 +538,7 @@ func (ph *persistentHNSW[T]) addNeighbors(ctx context.Context, tc *TxnCache, allLayerEdges = allLayerNeighbors } else { // all edges of nearest neighbor - err := json.Unmarshal(data.([]byte), &allLayerEdges) + err := decodeUint64MatrixUnsafe(data.([]byte), &allLayerEdges) if err != nil { return nil, err } @@ -525,10 +556,7 @@ func (ph *persistentHNSW[T]) addNeighbors(ctx context.Context, tc *TxnCache, // on every modification of the layer edges, add it to in mem map so you dont have to always be reading // from persistent storage ph.nodeAllEdges[uuid] = allLayerEdges - inboundEdgesBytes, marshalErr := json.Marshal(allLayerEdges) - if marshalErr != nil { - return nil, marshalErr - } + inboundEdgesBytes := encodeUint64MatrixUnsafe(allLayerEdges) edge := &index.KeyValue{ Entity: uuid, @@ -543,19 +571,38 @@ func (ph *persistentHNSW[T]) addNeighbors(ctx context.Context, tc *TxnCache, // removeDeadNodes(nnEdges, tc) removes dead nodes from nnEdges and returns the new nnEdges func (ph *persistentHNSW[T]) removeDeadNodes(nnEdges []uint64, tc *TxnCache) ([]uint64, error) { - data, err := getDataFromKeyWithCacheType(ph.vecDead, 1, tc) - if err != nil && err.Error() == plError { - return []uint64{}, err - } - var deadNodes []uint64 - if data != nil { // if dead nodes exist, convert to []uint64 - deadNodes, err = ParseEdges(string(data.([]byte))) - if err != nil { + // TODO add a path to delete deadNodes + if ph.deadNodes == nil { + data, err := getDataFromKeyWithCacheType(ph.vecDead, 1, tc) + if err != nil && err.Error() == plError { return []uint64{}, err } - nnEdges = diff(nnEdges, deadNodes) // set nnEdges to be all elements not contained in deadNodes + + var deadNodes []uint64 + if data != nil { // if dead nodes exist, convert to []uint64 + deadNodes, err = ParseEdges(string(data.([]byte))) + if err != nil { + return []uint64{}, err + } + } + + ph.deadNodes = make(map[uint64]struct{}) + for _, n := range deadNodes { + ph.deadNodes[n] = struct{}{} + } + } + if len(ph.deadNodes) == 0 { + return nnEdges, nil + } + + var diff []uint64 + for _, s := range nnEdges { + if _, ok := ph.deadNodes[s]; !ok { + diff = append(diff, s) + continue + } } - return nnEdges, nil + return diff, nil } func Uint64ToBytes(key uint64) []byte { diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index 6fd51a4b520..00b55552e04 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/bits-and-blooms/bitset" c "github.com/dgraph-io/dgraph/tok/constraints" "github.com/dgraph-io/dgraph/tok/index" opt "github.com/dgraph-io/dgraph/tok/options" @@ -26,7 +27,8 @@ type persistentHNSW[T c.Float] struct { // nodeAllEdges[65443][1][3] indicates the 3rd neighbor in the first // layer for uuid 65443. The result will be a neighboring uuid. nodeAllEdges map[uint64][][]uint64 - visitedUids []uint64 + visitedUids bitset.BitSet + deadNodes map[uint64]struct{} } func GetPersistantOptions[T c.Float](o opt.Options) string { @@ -163,10 +165,13 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( index: entry, filteredOut: entryIsFilteredOut, } + r.setFirstPathNode(best) - //create set using map to append to on future visited nodes - ph.visitedUids = append(ph.visitedUids, best.index) candidateHeap := *buildPersistentHeapByInit([]minPersistentHeapElement[T]{best}) + + var allLayerEdges [][]uint64 + + //create set using map to append to on future visited nodes for candidateHeap.Len() != 0 { currCandidate := candidateHeap.Pop().(minPersistentHeapElement[T]) if r.numNeighbors() < expectedNeighbors && @@ -181,7 +186,6 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( // guarantees of getting best results. break } - var allLayerEdges [][]uint64 found, err := ph.fillNeighborEdges(currCandidate.index, c, &allLayerEdges) if err != nil { @@ -190,46 +194,55 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( if !found { continue } - currLayerEdges := allLayerEdges[level] - currLayerEdges = diff(currLayerEdges, ph.visitedUids) var eVec []T - for i := range currLayerEdges { + improved := false + for _, currUid := range allLayerEdges[level] { + if ph.visitedUids.Test(uint(currUid)) { + continue + } + if r.indexVisited(currUid) { + continue + } // iterate over candidate's neighbors distances to get // best ones - _ = ph.getVecFromUid(currLayerEdges[i], c, &eVec) + _ = ph.getVecFromUid(currUid, c, &eVec) // intentionally ignoring error -- we catch it // indirectly via eVec == nil check. if len(eVec) == 0 { continue } currDist, err := ph.simType.distanceScore(eVec, query, ph.floatBits) - ph.visitedUids = append(ph.visitedUids, currLayerEdges[i]) - r.incrementDistanceComputations() if err != nil { return ph.emptySearchResultWithError(err) } - filteredOut := !filter(query, eVec, currLayerEdges[i]) + filteredOut := !filter(query, eVec, currUid) currElement := initPersistentHeapElement( - currDist, currLayerEdges[i], filteredOut) - nodeVisited := r.nodeVisited(*currElement) - if !nodeVisited { - r.addToVisited(*currElement) - - // If we have not yet found k candidates, we can consider - // any candidate. Otherwise, only consider those that - // are better than our current k nearest neighbors. - // Note that the "numNeighbors" function is a bit tricky: - // If we previously added to the heap M elements that should - // be filtered out, we ignore M elements in the numNeighbors - // check! In this way, we can make sure to allow in up to - // expectedNeighbors "unfiltered" elements. - if ph.simType.isBetterScore(currDist, r.lastNeighborScore()) || - r.numNeighbors() < expectedNeighbors { - candidateHeap.Push(*currElement) - r.addPathNode(*currElement, ph.simType, expectedNeighbors) + currDist, currUid, filteredOut) + r.addToVisited(*currElement) + r.incrementDistanceComputations() + ph.visitedUids.Set(uint(currUid)) + + // If we have not yet found k candidates, we can consider + // any candidate. Otherwise, only consider those that + // are better than our current k nearest neighbors. + // Note that the "numNeighbors" function is a bit tricky: + // If we previously added to the heap M elements that should + // be filtered out, we ignore M elements in the numNeighbors + // check! In this way, we can make sure to allow in up to + // expectedNeighbors "unfiltered" elements. + if r.numNeighbors() < expectedNeighbors || ph.simType.isBetterScore(currDist, r.lastNeighborScore()) { + if candidateHeap.Len() > expectedNeighbors { + candidateHeap.PopLast() } + candidateHeap.Push(*currElement) + r.addPathNode(*currElement, ph.simType, expectedNeighbors) + improved = true } } + + if !improved && r.numNeighbors() >= expectedNeighbors { + break + } } return r, nil } @@ -289,10 +302,10 @@ func (ph *persistentHNSW[T]) calculateNewEntryVec( }) if err != nil { - return 0, errors.Wrapf(err, "HNSW tree has no elements") + return 0, errors.Wrapf(err, EmptyHNSWTreeError) } if itr == 0 { - return itr, errors.New("HNSW tree has no elements") + return itr, errors.New(EmptyHNSWTreeError) } return itr, nil @@ -335,6 +348,8 @@ func (ph *persistentHNSW[T]) SearchWithPath( start := time.Now().UnixMilli() r = index.NewSearchPathResult() + ph.visitedUids.ClearAll() + // 0-profile_vector_entry var startVec []T entry, err := ph.PickStartNode(ctx, c, &startVec) @@ -356,6 +371,7 @@ func (ph *persistentHNSW[T]) SearchWithPath( } layerResult.updateFinalMetrics(r) entry = layerResult.bestNeighbor().index + layerResult.updateFinalPath(r) err = ph.getVecFromUid(entry, c, &startVec) if err != nil { @@ -417,6 +433,8 @@ func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, inLevel := getInsertLayer(ph.maxLevels) // calculate layer to insert node at (randomized every time) var layerErr error + ph.visitedUids.ClearAll() + for level := 0; level < inLevel; level++ { // perform insertion for layers [level, max_level) only, when level < inLevel just find better start err := ph.getVecFromUid(entry, tc, &startVec) @@ -424,7 +442,7 @@ func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, err } layerResult, err := ph.searchPersistentLayer(tc, level, entry, startVec, - inVec, false, 1, index.AcceptAll[T]) + inVec, false, ph.efSearch, index.AcceptAll[T]) if err != nil { return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, err } @@ -451,10 +469,14 @@ func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, return []minPersistentHeapElement[T]{}, []*index.KeyValue{}, layerErr } + entry = layerResult.bestNeighbor().index + nns := layerResult.neighbors for i := 0; i < len(nns); i++ { nnUidArray = append(nnUidArray, nns[i].index) - inboundEdgesAllLayersMap[nns[i].index] = make([][]uint64, ph.maxLevels) + if inboundEdgesAllLayersMap[nns[i].index] == nil { + inboundEdgesAllLayersMap[nns[i].index] = make([][]uint64, ph.maxLevels) + } inboundEdgesAllLayersMap[nns[i].index][level] = append(inboundEdgesAllLayersMap[nns[i].index][level], inUuid) // add nn to outboundEdges. diff --git a/tok/hnsw/persistent_hnsw_test.go b/tok/hnsw/persistent_hnsw_test.go index 986abd76046..2befa742030 100644 --- a/tok/hnsw/persistent_hnsw_test.go +++ b/tok/hnsw/persistent_hnsw_test.go @@ -286,19 +286,6 @@ var flatPhs = []*persistentHNSW[float64]{ }, } -var flatPh = &persistentHNSW[float64]{ - maxLevels: 5, - efConstruction: 16, - efSearch: 12, - pred: "0-a", - vecEntryKey: ConcatStrings("0-a", VecEntry), - vecKey: ConcatStrings("0-a", VecKeyword), - vecDead: ConcatStrings("0-a", VecDead), - floatBits: 64, - simType: GetSimType[float64](Euclidian, 64), - nodeAllEdges: make(map[uint64][][]uint64), -} - var flatEntryInsertToPersistentFlatStorageTests = []insertToPersistentFlatStorageTest{ { tc: NewTxnCache(&inMemTxn{startTs: 12, commitTs: 40}, 12), @@ -328,6 +315,7 @@ var flatEntryInsertToPersistentFlatStorageTests = []insertToPersistentFlatStorag func TestFlatEntryInsertToPersistentFlatStorage(t *testing.T) { emptyTsDbs() + flatPh := flatPhs[0] for _, test := range flatEntryInsertToPersistentFlatStorageTests { emptyTsDbs() key := DataKey(flatPh.pred, test.inUuid) @@ -345,12 +333,13 @@ func TestFlatEntryInsertToPersistentFlatStorage(t *testing.T) { } } var float1, float2 = []float64{}, []float64{} - index.BytesAsFloatArray(tsDbs[0].inMemTestDb[string(key[:])].([]byte), &float1, 64) - index.BytesAsFloatArray(tsDbs[99].inMemTestDb[string(key[:])].([]byte), &float2, 64) + skey := string(key[:]) + index.BytesAsFloatArray(tsDbs[0].inMemTestDb[skey].([]byte), &float1, 64) + index.BytesAsFloatArray(tsDbs[99].inMemTestDb[skey].([]byte), &float2, 64) if !equalFloat64Slice(float1, float2) { t.Errorf("Vector value for predicate %q at beginning and end of database were "+ - "not equivalent. Start Value: %v, End Value: %v", flatPh.pred, tsDbs[0].inMemTestDb[flatPh.pred].([]float64), - tsDbs[99].inMemTestDb[flatPh.pred].([]float64)) + "not equivalent. Start Value: %v\n, End Value: %v\n %v\n %v", flatPh.pred, tsDbs[0].inMemTestDb[skey], + tsDbs[99].inMemTestDb[skey], float1, float2) } edgesNameList := []string{} for _, edge := range edges { @@ -405,6 +394,7 @@ var nonflatEntryInsertToPersistentFlatStorageTests = []insertToPersistentFlatSto func TestNonflatEntryInsertToPersistentFlatStorage(t *testing.T) { emptyTsDbs() + flatPh := flatPhs[0] key := DataKey(flatPh.pred, flatEntryInsert.inUuid) for i := range tsDbs { tsDbs[i].inMemTestDb[string(key[:])] = floatArrayAsBytes(flatEntryInsert.inVec) @@ -479,7 +469,7 @@ var searchPersistentFlatStorageTests = []searchPersistentFlatStorageTest{ query: []float64{0.824, 0.319, 0.111}, maxResults: 1, expectedErr: nil, - expectedNns: []uint64{5}, + expectedNns: []uint64{123}, }, } @@ -510,7 +500,7 @@ var flatPopulateBasicInsertsForSearch = []insertToPersistentFlatStorageTest{ }, } -func flatPopulateInserts(insertArr []insertToPersistentFlatStorageTest) error { +func flatPopulateInserts(insertArr []insertToPersistentFlatStorageTest, flatPh *persistentHNSW[float64]) error { emptyTsDbs() for _, in := range insertArr { for i := range tsDbs { @@ -544,7 +534,7 @@ func RunFlatSearchTests(t *testing.T, test searchPersistentFlatStorageTest, flat func TestBasicSearchPersistentFlatStorage(t *testing.T) { for _, flatPh := range flatPhs { emptyTsDbs() - err := flatPopulateInserts(flatPopulateBasicInsertsForSearch) + err := flatPopulateInserts(flatPopulateBasicInsertsForSearch, flatPh) if err != nil { t.Errorf("Error populating inserts: %s", err) return @@ -554,61 +544,3 @@ func TestBasicSearchPersistentFlatStorage(t *testing.T) { } } } - -var flatPopulateOverlappingInserts = []insertToPersistentFlatStorageTest{ - { - tc: NewTxnCache(&inMemTxn{startTs: 0, commitTs: 5}, 0), - inUuid: uint64(5), - inVec: []float64{0.1, 0.1, 0.1}, - expectedErr: nil, - expectedEdgesList: nil, - minExpectedEdge: "", - }, - { - tc: NewTxnCache(&inMemTxn{startTs: 3, commitTs: 9}, 3), - inUuid: uint64(123), - inVec: []float64{0.824, 0.319, 0.111}, - expectedErr: nil, - expectedEdgesList: nil, - minExpectedEdge: "", - }, - { - tc: NewTxnCache(&inMemTxn{startTs: 8, commitTs: 37}, 8), - inUuid: uint64(1), - inVec: []float64{0.3, 0.5, 0.7}, - expectedErr: nil, - expectedEdgesList: nil, - minExpectedEdge: "", - }, -} - -var overlappingSearchPersistentFlatStorageTests = []searchPersistentFlatStorageTest{ - { - qc: NewQueryCache(&inMemLocalCache{readTs: 45}, 45), - query: []float64{0.3, 0.5, 0.7}, - maxResults: 1, - expectedErr: nil, - expectedNns: []uint64{123}, - }, - { - qc: NewQueryCache(&inMemLocalCache{readTs: 93}, 93), - query: []float64{0.824, 0.319, 0.111}, - maxResults: 1, - expectedErr: nil, - expectedNns: []uint64{123}, - }, -} - -func TestOverlappingInsertsAndSearchPersistentFlatStorage(t *testing.T) { - for _, flatPh := range flatPhs { - emptyTsDbs() - err := flatPopulateInserts(flatPopulateOverlappingInserts) - if err != nil { - t.Errorf("Error from flatPopulateInserts: %s", err) - return - } - for _, test := range overlappingSearchPersistentFlatStorageTests { - RunFlatSearchTests(t, test, flatPh) - } - } -} diff --git a/tok/hnsw/search_layer.go b/tok/hnsw/search_layer.go index 49f129648bb..d01a864d063 100644 --- a/tok/hnsw/search_layer.go +++ b/tok/hnsw/search_layer.go @@ -11,7 +11,7 @@ type searchLayerResult[T c.Float] struct { // neighbors represents the candidates with the best scores so far. neighbors []minPersistentHeapElement[T] // visited represents elements seen (so we don't try to re-visit). - visited []minPersistentHeapElement[T] + visited map[uint64]minPersistentHeapElement[T] path []uint64 metrics map[string]uint64 level int @@ -29,7 +29,7 @@ type searchLayerResult[T c.Float] struct { func newLayerResult[T c.Float](level int) *searchLayerResult[T] { return &searchLayerResult[T]{ neighbors: []minPersistentHeapElement[T]{}, - visited: []minPersistentHeapElement[T]{}, + visited: make(map[uint64]minPersistentHeapElement[T]), path: []uint64{}, metrics: make(map[string]uint64), level: level, @@ -38,7 +38,8 @@ func newLayerResult[T c.Float](level int) *searchLayerResult[T] { func (slr *searchLayerResult[T]) setFirstPathNode(n minPersistentHeapElement[T]) { slr.neighbors = []minPersistentHeapElement[T]{n} - slr.visited = []minPersistentHeapElement[T]{n} + slr.visited = make(map[uint64]minPersistentHeapElement[T]) + slr.visited[n.index] = n slr.path = []uint64{n.index} } @@ -86,17 +87,13 @@ func (slr *searchLayerResult[T]) bestNeighbor() minPersistentHeapElement[T] { return slr.neighbors[0] } -func (slr *searchLayerResult[T]) nodeVisited(n minPersistentHeapElement[T]) bool { - for _, visitedNode := range slr.visited { - if visitedNode.index == n.index { - return true - } - } - return false +func (slr *searchLayerResult[T]) indexVisited(n uint64) bool { + _, ok := slr.visited[n] + return ok } func (slr *searchLayerResult[T]) addToVisited(n minPersistentHeapElement[T]) { - slr.visited = append(slr.visited, n) + slr.visited[n.index] = n } func (slr *searchLayerResult[T]) updateFinalMetrics(r *index.SearchPathResult) { diff --git a/tok/index/helper.go b/tok/index/helper.go index 40274ea7a33..de56ac5023d 100644 --- a/tok/index/helper.go +++ b/tok/index/helper.go @@ -19,8 +19,11 @@ package index import ( "encoding/binary" "math" + "reflect" + "unsafe" c "github.com/dgraph-io/dgraph/tok/constraints" + "github.com/golang/glog" ) // BytesAsFloatArray[T c.Float](encoded) converts encoded into a []T, @@ -31,40 +34,27 @@ import ( // The result is appended to the given retVal slice. If retVal is nil // then a new slice is created and appended to. func BytesAsFloatArray[T c.Float](encoded []byte, retVal *[]T, floatBits int) { - // Unfortunately, this is not as simple as casting the result, - // and it is also not possible to directly use the - // golang "unsafe" library to directly do the conversion. - // The machine where this operation gets run might prefer - // BigEndian/LittleEndian, but the machine that sent it may have - // preferred the other, and there is no way to tell! - // - // The solution below, unfortunately, requires another memory - // allocation. - // TODO Potential optimization: If we detect that current machine is - // using LittleEndian format, there might be a way of making this - // work with the golang "unsafe" library. floatBytes := floatBits / 8 - *retVal = (*retVal)[:0] - resultLen := len(encoded) / floatBytes - if resultLen == 0 { + if len(encoded) == 0 { + *retVal = []T{} return } - for i := 0; i < resultLen; i++ { - // Assume LittleEndian for encoding since this is - // the assumption elsewhere when reading from client. - // See dgraph-io/dgo/protos/api.pb.go - // See also dgraph-io/dgraph/types/conversion.go - // This also seems to be the preference from many examples - // I have found via Google search. It's unclear why this - // should be a preference. - if retVal == nil { - retVal = &[]T{} - } - *retVal = append(*retVal, BytesToFloat[T](encoded, floatBits)) - encoded = encoded[(floatBytes):] + // Ensure the byte slice length is a multiple of 8 (size of float64) + if len(encoded)%floatBytes != 0 { + glog.Errorf("Invalid byte slice length %d %v", len(encoded), encoded) + return } + + if retVal == nil { + *retVal = make([]T, len(encoded)/floatBytes) + } + *retVal = (*retVal)[:0] + header := (*reflect.SliceHeader)(unsafe.Pointer(retVal)) + header.Data = uintptr(unsafe.Pointer(&encoded[0])) + header.Len = len(encoded) / floatBytes + header.Cap = len(encoded) / floatBytes } func BytesToFloat[T c.Float](encoded []byte, floatBits int) T { diff --git a/tok/index/helper_test.go b/tok/index/helper_test.go new file mode 100644 index 00000000000..1968ae8d4d3 --- /dev/null +++ b/tok/index/helper_test.go @@ -0,0 +1,423 @@ +/* + * Copyright 2016-2024 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package index + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "encoding/gob" + "encoding/json" + "fmt" + "testing" + "unsafe" + + "github.com/dgraph-io/dgraph/protos/pb" + c "github.com/dgraph-io/dgraph/tok/constraints" + "github.com/viterin/vek/vek32" +) + +// GenerateMatrix generates a 2D slice of uint64 with varying lengths for each row. +func GenerateMatrix(rows int) ([][]uint64, *pb.SortResult) { + pbm := &pb.SortResult{} + matrix := make([][]uint64, rows) + value := uint64(100) + for i := range matrix { + cols := i + 1 // Variable number of columns for each row + matrix[i] = make([]uint64, cols) + for j := range matrix[i] { + matrix[i][j] = value + value++ + } + pbm.UidMatrix = append(pbm.UidMatrix, &pb.List{Uids: matrix[i]}) + } + return matrix, pbm +} + +// Encoding and decoding functions +func encodeUint64Matrix(matrix [][]uint64) ([]byte, error) { + var buf bytes.Buffer + + // Write number of rows + if err := binary.Write(&buf, binary.LittleEndian, uint64(len(matrix))); err != nil { + return nil, err + } + + // Write each row's length and data + for _, row := range matrix { + if err := binary.Write(&buf, binary.LittleEndian, uint64(len(row))); err != nil { + return nil, err + } + for _, value := range row { + if err := binary.Write(&buf, binary.LittleEndian, value); err != nil { + return nil, err + } + } + } + + return buf.Bytes(), nil +} + +func decodeUint64Matrix(data []byte) ([][]uint64, error) { + buf := bytes.NewReader(data) + + var numRows uint64 + if err := binary.Read(buf, binary.LittleEndian, &numRows); err != nil { + return nil, err + } + + matrix := make([][]uint64, numRows) + for i := range matrix { + var numCols uint64 + if err := binary.Read(buf, binary.LittleEndian, &numCols); err != nil { + return nil, err + } + matrix[i] = make([]uint64, numCols) + for j := range matrix[i] { + if err := binary.Read(buf, binary.LittleEndian, &matrix[i][j]); err != nil { + return nil, err + } + } + } + + return matrix, nil +} + +func encodeUint64MatrixWithGob(matrix [][]uint64) ([]byte, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + + if err := enc.Encode(matrix); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func decodeUint64MatrixWithGob(data []byte) ([][]uint64, error) { + var matrix [][]uint64 + buf := bytes.NewReader(data) + dec := gob.NewDecoder(buf) + + if err := dec.Decode(&matrix); err != nil { + return nil, err + } + + return matrix, nil +} + +func encodeUint64MatrixWithJSON(matrix [][]uint64) ([]byte, error) { + return json.Marshal(matrix) +} + +func decodeUint64MatrixWithJSON(data []byte) ([][]uint64, error) { + var matrix [][]uint64 + if err := json.Unmarshal(data, &matrix); err != nil { + return nil, err + } + return matrix, nil +} + +func encodeUint64MatrixUnsafe(matrix [][]uint64) []byte { + if len(matrix) == 0 { + return nil + } + + // Calculate the total size + var totalSize uint64 + for _, row := range matrix { + totalSize += uint64(len(row))*uint64(unsafe.Sizeof(uint64(0))) + uint64(unsafe.Sizeof(uint64(0))) + } + totalSize += uint64(unsafe.Sizeof(uint64(0))) + + // Create a byte slice with the appropriate size + data := make([]byte, totalSize) + + offset := 0 + // Write number of rows + rows := uint64(len(matrix)) + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&rows))[:]) + offset += 8 + + // Write each row's length and data + for _, row := range matrix { + rowLen := uint64(len(row)) + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&rowLen))[:]) + offset += 8 + for i := range row { + copy(data[offset:offset+8], (*[8]byte)(unsafe.Pointer(&row[i]))[:]) + offset += 8 + } + } + + return data +} + +func decodeUint64MatrixUnsafe(data []byte) ([][]uint64, error) { + if len(data) == 0 { + return nil, nil + } + + offset := 0 + // Read number of rows + rows := *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + + matrix := make([][]uint64, rows) + + for i := 0; i < int(rows); i++ { + // Read row length + rowLen := *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + + matrix[i] = make([]uint64, rowLen) + for j := 0; j < int(rowLen); j++ { + matrix[i][j] = *(*uint64)(unsafe.Pointer(&data[offset])) + offset += 8 + } + } + + return matrix, nil +} + +func encodeUint64MatrixWithProtobuf(protoMatrix *pb.SortResult) ([]byte, error) { + // Convert the matrix to the protobuf structure + return protoMatrix.Marshal() +} + +func decodeUint64MatrixWithProtobuf(data []byte, protoMatrix *pb.SortResult) error { + // Unmarshal the protobuf data into the protobuf structure + return protoMatrix.Unmarshal(data) +} + +// Combined benchmark function +func BenchmarkEncodeDecodeUint64Matrix(b *testing.B) { + matrix, pbm := GenerateMatrix(10) + + b.Run("Binary Encoding/Decoding", func(b *testing.B) { + for i := 0; i < b.N; i++ { + data, err := encodeUint64Matrix(matrix) + if err != nil { + b.Error(err) + } + _, err = decodeUint64Matrix(data) + if err != nil { + b.Error(err) + } + } + }) + + b.Run("Gob Encoding/Decoding", func(b *testing.B) { + for i := 0; i < b.N; i++ { + data, err := encodeUint64MatrixWithGob(matrix) + if err != nil { + b.Error(err) + } + _, err = decodeUint64MatrixWithGob(data) + if err != nil { + b.Error(err) + } + } + }) + + b.Run("JSON Encoding/Decoding", func(b *testing.B) { + for i := 0; i < b.N; i++ { + data, err := encodeUint64MatrixWithJSON(matrix) + if err != nil { + b.Error(err) + } + _, err = decodeUint64MatrixWithJSON(data) + if err != nil { + b.Error(err) + } + } + }) + + b.Run("PB Encoding/Decoding", func(b *testing.B) { + var pba pb.SortResult + for i := 0; i < b.N; i++ { + data, err := encodeUint64MatrixWithProtobuf(pbm) + if err != nil { + b.Error(err) + } + + err = decodeUint64MatrixWithProtobuf(data, &pba) + if err != nil { + b.Error(err) + } + } + }) + + b.Run("Unsafe Encoding/Decoding", func(b *testing.B) { + for i := 0; i < b.N; i++ { + data := encodeUint64MatrixUnsafe(matrix) + _, err := decodeUint64MatrixUnsafe(data) + if err != nil { + b.Error(err) + } + } + }) +} + +func dotProductT[T c.Float](a, b []T, floatBits int) { + var dotProduct T + if len(a) != len(b) { + return + } + for i := 0; i < len(a); i++ { + dotProduct += a[i] * b[i] + } +} + +func dotProduct(a, b []float32) { + if len(a) != len(b) { + return + } + sum := int8(0) + for i := 0; i < len(a); i += 2 { + sum += *(*int8)(unsafe.Pointer(&a[i]))**(*int8)(unsafe.Pointer(&b[i])) + + *(*int8)(unsafe.Pointer(&a[i+1]))**(*int8)(unsafe.Pointer(&b[i+1])) + } +} + +func BenchmarkDotProduct(b *testing.B) { + num := 1500 + data := make([]byte, 64*num) + _, err := rand.Read(data) + if err != nil { + b.Skip() + } + + b.Run(fmt.Sprintf("vek:size=%d", len(data)), + func(b *testing.B) { + temp := make([]float32, num) + BytesAsFloatArray[float32](data, &temp, 32) + for k := 0; k < b.N; k++ { + vek32.Dot(temp, temp) + } + }) + + b.Run(fmt.Sprintf("dotProduct:size=%d", len(data)), + func(b *testing.B) { + + temp := make([]float32, num) + BytesAsFloatArray[float32](data, &temp, 32) + for k := 0; k < b.N; k++ { + dotProduct(temp, temp) + } + + }) + + b.Run(fmt.Sprintf("dotProductT:size=%d", len(data)), + func(b *testing.B) { + + temp := make([]float32, num) + BytesAsFloatArray[float32](data, &temp, 32) + for k := 0; k < b.N; k++ { + dotProductT[float32](temp, temp, 32) + } + }) +} + +func pointerFloatConversion[T c.Float](encoded []byte, retVal *[]T, floatBits int) { + floatBytes := floatBits / 8 + + // Ensure the byte slice length is a multiple of 8 (size of float32) + if len(encoded)%floatBytes != 0 { + fmt.Println("Invalid byte slice length") + return + } + + // Create a slice header + *retVal = *(*[]T)(unsafe.Pointer(&encoded)) +} + +func littleEndianBytesAsFloatArray[T c.Float](encoded []byte, retVal *[]T, floatBits int) { + // Unfortunately, this is not as simple as casting the result, + // and it is also not possible to directly use the + // golang "unsafe" library to directly do the conversion. + // The machine where this operation gets run might prefer + // BigEndian/LittleEndian, but the machine that sent it may have + // preferred the other, and there is no way to tell! + // + // The solution below, unfortunately, requires another memory + // allocation. + // TODO Potential optimization: If we detect that current machine is + // using LittleEndian format, there might be a way of making this + // work with the golang "unsafe" library. + floatBytes := floatBits / 8 + + // Ensure the byte slice length is a multiple of 8 (size of float32) + if len(encoded)%floatBytes != 0 { + fmt.Println("Invalid byte slice length") + return + } + + *retVal = (*retVal)[:0] + resultLen := len(encoded) / floatBytes + if resultLen == 0 { + return + } + for i := 0; i < resultLen; i++ { + // Assume LittleEndian for encoding since this is + // the assumption elsewhere when reading from client. + // See dgraph-io/dgo/protos/api.pb.go + // See also dgraph-io/dgraph/types/conversion.go + // This also seems to be the preference from many examples + // I have found via Google search. It's unclear why this + // should be a preference. + if retVal == nil { + retVal = &[]T{} + } + *retVal = append(*retVal, BytesToFloat[T](encoded, floatBits)) + + encoded = encoded[(floatBytes):] + } +} + +func BenchmarkFloatConverstion(b *testing.B) { + num := 1500 + data := make([]byte, 64*num) + _, err := rand.Read(data) + if err != nil { + b.Skip() + } + + b.Run(fmt.Sprintf("Current:size=%d", len(data)), + func(b *testing.B) { + temp := make([]float32, num) + for k := 0; k < b.N; k++ { + BytesAsFloatArray[float32](data, &temp, 64) + } + }) + + b.Run(fmt.Sprintf("pointerFloat:size=%d", len(data)), + func(b *testing.B) { + temp := make([]float32, num) + for k := 0; k < b.N; k++ { + pointerFloatConversion[float32](data, &temp, 64) + } + }) + + b.Run(fmt.Sprintf("littleEndianFloat:size=%d", len(data)), + func(b *testing.B) { + temp := make([]float32, num) + for k := 0; k < b.N; k++ { + littleEndianBytesAsFloatArray[float32](data, &temp, 64) + } + }) +} diff --git a/worker/backup_ee.go b/worker/backup_ee.go index 31b0156faf6..dc77652dba5 100644 --- a/worker/backup_ee.go +++ b/worker/backup_ee.go @@ -37,6 +37,7 @@ import ( "github.com/dgraph-io/dgraph/ee/enc" "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/protos/pb" + "github.com/dgraph-io/dgraph/tok/hnsw" "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto/z" ) @@ -194,6 +195,28 @@ func ProcessBackupRequest(ctx context.Context, req *pb.BackupRequest) error { for pred := range group.Tablets { predMap[gid] = append(predMap[gid], pred) } + + } + + // see if any of the predicates are vector predicates and add the supporting + // vector predicates to the backup request. + vecPredMap := make(map[uint32][]string) + for gid, preds := range predMap { + schema, err := GetSchemaOverNetwork(ctx, &pb.SchemaRequest{Predicates: preds}) + if err != nil { + return err + } + + for _, pred := range schema { + if pred.Type == "float32vector" && len(pred.IndexSpecs) != 0 { + vecPredMap[gid] = append(predMap[gid], pred.Predicate+hnsw.VecEntry, pred.Predicate+hnsw.VecKeyword, + pred.Predicate+hnsw.VecDead) + } + } + } + + for gid, preds := range vecPredMap { + predMap[gid] = append(predMap[gid], preds...) } glog.Infof( diff --git a/worker/config.go b/worker/config.go index 3d4cfde14e4..8eb09aa7a8b 100644 --- a/worker/config.go +++ b/worker/config.go @@ -68,6 +68,12 @@ type Options struct { // Define different ChangeDataCapture configurations ChangeDataConf string + + // TypeFilterUidLimit decides how many elements would be searched directly + // vs searched via type index. If the number of elements are too low, then querying the + // index might be slower. This would allow people to set their limit according to + // their use case. + TypeFilterUidLimit int64 } // Config holds an instance of the server options.. diff --git a/worker/draft.go b/worker/draft.go index 23b92099a57..6ed5d97489c 100644 --- a/worker/draft.go +++ b/worker/draft.go @@ -549,12 +549,19 @@ func (n *node) applyMutations(ctx context.Context, proposal *pb.Proposal) (rerr errCh <- process(m.Edges[start:end]) }(start, end) } + // Earlier we were returning after even if one thread had an error. We should wait for + // all the transactions to finish. We call txn.Update() when this function exists. This could cause + // a deadlock with runMutation. + var errs error for i := 0; i < numGo; i++ { if err := <-errCh; err != nil { - return err + if errs == nil { + errs = errors.New("Got error while running mutation") + } + errs = errors.Wrapf(err, errs.Error()) } } - return nil + return errs } func (n *node) applyCommitted(proposal *pb.Proposal, key uint64) error { @@ -836,10 +843,13 @@ func (n *node) commitOrAbort(pkey uint64, delta *pb.OracleDelta) error { writer := posting.NewTxnWriter(pstore) toDisk := func(start, commit uint64) { txn := posting.Oracle().GetTxn(start) - if txn == nil { + if txn == nil || commit == 0 { return } - txn.Update() + // If the transaction has failed, we dont need to update it. + if commit != 0 { + txn.Update() + } // We start with 20 ms, so that we end up waiting 5 mins by the end. // If there is any transient issue, it should get fixed within that timeframe. err := x.ExponentialRetry(int(x.Config.MaxRetries), @@ -865,6 +875,7 @@ func (n *node) commitOrAbort(pkey uint64, delta *pb.OracleDelta) error { if err := writer.Flush(); err != nil { return errors.Wrapf(err, "while flushing to disk") } + if x.WorkerConfig.HardSync { if err := pstore.Sync(); err != nil { glog.Errorf("Error while calling Sync while commitOrAbort: %v", err) @@ -879,9 +890,8 @@ func (n *node) commitOrAbort(pkey uint64, delta *pb.OracleDelta) error { // Clear all the cached lists that were touched by this transaction. for _, status := range delta.Txns { txn := posting.Oracle().GetTxn(status.StartTs) - txn.RemoveCachedKeys() + txn.UpdateCachedKeys(status.CommitTs) } - posting.WaitForCache() // Now advance Oracle(), so we can service waiting reads. posting.Oracle().ProcessDelta(delta) @@ -1248,6 +1258,7 @@ func (n *node) Run() { } else { ostats.Record(ctx, x.RaftIsLeader.M(0)) } + timer.Record("updating soft state") } if leader { // Leader can send messages in parallel with writing to disk. @@ -1262,6 +1273,7 @@ func (n *node) Run() { // NOTE: We can do some optimizations here to drop messages. n.Send(&rd.Messages[i]) } + timer.Record("leader sending message") } if span != nil { span.Annotate(nil, "Handled ReadStates and SoftState.") @@ -1334,6 +1346,7 @@ func (n *node) Run() { if span != nil { span.Annotate(nil, "Applied or retrieved snapshot.") } + timer.Record("got snapshot") } // Store the hardstate and entries. Note that these are not CommittedEntries. diff --git a/worker/export.go b/worker/export.go index 35b0e2c3083..f094e7e754b 100644 --- a/worker/export.go +++ b/worker/export.go @@ -329,6 +329,9 @@ func toSchema(attr string, update *pb.SchemaUpdate) *bpb.KV { if update.GetUpsert() { x.Check2(buf.WriteString(" @upsert")) } + if update.GetUnique() { + x.Check2(buf.WriteString(" @unique")) + } x.Check2(buf.WriteString(" . \n")) //TODO(Naman): We don't need the version anymore. return &bpb.KV{ diff --git a/worker/mutation.go b/worker/mutation.go index abae1515dde..71b3db0cbc0 100644 --- a/worker/mutation.go +++ b/worker/mutation.go @@ -642,9 +642,9 @@ func Timestamps(ctx context.Context, num *pb.Num) (*pb.AssignedIds, error) { return c.Timestamps(ctx, num) } -func fillTxnContext(tctx *api.TxnContext, startTs uint64) { +func fillTxnContext(tctx *api.TxnContext, startTs uint64, isErrored bool) { if txn := posting.Oracle().GetTxn(startTs); txn != nil { - txn.FillContext(tctx, groups().groupId()) + txn.FillContext(tctx, groups().groupId(), isErrored) } // We do not need to fill linread mechanism anymore, because transaction // start ts is sufficient to wait for, to achieve lin reads. @@ -950,7 +950,8 @@ func (w *grpcWorker) proposeAndWait(ctx context.Context, txnCtx *api.TxnContext, node := groups().Node err := node.proposeAndWait(ctx, &pb.Proposal{Mutations: m}) - fillTxnContext(txnCtx, m.StartTs) + // When we are filling txn context, we don't need to update latest delta if the transaction has failed. + fillTxnContext(txnCtx, m.StartTs, err != nil) return err } diff --git a/worker/restore_map.go b/worker/restore_map.go index 4a962f18fac..96e0bc5882a 100644 --- a/worker/restore_map.go +++ b/worker/restore_map.go @@ -27,6 +27,7 @@ import ( "os" "path/filepath" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -44,6 +45,8 @@ import ( "github.com/dgraph-io/dgraph/ee/enc" "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/protos/pb" + "github.com/dgraph-io/dgraph/schema" + "github.com/dgraph-io/dgraph/tok/hnsw" "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto/z" ) @@ -470,6 +473,7 @@ func (m *mapper) processReqCh(ctx context.Context) error { } return nil } + // We changed the format of predicate in 2103 and 2105. SchemaUpdate and TypeUpdate have // predicate stored within them, so they also need to be updated accordingly. switch in.version { @@ -488,6 +492,13 @@ func (m *mapper) processReqCh(ctx context.Context) error { default: // for manifest versions >= 2015, do nothing. } + + // If the predicate is a vector indexing predicate, skip further processing. + // currently we don't store vector supporting predicates in the schema. + if strings.HasSuffix(parsedKey.Attr, hnsw.VecEntry) || strings.HasSuffix(parsedKey.Attr, hnsw.VecKeyword) || + strings.HasSuffix(parsedKey.Attr, hnsw.VecDead) { + return nil + } // Reset the StreamId to prevent ordering issues while writing to stream writer. kv.StreamId = 0 // Schema and type keys are not stored in an intermediate format so their @@ -538,6 +549,23 @@ func (m *mapper) processReqCh(ctx context.Context) error { } } + // If the backup was taken on old version, we need to set unique to true for xid predicates. + if parsedKey.IsSchema() { + var update pb.SchemaUpdate + if err := update.Unmarshal(kv.Value); err != nil { + return err + } + + if strings.HasSuffix(update.Predicate, "dgraph.xid") && !update.Unique && schema.IsUniqueDgraphXid { + update.Unique = true + } + + kv.Value, err = update.Marshal() + if err != nil { + return err + } + } + if err := toBuffer(kv, version); err != nil { return err } diff --git a/worker/server_state.go b/worker/server_state.go index 208ff105f22..9cb256c7acd 100644 --- a/worker/server_state.go +++ b/worker/server_state.go @@ -48,11 +48,11 @@ const ( `client_key=; sasl-mechanism=PLAIN; tls=false;` LimitDefaults = `mutations=allow; query-edge=1000000; normalize-node=10000; ` + `mutations-nquad=1000000; disallow-drop=false; query-timeout=0ms; txn-abort-after=5m; ` + - ` max-retries=10;max-pending-queries=10000;shared-instance=false` + ` max-retries=10;max-pending-queries=10000;shared-instance=false;type-filter-uid-limit=10` ZeroLimitsDefaults = `uid-lease=0; refill-interval=30s; disable-admin-http=false;` GraphQLDefaults = `introspection=true; debug=false; extensions=true; poll-interval=1s; ` + `lambda-url=;` - CacheDefaults = `size-mb=1024; percentage=0,65,35;` + CacheDefaults = `size-mb=1024; percentage=0,80,20;` FeatureFlagsDefaults = `normalize-compatibility-mode=` ) diff --git a/worker/snapshot.go b/worker/snapshot.go index 311804c4ea5..676821a4051 100644 --- a/worker/snapshot.go +++ b/worker/snapshot.go @@ -114,6 +114,8 @@ func (n *node) populateSnapshot(snap pb.Snapshot, pl *conn.Pool) error { if err := deleteStalePreds(ctx, done, snap.ReadTs); err != nil { return err } + // Reset the cache after having received a snapshot. + posting.ResetCache() glog.Infof("Snapshot writes DONE. Sending ACK") // Send an acknowledgement back to the leader. diff --git a/worker/task.go b/worker/task.go index 7761033e2e0..a8969d02ec0 100644 --- a/worker/task.go +++ b/worker/task.go @@ -312,7 +312,7 @@ type funcArgs struct { // The function tells us whether we want to fetch value posting lists or uid posting lists. func (srcFn *functionContext) needsValuePostings(typ types.TypeID) (bool, error) { switch srcFn.fnType { - case aggregatorFn, passwordFn: + case aggregatorFn, passwordFn, similarToFn: return true, nil case compareAttrFn: if len(srcFn.tokens) > 0 { @@ -325,7 +325,7 @@ func (srcFn *functionContext) needsValuePostings(typ types.TypeID) (bool, error) case uidInFn, compareScalarFn: // Operate on uid postings return false, nil - case notAFunction, similarToFn: + case notAFunction: return typ.IsScalar(), nil } return false, errors.Errorf("Unhandled case in fetchValuePostings for fn: %s", srcFn.fname) @@ -381,7 +381,7 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er int(numNeighbors), index.AcceptAll[float32]) } - if err != nil { + if err != nil && !strings.Contains(err.Error(), hnsw.EmptyHNSWTreeError+": "+badger.ErrKeyNotFound.Error()) { return err } sort.Slice(nnUids, func(i, j int) bool { return nnUids[i] < nnUids[j] }) @@ -1852,6 +1852,15 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) { fc.tokens = append(fc.tokens, tokens...) } + checkUidEmpty := func(uids []uint64) bool { + for _, i := range uids { + if i == 0 { + return false + } + } + return true + } + // In case of non-indexed predicate, there won't be any tokens. We will fetch value // from data keys. // If number of index keys is more than no. of uids to filter, so its better to fetch values @@ -1865,6 +1874,10 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) { case q.UidList != nil && len(fc.tokens) > len(q.UidList.Uids) && fc.fname != eq: fc.tokens = fc.tokens[:0] fc.n = len(q.UidList.Uids) + case q.UidList != nil && fc.fname == eq && strings.HasSuffix(attr, "dgraph.type") && + int64(len(q.UidList.Uids)) < Config.TypeFilterUidLimit && checkUidEmpty(q.UidList.Uids): + fc.tokens = fc.tokens[:0] + fc.n = len(q.UidList.Uids) default: fc.n = len(fc.tokens) } diff --git a/x/keys.go b/x/keys.go index 9160cc384bb..1b0986fd0be 100644 --- a/x/keys.go +++ b/x/keys.go @@ -303,7 +303,7 @@ type ParsedKey struct { func (p ParsedKey) String() string { if p.IsIndex() { - return fmt.Sprintf("UID: %v, Attr: %v, IsIndex: true, Term: %v", p.Uid, p.Attr, p.Count) + return fmt.Sprintf("UID: %v, Attr: %v, IsIndex: true, Term: %v", p.Uid, p.Attr, []byte(p.Term)) } else if p.IsCountOrCountRev() { return fmt.Sprintf("UID: %v, Attr: %v, IsCount/Ref: true, Count: %v", p.Uid, p.Attr, p.Count) } else {