Add dynamic worker management

This commit is contained in:
Nicolas Will
2024-06-16 12:21:54 +02:00
parent 7ea45cb176
commit 903ca5673e
4 changed files with 124 additions and 45 deletions

View File

@@ -1,24 +1,29 @@
package main package main
import ( import (
"mrvacommander/pkg/agent" "context"
"mrvacommander/pkg/queue"
"os/signal"
"strconv"
"syscall"
"flag" "flag"
"os" "os"
"os/signal"
"runtime" "runtime"
"strconv"
"sync" "sync"
"syscall"
"time"
"github.com/elastic/go-sysinfo" "github.com/elastic/go-sysinfo"
"golang.org/x/exp/slog" "golang.org/x/exp/slog"
"mrvacommander/pkg/agent"
"mrvacommander/pkg/queue"
)
const (
workerMemoryMB = 2048 // 2 GB
monitorIntervalSec = 10 // Monitor every 10 seconds
) )
func calculateWorkers() int { func calculateWorkers() int {
const workerMemoryMB = 2048 // 2 GB
host, err := sysinfo.Host() host, err := sysinfo.Host()
if err != nil { if err != nil {
slog.Error("failed to get host info", "error", err) slog.Error("failed to get host info", "error", err)
@@ -49,6 +54,60 @@ func calculateWorkers() int {
return workers return workers
} }
func startAndMonitorWorkers(ctx context.Context, queue queue.Queue, desiredWorkerCount int, wg *sync.WaitGroup) {
currentWorkerCount := 0
stopChans := make([]chan struct{}, 0)
if desiredWorkerCount != 0 {
slog.Info("Starting workers", slog.Int("count", desiredWorkerCount))
for i := 0; i < desiredWorkerCount; i++ {
stopChan := make(chan struct{})
stopChans = append(stopChans, stopChan)
wg.Add(1)
go agent.RunWorker(ctx, stopChan, queue, wg)
}
return
}
slog.Info("Worker count not specified, managing based on available memory and CPU")
for {
select {
case <-ctx.Done():
// signal all workers to stop
for _, stopChan := range stopChans {
close(stopChan)
}
return
default:
newWorkerCount := calculateWorkers()
if newWorkerCount != currentWorkerCount {
slog.Info(
"Modifying worker count",
slog.Int("current", currentWorkerCount),
slog.Int("new", newWorkerCount))
}
if newWorkerCount > currentWorkerCount {
for i := currentWorkerCount; i < newWorkerCount; i++ {
stopChan := make(chan struct{})
stopChans = append(stopChans, stopChan)
wg.Add(1)
go agent.RunWorker(ctx, stopChan, queue, wg)
}
} else if newWorkerCount < currentWorkerCount {
for i := newWorkerCount; i < currentWorkerCount; i++ {
close(stopChans[i])
}
stopChans = stopChans[:newWorkerCount]
}
currentWorkerCount = newWorkerCount
time.Sleep(monitorIntervalSec * time.Second)
}
}
}
func main() { func main() {
slog.Info("Starting agent") slog.Info("Starting agent")
@@ -77,7 +136,6 @@ func main() {
rmqPass := os.Getenv("MRVA_RABBITMQ_PASSWORD") rmqPass := os.Getenv("MRVA_RABBITMQ_PASSWORD")
rmqPortAsInt, err := strconv.Atoi(rmqPort) rmqPortAsInt, err := strconv.Atoi(rmqPort)
if err != nil { if err != nil {
slog.Error("Failed to parse RabbitMQ port", slog.Any("error", err)) slog.Error("Failed to parse RabbitMQ port", slog.Any("error", err))
os.Exit(1) os.Exit(1)
@@ -92,29 +150,23 @@ func main() {
} }
defer rabbitMQQueue.Close() defer rabbitMQQueue.Close()
if *workerCount == 0 {
*workerCount = calculateWorkers()
}
slog.Info("Starting workers", slog.Int("count", *workerCount))
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < *workerCount; i++ { ctx, cancel := context.WithCancel(context.Background())
wg.Add(1)
go agent.RunWorker(rabbitMQQueue, &wg)
}
slog.Info("Agent startup complete") go startAndMonitorWorkers(ctx, rabbitMQQueue, *workerCount, &wg)
// Gracefully exit on SIGINT/SIGTERM (TODO: add job cleanup) slog.Info("Agent started")
// Gracefully exit on SIGINT/SIGTERM
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() { <-sigChan
<-sigChan slog.Info("Shutting down agent")
slog.Info("Shutting down agent")
rabbitMQQueue.Close()
os.Exit(0)
}()
select {} // TODO: fix this to gracefully terminate agent workers during jobs
cancel()
wg.Wait()
slog.Info("Agent shutdown complete")
} }

View File

@@ -1,6 +1,7 @@
package agent package agent
import ( import (
"context"
"fmt" "fmt"
"log/slog" "log/slog"
"mrvacommander/pkg/codeql" "mrvacommander/pkg/codeql"
@@ -111,8 +112,7 @@ func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) {
} }
// TODO: Upload the archive to storage // TODO: Upload the archive to storage
slog.Info("Results archive size", slog.Int("size", len(resultsArchive))) slog.Debug("Results archive size", slog.Int("size", len(resultsArchive)))
slog.Info("Analysis job successful.")
result = common.AnalyzeResult{ result = common.AnalyzeResult{
RequestId: job.RequestId, RequestId: job.RequestId,
@@ -125,16 +125,43 @@ func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) {
} }
// RunWorker runs a worker that processes jobs from queue // RunWorker runs a worker that processes jobs from queue
func RunWorker(queue queue.Queue, wg *sync.WaitGroup) { func RunWorker(ctx context.Context, stopChan chan struct{}, queue queue.Queue, wg *sync.WaitGroup) {
const (
WORKER_COUNT_STOP_MESSAGE = "Worker stopping due to reduction in worker count"
WORKER_CONTEXT_STOP_MESSAGE = "Worker stopping due to context cancellation"
)
defer wg.Done() defer wg.Done()
for job := range queue.Jobs() { slog.Info("Worker started")
slog.Info("Running analysis job", slog.Any("job", job)) for {
result, err := RunAnalysisJob(job) select {
if err != nil { case <-stopChan:
slog.Error("Failed to run analysis job", slog.Any("error", err)) slog.Info(WORKER_COUNT_STOP_MESSAGE)
continue return
case <-ctx.Done():
slog.Info(WORKER_CONTEXT_STOP_MESSAGE)
return
default:
select {
case job, ok := <-queue.Jobs():
if !ok {
return
}
slog.Info("Running analysis job", slog.Any("job", job))
result, err := RunAnalysisJob(job)
if err != nil {
slog.Error("Failed to run analysis job", slog.Any("error", err))
continue
}
slog.Info("Analysis job completed", slog.Any("result", result))
queue.Results() <- result
case <-stopChan:
slog.Info(WORKER_COUNT_STOP_MESSAGE)
return
case <-ctx.Done():
slog.Info(WORKER_CONTEXT_STOP_MESSAGE)
return
}
} }
slog.Info("Analysis job completed", slog.Any("result", result))
queue.Results() <- result
} }
} }

View File

@@ -115,7 +115,7 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string)
databaseSHA = *dbMetadata.CreationMetadata.SHA databaseSHA = *dbMetadata.CreationMetadata.SHA
} }
cmd := exec.Command(codeql.Path, "database", "run-queries", "--ram=1024", "--additional-packs", queryPackPath, "--", databasePath, queryPackPath) cmd := exec.Command(codeql.Path, "database", "run-queries", "--ram=2048", "--additional-packs", queryPackPath, "--", databasePath, queryPackPath)
if output, err := cmd.CombinedOutput(); err != nil { if output, err := cmd.CombinedOutput(); err != nil {
return nil, fmt.Errorf("failed to run queries: %v\nOutput: %s", err, output) return nil, fmt.Errorf("failed to run queries: %v\nOutput: %s", err, output)
} }
@@ -133,9 +133,9 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string)
shouldGenerateSarif := queryPackSupportsSarif(queryPackRunResults) shouldGenerateSarif := queryPackSupportsSarif(queryPackRunResults)
if shouldGenerateSarif { if shouldGenerateSarif {
slog.Info("Query pack supports SARIF") slog.Debug("Query pack supports SARIF")
} else { } else {
slog.Info("Query pack does not support SARIF") slog.Debug("Query pack does not support SARIF")
} }
var resultCount int var resultCount int
@@ -147,17 +147,17 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string)
return nil, fmt.Errorf("failed to generate SARIF: %v", err) return nil, fmt.Errorf("failed to generate SARIF: %v", err)
} }
resultCount = getSarifResultCount(sarif) resultCount = getSarifResultCount(sarif)
slog.Info("Generated SARIF", "resultCount", resultCount) slog.Debug("Generated SARIF", "resultCount", resultCount)
sarifFilePath = filepath.Join(resultsDir, "results.sarif") sarifFilePath = filepath.Join(resultsDir, "results.sarif")
if err := os.WriteFile(sarifFilePath, sarif, 0644); err != nil { if err := os.WriteFile(sarifFilePath, sarif, 0644); err != nil {
return nil, fmt.Errorf("failed to write SARIF file: %v", err) return nil, fmt.Errorf("failed to write SARIF file: %v", err)
} }
} else { } else {
resultCount = queryPackRunResults.TotalResultsCount resultCount = queryPackRunResults.TotalResultsCount
slog.Info("Did not generate SARIF", "resultCount", resultCount) slog.Debug("Did not generate SARIF", "resultCount", resultCount)
} }
slog.Info("Adjusting BQRS files") slog.Debug("Adjusting BQRS files")
bqrsFilePaths, err := adjustBqrsFiles(queryPackRunResults, resultsDir) bqrsFilePaths, err := adjustBqrsFiles(queryPackRunResults, resultsDir)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to adjust BQRS files: %v", err) return nil, fmt.Errorf("failed to adjust BQRS files: %v", err)

View File

@@ -151,7 +151,7 @@ func (q *RabbitMQQueue) publishResult(queueName string, result interface{}) {
return return
} }
slog.Info("Publishing result", slog.String("result", string(resultBytes))) slog.Debug("Publishing result", slog.String("result", string(resultBytes)))
err = q.channel.PublishWithContext(ctx, "", queueName, false, false, err = q.channel.PublishWithContext(ctx, "", queueName, false, false,
amqp.Publishing{ amqp.Publishing{
ContentType: "application/json", ContentType: "application/json",