From 3218f64bcf17f00fe5fa9eb7c5b28193a7b68ccd Mon Sep 17 00:00:00 2001 From: Nicolas Will Date: Fri, 14 Jun 2024 12:48:33 +0200 Subject: [PATCH 1/9] Move archive functions into utils package --- pkg/agent/agent.go | 110 ++------------------------------------------ utils/archive.go | 111 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 107 deletions(-) create mode 100644 utils/archive.go diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 3a62fbe..7e78b5c 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -6,14 +6,11 @@ import ( "mrvacommander/pkg/qpstore" "mrvacommander/pkg/queue" "mrvacommander/pkg/storage" + "mrvacommander/utils" "log/slog" - "archive/tar" - "archive/zip" - "compress/gzip" "fmt" - "io" "path/filepath" "os" @@ -106,7 +103,7 @@ func (r *RunnerSingle) RunAnalysis(job common.AnalyzeJob) (string, error) { return "", err } - if err := unzipFile(dbZip, dbExtract); err != nil { + if err := utils.UnzipFile(dbZip, dbExtract); err != nil { slog.Error("Failed to unzip DB", dbZip, err) return "", err } @@ -117,7 +114,7 @@ func (r *RunnerSingle) RunAnalysis(job common.AnalyzeJob) (string, error) { return "", err } - if err := untarGz(queryPack, queryExtract); err != nil { + if err := utils.UntarGz(queryPack, queryExtract); err != nil { slog.Error("Failed to extract querypack %s: %v", queryPack, err) return "", err } @@ -145,104 +142,3 @@ func (r *RunnerSingle) RunAnalysis(job common.AnalyzeJob) (string, error) { // Return result path return queryOutFile, nil } - -// unzipFile extracts a zip file to the specified destination -func unzipFile(zipFile, dest string) error { - r, err := zip.OpenReader(zipFile) - if err != nil { - return err - } - defer r.Close() - - for _, f := range r.File { - fPath := filepath.Join(dest, f.Name) - if f.FileInfo().IsDir() { - if err := os.MkdirAll(fPath, os.ModePerm); err != nil { - return err - } - continue - } - - if err := os.MkdirAll(filepath.Dir(fPath), os.ModePerm); err != nil { - return err - } - - outFile, err := os.OpenFile(fPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) - if err != nil { - return err - } - - rc, err := f.Open() - if err != nil { - outFile.Close() - return err - } - - _, err = io.Copy(outFile, rc) - - outFile.Close() - rc.Close() - - if err != nil { - return err - } - } - return nil -} - -// untarGz extracts a tar.gz file to the specified destination. -func untarGz(tarGzFile, dest string) error { - file, err := os.Open(tarGzFile) - if err != nil { - return err - } - defer file.Close() - - gzr, err := gzip.NewReader(file) - if err != nil { - return err - } - defer gzr.Close() - - return untar(gzr, dest) -} - -// untar extracts a tar archive to the specified destination. -func untar(r io.Reader, dest string) error { - tr := tar.NewReader(r) - - for { - header, err := tr.Next() - if err == io.EOF { - break - } - if err != nil { - return err - } - - fPath := filepath.Join(dest, header.Name) - if header.Typeflag == tar.TypeDir { - if err := os.MkdirAll(fPath, os.ModePerm); err != nil { - return err - } - } else { - if err := os.MkdirAll(filepath.Dir(fPath), os.ModePerm); err != nil { - return err - } - - outFile, err := os.OpenFile(fPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) - if err != nil { - return err - } - - if _, err := io.Copy(outFile, tr); err != nil { - outFile.Close() - return err - } - - outFile.Close() - } - } - - return nil -} diff --git a/utils/archive.go b/utils/archive.go new file mode 100644 index 0000000..6b4edf7 --- /dev/null +++ b/utils/archive.go @@ -0,0 +1,111 @@ +package utils + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "io" + "os" + "path/filepath" +) + +// UnzipFile extracts a zip file to the specified destination +func UnzipFile(zipFile, dest string) error { + r, err := zip.OpenReader(zipFile) + if err != nil { + return err + } + defer r.Close() + + for _, f := range r.File { + fPath := filepath.Join(dest, f.Name) + if f.FileInfo().IsDir() { + if err := os.MkdirAll(fPath, os.ModePerm); err != nil { + return err + } + continue + } + + if err := os.MkdirAll(filepath.Dir(fPath), os.ModePerm); err != nil { + return err + } + + outFile, err := os.OpenFile(fPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) + if err != nil { + return err + } + + rc, err := f.Open() + if err != nil { + outFile.Close() + return err + } + + _, err = io.Copy(outFile, rc) + + outFile.Close() + rc.Close() + + if err != nil { + return err + } + } + return nil +} + +// UntarGz extracts a tar.gz file to the specified destination. +func UntarGz(tarGzFile, dest string) error { + file, err := os.Open(tarGzFile) + if err != nil { + return err + } + defer file.Close() + + gzr, err := gzip.NewReader(file) + if err != nil { + return err + } + defer gzr.Close() + + return Untar(gzr, dest) +} + +// Untar extracts a tar archive to the specified destination. +func Untar(r io.Reader, dest string) error { + tr := tar.NewReader(r) + + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + + fPath := filepath.Join(dest, header.Name) + if header.Typeflag == tar.TypeDir { + if err := os.MkdirAll(fPath, os.ModePerm); err != nil { + return err + } + } else { + if err := os.MkdirAll(filepath.Dir(fPath), os.ModePerm); err != nil { + return err + } + + outFile, err := os.OpenFile(fPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) + if err != nil { + return err + } + + if _, err := io.Copy(outFile, tr); err != nil { + outFile.Close() + return err + } + + outFile.Close() + } + } + + return nil +} From c29daab045212a3b88d1a5b357cdb02078cc7fa1 Mon Sep 17 00:00:00 2001 From: Nicolas Will Date: Fri, 14 Jun 2024 12:55:45 +0200 Subject: [PATCH 2/9] Standardize NameWithOwner and Visible naming Acronyms are now "NWO" and "Vis" respsectively --- cmd/postgres/pgmin.go | 4 +- pkg/agent/agent.go | 12 +++--- pkg/common/types.go | 27 +++++++----- pkg/qpstore/container.go | 6 +-- pkg/qpstore/interfaces.go | 4 +- pkg/queue/interfaces.go | 2 +- pkg/queue/queue.go | 9 ++-- pkg/server/server.go | 88 +++++++++++++++++++-------------------- pkg/server/types.go | 14 +++---- pkg/storage/container.go | 6 +-- pkg/storage/interfaces.go | 4 +- pkg/storage/storage.go | 24 +++++------ 12 files changed, 102 insertions(+), 98 deletions(-) diff --git a/cmd/postgres/pgmin.go b/cmd/postgres/pgmin.go index ece850a..66189da 100644 --- a/cmd/postgres/pgmin.go +++ b/cmd/postgres/pgmin.go @@ -22,11 +22,11 @@ func main() { } // Migrate the schema: create the 'owner_repo' table from the struct - err = db.AutoMigrate(&common.OwnerRepo{}) + err = db.AutoMigrate(&common.NameWithOwner{}) if err != nil { panic("failed to migrate database") } // Create an entry in the database - db.Create(&common.OwnerRepo{Owner: "foo", Repo: "foo/bar"}) + db.Create(&common.NameWithOwner{Owner: "foo", Repo: "foo/bar"}) } diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 7e78b5c..b88e4c8 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -52,7 +52,7 @@ func (r *RunnerSingle) worker(wid int) { slog.Debug("Picking up job", "job", job, "worker", wid) slog.Debug("Analysis: running", "job", job) - storage.SetStatus(job.QueryPackId, job.ORepo, common.StatusQueued) + storage.SetStatus(job.QueryPackId, job.NWO, common.StatusQueued) resultFile, err := r.RunAnalysis(job) if err != nil { @@ -66,8 +66,8 @@ func (r *RunnerSingle) worker(wid int) { RunAnalysisBQRS: "", // FIXME ? } r.queue.Results() <- res - storage.SetStatus(job.QueryPackId, job.ORepo, common.StatusSuccess) - storage.SetResult(job.QueryPackId, job.ORepo, res) + storage.SetStatus(job.QueryPackId, job.NWO, common.StatusSuccess) + storage.SetResult(job.QueryPackId, job.NWO, res) } } @@ -75,9 +75,9 @@ func (r *RunnerSingle) worker(wid int) { func (r *RunnerSingle) RunAnalysis(job common.AnalyzeJob) (string, error) { // TODO Add multi-language tests including queryLanguage // queryPackID, queryLanguage, dbOwner, dbRepo := - // job.QueryPackId, job.QueryLanguage, job.ORL.Owner, job.ORL.Repo + // job.QueryPackId, job.QueryLanguage, job.NWO.Owner, job.NWO.Repo queryPackID, dbOwner, dbRepo := - job.QueryPackId, job.ORepo.Owner, job.ORepo.Repo + job.QueryPackId, job.NWO.Owner, job.NWO.Repo serverRoot := os.Getenv("MRVA_SERVER_ROOT") @@ -135,7 +135,7 @@ func (r *RunnerSingle) RunAnalysis(job common.AnalyzeJob) (string, error) { if err := cmd.Run(); err != nil { slog.Error("codeql database analyze failed:", "error", err, "job", job) - storage.SetStatus(job.QueryPackId, job.ORepo, common.StatusError) + storage.SetStatus(job.QueryPackId, job.NWO, common.StatusError) return "", err } diff --git a/pkg/common/types.go b/pkg/common/types.go index de7575f..1d78b7d 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -1,24 +1,29 @@ package common -type AnalyzeJob struct { - MirvaRequestID int - - QueryPackId int - QueryLanguage string - - ORepo OwnerRepo -} - -type OwnerRepo struct { +// NameWithOwner represents a repository name and its owner name. +type NameWithOwner struct { Owner string Repo string } +// AnalyzeJob represents a job specifying a repository and a query pack to analyze it with. +// This is the message format that the agent receives from the queue. +type AnalyzeJob struct { + MRVARequestID int + QueryPackId int + QueryPackURL string + QueryLanguage string + NWO NameWithOwner +} + +// AnalyzeResult represents the result of an analysis job. +// This is the message format that the agent sends to the queue. type AnalyzeResult struct { RunAnalysisSARIF string RunAnalysisBQRS string } +// Status represents the status of a job. type Status int const ( @@ -48,5 +53,5 @@ func (s Status) ToExternalString() string { type JobSpec struct { JobID int - OwnerRepo + NameWithOwner } diff --git a/pkg/qpstore/container.go b/pkg/qpstore/container.go index 120b805..864f025 100644 --- a/pkg/qpstore/container.go +++ b/pkg/qpstore/container.go @@ -57,10 +57,10 @@ func (s *StorageContainer) SaveQueryPack(tgz []byte, sessionID int) (storagePath return "todo:no-path-yet", nil } -func (s *StorageContainer) FindAvailableDBs(analysisReposRequested []common.OwnerRepo) (notFoundRepos []common.OwnerRepo, analysisRepos *map[common.OwnerRepo]qldbstore.DBLocation) { +func (s *StorageContainer) FindAvailableDBs(analysisReposRequested []common.NameWithOwner) (notFoundRepos []common.NameWithOwner, analysisRepos *map[common.NameWithOwner]qldbstore.DBLocation) { // TODO s.FindAvailableDBs() via postgres - analysisRepos = &map[common.OwnerRepo]qldbstore.DBLocation{} - notFoundRepos = []common.OwnerRepo{} + analysisRepos = &map[common.NameWithOwner]qldbstore.DBLocation{} + notFoundRepos = []common.NameWithOwner{} return notFoundRepos, analysisRepos } diff --git a/pkg/qpstore/interfaces.go b/pkg/qpstore/interfaces.go index 26081fc..08eb240 100644 --- a/pkg/qpstore/interfaces.go +++ b/pkg/qpstore/interfaces.go @@ -8,6 +8,6 @@ import ( type Storage interface { NextID() int SaveQueryPack(tgz []byte, sessionID int) (storagePath string, error error) - FindAvailableDBs(analysisReposRequested []common.OwnerRepo) (not_found_repos []common.OwnerRepo, - analysisRepos *map[common.OwnerRepo]qldbstore.DBLocation) + FindAvailableDBs(analysisReposRequested []common.NameWithOwner) (not_found_repos []common.NameWithOwner, + analysisRepos *map[common.NameWithOwner]qldbstore.DBLocation) } diff --git a/pkg/queue/interfaces.go b/pkg/queue/interfaces.go index ed9c02e..c978878 100644 --- a/pkg/queue/interfaces.go +++ b/pkg/queue/interfaces.go @@ -8,7 +8,7 @@ import ( type Queue interface { Jobs() chan common.AnalyzeJob Results() chan common.AnalyzeResult - StartAnalyses(analysis_repos *map[common.OwnerRepo]storage.DBLocation, + StartAnalyses(analysis_repos *map[common.NameWithOwner]storage.DBLocation, session_id int, session_language string) } diff --git a/pkg/queue/queue.go b/pkg/queue/queue.go index c228965..424c4d1 100644 --- a/pkg/queue/queue.go +++ b/pkg/queue/queue.go @@ -14,19 +14,18 @@ func (q *QueueSingle) Results() chan common.AnalyzeResult { return q.results } -func (q *QueueSingle) StartAnalyses(analysis_repos *map[common.OwnerRepo]storage.DBLocation, session_id int, +func (q *QueueSingle) StartAnalyses(analysis_repos *map[common.NameWithOwner]storage.DBLocation, session_id int, session_language string) { slog.Debug("Queueing codeql database analyze jobs") - for orl := range *analysis_repos { + for nwo := range *analysis_repos { info := common.AnalyzeJob{ QueryPackId: session_id, QueryLanguage: session_language, - - ORepo: orl, + NWO: nwo, } q.jobs <- info - storage.SetStatus(session_id, orl, common.StatusQueued) + storage.SetStatus(session_id, nwo, common.StatusQueued) storage.AddJob(session_id, info) } } diff --git a/pkg/server/server.go b/pkg/server/server.go index 603bb15..7592ad7 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -23,33 +23,33 @@ import ( func (c *CommanderSingle) Setup(st *Visibles) { r := mux.NewRouter() - c.st = st + c.vis = st // // First are the API endpoints that mirror those used in the github API // - r.HandleFunc("/repos/{owner}/{repo}/code-scanning/codeql/variant-analyses", c.MirvaRequest) - // /repos/hohn /mirva-controller/code-scanning/codeql/variant-analyses + r.HandleFunc("/repos/{owner}/{repo}/code-scanning/codeql/variant-analyses", c.MRVARequest) + // /repos/hohn /mrva-controller/code-scanning/codeql/variant-analyses // Or via - r.HandleFunc("/{repository_id}/code-scanning/codeql/variant-analyses", c.MirvaRequestID) + r.HandleFunc("/{repository_id}/code-scanning/codeql/variant-analyses", c.MRVARequestID) r.HandleFunc("/", c.RootHandler) // This is the standalone status request. // It's also the first request made when downloading; the difference is on the // client side's handling. - r.HandleFunc("/repos/{owner}/{repo}/code-scanning/codeql/variant-analyses/{codeql_variant_analysis_id}", c.MirvaStatus) + r.HandleFunc("/repos/{owner}/{repo}/code-scanning/codeql/variant-analyses/{codeql_variant_analysis_id}", c.MRVAStatus) - r.HandleFunc("/repos/{controller_owner}/{controller_repo}/code-scanning/codeql/variant-analyses/{codeql_variant_analysis_id}/repos/{repo_owner}/{repo_name}", c.MirvaDownloadArtifact) + r.HandleFunc("/repos/{controller_owner}/{controller_repo}/code-scanning/codeql/variant-analyses/{codeql_variant_analysis_id}/repos/{repo_owner}/{repo_name}", c.MRVADownloadArtifact) // Not implemented: - // r.HandleFunc("/codeql-query-console/codeql-variant-analysis-repo-tasks/{codeql_variant_analysis_id}/{repo_id}/{owner_id}/{controller_repo_id}", MirvaDownLoad3) - // r.HandleFunc("/github-codeql-query-console-prod/codeql-variant-analysis-repo-tasks/{codeql_variant_analysis_id}/{repo_id}", MirvaDownLoad4) + // r.HandleFunc("/codeql-query-console/codeql-variant-analysis-repo-tasks/{codeql_variant_analysis_id}/{repo_id}/{owner_id}/{controller_repo_id}", MRVADownLoad3) + // r.HandleFunc("/github-codeql-query-console-prod/codeql-variant-analysis-repo-tasks/{codeql_variant_analysis_id}/{repo_id}", MRVADownLoad4) // // Now some support API endpoints // - r.HandleFunc("/download-server/{local_path:.*}", c.MirvaDownloadServe) + r.HandleFunc("/download-server/{local_path:.*}", c.MRVADownloadServe) // // Bind to a port and pass our router in @@ -64,13 +64,13 @@ func (c *CommanderSingle) StatusResponse(w http.ResponseWriter, js common.JobSpe all_scanned := []common.ScannedRepo{} jobs := storage.GetJobList(js.JobID) for _, job := range jobs { - astat := storage.GetStatus(js.JobID, job.ORepo).ToExternalString() + astat := storage.GetStatus(js.JobID, job.NWO).ToExternalString() all_scanned = append(all_scanned, common.ScannedRepo{ Repository: common.Repository{ ID: 0, - Name: job.ORepo.Repo, - FullName: fmt.Sprintf("%s/%s", job.ORepo.Owner, job.ORepo.Repo), + Name: job.NWO.Repo, + FullName: fmt.Sprintf("%s/%s", job.NWO.Owner, job.NWO.Repo), Private: false, StargazersCount: 0, UpdatedAt: ji.UpdatedAt, @@ -82,7 +82,7 @@ func (c *CommanderSingle) StatusResponse(w http.ResponseWriter, js common.JobSpe ) } - astat := storage.GetStatus(js.JobID, js.OwnerRepo).ToExternalString() + astat := storage.GetStatus(js.JobID, js.NameWithOwner).ToExternalString() status := common.StatusResponse{ SessionId: js.JobID, @@ -116,9 +116,9 @@ func (c *CommanderSingle) RootHandler(w http.ResponseWriter, r *http.Request) { slog.Info("Request on /") } -func (c *CommanderSingle) MirvaStatus(w http.ResponseWriter, r *http.Request) { +func (c *CommanderSingle) MRVAStatus(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) - slog.Info("mrva status request for ", + slog.Info("MRVA status request for ", "owner", vars["owner"], "repo", vars["repo"], "codeql_variant_analysis_id", vars["codeql_variant_analysis_id"]) @@ -142,8 +142,8 @@ func (c *CommanderSingle) MirvaStatus(w http.ResponseWriter, r *http.Request) { job := spec[0] js := common.JobSpec{ - JobID: job.QueryPackId, - OwnerRepo: job.ORepo, + JobID: job.QueryPackId, + NameWithOwner: job.NWO, } ji := storage.GetJobInfo(js) @@ -152,7 +152,7 @@ func (c *CommanderSingle) MirvaStatus(w http.ResponseWriter, r *http.Request) { } // Download artifacts -func (c *CommanderSingle) MirvaDownloadArtifact(w http.ResponseWriter, r *http.Request) { +func (c *CommanderSingle) MRVADownloadArtifact(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) slog.Info("MRVA artifact download", "controller_owner", vars["controller_owner"], @@ -170,7 +170,7 @@ func (c *CommanderSingle) MirvaDownloadArtifact(w http.ResponseWriter, r *http.R } js := common.JobSpec{ JobID: vaid, - OwnerRepo: common.OwnerRepo{ + NameWithOwner: common.NameWithOwner{ Owner: vars["repo_owner"], Repo: vars["repo_name"], }, @@ -181,7 +181,7 @@ func (c *CommanderSingle) MirvaDownloadArtifact(w http.ResponseWriter, r *http.R func (c *CommanderSingle) DownloadResponse(w http.ResponseWriter, js common.JobSpec, vaid int) { slog.Debug("Forming download response", "session", vaid, "job", js) - astat := storage.GetStatus(vaid, js.OwnerRepo) + astat := storage.GetStatus(vaid, js.NameWithOwner) var dlr common.DownloadResponse if astat == common.StatusSuccess { @@ -234,7 +234,7 @@ func (c *CommanderSingle) DownloadResponse(w http.ResponseWriter, js common.JobS } -func (c *CommanderSingle) MirvaDownloadServe(w http.ResponseWriter, r *http.Request) { +func (c *CommanderSingle) MRVADownloadServe(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) slog.Info("File download request", "local_path", vars["local_path"]) @@ -266,16 +266,16 @@ func FileDownload(w http.ResponseWriter, path string) { } -func (c *CommanderSingle) MirvaRequestID(w http.ResponseWriter, r *http.Request) { +func (c *CommanderSingle) MRVARequestID(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) slog.Info("New mrva using repository_id=%v\n", vars["repository_id"]) } -func (c *CommanderSingle) MirvaRequest(w http.ResponseWriter, r *http.Request) { +func (c *CommanderSingle) MRVARequest(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) slog.Info("New mrva run ", "owner", vars["owner"], "repo", vars["repo"]) - session_id := c.st.ServerStore.NextID() + session_id := c.vis.ServerStore.NextID() session_owner := vars["owner"] session_controller_repo := vars["repo"] slog.Info("new run", "id: ", fmt.Sprint(session_id), session_owner, session_controller_repo) @@ -284,9 +284,9 @@ func (c *CommanderSingle) MirvaRequest(w http.ResponseWriter, r *http.Request) { return } - not_found_repos, analysisRepos := c.st.ServerStore.FindAvailableDBs(session_repositories) + not_found_repos, analysisRepos := c.vis.ServerStore.FindAvailableDBs(session_repositories) - c.st.Queue.StartAnalyses(analysisRepos, session_id, session_language) + c.vis.Queue.StartAnalyses(analysisRepos, session_id, session_language) si := SessionInfo{ ID: session_id, @@ -316,10 +316,10 @@ func (c *CommanderSingle) MirvaRequest(w http.ResponseWriter, r *http.Request) { w.Write(submit_response) } -func ORToArr(aor []common.OwnerRepo) ([]string, int) { +func nwoToNwoStringArray(nwo []common.NameWithOwner) ([]string, int) { repos := []string{} - count := len(aor) - for _, repo := range aor { + count := len(nwo) + for _, repo := range nwo { repos = append(repos, fmt.Sprintf("%s/%s", repo.Owner, repo.Repo)) } return repos, count @@ -330,17 +330,17 @@ func submit_response(sn SessionInfo) ([]byte, error) { var m_cr common.ControllerRepo var m_ac common.Actor - repos, count := ORToArr(sn.NotFoundRepos) + repos, count := nwoToNwoStringArray(sn.NotFoundRepos) r_nfr := common.NotFoundRepos{RepositoryCount: count, RepositoryFullNames: repos} - repos, count = ORToArr(sn.AccessMismatchRepos) + repos, count = nwoToNwoStringArray(sn.AccessMismatchRepos) r_amr := common.AccessMismatchRepos{RepositoryCount: count, Repositories: repos} - repos, count = ORToArr(sn.NoCodeqlDBRepos) + repos, count = nwoToNwoStringArray(sn.NoCodeqlDBRepos) r_ncd := common.NoCodeqlDBRepos{RepositoryCount: count, Repositories: repos} // TODO fill these with real values? - repos, count = ORToArr(sn.NoCodeqlDBRepos) + repos, count = nwoToNwoStringArray(sn.NoCodeqlDBRepos) r_olr := common.OverLimitRepos{RepositoryCount: count, Repositories: repos} m_skip := common.SkippedRepositories{ @@ -366,8 +366,8 @@ func submit_response(sn SessionInfo) ([]byte, error) { for _, job := range joblist { storage.SetJobInfo(common.JobSpec{ - JobID: sn.ID, - OwnerRepo: job.ORepo, + JobID: sn.ID, + NameWithOwner: job.NWO, }, common.JobInfo{ QueryLanguage: sn.Language, CreatedAt: m_sr.CreatedAt, @@ -387,28 +387,28 @@ func submit_response(sn SessionInfo) ([]byte, error) { } -func (c *CommanderSingle) collectRequestInfo(w http.ResponseWriter, r *http.Request, sessionId int) (string, []common.OwnerRepo, string, error) { +func (c *CommanderSingle) collectRequestInfo(w http.ResponseWriter, r *http.Request, sessionId int) (string, []common.NameWithOwner, string, error) { slog.Debug("Collecting session info") if r.Body == nil { err := errors.New("missing request body") log.Println(err) http.Error(w, err.Error(), http.StatusNoContent) - return "", []common.OwnerRepo{}, "", err + return "", []common.NameWithOwner{}, "", err } buf, err := io.ReadAll(r.Body) if err != nil { var w http.ResponseWriter slog.Error("Error reading MRVA submission body", "error", err.Error()) http.Error(w, err.Error(), http.StatusBadRequest) - return "", []common.OwnerRepo{}, "", err + return "", []common.NameWithOwner{}, "", err } msg, err := TrySubmitMsg(buf) if err != nil { // Unknown message slog.Error("Unknown MRVA submission body format") http.Error(w, err.Error(), http.StatusBadRequest) - return "", []common.OwnerRepo{}, "", err + return "", []common.NameWithOwner{}, "", err } // Decompose the SubmitMsg and keep information @@ -417,19 +417,19 @@ func (c *CommanderSingle) collectRequestInfo(w http.ResponseWriter, r *http.Requ slog.Error("MRVA submission body querypack has invalid format") err := errors.New("MRVA submission body querypack has invalid format") http.Error(w, err.Error(), http.StatusBadRequest) - return "", []common.OwnerRepo{}, "", err + return "", []common.NameWithOwner{}, "", err } session_tgz_ref, err := c.extract_tgz(msg.QueryPack, sessionId) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) - return "", []common.OwnerRepo{}, "", err + return "", []common.NameWithOwner{}, "", err } // 2. Save the language session_language := msg.Language // 3. Save the repositories - var session_repositories []common.OwnerRepo + var session_repositories []common.NameWithOwner for _, v := range msg.Repositories { t := strings.Split(v, "/") @@ -439,7 +439,7 @@ func (c *CommanderSingle) collectRequestInfo(w http.ResponseWriter, r *http.Requ http.Error(w, err, http.StatusBadRequest) } session_repositories = append(session_repositories, - common.OwnerRepo{Owner: t[0], Repo: t[1]}) + common.NameWithOwner{Owner: t[0], Repo: t[1]}) } return session_language, session_repositories, session_tgz_ref, nil } @@ -492,7 +492,7 @@ func (c *CommanderSingle) extract_tgz(qp string, sessionID int) (string, error) return "", err } - session_query_pack_tgz_filepath, err := c.st.ServerStore.SaveQueryPack(tgz, sessionID) + session_query_pack_tgz_filepath, err := c.vis.ServerStore.SaveQueryPack(tgz, sessionID) if err != nil { return "", err } diff --git a/pkg/server/types.go b/pkg/server/types.go index 927d2c4..ff78b3e 100644 --- a/pkg/server/types.go +++ b/pkg/server/types.go @@ -15,18 +15,18 @@ type SessionInfo struct { QueryPack string Language string - Repositories []common.OwnerRepo + Repositories []common.NameWithOwner - AccessMismatchRepos []common.OwnerRepo - NotFoundRepos []common.OwnerRepo - NoCodeqlDBRepos []common.OwnerRepo - OverLimitRepos []common.OwnerRepo + AccessMismatchRepos []common.NameWithOwner + NotFoundRepos []common.NameWithOwner + NoCodeqlDBRepos []common.NameWithOwner + OverLimitRepos []common.NameWithOwner - AnalysisRepos *map[common.OwnerRepo]storage.DBLocation + AnalysisRepos *map[common.NameWithOwner]storage.DBLocation } type CommanderSingle struct { - st *Visibles + vis *Visibles } func NewCommanderSingle() *CommanderSingle { diff --git a/pkg/storage/container.go b/pkg/storage/container.go index 3ba0214..e012032 100644 --- a/pkg/storage/container.go +++ b/pkg/storage/container.go @@ -24,10 +24,10 @@ func (s *StorageContainer) SaveQueryPack(tgz []byte, sessionID int) (storagePath return "todo:no-path-yet", nil } -func (s *StorageContainer) FindAvailableDBs(analysisReposRequested []common.OwnerRepo) (notFoundRepos []common.OwnerRepo, analysisRepos *map[common.OwnerRepo]DBLocation) { +func (s *StorageContainer) FindAvailableDBs(analysisReposRequested []common.NameWithOwner) (notFoundRepos []common.NameWithOwner, analysisRepos *map[common.NameWithOwner]DBLocation) { // TODO s.FindAvailableDBs() via postgres - analysisRepos = &map[common.OwnerRepo]DBLocation{} - notFoundRepos = []common.OwnerRepo{} + analysisRepos = &map[common.NameWithOwner]DBLocation{} + notFoundRepos = []common.NameWithOwner{} return notFoundRepos, analysisRepos } diff --git a/pkg/storage/interfaces.go b/pkg/storage/interfaces.go index 5f16b62..7faf585 100644 --- a/pkg/storage/interfaces.go +++ b/pkg/storage/interfaces.go @@ -5,6 +5,6 @@ import "mrvacommander/pkg/common" type Storage interface { NextID() int SaveQueryPack(tgz []byte, sessionID int) (storagePath string, error error) - FindAvailableDBs(analysisReposRequested []common.OwnerRepo) (not_found_repos []common.OwnerRepo, - analysisRepos *map[common.OwnerRepo]DBLocation) + FindAvailableDBs(analysisReposRequested []common.NameWithOwner) (not_found_repos []common.NameWithOwner, + analysisRepos *map[common.NameWithOwner]DBLocation) } diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index cfb19f9..6f420ec 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -67,8 +67,8 @@ func (s *StorageSingle) SaveQueryPack(tgz []byte, sessionId int) (string, error) // Determine for which repositories codeql databases are available. // // Those will be the analysis_repos. The rest will be skipped. -func (s *StorageSingle) FindAvailableDBs(analysisReposRequested []common.OwnerRepo) (not_found_repos []common.OwnerRepo, - analysisRepos *map[common.OwnerRepo]DBLocation) { +func (s *StorageSingle) FindAvailableDBs(analysisReposRequested []common.NameWithOwner) (not_found_repos []common.NameWithOwner, + analysisRepos *map[common.NameWithOwner]DBLocation) { slog.Debug("Looking for available CodeQL databases") cwd, err := os.Getwd() @@ -77,9 +77,9 @@ func (s *StorageSingle) FindAvailableDBs(analysisReposRequested []common.OwnerRe return } - analysisRepos = &map[common.OwnerRepo]DBLocation{} + analysisRepos = &map[common.NameWithOwner]DBLocation{} - not_found_repos = []common.OwnerRepo{} + not_found_repos = []common.NameWithOwner{} for _, rep := range analysisReposRequested { dbPrefix := filepath.Join(cwd, "codeql", "dbs", rep.Owner, rep.Repo) @@ -110,7 +110,7 @@ func ArtifactURL(js common.JobSpec, vaid int) (string, error) { return "", nil } - zfpath, err := PackageResults(ar, js.OwnerRepo, vaid) + zfpath, err := PackageResults(ar, js.NameWithOwner, vaid) if err != nil { slog.Error("Error packaging results:", "error", err) return "", err @@ -129,13 +129,13 @@ func GetResult(js common.JobSpec) common.AnalyzeResult { return ar } -func SetResult(sessionid int, orl common.OwnerRepo, ar common.AnalyzeResult) { +func SetResult(sessionid int, orl common.NameWithOwner, ar common.AnalyzeResult) { mutex.Lock() defer mutex.Unlock() - result[common.JobSpec{JobID: sessionid, OwnerRepo: orl}] = ar + result[common.JobSpec{JobID: sessionid, NameWithOwner: orl}] = ar } -func PackageResults(ar common.AnalyzeResult, owre common.OwnerRepo, vaid int) (zipPath string, e error) { +func PackageResults(ar common.AnalyzeResult, owre common.NameWithOwner, vaid int) (zipPath string, e error) { slog.Debug("Readying zip file with .sarif/.bqrs", "analyze-result", ar) cwd, err := os.Getwd() @@ -210,10 +210,10 @@ func SetJobInfo(js common.JobSpec, ji common.JobInfo) { info[js] = ji } -func GetStatus(sessionid int, orl common.OwnerRepo) common.Status { +func GetStatus(sessionid int, orl common.NameWithOwner) common.Status { mutex.Lock() defer mutex.Unlock() - return status[common.JobSpec{JobID: sessionid, OwnerRepo: orl}] + return status[common.JobSpec{JobID: sessionid, NameWithOwner: orl}] } func ResultAsFile(path string) (string, []byte, error) { @@ -231,10 +231,10 @@ func ResultAsFile(path string) (string, []byte, error) { return fpath, file, nil } -func SetStatus(sessionid int, orl common.OwnerRepo, s common.Status) { +func SetStatus(sessionid int, orl common.NameWithOwner, s common.Status) { mutex.Lock() defer mutex.Unlock() - status[common.JobSpec{JobID: sessionid, OwnerRepo: orl}] = s + status[common.JobSpec{JobID: sessionid, NameWithOwner: orl}] = s } func AddJob(sessionid int, job common.AnalyzeJob) { From ec4d2b3eac5aa1565fd766c93d837559a71220cb Mon Sep 17 00:00:00 2001 From: Nicolas Will Date: Sat, 15 Jun 2024 00:19:24 +0200 Subject: [PATCH 3/9] Move postgres-init-scripts to init/postgres --- docker-compose.yml | 2 +- {postgres-init-scripts => init/postgres}/dbinit.sh | 0 test/storage_test.go | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename {postgres-init-scripts => init/postgres}/dbinit.sh (100%) delete mode 100644 test/storage_test.go diff --git a/docker-compose.yml b/docker-compose.yml index e8a5e03..b57ffec 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,7 +10,7 @@ services: POSTGRES_DB: exampledb volumes: - postgres_data:/var/lib/postgresql/data - - ./postgres-init-scripts:/docker-entrypoint-initdb.d + - ./init/postgres:/docker-entrypoint-initdb.d ports: - "5432:5432" # Exposing PostgreSQL to the host expose: diff --git a/postgres-init-scripts/dbinit.sh b/init/postgres/dbinit.sh similarity index 100% rename from postgres-init-scripts/dbinit.sh rename to init/postgres/dbinit.sh diff --git a/test/storage_test.go b/test/storage_test.go deleted file mode 100644 index e69de29..0000000 From 3b06e2061fd65721ce271de4e02945cbc79274a6 Mon Sep 17 00:00:00 2001 From: Nicolas Will Date: Sat, 15 Jun 2024 00:23:14 +0200 Subject: [PATCH 4/9] Add RabbitMQ agent and containers --- cmd/agent/Dockerfile | 49 ++++ cmd/agent/main.go | 318 ++++++++++++++++++++++++- docker-compose.yml | 48 ++-- go.mod | 28 ++- go.sum | 71 ++++-- init/rabbitmq/definitions.json | 43 ++++ init/rabbitmq/rabbitmq.conf | 1 + pkg/agent/agent.go | 9 +- pkg/codeql/codeql.go | 421 +++++++++++++++++++++++++++++++++ pkg/codeql/types.go | 81 +++++++ pkg/common/types.go | 17 +- pkg/storage/storage.go | 53 +++-- 12 files changed, 1050 insertions(+), 89 deletions(-) create mode 100644 cmd/agent/Dockerfile create mode 100644 init/rabbitmq/definitions.json create mode 100644 init/rabbitmq/rabbitmq.conf create mode 100644 pkg/codeql/codeql.go create mode 100644 pkg/codeql/types.go diff --git a/cmd/agent/Dockerfile b/cmd/agent/Dockerfile new file mode 100644 index 0000000..3082169 --- /dev/null +++ b/cmd/agent/Dockerfile @@ -0,0 +1,49 @@ +FROM golang:1.22 AS builder + +# Copy the entire project +WORKDIR /app +COPY . . + +# Download dependencies +RUN go mod download + +# Set the working directory to the cmd/agent subproject +WORKDIR /app/cmd/agent + +# Build the agent +RUN go build -o /bin/mrva_agent ./main.go + +FROM ubuntu:24.10 as runner +ENV DEBIAN_FRONTEND=noninteractive + +# Build argument for CodeQL version, defaulting to the latest release +ARG CODEQL_VERSION=latest + +# Install packages +RUN apt-get update && apt-get install --no-install-recommends --assume-yes \ + unzip \ + curl \ + ca-certificates + +# If the version is 'latest', lsget the latest release version from GitHub, unzip the bundle into /opt, and delete the archive +RUN if [ "$CODEQL_VERSION" = "latest" ]; then \ + CODEQL_VERSION=$(curl -s https://api.github.com/repos/github/codeql-cli-binaries/releases/latest | grep '"tag_name"' | sed -E 's/.*"([^"]+)".*/\1/'); \ + fi && \ + echo "Using CodeQL version $CODEQL_VERSION" && \ + curl -L "https://github.com/github/codeql-cli-binaries/releases/download/$CODEQL_VERSION/codeql-linux64.zip" -o /tmp/codeql.zip && \ + unzip /tmp/codeql.zip -d /opt && \ + rm /tmp/codeql.zip + +# Set environment variables for CodeQL +ENV CODEQL_CLI_PATH=/opt/codeql + +# Set environment variable for CodeQL for `codeql database analyze` support on ARM +# This env var has no functional effect on CodeQL when running on x86_64 linux +ENV CODEQL_JAVA_HOME=/usr/lib/jvm/ + +# Copy built agent binary from the builder stage +WORKDIR /app +COPY --from=builder /bin/mrva_agent ./mrva_agent + +# Run the agent +ENTRYPOINT ["./mrva_agent"] \ No newline at end of file diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 4883155..83ae80a 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -1 +1,317 @@ -package agent +package main + +import ( + "io" + "log" + "mrvacommander/pkg/codeql" + "mrvacommander/pkg/common" + "mrvacommander/pkg/queue" + "mrvacommander/pkg/storage" + "mrvacommander/utils" + "net/http" + "path/filepath" + "runtime" + + "context" + "encoding/json" + "flag" + "fmt" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/google/uuid" + amqp "github.com/rabbitmq/amqp091-go" + "golang.org/x/exp/slog" + + "github.com/elastic/go-sysinfo" +) + +func downloadFile(url string, dest string) error { + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("failed to download file: %w", err) + } + defer resp.Body.Close() + + out, err := os.Create(dest) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + if err != nil { + return fmt.Errorf("failed to copy file content: %w", err) + } + + return nil +} + +func calculateWorkers() int { + const workerMemoryGB = 2 + + host, err := sysinfo.Host() + if err != nil { + log.Fatalf("failed to get host info: %v", err) + } + + memInfo, err := host.Memory() + if err != nil { + log.Fatalf("failed to get memory info: %v", err) + } + + // Convert total memory to GB + totalMemoryGB := memInfo.Available / (1024 * 1024 * 1024) + + // Ensure we have at least one worker + workers := int(totalMemoryGB / workerMemoryGB) + if workers < 1 { + workers = 1 + } + + // Limit the number of workers to the number of CPUs + cpuCount := runtime.NumCPU() + if workers > cpuCount { + workers = max(cpuCount, 1) + } + + return workers +} + +type RabbitMQQueue struct { + jobs chan common.AnalyzeJob + results chan common.AnalyzeResult + conn *amqp.Connection + channel *amqp.Channel +} + +func InitializeQueue(jobsQueueName, resultsQueueName string) (*RabbitMQQueue, error) { + rabbitMQHost := os.Getenv("MRVA_RABBITMQ_HOST") + rabbitMQPort := os.Getenv("MRVA_RABBITMQ_PORT") + rabbitMQUser := os.Getenv("MRVA_RABBITMQ_USER") + rabbitMQPassword := os.Getenv("MRVA_RABBITMQ_PASSWORD") + + if rabbitMQHost == "" || rabbitMQPort == "" || rabbitMQUser == "" || rabbitMQPassword == "" { + return nil, fmt.Errorf("RabbitMQ environment variables not set") + } + + rabbitMQURL := fmt.Sprintf("amqp://%s:%s@%s:%s/", rabbitMQUser, rabbitMQPassword, rabbitMQHost, rabbitMQPort) + + conn, err := amqp.Dial(rabbitMQURL) + if err != nil { + return nil, fmt.Errorf("failed to connect to RabbitMQ: %w", err) + } + + ch, err := conn.Channel() + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to open a channel: %w", err) + } + + _, err = ch.QueueDeclare(jobsQueueName, false, false, false, true, nil) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to declare tasks queue: %w", err) + } + + _, err = ch.QueueDeclare(resultsQueueName, false, false, false, true, nil) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to declare results queue: %w", err) + } + + err = ch.Qos(1, 0, false) + + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to set QoS: %w", err) + } + + return &RabbitMQQueue{ + conn: conn, + channel: ch, + jobs: make(chan common.AnalyzeJob), + results: make(chan common.AnalyzeResult), + }, nil +} + +func (q *RabbitMQQueue) Jobs() chan common.AnalyzeJob { + return q.jobs +} + +func (q *RabbitMQQueue) Results() chan common.AnalyzeResult { + return q.results +} + +func (q *RabbitMQQueue) StartAnalyses(analysis_repos *map[common.NameWithOwner]storage.DBLocation, session_id int, session_language string) { + slog.Info("Queueing codeql database analyze jobs") +} + +func (q *RabbitMQQueue) Close() { + q.channel.Close() + q.conn.Close() +} + +func (q *RabbitMQQueue) ConsumeJobs(queueName string) { + msgs, err := q.channel.Consume(queueName, "", true, false, false, false, nil) + if err != nil { + slog.Error("failed to register a consumer", slog.Any("error", err)) + } + + for msg := range msgs { + job := common.AnalyzeJob{} + err := json.Unmarshal(msg.Body, &job) + if err != nil { + slog.Error("failed to unmarshal job", slog.Any("error", err)) + continue + } + q.jobs <- job + } + close(q.jobs) +} + +func (q *RabbitMQQueue) PublishResults(queueName string) { + for result := range q.results { + q.publishResult(queueName, result) + } +} + +func (q *RabbitMQQueue) publishResult(queueName string, result interface{}) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + resultBytes, err := json.Marshal(result) + if err != nil { + slog.Error("failed to marshal result", slog.Any("error", err)) + return + } + + slog.Info("Publishing result", slog.String("result", string(resultBytes))) + err = q.channel.PublishWithContext(ctx, "", queueName, false, false, + amqp.Publishing{ + ContentType: "application/json", + Body: resultBytes, + }) + if err != nil { + slog.Error("failed to publish result", slog.Any("error", err)) + } +} + +func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) { + var result = common.AnalyzeResult{ + RequestId: job.RequestId, + ResultCount: 0, + ResultArchiveURL: "", + Status: common.StatusError, + } + + // Log job info + slog.Info("Running analysis job", slog.Any("job", job)) + + // Create a temporary directory + tempDir := filepath.Join(os.TempDir(), uuid.New().String()) + if err := os.MkdirAll(tempDir, 0755); err != nil { + return result, fmt.Errorf("failed to create temporary directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Extract the query pack + // TODO: download from the 'job' query pack URL + utils.UntarGz("qp-54674.tgz", filepath.Join(tempDir, "qp-54674")) + + // Perform the CodeQL analysis + runResult, err := codeql.RunQuery("google_flatbuffers_db.zip", "cpp", "qp-54674", tempDir) + if err != nil { + return result, fmt.Errorf("failed to run analysis: %w", err) + } + + // Generate a ZIP archive containing SARIF and BQRS files + resultsArchive, err := codeql.GenerateResultsZipArchive(runResult) + if err != nil { + return result, fmt.Errorf("failed to generate results archive: %w", err) + } + + // TODO: Upload the archive to storage + slog.Info("Results archive size", slog.Int("size", len(resultsArchive))) + slog.Info("Analysis job successful.") + + result = common.AnalyzeResult{ + RequestId: job.RequestId, + ResultCount: runResult.ResultCount, + ResultArchiveURL: "REPLACE_THIS_WITH_STORED_RESULTS_ARCHIVE", + Status: common.StatusSuccess, + } + + return result, nil +} + +func RunWorker(queue queue.Queue, wg *sync.WaitGroup) { + defer wg.Done() + for job := range queue.Jobs() { + result, err := RunAnalysisJob(job) + if err != nil { + slog.Error("failed to run analysis job", slog.Any("error", err)) + continue + } + queue.Results() <- result + } +} + +func main() { + slog.Info("Starting agent") + + workerCount := flag.Int("workers", 0, "number of workers") + flag.Parse() + + requiredEnvVars := []string{ + "MRVA_RABBITMQ_HOST", + "MRVA_RABBITMQ_PORT", + "MRVA_RABBITMQ_USER", + "MRVA_RABBITMQ_PASSWORD", + "CODEQL_JAVA_HOME", + "CODEQL_CLI_PATH", + } + + for _, envVar := range requiredEnvVars { + if os.Getenv(envVar) == "" { + log.Fatalf("Fatal: Missing required environment variable %s", envVar) + } + } + + slog.Info("Initializing RabbitMQ connection") + rabbitMQQueue, err := InitializeQueue("tasks", "results") + if err != nil { + slog.Error("failed to initialize RabbitMQ", slog.Any("error", err)) + os.Exit(1) + } + defer rabbitMQQueue.Close() + + if *workerCount == 0 { + *workerCount = calculateWorkers() + } + + slog.Info("Starting workers", slog.Int("count", *workerCount)) + var wg sync.WaitGroup + for i := 0; i < *workerCount; i++ { + wg.Add(1) + go RunWorker(rabbitMQQueue, &wg) + } + + slog.Info("Starting tasks consumer") + go rabbitMQQueue.ConsumeJobs("tasks") + + slog.Info("Starting results publisher") + go rabbitMQQueue.PublishResults("results") + + slog.Info("Agent startup complete") + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + slog.Info("Shutting down agent") + close(rabbitMQQueue.results) +} diff --git a/docker-compose.yml b/docker-compose.yml index b57ffec..97e3a8e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,3 @@ -version: '3.8' - services: postgres: image: postgres:16.3-bookworm @@ -18,32 +16,32 @@ services: networks: - backend - rabbitmq: - image: rabbitmq:3.13-management + image: rabbitmq:3-management + hostname: rabbitmq container_name: rabbitmq - environment: - RABBITMQ_DEFAULT_USER: user - RABBITMQ_DEFAULT_PASS: password + volumes: + - ./init/rabbitmq/rabbitmq.conf:/etc/rabbitmq/rabbitmq.conf:ro + - ./init/rabbitmq/definitions.json:/etc/rabbitmq/definitions.json:ro expose: - "5672" - "15672" ports: + - "5672:5672" - "15672:15672" networks: - backend - + healthcheck: + test: [ "CMD", "rabbitmqctl", "status" ] + interval: 1s server: - image: server-image + build: + context: ./cmd/server + dockerfile: Dockerfile container_name: server - environment: - - MRVA_SERVER_ROOT=/mrva/mrvacommander/cmd/server - command: sh -c "tail -f /dev/null" ports: - - "8080:8080" - volumes: - - /Users/hohn/work-gh/mrva/mrvacommander:/mrva/mrvacommander + - "8080:8080" depends_on: - postgres - rabbitmq @@ -63,6 +61,22 @@ services: volumes: - minio-data:/data + agent: + build: + context: . + dockerfile: ./cmd/agent/Dockerfile + container_name: agent + depends_on: + - rabbitmq + - minio + environment: + MRVA_RABBITMQ_HOST: rabbitmq + MRVA_RABBITMQ_PORT: 5672 + MRVA_RABBITMQ_USER: user + MRVA_RABBITMQ_PASSWORD: password + networks: + - backend + volumes: minio-data: postgres_data: @@ -71,7 +85,3 @@ volumes: networks: backend: driver: bridge - - - - diff --git a/go.mod b/go.mod index 7501e8c..115193e 100644 --- a/go.mod +++ b/go.mod @@ -3,24 +3,32 @@ module mrvacommander go 1.22.0 require ( + github.com/BurntSushi/toml v1.4.0 + github.com/elastic/go-sysinfo v1.14.0 + github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 - github.com/hohn/ghes-mirva-server v0.0.0-20240313191620-9917867ea540 - github.com/spf13/cobra v1.8.0 + github.com/rabbitmq/amqp091-go v1.10.0 + golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 + gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/postgres v1.5.9 + gorm.io/gorm v1.25.10 ) require ( - github.com/BurntSushi/toml v1.3.2 // indirect - github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/elastic/go-windows v1.0.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.6.0 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - github.com/spf13/pflag v1.0.5 // indirect - golang.org/x/crypto v0.23.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/prometheus/procfs v0.15.1 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + golang.org/x/crypto v0.24.0 // indirect golang.org/x/sync v0.7.0 // indirect - golang.org/x/text v0.15.0 // indirect - gorm.io/driver/postgres v1.5.7 // indirect - gorm.io/gorm v1.25.10 // indirect + golang.org/x/sys v0.21.0 // indirect + golang.org/x/text v0.16.0 // indirect + howett.net/plist v1.0.1 // indirect ) diff --git a/go.sum b/go.sum index b0e5edf..88085e0 100644 --- a/go.sum +++ b/go.sum @@ -1,44 +1,75 @@ -github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= -github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= -github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= +github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/elastic/go-sysinfo v1.14.0 h1:dQRtiqLycoOOla7IflZg3aN213vqJmP0lpVpKQ9lUEY= +github.com/elastic/go-sysinfo v1.14.0/go.mod h1:FKUXnZWhnYI0ueO7jhsGV3uQJ5hiz8OqM5b3oGyaRr8= +github.com/elastic/go-windows v1.0.1 h1:AlYZOldA+UJ0/2nBuqWdo90GFCgG9xuyw9SYzGUtJm0= +github.com/elastic/go-windows v1.0.1/go.mod h1:FoVvqWSun28vaDQPbj2Elfc0JahhPB7WQEGa3c814Ss= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= -github.com/hohn/ghes-mirva-server v0.0.0-20240313191620-9917867ea540 h1:ohnDVLM/VvVCVfjvSYKAPZIQhOPRKk1ZcZcMzf4yT8k= -github.com/hohn/ghes-mirva-server v0.0.0-20240313191620-9917867ea540/go.mod h1:ircD+yE4AxWL/DufgcLDi191c+JM9ge/C3yiT/0zL+U= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= -github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/rabbitmq/amqp091-go v1.10.0 h1:STpn5XsHlHGcecLmMFCtg7mqq0RnD+zFr4uzukfVhBw= +github.com/rabbitmq/amqp091-go v1.10.0/go.mod h1:Hy4jKW5kQART1u+JkDTF9YYOQUHXqMuhrgxOEeS7G4o= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= -gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= +gorm.io/driver/postgres v1.5.9 h1:DkegyItji119OlcaLjqN11kHoUgZ/j13E0jkJZgD6A8= +gorm.io/driver/postgres v1.5.9/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM= +howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= diff --git a/init/rabbitmq/definitions.json b/init/rabbitmq/definitions.json new file mode 100644 index 0000000..cb2124e --- /dev/null +++ b/init/rabbitmq/definitions.json @@ -0,0 +1,43 @@ +{ + "users": [ + { + "name": "user", + "password": "password", + "tags": "administrator" + } + ], + "vhosts": [ + { + "name": "/" + } + ], + "queues": [ + { + "name": "tasks", + "vhost": "/", + "durable": false, + "persistent": false, + "arguments": { + "x-queue-type": "classic" + } + }, + { + "name": "results", + "vhost": "/", + "durable": false, + "persistent": false, + "arguments": { + "x-queue-type": "classic" + } + } + ], + "permissions": [ + { + "user": "user", + "vhost": "/", + "configure": ".*", + "write": ".*", + "read": ".*" + } + ] +} \ No newline at end of file diff --git a/init/rabbitmq/rabbitmq.conf b/init/rabbitmq/rabbitmq.conf new file mode 100644 index 0000000..916cb91 --- /dev/null +++ b/init/rabbitmq/rabbitmq.conf @@ -0,0 +1 @@ +management.load_definitions = /etc/rabbitmq/definitions.json \ No newline at end of file diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index b88e4c8..01fa8ee 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -54,17 +54,14 @@ func (r *RunnerSingle) worker(wid int) { slog.Debug("Analysis: running", "job", job) storage.SetStatus(job.QueryPackId, job.NWO, common.StatusQueued) - resultFile, err := r.RunAnalysis(job) + _, err := RunAnalysis(job) if err != nil { continue } slog.Debug("Analysis run finished", "job", job) - res := common.AnalyzeResult{ - RunAnalysisSARIF: resultFile, - RunAnalysisBQRS: "", // FIXME ? - } + res := common.AnalyzeResult{} r.queue.Results() <- res storage.SetStatus(job.QueryPackId, job.NWO, common.StatusSuccess) storage.SetResult(job.QueryPackId, job.NWO, res) @@ -72,7 +69,7 @@ func (r *RunnerSingle) worker(wid int) { } } -func (r *RunnerSingle) RunAnalysis(job common.AnalyzeJob) (string, error) { +func RunAnalysis(job common.AnalyzeJob) (string, error) { // TODO Add multi-language tests including queryLanguage // queryPackID, queryLanguage, dbOwner, dbRepo := // job.QueryPackId, job.QueryLanguage, job.NWO.Owner, job.NWO.Repo diff --git a/pkg/codeql/codeql.go b/pkg/codeql/codeql.go new file mode 100644 index 0000000..d96373e --- /dev/null +++ b/pkg/codeql/codeql.go @@ -0,0 +1,421 @@ +package codeql + +import ( + "archive/zip" + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "log/slog" + "mrvacommander/utils" + "os" + "os/exec" + "path/filepath" + + "gopkg.in/yaml.v3" +) + +// Helper Functions +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +// Main Functions +func getCodeQLCLIPath() (string, error) { + // get the CODEQL_CLI_PATH environment variable + codeqlCliPath := os.Getenv("CODEQL_CLI_PATH") + if codeqlCliPath == "" { + return "", fmt.Errorf("CODEQL_CLI_PATH environment variable not set") + } + return codeqlCliPath, nil +} + +func GenerateResultsZipArchive(runQueryResult *RunQueryResult) ([]byte, error) { + buffer := new(bytes.Buffer) + zipWriter := zip.NewWriter(buffer) + + if runQueryResult.SarifFilePath != "" { + err := addFileToZip(zipWriter, runQueryResult.SarifFilePath, "results.sarif") + if err != nil { + return nil, fmt.Errorf("failed to add SARIF file to zip: %v", err) + } + } + + for _, relativePath := range runQueryResult.BqrsFilePaths.RelativeFilePaths { + fullPath := filepath.Join(runQueryResult.BqrsFilePaths.BasePath, relativePath) + err := addFileToZip(zipWriter, fullPath, relativePath) + if err != nil { + return nil, fmt.Errorf("failed to add BQRS file to zip: %v", err) + } + } + + err := zipWriter.Close() + if err != nil { + return nil, fmt.Errorf("failed to close zip writer: %v", err) + } + + return buffer.Bytes(), nil +} + +func addFileToZip(zipWriter *zip.Writer, filePath, zipPath string) error { + file, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("failed to open file %s: %v", filePath, err) + } + defer file.Close() + + w, err := zipWriter.Create(zipPath) + if err != nil { + return fmt.Errorf("failed to create zip entry for %s: %v", zipPath, err) + } + + _, err = io.Copy(w, file) + if err != nil { + return fmt.Errorf("failed to copy file content to zip entry for %s: %v", zipPath, err) + } + + return nil +} + +func RunQuery(database string, nwo string, queryPackPath string, tempDir string) (*RunQueryResult, error) { + path, err := getCodeQLCLIPath() + + if err != nil { + return nil, fmt.Errorf("failed to get codeql cli path: %v", err) + } + + codeql := CodeqlCli{path} + + resultsDir := filepath.Join(tempDir, "results") + if err = os.Mkdir(resultsDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create results directory: %v", err) + } + + databasePath := filepath.Join(tempDir, "db") + if utils.UnzipFile(database, databasePath) != nil { + return nil, fmt.Errorf("failed to unzip database: %v", err) + } + + dbMetadata, err := getDatabaseMetadata(databasePath) + if err != nil { + return nil, fmt.Errorf("failed to get database metadata: %v", err) + } + + // Check if the database has CreationMetadata / a SHA + var databaseSHA string + if dbMetadata.CreationMetadata == nil || dbMetadata.CreationMetadata.SHA == nil { + // If the database does not have a SHA, we can proceed regardless + slog.Warn("Database does not have a SHA") + databaseSHA = "" + } else { + databaseSHA = *dbMetadata.CreationMetadata.SHA + } + + cmd := exec.Command(codeql.Path, "database", "run-queries", "--ram=1024", "--additional-packs", queryPackPath, "--", databasePath, queryPackPath) + if output, err := cmd.CombinedOutput(); err != nil { + return nil, fmt.Errorf("failed to run queries: %v\nOutput: %s", err, output) + } + + queryPackRunResults, err := getQueryPackRunResults(codeql, databasePath, queryPackPath) + if err != nil { + return nil, fmt.Errorf("failed to get query pack run results: %v", err) + } + + sourceLocationPrefix, err := getSourceLocationPrefix(codeql, databasePath) + if err != nil { + return nil, fmt.Errorf("failed to get source location prefix: %v", err) + } + + shouldGenerateSarif := queryPackSupportsSarif(queryPackRunResults) + + if shouldGenerateSarif { + slog.Info("Query pack supports SARIF") + } else { + slog.Info("Query pack does not support SARIF") + } + + var resultCount int + var sarifFilePath string + + if shouldGenerateSarif { + sarif, err := generateSarif(codeql, nwo, databasePath, queryPackPath, databaseSHA, resultsDir) + if err != nil { + return nil, fmt.Errorf("failed to generate SARIF: %v", err) + } + resultCount = getSarifResultCount(sarif) + slog.Info("Generated SARIF", "resultCount", resultCount) + sarifFilePath = filepath.Join(resultsDir, "results.sarif") + if err := os.WriteFile(sarifFilePath, sarif, 0644); err != nil { + return nil, fmt.Errorf("failed to write SARIF file: %v", err) + } + } else { + resultCount = queryPackRunResults.TotalResultsCount + slog.Info("Did not generate SARIF", "resultCount", resultCount) + } + + slog.Info("Adjusting BQRS files") + bqrsFilePaths, err := adjustBqrsFiles(queryPackRunResults, resultsDir) + if err != nil { + return nil, fmt.Errorf("failed to adjust BQRS files: %v", err) + } + + return &RunQueryResult{ + ResultCount: resultCount, + DatabaseSHA: databaseSHA, + SourceLocationPrefix: sourceLocationPrefix, + BqrsFilePaths: bqrsFilePaths, + SarifFilePath: sarifFilePath, + }, nil +} + +func getDatabaseMetadata(databasePath string) (*DatabaseMetadata, error) { + data, err := os.ReadFile(filepath.Join(databasePath, "codeql-database.yml")) + if err != nil { + return nil, fmt.Errorf("failed to read database metadata: %v", err) + } + + var metadata DatabaseMetadata + if err := yaml.Unmarshal(data, &metadata); err != nil { + return nil, fmt.Errorf("failed to unmarshal database metadata: %v", err) + } + + return &metadata, nil +} + +func runCommand(command []string) (CodeQLCommandOutput, error) { + slog.Info("Running command", "command", command) + cmd := exec.Command(command[0], command[1:]...) + stdout, err := cmd.Output() + if err != nil { + return CodeQLCommandOutput{ExitCode: 1}, err + } + return CodeQLCommandOutput{ExitCode: 0, Stdout: string(stdout)}, nil +} + +func validateQueryMetadataObject(data []byte) (QueryMetadata, error) { + var queryMetadata QueryMetadata + if err := json.Unmarshal(data, &queryMetadata); err != nil { + return QueryMetadata{}, err + } + return queryMetadata, nil +} + +func validateBQRSInfoObject(data []byte) (BQRSInfo, error) { + var bqrsInfo BQRSInfo + if err := json.Unmarshal(data, &bqrsInfo); err != nil { + return BQRSInfo{}, err + } + return bqrsInfo, nil +} + +func getBqrsInfo(codeql CodeqlCli, bqrs string) (BQRSInfo, error) { + bqrsInfoOutput, err := runCommand([]string{codeql.Path, "bqrs", "info", "--format=json", bqrs}) + if err != nil { + return BQRSInfo{}, fmt.Errorf("unable to run codeql bqrs info. Error: %v", err) + } + if bqrsInfoOutput.ExitCode != 0 { + return BQRSInfo{}, fmt.Errorf("unable to run codeql bqrs info. Exit code: %d", bqrsInfoOutput.ExitCode) + } + return validateBQRSInfoObject([]byte(bqrsInfoOutput.Stdout)) +} + +func getQueryMetadata(codeql CodeqlCli, query string) (QueryMetadata, error) { + queryMetadataOutput, err := runCommand([]string{codeql.Path, "resolve", "metadata", "--format=json", query}) + if err != nil { + return QueryMetadata{}, fmt.Errorf("unable to run codeql resolve metadata. Error: %v", err) + } + if queryMetadataOutput.ExitCode != 0 { + return QueryMetadata{}, fmt.Errorf("unable to run codeql resolve metadata. Exit code: %d", queryMetadataOutput.ExitCode) + } + return validateQueryMetadataObject([]byte(queryMetadataOutput.Stdout)) +} + +func getQueryPackRunResults(codeql CodeqlCli, databasePath, queryPackPath string) (*QueryPackRunResults, error) { + resultsBasePath := filepath.Join(databasePath, "results") + + queryPaths := []string{} // Replace with actual query paths resolution logic + + var queries []Query + for _, queryPath := range queryPaths { + relativeBqrsFilePath := filepath.Join(queryPackPath, queryPath) + bqrsFilePath := filepath.Join(resultsBasePath, relativeBqrsFilePath) + + if _, err := os.Stat(bqrsFilePath); os.IsNotExist(err) { + return nil, fmt.Errorf("could not find BQRS file for query %s at %s", queryPath, bqrsFilePath) + } + + bqrsInfo, err := getBqrsInfo(codeql, bqrsFilePath) + if err != nil { + return nil, fmt.Errorf("failed to get BQRS info: %v", err) + } + + queryMetadata, err := getQueryMetadata(codeql, queryPath) + if err != nil { + return nil, fmt.Errorf("failed to get query metadata: %v", err) + } + + queries = append(queries, Query{ + QueryPath: queryPath, + QueryMetadata: queryMetadata, + RelativeBqrsFilePath: relativeBqrsFilePath, + BqrsInfo: bqrsInfo, + }) + } + + totalResultsCount := 0 + for _, query := range queries { + count, err := getBqrsResultCount(query.BqrsInfo) + if err != nil { + return nil, fmt.Errorf("failed to get BQRS result count: %v", err) + } + totalResultsCount += count + } + + return &QueryPackRunResults{ + Queries: queries, + TotalResultsCount: totalResultsCount, + ResultsBasePath: resultsBasePath, + }, nil +} + +func adjustBqrsFiles(queryPackRunResults *QueryPackRunResults, resultsDir string) (BqrsFilePaths, error) { + if len(queryPackRunResults.Queries) == 1 { + currentBqrsFilePath := filepath.Join(queryPackRunResults.ResultsBasePath, queryPackRunResults.Queries[0].RelativeBqrsFilePath) + newBqrsFilePath := filepath.Join(resultsDir, "results.bqrs") + + if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil { + return BqrsFilePaths{}, err + } + + if err := os.Rename(currentBqrsFilePath, newBqrsFilePath); err != nil { + return BqrsFilePaths{}, err + } + + return BqrsFilePaths{BasePath: resultsDir, RelativeFilePaths: []string{"results.bqrs"}}, nil + } + + relativeFilePaths := make([]string, len(queryPackRunResults.Queries)) + for i, query := range queryPackRunResults.Queries { + relativeFilePaths[i] = query.RelativeBqrsFilePath + } + + return BqrsFilePaths{ + BasePath: queryPackRunResults.ResultsBasePath, + RelativeFilePaths: relativeFilePaths, + }, nil +} + +func getSourceLocationPrefix(codeql CodeqlCli, databasePath string) (string, error) { + cmd := exec.Command(codeql.Path, "resolve", "database", databasePath) + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("failed to resolve database: %v\nOutput: %s", err, output) + } + + var resolvedDatabase ResolvedDatabase + if err := json.Unmarshal(output, &resolvedDatabase); err != nil { + return "", fmt.Errorf("failed to unmarshal resolved database: %v", err) + } + + return resolvedDatabase.SourceLocationPrefix, nil +} + +func queryPackSupportsSarif(queryPackRunResults *QueryPackRunResults) bool { + for _, query := range queryPackRunResults.Queries { + if !querySupportsSarif(query.QueryMetadata, query.BqrsInfo) { + return false + } + } + return true +} + +func querySupportsSarif(queryMetadata QueryMetadata, bqrsInfo BQRSInfo) bool { + return getSarifOutputType(queryMetadata, bqrsInfo.CompatibleQueryKinds) != "" +} + +func getSarifOutputType(queryMetadata QueryMetadata, compatibleQueryKinds []string) string { + if (*queryMetadata.Kind == "path-problem" || *queryMetadata.Kind == "path-alert") && contains(compatibleQueryKinds, "PathProblem") { + return "path-problem" + } + if (*queryMetadata.Kind == "problem" || *queryMetadata.Kind == "alert") && contains(compatibleQueryKinds, "Problem") { + return "problem" + } + return "" +} + +func generateSarif(codeql CodeqlCli, nwo, databasePath, queryPackPath, databaseSHA string, resultsDir string) ([]byte, error) { + sarifFile := filepath.Join(resultsDir, "results.sarif") + cmd := exec.Command(codeql.Path, "database", "interpret-results", "--format=sarif-latest", "--output="+sarifFile, "--sarif-add-snippets", "--no-group-results", databasePath, queryPackPath) + if output, err := cmd.CombinedOutput(); err != nil { + return nil, fmt.Errorf("failed to generate SARIF: %v\nOutput: %s", err, output) + } + + sarifData, err := os.ReadFile(sarifFile) + if err != nil { + return nil, fmt.Errorf("failed to read SARIF file: %v", err) + } + + var sarif Sarif + if err := json.Unmarshal(sarifData, &sarif); err != nil { + return nil, fmt.Errorf("failed to unmarshal SARIF: %v", err) + } + + injectVersionControlInfo(&sarif, nwo, databaseSHA) + sarifBytes, err := json.Marshal(sarif) + if err != nil { + return nil, fmt.Errorf("failed to marshal SARIF: %v", err) + } + + return sarifBytes, nil +} + +func injectVersionControlInfo(sarif *Sarif, nwo, databaseSHA string) { + for _, run := range sarif.Runs { + run.VersionControlProvenance = append(run.VersionControlProvenance, map[string]interface{}{ + "repositoryUri": fmt.Sprintf("%s/%s", os.Getenv("GITHUB_SERVER_URL"), nwo), + "revisionId": databaseSHA, + }) + } +} + +// getSarifResultCount returns the number of results in the SARIF file. +func getSarifResultCount(sarif []byte) int { + var sarifData Sarif + if err := json.Unmarshal(sarif, &sarifData); err != nil { + log.Printf("failed to unmarshal SARIF for result count: %v", err) + return 0 + } + count := 0 + for _, run := range sarifData.Runs { + count += len(run.Results) + } + return count +} + +// Known result set names +var KnownResultSetNames = []string{"#select", "problems"} + +// getBqrssResultCount returns the number of results in the BQRS file. +func getBqrsResultCount(bqrsInfo BQRSInfo) (int, error) { + for _, name := range KnownResultSetNames { + for _, resultSet := range bqrsInfo.ResultSets { + if resultSet.Name == name { + return resultSet.Rows, nil + } + } + } + var resultSetNames []string + for _, resultSet := range bqrsInfo.ResultSets { + resultSetNames = append(resultSetNames, resultSet.Name) + } + return 0, fmt.Errorf( + "BQRS does not contain any result sets matching known names. Expected one of %s but found %s", + KnownResultSetNames, resultSetNames, + ) +} diff --git a/pkg/codeql/types.go b/pkg/codeql/types.go new file mode 100644 index 0000000..e8b156a --- /dev/null +++ b/pkg/codeql/types.go @@ -0,0 +1,81 @@ +package codeql + +// Types +type CodeqlCli struct { + Path string +} + +type RunQueryResult struct { + ResultCount int + DatabaseSHA string + SourceLocationPrefix string + BqrsFilePaths BqrsFilePaths + SarifFilePath string +} + +type BqrsFilePaths struct { + BasePath string `json:"basePath"` + RelativeFilePaths []string `json:"relativeFilePaths"` +} + +type SarifOutputType string + +const ( + Problem SarifOutputType = "problem" + PathProblem SarifOutputType = "path-problem" +) + +type SarifRun struct { + VersionControlProvenance []interface{} `json:"versionControlProvenance,omitempty"` + Results []interface{} `json:"results"` +} + +type Sarif struct { + Runs []SarifRun `json:"runs"` +} + +type CreationMetadata struct { + SHA *string `yaml:"sha,omitempty"` + CLIVersion *string `yaml:"cliVersion,omitempty"` +} + +type DatabaseMetadata struct { + CreationMetadata *CreationMetadata `yaml:"creationMetadata,omitempty"` +} + +type QueryMetadata struct { + ID *string `json:"id,omitempty"` + Kind *string `json:"kind,omitempty"` +} + +type ResultSet struct { + Name string `json:"name"` + Rows int `json:"rows"` +} + +type BQRSInfo struct { + ResultSets []ResultSet `json:"resultSets"` + CompatibleQueryKinds []string `json:"compatibleQueryKinds"` +} + +type Query struct { + QueryPath string `json:"queryPath"` + QueryMetadata QueryMetadata `json:"queryMetadata"` + RelativeBqrsFilePath string `json:"relativeBqrsFilePath"` + BqrsInfo BQRSInfo `json:"bqrsInfo"` +} + +type QueryPackRunResults struct { + Queries []Query `json:"queries"` + TotalResultsCount int `json:"totalResultsCount"` + ResultsBasePath string `json:"resultsBasePath"` +} + +type ResolvedDatabase struct { + SourceLocationPrefix string `json:"sourceLocationPrefix"` +} + +type CodeQLCommandOutput struct { + ExitCode int `json:"exitCode"` + Stdout string `json:"stdout"` +} diff --git a/pkg/common/types.go b/pkg/common/types.go index 1d78b7d..8c8955b 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -9,18 +9,21 @@ type NameWithOwner struct { // AnalyzeJob represents a job specifying a repository and a query pack to analyze it with. // This is the message format that the agent receives from the queue. type AnalyzeJob struct { - MRVARequestID int - QueryPackId int - QueryPackURL string - QueryLanguage string - NWO NameWithOwner + RequestId int // json:"request_id" + QueryPackId int // json:"query_pack_id" + QueryPackURL string // json:"query_pack_url" + QueryLanguage string // json:"query_language" + NWO NameWithOwner // json:"nwo" } // AnalyzeResult represents the result of an analysis job. // This is the message format that the agent sends to the queue. +// Status will only ever be StatusSuccess or StatusError when sent in a result. type AnalyzeResult struct { - RunAnalysisSARIF string - RunAnalysisBQRS string + Status Status // json:"status" + RequestId int // json:"request_id" + ResultCount int // json:"result_count" + ResultArchiveURL string // json:"result_archive_url" } // Status represents the status of a job. diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 6f420ec..3dc0477 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -4,7 +4,6 @@ import ( "archive/zip" "errors" "fmt" - "io" "io/fs" "log/slog" "os" @@ -129,10 +128,10 @@ func GetResult(js common.JobSpec) common.AnalyzeResult { return ar } -func SetResult(sessionid int, orl common.NameWithOwner, ar common.AnalyzeResult) { +func SetResult(sessionid int, nwo common.NameWithOwner, ar common.AnalyzeResult) { mutex.Lock() defer mutex.Unlock() - result[common.JobSpec{JobID: sessionid, NameWithOwner: orl}] = ar + result[common.JobSpec{JobID: sessionid, NameWithOwner: nwo}] = ar } func PackageResults(ar common.AnalyzeResult, owre common.NameWithOwner, vaid int) (zipPath string, e error) { @@ -166,29 +165,31 @@ func PackageResults(ar common.AnalyzeResult, owre common.NameWithOwner, vaid int defer zwriter.Close() // Add each result file to the zip archive - names := []([]string){{ar.RunAnalysisSARIF, "results.sarif"}} - for _, fpath := range names { - file, err := os.Open(fpath[0]) - if err != nil { - return "", err - } - defer file.Close() + /* + names := []([]string){{ar.RunAnalysisSARIF, "results.sarif"}} + for _, fpath := range names { + file, err := os.Open(fpath[0]) + if err != nil { + return "", err + } + defer file.Close() - // Create a new file in the zip archive with custom name - // The client is very specific: - // if zf.Name != "results.sarif" && zf.Name != "results.bqrs" { continue } + // Create a new file in the zip archive with custom name + // The client is very specific: + // if zf.Name != "results.sarif" && zf.Name != "results.bqrs" { continue } - zipEntry, err := zwriter.Create(fpath[1]) - if err != nil { - return "", err - } + zipEntry, err := zwriter.Create(fpath[1]) + if err != nil { + return "", err + } - // Copy the contents of the file to the zip entry - _, err = io.Copy(zipEntry, file) - if err != nil { - return "", err + // Copy the contents of the file to the zip entry + _, err = io.Copy(zipEntry, file) + if err != nil { + return "", err + } } - } + */ return zpath, nil } @@ -210,10 +211,10 @@ func SetJobInfo(js common.JobSpec, ji common.JobInfo) { info[js] = ji } -func GetStatus(sessionid int, orl common.NameWithOwner) common.Status { +func GetStatus(sessionid int, nwo common.NameWithOwner) common.Status { mutex.Lock() defer mutex.Unlock() - return status[common.JobSpec{JobID: sessionid, NameWithOwner: orl}] + return status[common.JobSpec{JobID: sessionid, NameWithOwner: nwo}] } func ResultAsFile(path string) (string, []byte, error) { @@ -231,10 +232,10 @@ func ResultAsFile(path string) (string, []byte, error) { return fpath, file, nil } -func SetStatus(sessionid int, orl common.NameWithOwner, s common.Status) { +func SetStatus(sessionid int, nwo common.NameWithOwner, s common.Status) { mutex.Lock() defer mutex.Unlock() - status[common.JobSpec{JobID: sessionid, NameWithOwner: orl}] = s + status[common.JobSpec{JobID: sessionid, NameWithOwner: nwo}] = s } func AddJob(sessionid int, job common.AnalyzeJob) { From ea10403f6c3a3503adfd3d8c30cb7c761cc94e22 Mon Sep 17 00:00:00 2001 From: Nicolas Will Date: Sat, 15 Jun 2024 00:23:30 +0200 Subject: [PATCH 5/9] Bump server Ubuntu version to 24.10 --- cmd/server/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/server/Dockerfile b/cmd/server/Dockerfile index d12ebc8..23ce81a 100644 --- a/cmd/server/Dockerfile +++ b/cmd/server/Dockerfile @@ -1,5 +1,5 @@ # Use the ubuntu 22.04 base image -FROM ubuntu:22.04 +FROM ubuntu:24.10 # Set architecture to arm64 ARG ARCH=arm64 From 1a574c2f7f3c2ebf6a54dd59ef2e63e574130736 Mon Sep 17 00:00:00 2001 From: Nicolas Will Date: Sat, 15 Jun 2024 00:39:21 +0200 Subject: [PATCH 6/9] Add RabbitMQ connect retry and healthcheck --- cmd/agent/main.go | 23 ++++++++++++++++++++++- docker-compose.yml | 6 ++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 83ae80a..6326323 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -100,11 +100,32 @@ func InitializeQueue(jobsQueueName, resultsQueueName string) (*RabbitMQQueue, er rabbitMQURL := fmt.Sprintf("amqp://%s:%s@%s:%s/", rabbitMQUser, rabbitMQPassword, rabbitMQHost, rabbitMQPort) - conn, err := amqp.Dial(rabbitMQURL) + const ( + tryCount = 5 + retryDelaySec = 3 + ) + + var conn *amqp.Connection + var err error + + for i := 0; i < tryCount; i++ { + slog.Info("Attempting to connect to RabbitMQ", slog.Int("attempt", i+1)) + conn, err = amqp.Dial(rabbitMQURL) + if err != nil { + slog.Warn("Failed to connect to RabbitMQ: %w", err) + if i < tryCount-1 { + slog.Info("Retrying in %d seconds", retryDelaySec) + time.Sleep(retryDelaySec * time.Second) + } + } + } + if err != nil { return nil, fmt.Errorf("failed to connect to RabbitMQ: %w", err) } + slog.Info("Connected to RabbitMQ") + ch, err := conn.Channel() if err != nil { conn.Close() diff --git a/docker-compose.yml b/docker-compose.yml index 97e3a8e..a74cbac 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -32,8 +32,10 @@ services: networks: - backend healthcheck: - test: [ "CMD", "rabbitmqctl", "status" ] - interval: 1s + test: [ "CMD", "nc", "-z", "localhost", "5672" ] + interval: 5s + timeout: 15s + retries: 1 server: build: From e107f6cf80dd05a96aab5b550cd60af174885bdb Mon Sep 17 00:00:00 2001 From: Nicolas Will Date: Sat, 15 Jun 2024 23:10:37 +0200 Subject: [PATCH 7/9] Fix ENV CODEQL_JAVA_HOME in Dockerfile --- cmd/agent/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/agent/Dockerfile b/cmd/agent/Dockerfile index 3082169..6b506f7 100644 --- a/cmd/agent/Dockerfile +++ b/cmd/agent/Dockerfile @@ -39,7 +39,7 @@ ENV CODEQL_CLI_PATH=/opt/codeql # Set environment variable for CodeQL for `codeql database analyze` support on ARM # This env var has no functional effect on CodeQL when running on x86_64 linux -ENV CODEQL_JAVA_HOME=/usr/lib/jvm/ +ENV CODEQL_JAVA_HOME=/usr/ # Copy built agent binary from the builder stage WORKDIR /app From 7ea45cb17603adc01a3d68170993c4d7073a6306 Mon Sep 17 00:00:00 2001 From: Nicolas Will Date: Sat, 15 Jun 2024 23:12:11 +0200 Subject: [PATCH 8/9] Separate queue and agent logic and refactor --- cmd/agent/main.go | 304 ++++++------------------------------------ pkg/agent/agent.go | 187 +++++++++++++------------- pkg/codeql/codeql.go | 2 - pkg/queue/rabbitmq.go | 163 ++++++++++++++++++++++ utils/download.go | 29 ++++ 5 files changed, 328 insertions(+), 357 deletions(-) create mode 100644 pkg/queue/rabbitmq.go create mode 100644 utils/download.go diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 6326323..2df182f 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -1,73 +1,41 @@ package main import ( - "io" - "log" - "mrvacommander/pkg/codeql" - "mrvacommander/pkg/common" + "mrvacommander/pkg/agent" "mrvacommander/pkg/queue" - "mrvacommander/pkg/storage" - "mrvacommander/utils" - "net/http" - "path/filepath" - "runtime" - - "context" - "encoding/json" - "flag" - "fmt" - "os" "os/signal" - "sync" + "strconv" "syscall" - "time" - "github.com/google/uuid" - amqp "github.com/rabbitmq/amqp091-go" - "golang.org/x/exp/slog" + "flag" + "os" + "runtime" + "sync" "github.com/elastic/go-sysinfo" + "golang.org/x/exp/slog" ) -func downloadFile(url string, dest string) error { - resp, err := http.Get(url) - if err != nil { - return fmt.Errorf("failed to download file: %w", err) - } - defer resp.Body.Close() - - out, err := os.Create(dest) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - return fmt.Errorf("failed to copy file content: %w", err) - } - - return nil -} - func calculateWorkers() int { - const workerMemoryGB = 2 + const workerMemoryMB = 2048 // 2 GB host, err := sysinfo.Host() if err != nil { - log.Fatalf("failed to get host info: %v", err) + slog.Error("failed to get host info", "error", err) + os.Exit(1) } memInfo, err := host.Memory() if err != nil { - log.Fatalf("failed to get memory info: %v", err) + slog.Error("failed to get memory info", "error", err) + os.Exit(1) } - // Convert total memory to GB - totalMemoryGB := memInfo.Available / (1024 * 1024 * 1024) + // Get available memory in MB + totalMemoryMB := memInfo.Available / (1024 * 1024) // Ensure we have at least one worker - workers := int(totalMemoryGB / workerMemoryGB) + workers := int(totalMemoryMB / workerMemoryMB) if workers < 1 { workers = 1 } @@ -81,206 +49,6 @@ func calculateWorkers() int { return workers } -type RabbitMQQueue struct { - jobs chan common.AnalyzeJob - results chan common.AnalyzeResult - conn *amqp.Connection - channel *amqp.Channel -} - -func InitializeQueue(jobsQueueName, resultsQueueName string) (*RabbitMQQueue, error) { - rabbitMQHost := os.Getenv("MRVA_RABBITMQ_HOST") - rabbitMQPort := os.Getenv("MRVA_RABBITMQ_PORT") - rabbitMQUser := os.Getenv("MRVA_RABBITMQ_USER") - rabbitMQPassword := os.Getenv("MRVA_RABBITMQ_PASSWORD") - - if rabbitMQHost == "" || rabbitMQPort == "" || rabbitMQUser == "" || rabbitMQPassword == "" { - return nil, fmt.Errorf("RabbitMQ environment variables not set") - } - - rabbitMQURL := fmt.Sprintf("amqp://%s:%s@%s:%s/", rabbitMQUser, rabbitMQPassword, rabbitMQHost, rabbitMQPort) - - const ( - tryCount = 5 - retryDelaySec = 3 - ) - - var conn *amqp.Connection - var err error - - for i := 0; i < tryCount; i++ { - slog.Info("Attempting to connect to RabbitMQ", slog.Int("attempt", i+1)) - conn, err = amqp.Dial(rabbitMQURL) - if err != nil { - slog.Warn("Failed to connect to RabbitMQ: %w", err) - if i < tryCount-1 { - slog.Info("Retrying in %d seconds", retryDelaySec) - time.Sleep(retryDelaySec * time.Second) - } - } - } - - if err != nil { - return nil, fmt.Errorf("failed to connect to RabbitMQ: %w", err) - } - - slog.Info("Connected to RabbitMQ") - - ch, err := conn.Channel() - if err != nil { - conn.Close() - return nil, fmt.Errorf("failed to open a channel: %w", err) - } - - _, err = ch.QueueDeclare(jobsQueueName, false, false, false, true, nil) - if err != nil { - conn.Close() - return nil, fmt.Errorf("failed to declare tasks queue: %w", err) - } - - _, err = ch.QueueDeclare(resultsQueueName, false, false, false, true, nil) - if err != nil { - conn.Close() - return nil, fmt.Errorf("failed to declare results queue: %w", err) - } - - err = ch.Qos(1, 0, false) - - if err != nil { - conn.Close() - return nil, fmt.Errorf("failed to set QoS: %w", err) - } - - return &RabbitMQQueue{ - conn: conn, - channel: ch, - jobs: make(chan common.AnalyzeJob), - results: make(chan common.AnalyzeResult), - }, nil -} - -func (q *RabbitMQQueue) Jobs() chan common.AnalyzeJob { - return q.jobs -} - -func (q *RabbitMQQueue) Results() chan common.AnalyzeResult { - return q.results -} - -func (q *RabbitMQQueue) StartAnalyses(analysis_repos *map[common.NameWithOwner]storage.DBLocation, session_id int, session_language string) { - slog.Info("Queueing codeql database analyze jobs") -} - -func (q *RabbitMQQueue) Close() { - q.channel.Close() - q.conn.Close() -} - -func (q *RabbitMQQueue) ConsumeJobs(queueName string) { - msgs, err := q.channel.Consume(queueName, "", true, false, false, false, nil) - if err != nil { - slog.Error("failed to register a consumer", slog.Any("error", err)) - } - - for msg := range msgs { - job := common.AnalyzeJob{} - err := json.Unmarshal(msg.Body, &job) - if err != nil { - slog.Error("failed to unmarshal job", slog.Any("error", err)) - continue - } - q.jobs <- job - } - close(q.jobs) -} - -func (q *RabbitMQQueue) PublishResults(queueName string) { - for result := range q.results { - q.publishResult(queueName, result) - } -} - -func (q *RabbitMQQueue) publishResult(queueName string, result interface{}) { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - resultBytes, err := json.Marshal(result) - if err != nil { - slog.Error("failed to marshal result", slog.Any("error", err)) - return - } - - slog.Info("Publishing result", slog.String("result", string(resultBytes))) - err = q.channel.PublishWithContext(ctx, "", queueName, false, false, - amqp.Publishing{ - ContentType: "application/json", - Body: resultBytes, - }) - if err != nil { - slog.Error("failed to publish result", slog.Any("error", err)) - } -} - -func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) { - var result = common.AnalyzeResult{ - RequestId: job.RequestId, - ResultCount: 0, - ResultArchiveURL: "", - Status: common.StatusError, - } - - // Log job info - slog.Info("Running analysis job", slog.Any("job", job)) - - // Create a temporary directory - tempDir := filepath.Join(os.TempDir(), uuid.New().String()) - if err := os.MkdirAll(tempDir, 0755); err != nil { - return result, fmt.Errorf("failed to create temporary directory: %v", err) - } - defer os.RemoveAll(tempDir) - - // Extract the query pack - // TODO: download from the 'job' query pack URL - utils.UntarGz("qp-54674.tgz", filepath.Join(tempDir, "qp-54674")) - - // Perform the CodeQL analysis - runResult, err := codeql.RunQuery("google_flatbuffers_db.zip", "cpp", "qp-54674", tempDir) - if err != nil { - return result, fmt.Errorf("failed to run analysis: %w", err) - } - - // Generate a ZIP archive containing SARIF and BQRS files - resultsArchive, err := codeql.GenerateResultsZipArchive(runResult) - if err != nil { - return result, fmt.Errorf("failed to generate results archive: %w", err) - } - - // TODO: Upload the archive to storage - slog.Info("Results archive size", slog.Int("size", len(resultsArchive))) - slog.Info("Analysis job successful.") - - result = common.AnalyzeResult{ - RequestId: job.RequestId, - ResultCount: runResult.ResultCount, - ResultArchiveURL: "REPLACE_THIS_WITH_STORED_RESULTS_ARCHIVE", - Status: common.StatusSuccess, - } - - return result, nil -} - -func RunWorker(queue queue.Queue, wg *sync.WaitGroup) { - defer wg.Done() - for job := range queue.Jobs() { - result, err := RunAnalysisJob(job) - if err != nil { - slog.Error("failed to run analysis job", slog.Any("error", err)) - continue - } - queue.Results() <- result - } -} - func main() { slog.Info("Starting agent") @@ -297,13 +65,27 @@ func main() { } for _, envVar := range requiredEnvVars { - if os.Getenv(envVar) == "" { - log.Fatalf("Fatal: Missing required environment variable %s", envVar) + if _, ok := os.LookupEnv(envVar); !ok { + slog.Error("Missing required environment variable %s", envVar) + os.Exit(1) } } - slog.Info("Initializing RabbitMQ connection") - rabbitMQQueue, err := InitializeQueue("tasks", "results") + rmqHost := os.Getenv("MRVA_RABBITMQ_HOST") + rmqPort := os.Getenv("MRVA_RABBITMQ_PORT") + rmqUser := os.Getenv("MRVA_RABBITMQ_USER") + rmqPass := os.Getenv("MRVA_RABBITMQ_PASSWORD") + + rmqPortAsInt, err := strconv.Atoi(rmqPort) + + if err != nil { + slog.Error("Failed to parse RabbitMQ port", slog.Any("error", err)) + os.Exit(1) + } + + slog.Info("Initializing RabbitMQ queue") + + rabbitMQQueue, err := queue.InitializeRabbitMQQueue(rmqHost, int16(rmqPortAsInt), rmqUser, rmqPass) if err != nil { slog.Error("failed to initialize RabbitMQ", slog.Any("error", err)) os.Exit(1) @@ -318,21 +100,21 @@ func main() { var wg sync.WaitGroup for i := 0; i < *workerCount; i++ { wg.Add(1) - go RunWorker(rabbitMQQueue, &wg) + go agent.RunWorker(rabbitMQQueue, &wg) } - slog.Info("Starting tasks consumer") - go rabbitMQQueue.ConsumeJobs("tasks") - - slog.Info("Starting results publisher") - go rabbitMQQueue.PublishResults("results") - slog.Info("Agent startup complete") + // Gracefully exit on SIGINT/SIGTERM (TODO: add job cleanup) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - <-sigChan - slog.Info("Shutting down agent") - close(rabbitMQQueue.results) + go func() { + <-sigChan + slog.Info("Shutting down agent") + rabbitMQQueue.Close() + os.Exit(0) + }() + + select {} } diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 01fa8ee..018d759 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -1,20 +1,20 @@ package agent import ( + "fmt" + "log/slog" + "mrvacommander/pkg/codeql" "mrvacommander/pkg/common" "mrvacommander/pkg/logger" "mrvacommander/pkg/qpstore" "mrvacommander/pkg/queue" "mrvacommander/pkg/storage" "mrvacommander/utils" - - "log/slog" - - "fmt" - "path/filepath" - "os" - "os/exec" + "path/filepath" + "sync" + + "github.com/google/uuid" ) type RunnerSingle struct { @@ -40,102 +40,101 @@ type Visibles struct { } func (c *RunnerSingle) Setup(st *Visibles) { - return + // TODO: implement } func (r *RunnerSingle) worker(wid int) { - var job common.AnalyzeJob + // TODO: reimplement this later + /* + var job common.AnalyzeJob - for { - job = <-r.queue.Jobs() + for { + job = <-r.queue.Jobs() - slog.Debug("Picking up job", "job", job, "worker", wid) + slog.Debug("Picking up job", "job", job, "worker", wid) - slog.Debug("Analysis: running", "job", job) - storage.SetStatus(job.QueryPackId, job.NWO, common.StatusQueued) + slog.Debug("Analysis: running", "job", job) + storage.SetStatus(job.QueryPackId, job.NWO, common.StatusQueued) - _, err := RunAnalysis(job) + resultFile, err := RunAnalysis(job) + if err != nil { + continue + } + + slog.Debug("Analysis run finished", "job", job) + + // TODO: FIX THIS + res := common.AnalyzeResult{ + RunAnalysisSARIF: resultFile, + RunAnalysisBQRS: "", // FIXME ? + } + r.queue.Results() <- res + storage.SetStatus(job.QueryPackId, job.NWO, common.StatusSuccess) + storage.SetResult(job.QueryPackId, job.NWO, res) + + } + */ +} + +// RunAnalysisJob runs a CodeQL analysis job (AnalyzeJob) returning an AnalyzeResult +func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) { + var result = common.AnalyzeResult{ + RequestId: job.RequestId, + ResultCount: 0, + ResultArchiveURL: "", + Status: common.StatusError, + } + + // Create a temporary directory + tempDir := filepath.Join(os.TempDir(), uuid.New().String()) + if err := os.MkdirAll(tempDir, 0755); err != nil { + return result, fmt.Errorf("failed to create temporary directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Extract the query pack + // TODO: download from the 'job' query pack URL + // utils.downloadFile + queryPackPath := filepath.Join(tempDir, "qp-54674") + utils.UntarGz("qp-54674.tgz", queryPackPath) + + // Perform the CodeQL analysis + runResult, err := codeql.RunQuery("google_flatbuffers_db.zip", "cpp", queryPackPath, tempDir) + if err != nil { + return result, fmt.Errorf("failed to run analysis: %w", err) + } + + // Generate a ZIP archive containing SARIF and BQRS files + resultsArchive, err := codeql.GenerateResultsZipArchive(runResult) + if err != nil { + return result, fmt.Errorf("failed to generate results archive: %w", err) + } + + // TODO: Upload the archive to storage + slog.Info("Results archive size", slog.Int("size", len(resultsArchive))) + slog.Info("Analysis job successful.") + + result = common.AnalyzeResult{ + RequestId: job.RequestId, + ResultCount: runResult.ResultCount, + ResultArchiveURL: "REPLACE_THIS_WITH_STORED_RESULTS_ARCHIVE", // TODO + Status: common.StatusSuccess, + } + + return result, nil +} + +// RunWorker runs a worker that processes jobs from queue +func RunWorker(queue queue.Queue, wg *sync.WaitGroup) { + defer wg.Done() + for job := range queue.Jobs() { + slog.Info("Running analysis job", slog.Any("job", job)) + result, err := RunAnalysisJob(job) if err != nil { + slog.Error("Failed to run analysis job", slog.Any("error", err)) continue } - - slog.Debug("Analysis run finished", "job", job) - - res := common.AnalyzeResult{} - r.queue.Results() <- res - storage.SetStatus(job.QueryPackId, job.NWO, common.StatusSuccess) - storage.SetResult(job.QueryPackId, job.NWO, res) - + slog.Info("Analysis job completed", slog.Any("result", result)) + queue.Results() <- result } } - -func RunAnalysis(job common.AnalyzeJob) (string, error) { - // TODO Add multi-language tests including queryLanguage - // queryPackID, queryLanguage, dbOwner, dbRepo := - // job.QueryPackId, job.QueryLanguage, job.NWO.Owner, job.NWO.Repo - queryPackID, dbOwner, dbRepo := - job.QueryPackId, job.NWO.Owner, job.NWO.Repo - - serverRoot := os.Getenv("MRVA_SERVER_ROOT") - - // Set up derived paths - dbPath := filepath.Join(serverRoot, "var/codeql/dbs", dbOwner, dbRepo) - dbZip := filepath.Join(serverRoot, "codeql/dbs", dbOwner, dbRepo, - fmt.Sprintf("%s_%s_db.zip", dbOwner, dbRepo)) - dbExtract := filepath.Join(serverRoot, "var/codeql/dbs", dbOwner, dbRepo) - - queryPack := filepath.Join(serverRoot, - "var/codeql/querypacks", fmt.Sprintf("qp-%d.tgz", queryPackID)) - queryExtract := filepath.Join(serverRoot, - "var/codeql/querypacks", fmt.Sprintf("qp-%d", queryPackID)) - - queryOutDir := filepath.Join(serverRoot, - "var/codeql/sarif/localrun", dbOwner, dbRepo) - queryOutFile := filepath.Join(queryOutDir, - fmt.Sprintf("%s_%s.sarif", dbOwner, dbRepo)) - - // Prepare directory, extract database - if err := os.MkdirAll(dbExtract, 0755); err != nil { - slog.Error("Failed to create DB directory %s: %v", dbExtract, err) - return "", err - } - - if err := utils.UnzipFile(dbZip, dbExtract); err != nil { - slog.Error("Failed to unzip DB", dbZip, err) - return "", err - } - - // Prepare directory, extract query pack - if err := os.MkdirAll(queryExtract, 0755); err != nil { - slog.Error("Failed to create query pack directory %s: %v", queryExtract, err) - return "", err - } - - if err := utils.UntarGz(queryPack, queryExtract); err != nil { - slog.Error("Failed to extract querypack %s: %v", queryPack, err) - return "", err - } - - // Prepare query result directory - if err := os.MkdirAll(queryOutDir, 0755); err != nil { - slog.Error("Failed to create query result directory %s: %v", queryOutDir, err) - return "", err - } - - // Run database analyze - cmd := exec.Command("codeql", "database", "analyze", - "--format=sarif-latest", "--rerun", "--output", queryOutFile, - "-j8", dbPath, queryExtract) - cmd.Dir = serverRoot - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - if err := cmd.Run(); err != nil { - slog.Error("codeql database analyze failed:", "error", err, "job", job) - storage.SetStatus(job.QueryPackId, job.NWO, common.StatusError) - return "", err - } - - // Return result path - return queryOutFile, nil -} diff --git a/pkg/codeql/codeql.go b/pkg/codeql/codeql.go index d96373e..a44efd8 100644 --- a/pkg/codeql/codeql.go +++ b/pkg/codeql/codeql.go @@ -16,7 +16,6 @@ import ( "gopkg.in/yaml.v3" ) -// Helper Functions func contains(slice []string, item string) bool { for _, s := range slice { if s == item { @@ -26,7 +25,6 @@ func contains(slice []string, item string) bool { return false } -// Main Functions func getCodeQLCLIPath() (string, error) { // get the CODEQL_CLI_PATH environment variable codeqlCliPath := os.Getenv("CODEQL_CLI_PATH") diff --git a/pkg/queue/rabbitmq.go b/pkg/queue/rabbitmq.go new file mode 100644 index 0000000..7d32b1f --- /dev/null +++ b/pkg/queue/rabbitmq.go @@ -0,0 +1,163 @@ +package queue + +import ( + "mrvacommander/pkg/common" + "mrvacommander/pkg/storage" + + "context" + "encoding/json" + "fmt" + "log" + "time" + + amqp "github.com/rabbitmq/amqp091-go" + "golang.org/x/exp/slog" +) + +type RabbitMQQueue struct { + jobs chan common.AnalyzeJob + results chan common.AnalyzeResult + conn *amqp.Connection + channel *amqp.Channel +} + +func InitializeRabbitMQQueue( + host string, + port int16, + user string, + password string, +) (*RabbitMQQueue, error) { + const ( + tryCount = 5 + retryDelaySec = 3 + jobsQueueName = "tasks" + resultsQueueName = "results" + ) + + var conn *amqp.Connection + var err error + + rabbitMQURL := fmt.Sprintf("amqp://%s:%s@%s:%d/", user, password, host, port) + + for i := 0; i < tryCount; i++ { + slog.Info("Attempting to connect to RabbitMQ", slog.Int("attempt", i+1)) + conn, err = amqp.Dial(rabbitMQURL) + if err != nil { + slog.Warn("Failed to connect to RabbitMQ", "error", err) + if i < tryCount-1 { + slog.Info("Retrying", "seconds", retryDelaySec) + time.Sleep(retryDelaySec * time.Second) + } + } else { + // successfully connected to RabbitMQ + break + } + } + if err != nil { + return nil, fmt.Errorf("failed to connect: %w", err) + } + + slog.Info("Connected to RabbitMQ") + + ch, err := conn.Channel() + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to open a channel: %w", err) + } + + _, err = ch.QueueDeclare(jobsQueueName, false, false, false, true, nil) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to declare tasks queue: %w", err) + } + + _, err = ch.QueueDeclare(resultsQueueName, false, false, false, true, nil) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to declare results queue: %w", err) + } + + err = ch.Qos(1, 0, false) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to set QoS: %w", err) + } + + result := RabbitMQQueue{ + conn: conn, + channel: ch, + jobs: make(chan common.AnalyzeJob), + results: make(chan common.AnalyzeResult), + } + + slog.Info("Starting tasks consumer") + go result.ConsumeJobs(jobsQueueName) + + slog.Info("Starting results publisher") + go result.PublishResults(resultsQueueName) + + return &result, nil +} + +func (q *RabbitMQQueue) Jobs() chan common.AnalyzeJob { + return q.jobs +} + +func (q *RabbitMQQueue) Results() chan common.AnalyzeResult { + return q.results +} + +func (q *RabbitMQQueue) StartAnalyses(analysis_repos *map[common.NameWithOwner]storage.DBLocation, session_id int, session_language string) { + // TODO: Implement + log.Fatal("unimplemented") +} + +func (q *RabbitMQQueue) Close() { + q.channel.Close() + q.conn.Close() +} + +func (q *RabbitMQQueue) ConsumeJobs(queueName string) { + msgs, err := q.channel.Consume(queueName, "", true, false, false, false, nil) + if err != nil { + slog.Error("failed to register a consumer", slog.Any("error", err)) + } + + for msg := range msgs { + job := common.AnalyzeJob{} + err := json.Unmarshal(msg.Body, &job) + if err != nil { + slog.Error("failed to unmarshal job", slog.Any("error", err)) + continue + } + q.jobs <- job + } + close(q.jobs) +} + +func (q *RabbitMQQueue) PublishResults(queueName string) { + for result := range q.results { + q.publishResult(queueName, result) + } +} + +func (q *RabbitMQQueue) publishResult(queueName string, result interface{}) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + resultBytes, err := json.Marshal(result) + if err != nil { + slog.Error("failed to marshal result", slog.Any("error", err)) + return + } + + slog.Info("Publishing result", slog.String("result", string(resultBytes))) + err = q.channel.PublishWithContext(ctx, "", queueName, false, false, + amqp.Publishing{ + ContentType: "application/json", + Body: resultBytes, + }) + if err != nil { + slog.Error("failed to publish result", slog.Any("error", err)) + } +} diff --git a/utils/download.go b/utils/download.go new file mode 100644 index 0000000..ec1ae42 --- /dev/null +++ b/utils/download.go @@ -0,0 +1,29 @@ +package utils + +import ( + "fmt" + "io" + "net/http" + "os" +) + +func downloadFile(url string, dest string) error { + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("failed to download file: %w", err) + } + defer resp.Body.Close() + + out, err := os.Create(dest) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + if err != nil { + return fmt.Errorf("failed to copy file content: %w", err) + } + + return nil +} From 903ca5673e68f9a44a32bfc598d2b77e761759bd Mon Sep 17 00:00:00 2001 From: Nicolas Will Date: Sun, 16 Jun 2024 12:21:54 +0200 Subject: [PATCH 9/9] Add dynamic worker management --- cmd/agent/main.go | 106 +++++++++++++++++++++++++++++++----------- pkg/agent/agent.go | 49 ++++++++++++++----- pkg/codeql/codeql.go | 12 ++--- pkg/queue/rabbitmq.go | 2 +- 4 files changed, 124 insertions(+), 45 deletions(-) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 2df182f..7cd3e33 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -1,24 +1,29 @@ package main import ( - "mrvacommander/pkg/agent" - "mrvacommander/pkg/queue" - "os/signal" - "strconv" - "syscall" - + "context" "flag" "os" + "os/signal" "runtime" + "strconv" "sync" + "syscall" + "time" "github.com/elastic/go-sysinfo" "golang.org/x/exp/slog" + + "mrvacommander/pkg/agent" + "mrvacommander/pkg/queue" +) + +const ( + workerMemoryMB = 2048 // 2 GB + monitorIntervalSec = 10 // Monitor every 10 seconds ) func calculateWorkers() int { - const workerMemoryMB = 2048 // 2 GB - host, err := sysinfo.Host() if err != nil { slog.Error("failed to get host info", "error", err) @@ -49,6 +54,60 @@ func calculateWorkers() int { return workers } +func startAndMonitorWorkers(ctx context.Context, queue queue.Queue, desiredWorkerCount int, wg *sync.WaitGroup) { + currentWorkerCount := 0 + stopChans := make([]chan struct{}, 0) + + if desiredWorkerCount != 0 { + slog.Info("Starting workers", slog.Int("count", desiredWorkerCount)) + for i := 0; i < desiredWorkerCount; i++ { + stopChan := make(chan struct{}) + stopChans = append(stopChans, stopChan) + wg.Add(1) + go agent.RunWorker(ctx, stopChan, queue, wg) + } + return + } + + slog.Info("Worker count not specified, managing based on available memory and CPU") + + for { + select { + case <-ctx.Done(): + // signal all workers to stop + for _, stopChan := range stopChans { + close(stopChan) + } + return + default: + newWorkerCount := calculateWorkers() + + if newWorkerCount != currentWorkerCount { + slog.Info( + "Modifying worker count", + slog.Int("current", currentWorkerCount), + slog.Int("new", newWorkerCount)) + } + + if newWorkerCount > currentWorkerCount { + for i := currentWorkerCount; i < newWorkerCount; i++ { + stopChan := make(chan struct{}) + stopChans = append(stopChans, stopChan) + wg.Add(1) + go agent.RunWorker(ctx, stopChan, queue, wg) + } + } else if newWorkerCount < currentWorkerCount { + for i := newWorkerCount; i < currentWorkerCount; i++ { + close(stopChans[i]) + } + stopChans = stopChans[:newWorkerCount] + } + currentWorkerCount = newWorkerCount + time.Sleep(monitorIntervalSec * time.Second) + } + } +} + func main() { slog.Info("Starting agent") @@ -77,7 +136,6 @@ func main() { rmqPass := os.Getenv("MRVA_RABBITMQ_PASSWORD") rmqPortAsInt, err := strconv.Atoi(rmqPort) - if err != nil { slog.Error("Failed to parse RabbitMQ port", slog.Any("error", err)) os.Exit(1) @@ -92,29 +150,23 @@ func main() { } defer rabbitMQQueue.Close() - if *workerCount == 0 { - *workerCount = calculateWorkers() - } - - slog.Info("Starting workers", slog.Int("count", *workerCount)) var wg sync.WaitGroup - for i := 0; i < *workerCount; i++ { - wg.Add(1) - go agent.RunWorker(rabbitMQQueue, &wg) - } + ctx, cancel := context.WithCancel(context.Background()) - slog.Info("Agent startup complete") + go startAndMonitorWorkers(ctx, rabbitMQQueue, *workerCount, &wg) - // Gracefully exit on SIGINT/SIGTERM (TODO: add job cleanup) + slog.Info("Agent started") + + // Gracefully exit on SIGINT/SIGTERM sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - go func() { - <-sigChan - slog.Info("Shutting down agent") - rabbitMQQueue.Close() - os.Exit(0) - }() + <-sigChan + slog.Info("Shutting down agent") - select {} + // TODO: fix this to gracefully terminate agent workers during jobs + cancel() + wg.Wait() + + slog.Info("Agent shutdown complete") } diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 018d759..b42ab51 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -1,6 +1,7 @@ package agent import ( + "context" "fmt" "log/slog" "mrvacommander/pkg/codeql" @@ -111,8 +112,7 @@ func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) { } // TODO: Upload the archive to storage - slog.Info("Results archive size", slog.Int("size", len(resultsArchive))) - slog.Info("Analysis job successful.") + slog.Debug("Results archive size", slog.Int("size", len(resultsArchive))) result = common.AnalyzeResult{ RequestId: job.RequestId, @@ -125,16 +125,43 @@ func RunAnalysisJob(job common.AnalyzeJob) (common.AnalyzeResult, error) { } // RunWorker runs a worker that processes jobs from queue -func RunWorker(queue queue.Queue, wg *sync.WaitGroup) { +func RunWorker(ctx context.Context, stopChan chan struct{}, queue queue.Queue, wg *sync.WaitGroup) { + const ( + WORKER_COUNT_STOP_MESSAGE = "Worker stopping due to reduction in worker count" + WORKER_CONTEXT_STOP_MESSAGE = "Worker stopping due to context cancellation" + ) + defer wg.Done() - for job := range queue.Jobs() { - slog.Info("Running analysis job", slog.Any("job", job)) - result, err := RunAnalysisJob(job) - if err != nil { - slog.Error("Failed to run analysis job", slog.Any("error", err)) - continue + slog.Info("Worker started") + for { + select { + case <-stopChan: + slog.Info(WORKER_COUNT_STOP_MESSAGE) + return + case <-ctx.Done(): + slog.Info(WORKER_CONTEXT_STOP_MESSAGE) + return + default: + select { + case job, ok := <-queue.Jobs(): + if !ok { + return + } + slog.Info("Running analysis job", slog.Any("job", job)) + result, err := RunAnalysisJob(job) + if err != nil { + slog.Error("Failed to run analysis job", slog.Any("error", err)) + continue + } + slog.Info("Analysis job completed", slog.Any("result", result)) + queue.Results() <- result + case <-stopChan: + slog.Info(WORKER_COUNT_STOP_MESSAGE) + return + case <-ctx.Done(): + slog.Info(WORKER_CONTEXT_STOP_MESSAGE) + return + } } - slog.Info("Analysis job completed", slog.Any("result", result)) - queue.Results() <- result } } diff --git a/pkg/codeql/codeql.go b/pkg/codeql/codeql.go index a44efd8..2c4f45c 100644 --- a/pkg/codeql/codeql.go +++ b/pkg/codeql/codeql.go @@ -115,7 +115,7 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string) databaseSHA = *dbMetadata.CreationMetadata.SHA } - cmd := exec.Command(codeql.Path, "database", "run-queries", "--ram=1024", "--additional-packs", queryPackPath, "--", databasePath, queryPackPath) + cmd := exec.Command(codeql.Path, "database", "run-queries", "--ram=2048", "--additional-packs", queryPackPath, "--", databasePath, queryPackPath) if output, err := cmd.CombinedOutput(); err != nil { return nil, fmt.Errorf("failed to run queries: %v\nOutput: %s", err, output) } @@ -133,9 +133,9 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string) shouldGenerateSarif := queryPackSupportsSarif(queryPackRunResults) if shouldGenerateSarif { - slog.Info("Query pack supports SARIF") + slog.Debug("Query pack supports SARIF") } else { - slog.Info("Query pack does not support SARIF") + slog.Debug("Query pack does not support SARIF") } var resultCount int @@ -147,17 +147,17 @@ func RunQuery(database string, nwo string, queryPackPath string, tempDir string) return nil, fmt.Errorf("failed to generate SARIF: %v", err) } resultCount = getSarifResultCount(sarif) - slog.Info("Generated SARIF", "resultCount", resultCount) + slog.Debug("Generated SARIF", "resultCount", resultCount) sarifFilePath = filepath.Join(resultsDir, "results.sarif") if err := os.WriteFile(sarifFilePath, sarif, 0644); err != nil { return nil, fmt.Errorf("failed to write SARIF file: %v", err) } } else { resultCount = queryPackRunResults.TotalResultsCount - slog.Info("Did not generate SARIF", "resultCount", resultCount) + slog.Debug("Did not generate SARIF", "resultCount", resultCount) } - slog.Info("Adjusting BQRS files") + slog.Debug("Adjusting BQRS files") bqrsFilePaths, err := adjustBqrsFiles(queryPackRunResults, resultsDir) if err != nil { return nil, fmt.Errorf("failed to adjust BQRS files: %v", err) diff --git a/pkg/queue/rabbitmq.go b/pkg/queue/rabbitmq.go index 7d32b1f..f0047f3 100644 --- a/pkg/queue/rabbitmq.go +++ b/pkg/queue/rabbitmq.go @@ -151,7 +151,7 @@ func (q *RabbitMQQueue) publishResult(queueName string, result interface{}) { return } - slog.Info("Publishing result", slog.String("result", string(resultBytes))) + slog.Debug("Publishing result", slog.String("result", string(resultBytes))) err = q.channel.PublishWithContext(ctx, "", queueName, false, false, amqp.Publishing{ ContentType: "application/json",