Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtlynch committed Feb 15, 2025
1 parent 6e58239 commit b230612
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 11 deletions.
35 changes: 27 additions & 8 deletions handlers/auth/shared_secret/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,25 @@ package http
import (
"encoding/base64"
"encoding/json"
"errors"
"net/http"
"time"

"github.com/mtlynch/picoshare/v2/handlers/auth/shared_secret/kdf"
)

const authCookieName = "sharedSecret"

var (
// ErrInvalidCredentials indicates that the provided credentials are incorrect.
ErrInvalidCredentials = errors.New("incorrect shared secret")

// ErrEmptyCredentials indicates that no credentials were provided.
ErrEmptyCredentials = errors.New("invalid shared secret")

// ErrMalformedRequest indicates that the request body is malformed.
ErrMalformedRequest = errors.New("malformed request")
)

// Authenticator handles HTTP authentication using shared secrets.
type Authenticator struct {
kdf kdf.KDF
Expand All @@ -35,12 +46,17 @@ func New(sharedSecretKey string) (*Authenticator, error) {
func (a *Authenticator) StartSession(w http.ResponseWriter, r *http.Request) {
secret, err := a.sharedSecretFromRequest(r)
if err != nil {
http.Error(w, "Invalid shared secret", http.StatusBadRequest)
switch err {
case ErrMalformedRequest, ErrEmptyCredentials:
http.Error(w, err.Error(), http.StatusBadRequest)
default:
http.Error(w, ErrInvalidCredentials.Error(), http.StatusUnauthorized)
}
return
}

if !a.kdf.Compare(secret, a.secret) {
http.Error(w, "Incorrect shared secret", http.StatusUnauthorized)
http.Error(w, ErrInvalidCredentials.Error(), http.StatusUnauthorized)
return
}

Expand Down Expand Up @@ -70,7 +86,7 @@ func (a *Authenticator) ClearSession(w http.ResponseWriter) {
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Expires: time.Unix(0, 0),
MaxAge: -1,
})
}

Expand All @@ -79,9 +95,12 @@ func (a *Authenticator) sharedSecretFromRequest(r *http.Request) ([]byte, error)
SharedSecretKey string `json:"sharedSecretKey"`
}{}
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&body)
if err != nil {
return nil, err
if err := decoder.Decode(&body); err != nil {
return nil, ErrMalformedRequest
}

if body.SharedSecretKey == "" {
return nil, ErrEmptyCredentials
}

return a.kdf.DeriveFromKey([]byte(body.SharedSecretKey))
Expand All @@ -94,6 +113,6 @@ func (a *Authenticator) createCookie(w http.ResponseWriter) {
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Expires: time.Now().Add(time.Hour * 24 * 30),
MaxAge: 30 * 24 * 60 * 60, // 30 days in seconds
})
}
217 changes: 217 additions & 0 deletions handlers/auth/shared_secret/http/handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
package http_test

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

httpAuth "github.com/mtlynch/picoshare/v2/handlers/auth/shared_secret/http"
)

func TestStartSession(t *testing.T) {
for _, tt := range []struct {
description string
secretKey string
requestBody string
err error
}{
{
description: "accept valid credentials",
secretKey: "mysecret",
requestBody: `{"sharedSecretKey": "mysecret"}`,
err: nil,
},
{
description: "reject invalid credentials",
secretKey: "mysecret",
requestBody: `{"sharedSecretKey": "wrongsecret"}`,
err: httpAuth.ErrInvalidCredentials,
},
{
description: "reject empty credentials",
secretKey: "mysecret",
requestBody: `{"sharedSecretKey": ""}`,
err: httpAuth.ErrEmptyCredentials,
},
{
description: "reject malformed JSON",
secretKey: "mysecret",
requestBody: `{malformed`,
err: httpAuth.ErrMalformedRequest,
},
} {
t.Run(fmt.Sprintf("%s [%s]", tt.description, tt.requestBody), func(t *testing.T) {
auth, err := httpAuth.New(tt.secretKey)
if err != nil {
t.Fatalf("failed to create authenticator: %v", err)
}

req := httptest.NewRequest(http.MethodPost, "/auth", bytes.NewBufferString(tt.requestBody))
w := httptest.NewRecorder()

auth.StartSession(w, req)

resp := w.Result()
body, _ := io.ReadAll(resp.Body)
if got, want := getError(resp.StatusCode, strings.TrimSpace(string(body))), tt.err; got != want {
t.Fatalf("err=%v, want=%v", got, want)
}

if tt.err == nil {
cookie := getCookie(t, resp)
if got, want := cookie.Name, "sharedSecret"; got != want {
t.Errorf("cookie name=%v, want=%v", got, want)
}
if !cookie.HttpOnly {
t.Error("cookie is not HTTP-only")
}
if got, want := cookie.MaxAge, 30*24*60*60; got != want {
t.Errorf("cookie MaxAge=%v, want=%v", got, want)
}
}
})
}
}

