@@ -120,25 +120,24 @@ def test_threefry_against_reference_full(self):
120
120
0.3108327388763428 , 0.09639489650726318 , 0.004686474800109863 , 0.8435229063034058 , 0.824237585067749 ,
121
121
0.5873836278915405 , 0.4232727289199829 , 0.2530076503753662 , 0.40300023555755615 , 0.03966474533081055 ,
122
122
0.27904558181762695 , 0.9150195121765137 , 0.48057758808135986 , 0.23821306228637695 , 0.7676635980606079 ], dtype = np .float32 )
123
-
124
123
r = Tensor .rand (20 ).numpy ()
125
- np .testing .assert_allclose (jr , r , atol = 1e-5 , rtol = 1e-5 )
124
+ np .testing .assert_allclose (r , jr , atol = 1e-5 , rtol = 1e-5 )
126
125
127
126
# next 20, np.arange(20, 40, dtype=np.uint32)
128
127
jr = np .array ([0.7444133758544922 , 0.7713677883148193 , 0.8233780860900879 , 0.43871235847473145 , 0.517757773399353 ,
129
128
0.6437174081802368 , 0.967403769493103 , 0.26167726516723633 , 0.6825339794158936 , 0.14966607093811035 ,
130
129
0.28920769691467285 , 0.017063498497009277 , 0.2627382278442383 , 0.9525482654571533 , 0.9351049661636353 ,
131
130
0.43904995918273926 , 0.043945908546447754 , 0.6616791486740112 , 0.6667773723602295 , 0.5228077173233032 ], dtype = np .float32 )
132
131
r = Tensor .rand (20 ).numpy ()
133
- np .testing .assert_allclose (jr , r , atol = 1e-5 , rtol = 1e-5 )
132
+ np .testing .assert_allclose (r , jr , atol = 1e-5 , rtol = 1e-5 )
134
133
135
134
# next 10, np.arange(40, 50, dtype=np.uint32)
136
135
jr = np .array ([0.9614430665969849 , 0.059279561042785645 , 0.01909029483795166 , 0.47882091999053955 , 0.9677121639251709 ,
137
136
0.36863112449645996 , 0.3102607727050781 , 0.06608951091766357 , 0.35329878330230713 , 0.26518797874450684 ], dtype = np .float32 )
138
137
r = Tensor .rand (10 ).numpy ()
139
138
# TODO: this failed because increment happened before _threefry_random_bits
140
139
with self .assertRaises (AssertionError ):
141
- np .testing .assert_allclose (jr , r , atol = 1e-5 , rtol = 1e-5 )
140
+ np .testing .assert_allclose (r , jr , atol = 1e-5 , rtol = 1e-5 )
142
141
143
142
@unittest .skipIf (CI and Device .DEFAULT in ("GPU" , "CUDA" , "METAL" , "NV" ), "no GPU CI" )
144
143
def test_threefry_tensors_cnt (self ):
0 commit comments