Skip to content

Commit 25ce5ee

Browse files
Replace Unrolled with UnrolledUtilities
1 parent 4e7c44d commit 25ce5ee

11 files changed

+37
-141
lines changed

Project.toml

+5-7
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ RootSolvers = "7181ea78-2dcb-4de3-ab41-2b8ab5a31e74"
2727
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2828
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2929
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
30-
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
30+
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
3131

3232
[weakdeps]
3333
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -45,13 +45,13 @@ AssociatedLegendrePolynomials = "1"
4545
BandedMatrices = "0.17, 1"
4646
BenchmarkTools = "1"
4747
BlockArrays = "0.16, 1"
48+
CUDA = "5"
4849
ClimaComms = "0.6"
4950
Combinatorics = "1"
5051
CountFlops = "0.1"
5152
CubedSphere = "0.2, 0.3"
52-
CUDA = "5"
53-
Dates = "1"
5453
DataStructures = "0.18"
54+
Dates = "1"
5555
DocStringExtensions = "0.8, 0.9"
5656
FastBroadcast = "0.3"
5757
ForwardDiff = "0.10"
@@ -63,8 +63,8 @@ IntervalSets = "0.5, 0.6, 0.7"
6363
JET = "0.9"
6464
Krylov = "0.9"
6565
KrylovKit = "0.6, 0.7, 0.8"
66-
LinearAlgebra = "1"
6766
LazyBroadcast = "0.1"
67+
LinearAlgebra = "1"
6868
Logging = "1"
6969
MPI = "0.20"
7070
MultiBroadcastFusion = "0.3"
@@ -82,7 +82,6 @@ Statistics = "1"
8282
StatsBase = "0.34"
8383
TerminalLoggers = "0.1"
8484
Test = "1"
85-
Unrolled = "0.1"
8685
julia = "1.9"
8786

8887
[extras]
@@ -94,8 +93,8 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
9493
CountFlops = "1db9610d-79e1-487a-8d40-77f3295c7593"
9594
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
9695
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
97-
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
9896
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
97+
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
9998
LazyBroadcast = "9dccce8e-a116-406d-9fcc-a88ed4f510c8"
10099
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
101100
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
@@ -109,4 +108,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
109108

110109
[targets]
111110
test = ["Aqua", "ArgParse", "AssociatedLegendrePolynomials", "BenchmarkTools", "Combinatorics", "CountFlops", "Dates", "FastBroadcast", "Krylov", "JET", "LazyBroadcast", "Logging", "MPI", "OrderedCollections", "PrettyTables", "Random", "SafeTestsets", "StatsBase", "TerminalLoggers", "Test"]
112-

docs/src/api.md

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ CurrentModule = ClimaCore
99
```@docs
1010
Utilities.PlusHalf
1111
Utilities.half
12-
Utilities.UnrolledFunctions
1312
```
1413

1514
### Utilities.Cache

ext/cuda/matrix_fields_multiple_field_solve.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import ClimaCore.MatrixFields
66
import ClimaCore.MatrixFields: _single_field_solve!
77
import ClimaCore.MatrixFields: multiple_field_solve!
88
import ClimaCore.MatrixFields: is_CuArray_type
9-
import ClimaCore.Utilities.UnrolledFunctions: unrolled_map
9+
import UnrolledUtilities: unrolled_map
1010

1111
is_CuArray_type(::Type{T}) where {T <: CUDA.CuArray} = true
1212

src/Fields/Fields.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ import ..Quadratures
1717
import ..Grids: ColumnIndex, local_geometry_type
1818
import ..Spaces: Spaces, AbstractSpace, AbstractPointSpace, cuda_synchronize
1919
import ..Geometry: Geometry, Cartesian12Vector
20-
import ..Utilities: PlusHalf, half, UnrolledFunctions
20+
import ..Utilities: PlusHalf, half
21+
import UnrolledUtilities
2122

2223
using ..RecursiveApply
2324
using ClimaComms

src/Fields/fieldvector.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ LinearAlgebra.ldiv!(A::LinearAlgebra.LU, x::FieldVector) =
352352
x .= LinearAlgebra.ldiv!(A, Vector(x))
353353

