Skip to content

Commit 6427272

Browse files
authored
minor update to rand [pr] (tinygrad#9566)
1 parent b0e070e commit 6427272

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

test/test_randomness.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -120,25 +120,24 @@ def test_threefry_against_reference_full(self):
120120
0.3108327388763428, 0.09639489650726318, 0.004686474800109863, 0.8435229063034058, 0.824237585067749,
121121
0.5873836278915405, 0.4232727289199829, 0.2530076503753662, 0.40300023555755615, 0.03966474533081055,
122122
0.27904558181762695, 0.9150195121765137, 0.48057758808135986, 0.23821306228637695, 0.7676635980606079], dtype=np.float32)
123-
124123
r = Tensor.rand(20).numpy()
125-
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
124+
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
126125

127126
# next 20, np.arange(20, 40, dtype=np.uint32)
128127
jr = np.array([0.7444133758544922, 0.7713677883148193, 0.8233780860900879, 0.43871235847473145, 0.517757773399353,
129128
0.6437174081802368, 0.967403769493103, 0.26167726516723633, 0.6825339794158936, 0.14966607093811035,
130129
0.28920769691467285, 0.017063498497009277, 0.2627382278442383, 0.9525482654571533, 0.9351049661636353,
131130
0.43904995918273926, 0.043945908546447754, 0.6616791486740112, 0.6667773723602295, 0.5228077173233032], dtype=np.float32)
132131
r = Tensor.rand(20).numpy()
133-
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
132+
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
134133

135134
# next 10, np.arange(40, 50, dtype=np.uint32)
136135
jr = np.array([0.9614430665969849, 0.059279561042785645, 0.01909029483795166, 0.47882091999053955, 0.9677121639251709,
137136
0.36863112449645996, 0.3102607727050781, 0.06608951091766357, 0.35329878330230713, 0.26518797874450684], dtype=np.float32)
138137
r = Tensor.rand(10).numpy()
139138
# TODO: this failed because increment happened before _threefry_random_bits
140139
with self.assertRaises(AssertionError):
141-
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
140+
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
142141

143142
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
144143
def test_threefry_tensors_cnt(self):

tinygrad/tensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,10 @@ def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:boo
501501
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
502502
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
503503
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
504-
_device = device = Device.canonicalize(device)
504+
device = Device.canonicalize(device)
505505

506506
# if shape has 0, return zero tensor
507-
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
507+
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
508508
num = ceildiv(numel * dtype.itemsize, 4)
509509

510510
# generate per device seeds and rng counter if we haven't seen this device yet

0 commit comments

Comments
 (0)