Skip to content

Commit d126bba

Browse files
authored
Add Validating Directives (#583)
Adding support for directives which are evaluated prior to executing any resolvers. This allows validation to be performed on the request and prevent it from executing any significant work by rejecting the request early. The most obvious case for this is authorization: based on the requested fields, we can tell whether the request is valid given the current user, and reject the entire request. If that were applied at resolution time, the request would have partially resolved, only to return errors for the specific fields which are not authorized.
1 parent 801181b commit d126bba

File tree

9 files changed

+434
-48
lines changed

9 files changed

+434
-48
lines changed

directives/visitor.go

+5
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,8 @@ type Resolver interface {
2323
type ResolverInterceptor interface {
2424
Resolve(ctx context.Context, args interface{}, next Resolver) (output interface{}, err error)
2525
}
26+
27+
// Validator directive which executes before anything is resolved, allowing the request to be rejected.
28+
type Validator interface {
29+
Validate(ctx context.Context, args interface{}) error
30+
}

example/directives/authorization/README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,16 @@ $ curl 'http://localhost:8080/query' \
8080
return "hasRole"
8181
}
8282

83-
func (h *HasRoleDirective) Resolve(ctx context.Context, args interface{}, next directives.Resolver) (output interface{}, err error) {
83+
func (h *HasRoleDirective) Validate(ctx context.Context, _ interface{}) error {
8484
u, ok := user.FromContext(ctx)
8585
if !ok {
86-
return nil, fmt.Errorf("user not provided in context")
86+
return fmt.Errorf("user not provided in context")
8787
}
8888
role := strings.ToLower(h.Role)
8989
if !u.HasRole(role) {
90-
return nil, fmt.Errorf("access denied, %q role required", role)
90+
return fmt.Errorf("access denied, %q role required", role)
9191
}
92-
return next.Resolve(ctx, args)
92+
return nil
9393
}
9494
```
9595

example/directives/authorization/authorization.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"fmt"
77
"strings"
88

9-
"github.com/graph-gophers/graphql-go/directives"
109
"github.com/graph-gophers/graphql-go/example/directives/authorization/user"
1110
)
1211

@@ -36,17 +35,17 @@ func (h *HasRoleDirective) ImplementsDirective() string {
3635
return "hasRole"
3736
}
3837

39-
func (h *HasRoleDirective) Resolve(ctx context.Context, args interface{}, next directives.Resolver) (output interface{}, err error) {
38+
func (h *HasRoleDirective) Validate(ctx context.Context, _ interface{}) error {
4039
u, ok := user.FromContext(ctx)
4140
if !ok {
42-
return nil, fmt.Errorf("user not provided in cotext")
41+
return fmt.Errorf("user not provided in cotext")
4342
}
4443
role := strings.ToLower(h.Role)
4544
if !u.HasRole(role) {
46-
return nil, fmt.Errorf("access denied, %q role required", role)
45+
return fmt.Errorf("access denied, %q role required", role)
4746
}
4847

49-
return next.Resolve(ctx, args)
48+
return nil
5049
}
5150

5251
type Resolver struct{}

example_directives_test.go

+35-7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"os"
8+
"strings"
89

910
"github.com/graph-gophers/graphql-go"
1011
"github.com/graph-gophers/graphql-go/directives"
@@ -22,11 +23,31 @@ func (h *HasRoleDirective) ImplementsDirective() string {
2223
return "hasRole"
2324
}
2425

25-
func (h *HasRoleDirective) Resolve(ctx context.Context, in interface{}, next directives.Resolver) (interface{}, error) {
26+
func (h *HasRoleDirective) Validate(ctx context.Context, _ interface{}) error {
2627
if ctx.Value(RoleKey) != h.Role {
27-
return nil, fmt.Errorf("access deinied, role %q required", h.Role)
28+
return fmt.Errorf("access denied, role %q required", h.Role)
2829
}
29-
return next.Resolve(ctx, in)
30+
return nil
31+
}
32+
33+
type UpperDirective struct{}
34+
35+
func (d *UpperDirective) ImplementsDirective() string {
36+
return "upper"
37+
}
38+
39+
func (d *UpperDirective) Resolve(ctx context.Context, args interface{}, next directives.Resolver) (interface{}, error) {
40+
out, err := next.Resolve(ctx, args)
41+
if err != nil {
42+
return out, err
43+
}
44+
45+
s, ok := out.(string)
46+
if !ok {
47+
return out, nil
48+
}
49+
50+
return strings.ToUpper(s), nil
3051
}
3152

3253
type authResolver struct{}
@@ -43,13 +64,14 @@ func ExampleDirectives() {
4364
}
4465
4566
directive @hasRole(role: String!) on FIELD_DEFINITION
67+
directive @upper on FIELD_DEFINITION
4668
4769
type Query {
48-
greet(name: String!): String! @hasRole(role: "admin")
70+
greet(name: String!): String! @hasRole(role: "admin") @upper
4971
}
5072
`
5173
opts := []graphql.SchemaOpt{
52-
graphql.Directives(&HasRoleDirective{}),
74+
graphql.Directives(&HasRoleDirective{}, &UpperDirective{}),
5375
// other options go here
5476
}
5577
schema := graphql.MustParseSchema(s, &authResolver{}, opts...)
@@ -86,7 +108,13 @@ func ExampleDirectives() {
86108
// {
87109
// "errors": [
88110
// {
89-
// "message": "access deinied, role \"admin\" required",
111+
// "message": "access denied, role \"admin\" required",
112+
// "locations": [
113+
// {
114+
// "line": 10,
115+
// "column": 4
116+
// }
117+
// ],
90118
// "path": [
91119
// "greet"
92120
// ]
@@ -97,7 +125,7 @@ func ExampleDirectives() {
97125
// Admin user result:
98126
// {
99127
// "data": {
100-
// "greet": "Hello, GraphQL!"
128+
// "greet": "HELLO, GRAPHQL!"
101129
// }
102130
// }
103131
}

graphql_test.go

+201
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"strings"
89
"sync"
910
"testing"
1011
"time"
1112

1213
"github.com/graph-gophers/graphql-go"
1314
"github.com/graph-gophers/graphql-go/directives"
1415
gqlerrors "github.com/graph-gophers/graphql-go/errors"
16+
"github.com/graph-gophers/graphql-go/example/social"
1517
"github.com/graph-gophers/graphql-go/example/starwars"
1618
"github.com/graph-gophers/graphql-go/gqltesting"
1719
"github.com/graph-gophers/graphql-go/introspection"
@@ -477,6 +479,205 @@ func TestCustomDirective(t *testing.T) {
477479
})
478480
}
479481

482+
func TestCustomValidatingDirective(t *testing.T) {
483+
t.Parallel()
484+
485+
gqltesting.RunTests(t, []*gqltesting.Test{
486+
{
487+
Schema: graphql.MustParseSchema(`
488+
directive @hasRole(role: String!) on FIELD_DEFINITION
489+
490+
schema {
491+
query: Query
492+
}
493+
494+
type Query {
495+
hello: String! @hasRole(role: "ADMIN")
496+
}`,
497+
&helloResolver{},
498+
graphql.Directives(&HasRoleDirective{}),
499+
),
500+
Context: context.WithValue(context.Background(), RoleKey, "USER"),
501+
Query: `
502+
{
503+
hello
504+
}
505+
`,
506+
ExpectedResult: "null",
507+
ExpectedErrors: []*gqlerrors.QueryError{
508+
{Message: `access denied, role "ADMIN" required`, Locations: []gqlerrors.Location{{Line: 9, Column: 6}}, Path: []interface{}{"hello"}},
509+
},
510+
},
511+
{
512+
Schema: graphql.MustParseSchema(`
513+
directive @hasRole(role: String!) on FIELD_DEFINITION
514+
515+
schema {
516+
query: Query
517+
}
518+
519+
type Query {
520+
hello: String! @hasRole(role: "ADMIN")
521+
}`,
522+
&helloResolver{},
523+
graphql.Directives(&HasRoleDirective{}),
524+
),
525+
Context: context.WithValue(context.Background(), RoleKey, "ADMIN"),
526+
Query: `
527+
{
528+
hello
529+
}
530+
`,
531+
ExpectedResult: `
532+
{
533+
"hello": "Hello world!"
534+
}
535+
`,
536+
},
537+
{
538+
Schema: graphql.MustParseSchema(
539+
`directive @hasRole(role: String!) on FIELD_DEFINITION
540+
541+
`+strings.ReplaceAll(
542+
social.Schema,
543+
"role: Role!",
544+
`role: Role! @hasRole(role: "ADMIN")`,
545+
),
546+
&social.Resolver{},
547+
graphql.Directives(&HasRoleDirective{}),
548+
graphql.UseFieldResolvers(),
549+
),
550+
Context: context.WithValue(context.Background(), RoleKey, "ADMIN"),
551+
Query: `
552+
query {
553+
user(id: "0x01") {
554+
role
555+
... on User {
556+
email
557+
}
558+
... on Person {
559+
name
560+
}
561+
}
562+
}
563+
`,
564+
ExpectedResult: `
565+
{
566+
"user": {
567+
"role": "ADMIN",
568+
"email": "Albus@hogwarts.com",
569+
"name": "Albus Dumbledore"
570+
}
571+
}
572+
`,
573+
},
574+
{
575+
Schema: graphql.MustParseSchema(
576+
`directive @hasRole(role: String!) on FIELD_DEFINITION
577+
578+
`+strings.ReplaceAll(
579+
starwars.Schema,
580+
"hero(episode: Episode = NEWHOPE): Character",
581+
`hero(episode: Episode = NEWHOPE): Character @hasRole(role: "REBELLION")`,
582+
),
583+
&starwars.Resolver{},
584+
graphql.Directives(&HasRoleDirective{}),
585+
),
586+
Context: context.WithValue(context.Background(), RoleKey, "EMPIRE"),
587+
Query: `
588+
query HeroesOfTheRebellion($episode: Episode!) {
589+
hero(episode: $episode) {
590+
id name
591+
... on Human { starships { id name } }
592+
... on Droid { primaryFunction }
593+
}
594+
}
595+
`,
596+
Variables: map[string]interface{}{"episode": "NEWHOPE"},
597+
ExpectedResult: "null",
598+
ExpectedErrors: []*gqlerrors.QueryError{
599+
{Message: `access denied, role "REBELLION" required`, Locations: []gqlerrors.Location{{Line: 10, Column: 3}}, Path: []interface{}{"hero"}},
600+
},
601+
},
602+
{
603+
Schema: graphql.MustParseSchema(
604+
`directive @hasRole(role: String!) on FIELD_DEFINITION
605+
606+
`+strings.ReplaceAll(
607+
starwars.Schema,
608+
"starships: [Starship]",
609+
`starships: [Starship] @hasRole(role: "REBELLION")`,
610+
),
611+
&starwars.Resolver{},
612+
graphql.Directives(&HasRoleDirective{}),
613+
),
614+
Context: context.WithValue(context.Background(), RoleKey, "EMPIRE"),
615+
Query: `
616+
query HeroesOfTheRebellion($episode: Episode!) {
617+
hero(episode: $episode) {
618+
id name
619+
... on Human { starships { id name } }
620+
... on Droid { primaryFunction }
621+
}
622+
}
623+
`,
624+
Variables: map[string]interface{}{"episode": "NEWHOPE"},
625+
ExpectedResult: "null",
626+
ExpectedErrors: []*gqlerrors.QueryError{
627+
{Message: `access denied, role "REBELLION" required`, Locations: []gqlerrors.Location{{Line: 68, Column: 3}}, Path: []interface{}{"hero", "starships"}},
628+
},
629+
},
630+
{
631+
Schema: graphql.MustParseSchema(
632+
`directive @restrictImperialUnits on FIELD_DEFINITION
633+
634+
`+strings.ReplaceAll(
635+
starwars.Schema,
636+
"height(unit: LengthUnit = METER): Float!",
637+
`height(unit: LengthUnit = METER): Float! @restrictImperialUnits`,
638+
),
639+
&starwars.Resolver{},
640+
graphql.Directives(&restrictImperialUnitsDirective{}),
641+
),
642+
Context: context.WithValue(context.Background(), RoleKey, "REBELLION"),
643+
Query: `
644+
query HeroesOfTheRebellion($episode: Episode!) {
645+
hero(episode: $episode) {
646+
id name
647+
... on Human { height(unit: FOOT) }
648+
}
649+
}
650+
`,
651+
Variables: map[string]interface{}{"episode": "NEWHOPE"},
652+
ExpectedResult: "null",
653+
ExpectedErrors: []*gqlerrors.QueryError{
654+
{Message: `rebels cannot request imperial units`, Locations: []gqlerrors.Location{{Line: 58, Column: 3}}, Path: []interface{}{"hero", "height"}},
655+
},
656+
},
657+
})
658+
}
659+
660+
type restrictImperialUnitsDirective struct{}
661+
662+
func (d *restrictImperialUnitsDirective) ImplementsDirective() string {
663+
return "restrictImperialUnits"
664+
}
665+
666+
func (d *restrictImperialUnitsDirective) Validate(ctx context.Context, args interface{}) error {
667+
if ctx.Value(RoleKey) == "EMPIRE" {
668+
return nil
669+
}
670+
671+
v, ok := args.(struct {
672+
Unit string
673+
})
674+
if ok && v.Unit == "FOOT" {
675+
return fmt.Errorf("rebels cannot request imperial units")
676+
}
677+
678+
return nil
679+
}
680+
480681
func TestCustomDirectiveStructFieldResolver(t *testing.T) {
481682
t.Parallel()
482683

internal/exec/exec.go

+12
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *ast.Ope
5454
default:
5555
panic("unknown query operation")
5656
}
57+
58+
if errs := validateSelections(ctx, sels, nil, s); errs != nil {
59+
r.Errs = errs
60+
out.Write([]byte("null"))
61+
return
62+
}
63+
5764
r.execSelections(ctx, sels, nil, s, resolver, &out, op.Type == query.Mutation)
5865
}()
5966

@@ -64,6 +71,11 @@ func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *ast.Ope
6471
return out.Bytes(), r.Errs
6572
}
6673

74+
type fieldToValidate struct {
75+
field *selected.SchemaField
76+
sels []selected.Selection
77+
}
78+
6779
type fieldToExec struct {
6880
field *selected.SchemaField
6981
sels []selected.Selection

0 commit comments

Comments
 (0)