Skip to content

Commit

Permalink
Removing the fences speeds everything up and *is* correct this time...
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jan 5, 2024
1 parent 7b43890 commit 9130b6c
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 71 deletions.
28 changes: 14 additions & 14 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub struct MetalDevice {
/// execution order to be linear.
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
/// compute graph.
fence: metal::Fence,
// fence: metal::Fence,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`], both fences need to match
kernels: Arc<candle_metal_kernels::Kernels>,
Expand Down Expand Up @@ -131,9 +131,9 @@ impl MetalDevice {
&self.device
}

pub(crate) fn fence(&self) -> &metal::Fence {
&self.fence
}
// pub(crate) fn fence(&self) -> &metal::Fence {
// &self.fence
// }

pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
Expand Down Expand Up @@ -225,10 +225,10 @@ impl MetalDevice {
let command_buffer = self.command_buffer()?;
command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
// blit.wait_for_fence(&self.fence);
blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.update_fence(&self.fence);
// blit.update_fence(&self.fence);
blit.end_encoding();

// This is necessary, for mmaped safetensors
Expand All @@ -251,7 +251,7 @@ impl MetalDevice {
let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
// blit.wait_for_fence(&self.fence);
blit.fill_buffer(
&buffer,
metal::NSRange {
Expand All @@ -260,7 +260,7 @@ impl MetalDevice {
},
0,
);
blit.update_fence(&self.fence);
// blit.update_fence(&self.fence);
blit.end_encoding();
Ok(buffer)
}
Expand Down Expand Up @@ -1486,9 +1486,9 @@ impl MetalStorage {
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.wait_for_fence(&self.device.fence);
// blit.wait_for_fence(&self.device.fence);
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.update_fence(&self.device.fence);
// blit.update_fence(&self.device.fence);
blit.end_encoding();
}
self.device.wait_until_completed()?;
Expand All @@ -1506,16 +1506,16 @@ impl BackendDevice for MetalDevice {
command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0));
let fence = device.new_fence();
let kernels = Arc::new(Kernels::new(fence.clone()));
// let fence = device.new_fence();
let kernels = Arc::new(Kernels::new());
let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 20,
_ => 10,
};
Ok(Self {
device,
fence,
// fence,
command_queue,
command_buffer,
command_buffer_index,
Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/quantized/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ impl QMetalStorage {
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.wait_for_fence(&self.device.fence());
// blit.wait_for_fence(&self.device.fence());
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.update_fence(&self.device.fence());
// blit.update_fence(&self.device.fence());
blit.end_encoding();
self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count];
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/mistral/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ fn main() -> Result<()> {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QMistral::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
(Model::Quantized(model), device)
} else {
let dtype = if device.is_cuda() {
DType::BF16
Expand Down
Loading

0 comments on commit 9130b6c

Please sign in to comment.