Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sparse tensor types extend the existing tensor types #303

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion aesara/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,10 @@ def replace(self, var, new_var, reason=None, verbose=None, import_missing=False)
if verbose is None:
verbose = config.optimizer_verbose
if verbose:
print(reason, var, new_var)
print(
f"{reason}:\t{var.owner or var} [{var.name or var.auto_name}] -> "
f"{new_var.owner or new_var} [{new_var.name or new_var.auto_name}]"
)

new_var = var.type.filter_variable(new_var, allow_convert=True)

Expand Down
43 changes: 28 additions & 15 deletions aesara/graph/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def __hash__(self):
_optimizer_idx[0] += 1
return self._optimizer_idx

def __str__(self):
if hasattr(self, "name"):
return f"{type(self).__name__}[{self.name}]"
return repr(self)


class FromFunctionOptimizer(GlobalOptimizer):
"""A `GlobalOptimizer` constructed from a given function."""
Expand Down Expand Up @@ -1074,6 +1079,11 @@ def add_requirements(self, fgraph):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream)

def __str__(self):
if hasattr(self, "name"):
return f"{type(self).__name__}[{self.name}]"
return repr(self)


class LocalMetaOptimizer(LocalOptimizer):
"""
Expand Down Expand Up @@ -1253,6 +1263,14 @@ def decorator(f):
class LocalOptGroup(LocalOptimizer):
"""Takes a list of LocalOptimizer and applies them to the node.

This is where the "tracks" parameters are largely used. If you set
one of the `LocalOptimizer` in `LocalOptGroup.optimizers` to track a
specific `Op` instance, this optimizer will only apply said
`LocalOptimizer` when it's acting on a node that exactly matches the object
object tracked `Op` (the matching is performed using a `dict` lookup).

TODO: Use type-based matching (e.g. like `singledispatch`).

Parameters
----------
optimizers :
Expand All @@ -1269,10 +1287,12 @@ class LocalOptGroup(LocalOptimizer):

"""

def __init__(self, *optimizers, **kwargs):
def __init__(self, *optimizers, apply_all_opts=False, profile=False, name=None):
self.name = name

if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0])

self.opts = optimizers
assert isinstance(self.opts, tuple)

Expand All @@ -1281,10 +1301,10 @@ def __init__(self, *optimizers, **kwargs):
getattr(opt, "retains_inputs", False) for opt in optimizers
)

self.apply_all_opts = kwargs.pop("apply_all_opts", False)
self.profile = kwargs.pop("profile", False)
self.track_map = defaultdict(lambda: [])
assert len(kwargs) == 0
self.apply_all_opts = apply_all_opts
self.profile = profile
self.track_map = defaultdict(list)

if self.profile:
self.time_opts = {}
self.process_count = {}
Expand All @@ -1304,12 +1324,8 @@ def __init__(self, *optimizers, **kwargs):
for c in tracks:
self.track_map[c].append(o)

def __str__(self):
return getattr(
self,
"__name__",
f"LocalOptGroup({','.join([str(o) for o in self.opts])})",
)
def __repr__(self):
return f"LocalOptGroup([{', '.join([str(o) for o in self.opts])}])"

def tracks(self):
t = []
Expand Down Expand Up @@ -2189,9 +2205,6 @@ def print_profile(stream, prof, level=0):
level=level + 1,
)

def __str__(self):
return getattr(self, "__name__", "<TopoOptimizer instance>")


def out2in(*local_opts, **kwargs):
"""
Expand Down
11 changes: 4 additions & 7 deletions aesara/graph/optdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,22 +460,19 @@ def register(self, name, obj, *tags, **kwargs):
self.__position__[name] = position

def query(self, *tags, **kwtags):
# For the new `useless` optimizer
opts = list(super().query(*tags, **kwtags))
opts.sort(key=lambda obj: (self.__position__[obj.name], obj.name))

ret = self.local_opt(
*opts, apply_all_opts=self.apply_all_opts, profile=self.profile
*opts,
apply_all_opts=self.apply_all_opts,
profile=self.profile,
)
return ret


class TopoDB(DB):
"""

Generate a `GlobalOptimizer` of type TopoOptimizer.

"""
"""Generate a `GlobalOptimizer` of type TopoOptimizer."""

