diff --git a/api/verify.go b/api/verify.go index 7ea9cb7962..49f17b2c62 100644 --- a/api/verify.go +++ b/api/verify.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "errors" "net/http" "net/url" "strconv" @@ -13,6 +14,11 @@ import ( "github.com/sethvargo/go-password/password" ) +var ( + // used below to specify need to return answer to user via specific redirect + redirectWithQueryError = errors.New("need return answer with query params") +) + const ( signupVerification = "signup" recoveryVerification = "recovery" @@ -74,6 +80,14 @@ func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { } if terr != nil { + var e *HTTPError + if errors.As(terr, &e) { + if errors.Is(e.InternalError, redirectWithQueryError) { + rurl := a.prepErrorRedirectURL(e, r) + http.Redirect(w, r, rurl, http.StatusFound) + return nil + } + } return terr } @@ -172,14 +186,14 @@ func (a *API) recoverVerify(ctx context.Context, conn *storage.Connection, param user, err := models.FindUserByRecoveryToken(conn, params.Token) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()) + return nil, notFoundError(err.Error()).WithInternalError(redirectWithQueryError) } return nil, internalServerError("Database error finding user").WithInternalError(err) } nextDay := user.RecoverySentAt.Add(24 * time.Hour) if user.RecoverySentAt != nil && time.Now().After(nextDay) { - return nil, expiredTokenError("Recovery token expired") + return nil, expiredTokenError("Recovery token expired").WithInternalError(redirectWithQueryError) } err = conn.Transaction(func(tx *storage.Connection) error { @@ -207,3 +221,19 @@ func (a *API) recoverVerify(ctx context.Context, conn *storage.Connection, param } return user, nil } + +func (a *API) prepErrorRedirectURL(err *HTTPError, r *http.Request) string { + ctx := r.Context() + rurl := a.getConfig(ctx).SiteURL + q := url.Values{} + + log := getLogEntry(r) + log.Error(err.Message) + + if str, ok := oauthErrorMap[err.Code]; ok { + q.Set("error", str) + } + q.Set("error_code", strconv.Itoa(err.Code)) + q.Set("error_description", err.Message) + return rurl + "#" + q.Encode() +}