Skip to content

Commit

Permalink
Make sparse tensor types extend the existing tensor types
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Feb 15, 2021
1 parent 03487af commit e76a322
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 61 deletions.
37 changes: 29 additions & 8 deletions aesara/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from aesara.misc.safe_asarray import _asarray
from aesara.sparse.type import SparseType, _is_sparse
from aesara.sparse.utils import hash_from_sparse

# TODO:
# from aesara.tensor import _as_tensor_variable
from aesara.tensor import basic as aet
from aesara.tensor.basic import Split
from aesara.tensor.math import add as aet_add
Expand All @@ -46,6 +49,7 @@
from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes
from aesara.tensor.type import iscalar, ivector, scalar, tensor, vector
from aesara.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators


sparse_formats = ["csc", "csr"]
Expand Down Expand Up @@ -123,8 +127,9 @@ def _is_dense(x):
return isinstance(x, np.ndarray)


# Wrapper type
def as_sparse_variable(x, name=None):
# TODO:
# @_as_tensor_variable.register(scipy.sparse.base.spmatrix)
def as_sparse_variable(x, name=None, ndim=None, **kwargs):
"""
Wrapper around SparseVariable constructor to construct
a Variable with a sparse matrix with the same dtype and
Expand Down Expand Up @@ -247,7 +252,7 @@ def sp_zeros_like(x):
)


class _sparse_py_operators:
class _sparse_py_operators(_tensor_py_operators):
T = property(
lambda self: transpose(self), doc="Return aliased transpose of self (read-only)"
)
Expand Down Expand Up @@ -355,8 +360,7 @@ def __getitem__(self, args):
return ret


class SparseVariable(_sparse_py_operators, Variable):
dtype = property(lambda self: self.type.dtype)
class SparseVariable(_sparse_py_operators, TensorVariable):
format = property(lambda self: self.type.format)

def __str__(self):
Expand Down Expand Up @@ -389,8 +393,7 @@ def aesara_hash(self):
return hash_from_sparse(d)


class SparseConstant(Constant, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype)
class SparseConstant(TensorConstant, _sparse_py_operators):
format = property(lambda self: self.type.format)

def signature(self):
Expand All @@ -413,6 +416,12 @@ def __repr__(self):
SparseType.Variable = SparseVariable
SparseType.Constant = SparseConstant

# TODO:
# @_as_tensor_variable.register(SparseVariable)
# @_as_tensor_variable.register(SparseConstant)
# def _as_tensor_sparse(x, name, ndim):
# return x


# for more dtypes, call SparseType(format, dtype)
def matrix(format, name=None, dtype=None):
Expand Down Expand Up @@ -442,7 +451,7 @@ def bsr_matrix(name=None, dtype=None):
csr_fmatrix = SparseType(format="csr", dtype="float32")
bsr_fmatrix = SparseType(format="bsr", dtype="float32")

all_dtypes = SparseType.dtype_set
all_dtypes = list(SparseType.dtype_specs_map.keys())
complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"]
float_dtypes = [t for t in all_dtypes if t[:5] == "float"]
int_dtypes = [t for t in all_dtypes if t[:3] == "int"]
Expand Down Expand Up @@ -920,6 +929,12 @@ def __init__(self, structured=True):
def __str__(self):
return f"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}"

def __call__(self, x):
if not isinstance(x.type, SparseType):
return x

return super().__call__(x)

def make_node(self, x):
x = as_sparse_variable(x)
return Apply(
Expand Down Expand Up @@ -997,6 +1012,12 @@ def __init__(self, format):
def __str__(self):
return f"{self.__class__.__name__}{{{self.format}}}"

def __call__(self, x):
if isinstance(x.type, SparseType):
return x

return super().__call__(x)

def make_node(self, x):
x = aet.as_tensor_variable(x)
if x.ndim > 2:
Expand Down
72 changes: 40 additions & 32 deletions aesara/sparse/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


import aesara
from aesara.graph.type import Type
from aesara.tensor.type import TensorType


def _is_sparse(x):
Expand All @@ -32,7 +32,7 @@ def _is_sparse(x):
return isinstance(x, scipy.sparse.spmatrix)


class SparseType(Type):
class SparseType(TensorType):
"""
Fundamental way to create a sparse node.
Expand Down Expand Up @@ -60,48 +60,52 @@ class SparseType(Type):
"csc": scipy.sparse.csc_matrix,
"bsr": scipy.sparse.bsr_matrix,
}
dtype_set = {
"int8",
"int16",
"int32",
"int64",
"float32",
"uint8",
"uint16",
"uint32",
"uint64",
"float64",
"complex64",
"complex128",
dtype_specs_map = {
"float32": (float, "npy_float32", "NPY_FLOAT32"),
"float64": (float, "npy_float64", "NPY_FLOAT64"),
"uint8": (int, "npy_uint8", "NPY_UINT8"),
"int8": (int, "npy_int8", "NPY_INT8"),
"uint16": (int, "npy_uint16", "NPY_UINT16"),
"int16": (int, "npy_int16", "NPY_INT16"),
"uint32": (int, "npy_uint32", "NPY_UINT32"),
"int32": (int, "npy_int32", "NPY_INT32"),
"uint64": (int, "npy_uint64", "NPY_UINT64"),
"int64": (int, "npy_int64", "NPY_INT64"),
"complex128": (complex, "aesara_complex128", "NPY_COMPLEX128"),
"complex64": (complex, "aesara_complex64", "NPY_COMPLEX64"),
}
ndim = 2

# Will be set to SparseVariable SparseConstant later.
Variable = None
Constant = None

def __init__(self, format, dtype):
def __init__(self, format, dtype, broadcastable=None, name=None):
if not imported_scipy:
raise Exception(
"You can't make SparseType object as SciPy" " is not available."
)
dtype = str(dtype)
if dtype in self.dtype_set:
self.dtype = dtype
else:
raise NotImplementedError(
f'unsupported dtype "{dtype}" not in list', list(self.dtype_set)
"You can't make SparseType object when SciPy is not available."
)

assert isinstance(format, str)
if not isinstance(format, str):
raise TypeError("The sparse format parameter must be a string")

if format in self.format_cls:
self.format = format
else:
raise NotImplementedError(
f'unsupported format "{format}" not in list',
list(self.format_cls.keys()),
)

if broadcastable is None:
broadcastable = [False, False]

super().__init__(dtype, broadcastable, name=name)

def clone(self, dtype=None, broadcastable=None):
new = super().clone(dtype=dtype, broadcastable=broadcastable)
new.format = self.format
return new

def filter(self, value, strict=False, allow_downcast=None):
if (
isinstance(value, self.format_cls[self.format])
Expand Down Expand Up @@ -152,14 +156,10 @@ def make_variable(self, name=None):
return self.Variable(self, name=name)

def __eq__(self, other):
return (
type(self) == type(other)
and other.dtype == self.dtype
and other.format == self.format
)
return super().__eq__(other) and other.format == self.format

def __hash__(self):
return hash(self.dtype) ^ hash(self.format)
return super().__hash__() ^ hash(self.format)

def __str__(self):
return f"Sparse[{self.dtype}, {self.format}]"
Expand Down Expand Up @@ -208,6 +208,14 @@ def get_size(self, shape_info):
+ (shape_info[2] + shape_info[3]) * np.dtype("int32").itemsize
)

def value_zeros(self, shape):
matrix_constructor = getattr(scipy.sparse, f"{self.format}_matrix", None)

if matrix_constructor is None:
raise ValueError(f"Sparse matrix type {self.format} not found in SciPy")

return matrix_constructor(shape, dtype=self.dtype)


# Register SparseType's C code for ViewOp.
aesara.compile.register_view_op_c_code(
Expand Down
10 changes: 8 additions & 2 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,16 @@ def get_scalar_constant_value(
data = v.tag.unique_value
else:
data = v.data

if isinstance(data, np.ndarray):
return numpy_scalar(data).copy()
else:
return data

from aesara.sparse.type import SparseType

if isinstance(v.type, SparseType):
raise NotScalarConstantError()

return data

if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
max_recur -= 1
Expand Down
41 changes: 38 additions & 3 deletions aesara/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,12 @@
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add, mul, neg, sub
from aesara.tensor.type import integer_dtypes, tensor, values_eq_approx_remove_inf_nan
from aesara.tensor.type import (
DenseTensorType,
integer_dtypes,
tensor,
values_eq_approx_remove_inf_nan,
)
from aesara.utils import memoize


Expand Down Expand Up @@ -253,6 +258,7 @@ def make_node(self, y, alpha, A, x, beta):
A = aet.as_tensor_variable(A)
alpha = aet.as_tensor_variable(alpha)
beta = aet.as_tensor_variable(beta)

if y.dtype != A.dtype or y.dtype != x.dtype:
raise TypeError(
"Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype)
Expand All @@ -263,7 +269,13 @@ def make_node(self, y, alpha, A, x, beta):
raise TypeError("gemv requires vector for x", x.type)
if y.ndim != 1:
raise TypeError("gemv requires vector for y", y.type)
return Apply(self, [y, alpha, A, x, beta], [y.type()])

inputs = [y, alpha, A, x, beta]

if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")

return Apply(self, inputs, [y.type()])

def perform(self, node, inputs, out_storage, params=None):
y, alpha, A, x, beta = inputs
Expand Down Expand Up @@ -360,7 +372,12 @@ def make_node(self, A, alpha, x, y):

if x.dtype not in ("float32", "float64", "complex64", "complex128"):
raise TypeError("only float and complex types supported", x.dtype)
return Apply(self, [A, alpha, x, y], [A.type()])

inputs = [A, alpha, x, y]
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")

return Apply(self, inputs, [A.type()])

def perform(self, node, inp, out, params=None):
cA, calpha, cx, cy = inp
Expand Down Expand Up @@ -898,6 +915,10 @@ def __getstate__(self):

def make_node(self, *inputs):
inputs = list(map(aet.as_tensor_variable, inputs))

if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")

if len(inputs) != 5:
raise TypeError(
f"Wrong number of inputs for {self} (expected 5, got {len(inputs)})"
Expand Down Expand Up @@ -1580,6 +1601,10 @@ class Dot22(GemmRelated):
def make_node(self, x, y):
x = aet.as_tensor_variable(x)
y = aet.as_tensor_variable(y)

if any(not isinstance(i.type, DenseTensorType) for i in (x, y)):
raise NotImplementedError("Only dense tensor types are supported")

dtypes = ("float16", "float32", "float64", "complex64", "complex128")
if x.type.ndim != 2 or x.type.dtype not in dtypes:
raise TypeError(x)
Expand Down Expand Up @@ -1665,6 +1690,9 @@ def local_dot_to_dot22(fgraph, node):
if not isinstance(node.op, Dot):
return

if any(not isinstance(i.type, DenseTensorType) for i in node.inputs):
return False

x, y = node.inputs
if y.type.dtype != x.type.dtype:
# TODO: upcast one so the types match
Expand Down Expand Up @@ -1847,6 +1875,10 @@ class Dot22Scalar(GemmRelated):
check_input = False

def make_node(self, x, y, a):

if any(not isinstance(i.type, DenseTensorType) for i in (x, y, a)):
raise NotImplementedError("Only dense tensor types are supported")

if a.ndim != 0:
raise TypeError(Gemm.E_scalar, a)
if x.ndim != 2:
Expand Down Expand Up @@ -2066,6 +2098,9 @@ class BatchedDot(COp):
def make_node(self, *inputs):
inputs = list(map(aet.as_tensor_variable, inputs))

if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")

if len(inputs) != 2:
raise TypeError(f"Two arguments required, but {len(inputs)} given.")
if inputs[0].ndim not in (2, 3):
Expand Down
6 changes: 6 additions & 0 deletions aesara/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from aesara.tensor.shape import shape
from aesara.tensor.type import (
DenseTensorType,
complex_dtypes,
continuous_dtypes,
discrete_dtypes,
Expand Down Expand Up @@ -2006,6 +2007,11 @@ def dense_dot(a, b):
"""
a, b = as_tensor_variable(a), as_tensor_variable(b)

if not isinstance(a.type, DenseTensorType) or not isinstance(
b.type, DenseTensorType
):
raise TypeError("The dense dot product is only supported for dense types")

if a.ndim == 0 or b.ndim == 0:
return a * b
elif a.ndim > 2 or b.ndim > 2:
Expand Down
Loading

0 comments on commit e76a322

Please sign in to comment.