From 080b03e09599a6efc60ac87ac58d1a1f4e56671e Mon Sep 17 00:00:00 2001 From: Rohit Kulshreshtha Date: Tue, 28 Jan 2025 16:00:14 -0800 Subject: [PATCH] feat(dfir_lang): Allow state_by to use a factory function. --- datastores/gossip_kv/kv/server.rs | 2 +- dfir_lang/src/graph/ops/state.rs | 2 +- dfir_lang/src/graph/ops/state_by.rs | 46 ++++++++++++++++++++++------- 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/datastores/gossip_kv/kv/server.rs b/datastores/gossip_kv/kv/server.rs index f3d2afc199e..af8132b0399 100644 --- a/datastores/gossip_kv/kv/server.rs +++ b/datastores/gossip_kv/kv/server.rs @@ -274,7 +274,7 @@ where new_writes -> for_each(|x| trace!("NEW WRITE: {:?}", x)); // Step 1: Put the new writes in a map, with the write as the key and a SetBoundedLattice as the value. - infecting_writes = union() -> state::<'static, MapUnionHashMap>(); + infecting_writes = union() -> state_by::<'static, MapUnionHashMap>(std::convert::identity, std::default::Default::default); new_writes -> map(|write| { // Ideally, the write itself is the key, but writes are a hashmap and hashmaps don't diff --git a/dfir_lang/src/graph/ops/state.rs b/dfir_lang/src/graph/ops/state.rs index eadca0eccd1..cf4207c27ac 100644 --- a/dfir_lang/src/graph/ops/state.rs +++ b/dfir_lang/src/graph/ops/state.rs @@ -39,7 +39,7 @@ pub const STATE: OperatorConstraints = OperatorConstraints { diagnostics| { let wc = WriteContextArgs { - arguments: &parse_quote_spanned!(op_span => ::std::convert::identity), + arguments: &parse_quote_spanned!(op_span => ::std::convert::identity, ::std::default::Default::default), ..wc.clone() }; diff --git a/dfir_lang/src/graph/ops/state_by.rs b/dfir_lang/src/graph/ops/state_by.rs index c84730dd169..4ac5210a112 100644 --- a/dfir_lang/src/graph/ops/state_by.rs +++ b/dfir_lang/src/graph/ops/state_by.rs @@ -6,7 +6,8 @@ use super::{ }; use crate::diagnostic::{Diagnostic, Level}; -/// List state operator, but with a closure to map the input to the state lattice. +/// List state operator, but with a closure to map the input to the state lattice and a factory +/// function to initialize the internal data structure. /// /// The emitted outputs (both the referencable singleton and the optional pass-through stream) are /// of the same type as the inputs to the state_by operator and are not required to be a lattice @@ -18,8 +19,26 @@ use crate::diagnostic::{Diagnostic, Level}; /// use lattices::set_union::{CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet}; /// /// my_state = source_iter(0..3) -/// -> state_by::>(SetUnionSingletonSet::new_from); +/// -> state_by::>(SetUnionSingletonSet::new_from, std::default::Default::default); /// ``` +/// The 2nd argument into `state_by` is a factory function that can be used to supply a custom +/// initial value for the backing state. The initial value is still expected to be bottom (and will +/// be checked). This is useful for doing things like pre-allocating buffers, etc. In the above +/// example, it is just using `Default::default()` +/// +/// An example of preallocating the capacity in a hashmap: +/// +///```dfir +/// use std::collections::HashSet; +/// +/// use lattices::set_union::{CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet}; +/// +/// my_state = source_iter(0..3) +/// -> state_by::>(SetUnionSingletonSet::new_from, {|| SetUnion::new(HashSet::with_capacity(1_000)) }); +///``` +/// +/// The `state` operator is equivalent to `state_by` used with an identity mapping operator with +/// `Default::default` providing the factory function. pub const STATE_BY: OperatorConstraints = OperatorConstraints { name: "state_by", categories: &[OperatorCategory::Persistence], @@ -27,7 +46,7 @@ pub const STATE_BY: OperatorConstraints = OperatorConstraints { soft_range_inn: RANGE_1, hard_range_out: &(0..=1), soft_range_out: &(0..=1), - num_args: 1, + num_args: 2, persistence_args: &(0..=1), type_args: &(0..=1), is_external_input: false, @@ -80,11 +99,16 @@ pub const STATE_BY: OperatorConstraints = OperatorConstraints { _ => unreachable!(), }; + let state_ident = singleton_output_ident; - let mut write_prologue = quote_spanned! {op_span=> - let #state_ident = #hydroflow.add_state(::std::cell::RefCell::new( - <#lattice_type as ::std::default::Default>::default() - )); + let factory_fn = &arguments[1]; + + let mut write_prologue = quote_spanned! { op_span=> + let #state_ident = { + let data_struct : #lattice_type = (#factory_fn)(); + ::std::debug_assert!(::lattices::IsBot::is_bot(&data_struct)); + #hydroflow.add_state(::std::cell::RefCell::new(data_struct)) + }; }; if Persistence::Tick == persistence { write_prologue.extend(quote_spanned! {op_span=> @@ -92,7 +116,7 @@ pub const STATE_BY: OperatorConstraints = OperatorConstraints { }); } - let func = &arguments[0]; + let by_fn = &arguments[0]; // TODO(mingwei): deduplicate codegen let write_iterator = if is_pull { @@ -117,7 +141,7 @@ pub const STATE_BY: OperatorConstraints = OperatorConstraints { #root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item))) }) } - check_input::<_, _, _, _, #lattice_type>(#input, #func, #state_ident, #context) + check_input::<_, _, _, _, #lattice_type>(#input, #by_fn, #state_ident, #context) }; } } else if let Some(output) = outputs.first() { @@ -141,7 +165,7 @@ pub const STATE_BY: OperatorConstraints = OperatorConstraints { #root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item))) }, push) } - check_output::<_, _, _, _, #lattice_type>(#output, #func, #state_ident, #context) + check_output::<_, _, _, _, #lattice_type>(#output, #by_fn, #state_ident, #context) }; } } else { @@ -164,7 +188,7 @@ pub const STATE_BY: OperatorConstraints = OperatorConstraints { #root::lattices::Merge::merge(&mut *state, (mapfn)(item)); }) } - check_output::<_, _, _, #lattice_type>(#state_ident, #func, #context) + check_output::<_, _, _, #lattice_type>(#state_ident, #by_fn, #context) }; } };