Skip to content

Commit

Permalink
remove some dlpack logic
Browse files Browse the repository at this point in the history
  • Loading branch information
yitongh committed Jan 7, 2025
1 parent 5182324 commit b13b136
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 38 deletions.
32 changes: 8 additions & 24 deletions torch_xla/_dynamo/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ def _maybe_move_tensors_to_device(tensors: tuple,
target_device: torch.device) -> tuple:
assert target_device, "Moving tensors to None device not supported"

device_id = None

moved_tensors = []
for tensor in tensors:
if not isinstance(tensor, torch.Tensor):
Expand All @@ -161,19 +159,15 @@ def _maybe_move_tensors_to_device(tensors: tuple,
moved_tensors.append(tensor)
continue

# if dynamo_debug:
# print("Moving Tensor {} to device {}".format(tensor, target_device))
if dynamo_debug:
print("Moving Tensor {} to device {}".format(tensor, target_device))

zero_copy_enabled = xu.getenv_as(xenv.ZERO_COPY_ENABLED, bool, defval=False)
if zero_copy_enabled and tensor.device.type == 'cuda' and target_device.type == 'xla':
# If the input cuda tensor requires gradient, we need to call detach. Otherwise, we'd get the error "RuntimeError: Can't export tensors that require gradient, use tensor.detach()"
device_type, device_id = tensor.__dlpack_device__()
moved_tensor = torch_xla_dlpack.from_dlpack(tensor.detach())
elif zero_copy_enabled and tensor.device.type == 'xla' and target_device.type == 'cuda':
moved_tensor = torch_xla_dlpack.from_xla_cuda_to_cuda(tensor)
# HACK: The `torch_xla._XLAC._get_stream_for_cuda_device` requires a local device index, while the device index for xla tensors is always 0.
# Meanwhile, dlpack uses the actual device index, so we use the device index of the converted CUDA tensor.
device_id = moved_tensor.device.index
else:
# Have to move to CPU before moving it to target device.
cpu_device: torch.device = torch.device("cpu")
Expand All @@ -185,17 +179,6 @@ def _maybe_move_tensors_to_device(tensors: tuple,
moved_tensor.requires_grad = tensor.requires_grad
moved_tensors.append(moved_tensor)

if zero_copy_enabled and device_id is not None:
stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
stream = 1 if stream == 0 else stream
assert stream is None or type(stream) is int
external_stream = torch.cuda.ExternalStream(stream)
current_stream = torch.cuda.current_stream()
if external_stream != current_stream:
event = torch.cuda.Event()
event.record(current_stream)
external_stream.wait_event(event)

return tuple(moved_tensors)


Expand Down Expand Up @@ -566,12 +549,14 @@ def optimized_mod(*args: tuple):
nonlocal sym_constants_to_graph_vars
nonlocal graph_hash


original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = False
if original_device:
is_cuda_args = original_device.type == "cuda"

else:
is_cuda_args = config.outside_on_cuda
if is_cuda_args:
original_device = torch.device(torch.cuda.current_device())

# See [Note: Dynamo real-time input-shape cache look-up] above.
xla_args_tensor_only, sym_constants = _split_xla_args_tensor_sym_constant(
Expand Down Expand Up @@ -671,9 +656,7 @@ def optimized_mod(*args: tuple):

none_remover.add_nones(result)

# TODO: better fix this, input is not cuda tensor, output is cuda tensor
if is_cuda_args:
original_device = torch.device(torch.cuda.current_device())
result = _maybe_move_tensors_to_device(tuple(result), original_device)

if len(result) == 1:
Expand Down Expand Up @@ -904,7 +887,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
node.replace_all_uses_with(new_node)
partitioned_graph.graph.erase_node(node)

XLAConstructorMoverPass()(partitioned_graph.graph, move_xla_to_cuda=True)
if config.outside_on_cuda:
XLAConstructorMoverPass()(partitioned_graph.graph, move_xla_to_cuda=True)
partitioned_graph.recompile()

return partitioned_graph
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
auto external_ref = pjrt_buffer->AcquireExternalReference();
XLA_CHECK_OK(external_ref.status());
pack->external_reference = std::move(external_ref.value());
// XLA_CHECK_OK(pjrt_buffer->GetReadyFuture().Await());
XLA_CHECK_OK(pjrt_buffer->GetReadyFuture().Await());
}
pack->buffer_reference = pjrt_buffer;

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class PjRtComputationClient : public ComputationClient {
xla::PjRtLocalDeviceId(local_device_id));
XLA_CHECK(pjrt_device.ok()) << "Failed to get a PjRt device.";
absl::StatusOr<std::intptr_t> stream =
pjrt_device.value()->GetLocalComputeStream();
pjrt_device.value()->GetStreamForExternalReadyEvents();
XLA_CHECK(stream.ok()) << "Failed to get a stream.";
return stream.value();
}
Expand Down
23 changes: 11 additions & 12 deletions torch_xla/utils/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def from_dlpack(ext_tensor: Any):
ext_tensor, '__dlpack__'):
device_type, device_id = ext_tensor.__dlpack_device__()
if device_type == DLDeviceType.kDLGPU:
# stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
stream = None
stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
dlpack = ext_tensor.__dlpack__(stream=stream)
else:
dlpack = ext_tensor.__dlpack__()
Expand All @@ -38,16 +37,16 @@ def from_xla_cuda_to_cuda(tensor):
# https://github.com/pytorch/pytorch/blob/b0ef363972203b163cddc95e4c6054b8221c2300/torch/utils/dlpack.py#L114-L115
# The array API specify that the default legacy stream must be passed
# with a value of 1 for CUDA
# device_id = tensor.device.index
# stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
# stream = 1 if stream == 0 else stream
# assert stream is None or type(stream) is int
# external_stream = torch.cuda.ExternalStream(stream)
# current_stream = torch.cuda.current_stream()
# if external_stream != current_stream:
# event = torch.cuda.Event()
# event.record(current_stream)
# external_stream.wait_event(event)
device_id = tensor.device.index
stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
stream = 1 if stream == 0 else stream
assert stream is None or type(stream) is int
external_stream = torch.cuda.ExternalStream(stream)
current_stream = torch.cuda.current_stream()
if external_stream != current_stream:
event = torch.cuda.Event()
event.record(current_stream)
external_stream.wait_event(event)
dlpack = to_dlpack(tensor)
cuda_tensor = torch.utils.dlpack.from_dlpack(dlpack)
return cuda_tensor

0 comments on commit b13b136

Please sign in to comment.