C#: Extract extension types and members. Replacing invocations to static generated methods with invocation of extension type member.

This commit is contained in:
Michael Nebel
2026-02-04 15:57:16 +01:00
parent ab505e3281
commit edfdc9812f
8 changed files with 256 additions and 20 deletions

View File

@@ -4,6 +4,7 @@ using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using Microsoft.CodeAnalysis;
using Semmle.Util;
using Semmle.Extraction.CSharp.Entities;
namespace Semmle.Extraction.CSharp
@@ -164,6 +165,7 @@ namespace Semmle.Extraction.CSharp
case TypeKind.Enum:
case TypeKind.Delegate:
case TypeKind.Error:
case TypeKind.Extension:
var named = (INamedTypeSymbol)type;
named.BuildNamedTypeId(cx, trapFile, symbolBeingDefined, constructUnderlyingTupleType);
return;
@@ -275,6 +277,20 @@ namespace Semmle.Extraction.CSharp
public static IEnumerable<IFieldSymbol?> GetTupleElementsMaybeNull(this INamedTypeSymbol type) =>
type.TupleElements;
private static void BuildExtensionTypeId(this INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile)
{
trapFile.Write("extension(");
if (named.ExtensionMarkerName is not null)
{
trapFile.Write(named.ExtensionMarkerName);
}
else
{
trapFile.Write("unknown");
}
trapFile.Write(")");
}
private static void BuildQualifierAndName(INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile, ISymbol symbolBeingDefined)
{
if (named.ContainingType is not null)
@@ -289,8 +305,18 @@ namespace Semmle.Extraction.CSharp
named.ContainingNamespace.BuildNamespace(cx, trapFile);
}
var name = named.IsFileLocal ? named.MetadataName : named.Name;
trapFile.Write(name);
if (named.IsFileLocal)
{
trapFile.Write(named.MetadataName);
}
else if (named.IsExtension)
{
named.BuildExtensionTypeId(cx, trapFile);
}
else
{
trapFile.Write(named.Name);
}
}
private static void BuildTupleId(INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile, ISymbol symbolBeingDefined)
@@ -391,6 +417,7 @@ namespace Semmle.Extraction.CSharp
case TypeKind.Enum:
case TypeKind.Delegate:
case TypeKind.Error:
case TypeKind.Extension:
var named = (INamedTypeSymbol)type;
named.BuildNamedTypeDisplayName(cx, trapFile, constructUnderlyingTupleType);
return;
@@ -465,6 +492,20 @@ namespace Semmle.Extraction.CSharp
private static void BuildFunctionPointerTypeDisplayName(this IFunctionPointerTypeSymbol funptr, Context cx, TextWriter trapFile) =>
BuildFunctionPointerSignature(funptr, trapFile, s => s.BuildDisplayName(cx, trapFile));
private static void BuildExtensionTypeDisplayName(this INamedTypeSymbol named, Context cx, TextWriter trapFile)
{
trapFile.Write("extension(");
if (named.ExtensionParameter?.Type is ITypeSymbol type)
{
type.BuildDisplayName(cx, trapFile);
}
else
{
trapFile.Write("unknown");
}
trapFile.Write(")");
}
private static void BuildNamedTypeDisplayName(this INamedTypeSymbol namedType, Context cx, TextWriter trapFile, bool constructUnderlyingTupleType)
{
if (!constructUnderlyingTupleType && namedType.IsTupleType)
@@ -484,6 +525,12 @@ namespace Semmle.Extraction.CSharp
return;
}
if (namedType.IsExtension)
{
namedType.BuildExtensionTypeDisplayName(cx, trapFile);
return;
}
if (namedType.IsAnonymousType)
{
namedType.BuildAnonymousName(cx, trapFile);
@@ -596,6 +643,87 @@ namespace Semmle.Extraction.CSharp
return true;
}
/// <summary>
/// Return true if this method is a compiler-generated extension method.
/// </summary>
public static bool IsCompilerGeneratedExtensionMethod(this IMethodSymbol method) =>
method.TryGetExtensionMethod(out _);
/// <summary>
/// Returns true if this method is a compiler-generated extension method,
/// and outputs the original extension method declaration.
/// </summary>
public static bool TryGetExtensionMethod(this IMethodSymbol method, out IMethodSymbol? declaration)
{
declaration = null;
if (method.IsImplicitlyDeclared && method.ContainingSymbol is INamedTypeSymbol containingType)
{
// Extension types are declared within the same type as the generated
// extension method implementation.
var extensions = containingType.GetMembers()
.OfType<INamedTypeSymbol>()
.Where(t => t.IsExtension);
// Find the (possibly unbound) original extension method that maps to this implementation (if any).
var unboundDeclaration = extensions.SelectMany(e => e.GetMembers())
.OfType<IMethodSymbol>()
.FirstOrDefault(m => SymbolEqualityComparer.Default.Equals(m.AssociatedExtensionImplementation, method.ConstructedFrom));
var isFullyConstructed = method.IsBoundGenericMethod();
if (isFullyConstructed && unboundDeclaration?.ContainingType is INamedTypeSymbol extensionType)
{
try
{
// Use the type arguments from the constructed extension method to construct the extension type.
var arguments = method.TypeArguments.ToArray();
var (extensionTypeArguments, extensionMethodArguments) = arguments.SplitAt(extensionType.TypeParameters.Length);
// Construct the extension type.
var boundExtensionType = extensionType.IsUnboundGenericType()
? extensionType.Construct(extensionTypeArguments.ToArray())
: extensionType;
// Find the extension method declaration within the constructed extension type.
var extensionDeclaration = boundExtensionType.GetMembers()
.OfType<IMethodSymbol>()
.First(c => SymbolEqualityComparer.Default.Equals(c.OriginalDefinition, unboundDeclaration));
// If the extension declaration is unbound apply the remaning type arguments and construct it.
declaration = extensionDeclaration.IsUnboundGenericMethod()
? extensionDeclaration.Construct(extensionMethodArguments.ToArray())
: extensionDeclaration;
}
catch
{
// If anything goes wrong, fall back to the unbound declaration.
declaration = unboundDeclaration;
}
}
else
{
declaration = unboundDeclaration;
}
}
return declaration is not null;
}
/// <summary>
/// Returns true if this method is an unbound generic method.
/// </summary>
public static bool IsUnboundGenericMethod(this IMethodSymbol method) =>
method.IsGenericMethod && SymbolEqualityComparer.Default.Equals(method.ConstructedFrom, method);
/// <summary>
/// Returns true if this method is a bound generic method.
/// </summary>
public static bool IsBoundGenericMethod(this IMethodSymbol method) => method.IsGenericMethod && !method.IsUnboundGenericMethod();
/// <summary>
/// Returns true if this type is an unbound generic type.
/// </summary>
public static bool IsUnboundGenericType(this INamedTypeSymbol type) =>
type.IsGenericType && SymbolEqualityComparer.Default.Equals(type.ConstructedFrom, type);
/// <summary>
/// Gets the base type of `symbol`. Unlike `symbol.BaseType`, this excludes effective base
/// types of type parameters as well as `object` base types.

