-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathMatrixFields.jl
136 lines (126 loc) · 5.54 KB
/
MatrixFields.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
MatrixFields
This module adds support for defining and manipulating `Field`s that represent
matrices. Specifically, it adds the `BandMatrixRow` type, which can be used
to store the entries of a band matrix. A `Field` of `BandMatrixRow`s on a
`FiniteDifferenceSpace` can be interpreted as a band matrix by vertically
concatenating the `BandMatrixRow`s. Similarly, a `Field` of `BandMatrixRow`s on
an `ExtrudedFiniteDifferenceSpace` can be interpreted as a collection of band
matrices, one for each column of the `Field`. Such `Field`s are called
`ColumnwiseBandMatrixField`s, and this module adds the following functionality
for them:
- Constructors, e.g., `matrix_field = @. BidiagonalMatrixRow(field1, field2)`
- Linear combinations, e.g., `@. 3 * matrix_field1 + matrix_field2 / 3`
- Matrix-vector multiplication, e.g., `@. matrix_field ⋅ field`
- Matrix-matrix multiplication, e.g., `@. matrix_field1 ⋅ matrix_field2`
- Compatibility with `LinearAlgebra.I`, e.g., `@. matrix_field = (4I,)` or
`@. matrix_field - (4I,)`
- Integration with `RecursiveApply`, e.g., the entries of `matrix_field` can be
`Tuple`s or `NamedTuple`s instead of single values, which allows
`matrix_field` to represent multiple band matrices at the same time
- Integration with `Operators`, e.g., the `matrix_field` that gets applied to
the argument of any `FiniteDifferenceOperator` `op` can be obtained using
the `FiniteDifferenceOperator` `operator_matrix(op)`
- Conversions to native array types, e.g., `field2arrays(matrix_field)` can
convert each column of `matrix_field` into a `BandedMatrix` from
`BandedMatrices.jl`
- Custom printing, e.g., `matrix_field` gets displayed as a `BandedMatrix`,
specifically, as the `BandedMatrix` that corresponds to its first column
This module also adds support for defining and manipulating sparse block
matrices of `Field`s. Specifically, it adds the `FieldMatrix` type, which is a
dictionary that maps pairs of `FieldName`s to `ColumnwiseBandMatrixField`s or
multiples of `LinearAlgebra.I`. This comes with the following functionality:
- Addition and subtraction, e.g., `@. field_matrix1 + field_matrix2`
- Matrix-vector multiplication, e.g., `@. field_matrix * field_vector`
- Matrix-matrix multiplication, e.g., `@. field_matrix1 * field_matrix2`
- Integration with `RecursiveApply`, e.g., the entries of `field_matrix` can be
specified either as matrix `Field`s of `Tuple`s or `NamedTuple`s, or as
separate matrix `Field`s of single values
- The ability to solve linear equations using `FieldMatrixSolver`, which is a
generalization of `ldiv!` that is designed to optimize solver performance
"""
module MatrixFields
import LinearAlgebra: I, UniformScaling, Adjoint, AdjointAbsVec
import LinearAlgebra: inv, norm, ldiv!, mul!
import StaticArrays: SMatrix, SVector
import BandedMatrices: BandedMatrix, band, _BandedMatrix
import RecursiveArrayTools: recursive_bottom_eltype
import KrylovKit
import ClimaComms
import NVTX
import Adapt
using UnrolledUtilities
import ..Utilities: PlusHalf, half
import ..RecursiveApply:
rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
import ..RecursiveApply: ⊠, ⊞, ⊟
import ..DataLayouts: AbstractData
import ..DataLayouts: vindex
import ..Geometry
import ..Topologies
import ..Spaces
import ..Spaces: local_geometry_type
import ..Fields
import ..Operators
using ..Geometry:
rmul_with_projection,
mul_with_projection,
axis_tensor_type,
rmul_return_type
export DiagonalMatrixRow,
BidiagonalMatrixRow,
TridiagonalMatrixRow,
QuaddiagonalMatrixRow,
PentadiagonalMatrixRow
export FieldVectorKeys, FieldMatrixKeys, FieldVectorView, FieldMatrix
export FieldMatrixWithSolver, ⋅
include("band_matrix_row.jl")
const ColumnwiseBandMatrixField{V, S} = Fields.Field{
V,
S,
} where {
V <: AbstractData{<:BandMatrixRow},
S <: Union{
Spaces.AbstractSpace,
Operators.PlaceholderSpace, # so that this can exist inside cuda kernels
},
}
include("matrix_shape.jl")
include("matrix_multiplication.jl")
include("lazy_operators.jl")
include("operator_matrices.jl")
include("field2arrays.jl")
include("field_name.jl")
include("field_name_set.jl")
include("field_name_dict.jl")
include("single_field_solver.jl")
include("multiple_field_solver.jl")
include("field_matrix_solver.jl")
include("field_matrix_iterative_solver.jl")
include("field_matrix_with_solver.jl")
function Base.show(io::IO, field::ColumnwiseBandMatrixField)
print(io, eltype(field), "-valued Field")
if eltype(eltype(field)) <: Number
shape = typeof(matrix_shape(field)).name.name
if field isa Fields.FiniteDifferenceField
println(io, " that corresponds to the $shape matrix")
else
println(io, " whose first column corresponds to the $shape matrix")
end
column_field = Fields.column(field, 1, 1, 1)
io = IOContext(io, :compact => true, :limit => true)
ClimaComms.allowscalar(ClimaComms.device(field)) do
Base.print_array(io, column_field2array_view(column_field))
end
else
# When a BandedMatrix with non-number entries is printed, it currently
# either prints in an illegible format (e.g., if it has AxisTensor or
# AdjointAxisTensor entries) or crashes during the evaluation of
# isassigned (e.g., if it has Tuple or NamedTuple entries). So, for
# matrix fields with non-number entries, we fall back to the default
# function for printing fields.
print(io, ":")
Fields._show_compact_field(io, field, " ", true)
end
end
end