Skip to content

Commit

Permalink
#0:shard linear ops
Browse files Browse the repository at this point in the history
  • Loading branch information
vguduruTT committed Mar 9, 2025
1 parent 9c940dc commit 7ee67a9
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 504 deletions.
20 changes: 14 additions & 6 deletions models/experimental/functional_UFLD_v2/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ttnn.model_preprocessing import (
preprocess_model_parameters,
fold_batch_norm2d_into_conv2d,
infer_ttnn_module_args,
preprocess_linear_weight,
preprocess_linear_bias,
)
Expand Down Expand Up @@ -319,22 +320,23 @@ def attempt_download(file, key="dumps"):
@pytest.mark.parametrize(
"batch_size,input_channels,height,width",
[
(1, 3, 320, 800),
(2, 3, 320, 800),
],
)
@pytest.mark.parametrize(
"use_pretrained_weight",
[
False,
# True # uncomment to run the model for real weights
# False,
True # uncomment to run the model for real weights
],
ids=[
"pretrained_weight_false",
# "pretrained_weight_true", # uncomment to run the model for real weights
# "pretrained_weight_false",
"pretrained_weight_true", # uncomment to run the model for real weights
],
)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True)
def test_tu_simple_res34_inference(batch_size, input_channels, height, width, device, use_pretrained_weight):
torch_input_tensor = torch.randn((batch_size, input_channels, height, width))
reference_model = Tu_Simple(input_height=height, input_width=width)
if use_pretrained_weight:
logger.info(f"Demo Inference using Pre-trained Weights")
Expand All @@ -359,6 +361,7 @@ def test_tu_simple_res34_inference(batch_size, input_channels, height, width, de
cfg.crop_ratio,
cfg.train_width,
cfg.train_height,
batch_size=batch_size,
row_anchor=cfg.row_anchor,
col_anchor=cfg.col_anchor,
device=None,
Expand All @@ -369,7 +372,11 @@ def test_tu_simple_res34_inference(batch_size, input_channels, height, width, de
custom_preprocessor=custom_preprocessor,
device=device,
)
ttnn_model = ttnn_UFLD_V2(conv_args=reference_model, conv_pth=parameters, device=device)
parameters.conv_args = {}
parameters.conv_args = infer_ttnn_module_args(
model=reference_model, run_model=lambda model: reference_model(torch_input_tensor), device=device
)
ttnn_model = ttnn_UFLD_V2(conv_args=parameters.conv_args, conv_pth=parameters, device=device)
run_test_tusimple(
ttnn_model,
cfg.data_root,
Expand All @@ -379,6 +386,7 @@ def test_tu_simple_res34_inference(batch_size, input_channels, height, width, de
cfg.crop_ratio,
cfg.train_width,
cfg.train_height,
batch_size=batch_size,
row_anchor=cfg.row_anchor,
col_anchor=cfg.col_anchor,
device=device,
Expand Down

This file was deleted.

This file was deleted.

Loading

0 comments on commit 7ee67a9

Please sign in to comment.