diff --git a/go.mod b/go.mod index dcf4dd6..c89db20 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.19 require github.com/cli/go-gh v1.2.1 +require github.com/aymanbagabas/go-osc52 v1.2.1 // indirect + require ( github.com/cli/safeexec v1.0.0 // indirect github.com/cli/shurcooL-graphql v0.0.2 // indirect @@ -11,14 +13,11 @@ require ( github.com/henvic/httpretty v0.0.6 // indirect github.com/kr/text v0.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-isatty v0.0.16 // indirect - github.com/mattn/go-runewidth v0.0.13 // indirect - github.com/muesli/termenv v0.12.0 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect + github.com/mattn/go-runewidth v0.0.14 // indirect + github.com/muesli/termenv v0.14.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e // indirect - github.com/tidwall/gjson v1.14.4 // direct - github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.1 // indirect golang.org/x/net v0.7.0 // indirect golang.org/x/sys v0.5.0 // indirect golang.org/x/term v0.5.0 // indirect diff --git a/go.sum b/go.sum index cdf2330..a95410b 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= +github.com/aymanbagabas/go-osc52 v1.2.1 h1:q2sWUyDcozPLcLabEMd+a+7Ea2DitxZVN9hTxab9L4E= +github.com/aymanbagabas/go-osc52 v1.2.1/go.mod h1:zT8H+Rk4VSabYN90pWyugflM3ZhpTZNC7cASDfUCdT4= github.com/cli/go-gh v1.2.1 h1:xFrjejSsgPiwXFP6VYynKWwxLQcNJy3Twbu82ZDlR/o= github.com/cli/go-gh v1.2.1/go.mod h1:Jxk8X+TCO4Ui/GarwY9tByWm/8zp4jJktzVZNlTW5VM= github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI= @@ -17,34 +19,24 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= -github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= +github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= -github.com/muesli/termenv v0.12.0 h1:KuQRUE3PgxRFWhq4gHvZtPSLCGDqM5q/cYr1pZ39ytc= -github.com/muesli/termenv v0.12.0/go.mod h1:WCCv32tusQ/EEZ5S8oUIIrC/nIuBcxCVqlN4Xfkv+7A= +github.com/muesli/termenv v0.14.0 h1:8x9NFfOe8lmIWK4pgy3IfVEy47f+ppe3tUqdPZG2Uy0= +github.com/muesli/termenv v0.14.0/go.mod h1:kG/pF1E7fh949Xhe156crRUrHNyK221IuGO7Ez60Uc8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e h1:BuzhfgfWQbX0dWzYzT1zsORLnHRv3bcRcsaUk0VmXA8= github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e/go.mod h1:/Tnicc6m/lsJE0irFMA0LfIwTBo4QP7A8IfyIv4zZKI= -github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= -github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= -github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= golang.org/x/net v0.0.0-20220923203811-8be639271d50/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= diff --git a/main.go b/main.go index f271797..8f34781 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/cli/go-gh" "github.com/cli/go-gh/pkg/api" + "github.com/cli/go-gh/pkg/jsonpretty" "github.com/google/uuid" "gopkg.in/yaml.v3" "io/ioutil" @@ -18,21 +19,32 @@ import ( "os/exec" "path/filepath" "strings" + "sync" "text/template" + "time" ) const ( MAX_MRVA_REPOSITORIES = 1000 + WORKERS = 10 ) var ( - configFilePath = "" - controller = "" - language = "" - runName = "" - listFile = "" + configFilePath string ) +func runCodeQLCommand(codeqlPath string, combined bool, args ...string) ([]byte, error) { + if !strings.Contains(strings.Join(args, " "), "packlist") { + args = append(args, fmt.Sprintf("--additional-packs=%s", codeqlPath)) + } + cmd := exec.Command("codeql", args...) + cmd.Env = os.Environ() + if combined { + return cmd.CombinedOutput() + } else { + return cmd.Output() + } +} func resolveRepositories(listFile string, list string) ([]string, error) { fmt.Printf("Resolving %s repositories from %s\n", list, listFile) jsonFile, err := os.Open(listFile) @@ -50,9 +62,9 @@ func resolveRepositories(listFile string, list string) ([]string, error) { return repoLists[list], nil } -func resolveQueries(querySuite string) []string { +func resolveQueries(codeqlPath string, querySuite string) []string { args := []string{"resolve", "queries", "--format=json", querySuite} - jsonBytes, err := exec.Command("codeql", args...).Output() + jsonBytes, err := runCodeQLCommand(codeqlPath, false, args...) var queries []string err = json.Unmarshal(jsonBytes, &queries) if err != nil { @@ -61,14 +73,14 @@ func resolveQueries(querySuite string) []string { return queries } -func packPacklist(dir string, includeQueries bool) []string { +func packPacklist(codeqlPath string, dir string, includeQueries bool) []string { // since 2.7.1, packlist returns an object with a "paths" property that is a list of packs. args := []string{"pack", "packlist", "--format=json"} if !includeQueries { args = append(args, "--no-include-queries") } args = append(args, dir) - jsonBytes, err := exec.Command("codeql", args...).Output() + jsonBytes, err := runCodeQLCommand(codeqlPath, false, args...) var packlist map[string][]string err = json.Unmarshal(jsonBytes, &packlist) if err != nil { @@ -108,15 +120,6 @@ func copyFile(srcPath string, targetPath string) error { return nil } -// Fixes the qlpack.yml file to be correct in the context of the MRVA request. -// Performs the following fixes: -// - Updates the default suite of the query pack. This is used to ensure -// only the specified query is run. -// - Ensures the query pack name is set to the name expected by the server. -// - Removes any `${workspace}` version references from the qlpack.yml file. Converts them -// to `*` versions. -// @param queryPackDir The directory containing the query pack -// @param packRelativePath The relative path to the query pack from the root of the query pack func fixPackFile(queryPackDir string, packRelativePath string) error { packPath := filepath.Join(queryPackDir, "qlpack.yml") packFile, err := ioutil.ReadFile(packPath) @@ -168,7 +171,7 @@ func fixPackFile(queryPackDir string, packRelativePath string) error { } // Generate a query pack containing the given query file. -func generateQueryPack(queryFile string) (string, error) { +func generateQueryPack(codeqlPath string, queryFile string, language string) (string, error) { fmt.Printf("Generating query pack for %s\n", queryFile) // create a temporary directory to hold the query pack @@ -176,8 +179,7 @@ func generateQueryPack(queryFile string) (string, error) { if err != nil { log.Fatal(err) } - // TODO: uncomment this line when we're done debugging - //defer os.RemoveAll(queryPackDir) + defer os.RemoveAll(queryPackDir) queryFile, err = filepath.Abs(queryFile) if err != nil { @@ -232,7 +234,7 @@ defaultSuite: } else { // don't include all query files in the QLPacks. We only want the queryFile to be copied. fmt.Printf("QLPack exists, stripping all other queries from %s\n", originalPackRoot) - toCopy := packPacklist(originalPackRoot, false) + toCopy := packPacklist(codeqlPath, originalPackRoot, false) // also copy the lock file (either new name or old name) and the query file itself (these are not included in the packlist) lockFileNew := filepath.Join(originalPackRoot, "qlpack.lock.yml") lockFileOld := filepath.Join(originalPackRoot, "codeql-pack.lock.yml") @@ -266,7 +268,7 @@ defaultSuite: // install the pack dependencies fmt.Print("Installing QLPack dependencies\n") args := []string{"pack", "install", queryPackDir} - stdouterr, err := exec.Command("codeql", args...).CombinedOutput() + stdouterr, err := runCodeQLCommand(codeqlPath, true, args...) if err != nil { fmt.Printf("`codeql pack bundle` failed with error: %v\n", string(stdouterr)) return "", fmt.Errorf("Failed to install query pack: %v", err) @@ -275,7 +277,7 @@ defaultSuite: fmt.Print("Compiling and bundling the QLPack (This may take a while)\n") args = []string{"pack", "bundle", "-o", bundlePath, queryPackDir} args = append(args, precompilationOpts...) - stdouterr, err = exec.Command("codeql", args...).CombinedOutput() + stdouterr, err = runCodeQLCommand(codeqlPath, true, args...) if err != nil { fmt.Printf("`codeql pack bundle` failed with error: %v\n", string(stdouterr)) return "", fmt.Errorf("Failed to bundle query pack: %v\n", err) @@ -297,7 +299,7 @@ defaultSuite: } // Requests a query to be run against `respositories` on the given `controller`. -func submitRun(repoChunk []string, bundle string) (int, error) { +func submitRun(controller string, language string, repoChunk []string, bundle string) (int, error) { opts := api.ClientOptions{ Headers: map[string]string{"Accept": "application/vnd.github.v3+json"}, } @@ -330,7 +332,7 @@ func submitRun(repoChunk []string, bundle string) (int, error) { return id, nil } -func getRunDetails(runId int) (map[string]interface{}, error) { +func getRunDetails(controller string, runId int) (map[string]interface{}, error) { opts := api.ClientOptions{ Headers: map[string]string{"Accept": "application/vnd.github.v3+json"}, } @@ -346,7 +348,7 @@ func getRunDetails(runId int) (map[string]interface{}, error) { return response, nil } -func getRunRepositoryDetails(runId int, nwo string) (map[string]interface{}, error) { +func getRunRepositoryDetails(controller string, runId int, nwo string) (map[string]interface{}, error) { opts := api.ClientOptions{ Headers: map[string]string{"Accept": "application/vnd.github.v3+json"}, } @@ -362,14 +364,37 @@ func getRunRepositoryDetails(runId int, nwo string) (map[string]interface{}, err return response, nil } -func downloadArtifact(url string, outputDir string, nwo string) (string, error) { +type DownloadTask struct { + runId int + nwo string + controller string + artifact string + outputDir string + language string +} + +func downloadWorker(wg *sync.WaitGroup, taskChannel <-chan DownloadTask, resultChannel chan DownloadTask) { + defer wg.Done() + for task := range taskChannel { + if task.artifact == "artifact" { + downloadResults(task.controller, task.runId, task.nwo, task.outputDir) + resultChannel <- task + } else if task.artifact == "database" { + fmt.Println("Downloading database", task.nwo, task.language, task.outputDir) + downloadDatabase(task.nwo, task.language, task.outputDir) + resultChannel <- task + } + } +} + +func downloadArtifact(url string, outputDir string, nwo string) error { client, err := gh.HTTPClient(nil) if err != nil { - return "", err + return err } resp, err := client.Get(url) if err != nil { - return "", err + return err } defer resp.Body.Close() @@ -406,14 +431,30 @@ func downloadArtifact(url string, outputDir string, nwo string) (string, error) resultPath = filepath.Join(outputDir, fmt.Sprintf("%s.%s", strings.Replace(nwo, "/", "_", -1), extension)) err = ioutil.WriteFile(resultPath, bytes, os.ModePerm) if err != nil { - return "", err + return err } - return resultPath, nil + return nil } - return "", errors.New("No results.sarif file found in artifact") + return errors.New("No results.sarif file found in artifact") } -func downloadDatabase(nwo string, lang string, targetPath string) error { +func downloadResults(controller string, runId int, nwo string, outputDir string) error { + // download artifact (BQRS or SARIF) + runRepositoryDetails, err := getRunRepositoryDetails(controller, runId, nwo) + if err != nil { + return errors.New("Failed to get run repository details") + } + // download the results + err = downloadArtifact(runRepositoryDetails["artifact_url"].(string), outputDir, nwo) + if err != nil { + return errors.New("Failed to download artifact") + } + return nil +} + +func downloadDatabase(nwo string, language string, outputDir string) error { + dnwo := strings.Replace(nwo, "/", "_", -1) + targetPath := filepath.Join(outputDir, fmt.Sprintf("%s_%s_db.zip", dnwo, language)) opts := api.ClientOptions{ Headers: map[string]string{"Accept": "application/zip"}, } @@ -421,7 +462,7 @@ func downloadDatabase(nwo string, lang string, targetPath string) error { if err != nil { return err } - resp, err := client.Get(fmt.Sprintf("https://api.github.com/repos/%s/code-scanning/codeql/databases/%s", nwo, lang)) + resp, err := client.Get(fmt.Sprintf("https://api.github.com/repos/%s/code-scanning/codeql/databases/%s", nwo, language)) if err != nil { return err } @@ -435,19 +476,29 @@ func downloadDatabase(nwo string, lang string, targetPath string) error { return nil } -func saveInCache(name string, ids []int) error { +func saveInHistory(name string, controller string, runIds []int, language string, listFile string, list string, query string, count int) error { configData, err := getConfig(configFilePath) if err != nil { return err } - cache := configData.Cache - if cache == nil { - cache = map[string][]int{} + if configData.History == nil { + configData.History = make(map[string]HistoryEntry) } - if cache[name] == nil { - cache[name] = ids + // add new history entry if it doesn't already exist + if _, ok := configData.History[name]; ok { + return errors.New("Name already exists in history") } else { - cache[name] = append(cache[name], ids...) + configData.History[name] = HistoryEntry{ + Name: name, + RunIds: runIds, + Timestamp: time.Now(), + Controller: controller, + Language: language, + ListFile: listFile, + List: list, + Query: query, + RepositoryCount: count, + } } // marshal config data to yaml configDataYaml, err := yaml.Marshal(configData) @@ -462,17 +513,17 @@ func saveInCache(name string, ids []int) error { return nil } -func loadFromCache(name string) ([]int, error) { +func loadFromHistory(name string) (string, []int, string, error) { configData, err := getConfig(configFilePath) if err != nil { - return nil, err + return "", nil, "", err } - if configData.Cache != nil { - if configData.Cache[name] != nil { - return configData.Cache[name], nil + if configData.History != nil { + if entry, ok := configData.History[name]; ok { + return entry.Controller, entry.RunIds, entry.Language, nil } } - return []int{}, nil + return "", nil, "", errors.New("No history entry found for " + name) } func getConfig(path string) (Config, error) { @@ -488,15 +539,25 @@ func getConfig(path string) (Config, error) { return configData, nil } +type HistoryEntry struct { + Name string `yaml:"name"` + Timestamp time.Time `yaml:"timestamp"` + RunIds []int `yaml:"runIds"` + Controller string `yaml:"controller"` + ListFile string `yaml:"listFile"` + List string `yaml:"list"` + Language string `yaml:"language"` + Query string `yaml:"query"` + RepositoryCount int `yaml:"repositoryCount"` +} type Config struct { - Controller string `yaml:"controller"` - ListFile string `yaml:"listFile"` - Cache map[string][]int `yaml:"cache"` + Controller string `yaml:"controller"` + ListFile string `yaml:"listFile"` + CodeQLPath string `yaml:"codeqlPath"` + History map[string]HistoryEntry `yaml:"history"` } func main() { - - // read config file configPath := os.Getenv("XDG_CONFIG_HOME") if configPath == "" { homePath := os.Getenv("HOME") @@ -505,10 +566,10 @@ func main() { } configPath = filepath.Join(homePath, ".config") } - configFilePath = filepath.Join(configPath, "mrva", "config.yml") + configFilePath = filepath.Join(configPath, "gh-mrva", "config.yml") if _, err := os.Stat(configFilePath); os.IsNotExist(err) { // create config file if it doesn't exist - // since we will use it for the name/ids cache + // since we will use it for storing the history err := os.MkdirAll(filepath.Dir(configFilePath), os.ModePerm) if err != nil { log.Println("Failed to create config file directory") @@ -524,12 +585,6 @@ func main() { if err != nil { log.Fatal(err) } - if configData.Controller != "" { - controller = configData.Controller - } - if configData.ListFile != "" { - listFile = configData.ListFile - } helpFlag := flag.String("help", "", "This help documentation.") @@ -538,10 +593,13 @@ func main() { gh mrva - submit and download CodeQL queries from MRVA Usage: - gh mrva submit --controller --lang [--name ] --list-file --list --query + gh mrva submit [--codeql-dist ] [--controller ] --lang --name [--list-file ] --list [--query | --query-suite ] - gh mrva download --run --lang --controller --output-dir [--name ] [--download-dbs] + gh mrva download --name --output-dir [--download-dbs] + gh mrva status --name [--json] + + gh mrva list [--json] `) } @@ -561,31 +619,30 @@ Usage: switch cmd { case "submit": - submit(args) + submit(configData, args) case "download": download(args) + case "status": + status(args) + case "list": + list(args) default: log.Fatalf("Unrecognized command %q. "+ "Command must be one of: submit, download", cmd) } } -func submit(args []string) { - flag := flag.NewFlagSet("mrva submit", flag.ExitOnError) - queryFileFlag := flag.String("query", "", "Path to query file") - querySuiteFileFlag := flag.String("query-suite", "", "Path to query suite file") - controllerFlag := flag.String("controller", "", "MRVA controller repository (overrides config file)") - listFileFlag := flag.String("list-file", "", "Path to repo list file (overrides config file)") - listFlag := flag.String("list", "", "Name of repo list") - langFlag := flag.String("lang", "", "DB language") - nameFlag := flag.String("name", "", "Name of run (optional)") +func status(args []string) { + flag := flag.NewFlagSet("mrva status", flag.ExitOnError) + nameFlag := flag.String("name", "", "Name of run") + jsonFlag := flag.Bool("json", false, "Output in JSON format (default: false)") flag.Usage = func() { fmt.Fprintf(os.Stderr, ` gh mrva - submit and download CodeQL queries from MRVA Usage: - gh mrva submit --controller --lang [--name ] --list-file --list [--query | --query-suite ] + gh mrva status --name [--json] `) fmt.Fprintf(os.Stderr, "Flags:\n") @@ -595,41 +652,223 @@ Usage: flag.Parse(args) + var ( + runName = *nameFlag + jsonOutput = *jsonFlag + ) + + if runName == "" { + flag.Usage() + os.Exit(1) + } + + controller, runIds, _, err := loadFromHistory(runName) + if err != nil { + log.Fatal(err) + } + if len(runIds) == 0 { + log.Fatal("No runs found for run name", runName) + } + + type Run struct { + Id int + Status string + FailureReason string + } + + type RepoWithFindings struct { + Nwo string + Count int + } + type Results struct { + Runs []Run + ResositoriesWithFindings []RepoWithFindings + TotalFindingsCount int + TotalSuccessfulScans int + TotalFailedScans int + TotalRepositoriesWithFindings int + TotalSkippedRepositories int + TotalSkippedAccessMismatchRepositories int + TotalSkippedNotFoundRepositories int + TotalSkippedNoDatabaseRepositories int + TotalSkippedOverLimitRepositories int + } + + var results Results + + for _, runId := range runIds { + if err != nil { + log.Fatal(err) + } + runDetails, err := getRunDetails(controller, runId) + if err != nil { + log.Fatal(err) + } + + status := runDetails["status"].(string) + var failure_reason string + if status == "failed" { + failure_reason = runDetails["failure_reason"].(string) + } else { + failure_reason = "" + } + + results.Runs = append(results.Runs, Run{ + Id: runId, + Status: status, + FailureReason: failure_reason, + }) + + for _, repo := range runDetails["scanned_repositories"].([]interface{}) { + if repo.(map[string]interface{})["analysis_status"].(string) == "succeeded" { + results.TotalSuccessfulScans += 1 + if repo.(map[string]interface{})["result_count"].(float64) > 0 { + results.TotalRepositoriesWithFindings += 1 + results.TotalFindingsCount += int(repo.(map[string]interface{})["result_count"].(float64)) + repoInfo := repo.(map[string]interface{})["repository"].(map[string]interface{}) + results.ResositoriesWithFindings = append(results.ResositoriesWithFindings, RepoWithFindings{ + Nwo: repoInfo["full_name"].(string), + Count: int(repo.(map[string]interface{})["result_count"].(float64)), + }) + } + } else if repo.(map[string]interface{})["analysis_status"].(string) == "failed" { + results.TotalFailedScans += 1 + } + } + + skipped_repositories := runDetails["skipped_repositories"].(map[string]interface{}) + access_mismatch_repos := skipped_repositories["access_mismatch_repos"].(map[string]interface{}) + not_found_repos := skipped_repositories["not_found_repos"].(map[string]interface{}) + no_codeql_db_repos := skipped_repositories["no_codeql_db_repos"].(map[string]interface{}) + over_limit_repos := skipped_repositories["over_limit_repos"].(map[string]interface{}) + total_skipped_repos := access_mismatch_repos["repository_count"].(float64) + not_found_repos["repository_count"].(float64) + no_codeql_db_repos["repository_count"].(float64) + over_limit_repos["repository_count"].(float64) + + results.TotalSkippedAccessMismatchRepositories += int(access_mismatch_repos["repository_count"].(float64)) + results.TotalSkippedNotFoundRepositories += int(not_found_repos["repository_count"].(float64)) + results.TotalSkippedNoDatabaseRepositories += int(no_codeql_db_repos["repository_count"].(float64)) + results.TotalSkippedOverLimitRepositories += int(over_limit_repos["repository_count"].(float64)) + results.TotalSkippedRepositories += int(total_skipped_repos) + } + + if jsonOutput { + data, err := json.MarshalIndent(results, "", " ") + if err != nil { + log.Fatal(err) + } + w := &bytes.Buffer{} + jsonpretty.Format(w, bytes.NewReader(data), " ", true) + fmt.Println(w.String()) + } else { + // Print results in a nice way + fmt.Println("Run name:", runName) + fmt.Println("Total runs:", len(results.Runs)) + fmt.Println("Total successful scans:", results.TotalSuccessfulScans) + fmt.Println("Total failed scans:", results.TotalFailedScans) + fmt.Println("Total skipped repositories:", results.TotalSkippedRepositories) + fmt.Println("Total skipped repositories due to access mismatch:", results.TotalSkippedAccessMismatchRepositories) + fmt.Println("Total skipped repositories due to not found:", results.TotalSkippedNotFoundRepositories) + fmt.Println("Total skipped repositories due to no database:", results.TotalSkippedNoDatabaseRepositories) + fmt.Println("Total skipped repositories due to over limit:", results.TotalSkippedOverLimitRepositories) + fmt.Println("Total repositories with findings:", results.TotalRepositoriesWithFindings) + fmt.Println("Total findings:", results.TotalFindingsCount) + fmt.Println("Repositories with findings:") + for _, repo := range results.ResositoriesWithFindings { + fmt.Println(" ", repo.Nwo, ":", repo.Count) + } + } +} + +func submit(configData Config, args []string) { + + flag := flag.NewFlagSet("mrva submit", flag.ExitOnError) + queryFileFlag := flag.String("query", "", "Path to query file") + querySuiteFileFlag := flag.String("query-suite", "", "Path to query suite file") + controllerFlag := flag.String("controller", "", "MRVA controller repository (overrides config file)") + codeqlPathFlag := flag.String("codeql-path", "", "Path to CodeQL distribution (overrides config file)") + listFileFlag := flag.String("list-file", "", "Path to repo list file (overrides config file)") + listFlag := flag.String("list", "", "Name of repo list") + langFlag := flag.String("lang", "", "DB language") + nameFlag := flag.String("name", "", "Name of run") + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, ` +gh mrva - submit and download CodeQL queries from MRVA + +Usage: + gh mrva submit [--codeql-dist ] [--controller ] --lang --name [--list-file ] --list [--query | --query-suite ] + +`) + fmt.Fprintf(os.Stderr, "Flags:\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\n") + } + + flag.Parse(args) + + var ( + controller string + codeqlPath string + listFile string + list string + language string + runName string + queryFile string + querySuiteFile string + ) + if *controllerFlag != "" { + controller = *controllerFlag + } else if configData.Controller != "" { + controller = configData.Controller + } + if *listFileFlag != "" { + listFile = *listFileFlag + } else if configData.ListFile != "" { + listFile = configData.ListFile + } + if *codeqlPathFlag != "" { + codeqlPath = *codeqlPathFlag + } else if configData.CodeQLPath != "" { + codeqlPath = configData.CodeQLPath + } if *langFlag != "" { language = *langFlag } if *nameFlag != "" { runName = *nameFlag } - if *controllerFlag != "" { - controller = *controllerFlag + if *listFlag != "" { + list = *listFlag } - if *listFileFlag != "" { - listFile = *listFileFlag + if *queryFileFlag != "" { + queryFile = *queryFileFlag + } + if *querySuiteFileFlag != "" { + querySuiteFile = *querySuiteFileFlag } - if controller == "" || language == "" || listFile == "" || *listFlag == "" || (*queryFileFlag == "" && *querySuiteFileFlag == "") { + if runName == "" || codeqlPath == "" || controller == "" || language == "" || listFile == "" || list == "" || (queryFile == "" && querySuiteFile == "") { flag.Usage() os.Exit(1) } // read list of target repositories - repositories, err := resolveRepositories(listFile, *listFlag) + repositories, err := resolveRepositories(listFile, list) if err != nil { log.Fatal(err) } + // if a query suite is specified, resolve the queries queries := []string{} if *queryFileFlag != "" { queries = append(queries, *queryFileFlag) } else if *querySuiteFileFlag != "" { - queries = resolveQueries(*querySuiteFileFlag) + queries = resolveQueries(codeqlPath, querySuiteFile) } fmt.Printf("Requesting running %d queries for %d repositories\n", len(queries), len(repositories)) - + var runIds []int for _, query := range queries { - encodedBundle, err := generateQueryPack(query) + encodedBundle, err := generateQueryPack(codeqlPath, query, language) if err != nil { log.Fatal(err) } @@ -643,39 +882,34 @@ Usage: } chunks = append(chunks, repositories[i:end]) } - var ids []int for _, chunk := range chunks { - id, err := submitRun(chunk, encodedBundle) - if err != nil { - log.Fatal(err) - } - ids = append(ids, id) - } - fmt.Printf("Submitted run %v\n", ids) - if runName != "" { - err = saveInCache(runName, ids) + id, err := submitRun(controller, language, chunk, encodedBundle) if err != nil { log.Fatal(err) } + runIds = append(runIds, id) } } + fmt.Printf("Submitted runs: %v\n", runIds) + if querySuiteFile != "" { + err = saveInHistory(runName, controller, runIds, language, listFile, list, querySuiteFile, len(repositories)) + } else if queryFile != "" { + err = saveInHistory(runName, controller, runIds, language, listFile, list, queryFile, len(repositories)) + } + if err != nil { + log.Fatal(err) + } } -func download(args []string) { - flag := flag.NewFlagSet("mrva submit", flag.ExitOnError) - runFlag := flag.Int("run", 0, "MRVA run ID") - outputDirFlag := flag.String("output-dir", "", "Output directory") - downloadDBsFlag := flag.Bool("download-dbs", false, "Download databases (optional)") - controllerFlag := flag.String("controller", "", "MRVA controller repository (overrides config file)") - langFlag := flag.String("lang", "", "DB language") - nameFlag := flag.String("name", "", "Name of run (optional)") - +func list(args []string) { + flag := flag.NewFlagSet("mrva list", flag.ExitOnError) + jsonFlag := flag.Bool("json", false, "Output in JSON format (default: false)") flag.Usage = func() { fmt.Fprintf(os.Stderr, ` gh mrva - submit and download CodeQL queries from MRVA Usage: - gh mrva download --run --lang --controller --output-dir [--name ] [--download-dbs] + gh mrva list [--json] `) fmt.Fprintf(os.Stderr, "Flags:\n") @@ -685,47 +919,91 @@ Usage: flag.Parse(args) - if *langFlag != "" { - language = *langFlag + var jsonOutput = *jsonFlag + + configData, err := getConfig(configFilePath) + if err != nil { + log.Fatal(err) } - if *nameFlag != "" { - runName = *nameFlag + if configData.History != nil { + if jsonOutput { + for _, entry := range configData.History { + data, err := json.MarshalIndent(entry, "", " ") + if err != nil { + log.Fatal(err) + } + + w := &bytes.Buffer{} + jsonpretty.Format(w, bytes.NewReader(data), " ", true) + fmt.Println(w.String()) + } + } else { + for name, entry := range configData.History { + fmt.Printf("%s (%v)\n", name, entry.Timestamp) + fmt.Printf(" Controller: %s\n", entry.Controller) + fmt.Printf(" Language: %s\n", entry.Language) + fmt.Printf(" List file: %s\n", entry.ListFile) + fmt.Printf(" List: %s\n", entry.List) + fmt.Printf(" Repository count: %d\n", entry.RepositoryCount) + fmt.Printf(" Query(s) : %s\n", entry.Query) + } + } } - if *controllerFlag != "" { - controller = *controllerFlag +} + +func download(args []string) { + flag := flag.NewFlagSet("mrva download", flag.ExitOnError) + nameFlag := flag.String("name", "", "Name of run") + outputDirFlag := flag.String("output-dir", "", "Output directory") + downloadDBsFlag := flag.Bool("download-dbs", false, "Download databases (optional)") + nwoFlag := flag.String("nwo", "", "Repository to download artifacts for (optional)") + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, ` +gh mrva - submit and download CodeQL queries from MRVA + +Usage: + gh mrva download --name --output-dir [--download-dbs] [--nwo ] + +`) + fmt.Fprintf(os.Stderr, "Flags:\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\n") } - if controller == "" || language == "" || (*runFlag == 0 && runName == "") || *outputDirFlag == "" { + flag.Parse(args) + + var ( + runName = *nameFlag + outputDir = *outputDirFlag + downloadDBs = *downloadDBsFlag + targetNwo = *nwoFlag + ) + + if runName == "" || outputDir == "" { flag.Usage() os.Exit(1) } - // if outputDirFlag does not exist, create it - if _, err := os.Stat(*outputDirFlag); os.IsNotExist(err) { - err := os.MkdirAll(*outputDirFlag, os.ModePerm) + // if outputDir does not exist, create it + if _, err := os.Stat(outputDir); os.IsNotExist(err) { + err := os.MkdirAll(outputDir, os.ModePerm) if err != nil { log.Fatal(err) } } - runIds := []int{} - if *runFlag > 0 { - runIds = []int{*runFlag} - } else if runName != "" { - ids, err := loadFromCache(runName) - if err != nil { - log.Fatal(err) - } - if len(ids) > 0 { - runIds = ids - } + controller, runIds, language, err := loadFromHistory(runName) + if err != nil { + log.Fatal(err) + } else if len(runIds) == 0 { + log.Fatal("No runs found for name " + runName) } + var downloadTasks []DownloadTask + for _, runId := range runIds { - fmt.Printf("Downloading MRVA results for %s (%d)\n", controller, runId) - // check if the run is complete - runDetails, err := getRunDetails(runId) - fmt.Printf("Status: %v\n", runDetails["status"]) + runDetails, err := getRunDetails(controller, runId) if err != nil { log.Fatal(err) } @@ -738,41 +1016,81 @@ Usage: result_count := repo["result_count"] repoInfo := repo["repository"].(map[string]interface{}) nwo := repoInfo["full_name"].(string) + // if targetNwo is set, only download artifacts for that repository + if targetNwo != "" && targetNwo != nwo { + continue + } if result_count != nil && result_count.(float64) > 0 { - fmt.Printf("Repo %s has %d results\n", nwo, int(result_count.(float64))) - sarifPath := filepath.Join(*outputDirFlag, fmt.Sprintf("%s.sarif", strings.Replace(nwo, "/", "_", -1))) - bqrsPath := filepath.Join(*outputDirFlag, fmt.Sprintf("%s.bqrs", strings.Replace(nwo, "/", "_", -1))) - + // check if the SARIF or BQRS file already exists + dnwo := strings.Replace(nwo, "/", "_", -1) + sarifPath := filepath.Join(outputDir, fmt.Sprintf("%s.sarif", dnwo)) + bqrsPath := filepath.Join(outputDir, fmt.Sprintf("%s.bqrs", dnwo)) + targetPath := filepath.Join(outputDir, fmt.Sprintf("%s_%s_db.zip", dnwo, language)) _, bqrsErr := os.Stat(bqrsPath) _, sarifErr := os.Stat(sarifPath) if errors.Is(bqrsErr, os.ErrNotExist) && errors.Is(sarifErr, os.ErrNotExist) { - - // download artifact (BQRS or SARIF) - fmt.Printf("Downloading results for %s\n", repoInfo["full_name"]) - runRepositoryDetails, err := getRunRepositoryDetails(runId, nwo) - if err != nil { - log.Fatal(err) - } - // download the results - artifactPath, err := downloadArtifact(runRepositoryDetails["artifact_url"].(string), *outputDirFlag, nwo) - if err != nil { - log.Fatal(err) - } - fmt.Printf("Artifact path: %s\n", artifactPath) + downloadTasks = append(downloadTasks, DownloadTask{ + runId: runId, + nwo: nwo, + controller: controller, + artifact: "artifact", + language: language, + outputDir: outputDir, + }) } - if *downloadDBsFlag { - // download database - targetPath := filepath.Join(*outputDirFlag, fmt.Sprintf("%s_%s_db.zip", strings.Replace(nwo, "/", "_", -1), language)) + if downloadDBs { + // check if the database already exists if _, err := os.Stat(targetPath); errors.Is(err, os.ErrNotExist) { - fmt.Printf("Downloading database for %s\n", nwo) - err = downloadDatabase(nwo, language, targetPath) - if err != nil { - log.Fatal(err) - } - fmt.Printf("Database path: %s\n", targetPath) + downloadTasks = append(downloadTasks, DownloadTask{ + runId: runId, + nwo: nwo, + controller: controller, + artifact: "database", + language: language, + outputDir: outputDir, + }) } } } } } + + wg := new(sync.WaitGroup) + + taskChannel := make(chan DownloadTask) + resultChannel := make(chan DownloadTask, len(downloadTasks)) + + // Start the workers + for i := 0; i < WORKERS; i++ { + wg.Add(1) + go downloadWorker(wg, taskChannel, resultChannel) + } + + // Send jobs to worker + for _, downloadTask := range downloadTasks { + taskChannel <- downloadTask + } + close(taskChannel) + + count := 0 + progressDone := make(chan bool) + + go func() { + for value := range resultChannel { + count++ + fmt.Printf("Downloaded %s for %s (%d/%d)\n", value.artifact, value.nwo, count, len(downloadTasks)) + } + fmt.Println(count, " artifacts downloaded") + progressDone <- true + }() + + // wait for all workers to finish + wg.Wait() + + // close the result channel + close(resultChannel) + + // drain the progress channel + <-progressDone + }