Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate with Mooncake #278

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Integrate with Mooncake #278

wants to merge 3 commits into from

Conversation

mofeing
Copy link
Member

@mofeing mofeing commented Jan 8, 2025

CC @willtebbutt

i'm having also a problem when trying to compute the gradient.

this code:

A = Tensor(rand(2,3), (:i,:j))
B = Tensor(rand(3,4), (:j,:k))

f(a,b) = sum(contract(a,b))

rule = build_rrule(Tuple{typeof(f), typeof(A), typeof(B)})
Mooncake.value_and_gradient!!(rule, contract, A, B)

gives this error:

ERROR: ArgumentError: signature of arguments, Tuple{Mooncake.CoDual{typeof(contract), NoFData}, Mooncake.CoDual{Tensor{Float64, 2, Matrix{Float64}}, Mooncake.FData{@NamedTuple{data::Matrix{Float64}, inds::Vector{NoTangent}}}}, Mooncake.CoDual{Tensor{Float64, 2, Matrix{Float64}}, Mooncake.FData{@NamedTuple{data::Matrix{Float64}, inds::Vector{NoTangent}}}}}, not equal to signature required by rule, Tuple{Mooncake.CoDual{typeof(f), NoFData}, Mooncake.CoDual{Tensor{Float64, 2, Matrix{Float64}}, Mooncake.FData{@NamedTuple{data::Matrix{Float64}, inds::Vector{NoTangent}}}}, Mooncake.CoDual{Tensor{Float64, 2, Matrix{Float64}}, Mooncake.FData{@NamedTuple{data::Matrix{Float64}, inds::Vector{NoTangent}}}}}.
Stacktrace:
 [1] __verify_sig(rule::Mooncake.DerivedRule{…}, fx::Tuple{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interface.jl:27
 [2] __value_and_gradient!!(::Mooncake.DerivedRule{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interface.jl:81
 [3] value_and_gradient!!(::Mooncake.DerivedRule{…}, ::Function, ::Tensor{…}, ::Tensor{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interface.jl:145
 [4] top-level scope
   @ REPL[25]:1
Some type information was truncated. Use `show(err)` to see complete types.

@willtebbutt
Copy link

gives this error:

Is the problem just that f should be passed to value_and_gradient!!, rather than contract?

@mofeing
Copy link
Member Author

mofeing commented Jan 8, 2025

ah right 🤦

well, that uncovers another problem

julia> Mooncake.value_and_gradient!!(rule, f, A, B)
ERROR: MethodError: no method matching (::TenetChainRulesCoreExt.var"#contract_pullback#61"{…})(::ChainRulesCore.Tangent{…})
The function `contract_pullback` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  (::TenetChainRulesCoreExt.var"#contract_pullback#61")(::ChainRulesCore.AbstractThunk)
   @ TenetChainRulesCoreExt ~/Developer/Tenet.jl/ext/TenetChainRulesCoreExt/rrules.jl:117
  (::TenetChainRulesCoreExt.var"#contract_pullback#61")(::Tensor)
   @ TenetChainRulesCoreExt ~/Developer/Tenet.jl/ext/TenetChainRulesCoreExt/rrules.jl:110
  (::TenetChainRulesCoreExt.var"#contract_pullback#61")(::AbstractVector)
   @ TenetChainRulesCoreExt ~/Developer/Tenet.jl/ext/TenetChainRulesCoreExt/rrules.jl:116
  ...

Stacktrace:
 [1] (::Mooncake.var"#pb!!#291"{…})(y_rdata::NoRData)
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/tools_for_rules.jl:295
 [2] (::Mooncake.RRuleWrapperPb{Mooncake.var"#pb!!#291"{…}, Mooncake.LazyZeroRData{…}})(dy::NoRData)
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interpreter/s2s_reverse_mode_ad.jl:299
 [3] __run_rvs_pass!(::Type, ::Type{…}, ::Mooncake.RRuleWrapperPb{…}, ::Base.RefValue{…}, ::Nothing, ::Vararg{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interpreter/s2s_reverse_mode_ad.jl:820
 [4] f
   @ ./REPL[42]:1 [inlined]
 [5] (::Tuple{Mooncake.Stack{…}, Base.RefValue{…}, Mooncake.RRuleZeroWrapper{…}, Mooncake.Stack{…}})(none::Any)
   @ Base.Experimental ./<missing>:0
 [6] Pullback
   @ ~/.julia/packages/Mooncake/oBjQd/src/interpreter/s2s_reverse_mode_ad.jl:855 [inlined]
 [7] __value_and_gradient!!(::Mooncake.DerivedRule{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interface.jl:85
 [8] value_and_gradient!!(::Mooncake.DerivedRule{…}, ::Function, ::Tensor{…}, ::Tensor{…})
   @ Mooncake ~/.julia/packages/Mooncake/oBjQd/src/interface.jl:145
 [9] top-level scope
   @ REPL[44]:1
Some type information was truncated. Use `show(err)` to see complete types.

i probably need to add a method to contract_pullback in here

function contract_pullback(c̄::Tensor)
= @thunk proj_a(contract(c̄, conj(b); out=inds(a)))
= @thunk proj_b(contract(conj(a), c̄; out=inds(b)))
return (NoTangent(), ā, b̄)
end
contract_pullback(c̄::AbstractArray) = contract_pullback(Tensor(c̄, inds(c)))
contract_pullback(c̄::AbstractVector) = contract_pullback(Tensor(c̄, inds(c)))
contract_pullback(c̄::AbstractThunk) = contract_pullback(unthunk(c̄))
but i'm not sure what type of object is Mooncake passing to the ChainRules callbacks because it's an error i haven't found before with Zygote and ChainRulesTestUtils

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants