Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cache calls to /.well-known/openid-configuration #1

Merged
merged 8 commits into from
Nov 27, 2024
Merged
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
go: ['1.21', '1.22']
go: ['1.22']
name: Linux Go ${{ matrix.go }}
steps:
- uses: actions/checkout@v4
19 changes: 12 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
module github.com/coreos/go-oidc/v3

go 1.21

toolchain go1.22.0
go 1.22

require (
github.com/go-jose/go-jose/v4 v4.0.2
golang.org/x/net v0.27.0
golang.org/x/oauth2 v0.21.0
github.com/dgraph-io/ristretto/v2 v2.0.0
github.com/go-jose/go-jose/v4 v4.0.4
golang.org/x/net v0.31.0
golang.org/x/oauth2 v0.24.0
)

require golang.org/x/crypto v0.25.0 // indirect
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
golang.org/x/crypto v0.29.0 // indirect
golang.org/x/sys v0.27.0 // indirect
)
36 changes: 24 additions & 12 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk=
github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/dgraph-io/ristretto/v2 v2.0.0 h1:l0yiSOtlJvc0otkqyMaDNysg8E9/F/TYZwMbxscNOAQ=
github.com/dgraph-io/ristretto/v2 v2.0.0/go.mod h1:FVFokF2dRqXyPyeMnK1YDy8Fc6aTe0IKgbcd03CYeEk=
github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y=
github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/go-jose/go-jose/v4 v4.0.4 h1:VsjPI33J0SB9vQM6PLmNjoHqMQNGPiZ0rHL7Ni7Q6/E=
github.com/go-jose/go-jose/v4 v4.0.4/go.mod h1:NKb5HO1EZccyMpiZNbdUw/14tiXNyUJh188dfnMCAfc=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ=
golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg=
golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo=
golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM=
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
63 changes: 42 additions & 21 deletions oidc/oidc.go
Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@ import (
"sync"
"time"

"github.com/dgraph-io/ristretto/v2"

"golang.org/x/oauth2"
)

@@ -201,35 +203,54 @@ func (p *ProviderConfig) NewProvider(ctx context.Context) *Provider {
}
}

var oidcConfigCache, _ = ristretto.NewCache(&ristretto.Config[string, []byte]{
NumCounters: 10_000,
MaxCost: 1_000 * 1024, // 10k items, 1kB each
BufferItems: 64,
})

// NewProvider uses the OpenID Connect discovery mechanism to construct a Provider.
//
// The issuer is the URL identifier for the service. For example: "https://accounts.google.com"
// or "https://login.salesforce.com".
func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
req, err := http.NewRequest("GET", wellKnown, nil)
if err != nil {
return nil, err
}
resp, err := doRequest(ctx, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err)
}
var p providerJSON
var body []byte

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s: %s", resp.Status, body)
}
if cached, found := oidcConfigCache.Get(wellKnown); found {
err := json.Unmarshal(cached, &p)
if err != nil {
return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
}
body = cached
} else {
req, err := http.NewRequest("GET", wellKnown, nil)
if err != nil {
return nil, err
}
resp, err := doRequest(ctx, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

var p providerJSON
err = unmarshalResp(resp, body, &p)
if err != nil {
return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
body, err = io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s: %s", resp.Status, body)
}
oidcConfigCache.Set(wellKnown, body, int64(len(body)))
oidcConfigCache.Wait()

err = unmarshalResp(resp, body, &p)
if err != nil {
return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
}
}

issuerURL, skipIssuerValidation := ctx.Value(issuerURLKey).(string)
@@ -239,7 +260,7 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
if p.Issuer != issuerURL && !skipIssuerValidation {
return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer)
}
var algs []string
algs := make([]string, 0, len(p.Algorithms))
for _, a := range p.Algorithms {
if supportedAlgorithms[a] {
algs = append(algs, a)
59 changes: 58 additions & 1 deletion oidc/oidc_test.go
Original file line number Diff line number Diff line change
@@ -290,7 +290,7 @@ func TestNewProvider(t *testing.T) {
return
}
w.Header().Set("Content-Type", "application/json")
io.WriteString(w, strings.ReplaceAll(test.data, "ISSUER", issuer))
_, _ = io.WriteString(w, strings.ReplaceAll(test.data, "ISSUER", issuer))
}
s := httptest.NewServer(http.HandlerFunc(hf))
defer s.Close()
@@ -342,6 +342,63 @@ func TestNewProvider(t *testing.T) {
}
})
}

t.Run("caches openid config", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

nFetched := 0

hf := func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}
nFetched++
issuer := "http://" + r.Host

w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, fmt.Sprintf(`{
"issuer": "%[1]s",
"authorization_endpoint": "%[1]s/auth",
"token_endpoint": "%[1]s/token",
"jwks_uri": "%[1]s/keys",
"id_token_signing_alg_values_supported": ["RS256"]
}`, issuer))
}
s0 := httptest.NewServer(http.HandlerFunc(hf))
defer s0.Close()
s1 := httptest.NewServer(http.HandlerFunc(hf))
defer s1.Close()

for i := range 200 {
s := []*httptest.Server{s0, s1}[i%2]

p, err := NewProvider(ctx, s.URL)
if err != nil {
t.Fatalf("NewProvider() failed: %v", err)
}
if p.issuer != s.URL {
t.Fatalf("NewProvider() unexpected issuer value, got=%s, want=%s", p.issuer, s.URL)
}
if p.authURL != s.URL+"/auth" {
t.Fatalf("NewProvider() unexpected authURL value, got=%s, want=%s/auth", p.authURL, s.URL)
}
if p.tokenURL != s.URL+"/token" {
t.Fatalf("NewProvider() unexpected tokenURL value, got=%s, want=%s/token", p.tokenURL, s.URL)
}
if p.jwksURL != s.URL+"/keys" {
t.Fatalf("NewProvider() unexpected jwksURL value, got=%s, want=%s/keys", p.jwksURL, s.URL)
}
if !reflect.DeepEqual(p.algorithms, []string{"RS256"}) {
t.Fatalf("NewProvider() unexpected algorithms value, got=%s, want=RS256", p.algorithms)
}
}

if nFetched != 2 {
t.Errorf("NewProvider() fetched openid config too often, got=%d, want=2", nFetched)
}
})
}

func TestGetClient(t *testing.T) {