Skip to content

Commit

Permalink
Merge pull request #570 from robertknight/im2col-int8
Browse files Browse the repository at this point in the history
Implement im2col packing for int8 GEMM
  • Loading branch information
robertknight authored Feb 3, 2025
2 parents b5a6d1b + 0c2ba0c commit 352c986
Show file tree
Hide file tree
Showing 11 changed files with 358 additions and 53 deletions.
1 change: 1 addition & 0 deletions rten-simd/src/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub trait Simd: Copy + Sized {
type Array: Copy
+ std::fmt::Debug
+ std::ops::Index<usize, Output = Self::Elem>
+ std::ops::IndexMut<usize, Output = Self::Elem>
+ AsRef<[Self::Elem]>;

/// Combine elements of `self` and `rhs` according to a mask.
Expand Down
92 changes: 84 additions & 8 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,14 @@ impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT> GemmExecutor<LhsT, RhsT, OutT
///
/// The number of columns in [`ColOffsets`] must be a multiple of this.
pub fn im2col_col_count_step(&self) -> usize {
self.kernel.nr()
self.kernel.im2col_col_count_step()
}

/// Return row count step for building an [`Im2Col`] input.
///
/// The number of rows in [`RowOffsets`] must be a multiple of this.
pub fn im2col_row_count_step(&self) -> usize {
self.kernel.im2col_row_count_step()
}

/// Prepack a matrix for use as the right-hand or "B" matrix input.
Expand Down Expand Up @@ -2186,21 +2193,34 @@ mod tests {
// This builds a mapping between elements of an image and a
// `[chans, height x width]` matrix where `image[c, y, x]` maps to
// `im2col_matrix[c, y / width, y % width]`.
fn build_im2col(image: NdTensorView<f32, 3>, col_count_step: usize) -> Im2Col<f32> {
fn build_im2col<T: Copy>(
image: NdTensorView<T, 3>,
col_count_step: usize,
row_count_step: usize,
) -> Im2Col<T> {
let [chans, img_h, img_w] = image.shape();
let [chan_stride, h_stride, w_stride] = image.strides();

let rows = chans;
let n_cols = img_w * img_h;
let n_cols_padded = n_cols.next_multiple_of(col_count_step);

let row_offsets = RowOffsets {
let rows = chans;
let n_rows_padded = rows.next_multiple_of(row_count_step);

let mut row_offsets = RowOffsets {
chan: (0..rows as i32)
.map(|chan| chan * chan_stride as i32)
.collect(),
y: vec![0; rows],
x: vec![0; rows],
};

for _ in rows..n_rows_padded {
row_offsets.chan.push(i32::MAX);
row_offsets.x.push(i32::MAX);
row_offsets.y.push(i32::MAX);
}

let mut col_offsets = ColOffsets {
y: (0..n_cols)
.map(|i| i as i32 / img_w as i32)
Expand All @@ -2212,8 +2232,8 @@ mod tests {
.collect(),
};
for _ in n_cols..n_cols_padded {
col_offsets.y.push(0);
col_offsets.x.push(0);
col_offsets.y.push(i32::MAX);
col_offsets.x.push(i32::MAX);
}

let max_y_offset = (img_h - 1) * h_stride;
Expand All @@ -2224,13 +2244,14 @@ mod tests {
row_offsets,
col_offsets,
n_cols,
n_rows: rows,
max_y_offset: max_y_offset as i32,
max_x_offset: max_x_offset as i32,
}
}

#[test]
fn test_gemm_im2col() -> Result<(), Box<dyn Error>> {
fn test_gemm_im2col_f32() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);
let gemm = GemmExecutor::default();

Expand All @@ -2241,7 +2262,11 @@ mod tests {
let kernel_chans = 3;

let img = NdTensor::<f32, 3>::rand([img_chans, img_h, img_w], &mut rng);
let im2col = build_im2col(img.view(), gemm.im2col_col_count_step());
let im2col = build_im2col(
img.view(),
gemm.im2col_col_count_step(),
gemm.im2col_row_count_step(),
);

let kernel_mat = NdTensor::<f32, 2>::rand([kernel_chans, img_chans], &mut rng);
let mut output_mat = NdTensor::<f32, 2>::zeros([kernel_chans, img_h * img_w]);
Expand Down Expand Up @@ -2275,6 +2300,57 @@ mod tests {
Ok(())
}

#[test]
fn test_gemm_im2col_u8i8_i32() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);

// nb. If the test fails, debug by setting dimensions to 1.
let img_h = 2;
let img_w = 2;
let img_chans = 2;
let kernel_chans = 3;

let img = NdTensor::<i8, 3>::rand([img_chans, img_h, img_w], &mut rng);

for gemm in all_gemms() {
let im2col = build_im2col(
img.view(),
gemm.im2col_col_count_step(),
gemm.im2col_row_count_step(),
);
let kernel_mat = NdTensor::<u8, 2>::rand([kernel_chans, img_chans], &mut rng);
let mut output_mat = NdTensor::<i32, 2>::zeros([kernel_chans, img_h * img_w]);
let out_row_stride = output_mat.row_stride();

gemm.gemm(
output_mat.data_mut().unwrap(),
out_row_stride,
GemmInputA::Unpacked(kernel_mat.view()),
GemmInputB::Im2Col(&im2col),
1., // alpha
0, // beta
None, // bias
None, // a_quant
None, // b_quant
)
.unwrap();

let mut expected = NdTensor::<i32, 2>::zeros([kernel_chans, im2col.cols()]);
for i in 0..expected.rows() {
for j in 0..expected.cols() {
let mut acc = 0;
for k in 0..kernel_mat.cols() {
acc += kernel_mat[[i, k]] as i32 * img[[k, j / img_w, j % img_w]] as i32;
}
expected[[i, j]] = acc;
}
}
expect_equal(&output_mat, &expected)?;
}

Ok(())
}

#[test]
fn test_gemv() -> Result<(), Box<dyn Error>> {
enum Strides {
Expand Down
152 changes: 147 additions & 5 deletions src/gemm/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ use std::mem::MaybeUninit;
use std::ops::Range;

use rten_simd::{SimdInt, SimdMask};

use rten_tensor::{NdTensorView, Storage};

use super::packing::int8::shift_cast_i8_u8;
use crate::slice_cast::cast_pod_mut_slice;

/// Maps rows of an [`Im2Col`] matrix to locations in the source image.
///
/// For efficiency when packing the image, the locations are premultiplied by
Expand Down Expand Up @@ -45,17 +47,26 @@ pub struct Im2Col<'a, T> {

/// Map of im2col row index to input image coordinate, premultiplied with
/// the corresponding stride.
///
/// The arrays may be padded to a multiple of a step size specified by the
/// GEMM kernel. `n_rows` contains the actual number of rows in the virtual
/// matrix.
pub row_offsets: RowOffsets,

/// Map of im2col column index to input image coordinate, premultiplied with
/// the corresponding stride. The length of arrays in `col_offsets` is
/// rounded up to the nearest multiple of the panel width. `n_cols` contains
/// the actual number of columns in the virtual matrix.
/// the corresponding stride.
///
/// The arrays may be padded to a multiple of a step size specified by the
/// GEMM kernel. `n_cols` contains the actual number of columns in the
/// virtual matrix.
pub col_offsets: ColOffsets,

/// Number of columns in the im2col matrix.
pub n_cols: usize,

/// Number of rows in the im2col matrix.
pub n_rows: usize,

/// Maximum valid sum of `row_offsets.y + col_offsets.y`. Values above this
/// correspond to the padding region.
pub max_y_offset: i32,
Expand All @@ -68,7 +79,7 @@ pub struct Im2Col<'a, T> {
impl<T: Copy + Default> Im2Col<'_, T> {
/// Return the number of rows in the im2col matrix.
pub fn rows(&self) -> usize {
self.row_offsets.chan.len()
self.n_rows
}

/// Return the number of columns in the im2col matrix.
Expand All @@ -78,6 +89,9 @@ impl<T: Copy + Default> Im2Col<'_, T> {

/// Pack part of an image into a packing buffer.
///
/// This method is for use by kernels using the "standard" packing buffer
/// layout for the B / RHS input.
///
/// `NR_REGS` specifies the width of each column panel as a multiple of
/// `S::LEN` elements. In other words, `panel_width` must exactly equal
/// `NR_REGS * S::LEN`.
Expand Down Expand Up @@ -188,3 +202,131 @@ impl<T: Copy + Default> Im2Col<'_, T> {
assert_eq!(out_offset, used_size);
}
}

impl Im2Col<'_, i8> {
/// Pack part of an image into a packing buffer.
///
/// This method is for use by kernels using int8 dot product instructions
/// to compute `S::LEN x i32` dot products from two input vectors each
/// containing `S::LEN x 4 x i8` (or u8) inputs.
#[inline(always)]
#[allow(unused)] // Some architectures only
pub(super) unsafe fn pack_block_i8_dot<S: SimdInt>(
&self,
out: &mut [MaybeUninit<i8>],
rows: Range<usize>,
cols: Range<usize>,
) {
self.pack_block_int8::<S, false>(out, rows, cols);
}

/// Variant of [`pack_block_i8_dot`](Self::pack_block_i8_dot) which shifts
/// i8 values to u8 by adding 128.
#[inline(always)]
#[allow(unused)] // Some architectures only
pub(super) unsafe fn pack_block_i8_dot_cast_u8<S: SimdInt>(
&self,
out: &mut [MaybeUninit<u8>],
rows: Range<usize>,
cols: Range<usize>,
) {
let out = cast_pod_mut_slice(out).unwrap();
self.pack_block_int8::<S, true>(out, rows, cols);
}

#[inline(always)]
unsafe fn pack_block_int8<S: SimdInt, const CAST_B_U8: bool>(
&self,
out: &mut [MaybeUninit<i8>],
rows: Range<usize>,
cols: Range<usize>,
) {
const K_TILE: usize = size_of::<i32>() / size_of::<i8>();

debug_assert!(rows.end <= self.rows());
debug_assert!(cols.end <= self.cols());

let max_x_offset = S::splat(self.max_x_offset);
let max_y_offset = S::splat(self.max_y_offset);

let col_x_offsets = &self.col_offsets.x;
debug_assert_eq!(col_x_offsets.len() % S::LEN, 0);

let col_y_offsets = &self.col_offsets.y;
debug_assert_eq!(col_y_offsets.len() % S::LEN, 0);

let row_x_offsets = &self.row_offsets.x;
debug_assert_eq!(row_x_offsets.len() % K_TILE, 0);

let row_y_offsets = &self.row_offsets.y;
debug_assert_eq!(row_y_offsets.len() % K_TILE, 0);

let row_chan_offsets = &self.row_offsets.chan;
debug_assert_eq!(row_chan_offsets.len() % K_TILE, 0);

let img_ptr = self.image.storage().as_ptr();
let out_ptr = out.as_mut_ptr();

let mut out_offset = 0;

for start_col in cols.step_by(S::LEN) {
let col_y_offset = S::load(col_y_offsets.get_unchecked(start_col));
let col_x_offset = S::load(col_x_offsets.get_unchecked(start_col));
let zero = S::zero();

let mut col_sums = S::zero().to_array();

for start_row in rows.clone().step_by(4) {
for i in 0..K_TILE {
let k = start_row + i;
let row_x_offset = S::splat(*row_x_offsets.get_unchecked(k));
let row_y_offset = S::splat(*row_y_offsets.get_unchecked(k));
let row_chan_offset = S::splat(*row_chan_offsets.get_unchecked(k));

let x_offsets = row_x_offset.add(col_x_offset);
let y_offsets = row_y_offset.add(col_y_offset);
let offsets = x_offsets.add(y_offsets).add(row_chan_offset);

let pad_mask = y_offsets
.ge(zero)
.and(y_offsets.le(max_y_offset))
.and(x_offsets.ge(zero))
.and(x_offsets.le(max_x_offset));
let pad_mask_array = pad_mask.to_array();

// Set offsets to zero for padding elements. We require
// this offset is always valid.
let offsets_array = zero.blend(offsets, pad_mask).to_array();

for idx in 0..S::LEN {
let out_ptr = out_ptr.add(out_offset + idx * K_TILE + i);
let src_elem = *img_ptr.add(offsets_array[idx] as usize);

// This should be compiled to a conditional move.
let elem = if pad_mask_array[idx] { src_elem } else { 0 };

if CAST_B_U8 {
let elem = shift_cast_i8_u8(elem);
col_sums[idx] += elem as i32;
out_ptr.write(MaybeUninit::new(elem as i8));
} else {
col_sums[idx] += elem as i32;
out_ptr.write(MaybeUninit::new(elem));
}
}
}
out_offset += S::LEN * K_TILE;
}

// Store column sums at end of each panel.
let col_sum_ptr = out_ptr.add(out_offset) as *mut i32;
for i in 0..S::LEN {
*col_sum_ptr.add(i) = col_sums[i];
}
out_offset += S::LEN * K_TILE;
}

// Sanity check
assert_eq!(out_offset, out.len());
}
}
22 changes: 22 additions & 0 deletions src/gemm/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,32 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
/// Return a name for this kernel for use in logging etc.
fn name(&self) -> &'static str;

/// Return true if this kernel may encounter saturation in a data type that
/// is smaller than the accumulator.
///
/// The caller will have to prepare inputs (usually the weights) to avoid
/// this. This is primarily an issue for x64 systems without VNNI.
/// See https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html.
fn may_saturate(&self) -> bool {
false
}

/// Step size used when packing an image usage [`pack_im2col`](Kernel::pack_im2col).
///
/// The length of the offset arrays in [`Im2Col::row_offsets`] must be a
/// multiple of this.
fn im2col_row_count_step(&self) -> usize {
1
}

/// Step size used when packing an image usage [`pack_im2col`](Kernel::pack_im2col).
///
/// The length of the offset arrays in [`Im2Col::col_offsets`] must be a
/// multiple of this.
fn im2col_col_count_step(&self) -> usize {
self.nr()
}

/// Return the layout of a packing buffer required to pack an A / LHS input.
fn packed_a_layout(
&self,
Expand Down
Loading

0 comments on commit 352c986

Please sign in to comment.