Kotlin: Add Label.cast()

This commit is contained in:
Ian Lynagh
2022-03-29 10:50:00 +01:00
parent c89f3163f9
commit 86c31cb2e8
4 changed files with 38 additions and 65 deletions

View File

@@ -104,26 +104,22 @@ open class KotlinFileExtractor(
}
}
is IrFunction -> {
@Suppress("UNCHECKED_CAST")
val parentId = useDeclarationParent(declaration.parent, false) as Label<DbReftype>
val parentId = useDeclarationParent(declaration.parent, false).cast<DbReftype>()
extractFunction(declaration, parentId, true, null, listOf())
}
is IrAnonymousInitializer -> {
// Leaving this intentionally empty. init blocks are extracted during class extraction.
}
is IrProperty -> {
@Suppress("UNCHECKED_CAST")
val parentId = useDeclarationParent(declaration.parent, false) as Label<DbReftype>
val parentId = useDeclarationParent(declaration.parent, false).cast<DbReftype>()
extractProperty(declaration, parentId, true, null, listOf())
}
is IrEnumEntry -> {
@Suppress("UNCHECKED_CAST")
val parentId = useDeclarationParent(declaration.parent, false) as Label<DbReftype>
val parentId = useDeclarationParent(declaration.parent, false).cast<DbReftype>()
extractEnumEntry(declaration, parentId)
}
is IrField -> {
@Suppress("UNCHECKED_CAST")
val parentId = useDeclarationParent(declaration.parent, false) as Label<DbReftype>
val parentId = useDeclarationParent(declaration.parent, false).cast<DbReftype>()
extractField(declaration, parentId)
}
is IrTypeAlias -> extractTypeAlias(declaration)
@@ -252,16 +248,12 @@ open class KotlinFileExtractor(
// TODO: There's lots of duplication between this and extractClassSource.
// Can we share it?
if(kind == ClassKind.INTERFACE || kind == ClassKind.ANNOTATION_CLASS) {
@Suppress("UNCHECKED_CAST")
val interfaceId = id as Label<out DbInterface>
@Suppress("UNCHECKED_CAST")
val sourceInterfaceId = useClassSource(c) as Label<out DbInterface>
val interfaceId = id.cast<DbInterface>()
val sourceInterfaceId = useClassSource(c).cast<DbInterface>()
tw.writeInterfaces(interfaceId, cls, pkgId, sourceInterfaceId)
} else {
@Suppress("UNCHECKED_CAST")
val classId = id as Label<out DbClass>
@Suppress("UNCHECKED_CAST")
val sourceClassId = useClassSource(c) as Label<out DbClass>
val classId = id.cast<DbClass>()
val sourceClassId = useClassSource(c).cast<DbClass>()
tw.writeClasses(classId, cls, pkgId, sourceClassId)
if (kind == ClassKind.ENUM_CLASS) {
@@ -344,8 +336,7 @@ open class KotlinFileExtractor(
private fun extractLocalTypeDeclStmt(c: IrClass, callable: Label<out DbCallable>, parent: Label<out DbStmtparent>, idx: Int) {
val extractStaticInit = c.declarations.none { it is IrAnonymousInitializer }
@Suppress("UNCHECKED_CAST")
val id = extractClassSource(c, true, extractStaticInit) as Label<out DbClass>
val id = extractClassSource(c, true, extractStaticInit).cast<DbClass>()
extractLocalTypeDeclStmt(id, c, callable, parent, idx)
}
@@ -362,8 +353,7 @@ open class KotlinFileExtractor(
DeclarationStackAdjuster(c).use {
val id = if (c.isAnonymousObject) {
@Suppress("UNCHECKED_CAST")
useAnonymousClass(c).javaResult.id as Label<out DbClass>
useAnonymousClass(c).javaResult.id.cast<DbClass>()
} else {
useClassSource(c)
}
@@ -372,12 +362,10 @@ open class KotlinFileExtractor(
val pkgId = extractPackage(pkg)
val kind = c.kind
if (kind == ClassKind.INTERFACE || kind == ClassKind.ANNOTATION_CLASS) {
@Suppress("UNCHECKED_CAST")
val interfaceId = id as Label<out DbInterface>
val interfaceId = id.cast<DbInterface>()
tw.writeInterfaces(interfaceId, cls, pkgId, interfaceId)
} else {
@Suppress("UNCHECKED_CAST")
val classId = id as Label<out DbClass>
val classId = id.cast<DbClass>()
tw.writeClasses(classId, cls, pkgId, classId)
if (kind == ClassKind.ENUM_CLASS) {
@@ -410,8 +398,7 @@ open class KotlinFileExtractor(
tw.writeFieldsKotlinType(instance.id, type.kotlinResult.id)
tw.writeHasLocation(instance.id, locId)
addModifiers(instance.id, "public", "static", "final")
@Suppress("UNCHECKED_CAST")
tw.writeClass_object(id as Label<DbClass>, instance.id)
tw.writeClass_object(id.cast<DbClass>(), instance.id)
}
extractClassModifiers(c, id)
@@ -430,8 +417,7 @@ open class KotlinFileExtractor(
if (parent is IrClass) {
val parentId =
if (parent.isAnonymousObject) {
@Suppress("UNCHECKED_CAST")
useAnonymousClass(parent).javaResult.id as Label<out DbClass>
useAnonymousClass(parent).javaResult.id.cast<DbClass>()
} else {
useClassInstance(parent, parentClassTypeArguments).typeResult.id
}
@@ -447,8 +433,7 @@ open class KotlinFileExtractor(
tw.writeFieldsKotlinType(instance.id, type.kotlinResult.id)
tw.writeHasLocation(instance.id, innerLocId)
addModifiers(instance.id, "public", "static", "final")
@Suppress("UNCHECKED_CAST")
tw.writeType_companion_object(parentId, instance.id, innerId as Label<DbClass>)
tw.writeType_companion_object(parentId, instance.id, innerId.cast<DbClass>())
}
}
@@ -660,8 +645,7 @@ open class KotlinFileExtractor(
}
val allParamTypes = if (extReceiver != null) {
val extendedType = useType(extReceiver.type)
@Suppress("UNCHECKED_CAST")
tw.writeKtExtensionFunctions(id as Label<DbMethod>, extendedType.javaResult.id, extendedType.kotlinResult.id)
tw.writeKtExtensionFunctions(id.cast<DbMethod>(), extendedType.javaResult.id, extendedType.kotlinResult.id)
val t = extractValueParameter(extReceiver, id, 0, null, sourceDeclaration, classTypeArgsIncludingOuterClasses)
listOf(t) + paramTypes
@@ -680,15 +664,13 @@ open class KotlinFileExtractor(
typeSubstitution != null -> useType(substReturnType).javaResult.shortName
else -> f.returnType.classFqName?.shortName()?.asString() ?: f.name.asString()
}
@Suppress("UNCHECKED_CAST")
val constrId = id as Label<DbConstructor>
val constrId = id.cast<DbConstructor>()
tw.writeConstrs(constrId, shortName, "$shortName$paramsSignature", unitType.javaResult.id, parentId, sourceDeclaration as Label<DbConstructor>)
tw.writeConstrsKotlinType(constrId, unitType.kotlinResult.id)
} else {
val returnType = useType(substReturnType, TypeContext.RETURN)
val shortName = getFunctionShortName(f)
@Suppress("UNCHECKED_CAST")
val methodId = id as Label<DbMethod>
val methodId = id.cast<DbMethod>()
tw.writeMethods(methodId, shortName, "$shortName$paramsSignature", returnType.javaResult.id, parentId, sourceDeclaration as Label<DbMethod>)
tw.writeMethodsKotlinType(methodId, returnType.kotlinResult.id)
}
@@ -755,8 +737,7 @@ open class KotlinFileExtractor(
val setter = p.setter
if (getter != null) {
@Suppress("UNCHECKED_CAST")
val getterId = extractFunction(getter, parentId, extractBackingField, typeSubstitution, classTypeArgs) as Label<out DbMethod>?
val getterId = extractFunction(getter, parentId, extractBackingField, typeSubstitution, classTypeArgs)?.cast<DbMethod>()
if (getterId != null) {
tw.writeKtPropertyGetters(id, getterId)
}
@@ -770,8 +751,7 @@ open class KotlinFileExtractor(
if (!p.isVar) {
logger.errorElement("!isVar property with a setter", p)
}
@Suppress("UNCHECKED_CAST")
val setterId = extractFunction(setter, parentId, extractBackingField, typeSubstitution, classTypeArgs) as Label<out DbMethod>?
val setterId = extractFunction(setter, parentId, extractBackingField, typeSubstitution, classTypeArgs)?.cast<DbMethod>()
if (setterId != null) {
tw.writeKtPropertySetters(id, setterId)
}
@@ -1937,8 +1917,7 @@ open class KotlinFileExtractor(
val id = extractNewExpr(e.symbol.owner, (e.type as? IrSimpleType)?.arguments, type, locId, parent, idx, callable, enclosingStmt)
if (isAnonymous) {
@Suppress("UNCHECKED_CAST")
tw.writeIsAnonymClass(type.javaResult.id as Label<DbClass>, id)
tw.writeIsAnonymClass(type.javaResult.id.cast<DbClass>(), id)
}
extractCallValueArguments(id, e, enclosingStmt, callable, 0)
@@ -2147,8 +2126,7 @@ open class KotlinFileExtractor(
val methodId = useFunction<DbConstructor>(e.symbol.owner)
tw.writeHasLocation(id, locId)
@Suppress("UNCHECKED_CAST")
tw.writeCallableBinding(id as Label<DbCaller>, methodId)
tw.writeCallableBinding(id.cast<DbCaller>(), methodId)
extractCallValueArguments(id, e, id, callable, 0)
val dr = e.dispatchReceiver
if (dr != null) {
@@ -2919,8 +2897,7 @@ open class KotlinFileExtractor(
writeExpressionMetadataToTrapFile(callId, labels.methodId, retId)
val callableId = useFunction<DbCallable>(target.owner.realOverrideTarget, classTypeArgsIncludingOuterClasses)
@Suppress("UNCHECKED_CAST")
tw.writeCallableBinding(callId as Label<out DbCaller>, callableId)
tw.writeCallableBinding(callId.cast<DbCaller>(), callableId)
val useFirstArgAsDispatch: Boolean
if (dispatchReceiver != null) {
@@ -3399,8 +3376,7 @@ open class KotlinFileExtractor(
locId: Label<DbLocation>,
parameters: List<IrValueParameter>
) {
@Suppress("UNCHECKED_CAST")
val funLabels = addFunctionNInvoke(tw.getFreshIdLabel(), lambda.returnType, ids.type.javaResult.id as Label<DbReftype>, locId)
val funLabels = addFunctionNInvoke(tw.getFreshIdLabel(), lambda.returnType, ids.type.javaResult.id.cast<DbReftype>(), locId)
// Return
val retId = tw.getFreshIdLabel<DbReturnstmt>()
@@ -3911,8 +3887,7 @@ open class KotlinFileExtractor(
val idNewexpr = extractNewExpr(ids.constructor, ids.type, locId, id, 1, callable, enclosingStmt)
@Suppress("UNCHECKED_CAST")
tw.writeIsAnonymClass(ids.type.javaResult.id as Label<DbClass>, idNewexpr)
tw.writeIsAnonymClass(ids.type.javaResult.id.cast<DbClass>(), idNewexpr)
extractTypeAccessRecursive(e.typeOperand, locId, idNewexpr, -3, callable, enclosingStmt)
@@ -3964,8 +3939,7 @@ open class KotlinFileExtractor(
currentDeclaration: IrDeclaration
): Label<out DbClass> {
// Write class
@Suppress("UNCHECKED_CAST")
val id = ids.type.javaResult.id as Label<out DbClass>
val id = ids.type.javaResult.id.cast<DbClass>()
val pkgId = extractPackage("")
tw.writeClasses(id, "", pkgId, id)
tw.writeHasLocation(id, locId)
@@ -3990,8 +3964,7 @@ open class KotlinFileExtractor(
val baseConstructorId = useFunction<DbConstructor>(baseConstructor as IrFunction)
tw.writeHasLocation(superCallId, locId)
@Suppress("UNCHECKED_CAST")
tw.writeCallableBinding(superCallId as Label<DbCaller>, baseConstructorId)
tw.writeCallableBinding(superCallId.cast<DbCaller>(), baseConstructorId)
addModifiers(id, "final")
addVisibilityModifierToLocalOrAnonymousClass(id)

View File

@@ -838,8 +838,7 @@ open class KotlinUsesExtractor(
fun <T: DbCallable> useFunction(f: IrFunction, classTypeArgsIncludingOuterClasses: List<IrTypeArgument>? = null): Label<out T> {
if (f.isLocalFunction()) {
val ids = getLocallyVisibleFunctionLabels(f)
@Suppress("UNCHECKED_CAST")
return ids.function as Label<out T>
return ids.function.cast<T>()
} else {
return useFunctionCommon<T>(f, getFunctionLabel(f, classTypeArgsIncludingOuterClasses))
}
@@ -864,14 +863,12 @@ open class KotlinUsesExtractor(
// Note this function doesn't return a signature because type arguments are never incorporated into function signatures.
return when (arg) {
is IrStarProjection -> {
@Suppress("UNCHECKED_CAST")
val anyTypeLabel = useType(pluginContext.irBuiltIns.anyType).javaResult.id as Label<out DbReftype>
val anyTypeLabel = useType(pluginContext.irBuiltIns.anyType).javaResult.id.cast<DbReftype>()
TypeResult(extractBoundedWildcard(1, "@\"wildcard;\"", "?", anyTypeLabel), null, "?")
}
is IrTypeProjection -> {
val boundResults = useType(arg.type, TypeContext.GENERIC_ARGUMENT)
@Suppress("UNCHECKED_CAST")
val boundLabel = boundResults.javaResult.id as Label<out DbReftype>
val boundLabel = boundResults.javaResult.id.cast<DbReftype>()
return if(arg.variance == Variance.INVARIANT)
@Suppress("UNCHECKED_CAST")
@@ -949,8 +946,7 @@ open class KotlinUsesExtractor(
fun useClassSource(c: IrClass): Label<out DbClassorinterface> {
if (c.isAnonymousObject) {
@Suppress("UNCHECKED_CAST")
return useAnonymousClass(c).javaResult.id as Label<DbClass>
return useAnonymousClass(c).javaResult.id.cast<DbClass>()
}
// For source classes, the label doesn't include and type arguments

View File

@@ -6,7 +6,12 @@ import java.io.StringWriter
/**
* This represents a label (`#...`) in a TRAP file.
*/
interface Label<T>
interface Label<T> {
fun <U> cast(): Label<U> {
@Suppress("UNCHECKED_CAST")
return this as Label<U>
}
}
/**
* The label `#i`, e.g. `#123`. Most labels we generate are of this

View File

@@ -57,8 +57,7 @@ open class TrapWriter (protected val loggerBase: LoggerBase, val lm: TrapLabelMa
* initialised.
*/
fun <T> getExistingLabelFor(key: String): Label<T>? {
@Suppress("UNCHECKED_CAST")
return lm.labelMapping.get(key) as Label<T>?
return lm.labelMapping.get(key)?.cast<T>()
}
/**
* Returns the label for the given key, if one exists.