Skip to content

Commit

Permalink
Enhance simple_update! for MPS in the Canonical form (#255)
Browse files Browse the repository at this point in the history
* First round of fixes on simple_update

* Fix tests

* Add simple_update_2site! for MixedCanonical form

* Format code

* Renormalize mps in truncation when recanonize kwarg is true

* Enhance tests

* Change default recanonize kwarg to false in truncate! function

* Refactor normalize functions

* Enhance normalize tests

* Define LinearAlgebra.normalize for AbstractQuantum

* Fix normalize functions for MPS

* Enhance tests

* Fix normalize for Canonical MPS

* Format code

* Update normalization step on evolve

* Change normalization to all lambdas for Canonical form

* Format code

* Fix truncate by adding renormalize kwarg

* Small enhancements on normalize! functions

* Enhance tests

* Change default kwargs in truncate

* Fix evolve kwargs

* Fix normalize! by putting replace! instead of inplace modification for NonCanonical

* Enhance tests

* Fix aesthetic suggestions, improve kwarg definition

* Update comment
  • Loading branch information
jofrevalles authored Nov 22, 2024
1 parent 4b8c356 commit 4247207
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 47 deletions.
104 changes: 78 additions & 26 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ mixed_canonize(tn::AbstractAnsatz, args...; kwargs...) = mixed_canonize!(deepcop

canonize_site(tn::AbstractAnsatz, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...)

"""
normalize!(ψ::AbstractAnsatz, at)
Normalize the state at a given [`Site`](@ref) or bond in a [`AbstractAnsatz`](@ref) Tensor Network.
"""
LinearAlgebra.normalize::AbstractAnsatz, site) = normalize!(copy(ψ), site)

"""
isisometry(tn::AbstractAnsatz, site; dir, kwargs...)
Expand Down Expand Up @@ -274,8 +281,9 @@ Truncate the dimension of the virtual `bond` of a [`NonCanonical`](@ref) Tensor
- `threshold`: The threshold to truncate the bond dimension.
- `maxdim`: The maximum bond dimension to keep.
- `compute_local_svd`: Whether to compute the local SVD of the bond. If `true`, it will contract the bond and perform a SVD to get the local singular values. Defaults to `true`.
- `normalize`: Whether to normalize the state at the bond after truncation. Defaults to `false`.
"""
function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, compute_local_svd=true)
function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, compute_local_svd=true, normalize=false)
virtualind = inds(tn; bond)

if compute_local_svd
Expand Down Expand Up @@ -305,26 +313,31 @@ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim,
end

slice!(tn, virtualind, extent)
sliced_bond = tensors(tn; bond)

# Note: Inplace normalization of the inner arrays may be more efficient
normalize && replace!(tn, sliced_bond => sliced_bond ./ norm(tn))

return tn
end

function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; threshold, maxdim)
function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, normalize=false)
# move orthogonality center to bond
mixed_canonize!(tn, bond)
return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=true)

return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=true, normalize)
end

"""
truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=true)
truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=true)
Truncate the dimension of the virtual `bond` of a [`Canonical`](@ref) Tensor Network by keeping the `maxdim` largest
**Schmidt coefficients** or those larger than `threshold`, and then recanonizes the Tensor Network if `recanonize` is `true`.
**Schmidt coefficients** or those larger than `threshold`, and then canonizes the Tensor Network if `canonize` is `true`.
"""
function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=true)
truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false)
function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, canonize=false, normalize=false)
truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false, normalize)

recanonize && canonize!(tn)
canonize && canonize!(tn)

