@@ -120,33 +120,61 @@ pub trait Device<E: Dtype>:
120
120
+ crate :: tensor_ops:: axpy:: AxpyKernel < E >
121
121
122
122
// conv1d
123
- + super :: super :: conv1d:: Conv1DKernel < E >
123
+ + NonCudnnCuda < E >
124
+ {
125
+ }
126
+
127
+ #[ cfg( feature = "cudnn" ) ]
128
+ pub trait NonCudnnCuda < E : Dtype > { }
129
+
130
+ #[ cfg( not( feature = "cudnn" ) ) ]
131
+ pub trait NonCudnnCuda < E : Dtype > :
132
+ // conv1d
133
+ super :: super :: conv1d:: Conv1DKernel < E >
124
134
{
125
135
}
126
136
127
137
#[ cfg( feature = "f16" ) ]
128
- impl Device < f16 > for crate :: tensor:: Cpu { }
129
- #[ cfg( feature = "f16" ) ]
130
- impl Device < AMP < f16 > > for crate :: tensor:: Cpu { }
138
+ mod f16_ {
139
+ use super :: * ;
140
+ impl Device < f16 > for crate :: tensor:: Cpu { }
141
+ impl NonCudnnCuda < f16 > for crate :: tensor:: Cpu { }
142
+ impl Device < AMP < f16 > > for crate :: tensor:: Cpu { }
143
+ impl NonCudnnCuda < AMP < f16 > > for crate :: tensor:: Cpu { }
144
+ }
131
145
impl Device < f32 > for crate :: tensor:: Cpu { }
146
+ impl NonCudnnCuda < f32 > for crate :: tensor:: Cpu { }
132
147
impl Device < f64 > for crate :: tensor:: Cpu { }
148
+ impl NonCudnnCuda < f64 > for crate :: tensor:: Cpu { }
133
149
134
150
#[ cfg( all( feature = "cuda" , feature = "f16" ) ) ]
135
- impl Device < f16 > for crate :: tensor:: Cuda { }
136
- #[ cfg( all( feature = "cuda" , feature = "f16" ) ) ]
137
- impl Device < AMP < f16 > > for crate :: tensor:: Cuda { }
138
- #[ cfg( feature = "cuda" ) ]
139
- impl Device < f32 > for crate :: tensor:: Cuda { }
151
+ mod cuda_f16 {
152
+ use super :: * ;
153
+ impl Device < f16 > for crate :: tensor:: Cuda { }
154
+ impl NonCudnnCuda < f16 > for crate :: tensor:: Cuda { }
155
+ impl Device < AMP < f16 > > for crate :: tensor:: Cuda { }
156
+ impl NonCudnnCuda < AMP < f16 > > for crate :: tensor:: Cuda { }
157
+ }
140
158
#[ cfg( feature = "cuda" ) ]
141
- impl Device < f64 > for crate :: tensor:: Cuda { }
159
+ mod cuda {
160
+ use super :: * ;
161
+ impl Device < f32 > for crate :: tensor:: Cuda { }
162
+ impl NonCudnnCuda < f32 > for crate :: tensor:: Cuda { }
163
+ impl Device < f64 > for crate :: tensor:: Cuda { }
164
+ impl NonCudnnCuda < f64 > for crate :: tensor:: Cuda { }
165
+ }
142
166
143
167
// TODO: How can we implement this for f16 when WGSL doesn't support f16 yet?
144
168
// #[cfg(all(feature = "webgpu", feature = "f16"))]
145
169
// impl Device<f16> for crate::tensor::Webgpu {}
146
170
// #[cfg(all(feature = "webgpu", feature = "f16"))]
147
171
// impl Device<AMP<f16>> for crate::tensor::Webgpu {}
148
172
#[ cfg( feature = "webgpu" ) ]
149
- impl Device < f32 > for crate :: tensor:: Webgpu { }
173
+ mod webgpu {
174
+ use super :: * ;
175
+ impl Device < f32 > for crate :: tensor:: Webgpu { }
176
+ impl NonCudnnCuda < f32 > for crate :: tensor:: Webgpu { }
177
+ }
150
178
151
179
// TODO: How can we implement this for f64 when WGSL doesn't support f64 yet?
152
180
// #[cfg(feature = "webgpu")]
0 commit comments