Skip to content

Commit

Permalink
CopyPropagation, more tac builder, object support, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
kokifish committed Jan 17, 2025
1 parent 1bf8961 commit 6710e81
Show file tree
Hide file tree
Showing 15 changed files with 395 additions and 77 deletions.
24 changes: 20 additions & 4 deletions examples/dis_demo.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,35 @@
import argparse
import os
import subprocess

import ohre
from ohre.abcre.dis.PandaReverser import PandaReverser
from ohre.abcre.dis.DisFile import DisFile
from ohre.abcre.dis.PandaReverser import PandaReverser
from ohre.core import oh_app, oh_hap
from ohre.misc import Log

TMP_HAP_EXTRACT = "tmp_hap_extract"
TMP_APP_EXTRACT = "tmp_app_extract"
ARK_DISASM = "path2ark_disasm"

if __name__ == "__main__": # clear; pip install -e .; python3 examples/dis_demo.py name.abc.dis
Log.init_log("abcre", ".")
ohre.set_log_level("info")
ohre.set_log_print(True)
Log.info(f"START {__file__}")
parser = argparse.ArgumentParser()
parser.add_argument("dis_path", type=str, help="path to the dis file (ark_disasm-ed abc)")
parser.add_argument("in_path", type=str, help="path to the dis file (ark_disasm-ed abc) or hap/app")
arg = parser.parse_args()
dis_path = arg.dis_path
dis_file: DisFile = DisFile(dis_path)
in_path = arg.in_path
if (in_path.endswith(".dis")):
dis_file: DisFile = DisFile(in_path)
elif (in_path.endswith(".hap")):
hhap = oh_hap.oh_hap(in_path)
hhap.extract_all_to(TMP_HAP_EXTRACT)
abc_file = os.path.join(TMP_HAP_EXTRACT, "ets", "modules.abc")
dis_file = f"{os.path.splitext(os.path.basename(in_path))[0]}.abc.dis" # os.path.splitext(file_name)[0]
result = subprocess.run([ARK_DISASM, abc_file, dis_file], capture_output=True, text=True)
dis_file: DisFile = DisFile(dis_file)
panda_re: PandaReverser = PandaReverser(dis_file)
print(f"> panda_re: {panda_re}")

