diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index f247d8311..4d2caf9ef 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -11,12 +11,29 @@ const Enzyme = Reactant.Enzyme # we specify `mode` and `track_numbers` types due to ambiguity Base.@nospecializeinfer function Reactant.traced_type_inner( - @nospecialize(TT::Type{<:Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) + @nospecialize(::Type{Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) ) - A_traced = Reactant.traced_type_inner(Tenet.parenttype(TT), seen, mode, track_numbers) - T = eltype(A_traced) - N = ndims(TT) - return Tensor{T,N,A_traced} + return Tensor +end + +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(::Type{Tensor{T}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) +) where {T} + return Tensor{TracedRNumber{T}} +end + +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(::Type{Tensor{T,N}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) +) where {T,N} + return Tensor{TracedRNumber{T,N}} +end + +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(::Type{Tensor{T,N,A}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) +) where {T,N,A} + A_traced = Reactant.traced_type_inner(A, seen, mode, track_numbers) + T_traced = eltype(A_traced) + return Tensor{T_traced,N,A_traced} end # we specify `mode` and `track_numbers` types due to ambiguity