Formatting

This commit is contained in:
Alvaro Muñoz
2023-09-11 16:55:32 +02:00
parent 81de79e4c8
commit 361fa7c833
5 changed files with 122 additions and 122 deletions

View File

@@ -5,16 +5,16 @@ Copyright © 2023 NAME HERE <EMAIL ADDRESS>
package cmd package cmd
import ( import (
"sync" "errors"
"errors"
"fmt" "fmt"
"os" "github.com/GitHubSecurityLab/gh-mrva/config"
"path/filepath" "github.com/GitHubSecurityLab/gh-mrva/models"
"strings" "github.com/GitHubSecurityLab/gh-mrva/utils"
"log" "log"
"github.com/GitHubSecurityLab/gh-mrva/utils" "os"
"github.com/GitHubSecurityLab/gh-mrva/models" "path/filepath"
"github.com/GitHubSecurityLab/gh-mrva/config" "strings"
"sync"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@@ -22,9 +22,9 @@ import (
var downloadCmd = &cobra.Command{ var downloadCmd = &cobra.Command{
Use: "download", Use: "download",
Short: "Downloads the artifacts associated to a given session.", Short: "Downloads the artifacts associated to a given session.",
Long: `Downloads the artifacts associated to a given session.`, Long: `Downloads the artifacts associated to a given session.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
downloadArtifacts() downloadArtifacts()
}, },
} }

View File

@@ -22,33 +22,35 @@ THE SOFTWARE.
package cmd package cmd
import ( import (
"github.com/GitHubSecurityLab/gh-mrva/utils"
"log"
"os" "os"
"log"
"github.com/GitHubSecurityLab/gh-mrva/utils"
"path/filepath" "path/filepath"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var ( var (
sessionNameFlag string sessionNameFlag string
sessionPrefixFlag string runIdFlag int
outputDirFlag string sessionPrefixFlag string
downloadDBsFlag bool outputDirFlag string
nwoFlag string outputFilenameFlag string
jsonFlag bool downloadDBsFlag bool
languageFlag string nwoFlag string
listFileFlag string jsonFlag bool
listFlag string languageFlag string
codeqlPathFlag string listFileFlag string
controllerFlag string listFlag string
queryFileFlag string codeqlPathFlag string
querySuiteFileFlag string controllerFlag string
) queryFileFlag string
querySuiteFileFlag string
)
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{
Use: "gh-mrva", Use: "gh-mrva",
Short: "Run CodeQL queries at scale using GitHub's Multi-Repository Variant Analysis (MRVA)", Short: "Run CodeQL queries at scale using GitHub's Multi-Repository Variant Analysis (MRVA)",
Long: `Run CodeQL queries at scale using GitHub's Multi-Repository Variant Analysis (MRVA)`, Long: `Run CodeQL queries at scale using GitHub's Multi-Repository Variant Analysis (MRVA)`,
} }
func Execute() { func Execute() {
@@ -67,10 +69,10 @@ func init() {
} }
configPath = filepath.Join(homePath, ".config") configPath = filepath.Join(homePath, ".config")
} }
configFilePath := filepath.Join(configPath, "gh-mrva", "config.yml") configFilePath := filepath.Join(configPath, "gh-mrva", "config.yml")
utils.SetConfigFilePath(configFilePath) utils.SetConfigFilePath(configFilePath)
sessionsFilePath := filepath.Join(configPath, "gh-mrva", "sessions.yml") sessionsFilePath := filepath.Join(configPath, "gh-mrva", "sessions.yml")
if _, err := os.Stat(sessionsFilePath); os.IsNotExist(err) { if _, err := os.Stat(sessionsFilePath); os.IsNotExist(err) {
err := os.MkdirAll(filepath.Dir(sessionsFilePath), os.ModePerm) err := os.MkdirAll(filepath.Dir(sessionsFilePath), os.ModePerm)
if err != nil { if err != nil {
@@ -83,5 +85,5 @@ func init() {
} }
sessionsFile.Close() sessionsFile.Close()
} }
utils.SetSessionsFilePath(sessionsFilePath) utils.SetSessionsFilePath(sessionsFilePath)
} }

View File

@@ -1,37 +1,35 @@
/* /*
Copyright © 2023 NAME HERE <EMAIL ADDRESS> Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/ */
package cmd package cmd
import ( import (
"fmt" "fmt"
"log" "log"
"os" "os"
"github.com/GitHubSecurityLab/gh-mrva/config"
"github.com/GitHubSecurityLab/gh-mrva/models"
"github.com/GitHubSecurityLab/gh-mrva/utils"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/GitHubSecurityLab/gh-mrva/utils"
"github.com/GitHubSecurityLab/gh-mrva/config"
"github.com/GitHubSecurityLab/gh-mrva/models"
) )
var ( var (
controller string controller string
codeqlPath string codeqlPath string
listFile string listFile string
listName string listName string
language string language string
sessionName string sessionName string
queryFile string queryFile string
querySuiteFile string querySuiteFile string
) )
var submitCmd = &cobra.Command{ var submitCmd = &cobra.Command{
Use: "submit", Use: "submit",
Short: "Submit a query or query suite to a MRVA controller.", Short: "Submit a query or query suite to a MRVA controller.",
Long: `Submit a query or query suite to a MRVA controller.`, Long: `Submit a query or query suite to a MRVA controller.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
submitQuery() submitQuery()
}, },
} }
@@ -40,18 +38,18 @@ func init() {
submitCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name") submitCmd.Flags().StringVarP(&sessionNameFlag, "session", "s", "", "Session name")
submitCmd.Flags().StringVarP(&languageFlag, "language", "l", "", "DB language") submitCmd.Flags().StringVarP(&languageFlag, "language", "l", "", "DB language")
submitCmd.Flags().StringVarP(&queryFileFlag, "query", "q", "", "Path to query file") submitCmd.Flags().StringVarP(&queryFileFlag, "query", "q", "", "Path to query file")
submitCmd.Flags().StringVarP(&querySuiteFileFlag, "query-suite","x", "", "Path to query suite file") submitCmd.Flags().StringVarP(&querySuiteFileFlag, "query-suite", "x", "", "Path to query suite file")
submitCmd.Flags().StringVarP(&controllerFlag, "controller", "c", "", "MRVA controller repository (overrides config file)") submitCmd.Flags().StringVarP(&controllerFlag, "controller", "c", "", "MRVA controller repository (overrides config file)")
submitCmd.Flags().StringVarP(&listFileFlag, "list-file", "f", "", "Path to repo list file (overrides config file)") submitCmd.Flags().StringVarP(&listFileFlag, "list-file", "f", "", "Path to repo list file (overrides config file)")
submitCmd.Flags().StringVarP(&listFlag, "list", "i", "", "Name of repo list") submitCmd.Flags().StringVarP(&listFlag, "list", "i", "", "Name of repo list")
submitCmd.Flags().StringVarP(&codeqlPathFlag, "codeql-path", "p", "", "Path to CodeQL distribution (overrides config file)") submitCmd.Flags().StringVarP(&codeqlPathFlag, "codeql-path", "p", "", "Path to CodeQL distribution (overrides config file)")
submitCmd.MarkFlagRequired("session") submitCmd.MarkFlagRequired("session")
submitCmd.MarkFlagRequired("language") submitCmd.MarkFlagRequired("language")
submitCmd.MarkFlagsMutuallyExclusive("query", "query-suite") submitCmd.MarkFlagsMutuallyExclusive("query", "query-suite")
} }
func submitQuery() { func submitQuery() {
configData, err := utils.GetConfig() configData, err := utils.GetConfig()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@@ -88,26 +86,26 @@ func submitQuery() {
} }
if controller == "" { if controller == "" {
fmt.Println("Please specify a controller.") fmt.Println("Please specify a controller.")
os.Exit(1) os.Exit(1)
} }
if listFile == "" { if listFile == "" {
fmt.Println("Please specify a list file.") fmt.Println("Please specify a list file.")
os.Exit(1) os.Exit(1)
} }
if listName == "" { if listName == "" {
fmt.Println("Please specify a list name.") fmt.Println("Please specify a list name.")
os.Exit(1) os.Exit(1)
} }
if queryFile == "" && querySuiteFile == "" { if queryFile == "" && querySuiteFile == "" {
fmt.Println("Please specify a query or query suite.") fmt.Println("Please specify a query or query suite.")
os.Exit(1) os.Exit(1)
} }
if _, _, _, err := utils.LoadSession(sessionName); err == nil { if _, _, _, err := utils.LoadSession(sessionName); err == nil {
fmt.Println("Session already exists.") fmt.Println("Session already exists.")
os.Exit(1) os.Exit(1)
} }
// read list of target repositories // read list of target repositories
repositories, err := utils.ResolveRepositories(listFile, listName) repositories, err := utils.ResolveRepositories(listFile, listName)
@@ -159,4 +157,3 @@ func submitQuery() {
} }
fmt.Println("Done!") fmt.Println("Done!")
} }

