Skip to content

Commit

Permalink
Test TensorNetwork interface
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Feb 1, 2025
1 parent e22a4c3 commit f099ba9
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 70 deletions.
231 changes: 174 additions & 57 deletions test/Interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,120 +3,218 @@ using Tenet
using Tenet: ninds, ntensors

# TensorNetwork interface
function test_tensornetwork(tn)
test_tensornetwork_inds(tn)
test_tensornetwork_ninds(tn)
test_tensornetwork_tensors(tn)
test_tensornetwork_ntensors(tn)
test_tensornetwork_arrays(tn)
test_tensornetwork_size(tn)
test_tensornetwork_in(tn)
test_tensornetwork_replace!(tn)
test_tensornetwork_contract(tn)
test_tensornetwork_contract!(tn)
function test_tensornetwork(
tn;
inds=true,
ninds=true,
tensors=true,
ntensors=true,
arrays=true,
size=true,
inclusion=true,
replace=true,
contract=true,
)
inds && test_tensornetwork_inds(tn)
ninds && test_tensornetwork_ninds(tn)
tensors && test_tensornetwork_tensors(tn)
ntensors && test_tensornetwork_ntensors(tn)
arrays && test_tensornetwork_arrays(tn)
size && test_tensornetwork_size(tn)
inclusion && test_tensornetwork_in(tn)
replace && test_tensornetwork_replace!(tn)
contract && test_tensornetwork_contract(tn)
contract && test_tensornetwork_contract!(tn)
end

function test_tensornetwork_inds(tn)
# `inds` returns a list of the indices in the Tensor Network
@testimpl inds(tn) isa AbstractVector{Symbol}
@test inds(tn) isa AbstractVector{Symbol}

# `inds(; set = :all)` is equal to naive `inds`
@testimpl inds(tn; set=:all) == inds(tn)
@test inds(tn; set=:all) == inds(tn)

# `inds(; set = :open)` returns a list of indices of the Tensor Network
@testimpl inds(tn; set=:open) isa AbstractVector{Symbol}
@test inds(tn; set=:open) isa AbstractVector{Symbol}

# `inds(; set = :inner)` returns a list of indices of the Tensor Network
@testimpl inds(tn; set=:inner) isa AbstractVector{Symbol}
@test inds(tn; set=:inner) isa AbstractVector{Symbol}

# `inds(; set = :hyper)` returns a list of indices of the Tensor Network
@testimpl inds(tn; set=:hyper) isa AbstractVector{Symbol}
@test inds(tn; set=:hyper) isa AbstractVector{Symbol}

# `inds(; parallelto)` returns a list of indices parallel to `i` in the graph
@testimpl inds(tn; parallelto=first(inds(tn))) isa AbstractVector{Symbol}
@test let tn = tn
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
inds(tn; parallelto=first(_inds)) isa AbstractVector{Symbol}
end
end

function test_tensornetwork_ninds(tn)
# `ninds` returns the number of indices in the Tensor Network
@testimpl ninds(tn) == length(inds(tn))
@test ninds(tn) == length(inds(tn))

# `ninds(; set = :all)` is equal to naive `ninds`
@testimpl ninds(tn; set=:all) == ninds(tn)
@test ninds(tn; set=:all) == ninds(tn)

# `ninds(; set = :open)` returns the number of open indices in the Tensor Network
@testimpl ninds(tn; set=:open) == length(inds(tn; set=:open))
@test ninds(tn; set=:open) == length(inds(tn; set=:open))

# `ninds(; set = :inner)` returns the number of inner indices in the Tensor Network
@testimpl ninds(tn; set=:inner) == length(inds(tn; set=:inner))
@test ninds(tn; set=:inner) == length(inds(tn; set=:inner))

# `ninds(; set = :hyper)` returns the number of hyper indices in the Tensor Network
@testimpl ninds(tn; set=:hyper) == length(inds(tn; set=:hyper))
@test ninds(tn; set=:hyper) == length(inds(tn; set=:hyper))
end

function test_tensornetwork_tensors(tn)
# `tensors` returns a list of the tensors in the Tensor Network
@testimpl tensors(tn) isa AbstractVector{<:Tensor}
@test tensors(tn) isa AbstractVector{<:Tensor}

# `tensors(; contains = i)` returns a list of tensors containing index `i`
@testimpl tensors(tn; contains=first(inds(tn))) isa AbstractVector{<:Tensor}
@test let tn = tn
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
tensors(tn; contains=first(_inds)) isa AbstractVector{<:Tensor}
end

