mirror of
https://github.com/github/codeql.git
synced 2026-04-30 11:15:13 +02:00
C#: Extract extension types and members. Replacing invocations to static generated methods with invocation of extension type member.
This commit is contained in:
@@ -4,6 +4,7 @@ using System.Diagnostics.CodeAnalysis;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Semmle.Util;
|
||||
using Semmle.Extraction.CSharp.Entities;
|
||||
|
||||
namespace Semmle.Extraction.CSharp
|
||||
@@ -164,6 +165,7 @@ namespace Semmle.Extraction.CSharp
|
||||
case TypeKind.Enum:
|
||||
case TypeKind.Delegate:
|
||||
case TypeKind.Error:
|
||||
case TypeKind.Extension:
|
||||
var named = (INamedTypeSymbol)type;
|
||||
named.BuildNamedTypeId(cx, trapFile, symbolBeingDefined, constructUnderlyingTupleType);
|
||||
return;
|
||||
@@ -275,6 +277,20 @@ namespace Semmle.Extraction.CSharp
|
||||
public static IEnumerable<IFieldSymbol?> GetTupleElementsMaybeNull(this INamedTypeSymbol type) =>
|
||||
type.TupleElements;
|
||||
|
||||
private static void BuildExtensionTypeId(this INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile)
|
||||
{
|
||||
trapFile.Write("extension(");
|
||||
if (named.ExtensionMarkerName is not null)
|
||||
{
|
||||
trapFile.Write(named.ExtensionMarkerName);
|
||||
}
|
||||
else
|
||||
{
|
||||
trapFile.Write("unknown");
|
||||
}
|
||||
trapFile.Write(")");
|
||||
}
|
||||
|
||||
private static void BuildQualifierAndName(INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile, ISymbol symbolBeingDefined)
|
||||
{
|
||||
if (named.ContainingType is not null)
|
||||
@@ -289,8 +305,18 @@ namespace Semmle.Extraction.CSharp
|
||||
named.ContainingNamespace.BuildNamespace(cx, trapFile);
|
||||
}
|
||||
|
||||
var name = named.IsFileLocal ? named.MetadataName : named.Name;
|
||||
trapFile.Write(name);
|
||||
if (named.IsFileLocal)
|
||||
{
|
||||
trapFile.Write(named.MetadataName);
|
||||
}
|
||||
else if (named.IsExtension)
|
||||
{
|
||||
named.BuildExtensionTypeId(cx, trapFile);
|
||||
}
|
||||
else
|
||||
{
|
||||
trapFile.Write(named.Name);
|
||||
}
|
||||
}
|
||||
|
||||
private static void BuildTupleId(INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile, ISymbol symbolBeingDefined)
|
||||
@@ -391,6 +417,7 @@ namespace Semmle.Extraction.CSharp
|
||||
case TypeKind.Enum:
|
||||
case TypeKind.Delegate:
|
||||
case TypeKind.Error:
|
||||
case TypeKind.Extension:
|
||||
var named = (INamedTypeSymbol)type;
|
||||
named.BuildNamedTypeDisplayName(cx, trapFile, constructUnderlyingTupleType);
|
||||
return;
|
||||
@@ -465,6 +492,20 @@ namespace Semmle.Extraction.CSharp
|
||||
private static void BuildFunctionPointerTypeDisplayName(this IFunctionPointerTypeSymbol funptr, Context cx, TextWriter trapFile) =>
|
||||
BuildFunctionPointerSignature(funptr, trapFile, s => s.BuildDisplayName(cx, trapFile));
|
||||
|
||||
private static void BuildExtensionTypeDisplayName(this INamedTypeSymbol named, Context cx, TextWriter trapFile)
|
||||
{
|
||||
trapFile.Write("extension(");
|
||||
if (named.ExtensionParameter?.Type is ITypeSymbol type)
|
||||
{
|
||||
type.BuildDisplayName(cx, trapFile);
|
||||
}
|
||||
else
|
||||
{
|
||||
trapFile.Write("unknown");
|
||||
}
|
||||
trapFile.Write(")");
|
||||
}
|
||||
|
||||
private static void BuildNamedTypeDisplayName(this INamedTypeSymbol namedType, Context cx, TextWriter trapFile, bool constructUnderlyingTupleType)
|
||||
{
|
||||
if (!constructUnderlyingTupleType && namedType.IsTupleType)
|
||||
@@ -484,6 +525,12 @@ namespace Semmle.Extraction.CSharp
|
||||
return;
|
||||
}
|
||||
|
||||
if (namedType.IsExtension)
|
||||
{
|
||||
namedType.BuildExtensionTypeDisplayName(cx, trapFile);
|
||||
return;
|
||||
}
|
||||
|
||||
if (namedType.IsAnonymousType)
|
||||
{
|
||||
namedType.BuildAnonymousName(cx, trapFile);
|
||||
@@ -596,6 +643,87 @@ namespace Semmle.Extraction.CSharp
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return true if this method is a compiler-generated extension method.
|
||||
/// </summary>
|
||||
public static bool IsCompilerGeneratedExtensionMethod(this IMethodSymbol method) =>
|
||||
method.TryGetExtensionMethod(out _);
|
||||
|
||||
/// <summary>
|
||||
/// Returns true if this method is a compiler-generated extension method,
|
||||
/// and outputs the original extension method declaration.
|
||||
/// </summary>
|
||||
public static bool TryGetExtensionMethod(this IMethodSymbol method, out IMethodSymbol? declaration)
|
||||
{
|
||||
declaration = null;
|
||||
if (method.IsImplicitlyDeclared && method.ContainingSymbol is INamedTypeSymbol containingType)
|
||||
{
|
||||
// Extension types are declared within the same type as the generated
|
||||
// extension method implementation.
|
||||
var extensions = containingType.GetMembers()
|
||||
.OfType<INamedTypeSymbol>()
|
||||
.Where(t => t.IsExtension);
|
||||
// Find the (possibly unbound) original extension method that maps to this implementation (if any).
|
||||
var unboundDeclaration = extensions.SelectMany(e => e.GetMembers())
|
||||
.OfType<IMethodSymbol>()
|
||||
.FirstOrDefault(m => SymbolEqualityComparer.Default.Equals(m.AssociatedExtensionImplementation, method.ConstructedFrom));
|
||||
|
||||
var isFullyConstructed = method.IsBoundGenericMethod();
|
||||
if (isFullyConstructed && unboundDeclaration?.ContainingType is INamedTypeSymbol extensionType)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Use the type arguments from the constructed extension method to construct the extension type.
|
||||
var arguments = method.TypeArguments.ToArray();
|
||||
var (extensionTypeArguments, extensionMethodArguments) = arguments.SplitAt(extensionType.TypeParameters.Length);
|
||||
|
||||
// Construct the extension type.
|
||||
var boundExtensionType = extensionType.IsUnboundGenericType()
|
||||
? extensionType.Construct(extensionTypeArguments.ToArray())
|
||||
: extensionType;
|
||||
|
||||
// Find the extension method declaration within the constructed extension type.
|
||||
var extensionDeclaration = boundExtensionType.GetMembers()
|
||||
.OfType<IMethodSymbol>()
|
||||
.First(c => SymbolEqualityComparer.Default.Equals(c.OriginalDefinition, unboundDeclaration));
|
||||
|
||||
// If the extension declaration is unbound apply the remaning type arguments and construct it.
|
||||
declaration = extensionDeclaration.IsUnboundGenericMethod()
|
||||
? extensionDeclaration.Construct(extensionMethodArguments.ToArray())
|
||||
: extensionDeclaration;
|
||||
}
|
||||
catch
|
||||
{
|
||||
// If anything goes wrong, fall back to the unbound declaration.
|
||||
declaration = unboundDeclaration;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
declaration = unboundDeclaration;
|
||||
}
|
||||
|
||||
}
|
||||
return declaration is not null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns true if this method is an unbound generic method.
|
||||
/// </summary>
|
||||
public static bool IsUnboundGenericMethod(this IMethodSymbol method) =>
|
||||
method.IsGenericMethod && SymbolEqualityComparer.Default.Equals(method.ConstructedFrom, method);
|
||||
|
||||
/// <summary>
|
||||
/// Returns true if this method is a bound generic method.
|
||||
/// </summary>
|
||||
public static bool IsBoundGenericMethod(this IMethodSymbol method) => method.IsGenericMethod && !method.IsUnboundGenericMethod();
|
||||
|
||||
/// <summary>
|
||||
/// Returns true if this type is an unbound generic type.
|
||||
/// </summary>
|
||||
public static bool IsUnboundGenericType(this INamedTypeSymbol type) =>
|
||||
type.IsGenericType && SymbolEqualityComparer.Default.Equals(type.ConstructedFrom, type);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the base type of `symbol`. Unlike `symbol.BaseType`, this excludes effective base
|
||||
/// types of type parameters as well as `object` base types.
|
||||
|
||||
Reference in New Issue
Block a user