Skip to content

Commit

Permalink
move loop scheduling into Context
Browse files Browse the repository at this point in the history
  • Loading branch information
MingweiSamuel committed Jan 24, 2025
1 parent 4d93042 commit c13aa77
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 116 deletions.
1 change: 1 addition & 0 deletions dfir_lang/src/graph/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ declare_ops![
join_multiset::JOIN_MULTISET,
fold_keyed::FOLD_KEYED,
reduce_keyed::REDUCE_KEYED,
repeat_n::REPEAT_N,
lattice_bimorphism::LATTICE_BIMORPHISM,
_lattice_fold_batch::_LATTICE_FOLD_BATCH,
lattice_fold::LATTICE_FOLD,
Expand Down
57 changes: 57 additions & 0 deletions dfir_lang/src/graph/ops/repeat_n.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use quote::quote_spanned;

use super::{OperatorConstraints, OperatorWriteOutput, WriteContextArgs};

/// TODO(mingwei): docs
pub const REPEAT_N: OperatorConstraints = OperatorConstraints {
name: "repeat_n",
num_args: 1,
write_fn: |wc @ &WriteContextArgs {
context,
hydroflow,
op_span,
arguments,
..
},
diagnostics| {
let OperatorWriteOutput {
write_prologue,
write_iterator,
write_iterator_after,
} = (super::all_once::ALL_ONCE.write_fn)(wc, diagnostics)?;

let count_ident = wc.make_ident("count");

let write_prologue = quote_spanned! {op_span=>
#write_prologue

let #count_ident = #hydroflow.add_state(::std::cell::Cell::new(0_usize));
#hydroflow.set_state_tick_hook(#count_ident, move |cell| { cell.take(); });
};

// Reschedule, to repeat.
let count_arg = &arguments[0];
let write_iterator_after = quote_spanned! {op_span=>
#write_iterator_after

{
let count_ref = #context.state_ref(#count_ident);
if #context.is_first_loop_iteration() {
count_ref.set(0);
}
let count = count_ref.get() + 1;
if count < #count_arg {
count_ref.set(count);
#context.reschedule_loop_block();
}
}
};

Ok(OperatorWriteOutput {
write_prologue,
write_iterator,
write_iterator_after,
})
},
..super::all_once::ALL_ONCE
};
13 changes: 11 additions & 2 deletions dfir_rs/src/scheduled/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ use tokio::task::JoinHandle;
use web_time::SystemTime;

use super::state::StateHandle;
use super::{StateId, SubgraphId};
use super::{LoopId, LoopTag, StateId, SubgraphId};
use crate::scheduled::ticks::TickInstant;
use crate::util::slot_vec::{SecondarySlotVec, SlotVec};

