Skip to content

Commit

Permalink
Warn when dense conversion is used by sparse tensor methods
Browse files Browse the repository at this point in the history
  • Loading branch information
aerubanov authored and brandonwillard committed Mar 15, 2022
1 parent 09c3101 commit e824666
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 3 deletions.
89 changes: 89 additions & 0 deletions aesara/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,95 @@ def sp_zeros_like(x):
)


def override_dense(*methods):
def decorate(cls):
def native(method):
original = getattr(cls.__base__, method)

def to_dense(self, *args, **kwargs):
self = self.toarray()
new_args = [
arg.toarray()
if hasattr(arg, "type") and isinstance(arg.type, SparseType)
else arg
for arg in args
]
warn(
f"Method {method} is not implemented for sparse variables. The variable will be converted to dense."
)
return original(self, *new_args, **kwargs)

return to_dense

for method in methods:
setattr(cls, method, native(method))
return cls

return decorate


@override_dense(
"__abs__",
"__ceil__",
"__floor__",
"__trunc__",
"transpose",
"any",
"all",
"flatten",
"ravel",
"arccos",
"arcsin",
"arctan",
"arccosh",
"arcsinh",
"arctanh",
"ceil",
"cos",
"cosh",
"deg2rad",
"exp",
"exp2",
"expm1",
"floor",
"log",
"log10",
"log1p",
"log2",
"rad2deg",
"sin",
"sinh",
"sqrt",
"tan",
"tanh",
"copy",
"prod",
"mean",
"var",
"std",
"min",
"max",
"argmin",
"argmax",
"conj",
"round",
"trace",
"cumsum",
"cumprod",
"ptp",
"squeeze",
"diagonal",
"__and__",
"__or__",
"__xor__",
"__pow__",
"__mod__",
"__divmod__",
"__truediv__",
"__floordiv__",
"reshape",
"dimshuffle",
)
class _sparse_py_operators(_tensor_py_operators):
T = property(
lambda self: transpose(self), doc="Return aliased transpose of self (read-only)"
Expand Down
14 changes: 11 additions & 3 deletions aesara/sparse/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import scipy.sparse

import aesara
from aesara.graph.type import HasDataType, Type
from aesara.graph.type import HasDataType
from aesara.tensor.type import TensorType


Expand Down Expand Up @@ -148,8 +148,16 @@ def may_share_memory(a, b):
return True
return False

def make_variable(self, name=None):
return self.variable_type(self, name=name)
def convert_variable(self, var):
res = super().convert_variable(var)

if res and not isinstance(res.type, SparseType):
# TODO: Convert to this sparse format
raise NotImplementedError()

# TODO: Convert sparse `var`s with different formats to this format?

return res

def __hash__(self):
return super().__hash__() ^ hash(self.format)
Expand Down
234 changes: 234 additions & 0 deletions tests/sparse/test_var.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
from contextlib import ExitStack

import numpy as np
import pytest
from scipy.sparse.csr import csr_matrix

import aesara
import aesara.sparse as sparse
import aesara.tensor as at
from aesara.sparse.type import SparseType
from aesara.tensor.type import DenseTensorType


class TestSparseVariable:
@pytest.mark.parametrize(
"method, exp_type, cm",
[
("__abs__", DenseTensorType, None),
("__neg__", SparseType, ExitStack()),
("__ceil__", DenseTensorType, None),
("__floor__", DenseTensorType, None),
("__trunc__", DenseTensorType, None),
("transpose", DenseTensorType, None),
("any", DenseTensorType, None),
("all", DenseTensorType, None),
("flatten", DenseTensorType, None),
("ravel", DenseTensorType, None),
("arccos", DenseTensorType, None),
("arcsin", DenseTensorType, None),
("arctan", DenseTensorType, None),
("arccosh", DenseTensorType, None),
("arcsinh", DenseTensorType, None),
("arctanh", DenseTensorType, None),
("ceil", DenseTensorType, None),
("cos", DenseTensorType, None),
("cosh", DenseTensorType, None),
("deg2rad", DenseTensorType, None),
("exp", DenseTensorType, None),
("exp2", DenseTensorType, None),
("expm1", DenseTensorType, None),
("floor", DenseTensorType, None),
("log", DenseTensorType, None),
("log10", DenseTensorType, None),
("log1p", DenseTensorType, None),
("log2", DenseTensorType, None),
("rad2deg", DenseTensorType, None),
("sin", DenseTensorType, None),
("sinh", DenseTensorType, None),
("sqrt", DenseTensorType, None),
("tan", DenseTensorType, None),
("tanh", DenseTensorType, None),
("copy", DenseTensorType, None),
("sum", DenseTensorType, ExitStack()),
("prod", DenseTensorType, None),
("mean", DenseTensorType, None),
("var", DenseTensorType, None),
("std", DenseTensorType, None),
("min", DenseTensorType, None),
("max", DenseTensorType, None),
("argmin", DenseTensorType, None),
("argmax", DenseTensorType, None),
("nonzero", DenseTensorType, ExitStack()),
("nonzero_values", DenseTensorType, None),
("argsort", DenseTensorType, ExitStack()),
("conj", DenseTensorType, None),
("round", DenseTensorType, None),
("trace", DenseTensorType, None),
("zeros_like", SparseType, ExitStack()),
("ones_like", DenseTensorType, ExitStack()),
("cumsum", DenseTensorType, None),
("cumprod", DenseTensorType, None),
("ptp", DenseTensorType, None),
("squeeze", DenseTensorType, None),
("diagonal", DenseTensorType, None),
],
)
def test_unary(self, method, exp_type, cm):
x = at.dmatrix("x")
x = sparse.csr_from_dense(x)

method_to_call = getattr(x, method)

if cm is None:
cm = pytest.warns(UserWarning, match=".*converted to dense.*")

if exp_type == SparseType:
exp_res_type = csr_matrix
else:
exp_res_type = np.ndarray

with cm:
z = method_to_call()

if not isinstance(z, tuple):
z_outs = (z,)
else:
z_outs = z

assert all(isinstance(out.type, exp_type) for out in z_outs)

f = aesara.function([x], z, on_unused_input="ignore")

res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])

