From f0453cf4773f0cc1a24455cb4d42d68673dded52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 8 Jan 2025 13:51:08 +0100 Subject: [PATCH 1/3] Init integration with Mooncake --- Project.toml | 2 ++ ext/TenetMooncakeExt.jl | 14 ++++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 ext/TenetMooncakeExt.jl diff --git a/Project.toml b/Project.toml index 8b4c36012..4e547c423 100644 --- a/Project.toml +++ b/Project.toml @@ -32,6 +32,7 @@ ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7" ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" @@ -49,6 +50,7 @@ TenetITensorMPSExt = ["ITensors", "ITensorMPS"] TenetITensorNetworksExt = "ITensorNetworks" TenetITensorsExt = "ITensors" TenetKrylovKitExt = ["KrylovKit"] +TenetMooncakeExt = "Mooncake" TenetPythonCallExt = "PythonCall" TenetQuacExt = "Quac" TenetReactantExt = "Reactant" diff --git a/ext/TenetMooncakeExt.jl b/ext/TenetMooncakeExt.jl new file mode 100644 index 000000000..503334a63 --- /dev/null +++ b/ext/TenetMooncakeExt.jl @@ -0,0 +1,14 @@ +module TenetMooncakeExt + +using Tenet +using Mooncake: Mooncake, @from_rrule, DefaultCtx + +@from_rrule DefaultCtx Tuple{typeof(contract),Tensor} +@from_rrule DefaultCtx Tuple{typeof(contract),Tensor,Tensor} + +# function Mooncake.tangent(fdata::Mooncake.FData{<:NamedTuple{(:data, :inds)}}, tensor::Tensor) +# @show fdata +# @show tensor +# end + +end # module From 4e1a9d993ede7d305dfc782070b9dceb4b0483f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 8 Jan 2025 14:42:14 +0100 Subject: [PATCH 2/3] Mark that imported `rrules` have kwargs --- ext/TenetMooncakeExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TenetMooncakeExt.jl b/ext/TenetMooncakeExt.jl index 503334a63..f52aef560 100644 --- a/ext/TenetMooncakeExt.jl +++ b/ext/TenetMooncakeExt.jl @@ -3,8 +3,8 @@ module TenetMooncakeExt using Tenet using Mooncake: Mooncake, @from_rrule, DefaultCtx -@from_rrule DefaultCtx Tuple{typeof(contract),Tensor} -@from_rrule DefaultCtx Tuple{typeof(contract),Tensor,Tensor} +@from_rrule DefaultCtx Tuple{typeof(contract),Tensor} true +@from_rrule DefaultCtx Tuple{typeof(contract),Tensor,Tensor} true # function Mooncake.tangent(fdata::Mooncake.FData{<:NamedTuple{(:data, :inds)}}, tensor::Tensor) # @show fdata From 2c4c468c15705eb1427d94290ca3ba8b57e06b45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 8 Jan 2025 15:51:00 +0100 Subject: [PATCH 3/3] Implement `to_cr_tangent` for `Tensor` type --- ext/TenetMooncakeExt.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/TenetMooncakeExt.jl b/ext/TenetMooncakeExt.jl index f52aef560..78e08e975 100644 --- a/ext/TenetMooncakeExt.jl +++ b/ext/TenetMooncakeExt.jl @@ -3,6 +3,8 @@ module TenetMooncakeExt using Tenet using Mooncake: Mooncake, @from_rrule, DefaultCtx +Mooncake.to_cr_tangent(tensor::Tensor) = tensor + @from_rrule DefaultCtx Tuple{typeof(contract),Tensor} true @from_rrule DefaultCtx Tuple{typeof(contract),Tensor,Tensor} true