From a307d0609c2b32b4e02837ca679879a0df4f0ce8 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Thu, 13 Feb 2025 14:07:34 -0800 Subject: [PATCH] add SSTORE and SLOAD to traces (#456) --- src/halmos/__main__.py | 6 ++++ src/halmos/config.py | 41 ++++++++++++++++++++++--- src/halmos/sevm.py | 23 ++++++++++++-- src/halmos/traces.py | 69 +++++++++++++++++++++++++++++++++++------- src/halmos/utils.py | 2 +- tests/test_config.py | 30 +++++++++--------- 6 files changed, 138 insertions(+), 33 deletions(-) diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index 16d33c10..60fb9f85 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -24,6 +24,8 @@ unsat, ) +import halmos.traces + from .build import ( build_output_iterator, import_libs, @@ -401,6 +403,9 @@ def run_test(ctx: FunctionContext) -> TestResult: if args.verbose >= 1: print(f"Executing {funname}") + # set the config for every trace rendered in this test + halmos.traces.config_context.set(args) + # # prepare calldata # @@ -742,6 +747,7 @@ def run_contract(ctx: ContractContext) -> list[TestResult]: contract_ctx=ctx, ) + halmos.traces.config_context.set(setup_config) setup_ex = setup(setup_ctx) except Exception as err: error(f"{setup_info.sig} failed: {type(err).__name__}: {err}") diff --git a/src/halmos/config.py b/src/halmos/config.py index eaa78270..582d79ac 100644 --- a/src/halmos/config.py +++ b/src/halmos/config.py @@ -6,6 +6,7 @@ from collections.abc import Callable, Generator from dataclasses import MISSING, dataclass, fields from dataclasses import field as dataclass_field +from enum import Enum from pathlib import Path from typing import Any @@ -26,6 +27,12 @@ ) +class TraceEvent(Enum): + LOG = "LOG" + SSTORE = "SSTORE" + SLOAD = "SLOAD" + + def find_venv_root() -> Path | None: # If the environment variable is set, use that if "VIRTUAL_ENV" in os.environ: @@ -88,9 +95,28 @@ def parse_csv(values: str, sep: str = ",") -> Generator[Any, None, None]: return (x for _x in values.split(sep) if (x := _x.strip())) -class ParseCSV(argparse.Action): +class ParseCSVTraceEvent(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): - values = ParseCSV.parse(values) + values = ParseCSVTraceEvent.parse(values) + setattr(namespace, self.dest, values) + + @staticmethod + def parse(values: str) -> list[TraceEvent]: + # empty list is ok + try: + return [TraceEvent(x) for x in parse_csv(values)] + except ValueError as e: + valid = ", ".join([e.value for e in TraceEvent]) + raise ValueError(f"the list of valid trace events is: {valid}") from e + + @staticmethod + def unparse(values: list[TraceEvent]) -> str: + return ",".join([x.value for x in values]) + + +class ParseCSVInt(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + values = ParseCSVInt.parse(values) setattr(namespace, self.dest, values) @staticmethod @@ -277,14 +303,14 @@ class Config: help="set default lengths for dynamic-sized arrays (excluding bytes and string) not specified in --array-lengths", global_default="0,1,2", metavar="LENGTH1,LENGTH2,...", - action=ParseCSV, + action=ParseCSVInt, ) default_bytes_lengths: str = arg( help="set default lengths for bytes and string types not specified in --array-lengths", global_default="0,65,1024", # 65 is ECDSA signature size metavar="LENGTH1,LENGTH2,...", - action=ParseCSV, + action=ParseCSVInt, ) storage_layout: str = arg( @@ -424,6 +450,13 @@ class Config: group=debugging, ) + trace_events: str = arg( + help="include specific events in traces", + global_default=",".join([e.value for e in TraceEvent]), + metavar="EVENT1,EVENT2,...", + action=ParseCSVTraceEvent, + ) + ### Build options forge_build_out: str = arg( diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index aec19af5..c74ce3c4 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -13,6 +13,7 @@ ForwardRef, Optional, TypeVar, + Union, ) import rich @@ -344,6 +345,20 @@ class EventLog: data: Bytes | None +@dataclass(frozen=True) +class StorageWrite: + address: Address + slot: Word + value: Word + + +@dataclass(frozen=True) +class StorageRead: + address: Address + slot: Word + value: Word + + @dataclass(frozen=True) class Message: target: Address @@ -387,7 +402,7 @@ class CallOutput: # - gas_left -TraceElement = ForwardRef("CallContext") | EventLog +TraceElement = Union["CallContext", EventLog, StorageRead, StorageWrite] @dataclass @@ -2037,9 +2052,13 @@ def mk_storagedata(self) -> StorageData: return self.storage_model.mk_storagedata() def sload(self, ex: Exec, addr: Any, loc: Word) -> Word: - return self.storage_model.load(ex, addr, loc) + val = self.storage_model.load(ex, addr, loc) + ex.context.trace.append(StorageRead(addr, loc, val)) + return val def sstore(self, ex: Exec, addr: Any, loc: Any, val: Any) -> None: + ex.context.trace.append(StorageWrite(addr, loc, val)) + if ex.message().is_static: raise WriteInStaticContext(ex.context_str()) diff --git a/src/halmos/traces.py b/src/halmos/traces.py index c5d58bc8..79db2742 100644 --- a/src/halmos/traces.py +++ b/src/halmos/traces.py @@ -1,23 +1,29 @@ import io import sys +from contextvars import ContextVar from z3 import Z3_OP_CONCAT, BitVecNumRef, BitVecRef, is_app from halmos.bytevec import ByteVec +from halmos.config import Config, TraceEvent from halmos.exceptions import HalmosException from halmos.mapper import DeployAddressMapper, Mapper -from halmos.sevm import CallContext, EventLog, mnemonic +from halmos.sevm import CallContext, EventLog, StorageRead, StorageWrite, mnemonic from halmos.utils import ( + Address, byte_length, cyan, green, hexify, is_bv, + magenta, red, unbox_int, yellow, ) +config_context: ContextVar[Config | None] = ContextVar("config", default=None) + def rendered_initcode(context: CallContext) -> str: message = context.message @@ -73,6 +79,16 @@ def render_output(context: CallContext, file=sys.stdout) -> None: ) +def rendered_address(addr: Address) -> str: + addr = unbox_int(addr) + addr_str = str(addr) if is_bv(addr) else hex(addr) + + # check if we have a contract name for this address in our deployment mapper + addr_str = DeployAddressMapper().get_deployed_contract(addr_str) + + return addr_str + + def rendered_log(log: EventLog) -> str: opcode_str = f"LOG{len(log.topics)}" topics = [ @@ -84,6 +100,28 @@ def rendered_log(log: EventLog) -> str: return f"{opcode_str}({args_str})" +def rendered_slot(slot: Address) -> str: + slot = unbox_int(slot) + + if is_bv(slot): + return magenta(hexify(slot)) + + if slot < 2**16: + return magenta(str(slot)) + + return magenta(hex(slot)) + + +def rendered_sstore(update: StorageWrite) -> str: + slot_str = rendered_slot(update.slot) + return f"{cyan('SSTORE')} @{slot_str} ← {hexify(update.value)}" + + +def rendered_sload(read: StorageRead) -> str: + slot_str = rendered_slot(read.slot) + return f"{cyan('SLOAD')} @{slot_str} → {hexify(read.value)}" + + def rendered_trace(context: CallContext) -> str: with io.StringIO() as output: render_trace(context, file=output) @@ -106,11 +144,12 @@ def rendered_calldata(calldata: ByteVec, contract_name: str | None = None) -> st def render_trace(context: CallContext, file=sys.stdout) -> None: + config: Config = config_context.get() + if config is None: + raise HalmosException("config not set") + message = context.message - addr = unbox_int(message.target) - addr_str = str(addr) if is_bv(addr) else hex(addr) - # check if we have a contract name for this address in our deployment mapper - addr_str = DeployAddressMapper().get_deployed_contract(addr_str) + addr_str = rendered_address(message.target) value = unbox_int(message.value) value_str = f" (value: {value})" if is_bv(value) or value > 0 else "" @@ -147,12 +186,20 @@ def render_trace(context: CallContext, file=sys.stdout) -> None: log_indent = (context.depth + 1) * " " for trace_element in context.trace: - if isinstance(trace_element, CallContext): - render_trace(trace_element, file=file) - elif isinstance(trace_element, EventLog): - print(f"{log_indent}{rendered_log(trace_element)}", file=file) - else: - raise HalmosException(f"unexpected trace element: {trace_element}") + match trace_element: + case CallContext(): + render_trace(trace_element, file=file) + case EventLog(): + if TraceEvent.LOG in config.trace_events: + print(f"{log_indent}{rendered_log(trace_element)}", file=file) + case StorageRead(): + if TraceEvent.SLOAD in config.trace_events: + print(f"{log_indent}{rendered_sload(trace_element)}", file=file) + case StorageWrite(): + if TraceEvent.SSTORE in config.trace_events: + print(f"{log_indent}{rendered_sstore(trace_element)}", file=file) + case _: + raise HalmosException(f"unexpected trace element: {trace_element}") render_output(context, file=file) diff --git a/src/halmos/utils.py b/src/halmos/utils.py index 5b9e3326..d185610f 100644 --- a/src/halmos/utils.py +++ b/src/halmos/utils.py @@ -590,7 +590,7 @@ def cyan(text: str) -> str: def magenta(text: str) -> str: - return f"\033[35m{text}\033[0m" + return f"\033[95m{text}\033[0m" color_good = green diff --git a/tests/test_config.py b/tests/test_config.py index 93499ed6..72f126e7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,7 +7,7 @@ from halmos.config import ( Config, ParseArrayLengths, - ParseCSV, + ParseCSVInt, ParseErrorCodes, arg_parser, default_config, @@ -192,21 +192,21 @@ def test_config_pickle(config, parser): def test_parse_csv(): with pytest.raises(ValueError): - ParseCSV.parse("") - ParseCSV.parse(" ") - ParseCSV.parse(",") - assert ParseCSV.parse("0") == [0] - assert ParseCSV.parse("0,") == [0] - assert ParseCSV.parse("1,2,3") == [1, 2, 3] - assert ParseCSV.parse("1,2,3,") == [1, 2, 3] - assert ParseCSV.parse(" 1 , 2 , 3 ") == [1, 2, 3] - assert ParseCSV.parse(" , 1 , 2 , 3 , ") == [1, 2, 3] + ParseCSVInt.parse("") + ParseCSVInt.parse(" ") + ParseCSVInt.parse(",") + assert ParseCSVInt.parse("0") == [0] + assert ParseCSVInt.parse("0,") == [0] + assert ParseCSVInt.parse("1,2,3") == [1, 2, 3] + assert ParseCSVInt.parse("1,2,3,") == [1, 2, 3] + assert ParseCSVInt.parse(" 1 , 2 , 3 ") == [1, 2, 3] + assert ParseCSVInt.parse(" , 1 , 2 , 3 , ") == [1, 2, 3] def test_unparse_csv(): - assert ParseCSV.unparse([]) == "" - assert ParseCSV.unparse([0]) == "0" - assert ParseCSV.unparse([1, 2, 3]) == "1,2,3" + assert ParseCSVInt.unparse([]) == "" + assert ParseCSVInt.unparse([0]) == "0" + assert ParseCSVInt.unparse([1, 2, 3]) == "1,2,3" def test_parse_csv_roundtrip(): @@ -216,8 +216,8 @@ def test_parse_csv_roundtrip(): ] for original in test_cases: - unparsed = ParseCSV.unparse(original) - parsed = ParseCSV.parse(unparsed) + unparsed = ParseCSVInt.unparse(original) + parsed = ParseCSVInt.parse(unparsed) assert parsed == original, f"Roundtrip failed for {original}"