Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix seq wrap around #57

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION := v0.2.3
VERSION := v0.2.4

BIN_NAME = grpcr
CONTAINER = grpcr
Expand Down
9 changes: 5 additions & 4 deletions http2/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"
"math"

"github.com/vearne/grpcreplay/protocol"
slog "github.com/vearne/simplelog"
Expand Down Expand Up @@ -49,23 +50,23 @@ func (p *Processor) ProcessTCPPkg() {
}
hc := p.ConnRepository[dc]

payloadSize := uint32(len(payload))

// SYN/ACK/FIN
if len(payload) <= 0 {
if pkg.TCP.FIN {
slog.Info("got Fin package, close connection:%v", dc.String())
hc.TCPBuffer.Close()
delete(p.ConnRepository, dc)
} else {
hc.TCPBuffer.expectedSeq = int64(pkg.TCP.Seq) + int64(len(pkg.TCP.Payload))
hc.TCPBuffer.leftPointer = hc.TCPBuffer.expectedSeq
hc.TCPBuffer.expectedSeq = (pkg.TCP.Seq + payloadSize) % math.MaxUint32
}
continue
}

// connection preface
if IsConnPreface(payload) {
hc.TCPBuffer.expectedSeq = int64(pkg.TCP.Seq) + int64(len(pkg.TCP.Payload))
hc.TCPBuffer.leftPointer = hc.TCPBuffer.expectedSeq
hc.TCPBuffer.expectedSeq = (pkg.TCP.Seq + payloadSize) % math.MaxUint32
continue
}

Expand Down
88 changes: 26 additions & 62 deletions http2/tcp_buffer.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
package http2

import (
"bytes"
"github.com/google/gopacket/layers"
"github.com/huandu/skiplist"
slog "github.com/vearne/simplelog"
"math"
"net"
"sync/atomic"
)

