C++: Use [, ...] syntax more widely.

This commit is contained in:
Geoffrey White
2020-10-02 17:57:02 +01:00
parent fce76e2799
commit 3536d84bdf
23 changed files with 87 additions and 200 deletions

View File

@@ -23,10 +23,7 @@ import semmle.code.cpp.security.TaintTracking
* ```
*/
predicate sourceSized(FunctionCall fc, Expr src) {
exists(string name |
(name = "strncpy" or name = "strncat" or name = "memcpy" or name = "memmove") and
fc.getTarget().hasGlobalOrStdName(name)
) and
fc.getTarget().hasGlobalOrStdName(["strncpy", "strncat", "memcpy", "memmove"]) and
exists(Expr dest, Expr size, Variable v |
fc.getArgument(0) = dest and
fc.getArgument(1) = src and

View File

@@ -16,10 +16,7 @@ import cpp
class Allocation extends FunctionCall {
Allocation() {
exists(string name |
this.getTarget().hasGlobalOrStdName(name) and
(name = "malloc" or name = "calloc" or name = "realloc")
)
this.getTarget().hasGlobalOrStdName(["malloc", "calloc", "realloc"])
}
private string getName() { this.getTarget().hasGlobalOrStdName(result) }

View File

@@ -13,14 +13,7 @@
import cpp
class ForbiddenFunction extends Function {
ForbiddenFunction() {
exists(string name | name = this.getName() |
name = "setjmp" or
name = "longjmp" or
name = "sigsetjmp" or
name = "siglongjmp"
)
}
ForbiddenFunction() { this.getName() = ["setjmp", "longjmp", "sigsetjmp", "siglongjmp"] }
}
from FunctionCall call

View File

@@ -40,9 +40,7 @@ class DateStructModifiedFieldAccess extends LeapYearFieldAccess {
*/
class SafeTimeGatheringFunction extends Function {
SafeTimeGatheringFunction() {
this.getQualifiedName() = "GetFileTime" or
this.getQualifiedName() = "GetSystemTime" or
this.getQualifiedName() = "NtQuerySystemTime"
this.getQualifiedName() = ["GetFileTime", "GetSystemTime", "NtQuerySystemTime"]
}
}
@@ -51,15 +49,11 @@ class SafeTimeGatheringFunction extends Function {
*/
class TimeConversionFunction extends Function {
TimeConversionFunction() {
this.getQualifiedName() = "FileTimeToSystemTime" or
this.getQualifiedName() = "SystemTimeToFileTime" or
this.getQualifiedName() = "SystemTimeToTzSpecificLocalTime" or
this.getQualifiedName() = "SystemTimeToTzSpecificLocalTimeEx" or
this.getQualifiedName() = "TzSpecificLocalTimeToSystemTime" or
this.getQualifiedName() = "TzSpecificLocalTimeToSystemTimeEx" or
this.getQualifiedName() = "RtlLocalTimeToSystemTime" or
this.getQualifiedName() = "RtlTimeToSecondsSince1970" or
this.getQualifiedName() = "_mkgmtime"
this.getQualifiedName() =
["FileTimeToSystemTime", "SystemTimeToFileTime", "SystemTimeToTzSpecificLocalTime",
"SystemTimeToTzSpecificLocalTimeEx", "TzSpecificLocalTimeToSystemTime",
"TzSpecificLocalTimeToSystemTimeEx", "RtlLocalTimeToSystemTime",
"RtlTimeToSecondsSince1970", "_mkgmtime"]
}
}

View File

@@ -10,13 +10,8 @@ import cpp
*/
class SALMacro extends Macro {
SALMacro() {
exists(string filename | filename = this.getFile().getBaseName() |
filename = "sal.h" or
filename = "specstrings_strict.h" or
filename = "specstrings.h" or
filename = "w32p.h" or
filename = "minwindef.h"
) and
this.getFile().getBaseName() =
["sal.h", "specstrings_strict.h", "specstrings.h", "w32p.h", "minwindef.h"] and
(
// Dialect for Windows 8 and above
this.getName().matches("\\_%\\_")
@@ -58,10 +53,7 @@ class SALAnnotation extends MacroInvocation {
*/
class SALCheckReturn extends SALAnnotation {
SALCheckReturn() {
exists(SALMacro m | m = this.getMacro() |
m.getName() = "_Check_return_" or
m.getName() = "_Must_inspect_result_"
)
this.getMacro().(SALMacro).getName() = ["_Check_return_", "_Must_inspect_result_"]
}
}

View File

@@ -56,7 +56,7 @@ class VarargsFunction extends Function {
}
string normalTerminator(int cnt) {
(result = "0" or result = "-1") and
result = ["0", "-1"] and
cnt = trailingArgValueCount(result) and
2 * cnt > totalCount() and
not exists(FunctionCall fc, int index |

View File

@@ -66,19 +66,14 @@ class IFStream extends Type {
*/
class CinVariable extends NamespaceVariable {
CinVariable() {
(
getName() = "cin" or
getName() = "wcin"
) and
getNamespace().getName() = "std"
this.hasQualifiedName("std", ["cin", "wcin"])
}
}
/** A call to `std::operator>>`. */
class OperatorRShiftCall extends FunctionCall {
OperatorRShiftCall() {
getTarget().getNamespace().getName() = "std" and
getTarget().hasName("operator>>")
getTarget().hasQualifiedName("std", "operator>>")
}
/*

View File

@@ -14,12 +14,7 @@ import cpp
predicate potentiallyDangerousFunction(Function f, string message) {
exists(string name | f.hasGlobalName(name) |
(
name = "gmtime" or
name = "localtime" or
name = "ctime" or
name = "asctime"
) and
name = ["gmtime", "localtime", "ctime", "asctime"] and
message = "Call to " + name + " is potentially dangerous"
)
}

View File

@@ -19,12 +19,7 @@ predicate worldWritableCreation(FileCreationExpr fc, int mode) {
}
predicate setWorldWritable(FunctionCall fc, int mode) {
exists(string name | fc.getTarget().getName() = name |
name = "chmod" or
name = "fchmod" or
name = "_chmod" or
name = "_wchmod"
) and
fc.getTarget().getName() = ["chmod", "fchmod", "_chmod", "_wchmod"] and
mode = fc.getArgument(1).getValue().toInt() and
sets(mode, s_iwoth())
}

View File

@@ -31,11 +31,7 @@ predicate sets(int mask, int fields) { mask.bitAnd(fields) != 0 }
* one of the `umask` family of functions.
*/
private int umask(FunctionCall fc) {
exists(string name | name = fc.getTarget().getName() |
name = "umask" or
name = "_umask" or
name = "_umask_s"
) and
fc.getTarget().getName() = ["umask", "_umask", "_umask_s"] and
result = fc.getArgument(0).getValue().toInt()
}
@@ -89,11 +85,7 @@ abstract class FileCreationExpr extends FunctionCall {
class OpenCreationExpr extends FileCreationExpr {
OpenCreationExpr() {
exists(string name | name = this.getTarget().getName() |
name = "open" or
name = "_open" or
name = "_wopen"
) and
this.getTarget().getName() = ["open", "_open", "_wopen"] and
sets(this.getArgument(1).getValue().toInt(), o_creat())
}
@@ -134,14 +126,9 @@ private int fopenMode() {
class FopenCreationExpr extends FileCreationExpr {
FopenCreationExpr() {
exists(string name | name = this.getTarget().getName() |
name = "fopen" or
name = "_wfopen" or
name = "fsopen" or
name = "_wfsopen"
) and
this.getTarget().getName() = ["fopen", "_wfopen", "fsopen", "_wfsopen"] and
exists(string mode |
(mode = "w" or mode = "a") and
mode = ["w", "a"] and
this.getArgument(1).getValue().matches(mode + "%")
)
}

View File

@@ -18,6 +18,6 @@ import cpp
from File f
where
(f.getExtension().toLowerCase() = "h" or f.getExtension().toLowerCase() = "hpp") and
f.getExtension().toLowerCase() = ["h", "hpp"] and
f.getExtension() != "h"
select f, "AV Rule 53: Header files will always have a file name extension of .h."

View File

@@ -21,8 +21,8 @@ import cpp
*/
class WarningLateTemplateSpecialization extends CompilerWarning {
WarningLateTemplateSpecialization() {
this.getTag() = "partial_spec_after_instantiation" or
this.getTag() = "partial_spec_after_instantiation_ambiguous"
this.getTag() =
["partial_spec_after_instantiation", "partial_spec_after_instantiation_ambiguous"]
}
}

View File

@@ -29,7 +29,7 @@ private predicate readsEnvironment(Expr read, string sourceDescription) {
exists(FunctionCall call, string name |
read = call and
call.getTarget().hasGlobalOrStdName(name) and
(name = "getenv" or name = "secure_getenv" or name = "_wgetenv") and
name = ["getenv", "secure_getenv", "_wgetenv"] and
sourceDescription = name
)
}

View File

@@ -9,10 +9,7 @@ import semmle.code.cpp.models.interfaces.FormattingFunction
import semmle.code.cpp.models.implementations.Printf
class PrintfFormatAttribute extends FormatAttribute {
PrintfFormatAttribute() {
getArchetype() = "printf" or
getArchetype() = "__printf__"
}
PrintfFormatAttribute() { getArchetype() = ["printf", "__printf__"] }
}
/**
@@ -601,12 +598,12 @@ class FormatLiteral extends Literal {
or
len = "l" and result = this.getLongType()
or
(len = "ll" or len = "L" or len = "q") and
len = ["ll", "L", "q"] and
result instanceof LongLongType
or
len = "j" and result = this.getIntmax_t()
or
(len = "z" or len = "Z") and
len = ["z", "Z"] and
(result = this.getSize_t() or result = this.getSsize_t())
or
len = "t" and result = this.getPtrdiff_t()
@@ -639,12 +636,12 @@ class FormatLiteral extends Literal {
or
len = "l" and result = this.getLongType()
or
(len = "ll" or len = "L" or len = "q") and
len = ["ll", "L", "q"] and
result instanceof LongLongType
or
len = "j" and result = this.getIntmax_t()
or
(len = "z" or len = "Z") and
len = ["z", "Z"] and
(result = this.getSize_t() or result = this.getSsize_t())
or
len = "t" and result = this.getPtrdiff_t()
@@ -670,9 +667,7 @@ class FormatLiteral extends Literal {
FloatingPointType getFloatingPointConversion(int n) {
exists(string len |
len = this.getLength(n) and
if len = "L" or len = "ll"
then result instanceof LongDoubleType
else result instanceof DoubleType
if len = ["L", "ll"] then result instanceof LongDoubleType else result instanceof DoubleType
)
}
@@ -689,7 +684,7 @@ class FormatLiteral extends Literal {
or
len = "l" and base = this.getLongType()
or
(len = "ll" or len = "L") and
len = ["ll", "L"] and
base instanceof LongLongType
or
len = "q" and base instanceof LongLongType
@@ -736,12 +731,12 @@ class FormatLiteral extends Literal {
exists(string len, string conv |
this.parseConvSpec(n, _, _, _, _, _, len, conv) and
(
(conv = "c" or conv = "C") and
conv = ["c", "C"] and
len = "h" and
result instanceof PlainCharType
or
(conv = "c" or conv = "C") and
(len = "l" or len = "w") and
conv = ["c", "C"] and
len = ["l", "w"] and
result = getWideCharType()
or
conv = "c" and
@@ -781,12 +776,12 @@ class FormatLiteral extends Literal {
exists(string len, string conv |
this.parseConvSpec(n, _, _, _, _, _, len, conv) and
(
(conv = "s" or conv = "S") and
conv = ["s", "S"] and
len = "h" and
result.(PointerType).getBaseType() instanceof PlainCharType
or
(conv = "s" or conv = "S") and
(len = "l" or len = "w") and
conv = ["s", "S"] and
len = ["l", "w"] and
result.(PointerType).getBaseType() = getWideCharType()
or
conv = "s" and
@@ -823,10 +818,7 @@ class FormatLiteral extends Literal {
private Type getConversionType9(int n) {
this.getConversionChar(n) = "Z" and
(
this.getLength(n) = "l" or
this.getLength(n) = "w"
) and
this.getLength(n) = ["l", "w"] and
exists(Type t |
t.getName() = "UNICODE_STRING" and
result.(PointerType).getBaseType() = t
@@ -979,10 +971,7 @@ class FormatLiteral extends Literal {
len = (afterdot.maximum(1) + 6).maximum(1 + 1 + dot + afterdot + 1 + 1 + 3)
) // (e.g. "-1.59203e-319")
or
(
this.getConversionChar(n).toLowerCase() = "d" or
this.getConversionChar(n).toLowerCase() = "i"
) and
this.getConversionChar(n).toLowerCase() = ["d", "i"] and
// e.g. -2^31 = "-2147483648"
exists(int sizeBits |
sizeBits =

View File

@@ -8,14 +8,13 @@ import cpp
*/
class StrcatFunction extends Function {
StrcatFunction() {
exists(string name | name = getName() |
name = "strcat" or // strcat(dst, src)
name = "strncat" or // strncat(dst, src, max_amount)
name = "wcscat" or // wcscat(dst, src)
name = "_mbscat" or // _mbscat(dst, src)
name = "wcsncat" or // wcsncat(dst, src, max_amount)
name = "_mbsncat" or // _mbsncat(dst, src, max_amount)
name = "_mbsncat_l" // _mbsncat_l(dst, src, max_amount, locale)
)
// strcat(dst, src)
// strncat(dst, src, max_amount)
// wcscat(dst, src)
// _mbscat(dst, src)
// wcsncat(dst, src, max_amount)
// _mbsncat(dst, src, max_amount)
// _mbsncat_l(dst, src, max_amount, locale)
getName() = ["strcat", "strncat", "wcscat", "_mbscat", "wcsncat", "_mbsncat", "_mbsncat_l"]
}
}

View File

@@ -16,11 +16,10 @@ import semmle.code.cpp.models.interfaces.FlowSource
class GetsFunction extends DataFlowFunction, TaintFunction, ArrayFunction, AliasFunction,
SideEffectFunction, RemoteFlowFunction {
GetsFunction() {
exists(string name | hasGlobalOrStdName(name) |
name = "gets" or // gets(str)
name = "fgets" or // fgets(str, num, stream)
name = "fgetws" // fgetws(wstr, num, stream)
)
// gets(str)
// fgets(str, num, stream)
// fgetws(wstr, num, stream)
hasGlobalOrStdName(["gets", "fgets", "fgetws"])
}
override predicate hasDataFlow(FunctionInput input, FunctionOutput output) {

View File

@@ -4,16 +4,11 @@ import semmle.code.cpp.models.interfaces.DataFlow
import semmle.code.cpp.models.interfaces.SideEffect
/**
* The standard function templates `std::move` and `std::identity`
* The standard function templates `std::move` and `std::forward`.
*/
class IdentityFunction extends DataFlowFunction, SideEffectFunction, AliasFunction {
IdentityFunction() {
this.getNamespace().getParentNamespace() instanceof GlobalNamespace and
this.getNamespace().getName() = "std" and
(
this.getName() = "move" or
this.getName() = "forward"
)
this.hasQualifiedName("std", ["move", "forward"])
}
override predicate hasOnlySpecificReadSideEffects() { any() }

View File

@@ -13,43 +13,26 @@ import semmle.code.cpp.models.interfaces.SideEffect
*/
class StrcpyFunction extends ArrayFunction, DataFlowFunction, TaintFunction, SideEffectFunction {
StrcpyFunction() {
exists(string name | name = getName() |
// strcpy(dst, src)
name = "strcpy"
or
// wcscpy(dst, src)
name = "wcscpy"
or
// _mbscpy(dst, src)
name = "_mbscpy"
or
(
name = "strcpy_s" or // strcpy_s(dst, max_amount, src)
name = "wcscpy_s" or // wcscpy_s(dst, max_amount, src)
name = "_mbscpy_s" // _mbscpy_s(dst, max_amount, src)
) and
// exclude the 2-parameter template versions
// that find the size of a fixed size destination buffer.
getNumberOfParameters() = 3
or
// strncpy(dst, src, max_amount)
name = "strncpy"
or
// _strncpy_l(dst, src, max_amount, locale)
name = "_strncpy_l"
or
// wcsncpy(dst, src, max_amount)
name = "wcsncpy"
or
// _wcsncpy_l(dst, src, max_amount, locale)
name = "_wcsncpy_l"
or
// _mbsncpy(dst, src, max_amount)
name = "_mbsncpy"
or
// _mbsncpy_l(dst, src, max_amount, locale)
name = "_mbsncpy_l"
)
// strcpy(dst, src)
// wcscpy(dst, src)
// _mbscpy(dst, src)
// strncpy(dst, src, max_amount)
// _strncpy_l(dst, src, max_amount, locale)
// wcsncpy(dst, src, max_amount)
// _wcsncpy_l(dst, src, max_amount, locale)
// _mbsncpy(dst, src, max_amount)
// _mbsncpy_l(dst, src, max_amount, locale)
getName() =
["strcpy", "wcscpy", "_mbscpy", "strncpy", "_strncpy_l", "wcsncpy", "_wcsncpy_l", "_mbsncpy",
"_mbsncpy_l"]
or
// strcpy_s(dst, max_amount, src)
// wcscpy_s(dst, max_amount, src)
// _mbscpy_s(dst, max_amount, src)
getName() = ["strcpy_s", "wcscpy_s", "_mbscpy_s"] and
// exclude the 2-parameter template versions
// that find the size of a fixed size destination buffer.
getNumberOfParameters() = 3
}
/**

View File

@@ -354,11 +354,10 @@ class SnprintfBW extends BufferWriteCall {
*/
class GetsBW extends BufferWriteCall {
GetsBW() {
exists(TopLevelFunction fn, string name | fn = getTarget() and name = fn.getName() |
name = "gets" or // gets(dst)
name = "fgets" or // fgets(dst, max_amount, src_stream)
name = "fgetws" // fgetws(dst, max_amount, src_stream)
)
// gets(dst)
// fgets(dst, max_amount, src_stream)
// fgetws(dst, max_amount, src_stream)
getTarget().(TopLevelFunction).getName() = ["gets", "fgets", "fgetws"]
}
/**

View File

@@ -123,9 +123,7 @@ class WriteFunctionCall extends ChainedOutputCall {
private predicate fileStreamChain(ChainedOutputCall out, Expr source, Expr dest) {
source = out.getSource() and
dest = out.getEndDest() and
exists(string nme | nme = "basic_ofstream" or nme = "basic_fstream" |
dest.getUnderlyingType().(Class).getSimpleName() = nme
)
dest.getUnderlyingType().(Class).getSimpleName() = ["basic_ofstream", "basic_fstream"]
}
/**
@@ -139,15 +137,7 @@ private predicate fileWrite(Call write, Expr source, Expr dest) {
// named functions
name = "fwrite" and s = 0 and d = 3
or
(
name = "fputs" or
name = "fputws" or
name = "fputc" or
name = "fputwc" or
name = "putc" or
name = "putwc" or
name = "putw"
) and
name = ["fputs", "fputws", "fputc", "fputwc", "putc", "putwc", "putw"] and
s = 0 and
d = 1
)

View File

@@ -48,10 +48,7 @@ private predicate outputFile(Expr e) {
name = e.(VariableAccess).getTarget().(GlobalVariable).toString() or
name = e.findRootCause().(Macro).getName()
) and
(
name = "stdout" or
name = "stderr"
)
name = ["stdout", "stderr"]
)
}

View File

@@ -252,11 +252,10 @@ private predicate insideFunctionValueMoveTo(Element src, Element dest) {
copyValueBetweenArguments(c.getTarget(), sourceArg, destArg) and
// Only consider copies from `printf`-like functions if the format is a string
(
exists(FormattingFunctionCall ffc, FormatLiteral format, string argFormat |
exists(FormattingFunctionCall ffc, FormatLiteral format |
ffc = c and
format = ffc.getFormat() and
format.getConversionChar(sourceArg - ffc.getTarget().getNumberOfParameters()) = argFormat and
(argFormat = "s" or argFormat = "S")
format.getConversionChar(sourceArg - ffc.getTarget().getNumberOfParameters()) = ["s", "S"]
)
or
not exists(FormatLiteral fl | fl = c.(FormattingFunctionCall).getFormat())
@@ -273,12 +272,11 @@ private predicate insideFunctionValueMoveTo(Element src, Element dest) {
dest = c
)
or
exists(FormattingFunctionCall formattingSend, int arg, FormatLiteral format, string argFormat |
exists(FormattingFunctionCall formattingSend, int arg, FormatLiteral format |
dest = formattingSend and
formattingSend.getArgument(arg) = src and
format = formattingSend.getFormat() and
format.getConversionChar(arg - formattingSend.getTarget().getNumberOfParameters()) = argFormat and
(argFormat = "s" or argFormat = "S" or argFormat = "@")
format.getConversionChar(arg - formattingSend.getTarget().getNumberOfParameters()) = ["s", "S", "@"]
)
or
// Expressions computed from tainted data are also tainted

View File

@@ -65,15 +65,8 @@ class UMLElement extends XMLElement {
*/
class UMLType extends UMLElement {
UMLType() {
exists(string type |
this.getName() = "packagedElement" and
this.getAttribute("type").getValue() = type and
(
type = "uml:Class" or
type = "uml:Interface" or
type = "uml:PrimitiveType"
)
)
this.getName() = "packagedElement" and
this.getAttribute("type").getValue() = ["uml:Class", "uml:Interface", "uml:PrimitiveType"]
}
/**