Skip to content

Commit

Permalink
feat: fix refresh token reuse revocation (#1312)
Browse files Browse the repository at this point in the history
Refresh token reuse revocation was broken, as an error was returned from
the transaction where the revocation took place, which rolled back any
changes. This went unnoticed as the reuse error was sent. Ouch.
  • Loading branch information
hf authored Nov 20, 2023
1 parent 8565d26 commit 6e313f8
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 76 deletions.
4 changes: 2 additions & 2 deletions internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h

if token.Revoked {
activeRefreshToken, terr := session.FindCurrentlyActiveRefreshToken(tx)
if terr != nil {
if terr != nil && !models.IsNotFoundError(terr) {
return internalServerError(terr.Error())
}

Expand Down Expand Up @@ -199,7 +199,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
}
}

return oauthError("invalid_grant", "Invalid Refresh Token: Already Used").WithInternalMessage("Possible abuse attempt: %v", token.ID)
return storage.NewCommitWithError(oauthError("invalid_grant", "Invalid Refresh Token: Already Used").WithInternalMessage("Possible abuse attempt: %v", token.ID))
}
}
}
Expand Down
174 changes: 100 additions & 74 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,84 +362,110 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenGrantFailure() {
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}

func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() {
u, err := models.NewUser("", "foo@example.com", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
t := time.Now()
u.EmailConfirmedAt = &t
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving foo user")
func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() {
originalSecurity := ts.API.config.Security

first, err := models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{})
require.NoError(ts.T(), err)
second, err := models.GrantRefreshTokenSwap(&http.Request{}, ts.API.db, u, first)
require.NoError(ts.T(), err)
ts.API.config.Security.RefreshTokenRotationEnabled = true
ts.API.config.Security.RefreshTokenReuseInterval = 0

cases := []struct {
desc string
refreshTokenRotationEnabled bool
reuseInterval int
refreshToken string
expectedCode int
expectedBody map[string]interface{}
}{
{
desc: "Valid refresh within reuse interval",
refreshTokenRotationEnabled: true,
reuseInterval: 30,
refreshToken: second.Token,
expectedCode: http.StatusOK,
expectedBody: map[string]interface{}{
"refresh_token": "some-new-refresh-token",
},
},
{
desc: "Invalid refresh outside reuse interval",
refreshTokenRotationEnabled: true,
reuseInterval: 0,
refreshToken: first.Token,
expectedCode: http.StatusBadRequest,
expectedBody: map[string]interface{}{
"error": "invalid_grant",
"error_description": "Invalid Refresh Token: Already Used",
},
},
{
desc: "Invalid refresh, revoke third token",
refreshTokenRotationEnabled: true,
reuseInterval: 0,
refreshToken: first.Token,
expectedCode: http.StatusBadRequest,
expectedBody: map[string]interface{}{
"error": "invalid_grant",
"error_description": "Invalid Refresh Token: Already Used",
},
},
defer func() {
ts.API.config.Security = originalSecurity
}()

refreshTokens := []string{
ts.RefreshToken.Token,
}

for _, c := range cases {
ts.Run(c.desc, func() {
ts.Config.Security.RefreshTokenRotationEnabled = c.refreshTokenRotationEnabled
ts.Config.Security.RefreshTokenReuseInterval = c.reuseInterval
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": c.refreshToken,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), c.expectedCode, w.Code)

data := make(map[string]interface{})
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
for k, v := range c.expectedBody {
if k == "refresh_token" {
require.NotEmpty(ts.T(), v, data[k])
} else {
require.Equal(ts.T(), v, data[k])
}
}
})
for i := 0; i < 3; i += 1 {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": refreshTokens[len(refreshTokens)-1],
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

assert.Equal(ts.T(), http.StatusOK, w.Code)

var response struct {
RefreshToken string `json:"refresh_token"`
}

require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))

refreshTokens = append(refreshTokens, response.RefreshToken)
}

// ensure that the 4 refresh tokens are setup correctly
for i, refreshToken := range refreshTokens {
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
require.NoError(ts.T(), err)

if i == len(refreshTokens)-1 {
require.False(ts.T(), token.Revoked)
} else {
require.True(ts.T(), token.Revoked)
}
}

// try to reuse the first (earliest) refresh token which should trigger the family revocation logic
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": refreshTokens[0],
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

assert.Equal(ts.T(), http.StatusBadRequest, w.Code)

var response struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}

require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))
require.Equal(ts.T(), response.Error, "invalid_grant")
require.Equal(ts.T(), response.ErrorDescription, "Invalid Refresh Token: Already Used")

// ensure that the refresh tokens are marked as revoked in the database
for _, refreshToken := range refreshTokens {
_, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false)
require.NoError(ts.T(), err)

require.True(ts.T(), token.Revoked)
}

// finally ensure that none of the refresh tokens can be reused any
// more, starting with the previously valid one
for i := len(refreshTokens) - 1; i >= 0; i -= 1 {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": refreshTokens[i],
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

assert.Equal(ts.T(), http.StatusBadRequest, w.Code, "For refresh token %d", i)

var response struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}

require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response))
require.Equal(ts.T(), response.Error, "invalid_grant", "For refresh token %d", i)
require.Equal(ts.T(), response.ErrorDescription, "Invalid Refresh Token: Already Used", "For refresh token %d", i)
}
}

Expand Down
4 changes: 4 additions & 0 deletions internal/models/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ func RevokeTokenFamily(tx *storage.Connection, token *RefreshToken) error {
update `+tablename+` r set revoked = true from token_family where token_family.id = r.id;`, token.Token).Exec()
}
if err != nil {
if errors.Cause(err) == sql.ErrNoRows || errors.Is(err, sql.ErrNoRows) {
return nil
}

return err
}
return nil
Expand Down
4 changes: 4 additions & 0 deletions internal/models/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ func (s *Session) FindCurrentlyActiveRefreshToken(tx *storage.Connection) (*Refr
var activeRefreshToken RefreshToken

if err := tx.Q().Where("session_id = ? and revoked is false", s.ID).Order("id desc").First(&activeRefreshToken); err != nil {
if errors.Cause(err) == sql.ErrNoRows || errors.Is(err, sql.ErrNoRows) {
return nil, RefreshTokenNotFoundError{}
}

return nil, err
}

Expand Down

0 comments on commit 6e313f8

Please sign in to comment.