diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 6326323..2df182f 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -1,73 +1,41 @@ package main import ( - "io" - "log" - "mrvacommander/pkg/codeql" - "mrvacommander/pkg/common" + "mrvacommander/pkg/agent" "mrvacommander/pkg/queue" - "mrvacommander/pkg/storage" - "mrvacommander/utils" - "net/http" - "path/filepath" - "runtime" - - "context" - "encoding/json" - "flag" - "fmt" - "os" "os/signal" - "sync" + "strconv" "syscall" - "time" - "github.com/google/uuid" - amqp "github.com/rabbitmq/amqp091-go" - "golang.org/x/exp/slog" + "flag" + "os" + "runtime" + "sync" "github.com/elastic/go-sysinfo" + "golang.org/x/exp/slog" ) -func downloadFile(url string, dest string) error { - resp, err := http.Get(url) - if err != nil { - return fmt.Errorf("failed to download file: %w", err) - } - defer resp.Body.Close() - - out, err := os.Create(dest) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - return fmt.Errorf("failed to copy file content: %w", err) - } - - return nil -} - func calculateWorkers() int { - const workerMemoryGB = 2 + const workerMemoryMB = 2048 // 2 GB host, err := sysinfo.Host() if err != nil { - log.Fatalf("failed to get host info: %v", err) + slog.Error("failed to get host info", "error", err) + os.Exit(1) } memInfo, err := host.Memory() if err != nil { - log.Fatalf("failed to get memory info: %v", err) + slog.Error("failed to get memory info", "error", err) + os.Exit(1) } - // Convert total memory to GB - totalMemoryGB := memInfo.Available / (1024 * 1024 * 1024) + // Get available memory in MB + totalMemoryMB := memInfo.Available / (1024 * 1024) // Ensure we have at least one worker - workers := int(totalMemoryGB / workerMemoryGB) + workers := int(totalMemoryMB / workerMemoryMB) if workers < 1 { workers = 1 } @@ -81,206 +49,6 @@ func calculateWorkers() int { return workers } -type RabbitMQQueue struct { - jobs chan common.AnalyzeJob - results chan common.AnalyzeResult - conn *amqp.Connection - channel *amqp.Channel -} - -func InitializeQueue(jobsQueueName, resultsQueueName string) (*RabbitMQQueue, error) { - rabbitMQHost := os.Getenv("MRVA_RABBITMQ_HOST") - rabbitMQPort := os.Getenv("MRVA_RABBITMQ_PORT") - rabbitMQUser := os.Getenv("MRVA_RABBITMQ_USER") - rabbitMQPassword := os.Getenv("MRVA_RABBITMQ_PASSWORD") - - if rabbitMQHost == "" || rabbitMQPort == "" || rabbitMQUser == "" || rabbitMQPassword == "" { - return nil, fmt.Errorf("RabbitMQ environment variables not set") - } - - rabbitMQURL := fmt.Sprintf("amqp://%s:%s@%s:%s/", rabbitMQUser, rabbitMQPassword, rabbitMQHost, rabbitMQPort) - - const ( - tryCount = 5 - retryDelaySec = 3 - ) - - var conn *amqp.Connection - var err error - - for i := 0; i < tryCount; i++ { - slog.Info("Attempting to connect to RabbitMQ", slog.Int("attempt", i+1)) - conn, err = amqp.Dial(rabbitMQURL) - if err != nil { - slog.Warn("Failed to connect to RabbitMQ: %w", err) - if i < tryCount-1 { - slog.Info("Retrying in %d seconds", retryDelaySec) - time.Sleep(retryDelaySec * time.Second) - } - } - } - - if err != nil { - return nil, fmt.Errorf("failed to connect to RabbitMQ: %w", err) - } - - slog.Info("Connected to RabbitMQ") - - ch, err := conn.Channel() - if err != nil { - conn.Close() - return nil, fmt.Errorf("failed to open a channel: %w", err) - } - - _, err = ch.QueueDeclare(jobsQueueName, false, false, false, true, nil) - if err != nil { - conn.Close() - return nil, fmt.Errorf("failed to declare tasks queue: %w", err) - } - - _, err = ch.QueueDeclare(resultsQueueName, false, false, false, true, nil) - if err != nil { - conn.Close() - return nil, fmt.Errorf("failed to declare results queue: %w", err) - } - - err = ch.Qos(1, 0, false) - - if err != nil { - conn.Close() - return nil, fmt.Errorf("failed to set QoS: %w", err) - } - - return &RabbitMQQueue{ - conn: conn, - channel: ch, - jobs: make(chan common.AnalyzeJob), - results: make(chan common.AnalyzeResult), - }, nil -} - -func (q *RabbitMQQueue) Jobs() chan common.AnalyzeJob { - return q.jobs -} - -func (q *RabbitMQQueue) Results() chan common.AnalyzeResult { - return q.results -} - -func (q *RabbitMQQueue) StartAnalyses(analysis_repos *map[common.NameWithOwner]storage.DBLocation, session_id int, session_language string) { - slog.Info("Queueing codeql database analyze jobs") -} - -func (q *RabbitMQQueue) Close() { - q.channel.Close() - q.conn.Close() -} - -func (q *RabbitMQQueue) ConsumeJobs(queueName string) { - msgs, err := q.channel.Consume(queueName, "", true, false, false, false, nil) - if err != nil { - slog.Error("failed to register a consumer", slog.Any("error", err)) - } - - for msg := range msgs { - job := common.AnalyzeJob{} - err := json.Unmarshal(msg.Body, &job) - if err != nil { - slog.Error("failed to unmarshal job", slog.Any("error", err)) - continue - } - q.jobs <- job - } - close(q.jobs) -} - -func (q *RabbitMQQueue) PublishResults(queueName string) { - for result := range q.results { - q.publishResult(queueName, result) - } -} - -func (q *RabbitMQQueue) publishResult(queueName string, result interface{}) { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - resultBytes, err := json.Marshal(result) - if err != nil { - slog.Error("failed to marshal result", slog.Any("error", err)) - return - } - - slog.Info("Publishing result", slog.String("result", string(resultBytes))) - err = q.channel.PublishWithContext(ctx, "", queueName, false, false, - amqp.Publishing{ - ContentType: "application/json", - Body: resultBytes, - }) - if err != nil { - slog.Error("failed to publish result", slog.Any("error", err)) - } -} - -func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) { - var result = common.AnalyzeResult{ - RequestId: job.RequestId, - ResultCount: 0, - ResultArchiveURL: "", - Status: common.StatusError, - } - - // Log job info - slog.Info("Running analysis job", slog.Any("job", job)) - - // Create a temporary directory - tempDir := filepath.Join(os.TempDir(), uuid.New().String()) - if err := os.MkdirAll(tempDir, 0755); err != nil { - return result, fmt.Errorf("failed to create temporary directory: %v", err) - } - defer os.RemoveAll(tempDir) - - // Extract the query pack - // TODO: download from the 'job' query pack URL - utils.UntarGz("qp-54674.tgz", filepath.Join(tempDir, "qp-54674")) - - // Perform the CodeQL analysis - runResult, err := codeql.RunQuery("google_flatbuffers_db.zip", "cpp", "qp-54674", tempDir) - if err != nil { - return result, fmt.Errorf("failed to run analysis: %w", err) - } - - // Generate a ZIP archive containing SARIF and BQRS files - resultsArchive, err := codeql.GenerateResultsZipArchive(runResult) - if err != nil { - return result, fmt.Errorf("failed to generate results archive: %w", err) - } - - // TODO: Upload the archive to storage - slog.Info("Results archive size", slog.Int("size", len(resultsArchive))) - slog.Info("Analysis job successful.") - - result = common.AnalyzeResult{ - RequestId: job.RequestId, - ResultCount: runResult.ResultCount, - ResultArchiveURL: "REPLACE_THIS_WITH_STORED_RESULTS_ARCHIVE", - Status: common.StatusSuccess, - } - - return result, nil -} - -func RunWorker(queue queue.Queue, wg *sync.WaitGroup) { - defer wg.Done() - for job := range queue.Jobs() { - result, err := RunAnalysisJob(job) - if err != nil { - slog.Error("failed to run analysis job", slog.Any("error", err)) - continue - } - queue.Results() <- result - } -} - func main() { slog.Info("Starting agent") @@ -297,13 +65,27 @@ func main() { } for _, envVar := range requiredEnvVars { - if os.Getenv(envVar) == "" { - log.Fatalf("Fatal: Missing required environment variable %s", envVar) + if _, ok := os.LookupEnv(envVar); !ok { + slog.Error("Missing required environment variable %s", envVar) + os.Exit(1) } } - slog.Info("Initializing RabbitMQ connection") - rabbitMQQueue, err := InitializeQueue("tasks", "results") + rmqHost := os.Getenv("MRVA_RABBITMQ_HOST") + rmqPort := os.Getenv("MRVA_RABBITMQ_PORT") + rmqUser := os.Getenv("MRVA_RABBITMQ_USER") + rmqPass := os.Getenv("MRVA_RABBITMQ_PASSWORD") + + rmqPortAsInt, err := strconv.Atoi(rmqPort) + + if err != nil { + slog.Error("Failed to parse RabbitMQ port", slog.Any("error", err)) + os.Exit(1) + } + + slog.Info("Initializing RabbitMQ queue") + + rabbitMQQueue, err := queue.InitializeRabbitMQQueue(rmqHost, int16(rmqPortAsInt), rmqUser, rmqPass) if err != nil { slog.Error("failed to initialize RabbitMQ", slog.Any("error", err)) os.Exit(1) @@ -318,21 +100,21 @@ func main() { var wg sync.WaitGroup for i := 0; i < *workerCount; i++ { wg.Add(1) - go RunWorker(rabbitMQQueue, &wg) + go agent.RunWorker(rabbitMQQueue, &wg) } - slog.Info("Starting tasks consumer") - go rabbitMQQueue.ConsumeJobs("tasks") - - slog.Info("Starting results publisher") - go rabbitMQQueue.PublishResults("results") - slog.Info("Agent startup complete") + // Gracefully exit on SIGINT/SIGTERM (TODO: add job cleanup) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - <-sigChan - slog.Info("Shutting down agent") - close(rabbitMQQueue.results) + go func() { + <-sigChan + slog.Info("Shutting down agent") + rabbitMQQueue.Close() + os.Exit(0) + }() + + select {} } diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 01fa8ee..018d759 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -1,20 +1,20 @@ package agent import ( + "fmt" + "log/slog" + "mrvacommander/pkg/codeql" "mrvacommander/pkg/common" "mrvacommander/pkg/logger" "mrvacommander/pkg/qpstore" "mrvacommander/pkg/queue" "mrvacommander/pkg/storage" "mrvacommander/utils" - - "log/slog" - - "fmt" - "path/filepath" - "os" - "os/exec" + "path/filepath" + "sync" + + "github.com/google/uuid" ) type RunnerSingle struct { @@ -40,102 +40,101 @@ type Visibles struct { } func (c *RunnerSingle) Setup(st *Visibles) { - return + // TODO: implement } func (r *RunnerSingle) worker(wid int) { - var job common.AnalyzeJob + // TODO: reimplement this later + /* + var job common.AnalyzeJob - for { - job = <-r.queue.Jobs() + for { + job = <-r.queue.Jobs() - slog.Debug("Picking up job", "job", job, "worker", wid) + slog.Debug("Picking up job", "job", job, "worker", wid) - slog.Debug("Analysis: running", "job", job) - storage.SetStatus(job.QueryPackId, job.NWO, common.StatusQueued) + slog.Debug("Analysis: running", "job", job) + storage.SetStatus(job.QueryPackId, job.NWO, common.StatusQueued) - _, err := RunAnalysis(job) + resultFile, err := RunAnalysis(job) + if err != nil { + continue + } + + slog.Debug("Analysis run finished", "job", job) + + // TODO: FIX THIS + res := common.AnalyzeResult{ + RunAnalysisSARIF: resultFile, + RunAnalysisBQRS: "", // FIXME ? + } + r.queue.Results() <- res + storage.SetStatus(job.QueryPackId, job.NWO, common.StatusSuccess) + storage.SetResult(job.QueryPackId, job.NWO, res) + + } + */ +} + +// RunAnalysisJob runs a CodeQL analysis job (AnalyzeJob) returning an AnalyzeResult +func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) { + var result = common.AnalyzeResult{ + RequestId: job.RequestId, + ResultCount: 0, + ResultArchiveURL: "", + Status: common.StatusError, + } + + // Create a temporary directory + tempDir := filepath.Join(os.TempDir(), uuid.New().String()) + if err := os.MkdirAll(tempDir, 0755); err != nil { + return result, fmt.Errorf("failed to create temporary directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Extract the query pack + // TODO: download from the 'job' query pack URL + // utils.downloadFile + queryPackPath := filepath.Join(tempDir, "qp-54674") + utils.UntarGz("qp-54674.tgz", queryPackPath) + + // Perform the CodeQL analysis + runResult, err := codeql.RunQuery("google_flatbuffers_db.zip", "cpp", queryPackPath, tempDir) + if err != nil { + return result, fmt.Errorf("failed to run analysis: %w", err) + } + + // Generate a ZIP archive containing SARIF and BQRS files + resultsArchive, err := codeql.GenerateResultsZipArchive(runResult) + if err != nil { + return result, fmt.Errorf("failed to generate results archive: %w", err) + } + + // TODO: Upload the archive to storage + slog.Info("Results archive size", slog.Int("size", len(resultsArchive))) + slog.Info("Analysis job successful.") + + result = common.AnalyzeResult{ + RequestId: job.RequestId, + ResultCount: runResult.ResultCount, + ResultArchiveURL: "REPLACE_THIS_WITH_STORED_RESULTS_ARCHIVE", // TODO + Status: common.StatusSuccess, + } + + return result, nil +} + +// RunWorker runs a worker that processes jobs from queue +func RunWorker(queue queue.Queue, wg *sync.WaitGroup) { + defer wg.Done() + for job := range queue.Jobs() { + slog.Info("Running analysis job", slog.Any("job", job)) + result, err := RunAnalysisJob(job) if err != nil { + slog.Error("Failed to run analysis job", slog.Any("error", err)) continue } - - slog.Debug("Analysis run finished", "job", job) - - res := common.AnalyzeResult{} - r.queue.Results() <- res - storage.SetStatus(job.QueryPackId, job.NWO, common.StatusSuccess) - storage.SetResult(job.QueryPackId, job.NWO, res) - + slog.Info("Analysis job completed", slog.Any("result", result)) + queue.Results() <- result } } - -func RunAnalysis(job common.AnalyzeJob) (string, error) { - // TODO Add multi-language tests including queryLanguage - // queryPackID, queryLanguage, dbOwner, dbRepo := - // job.QueryPackId, job.QueryLanguage, job.NWO.Owner, job.NWO.Repo - queryPackID, dbOwner, dbRepo := - job.QueryPackId, job.NWO.Owner, job.NWO.Repo - - serverRoot := os.Getenv("MRVA_SERVER_ROOT") - - // Set up derived paths - dbPath := filepath.Join(serverRoot, "var/codeql/dbs", dbOwner, dbRepo) - dbZip := filepath.Join(serverRoot, "codeql/dbs", dbOwner, dbRepo, - fmt.Sprintf("%s_%s_db.zip", dbOwner, dbRepo)) - dbExtract := filepath.Join(serverRoot, "var/codeql/dbs", dbOwner, dbRepo) - - queryPack := filepath.Join(serverRoot, - "var/codeql/querypacks", fmt.Sprintf("qp-%d.tgz", queryPackID)) - queryExtract := filepath.Join(serverRoot, - "var/codeql/querypacks", fmt.Sprintf("qp-%d", queryPackID)) - - queryOutDir := filepath.Join(serverRoot, - "var/codeql/sarif/localrun", dbOwner, dbRepo) - queryOutFile := filepath.Join(queryOutDir, - fmt.Sprintf("%s_%s.sarif", dbOwner, dbRepo)) - - // Prepare directory, extract database - if err := os.MkdirAll(dbExtract, 0755); err != nil { - slog.Error("Failed to create DB directory %s: %v", dbExtract, err) - return "", err - } - - if err := utils.UnzipFile(dbZip, dbExtract); err != nil { - slog.Error("Failed to unzip DB", dbZip, err) - return "", err - } - - // Prepare directory, extract query pack - if err := os.MkdirAll(queryExtract, 0755); err != nil { - slog.Error("Failed to create query pack directory %s: %v", queryExtract, err) - return "", err - } - - if err := utils.UntarGz(queryPack, queryExtract); err != nil { - slog.Error("Failed to extract querypack %s: %v", queryPack, err) - return "", err - } - - // Prepare query result directory - if err := os.MkdirAll(queryOutDir, 0755); err != nil { - slog.Error("Failed to create query result directory %s: %v", queryOutDir, err) - return "", err - } - - // Run database analyze - cmd := exec.Command("codeql", "database", "analyze", - "--format=sarif-latest", "--rerun", "--output", queryOutFile, - "-j8", dbPath, queryExtract) - cmd.Dir = serverRoot - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - if err := cmd.Run(); err != nil { - slog.Error("codeql database analyze failed:", "error", err, "job", job) - storage.SetStatus(job.QueryPackId, job.NWO, common.StatusError) - return "", err - } - - // Return result path - return queryOutFile, nil -} diff --git a/pkg/codeql/codeql.go b/pkg/codeql/codeql.go index d96373e..a44efd8 100644 --- a/pkg/codeql/codeql.go +++ b/pkg/codeql/codeql.go @@ -16,7 +16,6 @@ import ( "gopkg.in/yaml.v3" ) -// Helper Functions func contains(slice []string, item string) bool { for _, s := range slice { if s == item { @@ -26,7 +25,6 @@ func contains(slice []string, item string) bool { return false } -// Main Functions func getCodeQLCLIPath() (string, error) { // get the CODEQL_CLI_PATH environment variable codeqlCliPath := os.Getenv("CODEQL_CLI_PATH") diff --git a/pkg/queue/rabbitmq.go b/pkg/queue/rabbitmq.go new file mode 100644 index 0000000..7d32b1f --- /dev/null +++ b/pkg/queue/rabbitmq.go @@ -0,0 +1,163 @@ +package queue + +import ( + "mrvacommander/pkg/common" + "mrvacommander/pkg/storage" + + "context" + "encoding/json" + "fmt" + "log" + "time" + + amqp "github.com/rabbitmq/amqp091-go" + "golang.org/x/exp/slog" +) + +type RabbitMQQueue struct { + jobs chan common.AnalyzeJob + results chan common.AnalyzeResult + conn *amqp.Connection + channel *amqp.Channel +} + +func InitializeRabbitMQQueue( + host string, + port int16, + user string, + password string, +) (*RabbitMQQueue, error) { + const ( + tryCount = 5 + retryDelaySec = 3 + jobsQueueName = "tasks" + resultsQueueName = "results" + ) + + var conn *amqp.Connection + var err error + + rabbitMQURL := fmt.Sprintf("amqp://%s:%s@%s:%d/", user, password, host, port) + + for i := 0; i < tryCount; i++ { + slog.Info("Attempting to connect to RabbitMQ", slog.Int("attempt", i+1)) + conn, err = amqp.Dial(rabbitMQURL) + if err != nil { + slog.Warn("Failed to connect to RabbitMQ", "error", err) + if i < tryCount-1 { + slog.Info("Retrying", "seconds", retryDelaySec) + time.Sleep(retryDelaySec * time.Second) + } + } else { + // successfully connected to RabbitMQ + break + } + } + if err != nil { + return nil, fmt.Errorf("failed to connect: %w", err) + } + + slog.Info("Connected to RabbitMQ") + + ch, err := conn.Channel() + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to open a channel: %w", err) + } + + _, err = ch.QueueDeclare(jobsQueueName, false, false, false, true, nil) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to declare tasks queue: %w", err) + } + + _, err = ch.QueueDeclare(resultsQueueName, false, false, false, true, nil) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to declare results queue: %w", err) + } + + err = ch.Qos(1, 0, false) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to set QoS: %w", err) + } + + result := RabbitMQQueue{ + conn: conn, + channel: ch, + jobs: make(chan common.AnalyzeJob), + results: make(chan common.AnalyzeResult), + } + + slog.Info("Starting tasks consumer") + go result.ConsumeJobs(jobsQueueName) + + slog.Info("Starting results publisher") + go result.PublishResults(resultsQueueName) + + return &result, nil +} + +func (q *RabbitMQQueue) Jobs() chan common.AnalyzeJob { + return q.jobs +} + +func (q *RabbitMQQueue) Results() chan common.AnalyzeResult { + return q.results +} + +func (q *RabbitMQQueue) StartAnalyses(analysis_repos *map[common.NameWithOwner]storage.DBLocation, session_id int, session_language string) { + // TODO: Implement + log.Fatal("unimplemented") +} + +func (q *RabbitMQQueue) Close() { + q.channel.Close() + q.conn.Close() +} + +func (q *RabbitMQQueue) ConsumeJobs(queueName string) { + msgs, err := q.channel.Consume(queueName, "", true, false, false, false, nil) + if err != nil { + slog.Error("failed to register a consumer", slog.Any("error", err)) + } + + for msg := range msgs { + job := common.AnalyzeJob{} + err := json.Unmarshal(msg.Body, &job) + if err != nil { + slog.Error("failed to unmarshal job", slog.Any("error", err)) + continue + } + q.jobs <- job + } + close(q.jobs) +} + +func (q *RabbitMQQueue) PublishResults(queueName string) { + for result := range q.results { + q.publishResult(queueName, result) + } +} + +func (q *RabbitMQQueue) publishResult(queueName string, result interface{}) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + resultBytes, err := json.Marshal(result) + if err != nil { + slog.Error("failed to marshal result", slog.Any("error", err)) + return + } + + slog.Info("Publishing result", slog.String("result", string(resultBytes))) + err = q.channel.PublishWithContext(ctx, "", queueName, false, false, + amqp.Publishing{ + ContentType: "application/json", + Body: resultBytes, + }) + if err != nil { + slog.Error("failed to publish result", slog.Any("error", err)) + } +} diff --git a/utils/download.go b/utils/download.go new file mode 100644 index 0000000..ec1ae42 --- /dev/null +++ b/utils/download.go @@ -0,0 +1,29 @@ +package utils + +import ( + "fmt" + "io" + "net/http" + "os" +) + +func downloadFile(url string, dest string) error { + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("failed to download file: %w", err) + } + defer resp.Body.Close() + + out, err := os.Create(dest) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + if err != nil { + return fmt.Errorf("failed to copy file content: %w", err) + } + + return nil +}