diff --git a/javascript/ql/experimental/adaptivethreatmodeling/lib/experimental/adaptivethreatmodeling/EndpointScoring.qll b/javascript/ql/experimental/adaptivethreatmodeling/lib/experimental/adaptivethreatmodeling/EndpointScoring.qll index 9913a2fb21a..e5e90c5038f 100644 --- a/javascript/ql/experimental/adaptivethreatmodeling/lib/experimental/adaptivethreatmodeling/EndpointScoring.qll +++ b/javascript/ql/experimental/adaptivethreatmodeling/lib/experimental/adaptivethreatmodeling/EndpointScoring.qll @@ -7,26 +7,38 @@ private import javascript private import BaseScoring private import EndpointFeatures as EndpointFeatures -private import PromptConfiguration +private import FeaturizationConfig private import EndpointTypes +private import ModelPrompt as ModelPrompt private string getACompatibleModelChecksum() { availableMlModels(result, "javascript", _, "atm-endpoint-scoring") } -// class RelevantFeaturizationConfig extends FeaturizationConfig { -// RelevantFeaturizationConfig() { this = "RelevantFeaturization" } -// override DataFlow::Node getAnEndpointToFeaturize() { -// getCfg().isEffectiveSource(result) and any(DataFlow::Configuration cfg).hasFlow(result, _) -// or -// getCfg().isEffectiveSink(result) and any(DataFlow::Configuration cfg).hasFlow(_, result) -// } -// } module ModelScoring { - predicate getARequestedEndpoint(DataFlow::Node node, string prompt) { - exists(PromptConfiguration cfg | - cfg.getPrompt(node) = prompt and cfg.getAnEndpointToFeaturize() = node - ) + /** + * A featurization config that only featurizes new candidate endpoints that are part of a flow + * path. + */ + class RelevantFeaturizationConfig extends FeaturizationConfig { + RelevantFeaturizationConfig() { this = "RelevantFeaturization" } + + override DataFlow::Node getAnEndpointToFeaturize() { + getCfg().isEffectiveSource(result) and any(DataFlow::Configuration cfg).hasFlow(result, _) + or + getCfg().isEffectiveSink(result) and any(DataFlow::Configuration cfg).hasFlow(_, result) + } + } + + DataFlow::Node getARequestedEndpoint() { + result = any(FeaturizationConfig cfg).getAnEndpointToFeaturize() + } + + private int getARequestedEndpointType() { result = any(EndpointType type).getEncoding() } + + predicate getEndpointPrompt(DataFlow::Node node, string prompt) { + node = getARequestedEndpoint() and + prompt = ModelPrompt::ModelPrompt::getPrompt(node) } predicate endpointScores(DataFlow::Node endpoint, int encodedEndpointType, float score) { @@ -35,7 +47,7 @@ module ModelScoring { } predicate internalEnpointScores(DataFlow::Node endpoint, string endpointType) = - remoteScoreEndpoints(getARequestedEndpoint/2)(endpoint, endpointType) + remoteScoreEndpoints(getEndpointPrompt/2)(endpoint, endpointType) private string mapEndpointType(int encodedEndpointType) { result = "no sink" and encodedEndpointType = 0 @@ -160,13 +172,16 @@ class EndpointScoringResults extends ScoringResults { ) } } -// module Debugging { -// query predicate hopInputEndpoints(DataFlow::Node endpoint) { -// endpoint = ModelScoring::getARequestedEndpoint() -// } -// query predicate endpointScores = ModelScoring::endpointScores/3; -// query predicate shouldResultBeIncluded(DataFlow::Node source, DataFlow::Node sink) { -// any(ScoringResults scoringResults).shouldResultBeIncluded(source, sink) and -// any(DataFlow::Configuration cfg).hasFlow(source, sink) -// } -// } + +module Debugging { + query predicate hopInputEndpoints(DataFlow::Node endpoint) { + endpoint = ModelScoring::getARequestedEndpoint() + } + + query predicate endpointScores = ModelScoring::endpointScores/3; + + query predicate shouldResultBeIncluded(DataFlow::Node source, DataFlow::Node sink) { + any(ScoringResults scoringResults).shouldResultBeIncluded(source, sink) and + any(DataFlow::Configuration cfg).hasFlow(source, sink) + } +} diff --git a/javascript/ql/experimental/adaptivethreatmodeling/src/SqlInjectionATM.ql b/javascript/ql/experimental/adaptivethreatmodeling/src/SqlInjectionATM.ql index f3ab8531743..fdeb79de145 100644 --- a/javascript/ql/experimental/adaptivethreatmodeling/src/SqlInjectionATM.ql +++ b/javascript/ql/experimental/adaptivethreatmodeling/src/SqlInjectionATM.ql @@ -16,7 +16,6 @@ import experimental.adaptivethreatmodeling.SqlInjectionATM import ATM::ResultsInfo import DataFlow::PathGraph -import experimental.adaptivethreatmodeling.PromptConfiguration from AtmConfig cfg, DataFlow::PathNode source, DataFlow::PathNode sink, float score where cfg.hasBoostedFlowPath(source, sink, score)