@@ -810,7 +810,7 @@ pub mod stateful {
810
810
T : Tape < E , D > ,
811
811
>
812
812
Module < (
813
- Tensor < ( Batch , C1 , DModel ) , E , D , T > ,
813
+ Tensor < ( Batch , DModel ) , E , D , T > ,
814
814
MambaStateCache < Batch , DState , DConv , DInner , E , D , T > ,
815
815
) > for MambaBlock < DModel , DState , DtRank , DConv , DInner , E , D >
816
816
where
@@ -842,35 +842,34 @@ pub mod stateful {
842
842
) : dfdx_core:: tensor_ops:: TryConcatShapeAlong < Axis < 2 > , Output = ( Batch , DInner , DConv ) > ,
843
843
{
844
844
type Output = (
845
- Tensor < ( Batch , C1 , DModel ) , E , D , T > ,
845
+ Tensor < ( Batch , DModel ) , E , D , T > ,
846
846
MambaStateCache < Batch , DState , DConv , DInner , E , D , T > ,
847
847
) ;
848
848
849
849
/// Mamba block forward.
850
850
fn try_forward (
851
851
& self ,
852
852
x : (
853
- Tensor < ( Batch , C1 , DModel ) , E , D , T > ,
853
+ Tensor < ( Batch , DModel ) , E , D , T > ,
854
854
MambaStateCache < Batch , DState , DConv , DInner , E , D , T > ,
855
855
) ,
856
856
) -> Result < Self :: Output , Error > {
857
857
let ( x, mut cache) = x;
858
858
859
- // let (batch, _d_model) = *x.shape();
860
859
let ( batch, d_inner, d_conv) = * cache. conv_state . shape ( ) ;
861
860
862
861
// layer 1 (in_proj)
863
862
let ( xs, res) : (
864
- Tensor < ( Batch , C1 , DInner ) , _ , _ , _ > ,
865
- Tensor < ( Batch , C1 , DInner ) , _ , _ , _ > ,
863
+ Tensor < ( Batch , DInner ) , _ , _ , _ > ,
864
+ Tensor < ( Batch , DInner ) , _ , _ , _ > ,
866
865
) = {
867
866
// projects the input DModel into 2*DInner
868
- let xs_and_res: Tensor < ( Batch , C1 , <DInner as Mul < C2 > >:: Output ) , _ , _ , _ > =
867
+ let xs_and_res: Tensor < ( Batch , <DInner as Mul < C2 > >:: Output ) , _ , _ , _ > =
869
868
self . in_proj . try_forward ( x) ?;
870
869
871
870
// splits xs_and_res into (xs, res)
872
871
let ( xs, res, _tape) =
873
- xs_and_res. try_split_tensor_along ( Axis :: < 2 > , d_inner, d_inner) ?;
872
+ xs_and_res. try_split_tensor_along ( Axis :: < 1 > , d_inner, d_inner) ?;
874
873
875
874
( xs, res)
876
875
} ;
@@ -893,12 +892,11 @@ pub mod stateful {
893
892
) ?;
894
893
// then concat with the xs as the last column (by the right side)
895
894
let xs: Tensor < ( Batch , DInner , C1 ) , _ , _ , _ > =
896
- xs. try_permute :: < _ , Axes3 < 0 , 2 , 1 > > ( ) ?;
897
- // let xs = xs.try_reshape_like(&(batch, d_inner, Const::<1>))?;
895
+ xs. try_reshape_like ( & ( batch, d_inner, Const :: < 1 > ) ) ?;
898
896
( conv_state, xs) . try_concat_tensor_along ( Axis :: < 2 > ) ?
899
897
} ;
900
898
901
- let xs: Tensor < ( Batch , C1 , DInner ) , E , _ , _ > = {
899
+ let xs: Tensor < ( Batch , DInner ) , E , _ , _ > = {
902
900
let conv1d = self
903
901
. conv1d
904
902
. weight
@@ -913,9 +911,7 @@ pub mod stateful {
913
911
let xs = self . conv1d_bias . try_forward ( xs) ?;
914
912
915
913
// activation
916
- let xs = xs. try_silu ( ) ?;
917
-
918
- xs. try_reshape_like ( & ( batch, Const :: < 1 > , d_inner) ) ?
914
+ xs. try_silu ( ) ?
919
915
} ;
920
916
921
917
let ( ss, cache_ssm_state) = ss_step :: < Batch , DState , DtRank , DInner , E , D , T > (
@@ -929,7 +925,7 @@ pub mod stateful {
929
925
) ?;
930
926
931
927
let ys = ss. try_mul ( res. try_silu ( ) ?) ?;
932
- let y: Tensor < ( Batch , C1 , DModel ) , _ , _ , _ > = self . out_proj . try_forward ( ys) ?;
928
+ let y: Tensor < ( Batch , DModel ) , _ , _ , _ > = self . out_proj . try_forward ( ys) ?;
933
929
934
930
cache. ssm_state = cache_ssm_state;
935
931
@@ -957,13 +953,13 @@ pub mod stateful {
957
953
//
958
954
a : Tensor < ( DInner , DState ) , E , D , T > ,
959
955
d : Tensor < ( DInner , ) , E , D , T > ,
960
- u : Tensor < ( Batch , C1 , DInner ) , E , D , T > ,
956
+ u : Tensor < ( Batch , DInner ) , E , D , T > ,
961
957
x_proj : & MatMul < DInner , <DtRank as Add < <DState as Mul < C2 > >:: Output > >:: Output , E , D > ,
962
958
dt_proj : & Linear < DtRank , DInner , E , D > ,
963
959
ssm_state_cache : Tensor < ( Batch , DInner , DState ) , E , D , T > ,
964
960
) -> Result <
965
961
(
966
- Tensor < ( Batch , C1 , DInner ) , E , D , T > ,
962
+ Tensor < ( Batch , DInner ) , E , D , T > ,
967
963
Tensor < ( Batch , DInner , DState ) , E , D , T > ,
968
964
) ,
969
965
dfdx:: tensor:: Error ,
@@ -987,25 +983,25 @@ pub mod stateful {
987
983
// this is input independent (see Section 3.5.2 "Interpretation of A" form the Mamba paper for why A isn't selective)
988
984
let a: Tensor < ( DInner , DState ) , _ , _ , _ > = a. try_exp ( ) ?. try_negate ( ) ?;
989
985
990
- // (Batch, 1, DtRank + DState * 2)
991
- let x_dbl: Tensor < ( Batch , C1 , _ ) , _ , _ , _ > = x_proj. try_forward ( u. retaped :: < T > ( ) ) ?;
986
+ // (Batch, DtRank + DState * 2)
987
+ let x_dbl: Tensor < ( Batch , _ ) , _ , _ , _ > = x_proj. try_forward ( u. retaped :: < T > ( ) ) ?;
992
988
993
989
// ∆ (part 1/2)
994
990
// ∆ is input-dependent
995
- let ( delta, x_dbl_tail, _tape) : ( Tensor < ( Batch , C1 , DtRank ) , _ , _ , _ > , _ , _ ) =
996
- x_dbl. try_split_tensor_along ( Axis :: < 2 > , dt_rank, d_state * Const :: < 2 > ) ?;
991
+ let ( delta, x_dbl_tail, _tape) : ( Tensor < ( Batch , DtRank ) , _ , _ , _ > , _ , _ ) =
992
+ x_dbl. try_split_tensor_along ( Axis :: < 1 > , dt_rank, d_state * Const :: < 2 > ) ?;
997
993
998
994
// B and C
999
995
// B and C are input-dependent
1000
996
let ( b, c, _tape) : (
1001
- Tensor < ( Batch , C1 , DState ) , _ , _ , _ > ,
1002
- Tensor < ( Batch , C1 , DState ) , _ , _ , _ > ,
997
+ Tensor < ( Batch , DState ) , _ , _ , _ > ,
998
+ Tensor < ( Batch , DState ) , _ , _ , _ > ,
1003
999
_ ,
1004
- ) = x_dbl_tail. try_split_tensor_along ( Axis :: < 2 > , d_state, d_state) ?;
1000
+ ) = x_dbl_tail. try_split_tensor_along ( Axis :: < 1 > , d_state, d_state) ?;
1005
1001
1006
1002
// ∆ (part 2/2)
1007
1003
// ∆ is input-dependent
1008
- let delta: Tensor < ( Batch , C1 , DInner ) , _ , _ , _ > = {
1004
+ let delta: Tensor < ( Batch , DInner ) , _ , _ , _ > = {
1009
1005
// note: don't add dt_proj bias
1010
1006
let delta = delta. try_matmul (
1011
1007
dt_proj
@@ -1021,22 +1017,14 @@ pub mod stateful {
1021
1017
dt_proj
1022
1018
. bias
1023
1019
. retaped :: < T > ( )
1024
- . try_broadcast_like ( & ( batch, Const :: < 1 > , d_inner) ) ?,
1020
+ . try_broadcast_like ( & ( batch, d_inner) ) ?,
1025
1021
) ?
1026
1022
. try_exp ( ) ?
1027
1023
. try_add ( one) ?)
1028
1024
. try_ln ( ) ?
1029
1025
} ;
1030
1026
1031
- selective_scan_step :: < Batch , DState , DInner , E , D , T > (
1032
- delta. try_permute :: < _ , Axes3 < 0 , 2 , 1 > > ( ) ?,
1033
- a,
1034
- b,
1035
- c. try_permute :: < _ , Axes3 < 1 , 0 , 2 > > ( ) ?,
1036
- d,
1037
- u,
1038
- ssm_state_cache,
1039
- )
1027
+ selective_scan_step :: < Batch , DState , DInner , E , D , T > ( delta, a, b, c, d, u, ssm_state_cache)
1040
1028
}
1041
1029
1042
1030
// Selective Scan.
@@ -1057,16 +1045,16 @@ pub mod stateful {
1057
1045
D : Device < E > ,
1058
1046
T : Tape < E , D > ,
1059
1047
> (
1060
- delta : Tensor < ( Batch , DInner , C1 ) , E , D , T > ,
1048
+ delta : Tensor < ( Batch , DInner ) , E , D , T > ,
1061
1049
a : Tensor < ( DInner , DState ) , E , D , T > ,
1062
- b : Tensor < ( Batch , C1 , DState ) , E , D , T > ,
1063
- c : Tensor < ( C1 , Batch , DState ) , E , D , T > ,
1050
+ b : Tensor < ( Batch , DState ) , E , D , T > ,
1051
+ c : Tensor < ( Batch , DState ) , E , D , T > ,
1064
1052
d : Tensor < ( DInner , ) , E , D , T > ,
1065
- u : Tensor < ( Batch , C1 , DInner ) , E , D , T > ,
1053
+ u : Tensor < ( Batch , DInner ) , E , D , T > ,
1066
1054
mut ssm_state_cache : Tensor < ( Batch , DInner , DState ) , E , D , T > ,
1067
1055
) -> Result <
1068
1056
(
1069
- Tensor < ( Batch , C1 , DInner ) , E , D , T > ,
1057
+ Tensor < ( Batch , DInner ) , E , D , T > ,
1070
1058
Tensor < ( Batch , DInner , DState ) , E , D , T > ,
1071
1059
) ,
1072
1060
dfdx:: tensor:: Error ,
@@ -1078,15 +1066,15 @@ pub mod stateful {
1078
1066
// - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
1079
1067
// "A is the more important term and the performance doesn't change much with the simplification on B"
1080
1068
let ( delta_a, delta_bu) : (
1081
- Tensor < ( Batch , DInner , C1 , DState ) , _ , _ , _ > ,
1082
- Tensor < ( Batch , DInner , C1 , DState ) , _ , _ , _ > ,
1069
+ Tensor < ( Batch , DInner , DState ) , _ , _ , _ > ,
1070
+ Tensor < ( Batch , DInner , DState ) , _ , _ , _ > ,
1083
1071
) = {
1084
- let target_shape = ( batch, d_inner, Const :: < 1 > , d_state) ;
1072
+ let target_shape = ( batch, d_inner, d_state) ;
1085
1073
1086
1074
let delta_broadcasted = delta. try_broadcast_like ( & target_shape) ?;
1087
1075
1088
1076
let a = a. try_broadcast_like ( & target_shape) ?;
1089
- let delta_a: Tensor < ( Batch , DInner , C1 , DState ) , _ , _ , _ > =
1077
+ let delta_a: Tensor < ( Batch , DInner , DState ) , _ , _ , _ > =
1090
1078
delta_broadcasted. retaped :: < T > ( ) . try_mul ( a) ?. try_exp ( ) ?;
1091
1079
1092
1080
let b = b. try_broadcast_like ( & target_shape) ?;
@@ -1106,13 +1094,9 @@ pub mod stateful {
1106
1094
1107
1095
let y = ssm_state_cache
1108
1096
. retaped :: < T > ( )
1109
- . try_matmul ( c. try_permute :: < _ , Axes3 < 1 , 2 , 0 > > ( ) ?) ?;
1110
- let du = d
1111
- . try_broadcast_like ( & ( batch, Const :: < 1 > , d_inner) ) ?
1112
- . try_mul ( u) ?;
1113
- let y = y
1114
- . try_reshape_like ( & ( batch, Const :: < 1 > , d_inner) ) ?
1115
- . try_add ( du) ?;
1097
+ . try_matmul ( c. try_reshape_like ( & ( batch, d_state, Const :: < 1 > ) ) ?) ?;
1098
+ let du = d. try_broadcast_like ( & ( batch, d_inner) ) ?. try_mul ( u) ?;
1099
+ let y = y. try_reshape_like ( & ( batch, d_inner) ) ?. try_add ( du) ?;
1116
1100
1117
1101
Ok ( ( y, ssm_state_cache) )
1118
1102
}
0 commit comments