Skip to content

Commit

Permalink
Specify traced_type_inner manually for all Tensor union alls
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Feb 3, 2025
1 parent 236816a commit 7381aae
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,29 @@ const Enzyme = Reactant.Enzyme

# we specify `mode` and `track_numbers` types due to ambiguity
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(TT::Type{<:Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
@nospecialize(::Type{Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
)
A_traced = Reactant.traced_type_inner(Tenet.parenttype(TT), seen, mode, track_numbers)
T = eltype(A_traced)
N = ndims(TT)
return Tensor{T,N,A_traced}
return Tensor
end

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(::Type{Tensor{T}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
) where {T}
return Tensor{TracedRNumber{T}}
end

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(::Type{Tensor{T,N}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
) where {T,N}
return Tensor{TracedRNumber{T,N}}
end

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(::Type{Tensor{T,N,A}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
) where {T,N,A}
A_traced = Reactant.traced_type_inner(A, seen, mode, track_numbers)
T_traced = eltype(A_traced)
return Tensor{T_traced,N,A_traced}
end

# we specify `mode` and `track_numbers` types due to ambiguity
Expand Down

0 comments on commit 7381aae

Please sign in to comment.