Skip to content

Commit 6004569

Browse files
Add fieldvector unit tests
1 parent 9a05ed2 commit 6004569

File tree

2 files changed

+93
-13
lines changed

2 files changed

+93
-13
lines changed

src/Fields/fieldvector.jl

+31-13
Original file line numberDiff line numberDiff line change
@@ -469,36 +469,54 @@ end
469469

470470

471471
# Recursively compare contents of similar fieldvectors
472-
_rcompare(pass, x::T, y::T) where {T <: Field} =
473-
pass && _rcompare(pass, field_values(x), field_values(y))
474-
_rcompare(pass, x::T, y::T) where {T <: DataLayouts.AbstractData} =
472+
_rcompare(pass, x::T, y::T; strict) where {T <: Field} =
473+
pass && _rcompare(pass, field_values(x), field_values(y); strict)
474+
_rcompare(pass, x::T, y::T; strict) where {T <: DataLayouts.AbstractData} =
475475
pass && (parent(x) == parent(y))
476-
_rcompare(pass, x::T, y::T) where {T} = pass && (x == y)
476+
_rcompare(pass, x::T, y::T; strict) where {T} = pass && (x == y)
477477

478-
function _rcompare(pass, x::T, y::T) where {T <: Union{FieldVector, NamedTuple}}
478+
_rcompare(pass, x::NamedTuple, y::NamedTuple; strict) =
479+
_rcompare_nt(pass, x, y; strict)
480+
_rcompare(pass, x::FieldVector, y::FieldVector; strict) =
481+
_rcompare_nt(pass, x, y; strict)
482+
483+
function _rcompare_nt(pass, x, y; strict)
484+
length(propertynames(x)) length(propertynames(y)) && return false
485+
if strict
486+
typeof(x) == typeof(y) || return false
487+
end
479488
for pn in propertynames(x)
480-
pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn))
489+
pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn); strict)
481490
end
482491
return pass
483492
end
484493

485494
"""
486-
rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}}
495+
rcompare(x::T, y::T; strict = true) where {T <: Union{FieldVector, NamedTuple}}
487496
488497
Recursively compare given fieldvectors via `==`.
489498
Returns `true` if `x == y` recursively.
490499
491500
FieldVectors with different types are considered different.
492501
"""
493-
rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}} =
494-
_rcompare(true, x, y)
502+
rcompare(
503+
x::T,
504+
y::T;
505+
strict = true,
506+
) where {T <: Union{FieldVector, NamedTuple}} = _rcompare(true, x, y; strict)
495507

496-
rcompare(x::T, y::T) where {T <: FieldVector} = _rcompare(true, x, y)
508+
rcompare(x::T, y::T; strict = true) where {T <: FieldVector} =
509+
_rcompare(true, x, y; strict)
497510

498-
rcompare(x::T, y::T) where {T <: NamedTuple} = _rcompare(true, x, y)
511+
rcompare(x::T, y::T; strict = true) where {T <: NamedTuple} =
512+
_rcompare(true, x, y; strict)
499513

500514
# FieldVectors with different types are always different
501-
rcompare(x::FieldVector, y::FieldVector) = false
515+
rcompare(x::FieldVector, y::FieldVector; strict::Bool = true) =
516+
strict ? false : _rcompare(true, x, y; strict)
517+
518+
rcompare(x::NamedTuple, y::NamedTuple; strict::Bool = true) =
519+
strict ? false : _rcompare(true, x, y; strict)
502520

503521
# Define == to call rcompare for two fieldvectors
504-
Base.:(==)(x::FieldVector, y::FieldVector) = rcompare(x, y)
522+
Base.:(==)(x::FieldVector, y::FieldVector) = rcompare(x, y; strict = true)

test/Fields/unit_field.jl

+62
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ ClimaComms.@import_required_backends
1111
using OrderedCollections
1212
using StaticArrays, IntervalSets
1313
import ClimaCore
14+
import ClimaCore.InputOutput
1415
import ClimaCore.Utilities: PlusHalf
1516
import ClimaCore.DataLayouts
1617
import ClimaCore.DataLayouts: IJFH
@@ -330,6 +331,67 @@ end
330331
@test occursin("==================== Difference found:", s)
331332
end
332333

334+
@testset "FieldVector restart + broadcasting" begin
335+
FT = Float32
336+
root_path = "bucket_restart"
337+
rm(root_path; force = true, recursive = true)
338+
mkpath(root_path)
339+
340+
radius = FT(100)
341+
depth = FT(3.5)
342+
nelements = (1, 10)
343+
npolynomial = 1
344+
dz_tuple = nothing
345+
context = ClimaComms.context()
346+
347+
vertdomain = Domains.IntervalDomain(
348+
Geometry.ZPoint(FT(-depth)),
349+
Geometry.ZPoint(FT(0));
350+
boundary_names = (:bottom, :top),
351+
)
352+
vertmesh = Meshes.IntervalMesh(vertdomain; nelems = nelements[2])
353+
device = ClimaComms.device()
354+
vert_center_space = Spaces.CenterFiniteDifferenceSpace(device, vertmesh)
355+
horzdomain = Domains.SphereDomain(radius)
356+
horzmesh = Meshes.EquiangularCubedSphere(horzdomain, nelements[1])
357+
horztopology = Topologies.Topology2D(context, horzmesh)
358+
quad = Spaces.Quadratures.GLL{npolynomial + 1}()
359+
space = Spaces.SpectralElementSpace2D(horztopology, quad)
360+
361+
vars = (;
362+
bucket = (;
363+
W = Fields.Field(FT, space),
364+
T = Fields.Field(FT, space),
365+
Ws = Fields.Field(FT, space),
366+
σS = Fields.Field(FT, space),
367+
)
368+
)
369+
370+
Y = Fields.FieldVector(; vars...)
371+
Y.bucket.T .= 280.0
372+
Y.bucket.W .= 0.05
373+
Y.bucket.Ws .= 0.0
374+
Y.bucket.σS .= 0.08
375+
376+
output_file = joinpath(root_path, "day0.0.hdf5")
377+
hdfwriter = InputOutput.HDF5Writer(output_file, context)
378+
InputOutput.write_attributes!(
379+
hdfwriter,
380+
"/",
381+
Dict("time" => 0.0, "hash" => "1234"),
382+
)
383+
InputOutput.write!(hdfwriter, Y, "Y")
384+
Base.close(hdfwriter)
385+
386+
restart_file = output_file
387+
Y_restart = Fields.FieldVector(; vars...)
388+
hdfreader = InputOutput.HDF5Reader(restart_file, context)
389+
Y_restart = InputOutput.read_field(hdfreader, "Y")
390+
close(hdfreader)
391+
Y .= Y_restart
392+
@test Fields.rcompare(Y, Y_restart; strict = false)
393+
end
394+
333395
# https://github.com/CliMA/ClimaCore.jl/issues/1465
334396
@testset "Diagonal FieldVector broadcast expressions" begin
335397
FT = Float64

0 commit comments

Comments
 (0)