Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix asymmetrical padding issue in conv2d #1362

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions forge/csrc/passes/fuse_pad_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@ void fuse_pad_conv2d(graphlib::Graph *graph)
continue;

auto attrs = op->op_attrs();
if (std::get<int>(attrs[attrs.size() - 2]) != 0)

auto top_pad = std::get<int>(attrs[0]);
auto bottom_pad = std::get<int>(attrs[1]);
auto left_pad = std::get<int>(attrs[2]);
auto right_pad = std::get<int>(attrs[3]);

// Check if top_pad != bottom_pad or left_pad != right_pad
if (top_pad != bottom_pad || left_pad != right_pad)
{
// Second last attr must be 0 (constant mode), otherwise cannot fuse into conv2d
// Padding conditions are not equal, cannot fuse into conv2d
continue;
}

auto users = graph->users(node);

bool all_users_are_conv2d = true;
Expand Down Expand Up @@ -59,7 +67,6 @@ void fuse_pad_conv2d(graphlib::Graph *graph)
{
graphlib::OpNode *user_op = dynamic_cast<graphlib::OpNode *>(user);
auto conv_attrs = user_op->op_attrs();
TT_ASSERT(conv_attrs.size() == 13);
// Conv2d attributes are
// [stride[0],stride[1],dilation,groups,padding[0],padding[1],padding[2],padding[3],channel_last]
int pad_idx_offset = 4;
Expand Down
201 changes: 0 additions & 201 deletions forge/forge/op/eval/forge/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,207 +1110,6 @@ def decompose(type, attr, dc, inputs):
dc.fuse(result)
return

if type == "pad":
if all([x == 0 for x in attr[0:-2]]):
# Pad size is 0
result = dc.op(Nop.create(), [inputs[0]])
dc.fuse(result)

activations = inputs[0]
mode_idx = attr[-2]
channel_last = attr[-1]
if channel_last:
r = activations.shape[-3]
c = activations.shape[-2]
else:
r = activations.shape[-2]
c = activations.shape[-1]

# Find out if padding exceeds tile boundary
# R, C are flipped because pytorch pad starts from last axis
if len(attr) == 4:
total_padding_c = attr[0] + attr[1]
total_padding_r = 0
all_around_padding = attr[:-2] + [0, 0]
elif len(attr) == 6:
total_padding_c = attr[0] + attr[1]
total_padding_r = attr[2] + attr[3]
all_around_padding = attr[:-2]
else:
raise RuntimeError("Forge only support Pad with either 2 or 4 attributes")

if (
((len(attr) == 4 and attr[0] == 0) or (len(attr) == 6 and attr[0] == 0 and attr[2] == 0))
and not channel_last
and math.ceil((total_padding_r + r) / TILE_DIM) == math.ceil(r / TILE_DIM)
and math.ceil((total_padding_c + c) / TILE_DIM) == math.ceil(c / TILE_DIM)
and mode_idx == 0 # 'constant' mode
):
# Pad does not exceed tile boundary and only on the end of axis
# Will be lowered into NOP
return

else:
# Lower into concats
left, right, top, bottom = 0, 0, 0, 0
if len(attr) == 4:
left, right, _, _ = attr

elif len(attr) == 6:
left, right, top, bottom, _, _ = attr
else:
raise RuntimeError("Forge only support Pad with either 3 or 5 attributes")

if mode_idx == 1: # 'replicate' mode
result = activations

if channel_last:
result = dc.op(TransposeTM.create(-3, -1, result.shape[-3]), [result])

orig_shape = result.shape
result = dc.op("reshape", [result], (1, 1, orig_shape[-3], orig_shape[-2] * orig_shape[-1]))
result = dc.op(TransposeTM.create(-2, -1), [result])
spm = create_pad_replicate_sparse_picker(c, r, top, bottom, left, right)
spm = dc.tensor(spm)
result = dc.op("sparse_matmul", [spm, result])
result = dc.op(TransposeTM.create(-2, -1), [result])
result = dc.op(
"reshape",
[result],
(1, orig_shape[-3], orig_shape[-1] + total_padding_r, orig_shape[-2] + total_padding_c),
)

result = dc.op(TransposeTM.create(-3, -1, result.shape[-3]), [result])
else:
orig_shape = result.shape
if len(orig_shape) == 2:
result = dc.op("reshape", [result], (1, orig_shape[-2] * orig_shape[-1]))
else:
result = dc.op("reshape", [result], (1, 1, orig_shape[-3], orig_shape[-2] * orig_shape[-1]))
result = dc.op(TransposeTM.create(-2, -1), [result])
spm = create_pad_replicate_sparse_picker(r, c, left, right, top, bottom)
spm = dc.tensor(spm)
result = dc.op("sparse_matmul", [spm, result])
result = dc.op(TransposeTM.create(-2, -1), [result])
if len(orig_shape) == 2:
result = dc.op(
"reshape", [result], (orig_shape[-2] + total_padding_r, orig_shape[-1] + total_padding_c)
)
else:
result = dc.op(
"reshape",
[result],
(1, orig_shape[-3], orig_shape[-2] + total_padding_r, orig_shape[-1] + total_padding_c),
)

dc.fuse(result)
return

elif mode_idx == 0: # 'constant' mode
c_dim_axis = -2 if channel_last else -1
r_dim_axis = -3 if channel_last else -2

# On right or bottom, we can concat all the way to TILE boundary
result = activations
if left > 0:
pad_shape = result.shape.as_list().copy()
pad_shape[c_dim_axis] = left
tensor = torch.zeros(pad_shape)
const_tensor = dc.tensor(tensor)
result = dc.op("concatenate", [const_tensor, result], [c_dim_axis])

if right > 0:
pad_shape = result.shape.as_list().copy()
pad_shape[c_dim_axis] = (
TILE_DIM if pad_shape[c_dim_axis] % TILE_DIM == 0 and right < TILE_DIM else right
)
tensor = torch.zeros(pad_shape)
const_tensor = dc.tensor(tensor)
result = dc.op("concatenate", [result, const_tensor], [c_dim_axis])

if top > 0:
pad_shape = result.shape.as_list().copy()
pad_shape[r_dim_axis] = top
tensor = torch.zeros(pad_shape)
const_tensor = dc.tensor(tensor)
result = dc.op("concatenate", [const_tensor, result], [r_dim_axis])

if bottom > 0:
pad_shape = result.shape.as_list().copy()
pad_shape[r_dim_axis] = (
TILE_DIM if pad_shape[r_dim_axis] % TILE_DIM == 0 and bottom < TILE_DIM else bottom
)
tensor = torch.zeros(pad_shape)
const_tensor = dc.tensor(tensor)
result = dc.op("concatenate", [result, const_tensor], [r_dim_axis])

result = dc.op("narrow", [result], (c_dim_axis, 0, total_padding_c + c, result.shape[c_dim_axis]))
if channel_last:
result = dc.op("select", [result], (r_dim_axis, 0, total_padding_r + r, result.shape[r_dim_axis]))
else:
result = dc.op("narrow", [result], (r_dim_axis, 0, total_padding_r + r, result.shape[r_dim_axis]))

dc.fuse(result)
return

elif mode_idx == 2:
# Reflect mode
result = activations

if channel_last:
result = dc.op(TransposeTM.create(-3, -1, result.shape[-3]), [result])

orig_shape = result.shape
result = dc.op_with_named_attrs(
"reshape",
[result],
{"shape": (1, 1, orig_shape[-3], orig_shape[-2] * orig_shape[-1])},
(1, 1, orig_shape[-3], orig_shape[-2] * orig_shape[-1]),
)
result = dc.op(TransposeTM.create(-2, -1), [result])
spm = create_pad_reflect_sparse_picker(c, r, top, bottom, left, right)
spm = dc.tensor(spm.to_dense())
result = dc.op("matmul", [spm, result])
result = dc.op(TransposeTM.create(-2, -1), [result])
result = dc.op_with_named_attrs(
"reshape",
[result],
{
"shape": (
1,
orig_shape[-3],
orig_shape[-1] + total_padding_r,
orig_shape[-2] + total_padding_c,
)
},
(1, orig_shape[-3], orig_shape[-1] + total_padding_r, orig_shape[-2] + total_padding_c),
)

result = dc.op(TransposeTM.create(-3, -1, result.shape[-3]), [result])
else:
orig_shape = result.shape
if len(orig_shape) == 2:
shape = (1, orig_shape[-2] * orig_shape[-1])
else:
shape = (1, 1, orig_shape[-3], orig_shape[-2] * orig_shape[-1])

result = dc.op_with_named_attrs("reshape", [result], {"shape": shape}, shape)
result = dc.op(TransposeTM.create(-2, -1), [result])
spm = create_pad_reflect_sparse_picker(r, c, left, right, top, bottom)
spm = dc.tensor(spm.to_dense())
result = dc.op("matmul", [spm, result])
result = dc.op(TransposeTM.create(-2, -1), [result])

if len(orig_shape) == 2:
shape = (orig_shape[-2] + total_padding_r, orig_shape[-1] + total_padding_c)
else:
shape = (1, orig_shape[-3], orig_shape[-2] + total_padding_r, orig_shape[-1] + total_padding_c)

result = dc.op_with_named_attrs("reshape", [result], {"shape": shape}, shape)

dc.fuse(result)
return

if type == "broadcast":
if attr[1] == 1:
dc.fuse(dc.op(Nop.create(), [inputs[0]]))
Expand Down
9 changes: 4 additions & 5 deletions forge/forge/tvm_calls/relay/op/forge_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,11 @@ def callback(self, pre, post, node_map):

pad_mode = pad.attrs.pad_mode

if pad_mode == "constant":
padding = [top_pad, left_pad, bottom_pad, right_pad]

if pad_mode == "reflect":
act = tvm.relay.op.nn.pad(act, pad_width, pad_mode="reflect")
if top_pad != bottom_pad or left_pad != right_pad:
act = tvm.relay.op.nn.pad(act, pad_width, pad_mode=pad_mode)
padding = [0, 0, 0, 0]
else:
padding = [top_pad, left_pad, bottom_pad, right_pad]

op_attrs = {**conv_pool.attrs}
op_attrs["padding"] = padding
Expand Down
10 changes: 1 addition & 9 deletions forge/test/mlir/operators/nn/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,7 @@ def forward(self, x):
pytest.param(
(1, 2, 1, 2),
marks=pytest.mark.xfail(
reason="TTNN only supports padding height/width attributes. Thus, padding_top "
"must equal padding_bottom for the op to execute as expected."
reason="RuntimeError: Found Unsupported operations while lowering from TTForge to TTIR in forward graph - Pad"
),
),
],
Expand All @@ -425,13 +424,6 @@ def forward(self, x):
x = nn.functional.pad(x, self.padding, mode="constant", value=0)
return self.conv(x)

pad_top, pad_bottom, pad_left, pad_right = padding
if pad_top != pad_bottom or pad_left != pad_right:
pytest.xfail(
"TTNN only supports padding height/width attributes. Thus, padding_top "
"must equal padding_bottom for the op to execute as expected."
)

inputs = [torch.rand(shape)]

framework_model = PaddingAndConv2d(padding=padding)
Expand Down
Loading