From a1baafa81e117bfb439a1fe38cc3ea820441b2c1 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 4 Feb 2025 12:02:21 -0800 Subject: [PATCH] Move getConsumerPosAlignedToProducerCA to MaxPosCalculator (#3766) Moved `getConsumerPosAlignedToProducerCA` to `MaxPosCalculator` to reuse an `IdModel` built for the entire fusion. No output should be affected, but there are two main reasons of the change, one is efficiency and another is functionality. - Previously, each call to `TensorView::updateMaxProducerPosition` created a (small) IdModel, which is inefficient since in common cases `updateMaxProducerPosition` is called from routines that already holds an IdModel through `MaxPosCalculator`. Moving the function to `MaxPosCalculator` can reuse the same IdModel. - A local IdModel that is created by just looking at a single expression may miss exact mappings previously registered through `Fusion::registerExactMapping`. For example: ``` t0: [i0] t1: [i1, i2] t2 = broadcast(t0, {false, true}); // t1: [i3, b4] t3 = t1 + t2 // [i5, i7] t4 = t3 + 1 // [i7, i8] ``` In this case, there are exact groups of `{i0, i1, i3, i5, i7}`, `{i2, i7, i8}`and `{b4}`. We could set the loop domain of `t2` look like `t4` by `scheduleLoopDomainsBy({t2}, t4->getLoopDomain())`, which would clone `i8` and set the loop domain of `t2` as `{i3, i9}`, where `i9` is a clone of `i8` and is registered as exact mapped with `i8`. This should all be fine as long as we looked at the whole fusion, but if an IdModel was built only for `t3 = t1 + t2`, the loop domain of `t2`, `{i3, i9}`, wouldn't be grouped together with the loop domain of `t3`, `{i5, i7}`. It would discover the mapping of `i3` and `i5`, but it wouldn't know anything between `i9` and `i7`. The registered exact mapping wouldn't help because it's registered with a pair of `i8` and `i9`. It is obvious that `i8` is mapped with `i7`, only when the IdModel analysis also looks at the `t4` definition. When an IdModel is created only for `t3 = t1 + t2`, the registered mapping would do nothing. This problem can be avoided by using a whole-fusion IdModel. --- csrc/ir/interface_nodes.h | 2 +- csrc/scheduler/tools/inlining.cpp | 66 ++++++++++++++++++++++++ csrc/scheduler/tools/inlining.h | 5 ++ csrc/tensor_view.cpp | 84 +++---------------------------- 4 files changed, 80 insertions(+), 77 deletions(-) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 98236fd3c5f..a60abf65e87 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -785,7 +785,7 @@ class NVF_API TensorView : public Val { // Update the max producer position of the current tensor. This is required // when we modify producer-consumer relationship of a scheduled tensor, for // example, grouping multiple reductions. - void updateMaxProducerPosition(); + void updateMaxProducerPosition(MaxPosCalculator* calc = nullptr); // Commit the current changes in loop domain into rFactor domain. This // function can be used to do implicit transpose and view, but today, only diff --git a/csrc/scheduler/tools/inlining.cpp b/csrc/scheduler/tools/inlining.cpp index 6064f4e6e7c..3eeea51fe0d 100644 --- a/csrc/scheduler/tools/inlining.cpp +++ b/csrc/scheduler/tools/inlining.cpp @@ -329,6 +329,72 @@ size_t MaxPosCalculator::getMaxPosAll( return max_pos; } +// Try to find the aligned position on consumer's domain corresponding to a +// position of producer domain. No checking on actual +// producer-consumer relationship. +int64_t MaxPosCalculator::getConsumerPosAlignedToProducerCA( + TensorView* consumer, + TensorView* producer, + int64_t producer_pos) { + // Locate consumer's position that aligns with + // the producer's position. We need broadcast axes forwarded so we + // need to replay PasC as CasP will not forward braodcast dims. For example + // if we have: + // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) + // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will + // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to + // NVFuserTest.FusionComplexBCast1_CUDA + + int64_t consumer_pos = consumer->nDims(); + + const bool may_need_forwarding = + ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(producer) && + ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(consumer); + + if (may_need_forwarding) { + auto disjoint_sets = BestEffortReplay::replayPasC( + producer, + consumer, + -1, + PairwiseLogicalDomainMap(producer, consumer)) + .getIterDomainEquivalence(); + + // Find the innermost position of consumer that has + // been mapped within the producer ca axis. + + while (consumer_pos > 0) { + auto consumer_id = consumer->axis(consumer_pos - 1); + const auto& p_dom = producer->getLoopDomain(); + if (std::any_of( + p_dom.begin(), + p_dom.begin() + producer_pos, + [&consumer_id, &disjoint_sets](IterDomain* p_id) { + return disjoint_sets.permissiveAreMapped(consumer_id, p_id); + })) { + break; + } + consumer_pos--; + } + } else { + while (consumer_pos > 0) { + auto consumer_id = consumer->axis(consumer_pos - 1); + const auto& p_dom = producer->getLoopDomain(); + if (std::any_of( + p_dom.begin(), + p_dom.begin() + producer_pos, + [&](IterDomain* p_id) { + return inliningGraph().disjointValSets().strictAreMapped( + consumer_id, p_id); + })) { + break; + } + consumer_pos--; + } + } + + return consumer_pos; +} + void inlineMost(const std::unordered_set& uninlinable_ids) { inlineMost(FusionGuard::getCurFusion()->allTvs(), uninlinable_ids); } diff --git a/csrc/scheduler/tools/inlining.h b/csrc/scheduler/tools/inlining.h index 0c23bce1340..24eca2ee80b 100644 --- a/csrc/scheduler/tools/inlining.h +++ b/csrc/scheduler/tools/inlining.h @@ -71,6 +71,11 @@ class MaxPosCalculator { bool best_effort = false, bool check_siblings = true); + int64_t getConsumerPosAlignedToProducerCA( + TensorView* consumer, + TensorView* producer, + int64_t producer_pos); + MaxPosCalculator( std::unordered_set uninlinable_ids = {}, bool compute_at_only = false); diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 07bc474c8c7..6d83176706e 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -196,89 +196,21 @@ void TensorView::inlineAt( } for (auto consumer : ir_utils::consumerTvsOf(this)) { - consumer->updateMaxProducerPosition(); + consumer->updateMaxProducerPosition(calc); } } -namespace { - -// Try to find the aligned position on consumer's domain corresponding to a -// position of producer domain. No checking on actual -// producer-consumer relationship. -int64_t getConsumerPosAlignedToProducerCA( - TensorView* consumer, - TensorView* producer, - int64_t producer_pos) { - // Locate consumer's position that aligns with - // the producer's position. We need broadcast axes forwarded so we - // need to replay PasC as CasP will not forward braodcast dims. For example - // if we have: - // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) - // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will - // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to - // NVFuserTest.FusionComplexBCast1_CUDA - - int64_t consumer_pos = consumer->nDims(); - - const bool may_need_forwarding = - ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(producer) && - ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(consumer); - - if (may_need_forwarding) { - auto disjoint_sets = BestEffortReplay::replayPasC( - producer, - consumer, - -1, - PairwiseLogicalDomainMap(producer, consumer)) - .getIterDomainEquivalence(); - - // Find the innermost position of consumer that has - // been mapped within the producer ca axis. - - while (consumer_pos > 0) { - auto consumer_id = consumer->axis(consumer_pos - 1); - const auto& p_dom = producer->getLoopDomain(); - if (std::any_of( - p_dom.begin(), - p_dom.begin() + producer_pos, - [&consumer_id, &disjoint_sets](IterDomain* p_id) { - return disjoint_sets.permissiveAreMapped(consumer_id, p_id); - })) { - break; - } - consumer_pos--; - } - } else { - IdModel id_model({consumer->definition()}, {}, false); - id_model.buildBroadcastGraph(); - const auto& inlining_graph = id_model.idGraph(IdMappingMode::BROADCAST); - - while (consumer_pos > 0) { - auto consumer_id = consumer->axis(consumer_pos - 1); - const auto& p_dom = producer->getLoopDomain(); - if (std::any_of( - p_dom.begin(), - p_dom.begin() + producer_pos, - [&consumer_id, &inlining_graph](IterDomain* p_id) { - return inlining_graph.disjointValSets().strictAreMapped( - consumer_id, p_id); - })) { - break; - } - consumer_pos--; - } +void TensorView::updateMaxProducerPosition(MaxPosCalculator* calc) { + std::unique_ptr calc_owner; + if (calc == nullptr) { + calc_owner = std::make_unique(); + calc = calc_owner.get(); } - return consumer_pos; -} - -} // namespace - -void TensorView::updateMaxProducerPosition() { for (auto producer : ir_utils::producerTvsOf(this)) { max_producer_pos_ = std::max( max_producer_pos_, - getConsumerPosAlignedToProducerCA( + calc->getConsumerPosAlignedToProducerCA( this, producer, producer->getComputePosition(this))); } @@ -293,7 +225,7 @@ void TensorView::updateMaxProducerPosition() { if (producer->hasComputeWith() && !producer->hasResolvedComputeWith()) { maybe_max_producer_pos_ = std::max( maybe_max_producer_pos_, - getConsumerPosAlignedToProducerCA( + calc->getConsumerPosAlignedToProducerCA( this, producer, producer->getComputeWithPosition())); } }