From edfdc9812fba4e722726e636f452e9c84ee65f8d Mon Sep 17 00:00:00 2001 From: Michael Nebel Date: Wed, 4 Feb 2026 15:57:16 +0100 Subject: [PATCH] C#: Extract extension types and members. Replacing invocations to static generated methods with invocation of extension type member. --- .../SymbolExtensions.cs | 132 +++++++++++++++++- .../Entities/Expressions/Invocation.cs | 61 ++++++-- .../Entities/Method.cs | 27 +++- .../Entities/OrdinaryMethod.cs | 6 +- .../Entities/Types/NamedType.cs | 21 +++ .../Entities/Types/Type.cs | 3 +- .../Semmle.Extraction.CSharp/Trap/Tuples.cs | 3 + .../Semmle.Util/IEnumerableExtensions.cs | 23 +++ 8 files changed, 256 insertions(+), 20 deletions(-) diff --git a/csharp/extractor/Semmle.Extraction.CSharp/CodeAnalysisExtensions/SymbolExtensions.cs b/csharp/extractor/Semmle.Extraction.CSharp/CodeAnalysisExtensions/SymbolExtensions.cs index 72f78f16059..0c0c17df125 100644 --- a/csharp/extractor/Semmle.Extraction.CSharp/CodeAnalysisExtensions/SymbolExtensions.cs +++ b/csharp/extractor/Semmle.Extraction.CSharp/CodeAnalysisExtensions/SymbolExtensions.cs @@ -4,6 +4,7 @@ using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using Microsoft.CodeAnalysis; +using Semmle.Util; using Semmle.Extraction.CSharp.Entities; namespace Semmle.Extraction.CSharp @@ -164,6 +165,7 @@ namespace Semmle.Extraction.CSharp case TypeKind.Enum: case TypeKind.Delegate: case TypeKind.Error: + case TypeKind.Extension: var named = (INamedTypeSymbol)type; named.BuildNamedTypeId(cx, trapFile, symbolBeingDefined, constructUnderlyingTupleType); return; @@ -275,6 +277,20 @@ namespace Semmle.Extraction.CSharp public static IEnumerable GetTupleElementsMaybeNull(this INamedTypeSymbol type) => type.TupleElements; + private static void BuildExtensionTypeId(this INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile) + { + trapFile.Write("extension("); + if (named.ExtensionMarkerName is not null) + { + trapFile.Write(named.ExtensionMarkerName); + } + else + { + trapFile.Write("unknown"); + } + trapFile.Write(")"); + } + private static void BuildQualifierAndName(INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile, ISymbol symbolBeingDefined) { if (named.ContainingType is not null) @@ -289,8 +305,18 @@ namespace Semmle.Extraction.CSharp named.ContainingNamespace.BuildNamespace(cx, trapFile); } - var name = named.IsFileLocal ? named.MetadataName : named.Name; - trapFile.Write(name); + if (named.IsFileLocal) + { + trapFile.Write(named.MetadataName); + } + else if (named.IsExtension) + { + named.BuildExtensionTypeId(cx, trapFile); + } + else + { + trapFile.Write(named.Name); + } } private static void BuildTupleId(INamedTypeSymbol named, Context cx, EscapingTextWriter trapFile, ISymbol symbolBeingDefined) @@ -391,6 +417,7 @@ namespace Semmle.Extraction.CSharp case TypeKind.Enum: case TypeKind.Delegate: case TypeKind.Error: + case TypeKind.Extension: var named = (INamedTypeSymbol)type; named.BuildNamedTypeDisplayName(cx, trapFile, constructUnderlyingTupleType); return; @@ -465,6 +492,20 @@ namespace Semmle.Extraction.CSharp private static void BuildFunctionPointerTypeDisplayName(this IFunctionPointerTypeSymbol funptr, Context cx, TextWriter trapFile) => BuildFunctionPointerSignature(funptr, trapFile, s => s.BuildDisplayName(cx, trapFile)); + private static void BuildExtensionTypeDisplayName(this INamedTypeSymbol named, Context cx, TextWriter trapFile) + { + trapFile.Write("extension("); + if (named.ExtensionParameter?.Type is ITypeSymbol type) + { + type.BuildDisplayName(cx, trapFile); + } + else + { + trapFile.Write("unknown"); + } + trapFile.Write(")"); + } + private static void BuildNamedTypeDisplayName(this INamedTypeSymbol namedType, Context cx, TextWriter trapFile, bool constructUnderlyingTupleType) { if (!constructUnderlyingTupleType && namedType.IsTupleType) @@ -484,6 +525,12 @@ namespace Semmle.Extraction.CSharp return; } + if (namedType.IsExtension) + { + namedType.BuildExtensionTypeDisplayName(cx, trapFile); + return; + } + if (namedType.IsAnonymousType) { namedType.BuildAnonymousName(cx, trapFile); @@ -596,6 +643,87 @@ namespace Semmle.Extraction.CSharp return true; } + /// + /// Return true if this method is a compiler-generated extension method. + /// + public static bool IsCompilerGeneratedExtensionMethod(this IMethodSymbol method) => + method.TryGetExtensionMethod(out _); + + /// + /// Returns true if this method is a compiler-generated extension method, + /// and outputs the original extension method declaration. + /// + public static bool TryGetExtensionMethod(this IMethodSymbol method, out IMethodSymbol? declaration) + { + declaration = null; + if (method.IsImplicitlyDeclared && method.ContainingSymbol is INamedTypeSymbol containingType) + { + // Extension types are declared within the same type as the generated + // extension method implementation. + var extensions = containingType.GetMembers() + .OfType() + .Where(t => t.IsExtension); + // Find the (possibly unbound) original extension method that maps to this implementation (if any). + var unboundDeclaration = extensions.SelectMany(e => e.GetMembers()) + .OfType() + .FirstOrDefault(m => SymbolEqualityComparer.Default.Equals(m.AssociatedExtensionImplementation, method.ConstructedFrom)); + + var isFullyConstructed = method.IsBoundGenericMethod(); + if (isFullyConstructed && unboundDeclaration?.ContainingType is INamedTypeSymbol extensionType) + { + try + { + // Use the type arguments from the constructed extension method to construct the extension type. + var arguments = method.TypeArguments.ToArray(); + var (extensionTypeArguments, extensionMethodArguments) = arguments.SplitAt(extensionType.TypeParameters.Length); + + // Construct the extension type. + var boundExtensionType = extensionType.IsUnboundGenericType() + ? extensionType.Construct(extensionTypeArguments.ToArray()) + : extensionType; + + // Find the extension method declaration within the constructed extension type. + var extensionDeclaration = boundExtensionType.GetMembers() + .OfType() + .First(c => SymbolEqualityComparer.Default.Equals(c.OriginalDefinition, unboundDeclaration)); + + // If the extension declaration is unbound apply the remaning type arguments and construct it. + declaration = extensionDeclaration.IsUnboundGenericMethod() + ? extensionDeclaration.Construct(extensionMethodArguments.ToArray()) + : extensionDeclaration; + } + catch + { + // If anything goes wrong, fall back to the unbound declaration. + declaration = unboundDeclaration; + } + } + else + { + declaration = unboundDeclaration; + } + + } + return declaration is not null; + } + + /// + /// Returns true if this method is an unbound generic method. + /// + public static bool IsUnboundGenericMethod(this IMethodSymbol method) => + method.IsGenericMethod && SymbolEqualityComparer.Default.Equals(method.ConstructedFrom, method); + + /// + /// Returns true if this method is a bound generic method. + /// + public static bool IsBoundGenericMethod(this IMethodSymbol method) => method.IsGenericMethod && !method.IsUnboundGenericMethod(); + + /// + /// Returns true if this type is an unbound generic type. + /// + public static bool IsUnboundGenericType(this INamedTypeSymbol type) => + type.IsGenericType && SymbolEqualityComparer.Default.Equals(type.ConstructedFrom, type); + /// /// Gets the base type of `symbol`. Unlike `symbol.BaseType`, this excludes effective base /// types of type parameters as well as `object` base types. diff --git a/csharp/extractor/Semmle.Extraction.CSharp/Entities/Expressions/Invocation.cs b/csharp/extractor/Semmle.Extraction.CSharp/Entities/Expressions/Invocation.cs index a6272974c22..26d64339ef0 100644 --- a/csharp/extractor/Semmle.Extraction.CSharp/Entities/Expressions/Invocation.cs +++ b/csharp/extractor/Semmle.Extraction.CSharp/Entities/Expressions/Invocation.cs @@ -24,6 +24,16 @@ namespace Semmle.Extraction.CSharp.Entities.Expressions private bool IsExplicitDelegateInvokeCall() => Kind == ExprKind.DELEGATE_INVOCATION && Context.GetModel(Syntax.Expression).GetSymbolInfo(Syntax.Expression).Symbol is IMethodSymbol m && m.MethodKind == MethodKind.DelegateInvoke; + private bool IsOperatorCall() => Kind == ExprKind.OPERATOR_INVOCATION; + + private bool IsValidMemberAccessKind() + { + return Kind == ExprKind.METHOD_INVOCATION || + IsEventDelegateCall() || + IsExplicitDelegateInvokeCall() || + IsOperatorCall(); + } + protected override void PopulateExpression(TextWriter trapFile) { if (IsNameof(Syntax)) @@ -37,7 +47,7 @@ namespace Semmle.Extraction.CSharp.Entities.Expressions var target = TargetSymbol; switch (Syntax.Expression) { - case MemberAccessExpressionSyntax memberAccess when Kind == ExprKind.METHOD_INVOCATION || IsEventDelegateCall() || IsExplicitDelegateInvokeCall(): + case MemberAccessExpressionSyntax memberAccess when IsValidMemberAccessKind(): memberName = memberAccess.Name.Identifier.Text; if (Syntax.Expression.Kind() == SyntaxKind.SimpleMemberAccessExpression) // Qualified method call; `x.M()` @@ -113,14 +123,31 @@ namespace Semmle.Extraction.CSharp.Entities.Expressions public SymbolInfo SymbolInfo => info.SymbolInfo; + private static bool IsOperatorLikeCall(ExpressionNodeInfo info) + { + return info.SymbolInfo.Symbol is IMethodSymbol method && + method.TryGetExtensionMethod(out var original) && + original!.MethodKind == MethodKind.UserDefinedOperator; + } + public IMethodSymbol? TargetSymbol { get { var si = SymbolInfo; - if (si.Symbol is not null) - return si.Symbol as IMethodSymbol; + if (si.Symbol is ISymbol symbol) + { + var method = symbol as IMethodSymbol; + // Case for compiler-generated extension methods. + if (method is not null && + method.TryGetExtensionMethod(out var original)) + { + return original; + } + + return method; + } if (si.CandidateReason == CandidateReason.OverloadResolutionFailure) { @@ -196,15 +223,25 @@ namespace Semmle.Extraction.CSharp.Entities.Expressions private static ExprKind GetKind(ExpressionNodeInfo info) { - return IsNameof((InvocationExpressionSyntax)info.Node) - ? ExprKind.NAMEOF - : IsDelegateLikeCall(info) - ? IsDelegateInvokeCall(info) - ? ExprKind.DELEGATE_INVOCATION - : ExprKind.FUNCTION_POINTER_INVOCATION - : IsLocalFunctionInvocation(info) - ? ExprKind.LOCAL_FUNCTION_INVOCATION - : ExprKind.METHOD_INVOCATION; + if (IsNameof((InvocationExpressionSyntax)info.Node)) + { + return ExprKind.NAMEOF; + } + if (IsDelegateLikeCall(info)) + { + return IsDelegateInvokeCall(info) + ? ExprKind.DELEGATE_INVOCATION + : ExprKind.FUNCTION_POINTER_INVOCATION; + } + if (IsLocalFunctionInvocation(info)) + { + return ExprKind.LOCAL_FUNCTION_INVOCATION; + } + if (IsOperatorLikeCall(info)) + { + return ExprKind.OPERATOR_INVOCATION; + } + return ExprKind.METHOD_INVOCATION; } private static bool IsNameof(InvocationExpressionSyntax syntax) diff --git a/csharp/extractor/Semmle.Extraction.CSharp/Entities/Method.cs b/csharp/extractor/Semmle.Extraction.CSharp/Entities/Method.cs index c92c561f31b..5a437c3d358 100644 --- a/csharp/extractor/Semmle.Extraction.CSharp/Entities/Method.cs +++ b/csharp/extractor/Semmle.Extraction.CSharp/Entities/Method.cs @@ -14,9 +14,28 @@ namespace Semmle.Extraction.CSharp.Entities protected Method(Context cx, IMethodSymbol init) : base(cx, init) { } + private SyntheticExtensionParameter? SyntheticParameter { get; set; } + + private int SynthesizeExtensionParameter() + { + // Synthesize implicit parameter for extension methods declared using extension(...) syntax. + if (Symbol.ContainingSymbol is INamedTypeSymbol type && + type.IsExtension && type.ExtensionParameter is IParameterSymbol parameter && + !string.IsNullOrEmpty(parameter.Name) && !Symbol.IsStatic) + { + var originalSyntheticParam = OriginalDefinition.SyntheticParameter; + SyntheticParameter = SyntheticExtensionParameter.Create(Context, this, parameter, originalSyntheticParam); + return 1; + } + + return 0; + } + protected void PopulateParameters() { var originalMethod = OriginalDefinition; + var positionOffset = SynthesizeExtensionParameter(); + IEnumerable parameters = Symbol.Parameters; IEnumerable originalParameters = originalMethod.Symbol.Parameters; @@ -24,8 +43,8 @@ namespace Semmle.Extraction.CSharp.Entities { var original = SymbolEqualityComparer.Default.Equals(p.paramSymbol, p.originalParam) ? null - : Parameter.Create(Context, p.originalParam, originalMethod); - Parameter.Create(Context, p.paramSymbol, this, original); + : Parameter.Create(Context, p.originalParam, originalMethod, null, positionOffset); + Parameter.Create(Context, p.paramSymbol, this, original, positionOffset); } if (Symbol.IsVararg) @@ -302,9 +321,9 @@ namespace Semmle.Extraction.CSharp.Entities /// /// Whether this method has unbound type parameters. /// - public bool IsUnboundGeneric => IsGeneric && SymbolEqualityComparer.Default.Equals(Symbol.ConstructedFrom, Symbol); + public bool IsUnboundGeneric => Symbol.IsUnboundGenericMethod(); - public bool IsBoundGeneric => IsGeneric && !IsUnboundGeneric; + public bool IsBoundGeneric => Symbol.IsBoundGenericMethod(); protected IMethodSymbol ConstructedFromSymbol => Symbol.ConstructedFrom; diff --git a/csharp/extractor/Semmle.Extraction.CSharp/Entities/OrdinaryMethod.cs b/csharp/extractor/Semmle.Extraction.CSharp/Entities/OrdinaryMethod.cs index 22bcd1dce2c..2fb148358e8 100644 --- a/csharp/extractor/Semmle.Extraction.CSharp/Entities/OrdinaryMethod.cs +++ b/csharp/extractor/Semmle.Extraction.CSharp/Entities/OrdinaryMethod.cs @@ -23,7 +23,11 @@ namespace Semmle.Extraction.CSharp.Entities ? Symbol.ContainingType.GetSymbolLocation() : BodyDeclaringSymbol.GetSymbolLocation(); - public override bool NeedsPopulation => base.NeedsPopulation || IsCompilerGeneratedDelegate(); + public override bool NeedsPopulation => + (base.NeedsPopulation || IsCompilerGeneratedDelegate()) && + // Exclude compiler-generated extension methods. A call to such a method + // is replaced by a call to the defining extension method. + !Symbol.IsCompilerGeneratedExtensionMethod(); public override void Populate(TextWriter trapFile) { diff --git a/csharp/extractor/Semmle.Extraction.CSharp/Entities/Types/NamedType.cs b/csharp/extractor/Semmle.Extraction.CSharp/Entities/Types/NamedType.cs index dcf2bffe095..078ccf01c10 100644 --- a/csharp/extractor/Semmle.Extraction.CSharp/Entities/Types/NamedType.cs +++ b/csharp/extractor/Semmle.Extraction.CSharp/Entities/Types/NamedType.cs @@ -20,6 +20,8 @@ namespace Semmle.Extraction.CSharp.Entities public static NamedType Create(Context cx, INamedTypeSymbol type) => NamedTypeFactory.Instance.CreateEntityFromSymbol(cx, type); + public NamedType OriginalDefinition => Create(Context, Symbol.OriginalDefinition); + /// /// Creates a named type entity from a tuple type. Unlike , this /// will create an entity for the underlying `System.ValueTuple` struct. @@ -90,6 +92,25 @@ namespace Semmle.Extraction.CSharp.Entities { trapFile.anonymous_types(this); } + + if (Symbol.IsExtension && Symbol.ExtensionParameter is IParameterSymbol parameter) + { + // For some reason an extension type has a receiver parameter with an empty name + // even when there is no parameter. + if (!string.IsNullOrEmpty(parameter.Name)) + { + var originalType = OriginalDefinition; + // In case this is a constructed generic, we also need to create the unbound parameter. + var originalParameter = SymbolEqualityComparer.Default.Equals(Symbol, originalType.Symbol.ExtensionParameter) || originalType.Symbol.ExtensionParameter is null + ? null + : Parameter.Create(Context, originalType.Symbol.ExtensionParameter, originalType); + Parameter.Create(Context, parameter, this, originalParameter); + } + + // Use the parameter type as the receiver type. + var receiverType = Type.Create(Context, parameter.Type).TypeRef; + trapFile.extension_receiver_type(this, receiverType); + } } private readonly Lazy typeArgumentsLazy; diff --git a/csharp/extractor/Semmle.Extraction.CSharp/Entities/Types/Type.cs b/csharp/extractor/Semmle.Extraction.CSharp/Entities/Types/Type.cs index 3e79a8f8101..0f28a1153e2 100644 --- a/csharp/extractor/Semmle.Extraction.CSharp/Entities/Types/Type.cs +++ b/csharp/extractor/Semmle.Extraction.CSharp/Entities/Types/Type.cs @@ -105,6 +105,7 @@ namespace Semmle.Extraction.CSharp.Entities case TypeKind.Pointer: return Kinds.TypeKind.POINTER; case TypeKind.FunctionPointer: return Kinds.TypeKind.FUNCTION_POINTER; case TypeKind.Error: return Kinds.TypeKind.UNKNOWN; + case TypeKind.Extension: return Kinds.TypeKind.EXTENSION; default: cx.ModelError(Symbol, $"Unhandled type kind '{Symbol.TypeKind}'"); return Kinds.TypeKind.UNKNOWN; @@ -366,7 +367,7 @@ namespace Semmle.Extraction.CSharp.Entities private DelegateTypeParameter(Context cx, IParameterSymbol init, IEntity parent, Parameter? original) : base(cx, init, parent, original) { } - public static new DelegateTypeParameter Create(Context cx, IParameterSymbol param, IEntity parent, Parameter? original = null) => + public static DelegateTypeParameter Create(Context cx, IParameterSymbol param, IEntity parent, Parameter? original = null) => // We need to use a different cache key than `param` to avoid mixing up // `DelegateTypeParameter`s and `Parameter`s DelegateTypeParameterFactory.Instance.CreateEntity(cx, (typeof(DelegateTypeParameter), new SymbolEqualityWrapper(param)), (param, parent, original)); diff --git a/csharp/extractor/Semmle.Extraction.CSharp/Trap/Tuples.cs b/csharp/extractor/Semmle.Extraction.CSharp/Trap/Tuples.cs index c1d082bbee2..1a25da058bd 100644 --- a/csharp/extractor/Semmle.Extraction.CSharp/Trap/Tuples.cs +++ b/csharp/extractor/Semmle.Extraction.CSharp/Trap/Tuples.cs @@ -202,6 +202,9 @@ namespace Semmle.Extraction.CSharp internal static void extend(this TextWriter trapFile, Type type, Type super) => trapFile.WriteTuple("extend", type, super); + internal static void extension_receiver_type(this TextWriter trapFile, Type @extension, Type receiverType) => + trapFile.WriteTuple("extension_receiver_type", extension, receiverType); + internal static void anonymous_types(this TextWriter trapFile, Type type) => trapFile.WriteTuple("anonymous_types", type); diff --git a/csharp/extractor/Semmle.Util/IEnumerableExtensions.cs b/csharp/extractor/Semmle.Util/IEnumerableExtensions.cs index 1ca676f0ce6..f4cc97a4d48 100644 --- a/csharp/extractor/Semmle.Util/IEnumerableExtensions.cs +++ b/csharp/extractor/Semmle.Util/IEnumerableExtensions.cs @@ -119,5 +119,28 @@ namespace Semmle.Util /// public static IEnumerable WhereNotNull(this IEnumerable items) where T : class => items.Where(i => i is not null)!; + + /// + /// Splits the sequence at the given index. + /// + public static (IEnumerable, IEnumerable) SplitAt(this IEnumerable items, int index) + { + var left = new List(); + var right = new List(); + var i = 0; + foreach (var item in items) + { + if (i < index) + { + left.Add(item); + } + else + { + right.Add(item); + } + i++; + } + return (left, right); + } } }