Skip to content

Commit 1e02f9f

Browse files
committed
unstack usage requires continuity
1 parent 27dff97 commit 1e02f9f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

dfdx/src/nn/layers/mamba_minimal.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ pub mod stateless {
673673
Err(_delta_a) => unreachable!(),
674674
};
675675
let (delta_a, _delta_a_tape): (Vec<Tensor<(Batch, DInner, DState), _, _, _>>, _) =
676-
delta_a.try_unstack()?;
676+
delta_a.try_contiguous()?.try_unstack()?;
677677
//
678678
// delta B
679679
let delta_bu: Tensor<(usize, Batch, DInner, DState), _, _, _> = match delta_bu.try_realize()
@@ -682,7 +682,7 @@ pub mod stateless {
682682
Err(_delta_bu) => unreachable!(),
683683
};
684684
let (delta_bu, _delta_bu_tape): (Vec<Tensor<(Batch, DInner, DState), _, _, _>>, _) =
685-
delta_bu.try_unstack()?;
685+
delta_bu.try_contiguous()?.try_unstack()?;
686686
//
687687
// C
688688
let c: Tensor<(usize, Batch, DState, C1), _, _, _> = match c

0 commit comments

Comments
 (0)