@@ -5,13 +5,15 @@ import (
5
5
"encoding/json"
6
6
"errors"
7
7
"fmt"
8
+ "strings"
8
9
"sync"
9
10
"testing"
10
11
"time"
11
12
12
13
"github.com/graph-gophers/graphql-go"
13
14
"github.com/graph-gophers/graphql-go/directives"
14
15
gqlerrors "github.com/graph-gophers/graphql-go/errors"
16
+ "github.com/graph-gophers/graphql-go/example/social"
15
17
"github.com/graph-gophers/graphql-go/example/starwars"
16
18
"github.com/graph-gophers/graphql-go/gqltesting"
17
19
"github.com/graph-gophers/graphql-go/introspection"
@@ -477,6 +479,205 @@ func TestCustomDirective(t *testing.T) {
477
479
})
478
480
}
479
481
482
+ func TestCustomValidatingDirective (t * testing.T ) {
483
+ t .Parallel ()
484
+
485
+ gqltesting .RunTests (t , []* gqltesting.Test {
486
+ {
487
+ Schema : graphql .MustParseSchema (`
488
+ directive @hasRole(role: String!) on FIELD_DEFINITION
489
+
490
+ schema {
491
+ query: Query
492
+ }
493
+
494
+ type Query {
495
+ hello: String! @hasRole(role: "ADMIN")
496
+ }` ,
497
+ & helloResolver {},
498
+ graphql .Directives (& HasRoleDirective {}),
499
+ ),
500
+ Context : context .WithValue (context .Background (), RoleKey , "USER" ),
501
+ Query : `
502
+ {
503
+ hello
504
+ }
505
+ ` ,
506
+ ExpectedResult : "null" ,
507
+ ExpectedErrors : []* gqlerrors.QueryError {
508
+ {Message : `access denied, role "ADMIN" required` , Locations : []gqlerrors.Location {{Line : 9 , Column : 6 }}, Path : []interface {}{"hello" }},
509
+ },
510
+ },
511
+ {
512
+ Schema : graphql .MustParseSchema (`
513
+ directive @hasRole(role: String!) on FIELD_DEFINITION
514
+
515
+ schema {
516
+ query: Query
517
+ }
518
+
519
+ type Query {
520
+ hello: String! @hasRole(role: "ADMIN")
521
+ }` ,
522
+ & helloResolver {},
523
+ graphql .Directives (& HasRoleDirective {}),
524
+ ),
525
+ Context : context .WithValue (context .Background (), RoleKey , "ADMIN" ),
526
+ Query : `
527
+ {
528
+ hello
529
+ }
530
+ ` ,
531
+ ExpectedResult : `
532
+ {
533
+ "hello": "Hello world!"
534
+ }
535
+ ` ,
536
+ },
537
+ {
538
+ Schema : graphql .MustParseSchema (
539
+ `directive @hasRole(role: String!) on FIELD_DEFINITION
540
+
541
+ ` + strings .ReplaceAll (
542
+ social .Schema ,
543
+ "role: Role!" ,
544
+ `role: Role! @hasRole(role: "ADMIN")` ,
545
+ ),
546
+ & social.Resolver {},
547
+ graphql .Directives (& HasRoleDirective {}),
548
+ graphql .UseFieldResolvers (),
549
+ ),
550
+ Context : context .WithValue (context .Background (), RoleKey , "ADMIN" ),
551
+ Query : `
552
+ query {
553
+ user(id: "0x01") {
554
+ role
555
+ ... on User {
556
+ email
557
+ }
558
+ ... on Person {
559
+ name
560
+ }
561
+ }
562
+ }
563
+ ` ,
564
+ ExpectedResult : `
565
+ {
566
+ "user": {
567
+ "role": "ADMIN",
568
+ "email": "Albus@hogwarts.com",
569
+ "name": "Albus Dumbledore"
570
+ }
571
+ }
572
+ ` ,
573
+ },
574
+ {
575
+ Schema : graphql .MustParseSchema (
576
+ `directive @hasRole(role: String!) on FIELD_DEFINITION
577
+
578
+ ` + strings .ReplaceAll (
579
+ starwars .Schema ,
580
+ "hero(episode: Episode = NEWHOPE): Character" ,
581
+ `hero(episode: Episode = NEWHOPE): Character @hasRole(role: "REBELLION")` ,
582
+ ),
583
+ & starwars.Resolver {},
584
+ graphql .Directives (& HasRoleDirective {}),
585
+ ),
586
+ Context : context .WithValue (context .Background (), RoleKey , "EMPIRE" ),
587
+ Query : `
588
+ query HeroesOfTheRebellion($episode: Episode!) {
589
+ hero(episode: $episode) {
590
+ id name
591
+ ... on Human { starships { id name } }
592
+ ... on Droid { primaryFunction }
593
+ }
594
+ }
595
+ ` ,
596
+ Variables : map [string ]interface {}{"episode" : "NEWHOPE" },
597
+ ExpectedResult : "null" ,
598
+ ExpectedErrors : []* gqlerrors.QueryError {
599
+ {Message : `access denied, role "REBELLION" required` , Locations : []gqlerrors.Location {{Line : 10 , Column : 3 }}, Path : []interface {}{"hero" }},
600
+ },
601
+ },
602
+ {
603
+ Schema : graphql .MustParseSchema (
604
+ `directive @hasRole(role: String!) on FIELD_DEFINITION
605
+
606
+ ` + strings .ReplaceAll (
607
+ starwars .Schema ,
608
+ "starships: [Starship]" ,
609
+ `starships: [Starship] @hasRole(role: "REBELLION")` ,
610
+ ),
611
+ & starwars.Resolver {},
612
+ graphql .Directives (& HasRoleDirective {}),
613
+ ),
614
+ Context : context .WithValue (context .Background (), RoleKey , "EMPIRE" ),
615
+ Query : `
616
+ query HeroesOfTheRebellion($episode: Episode!) {
617
+ hero(episode: $episode) {
618
+ id name
619
+ ... on Human { starships { id name } }
620
+ ... on Droid { primaryFunction }
621
+ }
622
+ }
623
+ ` ,
624
+ Variables : map [string ]interface {}{"episode" : "NEWHOPE" },
625
+ ExpectedResult : "null" ,
626
+ ExpectedErrors : []* gqlerrors.QueryError {
627
+ {Message : `access denied, role "REBELLION" required` , Locations : []gqlerrors.Location {{Line : 68 , Column : 3 }}, Path : []interface {}{"hero" , "starships" }},
628
+ },
629
+ },
630
+ {
631
+ Schema : graphql .MustParseSchema (
632
+ `directive @restrictImperialUnits on FIELD_DEFINITION
633
+
634
+ ` + strings .ReplaceAll (
635
+ starwars .Schema ,
636
+ "height(unit: LengthUnit = METER): Float!" ,
637
+ `height(unit: LengthUnit = METER): Float! @restrictImperialUnits` ,
638
+ ),
639
+ & starwars.Resolver {},
640
+ graphql .Directives (& restrictImperialUnitsDirective {}),
641
+ ),
642
+ Context : context .WithValue (context .Background (), RoleKey , "REBELLION" ),
643
+ Query : `
644
+ query HeroesOfTheRebellion($episode: Episode!) {
645
+ hero(episode: $episode) {
646
+ id name
647
+ ... on Human { height(unit: FOOT) }
648
+ }
649
+ }
650
+ ` ,
651
+ Variables : map [string ]interface {}{"episode" : "NEWHOPE" },
652
+ ExpectedResult : "null" ,
653
+ ExpectedErrors : []* gqlerrors.QueryError {
654
+ {Message : `rebels cannot request imperial units` , Locations : []gqlerrors.Location {{Line : 58 , Column : 3 }}, Path : []interface {}{"hero" , "height" }},
655
+ },
656
+ },
657
+ })
658
+ }
659
+
660
+ type restrictImperialUnitsDirective struct {}
661
+
662
+ func (d * restrictImperialUnitsDirective ) ImplementsDirective () string {
663
+ return "restrictImperialUnits"
664
+ }
665
+
666
+ func (d * restrictImperialUnitsDirective ) Validate (ctx context.Context , args interface {}) error {
667
+ if ctx .Value (RoleKey ) == "EMPIRE" {
668
+ return nil
669
+ }
670
+
671
+ v , ok := args .(struct {
672
+ Unit string
673
+ })
674
+ if ok && v .Unit == "FOOT" {
675
+ return fmt .Errorf ("rebels cannot request imperial units" )
676
+ }
677
+
678
+ return nil
679
+ }
680
+
480
681
func TestCustomDirectiveStructFieldResolver (t * testing.T ) {
481
682
t .Parallel ()
482
683
0 commit comments