142 lines
3.7 KiB
Go
142 lines
3.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
type stubVerifier struct {
|
|
userID uint
|
|
err error
|
|
token string
|
|
}
|
|
|
|
func (s *stubVerifier) VerifyToken(token string) (uint, error) {
|
|
s.token = token
|
|
if s.err != nil {
|
|
return 0, s.err
|
|
}
|
|
return s.userID, nil
|
|
}
|
|
|
|
func TestNewAuthWithoutAuthorization(t *testing.T) {
|
|
verifier := &stubVerifier{userID: 42}
|
|
called := false
|
|
|
|
middleware := NewAuth(verifier)
|
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
if id := GetUserIDFromContext(r.Context()); id != 0 {
|
|
t.Fatalf("unexpected user id %d", id)
|
|
}
|
|
}))
|
|
|
|
recorder := httptest.NewRecorder()
|
|
request := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
if called {
|
|
t.Fatal("expected next handler NOT to be called when no authorization header")
|
|
}
|
|
|
|
if recorder.Result().StatusCode != http.StatusUnauthorized {
|
|
t.Fatalf("expected status 401, got %d", recorder.Result().StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestNewAuthValidToken(t *testing.T) {
|
|
verifier := &stubVerifier{userID: 99}
|
|
middleware := NewAuth(verifier)
|
|
|
|
handlerCalled := false
|
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
if id := GetUserIDFromContext(r.Context()); id != 99 {
|
|
t.Fatalf("expected user id 99, got %d", id)
|
|
}
|
|
}))
|
|
|
|
recorder := httptest.NewRecorder()
|
|
request := httptest.NewRequest(http.MethodGet, "/secure", nil)
|
|
request.Header.Set("Authorization", "Bearer token-123")
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
if !handlerCalled {
|
|
t.Fatal("expected handler to be called for valid token")
|
|
}
|
|
|
|
if verifier.token != "token-123" {
|
|
t.Fatalf("expected verifier to receive token-123, got %q", verifier.token)
|
|
}
|
|
|
|
if recorder.Result().StatusCode != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", recorder.Result().StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestNewAuthInvalidHeaders(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
header string
|
|
status int
|
|
}{
|
|
{name: "MissingBearer", header: "Token value", status: http.StatusUnauthorized},
|
|
{name: "EmptyToken", header: "Bearer ", status: http.StatusUnauthorized},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
verifier := &stubVerifier{userID: 1}
|
|
middleware := NewAuth(verifier)
|
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatal("handler should not be called")
|
|
}))
|
|
|
|
recorder := httptest.NewRecorder()
|
|
request := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
request.Header.Set("Authorization", tc.header)
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
if recorder.Result().StatusCode != tc.status {
|
|
t.Fatalf("expected status %d, got %d", tc.status, recorder.Result().StatusCode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNewAuthVerifierError(t *testing.T) {
|
|
verifier := &stubVerifier{err: http.ErrNoCookie}
|
|
middleware := NewAuth(verifier)
|
|
|
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatal("handler should not be called when verifier fails")
|
|
}))
|
|
|
|
recorder := httptest.NewRecorder()
|
|
request := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
request.Header.Set("Authorization", "Bearer token-xyz")
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
if recorder.Result().StatusCode != http.StatusUnauthorized {
|
|
t.Fatalf("expected 401 when verifier fails, got %d", recorder.Result().StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestGetUserIDFromContext(t *testing.T) {
|
|
ctx := context.WithValue(context.Background(), UserIDKey, uint(55))
|
|
|
|
if id := GetUserIDFromContext(ctx); id != 55 {
|
|
t.Fatalf("expected id 55, got %d", id)
|
|
}
|
|
|
|
if id := GetUserIDFromContext(context.Background()); id != 0 {
|
|
t.Fatalf("expected zero when id missing, got %d", id)
|
|
}
|
|
}
|