Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(computegraph): remove compute_with_context and compute_untyped_with_context #92

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 8 additions & 78 deletions crates/computegraph/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
//! context.set_override(multiply_node.input_b(), 2);
//!
//! // Compute the result
//! let result = graph.compute_with_context(multiply_node.output(), &context).unwrap();
//! let options = ComputationOptions {context: Some(&context) };
//! let result = graph.compute_with(multiply_node.output(), &options, None).unwrap();
//! assert_eq!(result, (3 + 4) * 2);
//! ```
//!
Expand Down Expand Up @@ -499,13 +500,13 @@ enum Fallback {
Generator(FallbackGenerator),
}

/// Set predefined values for [`ComputeGraph::compute_with_context`].
/// Set predefined values for [`ComputeGraph::compute_with`].
///
/// Use this container to:
/// - Override values passed to [`InputPort`]s
/// - Set fallback values for unconnected [`InputPort`]s
///
/// To be used with [`ComputeGraph::compute_with_context`] and [`ComputeGraph::compute_untyped_with_context`].
/// To be used with [`ComputeGraph::compute_with`] and [`ComputeGraph::compute_untyped_with`].
#[derive(Debug, Default)]
pub struct ComputationContext {
overrides: Vec<InputPortValue>,
Expand Down Expand Up @@ -1173,6 +1174,8 @@ impl ComputeGraph {
///
/// This function is the untyped version of [`ComputeGraph::compute`].
///
/// Use [`ComputeGraph::compute_untyped_with`] when caching or a context are needed.
///
/// # Arguments
///
/// * `output` - The output port to compute.
Expand Down Expand Up @@ -1637,46 +1640,10 @@ impl ComputeGraph {
}
}

/// Computes the result for a given output port using the provided context, returning a boxed value.
///
/// This function is the untyped version of [`ComputeGraph::compute_with_context`].
///
/// This function behaves similarly to [`ComputeGraph::compute_untyped`], but uses
/// the values given in the context as described in [`ComputationContext`].
///
/// # Arguments
///
/// * `output` - The output port to compute.
/// * `context` - Collection of values used to perform the computation.
///
/// # Returns
///
/// A result containing the computed boxed value or an error.
///
/// # Errors
///
/// An error is returned if:
/// - The node is not found.
/// - The node has the incorrect output type
/// - An input port of the node or a dependency of the node are not connected, and
/// no value is provided via the context
/// - A cycle is detected in the graph.
pub fn compute_untyped_with_context(
&self,
output: OutputPortUntyped,
context: &ComputationContext,
) -> Result<Box<dyn SendSyncAny>, ComputeError> {
self.compute_untyped_with(
output,
&ComputationOptions {
context: Some(context),
},
None,
)
}

/// Computes the result for a given output port.
///
/// Use [`ComputeGraph::compute_with`] when caching or a context are needed.
///
/// # Arguments
///
/// * `output` - The output port to compute.
Expand All @@ -1703,43 +1670,6 @@ impl ComputeGraph {
Ok(*res)
}

/// Computes the result for a given output port using the provided context
///
/// This function behaves similarly to [`ComputeGraph::compute`], but uses
/// the values given in the context as described in [`ComputationContext`].
///
/// # Arguments
///
/// * `output` - The output port to compute.
/// * `context` - Collection of values used to perform the computation,
///
/// # Returns
///
/// A result containing the computed boxed value or an error.
///
/// # Errors
///
/// An error is returned if:
/// - The node is not found.
/// - The node has the incorrect output type
/// - An input port of the node or a dependency of the node are not connected, and
/// no value is provided via the context
/// - A cycle is detected in the graph.
pub fn compute_with_context<T: 'static>(
&self,
output: OutputPort<T>,
context: &ComputationContext,
) -> Result<T, ComputeError> {
let res = self.compute_untyped_with_context(output.port.clone(), context)?;
let res = res
.into_any()
.downcast::<T>()
.map_err(|_| ComputeError::OutputTypeMismatch {
node: output.port.node,
})?;
Ok(*res)
}

