Skip to content

Commit a14b40b

Browse files
committed
add SiLU activation function
1 parent 4722a99 commit a14b40b

File tree

7 files changed

+160
-0
lines changed

7 files changed

+160
-0
lines changed

dfdx-core/src/tensor_ops/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ mod roll;
197197
mod select_and_gather;
198198
mod sgd;
199199
mod sigmoid;
200+
mod silu;
200201
mod sin;
201202
mod slice;
202203
mod softmax;
@@ -264,6 +265,7 @@ pub use roll::Roll;
264265
pub use select_and_gather::{GatherTo, SelectTo};
265266
pub use sgd::SgdConfig;
266267
pub use sigmoid::sigmoid;
268+
pub use silu::silu;
267269
pub use sin::sin;
268270
pub use slice::slice;
269271
pub use softmax::softmax;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use crate::tensor_ops::cpu_kernels::UnaryDerivative;
2+
3+
impl<F: num_traits::Float> UnaryDerivative<F> for super::SiLUKernelOp {
4+
const DF_USES_FX: bool = false;
5+
const HAS_CONST_DF: bool = false;
6+
7+
// x / (1 + e^-x)
8+
#[inline(always)]
9+
fn f(&self, x: &F) -> F {
10+
*x / (F::one() + x.neg().exp())
11+
}
12+
13+
// (1 + e^-x + x * e^-x) / (1 + e^-x)^2
14+
// alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2
15+
#[inline(always)]
16+
fn df(&self, x: &F) -> F {
17+
let exp_nx = x.neg().exp();
18+
F::one() + exp_nx + *x * exp_nx / (F::one() + exp_nx).powi(2)
19+
}
20+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
use super::SiLUKernelOp;
2+
#[allow(unused_imports)]
3+
use crate::dtypes::*;
4+
use crate::tensor_ops::cuda_kernels::cuda_unary;
5+
6+
unsafe impl cudarc::driver::DeviceRepr for SiLUKernelOp {}
7+
8+
const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/silu.ptx"));
9+
10+
#[cfg(feature = "f16")]
11+
cuda_unary!(SiLUKernelOp, f16, PTX, "silu_fwd_f16", "silu_bwd_f16");
12+
#[cfg(feature = "f16")]
13+
cuda_unary!(SiLUKernelOp, AMP<f16>, PTX, "silu_fwd_f16", "silu_bwd_f16");
14+
cuda_unary!(SiLUKernelOp, f32, PTX, "silu_fwd_f32", "silu_bwd_f32");
15+
cuda_unary!(SiLUKernelOp, f64, PTX, "silu_fwd_f64", "silu_bwd_f64");

dfdx-core/src/tensor_ops/silu/mod.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
mod cpu_kernel;
2+
3+
#[cfg(feature = "cuda")]
4+
mod cuda_kernel;
5+
6+
#[cfg(feature = "webgpu")]
7+
mod webgpu_kernel;
8+
9+
use super::ops::{try_unary_op, UnaryKernel};
10+
use crate::{shapes::*, tensor::*};
11+
12+
#[repr(C)]
13+
#[derive(Debug, Default, Copy, Clone)]
14+
pub struct SiLUKernelOp;
15+
16+
/// [Sigmoid-Weighted Linear Unit (SiLU)](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)). `x * x.sigmoid()`
17+
///
18+
/// The derivative is `x * sigmoid'(x) + sigmoid(x)`.
19+
///
20+
/// Examples:
21+
/// ```rust
22+
/// # use dfdx_core::prelude::*;
23+
/// # let dev: Cpu = Default::default();
24+
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
25+
/// let r = t.silu();
26+
/// ```
27+
pub fn silu<S: Shape, E: Dtype, D: UnaryKernel<SiLUKernelOp, E>, T: Tape<E, D>>(
28+
t: Tensor<S, E, D, T>,
29+
) -> Tensor<S, E, D, T> {
30+
t.silu()
31+
}
32+
33+
impl<S: Shape, E: Dtype, D: UnaryKernel<SiLUKernelOp, E>, T: Tape<E, D>> Tensor<S, E, D, T> {
34+
/// See [silu]
35+
pub fn silu(self) -> Self {
36+
self.try_silu().unwrap()
37+
}
38+
/// See [silu]
39+
pub fn try_silu(self) -> Result<Self, crate::tensor::Error> {
40+
try_unary_op(SiLUKernelOp, self)
41+
}
42+
}
43+
44+
#[cfg(test)]
45+
mod tests {
46+
use crate::{tensor::*, tensor_ops::*, tests::*};
47+
48+
#[test]
49+
fn test_silu() {
50+
let dev: TestDevice = Default::default();
51+
let x = dev
52+
.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
53+
.to_dtype::<TestDtype>();
54+
let r = x.leaky_trace().silu();
55+
assert_close_to_literal!(r, [-0.23840584, -0.26894143, 0.0, 0.7310586, 1.761594]);
56+
let g = r.mean().backward();
57+
assert_close_to_literal!(
58+
g.get(&x),
59+
[1.635814, 0.70433396, 0.4, 0.31289828, 0.26906452]
60+
);
61+
}
62+
}

dfdx-core/src/tensor_ops/silu/silu.cu

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "unary_op_macros.cuh"
2+
3+
struct SiLUKernelOp {};
4+
5+
// x / (1 + e^-x)
6+
template<typename T>
7+
__device__ __forceinline__ T silu_fwd(T x) {
8+
T one = 1.0;
9+
return x / (one + expg(-x));
10+
}
11+
12+
// (1 + e^-x + x * e^-x) / (1 + e^-x)^2
13+
// alternative: (e^x (1 + e^x + x)) / (1 + e^x)^2
14+
template<typename T>
15+
__device__ __forceinline__ T silu_bwd(T x) {
16+
T one = 1.0;
17+
T exp_nx = expg(-x);
18+
T denom_sqrt = (one + exp_nx);
19+
return (one + exp_nx + x * exp_nx) / (denom_sqrt * denom_sqrt);
20+
}
21+
22+
UNARY_OP(__half, silu_fwd_f16, silu_bwd_f16, SiLUKernelOp,
23+
silu_fwd(x),
24+
silu_bwd(x))
25+
26+
UNARY_OP(float, silu_fwd_f32, silu_bwd_f32, SiLUKernelOp,
27+
silu_fwd(x),
28+
silu_bwd(x))
29+
30+
UNARY_OP(double, silu_fwd_f64, silu_bwd_f64, SiLUKernelOp,
31+
silu_fwd(x),
32+
silu_bwd(x))
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use std::borrow::Cow;
2+
3+
use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu};
4+
5+
impl<E: Dtype> UnaryKernel<super::SiLUKernelOp, E> for Webgpu {
6+
const BACKWARD_WITHOUT_INP: bool = false;
7+
8+
const BACKWARD_WITHOUT_DATA: bool = false;
9+
10+
fn forward<S: crate::prelude::Shape>(
11+
&self,
12+
op: super::SiLUKernelOp,
13+
inp: Cow<crate::prelude::Tensor<S, E, Self>>,
14+
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
15+
todo!()
16+
}
17+
18+
fn backward<S: crate::prelude::Shape>(
19+
&self,
20+
op: super::SiLUKernelOp,
21+
inp: &impl crate::prelude::Tensorlike<S, E, Self>,
22+
grad_inp: &mut Self::Vec,
23+
out: &impl crate::prelude::Tensorlike<S, E, Self>,
24+
grad_out: &Self::Vec,
25+
) -> Result<(), crate::prelude::Error> {
26+
todo!()
27+
}
28+
}

dfdx-core/src/tensor_ops/utilities/device.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ pub trait Device<E: Dtype>:
9292
+ UnaryKernel<super::super::fast_gelu::FastGeLUKernelOp, E>
9393
+ UnaryKernel<super::super::accurate_gelu::AccurateGeLUKernelOp, E>
9494
+ UnaryKernel<super::super::sigmoid::SigmoidKernelOp, E>
95+
+ UnaryKernel<super::super::silu::SiLUKernelOp, E>
9596
+ UnaryKernel<super::super::sin::SinKernelOp, E>
9697
+ UnaryKernel<super::super::sqrt::SqrtKernelOp, E>
9798
+ UnaryKernel<super::super::square::SquareKernelOp, E>

0 commit comments

Comments
 (0)