From c757253649e6b87411ecc699088c0b1358a3d232 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 27 Jan 2025 19:04:03 -0800 Subject: [PATCH] Move updateMaxProducerPos to MaxPosCalculator --- csrc/ir/interface_nodes.h | 2 +- csrc/scheduler/tools/inlining.cpp | 66 ++++++++++++++++++++++++ csrc/scheduler/tools/inlining.h | 5 ++ csrc/tensor_view.cpp | 84 +++---------------------------- 4 files changed, 80 insertions(+), 77 deletions(-) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 98236fd3c5f..a60abf65e87 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -785,7 +785,7 @@ class NVF_API TensorView : public Val { // Update the max producer position of the current tensor. This is required // when we modify producer-consumer relationship of a scheduled tensor, for // example, grouping multiple reductions. - void updateMaxProducerPosition(); + void updateMaxProducerPosition(MaxPosCalculator* calc = nullptr); // Commit the current changes in loop domain into rFactor domain. This // function can be used to do implicit transpose and view, but today, only diff --git a/csrc/scheduler/tools/inlining.cpp b/csrc/scheduler/tools/inlining.cpp index 6064f4e6e7c..3eeea51fe0d 100644 --- a/csrc/scheduler/tools/inlining.cpp +++ b/csrc/scheduler/tools/inlining.cpp @@ -329,6 +329,72 @@ size_t MaxPosCalculator::getMaxPosAll( return max_pos; } +// Try to find the aligned position on consumer's domain corresponding to a +// position of producer domain. No checking on actual +// producer-consumer relationship. +int64_t MaxPosCalculator::getConsumerPosAlignedToProducerCA( + TensorView* consumer, + TensorView* producer, + int64_t producer_pos) { + // Locate consumer's position that aligns with + // the producer's position. We need broadcast axes forwarded so we + // need to replay PasC as CasP will not forward braodcast dims. For example + // if we have: + // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) + // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will + // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to + // NVFuserTest.FusionComplexBCast1_CUDA + + int64_t consumer_pos = consumer->nDims(); + + const bool may_need_forwarding = + ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(producer) && + ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(consumer); + + if (may_need_forwarding) { + auto disjoint_sets = BestEffortReplay::replayPasC( + producer, + consumer, + -1, + PairwiseLogicalDomainMap(producer, consumer)) + .getIterDomainEquivalence(); + + // Find the innermost position of consumer that has + // been mapped within the producer ca axis. + + while (consumer_pos > 0) { + auto consumer_id = consumer->axis(consumer_pos - 1); + const auto& p_dom = producer->getLoopDomain(); + if (std::any_of( + p_dom.begin(), + p_dom.begin() + producer_pos, + [&consumer_id, &disjoint_sets](IterDomain* p_id) { + return disjoint_sets.permissiveAreMapped(consumer_id, p_id); + })) { + break; + } + consumer_pos--; + } + } else { + while (consumer_pos > 0) { + auto consumer_id = consumer->axis(consumer_pos - 1); + const auto& p_dom = producer->getLoopDomain(); + if (std::any_of( + p_dom.begin(), + p_dom.begin() + producer_pos, + [&](IterDomain* p_id) { + return inliningGraph().disjointValSets().strictAreMapped( + consumer_id, p_id); + })) { + break; + } + consumer_pos--; + } + } + + return consumer_pos; +} + void inlineMost(const std::unordered_set& uninlinable_ids) { inlineMost(FusionGuard::getCurFusion()->allTvs(), uninlinable_ids); } diff --git a/csrc/scheduler/tools/inlining.h b/csrc/scheduler/tools/inlining.h index 0c23bce1340..24eca2ee80b 100644 --- a/csrc/scheduler/tools/inlining.h +++ b/csrc/scheduler/tools/inlining.h @@ -71,6 +71,11 @@ class MaxPosCalculator { bool best_effort = false, bool check_siblings = true); + int64_t getConsumerPosAlignedToProducerCA( + TensorView* consumer, + TensorView* producer, + int64_t producer_pos); + MaxPosCalculator( std::unordered_set uninlinable_ids = {}, bool compute_at_only = false); diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index f04eee6a7d9..bf4fcd2f640 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -193,89 +193,21 @@ void TensorView::inlineAt( } for (auto consumer : ir_utils::consumerTvsOf(this)) { - consumer->updateMaxProducerPosition(); + consumer->updateMaxProducerPosition(calc); } } -namespace { - -// Try to find the aligned position on consumer's domain corresponding to a -// position of producer domain. No checking on actual -// producer-consumer relationship. -int64_t getConsumerPosAlignedToProducerCA( - TensorView* consumer, - TensorView* producer, - int64_t producer_pos) { - // Locate consumer's position that aligns with - // the producer's position. We need broadcast axes forwarded so we - // need to replay PasC as CasP will not forward braodcast dims. For example - // if we have: - // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) - // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will - // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to - // NVFuserTest.FusionComplexBCast1_CUDA - - int64_t consumer_pos = consumer->nDims(); - - const bool may_need_forwarding = - ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(producer) && - ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(consumer); - - if (may_need_forwarding) { - auto disjoint_sets = BestEffortReplay::replayPasC( - producer, - consumer, - -1, - PairwiseLogicalDomainMap(producer, consumer)) - .getIterDomainEquivalence(); - - // Find the innermost position of consumer that has - // been mapped within the producer ca axis. - - while (consumer_pos > 0) { - auto consumer_id = consumer->axis(consumer_pos - 1); - const auto& p_dom = producer->getLoopDomain(); - if (std::any_of( - p_dom.begin(), - p_dom.begin() + producer_pos, - [&consumer_id, &disjoint_sets](IterDomain* p_id) { - return disjoint_sets.permissiveAreMapped(consumer_id, p_id); - })) { - break; - } - consumer_pos--; - } - } else { - IdModel id_model({consumer->definition()}, {}, false); - id_model.buildBroadcastGraph(); - const auto& inlining_graph = id_model.idGraph(IdMappingMode::BROADCAST); - - while (consumer_pos > 0) { - auto consumer_id = consumer->axis(consumer_pos - 1); - const auto& p_dom = producer->getLoopDomain(); - if (std::any_of( - p_dom.begin(), - p_dom.begin() + producer_pos, - [&consumer_id, &inlining_graph](IterDomain* p_id) { - return inlining_graph.disjointValSets().strictAreMapped( - consumer_id, p_id); - })) { - break; - } - consumer_pos--; - } +void TensorView::updateMaxProducerPosition(MaxPosCalculator* calc) { + std::unique_ptr calc_owner; + if (calc == nullptr) { + calc_owner = std::make_unique(); + calc = calc_owner.get(); } - return consumer_pos; -} - -} // namespace - -void TensorView::updateMaxProducerPosition() { for (auto producer : ir_utils::producerTvsOf(this)) { max_producer_pos_ = std::max( max_producer_pos_, - getConsumerPosAlignedToProducerCA( + calc->getConsumerPosAlignedToProducerCA( this, producer, producer->getComputePosition(this))); } @@ -290,7 +222,7 @@ void TensorView::updateMaxProducerPosition() { if (producer->hasComputeWith() && !producer->hasResolvedComputeWith()) { maybe_max_producer_pos_ = std::max( maybe_max_producer_pos_, - getConsumerPosAlignedToProducerCA( + calc->getConsumerPosAlignedToProducerCA( this, producer, producer->getComputeWithPosition())); } }