Improve the structure of endpoint scoring

This commit is contained in:
tiferet
2022-12-06 12:28:49 -08:00
parent dfbfa5d27d
commit 9a8b0d7fb2
2 changed files with 39 additions and 25 deletions

View File

@@ -7,26 +7,38 @@
private import javascript
private import BaseScoring
private import EndpointFeatures as EndpointFeatures
private import PromptConfiguration
private import FeaturizationConfig
private import EndpointTypes
private import ModelPrompt as ModelPrompt
private string getACompatibleModelChecksum() {
availableMlModels(result, "javascript", _, "atm-endpoint-scoring")
}
// class RelevantFeaturizationConfig extends FeaturizationConfig {
// RelevantFeaturizationConfig() { this = "RelevantFeaturization" }
// override DataFlow::Node getAnEndpointToFeaturize() {
// getCfg().isEffectiveSource(result) and any(DataFlow::Configuration cfg).hasFlow(result, _)
// or
// getCfg().isEffectiveSink(result) and any(DataFlow::Configuration cfg).hasFlow(_, result)
// }
// }
module ModelScoring {
predicate getARequestedEndpoint(DataFlow::Node node, string prompt) {
exists(PromptConfiguration cfg |
cfg.getPrompt(node) = prompt and cfg.getAnEndpointToFeaturize() = node
)
/**
* A featurization config that only featurizes new candidate endpoints that are part of a flow
* path.
*/
class RelevantFeaturizationConfig extends FeaturizationConfig {
RelevantFeaturizationConfig() { this = "RelevantFeaturization" }
override DataFlow::Node getAnEndpointToFeaturize() {
getCfg().isEffectiveSource(result) and any(DataFlow::Configuration cfg).hasFlow(result, _)
or
getCfg().isEffectiveSink(result) and any(DataFlow::Configuration cfg).hasFlow(_, result)
}
}
DataFlow::Node getARequestedEndpoint() {
result = any(FeaturizationConfig cfg).getAnEndpointToFeaturize()
}
private int getARequestedEndpointType() { result = any(EndpointType type).getEncoding() }
predicate getEndpointPrompt(DataFlow::Node node, string prompt) {
node = getARequestedEndpoint() and
prompt = ModelPrompt::ModelPrompt::getPrompt(node)
}
predicate endpointScores(DataFlow::Node endpoint, int encodedEndpointType, float score) {
@@ -35,7 +47,7 @@ module ModelScoring {
}
predicate internalEnpointScores(DataFlow::Node endpoint, string endpointType) =
remoteScoreEndpoints(getARequestedEndpoint/2)(endpoint, endpointType)
remoteScoreEndpoints(getEndpointPrompt/2)(endpoint, endpointType)
private string mapEndpointType(int encodedEndpointType) {
result = "no sink" and encodedEndpointType = 0
@@ -160,13 +172,16 @@ class EndpointScoringResults extends ScoringResults {
)
}
}
// module Debugging {
// query predicate hopInputEndpoints(DataFlow::Node endpoint) {
// endpoint = ModelScoring::getARequestedEndpoint()
// }
// query predicate endpointScores = ModelScoring::endpointScores/3;
// query predicate shouldResultBeIncluded(DataFlow::Node source, DataFlow::Node sink) {
// any(ScoringResults scoringResults).shouldResultBeIncluded(source, sink) and
// any(DataFlow::Configuration cfg).hasFlow(source, sink)
// }
// }
module Debugging {
query predicate hopInputEndpoints(DataFlow::Node endpoint) {
endpoint = ModelScoring::getARequestedEndpoint()
}
query predicate endpointScores = ModelScoring::endpointScores/3;
query predicate shouldResultBeIncluded(DataFlow::Node source, DataFlow::Node sink) {
any(ScoringResults scoringResults).shouldResultBeIncluded(source, sink) and
any(DataFlow::Configuration cfg).hasFlow(source, sink)
}
}

View File

@@ -16,7 +16,6 @@
import experimental.adaptivethreatmodeling.SqlInjectionATM
import ATM::ResultsInfo
import DataFlow::PathGraph
import experimental.adaptivethreatmodeling.PromptConfiguration
from AtmConfig cfg, DataFlow::PathNode source, DataFlow::PathNode sink, float score
where cfg.hasBoostedFlowPath(source, sink, score)