From 7d3a5d4fde58cb594fdad9101816a88902ec6bf2 Mon Sep 17 00:00:00 2001 From: "pengchao.hu" Date: Tue, 16 May 2023 15:29:46 +0800 Subject: [PATCH] add yolov5s post process test 1. fix interpreter bar exceeded 2. add yolov5s post process test 3. fix multithread conflict Change-Id: I98e805589dde4cc83c27865d9f658a90b3bf4ff7 --- lib/Support/MathUtils.cpp | 7 ++--- lib/Support/ModuleInterpreter.cpp | 35 +++++++++++++++---------- python/test/test_onnx.py | 4 ++- python/test/test_tflite.py | 2 ++ python/test/test_torch.py | 4 ++- python/test/test_tpulang.py | 2 ++ regression/config/ssd-12.ini | 2 -- regression/config/yolov5s.ini | 4 --- regression/generic_cvimodel_sample.py | 37 +++++++++++++++++++-------- regression/run_model.py | 17 +++--------- regression/script_test/run.sh | 3 ++- regression/script_test/test1.sh | 1 + regression/script_test/test2.sh | 1 + regression/script_test/test3.sh | 1 + regression/script_test/test4.sh | 34 ++++++++++++++++++++++++ 15 files changed, 105 insertions(+), 49 deletions(-) create mode 100755 regression/script_test/test4.sh diff --git a/lib/Support/MathUtils.cpp b/lib/Support/MathUtils.cpp index c423f587b..bda87c647 100644 --- a/lib/Support/MathUtils.cpp +++ b/lib/Support/MathUtils.cpp @@ -562,14 +562,15 @@ void pad_tensor(float *p_after_pad, float *src, int n, int c, int d, int h, void pad_tensor_for_deconv(float *p_after_pad, float *src, int n, int c, int d, int h, int w, int kd, int kh, int kw, int dd, int dh, int dw, int sd, int sh, int sw, int pdf, int pdb, - int pht, int phb, int pwl, int pwr, int opd, int oph, int opw, - float pad_value) { + int pht, int phb, int pwl, int pwr, int opd, int oph, + int opw, float pad_value) { int nc = n * c; int od = (d - 1) * sd + 1 + dd * (2 * kd - 2 - pdf - pdb) + opd; int oh = (h - 1) * sh + 1 + dh * (2 * kh - 2 - pht - phb) + oph; int ow = (w - 1) * sw + 1 + dw * (2 * kw - 2 - pwl - pwr) + opw; int pst[3] = {(kd - 1) * dd - pdf, (kh - 1) * dh - pht, (kw - 1) * dw - pwl}; - int ped[3] = {(kd - 1) * dd - pdb + opd, (kh - 1) * dh - phb + oph, (kw - 1) * dw - pwr + opw}; + int ped[3] = {(kd - 1) * dd - pdb + opd, (kh - 1) * dh - phb + oph, + (kw - 1) * dw - pwr + opw}; for (int i = 0; i < nc; i++) { for (int m = 0; m < od; m++) { for (int j = 0; j < oh; j++) { diff --git a/lib/Support/ModuleInterpreter.cpp b/lib/Support/ModuleInterpreter.cpp index 1d9e476da..b7c356a0c 100644 --- a/lib/Support/ModuleInterpreter.cpp +++ b/lib/Support/ModuleInterpreter.cpp @@ -374,27 +374,32 @@ void ModuleInterpreter::invoke_all_in_mem(bool express_type) { std::string if_name; for (auto func : module.getOps()) { WalkResult result = func.walk([&](Operation *op) { - bar.update(); if (isa(*op)) { return WalkResult::advance(); } std::string name; - if (op->getLoc().isa() || op->getLoc().isa()) + if (op->getLoc().isa() || op->getLoc().isa()) { name = module::getName(op).str(); + } LLVM_DEBUG(llvm::dbgs() << "compute: '" << op << "'\n"); - if (flag && isa(*(op->getParentOp()))) - flag = 0; //clear + if (flag && isa(*(op->getParentOp()))) { + flag = 0; // clear + } if (isa(op)) { - std::optional info = op->getName().getRegisteredInfo(); + std::optional info = + op->getName().getRegisteredInfo(); if_name = name; - auto *inferInterface = info->getInterface(); - if (failed(inferInterface->inference(inferInterface, op, *inference_map[name]))) { - flag = 2; //else branch + auto *inferInterface = + info->getInterface(); + if (failed(inferInterface->inference(inferInterface, op, + *inference_map[name]))) { + flag = 2; // else branch } else { - flag = 1;//then branch + flag = 1; // then branch } return WalkResult::advance(); - } else if (isa(op) && !flag) { + } else if (isa(op) && 0 == flag) { + bar.update(); auto infer_op = dyn_cast(op); if (failed(infer_op.inference(*inference_map[name]))) { infer_op.dump(); @@ -414,7 +419,8 @@ void ModuleInterpreter::invoke_all_in_mem(bool express_type) { name = module::getName(op->getOperand(k).getDefiningOp()).str(); #pragma omp parallel for schedule(static, omp_schedule(num_element)) for (int i = 0; i < num_element; i++) - inference_map[if_name]->outputs[k][i] = inference_map[name]->outputs[k][i]; + inference_map[if_name]->outputs[k][i] = + inference_map[name]->outputs[k][i]; } } } @@ -781,9 +787,10 @@ bool ModuleInterpreter::getTensorQuantInfo(const std::string name, } scale = 1.0; zp = 0; - } - else { - dtype = std::string("UK"); scale = 1.0; zp = 0; + } else { + dtype = std::string("UK"); + scale = 1.0; + zp = 0; } return true; } diff --git a/python/test/test_onnx.py b/python/test/test_onnx.py index 25eb8dd9a..d4a71383c 100755 --- a/python/test/test_onnx.py +++ b/python/test/test_onnx.py @@ -287,6 +287,8 @@ def test_single(self, case: str): torch.manual_seed(7) print("Test: {}".format(case)) if case in self.test_cases: + os.makedirs(case, exist_ok=True) + os.chdir(case) func, _, _, _, _ = self.test_cases[case] func(case) print("====== TEST {} Success ======".format(case)) @@ -396,7 +398,7 @@ def inference_and_compare(self, elif quant_mode == "int4": ref_tpu_tolerance = "0.90,0.60" elif quant_mode == "bf16": - ref_tpu_tolerance = "0.95,0.85" + ref_tpu_tolerance = "0.95,0.80" tpu_npz = tpu_mlir.replace(".mlir", "_tpu_out.npz") file_mark(tpu_npz) show_fake_cmd(input_npz, tpu_mlir, tpu_npz) diff --git a/python/test/test_tflite.py b/python/test/test_tflite.py index 919b215d3..a8d7cbfc4 100755 --- a/python/test/test_tflite.py +++ b/python/test/test_tflite.py @@ -98,6 +98,8 @@ def __init__(self, chip: str = "bm1684x", mode: str = "all"): def test_single(self, case: str): np.random.seed(0) if case in self.test_function: + os.makedirs(case, exist_ok=True) + os.chdir(case) print("Test: {}".format(case)) self.test_function[case](case) print("====== TEST {} Success ======".format(case)) diff --git a/python/test/test_torch.py b/python/test/test_torch.py index a6771d10f..229271bfa 100755 --- a/python/test/test_torch.py +++ b/python/test/test_torch.py @@ -169,6 +169,8 @@ def test_single(self, case: str): TORCH_IR_TESTER.CURRENT_CASE = case print("Test: {}".format(case)) if case in self.test_cases: + os.makedirs(case, exist_ok=True) + os.chdir(case) func, _, _, _ = self.test_cases[case] func() print("====== TEST {} Success ======".format(case)) @@ -297,7 +299,7 @@ def inference_and_compare(self, elif quant_mode == "int4": ref_tpu_tolerance = "0.90,0.60" elif quant_mode == "bf16": - ref_tpu_tolerance = "0.95,0.85" + ref_tpu_tolerance = "0.95,0.80" input_data = np.load(input_npz) # tpu mlir inference and compare tpu_npz = tpu_mlir.replace(".mlir", "_tpu_out.npz") diff --git a/python/test/test_tpulang.py b/python/test/test_tpulang.py index d6901234a..80b15ce73 100755 --- a/python/test/test_tpulang.py +++ b/python/test/test_tpulang.py @@ -59,6 +59,8 @@ def test_single(self, case: str): TPULANG_IR_TESTER.ID = 0 print("Test: {}".format(case)) if case in self.test_function: + os.makedirs(case, exist_ok=True) + os.chdir(case) func, _ = self.test_function[case] func(case) print("====== TEST {} Success ======".format(case)) diff --git a/regression/config/ssd-12.ini b/regression/config/ssd-12.ini index 53957c500..97c146780 100644 --- a/regression/config/ssd-12.ini +++ b/regression/config/ssd-12.ini @@ -12,8 +12,6 @@ output_names=Concat_659,Slice_676 do_int8_asym=0 # has problem do_f16=0 do_bf16=0 -#app=detect_ssd-12.py -do_post_handle=0 #1: build post op into bmodel 0: don't build post op into bmodel post_type=ssd [bm1684x] diff --git a/regression/config/yolov5s.ini b/regression/config/yolov5s.ini index f7ccc72ea..e14929c76 100644 --- a/regression/config/yolov5s.ini +++ b/regression/config/yolov5s.ini @@ -7,14 +7,10 @@ keep_aspect_ratio=1 mean=0.0,0.0,0.0 scale=0.0039216,0.0039216,0.0039216 pixel_format=rgb -#output_names=350,498,646 output_names=326,474,622 do_dynamic=0 -do_post_handle=0 #1: build post op into bmodel 0: don't build post op into bmodel -post_type=yolo dynamic_shapes=[[1,3,320,320]] app=detect_yolov5.py -# pad_value=114 pad_type=normal [bm1684x] diff --git a/regression/generic_cvimodel_sample.py b/regression/generic_cvimodel_sample.py index e0bd71a89..05f33775a 100755 --- a/regression/generic_cvimodel_sample.py +++ b/regression/generic_cvimodel_sample.py @@ -11,7 +11,10 @@ import time from run_model import MODEL_RUN from utils.mlir_shell import _os_system + + class CviSampleGenerator: + def __init__(self, dest_dir): self.chips = ["cv180x", "cv181x", "cv182x", "cv183x"] #self.chips = ["cv183x"] @@ -36,17 +39,28 @@ def cmd_exec(self, cmd_str): else: raise RuntimeError("[!Error]: {}".format(cmd_str)) - def run_sample_net(self, model, dst_model, chip, quant_type, customization_format = "", fuse_pre = False, - aligned_input = False, merge_weight = False): + def run_sample_net(self, + model, + dst_model, + chip, + quant_type, + customization_format="", + fuse_pre=False, + aligned_input=False, + merge_weight=False): info = f"{chip} {model} {quant_type} fuse_preprocess:{fuse_pre} {customization_format} aligned_input:{aligned_input} merge_weight:{merge_weight}" print(f"run_model.py {info}") dir = os.path.expandvars(f"$REGRESSION_PATH/regression_out/{model}_{chip}") os.makedirs(dir, exist_ok=True) os.chdir(dir) - regressor = MODEL_RUN(model, chip, quant_type, \ - dyn_mode=False, do_post_handle=False, merge_weight=merge_weight, \ - fuse_preprocess=fuse_pre, customization_format=customization_format, \ - aligned_input=aligned_input) + regressor = MODEL_RUN(model, + chip, + quant_type, + dyn_mode=False, + merge_weight=merge_weight, + fuse_preprocess=fuse_pre, + customization_format=customization_format, + aligned_input=aligned_input) ret = regressor.run_full() if ret == 0: print(f"{info} Success.") @@ -59,6 +73,7 @@ def run_sample_net(self, model, dst_model, chip, quant_type, customization_forma def generate_chip_models(self, chip): tmp_dir = os.path.expandvars(f"$REGRESSION_PATH/regression_out/cvimodel_samples") os.makedirs(tmp_dir, exist_ok=True) + # yapf: disable self.run_sample_net("mobilenet_v2_cf", os.path.join(tmp_dir, "mobilenet_v2.cvimodel"), chip, "int8_sym") self.run_sample_net("mobilenet_v2_cf", os.path.join(tmp_dir, "mobilenet_v2_bf16.cvimodel"), chip, "bf16") self.run_sample_net("mobilenet_v2_cf", os.path.join(tmp_dir, "mobilenet_v2_fused_preprocess.cvimodel"), chip, "int8_sym", "BGR_PLANAR", True) @@ -76,12 +91,14 @@ def generate_chip_models(self, chip): self.run_sample_net("alphapose_res50", os.path.join(tmp_dir, "alphapose_fused_preprocess.cvimodel"), chip, "int8_sym", "RGB_PLANAR", True) self.run_sample_net("arcface_res50", os.path.join(tmp_dir, "arcface_res50_fused_preprocess.cvimodel"), chip, "int8_sym", "RGB_PLANAR", True) self.run_sample_net("arcface_res50", os.path.join(tmp_dir, "arcface_res50_fused_preprocess_aligned_input.cvimodel"), chip, "int8_sym", "RGB_PLANAR", True, True) - + # yapf: enable #gen merged cvimodel full_bs1_model = os.path.join(tmp_dir, "mobilenet_v2_bs1.cvimodel") full_bs4_model = os.path.join(tmp_dir, "mobilenet_v2_bs4.cvimodel") - self.run_sample_net("mobilenet_v2_cf", full_bs1_model, chip, "int8_sym", "", False, False, True) - self.run_sample_net("mobilenet_v2_cf_bs4", full_bs4_model, chip, "int8_sym", "", False, False, True) + self.run_sample_net("mobilenet_v2_cf", full_bs1_model, chip, "int8_sym", "", False, False, + True) + self.run_sample_net("mobilenet_v2_cf_bs4", full_bs4_model, chip, "int8_sym", "", False, + False, True) merge_model = os.path.join(tmp_dir, "mobilenet_v2_bs1_bs4.cvimodel") merge_cmd = f"model_tool --combine {full_bs1_model} {full_bs4_model} -o {merge_model}" self.cmd_exec(merge_cmd) @@ -94,11 +111,11 @@ def generate_chip_models(self, chip): os.chdir(self.current_dir) shutil.rmtree(tmp_dir) - def generate_models(self): for chip in self.chips: self.generate_chip_models(chip) + if __name__ == "__main__": t1 = time.time() parser = argparse.ArgumentParser() diff --git a/regression/run_model.py b/regression/run_model.py index 01c04b549..ed0683774 100755 --- a/regression/run_model.py +++ b/regression/run_model.py @@ -26,7 +26,6 @@ def __init__(self, chip: str = "bm1684x", mode: str = "all", dyn_mode: bool = False, - do_post_handle: bool = False, merge_weight: bool = False, fuse_preprocess: bool = False, customization_format: str = "", @@ -36,7 +35,6 @@ def __init__(self, self.model_name = model_name self.chip = chip self.mode = mode - self.do_post_handle = do_post_handle self.dyn_mode = dyn_mode self.fuse_pre = fuse_preprocess self.customization_format = customization_format @@ -70,7 +68,7 @@ def __init__(self, self.tolerance = { "f32": config.get(self.arch, "f32_tolerance", fallback="0.99,0.99"), "f16": config.get(self.arch, "f16_tolerance", fallback="0.95,0.85"), - "bf16": config.get(self.arch, "bf16_tolerance", fallback="0.95,0.85"), + "bf16": config.get(self.arch, "bf16_tolerance", fallback="0.95,0.80"), "int8_sym": config.get(self.arch, "int8_sym_tolerance", fallback="0.8,0.5"), "int8_asym": config.get(self.arch, "int8_asym_tolerance", fallback="0.8,0.5"), "int4_sym": config.get(self.arch, "int4_sym_tolerance", fallback="0.8,0.5"), @@ -158,8 +156,6 @@ def run_model_transform(self, model_name: str, dynamic: bool = False): cmd += ["--output_names {}".format(self.ini_content["output_names"])] if "excepts" in self.ini_content: cmd += ["--excepts {}".format(self.ini_content["excepts"])] - if self.do_post_handle and "post_type" in self.ini_content: - cmd += ["--post_handle_type {}".format(self.ini_content["post_type"])] _os_system(cmd, self.save_log) def make_calibration_table(self): @@ -250,8 +246,6 @@ def run_model_deploy(self, # add according to arguments model_file = f"{model_name}_{self.chip}_{quant_mode}" - if self.do_post_handle: - cmd += ["--post_op"] if self.fuse_pre: cmd += ["--fuse_preprocess"] model_file += "_fuse_preprocess" @@ -338,8 +332,6 @@ def run_dynamic(self, quant_mode: str): "model_runner.py", f"--input {static_model_name}_in_f32.npz", f"--model {static_model_file}", f"--output {static_out}" ] - if self.do_post_handle: - cmd += ["--post_op"] _os_system(cmd, self.save_log) cmd[2], cmd[3] = f"--model {dyn_model_file}", f"--output {dyn_out}" _os_system(cmd, self.save_log) @@ -433,7 +425,6 @@ def run_full(self): choices=['all', 'basic', 'f32', 'f16', 'bf16', 'int8_sym', 'int8_asym', 'int4_sym'], help="quantize mode, 'all' runs all modes except int4, 'baisc' runs f16 and int8 sym only") parser.add_argument("--dyn_mode", default='store_true', help="dynamic mode") - parser.add_argument("--do_post_handle", action='store_true', help="whether to do post handle") parser.add_argument("--merge_weight", action="store_true", help="merge weights into one weight binary with previous generated cvimodel") # fuse preprocess @@ -453,7 +444,7 @@ def run_full(self): dir = os.path.expandvars(out_dir) os.makedirs(dir, exist_ok=True) os.chdir(dir) - runner = MODEL_RUN(args.model_name, args.chip, args.mode, args.dyn_mode, args.do_post_handle, - args.merge_weight, args.fuse_preprocess, args.customization_format, - args.aligned_input, args.save_log, args.disable_thread) + runner = MODEL_RUN(args.model_name, args.chip, args.mode, args.dyn_mode, args.merge_weight, + args.fuse_preprocess, args.customization_format, args.aligned_input, + args.save_log, args.disable_thread) runner.run_full() diff --git a/regression/script_test/run.sh b/regression/script_test/run.sh index 1c5b38547..74bb37383 100755 --- a/regression/script_test/run.sh +++ b/regression/script_test/run.sh @@ -9,6 +9,7 @@ pushd $TEST_DIR $DIR/test1.sh $DIR/test2.sh - +$DIR/test3.sh +$DIR/test4.sh popd diff --git a/regression/script_test/test1.sh b/regression/script_test/test1.sh index 1518afc39..031be060f 100755 --- a/regression/script_test/test1.sh +++ b/regression/script_test/test1.sh @@ -1,4 +1,5 @@ #!/bin/bash +# test case: test batch 4, calibration by npz, preprocess, tosa convert set -ex DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" diff --git a/regression/script_test/test2.sh b/regression/script_test/test2.sh index f2f2d670d..c126c776a 100755 --- a/regression/script_test/test2.sh +++ b/regression/script_test/test2.sh @@ -1,4 +1,5 @@ #!/bin/bash +# test case: test quantization with qtable set -ex DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" diff --git a/regression/script_test/test3.sh b/regression/script_test/test3.sh index 5aecc599a..064eae607 100755 --- a/regression/script_test/test3.sh +++ b/regression/script_test/test3.sh @@ -1,4 +1,5 @@ #!/bin/bash +# test case: test torch input int32 or int16 situation set -ex DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" diff --git a/regression/script_test/test4.sh b/regression/script_test/test4.sh new file mode 100755 index 000000000..92ab07897 --- /dev/null +++ b/regression/script_test/test4.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# test case: test yolov5s preprocess and postprocess +set -ex + +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" + +model_transform.py \ + --model_name yolov5s \ + --model_def ${REGRESSION_PATH}/model/yolov5s.onnx \ + --input_shapes=[[1,3,192,1024]] \ + --output_names=350,498,646 \ + --scale=0.0039216,0.0039216,0.0039216 \ + --pixel_format=rgb \ + --test_input=${REGRESSION_PATH}/image/dog.jpg \ + --test_result=yolov5s_top_outputs.npz \ + --post_handle_type=yolo \ + --mlir yolov5s.mlir + +run_calibration.py yolov5s.mlir \ + --dataset ${REGRESSION_PATH}/dataset/COCO2017 \ + --input_num 100 \ + -o yolov5s_cali_table + +model_deploy.py \ + --mlir yolov5s.mlir \ + --quantize INT8 \ + --chip bm1684x \ + --calibration_table yolov5s_cali_table \ + --fuse_preprocess \ + --test_input ${REGRESSION_PATH}/image/dog.jpg \ + --test_reference yolov5s_top_outputs.npz \ + --except "yolo_post" \ + --compare_all \ + --model yolov5s_int8.bmodel