From a119b1df93d9f301539ba9d04047c994e9a1ced9 Mon Sep 17 00:00:00 2001 From: Ben Weintraub Date: Tue, 4 Jun 2024 10:43:02 -0700 Subject: [PATCH] improve graph reachability filtering and add tests --- .../testdata/cases/viz_from/expected.txt | 4 - pkg/graph/field_filter.go | 91 ------ pkg/graph/field_filter_test.go | 185 ------------ pkg/graph/graph.go | 55 ++-- pkg/graph/graph_test.go | 273 ++++++++++++++++++ pkg/graph/reference_set.go | 62 ++++ pkg/model/fields.go | 9 + 7 files changed, 379 insertions(+), 300 deletions(-) delete mode 100644 pkg/graph/field_filter.go delete mode 100644 pkg/graph/field_filter_test.go create mode 100644 pkg/graph/graph_test.go create mode 100644 pkg/graph/reference_set.go diff --git a/pkg/commands/testdata/cases/viz_from/expected.txt b/pkg/commands/testdata/cases/viz_from/expected.txt index cdadb8d..12d3571 100644 --- a/pkg/commands/testdata/cases/viz_from/expected.txt +++ b/pkg/commands/testdata/cases/viz_from/expected.txt @@ -32,10 +32,6 @@ digraph { object Query fruitFruit nameString - edibleEdible - nameString - edibles[Edible!]! - filterFilter >] n_Apple:p_variety -> n_AppleVariety:main diff --git a/pkg/graph/field_filter.go b/pkg/graph/field_filter.go deleted file mode 100644 index a64e4ac..0000000 --- a/pkg/graph/field_filter.go +++ /dev/null @@ -1,91 +0,0 @@ -package graph - -import ( - "github.com/benweint/gquil/pkg/model" -) - -func makeFilter(root *model.NameReference) *fieldFilter { - result := &fieldFilter{ - onType: root.GetTargetType().Name, - includeFields: map[string]bool{}, - } - - switch root.Kind { - case model.TypeNameReference: - result.includeAll = true - case model.FieldNameReference, model.InputFieldNameReference: - result.includeFields[root.GetFieldName()] = true - } - - return result -} - -func makeFieldFilters(roots []*model.NameReference) map[string]*fieldFilter { - filtersByType := map[string]*fieldFilter{} - for _, root := range roots { - filter := makeFilter(root) - if existingFilter, ok := filtersByType[filter.onType]; ok { - existingFilter.merge(filter) - } else { - filtersByType[filter.onType] = filter - } - } - return filtersByType -} - -func applyFieldFilters(defs model.DefinitionList, roots []*model.NameReference) model.DefinitionList { - var result model.DefinitionList - filters := makeFieldFilters(roots) - - for _, typeDef := range defs { - filter, ok := filters[typeDef.Name] - if !ok { - continue - } - - filteredDef := &model.Definition{ - Kind: typeDef.Kind, - Name: typeDef.Name, - Description: typeDef.Description, - Interfaces: typeDef.Interfaces, - PossibleTypes: typeDef.PossibleTypes, - EnumValues: typeDef.EnumValues, - Fields: applyFilter(filter, typeDef.Fields), - } - - result = append(result, filteredDef) - } - - return result -} - -type fieldFilter struct { - onType string - includeAll bool - includeFields map[string]bool -} - -func (f *fieldFilter) merge(other *fieldFilter) { - if other.includeAll { - f.includeAll = true - return - } - - for field := range other.includeFields { - f.includeFields[field] = true - } -} - -func applyFilter(f *fieldFilter, list model.FieldDefinitionList) model.FieldDefinitionList { - if f.includeAll { - return list - } - - var result model.FieldDefinitionList - for _, field := range list { - if f.includeFields[field.Name] { - result = append(result, field) - } - } - return result -} diff --git a/pkg/graph/field_filter_test.go b/pkg/graph/field_filter_test.go deleted file mode 100644 index 52712cd..0000000 --- a/pkg/graph/field_filter_test.go +++ /dev/null @@ -1,185 +0,0 @@ -package graph - -import ( - "fmt" - "sort" - "testing" - - "github.com/benweint/gquil/pkg/model" - "github.com/stretchr/testify/assert" - "github.com/vektah/gqlparser/v2/ast" -) - -func TestApplyFieldFilters(t *testing.T) { - exampleDefs := model.DefinitionList{ - { - Name: "Alpha", - Kind: ast.Object, - Fields: model.FieldDefinitionList{ - { - Name: "x", - }, - { - Name: "y", - }, - { - Name: "z", - }, - }, - }, - { - Name: "Beta", - Kind: ast.Object, - Fields: model.FieldDefinitionList{ - { - Name: "z", - }, - { - Name: "a", - }, - }, - }, - { - Name: "Theta", - Kind: ast.Object, - Fields: model.FieldDefinitionList{ - { - Name: "x", - }, - { - Name: "y", - }, - }, - }, - { - Name: "InputA", - Kind: ast.InputObject, - Fields: model.FieldDefinitionList{ - { - Name: "a", - }, - { - Name: "b", - }, - }, - }, - } - - for _, tc := range []struct { - name string - defs model.DefinitionList - roots []string - expectedFields []string - }{ - { - name: "single type", - defs: exampleDefs, - roots: []string{"Alpha"}, - expectedFields: []string{ - "Alpha.x", - "Alpha.y", - "Alpha.z", - }, - }, - { - name: "multiple types", - defs: exampleDefs, - roots: []string{"Alpha", "Beta"}, - expectedFields: []string{ - "Alpha.x", - "Alpha.y", - "Alpha.z", - "Beta.a", - "Beta.z", - }, - }, - { - name: "single field on multiple types", - defs: exampleDefs, - roots: []string{"Alpha.x", "Beta.z"}, - expectedFields: []string{ - "Alpha.x", - "Beta.z", - }, - }, - { - name: "multiple fields on the same type", - defs: exampleDefs, - roots: []string{ - "Alpha.x", - "Alpha.z", - }, - expectedFields: []string{ - "Alpha.x", - "Alpha.z", - }, - }, - { - name: "field on type plus whole type", - defs: exampleDefs, - roots: []string{ - "Alpha.x", - "Alpha", - }, - expectedFields: []string{ - "Alpha.x", - "Alpha.y", - "Alpha.z", - }, - }, - { - name: "input field", - defs: exampleDefs, - roots: []string{ - "InputA.a", - }, - expectedFields: []string{ - "InputA.a", - }, - }, - { - name: "input type", - defs: exampleDefs, - roots: []string{ - "InputA", - }, - expectedFields: []string{ - "InputA.a", - "InputA.b", - }, - }, - { - name: "type and then field", - defs: exampleDefs, - roots: []string{ - "Alpha", - "Alpha.x", - }, - expectedFields: []string{ - "Alpha.x", - "Alpha.y", - "Alpha.z", - }, - }, - } { - t.Run(tc.name, func(t *testing.T) { - s := &model.Schema{ - Types: exampleDefs.ToMap(), - } - roots, err := s.ResolveNames(tc.roots) - assert.NoError(t, err) - filtered := applyFieldFilters(exampleDefs, roots) - - var actualFieldNames []string - for _, def := range filtered { - for _, field := range def.Fields { - fieldName := fmt.Sprintf("%s.%s", def.Name, field.Name) - actualFieldNames = append(actualFieldNames, fieldName) - } - } - sort.Strings(actualFieldNames) - - assert.Equal(t, tc.expectedFields, actualFieldNames) - }) - } -} diff --git a/pkg/graph/graph.go b/pkg/graph/graph.go index 32e6152..026e3a6 100644 --- a/pkg/graph/graph.go +++ b/pkg/graph/graph.go @@ -134,39 +134,49 @@ func (g *Graph) GetDefinitions() model.DefinitionMap { } func (g *Graph) ReachableFrom(roots []*model.NameReference, maxDepth int) *Graph { - var defs model.DefinitionList - for _, node := range g.nodes { - defs = append(defs, node) - } - rootDefs := applyFieldFilters(defs, roots) - - seen := model.DefinitionMap{} + seen := referenceSet{} var traverse func(n *model.Definition, depth int) + + traverseField := func(typeName string, f *model.FieldDefinition, depth int) { + key := fieldRef(typeName, f.Name) + if seen[key] { + return + } + seen[key] = true + + if maxDepth > 0 && depth > maxDepth { + return + } + + for _, arg := range f.Arguments { + argType := arg.Type.Unwrap() + traverse(g.nodes[argType.Name], depth+1) + } + + underlyingType := f.Type.Unwrap() + traverse(g.nodes[underlyingType.Name], depth+1) + } + traverse = func(n *model.Definition, depth int) { if maxDepth > 0 && depth > maxDepth { return } - if _, ok := seen[n.Name]; ok { + key := typeRef(n.Name) + if _, ok := seen[key]; ok { return } if n.Kind == ast.Scalar { return } - seen[n.Name] = n + seen[key] = true kind := normalizeKind(n.Kind, g.interfacesAsUnions) switch kind { case ast.Object, ast.InputObject: for _, field := range n.Fields { - for _, arg := range field.Arguments { - argType := arg.Type.Unwrap() - traverse(g.nodes[argType.Name], depth+1) - } - - underlyingType := field.Type.Unwrap() - traverse(g.nodes[underlyingType.Name], depth+1) + traverseField(n.Name, field, depth) } case ast.Union: for _, pt := range n.PossibleTypes { @@ -175,14 +185,19 @@ func (g *Graph) ReachableFrom(roots []*model.NameReference, maxDepth int) *Graph } } - for _, root := range rootDefs { - traverse(root, 1) + for _, root := range roots { + targetType := root.GetTargetType() + if fieldName := root.GetFieldName(); fieldName != "" { + traverseField(targetType.Name, targetType.Fields.Named(fieldName), 1) + } else { + traverse(targetType, 1) + } } filteredNodes := model.DefinitionMap{} for name, node := range g.nodes { - if _, ok := seen[name]; ok { - filteredNodes[name] = node + if seen.includesType(name) { + filteredNodes[name] = seen.filterFields(node) } } diff --git a/pkg/graph/graph_test.go b/pkg/graph/graph_test.go new file mode 100644 index 0000000..6d2f1fe --- /dev/null +++ b/pkg/graph/graph_test.go @@ -0,0 +1,273 @@ +package graph + +import ( + "fmt" + "sort" + "strings" + "testing" + + "github.com/benweint/gquil/pkg/model" + "github.com/stretchr/testify/assert" + "github.com/vektah/gqlparser/v2" + "github.com/vektah/gqlparser/v2/ast" +) + +type edgeSpec struct { + srcType string + dstType string + fieldName string + argName string +} + +func (es edgeSpec) String() string { + return fmt.Sprintf("%s -> %s [%s].%s", es.srcType, es.dstType, es.fieldName, es.argName) +} + +func TestReachableFrom(t *testing.T) { + for _, tc := range []struct { + name string + schema string + roots []string + expectedNodes []string + expectedFields []string + expectedEdges []edgeSpec + maxDepth int + }{ + { + name: "single field root", + schema: `type Query { + alpha: Alpha + beta: Beta + } + + type Alpha { + name: String + } + + type Beta { + name: String + }`, + roots: []string{"Query.alpha"}, + expectedNodes: []string{"Alpha", "Query"}, + expectedFields: []string{"Alpha.name", "Query.alpha"}, + expectedEdges: []edgeSpec{ + { + srcType: "Query", + dstType: "Alpha", + fieldName: "alpha", + }, + }, + }, + { + name: "multiple field roots", + schema: `type Query { + alpha: Alpha + beta: Beta + gaga: Gaga + } + + type Alpha { + name: String + } + + type Beta { + name: String + } + + type Gaga { + name: String + }`, + roots: []string{"Query.alpha", "Query.beta"}, + expectedNodes: []string{"Alpha", "Beta", "Query"}, + expectedFields: []string{"Alpha.name", "Beta.name", "Query.alpha", "Query.beta"}, + expectedEdges: []edgeSpec{ + { + srcType: "Query", + dstType: "Alpha", + fieldName: "alpha", + }, + { + srcType: "Query", + dstType: "Beta", + fieldName: "beta", + }, + }, + }, + { + name: "root field with cycle", + schema: `type Query { + person(name: String): Person + organization(name: String): Organization + } + + type Person { + name: String + friends: [Person] + } + + type Organization { + name: String + }`, + roots: []string{"Person.friends"}, + expectedNodes: []string{"Person"}, + expectedFields: []string{"Person.friends", "Person.name"}, + expectedEdges: []edgeSpec{ + { + srcType: "Person", + dstType: "Person", + fieldName: "friends", + }, + }, + }, + { + name: "unions", + schema: `type Query { + subject(name: String): Subject + events: [Event] + } + + union Subject = Person | Organization + + type Person { + name: String + } + + type Organization { + name: String + } + + type Event { + title: String + }`, + roots: []string{"Query.subject"}, + expectedNodes: []string{"Organization", "Person", "Query", "Subject"}, + expectedFields: []string{"Organization.name", "Person.name", "Query.subject"}, + expectedEdges: []edgeSpec{ + { + srcType: "Query", + dstType: "Subject", + fieldName: "subject", + }, + { + srcType: "Subject", + dstType: "Organization", + }, + { + srcType: "Subject", + dstType: "Person", + }, + }, + }, + { + name: "depth limited", + schema: `type Query { + persons(filter: PersonFilter): [Person] + foods: [Food] + } + + input PersonFilter { + nameLike: String + matchMode: MatchMode + } + + enum MatchMode { + CASE_SENSITIVE + CASE_INSENSITIVE + } + + type Person { + name: String + favoriteFoods: [Food] + } + + type Food { + name: String + }`, + roots: []string{"Query.persons"}, + maxDepth: 2, + expectedNodes: []string{"Person", "PersonFilter", "Query"}, + expectedFields: []string{ + "Person.favoriteFoods", + "Person.name", + "PersonFilter.matchMode", + "PersonFilter.nameLike", + "Query.persons", + }, + expectedEdges: []edgeSpec{ + { + srcType: "Query", + dstType: "Person", + fieldName: "persons", + }, + { + srcType: "Query", + dstType: "PersonFilter", + fieldName: "persons", + argName: "filter", + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + src := ast.Source{ + Name: "testcase", + Input: tc.schema, + } + rawSchema, err := gqlparser.LoadSchema(&src) + assert.NoError(t, err) + + s, err := model.MakeSchema(rawSchema) + assert.NoError(t, err) + + roots, err := s.ResolveNames(tc.roots) + assert.NoError(t, err) + + g := MakeGraph(s) + trimmed := g.ReachableFrom(roots, tc.maxDepth) + + var actualNodes []string + var actualEdges []edgeSpec + var actualFields []string + + for _, node := range trimmed.nodes { + actualNodes = append(actualNodes, node.Name) + for _, field := range node.Fields { + fieldId := node.Name + "." + field.Name + actualFields = append(actualFields, fieldId) + } + } + + sort.Strings(actualNodes) + sort.Strings(actualFields) + + assert.Equal(t, tc.expectedNodes, actualNodes) + assert.Equal(t, tc.expectedFields, actualFields) + + for _, edges := range trimmed.edges { + for _, edge := range edges { + fieldName := "" + if edge.field != nil { + fieldName = edge.field.Name + } + argName := "" + if edge.argument != nil { + argName = edge.argument.Name + } + actualEdge := edgeSpec{ + srcType: edge.src.Name, + dstType: edge.dst.Name, + fieldName: fieldName, + argName: argName, + } + actualEdges = append(actualEdges, actualEdge) + } + } + + sort.Slice(actualEdges, func(i, j int) bool { + return strings.Compare(actualEdges[i].String(), actualEdges[j].String()) < 0 + }) + + assert.Equal(t, tc.expectedEdges, actualEdges) + }) + } +} diff --git a/pkg/graph/reference_set.go b/pkg/graph/reference_set.go new file mode 100644 index 0000000..a6a21a1 --- /dev/null +++ b/pkg/graph/reference_set.go @@ -0,0 +1,62 @@ +package graph + +import "github.com/benweint/gquil/pkg/model" + +// typeOrField refers to either an entire GraphQL type (e.g. a union type), or to a specific field +// on a type. When refering to an entire type, the fieldName field is set to its zero value (empty string). +type typeOrField struct { + typeName string + fieldName string +} + +func typeRef(name string) typeOrField { + return typeOrField{typeName: name} +} + +func fieldRef(typeName, fieldName string) typeOrField { + return typeOrField{ + typeName: typeName, + fieldName: fieldName, + } +} + +// referenceSet captures the set of types & fields which have been encountered when traversing a GraphQL schema. +type referenceSet map[typeOrField]bool + +// includesType returns true if the target referenceSet includes at least one field on the given type name, +// or a key representing the entire type. +func (s referenceSet) includesType(name string) bool { + for key := range s { + if key.typeName == name { + return true + } + } + return false +} + +// includesField returns true if the given referenceSet includes a key representing the given field on the given type. +func (s referenceSet) includesField(typeName, fieldName string) bool { + return s[typeOrField{typeName: typeName, fieldName: fieldName}] +} + +// filterFields returns a copy of the given definition, where the field list has been filtered to only include +// fields which were included in the referenceSet. The original def is not modified by this method. +func (s referenceSet) filterFields(def *model.Definition) *model.Definition { + var filteredFields []*model.FieldDefinition + for _, field := range def.Fields { + if s.includesField(def.Name, field.Name) { + filteredFields = append(filteredFields, field) + } + } + + return &model.Definition{ + Kind: def.Kind, + Name: def.Name, + Description: def.Description, + Directives: def.Directives, + Interfaces: def.Interfaces, + PossibleTypes: def.PossibleTypes, + EnumValues: def.EnumValues, + Fields: filteredFields, + } +} diff --git a/pkg/model/fields.go b/pkg/model/fields.go index 3ea8a01..e8defc0 100644 --- a/pkg/model/fields.go +++ b/pkg/model/fields.go @@ -35,6 +35,15 @@ func (fdl FieldDefinitionList) Sort() { }) } +func (fdl FieldDefinitionList) Named(name string) *FieldDefinition { + for _, field := range fdl { + if field.Name == name { + return field + } + } + return nil +} + func (fd *FieldDefinition) MarshalJSON() ([]byte, error) { m := map[string]any{ "name": fd.Name,