From 7096681dd8478ee882b988849e11ef8bd695a094 Mon Sep 17 00:00:00 2001 From: Devin Smith Date: Wed, 26 Feb 2025 13:21:21 -0800 Subject: [PATCH] fix: DH-18803: Add qst Type find support for componentType This adds the needed, centralized logic for adapting to qst Type when componentType is not null. This allows downstream callers, such as the Sql adapter layer, to correctly construct the appropriate Types. --- .../java/io/deephaven/engine/sql/Sql.java | 2 +- .../impl/locations/util/PartitionParser.java | 6 +- .../deephaven/qst/type/GenericVectorTest.java | 28 ++++++ .../qst/type/PrimitiveVectorTest.java | 88 ++++++++++++++++++- .../server/flightsql/FlightSqlTest.java | 35 ++++++++ .../java/io/deephaven/kafka/SimpleImpl.java | 7 +- qst/type/build.gradle | 2 + .../deephaven/qst/type/GenericVectorType.java | 11 +++ .../qst/type/PrimitiveVectorType.java | 53 ++++++++++- .../main/java/io/deephaven/qst/type/Type.java | 39 ++++++-- .../java/io/deephaven/qst/type/TypeTest.java | 45 +++++++++- 11 files changed, 294 insertions(+), 22 deletions(-) create mode 100644 engine/table/src/test/java/io/deephaven/qst/type/GenericVectorTest.java diff --git a/engine/sql/src/main/java/io/deephaven/engine/sql/Sql.java b/engine/sql/src/main/java/io/deephaven/engine/sql/Sql.java index a8f9958e67a..46cc950416a 100644 --- a/engine/sql/src/main/java/io/deephaven/engine/sql/Sql.java +++ b/engine/sql/src/main/java/io/deephaven/engine/sql/Sql.java @@ -109,7 +109,7 @@ private static TableHeader adapt(TableDefinition tableDef) { } private static ColumnHeader adapt(ColumnDefinition columnDef) { - return ColumnHeader.of(columnDef.getName(), Type.find(columnDef.getDataType())); + return ColumnHeader.of(columnDef.getName(), Type.find(columnDef.getDataType(), columnDef.getComponentType())); } private enum ToGraphvizDot implements ObjFormatter { diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/util/PartitionParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/util/PartitionParser.java index 446d8a48065..c2c989ad996 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/util/PartitionParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/util/PartitionParser.java @@ -195,9 +195,11 @@ public Comparable parse(@NotNull final String stringValue) { @Nullable public static PartitionParser lookup(@NotNull final Class dataType, @Nullable final Class componentType) { if (componentType != null) { + // This is a short-circuit since we know that Resolver does not support ArrayType. return null; } - return lookup(Type.find(dataType)); + // noinspection ConstantValue + return lookup(Type.find(dataType, componentType)); } /** @@ -289,6 +291,8 @@ public PartitionParser visit(@NotNull final InstantType instantType) { @Override public PartitionParser visit(@NotNull final ArrayType arrayType) { + // If the partition parser ever supports ArrayTypes, make sure the short-circuit in + // PartitionParser.lookup(java.lang.Class, java.lang.Class) is removed. return null; } diff --git a/engine/table/src/test/java/io/deephaven/qst/type/GenericVectorTest.java b/engine/table/src/test/java/io/deephaven/qst/type/GenericVectorTest.java new file mode 100644 index 00000000000..21bfcf95e1d --- /dev/null +++ b/engine/table/src/test/java/io/deephaven/qst/type/GenericVectorTest.java @@ -0,0 +1,28 @@ +// +// Copyright (c) 2016-2025 Deephaven Data Labs and Patent Pending +// +package io.deephaven.qst.type; + +import io.deephaven.vector.ObjectVector; +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +public class GenericVectorTest { + + @Test + public void stringType() throws ClassNotFoundException { + testConstruction(GenericVectorType.of(ObjectVector.class, Type.stringType())); + } + + @Test + public void instantType() throws ClassNotFoundException { + testConstruction(GenericVectorType.of(ObjectVector.class, Type.instantType())); + } + + private static void testConstruction(GenericVectorType vectorType) + throws ClassNotFoundException { + assertThat(Type.find(vectorType.clazz(), vectorType.componentType().clazz())).isEqualTo(vectorType); + assertThat(GenericVectorType.of(vectorType.clazz(), vectorType.componentType())).isEqualTo(vectorType); + } +} diff --git a/engine/table/src/test/java/io/deephaven/qst/type/PrimitiveVectorTest.java b/engine/table/src/test/java/io/deephaven/qst/type/PrimitiveVectorTest.java index cfca2317aea..90846f75e01 100644 --- a/engine/table/src/test/java/io/deephaven/qst/type/PrimitiveVectorTest.java +++ b/engine/table/src/test/java/io/deephaven/qst/type/PrimitiveVectorTest.java @@ -3,13 +3,20 @@ // package io.deephaven.qst.type; -import static org.assertj.core.api.Assertions.assertThat; - -import io.deephaven.vector.*; +import io.deephaven.vector.ByteVector; +import io.deephaven.vector.CharVector; +import io.deephaven.vector.DoubleVector; +import io.deephaven.vector.FloatVector; +import io.deephaven.vector.IntVector; +import io.deephaven.vector.LongVector; +import io.deephaven.vector.ShortVector; +import org.junit.Test; import java.lang.reflect.InvocationTargetException; +import java.util.stream.Collectors; -import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.failBecauseExceptionWasNotThrown; public class PrimitiveVectorTest { @@ -25,4 +32,77 @@ public void types() FloatVector.type(), DoubleVector.type()); } + + @Test + public void byteVector() throws ClassNotFoundException { + testConstruction(ByteVector.type()); + } + + @Test + public void charVector() throws ClassNotFoundException { + testConstruction(CharVector.type()); + } + + @Test + public void shortVector() throws ClassNotFoundException { + testConstruction(ShortVector.type()); + } + + @Test + public void intVector() throws ClassNotFoundException { + testConstruction(IntVector.type()); + } + + @Test + public void longVector() throws ClassNotFoundException { + testConstruction(LongVector.type()); + } + + @Test + public void floatVector() throws ClassNotFoundException { + testConstruction(FloatVector.type()); + } + + @Test + public void doubleVector() throws ClassNotFoundException { + testConstruction(DoubleVector.type()); + } + + private static void testConstruction(PrimitiveVectorType vectorType) + throws ClassNotFoundException { + assertThat(Type.find(vectorType.clazz())).isEqualTo(vectorType); + assertThat(Type.find(vectorType.clazz(), vectorType.componentType().clazz())).isEqualTo(vectorType); + assertThat(PrimitiveVectorType.of(vectorType.clazz(), vectorType.componentType())).isEqualTo(vectorType); + // fail if component type is bad + for (PrimitiveType badComponent : PrimitiveType.instances().collect(Collectors.toList())) { + if (badComponent.equals(vectorType.componentType())) { + continue; + } + fail(vectorType.clazz(), badComponent); + } + // fail if data type is bad + fail(Object.class, vectorType.componentType()); + for (PrimitiveVectorType primitiveVectorType : TypeHelper.primitiveVectorTypes() + .collect(Collectors.toList())) { + if (primitiveVectorType.equals(vectorType)) { + continue; + } + fail(primitiveVectorType.clazz(), vectorType.componentType()); + } + } + + public static void fail(Class clazz, PrimitiveType ct) { + try { + Type.find(clazz, ct.clazz()); + failBecauseExceptionWasNotThrown(IllegalArgumentException.class); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageContaining("Invalid PrimitiveVectorType"); + } + try { + PrimitiveVectorType.of(clazz, ct); + failBecauseExceptionWasNotThrown(IllegalArgumentException.class); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageContaining("Invalid PrimitiveVectorType"); + } + } } diff --git a/extensions/flight-sql/src/test/java/io/deephaven/server/flightsql/FlightSqlTest.java b/extensions/flight-sql/src/test/java/io/deephaven/server/flightsql/FlightSqlTest.java index 8a979b13faf..04f0cde0d4d 100644 --- a/extensions/flight-sql/src/test/java/io/deephaven/server/flightsql/FlightSqlTest.java +++ b/extensions/flight-sql/src/test/java/io/deephaven/server/flightsql/FlightSqlTest.java @@ -17,10 +17,20 @@ import io.deephaven.engine.table.TableDefinition; import io.deephaven.engine.util.TableTools; import io.deephaven.proto.backplane.grpc.WrappedAuthenticationRequest; +import io.deephaven.qst.type.GenericVectorType; +import io.deephaven.qst.type.Type; import io.deephaven.server.auth.AuthorizationProvider; import io.deephaven.server.config.ServerConfig; import io.deephaven.server.runner.DeephavenApiServerTestBase; import io.deephaven.server.runner.DeephavenApiServerTestBase.TestComponent.Builder; +import io.deephaven.vector.ByteVector; +import io.deephaven.vector.CharVector; +import io.deephaven.vector.DoubleVector; +import io.deephaven.vector.FloatVector; +import io.deephaven.vector.IntVector; +import io.deephaven.vector.LongVector; +import io.deephaven.vector.ObjectVector; +import io.deephaven.vector.ShortVector; import io.grpc.ManagedChannel; import org.apache.arrow.flight.Action; import org.apache.arrow.flight.ActionType; @@ -587,6 +597,31 @@ public void badSqlQuery() { queryError("this is not SQL", FlightStatusCode.INVALID_ARGUMENT, "Flight SQL: query can't be parsed"); } + @Test + public void testDh18803() throws Exception { + // https://deephaven.atlassian.net/browse/DH-18803: Sql fails to adapt Vector types + final TableDefinition td = TableDefinition.of( + ColumnDefinition.ofVector("ByteVector", ByteVector.class), + ColumnDefinition.ofVector("CharVector", CharVector.class), + ColumnDefinition.ofVector("ShortVector", ShortVector.class), + ColumnDefinition.ofVector("IntVector", IntVector.class), + ColumnDefinition.ofVector("LongVector", LongVector.class), + ColumnDefinition.ofVector("FloatVector", FloatVector.class), + ColumnDefinition.ofVector("DoubleVector", DoubleVector.class), + ColumnDefinition.of("StringVector", GenericVectorType.of(ObjectVector.class, Type.stringType()))); + final Table emptyTable = TableTools.newTable(td); + ExecutionContext.getContext().getQueryScope().putParam("MyTable", emptyTable); + { + final SchemaResult schema = flightSqlClient.getExecuteSchema("SELECT * FROM MyTable"); + assertThat(schema.getSchema().getFields()).hasSize(8); + } + { + final FlightInfo info = flightSqlClient.execute("SELECT * FROM MyTable"); + assertThat(info.getSchema().getFields()).hasSize(8); + consume(info, 0, 0, false); + } + } + @Test public void executeSubstrait() { getSchemaUnimplemented(() -> flightSqlClient.getExecuteSubstraitSchema(fakePlan()), diff --git a/extensions/kafka/src/main/java/io/deephaven/kafka/SimpleImpl.java b/extensions/kafka/src/main/java/io/deephaven/kafka/SimpleImpl.java index e897f263c63..0d145625d05 100644 --- a/extensions/kafka/src/main/java/io/deephaven/kafka/SimpleImpl.java +++ b/extensions/kafka/src/main/java/io/deephaven/kafka/SimpleImpl.java @@ -233,13 +233,14 @@ public Optional getSchemaProvider() { @Override Serializer getSerializer(SchemaRegistryClient schemaRegistryClient, TableDefinition definition) { - final Class dataType = definition.getColumn(columnName).getDataType(); - final Serializer serializer = serializer(Type.find(dataType)).orElse(null); + final ColumnDefinition cd = definition.getColumn(columnName); + final Type type = Type.find(cd.getDataType(), cd.getComponentType()); + final Serializer serializer = serializer(type).orElse(null); if (serializer != null) { return serializer; } throw new UncheckedDeephavenException( - String.format("Serializer not found for column %s, type %s", columnName, dataType.getName())); + String.format("Serializer not found for column %s, type %s", columnName, type)); } @Override diff --git a/qst/type/build.gradle b/qst/type/build.gradle index 6a31a0322c0..9235c42b6ec 100644 --- a/qst/type/build.gradle +++ b/qst/type/build.gradle @@ -9,6 +9,8 @@ dependencies { compileOnly project(':util-immutables') annotationProcessor libs.immutables.value + compileOnly libs.jetbrains.annotations + testImplementation libs.assertj testImplementation platform(libs.junit.bom) testImplementation libs.junit.jupiter diff --git a/qst/type/src/main/java/io/deephaven/qst/type/GenericVectorType.java b/qst/type/src/main/java/io/deephaven/qst/type/GenericVectorType.java index be0ddc8d6c1..0e5c9697e3e 100644 --- a/qst/type/src/main/java/io/deephaven/qst/type/GenericVectorType.java +++ b/qst/type/src/main/java/io/deephaven/qst/type/GenericVectorType.java @@ -4,6 +4,7 @@ package io.deephaven.qst.type; import io.deephaven.annotations.SimpleStyle; +import org.immutables.value.Value.Check; import org.immutables.value.Value.Immutable; import org.immutables.value.Value.Parameter; @@ -11,6 +12,8 @@ @SimpleStyle public abstract class GenericVectorType extends ArrayTypeBase { + private static final String OBJECT_VECTOR = "io.deephaven.vector.ObjectVector"; + public static GenericVectorType of( Class clazz, GenericType componentType) { @@ -27,4 +30,12 @@ public static GenericVectorType of( public final R walk(ArrayType.Visitor visitor) { return visitor.visit(this); } + + @Check + final void checkClazz() { + if (!OBJECT_VECTOR.equals(clazz().getName())) { + throw new IllegalArgumentException(String.format("Invalid GenericVectorType. clazz=%s, componentType=%s", + clazz().getName(), componentType())); + } + } } diff --git a/qst/type/src/main/java/io/deephaven/qst/type/PrimitiveVectorType.java b/qst/type/src/main/java/io/deephaven/qst/type/PrimitiveVectorType.java index d9cdedeb743..c4a29f17786 100644 --- a/qst/type/src/main/java/io/deephaven/qst/type/PrimitiveVectorType.java +++ b/qst/type/src/main/java/io/deephaven/qst/type/PrimitiveVectorType.java @@ -73,10 +73,55 @@ public final R walk(ArrayType.Visitor visitor) { } @Check - final void checkClazz() { - if (!VALID_CLASSES.contains(clazz().getName())) { - throw new IllegalArgumentException(String.format("Class '%s' is not a valid '%s'", - clazz(), PrimitiveVectorType.class)); + final void checkPairing() { + final String vectorClassNameFromComponent = componentType().walk(VectorClassName.INSTANCE); + if (!clazz().getName().equals(vectorClassNameFromComponent)) { + throw new IllegalArgumentException(String.format("Invalid PrimitiveVectorType. clazz=%s, componentType=%s", + clazz().getName(), componentType())); + } + } + + private enum VectorClassName implements PrimitiveType.Visitor { + INSTANCE; + + @Override + public String visit(BooleanType booleanType) { + return null; + } + + @Override + public String visit(ByteType byteType) { + return BYTE_VECTOR; + } + + @Override + public String visit(CharType charType) { + return CHAR_VECTOR; + } + + @Override + public String visit(ShortType shortType) { + return SHORT_VECTOR; + } + + @Override + public String visit(IntType intType) { + return INT_VECTOR; + } + + @Override + public String visit(LongType longType) { + return LONG_VECTOR; + } + + @Override + public String visit(FloatType floatType) { + return FLOAT_VECTOR; + } + + @Override + public String visit(DoubleType doubleType) { + return DOUBLE_VECTOR; } } } diff --git a/qst/type/src/main/java/io/deephaven/qst/type/Type.java b/qst/type/src/main/java/io/deephaven/qst/type/Type.java index 959c6d7fec2..a531c657fe5 100644 --- a/qst/type/src/main/java/io/deephaven/qst/type/Type.java +++ b/qst/type/src/main/java/io/deephaven/qst/type/Type.java @@ -3,6 +3,8 @@ // package io.deephaven.qst.type; +import org.jetbrains.annotations.Nullable; + import java.util.List; import java.util.Optional; @@ -19,19 +21,42 @@ public interface Type { * Finds the {@link #knownTypes() known type}, or else creates the relevant {@link NativeArrayType native array * type} or {@link CustomType custom type}. * - * @param clazz the class - * @param the generic type of {@code clazz} + * @param dataType the data type + * @param the generic type of {@code dataType} * @return the type */ - static Type find(Class clazz) { - Optional> found = TypeHelper.findStatic(clazz); + static Type find(Class dataType) { + Optional> found = TypeHelper.findStatic(dataType); if (found.isPresent()) { return found.get(); } - if (clazz.isArray()) { - return NativeArrayType.of(clazz, find(clazz.getComponentType())); + if (dataType.isArray()) { + return NativeArrayType.of(dataType, find(dataType.getComponentType())); } - return CustomType.of(clazz); + return CustomType.of(dataType); + } + + /** + * If {@code componentType} is not {@code null}, this will find the appropriate {@link ArrayType}. Otherwise, this + * is equivalent to {@link #find(Class)}. + * + * @param dataType the data type + * @param componentType the component type + * @return the type + * @param the generic type of {@code dataType} + */ + static Type find(final Class dataType, @Nullable final Class componentType) { + if (componentType == null) { + return find(dataType); + } + final Type ct = find(componentType); + if (dataType.isArray()) { + return NativeArrayType.of(dataType, ct); + } + if (componentType.isPrimitive()) { + return PrimitiveVectorType.of(dataType, (PrimitiveType) ct); + } + return GenericVectorType.of(dataType, (GenericType) ct); } /** diff --git a/qst/type/src/test/java/io/deephaven/qst/type/TypeTest.java b/qst/type/src/test/java/io/deephaven/qst/type/TypeTest.java index 7083033d8d5..7c46997804f 100644 --- a/qst/type/src/test/java/io/deephaven/qst/type/TypeTest.java +++ b/qst/type/src/test/java/io/deephaven/qst/type/TypeTest.java @@ -42,7 +42,6 @@ void numberOfStaticTypes() { assertThat(BoxedType.instances().distinct()).hasSize(8); } - @Test void findBooleans() { check(boolean.class, Boolean.class, booleanType(), BoxedBooleanType.of()); @@ -87,99 +86,139 @@ void findDoubles() { void findString() { assertThat(find(String.class)).isEqualTo(stringType()); assertThat(find(String[].class)).isEqualTo(stringType().arrayType()); + assertThat(find(String[].class, String.class)).isEqualTo(stringType().arrayType()); } @Test void findInstant() { assertThat(find(Instant.class)).isEqualTo(instantType()); assertThat(find(Instant[].class)).isEqualTo(instantType().arrayType()); + assertThat(find(Instant[].class, Instant.class)).isEqualTo(instantType().arrayType()); } @Test void findCustom() { assertThat(find(Custom.class)).isEqualTo(ofCustom(Custom.class)); assertThat(find(Custom[].class)).isEqualTo(ofCustom(Custom.class).arrayType()); + assertThat(find(Custom[].class, Custom.class)).isEqualTo(ofCustom(Custom.class).arrayType()); } @Test void booleanArrayType() { assertThat(find(boolean[].class)).isEqualTo(booleanType().arrayType()); + assertThat(find(boolean[].class, boolean.class)).isEqualTo(booleanType().arrayType()); + assertThat(find(Boolean[].class)).isEqualTo(BoxedBooleanType.of().arrayType()); + assertThat(find(Boolean[].class, Boolean.class)).isEqualTo(BoxedBooleanType.of().arrayType()); } @Test void byteArrayType() { assertThat(find(byte[].class)).isEqualTo(byteType().arrayType()); + assertThat(find(byte[].class, byte.class)).isEqualTo(byteType().arrayType()); + assertThat(find(Byte[].class)).isEqualTo(BoxedByteType.of().arrayType()); + assertThat(find(Byte[].class, Byte.class)).isEqualTo(BoxedByteType.of().arrayType()); } @Test void charArrayType() { assertThat(find(char[].class)).isEqualTo(charType().arrayType()); + assertThat(find(char[].class, char.class)).isEqualTo(charType().arrayType()); + assertThat(find(Character[].class)).isEqualTo(BoxedCharType.of().arrayType()); + assertThat(find(Character[].class, Character.class)).isEqualTo(BoxedCharType.of().arrayType()); } @Test void shortArrayType() { assertThat(find(short[].class)).isEqualTo(shortType().arrayType()); + assertThat(find(short[].class, short.class)).isEqualTo(shortType().arrayType()); + assertThat(find(Short[].class)).isEqualTo(BoxedShortType.of().arrayType()); + assertThat(find(Short[].class, Short.class)).isEqualTo(BoxedShortType.of().arrayType()); } @Test void intArrayType() { assertThat(find(int[].class)).isEqualTo(intType().arrayType()); + assertThat(find(int[].class, int.class)).isEqualTo(intType().arrayType()); + assertThat(find(Integer[].class)).isEqualTo(BoxedIntType.of().arrayType()); + assertThat(find(Integer[].class, Integer.class)).isEqualTo(BoxedIntType.of().arrayType()); } @Test void longArrayType() { assertThat(find(long[].class)).isEqualTo(longType().arrayType()); + assertThat(find(long[].class, long.class)).isEqualTo(longType().arrayType()); + assertThat(find(Long[].class)).isEqualTo(BoxedLongType.of().arrayType()); + assertThat(find(Long[].class, Long.class)).isEqualTo(BoxedLongType.of().arrayType()); } @Test void floatArrayType() { assertThat(find(float[].class)).isEqualTo(floatType().arrayType()); + assertThat(find(float[].class, float.class)).isEqualTo(floatType().arrayType()); + assertThat(find(Float[].class)).isEqualTo(BoxedFloatType.of().arrayType()); + assertThat(find(Float[].class, Float.class)).isEqualTo(BoxedFloatType.of().arrayType()); } @Test void doubleArrayType() { assertThat(find(double[].class)).isEqualTo(doubleType().arrayType()); - assertThat(find(Double[].class)).isEqualTo(BoxedDoubleType.of().arrayType()); + assertThat(find(double[].class, double.class)).isEqualTo(doubleType().arrayType()); + assertThat(find(Double[].class)).isEqualTo(BoxedDoubleType.of().arrayType()); + assertThat(find(Double[].class, Double.class)).isEqualTo(BoxedDoubleType.of().arrayType()); } @Test void nestedPrimitive2x() { assertThat(find(int[][].class)).isEqualTo(intType().arrayType().arrayType()); + assertThat(find(int[][].class, int[].class)).isEqualTo(intType().arrayType().arrayType()); + assertThat(find(Integer[][].class)).isEqualTo(BoxedIntType.of().arrayType().arrayType()); + assertThat(find(Integer[][].class, Integer[].class)).isEqualTo(BoxedIntType.of().arrayType().arrayType()); } @Test void nestedPrimitive3x() { assertThat(find(int[][][].class)).isEqualTo(intType().arrayType().arrayType().arrayType()); + assertThat(find(int[][][].class, int[][].class)).isEqualTo(intType().arrayType().arrayType().arrayType()); + assertThat(find(Integer[][][].class)).isEqualTo(BoxedIntType.of().arrayType().arrayType().arrayType()); + assertThat(find(Integer[][][].class, Integer[][].class)) + .isEqualTo(BoxedIntType.of().arrayType().arrayType().arrayType()); } @Test void nestedStatic2x() { assertThat(find(String[][].class)).isEqualTo(stringType().arrayType().arrayType()); + assertThat(find(String[][].class, String[].class)).isEqualTo(stringType().arrayType().arrayType()); } @Test void nestedStatic3x() { assertThat(find(String[][][].class)).isEqualTo(stringType().arrayType().arrayType().arrayType()); + assertThat(find(String[][][].class, String[][].class)) + .isEqualTo(stringType().arrayType().arrayType().arrayType()); } @Test void nestedCustom2x() { assertThat(find(Custom[][].class)).isEqualTo(CustomType.of(Custom.class).arrayType().arrayType()); + assertThat(find(Custom[][].class, Custom[].class)) + .isEqualTo(CustomType.of(Custom.class).arrayType().arrayType()); } @Test void nestedCustom3x() { assertThat(find(Custom[][][].class)).isEqualTo(CustomType.of(Custom.class).arrayType().arrayType().arrayType()); + assertThat(find(Custom[][][].class, Custom[][].class)) + .isEqualTo(CustomType.of(Custom.class).arrayType().arrayType().arrayType()); } @Test @@ -217,12 +256,14 @@ private static void check( PrimitiveType expectedPrimitive, BoxedType expectedBoxed) { assertThat(find(primitive)).isEqualTo(expectedPrimitive); + assertThat(find(primitive, null)).isEqualTo(expectedPrimitive); assertThat(expectedPrimitive.clazz()).isEqualTo(primitive); assertThat(expectedPrimitive.boxedType()).isEqualTo(expectedBoxed); assertThat(expectedPrimitive.arrayType().componentType()).isEqualTo(expectedPrimitive); assertThat(expectedPrimitive.arrayType().clazz()).isEqualTo(Array.newInstance(primitive, 0).getClass()); assertThat(find(boxed)).isEqualTo(expectedBoxed); + assertThat(find(boxed, null)).isEqualTo(expectedBoxed); assertThat(expectedBoxed.clazz()).isEqualTo(boxed); assertThat(expectedBoxed.primitiveType()).isEqualTo(expectedPrimitive); assertThat(expectedBoxed.arrayType().componentType()).isEqualTo(expectedBoxed);