/// The main state and scheduler of the Hydroflow instance. Provided as the `context` API to each
/// subgraph/operator as it is run.
Expand Down Expand Up @@ -50,12 +51,16 @@ pub struct Context {
pub(super) current_tick_start: SystemTime,
pub(super) subgraph_last_tick_run_in: Option<TickInstant>,

// Depth of loop (zero for top-level).
pub(super) loop_depth: SlotVec<LoopTag, usize>,
// Map from `LoopId` to parent `LoopId` (or `None` for top-level).
pub(super) loop_parent: SecondarySlotVec<LoopTag, Option<LoopId>>,

/// The SubgraphId of the currently running operator. When this context is
/// not being forwarded to a running operator, this field is meaningless.
pub(super) subgraph_id: SubgraphId,

tasks_to_spawn: Vec<Pin<Box<dyn Future<Output = ()> + 'static>>>,

/// Join handles for spawned tasks.
task_join_handles: Vec<JoinHandle<()>>,
}
Expand Down Expand Up @@ -235,6 +240,7 @@ impl Default for Context {
fn default() -> Self {
let stratum_queues = vec![Default::default()]; // Always initialize stratum #0.
let (event_queue_send, event_queue_recv) = mpsc::unbounded_channel();
let (loop_depth, loop_parent) = Default::default();
Self {
states: Vec::new(),

Expand All @@ -252,6 +258,9 @@ impl Default for Context {
current_tick_start: SystemTime::now(),
subgraph_last_tick_run_in: None,

loop_depth,
loop_parent,

// Will be re-set before use.
subgraph_id: SubgraphId::from_raw(0),

Expand Down
15 changes: 5 additions & 10 deletions dfir_rs/src/scheduled/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ use super::port::{RecvCtx, RecvPort, SendCtx, SendPort, RECV, SEND};
use super::reactor::Reactor;
use super::state::StateHandle;
use super::subgraph::Subgraph;
use super::{HandoffId, HandoffTag, LoopId, LoopTag, SubgraphId, SubgraphTag};
use super::{HandoffId, HandoffTag, LoopId, SubgraphId, SubgraphTag};
use crate::scheduled::ticks::{TickDuration, TickInstant};
use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
use crate::util::slot_vec::SlotVec;
use crate::Never;

/// A DFIR graph. Owns, schedules, and runs the compiled subgraphs.
Expand All @@ -32,11 +32,6 @@ pub struct Dfir<'a> {
pub(super) subgraphs: SlotVec<SubgraphTag, SubgraphData<'a>>,
pub(super) context: Context,

// Depth of loop (zero for top-level).
loop_depth: SlotVec<LoopTag, usize>,
// Map from `LoopId` to parent `LoopId` (or `None` for top-level).
loop_parent: SecondarySlotVec<LoopTag, Option<LoopId>>,

handoffs: SlotVec<HandoffTag, HandoffData>,

#[cfg(feature = "meta")]
Expand Down Expand Up @@ -817,9 +812,9 @@ impl<'a> Dfir<'a> {
///
/// TODO(mingwei): add loop names to ensure traceability while debugging?
pub fn add_loop(&mut self, parent: Option<LoopId>) -> LoopId {
let depth = parent.map_or(0, |p| self.loop_depth[p] + 1);
let loop_id = self.loop_depth.insert(depth);
self.loop_parent.insert(loop_id, parent);
let depth = parent.map_or(0, |p| self.context.loop_depth[p] + 1);
let loop_id = self.context.loop_depth.insert(depth);
self.context.loop_parent.insert(loop_id, parent);
loop_id
}
}
Expand Down
206 changes: 102 additions & 104 deletions dfir_rs/tests/surface_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,118 +72,116 @@ pub fn test_flo_nested() {
df.run_available();
}

/*
#[multiplatform_test]
pub fn test_flo_repeat_n() {
let mut df = dfir_syntax! {
users = source_iter(["alice", "bob"]);
messages = source_stream(iter_batches_stream(0..12, 3));
loop {
// TODO(mingwei): cross_join type negotion should allow us to eliminate `flatten()`.
users -> batch() -> flatten() -> [0]cp;
messages -> batch() -> flatten() -> [1]cp;
cp = cross_join::<'static, 'tick>();
loop {
cp
-> repeat_n(3)
-> map(|vec| (context.current_tick().0, vec))
-> inspect(|x| println!("{:?}", x))
-> assert_eq([
(0, vec![("alice", 0), ("alice", 1), ("alice", 2), ("bob", 0), ("bob", 1), ("bob", 2)]),
(0, vec![("alice", 0), ("alice", 1), ("alice", 2), ("bob", 0), ("bob", 1), ("bob", 2)]),
(0, vec![("alice", 0), ("alice", 1), ("alice", 2), ("bob", 0), ("bob", 1), ("bob", 2)]),
(1, vec![("alice", 3), ("alice", 4), ("alice", 5), ("bob", 3), ("bob", 4), ("bob", 5)]),
(1, vec![("alice", 3), ("alice", 4), ("alice", 5), ("bob", 3), ("bob", 4), ("bob", 5)]),
(1, vec![("alice", 3), ("alice", 4), ("alice", 5), ("bob", 3), ("bob", 4), ("bob", 5)]),
(2, vec![("alice", 6), ("alice", 7), ("alice", 8), ("bob", 6), ("bob", 7), ("bob", 8)]),
(2, vec![("alice", 6), ("alice", 7), ("alice", 8), ("bob", 6), ("bob", 7), ("bob", 8)]),
(2, vec![("alice", 6), ("alice", 7), ("alice", 8), ("bob", 6), ("bob", 7), ("bob", 8)]),
(3, vec![("alice", 9), ("alice", 10), ("alice", 11), ("bob", 9), ("bob", 10), ("bob", 11)]),
(3, vec![("alice", 9), ("alice", 10), ("alice", 11), ("bob", 9), ("bob", 10), ("bob", 11)]),
(3, vec![("alice", 9), ("alice", 10), ("alice", 11), ("bob", 9), ("bob", 10), ("bob", 11)]),
]);
}
}
};
assert_graphvis_snapshots!(df);
df.run_available();
}
#[multiplatform_test(test, wasm, env_tracing)]
pub fn test_flo_repeat_n_nested() {
let mut df = dfir_syntax! {
usrs1 = source_iter(["alice", "bob"]);
users = source_iter(["alice", "bob"]);
messages = source_stream(iter_batches_stream(0..12, 3));
loop {
// TODO(mingwei): cross_join type negotion should allow us to eliminate `flatten()`.
users -> batch() -> flatten() -> [0]cp;
messages -> batch() -> flatten() -> [1]cp;
cp = cross_join::<'static, 'tick>();
loop {
usrs2 = usrs1 -> batch() -> flatten();
loop {
usrs3 = usrs2 -> repeat_n(3) -> flatten()
-> inspect(|x| println!("{:?} {}", x, context.is_first_loop_iteration()));
loop {
usrs3 -> repeat_n(3)
-> inspect(|x| println!("{:?} {}", x, context.is_first_loop_iteration()))
-> assert_eq([
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
]);
}
cp
-> repeat_n(3)

Check failure on line 87 in dfir_rs/tests/surface_loop.rs

View workflow job for this annotation

GitHub Actions / Test Suite (WebAssembly) (pinned-nightly)

no method named `is_first_loop_iteration` found for mutable reference `&mut dfir_rs::scheduled::context::Context` in the current scope
-> map(|vec| (context.current_tick().0, vec))
-> inspect(|x| println!("{:?}", x))
-> assert_eq([
(0, vec![("alice", 0), ("alice", 1), ("alice", 2), ("bob", 0), ("bob", 1), ("bob", 2)]),
(0, vec![("alice", 0), ("alice", 1), ("alice", 2), ("bob", 0), ("bob", 1), ("bob", 2)]),
(0, vec![("alice", 0), ("alice", 1), ("alice", 2), ("bob", 0), ("bob", 1), ("bob", 2)]),
(1, vec![("alice", 3), ("alice", 4), ("alice", 5), ("bob", 3), ("bob", 4), ("bob", 5)]),
(1, vec![("alice", 3), ("alice", 4), ("alice", 5), ("bob", 3), ("bob", 4), ("bob", 5)]),
(1, vec![("alice", 3), ("alice", 4), ("alice", 5), ("bob", 3), ("bob", 4), ("bob", 5)]),
(2, vec![("alice", 6), ("alice", 7), ("alice", 8), ("bob", 6), ("bob", 7), ("bob", 8)]),
(2, vec![("alice", 6), ("alice", 7), ("alice", 8), ("bob", 6), ("bob", 7), ("bob", 8)]),
(2, vec![("alice", 6), ("alice", 7), ("alice", 8), ("bob", 6), ("bob", 7), ("bob", 8)]),
(3, vec![("alice", 9), ("alice", 10), ("alice", 11), ("bob", 9), ("bob", 10), ("bob", 11)]),
(3, vec![("alice", 9), ("alice", 10), ("alice", 11), ("bob", 9), ("bob", 10), ("bob", 11)]),
(3, vec![("alice", 9), ("alice", 10), ("alice", 11), ("bob", 9), ("bob", 10), ("bob", 11)]),
]);
}
}
};
assert_graphvis_snapshots!(df);
df.run_available();
}

#[multiplatform_test]
pub fn test_flo_repeat_n_multiple_nested() {
let mut df = dfir_syntax! {
usrs1 = source_iter(["alice", "bob"]);
loop {
usrs2 = usrs1 -> batch() -> flatten();
loop {
usrs3 = usrs2 -> repeat_n(3) -> flatten()
-> inspect(|x| println!("{:?} {}", x, context.is_first_loop_iteration()))
-> tee();
loop {
usrs3 -> repeat_n(3)
-> inspect(|x| println!("{} {:?} {}", line!(), x, context.is_first_loop_iteration()))
-> assert_eq([
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
]);
}
loop {
usrs3 -> repeat_n(3)
-> inspect(|x| println!("{} {:?} {}", line!(), x, context.is_first_loop_iteration()))
-> assert_eq([
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
vec!["alice", "bob"],
]);
}
}
}
};
assert_graphvis_snapshots!(df);
df.run_available();
}
*/
// #[multiplatform_test(test, wasm, env_tracing)]
// pub fn test_flo_repeat_n_nested() {
// let mut df = dfir_syntax! {
// usrs1 = source_iter(["alice", "bob"]);
// loop {
// usrs2 = usrs1 -> batch() -> flatten();
// loop {
// usrs3 = usrs2 -> repeat_n(3) -> flatten()
// -> inspect(|x| println!("{:?} {}", x, context.is_first_loop_iteration()));
// loop {
// usrs3 -> repeat_n(3)
// -> inspect(|x| println!("{:?} {}", x, context.is_first_loop_iteration()))
// -> assert_eq([
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// ]);
// }
// }
// }
// };
// assert_graphvis_snapshots!(df);
// df.run_available();
// }
//
// #[multiplatform_test]
// pub fn test_flo_repeat_n_multiple_nested() {
// let mut df = dfir_syntax! {
// usrs1 = source_iter(["alice", "bob"]);
// loop {
// usrs2 = usrs1 -> batch() -> flatten();
// loop {
// usrs3 = usrs2 -> repeat_n(3) -> flatten()
// -> inspect(|x| println!("{:?} {}", x, context.is_first_loop_iteration()))
// -> tee();
// loop {
// usrs3 -> repeat_n(3)
// -> inspect(|x| println!("{} {:?} {}", line!(), x, context.is_first_loop_iteration()))
// -> assert_eq([
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// ]);
// }
// loop {
// usrs3 -> repeat_n(3)
// -> inspect(|x| println!("{} {:?} {}", line!(), x, context.is_first_loop_iteration()))
// -> assert_eq([
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// vec!["alice", "bob"],
// ]);
// }
// }
// }
// };
// assert_graphvis_snapshots!(df);
// df.run_available();
// }

0 comments on commit c13aa77

Please sign in to comment.