Skip to content

Commit

Permalink
Clean up while loop dependency logic.
Browse files Browse the repository at this point in the history
- Move to ir_emission_utils and give it a better name
- Add unit tests
- Clarify dependency requirements (check for side effects)
- Fix accidental breakage of service/gpu/tests/gpu_copy_test.cc.
  • Loading branch information
jreiffers committed Feb 25, 2025
1 parent 5f1e20f commit c9f39f8
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 153 deletions.
167 changes: 19 additions & 148 deletions xla/backends/gpu/codegen/fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,145 +58,6 @@ bool IsDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) {
});
}

// Will include the instruction itself.
// TODO(jreiffers): there's probably something in
// xla/hlo/analysis/hlo_reachability.h that we can use.
void GetDependencies(const HloInstruction* instruction,
absl::flat_hash_set<const HloInstruction*>& result) {
if (result.insert(instruction).second) {
for (auto* operand : instruction->operands()) {
GetDependencies(operand, result);
}
}
}

// TODO(jreiffers): give this a better name.
struct WhileLoopSpec {
const HloInstruction* loop;
const HloInstruction* induction_var;
const HloInstruction* slice_arg;
};

// TODO(jreiffers): Move this to copy.cc?
std::optional<WhileLoopSpec> GetDefiningWhileLoop(
const IrEmitterCallStack& call_stack, const HloFusionInstruction* fusion,
const HloInstruction* parameter) {
if (call_stack.instructions().empty()) {
return std::nullopt;
}

VLOG(5) << "Looking for defining while loop of " << parameter->name()
<< " in " << fusion->name();

// Walk up the call stack, tracking the origin of `parameter`.
const HloInstruction* argument = parameter;
auto call_stack_it = call_stack.instructions().rbegin();
auto call_stack_end = call_stack.instructions().rend();
for (; call_stack_it != call_stack.instructions().rend() &&
argument->opcode() == HloOpcode::kParameter &&
((*call_stack_it)->opcode() == HloOpcode::kFusion ||
(*call_stack_it)->opcode() == HloOpcode::kAsyncStart ||
(*call_stack_it)->opcode() == HloOpcode::kCall);
++call_stack_it) {
argument = (*call_stack_it)->operand(argument->parameter_number());
}

if (call_stack_it == call_stack.instructions().rend()) {
return std::nullopt;
}

VLOG(5) << "Arrived at " << argument->name() << " in "
<< (*call_stack_it)->name();

// We should now be in a call (if command buffers are enabled) or a while.

// Find all the dependencies of the argument.
absl::flat_hash_set<const HloInstruction*> deps;
GetDependencies(argument, deps);

// Find a unique parameter and a gte.
const HloInstruction* unique_param = nullptr;
const HloInstruction* unique_gte = nullptr;

for (const auto* dep : deps) {
if (dep->opcode() == HloOpcode::kParameter) {
if (unique_param || !dep->shape().IsTuple()) {
VLOG(5) << "Found wrong parameters.";
return std::nullopt;
}
unique_param = dep;
}

if (dep->opcode() == HloOpcode::kGetTupleElement) {
if (unique_gte) {
VLOG(5) << "Found non-unique GTEs.";
return std::nullopt;
}
unique_gte = dep;
}
}

if (!unique_param || !unique_gte || unique_gte->operand(0) != unique_param) {
VLOG(5) << "Did not find a parameter or GTE or they don't match.";
return std::nullopt;
}

VLOG(5) << "Parameter and GTE: " << unique_param->name() << ", "
<< unique_gte->name();

// Continue walking up through call instructions.
while (call_stack_it != call_stack_end &&
(*call_stack_it)->opcode() == HloOpcode::kCall &&
unique_param->opcode() == HloOpcode::kParameter) {
unique_param = (*call_stack_it)->operand(unique_param->parameter_number());
++call_stack_it;
}

// Find the while loop for 'unique_param'.
auto while_instr_it = std::find_if(
call_stack_it, call_stack.instructions().rend(),
[&](const HloInstruction* instr) {
if (instr->opcode() != HloOpcode::kWhile) {
VLOG(5) << "Not a loop: " << instr->name();
return false;
}

// Verify that this GTE is the induction variable of the loop.
if (unique_param != instr->while_body()->parameter_instruction(0)) {
VLOG(5) << "Parameter mismatch: " << unique_param->name() << " vs "
<< instr->while_body()->parameter_instruction(0)->name();
VLOG(5) << instr->while_body()->ToString();
VLOG(5) << unique_param->parent()->ToString();
return false;
}

auto config = instr->backend_config<xla::WhileLoopBackendConfig>();
if (!config.ok()) {
VLOG(5) << "Loop has no WhileLoopBackendConfig.";
return false;
}
if (!config->has_known_trip_count() || !config->has_known_init_step() ||
!config->has_known_induction_variable()) {
VLOG(5) << "Loop has no known trip count, known init/step or no "
"known induction variable.";
return false;
}
if (unique_gte->tuple_index() !=
config->known_induction_variable().tuple_index()) {
VLOG(5) << "The offset does not depend on the induction variable.";
return false;
}
return true;
});
if (while_instr_it == call_stack.instructions().rend()) {
VLOG(5) << "Did not find a while loop.";
return std::nullopt;
}
VLOG(5) << "While loop for " << parameter->name() << " in " << fusion->name()
<< ": " << (*while_instr_it)->name();
return WhileLoopSpec{*while_instr_it, unique_gte, argument};
}

