Skip to content

Commit 4369a5a

Browse files
Merge pull request #1982 from CliMA/ck/inference_repro2
Add a broken inference test for field broadcasting
2 parents 42cd28b + 1f093c5 commit 4369a5a

File tree

1 file changed

+56
-20
lines changed

1 file changed

+56
-20
lines changed

test/Fields/field_opt.jl

+56-20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
#=
2+
julia --project
3+
using Revise; include(joinpath("test", "Fields", "field_opt.jl"))
4+
=#
15
# These tests require running with `--check-bounds=[auto|no]`
26
using Test
37
using StaticArrays, IntervalSets
@@ -307,27 +311,28 @@ end
307311
end
308312

309313
# https://github.com/CliMA/ClimaCore.jl/issues/1062
314+
function toy_sphere(::Type{FT}) where {FT}
315+
context = ClimaComms.context()
316+
helem = npoly = 2
317+
hdomain = Domains.SphereDomain(FT(1e7))
318+
hmesh = Meshes.EquiangularCubedSphere(hdomain, helem)
319+
htopology = Topologies.Topology2D(context, hmesh)
320+
quad = Quadratures.GLL{npoly + 1}()
321+
hspace = Spaces.SpectralElementSpace2D(htopology, quad)
322+
vdomain = Domains.IntervalDomain(
323+
Geometry.ZPoint{FT}(zero(FT)),
324+
Geometry.ZPoint{FT}(FT(1e4));
325+
boundary_names = (:bottom, :top),
326+
)
327+
vmesh = Meshes.IntervalMesh(vdomain, nelems = 4)
328+
vtopology = Topologies.IntervalTopology(context, vmesh)
329+
vspace = Spaces.CenterFiniteDifferenceSpace(vtopology)
330+
center_space = Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace)
331+
face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space)
332+
return (center_space, face_space)
333+
end
334+
310335
@testset "Allocations with copyto! on FieldVectors" begin
311-
function toy_sphere(::Type{FT}) where {FT}
312-
context = ClimaComms.context()
313-
helem = npoly = 2
314-
hdomain = Domains.SphereDomain(FT(1e7))
315-
hmesh = Meshes.EquiangularCubedSphere(hdomain, helem)
316-
htopology = Topologies.Topology2D(context, hmesh)
317-
quad = Quadratures.GLL{npoly + 1}()
318-
hspace = Spaces.SpectralElementSpace2D(htopology, quad)
319-
vdomain = Domains.IntervalDomain(
320-
Geometry.ZPoint{FT}(zero(FT)),
321-
Geometry.ZPoint{FT}(FT(1e4));
322-
boundary_names = (:bottom, :top),
323-
)
324-
vmesh = Meshes.IntervalMesh(vdomain, nelems = 4)
325-
vtopology = Topologies.IntervalTopology(context, vmesh)
326-
vspace = Spaces.CenterFiniteDifferenceSpace(vtopology)
327-
center_space = Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace)
328-
face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space)
329-
return (center_space, face_space)
330-
end
331336
function field_vec(center_space, face_space)
332337
Y = Fields.FieldVector(
333338
c = map(Fields.coordinate_field(center_space)) do coord
@@ -357,4 +362,35 @@ end
357362
palloc = @allocated foo!(obj)
358363
@test palloc == 0
359364
end
365+
366+
struct VarTimescaleAcnv{FT}
367+
τ::FT
368+
α::FT
369+
end
370+
Base.broadcastable(x::VarTimescaleAcnv) = tuple(x)
371+
function conv_q_liq_to_q_rai(
372+
(; τ, α)::VarTimescaleAcnv{FT},
373+
q_liq::FT,
374+
ρ::FT,
375+
N_d::FT,
376+
) where {FT}
377+
return max(0, q_liq) / (1 * (N_d / 1e8)^1)
378+
end
379+
function ifelsekernel!(Sᵖ, ρ)
380+
var = VarTimescaleAcnv(1.0, 2.0)
381+
@. Sᵖ = ifelse(false, 1.0, conv_q_liq_to_q_rai(var, 2.0, ρ, 2.0))
382+
return nothing
383+
end
384+
385+
using JET
386+
# https://github.com/CliMA/ClimaCore.jl/issues/1981
387+
# TODO: improve the testset name once we better under
388+
@testset "ifelse kernel" begin
389+
(cspace, fspace) = toy_sphere(Float64)
390+
ρ = Fields.Field(Float64, cspace)
391+
S = Fields.Field(Float64, cspace)
392+
ifelsekernel!(S, ρ)
393+
@test_opt broken = true ifelsekernel!(S, ρ)
394+
end
395+
360396
nothing

0 commit comments

Comments
 (0)