def __init__(
self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None
Expand Down
4 changes: 1 addition & 3 deletions aesara/graph/toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def replace_all_validate(

for r, new_r in replacements:
try:
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
fgraph.replace(r, new_r, reason=reason, verbose=verbose, **kwargs)
except Exception as e:
msg = str(e)
s1 = "The type of the replacement must be the same"
Expand Down Expand Up @@ -626,8 +626,6 @@ def replace_all_validate(
print(
"Scan removed", nb, nb2, getattr(reason, "name", reason), r, new_r
)
if verbose:
print(reason, r, new_r)
# The return is needed by replace_all_validate_remove
return chk

Expand Down
37 changes: 29 additions & 8 deletions aesara/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from aesara.misc.safe_asarray import _asarray
from aesara.sparse.type import SparseType, _is_sparse
from aesara.sparse.utils import hash_from_sparse

# TODO:
# from aesara.tensor import _as_tensor_variable
from aesara.tensor import basic as aet
from aesara.tensor.basic import Split
from aesara.tensor.math import add as aet_add
Expand All @@ -46,6 +49,7 @@
from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes
from aesara.tensor.type import iscalar, ivector, scalar, tensor, vector
from aesara.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators


sparse_formats = ["csc", "csr"]
Expand Down Expand Up @@ -123,8 +127,9 @@ def _is_dense(x):
return isinstance(x, np.ndarray)


# Wrapper type
def as_sparse_variable(x, name=None):
# TODO:
# @_as_tensor_variable.register(scipy.sparse.base.spmatrix)
def as_sparse_variable(x, name=None, ndim=None, **kwargs):
"""
Wrapper around SparseVariable constructor to construct
a Variable with a sparse matrix with the same dtype and
Expand Down Expand Up @@ -247,7 +252,7 @@ def sp_zeros_like(x):
)


class _sparse_py_operators:
class _sparse_py_operators(_tensor_py_operators):
T = property(
lambda self: transpose(self), doc="Return aliased transpose of self (read-only)"
)
Expand Down Expand Up @@ -355,8 +360,7 @@ def __getitem__(self, args):
return ret


class SparseVariable(_sparse_py_operators, Variable):
dtype = property(lambda self: self.type.dtype)
class SparseVariable(_sparse_py_operators, TensorVariable):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the supertypes in a different order compared to the SparseConstant below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure.

format = property(lambda self: self.type.format)

def __str__(self):
Expand Down Expand Up @@ -389,8 +393,7 @@ def aesara_hash(self):
return hash_from_sparse(d)


class SparseConstant(Constant, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype)
class SparseConstant(TensorConstant, _sparse_py_operators):
format = property(lambda self: self.type.format)

def signature(self):
Expand All @@ -413,6 +416,12 @@ def __repr__(self):
SparseType.Variable = SparseVariable
SparseType.Constant = SparseConstant

# TODO:
# @_as_tensor_variable.register(SparseVariable)
# @_as_tensor_variable.register(SparseConstant)
# def _as_tensor_sparse(x, name, ndim):
# return x


# for more dtypes, call SparseType(format, dtype)
def matrix(format, name=None, dtype=None):
Expand Down Expand Up @@ -442,7 +451,7 @@ def bsr_matrix(name=None, dtype=None):
csr_fmatrix = SparseType(format="csr", dtype="float32")
bsr_fmatrix = SparseType(format="bsr", dtype="float32")

all_dtypes = SparseType.dtype_set
all_dtypes = list(SparseType.dtype_specs_map.keys())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you make it a list so it can be appended at runtime?

complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"]
float_dtypes = [t for t in all_dtypes if t[:5] == "float"]
int_dtypes = [t for t in all_dtypes if t[:3] == "int"]
Expand Down Expand Up @@ -920,6 +929,12 @@ def __init__(self, structured=True):
def __str__(self):
return f"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}"

def __call__(self, x):
if not isinstance(x.type, SparseType):
return x

return super().__call__(x)

def make_node(self, x):
x = as_sparse_variable(x)
return Apply(
Expand Down Expand Up @@ -997,6 +1012,12 @@ def __init__(self, format):
def __str__(self):
return f"{self.__class__.__name__}{{{self.format}}}"

def __call__(self, x):
if isinstance(x.type, SparseType):
return x

return super().__call__(x)

def make_node(self, x):
x = aet.as_tensor_variable(x)
if x.ndim > 2:
Expand Down
72 changes: 40 additions & 32 deletions aesara/sparse/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


import aesara
from aesara.graph.type import Type
from aesara.tensor.type import TensorType


def _is_sparse(x):
Expand All @@ -32,7 +32,7 @@ def _is_sparse(x):
return isinstance(x, scipy.sparse.spmatrix)


class SparseType(Type):
class SparseType(TensorType):
"""
Fundamental way to create a sparse node.

Expand Down Expand Up @@ -60,48 +60,52 @@ class SparseType(Type):
"csc": scipy.sparse.csc_matrix,
"bsr": scipy.sparse.bsr_matrix,
}
dtype_set = {
"int8",
"int16",
"int32",
"int64",
"float32",
"uint8",
"uint16",
"uint32",
"uint64",
"float64",
"complex64",
"complex128",
dtype_specs_map = {
"float32": (float, "npy_float32", "NPY_FLOAT32"),
"float64": (float, "npy_float64", "NPY_FLOAT64"),
"uint8": (int, "npy_uint8", "NPY_UINT8"),
"int8": (int, "npy_int8", "NPY_INT8"),
"uint16": (int, "npy_uint16", "NPY_UINT16"),
"int16": (int, "npy_int16", "NPY_INT16"),
"uint32": (int, "npy_uint32", "NPY_UINT32"),
"int32": (int, "npy_int32", "NPY_INT32"),
"uint64": (int, "npy_uint64", "NPY_UINT64"),
"int64": (int, "npy_int64", "NPY_INT64"),
"complex128": (complex, "aesara_complex128", "NPY_COMPLEX128"),
"complex64": (complex, "aesara_complex64", "NPY_COMPLEX64"),
}
ndim = 2

# Will be set to SparseVariable SparseConstant later.
Variable = None
Constant = None

def __init__(self, format, dtype):
def __init__(self, format, dtype, broadcastable=None, name=None):
if not imported_scipy:
raise Exception(
"You can't make SparseType object as SciPy" " is not available."
)
dtype = str(dtype)
if dtype in self.dtype_set:
self.dtype = dtype
else:
raise NotImplementedError(
f'unsupported dtype "{dtype}" not in list', list(self.dtype_set)
"You can't make SparseType object when SciPy is not available."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it even a supported use case to NOT have scipy?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how this project was originally designed, I believe, but there's no point in supporting that optionality now.

)

assert isinstance(format, str)
if not isinstance(format, str):
raise TypeError("The sparse format parameter must be a string")

if format in self.format_cls:
self.format = format
else:
raise NotImplementedError(
f'unsupported format "{format}" not in list',
list(self.format_cls.keys()),
)

if broadcastable is None:
broadcastable = [False, False]

super().__init__(dtype, broadcastable, name=name)

def clone(self, dtype=None, broadcastable=None):
new = super().clone(dtype=dtype, broadcastable=broadcastable)
new.format = self.format
return new

def filter(self, value, strict=False, allow_downcast=None):
if (
isinstance(value, self.format_cls[self.format])
Expand Down Expand Up @@ -152,14 +156,10 @@ def make_variable(self, name=None):
return self.Variable(self, name=name)

def __eq__(self, other):
return (
type(self) == type(other)
and other.dtype == self.dtype
and other.format == self.format
)
return super().__eq__(other) and other.format == self.format

def __hash__(self):
return hash(self.dtype) ^ hash(self.format)
return super().__hash__() ^ hash(self.format)

def __str__(self):
return f"Sparse[{self.dtype}, {self.format}]"
Expand Down Expand Up @@ -208,6 +208,14 @@ def get_size(self, shape_info):
+ (shape_info[2] + shape_info[3]) * np.dtype("int32").itemsize
)

def value_zeros(self, shape):
matrix_constructor = getattr(scipy.sparse, f"{self.format}_matrix", None)

if matrix_constructor is None:
raise ValueError(f"Sparse matrix type {self.format} not found in SciPy")

return matrix_constructor(shape, dtype=self.dtype)


# Register SparseType's C code for ViewOp.
aesara.compile.register_view_op_c_code(
Expand Down
Loading