C++: Better support for complex numbers in IR and AST

This PR adds better support for differentiating complex and imaginary floating-point types from real floating-point types, in both the AST and in the IR type system.

*AST Changes*
- Introduces the new class `TypeDomain`, which can be either `RealDomain`, `ImaginaryDomain` or `ComplexDomain`. "type domain" is the term used for this concept in the C standard, and I couldn't think of a better one.
- Introduces `FloatingPointType.getDomain()`, to get the type domain of the type.
- Introduces `FloatingPointType.getBase()`, to get the numeric base of the type (either 2 or 10).
- Introduces three new subtypes of `FloatingPointType`: `RealNumberType`, `ComplexNumberType`, and `ImaginaryNumberType`, which differentiate between the types based on their type domain. Note that the decimal types (e.g., `_Decimal32`) are included in `RealNumberType`.
- Introduces two new subtypes of `FloatingPointType`: `BinaryFloatingPointType` and `DecimalFloatingPointType`, which differentiate between the types based on their numeric base, independent of type domain.

*IR Changes*
- `IRFloatingPointType` now has two additional parameters: the base and the type domain.
- New test that ensures that C++ types get mapped to the correct IR types.
- New IR test that verifies the IR for some basic usage of complex FP types.
This commit is contained in:
Dave Bartolomeo
2020-03-27 18:08:14 -04:00
parent 1baf5df342
commit c3a6ca0d9a
14 changed files with 526 additions and 44 deletions

View File

