diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index f2e3c1e3ca..bda055231a 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -40,7 +40,7 @@ impl ElementWiseBuilder { impl OptimizationBuilder> for ElementWiseBuilder { fn register(&mut self, operation: &burn_ir::OperationIr) { - self.builder.register(operation) + self.builder.register(operation); } fn build(&self) -> JitOptimization { diff --git a/crates/burn-jit/src/fusion/on_write/kernel.rs b/crates/burn-jit/src/fusion/on_write/kernel.rs index 0585df7674..73a39793fe 100644 --- a/crates/burn-jit/src/fusion/on_write/kernel.rs +++ b/crates/burn-jit/src/fusion/on_write/kernel.rs @@ -1147,63 +1147,135 @@ fn select_indices( _ => panic!("Indices tensor isn't an input"), }; - let stride_input = global_stride(inputs, dim, pos_input, precision_input); + let stride_input_dim = global_stride(inputs, dim, pos_input, precision_input); - let mut index = Line::empty(line_size_ref).fill(0); + let mut index = 0u32; + let mut result = Line::empty(line_size_ref); - if comptime![dim > 0] { - let index_before = global_offset( - inputs, - outputs, - write_pos, - comment!(input.clone()), - comptime![Some((0u32, dim))], - config, - ); - index += Line::new(index_before); - } + if comptime![dim != config.rank - 1] { + // In this scenario the select is actually broadcasted along the axis we're working on. + // + // Therefore the same indices are used to fetch multiple entries in the input tensor. - if comptime![dim + 1 < config.rank] { - let index_after = global_offset( + let write_pos_input = write_pos * line_size_ref; + let stride_input_line = global_stride( inputs, - outputs, - write_pos, - input, - comptime![Some((dim + 1, config.rank))], - config, + comptime![config.rank - 1], + pos_input, + precision_input, ); - index += Line::new(index_after); - } - let mut result = Line::empty(line_size_ref); + if comptime![dim > 0] { + let index_before = global_offset( + inputs, + outputs, + write_pos_input, + comment!(input.clone()), + comptime![Some((0u32, dim))], + config, + ); + index += index_before; + } - #[unroll] - for i in 0..line_size_ref { - let index_indices = ((write_pos * line_size_ref) + i) / stride_dim_ref % shape_dim_ref; + if comptime![dim + 1 < config.rank] { + let index_after = global_offset( + inputs, + outputs, + write_pos_input, + comment!(input.clone()), + comptime![Some((dim + 1, config.rank))], + config, + ); + index += index_after; + } + let coordinate_dim = write_pos_input / stride_dim_ref % shape_dim_ref; let offset_dim = read_input::( inputs, outputs, pos_indices, - index_indices, + coordinate_dim, LayoutInfo::IsRef, precision_indices, config, None, ); - let index = index[i] + offset_dim[0] * stride_input; - let input = read_input::( - inputs, - outputs, - pos_input, - index, - LayoutInfo::IsRef, - precision_input, - config, - None, - ); - result[i] = input[0]; + index *= line_size_ref; + index += offset_dim[0] * stride_input_dim; + + #[unroll] + for i in 0..line_size_ref { + let input = read_input::( + inputs, + outputs, + pos_input, + index + i * stride_input_line, + LayoutInfo::IsRef, + precision_input, + config, + None, + ); + result[i] = input[0]; + } + } else { + // In this scenario the select is actually performed on the last dimension we're working on. + // + // Therefore we need to fetch multiple indices that correspond to different entries in the + // input tensor. + + if comptime![dim > 0] { + let index_before = global_offset( + inputs, + outputs, + write_pos, + comment!(input.clone()), + comptime![Some((0u32, dim))], + config, + ); + index += index_before; + } + + if comptime![dim + 1 < config.rank] { + let index_after = global_offset( + inputs, + outputs, + write_pos, + input, + comptime![Some((dim + 1, config.rank))], + config, + ); + index += index_after; + } + + let write_pos_indices = write_pos * line_size_ref; + + #[unroll] + for i in 0..line_size_ref { + let coordinate_dim = (write_pos_indices + i) / stride_dim_ref % shape_dim_ref; + let offset_dim = read_input::( + inputs, + outputs, + pos_indices, + coordinate_dim, + LayoutInfo::IsRef, + precision_indices, + config, + None, + ); + + let input = read_input::( + inputs, + outputs, + pos_input, + index + (offset_dim[0] * stride_input_dim), + LayoutInfo::IsRef, + precision_input, + config, + None, + ); + result[i] = input[0]; + } } write::(inputs, outputs, locals, write_pos, result, output, config); diff --git a/crates/burn-tensor/src/tests/ops/select.rs b/crates/burn-tensor/src/tests/ops/select.rs index 46f9aac958..b56fc15886 100644 --- a/crates/burn-tensor/src/tests/ops/select.rs +++ b/crates/burn-tensor/src/tests/ops/select.rs @@ -56,6 +56,19 @@ mod tests { output.into_data().assert_eq(&expected, false); } + #[test] + fn should_select_2d_dim0_vec() { + let device = Default::default(); + let tensor = + TestTensor::<2>::from_data([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]], &device); + let indices = TestTensorInt::from_data([1, 0, 3, 2], &device); + + let output = tensor.select(0, indices); + let expected = TensorData::from([[2.0, 3.0], [0.0, 1.0], [6.0, 7.0], [4.0, 5.0]]); + + output.into_data().assert_eq(&expected, false); + } + #[test] fn should_select_2d_dim1() { let device = Default::default(); diff --git a/crates/burn-tensor/src/tests/ops/topk.rs b/crates/burn-tensor/src/tests/ops/topk.rs index 9d98926655..a133e5e205 100644 --- a/crates/burn-tensor/src/tests/ops/topk.rs +++ b/crates/burn-tensor/src/tests/ops/topk.rs @@ -43,8 +43,7 @@ mod tests { } #[test] - fn test_topk_with_indices() { - // 1D + fn test_topk_with_indices_1d() { let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0); @@ -54,8 +53,10 @@ mod tests { let indices_expected = TensorData::from([4, 3, 2]); indices.into_data().assert_eq(&indices_expected, false); + } - // 3D + #[test] + fn test_topk_with_indices_3d() { let tensor = TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]);