Skip to content

Commit

Permalink
Add import host function (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
pantrif authored Dec 13, 2024
1 parent 83fc552 commit d7c80f1
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 3 deletions.
2 changes: 2 additions & 0 deletions internal/polkavm/host_call/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
SolicitCost
ForgetCost
HistoricalLookupCost
ImportCost
)

const (
Expand All @@ -43,6 +44,7 @@ const (
SolicitID
ForgetID
HistoricalLookupID
ImportID
)

type Code uint64
Expand Down
42 changes: 39 additions & 3 deletions internal/polkavm/host_call/refine_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import (
"math"

"github.com/eigerco/strawberry/internal/block"
"github.com/eigerco/strawberry/internal/common"
"github.com/eigerco/strawberry/internal/crypto"
"github.com/eigerco/strawberry/internal/jamtime"
. "github.com/eigerco/strawberry/internal/polkavm"
"github.com/eigerco/strawberry/internal/service"
)

// HistoricalLookup ΩH(ϱ, ω, µ, (m, e), s,d, t)
// HistoricalLookup ΩH(ϱ, ω, µ, (m, e), s, d, t)
func HistoricalLookup(
gas Gas,
regs Registers,
Expand Down Expand Up @@ -53,8 +54,7 @@ func HistoricalLookup(
v := a.LookupPreimage(t, h)

if len(v) == 0 {
regs[A0] = uint64(NONE)
return gas, regs, mem, ctxPair, nil
return gas, withCode(regs, NONE), mem, ctxPair, nil
}

if uint64(len(v)) > bz {
Expand All @@ -70,3 +70,39 @@ func HistoricalLookup(

return gas, regs, mem, ctxPair, nil
}

// Import ΩY(ϱ, ω, µ, (m, e), i)
func Import(
gas Gas,
regs Registers,
mem Memory,
ctxPair RefineContextPair,
importedSegments []Segment,
) (Gas, Registers, Memory, RefineContextPair, error) {
if gas < ImportCost {
return gas, regs, mem, ctxPair, ErrOutOfGas
}
gas -= ImportCost

index := regs[A0] // ω7
offset := regs[A1] // ω8
length := regs[A2] // ω9

// v = ∅
if index >= uint64(len(importedSegments)) {
// v = ∅, return NONE
return gas, withCode(regs, NONE), mem, ctxPair, nil
}

// v = i[ω7]
v := importedSegments[index][:]

l := min(length, common.SizeOfSegment)

segmentToWrite := v[:l]
if err := mem.Write(uint32(offset), segmentToWrite); err != nil {
return gas, withCode(regs, OOB), mem, ctxPair, nil
}

return gas, withCode(regs, OK), mem, ctxPair, nil
}
64 changes: 64 additions & 0 deletions internal/polkavm/host_call/refine_functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/eigerco/strawberry/internal/block"
"github.com/eigerco/strawberry/internal/common"
"github.com/eigerco/strawberry/internal/crypto"
"github.com/eigerco/strawberry/internal/jamtime"
"github.com/eigerco/strawberry/internal/polkavm"
Expand Down Expand Up @@ -104,3 +105,66 @@ func TestHistoricalLookup(t *testing.T) {
expectedGasRemaining := polkavm.Gas(initialGas) - host_call.HistoricalLookupCost - polkavm.GasCosts[polkavm.Ecalli] - polkavm.GasCosts[polkavm.JumpIndirect]
assert.Equal(t, expectedGasRemaining, gasRemaining)
}

func TestImport(t *testing.T) {
pp := &polkavm.Program{
Instructions: []polkavm.Instruction{
{Opcode: polkavm.Ecalli, Imm: []uint32{0}, Offset: 0, Length: 1},
{Opcode: polkavm.JumpIndirect, Imm: []uint32{0}, Reg: []polkavm.Reg{polkavm.RA}, Offset: 1, Length: 2},
},
}

memoryMap, err := polkavm.NewMemoryMap(0, 256, 512, 0)
require.NoError(t, err)

initialGas := uint64(100)

segmentData := [common.SizeOfSegment]byte{}
for i := range segmentData {
segmentData[i] = byte('A')
}
importedSegments := []polkavm.Segment{segmentData}

bo := memoryMap.RWDataAddress + 100
bz := uint32(50)

initialRegs := polkavm.Registers{
polkavm.RA: polkavm.VmAddressReturnToHost,
polkavm.SP: uint64(memoryMap.StackAddressHigh),
polkavm.A0: uint64(0),
polkavm.A1: uint64(bo),
polkavm.A2: uint64(bz),
}

mem := memoryMap.NewMemory(nil, nil, nil)

hostCall := func(hostCall uint32, gasCounter polkavm.Gas, regs polkavm.Registers, mem polkavm.Memory, x service.ServiceAccount) (polkavm.Gas, polkavm.Registers, polkavm.Memory, service.ServiceAccount, error) {
gasCounterOut, regsOut, memOut, _, err := host_call.Import(
gasCounter,
regs,
mem,
polkavm.RefineContextPair{},
importedSegments,
)
require.NoError(t, err)
return gasCounterOut, regsOut, memOut, x, err
}

gasRemaining, regsOut, memOut, _, err := interpreter.InvokeHostCall(pp, memoryMap, 0, initialGas, initialRegs, mem, hostCall, service.ServiceAccount{})
require.ErrorIs(t, err, polkavm.ErrHalt)

actualValue := make([]byte, bz)
err = memOut.Read(bo, actualValue)
require.NoError(t, err)

expectedData := make([]byte, bz)
for i := range expectedData {
expectedData[i] = 'A'
}

assert.Equal(t, expectedData, actualValue)
assert.Equal(t, uint64(host_call.OK), regsOut[polkavm.A0])

expectedGasRemaining := polkavm.Gas(initialGas) - host_call.ImportCost - polkavm.GasCosts[polkavm.Ecalli] - polkavm.GasCosts[polkavm.JumpIndirect]
assert.Equal(t, expectedGasRemaining, gasRemaining)
}

0 comments on commit d7c80f1

Please sign in to comment.