Skip to content

Commit

Permalink
Replace strict Type comparisons with Type.is_compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 31, 2021
1 parent dd6973d commit 9ce3bd6
Show file tree
Hide file tree
Showing 26 changed files with 124 additions and 146 deletions.
6 changes: 3 additions & 3 deletions aesara/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def _find_bad_optimizations0(order, reasons, r_vals):
# check if the value for new_r doesn't match the value for r
new_r_val = r_vals[new_r]
r_val = r_vals[r]
assert r.type == new_r.type
assert r.type.is_compatible(new_r.type)

if hasattr(new_r.tag, "values_eq_approx"):
check = new_r.tag.values_eq_approx(r_val, new_r_val)
Expand Down Expand Up @@ -767,7 +767,7 @@ def _find_bad_optimizations1(order, reasons, r_vals):
if re0:
new_r_val = r_vals[re]
r_val = r_vals[re0]
assert re.type == re0.type
assert re.type.is_compatible(re0.type)
if not re.type.values_eq_approx(r_val, new_r_val):
equivalence_sets_broken[id(r_equiv)] = True
there_is_a_problem = True
Expand Down Expand Up @@ -809,7 +809,7 @@ def check_variable_norec(new_r):
new_r_val = r_vals[new_r]
r_val = r_vals[r]

