mirror of
https://github.com/github/codeql.git
synced 2025-12-16 16:53:25 +01:00
Modernize inconsistent equality
This commit is contained in:
@@ -91,6 +91,12 @@ class Class extends Class_, Scope, AstNode {
|
||||
/** Gets a method defined in this class */
|
||||
Function getAMethod() { result.getScope() = this }
|
||||
|
||||
/** Gets the method defined in this class with the specified name, if any. */
|
||||
Function getMethod(string name) {
|
||||
result = this.getAMethod() and
|
||||
result.getName() = name
|
||||
}
|
||||
|
||||
override Location getLocation() { py_scope_location(result, this) }
|
||||
|
||||
/** Gets the scope (module, class or function) in which this class is defined */
|
||||
|
||||
10
python/ql/src/Classes/Comparisons/Comparisons.qll
Normal file
10
python/ql/src/Classes/Comparisons/Comparisons.qll
Normal file
@@ -0,0 +1,10 @@
|
||||
/** Helper definitions for reasoning about comparison methods. */
|
||||
|
||||
import python
|
||||
import semmle.python.ApiGraphs
|
||||
|
||||
/** Holds if `cls` has the `functools.total_ordering` decorator. */
|
||||
predicate totalOrdering(Class cls) {
|
||||
cls.getADecorator() =
|
||||
API::moduleImport("functools").getMember("total_ordering").asSource().asExpr()
|
||||
}
|
||||
@@ -2,7 +2,8 @@
|
||||
* @name Inconsistent equality and inequality
|
||||
* @description Defining only an equality method or an inequality method for a class violates the object model.
|
||||
* @kind problem
|
||||
* @tags reliability
|
||||
* @tags quality
|
||||
* reliability
|
||||
* correctness
|
||||
* @problem.severity warning
|
||||
* @sub-severity high
|
||||
@@ -11,38 +12,29 @@
|
||||
*/
|
||||
|
||||
import python
|
||||
import Equality
|
||||
import Comparisons
|
||||
import semmle.python.dataflow.new.internal.DataFlowDispatch
|
||||
import Classes.Equality
|
||||
|
||||
string equals_or_ne() { result = "__eq__" or result = "__ne__" }
|
||||
|
||||
predicate total_ordering(Class cls) {
|
||||
exists(Attribute a | a = cls.getADecorator() | a.getName() = "total_ordering")
|
||||
predicate missingEquality(Class cls, Function defined, string missing) {
|
||||
defined = cls.getMethod("__ne__") and
|
||||
not exists(cls.getMethod("__eq__")) and
|
||||
missing = "__eq__"
|
||||
or
|
||||
exists(Name n | n = cls.getADecorator() | n.getId() = "total_ordering")
|
||||
// In python 3, __ne__ automatically delegates to __eq__ if its not defined in the hierarchy
|
||||
// However if it is defined in a superclass (and isn't a delegation method) then it will use the superclass method (which may be incorrect)
|
||||
defined = cls.getMethod("__eq__") and
|
||||
not exists(cls.getMethod("__ne__")) and
|
||||
exists(Function neMeth |
|
||||
neMeth = getADirectSuperclass+(cls).getMethod("__ne__") and
|
||||
not neMeth instanceof DelegatingEqualityMethod
|
||||
) and
|
||||
missing = "__ne__"
|
||||
}
|
||||
|
||||
CallableValue implemented_method(ClassValue c, string name) {
|
||||
result = c.declaredAttribute(name) and name = equals_or_ne()
|
||||
}
|
||||
|
||||
string unimplemented_method(ClassValue c) {
|
||||
not c.declaresAttribute(result) and result = equals_or_ne()
|
||||
}
|
||||
|
||||
predicate violates_equality_contract(
|
||||
ClassValue c, string present, string missing, CallableValue method
|
||||
) {
|
||||
missing = unimplemented_method(c) and
|
||||
method = implemented_method(c, present) and
|
||||
not c.failedInference(_) and
|
||||
not total_ordering(c.getScope()) and
|
||||
/* Python 3 automatically implements __ne__ if __eq__ is defined, but not vice-versa */
|
||||
not (major_version() = 3 and present = "__eq__" and missing = "__ne__") and
|
||||
not method.getScope() instanceof DelegatingEqualityMethod and
|
||||
not c.lookup(missing).(CallableValue).getScope() instanceof DelegatingEqualityMethod
|
||||
}
|
||||
|
||||
from ClassValue c, string present, string missing, CallableValue method
|
||||
where violates_equality_contract(c, present, missing, method)
|
||||
select method, "Class $@ implements " + present + " but does not implement " + missing + ".", c,
|
||||
c.getName()
|
||||
from Class cls, Function defined, string missing
|
||||
where
|
||||
not totalOrdering(cls) and
|
||||
missingEquality(cls, defined, missing)
|
||||
select cls, "This class implements $@, but does not implement " + missing + ".", defined,
|
||||
defined.getName()
|
||||
|
||||
@@ -14,29 +14,20 @@
|
||||
import python
|
||||
import semmle.python.dataflow.new.internal.DataFlowDispatch
|
||||
import semmle.python.ApiGraphs
|
||||
|
||||
predicate totalOrdering(Class cls) {
|
||||
cls.getADecorator() =
|
||||
API::moduleImport("functools").getMember("total_ordering").asSource().asExpr()
|
||||
}
|
||||
|
||||
Function getMethod(Class cls, string name) {
|
||||
result = cls.getAMethod() and
|
||||
result.getName() = name
|
||||
}
|
||||
import Comparisons
|
||||
|
||||
predicate definesStrictOrdering(Class cls, Function meth) {
|
||||
meth = getMethod(cls, "__lt__")
|
||||
meth = cls.getMethod("__lt__")
|
||||
or
|
||||
not exists(getMethod(cls, "__lt__")) and
|
||||
meth = getMethod(cls, "__gt__")
|
||||
not exists(cls.getMethod("__lt__")) and
|
||||
meth = cls.getMethod("__gt__")
|
||||
}
|
||||
|
||||
predicate definesNonStrictOrdering(Class cls, Function meth) {
|
||||
meth = getMethod(cls, "__le__")
|
||||
meth = cls.getMethod("__le__")
|
||||
or
|
||||
not exists(getMethod(cls, "__le__")) and
|
||||
meth = getMethod(cls, "__ge__")
|
||||
not exists(cls.getMethod("__le__")) and
|
||||
meth = cls.getMethod("__ge__")
|
||||
}
|
||||
|
||||
predicate missingComparison(Class cls, Function defined, string missing) {
|
||||
@@ -53,5 +44,5 @@ from Class cls, Function defined, string missing
|
||||
where
|
||||
not totalOrdering(cls) and
|
||||
missingComparison(cls, defined, missing)
|
||||
select cls, "This class implements $@, but does not implement an " + missing + " method.", defined,
|
||||
select cls, "This class implements $@, but does not implement " + missing + ".", defined,
|
||||
defined.getName()
|
||||
|
||||
@@ -30,3 +30,27 @@ class PointUpdated(object):
|
||||
def __ne__(self, other): # Improved: equality and inequality method defined (hash method still missing)
|
||||
return not self == other
|
||||
|
||||
|
||||
|
||||
class A:
|
||||
def __init__(self, a):
|
||||
self.a = a
|
||||
|
||||
def __eq__(self, other):
|
||||
print("A eq")
|
||||
return self.a == other.a
|
||||
|
||||
def __ne__(self, other):
|
||||
print("A ne")
|
||||
return self.a != other.a
|
||||
|
||||
class B(A):
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def __eq__(self, other):
|
||||
print("B eq")
|
||||
return self.a == other.a and self.b == other.b
|
||||
|
||||
print(B(1,2) != B(1,3))
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
/** Utility definitions for reasoning about equality methods. */
|
||||
|
||||
import python
|
||||
import semmle.python.dataflow.new.DataFlow
|
||||
|
||||
private Attribute dictAccess(LocalVariable var) {
|
||||
result.getName() = "__dict__" and
|
||||
@@ -59,16 +62,28 @@ class IdentityEqMethod extends Function {
|
||||
/** An (in)equality method that delegates to its complement */
|
||||
class DelegatingEqualityMethod extends Function {
|
||||
DelegatingEqualityMethod() {
|
||||
exists(Return ret, UnaryExpr not_, Compare comp, Cmpop op, Parameter p0, Parameter p1 |
|
||||
exists(Return ret, UnaryExpr not_, Expr comp, Parameter p0, Parameter p1 |
|
||||
ret.getScope() = this and
|
||||
ret.getValue() = not_ and
|
||||
not_.getOp() instanceof Not and
|
||||
not_.getOperand() = comp and
|
||||
comp.compares(p0.getVariable().getAnAccess(), op, p1.getVariable().getAnAccess())
|
||||
not_.getOperand() = comp
|
||||
|
|
||||
this.getName() = "__eq__" and op instanceof NotEq
|
||||
exists(Cmpop op |
|
||||
comp.(Compare).compares(p0.getVariable().getAnAccess(), op, p1.getVariable().getAnAccess())
|
||||
|
|
||||
this.getName() = "__eq__" and op instanceof NotEq
|
||||
or
|
||||
this.getName() = "__ne__" and op instanceof Eq
|
||||
)
|
||||
or
|
||||
this.getName() = "__ne__" and op instanceof Eq
|
||||
exists(DataFlow::MethodCallNode call, string name |
|
||||
call.calls(DataFlow::exprNode(p0.getVariable().getAnAccess()), name) and
|
||||
call.getArg(0).asExpr() = p1.getVariable().getAnAccess()
|
||||
|
|
||||
this.getName() = "__eq__" and name = "__ne__"
|
||||
or
|
||||
this.getName() = "__ne__" and name = "__eq__"
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user