diff --git a/crates/burn-import/src/burn/node/split.rs b/crates/burn-import/src/burn/node/split.rs index 17c67dbff3..d3bdf42cee 100644 --- a/crates/burn-import/src/burn/node/split.rs +++ b/crates/burn-import/src/burn/node/split.rs @@ -35,7 +35,6 @@ impl NodeCodegen for SplitNode { let input = scope.tensor_use_owned(&self.input, node_position); let axis = self.config.axis.to_tokens(); - let split_tensors = syn::Ident::new("split_tensors", proc_macro2::Span::call_site()); let outputs = self .outputs .iter() @@ -126,7 +125,7 @@ mod tests { &self, tensor1: Tensor, ) -> (Tensor, Tensor) { - let split_tensors = tensor1.split(2, 0); + let mut split_tensors = tensor1.split(2, 0); let [tensor2, tensor3] = split_tensors.try_into().unwrap(); (tensor2, tensor3) diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index 0ef4b3cfb5..8d7bf36ffd 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -915,7 +915,7 @@ fn split_update_outputs(node: &mut Node) { }; let input_dims = match &input_tensor.shape { - Some(shape) => shape.iter().map(|dim| *dim).collect::>(), + Some(shape) => shape.to_vec(), None => panic!("Split: Input tensor shape is not defined"), }; @@ -970,8 +970,8 @@ fn split_update_outputs(node: &mut Node) { let mut sizes = vec![split_size; num_outputs]; // According to ONNX spec, the last chunk will be smaller if not evenly divisible - for i in 0..remainder { - sizes[i] += 1; + for size in sizes.iter_mut().take(remainder) { + *size += 1; } sizes