From 2f4c35bca37b753a7e4d93322eb74647135156fe Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Sat, 22 Jan 2022 00:05:27 +0300 Subject: [PATCH] add is_super method --- aesara/sparse/type.py | 11 +++++++++++ aesara/tensor/type.py | 4 +++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/aesara/sparse/type.py b/aesara/sparse/type.py index 2020345a83..c26bb4d220 100644 --- a/aesara/sparse/type.py +++ b/aesara/sparse/type.py @@ -215,6 +215,17 @@ def value_zeros(self, shape): return matrix_constructor(shape, dtype=self.dtype) + def is_super(self, otype): + if ( + isinstance(otype, type(self)) + and otype.dtype == self.dtype + and otype.ndim == self.ndim + and self.format == otype.format + ): + return True + + return False + # Register SparseType's C code for ViewOp. aesara.compile.register_view_op_c_code( diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index f5401060a1..37324ee92c 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -315,7 +315,9 @@ def is_super(self, otype): return False def convert_variable(self, var): - if self.is_super(var.type): + if self.is_super(var.type) and not isinstance( + var.type, aesara.sparse.SparseType + ): # `var.type` is at least as specific as `self`, so we return # `var` as-is return var