Skip to content

Commit 61a23ea

Browse files
committed
Update to Cutlass 3.6.0
1 parent 9375ac9 commit 61a23ea

File tree

3 files changed

+17
-17
lines changed

3 files changed

+17
-17
lines changed

Diff for: csrc/cutlass

Submodule cutlass updated 582 files

Diff for: csrc/flash_attn/src/flash_fwd_kernel.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
12211221
constexpr int kBlockN = kNThreads / kBlockM;
12221222
using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
12231223
using GmemTiledCopyOaccum = decltype(
1224-
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
1224+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
12251225
GmemLayoutAtomOaccum{},
12261226
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
12271227
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;

Diff for: csrc/flash_attn/src/kernel_traits.h

+15-15
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
101101
using SmemLayoutO = decltype(tile_to_shape(
102102
SmemLayoutAtomO{},
103103
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
104-
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
105-
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
104+
using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
105+
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
106106

107107
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
108108
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
@@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base {
125125
using Gmem_copy_struct = std::conditional_t<
126126
Has_cp_async,
127127
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
128-
DefaultCopy
128+
AutoVectorizingCopyWithAssumedAlignment<128>
129129
>;
130130
using GmemTiledCopyQKV = decltype(
131131
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
132132
GmemLayoutAtom{},
133133
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
134134
using GmemTiledCopyO = decltype(
135-
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
135+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
136136
GmemLayoutAtom{},
137137
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
138138

@@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
144144
Stride< _16, _1>>
145145
>;
146146
using GmemTiledCopyOaccum = decltype(
147-
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
147+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
148148
GmemLayoutAtomOaccum{},
149149
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
150150
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
@@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base {
153153
GmemLayoutAtomRotcossin{},
154154
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
155155
using GmemTiledCopyRotcossinCont = decltype(
156-
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
156+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
157157
GmemLayoutAtomRotcossin{},
158158
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
159159
};
@@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base {
250250
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
251251
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
252252

253-
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
253+
using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
254254

255255
using SmemLayoutQdOtransposed = decltype(
256256
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
@@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base {
263263
using SmemLayoutdKV = decltype(tile_to_shape(
264264
SmemLayoutAtomdKV{},
265265
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
266-
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
266+
using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
267267

268268
using SmemLayoutAtomdQ = decltype(
269269
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base {
272272
using SmemLayoutdQ = decltype(tile_to_shape(
273273
SmemLayoutAtomdQ{},
274274
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
275-
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
275+
using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
276276

277277
// Double buffer for sQ
278278
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
@@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base {
303303
using Gmem_copy_struct = std::conditional_t<
304304
Has_cp_async,
305305
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
306-
DefaultCopy
306+
AutoVectorizingCopyWithAssumedAlignment<128>
307307
>;
308308
using GmemTiledCopyQKV = decltype(
309309
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
310310
GmemLayoutAtom{},
311311
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
312312
using GmemTiledCopydO = decltype(
313-
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
313+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
314314
GmemLayoutAtom{},
315315
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
316316
using GmemTiledCopydKV = decltype(
317-
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
317+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
318318
GmemLayoutAtom{},
319319
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
320320
using GmemTiledCopydQ = decltype(
321-
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
321+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
322322
GmemLayoutAtom{},
323323
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
324324
using GmemLayoutAtomdQaccum = std::conditional_t<
@@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base {
329329
Stride< _16, _1>>
330330
>;
331331
using GmemTiledCopydQaccum = decltype(
332-
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
332+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
333333
GmemLayoutAtomdQaccum{},
334334
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
335335

336336
using GmemTiledCopydQaccumAtomicAdd = decltype(
337-
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
337+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
338338
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
339339
Stride<_32, _1>>{},
340340
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store

0 commit comments

Comments
 (0)