Skip to content

Commit

Permalink
Add boolean and/or to bool tensors (#2802)
Browse files Browse the repository at this point in the history
* Add boolean and/or to bool tensors and fix version mismatch for rand

* Add tests

* Fix tch

* Fix for cubecl update

* Add missing test attribute

* Update crates/burn-fusion/src/stream/context.rs
  • Loading branch information
wingertge authored Feb 14, 2025
1 parent d9e4146 commit 2c4c039
Show file tree
Hide file tree
Showing 21 changed files with 325 additions and 24 deletions.
40 changes: 25 additions & 15 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ r2d2 = "0.8.10"
r2d2_sqlite = "0.25.0"
rayon = "1.10.0"
regex = "1.11.1"
reqwest = { version="0.12.12", default-features = false, features=["rustls-tls"] }
reqwest = { version = "0.12.12", default-features = false, features = [
"rustls-tls",
] }
rmp-serde = "1.3.0"
rstest = "0.23.0"
rusqlite = "0.32.1"
Expand Down Expand Up @@ -153,8 +155,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5ff9df3f10a533885a2b61f4d55bb2f8d2750627" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5ff9df3f10a533885a2b61f4d55bb2f8d2750627" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5ac5d57e24dbe8a06fd5a52bc730cb7673c7b680" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
8 changes: 8 additions & 0 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_not(tensor)
}

fn bool_and(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
B::bool_and(lhs, rhs)
}

fn bool_or(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
B::bool_or(lhs, rhs)
}

fn bool_into_float(tensor: BoolTensor<B>) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
AutodiffTensor::new(B::bool_into_float(tensor))
}
Expand Down
15 changes: 15 additions & 0 deletions crates/burn-candle/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
CandleTensor::new(tensor.tensor.eq(&x).unwrap())
}

fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
let x = candle_core::Tensor::ones_like(&lhs.tensor).unwrap();
CandleTensor::new(lhs.tensor.add(&rhs.tensor).unwrap().gt(&x).unwrap())
}

fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
CandleTensor::new(
lhs.tensor
.add(&rhs.tensor)
.unwrap()
.clamp(0u32, 1u32)
.unwrap(),
)
}

fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
super::base::swap_dims(tensor, dim1, dim2)
}
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-core/src/nn/loss/poisson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ impl PoissonNllLoss {

#[cfg(test)]
mod tests {
#![allow(clippy::approx_constant)]

use super::*;
use crate::tensor::TensorData;
use crate::TestBackend;
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-cubecl/src/fusion/matmul/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,19 +411,19 @@ fn matmul_launch_kernel<'a, R: Runtime, EG: Numeric, S: MatmulSelector>(
|| TypeId::of::<EG>() == TypeId::of::<flex32>()
{
S::select_kernel::<FusedMatmulSpec<EG, half::f16, f32>, R>(
client, input, output, problem, plane_size,
client, input, output, problem, plane_size, false,
)
} else if TypeId::of::<EG>() == TypeId::of::<half::bf16>() {
S::select_kernel::<FusedMatmulSpec<EG, half::bf16, f32>, R>(
client, input, output, problem, plane_size,
client, input, output, problem, plane_size, false,
)
} else if S::stage_tf32_supported() {
S::select_kernel::<FusedMatmulSpec<EG, tf32, f32>, R>(
client, input, output, problem, plane_size,
client, input, output, problem, plane_size, false,
)
} else {
S::select_kernel::<FusedMatmulSpec<EG, EG, f32>, R>(
client, input, output, problem, plane_size,
client, input, output, problem, plane_size, false,
)
}
}
24 changes: 24 additions & 0 deletions crates/burn-cubecl/src/kernel/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub(crate) struct SubOp;
pub(crate) struct MulOp;
pub(crate) struct DivOp;
pub(crate) struct RemainderOp;
pub(crate) struct AndOp;
pub(crate) struct OrOp;

/// Since Powf only works on float, but we still want to implement the numeric binary op family, we
/// set another precision in the family type to cast, when necessary, the input value to a valid
Expand Down Expand Up @@ -59,6 +61,14 @@ impl<F: Float> BinaryOpFamily for PowOp<F> {
type BinaryOp<C: Numeric> = Self;
}

impl BinaryOpFamily for AndOp {
type BinaryOp<C: Numeric> = Self;
}

impl BinaryOpFamily for OrOp {
type BinaryOp<C: Numeric> = Self;
}

#[cube]
impl<N: Numeric> BinaryOp<N> for AddOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
Expand Down Expand Up @@ -105,6 +115,20 @@ impl<N: Numeric, F: Float> BinaryOp<N> for PowOp<F> {
}
}

#[cube]
impl<N: Numeric> BinaryOp<N> for AndOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
Line::cast_from(Line::<bool>::cast_from(lhs).and(Line::<bool>::cast_from(rhs)))
}
}

#[cube]
impl<N: Numeric> BinaryOp<N> for OrOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
Line::cast_from(Line::<bool>::cast_from(lhs).or(Line::<bool>::cast_from(rhs)))
}
}

#[cube(launch_unchecked)]
pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOpFamily>(
input: &Tensor<Line<C>>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ where
cube_dim,
cube_count,
advanced_config,
false,
);
let size = SMM::stage_shape(&smm_config);

Expand Down
14 changes: 13 additions & 1 deletion crates/burn-cubecl/src/ops/bool_ops.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::{element::BoolElement, kernel, CubeBackend, CubeRuntime, FloatElement, IntElement};
use crate::{
element::BoolElement,
kernel::{self, AndOp, OrOp},
CubeBackend, CubeRuntime, FloatElement, IntElement,
};
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor};
use burn_tensor::{ops::BoolTensorOps, Shape, TensorData};
use std::ops::Range;
Expand Down Expand Up @@ -63,6 +67,14 @@ where
kernel::equal_elem::<R, BT, BT>(tensor, BT::false_val())
}

fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
kernel::launch_binop::<R, BT, AndOp>(lhs, rhs)
}

fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
kernel::launch_binop::<R, BT, OrOp>(lhs, rhs)
}

fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
kernel::bool_cast::<R, BT, F>(tensor)
}
Expand Down
Loading

0 comments on commit 2c4c039

Please sign in to comment.