@@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
101
101
using SmemLayoutO = decltype(tile_to_shape(
102
102
SmemLayoutAtomO{},
103
103
Shape<Int<kBlockM >, Int<kHeadDim >>{}));
104
- using SmemCopyAtomO = Copy_Atom<DefaultCopy , Element>;
105
- using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy , ElementAccum>;
104
+ using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , Element>;
105
+ using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , ElementAccum>;
106
106
107
107
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof (Element);
108
108
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof (Element);
@@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base {
125
125
using Gmem_copy_struct = std::conditional_t <
126
126
Has_cp_async,
127
127
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t >,
128
- DefaultCopy
128
+ AutoVectorizingCopyWithAssumedAlignment< 128 >
129
129
>;
130
130
using GmemTiledCopyQKV = decltype(
131
131
make_tiled_copy (Copy_Atom<Gmem_copy_struct, Element>{},
132
132
GmemLayoutAtom{},
133
133
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
134
134
using GmemTiledCopyO = decltype(
135
- make_tiled_copy (Copy_Atom<DefaultCopy , Element>{},
135
+ make_tiled_copy (Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , Element>{},
136
136
GmemLayoutAtom{},
137
137
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
138
138
@@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
144
144
Stride< _16, _1>>
145
145
>;
146
146
using GmemTiledCopyOaccum = decltype(
147
- make_tiled_copy (Copy_Atom<DefaultCopy , ElementAccum>{},
147
+ make_tiled_copy (Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , ElementAccum>{},
148
148
GmemLayoutAtomOaccum{},
149
149
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
150
150
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
@@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base {
153
153
GmemLayoutAtomRotcossin{},
154
154
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
155
155
using GmemTiledCopyRotcossinCont = decltype(
156
- make_tiled_copy (Copy_Atom<DefaultCopy , Element>{},
156
+ make_tiled_copy (Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , Element>{},
157
157
GmemLayoutAtomRotcossin{},
158
158
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
159
159
};
@@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base {
250
250
composition (SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN >, Int<kBlockM >>{}, GenRowMajor{})));
251
251
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
252
252
253
- using SmemCopyAtomPdS = Copy_Atom<DefaultCopy , elem_type>;
253
+ using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , elem_type>;
254
254
255
255
using SmemLayoutQdOtransposed = decltype(
256
256
composition (SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim >, Int<kBlockM >>{}, GenRowMajor{})));
@@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base {
263
263
using SmemLayoutdKV = decltype(tile_to_shape(
264
264
SmemLayoutAtomdKV{},
265
265
make_shape (Int<kBlockN >{}, Int<kHeadDim >{})));
266
- using SmemCopyAtomdKV = Copy_Atom<DefaultCopy , elem_type>;
266
+ using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , elem_type>;
267
267
268
268
using SmemLayoutAtomdQ = decltype(
269
269
composition (Swizzle<kSwizzle , 3 , 3 >{},
@@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base {
272
272
using SmemLayoutdQ = decltype(tile_to_shape(
273
273
SmemLayoutAtomdQ{},
274
274
make_shape (Int<kBlockM >{}, Int<kHeadDim >{})));
275
- using SmemCopyAtomdQ = Copy_Atom<DefaultCopy , elem_type>;
275
+ using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , elem_type>;
276
276
277
277
// Double buffer for sQ
278
278
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3 ) * sizeof (Element);
@@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base {
303
303
using Gmem_copy_struct = std::conditional_t <
304
304
Has_cp_async,
305
305
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t >,
306
- DefaultCopy
306
+ AutoVectorizingCopyWithAssumedAlignment< 128 >
307
307
>;
308
308
using GmemTiledCopyQKV = decltype(
309
309
make_tiled_copy (Copy_Atom<Gmem_copy_struct, elem_type>{},
310
310
GmemLayoutAtom{},
311
311
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
312
312
using GmemTiledCopydO = decltype(
313
- make_tiled_copy (Copy_Atom<DefaultCopy , elem_type>{},
313
+ make_tiled_copy (Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , elem_type>{},
314
314
GmemLayoutAtom{},
315
315
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
316
316
using GmemTiledCopydKV = decltype(
317
- make_tiled_copy (Copy_Atom<DefaultCopy , elem_type>{},
317
+ make_tiled_copy (Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , elem_type>{},
318
318
GmemLayoutAtom{},
319
319
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
320
320
using GmemTiledCopydQ = decltype(
321
- make_tiled_copy (Copy_Atom<DefaultCopy , elem_type>{},
321
+ make_tiled_copy (Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , elem_type>{},
322
322
GmemLayoutAtom{},
323
323
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
324
324
using GmemLayoutAtomdQaccum = std::conditional_t <
@@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base {
329
329
Stride< _16, _1>>
330
330
>;
331
331
using GmemTiledCopydQaccum = decltype(
332
- make_tiled_copy (Copy_Atom<DefaultCopy , ElementAccum>{},
332
+ make_tiled_copy (Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , ElementAccum>{},
333
333
GmemLayoutAtomdQaccum{},
334
334
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
335
335
336
336
using GmemTiledCopydQaccumAtomicAdd = decltype(
337
- make_tiled_copy (Copy_Atom<DefaultCopy , ElementAccum>{},
337
+ make_tiled_copy (Copy_Atom<AutoVectorizingCopyWithAssumedAlignment< 128 > , ElementAccum>{},
338
338
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
339
339
Stride<_32, _1>>{},
340
340
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
0 commit comments