Skip to content

Commit

Permalink
Merge pull request #17 from dingraha/cs_safe
Browse files Browse the repository at this point in the history
Add complex-step safe versions of a few functions
  • Loading branch information
andrewning authored Jun 22, 2022
2 parents 9dba266 + 966a962 commit 319b287
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: ['1.3']
julia-version: ['1.6']
julia-arch: [x64]
os: [ubuntu-latest, windows-latest, macOS-latest]

Expand All @@ -21,4 +21,4 @@ jobs:
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.julia-version }}
- uses: julia-actions/julia-runtest@master
- uses: julia-actions/julia-runtest@master
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "FLOWMath"
uuid = "6cb5d3fb-0fe8-4cc2-bd89-9fe0b19a99d3"
authors = ["Andrew Ning <aning@byu.edu>"]
version = "0.3.2"
version = "0.3.3"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"

[compat]
Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ Smoothing
- sigmoid blending
- cubic/quintic polynomial blending

[Complex step safe](https://doi.org/10.1145/838250.838251) versions of
- `abs`: `abs_cs_safe`
- `abs2`: `abs2_cs_safe`
- `norm`: `norm_cs_safe`
- `dot`: `dot_cs_safe`
- `atan` (two argument form): `atan_cs_safe`

### Install

```julia
Expand Down
22 changes: 22 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,25 @@ savefig("cubic.svg"); nothing # hide
cubic_blend
quintic_blend
```

### Complex-step safe functions
The [complex-step derivative approximation](https://doi.org/10.1145/838250.838251) can be used to easily and accurately approximate first derivatives.
However, the function `f` one wishes to differentiate must be composed of functions that are compatible with the method.
Most elementary functions are, but a few common ones are not:

* `abs`
* `abs2`
* `norm`
* `dot`
* the two argument form of `atan` (often called `atan2` or `arctan2` in other languages)

FLOWMath provides complex-step safe versions of these functions.
These functions use Julia's multiple dispatch to fall back on the standard implementations when given real arguments, and so shouldn't impose any performance penalty when not used with the complex step method.

```@docs
abs_cs_safe
abs2_cs_safe
norm_cs_safe
dot_cs_safe
atan_cs_safe
```
5 changes: 4 additions & 1 deletion src/FLOWMath.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
module FLOWMath

include("cs_safe.jl")
export abs_cs_safe
export abs_cs_safe, abs2_cs_safe
export norm_cs_safe
export dot_cs_safe
export atan_cs_safe

include("quadrature.jl")
export trapz
Expand Down
102 changes: 100 additions & 2 deletions src/cs_safe.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,105 @@
function abs_cs_safe(x::T) where {T<:Complex}
using LinearAlgebra: norm, dot


"""
abs_cs_safe(x)
Calculate the absolute value of `x` in a manner compatible with the complex-step derivative approximation.
See also: [`abs`](@ref).
"""
abs_cs_safe

@inline function abs_cs_safe(x::T) where {T<:Complex}
return x*sign(real(x))
end

function abs_cs_safe(x)
@inline function abs_cs_safe(x)
return abs(x)
end

"""
abs2_cs_safe(x)
Calculate the squared absolute value of `x` in a manner compatible with the complex-step derivative approximation.
See also: [`abs2`](@ref).
"""
abs2_cs_safe

@inline function abs2_cs_safe(x::T) where {T<:Complex}
return abs_cs_safe(x)^2
end

@inline function abs2_cs_safe(x)
return abs2(x)
end

"""
norm_cs_safe(x, p)
Calculate the `p`-norm value of iterable `x` in a manner compatible with the complex-step derivative approximation.
See also: [`norm`](@ref).
"""
norm_cs_safe

@inline function norm_cs_safe(x::AbstractArray{T}, p::Real=2) where {T<:Complex}
return sum(x.^p)^(1/p)
end

@inline function norm_cs_safe(x, p::Real=2)
return norm(x, p)
end

"""
dot_cs_safe(a, b)
Calculate the dot product of vectors `a` and `b` in a manner compatible with the complex-step derivative approximation.
See also: [`norm`](@ref).
"""
dot_cs_safe

@inline function dot_cs_safe(a::AbstractVector{T}, b) where {T<:Complex}
# `dot` conjugates its first argument, so we only need to worry about the case where the first argument is complex.
# return sum(a.*b)
return dot(conj.(a), b)
end

@inline function dot_cs_safe(a, b)
return dot(a, b)
end


"""
atan_cs_safe(y, x)
Calculate the two-argument arctangent function in a manner compatible with the complex-step derivative approximation.
See also: [`atan`](@ref).
"""
atan_cs_safe

@inline function atan_cs_safe(y, x)
return atan_cs_safe(promote(y, x)...)
end

@inline function atan_cs_safe(y::T, x::T) where {T<:Complex}
# Stolen from openmdao/utils/cs_safe.py
# a = np.real(y)
# b = np.imag(y)
# c = np.real(x)
# d = np.imag(x)
# return np.arctan2(a, c) + 1j * (c * b - a * d) / (a**2 + c**2)
a = real(y)
b = imag(y)
c = real(x)
d = imag(x)
return complex(atan(a, c), (c * b - a * d) / (a^2 + c^2))
end

@inline function atan_cs_safe(y::T, x::T) where {T}
return atan(y, x)
end

3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
FiniteDiff = "2.13"
84 changes: 83 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,92 @@ using FLOWMath
using Test
import ForwardDiff
import FiniteDiff
using LinearAlgebra: diag
using LinearAlgebra: diag, norm, dot

@testset "FLOWMath.jl" begin

# ------ complex-step safe --------
#
# abs_cs_safe
f(x) = 2*cos(abs(x)) + 3*sin(x)
f_cs_safe(x) = 2*cos(abs_cs_safe(x)) + 3*sin(x)
# Check for positive and negative arguments to abs_cs_safe.
for x0 in [2.5, -2.5]
@test f_cs_safe(x0) f(x0)
dfdx_fd = ForwardDiff.derivative(f_cs_safe, x0)
dfdx_cs = FiniteDiff.finite_difference_derivative(f_cs_safe, x0, Val{:complex})
dfdx_not_cs_safe = FiniteDiff.finite_difference_derivative(f, x0, Val{:complex})
@test dfdx_cs dfdx_fd
@test !(dfdx_not_cs_safe dfdx_fd)
end

# abs2_cs_safe
f(x) = 2*cos(abs2(x)) + 3*sin(x)
f_cs_safe(x) = 2*cos(abs2_cs_safe(x)) + 3*sin(x)
for x0 in [2.5, -2.5]
@test f_cs_safe(x0) f(x0)
dfdx_fd = ForwardDiff.derivative(f_cs_safe, x0)
dfdx_cs = FiniteDiff.finite_difference_derivative(f_cs_safe, x0, Val{:complex})
dfdx_not_cs_safe = FiniteDiff.finite_difference_derivative(f, x0, Val{:complex})
@test dfdx_cs dfdx_fd
@test !(dfdx_not_cs_safe dfdx_fd)
end

# norm_cs_safe
f(x, p) = norm(2 .* cos.(x) .+ 3 .* sin.(x), p)
f_cs_safe(x, p) = norm_cs_safe(2 .* cos.(x) .+ 3 .* sin.(x), p)
x0 = rand(3, 4)
for p in [1, 2, 3]
fp(x) = f(x, p)
fp_cs_safe(x) = f_cs_safe(x, p)
@test fp_cs_safe(x0) fp(x0)
dfdx_fd = ForwardDiff.gradient(fp_cs_safe, x0)
dfdx_cs = FiniteDiff.finite_difference_gradient(fp_cs_safe, x0, Val{:complex})
dfdx_not_cs_safe = FiniteDiff.finite_difference_gradient(fp, x0, Val{:complex})
@test dfdx_cs dfdx_fd
@test !(dfdx_not_cs_safe dfdx_fd)
end

# dot_cs_safe
f(x) = 3*dot(x, sin.(x))^2
f_cs_safe(x) = 3*dot_cs_safe(x, sin.(x))^2
x0 = rand(4)
@test f_cs_safe(x0) f(x0)
dfdx_fd = ForwardDiff.gradient(f_cs_safe, x0)
dfdx_cs = FiniteDiff.finite_difference_gradient(f_cs_safe, x0, Val{:complex})
dfdx_not_cs_safe = FiniteDiff.finite_difference_gradient(f, x0, Val{:complex})
@test dfdx_cs dfdx_fd
@test !(dfdx_not_cs_safe dfdx_fd)

# Test that we don't need to modify the second argument to `dot`:
y = rand(4)
f(x) = 3*dot(y, x)^2
f_cs_safe(x) = 3*dot_cs_safe(y, x)^2
x0 = rand(4)
@test f_cs_safe(x0) f(x0)
dfdx_fd = ForwardDiff.gradient(f_cs_safe, x0)
dfdx_cs = FiniteDiff.finite_difference_gradient(f_cs_safe, x0, Val{:complex})
dfdx_not_cs_safe = FiniteDiff.finite_difference_gradient(f, x0, Val{:complex})
@test dfdx_cs dfdx_fd
# Using a real input to the first argument of dot is actually complex-step safe, so these should be the same.
@test dfdx_not_cs_safe dfdx_fd

# two-argument atan_cs_safe
# Need to test all four quadrants for atan.
f(sign1,sign2) = x->3*atan(sign1*(x+2), sign2*x)^2
f_cs_safe(sign1,sign2) = x->3*atan_cs_safe(sign1*(x+2), sign2*x)^2
for sign1 in [-1, 1]
for sign2 in [-1, 1]
fs1s2 = f(sign1, sign2)
fs1s2_cs_safe = f_cs_safe(sign1, sign2)
x0 = rand()
@test fs1s2_cs_safe(x0) fs1s2(x0)
dfdx_fd = ForwardDiff.derivative(fs1s2, x0)
dfdx_cs = FiniteDiff.finite_difference_derivative(fs1s2_cs_safe, x0, Val{:complex})
@test dfdx_cs dfdx_fd
# Can't check that the non-complex-step-safe version of atan doesn't work since it's not implemented for two complex arguments.
end
end

# ------ trapz --------
# tests from matlab trapz docs
Expand Down

2 comments on commit 319b287

@andrewning
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62899

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.3 -m "<description of version>" 319b2870740a78ca062bf58dcb36d98faae48ad3
git push origin v0.3.3

Please sign in to comment.