package middleware import ( "context" "net/http" "strings" ) const ( userIDKey contextKey = "user_id" userPlanKey contextKey = "user_plan" ) // TokenClaims represents the result of validating an access token. type TokenClaims struct { UserID string Plan string } // AccessTokenValidator validates JWT access tokens. type AccessTokenValidator interface { ValidateAccessToken(tokenStr string) (*TokenClaims, error) } func Auth(validator AccessTokenValidator) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { header := r.Header.Get("Authorization") if !strings.HasPrefix(header, "Bearer ") { http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized) return } tokenStr := strings.TrimPrefix(header, "Bearer ") claims, err := validator.ValidateAccessToken(tokenStr) if err != nil { http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), userIDKey, claims.UserID) ctx = context.WithValue(ctx, userPlanKey, claims.Plan) next.ServeHTTP(w, r.WithContext(ctx)) }) } } func UserIDFromCtx(ctx context.Context) string { id, _ := ctx.Value(userIDKey).(string) return id } func UserPlanFromCtx(ctx context.Context) string { plan, _ := ctx.Value(userPlanKey).(string) return plan }