View File

@@ -24,6 +24,16 @@ namespace Semmle.Extraction.CSharp.Entities.Expressions
private bool IsExplicitDelegateInvokeCall() => Kind == ExprKind.DELEGATE_INVOCATION && Context.GetModel(Syntax.Expression).GetSymbolInfo(Syntax.Expression).Symbol is IMethodSymbol m && m.MethodKind == MethodKind.DelegateInvoke;
private bool IsOperatorCall() => Kind == ExprKind.OPERATOR_INVOCATION;
private bool IsValidMemberAccessKind()
{
return Kind == ExprKind.METHOD_INVOCATION ||
IsEventDelegateCall() ||
IsExplicitDelegateInvokeCall() ||
IsOperatorCall();
}
protected override void PopulateExpression(TextWriter trapFile)
{
if (IsNameof(Syntax))
@@ -37,7 +47,7 @@ namespace Semmle.Extraction.CSharp.Entities.Expressions
var target = TargetSymbol;
switch (Syntax.Expression)
{
case MemberAccessExpressionSyntax memberAccess when Kind == ExprKind.METHOD_INVOCATION || IsEventDelegateCall() || IsExplicitDelegateInvokeCall():
case MemberAccessExpressionSyntax memberAccess when IsValidMemberAccessKind():
memberName = memberAccess.Name.Identifier.Text;
if (Syntax.Expression.Kind() == SyntaxKind.SimpleMemberAccessExpression)
// Qualified method call; `x.M()`
@@ -113,14 +123,31 @@ namespace Semmle.Extraction.CSharp.Entities.Expressions
public SymbolInfo SymbolInfo => info.SymbolInfo;
private static bool IsOperatorLikeCall(ExpressionNodeInfo info)
{
return info.SymbolInfo.Symbol is IMethodSymbol method &&
method.TryGetExtensionMethod(out var original) &&
original!.MethodKind == MethodKind.UserDefinedOperator;
}
public IMethodSymbol? TargetSymbol
{
get
{
var si = SymbolInfo;
if (si.Symbol is not null)
return si.Symbol as IMethodSymbol;
if (si.Symbol is ISymbol symbol)
{
var method = symbol as IMethodSymbol;
// Case for compiler-generated extension methods.
if (method is not null &&
method.TryGetExtensionMethod(out var original))
{
return original;
}
return method;
}
if (si.CandidateReason == CandidateReason.OverloadResolutionFailure)
{
@@ -196,15 +223,25 @@ namespace Semmle.Extraction.CSharp.Entities.Expressions
private static ExprKind GetKind(ExpressionNodeInfo info)
{
return IsNameof((InvocationExpressionSyntax)info.Node)
? ExprKind.NAMEOF
: IsDelegateLikeCall(info)
? IsDelegateInvokeCall(info)
? ExprKind.DELEGATE_INVOCATION
: ExprKind.FUNCTION_POINTER_INVOCATION
: IsLocalFunctionInvocation(info)
? ExprKind.LOCAL_FUNCTION_INVOCATION
: ExprKind.METHOD_INVOCATION;
if (IsNameof((InvocationExpressionSyntax)info.Node))
{
return ExprKind.NAMEOF;
}
if (IsDelegateLikeCall(info))
{
return IsDelegateInvokeCall(info)
? ExprKind.DELEGATE_INVOCATION
: ExprKind.FUNCTION_POINTER_INVOCATION;
}
if (IsLocalFunctionInvocation(info))
{
return ExprKind.LOCAL_FUNCTION_INVOCATION;
}
if (IsOperatorLikeCall(info))
{
return ExprKind.OPERATOR_INVOCATION;
}
return ExprKind.METHOD_INVOCATION;
}
private static bool IsNameof(InvocationExpressionSyntax syntax)

View File

@@ -14,9 +14,28 @@ namespace Semmle.Extraction.CSharp.Entities
protected Method(Context cx, IMethodSymbol init)
: base(cx, init) { }
private SyntheticExtensionParameter? SyntheticParameter { get; set; }
private int SynthesizeExtensionParameter()
{
// Synthesize implicit parameter for extension methods declared using extension(...) syntax.
if (Symbol.ContainingSymbol is INamedTypeSymbol type &&
type.IsExtension && type.ExtensionParameter is IParameterSymbol parameter &&
!string.IsNullOrEmpty(parameter.Name) && !Symbol.IsStatic)
{
var originalSyntheticParam = OriginalDefinition.SyntheticParameter;
SyntheticParameter = SyntheticExtensionParameter.Create(Context, this, parameter, originalSyntheticParam);
return 1;
}
return 0;
}
protected void PopulateParameters()
{
var originalMethod = OriginalDefinition;
var positionOffset = SynthesizeExtensionParameter();
IEnumerable<IParameterSymbol> parameters = Symbol.Parameters;
IEnumerable<IParameterSymbol> originalParameters = originalMethod.Symbol.Parameters;
@@ -24,8 +43,8 @@ namespace Semmle.Extraction.CSharp.Entities
{
var original = SymbolEqualityComparer.Default.Equals(p.paramSymbol, p.originalParam)
? null
: Parameter.Create(Context, p.originalParam, originalMethod);
Parameter.Create(Context, p.paramSymbol, this, original);
: Parameter.Create(Context, p.originalParam, originalMethod, null, positionOffset);
Parameter.Create(Context, p.paramSymbol, this, original, positionOffset);
}
if (Symbol.IsVararg)
@@ -302,9 +321,9 @@ namespace Semmle.Extraction.CSharp.Entities
/// <summary>
/// Whether this method has unbound type parameters.
/// </summary>
public bool IsUnboundGeneric => IsGeneric && SymbolEqualityComparer.Default.Equals(Symbol.ConstructedFrom, Symbol);
public bool IsUnboundGeneric => Symbol.IsUnboundGenericMethod();
public bool IsBoundGeneric => IsGeneric && !IsUnboundGeneric;
public bool IsBoundGeneric => Symbol.IsBoundGenericMethod();
protected IMethodSymbol ConstructedFromSymbol => Symbol.ConstructedFrom;

View File

@@ -23,7 +23,11 @@ namespace Semmle.Extraction.CSharp.Entities
? Symbol.ContainingType.GetSymbolLocation()
: BodyDeclaringSymbol.GetSymbolLocation();
public override bool NeedsPopulation => base.NeedsPopulation || IsCompilerGeneratedDelegate();
public override bool NeedsPopulation =>
(base.NeedsPopulation || IsCompilerGeneratedDelegate()) &&
// Exclude compiler-generated extension methods. A call to such a method
// is replaced by a call to the defining extension method.
!Symbol.IsCompilerGeneratedExtensionMethod();
public override void Populate(TextWriter trapFile)
{

View File

@@ -20,6 +20,8 @@ namespace Semmle.Extraction.CSharp.Entities
public static NamedType Create(Context cx, INamedTypeSymbol type) =>
NamedTypeFactory.Instance.CreateEntityFromSymbol(cx, type);
public NamedType OriginalDefinition => Create(Context, Symbol.OriginalDefinition);
/// <summary>
/// Creates a named type entity from a tuple type. Unlike <see cref="Create"/>, this
/// will create an entity for the underlying `System.ValueTuple` struct.
@@ -90,6 +92,25 @@ namespace Semmle.Extraction.CSharp.Entities
{
trapFile.anonymous_types(this);
}
if (Symbol.IsExtension && Symbol.ExtensionParameter is IParameterSymbol parameter)
{
// For some reason an extension type has a receiver parameter with an empty name
// even when there is no parameter.
if (!string.IsNullOrEmpty(parameter.Name))
{
var originalType = OriginalDefinition;
// In case this is a constructed generic, we also need to create the unbound parameter.
var originalParameter = SymbolEqualityComparer.Default.Equals(Symbol, originalType.Symbol.ExtensionParameter) || originalType.Symbol.ExtensionParameter is null
? null
: Parameter.Create(Context, originalType.Symbol.ExtensionParameter, originalType);
Parameter.Create(Context, parameter, this, originalParameter);
}
// Use the parameter type as the receiver type.
var receiverType = Type.Create(Context, parameter.Type).TypeRef;
trapFile.extension_receiver_type(this, receiverType);
}
}
private readonly Lazy<Type[]> typeArgumentsLazy;

View File

@@ -105,6 +105,7 @@ namespace Semmle.Extraction.CSharp.Entities
case TypeKind.Pointer: return Kinds.TypeKind.POINTER;
case TypeKind.FunctionPointer: return Kinds.TypeKind.FUNCTION_POINTER;
case TypeKind.Error: return Kinds.TypeKind.UNKNOWN;
case TypeKind.Extension: return Kinds.TypeKind.EXTENSION;
default:
cx.ModelError(Symbol, $"Unhandled type kind '{Symbol.TypeKind}'");
return Kinds.TypeKind.UNKNOWN;
@@ -366,7 +367,7 @@ namespace Semmle.Extraction.CSharp.Entities
private DelegateTypeParameter(Context cx, IParameterSymbol init, IEntity parent, Parameter? original)
: base(cx, init, parent, original) { }
public static new DelegateTypeParameter Create(Context cx, IParameterSymbol param, IEntity parent, Parameter? original = null) =>
public static DelegateTypeParameter Create(Context cx, IParameterSymbol param, IEntity parent, Parameter? original = null) =>
// We need to use a different cache key than `param` to avoid mixing up
// `DelegateTypeParameter`s and `Parameter`s
DelegateTypeParameterFactory.Instance.CreateEntity(cx, (typeof(DelegateTypeParameter), new SymbolEqualityWrapper(param)), (param, parent, original));

View File

@@ -202,6 +202,9 @@ namespace Semmle.Extraction.CSharp
internal static void extend(this TextWriter trapFile, Type type, Type super) =>
trapFile.WriteTuple("extend", type, super);
internal static void extension_receiver_type(this TextWriter trapFile, Type @extension, Type receiverType) =>
trapFile.WriteTuple("extension_receiver_type", extension, receiverType);
internal static void anonymous_types(this TextWriter trapFile, Type type) =>
trapFile.WriteTuple("anonymous_types", type);

View File

@@ -119,5 +119,28 @@ namespace Semmle.Util
/// </summary>
public static IEnumerable<T> WhereNotNull<T>(this IEnumerable<T?> items) where T : class =>
items.Where(i => i is not null)!;
/// <summary>
/// Splits the sequence at the given index.
/// </summary>
public static (IEnumerable<T>, IEnumerable<T>) SplitAt<T>(this IEnumerable<T> items, int index)
{
var left = new List<T>();
var right = new List<T>();
var i = 0;
foreach (var item in items)
{
if (i < index)
{
left.Add(item);
}
else
{
right.Add(item);
}
i++;
}
return (left, right);
}
}
}