day8: add heap-based impl

This commit is contained in:
2025-12-08 13:16:53 -08:00
parent b5f6bcde11
commit 522e10a3ea

View File

@@ -1,4 +1,4 @@
use std::fmt::Display;
use std::{cmp::Reverse, collections::BinaryHeap, fmt::Display};
use aoc_runner_derive::{aoc, aoc_generator};
use itertools::Itertools;
@@ -112,12 +112,12 @@ fn part1_impl(input: &Circuits, n: usize) -> u64 {
.unwrap() as u64
}
#[aoc(day8, part1)]
#[aoc(day8, part1, Sorted)]
fn part1(input: &Circuits) -> u64 {
part1_impl(input, 1000)
}
#[aoc(day8, part2)]
#[aoc(day8, part2, Sorted)]
fn part2(input: &Circuits) -> u64 {
let mut circuits = input.clone();
@@ -138,6 +138,85 @@ fn part2(input: &Circuits) -> u64 {
panic!()
}
struct JunctionPair {
a: usize,
b: usize,
d: u64,
}
impl Eq for JunctionPair {}
impl PartialEq for JunctionPair {
fn eq(&self, other: &Self) -> bool {
self.d == other.d
}
}
impl PartialOrd for JunctionPair {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.d.cmp(&other.d))
}
}
impl Ord for JunctionPair {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.d.cmp(&other.d)
}
}
fn make_heap(circuits: &Circuits) -> BinaryHeap<Reverse<JunctionPair>> {
BinaryHeap::from_iter(
circuits
.junctions
.iter()
.enumerate()
.tuple_combinations()
.map(|((a_pos, a), (b_pos, b))| {
Reverse(JunctionPair {
a: a_pos,
b: b_pos,
d: a.squared_distance(b),
})
}),
)
}
fn part1_heaped_impl(input: &Circuits, n: usize) -> u64 {
let mut circuits = input.clone();
let mut distances = make_heap(&circuits);
for _ in 0..n {
let pair = distances.pop().unwrap().0;
circuits.connect(pair.a, pair.b);
}
circuits
.circuits
.iter()
.map(|c| c.len())
.sorted_unstable()
.rev()
.take(3)
.reduce(|a, b| a * b)
.unwrap() as u64
}
#[aoc(day8, part1, Heaped)]
fn part1_heaped(input: &Circuits) -> u64 {
part1_heaped_impl(input, 1000)
}
#[aoc(day8, part2, Heaped)]
fn part2_heaped(input: &Circuits) -> u64 {
let mut circuits = input.clone();
let mut distances = make_heap(&circuits);
while let Some(Reverse(jp)) = distances.pop() {
circuits.connect(jp.a, jp.b);
if circuits.circuits.len() == 1 {
return (circuits.junctions[jp.a].pos.0 * circuits.junctions[jp.b].pos.0) as u64;
}
}
panic!()
}
#[cfg(test)]
mod tests {
use super::*;
@@ -168,8 +247,18 @@ mod tests {
assert_eq!(part1_impl(&parse(EXAMPLE), 10), 40);
}
#[test]
fn part1_heaped_example() {
assert_eq!(part1_heaped_impl(&parse(EXAMPLE), 10), 40);
}
#[test]
fn part2_example() {
assert_eq!(part2(&parse(EXAMPLE)), 25272);
}
#[test]
fn part2_heaped_example() {
assert_eq!(part2_heaped(&parse(EXAMPLE)), 25272);
}
}