Skip to content

Commit

Permalink
feat: add an optional burstable rate limiter (#1924)
Browse files Browse the repository at this point in the history
The existing rate limiter was moved to a separate package and renamed to
IntervalLimiter. Added BurstLimiter which is a wrapper around the
"golang.org/x/time/rate" package.

The conf.Rate type now has a private `typ` field that indicates if it is
a `"interval"` or `"burst"` rate limiter. If the config value is in the
form of `"<burst>/<rate>"` we set it to `"burst"`, otherwise
`"interval"`. The `conf.Rate.GetRateType()` method is then called from
the `ratelimit.New` function to determine the underlying type of
`ratelimit.Limiter` it returns.

Then changed `api.NewLimiterOptions` to call `ratelimit.New` instead of
creating a specific type of rate limiter.

---------

Co-authored-by: Chris Stockton <chris.stockton@supabase.io>
  • Loading branch information
cstockton and Chris Stockton authored Jan 27, 2025
1 parent 50eb69b commit 1f06f58
Show file tree
Hide file tree
Showing 12 changed files with 578 additions and 199 deletions.
10 changes: 6 additions & 4 deletions internal/api/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ import (
"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/ratelimit"
)

type Option interface {
apply(*API)
}

type LimiterOptions struct {
Email *RateLimiter
Phone *RateLimiter
Email ratelimit.Limiter
Phone ratelimit.Limiter

Signups *limiter.Limiter
AnonymousSignIns *limiter.Limiter
Expand All @@ -36,8 +37,9 @@ func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo }
func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions {
o := &LimiterOptions{}

o.Email = newRateLimiter(gc.RateLimitEmailSent)
o.Phone = newRateLimiter(gc.RateLimitSmsSent)
o.Email = ratelimit.New(gc.RateLimitEmailSent)
o.Phone = ratelimit.New(gc.RateLimitSmsSent)

o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
Expand Down
30 changes: 30 additions & 0 deletions internal/api/options_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package api

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/supabase/auth/internal/conf"
)

func TestNewLimiterOptions(t *testing.T) {
cfg := &conf.GlobalConfiguration{}
cfg.ApplyDefaults()

rl := NewLimiterOptions(cfg)
assert.NotNil(t, rl.Email)
assert.NotNil(t, rl.Phone)
assert.NotNil(t, rl.Signups)
assert.NotNil(t, rl.AnonymousSignIns)
assert.NotNil(t, rl.Recover)
assert.NotNil(t, rl.Resend)
assert.NotNil(t, rl.MagicLink)
assert.NotNil(t, rl.Otp)
assert.NotNil(t, rl.Token)
assert.NotNil(t, rl.Verify)
assert.NotNil(t, rl.User)
assert.NotNil(t, rl.FactorVerify)
assert.NotNil(t, rl.FactorChallenge)
assert.NotNil(t, rl.SSO)
assert.NotNil(t, rl.SAMLAssertion)
}
49 changes: 0 additions & 49 deletions internal/api/ratelimits.go

This file was deleted.

125 changes: 0 additions & 125 deletions internal/api/ratelimits_test.go

This file was deleted.

17 changes: 12 additions & 5 deletions internal/conf/rate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,28 @@ import (

const defaultOverTime = time.Hour

const (
BurstRateType = "burst"
IntervalRateType = "interval"
)

type Rate struct {
Events float64 `json:"events,omitempty"`
OverTime time.Duration `json:"over_time,omitempty"`
typ string
}

func (r *Rate) EventsPerSecond() float64 {
d := r.OverTime
if d == 0 {
d = defaultOverTime
func (r *Rate) GetRateType() string {
if r.typ == "" {
return IntervalRateType
}
return r.Events / d.Seconds()
return r.typ
}

// Decode is used by envconfig to parse the env-config string to a Rate value.
func (r *Rate) Decode(value string) error {
if f, err := strconv.ParseFloat(value, 64); err == nil {
r.typ = IntervalRateType
r.Events = f
r.OverTime = defaultOverTime
return nil
Expand All @@ -45,6 +51,7 @@ func (r *Rate) Decode(value string) error {
return fmt.Errorf("rate: over-time part of rate value %q failed to parse as duration: %w", value, err)
}

r.typ = BurstRateType
r.Events = float64(e)
r.OverTime = d
return nil
Expand Down
44 changes: 28 additions & 16 deletions internal/conf/rate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,39 @@ import (
func TestRateDecode(t *testing.T) {
cases := []struct {
str string
eps float64
exp Rate
err string
}{
{str: "1800", eps: 0.5, exp: Rate{Events: 1800, OverTime: time.Hour}},
{str: "1800.0", eps: 0.5, exp: Rate{Events: 1800, OverTime: time.Hour}},
{str: "3600/1h", eps: 1, exp: Rate{Events: 3600, OverTime: time.Hour}},
{str: "1800",
exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}},
{str: "1800.0",
exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}},
{str: "3600/1h",
exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}},
{str: "3600/1h0m0s",
exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}},
{str: "100/24h",
eps: 0.0011574074074074073,
exp: Rate{Events: 100, OverTime: time.Hour * 24}},
{str: "", eps: 1, exp: Rate{},
exp: Rate{Events: 100, OverTime: time.Hour * 24, typ: BurstRateType}},
{str: "", exp: Rate{},
err: `rate: value does not match`},
{str: "1h", eps: 1, exp: Rate{},
{str: "1h", exp: Rate{},
err: `rate: value does not match`},
{str: "/", eps: 1, exp: Rate{},
{str: "/", exp: Rate{},
err: `rate: events part of rate value`},
{str: "/1h", eps: 1, exp: Rate{},
{str: "/1h", exp: Rate{},
err: `rate: events part of rate value`},
{str: "3600.0/1h", eps: 1, exp: Rate{},
{str: "3600.0/1h", exp: Rate{},
err: `rate: events part of rate value "3600.0/1h" failed to parse`},
{str: "100/", eps: 1, exp: Rate{},
{str: "100/", exp: Rate{},
err: `rate: over-time part of rate value`},
{str: "100/1", eps: 1, exp: Rate{},
{str: "100/1", exp: Rate{},
err: `rate: over-time part of rate value`},

// zero events
{str: "0/1h", eps: 0.0, exp: Rate{Events: 0, OverTime: time.Hour}},
{str: "0/24h", eps: 0.0, exp: Rate{Events: 0, OverTime: time.Hour * 24}},
{str: "0/1h",
exp: Rate{Events: 0, OverTime: time.Hour, typ: BurstRateType}},
{str: "0/24h",
exp: Rate{Events: 0, OverTime: time.Hour * 24, typ: BurstRateType}},
}
for idx, tc := range cases {
var r Rate
Expand All @@ -51,6 +56,13 @@ func TestRateDecode(t *testing.T) {
}
require.NoError(t, err)
require.Equal(t, tc.exp, r)
require.Equal(t, tc.eps, r.EventsPerSecond())
require.Equal(t, tc.exp.typ, r.GetRateType())
}

// GetRateType() zero value
require.Equal(t, IntervalRateType, (&Rate{}).GetRateType())

// String()
require.Equal(t, "0.000000", (&Rate{}).String())
require.Equal(t, "100/1h0m0s", (&Rate{Events: 100, OverTime: time.Hour}).String())
}
Loading

0 comments on commit 1f06f58

Please sign in to comment.