-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move getConsumerPosAlignedToProducerCA to MaxPosCalculator #3766
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -193,89 +193,21 @@ void TensorView::inlineAt( | |
} | ||
|
||
for (auto consumer : ir_utils::consumerTvsOf(this)) { | ||
consumer->updateMaxProducerPosition(); | ||
consumer->updateMaxProducerPosition(calc); | ||
} | ||
} | ||
|
||
namespace { | ||
|
||
// Try to find the aligned position on consumer's domain corresponding to a | ||
// position of producer domain. No checking on actual | ||
// producer-consumer relationship. | ||
int64_t getConsumerPosAlignedToProducerCA( | ||
TensorView* consumer, | ||
TensorView* producer, | ||
int64_t producer_pos) { | ||
// Locate consumer's position that aligns with | ||
// the producer's position. We need broadcast axes forwarded so we | ||
// need to replay PasC as CasP will not forward braodcast dims. For example | ||
// if we have: | ||
// T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) | ||
// produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will | ||
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to | ||
// NVFuserTest.FusionComplexBCast1_CUDA | ||
|
||
int64_t consumer_pos = consumer->nDims(); | ||
|
||
const bool may_need_forwarding = | ||
ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(producer) && | ||
ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(consumer); | ||
|
||
if (may_need_forwarding) { | ||
auto disjoint_sets = BestEffortReplay::replayPasC( | ||
producer, | ||
consumer, | ||
-1, | ||
PairwiseLogicalDomainMap(producer, consumer)) | ||
.getIterDomainEquivalence(); | ||
|
||
// Find the innermost position of consumer that has | ||
// been mapped within the producer ca axis. | ||
|
||
while (consumer_pos > 0) { | ||
auto consumer_id = consumer->axis(consumer_pos - 1); | ||
const auto& p_dom = producer->getLoopDomain(); | ||
if (std::any_of( | ||
p_dom.begin(), | ||
p_dom.begin() + producer_pos, | ||
[&consumer_id, &disjoint_sets](IterDomain* p_id) { | ||
return disjoint_sets.permissiveAreMapped(consumer_id, p_id); | ||
})) { | ||
break; | ||
} | ||
consumer_pos--; | ||
} | ||
} else { | ||
IdModel id_model({consumer->definition()}, {}, false); | ||
id_model.buildBroadcastGraph(); | ||
const auto& inlining_graph = id_model.idGraph(IdMappingMode::BROADCAST); | ||
Comment on lines
-252
to
-254
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is removed when moved to |
||
|
||
while (consumer_pos > 0) { | ||
auto consumer_id = consumer->axis(consumer_pos - 1); | ||
const auto& p_dom = producer->getLoopDomain(); | ||
if (std::any_of( | ||
p_dom.begin(), | ||
p_dom.begin() + producer_pos, | ||
[&consumer_id, &inlining_graph](IterDomain* p_id) { | ||
return inlining_graph.disjointValSets().strictAreMapped( | ||
consumer_id, p_id); | ||
})) { | ||
break; | ||
} | ||
consumer_pos--; | ||
} | ||
void TensorView::updateMaxProducerPosition(MaxPosCalculator* calc) { | ||
std::unique_ptr<MaxPosCalculator> calc_owner; | ||
if (calc == nullptr) { | ||
calc_owner = std::make_unique<MaxPosCalculator>(); | ||
calc = calc_owner.get(); | ||
} | ||
|
||
return consumer_pos; | ||
} | ||
|
||
} // namespace | ||
|
||
void TensorView::updateMaxProducerPosition() { | ||
for (auto producer : ir_utils::producerTvsOf(this)) { | ||
max_producer_pos_ = std::max( | ||
max_producer_pos_, | ||
getConsumerPosAlignedToProducerCA( | ||
calc->getConsumerPosAlignedToProducerCA( | ||
this, producer, producer->getComputePosition(this))); | ||
} | ||
|
||
|
@@ -290,7 +222,7 @@ void TensorView::updateMaxProducerPosition() { | |
if (producer->hasComputeWith() && !producer->hasResolvedComputeWith()) { | ||
maybe_max_producer_pos_ = std::max( | ||
maybe_max_producer_pos_, | ||
getConsumerPosAlignedToProducerCA( | ||
calc->getConsumerPosAlignedToProducerCA( | ||
this, producer, producer->getComputeWithPosition())); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is moved to
MaxPosCalculator
, except for one difference noted below.