Skip to content

Commit

Permalink
Support fixed-length encoding for integers in JAM codec (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
pantrif authored Jan 16, 2025
1 parent 1712821 commit ce23ec9
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 99 deletions.
12 changes: 9 additions & 3 deletions internal/polkavm/interpreter/mutator.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package interpreter

import (
"github.com/eigerco/strawberry/pkg/serialization/codec/jam"
"math"

"github.com/eigerco/strawberry/pkg/serialization/codec/jam"

"github.com/eigerco/strawberry/internal/polkavm"
)

Expand Down Expand Up @@ -44,8 +45,13 @@ func load[T number](m *Mutator, dst polkavm.Reg, base *polkavm.Reg, offset uint3
}
address += offset
value := T(0)
slice := make([]byte, jam.IntLength(value))
err := m.instance.memory.Read(address, slice)
l, err := jam.IntLength(value)
if err != nil {
return err
}

slice := make([]byte, l)
err = m.instance.memory.Read(address, slice)
if err != nil {
return err
}
Expand Down
33 changes: 33 additions & 0 deletions pkg/serialization/codec/jam/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package jam

import (
"fmt"
"strings"
)

func IntLength(in any) (uint, error) {
switch in.(type) {
case uint8, int8:
return 1, nil
case uint16, int16:
return 2, nil
case uint32, int32:
return 4, nil
case uint64, int64:
return 8, nil
default:
return 0, fmt.Errorf(ErrUnsupportedType, in)
}
}

func parseTag(tag string) map[string]string {
result := make(map[string]string)
pairs := strings.Split(tag, ",")
for _, pair := range pairs {
kv := strings.Split(pair, "=")
if len(kv) == 2 {
result[kv[0]] = kv[1]
}
}
return result
}
151 changes: 88 additions & 63 deletions pkg/serialization/codec/jam/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"math"
"math/bits"
"reflect"
"strconv"

"github.com/eigerco/strawberry/internal/crypto"
)
Expand Down Expand Up @@ -42,7 +43,11 @@ func (br *byteReader) unmarshal(value reflect.Value) error {
case int, uint:
return br.decodeUint(value)
case int8, uint8, int16, uint16, int32, uint32, int64, uint64:
return br.decodeFixedWidthInt(value)
l, err := IntLength(value.Interface())
if err != nil {
return err
}
return br.decodeFixedWidth(value, l)
case []byte:
return br.decodeBytes(value)
case bool:
Expand Down Expand Up @@ -154,38 +159,26 @@ func (br *byteReader) decodeEnum(enum EnumType) error {
}

func (br *byteReader) decodePointer(value reflect.Value) error {
rb, err := br.ReadOctet()
isNil, err := br.readPointerMarker()
if err != nil {
return err
}

switch rb {
case 0x00:
// Handle the nil pointer case by setting the destination to nil if necessary
if isNil {
// Set the pointer to nil
if !value.IsNil() {
value.Set(reflect.Zero(value.Type()))
}
case 0x01:
// Check if the destination is a non-nil pointer
if !value.IsZero() {
// If it's a pointer to another pointer, we need to handle it recursively
if value.Elem().Kind() == reflect.Ptr {
return br.unmarshal(value.Elem().Elem())
}
return br.unmarshal(value.Elem())
}
return nil
}

// If value is nil or zero, we need to create a new instance
elemType := value.Type().Elem()
tempElem := reflect.New(elemType)
if err := br.unmarshal(tempElem.Elem()); err != nil {
return err
}
value.Set(tempElem)
default:
return ErrInvalidPointer
// Allocate space for the pointer if it's nil
if value.IsNil() {
value.Set(reflect.New(value.Type().Elem()))
}
return nil

// Decode the dereferenced value
return br.unmarshal(value.Elem())
}

func (br *byteReader) decodeSlice(value reflect.Value) error {
Expand Down Expand Up @@ -291,6 +284,19 @@ func (br *byteReader) decodeStruct(value reflect.Value) error {
if tag == "-" {
continue
}
tagValues := parseTag(tag)
if length, found := tagValues["length"]; found {
size, err := strconv.ParseUint(length, 10, 64)
if err != nil {
return fmt.Errorf(ErrInvalidLengthValue, fieldType.Name, err)
}

err = br.decodeFixedWidth(field, uint(size))
if err != nil {
return fmt.Errorf(ErrEncodingStructField, fieldType.Name, err)
}
continue
}
}

// Decode the field value
Expand Down Expand Up @@ -390,57 +396,91 @@ func (br *byteReader) decodeBytes(dstv reflect.Value) error {
return nil
}

func (br *byteReader) decodeFixedWidthInt(dstv reflect.Value) error {
in := dstv.Interface()
var buf []byte
length := IntLength(in)
func (br *byteReader) decodeFixedWidth(dstv reflect.Value, length uint) error {
typ := dstv.Type()

// Handle pointers
if typ.Kind() == reflect.Ptr {
isNil, err := br.readPointerMarker()
if err != nil {
return err
}
if isNil {
dstv.Set(reflect.Zero(typ))

return nil
}
if dstv.IsNil() {
dstv.Set(reflect.New(typ.Elem()))
}
dstv = dstv.Elem()
typ = typ.Elem()
}

// Read the appropriate number of bytes
buf = make([]byte, length)
// Read the data
buf := make([]byte, length)
_, err := br.Read(buf)
if err != nil {
return fmt.Errorf(ErrReadingByte, err)
}

// Deserialize the value
switch in.(type) {
case uint8:
switch typ.Kind() {
case reflect.Uint8:
var temp uint8
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(temp))
case uint16:
dstv.Set(reflect.ValueOf(temp).Convert(typ))
case reflect.Uint16:
var temp uint16
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(temp))
case uint32:
dstv.Set(reflect.ValueOf(temp).Convert(typ))
case reflect.Uint32:
var temp uint32
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(temp))
case uint64:
dstv.Set(reflect.ValueOf(temp).Convert(typ))
case reflect.Uint64:
var temp uint64
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(temp))
case int8:
dstv.Set(reflect.ValueOf(temp).Convert(typ))
case reflect.Int8:
var temp uint8
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(int8(temp)))
case int16:
dstv.Set(reflect.ValueOf(int8(temp)).Convert(typ))
case reflect.Int16:
var temp uint16
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(int16(temp)))
case int32:
dstv.Set(reflect.ValueOf(int16(temp)).Convert(typ))
case reflect.Int32:
var temp uint32
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(int32(temp)))
case int64:
dstv.Set(reflect.ValueOf(int32(temp)).Convert(typ))
case reflect.Int64:
var temp uint64
deserializeTrivialNatural(buf, &temp)
dstv.Set(reflect.ValueOf(int64(temp)))
dstv.Set(reflect.ValueOf(int64(temp)).Convert(typ))
default:
return fmt.Errorf(ErrUnsupportedType, typ)
}

return nil
}

func (br *byteReader) readPointerMarker() (bool, error) {
var marker [1]byte
_, err := br.Read(marker[:])
if err != nil {
return false, err
}

switch marker[0] {
case 0x00:
return true, nil // Nil pointer
case 0x01:
return false, nil // Non-nil pointer
default:
return false, ErrInvalidPointer
}
}

// indirect recursively dereferences pointers and interfaces,
// allocating new pointers as needed, until it reaches a non-pointer value.
func indirect(v reflect.Value) reflect.Value {
Expand All @@ -461,18 +501,3 @@ func indirect(v reflect.Value) reflect.Value {
}
}
}

func IntLength(in any) int {
switch in.(type) {
case uint8, int8:
return 1
case uint16, int16:
return 2
case uint32, int32:
return 4
case uint64, int64:
return 8
default:
panic(fmt.Errorf(ErrUnsupportedType, in))
}
}
Loading

0 comments on commit ce23ec9

Please sign in to comment.