Expand Down Expand Up @@ -45,6 +60,7 @@
# tac_total = panda_re.get_insts_total()
# for idx in range(panda_re.method_len()):
# panda_re._code_lifting_algorithms(method_id=idx)
# print(f">> [{idx}/{panda_re.method_len()}] after lift {panda_re.dis_file.methods[idx]._debug_vstr()}")
# todo_tac = panda_re.get_tac_unknown_count()
# final_tac_total = panda_re.get_insts_total()
# print(f"todo_tac {todo_tac}/{tac_total} {todo_tac/tac_total:.4f} / nac {nac_total} {todo_tac/nac_total:.4f}")
Expand Down
125 changes: 122 additions & 3 deletions ohre/abcre/dis/AsmArg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,24 @@ def __init__(self, arg_type: AsmTypes = AsmTypes.UNKNOWN,
# name: e.g. for v0, type is VAR, name is v0(stored without truncating the prefix v)
self.name: str = name
# value: may be set in the subsequent analysis
self.value = value # if type is ARRAY, value is AsmArg list
# type is ARRAY: value is list[AsmArg]
# type is OBJECT: value is list[AsmArg]: AsmArg(name:key, value:any value)
self.value = value
self.ref_base = ref_base # AsmArg
self.paras_len: Union[int, None] = paras_len # for method object, store paras len here
if (self.is_value_valid() == False):
Log.error(f"AsmArg value is NOT valid, type {self.type_str} value {type(value)} {value}")

@property
def len(self):
if (len(self.name) > 0):
return len(self.name)
return len(self.type)

@property
def type_str(self) -> str:
return AsmTypes.get_code_name(self.type)

def __len__(self) -> int:
return self.len

Expand All @@ -46,6 +54,31 @@ def __hash__(self):
def __repr__(self):
return f"Arg({self._debug_str()})"

def obj_has_key(self, key) -> bool:
# if self is OBJECT and key exists in self.value, return True
if (not isinstance(self.value, Iterable)):
return False
key_name_str: str = ""
if (isinstance(key, AsmArg)):
key_name_str = key.name
elif (isinstance(key, str)):
key_name_str = key
else:
Log.error(f"ERROR! obj_has_key key {type(key)} {key}")
for arg in self.value:
if (key_name_str == arg.name):
return True
return False

def set_object_key_value(self, key: str, value: str, create=False):
if (self.type != AsmTypes.OBJECT):
return False
for arg in self.value:
if (key == arg.name):
arg.value = value
return True
return False

def set_ref(self, ref_ed_arg):
self.ref_base = ref_ed_arg

Expand Down Expand Up @@ -87,6 +120,25 @@ def ACC(cls): # return AsmArg(AsmTypes.ACC)
def build_arr(cls, args: List, name: str = ""): # element of args should be AsmArg
return AsmArg(AsmTypes.ARRAY, name=name, value=list(args))

@classmethod
def build_object(cls, in_kv: Dict = None, name: str = "", ref_base=None): # element of args should be AsmArg
obj_value_l = list()
if (isinstance(in_kv, Iterable)):
for k, v in in_kv.items():
if (isinstance(v, int)):
obj_value_l.append(AsmArg(AsmTypes.IMM, name=k, value=v))
elif (isinstance(v, float)):
obj_value_l.append(AsmArg(AsmTypes.IMM, name=k, value=v))
elif (isinstance(v, str)):
obj_value_l.append(AsmArg(AsmTypes.STR, name=k, value=v))
elif (v is None):
obj_value_l.append(AsmArg(AsmTypes.UNDEFINED, name=k, value=None))
else:
Log.error(f"ERROR! build_object k {k} {type(k)} v {v} {type(v)} name {name}")
if (len(obj_value_l) == 0):
obj_value_l = None
return AsmArg(AsmTypes.OBJECT, name=name, value=obj_value_l, ref_base=ref_base)

@classmethod
def build_FunctionObject(cls):
# FunctionObject always stored at a0
Expand All @@ -110,14 +162,60 @@ def build_next_arg(self): # arg is AsmArg
num += 1
return AsmArg(self.type, f"{self.name[0]}{num}")

def is_value_valid(self) -> bool: # TODO: for some types, value is not valid, judge it
pass
def is_value_valid(self) -> bool:
if (self.value is None):
return True
if (self.type == AsmTypes.IMM):
if (isinstance(self.value, int) or isinstance(self.value, float)):
return True
return False
if (self.type == AsmTypes.STR or self.type == AsmTypes.LABEL):
if (isinstance(self.value, str)):
return True
return False
if (self.type == AsmTypes.METHOD_OBJ):
if (isinstance(self.value, str)):
return True
return False
if (self.type == AsmTypes.OBJECT):
if (isinstance(self.value, Iterable)):
return True
return False
if (self.type == AsmTypes.ARRAY):
if (isinstance(self.value, list)):
return True
return False
if (self.type == AsmTypes.NULL or self.type == AsmTypes.INF or self.type == AsmTypes.NAN
or self.type == AsmTypes.UNDEFINED or self.type == AsmTypes.HOLE):
return False
Log.error(f"is_value_valid NOT supported logic type {self.type_str} value {type(self.value)} {self.value}")
return True

def is_acc(self) -> bool:
if (self.type == AsmTypes.ACC):
return True
return False

def is_imm(self) -> bool:
if (self.type == AsmTypes.IMM):
return True
return False

def is_field(self) -> bool:
if (self.type == AsmTypes.FIELD):
return True
return False

def is_unknown(self) -> bool:
if (self.type == AsmTypes.UNKNOWN):
return True
return False

def is_temp_var_like(self) -> bool:
if ((self.type == AsmTypes.VAR or self.type == AsmTypes.ACC) and self.is_no_ref()):
return True
return False

def get_all_args_recursively(self, include_self: bool = True) -> List:
out = list()
if (include_self):
Expand All @@ -144,9 +242,28 @@ def _common_error_check(self):
if (len(self.name) == 0):
Log.error(f"[ArgCC] A label without name: len {len(self.name)}")

def _debug_str_obj(self, detail=False):
out = ""
if (self.ref_base is not None):
out += f"{self.ref_base}->"
if (detail):
out += f"OBJ:{self.name}"
else:
out += f"{self.name}"
if (isinstance(self.value, Iterable)):
out += "{"
for v_arg in self.value:
out += f"{v_arg.name}:{v_arg.value}, "
out += "}"
elif (self.value is not None):
out += "{" + self.value + "}"
return out

def _debug_str(self, print_ref: bool = True):
self._common_error_check()
out = ""
if (self.type == AsmTypes.OBJECT):
return self._debug_str_obj()
if (self.type == AsmTypes.FIELD):
if (print_ref and self.ref_base is not None):
out += f"{self.ref_base}[{self.name}]"
Expand All @@ -167,6 +284,8 @@ def _debug_str(self, print_ref: bool = True):
def _debug_vstr(self, print_ref: bool = True):
self._common_error_check()
out = ""
if (self.type == AsmTypes.OBJECT):
return self._debug_str_obj(detail=True)
if (self.type == AsmTypes.FIELD):
if (print_ref and self.ref_base is not None):
out += f"{self.ref_base}[{AsmTypes.get_code_name(self.type)}-{self.name}]"
Expand Down
66 changes: 59 additions & 7 deletions ohre/abcre/dis/AsmLiteral.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Iterable, List, Tuple, Union

from ohre.abcre.dis.enum.CODE_LV import CODE_LV
from ohre.abcre.dis.DebugBase import DebugBase
from ohre.abcre.dis.enum.AsmTypes import AsmTypes
from ohre.misc import Log, utils


Expand Down Expand Up @@ -40,11 +40,11 @@ def _process_normal_literal(self, lines: List[str]):
change_flag = 0
for i in element_content:
if i == '"':
change_flag = abs(1-change_flag)
change_flag = abs(1 - change_flag)
s_cnt += 1
elif i == ',' and change_flag == 1:
modified_content = modified_content[:s_cnt] + \
'<comma>'+modified_content[s_cnt+1:]
'<comma>' + modified_content[s_cnt + 1:]
s_cnt += 7
else:
s_cnt += 1
Expand All @@ -58,8 +58,8 @@ def _process_normal_literal(self, lines: List[str]):
if 'string' in array_split_list[cnt]:
method_string = array_split_list[cnt].split(':')[
1].strip()[1:-1]
method_name = array_split_list[cnt+1].split(':')[1].strip()
method_aff = array_split_list[cnt+2].split(':')[1].strip()
method_name = array_split_list[cnt + 1].split(':')[1].strip()
method_aff = array_split_list[cnt + 2].split(':')[1].strip()
method_dict[method_string] = {
'method': method_name, 'method_affiliate': method_aff}
cnt += 3
Expand All @@ -76,13 +76,13 @@ def _process_normal_literal(self, lines: List[str]):
variable_string = array_split_list[cnt].split(':')[1].strip()
if '"' in variable_string:
variable_string = variable_string.replace('"', '')
variable_value = array_split_list[cnt+1]
variable_value = array_split_list[cnt + 1]
if 'null_value' in variable_value:
variable_value = 'null_value'
else:
variable_value = variable_value.split(":")[1].strip()
if '"' in variable_value:
variable_value = variable_value.replace('"', '').replace('<comma>',',')
variable_value = variable_value.replace('"', '').replace('<comma>', ',')
cnt += 2
method_dict[variable_string] = variable_value
if element_amount % 2 == 1:
Expand Down Expand Up @@ -146,3 +146,55 @@ def _debug_vstr(self) -> str:
if (self.module_tags is not None):
out += f" module_tags({len(self.module_tags)}) {self.module_tags}"
return out

def _lit_split_by_comma(s: str) -> List[str]:
modified_content = s
s_cnt = 0
change_flag = 0
for i in s:
if i == '"':
change_flag = abs(1 - change_flag)
s_cnt += 1
elif i == "," and change_flag == 1:
modified_content = modified_content[:s_cnt] + \
"<comma>" + modified_content[s_cnt + 1:]
s_cnt += 7
else:
s_cnt += 1

array_split_list = [x.strip() for x in modified_content.strip().split(",") if len(x) > 0]
for i in range(len(array_split_list)):
array_split_list[i] = array_split_list[i].replace("<comma>", ",")
return array_split_list

@classmethod
def literal_get_key_value(cls, in_s: str) -> Dict:
ret = dict()
in_s = utils.strip_sted_str(in_s.strip(), start_str="{", end_str="}").strip()
e_idx = in_s.find("[")
element_amount_str = in_s[0:e_idx].strip()
if (not element_amount_str.isdigit()):
Log.error(f"Expected a digit for element amount, got {element_amount_str}")
return dict()
element_amount = int(element_amount_str)
in_s = in_s[e_idx:].strip()
in_s = utils.strip_sted_str(in_s, start_str="[", end_str="]")
kv: List[str] = cls._lit_split_by_comma(in_s)
for i in range(0, element_amount, 2):
key = kv[i].split(":")[1].strip()
key = utils.strip_sted_str(key, start_str="\"", end_str="\"")
if (key.startswith("\"") and key.endswith("\"")):
key = key[1:-1]
value_type = kv[i + 1].split(":")[0].strip()
value = kv[i + 1].split(":")[1].strip()
if ("null_value" in value_type):
value = None
elif (AsmTypes.is_int(value_type)):
if (value.isdigit()):
value = int(value,)
else:
Log.error(f"ERROR literal_get_key_value value_type {value_type} value {value}")
elif ("string" in value_type):
value = utils.strip_sted_str(value, start_str="\"", end_str="\"")
ret[key] = value
return ret
4 changes: 2 additions & 2 deletions ohre/abcre/dis/AsmMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def _process_createobjectwithbuffer(self, lines: str, l_n: int) -> Tuple[List[st
idx = utils.find_next_delimiter_single_line(line_concat, s_idx)
ret.append(line_concat[s_idx: idx].strip()) # reserved number

s_idx = line_concat.find("\{", idx) + 1
e_idx = line_concat.rfind("\}")
s_idx = line_concat.find("{", idx) + 1
e_idx = line_concat.rfind("}")
ret.append(line_concat[s_idx: e_idx])
return ret, l_n_end

Expand Down
36 changes: 36 additions & 0 deletions ohre/abcre/dis/CodeBlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,42 @@ def __init__(self, in_l: Union[List[List[str]], List[NAC], List[TAC]],

self.use_vars: set[AsmArg] = None
self.def_vars: set[AsmArg] = None
self.var2val: Dict[AsmArg, AsmArg] = dict()

def set_var2val(self, var2val: Dict[AsmArg, AsmArg]):
self.var2val = var2val

def get_var2val(self):
return self.var2val

def empty_var2val(self):
self.var2val = dict()

def get_all_prev_cbs_var2val(self, get_current_cb=False, definite_flag=True) -> Dict[AsmArg, AsmArg]:
# recursively
# definite_flag: if True, when var def more than 1 with different value, let var undef
ret = dict()
if (get_current_cb):
ret.update(self.get_var2val())
for cb in self.prev_cb_list:
prev_cbs_var2val = cb.get_all_prev_cbs_var2val(True, True)
for var, val in prev_cbs_var2val.items():
if (val is None and definite_flag): # val maybe a return value of call
if (val in ret.keys()): # val is None means val is undef-ed
del ret[var]
continue
if (val.is_unknown()):
continue # maybe a para of function
if (definite_flag):
if (var not in ret.keys()):
ret[var] = val
elif (var in ret.keys() and ret[var] == val): # same value
continue
else: # var exist but not same val
del ret[var]
else:
ret[var] = val
return ret

def get_slice_block(self, idx_start: int, idx_end: int):
return CodeBlock(copy.deepcopy(self.insts[idx_start: idx_end]))
Expand Down
Loading

0 comments on commit 6710e81

Please sign in to comment.