Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce use of DataLayouts internals #1934

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions ext/cuda/cuda_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@ import ClimaCore.Fields
import ClimaCore.DataLayouts
import ClimaCore.DataLayouts: empty_kernel_stats

get_n_items(field::Fields.Field) = get_n_items(Fields.field_values(field))
get_n_items(data::DataLayouts.AbstractData) = get_n_items(size(data))
get_n_items(arr::AbstractArray) = get_n_items(size(parent(arr)))
get_n_items(tup::Tuple) = prod(tup)

const reported_stats = Dict()
# Call via ClimaCore.DataLayouts.empty_kernel_stats()
empty_kernel_stats(::ClimaComms.CUDADevice) = empty!(reported_stats)
Expand Down Expand Up @@ -37,15 +32,15 @@ to benchmark compare against auto-determined threads/blocks (if `auto=false`).
function auto_launch!(
f!::F!,
args,
data;
nitems::Union{Integer, Nothing} = nothing;
auto = false,
threads_s = nothing,
blocks_s = nothing,
always_inline = true,
caller = :unknown,
) where {F!}
if auto
nitems = get_n_items(data)
@assert !isnothing(nitems)
if nitems ≥ 0
kernel = CUDA.@cuda always_inline = true launch = false f!(args...)
config = CUDA.launch_configuration(kernel.fun)
Expand All @@ -64,7 +59,7 @@ function auto_launch!(
# CUDA.registers(kernel) > 50 || return nothing # for debugging
# occursin("single_field_solve_kernel", string(nameof(F!))) || return nothing
if !haskey(reported_stats, key)
nitems = get_n_items(data)
@assert !isnothing(nitems)
kernel = CUDA.@cuda always_inline = true launch = false f!(args...)
config = CUDA.launch_configuration(kernel.fun)
threads = min(nitems, config.threads)
Expand Down
20 changes: 6 additions & 14 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ function Base.copyto!(
if Nh > 0
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
(dest, bc);
threads_s = (Nij, Nij),
blocks_s = (Nh, 1),
)
Expand All @@ -42,8 +41,7 @@ function Base.copyto!(
Nv_blocks = cld(Nv, Nv_per_block)
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
(dest, bc);
threads_s = (Nij, Nij, Nv_per_block),
blocks_s = (Nh, Nv_blocks),
)
Expand All @@ -59,8 +57,7 @@ function Base.copyto!(
if Nv > 0
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
(dest, bc);
threads_s = (1, 1),
blocks_s = (1, Nv),
)
Expand All @@ -73,13 +70,7 @@ function Base.copyto!(
bc::DataLayouts.BroadcastedUnionDataF{S},
::ToCUDA,
) where {S}
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (1, 1),
blocks_s = (1, 1),
)
auto_launch!(knl_copyto!, (dest, bc); threads_s = (1, 1), blocks_s = (1, 1))
return dest
end

