Skip to content

Commit

Permalink
Replace OMEinsum code with TensorOperations
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Sep 23, 2024
1 parent 4c73e8a commit 7671b0f
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 39 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
KeywordDispatch = "5888135b-5456-5c80-a1b6-c91ef8180460"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[weakdeps]
Expand Down Expand Up @@ -70,14 +70,14 @@ KeywordDispatch = "0.3"
KrylovKit = "0.7, 0.8"
LinearAlgebra = "1.9"
Makie = "0.18,0.19,0.20, 0.21"
OMEinsum = "0.7, 0.8"
PythonCall = "0.9"
Quac = "0.3"
Random = "1.9"
Reactant = "0.2"
ScopedValues = "1"
Serialization = "1.9"
SparseArrays = "1.9"
TensorOperations = "4, 5"
UUIDs = "1.9"
Yao = "0.8, 0.9"
julia = "1.9"
3 changes: 0 additions & 3 deletions ext/TenetChainRulesCoreExt/non_diff.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# NOTE fix problem with vector generator in `contract`
@non_differentiable Tenet.__omeinsum_sym2str(x)

# WARN type-piracy
@non_differentiable setdiff(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)
@non_differentiable union(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)
Expand Down
33 changes: 10 additions & 23 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using OMEinsum
using TensorOperations
using LinearAlgebra
using UUIDs: uuid4
using SparseArrays
Expand All @@ -14,10 +14,6 @@ function Base.literal_pow(f, a::Tensor{T,0}, ::Val{p}) where {T,p}
return Tensor(fill(Base.literal_pow(f, only(a), Val(p))))
end

# NOTE used for marking non-differentiability
# NOTE use `String[...]` code instead of `map` or broadcasting to set eltype in empty cases
__omeinsum_sym2str(x) = String[string(i) for i in x]

