From f1a27187021ca4992db666d21e3671b70fbfb4b9 Mon Sep 17 00:00:00 2001 From: fanweixiao Date: Mon, 30 Aug 2021 12:01:58 +0800 Subject: [PATCH] fix(stream_decoder): io.EOF is not error when parsing (#7) --- parser.go | 20 +++++++++++++-- parser_test.go | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) create mode 100644 parser_test.go diff --git a/parser.go b/parser.go index 8eae02a..6d2afdf 100644 --- a/parser.go +++ b/parser.go @@ -2,16 +2,24 @@ package y3 import ( "bytes" + "errors" "fmt" "io" "github.com/yomorun/y3/encoding" ) +var ( + ErrMalformed = errors.New("y3.ReadPacket: malformed") +) + // ReadPacket will try to read a Y3 encoded packet from the reader func ReadPacket(reader io.Reader) ([]byte, error) { tag, err := readByte(reader) if err != nil { + if err == io.EOF { + return nil, ErrMalformed + } return nil, err } // buf will contain a complete y3 encoded handshakeFrame @@ -26,6 +34,9 @@ func ReadPacket(reader io.Reader) ([]byte, error) { for { b, err := readByte(reader) if err != nil { + if err == io.EOF { + return nil, ErrMalformed + } return nil, err } lenbuf.WriteByte(b) @@ -39,7 +50,7 @@ func ReadPacket(reader io.Reader) ([]byte, error) { codec := encoding.VarCodec{} err = codec.DecodePVarInt32(lenbuf.Bytes(), &length) if err != nil { - return nil, err + return nil, ErrMalformed } // validate len decoded from stream @@ -66,6 +77,10 @@ func ReadPacket(reader io.Reader) ([]byte, error) { p, err := reader.Read(tmpbuf) count += p if err != nil { + if err == io.EOF { + valbuf.Write(tmpbuf[:p]) + break + } return nil, fmt.Errorf("y3 parse valbuf error: %v", err) } valbuf.Write(tmpbuf[:p]) @@ -75,7 +90,8 @@ func ReadPacket(reader io.Reader) ([]byte, error) { } if count < int(length) { - return nil, fmt.Errorf("[y3] p should == len when getting y3 value buffer, len=%d, p=%d", length, count) + // return nil, fmt.Errorf("[y3] p should == len when getting y3 value buffer, len=%d, p=%d", length, count) + return nil, ErrMalformed } // write y3.Value bytes buf.Write(valbuf.Bytes()) diff --git a/parser_test.go b/parser_test.go new file mode 100644 index 0000000..58fc8dc --- /dev/null +++ b/parser_test.go @@ -0,0 +1,68 @@ +package y3 + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStreamParser1(t *testing.T) { + data := []byte{0x01, 0x03, 0x01, 0x02, 0x03} + reader := &pr{buf: data} + + p, err := ReadPacket(reader) + assert.NoError(t, err) + assert.Equal(t, data, p) +} + +func TestStreamParser2(t *testing.T) { + data := []byte{0x01, 0x03, 0x01, 0x02, 0x03, 0x04} + reader := &pr{buf: data} + + p, err := ReadPacket(reader) + assert.NoError(t, err) + assert.Equal(t, data[:5], p) +} + +func TestStreamParser3(t *testing.T) { + data := []byte{0x01, 0x03, 0x01, 0x02} + reader := &pr{buf: data} + + p, err := ReadPacket(reader) + assert.ErrorIs(t, err, ErrMalformed) + assert.Equal(t, []byte(nil), p) +} + +func TestStreamParser4(t *testing.T) { + data := []byte{} + reader := &pr{buf: data} + + p, err := ReadPacket(reader) + assert.ErrorIs(t, err, ErrMalformed) + assert.Equal(t, []byte(nil), p) +} + +func TestStreamParser5(t *testing.T) { + data := []byte{0x01} + reader := &pr{buf: data} + + p, err := ReadPacket(reader) + assert.ErrorIs(t, err, ErrMalformed) + assert.Equal(t, []byte(nil), p) +} + +type pr struct { + buf []byte + off int +} + +func (pr *pr) Read(buf []byte) (int, error) { + if pr.off >= len(pr.buf) { + return 0, io.EOF + } + + copy(buf, []byte{pr.buf[pr.off]}) + pr.off++ + return 1, nil +}