initial refactor to use Cobra

This commit is contained in:
Alvaro Muñoz
2023-04-21 13:43:20 +02:00
parent f589c06ce7
commit d60ef906fe
14 changed files with 1320 additions and 1179 deletions

12
LICENSE
View File

@@ -1,6 +1,6 @@
MIT License
The MIT License (MIT)
Copyright (c) 2023 Alvaro Muñoz
Copyright © 2023 Alvaro Munoz pwntester@github.com
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
@@ -9,13 +9,13 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

57
cmd/delete.go Normal file
View File

@@ -0,0 +1,57 @@
/*
Copyright © 2023 sessionNameFlag HERE <EMAIL ADDRESS>
*/
package cmd
import (
"errors"
"fmt"
"gopkg.in/yaml.v3"
"io/ioutil"
"github.com/GitHubSecurityLab/gh-mrva/utils"
"github.com/spf13/cobra"
)
var deleteCmd = &cobra.Command{
Use: "delete",
Short: "Delete a saved session.",
Long: `Delete a saved session.`,
Run: func(cmd *cobra.Command, args []string) {
deleteSession()
},
}
func init() {
rootCmd.AddCommand(deleteCmd)
deleteCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name be deleted")
deleteCmd.MarkFlagRequired("session")
}
func deleteSession() error {
sessions, err := utils.GetSessions()
if err != nil {
return err
}
if sessions == nil {
return errors.New("No sessions found")
}
// delete session if it exists
if _, ok := sessions[sessionNameFlag]; ok {
delete(sessions, sessionNameFlag)
// marshal sessions to yaml
sessionsYaml, err := yaml.Marshal(sessions)
if err != nil {
return err
}
// write sessions to file
err = ioutil.WriteFile(utils.GetSessionsFilePath(), sessionsYaml, 0755)
if err != nil {
return err
}
return nil
}
return errors.New(fmt.Sprintf("Session '%s' does not exist", sessionNameFlag))
}

150
cmd/download.go Normal file
View File

