From a615f31d339fbc93a5e28e4a21d3bc7c63e8aea0 Mon Sep 17 00:00:00 2001 From: Defelo Date: Thu, 19 Dec 2024 14:32:55 +0100 Subject: [PATCH] Rust/2024/19: improve solution --- Rust/2024/19.rs | 24 +++++---- Rust/lib/lib.rs | 1 + Rust/lib/trie.rs | 132 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+), 10 deletions(-) create mode 100644 Rust/lib/trie.rs diff --git a/Rust/2024/19.rs b/Rust/2024/19.rs index 9f2e0f8..4bfa47e 100644 --- a/Rust/2024/19.rs +++ b/Rust/2024/19.rs @@ -1,30 +1,34 @@ #![feature(test)] +use aoc::trie::Trie; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; #[derive(Debug)] struct Input { - patterns: Vec, - designs: Vec, + patterns: Trie, + designs: Vec>, } fn setup(input: &str) -> Input { let mut lines = input.trim().lines(); - let patterns = lines.next().unwrap().split(", ").map(Into::into).collect(); + let patterns = lines + .next() + .unwrap() + .split(", ") + .map(|p| p.bytes()) + .collect(); let designs = lines.skip(1).map(Into::into).collect(); Input { patterns, designs } } -fn count(design: &str, patterns: &[String]) -> usize { +fn count(design: &[u8], patterns: &Trie) -> usize { let mut dp = vec![0; design.len() + 1]; dp[design.len()] = 1; for i in (0..design.len()).rev() { - for p in patterns { - if i + p.len() > design.len() || !design[i..].starts_with(p) { - continue; - } - dp[i] += dp[i + p.len()]; - } + dp[i] = patterns + .prefix_matches(design[i..].iter().copied()) + .map(|len| dp[i + len]) + .sum(); } dp[0] } diff --git a/Rust/lib/lib.rs b/Rust/lib/lib.rs index ddca927..37098a2 100644 --- a/Rust/lib/lib.rs +++ b/Rust/lib/lib.rs @@ -10,6 +10,7 @@ pub mod iter_ext; pub mod math; pub mod parsing; pub mod range; +pub mod trie; pub mod tuples; extern crate test; diff --git a/Rust/lib/trie.rs b/Rust/lib/trie.rs new file mode 100644 index 0000000..b04f098 --- /dev/null +++ b/Rust/lib/trie.rs @@ -0,0 +1,132 @@ +use std::hash::Hash; + +use rustc_hash::FxHashMap; + +#[derive(Debug, Clone)] +pub struct Trie(Vec>); + +#[derive(Debug, Clone)] +struct TrieNode { + flag: bool, + next: FxHashMap, +} + +impl Trie { + pub fn new() -> Self { + Self::default() + } +} + +impl Trie { + pub fn insert(&mut self, item: impl IntoIterator) -> bool { + let mut node = 0; + for x in item { + match self.0[node].next.get(&x) { + Some(&next) => node = next, + None => { + let next = self.0.len(); + self.0.push(TrieNode::default()); + self.0[node].next.insert(x, next); + node = next; + } + } + } + !std::mem::replace(&mut self.0[node].flag, true) + } + + pub fn contains(&self, item: impl IntoIterator) -> bool { + let mut node = 0; + for x in item { + match self.0[node].next.get(&x) { + Some(&next) => node = next, + None => return false, + } + } + self.0[node].flag + } + + pub fn prefix_matches>( + &self, + item: U, + ) -> impl Iterator + use<'_, T, U> { + self.0[0].flag.then_some(0).into_iter().chain( + item.into_iter() + .scan(0, |node, x| { + self.0[*node].next.get(&x).map(|&next| { + *node = next; + self.0[*node].flag + }) + }) + .enumerate() + .flat_map(|(i, flag)| flag.then_some(i + 1)), + ) + } +} + +impl Default for Trie { + fn default() -> Self { + Self(vec![TrieNode::default()]) + } +} + +impl Default for TrieNode { + fn default() -> Self { + Self { + flag: false, + next: Default::default(), + } + } +} + +impl FromIterator for Trie +where + I2: IntoIterator, + I2::Item: Eq + Hash, +{ + fn from_iter>(iter: I1) -> Self { + let mut trie = Self::new(); + for item in iter { + trie.insert(item); + } + trie + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn trie() { + let mut t = ["foo", "bar", "baz"] + .into_iter() + .map(|x| x.chars()) + .collect::>(); + + assert!(t.contains("foo".chars())); + assert!(t.contains("bar".chars())); + assert!(t.contains("baz".chars())); + assert!(!t.contains("test".chars())); + assert!(!t.contains("baa".chars())); + + assert!(t.insert("test".chars())); + assert!(t.contains("test".chars())); + assert!(!t.insert("test".chars())); + assert!(t.contains("test".chars())); + } + + #[test] + fn prefix_matches() { + let t = ["", "123", "12345", "1234567", "test", "12xy"] + .into_iter() + .map(|x| x.chars()) + .collect::>(); + + assert_eq!( + t.prefix_matches("123456789".chars()).collect::>(), + [0, 3, 5, 7] + ); + + assert_eq!(t.prefix_matches("".chars()).collect::>(), [0]); + } +}