From 35c03b247b5180e15b9b356c794ac51825fdcce6 Mon Sep 17 00:00:00 2001 From: "daniel.vladco" Date: Mon, 13 Jan 2025 15:36:27 +0200 Subject: [PATCH] feat: invoke host call --- internal/polkavm/host_call/common.go | 11 ++- .../polkavm/host_call/refine_functions.go | 91 ++++++++++++++++- .../host_call/refine_functions_test.go | 98 +++++++++++++++++++ 3 files changed, 198 insertions(+), 2 deletions(-) diff --git a/internal/polkavm/host_call/common.go b/internal/polkavm/host_call/common.go index f4f0c266..22ea79e9 100644 --- a/internal/polkavm/host_call/common.go +++ b/internal/polkavm/host_call/common.go @@ -79,6 +79,15 @@ const ( OK Code = 0 ) +// Inner pvm invocations have their own set of result codes +const ( + HALT = 0 // The invocation completed and halted normally. + PANIC = 1 // The invocation completed with a panic. + FAULT = 2 // The invocation completed with a page fault. + HOST = 3 // The invocation completed with a host-call fault. + OOG = 4 // The invocation completed by running out of gas. +) + func (r Code) String() string { switch r { case NONE: @@ -107,7 +116,7 @@ func (r Code) String() string { return "unknown" } -func readNumber[U interface{ ~uint32 | ~uint64 }](mem Memory, addr uint32, length int) (u U, err error) { +func readNumber[U interface{ ~uint32 | ~uint64 | ~int64 }](mem Memory, addr uint32, length int) (u U, err error) { b := make([]byte, length) if err = mem.Read(addr, b); err != nil { return diff --git a/internal/polkavm/host_call/refine_functions.go b/internal/polkavm/host_call/refine_functions.go index f3e945e3..d6e3f2cf 100644 --- a/internal/polkavm/host_call/refine_functions.go +++ b/internal/polkavm/host_call/refine_functions.go @@ -1,6 +1,9 @@ package host_call import ( + "bytes" + "errors" + "log" "math" "github.com/eigerco/strawberry/internal/block" @@ -8,8 +11,10 @@ import ( "github.com/eigerco/strawberry/internal/crypto" "github.com/eigerco/strawberry/internal/jamtime" . "github.com/eigerco/strawberry/internal/polkavm" + "github.com/eigerco/strawberry/internal/polkavm/interpreter" "github.com/eigerco/strawberry/internal/service" "github.com/eigerco/strawberry/internal/work" + "github.com/eigerco/strawberry/pkg/serialization/codec/jam" ) // HistoricalLookup ΩH(ϱ, ω, µ, (m, e), s, d, t) @@ -366,7 +371,91 @@ func Invoke( mem Memory, ctxPair RefineContextPair, ) (Gas, Registers, Memory, RefineContextPair, error) { - return gas, regs, mem, ctxPair, nil + if gas < InvokeCost { + return gas, regs, mem, ctxPair, ErrOutOfGas + } + gas -= InvokeCost + // let [n, o] = ω7,8 + pvmKey, addr := regs[A0], regs[A1] + + // let (g, w) = (g, w) ∶ E8(g) ⌢ E#8(w) = μo⋅⋅⋅+112 if No⋅⋅⋅+112 ⊂ V∗μ + invokeGas, err := readNumber[Gas](mem, uint32(addr), 8) + if err != nil { + return gas, withCode(regs, OOB), mem, ctxPair, nil + } + var invokeRegs Registers // w + for i := range 13 { + invokeReg, err := readNumber[uint64](mem, uint32(addr+(uint64(i+1)*8)), 8) + if err != nil { + return gas, withCode(regs, OOB), mem, ctxPair, nil + } + invokeRegs[i] = invokeReg + } + + // let (c, i′, g′, w′, u′) = Ψ(m[n]p, m[n]i, g, w, m[n]u) + pvm, ok := ctxPair.IntegratedPVMMap[pvmKey] + if !ok { // if n ∉ m + return gas, withCode(regs, WHO), mem, ctxPair, nil // (WHO, ω8, μ, m) + } + updateIntegratedPVM := func(isHostCall bool, resultInstr uint32, resultMem Memory) { + pvm.Ram = resultMem + if isHostCall { + // m*[n]i = i′ + 1 if c ∈ {̵h} × NR + pvm.InstructionCounter = resultInstr + 1 + } else { + // m*[n]i = i′ + pvm.InstructionCounter = resultInstr + } + ctxPair.IntegratedPVMMap[pvmKey] = pvm + } + + // we only parse the code and jump table as we are not expected to invoke a full program + program := &Program{} + if err := ParseCodeAndJumpTable(uint32(len(pvm.Code)), NewReader(bytes.NewReader(pvm.Code)), program); err != nil { + return gas, withCode(regs, PANIC), mem, ctxPair, nil + } + + log.Println("invokeGas", invokeGas) + log.Println("invokeRegs", invokeRegs) + resultInstr, resultGas, resultRegs, resultMem, hostCall, invokeErr := interpreter.Invoke(program, nil, pvm.InstructionCounter, invokeGas, invokeRegs, pvm.Ram) + + if bb, err := jam.Marshal([14]uint64(append([]uint64{uint64(resultGas)}, resultRegs[:]...))); err != nil { + return gas, withCode(regs, OOB), mem, ctxPair, nil // (OOB, ω8, μ, m) + } else if err := mem.Write(uint32(addr), bb); err != nil { + return gas, withCode(regs, OOB), mem, ctxPair, nil // (OOB, ω8, μ, m) + } + if invokeErr != nil { + if errors.Is(invokeErr, ErrOutOfGas) { + updateIntegratedPVM(false, resultInstr, resultMem) + return gas, withCode(regs, OOG), mem, ctxPair, nil // (OOG, ω8, μ*, m*) + } + if errors.Is(invokeErr, ErrHalt) { + updateIntegratedPVM(false, resultInstr, resultMem) + return gas, withCode(regs, HALT), mem, ctxPair, nil // (HALT, ω8, μ*, m*) + } + if errors.Is(invokeErr, ErrHostCall) { + updateIntegratedPVM(true, resultInstr, resultMem) + regs[A1] = uint64(hostCall) + return gas, withCode(regs, HOST), mem, ctxPair, nil // (HOST, h, μ*, m*) + } + pageFault := &ErrPageFault{} + if errors.As(invokeErr, &pageFault) { + updateIntegratedPVM(false, resultInstr, resultMem) + regs[A1] = uint64(pageFault.Address) + return gas, withCode(regs, FAULT), mem, ctxPair, nil + } + panicErr := &ErrPanic{} + if errors.As(invokeErr, &panicErr) { + updateIntegratedPVM(false, resultInstr, resultMem) + return gas, withCode(regs, PANIC), mem, ctxPair, nil + } + + // must never occur + panic(invokeErr) + } + + updateIntegratedPVM(false, resultInstr, resultMem) + return gas, withCode(regs, HALT), mem, ctxPair, nil // (HALT, ω8, μ*, m*) } // Expunge ΩX(ϱ, ω, µ, (m, e)) diff --git a/internal/polkavm/host_call/refine_functions_test.go b/internal/polkavm/host_call/refine_functions_test.go index 9531838b..635b72ef 100644 --- a/internal/polkavm/host_call/refine_functions_test.go +++ b/internal/polkavm/host_call/refine_functions_test.go @@ -15,6 +15,7 @@ import ( "github.com/eigerco/strawberry/internal/polkavm/host_call" "github.com/eigerco/strawberry/internal/polkavm/interpreter" "github.com/eigerco/strawberry/internal/service" + "github.com/eigerco/strawberry/pkg/serialization/codec/jam" ) var initialGas = uint64(100) @@ -603,6 +604,103 @@ func TestVoid(t *testing.T) { assert.Equal(t, expectedGasRemaining, gasRemaining) } +func TestInvoke(t *testing.T) { + pp := &polkavm.Program{ + Instructions: []polkavm.Instruction{ + {Opcode: polkavm.Ecalli, Imm: []uint32{0}, Offset: 0, Length: 1}, + {Opcode: polkavm.JumpIndirect, Imm: []uint32{0}, Reg: []polkavm.Reg{polkavm.RA}, Offset: 1, Length: 2}, + }, + } + + memoryMap, err := polkavm.NewMemoryMap(0, 128*1024, 0, 0) + require.NoError(t, err) + + mem := memoryMap.NewMemory(nil, nil, nil) + + bb, err := jam.Marshal([14]uint64{ + 10000, // gas + 0, // regs + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 2, + 0, + 0, + 0, + 0, + }) + require.NoError(t, err) + + addr := memoryMap.RWDataAddress + if err := mem.Write(addr, bb); err != nil { + t.Fatal(err) + } + + pvmKey := uint64(0) + + ctxPair := polkavm.RefineContextPair{ + IntegratedPVMMap: map[uint64]polkavm.IntegratedPVM{pvmKey: { + Code: addInstrProgram, + Ram: polkavm.Memory{}, // we don't use memory in tests yet + InstructionCounter: 0, + }}, + } + + initialRegs := polkavm.Registers{ + polkavm.RA: polkavm.VmAddressReturnToHost, + polkavm.SP: uint64(memoryMap.StackAddressHigh), + polkavm.A0: pvmKey, + polkavm.A1: uint64(addr), + } + + hostCall := func(hc uint32, gasCounter polkavm.Gas, regs polkavm.Registers, + mm polkavm.Memory, x struct{}, + ) (polkavm.Gas, polkavm.Registers, polkavm.Memory, struct{}, error) { + gasOut, regsOut, memOut, ctxOut, err := host_call.Invoke(gasCounter, regs, mm, ctxPair) + require.NoError(t, err) + ctxPair = ctxOut + return gasOut, regsOut, memOut, x, err + } + + gasRemaining, regsOut, _, _, err := interpreter.InvokeHostCall( + pp, + memoryMap, + 0, + initialGas, + initialRegs, + mem, + hostCall, + struct{}{}, + ) + require.ErrorIs(t, err, polkavm.ErrHalt) + + assert.Equal(t, uint64(host_call.PANIC), regsOut[polkavm.A0]) + + expectedGasRemaining := polkavm.Gas(initialGas) - + host_call.InvokeCost - + polkavm.GasCosts[polkavm.Ecalli] - + polkavm.GasCosts[polkavm.JumpIndirect] + assert.Equal(t, expectedGasRemaining, gasRemaining) + + invokeResult := make([]byte, 112) + err = mem.Read(addr, invokeResult) + require.NoError(t, err) + + invokeGasAndRegs := [14]uint64{} + err = jam.Unmarshal(invokeResult, &invokeGasAndRegs) + require.NoError(t, err) + + assert.Equal(t, uint32(3), ctxPair.IntegratedPVMMap[pvmKey].InstructionCounter) + assert.Equal(t, uint64(9998), invokeGasAndRegs[0]) + assert.Equal(t, []uint64{0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 0}, invokeGasAndRegs[1:]) +} + +var addInstrProgram = []byte{0, 0, 3, 8, 135, 9, 1} // copied from testvectors + func TestExpunge(t *testing.T) { pp := &polkavm.Program{ Instructions: []polkavm.Instruction{