Skip to content

Commit

Permalink
add is_super method
Browse files Browse the repository at this point in the history
  • Loading branch information
aerubanov committed Jan 24, 2022
1 parent 6cb04e2 commit 2f4c35b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
11 changes: 11 additions & 0 deletions aesara/sparse/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,17 @@ def value_zeros(self, shape):

return matrix_constructor(shape, dtype=self.dtype)

def is_super(self, otype):
if (
isinstance(otype, type(self))
and otype.dtype == self.dtype
and otype.ndim == self.ndim
and self.format == otype.format
):
return True

return False


# Register SparseType's C code for ViewOp.
aesara.compile.register_view_op_c_code(
Expand Down
4 changes: 3 additions & 1 deletion aesara/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ def is_super(self, otype):
return False

def convert_variable(self, var):
if self.is_super(var.type):
if self.is_super(var.type) and not isinstance(
var.type, aesara.sparse.SparseType
):
# `var.type` is at least as specific as `self`, so we return
# `var` as-is
return var
Expand Down

0 comments on commit 2f4c35b

Please sign in to comment.