KE2: Make the shared stuff threadsafe

This commit is contained in:
Ian Lynagh
2024-10-04 16:11:26 +01:00
parent 8711099de2
commit f5033d1e88
3 changed files with 53 additions and 17 deletions

View File

@@ -96,7 +96,7 @@ open class KotlinUsesExtractor(
val pkg = f.packageFqName.asString()
val jvmName = getFileClassName(f)
val id = extractFileClass(pkg, jvmName)
if (tw.lm.fileClassLocationsExtracted.add(f)) {
if (tw.lm.markFileClassLocationAsExtracted(f)) {
val fileId = tw.mkFileId(f.virtualFilePath, false)
val locId = tw.getWholeFileLocation(fileId)
tw.writeHasLocation(id, locId)

View File

@@ -9,6 +9,8 @@ import com.semmle.extractor.java.PopulateFile
import com.semmle.util.unicode.UTF8Util
import java.io.BufferedWriter
import java.io.File
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
/*
OLD: KE1
import org.jetbrains.kotlin.ir.IrElement
@@ -28,16 +30,34 @@ import org.jetbrains.kotlin.psi.*
* names, and maintains a mapping from keys (`@"..."`) to labels.
*/
class TrapLabelManager {
/**
* The lock that controls access to the label manager state.
* While we can make a thread-safe MutableMap with
* Collections.synchronizedMap and use getOrPut, that doesn't
* guarantee not to run the `defaultValue` function when it isn't
* necessary, which makes it useless for our purposes.
* TODO: We only actually need this for the diagnostic TRAP file. Make it optional?
*/
private val lock = ReentrantLock()
/** The next integer to use as a label name. */
private var nextInt: Int = 100
/** Returns a fresh label. */
fun <T : AnyDbType> getFreshLabel(): Label<T> {
return IntLabel(nextInt++)
lock.withLock {
return IntLabel(nextInt++)
}
}
/** A mapping from a key (`@"..."`) to the label defined to be that key, if any. */
val labelMapping: MutableMap<String, Label<*>> = mutableMapOf<String, Label<*>>()
private val labelMapping: MutableMap<String, Label<*>> = mutableMapOf<String, Label<*>>()
fun <T> withLabelMapping(action: (MutableMap<String, Label<*>>) -> T): T {
lock.withLock {
return action(labelMapping)
}
}
/*
OLD: KE1
@@ -60,13 +80,25 @@ class TrapLabelManager {
* This allows us to keep track of whether we've written the location already in this TRAP file,
* to avoid duplication.
*/
val fileClassLocationsExtracted = HashSet<KtFile>()
private val fileClassLocationsExtracted = HashSet<KtFile>()
/**
* Indicate that we want `file`'s file class location marked as extracted.
* Returns true if we need to actually write the TRAP for it, or false
* if it's already been done.
*/
fun markFileClassLocationAsExtracted(file: KtFile): Boolean {
lock.withLock {
return fileClassLocationsExtracted.add(file)
}
}
}
/**
* A `TrapWriter` is used to write TRAP to a particular TRAP file. There may be multiple
* `TrapWriter`s for the same file, as different instances will have different additional state, but
* they must all share the same `TrapLabelManager` and `BufferedWriter`.
* `BasicLogger`s, `TrapLabelManager` and `BufferedWriter` are threadsafe, so `TrapWriter`s are too.
*/
abstract class TrapWriter(
protected val basicLogger: BasicLogger,
@@ -85,7 +117,7 @@ abstract class TrapWriter(
TODO: Inline this if it can remain private
*/
private fun <T : AnyDbType> getExistingLabelFor(key: String): Label<T>? {
return lm.labelMapping.get(key)?.cast<T>()
return lm.withLabelMapping { labelMapping -> labelMapping.get(key)?.cast<T>() }
}
/**
@@ -94,15 +126,17 @@ abstract class TrapWriter(
*/
@JvmOverloads // Needed so Java can call a method with an optional argument
fun <T : AnyDbType> getLabelFor(key: String, initialise: (Label<T>) -> Unit = {}): Label<T> {
val maybeLabel: Label<T>? = getExistingLabelFor(key)
if (maybeLabel == null) {
val label: Label<T> = lm.getFreshLabel()
lm.labelMapping.put(key, label)
writeTrap("$label = $key\n")
initialise(label)
return label
} else {
return maybeLabel
return lm.withLabelMapping { labelMapping ->
val maybeLabel: Label<T>? = getExistingLabelFor(key)
if (maybeLabel == null) {
val label: Label<T> = lm.getFreshLabel()
labelMapping.put(key, label)
writeTrap("$label = $key\n")
initialise(label)
label
} else {
maybeLabel
}
}
}

View File

@@ -1,6 +1,7 @@
package com.github.codeql
import com.intellij.psi.PsiElement
import java.io.BufferedWriter
import java.io.File
import java.io.FileWriter
import java.io.OutputStreamWriter
@@ -127,15 +128,16 @@ class LoggerBase(val diagnosticCounter: DiagnosticCounter) : BasicLogger {
verbosity = System.getenv("CODEQL_EXTRACTOR_KOTLIN_VERBOSITY")?.toIntOrNull() ?: 3
}
private val logStream: Writer
// Use BufferedWriter as it is threadsafe
private val logStream: BufferedWriter
init {
val extractorLogDir = System.getenv("CODEQL_EXTRACTOR_JAVA_LOG_DIR")
if (extractorLogDir == null || extractorLogDir == "") {
logStream = OutputStreamWriter(System.out)
logStream = BufferedWriter(OutputStreamWriter(System.out))
} else {
val logFile = File.createTempFile("kotlin-extractor.", ".log", File(extractorLogDir))
logStream = FileWriter(logFile)
logStream = BufferedWriter(FileWriter(logFile))
}
}