diff --git a/graphql-dgs-codegen-core/src/main/kotlin/com/netflix/graphql/dgs/codegen/CodeGen.kt b/graphql-dgs-codegen-core/src/main/kotlin/com/netflix/graphql/dgs/codegen/CodeGen.kt index bbc474e4..cbd181ae 100644 --- a/graphql-dgs-codegen-core/src/main/kotlin/com/netflix/graphql/dgs/codegen/CodeGen.kt +++ b/graphql-dgs-codegen-core/src/main/kotlin/com/netflix/graphql/dgs/codegen/CodeGen.kt @@ -210,7 +210,7 @@ class CodeGen(private val config: CodeGenConfig) { .filterIsInstance() .excludeSchemaTypeExtension() .filter { config.generateDataTypes || config.generateInterfaces || it.name in requiredTypeCollector.requiredTypes } - .map { UnionTypeGenerator(config).generate(it, findUnionExtensions(it.name, definitions)) } + .map { UnionTypeGenerator(config, document).generate(it, findUnionExtensions(it.name, definitions)) } .fold(CodeGenResult()) { t: CodeGenResult, u: CodeGenResult -> t.merge(u) } } diff --git a/graphql-dgs-codegen-core/src/main/kotlin/com/netflix/graphql/dgs/codegen/generators/java/UnionTypeGenerator.kt b/graphql-dgs-codegen-core/src/main/kotlin/com/netflix/graphql/dgs/codegen/generators/java/UnionTypeGenerator.kt index 0acd512c..6e514ae6 100644 --- a/graphql-dgs-codegen-core/src/main/kotlin/com/netflix/graphql/dgs/codegen/generators/java/UnionTypeGenerator.kt +++ b/graphql-dgs-codegen-core/src/main/kotlin/com/netflix/graphql/dgs/codegen/generators/java/UnionTypeGenerator.kt @@ -24,12 +24,17 @@ import com.netflix.graphql.dgs.codegen.shouldSkip import com.squareup.javapoet.ClassName import com.squareup.javapoet.JavaFile import com.squareup.javapoet.TypeSpec +import graphql.language.Document import graphql.language.TypeName import graphql.language.UnionTypeDefinition import graphql.language.UnionTypeExtensionDefinition import javax.lang.model.element.Modifier -class UnionTypeGenerator(private val config: CodeGenConfig) { +class UnionTypeGenerator(private val config: CodeGenConfig, private val document: Document) { + + val packageName = config.packageNameTypes + private val typeUtils = TypeUtils(packageName, config, document) + fun generate(definition: UnionTypeDefinition, extensions: List): CodeGenResult { if (definition.shouldSkip(config)) { return CodeGenResult() @@ -41,7 +46,10 @@ class UnionTypeGenerator(private val config: CodeGenConfig) { val memberTypes = definition.memberTypes.plus(extensions.flatMap { it.memberTypes }).asSequence() .filterIsInstance() - .map { member -> ClassName.get(packageName, member.name) } + .map { member -> + typeUtils.findJavaInterfaceName(member.name, packageName) + } + .filterIsInstance() .toList() if (memberTypes.isNotEmpty()) { @@ -52,6 +60,4 @@ class UnionTypeGenerator(private val config: CodeGenConfig) { val javaFile = JavaFile.builder(packageName, javaType.build()).build() return CodeGenResult(javaInterfaces = listOf(javaFile)) } - - val packageName = config.packageNameTypes } diff --git a/graphql-dgs-codegen-core/src/test/kotlin/com/netflix/graphql/dgs/codegen/CodeGenTest.kt b/graphql-dgs-codegen-core/src/test/kotlin/com/netflix/graphql/dgs/codegen/CodeGenTest.kt index f469ee49..ec2a5345 100644 --- a/graphql-dgs-codegen-core/src/test/kotlin/com/netflix/graphql/dgs/codegen/CodeGenTest.kt +++ b/graphql-dgs-codegen-core/src/test/kotlin/com/netflix/graphql/dgs/codegen/CodeGenTest.kt @@ -4530,6 +4530,36 @@ It takes a title and such. assertThat(dataTypes[1].typeSpec.superinterfaces[0].toString()).isEqualTo("java.lang.String") } + @Test + fun `Supports typeMapping in union type generation`() { + val schema = """ + type A { + name: String + } + + type B { + count: Int + } + + union C = A | B + """.trimIndent() + + val result = CodeGen( + CodeGenConfig( + schemas = setOf(schema), + packageName = basePackageName, + typeMapping = mapOf( + "A" to "java.lang.String" + ) + ) + ).generate() + + assertThat(result.javaDataTypes.size).isEqualTo(1) + + assertThat(result.javaDataTypes[0].typeSpec.superinterfaces[0].toString()).isEqualTo("com.netflix.graphql.dgs.codegen.tests.generated.types.C") + assertCompilesJava(result) + } + @Test fun `The default value for Locale should be overridden and wrapped`() { val schema = """