Skip to content

Commit

Permalink
Make sparse tensor types extend the existing tensor types
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Feb 27, 2022
1 parent 80bfde1 commit a37dac0
Show file tree
Hide file tree
Showing 13 changed files with 285 additions and 69 deletions.
26 changes: 18 additions & 8 deletions aesara/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes
from aesara.tensor.type import iscalar, ivector, scalar, tensor, vector
from aesara.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators


sparse_formats = ["csc", "csr"]
Expand Down Expand Up @@ -124,8 +125,7 @@ def _is_dense(x):
return isinstance(x, np.ndarray)


# Wrapper type
def as_sparse_variable(x, name=None):
def as_sparse_variable(x, name=None, ndim=None, **kwargs):
"""
Wrapper around SparseVariable constructor to construct
a Variable with a sparse matrix with the same dtype and
Expand Down Expand Up @@ -248,7 +248,7 @@ def sp_zeros_like(x):
)


class _sparse_py_operators:
class _sparse_py_operators(_tensor_py_operators):
T = property(
lambda self: transpose(self), doc="Return aliased transpose of self (read-only)"
)
Expand Down Expand Up @@ -359,8 +359,7 @@ def __getitem__(self, args):
return ret


class SparseVariable(_sparse_py_operators, Variable):
dtype = property(lambda self: self.type.dtype)
class SparseVariable(_sparse_py_operators, TensorVariable):
format = property(lambda self: self.type.format)

def __str__(self):
Expand Down Expand Up @@ -393,8 +392,7 @@ def aesara_hash(self):
return hash_from_sparse(d)


class SparseConstant(Constant, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype)
class SparseConstant(TensorConstant, _sparse_py_operators):
format = property(lambda self: self.type.format)

def signature(self):
Expand Down Expand Up @@ -446,7 +444,7 @@ def bsr_matrix(name=None, dtype=None):
csr_fmatrix = SparseType(format="csr", dtype="float32")
bsr_fmatrix = SparseType(format="bsr", dtype="float32")

all_dtypes = SparseType.dtype_set
all_dtypes = list(SparseType.dtype_specs_map.keys())
complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"]
float_dtypes = [t for t in all_dtypes if t[:5] == "float"]
int_dtypes = [t for t in all_dtypes if t[:3] == "int"]
Expand Down Expand Up @@ -924,6 +922,12 @@ def __init__(self, structured=True):
def __str__(self):
return f"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}"

def __call__(self, x):
if not isinstance(x.type, SparseType):
return x

return super().__call__(x)

def make_node(self, x):
x = as_sparse_variable(x)
return Apply(
Expand Down Expand Up @@ -1001,6 +1005,12 @@ def __init__(self, format):
def __str__(self):
return f"{self.__class__.__name__}{{{self.format}}}"

def __call__(self, x):
if isinstance(x.type, SparseType):
return x

return super().__call__(x)

def make_node(self, x):
x = at.as_tensor_variable(x)
if x.ndim > 2:
Expand Down
97 changes: 58 additions & 39 deletions aesara/sparse/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import scipy.sparse

import aesara
from aesara.graph.type import Type
from aesara.tensor.type import TensorType


def _is_sparse(x):
Expand All @@ -24,7 +24,7 @@ def _is_sparse(x):
return isinstance(x, scipy.sparse.spmatrix)


class SparseType(Type):
class SparseType(TensorType):
"""
Fundamental way to create a sparse node.
Expand Down Expand Up @@ -52,48 +52,45 @@ class SparseType(Type):
"csc": scipy.sparse.csc_matrix,
"bsr": scipy.sparse.bsr_matrix,
}
dtype_set = {
"int8",
"int16",
"int32",
"int64",
"float32",
"uint8",
"uint16",
"uint32",
"uint64",
"float64",
"complex64",
"complex128",
dtype_specs_map = {
"float32": (float, "npy_float32", "NPY_FLOAT32"),
"float64": (float, "npy_float64", "NPY_FLOAT64"),
"uint8": (int, "npy_uint8", "NPY_UINT8"),
"int8": (int, "npy_int8", "NPY_INT8"),
"uint16": (int, "npy_uint16", "NPY_UINT16"),
"int16": (int, "npy_int16", "NPY_INT16"),
"uint32": (int, "npy_uint32", "NPY_UINT32"),
"int32": (int, "npy_int32", "NPY_INT32"),
"uint64": (int, "npy_uint64", "NPY_UINT64"),
"int64": (int, "npy_int64", "NPY_INT64"),
"complex128": (complex, "aesara_complex128", "NPY_COMPLEX128"),
"complex64": (complex, "aesara_complex64", "NPY_COMPLEX64"),
}
ndim = 2

# Will be set to SparseVariable SparseConstant later.
Variable = None
Constant = None

def __init__(self, format, dtype, shape=None):
dtype = str(dtype)
if dtype in self.dtype_set:
self.dtype = dtype
else:
raise NotImplementedError(
f'unsupported dtype "{dtype}" not in list', list(self.dtype_set)
)

def __init__(self, format, dtype, shape=None, broadcastable=None, name=None):
if shape is None:
shape = (None, None)

self.shape = shape

assert isinstance(format, str)
if not isinstance(format, str):
raise TypeError("The sparse format parameter must be a string")

if format in self.format_cls:
self.format = format
else:
raise NotImplementedError(
f'unsupported format "{format}" not in list',
list(self.format_cls.keys()),
)
if broadcastable is None:
broadcastable = [False, False]

super().__init__(dtype, shape, name=name)

def clone(self, format=None, dtype=None, shape=None, **kwargs):
if format is None:
Expand Down Expand Up @@ -150,24 +147,21 @@ def may_share_memory(a, b):
return True
return False

def make_variable(self, name=None):
return self.Variable(self, name=name)
def convert_variable(self, var):
res = super().convert_variable(var)

def __eq__(self, other):
return (
type(self) == type(other)
and other.dtype == self.dtype
and other.format == self.format
)
if res and not isinstance(res.type, SparseType):
# TODO: Which format do we assign during conversion from a dense
# type?
raise NotImplementedError()

def __hash__(self):
return hash(self.dtype) ^ hash(self.format)
return res

def __str__(self):
return f"Sparse[{self.dtype}, {self.format}]"
def __hash__(self):
return super().__hash__() ^ hash(self.format)

def __repr__(self):
return f"Sparse[{self.dtype}, {self.format}]"
return f"Sparse({self.dtype}, {self.shape}, {self.format})"

def values_eq_approx(self, a, b, eps=1e-6):
# WARNING: equality comparison of sparse matrices is not fast or easy
Expand Down Expand Up @@ -210,6 +204,31 @@ def get_size(self, shape_info):
+ (shape_info[2] + shape_info[3]) * np.dtype("int32").itemsize
)

