Skip to content

Commit

Permalink
Add FixedShapeTensorTypes and refactor basic shape logic to use them
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 31, 2021
1 parent 9ce3bd6 commit fe03603
Show file tree
Hide file tree
Showing 10 changed files with 388 additions and 162 deletions.
16 changes: 12 additions & 4 deletions aesara/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from aesara.graph.type import CType


def register_view_op_c_code(type, code, version=()):
def register_view_op_c_code(types, code, version=()):
"""
Tell ViewOp how to generate C code for an Aesara Type.
Expand All @@ -30,7 +30,11 @@ def register_view_op_c_code(type, code, version=()):
A number indicating the version of the code, for cache.
"""
ViewOp.c_code_and_version[type] = (code, version)
if not isinstance(types, list):
types = [types]

for typ in types:
ViewOp.c_code_and_version[typ] = (code, version)


class ViewOp(COp):
Expand Down Expand Up @@ -127,7 +131,7 @@ class OutputGuard(ViewOp):
_output_guard = OutputGuard()


def register_deep_copy_op_c_code(typ, code, version=()):
def register_deep_copy_op_c_code(types, code, version=()):
"""
Tell DeepCopyOp how to generate C code for an Aesara Type.
Expand All @@ -142,7 +146,11 @@ def register_deep_copy_op_c_code(typ, code, version=()):
A number indicating the version of the code, for cache.
"""
DeepCopyOp.c_code_and_version[typ] = (code, version)
if not isinstance(types, list):
types = [types]

for typ in types:
DeepCopyOp.c_code_and_version[typ] = (code, version)


