diff --git a/Cargo.lock b/Cargo.lock index 6ce302defb..df2a3487b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1508,7 +1508,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1523,7 +1523,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1544,7 +1544,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1565,7 +1565,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "bytemuck", "cubecl-common", @@ -1579,7 +1579,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "bytemuck", "cubecl-common", @@ -1595,7 +1595,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "bytemuck", "cubecl-common", @@ -1621,7 +1621,7 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -1639,11 +1639,12 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "bytemuck", "cubecl-core", "cubecl-runtime", + "cubecl-std", "half", "serde", ] @@ -1651,7 +1652,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "cubecl-common", "darling", @@ -1666,7 +1667,7 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "darling", "proc-macro2", @@ -1677,7 +1678,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1693,7 +1694,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1703,7 +1704,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "async-channel", "async-lock", @@ -1725,7 +1726,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1737,10 +1738,19 @@ dependencies = [ "rspirv", ] +[[package]] +name = "cubecl-std" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" +dependencies = [ + "cubecl-core", + "cubecl-runtime", +] + [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5ff9df3f10a533885a2b61f4d55bb2f8d2750627#5ff9df3f10a533885a2b61f4d55bb2f8d2750627" +source = "git+https://github.com/tracel-ai/cubecl?rev=5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680#5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 26dc0cc406..a41db7b150 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,7 +64,9 @@ r2d2 = "0.8.10" r2d2_sqlite = "0.25.0" rayon = "1.10.0" regex = "1.11.1" -reqwest = { version="0.12.12", default-features = false, features=["rustls-tls"] } +reqwest = { version = "0.12.12", default-features = false, features = [ + "rustls-tls", +] } rmp-serde = "1.3.0" rstest = "0.23.0" rusqlite = "0.32.1" @@ -153,8 +155,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5ff9df3f10a533885a2b61f4d55bb2f8d2750627" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5ff9df3f10a533885a2b61f4d55bb2f8d2750627" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-autodiff/src/ops/bool_tensor.rs b/crates/burn-autodiff/src/ops/bool_tensor.rs index ef9f4c73df..5cfa63de99 100644 --- a/crates/burn-autodiff/src/ops/bool_tensor.rs +++ b/crates/burn-autodiff/src/ops/bool_tensor.rs @@ -59,6 +59,14 @@ impl BoolTensorOps for Autodiff { B::bool_not(tensor) } + fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + B::bool_and(lhs, rhs) + } + + fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + B::bool_or(lhs, rhs) + } + fn bool_into_float(tensor: BoolTensor) -> as Backend>::FloatTensorPrimitive { AutodiffTensor::new(B::bool_into_float(tensor)) } diff --git a/crates/burn-candle/src/ops/bool_tensor.rs b/crates/burn-candle/src/ops/bool_tensor.rs index 48e84a0a08..fe3215f573 100644 --- a/crates/burn-candle/src/ops/bool_tensor.rs +++ b/crates/burn-candle/src/ops/bool_tensor.rs @@ -73,6 +73,21 @@ impl BoolTensorOps for Candle< CandleTensor::new(tensor.tensor.eq(&x).unwrap()) } + fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + let x = candle_core::Tensor::ones_like(&lhs.tensor).unwrap(); + CandleTensor::new(lhs.tensor.add(&rhs.tensor).unwrap().gt(&x).unwrap()) + } + + fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + CandleTensor::new( + lhs.tensor + .add(&rhs.tensor) + .unwrap() + .clamp(0u32, 1u32) + .unwrap(), + ) + } + fn bool_swap_dims(tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { super::base::swap_dims(tensor, dim1, dim2) } diff --git a/crates/burn-core/src/nn/loss/poisson.rs b/crates/burn-core/src/nn/loss/poisson.rs index 3cc989ad8e..8ccefc828f 100644 --- a/crates/burn-core/src/nn/loss/poisson.rs +++ b/crates/burn-core/src/nn/loss/poisson.rs @@ -211,6 +211,8 @@ impl PoissonNllLoss { #[cfg(test)] mod tests { + #![allow(clippy::approx_constant)] + use super::*; use crate::tensor::TensorData; use crate::TestBackend; diff --git a/crates/burn-cubecl/src/fusion/matmul/optimization.rs b/crates/burn-cubecl/src/fusion/matmul/optimization.rs index 28581bf07c..08e39f29a1 100644 --- a/crates/burn-cubecl/src/fusion/matmul/optimization.rs +++ b/crates/burn-cubecl/src/fusion/matmul/optimization.rs @@ -411,19 +411,19 @@ fn matmul_launch_kernel<'a, R: Runtime, EG: Numeric, S: MatmulSelector>( || TypeId::of::() == TypeId::of::() { S::select_kernel::, R>( - client, input, output, problem, plane_size, + client, input, output, problem, plane_size, false, ) } else if TypeId::of::() == TypeId::of::() { S::select_kernel::, R>( - client, input, output, problem, plane_size, + client, input, output, problem, plane_size, false, ) } else if S::stage_tf32_supported() { S::select_kernel::, R>( - client, input, output, problem, plane_size, + client, input, output, problem, plane_size, false, ) } else { S::select_kernel::, R>( - client, input, output, problem, plane_size, + client, input, output, problem, plane_size, false, ) } } diff --git a/crates/burn-cubecl/src/kernel/binary.rs b/crates/burn-cubecl/src/kernel/binary.rs index 98222a5de2..3fcce9c0a4 100644 --- a/crates/burn-cubecl/src/kernel/binary.rs +++ b/crates/burn-cubecl/src/kernel/binary.rs @@ -24,6 +24,8 @@ pub(crate) struct SubOp; pub(crate) struct MulOp; pub(crate) struct DivOp; pub(crate) struct RemainderOp; +pub(crate) struct AndOp; +pub(crate) struct OrOp; /// Since Powf only works on float, but we still want to implement the numeric binary op family, we /// set another precision in the family type to cast, when necessary, the input value to a valid @@ -59,6 +61,14 @@ impl BinaryOpFamily for PowOp { type BinaryOp = Self; } +impl BinaryOpFamily for AndOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for OrOp { + type BinaryOp = Self; +} + #[cube] impl BinaryOp for AddOp { fn execute(lhs: Line, rhs: Line) -> Line { @@ -105,6 +115,20 @@ impl BinaryOp for PowOp { } } +#[cube] +impl BinaryOp for AndOp { + fn execute(lhs: Line, rhs: Line) -> Line { + Line::cast_from(Line::::cast_from(lhs).and(Line::::cast_from(rhs))) + } +} + +#[cube] +impl BinaryOp for OrOp { + fn execute(lhs: Line, rhs: Line) -> Line { + Line::cast_from(Line::::cast_from(lhs).or(Line::::cast_from(rhs))) + } +} + #[cube(launch_unchecked)] pub(crate) fn kernel_scalar_binop( input: &Tensor>, diff --git a/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/homogeneous/base.rs b/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/homogeneous/base.rs index 6976fbd68e..cf70974642 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/homogeneous/base.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv2d/gemm/homogeneous/base.rs @@ -207,6 +207,7 @@ where cube_dim, cube_count, advanced_config, + false, ); let size = SMM::stage_shape(&smm_config); diff --git a/crates/burn-cubecl/src/ops/bool_ops.rs b/crates/burn-cubecl/src/ops/bool_ops.rs index a29ffefc8e..632ac903b3 100644 --- a/crates/burn-cubecl/src/ops/bool_ops.rs +++ b/crates/burn-cubecl/src/ops/bool_ops.rs @@ -1,4 +1,8 @@ -use crate::{element::BoolElement, kernel, CubeBackend, CubeRuntime, FloatElement, IntElement}; +use crate::{ + element::BoolElement, + kernel::{self, AndOp, OrOp}, + CubeBackend, CubeRuntime, FloatElement, IntElement, +}; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor}; use burn_tensor::{ops::BoolTensorOps, Shape, TensorData}; use std::ops::Range; @@ -63,6 +67,14 @@ where kernel::equal_elem::(tensor, BT::false_val()) } + fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + kernel::launch_binop::(lhs, rhs) + } + + fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + kernel::launch_binop::(lhs, rhs) + } + fn bool_into_float(tensor: BoolTensor) -> FloatTensor { kernel::bool_cast::(tensor) } diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index c0479b32cc..ba7f78e13a 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -394,6 +394,78 @@ impl BoolTensorOps for Fusion { out } + fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + #[derive(new)] + struct AndOps { + desc: BinaryOpIr, + _b: PhantomData, + } + + impl Operation for AndOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_bool_tensor::(&self.desc.lhs); + let rhs = handles.get_bool_tensor::(&self.desc.rhs); + let output = B::bool_and(lhs, rhs); + handles.register_bool_tensor::(&self.desc.out.id, output); + } + } + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOpIr { + lhs: lhs.into_ir(), + rhs: rhs.into_ir(), + out: out.to_ir_out(), + }; + out.client.register( + vec![stream_1, stream_2], + OperationIr::Bool(BoolOperationIr::And(desc.clone())), + AndOps::::new(desc), + ); + + out + } + + fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + #[derive(new)] + struct OrOps { + desc: BinaryOpIr, + _b: PhantomData, + } + + impl Operation for OrOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let lhs = handles.get_bool_tensor::(&self.desc.lhs); + let rhs = handles.get_bool_tensor::(&self.desc.rhs); + let output = B::bool_or(lhs, rhs); + handles.register_bool_tensor::(&self.desc.out.id, output); + } + } + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs + .client + .tensor_uninitialized(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOpIr { + lhs: lhs.into_ir(), + rhs: rhs.into_ir(), + out: out.to_ir_out(), + }; + out.client.register( + vec![stream_1, stream_2], + OperationIr::Bool(BoolOperationIr::Or(desc.clone())), + OrOps::::new(desc), + ); + + out + } + fn bool_swap_dims(tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { #[derive(new)] struct SwapDimsOps { diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 7146223534..32a2174cd6 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -630,6 +630,16 @@ impl RelativeOps for BoolOperationIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), + BoolOperationIr::And(desc) => BoolOperationIr::And(BinaryOpIr { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }), + BoolOperationIr::Or(desc) => BoolOperationIr::Or(BinaryOpIr { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }), } } } diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index a70048d036..aaf193c7ce 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -2244,6 +2244,7 @@ mod tests { .assert_eq(&expected_indices_tensor, true); } + #[test] fn one_hot() { // Test for OneHot model diff --git a/crates/burn-ir/src/operation.rs b/crates/burn-ir/src/operation.rs index 1197fff632..5fa02df87f 100644 --- a/crates/burn-ir/src/operation.rs +++ b/crates/burn-ir/src/operation.rs @@ -574,6 +574,10 @@ pub enum BoolOperationIr { IntoInt(UnaryOpIr), /// Operation corresponding to [not](burn_tensor::ops::BoolTensorOps::bool_not). Not(UnaryOpIr), + /// Operation corresponding to [and](burn_tensor::ops::BoolTensorOps::bool_and). + And(BinaryOpIr), + /// Operation corresponding to [or](burn_tensor::ops::BoolTensorOps::bool_or). + Or(BinaryOpIr), } /// Swap dim operation intermediate representation. @@ -1626,6 +1630,8 @@ impl BoolOperationIr { BoolOperationIr::IntoFloat(repr) => vec![&repr.input, &repr.out], BoolOperationIr::IntoInt(repr) => vec![&repr.input, &repr.out], BoolOperationIr::Not(repr) => vec![&repr.input, &repr.out], + BoolOperationIr::And(repr) => vec![&repr.lhs, &repr.rhs, &repr.out], + BoolOperationIr::Or(repr) => vec![&repr.lhs, &repr.rhs, &repr.out], } } } diff --git a/crates/burn-ndarray/src/ops/bool_tensor.rs b/crates/burn-ndarray/src/ops/bool_tensor.rs index 79cb7402cc..6243975540 100644 --- a/crates/burn-ndarray/src/ops/bool_tensor.rs +++ b/crates/burn-ndarray/src/ops/bool_tensor.rs @@ -87,6 +87,22 @@ impl BoolTensorOp NdArrayTensor { array } } + fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + let output = Zip::from(&lhs.array) + .and(&rhs.array) + .map_collect(|&lhs_val, &rhs_val| (lhs_val && rhs_val)) + .into_shared(); + NdArrayTensor::new(output) + } + + fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + let output = Zip::from(&lhs.array) + .and(&rhs.array) + .map_collect(|&lhs_val, &rhs_val| (lhs_val || rhs_val)) + .into_shared(); + NdArrayTensor::new(output) + } + fn bool_into_float(tensor: NdArrayTensor) -> FloatTensor { new_tensor_float!(NdArrayTensor { array: tensor.array.mapv(|a| (a as i32).elem()).into_shared(), diff --git a/crates/burn-router/src/ops/binary.rs b/crates/burn-router/src/ops/binary.rs index 534c978bd8..3af45fde2d 100644 --- a/crates/burn-router/src/ops/binary.rs +++ b/crates/burn-router/src/ops/binary.rs @@ -53,3 +53,17 @@ macro_rules! binary_int_cmp_ops { $handles.register_bool_tensor::(&$desc.out.id, output); }}; } + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! binary_bool_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_bool_tensor::(&$desc.lhs); + let rhs = $handles.get_bool_tensor::(&$desc.rhs); + let output = $ops(lhs, rhs); + + $handles.register_bool_tensor::(&$desc.out.id, output); + }}; +} diff --git a/crates/burn-router/src/ops/op_bool.rs b/crates/burn-router/src/ops/op_bool.rs index 264cd8833f..cf2634de03 100644 --- a/crates/burn-router/src/ops/op_bool.rs +++ b/crates/burn-router/src/ops/op_bool.rs @@ -165,6 +165,36 @@ impl BoolTensorOps for BackendRouter { out } + fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + let client = lhs.client.clone(); + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = BinaryOpIr { + lhs: lhs.into_ir(), + rhs: rhs.into_ir(), + out: out.to_ir_out(), + }; + + client.register(OperationIr::Bool(BoolOperationIr::And(desc))); + + out + } + + fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + let client = lhs.client.clone(); + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = BinaryOpIr { + lhs: lhs.into_ir(), + rhs: rhs.into_ir(), + out: out.to_ir_out(), + }; + + client.register(OperationIr::Bool(BoolOperationIr::Or(desc))); + + out + } + fn bool_swap_dims(tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { let client = tensor.client.clone(); let mut shape = tensor.shape.clone(); diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 00ea2a4410..b73e54b91b 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -9,7 +9,7 @@ use core::future::Future; use super::{RouterTensor, RunnerClient}; use crate::{ - binary_float_cmp_ops, binary_float_ops, binary_int_cmp_ops, binary_int_ops, + binary_bool_ops, binary_float_cmp_ops, binary_float_ops, binary_int_cmp_ops, binary_int_ops, scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_dim_ops, scalar_float_ops, scalar_int_cmp_ops, scalar_int_dim_ops, scalar_int_ops, unary_float_ops, unary_int_ops, }; @@ -779,6 +779,12 @@ impl RunnerClient for Runner { let output = B::bool_not(tensor); handles.register_bool_tensor::(&desc.out.id, output); } + BoolOperationIr::And(desc) => { + binary_bool_ops!(handles, desc, B::bool_and) + } + BoolOperationIr::Or(desc) => { + binary_bool_ops!(handles, desc, B::bool_or) + } }, OperationIr::Int(op) => match op { IntOperationIr::IntoFloat(desc) => { diff --git a/crates/burn-tch/src/ops/bool_tensor.rs b/crates/burn-tch/src/ops/bool_tensor.rs index 6614b3f7bc..497020d307 100644 --- a/crates/burn-tch/src/ops/bool_tensor.rs +++ b/crates/burn-tch/src/ops/bool_tensor.rs @@ -70,6 +70,26 @@ impl BoolTensorOps for LibTorch { ) } + fn bool_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.logical_and_(rhs), + |lhs, rhs| rhs.logical_and_(lhs), + |lhs, rhs| lhs.logical_and(rhs), + ) + } + + fn bool_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.logical_or_(rhs), + |lhs, rhs| rhs.logical_or_(lhs), + |lhs, rhs| lhs.logical_or(rhs), + ) + } + fn bool_into_int(tensor: TchTensor) -> TchTensor { let tensor = tensor.tensor.to_kind(tch::Kind::Int64); TchTensor::new(tensor) diff --git a/crates/burn-tensor/src/tensor/api/bool.rs b/crates/burn-tensor/src/tensor/api/bool.rs index a02d51b4a7..cfeb70b6e8 100644 --- a/crates/burn-tensor/src/tensor/api/bool.rs +++ b/crates/burn-tensor/src/tensor/api/bool.rs @@ -39,6 +39,16 @@ where Tensor::new(B::bool_not(self.primitive)) } + /// Performs logical and (`&&`) on two boolean tensors + pub fn bool_and(self, rhs: Tensor) -> Tensor { + Tensor::new(B::bool_and(self.primitive, rhs.primitive)) + } + + /// Performs logical or (`||`) on two boolean tensors + pub fn bool_or(self, rhs: Tensor) -> Tensor { + Tensor::new(B::bool_or(self.primitive, rhs.primitive)) + } + /// Compute the indices of the elements that are non-zero. /// /// # Returns diff --git a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs index ffb74156f9..530ed99220 100644 --- a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs @@ -191,6 +191,30 @@ pub trait BoolTensorOps { /// The tensor with the result of the negation. fn bool_not(tensor: BoolTensor) -> BoolTensor; + /// Executes the logical and (`&&`) operation on two boolean tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The tensor with the result of the logical and. + fn bool_and(tensor: BoolTensor, rhs: BoolTensor) -> BoolTensor; + + /// Executes the logical or (`||`) operation on two boolean tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The tensor with the result of the logical or. + fn bool_or(tensor: BoolTensor, rhs: BoolTensor) -> BoolTensor; + /// Transposes a bool tensor. /// /// # Arguments diff --git a/crates/burn-tensor/src/tests/ops/bool.rs b/crates/burn-tensor/src/tests/ops/bool.rs index 96134b98aa..32f37f3c5c 100644 --- a/crates/burn-tensor/src/tests/ops/bool.rs +++ b/crates/burn-tensor/src/tests/ops/bool.rs @@ -18,4 +18,22 @@ mod tests { let data_expected = TensorData::from([[false, true, false], [true, true, true]]); assert_eq!(data_expected, data_actual); } + + #[test] + fn test_bool_and() { + let tensor1 = TestTensorBool::<2>::from([[false, true, false], [true, false, true]]); + let tensor2 = TestTensorBool::<2>::from([[true, true, false], [false, false, true]]); + let data_actual = tensor1.bool_and(tensor2).into_data(); + let data_expected = TensorData::from([[false, true, false], [false, false, true]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn test_bool_or() { + let tensor1 = TestTensorBool::<2>::from([[false, true, false], [true, false, true]]); + let tensor2 = TestTensorBool::<2>::from([[true, true, false], [false, false, true]]); + let data_actual = tensor1.bool_or(tensor2).into_data(); + let data_expected = TensorData::from([[true, true, false], [true, false, true]]); + assert_eq!(data_expected, data_actual); + } }