Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 730725629
  • Loading branch information
zzzaries authored and copybara-github committed Feb 26, 2025
1 parent 5152fa5 commit 4983fcd
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <sys/types.h>

#include <cstdint>
#include <deque>
#include <string>
#include <utility>
Expand Down Expand Up @@ -59,13 +60,43 @@ constexpr absl::string_view kGetTupleElementIndex = "get_tuple_element_index";
constexpr absl::string_view kUsers = "users";
constexpr absl::string_view kOperands = "operands";
constexpr absl::string_view kLiteral = "literal";
constexpr absl::string_view kHide = "hide_node";

constexpr int kMaxUsersToRender = 16;

// OutputEdges is a map from source instruction id to a list of its users.
using OutputEdges =
absl::flat_hash_map<std::string, std::vector<const xla::HloInstruction*>>;

// The current detection logic is best effort, could be improved if dedicated
// xla metadata is available.
bool IsAcfPrameter(const xla::HloInstruction* instruction) {
// Parameter is fused
if (instruction->opcode() != xla::HloOpcode::kParameter ||
!instruction->IsFused())
return false;

// Fused into ACF
int64_t parameter_number = instruction->parameter_number();
xla::HloInstruction* fusion_instruction =
instruction->parent()->FusionInstruction();
if (!instruction->parent()->name().starts_with("async_collective_fusion"))
return false;

// Input is gte from AsyncCollectiveStart
const xla::HloInstruction* parameterOperand =
fusion_instruction->operand(parameter_number);
if (parameterOperand->opcode() != xla::HloOpcode::kGetTupleElement) {
return false;
}
const xla::HloInstruction* gteOperand = parameterOperand->operand(0);
if (!gteOperand->name().starts_with("async-collective-start")) {
return false;
}
// Parameter only have one user
return instruction->user_count() == 1;
}

// Recursively include all instructions in the nested computations.
void RecursiveIncludeNestedComputations(
const xla::HloInstruction* instruction,
Expand Down Expand Up @@ -279,11 +310,25 @@ void SetInstructionNodeAttributes(const xla::HloInstruction* instruction,
}

// Attach get-tuple-element index if its define is a GTE and folded.
for (const xla::HloInstruction* operand : instruction->operands()) {
if (IsGetTupleElement(options, operand)) {
builder.AppendNodeAttribute(kGetTupleElementIndex,
absl::StrCat(operand->tuple_index(), " of ",
operand->operand(0)->name()));
if (options.get_tuple_element_folding) {
absl::flat_hash_map<std::string, std::vector<std::string>>
tuple_indexes_by_operand;
for (const xla::HloInstruction* operand : instruction->operands()) {
if (IsGetTupleElement(options, operand)) {
tuple_indexes_by_operand[operand->operand(0)->name()].push_back(
absl::StrCat(operand->tuple_index()));
}
}
std::string tuple_indexes_string;
for (const auto& [operand_name, tuple_indexes] : tuple_indexes_by_operand) {
if (!tuple_indexes_string.empty()) {
tuple_indexes_string += ";";
}
tuple_indexes_string +=
absl::StrCat(absl::StrJoin(tuple_indexes, ","), " of ", operand_name);
}
if (!tuple_indexes_string.empty()) {
builder.AppendNodeAttribute(kGetTupleElementIndex, tuple_indexes_string);
}
}

Expand All @@ -292,6 +337,12 @@ void SetInstructionNodeAttributes(const xla::HloInstruction* instruction,
xla::Cast<xla::HloConstantInstruction>(instruction)->HasLiteral()) {
builder.AppendNodeAttribute(kLiteral, instruction->literal().ToString());
}

if (options.hide_async_collective_fusion_parameter) {
if (IsAcfPrameter(instruction)) {
builder.AppendNodeAttribute(kHide, "1");
}
}
}

absl::Status BuildHloInstructionNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ struct HloAdapterOption {
// "ROOT of fusion.0" -> "Parameter 0 of fusion.1, tuple element 1 of
// fusion.1".
bool get_tuple_element_folding = true;
// If a parameter node has input to async-collective-start and output to
// async-collective-done, mark it as implementation details and hide
// on visualization.
bool hide_async_collective_fusion_parameter = true;
};

// Gets the instruction id.
Expand Down

0 comments on commit 4983fcd

Please sign in to comment.