return tn
end
Expand Down Expand Up @@ -354,7 +367,7 @@ function expect(ψ::AbstractAnsatz, observables::AbstractVecOrTuple; bra=copy(ψ
end

"""
evolve!(ψ::AbstractAnsatz, gate; threshold = nothing, maxdim = nothing, renormalize = false)
evolve!(ψ::AbstractAnsatz, gate; threshold = nothing, maxdim = nothing, normalize = false)
Evolve (through time) a [`AbstractAnsatz`](@ref) Tensor Network with a `gate` operator.
Expand All @@ -367,16 +380,16 @@ Evolve (through time) a [`AbstractAnsatz`](@ref) Tensor Network with a `gate` op
- `threshold`: The threshold to truncate the bond dimension.
- `maxdim`: The maximum bond dimension to keep.
- `renormalize`: Whether to renormalize the state after truncation.
- `normalize`: Whether to normalize the state after truncation.
# Notes
- The gate must act on neighboring sites according to the [`Lattice`](@ref) of the Tensor Network.
- The gate must have the same number of inputs and outputs.
- Currently only the "Simple Update" algorithm is used and the gate must be a 1-site or 2-site operator.
"""
function evolve!::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false)
return simple_update!(ψ, gate; threshold, maxdim, renormalize)
function evolve!::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, normalize=false, kwargs...)
return simple_update!(ψ, gate; threshold, maxdim, normalize, kwargs...)
end

# by popular demand (Stefano, I'm looking at you), I aliased `apply!` to `evolve!`
Expand All @@ -387,11 +400,11 @@ function simple_update!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=noth

if nlanes(gate) == 1
return simple_update_1site!(ψ, gate)
elseif nlanes(gate) == 2
return simple_update_2site!(form(ψ), ψ, gate; threshold, maxdim, kwargs...)
else
throw(ArgumentError("Only 1-site and 2-site gates are currently supported"))
end

@assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"

return simple_update!(form(ψ), ψ, gate; threshold, maxdim, kwargs...)
end

# TODO a lot of problems with merging... maybe we shouldn't merge manually
Expand Down Expand Up @@ -419,9 +432,15 @@ function simple_update_1site!(ψ::AbstractAnsatz, gate)
return contract!(ψ, contracting_index)
end

# TODO remove `renormalize` argument?
function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false)
@assert nlanes(gate) == 2 "Only 2-site gates are supported currently"
function simple_update_2site!(
::MixedCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, normalize=false
)
return simple_update_2site!(NonCanonical(), ψ, gate; threshold, maxdim, normalize)
end

function simple_update_2site!(
::NonCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, normalize=false
)
@assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"

# shallow copy to avoid problems if errors in mid execution
Expand Down Expand Up @@ -455,16 +474,49 @@ function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=noth

# truncate virtual index
if any(!isnothing, (threshold, maxdim))
truncate!(ψ, bond; threshold, maxdim)
renormalize && normalize!(ψ, bond[1])
truncate!(ψ, collect(bond); threshold, maxdim, normalize)
end

return ψ
end

# TODO remove `renormalize` argument?
# TODO optimize correctly -> avoid recanonization + use lateral Λs
function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, renormalize=false)
simple_update!(NonCanonical(), ψ, gate; threshold, maxdim, renormalize)
return canonize!(ψ)
# TODO remove `normalize` argument?
function simple_update_2site!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, normalize=false, canonize=true)
# Contract the exterior Λ tensors
sitel, siter = extrema(lanes(gate))
(0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) ||
throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))"))

Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel))
Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1)))

!isnothing(Λᵢ₋₁) && contract!(ψ; between=(Site(id(sitel) - 1), sitel), direction=:right, delete_Λ=false)
!isnothing(Λᵢ₊₁) && contract!(ψ; between=(siter, Site(id(siter) + 1)), direction=:left, delete_Λ=false)

simple_update_2site!(NonCanonical(), ψ, gate; threshold, maxdim, normalize=false)

# contract the updated tensors with the inverse of Λᵢ and Λᵢ₊₂, to get the new Γ tensors
U, Vt = tensors(ψ; at=sitel), tensors(ψ; at=siter)
Γᵢ₋₁ = if isnothing(Λᵢ₋₁)
U
else
contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)); atol=wrap_eps(eltype(U)))), inds(Λᵢ₋₁)); dims=())
end
Γᵢ = if isnothing(Λᵢ₊₁)
Vt
else
contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)); atol=wrap_eps(eltype(Vt)))), inds(Λᵢ₊₁)), Vt; dims=())
end

