Skip to content

Commit

Permalink
Fix dispatch of methods on Reactant with new eltype(::TracedRArray)
Browse files Browse the repository at this point in the history
… behavior
  • Loading branch information
mofeing committed Jan 7, 2025
1 parent b829a56 commit e030f88
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module TenetReactantExt
using Tenet
using EinExprs
using Reactant
using Reactant: TracedRArray
using Reactant: TracedRArray, TracedRNumber
const MLIR = Reactant.MLIR
const stablehlo = MLIR.Dialects.stablehlo

Expand Down Expand Up @@ -122,8 +122,13 @@ function Reactant.set_act!(inp::Enzyme.Annotation{TensorNetwork}, path, reverse,
end

function Tenet.contract(
a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb,TracedRArray{Tb,Nb}}; dims=((inds(a), inds(b))), out=nothing
a::Tensor{TracedRNumber{Ta},Na,TracedRArray{Ta,Na}}, b::Tensor{TracedRNumber{Tb},Nb,TracedRArray{Tb,Nb}}; kwargs...
) where {Ta,Na,Tb,Nb}
dims = get(kwargs, :dims) do
(inds(a), inds(b))
end
out = get(kwargs, :out, nothing)

ia, ib = collect(inds(a)), collect(inds(b))
@assert allunique(ia) "can't perform unary einsum operations on binary einsum"
@assert allunique(ib) "can't perform unary einsum operations on binary einsum"
Expand Down Expand Up @@ -167,12 +172,16 @@ function Tenet.contract(
return Tensor(data, ic)
end

function Tenet.contract(a::Tensor{T,N,TracedRArray{T,N}}; dims=nonunique(inds(a)), out=nothing) where {T,N}
function Tenet.contract(
a::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}; dims=nonunique(inds(a)), out=nothing
) where {T,N}
error("compilation of unary einsum operations are not yet supported")
end

Tenet.contract(a::Tensor, b::Tensor{T,N,TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...)
function Tenet.contract(a::Tensor{T,N,TracedRArray{T,N}}, b::Tensor; kwargs...) where {T,N}
function Tenet.contract(a::Tensor, b::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}; kwargs...) where {T,N}
contract(b, a; kwargs...)
end
function Tenet.contract(a::Tensor{TracedRNumber{T},N,TracedRArray{T,N}}, b::Tensor; kwargs...) where {T,N}
return contract(a, Tensor(Reactant.Ops.constant(parent(b)), inds(b)); kwargs...)
end

Expand Down

0 comments on commit e030f88

Please sign in to comment.