Skip to content

Commit

Permalink
[SYCLomatic][Bug][PyTorch Migration] Fixing the stream() base method …
Browse files Browse the repository at this point in the history
…not-migrated bug (#2560)
  • Loading branch information
TejaX-Alaghari authored Dec 13, 2024
1 parent 3afdb88 commit a55fb39
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
6 changes: 3 additions & 3 deletions clang/test/dpct/pytorch/ATen.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
// RUN: cp -r %S/pytorch_cuda_inc %T/pytorch/ATen/
// RUN: cd %T/pytorch/ATen
// RUN: mkdir dpct_out
// RUN: dpct -out-root dpct_out src/ATen.cu --extra-arg="-I./pytorch_cuda_inc" --cuda-include-path="%cuda-path/include" --rule-file=user_defined_rule_pytorch.yaml -- -x cuda --cuda-host-only
// RUN: FileCheck --input-file dpct_out/ATen.dp.cpp --match-full-lines src/ATen.cu
// RUN: %if build_lit %{icpx -c -fsycl -DNO_BUILD_TEST dpct_out/ATen.dp.cpp -o dpct_out/ATen.dp.o %}
// RUN: dpct -out-root dpct_out %T/pytorch/ATen/src/ATen.cu --extra-arg="-I%T/pytorch/ATen/pytorch_cuda_inc" --cuda-include-path="%cuda-path/include" --rule-file=%T/pytorch/ATen/user_defined_rule_pytorch.yaml -- -x cuda --cuda-host-only
// RUN: FileCheck --input-file %T/pytorch/ATen/dpct_out/ATen.dp.cpp --match-full-lines %T/pytorch/ATen/src/ATen.cu
// RUN: %if build_lit %{icpx -c -fsycl -DNO_BUILD_TEST %T/pytorch/ATen/dpct_out/ATen.dp.cpp -o %T/pytorch/ATen/dpct_out/ATen.dp.o %}

#ifndef NO_BUILD_TEST
#include <iostream>
Expand Down
12 changes: 7 additions & 5 deletions clang/test/dpct/pytorch/c10.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
// RUN: cp -r %S/pytorch_cuda_inc %T/pytorch/c10/
// RUN: cd %T/pytorch/c10
// RUN: mkdir dpct_out
// RUN: dpct -out-root dpct_out src/c10.cu --extra-arg="-I./pytorch_cuda_inc" --cuda-include-path="%cuda-path/include" --rule-file=user_defined_rule_pytorch.yaml -- -x cuda --cuda-host-only
// RUN: FileCheck --input-file dpct_out/c10.dp.cpp --match-full-lines src/c10.cu
// RUN: %if build_lit %{icpx -c -fsycl -DNO_BUILD_TEST dpct_out/c10.dp.cpp -o dpct_out/c10.dp.o %}
// RUN: dpct -out-root dpct_out %T/pytorch/c10/src/c10.cu --extra-arg="-I%T/pytorch/c10/pytorch_cuda_inc" --cuda-include-path="%cuda-path/include" --rule-file=%T/pytorch/c10/user_defined_rule_pytorch.yaml -- -x cuda --cuda-host-only
// RUN: FileCheck --input-file %T/pytorch/c10/dpct_out/c10.dp.cpp --match-full-lines %T/pytorch/c10/src/c10.cu
// RUN: %if build_lit %{icpx -c -fsycl -DNO_BUILD_TEST %T/pytorch/c10/dpct_out/c10.dp.cpp -o %T/pytorch/c10/dpct_out/c10.dp.o %}

#ifndef NO_BUILD_TEST
#include <iostream>
Expand All @@ -30,8 +30,10 @@ int main() {
// CHECK: auto currentStream = c10::xpu::getCurrentXPUStream();
auto currentStream = c10::cuda::getCurrentCUDAStream();

// CHECK: std::cout << "Current Stream (Default Device): " << currentStream.queue() << std::endl;
std::cout << "Current Stream (Default Device): " << currentStream.stream() << std::endl;
// CHECK: dpct::queue_ptr curr_cuda_st = &(currentStream.queue());
// CHECK-NEXT: curr_cuda_st = &(c10::xpu::getCurrentXPUStream().queue());
cudaStream_t curr_cuda_st = currentStream.stream();
curr_cuda_st = c10::cuda::getCurrentCUDAStream().stream();

// CHECK: auto deviceStream = c10::xpu::getCurrentXPUStream(0);
auto deviceStream = c10::cuda::getCurrentCUDAStream(0);
Expand Down
17 changes: 10 additions & 7 deletions clang/test/dpct/pytorch/user_defined_rule_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@
In: c10/cuda/CUDAGuard.h
Out: "<c10/core/DeviceGuard.h>"

- Rule: rule_c10_cuda_CUDAStream
Kind: Class
Priority: Takeover
In: c10::cuda::CUDAStream
Out: c10::xpu::XPUStream
Includes: ["<c10/xpu/XPUStream.h>"]
Methods:
- In: stream
Out: "&($method_base queue())"

- Rule: rule_c10_cuda_OptionalCUDAGuard
Kind: Type
Priority: Takeover
Expand All @@ -37,10 +47,3 @@
Out: c10::xpu::getCurrentXPUStream($1)
Includes: ["<c10/xpu/XPUStream.h>"]

- Rule: rule_c10_cuda_CUDAStream_stream
Kind: PatternRewriter
Priority: Takeover
In: ${prefix}.stream()
Out: ${prefix}.queue()
Includes: []

0 comments on commit a55fb39

Please sign in to comment.