@@ -0,0 +1,150 @@
/*
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/
package cmd
import (
"sync"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"log"
"github.com/GitHubSecurityLab/gh-mrva/utils"
"github.com/GitHubSecurityLab/gh-mrva/models"
"github.com/GitHubSecurityLab/gh-mrva/config"
"github.com/spf13/cobra"
)
var downloadCmd = &cobra.Command{
Use: "download",
Short: "Downloads the artifacts associated to a given session.",
Long: `Downloads the artifacts associated to a given session.`,
Run: func(cmd *cobra.Command, args []string) {
downloadArtifacts()
},
}
func init() {
rootCmd.AddCommand(downloadCmd)
downloadCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name to be downloaded")
downloadCmd.Flags().StringVarP(&outputDirFlag, "output-dir", "o", "", "Output directory")
downloadCmd.Flags().BoolVarP(&downloadDBsFlag, "download-dbs", "d", false, "Download databases (optional)")
downloadCmd.Flags().StringVarP(&nwoFlag, "nwo", "n", "", "Repository to download artifacts for (optional)")
statusCmd.MarkFlagRequired("session")
statusCmd.MarkFlagRequired("output-dir")
}
func downloadArtifacts() {
// if outputDirFlag does not exist, create it
if _, err := os.Stat(outputDirFlag); os.IsNotExist(err) {
err := os.MkdirAll(outputDirFlag, 0755)
if err != nil {
log.Fatal(err)
}
}
controller, runs, language, err := utils.LoadSession(sessionNameFlag)
if err != nil {
fmt.Println(err)
} else if len(runs) == 0 {
fmt.Println("No runs found for sessions" + sessionNameFlag)
}
var downloadTasks []models.DownloadTask
for _, run := range runs {
runDetails, err := utils.GetRunDetails(controller, run.Id)
if err != nil {
log.Fatal(err)
}
if runDetails["status"] == "in_progress" {
log.Printf("Run %d is not complete yet. Please try again later.", run.Id)
return
}
for _, r := range runDetails["scanned_repositories"].([]interface{}) {
repo := r.(map[string]interface{})
result_count := repo["result_count"]
repoInfo := repo["repository"].(map[string]interface{})
nwo := repoInfo["full_name"].(string)
// if nwoFlag is set, only download artifacts for that repository
if nwoFlag != "" && nwoFlag != nwo {
continue
}
if result_count != nil && result_count.(float64) > 0 {
// check if the SARIF or BQRS file already exists
dnwo := strings.Replace(nwo, "/", "_", -1)
sarifPath := filepath.Join(outputDirFlag, fmt.Sprintf("%s.sarif", dnwo))
bqrsPath := filepath.Join(outputDirFlag, fmt.Sprintf("%s.bqrs", dnwo))
targetPath := filepath.Join(outputDirFlag, fmt.Sprintf("%s_%s_db.zip", dnwo, language))
_, bqrsErr := os.Stat(bqrsPath)
_, sarifErr := os.Stat(sarifPath)
if errors.Is(bqrsErr, os.ErrNotExist) && errors.Is(sarifErr, os.ErrNotExist) {
downloadTasks = append(downloadTasks, models.DownloadTask{
RunId: run.Id,
Nwo: nwo,
Controller: controller,
Artifact: "artifact",
Language: language,
OutputDir: outputDirFlag,
})
}
if downloadDBsFlag {
// check if the database already exists
if _, err := os.Stat(targetPath); errors.Is(err, os.ErrNotExist) {
downloadTasks = append(downloadTasks, models.DownloadTask{
RunId: run.Id,
Nwo: nwo,
Controller: controller,
Artifact: "database",
Language: language,
OutputDir: outputDirFlag,
})
}
}
}
}
}
wg := new(sync.WaitGroup)
taskChannel := make(chan models.DownloadTask)
resultChannel := make(chan models.DownloadTask, len(downloadTasks))
// Start the workers
for i := 0; i < config.WORKERS; i++ {
wg.Add(1)
go utils.DownloadWorker(wg, taskChannel, resultChannel)
}
// Send jobs to worker
for _, downloadTask := range downloadTasks {
taskChannel <- downloadTask
}
close(taskChannel)
count := 0
progressDone := make(chan bool)
go func() {
for value := range resultChannel {
count++
fmt.Printf("Downloaded %s for %s (%d/%d)\n", value.Artifact, value.Nwo, count, len(downloadTasks))
}
fmt.Println(fmt.Sprintf("%d artifacts downloaded", count))
progressDone <- true
}()
// wait for all workers to finish
wg.Wait()
// close the result channel
close(resultChannel)
// drain the progress channel
<-progressDone
}

16
cmd/flags.go Normal file
View File

@@ -0,0 +1,16 @@
package cmd
var (
sessionNameFlag string
outputDirFlag string
downloadDBsFlag bool
nwoFlag string
jsonFlag bool
languageFlag string
listFileFlag string
listFlag string
codeqlPathFlag string
controllerFlag string
queryFileFlag string
querySuiteFileFlag string
)

66
cmd/list.go Normal file
View File

@@ -0,0 +1,66 @@
/*
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/
package cmd
import (
"log"
"fmt"
"encoding/json"
"github.com/spf13/cobra"
"github.com/GitHubSecurityLab/gh-mrva/utils"
"github.com/GitHubSecurityLab/gh-mrva/models"
)
var listCmd = &cobra.Command{
Use: "list",
Short: "List saved sessions.",
Long: `List saved sessions.`,
Run: func(cmd *cobra.Command, args []string) {
listSessions()
},
}
func init() {
rootCmd.AddCommand(listCmd)
listCmd.Flags().BoolVarP(&jsonFlag, "json", "j", false, "Output in JSON format (default: false)")
}
func listSessions() {
sessions, err := utils.GetSessions()
if err != nil {
log.Fatal(err)
}
if sessions != nil {
if jsonFlag {
sessions_list := make([]models.Session, 0, len(sessions))
for _, session := range sessions {
sessions_list = append(sessions_list, session)
}
data, err := json.MarshalIndent(sessions_list, "", " ")
if err != nil {
log.Fatal(err)
}
fmt.Println(string(data))
// w := &bytes.Buffer{}
// jsonpretty.Format(w, bytes.NewReader(data), " ", true)
// fmt.Println(w.String())
} else {
for name, entry := range sessions {
fmt.Printf("%s (%v)\n", name, entry.Timestamp)
fmt.Printf(" Controller: %s\n", entry.Controller)
fmt.Printf(" Language: %s\n", entry.Language)
fmt.Printf(" List file: %s\n", entry.ListFile)
fmt.Printf(" List: %s\n", entry.List)
fmt.Printf(" Repository count: %d\n", entry.RepositoryCount)
fmt.Println(" Runs:")
for _, run := range entry.Runs {
fmt.Printf(" ID: %d\n", run.Id)
fmt.Printf(" Query: %s\n", run.Query)
}
}
}
}
}

72
cmd/root.go Normal file
View File

@@ -0,0 +1,72 @@
/*
Copyright © 2023 Alvaro Munoz pwntester@github.com
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
package cmd
import (
"os"
"log"
"github.com/GitHubSecurityLab/gh-mrva/utils"
"path/filepath"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "gh-mrva",
Short: "Run CodeQL queries at scale using GitHub's Multi-Repository Variant Analysis (MRVA)",
Long: `Run CodeQL queries at scale using GitHub's Multi-Repository Variant Analysis (MRVA)`,
}
func Execute() {
err := rootCmd.Execute()
if err != nil {
os.Exit(1)
}
}
func init() {
configPath := os.Getenv("XDG_CONFIG_HOME")
if configPath == "" {
homePath := os.Getenv("HOME")
if homePath == "" {
log.Fatal("HOME environment variable not set")
}
configPath = filepath.Join(homePath, ".config")
}
configFilePath := filepath.Join(configPath, "gh-mrva", "config.yml")
utils.SetConfigFilePath(configFilePath)
sessionsFilePath := filepath.Join(configPath, "gh-mrva", "sessions.yml")
if _, err := os.Stat(sessionsFilePath); os.IsNotExist(err) {
err := os.MkdirAll(filepath.Dir(sessionsFilePath), os.ModePerm)
if err != nil {
log.Fatal("Failed to create config directory")
}
// create empty file at sessionsFilePath
sessionsFile, err := os.Create(sessionsFilePath)
if err != nil {
log.Fatal("Failed to create sessions file")
}
sessionsFile.Close()
}
utils.SetSessionsFilePath(sessionsFilePath)
}

126
cmd/status.go Normal file
View File

@@ -0,0 +1,126 @@
/*
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/
package cmd
import (
"encoding/json"
"fmt"
"log"
"github.com/spf13/cobra"
"github.com/GitHubSecurityLab/gh-mrva/utils"
"github.com/GitHubSecurityLab/gh-mrva/models"
)
var statusCmd = &cobra.Command{
Use: "status",
Short: "Checks the status of a given session.",
Long: `Checks the status of a given session.`,
Run: func(cmd *cobra.Command, args []string) {
sessionStatus()
},
}
func init() {
rootCmd.AddCommand(statusCmd)
statusCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name be deleted")
statusCmd.Flags().BoolVarP(&jsonFlag, "json", "j", false, "Output in JSON format (default: false)")
statusCmd.MarkFlagRequired("session")
}
func sessionStatus() {
controller, runs, _, err := utils.LoadSession(sessionNameFlag)
if err != nil {
log.Fatal(err)
}
if len(runs) == 0 {
log.Fatal("No runs found for run name", sessionNameFlag)
}
var results models.Results
for _, run := range runs {
if err != nil {
log.Fatal(err)
}
runDetails, err := utils.GetRunDetails(controller, run.Id)
if err != nil {
log.Fatal(err)
}
status := runDetails["status"].(string)
var failure_reason string
if status == "failed" {
failure_reason = runDetails["failure_reason"].(string)
} else {
failure_reason = ""
}
results.Runs = append(results.Runs, models.RunStatus{
Id: run.Id,
Query: run.Query,
Status: status,
FailureReason: failure_reason,
})
for _, repo := range runDetails["scanned_repositories"].([]interface{}) {
if repo.(map[string]interface{})["analysis_status"].(string) == "succeeded" {
results.TotalSuccessfulScans += 1
if repo.(map[string]interface{})["result_count"].(float64) > 0 {
results.TotalRepositoriesWithFindings += 1
results.TotalFindingsCount += int(repo.(map[string]interface{})["result_count"].(float64))
repoInfo := repo.(map[string]interface{})["repository"].(map[string]interface{})
results.ResositoriesWithFindings = append(results.ResositoriesWithFindings, models.RepoWithFindings{
Nwo: repoInfo["full_name"].(string),
Count: int(repo.(map[string]interface{})["result_count"].(float64)),
RunId: run.Id,
Stars: int(repoInfo["stargazers_count"].(float64)),
})
}
} else if repo.(map[string]interface{})["analysis_status"].(string) == "failed" {
results.TotalFailedScans += 1
}
}
skipped_repositories := runDetails["skipped_repositories"].(map[string]interface{})
access_mismatch_repos := skipped_repositories["access_mismatch_repos"].(map[string]interface{})
not_found_repos := skipped_repositories["not_found_repos"].(map[string]interface{})
no_codeql_db_repos := skipped_repositories["no_codeql_db_repos"].(map[string]interface{})
over_limit_repos := skipped_repositories["over_limit_repos"].(map[string]interface{})
total_skipped_repos := access_mismatch_repos["repository_count"].(float64) + not_found_repos["repository_count"].(float64) + no_codeql_db_repos["repository_count"].(float64) + over_limit_repos["repository_count"].(float64)
results.TotalSkippedAccessMismatchRepositories += int(access_mismatch_repos["repository_count"].(float64))
results.TotalSkippedNotFoundRepositories += int(not_found_repos["repository_count"].(float64))
results.TotalSkippedNoDatabaseRepositories += int(no_codeql_db_repos["repository_count"].(float64))
results.TotalSkippedOverLimitRepositories += int(over_limit_repos["repository_count"].(float64))
results.TotalSkippedRepositories += int(total_skipped_repos)
}
if jsonFlag {
data, err := json.MarshalIndent(results, "", " ")
if err != nil {
log.Fatal(err)
}
fmt.Println(string(data))
} else {
fmt.Println("Run name:", sessionNameFlag)
fmt.Println("Total runs:", len(results.Runs))
fmt.Println("Total successful scans:", results.TotalSuccessfulScans)
fmt.Println("Total failed scans:", results.TotalFailedScans)
fmt.Println("Total skipped repositories:", results.TotalSkippedRepositories)
fmt.Println("Total skipped repositories due to access mismatch:", results.TotalSkippedAccessMismatchRepositories)
fmt.Println("Total skipped repositories due to not found:", results.TotalSkippedNotFoundRepositories)
fmt.Println("Total skipped repositories due to no database:", results.TotalSkippedNoDatabaseRepositories)
fmt.Println("Total skipped repositories due to over limit:", results.TotalSkippedOverLimitRepositories)
fmt.Println("Total repositories with findings:", results.TotalRepositoriesWithFindings)
fmt.Println("Total findings:", results.TotalFindingsCount)
fmt.Println("Repositories with findings:")
for _, repo := range results.ResositoriesWithFindings {
fmt.Println(" ", repo.Nwo, ":", repo.Count)
}
}
}

157
cmd/submit.go Normal file
View File

@@ -0,0 +1,157 @@
/*
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/
package cmd
import (
"fmt"
"log"
"os"
"github.com/spf13/cobra"
"github.com/GitHubSecurityLab/gh-mrva/utils"
"github.com/GitHubSecurityLab/gh-mrva/config"
"github.com/GitHubSecurityLab/gh-mrva/models"
)
var (
controller string
codeqlPath string
listFile string
listName string
language string
sessionName string
queryFile string
querySuiteFile string
)
var submitCmd = &cobra.Command{
Use: "submit",
Short: "Submit a query or query suite to a MRVA controller.",
Long: `Submit a query or query suite to a MRVA controller.`,
Run: func(cmd *cobra.Command, args []string) {
submitQuery()
},
}
func init() {
rootCmd.AddCommand(submitCmd)
submitCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name")
submitCmd.Flags().StringVarP(&languageFlag, "language", "l", "", "DB language")
submitCmd.Flags().StringVarP(&queryFileFlag, "query", "q", "", "Path to query file")
submitCmd.Flags().StringVarP(&querySuiteFileFlag, "query-suite","x", "", "Path to query suite file")
submitCmd.Flags().StringVarP(&controllerFlag, "controller", "c", "", "MRVA controller repository (overrides config file)")
submitCmd.Flags().StringVarP(&listFileFlag, "list-file", "f", "", "Path to repo list file (overrides config file)")
submitCmd.Flags().StringVarP(&listFlag, "list", "i", "", "Name of repo list")
submitCmd.Flags().StringVarP(&codeqlPathFlag, "codeql-path", "p", "", "Path to CodeQL distribution (overrides config file)")
submitCmd.MarkFlagRequired("session")
submitCmd.MarkFlagRequired("language")
submitCmd.MarkFlagsMutuallyExclusive("query", "query-suite")
}
func submitQuery() {
configData, err := utils.GetConfig()
if err != nil {
log.Fatal(err)
}
if controllerFlag != "" {
controller = controllerFlag
} else if configData.Controller != "" {
controller = configData.Controller
}
if listFileFlag != "" {
listFile = listFileFlag
} else if configData.ListFile != "" {
listFile = configData.ListFile
}
if codeqlPathFlag != "" {
codeqlPath = codeqlPathFlag
} else if configData.CodeQLPath != "" {
codeqlPath = configData.CodeQLPath
}
if languageFlag != "" {
language = languageFlag
}
if sessionNameFlag != "" {
sessionName = sessionNameFlag
}
if listFlag != "" {
listName = listFlag
}
if queryFileFlag != "" {
queryFile = queryFileFlag
}
if querySuiteFileFlag != "" {
querySuiteFile = querySuiteFileFlag
}
if controller == "" {
fmt.Println("Please specify a controller.")
os.Exit(1)
}
if listFile == "" {
fmt.Println("Please specify a list file.")
os.Exit(1)
}
if listName == "" {
fmt.Println("Please specify a list name.")
os.Exit(1)
}
if queryFile == "" && querySuiteFile == "" {
fmt.Println("Please specify a query or query suite.")
os.Exit(1)
}
// read list of target repositories
repositories, err := utils.ResolveRepositories(listFile, listName)
if err != nil {
log.Fatal(err)
}
// if a query suite is specified, resolve the queries
queries := []string{}
if queryFileFlag != "" {
queries = append(queries, queryFileFlag)
} else if querySuiteFileFlag != "" {
queries = utils.ResolveQueries(codeqlPath, querySuiteFile)
}
fmt.Printf("Submitting %d queries for %d repositories\n", len(queries), len(repositories))
var runs []models.Run
for _, query := range queries {
encodedBundle, err := utils.GenerateQueryPack(codeqlPath, query, language)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Generated encoded bundle for %s\n", query)
var chunks [][]string
for i := 0; i < len(repositories); i += config.MAX_MRVA_REPOSITORIES {
end := i + config.MAX_MRVA_REPOSITORIES
if end > len(repositories) {
end = len(repositories)
}
chunks = append(chunks, repositories[i:end])
}
for _, chunk := range chunks {
id, err := utils.SubmitRun(controller, language, chunk, encodedBundle)
if err != nil {
log.Fatal(err)
}
runs = append(runs, models.Run{Id: id, Query: query})
}
}
if querySuiteFile != "" {
err = utils.SaveSession(sessionName, controller, runs, language, listFile, listName, querySuiteFile, len(repositories))
} else if queryFile != "" {
err = utils.SaveSession(sessionName, controller, runs, language, listFile, listName, queryFile, len(repositories))
}
if err != nil {
log.Fatal(err)
}
fmt.Println("Done!")
}

7
config/config.go Normal file
View File

@@ -0,0 +1,7 @@
package config
const (
MAX_MRVA_REPOSITORIES = 1000
WORKERS = 10
)

9
go.mod
View File

@@ -1,10 +1,15 @@
module github.com/pwntester/gh-mrva
module github.com/GitHubSecurityLab/gh-mrva
go 1.19
require github.com/cli/go-gh v1.2.1
require github.com/aymanbagabas/go-osc52 v1.2.1 // indirect
require (
github.com/aymanbagabas/go-osc52 v1.2.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/cobra v1.7.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
)
require (
github.com/cli/safeexec v1.0.0 // indirect

8
go.sum
View File

@@ -7,6 +7,7 @@ github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI=
github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=
github.com/cli/shurcooL-graphql v0.0.2 h1:rwP5/qQQ2fM0TzkUTwtt6E2LbIYf6R+39cUXTa04NYk=
github.com/cli/shurcooL-graphql v0.0.2/go.mod h1:tlrLmw/n5Q/+4qSvosT+9/W5zc8ZMjnJeYBxSdb4nWA=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
@@ -14,6 +15,8 @@ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
github.com/henvic/httpretty v0.0.6 h1:JdzGzKZBajBfnvlMALXXMVQWxWMF/ofTy8C3/OSUTxs=
github.com/henvic/httpretty v0.0.6/go.mod h1:X38wLjWXHkXT7r2+uK8LjCMne9rsuNaBLJ+5cU2/Pmo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
@@ -29,6 +32,11 @@ github.com/muesli/termenv v0.14.0/go.mod h1:kG/pF1E7fh949Xhe156crRUrHNyK221IuGO7
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e h1:BuzhfgfWQbX0dWzYzT1zsORLnHRv3bcRcsaUk0VmXA8=
github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e/go.mod h1:/Tnicc6m/lsJE0irFMA0LfIwTBo4QP7A8IfyIv4zZKI=

1194
main.go

File diff suppressed because it is too large Load Diff

65
models/models.go Normal file
View File

@@ -0,0 +1,65 @@
package models
import (
"time"
)
type Run struct {
Id int `yaml:"id"`
Query string `yaml:"query"`
}
type Session struct {
Name string `yaml:"name" json:"name"`
Timestamp time.Time `yaml:"timestamp" json:"timestamp"`
Runs []Run `yaml:"runs" json:"runs"`
Controller string `yaml:"controller" json:"controller"`
ListFile string `yaml:"list_file" json:"list_file"`
List string `yaml:"list" json:"list"`
Language string `yaml:"language" json:"language"`
RepositoryCount int `yaml:"repository_count" json:"repository_count"`
}
type Config struct {
Controller string `yaml:"controller"`
ListFile string `yaml:"list_file"`
CodeQLPath string `yaml:"codeql_path"`
}
type DownloadTask struct {
RunId int
Nwo string
Controller string
Artifact string
OutputDir string
Language string
}
type RunStatus struct {
Id int `json:"id"`
Query string `json:"query"`
Status string `json:"status"`
FailureReason string `json:"failure_reason"`
}
type RepoWithFindings struct {
Nwo string `json:"nwo"`
Count int `json:"count"`
RunId int `json:"run_id"`
Stars int `json:"stars"`
}
type Results struct {
Runs []RunStatus `json:"runs"`
ResositoriesWithFindings []RepoWithFindings `json:"repositories_with_findings"`
TotalFindingsCount int `json:"total_findings_count"`
TotalSuccessfulScans int `json:"total_successful_scans"`
TotalFailedScans int `json:"total_failed_scans"`
TotalRepositoriesWithFindings int `json:"total_repositories_with_findings"`
TotalSkippedRepositories int `json:"total_skipped_repositories"`
TotalSkippedAccessMismatchRepositories int `json:"total_skipped_access_mismatch_repositories"`
TotalSkippedNotFoundRepositories int `json:"total_skipped_not_found_repositories"`
TotalSkippedNoDatabaseRepositories int `json:"total_skipped_no_database_repositories"`
TotalSkippedOverLimitRepositories int `json:"total_skipped_over_limit_repositories"`
}

560
utils/utils.go Normal file
View File

@@ -0,0 +1,560 @@
package utils
import (
"archive/zip"
"sync"
"text/template"
"github.com/google/uuid"
"encoding/base64"
"path/filepath"
"strings"
"os/exec"
"bytes"
"encoding/json"
"fmt"
"time"
"os"
"gopkg.in/yaml.v3"
"io/ioutil"
"log"
"errors"
"github.com/cli/go-gh"
"github.com/cli/go-gh/pkg/api"
"github.com/GitHubSecurityLab/gh-mrva/models"
)
var (
configFilePath string
sessionsFilePath string
)
func GetSessionsFilePath() string {
return sessionsFilePath
}
func SetSessionsFilePath(path string) {
sessionsFilePath = path
}
func GetConfigFilePath() string {
return configFilePath
}
func SetConfigFilePath(path string) {
configFilePath = path
}
func GetSessions() (map[string]models.Session, error) {
sessionsFile, err := ioutil.ReadFile(sessionsFilePath)
var sessions map[string]models.Session
if err != nil {
return sessions, err
}
err = yaml.Unmarshal(sessionsFile, &sessions)
if err != nil {
log.Fatal(err)
}
return sessions, nil
}
func LoadSession(name string) (string, []models.Run, string, error) {
sessions, err := GetSessions()
if err != nil {
return "", nil, "", err
}
if sessions != nil {
if entry, ok := sessions[name]; ok {
return entry.Controller, entry.Runs, entry.Language, nil
}
}
return "", nil, "", errors.New("No session found for " + name)
}
func GetRunDetails(controller string, runId int) (map[string]interface{}, error) {
opts := api.ClientOptions{
Headers: map[string]string{"Accept": "application/vnd.github.v3+json"},
}
client, err := gh.RESTClient(&opts)
if err != nil {
return nil, err
}
response := make(map[string]interface{})
err = client.Get(fmt.Sprintf("repos/%s/code-scanning/codeql/variant-analyses/%d", controller, runId), &response)
if err != nil {
return nil, err
}
return response, nil
}
func GetRunRepositoryDetails(controller string, runId int, nwo string) (map[string]interface{}, error) {
opts := api.ClientOptions{
Headers: map[string]string{"Accept": "application/vnd.github.v3+json"},
}
client, err := gh.RESTClient(&opts)
if err != nil {
return nil, err
}
response := make(map[string]interface{})
err = client.Get(fmt.Sprintf("repos/%s/code-scanning/codeql/variant-analyses/%d/repos/%s", controller, runId, nwo), &response)
if err != nil {
return nil, err
}
return response, nil
}
func SaveSession(name string, controller string, runs []models.Run, language string, listFile string, list string, query string, count int) error {
sessions, err := GetSessions()
if err != nil {
return err
}
if sessions == nil {
sessions = make(map[string]models.Session)
}
// add new session if it doesn't already exist
if _, ok := sessions[name]; ok {
return errors.New(fmt.Sprintf("Session '%s' already exists", name))
} else {
sessions[name] = models.Session{
Name: name,
Runs: runs,
Timestamp: time.Now(),
Controller: controller,
Language: language,
ListFile: listFile,
List: list,
RepositoryCount: count,
}
}
// marshal sessions to yaml
sessionsYaml, err := yaml.Marshal(sessions)
if err != nil {
return err
}
// write sessions to file
err = ioutil.WriteFile(sessionsFilePath, sessionsYaml, os.ModePerm)
if err != nil {
return err
}
return nil
}
func SubmitRun(controller string, language string, repoChunk []string, bundle string) (int, error) {
opts := api.ClientOptions{
Headers: map[string]string{"Accept": "application/vnd.github.v3+json"},
}
client, err := gh.RESTClient(&opts)
if err != nil {
return -1, err
}
body := struct {
Repositories []string `json:"repositories"`
Language string `json:"language"`
Pack string `json:"query_pack"`
Ref string `json:"action_repo_ref"`
}{
Repositories: repoChunk,
Language: language,
Pack: bundle,
Ref: "main",
}
var buf bytes.Buffer
err = json.NewEncoder(&buf).Encode(body)
if err != nil {
return -1, err
}
response := make(map[string]interface{})
err = client.Post(fmt.Sprintf("repos/%s/code-scanning/codeql/variant-analyses", controller), &buf, &response)
if err != nil {
return -1, err
}
id := int(response["id"].(float64))
return id, nil
}
func GetConfig() (models.Config, error) {
configFile, err := ioutil.ReadFile(configFilePath)
var configData models.Config
if err != nil {
return configData, err
}
err = yaml.Unmarshal(configFile, &configData)
if err != nil {
log.Fatal(err)
}
return configData, nil
}
func ResolveRepositories(listFile string, list string) ([]string, error) {
fmt.Printf("Resolving %s repositories from %s\n", list, listFile)
jsonFile, err := os.Open(listFile)
if err != nil {
return nil, err
}
defer jsonFile.Close()
byteValue, _ := ioutil.ReadAll(jsonFile)
var repoLists map[string][]string
err = json.Unmarshal(byteValue, &repoLists)
if err != nil {
return nil, err
}
return repoLists[list], nil
}
func ResolveQueries(codeqlPath string, querySuite string) []string {
args := []string{"resolve", "queries", "--format=json", querySuite}
jsonBytes, err := RunCodeQLCommand(codeqlPath, false, args...)
var queries []string
if strings.TrimSpace(string(jsonBytes)) == "" {
fmt.Println("No queries found in the specified query suite.")
os.Exit(1)
}
err = json.Unmarshal(jsonBytes, &queries)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
return queries
}
func RunCodeQLCommand(codeqlPath string, combined bool, args ...string) ([]byte, error) {
if !strings.Contains(strings.Join(args, " "), "packlist") {
args = append(args, fmt.Sprintf("--additional-packs=%s", codeqlPath))
}
cmd := exec.Command("codeql", args...)
cmd.Env = os.Environ()
if combined {
return cmd.CombinedOutput()
} else {
return cmd.Output()
}
}
func GenerateQueryPack(codeqlPath string, queryFile string, language string) (string, error) {
fmt.Printf("Generating query pack for %s\n", queryFile)
// create a temporary directory to hold the query pack
queryPackDir, err := ioutil.TempDir("", "query-pack-")
if err != nil {
log.Fatal(err)
}
defer os.RemoveAll(queryPackDir)
queryFile, err = filepath.Abs(queryFile)
if err != nil {
log.Fatal(err)
}
if _, err := os.Stat(queryFile); errors.Is(err, os.ErrNotExist) {
log.Fatal(fmt.Sprintf("Query file %s does not exist", queryFile))
}
originalPackRoot := FindPackRoot(queryFile)
packRelativePath, _ := filepath.Rel(originalPackRoot, queryFile)
targetQueryFileName := filepath.Join(queryPackDir, packRelativePath)
if _, err := os.Stat(filepath.Join(originalPackRoot, "qlpack.yml")); errors.Is(err, os.ErrNotExist) {
// qlpack.yml not found, generate a synthetic one
fmt.Printf("QLPack does not exist. Generating synthetic one for %s\n", queryFile)
// copy only the query file to the query pack directory
err := CopyFile(queryFile, targetQueryFileName)
if err != nil {
log.Fatal(err)
}
// generate a synthetic qlpack.yml
td := struct {
Language string
Name string
Query string
}{
Language: language,
Name: "codeql-remote/query",
Query: strings.Replace(packRelativePath, string(os.PathSeparator), "/", -1),
}
t, err := template.New("").Parse(`name: {{ .Name }}
version: 0.0.0
dependencies:
codeql/{{ .Language }}-all: "*"
defaultSuite:
description: Query suite for variant analysis
query: {{ .Query }}`)
if err != nil {
log.Fatal(err)
}
f, err := os.Create(filepath.Join(queryPackDir, "qlpack.yml"))
defer f.Close()
if err != nil {
log.Fatal(err)
}
err = t.Execute(f, td)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Copied QLPack files to %s\n", queryPackDir)
} else {
// don't include all query files in the QLPacks. We only want the queryFile to be copied.
fmt.Printf("QLPack exists, stripping all other queries from %s\n", originalPackRoot)
toCopy := PackPacklist(codeqlPath, originalPackRoot, false)
// also copy the lock file (either new name or old name) and the query file itself (these are not included in the packlist)
lockFileNew := filepath.Join(originalPackRoot, "qlpack.lock.yml")
lockFileOld := filepath.Join(originalPackRoot, "codeql-pack.lock.yml")
candidateFiles := []string{lockFileNew, lockFileOld, queryFile}
for _, candidateFile := range candidateFiles {
if _, err := os.Stat(candidateFile); !errors.Is(err, os.ErrNotExist) {
// if the file exists, copy it
toCopy = append(toCopy, candidateFile)
}
}
// copy the files to the queryPackDir directory
fmt.Printf("Preparing stripped QLPack in %s\n", queryPackDir)
for _, srcPath := range toCopy {
relPath, _ := filepath.Rel(originalPackRoot, srcPath)
targetPath := filepath.Join(queryPackDir, relPath)
//fmt.Printf("Copying %s to %s\n", srcPath, targetPath)
err := CopyFile(srcPath, targetPath)
if err != nil {
log.Fatal(err)
}
}
fmt.Printf("Fixing QLPack in %s\n", queryPackDir)
FixPackFile(queryPackDir, packRelativePath)
}
// assuming we are using 2.11.3 or later so Qlx remote is supported
ccache := filepath.Join(originalPackRoot, ".cache")
precompilationOpts := []string{"--qlx", "--no-default-compilation-cache", "--compilation-cache=" + ccache}
bundlePath := filepath.Join(filepath.Dir(queryPackDir), fmt.Sprintf("qlpack-%s-generated.tgz", uuid.New().String()))
// install the pack dependencies
fmt.Print("Installing QLPack dependencies\n")
args := []string{"pack", "install", queryPackDir}
stdouterr, err := RunCodeQLCommand(codeqlPath, true, args...)
if err != nil {
fmt.Printf("`codeql pack bundle` failed with error: %v\n", string(stdouterr))
return "", fmt.Errorf("Failed to install query pack: %v", err)
}
// bundle the query pack
fmt.Print("Compiling and bundling the QLPack (This may take a while)\n")
args = []string{"pack", "bundle", "-o", bundlePath, queryPackDir}
args = append(args, precompilationOpts...)
stdouterr, err = RunCodeQLCommand(codeqlPath, true, args...)
if err != nil {
fmt.Printf("`codeql pack bundle` failed with error: %v\n", string(stdouterr))
return "", fmt.Errorf("Failed to bundle query pack: %v\n", err)
}
// open the bundle file and encode it as base64
bundleFile, err := os.Open(bundlePath)
if err != nil {
return "", fmt.Errorf("Failed to open bundle file: %v\n", err)
}
defer bundleFile.Close()
bundleBytes, err := ioutil.ReadAll(bundleFile)
if err != nil {
return "", fmt.Errorf("Failed to read bundle file: %v\n", err)
}
bundleBase64 := base64.StdEncoding.EncodeToString(bundleBytes)
return bundleBase64, nil
}
func PackPacklist(codeqlPath string, dir string, includeQueries bool) []string {
// since 2.7.1, packlist returns an object with a "paths" property that is a list of packs.
args := []string{"pack", "packlist", "--format=json"}
if !includeQueries {
args = append(args, "--no-include-queries")
}
args = append(args, dir)
jsonBytes, err := RunCodeQLCommand(codeqlPath, false, args...)
var packlist map[string][]string
err = json.Unmarshal(jsonBytes, &packlist)
if err != nil {
log.Fatal(err)
}
return packlist["paths"]
}
func FindPackRoot(queryFile string) string {
// Starting on the directory of queryPackDir, go down until a qlpack.yml find is found. return that directory
// If no qlpack.yml is found, return the directory of queryFile
currentDir := filepath.Dir(queryFile)
for currentDir != "/" {
if _, err := os.Stat(filepath.Join(currentDir, "qlpack.yml")); errors.Is(err, os.ErrNotExist) {
// qlpack.yml not found, go up one level
currentDir = filepath.Dir(currentDir)
} else {
return currentDir
}
}
return filepath.Dir(queryFile)
}
func FixPackFile(queryPackDir string, packRelativePath string) error {
packPath := filepath.Join(queryPackDir, "qlpack.yml")
packFile, err := ioutil.ReadFile(packPath)
if err != nil {
return err
}
var packData map[string]interface{}
err = yaml.Unmarshal(packFile, &packData)
if err != nil {
return err
}
// update the default suite
defaultSuiteFile := packData["defaultSuiteFile"]
if defaultSuiteFile != nil {
// remove the defaultSuiteFile property
delete(packData, "defaultSuiteFile")
}
packData["defaultSuite"] = map[string]string{
"query": packRelativePath,
"description": "Query suite for Variant Analysis",
}
// update the name
packData["name"] = "codeql-remote/query"
// remove any `${workspace}` version references
dependencies := packData["dependencies"]
if dependencies != nil {
// for key and value in dependencies
for key, value := range dependencies.(map[string]interface{}) {
// if value is a string and value contains `${workspace}`
if value == "${workspace}" {
// replace the value with `*`
packData["dependencies"].(map[string]interface{})[key] = "*"
}
}
}
// write the pack file
packFile, err = yaml.Marshal(packData)
if err != nil {
return err
}
err = ioutil.WriteFile(packPath, packFile, 0644)
if err != nil {
return err
}
return nil
}
func CopyFile(srcPath string, targetPath string) error {
err := os.MkdirAll(filepath.Dir(targetPath), os.ModePerm)
if err != nil {
return err
}
bytesRead, err := ioutil.ReadFile(srcPath)
if err != nil {
return err
}
err = ioutil.WriteFile(targetPath, bytesRead, 0644)
if err != nil {
return err
}
return nil
}
func DownloadWorker(wg *sync.WaitGroup, taskChannel <-chan models.DownloadTask, resultChannel chan models.DownloadTask) {
defer wg.Done()
for task := range taskChannel {
if task.Artifact == "artifact" {
DownloadResults(task.Controller, task.RunId, task.Nwo, task.OutputDir)
resultChannel <- task
} else if task.Artifact == "database" {
fmt.Println("Downloading database", task.Nwo, task.Language, task.OutputDir)
DownloadDatabase(task.Nwo, task.Language, task.OutputDir)
resultChannel <- task
}
}
}
func downloadArtifact(url string, outputDir string, nwo string) error {
client, err := gh.HTTPClient(nil)
if err != nil {
return err
}
resp, err := client.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatal(err)
}
zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body)))
if err != nil {
log.Fatal(err)
}
for _, zf := range zipReader.File {
if zf.Name != "results.sarif" && zf.Name != "results.bqrs" {
continue
}
f, err := zf.Open()
if err != nil {
log.Fatal(err)
}
defer f.Close()
bytes, err := ioutil.ReadAll(f)
if err != nil {
log.Fatal(err)
}
extension := ""
resultPath := ""
if zf.Name == "results.bqrs" {
extension = "bqrs"
} else if zf.Name == "results.sarif" {
extension = "sarif"
}
resultPath = filepath.Join(outputDir, fmt.Sprintf("%s.%s", strings.Replace(nwo, "/", "_", -1), extension))
err = ioutil.WriteFile(resultPath, bytes, os.ModePerm)
if err != nil {
return err
}
return nil
}
return errors.New("No results.sarif file found in artifact")
}
func DownloadResults(controller string, runId int, nwo string, outputDir string) error {
// download artifact (BQRS or SARIF)
runRepositoryDetails, err := GetRunRepositoryDetails(controller, runId, nwo)
if err != nil {
return errors.New("Failed to get run repository details")
}
// download the results
err = downloadArtifact(runRepositoryDetails["artifact_url"].(string), outputDir, nwo)
if err != nil {
return errors.New("Failed to download artifact")
}
return nil
}
func DownloadDatabase(nwo string, language string, outputDir string) error {
dnwo := strings.Replace(nwo, "/", "_", -1)
targetPath := filepath.Join(outputDir, fmt.Sprintf("%s_%s_db.zip", dnwo, language))
opts := api.ClientOptions{
Headers: map[string]string{"Accept": "application/zip"},
}
client, err := gh.HTTPClient(&opts)
if err != nil {
return err
}
resp, err := client.Get(fmt.Sprintf("https://api.github.com/repos/%s/code-scanning/codeql/databases/%s", nwo, language))
if err != nil {
return err
}
defer resp.Body.Close()
bytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
err = ioutil.WriteFile(targetPath, bytes, os.ModePerm)
return nil
}