Skip to content

Commit

Permalink
gather working with input_store
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Feb 19, 2025
1 parent bddaa50 commit cff3e36
Show file tree
Hide file tree
Showing 13 changed files with 371 additions and 178 deletions.
84 changes: 55 additions & 29 deletions core/src/model/order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::internal::*;
use bit_set::BitSet;
use std::collections::VecDeque;
use std::fmt::{Debug, Display};
use tract_itertools::Itertools;

/// Find an evaluation order for a model, using its default inputs and outputs
/// as boundaries.
Expand All @@ -11,8 +12,8 @@ where
F: Fact + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
let targets = model.output_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
let inputs = model.input_outlets()?.iter().map(|n| n.node).collect_vec();
let targets = model.output_outlets()?.iter().map(|n| n.node).collect_vec();
eval_order_for_nodes(model.nodes(), &inputs, &targets, &[])
}

Expand All @@ -38,8 +39,11 @@ where
let mut pending = BitSet::with_capacity(nodes.len());
while let Some((current_node, current_input)) = current_stack.pop() {
let deps_from_inputs = nodes[current_node].inputs.len();
let all_deps_count =
deps_from_inputs + more_dependencies.iter().filter(|a| a.0 == current_node).count();
let all_deps_count = deps_from_inputs
+ more_dependencies
.iter()
.filter(|a| a.0 == current_node)
.count();
if model_inputs.contains(&current_node) || current_input == all_deps_count {
order.push(current_node);
done.insert(current_node);
Expand All @@ -50,7 +54,12 @@ where
.iter()
.filter(|n| nodes[n.node].inputs.len() > 0)
.map(|n| n.node)
.chain(more_dependencies.iter().filter(|a| a.0 == current_node).map(|n| n.1))
.chain(
more_dependencies
.iter()
.filter(|a| a.0 == current_node)
.map(|n| n.1),
)
.chain(
nodes[current_node]
.inputs
Expand Down Expand Up @@ -82,28 +91,34 @@ where
Ok(order)
}

pub fn build_flush_list<F, O, Flushable>(model: &Graph<F, O>, order: &[usize], outputs: &[OutletId], flushable: Flushable) -> Vec<TVec<usize>>
pub fn build_flush_list<F, O, Flushable>(
model: &Graph<F, O>,
order: &[usize],
outputs: &[OutletId],
flushable: Flushable,
) -> Vec<TVec<usize>>
where
F: Fact + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
Flushable: Fn(&Node<F, O>) -> bool {
let mut values_needed_until_step = vec![0; model.nodes().len()];
for (step, node) in order.iter().enumerate() {
for i in &model.node(*node).inputs {
values_needed_until_step[i.node] = step;
}
}
for o in outputs.iter() {
values_needed_until_step[o.node] = order.len();
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
Flushable: Fn(&Node<F, O>) -> bool,
{
let mut values_needed_until_step = vec![0; model.nodes().len()];
for (step, node) in order.iter().enumerate() {
for i in &model.node(*node).inputs {
values_needed_until_step[i.node] = step;
}
let mut flush_lists: Vec<TVec<usize>> = vec![tvec!(); order.len() + 1];
}
for o in outputs.iter() {
values_needed_until_step[o.node] = order.len();
}
let mut flush_lists: Vec<TVec<usize>> = vec![tvec!(); order.len() + 1];

for (node, &flush_at) in values_needed_until_step.iter().enumerate() {
if flush_at != 0 && (flushable)(model.node(node)) {
flush_lists[flush_at].push(node)
}
for (node, &flush_at) in values_needed_until_step.iter().enumerate() {
if flush_at != 0 && (flushable)(model.node(node)) {
flush_lists[flush_at].push(node)
}
flush_lists
}
flush_lists
}

/// Find an evaluation order for a list of model trying to minimize memory occupation.
Expand All @@ -112,8 +127,8 @@ where
F: Fact + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
let targets = model.output_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
let inputs = model.input_outlets()?.iter().map(|n| n.node).collect_vec();
let targets = model.output_outlets()?.iter().map(|n| n.node).collect_vec();
eval_order_opt_ram_for_nodes(model.nodes(), &inputs, &targets, &[])
}

Expand Down Expand Up @@ -238,8 +253,10 @@ where
}

while !model_outputs.iter().all(|o| done.done.contains(*o)) {
let next = if let Some(next) =
done.candidates.iter().find(|n| dfs.ups[*n].iter().all(|n| done.done.contains(*n)))
let next = if let Some(next) = done
.candidates
.iter()
.find(|n| dfs.ups[*n].iter().all(|n| done.done.contains(*n)))
{
next
} else if let Some(next) = done.best_upstream_starter(&dfs) {
Expand Down Expand Up @@ -270,7 +287,10 @@ mod tests {
let add = model.wire_node("add", math::add(), &[a, b]).unwrap()[0];
model.auto_outputs().unwrap();
assert_eq!(model.eval_order().unwrap(), vec!(a.node, b.node, add.node));
assert_eq!(model.eval_order_opt_ram().unwrap(), vec!(a.node, b.node, add.node));
assert_eq!(
model.eval_order_opt_ram().unwrap(),
vec!(a.node, b.node, add.node)
);
}

#[test]
Expand Down Expand Up @@ -298,12 +318,18 @@ mod tests {
std::thread::spawn(move || {
rx.send(cloned.eval_order()).unwrap();
});
assert!(tx.recv_timeout(std::time::Duration::from_secs(1)).unwrap().is_err());
assert!(tx
.recv_timeout(std::time::Duration::from_secs(1))
.unwrap()
.is_err());
let (rx, tx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
rx.send(model.eval_order_opt_ram()).unwrap();
});
assert!(tx.recv_timeout(std::time::Duration::from_secs(1)).unwrap().is_err());
assert!(tx
.recv_timeout(std::time::Duration::from_secs(1))
.unwrap()
.is_err());
}

#[test]
Expand Down
73 changes: 56 additions & 17 deletions core/src/ops/array/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use ndarray::*;
use tract_linalg::block_quant::BlockQuantValue;
use tract_linalg::mmm::MMMInputValue;

#[derive(Debug, Clone, new, Hash)]
#[derive(Debug, Clone, Hash)]
pub struct Gather {
pub axis: usize,
pub output_type: Option<DatumType>,
}

impl Op for Gather {
Expand All @@ -18,6 +19,13 @@ impl Op for Gather {
}

impl Gather {
pub fn new(axis: usize) -> Gather {
Gather {
axis,
output_type: None,
}
}

pub fn compute_output_shape<D: DimLike>(
&self,
input_shape: &[D],
Expand Down Expand Up @@ -55,40 +63,59 @@ impl Gather {
Ok(output)
}

fn eval_bq_to_f16(&self, data: &BlockQuantValue, indices: &TValue) -> TractResult<Tensor> {
fn eval_bq<F: Datum>(&self, data: &BlockQuantValue, indices: &TValue) -> TractResult<Tensor> {
ensure!(self.axis == 0);
ensure!(data.fact.shape.len() == 2);
let data_shape = &data.fact.shape;
let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
let mut output = unsafe { Tensor::uninitialized::<f16>(output_shape)? };
let indices_slice = indices.as_slice::<i64>()?;
let vector_len = data_shape[1];
let output_slice = output.as_slice_mut::<f16>()?;
for (pos, ix) in indices_slice.iter().enumerate() {
let slice = &mut output_slice[pos * vector_len..][..vector_len];
for (i, slot) in slice.iter_mut().enumerate() {
let offset = data_shape[1] * *ix as usize + i;
*slot = data.fact.format.extract_at_offset_f16(&data.value, offset)
if F::datum_type() == f16::datum_type() {
let output_slice = output.as_slice_mut::<f16>()?;
for (pos, ix) in indices_slice.iter().enumerate() {
let slice = &mut output_slice[pos * vector_len..][..vector_len];
for (i, slot) in slice.iter_mut().enumerate() {
let offset = data_shape[1] * *ix as usize + i;
*slot = data.fact.format.extract_at_offset_f16(&data.value, offset)
}
}
} else {
let output_slice = output.as_slice_mut::<f32>()?;
for (pos, ix) in indices_slice.iter().enumerate() {
let slice = &mut output_slice[pos * vector_len..][..vector_len];
for (i, slot) in slice.iter_mut().enumerate() {
let offset = data_shape[1] * *ix as usize + i;
*slot = data.fact.format.extract_at_offset_f32(&data.value, offset)
}
}
}
Ok(output)
}

fn eval_input_store_to_f16(
fn eval_input_store<F: Datum>(
&self,
data: &dyn MMMInputValue,
indices: &TValue,
) -> TractResult<Tensor> {
ensure!(self.axis == 0);
let data_shape = &[data.mn(), data.k()];
let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
let mut output = unsafe { Tensor::uninitialized::<f16>(output_shape)? };
let mut output = unsafe { Tensor::uninitialized::<F>(output_shape)? };
let indices_slice = indices.as_slice::<i64>()?;
let vector_len = data_shape[1];
let output_slice = output.as_slice_mut::<f16>()?;
for (pos, m) in indices_slice.iter().enumerate() {
let slice = &mut output_slice[pos * vector_len..][..vector_len];
data.extract_at_mn_f16(*m as usize, slice)?;
if F::datum_type() == f16::datum_type() {
let output_slice = output.as_slice_mut::<f16>()?;
for (pos, m) in indices_slice.iter().enumerate() {
let slice = &mut output_slice[pos * vector_len..][..vector_len];
data.extract_at_mn_f16(*m as usize, slice)?;
}
} else {
let output_slice = output.as_slice_mut::<f32>()?;
for (pos, m) in indices_slice.iter().enumerate() {
let slice = &mut output_slice[pos * vector_len..][..vector_len];
data.extract_at_mn_f32(*m as usize, slice)?;
}
}
Ok(output)
}
Expand All @@ -98,10 +125,21 @@ impl TypedOp for Gather {
as_op!();

fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
if let Some(dt) = self.output_type {
ensure!(
inputs[0].datum_type.is_opaque() || inputs[0].datum_type == dt,
"Inconsistent datum_type in Gather: attribute is {:?}, but inputs[0] is {:?}",
dt,
inputs[0].datum_type
);
} else {
ensure!(inputs[0].datum_type.is_number(),
"Gather applied to compressed data requires an explicit datum_type attribute for its output");
}
ensure!(inputs[1].datum_type == i64::datum_type());
if inputs[0].datum_type.is_opaque() {
let data_shape = block_quant_aware_input_shape(inputs[0])?;
Ok(tvec!(f16::fact(
Ok(tvec!(self.output_type.unwrap().fact(
&*self.compute_output_shape(&data_shape, &inputs[1].shape)?
)))
} else {
Expand Down Expand Up @@ -153,10 +191,11 @@ impl EvalOp for Gather {
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let (data, indices) = args_2!(inputs);
let result = if let Ok(opaque) = data.to_scalar::<Opaque>() {
let dt = self.output_type.unwrap();
if let Some(data) = opaque.downcast_ref::<BlockQuantValue>() {
self.eval_bq_to_f16(data, &indices)?
dispatch_floatlike!(Self::eval_bq(dt)(self, data, &indices))?
} else if let Some(data) = opaque.downcast_ref::<Box<dyn MMMInputValue>>() {
self.eval_input_store_to_f16(&**data, &indices)?
dispatch_floatlike!(Self::eval_input_store(dt)(self, &**data, &indices))?
} else {
bail!("Can't use Gather on {:?} input", data);
}
Expand Down
11 changes: 11 additions & 0 deletions core/src/ops/cnn/conv/lazy_im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ impl MMMInputFormat for LazyIm2colParams {
) -> TractResult<()> {
unimplemented!()
}
fn extract_at_mn_f32(
&self,
_data: &tract_linalg::mmm::EagerPackedInput,
_mn: usize,
_slice: &mut [f32],
) -> TractResult<()> {
unimplemented!()
}
}

impl Display for LazyIm2colParams {
Expand Down Expand Up @@ -394,6 +402,9 @@ impl MMMInputValue for LazyIm2colInput {
fn extract_at_mn_f16(&self, _mn: usize, _slice: &mut [f16]) -> TractResult<()> {
unimplemented!()
}
fn extract_at_mn_f32(&self, _mn: usize, _slice: &mut [f32]) -> TractResult<()> {
unimplemented!()
}
}

impl LazyIm2colInput {
Expand Down
Loading

0 comments on commit cff3e36

Please sign in to comment.