Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

&[usize] -> TVec<usize> all the shapes #1401

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ impl ValueInterface for Value {
let dt = to_internal_dt(dt);
let len = shape.iter().product::<usize>() * dt.size_of();
anyhow::ensure!(len == data.len());
let tensor = unsafe { Tensor::from_raw_dt(dt, shape, data)? };
let tensor = unsafe { Tensor::from_raw_dt(dt, shape.into(), data)? };
Ok(Value(tensor.into_tvalue()))
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl EvalOp for MultiBroadcastTo {
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let shape = self.shape.eval_to_usize(&session.resolved_symbols)?;
Ok(tvec!(inputs[0].broadcast_to_shape(&shape)?.into_tvalue()))
Ok(tvec!(inputs[0].broadcast_to_shape(shape.into_owned())?.into_tvalue()))
}
}

Expand Down
4 changes: 2 additions & 2 deletions core/src/ops/array/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ impl Gather {
unsafe fn eval_t<T: Datum>(&self, data: TValue, indices: &TValue) -> TractResult<TValue> {
let data_view = data.to_array_view_unchecked::<T>();
let indices = indices.to_array_view::<i64>()?;
let output_shape = &*self.compute_output_shape(data.shape(), indices.shape())?;
let output_shape = self.compute_output_shape(data.shape(), indices.shape())?;
let mut output = Tensor::uninitialized::<T>(output_shape)?;
let mut output_view = output.to_array_view_mut::<T>()?;
for coords in tract_ndarray::indices(output_shape) {
for coords in tract_ndarray::indices(output_view.shape()) {
let ocoords = coords.as_array_view();
let ocoords = ocoords.as_slice().unwrap();
let mut icoords: TVec<usize> = ocoords[0..self.axis].into();
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/gather_nd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl EvalOp for GatherNd {
let indices = indices.cast_to::<i32>()?;
let indices = indices.to_array_view::<i32>()?;
unsafe {
let mut output = Tensor::uninitialized_dt(data.datum_type(), &shape)?;
let mut output = Tensor::uninitialized_dt(data.datum_type(), shape)?;
dispatch_datum_by_size!(Self::eval_t(data.datum_type())(
self,
&mut output,
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/one_hot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl EvalOp for OneHot {
let mut shape: TVec<usize> = input.shape().into();
shape.insert(self.axis, self.dim);
unsafe {
let mut output = self.off.broadcast_scalar_to_shape(&shape)?;
let mut output = self.off.broadcast_scalar_to_shape(shape)?;
dispatch_datum_by_size!(Self::eval_t(self.off.datum_type())(
self,
&input,
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl Range {
len: usize,
) -> TractResult<Tensor> {
unsafe {
let mut result = Tensor::uninitialized::<T>(&[len])?;
let mut result = Tensor::uninitialized::<T>(tvec!(len))?;
let mut v = start.to_scalar::<T>()?.clone();
let step = step.to_scalar::<T>()?;
for i in 0..len {
Expand Down
4 changes: 1 addition & 3 deletions core/src/ops/array/reshape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ impl Op for FiniteReshape {
op_as_typed_op!();
}



impl EvalOp for FiniteReshape {
fn is_stateless(&self) -> bool {
true
Expand All @@ -29,7 +27,7 @@ impl EvalOp for FiniteReshape {
let input = args_1!(inputs);
let mut tensor = input.into_tensor();
unsafe {
tensor.set_shape_unchecked(&self.shape);
tensor.set_shape_unchecked(self.shape.clone());
}
Ok(tvec!(tensor.into_tvalue()))
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ fn eval_slice(input: &Tensor, axis: usize, start: usize, end: usize) -> TractRes
unsafe {
let mut shape: TVec<_> = input.shape().into();
shape[axis] = end - start;
let mut tensor = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
let mut tensor = Tensor::uninitialized_dt(input.datum_type(), shape)?;
tensor.assign_slice_unchecked(.., input, start..end, axis);
Ok(tvec!(tensor.into_tvalue()))
}
Expand Down
6 changes: 3 additions & 3 deletions core/src/ops/array/topk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ impl EvalOp for Topk {
let k = k.cast_to_scalar::<i64>()? as usize;
output_shape[self.axis] = k;
let dt = input.datum_type();
let mut output_values = Tensor::zero_dt(dt, &output_shape)?;
let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
let mut iterating_shape = output_shape.clone();
let mut output_values = Tensor::zero_dt(dt, output_shape.clone())?;
let mut output_indices = Tensor::zero::<i64>(output_shape.clone())?;
let mut iterating_shape = output_shape;
iterating_shape[self.axis] = 1;
let mut output_indices_view = output_indices.to_array_view_mut::<i64>()?;
for coords in tract_ndarray::indices(&*iterating_shape) {
Expand Down
6 changes: 3 additions & 3 deletions core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static +
self.eval_in_a(&mut a, &b)?;
Ok(a)
} else {
let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
let mut c = unsafe { Tensor::uninitialized_dt(c_dt, c_shape)? };
self.eval_out_of_place(&mut c, &a, &b)?;
Ok(c)
}
Expand Down Expand Up @@ -584,7 +584,7 @@ macro_rules! bin_to_super_type {
let b = b.to_array_view::<u8>()?;
let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])
.context("no broadcast solution")?;
let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
let mut c = Tensor::zero_dt(*c_dt, c_shape)?;
let view = c.to_array_view_mut::<u8>()?;
$crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
*c = (scale_by($q_op_on_f32(
Expand Down Expand Up @@ -613,7 +613,7 @@ macro_rules! bin_to_super_type {
let b = b.cast_to_dt(accumulator_dt)?.into_owned();
let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])
.context("no broadcast solution")?;
let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
let mut c = Tensor::zero_dt(accumulator_dt, c_shape)?;
match accumulator_dt {
DatumType::F32 => {
let view = c.to_array_view_mut::<f32>()?;
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl EvalOp for Cast {
if input.datum_type() == self.to {
Ok(tvec!(input))
} else if input.datum_type() == TDim::datum_type() {
let mut tmp = Tensor::zero_dt(i64::datum_type(), input.shape())?;
let mut tmp = Tensor::zero_dt(i64::datum_type(), input.shape().into())?;
for (dim, i) in
tract_itertools::izip!(input.as_slice::<TDim>()?, tmp.as_slice_mut::<i64>()?)
{
Expand Down
4 changes: 2 additions & 2 deletions core/src/ops/change_axes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ impl AxisOp {
Reshape(at, from, to) => {
let mut shape: TVec<usize> = tensor.shape().into();
self.change_shape_array(&mut shape, false)?;
if tensor.set_shape(&shape).is_ok() {
if tensor.set_shape(shape).is_ok() {
Ok(())
} else if broadcasting
&& tensor.shape().iter().skip(*at).take(from.len()).all(|d| *d == 1)
Expand Down Expand Up @@ -1116,7 +1116,7 @@ mod proptests {

fn input(&self) -> TractResult<Tensor> {
unsafe {
let mut t = Tensor::uninitialized::<i64>(&self.input)?;
let mut t = Tensor::uninitialized::<i64>(self.input.clone())?;
for i in 0..t.len() {
t.as_slice_mut().unwrap()[i] = i as i64;
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/cnn/conv/depth_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ macro_rules! impl_eval {
mul: impl Fn(T, T) -> T + Copy + 'static,
) -> TractResult<TVec<TValue>> {
let (img, kernel, bias) = args_3!(inputs);
let mut output = unsafe { Tensor::uninitialized::<T>(&dw.output_shape.shape)? };
let mut output = unsafe { Tensor::uninitialized::<T>(dw.output_shape.shape.clone())? };
let iptr = img.as_ptr::<T>()?;
let optr = output.as_ptr_mut::<T>()?;
let k_stride_i = kernel.strides()[1];
Expand Down
6 changes: 3 additions & 3 deletions core/src/ops/cnn/conv/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl EvalOp for Im2Col {
unsafe {
let mut input = inputs.remove(0).into_tensor();
let pad_value: Option<&Tensor> = if inputs.len() > 0 { Some(&inputs[0]) } else { None };
let mut output = Tensor::uninitialized::<Opaque>(&geometry.packed_shape)?;
let mut output = Tensor::uninitialized::<Opaque>(geometry.packed_shape.clone())?;
if !self.pool_spec.data_format.has_n() {
input.insert_axis(0)?;
}
Expand All @@ -162,7 +162,7 @@ impl EvalOp for Im2Col {
for g in 0..self.group {
let mut data = Tensor::uninitialized_aligned_dt(
input.datum_type(),
&[geometry.b_pack.len(geometry.k, geometry.n)],
tvec![geometry.b_pack.len(geometry.k, geometry.n)],
geometry.b_pack.alignment(),
)?;
dispatch_copy_by_size!(Patcher::patch(input.datum_type())(
Expand Down Expand Up @@ -264,7 +264,7 @@ impl Patcher {
) -> TractResult<()> {
unsafe {
let pad_value = *pad_value.to_scalar_unchecked();
let mut mega_matrix = Tensor::uninitialized::<T>(&[geometry.k, geometry.n])?;
let mut mega_matrix = Tensor::uninitialized::<T>(tvec![geometry.k, geometry.n])?;
let mut mega_matrix_view = mega_matrix.to_array_view_mut_unchecked::<T>();
let ptr = input.as_ptr_unchecked::<T>();
let ptr = ptr.add(geometry.input_shape_with_n.c_stride() * (g * geometry.ci_per_group));
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/cnn/deconv/deconv_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl DeconvSum {
let mut tensor = bias.into_tensor();
let hw = *gemm.shape().last().unwrap();
let n = *output_shape.n().unwrap_or(&1);
let n_o_hkwk_hw = gemm.into_tensor().into_shape(&[
let n_o_hkwk_hw = gemm.into_tensor().into_shape(tvec![
n,
*output_shape.c(),
self.pool_spec.kernel_shape.iter().product(),
Expand Down
8 changes: 4 additions & 4 deletions core/src/ops/cnn/patches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ pub mod test {
fn reference_sumpool(&self) -> Tensor {
let input_shape = self.input_shape();
let output_shape = self.output_shape();
let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
let mut output = Tensor::zero::<f32>(output_shape.shape.clone()).unwrap();
for geo_out in tract_ndarray::indices(output_shape.hw_dims()) {
for geo_ker in tract_ndarray::indices(&*self.patch.spec.kernel_shape) {
let geo_in: TVec<isize> = izip!(
Expand Down Expand Up @@ -845,7 +845,7 @@ pub mod test {
fn check_visitor(&self) {
let input_shape = self.input_shape();
let output_shape = self.output_shape();
let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
let mut output = Tensor::zero::<f32>(output_shape.shape.clone()).unwrap();
self.patch.visit_output(|visitor| {
for (_k, offset_in) in visitor.valid_offsets_ker_in() {
for c in 0..*output_shape.c() {
Expand All @@ -862,7 +862,7 @@ pub mod test {
fn check_zone_visitor(&self) {
let input_shape = self.input_shape();
let output_shape = self.output_shape();
let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
let mut output = Tensor::zero::<f32>(output_shape.shape.clone()).unwrap();
for zone in &self.patch.zones {
zone.visit_output(&self.patch, |visitor| {
for (_k, offset_in) in visitor.valid_offsets_ker_in() {
Expand Down Expand Up @@ -945,7 +945,7 @@ pub mod test {
#[test]
fn test_visitor_1() {
let input_shape = NCHW.from_n_c_hw(1, 1, [2, 2]).unwrap();
let input = Tensor::zero::<f32>(&input_shape.shape).unwrap();
let input = Tensor::zero::<f32>(input_shape.shape.clone()).unwrap();
let patch = PatchSpec::for_data_shape(input_shape.clone())
.with_kernel_shape(tvec![2, 1])
.with_padding(PaddingSpec::SameLower)
Expand Down
10 changes: 6 additions & 4 deletions core/src/ops/cnn/sumpool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ impl EvalOp for LirSumPool {
let input = args_1!(inputs);
let geo = self.geometry.to_concrete(input.shape())?;
let values = if input.datum_type().is_float() {
let mut values =
unsafe { Tensor::uninitialized_dt(input.datum_type(), &geo.output_shape.shape)? };
let mut values = unsafe {
Tensor::uninitialized_dt(input.datum_type(), geo.output_shape.shape.clone())?
};
dispatch_floatlike!(Self::eval_t(input.datum_type())(
self,
&*input,
Expand All @@ -117,8 +118,9 @@ impl EvalOp for LirSumPool {
))?;
values
} else {
let mut values =
unsafe { Tensor::uninitialized_dt(DatumType::F32, &geo.output_shape.shape)? };
let mut values = unsafe {
Tensor::uninitialized_dt(DatumType::F32, geo.output_shape.shape.clone())?
};
let input_f32 = input.cast_to_dt(DatumType::F32)?;
self.eval_t::<f32>(input_f32.as_ref(), values.as_ptr_mut()?, geo.as_ref())?;
values.cast_to_dt(input.datum_type())?.into_owned()
Expand Down
7 changes: 5 additions & 2 deletions core/src/ops/downsample/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl EvalOp for Downsample {
let t = if self.modulo > input.shape()[self.axis] {
let mut shape: TVec<usize> = input.shape().into();
shape[self.axis] = 0;
Tensor::uninitialized_dt(input.datum_type(), &shape)?
Tensor::uninitialized_dt(input.datum_type(), shape)?
} else {
let slice = ndarray::Slice::new(self.modulo as isize, None, self.stride);
unsafe fn do_slice<T: Datum>(
Expand All @@ -86,7 +86,10 @@ impl EvalOp for Downsample {
impl TypedOp for Downsample {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(self.axis < inputs[0].rank());
ensure!(self.modulo == 0 || self.stride > 0, "non-zero modulo is only defined with forward strides");
ensure!(
self.modulo == 0 || self.stride > 0,
"non-zero modulo is only defined with forward strides"
);
let mut downed = inputs[0].clone();
let down_len = self.transform_dim(&downed.shape[self.axis]);
downed.shape.set(self.axis, down_len);
Expand Down
4 changes: 2 additions & 2 deletions core/src/ops/einsum/as_blas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ impl EvalOp for SGemm {
let n = c_shape[rank - 1];
let k = a.shape()[rank - 1];
unsafe {
let mut c = Tensor::uninitialized::<f32>(&c_shape)?;
let mut c = Tensor::uninitialized::<f32>(c_shape)?;
let c_ptr = c.as_ptr_mut::<f32>()?;
let silent_a_axis = c.rank() - a.rank();
let silent_b_axis = c.rank() - b.rank();
for prefix in ndarray::indices(&c_shape[0..rank - 2]) {
for prefix in ndarray::indices(&c.shape()[0..rank - 2]) {
let mut a_ptr = a_ptr;
let mut b_ptr = b_ptr;
let mut c_ptr = c_ptr;
Expand Down
Loading
Loading