func TestAuthenticate(t *testing.T) {
for _, tt := range []struct {
description string
secretKey string
cookieVal string
want bool
}{
{
description: "accept valid cookie",
secretKey: "mysecret",
cookieVal: createValidCookie(t, "mysecret"),
want: true,
},
{
description: "reject invalid cookie",
secretKey: "mysecret",
cookieVal: createValidCookie(t, "wrongsecret"),
want: false,
},
{
description: "reject empty cookie",
secretKey: "mysecret",
cookieVal: "",
want: false,
},
{
description: "reject malformed base64 cookie",
secretKey: "mysecret",
cookieVal: "not-base64!",
want: false,
},
} {
t.Run(fmt.Sprintf("%s [%s]", tt.description, tt.cookieVal), func(t *testing.T) {
auth, err := httpAuth.New(tt.secretKey)
if err != nil {
t.Fatalf("failed to create authenticator: %v", err)
}

req := httptest.NewRequest(http.MethodGet, "/", nil)
if tt.cookieVal != "" {
req.AddCookie(&http.Cookie{
Name: "sharedSecret",
Value: tt.cookieVal,
})
}

if got, want := auth.Authenticate(req), tt.want; got != want {
t.Errorf("got=%v, want=%v", got, want)
}
})
}
}

func TestClearSession(t *testing.T) {
auth, err := httpAuth.New("mysecret")
if err != nil {
t.Fatalf("failed to create authenticator: %v", err)
}

w := httptest.NewRecorder()
auth.ClearSession(w)

resp := w.Result()
cookie := getCookie(t, resp)

if got, want := cookie.Name, "sharedSecret"; got != want {
t.Errorf("cookie name=%v, want=%v", got, want)
}
if got, want := cookie.Value, ""; got != want {
t.Errorf("cookie value=%v, want=%v", got, want)
}
if !cookie.HttpOnly {
t.Error("cookie is not HTTP-only")
}
if got, want := cookie.MaxAge, -1; got != want {
t.Errorf("cookie MaxAge=%v, want=%v", got, want)
}
}

// Helper function to get error from status code and response body
func getError(statusCode int, body string) error {
switch statusCode {
case http.StatusOK:
return nil
case http.StatusUnauthorized:
return httpAuth.ErrInvalidCredentials
case http.StatusBadRequest:
switch body {
case httpAuth.ErrEmptyCredentials.Error():
return httpAuth.ErrEmptyCredentials
default:
return httpAuth.ErrMalformedRequest
}
default:
return fmt.Errorf("unexpected status code: %d", statusCode)
}
}

// Helper function to get cookie from response
func getCookie(t *testing.T, resp *http.Response) *http.Cookie {
t.Helper()
cookies := resp.Cookies()
if len(cookies) != 1 {
t.Fatalf("got %d cookies, want 1", len(cookies))
}
return cookies[0]
}

// Helper function to create a valid cookie value for testing
func createValidCookie(t *testing.T, secret string) string {
t.Helper()
auth, err := httpAuth.New(secret)
if err != nil {
t.Fatalf("failed to create authenticator: %v", err)
}

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/auth", createJSONBody(t, secret))
auth.StartSession(w, req)

return getCookie(t, w.Result()).Value
}

// Helper function to create a JSON request body
func createJSONBody(t *testing.T, secret string) *bytes.Buffer {
t.Helper()
body := struct {
SharedSecretKey string `json:"sharedSecretKey"`
}{
SharedSecretKey: secret,
}
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(body); err != nil {
t.Fatalf("failed to encode JSON: %v", err)
}
return &buf
}
14 changes: 11 additions & 3 deletions handlers/auth/shared_secret/kdf/kdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ import (
"golang.org/x/crypto/pbkdf2"
)

var (
// ErrInvalidKey indicates that the provided key is empty or invalid.
ErrInvalidKey = errors.New("invalid shared secret key")

// ErrInvalidBase64 indicates that the provided base64 string is empty or malformed.
ErrInvalidBase64 = errors.New("invalid shared secret")
)

// KDF defines the interface for key derivation operations.
type KDF interface {
DeriveFromKey(key []byte) ([]byte, error)
Expand Down Expand Up @@ -37,7 +45,7 @@ func New() KDF {
// DeriveFromKey derives a key using PBKDF2.
func (k *pbkdf2KDF) DeriveFromKey(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, errors.New("invalid shared secret key")
return nil, ErrInvalidKey
}

dk := pbkdf2.Key(key, k.salt, k.iter, k.keyLength, sha256.New)
Expand All @@ -47,12 +55,12 @@ func (k *pbkdf2KDF) DeriveFromKey(key []byte) ([]byte, error) {
// FromBase64 decodes a base64-encoded key.
func (k *pbkdf2KDF) FromBase64(b64encoded string) ([]byte, error) {
if len(b64encoded) == 0 {
return nil, errors.New("invalid shared secret")
return nil, ErrInvalidBase64
}

decoded, err := base64.StdEncoding.DecodeString(b64encoded)
if err != nil {
return nil, err
return nil, ErrInvalidBase64
}

return decoded, nil
Expand Down
Loading

0 comments on commit b230612

Please sign in to comment.