diff --git a/javascript/ql/experimental/adaptivethreatmodeling/lib/experimental/adaptivethreatmodeling/EndpointScoring.qll b/javascript/ql/experimental/adaptivethreatmodeling/lib/experimental/adaptivethreatmodeling/EndpointScoring.qll index e5e90c5038f..c6c295c24ef 100644 --- a/javascript/ql/experimental/adaptivethreatmodeling/lib/experimental/adaptivethreatmodeling/EndpointScoring.qll +++ b/javascript/ql/experimental/adaptivethreatmodeling/lib/experimental/adaptivethreatmodeling/EndpointScoring.qll @@ -34,45 +34,24 @@ module ModelScoring { result = any(FeaturizationConfig cfg).getAnEndpointToFeaturize() } - private int getARequestedEndpointType() { result = any(EndpointType type).getEncoding() } - predicate getEndpointPrompt(DataFlow::Node node, string prompt) { node = getARequestedEndpoint() and prompt = ModelPrompt::ModelPrompt::getPrompt(node) } predicate endpointScores(DataFlow::Node endpoint, int encodedEndpointType, float score) { - internalEnpointScores(endpoint, mapEndpointType(encodedEndpointType)) and - mapScore(score, mapEndpointType(encodedEndpointType)) - } - - predicate internalEnpointScores(DataFlow::Node endpoint, string endpointType) = - remoteScoreEndpoints(getEndpointPrompt/2)(endpoint, endpointType) - - private string mapEndpointType(int encodedEndpointType) { - result = "no sink" and encodedEndpointType = 0 - or - result = "xss sink" and encodedEndpointType = 1 - or - result = "nosql sink" and encodedEndpointType = 2 - or - result = "sql sink" and encodedEndpointType = 3 - or - result = "tainted path sink" and encodedEndpointType = 4 - } - - private predicate mapScore(float score, string endpointType) { - ( - ( - endpointType = "xss sink" or - endpointType = "nosql sink" or - endpointType = "sql sink" or - endpointType = "tainted path sink" - ) and + exists(EndpointType endpointType | + endpointType.getEncoding() = encodedEndpointType and + internalEnpointScores(endpoint, endpointType.getDescription()) and score = 1.0 ) - /*or (score = 0.0 and any(endpointType)) */ + or + encodedEndpointType = any(EndpointType endpointType).getEncoding() and + score = 0.0 } + + predicate internalEnpointScores(DataFlow::Node endpoint, string prediction) = + remoteScoreEndpoints(getEndpointPrompt/2)(endpoint, prediction) } /**