Skip to content

Commit

Permalink
add SSTORE and SLOAD to traces (#456)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xkarmacoma authored Feb 13, 2025
1 parent e65bc04 commit a307d06
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 33 deletions.
6 changes: 6 additions & 0 deletions src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
unsat,
)

import halmos.traces

from .build import (
build_output_iterator,
import_libs,
Expand Down Expand Up @@ -401,6 +403,9 @@ def run_test(ctx: FunctionContext) -> TestResult:
if args.verbose >= 1:
print(f"Executing {funname}")

# set the config for every trace rendered in this test
halmos.traces.config_context.set(args)

#
# prepare calldata
#
Expand Down Expand Up @@ -742,6 +747,7 @@ def run_contract(ctx: ContractContext) -> list[TestResult]:
contract_ctx=ctx,
)

halmos.traces.config_context.set(setup_config)
setup_ex = setup(setup_ctx)
except Exception as err:
error(f"{setup_info.sig} failed: {type(err).__name__}: {err}")
Expand Down
41 changes: 37 additions & 4 deletions src/halmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Callable, Generator
from dataclasses import MISSING, dataclass, fields
from dataclasses import field as dataclass_field
from enum import Enum
from pathlib import Path
from typing import Any

Expand All @@ -26,6 +27,12 @@
)


class TraceEvent(Enum):
LOG = "LOG"
SSTORE = "SSTORE"
SLOAD = "SLOAD"


def find_venv_root() -> Path | None:
# If the environment variable is set, use that
if "VIRTUAL_ENV" in os.environ:
Expand Down Expand Up @@ -88,9 +95,28 @@ def parse_csv(values: str, sep: str = ",") -> Generator[Any, None, None]:
return (x for _x in values.split(sep) if (x := _x.strip()))


class ParseCSV(argparse.Action):
class ParseCSVTraceEvent(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
values = ParseCSV.parse(values)
values = ParseCSVTraceEvent.parse(values)
setattr(namespace, self.dest, values)

@staticmethod
def parse(values: str) -> list[TraceEvent]:
# empty list is ok
try:
return [TraceEvent(x) for x in parse_csv(values)]
except ValueError as e:
valid = ", ".join([e.value for e in TraceEvent])
raise ValueError(f"the list of valid trace events is: {valid}") from e

@staticmethod
def unparse(values: list[TraceEvent]) -> str:
return ",".join([x.value for x in values])


class ParseCSVInt(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
values = ParseCSVInt.parse(values)
setattr(namespace, self.dest, values)

@staticmethod
Expand Down Expand Up @@ -277,14 +303,14 @@ class Config:
help="set default lengths for dynamic-sized arrays (excluding bytes and string) not specified in --array-lengths",
global_default="0,1,2",
metavar="LENGTH1,LENGTH2,...",
action=ParseCSV,
action=ParseCSVInt,
)

default_bytes_lengths: str = arg(
help="set default lengths for bytes and string types not specified in --array-lengths",
global_default="0,65,1024", # 65 is ECDSA signature size
metavar="LENGTH1,LENGTH2,...",
action=ParseCSV,
action=ParseCSVInt,
)

storage_layout: str = arg(
Expand Down Expand Up @@ -424,6 +450,13 @@ class Config:
group=debugging,
)

trace_events: str = arg(
help="include specific events in traces",
global_default=",".join([e.value for e in TraceEvent]),
metavar="EVENT1,EVENT2,...",
action=ParseCSVTraceEvent,
)

### Build options

forge_build_out: str = arg(
Expand Down
23 changes: 21 additions & 2 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ForwardRef,
Optional,
TypeVar,
Union,
)

import rich
Expand Down Expand Up @@ -344,6 +345,20 @@ class EventLog:
data: Bytes | None


@dataclass(frozen=True)
class StorageWrite:
address: Address
slot: Word
value: Word


@dataclass(frozen=True)
class StorageRead:
address: Address
slot: Word
value: Word


@dataclass(frozen=True)
class Message:
target: Address
Expand Down Expand Up @@ -387,7 +402,7 @@ class CallOutput:
# - gas_left


TraceElement = ForwardRef("CallContext") | EventLog
TraceElement = Union["CallContext", EventLog, StorageRead, StorageWrite]


@dataclass
Expand Down Expand Up @@ -2037,9 +2052,13 @@ def mk_storagedata(self) -> StorageData:
return self.storage_model.mk_storagedata()

def sload(self, ex: Exec, addr: Any, loc: Word) -> Word:
return self.storage_model.load(ex, addr, loc)
val = self.storage_model.load(ex, addr, loc)
ex.context.trace.append(StorageRead(addr, loc, val))
return val

def sstore(self, ex: Exec, addr: Any, loc: Any, val: Any) -> None:
ex.context.trace.append(StorageWrite(addr, loc, val))

if ex.message().is_static:
raise WriteInStaticContext(ex.context_str())

Expand Down
69 changes: 58 additions & 11 deletions src/halmos/traces.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
import io
import sys
from contextvars import ContextVar

from z3 import Z3_OP_CONCAT, BitVecNumRef, BitVecRef, is_app

from halmos.bytevec import ByteVec
from halmos.config import Config, TraceEvent
from halmos.exceptions import HalmosException
from halmos.mapper import DeployAddressMapper, Mapper
from halmos.sevm import CallContext, EventLog, mnemonic
from halmos.sevm import CallContext, EventLog, StorageRead, StorageWrite, mnemonic
from halmos.utils import (
Address,
byte_length,
cyan,
green,
hexify,
is_bv,
magenta,
red,
unbox_int,
yellow,
)

config_context: ContextVar[Config | None] = ContextVar("config", default=None)


def rendered_initcode(context: CallContext) -> str:
message = context.message
Expand Down Expand Up @@ -73,6 +79,16 @@ def render_output(context: CallContext, file=sys.stdout) -> None:
)


