Skip to content

Commit

Permalink
Reverts 3b24dc9
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 728078996
  • Loading branch information
Google-ML-Automation committed Feb 18, 2025
1 parent 37a4198 commit 1409698
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 206 deletions.
2 changes: 1 addition & 1 deletion xla/service/hlo_cost_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ absl::Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
auto [it_ignored, inserted] =
hlo_properties_.emplace(hlo, std::move(current_properties_));
current_properties_ = Properties();
TF_RET_CHECK(inserted) << hlo->name() << " already exists in hlo_properties_";
TF_RET_CHECK(inserted);

return absl::OkStatus();
}
Expand Down
9 changes: 1 addition & 8 deletions xla/service/memory_space_assignment/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ xla_cc_test(
"//xla/service:hlo_buffer",
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_value",
"//xla/service/cost_modelling:op_cost",
"//xla/service/heap_simulator",
"//xla/service/heap_simulator:allocation_block",
"//xla/tests:test_utils",
Expand Down Expand Up @@ -185,7 +184,6 @@ cc_library(
"//xla/service:hlo_buffer",
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_value",
"//xla/service/cost_modelling:op_cost",
"//xla/tests:hlo_test_base",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -353,8 +351,8 @@ cc_library(
"//xla/hlo/utils:hlo_live_range",
"//xla/service:call_graph",
"//xla/service:hlo_buffer",
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_value",
"//xla/service/cost_modelling:op_cost",
"//xla/service/heap_simulator",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:function_ref",
Expand All @@ -374,7 +372,6 @@ xla_cc_test(
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_cost_analysis",
"//xla/service/cost_modelling:op_cost",
"//xla/tests:hlo_test_base",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
Expand Down Expand Up @@ -420,7 +417,6 @@ xla_cc_test(
"//xla/hlo/utils:hlo_live_range",
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_value",
"//xla/service/cost_modelling:op_cost",
"//xla/service/heap_simulator",
"//xla/tests:hlo_test_base",
"//xla/tsl/lib/core:status_test_util",
Expand Down Expand Up @@ -470,7 +466,6 @@ cc_library(
"//xla/hlo/utils:hlo_live_range",
"//xla/service:call_graph",
"//xla/service:hlo_cost_analysis",
"//xla/service/cost_modelling:op_cost",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -537,7 +532,6 @@ xla_cc_test(
"//xla/service:buffer_value",
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_value",
"//xla/service/cost_modelling:op_cost",
"//xla/tests:hlo_test_base",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -655,7 +649,6 @@ xla_cc_test(
"//xla/hlo/ir:hlo",
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_value",
"//xla/service/cost_modelling:op_cost",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
Expand Down
21 changes: 12 additions & 9 deletions xla/service/memory_space_assignment/algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2211,14 +2211,16 @@ MsaAlgorithm::GetInefficientAllocationSites(
if (!allocation->is_copy_like_allocation()) {
const HloPosition& defining_position =
allocation->defining_position();
int64_t accessed = options_.cost_analysis->OutputBytesAccessed(
*defining_position.instruction, defining_position.index);
int64_t accessed =
options_.cost_analysis->base_costs().OutputBytesAccessed(
*defining_position.instruction, defining_position.index);
VLOG(3) << " pos: " << defining_position.ToString()
<< ", accessed: " << accessed << " / " << size;
}
for (const HloUse& use : allocation->uses()) {
int64_t accessed = options_.cost_analysis->OperandBytesAccessed(
*use.instruction, use.operand_number, use.operand_index);
int64_t accessed =
options_.cost_analysis->base_costs().OperandBytesAccessed(
*use.instruction, use.operand_number, use.operand_index);
VLOG(3) << " use: " << use.ToString() << ", accessed: " << accessed
<< " / " << size;
}
Expand Down Expand Up @@ -2246,14 +2248,15 @@ MsaAlgorithm::GetInefficientAllocationSites(
copy_bytes += size;
}
if (position_memory_space == MemorySpace::kAlternate) {
use_bytes += options_.cost_analysis->OutputBytesAccessed(
use_bytes += options_.cost_analysis->base_costs().OutputBytesAccessed(
*allocation->defining_position().instruction,
allocation->defining_position().index);
}
if (allocation->memory_space() == MemorySpace::kAlternate) {
for (const HloUse& use : allocation->uses()) {
use_bytes += options_.cost_analysis->OperandBytesAccessed(
*use.instruction, use.operand_number, use.operand_index);
use_bytes +=
options_.cost_analysis->base_costs().OperandBytesAccessed(
*use.instruction, use.operand_number, use.operand_index);
}
}
}
Expand Down Expand Up @@ -4566,10 +4569,10 @@ AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) {
<< options_.cost_analysis->GetAlternateMemoryBenefit(
request.use->hlo_use);
VLOG(3) << "Definition bytes accessed = "
<< options_.cost_analysis->OutputBytesAccessed(
<< options_.cost_analysis->base_costs().OutputBytesAccessed(
*defining_position.instruction, defining_position.index)
<< ", use bytes accessed = "
<< options_.cost_analysis->OperandBytesAccessed(
<< options_.cost_analysis->base_costs().OperandBytesAccessed(
*use.instruction, use.operand_number, use.operand_index);
}

Expand Down
71 changes: 46 additions & 25 deletions xla/service/memory_space_assignment/cost_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,52 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_live_range.h"
#include "xla/service/call_graph.h"
#include "xla/service/cost_modelling/op_cost.h"
#include "xla/service/heap_simulator/heap_simulator.h"
#include "xla/service/hlo_buffer.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_value.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/util.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace memory_space_assignment {

HloCostAnalysisCosts::HloCostAnalysisCosts(
const HloCostAnalysis& hlo_cost_analysis)
: hlo_cost_analysis_(hlo_cost_analysis) {}

float HloCostAnalysisCosts::BytesAccessed(const HloInstruction& instruction) {
return static_cast<float>(hlo_cost_analysis_.bytes_accessed(instruction));
}

float HloCostAnalysisCosts::OperandBytesAccessed(
const HloInstruction& instruction, int64_t operand_num,
const ShapeIndex& shape_index) {
return static_cast<float>(hlo_cost_analysis_.operand_bytes_accessed(
instruction, operand_num, shape_index));
}

float HloCostAnalysisCosts::OutputBytesAccessed(
const HloInstruction& instruction, const ShapeIndex& shape_index) {
return static_cast<float>(
hlo_cost_analysis_.output_bytes_accessed(instruction, shape_index));
}

float HloCostAnalysisCosts::ComputeSeconds(const HloInstruction& instruction) {
return std::max(
std::max(
hlo_cost_analysis_.min_latency_seconds(HloCostAnalysis::kFlopsKey),
static_cast<float>(hlo_cost_analysis_.flop_count(instruction)) /
hlo_cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey)),
static_cast<float>(hlo_cost_analysis_.transcendental_count(instruction)) /
hlo_cost_analysis_.per_second_rate(
HloCostAnalysis::kTranscendentalsKey));
}

/*static*/ absl::StatusOr<std::unique_ptr<CostAnalysis>> CostAnalysis::Create(
OpCostManager& op_cost_manager, const CostAnalysisOptions& options,
BaseCosts& base_costs, const CostAnalysisOptions& options,
const HloModule& module) {
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
TF_ASSIGN_OR_RETURN(auto hlo_live_range,
Expand All @@ -54,7 +87,7 @@ namespace memory_space_assignment {
auto call_graph = CallGraph::Build(&module);
// Using `new` to access a non-public constructor.
return absl::WrapUnique(
new CostAnalysis(op_cost_manager, options, std::move(alias_analysis),
new CostAnalysis(base_costs, options, std::move(alias_analysis),
std::move(hlo_live_range), std::move(call_graph)));
}

Expand All @@ -71,18 +104,6 @@ double CostAnalysis::DefaultMemBandwidthBytesPerSecond(
return options_.default_mem_bandwidth_bytes_per_second;
}

float CostAnalysis::OperandBytesAccessed(const HloInstruction& instruction,
int64_t operand_num,
const ShapeIndex& shape_index) const {
return op_cost_manager_.OperandBytesAccessed(instruction, operand_num,
shape_index);
}

float CostAnalysis::OutputBytesAccessed(const HloInstruction& instruction,
const ShapeIndex& shape_index) const {
return op_cost_manager_.OutputBytesAccessed(instruction, shape_index);
}

float CostAnalysis::GetAlternateMemoryBenefit(
const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
CostAnalysis::Cache* cache) const {
Expand Down Expand Up @@ -277,7 +298,7 @@ float CostAnalysis::GetDefaultMemoryAccessOverhead(
// = (window_size / bytes_accessed) * compute_elapsed
const float window_size_bytes =
options_.pipeline_overhead_window_size_mib * 1024 * 1024;
const float bytes_accessed = op_cost_manager_.TotalBytesAccessed(instruction);
const float bytes_accessed = base_costs_.BytesAccessed(instruction);
const float default_memory_bytes_accessed =
bytes_accessed -
GetBytesAccessedFromAlternateMemory(
Expand All @@ -297,7 +318,7 @@ float CostAnalysis::GetDefaultMemoryBandwidthIdleTime(
absl::Span<const std::pair<int64_t, ShapeIndex>> operands_in_alternate_mem,
absl::Span<const ShapeIndex> outputs_in_alternate_mem) const {
const float default_memory_bytes_accessed =
op_cost_manager_.TotalBytesAccessed(instruction) -
base_costs_.BytesAccessed(instruction) -
GetBytesAccessedFromAlternateMemory(
instruction, operands_in_alternate_mem, outputs_in_alternate_mem);
const float elapsed_due_to_default_mem =
Expand All @@ -313,14 +334,14 @@ float CostAnalysis::GetBytesAccessedFromAlternateMemory(
absl::Span<const ShapeIndex> outputs_in_alternate_mem) const {
float bytes_accessed_from_alternate_mem = 0.0;
for (auto& operand : operands_in_alternate_mem) {
const float operand_bytes_accessed = op_cost_manager_.OperandBytesAccessed(
const float operand_bytes_accessed = base_costs_.OperandBytesAccessed(
instruction, operand.first, operand.second);
bytes_accessed_from_alternate_mem += operand_bytes_accessed;
}

for (auto& shape_idx : outputs_in_alternate_mem) {
const float output_bytes_accessed =
op_cost_manager_.OutputBytesAccessed(instruction, shape_idx);
base_costs_.OutputBytesAccessed(instruction, shape_idx);
bytes_accessed_from_alternate_mem += output_bytes_accessed;
}
return bytes_accessed_from_alternate_mem;
Expand Down Expand Up @@ -349,7 +370,7 @@ float CostAnalysis::GetInstructionElapsedDueToCompute(
if (ExcludeInstructionFromElapsed(instruction)) {
return 0.0f;
}
return op_cost_manager_.ComputeSeconds(instruction);
return base_costs_.ComputeSeconds(instruction);
}

float CostAnalysis::GetInstructionElapsedDueToMemory(
Expand All @@ -359,7 +380,7 @@ float CostAnalysis::GetInstructionElapsedDueToMemory(
if (ExcludeInstructionFromElapsed(instruction)) {
return 0.0f;
}
float total_bytes_accessed = op_cost_manager_.TotalBytesAccessed(instruction);
float total_bytes_accessed = base_costs_.BytesAccessed(instruction);
float bytes_accessed_from_alternate_mem = GetBytesAccessedFromAlternateMemory(
instruction, operands_in_alternate_mem, outputs_in_alternate_mem);
float elapsed_due_to_alternate_mem =
Expand All @@ -377,7 +398,7 @@ float CostAnalysis::GetInstructionElapsedDueToMemory(
if (ExcludeInstructionFromElapsed(instruction)) {
return 0.0f;
}
float total_bytes_accessed = op_cost_manager_.TotalBytesAccessed(instruction);
float total_bytes_accessed = base_costs_.BytesAccessed(instruction);
float bytes_accessed_from_alternate_mem = 0.0;
for (int operand_num = 0; operand_num < instruction.operand_count();
++operand_num) {
Expand All @@ -389,8 +410,8 @@ float CostAnalysis::GetInstructionElapsedDueToMemory(
}
if (is_in_alternate_mem(operand_num, index, subshape)) {
bytes_accessed_from_alternate_mem +=
op_cost_manager_.OperandBytesAccessed(instruction, operand_num,
index);
base_costs_.OperandBytesAccessed(instruction, operand_num,
index);
}
});
}
Expand All @@ -401,7 +422,7 @@ float CostAnalysis::GetInstructionElapsedDueToMemory(
}
if (is_in_alternate_mem(/*operand_num=*/std::nullopt, index, subshape)) {
bytes_accessed_from_alternate_mem +=
op_cost_manager_.OutputBytesAccessed(instruction, index);
base_costs_.OutputBytesAccessed(instruction, index);
}
});
float elapsed_due_to_alternate_mem =
Expand Down
Loading

0 comments on commit 1409698

Please sign in to comment.