class DeepCopyOp(COp):
Expand Down
31 changes: 25 additions & 6 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
shape_tuple,
)
from aesara.tensor.type import (
FixedShapeTensorType,
TensorType,
discrete_dtypes,
float_dtypes,
Expand Down Expand Up @@ -782,7 +783,7 @@ def c_code_cache_version(self):
return tuple(version)


def register_rebroadcast_c_code(typ, code, version=()):
def register_rebroadcast_c_code(types, code, version=()):
"""
Tell Rebroadcast how to generate C code for an Aesara Type.
Expand All @@ -797,11 +798,15 @@ def register_rebroadcast_c_code(typ, code, version=()):
A number indicating the version of the code, for cache.
"""
Rebroadcast.c_code_and_version[typ] = (code, version)
if not isinstance(types, list):
types = [types]

for typ in types:
Rebroadcast.c_code_and_version[typ] = (code, version)


register_rebroadcast_c_code(
TensorType,
[FixedShapeTensorType, TensorType],
"""
if(PyArray_DIMS(%(iname)s)[%(axis)s] != 1){
PyErr_Format(PyExc_ValueError,
Expand Down Expand Up @@ -1631,11 +1636,11 @@ def make_node(self, *inputs):

if inputs:
dtype = inputs[0].type.dtype
otype = FixedShapeTensorType(dtype, (len(inputs),))
else:
dtype = self.dtype
# bcastable = (len(inputs) == 1)
bcastable = False
otype = TensorType(broadcastable=(bcastable,), dtype=dtype)
otype = FixedShapeTensorType(dtype, 0)

return Apply(self, inputs, [otype()])

def perform(self, node, inputs, out_):
Expand Down Expand Up @@ -2122,6 +2127,13 @@ def addbroadcast(x, *axes):
A aesara tensor, which is broadcastable along the specified dimensions.
"""
x = as_tensor_variable(x)

if isinstance(x.type, FixedShapeTensorType):
if not set(i for i, b in enumerate(x.broadcastable) if b).issuperset(axes):
raise ValueError(f"{x}'s fixed broadcast pattern does not match {axes}")
return x

rval = Rebroadcast(*[(axis, True) for axis in axes])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)

Expand Down Expand Up @@ -2152,6 +2164,13 @@ def unbroadcast(x, *axes):
A aesara tensor, which is unbroadcastable along the specified dimensions.
"""
x = as_tensor_variable(x)

if isinstance(x.type, FixedShapeTensorType):
if not set(i for i, b in enumerate(x.broadcastable) if not b).issuperset(axes):
raise ValueError(f"{x}'s fixed broadcast pattern does not match {axes}")
return x

rval = Rebroadcast(*[(axis, False) for axis in axes])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)

Expand Down
3 changes: 2 additions & 1 deletion aesara/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aesara.tensor.extra_ops import broadcast_to
from aesara.tensor.math import maximum
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import int_dtypes
from aesara.tensor.type import FixedShapeTensorType, int_dtypes


def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True):
Expand Down Expand Up @@ -130,6 +130,7 @@ def normalize_size_param(size):
# `Scan` performs)
size = specify_shape(size, (get_vector_length(size),))

assert isinstance(size.type, FixedShapeTensorType)
assert size.dtype in int_dtypes

return size
Expand Down
120 changes: 88 additions & 32 deletions aesara/tensor/shape.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import warnings
from typing import Dict
from numbers import Number
from typing import Dict, List, Tuple, Union

import numpy as np

import aesara
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Variable
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp
from aesara.graph.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray
from aesara.scalar import int32
from aesara.tensor import _get_vector_length
from aesara.tensor import basic as aet
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.type import TensorType, int_dtypes, tensor
from aesara.tensor.type import FixedShapeTensorType, TensorType, int_dtypes, tensor
from aesara.tensor.var import TensorConstant


def register_shape_c_code(type, code, version=()):
def register_shape_c_code(types, code, version=()):
"""
Tell Shape Op how to generate C code for an Aesara Type.
Expand All @@ -33,7 +35,11 @@ def register_shape_c_code(type, code, version=()):
A number indicating the version of the code, for cache.
"""
Shape.c_code_and_version[type] = (code, version)
if not isinstance(types, list):
types = [types]

for typ in types:
Shape.c_code_and_version[typ] = (code, version)


class Shape(COp):
Expand Down Expand Up @@ -61,7 +67,8 @@ def make_node(self, x):
# This will fail at execution time.
if not isinstance(x, Variable):
x = aet.as_tensor_variable(x)
return Apply(self, [x], [aesara.tensor.type.lvector()])
otype = FixedShapeTensorType("int64", (x.ndim,))
return Apply(self, [x], [otype()])

def perform(self, node, inp, out_):
(x,) = inp
Expand Down Expand Up @@ -126,16 +133,29 @@ def c_code_cache_version(self):
return tuple(version)


shape = Shape()
_shape = shape # was used in the past, now use shape directly.
_shape = Shape()


def shape(x: Union[np.ndarray, Number, Variable]) -> Variable:
"""Return the shape of `x`."""
x = aet.as_tensor_variable(x)
x_type = x.type

if isinstance(x_type, FixedShapeTensorType):
res = aet.as_tensor_variable(x_type.shape, ndim=1, dtype=np.int64)
else:
res = _shape(x)

assert isinstance(res.type, FixedShapeTensorType)
return res


@_get_vector_length.register(Shape)
def _get_vector_length_Shape(op, var):
return var.owner.inputs[0].type.ndim


def shape_tuple(x):
def shape_tuple(x: Variable) -> Tuple[Variable]:
"""Get a tuple of symbolic shape values.
This will return a `ScalarConstant` with the value ``1`` wherever
Expand Down Expand Up @@ -319,7 +339,7 @@ def shape_i_op(i):
shape_i_op.cache = {}


def register_shape_i_c_code(typ, code, check_input, version=()):
def register_shape_i_c_code(types, code, check_input, version=()):
"""
Tell Shape_i how to generate C code for an Aesara Type.
Expand All @@ -335,10 +355,14 @@ def register_shape_i_c_code(typ, code, check_input, version=()):
A number indicating the version of the code, for cache.
"""
Shape_i.c_code_and_version[typ] = (code, check_input, version)
if not isinstance(types, list):
types = [types]

for typ in types:
Shape_i.c_code_and_version[typ] = (code, check_input, version)


def register_specify_shape_c_code(typ, code, version=(), c_support_code_apply=None):
def register_specify_shape_c_code(types, code, version=(), c_support_code_apply=None):
"""
Tell SpecifyShape how to generate C code for an Aesara Type.
Expand All @@ -357,7 +381,11 @@ def register_specify_shape_c_code(typ, code, version=(), c_support_code_apply=No
Extra code.
"""
SpecifyShape.c_code_and_version[typ] = (code, version, c_support_code_apply)
if not isinstance(types, list):
types = [types]

for typ in types:
SpecifyShape.c_code_and_version[typ] = (code, version, c_support_code_apply)


class SpecifyShape(COp):
Expand Down Expand Up @@ -388,22 +416,29 @@ class SpecifyShape(COp):
_f16_ok = True

def make_node(self, x, shape):
if not isinstance(x, Variable):
x = aet.as_tensor_variable(x)
if shape == () or shape == []:
tshape = aet.constant([], dtype="int64")
x = aet.as_tensor_variable(x)
shape = aet.as_tensor_variable(shape, ndim=1)

if isinstance(shape, Constant):
shape = tuple(shape.data)
else:
tshape = aet.as_tensor_variable(shape, ndim=1)
if tshape.dtype not in aesara.tensor.type.integer_dtypes:
raise AssertionError(
f"The `shape` must be an integer type. Got {tshape.dtype} instead."
)
if isinstance(tshape, TensorConstant) and tshape.data.size != x.ndim:
ndim = len(tshape.data)
raise AssertionError(
f"Input `x` is {x.ndim}-dimensional and will never match a {ndim}-dimensional shape."
shape = tuple(aet.as_tensor_variable(s, ndim=0) for s in shape)

if any(s.dtype not in aesara.tensor.type.integer_dtypes for s in shape):
raise TypeError("Shape values must be integer types")

if len(shape) != x.ndim:
raise ValueError(
f"Input `x` is {x.ndim}-dimensional and will never match a shape of length {len(shape)}."
)
return Apply(self, [x, tshape], [x.type()])

if all(isinstance(s, Number) for s in shape):
out_var = FixedShapeTensorType(x.type.dtype, shape)()
else:
out_var = x.type()

in_shape = aet.as_tensor_variable(shape, ndim=1)
return Apply(self, [x, in_shape], [out_var])

def perform(self, node, inp, out_):
x, shape = inp
Expand Down Expand Up @@ -495,7 +530,28 @@ def c_code_cache_version(self):
return tuple(version)


specify_shape = SpecifyShape()
_specify_shape = SpecifyShape()


def specify_shape(
x: Union[np.ndarray, Number, Variable],
shape: Union[
int, List[Union[int, Variable]], Tuple[Union[int, Variable]], Variable
],
):
"""Specify a fixed shape for a `Variable`."""

x = aet.as_tensor_variable(x)

try:
_ = get_vector_length(shape)
except ValueError:
raise ValueError("Shape must have fixed dimensions")

if isinstance(shape, Constant):
shape = tuple(shape.data)

return _specify_shape(x, shape)


@_get_vector_length.register(SpecifyShape)
Expand Down Expand Up @@ -708,7 +764,7 @@ def reshape(x, newshape, ndim=None):
f" scalar. Got {newshape} after conversion to a vector."
)
try:
ndim = aet.get_vector_length(newshape)
ndim = get_vector_length(newshape)
except ValueError:
raise ValueError(
f"The length of the provided shape ({newshape}) cannot "
Expand Down Expand Up @@ -791,7 +847,7 @@ def shape_padaxis(t, axis):


register_shape_c_code(
TensorType,
[FixedShapeTensorType, TensorType],
"""
npy_intp shape[] = {PyArray_NDIM(%(iname)s)};
if(%(oname)s == NULL || (PyArray_DIMS(%(oname)s)[0] != shape[0]))
Expand All @@ -809,7 +865,7 @@ def shape_padaxis(t, axis):


register_shape_i_c_code(
TensorType,
[FixedShapeTensorType, TensorType],
"""
if(!%(oname)s)
%(oname)s=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
Expand All @@ -827,7 +883,7 @@ def shape_padaxis(t, axis):


register_specify_shape_c_code(
TensorType,
[FixedShapeTensorType, TensorType],
"""
if (PyArray_NDIM(%(iname)s) != PyArray_DIMS(%(shape)s)[0]) {
PyErr_Format(PyExc_AssertionError,
Expand Down
Loading

0 comments on commit fe03603

Please sign in to comment.