type TCPBuffer struct {
//The number of bytes of data currently cached
size uint32
actualCanReadSize uint32
size atomic.Int64
actualCanReadSize atomic.Int64
List *skiplist.SkipList
expectedSeq int64
// The sliding window contains the leftPointer
leftPointer int64

expectedSeq uint32
//There is at most one reader to read
dataChannel chan []byte
closeChan chan struct{}
Expand All @@ -26,11 +23,10 @@ type TCPBuffer struct {
func NewTCPBuffer() *TCPBuffer {
var sb TCPBuffer
sb.List = skiplist.New(skiplist.Uint32)
sb.size = 0
sb.actualCanReadSize = 0
sb.expectedSeq = -1
sb.leftPointer = -1
sb.dataChannel = make(chan []byte, 10)
sb.size.Store(0)
sb.actualCanReadSize.Store(0)
sb.expectedSeq = 0
sb.dataChannel = make(chan []byte, 100)
sb.closeChan = make(chan struct{})
return &sb
}
Expand All @@ -47,72 +43,41 @@ func (sb *TCPBuffer) Read(p []byte) (n int, err error) {
err = net.ErrClosed
case data = <-sb.dataChannel:
n = copy(p, data)
dataSize := int64(len(data))
sb.size.Add(dataSize * -1)
sb.actualCanReadSize.Add(dataSize * -1)
}
slog.Debug("SocketBuffer.Read, got:%v bytes", n)
return n, err
}

func (sb *TCPBuffer) AddTCP(tcpPkg *layers.TCP) {
sb.addTCP(tcpPkg)

if sb.actualCanReadSize > 0 {
slog.Debug("SocketBuffer.AddTCP, satisfy the conditions, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize, sb.expectedSeq)
data := sb.getData()
slog.Debug("push to channel: %v bytes", len(data))
sb.dataChannel <- data
}
}

func (sb *TCPBuffer) addTCP(tcpPkg *layers.TCP) {
slog.Debug("[start]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize, sb.expectedSeq)
sb.size.Load(), sb.actualCanReadSize.Load(), sb.expectedSeq)

// duplicate package
if int64(tcpPkg.Seq) < sb.leftPointer || sb.List.Get(tcpPkg.Seq) != nil {
if sb.List.Get(tcpPkg.Seq) != nil {
slog.Debug("[end]SocketBuffer.addTCP-duplicate package, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize, sb.expectedSeq)
sb.size.Load(), sb.actualCanReadSize.Load(), sb.expectedSeq)
return
}

ele := sb.List.Set(tcpPkg.Seq, tcpPkg)
sb.size += uint32(len(tcpPkg.Payload))
sb.size.Add(int64(len(tcpPkg.Payload)))
needRemoveList := make([]*skiplist.Element, 0)

for ele != nil && sb.expectedSeq == int64(tcpPkg.Seq) {
for ele != nil && sb.expectedSeq == tcpPkg.Seq {
// expect next sequence number
sb.expectedSeq = int64((tcpPkg.Seq + uint32(len(tcpPkg.Payload))) % math.MaxUint32)
sb.actualCanReadSize += uint32(len(tcpPkg.Payload))
// sequence numbers may wrap around
payloadSize := uint32(len(tcpPkg.Payload))
sb.actualCanReadSize.Add(int64(payloadSize))
sb.expectedSeq = (tcpPkg.Seq + payloadSize) % math.MaxUint32

ele = ele.Next()
if ele != nil {
tcpPkg = ele.Value.(*layers.TCP)
}
}
slog.Debug("[end]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize, sb.expectedSeq)
}

func (sb *TCPBuffer) getData() []byte {
slog.Debug("[start]SocketBuffer.getData, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize, sb.expectedSeq)

var tcpPkg *layers.TCP
buf := bytes.NewBuffer([]byte{})
ele := sb.List.Front()
if ele != nil {
tcpPkg = ele.Value.(*layers.TCP)
}

needRemoveList := make([]*skiplist.Element, 0)
for ele != nil && int64(tcpPkg.Seq) <= sb.expectedSeq {
sb.actualCanReadSize -= uint32(len(tcpPkg.Payload))
sb.size -= uint32(len(tcpPkg.Payload))
sb.leftPointer += int64(len(tcpPkg.Payload))

buf.Write(tcpPkg.Payload)
// push to channel
sb.dataChannel <- tcpPkg.Payload
needRemoveList = append(needRemoveList, ele)

ele = ele.Next()
ele = sb.List.Get(sb.expectedSeq)
if ele != nil {
tcpPkg = ele.Value.(*layers.TCP)
}
Expand All @@ -123,7 +88,6 @@ func (sb *TCPBuffer) getData() []byte {
sb.List.RemoveElement(element)
}

slog.Debug("[end]SocketBuffer.getData, size:%v, actualCanReadSize:%v, expectedSeq:%v, data: %v bytes",
sb.size, sb.actualCanReadSize, sb.expectedSeq, buf.Len())
return buf.Bytes()
slog.Debug("[end]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size.Load(), sb.actualCanReadSize.Load(), sb.expectedSeq)
}
75 changes: 66 additions & 9 deletions http2/tcp_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ func TestSocketBufferSequence1(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 1000
buffer.leftPointer = 1000

var tcpPkgA layers.TCP
tcpPkgA.Seq = 1000
Expand Down Expand Up @@ -44,7 +43,6 @@ func TestSocketBufferSequence2(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 1000
buffer.leftPointer = 1000

var tcpPkgA layers.TCP
tcpPkgA.Seq = 1000
Expand Down Expand Up @@ -80,7 +78,6 @@ func TestSocketBufferSequence3(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 1000
buffer.leftPointer = 1000

var tcpPkgA layers.TCP
tcpPkgA.Seq = 1000
Expand Down Expand Up @@ -112,22 +109,21 @@ func TestSocketBufferSequence3(t *testing.T) {
assert.Nil(t, err)
}

func TestSocketBufferDuplicate(t *testing.T) {
func TestSocketBufferWrapAround1(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 1000
buffer.leftPointer = 1000
buffer.expectedSeq = 4294967290

var tcpPkgA layers.TCP
tcpPkgA.Seq = 1000
tcpPkgA.Seq = 4294967290
tcpPkgA.Payload = []byte("aaaaaaaaaa")

var tcpPkgB layers.TCP
tcpPkgB.Seq = 1010
tcpPkgB.Seq = 4
tcpPkgB.Payload = []byte("bbbbbbbbbb")

var tcpPkgC layers.TCP
tcpPkgC.Seq = 1020
tcpPkgC.Seq = 14
tcpPkgC.Payload = []byte("cccccccccc")

buffer.AddTCP(&tcpPkgA)
Expand All @@ -143,3 +139,64 @@ func TestSocketBufferDuplicate(t *testing.T) {
// assert for nil (good for errors)
assert.Nil(t, err)
}

func TestSocketBufferWrapAround2(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 4294967290

var tcpPkgA layers.TCP
tcpPkgA.Seq = 4294967290
tcpPkgA.Payload = []byte("aaaaaaaaaa")

var tcpPkgB layers.TCP
tcpPkgB.Seq = 4
tcpPkgB.Payload = []byte("bbbbbbbbbb")

var tcpPkgC layers.TCP
tcpPkgC.Seq = 14
tcpPkgC.Payload = []byte("cccccccccc")

buffer.AddTCP(&tcpPkgB)
buffer.AddTCP(&tcpPkgA)
buffer.AddTCP(&tcpPkgC)
buffer.AddTCP(&tcpPkgA)

buf := make([]byte, 1024)
n, err := io.ReadAtLeast(buffer, buf, 30)
// assert equality
assert.Equal(t, 30, n, "read data")
assert.Equal(t, "aaaaaaaaaabbbbbbbbbbcccccccccc", string(buf[0:n]), "read data")
// assert for nil (good for errors)
assert.Nil(t, err)
}

func TestSocketBufferWrapAround3(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 4294967290

var tcpPkgA layers.TCP
tcpPkgA.Seq = 4294967290
tcpPkgA.Payload = []byte("aaaaaaaaaa")

var tcpPkgB layers.TCP
tcpPkgB.Seq = 4
tcpPkgB.Payload = []byte("bbbbbbbbbb")

var tcpPkgC layers.TCP
tcpPkgC.Seq = 14
tcpPkgC.Payload = []byte("cccccccccc")

buffer.AddTCP(&tcpPkgA)
buffer.AddTCP(&tcpPkgB)
buffer.AddTCP(&tcpPkgC)

buf := make([]byte, 1024)
n, err := io.ReadAtLeast(buffer, buf, 30)
// assert equality
assert.Equal(t, 30, n, "read data")
assert.Equal(t, "aaaaaaaaaabbbbbbbbbbcccccccccc", string(buf[0:n]), "read data")
// assert for nil (good for errors)
assert.Nil(t, err)
}
Loading