day10: working but way too slow simd implementation

This commit is contained in:
2025-12-10 10:14:57 -08:00
parent 0a5e5c8798
commit 866544d416
3 changed files with 203 additions and 42 deletions

26
Cargo.lock generated
View File

@@ -72,6 +72,7 @@ dependencies = [
"rayon",
"regex",
"rstest",
"wide",
]
[[package]]
@@ -86,6 +87,12 @@ version = "3.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43"
[[package]]
name = "bytemuck"
version = "1.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4"
[[package]]
name = "cached"
version = "0.56.0"
@@ -556,6 +563,15 @@ version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "safe_arch"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f7caad094bd561859bcd467734a720c3c1f5d1f338995351fefe2190c45efed"
dependencies = [
"bytemuck",
]
[[package]]
name = "semver"
version = "1.0.27"
@@ -767,6 +783,16 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "wide"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbace5de6cfc4866f684318ad85761c89380cfb191982ae96aa65c295bf5897e"
dependencies = [
"bytemuck",
"safe_arch",
]
[[package]]
name = "windows-link"
version = "0.2.1"

View File

@@ -14,6 +14,7 @@ misc = {path = "utils/misc"}
rayon = "1.11.0"
regex = "1.11.1"
rstest = "0.26.1"
wide = "1.0.2"
[profile.release]
lto = true

View File

