@@ -30,7 +30,9 @@ def run(sz, n_gpus=6, iters=10, use_ring=False):
30
30
with Context (RING = (2 if use_ring else 0 ), DEBUG = max (DEBUG .value , 2 )): return test (devs , N , iters = iters )
31
31
32
32
def main ():
33
+ ONLY_RING = getenv ("ONLY_RING" , 0 )
33
34
n_gpus = getenv ("GPUS" , 6 )
35
+ iters = getenv ("ITERS" , 10 )
34
36
35
37
if getenv ("BENCHMARK_SPLIT" ):
36
38
l , r = 0 , 512
@@ -44,10 +46,10 @@ def main():
44
46
else :
45
47
sz = getenv ("SZ" , 1000 ) * 10 ** 6 # size of data on each gpu
46
48
print (f"Using { sz / 10 ** 9 :.2f} GB of numbers on each of { n_gpus } GPUs, { n_gpus * sz / 10 ** 9 :.2f} GB total." )
47
- (ring_gflops , ring_gbs , ring_secs ) = run (sz , use_ring = True , n_gpus = n_gpus )
48
- (naive_gflops , naive_gbs , naive_secs ) = run (sz , use_ring = False , n_gpus = n_gpus )
49
+ (ring_gflops , ring_gbs , ring_secs ) = run (sz , use_ring = True , n_gpus = n_gpus , iters = iters )
50
+ if not ONLY_RING : (naive_gflops , naive_gbs , naive_secs ) = run (sz , use_ring = False , n_gpus = n_gpus , iters = iters )
49
51
print (f"Ring:\n { ring_secs :.6f} seconds/iter\n { ring_gflops :.2f} GFLOP/s\n { ring_gbs :.2f} GB/s" )
50
- print (f"Naive:\n { naive_secs :.6f} seconds/iter\n { naive_gflops :.2f} GFLOP/s\n { naive_gbs :.2f} GB/s" )
52
+ if not ONLY_RING : print (f"Naive:\n { naive_secs :.6f} seconds/iter\n { naive_gflops :.2f} GFLOP/s\n { naive_gbs :.2f} GB/s" )
51
53
52
54
if __name__ == "__main__" :
53
55
main ()
0 commit comments