diff --git a/internal/api/options.go b/internal/api/options.go index 7c755b78ec..9053c2f978 100644 --- a/internal/api/options.go +++ b/internal/api/options.go @@ -6,6 +6,7 @@ import ( "github.com/didip/tollbooth/v5" "github.com/didip/tollbooth/v5/limiter" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/ratelimit" ) type Option interface { @@ -13,8 +14,8 @@ type Option interface { } type LimiterOptions struct { - Email *RateLimiter - Phone *RateLimiter + Email ratelimit.Limiter + Phone ratelimit.Limiter Signups *limiter.Limiter AnonymousSignIns *limiter.Limiter @@ -36,8 +37,9 @@ func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo } func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions { o := &LimiterOptions{} - o.Email = newRateLimiter(gc.RateLimitEmailSent) - o.Phone = newRateLimiter(gc.RateLimitSmsSent) + o.Email = ratelimit.New(gc.RateLimitEmailSent) + o.Phone = ratelimit.New(gc.RateLimitSmsSent) + o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{ DefaultExpirationTTL: time.Hour, diff --git a/internal/api/options_test.go b/internal/api/options_test.go new file mode 100644 index 0000000000..c4c1d1623a --- /dev/null +++ b/internal/api/options_test.go @@ -0,0 +1,30 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/supabase/auth/internal/conf" +) + +func TestNewLimiterOptions(t *testing.T) { + cfg := &conf.GlobalConfiguration{} + cfg.ApplyDefaults() + + rl := NewLimiterOptions(cfg) + assert.NotNil(t, rl.Email) + assert.NotNil(t, rl.Phone) + assert.NotNil(t, rl.Signups) + assert.NotNil(t, rl.AnonymousSignIns) + assert.NotNil(t, rl.Recover) + assert.NotNil(t, rl.Resend) + assert.NotNil(t, rl.MagicLink) + assert.NotNil(t, rl.Otp) + assert.NotNil(t, rl.Token) + assert.NotNil(t, rl.Verify) + assert.NotNil(t, rl.User) + assert.NotNil(t, rl.FactorVerify) + assert.NotNil(t, rl.FactorChallenge) + assert.NotNil(t, rl.SSO) + assert.NotNil(t, rl.SAMLAssertion) +} diff --git a/internal/api/ratelimits.go b/internal/api/ratelimits.go deleted file mode 100644 index 349c6a01b7..0000000000 --- a/internal/api/ratelimits.go +++ /dev/null @@ -1,49 +0,0 @@ -package api - -import ( - "sync" - "time" - - "github.com/supabase/auth/internal/conf" -) - -// RateLimiter will limit the number of calls to Allow per interval. -type RateLimiter struct { - mu sync.Mutex - ival time.Duration // Count is reset and time updated every ival. - limit int // Limit calls to Allow() per ival. - - // Guarded by mu. - last time.Time // When the limiter was last reset. - count int // Total calls to Allow() since time. -} - -// newRateLimiter returns a rate limiter configured using the given conf.Rate. -func newRateLimiter(r conf.Rate) *RateLimiter { - return &RateLimiter{ - ival: r.OverTime, - limit: int(r.Events), - last: time.Now(), - } -} - -func (rl *RateLimiter) Allow() bool { - rl.mu.Lock() - defer rl.mu.Unlock() - - now := time.Now() - return rl.allowAt(now) -} - -func (rl *RateLimiter) allowAt(at time.Time) bool { - since := at.Sub(rl.last) - if ivals := int64(since / rl.ival); ivals > 0 { - rl.last = rl.last.Add(time.Duration(ivals) * rl.ival) - rl.count = 0 - } - if rl.count < rl.limit { - rl.count++ - return true - } - return false -} diff --git a/internal/api/ratelimits_test.go b/internal/api/ratelimits_test.go deleted file mode 100644 index b30e4244c9..0000000000 --- a/internal/api/ratelimits_test.go +++ /dev/null @@ -1,125 +0,0 @@ -package api - -import ( - "fmt" - "testing" - "time" - - "github.com/supabase/auth/internal/conf" -) - -func Example_newRateLimiter() { - now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") - cfg := conf.Rate{Events: 100, OverTime: time.Hour * 24} - rl := newRateLimiter(cfg) - rl.last = now - - cur := now - allowed := 0 - - for days := 0; days < 2; days++ { - // First 100 events succeed. - for i := 0; i < 100; i++ { - allow := rl.allowAt(cur) - cur = cur.Add(time.Second) - - if !allow { - fmt.Printf("false @ %v after %v events... [FAILED]\n", cur, allowed) - return - } - allowed++ - } - fmt.Printf("true @ %v for last %v events...\n", cur, allowed) - - // We try hourly until it allows us to make requests again. - denied := 0 - for i := 0; i < 23; i++ { - cur = cur.Add(time.Hour) - allow := rl.allowAt(cur) - if allow { - fmt.Printf("true @ %v before quota reset... [FAILED]\n", cur) - return - } - denied++ - } - fmt.Printf("false @ %v for last %v events...\n", cur, denied) - - cur = cur.Add(time.Hour) - } - - // Output: - // true @ 2024-09-24 10:01:40 +0000 UTC for last 100 events... - // false @ 2024-09-25 09:01:40 +0000 UTC for last 23 events... - // true @ 2024-09-25 10:03:20 +0000 UTC for last 200 events... - // false @ 2024-09-26 09:03:20 +0000 UTC for last 23 events... -} - -func TestNewRateLimiter(t *testing.T) { - now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") - - type event struct { - ok bool - at time.Time - r int - } - cases := []struct { - cfg conf.Rate - now time.Time - evts []event - }{ - { - cfg: conf.Rate{Events: 100, OverTime: time.Hour * 24}, - now: now, - evts: []event{ - {true, now, 0}, - {true, now.Add(time.Minute), 98}, - {false, now.Add(time.Minute), 0}, - {false, now.Add(time.Minute * 14), 0}, - {false, now.Add(time.Minute * 15), 0}, - {false, now.Add(time.Minute * 16), 0}, - {false, now.Add(time.Minute * 17), 0}, - {false, now.Add(time.Minute * 17), 0}, - {true, now.Add(time.Hour * 24), 0}, - {true, now.Add(time.Hour * 25), 0}, - }, - }, - { - cfg: conf.Rate{Events: 0, OverTime: time.Hour}, - now: now, - evts: []event{ - {false, now.Add(-time.Hour), 0}, - {false, now, 0}, - {false, now.Add(time.Minute), 0}, - {false, now.Add(time.Hour), 0}, - {false, now.Add(time.Hour), 12}, - {false, now.Add(time.Hour * 24), 0}, - {false, now.Add(time.Hour * 24 * 2), 0}, - }, - }, - { - cfg: conf.Rate{Events: 0, OverTime: time.Hour * 24}, - now: now, - evts: []event{ - {false, now.Add(-time.Hour), 0}, - {false, now, 0}, - {false, now.Add(time.Minute), 0}, - {false, now.Add(time.Hour), 0}, - {false, now.Add(time.Hour), 12}, - {false, now.Add(time.Hour * 24), 0}, - {false, now.Add(time.Hour * 24 * 2), 0}, - }, - }, - } - for _, tc := range cases { - rl := newRateLimiter(tc.cfg) - rl.last = tc.now - - for _, evt := range tc.evts { - for i := 0; i <= evt.r; i++ { - if exp, got := evt.ok, rl.allowAt(evt.at); exp != got { - t.Fatalf("exp AllowN(%v, 1) to be %v; got %v", evt.at, exp, got) - } - } - } - } -} diff --git a/internal/conf/rate.go b/internal/conf/rate.go index ebe7ba475b..059ed65f08 100644 --- a/internal/conf/rate.go +++ b/internal/conf/rate.go @@ -9,22 +9,28 @@ import ( const defaultOverTime = time.Hour +const ( + BurstRateType = "burst" + IntervalRateType = "interval" +) + type Rate struct { Events float64 `json:"events,omitempty"` OverTime time.Duration `json:"over_time,omitempty"` + typ string } -func (r *Rate) EventsPerSecond() float64 { - d := r.OverTime - if d == 0 { - d = defaultOverTime +func (r *Rate) GetRateType() string { + if r.typ == "" { + return IntervalRateType } - return r.Events / d.Seconds() + return r.typ } // Decode is used by envconfig to parse the env-config string to a Rate value. func (r *Rate) Decode(value string) error { if f, err := strconv.ParseFloat(value, 64); err == nil { + r.typ = IntervalRateType r.Events = f r.OverTime = defaultOverTime return nil @@ -45,6 +51,7 @@ func (r *Rate) Decode(value string) error { return fmt.Errorf("rate: over-time part of rate value %q failed to parse as duration: %w", value, err) } + r.typ = BurstRateType r.Events = float64(e) r.OverTime = d return nil diff --git a/internal/conf/rate_test.go b/internal/conf/rate_test.go index 51b73df02b..378dedac7c 100644 --- a/internal/conf/rate_test.go +++ b/internal/conf/rate_test.go @@ -10,34 +10,39 @@ import ( func TestRateDecode(t *testing.T) { cases := []struct { str string - eps float64 exp Rate err string }{ - {str: "1800", eps: 0.5, exp: Rate{Events: 1800, OverTime: time.Hour}}, - {str: "1800.0", eps: 0.5, exp: Rate{Events: 1800, OverTime: time.Hour}}, - {str: "3600/1h", eps: 1, exp: Rate{Events: 3600, OverTime: time.Hour}}, + {str: "1800", + exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}}, + {str: "1800.0", + exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}}, + {str: "3600/1h", + exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}}, + {str: "3600/1h0m0s", + exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}}, {str: "100/24h", - eps: 0.0011574074074074073, - exp: Rate{Events: 100, OverTime: time.Hour * 24}}, - {str: "", eps: 1, exp: Rate{}, + exp: Rate{Events: 100, OverTime: time.Hour * 24, typ: BurstRateType}}, + {str: "", exp: Rate{}, err: `rate: value does not match`}, - {str: "1h", eps: 1, exp: Rate{}, + {str: "1h", exp: Rate{}, err: `rate: value does not match`}, - {str: "/", eps: 1, exp: Rate{}, + {str: "/", exp: Rate{}, err: `rate: events part of rate value`}, - {str: "/1h", eps: 1, exp: Rate{}, + {str: "/1h", exp: Rate{}, err: `rate: events part of rate value`}, - {str: "3600.0/1h", eps: 1, exp: Rate{}, + {str: "3600.0/1h", exp: Rate{}, err: `rate: events part of rate value "3600.0/1h" failed to parse`}, - {str: "100/", eps: 1, exp: Rate{}, + {str: "100/", exp: Rate{}, err: `rate: over-time part of rate value`}, - {str: "100/1", eps: 1, exp: Rate{}, + {str: "100/1", exp: Rate{}, err: `rate: over-time part of rate value`}, // zero events - {str: "0/1h", eps: 0.0, exp: Rate{Events: 0, OverTime: time.Hour}}, - {str: "0/24h", eps: 0.0, exp: Rate{Events: 0, OverTime: time.Hour * 24}}, + {str: "0/1h", + exp: Rate{Events: 0, OverTime: time.Hour, typ: BurstRateType}}, + {str: "0/24h", + exp: Rate{Events: 0, OverTime: time.Hour * 24, typ: BurstRateType}}, } for idx, tc := range cases { var r Rate @@ -51,6 +56,13 @@ func TestRateDecode(t *testing.T) { } require.NoError(t, err) require.Equal(t, tc.exp, r) - require.Equal(t, tc.eps, r.EventsPerSecond()) + require.Equal(t, tc.exp.typ, r.GetRateType()) } + + // GetRateType() zero value + require.Equal(t, IntervalRateType, (&Rate{}).GetRateType()) + + // String() + require.Equal(t, "0.000000", (&Rate{}).String()) + require.Equal(t, "100/1h0m0s", (&Rate{Events: 100, OverTime: time.Hour}).String()) } diff --git a/internal/ratelimit/burst.go b/internal/ratelimit/burst.go new file mode 100644 index 0000000000..6ae0ef58b3 --- /dev/null +++ b/internal/ratelimit/burst.go @@ -0,0 +1,60 @@ +package ratelimit + +import ( + "time" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/time/rate" +) + +const defaultOverTime = time.Hour + +// BurstLimiter wraps the golang.org/x/time/rate package. +type BurstLimiter struct { + rl *rate.Limiter +} + +// NewBurstLimiter returns a rate limiter configured using the given conf.Rate. +// +// The returned Limiter will be configured with a token bucket containing a +// single token, which will fill up at a rate of 1 event per r.OverTime with +// an initial burst amount of r.Events. +// +// For example: +// - 1/10s is 1 events per 10 seconds with burst of 1. +// - 1/2s is 1 events per 2 seconds with burst of 1. +// - 10/10s is 1 events per 10 seconds with burst of 10. +// +// If Rate.Events is <= 0, the burst amount will be set to 1. +// +// See Example_newBurstLimiter for a visualization. +func NewBurstLimiter(r conf.Rate) *BurstLimiter { + // The rate limiter deals in events per second. + d := r.OverTime + if d <= 0 { + d = defaultOverTime + } + + e := r.Events + if e <= 0 { + e = 0 + } + + // BurstLimiter will have an initial token bucket of size `e`. It will + // be refilled at a rate of 1 per duration `d` indefinitely. + rl := &BurstLimiter{ + rl: rate.NewLimiter(rate.Every(d), int(e)), + } + return rl +} + +// Allow implements Limiter by calling AllowAt with the current time. +func (l *BurstLimiter) Allow() bool { + return l.AllowAt(time.Now()) +} + +// AllowAt implements Limiter by calling the underlying x/time/rate.Limiter +// with the given time. +func (l *BurstLimiter) AllowAt(at time.Time) bool { + return l.rl.AllowN(at, 1) +} diff --git a/internal/ratelimit/burst_test.go b/internal/ratelimit/burst_test.go new file mode 100644 index 0000000000..b854e3b274 --- /dev/null +++ b/internal/ratelimit/burst_test.go @@ -0,0 +1,214 @@ +package ratelimit + +import ( + "fmt" + "testing" + "time" + + "github.com/supabase/auth/internal/conf" +) + +func Example_newBurstLimiter() { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + { + cfg := conf.Rate{Events: 10, OverTime: time.Second * 20} + rl := NewBurstLimiter(cfg) + cur := now + for i := 0; i < 20; i++ { + allowed := rl.AllowAt(cur) + fmt.Printf("%-5v @ %v\n", allowed, cur) + cur = cur.Add(time.Second * 5) + } + } + + // Output: + // true @ 2024-09-24 10:00:00 +0000 UTC + // true @ 2024-09-24 10:00:05 +0000 UTC + // true @ 2024-09-24 10:00:10 +0000 UTC + // true @ 2024-09-24 10:00:15 +0000 UTC + // true @ 2024-09-24 10:00:20 +0000 UTC + // true @ 2024-09-24 10:00:25 +0000 UTC + // true @ 2024-09-24 10:00:30 +0000 UTC + // true @ 2024-09-24 10:00:35 +0000 UTC + // true @ 2024-09-24 10:00:40 +0000 UTC + // true @ 2024-09-24 10:00:45 +0000 UTC + // true @ 2024-09-24 10:00:50 +0000 UTC + // true @ 2024-09-24 10:00:55 +0000 UTC + // true @ 2024-09-24 10:01:00 +0000 UTC + // false @ 2024-09-24 10:01:05 +0000 UTC + // false @ 2024-09-24 10:01:10 +0000 UTC + // false @ 2024-09-24 10:01:15 +0000 UTC + // true @ 2024-09-24 10:01:20 +0000 UTC + // false @ 2024-09-24 10:01:25 +0000 UTC + // false @ 2024-09-24 10:01:30 +0000 UTC + // false @ 2024-09-24 10:01:35 +0000 UTC +} + +func TestBurstLimiter(t *testing.T) { + t.Run("Allow", func(t *testing.T) { + for i := 1; i < 10; i++ { + cfg := conf.Rate{Events: float64(i), OverTime: time.Hour} + rl := NewBurstLimiter(cfg) + for y := i; y > 0; y-- { + if exp, got := true, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + if exp, got := false, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + }) + + t.Run("AllowAt", func(t *testing.T) { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + + type event struct { + ok bool + at time.Time + + // Event should be `ok` at `at` for `i` times + i int + } + + type testCase struct { + cfg conf.Rate + now time.Time + evts []event + } + cases := []testCase{ + { + cfg: conf.Rate{Events: 20, OverTime: time.Second * 20}, + now: now, + evts: []event{ + // initial burst of 20 is permitted + {true, now, 19}, + + // then denied, even at same time + {false, now, 100}, + + // and continue to deny until the next generated token + {false, now.Add(time.Second), 100}, + {false, now.Add(time.Second * 19), 100}, + + // allows a single call to allow at 20 seconds + {true, now.Add(time.Second * 20), 0}, + + // then denied + {false, now.Add(time.Second * 20), 100}, + + // and the pattern repeats + {true, now.Add(time.Second * 40), 0}, + {false, now.Add(time.Second * 40), 100}, + {false, now.Add(time.Second * 59), 100}, + + {true, now.Add(time.Second * 60), 0}, + {false, now.Add(time.Second * 60), 100}, + {false, now.Add(time.Second * 79), 100}, + + {true, now.Add(time.Second * 80), 0}, + {false, now.Add(time.Second * 80), 100}, + {false, now.Add(time.Second * 99), 100}, + + // allow tokens to be built up still + {true, now.Add(time.Hour), 19}, + }, + }, + + { + cfg: conf.Rate{Events: 1, OverTime: time.Second * 20}, + now: now, + evts: []event{ + // initial burst of 1 is permitted + {true, now, 0}, + + // then denied, even at same time + {false, now, 100}, + + // and continue to deny until the next generated token + {false, now.Add(time.Second), 100}, + {false, now.Add(time.Second * 19), 100}, + + // allows a single call to allow at 20 seconds + {true, now.Add(time.Second * 20), 0}, + + // then denied + {false, now.Add(time.Second * 20), 100}, + + // and the pattern repeats + {true, now.Add(time.Second * 40), 0}, + {false, now.Add(time.Second * 40), 100}, + {false, now.Add(time.Second * 59), 100}, + + {true, now.Add(time.Second * 60), 0}, + {false, now.Add(time.Second * 60), 100}, + {false, now.Add(time.Second * 79), 100}, + + {true, now.Add(time.Second * 80), 0}, + {false, now.Add(time.Second * 80), 100}, + {false, now.Add(time.Second * 99), 100}, + }, + }, + + // 1 event per second + { + cfg: conf.Rate{Events: 1, OverTime: time.Second}, + now: now, + evts: []event{ + {true, now, 0}, + {true, now.Add(time.Second), 0}, + {false, now.Add(time.Second), 0}, + {true, now.Add(time.Second * 2), 0}, + }, + }, + + // 1 events per second and OverTime = 1 event per hour. + { + cfg: conf.Rate{Events: 1, OverTime: 0}, + now: now, + evts: []event{ + {true, now, 0}, + {false, now.Add(time.Hour - time.Second), 0}, + {true, now.Add(time.Hour), 0}, + {true, now.Add(time.Hour * 2), 0}, + }, + }, + + // zero value for Events = 0 event per second + { + cfg: conf.Rate{Events: 0, OverTime: time.Second}, + now: now, + evts: []event{ + {false, now, 0}, + {false, now.Add(-time.Second), 0}, + {false, now.Add(time.Second), 0}, + {false, now.Add(time.Second * 2), 0}, + }, + }, + + // zero value for both Events and OverTime = 1 event per hour. + { + cfg: conf.Rate{Events: 0, OverTime: 0}, + now: now, + evts: []event{ + {false, now, 0}, + {false, now.Add(time.Hour - time.Second), 0}, + {false, now.Add(-time.Hour), 0}, + {false, now.Add(time.Hour), 0}, + {false, now.Add(time.Hour * 2), 0}, + }, + }, + } + + for _, tc := range cases { + rl := NewBurstLimiter(tc.cfg) + for _, evt := range tc.evts { + for i := 0; i <= evt.i; i++ { + if exp, got := evt.ok, rl.AllowAt(evt.at); exp != got { + t.Fatalf("exp AllowAt(%v) to be %v; got %v", evt.at, exp, got) + } + } + } + } + }) +} diff --git a/internal/ratelimit/interval.go b/internal/ratelimit/interval.go new file mode 100644 index 0000000000..a72302f748 --- /dev/null +++ b/internal/ratelimit/interval.go @@ -0,0 +1,63 @@ +package ratelimit + +import ( + "sync" + "time" + + "github.com/supabase/auth/internal/conf" +) + +// IntervalLimiter will limit the number of calls to Allow per interval. +type IntervalLimiter struct { + mu sync.Mutex + ival time.Duration // Count is reset and time updated every ival. + limit int // Limit calls to Allow() per ival. + + // Guarded by mu. + last time.Time // When the limiter was last reset. + count int // Total calls to Allow() since time. +} + +// NewIntervalLimiter returns a rate limiter using the given conf.Rate. +func NewIntervalLimiter(r conf.Rate) *IntervalLimiter { + return &IntervalLimiter{ + ival: r.OverTime, + limit: int(r.Events), + last: time.Now(), + } +} + +// Allow implements Limiter by calling AllowAt with the current time. +func (rl *IntervalLimiter) Allow() bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + return rl.allowAt(time.Now()) +} + +// AllowAt implements Limiter by checking if the current number of permitted +// events within this interval would permit 1 additional event at the current +// time. +// +// When called with a time outside the current active interval the counter is +// reset, meaning it can be vulnerable at the edge of it's intervals so avoid +// small intervals. +func (rl *IntervalLimiter) AllowAt(at time.Time) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + return rl.allowAt(at) +} + +func (rl *IntervalLimiter) allowAt(at time.Time) bool { + since := at.Sub(rl.last) + if ivals := int64(since / rl.ival); ivals > 0 { + rl.last = rl.last.Add(time.Duration(ivals) * rl.ival) + rl.count = 0 + } + if rl.count < rl.limit { + rl.count++ + return true + } + return false +} diff --git a/internal/ratelimit/interval_test.go b/internal/ratelimit/interval_test.go new file mode 100644 index 0000000000..835ee82568 --- /dev/null +++ b/internal/ratelimit/interval_test.go @@ -0,0 +1,81 @@ +package ratelimit + +import ( + "fmt" + "testing" + "time" + + "github.com/supabase/auth/internal/conf" +) + +func Example_newIntervalLimiter() { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + cfg := conf.Rate{Events: 100, OverTime: time.Hour * 24} + rl := NewIntervalLimiter(cfg) + rl.last = now + + cur := now + allowed := 0 + + for days := 0; days < 2; days++ { + // First 100 events succeed. + for i := 0; i < 100; i++ { + allow := rl.allowAt(cur) + cur = cur.Add(time.Second) + + if !allow { + fmt.Printf("false @ %v after %v events... [FAILED]\n", cur, allowed) + return + } + allowed++ + } + fmt.Printf("true @ %v for last %v events...\n", cur, allowed) + + // We try hourly until it allows us to make requests again. + denied := 0 + for i := 0; i < 23; i++ { + cur = cur.Add(time.Hour) + allow := rl.AllowAt(cur) + if allow { + fmt.Printf("true @ %v before quota reset... [FAILED]\n", cur) + return + } + denied++ + } + fmt.Printf("false @ %v for last %v events...\n", cur, denied) + + cur = cur.Add(time.Hour) + } + + // Output: + // true @ 2024-09-24 10:01:40 +0000 UTC for last 100 events... + // false @ 2024-09-25 09:01:40 +0000 UTC for last 23 events... + // true @ 2024-09-25 10:03:20 +0000 UTC for last 200 events... + // false @ 2024-09-26 09:03:20 +0000 UTC for last 23 events... +} + +func TestNewIntervalLimiter(t *testing.T) { + t.Run("Allow", func(t *testing.T) { + for i := 1; i < 10; i++ { + cfg := conf.Rate{Events: float64(i), OverTime: time.Hour} + rl := NewIntervalLimiter(cfg) + for y := i; y > 0; y-- { + if exp, got := true, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + if exp, got := false, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + + // should accept a negative burst. + cfg := conf.Rate{Events: 10, OverTime: time.Hour} + rl := NewBurstLimiter(cfg) + for y := 0; y < 10; y++ { + if exp, got := true, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + }) +} diff --git a/internal/ratelimit/ratelimit.go b/internal/ratelimit/ratelimit.go new file mode 100644 index 0000000000..35fbf9bc47 --- /dev/null +++ b/internal/ratelimit/ratelimit.go @@ -0,0 +1,34 @@ +package ratelimit + +import ( + "time" + + "github.com/supabase/auth/internal/conf" +) + +// Limiter is the interface implemented by rate limiters. +// +// Implementations of Limiter must be safe for concurrent use. +type Limiter interface { + + // Allow should return true if an event should be allowed at the time + // which it was called, or false otherwise. + Allow() bool + + // AllowAt should return true if an event should be allowed at the given + // time, or false otherwise. + AllowAt(at time.Time) bool +} + +// New returns a new Limiter based on the given config. +// +// When the type is conf.BurstRateType it returns a BurstLimiter, otherwise +// New returns an IntervalLimiter. +func New(r conf.Rate) Limiter { + switch r.GetRateType() { + case conf.BurstRateType: + return NewBurstLimiter(r) + default: + return NewIntervalLimiter(r) + } +} diff --git a/internal/ratelimit/ratelimit_test.go b/internal/ratelimit/ratelimit_test.go new file mode 100644 index 0000000000..3bac1dca2b --- /dev/null +++ b/internal/ratelimit/ratelimit_test.go @@ -0,0 +1,50 @@ +package ratelimit + +import ( + "testing" + + "github.com/supabase/auth/internal/conf" +) + +func TestNew(t *testing.T) { + + // IntervalLimiter + { + var r conf.Rate + err := r.Decode("100") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + rl := New(r) + if _, ok := rl.(*IntervalLimiter); !ok { + t.Fatalf("exp type *IntervalLimiter; got %T", rl) + } + } + { + var r conf.Rate + err := r.Decode("100.123") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + rl := New(r) + if _, ok := rl.(*IntervalLimiter); !ok { + t.Fatalf("exp type *IntervalLimiter; got %T", rl) + } + } + + // BurstLimiter + { + var r conf.Rate + err := r.Decode("20/200s") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + rl := New(r) + if _, ok := rl.(*BurstLimiter); !ok { + t.Fatalf("exp type *BurstLimiter; got %T", rl) + } + } +}