Skip to content

Commit

Permalink
Support whitelist of dynamic sources (pytorch#147979)
Browse files Browse the repository at this point in the history
This PR introduces the ability to whitelist sources as dynamic. This is particularly useful for large models with graph breaks, as you can keep the dynamism across graph breaks since source names stay consistent. Additionally you can use this to mark ints as dynamic.

NB: I intentionally didn't complicate the interface by supporting specification of per dimension dynamism. There is virtue in keeping true to the standard way of representing sources (eg. L['x']). If we find in practice that we need more more fine grained control, we can explore further affordances at that time.

Pull Request resolved: pytorch#147979
Approved by: https://github.com/Mingming-Ding
  • Loading branch information
bobrenjc93 authored and pytorchmergebot committed Feb 28, 2025
1 parent 0a948f7 commit 4708cfd
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 2 deletions.
104 changes: 104 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
unsupported,
)
from torch._dynamo.utils import counters, ifdynstaticdefault
from torch._dynamo.variables import builder
from torch._inductor.utils import run_and_get_code
from torch.ao.quantization import MinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantize
Expand Down Expand Up @@ -7772,6 +7773,109 @@ def my_dyn_fn(x, y):
with self.assertRaises(ConstraintViolationError):
torch.compile(my_dyn_fn, backend="eager")(y, y)

@torch._dynamo.config.patch(force_parameter_static_shapes=True)
@torch._dynamo.config.patch(force_nn_module_property_static_shapes=True)
@torch.compiler.config.patch(
dynamic_sources="L['x'],L['y'],L['self']._modules['y'].x,L['self']._modules['y']._modules['c']._parameters['weight'],L['self']._modules['y']._modules['c']._parameters['bias']"
)
def test_dynamic_sources_force_parameter_static_shapes_and_property_static_shapes_override(
self,
):
builder._DYNAMIC_SOURCES = None

counter = CompileCounter()

class Y(torch.nn.Module):
def __init__(self, n_input, n_output):
super().__init__()
self.c = torch.nn.Linear(n_input, n_output)
self.x = n_input

def forward(self, x):
return self.c(x) * self.x

class M(torch.nn.Module):
def __init__(self, n_input, n_output):
self.n_input = n_input
self.n_output = n_output
super().__init__()
self.y = Y(n_input, n_output)

@torch.compile(backend=counter)
def forward(self, x, y):
return self.y(x) * y

model = M(3210, 30)
model(torch.randn(1, 3210), 2)
model = M(3211, 30)
model(torch.randn(1, 3211), 3)
model = M(3212, 30)
model(torch.randn(1, 3212), 4)

self.assertEqual(counter.frame_count, 1)

@torch.compiler.config.patch(dynamic_sources="L['x']")
def test_dynamic_sources_int(self):
counter = CompileCounter()

@torch.compile(backend=counter)
def fn(x):
return torch.randn(5) * x

fn(1)
fn(2)
fn(3)

self.assertEqual(counter.frame_count, 1)

@torch.compiler.config.patch(dynamic_sources="L['x']")
def test_dynamic_sources_tensor(self):
counter = CompileCounter()

@torch.compile(backend=counter)
def fn(x):
return x * x

fn(torch.randn(2))
fn(torch.randn(3))
fn(torch.randn(4))

self.assertEqual(counter.frame_count, 1)

@torch.compiler.config.patch(dynamic_sources="L['x']")
def test_dynamic_sources_graph_break(self):
counter = CompileCounter()

def foo(x):
return x * x

@torch.compile(backend=counter)
def fn(x):
x = x * x
torch._dynamo.graph_break()
return foo(x)

fn(torch.randn(2))
fn(torch.randn(3))
fn(torch.randn(4))

# 2 since graph break produces 2 graphs. NB: there are no recompiles
self.assertEqual(counter.frame_count, 2)

@torch.compiler.config.patch(dynamic_sources="L['x'], L['y']")
def test_dynamic_sources_dynamic_override(self):
counter = CompileCounter()

@torch.compile(dynamic=False, backend=counter)
def fn(x, y):
return x * y

fn(2, torch.randn(2))
fn(3, torch.randn(3))
fn(4, torch.randn(4))

self.assertEqual(counter.frame_count, 1)

def test_cannot_trace_mark_dynamic(self):
y = torch.randn([3, 3, 3])

Expand Down
28 changes: 26 additions & 2 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,7 +1973,10 @@ def wrap_symint(self, value):
# know if bare integers are actually going to be sizevars
# and it is inappropriate to eagerly duck size them with
# real sizevars
if (
if self.source.name() in get_dynamic_sources():
log.debug("%s marked dynamic via source whitelist", self.source.name())
dynamic_dim = DimDynamic.DYNAMIC
elif (
config.automatic_dynamic_shapes
and frame_state_entry.scalar is auto_dynamic
):
Expand Down Expand Up @@ -2664,6 +2667,21 @@ def get_automatic_dynamic_shapes_mark_as():
)


_DYNAMIC_SOURCES: Optional[set[str]] = None


def get_dynamic_sources() -> set[str]:
global _DYNAMIC_SOURCES
if _DYNAMIC_SOURCES is not None:
return _DYNAMIC_SOURCES

_DYNAMIC_SOURCES = set(
torch.compiler.config.dynamic_sources.replace(" ", "").split(",")
)

return _DYNAMIC_SOURCES


# Tracks the sources of all fake tensors we wrap in Dynamo.
# Used by shape guard computation.
@dataclasses.dataclass
Expand Down Expand Up @@ -2694,6 +2712,7 @@ def _automatic_dynamic(
unimplemented("torch.compile does not support strided NestedTensor")

name = source.name()
dynamic_sources = get_dynamic_sources()
prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
shape_env_to_source_to_symbol_cache = (
prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None
Expand Down Expand Up @@ -2732,7 +2751,7 @@ def _automatic_dynamic(
inner_contexts=inner_contexts,
)

if static_shapes:
if static_shapes and name not in dynamic_sources:
return StatefulSymbolicContext(
dynamic_sizes=[DimDynamic.STATIC] * e.dim(),
dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(),
Expand Down Expand Up @@ -2853,6 +2872,11 @@ def update_dim2constraint(dim, constraint_range, name):
config.automatic_dynamic_shapes and frame_state_entry.is_stride_dynamic(i)
)

if name in dynamic_sources:
log.debug("%s marked dynamic via source whitelist", name)
automatic_dynamic_size = True
automatic_dynamic_stride = True

automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride

# We will process constraints first, as they will imply that we
Expand Down
11 changes: 11 additions & 0 deletions torch/compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,16 @@
A common use case for such a tag is to break caches.
"""

dynamic_sources: str = Config(
env_name_default="TORCH_COMPILE_DYNAMIC_SOURCES", default=""
)
"""
Comma delimited list of sources that should be marked as dynamic. Primarily useful for large
models with graph breaks where you need intermediate tensors and ints to be marked dynamic.
This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes
and force_parameter_static_shapes.
"""


install_config_module(sys.modules[__name__])

0 comments on commit 4708cfd

Please sign in to comment.