def rendered_address(addr: Address) -> str:
addr = unbox_int(addr)
addr_str = str(addr) if is_bv(addr) else hex(addr)

# check if we have a contract name for this address in our deployment mapper
addr_str = DeployAddressMapper().get_deployed_contract(addr_str)

return addr_str


def rendered_log(log: EventLog) -> str:
opcode_str = f"LOG{len(log.topics)}"
topics = [
Expand All @@ -84,6 +100,28 @@ def rendered_log(log: EventLog) -> str:
return f"{opcode_str}({args_str})"


def rendered_slot(slot: Address) -> str:
slot = unbox_int(slot)

if is_bv(slot):
return magenta(hexify(slot))

if slot < 2**16:
return magenta(str(slot))

return magenta(hex(slot))


def rendered_sstore(update: StorageWrite) -> str:
slot_str = rendered_slot(update.slot)
return f"{cyan('SSTORE')} @{slot_str}{hexify(update.value)}"


def rendered_sload(read: StorageRead) -> str:
slot_str = rendered_slot(read.slot)
return f"{cyan('SLOAD')} @{slot_str}{hexify(read.value)}"


def rendered_trace(context: CallContext) -> str:
with io.StringIO() as output:
render_trace(context, file=output)
Expand All @@ -106,11 +144,12 @@ def rendered_calldata(calldata: ByteVec, contract_name: str | None = None) -> st


def render_trace(context: CallContext, file=sys.stdout) -> None:
config: Config = config_context.get()
if config is None:
raise HalmosException("config not set")

message = context.message
addr = unbox_int(message.target)
addr_str = str(addr) if is_bv(addr) else hex(addr)
# check if we have a contract name for this address in our deployment mapper
addr_str = DeployAddressMapper().get_deployed_contract(addr_str)
addr_str = rendered_address(message.target)

value = unbox_int(message.value)
value_str = f" (value: {value})" if is_bv(value) or value > 0 else ""
Expand Down Expand Up @@ -147,12 +186,20 @@ def render_trace(context: CallContext, file=sys.stdout) -> None:

log_indent = (context.depth + 1) * " "
for trace_element in context.trace:
if isinstance(trace_element, CallContext):
render_trace(trace_element, file=file)
elif isinstance(trace_element, EventLog):
print(f"{log_indent}{rendered_log(trace_element)}", file=file)
else:
raise HalmosException(f"unexpected trace element: {trace_element}")
match trace_element:
case CallContext():
render_trace(trace_element, file=file)
case EventLog():
if TraceEvent.LOG in config.trace_events:
print(f"{log_indent}{rendered_log(trace_element)}", file=file)
case StorageRead():
if TraceEvent.SLOAD in config.trace_events:
print(f"{log_indent}{rendered_sload(trace_element)}", file=file)
case StorageWrite():
if TraceEvent.SSTORE in config.trace_events:
print(f"{log_indent}{rendered_sstore(trace_element)}", file=file)
case _:
raise HalmosException(f"unexpected trace element: {trace_element}")

render_output(context, file=file)

Expand Down
2 changes: 1 addition & 1 deletion src/halmos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def cyan(text: str) -> str:


def magenta(text: str) -> str:
return f"\033[35m{text}\033[0m"
return f"\033[95m{text}\033[0m"


color_good = green
Expand Down
30 changes: 15 additions & 15 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from halmos.config import (
Config,
ParseArrayLengths,
ParseCSV,
ParseCSVInt,
ParseErrorCodes,
arg_parser,
default_config,
Expand Down Expand Up @@ -192,21 +192,21 @@ def test_config_pickle(config, parser):

def test_parse_csv():
with pytest.raises(ValueError):
ParseCSV.parse("")
ParseCSV.parse(" ")
ParseCSV.parse(",")
assert ParseCSV.parse("0") == [0]
assert ParseCSV.parse("0,") == [0]
assert ParseCSV.parse("1,2,3") == [1, 2, 3]
assert ParseCSV.parse("1,2,3,") == [1, 2, 3]
assert ParseCSV.parse(" 1 , 2 , 3 ") == [1, 2, 3]
assert ParseCSV.parse(" , 1 , 2 , 3 , ") == [1, 2, 3]
ParseCSVInt.parse("")
ParseCSVInt.parse(" ")
ParseCSVInt.parse(",")
assert ParseCSVInt.parse("0") == [0]
assert ParseCSVInt.parse("0,") == [0]
assert ParseCSVInt.parse("1,2,3") == [1, 2, 3]
assert ParseCSVInt.parse("1,2,3,") == [1, 2, 3]
assert ParseCSVInt.parse(" 1 , 2 , 3 ") == [1, 2, 3]
assert ParseCSVInt.parse(" , 1 , 2 , 3 , ") == [1, 2, 3]


def test_unparse_csv():
assert ParseCSV.unparse([]) == ""
assert ParseCSV.unparse([0]) == "0"
assert ParseCSV.unparse([1, 2, 3]) == "1,2,3"
assert ParseCSVInt.unparse([]) == ""
assert ParseCSVInt.unparse([0]) == "0"
assert ParseCSVInt.unparse([1, 2, 3]) == "1,2,3"


def test_parse_csv_roundtrip():
Expand All @@ -216,8 +216,8 @@ def test_parse_csv_roundtrip():
]

for original in test_cases:
unparsed = ParseCSV.unparse(original)
parsed = ParseCSV.parse(unparsed)
unparsed = ParseCSVInt.unparse(original)
parsed = ParseCSVInt.parse(unparsed)
assert parsed == original, f"Roundtrip failed for {original}"


Expand Down

0 comments on commit a307d06

Please sign in to comment.