From aa4a1ff027316c6ad850180cfc20c3962eafb4e4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 16 Jan 2025 21:29:50 +0000 Subject: [PATCH] Revert "Prevent _legacy_load with weights_only=True (#144914)" This reverts commit 7c3aa1da1c97812af54d41f3f0eff2ef922c0f32. Reverted https://github.com/pytorch/pytorch/pull/144914 on behalf of https://github.com/izaitsevfb due to breaking inductor on trunk ([comment](https://github.com/pytorch/pytorch/pull/144914#issuecomment-2596922781)) --- .../bc/test_backward_compatibility.py | 6 +- test/test_serialization.py | 58 ++++++------------- torch/serialization.py | 13 +++-- 3 files changed, 29 insertions(+), 48 deletions(-) diff --git a/test/quantization/bc/test_backward_compatibility.py b/test/quantization/bc/test_backward_compatibility.py index 601e7d080341bd..ca236e9a27b5b8 100644 --- a/test/quantization/bc/test_backward_compatibility.py +++ b/test/quantization/bc/test_backward_compatibility.py @@ -110,14 +110,12 @@ def _test_op( torch.jit.save(torch.jit.trace(qmodule, input_tensor), traced_module_file) torch.save(qmodule(input_tensor), expected_file) - # weights_only=False as file was saved in .tar format - input_tensor = torch.load(input_file, weights_only=False) + input_tensor = torch.load(input_file) # weights_only = False as sometimes get ScriptObject here qmodule.load_state_dict(torch.load(state_dict_file, weights_only=False)) qmodule_scripted = torch.jit.load(scripted_module_file) qmodule_traced = torch.jit.load(traced_module_file) - # weights_only=False as file was saved in .tar format - expected = torch.load(expected_file, weights_only=False) + expected = torch.load(expected_file) self.assertEqual(qmodule(input_tensor), expected, atol=prec) self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec) self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec) diff --git a/test/test_serialization.py b/test/test_serialization.py index 7451317bb0037c..aea2cf1a6f053d 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -227,6 +227,9 @@ def _test_serialization(self, weights_only): def test_serialization(self): self._test_serialization(False) + def test_serialization_safe(self): + self._test_serialization(True) + def test_serialization_filelike(self): # Test serialization (load and save) with a filelike object b = self._test_serialization_data() @@ -363,6 +366,9 @@ def _test_serialization(conversion): def test_serialization_sparse(self): self._test_serialization(False) + def test_serialization_sparse_safe(self): + self._test_serialization(True) + def test_serialization_sparse_invalid(self): x = torch.zeros(3, 3) x[1][1] = 1 @@ -508,6 +514,9 @@ def __reduce__(self): def test_serialization_backwards_compat(self): self._test_serialization_backwards_compat(False) + def test_serialization_backwards_compat_safe(self): + self._test_serialization_backwards_compat(True) + def test_serialization_save_warnings(self): with warnings.catch_warnings(record=True) as warns: with tempfile.NamedTemporaryFile() as checkpoint: @@ -552,8 +561,7 @@ def load_bytes(): def check_map_locations(map_locations, dtype, intended_device): for fileobject_lambda in fileobject_lambdas: for map_location in map_locations: - # weigts_only=False as the downloaded file path uses the old serialization format - tensor = torch.load(fileobject_lambda(), map_location=map_location, weights_only=False) + tensor = torch.load(fileobject_lambda(), map_location=map_location) self.assertEqual(tensor.device, intended_device) self.assertEqual(tensor.dtype, dtype) @@ -596,8 +604,7 @@ def test_load_nonexistent_device(self): error_msg = r'Attempting to deserialize object on a CUDA device' with self.assertRaisesRegex(RuntimeError, error_msg): - # weights_only=False as serialized is in legacy format - _ = torch.load(buf, weights_only=False) + _ = torch.load(buf) @unittest.skipIf((3, 8, 0) <= sys.version_info < (3, 8, 2), "See https://bugs.python.org/issue39681") def test_serialization_filelike_api_requirements(self): @@ -717,8 +724,7 @@ def test_serialization_storage_slice(self): b'\x00\x00\x00\x00') buf = io.BytesIO(serialized) - # serialized was saved with PyTorch 0.3.1 - (s1, s2) = torch.load(buf, weights_only=False) + (s1, s2) = torch.load(buf) self.assertEqual(s1[0], 0) self.assertEqual(s2[0], 0) self.assertEqual(s1.data_ptr() + 4, s2.data_ptr()) @@ -835,24 +841,6 @@ def wrapper(*args, **kwargs): def __exit__(self, *args, **kwargs): torch.save = self.torch_save - -# used to set weights_only=False in _use_new_zipfile_serialization=False tests -class load_method: - def __init__(self, weights_only): - self.weights_only = weights_only - self.torch_load = torch.load - - def __enter__(self, *args, **kwargs): - def wrapper(*args, **kwargs): - kwargs['weights_only'] = self.weights_only - return self.torch_load(*args, **kwargs) - - torch.load = wrapper - - def __exit__(self, *args, **kwargs): - torch.load = self.torch_load - - Point = namedtuple('Point', ['x', 'y']) class ClassThatUsesBuildInstruction: @@ -889,7 +877,7 @@ def test(f_new, f_old): torch.save(x, f_old, _use_new_zipfile_serialization=False) f_old.seek(0) - x_old_load = torch.load(f_old, weights_only=False) + x_old_load = torch.load(f_old, weights_only=weights_only) self.assertEqual(x_old_load, x_new_load) with AlwaysWarnTypedStorageRemoval(True), warnings.catch_warnings(record=True) as w: @@ -897,17 +885,6 @@ def test(f_new, f_old): test(f_new, f_old) self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}") - def test_old_serialization_fails_with_weights_only(self): - a = torch.randn(5, 5) - with BytesIOContext() as f: - torch.save(a, f, _use_new_zipfile_serialization=False) - f.seek(0) - with self.assertRaisesRegex( - RuntimeError, - "Cannot use ``weights_only=True`` with files saved in the .tar format used before version 1.6." - ): - torch.load(f, weights_only=True) - class TestOldSerialization(TestCase, SerializationMixin): # unique_key is necessary because on Python 2.7, if a warning passed to @@ -983,7 +960,8 @@ def test_serialization_offset(self): self.assertEqual(i, i_loaded) self.assertEqual(j, j_loaded) - def test_serialization_offset_filelike(self): + @parametrize('weights_only', (True, False)) + def test_serialization_offset_filelike(self, weights_only): a = torch.randn(5, 5) b = torch.randn(1024, 1024, 512, dtype=torch.float32) i, j = 41, 43 @@ -995,16 +973,16 @@ def test_serialization_offset_filelike(self): self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024) f.seek(0) i_loaded = pickle.load(f) - a_loaded = torch.load(f) + a_loaded = torch.load(f, weights_only=weights_only) j_loaded = pickle.load(f) - b_loaded = torch.load(f) + b_loaded = torch.load(f, weights_only=weights_only) self.assertTrue(torch.equal(a, a_loaded)) self.assertTrue(torch.equal(b, b_loaded)) self.assertEqual(i, i_loaded) self.assertEqual(j, j_loaded) def run(self, *args, **kwargs): - with serialization_method(use_zip=False), load_method(weights_only=False): + with serialization_method(use_zip=False): return super().run(*args, **kwargs) diff --git a/torch/serialization.py b/torch/serialization.py index 94f2316bfe883e..0a4d067b6ab091 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1501,10 +1501,15 @@ def _get_wo_message(message: str) -> str: "please torch.save your checkpoint with this option in order to use mmap." ) if weights_only: - raise RuntimeError( - "Cannot use ``weights_only=True`` with files saved in the " - ".tar format used before version 1.6. " + UNSAFE_MESSAGE - ) + try: + return _legacy_load( + opened_file, + map_location, + _weights_only_unpickler, + **pickle_load_args, + ) + except pickle.UnpicklingError as e: + raise pickle.UnpicklingError(_get_wo_message(str(e))) from None return _legacy_load( opened_file, map_location, pickle_module, **pickle_load_args )