package middleware_test import ( "fmt" "net/http" "net/http/httptest" "testing" "time" "github.com/food-ai/backend/internal/middleware" "github.com/golang-jwt/jwt/v5" ) // testJWTClaims mirrors auth.Claims for test token generation without importing auth. type testJWTClaims struct { UserID string `json:"user_id"` Plan string `json:"plan"` jwt.RegisteredClaims } func generateTestToken(secret string, userID, plan string, duration time.Duration) string { claims := testJWTClaims{ UserID: userID, Plan: plan, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(duration)), IssuedAt: jwt.NewNumericDate(time.Now()), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, _ := token.SignedString([]byte(secret)) return tokenString } // testAccessValidator implements middleware.AccessTokenValidator for tests. type testAccessValidator struct { secret string } func (v *testAccessValidator) ValidateAccessToken(tokenStr string) (*middleware.TokenClaims, error) { token, parseError := jwt.ParseWithClaims(tokenStr, &testJWTClaims{}, func(t *jwt.Token) (interface{}, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method") } return []byte(v.secret), nil }) if parseError != nil { return nil, parseError } claims, ok := token.Claims.(*testJWTClaims) if !ok || !token.Valid { return nil, fmt.Errorf("invalid token") } return &middleware.TokenClaims{UserID: claims.UserID, Plan: claims.Plan}, nil } // failingValidator always returns an error. type failingValidator struct{} func (v *failingValidator) ValidateAccessToken(tokenStr string) (*middleware.TokenClaims, error) { return nil, fmt.Errorf("invalid token") } func TestAuth_ValidToken(t *testing.T) { validator := &testAccessValidator{secret: "test-secret"} token := generateTestToken("test-secret", "user-1", "free", 15*time.Minute) handler := middleware.Auth(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userID := middleware.UserIDFromCtx(r.Context()) if userID != "user-1" { t.Errorf("expected user-1, got %s", userID) } plan := middleware.UserPlanFromCtx(r.Context()) if plan != "free" { t.Errorf("expected free, got %s", plan) } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer "+token) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) } } func TestAuth_MissingHeader(t *testing.T) { validator := &testAccessValidator{secret: "test-secret"} handler := middleware.Auth(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be called") })) req := httptest.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rr.Code) } } func TestAuth_InvalidBearerFormat(t *testing.T) { validator := &testAccessValidator{secret: "test-secret"} handler := middleware.Auth(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be called") })) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Basic abc123") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rr.Code) } } func TestAuth_ExpiredToken(t *testing.T) { validator := &testAccessValidator{secret: "test-secret"} token := generateTestToken("test-secret", "user-1", "free", -1*time.Second) handler := middleware.Auth(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be called") })) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer "+token) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rr.Code) } } func TestAuth_InvalidToken(t *testing.T) { validator := &testAccessValidator{secret: "test-secret"} handler := middleware.Auth(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be called") })) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer invalid-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rr.Code) } } func TestAuth_PaidPlan(t *testing.T) { validator := &testAccessValidator{secret: "test-secret"} token := generateTestToken("test-secret", "user-1", "paid", 15*time.Minute) handler := middleware.Auth(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { plan := middleware.UserPlanFromCtx(r.Context()) if plan != "paid" { t.Errorf("expected paid, got %s", plan) } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer "+token) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) } } func TestAuth_EmptyBearer(t *testing.T) { handler := middleware.Auth(&failingValidator{})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be called") })) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer ") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rr.Code) } }