@@ -326,13 +326,9 @@ def test_selective_state_update(dim, dstate, has_z, itype):
326
326
327
327
@pytest .mark .parametrize ("itype" ,
328
328
[torch .float32 , torch .float16 , torch .bfloat16 ])
329
- # @pytest.mark.parametrize('itype', [torch.float16])
330
329
@pytest .mark .parametrize ("has_z" , [False , True ])
331
- # @pytest.mark.parametrize('has_z', [True])
332
330
@pytest .mark .parametrize ("dstate" , [16 , 32 , 64 ])
333
- # @pytest.mark.parametrize("dstate", [16])
334
331
@pytest .mark .parametrize ("dim" , [2048 , 2048 + 16 , 4096 ])
335
- # @pytest.mark.parametrize("dim", [2048])
336
332
def test_selective_state_update_with_batch_indices (dim , dstate , has_z , itype ):
337
333
device = "cuda"
338
334
rtol , atol = (3e-4 , 1e-3 ) if itype == torch .float32 else (5e-3 , 1e-2 )
@@ -391,17 +387,11 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
391
387
392
388
@pytest .mark .parametrize ("itype" ,
393
389
[torch .float32 , torch .float16 , torch .bfloat16 ])
394
- #@pytest.mark.parametrize('itype', [torch.float32])
395
390
@pytest .mark .parametrize ("has_z" , [False , True ])
396
- # @pytest.mark.parametrize('has_z', [True])
397
391
@pytest .mark .parametrize ("tie_hdim" , [False , True ])
398
- # @pytest.mark.parametrize('tie_hdim', [True])
399
392
@pytest .mark .parametrize ("ngroups" , [1 , 2 , 4 ])
400
- # @pytest.mark.parametrize("ngroups", [2])
401
393
@pytest .mark .parametrize ("dstate" , [16 , 32 , 64 ])
402
- # @pytest.mark.parametrize("dstate", [16])
403
394
@pytest .mark .parametrize ("dim" , [2048 , 4096 ])
404
- # @pytest.mark.parametrize("dim", [2048])
405
395
def test_selective_state_update_with_heads_with_batch_indices (
406
396
dim , dstate , ngroups , has_z , tie_hdim , itype ):
407
397
device = "cuda"
0 commit comments