Skip to content

Commit

Permalink
[XLA:GPU] Operands couldn't be swapped for fusions with more than one…
Browse files Browse the repository at this point in the history
… parameter per operand.

The computation passed to `TritonFusionAnalysis` was not part of a module, and therefore all instruction ids were -1 (uniqueids are assigned only for instructions in a module).

That caused ConstHloInstructionSet to treat all instructions as identical.

PiperOrigin-RevId: 729040698
  • Loading branch information
mooskagh authored and Google-ML-Automation committed Feb 20, 2025
1 parent 4ae7f87 commit 374ddc4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
4 changes: 2 additions & 2 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1732,14 +1732,14 @@ cc_library(
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/utils:hlo_query",
"//xla/service/gpu:triton_fusion_analysis",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
12 changes: 6 additions & 6 deletions xla/service/gpu/transforms/gemm_fusion_swap_operands.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License.

#include <cstddef>
#include <cstdint>
#include <memory>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
Expand All @@ -38,10 +38,10 @@ limitations under the License.
#include "xla/service/gpu/triton_fusion_analysis.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -176,9 +176,9 @@ absl::StatusOr<bool> EmitterCanHandleSwappedOperands(
const HloInstruction* dot) {
auto tmp_module = HloModule("tmp", dot->parent()->parent()->config());
HloCloneContext clone_context(&tmp_module);
std::unique_ptr<HloComputation> cloned_computation =
dot->parent()->CloneInContext(clone_context);
TF_RETURN_IF_ERROR(SwapDotOperandsInFusion(cloned_computation.get()));
HloComputation* cloned_computation = tmp_module.AddEntryComputation(
dot->parent()->CloneInContext(clone_context));
TF_RETURN_IF_ERROR(SwapDotOperandsInFusion(cloned_computation));
return TritonFusionAnalysis::Execute(*cloned_computation).ok();
}

Expand Down
23 changes: 23 additions & 0 deletions xla/service/gpu/transforms/gemm_fusion_swap_operands_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,29 @@ ENTRY main {
EXPECT_FALSE(GemmFusionSwapOperands().Run(module->get()).value());
}

TEST_F(SwapOperandsTest, MultipleParameterIsFine) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule MultipleParameterIsFine
fcomp {
p0 = bf16[8,1536]{1,0} parameter(0)
p1 = s8[1536,1536]{1,0} parameter(1)
p2 = s8[1536,1536]{1,0} parameter(2)
c1 = s8[1536,3072]{1,0} concatenate(s8[1536,1536]{1,0} p1, s8[1536,1536]{1,0} p2), dimensions={1}
c2 = bf16[1536,3072]{1,0} convert(s8[1536,3072]{1,0} c1)
ROOT %dot.2515 = bf16[8,3072]{1,0} dot(bf16[8,1536]{1,0} p0, bf16[1536,3072]{1,0} c2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
ENTRY main {
p0 = bf16[8,1536]{1,0} parameter(0)
p1 = s8[1536,1536]{1,0} parameter(1)
p2 = s8[1536,1536]{1,0} parameter(2)
ROOT %micro_kernel = bf16[8,3072]{1,0} fusion(p0, p1, p2), kind=kCustom, calls=fcomp,
backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}}
})");
EXPECT_TRUE(GemmFusionSwapOperands().Run(module->get()).value());
}

} // namespace
} // namespace gpu
} // namespace xla

0 comments on commit 374ddc4

Please sign in to comment.