diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index b7df84404..e273b6c39 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -120,7 +120,11 @@ pub trait WithDType: } macro_rules! cpu_storage_as { - (match:($cpu_storage:expr, $dtype:tt), $layout:ident, $(($in_dtype:ident,$out_dtype:ident)),*) => {{ + (@inner $in_dtype:ident, ($($out_ty:ident),+)) => { + $(($in_dtype, $out_ty)),+ + }; + + (match:($cpu_storage:expr, $dtype:ident), $layout:ident, $in_dtype:ident, ($($out_ty:ident),+)) => {{ macro_rules! as_ { (U8, U8, $v:expr) => {$v}; (U32, U32, $v:expr) => {$v}; @@ -137,14 +141,13 @@ macro_rules! cpu_storage_as { ($in:expr, BF16, $v:expr) => { num_traits::AsPrimitive::::as_($v)}; ($in:expr, F16, $v:expr) => { num_traits::AsPrimitive::::as_($v)}; } + match ($cpu_storage, $dtype) { - $( - (CpuStorage::$in_dtype(storage), DType::$out_dtype) => { - Ok({ let data = crate::cpu_backend::unary_map(&storage, $layout, - |v| as_!($in_dtype, $out_dtype, v)); - CpuStorage::$out_dtype(data) - })}, - )* + $((CpuStorage::$in_dtype(storage), DType::$out_ty) => { + Ok({ let data = crate::cpu_backend::unary_map(&storage, $layout, + |v| as_!($in_dtype, $out_ty, v)); + CpuStorage::$out_ty(data) + })}),+, _ => Err(Error::UnexpectedDType { expected: $dtype, got: $cpu_storage.dtype(), @@ -202,8 +205,7 @@ macro_rules! with_dtype { #[inline] fn cpu_storage_as(s: &CpuStorage, layout: &crate::Layout, dtype: DType) -> Result { - cpu_storage_as!(match:(s, dtype), layout, - ($dtype,U8),($dtype,U32),($dtype,I64),($dtype,F16),($dtype,BF16),($dtype,F32),($dtype,F64)) + cpu_storage_as!(match:(s, dtype), layout, $dtype, (U8, U32, I64, F16, BF16, F32, F64)) } } };