From dcac6291a03ed8153676e60aa5698040d30fb92d Mon Sep 17 00:00:00 2001 From: fanyunqian Date: Mon, 13 Jun 2022 21:23:48 +0800 Subject: [PATCH] [Fix] adavanced ptq --- mqbench/advanced_ptq.py | 54 +++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/mqbench/advanced_ptq.py b/mqbench/advanced_ptq.py index f9ede726..3f5374e0 100644 --- a/mqbench/advanced_ptq.py +++ b/mqbench/advanced_ptq.py @@ -290,7 +290,10 @@ def subgraph_reconstruction(subgraph, cached_inps, cached_oups, config): a_para = [] for name, layer in subgraph.named_modules(): if isinstance(layer, _ADAROUND_SUPPORT_TYPE): - weight_quantizer = layer.weight_fake_quant + if hasattr(layer, 'weight_fake_quant'): + weight_quantizer = layer.weight_fake_quant + else: + continue # assert isinstance(weight_quantizer, adaround_quantizer) is True weight_quantizer.init(layer.weight.data, config.round_mode) w_para += [weight_quantizer.alpha] @@ -304,6 +307,8 @@ def subgraph_reconstruction(subgraph, cached_inps, cached_oups, config): a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=config.max_count, eta_min=0.) else: a_opt, a_scheduler = None, None + if w_para == []: + return w_opt = torch.optim.Adam(w_para) loss_func = LossFunction(subgraph=subgraph, weight=config.weight, max_count=config.max_count, b_range=config.b_range, @@ -633,7 +638,15 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_ continue logger.info('the node list is below!') logger.info(layer_node_list) - fp32_module = fp32_modules[qnode2fpnode_dict[layer_node_list[-1]]] + if layer_node_list[-1] in qnode2fpnode_dict: + fp32_module = fp32_modules[qnode2fpnode_dict[layer_node_list[-1]]] + use_next_input = False + else: + out_node = layer_node_list[-1] + inp_node = list(out_node.users)[0] + qinp_node = qnode2fpnode_dict[inp_node] + fp32_module = fp32_modules[qinp_node] + use_next_input = True fp32_all_inps = [] quant_all_inps = [] fp32_final_oups = None @@ -642,15 +655,35 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_ if all([arg in layer_node_list for arg in _flatten_args(_node.args) if isinstance(arg, torch.fx.Node)]): continue else: - fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]] - quant_module = quant_modules[_node] + if _node in quant_modules: + fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]] + quant_module = quant_modules[_node] + use_next_input_ = False + else: + _node = list(_node.users)[0] + fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]] + quant_module = quant_modules[_node] + use_next_input_ = True # fp32 inps: [out_b1, out_b2, ...] - _, fp32_inps = save_inp_oup_data(fp32_model, None, fp32_inp_module, cali_data, - store_inp=False, store_oup=(config.prob < 1.0), keep_gpu=config.keep_gpu) - _, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data, - store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu) - _, quant_inps = save_inp_oup_data(quant_model, None, quant_module, cali_data, - store_inp=False, store_oup=True, keep_gpu=config.keep_gpu) + if use_next_input is False: + _, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data, + store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu) + else: + fp32_oups, _ = save_inp_oup_data(fp32_model, fp32_module, None, cali_data, + store_inp=(not out_is_cached), store_oup=False, keep_gpu=config.keep_gpu) + fp32_oups = sum(fp32_oups, []) + if use_next_input_ is False: + _, fp32_inps = save_inp_oup_data(fp32_model, None, fp32_inp_module, cali_data, + store_inp=False, store_oup=(config.prob < 1.0), keep_gpu=config.keep_gpu) + _, quant_inps = save_inp_oup_data(quant_model, None, quant_module, cali_data, + store_inp=False, store_oup=True, keep_gpu=config.keep_gpu) + else: + fp32_inps, _ = save_inp_oup_data(fp32_model, fp32_inp_module,None, cali_data, + store_inp=(config.prob < 1.0), store_oup=False, keep_gpu=config.keep_gpu) + quant_inps, _ = save_inp_oup_data(quant_model, quant_module, None, cali_data, + store_inp=True, store_oup=False, keep_gpu=config.keep_gpu) + fp32_inps = sum(fp32_inps, []) + quant_inps = sum(quant_inps, []) fp32_all_inps.append(fp32_inps) quant_all_inps.append(quant_inps) if not out_is_cached: @@ -674,3 +707,4 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_ enable_quantization(quant_modules[node]) logger.info(f'set the node {node.target} in quant') return quant_model + \ No newline at end of file