Merge pull request #20726 from igfoo/igfoo/ClassInstanceStack

Kotlin: Avoid infinite recursion when extracting recursive interfaces
This commit is contained in:
Ian Lynagh
2025-10-31 16:18:39 +00:00
committed by GitHub
16 changed files with 175 additions and 7 deletions

View File

@@ -1,5 +1,6 @@
package com.github.codeql
import com.github.codeql.utils.ClassInstanceStack
import com.github.codeql.utils.isExternalFileClassMember
import com.semmle.extractor.java.OdasaOutput
import com.semmle.util.data.StringDigestor
@@ -18,6 +19,7 @@ class ExternalDeclExtractor(
val compression: Compression,
val invocationTrapFile: String,
val sourceFilePath: String,
val classInstanceStack: ClassInstanceStack,
val primitiveTypeMapping: PrimitiveTypeMapping,
val pluginContext: IrPluginContext,
val globalExtensionState: KotlinExtractorGlobalState,
@@ -163,6 +165,7 @@ class ExternalDeclExtractor(
binaryPath,
manager,
this,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
KotlinFileExtractor.DeclarationStack(),

View File

@@ -1,5 +1,6 @@
package com.github.codeql
import com.github.codeql.utils.ClassInstanceStack
import com.github.codeql.utils.versions.usesK2
import com.semmle.util.files.FileUtil
import com.semmle.util.trap.pathtransformers.PathTransformer
@@ -151,6 +152,7 @@ class KotlinExtractorExtension(
}
val compression = getCompression(logger)
val classInstanceStack = ClassInstanceStack()
val primitiveTypeMapping = PrimitiveTypeMapping(logger, pluginContext)
// FIXME: FileUtil expects a static global logger
// which should be provided by SLF4J's factory facility. For now we set it here.
@@ -182,6 +184,7 @@ class KotlinExtractorExtension(
trapDir,
srcDir,
file,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
globalExtensionState
@@ -358,6 +361,7 @@ private fun doFile(
dbTrapDir: File,
dbSrcDir: File,
srcFile: IrFile,
classInstanceStack: ClassInstanceStack,
primitiveTypeMapping: PrimitiveTypeMapping,
pluginContext: IrPluginContext,
globalExtensionState: KotlinExtractorGlobalState
@@ -415,6 +419,7 @@ private fun doFile(
compression,
invocationTrapFile,
srcFilePath,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
globalExtensionState,
@@ -429,6 +434,7 @@ private fun doFile(
srcFilePath,
null,
externalDeclExtractor,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
KotlinFileExtractor.DeclarationStack(),

View File

@@ -62,6 +62,7 @@ open class KotlinFileExtractor(
val filePath: String,
dependencyCollector: OdasaOutput.TrapFileManager?,
externalClassExtractor: ExternalDeclExtractor,
classInstanceStack: ClassInstanceStack,
primitiveTypeMapping: PrimitiveTypeMapping,
pluginContext: IrPluginContext,
val declarationStack: DeclarationStack,
@@ -72,6 +73,7 @@ open class KotlinFileExtractor(
tw,
dependencyCollector,
externalClassExtractor,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
globalExtensionState
@@ -496,12 +498,17 @@ open class KotlinFileExtractor(
}
extractClassModifiers(c, id)
extractClassSupertypes(
c,
id,
if (argsIncludingOuterClasses == null) ExtractSupertypesMode.Raw
else ExtractSupertypesMode.Specialised(argsIncludingOuterClasses)
)
classInstanceStack.push(c)
try {
extractClassSupertypes(
c,
id,
if (argsIncludingOuterClasses == null) ExtractSupertypesMode.Raw
else ExtractSupertypesMode.Specialised(argsIncludingOuterClasses)
)
} finally {
classInstanceStack.pop()
}
val locId = getLocation(c, argsIncludingOuterClasses)
tw.writeHasLocation(id, locId)

View File

@@ -49,6 +49,7 @@ open class KotlinUsesExtractor(
open val tw: TrapWriter,
val dependencyCollector: OdasaOutput.TrapFileManager?,
val externalClassExtractor: ExternalDeclExtractor,
val classInstanceStack: ClassInstanceStack,
val primitiveTypeMapping: PrimitiveTypeMapping,
val pluginContext: IrPluginContext,
val globalExtensionState: KotlinExtractorGlobalState
@@ -182,6 +183,7 @@ open class KotlinUsesExtractor(
filePath,
dependencyCollector,
externalClassExtractor,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
newDeclarationStack,
@@ -199,6 +201,7 @@ open class KotlinUsesExtractor(
clsFile.path,
dependencyCollector,
externalClassExtractor,
classInstanceStack,
primitiveTypeMapping,
pluginContext,
newDeclarationStack,
@@ -537,6 +540,19 @@ open class KotlinUsesExtractor(
return Pair(p?.first ?: c, p?.second ?: argsIncludingOuterClassesBeforeReplacement)
}
private fun avoidInfiniteRecursion(
pair: Pair<IrClass, List<IrTypeArgument>?>
): Pair<IrClass, List<IrTypeArgument>?> {
val c = pair.first
val args = pair.second
if (args != null && args.isNotEmpty() && classInstanceStack.possiblyCyclicExtraction(c, args)) {
logger.warn("Making use of ${c.name} a raw type to avoid infinite recursion")
return Pair(c, null)
} else {
return pair
}
}
// `typeArgs` can be null to describe a raw generic type.
// For non-generic types it will be zero-length list.
private fun addClassLabel(
@@ -545,7 +561,7 @@ open class KotlinUsesExtractor(
inReceiverContext: Boolean = false
): TypeResult<DbClassorinterface> {
val replaced =
tryReplaceType(cBeforeReplacement, argsIncludingOuterClassesBeforeReplacement)
avoidInfiniteRecursion(tryReplaceType(cBeforeReplacement, argsIncludingOuterClassesBeforeReplacement))
val replacedClass = replaced.first
val replacedArgsIncludingOuterClasses = replaced.second

View File

@@ -0,0 +1,47 @@
package com.github.codeql.utils
import java.util.Stack
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.types.*
class ClassInstanceStack {
private val stack: Stack<IrClass> = Stack()
fun push(c: IrClass) = stack.push(c)
fun pop() = stack.pop()
private fun checkTypeArgs(sym: IrClassSymbol, args: List<IrTypeArgument>): Boolean {
for (arg in args) {
if (arg is IrTypeProjection) {
if (checkType(sym, arg.type)) {
return true
}
}
}
return false
}
private fun checkType(sym: IrClassSymbol, type: IrType): Boolean {
if (type is IrSimpleType) {
val decl = type.classifier.owner
if (decl.symbol == sym) {
return true
}
if (checkTypeArgs(sym, type.arguments)) {
return true
}
}
return false
}
fun possiblyCyclicExtraction(classToCheck: IrClass, args: List<IrTypeArgument>): Boolean {
for (c in stack) {
if (c.symbol == classToCheck.symbol && checkTypeArgs(c.symbol, args)) {
return true
}
}
return false
}
}

View File

@@ -0,0 +1,4 @@
package somepkg;
public interface IfaceA<T> extends IfaceB<T> {}

View File

@@ -0,0 +1,4 @@
package somepkg;
public interface IfaceB<T> extends IfaceC<IfaceA<IfaceB<T>>> {}

View File

@@ -0,0 +1,4 @@
package somepkg;
public interface IfaceC<T> {}

View File

@@ -0,0 +1,6 @@
package somepkg;
public interface IfaceZ {
public <T> IfaceA<String> someFun();
}

View File

@@ -0,0 +1,5 @@
package mypkg
import somepkg.IfaceZ
class SomeClass(private val myVal: IfaceZ) { }

View File

@@ -0,0 +1,6 @@
import commands
def test(codeql, java_full):
codeql.database.create(
command=["kotlinc somepkg/IfaceA.java somepkg/IfaceB.java somepkg/IfaceC.java somepkg/IfaceZ.java test.kt"]
)

View File

@@ -0,0 +1,19 @@
| file:///!unknown-binary-location/somepkg/IfaceA.class:0:0:0:0 | IfaceA |
| file:///!unknown-binary-location/somepkg/IfaceA.class:0:0:0:0 | IfaceA<> |
| file:///!unknown-binary-location/somepkg/IfaceA.class:0:0:0:0 | IfaceA<IfaceB<>> |
| file:///!unknown-binary-location/somepkg/IfaceA.class:0:0:0:0 | IfaceA<IfaceB<String>> |
| file:///!unknown-binary-location/somepkg/IfaceA.class:0:0:0:0 | IfaceA<IfaceB<T>> |
| file:///!unknown-binary-location/somepkg/IfaceA.class:0:0:0:0 | IfaceA<IfaceB> |
| file:///!unknown-binary-location/somepkg/IfaceA.class:0:0:0:0 | IfaceA<String> |
| file:///!unknown-binary-location/somepkg/IfaceB.class:0:0:0:0 | IfaceB |
| file:///!unknown-binary-location/somepkg/IfaceB.class:0:0:0:0 | IfaceB<> |
| file:///!unknown-binary-location/somepkg/IfaceB.class:0:0:0:0 | IfaceB<IfaceB> |
| file:///!unknown-binary-location/somepkg/IfaceB.class:0:0:0:0 | IfaceB<String> |
| file:///!unknown-binary-location/somepkg/IfaceB.class:0:0:0:0 | IfaceB<T> |
| file:///!unknown-binary-location/somepkg/IfaceC.class:0:0:0:0 | IfaceC |
| file:///!unknown-binary-location/somepkg/IfaceC.class:0:0:0:0 | IfaceC<> |
| file:///!unknown-binary-location/somepkg/IfaceC.class:0:0:0:0 | IfaceC<IfaceA<IfaceB<>>> |
| file:///!unknown-binary-location/somepkg/IfaceC.class:0:0:0:0 | IfaceC<IfaceA<IfaceB<String>>> |
| file:///!unknown-binary-location/somepkg/IfaceC.class:0:0:0:0 | IfaceC<IfaceA<IfaceB<T>>> |
| file:///!unknown-binary-location/somepkg/IfaceC.class:0:0:0:0 | IfaceC<IfaceA<IfaceB>> |
| file:///!unknown-binary-location/somepkg/IfaceZ.class:0:0:0:0 | IfaceZ |

View File

@@ -0,0 +1,5 @@
import java
from Type t
where t.getName().matches("Iface%")
select t

View File

@@ -0,0 +1,16 @@
import java.util.Stack;
// Diagnostic Matches: %Making use of Stack a raw type to avoid infinite recursion%
class MyType
fun foo1(x: List<List<List<List<MyType>>>>) { }
fun foo2(x: Stack<Stack<Stack<Stack<MyType>>>>) { }
class MkT<T> { }
fun foo3(x: MkT<MkT<MkT<MkT<MyType>>>>) { }

View File

@@ -0,0 +1,13 @@
| file:///!unknown-binary-location/MkT.class:0:0:0:0 | MkT<MkT<MkT<MkT<MyType>>>> |
| file:///!unknown-binary-location/MkT.class:0:0:0:0 | MkT<MkT<MkT<MyType>>> |
| file:///!unknown-binary-location/MkT.class:0:0:0:0 | MkT<MkT<MyType>> |
| file:///!unknown-binary-location/MkT.class:0:0:0:0 | MkT<MyType> |
| file:///modules/java.base/java/util/List.class:0:0:0:0 | List<? extends List<? extends List<? extends List<MyType>>>> |
| file:///modules/java.base/java/util/List.class:0:0:0:0 | List<? extends List<? extends List<MyType>>> |
| file:///modules/java.base/java/util/List.class:0:0:0:0 | List<? extends List<MyType>> |
| file:///modules/java.base/java/util/List.class:0:0:0:0 | List<MyType> |
| file:///modules/java.base/java/util/List.class:0:0:0:0 | List<Stack<MyType>> |
| file:///modules/java.base/java/util/Stack.class:0:0:0:0 | Stack<MyType> |
| file:///modules/java.base/java/util/Stack.class:0:0:0:0 | Stack<Stack<MyType>> |
| file:///modules/java.base/java/util/Stack.class:0:0:0:0 | Stack<Stack<Stack<MyType>>> |
| file:///modules/java.base/java/util/Stack.class:0:0:0:0 | Stack<Stack<Stack<Stack<MyType>>>> |

View File

@@ -0,0 +1,7 @@
import java
from Type t
where
t.getName().matches("%MyType%") and
t.getName().matches(["List<%", "Stack<%", "MkT<%"])
select t