@@ -697,28 +697,187 @@ class Int128Type extends IntegralType {
override string getCanonicalQLClass() { result = "Int128Type" }
}
private newtype TTypeDomain =
TRealDomain() or
TComplexDomain() or
TImaginaryDomain()
/**
* The C/C++ floating point types. See 4.5. This includes `float`,
* `double` and `long double` types.
* ```
* float f;
* double d;
* long double ld;
* ```
* The type domain of a floating-point type. One of `RealDomain`, `ComplexDomain`, or
* `ImaginaryDomain`.
*/
class TypeDomain extends TTypeDomain {
string toString() { none() }
}
/**
* The type domain of a floating-point type that represents a real number.
*/
class RealDomain extends TypeDomain, TRealDomain {
final override string toString() { result = "real" }
}
/**
* The type domain of a floating-point type that represents a complex number.
*/
class ComplexDomain extends TypeDomain, TComplexDomain {
final override string toString() { result = "complex" }
}
/**
* The type domain of a floating-point type that represents an imaginary number.
*/
class ImaginaryDomain extends TypeDomain, TImaginaryDomain {
final override string toString() { result = "imaginary" }
}
/**
* Data for floating-point types.
*
* kind: The original type kind. Can be any floating-point type kind.
* base: The numeric base of the number's representation. Can be 2 (binary) or 10 (decimal).
* domain: The type domain of the type. Can be `RealDomain`, `ComplexDomain`, or `ImaginaryDomain`.
* realKind: The type kind of the corresponding real type. For example, the corresponding real type
* of `_Complex double` is `double`.
* extended: `true` if the number is an extended-precision floating-point number, such as
* `_Float32x`.
*/
private predicate floatingPointTypeMapping(
int kind, int base, TTypeDomain domain, int realKind, boolean extended
) {
// float
kind = 24 and base = 2 and domain = TRealDomain() and realKind = 24 and extended = false
or
// double
kind = 25 and base = 2 and domain = TRealDomain() and realKind = 25 and extended = false
or
// long double
kind = 26 and base = 2 and domain = TRealDomain() and realKind = 26 and extended = false
or
// _Complex float
kind = 27 and base = 2 and domain = TComplexDomain() and realKind = 24 and extended = false
or
// _Complex double
kind = 28 and base = 2 and domain = TComplexDomain() and realKind = 25 and extended = false
or
// _Complex long double
kind = 29 and base = 2 and domain = TComplexDomain() and realKind = 26 and extended = false
or
// _Imaginary float
kind = 30 and base = 2 and domain = TImaginaryDomain() and realKind = 24 and extended = false
or
// _Imaginary double
kind = 31 and base = 2 and domain = TImaginaryDomain() and realKind = 25 and extended = false
or
// _Imaginary long double
kind = 32 and base = 2 and domain = TImaginaryDomain() and realKind = 26 and extended = false
or
// __float128
kind = 38 and base = 2 and domain = TRealDomain() and realKind = 38 and extended = false
or
// _Complex __float128
kind = 39 and base = 2 and domain = TComplexDomain() and realKind = 38 and extended = false
or
// _Decimal32
kind = 40 and base = 10 and domain = TRealDomain() and realKind = 40 and extended = false
or
// _Decimal64
kind = 41 and base = 10 and domain = TRealDomain() and realKind = 41 and extended = false
or
// _Decimal128
kind = 42 and base = 10 and domain = TRealDomain() and realKind = 42 and extended = false
or
// _Float32
kind = 45 and base = 2 and domain = TRealDomain() and realKind = 45 and extended = false
or
// _Float32x
kind = 46 and base = 2 and domain = TRealDomain() and realKind = 46 and extended = true
or
// _Float64
kind = 47 and base = 2 and domain = TRealDomain() and realKind = 47 and extended = false
or
// _Float64x
kind = 48 and base = 2 and domain = TRealDomain() and realKind = 48 and extended = true
or
// _Float128
kind = 49 and base = 2 and domain = TRealDomain() and realKind = 49 and extended = false
or
// _Float128x
kind = 50 and base = 2 and domain = TRealDomain() and realKind = 50 and extended = true
}
/**
* The C/C++ floating point types. See 4.5. This includes `float`, `double` and `long double`, the
* fixed-size floating-point types like `_Float32`, the extended-precision floating-point types like
* `_Float64x`, and the decimal floating-point types like `_Decimal32`. It also includes the complex
* and imaginary versions of all of these types.
*/
class FloatingPointType extends ArithmeticType {
final int base;
final TypeDomain domain;
final int realKind;
final boolean extended;
FloatingPointType() {
exists(int kind |
builtintypes(underlyingElement(this), _, kind, _, _, _) and
(
kind >= 24 and kind <= 32
or
kind >= 38 and kind <= 42
or
kind >= 45 and kind <= 50
)
floatingPointTypeMapping(kind, base, domain, realKind, extended)
)
}
/** Gets the numeric base of this type's representation: 2 (binary) or 10 (decimal). */
final int getBase() { result = base }
/**
* Gets the type domain of this type. Can be `RealDomain`, `ComplexDomain`, or `ImaginaryDomain`.
*/
final TypeDomain getDomain() { result = domain }
/**
* Gets the corresponding real type of this type. For example, the corresponding real type of
* `_Complex double` is `double`.
*/
final RealNumberType getRealType() {
builtintypes(unresolveElement(result), _, realKind, _, _, _)
}
/** Holds if this type is an extended precision floating-point type, such as `_Float32x`. */
final predicate isExtendedPrecision() { extended = true }
}
/**
* A floating-point type representing a real number.
*/
class RealNumberType extends FloatingPointType {
RealNumberType() { domain instanceof RealDomain }
}
/**
* A floating-point type representing a complex number.
*/
class ComplexNumberType extends FloatingPointType {
ComplexNumberType() { domain instanceof ComplexDomain }
}
/**
* A floating-point type representing an imaginary number.
*/
class ImaginaryNumberType extends FloatingPointType {
ImaginaryNumberType() { domain instanceof ImaginaryDomain }
}
/**
* A floating-point type whose representation is base 2.
*/
class BinaryFloatingPointType extends FloatingPointType {
BinaryFloatingPointType() { base = 2 }
}
/**
* A floating-point type whose representation is base 10.
*/
class DecimalFloatingPointType extends FloatingPointType {
DecimalFloatingPointType() { base = 10 }
}
/**
@@ -727,7 +886,7 @@ class FloatingPointType extends ArithmeticType {
* float f;
* ```
*/
class FloatType extends FloatingPointType {
class FloatType extends RealNumberType, BinaryFloatingPointType {
FloatType() { builtintypes(underlyingElement(this), _, 24, _, _, _) }
override string getCanonicalQLClass() { result = "FloatType" }
@@ -739,7 +898,7 @@ class FloatType extends FloatingPointType {
* double d;
* ```
*/
class DoubleType extends FloatingPointType {
class DoubleType extends RealNumberType, BinaryFloatingPointType {
DoubleType() { builtintypes(underlyingElement(this), _, 25, _, _, _) }
override string getCanonicalQLClass() { result = "DoubleType" }
@@ -751,7 +910,7 @@ class DoubleType extends FloatingPointType {
* long double ld;
* ```
*/
class LongDoubleType extends FloatingPointType {
class LongDoubleType extends RealNumberType, BinaryFloatingPointType {
LongDoubleType() { builtintypes(underlyingElement(this), _, 26, _, _, _) }
override string getCanonicalQLClass() { result = "LongDoubleType" }
@@ -763,7 +922,7 @@ class LongDoubleType extends FloatingPointType {
* __float128 f128;
* ```
*/
class Float128Type extends FloatingPointType {
class Float128Type extends RealNumberType, BinaryFloatingPointType {
Float128Type() { builtintypes(underlyingElement(this), _, 38, _, _, _) }
override string getCanonicalQLClass() { result = "Float128Type" }
@@ -775,7 +934,7 @@ class Float128Type extends FloatingPointType {
* _Decimal32 d32;
* ```
*/
class Decimal32Type extends FloatingPointType {
class Decimal32Type extends RealNumberType, DecimalFloatingPointType {
Decimal32Type() { builtintypes(underlyingElement(this), _, 40, _, _, _) }
override string getCanonicalQLClass() { result = "Decimal32Type" }
@@ -787,7 +946,7 @@ class Decimal32Type extends FloatingPointType {
* _Decimal64 d64;
* ```
*/
class Decimal64Type extends FloatingPointType {
class Decimal64Type extends RealNumberType, DecimalFloatingPointType {
Decimal64Type() { builtintypes(underlyingElement(this), _, 41, _, _, _) }
override string getCanonicalQLClass() { result = "Decimal64Type" }
@@ -799,7 +958,7 @@ class Decimal64Type extends FloatingPointType {
* _Decimal128 d128;
* ```
*/
class Decimal128Type extends FloatingPointType {
class Decimal128Type extends RealNumberType, DecimalFloatingPointType {
Decimal128Type() { builtintypes(underlyingElement(this), _, 42, _, _, _) }
override string getCanonicalQLClass() { result = "Decimal128Type" }

View File

@@ -12,7 +12,9 @@ private newtype TIRType =
TIRBooleanType(int byteSize) { Language::hasBooleanType(byteSize) } or
TIRSignedIntegerType(int byteSize) { Language::hasSignedIntegerType(byteSize) } or
TIRUnsignedIntegerType(int byteSize) { Language::hasUnsignedIntegerType(byteSize) } or
TIRFloatingPointType(int byteSize) { Language::hasFloatingPointType(byteSize) } or
TIRFloatingPointType(int byteSize, int base, Language::TypeDomain domain) {
Language::hasFloatingPointType(byteSize, base, domain)
} or
TIRAddressType(int byteSize) { Language::hasAddressType(byteSize) } or
TIRFunctionAddressType(int byteSize) { Language::hasFunctionAddressType(byteSize) } or
TIROpaqueType(Language::OpaqueTypeTag tag, int byteSize) {
@@ -104,7 +106,7 @@ private class IRSizedType extends IRType {
this = TIRBooleanType(byteSize) or
this = TIRSignedIntegerType(byteSize) or
this = TIRUnsignedIntegerType(byteSize) or
this = TIRFloatingPointType(byteSize) or
this = TIRFloatingPointType(byteSize, _, _) or
this = TIRAddressType(byteSize) or
this = TIRFunctionAddressType(byteSize) or
this = TIROpaqueType(_, byteSize)
@@ -133,7 +135,7 @@ class IRNumericType extends IRSizedType {
IRNumericType() {
this = TIRSignedIntegerType(byteSize) or
this = TIRUnsignedIntegerType(byteSize) or
this = TIRFloatingPointType(byteSize)
this = TIRFloatingPointType(byteSize, _, _)
}
}
@@ -171,14 +173,45 @@ class IRUnsignedIntegerType extends IRNumericType, TIRUnsignedIntegerType {
* A floating-point type.
*/
class IRFloatingPointType extends IRNumericType, TIRFloatingPointType {
final override string toString() { result = "float" + byteSize.toString() }
private final int base;
private final Language::TypeDomain domain;
IRFloatingPointType() {
this = TIRFloatingPointType(_, base, domain)
}
final override string toString() {
result = getDomainPrefix() + getBaseString() + byteSize.toString()
}
final override Language::LanguageType getCanonicalLanguageType() {
result = Language::getCanonicalFloatingPointType(byteSize)
result = Language::getCanonicalFloatingPointType(byteSize, base, domain)
}
pragma[noinline]
final override int getByteSize() { result = byteSize }
/** Gets the numeric base of the type. Can be either 2 (binary) or 10 (decimal). */
final int getBase() { result = base }
/**
* Gets the type domain of the type. Can be `RealDomain`, `ComplexDomain`, or `ImaginaryDomain`.
*/
final Language::TypeDomain getDomain() { result = domain }
private string getBaseString() {
base = 2 and result = "float"
or
base = 10 and result = "decimal"
}
private string getDomainPrefix() {
domain instanceof Language::RealDomain and result = ""
or
domain instanceof Language::ComplexDomain and result = "c"
or
domain instanceof Language::ImaginaryDomain and result = "i"
}
}
/**

View File

@@ -86,9 +86,15 @@ predicate hasUnsignedIntegerType(int byteSize) {
}
/**
* Holds if an `IRFloatingPointType` with the specified `byteSize` should exist.
* Holds if an `IRFloatingPointType` with the specified size, base, and type domain should exist.
*/
predicate hasFloatingPointType(int byteSize) { byteSize = any(FloatingPointType type).getSize() }
predicate hasFloatingPointType(int byteSize, int base, TypeDomain domain) {
exists(FloatingPointType type |
byteSize = type.getSize() and
base = type.getBase() and
domain = type.getDomain()
)
}
private predicate isPointerIshType(Type type) {
type instanceof PointerType
@@ -159,8 +165,13 @@ private IRType getIRTypeForPRValue(Type type) {
isUnsignedIntegerType(unspecifiedType) and
result.(IRUnsignedIntegerType).getByteSize() = type.getSize()
or
unspecifiedType instanceof FloatingPointType and
result.(IRFloatingPointType).getByteSize() = type.getSize()
exists(FloatingPointType floatType, IRFloatingPointType irFloatType |
floatType = unspecifiedType and
irFloatType = result and
irFloatType.getByteSize() = floatType.getSize() and
irFloatType.getBase() = floatType.getBase() and
irFloatType.getDomain() = floatType.getDomain()
)
or
isPointerIshType(unspecifiedType) and result.(IRAddressType).getByteSize() = getTypeSize(type)
or
@@ -438,15 +449,30 @@ CppPRValueType getCanonicalUnsignedIntegerType(int byteSize) {
}
/**
* Gets the `CppType` that is the canonical type for an `IRFloatingPointType` with the specified
* `byteSize`.
* Gets the sort priority of a `RealNumberType` base on its precision.
*/
CppPRValueType getCanonicalFloatingPointType(int byteSize) {
private int getPrecisionPriority(RealNumberType type) {
// Prefer `double`, `float`, `long double` in that order.
if type instanceof DoubleType then result = 4
else if type instanceof FloatType then result = 3
else if type instanceof LongDoubleType then result = 2
// If we get this far, prefer non-extended-precision types.
else if not type.isExtendedPrecision() then result = 1
else result = 0
}
/**
* Gets the `CppType` that is the canonical type for an `IRFloatingPointType` with the specified
* size, base, and type domain.
*/
CppPRValueType getCanonicalFloatingPointType(int byteSize, int base, TypeDomain domain) {
result =
TPRValueType(max(FloatingPointType type |
type.getSize() = byteSize
type.getSize() = byteSize and
type.getBase() = base and
type.getDomain() = domain
|
type order by type.toString() desc
type order by getPrecisionPriority(type.getRealType()), type.toString() desc
))
}

View File

@@ -9,6 +9,11 @@ class LanguageType = CppType;
class OpaqueTypeTag = Cpp::Type;
class TypeDomain = Cpp::TypeDomain;
class RealDomain = Cpp::RealDomain;
class ComplexDomain = Cpp::ComplexDomain;
class ImaginaryDomain = Cpp::ImaginaryDomain;
class Function = Cpp::Function;
class Location = Cpp::Location;

View File

@@ -234,6 +234,62 @@ clang.cpp:
# 6| 0: [VariableAccess] globalInt
# 6| Type = [IntType] int
# 6| ValueCategory = lvalue
complex.c:
# 1| [TopLevelFunction] void complex_math()
# 1| params:
# 1| body: [Block] { ... }
# 2| 0: [DeclStmt] declaration
# 2| 0: [VariableDeclarationEntry] definition of cf
# 2| Type = [ArithmeticType] _Complex float
# 2| init: [Initializer] initializer for cf
# 2| expr: [CStyleCast] (_Complex float)...
# 2| Conversion = [FloatingPointConversion] floating point conversion
# 2| Type = [ArithmeticType] _Complex float
# 2| ValueCategory = prvalue
# 2| expr: [AddExpr] ... + ...
# 2| Type = [ArithmeticType] _Complex double
# 2| ValueCategory = prvalue
# 2| 0: [CStyleCast] (_Complex double)...
# 2| Conversion = [FloatingPointConversion] floating point conversion
# 2| Type = [ArithmeticType] _Complex double
# 2| ValueCategory = prvalue
# 2| expr: [Literal] 2.0
# 2| Type = [DoubleType] double
# 2| Value = [Literal] 2.0
# 2| ValueCategory = prvalue
# 2| 1: [CStyleCast] (_Complex double)...
# 2| Conversion = [FloatingPointConversion] floating point conversion
# 2| Type = [ArithmeticType] _Complex double
# 2| ValueCategory = prvalue
# 2| expr: [Literal] (0.0,1.0i)
# 2| Type = [ArithmeticType] _Complex float
# 2| Value = [Literal] (0.0,1.0i)
# 2| ValueCategory = prvalue
# 3| 1: [DeclStmt] declaration
# 3| 0: [VariableDeclarationEntry] definition of cf2
# 3| Type = [ArithmeticType] _Complex float
# 3| init: [Initializer] initializer for cf2
# 3| expr: [MulExpr] ... * ...
# 3| Type = [ArithmeticType] _Complex float
# 3| ValueCategory = prvalue
# 3| 0: [VariableAccess] cf
# 3| Type = [ArithmeticType] _Complex float
# 3| ValueCategory = prvalue(load)
# 3| 1: [VariableAccess] cf
# 3| Type = [ArithmeticType] _Complex float
# 3| ValueCategory = prvalue(load)
# 4| 2: [DeclStmt] declaration
# 4| 0: [VariableDeclarationEntry] definition of d
# 4| Type = [DoubleType] double
# 4| init: [Initializer] initializer for d
# 4| expr: [CStyleCast] (double)...
# 4| Conversion = [FloatingPointConversion] floating point conversion
# 4| Type = [DoubleType] double
# 4| ValueCategory = prvalue
# 4| expr: [VariableAccess] cf2
# 4| Type = [ArithmeticType] _Complex float
# 4| ValueCategory = prvalue(load)
# 5| 3: [ReturnStmt] return ...
ir.cpp:
# 1| [TopLevelFunction] void Constants()
# 1| params:

View File

@@ -0,0 +1,5 @@
void complex_math(void) {
_Complex float cf = 2.0 + 1.0if;
_Complex float cf2 = cf * cf;
double d = cf2;
}

View File

@@ -124,6 +124,39 @@ clang.cpp:
# 5| v5_8(void) = AliasedUse : ~mu5_4
# 5| v5_9(void) = ExitFunction :
complex.c:
# 1| void complex_math()
# 1| Block 0
# 1| v1_1(void) = EnterFunction :
# 1| mu1_2(unknown) = AliasedDefinition :
# 1| mu1_3(unknown) = InitializeNonLocal :
# 1| mu1_4(unknown) = UnmodeledDefinition :
# 2| r2_1(glval<_Complex float>) = VariableAddress[cf] :
# 2| r2_2(double) = Constant[2.0] :
# 2| r2_3(_Complex double) = Convert : r2_2
# 2| r2_4(_Complex float) = Constant[(0.0,1.0i)] :
# 2| r2_5(_Complex double) = Convert : r2_4
# 2| r2_6(_Complex double) = Add : r2_3, r2_5
# 2| r2_7(_Complex float) = Convert : r2_6
# 2| mu2_8(_Complex float) = Store : &:r2_1, r2_7
# 3| r3_1(glval<_Complex float>) = VariableAddress[cf2] :
# 3| r3_2(glval<_Complex float>) = VariableAddress[cf] :
# 3| r3_3(_Complex float) = Load : &:r3_2, ~mu1_4
# 3| r3_4(glval<_Complex float>) = VariableAddress[cf] :
# 3| r3_5(_Complex float) = Load : &:r3_4, ~mu1_4
# 3| r3_6(_Complex float) = Mul : r3_3, r3_5
# 3| mu3_7(_Complex float) = Store : &:r3_1, r3_6
# 4| r4_1(glval<double>) = VariableAddress[d] :
# 4| r4_2(glval<_Complex float>) = VariableAddress[cf2] :
# 4| r4_3(_Complex float) = Load : &:r4_2, ~mu1_4
# 4| r4_4(double) = Convert : r4_3
# 4| mu4_5(double) = Store : &:r4_1, r4_4
# 5| v5_1(void) = NoOp :
# 1| v1_5(void) = ReturnVoid :
# 1| v1_6(void) = UnmodeledUse : mu*
# 1| v1_7(void) = AliasedUse : ~mu1_4
# 1| v1_8(void) = ExitFunction :
ir.cpp:
# 1| void Constants()
# 1| Block 0

View File

@@ -0,0 +1,6 @@
void Complex(void) {
_Complex float cf; //$irtype=cfloat8
_Complex double cd; //$irtype=cfloat16
_Complex long double cld; //$irtype=cfloat32
// _Complex __float128 cf128;
}

View File

@@ -0,0 +1,65 @@
struct A {
int f_a;
};
struct B {
double f_a;
float f_b;
};
enum E {
Zero,
One,
Two,
Three
};
enum class ScopedE {
Zero,
One,
Two,
Three
};
void IRTypes() {
char c; //$irtype=int1
signed char sc; //$irtype=int1
unsigned char uc; //$irtype=uint1
short s; //$irtype=int2
signed short ss; //$irtype=int2
unsigned short us; //$irtype=uint2
int i; //$irtype=int4
signed int si; //$irtype=int4
unsigned int ui; //$irtype=uint4
long l; //$irtype=int8
signed long sl; //$irtype=int8
unsigned long ul; //$irtype=uint8
long long ll; //$irtype=int8
signed long long sll; //$irtype=int8
unsigned long long ull; //$irtype=uint8
bool b; //$irtype=bool1
float f; //$irtype=float4
double d; //$irtype=float8
long double ld; //$irtype=float16
__float128 f128; //$irtype=float16
wchar_t wc; //$irtype=uint4
// char8_t c8; //$irtype=uint1
char16_t c16; //$irtype=uint2
char32_t c32; //$irtype=uint4
int* pi; //$irtype=addr8
int& ri = i; //$irtype=addr8
void (*pfn)() = nullptr; //$irtype=func8
void (&rfn)() = IRTypes; //$irtype=func8
A s_a; //$irtype=opaque4{A}
B s_b; //$irtype=opaque16{B}
E e; //$irtype=uint4
ScopedE se; //$irtype=uint4
B a_b[10]; //$irtype=opaque160{B[10]}
}
// semmle-extractor-options: -std=c++17 --clang

View File

@@ -0,0 +1,18 @@
private import cpp
private import semmle.code.cpp.ir.implementation.raw.IR
import TestUtilities.InlineExpectationsTest
class IRTypesTest extends InlineExpectationsTest {
IRTypesTest() { this = "IRTypesTest" }
override string getARelevantTag() { result = "irtype" }
override predicate hasActualResult(Location location, string element, string tag, string value) {
exists(IRUserVariable irVar |
location = irVar.getLocation() and
element = irVar.toString() and
tag = "irtype" and
value = irVar.getIRType().toString()
)
}
}