From ee1f9fbc8a901dcab3d9d73d5b5d012b555ba6d2 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 11 Mar 2024 17:20:18 -0400 Subject: [PATCH] Make block unsafe --- candle-core/src/cuda_backend.rs | 58 +++++++++++++++++---------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index d33eb6e9d0..ea17fb3adf 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -372,34 +372,36 @@ impl BackendDevice for CudaDevice { self.const_impl(v as f64, shape, dtype); } let elem_count = shape.elem_count(); - let slice = match dtype { - DType::U8 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::U8(data) - } - DType::U32 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::U32(data) - } - DType::I64 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::I64(data) - } - DType::BF16 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::BF16(data) - } - DType::F16 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::F16(data) - } - DType::F32 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::F64(data) + let slice = unsafe { + match dtype { + DType::U8 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::U32(data) + } + DType::I64 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F64(data) + } } }; Ok(CudaStorage {