From 2b29fbd27427db387e26775a36ebb66ce48977a8 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Tue, 9 Apr 2024 14:56:11 +0200 Subject: [PATCH] Add more operators (#6) --- src/conversion.jl | 4 ++ src/operators.jl | 114 ++++++++++++++++++++++++++++++---------------- 2 files changed, 80 insertions(+), 38 deletions(-) diff --git a/src/conversion.jl b/src/conversion.jl index 0cb3ed37..eba952bb 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -2,6 +2,10 @@ Base.promote_rule(::Type{Tracer}, ::Type{N}) where {N<:Number} = Tracer Base.promote_rule(::Type{N}, ::Type{Tracer}) where {N<:Number} = Tracer +Base.big(::Type{Tracer}) = Tracer +Base.widen(::Type{Tracer}) = Tracer +Base.widen(t::Tracer) = t + Base.convert(::Type{Tracer}, x::Number) = tracer() Base.convert(::Type{Tracer}, t::Tracer) = t Base.convert(::Type{<:Number}, t::Tracer) = t diff --git a/src/operators.jl b/src/operators.jl index d541a5a4..b05a872c 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -1,53 +1,91 @@ -## Extent Base operators -for fn in (:+, :-, :*, :/) +## Operator definitions + +#! format: off +ops_2_to_1 = ( + :+, :-, :*, :/, + # division + :div, :fld, :cld, + # modulo + :mod, :rem, + # exponentials + :ldexp, + # sign + :copysign, :flipsign, + # other + :hypot, +) + +ops_1_to_1 = ( + # trigonometric functions + :deg2rad, :rad2deg, + :cos, :cosd, :cosh, :cospi, :cosc, + :sin, :sind, :sinh, :sinpi, :sinc, + :tan, :tand, :tanh, + # reciprocal trigonometric functions + :csc, :cscd, :csch, + :sec, :secd, :sech, + :cot, :cotd, :coth, + # inverse trigonometric functions + :acos, :acosd, :acosh, + :asin, :asind, :asinh, + :atan, :atand, :atanh, + :asec, :asech, + :acsc, :acsch, + :acot, :acoth, + # exponentials + :exp, :exp2, :exp10, :expm1, + :log, :log2, :log10, :log1p, + :abs, :abs2, + # roots + :sqrt, :cbrt, + # absolute values + :abs, :abs2, + # rounding + :floor, :ceil, :trunc, + # other + :inv, :signbit, :hypot, :sign, :mod2pi +) + +ops_1_to_2 = ( + # trigonometric + :sincos, + :sincosd, + :sincospi, + # exponentials + :frexp, +) +#! format: on + +for fn in ops_1_to_1 + @eval Base.$fn(t::Tracer) = t +end + +for fn in ops_1_to_2 + @eval Base.$fn(t::Tracer) = (t, t) +end + +for fn in ops_2_to_1 @eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b) - for T in (:Number,) - @eval Base.$fn(t::Tracer, ::$T) = t - @eval Base.$fn(::$T, t::Tracer) = t - end + @eval Base.$fn(t::Tracer, ::Number) = t + @eval Base.$fn(::Number, t::Tracer) = t end +# Extra types required for exponent Base.:^(a::Tracer, b::Tracer) = tracer(a, b) -for T in (:Number, :Integer, :Rational) +for T in (:Real, :Integer, :Rational) @eval Base.:^(t::Tracer, ::$T) = t @eval Base.:^(::$T, t::Tracer) = t end Base.:^(t::Tracer, ::Irrational{:ℯ}) = t Base.:^(::Irrational{:ℯ}, t::Tracer) = t -## Two-argument functions -for fn in (:div, :fld, :cld) - @eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b) - @eval Base.$fn(t::Tracer, ::Number) = t - @eval Base.$fn(::Number, t::Tracer) = t +## Precision operators create empty Tracer +for fn in (:eps, :nextfloat, :floatmin, :floatmax, :maxintfloat, :typemax) + @eval Base.$fn(::Tracer) = tracer() end -## Single-argument functions - -#! format: off -scalar_operations = ( - :exp2, :deg2rad, :rad2deg, - :cos, :cosd, :cosh, :cospi, :cosc, - :sin, :sind, :sinh, :sinpi, :sinc, - :tan, :tand, :tanh, - :csc, :cscd, :csch, - :sec, :secd, :sech, - :cot, :cotd, :coth, - :acos, :acosd, :acosh, - :asin, :asind, :asinh, - :atan, :atand, :atanh, - :asec, :asech, - :acsc, :acsch, - :acot, :acoth, - :exp, :expm1, :exp10, - :frexp, :ldexp, - :abs, :abs2, :sqrt -) -#! format: on - -for fn in scalar_operations - @eval Base.$fn(t::Tracer) = t -end +## Rounding +Base.round(t::Tracer, ::RoundingMode; kwargs...) = t ## Random numbers rand(::AbstractRNG, ::SamplerType{Tracer}) = tracer()