Skip to content

Commit

Permalink
Try fix parenttype on unionall of Tensor (#309)
Browse files Browse the repository at this point in the history
* Try fix `parenttype` on unionall of `Tensor`

* Bump Reactant compat constraint to v0.2.22

* Specify `traced_type_inner` manually for all `Tensor` union alls

* Fix syntax

* Remove `@invoke` calls in Quantum to fix infinite recursion in Reactant

* Stop Reactant overlay of `replace!(::AbstractQuantum)`

* Format code
  • Loading branch information
mofeing authored Feb 3, 2025
1 parent 7421ef1 commit c96ca18
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ OMEinsum = "0.7, 0.8"
PythonCall = "0.9"
Quac = "0.3"
Random = "1.10"
Reactant = "0.2.18"
Reactant = "0.2.22"
ScopedValues = "1"
Serialization = "1.10"
SparseArrays = "1.10"
Expand Down
60 changes: 34 additions & 26 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,41 @@ const stablehlo = MLIR.Dialects.stablehlo

const Enzyme = Reactant.Enzyme

@static if isdefined(Reactant, :traced_type_inner)
# we specify `mode` and `track_numbers` types due to ambiguity
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(TT::Type{<:Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
)
A_traced = Reactant.traced_type_inner(Tenet.parenttype(TT), seen, mode, track_numbers)
T = eltype(A_traced)
N = ndims(TT)
return Tensor{T,N,A_traced}
end
# we specify `mode` and `track_numbers` types due to ambiguity
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(_::Type{Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
)
return Tensor
end

# we specify `mode` and `track_numbers` types due to ambiguity
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type)
)
return T
end
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(_::Type{Tensor{T}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
) where {T}
return Tensor{TracedRNumber{T}}
end

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(_::Type{Tensor{T,N}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
) where {T,N}
return Tensor{TracedRNumber{T,N}}
end

Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(_::Type{Tensor{T,N,A}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
) where {T,N,A}
A_traced = Reactant.traced_type_inner(A, seen, mode, track_numbers)
T_traced = eltype(A_traced)
return Tensor{T_traced,N,A_traced}
end

# we specify `mode` and `track_numbers` types due to ambiguity
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type)
)
return T
end

function Reactant.make_tracer(seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...) where {RT<:Tensor}
Expand Down Expand Up @@ -221,11 +236,4 @@ function Base.conj(@nospecialize(x::Tensor{TracedRNumber{T},N,<:TracedRArray}))
Tensor(conj(parent(x)), inds(x))
end

# fix infinite recursion on Reactant rewrite of invoke/call step
@reactant_overlay @noinline function Base.replace!(
tn::Tenet.AbstractQuantum, old_new::Base.AbstractVecOrTuple{Pair{Symbol,Symbol}}
)
Base.inferencebarrier(Base.replace!)(tn, old_new)
end

end
11 changes: 5 additions & 6 deletions src/Quantum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,10 @@ Base.similar(tn::Quantum) = Quantum(similar(TensorNetwork(tn)), copy(tn.sites))
Base.zero(tn::Quantum) = Quantum(zero(TensorNetwork(tn)), copy(tn.sites))

function Base.:(==)(a::AbstractQuantum, b::AbstractQuantum)
return Quantum(a).sites == Quantum(b).sites && @invoke ==(a::AbstractTensorNetwork, b::AbstractTensorNetwork)
return Quantum(a).sites == Quantum(b).sites && ==(TensorNetwork(a), TensorNetwork(b))
end
function Base.isapprox(a::AbstractQuantum, b::AbstractQuantum; kwargs...)
return Quantum(a).sites == Quantum(b).sites &&
@invoke isapprox(a::AbstractTensorNetwork, b::AbstractTensorNetwork; kwargs...)
return Quantum(a).sites == Quantum(b).sites && isapprox(TensorNetwork(a), TensorNetwork(b); kwargs...)
end

Base.summary(io::IO, tn::AbstractQuantum) = print(io, "$(ntensors(tn))-tensors Quantum")
Expand Down Expand Up @@ -119,7 +118,7 @@ end

# `pop!` / `delete!` methods call this method
function Base.pop!(tn::AbstractQuantum, tensor::Tensor)
@invoke pop!(tn::AbstractTensorNetwork, tensor)
pop!(TensorNetwork(tn), tensor)

# TODO replace with `inds(tn; set=:physical)` when implemented
targets = values(Quantum(tn).sites) inds(tensor)
Expand All @@ -134,7 +133,7 @@ function Base.replace!(tn::AbstractQuantum, old_new::Pair{Symbol,Symbol})
tn = Quantum(tn)

# replace indices in underlying Tensor Network
@invoke replace!(tn::AbstractTensorNetwork, old_new)
replace!(TensorNetwork(tn), old_new)

# replace indices in site information
site = sites(tn; at=first(old_new))
Expand All @@ -151,7 +150,7 @@ function Base.replace!(tn::AbstractQuantum, old_new::Base.AbstractVecOrTuple{Pai
tn = Quantum(tn)

# replace indices in underlying Tensor Network
@invoke replace!(tn::AbstractTensorNetwork, old_new)
replace!(TensorNetwork(tn), old_new)

# replace indices in site information
from, to = first.(old_new), last.(old_new)
Expand Down
3 changes: 3 additions & 0 deletions src/Tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ Return the underlying array of the tensor.
"""
Base.parent(t::Tensor) = t.data
parenttype(::Type{Tensor{T,N,A}}) where {T,N,A} = A
parenttype(::Type{Tensor{T,N}}) where {T,N} = AbstractArray{T,N}
parenttype(::Type{Tensor{T}}) where {T} = AbstractArray{T}
parenttype(::Type{Tensor}) = AbstractArray
parenttype(::T) where {T<:Tensor} = parenttype(T)

"""
Expand Down

0 comments on commit c96ca18

Please sign in to comment.