diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..03f4879 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +rustflags = ["-Ctarget-cpu=native"] diff --git a/src/day16.rs b/src/day16.rs index 2d9e98a..53481fd 100644 --- a/src/day16.rs +++ b/src/day16.rs @@ -8,8 +8,8 @@ fn split_once<'a>(input: &'a str, delimeter: &str) -> Option<(&'a str, &'a str)> struct Rule<'a> { name: &'a str, - a: RangeInclusive, - b: RangeInclusive, + a: RangeInclusive, + b: RangeInclusive, } impl<'a> Rule<'a> { fn parse(input: &'a str) -> Option { @@ -17,22 +17,22 @@ impl<'a> Rule<'a> { let (a, b) = split_once(rem, " or ")?; let a = { let (start, end) = split_once(a, "-")?; - RangeInclusive::::new(start.parse().ok()?, end.parse().ok()?) + RangeInclusive::::new(start.parse().ok()?, end.parse().ok()?) }; let b = { let (start, end) = split_once(b, "-")?; - RangeInclusive::::new(start.parse().ok()?, end.parse().ok()?) + RangeInclusive::::new(start.parse().ok()?, end.parse().ok()?) }; Some(Rule { name, a, b }) } - fn matches(&self, value: usize) -> bool { + fn matches(&self, value: u16) -> bool { self.a.contains(&value) || self.b.contains(&value) } } #[aoc(day16, part1)] -fn solve_d16_p1(input: &str) -> usize { +fn solve_d16_p1(input: &str) -> u16 { let (rules, rem) = split_once(input, "\n\nyour ticket:\n").unwrap(); let (_your_ticket, nearby_tickets) = split_once(rem, "\n\nnearby tickets:\n").unwrap(); @@ -41,7 +41,7 @@ fn solve_d16_p1(input: &str) -> usize { .split('\n') .flat_map(|line| line.split(',')) .filter_map(|x| { - let value: usize = x.parse().unwrap(); + let value: u16 = x.parse().unwrap(); if rules.iter().any(|rule| rule.matches(value)) { None } else { @@ -112,18 +112,196 @@ fn solve_d16_p2(input: &str) -> usize { .product() } -#[test] -fn test_foo() { - const INPUT: &str = "class: 0-1 or 4-19 -row: 0-5 or 8-19 -seat: 0-13 or 16-19 - -your ticket: -11,12,13 - -nearby tickets: -3,9,18 -15,1,5 -5,14,9"; - solve_d16_p2(INPUT); +#[aoc(day16, part2, avx2)] +fn solve_d16_p2_avx2(input: &str) -> usize { + unsafe { avx2::solve_d16_p2(input) } +} + +#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] +mod avx2 { + use super::{split_once, Rule}; + use std::arch::x86_64::*; + + struct RuleEval { + lo_start: [__m256i; 2], + lo_end: [__m256i; 2], + hi_start: [__m256i; 2], + hi_end: [__m256i; 2], + } + + impl RuleEval { + #[target_feature(enable = "avx2")] + unsafe fn new(rules: &[Rule]) -> Option { + if rules.len() > 31 { + return None; + } + let lo_start = { + let mut lo_start = [_mm256_setzero_si256(); 2]; + let lo_start_16 = &mut *(&mut lo_start as *mut [__m256i; 2] as *mut [i16; 32]); + let mut rules_iter = rules.iter(); + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + lo_start_16[idx] = *rule.a.start() as i16 - 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + lo_start_16[idx + 16] = *rule.a.start() as i16 - 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + lo_start_16[idx + 8] = *rule.a.start() as i16 - 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + lo_start_16[idx + 24] = *rule.a.start() as i16 - 1; + } + lo_start + }; + let lo_end = { + let mut lo_end = [_mm256_setzero_si256(); 2]; + let lo_end_16 = &mut *(&mut lo_end as *mut [__m256i; 2] as *mut [i16; 32]); + let mut rules_iter = rules.iter(); + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + lo_end_16[idx] = *rule.a.end() as i16 + 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + lo_end_16[idx + 16] = *rule.a.end() as i16 + 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + lo_end_16[idx + 8] = *rule.a.end() as i16 + 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + lo_end_16[idx + 24] = *rule.a.end() as i16 + 1; + } + lo_end + }; + let hi_start = { + let mut hi_start = [_mm256_setzero_si256(); 2]; + let hi_start_16 = &mut *(&mut hi_start as *mut [__m256i; 2] as *mut [i16; 32]); + let mut rules_iter = rules.iter(); + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + hi_start_16[idx] = *rule.b.start() as i16 - 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + hi_start_16[idx + 16] = *rule.b.start() as i16 - 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + hi_start_16[idx + 8] = *rule.b.start() as i16 - 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + hi_start_16[idx + 24] = *rule.b.start() as i16 - 1; + } + hi_start + }; + let hi_end = { + let mut hi_end = [_mm256_setzero_si256(); 2]; + let hi_end_16 = &mut *(&mut hi_end as *mut [__m256i; 2] as *mut [i16; 32]); + let mut rules_iter = rules.iter(); + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + hi_end_16[idx] = *rule.b.end() as i16 + 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + hi_end_16[idx + 16] = *rule.b.end() as i16 + 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + hi_end_16[idx + 8] = *rule.b.end() as i16 + 1; + } + for (idx, rule) in rules_iter.by_ref().take(8).enumerate() { + hi_end_16[idx + 24] = *rule.b.end() as i16 + 1; + } + hi_end + }; + Some(RuleEval { + lo_start, + lo_end, + hi_start, + hi_end, + }) + } + + #[target_feature(enable = "avx2")] + unsafe fn eval(&self, value: u16) -> u32 { + let value = _mm256_set1_epi16(value as i16); + let within_lo = [ + _mm256_and_si256( + _mm256_cmpgt_epi16(self.lo_end[0], value), + _mm256_cmpgt_epi16(value, self.lo_start[0]), + ), + _mm256_and_si256( + _mm256_cmpgt_epi16(self.lo_end[1], value), + _mm256_cmpgt_epi16(value, self.lo_start[1]), + ), + ]; + let within_hi = [ + _mm256_and_si256( + _mm256_cmpgt_epi16(self.hi_end[0], value), + _mm256_cmpgt_epi16(value, self.hi_start[0]), + ), + _mm256_and_si256( + _mm256_cmpgt_epi16(self.hi_end[1], value), + _mm256_cmpgt_epi16(value, self.hi_start[1]), + ), + ]; + let valid = [ + _mm256_or_si256(within_lo[0], within_hi[0]), + _mm256_or_si256(within_lo[1], within_hi[1]), + ]; + let packed = _mm256_packs_epi16(valid[0], valid[1]); + _mm256_movemask_epi8(packed) as u32 + } + } + + #[target_feature(enable = "avx2")] + pub unsafe fn solve_d16_p2(input: &str) -> usize { + let (rules, rem) = split_once(input, "\n\nyour ticket:\n").unwrap(); + let (my_ticket, nearby_tickets) = split_once(rem, "\n\nnearby tickets:\n").unwrap(); + + let rules: Vec<_> = rules.split('\n').map(|x| Rule::parse(x).unwrap()).collect(); + assert!(rules.len() < 32); + let rule_eval = RuleEval::new(&rules).unwrap(); + let mut candidates = [_mm256_set1_epi32((1i32 << rules.len()) - 1); 4]; + let mut scratch_space = [_mm256_set1_epi32(1); 4]; + for line in nearby_tickets.split('\n') { + let scratch_slice = &mut * (&mut scratch_space as *mut _ as *mut [u32; 32]); + for (field, scratch) in line.split(',').zip(scratch_slice.iter_mut()) { + let field = field.parse().unwrap(); + *scratch = rule_eval.eval(field); + } + + if scratch_space.iter().copied().any(|elem| { + _mm256_movemask_epi8(_mm256_cmpeq_epi32(elem, _mm256_set1_epi32(0))) != 0 + }) { + continue; + } + + for (candidate, valid_bitmask) in candidates.iter_mut().zip(scratch_space.iter()) { + *candidate = _mm256_and_si256(*candidate, *valid_bitmask); + } + } + + let candidates = &mut * (&mut candidates as *mut _ as *mut [u32; 32]); + let candidates = &mut candidates[..rules.len()]; + while candidates.iter().copied().any(|x| x.count_ones() > 1) { + for idx in 0..candidates.len() { + let candidate = candidates[idx]; + if candidate.count_ones() == 1 { + let mask = !candidate; + for before in &mut candidates[..idx] { + *before &= mask; + } + for after in &mut candidates[idx + 1..] { + *after &= mask; + } + } + } + } + my_ticket + .split(',') + .map(|x| x.parse::().unwrap()) + .zip(candidates.into_iter().map(|x| x.trailing_zeros() as usize)) + .filter_map(|(field, rule_idx)| { + if rules[rule_idx].name.starts_with("departure") { + Some(field) + } else { + None + } + }) + .product() + } }