Skip to content

Commit

Permalink
Replace OMEinsum code with TensorOperations
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jan 8, 2025
1 parent e030f88 commit 385440e
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 42 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KeywordDispatch = "5888135b-5456-5c80-a1b6-c91ef8180460"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[weakdeps]
Expand Down Expand Up @@ -75,14 +75,14 @@ KeywordDispatch = "0.3"
KrylovKit = "0.7, 0.8"
LinearAlgebra = "1.10"
Makie = "0.18,0.19,0.20, 0.21"
OMEinsum = "0.7, 0.8"
PythonCall = "0.9"
Quac = "0.3"
Random = "1.10"
Reactant = "0.2.9"
ScopedValues = "1"
Serialization = "1.10"
SparseArrays = "1.10"
TensorOperations = "4, 5"
UUIDs = "1.10"
YaoBlocks = "0.13"
julia = "1.10"
39 changes: 10 additions & 29 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using OMEinsum
using TensorOperations
using LinearAlgebra
using UUIDs: uuid4
using SparseArrays
Expand Down Expand Up @@ -46,14 +46,7 @@ Perform a binary tensor contraction operation.
- `dims`: indices to contract over. Defaults to the set intersection of the indices of `a` and `b`.
- `out`: indices of the output tensor. Defaults to the set difference of the indices of `a` and `b`.
"""
function contract(a::Tensor, b::Tensor; kwargs...)
c = allocate_result(contract, a, b; kwargs...)
return contract!(c, a, b)
end

function allocate_result(
::typeof(contract), a::Tensor, b::Tensor; fillzero=false, dims=((inds(a), inds(b))), out=nothing
)
function contract(a::Tensor, b::Tensor; dims=((inds(a), inds(b))), out=nothing)
ia = collect(inds(a))
ib = collect(inds(b))
i = (dims, ia, ib)
Expand All @@ -64,7 +57,7 @@ function allocate_result(
out
end

data = OMEinsum.get_output_array((parent(a), parent(b)), [size(i in ia ? a : b, i) for i in ic]; fillzero)
data = tensorcontract(Tuple(ic), parent(a), Tuple(inds(a)), false, parent(b), Tuple(inds(b)), false)
return Tensor(data, ic)
end

Expand All @@ -78,12 +71,7 @@ Perform a unary tensor contraction operation.
- `dims`: indices to contract over. Defaults to the repeated indices.
- `out`: indices of the output tensor. Defaults to the unique indices.
"""
function contract(a::Tensor; kwargs...)
c = allocate_result(contract, a; kwargs...)
return contract!(c, a)
end

function allocate_result(::typeof(contract), a::Tensor; fillzero=false, dims=nonunique(inds(a)), out=nothing)
function contract(a::Tensor; dims=nonunique(inds(a)), out=nothing)
ia = inds(a)
i = (dims, ia)

