From 05d709949d5a579c005dcadb88ea320a438e5e02 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Tue, 12 Mar 2024 12:01:09 -0400 Subject: [PATCH] Fix dtypes --- candle-nn/src/layer_norm.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index c7ef0bf5fb..9f5030a866 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -31,7 +31,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use candle::{backend::BackendStorage, cuda_backend::{cudarc::driver::{LaunchAsync, LaunchConfig}, kernel_name, CudaDType}, DType, Device, Result, Storage, Tensor, D}; -//pub use candle_kernels as kernels; +pub use candle_kernels as kernels; #[derive(Debug, Clone, Copy, PartialEq)] pub struct LayerNormConfig { @@ -175,13 +175,13 @@ impl crate::Module for LayerNorm { block_dim: (K_CUDABLOCK_REDUCE_NUM_THREADS,1,1), shared_mem_bytes: 0, }; - let rowwisemoments = cuda_dev.get_or_load_func(&kernel_name::("rowwisemoments"), kernels::LAYERNORM)?; + let rowwisemoments = cuda_dev.get_or_load_func(&format!("rowwisemoments_{}", x.dtype().as_str()), kernels::LAYERNORM)?; let params = (n, self.eps, x_storage, mean_storage, rstd_storage); unsafe { rowwisemoments.launch(cfg_1, params) }; panic!("Done!"); - let layernorm = cuda_dev.get_or_load_func(&kernel_name::("layernorm"), kernels::LAYERNORM)?; + let layernorm = cuda_dev.get_or_load_func(&format!("layernorm_{}", x.dtype().as_str()), kernels::LAYERNORM)?; todo!() } }