From 903ca5673e68f9a44a32bfc598d2b77e761759bd Mon Sep 17 00:00:00 2001 From: Nicolas Will Date: Sun, 16 Jun 2024 12:21:54 +0200 Subject: [PATCH] Add dynamic worker management --- cmd/agent/main.go | 106 +++++++++++++++++++++++++++++++----------- pkg/agent/agent.go | 49 ++++++++++++++----- pkg/codeql/codeql.go | 12 ++--- pkg/queue/rabbitmq.go | 2 +- 4 files changed, 124 insertions(+), 45 deletions(-) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 2df182f..7cd3e33 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -1,24 +1,29 @@ package main import ( - "mrvacommander/pkg/agent" - "mrvacommander/pkg/queue" - "os/signal" - "strconv" - "syscall" - + "context" "flag" "os" + "os/signal" "runtime" + "strconv" "sync" + "syscall" + "time" "github.com/elastic/go-sysinfo" "golang.org/x/exp/slog" + + "mrvacommander/pkg/agent" + "mrvacommander/pkg/queue" +) + +const ( + workerMemoryMB = 2048 // 2 GB + monitorIntervalSec = 10 // Monitor every 10 seconds ) func calculateWorkers() int { - const workerMemoryMB = 2048 // 2 GB - host, err := sysinfo.Host() if err != nil { slog.Error("failed to get host info", "error", err) @@ -49,6 +54,60 @@ func calculateWorkers() int { return workers } +func startAndMonitorWorkers(ctx context.Context, queue queue.Queue, desiredWorkerCount int, wg *sync.WaitGroup) { + currentWorkerCount := 0 + stopChans := make([]chan struct{}, 0) + + if desiredWorkerCount != 0 { + slog.Info("Starting workers", slog.Int("count", desiredWorkerCount)) + for i := 0; i < desiredWorkerCount; i++ { + stopChan := make(chan struct{}) + stopChans = append(stopChans, stopChan) + wg.Add(1) + go agent.RunWorker(ctx, stopChan, queue, wg) + } + return + } + + slog.Info("Worker count not specified, managing based on available memory and CPU") + + for { + select { + case <-ctx.Done(): + // signal all workers to stop + for _, stopChan := range stopChans { + close(stopChan) + } + return + default: + newWorkerCount := calculateWorkers() + + if newWorkerCount != currentWorkerCount { + slog.Info( + "Modifying worker count", + slog.Int("current", currentWorkerCount), + slog.Int("new", newWorkerCount)) + } + + if newWorkerCount > currentWorkerCount { + for i := currentWorkerCount; i < newWorkerCount; i++ { + stopChan := make(chan struct{}) + stopChans = append(stopChans, stopChan) + wg.Add(1) + go agent.RunWorker(ctx, stopChan, queue, wg) + } + } else if newWorkerCount < currentWorkerCount { + for i := newWorkerCount; i < currentWorkerCount; i++ { + close(stopChans[i]) + } + stopChans = stopChans[:newWorkerCount] + } + currentWorkerCount = newWorkerCount + time.Sleep(monitorIntervalSec * time.Second) + } + } +} + func main() { slog.Info("Starting agent") @@ -77,7 +136,6 @@ func main() { rmqPass := os.Getenv("MRVA_RABBITMQ_PASSWORD") rmqPortAsInt, err := strconv.Atoi(rmqPort) - if err != nil { slog.Error("Failed to parse RabbitMQ port", slog.Any("error", err)) os.Exit(1) @@ -92,29 +150,23 @@ func main() { } defer rabbitMQQueue.Close() - if *workerCount == 0 { - *workerCount = calculateWorkers() - } - - slog.Info("Starting workers", slog.Int("count", *workerCount)) var wg sync.WaitGroup - for i := 0; i < *workerCount; i++ { - wg.Add(1) - go agent.RunWorker(rabbitMQQueue, &wg) - } + ctx, cancel := context.WithCancel(context.Background()) - slog.Info("Agent startup complete") + go startAndMonitorWorkers(ctx, rabbitMQQueue, *workerCount, &wg) - // Gracefully exit on SIGINT/SIGTERM (TODO: add job cleanup) + slog.Info("Agent started") + + // Gracefully exit on SIGINT/SIGTERM sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - go func() { - <-sigChan - slog.Info("Shutting down agent") - rabbitMQQueue.Close() - os.Exit(0) - }() + <-sigChan + slog.Info("Shutting down agent") - select {} + // TODO: fix this to gracefully terminate agent workers during jobs + cancel() + wg.Wait() + + slog.Info("Agent shutdown complete") } diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 018d759..b42ab51 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -1,6 +1,7 @@ package agent import ( + "context" "fmt" "log/slog" "mrvacommander/pkg/codeql" @@ -111,8 +112,7 @@ func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) { } // TODO: Upload the archive to storage - slog.Info("Results archive size", slog.Int("size", len(resultsArchive))) - slog.Info("Analysis job successful.") + slog.Debug("Results archive size", slog.Int("size", len(resultsArchive))) result = common.AnalyzeResult{ RequestId: job.RequestId, @@ -125,16 +125,43 @@ func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) { } // RunWorker runs a worker that processes jobs from queue -func RunWorker(queue queue.Queue, wg *sync.WaitGroup) { +func RunWorker(ctx context.Context, stopChan chan struct{}, queue queue.Queue, wg *sync.WaitGroup) { + const ( + WORKER_COUNT_STOP_MESSAGE = "Worker stopping due to reduction in worker count" + WORKER_CONTEXT_STOP_MESSAGE = "Worker stopping due to context cancellation" + ) + defer wg.Done() - for job := range queue.Jobs() { - slog.Info("Running analysis job", slog.Any("job", job)) - result, err := RunAnalysisJob(job) - if err != nil { - slog.Error("Failed to run analysis job", slog.Any("error", err)) - continue + slog.Info("Worker started") + for { + select { + case <-stopChan: + slog.Info(WORKER_COUNT_STOP_MESSAGE) + return + case <-ctx.Done(): + slog.Info(WORKER_CONTEXT_STOP_MESSAGE) + return + default: + select { + case job, ok := <-queue.Jobs(): + if !ok { + return + } + slog.Info("Running analysis job", slog.Any("job", job)) + result, err := RunAnalysisJob(job) + if err != nil { + slog.Error("Failed to run analysis job", slog.Any("error", err)) + continue + } + slog.Info("Analysis job completed", slog.Any("result", result)) + queue.Results() <- result + case <-stopChan: + slog.Info(WORKER_COUNT_STOP_MESSAGE) + return + case <-ctx.Done(): + slog.Info(WORKER_CONTEXT_STOP_MESSAGE) + return + } } - slog.Info("Analysis job completed", slog.Any("result", result)) - queue.Results() <- result } } diff --git a/pkg/codeql/codeql.go b/pkg/codeql/codeql.go index a44efd8..2c4f45c 100644 --- a/pkg/codeql/codeql.go +++ b/pkg/codeql/codeql.go @@ -115,7 +115,7 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string) databaseSHA = *dbMetadata.CreationMetadata.SHA } - cmd := exec.Command(codeql.Path, "database", "run-queries", "--ram=1024", "--additional-packs", queryPackPath, "--", databasePath, queryPackPath) + cmd := exec.Command(codeql.Path, "database", "run-queries", "--ram=2048", "--additional-packs", queryPackPath, "--", databasePath, queryPackPath) if output, err := cmd.CombinedOutput(); err != nil { return nil, fmt.Errorf("failed to run queries: %v\nOutput: %s", err, output) } @@ -133,9 +133,9 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string) shouldGenerateSarif := queryPackSupportsSarif(queryPackRunResults) if shouldGenerateSarif { - slog.Info("Query pack supports SARIF") + slog.Debug("Query pack supports SARIF") } else { - slog.Info("Query pack does not support SARIF") + slog.Debug("Query pack does not support SARIF") } var resultCount int @@ -147,17 +147,17 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string) return nil, fmt.Errorf("failed to generate SARIF: %v", err) } resultCount = getSarifResultCount(sarif) - slog.Info("Generated SARIF", "resultCount", resultCount) + slog.Debug("Generated SARIF", "resultCount", resultCount) sarifFilePath = filepath.Join(resultsDir, "results.sarif") if err := os.WriteFile(sarifFilePath, sarif, 0644); err != nil { return nil, fmt.Errorf("failed to write SARIF file: %v", err) } } else { resultCount = queryPackRunResults.TotalResultsCount - slog.Info("Did not generate SARIF", "resultCount", resultCount) + slog.Debug("Did not generate SARIF", "resultCount", resultCount) } - slog.Info("Adjusting BQRS files") + slog.Debug("Adjusting BQRS files") bqrsFilePaths, err := adjustBqrsFiles(queryPackRunResults, resultsDir) if err != nil { return nil, fmt.Errorf("failed to adjust BQRS files: %v", err) diff --git a/pkg/queue/rabbitmq.go b/pkg/queue/rabbitmq.go index 7d32b1f..f0047f3 100644 --- a/pkg/queue/rabbitmq.go +++ b/pkg/queue/rabbitmq.go @@ -151,7 +151,7 @@ func (q *RabbitMQQueue) publishResult(queueName string, result interface{}) { return } - slog.Info("Publishing result", slog.String("result", string(resultBytes))) + slog.Debug("Publishing result", slog.String("result", string(resultBytes))) err = q.channel.PublishWithContext(ctx, "", queueName, false, false, amqp.Publishing{ ContentType: "application/json",