Skip to content

Commit

Permalink
Fix Reactant.traced_type_inner implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jan 24, 2025
1 parent b8b4aa2 commit 393f727
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,22 @@ const stablehlo = MLIR.Dialects.stablehlo

const Enzyme = Reactant.Enzyme

# we specify `mode` and `track_numbers` types due to ambiguity
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(T::Type{<:Tensor}), seen, mode, @nospecialize(track_numbers)
@nospecialize(TT::Type{<:Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
)
A_traced = Reactant.traced_type_inner(Tenet.parenttype(T), seen, mode, track_numbers)
return Tensor{eltype(T),ndims(T),A_traced}
A_traced = Reactant.traced_type_inner(Tenet.parenttype(TT), seen, mode, track_numbers)
T = eltype(A_traced)
N = ndims(TT)
return Tensor{T,N,A_traced}
end

# we specify `mode` and `track_numbers` types due to ambiguity
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}), seen, mode, @nospecialize(track_numbers)
@nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}),
seen,
mode::Reactant.TraceMode,
@nospecialize(track_numbers::Type)
)
return T
end
Expand Down Expand Up @@ -132,11 +139,15 @@ function Reactant.set_act!(inp::Enzyme.Annotation{TensorNetwork}, path, reverse,
end
end

function Tenet.contract(
Base.@nospecializeinfer @noinline function Tenet.contract(
@nospecialize(a::Tensor{TracedRNumber{Ta},Na,TracedRArray{Ta,Na}}),
@nospecialize(b::Tensor{TracedRNumber{Tb},Nb,TracedRArray{Tb,Nb}});
kwargs...,
) where {Ta,Na,Tb,Nb}
# return Base.inferencebarrier(__contract_binary)(a, b; kwargs...)
# end

# function __contract_binary(@nospecialize(a), @nospecialize(b); kwargs...)
dims = get(kwargs, :dims) do
(inds(a), inds(b))
end
Expand Down Expand Up @@ -171,8 +182,8 @@ function Tenet.contract(

# TODO replace for `Ops.convert`/`adapt` when it's available (there can be problems with nested array structures)
T = Base.promote_eltype(a, b)
da = eltype(a) != T ? TracedRArray{T,Na}(parent(a)) : parent(a)
db = eltype(b) != T ? TracedRArray{T,Nb}(parent(b)) : parent(b)
da = eltype(a) != T ? TracedRArray{T,ndims(a)}(parent(a)) : parent(a)
db = eltype(b) != T ? TracedRArray{T,ndims(b)}(parent(b)) : parent(b)

data = Reactant.Ops.dot_general(da, db; contracting_dimensions, batching_dimensions)

Expand Down

0 comments on commit 393f727

Please sign in to comment.