From 77c80e6defc15ae737927ae3469d03503c303861 Mon Sep 17 00:00:00 2001 From: Patrick Sanders Date: Wed, 17 Mar 2021 16:02:27 -0700 Subject: [PATCH] Improve error handling (#51) --- cache/cache.go | 3 +-- creds/consoleme.go | 10 ++++++++++ creds/refreshable.go | 4 ---- errors/errors.go | 3 +++ metadata/metadata.go | 12 ++++++------ metadata/metadata_test.go | 4 ++-- metadata/types.go | 12 ++++++------ server/credentialsHandler.go | 13 ++++++++++--- server/ecsCredentialsHandler.go | 9 ++++++--- server/server.go | 13 ++++++++++++- server/utils.go | 21 +++++++++++++++++++++ util/util.go | 17 +++++------------ 12 files changed, 82 insertions(+), 39 deletions(-) create mode 100644 server/utils.go diff --git a/cache/cache.go b/cache/cache.go index 42e35a2..743b9d0 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -17,7 +17,6 @@ package cache import ( - "fmt" "strings" "sync" @@ -122,7 +121,7 @@ func (cc *CredentialCache) get(slug string) (*creds.RefreshableProvider, bool) { func (cc *CredentialCache) set(client creds.HTTPClient, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) { c, err := creds.NewRefreshableProvider(client, role, region, assumeChain, false) if err != nil { - return nil, fmt.Errorf("could not generate creds: %w", err) + return nil, err } cc.Lock() defer cc.Unlock() diff --git a/creds/consoleme.go b/creds/consoleme.go index 8ebcca5..6db366b 100644 --- a/creds/consoleme.go +++ b/creds/consoleme.go @@ -186,8 +186,18 @@ func parseError(statusCode int, rawErrorResponse []byte) error { } switch errorResponse.Code { + case "899": + return werrors.InvalidArn + case "900": + return werrors.NoMatchingRoles case "901": return werrors.MultipleMatchingRoles + case "902": + return werrors.CredentialRetrievalError + case "903": + return werrors.NoMatchingRoles + case "904": + return werrors.MalformedRequestError case "905": return werrors.MutualTLSCertNeedsRefreshError case "invalid_jwt": diff --git a/creds/refreshable.go b/creds/refreshable.go index dd9eed6..9e9623d 100644 --- a/creds/refreshable.go +++ b/creds/refreshable.go @@ -102,15 +102,11 @@ RetryLoop: rp.client.CloseIdleConnections() time.Sleep(retryDelay) } - case errors.MultipleMatchingRoles: - return err default: - log.Errorf("failed to get refreshed credentials: %s", err.Error()) return err } } if err != nil { - log.Errorf("Unable to retrieve credentials from ConsoleMe: %v", err) return err } diff --git a/errors/errors.go b/errors/errors.go index 7c53bc4..e40aecd 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -26,6 +26,9 @@ const ( CredentialGenerationFailed = Error("credential generation failed") CredentialRetrievalError = Error("failed to retrieve credentials from broker") InvalidJWT = Error("JWT is invalid") + InvalidArn = Error("requested ARN is invalid") MutualTLSCertNeedsRefreshError = Error("mTLS cert needs to be refreshed") MultipleMatchingRoles = Error("more than one matching role for search string") + NoMatchingRoles = Error("no matching roles for search string") + MalformedRequestError = Error("malformed request sent to broker") ) diff --git a/metadata/metadata.go b/metadata/metadata.go index 7e44288..da4dc16 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -41,12 +41,12 @@ func init() { func GetInstanceInfo() *InstanceInfo { currentTime := time.Now() currentInstanceInfo := &InstanceInfo{ - Hostname: hostname(), - Username: username(), - CertAgeSeconds: elapsedSeconds(certCreationTime, currentTime), - CertFingerprint: certFingerprint, - WeepVersion: Version, - WeepMethod: weepMethod, + Hostname: hostname(), + Username: username(), + CertAgeSeconds: elapsedSeconds(certCreationTime, currentTime), + CertFingerprintSHA256: certFingerprint, + WeepVersion: Version, + WeepMethod: weepMethod, } return currentInstanceInfo } diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index 66098e7..25f6159 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -24,8 +24,8 @@ func TestGetInstanceInfo(t *testing.T) { if !(result.CertAgeSeconds >= 4 && result.CertAgeSeconds <= 6) { t.Errorf("cert age seconds: expected 4 <= x <= 6, got %d", result.CertAgeSeconds) } - if result.CertFingerprint != certFingerprint { - t.Errorf("cert fingerprint: expected %s, got %s", certFingerprint, result.CertFingerprint) + if result.CertFingerprintSHA256 != certFingerprint { + t.Errorf("cert fingerprint: expected %s, got %s", certFingerprint, result.CertFingerprintSHA256) } if result.WeepVersion != Version { t.Errorf("weep version: expected %s, got %s", Version, result.WeepVersion) diff --git a/metadata/types.go b/metadata/types.go index 5fa110d..719658f 100644 --- a/metadata/types.go +++ b/metadata/types.go @@ -17,10 +17,10 @@ package metadata type InstanceInfo struct { - Hostname string `json:"hostname"` - Username string `json:"username"` - CertAgeSeconds int `json:"cert_age,omitempty"` - CertFingerprint string `json:"cert_fingerprint,omitempty"` - WeepVersion string `json:"weep_version"` - WeepMethod string `json:"weep_method"` + Hostname string `json:"hostname"` + Username string `json:"username"` + CertAgeSeconds int `json:"cert_age_seconds,omitempty"` + CertFingerprintSHA256 string `json:"cert_fingerprint_sha256,omitempty"` + WeepVersion string `json:"weep_version"` + WeepMethod string `json:"weep_method"` } diff --git a/server/credentialsHandler.go b/server/credentialsHandler.go index 6e44609..e49cd2a 100644 --- a/server/credentialsHandler.go +++ b/server/credentialsHandler.go @@ -21,27 +21,34 @@ import ( "fmt" "net/http" + "github.com/netflix/weep/util" + "github.com/netflix/weep/cache" ) func RoleHandler(w http.ResponseWriter, r *http.Request) { defaultRole, err := cache.GlobalCache.GetDefault() if err != nil { - fmt.Fprint(w, "error") + util.WriteError(w, "error", 500) return } - fmt.Fprint(w, defaultRole.Role) + if _, err := w.Write([]byte(defaultRole.Role)); err != nil { + log.Errorf("failed to write response: %v", err) + } } func IMDSHandler(w http.ResponseWriter, r *http.Request) { - c, err := cache.GlobalCache.GetDefault() if err != nil { log.Errorf("could not get credentials from cache: %e", err) + util.WriteError(w, err.Error(), http.StatusBadRequest) + return } credentials, err := c.Retrieve() if err != nil { log.Errorf("could not get credentials: %e", err) + util.WriteError(w, err.Error(), http.StatusBadRequest) + return } credentialResponse := MetaDataCredentialResponse{ diff --git a/server/ecsCredentialsHandler.go b/server/ecsCredentialsHandler.go index 48bbfcc..757b275 100644 --- a/server/ecsCredentialsHandler.go +++ b/server/ecsCredentialsHandler.go @@ -54,9 +54,10 @@ func parseAssumeRoleQuery(r *http.Request) ([]string, error) { func getCredentialHandler(region string) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - var client, err = creds.GetClient("") + var client, err = creds.GetClient(region) if err != nil { log.Error(err) + util.WriteError(w, err.Error(), http.StatusBadRequest) return } assume, err := parseAssumeRoleQuery(r) @@ -68,16 +69,18 @@ func getCredentialHandler(region string) func(http.ResponseWriter, *http.Request vars := mux.Vars(r) requestedRole := vars["role"] - cached, err := cache.GlobalCache.GetOrSet(client, requestedRole, "", assume) + cached, err := cache.GlobalCache.GetOrSet(client, requestedRole, region, assume) if err != nil { // TODO: handle error better and return a helpful response/status - log.Errorf("failed to get credentials: %s", err.Error()) + log.Errorf("failed to get credentials: %s", err) + util.WriteError(w, err.Error(), http.StatusBadRequest) return } cachedCredentials, err := cached.Retrieve() if err != nil { // TODO: handle error better and return a helpful response/status log.Errorf("failed to get credentials: %s", err.Error()) + util.WriteError(w, err.Error(), http.StatusBadRequest) return } diff --git a/server/server.go b/server/server.go index 7a776f4..6418174 100644 --- a/server/server.go +++ b/server/server.go @@ -5,6 +5,7 @@ import ( "net" "net/http" "os" + "time" "github.com/gorilla/mux" "github.com/netflix/weep/cache" @@ -48,7 +49,17 @@ func Run(host string, port int, role, region string, shutdown chan os.Signal) er go func() { log.Info("Starting weep on ", listenAddr) - log.Info(http.ListenAndServe(listenAddr, router)) + srv := &http.Server{ + ReadTimeout: 1 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 30 * time.Second, + ReadHeaderTimeout: 2 * time.Second, + Addr: listenAddr, + Handler: router, + } + if err := srv.ListenAndServe(); err != nil { + log.Fatalf("server failed: %v", err) + } }() // Check for interrupt signal and exit cleanly diff --git a/server/utils.go b/server/utils.go new file mode 100644 index 0000000..2885ffb --- /dev/null +++ b/server/utils.go @@ -0,0 +1,21 @@ +package server + +import ( + "encoding/json" + "net/http" +) + +type httpError struct { + Message string `json:"message"` + Code string `json:"code"` +} + +func errorResponse(w http.ResponseWriter, message, code string) { + resp := httpError{ + Message: message, + Code: code, + } + if err := json.NewEncoder(w).Encode(resp); err != nil { + log.Errorf("failed to write error response: %v", err) + } +} diff --git a/util/util.go b/util/util.go index b13ab4a..d46f4b0 100644 --- a/util/util.go +++ b/util/util.go @@ -17,7 +17,6 @@ package util import ( - "encoding/json" "fmt" "net/http" "os" @@ -40,7 +39,8 @@ type AwsArn struct { } type ErrorResponse struct { - Error string `json:"error"` + Message string `json:"message"` + Code string `json:"code"` } func validate(arn string, pieces []string) error { @@ -91,21 +91,14 @@ func FileExists(path string) bool { return err == nil } -// WriteError writes a status code and JSON-formatted error to the provided http.ResponseWriter. +// WriteError writes a status code and plaintext error to the provided http.ResponseWriter. +// The error is written as plaintext so AWS SDKs will display it inline with an error message. func WriteError(w http.ResponseWriter, message string, status int) { log.Debugf("writing HTTP error response: %s", message) - resp := ErrorResponse{Error: message} - respBytes, err := json.Marshal(resp) - if err != nil { - log.Errorf("could not marshal error response: %s", err) - w.WriteHeader(http.StatusInternalServerError) - return - } w.WriteHeader(status) - _, err = w.Write(respBytes) + _, err := w.Write([]byte(message)) if err != nil { log.Errorf("could not write error response: %s", err) w.WriteHeader(http.StatusInternalServerError) - return } }