diff --git a/aesara/graph/fg.py b/aesara/graph/fg.py index 7adde079fc..65e483d78e 100644 --- a/aesara/graph/fg.py +++ b/aesara/graph/fg.py @@ -465,19 +465,12 @@ def change_node_input( self.outputs[i] = new_var else: r = node.inputs[i] - if check: - if isinstance(r.type, aesara.sparse.SparseType): - fail = not ( - isinstance(new_var.type, type(r.type)) - and r.type.dtype == r.type.dtype - ) - else: - fail = not r.type.is_super(new_var.type) - if fail: - raise TypeError( - f"The type of the replacement ({new_var.type}) must be " - f"compatible with the type of the original Variable ({r.type})." - ) + if check and not r.type.is_super(new_var.type): + raise TypeError( + f"The type of the replacement ({new_var.type}) must be " + f"compatible with the type of the original Variable ({r.type})." + ) + node.inputs[i] = new_var if r is new_var: diff --git a/aesara/sparse/type.py b/aesara/sparse/type.py index c26bb4d220..12d00fc627 100644 --- a/aesara/sparse/type.py +++ b/aesara/sparse/type.py @@ -150,12 +150,12 @@ def may_share_memory(a, b): def make_variable(self, name=None): return self.Variable(self, name=name) - def __eq__(self, other): - return ( - super().__eq__(other) - and type(self) == type(other) - and other.format == self.format - ) + # def __eq__(self, other): + # return ( + # super().__eq__(other) + # and type(self) == type(other) + # and other.format == self.format + # ) def __hash__(self): return super().__hash__() ^ hash(self.format) @@ -215,15 +215,24 @@ def value_zeros(self, shape): return matrix_constructor(shape, dtype=self.dtype) + def __eq__(self, other): + if type(self) != type(other): + return NotImplemented + + return other.dtype == self.dtype and other.format == self.format + def is_super(self, otype): - if ( - isinstance(otype, type(self)) - and otype.dtype == self.dtype - and otype.ndim == self.ndim - and self.format == otype.format - ): + # if ( + # isinstance(otype, SparseType) + # and otype.dtype == self.dtype + # and otype.ndim == self.ndim + # and self.format == otype.format + # and otype.broadcastable == self.broadcastable + # ): + # return True + # return False + if self == otype: return True - return False