Skip to content

Commit

Permalink
clear code, reduce unnecessary changes
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinfromTJ committed Jan 24, 2024
1 parent ce925e4 commit b35c363
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 27 deletions.
9 changes: 2 additions & 7 deletions dicp/dicp/dynamo_bridge/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,12 @@ def make_tensor_meta(x) -> Optional[TensorMetadata]:
else:
continue
if 'val' in n.meta and test_infer:
restore_needed = False
if not isinstance(n.meta['val'],(Tuple,List)):
n.meta['val'],fake_value = (n.meta['val'],),(fake_value,)
restore_needed = True
for i,(meta_i,fv_i) in enumerate(zip(n.meta['val'],fake_value)):
(n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n_meta_val, fake_val)
for i,(meta_i,fv_i) in enumerate(zip(n_meta_val, fake_val)):
assert meta_i.size() == fv_i.size(), "check infer size failed"
assert meta_i.dtype == fv_i.dtype, "check infer dtype failed"
assert meta_i.stride() == fv_i.stride(), "check infer stride failed"
assert meta_i.storage_offset() == fv_i.storage_offset(), "check infer storage offset failed"
if restore_needed:
n.meta['val'],fake_value=n.meta['val'][0],fake_value[0]
if 'val' not in n.meta:
n.meta['val'] = fake_value
n.meta["tensor_meta"] = make_tensor_meta(n.meta['val'])
Expand Down
28 changes: 8 additions & 20 deletions dicp/test/op/test__native_batch_norm_legit_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class OpModule(torch.nn.Module):
def forward(self, a, running_mean,running_var,device='cpu'):
def forward(self, a, device="cpu"):
m = torch.nn.BatchNorm2d(100)
m.to(device)
res_default = m(a)
Expand All @@ -23,33 +23,21 @@ def forward(self, a, running_mean,running_var,device='cpu'):
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestNativeBatchNormLegitFunctional:
class TestNativeBatchNormLegitFunctional():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize(
"sizes",
[
Size((20, 100, 35, 45), (20, 100, 35, 45)),
Size((30, 100, 45, 35), (30, 100, 45, 35))
],
)
@pytest.mark.parametrize("sizes", [Size((20, 100, 35, 45), (20, 100, 35, 45)),
Size((30, 100, 45, 35), (30, 100, 45, 35))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch__native_batch_norm_legit_functional(
self, sizes, dtype, compiled_model
):
def test_torch__native_batch_norm_legit_functional(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
input1 = torch.randn(size, dtype=dtype)

running_mean,running_var =torch.zeros([100]),torch.zeros([100])
dicp_input1 = input1.to(device)

output = model(input1,running_mean,running_var)

dicp_running_mean,dicp_running_var=running_mean.to(device),running_var.to(device)
output = model(input1)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1, dicp_running_mean,dicp_running_var,device)
dicp_output = compiled_model.model(dicp_input1, device)

assert torch.allclose(
output.detach(), dicp_output.cpu().detach(), equal_nan=True
)
assert torch.allclose(output.detach(), dicp_output.cpu().detach(), equal_nan=True)

0 comments on commit b35c363

Please sign in to comment.