Address review comments

This commit is contained in:
Tom Hvitved
2023-09-21 12:50:17 +02:00
parent 2343e5ecd8
commit 04c4e739ac
4 changed files with 42 additions and 62 deletions

View File

@@ -2,7 +2,7 @@ using Semmle.Extraction.CSharp.DependencyFetching;
using Semmle.Extraction.CSharp.StubGenerator;
using Semmle.Util.Logging;
var logger = new ConsoleLogger(Verbosity.Info);
var logger = new ConsoleLogger(Verbosity.Info, logThreadId: false);
using var dependencyManager = new DependencyManager(".", DependencyOptions.Default, logger);
StubGenerator.GenerateStubs(logger, dependencyManager.ReferenceFiles, "codeql_csharp_stubs");

View File

@@ -27,11 +27,12 @@ public static class StubGenerator
var threads = EnvironmentVariables.GetDefaultNumberOfThreads();
using var references = new BlockingCollection<(MetadataReference Reference, string Path)>();
var referenceResolveTasks = GetResolvedReferenceTasks(referencesPaths, references);
Parallel.Invoke(
new ParallelOptions { MaxDegreeOfParallelism = threads },
referenceResolveTasks.ToArray());
Parallel.ForEach(referencesPaths, new ParallelOptions { MaxDegreeOfParallelism = threads }, path =>
{
var reference = MetadataReference.CreateFromFile(path);
references.Add((reference, path));
});
logger.Log(Severity.Info, $"Generating stubs for {references.Count} assemblies.");
@@ -41,43 +42,33 @@ public static class StubGenerator
references.Select(tuple => tuple.Item1),
new CSharpCompilationOptions(OutputKind.ConsoleApplication, allowUnsafe: true));
var referenceStubTasks = references.Select(@ref => (Action)(() => StubReference(compilation, outputPath, @ref.Reference, @ref.Path)));
Parallel.Invoke(
new ParallelOptions { MaxDegreeOfParallelism = threads },
referenceStubTasks.ToArray());
Parallel.ForEach(references, new ParallelOptions { MaxDegreeOfParallelism = threads }, @ref =>
{
StubReference(logger, compilation, outputPath, @ref.Reference, @ref.Path);
});
stopWatch.Stop();
logger.Log(Severity.Info, $"Stub generation took {stopWatch.Elapsed}.");
}
private static IEnumerable<Action> GetResolvedReferenceTasks(IEnumerable<string> referencePaths, BlockingCollection<(MetadataReference, string)> references)
private static void StubReference(ILogger logger, CSharpCompilation compilation, string outputPath, MetadataReference reference, string path)
{
return referencePaths.Select<string, Action>(path => () =>
if (compilation.GetAssemblyOrModuleSymbol(reference) is not IAssemblySymbol assembly)
return;
using var fileStream = new FileStream(FileUtils.NestPaths(logger, outputPath, path.Replace(".dll", ".cs")), FileMode.Create, FileAccess.Write);
using var writer = new StreamWriter(fileStream, new UTF8Encoding(false));
writer.WriteLine("// This file contains auto-generated code.");
writer.WriteLine($"// Generated from `{assembly.Identity}`.");
var visitor = new StubVisitor(assembly, writer);
visitor.StubAttributes(assembly.GetAttributes(), "assembly: ");
foreach (var module in assembly.Modules)
{
var reference = MetadataReference.CreateFromFile(path);
references.Add((reference, path));
});
}
private static void StubReference(CSharpCompilation compilation, string outputPath, MetadataReference reference, string path)
{
if (compilation.GetAssemblyOrModuleSymbol(reference) is IAssemblySymbol assembly)
{
var logger = new ConsoleLogger(Verbosity.Info);
using var fileStream = new FileStream(FileUtils.NestPaths(logger, outputPath, path.Replace(".dll", ".cs")), FileMode.Create, FileAccess.Write);
using var writer = new StreamWriter(fileStream, new UTF8Encoding(false));
writer.WriteLine("// This file contains auto-generated code.");
writer.WriteLine($"// Generated from `{assembly.Identity}`.");
var visitor = new StubVisitor(assembly, writer);
visitor.StubAttributes(assembly.GetAttributes(), "assembly: ");
foreach (var module in assembly.Modules)
{
module.GlobalNamespace.Accept(new StubVisitor(assembly, writer));
}
module.GlobalNamespace.Accept(visitor);
}
}
}

