Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable hiding ACF parameters node that are implementation details in the hlo graph. #310

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <sys/types.h>

#include <cstdint>
#include <deque>
#include <string>
#include <utility>
Expand All @@ -27,6 +28,7 @@
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
Expand Down Expand Up @@ -60,13 +62,57 @@ constexpr absl::string_view kGetTupleElementIndex = "get_tuple_element_index";
constexpr absl::string_view kUsers = "users";
constexpr absl::string_view kOperands = "operands";
constexpr absl::string_view kLiteral = "literal";
constexpr absl::string_view kHide = "hide_node";

constexpr int kMaxUsersToRender = 16;

// OutputEdges is a map from source instruction id to a list of its users.
using OutputEdges =
absl::flat_hash_map<std::string, std::vector<const xla::HloInstruction*>>;

// Detect if an instruction is an AsyncCollectiveFusion parameter that is
// implementation details.
bool IsAcfPrameter(const xla::HloInstruction* instruction) {
// Parameter is fused
if (instruction->opcode() != xla::HloOpcode::kParameter ||
!instruction->IsFused())
return false;

// Parameter piped through and is only consumed by 1 user
// Parameter 0 consumed by both root and all-gather will always persist.
if (instruction->user_count() != 1) return false;

const xla::HloComputation* parent_computation = instruction->parent();
int64_t parameter_number = instruction->parameter_number();
xla::HloInstruction* fusion_instruction =
parent_computation->FusionInstruction();
const xla::HloInstruction* parameterOperand =
fusion_instruction->operand(parameter_number);
// Operand is get-tuple-element
if (parameterOperand->opcode() != xla::HloOpcode::kGetTupleElement) {
return false;
}
constexpr absl::string_view kAcfComputationName = "async_collective_fusion";
constexpr absl::string_view kAcsInstructionName = "async-collective-start";
constexpr absl::string_view kAcdInstructionName = "async-collective-done";

const xla::HloInstruction* gteOperand = parameterOperand->operand(0);
// Parameter is fused into AsyncCollectiveFusion, operand is gte from
// AsyncCollectiveStart and user is the root node of ACF
if (absl::StartsWith(parent_computation->name(), kAcfComputationName)) {
return absl::StartsWith(gteOperand->name(), kAcsInstructionName) &&
instruction->users()[0] == parent_computation->root_instruction();
} else if (absl::StartsWith(fusion_instruction->name(),
kAcdInstructionName)) {
// Parameter is fused into AsyncCollectiveDone, mapped from Params in
// AsyncCollectiveFusion - operand is gte from ACF
return absl::StartsWith(
gteOperand->fused_instructions_computation()->name(),
kAcfComputationName);
}
return false;
}

// Recursively include all instructions in the nested computations.
void RecursiveIncludeNestedComputations(
const xla::HloInstruction* instruction,
Expand Down Expand Up @@ -311,6 +357,12 @@ void SetInstructionNodeAttributes(const xla::HloInstruction* instruction,
xla::Cast<xla::HloConstantInstruction>(instruction)->HasLiteral()) {
builder.AppendNodeAttribute(kLiteral, instruction->literal().ToString());
}

if (options.hide_async_collective_fusion_parameter) {
if (IsAcfPrameter(instruction)) {
builder.AppendNodeAttribute(kHide, "1");
}
}
}

absl::Status BuildHloInstructionNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ struct HloAdapterOption {
// "ROOT of fusion.0" -> "Parameter 0 of fusion.1, tuple element 1 of
// fusion.1".
bool get_tuple_element_folding = true;
// If a parameter node has input to async-collective-start and output to
// async-collective-done, mark it as implementation details and hide
// on visualization.
bool hide_async_collective_fusion_parameter = true;
};

// Gets the instruction id.
Expand Down