Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix accumulator only accumulating direct children #524

Merged
merged 9 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/function/accumulated.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{accumulator, storage::DatabaseGen, Id};
use crate::{accumulator, hash::FxHashSet, storage::DatabaseGen, DatabaseKeyIndex, Id};

use super::{Configuration, IngredientImpl};

Expand All @@ -21,14 +21,24 @@ where
// First ensure the result is up to date
self.fetch(db, key);

let database_key_index = self.database_key_index(key);
accumulator.produced_by(runtime, database_key_index, &mut output);
let db_key = self.database_key_index(key);
let mut visited: FxHashSet<DatabaseKeyIndex> = FxHashSet::default();
let mut stack: Vec<DatabaseKeyIndex> = vec![db_key];

if let Some(origin) = self.origin(key) {
for input in origin.inputs() {
if let Ok(input) = input.try_into() {
accumulator.produced_by(runtime, input, &mut output);
}
while let Some(k) = stack.pop() {
if visited.insert(k) {
accumulator.produced_by(runtime, k, &mut output);

let origin = db.lookup_ingredient(k.ingredient_index).origin(k.key_index);
let inputs = origin.iter().flat_map(|origin| origin.inputs());
// Careful: we want to push in execution order, so reverse order to
// ensure the first child that was executed will be the first child popped
// from the stack.
stack.extend(
inputs
.flat_map(|input| TryInto::<DatabaseKeyIndex>::try_into(input).into_iter())
.rev(),
);
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/runtime/local_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub enum QueryOrigin {

impl QueryOrigin {
/// Indices for queries *read* by this query
pub(crate) fn inputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
pub(crate) fn inputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
let opt_edges = match self {
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
Expand All @@ -86,7 +86,7 @@ impl QueryOrigin {
}

/// Indices for queries *written* by this query (if any)
pub(crate) fn outputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
pub(crate) fn outputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
let opt_edges = match self {
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None,
Expand Down Expand Up @@ -127,7 +127,7 @@ impl QueryEdges {
/// Returns the (tracked) inputs that were executed in computing this memoized value.
///
/// These will always be in execution order.
pub(crate) fn inputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
pub(crate) fn inputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
self.input_outputs
.iter()
.filter(|(edge_kind, _)| *edge_kind == EdgeKind::Input)
Expand All @@ -137,7 +137,7 @@ impl QueryEdges {
/// Returns the (tracked) outputs that were executed in computing this memoized value.
///
/// These will always be in execution order.
pub(crate) fn outputs(&self) -> impl Iterator<Item = DependencyIndex> + '_ {
pub(crate) fn outputs(&self) -> impl DoubleEndedIterator<Item = DependencyIndex> + '_ {
self.input_outputs
.iter()
.filter(|(edge_kind, _)| *edge_kind == EdgeKind::Output)
Expand Down
57 changes: 57 additions & 0 deletions tests/accumulate-chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//! Test that when having nested tracked functions
//! we don't drop any values when accumulating.

mod common;

use expect_test::expect;
use salsa::{Accumulator, Database};
use test_log::test;

#[salsa::accumulator]
struct Log(#[allow(dead_code)] String);

#[salsa::tracked]
fn push_logs(db: &dyn Database) {
push_a_logs(db);
}

#[salsa::tracked]
fn push_a_logs(db: &dyn Database) {
Log("log a".to_string()).accumulate(db);
push_b_logs(db);
}

#[salsa::tracked]
fn push_b_logs(db: &dyn Database) {
// No logs
push_c_logs(db);
}

#[salsa::tracked]
fn push_c_logs(db: &dyn Database) {
// No logs
push_d_logs(db);
}

#[salsa::tracked]
fn push_d_logs(db: &dyn Database) {
Log("log d".to_string()).accumulate(db);
}

#[test]
fn accumulate_chain() {
salsa::default_database().attach(|db| {
let logs = push_logs::accumulated::<Log>(db);
// Check that we get all the logs.
expect![[r#"
[
Log(
"log a",
),
Log(
"log d",
),
]"#]]
.assert_eq(&format!("{:#?}", logs));
})
}
64 changes: 64 additions & 0 deletions tests/accumulate-execution-order.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//! Demonstrates that accumulation is done in the order
//! in which things were originally executed.

mod common;

use expect_test::expect;
use salsa::{Accumulator, Database};
use test_log::test;

#[salsa::accumulator]
struct Log(#[allow(dead_code)] String);

#[salsa::tracked]
fn push_logs(db: &dyn Database) {
push_a_logs(db);
}

#[salsa::tracked]
fn push_a_logs(db: &dyn Database) {
Log("log a".to_string()).accumulate(db);
push_b_logs(db);
push_c_logs(db);
push_d_logs(db);
}

#[salsa::tracked]
fn push_b_logs(db: &dyn Database) {
Log("log b".to_string()).accumulate(db);
push_d_logs(db);
}

#[salsa::tracked]
fn push_c_logs(db: &dyn Database) {
Log("log c".to_string()).accumulate(db);
}

#[salsa::tracked]
fn push_d_logs(db: &dyn Database) {
Log("log d".to_string()).accumulate(db);
}

#[test]
fn accumulate_execution_order() {
salsa::default_database().attach(|db| {
let logs = push_logs::accumulated::<Log>(db);
// Check that we get logs in execution order
expect![[r#"
[
Log(
"log a",
),
Log(
"log b",
),
Log(
"log d",
),
Log(
"log c",
),
]"#]]
.assert_eq(&format!("{:#?}", logs));
})
}
104 changes: 104 additions & 0 deletions tests/accumulate-no-duplicates.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//! Test that we don't get duplicate accumulated values

mod common;

use expect_test::expect;
use salsa::{Accumulator, Database};
use test_log::test;

// A(1) {
// B
// B
// C {
// D {
// A(2) {
// B
// }
// B
// }
// E
// }
// B
// }

#[salsa::accumulator]
struct Log(#[allow(dead_code)] String);

#[salsa::input]
struct MyInput {
n: u32,
}

#[salsa::tracked]
fn push_logs(db: &dyn Database) {
push_a_logs(db, MyInput::new(db, 1));
}

#[salsa::tracked]
fn push_a_logs(db: &dyn Database, input: MyInput) {
Log("log a".to_string()).accumulate(db);
if input.n(db) == 1 {
push_b_logs(db);
push_b_logs(db);
push_c_logs(db);
push_b_logs(db);
} else {
push_b_logs(db);
}
}

#[salsa::tracked]
fn push_b_logs(db: &dyn Database) {
Log("log b".to_string()).accumulate(db);
}

#[salsa::tracked]
fn push_c_logs(db: &dyn Database) {
Log("log c".to_string()).accumulate(db);
push_d_logs(db);
push_e_logs(db);
}

// Note this isn't tracked
fn push_d_logs(db: &dyn Database) {
Log("log d".to_string()).accumulate(db);
push_a_logs(db, MyInput::new(db, 2));
push_b_logs(db);
}

#[salsa::tracked]
fn push_e_logs(db: &dyn Database) {
Log("log e".to_string()).accumulate(db);
}

#[test]
fn accumulate_no_duplicates() {
salsa::default_database().attach(|db| {
let logs = push_logs::accumulated::<Log>(db);
// Test that there aren't duplicate B logs.
// Note that log A appears twice, because they both come
// from different inputs.
expect![[r#"
[
Log(
"log a",
),
Log(
"log b",
),
Log(
"log c",
),
Log(
"log d",
),
Log(
"log a",
),
Log(
"log e",
),
]"#]]
.assert_eq(&format!("{:#?}", logs));
})
}
Loading