Kotlin: Avoid infinite recursion when extracting recursive interfaces

This commit is contained in:
Ian Lynagh
2025-10-30 13:18:48 +00:00
parent 26f59a8786
commit 1efecc099c
5 changed files with 86 additions and 7 deletions

View File

@@ -1,5 +1,6 @@
package com.github.codeql
import com.github.codeql.utils.ClassInstanceStack
import com.github.codeql.utils.isExternalFileClassMember
import com.semmle.extractor.java.OdasaOutput
import com.semmle.util.data.StringDigestor
@@ -18,6 +19,7 @@ class ExternalDeclExtractor(
val compression: Compression,
val invocationTrapFile: String,
val sourceFilePath: String,
val classInstanceStack: ClassInstanceStack,
val primitiveTypeMapping: PrimitiveTypeMapping,
val pluginContext: IrPluginContext,
val globalExtensionState: KotlinExtractorGlobalState,
@@ -163,6 +165,7 @@ class ExternalDeclExtractor(
binaryPath,
manager,
this,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
KotlinFileExtractor.DeclarationStack(),

View File

@@ -1,5 +1,6 @@
package com.github.codeql
import com.github.codeql.utils.ClassInstanceStack
import com.github.codeql.utils.versions.usesK2
import com.semmle.util.files.FileUtil
import com.semmle.util.trap.pathtransformers.PathTransformer
@@ -151,6 +152,7 @@ class KotlinExtractorExtension(
}
val compression = getCompression(logger)
val classInstanceStack = ClassInstanceStack()
val primitiveTypeMapping = PrimitiveTypeMapping(logger, pluginContext)
// FIXME: FileUtil expects a static global logger
// which should be provided by SLF4J's factory facility. For now we set it here.
@@ -182,6 +184,7 @@ class KotlinExtractorExtension(
trapDir,
srcDir,
file,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
globalExtensionState
@@ -358,6 +361,7 @@ private fun doFile(
dbTrapDir: File,
dbSrcDir: File,
srcFile: IrFile,
classInstanceStack: ClassInstanceStack,
primitiveTypeMapping: PrimitiveTypeMapping,
pluginContext: IrPluginContext,
globalExtensionState: KotlinExtractorGlobalState
@@ -415,6 +419,7 @@ private fun doFile(
compression,
invocationTrapFile,
srcFilePath,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
globalExtensionState,
@@ -429,6 +434,7 @@ private fun doFile(
srcFilePath,
null,
externalDeclExtractor,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
KotlinFileExtractor.DeclarationStack(),

View File

@@ -62,6 +62,7 @@ open class KotlinFileExtractor(
val filePath: String,
dependencyCollector: OdasaOutput.TrapFileManager?,
externalClassExtractor: ExternalDeclExtractor,
classInstanceStack: ClassInstanceStack,
primitiveTypeMapping: PrimitiveTypeMapping,
pluginContext: IrPluginContext,
val declarationStack: DeclarationStack,
@@ -72,6 +73,7 @@ open class KotlinFileExtractor(
tw,
dependencyCollector,
externalClassExtractor,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
globalExtensionState
@@ -496,12 +498,17 @@ open class KotlinFileExtractor(
}
extractClassModifiers(c, id)
extractClassSupertypes(
c,
id,
if (argsIncludingOuterClasses == null) ExtractSupertypesMode.Raw
else ExtractSupertypesMode.Specialised(argsIncludingOuterClasses)
)
classInstanceStack.push(c)
try {
extractClassSupertypes(
c,
id,
if (argsIncludingOuterClasses == null) ExtractSupertypesMode.Raw
else ExtractSupertypesMode.Specialised(argsIncludingOuterClasses)
)
} finally {
classInstanceStack.pop()
}
val locId = getLocation(c, argsIncludingOuterClasses)
tw.writeHasLocation(id, locId)

View File

@@ -49,6 +49,7 @@ open class KotlinUsesExtractor(
open val tw: TrapWriter,
val dependencyCollector: OdasaOutput.TrapFileManager?,
val externalClassExtractor: ExternalDeclExtractor,
val classInstanceStack: ClassInstanceStack,
val primitiveTypeMapping: PrimitiveTypeMapping,
val pluginContext: IrPluginContext,
val globalExtensionState: KotlinExtractorGlobalState
@@ -182,6 +183,7 @@ open class KotlinUsesExtractor(
filePath,
dependencyCollector,
externalClassExtractor,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
newDeclarationStack,
@@ -199,6 +201,7 @@ open class KotlinUsesExtractor(
clsFile.path,
dependencyCollector,
externalClassExtractor,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
newDeclarationStack,
@@ -537,6 +540,19 @@ open class KotlinUsesExtractor(
return Pair(p?.first ?: c, p?.second ?: argsIncludingOuterClassesBeforeReplacement)
}
private fun avoidInfiniteRecursion(
pair: Pair<IrClass, List<IrTypeArgument>?>
): Pair<IrClass, List<IrTypeArgument>?> {
val c = pair.first
val args = pair.second
if (args != null && args.isNotEmpty() && classInstanceStack.possiblyCyclicExtraction(c, args)) {
logger.warn("Making use of ${c.name} a raw type to avoid infinite recursion")
return Pair(c, null)
} else {
return pair
}
}
// `typeArgs` can be null to describe a raw generic type.
// For non-generic types it will be zero-length list.
private fun addClassLabel(
@@ -545,7 +561,7 @@ open class KotlinUsesExtractor(
inReceiverContext: Boolean = false
): TypeResult<DbClassorinterface> {
val replaced =
tryReplaceType(cBeforeReplacement, argsIncludingOuterClassesBeforeReplacement)
avoidInfiniteRecursion(tryReplaceType(cBeforeReplacement, argsIncludingOuterClassesBeforeReplacement))
val replacedClass = replaced.first
val replacedArgsIncludingOuterClasses = replaced.second

View File

@@ -0,0 +1,47 @@
package com.github.codeql.utils
import java.util.Stack
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.types.*
class ClassInstanceStack {
private val stack: Stack<IrClass> = Stack()
fun push(c: IrClass) = stack.push(c)
fun pop() = stack.pop()
private fun checkTypeArgs(sym: IrClassSymbol, args: List<IrTypeArgument>): Boolean {
for (arg in args) {
if (arg is IrTypeProjection) {
if (checkType(sym, arg.type)) {
return true
}
}
}
return false
}
private fun checkType(sym: IrClassSymbol, type: IrType): Boolean {
if (type is IrSimpleType) {
val decl = type.classifier.owner
if (decl.symbol == sym) {
return true
}
if (checkTypeArgs(sym, type.arguments)) {
return true
}
}
return false
}
fun possiblyCyclicExtraction(classToCheck: IrClass, args: List<IrTypeArgument>): Boolean {
for (c in stack) {
if (c.symbol == classToCheck.symbol && checkTypeArgs(c.symbol, args)) {
return true
}
}
return false
}
}