Skip to content

Commit 3bed155

Browse files
Add HF datalayouts
1 parent 96bb8f4 commit 3bed155

39 files changed

+1857
-191
lines changed

.buildkite/pipeline.yml

+25
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,20 @@ steps:
16011601
agents:
16021602
slurm_gpus: 1
16031603

1604+
- label: ":computer: Float 32 3D sphere baroclinic wave (ρe) HF datalayout GPU"
1605+
key: "gpu_baroclinic_wave_rho_e_float32_hf"
1606+
command:
1607+
- "julia --color=yes --project=.buildkite examples/hybrid/driver.jl"
1608+
artifact_paths:
1609+
- "examples/hybrid/sphere/output/baroclinic_wave_rhoe_hf/Float32/*"
1610+
env:
1611+
TEST_NAME: "sphere/baroclinic_wave_rhoe_hf"
1612+
FLOAT_TYPE: "Float32"
1613+
HorizontalLayout: "IJHF"
1614+
CLIMACOMMS_DEVICE: "CUDA"
1615+
agents:
1616+
slurm_gpus: 1
1617+
16041618
- label: ":computer: 3D Box limiters advection slotted spheres"
16051619
key: "cpu_box_advection_limiter_slotted_spheres"
16061620
command:
@@ -1870,6 +1884,17 @@ steps:
18701884
TEST_NAME: "sphere/baroclinic_wave_rhoe"
18711885
FLOAT_TYPE: "Float64"
18721886

1887+
- label: ":computer: Float 64 3D sphere baroclinic wave (ρe) HF datalayout"
1888+
key: "cpu_baroclinic_wave_rho_e_float64_hf"
1889+
command:
1890+
- "julia --color=yes --project=.buildkite examples/hybrid/driver.jl"
1891+
artifact_paths:
1892+
- "examples/hybrid/sphere/output/baroclinic_wave_rhoe_hf/Float64/*"
1893+
env:
1894+
TEST_NAME: "sphere/baroclinic_wave_rhoe_hf"
1895+
FLOAT_TYPE: "Float64"
1896+
HorizontalLayout: "IJHF"
1897+
18731898
- label: ":computer: 3D sphere baroclinic wave (ρe)"
18741899
key: "cpu_baroclinic_wave_rho_e"
18751900
command:

NEWS.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@ ClimaCore.jl Release Notes
44
main
55
-------
66

