Modernize inconsistent equality

This commit is contained in:
Joe Farebrother
2025-07-08 15:33:59 +01:00
parent 4c5c4e06c3
commit eb1b5a35d7
6 changed files with 92 additions and 54 deletions

View File

@@ -91,6 +91,12 @@ class Class extends Class_, Scope, AstNode {
/** Gets a method defined in this class */
Function getAMethod() { result.getScope() = this }
/** Gets the method defined in this class with the specified name, if any. */
Function getMethod(string name) {
result = this.getAMethod() and
result.getName() = name
}
override Location getLocation() { py_scope_location(result, this) }
/** Gets the scope (module, class or function) in which this class is defined */

View File

@@ -0,0 +1,10 @@
/** Helper definitions for reasoning about comparison methods. */
import python
import semmle.python.ApiGraphs
/** Holds if `cls` has the `functools.total_ordering` decorator. */
predicate totalOrdering(Class cls) {
cls.getADecorator() =
API::moduleImport("functools").getMember("total_ordering").asSource().asExpr()
}

View File

@@ -2,7 +2,8 @@
* @name Inconsistent equality and inequality
* @description Defining only an equality method or an inequality method for a class violates the object model.
* @kind problem
* @tags reliability
* @tags quality
* reliability
* correctness
* @problem.severity warning
* @sub-severity high
@@ -11,38 +12,29 @@
*/
import python
import Equality
import Comparisons
import semmle.python.dataflow.new.internal.DataFlowDispatch
import Classes.Equality
string equals_or_ne() { result = "__eq__" or result = "__ne__" }
predicate total_ordering(Class cls) {
exists(Attribute a | a = cls.getADecorator() | a.getName() = "total_ordering")
predicate missingEquality(Class cls, Function defined, string missing) {
defined = cls.getMethod("__ne__") and
not exists(cls.getMethod("__eq__")) and
missing = "__eq__"
or
exists(Name n | n = cls.getADecorator() | n.getId() = "total_ordering")
// In python 3, __ne__ automatically delegates to __eq__ if its not defined in the hierarchy
// However if it is defined in a superclass (and isn't a delegation method) then it will use the superclass method (which may be incorrect)
defined = cls.getMethod("__eq__") and
not exists(cls.getMethod("__ne__")) and
exists(Function neMeth |
neMeth = getADirectSuperclass+(cls).getMethod("__ne__") and
not neMeth instanceof DelegatingEqualityMethod
) and
missing = "__ne__"
}
CallableValue implemented_method(ClassValue c, string name) {
result = c.declaredAttribute(name) and name = equals_or_ne()
}
string unimplemented_method(ClassValue c) {
not c.declaresAttribute(result) and result = equals_or_ne()
}
predicate violates_equality_contract(
ClassValue c, string present, string missing, CallableValue method
) {
missing = unimplemented_method(c) and
method = implemented_method(c, present) and
not c.failedInference(_) and
not total_ordering(c.getScope()) and
/* Python 3 automatically implements __ne__ if __eq__ is defined, but not vice-versa */
not (major_version() = 3 and present = "__eq__" and missing = "__ne__") and
not method.getScope() instanceof DelegatingEqualityMethod and
not c.lookup(missing).(CallableValue).getScope() instanceof DelegatingEqualityMethod
}
from ClassValue c, string present, string missing, CallableValue method
where violates_equality_contract(c, present, missing, method)
select method, "Class $@ implements " + present + " but does not implement " + missing + ".", c,
c.getName()
from Class cls, Function defined, string missing
where
not totalOrdering(cls) and
missingEquality(cls, defined, missing)
select cls, "This class implements $@, but does not implement " + missing + ".", defined,
defined.getName()

View File

@@ -14,29 +14,20 @@
import python
import semmle.python.dataflow.new.internal.DataFlowDispatch
import semmle.python.ApiGraphs
predicate totalOrdering(Class cls) {
cls.getADecorator() =
API::moduleImport("functools").getMember("total_ordering").asSource().asExpr()
}
Function getMethod(Class cls, string name) {
result = cls.getAMethod() and
result.getName() = name
}
import Comparisons
predicate definesStrictOrdering(Class cls, Function meth) {
meth = getMethod(cls, "__lt__")
meth = cls.getMethod("__lt__")
or
not exists(getMethod(cls, "__lt__")) and
meth = getMethod(cls, "__gt__")
not exists(cls.getMethod("__lt__")) and
meth = cls.getMethod("__gt__")
}
predicate definesNonStrictOrdering(Class cls, Function meth) {
meth = getMethod(cls, "__le__")
meth = cls.getMethod("__le__")
or
not exists(getMethod(cls, "__le__")) and
meth = getMethod(cls, "__ge__")
not exists(cls.getMethod("__le__")) and
meth = cls.getMethod("__ge__")
}
predicate missingComparison(Class cls, Function defined, string missing) {
@@ -53,5 +44,5 @@ from Class cls, Function defined, string missing
where
not totalOrdering(cls) and
missingComparison(cls, defined, missing)
select cls, "This class implements $@, but does not implement an " + missing + " method.", defined,
select cls, "This class implements $@, but does not implement " + missing + ".", defined,
defined.getName()

View File

@@ -30,3 +30,27 @@ class PointUpdated(object):
def __ne__(self, other): # Improved: equality and inequality method defined (hash method still missing)
return not self == other
class A:
def __init__(self, a):
self.a = a
def __eq__(self, other):
print("A eq")
return self.a == other.a
def __ne__(self, other):
print("A ne")
return self.a != other.a
class B(A):
def __init__(self, a, b):
self.a = a
self.b = b
def __eq__(self, other):
print("B eq")
return self.a == other.a and self.b == other.b
print(B(1,2) != B(1,3))

View File

@@ -1,4 +1,7 @@
/** Utility definitions for reasoning about equality methods. */
import python
import semmle.python.dataflow.new.DataFlow
private Attribute dictAccess(LocalVariable var) {
result.getName() = "__dict__" and
@@ -59,16 +62,28 @@ class IdentityEqMethod extends Function {
/** An (in)equality method that delegates to its complement */
class DelegatingEqualityMethod extends Function {
DelegatingEqualityMethod() {
exists(Return ret, UnaryExpr not_, Compare comp, Cmpop op, Parameter p0, Parameter p1 |
exists(Return ret, UnaryExpr not_, Expr comp, Parameter p0, Parameter p1 |
ret.getScope() = this and
ret.getValue() = not_ and
not_.getOp() instanceof Not and
not_.getOperand() = comp and
comp.compares(p0.getVariable().getAnAccess(), op, p1.getVariable().getAnAccess())
not_.getOperand() = comp
|
this.getName() = "__eq__" and op instanceof NotEq
exists(Cmpop op |
comp.(Compare).compares(p0.getVariable().getAnAccess(), op, p1.getVariable().getAnAccess())
|
this.getName() = "__eq__" and op instanceof NotEq
or
this.getName() = "__ne__" and op instanceof Eq
)
or
this.getName() = "__ne__" and op instanceof Eq
exists(DataFlow::MethodCallNode call, string name |
call.calls(DataFlow::exprNode(p0.getVariable().getAnAccess()), name) and
call.getArg(0).asExpr() = p1.getVariable().getAnAccess()
|
this.getName() = "__eq__" and name = "__ne__"
or
this.getName() = "__ne__" and name = "__eq__"
)
)
}
}