diff --git a/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/wgpu.rs b/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/wgpu.rs index 30f1bde..990dc24 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/wgpu.rs +++ b/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/wgpu.rs @@ -9,7 +9,7 @@ use std::fmt::Formatter; use tracing::trace; use wgpu::{self, util::DeviceExt}; -/// Struct responsible for performing matrix multiplication on the GPU. +/// Matrix multiplication on the GPU using `wgpu`. pub struct MatrixMultiplier { device: wgpu::Device, queue: wgpu::Queue, diff --git a/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/lib.rs b/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/lib.rs index 192b2dc..7e3f0ad 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/lib.rs +++ b/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/lib.rs @@ -8,6 +8,7 @@ use std::future::Future; mod backends; pub mod variants; +/// The trait that defines how to multiply two matrices. pub trait MatrixMultiply: Display { fn new(variant: T) -> impl Future + Send; fn multiply(&self, a: &[f32], b: &[f32], m: u32, k: u32, n: u32) -> Vec; @@ -25,7 +26,7 @@ pub trait Cpu { ); } -/// Matrix multiplication logic that can be run on the CPU. +/// Matrix multiplication logic that can be run on the GPU. pub trait Gpu { fn compiled_shader(&self) -> &[u8]; fn entry_point(&self) -> &'static str {