Skip to content

Commit

Permalink
Improve error handling (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
patricksanders authored Mar 17, 2021
1 parent 75b14fc commit 77c80e6
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 39 deletions.
3 changes: 1 addition & 2 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package cache

import (
"fmt"
"strings"
"sync"

Expand Down Expand Up @@ -122,7 +121,7 @@ func (cc *CredentialCache) get(slug string) (*creds.RefreshableProvider, bool) {
func (cc *CredentialCache) set(client creds.HTTPClient, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) {
c, err := creds.NewRefreshableProvider(client, role, region, assumeChain, false)
if err != nil {
return nil, fmt.Errorf("could not generate creds: %w", err)
return nil, err
}
cc.Lock()
defer cc.Unlock()
Expand Down
10 changes: 10 additions & 0 deletions creds/consoleme.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,18 @@ func parseError(statusCode int, rawErrorResponse []byte) error {
}

switch errorResponse.Code {
case "899":
return werrors.InvalidArn
case "900":
return werrors.NoMatchingRoles
case "901":
return werrors.MultipleMatchingRoles
case "902":
return werrors.CredentialRetrievalError
case "903":
return werrors.NoMatchingRoles
case "904":
return werrors.MalformedRequestError
case "905":
return werrors.MutualTLSCertNeedsRefreshError
case "invalid_jwt":
Expand Down
4 changes: 0 additions & 4 deletions creds/refreshable.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,11 @@ RetryLoop:
rp.client.CloseIdleConnections()
time.Sleep(retryDelay)
}
case errors.MultipleMatchingRoles:
return err
default:
log.Errorf("failed to get refreshed credentials: %s", err.Error())
return err
}
}
if err != nil {
log.Errorf("Unable to retrieve credentials from ConsoleMe: %v", err)
return err
}

Expand Down
3 changes: 3 additions & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ const (
CredentialGenerationFailed = Error("credential generation failed")
CredentialRetrievalError = Error("failed to retrieve credentials from broker")
InvalidJWT = Error("JWT is invalid")
InvalidArn = Error("requested ARN is invalid")
MutualTLSCertNeedsRefreshError = Error("mTLS cert needs to be refreshed")
MultipleMatchingRoles = Error("more than one matching role for search string")
NoMatchingRoles = Error("no matching roles for search string")
MalformedRequestError = Error("malformed request sent to broker")
)
12 changes: 6 additions & 6 deletions metadata/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ func init() {
func GetInstanceInfo() *InstanceInfo {
currentTime := time.Now()
currentInstanceInfo := &InstanceInfo{
Hostname: hostname(),
Username: username(),
CertAgeSeconds: elapsedSeconds(certCreationTime, currentTime),
CertFingerprint: certFingerprint,
WeepVersion: Version,
WeepMethod: weepMethod,
Hostname: hostname(),
Username: username(),
CertAgeSeconds: elapsedSeconds(certCreationTime, currentTime),
CertFingerprintSHA256: certFingerprint,
WeepVersion: Version,
WeepMethod: weepMethod,
}
return currentInstanceInfo
}
Expand Down
4 changes: 2 additions & 2 deletions metadata/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ func TestGetInstanceInfo(t *testing.T) {
if !(result.CertAgeSeconds >= 4 && result.CertAgeSeconds <= 6) {
t.Errorf("cert age seconds: expected 4 <= x <= 6, got %d", result.CertAgeSeconds)
}
if result.CertFingerprint != certFingerprint {
t.Errorf("cert fingerprint: expected %s, got %s", certFingerprint, result.CertFingerprint)
if result.CertFingerprintSHA256 != certFingerprint {
t.Errorf("cert fingerprint: expected %s, got %s", certFingerprint, result.CertFingerprintSHA256)
}
if result.WeepVersion != Version {
t.Errorf("weep version: expected %s, got %s", Version, result.WeepVersion)
Expand Down
12 changes: 6 additions & 6 deletions metadata/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
package metadata

type InstanceInfo struct {
Hostname string `json:"hostname"`
Username string `json:"username"`
CertAgeSeconds int `json:"cert_age,omitempty"`
CertFingerprint string `json:"cert_fingerprint,omitempty"`
WeepVersion string `json:"weep_version"`
WeepMethod string `json:"weep_method"`
Hostname string `json:"hostname"`
Username string `json:"username"`
CertAgeSeconds int `json:"cert_age_seconds,omitempty"`
CertFingerprintSHA256 string `json:"cert_fingerprint_sha256,omitempty"`
WeepVersion string `json:"weep_version"`
WeepMethod string `json:"weep_method"`
}
13 changes: 10 additions & 3 deletions server/credentialsHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,34 @@ import (
"fmt"
"net/http"

"github.com/netflix/weep/util"

"github.com/netflix/weep/cache"
)

func RoleHandler(w http.ResponseWriter, r *http.Request) {
defaultRole, err := cache.GlobalCache.GetDefault()
if err != nil {
fmt.Fprint(w, "error")
util.WriteError(w, "error", 500)
return
}
fmt.Fprint(w, defaultRole.Role)
if _, err := w.Write([]byte(defaultRole.Role)); err != nil {
log.Errorf("failed to write response: %v", err)
}
}

func IMDSHandler(w http.ResponseWriter, r *http.Request) {

c, err := cache.GlobalCache.GetDefault()
if err != nil {
log.Errorf("could not get credentials from cache: %e", err)
util.WriteError(w, err.Error(), http.StatusBadRequest)
return
}
credentials, err := c.Retrieve()
if err != nil {
log.Errorf("could not get credentials: %e", err)
util.WriteError(w, err.Error(), http.StatusBadRequest)
return
}

credentialResponse := MetaDataCredentialResponse{
Expand Down
9 changes: 6 additions & 3 deletions server/ecsCredentialsHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ func parseAssumeRoleQuery(r *http.Request) ([]string, error) {

func getCredentialHandler(region string) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var client, err = creds.GetClient("")
var client, err = creds.GetClient(region)
if err != nil {
log.Error(err)
util.WriteError(w, err.Error(), http.StatusBadRequest)
return
}
assume, err := parseAssumeRoleQuery(r)
Expand All @@ -68,16 +69,18 @@ func getCredentialHandler(region string) func(http.ResponseWriter, *http.Request
vars := mux.Vars(r)
requestedRole := vars["role"]

cached, err := cache.GlobalCache.GetOrSet(client, requestedRole, "", assume)
cached, err := cache.GlobalCache.GetOrSet(client, requestedRole, region, assume)
if err != nil {
// TODO: handle error better and return a helpful response/status
log.Errorf("failed to get credentials: %s", err.Error())
log.Errorf("failed to get credentials: %s", err)
util.WriteError(w, err.Error(), http.StatusBadRequest)
return
}
cachedCredentials, err := cached.Retrieve()
if err != nil {
// TODO: handle error better and return a helpful response/status
log.Errorf("failed to get credentials: %s", err.Error())
util.WriteError(w, err.Error(), http.StatusBadRequest)
return
}

Expand Down
13 changes: 12 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net"
"net/http"
"os"
"time"

"github.com/gorilla/mux"
"github.com/netflix/weep/cache"
Expand Down Expand Up @@ -48,7 +49,17 @@ func Run(host string, port int, role, region string, shutdown chan os.Signal) er

go func() {
log.Info("Starting weep on ", listenAddr)
log.Info(http.ListenAndServe(listenAddr, router))
srv := &http.Server{
ReadTimeout: 1 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 30 * time.Second,
ReadHeaderTimeout: 2 * time.Second,
Addr: listenAddr,
Handler: router,
}
if err := srv.ListenAndServe(); err != nil {
log.Fatalf("server failed: %v", err)
}
}()

// Check for interrupt signal and exit cleanly
Expand Down
21 changes: 21 additions & 0 deletions server/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package server

import (
"encoding/json"
"net/http"
)

type httpError struct {
Message string `json:"message"`
Code string `json:"code"`
}

func errorResponse(w http.ResponseWriter, message, code string) {
resp := httpError{
Message: message,
Code: code,
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Errorf("failed to write error response: %v", err)
}
}
17 changes: 5 additions & 12 deletions util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package util

import (
"encoding/json"
"fmt"
"net/http"
"os"
Expand All @@ -40,7 +39,8 @@ type AwsArn struct {
}

type ErrorResponse struct {
Error string `json:"error"`
Message string `json:"message"`
Code string `json:"code"`
}

func validate(arn string, pieces []string) error {
Expand Down Expand Up @@ -91,21 +91,14 @@ func FileExists(path string) bool {
return err == nil
}

// WriteError writes a status code and JSON-formatted error to the provided http.ResponseWriter.
// WriteError writes a status code and plaintext error to the provided http.ResponseWriter.
// The error is written as plaintext so AWS SDKs will display it inline with an error message.
func WriteError(w http.ResponseWriter, message string, status int) {
log.Debugf("writing HTTP error response: %s", message)
resp := ErrorResponse{Error: message}
respBytes, err := json.Marshal(resp)
if err != nil {
log.Errorf("could not marshal error response: %s", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(status)
_, err = w.Write(respBytes)
_, err := w.Write([]byte(message))
if err != nil {
log.Errorf("could not write error response: %s", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}

0 comments on commit 77c80e6

Please sign in to comment.