Skip to content

Commit

Permalink
add yolov5s post process test
Browse files Browse the repository at this point in the history
1. fix interpreter bar exceeded
2. add yolov5s post process test
3. fix multithread conflict

Change-Id: I98e805589dde4cc83c27865d9f658a90b3bf4ff7
  • Loading branch information
HarmonyHu committed May 20, 2023
1 parent dd3a276 commit 7d3a5d4
Show file tree
Hide file tree
Showing 15 changed files with 105 additions and 49 deletions.
7 changes: 4 additions & 3 deletions lib/Support/MathUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,14 +562,15 @@ void pad_tensor(float *p_after_pad, float *src, int n, int c, int d, int h,
void pad_tensor_for_deconv(float *p_after_pad, float *src, int n, int c, int d,
int h, int w, int kd, int kh, int kw, int dd, int dh,
int dw, int sd, int sh, int sw, int pdf, int pdb,
int pht, int phb, int pwl, int pwr, int opd, int oph, int opw,
float pad_value) {
int pht, int phb, int pwl, int pwr, int opd, int oph,
int opw, float pad_value) {
int nc = n * c;
int od = (d - 1) * sd + 1 + dd * (2 * kd - 2 - pdf - pdb) + opd;
int oh = (h - 1) * sh + 1 + dh * (2 * kh - 2 - pht - phb) + oph;
int ow = (w - 1) * sw + 1 + dw * (2 * kw - 2 - pwl - pwr) + opw;
int pst[3] = {(kd - 1) * dd - pdf, (kh - 1) * dh - pht, (kw - 1) * dw - pwl};
int ped[3] = {(kd - 1) * dd - pdb + opd, (kh - 1) * dh - phb + oph, (kw - 1) * dw - pwr + opw};
int ped[3] = {(kd - 1) * dd - pdb + opd, (kh - 1) * dh - phb + oph,
(kw - 1) * dw - pwr + opw};
for (int i = 0; i < nc; i++) {
for (int m = 0; m < od; m++) {
for (int j = 0; j < oh; j++) {
Expand Down
35 changes: 21 additions & 14 deletions lib/Support/ModuleInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,27 +374,32 @@ void ModuleInterpreter::invoke_all_in_mem(bool express_type) {
std::string if_name;
for (auto func : module.getOps<FuncOp>()) {
WalkResult result = func.walk<WalkOrder::PreOrder>([&](Operation *op) {
bar.update();
if (isa<func::FuncOp>(*op)) {
return WalkResult::advance();
}
std::string name;
if (op->getLoc().isa<NameLoc>() || op->getLoc().isa<FusedLoc>())
if (op->getLoc().isa<NameLoc>() || op->getLoc().isa<FusedLoc>()) {
name = module::getName(op).str();
}
LLVM_DEBUG(llvm::dbgs() << "compute: '" << op << "'\n");
if (flag && isa<func::FuncOp>(*(op->getParentOp())))
flag = 0; //clear
if (flag && isa<func::FuncOp>(*(op->getParentOp()))) {
flag = 0; // clear
}
if (isa<tpu::IfOp, top::IfOp>(op)) {
std::optional<RegisteredOperationName> info = op->getName().getRegisteredInfo();
std::optional<RegisteredOperationName> info =
op->getName().getRegisteredInfo();
if_name = name;
auto *inferInterface = info->getInterface<tpu_mlir::InferenceInterface>();
if (failed(inferInterface->inference(inferInterface, op, *inference_map[name]))) {
flag = 2; //else branch
auto *inferInterface =
info->getInterface<tpu_mlir::InferenceInterface>();
if (failed(inferInterface->inference(inferInterface, op,
*inference_map[name]))) {
flag = 2; // else branch
} else {
flag = 1;//then branch
flag = 1; // then branch
}
return WalkResult::advance();
} else if (isa<tpu_mlir::InferenceInterface>(op) && !flag) {
} else if (isa<tpu_mlir::InferenceInterface>(op) && 0 == flag) {
bar.update();
auto infer_op = dyn_cast<InferenceInterface>(op);
if (failed(infer_op.inference(*inference_map[name]))) {
infer_op.dump();
Expand All @@ -414,7 +419,8 @@ void ModuleInterpreter::invoke_all_in_mem(bool express_type) {
name = module::getName(op->getOperand(k).getDefiningOp()).str();
#pragma omp parallel for schedule(static, omp_schedule(num_element))
for (int i = 0; i < num_element; i++)
inference_map[if_name]->outputs[k][i] = inference_map[name]->outputs[k][i];
inference_map[if_name]->outputs[k][i] =
inference_map[name]->outputs[k][i];
}
}
}
Expand Down Expand Up @@ -781,9 +787,10 @@ bool ModuleInterpreter::getTensorQuantInfo(const std::string name,
}
scale = 1.0;
zp = 0;
}
else {
dtype = std::string("UK"); scale = 1.0; zp = 0;
} else {
dtype = std::string("UK");
scale = 1.0;
zp = 0;
}
return true;
}
Expand Down
4 changes: 3 additions & 1 deletion python/test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def test_single(self, case: str):
torch.manual_seed(7)
print("Test: {}".format(case))
if case in self.test_cases:
os.makedirs(case, exist_ok=True)
os.chdir(case)
func, _, _, _, _ = self.test_cases[case]
func(case)
print("====== TEST {} Success ======".format(case))
Expand Down Expand Up @@ -396,7 +398,7 @@ def inference_and_compare(self,
elif quant_mode == "int4":
ref_tpu_tolerance = "0.90,0.60"
elif quant_mode == "bf16":
ref_tpu_tolerance = "0.95,0.85"
ref_tpu_tolerance = "0.95,0.80"
tpu_npz = tpu_mlir.replace(".mlir", "_tpu_out.npz")
file_mark(tpu_npz)
show_fake_cmd(input_npz, tpu_mlir, tpu_npz)
Expand Down
2 changes: 2 additions & 0 deletions python/test/test_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def __init__(self, chip: str = "bm1684x", mode: str = "all"):
def test_single(self, case: str):
np.random.seed(0)
if case in self.test_function:
os.makedirs(case, exist_ok=True)
os.chdir(case)
print("Test: {}".format(case))
self.test_function[case](case)
print("====== TEST {} Success ======".format(case))
Expand Down
4 changes: 3 additions & 1 deletion python/test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def test_single(self, case: str):
TORCH_IR_TESTER.CURRENT_CASE = case
print("Test: {}".format(case))
if case in self.test_cases:
os.makedirs(case, exist_ok=True)
os.chdir(case)
func, _, _, _ = self.test_cases[case]
func()
print("====== TEST {} Success ======".format(case))
Expand Down Expand Up @@ -297,7 +299,7 @@ def inference_and_compare(self,
elif quant_mode == "int4":
ref_tpu_tolerance = "0.90,0.60"
elif quant_mode == "bf16":
ref_tpu_tolerance = "0.95,0.85"
ref_tpu_tolerance = "0.95,0.80"
input_data = np.load(input_npz)
# tpu mlir inference and compare
tpu_npz = tpu_mlir.replace(".mlir", "_tpu_out.npz")
Expand Down
2 changes: 2 additions & 0 deletions python/test/test_tpulang.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def test_single(self, case: str):
TPULANG_IR_TESTER.ID = 0
print("Test: {}".format(case))
if case in self.test_function:
os.makedirs(case, exist_ok=True)
os.chdir(case)
func, _ = self.test_function[case]
func(case)
print("====== TEST {} Success ======".format(case))
Expand Down
2 changes: 0 additions & 2 deletions regression/config/ssd-12.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ output_names=Concat_659,Slice_676
do_int8_asym=0 # has problem
do_f16=0
do_bf16=0
#app=detect_ssd-12.py
do_post_handle=0 #1: build post op into bmodel 0: don't build post op into bmodel
post_type=ssd

[bm1684x]
Expand Down
4 changes: 0 additions & 4 deletions regression/config/yolov5s.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@ keep_aspect_ratio=1
mean=0.0,0.0,0.0
scale=0.0039216,0.0039216,0.0039216
pixel_format=rgb
#output_names=350,498,646
output_names=326,474,622
do_dynamic=0
do_post_handle=0 #1: build post op into bmodel 0: don't build post op into bmodel
post_type=yolo
dynamic_shapes=[[1,3,320,320]]
app=detect_yolov5.py
# pad_value=114
pad_type=normal

[bm1684x]
Expand Down
37 changes: 27 additions & 10 deletions regression/generic_cvimodel_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import time
from run_model import MODEL_RUN
from utils.mlir_shell import _os_system


class CviSampleGenerator:

def __init__(self, dest_dir):
self.chips = ["cv180x", "cv181x", "cv182x", "cv183x"]
#self.chips = ["cv183x"]
Expand All @@ -36,17 +39,28 @@ def cmd_exec(self, cmd_str):
else:
raise RuntimeError("[!Error]: {}".format(cmd_str))

def run_sample_net(self, model, dst_model, chip, quant_type, customization_format = "", fuse_pre = False,
aligned_input = False, merge_weight = False):
def run_sample_net(self,
model,
dst_model,
chip,
quant_type,
customization_format="",
fuse_pre=False,
aligned_input=False,
merge_weight=False):
info = f"{chip} {model} {quant_type} fuse_preprocess:{fuse_pre} {customization_format} aligned_input:{aligned_input} merge_weight:{merge_weight}"
print(f"run_model.py {info}")
dir = os.path.expandvars(f"$REGRESSION_PATH/regression_out/{model}_{chip}")
os.makedirs(dir, exist_ok=True)
os.chdir(dir)
regressor = MODEL_RUN(model, chip, quant_type, \
dyn_mode=False, do_post_handle=False, merge_weight=merge_weight, \
fuse_preprocess=fuse_pre, customization_format=customization_format, \
aligned_input=aligned_input)
regressor = MODEL_RUN(model,
chip,
quant_type,
dyn_mode=False,
merge_weight=merge_weight,
fuse_preprocess=fuse_pre,
customization_format=customization_format,
aligned_input=aligned_input)
ret = regressor.run_full()
if ret == 0:
print(f"{info} Success.")
Expand All @@ -59,6 +73,7 @@ def run_sample_net(self, model, dst_model, chip, quant_type, customization_forma
def generate_chip_models(self, chip):
tmp_dir = os.path.expandvars(f"$REGRESSION_PATH/regression_out/cvimodel_samples")
os.makedirs(tmp_dir, exist_ok=True)
# yapf: disable
self.run_sample_net("mobilenet_v2_cf", os.path.join(tmp_dir, "mobilenet_v2.cvimodel"), chip, "int8_sym")
self.run_sample_net("mobilenet_v2_cf", os.path.join(tmp_dir, "mobilenet_v2_bf16.cvimodel"), chip, "bf16")
self.run_sample_net("mobilenet_v2_cf", os.path.join(tmp_dir, "mobilenet_v2_fused_preprocess.cvimodel"), chip, "int8_sym", "BGR_PLANAR", True)
Expand All @@ -76,12 +91,14 @@ def generate_chip_models(self, chip):
self.run_sample_net("alphapose_res50", os.path.join(tmp_dir, "alphapose_fused_preprocess.cvimodel"), chip, "int8_sym", "RGB_PLANAR", True)
self.run_sample_net("arcface_res50", os.path.join(tmp_dir, "arcface_res50_fused_preprocess.cvimodel"), chip, "int8_sym", "RGB_PLANAR", True)
self.run_sample_net("arcface_res50", os.path.join(tmp_dir, "arcface_res50_fused_preprocess_aligned_input.cvimodel"), chip, "int8_sym", "RGB_PLANAR", True, True)

# yapf: enable
#gen merged cvimodel
full_bs1_model = os.path.join(tmp_dir, "mobilenet_v2_bs1.cvimodel")
full_bs4_model = os.path.join(tmp_dir, "mobilenet_v2_bs4.cvimodel")
self.run_sample_net("mobilenet_v2_cf", full_bs1_model, chip, "int8_sym", "", False, False, True)
self.run_sample_net("mobilenet_v2_cf_bs4", full_bs4_model, chip, "int8_sym", "", False, False, True)
self.run_sample_net("mobilenet_v2_cf", full_bs1_model, chip, "int8_sym", "", False, False,
True)
self.run_sample_net("mobilenet_v2_cf_bs4", full_bs4_model, chip, "int8_sym", "", False,
False, True)
merge_model = os.path.join(tmp_dir, "mobilenet_v2_bs1_bs4.cvimodel")
merge_cmd = f"model_tool --combine {full_bs1_model} {full_bs4_model} -o {merge_model}"
self.cmd_exec(merge_cmd)
Expand All @@ -94,11 +111,11 @@ def generate_chip_models(self, chip):
os.chdir(self.current_dir)
shutil.rmtree(tmp_dir)


def generate_models(self):
for chip in self.chips:
self.generate_chip_models(chip)


if __name__ == "__main__":
t1 = time.time()
parser = argparse.ArgumentParser()
Expand Down
17 changes: 4 additions & 13 deletions regression/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(self,
chip: str = "bm1684x",
mode: str = "all",
dyn_mode: bool = False,
do_post_handle: bool = False,
merge_weight: bool = False,
fuse_preprocess: bool = False,
customization_format: str = "",
Expand All @@ -36,7 +35,6 @@ def __init__(self,
self.model_name = model_name
self.chip = chip
self.mode = mode
self.do_post_handle = do_post_handle
self.dyn_mode = dyn_mode
self.fuse_pre = fuse_preprocess
self.customization_format = customization_format
Expand Down Expand Up @@ -70,7 +68,7 @@ def __init__(self,
self.tolerance = {
"f32": config.get(self.arch, "f32_tolerance", fallback="0.99,0.99"),
"f16": config.get(self.arch, "f16_tolerance", fallback="0.95,0.85"),
"bf16": config.get(self.arch, "bf16_tolerance", fallback="0.95,0.85"),
"bf16": config.get(self.arch, "bf16_tolerance", fallback="0.95,0.80"),
"int8_sym": config.get(self.arch, "int8_sym_tolerance", fallback="0.8,0.5"),
"int8_asym": config.get(self.arch, "int8_asym_tolerance", fallback="0.8,0.5"),
"int4_sym": config.get(self.arch, "int4_sym_tolerance", fallback="0.8,0.5"),
Expand Down Expand Up @@ -158,8 +156,6 @@ def run_model_transform(self, model_name: str, dynamic: bool = False):
cmd += ["--output_names {}".format(self.ini_content["output_names"])]
if "excepts" in self.ini_content:
cmd += ["--excepts {}".format(self.ini_content["excepts"])]
if self.do_post_handle and "post_type" in self.ini_content:
cmd += ["--post_handle_type {}".format(self.ini_content["post_type"])]
_os_system(cmd, self.save_log)

def make_calibration_table(self):
Expand Down Expand Up @@ -250,8 +246,6 @@ def run_model_deploy(self,

# add according to arguments
model_file = f"{model_name}_{self.chip}_{quant_mode}"
if self.do_post_handle:
cmd += ["--post_op"]
if self.fuse_pre:
cmd += ["--fuse_preprocess"]
model_file += "_fuse_preprocess"
Expand Down Expand Up @@ -338,8 +332,6 @@ def run_dynamic(self, quant_mode: str):
"model_runner.py", f"--input {static_model_name}_in_f32.npz",
f"--model {static_model_file}", f"--output {static_out}"
]
if self.do_post_handle:
cmd += ["--post_op"]
_os_system(cmd, self.save_log)
cmd[2], cmd[3] = f"--model {dyn_model_file}", f"--output {dyn_out}"
_os_system(cmd, self.save_log)
Expand Down Expand Up @@ -433,7 +425,6 @@ def run_full(self):
choices=['all', 'basic', 'f32', 'f16', 'bf16', 'int8_sym', 'int8_asym', 'int4_sym'],
help="quantize mode, 'all' runs all modes except int4, 'baisc' runs f16 and int8 sym only")
parser.add_argument("--dyn_mode", default='store_true', help="dynamic mode")
parser.add_argument("--do_post_handle", action='store_true', help="whether to do post handle")
parser.add_argument("--merge_weight", action="store_true",
help="merge weights into one weight binary with previous generated cvimodel")
# fuse preprocess
Expand All @@ -453,7 +444,7 @@ def run_full(self):
dir = os.path.expandvars(out_dir)
os.makedirs(dir, exist_ok=True)
os.chdir(dir)
runner = MODEL_RUN(args.model_name, args.chip, args.mode, args.dyn_mode, args.do_post_handle,
args.merge_weight, args.fuse_preprocess, args.customization_format,
args.aligned_input, args.save_log, args.disable_thread)
runner = MODEL_RUN(args.model_name, args.chip, args.mode, args.dyn_mode, args.merge_weight,
args.fuse_preprocess, args.customization_format, args.aligned_input,
args.save_log, args.disable_thread)
runner.run_full()
3 changes: 2 additions & 1 deletion regression/script_test/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pushd $TEST_DIR

$DIR/test1.sh
$DIR/test2.sh

$DIR/test3.sh
$DIR/test4.sh

popd
1 change: 1 addition & 0 deletions regression/script_test/test1.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/bin/bash
# test case: test batch 4, calibration by npz, preprocess, tosa convert
set -ex

DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"
Expand Down
1 change: 1 addition & 0 deletions regression/script_test/test2.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/bin/bash
# test case: test quantization with qtable
set -ex

DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"
Expand Down
1 change: 1 addition & 0 deletions regression/script_test/test3.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/bin/bash
# test case: test torch input int32 or int16 situation
set -ex

DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"
Expand Down
34 changes: 34 additions & 0 deletions regression/script_test/test4.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/bin/bash
# test case: test yolov5s preprocess and postprocess
set -ex

DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"

model_transform.py \
--model_name yolov5s \
--model_def ${REGRESSION_PATH}/model/yolov5s.onnx \
--input_shapes=[[1,3,192,1024]] \
--output_names=350,498,646 \
--scale=0.0039216,0.0039216,0.0039216 \
--pixel_format=rgb \
--test_input=${REGRESSION_PATH}/image/dog.jpg \
--test_result=yolov5s_top_outputs.npz \
--post_handle_type=yolo \
--mlir yolov5s.mlir

run_calibration.py yolov5s.mlir \
--dataset ${REGRESSION_PATH}/dataset/COCO2017 \
--input_num 100 \
-o yolov5s_cali_table

model_deploy.py \
--mlir yolov5s.mlir \
--quantize INT8 \
--chip bm1684x \
--calibration_table yolov5s_cali_table \
--fuse_preprocess \
--test_input ${REGRESSION_PATH}/image/dog.jpg \
--test_reference yolov5s_top_outputs.npz \
--except "yolo_post" \
--compare_all \
--model yolov5s_int8.bmodel

0 comments on commit 7d3a5d4

Please sign in to comment.