Skip to content

Commit

Permalink
Rename SparseType to SparseTensorType
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 16, 2022
1 parent b8c1c46 commit 94f5ddf
Show file tree
Hide file tree
Showing 15 changed files with 145 additions and 131 deletions.
2 changes: 1 addition & 1 deletion aesara/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_scalar_constant_value(v):
"""
# Is it necessary to test for presence of aesara.sparse at runtime?
sparse = globals().get("sparse")
if sparse and isinstance(v.type, sparse.SparseType):
if sparse and isinstance(v.type, sparse.SparseTensorType):
if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
data = v.owner.inputs[0]
return tensor.get_scalar_constant_value(data)
Expand Down
2 changes: 1 addition & 1 deletion aesara/link/c/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CType(Type, CLinkerType):
- `TensorType`: for numpy.ndarray
- `SparseType`: for scipy.sparse
- `SparseTensorType`: for scipy.sparse
But you are encouraged to write your own, as described in WRITEME.
Expand Down
4 changes: 2 additions & 2 deletions aesara/misc/may_share_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
try:
import scipy.sparse

from aesara.sparse.basic import SparseType
from aesara.sparse.basic import SparseTensorType

def _is_sparse(a):
return scipy.sparse.issparse(a)
Expand Down Expand Up @@ -64,4 +64,4 @@ def may_share_memory(a, b, raise_other_type=True):

if a_gpua or b_gpua:
return False
return SparseType.may_share_memory(a, b)
return SparseTensorType.may_share_memory(a, b)
2 changes: 1 addition & 1 deletion aesara/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
enable_sparse = False
warn("SciPy can't be imported. Sparse matrix support is disabled.")

from aesara.sparse.type import SparseType, _is_sparse
from aesara.sparse.type import SparseTensorType, _is_sparse


if enable_sparse:
Expand Down
96 changes: 55 additions & 41 deletions aesara/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from aesara.link.c.op import COp
from aesara.link.c.type import generic
from aesara.misc.safe_asarray import _asarray
from aesara.sparse.type import SparseType, _is_sparse
from aesara.sparse.type import SparseTensorType, _is_sparse
from aesara.sparse.utils import hash_from_sparse
from aesara.tensor import basic as at
from aesara.tensor.basic import Split
Expand Down Expand Up @@ -80,11 +80,11 @@ def _is_sparse_variable(x):
if not isinstance(x, Variable):
raise NotImplementedError(
"this function should only be called on "
"*variables* (of type sparse.SparseType "
"*variables* (of type sparse.SparseTensorType "
"or TensorType, for instance), not ",
x,
)
return isinstance(x.type, SparseType)
return isinstance(x.type, SparseTensorType)


def _is_dense_variable(x):
Expand All @@ -100,7 +100,7 @@ def _is_dense_variable(x):
if not isinstance(x, Variable):
raise NotImplementedError(
"this function should only be called on "
"*variables* (of type sparse.SparseType or "
"*variables* (of type sparse.SparseTensorType or "
"TensorType, for instance), not ",
x,
)
Expand Down Expand Up @@ -159,13 +159,15 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):
else:
x = x.outputs[0]
if isinstance(x, Variable):
if not isinstance(x.type, SparseType):
raise TypeError("Variable type field must be a SparseType.", x, x.type)
if not isinstance(x.type, SparseTensorType):
raise TypeError(
"Variable type field must be a SparseTensorType.", x, x.type
)
return x
try:
return constant(x, name=name)
except TypeError:
raise TypeError(f"Cannot convert {x} to SparseType", type(x))
raise TypeError(f"Cannot convert {x} to SparseTensorType", type(x))


as_sparse = as_sparse_variable
Expand Down Expand Up @@ -198,10 +200,10 @@ def constant(x, name=None):
raise TypeError("sparse.constant must be called on a " "scipy.sparse.spmatrix")
try:
return SparseConstant(
SparseType(format=x.format, dtype=x.dtype), x.copy(), name=name
SparseTensorType(format=x.format, dtype=x.dtype), x.copy(), name=name
)
except TypeError:
raise TypeError(f"Could not convert {x} to SparseType", type(x))
raise TypeError(f"Could not convert {x} to SparseTensorType", type(x))


def sp_ones_like(x):
Expand Down Expand Up @@ -259,7 +261,7 @@ def to_dense(self, *args, **kwargs):
self = self.toarray()
new_args = [
arg.toarray()
if hasattr(arg, "type") and isinstance(arg.type, SparseType)
if hasattr(arg, "type") and isinstance(arg.type, SparseTensorType)
else arg
for arg in args
]
Expand Down Expand Up @@ -503,15 +505,15 @@ def __repr__(self):
return str(self)


SparseType.variable_type = SparseVariable
SparseType.constant_type = SparseConstant
SparseTensorType.variable_type = SparseVariable
SparseTensorType.constant_type = SparseConstant


# for more dtypes, call SparseType(format, dtype)
# for more dtypes, call SparseTensorType(format, dtype)
def matrix(format, name=None, dtype=None):
if dtype is None:
dtype = config.floatX
type = SparseType(format=format, dtype=dtype)
type = SparseTensorType(format=format, dtype=dtype)
return type(name)


Expand All @@ -527,15 +529,15 @@ def bsr_matrix(name=None, dtype=None):
return matrix("bsr", name, dtype)


# for more dtypes, call SparseType(format, dtype)
csc_dmatrix = SparseType(format="csc", dtype="float64")
csr_dmatrix = SparseType(format="csr", dtype="float64")
bsr_dmatrix = SparseType(format="bsr", dtype="float64")
csc_fmatrix = SparseType(format="csc", dtype="float32")
csr_fmatrix = SparseType(format="csr", dtype="float32")
bsr_fmatrix = SparseType(format="bsr", dtype="float32")
# for more dtypes, call SparseTensorType(format, dtype)
csc_dmatrix = SparseTensorType(format="csc", dtype="float64")
csr_dmatrix = SparseTensorType(format="csr", dtype="float64")
bsr_dmatrix = SparseTensorType(format="bsr", dtype="float64")
csc_fmatrix = SparseTensorType(format="csc", dtype="float32")
csr_fmatrix = SparseTensorType(format="csr", dtype="float32")
bsr_fmatrix = SparseTensorType(format="bsr", dtype="float32")

all_dtypes = list(SparseType.dtype_specs_map.keys())
all_dtypes = list(SparseTensorType.dtype_specs_map.keys())
complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"]
float_dtypes = [t for t in all_dtypes if t[:5] == "float"]
int_dtypes = [t for t in all_dtypes if t[:3] == "int"]
Expand Down Expand Up @@ -725,7 +727,7 @@ def make_node(self, data, indices, indptr, shape):
return Apply(
self,
[data, indices, indptr, shape],
[SparseType(dtype=data.type.dtype, format=self.format)()],
[SparseTensorType(dtype=data.type.dtype, format=self.format)()],
)

def perform(self, node, inputs, outputs):
Expand Down Expand Up @@ -931,7 +933,9 @@ def __init__(self, out_type):
def make_node(self, x):
x = as_sparse_variable(x)
assert x.format in ("csr", "csc")
return Apply(self, [x], [SparseType(dtype=self.out_type, format=x.format)()])
return Apply(
self, [x], [SparseTensorType(dtype=self.out_type, format=x.format)()]
)

def perform(self, node, inputs, outputs):
(x,) = inputs
Expand Down Expand Up @@ -1014,7 +1018,7 @@ def __str__(self):
return f"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}"

def __call__(self, x):
if not isinstance(x.type, SparseType):
if not isinstance(x.type, SparseTensorType):
return x

return super().__call__(x)
Expand Down Expand Up @@ -1097,7 +1101,7 @@ def __str__(self):
return f"{self.__class__.__name__}{{{self.format}}}"

def __call__(self, x):
if isinstance(x.type, SparseType):
if isinstance(x.type, SparseTensorType):
return x

return super().__call__(x)
Expand All @@ -1116,12 +1120,14 @@ def make_node(self, x):
else:
assert x.ndim == 2

return Apply(self, [x], [SparseType(dtype=x.type.dtype, format=self.format)()])
return Apply(
self, [x], [SparseTensorType(dtype=x.type.dtype, format=self.format)()]
)

def perform(self, node, inputs, outputs):
(x,) = inputs
(out,) = outputs
out[0] = SparseType.format_cls[self.format](x)
out[0] = SparseTensorType.format_cls[self.format](x)

def grad(self, inputs, gout):
(x,) = inputs
Expand Down Expand Up @@ -1585,7 +1591,11 @@ def make_node(self, x):
return Apply(
self,
[x],
[SparseType(dtype=x.type.dtype, format=self.format_map[x.type.format])()],
[
SparseTensorType(
dtype=x.type.dtype, format=self.format_map[x.type.format]
)()
],
)

def perform(self, node, inputs, outputs):
Expand Down Expand Up @@ -2002,7 +2012,7 @@ def make_node(self, diag):
if diag.type.ndim != 1:
raise TypeError("data argument must be a vector", diag.type)

return Apply(self, [diag], [SparseType(dtype=diag.dtype, format="csc")()])
return Apply(self, [diag], [SparseTensorType(dtype=diag.dtype, format="csc")()])

def perform(self, node, inputs, outputs):
(z,) = outputs
Expand Down Expand Up @@ -2146,7 +2156,7 @@ def make_node(self, x, y):
assert y.format in ("csr", "csc")
out_dtype = aes.upcast(x.type.dtype, y.type.dtype)
return Apply(
self, [x, y], [SparseType(dtype=out_dtype, format=x.type.format)()]
self, [x, y], [SparseTensorType(dtype=out_dtype, format=x.type.format)()]
)

def perform(self, node, inputs, outputs):
Expand Down Expand Up @@ -2183,7 +2193,7 @@ def make_node(self, x, y):
if x.type.format != y.type.format:
raise NotImplementedError()
return Apply(
self, [x, y], [SparseType(dtype=x.type.dtype, format=x.type.format)()]
self, [x, y], [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()]
)

def perform(self, node, inputs, outputs):
Expand Down Expand Up @@ -2286,7 +2296,7 @@ def make_node(self, x, y):
if x.type.dtype != y.type.dtype:
raise NotImplementedError()
return Apply(
self, [x, y], [SparseType(dtype=x.type.dtype, format=x.type.format)()]
self, [x, y], [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()]
)

def perform(self, node, inputs, outputs):
Expand Down Expand Up @@ -2426,7 +2436,7 @@ def make_node(self, x, y):
assert y.format in ("csr", "csc")
out_dtype = aes.upcast(x.type.dtype, y.type.dtype)
return Apply(
self, [x, y], [SparseType(dtype=out_dtype, format=x.type.format)()]
self, [x, y], [SparseTensorType(dtype=out_dtype, format=x.type.format)()]
)

def perform(self, node, inputs, outputs):
Expand Down Expand Up @@ -2469,7 +2479,7 @@ def make_node(self, x, y):
# Broadcasting of the sparse matrix is not supported.
# We support nd == 0 used by grad of SpSum()
assert y.type.ndim in (0, 2)
out = SparseType(dtype=dtype, format=x.type.format)()
out = SparseTensorType(dtype=dtype, format=x.type.format)()
return Apply(self, [x, y], [out])

def perform(self, node, inputs, outputs):
Expand Down Expand Up @@ -2559,7 +2569,7 @@ def make_node(self, x, y):
f"Got {x.type.dtype} and {y.type.dtype}."
)
return Apply(
self, [x, y], [SparseType(dtype=x.type.dtype, format=x.type.format)()]
self, [x, y], [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()]
)

def perform(self, node, inputs, outputs):
Expand Down Expand Up @@ -2694,7 +2704,9 @@ def make_node(self, x, y):

if x.type.format != y.type.format:
raise NotImplementedError()
return Apply(self, [x, y], [SparseType(dtype="uint8", format=x.type.format)()])
return Apply(
self, [x, y], [SparseTensorType(dtype="uint8", format=x.type.format)()]
)

def perform(self, node, inputs, outputs):
(x, y) = inputs
Expand Down Expand Up @@ -3050,7 +3062,9 @@ def make_node(self, *mat):
for x in var:
assert x.format in ("csr", "csc")

return Apply(self, var, [SparseType(dtype=self.dtype, format=self.format)()])
return Apply(
self, var, [SparseTensorType(dtype=self.dtype, format=self.format)()]
)

def perform(self, node, block, outputs):
(out,) = outputs
Expand Down Expand Up @@ -3578,7 +3592,7 @@ def make_node(self, x, y):
raise NotImplementedError()

inputs = [x, y] # Need to convert? e.g. assparse
outputs = [SparseType(dtype=x.type.dtype, format=myformat)()]
outputs = [SparseTensorType(dtype=x.type.dtype, format=myformat)()]
return Apply(self, inputs, outputs)

def perform(self, node, inp, out_):
Expand Down Expand Up @@ -3702,7 +3716,7 @@ def make_node(self, a, b):
raise NotImplementedError("non-matrix b")

if _is_sparse_variable(b):
return Apply(self, [a, b], [SparseType(a.type.format, dtype_out)()])
return Apply(self, [a, b], [SparseTensorType(a.type.format, dtype_out)()])
else:
return Apply(
self,
Expand All @@ -3719,7 +3733,7 @@ def perform(self, node, inputs, outputs):
)

variable = a * b
if isinstance(node.outputs[0].type, SparseType):
if isinstance(node.outputs[0].type, SparseTensorType):
assert _is_sparse(variable)
out[0] = variable
return
Expand Down
6 changes: 4 additions & 2 deletions aesara/sparse/sandbox/sp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from aesara.graph.op import Op
from aesara.sparse.basic import (
Remove0,
SparseType,
SparseTensorType,
_is_sparse,
as_sparse_variable,
remove0,
Expand Down Expand Up @@ -108,7 +108,9 @@ def make_node(self, n, p, shape):
assert shape.dtype in discrete_dtypes

return Apply(
self, [n, p, shape], [SparseType(dtype=self.dtype, format=self.format)()]
self,
[n, p, shape],
[SparseTensorType(dtype=self.dtype, format=self.format)()],
)

def perform(self, node, inputs, outputs):
Expand Down
6 changes: 3 additions & 3 deletions aesara/sparse/sharedvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import scipy.sparse

from aesara.compile import SharedVariable, shared_constructor
from aesara.sparse.basic import SparseType, _sparse_py_operators
from aesara.sparse.basic import SparseTensorType, _sparse_py_operators


class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable):
Expand All @@ -16,7 +16,7 @@ def sparse_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False, format=None
):
"""
SharedVariable Constructor for SparseType.
SharedVariable Constructor for SparseTensorType.
writeme
Expand All @@ -29,7 +29,7 @@ def sparse_constructor(

if format is None:
format = value.format
type = SparseType(format=format, dtype=value.dtype)
type = SparseTensorType(format=format, dtype=value.dtype)
if not borrow:
value = copy.deepcopy(value)
return SparseTensorSharedVariable(
Expand Down
Loading

0 comments on commit 94f5ddf

Please sign in to comment.