Skip to content

Commit 0e9f509

Browse files
committed
avoid conv1d bound for cudnn
1 parent bfb76e5 commit 0e9f509

File tree

1 file changed

+39
-11
lines changed

1 file changed

+39
-11
lines changed

dfdx-core/src/tensor_ops/utilities/device.rs

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,33 +120,61 @@ pub trait Device<E: Dtype>:
120120
+ crate::tensor_ops::axpy::AxpyKernel<E>
121121

122122
// conv1d
123-
+ super::super::conv1d::Conv1DKernel<E>
123+
+ NonCudnnCuda<E>
124+
{
125+
}
126+
127+
#[cfg(feature = "cudnn")]
128+
pub trait NonCudnnCuda<E: Dtype> {}
129+
130+
#[cfg(not(feature = "cudnn"))]
131+
pub trait NonCudnnCuda<E: Dtype>:
132+
// conv1d
133+
super::super::conv1d::Conv1DKernel<E>
124134
{
125135
}
126136

127137
#[cfg(feature = "f16")]
128-
impl Device<f16> for crate::tensor::Cpu {}
129-
#[cfg(feature = "f16")]
130-
impl Device<AMP<f16>> for crate::tensor::Cpu {}
138+
mod f16_ {
139+
use super::*;
140+
impl Device<f16> for crate::tensor::Cpu {}
141+
impl NonCudnnCuda<f16> for crate::tensor::Cpu {}
142+
impl Device<AMP<f16>> for crate::tensor::Cpu {}
143+
impl NonCudnnCuda<AMP<f16>> for crate::tensor::Cpu {}
144+
}
131145
impl Device<f32> for crate::tensor::Cpu {}
146+
impl NonCudnnCuda<f32> for crate::tensor::Cpu {}
132147
impl Device<f64> for crate::tensor::Cpu {}
148+
impl NonCudnnCuda<f64> for crate::tensor::Cpu {}
133149

134150
#[cfg(all(feature = "cuda", feature = "f16"))]
135-
impl Device<f16> for crate::tensor::Cuda {}
136-
#[cfg(all(feature = "cuda", feature = "f16"))]
137-
impl Device<AMP<f16>> for crate::tensor::Cuda {}
138-
#[cfg(feature = "cuda")]
139-
impl Device<f32> for crate::tensor::Cuda {}
151+
mod cuda_f16 {
152+
use super::*;
153+
impl Device<f16> for crate::tensor::Cuda {}
154+
impl NonCudnnCuda<f16> for crate::tensor::Cuda {}
155+
impl Device<AMP<f16>> for crate::tensor::Cuda {}
156+
impl NonCudnnCuda<AMP<f16>> for crate::tensor::Cuda {}
157+
}
140158
#[cfg(feature = "cuda")]
141-
impl Device<f64> for crate::tensor::Cuda {}
159+
mod cuda {
160+
use super::*;
161+
impl Device<f32> for crate::tensor::Cuda {}
162+
impl NonCudnnCuda<f32> for crate::tensor::Cuda {}
163+
impl Device<f64> for crate::tensor::Cuda {}
164+
impl NonCudnnCuda<f64> for crate::tensor::Cuda {}
165+
}
142166

143167
// TODO: How can we implement this for f16 when WGSL doesn't support f16 yet?
144168
// #[cfg(all(feature = "webgpu", feature = "f16"))]
145169
// impl Device<f16> for crate::tensor::Webgpu {}
146170
// #[cfg(all(feature = "webgpu", feature = "f16"))]
147171
// impl Device<AMP<f16>> for crate::tensor::Webgpu {}
148172
#[cfg(feature = "webgpu")]
149-
impl Device<f32> for crate::tensor::Webgpu {}
173+
mod webgpu {
174+
use super::*;
175+
impl Device<f32> for crate::tensor::Webgpu {}
176+
impl NonCudnnCuda<f32> for crate::tensor::Webgpu {}
177+
}
150178

151179
// TODO: How can we implement this for f64 when WGSL doesn't support f64 yet?
152180
// #[cfg(feature = "webgpu")]

0 commit comments

Comments
 (0)