diff --git a/Project.toml b/Project.toml index 87e780020..ff6e68767 100644 --- a/Project.toml +++ b/Project.toml @@ -7,8 +7,6 @@ version = "0.4.1" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" -GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" -Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Muscle = "21fe5c4b-a943-414d-bf3e-516f24900631" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" @@ -21,19 +19,22 @@ ValSplit = "0625e100-946b-11ec-09cd-6328dd093154" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" [extensions] TenetChainRulesCoreExt = "ChainRulesCore" TenetChainRulesTestUtilsExt = ["ChainRulesCore", "ChainRulesTestUtils"] TenetFiniteDifferencesExt = "FiniteDifferences" +TenetGraphMakieExt = ["GraphMakie", "Makie"] TenetMakieExt = "Makie" [compat] ChainRulesCore = "1.0" Combinatorics = "1.0" DeltaArrays = "0.1.1" -EinExprs = "0.5.5" +EinExprs = "0.5, 0.6" GraphMakie = "0.4,0.5" Graphs = "1.7" LinearAlgebra = "1.9" diff --git a/ext/TenetGraphMakieExt.jl b/ext/TenetGraphMakieExt.jl new file mode 100644 index 000000000..be408c1f4 --- /dev/null +++ b/ext/TenetGraphMakieExt.jl @@ -0,0 +1,144 @@ +module TenetGraphMakieExt + +function __init__() + try + Base.require(Main, :Graphs) + catch + @warn """Package Graphs or Graphs not found in current path. It is needed to plot `Tenet`s with `GraphMakie`. + - Run `import Pkg; Pkg.add(\"Graphs\")` or `]add Graphs` to install the Graphs package, then restart julia. + """ + end +end + +using Tenet +using Tenet: AbstractTensorNetwork +using Combinatorics: combinations +using Graphs +using Makie + +using GraphMakie + +""" + plot(tn::TensorNetwork; kwargs...) + plot!(f::Union{Figure,GridPosition}, tn::TensorNetwork; kwargs...) + plot!(ax::Union{Axis,Axis3}, tn::TensorNetwork; kwargs...) + +Plot a [`TensorNetwork`](@ref) as a graph. + +# Keyword Arguments + + - `labels` If `true`, show the labels of the tensor indices. Defaults to `false`. + - The rest of `kwargs` are passed to `GraphMakie.graphplot`. +""" +function Makie.plot(@nospecialize tn::AbstractTensorNetwork; kwargs...) + f = Figure() + ax, p = plot!(f[1, 1], tn; kwargs...) + return Makie.FigureAxisPlot(f, ax, p) +end + +# NOTE this is a hack! we did it in order not to depend on NetworkLayout but can be unstable +__networklayout_dim(x) = typeof(x).super.parameters |> first + +function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::AbstractTensorNetwork; kwargs...) + ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3 + Axis3(f[1, 1]) + else + ax = Axis(f[1, 1]) + ax.aspect = DataAspect() + ax + end + + hidedecorations!(ax) + hidespines!(ax) + + p = plot!(ax, tn; kwargs...) + + return Makie.AxisPlot(ax, p) +end + +function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::AbstractTensorNetwork; labels = false, kwargs...) + hypermap = Tenet.hyperflatten(tn) + tn = transform(tn, Tenet.HyperindConverter) + + tensormap = IdDict(tensor => i for (i, tensor) in enumerate(tensors(tn))) + + # TODO how to mark multiedges? (i.e. parallel edges) + graph = SimpleGraph([ + Edge(map(Base.Fix1(getindex, tensormap), tensors)...) for (_, tensors) in tn.indexmap if length(tensors) > 1 + ]) + + # TODO recognise `copytensors` by using `DeltaArray` or `Diagonal` representations + copytensors = findall(tensor -> any(flatinds -> issetequal(inds(tensor), flatinds), keys(hypermap)), tensors(tn)) + ghostnodes = map(inds(tn, :open)) do index + # create new ghost node + add_vertex!(graph) + node = nv(graph) + + # connect ghost node + tensor = only(tn.indexmap[index]) + add_edge!(graph, node, tensormap[tensor]) + + return node + end + + # configure graphics + # TODO refactor hardcoded values into constants + kwargs = Dict{Symbol,Any}(kwargs) + + if haskey(kwargs, :node_size) + append!(kwargs[:node_size], zero(ghostnodes)) + else + kwargs[:node_size] = map(1:nv(graph)) do i + i ∈ ghostnodes ? 0 : max(15, log2(length(tensors(tn)[i]))) + end + end + + if haskey(kwargs, :node_marker) + append!(kwargs[:node_marker], fill(:circle, length(ghostnodes))) + else + kwargs[:node_marker] = map(i -> i ∈ copytensors ? :diamond : :circle, 1:nv(graph)) + end + + if haskey(kwargs, :node_color) + kwargs[:node_color] = vcat(kwargs[:node_color], fill(:black, length(ghostnodes))) + else + kwargs[:node_color] = map(1:nv(graph)) do v + v ∈ copytensors ? Makie.to_color(:black) : Makie.RGBf(240 // 256, 180 // 256, 100 // 256) + end + end + + get!(kwargs, :node_attr, (colormap = :viridis, strokewidth = 2.0, strokecolor = :black)) + + # configure labels + labels == true && get!(kwargs, :elabels) do + opentensors = findall(t -> !isdisjoint(inds(t), inds(tn, :open)), tensors(tn)) + opencounter = IdDict(tensor => 0 for tensor in opentensors) + + map(edges(graph)) do edge + # case: open edge + if any(∈(ghostnodes), [src(edge), dst(edge)]) + notghost = src(edge) ∈ ghostnodes ? dst(edge) : src(edge) + inds = Tenet.inds(tn, :open) ∩ Tenet.inds(tensors(tn)[notghost]) + opencounter[notghost] += 1 + return inds[opencounter[notghost]] |> string + end + + # case: hyperedge + if any(∈(copytensors), [src(edge), dst(edge)]) + i = src(edge) ∈ copytensors ? src(edge) : dst(edge) + # hyperindex = filter(p -> isdisjoint(inds(tensors)[i], p[2]), hypermap) |> only |> first + hyperindex = hypermap[Tenet.inds(tensors(tn)[i])] + return hyperindex |> string + end + + return join(Tenet.inds(tensors(tn)[src(edge)]) ∩ Tenet.inds(tensors(tn)[dst(edge)]), ',') + end + end + get!(() -> repeat([:black], ne(graph)), kwargs, :elabels_color) + get!(() -> repeat([17], ne(graph)), kwargs, :elabels_textsize) + + # plot graph + graphplot!(ax, graph; kwargs...) +end + +end diff --git a/ext/TenetMakieExt.jl b/ext/TenetMakieExt.jl index b8f5aacbc..4cc2f29ca 100644 --- a/ext/TenetMakieExt.jl +++ b/ext/TenetMakieExt.jl @@ -1,134 +1,13 @@ module TenetMakieExt -using Tenet -using Tenet: AbstractTensorNetwork -using Combinatorics: combinations -using Graphs -using Makie - -using GraphMakie - -""" - plot(tn::TensorNetwork; kwargs...) - plot!(f::Union{Figure,GridPosition}, tn::TensorNetwork; kwargs...) - plot!(ax::Union{Axis,Axis3}, tn::TensorNetwork; kwargs...) - -Plot a [`TensorNetwork`](@ref) as a graph. - -# Keyword Arguments - - - `labels` If `true`, show the labels of the tensor indices. Defaults to `false`. - - The rest of `kwargs` are passed to `GraphMakie.graphplot`. -""" -function Makie.plot(@nospecialize tn::AbstractTensorNetwork; kwargs...) - f = Figure() - ax, p = plot!(f[1, 1], tn; kwargs...) - return Makie.FigureAxisPlot(f, ax, p) -end - -# NOTE this is a hack! we did it in order not to depend on NetworkLayout but can be unstable -__networklayout_dim(x) = typeof(x).super.parameters |> first - -function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::AbstractTensorNetwork; kwargs...) - ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3 - Axis3(f[1, 1]) - else - ax = Axis(f[1, 1]) - ax.aspect = DataAspect() - ax +function __init__() + try + Base.require(Main, :GraphMakie) + catch + @warn """Package GraphMakie not found in current path. It is needed to plot `Tenet`s with `Makie`. + - Run `import Pkg; Pkg.add(\"GraphMakie\")` or `]add GraphMakie` to install the GraphMakie package, then restart julia. + """ end - - hidedecorations!(ax) - hidespines!(ax) - - p = plot!(ax, tn; kwargs...) - - return Makie.AxisPlot(ax, p) -end - -function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::AbstractTensorNetwork; labels = false, kwargs...) - hypermap = Tenet.hyperflatten(tn) - tn = transform(tn, Tenet.HyperindConverter) - - tensormap = IdDict(tensor => i for (i, tensor) in enumerate(tensors(tn))) - - # TODO how to mark multiedges? (i.e. parallel edges) - graph = SimpleGraph([ - Edge(map(Base.Fix1(getindex, tensormap), tensors)...) for (_, tensors) in tn.indexmap if length(tensors) > 1 - ]) - - # TODO recognise `copytensors` by using `DeltaArray` or `Diagonal` representations - copytensors = findall(tensor -> any(flatinds -> issetequal(inds(tensor), flatinds), keys(hypermap)), tensors(tn)) - ghostnodes = map(inds(tn, :open)) do index - # create new ghost node - add_vertex!(graph) - node = nv(graph) - - # connect ghost node - tensor = only(tn.indexmap[index]) - add_edge!(graph, node, tensormap[tensor]) - - return node - end - - # configure graphics - # TODO refactor hardcoded values into constants - kwargs = Dict{Symbol,Any}(kwargs) - - if haskey(kwargs, :node_size) - append!(kwargs[:node_size], zero(ghostnodes)) - else - kwargs[:node_size] = map(1:nv(graph)) do i - i ∈ ghostnodes ? 0 : max(15, log2(length(tensors(tn)[i]))) - end - end - - if haskey(kwargs, :node_marker) - append!(kwargs[:node_marker], fill(:circle, length(ghostnodes))) - else - kwargs[:node_marker] = map(i -> i ∈ copytensors ? :diamond : :circle, 1:nv(graph)) - end - - if haskey(kwargs, :node_color) - kwargs[:node_color] = vcat(kwargs[:node_color], fill(:black, length(ghostnodes))) - else - kwargs[:node_color] = map(1:nv(graph)) do v - v ∈ copytensors ? Makie.to_color(:black) : Makie.RGBf(240 // 256, 180 // 256, 100 // 256) - end - end - - get!(kwargs, :node_attr, (colormap = :viridis, strokewidth = 2.0, strokecolor = :black)) - - # configure labels - labels == true && get!(kwargs, :elabels) do - opentensors = findall(t -> !isdisjoint(inds(t), inds(tn, :open)), tensors(tn)) - opencounter = IdDict(tensor => 0 for tensor in opentensors) - - map(edges(graph)) do edge - # case: open edge - if any(∈(ghostnodes), [src(edge), dst(edge)]) - notghost = src(edge) ∈ ghostnodes ? dst(edge) : src(edge) - inds = Tenet.inds(tn, :open) ∩ Tenet.inds(tensors(tn)[notghost]) - opencounter[notghost] += 1 - return inds[opencounter[notghost]] |> string - end - - # case: hyperedge - if any(∈(copytensors), [src(edge), dst(edge)]) - i = src(edge) ∈ copytensors ? src(edge) : dst(edge) - # hyperindex = filter(p -> isdisjoint(inds(tensors)[i], p[2]), hypermap) |> only |> first - hyperindex = hypermap[Tenet.inds(tensors(tn)[i])] - return hyperindex |> string - end - - return join(Tenet.inds(tensors(tn)[src(edge)]) ∩ Tenet.inds(tensors(tn)[dst(edge)]), ',') - end - end - get!(() -> repeat([:black], ne(graph)), kwargs, :elabels_color) - get!(() -> repeat([17], ne(graph)), kwargs, :elabels_textsize) - - # plot graph - graphplot!(ax, graph; kwargs...) end end