Convert SSRF tests to inline expectations tests

This commit is contained in:
Owen Mansel-Chan
2025-10-01 10:06:12 +01:00
parent 6e4dbe8e22
commit 8b04d0a2b9
3 changed files with 34 additions and 32 deletions

View File

@@ -1,2 +1,4 @@
query: experimental/CWE-918/SSRF.ql query: experimental/CWE-918/SSRF.ql
postprocess: utils/test/PrettyPrintModels.ql postprocess:
- utils/test/PrettyPrintModels.ql
- utils/test/InlineExpectationsTestQuery.ql

View File

@@ -16,10 +16,10 @@ import (
) )
func handler(w http.ResponseWriter, req *http.Request) { func handler(w http.ResponseWriter, req *http.Request) {
target := req.FormValue("target") target := req.FormValue("target") // $ Source
// BAD: `target` is controlled by the attacker // BAD: `target` is controlled by the attacker
_, err := http.Get("https://" + target + ".example.com/data/") _, err := http.Get("https://" + target + ".example.com/data/") // $ Alert
if err != nil { if err != nil {
// error handling // error handling
} }
@@ -80,12 +80,12 @@ func test() {
// x net websocket dial bad // x net websocket dial bad
http.HandleFunc("/ex2", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/ex2", func(w http.ResponseWriter, r *http.Request) {
untrustedInput := r.Referer() untrustedInput := r.Referer() // $ Source
origin := "http://localhost/" origin := "http://localhost/"
// bad as input is directly passed to dial function // bad as input is directly passed to dial function
ws, _ := websocket.Dial(untrustedInput, "", origin) // SSRF ws, _ := websocket.Dial(untrustedInput, "", origin) // $ Alert
var msg = make([]byte, 512) var msg = make([]byte, 512)
var n int var n int
n, _ = ws.Read(msg) n, _ = ws.Read(msg)
@@ -94,12 +94,12 @@ func test() {
// x net websocket dialConfig bad // x net websocket dialConfig bad
http.HandleFunc("/ex3", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/ex3", func(w http.ResponseWriter, r *http.Request) {
untrustedInput := r.Referer() untrustedInput := r.Referer() // $ Source
origin := "http://localhost/" origin := "http://localhost/"
// bad as input is directly used // bad as input is directly used
config, _ := websocket.NewConfig(untrustedInput, origin) // SSRF config, _ := websocket.NewConfig(untrustedInput, origin) // $ Sink
ws2, _ := websocket.DialConfig(config) ws2, _ := websocket.DialConfig(config) // $ Alert
var msg = make([]byte, 512) var msg = make([]byte, 512)
var n int var n int
n, _ = ws2.Read(msg) n, _ = ws2.Read(msg)
@@ -108,10 +108,10 @@ func test() {
// gorilla websocket Dialer.Dial bad // gorilla websocket Dialer.Dial bad
http.HandleFunc("/ex6", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/ex6", func(w http.ResponseWriter, r *http.Request) {
untrustedInput := r.Referer() untrustedInput := r.Referer() // $ Source
dialer := gorilla.Dialer{} dialer := gorilla.Dialer{}
dialer.Dial(untrustedInput, r.Header) //SSRF dialer.Dial(untrustedInput, r.Header) // $ Alert
}) })
// gorilla websocket Dialer.Dial good // gorilla websocket Dialer.Dial good
@@ -126,10 +126,10 @@ func test() {
// gorilla websocket Dialer.DialContext bad // gorilla websocket Dialer.DialContext bad
http.HandleFunc("/ex8", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/ex8", func(w http.ResponseWriter, r *http.Request) {
untrustedInput := r.Referer() untrustedInput := r.Referer() // $ Source
dialer := gorilla.Dialer{} dialer := gorilla.Dialer{}
dialer.DialContext(context.TODO(), untrustedInput, r.Header) //SSRF dialer.DialContext(context.TODO(), untrustedInput, r.Header) // $ Alert
}) })
// gorilla websocket Dialer.DialContext good // gorilla websocket Dialer.DialContext good

View File

@@ -23,20 +23,20 @@ func HandlerGin(c *gin.Context) {
safe string `binding:"alphanum"` safe string `binding:"alphanum"`
} }
err := c.ShouldBindJSON(&body) err := c.ShouldBindJSON(&body) // $ Source
http.Get(fmt.Sprintf("http://example.com/%d", body.integer)) // OK http.Get(fmt.Sprintf("http://example.com/%d", body.integer)) // OK
http.Get(fmt.Sprintf("http://example.com/%v", body.float)) // OK http.Get(fmt.Sprintf("http://example.com/%v", body.float)) // OK
http.Get(fmt.Sprintf("http://example.com/%v", body.boolean)) // OK http.Get(fmt.Sprintf("http://example.com/%v", body.boolean)) // OK
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert
http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // SSRF http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // $ Alert
if err == nil { if err == nil {
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert
http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // OK http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // OK
} }
taintedParam := c.Param("id") taintedParam := c.Param("id") // $ Source
validate := validator.New() validate := validator.New()
err = validate.Var(taintedParam, "alpha") err = validate.Var(taintedParam, "alpha")
@@ -44,10 +44,10 @@ func HandlerGin(c *gin.Context) {
http.Get("http://example.com/" + taintedParam) // OK http.Get("http://example.com/" + taintedParam) // OK
} }
http.Get("http://example.com/" + taintedParam) //SSRF http.Get("http://example.com/" + taintedParam) // $ Alert
taintedQuery := c.Query("id") taintedQuery := c.Query("id") // $ Source
http.Get("http://example.com/" + taintedQuery) //SSRF http.Get("http://example.com/" + taintedQuery) // $ Alert
} }
func HandlerHttp(req *http.Request) { func HandlerHttp(req *http.Request) {
@@ -59,41 +59,41 @@ func HandlerHttp(req *http.Request) {
word string word string
safe string `validate:"alphanum"` safe string `validate:"alphanum"`
} }
reqBody, _ := ioutil.ReadAll(req.Body) reqBody, _ := ioutil.ReadAll(req.Body) // $ Source
json.Unmarshal(reqBody, &body) json.Unmarshal(reqBody, &body)
http.Get(fmt.Sprintf("http://example.com/%d", body.integer)) // OK http.Get(fmt.Sprintf("http://example.com/%d", body.integer)) // OK
http.Get(fmt.Sprintf("http://example.com/%v", body.float)) // OK http.Get(fmt.Sprintf("http://example.com/%v", body.float)) // OK
http.Get(fmt.Sprintf("http://example.com/%v", body.boolean)) // OK http.Get(fmt.Sprintf("http://example.com/%v", body.boolean)) // OK
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert
http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // SSRF http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // $ Alert
validate := validator.New() validate := validator.New()
err := validate.Struct(body) err := validate.Struct(body)
if err == nil { if err == nil {
http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // SSRF http.Get(fmt.Sprintf("http://example.com/%s", body.word)) // $ Alert
http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // OK http.Get(fmt.Sprintf("http://example.com/%s", body.safe)) // OK
} }
taintedQuery := req.URL.Query().Get("param1") taintedQuery := req.URL.Query().Get("param1") // $ Source
http.Get("http://example.com/" + taintedQuery) // SSRF http.Get("http://example.com/" + taintedQuery) // $ Alert
taintedParam := strings.TrimPrefix(req.URL.Path, "/example-path/") taintedParam := strings.TrimPrefix(req.URL.Path, "/example-path/") // $ Source
http.Get("http://example.com/" + taintedParam) // SSRF http.Get("http://example.com/" + taintedParam) // $ Alert
} }
func HandlerMux(r *http.Request) { func HandlerMux(r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r) // $ Source
taintedParam := vars["id"] taintedParam := vars["id"]
http.Get("http://example.com/" + taintedParam) // SSRF http.Get("http://example.com/" + taintedParam) // $ Alert
numericID, _ := strconv.Atoi(taintedParam) numericID, _ := strconv.Atoi(taintedParam)
http.Get(fmt.Sprintf("http://example.com/%d", numericID)) // OK http.Get(fmt.Sprintf("http://example.com/%d", numericID)) // OK
} }
func HandlerChi(r *http.Request) { func HandlerChi(r *http.Request) {
taintedParam := chi.URLParam(r, "articleID") taintedParam := chi.URLParam(r, "articleID") // $ Source
http.Get("http://example.com/" + taintedParam) // SSRF http.Get("http://example.com/" + taintedParam) // $ Alert
b, _ := strconv.ParseBool(taintedParam) b, _ := strconv.ParseBool(taintedParam)
http.Get(fmt.Sprintf("http://example.com/%t", b)) // OK http.Get(fmt.Sprintf("http://example.com/%t", b)) // OK