# `tensors(; intersects = i)` returns a list of tensors intersecting index `i`
@testimpl tensors(tn; intersects=first(inds(tn))) isa AbstractVector{<:Tensor}
@test let tn = tn
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
tensors(tn; intersects=first(_inds)) isa AbstractVector{<:Tensor}
end
end

function test_tensornetwork_ntensors(tn)
#`ntensors` returns the number of tensors in the Tensor Network
@testimpl ntensors(tn) == length(tensors(tn))
@test ntensors(tn) == length(tensors(tn))

#`ntensors(; contains = i)` returns the number of tensors containing index `i`
@testimpl ntensors(tn; contains=first(inds(tn))) == length(tensors(tn; contains=first(inds(tn))))
@test let tn = tn
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
ntensors(tn; contains=first(_inds)) == length(tensors(tn; contains=first(_inds)))
end

#`ntensors(; intersects = i)` returns the number of tensors intersecting index `i`
@testimpl ntensors(tn; intersects=first(inds(tn))) == length(tensors(tn; contains=first(inds(tn))))
@test let tn = tn
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
ntensors(tn; intersects=first(_inds)) == length(tensors(tn; contains=first(_inds)))
end
end

function test_tensornetwork_arrays(tn)
# `arrays` returns a list of the arrays in the Tensor Network
@testimpl arrays(tn) == parent.(tensors(tn))
@test arrays(tn) == parent.(tensors(tn))

# `arrays(; contains = i)` returns a list of arrays containing index `i`
@testimpl arrays(tn; contains=first(inds(tn))) == parent.(tensors(tn; contains=first(inds(tn))))
@test let tn = tn
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
arrays(tn; contains=first(_inds)) == parent.(tensors(tn; contains=first(_inds)))
end

# `arrays(; intersects = i)` returns a list of arrays intersecting index `i`
@testimpl arrays(tn; intersects=first(inds(tn))) == parent.(tensors(tn; contains=first(inds(tn))))
@test let tn = tn
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
arrays(tn; intersects=first(_inds)) == parent.(tensors(tn; contains=first(_inds)))
end
end

function test_tensornetwork_size(tn)
# `size` returns a mapping from indices to their dimensionalities
@testimpl size(tn) isa AbstractDict{Symbol,Int}
@test size(tn) isa AbstractDict{Symbol,Int}

# `size` on Symbol returns the dimensionality of that index
@testimpl size(tn, first(inds(tn))) isa Int
@test let tn = tn
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
size(tn, first(_inds)) isa Int
end
end

function test_tensornetwork_in(tn)
# `in` on `Symbol` returns if the index is present in the Tensor Network
@testimpl in(first(inds(tn)), tn) == true
@test let tn = tn
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
in(first(_inds), tn) == true
end

# `in` on `Tensor` returns if that exact object is present in the Tensor Network
@testimpl in(first(tensors(tn)), tn) == true
@test let tn = tn
_tensors = tensors(tn)
if isempty(_tensors)
# TODO it should skip here but let's just return true for now
return true
end
in(first(_tensors), tn) == true
end

# `in` on copied `Tensor` is never included
@testimpl in(copy(first(tensors(tn))), tn) == false
@test let tn = tn
_tensors = tensors(tn)
if isempty(_tensors)
# TODO it should skip here but let's just return true for now
return true
end
in(copy(first(_tensors)), tn) == false
end
end

function test_tensornetwork_replace!(tn)
# `replace!` on `Symbol` replaces an index in the Tensor Network
@testimpl let tn = deepcopy(tn)
ind = first(inds(tn))
@test let tn = deepcopy(tn)
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
ind = first(_inds)
new_ind = gensym(:new)
replace!(tn, ind => new_ind)
new_ind tn
end

# `replace!` on `Tensor` replaces a tensor in the Tensor Network
@testimpl let tn = deepcopy(tn)
tensor = first(tensors(tn))
@test let tn = deepcopy(tn)
_tensors = tensors(tn)
if isempty(_tensors)
# TODO it should skip here but let's just return true for now
return true
end
tensor = first(_tensors)
new_tensor = copy(tensor)
replace!(tn, tensor => new_tensor)
new_tensor tn
Expand All @@ -125,13 +223,18 @@ end

function test_tensornetwork_contract(tn)
# `contract` returns a `Tensor`
@testimpl contract(tn) isa Tensor
@test contract(tn) isa Tensor
end

