package userproduct import ( "context" "errors" "fmt" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) // ErrNotFound is returned when a user product is not found or does not belong to the user. var ErrNotFound = errors.New("user product not found") // Repository handles user product persistence. type Repository struct { pool *pgxpool.Pool } // NewRepository creates a new Repository. func NewRepository(pool *pgxpool.Pool) *Repository { return &Repository{pool: pool} } // expires_at is computed in SQL because TIMESTAMPTZ + INTERVAL is STABLE (not IMMUTABLE), // which prevents it from being used as a stored generated column. const selectCols = `id, user_id, primary_product_id, name, quantity, unit, category, storage_days, added_at, (added_at + storage_days * INTERVAL '1 day') AS expires_at` // List returns all user products sorted by expires_at ASC. func (r *Repository) List(requestContext context.Context, userID string) ([]*UserProduct, error) { rows, queryError := r.pool.Query(requestContext, ` SELECT `+selectCols+` FROM user_products WHERE user_id = $1 ORDER BY expires_at ASC`, userID) if queryError != nil { return nil, fmt.Errorf("list user products: %w", queryError) } defer rows.Close() return collectUserProducts(rows) } // Create inserts a new user product and returns the created record. func (r *Repository) Create(requestContext context.Context, userID string, req CreateRequest) (*UserProduct, error) { storageDays := req.StorageDays if storageDays <= 0 { storageDays = 7 } unit := req.Unit if unit == "" { unit = "pcs" } qty := req.Quantity if qty <= 0 { qty = 1 } // Accept both new and legacy field names. primaryID := req.PrimaryProductID if primaryID == nil { primaryID = req.MappingID } row := r.pool.QueryRow(requestContext, ` INSERT INTO user_products (user_id, primary_product_id, name, quantity, unit, category, storage_days) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING `+selectCols, userID, primaryID, req.Name, qty, unit, req.Category, storageDays, ) return scanUserProduct(row) } // BatchCreate inserts multiple user products sequentially and returns all created records. func (r *Repository) BatchCreate(requestContext context.Context, userID string, items []CreateRequest) ([]*UserProduct, error) { var result []*UserProduct for _, req := range items { userProduct, createError := r.Create(requestContext, userID, req) if createError != nil { return nil, fmt.Errorf("batch create user product %q: %w", req.Name, createError) } result = append(result, userProduct) } return result, nil } // Update modifies an existing user product. Only non-nil fields are changed. // Returns ErrNotFound if the product does not exist or belongs to a different user. func (r *Repository) Update(requestContext context.Context, id, userID string, req UpdateRequest) (*UserProduct, error) { row := r.pool.QueryRow(requestContext, ` UPDATE user_products SET name = COALESCE($3, name), quantity = COALESCE($4, quantity), unit = COALESCE($5, unit), category = COALESCE($6, category), storage_days = COALESCE($7, storage_days) WHERE id = $1 AND user_id = $2 RETURNING `+selectCols, id, userID, req.Name, req.Quantity, req.Unit, req.Category, req.StorageDays, ) userProduct, scanError := scanUserProduct(row) if errors.Is(scanError, pgx.ErrNoRows) { return nil, ErrNotFound } return userProduct, scanError } // DeleteAll removes all user products for the given user. func (r *Repository) DeleteAll(requestContext context.Context, userID string) error { _, execError := r.pool.Exec(requestContext, `DELETE FROM user_products WHERE user_id = $1`, userID) if execError != nil { return fmt.Errorf("delete all user products: %w", execError) } return nil } // Delete removes a user product. Returns ErrNotFound if it does not exist or belongs to a different user. func (r *Repository) Delete(requestContext context.Context, id, userID string) error { tag, execError := r.pool.Exec(requestContext, `DELETE FROM user_products WHERE id = $1 AND user_id = $2`, id, userID) if execError != nil { return fmt.Errorf("delete user product: %w", execError) } if tag.RowsAffected() == 0 { return ErrNotFound } return nil } // ListForPrompt returns a human-readable list of user's products for the AI prompt. // Expiring soon items are marked with ⚠. func (r *Repository) ListForPrompt(requestContext context.Context, userID string) ([]string, error) { rows, queryError := r.pool.Query(requestContext, ` WITH up AS ( SELECT name, quantity, unit, (added_at + storage_days * INTERVAL '1 day') AS expires_at FROM user_products WHERE user_id = $1 ) SELECT name, quantity, unit, expires_at FROM up ORDER BY expires_at ASC`, userID) if queryError != nil { return nil, fmt.Errorf("list user products for prompt: %w", queryError) } defer rows.Close() var lines []string now := time.Now() for rows.Next() { var name, unit string var qty float64 var expiresAt time.Time if scanError := rows.Scan(&name, &qty, &unit, &expiresAt); scanError != nil { return nil, fmt.Errorf("scan user product for prompt: %w", scanError) } daysLeft := int(expiresAt.Sub(now).Hours() / 24) line := fmt.Sprintf("- %s %.0f %s", name, qty, unit) switch { case daysLeft <= 0: line += " (expires today ⚠)" case daysLeft == 1: line += " (expires tomorrow ⚠)" case daysLeft <= 3: line += fmt.Sprintf(" (expires in %d days ⚠)", daysLeft) } lines = append(lines, line) } return lines, rows.Err() } func (r *Repository) ListForPromptByIDs(requestContext context.Context, userID string, ids []string) ([]string, error) { rows, queryError := r.pool.Query(requestContext, ` WITH up AS ( SELECT name, quantity, unit, (added_at + storage_days * INTERVAL '1 day') AS expires_at FROM user_products WHERE user_id = $1 AND id = ANY($2) ) SELECT name, quantity, unit, expires_at FROM up ORDER BY expires_at ASC`, userID, ids) if queryError != nil { return nil, fmt.Errorf("list user products by ids for prompt: %w", queryError) } defer rows.Close() var lines []string now := time.Now() for rows.Next() { var name, unit string var qty float64 var expiresAt time.Time if scanError := rows.Scan(&name, &qty, &unit, &expiresAt); scanError != nil { return nil, fmt.Errorf("scan user product for prompt: %w", scanError) } daysLeft := int(expiresAt.Sub(now).Hours() / 24) line := fmt.Sprintf("- %s %.0f %s", name, qty, unit) switch { case daysLeft <= 0: line += " (expires today ⚠)" case daysLeft == 1: line += " (expires tomorrow ⚠)" case daysLeft <= 3: line += fmt.Sprintf(" (expires in %d days ⚠)", daysLeft) } lines = append(lines, line) } return lines, rows.Err() } // --- helpers --- func scanUserProduct(row pgx.Row) (*UserProduct, error) { var userProduct UserProduct scanError := row.Scan( &userProduct.ID, &userProduct.UserID, &userProduct.PrimaryProductID, &userProduct.Name, &userProduct.Quantity, &userProduct.Unit, &userProduct.Category, &userProduct.StorageDays, &userProduct.AddedAt, &userProduct.ExpiresAt, ) if scanError != nil { return nil, scanError } computeDaysLeft(&userProduct) return &userProduct, nil } func collectUserProducts(rows pgx.Rows) ([]*UserProduct, error) { var result []*UserProduct for rows.Next() { var userProduct UserProduct if scanError := rows.Scan( &userProduct.ID, &userProduct.UserID, &userProduct.PrimaryProductID, &userProduct.Name, &userProduct.Quantity, &userProduct.Unit, &userProduct.Category, &userProduct.StorageDays, &userProduct.AddedAt, &userProduct.ExpiresAt, ); scanError != nil { return nil, fmt.Errorf("scan user product: %w", scanError) } computeDaysLeft(&userProduct) result = append(result, &userProduct) } return result, rows.Err() } func computeDaysLeft(userProduct *UserProduct) { days := int(time.Until(userProduct.ExpiresAt).Hours() / 24) if days < 0 { days = 0 } userProduct.DaysLeft = days userProduct.ExpiringSoon = days <= 3 }