Skip to content

Commit

Permalink
feat(dfir_lang): Allow state_by to use a factory function. (#1682)
Browse files Browse the repository at this point in the history
Currently, state_by uses Default::default to instantiate the backing
storage. Accepting a factory function will allow the storage to be
tweaked per instance of state. Example usage: pre-allocating memory for
the data structures.
  • Loading branch information
rohitkulshreshtha authored Jan 29, 2025
1 parent 19784f5 commit a9762c5
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 13 deletions.
2 changes: 1 addition & 1 deletion datastores/gossip_kv/kv/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ where
new_writes -> for_each(|x| trace!("NEW WRITE: {:?}", x));

// Step 1: Put the new writes in a map, with the write as the key and a SetBoundedLattice as the value.
infecting_writes = union() -> state::<'static, MapUnionHashMap<MessageId, InfectingWrite>>();
infecting_writes = union() -> state_by::<'static, MapUnionHashMap<MessageId, InfectingWrite>>(std::convert::identity, std::default::Default::default);

new_writes -> map(|write| {
// Ideally, the write itself is the key, but writes are a hashmap and hashmaps don't
Expand Down
4 changes: 3 additions & 1 deletion dfir_lang/src/graph/ops/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ use super::{
/// -> map(SetUnionSingletonSet::new_from)
/// -> state::<SetUnionHashSet<usize>>();
/// ```
/// The `state` operator is equivalent to `state_by` used with an identity mapping operator with
/// `Default::default` providing the factory function.
pub const STATE: OperatorConstraints = OperatorConstraints {
name: "state",
categories: &[OperatorCategory::Persistence],
Expand All @@ -39,7 +41,7 @@ pub const STATE: OperatorConstraints = OperatorConstraints {
diagnostics| {

let wc = WriteContextArgs {
arguments: &parse_quote_spanned!(op_span => ::std::convert::identity),
arguments: &parse_quote_spanned!(op_span => ::std::convert::identity, ::std::default::Default::default),
..wc.clone()
};

Expand Down
46 changes: 35 additions & 11 deletions dfir_lang/src/graph/ops/state_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use super::{
};
use crate::diagnostic::{Diagnostic, Level};

/// List state operator, but with a closure to map the input to the state lattice.
/// List state operator, but with a closure to map the input to the state lattice and a factory
/// function to initialize the internal data structure.
///
/// The emitted outputs (both the referencable singleton and the optional pass-through stream) are
/// of the same type as the inputs to the state_by operator and are not required to be a lattice
Expand All @@ -15,19 +16,37 @@ use crate::diagnostic::{Diagnostic, Level};
/// ```dfir
/// use std::collections::HashSet;
///
///
/// use lattices::set_union::{CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet};
///
/// my_state = source_iter(0..3)
/// -> state_by::<SetUnionHashSet<usize>>(SetUnionSingletonSet::new_from);
/// -> state_by::<SetUnionHashSet<usize>>(SetUnionSingletonSet::new_from, std::default::Default::default);
/// ```
/// The 2nd argument into `state_by` is a factory function that can be used to supply a custom
/// initial value for the backing state. The initial value is still expected to be bottom (and will
/// be checked). This is useful for doing things like pre-allocating buffers, etc. In the above
/// example, it is just using `Default::default()`
///
/// An example of preallocating the capacity in a hashmap:
///
///```dfir
/// use std::collections::HashSet;
/// use lattices::set_union::{SetUnion, CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet};
///
/// my_state = source_iter(0..3)
/// -> state_by::<SetUnionHashSet<usize>>(SetUnionSingletonSet::new_from, {|| SetUnion::new(HashSet::<usize>::with_capacity(1_000)) });
///```
///
/// The `state` operator is equivalent to `state_by` used with an identity mapping operator with
/// `Default::default` providing the factory function.
pub const STATE_BY: OperatorConstraints = OperatorConstraints {
name: "state_by",
categories: &[OperatorCategory::Persistence],
hard_range_inn: RANGE_1,
soft_range_inn: RANGE_1,
hard_range_out: &(0..=1),
soft_range_out: &(0..=1),
num_args: 1,
num_args: 2,
persistence_args: &(0..=1),
type_args: &(0..=1),
is_external_input: false,
Expand Down Expand Up @@ -80,19 +99,24 @@ pub const STATE_BY: OperatorConstraints = OperatorConstraints {
_ => unreachable!(),
};


let state_ident = singleton_output_ident;
let mut write_prologue = quote_spanned! {op_span=>
let #state_ident = #hydroflow.add_state(::std::cell::RefCell::new(
<#lattice_type as ::std::default::Default>::default()
));
let factory_fn = &arguments[1];

let mut write_prologue = quote_spanned! { op_span=>
let #state_ident = {
let data_struct : #lattice_type = (#factory_fn)();
::std::debug_assert!(::lattices::IsBot::is_bot(&data_struct));
#hydroflow.add_state(::std::cell::RefCell::new(data_struct))
};
};
if Persistence::Tick == persistence {
write_prologue.extend(quote_spanned! {op_span=>
#hydroflow.set_state_tick_hook(#state_ident, |rcell| { rcell.take(); }); // Resets state to `Default::default()`.
});
}

let func = &arguments[0];
let by_fn = &arguments[0];

// TODO(mingwei): deduplicate codegen
let write_iterator = if is_pull {
Expand All @@ -117,7 +141,7 @@ pub const STATE_BY: OperatorConstraints = OperatorConstraints {
#root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
})
}
check_input::<_, _, _, _, #lattice_type>(#input, #func, #state_ident, #context)
check_input::<_, _, _, _, #lattice_type>(#input, #by_fn, #state_ident, #context)
};
}
} else if let Some(output) = outputs.first() {
Expand All @@ -141,7 +165,7 @@ pub const STATE_BY: OperatorConstraints = OperatorConstraints {
#root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
}, push)
}
check_output::<_, _, _, _, #lattice_type>(#output, #func, #state_ident, #context)
check_output::<_, _, _, _, #lattice_type>(#output, #by_fn, #state_ident, #context)
};
}
} else {
Expand All @@ -164,7 +188,7 @@ pub const STATE_BY: OperatorConstraints = OperatorConstraints {
#root::lattices::Merge::merge(&mut *state, (mapfn)(item));
})
}
check_output::<_, _, _, #lattice_type>(#state_ident, #func, #context)
check_output::<_, _, _, #lattice_type>(#state_ident, #by_fn, #context)
};
}
};
Expand Down

0 comments on commit a9762c5

Please sign in to comment.