# Update the tensors in the tensor network
replace!(ψ, tensors(ψ; at=sitel) => Γᵢ₋₁)
replace!(ψ, tensors(ψ; at=siter) => Γᵢ)

if canonize
canonize!(ψ; normalize)
else
normalize && normalize!(ψ, collect((sitel, siter)))
end

return ψ
end
31 changes: 25 additions & 6 deletions src/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ function canonize_site!(ψ::MPS, site::Site; direction::Symbol, method=:qr)
return ψ
end

function canonize!::AbstractMPO)
function canonize!::AbstractMPO; normalize=false)
Λ = Tensor[]

# right-to-left QR sweep, get right-canonical tensors
Expand All @@ -495,6 +495,7 @@ function canonize!(ψ::AbstractMPO)

# extract the singular values and contract them with the next tensor
Λᵢ = pop!(ψ, tensors(ψ; between=(Site(i), Site(i + 1))))
normalize && (Λᵢ ./= norm(Λᵢ))
Aᵢ₊₁ = tensors(ψ; at=Site(i + 1))
replace!(ψ, Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ; dims=()))
push!(Λ, Λᵢ)
Expand Down Expand Up @@ -541,19 +542,37 @@ function mixed_canonize!(tn::AbstractMPO, orthog_center)
end

LinearAlgebra.normalize!::AbstractMPO; kwargs...) = normalize!(form(ψ), ψ; kwargs...)
LinearAlgebra.normalize!::AbstractMPO, at::Site) = normalize!(form(ψ), ψ; at)
LinearAlgebra.normalize!::AbstractMPO, bond::Base.AbstractVecOrTuple{Site}) = normalize!(form(ψ), ψ; bond)

# NOTE: Inplace normalization of the arrays should be faster, but currently lead to problems for `copy` TensorNetworks
function LinearAlgebra.normalize!(::NonCanonical, ψ::AbstractMPO; at=Site(nsites(ψ) ÷ 2))
tensor = tensors(ψ; at)
tensor ./= norm(ψ)
if at isa Site
tensor = tensors(ψ; at)
replace!(ψ, tensor => tensor ./ norm(ψ))
else
normalize!(mixed_canonize!(ψ, at))
end

return ψ
end

LinearAlgebra.normalize!::AbstractMPO, site::Site) = normalize!(mixed_canonize!(ψ, site); at=site)

function LinearAlgebra.normalize!(config::MixedCanonical, ψ::AbstractMPO; at=config.orthog_center)
mixed_canonize!(ψ, at)
normalize!(tensors(ψ; at), 2)
return ψ
end

# TODO function LinearAlgebra.normalize!(::Canonical, ψ::AbstractMPO) end
function LinearAlgebra.normalize!(config::Canonical, ψ::AbstractMPO; bond=nothing)
if isnothing(bond) # Normalize all λ tensors
for i in 1:(nsites(ψ) - 1)
λ = tensors(ψ; between=(Site(i), Site(i + 1)))
replace!(ψ, λ => λ ./ norm(λ)^(1 / (nsites(ψ) - 1)))
end
else
λ = tensors(ψ; between=bond)
replace!(ψ, λ => λ ./ norm(λ))
end

return ψ
end
2 changes: 2 additions & 0 deletions src/Quantum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,8 @@ function Base.merge!(a::AbstractQuantum, b::AbstractQuantum; reset=true)
return a
end

LinearAlgebra.normalize::AbstractQuantum; kwargs...) = normalize!(copy(ψ); kwargs...)

function LinearAlgebra.norm::AbstractQuantum, p::Real=2; kwargs...)
p == 2 || throw(ArgumentError("only L2-norm is implemented yet"))
return LinearAlgebra.norm2(ψ; kwargs...)
Expand Down
91 changes: 76 additions & 15 deletions test/MPS_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,33 @@ using LinearAlgebra
# If maxdim > size(spectrum), the bond dimension is not truncated
truncated = truncate(ψ, [site"2", site"3"]; maxdim=4)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2
end

