Skip to content

Commit

Permalink
feat(request-payload-validation): Add test case reproducing bug report (
Browse files Browse the repository at this point in the history
#11)

Fixes #10 . 
* Adds a test case validating the bug report
* Removed graphql-go in favor of gqlparser, which is more active and
recent
* Add SchemaProvider which loads and exposes the Schema for validations
* Refactor code for gqlparser
* Update test cases

Decided to give awareness of the schema to go-graphql-armor, so it can
validate and interpret with its knowledge.

---------

Co-authored-by: ldebruijn <[email protected]>
  • Loading branch information
ldebruijn and ldebruijn authored Dec 22, 2023
1 parent f655893 commit f73dec5
Show file tree
Hide file tree
Showing 12 changed files with 462 additions and 244 deletions.
41 changes: 28 additions & 13 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"flag"
"fmt"
"github.com/ardanlabs/conf/v3"
"github.com/graphql-go/graphql"
"github.com/ldebruijn/go-graphql-armor/internal/app/config"
"github.com/ldebruijn/go-graphql-armor/internal/business/aliases"
"github.com/ldebruijn/go-graphql-armor/internal/business/block_field_suggestions"
Expand All @@ -16,8 +15,12 @@ import (
"github.com/ldebruijn/go-graphql-armor/internal/business/persisted_operations"
"github.com/ldebruijn/go-graphql-armor/internal/business/proxy"
"github.com/ldebruijn/go-graphql-armor/internal/business/readiness"
"github.com/ldebruijn/go-graphql-armor/internal/business/schema"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/parser"
"github.com/vektah/gqlparser/v2/validator"
log2 "log"
"log/slog"
"net/http"
Expand Down Expand Up @@ -104,9 +107,15 @@ func run(log *slog.Logger, cfg *config.Config, shutdown chan os.Signal) error {
return nil
}

schemaProvider, err := schema.NewSchema(cfg.Schema, log)
if err != nil {
log.Error("Error initializing schema", "err", err)
return nil
}

mux := http.NewServeMux()

mid := middleware(log, cfg, po)
mid := middleware(log, cfg, po, schemaProvider)

mux.Handle("/metrics", promhttp.Handler())
mux.Handle("/internal/healthz/readiness", readiness.NewReadinessHandler())
Expand Down Expand Up @@ -150,12 +159,12 @@ func run(log *slog.Logger, cfg *config.Config, shutdown chan os.Signal) error {
return nil
}

func middleware(log *slog.Logger, cfg *config.Config, po *persisted_operations.PersistedOperationsHandler) func(next http.Handler) http.Handler {
func middleware(log *slog.Logger, cfg *config.Config, po *persisted_operations.PersistedOperationsHandler, schema *schema.Provider) func(next http.Handler) http.Handler {
rec := middleware2.Recover(log)
httpInstrumentation := HttpInstrumentation()

_ = aliases.NewMaxAliasesRule(cfg.MaxAliases)
vr := ValidationRules()
aliases.NewMaxAliasesRule(cfg.MaxAliases)
vr := ValidationRules(schema)

fn := func(next http.Handler) http.Handler {
return rec(httpInstrumentation(po.Execute(vr(next))))
Expand All @@ -178,22 +187,28 @@ func HttpInstrumentation() func(next http.Handler) http.Handler {
}
}

func ValidationRules() func(next http.Handler) http.Handler {
func ValidationRules(schema *schema.Provider) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
payload, err := gql.ParseRequestPayload(r)
if err != nil {
next.ServeHTTP(w, r)
return
}
params := graphql.Params{
RequestString: payload.Query,
Context: r.Context(),
}
result := graphql.Do(params)

if result.HasErrors() {
_ = json.NewEncoder(w).Encode(result)
var query, _ = parser.ParseQuery(&ast.Source{
Name: payload.OperationName,
Input: payload.Query,
})

errs := validator.Validate(schema.Get(), query)

if errs != nil {
response := map[string]interface{}{
"data": nil,
"errors": errs,
}
_ = json.NewEncoder(w).Encode(response)
return
}

Expand Down
124 changes: 123 additions & 1 deletion cmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func TestHttpServerIntegration(t *testing.T) {
mockResponse map[string]interface{}
mockStatusCode int
cfgOverrides func(cfg *config.Config) *config.Config
schema string
}
tests := []struct {
name string
Expand All @@ -41,6 +42,16 @@ func TestHttpServerIntegration(t *testing.T) {
r := httptest.NewRequest("POST", "/graphql", bytes.NewBuffer(bts))
return r
}(),
schema: `
extend type Query {
product(id: ID!): Product
}
type Product {
id: ID!
name: String
}
`,
cfgOverrides: func(cfg *config.Config) *config.Config {
cfg.PersistedOperations.Enabled = true
cfg.PersistedOperations.Store = "./"
Expand Down Expand Up @@ -88,6 +99,16 @@ func TestHttpServerIntegration(t *testing.T) {
r := httptest.NewRequest("POST", "/graphql", bytes.NewBuffer(bts))
return r
}(),
schema: `
extend type Query {
product(id: ID!): Product
}
type Product {
id: ID!
name: String
}
`,
cfgOverrides: func(cfg *config.Config) *config.Config {
cfg.PersistedOperations.Enabled = true
cfg.PersistedOperations.Store = "./"
Expand Down Expand Up @@ -134,6 +155,16 @@ func TestHttpServerIntegration(t *testing.T) {
r := httptest.NewRequest("POST", "/graphql", bytes.NewBuffer(bts))
return r
}(),
schema: `
extend type Query {
product(id: ID!): Product
}
type Product {
id: ID!
name: String
}
`,
cfgOverrides: func(cfg *config.Config) *config.Config {
cfg.PersistedOperations.Enabled = true
cfg.PersistedOperations.Store = "./"
Expand Down Expand Up @@ -184,7 +215,7 @@ func TestHttpServerIntegration(t *testing.T) {
request: func() *http.Request {
body := map[string]interface{}{
"query": `
query Foo {
query Foo($image: ImageInput!) {
a1: uploadImage(image: $image)
a2: uploadImage(image: $image)
a3: uploadImage(image: $image)
Expand All @@ -196,12 +227,26 @@ query Foo {
a9: uploadImage(image: $image)
a10: uploadImage(image: $image)
}`,
"variables": map[string]interface{}{
"image": map[string]interface{}{
"id": "1",
},
},
}

bts, _ := json.Marshal(body)
r := httptest.NewRequest("POST", "/graphql", bytes.NewBuffer(bts))
return r
}(),
schema: `
extend type Query {
uploadImage(image: ImageInput!): String
}
input ImageInput {
id: ID!
}
`,
cfgOverrides: func(cfg *config.Config) *config.Config {
cfg.MaxAliases.Enabled = true
cfg.MaxAliases.Max = 3
Expand Down Expand Up @@ -250,6 +295,16 @@ query Foo {
r := httptest.NewRequest("POST", "/graphql", bytes.NewBuffer(bts))
return r
}(),
schema: `
extend type Query {
product(id: ID!): Product
}
type Product {
id: ID!
name: String
}
`,
cfgOverrides: func(cfg *config.Config) *config.Config {
cfg.PersistedOperations.Enabled = true
cfg.PersistedOperations.Store = "./"
Expand Down Expand Up @@ -292,6 +347,61 @@ query Foo {
assert.True(t, errorsContainsMessage("Some unexpected error", actual))
},
},
{
name: "validates incoming request payload against schema",
args: args{
request: func() *http.Request {
body := map[string]interface{}{
"query": "query Foo($id: ID!) { product(id: $id) { id name } }",
}

bts, _ := json.Marshal(body)
r := httptest.NewRequest("POST", "/graphql", bytes.NewBuffer(bts))
return r
}(),
schema: `
extend type Query {
product(id: ID!): Product
}
type Product {
id: ID!
name: String
}
`,
cfgOverrides: func(cfg *config.Config) *config.Config {
cfg.PersistedOperations.Enabled = true
cfg.PersistedOperations.Store = "./"
cfg.PersistedOperations.FailUnknownOperations = false
return cfg
},
mockResponse: map[string]interface{}{
"data": map[string]interface{}{
"product": map[string]interface{}{
"id": "1",
"name": "name",
},
},
},
mockStatusCode: http.StatusOK,
},
want: func(t *testing.T, response *http.Response) {
expected := map[string]interface{}{
"data": map[string]interface{}{
"product": map[string]interface{}{
"id": "1",
"name": "name",
},
},
}
_, _ = json.Marshal(expected)
assert.Equal(t, http.StatusOK, response.StatusCode)
actual, err := io.ReadAll(response.Body)
assert.NoError(t, err)
_ = actual
assert.NotContains(t, string(actual), "\"errors\":")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -305,11 +415,23 @@ query Foo {

shutdown := make(chan os.Signal, 1)

// create temp file for storing schema
file, _ := os.CreateTemp("", "")
defer func() {
_ = os.Remove(file.Name())
}()

write, err := file.Write([]byte(tt.args.schema))
assert.NoError(t, err)
assert.NotEqual(t, 0, write)
_ = file.Close()

defaultConfig, _ := config.NewConfig("")
cfg := tt.args.cfgOverrides(defaultConfig)

// set target to mockserver
cfg.Target.Host = mockServer.URL
cfg.Schema.Path = file.Name()

go func() {
err := run(slog.Default(), cfg, shutdown)
Expand Down
40 changes: 40 additions & 0 deletions cmd/test_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package main

import (
"github.com/stretchr/testify/require"
"github.com/vektah/gqlparser/v2"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/parser"
"github.com/vektah/gqlparser/v2/validator"
"testing"
)

func TestTest(t *testing.T) {
s := gqlparser.MustLoadSchema(
&ast.Source{Name: "graph/schema.graphqls", Input: `
extend type User {
id: ID!
}
extend type Product {
upc: String!
}
union _Entity = Product | User
extend type Query {
entity: _Entity
}
`, BuiltIn: false},
)

q, err := parser.ParseQuery(&ast.Source{Name: "ff", Input: `{
entity {
... on User {
id
}
}
}`})
require.Nil(t, err)
require.Nil(t, validator.Validate(s, q))
}
3 changes: 3 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ target:
host: http://localhost:8081
timeout: 10s
keep_alive: 180s

schema:
path: ./schema.graphql

persisted_operations:
# Enable or disable the feature, enabled by default
Expand Down
17 changes: 11 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@ module github.com/ldebruijn/go-graphql-armor

go 1.21.1

require (
cloud.google.com/go/storage v1.33.0
github.com/ardanlabs/conf/v3 v3.1.6
github.com/prometheus/client_golang v1.17.0
github.com/stretchr/testify v1.8.4
github.com/vektah/gqlparser/v2 v2.5.10
google.golang.org/api v0.142.0
)

require (
cloud.google.com/go v0.110.8 // indirect
cloud.google.com/go/compute v1.23.0 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
cloud.google.com/go/iam v1.1.2 // indirect
cloud.google.com/go/storage v1.33.0 // indirect
github.com/ardanlabs/conf/v3 v3.1.6 // indirect
github.com/agnivade/levenshtein v1.1.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
Expand All @@ -19,14 +27,12 @@ require (
github.com/google/uuid v1.3.1 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.1 // indirect
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
github.com/graphql-go/graphql v0.8.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.17.0 // indirect
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.11.1 // indirect
github.com/stretchr/testify v1.8.4 // indirect
go.opencensus.io v0.24.0 // indirect
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect
Expand All @@ -35,7 +41,6 @@ require (
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
google.golang.org/api v0.142.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto v0.0.0-20230920204549-e6e6cdab5c13 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20230920204549-e6e6cdab5c13 // indirect
Expand Down
Loading

0 comments on commit f73dec5

Please sign in to comment.