From b8b4aa2d86949727ad577fd1895e9ffbc137caf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 23 Jan 2025 17:54:19 +0100 Subject: [PATCH] Speedup Tenet+Reactant compile time by avoiding over-specialization --- ext/TenetReactantExt.jl | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index e1c073d9b..a3282ad4e 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -9,14 +9,25 @@ const stablehlo = MLIR.Dialects.stablehlo const Enzyme = Reactant.Enzyme -function Reactant.make_tracer( - seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs... -) where {RT<:Tensor} +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(T::Type{<:Tensor}), seen, mode, @nospecialize(track_numbers) +) + A_traced = Reactant.traced_type_inner(Tenet.parenttype(T), seen, mode, track_numbers) + return Tensor{eltype(T),ndims(T),A_traced} +end + +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}), seen, mode, @nospecialize(track_numbers) +) + return T +end + +function Reactant.make_tracer(seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...) where {RT<:Tensor} tracedata = Reactant.make_tracer(seen, parent(prev), Reactant.append_path(path, :data), mode; kwargs...) return Tensor(tracedata, inds(prev)) end -function Reactant.make_tracer(seen, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...) +function Reactant.make_tracer(seen, prev::TensorNetwork, @nospecialize(path), mode; kwargs...) tracetensors = Vector{Tensor}(undef, Tenet.ntensors(prev)) for (i, tensor) in enumerate(tensors(prev)) tracetensors[i] = Reactant.make_tracer(seen, tensor, Reactant.append_path(path, i), mode; kwargs...) @@ -26,24 +37,24 @@ end Reactant.traced_getfield(x::TensorNetwork, i::Int) = tensors(x)[i] -function Reactant.make_tracer(seen, prev::Quantum, path::Tuple, mode::Reactant.TraceMode; kwargs...) +function Reactant.make_tracer(seen, prev::Quantum, @nospecialize(path), mode; kwargs...) tracetn = Reactant.make_tracer(seen, TensorNetwork(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Quantum(tracetn, copy(prev.sites)) end -function Reactant.make_tracer(seen, prev::Ansatz, path::Tuple, mode::Reactant.TraceMode; kwargs...) +function Reactant.make_tracer(seen, prev::Ansatz, @nospecialize(path), mode; kwargs...) tracetn = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Ansatz(tracetn, copy(Tenet.lattice(prev))) end # TODO try rely on generic fallback for ansatzes -function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) +function Reactant.make_tracer(seen, prev::Tenet.Product, @nospecialize(path), mode; kwargs...) tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Tenet.Product(tracetn) end function Reactant.make_tracer( - seen, prev::A, path::Tuple, mode::Reactant.TraceMode; kwargs... + seen, @nospecialize(prev::A), @nospecialize(path), mode; kwargs... ) where {A<:Tenet.AbstractMPO} tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) return A(tracetn, copy(form(prev))) @@ -122,7 +133,9 @@ function Reactant.set_act!(inp::Enzyme.Annotation{TensorNetwork}, path, reverse, end function Tenet.contract( - a::Tensor{TracedRNumber{Ta},Na,TracedRArray{Ta,Na}}, b::Tensor{TracedRNumber{Tb},Nb,TracedRArray{Tb,Nb}}; kwargs... + @nospecialize(a::Tensor{TracedRNumber{Ta},Na,TracedRArray{Ta,Na}}), + @nospecialize(b::Tensor{TracedRNumber{Tb},Nb,TracedRArray{Tb,Nb}}); + kwargs..., ) where {Ta,Na,Tb,Nb} dims = get(kwargs, :dims) do ∩(inds(a), inds(b)) @@ -173,15 +186,19 @@ function Tenet.contract( end function Tenet.contract( - a::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}; dims=nonunique(inds(a)), out=nothing + @nospecialize(a::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}); dims=nonunique(inds(a)), out=nothing ) where {T,N} error("compilation of unary einsum operations are not yet supported") end -function Tenet.contract(a::Tensor, b::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}; kwargs...) where {T,N} +function Tenet.contract( + @nospecialize(a::Tensor), @nospecialize(b::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}); kwargs... +) where {T,N} contract(b, a; kwargs...) end -function Tenet.contract(a::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}, b::Tensor; kwargs...) where {T,N} +function Tenet.contract( + @nospecialize(a::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}), @nospecialize(b::Tensor); kwargs... +) where {T,N} return contract(a, Tensor(Reactant.Ops.constant(parent(b)), inds(b)); kwargs...) end