package recognition import ( "context" "encoding/json" "fmt" "log/slog" "net/http" "sync" "github.com/go-chi/chi/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/food-ai/backend/internal/infra/middleware" ) // ProductSSEBroker manages Server-Sent Events for product recognition job status updates. // It listens on the PostgreSQL "product_job_update" NOTIFY channel and fans out events // to all HTTP clients currently streaming a given job. type ProductSSEBroker struct { pool *pgxpool.Pool productJobRepo ProductJobRepository mu sync.RWMutex clients map[string][]chan sseEvent } // NewProductSSEBroker creates a new ProductSSEBroker. func NewProductSSEBroker(pool *pgxpool.Pool, productJobRepo ProductJobRepository) *ProductSSEBroker { return &ProductSSEBroker{ pool: pool, productJobRepo: productJobRepo, clients: make(map[string][]chan sseEvent), } } // Start launches the PostgreSQL LISTEN loop in a background goroutine. func (broker *ProductSSEBroker) Start(brokerContext context.Context) { go broker.listenLoop(brokerContext) } func (broker *ProductSSEBroker) listenLoop(brokerContext context.Context) { conn, acquireError := broker.pool.Acquire(brokerContext) if acquireError != nil { slog.Error("ProductSSEBroker: acquire PG connection", "err", acquireError) return } defer conn.Release() if _, listenError := conn.Exec(brokerContext, "LISTEN product_job_update"); listenError != nil { slog.Error("ProductSSEBroker: LISTEN product_job_update", "err", listenError) return } for { notification, waitError := conn.Conn().WaitForNotification(brokerContext) if brokerContext.Err() != nil { return } if waitError != nil { slog.Error("ProductSSEBroker: wait for notification", "err", waitError) return } broker.fanOut(brokerContext, notification.Payload) } } func (broker *ProductSSEBroker) subscribe(jobID string) chan sseEvent { channel := make(chan sseEvent, 10) broker.mu.Lock() broker.clients[jobID] = append(broker.clients[jobID], channel) broker.mu.Unlock() return channel } func (broker *ProductSSEBroker) unsubscribe(jobID string, channel chan sseEvent) { broker.mu.Lock() defer broker.mu.Unlock() existing := broker.clients[jobID] for index, existingChannel := range existing { if existingChannel == channel { broker.clients[jobID] = append(broker.clients[jobID][:index], broker.clients[jobID][index+1:]...) break } } if len(broker.clients[jobID]) == 0 { delete(broker.clients, jobID) } } func (broker *ProductSSEBroker) fanOut(fanContext context.Context, jobID string) { job, fetchError := broker.productJobRepo.GetProductJobByID(fanContext, jobID) if fetchError != nil { slog.Warn("ProductSSEBroker: get job for fan-out", "job_id", jobID, "err", fetchError) return } event, ok := productJobToSSEEvent(job) if !ok { return } broker.mu.RLock() channels := make([]chan sseEvent, len(broker.clients[jobID])) copy(channels, broker.clients[jobID]) broker.mu.RUnlock() for _, channel := range channels { select { case channel <- event: default: // channel full; skip this delivery } } } func productJobToSSEEvent(job *ProductJob) (sseEvent, bool) { switch job.Status { case JobStatusProcessing: return sseEvent{name: "processing", data: "{}"}, true case JobStatusDone: resultJSON, marshalError := json.Marshal(job.Result) if marshalError != nil { return sseEvent{}, false } return sseEvent{name: "done", data: string(resultJSON)}, true case JobStatusFailed: errMsg := "recognition failed, please try again" if job.Error != nil { errMsg = *job.Error } errorData, _ := json.Marshal(map[string]string{"error": errMsg}) return sseEvent{name: "failed", data: string(errorData)}, true default: return sseEvent{}, false } } // ServeSSE handles GET /ai/product-jobs/{id}/stream — streams SSE events until the job completes. func (broker *ProductSSEBroker) ServeSSE(responseWriter http.ResponseWriter, request *http.Request) { jobID := chi.URLParam(request, "id") userID := middleware.UserIDFromCtx(request.Context()) job, fetchError := broker.productJobRepo.GetProductJobByID(request.Context(), jobID) if fetchError != nil { writeErrorJSON(responseWriter, request, http.StatusNotFound, "job not found") return } if job.UserID != userID { writeErrorJSON(responseWriter, request, http.StatusForbidden, "forbidden") return } flusher, supported := responseWriter.(http.Flusher) if !supported { writeErrorJSON(responseWriter, request, http.StatusInternalServerError, "streaming not supported") return } responseWriter.Header().Set("Content-Type", "text/event-stream") responseWriter.Header().Set("Cache-Control", "no-cache") responseWriter.Header().Set("Connection", "keep-alive") responseWriter.Header().Set("X-Accel-Buffering", "no") if job.Status == JobStatusDone || job.Status == JobStatusFailed { if event, ok := productJobToSSEEvent(job); ok { fmt.Fprintf(responseWriter, "event: %s\ndata: %s\n\n", event.name, event.data) flusher.Flush() } return } eventChannel := broker.subscribe(jobID) defer broker.unsubscribe(jobID, eventChannel) position, _ := broker.productJobRepo.ProductQueuePosition(request.Context(), job.UserPlan, job.CreatedAt) estimatedSeconds := (position + 1) * 6 queuedData, _ := json.Marshal(map[string]any{ "position": position, "estimated_seconds": estimatedSeconds, }) fmt.Fprintf(responseWriter, "event: queued\ndata: %s\n\n", queuedData) flusher.Flush() for { select { case event := <-eventChannel: fmt.Fprintf(responseWriter, "event: %s\ndata: %s\n\n", event.name, event.data) flusher.Flush() if event.name == "done" || event.name == "failed" { return } case <-request.Context().Done(): return } } }