Skip to content

Commit

Permalink
Enable hiding ACF parameters node that are implementation details in …
Browse files Browse the repository at this point in the history
…the hlo graph.

PiperOrigin-RevId: 730725629
  • Loading branch information
zzzaries authored and copybara-github committed Mar 7, 2025
1 parent 076bccb commit 26baddb
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <sys/types.h>

#include <cstdint>
#include <deque>
#include <string>
#include <utility>
Expand All @@ -27,6 +28,7 @@
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
Expand Down Expand Up @@ -60,13 +62,57 @@ constexpr absl::string_view kGetTupleElementIndex = "get_tuple_element_index";
constexpr absl::string_view kUsers = "users";
constexpr absl::string_view kOperands = "operands";
constexpr absl::string_view kLiteral = "literal";
constexpr absl::string_view kHide = "hide_node";

constexpr int kMaxUsersToRender = 16;

// OutputEdges is a map from source instruction id to a list of its users.
using OutputEdges =
absl::flat_hash_map<std::string, std::vector<const xla::HloInstruction*>>;

// Detect if an instruction is an AsyncCollectiveFusion parameter that is
// implementation details.
bool IsAcfPrameter(const xla::HloInstruction* instruction) {
// Parameter is fused
if (instruction->opcode() != xla::HloOpcode::kParameter ||
!instruction->IsFused())
return false;

// Parameter piped through and is only consumed by 1 user
// Parameter 0 consumed by both root and all-gather will always persist.
if (instruction->user_count() != 1) return false;

const xla::HloComputation* parent_computation = instruction->parent();
int64_t parameter_number = instruction->parameter_number();
xla::HloInstruction* fusion_instruction =
parent_computation->FusionInstruction();
const xla::HloInstruction* parameterOperand =
fusion_instruction->operand(parameter_number);
// Operand is get-tuple-element
if (parameterOperand->opcode() != xla::HloOpcode::kGetTupleElement) {
return false;
}
constexpr absl::string_view kAcfComputationName = "async_collective_fusion";
constexpr absl::string_view kAcsInstructionName = "async-collective-start";
constexpr absl::string_view kAcdInstructionName = "async-collective-done";

const xla::HloInstruction* gteOperand = parameterOperand->operand(0);
// Parameter is fused into AsyncCollectiveFusion, operand is gte from
// AsyncCollectiveStart and user is the root node of ACF
if (absl::StartsWith(parent_computation->name(), kAcfComputationName)) {
return absl::StartsWith(gteOperand->name(), kAcsInstructionName) &&
instruction->users()[0] == parent_computation->root_instruction();
} else if (absl::StartsWith(fusion_instruction->name(),
kAcdInstructionName)) {
// Parameter is fused into AsyncCollectiveDone, mapped from Params in
// AsyncCollectiveFusion - operand is gte from ACF
return absl::StartsWith(
gteOperand->fused_instructions_computation()->name(),
kAcfComputationName);
}
return false;
}

// Recursively include all instructions in the nested computations.
void RecursiveIncludeNestedComputations(
const xla::HloInstruction* instruction,
Expand Down Expand Up @@ -311,6 +357,12 @@ void SetInstructionNodeAttributes(const xla::HloInstruction* instruction,
xla::Cast<xla::HloConstantInstruction>(instruction)->HasLiteral()) {
builder.AppendNodeAttribute(kLiteral, instruction->literal().ToString());
}

if (options.hide_async_collective_fusion_parameter) {
if (IsAcfPrameter(instruction)) {
builder.AppendNodeAttribute(kHide, "1");
}
}
}

absl::Status BuildHloInstructionNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ struct HloAdapterOption {
// "ROOT of fusion.0" -> "Parameter 0 of fusion.1, tuple element 1 of
// fusion.1".
bool get_tuple_element_folding = true;
// If a parameter node has input to async-collective-start and output to
// async-collective-done, mark it as implementation details and hide
// on visualization.
bool hide_async_collective_fusion_parameter = true;
};

// Gets the instruction id.
Expand Down

0 comments on commit 26baddb

Please sign in to comment.