Add dynamic worker management
This commit is contained in:
@@ -1,24 +1,29 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"mrvacommander/pkg/agent"
|
"context"
|
||||||
"mrvacommander/pkg/queue"
|
|
||||||
"os/signal"
|
|
||||||
"strconv"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"flag"
|
"flag"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/elastic/go-sysinfo"
|
"github.com/elastic/go-sysinfo"
|
||||||
"golang.org/x/exp/slog"
|
"golang.org/x/exp/slog"
|
||||||
|
|
||||||
|
"mrvacommander/pkg/agent"
|
||||||
|
"mrvacommander/pkg/queue"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
workerMemoryMB = 2048 // 2 GB
|
||||||
|
monitorIntervalSec = 10 // Monitor every 10 seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
func calculateWorkers() int {
|
func calculateWorkers() int {
|
||||||
const workerMemoryMB = 2048 // 2 GB
|
|
||||||
|
|
||||||
host, err := sysinfo.Host()
|
host, err := sysinfo.Host()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to get host info", "error", err)
|
slog.Error("failed to get host info", "error", err)
|
||||||
@@ -49,6 +54,60 @@ func calculateWorkers() int {
|
|||||||
return workers
|
return workers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func startAndMonitorWorkers(ctx context.Context, queue queue.Queue, desiredWorkerCount int, wg *sync.WaitGroup) {
|
||||||
|
currentWorkerCount := 0
|
||||||
|
stopChans := make([]chan struct{}, 0)
|
||||||
|
|
||||||
|
if desiredWorkerCount != 0 {
|
||||||
|
slog.Info("Starting workers", slog.Int("count", desiredWorkerCount))
|
||||||
|
for i := 0; i < desiredWorkerCount; i++ {
|
||||||
|
stopChan := make(chan struct{})
|
||||||
|
stopChans = append(stopChans, stopChan)
|
||||||
|
wg.Add(1)
|
||||||
|
go agent.RunWorker(ctx, stopChan, queue, wg)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("Worker count not specified, managing based on available memory and CPU")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// signal all workers to stop
|
||||||
|
for _, stopChan := range stopChans {
|
||||||
|
close(stopChan)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
newWorkerCount := calculateWorkers()
|
||||||
|
|
||||||
|
if newWorkerCount != currentWorkerCount {
|
||||||
|
slog.Info(
|
||||||
|
"Modifying worker count",
|
||||||
|
slog.Int("current", currentWorkerCount),
|
||||||
|
slog.Int("new", newWorkerCount))
|
||||||
|
}
|
||||||
|
|
||||||
|
if newWorkerCount > currentWorkerCount {
|
||||||
|
for i := currentWorkerCount; i < newWorkerCount; i++ {
|
||||||
|
stopChan := make(chan struct{})
|
||||||
|
stopChans = append(stopChans, stopChan)
|
||||||
|
wg.Add(1)
|
||||||
|
go agent.RunWorker(ctx, stopChan, queue, wg)
|
||||||
|
}
|
||||||
|
} else if newWorkerCount < currentWorkerCount {
|
||||||
|
for i := newWorkerCount; i < currentWorkerCount; i++ {
|
||||||
|
close(stopChans[i])
|
||||||
|
}
|
||||||
|
stopChans = stopChans[:newWorkerCount]
|
||||||
|
}
|
||||||
|
currentWorkerCount = newWorkerCount
|
||||||
|
time.Sleep(monitorIntervalSec * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
slog.Info("Starting agent")
|
slog.Info("Starting agent")
|
||||||
|
|
||||||
@@ -77,7 +136,6 @@ func main() {
|
|||||||
rmqPass := os.Getenv("MRVA_RABBITMQ_PASSWORD")
|
rmqPass := os.Getenv("MRVA_RABBITMQ_PASSWORD")
|
||||||
|
|
||||||
rmqPortAsInt, err := strconv.Atoi(rmqPort)
|
rmqPortAsInt, err := strconv.Atoi(rmqPort)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Failed to parse RabbitMQ port", slog.Any("error", err))
|
slog.Error("Failed to parse RabbitMQ port", slog.Any("error", err))
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
@@ -92,29 +150,23 @@ func main() {
|
|||||||
}
|
}
|
||||||
defer rabbitMQQueue.Close()
|
defer rabbitMQQueue.Close()
|
||||||
|
|
||||||
if *workerCount == 0 {
|
|
||||||
*workerCount = calculateWorkers()
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("Starting workers", slog.Int("count", *workerCount))
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < *workerCount; i++ {
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
wg.Add(1)
|
|
||||||
go agent.RunWorker(rabbitMQQueue, &wg)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("Agent startup complete")
|
go startAndMonitorWorkers(ctx, rabbitMQQueue, *workerCount, &wg)
|
||||||
|
|
||||||
// Gracefully exit on SIGINT/SIGTERM (TODO: add job cleanup)
|
slog.Info("Agent started")
|
||||||
|
|
||||||
|
// Gracefully exit on SIGINT/SIGTERM
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
go func() {
|
<-sigChan
|
||||||
<-sigChan
|
slog.Info("Shutting down agent")
|
||||||
slog.Info("Shutting down agent")
|
|
||||||
rabbitMQQueue.Close()
|
|
||||||
os.Exit(0)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {}
|
// TODO: fix this to gracefully terminate agent workers during jobs
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
slog.Info("Agent shutdown complete")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"mrvacommander/pkg/codeql"
|
"mrvacommander/pkg/codeql"
|
||||||
@@ -111,8 +112,7 @@ func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Upload the archive to storage
|
// TODO: Upload the archive to storage
|
||||||
slog.Info("Results archive size", slog.Int("size", len(resultsArchive)))
|
slog.Debug("Results archive size", slog.Int("size", len(resultsArchive)))
|
||||||
slog.Info("Analysis job successful.")
|
|
||||||
|
|
||||||
result = common.AnalyzeResult{
|
result = common.AnalyzeResult{
|
||||||
RequestId: job.RequestId,
|
RequestId: job.RequestId,
|
||||||
@@ -125,16 +125,43 @@ func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RunWorker runs a worker that processes jobs from queue
|
// RunWorker runs a worker that processes jobs from queue
|
||||||
func RunWorker(queue queue.Queue, wg *sync.WaitGroup) {
|
func RunWorker(ctx context.Context, stopChan chan struct{}, queue queue.Queue, wg *sync.WaitGroup) {
|
||||||
|
const (
|
||||||
|
WORKER_COUNT_STOP_MESSAGE = "Worker stopping due to reduction in worker count"
|
||||||
|
WORKER_CONTEXT_STOP_MESSAGE = "Worker stopping due to context cancellation"
|
||||||
|
)
|
||||||
|
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for job := range queue.Jobs() {
|
slog.Info("Worker started")
|
||||||
slog.Info("Running analysis job", slog.Any("job", job))
|
for {
|
||||||
result, err := RunAnalysisJob(job)
|
select {
|
||||||
if err != nil {
|
case <-stopChan:
|
||||||
slog.Error("Failed to run analysis job", slog.Any("error", err))
|
slog.Info(WORKER_COUNT_STOP_MESSAGE)
|
||||||
continue
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
slog.Info(WORKER_CONTEXT_STOP_MESSAGE)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case job, ok := <-queue.Jobs():
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Info("Running analysis job", slog.Any("job", job))
|
||||||
|
result, err := RunAnalysisJob(job)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to run analysis job", slog.Any("error", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
slog.Info("Analysis job completed", slog.Any("result", result))
|
||||||
|
queue.Results() <- result
|
||||||
|
case <-stopChan:
|
||||||
|
slog.Info(WORKER_COUNT_STOP_MESSAGE)
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
slog.Info(WORKER_CONTEXT_STOP_MESSAGE)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
slog.Info("Analysis job completed", slog.Any("result", result))
|
|
||||||
queue.Results() <- result
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string)
|
|||||||
databaseSHA = *dbMetadata.CreationMetadata.SHA
|
databaseSHA = *dbMetadata.CreationMetadata.SHA
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command(codeql.Path, "database", "run-queries", "--ram=1024", "--additional-packs", queryPackPath, "--", databasePath, queryPackPath)
|
cmd := exec.Command(codeql.Path, "database", "run-queries", "--ram=2048", "--additional-packs", queryPackPath, "--", databasePath, queryPackPath)
|
||||||
if output, err := cmd.CombinedOutput(); err != nil {
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to run queries: %v\nOutput: %s", err, output)
|
return nil, fmt.Errorf("failed to run queries: %v\nOutput: %s", err, output)
|
||||||
}
|
}
|
||||||
@@ -133,9 +133,9 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string)
|
|||||||
shouldGenerateSarif := queryPackSupportsSarif(queryPackRunResults)
|
shouldGenerateSarif := queryPackSupportsSarif(queryPackRunResults)
|
||||||
|
|
||||||
if shouldGenerateSarif {
|
if shouldGenerateSarif {
|
||||||
slog.Info("Query pack supports SARIF")
|
slog.Debug("Query pack supports SARIF")
|
||||||
} else {
|
} else {
|
||||||
slog.Info("Query pack does not support SARIF")
|
slog.Debug("Query pack does not support SARIF")
|
||||||
}
|
}
|
||||||
|
|
||||||
var resultCount int
|
var resultCount int
|
||||||
@@ -147,17 +147,17 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string)
|
|||||||
return nil, fmt.Errorf("failed to generate SARIF: %v", err)
|
return nil, fmt.Errorf("failed to generate SARIF: %v", err)
|
||||||
}
|
}
|
||||||
resultCount = getSarifResultCount(sarif)
|
resultCount = getSarifResultCount(sarif)
|
||||||
slog.Info("Generated SARIF", "resultCount", resultCount)
|
slog.Debug("Generated SARIF", "resultCount", resultCount)
|
||||||
sarifFilePath = filepath.Join(resultsDir, "results.sarif")
|
sarifFilePath = filepath.Join(resultsDir, "results.sarif")
|
||||||
if err := os.WriteFile(sarifFilePath, sarif, 0644); err != nil {
|
if err := os.WriteFile(sarifFilePath, sarif, 0644); err != nil {
|
||||||
return nil, fmt.Errorf("failed to write SARIF file: %v", err)
|
return nil, fmt.Errorf("failed to write SARIF file: %v", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
resultCount = queryPackRunResults.TotalResultsCount
|
resultCount = queryPackRunResults.TotalResultsCount
|
||||||
slog.Info("Did not generate SARIF", "resultCount", resultCount)
|
slog.Debug("Did not generate SARIF", "resultCount", resultCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Adjusting BQRS files")
|
slog.Debug("Adjusting BQRS files")
|
||||||
bqrsFilePaths, err := adjustBqrsFiles(queryPackRunResults, resultsDir)
|
bqrsFilePaths, err := adjustBqrsFiles(queryPackRunResults, resultsDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to adjust BQRS files: %v", err)
|
return nil, fmt.Errorf("failed to adjust BQRS files: %v", err)
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ func (q *RabbitMQQueue) publishResult(queueName string, result interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Publishing result", slog.String("result", string(resultBytes)))
|
slog.Debug("Publishing result", slog.String("result", string(resultBytes)))
|
||||||
err = q.channel.PublishWithContext(ctx, "", queueName, false, false,
|
err = q.channel.PublishWithContext(ctx, "", queueName, false, false,
|
||||||
amqp.Publishing{
|
amqp.Publishing{
|
||||||
ContentType: "application/json",
|
ContentType: "application/json",
|
||||||
|
|||||||
Reference in New Issue
Block a user