View File

@@ -36,7 +36,7 @@ internal sealed class StubVisitor : SymbolVisitor
IsRelevantBaseType(symbol) &&
SymbolEqualityComparer.Default.Equals(symbol.ContainingAssembly, assembly);
private bool IsRelevantNamespace(INamespaceSymbol symbol) => isRelevantNamespace[symbol];
private bool IsRelevantNamespace(INamespaceSymbol symbol) => isRelevantNamespace.Invoke(symbol);
private void StubExplicitInterface(ISymbol symbol, ISymbol? explicitInterfaceSymbol, bool writeName = true)
{
@@ -109,7 +109,7 @@ internal sealed class StubVisitor : SymbolVisitor
case Accessibility.Internal:
stubWriter.Write("internal ");
break;
case Accessibility.ProtectedAndInternal or Accessibility.ProtectedOrInternal:
case Accessibility.ProtectedAndInternal:
stubWriter.Write("protected internal ");
break;
default:
@@ -156,7 +156,7 @@ internal sealed class StubVisitor : SymbolVisitor
stubWriter.Write("extern ");
}
public void StubTypedConstant(TypedConstant c)
private void StubTypedConstant(TypedConstant c)
{
switch (c.Kind)
{
@@ -221,7 +221,9 @@ internal sealed class StubVisitor : SymbolVisitor
if (!attributeAllowList.Contains(qualifiedName))
return;
stubWriter.Write($"[{prefix}{qualifiedName.AsSpan(0, @class.GetQualifiedName().Length - 9)}");
if (qualifiedName.EndsWith("Attribute"))
qualifiedName = qualifiedName[..^9];
stubWriter.Write($"[{prefix}{qualifiedName}");
if (a.ConstructorArguments.Any())
{
stubWriter.Write("(");
@@ -295,12 +297,8 @@ internal sealed class StubVisitor : SymbolVisitor
"volatile", "while"
};
private static string EscapeIdentifier(string identifier)
{
if (keywords.Contains(identifier))
return "@" + identifier;
return identifier;
}
private static string EscapeIdentifier(string identifier) =>
keywords.Contains(identifier) ? "@" + identifier : identifier;
public override void VisitField(IFieldSymbol symbol)
{
@@ -739,7 +737,7 @@ internal sealed class StubVisitor : SymbolVisitor
else
{
var seenCtor = false;
foreach (var childSymbol in symbol.GetMembers())
foreach (var childSymbol in symbol.GetMembers().OrderBy(m => m.GetName()))
{
seenCtor |= childSymbol is IMethodSymbol method && method.MethodKind == MethodKind.Constructor;
childSymbol.Accept(this);
@@ -768,7 +766,7 @@ internal sealed class StubVisitor : SymbolVisitor
if (!isGlobal)
stubWriter.WriteLine($"namespace {symbol.Name} {{");
foreach (var childSymbol in symbol.GetMembers())
foreach (var childSymbol in symbol.GetMembers().OrderBy(m => m.GetName()))
{
childSymbol.Accept(this);
}

View File

@@ -14,17 +14,14 @@ public class MemoizedFunc<T1, T2> where T1 : notnull
this.f = f;
}
public T2 this[T1 s]
public T2 Invoke(T1 s)
{
get
if (!cache.TryGetValue(s, out var t))
{
if (!cache.TryGetValue(s, out var t))
{
t = f(s);
cache[s] = t;
}
return t;
t = f(s);
cache[s] = t;
}
return t;
}
}
@@ -38,11 +35,5 @@ public class ConcurrentMemoizedFunc<T1, T2> where T1 : notnull
this.f = f;
}
public T2 this[T1 s]
{
get
{
return cache.GetOrAdd(s, f);
}
}
public T2 Invoke(T1 s) => cache.GetOrAdd(s, f);
}