This repository has been archived by the owner on Jan 15, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathhttp.go
294 lines (251 loc) · 9.28 KB
/
http.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
package awsconsoleauth
import (
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/crowdmob/goamz/sts"
"github.com/dgrijalva/jwt-go"
"github.com/drone/config"
"golang.org/x/oauth2"
)
var trustXForwarded = config.Bool("trust-x-forwarded", false)
var loginTimeout = config.Duration("google-login-timeout", time.Second*120)
// We reuse the Google client secret as the web secret.
var secret = googleClientSecret
// getOriginUrl returns the HTTP origin string (i.e.
// https://alice.example.com, or http://localhost:8000)
func getOriginURL(r *http.Request) string {
scheme := "https"
if r.TLS == nil {
scheme = "http"
}
if *trustXForwarded {
scheme = r.Header.Get("X-Forwarded-Proto")
}
return fmt.Sprintf("%s://%s", scheme, r.Host)
}
func getRemoteAddress(r *http.Request) string {
remoteAddr := r.RemoteAddr
if *trustXForwarded {
forwardedFor := strings.Split(r.Header.Get("X-Forwarded-For"), ",")
remoteAddr = strings.TrimSpace(forwardedFor[len(forwardedFor)-1])
}
return remoteAddr
}
func newTokenFromRefreshToken(refreshToken string) (*oauth2.Token, error) {
token := new(oauth2.Token)
token.RefreshToken = refreshToken
token.Expiry = time.Now()
// TokenSource will refresh the token if needed (which is likely in this
// use case)
oauthConfig := *googleOauthConfig
ts := oauthConfig.TokenSource(oauth2.NoContext, token)
return ts.Token()
}
// GetRoot handles requests for '/' by redirecting to the Google OAuth URL.
//
// Any query string arguments are passed securely through the OAuth flow to
// /oauth2callback
func GetRoot(w http.ResponseWriter, r *http.Request) {
// If refresh_token is provided, we can use it to get a new ID token to
// identify the user, avoiding the whole OAuth flow and allowing for automation.
refreshToken := r.FormValue("refresh_token")
if refreshToken != "" {
token, err := newTokenFromRefreshToken(refreshToken)
if err == nil {
FetchCredentialsForToken(w, r, token, r.URL.RawQuery)
return
}
// Otherwise we just fallback to the default flow below
}
state, err := generateState(r)
if err != nil {
log.Printf("ERROR: cannot generate token: %s", err)
http.Error(w, "Authentication failed", 500)
return
}
oauthConfig := *googleOauthConfig
oauthConfig.RedirectURL = fmt.Sprintf("%s/oauth2callback", getOriginURL(r))
url := oauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline)
http.Redirect(w, r, url, http.StatusFound)
}
// generateState builds a JWT that we can use as /state/ across the oauth
// request to mitigate CSRF. When handling the callback we use validateState()
// to make sure that the callback request corresponds to a valid request we
// emitted.
func generateState(r *http.Request) (string, error) {
token := jwt.New(jwt.GetSigningMethod("HS256"))
token.Claims["ua"] = r.Header.Get("User-Agent")
token.Claims["ra"] = getRemoteAddress(r)
token.Claims["exp"] = time.Now().Add(*loginTimeout).Unix()
token.Claims["query"] = r.URL.RawQuery
state, err := token.SignedString([]byte(*secret))
return state, err
}
// valididateState checks that the state parameter is valid and returns nil if
// so, otherwise it returns a non-nill error value. (Do not show the returned
// error to the user, it might contain security-sensitive information)
func valididateState(r *http.Request, state string) (string, error) {
token, err := jwt.Parse(state, func(t *jwt.Token) (interface{}, error) {
return []byte(*secret), nil
})
if err != nil {
return "", err
}
if !token.Valid {
// TODO(ross): this branch is never executed, I think because if the
// token is invalid `err` is never nil from Parse()
return "", fmt.Errorf("Invalid Token")
}
if token.Claims["ra"].(string) != getRemoteAddress(r) {
return "", fmt.Errorf("Wrong remote address. Expected %#v, got %#v",
token.Claims["ra"].(string), getRemoteAddress(r))
}
if token.Claims["ua"].(string) != r.Header.Get("User-Agent") {
return "", fmt.Errorf("Wrong user agent. Expected %#v, got %#v",
token.Claims["ua"].(string), r.Header.Get("User-Agent"))
}
return token.Claims["query"].(string), nil
}
// GetCallback handles requests for '/oauth2callback' by validating the oauth
// response, determining the user's group membership, determining the user's
// AWS policy, fetching credentials and (optionally) redirecting to the console.
//
// The returned document is controlled by the `view` argument passed to the
// root URL.
//
// - if view=sh is specified, then a bash-compatible script is returned with the
// credentials
// - if view=csh is specified, then a csh-compatible script is returned with
// the credentials
// - if view=fish is specified, then a fish-compatible script is returned with
// the credentials
// - otherwise we redirect to the AWS Console. If `uri` is specified it is
// appended to the end of the aws console url.
//
// For example:
//
// - https://aws.example.com/?view=sh -> returns a bash script
//
// - https://aws.example.com/?uri=/s3/home -> redirects to https://console.aws.amazon.com/s3/home
//
func GetCallback(w http.ResponseWriter, r *http.Request) {
oauthConfig := *googleOauthConfig
oauthConfig.RedirectURL = fmt.Sprintf("%s/oauth2callback", getOriginURL(r))
token, err := oauthConfig.Exchange(oauth2.NoContext, r.FormValue("code"))
if err != nil {
log.Printf("oauth exchange failed: %s", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
rawQuery, err := valididateState(r, r.FormValue("state"))
if err != nil {
log.Printf("ERROR: state: %s", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
FetchCredentialsForToken(w, r, token, rawQuery)
}
// Utility method which gets the AWS credentials for the given OAuth token
func FetchCredentialsForToken(w http.ResponseWriter, r *http.Request,
token *oauth2.Token, rawQuery string) {
user, err := GetUserFromGoogleOauthToken(token.Extra("id_token").(string))
if err != nil {
log.Printf("failed to parse google id_token: %s", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
groups, err := GetUserGroups(user)
if err != nil {
log.Printf("failed to fetch google group membership for %s: %s", user, err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
policy, err := MapUserAndGroupsToPolicy(user, groups)
if err != nil {
log.Printf("failed to determine policy for %s: %s", user, err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
if policy == nil {
log.Printf("no matching policy for %s", user)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
credentials, err := GetCredentials(user, policy.Policy, time.Second*43200)
if err != nil {
log.Printf("failed to get credentials for %s: %s", user, err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
query, err := url.ParseQuery(rawQuery)
if err != nil {
log.Printf("ERROR: parse query: %s", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
fmt.Printf("login %s from %s with policy %s key %s\n", user, getRemoteAddress(r),
policy.Name, credentials.AccessKeyId)
RespondWithCredentials(w, r, credentials, query, token)
}
func formatEnvVariableForShell(name, value, shell string) string {
switch shell {
case "csh":
return fmt.Sprintf("setenv %s \"%s\"\n", name, value)
case "fish":
return fmt.Sprintf("set -x %s \"%s\"\n", name, value)
default:
return fmt.Sprintf("export %s=\"%s\"\n", name, value)
}
}
// RespondWithCredentials response to the oauth callback request based on the
// query parameters and the specified credentials.
func RespondWithCredentials(w http.ResponseWriter, r *http.Request,
credentials *sts.Credentials, query url.Values, token *oauth2.Token) {
view := query.Get("view")
if query.Get("action") == "key" {
view = "sh"
}
if view == "sh" || view == "csh" || view == "fish" {
w.Header().Set("Content-type", "text-plain")
fmt.Fprintf(w, "# expires %s\n", credentials.Expiration)
fmt.Fprint(w, formatEnvVariableForShell("AWS_ACCESS_KEY_ID", credentials.AccessKeyId, view))
fmt.Fprint(w, formatEnvVariableForShell("AWS_SECRET_ACCESS_KEY", credentials.SecretAccessKey, view))
fmt.Fprint(w, formatEnvVariableForShell("AWS_SESSION_TOKEN", credentials.SessionToken, view))
if token.RefreshToken != "" {
fmt.Fprint(w, "# add this as a refresh_token param in the future to avoid the OAuth flow\n")
fmt.Fprint(w, formatEnvVariableForShell("OAUTH_REFRESH_TOKEN", token.RefreshToken, view))
}
return
}
redirectURL, err := GetAWSConsoleURL(credentials, query.Get("uri"))
if err != nil {
log.Printf("ERROR: %s", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
http.Redirect(w, r, redirectURL, http.StatusFound)
}
// Initialize sets up the web server and binds the URI patterns for the
// authorization service.
func Initialize() error {
if err := InitializeGoogleGroup(); err != nil {
return fmt.Errorf("InitializeGoogleGroup: %s", err)
}
if err := InitializeGoogleLogin(); err != nil {
return fmt.Errorf("InitializeGoogleLogin: %s", err)
}
if err := InitializeAWS(); err != nil {
return fmt.Errorf("InitializeAWS: %s", err)
}
// TODO(ross): Strict-Transport-Security: max-age=31536000; includeSubDomains; preload
http.HandleFunc("/", GetRoot)
http.HandleFunc("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not Found", http.StatusNotFound)
})
http.HandleFunc("/oauth2callback", GetCallback)
return nil
}