@@ -322,3 +322,161 @@ def test_selective_state_update(dim, dstate, has_z, itype):
322
322
323
323
assert torch .allclose (state , state_ref , rtol = rtol , atol = atol )
324
324
assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
325
+
326
+
327
+ @pytest .mark .parametrize ("itype" ,
328
+ [torch .float32 , torch .float16 , torch .bfloat16 ])
329
+ # @pytest.mark.parametrize('itype', [torch.float16])
330
+ @pytest .mark .parametrize ("has_z" , [False , True ])
331
+ # @pytest.mark.parametrize('has_z', [True])
332
+ @pytest .mark .parametrize ("dstate" , [16 , 32 , 64 ])
333
+ # @pytest.mark.parametrize("dstate", [16])
334
+ @pytest .mark .parametrize ("dim" , [2048 , 2048 + 16 , 4096 ])
335
+ # @pytest.mark.parametrize("dim", [2048])
336
+ def test_selective_state_update_with_batch_indices (dim , dstate , has_z , itype ):
337
+ device = "cuda"
338
+ rtol , atol = (3e-4 , 1e-3 ) if itype == torch .float32 else (5e-3 , 1e-2 )
339
+ if itype == torch .bfloat16 :
340
+ rtol , atol = 6e-2 , 6e-2
341
+ if torch .version .hip :
342
+ atol *= 2
343
+ # set seed
344
+ torch .random .manual_seed (0 )
345
+ batch_size = 16
346
+
347
+ total_entries = 10 * batch_size
348
+ state = torch .randn (total_entries , dim , dstate , dtype = itype , device = device )
349
+ state_indices = torch .randperm (total_entries )[:batch_size ].to (
350
+ dtype = torch .int32 , device = device )
351
+
352
+ x = torch .randn (batch_size , dim , device = device , dtype = itype )
353
+ dt = torch .randn (batch_size , dim , device = device , dtype = itype )
354
+ dt_bias = torch .rand (dim , device = device ) - 4.0
355
+ A = - torch .rand (dim , dstate , device = device ) - 1.0
356
+ B = torch .randn (batch_size , dstate , device = device )
357
+ C = torch .randn (batch_size , dstate , device = device )
358
+ D = torch .randn (dim , device = device )
359
+ z = torch .randn_like (x ) if has_z else None
360
+ state_ref = state [state_indices , :].detach ().clone ()
361
+ out = selective_state_update (state ,
362
+ x ,
363
+ dt ,
364
+ A ,
365
+ B ,
366
+ C ,
367
+ D = D ,
368
+ z = z ,
369
+ dt_bias = dt_bias ,
370
+ dt_softplus = True ,
371
+ state_batch_indices = state_indices )
372
+ out_ref = selective_state_update_ref (state_ref ,
373
+ x ,
374
+ dt ,
375
+ A ,
376
+ B ,
377
+ C ,
378
+ D = D ,
379
+ z = z ,
380
+ dt_bias = dt_bias ,
381
+ dt_softplus = True )
382
+
383
+ print (f"Output max diff: { (out - out_ref ).abs ().max ().item ()} " )
384
+ print (f"Output mean diff: { (out - out_ref ).abs ().mean ().item ()} " )
385
+ assert torch .allclose (state [state_indices , :],
386
+ state_ref ,
387
+ rtol = rtol ,
388
+ atol = atol )
389
+ assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
390
+
391
+
392
+ @pytest .mark .parametrize ("itype" ,
393
+ [torch .float32 , torch .float16 , torch .bfloat16 ])
394
+ #@pytest.mark.parametrize('itype', [torch.float32])
395
+ @pytest .mark .parametrize ("has_z" , [False , True ])
396
+ # @pytest.mark.parametrize('has_z', [True])
397
+ @pytest .mark .parametrize ("tie_hdim" , [False , True ])
398
+ # @pytest.mark.parametrize('tie_hdim', [True])
399
+ @pytest .mark .parametrize ("ngroups" , [1 , 2 , 4 ])
400
+ # @pytest.mark.parametrize("ngroups", [2])
401
+ @pytest .mark .parametrize ("dstate" , [16 , 32 , 64 ])
402
+ # @pytest.mark.parametrize("dstate", [16])
403
+ @pytest .mark .parametrize ("dim" , [2048 , 4096 ])
404
+ # @pytest.mark.parametrize("dim", [2048])
405
+ def test_selective_state_update_with_heads_with_batch_indices (
406
+ dim , dstate , ngroups , has_z , tie_hdim , itype ):
407
+ device = "cuda"
408
+ rtol , atol = (3e-4 , 1e-3 ) if itype == torch .float32 else (5e-3 , 3e-2 )
409
+ if itype == torch .bfloat16 :
410
+ rtol , atol = 1e-1 , 1e-1
411
+ # set seed
412
+ torch .random .manual_seed (0 )
413
+ batch_size = 16
414
+ headdim = 64
415
+ nheads = dim // headdim
416
+
417
+ total_entries = 10 * batch_size
418
+ state = torch .randn (total_entries ,
419
+ nheads ,
420
+ headdim ,
421
+ dstate ,
422
+ dtype = itype ,
423
+ device = device )
424
+ state_indices = torch .randperm (total_entries )[:batch_size ].to (
425
+ dtype = torch .int32 , device = device )
426
+
427
+ x = torch .randn (batch_size , nheads , headdim , device = device , dtype = itype )
428
+ if not tie_hdim :
429
+ dt = torch .randn (batch_size ,
430
+ nheads ,
431
+ headdim ,
432
+ device = device ,
433
+ dtype = itype )
434
+ dt_bias = torch .rand (nheads , headdim , device = device ) - 4.0
435
+ A = - torch .rand (nheads , headdim , dstate , device = device ) - 1.0
436
+ D = torch .randn (nheads , headdim , device = device )
437
+ else :
438
+ dt = repeat (torch .randn (batch_size , nheads , device = device ,
439
+ dtype = itype ),
440
+ "b h -> b h p" ,
441
+ p = headdim )
442
+ dt_bias = repeat (torch .rand (nheads , device = device ) - 4.0 ,
443
+ "h -> h p" ,
444
+ p = headdim )
445
+ A = repeat (- torch .rand (nheads , device = device ) - 1.0 ,
446
+ "h -> h p n" ,
447
+ p = headdim ,
448
+ n = dstate )
449
+ D = repeat (torch .randn (nheads , device = device ), "h -> h p" , p = headdim )
450
+ B = torch .randn (batch_size , ngroups , dstate , device = device )
451
+ C = torch .randn (batch_size , ngroups , dstate , device = device )
452
+ z = torch .randn_like (x ) if has_z else None
453
+ state_ref = state [state_indices , :].detach ().clone ()
454
+ out = selective_state_update (state ,
455
+ x ,
456
+ dt ,
457
+ A ,
458
+ B ,
459
+ C ,
460
+ D = D ,
461
+ z = z ,
462
+ dt_bias = dt_bias ,
463
+ dt_softplus = True ,
464
+ state_batch_indices = state_indices )
465
+ out_ref = selective_state_update_ref (state_ref ,
466
+ x ,
467
+ dt ,
468
+ A ,
469
+ B ,
470
+ C ,
471
+ D = D ,
472
+ z = z ,
473
+ dt_bias = dt_bias ,
474
+ dt_softplus = True )
475
+
476
+ print (f"Output max diff: { (out - out_ref ).abs ().max ().item ()} " )
477
+ print (f"Output mean diff: { (out - out_ref ).abs ().mean ().item ()} " )
478
+ assert torch .allclose (state [state_indices , :],
479
+ state_ref ,
480
+ rtol = rtol ,
481
+ atol = atol )
482
+ assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
0 commit comments