Skip to content

Commit

Permalink
Merge pull request #49 from vearne/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
vearne authored Sep 16, 2024
2 parents b59e97c + 5f105c0 commit 5bce1e4
Show file tree
Hide file tree
Showing 12 changed files with 470 additions and 261 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION := v0.1.9
VERSION := v0.2.1

BIN_NAME = grpcr
CONTAINER = grpcr
Expand Down
35 changes: 15 additions & 20 deletions example/search_client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"log"
"math/rand"
"strconv"
"time"
)
Expand All @@ -32,25 +33,19 @@ func main() {
timeCounter := 0
muchCounter := 0
for i := 0; i < 1000000; i++ {
//value := rand.Intn(100)
//if value <= 5 {
// counter++
// sendSearch(client, counter)
//} else if value <= 20 {
//muchCounter++
// sendMuch(client, uint64(muchCounter))
//} else {
// timeCounter++
// sendCurrTime(client, uint64(timeCounter))
//}
counter++
sendSearch(client, counter)
timeCounter++
sendCurrTime(client, uint64(timeCounter))
muchCounter++
sendMuch(client, muchCounter)
value := rand.Intn(100)
if value <= 5 {
counter++
sendSearch(client, counter)
} else if value <= 20 {
muchCounter++
sendMuch(client, muchCounter)
} else {
timeCounter++
sendCurrTime(client, timeCounter)
}
//time.Sleep(100 * time.Millisecond)
time.Sleep(3 * time.Second)
time.Sleep(10 * time.Second)
}
}

Expand Down Expand Up @@ -89,15 +84,15 @@ func sendSearch(client pb.SearchServiceClient, i int) {
log.Println("resp:", string(bt))
}

