diff --git a/.gitignore b/.gitignore index 0b7ff501b2..2b8a771bc2 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ Cargo.lock .idea .vscode .fleet +.vs diff --git a/burn-core/src/nn/embedding.rs b/burn-core/src/nn/embedding.rs index aa6ba6066a..314c9cfe11 100644 --- a/burn-core/src/nn/embedding.rs +++ b/burn-core/src/nn/embedding.rs @@ -28,7 +28,9 @@ pub struct EmbeddingConfig { /// `N(0, 1)` #[derive(Module, Debug)] pub struct Embedding { - weight: Param>, + /// The learnable weights of the module of shape [n_embedding, d_model] initialized + /// from a normal distribution `N(0, 1)`. + pub weight: Param>, } impl EmbeddingConfig { diff --git a/burn-fusion/src/ops/boolean.rs b/burn-fusion/src/ops/boolean.rs index e46d9e0a34..4d8a81ca32 100644 --- a/burn-fusion/src/ops/boolean.rs +++ b/burn-fusion/src/ops/boolean.rs @@ -266,7 +266,7 @@ impl BoolTensorOps for Fusion { } } - let tensor_first = tensors.get(0).unwrap(); + let tensor_first = tensors.first().unwrap(); let client = tensor_first.client.clone(); // Calculate the output shape diff --git a/burn-fusion/src/ops/float.rs b/burn-fusion/src/ops/float.rs index 3ef997aa50..cf611165b0 100644 --- a/burn-fusion/src/ops/float.rs +++ b/burn-fusion/src/ops/float.rs @@ -1367,7 +1367,7 @@ impl TensorOps for Fusion { } } - let tensor_first = tensors.get(0).unwrap(); + let tensor_first = tensors.first().unwrap(); let client = tensor_first.client.clone(); // Calculate the output shape diff --git a/burn-fusion/src/ops/int.rs b/burn-fusion/src/ops/int.rs index 33289fea31..409f43aa91 100644 --- a/burn-fusion/src/ops/int.rs +++ b/burn-fusion/src/ops/int.rs @@ -434,7 +434,7 @@ impl IntTensorOps for Fusion { } } - let tensor_first = tensors.get(0).unwrap(); + let tensor_first = tensors.first().unwrap(); let client = tensor_first.client.clone(); // Calculate the output shape diff --git a/burn-import/src/onnx/op_configuration.rs b/burn-import/src/onnx/op_configuration.rs index 63cbbbf352..eacb4d9b8d 100644 --- a/burn-import/src/onnx/op_configuration.rs +++ b/burn-import/src/onnx/op_configuration.rs @@ -219,7 +219,7 @@ pub fn flatten_config(curr: &Node) -> (usize, usize) { } // extract the shape of the input tensor - let tensor = match curr.inputs.get(0).unwrap().clone().ty { + let tensor = match curr.inputs.first().unwrap().clone().ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), }; @@ -262,7 +262,7 @@ pub fn gather_config(curr: &Node) -> usize { } // extract the shape of the input tensor - let tensor = match curr.inputs.get(0).unwrap().clone().ty { + let tensor = match curr.inputs.first().unwrap().clone().ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), }; @@ -355,7 +355,7 @@ pub fn log_softmax_config(node: &Node) -> usize { } // extract the shape of the input tensor - let tensor = match node.inputs.get(0).unwrap().clone().ty { + let tensor = match node.inputs.first().unwrap().clone().ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), }; @@ -390,7 +390,7 @@ pub fn softmax_config(node: &Node) -> usize { } // extract the shape of the input tensor - let tensor = match node.inputs.get(0).unwrap().clone().ty { + let tensor = match node.inputs.first().unwrap().clone().ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), }; @@ -417,7 +417,7 @@ pub fn concat_config(node: &Node) -> usize { let mut axis: i64 = 1; // extract the shape of the input tensor - let tensor = match node.inputs.get(0).unwrap().clone().ty { + let tensor = match node.inputs.first().unwrap().clone().ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), }; diff --git a/burn-import/src/onnx/to_burn.rs b/burn-import/src/onnx/to_burn.rs index db5c84c24a..1ddc7f57d9 100644 --- a/burn-import/src/onnx/to_burn.rs +++ b/burn-import/src/onnx/to_burn.rs @@ -289,7 +289,7 @@ impl ONNXGraph { } fn constant_conversion(node: Node) -> ConstantNode { - let output = node.outputs.get(0).unwrap(); + let output = node.outputs.first().unwrap(); let attr = convert_constant_value(&node); @@ -340,168 +340,168 @@ impl ONNXGraph { } fn add_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); + let lhs = node.inputs.first().unwrap().to_type(); let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); BinaryNode::add(lhs, rhs, output) } fn sub_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); + let lhs = node.inputs.first().unwrap().to_type(); let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); BinaryNode::sub(lhs, rhs, output) } fn mul_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); + let lhs = node.inputs.first().unwrap().to_type(); let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); BinaryNode::mul(lhs, rhs, output) } fn div_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); + let lhs = node.inputs.first().unwrap().to_type(); let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); BinaryNode::div(lhs, rhs, output) } fn matmul_conversion(node: Node) -> MatmulNode { - let lhs = node.inputs.get(0).unwrap().to_tensor_type(); + let lhs = node.inputs.first().unwrap().to_tensor_type(); let rhs = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); MatmulNode::new(lhs, rhs, output) } fn equal_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.get(0).unwrap().to_type(); + let lhs = node.inputs.first().unwrap().to_type(); let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); BinaryNode::equal(lhs, rhs, output) } fn erf_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::erf(input, output) } fn relu_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::relu(input, output) } fn gelu_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::gelu(input, output) } fn log_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::log(input, output) } fn flatten_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); let (start_dim, end_dim) = flatten_config(&node); UnaryNode::flatten(input, output, start_dim, end_dim) } fn gather_conversion(node: Node) -> GatherNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); let index = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let dim = gather_config(&node); GatherNode::new(input, index, output, dim) } fn transpose_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::transpose(input, output) } fn cast_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::cast(input, output) } fn reshape_conversion(node: Node) -> ReshapeNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let shape = reshape_config(&node); ReshapeNode::new(input, output, shape) } fn clip_conversion(node: Node) -> ClipNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let (min, max) = clip_config(&node); ClipNode::new(input, output, min, max) } fn sigmoid_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::sigmoid(input, output) } fn reciprocal_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::reciprocal(input, output) } fn log_softmax_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); let dim = log_softmax_config(&node); UnaryNode::log_softmax(input, output, dim) } fn softmax_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); let dim = softmax_config(&node); UnaryNode::softmax(input, output, dim) } fn sqrt_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::sqrt(input, output) } fn tanh_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::tanh(input, output) } @@ -513,7 +513,7 @@ impl ONNXGraph { .map(|input| input.to_tensor_type()) .collect(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let dim = concat_config(&node); ConcatNode::new(inputs, output, dim) @@ -521,8 +521,8 @@ impl ONNXGraph { fn linear_conversion(node: Node) -> LinearNode { let name = &node.name; - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let config = linear_config(&node); let weight = extract_data_serialize::(1, &node).expect("Weight is required"); @@ -534,8 +534,8 @@ impl ONNXGraph { fn dropout_conversion(node: Node) -> DropoutNode { let name = &node.name; - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let config = dropout_config(&node); DropoutNode::new(name, input, output, config) @@ -543,8 +543,8 @@ impl ONNXGraph { fn batch_norm_conversion(node: Node) -> BatchNormNode { let config = batch_norm_config(&node); - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let dim = input.dim - 2; let gamma = extract_data_serialize::(1, &node).expect("Gamma is required"); @@ -570,8 +570,8 @@ impl ONNXGraph { } fn conv1d_conversion(node: Node) -> Conv1dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let config = conv1d_config(&node); let bias = node.inputs.len() == 3; @@ -586,8 +586,8 @@ impl ONNXGraph { } fn conv2d_conversion(node: Node) -> Conv2dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let config = conv2d_config(&node); let bias = node.inputs.len() == 3; @@ -602,8 +602,8 @@ impl ONNXGraph { } fn max_pool2d_conversion(node: Node) -> MaxPool2dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let config = max_pool2d_config(&node); let name = &node.name; @@ -611,8 +611,8 @@ impl ONNXGraph { } fn conv_transpose2d_conversion(node: Node) -> ConvTranspose2dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let config = conv_transpose2d_config(&node); let bias = node.inputs.len() == 3; @@ -627,8 +627,8 @@ impl ONNXGraph { } fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let config = avg_pool2d_config(&node); let name = &node.name; @@ -636,8 +636,8 @@ impl ONNXGraph { } fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode { - let input = node.inputs.get(0).unwrap().to_tensor_type(); - let output = node.outputs.get(0).unwrap().to_tensor_type(); + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); let name = &node.name; @@ -645,22 +645,22 @@ impl ONNXGraph { } fn cos_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::cos(input, output) } fn exp_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::exp(input, output) } fn neg_conversion(node: Node) -> UnaryNode { - let input = node.inputs.get(0).unwrap().to_type(); - let output = node.outputs.get(0).unwrap().to_type(); + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); UnaryNode::neg(input, output) } } diff --git a/burn-tensor/src/tensor/api/check.rs b/burn-tensor/src/tensor/api/check.rs index 7457c3efae..38ce607e65 100644 --- a/burn-tensor/src/tensor/api/check.rs +++ b/burn-tensor/src/tensor/api/check.rs @@ -353,7 +353,7 @@ impl TensorCheck { ); } - let shape_reference = tensors.get(0).unwrap().shape(); + let shape_reference = tensors.first().unwrap().shape(); for tensor in tensors { let shape = tensor.shape(); @@ -398,7 +398,7 @@ impl TensorCheck { ); } - let mut shape_reference = tensors.get(0).unwrap().shape(); + let mut shape_reference = tensors.first().unwrap().shape(); shape_reference.dims[dim] = 1; // We want to check every dims except the one where the // concatenation happens. diff --git a/burn-train/src/learner/epoch.rs b/burn-train/src/learner/epoch.rs index 06eba7b7d9..bd6b9e18c1 100644 --- a/burn-train/src/learner/epoch.rs +++ b/burn-train/src/learner/epoch.rs @@ -192,7 +192,7 @@ impl TrainEpoch { let step = MultiDevicesTrainStep::new(&devices); // The main device is always the first in the list. - let device_main = devices.get(0).expect("A minimum of one device.").clone(); + let device_main = devices.first().expect("A minimum of one device.").clone(); let mut interrupted = false; loop { diff --git a/burn-train/src/learner/train_val.rs b/burn-train/src/learner/train_val.rs index b8b16dddf9..79cb292bc3 100644 --- a/burn-train/src/learner/train_val.rs +++ b/burn-train/src/learner/train_val.rs @@ -124,7 +124,7 @@ impl Learner { { log::info!("Fitting {}", self.model.to_string()); // The reference model is always on the first device provided. - if let Some(device) = self.devices.get(0) { + if let Some(device) = self.devices.first() { self.model = self.model.fork(device); } diff --git a/burn-wgpu/src/kernel/base.rs b/burn-wgpu/src/kernel/base.rs index 3587e54902..8f6b26642f 100644 --- a/burn-wgpu/src/kernel/base.rs +++ b/burn-wgpu/src/kernel/base.rs @@ -214,7 +214,7 @@ pub fn build_info(tensors: &[&WgpuTensor]) /// Similar to [build info](build_info) but with dynamic rank. pub fn build_info_dyn(shapes: &[&[usize]], strides: &[&[usize]]) -> Vec { - let rank = shapes.get(0).unwrap().len(); + let rank = shapes.first().unwrap().len(); let mut info: Vec = vec![0; shapes.len() * 2 * rank + 1]; info[0] = rank as u32; diff --git a/burn-wgpu/src/kernel/cat.rs b/burn-wgpu/src/kernel/cat.rs index 3c4441af9a..763b52f49c 100644 --- a/burn-wgpu/src/kernel/cat.rs +++ b/burn-wgpu/src/kernel/cat.rs @@ -14,7 +14,7 @@ pub fn cat( inputs: Vec>, dim: usize, ) -> WgpuTensor { - let first_input = inputs.get(0).unwrap(); + let first_input = inputs.first().unwrap(); let client = &first_input.client; let mut shape_output = first_input.shape.clone(); shape_output.dims[dim] = inputs.iter().map(|input| input.shape.dims[dim]).sum();