diff --git a/internal/api/api.go b/internal/api/api.go
index 49b8106963..cfff819167 100644
--- a/internal/api/api.go
+++ b/internal/api/api.go
@@ -258,7 +258,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
})
r.Route("/sso", func(r *router) {
- r.Use(api.requireSAMLEnabled)
+ r.Use(api.requireSSOEnabled)
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.RateLimitSso/(60*5), &limiter.ExpirableOptions{
@@ -267,6 +267,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
)).With(api.verifyCaptcha).Post("/", api.SingleSignOn)
r.Route("/saml", func(r *router) {
+ r.Use(api.requireSSOSAMLEnabled)
r.Get("/metadata", api.SAMLMetadata)
r.With(api.limitHandler(
@@ -276,6 +277,18 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
}).SetBurst(30),
)).Post("/acs", api.SamlAcs)
})
+
+ r.Route("/oidc", func(r *router) {
+ r.Use(api.requireSSOOIDCEnabled)
+ r.Route("/callback", func(r *router) {
+ r.Use(api.isValidExternalHost)
+ r.Use(api.loadSSOOIDCFlowState)
+
+ r.Get("/", api.ExternalProviderCallback)
+ r.Post("/", api.ExternalProviderCallback)
+ })
+ })
+
})
r.Route("/admin", func(r *router) {
@@ -320,6 +333,19 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
r.Put("/", api.adminSSOProvidersUpdate)
r.Delete("/", api.adminSSOProvidersDelete)
})
+
+ r.Route("/oidc", func(r *router) {
+ r.Get("/", api.adminOIDCSSOProvidersList)
+ r.Post("/", api.adminOIDCSSOProvidersCreate)
+
+ r.Route("/{idp_id}", func(r *router) {
+ r.Use(api.loadOIDCSSOProvider)
+
+ r.Get("/", api.adminOIDCSSOProvidersGet)
+ // r.Put("/", api.adminOIDCSSOProvidersUpdate)
+ r.Delete("/", api.adminOIDCSSOProvidersDelete)
+ })
+ })
})
})
diff --git a/internal/api/context.go b/internal/api/context.go
index 3047f3dd6a..f7738e17c1 100644
--- a/internal/api/context.go
+++ b/internal/api/context.go
@@ -5,6 +5,7 @@ import (
"net/url"
jwt "github.com/golang-jwt/jwt/v5"
+ "github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
)
@@ -15,22 +16,23 @@ func (c contextKey) String() string {
}
const (
- tokenKey = contextKey("jwt")
- inviteTokenKey = contextKey("invite_token")
- signatureKey = contextKey("signature")
- externalProviderTypeKey = contextKey("external_provider_type")
- userKey = contextKey("user")
- targetUserKey = contextKey("target_user")
- factorKey = contextKey("factor")
- sessionKey = contextKey("session")
- externalReferrerKey = contextKey("external_referrer")
- functionHooksKey = contextKey("function_hooks")
- adminUserKey = contextKey("admin_user")
- oauthTokenKey = contextKey("oauth_token") // for OAuth1.0, also known as request token
- oauthVerifierKey = contextKey("oauth_verifier")
- ssoProviderKey = contextKey("sso_provider")
- externalHostKey = contextKey("external_host")
- flowStateKey = contextKey("flow_state_id")
+ tokenKey = contextKey("jwt")
+ inviteTokenKey = contextKey("invite_token")
+ signatureKey = contextKey("signature")
+ externalProviderTypeKey = contextKey("external_provider_type")
+ userKey = contextKey("user")
+ targetUserKey = contextKey("target_user")
+ factorKey = contextKey("factor")
+ sessionKey = contextKey("session")
+ externalReferrerKey = contextKey("external_referrer")
+ functionHooksKey = contextKey("function_hooks")
+ adminUserKey = contextKey("admin_user")
+ oauthTokenKey = contextKey("oauth_token") // for OAuth1.0, also known as request token
+ oauthVerifierKey = contextKey("oauth_verifier")
+ ssoProviderKey = contextKey("sso_provider")
+ externalHostKey = contextKey("external_host")
+ flowStateKey = contextKey("flow_state_id")
+ genericProviderConfigKey = contextKey("generic_provider_config")
)
// withToken adds the JWT token to the context.
@@ -241,3 +243,16 @@ func getExternalHost(ctx context.Context) *url.URL {
}
return obj.(*url.URL)
}
+
+func withGenericProviderConfig(ctx context.Context, token *conf.GenericOAuthProviderConfiguration) context.Context {
+ return context.WithValue(ctx, genericProviderConfigKey, token)
+}
+
+func getGenericProviderConfig(ctx context.Context) *conf.GenericOAuthProviderConfiguration {
+ obj := ctx.Value(genericProviderConfigKey)
+ if obj == nil {
+ return nil
+ }
+
+ return obj.(*conf.GenericOAuthProviderConfiguration)
+}
diff --git a/internal/api/external.go b/internal/api/external.go
index 4df4c65026..7db0652d39 100644
--- a/internal/api/external.go
+++ b/internal/api/external.go
@@ -578,6 +578,9 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
return provider.NewWorkOSProvider(config.External.WorkOS)
case "zoom":
return provider.NewZoomProvider(config.External.Zoom)
+ case "sso/oidc":
+ config := getGenericProviderConfig(ctx)
+ return provider.NewGenericProvider(*config, scopes)
default:
return nil, fmt.Errorf("Provider %s could not be found", name)
}
diff --git a/internal/api/helpers.go b/internal/api/helpers.go
index 6921392525..96f3b22d27 100644
--- a/internal/api/helpers.go
+++ b/internal/api/helpers.go
@@ -65,6 +65,7 @@ func getBodyBytes(req *http.Request) ([]byte, error) {
type RequestParams interface {
AdminUserParams |
CreateSSOProviderParams |
+ CreateOIDCSSOProviderParams |
EnrollFactorParams |
GenerateLinkParams |
IdTokenGrantParams |
diff --git a/internal/api/middleware.go b/internal/api/middleware.go
index 8caa5d8855..edfc23338c 100644
--- a/internal/api/middleware.go
+++ b/internal/api/middleware.go
@@ -223,7 +223,7 @@ func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (con
return withExternalHost(ctx, u), nil
}
-func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) {
+func (a *API) requireSSOSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) {
ctx := req.Context()
if !a.config.SAML.Enabled {
return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled")
@@ -231,6 +231,22 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont
return ctx, nil
}
+func (a *API) requireSSOOIDCEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) {
+ ctx := req.Context()
+ if !a.config.OIDC.Enabled {
+ return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "OIDC is disabled")
+ }
+ return ctx, nil
+}
+
+func (a *API) requireSSOEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) {
+ ctx := req.Context()
+ if !(a.config.OIDC.Enabled || a.config.SAML.Enabled) {
+ return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "Either SAML or OIDC for SSO need to be enabled")
+ }
+ return ctx, nil
+}
+
func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) {
ctx := req.Context()
if !a.config.Security.ManualLinkingEnabled {
diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go
index eb8c5da3b5..4c65709109 100644
--- a/internal/api/middleware_test.go
+++ b/internal/api/middleware_test.go
@@ -287,7 +287,7 @@ func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() {
req := httptest.NewRequest("GET", "http://localhost", nil)
w := httptest.NewRecorder()
- _, err := ts.API.requireSAMLEnabled(w, req)
+ _, err := ts.API.requireSSOSAMLEnabled(w, req)
require.Equal(ts.T(), c.expectedErr, err)
})
}
diff --git a/internal/api/provider/generic.go b/internal/api/provider/generic.go
new file mode 100644
index 0000000000..25c1363404
--- /dev/null
+++ b/internal/api/provider/generic.go
@@ -0,0 +1,280 @@
+package provider
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "math"
+ "net/http"
+ "strconv"
+ "strings"
+
+ "github.com/supabase/auth/internal/conf"
+ "github.com/supabase/auth/internal/utilities"
+ "golang.org/x/oauth2"
+)
+
+type genericProvider struct {
+ *oauth2.Config
+ Issuer string
+ UserInfoURL string
+ UserDataMapping map[string]string
+}
+
+func (p genericProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
+ return p.Exchange(context.Background(), code)
+}
+
+func (p genericProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
+ var u map[string]interface{}
+
+ // Perform http request manually, because we need to vary it based on the provider config
+ req, err := http.NewRequest("GET", p.UserInfoURL, nil)
+
+ if err != nil {
+ return nil, err
+ }
+
+ // set headers
+ req.Header.Set("Client-Id", p.ClientID)
+ req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
+
+ client := &http.Client{Timeout: defaultTimeout}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer utilities.SafeClose(resp.Body)
+
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, fmt.Errorf("a %v error occurred with retrieving user from OAuth2 provider via %s", resp.StatusCode, p.UserInfoURL)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ err = json.Unmarshal(body, &u)
+ if err != nil {
+ return nil, err
+ }
+
+ // Read user data as specified in the JSON mapping
+ mapping := p.UserDataMapping
+
+ email, err := getStringFieldByPath(u, mapping["Email"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ emailVerified, err := getBooleanFieldByPath(u, mapping["EmailVerified"], email != "")
+ if err != nil {
+ return nil, err
+ }
+
+ emailPrimary, err := getBooleanFieldByPath(u, mapping["EmailPrimary"], email != "")
+ if err != nil {
+ return nil, err
+ }
+
+ issuer, err := getStringFieldByPath(u, mapping["Issuer"], p.Issuer)
+ if err != nil {
+ return nil, err
+ }
+
+ subject, err := getStringFieldByPath(u, mapping["Subject"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ name, err := getStringFieldByPath(u, mapping["Name"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ familyName, err := getStringFieldByPath(u, mapping["FamilyName"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ givenName, err := getStringFieldByPath(u, mapping["GivenName"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ middleName, err := getStringFieldByPath(u, mapping["MiddleName"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ nickName, err := getStringFieldByPath(u, mapping["NickName"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ preferredUsername, err := getStringFieldByPath(u, mapping["PreferredUsername"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ profile, err := getStringFieldByPath(u, mapping["Profile"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ picture, err := getStringFieldByPath(u, mapping["Picture"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ website, err := getStringFieldByPath(u, mapping["Website"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ gender, err := getStringFieldByPath(u, mapping["Gender"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ birthdate, err := getStringFieldByPath(u, mapping["Birthdate"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ zoneInfo, err := getStringFieldByPath(u, mapping["ZoneInfo"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ locale, err := getStringFieldByPath(u, mapping["Locale"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ updatedAt, err := getStringFieldByPath(u, mapping["UpdatedAt"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ phone, err := getStringFieldByPath(u, mapping["Phone"], "")
+ if err != nil {
+ return nil, err
+ }
+
+ phoneVerified, err := getBooleanFieldByPath(u, mapping["PhoneVerified"], phone != "")
+ if err != nil {
+ return nil, err
+ }
+
+ data := &UserProvidedData{
+ Emails: []Email{
+ {
+ Email: email,
+ Verified: emailVerified,
+ Primary: emailPrimary,
+ },
+ },
+ Metadata: &Claims{
+ Issuer: issuer,
+ Subject: subject,
+ Name: name,
+ FamilyName: familyName,
+ GivenName: givenName,
+ MiddleName: middleName,
+ NickName: nickName,
+ PreferredUsername: preferredUsername,
+ Profile: profile,
+ Picture: picture,
+ Website: website,
+ Gender: gender,
+ Birthdate: birthdate,
+ ZoneInfo: zoneInfo,
+ Locale: locale,
+ UpdatedAt: updatedAt,
+ Email: email,
+ EmailVerified: emailVerified,
+ Phone: phone,
+ PhoneVerified: phoneVerified,
+ },
+ }
+
+ return data, nil
+}
+
+func getFieldByPath(obj map[string]interface{}, path string, fallback interface{}) (interface{}, error) {
+ value := obj
+
+ pathParts := strings.Split(path, ".")
+ for index, field := range pathParts {
+ fieldValue, ok := value[field]
+ if !ok {
+ return fallback, nil
+ }
+
+ if index == len(pathParts)-1 {
+ return fieldValue, nil
+ }
+
+ value = fieldValue.(map[string]interface{})
+ }
+
+ return nil, nil
+}
+
+func getStringFieldByPath(obj map[string]interface{}, path string, fallback string) (string, error) {
+ value, err := getFieldByPath(obj, path, fallback)
+ if err != nil {
+ return "", err
+ }
+ if result, ok := value.(string); ok {
+ return result, nil
+ } else if intValue, ok := value.(int); ok {
+ return strconv.Itoa(intValue), nil
+ } else if floatValue, ok := value.(float64); ok {
+ return strconv.Itoa(int(math.Round(floatValue))), nil
+ } else if value == nil {
+ return "", nil
+ } else {
+ return "", fmt.Errorf("unable to read field as string: %q %q", path, value)
+ }
+}
+
+func getBooleanFieldByPath(obj map[string]interface{}, path string, fallback bool) (bool, error) {
+ value, err := getFieldByPath(obj, path, fallback)
+ if err != nil {
+ return false, err
+ }
+ if result, ok := value.(bool); ok {
+ return result, nil
+ } else {
+ return false, fmt.Errorf("unable to read field as boolean: %q", path)
+ }
+}
+
+// NewGenericProvider creates an OAuth provider according to the config specified by the user
+func NewGenericProvider(ext conf.GenericOAuthProviderConfiguration, scopes string) (OAuthProvider, error) {
+ if err := ext.ValidateOAuth(); err != nil {
+ return nil, err
+ }
+
+ oauthScopes := strings.Split(scopes, ",")
+
+ return &genericProvider{
+ Config: &oauth2.Config{
+ ClientID: ext.ClientID[0],
+ ClientSecret: ext.Secret,
+ Endpoint: oauth2.Endpoint{
+ AuthURL: ext.AuthURL,
+ TokenURL: ext.TokenURL,
+ },
+ RedirectURL: ext.RedirectURI,
+ Scopes: oauthScopes,
+ },
+ Issuer: ext.Issuer,
+ UserInfoURL: ext.UserInfoURL,
+ UserDataMapping: ext.UserDataMapping,
+ }, nil
+}
diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go
index 907efcd4c0..2e73c04774 100644
--- a/internal/api/samlacs.go
+++ b/internal/api/samlacs.go
@@ -170,7 +170,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
logentry.Warn("SAML Metadata for identity provider will expire soon! Update its metadata_xml!")
}
- } else if *ssoProvider.SAMLProvider.MetadataURL != "" && IsSAMLMetadataStale(idpMetadata, ssoProvider.SAMLProvider) {
+ } else if *ssoProvider.SAMLProvider.MetadataURL != "" && IsSAMLMetadataStale(idpMetadata, *ssoProvider.SAMLProvider) {
rawMetadata, err := fetchSAMLMetadata(ctx, *ssoProvider.SAMLProvider.MetadataURL)
if err != nil {
// Fail silently but raise warning and continue with existing metadata
diff --git a/internal/api/sso.go b/internal/api/sso.go
index 10034075c2..080d3df708 100644
--- a/internal/api/sso.go
+++ b/internal/api/sso.go
@@ -2,11 +2,10 @@ package api
import (
"net/http"
+ "net/url"
- "github.com/crewjam/saml"
"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/models"
- "github.com/supabase/auth/internal/storage"
)
type SingleSignOnParams struct {
@@ -57,16 +56,6 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil {
return err
}
- flowType := getFlowFromChallenge(params.CodeChallenge)
- var flowStateID *uuid.UUID
- flowStateID = nil
- if isPKCEFlow(flowType) {
- flowState, err := generateFlowState(db, models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil)
- if err != nil {
- return err
- }
- flowStateID = &flowState.ID
- }
var ssoProvider *models.SSOProvider
@@ -86,48 +75,37 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
}
}
- entityDescriptor, err := ssoProvider.SAMLProvider.EntityDescriptor()
- if err != nil {
- return internalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err)
- }
-
- serviceProvider := a.getSAMLServiceProvider(entityDescriptor, false /* <- idpInitiated */)
-
- authnRequest, err := serviceProvider.MakeAuthenticationRequest(
- serviceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding),
- saml.HTTPRedirectBinding,
- saml.HTTPPostBinding,
- )
- if err != nil {
- return internalServerError("Error creating SAML Authentication Request").WithInternalError(err)
- }
-
- // Some IdPs do not support the use of the `persistent` NameID format,
- // and require a different format to be sent to work.
- if ssoProvider.SAMLProvider.NameIDFormat != nil {
- authnRequest.NameIDPolicy.Format = ssoProvider.SAMLProvider.NameIDFormat
- }
-
- relayState := models.SAMLRelayState{
- SSOProviderID: ssoProvider.ID,
- RequestID: authnRequest.ID,
- RedirectTo: params.RedirectTo,
- FlowStateID: flowStateID,
+ var authMethod models.AuthenticationMethod
+ var providerType string
+ // providerType, authMethod := "", models.AuthenticationMethod
+ if ssoProvider.OIDCProvider == nil || ssoProvider.OIDCProvider.ClientId == "" {
+ providerType, authMethod = models.SSOSAML.String(), models.SSOSAML
+ } else {
+ providerType, authMethod = models.SSOOIDC.String(), models.SSOOIDC
}
- if err := db.Transaction(func(tx *storage.Connection) error {
- if terr := tx.Create(&relayState); terr != nil {
- return internalServerError("Error creating SAML relay state from sign up").WithInternalError(err)
+ flowType := getFlowFromChallenge(params.CodeChallenge)
+ var flowStateID *uuid.UUID
+ flowStateID = nil
+ if isPKCEFlow(flowType) {
+ flowState, err := generateFlowState(db, providerType, authMethod, codeChallengeMethod, codeChallenge, nil)
+ if err != nil {
+ return err
}
-
- return nil
- }); err != nil {
- return err
+ flowStateID = &flowState.ID
}
- ssoRedirectURL, err := authnRequest.Redirect(relayState.ID.String(), serviceProvider)
- if err != nil {
- return internalServerError("Error creating SAML authentication request redirect URL").WithInternalError(err)
+ var ssoRedirectURL *url.URL
+ if authMethod == models.SSOSAML {
+ ssoRedirectURL, err = GenerateRedirectWithSAML(a, db, ssoProvider, flowStateID, params)
+ if err != nil {
+ return internalServerError("Error creating SAML authentication request redirect URL").WithInternalError(err)
+ }
+ } else if authMethod == models.SSOOIDC {
+ ssoRedirectURL, err = GenerateRedirectWithOIDC(a, db, ssoProvider, flowStateID, params)
+ if err != nil {
+ return internalServerError("Error creating OIDC authentication request redirect URL").WithInternalError(err)
+ }
}
skipHTTPRedirect := false
diff --git a/internal/api/sso_oidc.go b/internal/api/sso_oidc.go
new file mode 100644
index 0000000000..dcd77f9ad2
--- /dev/null
+++ b/internal/api/sso_oidc.go
@@ -0,0 +1,124 @@
+package api
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "fmt"
+ "net/http"
+ "net/url"
+
+ "github.com/gofrs/uuid"
+ "github.com/supabase/auth/internal/api/provider"
+ "github.com/supabase/auth/internal/models"
+ "github.com/supabase/auth/internal/storage"
+)
+
+// GenerateRandomState generates a random state string for OAuth2
+func GenerateRandomState(length int) (string, error) {
+ // Create a byte slice to hold the random bytes
+ bytes := make([]byte, length)
+
+ // Read random bytes into the slice
+ if _, err := rand.Read(bytes); err != nil {
+ return "", err
+ }
+
+ // Encode the random bytes into a URL-safe base64 string
+ return base64.URLEncoding.EncodeToString(bytes), nil
+}
+
+func GenerateRedirectWithOIDC(a *API, db *storage.Connection, ssoProvider *models.SSOProvider, flowStateID *uuid.UUID, params *SingleSignOnParams) (*url.URL, error) {
+ oidcProviderConfig, err := ssoProvider.OIDCProvider.GenericProviderConfig()
+ if err != nil {
+ return &url.URL{}, internalServerError("Error creating generic OIDC provider config").WithInternalError(err)
+ }
+
+ oidcProviderConfig.RedirectURI = fmt.Sprintf("%s/sso/oidc/callback", a.config.API.ExternalURL)
+
+ provider, err := provider.NewGenericProvider(oidcProviderConfig, "openid")
+ if err != nil {
+ return &url.URL{}, internalServerError("Error creating generic OIDC provider").WithInternalError(err)
+ }
+
+ state, err := GenerateRandomState(32)
+ if err != nil {
+ return &url.URL{}, internalServerError("Error creating state").WithInternalError(err)
+ }
+
+ relayState := models.OIDCFlowState{
+ SSOProviderID: ssoProvider.ID,
+ State: state,
+ RedirectTo: params.RedirectTo,
+ FlowStateID: flowStateID,
+ }
+
+ if err := db.Transaction(func(tx *storage.Connection) error {
+ if terr := tx.Create(&relayState); terr != nil {
+ return internalServerError("Error creating SAML relay state from sign up").WithInternalError(err)
+ }
+
+ return nil
+ }); err != nil {
+ return &url.URL{}, err
+ }
+
+ link := provider.AuthCodeURL(state)
+
+ parsedUrl, err := url.Parse(link)
+ if err != nil {
+ return &url.URL{}, internalServerError("Error creating generic auth URL").WithInternalError(err)
+ }
+
+ return parsedUrl, nil
+}
+
+// loadFlowState parses the `state` query parameter as a JWS payload,
+// extracting the provider requested
+func (a *API) loadSSOOIDCFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
+ var state string
+ if r.Method == http.MethodPost {
+ state = r.FormValue("state")
+ } else {
+ state = r.URL.Query().Get("state")
+ }
+
+ if state == "" {
+ return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing")
+ }
+
+ ctx := r.Context()
+ oauthToken := r.URL.Query().Get("oauth_token")
+ if oauthToken != "" {
+ ctx = withRequestToken(ctx, oauthToken)
+ }
+ oauthVerifier := r.URL.Query().Get("oauth_verifier")
+ if oauthVerifier != "" {
+ ctx = withOAuthVerifier(ctx, oauthVerifier)
+ }
+ return a.loadSSOIDCState(ctx, state)
+}
+
+func (a *API) loadSSOIDCState(ctx context.Context, state string) (context.Context, error) {
+ db := a.db.WithContext(ctx)
+
+ flowState, err := models.FindOIDCFlowStateByID(db, state)
+ if err != nil {
+ return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err)
+ }
+
+ ctx = withFlowStateID(ctx, flowState.FlowStateID.String())
+
+ ssoProvider, err := models.FindSSOProviderByID(db, flowState.SSOProviderID)
+ if err != nil {
+ return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback provider not found").WithInternalError(err)
+ }
+ config, err := ssoProvider.OIDCProvider.GenericProviderConfig()
+
+ config.RedirectURI = fmt.Sprintf("%s/sso/oidc/callback", a.config.API.ExternalURL)
+
+ ctx = withGenericProviderConfig(ctx, &config)
+ ctx = withExternalProviderType(ctx, "sso/oidc")
+
+ return ctx, err
+}
diff --git a/internal/api/sso_saml.go b/internal/api/sso_saml.go
new file mode 100644
index 0000000000..b27b6e69c8
--- /dev/null
+++ b/internal/api/sso_saml.go
@@ -0,0 +1,55 @@
+package api
+
+import (
+ "net/url"
+
+ "github.com/crewjam/saml"
+ "github.com/gofrs/uuid"
+ "github.com/supabase/auth/internal/models"
+ "github.com/supabase/auth/internal/storage"
+)
+
+func GenerateRedirectWithSAML(a *API, db *storage.Connection, ssoProvider *models.SSOProvider, flowStateID *uuid.UUID, params *SingleSignOnParams) (*url.URL, error) {
+ entityDescriptor, err := ssoProvider.SAMLProvider.EntityDescriptor()
+ if err != nil {
+ return &url.URL{}, internalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err)
+ }
+
+ serviceProvider := a.getSAMLServiceProvider(entityDescriptor, false /* <- idpInitiated */)
+
+ authnRequest, err := serviceProvider.MakeAuthenticationRequest(
+ serviceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding),
+ saml.HTTPRedirectBinding,
+ saml.HTTPPostBinding,
+ )
+ if err != nil {
+ return &url.URL{}, internalServerError("Error creating SAML Authentication Request").WithInternalError(err)
+ }
+
+ // Some IdPs do not support the use of the `persistent` NameID format,
+ // and require a different format to be sent to work.
+ if ssoProvider.SAMLProvider.NameIDFormat != nil {
+ authnRequest.NameIDPolicy.Format = ssoProvider.SAMLProvider.NameIDFormat
+ }
+
+ relayState := models.SAMLRelayState{
+ SSOProviderID: ssoProvider.ID,
+ RequestID: authnRequest.ID,
+ RedirectTo: params.RedirectTo,
+ FlowStateID: flowStateID,
+ }
+
+ if err := db.Transaction(func(tx *storage.Connection) error {
+ if terr := tx.Create(&relayState); terr != nil {
+ return internalServerError("Error creating SAML relay state from sign up").WithInternalError(err)
+ }
+
+ return nil
+ }); err != nil {
+ return &url.URL{}, err
+ }
+
+ ssoRedirectURL, err := authnRequest.Redirect(relayState.ID.String(), serviceProvider)
+
+ return ssoRedirectURL, err
+}
diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go
index 20fd8b9c5d..d8c72a820a 100644
--- a/internal/api/ssoadmin.go
+++ b/internal/api/ssoadmin.go
@@ -223,7 +223,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er
provider := &models.SSOProvider{
// TODO handle Name, Description, Attribute Mapping
- SAMLProvider: models.SAMLProvider{
+ SAMLProvider: &models.SAMLProvider{
EntityID: metadata.EntityID,
MetadataXML: string(rawMetadata),
},
@@ -390,7 +390,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er
}
if updateAttributeMapping || updateSAMLProvider {
- if terr := tx.Eager().Update(&provider.SAMLProvider); terr != nil {
+ if terr := tx.Eager().Update(provider.SAMLProvider); terr != nil {
return terr
}
}
diff --git a/internal/api/ssooidcadmin.go b/internal/api/ssooidcadmin.go
new file mode 100644
index 0000000000..1fbca22339
--- /dev/null
+++ b/internal/api/ssooidcadmin.go
@@ -0,0 +1,476 @@
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+
+ "github.com/go-chi/chi/v5"
+ "github.com/gofrs/uuid"
+ "github.com/supabase/auth/internal/conf"
+ "github.com/supabase/auth/internal/models"
+ "github.com/supabase/auth/internal/observability"
+ "github.com/supabase/auth/internal/storage"
+ "github.com/supabase/auth/internal/utilities"
+)
+
+// loadSSOProvider looks for an idp_id parameter in the URL route and loads the SSO provider
+// with that ID (or resource ID) and adds it to the context.
+func (a *API) loadOIDCSSOProvider(w http.ResponseWriter, r *http.Request) (context.Context, error) {
+ ctx := r.Context()
+ db := a.db.WithContext(ctx)
+
+ idpParam := chi.URLParam(r, "idp_id")
+
+ idpID, err := uuid.FromString(idpParam)
+ if err != nil {
+ // idpParam is not UUIDv4
+ return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found")
+ }
+
+ // idpParam is a UUIDv4
+ provider, err := models.FindSSOProviderByID(db, idpID)
+ if err != nil {
+ if models.IsNotFoundError(err) {
+ return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found")
+ } else {
+ return nil, internalServerError("Database error finding SSO Identity Provider").WithInternalError(err)
+ }
+ }
+
+ observability.LogEntrySetField(r, "sso_provider_id", provider.ID.String())
+
+ return withSSOProvider(r.Context(), provider), nil
+}
+
+// adminSSOProvidersList lists all SAML SSO Identity Providers in the system. Does
+// not deal with pagination at this time.
+func (a *API) adminOIDCSSOProvidersList(w http.ResponseWriter, r *http.Request) error {
+ ctx := r.Context()
+ db := a.db.WithContext(ctx)
+
+ providers, err := models.FindAllSAMLProviders(db)
+ if err != nil {
+ return err
+ }
+
+ for i := range providers {
+ // remove metadata XML so that the returned JSON is not ginormous
+ providers[i].SAMLProvider.MetadataXML = ""
+ }
+
+ return sendJSON(w, http.StatusOK, map[string]interface{}{
+ "items": providers,
+ })
+}
+
+type CreateOIDCSSOProviderParams struct {
+ Type string `json:"type"`
+
+ ClientId string `json:"client_id"`
+ Secret string `json:"secret"`
+ AuthURL string `json:"auth_url"`
+ TokenURL string `json:"token_url"`
+ UserinfoURL string `json:"userinfo_url"`
+ // MetadataURL string `json:"metadata_url"`
+ // MetadataXML string `json:"metadata_xml"`
+
+ DiscoveryURL string `json:"discover_url"`
+
+ Domains []string `json:"domains"`
+ AttributeMapping models.UserDataMapping `json:"attribute_mapping"`
+ // NameIDFormat string `json:"name_id_format"`
+}
+
+func (p *CreateOIDCSSOProviderParams) validate(forUpdate bool) error {
+ if !forUpdate && p.Type != "oidc" {
+ return badRequestError(ErrorCodeValidationFailed, "Only 'oidc' supported for SSO provider type")
+ }
+ // } else if p.MetadataURL != "" && p.MetadataXML != "" {
+ // return badRequestError(ErrorCodeValidationFailed, "Only one of metadata_xml or metadata_url needs to be set")
+ // } else if !forUpdate && p.MetadataURL == "" && p.MetadataXML == "" {
+ // return badRequestError(ErrorCodeValidationFailed, "Either metadata_xml or metadata_url must be set")
+ // } else if p.MetadataURL != "" {
+ // metadataURL, err := url.ParseRequestURI(p.MetadataURL)
+ // if err != nil {
+ // return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a valid URL")
+ // }
+
+ // if metadataURL.Scheme != "https" {
+ // return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a HTTPS URL")
+ // }
+ // }
+
+ // switch p.NameIDFormat {
+ // case "",
+ // string(saml.PersistentNameIDFormat),
+ // string(saml.EmailAddressNameIDFormat),
+ // string(saml.TransientNameIDFormat),
+ // string(saml.UnspecifiedNameIDFormat):
+ // // it's valid
+
+ // default:
+ // return badRequestError(ErrorCodeValidationFailed, "name_id_format must be unspecified or one of %v", strings.Join([]string{
+ // string(saml.PersistentNameIDFormat),
+ // string(saml.EmailAddressNameIDFormat),
+ // string(saml.TransientNameIDFormat),
+ // string(saml.UnspecifiedNameIDFormat),
+ // }, ", "))
+ // }
+
+ return nil
+}
+
+func (p *CreateOIDCSSOProviderParams) metadata(ctx context.Context) (*conf.GenericOAuthProviderConfiguration, error) {
+ var discover *OIDCDiscoveryResponse
+ var err error
+
+ var config *conf.GenericOAuthProviderConfiguration
+
+ if p.DiscoveryURL != "" {
+ discover, err = fetchOIDCMetadata(ctx, p.DiscoveryURL)
+ if err != nil {
+ return nil, err
+ }
+ config = &conf.GenericOAuthProviderConfiguration{
+ OAuthProviderConfiguration: &conf.OAuthProviderConfiguration{
+ ClientID: []string{p.ClientId},
+ Secret: p.Secret,
+ URL: discover.Issuer,
+ ApiURL: discover.UserInfoEndpoint,
+ RedirectURI: "", // TODO: figure out how to get the data
+ },
+ Issuer: discover.Issuer,
+ AuthURL: discover.AuthorizationEndpoint,
+ TokenURL: discover.TokenEndpoint,
+ UserInfoURL: discover.UserInfoEndpoint,
+ UserDataMapping: p.AttributeMapping.Keys,
+ }
+
+ log.Println(p.AttributeMapping)
+ } else if p.DiscoveryURL == "" && true {
+ config = &conf.GenericOAuthProviderConfiguration{
+ OAuthProviderConfiguration: &conf.OAuthProviderConfiguration{},
+ }
+ } else {
+ // impossible situation if you called validate() prior
+ return nil, nil
+ }
+
+ // metadata, err := parseSAMLMetadata(rawMetadata)
+ // if err != nil {
+ // return nil, err
+ // }
+
+ return config, nil
+}
+
+// func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) {
+// if !utf8.Valid(rawMetadata) {
+// return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time")
+// }
+
+// metadata, err := samlsp.ParseMetadata(rawMetadata)
+// if err != nil {
+// return nil, err
+// }
+
+// if metadata.EntityID == "" {
+// return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain an EntityID")
+// }
+
+// if len(metadata.IDPSSODescriptors) < 1 {
+// return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain any IDPSSODescriptor")
+// }
+
+// if len(metadata.IDPSSODescriptors) > 1 {
+// return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata contains multiple IDPSSODescriptors")
+// }
+
+// return metadata, nil
+// }
+
+type OIDCDiscoveryResponse struct {
+ Issuer string `json:"issuer"`
+ AuthorizationEndpoint string `json:"authorization_endpoint"`
+ TokenEndpoint string `json:"token_endpoint"`
+ UserInfoEndpoint string `json:"userinfo_endpoint"`
+ JWKSURI string `json:"jwks_uri"`
+ ScopesSupported []string `json:"scopes_supported"`
+ ResponseTypesSupported []string `json:"response_types_supported"`
+}
+
+func fetchOIDCMetadata(ctx context.Context, issuerURL string) (*OIDCDiscoveryResponse, error) {
+ // Construct the well-known URL
+ discoveryURL := fmt.Sprintf("%s/.well-known/openid-configuration", issuerURL)
+
+ req, err := http.NewRequest(http.MethodGet, discoveryURL, nil)
+ if err != nil {
+ return nil, internalServerError("Unable to create a request to metadata_url").WithInternalError(err)
+ }
+
+ req = req.WithContext(ctx)
+
+ // req.Header.Set("Accept", "application/xml;charset=UTF-8")
+ req.Header.Set("Accept-Charset", "UTF-8")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+
+ defer utilities.SafeClose(resp.Body)
+ if resp.StatusCode != http.StatusOK {
+ return nil, badRequestError(ErrorCodeSAMLMetadataFetchFailed, "HTTP %v error fetching OIDC Metadata from URL '%s'", resp.StatusCode, issuerURL)
+ }
+
+ // Decode the JSON response into a struct
+ var config OIDCDiscoveryResponse
+ if err := json.NewDecoder(resp.Body).Decode(&config); err != nil {
+ return nil, err
+ }
+
+ return &config, nil
+}
+
+// adminSSOProvidersCreate creates a new SAML Identity Provider in the system.
+func (a *API) adminOIDCSSOProvidersCreate(w http.ResponseWriter, r *http.Request) error {
+ ctx := r.Context()
+ db := a.db.WithContext(ctx)
+
+ params := &CreateOIDCSSOProviderParams{}
+ if err := retrieveRequestParams(r, params); err != nil {
+ return err
+ }
+
+ if err := params.validate(false /* <- forUpdate */); err != nil {
+ return err
+ }
+
+ log.Println("20")
+ config, err := params.metadata(ctx)
+ if err != nil {
+ return err
+ }
+
+ log.Println("21")
+ existingProvider, err := models.FindOIDCProviderByEntityID(db, params.ClientId, params.AuthURL)
+ if err != nil && !models.IsNotFoundError(err) {
+ return err
+ }
+ log.Println("22")
+ if existingProvider != nil {
+ return unprocessableEntityError(ErrorCodeSAMLIdPAlreadyExists, "OIDC Identity Provider with this ClientID (%s) and AuthURL (%s) already exists", params.ClientId, params.AuthURL)
+ }
+ log.Println("23")
+ provider := &models.SSOProvider{
+ // TODO handle Name, Description, Attribute Mapping
+ SAMLProvider: nil,
+ OIDCProvider: &models.OIDCProvider{
+ Issuer: config.Issuer,
+ ClientId: config.ClientID[0],
+ AuthURL: config.AuthURL,
+ TokenURL: config.TokenURL,
+ UserInfoURL: config.UserInfoURL,
+ Secret: config.Secret,
+ RedirectURI: config.RedirectURI,
+ AttributeMapping: models.UserDataMapping{Keys: config.UserDataMapping},
+ },
+ }
+ log.Println("24")
+
+ // if params.MetadataURL != "" {
+ // provider.SAMLProvider.MetadataURL = ¶ms.MetadataURL
+ // }
+
+ // if params.NameIDFormat != "" {
+ // provider.SAMLProvider.NameIDFormat = ¶ms.NameIDFormat
+ // }
+
+ // provider.SAMLProvider.AttributeMapping = params.AttributeMapping
+
+ for _, domain := range params.Domains {
+ existingProvider, err := models.FindSSOProviderByDomain(db, domain)
+ if err != nil && !models.IsNotFoundError(err) {
+ return err
+ }
+ if existingProvider != nil {
+ return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String())
+ }
+
+ provider.SSODomains = append(provider.SSODomains, models.SSODomain{
+ Domain: domain,
+ })
+ }
+ log.Println("25")
+
+ if err := db.Transaction(func(tx *storage.Connection) error {
+
+ if terr := tx.Eager().Create(provider); terr != nil {
+ return terr
+ }
+
+ return tx.Eager().Load(provider)
+ }); err != nil {
+ return err
+ }
+ log.Println("26")
+
+ return sendJSON(w, http.StatusCreated, provider)
+}
+
+// adminSSOProvidersGet returns an existing SAML Identity Provider in the system.
+func (a *API) adminOIDCSSOProvidersGet(w http.ResponseWriter, r *http.Request) error {
+ provider := getSSOProvider(r.Context())
+
+ return sendJSON(w, http.StatusOK, provider)
+}
+
+// adminSSOProvidersUpdate updates a provider with the provided diff values.
+// func (a *API) adminOIDCSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) error {
+// ctx := r.Context()
+// db := a.db.WithContext(ctx)
+
+// params := &CreateSSOProviderParams{}
+// if err := retrieveRequestParams(r, params); err != nil {
+// return err
+// }
+
+// if err := params.validate(true /* <- forUpdate */); err != nil {
+// return err
+// }
+
+// modified := false
+// updateSAMLProvider := false
+
+// provider := getSSOProvider(ctx)
+
+// if params.MetadataXML != "" || params.MetadataURL != "" {
+// // metadata is being updated
+// rawMetadata, metadata, err := params.metadata(ctx)
+// if err != nil {
+// return err
+// }
+
+// if provider.SAMLProvider.EntityID != metadata.EntityID {
+// return badRequestError(ErrorCodeSAMLEntityIDMismatch, "SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID)
+// }
+
+// if params.MetadataURL != "" {
+// provider.SAMLProvider.MetadataURL = ¶ms.MetadataURL
+// }
+
+// provider.SAMLProvider.MetadataXML = string(rawMetadata)
+// updateSAMLProvider = true
+// modified = true
+// }
+
+// // domains are being "updated" only when params.Domains is not nil, if
+// // it was nil (but not `[]`) then the caller is expecting not to modify
+// // the domains
+// updateDomains := params.Domains != nil
+
+// var createDomains, deleteDomains []models.SSODomain
+// keepDomains := make(map[string]bool)
+
+// for _, domain := range params.Domains {
+// existingProvider, err := models.FindSSOProviderByDomain(db, domain)
+// if err != nil && !models.IsNotFoundError(err) {
+// return err
+// }
+// if existingProvider != nil {
+// if existingProvider.ID == provider.ID {
+// keepDomains[domain] = true
+// } else {
+// return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String())
+// }
+// } else {
+// modified = true
+// createDomains = append(createDomains, models.SSODomain{
+// Domain: domain,
+// SSOProviderID: provider.ID,
+// })
+// }
+// }
+
+// if updateDomains {
+// for i, domain := range provider.SSODomains {
+// if !keepDomains[domain.Domain] {
+// modified = true
+// deleteDomains = append(deleteDomains, provider.SSODomains[i])
+// }
+// }
+// }
+
+// updateAttributeMapping := false
+// if params.AttributeMapping.Keys != nil {
+// updateAttributeMapping = !provider.SAMLProvider.AttributeMapping.Equal(¶ms.AttributeMapping)
+// if updateAttributeMapping {
+// modified = true
+// provider.SAMLProvider.AttributeMapping = params.AttributeMapping
+// }
+// }
+
+// nameIDFormat := ""
+// if provider.SAMLProvider.NameIDFormat != nil {
+// nameIDFormat = *provider.SAMLProvider.NameIDFormat
+// }
+
+// if params.NameIDFormat != nameIDFormat {
+// modified = true
+
+// if params.NameIDFormat == "" {
+// provider.SAMLProvider.NameIDFormat = nil
+// } else {
+// provider.SAMLProvider.NameIDFormat = ¶ms.NameIDFormat
+// }
+// }
+
+// if modified {
+// if err := db.Transaction(func(tx *storage.Connection) error {
+// if terr := tx.Eager().Update(provider); terr != nil {
+// return terr
+// }
+
+// if updateDomains {
+// if terr := tx.Destroy(deleteDomains); terr != nil {
+// return terr
+// }
+
+// if terr := tx.Eager().Create(createDomains); terr != nil {
+// return terr
+// }
+// }
+
+// if updateAttributeMapping || updateSAMLProvider {
+// if terr := tx.Eager().Update(&provider.SAMLProvider); terr != nil {
+// return terr
+// }
+// }
+
+// return tx.Eager().Load(provider)
+// }); err != nil {
+// return unprocessableEntityError(ErrorCodeConflict, "Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err)
+// }
+// }
+
+// return sendJSON(w, http.StatusOK, provider)
+// }
+
+// adminSSOProvidersDelete deletes a SAML identity provider.
+func (a *API) adminOIDCSSOProvidersDelete(w http.ResponseWriter, r *http.Request) error {
+ ctx := r.Context()
+ db := a.db.WithContext(ctx)
+
+ provider := getSSOProvider(ctx)
+
+ if err := db.Transaction(func(tx *storage.Connection) error {
+ return tx.Eager().Destroy(provider)
+ }); err != nil {
+ return err
+ }
+
+ return sendJSON(w, http.StatusOK, provider)
+}
diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go
index 21216fedbf..e60fcf947a 100644
--- a/internal/conf/configuration.go
+++ b/internal/conf/configuration.go
@@ -65,6 +65,16 @@ type OAuthProviderConfiguration struct {
SkipNonceCheck bool `json:"skip_nonce_check" split_words:"true"`
}
+// GenericOAuthProviderConfiguration holds all config related to generic OAuth providers.
+type GenericOAuthProviderConfiguration struct {
+ *OAuthProviderConfiguration
+ AuthURL string `json:"auth_url" envconfig:"AUTH_URL"`
+ TokenURL string `json:"token_url" envconfig:"TOKEN_URL"`
+ Issuer string `json:"issuer"`
+ UserInfoURL string `json:"userinfo_url" split_words:"true"`
+ UserDataMapping map[string]string `json:"mapping" split_words:"true"`
+}
+
type AnonymousProviderConfiguration struct {
Enabled bool `json:"enabled" default:"false"`
}
@@ -263,6 +273,7 @@ type GlobalConfiguration struct {
Duration int `json:"duration"`
} `json:"cookies"`
SAML SAMLConfiguration `json:"saml"`
+ OIDC OIDCConfiguration `json:"oidc"`
CORS CORSConfiguration `json:"cors"`
}
diff --git a/internal/conf/oidc.go b/internal/conf/oidc.go
new file mode 100644
index 0000000000..a705e6b682
--- /dev/null
+++ b/internal/conf/oidc.go
@@ -0,0 +1,14 @@
+package conf
+
+// OIDCConfiguration holds configuration for native OIDC SSO support.
+type OIDCConfiguration struct {
+ Enabled bool `json:"enabled"`
+}
+
+func (c *OIDCConfiguration) Validate() error {
+ if c.Enabled {
+
+ }
+
+ return nil
+}
diff --git a/internal/models/factor.go b/internal/models/factor.go
index 7c6f6dd30c..58dd815812 100644
--- a/internal/models/factor.go
+++ b/internal/models/factor.go
@@ -49,6 +49,7 @@ const (
EmailChange
TokenRefresh
Anonymous
+ SSOOIDC
)
func (authMethod AuthenticationMethod) String() string {
@@ -67,6 +68,8 @@ func (authMethod AuthenticationMethod) String() string {
return "invite"
case SSOSAML:
return "sso/saml"
+ case SSOOIDC:
+ return "sso/oidc"
case MagicLink:
return "magiclink"
case EmailSignup:
@@ -102,6 +105,8 @@ func ParseAuthenticationMethod(authMethod string) (AuthenticationMethod, error)
return Invite, nil
case "sso/saml":
return SSOSAML, nil
+ case "sso/oidc":
+ return SSOOIDC, nil
case "magiclink":
return MagicLink, nil
case "email/signup":
diff --git a/internal/models/sso.go b/internal/models/sso.go
index 28c2429acb..5526f0edd0 100644
--- a/internal/models/sso.go
+++ b/internal/models/sso.go
@@ -12,14 +12,16 @@ import (
"github.com/crewjam/saml/samlsp"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
+ "github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/storage"
)
type SSOProvider struct {
ID uuid.UUID `db:"id" json:"id"`
- SAMLProvider SAMLProvider `has_one:"saml_providers" fk_id:"sso_provider_id" json:"saml,omitempty"`
- SSODomains []SSODomain `has_many:"sso_domains" fk_id:"sso_provider_id" json:"domains"`
+ SAMLProvider *SAMLProvider `has_one:"saml_providers" fk_id:"sso_provider_id" json:"saml,omitempty"`
+ OIDCProvider *OIDCProvider `has_one:"oidc_providers" fk_id:"sso_provider_id" json:"oidc,omitempty"`
+ SSODomains []SSODomain `has_many:"sso_domains" fk_id:"sso_provider_id" json:"domains"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
@@ -127,14 +129,87 @@ type SAMLProvider struct {
UpdatedAt time.Time `db:"updated_at" json:"-"`
}
+type UserDataMapping struct {
+ Keys map[string]string `json:"keys,omitempty"`
+}
+
+func (m *UserDataMapping) Scan(src interface{}) error {
+ b, ok := src.([]byte)
+ if !ok {
+ return errors.New("scan source was not []byte")
+ }
+ err := json.Unmarshal(b, m)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (m UserDataMapping) Value() (driver.Value, error) {
+ b, err := json.Marshal(m)
+ if err != nil {
+ return nil, err
+ }
+ return string(b), nil
+}
+
+type OIDCProvider struct {
+ ID uuid.UUID `db:"id" json:"-"`
+
+ SSOProvider *SSOProvider `belongs_to:"sso_providers" json:"-"`
+ SSOProviderID uuid.UUID `db:"sso_provider_id" json:"-"`
+
+ Issuer string `db:"issuer" json:"issuer"`
+ ClientId string `db:"client_id" json:"client_id"`
+ Secret string `db:"secret" json:"secret"`
+ AuthURL string `db:"auth_url" json:"auth_url"`
+ TokenURL string `db:"token_url" json:"token_url"`
+ UserInfoURL string `db:"userinfo_url" json:"userinfo_url"`
+
+ RedirectURI string `db:"redirect_uri" json:"redirect_uri"`
+
+ AttributeMapping UserDataMapping `db:"attribute_mapping" json:"attribute_mapping,omitempty"`
+
+ CreatedAt time.Time `db:"created_at" json:"-"`
+ UpdatedAt time.Time `db:"updated_at" json:"-"`
+}
+
func (p SAMLProvider) TableName() string {
return "saml_providers"
}
+func (p OIDCProvider) TableName() string {
+ return "oidc_providers"
+}
+
func (p SAMLProvider) EntityDescriptor() (*saml.EntityDescriptor, error) {
return samlsp.ParseMetadata([]byte(p.MetadataXML))
}
+func (p OIDCProvider) GenericProviderConfig() (conf.GenericOAuthProviderConfiguration, error) {
+
+ oauthConfig := &conf.OAuthProviderConfiguration{
+ ClientID: []string{p.ClientId},
+ Secret: p.Secret,
+ RedirectURI: "",
+ URL: p.Issuer,
+ ApiURL: p.UserInfoURL,
+ Enabled: true,
+ SkipNonceCheck: true,
+ }
+
+ providerConfig := conf.GenericOAuthProviderConfiguration{
+ OAuthProviderConfiguration: oauthConfig,
+ AuthURL: p.AuthURL,
+ TokenURL: p.TokenURL,
+ Issuer: p.Issuer,
+ UserInfoURL: p.UserInfoURL,
+ UserDataMapping: p.AttributeMapping.Keys,
+ }
+
+ return providerConfig, nil
+}
+
type SSODomain struct {
ID uuid.UUID `db:"id" json:"-"`
@@ -167,10 +242,29 @@ type SAMLRelayState struct {
FlowState *FlowState `db:"-" json:"flow_state,omitempty" belongs_to:"flow_state"`
}
+type OIDCFlowState struct {
+ ID uuid.UUID `db:"id"`
+
+ SSOProviderID uuid.UUID `db:"sso_provider_id"`
+
+ State string `db:"state"`
+
+ RedirectTo string `db:"redirect_to"`
+
+ CreatedAt time.Time `db:"created_at" json:"-"`
+ UpdatedAt time.Time `db:"updated_at" json:"-"`
+ FlowStateID *uuid.UUID `db:"flow_state_id" json:"flow_state_id,omitempty"`
+ FlowState *FlowState `db:"-" json:"flow_state,omitempty" belongs_to:"flow_state"`
+}
+
func (s SAMLRelayState) TableName() string {
return "saml_relay_states"
}
+func (s OIDCFlowState) TableName() string {
+ return "oidc_relay_states"
+}
+
func FindSAMLProviderByEntityID(tx *storage.Connection, entityId string) (*SSOProvider, error) {
var samlProvider SAMLProvider
if err := tx.Q().Where("entity_id = ?", entityId).First(&samlProvider); err != nil {
@@ -189,6 +283,24 @@ func FindSAMLProviderByEntityID(tx *storage.Connection, entityId string) (*SSOPr
return &ssoProvider, nil
}
+func FindOIDCProviderByEntityID(tx *storage.Connection, clientId string, authUrl string) (*SSOProvider, error) {
+ var samlProvider OIDCProvider
+ if err := tx.Q().Where("client_id = ?", clientId).Where("auth_url = ?", clientId).First(&samlProvider); err != nil {
+ if errors.Cause(err) == sql.ErrNoRows {
+ return nil, SSOProviderNotFoundError{}
+ }
+
+ return nil, errors.Wrap(err, "error finding SAML SSO provider by EntityID")
+ }
+
+ var ssoProvider SSOProvider
+ if err := tx.Eager().Q().Where("id = ?", samlProvider.SSOProviderID).First(&ssoProvider); err != nil {
+ return nil, errors.Wrap(err, "error finding SAML SSO provider by ID (via EntityID)")
+ }
+
+ return &ssoProvider, nil
+}
+
func FindSSOProviderByID(tx *storage.Connection, id uuid.UUID) (*SSOProvider, error) {
var ssoProvider SSOProvider
@@ -247,6 +359,20 @@ func FindAllSAMLProviders(tx *storage.Connection) ([]SSOProvider, error) {
return providers, nil
}
+func FindAllOIDCProviders(tx *storage.Connection) ([]SSOProvider, error) {
+ var providers []SSOProvider
+
+ if err := tx.Eager().All(&providers); err != nil {
+ if errors.Cause(err) == sql.ErrNoRows {
+ return nil, nil
+ }
+
+ return nil, errors.Wrap(err, "error loading all OIDC SSO providers")
+ }
+
+ return providers, nil
+}
+
func FindSAMLRelayStateByID(tx *storage.Connection, id uuid.UUID) (*SAMLRelayState, error) {
var state SAMLRelayState
@@ -260,3 +386,17 @@ func FindSAMLRelayStateByID(tx *storage.Connection, id uuid.UUID) (*SAMLRelaySta
return &state, nil
}
+
+func FindOIDCFlowStateByID(tx *storage.Connection, stateId string) (*OIDCFlowState, error) {
+ var state OIDCFlowState
+
+ if err := tx.Eager().Q().Where("state = ?", stateId).First(&state); err != nil {
+ if errors.Cause(err) == sql.ErrNoRows {
+ return nil, SAMLRelayStateNotFoundError{}
+ }
+
+ return nil, errors.Wrap(err, "error loading OIDC Flow State")
+ }
+
+ return &state, nil
+}
diff --git a/internal/models/sso_test.go b/internal/models/sso_test.go
index b6c9656309..9993c043a6 100644
--- a/internal/models/sso_test.go
+++ b/internal/models/sso_test.go
@@ -43,7 +43,7 @@ func (ts *SSOTestSuite) TestConstraints() {
examples := []exampleSpec{
{
Provider: &SSOProvider{
- SAMLProvider: SAMLProvider{
+ SAMLProvider: &SAMLProvider{
EntityID: "",
MetadataXML: "",
},
@@ -51,7 +51,7 @@ func (ts *SSOTestSuite) TestConstraints() {
},
{
Provider: &SSOProvider{
- SAMLProvider: SAMLProvider{
+ SAMLProvider: &SAMLProvider{
EntityID: "https://example.com/saml/metadata",
MetadataXML: "",
},
@@ -59,7 +59,7 @@ func (ts *SSOTestSuite) TestConstraints() {
},
{
Provider: &SSOProvider{
- SAMLProvider: SAMLProvider{
+ SAMLProvider: &SAMLProvider{
EntityID: "https://example.com/saml/metadata",
MetadataXML: "",
},
@@ -79,7 +79,7 @@ func (ts *SSOTestSuite) TestConstraints() {
func (ts *SSOTestSuite) TestDomainUniqueness() {
require.NoError(ts.T(), ts.db.Eager().Create(&SSOProvider{
- SAMLProvider: SAMLProvider{
+ SAMLProvider: &SAMLProvider{
EntityID: "https://example.com/saml/metadata1",
MetadataXML: "",
},
@@ -91,7 +91,7 @@ func (ts *SSOTestSuite) TestDomainUniqueness() {
}))
require.Error(ts.T(), ts.db.Eager().Create(&SSOProvider{
- SAMLProvider: SAMLProvider{
+ SAMLProvider: &SAMLProvider{
EntityID: "https://example.com/saml/metadata2",
MetadataXML: "",
},
@@ -105,7 +105,7 @@ func (ts *SSOTestSuite) TestDomainUniqueness() {
func (ts *SSOTestSuite) TestEntityIDUniqueness() {
require.NoError(ts.T(), ts.db.Eager().Create(&SSOProvider{
- SAMLProvider: SAMLProvider{
+ SAMLProvider: &SAMLProvider{
EntityID: "https://example.com/saml/metadata",
MetadataXML: "",
},
@@ -117,7 +117,7 @@ func (ts *SSOTestSuite) TestEntityIDUniqueness() {
}))
require.Error(ts.T(), ts.db.Eager().Create(&SSOProvider{
- SAMLProvider: SAMLProvider{
+ SAMLProvider: &SAMLProvider{
EntityID: "https://example.com/saml/metadata",
MetadataXML: "",
},
@@ -131,7 +131,7 @@ func (ts *SSOTestSuite) TestEntityIDUniqueness() {
func (ts *SSOTestSuite) TestFindSSOProviderForEmailAddress() {
provider := &SSOProvider{
- SAMLProvider: SAMLProvider{
+ SAMLProvider: &SAMLProvider{
EntityID: "https://example.com/saml/metadata",
MetadataXML: "",
},
@@ -182,7 +182,7 @@ func (ts *SSOTestSuite) TestFindSSOProviderForEmailAddress() {
func (ts *SSOTestSuite) TestFindSAMLProviderByEntityID() {
provider := &SSOProvider{
- SAMLProvider: SAMLProvider{
+ SAMLProvider: &SAMLProvider{
EntityID: "https://example.com/saml/metadata",
MetadataXML: "",
},
diff --git a/migrations/20240819081613_add_oidc_sso.up.sql b/migrations/20240819081613_add_oidc_sso.up.sql
new file mode 100644
index 0000000000..ad34b9d9f7
--- /dev/null
+++ b/migrations/20240819081613_add_oidc_sso.up.sql
@@ -0,0 +1,49 @@
+do $$
+begin
+ create table if not exists {{ index .Options "Namespace" }}.oidc_providers (
+ id uuid not null,
+ sso_provider_id uuid not null,
+ issuer text not null,
+ client_id text not null,
+ secret text not null,
+ auth_url text not null,
+ token_url text not null,
+ userinfo_url text not null,
+ redirect_uri text not null,
+ -- metadata_url text null,
+ attribute_mapping jsonb null,
+ created_at timestamptz null,
+ updated_at timestamptz null,
+ primary key (id),
+ foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade
+ -- constraint "metadata_xml not empty" check (char_length(metadata_xml) > 0),
+ -- constraint "metadata_url not empty" check (metadata_url = null or char_length(metadata_url) > 0),
+ -- constraint "entity_id not empty" check (char_length(entity_id) > 0)
+ );
+
+ create index if not exists oidc_providers_sso_provider_id_idx on {{ index .Options "Namespace" }}.oidc_providers (sso_provider_id);
+
+ comment on table {{ index .Options "Namespace" }}.oidc_providers is 'Auth: Manages OIDC Identity Provider connections.';
+
+ create table if not exists {{ index .Options "Namespace" }}.oidc_relay_states (
+ id uuid not null,
+ sso_provider_id uuid not null,
+ state text not null,
+ for_email text null,
+ redirect_to text null,
+ created_at timestamptz null,
+ updated_at timestamptz null,
+ flow_state_id uuid null,
+ primary key (id),
+ foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade,
+ foreign key (flow_state_id) references {{ index .Options "Namespace" }}.flow_state (id) on delete cascade,
+ constraint "state not empty" check(char_length(state) > 0)
+ );
+
+ create index if not exists oidc_relay_states_sso_provider_id_idx on {{ index .Options "Namespace" }}.oidc_relay_states (sso_provider_id);
+ create index if not exists oidc_relay_states_for_email_idx on {{ index .Options "Namespace" }}.oidc_relay_states (for_email);
+
+ comment on table {{ index .Options "Namespace" }}.oidc_relay_states is 'Auth: Contains OIDC Relay State information for each Service Provider initiated login.';
+
+
+end $$;