Kotlin: Define DiagnosticTrapWriter, for type safety

In some cases, we were writing diagnostics to TRAP files where they
shouldn't be written. Such TRAP files don't define #compilation, so TRAP
import gave errors.

Now we use DiagnosticTrapWriter to get the type system to check that we
are writing diagnostics to the right place.
This commit is contained in:
Ian Lynagh
2023-06-21 18:36:39 +01:00
parent 61a3f86f0f
commit bfd0a19d85
5 changed files with 82 additions and 47 deletions

View File

@@ -14,7 +14,7 @@ import java.util.ArrayList
import java.util.HashSet
import java.util.zip.GZIPOutputStream
class ExternalDeclExtractor(val logger: FileLogger, val invocationTrapFile: String, val sourceFilePath: String, val primitiveTypeMapping: PrimitiveTypeMapping, val pluginContext: IrPluginContext, val globalExtensionState: KotlinExtractorGlobalState, val diagnosticTrapWriter: TrapWriter) {
class ExternalDeclExtractor(val logger: FileLogger, val invocationTrapFile: String, val sourceFilePath: String, val primitiveTypeMapping: PrimitiveTypeMapping, val pluginContext: IrPluginContext, val globalExtensionState: KotlinExtractorGlobalState, val diagnosticTrapWriter: DiagnosticTrapWriter) {
val declBinaryNames = HashMap<IrDeclaration, String>()
val externalDeclsDone = HashSet<Pair<String, String>>()
@@ -95,8 +95,8 @@ class ExternalDeclExtractor(val logger: FileLogger, val invocationTrapFile: Stri
val binaryPath = getIrClassBinaryPath(containingClass)
// We want our comments to be the first thing in the file,
// so start off with a mere TrapWriter
val tw = TrapWriter(logger.loggerBase, TrapLabelManager(), trapFileBW, diagnosticTrapWriter)
// so start off with a PlainTrapWriter
val tw = PlainTrapWriter(logger.loggerBase, TrapLabelManager(), trapFileBW, diagnosticTrapWriter)
tw.writeComment("Generated by the CodeQL Kotlin extractor for external dependencies")
tw.writeComment("Part of invocation $invocationTrapFile")
if (signature != possiblyLongSignature) {

View File

@@ -127,7 +127,7 @@ class KotlinExtractorExtension(
val lm = TrapLabelManager()
val logCounter = LogCounter()
val loggerBase = LoggerBase(logCounter)
val tw = TrapWriter(loggerBase, lm, invocationTrapFileBW, null)
val tw = DiagnosticTrapWriter(loggerBase, lm, invocationTrapFileBW)
// The interceptor has already defined #compilation = *
val compilation: Label<DbCompilation> = StringLabel("compilation")
tw.writeCompilation_started(compilation)
@@ -324,13 +324,13 @@ private fun doFile(
trapFileWriter.getTempWriter().use { trapFileBW ->
// We want our comments to be the first thing in the file,
// so start off with a mere TrapWriter
val tw = TrapWriter(loggerBase, TrapLabelManager(), trapFileBW, fileTrapWriter)
val tw = PlainTrapWriter(loggerBase, TrapLabelManager(), trapFileBW, fileTrapWriter.getDiagnosticTrapWriter())
tw.writeComment("Generated by the CodeQL Kotlin extractor for kotlin source code")
tw.writeComment("Part of invocation $invocationTrapFile")
// Now elevate to a SourceFileTrapWriter, and populate the
// file information
val sftw = tw.makeSourceFileTrapWriter(srcFile, true)
val externalDeclExtractor = ExternalDeclExtractor(logger, invocationTrapFile, srcFilePath, primitiveTypeMapping, pluginContext, globalExtensionState, fileTrapWriter)
val externalDeclExtractor = ExternalDeclExtractor(logger, invocationTrapFile, srcFilePath, primitiveTypeMapping, pluginContext, globalExtensionState, fileTrapWriter.getDiagnosticTrapWriter())
val linesOfCode = LinesOfCode(logger, sftw, srcFile)
val fileExtractor = KotlinFileExtractor(logger, sftw, linesOfCode, srcFilePath, null, externalDeclExtractor, primitiveTypeMapping, pluginContext, KotlinFileExtractor.DeclarationStack(), globalExtensionState)

View File

@@ -139,13 +139,13 @@ open class KotlinUsesExtractor(
if (clsFile == null || isExternalDeclaration(cls)) {
val filePath = getIrClassBinaryPath(cls)
val newTrapWriter = tw.makeFileTrapWriter(filePath, true)
val newLoggerTrapWriter = logger.tw.makeFileTrapWriter(filePath, false)
val newLoggerTrapWriter = logger.dtw.makeFileTrapWriter(filePath, false)
val newLogger = FileLogger(logger.loggerBase, newLoggerTrapWriter)
return KotlinFileExtractor(newLogger, newTrapWriter, null, filePath, dependencyCollector, externalClassExtractor, primitiveTypeMapping, pluginContext, newDeclarationStack, globalExtensionState)
}
val newTrapWriter = tw.makeSourceFileTrapWriter(clsFile, true)
val newLoggerTrapWriter = logger.tw.makeSourceFileTrapWriter(clsFile, false)
val newLoggerTrapWriter = logger.dtw.makeSourceFileTrapWriter(clsFile, false)
val newLogger = FileLogger(logger.loggerBase, newLoggerTrapWriter)
return KotlinFileExtractor(newLogger, newTrapWriter, null, clsFile.path, dependencyCollector, externalClassExtractor, primitiveTypeMapping, pluginContext, newDeclarationStack, globalExtensionState)
}

View File

@@ -57,7 +57,9 @@ class TrapLabelManager {
* share the same `TrapLabelManager` and `BufferedWriter`.
*/
// TODO lm was `protected` before anonymousTypeMapping and locallyVisibleFunctionLabelMapping moved into it. Should we re-protect it and provide accessors?
open class TrapWriter (protected val loggerBase: LoggerBase, val lm: TrapLabelManager, private val bw: BufferedWriter, val diagnosticTrapWriter: TrapWriter?) {
abstract class TrapWriter (protected val loggerBase: LoggerBase, val lm: TrapLabelManager, private val bw: BufferedWriter) {
abstract fun getDiagnosticTrapWriter(): DiagnosticTrapWriter
/**
* Returns the label that is defined to be the given key, if such
* a label exists, and `null` otherwise. Most users will want to use
@@ -223,7 +225,7 @@ open class TrapWriter (protected val loggerBase: LoggerBase, val lm: TrapLabelMa
val len = str.length
val newLen = UTF8Util.encodablePrefixLength(str, MAX_STRLEN)
if (newLen < len) {
loggerBase.warn(diagnosticTrapWriter ?: this,
loggerBase.warn(this.getDiagnosticTrapWriter(),
"Truncated string of length $len",
"Truncated string of length $len, starting '${str.take(100)}', ending '${str.takeLast(100)}'")
return str.take(newLen)
@@ -237,14 +239,43 @@ open class TrapWriter (protected val loggerBase: LoggerBase, val lm: TrapLabelMa
* writer etc), but using the given `filePath` for locations.
*/
fun makeFileTrapWriter(filePath: String, populateFileTables: Boolean) =
FileTrapWriter(loggerBase, lm, bw, diagnosticTrapWriter, filePath, populateFileTables)
FileTrapWriter(loggerBase, lm, bw, this.getDiagnosticTrapWriter(), filePath, populateFileTables)
/**
* Gets a FileTrapWriter like this one (using the same label manager,
* writer etc), but using the given `IrFile` for locations.
*/
fun makeSourceFileTrapWriter(file: IrFile, populateFileTables: Boolean) =
SourceFileTrapWriter(loggerBase, lm, bw, diagnosticTrapWriter, file, populateFileTables)
SourceFileTrapWriter(loggerBase, lm, bw, this.getDiagnosticTrapWriter(), file, populateFileTables)
}
/**
* A `PlainTrapWriter` has no additional context of its own.
*/
class PlainTrapWriter (
loggerBase: LoggerBase,
lm: TrapLabelManager,
bw: BufferedWriter,
val dtw: DiagnosticTrapWriter
): TrapWriter (loggerBase, lm, bw) {
override fun getDiagnosticTrapWriter(): DiagnosticTrapWriter {
return dtw
}
}
/**
* A `DiagnosticTrapWriter` is a TrapWriter that diagnostics can be
* written to; i.e. it has the #compilation label defined. In practice,
* this means that it is a TrapWriter for the invocation TRAP file.
*/
class DiagnosticTrapWriter (
loggerBase: LoggerBase,
lm: TrapLabelManager,
bw: BufferedWriter
): TrapWriter (loggerBase, lm, bw) {
override fun getDiagnosticTrapWriter(): DiagnosticTrapWriter {
return this
}
}
/**
@@ -259,16 +290,20 @@ open class FileTrapWriter (
loggerBase: LoggerBase,
lm: TrapLabelManager,
bw: BufferedWriter,
diagnosticTrapWriter: TrapWriter?,
val dtw: DiagnosticTrapWriter,
val filePath: String,
populateFileTables: Boolean
): TrapWriter (loggerBase, lm, bw, diagnosticTrapWriter) {
): TrapWriter (loggerBase, lm, bw) {
/**
* The ID for the file that we are extracting from.
*/
val fileId = mkFileId(filePath, populateFileTables)
override fun getDiagnosticTrapWriter(): DiagnosticTrapWriter {
return dtw
}
private fun offsetMinOf(default: Int, vararg options: Int?): Int {
if (default == UNDEFINED_OFFSET || default == SYNTHETIC_OFFSET) {
return default
@@ -349,10 +384,10 @@ class SourceFileTrapWriter (
loggerBase: LoggerBase,
lm: TrapLabelManager,
bw: BufferedWriter,
diagnosticTrapWriter: TrapWriter?,
dtw: DiagnosticTrapWriter,
val irFile: IrFile,
populateFileTables: Boolean) :
FileTrapWriter(loggerBase, lm, bw, diagnosticTrapWriter, irFile.path, populateFileTables) {
FileTrapWriter(loggerBase, lm, bw, dtw, irFile.path, populateFileTables) {
/**
* The file entry for the file that we are extracting from.
@@ -363,14 +398,14 @@ class SourceFileTrapWriter (
override fun getLocation(startOffset: Int, endOffset: Int): Label<DbLocation> {
if (startOffset == UNDEFINED_OFFSET || endOffset == UNDEFINED_OFFSET) {
if (startOffset != endOffset) {
loggerBase.warn(this, "Location with inconsistent offsets (start $startOffset, end $endOffset)", null)
loggerBase.warn(dtw, "Location with inconsistent offsets (start $startOffset, end $endOffset)", null)
}
return getWholeFileLocation()
}
if (startOffset == SYNTHETIC_OFFSET || endOffset == SYNTHETIC_OFFSET) {
if (startOffset != endOffset) {
loggerBase.warn(this, "Location with inconsistent offsets (start $startOffset, end $endOffset)", null)
loggerBase.warn(dtw, "Location with inconsistent offsets (start $startOffset, end $endOffset)", null)
}
return getWholeFileLocation()
}
@@ -390,14 +425,14 @@ class SourceFileTrapWriter (
override fun getLocationString(e: IrElement): String {
if (e.startOffset == UNDEFINED_OFFSET || e.endOffset == UNDEFINED_OFFSET) {
if (e.startOffset != e.endOffset) {
loggerBase.warn(this, "Location with inconsistent offsets (start ${e.startOffset}, end ${e.endOffset})", null)
loggerBase.warn(dtw, "Location with inconsistent offsets (start ${e.startOffset}, end ${e.endOffset})", null)
}
return "<unknown location while processing $filePath>"
}
if (e.startOffset == SYNTHETIC_OFFSET || e.endOffset == SYNTHETIC_OFFSET) {
if (e.startOffset != e.endOffset) {
loggerBase.warn(this, "Location with inconsistent offsets (start ${e.startOffset}, end ${e.endOffset})", null)
loggerBase.warn(dtw, "Location with inconsistent offsets (start ${e.startOffset}, end ${e.endOffset})", null)
}
return "<synthetic location while processing $filePath>"
}

View File

@@ -107,7 +107,7 @@ open class LoggerBase(val logCounter: LogCounter) {
file_number_diagnostic_number = 0
}
fun diagnostic(tw: TrapWriter, severity: Severity, msg: String, extraInfo: String?, locationString: String? = null, mkLocationId: () -> Label<DbLocation> = { tw.unknownLocation }) {
fun diagnostic(dtw: DiagnosticTrapWriter, severity: Severity, msg: String, extraInfo: String?, locationString: String? = null, mkLocationId: () -> Label<DbLocation> = { dtw.unknownLocation }) {
val diagnosticLoc = getDiagnosticLocation()
val diagnosticLocStr = if(diagnosticLoc == null) "<unknown location>" else diagnosticLoc
val suffix =
@@ -121,7 +121,7 @@ open class LoggerBase(val logCounter: LogCounter) {
// counting machinery
if (verbosity >= 1) {
val message = "Severity mismatch ($severity vs ${oldInfo.first}) at $diagnosticLoc"
emitDiagnostic(tw, Severity.Error, "Inconsistency", message, message)
emitDiagnostic(dtw, Severity.Error, "Inconsistency", message, message)
}
}
val newCount = oldInfo.second + 1
@@ -149,18 +149,18 @@ open class LoggerBase(val logCounter: LogCounter) {
fullMsgBuilder.append(suffix)
val fullMsg = fullMsgBuilder.toString()
emitDiagnostic(tw, severity, diagnosticLocStr, msg, fullMsg, locationString, mkLocationId)
emitDiagnostic(dtw, severity, diagnosticLocStr, msg, fullMsg, locationString, mkLocationId)
}
private fun emitDiagnostic(tw: TrapWriter, severity: Severity, diagnosticLocStr: String, msg: String, fullMsg: String, locationString: String? = null, mkLocationId: () -> Label<DbLocation> = { tw.unknownLocation }) {
private fun emitDiagnostic(dtw: DiagnosticTrapWriter, severity: Severity, diagnosticLocStr: String, msg: String, fullMsg: String, locationString: String? = null, mkLocationId: () -> Label<DbLocation> = { dtw.unknownLocation }) {
val locStr = if (locationString == null) "" else "At " + locationString + ": "
val kind = if (severity <= Severity.WarnHigh) "WARN" else "ERROR"
val logMessage = LogMessage(kind, "Diagnostic($diagnosticLocStr): $locStr$fullMsg")
// We don't actually make the location until after the `return` above
val locationId = mkLocationId()
val diagLabel = tw.getFreshIdLabel<DbDiagnostic>()
tw.writeDiagnostics(diagLabel, "CodeQL Kotlin extractor", severity.sev, "", msg, "${logMessage.timestamp} $fullMsg", locationId)
tw.writeDiagnostic_for(diagLabel, StringLabel("compilation"), file_number, file_number_diagnostic_number++)
val diagLabel = dtw.getFreshIdLabel<DbDiagnostic>()
dtw.writeDiagnostics(diagLabel, "CodeQL Kotlin extractor", severity.sev, "", msg, "${logMessage.timestamp} $fullMsg", locationId)
dtw.writeDiagnostic_for(diagLabel, StringLabel("compilation"), file_number, file_number_diagnostic_number++)
logStream.write(logMessage.toJsonLine())
}
@@ -188,18 +188,18 @@ open class LoggerBase(val logCounter: LogCounter) {
}
}
fun warn(tw: TrapWriter, msg: String, extraInfo: String?) {
fun warn(dtw: DiagnosticTrapWriter, msg: String, extraInfo: String?) {
if (verbosity >= 2) {
diagnostic(tw, Severity.Warn, msg, extraInfo)
diagnostic(dtw, Severity.Warn, msg, extraInfo)
}
}
fun error(tw: TrapWriter, msg: String, extraInfo: String?) {
fun error(dtw: DiagnosticTrapWriter, msg: String, extraInfo: String?) {
if (verbosity >= 1) {
diagnostic(tw, Severity.Error, msg, extraInfo)
diagnostic(dtw, Severity.Error, msg, extraInfo)
}
}
fun printLimitedDiagnosticCounts(tw: TrapWriter) {
fun printLimitedDiagnosticCounts(dtw: DiagnosticTrapWriter) {
for((caller, info) in logCounter.diagnosticInfo) {
val severity = info.first
val count = info.second
@@ -209,7 +209,7 @@ open class LoggerBase(val logCounter: LogCounter) {
// to be an error regardless.
val message = "Total of $count diagnostics (reached limit of ${logCounter.diagnosticLimit}) from $caller."
if (verbosity >= 1) {
emitDiagnostic(tw, severity, "Limit", message, message)
emitDiagnostic(dtw, severity, "Limit", message, message)
}
}
}
@@ -224,28 +224,28 @@ open class LoggerBase(val logCounter: LogCounter) {
}
}
open class Logger(val loggerBase: LoggerBase, open val tw: TrapWriter) {
open class Logger(val loggerBase: LoggerBase, open val dtw: DiagnosticTrapWriter) {
fun flush() {
tw.flush()
dtw.flush()
loggerBase.flush()
}
fun trace(msg: String) {
loggerBase.trace(tw, msg)
loggerBase.trace(dtw, msg)
}
fun trace(msg: String, exn: Throwable) {
trace(msg + "\n" + exn.stackTraceToString())
}
fun debug(msg: String) {
loggerBase.debug(tw, msg)
loggerBase.debug(dtw, msg)
}
fun info(msg: String) {
loggerBase.info(tw, msg)
loggerBase.info(dtw, msg)
}
private fun warn(msg: String, extraInfo: String?) {
loggerBase.warn(tw, msg, extraInfo)
loggerBase.warn(dtw, msg, extraInfo)
}
fun warn(msg: String, exn: Throwable) {
warn(msg, exn.stackTraceToString())
@@ -255,7 +255,7 @@ open class Logger(val loggerBase: LoggerBase, open val tw: TrapWriter) {
}
private fun error(msg: String, extraInfo: String?) {
loggerBase.error(tw, msg, extraInfo)
loggerBase.error(dtw, msg, extraInfo)
}
fun error(msg: String) {
error(msg, null)
@@ -265,16 +265,16 @@ open class Logger(val loggerBase: LoggerBase, open val tw: TrapWriter) {
}
}
class FileLogger(loggerBase: LoggerBase, override val tw: FileTrapWriter): Logger(loggerBase, tw) {
class FileLogger(loggerBase: LoggerBase, val ftw: FileTrapWriter): Logger(loggerBase, ftw.getDiagnosticTrapWriter()) {
fun warnElement(msg: String, element: IrElement, exn: Throwable? = null) {
val locationString = tw.getLocationString(element)
val mkLocationId = { tw.getLocation(element) }
loggerBase.diagnostic(tw, Severity.Warn, msg, exn?.stackTraceToString(), locationString, mkLocationId)
val locationString = ftw.getLocationString(element)
val mkLocationId = { ftw.getLocation(element) }
loggerBase.diagnostic(ftw.getDiagnosticTrapWriter(), Severity.Warn, msg, exn?.stackTraceToString(), locationString, mkLocationId)
}
fun errorElement(msg: String, element: IrElement, exn: Throwable? = null) {
val locationString = tw.getLocationString(element)
val mkLocationId = { tw.getLocation(element) }
loggerBase.diagnostic(tw, Severity.Error, msg, exn?.stackTraceToString(), locationString, mkLocationId)
val locationString = ftw.getLocationString(element)
val mkLocationId = { ftw.getLocation(element) }
loggerBase.diagnostic(ftw.getDiagnosticTrapWriter(), Severity.Error, msg, exn?.stackTraceToString(), locationString, mkLocationId)
}
}