diff --git a/LICENSE b/LICENSE index 9f89be4..af61d39 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ -MIT License +The MIT License (MIT) -Copyright (c) 2023 Alvaro Muñoz +Copyright © 2023 Alvaro Munoz pwntester@github.com Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -9,13 +9,13 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/cmd/delete.go b/cmd/delete.go new file mode 100644 index 0000000..aeb4311 --- /dev/null +++ b/cmd/delete.go @@ -0,0 +1,57 @@ +/* +Copyright © 2023 sessionNameFlag HERE +*/ +package cmd + +import ( + "errors" + "fmt" + "gopkg.in/yaml.v3" + "io/ioutil" + + "github.com/GitHubSecurityLab/gh-mrva/utils" + "github.com/spf13/cobra" +) + +var deleteCmd = &cobra.Command{ + Use: "delete", + Short: "Delete a saved session.", + Long: `Delete a saved session.`, + Run: func(cmd *cobra.Command, args []string) { + deleteSession() + }, +} + +func init() { + rootCmd.AddCommand(deleteCmd) + deleteCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name be deleted") + deleteCmd.MarkFlagRequired("session") +} + +func deleteSession() error { + sessions, err := utils.GetSessions() + if err != nil { + return err + } + if sessions == nil { + return errors.New("No sessions found") + } + // delete session if it exists + if _, ok := sessions[sessionNameFlag]; ok { + + delete(sessions, sessionNameFlag) + + // marshal sessions to yaml + sessionsYaml, err := yaml.Marshal(sessions) + if err != nil { + return err + } + // write sessions to file + err = ioutil.WriteFile(utils.GetSessionsFilePath(), sessionsYaml, 0755) + if err != nil { + return err + } + return nil + } + return errors.New(fmt.Sprintf("Session '%s' does not exist", sessionNameFlag)) +} diff --git a/cmd/download.go b/cmd/download.go new file mode 100644 index 0000000..8bf44e0 --- /dev/null +++ b/cmd/download.go @@ -0,0 +1,150 @@ +/* +Copyright © 2023 NAME HERE + +*/ +package cmd + +import ( + "sync" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "log" + "github.com/GitHubSecurityLab/gh-mrva/utils" + "github.com/GitHubSecurityLab/gh-mrva/models" + "github.com/GitHubSecurityLab/gh-mrva/config" + + "github.com/spf13/cobra" +) + +var downloadCmd = &cobra.Command{ + Use: "download", + Short: "Downloads the artifacts associated to a given session.", + Long: `Downloads the artifacts associated to a given session.`, + Run: func(cmd *cobra.Command, args []string) { + downloadArtifacts() + }, +} + +func init() { + rootCmd.AddCommand(downloadCmd) + downloadCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name to be downloaded") + downloadCmd.Flags().StringVarP(&outputDirFlag, "output-dir", "o", "", "Output directory") + downloadCmd.Flags().BoolVarP(&downloadDBsFlag, "download-dbs", "d", false, "Download databases (optional)") + downloadCmd.Flags().StringVarP(&nwoFlag, "nwo", "n", "", "Repository to download artifacts for (optional)") + statusCmd.MarkFlagRequired("session") + statusCmd.MarkFlagRequired("output-dir") +} + +func downloadArtifacts() { + + // if outputDirFlag does not exist, create it + if _, err := os.Stat(outputDirFlag); os.IsNotExist(err) { + err := os.MkdirAll(outputDirFlag, 0755) + if err != nil { + log.Fatal(err) + } + } + + controller, runs, language, err := utils.LoadSession(sessionNameFlag) + if err != nil { + fmt.Println(err) + } else if len(runs) == 0 { + fmt.Println("No runs found for sessions" + sessionNameFlag) + } + + var downloadTasks []models.DownloadTask + + for _, run := range runs { + runDetails, err := utils.GetRunDetails(controller, run.Id) + if err != nil { + log.Fatal(err) + } + if runDetails["status"] == "in_progress" { + log.Printf("Run %d is not complete yet. Please try again later.", run.Id) + return + } + for _, r := range runDetails["scanned_repositories"].([]interface{}) { + repo := r.(map[string]interface{}) + result_count := repo["result_count"] + repoInfo := repo["repository"].(map[string]interface{}) + nwo := repoInfo["full_name"].(string) + // if nwoFlag is set, only download artifacts for that repository + if nwoFlag != "" && nwoFlag != nwo { + continue + } + if result_count != nil && result_count.(float64) > 0 { + // check if the SARIF or BQRS file already exists + dnwo := strings.Replace(nwo, "/", "_", -1) + sarifPath := filepath.Join(outputDirFlag, fmt.Sprintf("%s.sarif", dnwo)) + bqrsPath := filepath.Join(outputDirFlag, fmt.Sprintf("%s.bqrs", dnwo)) + targetPath := filepath.Join(outputDirFlag, fmt.Sprintf("%s_%s_db.zip", dnwo, language)) + _, bqrsErr := os.Stat(bqrsPath) + _, sarifErr := os.Stat(sarifPath) + if errors.Is(bqrsErr, os.ErrNotExist) && errors.Is(sarifErr, os.ErrNotExist) { + downloadTasks = append(downloadTasks, models.DownloadTask{ + RunId: run.Id, + Nwo: nwo, + Controller: controller, + Artifact: "artifact", + Language: language, + OutputDir: outputDirFlag, + }) + } + if downloadDBsFlag { + // check if the database already exists + if _, err := os.Stat(targetPath); errors.Is(err, os.ErrNotExist) { + downloadTasks = append(downloadTasks, models.DownloadTask{ + RunId: run.Id, + Nwo: nwo, + Controller: controller, + Artifact: "database", + Language: language, + OutputDir: outputDirFlag, + }) + } + } + } + } + } + + wg := new(sync.WaitGroup) + + taskChannel := make(chan models.DownloadTask) + resultChannel := make(chan models.DownloadTask, len(downloadTasks)) + + // Start the workers + for i := 0; i < config.WORKERS; i++ { + wg.Add(1) + go utils.DownloadWorker(wg, taskChannel, resultChannel) + } + + // Send jobs to worker + for _, downloadTask := range downloadTasks { + taskChannel <- downloadTask + } + close(taskChannel) + + count := 0 + progressDone := make(chan bool) + + go func() { + for value := range resultChannel { + count++ + fmt.Printf("Downloaded %s for %s (%d/%d)\n", value.Artifact, value.Nwo, count, len(downloadTasks)) + } + fmt.Println(fmt.Sprintf("%d artifacts downloaded", count)) + progressDone <- true + }() + + // wait for all workers to finish + wg.Wait() + + // close the result channel + close(resultChannel) + + // drain the progress channel + <-progressDone +} diff --git a/cmd/flags.go b/cmd/flags.go new file mode 100644 index 0000000..051cc02 --- /dev/null +++ b/cmd/flags.go @@ -0,0 +1,16 @@ +package cmd + +var ( + sessionNameFlag string + outputDirFlag string + downloadDBsFlag bool + nwoFlag string + jsonFlag bool + languageFlag string + listFileFlag string + listFlag string + codeqlPathFlag string + controllerFlag string + queryFileFlag string + querySuiteFileFlag string +) diff --git a/cmd/list.go b/cmd/list.go new file mode 100644 index 0000000..7c1a059 --- /dev/null +++ b/cmd/list.go @@ -0,0 +1,66 @@ +/* +Copyright © 2023 NAME HERE + +*/ +package cmd + +import ( + "log" + "fmt" + "encoding/json" + + "github.com/spf13/cobra" + "github.com/GitHubSecurityLab/gh-mrva/utils" + "github.com/GitHubSecurityLab/gh-mrva/models" +) + +var listCmd = &cobra.Command{ + Use: "list", + Short: "List saved sessions.", + Long: `List saved sessions.`, + Run: func(cmd *cobra.Command, args []string) { + listSessions() + }, +} + +func init() { + rootCmd.AddCommand(listCmd) + listCmd.Flags().BoolVarP(&jsonFlag, "json", "j", false, "Output in JSON format (default: false)") +} + +func listSessions() { + sessions, err := utils.GetSessions() + if err != nil { + log.Fatal(err) + } + if sessions != nil { + if jsonFlag { + sessions_list := make([]models.Session, 0, len(sessions)) + for _, session := range sessions { + sessions_list = append(sessions_list, session) + } + data, err := json.MarshalIndent(sessions_list, "", " ") + if err != nil { + log.Fatal(err) + } + fmt.Println(string(data)) + // w := &bytes.Buffer{} + // jsonpretty.Format(w, bytes.NewReader(data), " ", true) + // fmt.Println(w.String()) + } else { + for name, entry := range sessions { + fmt.Printf("%s (%v)\n", name, entry.Timestamp) + fmt.Printf(" Controller: %s\n", entry.Controller) + fmt.Printf(" Language: %s\n", entry.Language) + fmt.Printf(" List file: %s\n", entry.ListFile) + fmt.Printf(" List: %s\n", entry.List) + fmt.Printf(" Repository count: %d\n", entry.RepositoryCount) + fmt.Println(" Runs:") + for _, run := range entry.Runs { + fmt.Printf(" ID: %d\n", run.Id) + fmt.Printf(" Query: %s\n", run.Query) + } + } + } + } +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000..6f6fe1e --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,72 @@ +/* +Copyright © 2023 Alvaro Munoz pwntester@github.com + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +package cmd + +import ( + "os" + "log" + "github.com/GitHubSecurityLab/gh-mrva/utils" + "path/filepath" + + "github.com/spf13/cobra" +) + +var rootCmd = &cobra.Command{ + Use: "gh-mrva", + Short: "Run CodeQL queries at scale using GitHub's Multi-Repository Variant Analysis (MRVA)", + Long: `Run CodeQL queries at scale using GitHub's Multi-Repository Variant Analysis (MRVA)`, +} + +func Execute() { + err := rootCmd.Execute() + if err != nil { + os.Exit(1) + } +} + +func init() { + configPath := os.Getenv("XDG_CONFIG_HOME") + if configPath == "" { + homePath := os.Getenv("HOME") + if homePath == "" { + log.Fatal("HOME environment variable not set") + } + configPath = filepath.Join(homePath, ".config") + } + configFilePath := filepath.Join(configPath, "gh-mrva", "config.yml") + utils.SetConfigFilePath(configFilePath) + + sessionsFilePath := filepath.Join(configPath, "gh-mrva", "sessions.yml") + if _, err := os.Stat(sessionsFilePath); os.IsNotExist(err) { + err := os.MkdirAll(filepath.Dir(sessionsFilePath), os.ModePerm) + if err != nil { + log.Fatal("Failed to create config directory") + } + // create empty file at sessionsFilePath + sessionsFile, err := os.Create(sessionsFilePath) + if err != nil { + log.Fatal("Failed to create sessions file") + } + sessionsFile.Close() + } + utils.SetSessionsFilePath(sessionsFilePath) +} diff --git a/cmd/status.go b/cmd/status.go new file mode 100644 index 0000000..bb3a70c --- /dev/null +++ b/cmd/status.go @@ -0,0 +1,126 @@ +/* +Copyright © 2023 NAME HERE + +*/ +package cmd + +import ( + "encoding/json" + "fmt" + "log" + + "github.com/spf13/cobra" + "github.com/GitHubSecurityLab/gh-mrva/utils" + "github.com/GitHubSecurityLab/gh-mrva/models" +) + +var statusCmd = &cobra.Command{ + Use: "status", + Short: "Checks the status of a given session.", + Long: `Checks the status of a given session.`, + Run: func(cmd *cobra.Command, args []string) { + sessionStatus() + }, +} + +func init() { + rootCmd.AddCommand(statusCmd) + statusCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name be deleted") + statusCmd.Flags().BoolVarP(&jsonFlag, "json", "j", false, "Output in JSON format (default: false)") + statusCmd.MarkFlagRequired("session") +} + +func sessionStatus() { + + controller, runs, _, err := utils.LoadSession(sessionNameFlag) + if err != nil { + log.Fatal(err) + } + if len(runs) == 0 { + log.Fatal("No runs found for run name", sessionNameFlag) + } + + + var results models.Results + + for _, run := range runs { + if err != nil { + log.Fatal(err) + } + runDetails, err := utils.GetRunDetails(controller, run.Id) + if err != nil { + log.Fatal(err) + } + + status := runDetails["status"].(string) + var failure_reason string + if status == "failed" { + failure_reason = runDetails["failure_reason"].(string) + } else { + failure_reason = "" + } + + results.Runs = append(results.Runs, models.RunStatus{ + Id: run.Id, + Query: run.Query, + Status: status, + FailureReason: failure_reason, + }) + + for _, repo := range runDetails["scanned_repositories"].([]interface{}) { + if repo.(map[string]interface{})["analysis_status"].(string) == "succeeded" { + results.TotalSuccessfulScans += 1 + if repo.(map[string]interface{})["result_count"].(float64) > 0 { + results.TotalRepositoriesWithFindings += 1 + results.TotalFindingsCount += int(repo.(map[string]interface{})["result_count"].(float64)) + repoInfo := repo.(map[string]interface{})["repository"].(map[string]interface{}) + results.ResositoriesWithFindings = append(results.ResositoriesWithFindings, models.RepoWithFindings{ + Nwo: repoInfo["full_name"].(string), + Count: int(repo.(map[string]interface{})["result_count"].(float64)), + RunId: run.Id, + Stars: int(repoInfo["stargazers_count"].(float64)), + }) + } + } else if repo.(map[string]interface{})["analysis_status"].(string) == "failed" { + results.TotalFailedScans += 1 + } + } + + skipped_repositories := runDetails["skipped_repositories"].(map[string]interface{}) + access_mismatch_repos := skipped_repositories["access_mismatch_repos"].(map[string]interface{}) + not_found_repos := skipped_repositories["not_found_repos"].(map[string]interface{}) + no_codeql_db_repos := skipped_repositories["no_codeql_db_repos"].(map[string]interface{}) + over_limit_repos := skipped_repositories["over_limit_repos"].(map[string]interface{}) + total_skipped_repos := access_mismatch_repos["repository_count"].(float64) + not_found_repos["repository_count"].(float64) + no_codeql_db_repos["repository_count"].(float64) + over_limit_repos["repository_count"].(float64) + + results.TotalSkippedAccessMismatchRepositories += int(access_mismatch_repos["repository_count"].(float64)) + results.TotalSkippedNotFoundRepositories += int(not_found_repos["repository_count"].(float64)) + results.TotalSkippedNoDatabaseRepositories += int(no_codeql_db_repos["repository_count"].(float64)) + results.TotalSkippedOverLimitRepositories += int(over_limit_repos["repository_count"].(float64)) + results.TotalSkippedRepositories += int(total_skipped_repos) + } + + if jsonFlag { + data, err := json.MarshalIndent(results, "", " ") + if err != nil { + log.Fatal(err) + } + fmt.Println(string(data)) + } else { + fmt.Println("Run name:", sessionNameFlag) + fmt.Println("Total runs:", len(results.Runs)) + fmt.Println("Total successful scans:", results.TotalSuccessfulScans) + fmt.Println("Total failed scans:", results.TotalFailedScans) + fmt.Println("Total skipped repositories:", results.TotalSkippedRepositories) + fmt.Println("Total skipped repositories due to access mismatch:", results.TotalSkippedAccessMismatchRepositories) + fmt.Println("Total skipped repositories due to not found:", results.TotalSkippedNotFoundRepositories) + fmt.Println("Total skipped repositories due to no database:", results.TotalSkippedNoDatabaseRepositories) + fmt.Println("Total skipped repositories due to over limit:", results.TotalSkippedOverLimitRepositories) + fmt.Println("Total repositories with findings:", results.TotalRepositoriesWithFindings) + fmt.Println("Total findings:", results.TotalFindingsCount) + fmt.Println("Repositories with findings:") + for _, repo := range results.ResositoriesWithFindings { + fmt.Println(" ", repo.Nwo, ":", repo.Count) + } + } +} diff --git a/cmd/submit.go b/cmd/submit.go new file mode 100644 index 0000000..ab03873 --- /dev/null +++ b/cmd/submit.go @@ -0,0 +1,157 @@ +/* +Copyright © 2023 NAME HERE + +*/ +package cmd + +import ( + "fmt" + "log" + "os" + + "github.com/spf13/cobra" + "github.com/GitHubSecurityLab/gh-mrva/utils" + "github.com/GitHubSecurityLab/gh-mrva/config" + "github.com/GitHubSecurityLab/gh-mrva/models" + +) + +var ( + controller string + codeqlPath string + listFile string + listName string + language string + sessionName string + queryFile string + querySuiteFile string +) +var submitCmd = &cobra.Command{ + Use: "submit", + Short: "Submit a query or query suite to a MRVA controller.", + Long: `Submit a query or query suite to a MRVA controller.`, + Run: func(cmd *cobra.Command, args []string) { + submitQuery() + }, +} + +func init() { + rootCmd.AddCommand(submitCmd) + submitCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name") + submitCmd.Flags().StringVarP(&languageFlag, "language", "l", "", "DB language") + submitCmd.Flags().StringVarP(&queryFileFlag, "query", "q", "", "Path to query file") + submitCmd.Flags().StringVarP(&querySuiteFileFlag, "query-suite","x", "", "Path to query suite file") + submitCmd.Flags().StringVarP(&controllerFlag, "controller", "c", "", "MRVA controller repository (overrides config file)") + submitCmd.Flags().StringVarP(&listFileFlag, "list-file", "f", "", "Path to repo list file (overrides config file)") + submitCmd.Flags().StringVarP(&listFlag, "list", "i", "", "Name of repo list") + submitCmd.Flags().StringVarP(&codeqlPathFlag, "codeql-path", "p", "", "Path to CodeQL distribution (overrides config file)") + submitCmd.MarkFlagRequired("session") + submitCmd.MarkFlagRequired("language") + submitCmd.MarkFlagsMutuallyExclusive("query", "query-suite") +} + +func submitQuery() { + configData, err := utils.GetConfig() + if err != nil { + log.Fatal(err) + } + + if controllerFlag != "" { + controller = controllerFlag + } else if configData.Controller != "" { + controller = configData.Controller + } + if listFileFlag != "" { + listFile = listFileFlag + } else if configData.ListFile != "" { + listFile = configData.ListFile + } + if codeqlPathFlag != "" { + codeqlPath = codeqlPathFlag + } else if configData.CodeQLPath != "" { + codeqlPath = configData.CodeQLPath + } + if languageFlag != "" { + language = languageFlag + } + if sessionNameFlag != "" { + sessionName = sessionNameFlag + } + if listFlag != "" { + listName = listFlag + } + if queryFileFlag != "" { + queryFile = queryFileFlag + } + if querySuiteFileFlag != "" { + querySuiteFile = querySuiteFileFlag + } + + if controller == "" { + fmt.Println("Please specify a controller.") + os.Exit(1) + } + if listFile == "" { + fmt.Println("Please specify a list file.") + os.Exit(1) + } + if listName == "" { + fmt.Println("Please specify a list name.") + os.Exit(1) + } + if queryFile == "" && querySuiteFile == "" { + fmt.Println("Please specify a query or query suite.") + os.Exit(1) + } + + // read list of target repositories + repositories, err := utils.ResolveRepositories(listFile, listName) + if err != nil { + log.Fatal(err) + } + + // if a query suite is specified, resolve the queries + queries := []string{} + if queryFileFlag != "" { + queries = append(queries, queryFileFlag) + } else if querySuiteFileFlag != "" { + queries = utils.ResolveQueries(codeqlPath, querySuiteFile) + } + + fmt.Printf("Submitting %d queries for %d repositories\n", len(queries), len(repositories)) + var runs []models.Run + for _, query := range queries { + encodedBundle, err := utils.GenerateQueryPack(codeqlPath, query, language) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Generated encoded bundle for %s\n", query) + + var chunks [][]string + for i := 0; i < len(repositories); i += config.MAX_MRVA_REPOSITORIES { + end := i + config.MAX_MRVA_REPOSITORIES + if end > len(repositories) { + end = len(repositories) + } + chunks = append(chunks, repositories[i:end]) + } + for _, chunk := range chunks { + id, err := utils.SubmitRun(controller, language, chunk, encodedBundle) + if err != nil { + log.Fatal(err) + } + runs = append(runs, models.Run{Id: id, Query: query}) + } + + } + if querySuiteFile != "" { + err = utils.SaveSession(sessionName, controller, runs, language, listFile, listName, querySuiteFile, len(repositories)) + } else if queryFile != "" { + err = utils.SaveSession(sessionName, controller, runs, language, listFile, listName, queryFile, len(repositories)) + } + if err != nil { + log.Fatal(err) + } + fmt.Println("Done!") +} + diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..d00da4f --- /dev/null +++ b/config/config.go @@ -0,0 +1,7 @@ +package config + +const ( + MAX_MRVA_REPOSITORIES = 1000 + WORKERS = 10 +) + diff --git a/go.mod b/go.mod index c89db20..54d592c 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,15 @@ -module github.com/pwntester/gh-mrva +module github.com/GitHubSecurityLab/gh-mrva go 1.19 require github.com/cli/go-gh v1.2.1 -require github.com/aymanbagabas/go-osc52 v1.2.1 // indirect +require ( + github.com/aymanbagabas/go-osc52 v1.2.1 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/cobra v1.7.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect +) require ( github.com/cli/safeexec v1.0.0 // indirect diff --git a/go.sum b/go.sum index a95410b..57fb44b 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,7 @@ github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI= github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q= github.com/cli/shurcooL-graphql v0.0.2 h1:rwP5/qQQ2fM0TzkUTwtt6E2LbIYf6R+39cUXTa04NYk= github.com/cli/shurcooL-graphql v0.0.2/go.mod h1:tlrLmw/n5Q/+4qSvosT+9/W5zc8ZMjnJeYBxSdb4nWA= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= @@ -14,6 +15,8 @@ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/henvic/httpretty v0.0.6 h1:JdzGzKZBajBfnvlMALXXMVQWxWMF/ofTy8C3/OSUTxs= github.com/henvic/httpretty v0.0.6/go.mod h1:X38wLjWXHkXT7r2+uK8LjCMne9rsuNaBLJ+5cU2/Pmo= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= @@ -29,6 +32,11 @@ github.com/muesli/termenv v0.14.0/go.mod h1:kG/pF1E7fh949Xhe156crRUrHNyK221IuGO7 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e h1:BuzhfgfWQbX0dWzYzT1zsORLnHRv3bcRcsaUk0VmXA8= github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e/go.mod h1:/Tnicc6m/lsJE0irFMA0LfIwTBo4QP7A8IfyIv4zZKI= diff --git a/main.go b/main.go index 6ab22f7..d4856be 100644 --- a/main.go +++ b/main.go @@ -1,1176 +1,28 @@ +/* +Copyright © 2023 Alvaro Munoz pwntester@github.com + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ package main -import ( - "archive/zip" - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "flag" - "fmt" - "github.com/cli/go-gh" - "github.com/cli/go-gh/pkg/api" - "github.com/google/uuid" - "gopkg.in/yaml.v3" - "io/ioutil" - "log" - "os" - "os/exec" - "path/filepath" - "strings" - "sync" - "text/template" - "time" -) - -const ( - MAX_MRVA_REPOSITORIES = 1000 - WORKERS = 10 -) - -var ( - configFilePath string - sessionsFilePath string -) - -func runCodeQLCommand(codeqlPath string, combined bool, args ...string) ([]byte, error) { - if !strings.Contains(strings.Join(args, " "), "packlist") { - args = append(args, fmt.Sprintf("--additional-packs=%s", codeqlPath)) - } - cmd := exec.Command("codeql", args...) - cmd.Env = os.Environ() - if combined { - return cmd.CombinedOutput() - } else { - return cmd.Output() - } -} -func resolveRepositories(listFile string, list string) ([]string, error) { - fmt.Printf("Resolving %s repositories from %s\n", list, listFile) - jsonFile, err := os.Open(listFile) - if err != nil { - fmt.Println(err) - return nil, err - } - defer jsonFile.Close() - byteValue, _ := ioutil.ReadAll(jsonFile) - var repoLists map[string][]string - err = json.Unmarshal(byteValue, &repoLists) - if err != nil { - log.Fatal(err) - } - return repoLists[list], nil -} - -func resolveQueries(codeqlPath string, querySuite string) []string { - args := []string{"resolve", "queries", "--format=json", querySuite} - jsonBytes, err := runCodeQLCommand(codeqlPath, false, args...) - var queries []string - err = json.Unmarshal(jsonBytes, &queries) - if err != nil { - log.Fatal(err) - } - return queries -} - -func packPacklist(codeqlPath string, dir string, includeQueries bool) []string { - // since 2.7.1, packlist returns an object with a "paths" property that is a list of packs. - args := []string{"pack", "packlist", "--format=json"} - if !includeQueries { - args = append(args, "--no-include-queries") - } - args = append(args, dir) - jsonBytes, err := runCodeQLCommand(codeqlPath, false, args...) - var packlist map[string][]string - err = json.Unmarshal(jsonBytes, &packlist) - if err != nil { - log.Fatal(err) - } - return packlist["paths"] -} - -func findPackRoot(queryFile string) string { - // Starting on the directory of queryPackDir, go down until a qlpack.yml find is found. return that directory - // If no qlpack.yml is found, return the directory of queryFile - currentDir := filepath.Dir(queryFile) - for currentDir != "/" { - if _, err := os.Stat(filepath.Join(currentDir, "qlpack.yml")); errors.Is(err, os.ErrNotExist) { - // qlpack.yml not found, go up one level - currentDir = filepath.Dir(currentDir) - } else { - return currentDir - } - } - return filepath.Dir(queryFile) -} - -func copyFile(srcPath string, targetPath string) error { - err := os.MkdirAll(filepath.Dir(targetPath), os.ModePerm) - if err != nil { - return err - } - bytesRead, err := ioutil.ReadFile(srcPath) - if err != nil { - return err - } - err = ioutil.WriteFile(targetPath, bytesRead, 0644) - if err != nil { - return err - } - return nil -} - -func fixPackFile(queryPackDir string, packRelativePath string) error { - packPath := filepath.Join(queryPackDir, "qlpack.yml") - packFile, err := ioutil.ReadFile(packPath) - if err != nil { - return err - } - var packData map[string]interface{} - err = yaml.Unmarshal(packFile, &packData) - if err != nil { - return err - } - // update the default suite - defaultSuiteFile := packData["defaultSuiteFile"] - if defaultSuiteFile != nil { - // remove the defaultSuiteFile property - delete(packData, "defaultSuiteFile") - } - packData["defaultSuite"] = map[string]string{ - "query": packRelativePath, - "description": "Query suite for Variant Analysis", - } - - // update the name - packData["name"] = "codeql-remote/query" - - // remove any `${workspace}` version references - dependencies := packData["dependencies"] - if dependencies != nil { - // for key and value in dependencies - for key, value := range dependencies.(map[string]interface{}) { - // if value is a string and value contains `${workspace}` - if value == "${workspace}" { - // replace the value with `*` - packData["dependencies"].(map[string]interface{})[key] = "*" - } - } - } - - // write the pack file - packFile, err = yaml.Marshal(packData) - if err != nil { - return err - } - err = ioutil.WriteFile(packPath, packFile, 0644) - if err != nil { - return err - } - return nil -} - -// Generate a query pack containing the given query file. -func generateQueryPack(codeqlPath string, queryFile string, language string) (string, error) { - fmt.Printf("Generating query pack for %s\n", queryFile) - - // create a temporary directory to hold the query pack - queryPackDir, err := ioutil.TempDir("", "query-pack-") - if err != nil { - log.Fatal(err) - } - defer os.RemoveAll(queryPackDir) - - queryFile, err = filepath.Abs(queryFile) - if err != nil { - log.Fatal(err) - } - if _, err := os.Stat(queryFile); errors.Is(err, os.ErrNotExist) { - log.Fatal(fmt.Sprintf("Query file %s does not exist", queryFile)) - } - originalPackRoot := findPackRoot(queryFile) - packRelativePath, _ := filepath.Rel(originalPackRoot, queryFile) - targetQueryFileName := filepath.Join(queryPackDir, packRelativePath) - - if _, err := os.Stat(filepath.Join(originalPackRoot, "qlpack.yml")); errors.Is(err, os.ErrNotExist) { - // qlpack.yml not found, generate a synthetic one - fmt.Printf("QLPack does not exist. Generating synthetic one for %s\n", queryFile) - // copy only the query file to the query pack directory - err := copyFile(queryFile, targetQueryFileName) - if err != nil { - log.Fatal(err) - } - // generate a synthetic qlpack.yml - td := struct { - Language string - Name string - Query string - }{ - Language: language, - Name: "codeql-remote/query", - Query: strings.Replace(packRelativePath, string(os.PathSeparator), "/", -1), - } - t, err := template.New("").Parse(`name: {{ .Name }} -version: 0.0.0 -dependencies: - codeql/{{ .Language }}-all: "*" -defaultSuite: - description: Query suite for variant analysis - query: {{ .Query }}`) - if err != nil { - log.Fatal(err) - } - - f, err := os.Create(filepath.Join(queryPackDir, "qlpack.yml")) - defer f.Close() - if err != nil { - log.Fatal(err) - } - err = t.Execute(f, td) - if err != nil { - log.Fatal(err) - } - fmt.Printf("Copied QLPack files to %s\n", queryPackDir) - } else { - // don't include all query files in the QLPacks. We only want the queryFile to be copied. - fmt.Printf("QLPack exists, stripping all other queries from %s\n", originalPackRoot) - toCopy := packPacklist(codeqlPath, originalPackRoot, false) - // also copy the lock file (either new name or old name) and the query file itself (these are not included in the packlist) - lockFileNew := filepath.Join(originalPackRoot, "qlpack.lock.yml") - lockFileOld := filepath.Join(originalPackRoot, "codeql-pack.lock.yml") - candidateFiles := []string{lockFileNew, lockFileOld, queryFile} - for _, candidateFile := range candidateFiles { - if _, err := os.Stat(candidateFile); !errors.Is(err, os.ErrNotExist) { - // if the file exists, copy it - toCopy = append(toCopy, candidateFile) - } - } - // copy the files to the queryPackDir directory - fmt.Printf("Preparing stripped QLPack in %s\n", queryPackDir) - for _, srcPath := range toCopy { - relPath, _ := filepath.Rel(originalPackRoot, srcPath) - targetPath := filepath.Join(queryPackDir, relPath) - //fmt.Printf("Copying %s to %s\n", srcPath, targetPath) - err := copyFile(srcPath, targetPath) - if err != nil { - log.Fatal(err) - } - } - fmt.Printf("Fixing QLPack in %s\n", queryPackDir) - fixPackFile(queryPackDir, packRelativePath) - } - - // assuming we are using 2.11.3 or later so Qlx remote is supported - ccache := filepath.Join(originalPackRoot, ".cache") - precompilationOpts := []string{"--qlx", "--no-default-compilation-cache", "--compilation-cache=" + ccache} - bundlePath := filepath.Join(filepath.Dir(queryPackDir), fmt.Sprintf("qlpack-%s-generated.tgz", uuid.New().String())) - - // install the pack dependencies - fmt.Print("Installing QLPack dependencies\n") - args := []string{"pack", "install", queryPackDir} - stdouterr, err := runCodeQLCommand(codeqlPath, true, args...) - if err != nil { - fmt.Printf("`codeql pack bundle` failed with error: %v\n", string(stdouterr)) - return "", fmt.Errorf("Failed to install query pack: %v", err) - } - // bundle the query pack - fmt.Print("Compiling and bundling the QLPack (This may take a while)\n") - args = []string{"pack", "bundle", "-o", bundlePath, queryPackDir} - args = append(args, precompilationOpts...) - stdouterr, err = runCodeQLCommand(codeqlPath, true, args...) - if err != nil { - fmt.Printf("`codeql pack bundle` failed with error: %v\n", string(stdouterr)) - return "", fmt.Errorf("Failed to bundle query pack: %v\n", err) - } - - // open the bundle file and encode it as base64 - bundleFile, err := os.Open(bundlePath) - if err != nil { - return "", fmt.Errorf("Failed to open bundle file: %v\n", err) - } - defer bundleFile.Close() - bundleBytes, err := ioutil.ReadAll(bundleFile) - if err != nil { - return "", fmt.Errorf("Failed to read bundle file: %v\n", err) - } - bundleBase64 := base64.StdEncoding.EncodeToString(bundleBytes) - - return bundleBase64, nil -} - -// Requests a query to be run against `respositories` on the given `controller`. -func submitRun(controller string, language string, repoChunk []string, bundle string) (int, error) { - opts := api.ClientOptions{ - Headers: map[string]string{"Accept": "application/vnd.github.v3+json"}, - } - client, err := gh.RESTClient(&opts) - if err != nil { - return -1, err - } - body := struct { - Repositories []string `json:"repositories"` - Language string `json:"language"` - Pack string `json:"query_pack"` - Ref string `json:"action_repo_ref"` - }{ - Repositories: repoChunk, - Language: language, - Pack: bundle, - Ref: "main", - } - var buf bytes.Buffer - err = json.NewEncoder(&buf).Encode(body) - if err != nil { - return -1, err - } - response := make(map[string]interface{}) - err = client.Post(fmt.Sprintf("repos/%s/code-scanning/codeql/variant-analyses", controller), &buf, &response) - if err != nil { - return -1, err - } - id := int(response["id"].(float64)) - return id, nil -} - -func getRunDetails(controller string, runId int) (map[string]interface{}, error) { - opts := api.ClientOptions{ - Headers: map[string]string{"Accept": "application/vnd.github.v3+json"}, - } - client, err := gh.RESTClient(&opts) - if err != nil { - return nil, err - } - response := make(map[string]interface{}) - err = client.Get(fmt.Sprintf("repos/%s/code-scanning/codeql/variant-analyses/%d", controller, runId), &response) - if err != nil { - return nil, err - } - return response, nil -} - -func getRunRepositoryDetails(controller string, runId int, nwo string) (map[string]interface{}, error) { - opts := api.ClientOptions{ - Headers: map[string]string{"Accept": "application/vnd.github.v3+json"}, - } - client, err := gh.RESTClient(&opts) - if err != nil { - return nil, err - } - response := make(map[string]interface{}) - err = client.Get(fmt.Sprintf("repos/%s/code-scanning/codeql/variant-analyses/%d/repos/%s", controller, runId, nwo), &response) - if err != nil { - return nil, err - } - return response, nil -} - -type DownloadTask struct { - runId int - nwo string - controller string - artifact string - outputDir string - language string -} - -func downloadWorker(wg *sync.WaitGroup, taskChannel <-chan DownloadTask, resultChannel chan DownloadTask) { - defer wg.Done() - for task := range taskChannel { - if task.artifact == "artifact" { - downloadResults(task.controller, task.runId, task.nwo, task.outputDir) - resultChannel <- task - } else if task.artifact == "database" { - fmt.Println("Downloading database", task.nwo, task.language, task.outputDir) - downloadDatabase(task.nwo, task.language, task.outputDir) - resultChannel <- task - } - } -} - -func downloadArtifact(url string, outputDir string, nwo string) error { - client, err := gh.HTTPClient(nil) - if err != nil { - return err - } - resp, err := client.Get(url) - if err != nil { - return err - } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - log.Fatal(err) - } - - zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) - if err != nil { - log.Fatal(err) - } - - for _, zf := range zipReader.File { - if zf.Name != "results.sarif" && zf.Name != "results.bqrs" { - continue - } - f, err := zf.Open() - if err != nil { - log.Fatal(err) - } - defer f.Close() - bytes, err := ioutil.ReadAll(f) - if err != nil { - log.Fatal(err) - } - extension := "" - resultPath := "" - if zf.Name == "results.bqrs" { - extension = "bqrs" - } else if zf.Name == "results.sarif" { - extension = "sarif" - } - resultPath = filepath.Join(outputDir, fmt.Sprintf("%s.%s", strings.Replace(nwo, "/", "_", -1), extension)) - err = ioutil.WriteFile(resultPath, bytes, os.ModePerm) - if err != nil { - return err - } - return nil - } - return errors.New("No results.sarif file found in artifact") -} - -func downloadResults(controller string, runId int, nwo string, outputDir string) error { - // download artifact (BQRS or SARIF) - runRepositoryDetails, err := getRunRepositoryDetails(controller, runId, nwo) - if err != nil { - return errors.New("Failed to get run repository details") - } - // download the results - err = downloadArtifact(runRepositoryDetails["artifact_url"].(string), outputDir, nwo) - if err != nil { - return errors.New("Failed to download artifact") - } - return nil -} - -func downloadDatabase(nwo string, language string, outputDir string) error { - dnwo := strings.Replace(nwo, "/", "_", -1) - targetPath := filepath.Join(outputDir, fmt.Sprintf("%s_%s_db.zip", dnwo, language)) - opts := api.ClientOptions{ - Headers: map[string]string{"Accept": "application/zip"}, - } - client, err := gh.HTTPClient(&opts) - if err != nil { - return err - } - resp, err := client.Get(fmt.Sprintf("https://api.github.com/repos/%s/code-scanning/codeql/databases/%s", nwo, language)) - if err != nil { - return err - } - defer resp.Body.Close() - - bytes, err := ioutil.ReadAll(resp.Body) - if err != nil { - return err - } - err = ioutil.WriteFile(targetPath, bytes, os.ModePerm) - return nil -} - -func deleteSession(name string) error { - sessions, err := getSessions() - if err != nil { - return err - } - if sessions == nil { - return errors.New("No sessions found") - } - // delete session if it exists - if _, ok := sessions[name]; ok { - - delete(sessions, name) - - // marshal sessions to yaml - sessionsYaml, err := yaml.Marshal(sessions) - if err != nil { - return err - } - // write sessions to file - err = ioutil.WriteFile(sessionsFilePath, sessionsYaml, os.ModePerm) - if err != nil { - return err - } - return nil - } - return errors.New(fmt.Sprintf("Session '%s' does not exist", name)) -} - -func saveSession(name string, controller string, runs []Run, language string, listFile string, list string, query string, count int) error { - sessions, err := getSessions() - if err != nil { - return err - } - if sessions == nil { - sessions = make(map[string]Session) - } - // add new session if it doesn't already exist - if _, ok := sessions[name]; ok { - return errors.New(fmt.Sprintf("Session '%s' already exists", name)) - } else { - sessions[name] = Session{ - Name: name, - Runs: runs, - Timestamp: time.Now(), - Controller: controller, - Language: language, - ListFile: listFile, - List: list, - RepositoryCount: count, - } - } - // marshal sessions to yaml - sessionsYaml, err := yaml.Marshal(sessions) - if err != nil { - return err - } - // write sessions to file - err = ioutil.WriteFile(sessionsFilePath, sessionsYaml, os.ModePerm) - if err != nil { - return err - } - return nil -} - -func loadSession(name string) (string, []Run, string, error) { - sessions, err := getSessions() - if err != nil { - return "", nil, "", err - } - if sessions != nil { - if entry, ok := sessions[name]; ok { - return entry.Controller, entry.Runs, entry.Language, nil - } - } - return "", nil, "", errors.New("No session found for " + name) -} - -func getSessions() (map[string]Session, error) { - sessionsFile, err := ioutil.ReadFile(sessionsFilePath) - var sessions map[string]Session - if err != nil { - return sessions, err - } - err = yaml.Unmarshal(sessionsFile, &sessions) - if err != nil { - log.Fatal(err) - } - return sessions, nil -} - -func getConfig() (Config, error) { - configFile, err := ioutil.ReadFile(configFilePath) - var configData Config - if err != nil { - return configData, err - } - err = yaml.Unmarshal(configFile, &configData) - if err != nil { - log.Fatal(err) - } - return configData, nil -} - -type Run struct { - Id int `yaml:"id"` - Query string `yaml:"query"` -} - -type Session struct { - Name string `yaml:"name" json:"name"` - Timestamp time.Time `yaml:"timestamp" json:"timestamp"` - Runs []Run `yaml:"runs" json:"runs"` - Controller string `yaml:"controller" json:"controller"` - ListFile string `yaml:"list_file" json:"list_file"` - List string `yaml:"list" json:"list"` - Language string `yaml:"language" json:"language"` - RepositoryCount int `yaml:"repository_count" json:"repository_count"` -} -type Config struct { - Controller string `yaml:"controller"` - ListFile string `yaml:"list_file"` - CodeQLPath string `yaml:"codeql_path"` -} +import "github.com/GitHubSecurityLab/gh-mrva/cmd" func main() { - configPath := os.Getenv("XDG_CONFIG_HOME") - if configPath == "" { - homePath := os.Getenv("HOME") - if homePath == "" { - log.Fatal("HOME environment variable not set") - } - configPath = filepath.Join(homePath, ".config") - } - configFilePath = filepath.Join(configPath, "gh-mrva", "config.yml") - configData, err := getConfig() - if err != nil { - log.Fatal(err) - } - - sessionsFilePath = filepath.Join(configPath, "gh-mrva", "sessions.yml") - if _, err := os.Stat(sessionsFilePath); os.IsNotExist(err) { - err := os.MkdirAll(filepath.Dir(sessionsFilePath), os.ModePerm) - if err != nil { - log.Fatal("Failed to create config directory") - } - // create empty file at sessionsFilePath - sessionsFile, err := os.Create(sessionsFilePath) - if err != nil { - log.Fatal("Failed to create sessions file") - } - sessionsFile.Close() - } - helpFlag := flag.String("help", "", "This help documentation.") - - flag.Usage = func() { - fmt.Fprintf(os.Stderr, `gh mrva - Run CodeQL queries at scale using Multi-Repository Variant Analysis (MRVA) - -Usage: - gh mrva submit [--codeql-path ] [--controller ] --lang --name [--list-file ] --list [--query | --query-suite ] - - gh mrva download --name --output-dir [--download-dbs] [--nwo ] - - gh mrva status --name [--json] - - gh mrva list [--json] - - gh mrva delete --name -`) - } - - flag.Parse() - - if *helpFlag != "" { - flag.Usage() - os.Exit(0) - } - - args := flag.Args() - if len(args) == 0 { - flag.Usage() - os.Exit(0) - } - cmd, args := args[0], args[1:] - - switch cmd { - case "submit": - submit(configData, args) - case "download": - download(args) - case "status": - status(args) - case "list": - list(args) - case "delete": - del(args) - default: - log.Fatalf("Unrecognized command %q. "+ - "Command must be one of: submit, download", cmd) - } -} - -func status(args []string) { - flag := flag.NewFlagSet("mrva status", flag.ExitOnError) - nameFlag := flag.String("name", "", "Name of run") - jsonFlag := flag.Bool("json", false, "Output in JSON format (default: false)") - - flag.Usage = func() { - fmt.Fprintf(os.Stderr, `gh mrva - Run CodeQL queries at scale using Multi-Repository Variant Analysis (MRVA) - -Usage: - gh mrva status --name [--json] - -`) - fmt.Fprintf(os.Stderr, "Flags:\n") - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, "\n") - } - - flag.Parse(args) - - var ( - runName = *nameFlag - jsonOutput = *jsonFlag - ) - - if runName == "" { - flag.Usage() - os.Exit(1) - } - - controller, runs, _, err := loadSession(runName) - if err != nil { - log.Fatal(err) - } - if len(runs) == 0 { - log.Fatal("No runs found for run name", runName) - } - - type RunStatus struct { - Id int `json:"id"` - Query string `json:"query"` - Status string `json:"status"` - FailureReason string `json:"failure_reason"` - } - - type RepoWithFindings struct { - Nwo string `json:"nwo"` - Count int `json:"count"` - RunId int `json:"run_id"` - Stars int `json:"stars"` - } - type Results struct { - Runs []RunStatus `json:"runs"` - ResositoriesWithFindings []RepoWithFindings `json:"repositories_with_findings"` - TotalFindingsCount int `json:"total_findings_count"` - TotalSuccessfulScans int `json:"total_successful_scans"` - TotalFailedScans int `json:"total_failed_scans"` - TotalRepositoriesWithFindings int `json:"total_repositories_with_findings"` - TotalSkippedRepositories int `json:"total_skipped_repositories"` - TotalSkippedAccessMismatchRepositories int `json:"total_skipped_access_mismatch_repositories"` - TotalSkippedNotFoundRepositories int `json:"total_skipped_not_found_repositories"` - TotalSkippedNoDatabaseRepositories int `json:"total_skipped_no_database_repositories"` - TotalSkippedOverLimitRepositories int `json:"total_skipped_over_limit_repositories"` - } - - var results Results - - for _, run := range runs { - if err != nil { - log.Fatal(err) - } - runDetails, err := getRunDetails(controller, run.Id) - if err != nil { - log.Fatal(err) - } - - status := runDetails["status"].(string) - var failure_reason string - if status == "failed" { - failure_reason = runDetails["failure_reason"].(string) - } else { - failure_reason = "" - } - - results.Runs = append(results.Runs, RunStatus{ - Id: run.Id, - Query: run.Query, - Status: status, - FailureReason: failure_reason, - }) - - for _, repo := range runDetails["scanned_repositories"].([]interface{}) { - if repo.(map[string]interface{})["analysis_status"].(string) == "succeeded" { - results.TotalSuccessfulScans += 1 - if repo.(map[string]interface{})["result_count"].(float64) > 0 { - results.TotalRepositoriesWithFindings += 1 - results.TotalFindingsCount += int(repo.(map[string]interface{})["result_count"].(float64)) - repoInfo := repo.(map[string]interface{})["repository"].(map[string]interface{}) - results.ResositoriesWithFindings = append(results.ResositoriesWithFindings, RepoWithFindings{ - Nwo: repoInfo["full_name"].(string), - Count: int(repo.(map[string]interface{})["result_count"].(float64)), - RunId: run.Id, - Stars: int(repoInfo["stargazers_count"].(float64)), - }) - } - } else if repo.(map[string]interface{})["analysis_status"].(string) == "failed" { - results.TotalFailedScans += 1 - } - } - - skipped_repositories := runDetails["skipped_repositories"].(map[string]interface{}) - access_mismatch_repos := skipped_repositories["access_mismatch_repos"].(map[string]interface{}) - not_found_repos := skipped_repositories["not_found_repos"].(map[string]interface{}) - no_codeql_db_repos := skipped_repositories["no_codeql_db_repos"].(map[string]interface{}) - over_limit_repos := skipped_repositories["over_limit_repos"].(map[string]interface{}) - total_skipped_repos := access_mismatch_repos["repository_count"].(float64) + not_found_repos["repository_count"].(float64) + no_codeql_db_repos["repository_count"].(float64) + over_limit_repos["repository_count"].(float64) - - results.TotalSkippedAccessMismatchRepositories += int(access_mismatch_repos["repository_count"].(float64)) - results.TotalSkippedNotFoundRepositories += int(not_found_repos["repository_count"].(float64)) - results.TotalSkippedNoDatabaseRepositories += int(no_codeql_db_repos["repository_count"].(float64)) - results.TotalSkippedOverLimitRepositories += int(over_limit_repos["repository_count"].(float64)) - results.TotalSkippedRepositories += int(total_skipped_repos) - } - - if jsonOutput { - data, err := json.MarshalIndent(results, "", " ") - if err != nil { - log.Fatal(err) - } - fmt.Println(string(data)) - } else { - fmt.Println("Run name:", runName) - fmt.Println("Total runs:", len(results.Runs)) - fmt.Println("Total successful scans:", results.TotalSuccessfulScans) - fmt.Println("Total failed scans:", results.TotalFailedScans) - fmt.Println("Total skipped repositories:", results.TotalSkippedRepositories) - fmt.Println("Total skipped repositories due to access mismatch:", results.TotalSkippedAccessMismatchRepositories) - fmt.Println("Total skipped repositories due to not found:", results.TotalSkippedNotFoundRepositories) - fmt.Println("Total skipped repositories due to no database:", results.TotalSkippedNoDatabaseRepositories) - fmt.Println("Total skipped repositories due to over limit:", results.TotalSkippedOverLimitRepositories) - fmt.Println("Total repositories with findings:", results.TotalRepositoriesWithFindings) - fmt.Println("Total findings:", results.TotalFindingsCount) - fmt.Println("Repositories with findings:") - for _, repo := range results.ResositoriesWithFindings { - fmt.Println(" ", repo.Nwo, ":", repo.Count) - } - } -} - -func submit(configData Config, args []string) { - - flag := flag.NewFlagSet("mrva submit", flag.ExitOnError) - queryFileFlag := flag.String("query", "", "Path to query file") - querySuiteFileFlag := flag.String("query-suite", "", "Path to query suite file") - controllerFlag := flag.String("controller", "", "MRVA controller repository (overrides config file)") - codeqlPathFlag := flag.String("codeql-path", "", "Path to CodeQL distribution (overrides config file)") - listFileFlag := flag.String("list-file", "", "Path to repo list file (overrides config file)") - listFlag := flag.String("list", "", "Name of repo list") - langFlag := flag.String("lang", "", "DB language") - nameFlag := flag.String("name", "", "Name of run") - - flag.Usage = func() { - fmt.Fprintf(os.Stderr, `gh mrva - Run CodeQL queries at scale using Multi-Repository Variant Analysis (MRVA) - -Usage: - gh mrva submit [--codeql-path ] [--controller ] --lang --name [--list-file ] --list [--query | --query-suite ] - -`) - fmt.Fprintf(os.Stderr, "Flags:\n") - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, "\n") - } - - flag.Parse(args) - - var ( - controller string - codeqlPath string - listFile string - list string - language string - runName string - queryFile string - querySuiteFile string - ) - if *controllerFlag != "" { - controller = *controllerFlag - } else if configData.Controller != "" { - controller = configData.Controller - } - if *listFileFlag != "" { - listFile = *listFileFlag - } else if configData.ListFile != "" { - listFile = configData.ListFile - } - if *codeqlPathFlag != "" { - codeqlPath = *codeqlPathFlag - } else if configData.CodeQLPath != "" { - codeqlPath = configData.CodeQLPath - } - if *langFlag != "" { - language = *langFlag - } - if *nameFlag != "" { - runName = *nameFlag - } - if *listFlag != "" { - list = *listFlag - } - if *queryFileFlag != "" { - queryFile = *queryFileFlag - } - if *querySuiteFileFlag != "" { - querySuiteFile = *querySuiteFileFlag - } - - if runName == "" || codeqlPath == "" || controller == "" || language == "" || listFile == "" || list == "" || (queryFile == "" && querySuiteFile == "") { - flag.Usage() - os.Exit(1) - } - - // read list of target repositories - repositories, err := resolveRepositories(listFile, list) - if err != nil { - log.Fatal(err) - } - - // if a query suite is specified, resolve the queries - queries := []string{} - if *queryFileFlag != "" { - queries = append(queries, *queryFileFlag) - } else if *querySuiteFileFlag != "" { - queries = resolveQueries(codeqlPath, querySuiteFile) - } - - fmt.Printf("Submitting %d queries for %d repositories\n", len(queries), len(repositories)) - var runs []Run - for _, query := range queries { - encodedBundle, err := generateQueryPack(codeqlPath, query, language) - if err != nil { - log.Fatal(err) - } - fmt.Printf("Generated encoded bundle for %s\n", query) - - var chunks [][]string - for i := 0; i < len(repositories); i += MAX_MRVA_REPOSITORIES { - end := i + MAX_MRVA_REPOSITORIES - if end > len(repositories) { - end = len(repositories) - } - chunks = append(chunks, repositories[i:end]) - } - for _, chunk := range chunks { - id, err := submitRun(controller, language, chunk, encodedBundle) - if err != nil { - log.Fatal(err) - } - runs = append(runs, Run{Id: id, Query: query}) - } - - } - if querySuiteFile != "" { - err = saveSession(runName, controller, runs, language, listFile, list, querySuiteFile, len(repositories)) - } else if queryFile != "" { - err = saveSession(runName, controller, runs, language, listFile, list, queryFile, len(repositories)) - } - if err != nil { - log.Fatal(err) - } - fmt.Println("Done!") -} - -func del(args []string) { - flag := flag.NewFlagSet("mrva delete", flag.ExitOnError) - nameFlag := flag.String("name", "", "Session name to be deleted") - flag.Usage = func() { - fmt.Fprintf(os.Stderr, `gh mrva - Run CodeQL queries at scale using Multi-Repository Variant Analysis (MRVA) - -Usage: - gh mrva delete --name - -`) - fmt.Fprintf(os.Stderr, "Flags:\n") - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, "\n") - } - - flag.Parse(args) - - if *nameFlag == "" { - flag.Usage() - os.Exit(1) - } - - err := deleteSession(*nameFlag) - if err == nil { - fmt.Println("Session deleted") - } else { - log.Fatal(err) - } -} - -func list(args []string) { - flag := flag.NewFlagSet("mrva list", flag.ExitOnError) - jsonFlag := flag.Bool("json", false, "Output in JSON format (default: false)") - flag.Usage = func() { - fmt.Fprintf(os.Stderr, `gh mrva - Run CodeQL queries at scale using Multi-Repository Variant Analysis (MRVA) - -Usage: - gh mrva list [--json] - -`) - fmt.Fprintf(os.Stderr, "Flags:\n") - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, "\n") - } - - flag.Parse(args) - - var jsonOutput = *jsonFlag - - sessions, err := getSessions() - if err != nil { - log.Fatal(err) - } - if sessions != nil { - if jsonOutput { - sessions_list := make([]Session, 0, len(sessions)) - for _, session := range sessions { - sessions_list = append(sessions_list, session) - } - data, err := json.MarshalIndent(sessions_list, "", " ") - if err != nil { - log.Fatal(err) - } - fmt.Println(string(data)) - // w := &bytes.Buffer{} - // jsonpretty.Format(w, bytes.NewReader(data), " ", true) - // fmt.Println(w.String()) - } else { - for name, entry := range sessions { - fmt.Printf("%s (%v)\n", name, entry.Timestamp) - fmt.Printf(" Controller: %s\n", entry.Controller) - fmt.Printf(" Language: %s\n", entry.Language) - fmt.Printf(" List file: %s\n", entry.ListFile) - fmt.Printf(" List: %s\n", entry.List) - fmt.Printf(" Repository count: %d\n", entry.RepositoryCount) - fmt.Println(" Runs:") - for _, run := range entry.Runs { - fmt.Printf(" ID: %d\n", run.Id) - fmt.Printf(" Query: %s\n", run.Query) - } - } - } - } -} - -func download(args []string) { - flag := flag.NewFlagSet("mrva download", flag.ExitOnError) - nameFlag := flag.String("name", "", "Name of run") - outputDirFlag := flag.String("output-dir", "", "Output directory") - downloadDBsFlag := flag.Bool("download-dbs", false, "Download databases (optional)") - nwoFlag := flag.String("nwo", "", "Repository to download artifacts for (optional)") - - flag.Usage = func() { - fmt.Fprintf(os.Stderr, `gh mrva - Run CodeQL queries at scale using Multi-Repository Variant Analysis (MRVA) - -Usage: - gh mrva download --name --output-dir [--download-dbs] [--nwo ] - -`) - fmt.Fprintf(os.Stderr, "Flags:\n") - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, "\n") - } - - flag.Parse(args) - - var ( - runName = *nameFlag - outputDir = *outputDirFlag - downloadDBs = *downloadDBsFlag - targetNwo = *nwoFlag - ) - - if runName == "" || outputDir == "" { - flag.Usage() - os.Exit(1) - } - - // if outputDir does not exist, create it - if _, err := os.Stat(outputDir); os.IsNotExist(err) { - err := os.MkdirAll(outputDir, os.ModePerm) - if err != nil { - log.Fatal(err) - } - } - - controller, runs, language, err := loadSession(runName) - if err != nil { - log.Fatal(err) - } else if len(runs) == 0 { - log.Fatal("No runs found for name " + runName) - } - - var downloadTasks []DownloadTask - - for _, run := range runs { - runDetails, err := getRunDetails(controller, run.Id) - if err != nil { - log.Fatal(err) - } - if runDetails["status"] == "in_progress" { - log.Printf("Run %d is not complete yet. Please try again later.", run.Id) - return - } - for _, r := range runDetails["scanned_repositories"].([]interface{}) { - repo := r.(map[string]interface{}) - result_count := repo["result_count"] - repoInfo := repo["repository"].(map[string]interface{}) - nwo := repoInfo["full_name"].(string) - // if targetNwo is set, only download artifacts for that repository - if targetNwo != "" && targetNwo != nwo { - continue - } - if result_count != nil && result_count.(float64) > 0 { - // check if the SARIF or BQRS file already exists - dnwo := strings.Replace(nwo, "/", "_", -1) - sarifPath := filepath.Join(outputDir, fmt.Sprintf("%s.sarif", dnwo)) - bqrsPath := filepath.Join(outputDir, fmt.Sprintf("%s.bqrs", dnwo)) - targetPath := filepath.Join(outputDir, fmt.Sprintf("%s_%s_db.zip", dnwo, language)) - _, bqrsErr := os.Stat(bqrsPath) - _, sarifErr := os.Stat(sarifPath) - if errors.Is(bqrsErr, os.ErrNotExist) && errors.Is(sarifErr, os.ErrNotExist) { - downloadTasks = append(downloadTasks, DownloadTask{ - runId: run.Id, - nwo: nwo, - controller: controller, - artifact: "artifact", - language: language, - outputDir: outputDir, - }) - } - if downloadDBs { - // check if the database already exists - if _, err := os.Stat(targetPath); errors.Is(err, os.ErrNotExist) { - downloadTasks = append(downloadTasks, DownloadTask{ - runId: run.Id, - nwo: nwo, - controller: controller, - artifact: "database", - language: language, - outputDir: outputDir, - }) - } - } - } - } - } - - wg := new(sync.WaitGroup) - - taskChannel := make(chan DownloadTask) - resultChannel := make(chan DownloadTask, len(downloadTasks)) - - // Start the workers - for i := 0; i < WORKERS; i++ { - wg.Add(1) - go downloadWorker(wg, taskChannel, resultChannel) - } - - // Send jobs to worker - for _, downloadTask := range downloadTasks { - taskChannel <- downloadTask - } - close(taskChannel) - - count := 0 - progressDone := make(chan bool) - - go func() { - for value := range resultChannel { - count++ - fmt.Printf("Downloaded %s for %s (%d/%d)\n", value.artifact, value.nwo, count, len(downloadTasks)) - } - fmt.Println(count, " artifacts downloaded") - progressDone <- true - }() - - // wait for all workers to finish - wg.Wait() - - // close the result channel - close(resultChannel) - - // drain the progress channel - <-progressDone + cmd.Execute() } diff --git a/models/models.go b/models/models.go new file mode 100644 index 0000000..8f3861c --- /dev/null +++ b/models/models.go @@ -0,0 +1,65 @@ +package models + +import ( + "time" +) + +type Run struct { + Id int `yaml:"id"` + Query string `yaml:"query"` +} + +type Session struct { + Name string `yaml:"name" json:"name"` + Timestamp time.Time `yaml:"timestamp" json:"timestamp"` + Runs []Run `yaml:"runs" json:"runs"` + Controller string `yaml:"controller" json:"controller"` + ListFile string `yaml:"list_file" json:"list_file"` + List string `yaml:"list" json:"list"` + Language string `yaml:"language" json:"language"` + RepositoryCount int `yaml:"repository_count" json:"repository_count"` +} + +type Config struct { + Controller string `yaml:"controller"` + ListFile string `yaml:"list_file"` + CodeQLPath string `yaml:"codeql_path"` +} + +type DownloadTask struct { + RunId int + Nwo string + Controller string + Artifact string + OutputDir string + Language string +} + +type RunStatus struct { + Id int `json:"id"` + Query string `json:"query"` + Status string `json:"status"` + FailureReason string `json:"failure_reason"` +} + +type RepoWithFindings struct { + Nwo string `json:"nwo"` + Count int `json:"count"` + RunId int `json:"run_id"` + Stars int `json:"stars"` +} + +type Results struct { + Runs []RunStatus `json:"runs"` + ResositoriesWithFindings []RepoWithFindings `json:"repositories_with_findings"` + TotalFindingsCount int `json:"total_findings_count"` + TotalSuccessfulScans int `json:"total_successful_scans"` + TotalFailedScans int `json:"total_failed_scans"` + TotalRepositoriesWithFindings int `json:"total_repositories_with_findings"` + TotalSkippedRepositories int `json:"total_skipped_repositories"` + TotalSkippedAccessMismatchRepositories int `json:"total_skipped_access_mismatch_repositories"` + TotalSkippedNotFoundRepositories int `json:"total_skipped_not_found_repositories"` + TotalSkippedNoDatabaseRepositories int `json:"total_skipped_no_database_repositories"` + TotalSkippedOverLimitRepositories int `json:"total_skipped_over_limit_repositories"` +} + diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..5ade6c2 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,560 @@ +package utils + +import ( + "archive/zip" + "sync" + "text/template" + "github.com/google/uuid" + "encoding/base64" + "path/filepath" + "strings" + "os/exec" + "bytes" + "encoding/json" + "fmt" + "time" + "os" + "gopkg.in/yaml.v3" + "io/ioutil" + "log" + "errors" + + "github.com/cli/go-gh" + "github.com/cli/go-gh/pkg/api" + "github.com/GitHubSecurityLab/gh-mrva/models" +) + +var ( + configFilePath string + sessionsFilePath string +) + +func GetSessionsFilePath() string { + return sessionsFilePath +} + +func SetSessionsFilePath(path string) { + sessionsFilePath = path +} + +func GetConfigFilePath() string { + return configFilePath +} + +func SetConfigFilePath(path string) { + configFilePath = path +} + +func GetSessions() (map[string]models.Session, error) { + sessionsFile, err := ioutil.ReadFile(sessionsFilePath) + var sessions map[string]models.Session + if err != nil { + return sessions, err + } + err = yaml.Unmarshal(sessionsFile, &sessions) + if err != nil { + log.Fatal(err) + } + return sessions, nil +} + +func LoadSession(name string) (string, []models.Run, string, error) { + sessions, err := GetSessions() + if err != nil { + return "", nil, "", err + } + if sessions != nil { + if entry, ok := sessions[name]; ok { + return entry.Controller, entry.Runs, entry.Language, nil + } + } + return "", nil, "", errors.New("No session found for " + name) +} + +func GetRunDetails(controller string, runId int) (map[string]interface{}, error) { + opts := api.ClientOptions{ + Headers: map[string]string{"Accept": "application/vnd.github.v3+json"}, + } + client, err := gh.RESTClient(&opts) + if err != nil { + return nil, err + } + response := make(map[string]interface{}) + err = client.Get(fmt.Sprintf("repos/%s/code-scanning/codeql/variant-analyses/%d", controller, runId), &response) + if err != nil { + return nil, err + } + return response, nil +} + +func GetRunRepositoryDetails(controller string, runId int, nwo string) (map[string]interface{}, error) { + opts := api.ClientOptions{ + Headers: map[string]string{"Accept": "application/vnd.github.v3+json"}, + } + client, err := gh.RESTClient(&opts) + if err != nil { + return nil, err + } + response := make(map[string]interface{}) + err = client.Get(fmt.Sprintf("repos/%s/code-scanning/codeql/variant-analyses/%d/repos/%s", controller, runId, nwo), &response) + if err != nil { + return nil, err + } + return response, nil +} + +func SaveSession(name string, controller string, runs []models.Run, language string, listFile string, list string, query string, count int) error { + sessions, err := GetSessions() + if err != nil { + return err + } + if sessions == nil { + sessions = make(map[string]models.Session) + } + // add new session if it doesn't already exist + if _, ok := sessions[name]; ok { + return errors.New(fmt.Sprintf("Session '%s' already exists", name)) + } else { + sessions[name] = models.Session{ + Name: name, + Runs: runs, + Timestamp: time.Now(), + Controller: controller, + Language: language, + ListFile: listFile, + List: list, + RepositoryCount: count, + } + } + // marshal sessions to yaml + sessionsYaml, err := yaml.Marshal(sessions) + if err != nil { + return err + } + // write sessions to file + err = ioutil.WriteFile(sessionsFilePath, sessionsYaml, os.ModePerm) + if err != nil { + return err + } + return nil +} + +func SubmitRun(controller string, language string, repoChunk []string, bundle string) (int, error) { + opts := api.ClientOptions{ + Headers: map[string]string{"Accept": "application/vnd.github.v3+json"}, + } + client, err := gh.RESTClient(&opts) + if err != nil { + return -1, err + } + body := struct { + Repositories []string `json:"repositories"` + Language string `json:"language"` + Pack string `json:"query_pack"` + Ref string `json:"action_repo_ref"` + }{ + Repositories: repoChunk, + Language: language, + Pack: bundle, + Ref: "main", + } + var buf bytes.Buffer + err = json.NewEncoder(&buf).Encode(body) + if err != nil { + return -1, err + } + response := make(map[string]interface{}) + err = client.Post(fmt.Sprintf("repos/%s/code-scanning/codeql/variant-analyses", controller), &buf, &response) + if err != nil { + return -1, err + } + id := int(response["id"].(float64)) + return id, nil +} + +func GetConfig() (models.Config, error) { + configFile, err := ioutil.ReadFile(configFilePath) + var configData models.Config + if err != nil { + return configData, err + } + err = yaml.Unmarshal(configFile, &configData) + if err != nil { + log.Fatal(err) + } + return configData, nil +} + +func ResolveRepositories(listFile string, list string) ([]string, error) { + fmt.Printf("Resolving %s repositories from %s\n", list, listFile) + jsonFile, err := os.Open(listFile) + if err != nil { + return nil, err + } + defer jsonFile.Close() + byteValue, _ := ioutil.ReadAll(jsonFile) + var repoLists map[string][]string + err = json.Unmarshal(byteValue, &repoLists) + if err != nil { + return nil, err + } + return repoLists[list], nil +} + +func ResolveQueries(codeqlPath string, querySuite string) []string { + args := []string{"resolve", "queries", "--format=json", querySuite} + jsonBytes, err := RunCodeQLCommand(codeqlPath, false, args...) + var queries []string + if strings.TrimSpace(string(jsonBytes)) == "" { + fmt.Println("No queries found in the specified query suite.") + os.Exit(1) + } + err = json.Unmarshal(jsonBytes, &queries) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + return queries +} + +func RunCodeQLCommand(codeqlPath string, combined bool, args ...string) ([]byte, error) { + if !strings.Contains(strings.Join(args, " "), "packlist") { + args = append(args, fmt.Sprintf("--additional-packs=%s", codeqlPath)) + } + cmd := exec.Command("codeql", args...) + cmd.Env = os.Environ() + if combined { + return cmd.CombinedOutput() + } else { + return cmd.Output() + } +} + +func GenerateQueryPack(codeqlPath string, queryFile string, language string) (string, error) { + fmt.Printf("Generating query pack for %s\n", queryFile) + + // create a temporary directory to hold the query pack + queryPackDir, err := ioutil.TempDir("", "query-pack-") + if err != nil { + log.Fatal(err) + } + defer os.RemoveAll(queryPackDir) + + queryFile, err = filepath.Abs(queryFile) + if err != nil { + log.Fatal(err) + } + if _, err := os.Stat(queryFile); errors.Is(err, os.ErrNotExist) { + log.Fatal(fmt.Sprintf("Query file %s does not exist", queryFile)) + } + originalPackRoot := FindPackRoot(queryFile) + packRelativePath, _ := filepath.Rel(originalPackRoot, queryFile) + targetQueryFileName := filepath.Join(queryPackDir, packRelativePath) + + if _, err := os.Stat(filepath.Join(originalPackRoot, "qlpack.yml")); errors.Is(err, os.ErrNotExist) { + // qlpack.yml not found, generate a synthetic one + fmt.Printf("QLPack does not exist. Generating synthetic one for %s\n", queryFile) + // copy only the query file to the query pack directory + err := CopyFile(queryFile, targetQueryFileName) + if err != nil { + log.Fatal(err) + } + // generate a synthetic qlpack.yml + td := struct { + Language string + Name string + Query string + }{ + Language: language, + Name: "codeql-remote/query", + Query: strings.Replace(packRelativePath, string(os.PathSeparator), "/", -1), + } + t, err := template.New("").Parse(`name: {{ .Name }} +version: 0.0.0 +dependencies: + codeql/{{ .Language }}-all: "*" +defaultSuite: + description: Query suite for variant analysis + query: {{ .Query }}`) + if err != nil { + log.Fatal(err) + } + + f, err := os.Create(filepath.Join(queryPackDir, "qlpack.yml")) + defer f.Close() + if err != nil { + log.Fatal(err) + } + err = t.Execute(f, td) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Copied QLPack files to %s\n", queryPackDir) + } else { + // don't include all query files in the QLPacks. We only want the queryFile to be copied. + fmt.Printf("QLPack exists, stripping all other queries from %s\n", originalPackRoot) + toCopy := PackPacklist(codeqlPath, originalPackRoot, false) + // also copy the lock file (either new name or old name) and the query file itself (these are not included in the packlist) + lockFileNew := filepath.Join(originalPackRoot, "qlpack.lock.yml") + lockFileOld := filepath.Join(originalPackRoot, "codeql-pack.lock.yml") + candidateFiles := []string{lockFileNew, lockFileOld, queryFile} + for _, candidateFile := range candidateFiles { + if _, err := os.Stat(candidateFile); !errors.Is(err, os.ErrNotExist) { + // if the file exists, copy it + toCopy = append(toCopy, candidateFile) + } + } + // copy the files to the queryPackDir directory + fmt.Printf("Preparing stripped QLPack in %s\n", queryPackDir) + for _, srcPath := range toCopy { + relPath, _ := filepath.Rel(originalPackRoot, srcPath) + targetPath := filepath.Join(queryPackDir, relPath) + //fmt.Printf("Copying %s to %s\n", srcPath, targetPath) + err := CopyFile(srcPath, targetPath) + if err != nil { + log.Fatal(err) + } + } + fmt.Printf("Fixing QLPack in %s\n", queryPackDir) + FixPackFile(queryPackDir, packRelativePath) + } + + // assuming we are using 2.11.3 or later so Qlx remote is supported + ccache := filepath.Join(originalPackRoot, ".cache") + precompilationOpts := []string{"--qlx", "--no-default-compilation-cache", "--compilation-cache=" + ccache} + bundlePath := filepath.Join(filepath.Dir(queryPackDir), fmt.Sprintf("qlpack-%s-generated.tgz", uuid.New().String())) + + // install the pack dependencies + fmt.Print("Installing QLPack dependencies\n") + args := []string{"pack", "install", queryPackDir} + stdouterr, err := RunCodeQLCommand(codeqlPath, true, args...) + if err != nil { + fmt.Printf("`codeql pack bundle` failed with error: %v\n", string(stdouterr)) + return "", fmt.Errorf("Failed to install query pack: %v", err) + } + // bundle the query pack + fmt.Print("Compiling and bundling the QLPack (This may take a while)\n") + args = []string{"pack", "bundle", "-o", bundlePath, queryPackDir} + args = append(args, precompilationOpts...) + stdouterr, err = RunCodeQLCommand(codeqlPath, true, args...) + if err != nil { + fmt.Printf("`codeql pack bundle` failed with error: %v\n", string(stdouterr)) + return "", fmt.Errorf("Failed to bundle query pack: %v\n", err) + } + + // open the bundle file and encode it as base64 + bundleFile, err := os.Open(bundlePath) + if err != nil { + return "", fmt.Errorf("Failed to open bundle file: %v\n", err) + } + defer bundleFile.Close() + bundleBytes, err := ioutil.ReadAll(bundleFile) + if err != nil { + return "", fmt.Errorf("Failed to read bundle file: %v\n", err) + } + bundleBase64 := base64.StdEncoding.EncodeToString(bundleBytes) + + return bundleBase64, nil +} + + +func PackPacklist(codeqlPath string, dir string, includeQueries bool) []string { + // since 2.7.1, packlist returns an object with a "paths" property that is a list of packs. + args := []string{"pack", "packlist", "--format=json"} + if !includeQueries { + args = append(args, "--no-include-queries") + } + args = append(args, dir) + jsonBytes, err := RunCodeQLCommand(codeqlPath, false, args...) + var packlist map[string][]string + err = json.Unmarshal(jsonBytes, &packlist) + if err != nil { + log.Fatal(err) + } + return packlist["paths"] +} + +func FindPackRoot(queryFile string) string { + // Starting on the directory of queryPackDir, go down until a qlpack.yml find is found. return that directory + // If no qlpack.yml is found, return the directory of queryFile + currentDir := filepath.Dir(queryFile) + for currentDir != "/" { + if _, err := os.Stat(filepath.Join(currentDir, "qlpack.yml")); errors.Is(err, os.ErrNotExist) { + // qlpack.yml not found, go up one level + currentDir = filepath.Dir(currentDir) + } else { + return currentDir + } + } + return filepath.Dir(queryFile) +} + +func FixPackFile(queryPackDir string, packRelativePath string) error { + packPath := filepath.Join(queryPackDir, "qlpack.yml") + packFile, err := ioutil.ReadFile(packPath) + if err != nil { + return err + } + var packData map[string]interface{} + err = yaml.Unmarshal(packFile, &packData) + if err != nil { + return err + } + // update the default suite + defaultSuiteFile := packData["defaultSuiteFile"] + if defaultSuiteFile != nil { + // remove the defaultSuiteFile property + delete(packData, "defaultSuiteFile") + } + packData["defaultSuite"] = map[string]string{ + "query": packRelativePath, + "description": "Query suite for Variant Analysis", + } + + // update the name + packData["name"] = "codeql-remote/query" + + // remove any `${workspace}` version references + dependencies := packData["dependencies"] + if dependencies != nil { + // for key and value in dependencies + for key, value := range dependencies.(map[string]interface{}) { + // if value is a string and value contains `${workspace}` + if value == "${workspace}" { + // replace the value with `*` + packData["dependencies"].(map[string]interface{})[key] = "*" + } + } + } + + // write the pack file + packFile, err = yaml.Marshal(packData) + if err != nil { + return err + } + err = ioutil.WriteFile(packPath, packFile, 0644) + if err != nil { + return err + } + return nil +} + +func CopyFile(srcPath string, targetPath string) error { + err := os.MkdirAll(filepath.Dir(targetPath), os.ModePerm) + if err != nil { + return err + } + bytesRead, err := ioutil.ReadFile(srcPath) + if err != nil { + return err + } + err = ioutil.WriteFile(targetPath, bytesRead, 0644) + if err != nil { + return err + } + return nil +} + +func DownloadWorker(wg *sync.WaitGroup, taskChannel <-chan models.DownloadTask, resultChannel chan models.DownloadTask) { + defer wg.Done() + for task := range taskChannel { + if task.Artifact == "artifact" { + DownloadResults(task.Controller, task.RunId, task.Nwo, task.OutputDir) + resultChannel <- task + } else if task.Artifact == "database" { + fmt.Println("Downloading database", task.Nwo, task.Language, task.OutputDir) + DownloadDatabase(task.Nwo, task.Language, task.OutputDir) + resultChannel <- task + } + } +} + +func downloadArtifact(url string, outputDir string, nwo string) error { + client, err := gh.HTTPClient(nil) + if err != nil { + return err + } + resp, err := client.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Fatal(err) + } + + zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) + if err != nil { + log.Fatal(err) + } + + for _, zf := range zipReader.File { + if zf.Name != "results.sarif" && zf.Name != "results.bqrs" { + continue + } + f, err := zf.Open() + if err != nil { + log.Fatal(err) + } + defer f.Close() + bytes, err := ioutil.ReadAll(f) + if err != nil { + log.Fatal(err) + } + extension := "" + resultPath := "" + if zf.Name == "results.bqrs" { + extension = "bqrs" + } else if zf.Name == "results.sarif" { + extension = "sarif" + } + resultPath = filepath.Join(outputDir, fmt.Sprintf("%s.%s", strings.Replace(nwo, "/", "_", -1), extension)) + err = ioutil.WriteFile(resultPath, bytes, os.ModePerm) + if err != nil { + return err + } + return nil + } + return errors.New("No results.sarif file found in artifact") +} + +func DownloadResults(controller string, runId int, nwo string, outputDir string) error { + // download artifact (BQRS or SARIF) + runRepositoryDetails, err := GetRunRepositoryDetails(controller, runId, nwo) + if err != nil { + return errors.New("Failed to get run repository details") + } + // download the results + err = downloadArtifact(runRepositoryDetails["artifact_url"].(string), outputDir, nwo) + if err != nil { + return errors.New("Failed to download artifact") + } + return nil +} + +func DownloadDatabase(nwo string, language string, outputDir string) error { + dnwo := strings.Replace(nwo, "/", "_", -1) + targetPath := filepath.Join(outputDir, fmt.Sprintf("%s_%s_db.zip", dnwo, language)) + opts := api.ClientOptions{ + Headers: map[string]string{"Accept": "application/zip"}, + } + client, err := gh.HTTPClient(&opts) + if err != nil { + return err + } + resp, err := client.Get(fmt.Sprintf("https://api.github.com/repos/%s/code-scanning/codeql/databases/%s", nwo, language)) + if err != nil { + return err + } + defer resp.Body.Close() + + bytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + err = ioutil.WriteFile(targetPath, bytes, os.ModePerm) + return nil +} +