@testset "Canonical" begin
ψ = rand(MPS; n=5, maxdim=16)
canonize!(ψ)

truncated = truncate(ψ, [site"2", site"3"]; maxdim=2)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2
normalize!(ψ)
truncated = truncate(ψ, [site"2", site"3"]; maxdim=1, normalize=true)
@test norm(truncated) 1.0
end

@testset "MixedCanonical" begin
ψ = rand(MPS; n=5, maxdim=16)

truncated = truncate(ψ, [site"2", site"3"]; maxdim=3)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 3

truncated = truncate(ψ, [site"2", site"3"]; maxdim=3, normalize=true)
@test norm(truncated) 1.0
end

@testset "Canonical" begin
ψ = rand(MPS; n=5, maxdim=16)
canonize!(ψ)

truncated = truncate(ψ, [site"2", site"3"]; maxdim=2, canonize=true, normalize=true)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2
@test Tenet.check_form(truncated)
@test norm(truncated) 1.0

truncated = truncate(ψ, [site"2", site"3"]; maxdim=2, canonize=false, normalize=true)
@test norm(truncated) 1.0
end
end

Expand All @@ -144,11 +156,42 @@ using LinearAlgebra
end

@testset "normalize!" begin
using LinearAlgebra: normalize!
using LinearAlgebra: normalize, normalize!

ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])
normalize!(ψ, Site(3))
@test isapprox(norm(ψ), 1.0)
@testset "NonCanonical" begin
ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])

normalized = normalize(ψ)
@test norm(normalized) 1.0

normalize!(ψ, Site(3))
@test norm(ψ) 1.0
end

@testset "MixedCanonical" begin
ψ = rand(MPS; n=5, maxdim=16)

# Perturb the state to make it non-normalized
t = tensors(ψ; at=site"3")
replace!(ψ, t => Tensor(rand(size(t)...), inds(t)))

normalized = normalize(ψ)
@test norm(normalized) 1.0

normalize!(ψ, Site(3))
@test norm(ψ) 1.0
end

@testset "Canonical" begin
ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])
canonize!(ψ)

normalized = normalize(ψ)
@test norm(normalized) 1.0

normalize!(ψ, (Site(3), Site(4)))
@test norm(ψ) 1.0
end
end

@testset "canonize_site!" begin
Expand Down Expand Up @@ -303,14 +346,32 @@ using LinearAlgebra
@test length(tensors(ϕ)) == 5
@test issetequal(size.(tensors(ϕ)), [(2, 2), (2, 2, 2), (2,), (2, 2, 2), (2, 2, 2), (2, 2)])
@test isapprox(contract(ϕ), contract(ψ))

evolved = evolve!(normalize(ψ), gate; maxdim=1, normalize=true)
@test norm(evolved) 1.0
end

@testset "Canonical" begin
ψ = deepcopy(ψ)
ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)])
normalize!(ψ)
ϕ = deepcopy(ψ)

canonize!(ψ)
evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14)
@test isapprox(contract(evolved), contract(ψ))
@test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)])

evolved = evolve!(deepcopy(ψ), gate)
@test Tenet.check_form(evolved)
@test isapprox(contract(evolved), contract(ϕ)) # Identity gate should not change the state

# Ensure that the original MixedCanonical state evolves into the same state as the canonicalized one
@test contract(ψ) contract(evolve!(ϕ, gate; threshold=1e-14))

evolved = evolve!(deepcopy(ψ), gate; maxdim=1, normalize=true, canonize=true)
@test norm(evolved) 1.0
@test Tenet.check_form(evolved)

evolved = evolve!(deepcopy(ψ), gate; maxdim=1, normalize=true, canonize=false)
@test norm(evolved) 1.0
@test_throws ArgumentError Tenet.check_form(evolved)
end
end
end
Expand Down

0 comments on commit 4247207

Please sign in to comment.