From 7708c955fbf4e3b9378423398da6d265b9373942 Mon Sep 17 00:00:00 2001 From: vearne Date: Sun, 15 Sep 2024 09:12:23 +0800 Subject: [PATCH 1/7] update example/search_client/client.go --- example/search_client/client.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/example/search_client/client.go b/example/search_client/client.go index 6da5c48..e73c3e7 100644 --- a/example/search_client/client.go +++ b/example/search_client/client.go @@ -28,8 +28,8 @@ func main() { defer conn.Close() client := pb.NewSearchServiceClient(conn) - counter := 0 - timeCounter := 0 + //counter := 0 + //timeCounter := 0 muchCounter := 0 for i := 0; i < 1000000; i++ { //value := rand.Intn(100) @@ -43,14 +43,14 @@ func main() { // timeCounter++ // sendCurrTime(client, uint64(timeCounter)) //} - counter++ - sendSearch(client, counter) - timeCounter++ - sendCurrTime(client, uint64(timeCounter)) - muchCounter++ + //counter++ + //sendSearch(client, counter) + //timeCounter++ + //sendCurrTime(client, uint64(timeCounter)) + //muchCounter++ sendMuch(client, muchCounter) //time.Sleep(100 * time.Millisecond) - time.Sleep(3 * time.Second) + time.Sleep(10 * time.Second) } } From 162d048d227f42068ed4bcf234df582381b028bd Mon Sep 17 00:00:00 2001 From: vearne Date: Sun, 15 Sep 2024 11:05:50 +0800 Subject: [PATCH 2/7] fix some bug --- example/search_client/client.go | 37 ++++++------ http2/http2.go | 24 ++------ http2/processor.go | 36 +++++++++--- http2/{socket_buffer.go => tcp_buffer.go} | 57 +++++++------------ ...cket_buffer_test.go => tcp_buffer_test.go} | 0 plugin/tcp_kill.go | 1 - 6 files changed, 70 insertions(+), 85 deletions(-) rename http2/{socket_buffer.go => tcp_buffer.go} (62%) rename http2/{socket_buffer_test.go => tcp_buffer_test.go} (100%) diff --git a/example/search_client/client.go b/example/search_client/client.go index e73c3e7..bea7614 100644 --- a/example/search_client/client.go +++ b/example/search_client/client.go @@ -11,6 +11,7 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "log" + "math/rand" "strconv" "time" ) @@ -28,27 +29,21 @@ func main() { defer conn.Close() client := pb.NewSearchServiceClient(conn) - //counter := 0 - //timeCounter := 0 + counter := 0 + timeCounter := 0 muchCounter := 0 for i := 0; i < 1000000; i++ { - //value := rand.Intn(100) - //if value <= 5 { - // counter++ - // sendSearch(client, counter) - //} else if value <= 20 { - //muchCounter++ - // sendMuch(client, uint64(muchCounter)) - //} else { - // timeCounter++ - // sendCurrTime(client, uint64(timeCounter)) - //} - //counter++ - //sendSearch(client, counter) - //timeCounter++ - //sendCurrTime(client, uint64(timeCounter)) - //muchCounter++ - sendMuch(client, muchCounter) + value := rand.Intn(100) + if value <= 5 { + counter++ + sendSearch(client, counter) + } else if value <= 20 { + muchCounter++ + sendMuch(client, muchCounter) + } else { + timeCounter++ + sendCurrTime(client, timeCounter) + } //time.Sleep(100 * time.Millisecond) time.Sleep(10 * time.Second) } @@ -89,7 +84,7 @@ func sendSearch(client pb.SearchServiceClient, i int) { log.Println("resp:", string(bt)) } -func sendCurrTime(client pb.SearchServiceClient, id uint64) { +func sendCurrTime(client pb.SearchServiceClient, id int) { md := metadata.New(map[string]string{ "testkey3": "testvalue3", "testkey4": "testvalue4", @@ -97,7 +92,7 @@ func sendCurrTime(client pb.SearchServiceClient, id uint64) { ctx := metadata.NewOutgoingContext(context.Background(), md) resp, err := client.CurrentTime( ctx, - &pb.TimeRequest{RequestId: id}, + &pb.TimeRequest{RequestId: uint64(id)}, ) if err != nil { log.Fatalf("client.CurrentTime err: %v", err) diff --git a/http2/http2.go b/http2/http2.go index 551a5b8..20af38a 100644 --- a/http2/http2.go +++ b/http2/http2.go @@ -10,8 +10,6 @@ import ( "github.com/vearne/grpcreplay/protocol" slog "github.com/vearne/simplelog" "golang.org/x/net/http2/hpack" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" "io" "strings" "time" @@ -121,7 +119,7 @@ type Http2Conn struct { MaxHeaderStringLen uint32 HeaderDecoder *hpack.Decoder Streams [StreamArraySize]*Stream - SocketBuffer *SocketBuffer + TCPBuffer *TCPBuffer Reader *bufio.Reader Processor *Processor } @@ -136,8 +134,8 @@ func NewHttp2Conn(conn DirectConn, maxDynamicTableSize uint32, p *Processor) *Ht for i := 0; i < StreamArraySize; i++ { hc.Streams[i] = NewStream() } - hc.SocketBuffer = NewSocketBuffer() - hc.Reader = bufio.NewReaderSize(hc.SocketBuffer, ReadBufferSize) + hc.TCPBuffer = NewTCPBuffer() + hc.Reader = bufio.NewReaderSize(hc.TCPBuffer, ReadBufferSize) hc.Processor = p go hc.deal() @@ -379,6 +377,7 @@ func NewStream() *Stream { func (s *Stream) toMsg(finder *PBMessageFinder) (*protocol.Message, error) { var msg protocol.Message + var err error id := uuid.Must(uuid.NewUUID()) msg.Meta.Version = 1 msg.Meta.UUID = id.String() @@ -395,20 +394,9 @@ func (s *Stream) toMsg(finder *PBMessageFinder) (*protocol.Message, error) { return nil, errors.New("method is empty") } else if !strings.Contains(s.Method, "grpc.reflection") { // Note: Temporarily only handle the case where the encoding method is Protobuf - pbMsg, err := finder.FindMethodInputWithCache(s.Method) - if err != nil { - slog.Error("finder.FindMethodInputWithCache, error:%v", err) - return nil, err - } - err = proto.Unmarshal(s.DataBuf.Bytes(), pbMsg) - if err != nil { - slog.Error("method:%v, proto.Unmarshal:%v", s.Method, err) - return nil, err - } - - s.Request, err = protojson.Marshal(pbMsg) + s.Request, err = finder.HandleRequestToJson(s.Method, s.DataBuf.Bytes()) if err != nil { - slog.Error("method:%v, json.Marshal:%v", s.Method, err) + slog.Error("method:%v, HandleRequestToJson:%v", s.Method, err) return nil, err } } diff --git a/http2/processor.go b/http2/processor.go index 8a6209c..433bfb0 100644 --- a/http2/processor.go +++ b/http2/processor.go @@ -6,6 +6,7 @@ import ( "github.com/fullstorydev/grpcurl" "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/grpcreflect" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" @@ -55,13 +56,13 @@ func (p *Processor) ProcessTCPPkg() { // connection preface if IsConnPreface(payload) { - hc.SocketBuffer.expectedSeq = int64(pkg.TCP.Seq) + int64(len(pkg.TCP.Payload)) - hc.SocketBuffer.leftPointer = hc.SocketBuffer.expectedSeq + hc.TCPBuffer.expectedSeq = int64(pkg.TCP.Seq) + int64(len(pkg.TCP.Payload)) + hc.TCPBuffer.leftPointer = hc.TCPBuffer.expectedSeq continue } slog.Debug("[AddTCP]Connection:%v, seq:%v, length:%v", dc.String(), pkg.TCP.Seq, len(payload)) - hc.SocketBuffer.AddTCP(pkg.TCP) + hc.TCPBuffer.AddTCP(pkg.TCP) } } @@ -87,7 +88,7 @@ func getCodecType(headers map[string]string) int { } type PBMessageFinder struct { - cacheMu sync.RWMutex + cacheMu sync.Mutex symbolMsg map[string]proto.Message // server address addr string @@ -101,12 +102,33 @@ func NewPBMessageFinder(addr string) *PBMessageFinder { return &f } +func (f *PBMessageFinder) HandleRequestToJson(svcAndMethod string, data []byte) ([]byte, error) { + f.cacheMu.Lock() + defer f.cacheMu.Unlock() + + pbMsg, err := f.FindMethodInputWithCache(svcAndMethod) + if err != nil { + slog.Error("finder.FindMethodInputWithCache, error:%v", err) + return nil, err + } + err = proto.Unmarshal(data, pbMsg) + if err != nil { + slog.Error("method:%v, proto.Unmarshal:%v", svcAndMethod, err) + return nil, err + } + + result, err := protojson.Marshal(pbMsg) + if err != nil { + slog.Error("method:%v, json.Marshal:%v", svcAndMethod, err) + return nil, err + } + return result, nil +} + func (f *PBMessageFinder) FindMethodInputWithCache(svcAndMethod string) (proto.Message, error) { slog.Debug("FindMethodInputWithCache, svcAndMethod:%v", svcAndMethod) - f.cacheMu.RLock() m, ok := f.symbolMsg[svcAndMethod] - f.cacheMu.RUnlock() if ok { slog.Debug("FindMethodInputWithCache,svcAndMethod:%v, hit cache", svcAndMethod) return m, nil @@ -117,9 +139,7 @@ func (f *PBMessageFinder) FindMethodInputWithCache(svcAndMethod string) (proto.M return nil, err } - f.cacheMu.Lock() f.symbolMsg[svcAndMethod] = msg - f.cacheMu.Unlock() return msg, nil } diff --git a/http2/socket_buffer.go b/http2/tcp_buffer.go similarity index 62% rename from http2/socket_buffer.go rename to http2/tcp_buffer.go index 8e130fa..49814b2 100644 --- a/http2/socket_buffer.go +++ b/http2/tcp_buffer.go @@ -6,15 +6,12 @@ import ( "github.com/huandu/skiplist" slog "github.com/vearne/simplelog" "math" - "sync" - "sync/atomic" ) -type SocketBuffer struct { - lock sync.Mutex +type TCPBuffer struct { //The number of bytes of data currently cached size uint32 - actualCanReadSize atomic.Uint32 + actualCanReadSize uint32 List *skiplist.SkipList expectedSeq int64 // The sliding window contains the leftPointer @@ -24,11 +21,11 @@ type SocketBuffer struct { dataChannel chan []byte } -func NewSocketBuffer() *SocketBuffer { - var sb SocketBuffer +func NewTCPBuffer() *TCPBuffer { + var sb TCPBuffer sb.List = skiplist.New(skiplist.Uint32) sb.size = 0 - sb.actualCanReadSize.Store(0) + sb.actualCanReadSize = 0 sb.expectedSeq = -1 sb.leftPointer = -1 sb.dataChannel = make(chan []byte, 10) @@ -36,44 +33,33 @@ func NewSocketBuffer() *SocketBuffer { } // may block -func (sb *SocketBuffer) Read(p []byte) (n int, err error) { - slog.Debug("[start]SocketBuffer.Read") - var data []byte - if sb.actualCanReadSize.Load() > 0 { - slog.Debug("SocketBuffer.Read, satisfy the conditions, size:%v, actualCanReadSize:%v, expectedSeq:%v", - sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq) - data = sb.getData() - } else { - data = <-sb.dataChannel - } +func (sb *TCPBuffer) Read(p []byte) (n int, err error) { + data := <-sb.dataChannel n = copy(p, data) - slog.Debug("[end]SocketBuffer.Read, got:%v bytes", n) + slog.Debug("SocketBuffer.Read, got:%v bytes", n) return n, nil } -func (sb *SocketBuffer) AddTCP(tcpPkg *layers.TCP) { +func (sb *TCPBuffer) AddTCP(tcpPkg *layers.TCP) { sb.addTCP(tcpPkg) - if sb.actualCanReadSize.Load() > 0 { + if sb.actualCanReadSize > 0 { slog.Debug("SocketBuffer.AddTCP, satisfy the conditions, size:%v, actualCanReadSize:%v, expectedSeq:%v", - sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq) + sb.size, sb.actualCanReadSize, sb.expectedSeq) data := sb.getData() slog.Debug("push to channel: %v bytes", len(data)) sb.dataChannel <- data } } -func (sb *SocketBuffer) addTCP(tcpPkg *layers.TCP) { +func (sb *TCPBuffer) addTCP(tcpPkg *layers.TCP) { slog.Debug("[start]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v", - sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq) - - sb.lock.Lock() - defer sb.lock.Unlock() + sb.size, sb.actualCanReadSize, sb.expectedSeq) // duplicate package if int64(tcpPkg.Seq) < sb.leftPointer || sb.List.Get(tcpPkg.Seq) != nil { slog.Debug("[end]SocketBuffer.addTCP-duplicate package, size:%v, actualCanReadSize:%v, expectedSeq:%v", - sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq) + sb.size, sb.actualCanReadSize, sb.expectedSeq) return } @@ -83,7 +69,7 @@ func (sb *SocketBuffer) addTCP(tcpPkg *layers.TCP) { for ele != nil && sb.expectedSeq == int64(tcpPkg.Seq) { // expect next sequence number sb.expectedSeq = int64((tcpPkg.Seq + uint32(len(tcpPkg.Payload))) % math.MaxUint32) - sb.actualCanReadSize.Store(sb.actualCanReadSize.Load() + uint32(len(tcpPkg.Payload))) + sb.actualCanReadSize += uint32(len(tcpPkg.Payload)) ele = ele.Next() if ele != nil { @@ -91,15 +77,12 @@ func (sb *SocketBuffer) addTCP(tcpPkg *layers.TCP) { } } slog.Debug("[end]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v", - sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq) + sb.size, sb.actualCanReadSize, sb.expectedSeq) } -func (sb *SocketBuffer) getData() []byte { +func (sb *TCPBuffer) getData() []byte { slog.Debug("[start]SocketBuffer.getData, size:%v, actualCanReadSize:%v, expectedSeq:%v", - sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq) - - sb.lock.Lock() - defer sb.lock.Unlock() + sb.size, sb.actualCanReadSize, sb.expectedSeq) var tcpPkg *layers.TCP buf := bytes.NewBuffer([]byte{}) @@ -110,7 +93,7 @@ func (sb *SocketBuffer) getData() []byte { needRemoveList := make([]*skiplist.Element, 0) for ele != nil && int64(tcpPkg.Seq) <= sb.expectedSeq { - sb.actualCanReadSize.Store(sb.actualCanReadSize.Load() - uint32(len(tcpPkg.Payload))) + sb.actualCanReadSize -= uint32(len(tcpPkg.Payload)) sb.size -= uint32(len(tcpPkg.Payload)) sb.leftPointer += int64(len(tcpPkg.Payload)) @@ -129,6 +112,6 @@ func (sb *SocketBuffer) getData() []byte { } slog.Debug("[end]SocketBuffer.getData, size:%v, actualCanReadSize:%v, expectedSeq:%v, data: %v bytes", - sb.size, sb.actualCanReadSize.Load(), sb.expectedSeq, buf.Len()) + sb.size, sb.actualCanReadSize, sb.expectedSeq, buf.Len()) return buf.Bytes() } diff --git a/http2/socket_buffer_test.go b/http2/tcp_buffer_test.go similarity index 100% rename from http2/socket_buffer_test.go rename to http2/tcp_buffer_test.go diff --git a/plugin/tcp_kill.go b/plugin/tcp_kill.go index 4b88706..5fb120d 100644 --- a/plugin/tcp_kill.go +++ b/plugin/tcp_kill.go @@ -20,7 +20,6 @@ const ( URG ) -// srcAddr, dstAddr: TCP地址 func sendFakePkg(seq uint32, srcAddr string, srcPort uint16, dstAddr string, dstPort uint16, flag uint8) { var ( From 35a01fd1296c29ba0a3f667aa79b280b4704289c Mon Sep 17 00:00:00 2001 From: vearne Date: Sun, 15 Sep 2024 11:07:06 +0800 Subject: [PATCH 3/7] update version --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 16d4862..3f63b0c 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -VERSION := v0.1.9 +VERSION := v0.2.0 BIN_NAME = grpcr CONTAINER = grpcr From 9702c4e4f91da5c3a6d286b61f34241be7eaeced Mon Sep 17 00:00:00 2001 From: vearne Date: Sun, 15 Sep 2024 13:34:29 +0800 Subject: [PATCH 4/7] skip invalid device --- plugin/input_raw.go | 41 ++++++++++++++++++++++++----------------- plugin/tcp_kill.go | 2 +- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/plugin/input_raw.go b/plugin/input_raw.go index 8623b60..63389ff 100644 --- a/plugin/input_raw.go +++ b/plugin/input_raw.go @@ -68,9 +68,12 @@ func (l *DeviceListener) listen() error { if netPkg.Direction == http2.DirIncoming { if l.rawInput.connSet.Has(conn) { // history connection if netPkg.TCP.ACK { - slog.Debug("send RST, for connection:%v", &conn) + slog.Debug("receive Ack package, for connection:%v, expected seq:%v, window:%v", + &conn, netPkg.TCP.Ack, netPkg.TCP.Window) + actualSeq := rand.Uint32()%uint32(netPkg.TCP.Window) + netPkg.TCP.Ack + slog.Debug("send RST, for connection:%v, seq:%v", &conn, actualSeq) // forge a packet from local -> remote - sendFakePkg(netPkg.TCP.Ack, conn.DstAddr.IP, uint16(conn.DstAddr.Port), + sendFakePkg(actualSeq, conn.DstAddr.IP, uint16(conn.DstAddr.Port), conn.SrcAddr.IP, uint16(conn.SrcAddr.Port), RST) } else if netPkg.TCP.SYN { // // new connection slog.Debug("got SYN, remove %v from connSet", &conn) @@ -145,10 +148,12 @@ func NewRAWInput(address string) (*RAWInput, error) { // save all local IP addresses to determine the source of the packet later for _, itf := range itfStatList { - for _, addr := range itf.Addrs { - idx := strings.LastIndex(addr.Addr, "/") - //slog.Debug("addr: %v", addr.Addr[0:idx]) - i.ipSet.Add(addr.Addr[0:idx]) + if itf.MTU > 0 { + for _, addr := range itf.Addrs { + idx := strings.LastIndex(addr.Addr, "/") + //slog.Debug("addr: %v", addr.Addr[0:idx]) + i.ipSet.Add(addr.Addr[0:idx]) + } } } @@ -157,15 +162,19 @@ func NewRAWInput(address string) (*RAWInput, error) { host = strings.TrimSpace(host) if len(host) <= 0 || host == "0.0.0.0" { // all devices for _, itf := range itfStatList { - slog.Debug("interface:%v", itf.Name) - deviceList = append(deviceList, itf.Name) + if itf.MTU > 0 { + slog.Debug("interface:%v", itf.Name) + deviceList = append(deviceList, itf.Name) + } } } else { for _, itf := range itfStatList { - for _, addr := range itf.Addrs { - slog.Debug("interface:%v, addr:%v, host:%v", itf.Name, addr.Addr, host) - if strings.HasPrefix(addr.Addr, host) { - deviceList = append(deviceList, itf.Name) + if itf.MTU > 0 { + for _, addr := range itf.Addrs { + slog.Debug("interface:%v, addr:%v, host:%v", itf.Name, addr.Addr, host) + if strings.HasPrefix(addr.Addr, host) { + deviceList = append(deviceList, itf.Name) + } } } } @@ -211,12 +220,10 @@ func (i *RAWInput) Listen() { } // A∩B - B := http2.NewConnSet() - B.AddAll(cons) + newSet := http2.NewConnSet() + newSet.AddAll(cons) - A := i.connSet.Clone() - A.RemoveAll(B) - i.connSet.RemoveAll(A) + i.connSet = i.connSet.Intersection(newSet) for _, conn := range i.connSet.ToArray() { // trigger challenge ack // remote -> local diff --git a/plugin/tcp_kill.go b/plugin/tcp_kill.go index 5fb120d..9bfa5fb 100644 --- a/plugin/tcp_kill.go +++ b/plugin/tcp_kill.go @@ -75,7 +75,7 @@ func sendFakePkg(seq uint32, srcAddr string, srcPort uint16, addr.Addr = IPtoByte(dstAddr) addr.Port = int(dstPort) - slog.Debug("send %v , %v:%v -> %v:%v", flagStr(flag), + slog.Debug("send %v, %v:%v -> %v:%v", flagStr(flag), srcAddr, srcPort, dstAddr, dstPort) if err = syscall.Sendto(sockfd, buffer.Bytes(), 0, &addr); err != nil { slog.Error("Sendto() error: %v", err) From c36d972b356894d85aee3fa2baf9aad0814283b3 Mon Sep 17 00:00:00 2001 From: vearne Date: Sun, 15 Sep 2024 18:44:41 +0800 Subject: [PATCH 5/7] fix tcpkill --- http2/net_pkg.go | 6 +- plugin/input_raw.go | 34 ++++-- plugin/tcp_kill.go | 202 ++++++++++++------------------------ test/capture/main.go | 37 ++++--- test/string_deal/main.go | 1 + test/tcpkill/main.go | 217 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 337 insertions(+), 160 deletions(-) create mode 100644 test/tcpkill/main.go diff --git a/http2/net_pkg.go b/http2/net_pkg.go index 80436c7..697d71a 100644 --- a/http2/net_pkg.go +++ b/http2/net_pkg.go @@ -25,6 +25,7 @@ type NetPkg struct { SrcIP string DstIP string + Ethernet *layers.Ethernet IPv4 *layers.IPv4 IPv6 *layers.IPv6 TCP *layers.TCP @@ -33,12 +34,15 @@ type NetPkg struct { func ProcessPacket(packet gopacket.Packet, ipSet *util.StringSet, port int) (*NetPkg, error) { var p NetPkg + + ethernet := packet.Layer(layers.LayerTypeEthernet) ipLayerIPv4 := packet.Layer(layers.LayerTypeIPv4) ipLayerIPv6 := packet.Layer(layers.LayerTypeIPv6) - if ipLayerIPv4 == nil && ipLayerIPv6 == nil { + if ethernet == nil || (ipLayerIPv4 == nil && ipLayerIPv6 == nil) { return nil, errors.New("invalid IP package") } + p.Ethernet = ethernet.(*layers.Ethernet) if ipLayerIPv4 != nil { p.IPv4 = ipLayerIPv4.(*layers.IPv4) p.SrcIP = p.IPv4.SrcIP.String() diff --git a/plugin/input_raw.go b/plugin/input_raw.go index 63389ff..8ea257a 100644 --- a/plugin/input_raw.go +++ b/plugin/input_raw.go @@ -3,6 +3,7 @@ package plugin import ( "fmt" "github.com/google/gopacket" + "github.com/google/gopacket/layers" "github.com/google/gopacket/pcap" psnet "github.com/shirou/gopsutil/v3/net" "github.com/vearne/grpcreplay/http2" @@ -20,6 +21,7 @@ const ( snapshotLen int32 = 1024 * 1024 promiscuous bool = false timeout time.Duration = 5 * time.Second + RSTNum = 3 ) type DeviceListener struct { @@ -67,14 +69,21 @@ func (l *DeviceListener) listen() error { &conn, http2.GetDirection(netPkg.Direction)) if netPkg.Direction == http2.DirIncoming { if l.rawInput.connSet.Has(conn) { // history connection - if netPkg.TCP.ACK { + if netPkg.TCP.ACK && netPkg.IPv4 != nil { slog.Debug("receive Ack package, for connection:%v, expected seq:%v, window:%v", &conn, netPkg.TCP.Ack, netPkg.TCP.Window) - actualSeq := rand.Uint32()%uint32(netPkg.TCP.Window) + netPkg.TCP.Ack - slog.Debug("send RST, for connection:%v, seq:%v", &conn, actualSeq) - // forge a packet from local -> remote - sendFakePkg(actualSeq, conn.DstAddr.IP, uint16(conn.DstAddr.Port), - conn.SrcAddr.IP, uint16(conn.SrcAddr.Port), RST) + actualSeq := netPkg.TCP.Ack + for i := 0; i < RSTNum; i++ { + actualSeq += uint32(netPkg.TCP.Window) * uint32(i) + slog.Debug("send RST, for connection:%v, seq:%v", &conn, actualSeq) + // forge a packet from local -> remote + err = SendRST(netPkg.Ethernet.DstMAC, netPkg.Ethernet.SrcMAC, + netPkg.IPv4.DstIP, netPkg.IPv4.SrcIP, netPkg.TCP.DstPort, + netPkg.TCP.SrcPort, actualSeq, l.handle) + if err != nil { + slog.Error("SendRST, for connection:%v, error:%v", &conn, err) + } + } } else if netPkg.TCP.SYN { // // new connection slog.Debug("got SYN, remove %v from connSet", &conn) l.rawInput.connSet.Remove(conn) @@ -229,8 +238,13 @@ func (i *RAWInput) Listen() { // remote -> local // src -> dst // Fake a packet from local -> remote - sendFakePkg(uint32(rand.Intn(100)), conn.DstAddr.IP, uint16(conn.DstAddr.Port), - conn.SrcAddr.IP, uint16(conn.SrcAddr.Port), SYN) + err = SendSYN(IPtoByte(conn.DstAddr.IP), IPtoByte(conn.SrcAddr.IP), + layers.TCPPort(conn.DstAddr.Port), + layers.TCPPort(conn.SrcAddr.Port), + uint32(rand.Intn(100)), nil) + if err != nil { + slog.Error("SendSYN, for connection:%v, error:%v", &conn, err) + } } slog.Debug("history connections:%v", i.connSet) } @@ -267,3 +281,7 @@ func listAllConns(port int) ([]http2.DirectConn, error) { } return conns, nil } + +func IPtoByte(ipStr string) []byte { + return net.ParseIP(ipStr).To4() +} diff --git a/plugin/tcp_kill.go b/plugin/tcp_kill.go index 9bfa5fb..78c6698 100644 --- a/plugin/tcp_kill.go +++ b/plugin/tcp_kill.go @@ -1,163 +1,93 @@ package plugin import ( - "bytes" - "encoding/binary" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" slog "github.com/vearne/simplelog" "net" - "strconv" - "strings" - "syscall" - "unsafe" ) -const ( - FIN = 1 << iota - SYN - RST - PSH - ACK - URG -) - -func sendFakePkg(seq uint32, srcAddr string, srcPort uint16, - dstAddr string, dstPort uint16, flag uint8) { - var ( - msg string - ipheader IPHeader - tcpheader TCPHeader - buffer bytes.Buffer - ) - - /*Fill IP header*/ - ipheader.SrcAddr = inetAddr(srcAddr) - ipheader.DstAddr = inetAddr(dstAddr) - ipheader.Zero = 0 - ipheader.ProtoType = syscall.IPPROTO_TCP - ipheader.TcpLength = uint16(unsafe.Sizeof(TCPHeader{})) + uint16(len(msg)) - - /*Filling TCP header*/ - tcpheader.SrcPort = srcPort - tcpheader.DstPort = dstPort - tcpheader.SeqNum = seq - tcpheader.AckNum = 0 - tcpheader.Offset = uint8(uint16(unsafe.Sizeof(TCPHeader{}))/4) << 4 - tcpheader.Flag |= flag - tcpheader.Window = 60000 - tcpheader.Checksum = 0 - - // calculate Checksum - // nolint: errcheck - binary.Write(&buffer, binary.BigEndian, ipheader) - // nolint: errcheck - binary.Write(&buffer, binary.BigEndian, tcpheader) - tcpheader.Checksum = CheckSum(buffer.Bytes()) +func SendSYN(srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, seq uint32, handle *pcap.Handle) error { + slog.Info("send %v:%v > %v:%v [SYN] seq %v", srcIp.String(), srcPort.String(), + dstIp.String(), dstPort.String(), seq) + iPv4 := layers.IPv4{ + SrcIP: srcIp, + DstIP: dstIp, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + } - // Next, clear the buffer and fill it with the part that is actually to be sent. - buffer.Reset() - //nolint:all - binary.Write(&buffer, binary.BigEndian, tcpheader) - //nolint:all - binary.Write(&buffer, binary.BigEndian, msg) + tcp := layers.TCP{ + SrcPort: srcPort, + DstPort: dstPort, + Seq: seq, + SYN: true, + } - /*raw socket*/ - var ( - sockfd int - addr syscall.SockaddrInet4 - err error - ) - if sockfd, err = syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_TCP); err != nil { - slog.Error("Socket() error: %v", err) - return + if err := tcp.SetNetworkLayerForChecksum(&iPv4); err != nil { + return err } - // nolint: errcheck - defer syscall.Shutdown(sockfd, syscall.SHUT_RDWR) - addr.Addr = IPtoByte(dstAddr) - addr.Port = int(dstPort) - slog.Debug("send %v, %v:%v -> %v:%v", flagStr(flag), - srcAddr, srcPort, dstAddr, dstPort) - if err = syscall.Sendto(sockfd, buffer.Bytes(), 0, &addr); err != nil { - slog.Error("Sendto() error: %v", err) - return + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + if err := gopacket.SerializeLayers(buffer, options, &tcp); err != nil { + return err } - slog.Debug("Send success!") -} -type TCPHeader struct { - SrcPort uint16 - DstPort uint16 - SeqNum uint32 - AckNum uint32 - Offset uint8 - Flag uint8 - Window uint16 - Checksum uint16 - UrgentPtr uint16 + err := handle.WritePacketData(buffer.Bytes()) + if err != nil { + return err + } + return nil } -type IPHeader struct { - SrcAddr uint32 - DstAddr uint32 - Zero uint8 - ProtoType uint8 - TcpLength uint16 -} +func SendRST(srcMac, dstMac net.HardwareAddr, srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, + seq uint32, handle *pcap.Handle) error { + slog.Info("send %v:%v > %v:%v [RST] seq %v", srcIp.String(), srcPort.String(), + dstIp.String(), dstPort.String(), seq) -func flagStr(flag uint8) string { - m := make(map[uint8]string) - m[FIN] = "FIN" - m[SYN] = "SYN" - m[RST] = "RST" - m[PSH] = "PSH" - m[ACK] = "ACK" - m[URG] = "URG" + eth := layers.Ethernet{ + SrcMAC: srcMac, + DstMAC: dstMac, + EthernetType: layers.EthernetTypeIPv4, + } - tmpList := make([]string, 0) - for _, f := range []uint8{FIN, SYN, RST, PSH, ACK, URG} { - if flag&f > 0 { - tmpList = append(tmpList, m[f]) - } + iPv4 := layers.IPv4{ + SrcIP: srcIp, + DstIP: dstIp, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, } - return strings.Join(tmpList, "|") -} -func IPtoByte(ipStr string) [4]byte { - var addr [4]byte - for i, bt := range net.ParseIP(ipStr).To4() { - addr[i] = bt + tcp := layers.TCP{ + SrcPort: srcPort, + DstPort: dstPort, + Seq: seq, + RST: true, } - return addr -} -func inetAddr(ipaddr string) uint32 { - var ( - segments []string = strings.Split(ipaddr, ".") - ip [4]uint64 - ret uint64 - ) - for i := 0; i < 4; i++ { - ip[i], _ = strconv.ParseUint(segments[i], 10, 64) + if err := tcp.SetNetworkLayerForChecksum(&iPv4); err != nil { + return err } - ret = ip[3]<<24 + ip[2]<<16 + ip[1]<<8 + ip[0] - return uint32(ret) -} -func CheckSum(data []byte) uint16 { - var ( - sum uint32 - length int = len(data) - index int - ) - for length > 1 { - sum += uint32(data[index])<<8 + uint32(data[index+1]) - index += 2 - length -= 2 + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, } - if length > 0 { - sum += uint32(data[index]) + if err := gopacket.SerializeLayers(buffer, options, ð, &iPv4, &tcp); err != nil { + return err } - sum += (sum >> 16) - return uint16(^sum) + err := handle.WritePacketData(buffer.Bytes()) + if err != nil { + return err + } + return nil } diff --git a/test/capture/main.go b/test/capture/main.go index b49c845..a2d411d 100644 --- a/test/capture/main.go +++ b/test/capture/main.go @@ -34,20 +34,23 @@ type ConnEntry struct { } var ( - device string = "lo0" - snapshotLen int32 = 1024 - promiscuous bool = false + device string = "en0" + //device string = "lo0" + snapshotLen int32 = 1024 + promiscuous bool = false err error timeout time.Duration = 5 * time.Second handle *pcap.Handle ) func main() { + port := 9000 + go func() { - killAllConn() + killAllConn(port) }() time.Sleep(5 * time.Second) - itemList := getAllConn(8080) + itemList := getAllConn(port) for _, item := range itemList { sendRST(10, item.dstIP, uint16(item.dstPort), item.srcIP, uint16(item.srcPort), SYN) } @@ -56,10 +59,10 @@ func main() { func getAllConn(localPort int) []*ConnEntry { result := make([]*ConnEntry, 0) - item1 := ConnEntry{srcIP: "127.0.0.1", srcPort: 63843, dstIP: "127.0.0.1", dstPort: 8080} + item1 := ConnEntry{srcIP: "192.168.8.233", srcPort: 34932, dstIP: "192.168.8.218", dstPort: 9000} result = append(result, &item1) - item2 := ConnEntry{srcIP: "127.0.0.1", srcPort: 63842, dstIP: "127.0.0.1", dstPort: 8080} - result = append(result, &item2) + //item2 := ConnEntry{srcIP: "127.0.0.1", srcPort: 63842, dstIP: "127.0.0.1", dstPort: 8080} + //result = append(result, &item2) return result } @@ -67,7 +70,7 @@ func fmtAddr(ipAddr net.IP, port int) string { return fmt.Sprintf("%v:%v", ipAddr.String(), port) } -func killAllConn() { +func killAllConn(port int) { // Open device handle, err = pcap.OpenLive(device, snapshotLen, promiscuous, timeout) if err != nil { @@ -75,7 +78,7 @@ func killAllConn() { } defer handle.Close() // Set filter - var filter string = "tcp and port 8080" + var filter string = fmt.Sprintf("tcp and port %v", port) err = handle.SetBPFFilter(filter) if err != nil { log.Fatal(err) @@ -109,15 +112,19 @@ func killAllConn() { tcp, _ := tcpLayer.(*layers.TCP) srcPort = tcp.SrcPort dstPort = tcp.DstPort - if dstPort == 8080 && tcp.ACK { + if int(tcp.DstPort) == port && tcp.ACK { ACKFlag = true seq = tcp.Ack + if ACKFlag { + //var i uint32 + for i := 0; i < 9; i++ { + seq += uint32(i) * uint32(tcp.Window) + sendRST(seq, dstIP.String(), uint16(dstPort), srcIP.String(), uint16(srcPort), RST) + } + } } } - if ACKFlag { - //var i uint32 - sendRST(seq, dstIP.String(), uint16(dstPort), srcIP.String(), uint16(srcPort), RST) - } + } } diff --git a/test/string_deal/main.go b/test/string_deal/main.go index 3c8dec5..85e6a2c 100644 --- a/test/string_deal/main.go +++ b/test/string_deal/main.go @@ -9,4 +9,5 @@ func main() { fmt.Println(strings.HasPrefix("192.168.8.128aa", "192.168.8.128")) fmt.Println(strings.HasPrefix("192.168.8.218 aaaa", "192.168.8.128")) + fmt.Println(8 >> 0) } diff --git a/test/tcpkill/main.go b/test/tcpkill/main.go new file mode 100644 index 0000000..9ab89bf --- /dev/null +++ b/test/tcpkill/main.go @@ -0,0 +1,217 @@ +package main + +import ( + "fmt" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" + "log" + "net" + "time" +) + +const ( + FIN = 1 << iota + SYN + RST + PSH + ACK + URG +) + +type ConnEntry struct { + srcIP string + dstIP string + srcPort layers.TCPPort + dstPort layers.TCPPort +} + +var ( + device string = "en0" + //device string = "lo0" + snapshotLen int32 = 1024 + promiscuous bool = false + err error + timeout time.Duration = 5 * time.Second + handle *pcap.Handle +) + +func main() { + port := 9000 + + go func() { + killAllConn(port) + }() + time.Sleep(5 * time.Second) + itemList := getAllConn(port) + for _, item := range itemList { + SendSYN(IPtoByte(item.srcIP), IPtoByte(item.dstIP), item.srcPort, item.dstPort, 10, nil) + } + time.Sleep(5 * time.Minute) +} + +func getAllConn(localPort int) []*ConnEntry { + result := make([]*ConnEntry, 0) + item1 := ConnEntry{srcIP: "192.168.8.233", srcPort: 34932, dstIP: "192.168.8.218", dstPort: 9000} + result = append(result, &item1) + //item2 := ConnEntry{srcIP: "127.0.0.1", srcPort: 63842, dstIP: "127.0.0.1", dstPort: 8080} + //result = append(result, &item2) + return result +} + +func fmtAddr(ipAddr net.IP, port int) string { + return fmt.Sprintf("%v:%v", ipAddr.String(), port) +} + +func killAllConn(port int) { + // Open device + handle, err = pcap.OpenLive(device, snapshotLen, promiscuous, timeout) + if err != nil { + log.Fatal(err) + } + defer handle.Close() + // Set filter + var filter string = fmt.Sprintf("tcp and port %v", port) + err = handle.SetBPFFilter(filter) + if err != nil { + log.Fatal(err) + } + fmt.Println("Only capturing TCP port 8080 packets.") + + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + var ( + srcIP net.IP + dstIP net.IP + srcPort layers.TCPPort + dstPort layers.TCPPort + ACKFlag bool + seq uint32 + ) + for packet := range packetSource.Packets() { + ACKFlag = false + //packet.Data() + fmt.Println("got package") + + ethLayer := packet.Layer(layers.LayerTypeEthernet) + if ethLayer == nil { + continue + } + eth := ethLayer.(*layers.Ethernet) + ipLayer := packet.Layer(layers.LayerTypeIPv4) + if ipLayer == nil { + continue + } + ip, _ := ipLayer.(*layers.IPv4) + srcIP = ip.SrcIP + dstIP = ip.DstIP + + // Let's see if the packet is TCP + tcpLayer := packet.Layer(layers.LayerTypeTCP) + if tcpLayer != nil { + tcp, _ := tcpLayer.(*layers.TCP) + srcPort = tcp.SrcPort + dstPort = tcp.DstPort + if int(tcp.DstPort) == port && tcp.ACK { + ACKFlag = true + seq = tcp.Ack + if ACKFlag { + //var i uint32 + for i := 0; i < 3; i++ { + seq += uint32(i) * uint32(tcp.Window) + err = SendRST(eth.DstMAC, eth.SrcMAC, dstIP, srcIP, dstPort, srcPort, seq, handle) + if err != nil { + log.Println("SendNetPkg", err) + } + } + } + } + } + + } +} + +func SendSYN(srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, seq uint32, handle *pcap.Handle) error { + log.Printf("send %v:%v > %v:%v [SYN] seq %v", srcIp.String(), srcPort.String(), dstIp.String(), dstPort.String(), seq) + iPv4 := layers.IPv4{ + SrcIP: srcIp, + DstIP: dstIp, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + } + + tcp := layers.TCP{ + SrcPort: srcPort, + DstPort: dstPort, + Seq: seq, + SYN: true, + } + + if err := tcp.SetNetworkLayerForChecksum(&iPv4); err != nil { + return err + } + + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + if err := gopacket.SerializeLayers(buffer, options, &tcp); err != nil { + return err + } + + err := handle.WritePacketData(buffer.Bytes()) + if err != nil { + return err + } + return nil +} + +func SendRST(srcMac, dstMac net.HardwareAddr, srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, seq uint32, handle *pcap.Handle) error { + log.Printf("send %v:%v > %v:%v [RST] seq %v", srcIp.String(), srcPort.String(), dstIp.String(), dstPort.String(), seq) + + eth := layers.Ethernet{ + SrcMAC: srcMac, + DstMAC: dstMac, + EthernetType: layers.EthernetTypeIPv4, + } + + iPv4 := layers.IPv4{ + SrcIP: srcIp, + DstIP: dstIp, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + } + + tcp := layers.TCP{ + SrcPort: srcPort, + DstPort: dstPort, + Seq: seq, + RST: true, + } + + if err := tcp.SetNetworkLayerForChecksum(&iPv4); err != nil { + return err + } + + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + if err := gopacket.SerializeLayers(buffer, options, ð, &iPv4, &tcp); err != nil { + return err + } + + err := handle.WritePacketData(buffer.Bytes()) + if err != nil { + return err + } + return nil +} + +func IPtoByte(ipStr string) []byte { + return net.ParseIP(ipStr).To4() +} From e414110538f944ba81c003c3bd0f1f92cbe21620 Mon Sep 17 00:00:00 2001 From: vearne Date: Mon, 16 Sep 2024 14:54:43 +0800 Subject: [PATCH 6/7] Optimize tcpkill function --- Makefile | 2 +- plugin/input_raw.go | 2 +- plugin/tcp_kill.go | 21 ++++++++++++++++-- test/tcpkill/main.go | 51 ++++++++++++++++++++++++++++++++------------ 4 files changed, 58 insertions(+), 18 deletions(-) diff --git a/Makefile b/Makefile index 3f63b0c..5788b71 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -VERSION := v0.2.0 +VERSION := v0.2.1 BIN_NAME = grpcr CONTAINER = grpcr diff --git a/plugin/input_raw.go b/plugin/input_raw.go index 8ea257a..50bb8a7 100644 --- a/plugin/input_raw.go +++ b/plugin/input_raw.go @@ -241,7 +241,7 @@ func (i *RAWInput) Listen() { err = SendSYN(IPtoByte(conn.DstAddr.IP), IPtoByte(conn.SrcAddr.IP), layers.TCPPort(conn.DstAddr.Port), layers.TCPPort(conn.SrcAddr.Port), - uint32(rand.Intn(100)), nil) + uint32(rand.Intn(100))) if err != nil { slog.Error("SendSYN, for connection:%v, error:%v", &conn, err) } diff --git a/plugin/tcp_kill.go b/plugin/tcp_kill.go index 78c6698..ac4ecb6 100644 --- a/plugin/tcp_kill.go +++ b/plugin/tcp_kill.go @@ -6,11 +6,13 @@ import ( "github.com/google/gopacket/pcap" slog "github.com/vearne/simplelog" "net" + "syscall" ) -func SendSYN(srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, seq uint32, handle *pcap.Handle) error { +func SendSYN(srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, seq uint32) error { slog.Info("send %v:%v > %v:%v [SYN] seq %v", srcIp.String(), srcPort.String(), dstIp.String(), dstPort.String(), seq) + iPv4 := layers.IPv4{ SrcIP: srcIp, DstIP: dstIp, @@ -39,10 +41,25 @@ func SendSYN(srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, seq uint32, h return err } - err := handle.WritePacketData(buffer.Bytes()) + // 创建一个原始套接字 + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_TCP) + if err != nil { + slog.Error("Socket creation error: %v", err) + return err + } + defer syscall.Close(fd) + // 设置目标地址 + addr := syscall.SockaddrInet4{ + Port: int(dstPort), // 目标端口 + } + copy(addr.Addr[:], dstIp) + // 发送 SYN 包 + err = syscall.Sendto(fd, buffer.Bytes(), 0, &addr) if err != nil { + slog.Error("Sendto error: %v", err) return err } + slog.Info("SYN packet sent") return nil } diff --git a/test/tcpkill/main.go b/test/tcpkill/main.go index 9ab89bf..1e48d70 100644 --- a/test/tcpkill/main.go +++ b/test/tcpkill/main.go @@ -2,12 +2,13 @@ package main import ( "fmt" - "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/google/gopacket/pcap" "log" + "math/rand" "net" + "syscall" "time" ) @@ -46,24 +47,24 @@ func main() { time.Sleep(5 * time.Second) itemList := getAllConn(port) for _, item := range itemList { - SendSYN(IPtoByte(item.srcIP), IPtoByte(item.dstIP), item.srcPort, item.dstPort, 10, nil) + err := SendSYN(IPtoByte(item.srcIP), IPtoByte(item.dstIP), item.srcPort, + item.dstPort, uint32(rand.Int31n(100))) + if err != nil { + + } } time.Sleep(5 * time.Minute) } func getAllConn(localPort int) []*ConnEntry { result := make([]*ConnEntry, 0) - item1 := ConnEntry{srcIP: "192.168.8.233", srcPort: 34932, dstIP: "192.168.8.218", dstPort: 9000} + item1 := ConnEntry{srcIP: "192.168.8.223", srcPort: 52456, dstIP: "192.168.8.218", dstPort: 9000} result = append(result, &item1) - //item2 := ConnEntry{srcIP: "127.0.0.1", srcPort: 63842, dstIP: "127.0.0.1", dstPort: 8080} - //result = append(result, &item2) + item2 := ConnEntry{srcIP: "192.168.8.218", srcPort: 9000, dstIP: "192.168.8.223", dstPort: 52456} + result = append(result, &item2) return result } -func fmtAddr(ipAddr net.IP, port int) string { - return fmt.Sprintf("%v:%v", ipAddr.String(), port) -} - func killAllConn(port int) { // Open device handle, err = pcap.OpenLive(device, snapshotLen, promiscuous, timeout) @@ -91,15 +92,19 @@ func killAllConn(port int) { for packet := range packetSource.Packets() { ACKFlag = false //packet.Data() + for i, layer := range packet.Layers() { + fmt.Println(i, ":", layer.LayerType().String()) + } fmt.Println("got package") - ethLayer := packet.Layer(layers.LayerTypeEthernet) if ethLayer == nil { + log.Println("ethLayer == nil") continue } eth := ethLayer.(*layers.Ethernet) ipLayer := packet.Layer(layers.LayerTypeIPv4) if ipLayer == nil { + log.Println("ipLayer == nil") continue } ip, _ := ipLayer.(*layers.IPv4) @@ -131,8 +136,10 @@ func killAllConn(port int) { } } -func SendSYN(srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, seq uint32, handle *pcap.Handle) error { - log.Printf("send %v:%v > %v:%v [SYN] seq %v", srcIp.String(), srcPort.String(), dstIp.String(), dstPort.String(), seq) +func SendSYN(srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, seq uint32) error { + log.Printf("send %v:%v > %v:%v [SYN] seq %v", srcIp.String(), srcPort.String(), + dstIp.String(), dstPort.String(), seq) + iPv4 := layers.IPv4{ SrcIP: srcIp, DstIP: dstIp, @@ -161,14 +168,30 @@ func SendSYN(srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, seq uint32, h return err } - err := handle.WritePacketData(buffer.Bytes()) + // 创建一个原始套接字 + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_TCP) + if err != nil { + fmt.Printf("Socket creation error: %v\n", err) + return err + } + defer syscall.Close(fd) + // 设置目标地址 + addr := syscall.SockaddrInet4{ + Port: int(dstPort), // 目标端口 + } + copy(addr.Addr[:], dstIp) + // 发送 SYN 包 + err = syscall.Sendto(fd, buffer.Bytes(), 0, &addr) if err != nil { + fmt.Printf("Sendto error: %v\n", err) return err } + fmt.Println("SYN packet sent") return nil } -func SendRST(srcMac, dstMac net.HardwareAddr, srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, seq uint32, handle *pcap.Handle) error { +func SendRST(srcMac, dstMac net.HardwareAddr, srcIp, dstIp net.IP, srcPort, dstPort layers.TCPPort, + seq uint32, handle *pcap.Handle) error { log.Printf("send %v:%v > %v:%v [RST] seq %v", srcIp.String(), srcPort.String(), dstIp.String(), dstPort.String(), seq) eth := layers.Ethernet{ From 5f105c095aef636a4760367f0db4089c19144a66 Mon Sep 17 00:00:00 2001 From: vearne Date: Mon, 16 Sep 2024 14:58:15 +0800 Subject: [PATCH 7/7] update http2/tcp_buffer_test.go --- http2/tcp_buffer_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/http2/tcp_buffer_test.go b/http2/tcp_buffer_test.go index b06acbd..d7f9cd1 100644 --- a/http2/tcp_buffer_test.go +++ b/http2/tcp_buffer_test.go @@ -10,7 +10,7 @@ import ( func TestSocketBufferSequence1(t *testing.T) { slog.SetLevel(slog.DebugLevel) - buffer := NewSocketBuffer() + buffer := NewTCPBuffer() buffer.expectedSeq = 1000 buffer.leftPointer = 1000 @@ -42,7 +42,7 @@ func TestSocketBufferSequence1(t *testing.T) { func TestSocketBufferSequence2(t *testing.T) { slog.SetLevel(slog.DebugLevel) - buffer := NewSocketBuffer() + buffer := NewTCPBuffer() buffer.expectedSeq = 1000 buffer.leftPointer = 1000 @@ -78,7 +78,7 @@ func TestSocketBufferSequence2(t *testing.T) { func TestSocketBufferSequence3(t *testing.T) { slog.SetLevel(slog.DebugLevel) - buffer := NewSocketBuffer() + buffer := NewTCPBuffer() buffer.expectedSeq = 1000 buffer.leftPointer = 1000 @@ -114,7 +114,7 @@ func TestSocketBufferSequence3(t *testing.T) { func TestSocketBufferDuplicate(t *testing.T) { slog.SetLevel(slog.DebugLevel) - buffer := NewSocketBuffer() + buffer := NewTCPBuffer() buffer.expectedSeq = 1000 buffer.leftPointer = 1000