From 9ad993f058ec7dbcf9143b91488ae41bd9d7a27d Mon Sep 17 00:00:00 2001 From: twata Date: Thu, 13 Apr 2023 06:13:29 +0000 Subject: [PATCH 01/11] Parameterize onnx grad test with pfto --- .../onnx_tests/test_grad.py | 132 ++++++++++++------ 1 file changed, 89 insertions(+), 43 deletions(-) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py index 5a7a449b7..1c80f3061 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py @@ -61,9 +61,10 @@ def forward(self, x): assert y.shape == (1, 1, 32, 20) +# @pytest.mark.parametrize("use_pfto", [False, True]) @pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_grad(): +def test_grad(use_pfto: bool = False): if not pytorch_pfn_extras.requires('1.8.0'): pytest.skip('skip for PyTorch 1.7 or earlier') @@ -96,7 +97,7 @@ def forward(self, x): x, 'grad', enable_onnx_checker=False, - use_pfto=False, + use_pfto=use_pfto, ) actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx')) @@ -110,21 +111,34 @@ def forward(self, x): assert 'Gradient_4' in named_nodes assert 'MatMul_6' in named_nodes - assert list([v.name for v in actual_onnx.graph.output]) == [ - "v10_MatMul", "Gradient_y_0", "Gradient_x_0_0" - ] - y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") - if pytorch_pfn_extras.requires("1.13"): - assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" - assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in + if use_pfto: + assert list([v.name for v in actual_onnx.graph.output]) == [ + "linear.72", "Gradient_y_0", "Gradient_x_0_0" + ] + y_in, _ = _get_name(actual_onnx.graph, "input.1") + if pytorch_pfn_extras.requires("1.13"): + assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" + assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in + else: + assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" + assert named_nodes["Conv_2"].output[0] == y_in else: - assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" - assert named_nodes["Conv_2"].output[0] == y_in - - + assert list([v.name for v in actual_onnx.graph.output]) == [ + "v10_MatMul", "Gradient_y_0", "Gradient_x_0_0" + ] + y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") + if pytorch_pfn_extras.requires("1.13"): + assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" + assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in + else: + assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" + assert named_nodes["Conv_2"].output[0] == y_in + + +@pytest.mark.parametrize("use_pfto", [False, True]) @pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_grad_multiple_times(): +def test_grad_multiple_times(use_pfto: bool): if not pytorch_pfn_extras.requires("1.8.0"): pytest.skip('skip for PyTorch 1.7 or earlier') @@ -184,26 +198,44 @@ def forward(self, x): assert 'Gradient_9' in named_nodes assert 'MatMul_12' in named_nodes - assert list([v.name for v in actual_onnx.graph.output]) == [ - "v16_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1" - ] - y0_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") - y1_in, _ = _get_name(actual_onnx.graph, "Gradient_y_1") - if pytorch_pfn_extras.requires("1.13"): - assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" - assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y0_in - assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].input[0] == "Gradient_x_0_1" - assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].output[0] == y1_in + if use_pfto: + assert list([v.name for v in actual_onnx.graph.output]) == [ + "v16_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1" + ] + y0_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") + y1_in, _ = _get_name(actual_onnx.graph, "Gradient_y_1") + if pytorch_pfn_extras.requires("1.13"): + assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" + assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y0_in + assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].input[0] == "Gradient_x_0_1" + assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].output[0] == y1_in + else: + assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" + assert named_nodes["Conv_2"].output[0] == y0_in + assert named_nodes["Conv_7"].input[0] == "Gradient_x_0_1" + assert named_nodes["Conv_7"].output[0] == y1_in else: - assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" - assert named_nodes["Conv_2"].output[0] == y0_in - assert named_nodes["Conv_7"].input[0] == "Gradient_x_0_1" - assert named_nodes["Conv_7"].output[0] == y1_in - - + assert list([v.name for v in actual_onnx.graph.output]) == [ + "v16_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1" + ] + y0_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") + y1_in, _ = _get_name(actual_onnx.graph, "Gradient_y_1") + if pytorch_pfn_extras.requires("1.13"): + assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" + assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y0_in + assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].input[0] == "Gradient_x_0_1" + assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].output[0] == y1_in + else: + assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" + assert named_nodes["Conv_2"].output[0] == y0_in + assert named_nodes["Conv_7"].input[0] == "Gradient_x_0_1" + assert named_nodes["Conv_7"].output[0] == y1_in + + +# @pytest.mark.parametrize("use_pfto", [False, True]) @pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_grad_with_multiple_inputs(): +def test_grad_with_multiple_inputs(use_pfto: bool = False): if not pytorch_pfn_extras.requires("1.8.0"): pytest.skip('skip for PyTorch 1.7 or earlier') @@ -238,7 +270,7 @@ def forward(self, x): x, 'grad', enable_onnx_checker=False, - use_pfto=False, + use_pfto=use_pfto, ) actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx')) @@ -252,15 +284,29 @@ def forward(self, x): assert 'Gradient_7' in named_nodes assert 'MatMul_9' in named_nodes - assert list([v.name for v in actual_onnx.graph.output]) == [ - "v14_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0" - ] - y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") - if pytorch_pfn_extras.requires("1.13"): - assert named_nodes["/_ppe_as_out_module/Concat"].input[0] == "Gradient_x_0_0" - assert named_nodes["/_ppe_as_out_module/Concat"].input[1] == "Gradient_x_1_0" - assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in + if use_pfto: + assert list([v.name for v in actual_onnx.graph.output]) == [ + "linear.87", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0" + ] + y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") + if pytorch_pfn_extras.requires("1.13"): + assert named_nodes["/_ppe_as_out_module/Concat"].input[0] == "Gradient_x_0_0" + assert named_nodes["/_ppe_as_out_module/Concat"].input[1] == "Gradient_x_1_0" + assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in + else: + assert named_nodes["Concat_4"].input[0] == "x0" + assert named_nodes["Concat_4"].input[1] == "x1" + assert named_nodes["Conv_5"].output[0] == "conv.output.1" else: - assert named_nodes["Concat_4"].input[0] == "Gradient_x_0_0" - assert named_nodes["Concat_4"].input[1] == "Gradient_x_1_0" - assert named_nodes["Conv_5"].output[0] == y_in + assert list([v.name for v in actual_onnx.graph.output]) == [ + "v14_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0" + ] + y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") + if pytorch_pfn_extras.requires("1.13"): + assert named_nodes["/_ppe_as_out_module/Concat"].input[0] == "Gradient_x_0_0" + assert named_nodes["/_ppe_as_out_module/Concat"].input[1] == "Gradient_x_1_0" + assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in + else: + assert named_nodes["Concat_4"].input[0] == "Gradient_x_0_0" + assert named_nodes["Concat_4"].input[1] == "Gradient_x_1_0" + assert named_nodes["Conv_5"].output[0] == y_in From bb75b2b0663a1cbd716829aadf2cb4364683e5c6 Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Thu, 20 Apr 2023 11:32:23 +0900 Subject: [PATCH 02/11] Remove okdshin-san from code owner --- .github/CODEOWNERS | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 839bc1d9b..aa34069a1 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,3 +1,3 @@ * @asi1024 @emcastillo @kmaehashi @linshokaku -/pytorch_pfn_extras/onnx @xuzijian629 @okdshin @take-cheeze @asi1024 @emcastillo @kmaehashi @linshokaku -/tests/pytorch_pfn_extras_tests/onnx_tests @xuzijian629 @okdshin @take-cheeze @asi1024 @emcastillo @kmaehashi @linshokaku +/pytorch_pfn_extras/onnx @xuzijian629 @take-cheeze @asi1024 @emcastillo @kmaehashi @linshokaku +/tests/pytorch_pfn_extras_tests/onnx_tests @xuzijian629 @take-cheeze @asi1024 @emcastillo @kmaehashi @linshokaku From ca3b936cd8d5e94f1d832442983cba24d241825c Mon Sep 17 00:00:00 2001 From: Hiroaki Mikami Date: Wed, 19 Apr 2023 16:01:54 +0900 Subject: [PATCH 03/11] Use state_dict to decide name of initializers --- .../onnx/pfto_exporter/export.py | 38 ++++++++++++++----- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 10e2ac7df..52616d7b0 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -240,7 +240,29 @@ def _run_trace(self) -> None: self.original_outputs = self.original_model(*self.inputs) self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0]) self.g: torch._C.Graph = self.traced.inlined_graph - self.vars: Dict[str, torch.IValue] = {_remove_prefix(k, f"{_ppe_ignore_scope}."): v for k, v in self.traced.state_dict().items()} + """ + `self.trace` ignores the override of `state_dict` method in `self.original_model`. + Thus, the key name may be different between state dict of `self.trace` and `self.original_model`. + pfto uses the key name of `self.original_model.state_dict()` as the parameter names in ONNX. + + To implement this behavior, we have to prepare mapping from name of `self.trace` state_dict to + the name of `self.original_model` state_dict. + """ + self.name_from_trace: Dict[str, str] = {} + vars_in_traced: Dict[str, torch.IValue] = { + _remove_prefix(k, f"{_ppe_ignore_scope}."): v for k, v in self.traced.state_dict().items() + } + if isinstance(self.original_model, torch.nn.Module): + vars_tmp: Dict[str, torch.IValue] = { + _remove_prefix(k, f"{_ppe_ignore_scope}."): v for k, v in self.traced.state_dict(keep_vars=True).items() + } + v_to_name = {v: k for k, v in self.original_model.state_dict(keep_vars=True).items()} + for name, v in vars_tmp.items(): + self.name_from_trace[name] = v_to_name[v] + else: + for name in vars_in_traced.keys(): + self.name_from_trace[name] = name + self.vars: Dict[str, torch.IValue] = {self.name_from_trace[name]: v for name, v in vars_in_traced.items()} self.torch2onnx_var: Dict[torch._C.Value, torch._C.Value] = { i: i for i in self.g.inputs() } @@ -422,19 +444,15 @@ def gen_const(g: torch._C.Graph, value: Any = None) -> torch._C.Value: def handle_getattr(self, g: torch._C.Graph, n: torch._C.Node) -> None: if self.is_self(n.input()) or self.attrs[_unique_id(n.input())] == _ppe_ignore_scope: - self.attrs[_unique_id(n.output())] = ONNXValueID(n.s("name")) + var_name = n.s("name") else: - self.attrs[_unique_id(n.output())] = ONNXValueID( - "%s.%s" - % ( - self.attrs[_unique_id(n.input())], - n.s("name"), - ) - ) - var_name = self.attrs[_unique_id(n.output())] + var_name = "%s.%s" % (self.attrs[_unique_id(n.input())], n.s("name")) + if var_name in self.name_from_trace: + var_name = self.name_from_trace[var_name] if var_name in self.vars: assert isinstance(self.vars[var_name], torch.Tensor) n.output().inferTypeFrom(cast(torch.Tensor, self.vars[var_name])) + self.attrs[_unique_id(n.output())] = ONNXValueID(var_name) def handle_list_construct(self, g: torch._C.Graph, n: torch._C.Node) -> None: # Concat if int type input From 1dcfb3f4df5ea187cee8ceedf7476d115495e65b Mon Sep 17 00:00:00 2001 From: Hiroaki Mikami Date: Wed, 19 Apr 2023 17:39:48 +0900 Subject: [PATCH 04/11] Add type hint --- pytorch_pfn_extras/onnx/pfto_exporter/export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 52616d7b0..edab29f06 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -253,10 +253,10 @@ def _run_trace(self) -> None: _remove_prefix(k, f"{_ppe_ignore_scope}."): v for k, v in self.traced.state_dict().items() } if isinstance(self.original_model, torch.nn.Module): - vars_tmp: Dict[str, torch.IValue] = { + vars_tmp: Dict[str, Any] = { _remove_prefix(k, f"{_ppe_ignore_scope}."): v for k, v in self.traced.state_dict(keep_vars=True).items() } - v_to_name = {v: k for k, v in self.original_model.state_dict(keep_vars=True).items()} + v_to_name: Dict[Any, str] = {v: k for k, v in self.original_model.state_dict(keep_vars=True).items()} for name, v in vars_tmp.items(): self.name_from_trace[name] = v_to_name[v] else: From 823b23a21e1c111d2e56d9f74e14ef5847e9d17c Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Fri, 21 Apr 2023 15:43:18 +0900 Subject: [PATCH 05/11] [pfto] Run CSE to reduce unnecessary constants --- pytorch_pfn_extras/onnx/pfto_exporter/export.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index edab29f06..72ca51028 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -298,6 +298,8 @@ def optimize_torch(self, graph: torch._C.Graph) -> torch._C.Graph: # left behind by things like symbolic_override run_jit_pass(torch._C._jit_pass_dce, graph) + run_jit_pass(torch._C._jit_pass_cse, graph) + run_jit_pass(torch._C._jit_pass_canonicalize_graph_fuser_ops, graph) # type: ignore[attr-defined] torch._C._jit_pass_peephole(graph, True) # type: ignore[attr-defined] run_jit_pass(torch._C._jit_pass_fuse_addmm, graph) # type: ignore[attr-defined] From 25356b28c451ac998a8ad87f99f703acaaa8d8dc Mon Sep 17 00:00:00 2001 From: twata Date: Tue, 2 May 2023 06:13:03 +0000 Subject: [PATCH 06/11] [WIP][pfto] Test op normalization --- .../onnx_tests/test_export.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 9cf017ed7..5536db1ae 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -252,3 +252,28 @@ def forward(self, x): onnx_scalar_type_analysis=False, skip_oxrt=True, # Add op in ONNX spec doesn't support complex input ) + + +def test_op_norm(): + class Clip(torch.nn.Module): + def __init__(self): + super(Clip, self).__init__() + self.a = torch.rand(32, 32) + + def forward(self, x): + return torch.clip(x, -2.0, 4.0) + torch.clip(self.a, 10, 20) + + class Proxy(torch.nn.Module): + def __init__(self): + super(Proxy, self).__init__() + self.c = Clip() + + def forward(self, x): + return self.c(x + 1) + + x = torch.rand(32, 32) + run_model_test( + Proxy(), + (x,), + do_constant_folding=False, + ) From 2ed583eac4cc4d5c284aa93cf3f484e47760da41 Mon Sep 17 00:00:00 2001 From: twata Date: Tue, 2 May 2023 06:55:45 +0000 Subject: [PATCH 07/11] Run op norm inside pfto --- .../onnx/pfto_exporter/export.py | 85 +++++++++++++++++++ .../onnx_tests/test_export.py | 65 ++++++++++++++ 2 files changed, 150 insertions(+) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 72ca51028..85b4b87d7 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -29,6 +29,88 @@ _ppe_ignore_scope: str = "_ppe_as_out_module" _list_create_ops: List[str] = ["prim::ListConstruct", "onnx::SequenceConstruct", "onnx::SequenceEmpty"] +# Original from https://github.com/pytorch/pytorch/blob/52a36a98d9425479f62b6e2d1a59e434b85f7f7e/torch/csrc/jit/passes/normalize_ops.cpp#L85-L162 +_op_normalize_table: Dict[str, str] = { + "absolute": "abs", + "absolute_": "abs_", + "clip": "clamp", + "clip_": "clamp_", + "det": "linalg_det", + "matrix_power": "linalg_matrix_power", + "matrix_exp": "linalg_matrix_exp", + "ger": "outer", + "arccos": "acos", + "arccos_": "acos_", + "arcsin": "asin", + "arcsin_": "asin_", + "arctan": "atan", + "arctan_": "atan_", + "arctan2": "atan2", + "arctan2_": "atan2_", + "arccosh": "acosh", + "arccosh_": "acosh_", + "arcsinh": "asinh", + "arcsinh_": "asinh_", + "arctanh": "atanh", + "arctanh_": "atanh_", + "fix": "trunc", + "fix_": "trunc_", + "negative": "neg", + "negative_": "neg_", + "subtract": "sub", + "subtract_": "sub_", + "greater_equal": "ge", + "greater_equal_": "ge_", + "greater": "gt", + "greater_": "gt_", + "less_equal": "le", + "less_equal_": "le_", + "less": "lt", + "less_": "lt_", + "not_equal": "ne", + "not_equal_": "ne_", + "divide": "div", + "divide_": "div_", + "multiply": "mul", + "multiply_": "mul_", + "linalg_matmul": "matmul", + "inverse": "linalg_inv", + "true_divide": "div", + "true_divide_": "div_", + "concat": "cat", + "concatenate": "cat", + "row_stack": "vstack", + "swapdims": "transpose", + "swapdims_": "transpose_", + "swapaxes": "transpose", + "swapaxes_": "transpose_", + "moveaxis": "movedim", + "special_erf": "erf", + "special_erfc": "erfc", + "special_erfinv": "erfinv", + "special_expit": "sigmoid", + "special_exp2": "exp2", + "special_expm1": "expm1", + "special_logit": "logit", + "special_logsumexp": "logsumexp", + "special_round": "round", + "special_log1p": "log1p", + "special_sinc": "sinc", + "special_digamma": "digamma", + "special_psi": "digamma", + "special_i0": "i0", + "special_xlogy": "xlogy", + "special_log_softmax": "log_softmax", + "orgqr": "linalg_householder_product", + "adjoint": "mH", + "special_multigammaln": "mvlgamma", + "special_polygamma": "polygamma", + "special_softmax": "softmax", + "special_gammainc": "igamma", + "special_gammaincc": "igammac", + "special_gammaln": "lgamma", +} + if pytorch_pfn_extras.requires("1.13"): from torch.onnx._internal import jit_utils GraphContext = jit_utils.GraphContext @@ -539,6 +621,9 @@ def symbolic_function(self, n: torch._C.Node) -> Optional[Callable]: import pytorch_pfn_extras.onnx.symbolic_registry as sym_reg + if not sym_reg.is_registered_op(op, domain, self.opset_version) and op in _op_normalize_table: + op = _op_normalize_table[op] + if sym_reg.is_registered_op(op, domain, self.opset_version): # type: ignore[no-untyped-call] return cast( # type: ignore[redundant-cast] Callable, sym_reg.get_registered_op(op, domain, self.opset_version) # type: ignore[no-untyped-call] diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 5536db1ae..81cb2c68b 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -2,6 +2,7 @@ import pytest import torch +import pytorch_pfn_extras as ppe from pytorch_pfn_extras_tests.onnx_tests.utils import run_model_test @@ -255,6 +256,70 @@ def forward(self, x): def test_op_norm(): + if ppe.requires("1.9.0"): + import torch.onnx.symbolic_helper as sym_help + + @torch.onnx.symbolic_helper.parse_args("v", "v") + def clamp_min(g, self, min): + # dtype = self.type().scalarType() + # Type info may be lost here. + # https://github.com/pfnet/pytorch-pfn-extras/issues/578 + # min = g.op("Cast", min, to_i=sym_help.cast_pytorch_to_onnx[dtype]) + if sym_help._get_tensor_rank(min) == 0: + max = torch.onnx.symbolic_opset9.unused(g) + return g.op("Clip", self, min, max) + else: + return g.op("Max", self, min) + + @torch.onnx.symbolic_helper.parse_args("v", "v") + def clamp_max(g, self, max): + # dtype = self.type().scalarType() + # Type info may be lost here. + # https://github.com/pfnet/pytorch-pfn-extras/issues/578 + # max = g.op("Cast", max, to_i=sym_help.cast_pytorch_to_onnx[dtype]) + if sym_help._get_tensor_rank(max) == 0: + min = torch.onnx.symbolic_opset9.unused(g) + return g.op("Clip", self, min, max) + else: + return g.op("Min", self, max) + + torch.onnx.symbolic_opset11.clamp_min = clamp_min + torch.onnx.symbolic_opset11.clamp_max = clamp_max + + @torch.onnx.symbolic_helper.parse_args("v", "v", "v") + def clamp(g, self, min, max): + dtype = self.type().scalarType() + + def _cast_if_not_none(tensor, dtype): + if tensor is not None and not sym_help._is_none(tensor): + return g.op( + "Cast", tensor, to_i=sym_help.cast_pytorch_to_onnx[dtype] + ) + else: + return tensor + + # pfto loses type info after Cast. + # https://github.com/pfnet/pytorch-pfn-extras/issues/578 + orig_min = min + orig_max = max + + if dtype is not None: + min = _cast_if_not_none(min, dtype) + max = _cast_if_not_none(max, dtype) + + if sym_help._is_none(min): + return clamp_max(g, self, max) + elif sym_help._is_none(max): + return clamp_min(g, self, min) + else: + if ( + sym_help._get_tensor_rank(orig_min) == 0 + and sym_help._get_tensor_rank(orig_max) == 0 + ): + return g.op("Clip", self, min, max) + else: + return clamp_max(g, clamp_min(g, self, min), max) + class Clip(torch.nn.Module): def __init__(self): super(Clip, self).__init__() From 02ddabf21e996cae8ce5f53ff794115796804098 Mon Sep 17 00:00:00 2001 From: twata Date: Mon, 8 May 2023 10:10:44 +0000 Subject: [PATCH 08/11] Make CI fail first --- .../onnx_tests/test_grad.py | 123 ++++++------------ 1 file changed, 40 insertions(+), 83 deletions(-) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py index 1c80f3061..320fab015 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py @@ -61,10 +61,10 @@ def forward(self, x): assert y.shape == (1, 1, 32, 20) -# @pytest.mark.parametrize("use_pfto", [False, True]) +@pytest.mark.parametrize("use_pfto", [False, True]) @pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_grad(use_pfto: bool = False): +def test_grad(use_pfto: bool): if not pytorch_pfn_extras.requires('1.8.0'): pytest.skip('skip for PyTorch 1.7 or earlier') @@ -111,28 +111,16 @@ def forward(self, x): assert 'Gradient_4' in named_nodes assert 'MatMul_6' in named_nodes - if use_pfto: - assert list([v.name for v in actual_onnx.graph.output]) == [ - "linear.72", "Gradient_y_0", "Gradient_x_0_0" - ] - y_in, _ = _get_name(actual_onnx.graph, "input.1") - if pytorch_pfn_extras.requires("1.13"): - assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" - assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in - else: - assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" - assert named_nodes["Conv_2"].output[0] == y_in + assert list([v.name for v in actual_onnx.graph.output]) == [ + "v10_MatMul", "Gradient_y_0", "Gradient_x_0_0" + ] + y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") + if pytorch_pfn_extras.requires("1.13"): + assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" + assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in else: - assert list([v.name for v in actual_onnx.graph.output]) == [ - "v10_MatMul", "Gradient_y_0", "Gradient_x_0_0" - ] - y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") - if pytorch_pfn_extras.requires("1.13"): - assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" - assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in - else: - assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" - assert named_nodes["Conv_2"].output[0] == y_in + assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" + assert named_nodes["Conv_2"].output[0] == y_in @pytest.mark.parametrize("use_pfto", [False, True]) @@ -198,44 +186,27 @@ def forward(self, x): assert 'Gradient_9' in named_nodes assert 'MatMul_12' in named_nodes - if use_pfto: - assert list([v.name for v in actual_onnx.graph.output]) == [ - "v16_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1" - ] - y0_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") - y1_in, _ = _get_name(actual_onnx.graph, "Gradient_y_1") - if pytorch_pfn_extras.requires("1.13"): - assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" - assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y0_in - assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].input[0] == "Gradient_x_0_1" - assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].output[0] == y1_in - else: - assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" - assert named_nodes["Conv_2"].output[0] == y0_in - assert named_nodes["Conv_7"].input[0] == "Gradient_x_0_1" - assert named_nodes["Conv_7"].output[0] == y1_in + assert list([v.name for v in actual_onnx.graph.output]) == [ + "v16_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1" + ] + y0_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") + y1_in, _ = _get_name(actual_onnx.graph, "Gradient_y_1") + if pytorch_pfn_extras.requires("1.13"): + assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" + assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y0_in + assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].input[0] == "Gradient_x_0_1" + assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].output[0] == y1_in else: - assert list([v.name for v in actual_onnx.graph.output]) == [ - "v16_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1" - ] - y0_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") - y1_in, _ = _get_name(actual_onnx.graph, "Gradient_y_1") - if pytorch_pfn_extras.requires("1.13"): - assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" - assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y0_in - assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].input[0] == "Gradient_x_0_1" - assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].output[0] == y1_in - else: - assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" - assert named_nodes["Conv_2"].output[0] == y0_in - assert named_nodes["Conv_7"].input[0] == "Gradient_x_0_1" - assert named_nodes["Conv_7"].output[0] == y1_in - - -# @pytest.mark.parametrize("use_pfto", [False, True]) + assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0" + assert named_nodes["Conv_2"].output[0] == y0_in + assert named_nodes["Conv_7"].input[0] == "Gradient_x_0_1" + assert named_nodes["Conv_7"].output[0] == y1_in + + +@pytest.mark.parametrize("use_pfto", [False, True]) @pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_grad_with_multiple_inputs(use_pfto: bool = False): +def test_grad_with_multiple_inputs(use_pfto: bool): if not pytorch_pfn_extras.requires("1.8.0"): pytest.skip('skip for PyTorch 1.7 or earlier') @@ -284,29 +255,15 @@ def forward(self, x): assert 'Gradient_7' in named_nodes assert 'MatMul_9' in named_nodes - if use_pfto: - assert list([v.name for v in actual_onnx.graph.output]) == [ - "linear.87", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0" - ] - y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") - if pytorch_pfn_extras.requires("1.13"): - assert named_nodes["/_ppe_as_out_module/Concat"].input[0] == "Gradient_x_0_0" - assert named_nodes["/_ppe_as_out_module/Concat"].input[1] == "Gradient_x_1_0" - assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in - else: - assert named_nodes["Concat_4"].input[0] == "x0" - assert named_nodes["Concat_4"].input[1] == "x1" - assert named_nodes["Conv_5"].output[0] == "conv.output.1" + assert list([v.name for v in actual_onnx.graph.output]) == [ + "v14_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0" + ] + y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") + if pytorch_pfn_extras.requires("1.13"): + assert named_nodes["/_ppe_as_out_module/Concat"].input[0] == "Gradient_x_0_0" + assert named_nodes["/_ppe_as_out_module/Concat"].input[1] == "Gradient_x_1_0" + assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in else: - assert list([v.name for v in actual_onnx.graph.output]) == [ - "v14_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0" - ] - y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") - if pytorch_pfn_extras.requires("1.13"): - assert named_nodes["/_ppe_as_out_module/Concat"].input[0] == "Gradient_x_0_0" - assert named_nodes["/_ppe_as_out_module/Concat"].input[1] == "Gradient_x_1_0" - assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in - else: - assert named_nodes["Concat_4"].input[0] == "Gradient_x_0_0" - assert named_nodes["Concat_4"].input[1] == "Gradient_x_1_0" - assert named_nodes["Conv_5"].output[0] == y_in + assert named_nodes["Concat_4"].input[0] == "Gradient_x_0_0" + assert named_nodes["Concat_4"].input[1] == "Gradient_x_1_0" + assert named_nodes["Conv_5"].output[0] == y_in From ccde11d4ea32a3b6b52c85c9c02fbe2c177215bc Mon Sep 17 00:00:00 2001 From: twata Date: Mon, 8 May 2023 10:44:54 +0000 Subject: [PATCH 09/11] Set output_names --- .../onnx_tests/test_grad.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py index 320fab015..b710c3404 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py @@ -63,6 +63,7 @@ def forward(self, x): @pytest.mark.parametrize("use_pfto", [False, True]) @pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning") +@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_grad(use_pfto: bool): if not pytorch_pfn_extras.requires('1.8.0'): @@ -98,11 +99,12 @@ def forward(self, x): 'grad', enable_onnx_checker=False, use_pfto=use_pfto, + output_names=["h"], ) actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx')) named_nodes = {n.name: n for n in actual_onnx.graph.node} - if pytorch_pfn_extras.requires("1.13"): + if pytorch_pfn_extras.requires("1.13") and not use_pfto: assert '/_ppe_as_out_module/conv/Conv' in named_nodes assert '/_ppe_as_out_module/Gradient' in named_nodes assert '/_ppe_as_out_module/linear/MatMul' in named_nodes @@ -112,10 +114,10 @@ def forward(self, x): assert 'MatMul_6' in named_nodes assert list([v.name for v in actual_onnx.graph.output]) == [ - "v10_MatMul", "Gradient_y_0", "Gradient_x_0_0" + "h", "Gradient_y_0", "Gradient_x_0_0" ] y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") - if pytorch_pfn_extras.requires("1.13"): + if pytorch_pfn_extras.requires("1.13") and not use_pfto: assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in else: @@ -125,6 +127,7 @@ def forward(self, x): @pytest.mark.parametrize("use_pfto", [False, True]) @pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning") +@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_grad_multiple_times(use_pfto: bool): if not pytorch_pfn_extras.requires("1.8.0"): @@ -168,12 +171,13 @@ def forward(self, x): x, 'grad', enable_onnx_checker=False, - use_pfto=False, + use_pfto=use_pfto, + output_names=["h"], ) actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx')) named_nodes = {n.name: n for n in actual_onnx.graph.node} - if pytorch_pfn_extras.requires("1.13"): + if pytorch_pfn_extras.requires("1.13") and not use_pfto: assert '/_ppe_as_out_module/conv/Conv' in named_nodes assert '/_ppe_as_out_module/conv_1/Conv' in named_nodes assert '/_ppe_as_out_module/Gradient' in named_nodes @@ -187,11 +191,11 @@ def forward(self, x): assert 'MatMul_12' in named_nodes assert list([v.name for v in actual_onnx.graph.output]) == [ - "v16_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1" + "h", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1" ] y0_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") y1_in, _ = _get_name(actual_onnx.graph, "Gradient_y_1") - if pytorch_pfn_extras.requires("1.13"): + if pytorch_pfn_extras.requires("1.13") and not use_pfto: assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0" assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y0_in assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].input[0] == "Gradient_x_0_1" @@ -205,6 +209,7 @@ def forward(self, x): @pytest.mark.parametrize("use_pfto", [False, True]) @pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning") +@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_grad_with_multiple_inputs(use_pfto: bool): if not pytorch_pfn_extras.requires("1.8.0"): @@ -242,11 +247,12 @@ def forward(self, x): 'grad', enable_onnx_checker=False, use_pfto=use_pfto, + output_names=["h"], ) actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx')) named_nodes = {n.name: n for n in actual_onnx.graph.node} - if pytorch_pfn_extras.requires("1.13"): + if pytorch_pfn_extras.requires("1.13") and not use_pfto: assert '/_ppe_as_out_module/conv/Conv' in named_nodes assert '/_ppe_as_out_module/Gradient' in named_nodes assert '/_ppe_as_out_module/linear/MatMul' in named_nodes @@ -256,10 +262,10 @@ def forward(self, x): assert 'MatMul_9' in named_nodes assert list([v.name for v in actual_onnx.graph.output]) == [ - "v14_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0" + "h", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0" ] y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0") - if pytorch_pfn_extras.requires("1.13"): + if pytorch_pfn_extras.requires("1.13") and not use_pfto: assert named_nodes["/_ppe_as_out_module/Concat"].input[0] == "Gradient_x_0_0" assert named_nodes["/_ppe_as_out_module/Concat"].input[1] == "Gradient_x_1_0" assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in From 771b38ba1f33e1ccc8a84b3c1809ca44dd996519 Mon Sep 17 00:00:00 2001 From: twata Date: Tue, 9 May 2023 03:05:37 +0000 Subject: [PATCH 10/11] Reorder trace and original outputs getting to fix as_output feature --- pytorch_pfn_extras/onnx/pfto_exporter/export.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 85b4b87d7..f7efd11da 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -302,9 +302,16 @@ def _restore_state(self) -> None: if torch.cuda.is_available(): torch.cuda.set_rng_state_all(self.cuda_rng_state) + # TODO(twata): Use `self.traced` instead or use traced result outputs + def _get_original_outputs(self) -> None: + self._restore_state() + self.original_outputs = self.original_model(*self.inputs) + self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0]) + def _run_trace(self) -> None: # TODO(twata): Use `torch._C._craete_graph_by_tracing` instead. # So that we don't need to run heavy models multiple times + self._restore_state() self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore self.original_model, self.inputs, @@ -317,10 +324,6 @@ def _run_trace(self) -> None: # Model: {self.traced.original_name} """ - # TODO(twata): Use `self.traced` instead or use traced result outputs - self._restore_state() - self.original_outputs = self.original_model(*self.inputs) - self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0]) self.g: torch._C.Graph = self.traced.inlined_graph """ `self.trace` ignores the override of `state_dict` method in `self.original_model`. @@ -1061,6 +1064,7 @@ def _convert(self) -> None: sym_hel._set_onnx_shape_inference( # type: ignore[no-untyped-call] False # TODO(twata): Use `self.onnx_shape_inference` ) + self._get_original_outputs() self._run_trace() self.model: onnx.ModelProto = self.generate_onnx() finally: From 0aff084e94fc28882f5fbfb108992d027f9b1ab0 Mon Sep 17 00:00:00 2001 From: twata Date: Tue, 9 May 2023 09:48:28 +0000 Subject: [PATCH 11/11] Fix name by initializing grad state --- .../onnx/pfto_exporter/export.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 217648be5..9ca5c870c 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -12,6 +12,7 @@ import onnx.shape_inference import pytorch_pfn_extras import pytorch_pfn_extras.onnx._constants +from pytorch_pfn_extras.onnx import _grad as grad from pytorch_pfn_extras.onnx._globals import GLOBALS from pytorch_pfn_extras.torchscript import run_jit_pass import torch @@ -321,7 +322,7 @@ def _restore_state(self) -> None: # TODO(twata): Use `self.traced` instead or use traced result outputs def _get_original_outputs(self) -> None: self._restore_state() - with _force_tracing(): + with _force_tracing(), grad.init_grad_state(): self.original_outputs = self.original_model(*self.inputs) self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0]) @@ -329,13 +330,14 @@ def _run_trace(self) -> None: # TODO(twata): Use `torch._C._craete_graph_by_tracing` instead. # So that we don't need to run heavy models multiple times self._restore_state() - self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore - self.original_model, - self.inputs, - check_trace=self.check_trace, - strict=self.strict_trace, - _force_outplace=self.force_outplace_trace, - ) + with grad.init_grad_state(): + self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore + self.original_model, + self.inputs, + check_trace=self.check_trace, + strict=self.strict_trace, + _force_outplace=self.force_outplace_trace, + ) self.graph_doc_string = f""" # Model: {self.traced.original_name}