Expand All @@ -93,7 +81,8 @@ function allocate_result(::typeof(contract), a::Tensor; fillzero=false, dims=non
out
end

data = OMEinsum.get_output_array((parent(a),), [size(a, i) for i in ic]; fillzero)
# TODO might fail on partial trace
data = tensortrace(Tuple(ic), parent(a), Tuple(inds(a)), false)
return Tensor(data, ic)
end

Expand All @@ -109,13 +98,8 @@ contract(tensors::Tensor...; kwargs...) = reduce((x, y) -> contract(x, y; kwargs
Perform a binary tensor contraction operation between `a` and `b` and store the result in `c`.
"""
function contract!(c::Tensor, a::Tensor, b::Tensor)
ixs = (inds(a), inds(b))
iy = inds(c)
xs = (parent(a), parent(b))
y = parent(c)
size_dict = merge!(Dict{Symbol,Int}.([inds(a) .=> size(a), inds(b) .=> size(b)])...)

einsum!(ixs, iy, xs, y, true, false, size_dict)
pA, pB, pAB = contract_indices(Tuple(inds(a)), Tuple(inds(b)), Tuple(inds(c)))
tensorcontract!(parent(c), parent(a), pA, false, parent(b), pB, false, pAB)
return c
end

Expand All @@ -125,11 +109,8 @@ end
Perform a unary tensor contraction operation on `a` and store the result in `c`.
"""
function contract!(y::Tensor, x::Tensor)
ixs = (inds(x),)
iy = inds(y)
size_dict = Dict{Symbol,Int}(inds(x) .=> size(x))

einsum!(ixs, iy, (parent(x),), parent(y), true, false, size_dict)
p, q = TensorOperations.trace_indices(Tuple(inds(x)), Tuple(inds(y)))
tensortrace!(parent(y), parent(x), p, q, false)
return y
end

Expand Down
1 change: 0 additions & 1 deletion src/TensorNetwork.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Base: AbstractVecOrTuple
using Random
using EinExprs
using OMEinsum
using LinearAlgebra
using ScopedValues
using Serialization
Expand Down
1 change: 0 additions & 1 deletion src/Transformations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using DeltaArrays
using EinExprs
using OMEinsum
using UUIDs: uuid4
using Tenet: parenttype
using Combinatorics: combinations
Expand Down
21 changes: 14 additions & 7 deletions test/Numerics_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
A = Tensor(rand(2, 3, 4), (:i, :j, :k))

C = contract(A; dims=(:i,))
C_ein = ein"ijk -> jk"(parent(A))
# C_ein = ein"ijk -> jk"(parent(A))
@tensor C_ein[j, k] := A[i, j, k]
@test inds(C) == [:j, :k]
@test size(C) == size(C_ein) == (3, 4)
@test parent(C) C_ein
Expand All @@ -31,7 +32,8 @@
A = Tensor(rand(2, 3, 2), (:i, :j, :i))

C = contract(A; dims=())
C_ein = ein"iji -> ij"(parent(A))
# C_ein = ein"iji -> ij"(parent(A))
@tensor C_ein[i, j] := A[i, j, i]
@test inds(C) == [:i, :j]
@test size(C) == size(C_ein) == (2, 3)
@test parent(C) C_ein
Expand All @@ -41,7 +43,8 @@
A = Tensor(rand(2, 3, 2), (:i, :j, :i))

C = contract(A; dims=(:i,))
C_ein = ein"iji -> j"(parent(A))
# C_ein = ein"iji -> j"(parent(A))
@tensor C_ein[j] := A[i, j, i]
@test inds(C) == [:j]
@test size(C) == size(C_ein) == (3,)
@test parent(C) C_ein
Expand Down Expand Up @@ -74,7 +77,8 @@
B = Tensor(rand(2, 2), (:k, :l))

C = contract(A, B)
C_ein = ein"ij, kl -> ijkl"(parent(A), parent(B))
# C_ein = ein"ij, kl -> ijkl"(parent(A), parent(B))
@tensor C_ein[i, j, k, l] := A[i, j] * B[k, l]
@test size(C) == (2, 2, 2, 2) == size(C_ein)
@test inds(C) == [:i, :j, :k, :l]
@test parent(C) C_ein
Expand All @@ -101,14 +105,16 @@

# Contraction of all common indices
C = contract(A, B; dims=(:j, :k))
C_ein = ein"ijk, klj -> il"(parent(A), parent(B))
# C_ein = ein"ijk, klj -> il"(parent(A), parent(B))
@tensor C_ein[i, l] := A[i, j, k] * B[k, l, j]
@test inds(C) == [:i, :l]
@test size(C) == (2, 5) == size(C_ein)
@test parent(C) C_ein

# Contraction of not all common indices
C = contract(A, B; dims=(:j,))
C_ein = ein"ijk, klj -> ikl"(parent(A), parent(B))
# C_ein = ein"ijk, klj -> ikl"(parent(A), parent(B))
@tensor C_ein[i, k, l] := A[i, j, k] * B[k, l, j]
@test inds(C) == [:i, :k, :l]
@test size(C) == (2, 4, 5) == size(C_ein)
@test parent(C) C_ein
Expand All @@ -118,7 +124,8 @@
B = Tensor(rand(Complex{Float64}, 4, 5, 3), (:k, :l, :j))

C = contract(A, B; dims=(:j, :k))
C_ein = ein"ijk, klj -> il"(parent(A), parent(B))
# C_ein = ein"ijk, klj -> il"(parent(A), parent(B))
@tensor C_ein[i, l] := A[i, j, k] * B[k, l, j]
@test inds(C) == [:i, :l]
@test size(C) == (2, 5) == size(C_ein)
@test parent(C) C_ein
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Permutations = "2ae35dd2-176d-5d53-8349-f30d82d94d4f"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test
using Tenet
using OMEinsum
using TensorOperations
using Reactant
using Adapt

Expand Down

0 comments on commit 385440e

Please sign in to comment.