354354
function LinearAlgebra.norm_sqr(x::FieldVector)
355-
value_norm_sqrs = UnrolledFunctions.unrolled_map(_values(x)) do value
355+
value_norm_sqrs = UnrolledUtilities.unrolled_map(_values(x)) do value
356356
LinearAlgebra.norm_sqr(backing_array(value))
357357
end
358358
return sum(value_norm_sqrs; init = zero(eltype(x)))
@@ -364,7 +364,7 @@ end
364364
import ClimaComms
365365

366366
ClimaComms.array_type(x::FieldVector) = promote_type(
367-
UnrolledFunctions.unrolled_map(ClimaComms.array_type, _values(x))...,
367+
UnrolledUtilities.unrolled_map(ClimaComms.array_type, _values(x))...,
368368
)
369369

370370
function __rprint_diff(

src/MatrixFields/MatrixFields.jl

+14-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,20 @@ import ..Spaces: local_geometry_type
6666
import ..Fields
6767
import ..Operators
6868

69-
using ..Utilities.UnrolledFunctions
69+
using UnrolledUtilities:
70+
unrolled_map,
71+
unrolled_take,
72+
unrolled_drop,
73+
unrolled_any,
74+
unrolled_foreach,
75+
unrolled_filter,
76+
unrolled_flatmap,
77+
unrolled_all,
78+
unrolled_in,
79+
unrolled_product,
80+
unrolled_unique,
81+
unrolled_split
82+
7083
using ..Geometry:
7184
rmul_with_projection,
7285
mul_with_projection,

src/MatrixFields/field_name.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,18 @@ function child_names(name, tree)
144144
subtree isa FieldNameTreeNode || error("$name does not have child names")
145145
return unrolled_map(subsubtree -> subsubtree.name, subtree.subtrees)
146146
end
147-
get_subtree_at_name(name, tree) =
147+
function get_subtree_at_name(name, tree)
148148
if name == tree.name
149149
tree
150150
else
151-
subtree = unrolled_findonly(tree.subtrees) do subtree
151+
filtered_values = unrolled_filter(tree.subtrees) do subtree
152152
is_valid_name(name, subtree)
153153
end
154+
@assert length(filtered_values) == 1
155+
subtree = filtered_values[1]
154156
get_subtree_at_name(name, subtree)
155157
end
158+
end
156159

157160
################################################################################
158161

src/MatrixFields/field_name_dict.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,11 @@ Base.:(==)(dict1::FieldNameDict, dict2::FieldNameDict) =
120120

121121
function Base.getindex(dict::FieldNameDict, key)
122122
key in keys(dict) || throw(KeyError(key))
123-
key′, entry′ =
124-
unrolled_findonly(pair -> is_child_value(key, pair[1]), pairs(dict))
123+
filtered_values = unrolled_filter(pairs(dict)) do pair
124+
is_child_value(key, pair[1])
125+
end
126+
@assert length(filtered_values) == 1
127+
key′, entry′ = filtered_values[1]
125128
return get_internal_entry(entry′, get_internal_key(key, key′))
126129
end
127130

src/Operators/finitedifference.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import ..Utilities: PlusHalf, half, UnrolledFunctions
1+
import ..Utilities: PlusHalf, half
2+
import UnrolledUtilities
23

34
const AllFiniteDifferenceSpace =
45
Union{Spaces.FiniteDifferenceSpace, Spaces.ExtrudedFiniteDifferenceSpace}
@@ -2624,7 +2625,7 @@ end
26242625
function Adapt.adapt_structure(to, op::FiniteDifferenceOperator)
26252626
if hasfield(typeof(op), :bcs)
26262627
bcs_adapted = NamedTuple{keys(op.bcs)}(
2627-
UnrolledFunctions.unrolled_map(
2628+
UnrolledUtilities.unrolled_map(
26282629
bc -> Adapt.adapt_structure(to, bc),
26292630
values(op.bcs),
26302631
),

src/Utilities/Utilities.jl

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module Utilities
22

33
include("plushalf.jl")
4-
include("unrolled_functions.jl")
54
include("cache.jl")
65

76
"""

src/Utilities/unrolled_functions.jl

-121
This file was deleted.

0 commit comments

Comments
 (0)