Skip to content

Commit

Permalink
Merge branch 'master' into test/mpo-canonization
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing authored Jan 24, 2025
2 parents 44fcc36 + 08fc199 commit a18e2b5
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,26 @@ const stablehlo = MLIR.Dialects.stablehlo

const Enzyme = Reactant.Enzyme

# we specify `mode` and `track_numbers` types due to ambiguity
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(TT::Type{<:Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
)
A_traced = Reactant.traced_type_inner(Tenet.parenttype(TT), seen, mode, track_numbers)
T = eltype(A_traced)
N = ndims(TT)
return Tensor{T,N,A_traced}
end

# we specify `mode` and `track_numbers` types due to ambiguity
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type)
)
return T
@static if isdefined(Reactant, :traced_type_inner)
# we specify `mode` and `track_numbers` types due to ambiguity
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(TT::Type{<:Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
)
A_traced = Reactant.traced_type_inner(Tenet.parenttype(TT), seen, mode, track_numbers)
T = eltype(A_traced)
N = ndims(TT)
return Tensor{T,N,A_traced}
end

# we specify `mode` and `track_numbers` types due to ambiguity
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type)
)
return T
end
end

function Reactant.make_tracer(seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...) where {RT<:Tensor}
Expand Down

0 comments on commit a18e2b5

Please sign in to comment.