Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fixed-length encoding for integers in JAM codec #226

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions internal/polkavm/interpreter/mutator.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package interpreter

import (
"github.com/eigerco/strawberry/pkg/serialization/codec/jam"
"math"

"github.com/eigerco/strawberry/pkg/serialization/codec/jam"

"github.com/eigerco/strawberry/internal/polkavm"
)

Expand Down Expand Up @@ -44,8 +45,13 @@ func load[T number](m *Mutator, dst polkavm.Reg, base *polkavm.Reg, offset uint3
}
address += offset
value := T(0)
slice := make([]byte, jam.IntLength(value))
err := m.instance.memory.Read(address, slice)
l, err := jam.IntLength(value)
if err != nil {
return err
}

slice := make([]byte, l)
err = m.instance.memory.Read(address, slice)
if err != nil {
return err
}
Expand Down
33 changes: 33 additions & 0 deletions pkg/serialization/codec/jam/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package jam

import (
"fmt"
"strings"
)

func IntLength(in any) (uint, error) {
switch in.(type) {
case uint8, int8:
return 1, nil
case uint16, int16:
return 2, nil
case uint32, int32:
return 4, nil
case uint64, int64:
return 8, nil
default:
return 0, fmt.Errorf(ErrUnsupportedType, in)
}
}

func parseTag(tag string) map[string]string {
result := make(map[string]string)
pairs := strings.Split(tag, ",")
for _, pair := range pairs {
kv := strings.Split(pair, "=")
if len(kv) == 2 {
result[kv[0]] = kv[1]
}
}
return result
}
151 changes: 88 additions & 63 deletions pkg/serialization/codec/jam/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"math"
"math/bits"
"reflect"
"strconv"

"github.com/eigerco/strawberry/internal/crypto"
)
Expand Down Expand Up @@ -42,7 +43,11 @@ func (br *byteReader) unmarshal(value reflect.Value) error {
case int, uint:
return br.decodeUint(value)
case int8, uint8, int16, uint16, int32, uint32, int64, uint64:
return br.decodeFixedWidthInt(value)
l, err := IntLength(value.Interface())
if err != nil {
return err
}
return br.decodeFixedWidth(value, l)
case []byte:
return br.decodeBytes(value)
case bool:
Expand Down Expand Up @@ -154,38 +159,26 @@ func (br *byteReader) decodeEnum(enum EnumType) error {
}

func (br *byteReader) decodePointer(value reflect.Value) error {
rb, err := br.ReadOctet()
isNil, err := br.readPointerMarker()
if err != nil {
return err
}

switch rb {
case 0x00:
// Handle the nil pointer case by setting the destination to nil if necessary
if isNil {
// Set the pointer to nil
if !value.IsNil() {
value.Set(reflect.Zero(value.Type()))
}
case 0x01:
// Check if the destination is a non-nil pointer
if !value.IsZero() {
// If it's a pointer to another pointer, we need to handle it recursively
if value.Elem().Kind() == reflect.Ptr {
return br.unmarshal(value.Elem().Elem())
}
return br.unmarshal(value.Elem())
}
return nil
}

// If value is nil or zero, we need to create a new instance
elemType := value.Type().Elem()
tempElem := reflect.New(elemType)
if err := br.unmarshal(tempElem.Elem()); err != nil {
return err
}
value.Set(tempElem)
default:
return ErrInvalidPointer
// Allocate space for the pointer if it's nil
if value.IsNil() {
value.Set(reflect.New(value.Type().Elem()))
}
return nil

// Decode the dereferenced value
return br.unmarshal(value.Elem())
}

func (br *byteReader) decodeSlice(value reflect.Value) error {
Expand Down Expand Up @@ -291,6 +284,19 @@ func (br *byteReader) decodeStruct(value reflect.Value) error {
if tag == "-" {
continue
}
tagValues := parseTag(tag)
if length, found := tagValues["length"]; found {
size, err := strconv.ParseUint(length, 10, 64)
if err != nil {
return fmt.Errorf(ErrInvalidLengthValue, fieldType.Name, err)
}

err = br.decodeFixedWidth(field, uint(size))
if err != nil {
return fmt.Errorf(ErrEncodingStructField, fieldType.Name, err)
}
continue
}
}

// Decode the field value
Expand Down Expand Up @@ -390,57 +396,91 @@ func (br *byteReader) decodeBytes(dstv reflect.Value) error {
return nil
}

func (br *byteReader) decodeFixedWidthInt(dstv reflect.Value) error {
in := dstv.Interface()
var buf []byte
length := IntLength(in)
func (br *byteReader) decodeFixedWidth(dstv reflect.Value, length uint) error {
typ := dstv.Type()

// Handle pointers
if typ.Kind() == reflect.Ptr {
isNil, err := br.readPointerMarker()
if err != nil {
return err
}
if isNil {
dstv.Set(reflect.Zero(typ))

return nil
}
if dstv.IsNil() {
dstv.Set(reflect.New(typ.Elem()))
}
dstv = dstv.Elem()
typ = typ.Elem()
}

// Read the appropriate number of bytes
buf = make([]byte, length)
// Read the data
buf := make([]byte, length)
_, err := br.Read(buf)
if err != nil {
return fmt.Errorf(ErrReadingByte, err)
}

// Deserialize the value
switch in.(type) {
case uint8:
switch typ.Kind() {
case reflect.Uint8:
var temp uint8
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(temp))
case uint16:
dstv.Set(reflect.ValueOf(temp).Convert(typ))
case reflect.Uint16:
var temp uint16
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(temp))
case uint32:
dstv.Set(reflect.ValueOf(temp).Convert(typ))
case reflect.Uint32:
var temp uint32
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(temp))
case uint64:
dstv.Set(reflect.ValueOf(temp).Convert(typ))
case reflect.Uint64:
var temp uint64
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(temp))
case int8:
dstv.Set(reflect.ValueOf(temp).Convert(typ))
case reflect.Int8:
var temp uint8
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(int8(temp)))
case int16:
dstv.Set(reflect.ValueOf(int8(temp)).Convert(typ))
case reflect.Int16:
var temp uint16
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(int16(temp)))
case int32:
dstv.Set(reflect.ValueOf(int16(temp)).Convert(typ))
case reflect.Int32:
var temp uint32
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(int32(temp)))
case int64:
dstv.Set(reflect.ValueOf(int32(temp)).Convert(typ))
case reflect.Int64:
var temp uint64
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(int64(temp)))
dstv.Set(reflect.ValueOf(int64(temp)).Convert(typ))
default:
return fmt.Errorf(ErrUnsupportedType, typ)
}

return nil
}

func (br *byteReader) readPointerMarker() (bool, error) {
var marker [1]byte
_, err := br.Read(marker[:])
if err != nil {
return false, err
}

switch marker[0] {
case 0x00:
return true, nil // Nil pointer
case 0x01:
return false, nil // Non-nil pointer
default:
return false, ErrInvalidPointer
}
}

// indirect recursively dereferences pointers and interfaces,
// allocating new pointers as needed, until it reaches a non-pointer value.
func indirect(v reflect.Value) reflect.Value {
Expand All @@ -461,18 +501,3 @@ func indirect(v reflect.Value) reflect.Value {
}
}
}

func IntLength(in any) int {
switch in.(type) {
case uint8, int8:
return 1
case uint16, int16:
return 2
case uint32, int32:
return 4
case uint64, int64:
return 8
default:
panic(fmt.Errorf(ErrUnsupportedType, in))
}
}
Loading
Loading