From 866544d416cb758ef4cfdbf275f416e837bb5528 Mon Sep 17 00:00:00 2001 From: Keenan Tims Date: Wed, 10 Dec 2025 10:14:57 -0800 Subject: [PATCH] day10: working but way too slow simd implementation --- Cargo.lock | 26 ++++++ Cargo.toml | 1 + src/day10.rs | 218 +++++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 203 insertions(+), 42 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1e0192d..dc466c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -72,6 +72,7 @@ dependencies = [ "rayon", "regex", "rstest", + "wide", ] [[package]] @@ -86,6 +87,12 @@ version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" +[[package]] +name = "bytemuck" +version = "1.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" + [[package]] name = "cached" version = "0.56.0" @@ -556,6 +563,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "safe_arch" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f7caad094bd561859bcd467734a720c3c1f5d1f338995351fefe2190c45efed" +dependencies = [ + "bytemuck", +] + [[package]] name = "semver" version = "1.0.27" @@ -767,6 +783,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "wide" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbace5de6cfc4866f684318ad85761c89380cfb191982ae96aa65c295bf5897e" +dependencies = [ + "bytemuck", + "safe_arch", +] + [[package]] name = "windows-link" version = "0.2.1" diff --git a/Cargo.toml b/Cargo.toml index dba4132..c34d956 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ misc = {path = "utils/misc"} rayon = "1.11.0" regex = "1.11.1" rstest = "0.26.1" +wide = "1.0.2" [profile.release] lto = true diff --git a/src/day10.rs b/src/day10.rs index 8474e41..42b6d6b 100644 --- a/src/day10.rs +++ b/src/day10.rs @@ -1,13 +1,22 @@ -use std::{collections::BinaryHeap, hash::Hash, iter::repeat_n}; +use std::{ + collections::{BinaryHeap, VecDeque}, + hash::Hash, + iter::{repeat, repeat_n}, +}; use aoc_runner_derive::{aoc, aoc_generator}; +use indicatif::{ProgressBar, ProgressStyle}; +use itertools::Itertools; use regex::Regex; +use wide::{CmpGt, i16x16}; #[derive(Clone, Debug, Default)] struct MachineDefinition { desired: Vec, buttons: Vec>, - joltages: Vec, + buttons2: Vec, + buttons_max: i16x16, + joltages: i16x16, } impl MachineDefinition { @@ -17,6 +26,13 @@ impl MachineDefinition { lights: Vec::from_iter(repeat_n(false, self.desired.len())), } } + + fn create2<'a>(&'a self) -> JoltMachine<'a> { + JoltMachine { + d: self, + joltages: i16x16::splat(0), + } + } } impl From<&str> for MachineDefinition { @@ -27,22 +43,47 @@ impl From<&str> for MachineDefinition { .unwrap(); let parts = parse_re.captures(value).unwrap(); + let joltages: [i16; 16] = parts["joltages"] + .split(',') + .map(|n| n.parse().unwrap()) + .chain(repeat(0)) + .take(16) + .collect_array() + .unwrap(); + + let buttons = parts["buttons"] + .split_ascii_whitespace() + .map(|s| { + s[1..s.len() - 1] + .split(',') + .map(|n| n.parse().unwrap()) + .collect() + }) + .sorted_unstable_by_key(|s: &Vec| s.len()) + .rev() + .collect_vec(); + + let mut buttons2 = Vec::new(); + let mut buttons_max = [0i16; 16]; + + for (i, b) in buttons.iter().enumerate() { + let mut but = [0i16; 16]; + for i in b { + but[*i] = 1; + } + buttons2.push(i16x16::new(but)); + + // find the joltage this button affects with the lowest value + // it is the max number of presses for this button + buttons_max[i] = b.iter().map(|idx| joltages[*idx]).min().unwrap(); + } MachineDefinition { desired: parts["desired"].chars().map(|c| c == '#').collect(), - buttons: parts["buttons"] - .split_ascii_whitespace() - .map(|s| { - s[1..s.len() - 1] - .split(',') - .map(|n| n.parse().unwrap()) - .collect() - }) - .collect(), - joltages: parts["joltages"] - .split(',') - .map(|n| n.parse().unwrap()) - .collect(), + buttons: buttons, + buttons2: buttons2, + buttons_max: i16x16::new(buttons_max), + joltages: i16x16::new(joltages), } } } @@ -53,25 +94,18 @@ struct Machine<'a> { lights: Vec, } -impl<'a> Eq for Machine<'a> {} -impl<'a> PartialEq for Machine<'a> { - fn eq(&self, other: &Self) -> bool { - self.lights == other.lights - } -} - -impl<'a> Hash for Machine<'a> { - fn hash(&self, state: &mut H) { - self.lights.hash(state) - } +#[derive(Clone, Debug)] +struct JoltMachine<'a> { + d: &'a MachineDefinition, + joltages: i16x16, } impl<'a> Machine<'a> { - /// Get the state after pressing `button`, returns None if the state is as desired + /// Get the state after pressing `button`, returns None if the state is as desired. fn press(&self, button: usize) -> Option { let mut new_state = self.lights.clone(); for light in &self.d.buttons[button] { - new_state[*light] = !new_state[*light] + new_state[*light] = !new_state[*light]; } if new_state == self.d.desired { None @@ -82,9 +116,8 @@ impl<'a> Machine<'a> { }) } } - /// Get the possible states from the current position - fn next_states(&self) -> Vec<(usize, Option>)> { + fn next_states(&self) -> Vec<(usize, Option)> { self.d .buttons .iter() @@ -94,14 +127,65 @@ impl<'a> Machine<'a> { } } +impl<'a> JoltMachine<'a> { + fn press_jolts(&self, button: usize, presses: &i16x16) -> (i16x16, Option) { + // let mut new_joltage = self.joltages.clone(); + // // for jolt in &self.d.buttons[button] { + // // new_joltage[*jolt] += 1; + // // } + let new_joltage = self.joltages + self.d.buttons2[button]; + let mut new_presses = presses.clone(); + new_presses.as_mut_array()[button] += 1; + if new_joltage == self.d.joltages { + (new_presses, None) + } else { + ( + new_presses, + Some(Self { + d: self.d, + joltages: new_joltage, + }), + ) + } + } + + fn next_states_jolt(&self, presses: &i16x16) -> Vec<(i16x16, Option)> { + self.d + .buttons + .iter() + .enumerate() + .map(|(i, _but)| self.press_jolts(i, &presses)) + // .inspect(|(p, o)| println!(" {p:?} {o:?}\n")) + // joltages monotonically increase, so cull any where a joltage is higher than needed + .filter(|(presses, candidate)| { + !presses.simd_gt(self.d.buttons_max).any() + && candidate.as_ref().is_none_or(|c| { + !c.joltages.simd_gt(self.d.joltages).any() + // !c.joltages + // .iter() + // .zip(self.d.joltages.iter()) + // .any(|(candidate, expected)| candidate > expected) + }) + }) + .collect() + } +} + #[derive(Debug, Clone)] struct PressSet<'a> { machine: Machine<'a>, presses: usize, } +#[derive(Debug, Clone)] +struct PressSet2<'a> { + machine: JoltMachine<'a>, + presses: i16x16, +} + // NOTE: All compares are reversed so our max heap becomes a min heap impl<'a> Eq for PressSet<'a> {} +impl<'a> Eq for PressSet2<'a> {} impl<'a> PartialEq for PressSet<'a> { fn eq(&self, other: &Self) -> bool { @@ -109,33 +193,44 @@ impl<'a> PartialEq for PressSet<'a> { } } +impl<'a> PartialEq for PressSet2<'a> { + fn eq(&self, other: &Self) -> bool { + other.presses.eq(&self.presses) + } +} + impl<'a> PartialOrd for PressSet<'a> { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } +impl<'a> PartialOrd for PressSet2<'a> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + impl<'a> Ord for PressSet<'a> { fn cmp(&self, other: &Self) -> std::cmp::Ordering { other.presses.cmp(&self.presses) } } +impl<'a> Ord for PressSet2<'a> { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + other.presses.reduce_add().cmp(&self.presses.reduce_add()) + } +} + fn find_best(md: &MachineDefinition) -> usize { let m = md.create(); let mut to_check = BinaryHeap::new(); + to_check.push(PressSet { + presses: 0, + machine: m, + }); - for next in m.next_states() { - if let Some(new_m) = next.1 { - to_check.push(PressSet { - presses: 1, - machine: new_m.clone(), - }) - } else { - // what we found a solution on the first move? - return 1; - } - } while let Some(candidate) = to_check.pop() { let cm = candidate.machine.clone(); for next in cm.next_states() { @@ -153,6 +248,41 @@ fn find_best(md: &MachineDefinition) -> usize { panic!() } +fn find_best_jolts(md: &MachineDefinition) -> usize { + let m = md.create2(); + let mut to_check = VecDeque::new(); + to_check.push_back(PressSet2 { + presses: i16x16::splat(0), + machine: m, + }); + + let mut pb = ProgressBar::no_length() + .with_style( + ProgressStyle::with_template( + "[{elapsed_precise}/{eta_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {per_sec}", + ) + .unwrap(), + ) + .with_finish(indicatif::ProgressFinish::AndLeave); + + while let Some(candidate) = to_check.pop_front() { + pb.inc(1); + pb.set_length(to_check.len() as u64); + let cm = candidate.machine.clone(); + for (presses, next) in cm.next_states_jolt(&candidate.presses) { + if let Some(new_m) = next { + to_check.push_back(PressSet2 { + presses, + machine: new_m.clone(), + }) + } else { + return presses.reduce_add() as usize; + } + } + } + panic!() +} + #[aoc_generator(day10)] fn parse(input: &str) -> Vec { input.lines().map(|l| l.into()).collect() @@ -170,7 +300,11 @@ fn part1(input: &[MachineDefinition]) -> u64 { #[aoc(day10, part2)] fn part2(input: &[MachineDefinition]) -> u64 { - 0 + input + .iter() + .map(find_best_jolts) + .map(|sol| sol as u64) + .sum() } #[cfg(test)]