Skip to content

Commit ee64526

Browse files
authored
Using arch option in nvrtc (#675)
* Using arch in nvrtc * Fixing unused message for cudnn * Fixing env var when ci-check is active
1 parent 64a60e2 commit ee64526

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

build.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,16 @@ mod cuda {
5252
.collect::<Vec<_>>();
5353

5454
#[cfg(feature = "ci-check")]
55-
for mut kernel_path in kernel_paths.into_iter() {
56-
kernel_path.set_extension("ptx");
55+
{
56+
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=ci");
57+
58+
for mut kernel_path in kernel_paths.into_iter() {
59+
kernel_path.set_extension("ptx");
5760

58-
let mut ptx_path: std::path::PathBuf = out_dir.clone().into();
59-
ptx_path.push(kernel_path.as_path().file_name().unwrap());
60-
std::fs::File::create(ptx_path).unwrap();
61+
let mut ptx_path: std::path::PathBuf = out_dir.clone().into();
62+
ptx_path.push(kernel_path.as_path().file_name().unwrap());
63+
std::fs::File::create(ptx_path).unwrap();
64+
}
6165
}
6266

6367
#[cfg(not(feature = "ci-check"))]
@@ -76,6 +80,8 @@ mod cuda {
7680
lines.next().unwrap().replace('.', "")
7781
};
7882

83+
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
84+
7985
kernel_paths
8086
.iter()
8187
.for_each(|p| println!("cargo:rerun-if-changed={}", p.display()));

src/tensor/cuda/device.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub struct Cuda {
2121
pub(crate) dev: Arc<CudaDevice>,
2222
pub(crate) blas: Arc<CudaBlas>,
2323
#[cfg(feature = "cudnn")]
24+
#[allow(unused)]
2425
pub(crate) cudnn: Arc<cudarc::cudnn::Cudnn>,
2526
/// A second stream for kernels to optionally execute on.
2627
pub(crate) par_stream: Arc<CudaStream>,

src/tensor_ops/reshape_to/cuda_kernel.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::{
44
};
55
use cudarc::{
66
driver::{DeviceSlice, LaunchAsync},
7-
nvrtc::compile_ptx,
7+
nvrtc::{compile_ptx_with_opts, CompileOptions},
88
types::CudaTypeName,
99
};
1010

@@ -17,7 +17,11 @@ impl<E: Dtype + CudaTypeName> super::ReshapeKernel<E> for Cuda {
1717
let module = std::format!("reshape_fwd_{}", E::NAME);
1818
if !self.dev.has_func(&module, "reshape_fwd") {
1919
let src = FWD_KERNEL.replace("$T", E::NAME);
20-
let ptx = compile_ptx(src).unwrap();
20+
let opts = CompileOptions {
21+
arch: Some(env!("CUDA_COMPUTE_CAP")),
22+
..Default::default()
23+
};
24+
let ptx = compile_ptx_with_opts(src, opts).unwrap();
2125
self.dev.load_ptx(ptx, &module, &["reshape_fwd"])?;
2226
}
2327
let fwd_fn = self.dev.get_func(&module, "reshape_fwd").unwrap();
@@ -56,7 +60,11 @@ impl<E: Dtype + CudaTypeName> super::ReshapeKernel<E> for Cuda {
5660
let module = std::format!("reshape_bwd_{}", E::NAME);
5761
if !self.dev.has_func(&module, "reshape_bwd") {
5862
let src = BWD_KERNEL.replace("$T", E::NAME);
59-
let ptx = compile_ptx(src).unwrap();
63+
let opts = CompileOptions {
64+
arch: Some(env!("CUDA_COMPUTE_CAP")),
65+
..Default::default()
66+
};
67+
let ptx = compile_ptx_with_opts(src, opts).unwrap();
6068
self.dev.load_ptx(ptx, &module, &["reshape_bwd"])?;
6169
}
6270
let bwd_fn = self.dev.get_func(&module, "reshape_bwd").unwrap();

0 commit comments

Comments
 (0)