From 4fe36aae3d6502e2a0085e27a11caefdf6da6cc8 Mon Sep 17 00:00:00 2001 From: viandoxdev Date: Sat, 16 Dec 2023 16:19:44 +0100 Subject: [PATCH] more optimizations (170ms) --- 2023/.gitignore | 3 + 2023/src/challenges/day12.rs | 598 ++++++++++++----------------------- 2023/src/main.rs | 2 - 3 files changed, 214 insertions(+), 389 deletions(-) diff --git a/2023/.gitignore b/2023/.gitignore index d22fc7c..d7a0574 100644 --- a/2023/.gitignore +++ b/2023/.gitignore @@ -1,2 +1,5 @@ inputs target +perf.data +perf.data.old +flamegraph.svg diff --git a/2023/src/challenges/day12.rs b/2023/src/challenges/day12.rs index 711881b..36f6116 100644 --- a/2023/src/challenges/day12.rs +++ b/2023/src/challenges/day12.rs @@ -1,25 +1,16 @@ use std::{ - collections::HashMap, - fmt::Display, hash::Hash, - sync::{Arc, OnceLock}, + sync::{Arc, OnceLock}, hint::unreachable_unchecked, }; +use rustc_hash::FxHashMap as HashMap; + use anyhow::{anyhow, Context, Result}; use itertools::Itertools; use parking_lot::RwLock; -use rayon::prelude::{IntoParallelIterator, ParallelIterator}; use Spring::*; -// TODO: Recover this file: -// - use FxHasher -// - Remove any debugging info -// - make Record a struct of only unnresolved -// - remove resolved -// - use u8 where possible -// - Remove recursion - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum Spring { // '#' @@ -30,411 +21,250 @@ enum Spring { Unknown, } -impl Display for Spring { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Damaged => write!(f, "#"), - Operational => write!(f, "."), - Unknown => write!(f, "?"), - } - } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct Record { + current_group: Option, + // Groups in reverse order for efficient unshift + groups: Vec, + // Springs in reverse order for efficient unshift + springs: Vec<(u8, Spring)>, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -enum ResolvingError { - UnfinishedGroup, - UnfinishedGroupAtEnd, - NotEnoughSprings, - TooManyDamaged, - NoMoreGroups, - LeftoverGroups, -} - -impl Display for ResolvingError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::TooManyDamaged => write!(f, "got more damaged springs than the group expected"), - Self::NotEnoughSprings => write!( - f, - "the groups would need more springs than there is to complete" - ), - Self::UnfinishedGroup => write!(f, "can't end a group that wasn't finished (on .)"), - Self::UnfinishedGroupAtEnd => { - write!(f, "can't end a group that wasn't finished (at end)") - } - Self::NoMoreGroups => write!(f, "can't start group because there are no more"), - Self::LeftoverGroups => write!(f, "got to the end with leftover groups"), - } - } -} - -#[derive(Debug, Clone, Eq)] -enum Record { - Resolved { - springs: Vec, - }, - Unresolved { - current_group: Option, - // Groups in reverse order for efficient unshift - groups: Vec, - // Springs in reverse order for efficient unshift - springs: Vec<(u32, Spring)>, - // Resolved springs in order - resolved: Vec<(u32, Spring)>, - }, - Invalid(ResolvingError), +static RESOLVE_MEMO: OnceLock>>> = OnceLock::new(); +fn get_resolve_memo() -> Arc>> { + RESOLVE_MEMO + .get_or_init(|| { + let mut map = HashMap::default(); + map.reserve(2usize.pow(19)); + Arc::new(RwLock::new(map)) + }) + .clone() } -impl PartialEq for Record { - fn eq(&self, other: &Self) -> bool { - match self { - Self::Resolved { springs } => match other { - Self::Resolved { springs: others } => springs == others, - _ => false, - }, - Self::Invalid(err) => match other { - Self::Invalid(others) => err == others, - _ => false, +impl Record { + fn resolve(self) -> u64 { + enum Iteration { + Next { + source: usize, + value: u64, + expanded: bool, + key: Record, }, - Self::Unresolved { - current_group, - groups, - springs, - .. - } => match other { - Self::Unresolved { - current_group: others_cg, - groups: others_gs, - springs: others_sprgs, - .. - } => current_group == others_cg && groups == others_gs && springs == others_sprgs, - _ => false, + Finished { + source: usize, + value: u64, }, } - } -} - -impl Hash for Record { - fn hash(&self, state: &mut H) { - std::mem::discriminant(self).hash(state); - match self { - Self::Resolved { springs } => springs.hash(state), - Self::Invalid(e) => e.hash(state), - Self::Unresolved { - current_group, - groups, - springs, - .. - } => { - current_group.hash(state); - groups.hash(state); - springs.hash(state); - } - } - } -} -impl Display for Record { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Invalid(err) => write!(f, "\x1b[1;31m[ invalid: {err} ]\x1b[0m"), - Self::Resolved { springs } => { - write!(f, "\x1b[32m[ ")?; - for s in springs { - write!(f, "{s}")?; + impl Iteration { + fn add_to_value(&mut self, v: u64) { + match self { + Self::Next { value, .. } => *value += v, + Self::Finished { value, .. } => *value += v, } - write!(f, " ]\x1b[0m") } - Self::Unresolved { - groups, - springs, - resolved, - current_group, - } => { - write!(f, "\x1b[2m[ ")?; - if !springs.is_empty() { - for &(c, s) in springs.iter().rev() { - for _ in 0..c { - write!(f, "{s}")?; - } - write!(f, " ")?; - } - } else { - write!(f, "ø")?; - } - write!(f, " | ")?; - if let Some(last) = groups.first() { - if let Some(cur) = current_group { - write!(f, "+{cur},")?; - } - - for g in groups[1..].iter().rev() { - write!(f, "{g},")?; - } - - write!(f, "{last}")?; - } else { - if let Some(cur) = current_group { - write!(f, "+{cur}")?; - } else { - write!(f, "ø")?; - } - } - write!(f, " ] -> [ ")?; - for &(c, s) in resolved { - for _ in 0..c { - write!(f, "{s}")?; - } + fn set_expanded(&mut self) { + if let Self::Next { expanded, .. } = self { + *expanded = true; } - write!(f, " ]\x1b[0m") } } - } -} - -static RESOLVE_MEMO: OnceLock>>> = OnceLock::new(); -fn get_resolve_memo() -> Arc>> { - RESOLVE_MEMO - .get_or_init(|| Arc::new(RwLock::new(HashMap::new()))) - .clone() -} -impl Record { - fn resolve(self) -> usize { let memo = get_resolve_memo(); - if let Some(&res) = memo.read().get(&self) { - return res; + if let Some(&value) = memo.read().get(&self) { + return value; } - let key = self.clone(); + let mut work = vec![Iteration::Next { + source: usize::MAX, + value: 0, + expanded: false, + key: self, + }]; + + while let Some(last) = work.last() { + let index = work.len() - 1; + let key = match last { + Iteration::Next { + key, + expanded: false, + .. + } => key, + Iteration::Next { expanded: true, .. } => { + let Some(Iteration::Next { + source, value, key, .. + }) = work.pop() + else { + unsafe { unreachable_unchecked() } + }; + + memo.write().insert(key, value); + + if source == usize::MAX { + return value; + } else { + work[source].add_to_value(value); + continue; + } - let mut nexts = Vec::new(); - match self { - Self::Resolved { .. } => { - memo.write().insert(self, 1); - return 1; - } - Self::Invalid(_) => {} - Self::Unresolved { + } + &Iteration::Finished { source, value } => { + work.pop(); + work[source].add_to_value(value); + continue; + } + }; + + let Record { mut groups, mut springs, - mut resolved, mut current_group, - } => { - // Minimum number of springs needed to be valid, if all damaged springs are continuous and - // separated by a single operational one (sum of each plus one for each separator). - let min_springs = - (groups.iter().sum::() as usize + groups.len()).saturating_sub(1); - let spring_count: usize = springs.iter().map(|&(c, _)| c as usize).sum(); - let s = springs.pop(); - - match (s, current_group.as_mut()) { - // Fail fast, if we can't meet the quota no need to continue through iterations - // and branches - _ if min_springs > spring_count => { - nexts.push(Self::Invalid(ResolvingError::NotEnoughSprings)); + } = key.clone(); + + work.last_mut().unwrap().set_expanded(); + + macro_rules! nxt { + ($groups:expr, $springs:expr, $cur:expr) => {{ + let key = Record { + groups: $groups, + springs: $springs, + current_group: $cur, + }; + + if let Some(&value) = memo.read().get(&key) { + work.push(Iteration::Finished { + source: index, + value, + }); + } else { + work.push(Iteration::Next { + source: index, + value: 0, + key, + expanded: false, + }); } + }}; + } - // We have no more springs and no more groups, we have resolved the record - (None, None) if groups.is_empty() => nexts.push(Self::Resolved { - springs: resolved - .into_iter() - .flat_map(|(c, s)| std::iter::repeat(s).take(c as usize)) - .collect(), - }), - // We have no more springs but we still have groups - (None, None) => nexts.push(Self::Invalid(ResolvingError::LeftoverGroups)), - // We have no more springs and the last group started has been finished, so we - // have resolved the record - (None, Some(0)) => nexts.push(Self::Resolved { - springs: resolved - .into_iter() - .flat_map(|(c, s)| std::iter::repeat(s).take(c as usize)) - .collect(), - }), - // We have no more springs and the started group isn't finished, this record - // must be invalid - (None, Some(_)) => { - nexts.push(Self::Invalid(ResolvingError::UnfinishedGroupAtEnd)) - } + // Minimum number of springs needed to be valid, if all damaged springs are continuous and + // separated by a single operational one (sum of each plus one for each separator). + let min_springs = (groups.iter().sum::() as usize + groups.len()).saturating_sub(1); + let spring_count: usize = springs.iter().map(|&(c, _)| c as usize).sum(); + let s = springs.pop(); + + match (s, current_group.as_mut()) { + // Fail fast, if we can't meet the quota no need to continue through iterations + // and branches + _ if min_springs > spring_count => {} + + // We have no more springs and no more groups, or we have no more springs and the + // last group started has been finished, so we have resolved the record + (None, None) | (None, Some(0)) if groups.is_empty() => { + work.push(Iteration::Finished { + source: index, + value: 1, + }); + } - // A group is started and we encounter some '#'s while we expect more - (Some((y, Damaged)), Some(x)) if *x >= y => { - *x -= y; - resolved.push((y, Damaged)); - nexts.push(Self::Unresolved { - groups, - springs, - resolved, - current_group, - }) - } - // Since the arm above didn't go through this means we have more '#' than the - // group expects - (Some((_, Damaged)), Some(_)) => { - nexts.push(Self::Invalid(ResolvingError::TooManyDamaged)) - } - // We got '#', we don't have a started group - (Some((y, Damaged)), None) => match groups.last() { - // The next group expects enough - Some(&x) if x >= y => { - current_group = groups.pop().map(|x| x - y); - resolved.push((y, Damaged)); - nexts.push(Self::Unresolved { - current_group, - groups, - springs, - resolved, - }) - } - // We have more than the next group expects - Some(_) => nexts.push(Self::Invalid(ResolvingError::TooManyDamaged)), - // There is no next group - None => nexts.push(Self::Invalid(ResolvingError::NoMoreGroups)), - }, - - // We have '.' after finishing a group, this is normal, end the current group - // and continue - (Some((c, Operational)), Some(&mut 0)) => { - resolved.push((c, Operational)); - nexts.push(Self::Unresolved { - groups, - springs, - resolved, - current_group: None, - }) - } - // We have '.' but the current group hasn't been finished, this record must be - // invalid - (Some((_, Operational)), Some(_)) => { - nexts.push(Self::Invalid(ResolvingError::UnfinishedGroup)) - } - // We have '.' but haven't started any group, this is also normal and we have - // nothing to do, continue - (Some((c, Operational)), None) => { - resolved.push((c, Operational)); - nexts.push(Self::Unresolved { - groups, - springs, - resolved, - current_group, - }) - } + // We have no more springs but we still have groups + (None, None) => {} + // We have no more springs and the started group isn't finished, this record + // must be invalid + (None, Some(_)) => {} - // We got '?' while we had an unfinished group - (Some((y, Unknown)), Some(&mut x @ (1..))) => { - if x >= y { - resolved.push((y, Damaged)); - nexts.push(Self::Unresolved { - current_group: Some(x - y), - groups, - springs, - resolved, - }) - } else { - springs.push((y - x, Unknown)); - resolved.push((x, Damaged)); - nexts.push(Self::Unresolved { - current_group: Some(0), - groups, - springs, - resolved, - }) - } + // A group is started and we encounter some '#'s while we expect more + (Some((y, Damaged)), Some(x)) if *x >= y => { + *x -= y; + nxt!(groups, springs, current_group); + } + // Since the arm above didn't go through this means we have more '#' than the + // group expects + (Some((_, Damaged)), Some(_)) => {} + // We got '#', we don't have a started group + (Some((y, Damaged)), None) => match groups.last() { + // The next group expects enough + Some(&x) if x >= y => { + current_group = groups.pop().map(|x| x - y); + nxt!(groups, springs, current_group); } - - // We got '?' while we had a finished group, end it and leave the rest for - // later - (Some((c, Unknown)), Some(0)) => { - springs.push((c - 1, Unknown)); - resolved.push((1, Operational)); - nexts.push(Self::Unresolved { - current_group: None, - groups, - springs, - resolved, - }) + // We have more than the next group expects or there is no next group + Some(_) | None => {} + }, + + // We have '.' after finishing a group, this is normal, end the current group + // and continue + (Some((_, Operational)), Some(&mut 0)) => nxt!(groups, springs, None), + // We have '.' but the current group hasn't been finished, this record must be + // invalid + (Some((_, Operational)), Some(_)) => {} + // We have '.' but haven't started any group, this is also normal and we have + // nothing to do, continue + (Some((_, Operational)), None) => nxt!(groups, springs, current_group), + + // We got '?' while we had an unfinished group + (Some((y, Unknown)), Some(&mut x @ (1..))) => { + if x >= y { + nxt!(groups, springs, Some(x - y)); + } else { + springs.push((y - x, Unknown)); + nxt!(groups, springs, Some(0)); } + } - // We got '?' while we didn't have a started group - (Some((c, Unknown)), None) => match groups.pop() { - // No group to start, only solution is for all of the '?' to become - // '.' - None => { - resolved.push((c, Operational)); - nexts.push(Self::Unresolved { - current_group: None, - groups, - springs, - resolved, - }) - } - // We have a group g - Some(g) => { - // We can choose to treat i '?' as '.' and the next '?' as a '#' - for i in 0..c { - let mut new_springs = springs.clone(); - let mut new_resolved = resolved.clone(); - // Add back the springs we didn't use - if c - i - 1 > 0 { - new_springs.push((c - i - 1, Unknown)); - } - - new_resolved.push((i, Operational)); - new_resolved.push((1, Damaged)); - nexts.push(Self::Unresolved { - // Already decrement since we also do the first '#' - current_group: Some(g - 1), - groups: groups.clone(), - springs: new_springs, - resolved: new_resolved, - }); + // We got '?' while we had a finished group, end it and leave the rest for + // later + (Some((c, Unknown)), Some(0)) => { + springs.push((c - 1, Unknown)); + nxt!(groups, springs, None); + } + + // We got '?' while we didn't have a started group + (Some((c, Unknown)), None) => match groups.pop() { + // No group to start, only solution is for all of the '?' to become + // '.' + None => nxt!(groups, springs, None), + // We have a group g + Some(g) => { + // We can choose to treat i '?' as '.' and the next '?' as a '#' + for i in 0..c { + let mut new_springs = springs.clone(); + // Add back the springs we didn't use + if c - i - 1 > 0 { + new_springs.push((c - i - 1, Unknown)); } - // Or treat all '?' as '.' - // Add back the group we didn't consume - groups.push(g); - resolved.push((c, Operational)); - nexts.push(Self::Unresolved { - current_group: None, - groups, - springs, - resolved, - }) + // Already decrement since we also do the first '#' + nxt!(groups.clone(), new_springs, Some(g - 1)); } - }, - } + + // Or treat all '?' as '.' + // Add back the group we didn't consume + groups.push(g); + nxt!(groups, springs, None); + } + }, } } - let res = nexts.into_iter().map(Record::resolve).sum(); - memo.write().insert(key, res); - res + + unsafe { unreachable_unchecked() } } fn unfold(self) -> Self { - match self { - Self::Resolved { .. } | Self::Invalid(_) => self, - Self::Unresolved { - current_group, - groups, - springs, - resolved, - } => Self::Unresolved { - groups: std::iter::repeat_with(|| groups.iter().copied()) - .take(5) - .flatten() - .collect(), - springs: std::iter::repeat_with(|| { - std::iter::once((1, Unknown)).chain(springs.iter().copied()) - }) + Self { + groups: std::iter::repeat_with(|| self.groups.iter().copied()) .take(5) .flatten() - .skip(1) .collect(), - current_group, - resolved, - }, + springs: std::iter::repeat_with(|| { + std::iter::once((1, Unknown)).chain(self.springs.iter().copied()) + }) + .take(5) + .flatten() + .skip(1) + .collect(), + ..self } } } @@ -458,30 +288,24 @@ pub async fn day12(input: String) -> Result<(String, String)> { let springs = springs .into_iter() .dedup_with_count() - .map(|(c, x)| (c as u32, x)) + .map(|(c, x)| (c as u8, x)) .collect_vec(); - let groups: Vec = groups_str + let groups: Vec = groups_str .split(',') .map(|x| x.parse()) .rev() .try_collect()?; - Ok(Record::Unresolved { + Ok(Record { springs, groups, - resolved: Vec::new(), current_group: None, }) }) .try_collect()?; - let solve = |records: Vec| { - records - .into_iter() - .map(|r| r.resolve()) - .sum::() - }; + let solve = |records: Vec| records.into_iter().map(|r| r.resolve()).sum::(); let part1 = solve(records.clone()); let part2 = solve(records.into_iter().map(Record::unfold).collect()); diff --git a/2023/src/main.rs b/2023/src/main.rs index 7d14cab..8247f05 100644 --- a/2023/src/main.rs +++ b/2023/src/main.rs @@ -179,8 +179,6 @@ async fn main() -> Result<()> { let days = &days[0..released]; - let days = &days[11..=11]; - let session_file = std::fs::read_to_string("../session")?; let session = session_file.trim_end(); let aoc = Aoc::new(session);