@@ -1,13 +1,22 @@
use std::{collections::BinaryHeap, hash::Hash, iter::repeat_n};
use std::{
collections::{BinaryHeap, VecDeque},
hash::Hash,
iter::{repeat, repeat_n},
};
use aoc_runner_derive::{aoc, aoc_generator};
use indicatif::{ProgressBar, ProgressStyle};
use itertools::Itertools;
use regex::Regex;
use wide::{CmpGt, i16x16};
#[derive(Clone, Debug, Default)]
struct MachineDefinition {
desired: Vec<bool>,
buttons: Vec<Vec<usize>>,
joltages: Vec<u64>,
buttons2: Vec<i16x16>,
buttons_max: i16x16,
joltages: i16x16,
}
impl MachineDefinition {
@@ -17,6 +26,13 @@ impl MachineDefinition {
lights: Vec::from_iter(repeat_n(false, self.desired.len())),
}
}
fn create2<'a>(&'a self) -> JoltMachine<'a> {
JoltMachine {
d: self,
joltages: i16x16::splat(0),
}
}
}
impl From<&str> for MachineDefinition {
@@ -27,22 +43,47 @@ impl From<&str> for MachineDefinition {
.unwrap();
let parts = parse_re.captures(value).unwrap();
let joltages: [i16; 16] = parts["joltages"]
.split(',')
.map(|n| n.parse().unwrap())
.chain(repeat(0))
.take(16)
.collect_array()
.unwrap();
let buttons = parts["buttons"]
.split_ascii_whitespace()
.map(|s| {
s[1..s.len() - 1]
.split(',')
.map(|n| n.parse().unwrap())
.collect()
})
.sorted_unstable_by_key(|s: &Vec<usize>| s.len())
.rev()
.collect_vec();
let mut buttons2 = Vec::new();
let mut buttons_max = [0i16; 16];
for (i, b) in buttons.iter().enumerate() {
let mut but = [0i16; 16];
for i in b {
but[*i] = 1;
}
buttons2.push(i16x16::new(but));
// find the joltage this button affects with the lowest value
// it is the max number of presses for this button
buttons_max[i] = b.iter().map(|idx| joltages[*idx]).min().unwrap();
}
MachineDefinition {
desired: parts["desired"].chars().map(|c| c == '#').collect(),
buttons: parts["buttons"]
.split_ascii_whitespace()
.map(|s| {
s[1..s.len() - 1]
.split(',')
.map(|n| n.parse().unwrap())
.collect()
})
.collect(),
joltages: parts["joltages"]
.split(',')
.map(|n| n.parse().unwrap())
.collect(),
buttons: buttons,
buttons2: buttons2,
buttons_max: i16x16::new(buttons_max),
joltages: i16x16::new(joltages),
}
}
}
@@ -53,25 +94,18 @@ struct Machine<'a> {
lights: Vec<bool>,
}
impl<'a> Eq for Machine<'a> {}
impl<'a> PartialEq for Machine<'a> {
fn eq(&self, other: &Self) -> bool {
self.lights == other.lights
}
}
impl<'a> Hash for Machine<'a> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.lights.hash(state)
}
#[derive(Clone, Debug)]
struct JoltMachine<'a> {
d: &'a MachineDefinition,
joltages: i16x16,
}
impl<'a> Machine<'a> {
/// Get the state after pressing `button`, returns None if the state is as desired
/// Get the state after pressing `button`, returns None if the state is as desired.
fn press(&self, button: usize) -> Option<Self> {
let mut new_state = self.lights.clone();
for light in &self.d.buttons[button] {
new_state[*light] = !new_state[*light]
new_state[*light] = !new_state[*light];
}
if new_state == self.d.desired {
None
@@ -82,9 +116,8 @@ impl<'a> Machine<'a> {
})
}
}
/// Get the possible states from the current position
fn next_states(&self) -> Vec<(usize, Option<Machine<'a>>)> {
fn next_states(&self) -> Vec<(usize, Option<Self>)> {
self.d
.buttons
.iter()
@@ -94,14 +127,65 @@ impl<'a> Machine<'a> {
}
}
impl<'a> JoltMachine<'a> {
fn press_jolts(&self, button: usize, presses: &i16x16) -> (i16x16, Option<Self>) {
// let mut new_joltage = self.joltages.clone();
// // for jolt in &self.d.buttons[button] {
// // new_joltage[*jolt] += 1;
// // }
let new_joltage = self.joltages + self.d.buttons2[button];
let mut new_presses = presses.clone();
new_presses.as_mut_array()[button] += 1;
if new_joltage == self.d.joltages {
(new_presses, None)
} else {
(
new_presses,
Some(Self {
d: self.d,
joltages: new_joltage,
}),
)
}
}
fn next_states_jolt(&self, presses: &i16x16) -> Vec<(i16x16, Option<Self>)> {
self.d
.buttons
.iter()
.enumerate()
.map(|(i, _but)| self.press_jolts(i, &presses))
// .inspect(|(p, o)| println!(" {p:?} {o:?}\n"))
// joltages monotonically increase, so cull any where a joltage is higher than needed
.filter(|(presses, candidate)| {
!presses.simd_gt(self.d.buttons_max).any()
&& candidate.as_ref().is_none_or(|c| {
!c.joltages.simd_gt(self.d.joltages).any()
// !c.joltages
// .iter()
// .zip(self.d.joltages.iter())
// .any(|(candidate, expected)| candidate > expected)
})
})
.collect()
}
}
#[derive(Debug, Clone)]
struct PressSet<'a> {
machine: Machine<'a>,
presses: usize,
}
#[derive(Debug, Clone)]
struct PressSet2<'a> {
machine: JoltMachine<'a>,
presses: i16x16,
}
// NOTE: All compares are reversed so our max heap becomes a min heap
impl<'a> Eq for PressSet<'a> {}
impl<'a> Eq for PressSet2<'a> {}
impl<'a> PartialEq for PressSet<'a> {
fn eq(&self, other: &Self) -> bool {
@@ -109,33 +193,44 @@ impl<'a> PartialEq for PressSet<'a> {
}
}
impl<'a> PartialEq for PressSet2<'a> {
fn eq(&self, other: &Self) -> bool {
other.presses.eq(&self.presses)
}
}
impl<'a> PartialOrd for PressSet<'a> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<'a> PartialOrd for PressSet2<'a> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<'a> Ord for PressSet<'a> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.presses.cmp(&self.presses)
}
}
impl<'a> Ord for PressSet2<'a> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.presses.reduce_add().cmp(&self.presses.reduce_add())
}
}
fn find_best(md: &MachineDefinition) -> usize {
let m = md.create();
let mut to_check = BinaryHeap::new();
to_check.push(PressSet {
presses: 0,
machine: m,
});
for next in m.next_states() {
if let Some(new_m) = next.1 {
to_check.push(PressSet {
presses: 1,
machine: new_m.clone(),
})
} else {
// what we found a solution on the first move?
return 1;
}
}
while let Some(candidate) = to_check.pop() {
let cm = candidate.machine.clone();
for next in cm.next_states() {
@@ -153,6 +248,41 @@ fn find_best(md: &MachineDefinition) -> usize {
panic!()
}
fn find_best_jolts(md: &MachineDefinition) -> usize {
let m = md.create2();
let mut to_check = VecDeque::new();
to_check.push_back(PressSet2 {
presses: i16x16::splat(0),
machine: m,
});
let mut pb = ProgressBar::no_length()
.with_style(
ProgressStyle::with_template(
"[{elapsed_precise}/{eta_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {per_sec}",
)
.unwrap(),
)
.with_finish(indicatif::ProgressFinish::AndLeave);
while let Some(candidate) = to_check.pop_front() {
pb.inc(1);
pb.set_length(to_check.len() as u64);
let cm = candidate.machine.clone();
for (presses, next) in cm.next_states_jolt(&candidate.presses) {
if let Some(new_m) = next {
to_check.push_back(PressSet2 {
presses,
machine: new_m.clone(),
})
} else {
return presses.reduce_add() as usize;
}
}
}
panic!()
}
#[aoc_generator(day10)]
fn parse(input: &str) -> Vec<MachineDefinition> {
input.lines().map(|l| l.into()).collect()
@@ -170,7 +300,11 @@ fn part1(input: &[MachineDefinition]) -> u64 {
#[aoc(day10, part2)]
fn part2(input: &[MachineDefinition]) -> u64 {
0
input
.iter()
.map(find_best_jolts)
.map(|sol| sol as u64)
.sum()
}
#[cfg(test)]