From e76a3220e82faef6c6f95ab95d94df0527f86e79 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Feb 2021 00:44:48 -0600 Subject: [PATCH] Make sparse tensor types extend the existing tensor types --- aesara/sparse/basic.py | 37 +++++++++++++---- aesara/sparse/type.py | 72 +++++++++++++++++++--------------- aesara/tensor/basic.py | 10 ++++- aesara/tensor/blas.py | 41 +++++++++++++++++-- aesara/tensor/math.py | 6 +++ aesara/tensor/type.py | 15 ++++++- aesara/tensor/var.py | 23 +++++++++++ tests/sparse/test_basic.py | 67 +++++++++++++++++++++++++++---- tests/tensor/test_sharedvar.py | 7 +--- tests/tensor/test_var.py | 14 ++++++- 10 files changed, 231 insertions(+), 61 deletions(-) diff --git a/aesara/sparse/basic.py b/aesara/sparse/basic.py index 90825857b0..b42e8a6ff9 100644 --- a/aesara/sparse/basic.py +++ b/aesara/sparse/basic.py @@ -22,6 +22,9 @@ from aesara.misc.safe_asarray import _asarray from aesara.sparse.type import SparseType, _is_sparse from aesara.sparse.utils import hash_from_sparse + +# TODO: +# from aesara.tensor import _as_tensor_variable from aesara.tensor import basic as aet from aesara.tensor.basic import Split from aesara.tensor.math import add as aet_add @@ -46,6 +49,7 @@ from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes from aesara.tensor.type import iscalar, ivector, scalar, tensor, vector +from aesara.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators sparse_formats = ["csc", "csr"] @@ -123,8 +127,9 @@ def _is_dense(x): return isinstance(x, np.ndarray) -# Wrapper type -def as_sparse_variable(x, name=None): +# TODO: +# @_as_tensor_variable.register(scipy.sparse.base.spmatrix) +def as_sparse_variable(x, name=None, ndim=None, **kwargs): """ Wrapper around SparseVariable constructor to construct a Variable with a sparse matrix with the same dtype and @@ -247,7 +252,7 @@ def sp_zeros_like(x): ) -class _sparse_py_operators: +class _sparse_py_operators(_tensor_py_operators): T = property( lambda self: transpose(self), doc="Return aliased transpose of self (read-only)" ) @@ -355,8 +360,7 @@ def __getitem__(self, args): return ret -class SparseVariable(_sparse_py_operators, Variable): - dtype = property(lambda self: self.type.dtype) +class SparseVariable(_sparse_py_operators, TensorVariable): format = property(lambda self: self.type.format) def __str__(self): @@ -389,8 +393,7 @@ def aesara_hash(self): return hash_from_sparse(d) -class SparseConstant(Constant, _sparse_py_operators): - dtype = property(lambda self: self.type.dtype) +class SparseConstant(TensorConstant, _sparse_py_operators): format = property(lambda self: self.type.format) def signature(self): @@ -413,6 +416,12 @@ def __repr__(self): SparseType.Variable = SparseVariable SparseType.Constant = SparseConstant +# TODO: +# @_as_tensor_variable.register(SparseVariable) +# @_as_tensor_variable.register(SparseConstant) +# def _as_tensor_sparse(x, name, ndim): +# return x + # for more dtypes, call SparseType(format, dtype) def matrix(format, name=None, dtype=None): @@ -442,7 +451,7 @@ def bsr_matrix(name=None, dtype=None): csr_fmatrix = SparseType(format="csr", dtype="float32") bsr_fmatrix = SparseType(format="bsr", dtype="float32") -all_dtypes = SparseType.dtype_set +all_dtypes = list(SparseType.dtype_specs_map.keys()) complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"] float_dtypes = [t for t in all_dtypes if t[:5] == "float"] int_dtypes = [t for t in all_dtypes if t[:3] == "int"] @@ -920,6 +929,12 @@ def __init__(self, structured=True): def __str__(self): return f"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}" + def __call__(self, x): + if not isinstance(x.type, SparseType): + return x + + return super().__call__(x) + def make_node(self, x): x = as_sparse_variable(x) return Apply( @@ -997,6 +1012,12 @@ def __init__(self, format): def __str__(self): return f"{self.__class__.__name__}{{{self.format}}}" + def __call__(self, x): + if isinstance(x.type, SparseType): + return x + + return super().__call__(x) + def make_node(self, x): x = aet.as_tensor_variable(x) if x.ndim > 2: diff --git a/aesara/sparse/type.py b/aesara/sparse/type.py index 3d2d2a68ec..6851bf4896 100644 --- a/aesara/sparse/type.py +++ b/aesara/sparse/type.py @@ -10,7 +10,7 @@ import aesara -from aesara.graph.type import Type +from aesara.tensor.type import TensorType def _is_sparse(x): @@ -32,7 +32,7 @@ def _is_sparse(x): return isinstance(x, scipy.sparse.spmatrix) -class SparseType(Type): +class SparseType(TensorType): """ Fundamental way to create a sparse node. @@ -60,19 +60,19 @@ class SparseType(Type): "csc": scipy.sparse.csc_matrix, "bsr": scipy.sparse.bsr_matrix, } - dtype_set = { - "int8", - "int16", - "int32", - "int64", - "float32", - "uint8", - "uint16", - "uint32", - "uint64", - "float64", - "complex64", - "complex128", + dtype_specs_map = { + "float32": (float, "npy_float32", "NPY_FLOAT32"), + "float64": (float, "npy_float64", "NPY_FLOAT64"), + "uint8": (int, "npy_uint8", "NPY_UINT8"), + "int8": (int, "npy_int8", "NPY_INT8"), + "uint16": (int, "npy_uint16", "NPY_UINT16"), + "int16": (int, "npy_int16", "NPY_INT16"), + "uint32": (int, "npy_uint32", "NPY_UINT32"), + "int32": (int, "npy_int32", "NPY_INT32"), + "uint64": (int, "npy_uint64", "NPY_UINT64"), + "int64": (int, "npy_int64", "NPY_INT64"), + "complex128": (complex, "aesara_complex128", "NPY_COMPLEX128"), + "complex64": (complex, "aesara_complex64", "NPY_COMPLEX64"), } ndim = 2 @@ -80,28 +80,32 @@ class SparseType(Type): Variable = None Constant = None - def __init__(self, format, dtype): + def __init__(self, format, dtype, broadcastable=None, name=None): if not imported_scipy: raise Exception( - "You can't make SparseType object as SciPy" " is not available." - ) - dtype = str(dtype) - if dtype in self.dtype_set: - self.dtype = dtype - else: - raise NotImplementedError( - f'unsupported dtype "{dtype}" not in list', list(self.dtype_set) + "You can't make SparseType object when SciPy is not available." ) - assert isinstance(format, str) + if not isinstance(format, str): + raise TypeError("The sparse format parameter must be a string") + if format in self.format_cls: self.format = format else: raise NotImplementedError( f'unsupported format "{format}" not in list', - list(self.format_cls.keys()), ) + if broadcastable is None: + broadcastable = [False, False] + + super().__init__(dtype, broadcastable, name=name) + + def clone(self, dtype=None, broadcastable=None): + new = super().clone(dtype=dtype, broadcastable=broadcastable) + new.format = self.format + return new + def filter(self, value, strict=False, allow_downcast=None): if ( isinstance(value, self.format_cls[self.format]) @@ -152,14 +156,10 @@ def make_variable(self, name=None): return self.Variable(self, name=name) def __eq__(self, other): - return ( - type(self) == type(other) - and other.dtype == self.dtype - and other.format == self.format - ) + return super().__eq__(other) and other.format == self.format def __hash__(self): - return hash(self.dtype) ^ hash(self.format) + return super().__hash__() ^ hash(self.format) def __str__(self): return f"Sparse[{self.dtype}, {self.format}]" @@ -208,6 +208,14 @@ def get_size(self, shape_info): + (shape_info[2] + shape_info[3]) * np.dtype("int32").itemsize ) + def value_zeros(self, shape): + matrix_constructor = getattr(scipy.sparse, f"{self.format}_matrix", None) + + if matrix_constructor is None: + raise ValueError(f"Sparse matrix type {self.format} not found in SciPy") + + return matrix_constructor(shape, dtype=self.dtype) + # Register SparseType's C code for ViewOp. aesara.compile.register_view_op_c_code( diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index f89b908c10..9a6fe0ce4b 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -347,10 +347,16 @@ def get_scalar_constant_value( data = v.tag.unique_value else: data = v.data + if isinstance(data, np.ndarray): return numpy_scalar(data).copy() - else: - return data + + from aesara.sparse.type import SparseType + + if isinstance(v.type, SparseType): + raise NotScalarConstantError() + + return data if not only_process_constants and getattr(v, "owner", None) and max_recur > 0: max_recur -= 1 diff --git a/aesara/tensor/blas.py b/aesara/tensor/blas.py index 81f2259972..a441328ea9 100644 --- a/aesara/tensor/blas.py +++ b/aesara/tensor/blas.py @@ -166,7 +166,12 @@ from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.math import Dot, add, mul, neg, sub -from aesara.tensor.type import integer_dtypes, tensor, values_eq_approx_remove_inf_nan +from aesara.tensor.type import ( + DenseTensorType, + integer_dtypes, + tensor, + values_eq_approx_remove_inf_nan, +) from aesara.utils import memoize @@ -253,6 +258,7 @@ def make_node(self, y, alpha, A, x, beta): A = aet.as_tensor_variable(A) alpha = aet.as_tensor_variable(alpha) beta = aet.as_tensor_variable(beta) + if y.dtype != A.dtype or y.dtype != x.dtype: raise TypeError( "Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype) @@ -263,7 +269,13 @@ def make_node(self, y, alpha, A, x, beta): raise TypeError("gemv requires vector for x", x.type) if y.ndim != 1: raise TypeError("gemv requires vector for y", y.type) - return Apply(self, [y, alpha, A, x, beta], [y.type()]) + + inputs = [y, alpha, A, x, beta] + + if any(not isinstance(i.type, DenseTensorType) for i in inputs): + raise NotImplementedError("Only dense tensor types are supported") + + return Apply(self, inputs, [y.type()]) def perform(self, node, inputs, out_storage, params=None): y, alpha, A, x, beta = inputs @@ -360,7 +372,12 @@ def make_node(self, A, alpha, x, y): if x.dtype not in ("float32", "float64", "complex64", "complex128"): raise TypeError("only float and complex types supported", x.dtype) - return Apply(self, [A, alpha, x, y], [A.type()]) + + inputs = [A, alpha, x, y] + if any(not isinstance(i.type, DenseTensorType) for i in inputs): + raise NotImplementedError("Only dense tensor types are supported") + + return Apply(self, inputs, [A.type()]) def perform(self, node, inp, out, params=None): cA, calpha, cx, cy = inp @@ -898,6 +915,10 @@ def __getstate__(self): def make_node(self, *inputs): inputs = list(map(aet.as_tensor_variable, inputs)) + + if any(not isinstance(i.type, DenseTensorType) for i in inputs): + raise NotImplementedError("Only dense tensor types are supported") + if len(inputs) != 5: raise TypeError( f"Wrong number of inputs for {self} (expected 5, got {len(inputs)})" @@ -1580,6 +1601,10 @@ class Dot22(GemmRelated): def make_node(self, x, y): x = aet.as_tensor_variable(x) y = aet.as_tensor_variable(y) + + if any(not isinstance(i.type, DenseTensorType) for i in (x, y)): + raise NotImplementedError("Only dense tensor types are supported") + dtypes = ("float16", "float32", "float64", "complex64", "complex128") if x.type.ndim != 2 or x.type.dtype not in dtypes: raise TypeError(x) @@ -1665,6 +1690,9 @@ def local_dot_to_dot22(fgraph, node): if not isinstance(node.op, Dot): return + if any(not isinstance(i.type, DenseTensorType) for i in node.inputs): + return False + x, y = node.inputs if y.type.dtype != x.type.dtype: # TODO: upcast one so the types match @@ -1847,6 +1875,10 @@ class Dot22Scalar(GemmRelated): check_input = False def make_node(self, x, y, a): + + if any(not isinstance(i.type, DenseTensorType) for i in (x, y, a)): + raise NotImplementedError("Only dense tensor types are supported") + if a.ndim != 0: raise TypeError(Gemm.E_scalar, a) if x.ndim != 2: @@ -2066,6 +2098,9 @@ class BatchedDot(COp): def make_node(self, *inputs): inputs = list(map(aet.as_tensor_variable, inputs)) + if any(not isinstance(i.type, DenseTensorType) for i in inputs): + raise NotImplementedError("Only dense tensor types are supported") + if len(inputs) != 2: raise TypeError(f"Two arguments required, but {len(inputs)} given.") if inputs[0].ndim not in (2, 3): diff --git a/aesara/tensor/math.py b/aesara/tensor/math.py index 5560f4a12a..fbbcd6d1ac 100644 --- a/aesara/tensor/math.py +++ b/aesara/tensor/math.py @@ -32,6 +32,7 @@ ) from aesara.tensor.shape import shape from aesara.tensor.type import ( + DenseTensorType, complex_dtypes, continuous_dtypes, discrete_dtypes, @@ -2006,6 +2007,11 @@ def dense_dot(a, b): """ a, b = as_tensor_variable(a), as_tensor_variable(b) + if not isinstance(a.type, DenseTensorType) or not isinstance( + b.type, DenseTensorType + ): + raise TypeError("The dense dot product is only supported for dense types") + if a.ndim == 0 or b.ndim == 0: return a * b elif a.ndim > 2 or b.ndim > 2: diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 84b3a2a0b1..821fbc7eb9 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -7,6 +7,7 @@ from aesara.configdefaults import config from aesara.graph.basic import Variable from aesara.graph.type import CType +from aesara.graph.utils import MetaType from aesara.misc.safe_asarray import _asarray from aesara.utils import apply_across_args @@ -67,6 +68,7 @@ class TensorType(CType): """ + dtype_specs_map = dtype_specs_map context_name = "cpu" filter_checks_isfinite = False """ @@ -280,7 +282,7 @@ def dtype_specs(self): """ try: - return dtype_specs_map[self.dtype] + return self.dtype_specs_map[self.dtype] except KeyError: raise TypeError( f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}" @@ -617,6 +619,17 @@ def get_size(self, shape_info): aesara.compile.ops.expandable_types += (TensorType,) +class DenseTypeMeta(MetaType): + def __instancecheck__(self, o): + if type(o) == TensorType or isinstance(o, DenseTypeMeta): + return True + return False + + +class DenseTensorType(TensorType, metaclass=DenseTypeMeta): + """A `Type` for dense tensors.""" + + def values_eq_approx( a, b, allow_remove_inf=False, allow_remove_nan=False, rtol=None, atol=None ): diff --git a/aesara/tensor/var.py b/aesara/tensor/var.py index cfb933485c..c114adc0db 100644 --- a/aesara/tensor/var.py +++ b/aesara/tensor/var.py @@ -8,6 +8,7 @@ from aesara import tensor as aet from aesara.configdefaults import config from aesara.graph.basic import Constant, Variable +from aesara.graph.utils import MetaType from aesara.scalar import ComplexError, IntegerDivisionError from aesara.tensor.exceptions import AdvancedIndexingError from aesara.tensor.type import TensorType @@ -1038,3 +1039,25 @@ def __deepcopy__(self, memo): TensorType.Constant = TensorConstant + + +class DenseVariableMeta(MetaType): + def __instancecheck__(self, o): + if type(o) == TensorVariable or isinstance(o, DenseVariableMeta): + return True + return False + + +class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta): + """A `Variable` for dense tensors.""" + + +class DenseConstantMeta(MetaType): + def __instancecheck__(self, o): + if type(o) == TensorConstant or isinstance(o, DenseConstantMeta): + return True + return False + + +class DenseTensorConstant(TensorType, metaclass=DenseConstantMeta): + """A `Constant` for dense tensors.""" diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 4b4b887195..2f04b6b7d2 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -12,7 +12,7 @@ from aesara.compile.io import In, Out from aesara.configdefaults import config from aesara.gradient import GradientError -from aesara.graph.basic import Apply, Constant +from aesara.graph.basic import Apply, Constant, applys_between from aesara.graph.op import Op from aesara.misc.safe_asarray import _asarray from aesara.sparse import ( @@ -78,6 +78,7 @@ true_dot, ) from aesara.sparse.basic import ( + SparseConstant, _is_dense_variable, _is_sparse, _is_sparse_variable, @@ -1017,22 +1018,45 @@ class TestConversion: def setup_method(self): utt.seed_rng() - @pytest.mark.skip def test_basic(self): - a = aet.as_tensor_variable(np.random.rand(5)) + test_val = np.random.rand(5).astype(config.floatX) + a = aet.as_tensor_variable(test_val) s = csc_from_dense(a) val = eval_outputs([s]) - assert str(val.dtype) == "float64" + assert str(val.dtype) == config.floatX assert val.format == "csc" - @pytest.mark.skip - def test_basic_1(self): - a = aet.as_tensor_variable(np.random.rand(5)) + a = aet.as_tensor_variable(test_val) s = csr_from_dense(a) val = eval_outputs([s]) - assert str(val.dtype) == "float64" + assert str(val.dtype) == config.floatX assert val.format == "csr" + test_val = np.eye(3).astype(config.floatX) + a = sp.sparse.csr_matrix(test_val) + s = as_sparse_or_tensor_variable(a) + res = aet.as_tensor_variable(s) + assert isinstance(res, SparseConstant) + + a = sp.sparse.csr_matrix(test_val) + s = as_sparse_or_tensor_variable(a) + from aesara.tensor.exceptions import NotScalarConstantError + + with pytest.raises(NotScalarConstantError): + aet.get_scalar_constant_value(s, only_process_constants=True) + + # TODO: + # def test_sparse_as_tensor_variable(self): + # csr = sp.sparse.csr_matrix(np.eye(3)) + # val = aet.as_tensor_variable(csr) + # assert str(val.dtype) == config.floatX + # assert val.format == "csr" + # + # csr = sp.sparse.csc_matrix(np.eye(3)) + # val = aet.as_tensor_variable(csr) + # assert str(val.dtype) == config.floatX + # assert val.format == "csc" + def test_dense_from_sparse(self): # call dense_from_sparse for t in _mtypes: @@ -1607,6 +1631,33 @@ def test_int32_dtype(self): a = np.asarray(np.random.randint(0, 100, (size, size)), dtype=intX) f(i, a) + def test_tensor_dot_types(self): + + x = sparse.csc_matrix("x") + x_d = aet.matrix("x_d") + y = sparse.csc_matrix("y") + + res = aet.dot(x, y) + op_types = set(type(n.op) for n in applys_between([x, y], [res])) + assert sparse.basic.StructuredDot in op_types + assert aet.math.Dot not in op_types + + res = aet.dot(x_d, y) + op_types = set(type(n.op) for n in applys_between([x, y], [res])) + assert sparse.basic.StructuredDot in op_types + assert aet.math.Dot not in op_types + + res = aet.dot(x, x_d) + op_types = set(type(n.op) for n in applys_between([x, y], [res])) + assert sparse.basic.StructuredDot in op_types + assert aet.math.Dot not in op_types + + res = aet.dot(aet.second(1, x), y) + op_types = set(type(n.op) for n in applys_between([x, y], [res])) + assert sparse.basic.StructuredDot in op_types + assert aet.math.Dot not in op_types + + # @aesara.config.change_flags(compute_test_value="raise", optimizer_verbose=True) def test_csr_dense_grad(self): # shortcut: testing csc in float32, testing csr in float64 diff --git a/tests/tensor/test_sharedvar.py b/tests/tensor/test_sharedvar.py index db1fe7538c..bae7690ed8 100644 --- a/tests/tensor/test_sharedvar.py +++ b/tests/tensor/test_sharedvar.py @@ -429,12 +429,7 @@ def test_specify_shape(self): topo_cst[0].op == aesara.compile.function.types.deep_copy_op # Test that we can take the grad. - if aesara.sparse.enable_sparse and isinstance( - x1_specify_shape.type, aesara.sparse.SparseType - ): - # SparseVariable don't support sum for now. - assert not hasattr(x1_specify_shape, "sum") - else: + if hasattr(x1_specify_shape, "sum"): shape_grad = aesara.gradient.grad(x1_specify_shape.sum(), x1_shared) shape_constant_fct_grad = aesara.function([], shape_grad) # aesara.printing.debugprint(shape_constant_fct_grad) diff --git a/tests/tensor/test_var.py b/tests/tensor/test_var.py index 422f467e10..4938bb960f 100644 --- a/tests/tensor/test_var.py +++ b/tests/tensor/test_var.py @@ -4,11 +4,12 @@ import aesara import tests.unittest_tools as utt +from aesara.tensor.basic import constant from aesara.tensor.elemwise import DimShuffle from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from aesara.tensor.type import TensorType, dmatrix, iscalar, ivector, matrix from aesara.tensor.type_other import MakeSlice -from aesara.tensor.var import TensorConstant +from aesara.tensor.var import DenseTensorConstant, DenseTensorVariable, TensorConstant @pytest.mark.parametrize( @@ -170,3 +171,14 @@ def test__getitem__AdvancedSubtensor(): z = x[i, None] op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])] assert op_types[-1] == AdvancedSubtensor + + +def test_dense_types(): + + x = matrix() + assert isinstance(x, DenseTensorVariable) + assert not isinstance(x, DenseTensorConstant) + + x = constant(1) + assert not isinstance(x, DenseTensorVariable) + assert isinstance(x, DenseTensorConstant)