View File

@@ -1,7 +1,7 @@
package models package models
import ( import (
"time" "time"
) )
type Run struct { type Run struct {

View File

@@ -1,27 +1,28 @@
package utils package utils
import ( import (
"archive/zip" "archive/zip"
"sync" "bufio"
"text/template" "bytes"
"github.com/google/uuid"
"encoding/base64" "encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"gopkg.in/yaml.v3"
"io"
"log"
"os"
"os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
"os/exec" "sync"
"bytes" "text/template"
"encoding/json" "time"
"fmt"
"time"
"os"
"gopkg.in/yaml.v3"
"io/ioutil"
"log"
"errors"
"github.com/GitHubSecurityLab/gh-mrva/models"
"github.com/cli/go-gh" "github.com/cli/go-gh"
"github.com/cli/go-gh/pkg/api" "github.com/cli/go-gh/pkg/api"
"github.com/GitHubSecurityLab/gh-mrva/models"
) )
var ( var (
@@ -30,23 +31,23 @@ var (
) )
func GetSessionsFilePath() string { func GetSessionsFilePath() string {
return sessionsFilePath return sessionsFilePath
} }
func SetSessionsFilePath(path string) { func SetSessionsFilePath(path string) {
sessionsFilePath = path sessionsFilePath = path
} }
func GetConfigFilePath() string { func GetConfigFilePath() string {
return configFilePath return configFilePath
} }
func SetConfigFilePath(path string) { func SetConfigFilePath(path string) {
configFilePath = path configFilePath = path
} }
func GetSessions() (map[string]models.Session, error) { func GetSessions() (map[string]models.Session, error) {
sessionsFile, err := ioutil.ReadFile(sessionsFilePath) sessionsFile, err := os.ReadFile(sessionsFilePath)
var sessions map[string]models.Session var sessions map[string]models.Session
if err != nil { if err != nil {
return sessions, err return sessions, err
@@ -72,19 +73,19 @@ func LoadSession(name string) (string, []models.Run, string, error) {
} }
func GetSessionsStartingWith(prefix string) ([]string, error) { func GetSessionsStartingWith(prefix string) ([]string, error) {
sessions, err := GetSessions() sessions, err := GetSessions()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var matchingSessions []string var matchingSessions []string
if sessions != nil { if sessions != nil {
for session := range sessions { for session := range sessions {
if strings.HasPrefix(session, prefix) { if strings.HasPrefix(session, prefix) {
matchingSessions = append(matchingSessions, session) matchingSessions = append(matchingSessions, session)
} }
} }
} }
return matchingSessions, nil return matchingSessions, nil
} }
func GetRunDetails(controller string, runId int) (map[string]interface{}, error) { func GetRunDetails(controller string, runId int) (map[string]interface{}, error) {
@@ -148,7 +149,7 @@ func SaveSession(name string, controller string, runs []models.Run, language str
return err return err
} }
// write sessions to file // write sessions to file
err = ioutil.WriteFile(sessionsFilePath, sessionsYaml, os.ModePerm) err = os.WriteFile(sessionsFilePath, sessionsYaml, os.ModePerm)
if err != nil { if err != nil {
return err return err
} }
@@ -189,7 +190,7 @@ func SubmitRun(controller string, language string, repoChunk []string, bundle st
} }
func GetConfig() (models.Config, error) { func GetConfig() (models.Config, error) {
configFile, err := ioutil.ReadFile(configFilePath) configFile, err := os.ReadFile(configFilePath)
var configData models.Config var configData models.Config
if err != nil { if err != nil {
return configData, err return configData, err
@@ -208,7 +209,7 @@ func ResolveRepositories(listFile string, list string) ([]string, error) {
return nil, err return nil, err
} }
defer jsonFile.Close() defer jsonFile.Close()
byteValue, _ := ioutil.ReadAll(jsonFile) byteValue, _ := io.ReadAll(jsonFile)
var repoLists map[string][]string var repoLists map[string][]string
err = json.Unmarshal(byteValue, &repoLists) err = json.Unmarshal(byteValue, &repoLists)
if err != nil { if err != nil {
@@ -259,14 +260,14 @@ func ResolveQueries(codeqlPath string, querySuite string) []string {
args := []string{"resolve", "queries", "--format=json", querySuite} args := []string{"resolve", "queries", "--format=json", querySuite}
jsonBytes, err := RunCodeQLCommand(codeqlPath, false, args...) jsonBytes, err := RunCodeQLCommand(codeqlPath, false, args...)
var queries []string var queries []string
if strings.TrimSpace(string(jsonBytes)) == "" { if strings.TrimSpace(string(jsonBytes)) == "" {
fmt.Println("No queries found in the specified query suite.") fmt.Println("No queries found in the specified query suite.")
os.Exit(1) os.Exit(1)
} }
err = json.Unmarshal(jsonBytes, &queries) err = json.Unmarshal(jsonBytes, &queries)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
return queries return queries
} }
@@ -288,7 +289,7 @@ func GenerateQueryPack(codeqlPath string, queryFile string, language string) (st
fmt.Printf("Generating query pack for %s\n", queryFile) fmt.Printf("Generating query pack for %s\n", queryFile)
// create a temporary directory to hold the query pack // create a temporary directory to hold the query pack
queryPackDir, err := ioutil.TempDir("", "query-pack-") queryPackDir, err := os.MkdirTemp("", "query-pack-")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@@ -406,7 +407,7 @@ defaultSuite:
return "", "", fmt.Errorf("Failed to open bundle file: %v\n", err) return "", "", fmt.Errorf("Failed to open bundle file: %v\n", err)
} }
defer bundleFile.Close() defer bundleFile.Close()
bundleBytes, err := ioutil.ReadAll(bundleFile) bundleBytes, err := io.ReadAll(bundleFile)
if err != nil { if err != nil {
return "", "", fmt.Errorf("Failed to read bundle file: %v\n", err) return "", "", fmt.Errorf("Failed to read bundle file: %v\n", err)
} }
@@ -449,7 +450,7 @@ func FindPackRoot(queryFile string) string {
func FixPackFile(queryPackDir string, packRelativePath string) error { func FixPackFile(queryPackDir string, packRelativePath string) error {
packPath := filepath.Join(queryPackDir, "qlpack.yml") packPath := filepath.Join(queryPackDir, "qlpack.yml")
packFile, err := ioutil.ReadFile(packPath) packFile, err := os.ReadFile(packPath)
if err != nil { if err != nil {
return err return err
} }
@@ -490,7 +491,7 @@ func FixPackFile(queryPackDir string, packRelativePath string) error {
if err != nil { if err != nil {
return err return err
} }
err = ioutil.WriteFile(packPath, packFile, 0644) err = os.WriteFile(packPath, packFile, 0644)
if err != nil { if err != nil {
return err return err
} }
@@ -502,11 +503,11 @@ func CopyFile(srcPath string, targetPath string) error {
if err != nil { if err != nil {
return err return err
} }
bytesRead, err := ioutil.ReadFile(srcPath) bytesRead, err := os.ReadFile(srcPath)
if err != nil { if err != nil {
return err return err
} }
err = ioutil.WriteFile(targetPath, bytesRead, 0644) err = os.WriteFile(targetPath, bytesRead, 0644)
if err != nil { if err != nil {
return err return err
} }
@@ -538,7 +539,7 @@ func downloadArtifact(url string, outputDir string, nwo string) error {
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@@ -557,7 +558,7 @@ func downloadArtifact(url string, outputDir string, nwo string) error {
log.Fatal(err) log.Fatal(err)
} }
defer f.Close() defer f.Close()
bytes, err := ioutil.ReadAll(f) bytes, err := io.ReadAll(f)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@@ -569,7 +570,7 @@ func downloadArtifact(url string, outputDir string, nwo string) error {
extension = "sarif" extension = "sarif"
} }
resultPath = filepath.Join(outputDir, fmt.Sprintf("%s.%s", strings.Replace(nwo, "/", "_", -1), extension)) resultPath = filepath.Join(outputDir, fmt.Sprintf("%s.%s", strings.Replace(nwo, "/", "_", -1), extension))
err = ioutil.WriteFile(resultPath, bytes, os.ModePerm) err = os.WriteFile(resultPath, bytes, os.ModePerm)
if err != nil { if err != nil {
return err return err
} }
@@ -608,11 +609,11 @@ func DownloadDatabase(nwo string, language string, outputDir string) error {
} }
defer resp.Body.Close() defer resp.Body.Close()
bytes, err := ioutil.ReadAll(resp.Body) bytes, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return err return err
} }
err = ioutil.WriteFile(targetPath, bytes, os.ModePerm) err = os.WriteFile(targetPath, bytes, os.ModePerm)
return nil return nil
} }