Skip to content

Commit

Permalink
Add dtype option to identity_like
Browse files Browse the repository at this point in the history
Resolves #816
  • Loading branch information
Tommy Guy authored and Ricardo Vieira committed Jun 30, 2022
1 parent 7393b74 commit d09e222
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
18 changes: 16 additions & 2 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,8 +1422,22 @@ def eye(n, m=None, k=0, dtype=None):
return localop(n, m, k)


def identity_like(x):
return eye(x.shape[0], x.shape[1], k=0, dtype=x.dtype)
def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
"""Create a tensor with ones on main diagonal and zeroes elsewhere.
Parameters
----------
x : tensor
dtype : data-type, optional
Returns
-------
tensor
tensor the shape of x with ones on main diagonal and zeroes elsewhere of type of dtype.
"""
if dtype is None:
dtype = x.dtype
return eye(x.shape[0], x.shape[1], k=0, dtype=dtype)


def infer_broadcastable(shape):
Expand Down
5 changes: 3 additions & 2 deletions doc/library/tensor/basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -777,12 +777,13 @@ Creating Tensors
:returns: An array where all elements are equal to zero, except for the `k`-th
diagonal, whose values are equal to one.

.. function:: identity_like(x)
.. function:: identity_like(x, dtype=None)

:param x: tensor
:param dtype: The dtype of the returned tensor. If `None`, default to dtype of `x`
:returns: A tensor of same shape as `x` that is filled with zeros everywhere
except for the main diagonal, whose values are equal to one. The output
will have same dtype as `x`.
will have same dtype as `x` unless overridden in `dtype`.

.. function:: stack(tensors, axis=0)

Expand Down
10 changes: 10 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
get_scalar_constant_value,
get_vector_length,
horizontal_stack,
identity_like,
infer_broadcastable,
inverse_permutation,
join,
Expand Down Expand Up @@ -4392,6 +4393,15 @@ def test_empty():
assert res.dtype == "int64"


def test_identity_like_dtype():
# Test that we allocate eye correctly via identity_like
m = matrix(dtype="int64")
m_out = identity_like(m)
assert m_out.dtype == m.dtype
m_out_float = identity_like(m, dtype=np.float64)
assert m_out_float.dtype == "float64"


def test_atleast_Nd():
ary1 = dscalar()
res_ary1 = atleast_Nd(ary1, n=1)
Expand Down

0 comments on commit d09e222

Please sign in to comment.