diff --git a/internal/polkavm/interpreter/mutator.go b/internal/polkavm/interpreter/mutator.go index 1865662c..cf4e805a 100644 --- a/internal/polkavm/interpreter/mutator.go +++ b/internal/polkavm/interpreter/mutator.go @@ -1,9 +1,10 @@ package interpreter import ( - "github.com/eigerco/strawberry/pkg/serialization/codec/jam" "math" + "github.com/eigerco/strawberry/pkg/serialization/codec/jam" + "github.com/eigerco/strawberry/internal/polkavm" ) @@ -44,8 +45,13 @@ func load[T number](m *Mutator, dst polkavm.Reg, base *polkavm.Reg, offset uint3 } address += offset value := T(0) - slice := make([]byte, jam.IntLength(value)) - err := m.instance.memory.Read(address, slice) + l, err := jam.IntLength(value) + if err != nil { + return err + } + + slice := make([]byte, l) + err = m.instance.memory.Read(address, slice) if err != nil { return err } diff --git a/pkg/serialization/codec/jam/common.go b/pkg/serialization/codec/jam/common.go new file mode 100644 index 00000000..3d91ee85 --- /dev/null +++ b/pkg/serialization/codec/jam/common.go @@ -0,0 +1,33 @@ +package jam + +import ( + "fmt" + "strings" +) + +func IntLength(in any) (uint, error) { + switch in.(type) { + case uint8, int8: + return 1, nil + case uint16, int16: + return 2, nil + case uint32, int32: + return 4, nil + case uint64, int64: + return 8, nil + default: + return 0, fmt.Errorf(ErrUnsupportedType, in) + } +} + +func parseTag(tag string) map[string]string { + result := make(map[string]string) + pairs := strings.Split(tag, ",") + for _, pair := range pairs { + kv := strings.Split(pair, "=") + if len(kv) == 2 { + result[kv[0]] = kv[1] + } + } + return result +} diff --git a/pkg/serialization/codec/jam/decode.go b/pkg/serialization/codec/jam/decode.go index 31b193f3..3b7ca955 100644 --- a/pkg/serialization/codec/jam/decode.go +++ b/pkg/serialization/codec/jam/decode.go @@ -8,6 +8,7 @@ import ( "math" "math/bits" "reflect" + "strconv" "github.com/eigerco/strawberry/internal/crypto" ) @@ -42,7 +43,11 @@ func (br *byteReader) unmarshal(value reflect.Value) error { case int, uint: return br.decodeUint(value) case int8, uint8, int16, uint16, int32, uint32, int64, uint64: - return br.decodeFixedWidthInt(value) + l, err := IntLength(value.Interface()) + if err != nil { + return err + } + return br.decodeFixedWidth(value, l) case []byte: return br.decodeBytes(value) case bool: @@ -154,38 +159,26 @@ func (br *byteReader) decodeEnum(enum EnumType) error { } func (br *byteReader) decodePointer(value reflect.Value) error { - rb, err := br.ReadOctet() + isNil, err := br.readPointerMarker() if err != nil { return err } - switch rb { - case 0x00: - // Handle the nil pointer case by setting the destination to nil if necessary + if isNil { + // Set the pointer to nil if !value.IsNil() { value.Set(reflect.Zero(value.Type())) } - case 0x01: - // Check if the destination is a non-nil pointer - if !value.IsZero() { - // If it's a pointer to another pointer, we need to handle it recursively - if value.Elem().Kind() == reflect.Ptr { - return br.unmarshal(value.Elem().Elem()) - } - return br.unmarshal(value.Elem()) - } + return nil + } - // If value is nil or zero, we need to create a new instance - elemType := value.Type().Elem() - tempElem := reflect.New(elemType) - if err := br.unmarshal(tempElem.Elem()); err != nil { - return err - } - value.Set(tempElem) - default: - return ErrInvalidPointer + // Allocate space for the pointer if it's nil + if value.IsNil() { + value.Set(reflect.New(value.Type().Elem())) } - return nil + + // Decode the dereferenced value + return br.unmarshal(value.Elem()) } func (br *byteReader) decodeSlice(value reflect.Value) error { @@ -291,6 +284,19 @@ func (br *byteReader) decodeStruct(value reflect.Value) error { if tag == "-" { continue } + tagValues := parseTag(tag) + if length, found := tagValues["length"]; found { + size, err := strconv.ParseUint(length, 10, 64) + if err != nil { + return fmt.Errorf(ErrInvalidLengthValue, fieldType.Name, err) + } + + err = br.decodeFixedWidth(field, uint(size)) + if err != nil { + return fmt.Errorf(ErrEncodingStructField, fieldType.Name, err) + } + continue + } } // Decode the field value @@ -390,57 +396,91 @@ func (br *byteReader) decodeBytes(dstv reflect.Value) error { return nil } -func (br *byteReader) decodeFixedWidthInt(dstv reflect.Value) error { - in := dstv.Interface() - var buf []byte - length := IntLength(in) +func (br *byteReader) decodeFixedWidth(dstv reflect.Value, length uint) error { + typ := dstv.Type() + + // Handle pointers + if typ.Kind() == reflect.Ptr { + isNil, err := br.readPointerMarker() + if err != nil { + return err + } + if isNil { + dstv.Set(reflect.Zero(typ)) + + return nil + } + if dstv.IsNil() { + dstv.Set(reflect.New(typ.Elem())) + } + dstv = dstv.Elem() + typ = typ.Elem() + } - // Read the appropriate number of bytes - buf = make([]byte, length) + // Read the data + buf := make([]byte, length) _, err := br.Read(buf) if err != nil { return fmt.Errorf(ErrReadingByte, err) } - // Deserialize the value - switch in.(type) { - case uint8: + switch typ.Kind() { + case reflect.Uint8: var temp uint8 deserializeTrivialNatural(buf, &temp) - dstv.Set(reflect.ValueOf(temp)) - case uint16: + dstv.Set(reflect.ValueOf(temp).Convert(typ)) + case reflect.Uint16: var temp uint16 deserializeTrivialNatural(buf, &temp) - dstv.Set(reflect.ValueOf(temp)) - case uint32: + dstv.Set(reflect.ValueOf(temp).Convert(typ)) + case reflect.Uint32: var temp uint32 deserializeTrivialNatural(buf, &temp) - dstv.Set(reflect.ValueOf(temp)) - case uint64: + dstv.Set(reflect.ValueOf(temp).Convert(typ)) + case reflect.Uint64: var temp uint64 deserializeTrivialNatural(buf, &temp) - dstv.Set(reflect.ValueOf(temp)) - case int8: + dstv.Set(reflect.ValueOf(temp).Convert(typ)) + case reflect.Int8: var temp uint8 deserializeTrivialNatural(buf, &temp) - dstv.Set(reflect.ValueOf(int8(temp))) - case int16: + dstv.Set(reflect.ValueOf(int8(temp)).Convert(typ)) + case reflect.Int16: var temp uint16 deserializeTrivialNatural(buf, &temp) - dstv.Set(reflect.ValueOf(int16(temp))) - case int32: + dstv.Set(reflect.ValueOf(int16(temp)).Convert(typ)) + case reflect.Int32: var temp uint32 deserializeTrivialNatural(buf, &temp) - dstv.Set(reflect.ValueOf(int32(temp))) - case int64: + dstv.Set(reflect.ValueOf(int32(temp)).Convert(typ)) + case reflect.Int64: var temp uint64 deserializeTrivialNatural(buf, &temp) - dstv.Set(reflect.ValueOf(int64(temp))) + dstv.Set(reflect.ValueOf(int64(temp)).Convert(typ)) + default: + return fmt.Errorf(ErrUnsupportedType, typ) } return nil } +func (br *byteReader) readPointerMarker() (bool, error) { + var marker [1]byte + _, err := br.Read(marker[:]) + if err != nil { + return false, err + } + + switch marker[0] { + case 0x00: + return true, nil // Nil pointer + case 0x01: + return false, nil // Non-nil pointer + default: + return false, ErrInvalidPointer + } +} + // indirect recursively dereferences pointers and interfaces, // allocating new pointers as needed, until it reaches a non-pointer value. func indirect(v reflect.Value) reflect.Value { @@ -461,18 +501,3 @@ func indirect(v reflect.Value) reflect.Value { } } } - -func IntLength(in any) int { - switch in.(type) { - case uint8, int8: - return 1 - case uint16, int16: - return 2 - case uint32, int32: - return 4 - case uint64, int64: - return 8 - default: - panic(fmt.Errorf(ErrUnsupportedType, in)) - } -} diff --git a/pkg/serialization/codec/jam/encode.go b/pkg/serialization/codec/jam/encode.go index 17a08d49..dda603f7 100644 --- a/pkg/serialization/codec/jam/encode.go +++ b/pkg/serialization/codec/jam/encode.go @@ -7,6 +7,7 @@ import ( "io" "reflect" "sort" + "strconv" ) func Marshal(v interface{}) ([]byte, error) { @@ -33,19 +34,23 @@ func (bw *byteWriter) marshal(in interface{}) error { return bw.encodeEnumType(v) } - switch in := in.(type) { + switch v := in.(type) { case int: - return bw.encodeUint(uint(in)) + return bw.encodeUint(uint(v)) case uint: - return bw.encodeUint(in) + return bw.encodeUint(v) case uint8, uint16, uint32, uint64: - return bw.encodeFixedWidthUint(in) + l, err := IntLength(v) + if err != nil { + return err + } + return bw.encodeFixedWidth(v, l) case []byte: - return bw.encodeBytes(in) + return bw.encodeBytes(v) case bool: - return bw.encodeBool(in) + return bw.encodeBool(v) default: - return bw.handleReflectTypes(in) + return bw.handleReflectTypes(v) } } @@ -56,18 +61,14 @@ func (bw *byteWriter) handleReflectTypes(in interface{}) error { reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return bw.encodeCustomPrimitive(in) case reflect.Ptr: - elem := reflect.ValueOf(in).Elem() - switch elem.IsValid() { - case false: - _, err := bw.Write([]byte{0}) + err := bw.writePointerMarker(val.IsNil()) + if err != nil { return err - default: - _, err := bw.Write([]byte{1}) - if err != nil { - return err - } - return bw.marshal(elem.Interface()) } + if val.IsNil() { + return nil + } + return bw.marshal(val.Elem().Interface()) case reflect.Struct: return bw.encodeStruct(in) case reflect.Array: @@ -287,23 +288,38 @@ func (bw *byteWriter) encodeBytes(b []byte) error { return err } -func (bw *byteWriter) encodeFixedWidthUint(i interface{}) error { - var data []byte - - switch v := i.(type) { - case uint8: - data = serializeTrivialNatural(v, 1) - case uint16: - data = serializeTrivialNatural(v, 2) - case uint32: - data = serializeTrivialNatural(v, 4) - case uint64: - data = serializeTrivialNatural(v, 8) +func (bw *byteWriter) encodeFixedWidth(i interface{}, l uint) error { + val := reflect.ValueOf(i) + + // Handle pointers + if val.Kind() == reflect.Ptr { + err := bw.writePointerMarker(val.IsNil()) + if err != nil { + return err + } + if val.IsNil() { + return nil + } + val = val.Elem() // Dereference non-nil pointer + } + + typ := val.Type() + switch typ.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + data := serializeTrivialNatural(val.Uint(), l) + _, err := bw.Write(data) + return err default: return fmt.Errorf(ErrUnsupportedType, i) } +} - _, err := bw.Write(data) +func (bw *byteWriter) writePointerMarker(isNil bool) error { + marker := byte(0x00) + if !isNil { + marker = byte(0x01) + } + _, err := bw.Write([]byte{marker}) return err } @@ -324,6 +340,20 @@ func (bw *byteWriter) encodeStruct(in interface{}) error { if tag == "-" { continue } + + tagValues := parseTag(tag) + if length, found := tagValues["length"]; found { + size, err := strconv.ParseUint(length, 10, 64) + if err != nil { + return fmt.Errorf(ErrInvalidLengthValue, fieldType.Name, err) + } + + err = bw.encodeFixedWidth(field.Interface(), uint(size)) + if err != nil { + return fmt.Errorf(ErrEncodingStructField, fieldType.Name, err) + } + continue + } } // Marshal and encode the field value diff --git a/pkg/serialization/codec/jam/encode_decode_jam_test.go b/pkg/serialization/codec/jam/encode_decode_jam_test.go index d547838f..4fcfa57c 100644 --- a/pkg/serialization/codec/jam/encode_decode_jam_test.go +++ b/pkg/serialization/codec/jam/encode_decode_jam_test.go @@ -86,3 +86,63 @@ func TestMarshalUnmarshalWithPointer(t *testing.T) { assert.Equal(t, original, unmarshaled) } + +func TestLengthTag(t *testing.T) { + // simple struct without tags + type NoTag struct { + Uint32 uint32 + } + noTag := NoTag{10} + marshaledData, err := jam.Marshal(noTag) + require.NoError(t, err) + require.Len(t, marshaledData, 4) + require.Equal(t, []byte{10, 0, 0, 0}, marshaledData) + + var noTagUnmarshaled NoTag + err = jam.Unmarshal(marshaledData, &noTagUnmarshaled) + require.NoError(t, err) + assert.Equal(t, noTag, noTagUnmarshaled) + + // simple struct with tag + type WithTag struct { + Uint32 uint32 `jam:"length=32"` + } + withTag := WithTag{50} + marshaledData, err = jam.Marshal(withTag) + require.NoError(t, err) + require.Len(t, marshaledData, 32) + expectedBytes := append([]byte{50}, make([]byte, 31)...) + assert.Equal(t, expectedBytes, marshaledData) + + var withTagUnmarshaled WithTag + err = jam.Unmarshal(marshaledData, &withTagUnmarshaled) + require.NoError(t, err) + assert.Equal(t, withTag, withTagUnmarshaled) + + // more complex struct to check alias and pointers + type Alias uint16 + type CustomStruct struct { + Alias Alias `jam:"length=6"` + Uint32 uint32 `jam:"length=32"` + NilPointer *uint8 `jam:"length=4"` + Pointer *uint64 `jam:"length=10"` + Bool bool + } + + p := uint64(40) + original := CustomStruct{ + Alias: 5, + Uint32: 50, + Pointer: &p, + Bool: true, + } + + marshaledData, err = jam.Marshal(original) + require.NoError(t, err) + + var unmarshaled CustomStruct + err = jam.Unmarshal(marshaledData, &unmarshaled) + require.NoError(t, err) + + assert.Equal(t, original, unmarshaled) +} diff --git a/pkg/serialization/codec/jam/errors.go b/pkg/serialization/codec/jam/errors.go index a57dbc7f..79efaff8 100644 --- a/pkg/serialization/codec/jam/errors.go +++ b/pkg/serialization/codec/jam/errors.go @@ -23,4 +23,5 @@ var ( ErrDecodingMapValue = "error decoding map value: %v" ErrEncodingStructField = "encoding struct field '%s': %w" ErrDecodingStructField = "decoding struct field '%s': %w" + ErrInvalidLengthValue = "invalid length value in jam tag for field %s: %v" ) diff --git a/pkg/serialization/codec/jam/trivial_natural.go b/pkg/serialization/codec/jam/trivial_natural.go index 1b3b6fe2..d805dc18 100644 --- a/pkg/serialization/codec/jam/trivial_natural.go +++ b/pkg/serialization/codec/jam/trivial_natural.go @@ -4,9 +4,9 @@ import ( "math" ) -func serializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](x T, l uint8) []byte { +func serializeTrivialNatural[T ~uint8 | ~uint16 | ~uint32 | ~uint64](x T, l uint) []byte { bytes := make([]byte, l) - for i := uint8(0); i < l; i++ { + for i := uint(0); i < l; i++ { bytes[i] = byte((x >> (8 * i)) & T(math.MaxUint8)) } return bytes diff --git a/pkg/serialization/codec/jam/trivial_natural_test.go b/pkg/serialization/codec/jam/trivial_natural_test.go index 9f8e78d6..c1af2e25 100644 --- a/pkg/serialization/codec/jam/trivial_natural_test.go +++ b/pkg/serialization/codec/jam/trivial_natural_test.go @@ -12,7 +12,7 @@ import ( func TestSerializationTrivialNatural(t *testing.T) { testCases := []struct { x any - l uint8 + l uint expected []byte }{ {uint8(0), 0, []byte{}},