Skip to content

Commit cadf65c

Browse files
committed
stateful has an implicit sequence of 1
1 parent 66ff785 commit cadf65c

File tree

1 file changed

+35
-51
lines changed

1 file changed

+35
-51
lines changed

dfdx/src/nn/layers/mamba_minimal.rs

Lines changed: 35 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ pub mod stateful {
810810
T: Tape<E, D>,
811811
>
812812
Module<(
813-
Tensor<(Batch, C1, DModel), E, D, T>,
813+
Tensor<(Batch, DModel), E, D, T>,
814814
MambaStateCache<Batch, DState, DConv, DInner, E, D, T>,
815815
)> for MambaBlock<DModel, DState, DtRank, DConv, DInner, E, D>
816816
where
@@ -842,35 +842,34 @@ pub mod stateful {
842842
): dfdx_core::tensor_ops::TryConcatShapeAlong<Axis<2>, Output = (Batch, DInner, DConv)>,
843843
{
844844
type Output = (
845-
Tensor<(Batch, C1, DModel), E, D, T>,
845+
Tensor<(Batch, DModel), E, D, T>,
846846
MambaStateCache<Batch, DState, DConv, DInner, E, D, T>,
847847
);
848848

849849
/// Mamba block forward.
850850
fn try_forward(
851851
&self,
852852
x: (
853-
Tensor<(Batch, C1, DModel), E, D, T>,
853+
Tensor<(Batch, DModel), E, D, T>,
854854
MambaStateCache<Batch, DState, DConv, DInner, E, D, T>,
855855
),
856856
) -> Result<Self::Output, Error> {
857857
let (x, mut cache) = x;
858858

859-
// let (batch, _d_model) = *x.shape();
860859
let (batch, d_inner, d_conv) = *cache.conv_state.shape();
861860

862861
// layer 1 (in_proj)
863862
let (xs, res): (
864-
Tensor<(Batch, C1, DInner), _, _, _>,
865-
Tensor<(Batch, C1, DInner), _, _, _>,
863+
Tensor<(Batch, DInner), _, _, _>,
864+
Tensor<(Batch, DInner), _, _, _>,
866865
) = {
867866
// projects the input DModel into 2*DInner
868-
let xs_and_res: Tensor<(Batch, C1, <DInner as Mul<C2>>::Output), _, _, _> =
867+
let xs_and_res: Tensor<(Batch, <DInner as Mul<C2>>::Output), _, _, _> =
869868
self.in_proj.try_forward(x)?;
870869

871870
// splits xs_and_res into (xs, res)
872871
let (xs, res, _tape) =
873-
xs_and_res.try_split_tensor_along(Axis::<2>, d_inner, d_inner)?;
872+
xs_and_res.try_split_tensor_along(Axis::<1>, d_inner, d_inner)?;
874873

875874
(xs, res)
876875
};
@@ -893,12 +892,11 @@ pub mod stateful {
893892
)?;
894893
// then concat with the xs as the last column (by the right side)
895894
let xs: Tensor<(Batch, DInner, C1), _, _, _> =
896-
xs.try_permute::<_, Axes3<0, 2, 1>>()?;
897-
// let xs = xs.try_reshape_like(&(batch, d_inner, Const::<1>))?;
895+
xs.try_reshape_like(&(batch, d_inner, Const::<1>))?;
898896
(conv_state, xs).try_concat_tensor_along(Axis::<2>)?
899897
};
900898

901-
let xs: Tensor<(Batch, C1, DInner), E, _, _> = {
899+
let xs: Tensor<(Batch, DInner), E, _, _> = {
902900
let conv1d = self
903901
.conv1d
904902
.weight
@@ -913,9 +911,7 @@ pub mod stateful {
913911
let xs = self.conv1d_bias.try_forward(xs)?;
914912

915913
// activation
916-
let xs = xs.try_silu()?;
917-
918-
xs.try_reshape_like(&(batch, Const::<1>, d_inner))?
914+
xs.try_silu()?
919915
};
920916

921917
let (ss, cache_ssm_state) = ss_step::<Batch, DState, DtRank, DInner, E, D, T>(
@@ -929,7 +925,7 @@ pub mod stateful {
929925
)?;
930926

931927
let ys = ss.try_mul(res.try_silu()?)?;
932-
let y: Tensor<(Batch, C1, DModel), _, _, _> = self.out_proj.try_forward(ys)?;
928+
let y: Tensor<(Batch, DModel), _, _, _> = self.out_proj.try_forward(ys)?;
933929

934930
cache.ssm_state = cache_ssm_state;
935931

@@ -957,13 +953,13 @@ pub mod stateful {
957953
//
958954
a: Tensor<(DInner, DState), E, D, T>,
959955
d: Tensor<(DInner,), E, D, T>,
960-
u: Tensor<(Batch, C1, DInner), E, D, T>,
956+
u: Tensor<(Batch, DInner), E, D, T>,
961957
x_proj: &MatMul<DInner, <DtRank as Add<<DState as Mul<C2>>::Output>>::Output, E, D>,
962958
dt_proj: &Linear<DtRank, DInner, E, D>,
963959
ssm_state_cache: Tensor<(Batch, DInner, DState), E, D, T>,
964960
) -> Result<
965961
(
966-
Tensor<(Batch, C1, DInner), E, D, T>,
962+
Tensor<(Batch, DInner), E, D, T>,
967963
Tensor<(Batch, DInner, DState), E, D, T>,
968964
),
969965
dfdx::tensor::Error,
@@ -987,25 +983,25 @@ pub mod stateful {
987983
// this is input independent (see Section 3.5.2 "Interpretation of A" form the Mamba paper for why A isn't selective)
988984
let a: Tensor<(DInner, DState), _, _, _> = a.try_exp()?.try_negate()?;
989985

990-
// (Batch, 1, DtRank + DState * 2)
991-
let x_dbl: Tensor<(Batch, C1, _), _, _, _> = x_proj.try_forward(u.retaped::<T>())?;
986+
// (Batch, DtRank + DState * 2)
987+
let x_dbl: Tensor<(Batch, _), _, _, _> = x_proj.try_forward(u.retaped::<T>())?;
992988

993989
// ∆ (part 1/2)
994990
// ∆ is input-dependent
995-
let (delta, x_dbl_tail, _tape): (Tensor<(Batch, C1, DtRank), _, _, _>, _, _) =
996-
x_dbl.try_split_tensor_along(Axis::<2>, dt_rank, d_state * Const::<2>)?;
991+
let (delta, x_dbl_tail, _tape): (Tensor<(Batch, DtRank), _, _, _>, _, _) =
992+
x_dbl.try_split_tensor_along(Axis::<1>, dt_rank, d_state * Const::<2>)?;
997993

998994
// B and C
999995
// B and C are input-dependent
1000996
let (b, c, _tape): (
1001-
Tensor<(Batch, C1, DState), _, _, _>,
1002-
Tensor<(Batch, C1, DState), _, _, _>,
997+
Tensor<(Batch, DState), _, _, _>,
998+
Tensor<(Batch, DState), _, _, _>,
1003999
_,
1004-
) = x_dbl_tail.try_split_tensor_along(Axis::<2>, d_state, d_state)?;
1000+
) = x_dbl_tail.try_split_tensor_along(Axis::<1>, d_state, d_state)?;
10051001

10061002
// ∆ (part 2/2)
10071003
// ∆ is input-dependent
1008-
let delta: Tensor<(Batch, C1, DInner), _, _, _> = {
1004+
let delta: Tensor<(Batch, DInner), _, _, _> = {
10091005
// note: don't add dt_proj bias
10101006
let delta = delta.try_matmul(
10111007
dt_proj
@@ -1021,22 +1017,14 @@ pub mod stateful {
10211017
dt_proj
10221018
.bias
10231019
.retaped::<T>()
1024-
.try_broadcast_like(&(batch, Const::<1>, d_inner))?,
1020+
.try_broadcast_like(&(batch, d_inner))?,
10251021
)?
10261022
.try_exp()?
10271023
.try_add(one)?)
10281024
.try_ln()?
10291025
};
10301026

1031-
selective_scan_step::<Batch, DState, DInner, E, D, T>(
1032-
delta.try_permute::<_, Axes3<0, 2, 1>>()?,
1033-
a,
1034-
b,
1035-
c.try_permute::<_, Axes3<1, 0, 2>>()?,
1036-
d,
1037-
u,
1038-
ssm_state_cache,
1039-
)
1027+
selective_scan_step::<Batch, DState, DInner, E, D, T>(delta, a, b, c, d, u, ssm_state_cache)
10401028
}
10411029

10421030
// Selective Scan.
@@ -1057,16 +1045,16 @@ pub mod stateful {
10571045
D: Device<E>,
10581046
T: Tape<E, D>,
10591047
>(
1060-
delta: Tensor<(Batch, DInner, C1), E, D, T>,
1048+
delta: Tensor<(Batch, DInner), E, D, T>,
10611049
a: Tensor<(DInner, DState), E, D, T>,
1062-
b: Tensor<(Batch, C1, DState), E, D, T>,
1063-
c: Tensor<(C1, Batch, DState), E, D, T>,
1050+
b: Tensor<(Batch, DState), E, D, T>,
1051+
c: Tensor<(Batch, DState), E, D, T>,
10641052
d: Tensor<(DInner,), E, D, T>,
1065-
u: Tensor<(Batch, C1, DInner), E, D, T>,
1053+
u: Tensor<(Batch, DInner), E, D, T>,
10661054
mut ssm_state_cache: Tensor<(Batch, DInner, DState), E, D, T>,
10671055
) -> Result<
10681056
(
1069-
Tensor<(Batch, C1, DInner), E, D, T>,
1057+
Tensor<(Batch, DInner), E, D, T>,
10701058
Tensor<(Batch, DInner, DState), E, D, T>,
10711059
),
10721060
dfdx::tensor::Error,
@@ -1078,15 +1066,15 @@ pub mod stateful {
10781066
// - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
10791067
// "A is the more important term and the performance doesn't change much with the simplification on B"
10801068
let (delta_a, delta_bu): (
1081-
Tensor<(Batch, DInner, C1, DState), _, _, _>,
1082-
Tensor<(Batch, DInner, C1, DState), _, _, _>,
1069+
Tensor<(Batch, DInner, DState), _, _, _>,
1070+
Tensor<(Batch, DInner, DState), _, _, _>,
10831071
) = {
1084-
let target_shape = (batch, d_inner, Const::<1>, d_state);
1072+
let target_shape = (batch, d_inner, d_state);
10851073

10861074
let delta_broadcasted = delta.try_broadcast_like(&target_shape)?;
10871075

10881076
let a = a.try_broadcast_like(&target_shape)?;
1089-
let delta_a: Tensor<(Batch, DInner, C1, DState), _, _, _> =
1077+
let delta_a: Tensor<(Batch, DInner, DState), _, _, _> =
10901078
delta_broadcasted.retaped::<T>().try_mul(a)?.try_exp()?;
10911079

10921080
let b = b.try_broadcast_like(&target_shape)?;
@@ -1106,13 +1094,9 @@ pub mod stateful {
11061094

11071095
let y = ssm_state_cache
11081096
.retaped::<T>()
1109-
.try_matmul(c.try_permute::<_, Axes3<1, 2, 0>>()?)?;
1110-
let du = d
1111-
.try_broadcast_like(&(batch, Const::<1>, d_inner))?
1112-
.try_mul(u)?;
1113-
let y = y
1114-
.try_reshape_like(&(batch, Const::<1>, d_inner))?
1115-
.try_add(du)?;
1097+
.try_matmul(c.try_reshape_like(&(batch, d_state, Const::<1>))?)?;
1098+
let du = d.try_broadcast_like(&(batch, d_inner))?.try_mul(u)?;
1099+
let y = y.try_reshape_like(&(batch, d_inner))?.try_add(du)?;
11161100

11171101
Ok((y, ssm_state_cache))
11181102
}

0 commit comments

Comments
 (0)