diff --git a/python/ql/lib/semmle/python/Class.qll b/python/ql/lib/semmle/python/Class.qll index 52c6c5aa389..58a6504b547 100644 --- a/python/ql/lib/semmle/python/Class.qll +++ b/python/ql/lib/semmle/python/Class.qll @@ -91,6 +91,12 @@ class Class extends Class_, Scope, AstNode { /** Gets a method defined in this class */ Function getAMethod() { result.getScope() = this } + /** Gets the method defined in this class with the specified name, if any. */ + Function getMethod(string name) { + result = this.getAMethod() and + result.getName() = name + } + override Location getLocation() { py_scope_location(result, this) } /** Gets the scope (module, class or function) in which this class is defined */ diff --git a/python/ql/src/Classes/Comparisons/Comparisons.qll b/python/ql/src/Classes/Comparisons/Comparisons.qll new file mode 100644 index 00000000000..b835b07ef44 --- /dev/null +++ b/python/ql/src/Classes/Comparisons/Comparisons.qll @@ -0,0 +1,10 @@ +/** Helper definitions for reasoning about comparison methods. */ + +import python +import semmle.python.ApiGraphs + +/** Holds if `cls` has the `functools.total_ordering` decorator. */ +predicate totalOrdering(Class cls) { + cls.getADecorator() = + API::moduleImport("functools").getMember("total_ordering").asSource().asExpr() +} diff --git a/python/ql/src/Classes/Comparisons/EqualsOrNotEquals.ql b/python/ql/src/Classes/Comparisons/EqualsOrNotEquals.ql index adac5a20e87..feeada86682 100644 --- a/python/ql/src/Classes/Comparisons/EqualsOrNotEquals.ql +++ b/python/ql/src/Classes/Comparisons/EqualsOrNotEquals.ql @@ -2,7 +2,8 @@ * @name Inconsistent equality and inequality * @description Defining only an equality method or an inequality method for a class violates the object model. * @kind problem - * @tags reliability + * @tags quality + * reliability * correctness * @problem.severity warning * @sub-severity high @@ -11,38 +12,29 @@ */ import python -import Equality +import Comparisons +import semmle.python.dataflow.new.internal.DataFlowDispatch +import Classes.Equality -string equals_or_ne() { result = "__eq__" or result = "__ne__" } - -predicate total_ordering(Class cls) { - exists(Attribute a | a = cls.getADecorator() | a.getName() = "total_ordering") +predicate missingEquality(Class cls, Function defined, string missing) { + defined = cls.getMethod("__ne__") and + not exists(cls.getMethod("__eq__")) and + missing = "__eq__" or - exists(Name n | n = cls.getADecorator() | n.getId() = "total_ordering") + // In python 3, __ne__ automatically delegates to __eq__ if its not defined in the hierarchy + // However if it is defined in a superclass (and isn't a delegation method) then it will use the superclass method (which may be incorrect) + defined = cls.getMethod("__eq__") and + not exists(cls.getMethod("__ne__")) and + exists(Function neMeth | + neMeth = getADirectSuperclass+(cls).getMethod("__ne__") and + not neMeth instanceof DelegatingEqualityMethod + ) and + missing = "__ne__" } -CallableValue implemented_method(ClassValue c, string name) { - result = c.declaredAttribute(name) and name = equals_or_ne() -} - -string unimplemented_method(ClassValue c) { - not c.declaresAttribute(result) and result = equals_or_ne() -} - -predicate violates_equality_contract( - ClassValue c, string present, string missing, CallableValue method -) { - missing = unimplemented_method(c) and - method = implemented_method(c, present) and - not c.failedInference(_) and - not total_ordering(c.getScope()) and - /* Python 3 automatically implements __ne__ if __eq__ is defined, but not vice-versa */ - not (major_version() = 3 and present = "__eq__" and missing = "__ne__") and - not method.getScope() instanceof DelegatingEqualityMethod and - not c.lookup(missing).(CallableValue).getScope() instanceof DelegatingEqualityMethod -} - -from ClassValue c, string present, string missing, CallableValue method -where violates_equality_contract(c, present, missing, method) -select method, "Class $@ implements " + present + " but does not implement " + missing + ".", c, - c.getName() +from Class cls, Function defined, string missing +where + not totalOrdering(cls) and + missingEquality(cls, defined, missing) +select cls, "This class implements $@, but does not implement " + missing + ".", defined, + defined.getName() diff --git a/python/ql/src/Classes/Comparisons/IncompleteOrdering.ql b/python/ql/src/Classes/Comparisons/IncompleteOrdering.ql index bbb6ca5cf6d..882321cc3f5 100644 --- a/python/ql/src/Classes/Comparisons/IncompleteOrdering.ql +++ b/python/ql/src/Classes/Comparisons/IncompleteOrdering.ql @@ -14,29 +14,20 @@ import python import semmle.python.dataflow.new.internal.DataFlowDispatch import semmle.python.ApiGraphs - -predicate totalOrdering(Class cls) { - cls.getADecorator() = - API::moduleImport("functools").getMember("total_ordering").asSource().asExpr() -} - -Function getMethod(Class cls, string name) { - result = cls.getAMethod() and - result.getName() = name -} +import Comparisons predicate definesStrictOrdering(Class cls, Function meth) { - meth = getMethod(cls, "__lt__") + meth = cls.getMethod("__lt__") or - not exists(getMethod(cls, "__lt__")) and - meth = getMethod(cls, "__gt__") + not exists(cls.getMethod("__lt__")) and + meth = cls.getMethod("__gt__") } predicate definesNonStrictOrdering(Class cls, Function meth) { - meth = getMethod(cls, "__le__") + meth = cls.getMethod("__le__") or - not exists(getMethod(cls, "__le__")) and - meth = getMethod(cls, "__ge__") + not exists(cls.getMethod("__le__")) and + meth = cls.getMethod("__ge__") } predicate missingComparison(Class cls, Function defined, string missing) { @@ -53,5 +44,5 @@ from Class cls, Function defined, string missing where not totalOrdering(cls) and missingComparison(cls, defined, missing) -select cls, "This class implements $@, but does not implement an " + missing + " method.", defined, +select cls, "This class implements $@, but does not implement " + missing + ".", defined, defined.getName() diff --git a/python/ql/src/Classes/Comparisons/examples/EqualsOrNotEquals.py b/python/ql/src/Classes/Comparisons/examples/EqualsOrNotEquals.py index 7e1ece7685c..32bc26d4737 100644 --- a/python/ql/src/Classes/Comparisons/examples/EqualsOrNotEquals.py +++ b/python/ql/src/Classes/Comparisons/examples/EqualsOrNotEquals.py @@ -30,3 +30,27 @@ class PointUpdated(object): def __ne__(self, other): # Improved: equality and inequality method defined (hash method still missing) return not self == other + + +class A: + def __init__(self, a): + self.a = a + + def __eq__(self, other): + print("A eq") + return self.a == other.a + + def __ne__(self, other): + print("A ne") + return self.a != other.a + +class B(A): + def __init__(self, a, b): + self.a = a + self.b = b + + def __eq__(self, other): + print("B eq") + return self.a == other.a and self.b == other.b + +print(B(1,2) != B(1,3)) diff --git a/python/ql/src/Classes/Equality.qll b/python/ql/src/Classes/Equality.qll index 347f5057c38..08162399e3e 100644 --- a/python/ql/src/Classes/Equality.qll +++ b/python/ql/src/Classes/Equality.qll @@ -1,4 +1,7 @@ +/** Utility definitions for reasoning about equality methods. */ + import python +import semmle.python.dataflow.new.DataFlow private Attribute dictAccess(LocalVariable var) { result.getName() = "__dict__" and @@ -59,16 +62,28 @@ class IdentityEqMethod extends Function { /** An (in)equality method that delegates to its complement */ class DelegatingEqualityMethod extends Function { DelegatingEqualityMethod() { - exists(Return ret, UnaryExpr not_, Compare comp, Cmpop op, Parameter p0, Parameter p1 | + exists(Return ret, UnaryExpr not_, Expr comp, Parameter p0, Parameter p1 | ret.getScope() = this and ret.getValue() = not_ and not_.getOp() instanceof Not and - not_.getOperand() = comp and - comp.compares(p0.getVariable().getAnAccess(), op, p1.getVariable().getAnAccess()) + not_.getOperand() = comp | - this.getName() = "__eq__" and op instanceof NotEq + exists(Cmpop op | + comp.(Compare).compares(p0.getVariable().getAnAccess(), op, p1.getVariable().getAnAccess()) + | + this.getName() = "__eq__" and op instanceof NotEq + or + this.getName() = "__ne__" and op instanceof Eq + ) or - this.getName() = "__ne__" and op instanceof Eq + exists(DataFlow::MethodCallNode call, string name | + call.calls(DataFlow::exprNode(p0.getVariable().getAnAccess()), name) and + call.getArg(0).asExpr() = p1.getVariable().getAnAccess() + | + this.getName() = "__eq__" and name = "__ne__" + or + this.getName() = "__ne__" and name = "__eq__" + ) ) } }