diff --git a/Project.toml b/Project.toml index c3f02be2d..9596a57de 100644 --- a/Project.toml +++ b/Project.toml @@ -10,11 +10,11 @@ DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" KeywordDispatch = "5888135b-5456-5c80-a1b6-c91ef8180460" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [weakdeps] @@ -70,7 +70,6 @@ KeywordDispatch = "0.3" KrylovKit = "0.7, 0.8" LinearAlgebra = "1.9" Makie = "0.18,0.19,0.20, 0.21" -OMEinsum = "0.7, 0.8" PythonCall = "0.9" Quac = "0.3" Random = "1.9" @@ -78,6 +77,7 @@ Reactant = "0.2" ScopedValues = "1" Serialization = "1.9" SparseArrays = "1.9" +TensorOperations = "4, 5" UUIDs = "1.9" Yao = "0.8, 0.9" julia = "1.9" diff --git a/ext/TenetChainRulesCoreExt/non_diff.jl b/ext/TenetChainRulesCoreExt/non_diff.jl index 9886b295c..c06849390 100644 --- a/ext/TenetChainRulesCoreExt/non_diff.jl +++ b/ext/TenetChainRulesCoreExt/non_diff.jl @@ -1,6 +1,3 @@ -# NOTE fix problem with vector generator in `contract` -@non_differentiable Tenet.__omeinsum_sym2str(x) - # WARN type-piracy @non_differentiable setdiff(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...) @non_differentiable union(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...) diff --git a/src/Numerics.jl b/src/Numerics.jl index 0e9a7c02c..421e9f029 100644 --- a/src/Numerics.jl +++ b/src/Numerics.jl @@ -1,4 +1,4 @@ -using OMEinsum +using TensorOperations using LinearAlgebra using UUIDs: uuid4 using SparseArrays @@ -14,10 +14,6 @@ function Base.literal_pow(f, a::Tensor{T,0}, ::Val{p}) where {T,p} return Tensor(fill(Base.literal_pow(f, only(a), Val(p)))) end -# NOTE used for marking non-differentiability -# NOTE use `String[...]` code instead of `map` or broadcasting to set eltype in empty cases -__omeinsum_sym2str(x) = String[string(i) for i in x] - function Base.:(+)(a::Tensor, b::Tensor) issetequal(inds(a), inds(b)) || throw(ArgumentError("indices must be equal")) perm = __find_index_permutation(inds(a), inds(b)) @@ -46,9 +42,8 @@ function contract(a::Tensor, b::Tensor; dims=(∩(inds(a), inds(b))), out=nothin out end - data = OMEinsum.get_output_array((parent(a), parent(b)), [size(i in ia ? a : b, i) for i in ic]; fillzero=false) - c = Tensor(data, ic) - return contract!(c, a, b) + data = tensorcontract(Tuple(ic), parent(a), Tuple(inds(a)), false, parent(b), Tuple(inds(b)), false) + return Tensor(data, ic) end function contract(a::Tensor; dims=nonunique(inds(a)), out=nothing) @@ -61,9 +56,9 @@ function contract(a::Tensor; dims=nonunique(inds(a)), out=nothing) out end - data = OMEinsum.get_output_array((parent(a),), [size(a, i) for i in ic]; fillzero=false) - c = Tensor(data, ic) - return contract!(c, a) + # TODO might fail on partial trace + data = tensortrace(Tuple(ic), parent(a), Tuple(inds(a)), false) + return Tensor(data, ic) end contract(a::Union{T,AbstractArray{T,0}}, b::Tensor{T}) where {T} = contract(Tensor(a), b) @@ -73,22 +68,14 @@ contract(a::Number, b::Number) = contract(fill(a), fill(b)) contract(tensors::Tensor...; kwargs...) = reduce((x, y) -> contract(x, y; kwargs...), tensors) function contract!(c::Tensor, a::Tensor, b::Tensor) - ixs = (inds(a), inds(b)) - iy = inds(c) - xs = (parent(a), parent(b)) - y = parent(c) - size_dict = merge!(Dict{Symbol,Int}.([inds(a) .=> size(a), inds(b) .=> size(b)])...) - - einsum!(ixs, iy, xs, y, true, false, size_dict) + pA, pB, pAB = contract_indices(Tuple(inds(a)), Tuple(inds(b)), Tuple(inds(c))) + tensorcontract!(parent(c), parent(a), pA, false, parent(b), pB, false, pAB) return c end function contract!(y::Tensor, x::Tensor) - ixs = (inds(x),) - iy = inds(y) - size_dict = Dict{Symbol,Int}(inds(x) .=> size(x)) - - einsum!(ixs, iy, (parent(x),), parent(y), true, false, size_dict) + p, q = TensorOperations.trace_indices(Tuple(inds(x)), Tuple(inds(y))) + tensortrace!(parent(y), parent(x), p, q, false) return y end diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index fbdddc81e..67fe32361 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -1,7 +1,6 @@ using Base: AbstractVecOrTuple using Random using EinExprs -using OMEinsum using LinearAlgebra using ScopedValues using Serialization diff --git a/src/Transformations.jl b/src/Transformations.jl index 779a7eefa..6ba80d257 100644 --- a/src/Transformations.jl +++ b/src/Transformations.jl @@ -1,6 +1,5 @@ using DeltaArrays using EinExprs -using OMEinsum using UUIDs: uuid4 using Tenet: parenttype using Combinatorics: combinations diff --git a/test/Numerics_test.jl b/test/Numerics_test.jl index 49fa9856d..4c97b0586 100644 --- a/test/Numerics_test.jl +++ b/test/Numerics_test.jl @@ -21,7 +21,8 @@ A = Tensor(rand(2, 3, 4), (:i, :j, :k)) C = contract(A; dims=(:i,)) - C_ein = ein"ijk -> jk"(parent(A)) + # C_ein = ein"ijk -> jk"(parent(A)) + @tensor C_ein[j, k] := A[i, j, k] @test inds(C) == [:j, :k] @test size(C) == size(C_ein) == (3, 4) @test parent(C) ≈ C_ein @@ -31,7 +32,8 @@ A = Tensor(rand(2, 3, 2), (:i, :j, :i)) C = contract(A; dims=()) - C_ein = ein"iji -> ij"(parent(A)) + # C_ein = ein"iji -> ij"(parent(A)) + @tensor C_ein[i, j] := A[i, j, i] @test inds(C) == [:i, :j] @test size(C) == size(C_ein) == (2, 3) @test parent(C) ≈ C_ein @@ -41,7 +43,8 @@ A = Tensor(rand(2, 3, 2), (:i, :j, :i)) C = contract(A; dims=(:i,)) - C_ein = ein"iji -> j"(parent(A)) + # C_ein = ein"iji -> j"(parent(A)) + @tensor C_ein[j] := A[i, j, i] @test inds(C) == [:j] @test size(C) == size(C_ein) == (3,) @test parent(C) ≈ C_ein @@ -74,7 +77,8 @@ B = Tensor(rand(2, 2), (:k, :l)) C = contract(A, B) - C_ein = ein"ij, kl -> ijkl"(parent(A), parent(B)) + # C_ein = ein"ij, kl -> ijkl"(parent(A), parent(B)) + @tensor C_ein[i, j, k, l] := A[i, j] * B[k, l] @test size(C) == (2, 2, 2, 2) == size(C_ein) @test inds(C) == [:i, :j, :k, :l] @test parent(C) ≈ C_ein @@ -101,14 +105,16 @@ # Contraction of all common indices C = contract(A, B; dims=(:j, :k)) - C_ein = ein"ijk, klj -> il"(parent(A), parent(B)) + # C_ein = ein"ijk, klj -> il"(parent(A), parent(B)) + @tensor C_ein[i, l] := A[i, j, k] * B[k, l, j] @test inds(C) == [:i, :l] @test size(C) == (2, 5) == size(C_ein) @test parent(C) ≈ C_ein # Contraction of not all common indices C = contract(A, B; dims=(:j,)) - C_ein = ein"ijk, klj -> ikl"(parent(A), parent(B)) + # C_ein = ein"ijk, klj -> ikl"(parent(A), parent(B)) + @tensor C_ein[i, k, l] := A[i, j, k] * B[k, l, j] @test inds(C) == [:i, :k, :l] @test size(C) == (2, 4, 5) == size(C_ein) @test parent(C) ≈ C_ein @@ -118,7 +124,8 @@ B = Tensor(rand(Complex{Float64}, 4, 5, 3), (:k, :l, :j)) C = contract(A, B; dims=(:j, :k)) - C_ein = ein"ijk, klj -> il"(parent(A), parent(B)) + # C_ein = ein"ijk, klj -> il"(parent(A), parent(B)) + @tensor C_ein[i, l] := A[i, j, k] * B[k, l, j] @test inds(C) == [:i, :l] @test size(C) == (2, 5) == size(C_ein) @test parent(C) ≈ C_ein diff --git a/test/Project.toml b/test/Project.toml index f8588675a..e4288d21e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,10 +13,10 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a" -OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Permutations = "2ae35dd2-176d-5d53-8349-f30d82d94d4f" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/test/runtests.jl b/test/runtests.jl index 6a9bc8e71..44e93cd7d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using Test using Tenet -using OMEinsum +using TensorOperations @testset "Unit tests" verbose = true begin include("Helpers_test.jl")