Skip to content

Commit

Permalink
refactor(dtype): deduplicate type conversion patterns in as_ macro
Browse files Browse the repository at this point in the history
  • Loading branch information
haricot committed Feb 15, 2025
1 parent 9996e5a commit 7785b3b
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions candle-core/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,25 @@ pub trait WithDType:
fn cpu_storage_as(s: &CpuStorage, layout: &crate::Layout, dtype: DType) -> Result<CpuStorage>;
}

macro_rules! as_ {
(U8, U8, $v:expr) => {$v};
(U32, U32, $v:expr) => {$v};
(I64, I64, $v:expr) => {$v};
(F32, F32, $v:expr) => {$v};
(F64, F64, $v:expr) => {$v};
(BF16, BF16, $v:expr) => {$v};
(F16, F16, $v:expr) => {$v};
($in:expr, U8, $v:expr) => { num_traits::AsPrimitive::<u8>::as_($v)};
($in:expr, U32, $v:expr) => { num_traits::AsPrimitive::<u32>::as_($v)};
($in:expr, I64, $v:expr) => { num_traits::AsPrimitive::<i64>::as_($v)};
($in:expr, F32, $v:expr) => { num_traits::AsPrimitive::<f32>::as_($v)};
($in:expr, F64, $v:expr) => { num_traits::AsPrimitive::<f64>::as_($v)};
($in:expr, BF16, $v:expr) => { num_traits::AsPrimitive::<bf16>::as_($v)};
($in:expr, F16, $v:expr) => { num_traits::AsPrimitive::<f16>::as_($v)};
}

macro_rules! cpu_storage_as {
(match:($cpu_storage: expr, $match_dtype: ident), $layout: ident, $with_dtype: ident, ($($dtype: ident),+)) => {{
macro_rules! as_ {
(U8, U8, $v:expr) => {$v};
(U32, U32, $v:expr) => {$v};
(I64, I64, $v:expr) => {$v};
(F32, F32, $v:expr) => {$v};
(F64, F64, $v:expr) => {$v};
(BF16, BF16, $v:expr) => {$v};
(F16, F16, $v:expr) => {$v};
($in:expr, U8, $v:expr) => { num_traits::AsPrimitive::<u8>::as_($v)};
($in:expr, U32, $v:expr) => { num_traits::AsPrimitive::<u32>::as_($v)};
($in:expr, I64, $v:expr) => { num_traits::AsPrimitive::<i64>::as_($v)};
($in:expr, F32, $v:expr) => { num_traits::AsPrimitive::<f32>::as_($v)};
($in:expr, F64, $v:expr) => { num_traits::AsPrimitive::<f64>::as_($v)};
($in:expr, BF16, $v:expr) => { num_traits::AsPrimitive::<bf16>::as_($v)};
($in:expr, F16, $v:expr) => { num_traits::AsPrimitive::<f16>::as_($v)};
}

match ($cpu_storage, $match_dtype) {
$((CpuStorage::$with_dtype(storage), DType::$dtype) => {
Ok({ let data = crate::cpu_backend::unary_map(&storage, $layout,
Expand Down

0 comments on commit 7785b3b

Please sign in to comment.