if not isinstance(res, list):
res_outs = [res]
else:
res_outs = res

assert all(isinstance(out, exp_res_type) for out in res_outs)

@pytest.mark.parametrize(
"method, exp_type",
[
("__lt__", SparseType),
("__le__", SparseType),
("__gt__", SparseType),
("__ge__", SparseType),
("__and__", DenseTensorType),
("__or__", DenseTensorType),
("__xor__", DenseTensorType),
("__add__", SparseType),
("__sub__", SparseType),
("__mul__", SparseType),
("__pow__", DenseTensorType),
("__mod__", DenseTensorType),
("__divmod__", DenseTensorType),
("__truediv__", DenseTensorType),
("__floordiv__", DenseTensorType),
],
)
def test_binary(self, method, exp_type):
x = at.lmatrix("x")
y = at.lmatrix("y")
x = sparse.csr_from_dense(x)
y = sparse.csr_from_dense(y)

method_to_call = getattr(x, method)

if exp_type == SparseType:
exp_res_type = csr_matrix
cm = ExitStack()
else:
exp_res_type = np.ndarray
cm = pytest.warns(UserWarning, match=".*converted to dense.*")

with cm:
z = method_to_call(y)

if not isinstance(z, tuple):
z_outs = (z,)
else:
z_outs = z

assert all(isinstance(out.type, exp_type) for out in z_outs)

f = aesara.function([x, y], z)
res = f(
[[1, 0, 2], [-1, 0, 0]],
[[1, 1, 2], [1, 4, 1]],
)

if not isinstance(res, list):
res_outs = [res]
else:
res_outs = res

assert all(isinstance(out, exp_res_type) for out in res_outs)

def test_reshape(self):
x = at.dmatrix("x")
x = sparse.csr_from_dense(x)

with pytest.warns(UserWarning, match=".*converted to dense.*"):
z = x.reshape((3, 2))

assert isinstance(z.type, DenseTensorType)

f = aesara.function([x], z)
exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
assert isinstance(exp_res, np.ndarray)

def test_dimshuffle(self):
x = at.dmatrix("x")
x = sparse.csr_from_dense(x)

with pytest.warns(UserWarning, match=".*converted to dense.*"):
z = x.dimshuffle((1, 0))

assert isinstance(z.type, DenseTensorType)

f = aesara.function([x], z)
exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
assert isinstance(exp_res, np.ndarray)

def test_getitem(self):
x = at.dmatrix("x")
x = sparse.csr_from_dense(x)

z = x[:, :2]
assert isinstance(z.type, SparseType)

f = aesara.function([x], z)
exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
assert isinstance(exp_res, csr_matrix)

def test_dot(self):
x = at.lmatrix("x")
y = at.lmatrix("y")
x = sparse.csr_from_dense(x)
y = sparse.csr_from_dense(y)

z = x.__dot__(y)
assert isinstance(z.type, SparseType)

f = aesara.function([x, y], z)
exp_res = f(
[[1, 0, 2], [-1, 0, 0]],
[[-1], [2], [1]],
)
assert isinstance(exp_res, csr_matrix)

def test_repeat(self):
x = at.dmatrix("x")
x = sparse.csr_from_dense(x)

with pytest.warns(UserWarning, match=".*converted to dense.*"):
z = x.repeat(2, axis=1)

assert isinstance(z.type, DenseTensorType)

f = aesara.function([x], z)
exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]])
assert isinstance(exp_res, np.ndarray)

0 comments on commit e824666

Please sign in to comment.