function Base.:(+)(a::Tensor, b::Tensor)
issetequal(inds(a), inds(b)) || throw(ArgumentError("indices must be equal"))
perm = __find_index_permutation(inds(a), inds(b))
Expand Down Expand Up @@ -46,9 +42,8 @@ function contract(a::Tensor, b::Tensor; dims=(∩(inds(a), inds(b))), out=nothin
out
end

data = OMEinsum.get_output_array((parent(a), parent(b)), [size(i in ia ? a : b, i) for i in ic]; fillzero=false)
c = Tensor(data, ic)
return contract!(c, a, b)
data = tensorcontract(Tuple(ic), parent(a), Tuple(inds(a)), false, parent(b), Tuple(inds(b)), false)
return Tensor(data, ic)
end

function contract(a::Tensor; dims=nonunique(inds(a)), out=nothing)
Expand All @@ -61,9 +56,9 @@ function contract(a::Tensor; dims=nonunique(inds(a)), out=nothing)
out
end

data = OMEinsum.get_output_array((parent(a),), [size(a, i) for i in ic]; fillzero=false)
c = Tensor(data, ic)
return contract!(c, a)
# TODO might fail on partial trace
data = tensortrace(Tuple(ic), parent(a), Tuple(inds(a)), false)
return Tensor(data, ic)
end

contract(a::Union{T,AbstractArray{T,0}}, b::Tensor{T}) where {T} = contract(Tensor(a), b)
Expand All @@ -73,22 +68,14 @@ contract(a::Number, b::Number) = contract(fill(a), fill(b))
contract(tensors::Tensor...; kwargs...) = reduce((x, y) -> contract(x, y; kwargs...), tensors)

function contract!(c::Tensor, a::Tensor, b::Tensor)
ixs = (inds(a), inds(b))
iy = inds(c)
xs = (parent(a), parent(b))
y = parent(c)
size_dict = merge!(Dict{Symbol,Int}.([inds(a) .=> size(a), inds(b) .=> size(b)])...)

einsum!(ixs, iy, xs, y, true, false, size_dict)
pA, pB, pAB = contract_indices(Tuple(inds(a)), Tuple(inds(b)), Tuple(inds(c)))
tensorcontract!(parent(c), parent(a), pA, false, parent(b), pB, false, pAB)
return c
end

function contract!(y::Tensor, x::Tensor)
ixs = (inds(x),)
iy = inds(y)
size_dict = Dict{Symbol,Int}(inds(x) .=> size(x))

einsum!(ixs, iy, (parent(x),), parent(y), true, false, size_dict)
p, q = TensorOperations.trace_indices(Tuple(inds(x)), Tuple(inds(y)))
tensortrace!(parent(y), parent(x), p, q, false)
return y
end

Expand Down
1 change: 0 additions & 1 deletion src/TensorNetwork.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Base: AbstractVecOrTuple
using Random
using EinExprs
using OMEinsum
using LinearAlgebra
using ScopedValues
using Serialization
Expand Down
1 change: 0 additions & 1 deletion src/Transformations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using DeltaArrays
using EinExprs
using OMEinsum
using UUIDs: uuid4
using Tenet: parenttype
using Combinatorics: combinations
Expand Down
21 changes: 14 additions & 7 deletions test/Numerics_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
A = Tensor(rand(2, 3, 4), (:i, :j, :k))

C = contract(A; dims=(:i,))
C_ein = ein"ijk -> jk"(parent(A))
# C_ein = ein"ijk -> jk"(parent(A))
@tensor C_ein[j, k] := A[i, j, k]
@test inds(C) == [:j, :k]
@test size(C) == size(C_ein) == (3, 4)
@test parent(C) C_ein
Expand All @@ -31,7 +32,8 @@
A = Tensor(rand(2, 3, 2), (:i, :j, :i))

C = contract(A; dims=())
C_ein = ein"iji -> ij"(parent(A))
# C_ein = ein"iji -> ij"(parent(A))
@tensor C_ein[i, j] := A[i, j, i]
@test inds(C) == [:i, :j]
@test size(C) == size(C_ein) == (2, 3)
@test parent(C) C_ein
Expand All @@ -41,7 +43,8 @@
A = Tensor(rand(2, 3, 2), (:i, :j, :i))

C = contract(A; dims=(:i,))
C_ein = ein"iji -> j"(parent(A))
# C_ein = ein"iji -> j"(parent(A))
@tensor C_ein[j] := A[i, j, i]
@test inds(C) == [:j]
@test size(C) == size(C_ein) == (3,)
@test parent(C) C_ein
Expand Down Expand Up @@ -74,7 +77,8 @@
B = Tensor(rand(2, 2), (:k, :l))

C = contract(A, B)
C_ein = ein"ij, kl -> ijkl"(parent(A), parent(B))
# C_ein = ein"ij, kl -> ijkl"(parent(A), parent(B))
@tensor C_ein[i, j, k, l] := A[i, j] * B[k, l]
@test size(C) == (2, 2, 2, 2) == size(C_ein)
@test inds(C) == [:i, :j, :k, :l]
@test parent(C) C_ein
Expand All @@ -101,14 +105,16 @@

# Contraction of all common indices
C = contract(A, B; dims=(:j, :k))
C_ein = ein"ijk, klj -> il"(parent(A), parent(B))
# C_ein = ein"ijk, klj -> il"(parent(A), parent(B))
@tensor C_ein[i, l] := A[i, j, k] * B[k, l, j]
@test inds(C) == [:i, :l]
@test size(C) == (2, 5) == size(C_ein)
@test parent(C) C_ein

# Contraction of not all common indices
C = contract(A, B; dims=(:j,))
C_ein = ein"ijk, klj -> ikl"(parent(A), parent(B))
# C_ein = ein"ijk, klj -> ikl"(parent(A), parent(B))
@tensor C_ein[i, k, l] := A[i, j, k] * B[k, l, j]
@test inds(C) == [:i, :k, :l]
@test size(C) == (2, 4, 5) == size(C_ein)
@test parent(C) C_ein
Expand All @@ -118,7 +124,8 @@
B = Tensor(rand(Complex{Float64}, 4, 5, 3), (:k, :l, :j))

C = contract(A, B; dims=(:j, :k))
C_ein = ein"ijk, klj -> il"(parent(A), parent(B))
# C_ein = ein"ijk, klj -> il"(parent(A), parent(B))
@tensor C_ein[i, l] := A[i, j, k] * B[k, l, j]
@test inds(C) == [:i, :l]
@test size(C) == (2, 5) == size(C_ein)
@test parent(C) C_ein
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Permutations = "2ae35dd2-176d-5d53-8349-f30d82d94d4f"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test
using Tenet
using OMEinsum
using TensorOperations

@testset "Unit tests" verbose = true begin
include("Helpers_test.jl")
Expand Down

0 comments on commit 7671b0f

Please sign in to comment.