From 393f7272fbe18c8d80de81b6514ef37758dc0dfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 24 Jan 2025 07:09:37 +0100 Subject: [PATCH] Fix `Reactant.traced_type_inner` implementation --- ext/TenetReactantExt.jl | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index a3282ad4e..f247d8311 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -9,15 +9,22 @@ const stablehlo = MLIR.Dialects.stablehlo const Enzyme = Reactant.Enzyme +# we specify `mode` and `track_numbers` types due to ambiguity Base.@nospecializeinfer function Reactant.traced_type_inner( - @nospecialize(T::Type{<:Tensor}), seen, mode, @nospecialize(track_numbers) + @nospecialize(TT::Type{<:Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) ) - A_traced = Reactant.traced_type_inner(Tenet.parenttype(T), seen, mode, track_numbers) - return Tensor{eltype(T),ndims(T),A_traced} + A_traced = Reactant.traced_type_inner(Tenet.parenttype(TT), seen, mode, track_numbers) + T = eltype(A_traced) + N = ndims(TT) + return Tensor{T,N,A_traced} end +# we specify `mode` and `track_numbers` types due to ambiguity Base.@nospecializeinfer function Reactant.traced_type_inner( - @nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}), seen, mode, @nospecialize(track_numbers) + @nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type) ) return T end @@ -132,11 +139,15 @@ function Reactant.set_act!(inp::Enzyme.Annotation{TensorNetwork}, path, reverse, end end -function Tenet.contract( +Base.@nospecializeinfer @noinline function Tenet.contract( @nospecialize(a::Tensor{TracedRNumber{Ta},Na,TracedRArray{Ta,Na}}), @nospecialize(b::Tensor{TracedRNumber{Tb},Nb,TracedRArray{Tb,Nb}}); kwargs..., ) where {Ta,Na,Tb,Nb} + # return Base.inferencebarrier(__contract_binary)(a, b; kwargs...) + # end + + # function __contract_binary(@nospecialize(a), @nospecialize(b); kwargs...) dims = get(kwargs, :dims) do ∩(inds(a), inds(b)) end @@ -171,8 +182,8 @@ function Tenet.contract( # TODO replace for `Ops.convert`/`adapt` when it's available (there can be problems with nested array structures) T = Base.promote_eltype(a, b) - da = eltype(a) != T ? TracedRArray{T,Na}(parent(a)) : parent(a) - db = eltype(b) != T ? TracedRArray{T,Nb}(parent(b)) : parent(b) + da = eltype(a) != T ? TracedRArray{T,ndims(a)}(parent(a)) : parent(a) + db = eltype(b) != T ? TracedRArray{T,ndims(b)}(parent(b)) : parent(b) data = Reactant.Ops.dot_general(da, db; contracting_dimensions, batching_dimensions)