Skip to content

Commit

Permalink
style(dipu): run format on dipu codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
lljbash committed Mar 28, 2024
1 parent 4574652 commit 79e71b4
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 24 deletions.
29 changes: 16 additions & 13 deletions dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,10 +772,9 @@ def functions_code_gen(fun_config):
)

if fun_config.get("print_func_call_info", False) == True:
fun_config[
"custom_code_at_the_beginning"
] = create_code_to_print_fun_call_info_from_schema(fun_config) + fun_config.get(
"custom_code_at_the_beginning", ""
fun_config["custom_code_at_the_beginning"] = (
create_code_to_print_fun_call_info_from_schema(fun_config)
+ fun_config.get("custom_code_at_the_beginning", "")
)

if fun_config.get("print_op_args", False) == True:
Expand Down Expand Up @@ -881,13 +880,15 @@ def functions_code_gen(fun_config):
],
call_backward_impl_code=[
(
"auto result = "
+ create_call_cpp_function_code_from_schema(
fun_config["backward_schema"]
).replace("; ", ";\n")
(
"auto result = "
+ create_call_cpp_function_code_from_schema(
fun_config["backward_schema"]
).replace("; ", ";\n")
)
if "backward_schema" in fun_config
else ""
)
if "backward_schema" in fun_config
else ""
],
backward_return_code=[
fun_config.get("backward_return_code", "").replace("; ", ";\n")
Expand Down Expand Up @@ -955,9 +956,11 @@ def functions_code_gen(fun_config):
)
],
force_fallback=[
"false"
if fun_config.get("force_fallback", False) in [False, "False"]
else "true"
(
"false"
if fun_config.get("force_fallback", False) in [False, "False"]
else "true"
)
],
fallbackFunc=[
"dipu::native::"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def __init__(self, config_yaml):
detail = config[interface]
assert isinstance(detail, dict)
if "layout" in detail:
self.convert_dict[interface][
"layout"
] = self.layout2memoryformat(detail["layout"])
self.convert_dict[interface]["layout"] = (
self.layout2memoryformat(detail["layout"])
)

def layout2memoryformat(self, layout):
# used when pasing convert_config.yaml, return the memory format based on NCHW/NHWC and other layout.
Expand Down
2 changes: 1 addition & 1 deletion dipu/torch_dipu/csrc_dipu/aten/ops/OpUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ inline std::string dumpArg(const at::Tensor& tensor) {
<< ", storage_data_ptr: " << tensor.storage().data_ptr().get()
<< ", storage_offset: " << tensor.storage_offset();
if (dumpOpArgLevel() > 2) {
stream << '\n' <<toCpuTensorWithoutDiopiCopy(tensor);
stream << '\n' << toCpuTensorWithoutDiopiCopy(tensor);
}
} else {
stream << "undefined";
Expand Down
1 change: 1 addition & 0 deletions dipu/torch_dipu/dipu/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class StreamContext:
current device, this function will also change the current device to
match the stream.
"""

cur_stream: Optional["torch_dipu.dipu.Stream"]

def __init__(self, stream):
Expand Down
8 changes: 5 additions & 3 deletions dipu/torch_dipu/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,11 @@ def trim_path(path, src_column_width):
_format_time_share(evt.self_cpu_time_total, sum_self_cpu_time_total),
evt.self_cpu_time_total_str, # Self CPU total
# CPU total %, 0 for async events.
_format_time_share(evt.cpu_time_total, sum_self_cpu_time_total)
if not evt.is_async
else 0,
(
_format_time_share(evt.cpu_time_total, sum_self_cpu_time_total)
if not evt.is_async
else 0
),
evt.cpu_time_total_str, # CPU total
evt.cpu_time_str, # CPU time avg
]
Expand Down
6 changes: 2 additions & 4 deletions dipu/torch_dipu/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,11 @@ def skipOn(vendor: str, reason: str):


@overload
def onlyOn(vendor: str):
...
def onlyOn(vendor: str): ...


@overload
def onlyOn(vendor: List[str]):
...
def onlyOn(vendor: List[str]): ...


def onlyOn(vendor):
Expand Down

0 comments on commit 79e71b4

Please sign in to comment.