Skip to content

Commit

Permalink
Rust/2024/19: improve solution
Browse files Browse the repository at this point in the history
  • Loading branch information
Defelo committed Dec 19, 2024
1 parent 0ffc103 commit a615f31
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 10 deletions.
24 changes: 14 additions & 10 deletions Rust/2024/19.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
#![feature(test)]

use aoc::trie::Trie;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};

#[derive(Debug)]
struct Input {
patterns: Vec<String>,
designs: Vec<String>,
patterns: Trie<u8>,
designs: Vec<Vec<u8>>,
}

fn setup(input: &str) -> Input {
let mut lines = input.trim().lines();
let patterns = lines.next().unwrap().split(", ").map(Into::into).collect();
let patterns = lines
.next()
.unwrap()
.split(", ")
.map(|p| p.bytes())
.collect();
let designs = lines.skip(1).map(Into::into).collect();
Input { patterns, designs }
}

fn count(design: &str, patterns: &[String]) -> usize {
fn count(design: &[u8], patterns: &Trie<u8>) -> usize {
let mut dp = vec![0; design.len() + 1];
dp[design.len()] = 1;
for i in (0..design.len()).rev() {
for p in patterns {
if i + p.len() > design.len() || !design[i..].starts_with(p) {
continue;
}
dp[i] += dp[i + p.len()];
}
dp[i] = patterns
.prefix_matches(design[i..].iter().copied())
.map(|len| dp[i + len])
.sum();
}
dp[0]
}
Expand Down
1 change: 1 addition & 0 deletions Rust/lib/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod iter_ext;
pub mod math;
pub mod parsing;
pub mod range;
pub mod trie;
pub mod tuples;

extern crate test;
Expand Down
132 changes: 132 additions & 0 deletions Rust/lib/trie.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
use std::hash::Hash;

use rustc_hash::FxHashMap;

#[derive(Debug, Clone)]
pub struct Trie<T>(Vec<TrieNode<T>>);

#[derive(Debug, Clone)]
struct TrieNode<T> {
flag: bool,
next: FxHashMap<T, usize>,
}

impl<T> Trie<T> {
pub fn new() -> Self {
Self::default()
}
}

impl<T: Eq + Hash> Trie<T> {
pub fn insert(&mut self, item: impl IntoIterator<Item = T>) -> bool {
let mut node = 0;
for x in item {
match self.0[node].next.get(&x) {
Some(&next) => node = next,
None => {
let next = self.0.len();
self.0.push(TrieNode::default());
self.0[node].next.insert(x, next);
node = next;
}
}
}
!std::mem::replace(&mut self.0[node].flag, true)
}

pub fn contains(&self, item: impl IntoIterator<Item = T>) -> bool {
let mut node = 0;
for x in item {
match self.0[node].next.get(&x) {
Some(&next) => node = next,
None => return false,
}
}
self.0[node].flag
}

pub fn prefix_matches<U: IntoIterator<Item = T>>(
&self,
item: U,
) -> impl Iterator<Item = usize> + use<'_, T, U> {
self.0[0].flag.then_some(0).into_iter().chain(
item.into_iter()
.scan(0, |node, x| {
self.0[*node].next.get(&x).map(|&next| {
*node = next;
self.0[*node].flag
})
})
.enumerate()
.flat_map(|(i, flag)| flag.then_some(i + 1)),
)
}
}

impl<T> Default for Trie<T> {
fn default() -> Self {
Self(vec![TrieNode::default()])
}
}

impl<T> Default for TrieNode<T> {
fn default() -> Self {
Self {
flag: false,
next: Default::default(),
}
}
}

impl<I2> FromIterator<I2> for Trie<I2::Item>
where
I2: IntoIterator,
I2::Item: Eq + Hash,
{
fn from_iter<I1: IntoIterator<Item = I2>>(iter: I1) -> Self {
let mut trie = Self::new();
for item in iter {
trie.insert(item);
}
trie
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn trie() {
let mut t = ["foo", "bar", "baz"]
.into_iter()
.map(|x| x.chars())
.collect::<Trie<_>>();

assert!(t.contains("foo".chars()));
assert!(t.contains("bar".chars()));
assert!(t.contains("baz".chars()));
assert!(!t.contains("test".chars()));
assert!(!t.contains("baa".chars()));

assert!(t.insert("test".chars()));
assert!(t.contains("test".chars()));
assert!(!t.insert("test".chars()));
assert!(t.contains("test".chars()));
}

#[test]
fn prefix_matches() {
let t = ["", "123", "12345", "1234567", "test", "12xy"]
.into_iter()
.map(|x| x.chars())
.collect::<Trie<_>>();

assert_eq!(
t.prefix_matches("123456789".chars()).collect::<Vec<_>>(),
[0, 3, 5, 7]
);

assert_eq!(t.prefix_matches("".chars()).collect::<Vec<_>>(), [0]);
}
}

0 comments on commit a615f31

Please sign in to comment.