From 57f1128aa983d76eb54eb1f4eeb2ca362fc33129 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 12 Feb 2025 13:40:33 -0500 Subject: [PATCH] Fix stage tiling errors --- .../src/kernel/conv/conv2d/gemm/homogeneous/base.rs | 1 + .../burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs | 8 ++++---- .../burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs index 6976fbd68e..1d743fce57 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs @@ -274,6 +274,7 @@ pub(crate) fn implicit_conv< #[comptime] config: GMM::Config, #[comptime] has_bias: bool, ) { + // num_elements_x_dim / num_elements_y_dim let x_offset = CUBE_POS_X * config.stage_tiling(Ident::Lhs).total_row(); let y_offset = CUBE_POS_Y * config.stage_tiling(Ident::Rhs).total_col(); let k_range = (0, rhs.shape(0)); diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs index ea43765281..56a6ac7752 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs @@ -121,13 +121,13 @@ impl SimpleIm2col { let (tile_x, tile_y) = match config.tiling_order(ident) { TilingOrderConfig::RowMajor => RowMajorTiling::to_x_y( nth_tile, - stage_tiling.total_row(), - stage_tiling.total_col(), + stage_tiling.tile_count_row(), + stage_tiling.tile_count_col(), ), TilingOrderConfig::ColMajor => ColMajorTiling::to_x_y( nth_tile, - stage_tiling.total_row(), - stage_tiling.total_col(), + stage_tiling.tile_count_row(), + stage_tiling.tile_count_col(), ), }; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs index 19d2bebfc2..ea89b3a165 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs @@ -98,8 +98,8 @@ impl Im2colReader { #[comptime] config: G, ) -> Line { let line_size = config.global_line_size(ident); - let tile_size_x = config.stage_tiling(ident).total_row(); - let tile_size_y = config.stage_tiling(ident).total_col(); + let tile_size_x = config.stage_tiling(ident).tile_shape_row(); + let tile_size_y = config.stage_tiling(ident).tile_shape_col(); let view_tile_m = tile_x * tile_size_x + self.m_offset; let view_tile_k = tile_y * tile_size_y + self.k_offset;