Skip to content

Commit

Permalink
Merge branch 'master' into test/mpo-canonization
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing authored Jan 24, 2025
2 parents 80a59eb + b8b4aa2 commit 5768a36
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,25 @@ const stablehlo = MLIR.Dialects.stablehlo

const Enzyme = Reactant.Enzyme

function Reactant.make_tracer(
seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs...
) where {RT<:Tensor}
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(T::Type{<:Tensor}), seen, mode, @nospecialize(track_numbers)
)
A_traced = Reactant.traced_type_inner(Tenet.parenttype(T), seen, mode, track_numbers)
return Tensor{eltype(T),ndims(T),A_traced}
end

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}), seen, mode, @nospecialize(track_numbers)
)
return T
end

function Reactant.make_tracer(seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...) where {RT<:Tensor}
tracedata = Reactant.make_tracer(seen, parent(prev), Reactant.append_path(path, :data), mode; kwargs...)
return Tensor(tracedata, inds(prev))
end

function Reactant.make_tracer(seen, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...)
function Reactant.make_tracer(seen, prev::TensorNetwork, @nospecialize(path), mode; kwargs...)
tracetensors = Vector{Tensor}(undef, Tenet.ntensors(prev))
for (i, tensor) in enumerate(tensors(prev))
tracetensors[i] = Reactant.make_tracer(seen, tensor, Reactant.append_path(path, i), mode; kwargs...)
Expand All @@ -26,24 +37,24 @@ end

Reactant.traced_getfield(x::TensorNetwork, i::Int) = tensors(x)[i]

function Reactant.make_tracer(seen, prev::Quantum, path::Tuple, mode::Reactant.TraceMode; kwargs...)
function Reactant.make_tracer(seen, prev::Quantum, @nospecialize(path), mode; kwargs...)
tracetn = Reactant.make_tracer(seen, TensorNetwork(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return Quantum(tracetn, copy(prev.sites))
end

function Reactant.make_tracer(seen, prev::Ansatz, path::Tuple, mode::Reactant.TraceMode; kwargs...)
function Reactant.make_tracer(seen, prev::Ansatz, @nospecialize(path), mode; kwargs...)
tracetn = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return Ansatz(tracetn, copy(Tenet.lattice(prev)))
end

# TODO try rely on generic fallback for ansatzes
function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...)
function Reactant.make_tracer(seen, prev::Tenet.Product, @nospecialize(path), mode; kwargs...)
tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return Tenet.Product(tracetn)
end

function Reactant.make_tracer(
seen, prev::A, path::Tuple, mode::Reactant.TraceMode; kwargs...
seen, @nospecialize(prev::A), @nospecialize(path), mode; kwargs...
) where {A<:Tenet.AbstractMPO}
tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return A(tracetn, copy(form(prev)))
Expand Down Expand Up @@ -122,7 +133,9 @@ function Reactant.set_act!(inp::Enzyme.Annotation{TensorNetwork}, path, reverse,
end

function Tenet.contract(
a::Tensor{TracedRNumber{Ta},Na,TracedRArray{Ta,Na}}, b::Tensor{TracedRNumber{Tb},Nb,TracedRArray{Tb,Nb}}; kwargs...
@nospecialize(a::Tensor{TracedRNumber{Ta},Na,TracedRArray{Ta,Na}}),
@nospecialize(b::Tensor{TracedRNumber{Tb},Nb,TracedRArray{Tb,Nb}});
kwargs...,
) where {Ta,Na,Tb,Nb}
dims = get(kwargs, :dims) do
(inds(a), inds(b))
Expand Down Expand Up @@ -173,15 +186,19 @@ function Tenet.contract(
end

function Tenet.contract(
a::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}; dims=nonunique(inds(a)), out=nothing
@nospecialize(a::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}); dims=nonunique(inds(a)), out=nothing
) where {T,N}
error("compilation of unary einsum operations are not yet supported")
end

function Tenet.contract(a::Tensor, b::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}; kwargs...) where {T,N}
function Tenet.contract(
@nospecialize(a::Tensor), @nospecialize(b::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}); kwargs...
) where {T,N}
contract(b, a; kwargs...)
end
function Tenet.contract(a::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}, b::Tensor; kwargs...) where {T,N}
function Tenet.contract(
@nospecialize(a::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}), @nospecialize(b::Tensor); kwargs...
) where {T,N}
return contract(a, Tensor(Reactant.Ops.constant(parent(b)), inds(b)); kwargs...)
end

Expand Down

0 comments on commit 5768a36

Please sign in to comment.