/// Computes the result for a given output port using the provided options.
///
/// This function is the primary way to execute computations in the [`ComputeGraph`].
Expand Down
71 changes: 62 additions & 9 deletions crates/computegraph/tests/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,25 @@ fn test_context_override() -> Result<()> {
ctx.set_override(addition.input_a(), 5);

assert_eq!(
graph.compute_with_context(addition.output(), &ctx)?,
graph.compute_with(
addition.output(),
&ComputationOptions {
context: Some(&ctx),
},
None
)?,
8,
"ctx should use the latest given value"
);
assert_eq!(
*graph
.compute_untyped_with_context(addition.output().into(), &ctx)?
.compute_untyped_with(
addition.output().into(),
&ComputationOptions {
context: Some(&ctx),
},
None
)?
.as_ref()
.as_any()
.downcast_ref::<usize>()
Expand Down Expand Up @@ -55,7 +67,13 @@ fn test_context_override_skip_dependencies() -> Result<()> {

assert_eq!(
graph
.compute_with_context(addition.output(), &ctx)
.compute_with(
addition.output(),
&ComputationOptions {
context: Some(&ctx),
},
None
)
.expect("This should skip 'invalid_dep' entirely"),
15
);
Expand Down Expand Up @@ -85,10 +103,25 @@ fn test_context_fallback() -> Result<()> {
ctx.set_fallback(5_usize);
ctx.set_fallback(10_usize);

assert_eq!(graph.compute_with_context(addition.output(), &ctx)?, 20);
assert_eq!(
graph.compute_with(
addition.output(),
&ComputationOptions {
context: Some(&ctx),
},
None
)?,
20
);
assert_eq!(
*graph
.compute_untyped_with_context(addition.output().into(), &ctx)?
.compute_untyped_with(
addition.output().into(),
&ComputationOptions {
context: Some(&ctx),
},
None
)?
.as_ref()
.as_any()
.downcast_ref::<usize>()
Expand All @@ -112,10 +145,25 @@ fn test_context_fallback_generator() -> Result<()> {
ctx.set_fallback(5_usize);
ctx.set_fallback_generator(|_name| 10_usize);

assert_eq!(graph.compute_with_context(addition.output(), &ctx)?, 20);
assert_eq!(
graph.compute_with(
addition.output(),
&ComputationOptions {
context: Some(&ctx),
},
None
)?,
20
);
assert_eq!(
*graph
.compute_untyped_with_context(addition.output().into(), &ctx)?
.compute_untyped_with(
addition.output().into(),
&ComputationOptions {
context: Some(&ctx),
},
None
)?
.as_ref()
.as_any()
.downcast_ref::<usize>()
Expand All @@ -127,7 +175,6 @@ fn test_context_fallback_generator() -> Result<()> {

Ok(())
}

#[test]
fn test_context_priority() -> Result<()> {
let mut graph = ComputeGraph::new();
Expand All @@ -144,7 +191,13 @@ fn test_context_priority() -> Result<()> {
ctx.set_fallback(10_usize);

assert_eq!(
graph.compute_with_context(addition.output(), &ctx)?,
graph.compute_with(
addition.output(),
&ComputationOptions {
context: Some(&ctx),
},
None
)?,
1,
"priority should be override > connected > fallback"
);
Expand Down
30 changes: 23 additions & 7 deletions crates/viewport/src/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::ViewportEvent;
use computegraph::{
ComputationContext, ComputeGraph, DynamicNode, InputPort, InputPortUntyped, NodeFactory,
NodeHandle, OutputPort, OutputPortUntyped,
ComputationContext, ComputationOptions, ComputeGraph, DynamicNode, InputPort, InputPortUntyped,
NodeFactory, NodeHandle, OutputPort, OutputPortUntyped,
};
use project::ProjectView;
use std::any::TypeId;
Expand Down Expand Up @@ -472,9 +472,13 @@ impl ViewportPipeline {
let last_node = self.nodes.last().ok_or(ExecuteError::EmptyPipeline)?;
let mut ctx = ComputationContext::default();
ctx.set_fallback(project_view);
let scene = self
.graph
.compute_with_context(last_node.scene_output.clone(), &ctx)?;
let scene = self.graph.compute_with(
last_node.scene_output.clone(),
&ComputationOptions {
context: Some(&ctx),
},
None,
)?;

Ok(scene)
}
Expand All @@ -500,7 +504,13 @@ impl ViewportPipeline {

let result = scene
.graph
.compute_untyped_with_context(scene.update_state_out, &ctx)
.compute_untyped_with(
scene.update_state_out,
&ComputationOptions {
context: Some(&ctx),
},
None,
)
.map_err(ExecuteError::ComputeError)?;
state.state = Some(result);
Ok(())
Expand All @@ -525,7 +535,13 @@ impl ViewportPipeline {

let result = scene
.graph
.compute_with_context(scene.render_primitive_out, &ctx)
.compute_with(
scene.render_primitive_out,
&ComputationOptions {
context: Some(&ctx),
},
None,
)
.map_err(ExecuteError::ComputeError);
let a = ctx.remove_override_untyped(&scene.render_state_in);
debug_assert!(a.is_some());
Expand Down
Loading