Skip to content

Commit

Permalink
#0: remove shard_to_interleaved
Browse files Browse the repository at this point in the history
  • Loading branch information
vguduruTT committed Mar 9, 2025
1 parent 93dd36d commit 9c940dc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 24 deletions.
40 changes: 20 additions & 20 deletions models/experimental/functional_UFLD_v2/ttnn/ttnn_UFLD_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
device.arch(),
math_fidelity=ttnn.MathFidelity.LoFi,
fp32_dest_acc_en=False,
packer_l1_acc=True,
packer_l1_acc=False,
math_approx_mode=True,
)
self.conv_config = ttnn.Conv2dConfig(
Expand All @@ -52,9 +52,9 @@ def __init__(
enable_act_double_buffer=False,
enable_split_reader=False,
enable_subblock_padding=False,
reshard_if_not_optimal=True,
reshard_if_not_optimal=False,
activation=activation,
input_channels_alignment=16,
input_channels_alignment=8,
)
config_override = None
if config_override and "act_block_h" in config_override:
Expand All @@ -69,9 +69,9 @@ def __init__(
self.weight = weight

def __call__(self, x):
input_height = x.shape[1]
input_width = x.shape[2]
batch_size = x.shape[0]
input_height = self.conv.input_height
input_width = self.conv.input_width
batch_size = self.conv.batch_size
[x, [output_height, output_width], [self.weight, self.bias]] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.weight,
Expand All @@ -91,7 +91,6 @@ def __call__(self, x):
compute_config=self.compute_config,
return_output_dim=True,
return_weights_and_bias=True,
# memory_config = ttnn.DRAM_MEMORY_CONFIG
)
return x, output_height, output_width

Expand All @@ -110,19 +109,19 @@ def __init__(self, conv_args, conv_pth, device, is_downsample=False):
def __call__(self, device, input):
x_identity = input
x, out_ht, out_wdth = self.conv1(input)
if x.is_sharded():
x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)
x = ttnn.reshape(x, (1, out_ht, out_wdth, x.shape[-1])) # RESHAPING FROM (1,1,NHW,C) TO (N,H,W,C) TO AVOID OOM
# if x.is_sharded():
# x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)
# x = ttnn.reshape(x, (1, out_ht, out_wdth, x.shape[-1])) # RESHAPING FROM (1,1,NHW,C) TO (N,H,W,C) TO AVOID OOM
x, out_ht, out_wdth = self.conv2(x)
if x.is_sharded():
x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)
x = ttnn.reshape(x, (1, out_ht, out_wdth, x.shape[-1]))
# if x.is_sharded():
# x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)
# x = ttnn.reshape(x, (1, out_ht, out_wdth, x.shape[-1]))
if self.is_downsample:
x_identity, out_ht, out_wdth = self.downsample(input)
if x_identity.is_sharded():
x_identity = ttnn.sharded_to_interleaved(x_identity, memory_config=ttnn.L1_MEMORY_CONFIG)
x_identity = ttnn.reshape(x_identity, (1, out_ht, out_wdth, x_identity.shape[-1]))
x = ttnn.add(x, x_identity, memory_config=ttnn.L1_MEMORY_CONFIG)
# if x_identity.is_sharded():
# x_identity = ttnn.sharded_to_interleaved(x_identity, memory_config=ttnn.L1_MEMORY_CONFIG)
# x_identity = ttnn.reshape(x_identity, (1, out_ht, out_wdth, x_identity.shape[-1]))
x = ttnn.add(x, x_identity, memory_config=x.memory_config())
x = ttnn.relu(x)

return x
Expand Down Expand Up @@ -180,10 +179,11 @@ def __init__(self, conv_args, conv_pth, device):
)