std::optional<DynamicMemcpyThunk::MemcpyDescriptor> GetDynamicMemcpyDescriptor(
const HloFusionAnalysis& analysis, const HloFusionInstruction* fusion,
const IrEmitterCallStack& call_stack) {
Expand Down Expand Up @@ -252,17 +113,27 @@ std::optional<DynamicMemcpyThunk::MemcpyDescriptor> GetDynamicMemcpyDescriptor(
continue;
}

auto loop = GetDefiningWhileLoop(call_stack, fusion, operand);
if (loop) {
VLOG(5) << "Offset for dimension " << i << " is dynamic.";
descriptor.src_dynamic_offsets.emplace_back() = {
loop->loop, loop->induction_var, loop->slice_arg, (*strides)[i]};
continue;
auto functional_dependency = ResolveFunctionalDependencyOnInductionVariable(
call_stack.instructions(), operand);
if (!functional_dependency) {
VLOG(5) << "Offset for dimension " << i << " is not statically known.";
return std::nullopt;
}

VLOG(5) << "Offset for dimension " << i
<< " is unsupported: " << operand->name();
return std::nullopt;
// The while loop must actually be a for loop.
auto loop_config = functional_dependency->loop
->backend_config<xla::WhileLoopBackendConfig>();
if (!loop_config.ok() || !loop_config->has_known_init_step() ||
!loop_config->has_known_trip_count()) {
VLOG(5) << "Offset for dimension " << i
<< " depends on loop with unknown behavior.";
return std::nullopt;
}

VLOG(5) << "Offset for dimension " << i << " is dynamic.";
descriptor.src_dynamic_offsets.emplace_back() = {
functional_dependency->loop, functional_dependency->induction_var,
functional_dependency->derived_value, (*strides)[i]};
}

return descriptor;
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/codegen/fusions.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class HloFusionInfo : public FusionInfo {
const HloFusionInstruction* instr,
const BufferAssignment* buffer_assignment,
const CallGraph& call_graph,
const IrEmitterCallStack& call_stack,)
const IrEmitterCallStack& call_stack)
: FusionInfo(analysis),
instr_(instr),
buffer_assignment_(buffer_assignment),
Expand Down
113 changes: 113 additions & 0 deletions xla/service/gpu/ir_emission_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -664,5 +664,118 @@ bool IsDynamicSliceFusion(const HloInstruction* instr) {
name == kDynamicSliceFusionWithDynamicAddressComputationConfigName;
}

