-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathday-14.rs
110 lines (94 loc) · 2.87 KB
/
day-14.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use std::{collections::HashMap, str::FromStr};
const INPUT: &str = include_str!("day-14.input");
fn main() {
let program: Vec<Instruction> = INPUT.lines().map(|s| s.parse().unwrap()).collect();
println!("part 1: {}", part_1(&program));
println!("part 2: {}", part_2(&program));
}
#[derive(Copy, Clone)]
enum Instruction {
Mask(u64, u64),
Set(u64, u64),
}
impl FromStr for Instruction {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut it = s.split(" = ");
let target = it.next().unwrap();
let arg = it.next().unwrap();
if target == "mask" {
return Ok(Instruction::Mask(
u64::from_str_radix(&arg.replace('X', "1"), 2).unwrap(),
u64::from_str_radix(&arg.replace('X', "0"), 2).unwrap(),
));
}
if target.starts_with("mem[") && target.ends_with("]") {
return Ok(Instruction::Set(
target[4..(target.len() - 1)].parse().unwrap(),
arg.parse().unwrap(),
));
}
panic!("illegal instruction: {}", s);
}
}
fn part_1(program: &[Instruction]) -> u64 {
let mut mem = HashMap::new();
let mut and = 0xffff_ffff_ffff;
let mut or = 0;
let mut sum = 0;
for &i in program.iter() {
match i {
Instruction::Mask(a, o) => {
and = a;
or = o;
}
Instruction::Set(addr, value) => {
let value = (value & and) | or;
sum += value;
if let Some(prev) = mem.insert(addr, value) {
sum -= prev;
}
}
}
}
sum
}
fn part_2(program: &[Instruction]) -> u64 {
let mut mem = HashMap::new();
let mut or = 0;
let mut floats = vec![0];
let mut sum = 0;
for &i in program.iter() {
match i {
Instruction::Mask(a, o) => {
or = o;
let mut bits = !(!a | o) & 0xffff_ffff_ffff;
floats.clear();
floats.push(bits);
loop {
let tz = bits.trailing_zeros();
if tz == 64 {
break;
}
let flip = 1 << tz;
bits ^= flip;
let n = floats.len();
floats.resize(n * 2, 0);
for i in 0..n {
floats[n + i] = floats[i] ^ flip;
}
}
}
Instruction::Set(addr, value) => {
for &f in floats.iter() {
let addr = (addr | or) ^ f;
sum += value;
if let Some(prev) = mem.insert(addr, value) {
sum -= prev;
}
}
}
}
}
sum
}