diff --git a/cubed/random.py b/cubed/random.py index 6c60a6c9..4016f94e 100644 --- a/cubed/random.py +++ b/cubed/random.py @@ -9,14 +9,13 @@ from cubed.vendor.dask.array.core import normalize_chunks -def random(size, *, chunks=None, spec=None): +def random(size, *, dtype=nxp.float64, chunks=None, spec=None): """Return random floats in the half-open interval [0.0, 1.0).""" shape = normalize_shape(size) - dtype = nxp.float64 chunks = normalize_chunks(chunks, shape=shape, dtype=dtype) numblocks = tuple(map(len, chunks)) root_seed = pyrandom.getrandbits(128) - + extra_func_kwargs = dict(dtype=dtype) return map_blocks( _random, dtype=dtype, @@ -24,12 +23,13 @@ def random(size, *, chunks=None, spec=None): spec=spec, numblocks=numblocks, root_seed=root_seed, + extra_func_kwargs=extra_func_kwargs, ) -def _random(x, numblocks=None, root_seed=None, block_id=None): +def _random(x, numblocks=None, root_seed=None, dtype=nxp.float64, block_id=None): stream_id = block_id_to_offset(block_id, numblocks) rg = Generator(Philox(key=root_seed + stream_id)) - out = rg.random(x.shape) + out = rg.random(x.shape, dtype=dtype) out = numpy_array_to_backend_array(out) return out diff --git a/cubed/tests/test_random.py b/cubed/tests/test_random.py index 5144538c..16ce795b 100644 --- a/cubed/tests/test_random.py +++ b/cubed/tests/test_random.py @@ -29,8 +29,22 @@ def test_random(spec, executor): assert a.shape == (10, 10) assert a.chunks == ((4, 4, 2), (5, 5)) + assert a.dtype == xp.float64 x = nxp.unique_values(a.compute(executor=executor)) + assert x.dtype == xp.float64 + assert len(x) > 90 + + +def test_random_dtype(spec, executor): + a = cubed.random.random((10, 10), dtype=xp.float32, chunks=(4, 5), spec=spec) + + assert a.shape == (10, 10) + assert a.chunks == ((4, 4, 2), (5, 5)) + assert a.dtype == xp.float32 + + x = nxp.unique_values(a.compute(executor=executor)) + assert x.dtype == xp.float32 assert len(x) > 90