Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tangzhiyi11 committed Jan 25, 2024
1 parent 7bd8cc4 commit 837f289
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion dicp/dicp/dynamo_bridge/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ def make_tensor_meta(x) -> Optional[TensorMetadata]:
else:
continue
if 'val' in n.meta and test_infer:
(n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value)
(n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value)
for i,(meta_i,fv_i) in enumerate(zip(n_meta_val, fake_val)):
if not isinstance(fv_i, FakeTensor):
continue
assert meta_i.size() == fv_i.size(), "check infer size failed"
assert meta_i.dtype == fv_i.dtype, "check infer dtype failed"
assert meta_i.stride() == fv_i.stride(), "check infer stride failed"
Expand Down
2 changes: 1 addition & 1 deletion dicp/test/op/test_scalar_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TestScalarTensor():
@pytest.mark.parametrize("inputs", [1.0, 3.0, 0.0])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_scalar_tensor(self, inputs, dtype, compiled_model):
redundant_input = torch.randn(1, dtype=dtype)
redundant_input = torch.ones(1, dtype=dtype)
_, output = model(inputs, redundant_input, dtype)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
Expand Down

0 comments on commit 837f289

Please sign in to comment.