Skip to content

Commit

Permalink
feat(api): remaining reduction ops (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcvz authored Apr 30, 2024
1 parent 9a6dc1d commit c8cc0e9
Show file tree
Hide file tree
Showing 3 changed files with 988 additions and 54 deletions.
18 changes: 3 additions & 15 deletions src/ops/logical.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::array::Array;
use crate::error::{DataStoreError, OperationError};
use crate::stream::StreamOrDevice;
use crate::utils::{can_reduce_shape, is_broadcastable};
use crate::utils::{axes_or_default_to_all, can_reduce_shape, is_broadcastable};
use mlx_macros::default_device;

impl Array {
Expand Down Expand Up @@ -1126,13 +1126,7 @@ impl Array {
keep_dims: impl Into<Option<bool>>,
stream: StreamOrDevice,
) -> Array {
let axes = match axes.into() {
Some(axes) => axes.to_vec(),
None => {
let axes: Vec<i32> = (0..self.ndim() as i32).collect();
axes
}
};
let axes = axes_or_default_to_all(axes, self.ndim() as i32);

unsafe {
Array::from_ptr(mlx_sys::mlx_any(
Expand Down Expand Up @@ -1171,13 +1165,7 @@ impl Array {
keep_dims: impl Into<Option<bool>>,
stream: StreamOrDevice,
) -> Result<Array, OperationError> {
let axes = match axes.into() {
Some(axes) => axes.to_vec(),
None => {
let axes: Vec<i32> = (0..self.ndim() as i32).collect();
axes
}
};
let axes = axes_or_default_to_all(axes, self.ndim() as i32);

// verify reducing shape only if axes are provided
if !axes.is_empty() {
Expand Down
Loading

0 comments on commit c8cc0e9

Please sign in to comment.