Skip to content

Commit

Permalink
Clean up Join.make_node
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 31, 2021
1 parent e98a498 commit de52629
Showing 1 changed file with 26 additions and 53 deletions.
79 changes: 26 additions & 53 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,65 +2240,48 @@ def __setstate__(self, d):
if not hasattr(self, "view"):
self.view = -1

def make_node(self, *axis_and_tensors):
def make_node(self, axis, *tensors):
"""
Parameters
----------
axis: an Int or integer-valued Variable
axis
The axis upon which to join `tensors`.
tensors
A variable number (but not zero) of tensors to
concatenate along the specified axis. These tensors must have
the same shape along all dimensions other than this axis.
Returns
-------
A symbolic Variable
It has the same ndim as the input tensors, and the most inclusive
dtype.
A variable number of tensors to join along the specified axis.
These tensors must have the same shape along all dimensions other
than `axis`.
"""
axis, tens = axis_and_tensors[0], axis_and_tensors[1:]
if not tens:
if not tensors:
raise ValueError("Cannot join an empty list of tensors")
as_tensor_variable_args = [as_tensor_variable(x) for x in tens]

dtypes = [x.type.dtype for x in as_tensor_variable_args]
out_dtype = aes.upcast(*dtypes)

def output_maker(bcastable):
return tensor(dtype=out_dtype, broadcastable=bcastable)

return self._make_node_internal(
axis, tens, as_tensor_variable_args, output_maker
)
tensors = [as_tensor_variable(x) for x in tensors]
out_dtype = aes.upcast(*[x.type.dtype for x in tensors])

def _make_node_internal(self, axis, tens, as_tensor_variable_args, output_maker):
if not builtins.all(targs.type.ndim for targs in as_tensor_variable_args):
if not builtins.all(targs.type.ndim for targs in tensors):
raise TypeError(
"Join cannot handle arguments of dimension 0."
" For joining scalar values, see @stack"
" Use `stack` to join scalar values."
)
# Handle single-tensor joins immediately.
if len(as_tensor_variable_args) == 1:
bcastable = list(as_tensor_variable_args[0].type.broadcastable)
if len(tensors) == 1:
bcastable = list(tensors[0].type.broadcastable)
else:
# When the axis is fixed, a dimension should be
# broadcastable if at least one of the inputs is
# broadcastable on that dimension (see justification below),
# except for the axis dimension.
# Initialize bcastable all false, and then fill in some trues with
# the loops.
bcastable = [False] * len(as_tensor_variable_args[0].type.broadcastable)
bcastable = [False] * len(tensors[0].type.broadcastable)
ndim = len(bcastable)
# Axis can also be a constant

if not isinstance(axis, int):
try:
# Note : `get_scalar_constant_value` returns a ndarray not
# an int
axis = int(get_scalar_constant_value(axis))

except NotScalarConstantError:
pass

if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the
# converse does not hold. So we permit e.g. T/F/T
Expand All @@ -2310,12 +2293,12 @@ def _make_node_internal(self, axis, tens, as_tensor_variable_args, output_maker)

if axis < -ndim:
raise IndexError(
f"Join axis {int(axis)} out of bounds [0, {int(ndim)})"
f"Axis value {axis} is out of range for the given input dimensions"
)
if axis < 0:
axis += ndim

for x in as_tensor_variable_args:
for x in tensors:
for current_axis, bflag in enumerate(x.type.broadcastable):
# Constant negative axis can no longer be negative at
# this point. It safe to compare this way.
Expand All @@ -2327,34 +2310,24 @@ def _make_node_internal(self, axis, tens, as_tensor_variable_args, output_maker)
bcastable[axis] = False
except IndexError:
raise ValueError(
'Join argument "axis" is out of range'
" (given input dimensions)"
f"Axis value {axis} is out of range for the given input dimensions"
)
else:
# When the axis may vary, no dimension can be guaranteed to be
# broadcastable.
bcastable = [False] * len(as_tensor_variable_args[0].type.broadcastable)
bcastable = [False] * len(tensors[0].type.broadcastable)

if not builtins.all(
[x.ndim == len(bcastable) for x in as_tensor_variable_args[1:]]
):
if not builtins.all([x.ndim == len(bcastable) for x in tensors]):
raise TypeError(
"Join() can only join tensors with the same " "number of dimensions."
"Only tensors with the same number of dimensions can be joined"
)

inputs = [as_tensor_variable(axis)] + list(as_tensor_variable_args)
if inputs[0].type not in int_types:
raise TypeError(
"Axis could not be cast to an integer type",
axis,
inputs[0].type,
int_types,
)
inputs = [as_tensor_variable(axis)] + list(tensors)

outputs = [output_maker(bcastable)]
if inputs[0].type.dtype not in int_dtypes:
raise TypeError(f"Axis value {inputs[0]} must be an integer type")

node = Apply(self, inputs, outputs)
return node
return Apply(self, inputs, [tensor(dtype=out_dtype, broadcastable=bcastable)])

def perform(self, node, axis_and_tensors, out_):
(out,) = out_
Expand Down

0 comments on commit de52629

Please sign in to comment.