Skip to content

Commit

Permalink
feat: add single session per user with tags support (#1297)
Browse files Browse the repository at this point in the history
Enforces a single session per user with optional tags. If a session has
a tag, only the most recently refreshed session with the same tag can be
refreshed. If no tags are configured, then only the most recently
refreshed session of all of the user's sessions will be refreshed.

Sessions that are invalid due to inactivity or timeboxing won't be
considered.
  • Loading branch information
hf authored Nov 15, 2023
1 parent 3c72faf commit 69feebc
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 20 deletions.
81 changes: 61 additions & 20 deletions internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,31 +64,18 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
}

if session != nil {
var notAfter time.Time
result := session.CheckValidity(retryStart, &token.UpdatedAt, config.Sessions.Timebox, config.Sessions.InactivityTimeout)

if session.NotAfter != nil {
notAfter = *session.NotAfter
}

if config.Sessions.Timebox != nil {
sessionEndsAt := session.CreatedAt.Add((*config.Sessions.Timebox).Abs())
switch result {
case models.SessionValid:
// do nothing

if notAfter.IsZero() || notAfter.After(sessionEndsAt) {
notAfter = sessionEndsAt
}
}
case models.SessionTimedOut:
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (Inactivity)")

if !notAfter.IsZero() && a.Now().After(notAfter) {
default:
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired")
}

if config.Sessions.InactivityTimeout != nil {
timesOutAt := session.LastRefreshedAt(&token.UpdatedAt).Add(*config.Sessions.InactivityTimeout)

if timesOutAt.Before(a.Now()) {
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (Inactivity)")
}
}
}

// Basic checks above passed, now we need to serialize access
Expand Down Expand Up @@ -120,6 +107,60 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return internalServerError(terr.Error())
}

if a.config.Sessions.SinglePerUser {
sessions, terr := models.FindAllSessionsForUser(tx, user.ID, true /* forUpdate */)
if models.IsNotFoundError(terr) {
// because forUpdate was set, and the
// previous check outside the
// transaction found a user and
// session, but now we're getting a
// IsNotFoundError, this means that the
// user is locked and we need to retry
// in a few milliseconds
retry = true
return terr
} else if terr != nil {
return internalServerError(terr.Error())
}

sessionTag := session.DetermineTag(config.Sessions.Tags)

// go through all sessions of the user and
// check if the current session is the user's
// most recently refreshed valid session
for _, s := range sessions {
if s.ID == session.ID {
// current session, skip it
continue
}

if s.CheckValidity(retryStart, nil, config.Sessions.Timebox, config.Sessions.InactivityTimeout) != models.SessionValid {
// session is not valid so it
// can't be regarded as active
// on the user
continue
}

if s.DetermineTag(config.Sessions.Tags) != sessionTag {
// if tags are specified,
// ignore sessions with a
// mismatching tag
continue
}

// since token is not the refresh token
// of s, we can't use it's UpdatedAt
// time to compare!
if s.LastRefreshedAt(nil).After(session.LastRefreshedAt(&token.UpdatedAt)) {
// session is not the most
// recently active one
return oauthError("invalid_grant", "Invalid Refresh Token: Session Expired (Revoked by Newer Login)")
}
}

// this session is the user's active session
}

// refresh token row and session are locked at this
// point, cannot be concurrently refreshed

Expand Down
43 changes: 43 additions & 0 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (ts *TokenTestSuite) TestSessionTimebox() {

defer func() {
ts.API.overrideTime = nil
ts.API.config.Sessions.Timebox = nil
}()

var buffer bytes.Buffer
Expand Down Expand Up @@ -176,6 +177,48 @@ func (ts *TokenTestSuite) TestFailedToSaveRefreshTokenResultCase() {
assert.Equal(ts.T(), firstResult.RefreshToken, secondResult.RefreshToken)
}

func (ts *TokenTestSuite) TestSingleSessionPerUserNoTags() {
ts.API.config.Sessions.SinglePerUser = true
defer func() {
ts.API.config.Sessions.SinglePerUser = false
}()

firstRefreshToken := ts.RefreshToken

// just in case to give some delay between first and second session creation
time.Sleep(10 * time.Millisecond)

secondRefreshToken, err := models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{})

require.NoError(ts.T(), err)

require.NotEqual(ts.T(), *firstRefreshToken.SessionId, *secondRefreshToken.SessionId)
require.Equal(ts.T(), firstRefreshToken.UserID, secondRefreshToken.UserID)

var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": firstRefreshToken.Token,
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()

ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
assert.True(ts.T(), ts.API.config.Sessions.SinglePerUser)

var firstResult struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}

assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult))
assert.Equal(ts.T(), "invalid_grant", firstResult.Error)
assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Revoked by Newer Login)", firstResult.ErrorDescription)
}

func (ts *TokenTestSuite) TestRateLimitTokenRefresh() {
var buffer bytes.Buffer
req := httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer)
Expand Down
3 changes: 3 additions & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ func (a *APIConfiguration) Validate() error {
type SessionsConfiguration struct {
Timebox *time.Duration `json:"timebox"`
InactivityTimeout *time.Duration `json:"inactivity_timeout,omitempty" split_words:"true"`

SinglePerUser bool `json:"single_per_user" split_words:"true"`
Tags []string `json:"tags,omitempty"`
}

func (c *SessionsConfiguration) Validate() error {
Expand Down
5 changes: 5 additions & 0 deletions internal/models/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type GrantParams struct {
FactorID *uuid.UUID

SessionNotAfter *time.Time
SessionTag *string

UserAgent string
IP string
Expand Down Expand Up @@ -145,6 +146,10 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
session.IP = &params.IP
}

if params.SessionTag != nil && *params.SessionTag != "" {
session.Tag = params.SessionTag
}

if err := tx.Create(session); err != nil {
return nil, errors.Wrap(err, "error creating new session")
}
Expand Down
79 changes: 79 additions & 0 deletions internal/models/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ type Session struct {
RefreshedAt *time.Time `json:"refreshed_at,omitempty" db:"refreshed_at"`
UserAgent *string `json:"user_agent,omitempty" db:"user_agent"`
IP *string `json:"ip,omitempty" db:"ip"`

Tag *string `json:"tag" db:"tag"`
}

func (Session) TableName() string {
Expand Down Expand Up @@ -104,6 +106,54 @@ func (s *Session) UpdateOnlyRefreshInfo(tx *storage.Connection) error {
return tx.UpdateOnly(s, "refreshed_at", "user_agent", "ip")
}

type SessionValidityReason = int

const (
SessionValid SessionValidityReason = iota
SessionPastNotAfter = iota
SessionPastTimebox = iota
SessionTimedOut = iota
)

func (s *Session) CheckValidity(now time.Time, refreshTokenTime *time.Time, timebox, inactivityTimeout *time.Duration) SessionValidityReason {
if s.NotAfter != nil && now.After(*s.NotAfter) {
return SessionPastNotAfter
}

if timebox != nil && *timebox != 0 && now.After(s.CreatedAt.Add(*timebox)) {
return SessionPastTimebox
}

if inactivityTimeout != nil && *inactivityTimeout != 0 && now.After(s.LastRefreshedAt(refreshTokenTime).Add(*inactivityTimeout)) {
return SessionTimedOut
}

return SessionValid
}

func (s *Session) DetermineTag(tags []string) string {
if len(tags) == 0 {
return ""
}

if s.Tag == nil {
return tags[0]
}

tag := *s.Tag
if tag == "" {
return tags[0]
}

for _, t := range tags {
if t == tag {
return tag
}
}

return tags[0]
}

func NewSession() (*Session, error) {
id := uuid.Must(uuid.NewV4())

Expand Down Expand Up @@ -168,6 +218,35 @@ func FindSessionsByFactorID(tx *storage.Connection, factorID uuid.UUID) ([]*Sess
return sessions, nil
}

// FindAllSessionsForUser finds all of the sessions for a user. If forUpdate is
// set, it will first lock on the user row which can be used to prevent issues
// with concurrency. If the lock is acquired, it will return a
// UserNotFoundError and the operation should be retried. If there are no
// sessions for the user, a nil result is returned without an error.
func FindAllSessionsForUser(tx *storage.Connection, userId uuid.UUID, forUpdate bool) ([]*Session, error) {
if forUpdate {
user := &User{}
if err := tx.RawQuery(fmt.Sprintf("SELECT id FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", user.TableName()), userId).First(user); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, UserNotFoundError{}
}

return nil, err
}
}

var sessions []*Session
if err := tx.Where("user_id = ?", userId).All(&sessions); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, nil
}

return nil, err
}

return sessions, nil
}

func updateFactorAssociatedSessions(tx *storage.Connection, userID, factorID uuid.UUID, aal string) error {
return tx.RawQuery("UPDATE "+(&pop.Model{Value: Session{}}).TableName()+" set aal = ?, factor_id = ? WHERE user_id = ? AND factor_id = ?", aal, nil, userID, factorID).Exec()
}
Expand Down
2 changes: 2 additions & 0 deletions migrations/20231114161723_add_sessions_tag.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
alter table if exists {{ index .Options "Namespace" }}.sessions
add column if not exists tag text;

0 comments on commit 69feebc

Please sign in to comment.