diff --git a/src/array/darray.jl b/src/array/darray.jl index 11feb53c..37c61a93 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -191,6 +191,8 @@ function Base.collect(d::DArray; tree=false) end end +Base.wait(A::DArray) = foreach(wait, A.chunks) + ### show #= FIXME diff --git a/src/array/operators.jl b/src/array/operators.jl index 4c3ff3fb..4dbf398b 100644 --- a/src/array/operators.jl +++ b/src/array/operators.jl @@ -114,10 +114,17 @@ Base.last(A::DArray) = A[end] # In-place operations +function imap!(f, A) + for idx in eachindex(A) + A[idx] = f(A[idx]) + end + return A +end + function Base.map!(f, a::DArray{T}) where T Dagger.spawn_datadeps() do for ca in chunks(a) - Dagger.@spawn map!(f, InOut(ca), ca) + Dagger.@spawn imap!(f, InOut(ca)) end end return a diff --git a/src/array/random.jl b/src/array/random.jl index e6698470..705f3e33 100644 --- a/src/array/random.jl +++ b/src/array/random.jl @@ -9,7 +9,7 @@ function Random.rand!(rng::AbstractRNG, A::DArray{T}) where T Dagger.spawn_datadeps() do for Ac in chunks(A) rng = randfork(rng, part_sz) - Dagger.@spawn map!(_->rand(rng, T), InOut(Ac), Ac) + Dagger.@spawn imap!(InOut(_->rand(rng, T)), InOut(Ac)) end end return A @@ -19,7 +19,7 @@ function Random.randn!(rng::AbstractRNG, A::DArray{T}) where T Dagger.spawn_datadeps() do for Ac in chunks(A) rng = randfork(rng, part_sz) - Dagger.@spawn map!(_->randn(rng, T), InOut(Ac), Ac) + Dagger.@spawn imap!(InOut(_->randn(rng, T)), InOut(Ac)) end end return A diff --git a/src/datadeps.jl b/src/datadeps.jl index 43c4c384..42778fce 100644 --- a/src/datadeps.jl +++ b/src/datadeps.jl @@ -147,18 +147,22 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState # The mapping of memory space to remote argument copies remote_args::Dict{MemorySpace,IdDict{Any,Any}} + # Cache of whether arguments supports in-place move + supports_inplace_cache::IdDict{Any,Bool} + # The aliasing analysis state alias_state::State function DataDepsState(aliasing::Bool) dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}[] remote_args = Dict{MemorySpace,IdDict{Any,Any}}() + supports_inplace_cache = IdDict{Any,Bool}() if aliasing state = DataDepsAliasingState() else state = DataDepsNonAliasingState() end - return new{typeof(state)}(aliasing, dependencies, remote_args, state) + return new{typeof(state)}(aliasing, dependencies, remote_args, supports_inplace_cache, state) end end @@ -168,6 +172,12 @@ function aliasing(astate::DataDepsAliasingState, arg, dep_mod) end end +function supports_inplace_move(state::DataDepsState, arg) + return get!(state.supports_inplace_cache, arg) do + return supports_inplace_move(arg) + end +end + # Determine which arguments could be written to, and thus need tracking "Whether `arg` has any writedep in this datadeps region." @@ -323,6 +333,30 @@ function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, t astate.data_origin[task] = space end +""" + supports_inplace_move(x) -> Bool + +Returns `false` if `x` doesn't support being copied into from another object +like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting +to copy between values which don't support mutation or otherwise don't have an +implemented `move!` and want to skip in-place copies. When this returns +`false`, datadeps will instead perform out-of-place copies for each non-local +use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` +region returns. +""" +supports_inplace_move(x) = true +supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) +function supports_inplace_move(c::Chunk) + # FIXME: Use MemPool.access_ref + pid = root_worker_id(c.processor) + if pid == myid() + return supports_inplace_move(poolget(c.handle)) + else + return remotecall_fetch(supports_inplace_move, pid, c) + end +end +supports_inplace_move(::Function) = false + # Read/write dependency management function get_write_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps) _get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps) @@ -677,8 +711,15 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Is the data written previously or now? arg, deps = unwrap_inout(arg) arg = arg isa DTask ? fetch(arg; raw=true) : arg - if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)" + if !type_may_alias(typeof(arg)) + @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (immutable)" + spec.args[idx] = pos => arg + continue + end + + # Is the data writeable? + if !supports_inplace_move(state, arg) + @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (non-writeable)" spec.args[idx] = pos => arg continue end @@ -738,7 +779,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Validate that we're not accidentally performing a copy for (idx, (_, arg)) in enumerate(spec.args) _, deps = unwrap_inout(task_args[idx][2]) - if is_writedep(arg, deps, task) + # N.B. We only do this check when the argument supports in-place + # moves, because for the moment, we are not guaranteeing updates or + # write-back of results + if is_writedep(arg, deps, task) && supports_inplace_move(state, arg) arg_space = memory_space(arg) @assert arg_space == our_space "($(repr(spec.f)))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space" end @@ -750,6 +794,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) arg, deps = unwrap_inout(arg) arg = arg isa DTask ? fetch(arg; raw=true) : arg type_may_alias(typeof(arg)) || continue + supports_inplace_move(state, arg) || continue if queue.aliasing for (dep_mod, _, writedep) in deps ainfo = aliasing(astate, arg, dep_mod) @@ -830,6 +875,12 @@ function distribute_tasks!(queue::DataDepsTaskQueue) continue end + # Skip non-writeable arguments + if !supports_inplace_move(state, arg) + @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" + continue + end + # Get the set of writers ainfo_writes = get!(Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}, arg_writes, arg) @@ -877,8 +928,13 @@ function distribute_tasks!(queue::DataDepsTaskQueue) for arg in keys(astate.data_origin) # Is the data previously written? arg, deps = unwrap_inout(arg) - if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (unwritten)" + if !type_may_alias(typeof(arg)) + @dagdebug nothing :spawn_datadeps "Skipped copy-from (immutable)" + end + + # Can the data be written back to? + if !supports_inplace_move(state, arg) + @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" end # Is the source of truth elsewhere? @@ -912,7 +968,7 @@ Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or argument, respectively. These argument dependencies will be used to specify which tasks depend on each other based on the following rules: -- Dependencies across different arguments are independent; only dependencies on the same argument synchronize with each other ("same-ness" is determined based on `isequal`) +- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other - `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects - Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel - An `Out` dependency synchronizes with any previous `In` and `Out` dependencies diff --git a/src/thunk.jl b/src/thunk.jl index e173e1c2..b24806f8 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -570,7 +570,14 @@ function show_thunk(io::IO, t) end print(io, ")") end -Base.show(io::IO, t::Thunk) = show_thunk(io, t) +function Base.show(io::IO, t::Thunk) + lazy_level = parse(Int, get(ENV, "JULIA_DAGGER_SHOW_THUNK_VERBOSITY", "0")) + if lazy_level == 0 + show_thunk(io, t) + else + show_thunk(IOContext(io, :lazy_level => lazy_level), t) + end +end Base.summary(t::Thunk) = repr(t) inputs(x::Thunk) = x.inputs diff --git a/test/datadeps.jl b/test/datadeps.jl index 8d16c65b..fa1bcd7e 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -44,7 +44,7 @@ function taskdeps_for_task(logs::Dict{Int,<:Dict}, tid::Int) end error("Task $tid not found in logs") end -function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=true) +function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=false) g = SimpleDiGraph() tid_to_v = Dict{Int,Int}() seen = Set{Int}() @@ -165,7 +165,7 @@ function test_datadeps(;args_chunks::Bool, end tid_1, tid_2 = task_id.(ts) test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2]) + test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2], nondom_check=false) # R->W Aliasing ts = []