def __call__(self, x): # [1, 320, 800, 3]
batch_size = x.shape[0]
x, out_ht, out_wdth = self.conv1(x) # [1, 1, 64000, 64] #0.99974
x = ttnn.max_pool2d(
x,
batch_size=1,
batch_size=batch_size,
input_h=out_ht,
input_w=out_wdth,
channels=x.shape[-1],
Expand All @@ -195,7 +195,7 @@ def __call__(self, x): # [1, 320, 800, 3]
if x.is_sharded():
x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG)
x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)
x = ttnn.reshape(x, (1, 80, 200, x.shape[-1]))
# x = ttnn.reshape(x, (1, 80, 200, x.shape[-1]))
x = self.layer1_0(device=self.device, input=x)
x = self.layer1_1(device=self.device, input=x)
x = self.layer1_2(device=self.device, input=x)
Expand Down Expand Up @@ -241,7 +241,7 @@ def __init__(self, conv_args, conv_pth, device):
self.input_dim = self.input_height // 32 * self.input_width // 32 * 8
self.device = device
self.conv_pth = conv_pth
self.res_model = ttnn_Resnet_34(conv_args.res_model, conv_pth.res_model, device=device)
self.res_model = ttnn_Resnet_34(conv_args, conv_pth.res_model, device=device)
self.pool = ttnn_UFLD_V2_Conv2D(conv_args.pool, conv_pth.pool, activation="", device=device)

def __call__(self, input):
Expand Down
22 changes: 18 additions & 4 deletions tests/ttnn/integration_tests/UFLD_v2/test_ttnn_UFLD_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def test_ufld_v2_basic_block(device, batch_size, input_channels, height, width):
torch_model.eval()
torch_input_tensor = torch.randn((batch_size, input_channels, height, width), dtype=torch.bfloat16)
ttnn_input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1))
ttnn_input_tensor = ttnn.from_torch(ttnn_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
ttnn_input_tensor = ttnn.from_torch(ttnn_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
parameters = preprocess_model_parameters(
initialize_model=lambda: torch_model,
custom_preprocessor=custom_preprocessor_basic_block,
Expand Down Expand Up @@ -398,10 +398,13 @@ def test_ufld_v2_basic_block(device, batch_size, input_channels, height, width):
@skip_for_grayskull()
@pytest.mark.parametrize(
"use_pretrained_weight",
[False, True], # uncomment to run the model for real weights
[
False,
# True
], # uncomment to run the model for real weights
ids=[
"pretrained_weight_false",
"pretrained_weight_true", # uncomment to run the model for real weights
# "pretrained_weight_true", # uncomment to run the model for real weights
],
)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True)
Expand All @@ -421,6 +424,12 @@ def test_UFD_V2_Model(device, batch_size, input_channels, height, width, use_pre
torch_model.load_state_dict(new_state_dict)

ttnn_input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1))
ttnn_input_tensor = ttnn_input_tensor.reshape(
1,
1,
(ttnn_input_tensor.shape[0] * ttnn_input_tensor.shape[1] * ttnn_input_tensor.shape[2]),
ttnn_input_tensor.shape[-1],
)
ttnn_input_tensor = ttnn.from_torch(
ttnn_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device
)
Expand All @@ -429,7 +438,12 @@ def test_UFD_V2_Model(device, batch_size, input_channels, height, width, use_pre
custom_preprocessor=custom_preprocessor_whole_model,
device=device,
)
ttnn_model = ttnn_UFLD_V2(conv_args=torch_model, conv_pth=parameters, device=device)
parameters.conv_args = {}
parameters.conv_args = infer_ttnn_module_args(
model=torch_model, run_model=lambda model: torch_model(torch_input_tensor), device=device
)

ttnn_model = ttnn_UFLD_V2(conv_args=parameters.conv_args, conv_pth=parameters, device=device)
torch_output, pred_list = torch_model(torch_input_tensor)
ttnn_output, tt_pred_list = ttnn_model(input=ttnn_input_tensor)
ttnn_output = ttnn.to_torch(ttnn_output)
Expand Down

0 comments on commit 9c940dc

Please sign in to comment.