From 36be0cb4f9e7743a7020f29568aae1a9c95133ca Mon Sep 17 00:00:00 2001 From: Linsho Kaku Date: Mon, 13 May 2024 14:27:06 +0900 Subject: [PATCH 1/5] update of the method of checking for calls to the optimizer --- .../training/extensions/lr_scheduler.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/pytorch_pfn_extras/training/extensions/lr_scheduler.py b/pytorch_pfn_extras/training/extensions/lr_scheduler.py index 64e2ce30..747f4901 100644 --- a/pytorch_pfn_extras/training/extensions/lr_scheduler.py +++ b/pytorch_pfn_extras/training/extensions/lr_scheduler.py @@ -1,12 +1,14 @@ from typing import Any, Dict, Optional + from pytorch_pfn_extras.training import extension from pytorch_pfn_extras.training import trigger as trigger_module from pytorch_pfn_extras.training._manager_protocol import ( ExtensionsManagerProtocol, ) from torch.optim.lr_scheduler import ReduceLROnPlateau - +from torch.optim import Optimizer +from pytorch_pfn_extras._torch_version import requires def _get_value_from_log_report( manager: ExtensionsManagerProtocol, key: Any @@ -32,6 +34,22 @@ def _default_stepper( else: scheduler.step() +def check_optimizer_is_called(optimizer: Optimizer) -> bool: + if requires("2.4.0.dev"): + # https://github.com/pytorch/pytorch/blob/afda6685ae87cce7ac2fe4bac3926572da2960f7/torch/optim/lr_scheduler.py#L172-L191 + # TODO: Rewrite this URL when pytorch 2.4.0 is released. + if hasattr(optimizer.step, "_wrapped_by_lr_sched"): + return getattr(optimizer, "_opt_called", False) + else: + return True + else: + # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/optim/lr_scheduler.py#L137-L138 + if hasattr(optimizer.step, "_with_counter"): + return optimizer._step_count >= 1 + else: + return True + + class LRScheduler(extension.Extension): """Trainer extension to adjust the learning rate using PyTorch's learning @@ -72,8 +90,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/optim/lr_scheduler.py#L137-L138 if ( self.wait_for_first_optimizer_step - and hasattr(self.scheduler.optimizer.step, "_with_counter") - and self.scheduler.optimizer._step_count < 1 + and not check_optimizer_is_called(self.scheduler.optimizer) ): return self.stepper(manager, self.scheduler) From 736d971de20f9676a8a401e00e233b962cb898ca Mon Sep 17 00:00:00 2001 From: Linsho Kaku Date: Mon, 13 May 2024 14:42:57 +0900 Subject: [PATCH 2/5] fix lint --- pytorch_pfn_extras/training/extensions/lr_scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_pfn_extras/training/extensions/lr_scheduler.py b/pytorch_pfn_extras/training/extensions/lr_scheduler.py index 747f4901..dc6c081f 100644 --- a/pytorch_pfn_extras/training/extensions/lr_scheduler.py +++ b/pytorch_pfn_extras/training/extensions/lr_scheduler.py @@ -1,14 +1,14 @@ from typing import Any, Dict, Optional - +from pytorch_pfn_extras._torch_version import requires from pytorch_pfn_extras.training import extension from pytorch_pfn_extras.training import trigger as trigger_module from pytorch_pfn_extras.training._manager_protocol import ( ExtensionsManagerProtocol, ) -from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim import Optimizer -from pytorch_pfn_extras._torch_version import requires +from torch.optim.lr_scheduler import ReduceLROnPlateau + def _get_value_from_log_report( manager: ExtensionsManagerProtocol, key: Any @@ -34,6 +34,7 @@ def _default_stepper( else: scheduler.step() + def check_optimizer_is_called(optimizer: Optimizer) -> bool: if requires("2.4.0.dev"): # https://github.com/pytorch/pytorch/blob/afda6685ae87cce7ac2fe4bac3926572da2960f7/torch/optim/lr_scheduler.py#L172-L191 @@ -50,7 +51,6 @@ def check_optimizer_is_called(optimizer: Optimizer) -> bool: return True - class LRScheduler(extension.Extension): """Trainer extension to adjust the learning rate using PyTorch's learning rate scheduler. From d5b16e6fde0d87f036149e300da69001c353a355 Mon Sep 17 00:00:00 2001 From: Linsho Kaku Date: Mon, 13 May 2024 15:06:43 +0900 Subject: [PATCH 3/5] follow the new I/F. --- pytorch_pfn_extras/onnx/pfto_exporter/export.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 359a39d0..ad6135a7 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -693,11 +693,16 @@ def run_symbolic_function(self, g: torch._C.Graph, n: torch._C.Node, sym_func: C node_inputs.extend(n.scalar_args()) if "module" in attrs: del attrs["module"] - if pytorch_pfn_extras.requires("1.13"): + if pytorch_pfn_extras.requires("2.4.0.dev"): g_ctx = GraphContext( graph=g, block=n.owningBlock(), opset=self.opset_version, original_node=n, - params_dict=self.vars, env=self.torch2onnx_var) + params_dict=self.vars, env=self.torch2onnx_var, values_in_env=set()) + elif pytorch_pfn_extras.requires("1.13"): + g_ctx = GraphContext( + graph=g, block=n.owningBlock(), + opset=self.opset_version, original_node=n, + params_dict=self.vars, env=self.torch2onnx_var) else: g_ctx = g # type: ignore if ( From e3b8aab63529e30802df27f9f883f49a57144245 Mon Sep 17 00:00:00 2001 From: Linsho Kaku Date: Mon, 13 May 2024 15:27:56 +0900 Subject: [PATCH 4/5] fix type --- pytorch_pfn_extras/training/extensions/lr_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pfn_extras/training/extensions/lr_scheduler.py b/pytorch_pfn_extras/training/extensions/lr_scheduler.py index dc6c081f..ead7420b 100644 --- a/pytorch_pfn_extras/training/extensions/lr_scheduler.py +++ b/pytorch_pfn_extras/training/extensions/lr_scheduler.py @@ -46,7 +46,7 @@ def check_optimizer_is_called(optimizer: Optimizer) -> bool: else: # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/optim/lr_scheduler.py#L137-L138 if hasattr(optimizer.step, "_with_counter"): - return optimizer._step_count >= 1 + return bool(optimizer._step_count >= 1) # type: ignore[attr-defined] else: return True From ddf118acafc4d2cb37ccc5383d5d839e04e9f9a0 Mon Sep 17 00:00:00 2001 From: Linsho Kaku Date: Mon, 13 May 2024 16:18:11 +0900 Subject: [PATCH 5/5] fix type --- pytorch_pfn_extras/onnx/pfto_exporter/export.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index ad6135a7..746494ef 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -693,16 +693,15 @@ def run_symbolic_function(self, g: torch._C.Graph, n: torch._C.Node, sym_func: C node_inputs.extend(n.scalar_args()) if "module" in attrs: del attrs["module"] - if pytorch_pfn_extras.requires("2.4.0.dev"): - g_ctx = GraphContext( - graph=g, block=n.owningBlock(), - opset=self.opset_version, original_node=n, - params_dict=self.vars, env=self.torch2onnx_var, values_in_env=set()) - elif pytorch_pfn_extras.requires("1.13"): + if pytorch_pfn_extras.requires("1.13"): + if pytorch_pfn_extras.requires("2.4.0.dev"): + g_ctx_kwargs: Dict[str, Any] = {"values_in_env": set()} + else: + g_ctx_kwargs = {} g_ctx = GraphContext( graph=g, block=n.owningBlock(), opset=self.opset_version, original_node=n, - params_dict=self.vars, env=self.torch2onnx_var) + params_dict=self.vars, env=self.torch2onnx_var, **g_ctx_kwargs) else: g_ctx = g # type: ignore if (