From a0f3eab8252e385ea8662723b9c66558fa0afd7c Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Tue, 25 Feb 2025 15:14:04 -0800 Subject: [PATCH] Replace UnrolledFunctions with UnrolledUtilities --- .buildkite/Manifest.toml | 23 ++-- Project.toml | 4 +- benchmarks/bickleyjet/Manifest.toml | 17 ++- docs/src/api.md | 1 - .../matrix_fields_multiple_field_solve.jl | 1 - src/Fields/Fields.jl | 3 +- src/Fields/fieldvector.jl | 7 +- src/MatrixFields/MatrixFields.jl | 3 +- src/MatrixFields/field_name.jl | 5 +- src/MatrixFields/field_name_dict.jl | 2 +- src/Operators/finitedifference.jl | 8 +- src/Utilities/Utilities.jl | 1 - src/Utilities/unrolled_functions.jl | 121 ------------------ 13 files changed, 38 insertions(+), 158 deletions(-) delete mode 100644 src/Utilities/unrolled_functions.jl diff --git a/.buildkite/Manifest.toml b/.buildkite/Manifest.toml index c6d180439b..bb0f1616bc 100644 --- a/.buildkite/Manifest.toml +++ b/.buildkite/Manifest.toml @@ -327,10 +327,10 @@ weakdeps = ["CUDA", "MPI"] ClimaCommsMPIExt = "MPI" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "Unrolled"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "UnrolledUtilities"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" -version = "0.14.24" +version = "0.14.26" weakdeps = ["CUDA", "Krylov"] [deps.ClimaCore.extensions] @@ -344,7 +344,7 @@ uuid = "cf7c7e5a-b407-4c48-9047-11a94a308626" version = "0.2.11" [[deps.ClimaCoreTempestRemap]] -deps = ["ClimaComms", "ClimaCore", "CommonDataModel", "Dates", "LinearAlgebra", "NCDatasets", "PkgVersion", "TempestRemap_jll"] +deps = ["ClimaComms", "ClimaCore", "CommonDataModel", "Dates", "DiskArrays", "LinearAlgebra", "NCDatasets", "PkgVersion", "TempestRemap_jll"] path = "../lib/ClimaCoreTempestRemap" uuid = "d934ef94-cdd4-4710-83d6-720549644b70" version = "0.3.18" @@ -671,9 +671,9 @@ version = "0.6.40" [[deps.DiskArrays]] deps = ["LRUCache", "Mmap", "OffsetArrays"] -git-tree-sha1 = "64650943240652ebedc6c43d03cccda247b327a3" +git-tree-sha1 = "4687e77a603fcd86738a92758086717cd06cdaae" uuid = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3" -version = "0.4.9" +version = "0.4.8" [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] @@ -2281,11 +2281,14 @@ git-tree-sha1 = "975c354fcd5f7e1ddcc1f1a23e6e091d99e99bc8" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.4" -[[deps.Unrolled]] -deps = ["MacroTools"] -git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" -uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" -version = "0.1.5" +[[deps.UnrolledUtilities]] +git-tree-sha1 = "5caf11dfadeee25daafa7caabb3f252a977ffe72" +uuid = "0fe1646c-419e-43be-ac14-22321958931b" +version = "0.1.6" +weakdeps = ["StaticArrays"] + + [deps.UnrolledUtilities.extensions] + UnrolledUtilitiesStaticArraysExt = "StaticArrays" [[deps.UnsafeAtomics]] git-tree-sha1 = "b13c4edda90890e5b04ba24e20a310fbe6f249ff" diff --git a/Project.toml b/Project.toml index 3fb6c5ca17..aa7b8eff4d 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ RootSolvers = "7181ea78-2dcb-4de3-ab41-2b8ab5a31e74" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -77,7 +77,7 @@ Statistics = "1" StatsBase = "0.34" TerminalLoggers = "0.1" Test = "1" -Unrolled = "0.1.5" +UnrolledUtilities = "0.1.6" julia = "1.10" [extras] diff --git a/benchmarks/bickleyjet/Manifest.toml b/benchmarks/bickleyjet/Manifest.toml index e330b5b745..572de0308a 100644 --- a/benchmarks/bickleyjet/Manifest.toml +++ b/benchmarks/bickleyjet/Manifest.toml @@ -232,10 +232,10 @@ version = "0.6.6" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "Unrolled"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "UnrolledUtilities"] path = "../.." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" -version = "0.14.24" +version = "0.14.26" [deps.ClimaCore.extensions] ClimaCoreCUDAExt = "CUDA" @@ -1490,11 +1490,14 @@ git-tree-sha1 = "975c354fcd5f7e1ddcc1f1a23e6e091d99e99bc8" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.4" -[[deps.Unrolled]] -deps = ["MacroTools"] -git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" -uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" -version = "0.1.5" +[[deps.UnrolledUtilities]] +git-tree-sha1 = "5caf11dfadeee25daafa7caabb3f252a977ffe72" +uuid = "0fe1646c-419e-43be-ac14-22321958931b" +version = "0.1.6" +weakdeps = ["StaticArrays"] + + [deps.UnrolledUtilities.extensions] + UnrolledUtilitiesStaticArraysExt = "StaticArrays" [[deps.UnsafeAtomics]] git-tree-sha1 = "b13c4edda90890e5b04ba24e20a310fbe6f249ff" diff --git a/docs/src/api.md b/docs/src/api.md index fb4d8b2ea5..77073fe8c1 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -9,7 +9,6 @@ CurrentModule = ClimaCore ```@docs Utilities.PlusHalf Utilities.half -Utilities.UnrolledFunctions ``` ### Utilities.Cache diff --git a/ext/cuda/matrix_fields_multiple_field_solve.jl b/ext/cuda/matrix_fields_multiple_field_solve.jl index 76c0b6e3eb..dadc7f5fbc 100644 --- a/ext/cuda/matrix_fields_multiple_field_solve.jl +++ b/ext/cuda/matrix_fields_multiple_field_solve.jl @@ -6,7 +6,6 @@ import ClimaCore.MatrixFields import ClimaCore.MatrixFields: _single_field_solve! import ClimaCore.MatrixFields: multiple_field_solve! import ClimaCore.MatrixFields: is_CuArray_type -import ClimaCore.Utilities.UnrolledFunctions: unrolled_map is_CuArray_type(::Type{T}) where {T <: CUDA.CuArray} = true diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index 94f0fd73e3..6bb7f5da37 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -23,11 +23,12 @@ import ..Grids: ColumnIndex, local_geometry_type import ..Spaces: Spaces, AbstractSpace, AbstractPointSpace, cuda_synchronize import ..Spaces: nlevels, ncolumns import ..Geometry: Geometry, Cartesian12Vector -import ..Utilities: PlusHalf, half, UnrolledFunctions +import ..Utilities: PlusHalf, half using ..RecursiveApply using ClimaComms import Adapt +import UnrolledUtilities: unrolled_map import StaticArrays, LinearAlgebra, Statistics, InteractiveUtils diff --git a/src/Fields/fieldvector.jl b/src/Fields/fieldvector.jl index b88fd37ef4..a824f38bc8 100644 --- a/src/Fields/fieldvector.jl +++ b/src/Fields/fieldvector.jl @@ -429,7 +429,7 @@ LinearAlgebra.ldiv!(A::LinearAlgebra.LU, x::FieldVector) = x .= LinearAlgebra.ldiv!(A, Vector(x)) function LinearAlgebra.norm_sqr(x::FieldVector) - value_norm_sqrs = UnrolledFunctions.unrolled_map(_values(x)) do value + value_norm_sqrs = unrolled_map(_values(x)) do value LinearAlgebra.norm_sqr(backing_array(value)) end return sum(value_norm_sqrs; init = zero(eltype(x))) @@ -440,9 +440,8 @@ end import ClimaComms -ClimaComms.array_type(x::FieldVector) = promote_type( - UnrolledFunctions.unrolled_map(ClimaComms.array_type, _values(x))..., -) +ClimaComms.array_type(x::FieldVector) = + promote_type(unrolled_map(ClimaComms.array_type, _values(x))...) function __rprint_diff( io::IO, diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index 1a63772d3f..3d4725ca48 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -52,6 +52,7 @@ import KrylovKit import ClimaComms import NVTX import Adapt +using UnrolledUtilities import ..Utilities: PlusHalf, half import ..RecursiveApply: @@ -65,8 +66,6 @@ import ..Spaces import ..Spaces: local_geometry_type import ..Fields import ..Operators - -using ..Utilities.UnrolledFunctions using ..Geometry: rmul_with_projection, mul_with_projection, diff --git a/src/MatrixFields/field_name.jl b/src/MatrixFields/field_name.jl index d961f375a4..bd05f844c1 100644 --- a/src/MatrixFields/field_name.jl +++ b/src/MatrixFields/field_name.jl @@ -160,10 +160,11 @@ get_subtree_at_name(name, tree) = if name == tree.name tree else - subtree = unrolled_findonly(tree.subtrees) do subtree + subtrees_at_name = unrolled_filter(tree.subtrees) do subtree is_valid_name(name, subtree) end - get_subtree_at_name(name, subtree) + @assert length(subtrees_at_name) == 1 + get_subtree_at_name(name, subtrees_at_name[1]) end ################################################################################ diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index f5ad66642b..c827d7fd77 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -137,7 +137,7 @@ Base.:(==)(dict1::FieldNameDict, dict2::FieldNameDict) = function Base.getindex(dict::FieldNameDict, key) key in keys(dict) || throw(KeyError(key)) key′, entry′ = - unrolled_findonly(pair -> is_child_value(key, pair[1]), pairs(dict)) + unrolled_filter(pair -> is_child_value(key, pair[1]), pairs(dict))[1] internal_key = get_internal_key(key, key′) return get_internal_entry(entry′, internal_key, KeyError(key)) end diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index f9f4dd49f3..4187fe4c03 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -1,4 +1,5 @@ -import ..Utilities: PlusHalf, half, UnrolledFunctions +import ..Utilities: PlusHalf, half +import UnrolledUtilities: unrolled_map const AllFiniteDifferenceSpace = Union{Spaces.FiniteDifferenceSpace, Spaces.ExtrudedFiniteDifferenceSpace} @@ -3324,10 +3325,7 @@ Adapt.adapt_structure(to, op::FiniteDifferenceOperator) = unionall_type(typeof(op))(; adapt_bcs(to, bcs)...) @inline adapt_bcs(to, bcs) = NamedTuple{keys(bcs)}( - UnrolledFunctions.unrolled_map( - bc -> Adapt.adapt_structure(to, bc), - values(bcs), - ), + unrolled_map(bc -> Adapt.adapt_structure(to, bc), values(bcs)), ) """ diff --git a/src/Utilities/Utilities.jl b/src/Utilities/Utilities.jl index d74b310fc6..3c6a4bcd9e 100644 --- a/src/Utilities/Utilities.jl +++ b/src/Utilities/Utilities.jl @@ -1,7 +1,6 @@ module Utilities include("plushalf.jl") -include("unrolled_functions.jl") include("cache.jl") """ diff --git a/src/Utilities/unrolled_functions.jl b/src/Utilities/unrolled_functions.jl deleted file mode 100644 index 95577fc74e..0000000000 --- a/src/Utilities/unrolled_functions.jl +++ /dev/null @@ -1,121 +0,0 @@ -""" - UnrolledFunctions - -A collection of generated functions that get unrolled during compilation, which -make it possible to iterate over nonuniform collections without sacrificing -type-stability. - -The functions exported by this module are -- `unrolled_map(f, values, [values2])`: alternative to `map` -- `unrolled_any(f, values)`: alternative to `any` -- `unrolled_all(f, values)`: alternative to `all` -- `unrolled_filter(f, values)`: alternative to `filter` -- `unrolled_foreach(f, values)`: alternative to `foreach` -- `unrolled_in(value, values)`: alternative to `in` -- `unrolled_unique(values)`: alternative to `unique` -- `unrolled_flatten(values)`: alternative to `Iterators.flatten` -- `unrolled_flatmap(f, values)`: alternative to `Iterators.flatmap` -- `unrolled_product(values1, values2)`: alternative to `Iterators.product` -- `unrolled_findonly(f, values)`: checks that only one value satisfies `f`, and - then returns that value -- `unrolled_split(f, values)`: returns a tuple that contains the result of - calling `unrolled_filter` with `f` and the result of calling it with `!f` -- `unrolled_take(values, ::Val{N})`: alternative to `Iterators.take`, but with - an `Int` wrapped in a `Val` as the second argument instead of a regular `Int`; - this usually compiles more quickly than `values[1:N]` -- `unrolled_drop(values, ::Val{N})`: alternative to `Iterators.drop`, but with - an `Int` wrapped in a `Val` as the second argument instead of a regular `Int`; - this usually compiles more quickly than `values[(end - N + 1):end]` -""" -module UnrolledFunctions - -import Unrolled -import Unrolled: @unroll - -export unrolled_map, - unrolled_any, - unrolled_all, - unrolled_filter, - unrolled_foreach, - unrolled_in, - unrolled_unique, - unrolled_flatten, - unrolled_flatmap, - unrolled_product, - unrolled_findonly, - unrolled_split, - unrolled_take, - unrolled_drop - -# The definitions of unrolled_map and unrolled_any are copied over from -# Unrolled.jl, but their recursion limits are disabled here. As of Julia 1.9, we -# need to remove their recursion limits so that we can use them to implement -# recursion in other functions without any type-instabilities. For example, if a -# function f needs to map over some values, and if the computation for each -# value recursively calls f, then the map can be implemented using unrolled_map. - -@generated unrolled_map(f, values) = - :(tuple($((:(f(values[$i])) for i in 1:Unrolled.type_length(values))...))) - -@generated function unrolled_map(f, values1, values2) - N = Unrolled.type_length(values1) - @assert N == Unrolled.type_length(values2) - :(tuple($((:(f(values1[$i], values2[$i])) for i in 1:N)...))) -end - -@unroll function unrolled_any(f, values) - @unroll for value in values - f(value) && return true - end - return false -end - -if hasfield(Method, :recursion_relation) - dont_limit = (args...) -> true - for m in methods(unrolled_map) - m.recursion_relation = dont_limit - end - for m in methods(unrolled_any) - m.recursion_relation = dont_limit - end -end - -const unrolled_all = Unrolled.unrolled_all -const unrolled_filter = Unrolled.unrolled_filter -const unrolled_foreach = Unrolled.unrolled_foreach -const unrolled_in = Unrolled.unrolled_in - -# Note: Unrolled.unrolled_reduce passes the arguments to its input function in -# reverse order (as of version 0.1 of Unrolled.jl). - -unrolled_unique(values) = - Unrolled.unrolled_reduce((), values) do value, unique_values - unrolled_in(value, unique_values) ? unique_values : - (unique_values..., value) - end - -unrolled_flatten(values) = - Unrolled.unrolled_reduce((tup2, tup1) -> (tup1..., tup2...), (), values) - -unrolled_flatmap(f::F, values) where {F} = - unrolled_flatten(unrolled_map(f, values)) - -unrolled_product(values1, values2) = - unrolled_flatmap(values1) do value1 - unrolled_map(value2 -> (value1, value2), values2) - end - -function unrolled_findonly(f::F, values) where {F} - filtered_values = unrolled_filter(f, values) - return length(filtered_values) == 1 ? filtered_values[1] : - error("unrolled_findonly requires that exactly 1 value makes f true") -end - -unrolled_split(f::F, values) where {F} = - (unrolled_filter(f, values), unrolled_filter(!f, values)) - -unrolled_take(values, ::Val{N}) where {N} = ntuple(i -> values[i], Val(N)) -unrolled_drop(values, ::Val{N}) where {N} = - ntuple(i -> values[N + i], Val(length(values) - N)) - -end # module