From 8b04d0a2b9f7e41836475e51def398ede92a4f78 Mon Sep 17 00:00:00 2001 From: Owen Mansel-Chan Date: Wed, 1 Oct 2025 10:06:12 +0100 Subject: [PATCH] Convert SSRF tests to inline expectations tests --- go/ql/test/experimental/CWE-918/SSRF.qlref | 4 +- go/ql/test/experimental/CWE-918/builtin.go | 22 +++++------ go/ql/test/experimental/CWE-918/new-tests.go | 40 ++++++++++---------- 3 files changed, 34 insertions(+), 32 deletions(-) diff --git a/go/ql/test/experimental/CWE-918/SSRF.qlref b/go/ql/test/experimental/CWE-918/SSRF.qlref index 7cba541836f..d68094fa2a0 100644 --- a/go/ql/test/experimental/CWE-918/SSRF.qlref +++ b/go/ql/test/experimental/CWE-918/SSRF.qlref @@ -1,2 +1,4 @@ query: experimental/CWE-918/SSRF.ql -postprocess: utils/test/PrettyPrintModels.ql +postprocess: + - utils/test/PrettyPrintModels.ql + - utils/test/InlineExpectationsTestQuery.ql diff --git a/go/ql/test/experimental/CWE-918/builtin.go b/go/ql/test/experimental/CWE-918/builtin.go index 5c65bc9d3de..463f4b3d09b 100644 --- a/go/ql/test/experimental/CWE-918/builtin.go +++ b/go/ql/test/experimental/CWE-918/builtin.go @@ -16,10 +16,10 @@ import ( ) func handler(w http.ResponseWriter, req *http.Request) { - target := req.FormValue("target") + target := req.FormValue("target") // $ Source // BAD: `target` is controlled by the attacker - _, err := http.Get("https://" + target + ".example.com/data/") + _, err := http.Get("https://" + target + ".example.com/data/") // $ Alert if err != nil { // error handling } @@ -80,12 +80,12 @@ func test() { // x net websocket dial bad http.HandleFunc("/ex2", func(w http.ResponseWriter, r *http.Request) { - untrustedInput := r.Referer() + untrustedInput := r.Referer() // $ Source origin := "http://localhost/" // bad as input is directly passed to dial function - ws, _ := websocket.Dial(untrustedInput, "", origin) // SSRF + ws, _ := websocket.Dial(untrustedInput, "", origin) // $ Alert var msg = make([]byte, 512) var n int n, _ = ws.Read(msg) @@ -94,12 +94,12 @@ func test() { // x net websocket dialConfig bad http.HandleFunc("/ex3", func(w http.ResponseWriter, r *http.Request) { - untrustedInput := r.Referer() + untrustedInput := r.Referer() // $ Source origin := "http://localhost/" // bad as input is directly used - config, _ := websocket.NewConfig(untrustedInput, origin) // SSRF - ws2, _ := websocket.DialConfig(config) + config, _ := websocket.NewConfig(untrustedInput, origin) // $ Sink + ws2, _ := websocket.DialConfig(config) // $ Alert var msg = make([]byte, 512) var n int n, _ = ws2.Read(msg) @@ -108,10 +108,10 @@ func test() { // gorilla websocket Dialer.Dial bad http.HandleFunc("/ex6", func(w http.ResponseWriter, r *http.Request) { - untrustedInput := r.Referer() + untrustedInput := r.Referer() // $ Source dialer := gorilla.Dialer{} - dialer.Dial(untrustedInput, r.Header) //SSRF + dialer.Dial(untrustedInput, r.Header) // $ Alert }) // gorilla websocket Dialer.Dial good @@ -126,10 +126,10 @@ func test() { // gorilla websocket Dialer.DialContext bad http.HandleFunc("/ex8", func(w http.ResponseWriter, r *http.Request) { - untrustedInput := r.Referer() + untrustedInput := r.Referer() // $ Source dialer := gorilla.Dialer{} - dialer.DialContext(context.TODO(), untrustedInput, r.Header) //SSRF + dialer.DialContext(context.TODO(), untrustedInput, r.Header) // $ Alert }) // gorilla websocket Dialer.DialContext good diff --git a/go/ql/test/experimental/CWE-918/new-tests.go b/go/ql/test/experimental/CWE-918/new-tests.go index 040bad48596..ddf94daa09e 100644 --- a/go/ql/test/experimental/CWE-918/new-tests.go +++ b/go/ql/test/experimental/CWE-918/new-tests.go @@ -23,20 +23,20 @@ func HandlerGin(c *gin.Context) { safe string `binding:"alphanum"` } - err := c.ShouldBindJSON(&body) + err := c.ShouldBindJSON(&body) // $ Source http.Get(fmt.Sprintf("http://example.com/%d", body.integer)) // OK http.Get(fmt.Sprintf("http://example.com/%v", body.float)) // OK http.Get(fmt.Sprintf("http://example.com/%v", body.boolean)) // OK - http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF - http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // SSRF + http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert + http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // $ Alert if err == nil { - http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF + http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // OK } - taintedParam := c.Param("id") + taintedParam := c.Param("id") // $ Source validate := validator.New() err = validate.Var(taintedParam, "alpha") @@ -44,10 +44,10 @@ func HandlerGin(c *gin.Context) { http.Get("http://example.com/" + taintedParam) // OK } - http.Get("http://example.com/" + taintedParam) //SSRF + http.Get("http://example.com/" + taintedParam) // $ Alert - taintedQuery := c.Query("id") - http.Get("http://example.com/" + taintedQuery) //SSRF + taintedQuery := c.Query("id") // $ Source + http.Get("http://example.com/" + taintedQuery) // $ Alert } func HandlerHttp(req *http.Request) { @@ -59,41 +59,41 @@ func HandlerHttp(req *http.Request) { word string safe string `validate:"alphanum"` } - reqBody, _ := ioutil.ReadAll(req.Body) + reqBody, _ := ioutil.ReadAll(req.Body) // $ Source json.Unmarshal(reqBody, &body) http.Get(fmt.Sprintf("http://example.com/%d", body.integer)) // OK http.Get(fmt.Sprintf("http://example.com/%v", body.float)) // OK http.Get(fmt.Sprintf("http://example.com/%v", body.boolean)) // OK - http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF - http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // SSRF + http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert + http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // $ Alert validate := validator.New() err := validate.Struct(body) if err == nil { - http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF + http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // OK } - taintedQuery := req.URL.Query().Get("param1") - http.Get("http://example.com/" + taintedQuery) // SSRF + taintedQuery := req.URL.Query().Get("param1") // $ Source + http.Get("http://example.com/" + taintedQuery) // $ Alert - taintedParam := strings.TrimPrefix(req.URL.Path, "/example-path/") - http.Get("http://example.com/" + taintedParam) // SSRF + taintedParam := strings.TrimPrefix(req.URL.Path, "/example-path/") // $ Source + http.Get("http://example.com/" + taintedParam) // $ Alert } func HandlerMux(r *http.Request) { - vars := mux.Vars(r) + vars := mux.Vars(r) // $ Source taintedParam := vars["id"] - http.Get("http://example.com/" + taintedParam) // SSRF + http.Get("http://example.com/" + taintedParam) // $ Alert numericID, _ := strconv.Atoi(taintedParam) http.Get(fmt.Sprintf("http://example.com/%d", numericID)) // OK } func HandlerChi(r *http.Request) { - taintedParam := chi.URLParam(r, "articleID") - http.Get("http://example.com/" + taintedParam) // SSRF + taintedParam := chi.URLParam(r, "articleID") // $ Source + http.Get("http://example.com/" + taintedParam) // $ Alert b, _ := strconv.ParseBool(taintedParam) http.Get(fmt.Sprintf("http://example.com/%t", b)) // OK