From 5b54ffbf91810f5190fccf20b7f2ef6e539a946e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alvaro=20Mu=C3=B1oz?= Date: Mon, 11 Sep 2023 16:56:37 +0200 Subject: [PATCH] Add `download run` subcommand and --output-filename option --- cmd/download.go | 56 ++++++++++++++++++++++++++++++++---------------- models/models.go | 38 ++++++++++++++++---------------- utils/utils.go | 54 +++++++++++++++++++++++++++++++--------------- 3 files changed, 93 insertions(+), 55 deletions(-) diff --git a/cmd/download.go b/cmd/download.go index c148045..07922c9 100644 --- a/cmd/download.go +++ b/cmd/download.go @@ -1,6 +1,5 @@ /* Copyright © 2023 NAME HERE - */ package cmd @@ -31,11 +30,13 @@ var downloadCmd = &cobra.Command{ func init() { rootCmd.AddCommand(downloadCmd) downloadCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name to be downloaded") + downloadCmd.Flags().IntVarP(&runIdFlag, "run", "r", 0, "Run ID to be downloaded") downloadCmd.Flags().StringVarP(&outputDirFlag, "output-dir", "o", "", "Output directory") + downloadCmd.Flags().StringVarP(&outputFilenameFlag, "output-filename", "f", "", "Output filename") downloadCmd.Flags().BoolVarP(&downloadDBsFlag, "download-dbs", "d", false, "Download databases (optional)") downloadCmd.Flags().StringVarP(&nwoFlag, "nwo", "n", "", "Repository to download artifacts for (optional)") - downloadCmd.MarkFlagRequired("session") downloadCmd.MarkFlagRequired("output-dir") + downloadCmd.MarkFlagsMutuallyExclusive("session", "run") } func downloadArtifacts() { @@ -48,11 +49,26 @@ func downloadArtifacts() { } } - controller, runs, language, err := utils.LoadSession(sessionNameFlag) - if err != nil { - fmt.Println(err) - } else if len(runs) == 0 { - fmt.Println("No runs found for sessions" + sessionNameFlag) + controller := "" + language := "" + runs := []models.Run{} + err := error(nil) + + if sessionNameFlag != "" { + controller, runs, language, err = utils.LoadSession(sessionNameFlag) + if err != nil { + fmt.Println(err) + } else if len(runs) == 0 { + fmt.Println("No runs found for sessions" + sessionNameFlag) + } + } else if runIdFlag > 0 { + controller, runs, language, err = utils.LoadRun(runIdFlag) + if err != nil { + fmt.Println(err) + } + } else { + fmt.Println("Please specify a session or run to download artifacts for") + return } var downloadTasks []models.DownloadTask @@ -85,24 +101,26 @@ func downloadArtifacts() { _, sarifErr := os.Stat(sarifPath) if errors.Is(bqrsErr, os.ErrNotExist) && errors.Is(sarifErr, os.ErrNotExist) { downloadTasks = append(downloadTasks, models.DownloadTask{ - RunId: run.Id, - Nwo: nwo, - Controller: controller, - Artifact: "artifact", - Language: language, - OutputDir: outputDirFlag, + RunId: run.Id, + Nwo: nwo, + Controller: controller, + Artifact: "artifact", + Language: language, + OutputDir: outputDirFlag, + OutputFilename: outputFilenameFlag, }) } if downloadDBsFlag { // check if the database already exists if _, err := os.Stat(targetPath); errors.Is(err, os.ErrNotExist) { downloadTasks = append(downloadTasks, models.DownloadTask{ - RunId: run.Id, - Nwo: nwo, - Controller: controller, - Artifact: "database", - Language: language, - OutputDir: outputDirFlag, + RunId: run.Id, + Nwo: nwo, + Controller: controller, + Artifact: "database", + Language: language, + OutputDir: outputDirFlag, + OutputFilename: outputFilenameFlag, }) } } diff --git a/models/models.go b/models/models.go index 2fe7c24..413b660 100644 --- a/models/models.go +++ b/models/models.go @@ -28,12 +28,13 @@ type Config struct { } type DownloadTask struct { - RunId int - Nwo string - Controller string - Artifact string - OutputDir string - Language string + RunId int + Nwo string + Controller string + Artifact string + OutputDir string + OutputFilename string + Language string } type RunStatus struct { @@ -55,17 +56,16 @@ type RepoWithFindings struct { type Results struct { Name string `json:"name"` - Status string `json:"status"` - Runs []RunStatus `json:"runs"` - ResositoriesWithFindings []RepoWithFindings `json:"repositories_with_findings"` - TotalFindingsCount int `json:"total_findings_count"` - TotalSuccessfulScans int `json:"total_successful_scans"` - TotalFailedScans int `json:"total_failed_scans"` - TotalRepositoriesWithFindings int `json:"total_repositories_with_findings"` - TotalSkippedRepositories int `json:"total_skipped_repositories"` - TotalSkippedAccessMismatchRepositories int `json:"total_skipped_access_mismatch_repositories"` - TotalSkippedNotFoundRepositories int `json:"total_skipped_not_found_repositories"` - TotalSkippedNoDatabaseRepositories int `json:"total_skipped_no_database_repositories"` - TotalSkippedOverLimitRepositories int `json:"total_skipped_over_limit_repositories"` + Status string `json:"status"` + Runs []RunStatus `json:"runs"` + ResositoriesWithFindings []RepoWithFindings `json:"repositories_with_findings"` + TotalFindingsCount int `json:"total_findings_count"` + TotalSuccessfulScans int `json:"total_successful_scans"` + TotalFailedScans int `json:"total_failed_scans"` + TotalRepositoriesWithFindings int `json:"total_repositories_with_findings"` + TotalSkippedRepositories int `json:"total_skipped_repositories"` + TotalSkippedAccessMismatchRepositories int `json:"total_skipped_access_mismatch_repositories"` + TotalSkippedNotFoundRepositories int `json:"total_skipped_not_found_repositories"` + TotalSkippedNoDatabaseRepositories int `json:"total_skipped_no_database_repositories"` + TotalSkippedOverLimitRepositories int `json:"total_skipped_over_limit_repositories"` } - diff --git a/utils/utils.go b/utils/utils.go index 1edf2ca..5ab5947 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -59,6 +59,23 @@ func GetSessions() (map[string]models.Session, error) { return sessions, nil } +func LoadRun(id int) (string, []models.Run, string, error) { + sessions, err := GetSessions() + if err != nil { + return "", nil, "", err + } + if sessions != nil { + for _, session := range sessions { + for _, run := range session.Runs { + if run.Id == id { + return session.Controller, []models.Run{run}, session.Language, nil + } + } + } + } + return "", nil, "", errors.New("No run found for " + fmt.Sprint(id)) +} + func LoadSession(name string) (string, []models.Run, string, error) { sessions, err := GetSessions() if err != nil { @@ -416,7 +433,6 @@ defaultSuite: return bundleBase64, queryId, nil } - func PackPacklist(codeqlPath string, dir string, includeQueries bool) []string { // since 2.7.1, packlist returns an object with a "paths" property that is a list of packs. args := []string{"pack", "packlist", "--format=json"} @@ -518,17 +534,17 @@ func DownloadWorker(wg *sync.WaitGroup, taskChannel <-chan models.DownloadTask, defer wg.Done() for task := range taskChannel { if task.Artifact == "artifact" { - DownloadResults(task.Controller, task.RunId, task.Nwo, task.OutputDir) + DownloadResults(task.Controller, task.RunId, task.Nwo, task.OutputDir, task.OutputFilename) resultChannel <- task } else if task.Artifact == "database" { - fmt.Println("Downloading database", task.Nwo, task.Language, task.OutputDir) - DownloadDatabase(task.Nwo, task.Language, task.OutputDir) + fmt.Println("Downloading database", task.Nwo, task.Language, task.OutputDir, task.OutputFilename) + DownloadDatabase(task.Nwo, task.Language, task.OutputDir, task.OutputFilename) resultChannel <- task } } } -func downloadArtifact(url string, outputDir string, nwo string) error { +func downloadArtifact(url string, outputDir string, nwo string, outputFilename string) error { client, err := gh.HTTPClient(nil) if err != nil { return err @@ -562,14 +578,16 @@ func downloadArtifact(url string, outputDir string, nwo string) error { if err != nil { log.Fatal(err) } - extension := "" - resultPath := "" - if zf.Name == "results.bqrs" { - extension = "bqrs" - } else if zf.Name == "results.sarif" { - extension = "sarif" + if outputFilename == "" { + extension := "" + if zf.Name == "results.bqrs" { + extension = "bqrs" + } else if zf.Name == "results.sarif" { + extension = "sarif" + } + outputFilename = fmt.Sprintf("%s.%s", strings.Replace(nwo, "/", "_", -1), extension) } - resultPath = filepath.Join(outputDir, fmt.Sprintf("%s.%s", strings.Replace(nwo, "/", "_", -1), extension)) + resultPath := filepath.Join(outputDir, outputFilename) err = os.WriteFile(resultPath, bytes, os.ModePerm) if err != nil { return err @@ -579,23 +597,26 @@ func downloadArtifact(url string, outputDir string, nwo string) error { return errors.New("No results.sarif file found in artifact") } -func DownloadResults(controller string, runId int, nwo string, outputDir string) error { +func DownloadResults(controller string, runId int, nwo string, outputDir string, outputFilename string) error { // download artifact (BQRS or SARIF) runRepositoryDetails, err := GetRunRepositoryDetails(controller, runId, nwo) if err != nil { return errors.New("Failed to get run repository details") } // download the results - err = downloadArtifact(runRepositoryDetails["artifact_url"].(string), outputDir, nwo) + err = downloadArtifact(runRepositoryDetails["artifact_url"].(string), outputDir, nwo, outputFilename) if err != nil { return errors.New("Failed to download artifact") } return nil } -func DownloadDatabase(nwo string, language string, outputDir string) error { +func DownloadDatabase(nwo string, language string, outputDir string, outputFilename string) error { dnwo := strings.Replace(nwo, "/", "_", -1) - targetPath := filepath.Join(outputDir, fmt.Sprintf("%s_%s_db.zip", dnwo, language)) + if outputFilename == "" { + outputFilename = fmt.Sprintf("%s_%s_db.zip", dnwo, language) + } + targetPath := filepath.Join(outputDir, outputFilename) opts := api.ClientOptions{ Headers: map[string]string{"Accept": "application/zip"}, } @@ -616,4 +637,3 @@ func DownloadDatabase(nwo string, language string, outputDir string) error { err = os.WriteFile(targetPath, bytes, os.ModePerm) return nil } -