Guards: Generalise wrapper guards.

This commit is contained in:
Anders Schack-Mulligen
2025-07-23 12:59:59 +02:00
parent 6e52df1639
commit 0c31a80f3c
4 changed files with 128 additions and 29 deletions

View File

@@ -344,11 +344,8 @@ private module GuardsInput implements SharedGuards::InputSig<Location> {
final private class FinalMethod = Method;
class BooleanMethod extends FinalMethod {
BooleanMethod() {
super.getReturnType().(PrimitiveType).hasName("boolean") and
not super.isOverridable()
}
class NonOverridableMethod extends FinalMethod {
NonOverridableMethod() { not super.isOverridable() }
Parameter getParameter(ParameterPosition ppos) {
super.getParameter(ppos) = result and
@@ -363,14 +360,14 @@ private module GuardsInput implements SharedGuards::InputSig<Location> {
}
}
private predicate booleanMethodCall(MethodCall call, BooleanMethod m) {
private predicate nonOverridableMethodCall(MethodCall call, NonOverridableMethod m) {
call.getMethod().getSourceDeclaration() = m
}
class BooleanMethodCall extends GuardsInput::Expr instanceof MethodCall {
BooleanMethodCall() { booleanMethodCall(this, _) }
class NonOverridableMethodCall extends GuardsInput::Expr instanceof MethodCall {
NonOverridableMethodCall() { nonOverridableMethodCall(this, _) }
BooleanMethod getMethod() { booleanMethodCall(this, result) }
NonOverridableMethod getMethod() { nonOverridableMethodCall(this, result) }
GuardsInput::Expr getArgument(ArgumentPosition apos) { result = super.getArgument(apos) }
}

View File

@@ -143,4 +143,63 @@ public class Guards {
chk(); // $ guarded=found:true guarded='i < a.length:false'
}
}
public static boolean testNotNull1(String input) {
return input != null && input.length() > 0;
}
public static boolean testNotNull2(String input) {
if (input == null) return false;
return input.length() > 0;
}
public static int getNumOrDefault(Integer number) {
return number == null ? 0 : number;
}
public static String concatNonNull(String s1, String s2) {
if (s1 == null || s2 == null) return null;
return s1 + s2;
}
public static Status testEnumWrapper(boolean flag) {
return flag ? Status.SUCCESS : Status.FAILURE;
}
enum Status { SUCCESS, FAILURE }
void testWrappers(String s, Integer i) {
if (testNotNull1(s)) {
chk(); // $ guarded='s:not null' guarded=testNotNull1(...):true
} else {
chk(); // $ guarded=testNotNull1(...):false
}
if (testNotNull2(s)) {
chk(); // $ guarded='s:not null' guarded=testNotNull2(...):true
} else {
chk(); // $ guarded=testNotNull2(...):false
}
if (0 == getNumOrDefault(i)) {
chk(); // $ guarded='0 == getNumOrDefault(...):true' guarded='getNumOrDefault(...):0'
} else {
chk(); // $ guarded='0 == getNumOrDefault(...):false' guarded='getNumOrDefault(...):not 0' guarded='i:not 0' guarded='i:not null'
}
if (null == concatNonNull(s, "suffix")) {
chk(); // $ guarded='concatNonNull(...):null' guarded='null == concatNonNull(...):true'
} else {
chk(); // $ guarded='concatNonNull(...):not null' guarded='null == concatNonNull(...):false' guarded='s:not null'
}
switch (testEnumWrapper(g(1))) {
case SUCCESS:
chk(); // $ guarded='testEnumWrapper(...):SUCCESS' guarded='testEnumWrapper(...):match SUCCESS' guarded=g(1):true
break;
case FAILURE:
chk(); // $ guarded='testEnumWrapper(...):FAILURE' guarded='testEnumWrapper(...):match FAILURE' guarded=g(1):false
break;
}
}
}

View File

@@ -89,3 +89,26 @@
| Guards.java:139:9:139:13 | chk(...) | found:true |
| Guards.java:143:7:143:11 | chk(...) | 'i < a.length:false' |
| Guards.java:143:7:143:11 | chk(...) | found:true |
| Guards.java:173:7:173:11 | chk(...) | 's:not null' |
| Guards.java:173:7:173:11 | chk(...) | testNotNull1(...):true |
| Guards.java:175:7:175:11 | chk(...) | testNotNull1(...):false |
| Guards.java:179:7:179:11 | chk(...) | 's:not null' |
| Guards.java:179:7:179:11 | chk(...) | testNotNull2(...):true |
| Guards.java:181:7:181:11 | chk(...) | testNotNull2(...):false |
| Guards.java:185:7:185:11 | chk(...) | '0 == getNumOrDefault(...):true' |
| Guards.java:185:7:185:11 | chk(...) | 'getNumOrDefault(...):0' |
| Guards.java:187:7:187:11 | chk(...) | '0 == getNumOrDefault(...):false' |
| Guards.java:187:7:187:11 | chk(...) | 'getNumOrDefault(...):not 0' |
| Guards.java:187:7:187:11 | chk(...) | 'i:not 0' |
| Guards.java:187:7:187:11 | chk(...) | 'i:not null' |
| Guards.java:191:7:191:11 | chk(...) | 'concatNonNull(...):null' |
| Guards.java:191:7:191:11 | chk(...) | 'null == concatNonNull(...):true' |
| Guards.java:193:7:193:11 | chk(...) | 'concatNonNull(...):not null' |
| Guards.java:193:7:193:11 | chk(...) | 'null == concatNonNull(...):false' |
| Guards.java:193:7:193:11 | chk(...) | 's:not null' |
| Guards.java:198:9:198:13 | chk(...) | 'testEnumWrapper(...):SUCCESS' |
| Guards.java:198:9:198:13 | chk(...) | 'testEnumWrapper(...):match SUCCESS' |
| Guards.java:198:9:198:13 | chk(...) | g(1):true |
| Guards.java:201:9:201:13 | chk(...) | 'testEnumWrapper(...):FAILURE' |
| Guards.java:201:9:201:13 | chk(...) | 'testEnumWrapper(...):match FAILURE' |
| Guards.java:201:9:201:13 | chk(...) | g(1):false |

View File

@@ -234,16 +234,16 @@ signature module InputSig<LocationSig Location> {
*/
predicate parameterMatch(ParameterPosition ppos, ArgumentPosition apos);
/** A non-overridable method with a boolean return value. */
class BooleanMethod {
/** A non-overridable method. */
class NonOverridableMethod {
Parameter getParameter(ParameterPosition ppos);
/** Gets an expression being returned by this method. */
Expr getAReturnExpr();
}
class BooleanMethodCall extends Expr {
BooleanMethod getMethod();
class NonOverridableMethodCall extends Expr {
NonOverridableMethod getMethod();
Expr getArgument(ArgumentPosition apos);
}
@@ -998,17 +998,32 @@ module Make<LocationSig Location, InputSig<Location> Input> {
final private class FinalExpr = Expr;
private class ReturnExpr extends FinalExpr {
ReturnExpr() { any(BooleanMethod m).getAReturnExpr() = this }
ReturnExpr() { any(NonOverridableMethod m).getAReturnExpr() = this }
NonOverridableMethod getMethod() { result.getAReturnExpr() = this }
pragma[nomagic]
BasicBlock getBasicBlock() { result = super.getBasicBlock() }
}
private predicate booleanReturnGuard(Guard guard, GuardValue val) {
guard instanceof ReturnExpr and exists(val.asBooleanValue())
private predicate relevantCallValue(NonOverridableMethodCall call, GuardValue val) {
BranchImplies::guardControls(call, val, _, _) or
ReturnImplies::guardControls(call, val, _, _)
}
private module ReturnImplies = ImpliesTC<booleanReturnGuard/2>;
private predicate relevantReturnValue(NonOverridableMethod m, GuardValue val) {
exists(NonOverridableMethodCall call |
relevantCallValue(call, val) and
call.getMethod() = m and
not val instanceof TException
)
}
private predicate returnGuard(Guard guard, GuardValue val) {
relevantReturnValue(guard.(ReturnExpr).getMethod(), val)
}
private module ReturnImplies = ImpliesTC<returnGuard/2>;
/**
* Holds if `ret` is a return expression in a non-overridable method that
@@ -1016,32 +1031,36 @@ module Make<LocationSig Location, InputSig<Location> Input> {
* parameter has the value `val`.
*/
private predicate validReturnInCustomGuard(
ReturnExpr ret, ParameterPosition ppos, boolean retval, GuardValue val
ReturnExpr ret, ParameterPosition ppos, GuardValue retval, GuardValue val
) {
exists(BooleanMethod m, SsaDefinition param |
exists(NonOverridableMethod m, SsaDefinition param |
m.getAReturnExpr() = ret and
parameterDefinition(m.getParameter(ppos), param)
|
exists(Guard g0, GuardValue v0 |
g0.directlyValueControls(ret.getBasicBlock(), v0) and
BranchImplies::ssaControls(param, val, g0, v0) and
retval = [true, false]
relevantReturnValue(m, retval)
)
or
ReturnImplies::ssaControls(param, val, ret,
any(GuardValue r | r.asBooleanValue() = retval))
ReturnImplies::ssaControls(param, val, ret, retval)
)
}
/**
* Gets a non-overridable method with a boolean return value that performs a check
* on the `ppos`th parameter. A return value equal to `retval` allows us to conclude
* Gets a non-overridable method that performs a check on the `ppos`th
* parameter. A return value equal to `retval` allows us to conclude
* that the argument has the value `val`.
*/
private BooleanMethod customGuard(ParameterPosition ppos, boolean retval, GuardValue val) {
private NonOverridableMethod customGuard(
ParameterPosition ppos, GuardValue retval, GuardValue val
) {
forex(ReturnExpr ret |
result.getAReturnExpr() = ret and
not ret.(ConstantExpr).asBooleanValue() = retval.booleanNot()
not exists(GuardValue notRetval |
exprHasValue(ret, notRetval) and
disjointValues(notRetval, retval)
)
|
validReturnInCustomGuard(ret, ppos, retval, val)
)
@@ -1056,11 +1075,12 @@ module Make<LocationSig Location, InputSig<Location> Input> {
* custom guard wrappers.
*/
predicate additionalImpliesStep(PreGuard g1, GuardValue v1, PreGuard g2, GuardValue v2) {
exists(BooleanMethodCall call, ParameterPosition ppos, ArgumentPosition apos |
exists(NonOverridableMethodCall call, ParameterPosition ppos, ArgumentPosition apos |
g1 = call and
call.getMethod() = customGuard(ppos, v1.asBooleanValue(), v2) and
call.getMethod() = customGuard(ppos, v1, v2) and
call.getArgument(apos) = g2 and
parameterMatch(pragma[only_bind_out](ppos), pragma[only_bind_out](apos))
parameterMatch(pragma[only_bind_out](ppos), pragma[only_bind_out](apos)) and
not exprHasValue(g2, v2) // disregard trivial guard
)
}
}