7-
- Fixed world-age issue on Julia 1.11 issue [Julia#54780](https://github.com/JuliaLang/julia/issues/54780), PR [#2034](https://github.com/CliMA/ClimaCore.jl/pull/2034).
7+
- We've added new datalayouts: `VIJHF`,`IJHF`,`IHF`,`VIHF`, to explore their performance compared to our existing datalayouts: `VIJFH`,`IJFH`,`IFH`,`VIFH`. PR [#2055](https://github.com/CliMA/ClimaCore.jl/pull/2053), PR [#2052](https://github.com/CliMA/ClimaCore.jl/pull/2055).
8+
- We've refactored some modules to use less internals. PR [#2053](https://github.com/CliMA/ClimaCore.jl/pull/2053), PR [#2052](https://github.com/CliMA/ClimaCore.jl/pull/2052), [#2051](https://github.com/CliMA/ClimaCore.jl/pull/2051), [#2049](https://github.com/CliMA/ClimaCore.jl/pull/2049).
9+
- Some work was done in attempt to reduce specializations and compile time. PR [#2042](https://github.com/CliMA/ClimaCore.jl/pull/2042), [#2041](https://github.com/CliMA/ClimaCore.jl/pull/2041)
810

911
v0.14.19
1012
-------
1113

14+
- Fixed world-age issue on Julia 1.11 issue [Julia#54780](https://github.com/JuliaLang/julia/issues/54780), PR [#2034](https://github.com/CliMA/ClimaCore.jl/pull/2034).
15+
1216
### ![][badge-🐛bugfix] Fix undefined behavior in `DataLayout`s
1317

1418
PR [#2034](https://github.com/CliMA/ClimaCore.jl/pull/2034) fixes some undefined

docs/src/api.md

+4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ DataLayouts.IFH
3232
DataLayouts.IJFH
3333
DataLayouts.VIFH
3434
DataLayouts.VIJFH
35+
DataLayouts.IHF
36+
DataLayouts.IJHF
37+
DataLayouts.VIHF
38+
DataLayouts.VIJHF
3539
```
3640

3741
## Geometry

examples/common_spaces.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ function make_horizontal_space(
3535
mesh,
3636
npoly,
3737
context::ClimaComms.SingletonCommsContext,
38+
HorizontalLayout = DataLayouts.IJFH,
3839
)
3940
quad = Quadratures.GLL{npoly + 1}()
4041
if mesh isa Meshes.AbstractMesh1D
4142
topology = Topologies.IntervalTopology(ClimaComms.device(context), mesh)
4243
space = Spaces.SpectralElementSpace1D(topology, quad)
4344
elseif mesh isa Meshes.AbstractMesh2D
4445
topology = Topologies.Topology2D(context, mesh)
45-
space = Spaces.SpectralElementSpace2D(topology, quad)
46+
space = Spaces.SpectralElementSpace2D(topology, quad; HorizontalLayout)
4647
end
4748
return space
4849
end
@@ -51,13 +52,14 @@ function make_horizontal_space(
5152
mesh,
5253
npoly,
5354
comms_ctx::ClimaComms.MPICommsContext,
55+
HorizontalLayout = DataLayouts.IJFH,
5456
)
5557
quad = Quadratures.GLL{npoly + 1}()
5658
if mesh isa Meshes.AbstractMesh1D
5759
error("Distributed mode does not work with 1D horizontal spaces.")
5860
elseif mesh isa Meshes.AbstractMesh2D
5961
topology = Topologies.Topology2D(comms_ctx, mesh)
60-
space = Spaces.SpectralElementSpace2D(topology, quad)
62+
space = Spaces.SpectralElementSpace2D(topology, quad; HorizontalLayout)
6163
end
6264
return space
6365
end

examples/hybrid/driver.jl

+11-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ ClimaComms.@import_required_backends
3030
import SciMLBase
3131
const comms_ctx = ClimaComms.context()
3232
is_distributed = comms_ctx isa ClimaComms.MPICommsContext
33+
using ClimaCore: DataLayouts
3334

3435
using Logging
3536

@@ -91,7 +92,16 @@ if haskey(ENV, "RESTART_FILE")
9192
ᶠlocal_geometry = Fields.local_geometry_field(Y.f)
9293
else
9394
t_start = FT(0)
94-
h_space = make_horizontal_space(horizontal_mesh, npoly, comms_ctx)
95+
HorizontalLayouts = Dict()
96+
HorizontalLayouts["IJFH"] = DataLayouts.IJFH
97+
HorizontalLayouts["IJHF"] = DataLayouts.IJHF
98+
HorizontalLayout = HorizontalLayouts[get(ENV, "HorizontalLayout", "IJFH")]
99+
h_space = make_horizontal_space(
100+
horizontal_mesh,
101+
npoly,
102+
comms_ctx,
103+
HorizontalLayout,
104+
)
95105
center_space, face_space =
96106
make_hybrid_spaces(h_space, z_max, z_elem; z_stretch)
97107
ᶜlocal_geometry = Fields.local_geometry_field(center_space)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using ClimaCorePlots, Plots
2+
using ClimaCore.DataLayouts
3+
4+
include("baroclinic_wave_utilities.jl")
5+
6+
const sponge = false
7+
8+
# Variables required for driver.jl (modify as needed)
9+
horizontal_mesh = cubed_sphere_mesh(; radius = R, h_elem = 4)
10+
npoly = 4
11+
z_max = FT(30e3)
12+
z_elem = 10
13+
t_end = FT(60 * 60 * 24 * 10)
14+
dt = FT(400)
15+
dt_save_to_sol = FT(60 * 60 * 24)
16+
dt_save_to_disk = FT(0) # 0 means don't save to disk
17+
ode_algorithm = CTS.SSP333
18+
jacobian_flags = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :no_∂ᶜp∂ᶜK, ∂ᶠ𝕄ₜ∂ᶜρ_mode = :exact)
19+
20+
additional_cache(ᶜlocal_geometry, ᶠlocal_geometry, dt) = merge(
21+
hyperdiffusion_cache(ᶜlocal_geometry, ᶠlocal_geometry; κ₄ = FT(2e17)),
22+
sponge ? rayleigh_sponge_cache(ᶜlocal_geometry, ᶠlocal_geometry, dt) : (;),
23+
)
24+
function additional_tendency!(Yₜ, Y, p, t)
25+
hyperdiffusion_tendency!(Yₜ, Y, p, t)
26+
sponge && rayleigh_sponge_tendency!(Yₜ, Y, p, t)
27+
end
28+
29+
center_initial_condition(local_geometry) =
30+
center_initial_condition(local_geometry, Val(:ρe))
31+
function postprocessing(sol, output_dir)
32+
@info "L₂ norm of ρe at t = $(sol.t[1]): $(norm(sol.u[1].c.ρe))"
33+
@info "L₂ norm of ρe at t = $(sol.t[end]): $(norm(sol.u[end].c.ρe))"
34+
35+
anim = Plots.@animate for Y in sol.u
36+
ᶜv = Geometry.UVVector.(Y.c.uₕ).components.data.:2
37+
Plots.plot(ᶜv, level = 3, clim = (-6, 6))
38+
end
39+
Plots.mp4(anim, joinpath(output_dir, "v.mp4"), fps = 5)
40+
end

ext/cuda/data_layouts.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11

22
import ClimaCore.DataLayouts: AbstractData
33
import ClimaCore.DataLayouts: FusedMultiBroadcast
4-
import ClimaCore.DataLayouts: IJKFVH, IJFH, VIJFH, VIFH, IFH, IJF, IF, VF, DataF
4+
import ClimaCore.DataLayouts:
5+
IJKFVH, IJFH, IJHF, VIJFH, VIJHF, VIFH, VIHF, IFH, IHF, IJF, IF, VF, DataF
56
import ClimaCore.DataLayouts: IJFHStyle, VIJFHStyle, VFStyle, DataFStyle
7+
import ClimaCore.DataLayouts: IJHFStyle, VIJHFStyle
68
import ClimaCore.DataLayouts: promote_parent_array_type
79
import ClimaCore.DataLayouts: parent_array_type
810
import ClimaCore.DataLayouts: isascalar

ext/cuda/data_layouts_mapreduce.jl

+7-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ end
2121
function mapreduce_cuda(
2222
f,
2323
op,
24-
data::Union{DataLayouts.VF, DataLayouts.IJFH, DataLayouts.VIJFH};
24+
data::Union{
25+
DataLayouts.VF,
26+
DataLayouts.IJFH,
27+
DataLayouts.IJHF,
28+
DataLayouts.VIJFH,
29+
DataLayouts.VIJHF,
30+
};
2531
weighted_jacobian = OnesArray(parent(data)),
2632
opargs...,
2733
)

ext/cuda/data_layouts_threadblock.jl

+40-16
Original file line numberDiff line numberDiff line change
@@ -47,24 +47,33 @@ bounds to ensure that the result of
4747
function is_valid_index end
4848

4949
##### VIJFH
50-
@inline function partition(data::DataLayouts.VIJFH, n_max_threads::Integer)
50+
@inline function partition(
51+
data::Union{DataLayouts.VIJFH, DataLayouts.VIJHF},
52+
n_max_threads::Integer,
53+
)
5154
(Nij, _, _, Nv, Nh) = DataLayouts.universal_size(data)
5255
Nv_thread = min(Int(fld(n_max_threads, Nij * Nij)), Nv)
5356
Nv_blocks = cld(Nv, Nv_thread)
5457
@assert prod((Nv_thread, Nij, Nij)) n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Nij, Nij))),$n_max_threads)"
5558
return (; threads = (Nv_thread, Nij, Nij), blocks = (Nv_blocks, Nh))
5659
end
57-
@inline function universal_index(::DataLayouts.VIJFH)
60+
@inline function universal_index(::Union{DataLayouts.VIJFH, DataLayouts.VIJHF})
5861
(tv, i, j) = CUDA.threadIdx()
5962
(bv, h) = CUDA.blockIdx()
6063
v = tv + (bv - 1) * CUDA.blockDim().x
6164
return CartesianIndex((i, j, 1, v, h))
6265
end
63-
@inline is_valid_index(::DataLayouts.VIJFH, I::CI5, us::UniversalSize) =
64-
1 I[4] DataLayouts.get_Nv(us)
66+
@inline is_valid_index(
67+
::Union{DataLayouts.VIJFH, DataLayouts.VIJHF},
68+
I::CI5,
69+
us::UniversalSize,
70+
) = 1 I[4] DataLayouts.get_Nv(us)
6571

6672
##### IJFH
67-
@inline function partition(data::DataLayouts.IJFH, n_max_threads::Integer)
73+
@inline function partition(
74+
data::Union{DataLayouts.IJFH, DataLayouts.IJHF},
75+
n_max_threads::Integer,
76+
)
6877
(Nij, _, _, _, Nh) = DataLayouts.universal_size(data)
6978
Nh_thread = min(
7079
Int(fld(n_max_threads, Nij * Nij)),
@@ -75,31 +84,40 @@ end
7584
@assert prod((Nij, Nij)) n_max_threads "threads,n_max_threads=($(prod((Nij, Nij))),$n_max_threads)"
7685
return (; threads = (Nij, Nij, Nh_thread), blocks = (Nh_blocks,))
7786
end
78-
@inline function universal_index(::DataLayouts.IJFH)
87+
@inline function universal_index(::Union{DataLayouts.IJFH, DataLayouts.IJHF})
7988
(i, j, th) = CUDA.threadIdx()
8089
(bh,) = CUDA.blockIdx()
8190
h = th + (bh - 1) * CUDA.blockDim().z
8291
return CartesianIndex((i, j, 1, 1, h))
8392
end
84-
@inline is_valid_index(::DataLayouts.IJFH, I::CI5, us::UniversalSize) =
85-
1 I[5] DataLayouts.get_Nh(us)
93+
@inline is_valid_index(
94+
::Union{DataLayouts.IJFH, DataLayouts.IJHF},
95+
I::CI5,
96+
us::UniversalSize,
97+
) = 1 I[5] DataLayouts.get_Nh(us)
8698

8799
##### IFH
88-
@inline function partition(data::DataLayouts.IFH, n_max_threads::Integer)
100+
@inline function partition(
101+
data::Union{DataLayouts.IFH, DataLayouts.IHF},
102+
n_max_threads::Integer,
103+
)
89104
(Ni, _, _, _, Nh) = DataLayouts.universal_size(data)
90105
Nh_thread = min(Int(fld(n_max_threads, Ni)), Nh)
91106
Nh_blocks = cld(Nh, Nh_thread)
92107
@assert prod((Ni, Nh_thread)) n_max_threads "threads,n_max_threads=($(prod((Ni, Nh_thread))),$n_max_threads)"
93108
return (; threads = (Ni, Nh_thread), blocks = (Nh_blocks,))
94109
end
95-
@inline function universal_index(::DataLayouts.IFH)
110+
@inline function universal_index(::Union{DataLayouts.IFH, DataLayouts.IHF})
96111
(i, th) = CUDA.threadIdx()
97112
(bh,) = CUDA.blockIdx()
98113
h = th + (bh - 1) * CUDA.blockDim().y
99114
return CartesianIndex((i, 1, 1, 1, h))
100115
end
101-
@inline is_valid_index(::DataLayouts.IFH, I::CI5, us::UniversalSize) =
102-
1 I[5] DataLayouts.get_Nh(us)
116+
@inline is_valid_index(
117+
::Union{DataLayouts.IFH, DataLayouts.IHF},
118+
I::CI5,
119+
us::UniversalSize,
120+
) = 1 I[5] DataLayouts.get_Nh(us)
103121

104122
##### IJF
105123
@inline function partition(data::DataLayouts.IJF, n_max_threads::Integer)
@@ -126,21 +144,27 @@ end
126144
@inline is_valid_index(::DataLayouts.IF, I::CI5, us::UniversalSize) = true
127145

128146
##### VIFH
129-
@inline function partition(data::DataLayouts.VIFH, n_max_threads::Integer)
147+
@inline function partition(
148+
data::Union{DataLayouts.VIFH, DataLayouts.VIHF},
149+
n_max_threads::Integer,
150+
)
130151
(Ni, _, _, Nv, Nh) = DataLayouts.universal_size(data)
131152
Nv_thread = min(Int(fld(n_max_threads, Ni)), Nv)
132153
Nv_blocks = cld(Nv, Nv_thread)
133154
@assert prod((Nv_thread, Ni)) n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Ni))),$n_max_threads)"
134155
return (; threads = (Nv_thread, Ni), blocks = (Nv_blocks, Nh))
135156
end
136-
@inline function universal_index(::DataLayouts.VIFH)
157+
@inline function universal_index(::Union{DataLayouts.VIFH, DataLayouts.VIHF})
137158
(tv, i) = CUDA.threadIdx()
138159
(bv, h) = CUDA.blockIdx()
139160
v = tv + (bv - 1) * CUDA.blockDim().x
140161
return CartesianIndex((i, 1, 1, v, h))
141162
end
142-
@inline is_valid_index(::DataLayouts.VIFH, I::CI5, us::UniversalSize) =
143-
1 I[4] DataLayouts.get_Nv(us)
163+
@inline is_valid_index(
164+
::Union{DataLayouts.VIFH, DataLayouts.VIHF},
165+
I::CI5,
166+
us::UniversalSize,
167+
) = 1 I[4] DataLayouts.get_Nv(us)
144168

145169
##### VF
146170
@inline function partition(data::DataLayouts.VF, n_max_threads::Integer)

0 commit comments

Comments
 (0)