Skip to content

Commit 7153673

Browse files
authored
Fix swiglu backwards return type (Dao-AILab#1337)
1 parent 641db75 commit 7153673

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

flash_attn/ops/activations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def sqrelu_bwd(g, x):
110110
}
111111
"""
112112
swiglu_bwd_codestring = """
113-
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
113+
template <typename T> void swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
114114
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
115115
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
116116
dy = float(x) * x_sigmoid * float(g);

0 commit comments

Comments
 (0)