4
4
import math
5
5
import pickle as pkl
6
6
import time
7
- from typing import Callable , Iterable , List , Tuple
7
+ from itertools import product
8
+ from typing import Callable , Iterable , List , Optional , Tuple
8
9
10
+ import pandas as pd
9
11
import torch
10
12
import torch .utils .benchmark as TBenchmark
11
13
from torch .utils .benchmark import Measurement as TMeasurement
@@ -84,6 +86,10 @@ def loop_over_weights(
84
86
fn (a , w_ref , w_q , w_s )
85
87
86
88
89
+ _SWEEP_SCHEDULES_RESULTS : Optional [pd .DataFrame ] = None
90
+ _SWEEP_SCHEDULES_RESULTS_CSV : Optional [str ] = None
91
+
92
+
87
93
def bench (atype : torch .dtype ,
88
94
wtype : ScalarType ,
89
95
group_size : int ,
@@ -94,6 +100,8 @@ def bench(atype: torch.dtype,
94
100
sub_label : str ,
95
101
benchmark_marlinv1 : bool = True ,
96
102
sweep_schedules : bool = True ) -> Iterable [TMeasurement ]:
103
+ global _SWEEP_SCHEDULES_RESULTS
104
+
97
105
a , weights = make_bench_tensors (atype , wtype , group_size , m , n , k )
98
106
sub_label += f", L={ len (weights )} "
99
107
@@ -163,6 +171,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
163
171
best_schedule = None
164
172
schedules = ops .machete_supported_schedules (wtype )
165
173
for schedule in reversed (schedules ):
174
+ schedule_M = int (schedule .split ("_" )[0 ].split ("x" )[1 ])
175
+
176
+ # Prune known bad schedules
177
+ if schedule_M >= 2 * max (m , 16 ) or schedule_M < m // 4 :
178
+ continue
166
179
167
180
def run (a , _ , w_q , w_s , schedule = schedule ):
168
181
ops .machete_gemm (a ,
@@ -175,6 +188,20 @@ def run(a, _, w_q, w_s, schedule=schedule):
175
188
res = bench_fn (label , sub_label , "machete_best" ,
176
189
lambda : loop_over_weights (a , weights_machete , run ))
177
190
191
+ results_row = {
192
+ "M" : m ,
193
+ "K" : k ,
194
+ "N" : n ,
195
+ "group_size" : group_size ,
196
+ "schedule" : schedule ,
197
+ "median" : res .median ,
198
+ }
199
+ if _SWEEP_SCHEDULES_RESULTS is None :
200
+ _SWEEP_SCHEDULES_RESULTS = pd .DataFrame (
201
+ columns = results_row .keys ())
202
+ _SWEEP_SCHEDULES_RESULTS .\
203
+ loc [len (_SWEEP_SCHEDULES_RESULTS )] = results_row
204
+
178
205
print (f" { res .median :5.5} " , schedule )
179
206
if not best or res .median < best .median :
180
207
best = res
@@ -235,18 +262,22 @@ def run_square_bench(args):
235
262
dim_sizes = list (
236
263
range (args .dim_start , args .dim_end + 1 , args .dim_increment ))
237
264
MKNs = list (zip (dim_sizes , dim_sizes , dim_sizes ))
265
+
238
266
data = run (args .dtype , args .sweep_schedules , MKNs )
239
267
240
268
make_output (data , MKNs , f"square_bench-{ args .dtype } " )
241
269
242
270
243
271
def run_range_bench (args ):
244
- dim_sizes = list (range (args .dim_start , args .dim_end , args .dim_increment ))
245
- n = len (dim_sizes )
246
- Ms = [args .m_constant ] * n if args .m_constant is not None else dim_sizes
247
- Ks = [args .k_constant ] * n if args .k_constant is not None else dim_sizes
248
- Ns = [args .n_constant ] * n if args .n_constant is not None else dim_sizes
249
- MKNs = list (zip (Ms , Ks , Ns ))
272
+ m_start , k_start , n_start = [int (x ) for x in args .dim_start .split ("," )]
273
+ m_end , k_end , n_end = [int (x ) for x in args .dim_end .split ("," )]
274
+ m_increment , k_increment , n_increment = \
275
+ [int (x ) for x in args .dim_increment .split ("," )]
276
+ Ms = list (range (m_start , m_end + 1 , m_increment ))
277
+ Ks = list (range (k_start , k_end + 1 , k_increment ))
278
+ Ns = list (range (n_start , n_end + 1 , n_increment ))
279
+ MKNs = list (product (Ms , Ks , Ns ))
280
+
250
281
data = run (args .dtype , args .sweep_schedules , MKNs )
251
282
252
283
make_output (data , MKNs , f"range_bench-{ args .dtype } " )
@@ -333,6 +364,9 @@ def to_torch_dtype(dt):
333
364
action = "store_true" ,
334
365
help = "Run a sweep over all supported schedules" ,
335
366
)
367
+ parser .add_argument ("--sweep-csv-out" ,
368
+ help = "CSV to store sweep results" ,
369
+ default = "sch_sweep_results.csv" )
336
370
subparsers = parser .add_subparsers (dest = "cmd" , required = True )
337
371
338
372
square_parser = subparsers .add_parser ("square_bench" )
@@ -342,12 +376,21 @@ def to_torch_dtype(dt):
342
376
square_parser .set_defaults (func = run_square_bench )
343
377
344
378
range_parser = subparsers .add_parser ("range_bench" )
345
- range_parser .add_argument ("--dim-start" , type = int , required = True )
346
- range_parser .add_argument ("--dim-end" , type = int , required = True )
347
- range_parser .add_argument ("--dim-increment" , type = int , required = True )
348
- range_parser .add_argument ("--m-constant" , type = int , default = None )
349
- range_parser .add_argument ("--n-constant" , type = int , default = None )
350
- range_parser .add_argument ("--k-constant" , type = int , default = None )
379
+ range_parser .add_argument (
380
+ "--dim-start" ,
381
+ type = str ,
382
+ required = True ,
383
+ help = "Start value for M,K,N as common separated list" )
384
+ range_parser .add_argument (
385
+ "--dim-end" ,
386
+ type = str ,
387
+ required = True ,
388
+ help = "End value (inclusive) for M,K,N as common separated list" )
389
+ range_parser .add_argument (
390
+ "--dim-increment" ,
391
+ type = str ,
392
+ required = True ,
393
+ help = "Increment value for M,K,N as common separated list" )
351
394
range_parser .set_defaults (func = run_range_bench )
352
395
353
396
model_parser = subparsers .add_parser ("model_bench" )
@@ -369,4 +412,9 @@ def to_torch_dtype(dt):
369
412
model_parser .set_defaults (func = run_model_bench )
370
413
371
414
args = parser .parse_args ()
415
+
416
+ _SWEEP_SCHEDULES_RESULTS_CSV = args .sweep_csv_out
372
417
args .func (args )
418
+
419
+ if _SWEEP_SCHEDULES_RESULTS is not None :
420
+ _SWEEP_SCHEDULES_RESULTS .to_csv (_SWEEP_SCHEDULES_RESULTS_CSV )
0 commit comments