if (r.type != new_r.type) or (
if (not r.type.is_compatible(new_r.type)) or (
not r.type.values_eq_approx(r_val, new_r_val)
):
raise BadOptimization(
Expand Down
4 changes: 2 additions & 2 deletions aesara/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def clone_v_get_shared_updates(v, copy_inputs_over):
v_update = v.type.filter_variable(
v.default_update, allow_convert=False
)
if v_update.type != v.type:
if not v.type.is_compatible(v_update.type):
raise TypeError(
"An update must have a type compatible with "
"the original shared variable"
Expand Down Expand Up @@ -205,7 +205,7 @@ def clone_inputs(i):
)

raise TypeError(err_msg, err_sug)
assert update_val.type == store_into.type
assert store_into.type.is_compatible(update_val.type)

update_d[store_into] = update_val
update_expr.append((store_into, update_val))
Expand Down
15 changes: 8 additions & 7 deletions aesara/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def checkSV(sv_ori, sv_rpl):
"type",
type(sv_ori),
)
assert sv_ori.type == sv_rpl.type, (
assert sv_ori.type.is_compatible(sv_rpl.type), (
"Type of given SharedVariable conflicts with original one",
"Type of given SharedVariable:",
sv_rpl.type,
Expand Down Expand Up @@ -1424,13 +1424,13 @@ def find_same_graph_in_db(graph_db):
print("need to optimize, because output size is different")
continue
elif not all(
input_new.type == input_old.type
input_old.type.is_compatible(input_new.type)
for input_new, input_old in zip(inputs_new, inputs_old)
):
print("need to optimize, because inputs are of different " "types")
continue
elif not all(
output_new.type == output_old.type
output_old.type.is_compatible(output_new.type)
for output_new, output_old in zip(outputs_new, outputs_old)
):
print("need to optimize, because outputs are of different " "types")
Expand Down Expand Up @@ -1471,11 +1471,12 @@ def find_same_graph_in_db(graph_db):
)
)

# hack to remove inconsistent entry in givens
# seems to work that but source of inconsistency
# could be worth investigating.
for key, value in temp.items():
if key.type != value.type:
if not key.type.is_compatible(value.type):
warnings.warn(
UserWarning,
"`givens` key type is not consistent with its value.",
)
del givens[key]

flag = is_same_graph(t1, t2, givens=givens)
Expand Down
4 changes: 2 additions & 2 deletions aesara/compile/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def __init__(
raise TypeError(f"name must be a string! (got: {self.name})")
self.update = update
if update is not None:
if variable.type != update.type:
if not variable.type.is_compatible(update.type):
raise TypeError(
f"Variable '{variable}' has type {variable.type} but an update of "
f"type {update.type}. The type of the update should be "
"the same as the type of the variable"
"compatible with the type of the variable."
)

if mutable is not None:
Expand Down
2 changes: 2 additions & 0 deletions aesara/gpuarray/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2955,6 +2955,7 @@ def c_headers(self, **kwargs):
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
# This check is fine, because it strictly deals with scalar `Type`s
if node.inputs[0].type in complex_types:
raise NotImplementedError("type not supported", type)
# NB: CUDA erfinv function (GPU op) returns NaN if x not in [-1;1],
Expand Down Expand Up @@ -2982,6 +2983,7 @@ def c_headers(self, **kwargs):
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
# This check is fine, because it strictly deals with scalar `Type`s
if node.inputs[0].type in complex_types:
raise NotImplementedError("type not supported", type)
# NB: CUDA erfcinv function (GPU op) returns NaN if x not in [0;2],
Expand Down
1 change: 1 addition & 0 deletions aesara/gpuarray/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def apply(self, fgraph):
if (
new_o.owner
and isinstance(new_o.owner.op, GpuFromHost)
# TODO: Should this be `Type.is_compatible`?
and new_o.owner.inputs[0].type == o.type
):
new_o = new_o.owner.inputs[0]
Expand Down
2 changes: 1 addition & 1 deletion aesara/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _traverse(node):
# correctly, the same as grad
y = aesara.tensor.cast(y, x.type.dtype)
y = x.type.filter_variable(y)
assert x.type == y.type
assert x.type.is_compatible(y.type)
same_type_eval_points.append(y)
else:
same_type_eval_points.append(y)
Expand Down
2 changes: 2 additions & 0 deletions aesara/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,13 +1543,15 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
if x.owner: # Check above tell that y.owner eval to True too.
if x.owner.outputs.index(x) != y.owner.outputs.index(y):
return False
# TODO: Should we use `Type.is_compatible`?
if x not in in_xs and x.type != y.type:
return False

if len(in_xs) != len(in_ys):
return False

for _x, _y in zip(in_xs, in_ys):
# TODO: Should we use `Type.is_compatible`?
if _x.type != _y.type:
return False

Expand Down
17 changes: 9 additions & 8 deletions aesara/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,9 @@ def change_input(
) -> None:
"""Change ``node.inputs[i]`` to `new_var`.
``new_var.type == old_var.type`` must be ``True``, where ``old_var`` is the
current value of ``node.inputs[i]`` which we want to replace.
``new_var.type.is_compatible(old_var.type)`` must be ``True``, where
``old_var`` is the current value of ``node.inputs[i]`` which we want to
replace.
For each feature that has an `on_change_input` method, this method calls:
``feature.on_change_input(function_graph, node, i, old_var, new_var, reason)``
Expand All @@ -450,18 +451,18 @@ def change_input(
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == "output":
r = self.outputs[i]
if r.type != new_var.type:
if not r.type.is_compatible(new_var.type):
raise TypeError(
f"The type of the replacement ({new_var.type}) must be the"
f" same as the type of the original Variable ({r.type})."
f"The type of the replacement ({new_var.type}) must be "
f"compatible with the type of the original Variable ({r.type})."
)
self.outputs[i] = new_var
else:
r = node.inputs[i]
if r.type != new_var.type:
if not r.type.is_compatible(new_var.type):
raise TypeError(
f"The type of the replacement ({new_var.type}) must be the"
f" same as the type of the original Variable ({r.type})."
f"The type of the replacement ({new_var.type}) must be "
f"compatible with the type of the original Variable ({r.type})."
)
node.inputs[i] = new_var

Expand Down
1 change: 1 addition & 0 deletions aesara/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def make_node(self, *inputs: Variable) -> Apply:
raise ValueError(
f"We expected {len(self.itypes)} inputs but got {len(inputs)}."
)
# TODO: Should this be `is_compatible`?
if not all(inp.type == it for inp, it in zip(inputs, self.itypes)):
raise TypeError(
f"Invalid input types for Op {self}:\n"
Expand Down
2 changes: 1 addition & 1 deletion aesara/graph/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,7 +1768,7 @@ def transform(self, fgraph, node, get_nodes=True):
else:
# ret is just an input variable
assert len(node.outputs) == 1
if ret.type != node.outputs[0].type:
if not node.outputs[0].type.is_compatible(ret.type):
return False

return [ret]
Expand Down
20 changes: 6 additions & 14 deletions aesara/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.tensor import basic
from aesara.tensor.shape import Reshape, Shape, SpecifyShape
from aesara.tensor.type import TensorType


__docformat__ = "restructedtext en"
Expand Down Expand Up @@ -190,9 +189,9 @@ def make_node(self, c, *args):
# TODO: Attempt to convert types so that they match?
# new_f = t.type.filter_variable(f)

if t.type != f.type:
if not t.type.is_compatible(f.type):
raise TypeError(
"IfElse requires same types for true and false return values: "
"IfElse requires compatible types for true and false return values: "
f"true_branch={t.type}, false_branch={f.type}"
)
if c.ndim > 0:
Expand Down Expand Up @@ -369,27 +368,20 @@ def ifelse(
if not isinstance(else_branch_elem, Variable):
else_branch_elem = aet.basic.as_tensor_variable(else_branch_elem)

if then_branch_elem.type != else_branch_elem.type:
if not then_branch_elem.type.is_compatible(else_branch_elem.type):
# If one of them is a TensorType, and the other one can be
# converted into one, then we try to do that.
# This case happens when one of the elements has a GPU type,
# for instance a shared variable that was silently moved to GPU.
if isinstance(then_branch_elem.type, TensorType) and not isinstance(
else_branch_elem.type, TensorType
):
if then_branch_elem.type.is_compatible(else_branch_elem.type):
else_branch_elem = then_branch_elem.type.filter_variable(
else_branch_elem
)

elif isinstance(else_branch_elem.type, TensorType) and not isinstance(
then_branch_elem.type, TensorType
):
elif else_branch_elem.type.is_compatible(then_branch_elem.type):
then_branch_elem = else_branch_elem.type.filter_variable(
then_branch_elem
)

if then_branch_elem.type != else_branch_elem.type:
# If the types still don't match, there is a problem.
else:
raise TypeError(
"The two branches should have identical types, but "
f"they are {then_branch_elem.type} and {else_branch_elem.type} respectively. This error could be "
Expand Down
2 changes: 2 additions & 0 deletions aesara/link/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
if (
getattr(out, "ndim", None) == 0
and out not in pre_allocated
# TODO: Should this be `is_compatible`?
and ins.type == out.type
):
reuse_out = out
Expand All @@ -110,6 +111,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
if (
getattr(out, "ndim", None) == 0
and out not in pre_allocated
# TODO: Should this be `is_compatible`?
and ins.type == out.type
):
reuse_out = out
Expand Down
2 changes: 1 addition & 1 deletion aesara/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ def c_code_contiguous(self, node, name, inputs, outputs, sub):
or
# We compare the dtype AND the broadcast flag
# as this function do not broadcast
node.inputs[0].type != node.outputs[0].type
not node.inputs[0].type.is_compatible(node.outputs[0].type)
):
raise MethodNotDefined()

Expand Down
12 changes: 6 additions & 6 deletions aesara/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,12 +534,12 @@ def validate_inner_graph(self):

type_input = self.inputs[inner_iidx].type
type_output = self.outputs[inner_oidx].type
if type_input != type_output:
if not type_input.is_compatible(type_output):
raise TypeError(
"Inconsistency in the inner graph of "
f"scan '{self.name}' : an input and an output are "
"associated with the same recurrent state "
"and should have the same type but have "
"and should have compatible types but have "
f"type '{type_input}' and '{type_output}' respectively."
)

Expand Down Expand Up @@ -1068,11 +1068,11 @@ def make_node(self, *inputs):
):
outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq)
if inner_nonseq.type != outer_nonseq.type:
if not inner_nonseq.type.is_compatible(outer_nonseq.type):
raise ValueError(
(
f"Argument {outer_nonseq} given to the scan node does not"
f" match its corresponding loop function variable {inner_nonseq}"
f"Argument {outer_nonseq} given to the scan node is not"
f" compatible with its corresponding loop function variable {inner_nonseq}"
)
)

Expand Down Expand Up @@ -1143,7 +1143,7 @@ def __eq__(self, other):
return False

for self_in, other_in in zip(self.inputs, other.inputs):
if self_in.type != other_in.type:
if not self_in.type.is_compatible(other_in.type):
return False

return equal_computations(
Expand Down
4 changes: 2 additions & 2 deletions aesara/scan/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def add_to_replace(y):
for out, idx in to_replace_map.items():
if ( # If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
replace_with_in[idx].type == out.type
replace_with_in[idx].type.is_compatible(out.type)
and out in to_keep_set
and out.owner not in existent_nodes_set
):
Expand Down Expand Up @@ -557,7 +557,7 @@ def add_to_replace(y):
and
# If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
replace_with_in[idx].type == out.type
replace_with_in[idx].type.is_compatible(out.type)
):

clean_to_replace.append(out)
Expand Down
Loading

0 comments on commit 9ce3bd6

Please sign in to comment.