Skip to content

Commit

Permalink
Merge pull request #123 from magnusuMET/feature/ndarray
Browse files Browse the repository at this point in the history
Improve `ndarray` integration
  • Loading branch information
magnusuMET authored Dec 17, 2023
2 parents 3804a00 + 057e97c commit c99fd93
Show file tree
Hide file tree
Showing 5 changed files with 440 additions and 35 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ members = [
"netcdf-src",
]
default-members = ["netcdf", "netcdf-sys"]
resolver = "2"
145 changes: 113 additions & 32 deletions netcdf/src/extent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ impl From<RangeFull> for Extent {
impl_for_ref!(RangeFull: Extent);

impl Extent {
#[allow(unused)]
const fn stride(&self) -> Option<isize> {
match *self {
Self::Slice { start: _, stride }
Expand Down Expand Up @@ -460,52 +461,132 @@ impl From<()> for Extents {

pub(crate) type StartCountStride = (Vec<usize>, Vec<usize>, Vec<isize>);

#[allow(dead_code)]
pub(crate) struct StartCountStrideIterItem {
pub(crate) start: usize,
pub(crate) count: usize,
pub(crate) stride: isize,
/// Extent is an index
pub(crate) is_an_index: bool,
/// The dimension can increase
pub(crate) is_growable: bool,
/// The extent has an upper bound
pub(crate) is_upwards_limited: bool,
}

enum StartCountStrideIter<'a> {
All(std::slice::Iter<'a, Dimension<'a>>),
Extent(std::iter::Zip<std::slice::Iter<'a, Extent>, std::slice::Iter<'a, Dimension<'a>>>),
}

impl<'a> Iterator for StartCountStrideIter<'a> {
type Item = StartCountStrideIterItem;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::All(iter) => iter.next().map(|dim| Self::Item {
start: 0,
count: dim.len(),
stride: 1,
is_an_index: false,
is_growable: dim.is_unlimited(),
is_upwards_limited: false,
}),
Self::Extent(iter) => iter.next().map(|(extent, dim)| match *extent {
Extent::Index(start) => Self::Item {
start,
count: 1,
stride: 1,
is_an_index: true,
is_growable: dim.is_unlimited(),
is_upwards_limited: true,
},
Extent::Slice { start, stride } => stride.try_into().map_or_else(
|_| Self::Item {
start,
count: 0,
stride, // negative stride is not used
is_an_index: false,
is_growable: dim.is_unlimited(),
is_upwards_limited: false,
},
|stride| Self::Item {
start,
count: (start..dim.len()).step_by(stride).count(),
stride: stride as isize,
is_an_index: false,
is_growable: dim.is_unlimited(),
is_upwards_limited: false,
},
),
Extent::SliceCount {
start,
count,
stride,
} => Self::Item {
start,
count,
stride,
is_an_index: false,
is_growable: dim.is_unlimited(),
is_upwards_limited: true,
},
Extent::SliceEnd { start, end, stride } => stride.try_into().map_or_else(
|_| Self::Item {
start,
count: 0,
stride, // negative stride is not used
is_an_index: false,
is_growable: dim.is_unlimited(),
is_upwards_limited: true,
},
|stride| Self::Item {
start,
count: (start..end).step_by(stride).count(),
stride: stride as isize,
is_an_index: false,
is_growable: dim.is_unlimited(),
is_upwards_limited: true,
},
),
}),
}
}
}

impl Extents {
pub(crate) fn get_start_count_stride(
&self,
dims: &[Dimension],
) -> Result<StartCountStride, error::Error> {
let (start, count, stride): StartCountStride = match self {
Self::All => {
let start = dims.iter().map(|_| 0).collect();
let counts = dims.iter().map(Dimension::len).collect();
let stride = dims.iter().map(|_| 1).collect();
let mut start = vec![];
let mut count = vec![];
let mut stride = vec![];
for item in self.iter_with_dims(dims)? {
start.push(item.start);
count.push(item.count);
stride.push(item.stride);
}
Ok((start, count, stride))
}

(start, counts, stride)
}
pub(crate) fn iter_with_dims<'a>(
&'a self,
dims: &'a [Dimension],
) -> Result<impl Iterator<Item = StartCountStrideIterItem> + 'a, error::Error> {
match self {
Self::All => Ok(StartCountStrideIter::All(dims.iter())),
Self::Extent(extents) => {
if extents.len() != dims.len() {
return Err(error::Error::DimensionMismatch {
wanted: dims.len(),
actual: extents.len(),
});
}
let (start, count) = dims
.iter()
.zip(extents)
.map(|(d, &e)| match e {
Extent::Index(start) => (start, 1),
Extent::Slice { start, stride } => usize::try_from(stride).map_or_else(
|_| (start, 0),
|stride| (start, (start..d.len()).step_by(stride).count()),
),
Extent::SliceCount {
start,
count,
stride: _,
} => (start, count),
Extent::SliceEnd { start, end, stride } => usize::try_from(stride)
.map_or_else(
|_| (start, 0),
|stride| (start, (start..end).step_by(stride).count()),
),
})
.unzip();
let stride = extents.iter().map(|e| e.stride().unwrap_or(1)).collect();
(start, count, stride)
Ok(StartCountStrideIter::Extent(
extents.iter().zip(dims.iter()),
))
}
};
Ok((start, count, stride))
}
}
}

Expand Down
17 changes: 16 additions & 1 deletion netcdf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
//! let data_i32 = var.value::<i32, _>((40, 0, 0))?;
//!
//! // You can use `values_arr()` to get all the data from the variable.
//! // This requires the `ndarray` feature
//! // Passing `..` will give you the entire slice
//! # #[cfg(feature = "ndarray")]
//! let data = var.values_arr::<i32, _>(..)?;
Expand All @@ -46,7 +47,14 @@
//! // `(40, 0, 0)` and get a dataset of size `100, 100` from this
//! # #[cfg(feature = "ndarray")]
//! let data = var.values_arr::<i32, _>(([40, 0 ,0], [1, 100, 100]))?;
//! # #[cfg(feature = "ndarray")]
//! let data = var.values_arr::<i32, _>((40, ..100, ..100))?;
//!
//! // You can read into an ndarray to reuse an allocation
//! # #[cfg(feature = "ndarray")]
//! let mut data = ndarray::Array::<f32, _>::zeros((100, 100));
//! # #[cfg(feature = "ndarray")]
//! var.values_arr_into((0, .., ..), data.view_mut())?;
//! # Ok(()) }
//! ```
//!
Expand All @@ -67,7 +75,8 @@
//! "crab_coolness_level",
//! &["time", "ncrabs"],
//! )?;
//! // Metadata can be added to the variable
//! // Metadata can be added to the variable, but will not be used when
//! // writing or reading data
//! var.add_attribute("units", "Kelvin")?;
//! var.add_attribute("add_offset", 273.15_f32)?;
//!
Expand All @@ -78,6 +87,12 @@
//! // Values can be added along the unlimited dimension, which
//! // resizes along the `time` axis
//! var.put_values(&data, (11, ..))?;
//!
//! // Using the ndarray feature you can also use
//! # #[cfg(feature = "ndarray")]
//! let values = ndarray::Array::from_shape_fn((5, 10), |(j, i)| (j * 10 + i) as f32);
//! # #[cfg(feature = "ndarray")]
//! var.put_values_arr((11.., ..), values.view())?;
//! # Ok(()) }
//! ```
Expand Down
Loading

0 comments on commit c99fd93

Please sign in to comment.