Skip to content

Commit ba0ca24

Browse files
authored
Check for optional type in nullness checks (#522)
We forgot to check that the value type is an optional type in type inference. This causes various compilation errors in some of the strongly typed languages as you can not compare non-null types with null (*e.g.*, in C++ or Go). In this patch, we fix the type inference to correctly detect the invalid (non-)nullness checks and report them.
1 parent 11a73d9 commit ba0ca24

File tree

2 files changed

+138
-8
lines changed

2 files changed

+138
-8
lines changed

aas_core_codegen/intermediate/type_inference.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -1458,11 +1458,22 @@ def transform_constant(
14581458
def transform_is_none(
14591459
self, node: parse_tree.IsNone
14601460
) -> Optional["TypeAnnotationUnion"]:
1461-
# Just recurse to fill ``type_map`` on ``value`` even though we know the type in
1462-
# advance
1463-
success = self.transform(node.value) is not None
1461+
value_type = self.transform(node.value)
14641462

1465-
if not success:
1463+
# NOTE (mristin):
1464+
# Something went wrong if we could not infer the type of the ``value``.
1465+
if value_type is None:
1466+
return None
1467+
1468+
if not isinstance(value_type, OptionalTypeAnnotation):
1469+
self.errors.append(
1470+
Error(
1471+
node.value.original_node,
1472+
f"Expected the value to be of an optional type for "
1473+
f"a nullness check (``is None``), "
1474+
f"but got {value_type}",
1475+
)
1476+
)
14661477
return None
14671478

14681479
result = PrimitiveTypeAnnotation(PrimitiveType.BOOL)
@@ -1472,11 +1483,22 @@ def transform_is_none(
14721483
def transform_is_not_none(
14731484
self, node: parse_tree.IsNotNone
14741485
) -> Optional["TypeAnnotationUnion"]:
1475-
# Just recurse to fill ``type_map`` on ``value`` even though we know the type in
1476-
# advance
1477-
success = self.transform(node.value) is not None
1486+
value_type = self.transform(node.value)
14781487

1479-
if not success:
1488+
# NOTE (mristin):
1489+
# Something went wrong if we could not infer the type of the ``value``.
1490+
if value_type is None:
1491+
return None
1492+
1493+
if not isinstance(value_type, OptionalTypeAnnotation):
1494+
self.errors.append(
1495+
Error(
1496+
node.value.original_node,
1497+
f"Expected the value to be of an optional type "
1498+
f"for a non-nullness check (``is not None``), "
1499+
f"but got {value_type}",
1500+
)
1501+
)
14801502
return None
14811503

14821504
result = PrimitiveTypeAnnotation(PrimitiveType.BOOL)

tests/intermediate/test_type_inference.py

+108
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import textwrap
44
import unittest
5+
from typing import List
56

67
import tests.common
78
from aas_core_codegen import intermediate
@@ -51,6 +52,59 @@ def execute(source: str) -> None:
5152
len(inferrer.errors) == 0
5253
), tests.common.most_underlying_messages(inferrer.errors)
5354

55+
def expect_type_inference_to_fail(
56+
self, source: str, expected_joined_message: str
57+
) -> None:
58+
"""Execute a smoke test and expect type inference to fail."""
59+
symbol_table, error = tests.common.translate_source_to_intermediate(
60+
source=source
61+
)
62+
assert error is None, tests.common.most_underlying_messages(error)
63+
64+
assert symbol_table is not None
65+
66+
base_environment = intermediate_type_inference.populate_base_environment(
67+
symbol_table=symbol_table
68+
)
69+
70+
type_inference_errors = [] # type: List[str]
71+
72+
for our_type in symbol_table.our_types:
73+
if isinstance(
74+
our_type, (intermediate.AbstractClass, intermediate.ConcreteClass)
75+
):
76+
environment = intermediate_type_inference.MutableEnvironment(
77+
parent=base_environment
78+
)
79+
environment.set(
80+
Identifier("self"),
81+
intermediate_type_inference.OurTypeAnnotation(our_type=our_type),
82+
)
83+
84+
for invariant in our_type.invariants:
85+
canonicalizer = intermediate_type_inference.Canonicalizer()
86+
canonicalizer.transform(invariant.body)
87+
88+
inferrer = intermediate_type_inference.Inferrer(
89+
symbol_table=symbol_table,
90+
environment=environment,
91+
representation_map=canonicalizer.representation_map,
92+
)
93+
94+
inferrer.transform(invariant.body)
95+
if len(inferrer.errors) > 0:
96+
type_inference_errors.append(
97+
tests.common.most_underlying_messages(inferrer.errors)
98+
)
99+
100+
assert len(type_inference_errors) > 0, (
101+
f"Expected one or more type inference errors, "
102+
f"but got none on the source code:\n{source}"
103+
)
104+
105+
joined_message = "\n".join(type_inference_errors)
106+
self.assertEqual(expected_joined_message, joined_message, source)
107+
54108
def test_enumeration_literal_as_member(self) -> None:
55109
source = textwrap.dedent(
56110
"""\
@@ -165,6 +219,60 @@ def __init__(self, some_instance: Optional[Some_class] = None) -> None:
165219

166220
Test_with_smoke.execute(source=source)
167221

222+
def test_is_none_fails_on_non_optional(self) -> None:
223+
source = textwrap.dedent(
224+
"""\
225+
@invariant(
226+
lambda self:
227+
self.something is None,
228+
"Dummy invariant description"
229+
)
230+
class Some_class:
231+
something: str
232+
233+
def __init__(self, something: str) -> None:
234+
self.something = something
235+
236+
__version__ = "dummy"
237+
__xml_namespace__ = "https://dummy.com"
238+
"""
239+
)
240+
241+
self.expect_type_inference_to_fail(
242+
source=source,
243+
expected_joined_message=(
244+
"Expected the value to be of an optional type "
245+
"for a nullness check (``is None``), but got str"
246+
),
247+
)
248+
249+
def test_is_not_none_fails_on_non_optional(self) -> None:
250+
source = textwrap.dedent(
251+
"""\
252+
@invariant(
253+
lambda self:
254+
self.something is not None,
255+
"Dummy invariant description"
256+
)
257+
class Some_class:
258+
something: str
259+
260+
def __init__(self, something: str) -> None:
261+
self.something = something
262+
263+
__version__ = "dummy"
264+
__xml_namespace__ = "https://dummy.com"
265+
"""
266+
)
267+
268+
self.expect_type_inference_to_fail(
269+
source=source,
270+
expected_joined_message=(
271+
"Expected the value to be of an optional type "
272+
"for a non-nullness check (``is not None``), but got str"
273+
),
274+
)
275+
168276

169277
if __name__ == "__main__":
170278
unittest.main()

0 commit comments

Comments
 (0)