Skip to content

Commit b14237b

Browse files
authored
Fix type inference for disjunctions with nullness (#524)
We extend the type inference to, correctly, consider the case with disjunctions (``or``) with nullness checks (``is None``). For example, ``(A is None) or (A == 3)`` means that the first value in the disjunction (``A is None``) implies that ``A`` is non-null in the second value (``A == 3``).
1 parent 6df5c9e commit b14237b

File tree

7 files changed

+291
-31
lines changed

7 files changed

+291
-31
lines changed

aas_core_codegen/cpp/transpilation.py

+15
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,14 @@ def transform_method_call(
635635
for arg_node in node.args:
636636
arg_type = self.type_map[arg_node]
637637

638+
# NOTE (mristin):
639+
# This is a tough call to make. We decide that a value, for which we
640+
# know that it might be null, will not be de-referenced. On the other
641+
# hand, if the value is certainly not null, we de-reference it.
642+
#
643+
# The problem here is that the actual type of the argument in C++ changes
644+
# depending on whether we check for its nullness before with an implication.
645+
638646
arg: Optional[Stripped]
639647
if isinstance(arg_type, intermediate_type_inference.OptionalTypeAnnotation):
640648
arg, error = self.transform(arg_node)
@@ -680,6 +688,13 @@ def transform_function_call(
680688
for arg_node in node.args:
681689
arg_type = self.type_map[arg_node]
682690

691+
# NOTE (mristin):
692+
# This is a tough call to make. We decide that a value, for which we
693+
# know that it might be null, will not be de-referenced. On the other
694+
# hand, if the value is certainly not null, we de-reference it.
695+
#
696+
# The problem here is that the actual type of the argument in C++ changes
697+
# depending on whether we check for its nullness before with an implication.
683698
arg: Optional[Stripped]
684699
if isinstance(arg_type, intermediate_type_inference.OptionalTypeAnnotation):
685700
arg, error = self.transform(arg_node)

aas_core_codegen/cpp/verification/_generate.py

+14-18
Original file line numberDiff line numberDiff line change
@@ -1538,12 +1538,22 @@ def _generate_non_recursive_verificator_implementation(
15381538

15391539
@ensure(lambda result: (result[0] is not None) ^ (result[1] is not None))
15401540
def _generate_non_recursive_verificator(
1541-
verificator_qualities: VerificatorQualities,
1541+
cls: intermediate.ConcreteClass,
15421542
symbol_table: intermediate.SymbolTable,
1543-
environment: intermediate_type_inference.Environment,
1543+
base_environment: intermediate_type_inference.Environment,
15441544
) -> Tuple[Optional[List[Stripped]], Optional[Error]]:
15451545
"""Generate the non-recursive verificator for the ``cls``."""
1546-
cls = verificator_qualities.cls
1546+
environment = intermediate_type_inference.MutableEnvironment(
1547+
parent=base_environment
1548+
)
1549+
1550+
assert environment.find(Identifier("self")) is None
1551+
environment.set(
1552+
identifier=Identifier("self"),
1553+
type_annotation=intermediate_type_inference.OurTypeAnnotation(our_type=cls),
1554+
)
1555+
1556+
verificator_qualities = VerificatorQualities(cls=cls)
15471557

15481558
if verificator_qualities.is_noop:
15491559
return (
@@ -3007,22 +3017,8 @@ def generate_implementation(
30073017
)
30083018

30093019
for cls in symbol_table.concrete_classes:
3010-
invariant_environment = intermediate_type_inference.MutableEnvironment(
3011-
parent=base_environment
3012-
)
3013-
3014-
assert invariant_environment.find(Identifier("self")) is None
3015-
invariant_environment.set(
3016-
identifier=Identifier("self"),
3017-
type_annotation=intermediate_type_inference.OurTypeAnnotation(our_type=cls),
3018-
)
3019-
3020-
verificator_qualities = VerificatorQualities(cls=cls)
3021-
30223020
nrv_blocks, nrv_error = _generate_non_recursive_verificator(
3023-
verificator_qualities=verificator_qualities,
3024-
symbol_table=symbol_table,
3025-
environment=invariant_environment,
3021+
cls=cls, symbol_table=symbol_table, base_environment=base_environment
30263022
)
30273023
if nrv_error is not None:
30283024
errors.append(nrv_error)

aas_core_codegen/intermediate/type_inference.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,7 @@ def transform_not(self, node: parse_tree.Not) -> str:
724724

725725
def transform_and(self, node: parse_tree.And) -> str:
726726
values = [] # type: List[str]
727+
727728
for value_node in node.values:
728729
value = self.transform(value_node)
729730
if not _Canonicalizer._needs_no_brackets(value_node):
@@ -1611,20 +1612,34 @@ def transform_or(self, node: parse_tree.Or) -> Optional["TypeAnnotationUnion"]:
16111612

16121613
success = True
16131614

1614-
for value_node in node.values:
1615-
value_type = self.transform(value_node)
1616-
if value_type is None:
1617-
return None
1615+
with contextlib.ExitStack() as exit_stack:
1616+
for value_node in node.values:
1617+
value_type = self.transform(value_node)
1618+
if value_type is None:
1619+
return None
16181620

1619-
if isinstance(value_type, OptionalTypeAnnotation):
1620-
self.errors.append(
1621-
Error(
1622-
value_node.original_node,
1623-
f"Expected the value to be a non-None, "
1624-
f"but got: {value_type}",
1621+
if isinstance(value_type, OptionalTypeAnnotation):
1622+
self.errors.append(
1623+
Error(
1624+
value_node.original_node,
1625+
f"Expected the value to be a non-None, "
1626+
f"but got: {value_type}",
1627+
)
16251628
)
1626-
)
1627-
success = False
1629+
success = False
1630+
1631+
if isinstance(value_node, parse_tree.IsNone):
1632+
canonical_repr = self._representation_map[value_node.value]
1633+
self._non_null.increment(canonical_repr)
1634+
1635+
# fmt: off
1636+
exit_stack.callback(
1637+
lambda a_canonical_repr=canonical_repr: # type: ignore
1638+
self._non_null.decrement(
1639+
a_canonical_repr
1640+
)
1641+
)
1642+
# fmt: on
16281643

16291644
if not success:
16301645
return None

test_data/cpp/test_main/aas_core_meta.v3/expected_output/verification.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -11560,7 +11560,7 @@ void OfSubmodelElementList::Execute() {
1156011560
(
1156111561
(!(instance_->value().has_value()))
1156211562
|| PropertiesOrRangesHaveValueType(
11563-
instance_->value(),
11563+
(*(instance_->value())),
1156411564
(*(instance_->value_type_list_element()))
1156511565
)
1156611566
)

tests/common.py

+23
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,29 @@ def translate_source_to_intermediate(
9393
return intermediate.translate(parsed_symbol_table=parsed_symbol_table, atok=atok)
9494

9595

96+
def must_translate_source_to_intermediate(
97+
source: str,
98+
) -> intermediate.SymbolTable:
99+
atok, parse_exception = parse.source_to_atok(source=source)
100+
if parse_exception:
101+
raise parse_exception # pylint: disable=raising-bad-type
102+
103+
assert atok is not None
104+
105+
parsed_symbol_table, error = parse_atok(atok=atok)
106+
assert error is None, f"{most_underlying_messages(error)}"
107+
assert parsed_symbol_table is not None
108+
109+
symbol_table, error = intermediate.translate(
110+
parsed_symbol_table=parsed_symbol_table, atok=atok
111+
)
112+
assert (
113+
error is None
114+
), f"Unexpected error when parsing the source: {most_underlying_messages(error)}"
115+
assert symbol_table is not None
116+
return symbol_table
117+
118+
96119
#: If set, this environment variable indicates that the golden files should be
97120
#: re-recorded instead of checked against.
98121
RERECORD = os.environ.get("AAS_CORE_CODEGEN_RERECORD", "").lower() in (

tests/cpp/test_verification.py

+189
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# pylint: disable=missing-docstring
2+
3+
import unittest
4+
from typing import List
5+
6+
from aas_core_codegen.common import Stripped, Identifier
7+
from aas_core_codegen.cpp.verification import _generate as cpp_verification_generate
8+
from aas_core_codegen.intermediate import type_inference as intermediate_type_inference
9+
from tests import common as tests_common
10+
11+
12+
class Test_against_recorded(unittest.TestCase):
13+
def test_optional_vector_in_invariant_not_deferenced(self) -> None:
14+
# NOTE (mristin):
15+
# This is a regression test where we check that an optional list without
16+
# implication is not de-referenced in the invariant's transpiled code.
17+
18+
source = """\
19+
class Item:
20+
value: str
21+
22+
def __init__(self, value: str) -> None:
23+
self.value = value
24+
25+
@verification
26+
@implementation_specific
27+
def check_something(value: Optional[List[Item]]) -> bool:
28+
pass
29+
30+
@invariant(
31+
lambda self:
32+
check_something(self.value),
33+
"Some description"
34+
)
35+
class Something:
36+
value: Optional[List[Item]]
37+
38+
def __init__(self, value: Optional[List[Item]] = None) -> None:
39+
self.value = value
40+
41+
__version__ = "dummy"
42+
__xml_namespace__ = "https://dummy.com"
43+
"""
44+
45+
symbol_table = tests_common.must_translate_source_to_intermediate(source=source)
46+
47+
base_environment = intermediate_type_inference.populate_base_environment(
48+
symbol_table=symbol_table
49+
)
50+
51+
something_cls = symbol_table.must_find_concrete_class(
52+
name=Identifier("Something")
53+
)
54+
55+
environment = intermediate_type_inference.MutableEnvironment(
56+
parent=base_environment
57+
)
58+
environment.set(
59+
identifier=Identifier("self"),
60+
type_annotation=intermediate_type_inference.OurTypeAnnotation(
61+
our_type=something_cls
62+
),
63+
)
64+
65+
blocks = [] # type: List[Stripped]
66+
for invariant in something_cls.invariants:
67+
(
68+
condition_expr,
69+
error,
70+
) = cpp_verification_generate._transpile_class_invariant(
71+
invariant=invariant, symbol_table=symbol_table, environment=environment
72+
)
73+
assert error is None, (
74+
f"Unexpected generation error for an invariant: "
75+
f"{tests_common.most_underlying_messages(error)}"
76+
)
77+
assert condition_expr is not None
78+
79+
blocks.append(condition_expr)
80+
81+
assert len(blocks) == 1, (
82+
f"Expected only a single block for a single invariant "
83+
f"in the class {something_cls.name!r}"
84+
)
85+
86+
# NOTE (mristin):
87+
# The implementation of ``CheckSomething`` needs to deal with the optional
88+
# values. As soon as something is optional in C++, it will be provided to
89+
# the function as-is. This is intentional.
90+
self.assertEqual(
91+
"""\
92+
CheckSomething(
93+
instance_->value()
94+
)""",
95+
blocks[0],
96+
)
97+
98+
def test_optional_in_invariant_dereferenced_if_certainly_not_null(
99+
self,
100+
) -> None:
101+
# NOTE (mristin):
102+
# This is a regression test where we check that an optional value is correctly
103+
# de-referenced in the invariant's transpiled code if type inference determines
104+
# that it is certainly not null.
105+
106+
source = """\
107+
class Item:
108+
value: str
109+
110+
def __init__(self, value: str) -> None:
111+
self.value = value
112+
113+
@verification
114+
@implementation_specific
115+
def check_something(value: Optional[List[Item]]) -> bool:
116+
pass
117+
118+
@invariant(
119+
lambda self:
120+
(self.value is None) or check_something(self.value),
121+
"Some description"
122+
)
123+
class Something:
124+
value: Optional[List[Item]]
125+
126+
def __init__(self, value: Optional[List[Item]] = None) -> None:
127+
self.value = value
128+
129+
__version__ = "dummy"
130+
__xml_namespace__ = "https://dummy.com"
131+
"""
132+
133+
symbol_table = tests_common.must_translate_source_to_intermediate(source=source)
134+
135+
base_environment = intermediate_type_inference.populate_base_environment(
136+
symbol_table=symbol_table
137+
)
138+
139+
something_cls = symbol_table.must_find_concrete_class(
140+
name=Identifier("Something")
141+
)
142+
143+
environment = intermediate_type_inference.MutableEnvironment(
144+
parent=base_environment
145+
)
146+
environment.set(
147+
identifier=Identifier("self"),
148+
type_annotation=intermediate_type_inference.OurTypeAnnotation(
149+
our_type=something_cls
150+
),
151+
)
152+
153+
blocks = [] # type: List[Stripped]
154+
for invariant in something_cls.invariants:
155+
(
156+
condition_expr,
157+
error,
158+
) = cpp_verification_generate._transpile_class_invariant(
159+
invariant=invariant, symbol_table=symbol_table, environment=environment
160+
)
161+
assert error is None, (
162+
f"Unexpected generation error for an invariant: "
163+
f"{tests_common.most_underlying_messages(error)}"
164+
)
165+
assert condition_expr is not None
166+
167+
blocks.append(condition_expr)
168+
169+
assert len(blocks) == 1, (
170+
f"Expected only a single block for a single invariant "
171+
f"in the class {something_cls.name!r}"
172+
)
173+
174+
# NOTE (mristin):
175+
# If we know that the value is not optional, we explicitly de-reference it.
176+
self.assertEqual(
177+
"""\
178+
(
179+
(!(instance_->value().has_value()))
180+
|| CheckSomething(
181+
(*(instance_->value()))
182+
)
183+
)""",
184+
blocks[0],
185+
)
186+
187+
188+
if __name__ == "__main__":
189+
unittest.main()

tests/intermediate/test_type_inference.py

+22
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,28 @@ def __init__(self, some_instance: Optional[Some_class] = None) -> None:
252252

253253
Test_with_smoke.execute(source=source)
254254

255+
def test_non_nullness_in_disjunction_with_is_none(self) -> None:
256+
source = textwrap.dedent(
257+
"""\
258+
@invariant(
259+
lambda self:
260+
(self.some_property is None) or (self.some_property == "dummy"),
261+
"Dummy description"
262+
)
263+
class Something:
264+
some_property: Optional[str]
265+
266+
def __init__(self, some_property: Optional[str] = None) -> None:
267+
self.some_property = some_property
268+
269+
270+
__version__ = "dummy"
271+
__xml_namespace__ = "https://dummy.com"
272+
"""
273+
)
274+
275+
Test_with_smoke.execute(source=source)
276+
255277
def test_is_none_fails_on_non_optional(self) -> None:
256278
source = textwrap.dedent(
257279
"""\

0 commit comments

Comments
 (0)