Skip to content

Commit

Permalink
feat: support 1684 masked fill static
Browse files Browse the repository at this point in the history
Change-Id: I9644c27dc9d5cd096ade6f3fb2613e9a49d135ec
  • Loading branch information
LuTaoChen authored and Korbin-chen committed Jun 19, 2023
1 parent 34e83f5 commit 5721ed0
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 23 deletions.
4 changes: 4 additions & 0 deletions include/tpu_mlir/Backend/BM168x/BM1684.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ typedef void (*nodechip_space2batch_fix8b)(uint64_t input_global_addr, uint64_t
typedef void (*nodechip_batch2space)(uint64_t input_global_addr, uint64_t output_global_addr, uint64_t buffer_global_addr, const int* input_shape, const int input_dim, const int* block_sizes, const int* pad_sizes, int* output_shape, CMD_ID_NODE *pid_node);
typedef void (*nodechip_batch2space_fix8b)(uint64_t input_global_addr, uint64_t output_global_addr, uint64_t buffer_global_addr, uint64_t imm_global_addr, uint64_t *buffer_size, const int* input_shape, const int input_dim, int in_store_mode, int out_store_mode, const int* block_sizes, const int* crops_sizes, int* output_shape, CMD_ID_NODE *pid_node);
typedef void (*nodechip_unary)(uint64_t bottom_global_addr, uint64_t top_global_addr, uint64_t length, UNARY_FUNC_TYPE type, void *param, CMD_ID_NODE *pid_node);
typedef void (*nodechip_masked_fill_global)(uint64_t input_global_addr, uint64_t mask_global_addr, uint64_t output_global_addr, uint32_t *input_shape, uint32_t *mask_shape, int input_dims, int mask_dims, uint32_t value, CMD_ID_NODE *pid_node);
typedef void (*nodechip_masked_fill_local)(uint32_t input_addr, uint32_t mask_addr, uint32_t buffer_addr, uint32_t output_addr, const int *input_shape, const int *mask_shape, int input_dims, int mask_dims, uint32_t value, CMD_ID_NODE *pid_node);

// clang-format on
namespace tpu_mlir {
Expand Down Expand Up @@ -417,6 +419,8 @@ class BM1684 : public BM168x {
nodechip_batch2space dl_nodechip_batch2space;
nodechip_batch2space_fix8b dl_nodechip_batch2space_fix8b;
nodechip_unary dl_nodechip_unary;
nodechip_masked_fill_global dl_nodechip_masked_fill_global;
nodechip_masked_fill_local dl_nodechip_masked_fill_local;
// clang-format on
public:
virtual uint32_t get_bdc_len(int bdc_num, int group_id) override;
Expand Down
1 change: 1 addition & 0 deletions include/tpu_mlir/Conversion/TopToTpu/LoweringBM1684.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ LOWERING_BM1684(Compare)
LOWERING_BM1684(CompareConst)
LOWERING_BM1684(Mish)
LOWERING_BM1684(Softsign)
LOWERING_BM1684(MaskedFill)
} // namespace bm1684
} // namespace tpu_mlir
2 changes: 2 additions & 0 deletions lib/Backend/BM168x/BM1684.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,5 +262,7 @@ void BM1684::load_functions() {
CAST_FUNCTION(nodechip_batch2space);
CAST_FUNCTION(nodechip_batch2space_fix8b);
CAST_FUNCTION(nodechip_unary);
CAST_FUNCTION(nodechip_masked_fill_global);
CAST_FUNCTION(nodechip_masked_fill_local);
// clang-format on
}
25 changes: 25 additions & 0 deletions lib/Conversion/TopToTpu/BM1684/MaskedFill.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===----------------------------------------------------------------------===//
//
// Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved.
//
// TPU-MLIR is licensed under the 2-Clause BSD License except for the
// third-party components.
//
//===----------------------------------------------------------------------===//

#include "tpu_mlir/Conversion/TopToTpu/LoweringBM1684.h"

namespace tpu_mlir {
namespace bm1684 {

void MaskedFillLowering::LoweringF32(PatternRewriter &rewriter, top::MaskedFillOp op) const {
lowering_common_f32<tpu::MaskedFillOp>(rewriter, op, 2);
}

void MaskedFillLowering::LoweringINT8(PatternRewriter &rewriter, top::MaskedFillOp op,
bool asymmetric) const {
lowering_common_f32<tpu::MaskedFillOp>(rewriter, op, 2);
}

} // namespace bm1684
} // namespace tpu_mlir
3 changes: 2 additions & 1 deletion lib/Conversion/TopToTpu/LoweringBM1684.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ void populateTopToTpuConversionPatterns(RewritePatternSet *patterns) {
CompareLowering,
CompareConstLowering,
MishLowering,
SoftsignLowering
SoftsignLowering,
MaskedFillLowering
// clang-format on
>(patterns->getContext());
}
Expand Down
84 changes: 65 additions & 19 deletions lib/Dialect/Tpu/Interfaces/BM1684/MaskedFill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,86 @@
//
//===----------------------------------------------------------------------===//

#include "tpu_mlir/Dialect/Tpu/IR/TpuOps.h"
#include "tpu_mlir/Backend/BM168x/BM1684.h"

#include "tpu_mlir/Dialect/Tpu/IR/TpuOps.h"
#include "tpu_mlir/Dialect/Tpu/Transforms/Codegen/Dynamic/DynamicLayer.hpp"
#include "tpu_mlir/Support/MathUtils.h"
#include "tpu_mlir/Support/Module.h"



using namespace tpu_mlir::backend;

void tpu::MaskedFillOp::codegen_global_bm1684() {
llvm_unreachable("Not Implemented");
auto input = getBrn();
auto mask = getCond();
auto mask_addr = module::getAddress(mask);
auto input_addr = module::getAddress(input);
auto top_addr = module::getAddress(getOutput());
float value = getConstVal().convertToDouble();
int *input_shape = new int[MAX_SHAPE_DIMS];
int *mask_shape = new int[MAX_SHAPE_DIMS];
int input_dims = module::getShape(input).size();
int mask_dims = module::getShape(mask).size();

for (auto v : llvm::enumerate(module::getShape(input)))
input_shape[v.index()] = (int)v.value();
for (auto v : llvm::enumerate(module::getShape(mask)))
mask_shape[v.index()] = (int)v.value();

BM1684::instance().dl_nodechip_masked_fill_global(
input_addr, mask_addr, top_addr, (uint32_t *)input_shape,
(uint32_t *)mask_shape, input_dims, mask_dims, *((uint32_t *)&value),
(CMD_ID_NODE *)BM1684::instance().cmdid_node);
}

int64_t tpu::MaskedFillOp::getBufferSize_bm1684(int64_t in_lmem_bytes,
int64_t out_lmem_bytes,
int64_t in_nslice, int64_t in_hslice,
int64_t out_nslice,
int64_t out_hslice) {
return 0;
int64_t tpu::MaskedFillOp::getBufferSize_bm1684(
int64_t in_lmem_bytes, int64_t out_lmem_bytes, int64_t in_nslice,
int64_t in_hslice, int64_t out_nslice, int64_t out_hslice) {
int64_t buffer_size = 0;
int64_t n, c, h, w;
module::getNCHW(getOutput(), n, c, h, w);
auto EU_NUM = BM1684::eu_num(sizeof(int32_t));
buffer_size = out_nslice * ceiling_func(c, BM1684::NPU_NUM) *
align_up(out_hslice * w, EU_NUM) * sizeof(float);

return buffer_size;
}

void tpu::MaskedFillOp::codegen_local_bm1684(int64_t n_step, int64_t h_step, local_sec_info_t &sec_info) {
llvm_unreachable("Not Implemented");
void tpu::MaskedFillOp::codegen_local_bm1684(int64_t n_step, int64_t h_step,
local_sec_info_t &sec_info) {
if (module::isUniformQuantized(getOutput())) {
llvm_unreachable("Not Implemented");
}
auto input_dims = module::getShape(getBrn()).size();
auto mask_dims = module::getShape(getCond()).size();
auto output_dims = module::getShape(getCond()).size();
assert(input_dims == mask_dims);
assert(input_dims == output_dims);
assert(output_dims == mask_dims);
int input_shape[input_dims], mask_shape[mask_dims];
module::getLocalShape(getBrn(), n_step, h_step, input_shape);
module::getLocalShape(getCond(), n_step, h_step, mask_shape);
auto top_ginfo = getGroupInfo(n_step, h_step, 0, 0, 0);
auto input_ginfo =
LocalGenInterface::getGroupInfo(getBrn(), n_step, h_step, 0, 0, 0);
auto mask_ginfo =
LocalGenInterface::getGroupInfo(getCond(), n_step, h_step, 0, 0, 0);
auto output_ginfo =
LocalGenInterface::getGroupInfo(getOutput(), n_step, h_step, 0, 0, 0);

float value = getConstVal().convertToDouble();
BM1684::instance().dl_nodechip_masked_fill_local(
input_ginfo.out_addr, mask_ginfo.out_addr, top_ginfo.buffer_addr,
top_ginfo.out_addr, input_shape, mask_shape, 4, 4, *((uint32_t *)&value),
(CMD_ID_NODE *)BM1684::instance().bdc_node);
}

uint32_t tpu::MaskedFillOp::dyn_codegen_global_bm1684(void* ir_layer_info) {
uint32_t tpu::MaskedFillOp::dyn_codegen_global_bm1684(void *ir_layer_info) {
llvm_unreachable("Not Implemented");
return 0;
}
int64_t tpu::MaskedFillOp::get_fw_type_bm1684() {
return -1;
}
int64_t tpu::MaskedFillOp::get_fw_type_bm1684() { return -1; }

int32_t tpu::MaskedFillOp::dyn_codegen_local_bm1684(void* ir_layer_info) {
int32_t tpu::MaskedFillOp::dyn_codegen_local_bm1684(void *ir_layer_info) {
llvm_unreachable("Not Implemented");
return 0;
}
}
39 changes: 38 additions & 1 deletion python/test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self,
"Linear": (self.test_Linear, N, Y, Y, Y),
"LogSoftmax": (self.test_LogSoftmax, N, Y, Y, Y),
"LSTM": (self.test_LSTM, N, Y, Y, Y),
"MaskedFill": (self.test_MaskedFill, Y, Y, Y, N),
"Math": (self.test_Math, N, Y, Y, N),
"MatMul": (self.test_MatMul, N, Y, Y, Y),
"Max": (self.test_Max, N, Y, Y, N),
Expand Down Expand Up @@ -203,7 +204,8 @@ def cosine_similarity(self, x, y):
def compare(self, ref_out, targe_out):
if ref_out.dtype in [np.int64, np.int32, np.int16, np.int8]:
cos = self.cosine_similarity(ref_out, targe_out)
assert (cos > 0.997)
assert (cos > 0.997 or (np.linalg.norm(ref_out) == 0
and np.linalg.norm(targe_out) == 0))
else:
np.testing.assert_allclose(ref_out, targe_out, rtol=1e-5, atol=1e-01)

Expand Down Expand Up @@ -1419,6 +1421,41 @@ def forward(self, x):
if not self.is_cv18xx:
_test_t((32, ))

#######################################################################
# MaskedFill
# ------------
def test_MaskedFill(self):
def _test_masked_fill(in_shape, mask_shape, is_local):
class Model(nn.Module):

def __init__(self):
super(Model, self).__init__()

def forward(self, x, mask):
if is_local:
x += x
x -= 2
x *= 2
x += 1
x = torch.masked_fill(x, mask, 5)
if is_local:
x += 1
return x

self.trace_and_test([in_shape, mask_shape], Model(),
[self.Desc('float', -10, 10), self.Desc('int', 0, 2)])

dims = [3, 4, 5]
shape = [1, 3, 128, 300, 2]
for dim in dims:
shapes = [shape[: dim], shape[: dim]]
odd = True
for i in range(dim):
shapes[odd][i] = 1
odd = not odd
_test_masked_fill(tuple(shapes[0]), tuple(shapes[1]), False)
_test_masked_fill(([1, 3, 1, 300]), ([1, 1, 128, 300]), True)

#######################################################################
# Math: cos/sin/tan/tanh
# ------------
Expand Down
14 changes: 14 additions & 0 deletions python/transform/TorchConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(self,
"aten::log_softmax": lambda node: self.convert_softmax_op(node, log=True),
"aten::lstm": lambda node: self.convert_lstm_op(node),
"aten::lt": lambda node: self.convert_compare_op(node, "Less"),
"aten::masked_fill": lambda node: self.convert_masked_fill(node),
"aten::matmul": lambda node: self.convert_matmul_op(node),
"aten::max": lambda node: self.convert_max_op(node),
"aten::max_pool1d": lambda node: self.convert_maxpool_op(node),
Expand Down Expand Up @@ -1408,6 +1409,19 @@ def convert_gelu_op(self, torch_node: TorchNode):
ip=self.mlir.insert_point).output
self.addOperand(torch_node.name, new_op)

def convert_masked_fill(self, torch_node: TorchNode):
x = self.getOp(torch_node.inputs[0])
mask = self.getOp(torch_node.inputs[1])
const_val = self.const_val[torch_node.inputs[2]]
new_op = top.MaskedFillOp(self.unranked_type,
mask,
x,
inversed=True,
const_val=const_val,
loc=self.get_loc(torch_node.name),
ip=self.mlir.insert_point).output
self.addOperand(torch_node.name, new_op)

def convert_matmul_op(self, torch_node: TorchNode):
op0 = self.getOp(torch_node.inputs[0])
op1 = self.getOp(torch_node.inputs[1])
Expand Down
4 changes: 2 additions & 2 deletions third_party/nntoolchain/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
## TPU1684 2023-06-05
sha256: 647b030b051283d3ebcd4b59799fe47f29cd82da
## TPU1684 2023-06-16
sha256: 9540a9b1bff45c9ffb11507c6b2388f95f21994a

``` bash
cd nntoolchain/net_compiler/
Expand Down
Binary file modified third_party/nntoolchain/lib/libbackend_1684.so
Binary file not shown.
Binary file modified third_party/nntoolchain/lib/libcmodel_1684.so
Binary file not shown.

0 comments on commit 5721ed0

Please sign in to comment.