Modernize inconsistent equality

This commit is contained in:
Joe Farebrother
2025-07-08 15:33:59 +01:00
parent 4c5c4e06c3
commit eb1b5a35d7
6 changed files with 92 additions and 54 deletions

View File

@@ -91,6 +91,12 @@ class Class extends Class_, Scope, AstNode {
/** Gets a method defined in this class */ /** Gets a method defined in this class */
Function getAMethod() { result.getScope() = this } Function getAMethod() { result.getScope() = this }
/** Gets the method defined in this class with the specified name, if any. */
Function getMethod(string name) {
result = this.getAMethod() and
result.getName() = name
}
override Location getLocation() { py_scope_location(result, this) } override Location getLocation() { py_scope_location(result, this) }
/** Gets the scope (module, class or function) in which this class is defined */ /** Gets the scope (module, class or function) in which this class is defined */

View File

@@ -0,0 +1,10 @@
/** Helper definitions for reasoning about comparison methods. */
import python
import semmle.python.ApiGraphs
/** Holds if `cls` has the `functools.total_ordering` decorator. */
predicate totalOrdering(Class cls) {
cls.getADecorator() =
API::moduleImport("functools").getMember("total_ordering").asSource().asExpr()
}

View File

@@ -2,7 +2,8 @@
* @name Inconsistent equality and inequality * @name Inconsistent equality and inequality
* @description Defining only an equality method or an inequality method for a class violates the object model. * @description Defining only an equality method or an inequality method for a class violates the object model.
* @kind problem * @kind problem
* @tags reliability * @tags quality
* reliability
* correctness * correctness
* @problem.severity warning * @problem.severity warning
* @sub-severity high * @sub-severity high
@@ -11,38 +12,29 @@
*/ */
import python import python
import Equality import Comparisons
import semmle.python.dataflow.new.internal.DataFlowDispatch
import Classes.Equality
string equals_or_ne() { result = "__eq__" or result = "__ne__" } predicate missingEquality(Class cls, Function defined, string missing) {
defined = cls.getMethod("__ne__") and
predicate total_ordering(Class cls) { not exists(cls.getMethod("__eq__")) and
exists(Attribute a | a = cls.getADecorator() | a.getName() = "total_ordering") missing = "__eq__"
or or
exists(Name n | n = cls.getADecorator() | n.getId() = "total_ordering") // In python 3, __ne__ automatically delegates to __eq__ if its not defined in the hierarchy
// However if it is defined in a superclass (and isn't a delegation method) then it will use the superclass method (which may be incorrect)
defined = cls.getMethod("__eq__") and
not exists(cls.getMethod("__ne__")) and
exists(Function neMeth |
neMeth = getADirectSuperclass+(cls).getMethod("__ne__") and
not neMeth instanceof DelegatingEqualityMethod
) and
missing = "__ne__"
} }
CallableValue implemented_method(ClassValue c, string name) { from Class cls, Function defined, string missing
result = c.declaredAttribute(name) and name = equals_or_ne() where
} not totalOrdering(cls) and
missingEquality(cls, defined, missing)
string unimplemented_method(ClassValue c) { select cls, "This class implements $@, but does not implement " + missing + ".", defined,
not c.declaresAttribute(result) and result = equals_or_ne() defined.getName()
}
predicate violates_equality_contract(
ClassValue c, string present, string missing, CallableValue method
) {
missing = unimplemented_method(c) and
method = implemented_method(c, present) and
not c.failedInference(_) and
not total_ordering(c.getScope()) and
/* Python 3 automatically implements __ne__ if __eq__ is defined, but not vice-versa */
not (major_version() = 3 and present = "__eq__" and missing = "__ne__") and
not method.getScope() instanceof DelegatingEqualityMethod and
not c.lookup(missing).(CallableValue).getScope() instanceof DelegatingEqualityMethod
}
from ClassValue c, string present, string missing, CallableValue method
where violates_equality_contract(c, present, missing, method)
select method, "Class $@ implements " + present + " but does not implement " + missing + ".", c,
c.getName()

View File

@@ -14,29 +14,20 @@
import python import python
import semmle.python.dataflow.new.internal.DataFlowDispatch import semmle.python.dataflow.new.internal.DataFlowDispatch
import semmle.python.ApiGraphs import semmle.python.ApiGraphs
import Comparisons
predicate totalOrdering(Class cls) {
cls.getADecorator() =
API::moduleImport("functools").getMember("total_ordering").asSource().asExpr()
}
Function getMethod(Class cls, string name) {
result = cls.getAMethod() and
result.getName() = name
}
predicate definesStrictOrdering(Class cls, Function meth) { predicate definesStrictOrdering(Class cls, Function meth) {
meth = getMethod(cls, "__lt__") meth = cls.getMethod("__lt__")
or or
not exists(getMethod(cls, "__lt__")) and not exists(cls.getMethod("__lt__")) and
meth = getMethod(cls, "__gt__") meth = cls.getMethod("__gt__")
} }
predicate definesNonStrictOrdering(Class cls, Function meth) { predicate definesNonStrictOrdering(Class cls, Function meth) {
meth = getMethod(cls, "__le__") meth = cls.getMethod("__le__")
or or
not exists(getMethod(cls, "__le__")) and not exists(cls.getMethod("__le__")) and
meth = getMethod(cls, "__ge__") meth = cls.getMethod("__ge__")
} }
predicate missingComparison(Class cls, Function defined, string missing) { predicate missingComparison(Class cls, Function defined, string missing) {
@@ -53,5 +44,5 @@ from Class cls, Function defined, string missing
where where
not totalOrdering(cls) and not totalOrdering(cls) and
missingComparison(cls, defined, missing) missingComparison(cls, defined, missing)
select cls, "This class implements $@, but does not implement an " + missing + " method.", defined, select cls, "This class implements $@, but does not implement " + missing + ".", defined,
defined.getName() defined.getName()

View File

@@ -30,3 +30,27 @@ class PointUpdated(object):
def __ne__(self, other): # Improved: equality and inequality method defined (hash method still missing) def __ne__(self, other): # Improved: equality and inequality method defined (hash method still missing)
return not self == other return not self == other
class A:
def __init__(self, a):
self.a = a
def __eq__(self, other):
print("A eq")
return self.a == other.a
def __ne__(self, other):
print("A ne")
return self.a != other.a
class B(A):
def __init__(self, a, b):
self.a = a
self.b = b
def __eq__(self, other):
print("B eq")
return self.a == other.a and self.b == other.b
print(B(1,2) != B(1,3))

View File

@@ -1,4 +1,7 @@
/** Utility definitions for reasoning about equality methods. */
import python import python
import semmle.python.dataflow.new.DataFlow
private Attribute dictAccess(LocalVariable var) { private Attribute dictAccess(LocalVariable var) {
result.getName() = "__dict__" and result.getName() = "__dict__" and
@@ -59,16 +62,28 @@ class IdentityEqMethod extends Function {
/** An (in)equality method that delegates to its complement */ /** An (in)equality method that delegates to its complement */
class DelegatingEqualityMethod extends Function { class DelegatingEqualityMethod extends Function {
DelegatingEqualityMethod() { DelegatingEqualityMethod() {
exists(Return ret, UnaryExpr not_, Compare comp, Cmpop op, Parameter p0, Parameter p1 | exists(Return ret, UnaryExpr not_, Expr comp, Parameter p0, Parameter p1 |
ret.getScope() = this and ret.getScope() = this and
ret.getValue() = not_ and ret.getValue() = not_ and
not_.getOp() instanceof Not and not_.getOp() instanceof Not and
not_.getOperand() = comp and not_.getOperand() = comp
comp.compares(p0.getVariable().getAnAccess(), op, p1.getVariable().getAnAccess())
| |
this.getName() = "__eq__" and op instanceof NotEq exists(Cmpop op |
comp.(Compare).compares(p0.getVariable().getAnAccess(), op, p1.getVariable().getAnAccess())
|
this.getName() = "__eq__" and op instanceof NotEq
or
this.getName() = "__ne__" and op instanceof Eq
)
or or
this.getName() = "__ne__" and op instanceof Eq exists(DataFlow::MethodCallNode call, string name |
call.calls(DataFlow::exprNode(p0.getVariable().getAnAccess()), name) and
call.getArg(0).asExpr() = p1.getVariable().getAnAccess()
|
this.getName() = "__eq__" and name = "__ne__"
or
this.getName() = "__ne__" and name = "__eq__"
)
) )
} }
} }