diff --git a/src/builtin-adapter/hlo_adapter/direct_hlo_to_json_graph_convert.cc b/src/builtin-adapter/hlo_adapter/direct_hlo_to_json_graph_convert.cc index 2a87f4f5..5462d0a6 100644 --- a/src/builtin-adapter/hlo_adapter/direct_hlo_to_json_graph_convert.cc +++ b/src/builtin-adapter/hlo_adapter/direct_hlo_to_json_graph_convert.cc @@ -17,6 +17,7 @@ #include +#include #include #include #include @@ -27,6 +28,7 @@ #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -60,6 +62,7 @@ constexpr absl::string_view kGetTupleElementIndex = "get_tuple_element_index"; constexpr absl::string_view kUsers = "users"; constexpr absl::string_view kOperands = "operands"; constexpr absl::string_view kLiteral = "literal"; +constexpr absl::string_view kHide = "hide_node"; constexpr int kMaxUsersToRender = 16; @@ -67,6 +70,49 @@ constexpr int kMaxUsersToRender = 16; using OutputEdges = absl::flat_hash_map>; +// Detect if an instruction is an AsyncCollectiveFusion parameter that is +// implementation details. +bool IsAcfPrameter(const xla::HloInstruction* instruction) { + // Parameter is fused + if (instruction->opcode() != xla::HloOpcode::kParameter || + !instruction->IsFused()) + return false; + + // Parameter piped through and is only consumed by 1 user + // Parameter 0 consumed by both root and all-gather will always persist. + if (instruction->user_count() != 1) return false; + + const xla::HloComputation* parent_computation = instruction->parent(); + int64_t parameter_number = instruction->parameter_number(); + xla::HloInstruction* fusion_instruction = + parent_computation->FusionInstruction(); + const xla::HloInstruction* parameterOperand = + fusion_instruction->operand(parameter_number); + // Operand is get-tuple-element + if (parameterOperand->opcode() != xla::HloOpcode::kGetTupleElement) { + return false; + } + constexpr absl::string_view kAcfComputationName = "async_collective_fusion"; + constexpr absl::string_view kAcsInstructionName = "async-collective-start"; + constexpr absl::string_view kAcdInstructionName = "async-collective-done"; + + const xla::HloInstruction* gteOperand = parameterOperand->operand(0); + // Parameter is fused into AsyncCollectiveFusion, operand is gte from + // AsyncCollectiveStart and user is the root node of ACF + if (absl::StartsWith(parent_computation->name(), kAcfComputationName)) { + return absl::StartsWith(gteOperand->name(), kAcsInstructionName) && + instruction->users()[0] == parent_computation->root_instruction(); + } else if (absl::StartsWith(fusion_instruction->name(), + kAcdInstructionName)) { + // Parameter is fused into AsyncCollectiveDone, mapped from Params in + // AsyncCollectiveFusion - operand is gte from ACF + return absl::StartsWith( + gteOperand->fused_instructions_computation()->name(), + kAcfComputationName); + } + return false; +} + // Recursively include all instructions in the nested computations. void RecursiveIncludeNestedComputations( const xla::HloInstruction* instruction, @@ -311,6 +357,12 @@ void SetInstructionNodeAttributes(const xla::HloInstruction* instruction, xla::Cast(instruction)->HasLiteral()) { builder.AppendNodeAttribute(kLiteral, instruction->literal().ToString()); } + + if (options.hide_async_collective_fusion_parameter) { + if (IsAcfPrameter(instruction)) { + builder.AppendNodeAttribute(kHide, "1"); + } + } } absl::Status BuildHloInstructionNode( diff --git a/src/builtin-adapter/hlo_adapter/direct_hlo_to_json_graph_convert.h b/src/builtin-adapter/hlo_adapter/direct_hlo_to_json_graph_convert.h index 33c14c4d..0f3ae7aa 100644 --- a/src/builtin-adapter/hlo_adapter/direct_hlo_to_json_graph_convert.h +++ b/src/builtin-adapter/hlo_adapter/direct_hlo_to_json_graph_convert.h @@ -36,6 +36,10 @@ struct HloAdapterOption { // "ROOT of fusion.0" -> "Parameter 0 of fusion.1, tuple element 1 of // fusion.1". bool get_tuple_element_folding = true; + // If a parameter node has input to async-collective-start and output to + // async-collective-done, mark it as implementation details and hide + // on visualization. + bool hide_async_collective_fusion_parameter = true; }; // Gets the instruction id.