From c361898aed262dfb0e3a356f4564c3d293f68ea2 Mon Sep 17 00:00:00 2001 From: vearne Date: Wed, 21 Aug 2024 15:24:02 +0800 Subject: [PATCH] add rate limit --- biz/emitter.go | 20 ++++++++++++++------ biz/itf.go | 4 ++++ biz/ratelimit.go | 14 ++++++++++++++ config/settings.go | 5 +++++ go.mod | 1 + go.sum | 2 ++ main.go | 7 ++++++- 7 files changed, 46 insertions(+), 7 deletions(-) create mode 100644 biz/ratelimit.go diff --git a/biz/emitter.go b/biz/emitter.go index 12e5937..449e06f 100644 --- a/biz/emitter.go +++ b/biz/emitter.go @@ -12,12 +12,14 @@ type Emitter struct { sync.WaitGroup plugins *InOutPlugins filterChain filter.Filter + limiter Limiter } // NewEmitter creates and initializes new Emitter object. -func NewEmitter(f filter.Filter) *Emitter { +func NewEmitter(f filter.Filter, lim Limiter) *Emitter { var e Emitter e.filterChain = f + e.limiter = lim return &e } @@ -58,11 +60,17 @@ func (e *Emitter) CopyMulty(src PluginReader, writers ...PluginWriter) error { continue } msg, ok := e.filterChain.Filter(msg) - if ok { - for _, dst := range writers { - if err = dst.Write(msg); err != nil { - slog.Error("dst.Write:%v", err) - } + if !ok { + continue + } + + if e.limiter != nil && !e.limiter.Allow() { + continue + } + + for _, dst := range writers { + if err = dst.Write(msg); err != nil { + slog.Error("dst.Write:%v", err) } } } diff --git a/biz/itf.go b/biz/itf.go index a8f9905..1fc1a18 100644 --- a/biz/itf.go +++ b/biz/itf.go @@ -16,3 +16,7 @@ type PluginWriter interface { io.Closer Write(msg *protocol.Message) (err error) } + +type Limiter interface { + Allow() bool +} diff --git a/biz/ratelimit.go b/biz/ratelimit.go new file mode 100644 index 0000000..8380649 --- /dev/null +++ b/biz/ratelimit.go @@ -0,0 +1,14 @@ +package biz + +import ( + "github.com/vearne/grpcreplay/config" + "golang.org/x/time/rate" +) + +func NewRateLimit(settings *config.AppSettings) Limiter { + if settings.RateLimitQPS > 0 { + value := settings.RateLimitQPS + return rate.NewLimiter(rate.Limit(value), value) + } + return nil +} diff --git a/config/settings.go b/config/settings.go index 399325e..eb335e7 100644 --- a/config/settings.go +++ b/config/settings.go @@ -92,6 +92,11 @@ type AppSettings struct { // --- filter --- IncludeFilterMethodMatch string `json:"include-filter-method-match"` + + // --- rate limit --- + // Query per second + RateLimitQPS int `json:"rate-limit-qps"` + // --- other --- Codec string `json:"codec"` } diff --git a/go.mod b/go.mod index 8ca26ef..15571b3 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect + golang.org/x/time v0.6.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157 // indirect stathat.com/c/consistent v1.0.0 // indirect ) diff --git a/go.sum b/go.sum index c5f9dd5..0adfe81 100644 --- a/go.sum +++ b/go.sum @@ -218,6 +218,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/main.go b/main.go index 6eeca3c..8662a74 100644 --- a/main.go +++ b/main.go @@ -94,6 +94,10 @@ func init() { flag.StringVar(&settings.IncludeFilterMethodMatch, "include-filter-method-match", "", `filter requests when the method matches the specified regular expression`) + // rate limit + flag.IntVar(&settings.RateLimitQPS, "rate-limit-qps", -1, + `the capture rate per second limit for Query`) + // rocketmq flag.Var(&config.MultiStringOption{Params: &settings.OutputRocketMQNameServer}, "output-rocketmq-name-server", @@ -129,7 +133,8 @@ func main() { if err != nil { slog.Fatal("create FilterChain error:%v", err) } - emitter := biz.NewEmitter(filterChain) + limiter := biz.NewRateLimit(&settings) + emitter := biz.NewEmitter(filterChain, limiter) plugins := biz.NewPlugins(&settings) slog.Info("plugins:%v", plugins)