day10: working but way too slow simd implementation
This commit is contained in:
26
Cargo.lock
generated
26
Cargo.lock
generated
@@ -72,6 +72,7 @@ dependencies = [
|
||||
"rayon",
|
||||
"regex",
|
||||
"rstest",
|
||||
"wide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -86,6 +87,12 @@ version = "3.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4"
|
||||
|
||||
[[package]]
|
||||
name = "cached"
|
||||
version = "0.56.0"
|
||||
@@ -556,6 +563,15 @@ version = "1.0.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
|
||||
|
||||
[[package]]
|
||||
name = "safe_arch"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f7caad094bd561859bcd467734a720c3c1f5d1f338995351fefe2190c45efed"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.27"
|
||||
@@ -767,6 +783,16 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wide"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbace5de6cfc4866f684318ad85761c89380cfb191982ae96aa65c295bf5897e"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"safe_arch",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-link"
|
||||
version = "0.2.1"
|
||||
|
||||
@@ -14,6 +14,7 @@ misc = {path = "utils/misc"}
|
||||
rayon = "1.11.0"
|
||||
regex = "1.11.1"
|
||||
rstest = "0.26.1"
|
||||
wide = "1.0.2"
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
||||
|
||||
218
src/day10.rs
218
src/day10.rs
@@ -1,13 +1,22 @@
|
||||
use std::{collections::BinaryHeap, hash::Hash, iter::repeat_n};
|
||||
use std::{
|
||||
collections::{BinaryHeap, VecDeque},
|
||||
hash::Hash,
|
||||
iter::{repeat, repeat_n},
|
||||
};
|
||||
|
||||
use aoc_runner_derive::{aoc, aoc_generator};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use itertools::Itertools;
|
||||
use regex::Regex;
|
||||
use wide::{CmpGt, i16x16};
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct MachineDefinition {
|
||||
desired: Vec<bool>,
|
||||
buttons: Vec<Vec<usize>>,
|
||||
joltages: Vec<u64>,
|
||||
buttons2: Vec<i16x16>,
|
||||
buttons_max: i16x16,
|
||||
joltages: i16x16,
|
||||
}
|
||||
|
||||
impl MachineDefinition {
|
||||
@@ -17,6 +26,13 @@ impl MachineDefinition {
|
||||
lights: Vec::from_iter(repeat_n(false, self.desired.len())),
|
||||
}
|
||||
}
|
||||
|
||||
fn create2<'a>(&'a self) -> JoltMachine<'a> {
|
||||
JoltMachine {
|
||||
d: self,
|
||||
joltages: i16x16::splat(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for MachineDefinition {
|
||||
@@ -27,22 +43,47 @@ impl From<&str> for MachineDefinition {
|
||||
.unwrap();
|
||||
|
||||
let parts = parse_re.captures(value).unwrap();
|
||||
let joltages: [i16; 16] = parts["joltages"]
|
||||
.split(',')
|
||||
.map(|n| n.parse().unwrap())
|
||||
.chain(repeat(0))
|
||||
.take(16)
|
||||
.collect_array()
|
||||
.unwrap();
|
||||
|
||||
let buttons = parts["buttons"]
|
||||
.split_ascii_whitespace()
|
||||
.map(|s| {
|
||||
s[1..s.len() - 1]
|
||||
.split(',')
|
||||
.map(|n| n.parse().unwrap())
|
||||
.collect()
|
||||
})
|
||||
.sorted_unstable_by_key(|s: &Vec<usize>| s.len())
|
||||
.rev()
|
||||
.collect_vec();
|
||||
|
||||
let mut buttons2 = Vec::new();
|
||||
let mut buttons_max = [0i16; 16];
|
||||
|
||||
for (i, b) in buttons.iter().enumerate() {
|
||||
let mut but = [0i16; 16];
|
||||
for i in b {
|
||||
but[*i] = 1;
|
||||
}
|
||||
buttons2.push(i16x16::new(but));
|
||||
|
||||
// find the joltage this button affects with the lowest value
|
||||
// it is the max number of presses for this button
|
||||
buttons_max[i] = b.iter().map(|idx| joltages[*idx]).min().unwrap();
|
||||
}
|
||||
|
||||
MachineDefinition {
|
||||
desired: parts["desired"].chars().map(|c| c == '#').collect(),
|
||||
buttons: parts["buttons"]
|
||||
.split_ascii_whitespace()
|
||||
.map(|s| {
|
||||
s[1..s.len() - 1]
|
||||
.split(',')
|
||||
.map(|n| n.parse().unwrap())
|
||||
.collect()
|
||||
})
|
||||
.collect(),
|
||||
joltages: parts["joltages"]
|
||||
.split(',')
|
||||
.map(|n| n.parse().unwrap())
|
||||
.collect(),
|
||||
buttons: buttons,
|
||||
buttons2: buttons2,
|
||||
buttons_max: i16x16::new(buttons_max),
|
||||
joltages: i16x16::new(joltages),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -53,25 +94,18 @@ struct Machine<'a> {
|
||||
lights: Vec<bool>,
|
||||
}
|
||||
|
||||
impl<'a> Eq for Machine<'a> {}
|
||||
impl<'a> PartialEq for Machine<'a> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.lights == other.lights
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Hash for Machine<'a> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.lights.hash(state)
|
||||
}
|
||||
#[derive(Clone, Debug)]
|
||||
struct JoltMachine<'a> {
|
||||
d: &'a MachineDefinition,
|
||||
joltages: i16x16,
|
||||
}
|
||||
|
||||
impl<'a> Machine<'a> {
|
||||
/// Get the state after pressing `button`, returns None if the state is as desired
|
||||
/// Get the state after pressing `button`, returns None if the state is as desired.
|
||||
fn press(&self, button: usize) -> Option<Self> {
|
||||
let mut new_state = self.lights.clone();
|
||||
for light in &self.d.buttons[button] {
|
||||
new_state[*light] = !new_state[*light]
|
||||
new_state[*light] = !new_state[*light];
|
||||
}
|
||||
if new_state == self.d.desired {
|
||||
None
|
||||
@@ -82,9 +116,8 @@ impl<'a> Machine<'a> {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the possible states from the current position
|
||||
fn next_states(&self) -> Vec<(usize, Option<Machine<'a>>)> {
|
||||
fn next_states(&self) -> Vec<(usize, Option<Self>)> {
|
||||
self.d
|
||||
.buttons
|
||||
.iter()
|
||||
@@ -94,14 +127,65 @@ impl<'a> Machine<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> JoltMachine<'a> {
|
||||
fn press_jolts(&self, button: usize, presses: &i16x16) -> (i16x16, Option<Self>) {
|
||||
// let mut new_joltage = self.joltages.clone();
|
||||
// // for jolt in &self.d.buttons[button] {
|
||||
// // new_joltage[*jolt] += 1;
|
||||
// // }
|
||||
let new_joltage = self.joltages + self.d.buttons2[button];
|
||||
let mut new_presses = presses.clone();
|
||||
new_presses.as_mut_array()[button] += 1;
|
||||
if new_joltage == self.d.joltages {
|
||||
(new_presses, None)
|
||||
} else {
|
||||
(
|
||||
new_presses,
|
||||
Some(Self {
|
||||
d: self.d,
|
||||
joltages: new_joltage,
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn next_states_jolt(&self, presses: &i16x16) -> Vec<(i16x16, Option<Self>)> {
|
||||
self.d
|
||||
.buttons
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, _but)| self.press_jolts(i, &presses))
|
||||
// .inspect(|(p, o)| println!(" {p:?} {o:?}\n"))
|
||||
// joltages monotonically increase, so cull any where a joltage is higher than needed
|
||||
.filter(|(presses, candidate)| {
|
||||
!presses.simd_gt(self.d.buttons_max).any()
|
||||
&& candidate.as_ref().is_none_or(|c| {
|
||||
!c.joltages.simd_gt(self.d.joltages).any()
|
||||
// !c.joltages
|
||||
// .iter()
|
||||
// .zip(self.d.joltages.iter())
|
||||
// .any(|(candidate, expected)| candidate > expected)
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct PressSet<'a> {
|
||||
machine: Machine<'a>,
|
||||
presses: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct PressSet2<'a> {
|
||||
machine: JoltMachine<'a>,
|
||||
presses: i16x16,
|
||||
}
|
||||
|
||||
// NOTE: All compares are reversed so our max heap becomes a min heap
|
||||
impl<'a> Eq for PressSet<'a> {}
|
||||
impl<'a> Eq for PressSet2<'a> {}
|
||||
|
||||
impl<'a> PartialEq for PressSet<'a> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
@@ -109,33 +193,44 @@ impl<'a> PartialEq for PressSet<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> PartialEq for PressSet2<'a> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
other.presses.eq(&self.presses)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> PartialOrd for PressSet<'a> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> PartialOrd for PressSet2<'a> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Ord for PressSet<'a> {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
other.presses.cmp(&self.presses)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Ord for PressSet2<'a> {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
other.presses.reduce_add().cmp(&self.presses.reduce_add())
|
||||
}
|
||||
}
|
||||
|
||||
fn find_best(md: &MachineDefinition) -> usize {
|
||||
let m = md.create();
|
||||
let mut to_check = BinaryHeap::new();
|
||||
to_check.push(PressSet {
|
||||
presses: 0,
|
||||
machine: m,
|
||||
});
|
||||
|
||||
for next in m.next_states() {
|
||||
if let Some(new_m) = next.1 {
|
||||
to_check.push(PressSet {
|
||||
presses: 1,
|
||||
machine: new_m.clone(),
|
||||
})
|
||||
} else {
|
||||
// what we found a solution on the first move?
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
while let Some(candidate) = to_check.pop() {
|
||||
let cm = candidate.machine.clone();
|
||||
for next in cm.next_states() {
|
||||
@@ -153,6 +248,41 @@ fn find_best(md: &MachineDefinition) -> usize {
|
||||
panic!()
|
||||
}
|
||||
|
||||
fn find_best_jolts(md: &MachineDefinition) -> usize {
|
||||
let m = md.create2();
|
||||
let mut to_check = VecDeque::new();
|
||||
to_check.push_back(PressSet2 {
|
||||
presses: i16x16::splat(0),
|
||||
machine: m,
|
||||
});
|
||||
|
||||
let mut pb = ProgressBar::no_length()
|
||||
.with_style(
|
||||
ProgressStyle::with_template(
|
||||
"[{elapsed_precise}/{eta_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {per_sec}",
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.with_finish(indicatif::ProgressFinish::AndLeave);
|
||||
|
||||
while let Some(candidate) = to_check.pop_front() {
|
||||
pb.inc(1);
|
||||
pb.set_length(to_check.len() as u64);
|
||||
let cm = candidate.machine.clone();
|
||||
for (presses, next) in cm.next_states_jolt(&candidate.presses) {
|
||||
if let Some(new_m) = next {
|
||||
to_check.push_back(PressSet2 {
|
||||
presses,
|
||||
machine: new_m.clone(),
|
||||
})
|
||||
} else {
|
||||
return presses.reduce_add() as usize;
|
||||
}
|
||||
}
|
||||
}
|
||||
panic!()
|
||||
}
|
||||
|
||||
#[aoc_generator(day10)]
|
||||
fn parse(input: &str) -> Vec<MachineDefinition> {
|
||||
input.lines().map(|l| l.into()).collect()
|
||||
@@ -170,7 +300,11 @@ fn part1(input: &[MachineDefinition]) -> u64 {
|
||||
|
||||
#[aoc(day10, part2)]
|
||||
fn part2(input: &[MachineDefinition]) -> u64 {
|
||||
0
|
||||
input
|
||||
.iter()
|
||||
.map(find_best_jolts)
|
||||
.map(|sol| sol as u64)
|
||||
.sum()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user