Expand All @@ -100,7 +91,8 @@ function cuda_copyto!(dest::AbstractData, bc)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
nitems = prod(DataLayouts.universal_size(dest))
auto_launch!(knl_copyto_flat!, (dest, bc, us), nitems; auto = true)
end
return dest
end
Expand Down
3 changes: 2 additions & 1 deletion ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ function cuda_fill!(dest::AbstractData, val)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
auto_launch!(knl_fill_flat!, (dest, val, us), dest; auto = true)
nitems = prod(DataLayouts.universal_size(dest))
auto_launch!(knl_fill_flat!, (dest, val, us), nitems; auto = true)
end
return dest
end
Expand Down
12 changes: 4 additions & 8 deletions ext/cuda/data_layouts_fused_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ function fused_copyto!(
Nv_blocks = cld(Nv, Nv_per_block)
auto_launch!(
knl_fused_copyto!,
(fmbc,),
dest1;
(fmbc,);
threads_s = (Nij, Nij, Nv_per_block),
blocks_s = (Nh, Nv_blocks),
)
Expand All @@ -68,8 +67,7 @@ function fused_copyto!(
if Nh > 0
auto_launch!(
knl_fused_copyto!,
(fmbc,),
dest1;
(fmbc,);
threads_s = (Nij, Nij),
blocks_s = (Nh, 1),
)
Expand All @@ -85,8 +83,7 @@ function fused_copyto!(
if Nv > 0 && Nh > 0
auto_launch!(
knl_fused_copyto!,
(fmbc,),
dest1;
(fmbc,);
threads_s = (1, 1),
blocks_s = (Nh, Nv),
)
Expand All @@ -101,8 +98,7 @@ function fused_copyto!(
) where {S}
auto_launch!(
knl_fused_copyto!,
(fmbc,),
dest1;
(fmbc,);
threads_s = (1, 1),
blocks_s = (1, 1),
)
Expand Down
2 changes: 1 addition & 1 deletion ext/cuda/data_layouts_mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function mapreduce_cuda(
pdata = parent(data)
T = eltype(pdata)
(Ni, Nj, Nk, Nv, Nh) = size(data)
Nf = div(length(pdata), prod(size(data))) # length of field dimension
Nf = DataLayouts.ncomponents(data) # length of field dimension
pwt = parent(weighted_jacobian)

nitems = Nv * Ni * Nj * Nk * Nh
Expand Down
70 changes: 22 additions & 48 deletions ext/cuda/limiters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,24 @@ function compute_element_bounds!(
ρ,
::ClimaComms.CUDADevice,
)
S = size(Fields.field_values(ρ))
(Ni, Nj, _, Nv, Nh) = S
ρ_values = Fields.field_values(Operators.strip_space(ρ, axes(ρ)))
ρq_values = Fields.field_values(Operators.strip_space(ρq, axes(ρq)))
(_, _, _, Nv, Nh) = DataLayouts.universal_size(ρ_values)
nthreads, nblocks = config_threadblock(Nv, Nh)

args = (
limiter,
Fields.field_values(Operators.strip_space(ρq, axes(ρq))),
Fields.field_values(Operators.strip_space(ρ, axes(ρ))),
Nv,
Nh,
Val(Ni),
Val(Nj),
)
args = (limiter, ρq_values, ρ_values)
auto_launch!(
compute_element_bounds_kernel!,
args,
ρ;
args;
threads_s = nthreads,
blocks_s = nblocks,
)
return nothing
end


function compute_element_bounds_kernel!(
limiter,
ρq,
ρ,
Nv,
Nh,
::Val{Ni},
::Val{Nj},
) where {Ni, Nj}
function compute_element_bounds_kernel!(limiter, ρq, ρ)
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(ρ)
n = (Nv, Nh)
tidx = thread_index()
@inbounds if valid_range(tidx, prod(n))
Expand Down Expand Up @@ -88,21 +73,18 @@ function compute_neighbor_bounds_local!(
::ClimaComms.CUDADevice,
)
topology = Spaces.topology(axes(ρ))
Ni, Nj, _, Nv, Nh = size(Fields.field_values(ρ))
us = DataLayouts.UniversalSize(Fields.field_values(ρ))
(_, _, _, Nv, Nh) = DataLayouts.universal_size(us)
nthreads, nblocks = config_threadblock(Nv, Nh)
args = (
limiter,
topology.local_neighbor_elem,
topology.local_neighbor_elem_offset,
Nv,
Nh,
Val(Ni),
Val(Nj),
us,
)
auto_launch!(
compute_neighbor_bounds_local_kernel!,
args,
ρ;
args;
threads_s = nthreads,
blocks_s = nblocks,
)
Expand All @@ -112,12 +94,9 @@ function compute_neighbor_bounds_local_kernel!(
limiter,
local_neighbor_elem,
local_neighbor_elem_offset,
Nv,
Nh,
::Val{Ni},
::Val{Nj},
) where {Ni, Nj}

us::DataLayouts.UniversalSize,
)
(_, _, _, Nv, Nh) = DataLayouts.universal_size(us)
n = (Nv, Nh)
tidx = thread_index()
@inbounds if valid_range(tidx, prod(n))
Expand Down Expand Up @@ -147,27 +126,24 @@ function apply_limiter!(
::ClimaComms.CUDADevice,
)
ρq_data = Fields.field_values(ρq)
(Ni, Nj, _, Nv, Nh) = size(ρq_data)
Nf = DataLayouts.ncomponents(ρq_data)
us = DataLayouts.UniversalSize(ρq_data)
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us)
maxiter = Ni * Nj
Nf = DataLayouts.ncomponents(ρq_data)
WJ = Spaces.local_geometry_data(axes(ρq)).WJ
nthreads, nblocks = config_threadblock(Nv, Nh)
args = (
limiter,
Fields.field_values(Operators.strip_space(ρq, axes(ρq))),
Fields.field_values(Operators.strip_space(ρ, axes(ρ))),
WJ,
Nv,
Nh,
us,
Val(Nf),
Val(Ni),
Val(Nj),
Val(maxiter),
)
auto_launch!(
apply_limiter_kernel!,
args,
ρ;
args;
threads_s = nthreads,
blocks_s = nblocks,
)
Expand All @@ -179,15 +155,13 @@ function apply_limiter_kernel!(
ρq_data,
ρ_data,
WJ_data,
Nv,
Nh,
us::DataLayouts.UniversalSize,
::Val{Nf},
::Val{Ni},
::Val{Nj},
::Val{maxiter},
) where {Nf, Ni, Nj, maxiter}
) where {Nf, maxiter}
(; q_bounds_nbr, rtol) = limiter
converged = true
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us)
n = (Nv, Nh)
tidx = thread_index()
@inbounds if valid_range(tidx, prod(n))
Expand Down
3 changes: 1 addition & 2 deletions ext/cuda/matrix_fields_multiple_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ NVTX.@annotate function multiple_field_solve!(

auto_launch!(
multiple_field_solve_kernel!,
args,
x1;
args;
threads_s = nthreads,
blocks_s = nblocks,
always_inline = true,
Expand Down
3 changes: 1 addition & 2 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
args = (device, cache, x, A, b)
auto_launch!(
single_field_solve_kernel!,
args,
x;
args;
threads_s = nthreads,
blocks_s = nblocks,
)
Expand Down
3 changes: 1 addition & 2 deletions ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ function Base.copyto!(
(strip_space(out, space), strip_space(bc, space), axes(out), bounds, us)
auto_launch!(
copyto_stencil_kernel!,
args,
out;
args;
threads_s = (nthreads,),
blocks_s = (nblocks,),
)
Expand Down
4 changes: 2 additions & 2 deletions ext/cuda/operators_integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function column_reduce_device!(
init,
space,
)
auto_launch!(bycolumn_kernel!, args, (); threads_s, blocks_s)
auto_launch!(bycolumn_kernel!, args; threads_s, blocks_s)
end

function column_accumulate_device!(
Expand All @@ -52,7 +52,7 @@ function column_accumulate_device!(
init,
space,
)
auto_launch!(bycolumn_kernel!, args, (); threads_s, blocks_s)
auto_launch!(bycolumn_kernel!, args; threads_s, blocks_s)
end

bycolumn_kernel!(
Expand Down
3 changes: 1 addition & 2 deletions ext/cuda/operators_spectral_element.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ function Base.copyto!(
)
auto_launch!(
copyto_spectral_kernel!,
args,
out;
args;
threads_s = (Nq, Nq, Nvthreads),
blocks_s = (Nh, Nvblocks),
)
Expand Down
3 changes: 1 addition & 2 deletions ext/cuda/operators_thomas_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ function column_thomas_solve!(::ClimaComms.CUDADevice, A, b)
args = (A, b)
auto_launch!(
thomas_algorithm_kernel!,
args,
size(Fields.field_values(A));
args;
threads_s = nthreads,
blocks_s = nblocks,
)
Expand Down
6 changes: 2 additions & 4 deletions ext/cuda/remapping_distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ function _set_interpolated_values_device!(
)
auto_launch!(
set_interpolated_values_kernel!,
args,
out;
args;
threads_s = (nthreads),
blocks_s = (nblocks),
)
Expand Down Expand Up @@ -163,8 +162,7 @@ function _set_interpolated_values_device!(
)
auto_launch!(
set_interpolated_values_kernel!,
args,
out;
args;
threads_s = (nthreads),
blocks_s = (nblocks),
)
Expand Down
Loading
Loading