Skip to content

Commit

Permalink
Added aesara.tensor.full_like equivalent to np.full_like (#567)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo Vieira authored Oct 23, 2021
1 parent 4125390 commit b0ba476
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
19 changes: 19 additions & 0 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,6 +1613,24 @@ def full(shape, fill_value, dtype=None):
return alloc(fill_value, *shape)


def full_like(
a: TensorVariable,
fill_value: Union[TensorVariable, int, float],
dtype: Union[str, np.generic, np.dtype] = None,
) -> TensorVariable:
"""Equivalent of `numpy.full_like`.
Returns
-------
tensor
tensor the shape of `a` containing `fill_value` of the type of dtype.
"""
fill_value = as_tensor_variable(fill_value)
if dtype is not None:
fill_value = fill_value.astype(dtype)
return fill(a, fill_value)


class MakeVector(COp):
"""Concatenate a number of scalars together into a vector.
Expand Down Expand Up @@ -4482,6 +4500,7 @@ def take_along_axis(arr, indices, axis=0):
"as_tensor",
"extract_diag",
"full",
"full_like",
"empty",
"empty_like",
]
18 changes: 18 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
fill,
flatnonzero,
flatten,
full_like,
get_scalar_constant_value,
get_vector_length,
horizontal_stack,
Expand Down Expand Up @@ -4216,3 +4217,20 @@ def test_ndim_dtype_failures(self):
indices = aet.tensor(np.float64, [False] * 2)
with pytest.raises(IndexError):
aet.take_along_axis(arr, indices)


@pytest.mark.parametrize(
"inp, shape",
[(scalar, ()), (vector, 3), (matrix, (3, 4))],
)
def test_full_like(inp, shape):
fill_value = 5
dtype = config.floatX

x = inp("x")
y = full_like(x, fill_value, dtype=dtype)

np.testing.assert_array_equal(
y.eval({x: np.zeros(shape, dtype=dtype)}),
np.full(shape, fill_value, dtype=dtype),
)

0 comments on commit b0ba476

Please sign in to comment.