Add download run subcommand and --output-filename option

This commit is contained in:
Alvaro Muñoz
2023-09-11 16:56:37 +02:00
parent 361fa7c833
commit 5b54ffbf91
3 changed files with 93 additions and 55 deletions

View File

@@ -1,6 +1,5 @@
/*
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/
package cmd
@@ -31,11 +30,13 @@ var downloadCmd = &cobra.Command{
func init() {
rootCmd.AddCommand(downloadCmd)
downloadCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name to be downloaded")
downloadCmd.Flags().IntVarP(&runIdFlag, "run", "r", 0, "Run ID to be downloaded")
downloadCmd.Flags().StringVarP(&outputDirFlag, "output-dir", "o", "", "Output directory")
downloadCmd.Flags().StringVarP(&outputFilenameFlag, "output-filename", "f", "", "Output filename")
downloadCmd.Flags().BoolVarP(&downloadDBsFlag, "download-dbs", "d", false, "Download databases (optional)")
downloadCmd.Flags().StringVarP(&nwoFlag, "nwo", "n", "", "Repository to download artifacts for (optional)")
downloadCmd.MarkFlagRequired("session")
downloadCmd.MarkFlagRequired("output-dir")
downloadCmd.MarkFlagsMutuallyExclusive("session", "run")
}
func downloadArtifacts() {
@@ -48,12 +49,27 @@ func downloadArtifacts() {
}
}
controller, runs, language, err := utils.LoadSession(sessionNameFlag)
controller := ""
language := ""
runs := []models.Run{}
err := error(nil)
if sessionNameFlag != "" {
controller, runs, language, err = utils.LoadSession(sessionNameFlag)
if err != nil {
fmt.Println(err)
} else if len(runs) == 0 {
fmt.Println("No runs found for sessions" + sessionNameFlag)
}
} else if runIdFlag > 0 {
controller, runs, language, err = utils.LoadRun(runIdFlag)
if err != nil {
fmt.Println(err)
}
} else {
fmt.Println("Please specify a session or run to download artifacts for")
return
}
var downloadTasks []models.DownloadTask
@@ -91,6 +107,7 @@ func downloadArtifacts() {
Artifact: "artifact",
Language: language,
OutputDir: outputDirFlag,
OutputFilename: outputFilenameFlag,
})
}
if downloadDBsFlag {
@@ -103,6 +120,7 @@ func downloadArtifacts() {
Artifact: "database",
Language: language,
OutputDir: outputDirFlag,
OutputFilename: outputFilenameFlag,
})
}
}

View File

@@ -33,6 +33,7 @@ type DownloadTask struct {
Controller string
Artifact string
OutputDir string
OutputFilename string
Language string
}
@@ -68,4 +69,3 @@ type Results struct {
TotalSkippedNoDatabaseRepositories int `json:"total_skipped_no_database_repositories"`
TotalSkippedOverLimitRepositories int `json:"total_skipped_over_limit_repositories"`
}

View File

@@ -59,6 +59,23 @@ func GetSessions() (map[string]models.Session, error) {
return sessions, nil
}
func LoadRun(id int) (string, []models.Run, string, error) {
sessions, err := GetSessions()
if err != nil {
return "", nil, "", err
}
if sessions != nil {
for _, session := range sessions {
for _, run := range session.Runs {
if run.Id == id {
return session.Controller, []models.Run{run}, session.Language, nil
}
}
}
}
return "", nil, "", errors.New("No run found for " + fmt.Sprint(id))
}
func LoadSession(name string) (string, []models.Run, string, error) {
sessions, err := GetSessions()
if err != nil {
@@ -416,7 +433,6 @@ defaultSuite:
return bundleBase64, queryId, nil
}
func PackPacklist(codeqlPath string, dir string, includeQueries bool) []string {
// since 2.7.1, packlist returns an object with a "paths" property that is a list of packs.
args := []string{"pack", "packlist", "--format=json"}
@@ -518,17 +534,17 @@ func DownloadWorker(wg *sync.WaitGroup, taskChannel <-chan models.DownloadTask,
defer wg.Done()
for task := range taskChannel {
if task.Artifact == "artifact" {
DownloadResults(task.Controller, task.RunId, task.Nwo, task.OutputDir)
DownloadResults(task.Controller, task.RunId, task.Nwo, task.OutputDir, task.OutputFilename)
resultChannel <- task
} else if task.Artifact == "database" {
fmt.Println("Downloading database", task.Nwo, task.Language, task.OutputDir)
DownloadDatabase(task.Nwo, task.Language, task.OutputDir)
fmt.Println("Downloading database", task.Nwo, task.Language, task.OutputDir, task.OutputFilename)
DownloadDatabase(task.Nwo, task.Language, task.OutputDir, task.OutputFilename)
resultChannel <- task
}
}
}
func downloadArtifact(url string, outputDir string, nwo string) error {
func downloadArtifact(url string, outputDir string, nwo string, outputFilename string) error {
client, err := gh.HTTPClient(nil)
if err != nil {
return err
@@ -562,14 +578,16 @@ func downloadArtifact(url string, outputDir string, nwo string) error {
if err != nil {
log.Fatal(err)
}
if outputFilename == "" {
extension := ""
resultPath := ""
if zf.Name == "results.bqrs" {
extension = "bqrs"
} else if zf.Name == "results.sarif" {
extension = "sarif"
}
resultPath = filepath.Join(outputDir, fmt.Sprintf("%s.%s", strings.Replace(nwo, "/", "_", -1), extension))
outputFilename = fmt.Sprintf("%s.%s", strings.Replace(nwo, "/", "_", -1), extension)
}
resultPath := filepath.Join(outputDir, outputFilename)
err = os.WriteFile(resultPath, bytes, os.ModePerm)
if err != nil {
return err
@@ -579,23 +597,26 @@ func downloadArtifact(url string, outputDir string, nwo string) error {
return errors.New("No results.sarif file found in artifact")
}
func DownloadResults(controller string, runId int, nwo string, outputDir string) error {
func DownloadResults(controller string, runId int, nwo string, outputDir string, outputFilename string) error {
// download artifact (BQRS or SARIF)
runRepositoryDetails, err := GetRunRepositoryDetails(controller, runId, nwo)
if err != nil {
return errors.New("Failed to get run repository details")
}
// download the results
err = downloadArtifact(runRepositoryDetails["artifact_url"].(string), outputDir, nwo)
err = downloadArtifact(runRepositoryDetails["artifact_url"].(string), outputDir, nwo, outputFilename)
if err != nil {
return errors.New("Failed to download artifact")
}
return nil
}
func DownloadDatabase(nwo string, language string, outputDir string) error {
func DownloadDatabase(nwo string, language string, outputDir string, outputFilename string) error {
dnwo := strings.Replace(nwo, "/", "_", -1)
targetPath := filepath.Join(outputDir, fmt.Sprintf("%s_%s_db.zip", dnwo, language))
if outputFilename == "" {
outputFilename = fmt.Sprintf("%s_%s_db.zip", dnwo, language)
}
targetPath := filepath.Join(outputDir, outputFilename)
opts := api.ClientOptions{
Headers: map[string]string{"Accept": "application/zip"},
}
@@ -616,4 +637,3 @@ func DownloadDatabase(nwo string, language string, outputDir string) error {
err = os.WriteFile(targetPath, bytes, os.ModePerm)
return nil
}