From 9ce3bd6a5f5fcf9ab2a5f68f5af986899eeeb0d9 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 29 Dec 2021 22:41:10 -0600 Subject: [PATCH] Replace strict Type comparisons with Type.is_compatible --- aesara/compile/debugmode.py | 6 ++--- aesara/compile/function/pfunc.py | 4 +-- aesara/compile/function/types.py | 15 +++++------ aesara/compile/io.py | 4 +-- aesara/gpuarray/elemwise.py | 2 ++ aesara/gpuarray/opt.py | 1 + aesara/gradient.py | 2 +- aesara/graph/basic.py | 2 ++ aesara/graph/fg.py | 17 +++++++------ aesara/graph/op.py | 1 + aesara/graph/opt.py | 2 +- aesara/ifelse.py | 20 +++++---------- aesara/link/vm.py | 2 ++ aesara/scalar/basic.py | 2 +- aesara/scan/op.py | 12 ++++----- aesara/scan/opt.py | 4 +-- aesara/tensor/basic.py | 21 ++++++---------- aesara/tensor/basic_opt.py | 26 +++++++++---------- aesara/tensor/blas.py | 25 ++++++++++++------- aesara/tensor/extra_ops.py | 11 ++------- aesara/tensor/math_opt.py | 15 ++++++----- aesara/tensor/nnet/basic.py | 17 ++++++------- aesara/tensor/subtensor.py | 33 ++++++------------------- aesara/tensor/subtensor_opt.py | 14 +++++------ aesara/typed_list/basic.py | 4 +-- tests/tensor/nnet/test_abstract_conv.py | 8 +++--- 26 files changed, 124 insertions(+), 146 deletions(-) diff --git a/aesara/compile/debugmode.py b/aesara/compile/debugmode.py index 7a4ab3a7f5..9ec055e55d 100644 --- a/aesara/compile/debugmode.py +++ b/aesara/compile/debugmode.py @@ -715,7 +715,7 @@ def _find_bad_optimizations0(order, reasons, r_vals): # check if the value for new_r doesn't match the value for r new_r_val = r_vals[new_r] r_val = r_vals[r] - assert r.type == new_r.type + assert r.type.is_compatible(new_r.type) if hasattr(new_r.tag, "values_eq_approx"): check = new_r.tag.values_eq_approx(r_val, new_r_val) @@ -767,7 +767,7 @@ def _find_bad_optimizations1(order, reasons, r_vals): if re0: new_r_val = r_vals[re] r_val = r_vals[re0] - assert re.type == re0.type + assert re.type.is_compatible(re0.type) if not re.type.values_eq_approx(r_val, new_r_val): equivalence_sets_broken[id(r_equiv)] = True there_is_a_problem = True @@ -809,7 +809,7 @@ def check_variable_norec(new_r): new_r_val = r_vals[new_r] r_val = r_vals[r] - if (r.type != new_r.type) or ( + if (not r.type.is_compatible(new_r.type)) or ( not r.type.values_eq_approx(r_val, new_r_val) ): raise BadOptimization( diff --git a/aesara/compile/function/pfunc.py b/aesara/compile/function/pfunc.py index 56decca40d..5089d1a693 100644 --- a/aesara/compile/function/pfunc.py +++ b/aesara/compile/function/pfunc.py @@ -111,7 +111,7 @@ def clone_v_get_shared_updates(v, copy_inputs_over): v_update = v.type.filter_variable( v.default_update, allow_convert=False ) - if v_update.type != v.type: + if not v.type.is_compatible(v_update.type): raise TypeError( "An update must have a type compatible with " "the original shared variable" @@ -205,7 +205,7 @@ def clone_inputs(i): ) raise TypeError(err_msg, err_sug) - assert update_val.type == store_into.type + assert store_into.type.is_compatible(update_val.type) update_d[store_into] = update_val update_expr.append((store_into, update_val)) diff --git a/aesara/compile/function/types.py b/aesara/compile/function/types.py index 27122e56e2..561df9c05b 100644 --- a/aesara/compile/function/types.py +++ b/aesara/compile/function/types.py @@ -626,7 +626,7 @@ def checkSV(sv_ori, sv_rpl): "type", type(sv_ori), ) - assert sv_ori.type == sv_rpl.type, ( + assert sv_ori.type.is_compatible(sv_rpl.type), ( "Type of given SharedVariable conflicts with original one", "Type of given SharedVariable:", sv_rpl.type, @@ -1424,13 +1424,13 @@ def find_same_graph_in_db(graph_db): print("need to optimize, because output size is different") continue elif not all( - input_new.type == input_old.type + input_old.type.is_compatible(input_new.type) for input_new, input_old in zip(inputs_new, inputs_old) ): print("need to optimize, because inputs are of different " "types") continue elif not all( - output_new.type == output_old.type + output_old.type.is_compatible(output_new.type) for output_new, output_old in zip(outputs_new, outputs_old) ): print("need to optimize, because outputs are of different " "types") @@ -1471,11 +1471,12 @@ def find_same_graph_in_db(graph_db): ) ) - # hack to remove inconsistent entry in givens - # seems to work that but source of inconsistency - # could be worth investigating. for key, value in temp.items(): - if key.type != value.type: + if not key.type.is_compatible(value.type): + warnings.warn( + UserWarning, + "`givens` key type is not consistent with its value.", + ) del givens[key] flag = is_same_graph(t1, t2, givens=givens) diff --git a/aesara/compile/io.py b/aesara/compile/io.py index 3580f3ae46..6a5fcd44fc 100644 --- a/aesara/compile/io.py +++ b/aesara/compile/io.py @@ -79,11 +79,11 @@ def __init__( raise TypeError(f"name must be a string! (got: {self.name})") self.update = update if update is not None: - if variable.type != update.type: + if not variable.type.is_compatible(update.type): raise TypeError( f"Variable '{variable}' has type {variable.type} but an update of " f"type {update.type}. The type of the update should be " - "the same as the type of the variable" + "compatible with the type of the variable." ) if mutable is not None: diff --git a/aesara/gpuarray/elemwise.py b/aesara/gpuarray/elemwise.py index 65313e7e7d..64b12ede16 100644 --- a/aesara/gpuarray/elemwise.py +++ b/aesara/gpuarray/elemwise.py @@ -2955,6 +2955,7 @@ def c_headers(self, **kwargs): def c_code(self, node, name, inp, out, sub): (x,) = inp (z,) = out + # This check is fine, because it strictly deals with scalar `Type`s if node.inputs[0].type in complex_types: raise NotImplementedError("type not supported", type) # NB: CUDA erfinv function (GPU op) returns NaN if x not in [-1;1], @@ -2982,6 +2983,7 @@ def c_headers(self, **kwargs): def c_code(self, node, name, inp, out, sub): (x,) = inp (z,) = out + # This check is fine, because it strictly deals with scalar `Type`s if node.inputs[0].type in complex_types: raise NotImplementedError("type not supported", type) # NB: CUDA erfcinv function (GPU op) returns NaN if x not in [0;2], diff --git a/aesara/gpuarray/opt.py b/aesara/gpuarray/opt.py index d6c97948d2..45f5dfee91 100644 --- a/aesara/gpuarray/opt.py +++ b/aesara/gpuarray/opt.py @@ -458,6 +458,7 @@ def apply(self, fgraph): if ( new_o.owner and isinstance(new_o.owner.op, GpuFromHost) + # TODO: Should this be `Type.is_compatible`? and new_o.owner.inputs[0].type == o.type ): new_o = new_o.owner.inputs[0] diff --git a/aesara/gradient.py b/aesara/gradient.py index 02ffaf85f1..b4d668b959 100644 --- a/aesara/gradient.py +++ b/aesara/gradient.py @@ -301,7 +301,7 @@ def _traverse(node): # correctly, the same as grad y = aesara.tensor.cast(y, x.type.dtype) y = x.type.filter_variable(y) - assert x.type == y.type + assert x.type.is_compatible(y.type) same_type_eval_points.append(y) else: same_type_eval_points.append(y) diff --git a/aesara/graph/basic.py b/aesara/graph/basic.py index 4cd2b6bf6d..c8be77eddd 100644 --- a/aesara/graph/basic.py +++ b/aesara/graph/basic.py @@ -1543,6 +1543,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): if x.owner: # Check above tell that y.owner eval to True too. if x.owner.outputs.index(x) != y.owner.outputs.index(y): return False + # TODO: Should we use `Type.is_compatible`? if x not in in_xs and x.type != y.type: return False @@ -1550,6 +1551,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): return False for _x, _y in zip(in_xs, in_ys): + # TODO: Should we use `Type.is_compatible`? if _x.type != _y.type: return False diff --git a/aesara/graph/fg.py b/aesara/graph/fg.py index 2189d55e82..65fd5fe58a 100644 --- a/aesara/graph/fg.py +++ b/aesara/graph/fg.py @@ -428,8 +428,9 @@ def change_input( ) -> None: """Change ``node.inputs[i]`` to `new_var`. - ``new_var.type == old_var.type`` must be ``True``, where ``old_var`` is the - current value of ``node.inputs[i]`` which we want to replace. + ``new_var.type.is_compatible(old_var.type)`` must be ``True``, where + ``old_var`` is the current value of ``node.inputs[i]`` which we want to + replace. For each feature that has an `on_change_input` method, this method calls: ``feature.on_change_input(function_graph, node, i, old_var, new_var, reason)`` @@ -450,18 +451,18 @@ def change_input( # TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?) if node == "output": r = self.outputs[i] - if r.type != new_var.type: + if not r.type.is_compatible(new_var.type): raise TypeError( - f"The type of the replacement ({new_var.type}) must be the" - f" same as the type of the original Variable ({r.type})." + f"The type of the replacement ({new_var.type}) must be " + f"compatible with the type of the original Variable ({r.type})." ) self.outputs[i] = new_var else: r = node.inputs[i] - if r.type != new_var.type: + if not r.type.is_compatible(new_var.type): raise TypeError( - f"The type of the replacement ({new_var.type}) must be the" - f" same as the type of the original Variable ({r.type})." + f"The type of the replacement ({new_var.type}) must be " + f"compatible with the type of the original Variable ({r.type})." ) node.inputs[i] = new_var diff --git a/aesara/graph/op.py b/aesara/graph/op.py index 3d69ed0bbb..8c32ea163a 100644 --- a/aesara/graph/op.py +++ b/aesara/graph/op.py @@ -223,6 +223,7 @@ def make_node(self, *inputs: Variable) -> Apply: raise ValueError( f"We expected {len(self.itypes)} inputs but got {len(inputs)}." ) + # TODO: Should this be `is_compatible`? if not all(inp.type == it for inp, it in zip(inputs, self.itypes)): raise TypeError( f"Invalid input types for Op {self}:\n" diff --git a/aesara/graph/opt.py b/aesara/graph/opt.py index ea3135fe16..9d8cfe0b2f 100644 --- a/aesara/graph/opt.py +++ b/aesara/graph/opt.py @@ -1768,7 +1768,7 @@ def transform(self, fgraph, node, get_nodes=True): else: # ret is just an input variable assert len(node.outputs) == 1 - if ret.type != node.outputs[0].type: + if not node.outputs[0].type.is_compatible(ret.type): return False return [ret] diff --git a/aesara/ifelse.py b/aesara/ifelse.py index bc2f1778c6..fce2bc3438 100644 --- a/aesara/ifelse.py +++ b/aesara/ifelse.py @@ -25,7 +25,6 @@ from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer from aesara.tensor import basic from aesara.tensor.shape import Reshape, Shape, SpecifyShape -from aesara.tensor.type import TensorType __docformat__ = "restructedtext en" @@ -190,9 +189,9 @@ def make_node(self, c, *args): # TODO: Attempt to convert types so that they match? # new_f = t.type.filter_variable(f) - if t.type != f.type: + if not t.type.is_compatible(f.type): raise TypeError( - "IfElse requires same types for true and false return values: " + "IfElse requires compatible types for true and false return values: " f"true_branch={t.type}, false_branch={f.type}" ) if c.ndim > 0: @@ -369,27 +368,20 @@ def ifelse( if not isinstance(else_branch_elem, Variable): else_branch_elem = aet.basic.as_tensor_variable(else_branch_elem) - if then_branch_elem.type != else_branch_elem.type: + if not then_branch_elem.type.is_compatible(else_branch_elem.type): # If one of them is a TensorType, and the other one can be # converted into one, then we try to do that. # This case happens when one of the elements has a GPU type, # for instance a shared variable that was silently moved to GPU. - if isinstance(then_branch_elem.type, TensorType) and not isinstance( - else_branch_elem.type, TensorType - ): + if then_branch_elem.type.is_compatible(else_branch_elem.type): else_branch_elem = then_branch_elem.type.filter_variable( else_branch_elem ) - - elif isinstance(else_branch_elem.type, TensorType) and not isinstance( - then_branch_elem.type, TensorType - ): + elif else_branch_elem.type.is_compatible(then_branch_elem.type): then_branch_elem = else_branch_elem.type.filter_variable( then_branch_elem ) - - if then_branch_elem.type != else_branch_elem.type: - # If the types still don't match, there is a problem. + else: raise TypeError( "The two branches should have identical types, but " f"they are {then_branch_elem.type} and {else_branch_elem.type} respectively. This error could be " diff --git a/aesara/link/vm.py b/aesara/link/vm.py index b3f4614121..fbbc161107 100644 --- a/aesara/link/vm.py +++ b/aesara/link/vm.py @@ -87,6 +87,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend if ( getattr(out, "ndim", None) == 0 and out not in pre_allocated + # TODO: Should this be `is_compatible`? and ins.type == out.type ): reuse_out = out @@ -110,6 +111,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend if ( getattr(out, "ndim", None) == 0 and out not in pre_allocated + # TODO: Should this be `is_compatible`? and ins.type == out.type ): reuse_out = out diff --git a/aesara/scalar/basic.py b/aesara/scalar/basic.py index c097e49241..4cbac867aa 100644 --- a/aesara/scalar/basic.py +++ b/aesara/scalar/basic.py @@ -1236,7 +1236,7 @@ def c_code_contiguous(self, node, name, inputs, outputs, sub): or # We compare the dtype AND the broadcast flag # as this function do not broadcast - node.inputs[0].type != node.outputs[0].type + not node.inputs[0].type.is_compatible(node.outputs[0].type) ): raise MethodNotDefined() diff --git a/aesara/scan/op.py b/aesara/scan/op.py index ce4549494a..ef4764f78d 100644 --- a/aesara/scan/op.py +++ b/aesara/scan/op.py @@ -534,12 +534,12 @@ def validate_inner_graph(self): type_input = self.inputs[inner_iidx].type type_output = self.outputs[inner_oidx].type - if type_input != type_output: + if not type_input.is_compatible(type_output): raise TypeError( "Inconsistency in the inner graph of " f"scan '{self.name}' : an input and an output are " "associated with the same recurrent state " - "and should have the same type but have " + "and should have compatible types but have " f"type '{type_input}' and '{type_output}' respectively." ) @@ -1068,11 +1068,11 @@ def make_node(self, *inputs): ): outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq) new_inputs.append(outer_nonseq) - if inner_nonseq.type != outer_nonseq.type: + if not inner_nonseq.type.is_compatible(outer_nonseq.type): raise ValueError( ( - f"Argument {outer_nonseq} given to the scan node does not" - f" match its corresponding loop function variable {inner_nonseq}" + f"Argument {outer_nonseq} given to the scan node is not" + f" compatible with its corresponding loop function variable {inner_nonseq}" ) ) @@ -1143,7 +1143,7 @@ def __eq__(self, other): return False for self_in, other_in in zip(self.inputs, other.inputs): - if self_in.type != other_in.type: + if not self_in.type.is_compatible(other_in.type): return False return equal_computations( diff --git a/aesara/scan/opt.py b/aesara/scan/opt.py index e6b703c4ef..51eac7e0a7 100644 --- a/aesara/scan/opt.py +++ b/aesara/scan/opt.py @@ -308,7 +308,7 @@ def add_to_replace(y): for out, idx in to_replace_map.items(): if ( # If types are different, conversion Op will be inserted, # and it may trigger an infinite loop. - replace_with_in[idx].type == out.type + replace_with_in[idx].type.is_compatible(out.type) and out in to_keep_set and out.owner not in existent_nodes_set ): @@ -557,7 +557,7 @@ def add_to_replace(y): and # If types are different, conversion Op will be inserted, # and it may trigger an infinite loop. - replace_with_in[idx].type == out.type + replace_with_in[idx].type.is_compatible(out.type) ): clean_to_replace.append(out) diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index 49667c571e..e5ae1cf180 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -55,8 +55,6 @@ discrete_dtypes, float_dtypes, int_dtypes, - int_types, - int_vector_types, integer_dtypes, tensor, uint_dtypes, @@ -650,7 +648,7 @@ class Rebroadcast(COp): Examples -------- - `Rebroadcast((0, True), (1, False))(x)` would make `x` broadcastable in + ``Rebroadcast((0, True), (1, False))(x)`` would make `x` broadcastable in axis 0 and not broadcastable in axis 1. """ @@ -955,6 +953,7 @@ def ones_like(model, dtype=None, opt=False): if dtype is None: dtype = model.type.dtype ret = constant(1.0, dtype=dtype) + # TODO: Should this be `is_compatible`? if opt and ret.type == model.type: return ret return fill(model, ret) @@ -979,6 +978,7 @@ def zeros_like(model, dtype=None, opt=False): if dtype is None: dtype = model.type.dtype ret = constant(0.0, dtype=dtype) + # TODO: Should this be `is_compatible`? if opt and ret.type == model.type: return ret return fill(model, ret) @@ -1769,8 +1769,8 @@ class Default(Op): def make_node(self, x, default): x, default = as_tensor_variable(x), as_tensor_variable(default) - if x.type != default.type: - raise TypeError("Both default() arguments must have same type", x, default) + if not default.type.is_compatible(x.type): + raise TypeError("Both arguments must have compatible types") return Apply(self, [x, default], [default.type()]) def perform(self, node, inp, out_): @@ -1872,16 +1872,11 @@ def make_node(self, x, axis, splits): axis = as_tensor_variable(axis) splits = as_tensor_variable(splits) - if splits.type not in int_vector_types: + if splits.type.ndim == 1 and splits.type.dtype not in integer_dtypes: raise TypeError("`splits` parameter must be tensors of integer type") - if axis.type not in int_types: - raise TypeError("`axis` parameter must be an integer scalar") - # # The following lines are necessary if we allow splits of zero - # if isinstance(axis, Constant): - # x = unbroadcast(x, int(axis.data)) - # else: - # x = unbroadcast(x, *range(x.type.ndim)) + if axis.type.dtype not in integer_dtypes: + raise TypeError("`axis` parameter must be an integer scalar") inputs = [x, axis, splits] outputs = [x.type() for i in range(self.len_splits)] diff --git a/aesara/tensor/basic_opt.py b/aesara/tensor/basic_opt.py index ec14e1a318..467780ed5c 100644 --- a/aesara/tensor/basic_opt.py +++ b/aesara/tensor/basic_opt.py @@ -120,7 +120,7 @@ def broadcast_like(value, template, fgraph, dtype=None): """ value = as_tensor_variable(value) - if value.type == template.type: + if value.type.is_compatible(template.type): return value if template not in fgraph.variables: raise NotImplementedError( @@ -130,7 +130,7 @@ def broadcast_like(value, template, fgraph, dtype=None): if dtype is None: dtype = template.dtype value = cast(value, dtype) - if value.type == template.type: + if value.type.is_compatible(template.type): return value if hasattr(fgraph, "shape_feature"): new_shape = fgraph.shape_feature.shape_of[template] @@ -1543,9 +1543,9 @@ def dimshuffled_alloc(i): if ( i.owner and isinstance(i.owner.op, Alloc) - and i.owner.inputs[0].type != i.owner.outputs[0].type + and not i.owner.inputs[0].type.is_compatible(i.owner.outputs[0].type) ): - # when `i.owner.inputs[0].type == i.owner.outputs[0].type` we + # when `i.owner.inputs[0].type.is_compatible(i.owner.outputs[0].type)` we # will remove that `Alloc` later assert i.type.ndim == cmp_op.ndim if config.experimental__local_alloc_elemwise_assert: @@ -1623,7 +1623,7 @@ def local_fill_sink(fgraph, node): return False c = node.op(*inputs) for model in models: - if model.type != c.type: + if not model.type.is_compatible(c.type): c = fill(model, c) # The newly created node c doesn't has 'clients', @@ -1721,7 +1721,7 @@ def local_useless_fill(fgraph, node): """ if node.op == fill: r, v = node.inputs - if v.type == node.outputs[0].type: + if v.type.is_compatible(node.outputs[0].type): # this is a useless fill, erase it. # also, we don't need to copy over any stack traces here return [v] @@ -1744,7 +1744,7 @@ def local_useless_alloc(fgraph, node): inp = node.inputs[0] output = node.outputs[0] - if inp.type == output.type: + if inp.type.is_compatible(output.type): if inp.ndim == 0: return [inp] else: @@ -1779,7 +1779,7 @@ def local_canonicalize_alloc(fgraph, node): output = node.outputs[0] # Check if dtype and broadcast remain the same. - if inp.type == output.type: + if inp.type.is_compatible(output.type): # We don't need to copy over any stack traces here return [inp] @@ -2210,7 +2210,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): if new_inputs != node.inputs: rval = [node.op(*new_inputs)] - if rval[0].type != node.outputs[0].type: + if not node.outputs[0].type.is_compatible(rval[0].type): # This can happen for example when floatX=float32 # and we do the true division between and int64 # and a constant that will get typed as int8. @@ -2406,7 +2406,7 @@ def local_join_empty(fgraph, node): # by an error in the old join op. copy_stack_trace(node.outputs, ret) - if ret.type != o.type: + if not o.type.is_compatible(ret.type): assert ret.dtype == o.dtype assert ret.ndim == o.ndim ret = patternbroadcast(ret, node.outputs[0].broadcastable) @@ -2511,7 +2511,7 @@ def local_useless_switch(fgraph, node): if node.inputs[1] is node.inputs[2]: # Note: No need to copy over stacktrace, because the input node # already has its own stacktrace - if cond.type == node.inputs[1].type: + if cond.type.is_compatible(node.inputs[1].type): return [node.inputs[1]] ret = fill(cond, node.inputs[1]) @@ -2536,7 +2536,7 @@ def local_useless_switch(fgraph, node): and extract_constant(left, only_process_constants=True) == 0 and right is cond_var.owner.inputs[0] ): - assert right.type == node.outputs[0].type + assert node.outputs[0].type.is_compatible(right.type) # No need to copy over stacktrace, because the right input node # already has its own stacktrace return [right] @@ -2892,7 +2892,7 @@ def local_reshape_lift(fgraph, node): # In rare case the original broadcast was (False, True), but # the new one is (False, False). So don't crash in that case. - if e.type != node.outputs[0].type: + if not node.outputs[0].type.is_compatible(e.type): re = patternbroadcast(e, node.outputs[0].broadcastable) # Copy over stack trace. diff --git a/aesara/tensor/blas.py b/aesara/tensor/blas.py index fcbc4902b6..57e5e95f21 100644 --- a/aesara/tensor/blas.py +++ b/aesara/tensor/blas.py @@ -1112,7 +1112,7 @@ def res_is_a(fgraph, var, op, maxclients=None): def _as_scalar(res, dtype=None): - """Return ``None`` or a `TensorVariable` whose type is in `float_scalar_types`""" + """Return ``None`` or a `TensorVariable` of float type""" if dtype is None: dtype = config.floatX if np.all(res.type.broadcastable): @@ -1367,7 +1367,7 @@ def item_to_var(t): for j in range(i + 1, len(lst)): s_j, M_j = lst[j] - if M_i.type != M_j.type: + if not M_i.type.is_compatible(M_j.type): continue # print 'TRYING', (s_i, M_i, s_j, M_j) @@ -1393,11 +1393,11 @@ def item_to_var(t): def _gemm_from_node2(fgraph, node): """ - :todo: In many expressions, there are many ways to turn it into a - gemm. For example dot(a,b) + c + d. This function should - return all of them, so that if one version of gemm causes a - cycle in the graph, then another application of gemm can be - tried. + + TODO: In many expressions, there are many ways to turn it into a + gemm. For example dot(a,b) + c + d. This function should return all + of them, so that if one version of gemm causes a cycle in the graph, then + another application of gemm can be tried. """ lst = [] @@ -1405,7 +1405,6 @@ def _gemm_from_node2(fgraph, node): _gemm_canonicalize(fgraph, node.outputs[0], 1.0, lst, 0) t1 = time.time() - # print "GEMM CANON", lst if len(lst) > 1: lst = _factor_canonicalized(lst) t2 = time.time() @@ -1421,7 +1420,15 @@ def _gemm_from_node2(fgraph, node): # http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5, # but never made it into a trac ticket. - if rval and (rval[0][0].type == node.outputs[0].type): + if rval and ( + # TODO FIXME: Clarify this logic (e.g. what sort of "upcasting" + # changes the shapes/broadcast patterns, why is that allowed, + # can the rewrite logic be simplified/made to avoid these + # cryptic checks, etc.) + # All of this code *really* needs to be refactored. + rval[0][0].dtype == node.outputs[0].dtype + and rval[0][0].broadcastable == node.outputs[0].broadcastable + ): return rval, t1 - t0, t2 - t1, t3 - t2 return None, t1 - t0, 0, 0 diff --git a/aesara/tensor/extra_ops.py b/aesara/tensor/extra_ops.py index 28052e5763..5287b799d5 100644 --- a/aesara/tensor/extra_ops.py +++ b/aesara/tensor/extra_ops.py @@ -27,14 +27,7 @@ from aesara.tensor.math import maximum, minimum, or_, prod from aesara.tensor.math import sum as aet_sum from aesara.tensor.subtensor import advanced_inc_subtensor1, set_subtensor -from aesara.tensor.type import ( - TensorType, - dvector, - int_dtypes, - int_vector_types, - integer_dtypes, - vector, -) +from aesara.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from aesara.tensor.var import TensorVariable from aesara.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH @@ -138,7 +131,7 @@ def make_node(self, x, v, sorter=None): "numpy.searchsorted with Python 32bit do not support a" " sorter of int64." ) - if sorter.type not in int_vector_types: + if sorter.type.ndim == 1 and sorter.type.dtype not in int_dtypes: raise TypeError("sorter must be an integer vector", sorter.type) return Apply(self, [x, v, sorter], [out_type()]) diff --git a/aesara/tensor/math_opt.py b/aesara/tensor/math_opt.py index 3c4a26732b..af8412a1ff 100644 --- a/aesara/tensor/math_opt.py +++ b/aesara/tensor/math_opt.py @@ -309,7 +309,7 @@ def local_exp_log(fgraph, node): x = x.owner.inputs[0] old_out = node.outputs[0] new_out = add(1, exp(x)) - if new_out.type != old_out.type: + if not old_out.type.is_compatible(new_out.type): return return [new_out] @@ -333,7 +333,7 @@ def local_exp_log_nan_switch(fgraph, node): x = x.owner.inputs[0] old_out = node.outputs[0] new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype)) - if new_out.type != old_out.type: + if not old_out.type.is_compatible(new_out.type): return return [new_out] @@ -342,7 +342,7 @@ def local_exp_log_nan_switch(fgraph, node): x = x.owner.inputs[0] old_out = node.outputs[0] new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype)) - if new_out.type != old_out.type: + if not old_out.type.is_compatible(new_out.type): return return [new_out] @@ -351,7 +351,7 @@ def local_exp_log_nan_switch(fgraph, node): x = x.owner.inputs[0] old_out = node.outputs[0] new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype)) - if new_out.type != old_out.type: + if not old_out.type.is_compatible(new_out.type): return return [new_out] @@ -425,7 +425,8 @@ def local_expm1(fgraph, node): if new_out.dtype != out.dtype: new_out = cast(new_out, dtype=out.dtype) - if new_out.type != out.type: + + if not out.type.is_compatible(new_out.type): return return [new_out] @@ -1019,8 +1020,6 @@ def same(x, y): if new.type.dtype != out.type.dtype: new = cast(new, out.type.dtype) - assert (new.type == out.type) == (not (new.type != out.type)) - if new.type != out.type: new = fill_chain(new, node.inputs)[0] @@ -1831,7 +1830,7 @@ def local_div_to_reciprocal(fgraph, node): if new_out.dtype != out.dtype: new_out = cast(new_out, dtype=out.dtype) # The ones could have forced a specific length - if new_out.type != out.type: + if not out.type.is_compatible(new_out.type): new_out = broadcast_like(new_out, out, fgraph) return [new_out] else: diff --git a/aesara/tensor/nnet/basic.py b/aesara/tensor/nnet/basic.py index beb2baf97a..42c1871ea4 100644 --- a/aesara/tensor/nnet/basic.py +++ b/aesara/tensor/nnet/basic.py @@ -70,8 +70,7 @@ TensorType, discrete_dtypes, float_dtypes, - ivector, - lvector, + integer_dtypes, values_eq_approx_remove_inf, values_eq_approx_remove_nan, ) @@ -1223,7 +1222,7 @@ def local_softmax_with_bias(fgraph, node): # forget about it return - if sm_bias.type == node.outputs[0].type: + if node.outputs[0].type.is_compatible(sm_bias.type): # This condition is not always true. See the test # nnet/tests/test_basic.py:T_SoftmaxWithBias.test_broadcast return [sm_bias] @@ -1818,12 +1817,12 @@ def make_node(self, coding_dist, true_one_of_n): _coding_dist = aet.as_tensor_variable(coding_dist) _true_one_of_n = aet.as_tensor_variable(true_one_of_n) if _coding_dist.type.ndim != 2: - raise TypeError("matrix required for argument: coding_dist") - if _true_one_of_n.type not in (lvector, ivector): - raise TypeError( - "integer vector required for argument: true_one_of_n" - f"(got type: {_true_one_of_n.type} instead of: {lvector})" - ) + raise TypeError("Matrix required for argument `coding_dist`") + if ( + _true_one_of_n.type.ndim != 1 + and _true_one_of_n.type.dtype not in integer_dtypes + ): + raise TypeError("Integer vector required for argument `true_one_of_n`") return Apply( self, diff --git a/aesara/tensor/subtensor.py b/aesara/tensor/subtensor.py index 9af5074cf2..fd029bccfa 100644 --- a/aesara/tensor/subtensor.py +++ b/aesara/tensor/subtensor.py @@ -30,18 +30,11 @@ from aesara.tensor.shape import Reshape from aesara.tensor.type import ( TensorType, - bscalar, complex_dtypes, - cscalar, discrete_dtypes, - dscalar, - fscalar, + float_dtypes, integer_dtypes, - iscalar, - lscalar, tensor, - wscalar, - zscalar, ) from aesara.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice @@ -50,18 +43,8 @@ invalid_scal_types = (aes.float64, aes.float32, aes.float16) scal_types = (aes.int64, aes.int32, aes.int16, aes.int8) -tensor_types = ( - lscalar, - iscalar, - wscalar, - bscalar, -) -invalid_tensor_types = ( - fscalar, - dscalar, - cscalar, - zscalar, -) +valid_index_dtypes = discrete_dtypes +invalid_index_dtypes = float_dtypes + complex_dtypes def indices_from_subtensor( @@ -548,7 +531,7 @@ def index_vars_to_types(entry, slice_ok=True): raise AdvancedIndexingError("Invalid index type or slice for Subtensor") if isinstance(entry, Variable) and ( - entry.type in invalid_scal_types or entry.type in invalid_tensor_types + entry.type in invalid_scal_types or entry.type.dtype in invalid_index_dtypes ): raise TypeError("Expected an integer") @@ -559,13 +542,13 @@ def index_vars_to_types(entry, slice_ok=True): if ( isinstance(entry, Variable) - and entry.type in tensor_types + and entry.type.dtype in valid_index_dtypes and np.all(entry.type.broadcastable) ): return aes.get_scalar_type(entry.type.dtype) elif ( isinstance(entry, Type) - and entry in tensor_types + and entry.dtype in valid_index_dtypes and np.all(entry.broadcastable) ): return aes.get_scalar_type(entry.dtype) @@ -700,7 +683,7 @@ def make_node(self, x, *inputs): if len(inputs) != len(input_types): raise IndexError("Not enough inputs to fill in the Subtensor template.") for input, expected_type in zip(inputs, input_types): - if input.type != expected_type: + if not expected_type.is_compatible(input.type): raise TypeError( f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}." ) @@ -1528,7 +1511,7 @@ def make_node(self, x, y, *inputs): "Not enough inputs to fill in the Subtensor template.", inputs, idx_list ) for input, expected_type in zip(inputs, input_types): - if input.type != expected_type: + if not expected_type.is_compatible(input.type): raise TypeError( f"Wrong type for Subtensor template. Expected {input.type}, got {expected_type}." ) diff --git a/aesara/tensor/subtensor_opt.py b/aesara/tensor/subtensor_opt.py index a72372ccd3..11d756895b 100644 --- a/aesara/tensor/subtensor_opt.py +++ b/aesara/tensor/subtensor_opt.py @@ -538,7 +538,7 @@ def local_subtensor_merge(fgraph, node): # Restore original broadcastable dimensions that `subtens()` may # have been unable to infer again - if out.type != orig_out.type: + if not orig_out.type.is_compatible(out.type): assert out.dtype == orig_out.dtype assert out.ndim == orig_out.ndim out = patternbroadcast(out, orig_out.broadcastable) @@ -660,7 +660,7 @@ def local_subtensor_of_alloc(fgraph, node): rval = alloc(nw_val, *nw_dims) if not isinstance(rval, (list, tuple)): rval = [rval] - if rval[0].type != node.outputs[0].type: + if not node.outputs[0].type.is_compatible(rval[0].type): # It happen that the make_node() isn't able to infer the same pattern. # We know it is safe, so fix that. rval[0] = patternbroadcast(rval[0], node.outputs[0].broadcastable) @@ -691,7 +691,7 @@ def local_subtensor_inc_subtensor(fgraph, node): # If the dtypes differ, cast y into x.dtype if x.dtype != y.dtype: y = y.astype(x.dtype) - if out.type == y.type: + if out.type.is_compatible(y.type): # if x[idx] and y have the same type, directly return y return [y] else: @@ -744,7 +744,7 @@ def local_subtensor_make_vector(fgraph, node): if isinstance(idx, (aes.Scalar, TensorType)): old_idx, idx = idx, node.inputs[1] - assert idx.type == old_idx + assert idx.type.is_compatible(old_idx) elif isinstance(node.op, AdvancedSubtensor1): idx = node.inputs[1] @@ -1180,7 +1180,7 @@ def movable(i): AdvancedIncSubtensor, ), ) - and i.type == o_type + and i.type.is_compatible(o_type) and len(fgraph.clients[i]) == 1 and not i.owner.op.set_instead_of_inc ) @@ -1207,8 +1207,8 @@ def movable(i): # stack up the new incsubtensors tip = new_add for mi in movable_inputs: - assert tip.type == o_type - assert tip.type == mi.owner.inputs[0].type + assert o_type.is_compatible(tip.type) + assert mi.owner.inputs[0].type.is_compatible(tip.type) tip = mi.owner.op(tip, *mi.owner.inputs[1:]) # Copy over stacktrace from outputs of the original # "movable" operation to the new operation. diff --git a/aesara/typed_list/basic.py b/aesara/typed_list/basic.py index aa49da1592..525fc93068 100644 --- a/aesara/typed_list/basic.py +++ b/aesara/typed_list/basic.py @@ -231,7 +231,7 @@ def __init__(self, inplace=False): def make_node(self, x, toAppend): assert isinstance(x.type, TypedListType) - assert x.type == toAppend.type + assert toAppend.type.is_compatible(x.type) return Apply(self, [x, toAppend], [x.type()]) def perform(self, node, inputs, outputs): @@ -651,7 +651,7 @@ def make_node(self, a): if not isinstance(elem, Variable): elem = aet.as_tensor_variable(elem) a2.append(elem) - if not all(a2[0].type == elem.type for elem in a2): + if not all(a2[0].type.is_compatible(elem.type) for elem in a2): raise TypeError("MakeList need all input variable to be of the same type.") tl = TypedListType(a2[0].type)() diff --git a/tests/tensor/nnet/test_abstract_conv.py b/tests/tensor/nnet/test_abstract_conv.py index e02dd980fa..e9431a40ac 100644 --- a/tests/tensor/nnet/test_abstract_conv.py +++ b/tests/tensor/nnet/test_abstract_conv.py @@ -1507,7 +1507,7 @@ def test_constant_input(self): # Check the forward Op output = conv.abstract_conv2d(constant_tensor, filters) grad_filters = aesara.grad(output.sum(), wrt=filters) - assert grad_filters.type == filters.type, ( + assert filters.type.is_compatible(grad_filters.type), ( grad_filters, grad_filters.type, filters, @@ -1516,7 +1516,7 @@ def test_constant_input(self): output = conv.abstract_conv2d(input, constant_tensor) grad_input = aesara.grad(output.sum(), wrt=input) - assert grad_input.type == input.type, ( + assert input.type.is_compatible(grad_input.type), ( grad_input, grad_input.type, input, @@ -1528,7 +1528,7 @@ def test_constant_input(self): constant_tensor, topgrad, out_shape ) grad_topgrad = aesara.grad(grad_filters.sum(), wrt=topgrad) - assert grad_topgrad.type == topgrad.type, ( + assert topgrad.type.is_compatible(grad_topgrad.type), ( grad_topgrad, grad_topgrad.type, topgrad, @@ -1551,7 +1551,7 @@ def test_constant_input(self): constant_tensor, topgrad, out_shape ) grad_topgrad = aesara.grad(grad_input.sum(), wrt=topgrad) - assert grad_topgrad.type == topgrad.type, ( + assert topgrad.type.is_compatible(grad_topgrad.type), ( grad_topgrad, grad_topgrad.type, topgrad,