From 08fc199e0a1b8c1c76977bbaa123dfbdcb96d663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 24 Jan 2025 07:39:13 +0100 Subject: [PATCH] Fix `Reactant.traced_type_inner` not defined --- ext/TenetReactantExt.jl | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index f247d8311..3d433c3b5 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -9,24 +9,26 @@ const stablehlo = MLIR.Dialects.stablehlo const Enzyme = Reactant.Enzyme -# we specify `mode` and `track_numbers` types due to ambiguity -Base.@nospecializeinfer function Reactant.traced_type_inner( - @nospecialize(TT::Type{<:Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) -) - A_traced = Reactant.traced_type_inner(Tenet.parenttype(TT), seen, mode, track_numbers) - T = eltype(A_traced) - N = ndims(TT) - return Tensor{T,N,A_traced} -end - -# we specify `mode` and `track_numbers` types due to ambiguity -Base.@nospecializeinfer function Reactant.traced_type_inner( - @nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}), - seen, - mode::Reactant.TraceMode, - @nospecialize(track_numbers::Type) -) - return T +@static if isdefined(Reactant, :traced_type_inner) + # we specify `mode` and `track_numbers` types due to ambiguity + Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(TT::Type{<:Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) + ) + A_traced = Reactant.traced_type_inner(Tenet.parenttype(TT), seen, mode, track_numbers) + T = eltype(A_traced) + N = ndims(TT) + return Tensor{T,N,A_traced} + end + + # we specify `mode` and `track_numbers` types due to ambiguity + Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type) + ) + return T + end end function Reactant.make_tracer(seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...) where {RT<:Tensor}