From 522ceeed6363e06dd571c68bc63d10d3586cbb9b Mon Sep 17 00:00:00 2001 From: David Chavez Date: Thu, 31 Oct 2024 14:39:58 +0100 Subject: [PATCH] chore(mlx-c): 0.10.0 Add remaining additions (#128) --- .gitmodules | 3 -- README.md | 6 +++ mlx-rs/src/ops/indexing/mod.rs | 72 ++++++++++++++++++++++++++++------ mlx-rs/src/ops/other.rs | 52 ++++++++++++++++++++++++ mlx-rs/src/ops/shapes.rs | 1 + mlx-rs/src/utils.rs | 31 --------------- mlx-sys/Cargo.toml | 4 +- mlx-sys/build.rs | 2 + 8 files changed, 124 insertions(+), 47 deletions(-) diff --git a/.gitmodules b/.gitmodules index a8526e9f7..f73fa4c9e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ -[submodule "mlx-sys/mlx"] - path = mlx-sys/mlx - url = https://github.com/ml-explore/mlx.git [submodule "mlx-sys/src/mlx-c"] path = mlx-sys/src/mlx-c url = https://github.com/ml-explore/mlx-c.git diff --git a/README.md b/README.md index e767597f9..48fb6a18c 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,12 @@ We try to be as welcoming as possible to everybody from any background. We're st mlx-rs is currently in active development, and is not yet complete. +## MSRV + +The minimum supported Rust version is 1.75.0. + +The MSRV is the minimum Rust version that can be used to compile each crate. + ## License mlx-rs is distributed under the terms of the MIT license. See [LICENSE](./LICENSE) for details. diff --git a/mlx-rs/src/ops/indexing/mod.rs b/mlx-rs/src/ops/indexing/mod.rs index 95c9d5c8d..183ef486a 100644 --- a/mlx-rs/src/ops/indexing/mod.rs +++ b/mlx-rs/src/ops/indexing/mod.rs @@ -352,11 +352,10 @@ impl Array { } } - // NOTE: take and take_long_axis are two separate functions in the c++ code. They don't call - // each other. - /// Take values along an axis at the specified indices. /// + /// If no axis is specified, the array is flattened to 1D prior to the indexing operation. + /// /// # Params /// /// - `indices`: The indices to take from the array. @@ -365,22 +364,61 @@ impl Array { pub fn take_along_axis_device( &self, indices: &Array, - axis: i32, + axis: impl Into>, stream: impl AsRef, ) -> Result { + let (input, axis) = match axis.into() { + None => (self.reshape_device(&[-1], &stream)?, 0), + Some(ax) => (self.clone(), ax), + }; + unsafe { let c_array = try_catch_c_ptr_expr! { - mlx_sys::mlx_take_along_axis( - self.c_array, - indices.c_array, - axis, - stream.as_ref().as_ptr(), - ) + mlx_sys::mlx_take_along_axis(input.c_array, indices.c_array, axis, stream.as_ref().as_ptr()) }; Ok(Array::from_ptr(c_array)) } } + + /// Put values along an axis at the specified indices. + /// + /// If no axis is specified, the array is flattened to 1D prior to the indexing operation. + /// + /// # Params + /// - indices: Indices array. These should be broadcastable with the input array excluding the `axis` dimension. + /// - values: Values array. These should be broadcastable with the indices. + /// - axis: Axis in the destination to put the values to. + /// - stream: stream or device to evaluate on. + #[default_device] + pub fn put_along_axis_device( + &self, + indices: &Array, + values: &Array, + axis: impl Into>, + stream: impl AsRef, + ) -> Result { + match axis.into() { + None => unsafe { + let input = self.reshape_device(&[-1], &stream)?; + + let c_array = try_catch_c_ptr_expr! { + mlx_sys::mlx_put_along_axis(input.c_array, indices.c_array, values.c_array, 0, stream.as_ref().as_ptr()) + }; + + let array = Array::from_ptr(c_array); + let array = array.reshape_device(self.shape(), &stream)?; + Ok(array) + }, + Some(ax) => unsafe { + let c_array = try_catch_c_ptr_expr! { + mlx_sys::mlx_put_along_axis(self.c_array, indices.c_array, values.c_array, ax, stream.as_ref().as_ptr()) + }; + + Ok(Array::from_ptr(c_array)) + }, + } + } } /// Indices of the maximum values along the axis. @@ -579,12 +617,24 @@ pub fn argsort_all_device(a: &Array, stream: impl AsRef) -> Result>, stream: impl AsRef, ) -> Result { a.take_along_axis_device(indices, axis, stream) } +/// See [`Array::put_along_axis`] +#[default_device] +pub fn put_along_axis_device( + a: &Array, + indices: &Array, + values: &Array, + axis: impl Into>, + stream: impl AsRef, +) -> Result { + a.put_along_axis_device(indices, values, axis, stream) +} + /// See [`Array::take`] #[default_device] pub fn take_device( diff --git a/mlx-rs/src/ops/other.rs b/mlx-rs/src/ops/other.rs index c81dc1737..dbbfd7754 100644 --- a/mlx-rs/src/ops/other.rs +++ b/mlx-rs/src/ops/other.rs @@ -68,6 +68,39 @@ impl Array { Ok(Array::from_ptr(c_array)) } } + + /// Perform the Walsh-Hadamard transform along the final axis. + /// + /// Supports sizes `n = m*2^k` for `m` in `(1, 12, 20, 28)` and `2^k <= 8192` + /// for ``DType/float32`` and `2^k <= 16384` for ``DType/float16`` and ``DType/bfloat16``. + /// + /// # Params + /// - scale: scale the output by this factor -- default is `1.0/sqrt(array.dim(-1))` + /// - stream: stream to evaluate on. + #[default_device] + pub fn hadamard_transform_device( + &self, + scale: impl Into>, + stream: impl AsRef, + ) -> Result { + let scale = scale.into(); + let scale = mlx_sys::mlx_optional_float { + value: scale.unwrap_or(0.0), + has_value: scale.is_some(), + }; + + unsafe { + let c_array = try_catch_c_ptr_expr! { + mlx_sys::mlx_hadamard_transform( + self.c_array, + scale, + stream.as_ref().as_ptr(), + ) + }; + + Ok(Array::from_ptr(c_array)) + } + } } /// See [`Array::diag`] @@ -248,4 +281,23 @@ mod tests { let out = einsum("ii->", &[m]).unwrap(); assert_eq!(out, array!(5.0)); } + + #[test] + fn test_hadamard_transform() { + let input = Array::from_slice(&[1.0, -1.0, -1.0, 1.0], &[2, 2]); + let expected = Array::from_slice( + &[ + 0.0, + 2.0_f32 / 2.0_f32.sqrt(), + 0.0, + -2.0_f32 / 2.0_f32.sqrt(), + ], + &[2, 2], + ); + let result = input.hadamard_transform(None).unwrap(); + + let c = result.all_close(&expected, 1e-5, 1e-5, None).unwrap(); + let c_data: &[bool] = c.as_slice(); + assert_eq!(c_data, [true]); + } } diff --git a/mlx-rs/src/ops/shapes.rs b/mlx-rs/src/ops/shapes.rs index 669e8b3d7..8bcfa721e 100644 --- a/mlx-rs/src/ops/shapes.rs +++ b/mlx-rs/src/ops/shapes.rs @@ -688,6 +688,7 @@ impl PadMode { /// `(before_i, after_i)` are all the same. If a single integer or tuple with a single integer is /// passed then all axes are extended by the same number on each side. /// - `value`: The value to pad the array with. Default is `0` if not provided. +/// - `mode`: The padding mode. Default is `PadMode::Constant` if not provided. /// /// # Example /// diff --git a/mlx-rs/src/utils.rs b/mlx-rs/src/utils.rs index cb3e940f7..13a3252ff 100644 --- a/mlx-rs/src/utils.rs +++ b/mlx-rs/src/utils.rs @@ -367,37 +367,6 @@ where extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {} -pub(crate) struct VectorVectorArray { - c_vec: mlx_sys::mlx_vector_vector_array, -} - -impl Drop for VectorVectorArray { - fn drop(&mut self) { - unsafe { mlx_sys::mlx_free(self.c_vec as *mut c_void) } - } -} - -impl VectorVectorArray { - pub(crate) unsafe fn from_ptr(c_vec: mlx_sys::mlx_vector_vector_array) -> Self { - Self { c_vec } - } - - pub(crate) fn into_values(self) -> T - where - T: FromIterator, - { - unsafe { - let size = mlx_sys::mlx_vector_vector_array_size(self.c_vec); - (0..size) - .map(|i| { - let c_array = mlx_sys::mlx_vector_vector_array_get(self.c_vec, i); - VectorArray::from_ptr(c_array) - }) - .collect::() - } - } -} - pub(crate) struct TupleArrayArray { c_tuple: mlx_tuple_array_array, } diff --git a/mlx-sys/Cargo.toml b/mlx-sys/Cargo.toml index 985b96db1..b1abf7350 100644 --- a/mlx-sys/Cargo.toml +++ b/mlx-sys/Cargo.toml @@ -27,6 +27,6 @@ metal = [] [dependencies] [build-dependencies] -bindgen = "0.69.4" +bindgen = "0.70.1" cmake = "0.1.31" -cc = "1" \ No newline at end of file +cc = "1" diff --git a/mlx-sys/build.rs b/mlx-sys/build.rs index 32c6c81c8..65eca98b8 100644 --- a/mlx-sys/build.rs +++ b/mlx-sys/build.rs @@ -1,5 +1,6 @@ extern crate cmake; +use bindgen::RustTarget; use cmake::Config; use std::env; use std::path::{Path, PathBuf}; @@ -81,6 +82,7 @@ fn main() { // generate bindings let bindings = bindgen::Builder::default() + .rust_target(RustTarget::Stable_1_73) .header("src/mlx-c/mlx/c/mlx.h") .header("src/mlx-c/mlx/c/linalg.h") .header("src/mlx-c/mlx/c/error.h")