def value_zeros(self, shape):
matrix_constructor = getattr(scipy.sparse, f"{self.format}_matrix", None)

if matrix_constructor is None:
raise ValueError(f"Sparse matrix type {self.format} not found in SciPy")

return matrix_constructor(shape, dtype=self.dtype)

def __eq__(self, other):
res = super().__eq__(other)

if isinstance(res, bool):
return res and other.format == self.format

return res

def is_super(self, otype):
if not super().is_super(otype):
return False

if self.format == otype.format:
return True

return False


# Register SparseType's C code for ViewOp.
aesara.compile.register_view_op_c_code(
Expand Down
9 changes: 7 additions & 2 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,13 @@ def get_scalar_constant_value(
return np.array(data.item(), dtype=v.dtype)
except ValueError:
raise NotScalarConstantError()
else:
return data

from aesara.sparse.type import SparseType

if isinstance(v.type, SparseType):
raise NotScalarConstantError()

return data

if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
max_recur -= 1
Expand Down
10 changes: 8 additions & 2 deletions aesara/tensor/basic_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes
from aesara.tensor.type import (
DenseTensorType,
TensorType,
discrete_dtypes,
integer_dtypes,
)
from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter

Expand Down Expand Up @@ -2954,7 +2959,8 @@ def constant_folding(fgraph, node):

# TODO: `Type` itself should provide an interface for constructing
# instances appropriate for a given constant.
if isinstance(output.type, TensorType):
# TODO: Add handling for sparse types.
if isinstance(output.type, DenseTensorType):
output_type = TensorType(
output.type.dtype,
tuple(s == 1 for s in data.shape),
Expand Down
40 changes: 37 additions & 3 deletions aesara/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,12 @@
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add, mul, neg, sub
from aesara.tensor.type import integer_dtypes, tensor, values_eq_approx_remove_inf_nan
from aesara.tensor.type import (
DenseTensorType,
integer_dtypes,
tensor,
values_eq_approx_remove_inf_nan,
)
from aesara.utils import memoize


Expand Down Expand Up @@ -263,7 +268,13 @@ def make_node(self, y, alpha, A, x, beta):
raise TypeError("gemv requires vector for x", x.type)
if y.ndim != 1:
raise TypeError("gemv requires vector for y", y.type)
return Apply(self, [y, alpha, A, x, beta], [y.type()])

inputs = [y, alpha, A, x, beta]

if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")

return Apply(self, inputs, [y.type()])

def perform(self, node, inputs, out_storage, params=None):
y, alpha, A, x, beta = inputs
Expand Down Expand Up @@ -360,7 +371,12 @@ def make_node(self, A, alpha, x, y):

if x.dtype not in ("float32", "float64", "complex64", "complex128"):
raise TypeError("only float and complex types supported", x.dtype)
return Apply(self, [A, alpha, x, y], [A.type()])

inputs = [A, alpha, x, y]
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")

return Apply(self, inputs, [A.type()])

def perform(self, node, inp, out, params=None):
cA, calpha, cx, cy = inp
Expand Down Expand Up @@ -898,6 +914,10 @@ def __getstate__(self):

def make_node(self, *inputs):
inputs = list(map(at.as_tensor_variable, inputs))

if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")

if len(inputs) != 5:
raise TypeError(
f"Wrong number of inputs for {self} (expected 5, got {len(inputs)})"
Expand Down Expand Up @@ -1579,6 +1599,10 @@ class Dot22(GemmRelated):
def make_node(self, x, y):
x = at.as_tensor_variable(x)
y = at.as_tensor_variable(y)

if any(not isinstance(i.type, DenseTensorType) for i in (x, y)):
raise NotImplementedError("Only dense tensor types are supported")

dtypes = ("float16", "float32", "float64", "complex64", "complex128")
if x.type.ndim != 2 or x.type.dtype not in dtypes:
raise TypeError(x)
Expand Down Expand Up @@ -1664,6 +1688,9 @@ def local_dot_to_dot22(fgraph, node):
if not isinstance(node.op, Dot):
return

if any(not isinstance(i.type, DenseTensorType) for i in node.inputs):
return False

x, y = node.inputs
if y.type.dtype != x.type.dtype:
# TODO: upcast one so the types match
Expand Down Expand Up @@ -1868,6 +1895,10 @@ class Dot22Scalar(GemmRelated):
check_input = False

def make_node(self, x, y, a):

if any(not isinstance(i.type, DenseTensorType) for i in (x, y, a)):
raise NotImplementedError("Only dense tensor types are supported")

if a.ndim != 0:
raise TypeError(Gemm.E_scalar, a)
if x.ndim != 2:
Expand Down Expand Up @@ -2088,6 +2119,9 @@ class BatchedDot(COp):
def make_node(self, *inputs):
inputs = list(map(at.as_tensor_variable, inputs))

if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")

if len(inputs) != 2:
raise TypeError(f"Two arguments required, but {len(inputs)} given.")
if inputs[0].ndim not in (2, 3):
Expand Down
6 changes: 6 additions & 0 deletions aesara/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from aesara.tensor.shape import shape
from aesara.tensor.type import (
DenseTensorType,
complex_dtypes,
continuous_dtypes,
discrete_dtypes,
Expand Down Expand Up @@ -2075,6 +2076,11 @@ def dense_dot(a, b):
"""
a, b = as_tensor_variable(a), as_tensor_variable(b)

if not isinstance(a.type, DenseTensorType) or not isinstance(
b.type, DenseTensorType
):
raise TypeError("The dense dot product is only supported for dense types")

if a.ndim == 0 or b.ndim == 0:
return a * b
elif a.ndim > 2 or b.ndim > 2:
Expand Down
Loading

0 comments on commit a37dac0

Please sign in to comment.