From 8db7c3356fbedbdc1dc31fad46cad5d653e75f85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Thu, 15 Feb 2024 14:28:01 +0100 Subject: [PATCH 1/8] Fix rightindex and leftindex logic --- src/Ansatz/Chain.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 8f88128..8fa8fec 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -112,15 +112,19 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end +rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) +rightsite(::Open, tn::Chain, site::Site) = Site(site.id + 1) + +leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) +leftsite(::Open, tn::Chain, site::Site) = Site(site.id - 1) + leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) leftindex(::Periodic, tn::Chain, site::Site) = (select(tn, :tensor, site)|>inds)[end-1] function leftindex(::Open, tn::Chain, site::Site) if site == site"1" nothing - elseif site == Site(nsites(tn)) # TODO review - (select(tn, :tensor, site)|>inds)[end] else - (select(tn, :tensor, site)|>inds)[end-1] + (select(tn, :tensor, site)|>inds) ∩ (select(tn, :tensor, leftsite(tn, site))|>inds) |> only end end @@ -130,7 +134,7 @@ function rightindex(::Open, tn::Chain, site::Site) if site == Site(nsites(tn)) # TODO review nothing else - (select(tn, :tensor, site)|>inds)[end] + (select(tn, :tensor, site)|>inds) ∩ (select(tn, :tensor, rightsite(tn, site))|>inds) |> only end end From 55be4467ecadf70ce2ac78a0dc06a203e22fcc78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Thu, 15 Feb 2024 14:28:34 +0100 Subject: [PATCH 2/8] Fix canonize! function and enhance syntax --- src/Ansatz/Chain.jl | 29 ++++++++++++++++++----------- src/Qrochet.jl | 2 +- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 8fa8fec..00c9017 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -138,35 +138,42 @@ function rightindex(::Open, tn::Chain, site::Site) end end +canonize(tn::Chain, args...; kwargs...) = canonize!(deepcopy(tn), args...; kwargs...) canonize!(tn::Chain, args...; kwargs...) = canonize!(boundary(tn), tn, args...; kwargs...) -# NOTE spectral weights are stored in a vector connected to the now virtual hyperindex! -function canonize!(::Open, tn::Chain, site::Site; direction::Symbol) +# NOTE: in mode == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex! +function canonize!(::Open, tn::Chain, site::Site; direction::Symbol, mode::Symbol = :qr) left_inds = Symbol[] right_inds = Symbol[] virtualind = if direction === :left - site == Site(1) && throw(ArgumentError("Cannot left-canonize left-most tensor")) - push!(left_inds, leftindex(tn, site)) - - site == Site(nsites(tn)) || push!(right_inds, rightindex(tn, site)) - push!(right_inds, Quantum(tn)[site]) - - only(left_inds) - elseif direction === :right site == Site(nsites(tn)) && throw(ArgumentError("Cannot right-canonize right-most tensor")) push!(right_inds, rightindex(tn, site)) site == Site(1) || push!(left_inds, leftindex(tn, site)) push!(left_inds, Quantum(tn)[site]) + only(right_inds) + elseif direction === :right + site == Site(1) && throw(ArgumentError("Cannot left-canonize left-most tensor")) + push!(right_inds, leftindex(tn, site)) + + site == Site(nsites(tn)) || push!(left_inds, rightindex(tn, site)) + push!(left_inds, Quantum(tn)[site]) + only(right_inds) else throw(ArgumentError("Unknown direction=:$direction")) end tmpind = gensym(:tmp) - qr!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind) + if mode == :qr + qr!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind) + elseif mode == :svd + svd!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind) + else + throw(ArgumentError("Unknown mode=:$mode")) + end contract!(TensorNetwork(tn), virtualind) replace!(TensorNetwork(tn), tmpind => virtualind) diff --git a/src/Qrochet.jl b/src/Qrochet.jl index d0387a0..72f4eaa 100644 --- a/src/Qrochet.jl +++ b/src/Qrochet.jl @@ -17,7 +17,7 @@ export Product include("Ansatz/Chain.jl") export Chain export MPS, pMPS, MPO, pMPO -export leftindex, rightindex, canonize! +export leftindex, rightindex, canonize, canonize! # reexports from Tenet using Tenet From 193e95f5a2886177a92a0b26c4caddd6d1399e55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Thu, 15 Feb 2024 14:37:45 +0100 Subject: [PATCH 3/8] Fix typo and add tests --- src/Ansatz/Chain.jl | 2 +- test/Ansatz/Chain_test.jl | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 00c9017..ecf73a5 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -142,7 +142,7 @@ canonize(tn::Chain, args...; kwargs...) = canonize!(deepcopy(tn), args...; kwarg canonize!(tn::Chain, args...; kwargs...) = canonize!(boundary(tn), tn, args...; kwargs...) # NOTE: in mode == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex! -function canonize!(::Open, tn::Chain, site::Site; direction::Symbol, mode::Symbol = :qr) +function canonize!(::Open, tn::Chain, site::Site; direction::Symbol, mode = :qr) left_inds = Symbol[] right_inds = Symbol[] diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 2112dd6..25fe934 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -26,4 +26,43 @@ @test noutputs(qtn) == 3 @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) @test boundary(qtn) == Open() + + @testset "canonize" begin + function is_left_canonical(qtn, s::Site) + label_r = rightindex(qtn, s) + A = select(qtn, :tensor, s) + try + contracted = contract(A, replace(conj(A), label_r => :new_ind_name)) + return isapprox(contracted, Matrix{Float64}(I, size(A, label_r), size(A, label_r)), atol=1e-12) + catch + return false + end + end + + function is_right_canonical(qtn, s::Site) + label_l = leftindex(qtn, s) + A = select(qtn, :tensor, s) + try + contracted = contract(A, replace(conj(A), label_l => :new_ind_name)) + return isapprox(contracted, Matrix{Float64}(I, size(A, label_l), size(A, label_l)), atol=1e-12) + catch + return false + end + end + + qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)]) + + @test_throws ArgumentError canonize!(qtn, Site(1); direction=:right) + @test_throws ArgumentError canonize!(qtn, Site(3); direction=:left) + + for mode in [:qr, :svd] + for i in 1:length(sites(qtn)) + if i != 1 + @test is_right_canonical(canonize(qtn, Site(i); direction=:right, mode=mode), Site(i)) + elseif i != length(sites(qtn)) + @test is_left_canonical(canonize(qtn, Site(i); direction=:left, mode=mode), Site(i)) + end + end + end + end end From d6eb2d044bc4f19ee96f8bfea3720365867c503f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Thu, 15 Feb 2024 14:46:03 +0100 Subject: [PATCH 4/8] Enhance tests --- test/Ansatz/Chain_test.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 25fe934..4f190eb 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -28,6 +28,8 @@ @test boundary(qtn) == Open() @testset "canonize" begin + using Tenet + function is_left_canonical(qtn, s::Site) label_r = rightindex(qtn, s) A = select(qtn, :tensor, s) @@ -58,11 +60,18 @@ for mode in [:qr, :svd] for i in 1:length(sites(qtn)) if i != 1 - @test is_right_canonical(canonize(qtn, Site(i); direction=:right, mode=mode), Site(i)) + canonized = canonize(qtn, Site(i); direction=:right, mode=mode) + @test is_right_canonical(canonized, Site(i)) + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn))) elseif i != length(sites(qtn)) - @test is_left_canonical(canonize(qtn, Site(i); direction=:left, mode=mode), Site(i)) + canonized = canonize(qtn, Site(i); direction=:left, mode=mode) + @test is_left_canonical(canonized, Site(i)) + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), contract(TensorNetwork(qtn))) end end end + + # Ensure that svd creates a new tensor + @test length(tensors(canonize(qtn, Site(2); direction=:right, mode=:svd))) == 4 end end From f8fa5e0be0de632077271836d07abcdf91fbd6f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Thu, 15 Feb 2024 15:13:59 +0100 Subject: [PATCH 5/8] Update code for Periodic boundary --- src/Ansatz/Chain.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index ecf73a5..f23bd5e 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -113,14 +113,13 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}) end rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) -rightsite(::Open, tn::Chain, site::Site) = Site(site.id + 1) +rightsite(::Union{Open, Periodic}, tn::Chain, site::Site) = Site(site.id + 1) leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) -leftsite(::Open, tn::Chain, site::Site) = Site(site.id - 1) +leftsite(::Union{Open, Periodic}, tn::Chain, site::Site) = Site(site.id - 1) leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) -leftindex(::Periodic, tn::Chain, site::Site) = (select(tn, :tensor, site)|>inds)[end-1] -function leftindex(::Open, tn::Chain, site::Site) +function leftindex(::Union{Open, Periodic}, tn::Chain, site::Site) if site == site"1" nothing else @@ -129,8 +128,7 @@ function leftindex(::Open, tn::Chain, site::Site) end rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site) -rightindex(::Periodic, tn::Chain, site::Site) = (select(tn, :tensor, site)|>inds)[end] -function rightindex(::Open, tn::Chain, site::Site) +function rightindex(::Union{Open, Periodic}, tn::Chain, site::Site) if site == Site(nsites(tn)) # TODO review nothing else From df93790f0f6732ab02faf0e927b89f531d8f6e0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Fri, 16 Feb 2024 10:06:49 +0100 Subject: [PATCH 6/8] Fix leftsite and rightsite function --- src/Ansatz/Chain.jl | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index f23bd5e..1dfa744 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -112,11 +112,19 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) -rightsite(::Union{Open, Periodic}, tn::Chain, site::Site) = Site(site.id + 1) - leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) -leftsite(::Union{Open, Periodic}, tn::Chain, site::Site) = Site(site.id - 1) +function leftsite(::Open, tn::Chain, site::Site) + (site.id > length(sites(tn)) || site.id <= 1) && throw(ArgumentError("Invalid site $site")) + Site(site.id - 1) +end +leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id - 1, length(sites(tn)))) + +rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) +function rightsite(::Open, tn::Chain, site::Site) + (site.id > length(sites(tn))-1 || site.id < 1) && throw(ArgumentError("Invalid site $site")) + Site(site.id + 1) +end +rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id + 1, length(sites(tn)))) leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) function leftindex(::Union{Open, Periodic}, tn::Chain, site::Site) From fa4bfbb442b43e0b52286942d48cdc4c5ce7b549 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Fri, 16 Feb 2024 10:07:07 +0100 Subject: [PATCH 7/8] Add Site testset --- test/Ansatz/Chain_test.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 4f190eb..af94f79 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -27,6 +27,30 @@ @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) @test boundary(qtn) == Open() + @testset "Site" begin + using Qrochet: leftsite, rightsite + qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) + + @test leftsite(qtn, Site(1)) == Site(3) + @test leftsite(qtn, Site(2)) == Site(1) + @test leftsite(qtn, Site(3)) == Site(2) + + @test rightsite(qtn, Site(1)) == Site(2) + @test rightsite(qtn, Site(2)) == Site(3) + @test rightsite(qtn, Site(3)) == Site(1) + + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + + @test_throws ArgumentError leftsite(qtn, Site(1)) + @test_throws ArgumentError rightsite(qtn, Site(3)) + + @test leftsite(qtn, Site(2)) == Site(1) + @test leftsite(qtn, Site(3)) == Site(2) + + @test rightsite(qtn, Site(2)) == Site(3) + @test rightsite(qtn, Site(1)) == Site(2) + end + @testset "canonize" begin using Tenet From 656e4d5595f84c17eb49dbf4a77836fcc5f0566c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Fri, 16 Feb 2024 10:58:57 +0100 Subject: [PATCH 8/8] Update syntax --- src/Ansatz/Chain.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 1dfa744..6c3beec 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -114,14 +114,14 @@ end leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) function leftsite(::Open, tn::Chain, site::Site) - (site.id > length(sites(tn)) || site.id <= 1) && throw(ArgumentError("Invalid site $site")) + site.id ∉ range(2, length(sites(tn))) && throw(ArgumentError("Invalid site $site")) Site(site.id - 1) end leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id - 1, length(sites(tn)))) rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) function rightsite(::Open, tn::Chain, site::Site) - (site.id > length(sites(tn))-1 || site.id < 1) && throw(ArgumentError("Invalid site $site")) + site.id ∉ range(1, length(sites(tn))-1) && throw(ArgumentError("Invalid site $site")) Site(site.id + 1) end rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(site.id + 1, length(sites(tn))))