From adb9b244cee26fe77b3ad4377b8eada783245fe5 Mon Sep 17 00:00:00 2001 From: "daniel.vladco" Date: Tue, 4 Feb 2025 14:32:35 +0200 Subject: [PATCH] feat: add bits encoder as well --- pkg/serialization/codec/jam/decode.go | 59 +++++++++++++------ pkg/serialization/codec/jam/decode_test.go | 36 ++--------- pkg/serialization/codec/jam/encode.go | 25 ++++++++ .../codec/jam/encode_decode_jam_test.go | 10 ++++ 4 files changed, 80 insertions(+), 50 deletions(-) diff --git a/pkg/serialization/codec/jam/decode.go b/pkg/serialization/codec/jam/decode.go index 0ec9550..a7c763e 100644 --- a/pkg/serialization/codec/jam/decode.go +++ b/pkg/serialization/codec/jam/decode.go @@ -54,14 +54,21 @@ func (d *Decoder) DecodeFixedLength(dst any, length uint) error { dstv = indirect(dstv) in := dstv.Interface() - switch in.(type) { + switch v := in.(type) { case int8, uint8, int16, uint16, int32, uint32, int64, uint64: return d.decodeFixedWidth(dstv, length) case []byte: return d.decodeBytesFixedLength(dstv, length) + case BitSequence: + if err := d.decodeBitsFixedLength(&v, length); err != nil { + return err + } + inType := reflect.TypeOf(in) + dstv.Set(reflect.ValueOf(v).Convert(inType)) default: return fmt.Errorf(ErrUnsupportedType, dst) } + return nil } type byteReader struct { @@ -77,7 +84,7 @@ func (br *byteReader) unmarshal(value reflect.Value) error { } in := value.Interface() - switch v := in.(type) { + switch in.(type) { case int, uint: return br.decodeUint(value) @@ -90,12 +97,7 @@ func (br *byteReader) unmarshal(value reflect.Value) error { case []byte: return br.decodeBytes(value) case BitSequence: - if err := br.decodeBits(&v); err != nil { - return err - } - inType := reflect.TypeOf(in) - value.Set(reflect.ValueOf(v).Convert(inType)) - return nil + return br.decodeBits(value) case bool: return br.decodeBool(value) default: @@ -118,6 +120,12 @@ func (br *byteReader) handleReflectTypes(value reflect.Value) error { if value.Type() == reflect.TypeOf(ed25519.PublicKey{}) { return br.decodeEd25519PublicKey(value) } + if value.Type() == reflect.TypeOf(BitSequence{}) { + return br.decodeBits(value) + } + if value.Type() == reflect.TypeOf([]byte{}) { + return br.decodeBytes(value) + } return br.decodeSlice(value) case reflect.Map: return br.decodeMap(value) @@ -445,20 +453,35 @@ func (br *byteReader) decodeBytesFixedLength(dstv reflect.Value, length uint) er dstv.Set(reflect.ValueOf(b).Convert(inType)) return nil } -func (br *byteReader) decodeBits(v *BitSequence) (err error) { - length := len(*v) - if length > math.MaxUint32 { + +// decodeBytes is used to decode with a destination of []byte +func (br *byteReader) decodeBits(dstv reflect.Value) error { + length, err := br.decodeLength() + if err != nil { + return err + } + var v BitSequence + if err := br.decodeBitsFixedLength(&v, length); err != nil { + return err + } + in := dstv.Interface() + inType := reflect.TypeOf(in) + dstv.Set(reflect.ValueOf(v).Convert(inType)) + return nil +} + +func (br *byteReader) decodeBitsFixedLength(v *BitSequence, bytesLength uint) (err error) { + if bytesLength > math.MaxUint32 { return ErrExceedingByteArrayLimit } - var b byte + bb := make([]byte, bytesLength) + if _, err = br.Reader.Read(bb); err != nil { + return err + } + *v = make(BitSequence, bytesLength*8) for i := range *v { mod := i % 8 - if mod == 0 { - b, err = br.ReadOctet() // take as many bytes as needed to fill the bit sequence - if err != nil { - return err - } - } + b := bb[i/8] pow2 := byte(1 << mod) // powers of 2 (*v)[i] = b&pow2 == pow2 // identify the bit } diff --git a/pkg/serialization/codec/jam/decode_test.go b/pkg/serialization/codec/jam/decode_test.go index a3a15b5..4980b52 100644 --- a/pkg/serialization/codec/jam/decode_test.go +++ b/pkg/serialization/codec/jam/decode_test.go @@ -2,7 +2,6 @@ package jam import ( "bytes" - "io" "testing" "github.com/stretchr/testify/assert" @@ -13,37 +12,26 @@ func TestDecodeBits(t *testing.T) { name string input []byte expect BitSequence - len int leftover int err error }{{ name: "empty", input: []byte{}, - len: 0, expect: BitSequence{}, }, { name: "1 bytes", input: []byte{255}, - len: 8, expect: BitSequence{true, true, true, true, true, true, true, true}, }, { name: "1.5 bytes", input: []byte{0, 255}, - len: 12, expect: BitSequence{ false, false, false, false, false, false, false, false, - true, true, true, true, + true, true, true, true, true, true, true, true, }, - }, { - name: "1 bytes 1 unused", - input: []byte{255, 255}, - len: 8, - expect: BitSequence{true, true, true, true, true, true, true, true}, - leftover: 1, }, { name: "5 bytes", input: []byte{17, 25, 0, 1, 2}, - len: 5 * 8, expect: BitSequence{ true, false, false, false, true, false, false, false, true, false, false, true, true, false, false, false, @@ -51,34 +39,18 @@ func TestDecodeBits(t *testing.T) { true, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false, }, - }, { - name: "empty byte array", - input: []byte{}, - expect: BitSequence{ - false, false, false, false, false, false, false, false, - }, - len: 8, - err: io.EOF, - }, { - name: "not enough bytes", - input: []byte{255}, - expect: BitSequence{ - true, true, true, true, true, true, true, true, - false, - }, - len: 9, - err: io.EOF, }} for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { buff := bytes.NewBuffer(tc.input) d := NewDecoder(buff) - actual := make(BitSequence, tc.len) - err := d.Decode(&actual) + actual := BitSequence{} + err := d.DecodeFixedLength(&actual, uint(len(tc.input))) if tc.err != nil { assert.Equal(t, tc.err, err) } + assert.Equal(t, tc.expect, actual) assert.Equal(t, tc.leftover, buff.Len()) }) diff --git a/pkg/serialization/codec/jam/encode.go b/pkg/serialization/codec/jam/encode.go index dda603f..8a22fa3 100644 --- a/pkg/serialization/codec/jam/encode.go +++ b/pkg/serialization/codec/jam/encode.go @@ -47,11 +47,14 @@ func (bw *byteWriter) marshal(in interface{}) error { return bw.encodeFixedWidth(v, l) case []byte: return bw.encodeBytes(v) + case BitSequence: + return bw.encodeBits(v) case bool: return bw.encodeBool(v) default: return bw.handleReflectTypes(v) } + return nil } func (bw *byteWriter) handleReflectTypes(in interface{}) error { @@ -288,6 +291,28 @@ func (bw *byteWriter) encodeBytes(b []byte) error { return err } +func (bw *byteWriter) encodeBits(bitSequence BitSequence) error { + length := len(bitSequence) / 8 + if length%8 == 0 { + length += 1 + } + err := bw.encodeLength(length) + if err != nil { + return err + } + + bb := make([]byte, length) + for i, b := range bitSequence { + if b { + pow2 := byte(1 << (i % 8)) // powers of 2 + bb[i/8] |= pow2 // identify the bit + } + } + + _, err = bw.Write(bb) + return err +} + func (bw *byteWriter) encodeFixedWidth(i interface{}, l uint) error { val := reflect.ValueOf(i) diff --git a/pkg/serialization/codec/jam/encode_decode_jam_test.go b/pkg/serialization/codec/jam/encode_decode_jam_test.go index 4fcfa57..012974f 100644 --- a/pkg/serialization/codec/jam/encode_decode_jam_test.go +++ b/pkg/serialization/codec/jam/encode_decode_jam_test.go @@ -23,6 +23,7 @@ type TestStruct struct { LargeUint uint PubKey *ed25519.PublicKey InnerSlice []InnerStruct + Bits jam.BitSequence } func TestMarshalUnmarshal(t *testing.T) { @@ -40,6 +41,15 @@ func TestMarshalUnmarshal(t *testing.T) { {2, 3, 4, 5}, {3, 4, 5, 6}, }, + Bits: jam.BitSequence{ + true, true, true, true, true, true, true, true, + true, true, true, true, true, true, true, false, + true, true, true, true, true, true, false, false, + true, true, true, true, false, false, false, false, + true, true, false, false, false, false, false, false, + true, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, + }, } marshaledData, err := jam.Marshal(original)