From 4bc0af8333b0d4f14417c0d015017b8a0d8678e7 Mon Sep 17 00:00:00 2001 From: vguduruTT Date: Mon, 3 Mar 2025 18:05:36 +0000 Subject: [PATCH] #0:shard linear ops --- .../functional_UFLD_v2/demo/demo.py | 20 +- .../tests/test_ttnn_UFLD_v2_basic_block.py | 62 ---- .../tests/test_ttnn_UFLD_v2_conv.py | 59 ---- .../tests/test_ttnn_UFLD_v2_model.py | 321 ------------------ .../functional_UFLD_v2/ttnn/ttnn_UFLD_v2.py | 100 +++--- .../UFLD_v2/test_ttnn_UFLD_v2.py | 18 +- 6 files changed, 76 insertions(+), 504 deletions(-) delete mode 100644 models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_basic_block.py delete mode 100644 models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_conv.py delete mode 100644 models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_model.py diff --git a/models/experimental/functional_UFLD_v2/demo/demo.py b/models/experimental/functional_UFLD_v2/demo/demo.py index 404a4a38b0c..1a4564b5330 100644 --- a/models/experimental/functional_UFLD_v2/demo/demo.py +++ b/models/experimental/functional_UFLD_v2/demo/demo.py @@ -12,6 +12,7 @@ from ttnn.model_preprocessing import ( preprocess_model_parameters, fold_batch_norm2d_into_conv2d, + infer_ttnn_module_args, preprocess_linear_weight, preprocess_linear_bias, ) @@ -319,22 +320,23 @@ def attempt_download(file, key="dumps"): @pytest.mark.parametrize( "batch_size,input_channels,height,width", [ - (1, 3, 320, 800), + (2, 3, 320, 800), ], ) @pytest.mark.parametrize( "use_pretrained_weight", [ - False, - # True # uncomment to run the model for real weights + # False, + True # uncomment to run the model for real weights ], ids=[ - "pretrained_weight_false", - # "pretrained_weight_true", # uncomment to run the model for real weights + # "pretrained_weight_false", + "pretrained_weight_true", # uncomment to run the model for real weights ], ) @pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True) def test_tu_simple_res34_inference(batch_size, input_channels, height, width, device, use_pretrained_weight): + torch_input_tensor = torch.randn((batch_size, input_channels, height, width)) reference_model = Tu_Simple(input_height=height, input_width=width) if use_pretrained_weight: logger.info(f"Demo Inference using Pre-trained Weights") @@ -359,6 +361,7 @@ def test_tu_simple_res34_inference(batch_size, input_channels, height, width, de cfg.crop_ratio, cfg.train_width, cfg.train_height, + batch_size=batch_size, row_anchor=cfg.row_anchor, col_anchor=cfg.col_anchor, device=None, @@ -369,7 +372,11 @@ def test_tu_simple_res34_inference(batch_size, input_channels, height, width, de custom_preprocessor=custom_preprocessor, device=device, ) - ttnn_model = ttnn_UFLD_V2(conv_args=reference_model, conv_pth=parameters, device=device) + parameters.conv_args = {} + parameters.conv_args = infer_ttnn_module_args( + model=reference_model, run_model=lambda model: reference_model(torch_input_tensor), device=device + ) + ttnn_model = ttnn_UFLD_V2(conv_args=parameters.conv_args, conv_pth=parameters, device=device) run_test_tusimple( ttnn_model, cfg.data_root, @@ -379,6 +386,7 @@ def test_tu_simple_res34_inference(batch_size, input_channels, height, width, de cfg.crop_ratio, cfg.train_width, cfg.train_height, + batch_size=batch_size, row_anchor=cfg.row_anchor, col_anchor=cfg.col_anchor, device=device, diff --git a/models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_basic_block.py b/models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_basic_block.py deleted file mode 100644 index 5a60d34b8a3..00000000000 --- a/models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_basic_block.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -import pytest -import torch -import torch.nn as nn -from models.experimental.functional_UFLD_v2.ttnn.ttnn_UFLD_v2 import ttnn_Basic_Block -from ttnn.model_preprocessing import preprocess_model_parameters, fold_batch_norm2d_into_conv2d, infer_ttnn_module_args -from models.experimental.functional_UFLD_v2.reference.UFLD_v2_model import Tu_Simple, BasicBlock -from tests.ttnn.utils_for_testing import assert_with_pcc - - -def custom_preprocessor(model, name): - parameters = {} - if isinstance(model, BasicBlock): - weight, bias = fold_batch_norm2d_into_conv2d(model.conv1, model.bn1) - parameters["conv1"] = {} - parameters["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - - weight, bias = fold_batch_norm2d_into_conv2d(model.conv2, model.bn2) - parameters["conv2"] = {} - parameters["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - - return parameters - - -@pytest.mark.parametrize( - "batch_size,input_channels,height,width", - [ - (1, 64, 80, 200), - ], -) -@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True) -def test_ufld_v2_basic_block(device, batch_size, input_channels, height, width): - torch_model = Tu_Simple(input_height=height, input_width=width).res_model.layer1[0] - torch_model.to(torch.bfloat16) - torch_model.eval() - torch_input_tensor = torch.randn((batch_size, input_channels, height, width), dtype=torch.bfloat16) - ttnn_input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1)) - ttnn_input_tensor = ttnn.from_torch(ttnn_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) - parameters = preprocess_model_parameters( - initialize_model=lambda: torch_model, - custom_preprocessor=custom_preprocessor, - device=device, - ) - parameters.conv_args = {} - parameters.conv_args = infer_ttnn_module_args( - model=torch_model, run_model=lambda model: torch_model(torch_input_tensor), device=device - ) - ttnn_model = ttnn_Basic_Block(parameters.conv_args, parameters, device=device) - torch_out = torch_model(torch_input_tensor) - ttnn_output = ttnn_model(device=device, input=ttnn_input_tensor) - ttnn_output = ttnn.to_torch(ttnn_output) - ttnn_output = ttnn_output.permute(0, 3, 1, 2) - ttnn_output = ttnn_output.reshape(torch_out.shape) - assert_with_pcc(ttnn_output, torch_out, 0.9999999) diff --git a/models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_conv.py b/models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_conv.py deleted file mode 100644 index 89c3e0e55d0..00000000000 --- a/models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_conv.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -import pytest -import torch -import torch.nn as nn -from models.experimental.functional_UFLD_v2.ttnn.ttnn_UFLD_v2 import ttnn_UFLD_V2_Conv2D -from ttnn.model_preprocessing import preprocess_model_parameters, infer_ttnn_module_args -from models.experimental.functional_UFLD_v2.reference.UFLD_v2_model import Tu_Simple -from tests.ttnn.utils_for_testing import assert_with_pcc - - -def custom_preprocessor(model, name): - parameters = {} - if isinstance(model, nn.Conv2d): - parameters["weight"] = ttnn.from_torch(model.weight, dtype=ttnn.float32) - if model.bias is not None: - bias = model.bias.reshape((1, 1, 1, -1)) - parameters["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - else: - parameters["bias"] = None - - return parameters - - -@pytest.mark.parametrize( - "batch_size,input_channels,height,width", - [ - (1, 3, 320, 800), - ], -) -@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True) -def test_UFLD_V2_conv(device, batch_size, input_channels, height, width): - torch_model = Tu_Simple(input_height=height, input_width=width).res_model.conv1 - torch_model.to(torch.bfloat16) - torch_model.eval() - torch_input_tensor = torch.randn((batch_size, input_channels, height, width), dtype=torch.bfloat16) - ttnn_input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1)) - ttnn_input_tensor = ttnn.from_torch( - ttnn_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device - ) - parameters = preprocess_model_parameters( - initialize_model=lambda: torch_model, - custom_preprocessor=custom_preprocessor, - device=device, - ) - parameters.conv_args = {} - parameters.conv_args = infer_ttnn_module_args( - model=torch_model, run_model=lambda model: torch_model(torch_input_tensor), device=None - ) - ttnn_model = ttnn_UFLD_V2_Conv2D(parameters.conv_args, parameters, activation="", device=device) - torch_out = torch_model(torch_input_tensor) - ttnn_output = ttnn_model(ttnn_input_tensor) - ttnn_output = ttnn.to_torch(ttnn_output[0]) - ttnn_output = ttnn_output.permute(0, 3, 1, 2) - ttnn_output = ttnn_output.reshape(torch_out.shape) - assert_with_pcc(ttnn_output, torch_out, 0.9999) diff --git a/models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_model.py b/models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_model.py deleted file mode 100644 index 328add225af..00000000000 --- a/models/experimental/functional_UFLD_v2/tests/test_ttnn_UFLD_v2_model.py +++ /dev/null @@ -1,321 +0,0 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -import pytest -import torch -from models.experimental.functional_UFLD_v2.ttnn.ttnn_UFLD_v2 import ( - ttnn_UFLD_V2, -) -from ttnn.model_preprocessing import preprocess_model_parameters, fold_batch_norm2d_into_conv2d -from ttnn.model_preprocessing import preprocess_linear_weight, preprocess_linear_bias -from models.experimental.functional_UFLD_v2.reference.UFLD_v2_model import Tu_Simple -from tests.ttnn.utils_for_testing import assert_with_pcc - - -def custom_preprocessor(model, name): - parameters = {} - if isinstance(model, Tu_Simple): - # conv1,bn1 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.conv1, model.res_model.bn1) - parameters["res_model"] = {} - parameters["res_model"]["conv1"] = {} - parameters["res_model"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer0 - 0 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer1[0].conv1, model.res_model.layer1[0].bn1) - parameters["res_model"]["layer1_0"] = {} - parameters["res_model"]["layer1_0"]["conv1"] = {} - parameters["res_model"]["layer1_0"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer1_0"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer1[0].conv2, model.res_model.layer1[0].bn2) - parameters["res_model"]["layer1_0"]["conv2"] = {} - parameters["res_model"]["layer1_0"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer1_0"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer1 - 1 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer1[1].conv1, model.res_model.layer1[1].bn1) - parameters["res_model"]["layer1_1"] = {} - parameters["res_model"]["layer1_1"]["conv1"] = {} - parameters["res_model"]["layer1_1"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer1_1"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer1[1].conv2, model.res_model.layer1[1].bn2) - parameters["res_model"]["layer1_1"]["conv2"] = {} - parameters["res_model"]["layer1_1"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer1_1"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer1 - 2 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer1[2].conv1, model.res_model.layer1[2].bn1) - parameters["res_model"]["layer1_2"] = {} - parameters["res_model"]["layer1_2"]["conv1"] = {} - parameters["res_model"]["layer1_2"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer1_2"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer1[2].conv2, model.res_model.layer1[2].bn2) - parameters["res_model"]["layer1_2"]["conv2"] = {} - parameters["res_model"]["layer1_2"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer1_2"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer-2-0 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer2[0].conv1, model.res_model.layer2[0].bn1) - parameters["res_model"]["layer2_0"] = {} - parameters["res_model"]["layer2_0"]["conv1"] = {} - parameters["res_model"]["layer2_0"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer2_0"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer2[0].conv2, model.res_model.layer2[0].bn2) - parameters["res_model"]["layer2_0"]["conv2"] = {} - parameters["res_model"]["layer2_0"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer2_0"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer2 - 1 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer2[1].conv1, model.res_model.layer2[1].bn1) - parameters["res_model"]["layer2_1"] = {} - parameters["res_model"]["layer2_1"]["conv1"] = {} - parameters["res_model"]["layer2_1"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer2_1"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer2[1].conv2, model.res_model.layer2[1].bn2) - parameters["res_model"]["layer2_1"]["conv2"] = {} - parameters["res_model"]["layer2_1"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer2_1"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer2-2 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer2[2].conv1, model.res_model.layer2[2].bn1) - parameters["res_model"]["layer2_2"] = {} - parameters["res_model"]["layer2_2"]["conv1"] = {} - parameters["res_model"]["layer2_2"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer2_2"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer2[2].conv2, model.res_model.layer2[2].bn2) - parameters["res_model"]["layer2_2"]["conv2"] = {} - parameters["res_model"]["layer2_2"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer2_2"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer2-3 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer2[3].conv1, model.res_model.layer2[3].bn1) - parameters["res_model"]["layer2_3"] = {} - parameters["res_model"]["layer2_3"]["conv1"] = {} - parameters["res_model"]["layer2_3"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer2_3"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer2[3].conv2, model.res_model.layer2[3].bn2) - parameters["res_model"]["layer2_3"]["conv2"] = {} - parameters["res_model"]["layer2_3"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer2_3"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # downsample layer2[0] - if hasattr(model.res_model.layer2[0], "downsample") and model.res_model.layer2[0].downsample is not None: - downsample = model.res_model.layer2[0].downsample - if isinstance(downsample, torch.nn.Sequential): - conv_layer = downsample[0] - bn_layer = downsample[1] - weight, bias = fold_batch_norm2d_into_conv2d(conv_layer, bn_layer) - parameters["res_model"]["layer2_0"]["downsample"] = {} - parameters["res_model"]["layer2_0"]["downsample"]["weight"] = ttnn.from_torch( - weight, dtype=ttnn.float32 - ) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer2_0"]["downsample"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer3-0 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[0].conv1, model.res_model.layer3[0].bn1) - parameters["res_model"]["layer3_0"] = {} - parameters["res_model"]["layer3_0"]["conv1"] = {} - parameters["res_model"]["layer3_0"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_0"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[0].conv2, model.res_model.layer3[0].bn2) - parameters["res_model"]["layer3_0"]["conv2"] = {} - parameters["res_model"]["layer3_0"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_0"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer3-1 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[1].conv1, model.res_model.layer3[1].bn1) - parameters["res_model"]["layer3_1"] = {} - parameters["res_model"]["layer3_1"]["conv1"] = {} - parameters["res_model"]["layer3_1"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_1"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[1].conv2, model.res_model.layer3[1].bn2) - parameters["res_model"]["layer3_1"]["conv2"] = {} - parameters["res_model"]["layer3_1"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_1"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer3-2 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[2].conv1, model.res_model.layer3[2].bn1) - parameters["res_model"]["layer3_2"] = {} - parameters["res_model"]["layer3_2"]["conv1"] = {} - parameters["res_model"]["layer3_2"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_2"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[2].conv2, model.res_model.layer3[2].bn2) - parameters["res_model"]["layer3_2"]["conv2"] = {} - parameters["res_model"]["layer3_2"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_2"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer3-3 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[3].conv1, model.res_model.layer3[3].bn1) - parameters["res_model"]["layer3_3"] = {} - parameters["res_model"]["layer3_3"]["conv1"] = {} - parameters["res_model"]["layer3_3"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_3"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[3].conv2, model.res_model.layer3[3].bn2) - parameters["res_model"]["layer3_3"]["conv2"] = {} - parameters["res_model"]["layer3_3"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_3"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer3-4 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[4].conv1, model.res_model.layer3[4].bn1) - parameters["res_model"]["layer3_4"] = {} - parameters["res_model"]["layer3_4"]["conv1"] = {} - parameters["res_model"]["layer3_4"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_4"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[4].conv2, model.res_model.layer3[4].bn2) - parameters["res_model"]["layer3_4"]["conv2"] = {} - parameters["res_model"]["layer3_4"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_4"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer3-5 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[5].conv1, model.res_model.layer3[5].bn1) - parameters["res_model"]["layer3_5"] = {} - parameters["res_model"]["layer3_5"]["conv1"] = {} - parameters["res_model"]["layer3_5"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_5"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer3[5].conv2, model.res_model.layer3[5].bn2) - parameters["res_model"]["layer3_5"]["conv2"] = {} - parameters["res_model"]["layer3_5"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_5"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # downsample - layer3[0] - if hasattr(model.res_model.layer3[0], "downsample") and model.res_model.layer3[0].downsample is not None: - downsample = model.res_model.layer3[0].downsample - if isinstance(downsample, torch.nn.Sequential): - conv_layer = downsample[0] - bn_layer = downsample[1] - weight, bias = fold_batch_norm2d_into_conv2d(conv_layer, bn_layer) - parameters["res_model"]["layer3_0"]["downsample"] = {} - parameters["res_model"]["layer3_0"]["downsample"]["weight"] = ttnn.from_torch( - weight, dtype=ttnn.float32 - ) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer3_0"]["downsample"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer4-0 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer4[0].conv1, model.res_model.layer4[0].bn1) - parameters["res_model"]["layer4_0"] = {} - parameters["res_model"]["layer4_0"]["conv1"] = {} - parameters["res_model"]["layer4_0"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer4_0"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer4[0].conv2, model.res_model.layer4[0].bn2) - parameters["res_model"]["layer4_0"]["conv2"] = {} - parameters["res_model"]["layer4_0"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer4_0"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer4 - 1 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer4[1].conv1, model.res_model.layer4[1].bn1) - parameters["res_model"]["layer4_1"] = {} - parameters["res_model"]["layer4_1"]["conv1"] = {} - parameters["res_model"]["layer4_1"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer4_1"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer4[1].conv2, model.res_model.layer4[1].bn2) - parameters["res_model"]["layer4_1"]["conv2"] = {} - parameters["res_model"]["layer4_1"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer4_1"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # layer4-2 - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer4[2].conv1, model.res_model.layer4[2].bn1) - parameters["res_model"]["layer4_2"] = {} - parameters["res_model"]["layer4_2"]["conv1"] = {} - parameters["res_model"]["layer4_2"]["conv1"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer4_2"]["conv1"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - weight, bias = fold_batch_norm2d_into_conv2d(model.res_model.layer4[2].conv2, model.res_model.layer4[2].bn2) - parameters["res_model"]["layer4_2"]["conv2"] = {} - parameters["res_model"]["layer4_2"]["conv2"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer4_2"]["conv2"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - # downsample - layer3[0] - if hasattr(model.res_model.layer4[0], "downsample") and model.res_model.layer4[0].downsample is not None: - downsample = model.res_model.layer4[0].downsample - if isinstance(downsample, torch.nn.Sequential): - conv_layer = downsample[0] - bn_layer = downsample[1] - weight, bias = fold_batch_norm2d_into_conv2d(conv_layer, bn_layer) - parameters["res_model"]["layer4_0"]["downsample"] = {} - parameters["res_model"]["layer4_0"]["downsample"]["weight"] = ttnn.from_torch( - weight, dtype=ttnn.float32 - ) - bias = bias.reshape((1, 1, 1, -1)) - parameters["res_model"]["layer4_0"]["downsample"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - - # pool - parameters["pool"] = {} - parameters["pool"]["weight"] = ttnn.from_torch(model.pool.weight, dtype=ttnn.float32) - if model.pool.bias is not None: - bias = model.pool.bias.reshape((1, 1, 1, -1)) - parameters["pool"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) - else: - parameters["pool"]["bias"] = None - - parameters["cls"] = {} - parameters["cls"]["linear_1"] = {} - parameters["cls"]["linear_1"]["weight"] = preprocess_linear_weight(model.cls[1].weight, dtype=ttnn.bfloat16) - if model.cls[1].bias is not None: - parameters["cls"]["linear_1"]["bias"] = preprocess_linear_bias(model.cls[1].bias, dtype=ttnn.bfloat16) - else: - parameters["cls"]["linear_1"]["bias"] = None - - parameters["cls"]["linear_2"] = {} - parameters["cls"]["linear_2"]["weight"] = preprocess_linear_weight(model.cls[3].weight, dtype=ttnn.bfloat16) - if model.cls[3].bias is not None: - parameters["cls"]["linear_2"]["bias"] = preprocess_linear_bias(model.cls[3].bias, dtype=ttnn.bfloat16) - else: - parameters["cls"]["linear_2"]["bias"] = None - - return parameters - - -@pytest.mark.parametrize( - "batch_size,input_channels,height,width", - [ - (1, 3, 320, 800), - ], -) -@pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True) -def test_UFD_V2_Model(device, batch_size, input_channels, height, width): - torch_model = Tu_Simple(input_height=height, input_width=width) - torch_model.to(torch.bfloat16) - torch_model.eval() - torch_input_tensor = torch.randn((batch_size, input_channels, height, width), dtype=torch.bfloat16) - torch_output = torch_model(torch_input_tensor) - ttnn_input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1)) - ttnn_input_tensor = ttnn.from_torch( - ttnn_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device - ) - - parameters = preprocess_model_parameters( - initialize_model=lambda: torch_model, - custom_preprocessor=custom_preprocessor, - device=device, - ) - ttnn_model = ttnn_UFLD_V2(conv_args=torch_model, conv_pth=parameters, device=device) - torch_output, pred_list = torch_model(torch_input_tensor) - ttnn_output, tt_pred_list = ttnn_model(input=ttnn_input_tensor) - ttnn_output = ttnn.to_torch(ttnn_output) - tt_pred_list["loc_row"] = ttnn.to_torch(tt_pred_list["loc_row"]) - tt_pred_list["loc_col"] = ttnn.to_torch(tt_pred_list["loc_col"]) - tt_pred_list["exist_row"] = ttnn.to_torch(tt_pred_list["exist_row"]) - tt_pred_list["exist_col"] = ttnn.to_torch(tt_pred_list["exist_col"]) - assert_with_pcc(torch_output, ttnn_output, 0.999) - assert_with_pcc(tt_pred_list["loc_row"], pred_list["loc_row"], 0.999) - assert_with_pcc(tt_pred_list["loc_col"], pred_list["loc_col"], 0.999) - assert_with_pcc(tt_pred_list["exist_row"], pred_list["exist_row"], 0.999) - assert_with_pcc(tt_pred_list["exist_col"], pred_list["exist_col"], 0.999) diff --git a/models/experimental/functional_UFLD_v2/ttnn/ttnn_UFLD_v2.py b/models/experimental/functional_UFLD_v2/ttnn/ttnn_UFLD_v2.py index 0ffc4b5af9d..0c43d88b9ae 100644 --- a/models/experimental/functional_UFLD_v2/ttnn/ttnn_UFLD_v2.py +++ b/models/experimental/functional_UFLD_v2/ttnn/ttnn_UFLD_v2.py @@ -106,21 +106,12 @@ def __init__(self, conv_args, conv_pth, device, is_downsample=False): self.conv_args.downsample[0], conv_pth.downsample, device=device, activation="" ) - def __call__(self, device, input): + def __call__(self, input): x_identity = input x, out_ht, out_wdth = self.conv1(input) - # if x.is_sharded(): - # x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG) - # x = ttnn.reshape(x, (1, out_ht, out_wdth, x.shape[-1])) # RESHAPING FROM (1,1,NHW,C) TO (N,H,W,C) TO AVOID OOM x, out_ht, out_wdth = self.conv2(x) - # if x.is_sharded(): - # x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG) - # x = ttnn.reshape(x, (1, out_ht, out_wdth, x.shape[-1])) if self.is_downsample: x_identity, out_ht, out_wdth = self.downsample(input) - # if x_identity.is_sharded(): - # x_identity = ttnn.sharded_to_interleaved(x_identity, memory_config=ttnn.L1_MEMORY_CONFIG) - # x_identity = ttnn.reshape(x_identity, (1, out_ht, out_wdth, x_identity.shape[-1])) x = ttnn.add(x, x_identity, memory_config=x.memory_config()) x = ttnn.relu(x) @@ -129,6 +120,7 @@ def __call__(self, device, input): class ttnn_Resnet_34: def __init__(self, conv_args, conv_pth, device): + self.maxpool_args = conv_args.maxpool self.device = device self.conv1 = ttnn_UFLD_V2_Conv2D(conv_args.conv1, conv_pth.conv1, device=self.device, activation="relu") # layer-1 @@ -178,49 +170,48 @@ def __init__(self, conv_args, conv_pth, device): conv_args.layer4[2], conv_pth.layer4_2, device=self.device, is_downsample=False ) - def __call__(self, x): # [1, 320, 800, 3] - batch_size = x.shape[0] - x, out_ht, out_wdth = self.conv1(x) # [1, 1, 64000, 64] #0.99974 + def __call__(self, x, batch_size=1): # [1, 320, 800, 3] + x, out_ht, out_wdth = self.conv1(x) x = ttnn.max_pool2d( x, batch_size=batch_size, input_h=out_ht, input_w=out_wdth, channels=x.shape[-1], - kernel_size=[3, 3], - stride=[2, 2], - padding=[1, 1], - dilation=[1, 1], + kernel_size=[self.maxpool_args.kernel_size, self.maxpool_args.kernel_size], + stride=[self.maxpool_args.stride, self.maxpool_args.stride], + padding=[self.maxpool_args.padding, self.maxpool_args.padding], + dilation=[self.maxpool_args.dilation, self.maxpool_args.dilation], ) if x.is_sharded(): x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG) x = ttnn.to_layout(x, ttnn.TILE_LAYOUT) - # x = ttnn.reshape(x, (1, 80, 200, x.shape[-1])) - x = self.layer1_0(device=self.device, input=x) - x = self.layer1_1(device=self.device, input=x) - x = self.layer1_2(device=self.device, input=x) + x = self.layer1_0(x) + x = self.layer1_1(x) + x = self.layer1_2(x) - x = self.layer2_0(device=self.device, input=x) - x = self.layer2_1(device=self.device, input=x) - x = self.layer2_2(device=self.device, input=x) - x = self.layer2_3(device=self.device, input=x) + x = self.layer2_0(x) + x = self.layer2_1(x) + x = self.layer2_2(x) + x = self.layer2_3(x) - x = self.layer3_0(device=self.device, input=x) - x = self.layer3_1(device=self.device, input=x) - x = self.layer3_2(device=self.device, input=x) - x = self.layer3_3(device=self.device, input=x) - x = self.layer3_4(device=self.device, input=x) - x = self.layer3_5(device=self.device, input=x) + x = self.layer3_0(x) + x = self.layer3_1(x) + x = self.layer3_2(x) + x = self.layer3_3(x) + x = self.layer3_4(x) + x = self.layer3_5(x) - x = self.layer4_0(device=self.device, input=x) - x = self.layer4_1(device=self.device, input=x) - x = self.layer4_2(device=self.device, input=x) + x = self.layer4_0(input=x) + x = self.layer4_1(input=x) + x = self.layer4_2(input=x) return x class ttnn_UFLD_V2: def __init__(self, conv_args, conv_pth, device): + self.device = device self.input_height = 320 self.input_width = 800 self.num_grid_row = 100 @@ -239,33 +230,54 @@ def __init__(self, conv_args, conv_pth, device): self.input_height = self.input_height self.input_width = self.input_width self.input_dim = self.input_height // 32 * self.input_width // 32 * 8 - self.device = device self.conv_pth = conv_pth - self.res_model = ttnn_Resnet_34(conv_args, conv_pth.res_model, device=device) - self.pool = ttnn_UFLD_V2_Conv2D(conv_args.pool, conv_pth.pool, activation="", device=device) + self.res_model = ttnn_Resnet_34(conv_args, conv_pth.res_model, device=self.device) + self.pool = ttnn_UFLD_V2_Conv2D(conv_args.pool, conv_pth.pool, activation="", device=self.device) def __call__(self, input): - fea = self.res_model(input) # 0.998 - fea, out_h, out_w = self.pool(fea) # 0.979 + batch_size = input.shape[0] + fea = self.res_model(input, batch_size=batch_size) + fea, out_h, out_w = self.pool(fea) if fea.is_sharded(): fea = ttnn.sharded_to_interleaved(fea, ttnn.L1_MEMORY_CONFIG) + fea = ttnn.reshape(fea, (batch_size, out_h, out_w, fea.shape[-1])) fea = ttnn.permute(fea, (0, 3, 1, 2)) - fea = ttnn.reshape(fea, (1, 2000)) + fea = ttnn.reshape( + fea, (batch_size, int(fea.shape[0] * fea.shape[1] * fea.shape[2] * fea.shape[3]) // batch_size) + ) + grid_size = (8, 8) + shard_grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(grid_size[0] - 1, grid_size[1] - 1), + ) + } + ) + shard_shape = [32, 32] + print("shard shape is", shard_shape) + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR) + width_sharded_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1, shard_spec + ) + fea = ttnn.to_memory_config(fea, width_sharded_mem_config) out = ttnn.linear( fea, self.conv_pth.cls.linear_1.weight, bias=self.conv_pth.cls.linear_1.bias, - memory_config=ttnn.L1_MEMORY_CONFIG, + memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, ) - out = ttnn.relu(out) out = ttnn.linear( out, self.conv_pth.cls.linear_2.weight, bias=self.conv_pth.cls.linear_2.bias, - memory_config=ttnn.L1_MEMORY_CONFIG, + memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, ) - out = ttnn.to_layout(out, ttnn.ROW_MAJOR_LAYOUT) + if out.is_sharded(): + out = ttnn.sharded_to_interleaved(out, ttnn.L1_MEMORY_CONFIG) + if out.layout != ttnn.ROW_MAJOR_LAYOUT: + out = ttnn.to_layout(out, ttnn.ROW_MAJOR_LAYOUT) loc_row, loc_col, exist_row, exist_col = ( out[:, : self.dim1], out[:, self.dim1 : self.dim1 + self.dim2], diff --git a/tests/ttnn/integration_tests/UFLD_v2/test_ttnn_UFLD_v2.py b/tests/ttnn/integration_tests/UFLD_v2/test_ttnn_UFLD_v2.py index 1cabd243074..d3db05dca7f 100644 --- a/tests/ttnn/integration_tests/UFLD_v2/test_ttnn_UFLD_v2.py +++ b/tests/ttnn/integration_tests/UFLD_v2/test_ttnn_UFLD_v2.py @@ -392,19 +392,20 @@ def test_ufld_v2_basic_block(device, batch_size, input_channels, height, width): @pytest.mark.parametrize( "batch_size,input_channels,height,width", [ - (1, 3, 320, 800), + # (1, 3, 320, 800), + (2, 3, 320, 800), ], ) @skip_for_grayskull() @pytest.mark.parametrize( "use_pretrained_weight", [ - False, - # True + # False, + True ], # uncomment to run the model for real weights ids=[ - "pretrained_weight_false", - # "pretrained_weight_true", # uncomment to run the model for real weights + # "pretrained_weight_false", + "pretrained_weight_true", # uncomment to run the model for real weights ], ) @pytest.mark.parametrize("device_params", [{"l1_small_size": 79104}], indirect=True) @@ -424,12 +425,6 @@ def test_UFD_V2_Model(device, batch_size, input_channels, height, width, use_pre torch_model.load_state_dict(new_state_dict) ttnn_input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1)) - ttnn_input_tensor = ttnn_input_tensor.reshape( - 1, - 1, - (ttnn_input_tensor.shape[0] * ttnn_input_tensor.shape[1] * ttnn_input_tensor.shape[2]), - ttnn_input_tensor.shape[-1], - ) ttnn_input_tensor = ttnn.from_torch( ttnn_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device ) @@ -442,7 +437,6 @@ def test_UFD_V2_Model(device, batch_size, input_channels, height, width, use_pre parameters.conv_args = infer_ttnn_module_args( model=torch_model, run_model=lambda model: torch_model(torch_input_tensor), device=device ) - ttnn_model = ttnn_UFLD_V2(conv_args=parameters.conv_args, conv_pth=parameters, device=device) torch_output, pred_list = torch_model(torch_input_tensor) ttnn_output, tt_pred_list = ttnn_model(input=ttnn_input_tensor)