-
Notifications
You must be signed in to change notification settings - Fork 504
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into fix/fuse-select
- Loading branch information
Showing
15 changed files
with
325 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import numpy as np | ||
import onnx | ||
from onnx import helper, TensorProto | ||
|
||
# Define the input tensor | ||
X = np.array([[0, 1, 2, 3], | ||
[4, 5, 6, 7], | ||
[8, 9, 10, 11]], dtype=np.float32) | ||
|
||
# Define the value of K | ||
k = 3 | ||
K = np.array([k], dtype=np.int64) | ||
axis = 1 | ||
new_dims = [X.shape[0], k] | ||
|
||
def create_model(op_set_version: int): | ||
input_tensors = [helper.make_tensor_value_info('X', TensorProto.FLOAT, X.shape)] | ||
|
||
output_tensors = [ | ||
helper.make_tensor_value_info('Values', TensorProto.FLOAT, new_dims), | ||
helper.make_tensor_value_info('Indices', TensorProto.INT32, new_dims) | ||
] | ||
|
||
# Create the TopK node | ||
if op_set_version > 1: | ||
node = helper.make_node( | ||
'TopK', | ||
inputs=['X', 'K'], | ||
outputs=['Values', 'Indices'], | ||
axis=axis, # Axis along which to find the top K elements | ||
) | ||
input_tensors.append(helper.make_tensor_value_info('K', TensorProto.INT32, K.shape)) | ||
else: | ||
node = helper.make_node( | ||
'TopK', | ||
inputs=['X'], | ||
outputs=['Values', 'Indices'], | ||
axis=axis, # Axis along which to find the top K elements | ||
k=k | ||
) | ||
|
||
# Create the graph | ||
graph = helper.make_graph( | ||
nodes = [node], | ||
name = 'TopKGraph', | ||
inputs = input_tensors, | ||
outputs = output_tensors, | ||
# Uncomment when initializers are supported. Currently we can't test opset 10/11 since the code will require a k value to be initialized for testing. | ||
#initializer = [ | ||
# helper.make_tensor('X', TensorProto.FLOAT, X.shape, X), | ||
# helper.make_tensor('K', TensorProto.INT64, [1], [k]), | ||
#] | ||
) | ||
|
||
# Create the model | ||
model = helper.make_model( | ||
graph, | ||
ir_version=8, | ||
opset_imports=[onnx.helper.make_operatorsetid("", op_set_version)] | ||
) | ||
# Check the model | ||
onnx.checker.check_model(model) | ||
|
||
# Save the model to a file | ||
onnx.save(model, f'top_k_opset_{op_set_version}.onnx') | ||
print(f"Model saved to top_k_opset_{op_set_version}.onnx") | ||
|
||
def main(): | ||
# Uncomment when initializers are supported. | ||
# for op_set_version in [1, 10, 11]: | ||
for op_set_version in [1]: | ||
create_model(op_set_version) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
use super::{Node, NodeCodegen}; | ||
use crate::burn::{Scope, TensorType, Type}; | ||
use burn::config::Config; | ||
use burn::record::PrecisionSettings; | ||
use proc_macro2::TokenStream; | ||
use quote::{quote, ToTokens}; | ||
|
||
#[derive(Config, Debug)] | ||
pub struct TopKConfig { | ||
pub axis: usize, | ||
pub k: usize, | ||
} | ||
|
||
#[derive(Debug, Clone, new)] | ||
pub struct TopKNode { | ||
pub input: TensorType, | ||
pub outputs: Vec<TensorType>, | ||
pub config: TopKConfig, | ||
} | ||
|
||
impl<PS: PrecisionSettings> NodeCodegen<PS> for TopKNode { | ||
fn output_types(&self) -> Vec<Type> { | ||
self.outputs | ||
.iter() | ||
.map(|t| Type::Tensor(t.clone())) | ||
.collect() | ||
} | ||
|
||
fn input_types(&self) -> Vec<Type> { | ||
vec![Type::Tensor(self.input.clone())] | ||
} | ||
|
||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { | ||
let axis = self.config.axis.to_token_stream(); | ||
let k = self.config.k.to_token_stream(); | ||
|
||
let input = scope.tensor_use_owned(&self.input, node_position); | ||
let values_output = &self.outputs[0].name; | ||
let indices_output = &self.outputs[1].name; | ||
|
||
quote! { | ||
let (#values_output, #indices_output) = #input.topk_with_indices(#k, #axis); | ||
} | ||
} | ||
|
||
fn into_node(self) -> Node<PS> { | ||
Node::TopK(self) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use burn::record::FullPrecisionSettings; | ||
|
||
use super::*; | ||
use crate::burn::{ | ||
graph::BurnGraph, | ||
node::{test::assert_tokens, top_k::TopKNode}, | ||
TensorType, | ||
}; | ||
|
||
#[test] | ||
fn test_codegen_nodes() { | ||
let mut graph = BurnGraph::<FullPrecisionSettings>::default(); | ||
let config = TopKConfig::new(1, 3); | ||
|
||
graph.register(TopKNode::new( | ||
TensorType::new_float("input_tensor", 4), | ||
vec![ | ||
TensorType::new_float("values_tensor", 4), | ||
TensorType::new_int("indices_tensor", 4), | ||
], | ||
config, | ||
)); | ||
|
||
graph.register_input_output( | ||
vec!["input_tensor".to_string()], | ||
vec!["values_tensor".to_string(), "indices_tensor".to_string()], | ||
); | ||
|
||
let expected = quote! { | ||
use burn::tensor::Int; | ||
use burn::{ | ||
module::Module, | ||
tensor::{backend::Backend, Tensor}, | ||
}; | ||
|
||
#[derive(Module, Debug)] | ||
pub struct Model<B: Backend> { | ||
phantom: core::marker::PhantomData<B>, | ||
device: burn::module::Ignored<B::Device>, | ||
} | ||
|
||
impl<B: Backend> Model <B> { | ||
#[allow(unused_variables)] | ||
pub fn new(device: &B::Device) -> Self { | ||
Self { | ||
phantom: core::marker::PhantomData, | ||
device: burn::module::Ignored(device.clone()), | ||
} | ||
} | ||
#[allow(clippy::let_and_return, clippy::approx_constant)] | ||
pub fn forward(&self, input_tensor: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 4, Int>) { | ||
let (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3usize, 1usize); | ||
(values_tensor, indices_tensor) | ||
} | ||
} | ||
}; | ||
|
||
assert_tokens(graph.codegen(), expected); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.