function test_tensornetwork_contract!(tn)
# `contract!` on `Symbol` contracts an index in-place
@testimpl let tn = deepcopy(tn)
ind = first(inds(tn))
@test let tn = deepcopy(tn)
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
ind = first(_inds)
contract!(tn, ind)
ind tn
end
Expand All @@ -147,43 +250,57 @@ end

function test_pluggable_sites(tn)
# `sites` returns a list of the sites in the Tensor Network
@testimpl sites(tn) isa AbstractVector{<:Site}
@test sites(tn) isa AbstractVector{<:Site}

# `sites(; set = :all)` is equal to naive `sites`
@testimpl sites(tn; set=:all) == sites(tn)
@test sites(tn; set=:all) == sites(tn)

# `sites(; set = :inputs)` returns a list of input sites (i.e. dual) in the Tensor Network
@testimpl sites(tn; set=:inputs) isa AbstractVector{<:Site} && all(isdual, sites(tn; set=:inputs))
@test sites(tn; set=:inputs) isa AbstractVector{<:Site} && all(isdual, sites(tn; set=:inputs))

# `sites(; set = :outputs)` returns a list of output sites (i.e. non-dual) in the Tensor Network
@testimpl sites(tn; set=:outputs) isa AbstractVector{<:Site} && all(!isdual, sites(tn; set=:outputs))
@test sites(tn; set=:outputs) isa AbstractVector{<:Site} && all(!isdual, sites(tn; set=:outputs))

# `sites(; at::Symbol)` returns the site linked to the index
@testimpl sites(tn; at=first(inds(tn))) isa Site
@test let tn = tn
_inds = inds(tn)
if isempty(_inds)
# TODO it should skip here but let's just return true for now
return true
end
sites(tn; at=first(_inds)) isa Site
end
end

function test_pluggable_socket(tn)
# `socket` returns the socket of the Tensor Network
@testimpl socket(tn) isa Socket
@test socket(tn) isa Socket
end

function test_pluggable_inds(tn)
# `inds` returns a list of the indices in the Tensor Network
@testimpl inds(tn; at=first(sites(tn))) isa Site
@test let tn = tn
_sites = sites(tn)
if isempty(_sites)
# TODO it should skip here but let's just return true for now
return true
end
inds(tn; at=first(_sites)) isa Site
end
end

function test_pluggable_ninds(tn)
# `ninds` returns the number of sites in the Tensor Network
@testimpl nsites(tn) == length(sites(tn))
@test nsites(tn) == length(sites(tn))

# `ninds(; set = :all)` is equal to naive `ninds`
@testimpl nsites(tn; set=:all) == nsites(tn)
@test nsites(tn; set=:all) == nsites(tn)

# `ninds(; set = :inputs)` returns the number of input sites in the Tensor Network
@testimpl nsites(tn; set=:inputs) == length(sites(tn; set=:inputs))
@test nsites(tn; set=:inputs) == length(sites(tn; set=:inputs))

# `ninds(; set = :outputs)` returns the number of output sites in the Tensor Network
@testimpl nsites(tn; set=:outputs) == length(sites(tn; set=:inputs))
@test nsites(tn; set=:outputs) == length(sites(tn; set=:inputs))
end

# Ansatz interface
Expand All @@ -195,15 +312,15 @@ end

function test_ansatz_lanes(tn)
# `lanes` returns a list of the lanes in the Tensor Network
@testimpl lanes(tn) isa AbstractVector{<:Lane}
@test lanes(tn) isa AbstractVector{<:Lane}
end

function test_ansatz_lattice(tn)
# `lattice` returns the lattice of the Tensor Network
@testimpl lattice(tn) isa Lattice
@test lattice(tn) isa Lattice
end

function test_ansatz_tensors(tn)
# `tensors(; at::Lane)` returns the `Tensor` linked to a `Lane`
@testimpl tensors(tn; at=first(lanes(tn))) isa Tensor
@test tensors(tn; at=first(lanes(tn))) isa Tensor
end
20 changes: 9 additions & 11 deletions test/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,15 @@ end
Test `expr` but return `false` if a `MethodError` is thrown.
"""
macro testimpl(expr)
quote
@test begin
try
$expr
catch e
if e isa MethodError
return false
else
rethrow(e)
end
return Base.remove_linenums!(:(
try
@test $(esc(expr))
catch e
if e isa MethodError
return false
else
rethrow(e)
end
end
end
))
end
Loading

0 comments on commit f099ba9

Please sign in to comment.