diff --git a/NEWS.md b/NEWS.md index 963737d8a1..d43cacf0b5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,9 @@ ClimaCore.jl Release Notes main ------- + - A `strict = true` keyword was added to `rcompare`, which checks that the types match. If `strict = false`, then `rcompare` will return `true` for `FieldVector`s and `NamedTuple`s with the same properties but permuted order. For example: + - `rcompare((;a=1,b=2), (;b=2,a=1); strict = true)` will return `false` and + - `rcompare((;a=1,b=2), (;b=2,a=1); strict = false)` will return `true` - We've added new datalayouts: `VIJHF`,`IJHF`,`IHF`,`VIHF`, to explore their performance compared to our existing datalayouts: `VIJFH`,`IJFH`,`IFH`,`VIFH`. PR [#2055](https://github.com/CliMA/ClimaCore.jl/pull/2053), PR [#2052](https://github.com/CliMA/ClimaCore.jl/pull/2055). - We've refactored some modules to use less internals. PR [#2053](https://github.com/CliMA/ClimaCore.jl/pull/2053), PR [#2052](https://github.com/CliMA/ClimaCore.jl/pull/2052), [#2051](https://github.com/CliMA/ClimaCore.jl/pull/2051), [#2049](https://github.com/CliMA/ClimaCore.jl/pull/2049). - Some work was done in attempt to reduce specializations and compile time. PR [#2042](https://github.com/CliMA/ClimaCore.jl/pull/2042), [#2041](https://github.com/CliMA/ClimaCore.jl/pull/2041) diff --git a/src/Fields/fieldvector.jl b/src/Fields/fieldvector.jl index 520a86349e..cdb8ce8184 100644 --- a/src/Fields/fieldvector.jl +++ b/src/Fields/fieldvector.jl @@ -469,36 +469,54 @@ end # Recursively compare contents of similar fieldvectors -_rcompare(pass, x::T, y::T) where {T <: Field} = - pass && _rcompare(pass, field_values(x), field_values(y)) -_rcompare(pass, x::T, y::T) where {T <: DataLayouts.AbstractData} = +_rcompare(pass, x::T, y::T; strict) where {T <: Field} = + pass && _rcompare(pass, field_values(x), field_values(y); strict) +_rcompare(pass, x::T, y::T; strict) where {T <: DataLayouts.AbstractData} = pass && (parent(x) == parent(y)) -_rcompare(pass, x::T, y::T) where {T} = pass && (x == y) +_rcompare(pass, x::T, y::T; strict) where {T} = pass && (x == y) -function _rcompare(pass, x::T, y::T) where {T <: Union{FieldVector, NamedTuple}} +_rcompare(pass, x::NamedTuple, y::NamedTuple; strict) = + _rcompare_nt(pass, x, y; strict) +_rcompare(pass, x::FieldVector, y::FieldVector; strict) = + _rcompare_nt(pass, x, y; strict) + +function _rcompare_nt(pass, x, y; strict) + length(propertynames(x)) ≠ length(propertynames(y)) && return false + if strict + typeof(x) == typeof(y) || return false + end for pn in propertynames(x) - pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn)) + pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn); strict) end return pass end """ - rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}} + rcompare(x::T, y::T; strict = true) where {T <: Union{FieldVector, NamedTuple}} Recursively compare given fieldvectors via `==`. Returns `true` if `x == y` recursively. FieldVectors with different types are considered different. """ -rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}} = - _rcompare(true, x, y) +rcompare( + x::T, + y::T; + strict = true, +) where {T <: Union{FieldVector, NamedTuple}} = _rcompare(true, x, y; strict) -rcompare(x::T, y::T) where {T <: FieldVector} = _rcompare(true, x, y) +rcompare(x::T, y::T; strict = true) where {T <: FieldVector} = + _rcompare(true, x, y; strict) -rcompare(x::T, y::T) where {T <: NamedTuple} = _rcompare(true, x, y) +rcompare(x::T, y::T; strict = true) where {T <: NamedTuple} = + _rcompare(true, x, y; strict) # FieldVectors with different types are always different -rcompare(x::FieldVector, y::FieldVector) = false +rcompare(x::FieldVector, y::FieldVector; strict::Bool = true) = + strict ? false : _rcompare(true, x, y; strict) + +rcompare(x::NamedTuple, y::NamedTuple; strict::Bool = true) = + strict ? false : _rcompare(true, x, y; strict) # Define == to call rcompare for two fieldvectors -Base.:(==)(x::FieldVector, y::FieldVector) = rcompare(x, y) +Base.:(==)(x::FieldVector, y::FieldVector) = rcompare(x, y; strict = true) diff --git a/test/Fields/unit_field.jl b/test/Fields/unit_field.jl index b0362c3b0d..ad17f2dcb2 100644 --- a/test/Fields/unit_field.jl +++ b/test/Fields/unit_field.jl @@ -11,6 +11,7 @@ ClimaComms.@import_required_backends using OrderedCollections using StaticArrays, IntervalSets import ClimaCore +import ClimaCore.InputOutput import ClimaCore.Utilities: PlusHalf import ClimaCore.DataLayouts import ClimaCore.DataLayouts: IJFH @@ -330,6 +331,47 @@ end @test occursin("==================== Difference found:", s) end +@testset "Nested FieldVector broadcasting with permuted order" begin + FT = Float32 + context = ClimaComms.context() + vertdomain = Domains.IntervalDomain( + Geometry.ZPoint{FT}(-3.5), + Geometry.ZPoint{FT}(0); + boundary_names = (:bottom, :top), + ) + vertmesh = Meshes.IntervalMesh(vertdomain; nelems = 10) + device = ClimaComms.device() + vert_center_space = Spaces.CenterFiniteDifferenceSpace(device, vertmesh) + horzdomain = Domains.SphereDomain(FT(100)) + horzmesh = Meshes.EquiangularCubedSphere(horzdomain, 1) + horztopology = Topologies.Topology2D(context, horzmesh) + quad = Spaces.Quadratures.GLL{2}() + space = Spaces.SpectralElementSpace2D(horztopology, quad) + + vars1 = (; # order is different! + bucket = (; # nesting is needed! + T = Fields.Field(FT, space), + W = Fields.Field(FT, space), + ) + ) + vars2 = (; # order is different! + bucket = (; # nesting is needed! + W = Fields.Field(FT, space), + T = Fields.Field(FT, space), + ) + ) + Y1 = Fields.FieldVector(; vars1...) + Y1.bucket.T .= 280.0 + Y1.bucket.W .= 0.05 + + Y2 = Fields.FieldVector(; vars2...) + Y2.bucket.T .= 280.0 + Y2.bucket.W .= 0.05 + + Y1 .= Y2 # FieldVector broadcasting + @test Fields.rcompare(Y1, Y2; strict = false) +end + # https://github.com/CliMA/ClimaCore.jl/issues/1465 @testset "Diagonal FieldVector broadcast expressions" begin FT = Float64