std::optional<InductionVariableFunctionalDependency>
ResolveFunctionalDependencyOnInductionVariable(
absl::Span<const HloInstruction* const> call_stack,
const HloInstruction* parameter) {
if (call_stack.empty()) {
return std::nullopt;
}

VLOG(5) << "Looking for defining while loop of " << parameter->name();

// Walk up the call stack, tracking the origin of `parameter`.
const HloInstruction* argument = parameter;
auto call_stack_it = call_stack.rbegin();
auto call_stack_end = call_stack.rend();
for (; call_stack_it != call_stack_end &&
argument->opcode() == HloOpcode::kParameter &&
((*call_stack_it)->opcode() == HloOpcode::kFusion ||
(*call_stack_it)->opcode() == HloOpcode::kAsyncStart ||
(*call_stack_it)->opcode() == HloOpcode::kCall);
++call_stack_it) {
argument = (*call_stack_it)->operand(argument->parameter_number());
}

if (call_stack_it == call_stack_end) {
return std::nullopt;
}

VLOG(5) << "Arrived at " << argument->name() << " in "
<< (*call_stack_it)->name();

// Find a unique parameter and a gte in the transitive dependencies of
// `argument`.
const HloInstruction* unique_param = nullptr;
const HloInstruction* unique_gte = nullptr;
absl::flat_hash_set<const HloInstruction*> seen{argument};
std::queue<const HloInstruction*> queue;
queue.push(argument);
while (!queue.empty()) {
const auto* instruction = queue.front();
queue.pop();

if (instruction->opcode() == HloOpcode::kCustomCall ||
instruction->HasSideEffect()) {
VLOG(5) << "Found an unsafe operation.";
return std::nullopt;
}

if (instruction->opcode() == HloOpcode::kParameter) {
if (unique_param || !instruction->shape().IsTuple()) {
VLOG(5) << "Failed to match parameters.";
return std::nullopt;
}
unique_param = instruction;
}

if (instruction->opcode() == HloOpcode::kGetTupleElement) {
if (unique_gte) {
VLOG(5) << "Found non-unique GTEs.";
return std::nullopt;
}
unique_gte = instruction;
}

for (auto* operand : instruction->operands()) {
if (seen.insert(operand).second) {
queue.push(operand);
}
}
}

if (!unique_param || !unique_gte || unique_gte->operand(0) != unique_param) {
VLOG(5) << "Did not find a parameter or GTE or they don't match.";
return std::nullopt;
}

VLOG(5) << "Parameter and GTE: " << unique_param->name() << ", "
<< unique_gte->name();

// Continue walking up through call instructions.
while (call_stack_it != call_stack_end &&
(*call_stack_it)->opcode() == HloOpcode::kCall &&
unique_param->opcode() == HloOpcode::kParameter) {
unique_param = (*call_stack_it)->operand(unique_param->parameter_number());
++call_stack_it;
}

// Find the while loop for 'unique_param'.
auto while_instr_it = std::find_if(
call_stack_it, call_stack_end, [&](const HloInstruction* instr) {
return instr->opcode() == HloOpcode::kWhile &&
unique_param == instr->while_body()->parameter_instruction(0);
});

if (while_instr_it == call_stack_end) {
VLOG(5) << "Did not find a while loop.";
return std::nullopt;
}

auto config = (*while_instr_it)->backend_config<xla::WhileLoopBackendConfig>();
if (!config.ok() || !config->has_known_induction_variable() ||
unique_gte->tuple_index() !=
config->known_induction_variable().tuple_index()) {
VLOG(5) << "Failed to verify that the offset depends on the induction "
"variable.";
return std::nullopt;
}

VLOG(5) << "While loop for " << parameter->name() << ": "
<< (*while_instr_it)->name();
return InductionVariableFunctionalDependency{argument, *while_instr_it,
unique_gte};
}

} // namespace gpu
} // namespace xla
16 changes: 16 additions & 0 deletions xla/service/gpu/ir_emission_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,22 @@ absl::StatusOr<std::string> FingerprintWithBackendConfig(
", backend_config_fingerprint=", fingerprint);
}

struct InductionVariableFunctionalDependency {
// The value that is derived from the induction variable. This is guaranteed
// to have no other transitive dependencies (except constants).
const HloInstruction* derived_value;

// The loop and its induction variable that the value depends on.
const HloInstruction* loop;
const HloInstruction* induction_var;
};

// Checks if `parameter`'s value is a pure function of a while loop's induction
// variable.
std::optional<InductionVariableFunctionalDependency>
ResolveFunctionalDependencyOnInductionVariable(
absl::Span<const HloInstruction* const> call_stack, const HloInstruction* parameter);

} // namespace gpu
} // namespace xla

Expand Down
Loading

0 comments on commit c9f39f8

Please sign in to comment.