Files
gh-mrva/cmd/download.go
2023-04-21 13:43:20 +02:00

151 lines
4.2 KiB
Go

/*
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/
package cmd
import (
"sync"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"log"
"github.com/GitHubSecurityLab/gh-mrva/utils"
"github.com/GitHubSecurityLab/gh-mrva/models"
"github.com/GitHubSecurityLab/gh-mrva/config"
"github.com/spf13/cobra"
)
var downloadCmd = &cobra.Command{
Use: "download",
Short: "Downloads the artifacts associated to a given session.",
Long: `Downloads the artifacts associated to a given session.`,
Run: func(cmd *cobra.Command, args []string) {
downloadArtifacts()
},
}
func init() {
rootCmd.AddCommand(downloadCmd)
downloadCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name to be downloaded")
downloadCmd.Flags().StringVarP(&outputDirFlag, "output-dir", "o", "", "Output directory")
downloadCmd.Flags().BoolVarP(&downloadDBsFlag, "download-dbs", "d", false, "Download databases (optional)")
downloadCmd.Flags().StringVarP(&nwoFlag, "nwo", "n", "", "Repository to download artifacts for (optional)")
statusCmd.MarkFlagRequired("session")
statusCmd.MarkFlagRequired("output-dir")
}
func downloadArtifacts() {
// if outputDirFlag does not exist, create it
if _, err := os.Stat(outputDirFlag); os.IsNotExist(err) {
err := os.MkdirAll(outputDirFlag, 0755)
if err != nil {
log.Fatal(err)
}
}
controller, runs, language, err := utils.LoadSession(sessionNameFlag)
if err != nil {
fmt.Println(err)
} else if len(runs) == 0 {
fmt.Println("No runs found for sessions" + sessionNameFlag)
}
var downloadTasks []models.DownloadTask
for _, run := range runs {
runDetails, err := utils.GetRunDetails(controller, run.Id)
if err != nil {
log.Fatal(err)
}
if runDetails["status"] == "in_progress" {
log.Printf("Run %d is not complete yet. Please try again later.", run.Id)
return
}
for _, r := range runDetails["scanned_repositories"].([]interface{}) {
repo := r.(map[string]interface{})
result_count := repo["result_count"]
repoInfo := repo["repository"].(map[string]interface{})
nwo := repoInfo["full_name"].(string)
// if nwoFlag is set, only download artifacts for that repository
if nwoFlag != "" && nwoFlag != nwo {
continue
}
if result_count != nil && result_count.(float64) > 0 {
// check if the SARIF or BQRS file already exists
dnwo := strings.Replace(nwo, "/", "_", -1)
sarifPath := filepath.Join(outputDirFlag, fmt.Sprintf("%s.sarif", dnwo))
bqrsPath := filepath.Join(outputDirFlag, fmt.Sprintf("%s.bqrs", dnwo))
targetPath := filepath.Join(outputDirFlag, fmt.Sprintf("%s_%s_db.zip", dnwo, language))
_, bqrsErr := os.Stat(bqrsPath)
_, sarifErr := os.Stat(sarifPath)
if errors.Is(bqrsErr, os.ErrNotExist) && errors.Is(sarifErr, os.ErrNotExist) {
downloadTasks = append(downloadTasks, models.DownloadTask{
RunId: run.Id,
Nwo: nwo,
Controller: controller,
Artifact: "artifact",
Language: language,
OutputDir: outputDirFlag,
})
}
if downloadDBsFlag {
// check if the database already exists
if _, err := os.Stat(targetPath); errors.Is(err, os.ErrNotExist) {
downloadTasks = append(downloadTasks, models.DownloadTask{
RunId: run.Id,
Nwo: nwo,
Controller: controller,
Artifact: "database",
Language: language,
OutputDir: outputDirFlag,
})
}
}
}
}
}
wg := new(sync.WaitGroup)
taskChannel := make(chan models.DownloadTask)
resultChannel := make(chan models.DownloadTask, len(downloadTasks))
// Start the workers
for i := 0; i < config.WORKERS; i++ {
wg.Add(1)
go utils.DownloadWorker(wg, taskChannel, resultChannel)
}
// Send jobs to worker
for _, downloadTask := range downloadTasks {
taskChannel <- downloadTask
}
close(taskChannel)
count := 0
progressDone := make(chan bool)
go func() {
for value := range resultChannel {
count++
fmt.Printf("Downloaded %s for %s (%d/%d)\n", value.Artifact, value.Nwo, count, len(downloadTasks))
}
fmt.Println(fmt.Sprintf("%d artifacts downloaded", count))
progressDone <- true
}()
// wait for all workers to finish
wg.Wait()
// close the result channel
close(resultChannel)
// drain the progress channel
<-progressDone
}