Skip to content

Commit

Permalink
chore(mlx-c): 0.10.0 Add remaining additions (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcvz authored Oct 31, 2024
1 parent 0a2baa8 commit 522ceee
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 47 deletions.
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
[submodule "mlx-sys/mlx"]
path = mlx-sys/mlx
url = https://github.com/ml-explore/mlx.git
[submodule "mlx-sys/src/mlx-c"]
path = mlx-sys/src/mlx-c
url = https://github.com/ml-explore/mlx-c.git
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ We try to be as welcoming as possible to everybody from any background. We're st

mlx-rs is currently in active development, and is not yet complete.

## MSRV

The minimum supported Rust version is 1.75.0.

The MSRV is the minimum Rust version that can be used to compile each crate.

## License

mlx-rs is distributed under the terms of the MIT license. See [LICENSE](./LICENSE) for details.
Expand Down
72 changes: 61 additions & 11 deletions mlx-rs/src/ops/indexing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,10 @@ impl Array {
}
}

// NOTE: take and take_long_axis are two separate functions in the c++ code. They don't call
// each other.

/// Take values along an axis at the specified indices.
///
/// If no axis is specified, the array is flattened to 1D prior to the indexing operation.
///
/// # Params
///
/// - `indices`: The indices to take from the array.
Expand All @@ -365,22 +364,61 @@ impl Array {
pub fn take_along_axis_device(
&self,
indices: &Array,
axis: i32,
axis: impl Into<Option<i32>>,
stream: impl AsRef<Stream>,
) -> Result<Array, Exception> {
let (input, axis) = match axis.into() {
None => (self.reshape_device(&[-1], &stream)?, 0),
Some(ax) => (self.clone(), ax),
};

unsafe {
let c_array = try_catch_c_ptr_expr! {
mlx_sys::mlx_take_along_axis(
self.c_array,
indices.c_array,
axis,
stream.as_ref().as_ptr(),
)
mlx_sys::mlx_take_along_axis(input.c_array, indices.c_array, axis, stream.as_ref().as_ptr())
};

Ok(Array::from_ptr(c_array))
}
}

/// Put values along an axis at the specified indices.
///
/// If no axis is specified, the array is flattened to 1D prior to the indexing operation.
///
/// # Params
/// - indices: Indices array. These should be broadcastable with the input array excluding the `axis` dimension.
/// - values: Values array. These should be broadcastable with the indices.
/// - axis: Axis in the destination to put the values to.
/// - stream: stream or device to evaluate on.
#[default_device]
pub fn put_along_axis_device(
&self,
indices: &Array,
values: &Array,
axis: impl Into<Option<i32>>,
stream: impl AsRef<Stream>,
) -> Result<Array, Exception> {
match axis.into() {
None => unsafe {
let input = self.reshape_device(&[-1], &stream)?;

let c_array = try_catch_c_ptr_expr! {
mlx_sys::mlx_put_along_axis(input.c_array, indices.c_array, values.c_array, 0, stream.as_ref().as_ptr())
};

let array = Array::from_ptr(c_array);
let array = array.reshape_device(self.shape(), &stream)?;
Ok(array)
},
Some(ax) => unsafe {
let c_array = try_catch_c_ptr_expr! {
mlx_sys::mlx_put_along_axis(self.c_array, indices.c_array, values.c_array, ax, stream.as_ref().as_ptr())
};

Ok(Array::from_ptr(c_array))
},
}
}
}

/// Indices of the maximum values along the axis.
Expand Down Expand Up @@ -579,12 +617,24 @@ pub fn argsort_all_device(a: &Array, stream: impl AsRef<Stream>) -> Result<Array
pub fn take_along_axis_device(
a: &Array,
indices: &Array,
axis: i32,
axis: impl Into<Option<i32>>,
stream: impl AsRef<Stream>,
) -> Result<Array, Exception> {
a.take_along_axis_device(indices, axis, stream)
}

/// See [`Array::put_along_axis`]
#[default_device]
pub fn put_along_axis_device(
a: &Array,
indices: &Array,
values: &Array,
axis: impl Into<Option<i32>>,
stream: impl AsRef<Stream>,
) -> Result<Array, Exception> {
a.put_along_axis_device(indices, values, axis, stream)
}

/// See [`Array::take`]
#[default_device]
pub fn take_device(
Expand Down
52 changes: 52 additions & 0 deletions mlx-rs/src/ops/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,39 @@ impl Array {
Ok(Array::from_ptr(c_array))
}
}

