Skip to content

Commit

Permalink
resize scheduler fix
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Jan 28, 2025
1 parent 4a87662 commit 34a57d5
Showing 1 changed file with 54 additions and 4 deletions.
58 changes: 54 additions & 4 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,12 +442,13 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
// which has the repeat ID, the full loop domain is propagated only
// to the post-repeat group. For the pre-repeat group, the repeat ID
// is dropped and only the remaining loop domain is propagated.

// Divide all tvs to the pre and posgt repeat groups
std::vector<TensorView*> post_repeat_tvs;
std::vector<TensorView*> pre_repeat_tvs;
if (repeat_id_moved_to_outermost) {
// Divide all tvs to the pre and posgt repeat groups
auto all_tvs = fusion->allTvs();
std::vector<TensorView*> post_repeat_tvs;
post_repeat_tvs.reserve(static_repeat_info->repeat_tvs.size());
std::vector<TensorView*> pre_repeat_tvs;
pre_repeat_tvs.reserve(
all_tvs.size() - static_repeat_info->repeat_tvs.size());
for (auto tv : all_tvs) {
Expand All @@ -457,7 +458,9 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
pre_repeat_tvs.push_back(tv);
}
}
}

if (repeat_id_moved_to_outermost) {
// The repeat ID should be located at the outermost position
std::vector<IterDomain*> non_repeated_loop{
ref_tv->getLoopDomain().begin() + 1, ref_tv->getLoopDomain().end()};
Expand Down Expand Up @@ -517,7 +520,54 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
}
}

inlineMost();
if (!repeat_id_moved_to_outermost) {
inlineMost();
} else {
// Inline the pre-repeat tvs to the tensor that is an input to the
// repeat ops.
auto it = std::find_if(
static_repeat_info->repeat_tvs.begin(),
static_repeat_info->repeat_tvs.end(),
[](TensorView* tv) { return tv->definition()->isA<BroadcastOp>(); });
NVF_ERROR(it != static_repeat_info->repeat_tvs.end());
auto broadcast_inp = (*it)->definition()->input(0)->as<TensorView>();

inlineMost({broadcast_inp->axis(0)});

for (auto producer_tv_of_broadcast_inp :
ir_utils::producerTvsOf(broadcast_inp)) {
int64_t pos = producer_tv_of_broadcast_inp->axis(-1)->getParallelType() ==
ParallelType::Vectorize
? -2
: -1;
producer_tv_of_broadcast_inp->inlineAt(pos);
}

// The loop domain of the broadcast_inp should be exact mapped
// with the ref tensor minus the outermost repeat ID. Since the
// inlining is blocked there, the parallel types need to be
// explicitly set
NVF_ERROR(
ref_tv->getLoopDomain().size() ==
broadcast_inp->getLoopDomain().size() + 1);
for (const auto i : c10::irange(broadcast_inp->getLoopDomain().size())) {
auto ref_tv_ptype = ref_tv->getLoopDomain().at(i + 1)->getParallelType();
if (ref_tv_ptype == ParallelType::Vectorize) {
continue;
}
broadcast_inp->getLoopDomain().at(i)->parallelize(ref_tv_ptype);
}
}

fusion->printMath();

std::cerr << "Final\n";
for (auto tv : fusion->allTvs()) {
std::cerr << tv->toString();
for (auto expr : tv->domain()->allExprs()) {
std::cerr << expr->toString();
}
}

markAliases(fusion);
}
Expand Down

0 comments on commit 34a57d5

Please sign in to comment.