diff --git a/pkg/queue/queue_rabbitmq.go b/pkg/queue/queue_rabbitmq.go index bd674cd..dee1e02 100644 --- a/pkg/queue/queue_rabbitmq.go +++ b/pkg/queue/queue_rabbitmq.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "sync" "time" amqp "github.com/rabbitmq/amqp091-go" @@ -15,6 +16,9 @@ type RabbitMQQueue struct { results chan AnalyzeResult conn *amqp.Connection channel *amqp.Channel + + mu sync.Mutex + connString string } // NewRabbitMQQueue initializes a RabbitMQ queue. @@ -89,10 +93,12 @@ func NewRabbitMQQueue( } result := RabbitMQQueue{ - conn: conn, - channel: ch, - jobs: make(chan AnalyzeJob), - results: make(chan AnalyzeResult), + conn: conn, + channel: ch, + jobs: make(chan AnalyzeJob), + results: make(chan AnalyzeResult), + mu: sync.Mutex{}, + connString: rabbitMQURL, } if isAgent { @@ -125,6 +131,49 @@ func (q *RabbitMQQueue) Close() { q.conn.Close() } +func (q *RabbitMQQueue) reconnectIfNeeded() error { + q.mu.Lock() + defer q.mu.Unlock() + + if q.conn != nil && !q.conn.IsClosed() && q.channel != nil { + return nil // still valid + } + + // Recreate everything + conn, err := amqp.Dial(q.connString) + if err != nil { + return fmt.Errorf("failed to reconnect: %w", err) + } + + ch, err := conn.Channel() + if err != nil { + conn.Close() + return fmt.Errorf("failed to open channel: %w", err) + } + + // Optional: redeclare queues here + // _, _ = ch.QueueDeclare(...) + + q.conn = conn + q.channel = ch + return nil +} + +func (q *RabbitMQQueue) invalidateConnection() { + q.mu.Lock() + defer q.mu.Unlock() + + if q.channel != nil { + _ = q.channel.Close() + } + if q.conn != nil { + _ = q.conn.Close() + } + + q.channel = nil + q.conn = nil +} + func (q *RabbitMQQueue) ConsumeJobs(queueName string) { const pollInterval = 5 * time.Second @@ -135,9 +184,17 @@ func (q *RabbitMQQueue) ConsumeJobs(queueName string) { // | Connection lost | msg = zero, ok = false, err = non-nil | for { + + if err := q.reconnectIfNeeded(); err != nil { + slog.Error("failed to reconnect", slog.Any("error", err)) + time.Sleep(10 * time.Second) + continue + } + msg, ok, err := q.channel.Get(queueName, false) // false = manual ack if err != nil { slog.Error("polling error while getting job", slog.Any("error", err)) + q.invalidateConnection() time.Sleep(pollInterval) continue }