/// Perform the Walsh-Hadamard transform along the final axis.
///
/// Supports sizes `n = m*2^k` for `m` in `(1, 12, 20, 28)` and `2^k <= 8192`
/// for ``DType/float32`` and `2^k <= 16384` for ``DType/float16`` and ``DType/bfloat16``.
///
/// # Params
/// - scale: scale the output by this factor -- default is `1.0/sqrt(array.dim(-1))`
/// - stream: stream to evaluate on.
#[default_device]
pub fn hadamard_transform_device(
&self,
scale: impl Into<Option<f32>>,
stream: impl AsRef<Stream>,
) -> Result<Array, Exception> {
let scale = scale.into();
let scale = mlx_sys::mlx_optional_float {
value: scale.unwrap_or(0.0),
has_value: scale.is_some(),
};

unsafe {
let c_array = try_catch_c_ptr_expr! {
mlx_sys::mlx_hadamard_transform(
self.c_array,
scale,
stream.as_ref().as_ptr(),
)
};

Ok(Array::from_ptr(c_array))
}
}
}

/// See [`Array::diag`]
Expand Down Expand Up @@ -248,4 +281,23 @@ mod tests {
let out = einsum("ii->", &[m]).unwrap();
assert_eq!(out, array!(5.0));
}

#[test]
fn test_hadamard_transform() {
let input = Array::from_slice(&[1.0, -1.0, -1.0, 1.0], &[2, 2]);
let expected = Array::from_slice(
&[
0.0,
2.0_f32 / 2.0_f32.sqrt(),
0.0,
-2.0_f32 / 2.0_f32.sqrt(),
],
&[2, 2],
);
let result = input.hadamard_transform(None).unwrap();

let c = result.all_close(&expected, 1e-5, 1e-5, None).unwrap();
let c_data: &[bool] = c.as_slice();
assert_eq!(c_data, [true]);
}
}
1 change: 1 addition & 0 deletions mlx-rs/src/ops/shapes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ impl PadMode {
/// `(before_i, after_i)` are all the same. If a single integer or tuple with a single integer is
/// passed then all axes are extended by the same number on each side.
/// - `value`: The value to pad the array with. Default is `0` if not provided.
/// - `mode`: The padding mode. Default is `PadMode::Constant` if not provided.
///
/// # Example
///
Expand Down
31 changes: 0 additions & 31 deletions mlx-rs/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,37 +367,6 @@ where

extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {}

pub(crate) struct VectorVectorArray {
c_vec: mlx_sys::mlx_vector_vector_array,
}

impl Drop for VectorVectorArray {
fn drop(&mut self) {
unsafe { mlx_sys::mlx_free(self.c_vec as *mut c_void) }
}
}

impl VectorVectorArray {
pub(crate) unsafe fn from_ptr(c_vec: mlx_sys::mlx_vector_vector_array) -> Self {
Self { c_vec }
}

pub(crate) fn into_values<T>(self) -> T
where
T: FromIterator<VectorArray>,
{
unsafe {
let size = mlx_sys::mlx_vector_vector_array_size(self.c_vec);
(0..size)
.map(|i| {
let c_array = mlx_sys::mlx_vector_vector_array_get(self.c_vec, i);
VectorArray::from_ptr(c_array)
})
.collect::<T>()
}
}
}

pub(crate) struct TupleArrayArray {
c_tuple: mlx_tuple_array_array,
}
Expand Down
4 changes: 2 additions & 2 deletions mlx-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ metal = []
[dependencies]

[build-dependencies]
bindgen = "0.69.4"
bindgen = "0.70.1"
cmake = "0.1.31"
cc = "1"
cc = "1"
2 changes: 2 additions & 0 deletions mlx-sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
extern crate cmake;

use bindgen::RustTarget;
use cmake::Config;
use std::env;
use std::path::{Path, PathBuf};
Expand Down Expand Up @@ -81,6 +82,7 @@ fn main() {

// generate bindings
let bindings = bindgen::Builder::default()
.rust_target(RustTarget::Stable_1_73)
.header("src/mlx-c/mlx/c/mlx.h")
.header("src/mlx-c/mlx/c/linalg.h")
.header("src/mlx-c/mlx/c/error.h")
Expand Down

0 comments on commit 522ceee

Please sign in to comment.