Skip to content

Commit

Permalink
Support float32 in random number generation (#693)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Feb 11, 2025
1 parent afd519a commit 745f564
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
10 changes: 5 additions & 5 deletions cubed/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,27 @@
from cubed.vendor.dask.array.core import normalize_chunks


def random(size, *, chunks=None, spec=None):
def random(size, *, dtype=nxp.float64, chunks=None, spec=None):
"""Return random floats in the half-open interval [0.0, 1.0)."""
shape = normalize_shape(size)
dtype = nxp.float64
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)
numblocks = tuple(map(len, chunks))
root_seed = pyrandom.getrandbits(128)

extra_func_kwargs = dict(dtype=dtype)
return map_blocks(
_random,
dtype=dtype,
chunks=chunks,
spec=spec,
numblocks=numblocks,
root_seed=root_seed,
extra_func_kwargs=extra_func_kwargs,
)


def _random(x, numblocks=None, root_seed=None, block_id=None):
def _random(x, numblocks=None, root_seed=None, dtype=nxp.float64, block_id=None):
stream_id = block_id_to_offset(block_id, numblocks)
rg = Generator(Philox(key=root_seed + stream_id))
out = rg.random(x.shape)
out = rg.random(x.shape, dtype=dtype)
out = numpy_array_to_backend_array(out)
return out
14 changes: 14 additions & 0 deletions cubed/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,22 @@ def test_random(spec, executor):

assert a.shape == (10, 10)
assert a.chunks == ((4, 4, 2), (5, 5))
assert a.dtype == xp.float64

x = nxp.unique_values(a.compute(executor=executor))
assert x.dtype == xp.float64
assert len(x) > 90


def test_random_dtype(spec, executor):
a = cubed.random.random((10, 10), dtype=xp.float32, chunks=(4, 5), spec=spec)

assert a.shape == (10, 10)
assert a.chunks == ((4, 4, 2), (5, 5))
assert a.dtype == xp.float32

x = nxp.unique_values(a.compute(executor=executor))
assert x.dtype == xp.float32
assert len(x) > 90


Expand Down

0 comments on commit 745f564

Please sign in to comment.