Skip to content

Commit

Permalink
Move getConsumerPosAlignedToProducerCA to MaxPosCalculator (#3766)
Browse files Browse the repository at this point in the history
Moved `getConsumerPosAlignedToProducerCA` to `MaxPosCalculator` to reuse
an `IdModel` built for the entire fusion.

No output should be affected, but there are two main reasons of the
change, one is efficiency and another is functionality.

- Previously, each call to `TensorView::updateMaxProducerPosition`
created a (small) IdModel, which is inefficient since in common cases
`updateMaxProducerPosition` is called from routines that already holds
an IdModel through `MaxPosCalculator`. Moving the function to
`MaxPosCalculator` can reuse the same IdModel.
- A local IdModel that is created by just looking at a single expression
may miss exact mappings previously registered through
`Fusion::registerExactMapping`. For example:

```
t0: [i0]
t1: [i1, i2]

t2 = broadcast(t0, {false, true});
// t1: [i3, b4]

t3 = t1 + t2
// [i5, i7]

t4 = t3 + 1
// [i7, i8]
```

In this case, there are exact groups of `{i0, i1, i3, i5, i7}`, `{i2,
i7, i8}`and `{b4}`. We could set the loop domain of `t2` look like `t4`
by `scheduleLoopDomainsBy({t2}, t4->getLoopDomain())`, which would clone
`i8` and set the loop domain of `t2` as `{i3, i9}`, where `i9` is a
clone of `i8` and is registered as exact mapped with `i8`. This should
all be fine as long as we looked at the whole fusion, but if an IdModel
was built only for `t3 = t1 + t2`, the loop domain of `t2`, `{i3, i9}`,
wouldn't be grouped together with the loop domain of `t3`, `{i5, i7}`.
It would discover the mapping of `i3` and `i5`, but it wouldn't know
anything between `i9` and `i7`. The registered exact mapping wouldn't
help because it's registered with a pair of `i8` and `i9`. It is obvious
that `i8` is mapped with `i7`, only when the IdModel analysis also looks
at the `t4` definition. When an IdModel is created only for `t3 = t1 +
t2`, the registered mapping would do nothing. This problem can be
avoided by using a whole-fusion IdModel.
  • Loading branch information
naoyam authored Feb 4, 2025
1 parent 3ac19f0 commit a1baafa
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 77 deletions.
2 changes: 1 addition & 1 deletion csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ class NVF_API TensorView : public Val {
// Update the max producer position of the current tensor. This is required
// when we modify producer-consumer relationship of a scheduled tensor, for
// example, grouping multiple reductions.
void updateMaxProducerPosition();
void updateMaxProducerPosition(MaxPosCalculator* calc = nullptr);

// Commit the current changes in loop domain into rFactor domain. This
// function can be used to do implicit transpose and view, but today, only
Expand Down
66 changes: 66 additions & 0 deletions csrc/scheduler/tools/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,72 @@ size_t MaxPosCalculator::getMaxPosAll(
return max_pos;
}

// Try to find the aligned position on consumer's domain corresponding to a
// position of producer domain. No checking on actual
// producer-consumer relationship.
int64_t MaxPosCalculator::getConsumerPosAlignedToProducerCA(
TensorView* consumer,
TensorView* producer,
int64_t producer_pos) {
// Locate consumer's position that aligns with
// the producer's position. We need broadcast axes forwarded so we
// need to replay PasC as CasP will not forward braodcast dims. For example
// if we have:
// T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
// produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
// NVFuserTest.FusionComplexBCast1_CUDA

int64_t consumer_pos = consumer->nDims();

const bool may_need_forwarding =
ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(producer) &&
ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(consumer);

if (may_need_forwarding) {
auto disjoint_sets = BestEffortReplay::replayPasC(
producer,
consumer,
-1,
PairwiseLogicalDomainMap(producer, consumer))
.getIterDomainEquivalence();

// Find the innermost position of consumer that has
// been mapped within the producer ca axis.

while (consumer_pos > 0) {
auto consumer_id = consumer->axis(consumer_pos - 1);
const auto& p_dom = producer->getLoopDomain();
if (std::any_of(
p_dom.begin(),
p_dom.begin() + producer_pos,
[&consumer_id, &disjoint_sets](IterDomain* p_id) {
return disjoint_sets.permissiveAreMapped(consumer_id, p_id);
})) {
break;
}
consumer_pos--;
}
} else {
while (consumer_pos > 0) {
auto consumer_id = consumer->axis(consumer_pos - 1);
const auto& p_dom = producer->getLoopDomain();
if (std::any_of(
p_dom.begin(),
p_dom.begin() + producer_pos,
[&](IterDomain* p_id) {
return inliningGraph().disjointValSets().strictAreMapped(
consumer_id, p_id);
})) {
break;
}
consumer_pos--;
}
}

return consumer_pos;
}

void inlineMost(const std::unordered_set<IterDomain*>& uninlinable_ids) {
inlineMost(FusionGuard::getCurFusion()->allTvs(), uninlinable_ids);
}
Expand Down
5 changes: 5 additions & 0 deletions csrc/scheduler/tools/inlining.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class MaxPosCalculator {
bool best_effort = false,
bool check_siblings = true);

int64_t getConsumerPosAlignedToProducerCA(
TensorView* consumer,
TensorView* producer,
int64_t producer_pos);

MaxPosCalculator(
std::unordered_set<IterDomain*> uninlinable_ids = {},
bool compute_at_only = false);
Expand Down
84 changes: 8 additions & 76 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,89 +196,21 @@ void TensorView::inlineAt(
}

for (auto consumer : ir_utils::consumerTvsOf(this)) {
consumer->updateMaxProducerPosition();
consumer->updateMaxProducerPosition(calc);
}
}

namespace {

// Try to find the aligned position on consumer's domain corresponding to a
// position of producer domain. No checking on actual
// producer-consumer relationship.
int64_t getConsumerPosAlignedToProducerCA(
TensorView* consumer,
TensorView* producer,
int64_t producer_pos) {
// Locate consumer's position that aligns with
// the producer's position. We need broadcast axes forwarded so we
// need to replay PasC as CasP will not forward braodcast dims. For example
// if we have:
// T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
// produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
// NVFuserTest.FusionComplexBCast1_CUDA

int64_t consumer_pos = consumer->nDims();

const bool may_need_forwarding =
ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(producer) &&
ir_utils::isLoopDomainFullyDerivedFromLogicalDomain(consumer);

if (may_need_forwarding) {
auto disjoint_sets = BestEffortReplay::replayPasC(
producer,
consumer,
-1,
PairwiseLogicalDomainMap(producer, consumer))
.getIterDomainEquivalence();

// Find the innermost position of consumer that has
// been mapped within the producer ca axis.

while (consumer_pos > 0) {
auto consumer_id = consumer->axis(consumer_pos - 1);
const auto& p_dom = producer->getLoopDomain();
if (std::any_of(
p_dom.begin(),
p_dom.begin() + producer_pos,
[&consumer_id, &disjoint_sets](IterDomain* p_id) {
return disjoint_sets.permissiveAreMapped(consumer_id, p_id);
})) {
break;
}
consumer_pos--;
}
} else {
IdModel id_model({consumer->definition()}, {}, false);
id_model.buildBroadcastGraph();
const auto& inlining_graph = id_model.idGraph(IdMappingMode::BROADCAST);

while (consumer_pos > 0) {
auto consumer_id = consumer->axis(consumer_pos - 1);
const auto& p_dom = producer->getLoopDomain();
if (std::any_of(
p_dom.begin(),
p_dom.begin() + producer_pos,
[&consumer_id, &inlining_graph](IterDomain* p_id) {
return inlining_graph.disjointValSets().strictAreMapped(
consumer_id, p_id);
})) {
break;
}
consumer_pos--;
}
void TensorView::updateMaxProducerPosition(MaxPosCalculator* calc) {
std::unique_ptr<MaxPosCalculator> calc_owner;
if (calc == nullptr) {
calc_owner = std::make_unique<MaxPosCalculator>();
calc = calc_owner.get();
}

return consumer_pos;
}

} // namespace

void TensorView::updateMaxProducerPosition() {
for (auto producer : ir_utils::producerTvsOf(this)) {
max_producer_pos_ = std::max(
max_producer_pos_,
getConsumerPosAlignedToProducerCA(
calc->getConsumerPosAlignedToProducerCA(
this, producer, producer->getComputePosition(this)));
}

Expand All @@ -293,7 +225,7 @@ void TensorView::updateMaxProducerPosition() {
if (producer->hasComputeWith() && !producer->hasResolvedComputeWith()) {
maybe_max_producer_pos_ = std::max(
maybe_max_producer_pos_,
getConsumerPosAlignedToProducerCA(
calc->getConsumerPosAlignedToProducerCA(
this, producer, producer->getComputeWithPosition()));
}
}
Expand Down

0 comments on commit a1baafa

Please sign in to comment.