Skip to content

Commit

Permalink
Merge branch 'main' into fix/fuse-select
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Feb 12, 2025
2 parents 0dd4f44 + e23c8ef commit 49fa46e
Show file tree
Hide file tree
Showing 15 changed files with 325 additions and 13 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ represent the corresponding Burn Op.
| [TfIdfVectorizer][183] |||
| [ThresholdedRelu][184] |||
| [Tile][185] |||
| [TopK][186] | ||
| [TopK][186] | ||
| [Transpose][187] |||
| [Trilu][188] |||
| [Unique][189] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ fn main() {
.input("tests/sum/sum_int.onnx")
.input("tests/tanh/tanh.onnx")
.input("tests/tile/tile.onnx")
.input("tests/top_k/top_k_opset_1.onnx")
.input("tests/trilu/trilu_upper.onnx")
.input("tests/trilu/trilu_lower.onnx")
.input("tests/transpose/transpose.onnx")
Expand Down
28 changes: 27 additions & 1 deletion crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ include_models!(
sum_int,
tanh,
tile,
top_k_opset_1,
trilu_upper,
trilu_lower,
transpose,
Expand All @@ -135,7 +136,7 @@ mod tests {

use super::*;

use burn::tensor::{Bool, Int, Shape, Tensor, TensorData};
use burn::tensor::{cast::ToElement, Bool, Int, Shape, Tensor, TensorData};

use float_cmp::ApproxEq;

Expand Down Expand Up @@ -2218,6 +2219,31 @@ mod tests {
}

#[test]
fn top_k_opset_1() {
// Initialize the model
let device = Default::default();
let model = top_k_opset_1::Model::<Backend>::new(&device);

// Run the model
let input = Tensor::<Backend, 2>::from_floats(
[[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]],
&device,
);
let (values_tensor, indices_tensor) = model.forward(input);

// expected results
let expected_values_tensor =
TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]);
let expected_indices_tensor = TensorData::from([[3i64, 2, 1], [3, 2, 1]]);

values_tensor
.to_data()
.assert_eq(&expected_values_tensor, true);
indices_tensor
.to_data()
.assert_eq(&expected_indices_tensor, true);
}

fn one_hot() {
// Test for OneHot model

Expand Down
76 changes: 76 additions & 0 deletions crates/burn-import/onnx-tests/tests/top_k/top_k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import numpy as np
import onnx
from onnx import helper, TensorProto

# Define the input tensor
X = np.array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=np.float32)

# Define the value of K
k = 3
K = np.array([k], dtype=np.int64)
axis = 1
new_dims = [X.shape[0], k]

def create_model(op_set_version: int):
input_tensors = [helper.make_tensor_value_info('X', TensorProto.FLOAT, X.shape)]

output_tensors = [
helper.make_tensor_value_info('Values', TensorProto.FLOAT, new_dims),
helper.make_tensor_value_info('Indices', TensorProto.INT32, new_dims)
]

# Create the TopK node
if op_set_version > 1:
node = helper.make_node(
'TopK',
inputs=['X', 'K'],
outputs=['Values', 'Indices'],
axis=axis, # Axis along which to find the top K elements
)
input_tensors.append(helper.make_tensor_value_info('K', TensorProto.INT32, K.shape))
else:
node = helper.make_node(
'TopK',
inputs=['X'],
outputs=['Values', 'Indices'],
axis=axis, # Axis along which to find the top K elements
k=k
)

# Create the graph
graph = helper.make_graph(
nodes = [node],
name = 'TopKGraph',
inputs = input_tensors,
outputs = output_tensors,
# Uncomment when initializers are supported. Currently we can't test opset 10/11 since the code will require a k value to be initialized for testing.
#initializer = [
# helper.make_tensor('X', TensorProto.FLOAT, X.shape, X),
# helper.make_tensor('K', TensorProto.INT64, [1], [k]),
#]
)

# Create the model
model = helper.make_model(
graph,
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", op_set_version)]
)
# Check the model
onnx.checker.check_model(model)

# Save the model to a file
onnx.save(model, f'top_k_opset_{op_set_version}.onnx')
print(f"Model saved to top_k_opset_{op_set_version}.onnx")

def main():
# Uncomment when initializers are supported.
# for op_set_version in [1, 10, 11]:
for op_set_version in [1]:
create_model(op_set_version)


if __name__ == "__main__":
main()
Binary file not shown.
5 changes: 4 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::{
random_normal_like::RandomNormalLikeNode, random_uniform::RandomUniformNode,
random_uniform_like::RandomUniformLikeNode, range::RangeNode, reshape::ReshapeNode,
resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, tile::TileNode,
trilu::TriluNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
top_k::TopKNode, trilu::TriluNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::record::PrecisionSettings;
Expand Down Expand Up @@ -119,6 +119,7 @@ pub enum Node<PS: PrecisionSettings> {
Squeeze(SqueezeNode),
Sum(SumNode),
Tile(TileNode),
TopK(TopKNode),
Trilu(TriluNode),
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Expand Down Expand Up @@ -173,6 +174,7 @@ macro_rules! match_all {
Node::Squeeze(node) => $func(node),
Node::Sum(node) => $func(node),
Node::Tile(node) => $func(node),
Node::TopK(node) => $func(node),
Node::Trilu(node) => $func(node),
Node::Unary(node) => $func(node),
Node::Unsqueeze(node) => $func(node),
Expand Down Expand Up @@ -235,6 +237,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Squeeze(_) => "squeeze",
Node::Sum(_) => "add",
Node::Tile(_) => "tile",
Node::TopK(_) => "top_k",
Node::Trilu(_) => "trilu",
Node::Unary(unary) => unary.kind.as_str(),
Node::Unsqueeze(_) => "unsqueeze",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub(crate) mod slice;
pub(crate) mod squeeze;
pub(crate) mod sum;
pub(crate) mod tile;
pub(crate) mod top_k;
pub(crate) mod trilu;
pub(crate) mod unary;
pub(crate) mod unsqueeze;
Expand Down
112 changes: 112 additions & 0 deletions crates/burn-import/src/burn/node/top_k.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use burn::config::Config;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};

#[derive(Config, Debug)]
pub struct TopKConfig {
pub axis: usize,
pub k: usize,
}

#[derive(Debug, Clone, new)]
pub struct TopKNode {
pub input: TensorType,
pub outputs: Vec<TensorType>,
pub config: TopKConfig,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for TopKNode {
fn output_types(&self) -> Vec<Type> {
self.outputs
.iter()
.map(|t| Type::Tensor(t.clone()))
.collect()
}

fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let axis = self.config.axis.to_token_stream();
let k = self.config.k.to_token_stream();

let input = scope.tensor_use_owned(&self.input, node_position);
let values_output = &self.outputs[0].name;
let indices_output = &self.outputs[1].name;

quote! {
let (#values_output, #indices_output) = #input.topk_with_indices(#k, #axis);
}
}

fn into_node(self) -> Node<PS> {
Node::TopK(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{test::assert_tokens, top_k::TopKNode},
TensorType,
};

#[test]
fn test_codegen_nodes() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let config = TopKConfig::new(1, 3);

graph.register(TopKNode::new(
TensorType::new_float("input_tensor", 4),
vec![
TensorType::new_float("values_tensor", 4),
TensorType::new_int("indices_tensor", 4),
],
config,
));

graph.register_input_output(
vec!["input_tensor".to_string()],
vec!["values_tensor".to_string(), "indices_tensor".to_string()],
);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input_tensor: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 4, Int>) {
let (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3usize, 1usize);
(values_tensor, indices_tensor)
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
49 changes: 48 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use burn::nn::{
};

use crate::burn::node::{
expand::ExpandShape, pad::PadConfig, tile::TileConfig, trilu::TriluConfig,
expand::ExpandShape, pad::PadConfig, tile::TileConfig, top_k::TopKConfig, trilu::TriluConfig,
};
use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node};

Expand Down Expand Up @@ -900,6 +900,53 @@ pub fn tile_config(node: &Node) -> TileConfig {
TileConfig::new(repeat)
}

/// Create a TopKConfig from the attributes of the node.
pub fn top_k_config(node: &Node) -> TopKConfig {
// extract the shape of the input data tensor
let data_tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

let k = match node.inputs.get(1) {
Some(k_tensor) => k_tensor
.clone()
.value
.expect("TopK: only constant 'k' tensor is currently supported")
.into_i64s()[0],
_ => node
.attrs
.get("k")
.expect("TopK: number of top elements 'k' is missing")
.clone()
.into_i64(),
};

let mut axis = match node.attrs.get("axis") {
Some(axis) => axis.clone().into_i64(),
None => -1,
};

// if axis is negative, it is counted from the end
if axis < 0 {
axis += data_tensor.dim as i64;
}

if let Some(largest) = node.attrs.get("largest") {
if largest.clone().into_i64() != 1 {
unimplemented!("TopK: only largest elements is supported")
}
};

if let Some(sorted) = node.attrs.get("sorted") {
if sorted.clone().into_i64() != 1 {
unimplemented!("TopK: only sorted elements is supported")
}
};

TopKConfig::new(axis as usize, k as usize)
}

/// Create a TriluConfig from the attributes of the node
pub fn trilu_config(node: &Node) -> TriluConfig {
let mut upper = true;
Expand Down
Loading

0 comments on commit 49fa46e

Please sign in to comment.