package recognition import ( "context" "encoding/json" "fmt" "log/slog" "net/http" "sync" "github.com/go-chi/chi/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/food-ai/backend/internal/infra/middleware" ) type sseEvent struct { name string data string } // SSEBroker manages Server-Sent Events for job status updates. // It listens on the PostgreSQL "job_update" NOTIFY channel and fans out // events to all HTTP clients currently streaming a given job. type SSEBroker struct { pool *pgxpool.Pool jobRepo JobRepository mu sync.RWMutex clients map[string][]chan sseEvent } // NewSSEBroker creates a new SSEBroker. func NewSSEBroker(pool *pgxpool.Pool, jobRepo JobRepository) *SSEBroker { return &SSEBroker{ pool: pool, jobRepo: jobRepo, clients: make(map[string][]chan sseEvent), } } // Start launches the PostgreSQL LISTEN loop in a background goroutine. func (broker *SSEBroker) Start(brokerContext context.Context) { go broker.listenLoop(brokerContext) } func (broker *SSEBroker) listenLoop(brokerContext context.Context) { conn, acquireError := broker.pool.Acquire(brokerContext) if acquireError != nil { slog.Error("SSEBroker: acquire PG connection", "err", acquireError) return } defer conn.Release() if _, listenError := conn.Exec(brokerContext, "LISTEN job_update"); listenError != nil { slog.Error("SSEBroker: LISTEN job_update", "err", listenError) return } for { notification, waitError := conn.Conn().WaitForNotification(brokerContext) if brokerContext.Err() != nil { return } if waitError != nil { slog.Error("SSEBroker: wait for notification", "err", waitError) return } broker.fanOut(brokerContext, notification.Payload) } } func (broker *SSEBroker) subscribe(jobID string) chan sseEvent { channel := make(chan sseEvent, 10) broker.mu.Lock() broker.clients[jobID] = append(broker.clients[jobID], channel) broker.mu.Unlock() return channel } func (broker *SSEBroker) unsubscribe(jobID string, channel chan sseEvent) { broker.mu.Lock() defer broker.mu.Unlock() existing := broker.clients[jobID] for index, existing := range existing { if existing == channel { broker.clients[jobID] = append(broker.clients[jobID][:index], broker.clients[jobID][index+1:]...) break } } if len(broker.clients[jobID]) == 0 { delete(broker.clients, jobID) } } func (broker *SSEBroker) fanOut(fanContext context.Context, jobID string) { job, fetchError := broker.jobRepo.GetJobByID(fanContext, jobID) if fetchError != nil { slog.Warn("SSEBroker: get job for fan-out", "job_id", jobID, "err", fetchError) return } event, ok := jobToSSEEvent(job) if !ok { return } broker.mu.RLock() channels := make([]chan sseEvent, len(broker.clients[jobID])) copy(channels, broker.clients[jobID]) broker.mu.RUnlock() for _, channel := range channels { select { case channel <- event: default: // channel full; skip this delivery } } } func jobToSSEEvent(job *Job) (sseEvent, bool) { switch job.Status { case JobStatusProcessing: return sseEvent{name: "processing", data: "{}"}, true case JobStatusDone: resultJSON, marshalError := json.Marshal(job.Result) if marshalError != nil { return sseEvent{}, false } return sseEvent{name: "done", data: string(resultJSON)}, true case JobStatusFailed: errMsg := "recognition failed, please try again" if job.Error != nil { errMsg = *job.Error } errorData, _ := json.Marshal(map[string]string{"error": errMsg}) return sseEvent{name: "failed", data: string(errorData)}, true default: return sseEvent{}, false } } // ServeSSE handles GET /ai/jobs/{id}/stream — streams SSE events until the job completes. func (broker *SSEBroker) ServeSSE(responseWriter http.ResponseWriter, request *http.Request) { jobID := chi.URLParam(request, "id") userID := middleware.UserIDFromCtx(request.Context()) job, fetchError := broker.jobRepo.GetJobByID(request.Context(), jobID) if fetchError != nil { writeErrorJSON(responseWriter, request, http.StatusNotFound, "job not found") return } if job.UserID != userID { writeErrorJSON(responseWriter, request, http.StatusForbidden, "forbidden") return } flusher, supported := responseWriter.(http.Flusher) if !supported { writeErrorJSON(responseWriter, request, http.StatusInternalServerError, "streaming not supported") return } responseWriter.Header().Set("Content-Type", "text/event-stream") responseWriter.Header().Set("Cache-Control", "no-cache") responseWriter.Header().Set("Connection", "keep-alive") responseWriter.Header().Set("X-Accel-Buffering", "no") // If the job is already in a terminal state, send the event immediately. if job.Status == JobStatusDone || job.Status == JobStatusFailed { if event, ok := jobToSSEEvent(job); ok { fmt.Fprintf(responseWriter, "event: %s\ndata: %s\n\n", event.name, event.data) flusher.Flush() } return } // Subscribe to future notifications before sending the queued event to // avoid a race where the job completes between reading the current state // and registering the subscriber. eventChannel := broker.subscribe(jobID) defer broker.unsubscribe(jobID, eventChannel) // Send initial queued event with estimated wait. position, _ := broker.jobRepo.QueuePosition(request.Context(), job.UserPlan, job.CreatedAt) estimatedSeconds := (position + 1) * 6 queuedData, _ := json.Marshal(map[string]any{ "position": position, "estimated_seconds": estimatedSeconds, }) fmt.Fprintf(responseWriter, "event: queued\ndata: %s\n\n", queuedData) flusher.Flush() for { select { case event := <-eventChannel: fmt.Fprintf(responseWriter, "event: %s\ndata: %s\n\n", event.name, event.data) flusher.Flush() if event.name == "done" || event.name == "failed" { return } case <-request.Context().Done(): return } } }