func sendCurrTime(client pb.SearchServiceClient, id uint64) {
func sendCurrTime(client pb.SearchServiceClient, id int) {
md := metadata.New(map[string]string{
"testkey3": "testvalue3",
"testkey4": "testvalue4",
})
ctx := metadata.NewOutgoingContext(context.Background(), md)
resp, err := client.CurrentTime(
ctx,
&pb.TimeRequest{RequestId: id},
&pb.TimeRequest{RequestId: uint64(id)},
)
if err != nil {
log.Fatalf("client.CurrentTime err: %v", err)
Expand Down
24 changes: 6 additions & 18 deletions http2/http2.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
"github.com/vearne/grpcreplay/protocol"
slog "github.com/vearne/simplelog"
"golang.org/x/net/http2/hpack"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"io"
"strings"
"time"
Expand Down Expand Up @@ -121,7 +119,7 @@ type Http2Conn struct {
MaxHeaderStringLen uint32
HeaderDecoder *hpack.Decoder
Streams [StreamArraySize]*Stream
SocketBuffer *SocketBuffer
TCPBuffer *TCPBuffer
Reader *bufio.Reader
Processor *Processor
}
Expand All @@ -136,8 +134,8 @@ func NewHttp2Conn(conn DirectConn, maxDynamicTableSize uint32, p *Processor) *Ht
for i := 0; i < StreamArraySize; i++ {
hc.Streams[i] = NewStream()
}
hc.SocketBuffer = NewSocketBuffer()
hc.Reader = bufio.NewReaderSize(hc.SocketBuffer, ReadBufferSize)
hc.TCPBuffer = NewTCPBuffer()
hc.Reader = bufio.NewReaderSize(hc.TCPBuffer, ReadBufferSize)
hc.Processor = p

go hc.deal()
Expand Down Expand Up @@ -379,6 +377,7 @@ func NewStream() *Stream {

func (s *Stream) toMsg(finder *PBMessageFinder) (*protocol.Message, error) {
var msg protocol.Message
var err error
id := uuid.Must(uuid.NewUUID())
msg.Meta.Version = 1
msg.Meta.UUID = id.String()
Expand All @@ -395,20 +394,9 @@ func (s *Stream) toMsg(finder *PBMessageFinder) (*protocol.Message, error) {
return nil, errors.New("method is empty")
} else if !strings.Contains(s.Method, "grpc.reflection") {
// Note: Temporarily only handle the case where the encoding method is Protobuf
pbMsg, err := finder.FindMethodInputWithCache(s.Method)
if err != nil {
slog.Error("finder.FindMethodInputWithCache, error:%v", err)
return nil, err
}
err = proto.Unmarshal(s.DataBuf.Bytes(), pbMsg)
if err != nil {
slog.Error("method:%v, proto.Unmarshal:%v", s.Method, err)
return nil, err
}

s.Request, err = protojson.Marshal(pbMsg)
s.Request, err = finder.HandleRequestToJson(s.Method, s.DataBuf.Bytes())
if err != nil {
slog.Error("method:%v, json.Marshal:%v", s.Method, err)
slog.Error("method:%v, HandleRequestToJson:%v", s.Method, err)
return nil, err
}
}
Expand Down
6 changes: 5 additions & 1 deletion http2/net_pkg.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type NetPkg struct {
SrcIP string
DstIP string

Ethernet *layers.Ethernet
IPv4 *layers.IPv4
IPv6 *layers.IPv6
TCP *layers.TCP
Expand All @@ -33,12 +34,15 @@ type NetPkg struct {

func ProcessPacket(packet gopacket.Packet, ipSet *util.StringSet, port int) (*NetPkg, error) {
var p NetPkg

ethernet := packet.Layer(layers.LayerTypeEthernet)
ipLayerIPv4 := packet.Layer(layers.LayerTypeIPv4)
ipLayerIPv6 := packet.Layer(layers.LayerTypeIPv6)
if ipLayerIPv4 == nil && ipLayerIPv6 == nil {
if ethernet == nil || (ipLayerIPv4 == nil && ipLayerIPv6 == nil) {
return nil, errors.New("invalid IP package")
}

p.Ethernet = ethernet.(*layers.Ethernet)
if ipLayerIPv4 != nil {
p.IPv4 = ipLayerIPv4.(*layers.IPv4)
p.SrcIP = p.IPv4.SrcIP.String()
Expand Down
36 changes: 28 additions & 8 deletions http2/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/fullstorydev/grpcurl"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/grpcreflect"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
Expand Down Expand Up @@ -55,13 +56,13 @@ func (p *Processor) ProcessTCPPkg() {

// connection preface
if IsConnPreface(payload) {
hc.SocketBuffer.expectedSeq = int64(pkg.TCP.Seq) + int64(len(pkg.TCP.Payload))
hc.SocketBuffer.leftPointer = hc.SocketBuffer.expectedSeq
hc.TCPBuffer.expectedSeq = int64(pkg.TCP.Seq) + int64(len(pkg.TCP.Payload))
hc.TCPBuffer.leftPointer = hc.TCPBuffer.expectedSeq
continue
}

slog.Debug("[AddTCP]Connection:%v, seq:%v, length:%v", dc.String(), pkg.TCP.Seq, len(payload))
hc.SocketBuffer.AddTCP(pkg.TCP)
hc.TCPBuffer.AddTCP(pkg.TCP)
}
}

Expand All @@ -87,7 +88,7 @@ func getCodecType(headers map[string]string) int {
}

type PBMessageFinder struct {
cacheMu sync.RWMutex
cacheMu sync.Mutex
symbolMsg map[string]proto.Message
// server address
addr string
Expand All @@ -101,12 +102,33 @@ func NewPBMessageFinder(addr string) *PBMessageFinder {
return &f
}

func (f *PBMessageFinder) HandleRequestToJson(svcAndMethod string, data []byte) ([]byte, error) {
f.cacheMu.Lock()
defer f.cacheMu.Unlock()

pbMsg, err := f.FindMethodInputWithCache(svcAndMethod)
if err != nil {
slog.Error("finder.FindMethodInputWithCache, error:%v", err)
return nil, err
}
err = proto.Unmarshal(data, pbMsg)
if err != nil {
slog.Error("method:%v, proto.Unmarshal:%v", svcAndMethod, err)
return nil, err
}

result, err := protojson.Marshal(pbMsg)
if err != nil {
slog.Error("method:%v, json.Marshal:%v", svcAndMethod, err)
return nil, err
}
return result, nil
}

func (f *PBMessageFinder) FindMethodInputWithCache(svcAndMethod string) (proto.Message, error) {
slog.Debug("FindMethodInputWithCache, svcAndMethod:%v", svcAndMethod)

f.cacheMu.RLock()
m, ok := f.symbolMsg[svcAndMethod]
f.cacheMu.RUnlock()
if ok {
slog.Debug("FindMethodInputWithCache,svcAndMethod:%v, hit cache", svcAndMethod)
return m, nil
Expand All @@ -117,9 +139,7 @@ func (f *PBMessageFinder) FindMethodInputWithCache(svcAndMethod string) (proto.M
return nil, err
}

f.cacheMu.Lock()
f.symbolMsg[svcAndMethod] = msg
f.cacheMu.Unlock()
return msg, nil
}

Expand Down
57 changes: 20 additions & 37 deletions http2/socket_buffer.go → http2/tcp_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@ import (
"github.com/huandu/skiplist"
slog "github.com/vearne/simplelog"
"math"
"sync"
"sync/atomic"
)

type SocketBuffer struct {
lock sync.Mutex
type TCPBuffer struct {
//The number of bytes of data currently cached
size uint32
actualCanReadSize atomic.Uint32
actualCanReadSize uint32
List *skiplist.SkipList
expectedSeq int64
// The sliding window contains the leftPointer
Expand All @@ -24,56 +21,45 @@ type SocketBuffer struct {
dataChannel chan []byte
}

func NewSocketBuffer() *SocketBuffer {
var sb SocketBuffer
func NewTCPBuffer() *TCPBuffer {
var sb TCPBuffer
sb.List = skiplist.New(skiplist.Uint32)
sb.size = 0
sb.actualCanReadSize.Store(0)
sb.actualCanReadSize = 0
sb.expectedSeq = -1
sb.leftPointer = -1
sb.dataChannel = make(chan []byte, 10)
return &sb
}

// may block
func (sb *SocketBuffer) Read(p []byte) (n int, err error) {
slog.Debug("[start]SocketBuffer.Read")
var data []byte
if sb.actualCanReadSize.Load() > 0 {
slog.Debug("SocketBuffer.Read, satisfy the conditions, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq)
data = sb.getData()
} else {
data = <-sb.dataChannel
}
func (sb *TCPBuffer) Read(p []byte) (n int, err error) {
data := <-sb.dataChannel
n = copy(p, data)
slog.Debug("[end]SocketBuffer.Read, got:%v bytes", n)
slog.Debug("SocketBuffer.Read, got:%v bytes", n)
return n, nil
}

func (sb *SocketBuffer) AddTCP(tcpPkg *layers.TCP) {
func (sb *TCPBuffer) AddTCP(tcpPkg *layers.TCP) {
sb.addTCP(tcpPkg)

if sb.actualCanReadSize.Load() > 0 {
if sb.actualCanReadSize > 0 {
slog.Debug("SocketBuffer.AddTCP, satisfy the conditions, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq)
sb.size, sb.actualCanReadSize, sb.expectedSeq)
data := sb.getData()
slog.Debug("push to channel: %v bytes", len(data))
sb.dataChannel <- data
}
}

func (sb *SocketBuffer) addTCP(tcpPkg *layers.TCP) {
func (sb *TCPBuffer) addTCP(tcpPkg *layers.TCP) {
slog.Debug("[start]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq)

sb.lock.Lock()
defer sb.lock.Unlock()
sb.size, sb.actualCanReadSize, sb.expectedSeq)

// duplicate package
if int64(tcpPkg.Seq) < sb.leftPointer || sb.List.Get(tcpPkg.Seq) != nil {
slog.Debug("[end]SocketBuffer.addTCP-duplicate package, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq)
sb.size, sb.actualCanReadSize, sb.expectedSeq)
return
}

Expand All @@ -83,23 +69,20 @@ func (sb *SocketBuffer) addTCP(tcpPkg *layers.TCP) {
for ele != nil && sb.expectedSeq == int64(tcpPkg.Seq) {
// expect next sequence number
sb.expectedSeq = int64((tcpPkg.Seq + uint32(len(tcpPkg.Payload))) % math.MaxUint32)
sb.actualCanReadSize.Store(sb.actualCanReadSize.Load() + uint32(len(tcpPkg.Payload)))
sb.actualCanReadSize += uint32(len(tcpPkg.Payload))

ele = ele.Next()
if ele != nil {
tcpPkg = ele.Value.(*layers.TCP)
}
}
slog.Debug("[end]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq)
sb.size, sb.actualCanReadSize, sb.expectedSeq)
}

func (sb *SocketBuffer) getData() []byte {
func (sb *TCPBuffer) getData() []byte {
slog.Debug("[start]SocketBuffer.getData, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq)

sb.lock.Lock()
defer sb.lock.Unlock()
sb.size, sb.actualCanReadSize, sb.expectedSeq)

var tcpPkg *layers.TCP
buf := bytes.NewBuffer([]byte{})
Expand All @@ -110,7 +93,7 @@ func (sb *SocketBuffer) getData() []byte {

needRemoveList := make([]*skiplist.Element, 0)
for ele != nil && int64(tcpPkg.Seq) <= sb.expectedSeq {
sb.actualCanReadSize.Store(sb.actualCanReadSize.Load() - uint32(len(tcpPkg.Payload)))
sb.actualCanReadSize -= uint32(len(tcpPkg.Payload))
sb.size -= uint32(len(tcpPkg.Payload))
sb.leftPointer += int64(len(tcpPkg.Payload))

Expand All @@ -129,6 +112,6 @@ func (sb *SocketBuffer) getData() []byte {
}

slog.Debug("[end]SocketBuffer.getData, size:%v, actualCanReadSize:%v, expectedSeq:%v, data: %v bytes",
sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq, buf.Len())
sb.size, sb.actualCanReadSize, sb.expectedSeq, buf.Len())
return buf.Bytes()
}
Loading

0 comments on commit 5bce1e4

Please sign in to comment.