Skip to content

Commit

Permalink
Merge pull request #62 from Netflix/fix-unchecked-cast
Browse files Browse the repository at this point in the history
Remove an unchecked cast in RequiredTypeCollector
  • Loading branch information
paulbakker authored Mar 11, 2021
2 parents 35bba17 + 6dccbd8 commit aa670c1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class CodeGen(private val config: CodeGenConfig) {
}

private fun generateForSchema(schema: String): CodeGenResult {
document = Parser().parseDocument(schema)
document = Parser.parse(schema)
requiredTypeCollector = RequiredTypeCollector(document, queries = config.includeQueries, mutations = config.includeMutations)
val definitions = document.definitions
val dataTypesResult = generateJavaDataType(definitions)
Expand Down Expand Up @@ -187,7 +187,7 @@ class CodeGen(private val config: CodeGenConfig) {
definitions.filterIsInstance<InputObjectTypeExtensionDefinition>().filter { name == it.name }

private fun generateKotlinForSchema(schema: String): KotlinCodeGenResult {
document = Parser().parseDocument(schema)
document = Parser.parse(schema)
requiredTypeCollector = RequiredTypeCollector(document, queries = config.includeQueries, mutations = config.includeMutations)
val definitions = document.definitions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,69 +27,48 @@ class RequiredTypeCollector(
queries: Set<String> = emptySet(),
mutations: Set<String> = emptySet()
) {
val requiredTypes: Set<String>
val requiredTypes: Set<String> = LinkedHashSet()

init {
val queryFieldDefinitions: List<FieldDefinition> = document.definitions.filterIsInstance<ObjectTypeDefinition>()
.find { it.name == "Query" }?.fieldDefinitions ?: emptyList()
val fieldDefinitions = mutableListOf<FieldDefinition>()
for (definition in document.definitions.asSequence().filterIsInstance<ObjectTypeDefinition>()) {
when (definition.name) {
"Query" -> definition.fieldDefinitions.filterTo(fieldDefinitions) { it.name in queries }
"Mutation" -> definition.fieldDefinitions.filterTo(fieldDefinitions) { it.name in mutations }
}
}

val mutationDefinitions: List<FieldDefinition> = document.definitions.filterIsInstance<ObjectTypeDefinition>()
.find { it.name == "Mutation" }?.fieldDefinitions ?: emptyList()
val fieldDefinitions = queryFieldDefinitions.plus(mutationDefinitions).filter { queries.contains(it.name) || mutations.contains(it.name) }
val required = requiredTypes as MutableSet<String>

val nodeTraverserResult = NodeTraverser().postOrder(object : NodeVisitorStub() {
NodeTraverser().postOrder(object : NodeVisitorStub() {
override fun visitInputObjectTypeDefinition(
node: InputObjectTypeDefinition,
context: TraverserContext<Node<Node<*>>>
): TraversalControl {
println(node)
required += node.name

val currentAccumulate = context.getNewAccumulate<Set<String>>()
if (currentAccumulate == null) {
context.setAccumulate(setOf(node.name))
} else {
context.setAccumulate(currentAccumulate.plus(node.name))
node.inputValueDefinitions.forEach {
it.type.findTypeDefinition(document)?.accept(context, this)
}

node.inputValueDefinitions.map { it.type.findTypeDefinition(document)?.accept(context, this) }

return TraversalControl.CONTINUE
}

override fun visitEnumTypeDefinition(
node: EnumTypeDefinition,
context: TraverserContext<Node<Node<*>>>
): TraversalControl {
println(node)


val currentAccumulate = context.getNewAccumulate<Set<String>>()
if (currentAccumulate == null) {
context.setAccumulate(setOf(node.name))
} else {
context.setAccumulate(currentAccumulate.plus(node.name))
}

required += node.name
return TraversalControl.CONTINUE
}

override fun visitInputValueDefinition(
node: InputValueDefinition,
context: TraverserContext<Node<Node<*>>>
): TraversalControl {
println(node)

node.type.findTypeDefinition(document)?.accept(context, this)


return super.visitInputValueDefinition(node, context)
return TraversalControl.CONTINUE
}
}, fieldDefinitions)

requiredTypes = if (nodeTraverserResult != null && nodeTraverserResult is Set<*>) {
nodeTraverserResult as Set<String>
} else {
emptySet()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@

package com.netflix.graphql.dgs.codegen

import graphql.language.Document
import graphql.parser.Parser
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test

import org.junit.jupiter.api.Assertions.*

internal class RequiredTypeCollectorTest {
class RequiredTypeCollectorTest {
private val document = Parser.parse("""
type Query {
search(filter: Filter): [Show]
Expand Down

0 comments on commit aa670c1

Please sign in to comment.