Skip to content

Commit

Permalink
Reverts 6b75e60
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729462063
  • Loading branch information
Google-ML-Automation committed Feb 21, 2025
1 parent 4915ff4 commit 29142ed
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 10 deletions.
5 changes: 3 additions & 2 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1023,9 +1023,10 @@ cc_library(
"@stablehlo//:stablehlo_ops",
"@stablehlo//:stablehlo_serialization",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/translate:stablehlo",
"//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
"//xla/mlir/utils:error_util",
"//xla/mlir_hlo:mhlo_passes",
"//xla/mlir_hlo",
"//xla/mlir_hlo:all_passes",
"//xla/pjrt:mlir_to_hlo",
"//xla/pjrt:status_casters",
"//xla/service/llvm_ir:llvm_util",
Expand Down
30 changes: 24 additions & 6 deletions xla/python/mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include <string>

#include "mhlo/transforms/passes.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand All @@ -37,7 +38,8 @@ limitations under the License.
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
#include "stablehlo/dialect/Serialization.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/translate/stablehlo.h"
#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/pjrt/mlir_to_hlo.h"
#include "xla/pjrt/status_casters.h"
Expand Down Expand Up @@ -77,16 +79,32 @@ void EnablePrintBeforeAndAfter(mlir::PassManager& pm) {
pm.enableIRPrinting(print_before, print_after);
}

// Converts an XlaComputation to a StableHLO mlir::Module string.
// Converts an XlaComputation to an MHLO or StableHLO mlir::Module string.
// Exists for backwards compatibility.
// TODO(phawkins): port remaining users of XlaComputations to use mlir::Modules
// instead and delete this function.
absl::StatusOr<std::string> PyXlaComputationToMlirModule(
const XlaComputation& computation) {
const XlaComputation& computation, bool emit_stable_hlo) {
mlir::MLIRContext context;
if (VLOG_IS_ON(3)) context.disableMultithreading();
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
ConvertHloToStablehlo(context, &computation.proto()));
mlir::OwningOpRef<mlir::ModuleOp> module =
llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&context));
context.loadDialect<mlir::func::FuncDialect>();
context.loadDialect<mlir::mhlo::MhloDialect>();
mlir::DialectRegistry registry;
mlir::func::registerAllExtensions(registry);
context.appendDialectRegistry(registry);

TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(*module, &computation.proto(),
/*import_all_computations=*/true));
mlir::PassManager pm(&context);
if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm);
if (emit_stable_hlo) {
pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass());
}
if (!mlir::succeeded(pm.run(*module))) {
return tsl::errors::InvalidArgument("MHLO => StableHLO failed");
}
return PrintModule(*module);
}

Expand Down Expand Up @@ -184,7 +202,7 @@ void BuildMlirSubmodule(nb::module_& m) {

mlir_module.def("xla_computation_to_mlir_module",
xla::ValueOrThrowWrapper(PyXlaComputationToMlirModule),
nb::arg("computation"));
nb::arg("computation"), nb::arg("emit_stable_hlo") = true);
mlir_module.def(
"mlir_module_to_xla_computation",
[](const nb::bytes& bytecode, bool use_tuple_args, bool return_tuple) {
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
_version = 316

# Version number for MLIR:Python components.
mlir_api_version = 58
mlir_api_version = 57

xla_platform_names = {
'cpu': 'Host',
Expand Down
4 changes: 3 additions & 1 deletion xla/python/xla_extension/mlir.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from typing import Union
from . import XlaComputation

def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ...
def xla_computation_to_mlir_module(
computation: XlaComputation, emit_stable_hlo: bool = ...
) -> str: ...
def mlir_module_to_xla_computation(
mlir_module: Union[bytes, str],
use_tuple_args: bool = ...,
Expand Down

0 comments on commit 29142ed

Please sign in to comment.