Compare commits

..

3 Commits

Author SHA1 Message Date
Taus
cdd557f877 Python: hotfix - disable instanceFieldStep to avoid type-tracker blowup
The `instanceFieldStep` disjunct of `TypeTrackingInput::levelStepCall`
that was added in 7.2.0 uses `classInstanceTracker(cls)` -- which is
itself a type-tracker -- inside `levelStepCall`. That creates a
structural mutual recursion between the main type-tracker fixpoint and
`classInstanceTracker`, causing the type-tracker delta to blow up to
~100M tuples per iteration on some OOP-heavy Python codebases.
Verified on the python/mypy database: SSRF query wall time goes from
~12s before the offending commit to >40 minutes after it.

This hotfix temporarily drops the `instanceFieldStep` disjunct and
keeps only `inheritedFieldStep`, which does not pull on the call
graph and is well-behaved (verified at ~12s on mypy). The
`instanceFieldStep` helper predicate itself is kept in place, and
the `levelStepCall` body has a commented-out call to it so the
change is trivial to re-enable once the recursion issue is properly
addressed.
2026-07-01 14:31:00 +01:00
Tom Hvitved
2bf6031c0f Python: Update inline test expectations 2026-07-01 13:10:41 +02:00
Tom Hvitved
a5444b573a Python: Improve some flow summaries 2026-07-01 12:05:53 +02:00
38 changed files with 857 additions and 1162 deletions

View File

@@ -28,6 +28,7 @@
/swift/extractor/ @github/codeql-swift @github/code-scanning-language-coverage /swift/extractor/ @github/codeql-swift @github/code-scanning-language-coverage
/misc/codegen/ @github/codeql-swift /misc/codegen/ @github/codeql-swift
/java/kotlin-extractor/ @github/codeql-kotlin @github/code-scanning-language-coverage /java/kotlin-extractor/ @github/codeql-kotlin @github/code-scanning-language-coverage
/java/ql/test-kotlin1/ @github/codeql-kotlin
/java/ql/test-kotlin2/ @github/codeql-kotlin /java/ql/test-kotlin2/ @github/codeql-kotlin
# Experimental CodeQL cryptography # Experimental CodeQL cryptography

View File

@@ -6,8 +6,6 @@ import com.github.codeql.utils.*
import com.github.codeql.utils.versions.* import com.github.codeql.utils.versions.*
import com.semmle.extractor.java.OdasaOutput import com.semmle.extractor.java.OdasaOutput
import java.io.Closeable import java.io.Closeable
import java.nio.file.Files
import java.nio.file.Path
import java.util.* import java.util.*
import kotlin.collections.ArrayList import kotlin.collections.ArrayList
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
@@ -52,7 +50,6 @@ import org.jetbrains.kotlin.load.java.structure.JavaMethod
import org.jetbrains.kotlin.load.java.structure.JavaTypeParameter import org.jetbrains.kotlin.load.java.structure.JavaTypeParameter
import org.jetbrains.kotlin.load.java.structure.JavaTypeParameterListOwner import org.jetbrains.kotlin.load.java.structure.JavaTypeParameterListOwner
import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaClass import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaClass
import org.jetbrains.kotlin.fir.java.VirtualFileBasedSourceElement
import org.jetbrains.kotlin.name.FqName import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.types.Variance import org.jetbrains.kotlin.types.Variance
import org.jetbrains.kotlin.util.OperatorNameConventions import org.jetbrains.kotlin.util.OperatorNameConventions
@@ -164,60 +161,11 @@ open class KotlinFileExtractor(
} }
} }
private fun javaBinaryDeclaresMethod(c: IrClass, name: String): Boolean? { private fun javaBinaryDeclaresMethod(c: IrClass, name: String) =
// K1 path: source is JavaSourceElement wrapping a BinaryJavaClass - inspect class metadata ((c.source as? JavaSourceElement)?.javaElement as? BinaryJavaClass)?.methods?.any {
val binaryJavaClass = (c.source as? JavaSourceElement)?.javaElement as? BinaryJavaClass it.name.asString() == name
if (binaryJavaClass != null) {
return binaryJavaClass.methods.any { it.name.asString() == name }
} }
// K2 path: binary Java classes use VirtualFileBasedSourceElement instead of
// JavaSourceElement. The BinaryJavaClass is not stored in the source element, so we parse
// the class bytes directly using ASM to check if the method is explicitly declared.
val virtualFile = (c.source as? VirtualFileBasedSourceElement)?.virtualFile
if (virtualFile != null) {
if (!virtualFile.name.endsWith(".class")) return null
return try {
val bytes = virtualFile.contentsToByteArray()
var found = false
var hasKotlinMetadata = false
val reader = org.jetbrains.org.objectweb.asm.ClassReader(bytes)
reader.accept(
object : org.jetbrains.org.objectweb.asm.ClassVisitor(
org.jetbrains.org.objectweb.asm.Opcodes.ASM9
) {
override fun visitAnnotation(
descriptor: String,
visible: Boolean
): org.jetbrains.org.objectweb.asm.AnnotationVisitor? {
if (descriptor == "Lkotlin/Metadata;") hasKotlinMetadata = true
return null
}
override fun visitMethod(
access: Int,
methodName: String,
descriptor: String,
signature: String?,
exceptions: Array<String>?
): org.jetbrains.org.objectweb.asm.MethodVisitor? {
if (methodName == name) found = true
return null
}
},
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_CODE or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_DEBUG or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_FRAMES
)
if (hasKotlinMetadata) false else found
} catch (e: Exception) {
logger.warn("Failed to check binary class methods for ${c.fqNameWhenAvailable}: $e")
null
}
}
return null
}
private fun isJavaBinaryDeclaration(f: IrFunction) = private fun isJavaBinaryDeclaration(f: IrFunction) =
f.parentClassOrNull?.let { javaBinaryDeclaresMethod(it, f.name.asString()) } ?: false f.parentClassOrNull?.let { javaBinaryDeclaresMethod(it, f.name.asString()) } ?: false
@@ -227,14 +175,7 @@ open class KotlinFileExtractor(
when (d.name.asString()) { when (d.name.asString()) {
"toString" -> d.codeQlValueParameters.isEmpty() "toString" -> d.codeQlValueParameters.isEmpty()
"hashCode" -> d.codeQlValueParameters.isEmpty() "hashCode" -> d.codeQlValueParameters.isEmpty()
// Under K2 (language version 2.0+), the Object.equals(Object) parameter is "equals" -> d.codeQlValueParameters.singleOrNull()?.type?.isNullableAny() ?: false
// typed as Any (non-nullable) rather than Any? (nullable). Under K1 it is Any?.
// Accept both so the redeclaration is recovered consistently across compilers.
"equals" ->
d.codeQlValueParameters
.singleOrNull()
?.type
?.let { it.isNullableAny() || it.isAny() } ?: false
else -> false else -> false
} && isJavaBinaryDeclaration(d) } && isJavaBinaryDeclaration(d)
else -> false else -> false
@@ -1371,28 +1312,27 @@ open class KotlinFileExtractor(
): TypeResults { ): TypeResults {
with("value parameter", vp) { with("value parameter", vp) {
val location = locOverride ?: getLocation(vp, classTypeArgsIncludingOuterClasses) val location = locOverride ?: getLocation(vp, classTypeArgsIncludingOuterClasses)
val parentFunction = vp.parent as? IrFunction
val javaCallable = parentFunction?.let { getJavaCallable(it) }
val maybeAlteredType = val maybeAlteredType =
parentFunction?.let { (vp.parent as? IrFunction)?.let {
if (overridesCollectionsMethodWithAlteredParameterTypes(it)) if (overridesCollectionsMethodWithAlteredParameterTypes(it))
eraseCollectionsMethodParameterType(vp.type, it.name.asString(), idx) eraseCollectionsMethodParameterType(vp.type, it.name.asString(), idx)
else if ( else if (
(parentFunction as? IrConstructor)?.parentClassOrNull?.kind == (vp.parent as? IrConstructor)?.parentClassOrNull?.kind ==
ClassKind.ANNOTATION_CLASS ClassKind.ANNOTATION_CLASS
) )
kClassToJavaClass(vp.type) kClassToJavaClass(vp.type)
else null else null
} ?: vp.type } ?: vp.type
val javaType = javaCallable?.let { jCallable -> getJavaValueParameterType(jCallable, idx) } val javaType =
val addParameterWildcardsByDefault = (vp.parent as? IrFunction)?.let {
!getInnermostWildcardSupppressionAnnotation(vp) && getJavaCallable(it)?.let { jCallable ->
!(javaCallable == null && getJavaValueParameterType(jCallable, idx)
parentFunction?.origin == IrDeclarationOrigin.IR_EXTERNAL_JAVA_DECLARATION_STUB) }
}
val typeWithWildcards = val typeWithWildcards =
addJavaLoweringWildcards( addJavaLoweringWildcards(
maybeAlteredType, maybeAlteredType,
addParameterWildcardsByDefault, !getInnermostWildcardSupppressionAnnotation(vp),
javaType javaType
) )
val substitutedType = val substitutedType =
@@ -1406,9 +1346,9 @@ open class KotlinFileExtractor(
vp.origin == IrDeclarationOrigin.UNDERSCORE_PARAMETER || vp.origin == IrDeclarationOrigin.UNDERSCORE_PARAMETER ||
((vp.parent as? IrFunction)?.let { hasSynthesizedParameterNames(it) } ?: true) ((vp.parent as? IrFunction)?.let { hasSynthesizedParameterNames(it) } ?: true)
val javaParameter = val javaParameter =
when (javaCallable) { when (val callable = (vp.parent as? IrFunction)?.let { getJavaCallable(it) }) {
is JavaConstructor -> javaCallable.valueParameters.getOrNull(idx) is JavaConstructor -> callable.valueParameters.getOrNull(idx)
is JavaMethod -> javaCallable.valueParameters.getOrNull(idx) is JavaMethod -> callable.valueParameters.getOrNull(idx)
else -> null else -> null
} }
val extraAnnotations = val extraAnnotations =
@@ -2934,52 +2874,6 @@ open class KotlinFileExtractor(
return v return v
} }
private val sourceTextCache = mutableMapOf<String, String?>()
private fun getCurrentFileSourceText() =
sourceTextCache.getOrPut(filePath) {
runCatching { Files.readString(Path.of(filePath)) }.getOrNull()
}
private fun getVariableNameLocation(v: IrVariable): Label<DbLocation>? {
if (v.startOffset < 0 || v.endOffset < v.startOffset) return null
val source = getCurrentFileSourceText() ?: return null
if (v.startOffset >= source.length) return null
val name = v.name.asString()
if (name.isEmpty()) return null
val endExclusive = minOf(v.endOffset + 1, source.length)
val declarationText = source.substring(v.startOffset, endExclusive)
val nameOffsetInDeclaration = declarationText.indexOf(name)
if (nameOffsetInDeclaration < 0) return null
val nameStartOffset = v.startOffset + nameOffsetInDeclaration
// getLocation treats the end offset as exclusive (matching IR's getEndOffset), so the
// identifier span is [nameStartOffset, nameStartOffset + name.length).
val nameEndOffset = nameStartOffset + name.length
return tw.getLocation(nameStartOffset, nameEndOffset)
}
private fun shouldUseVariableNameLocation(v: IrVariable): Boolean {
// For a variable initialised by an IMPLICIT_NOTNULL coercion (a platform-type not-null
// assertion), the K2 frontend widens the IrVariable span to cover the coercion, which would
// shift the location away from the identifier. Anchor those to the name token instead.
// Variables without this coercion keep the location-provider span, which already points at
// the identifier.
val initializer = v.initializer
return initializer is IrTypeOperatorCall && initializer.operator == IrTypeOperator.IMPLICIT_NOTNULL
}
private fun getVariableLocation(v: IrVariable): Label<DbLocation> {
if (shouldUseVariableNameLocation(v)) {
val nameLocation = getVariableNameLocation(v)
if (nameLocation != null) return nameLocation
}
return tw.getLocation(getVariableLocationProvider(v))
}
private fun extractVariable( private fun extractVariable(
v: IrVariable, v: IrVariable,
callable: Label<out DbCallable>, callable: Label<out DbCallable>,
@@ -2988,7 +2882,7 @@ open class KotlinFileExtractor(
) { ) {
with("variable", v) { with("variable", v) {
val stmtId = tw.getFreshIdLabel<DbLocalvariabledeclstmt>() val stmtId = tw.getFreshIdLabel<DbLocalvariabledeclstmt>()
val locId = getVariableLocation(v) val locId = tw.getLocation(getVariableLocationProvider(v))
tw.writeStmts_localvariabledeclstmt(stmtId, parent, idx, callable) tw.writeStmts_localvariabledeclstmt(stmtId, parent, idx, callable)
tw.writeHasLocation(stmtId, locId) tw.writeHasLocation(stmtId, locId)
extractVariableExpr(v, callable, stmtId, 1, stmtId) extractVariableExpr(v, callable, stmtId, 1, stmtId)
@@ -3006,7 +2900,7 @@ open class KotlinFileExtractor(
with("variable expr", v) { with("variable expr", v) {
val varId = useVariable(v) val varId = useVariable(v)
val exprId = tw.getFreshIdLabel<DbLocalvariabledeclexpr>() val exprId = tw.getFreshIdLabel<DbLocalvariabledeclexpr>()
val locId = getVariableLocation(v) val locId = tw.getLocation(getVariableLocationProvider(v))
val type = useType(v.type) val type = useType(v.type)
tw.writeLocalvars(varId, v.name.asString(), type.javaResult.id, exprId) tw.writeLocalvars(varId, v.name.asString(), type.javaResult.id, exprId)
tw.writeLocalvarsKotlinType(varId, type.kotlinResult.id) tw.writeLocalvarsKotlinType(varId, type.kotlinResult.id)
@@ -4172,28 +4066,6 @@ open class KotlinFileExtractor(
else -> false else -> false
} }
private fun getCallResultType(c: IrCall, syntacticCallTarget: IrFunction): IrType {
if (syntacticCallTarget.origin != IrDeclarationOrigin.IR_EXTERNAL_JAVA_DECLARATION_STUB) {
return c.type
}
val primitiveInfo =
(c.type as? IrSimpleType)?.let { primitiveTypeMapping.getPrimitiveInfo(it) } ?: return c.type
val parentClass = syntacticCallTarget.parentClassOrNull ?: return c.type
val returnIsClassifier =
javaBinaryMethodReturnIsClassifierType(
parentClass,
getFunctionShortName(syntacticCallTarget).nameInDB,
syntacticCallTarget.codeQlValueParameters.size,
syntacticCallTarget is IrConstructor
)
return if (returnIsClassifier == true) {
primitiveInfo.javaClass.symbol.typeWith()
} else {
c.type
}
}
private fun isGenericArrayType(typeName: String) = private fun isGenericArrayType(typeName: String) =
when (typeName) { when (typeName) {
"Array" -> true "Array" -> true
@@ -4239,7 +4111,7 @@ open class KotlinFileExtractor(
extractRawMethodAccess( extractRawMethodAccess(
syntacticCallTarget, syntacticCallTarget,
c, c,
getCallResultType(c, syntacticCallTarget), c.type,
callable, callable,
parent, parent,
idx, idx,

View File

@@ -36,7 +36,6 @@ import org.jetbrains.kotlin.load.java.BuiltinMethodsWithSpecialGenericSignature
import org.jetbrains.kotlin.load.java.JvmAbi import org.jetbrains.kotlin.load.java.JvmAbi
import org.jetbrains.kotlin.load.java.sources.JavaSourceElement import org.jetbrains.kotlin.load.java.sources.JavaSourceElement
import org.jetbrains.kotlin.load.java.structure.* import org.jetbrains.kotlin.load.java.structure.*
import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaClass
import org.jetbrains.kotlin.load.java.typeEnhancement.hasEnhancedNullability import org.jetbrains.kotlin.load.java.typeEnhancement.hasEnhancedNullability
import org.jetbrains.kotlin.name.FqName import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.NameUtils import org.jetbrains.kotlin.name.NameUtils
@@ -997,20 +996,7 @@ open class KotlinUsesExtractor(
) )
return null return null
} }
val fileClassId = extractFileClass(fqName) return extractFileClass(fqName)
// Under K2, external file class members sit directly under IrExternalPackageFragment
// rather than under their IrClass parent. In that case the file class entity won't
// get a location set through the normal extractClassSource path.
if (d is IrMemberWithContainerSource && tw.lm.externalFileClassLocationsExtracted.add(fqName)) {
val binaryPath =
getContainerSourceBinaryPath(d.containerSource)
?.let { normalizeExternalFileClassBinaryPath(it, fqName) }
if (binaryPath != null && shouldUseConcreteExternalFileClassLocation(binaryPath)) {
val fileId = tw.mkFileId(binaryPath, true)
tw.writeHasLocation(fileClassId, tw.getWholeFileLocation(fileId))
}
}
return fileClassId
} }
return useDeclarationParent(parent, canBeTopLevel, classTypeArguments, inReceiverContext) return useDeclarationParent(parent, canBeTopLevel, classTypeArguments, inReceiverContext)
} }
@@ -1385,13 +1371,8 @@ open class KotlinUsesExtractor(
parentId: Label<out DbElement>, parentId: Label<out DbElement>,
classTypeArgsIncludingOuterClasses: List<IrTypeArgument>?, classTypeArgsIncludingOuterClasses: List<IrTypeArgument>?,
maybeParameterList: List<IrValueParameter>? = null maybeParameterList: List<IrValueParameter>? = null
): String { ): String =
val javaCallable = getJavaCallable(f) getFunctionLabel(
val addParameterWildcardsByDefault =
!getInnermostWildcardSupppressionAnnotation(f) &&
!(javaCallable == null && f.origin == IrDeclarationOrigin.IR_EXTERNAL_JAVA_DECLARATION_STUB)
return getFunctionLabel(
f.parent, f.parent,
parentId, parentId,
getFunctionShortName(f).nameInDB, getFunctionShortName(f).nameInDB,
@@ -1401,10 +1382,9 @@ open class KotlinUsesExtractor(
getFunctionTypeParameters(f), getFunctionTypeParameters(f),
classTypeArgsIncludingOuterClasses, classTypeArgsIncludingOuterClasses,
overridesCollectionsMethodWithAlteredParameterTypes(f), overridesCollectionsMethodWithAlteredParameterTypes(f),
javaCallable, getJavaCallable(f),
addParameterWildcardsByDefault !getInnermostWildcardSupppressionAnnotation(f)
) )
}
/* /*
* This function actually generates the label for a function. * This function actually generates the label for a function.
@@ -1491,41 +1471,15 @@ open class KotlinUsesExtractor(
// Finally, mimic the Java extractor's behaviour by naming functions with type // Finally, mimic the Java extractor's behaviour by naming functions with type
// parameters for their erased types; // parameters for their erased types;
// those without type parameters are named for the generic type. // those without type parameters are named for the generic type.
var maybeErased = val maybeErased =
if (functionTypeParameters.isEmpty()) maybeSubbed else erase(maybeSubbed) if (functionTypeParameters.isEmpty()) maybeSubbed else erase(maybeSubbed)
// K2 compatibility: under K2, Java @NotNull reference types such as @NotNull Integer
// are enhanced to Kotlin primitives (e.g. kotlin.Int). But the Java extractor uses
// the original reference type (java.lang.Integer) in callable labels. When we detect
// that the original Java parameter type is a reference (classifier) type but the
// Kotlin IR type is a primitive, revert to the boxed Java class so both extractors
// produce matching callable IDs.
if (functionTypeParameters.isEmpty()) {
val primitiveInfo = (maybeErased as? IrSimpleType)?.let {
primitiveTypeMapping.getPrimitiveInfo(it)
}
if (primitiveInfo != null) {
val parentClass = parent as? IrClass
if (parentClass != null) {
val isClassifierType = javaBinaryMethodParamIsClassifierType(
parentClass,
name,
allParamTypes.size,
name == "<init>",
it.index
)
if (isClassifierType == true) {
maybeErased = primitiveInfo.javaClass.symbol.typeWith()
}
}
}
}
"{${useType(maybeErased).javaResult.id}}" "{${useType(maybeErased).javaResult.id}}"
} }
val paramTypeIds = val paramTypeIds =
allParamTypes allParamTypes
.withIndex() .withIndex()
.joinToString(separator = ",", transform = getIdForFunctionLabel) .joinToString(separator = ",", transform = getIdForFunctionLabel)
var labelReturnType = val labelReturnType =
if (name == "<init>") pluginContext.irBuiltIns.unitType if (name == "<init>") pluginContext.irBuiltIns.unitType
else else
erase( erase(
@@ -1535,28 +1489,6 @@ open class KotlinUsesExtractor(
pluginContext pluginContext
) )
) )
// K2 compatibility: same as for parameters, if the Java binary method return type is a
// reference type but K2 enhanced it to a Kotlin primitive, use the boxed Java class.
if (functionTypeParameters.isEmpty() && name != "<init>") {
val primitiveInfo = (labelReturnType as? IrSimpleType)?.let {
primitiveTypeMapping.getPrimitiveInfo(it)
}
if (primitiveInfo != null) {
val parentClass = parent as? IrClass
if (parentClass != null) {
val returnIsClassifier =
javaBinaryMethodReturnIsClassifierType(
parentClass,
name,
allParamTypes.size,
false
)
if (returnIsClassifier == true) {
labelReturnType = primitiveInfo.javaClass.symbol.typeWith()
}
}
}
}
// Note that `addJavaLoweringWildcards` is not required here because the return type used to // Note that `addJavaLoweringWildcards` is not required here because the return type used to
// form the function // form the function
// label is always erased. // label is always erased.
@@ -1662,23 +1594,9 @@ open class KotlinUsesExtractor(
} }
@OptIn(ObsoleteDescriptorBasedAPI::class) @OptIn(ObsoleteDescriptorBasedAPI::class)
fun getJavaCallable(f: IrFunction): JavaMember? { fun getJavaCallable(f: IrFunction) =
val fromDescriptor = (f.descriptor.source as? JavaSourceElement)?.javaElement as? JavaMember (f.descriptor.source as? JavaSourceElement)?.javaElement as? JavaMember
if (fromDescriptor != null) return fromDescriptor
// K2 fallback: under K2, descriptor.source may not carry JavaSourceElement for binary Java
// methods. Try to get the JavaMember from the parent class's binary class directly.
val parentClass = f.parentClassOrNull ?: return null
val binaryJavaClass = (parentClass.source as? JavaSourceElement)?.javaElement as? BinaryJavaClass
?: return null
val name = getFunctionShortName(f).nameInDB
val nParams = f.codeQlValueParameters.size
return if (f is IrConstructor) {
binaryJavaClass.constructors.find { it.valueParameters.size == nParams }
} else {
binaryJavaClass.methods.find { it.name.asString() == name && it.valueParameters.size == nParams }
}
}
fun getJavaValueParameterType(m: JavaMember, idx: Int) = fun getJavaValueParameterType(m: JavaMember, idx: Int) =
when (m) { when (m) {
is JavaMethod -> m.valueParameters[idx].type is JavaMethod -> m.valueParameters[idx].type

View File

@@ -51,13 +51,6 @@ class TrapLabelManager {
* to avoid duplication. * to avoid duplication.
*/ */
val fileClassLocationsExtracted = HashSet<IrFile>() val fileClassLocationsExtracted = HashSet<IrFile>()
/**
* Tracks external file classes (by FqName) whose location has been set from a binary path.
* Used to avoid writing duplicate hasLocation facts for external file class entities extracted
* through the K2 code path where declarations sit directly under IrExternalPackageFragment.
*/
val externalFileClassLocationsExtracted = HashSet<org.jetbrains.kotlin.name.FqName>()
} }
/** /**

View File

@@ -17,7 +17,6 @@ import org.jetbrains.kotlin.load.kotlin.JvmPackagePartSource
import org.jetbrains.kotlin.load.kotlin.KotlinJvmBinarySourceElement import org.jetbrains.kotlin.load.kotlin.KotlinJvmBinarySourceElement
import org.jetbrains.kotlin.load.kotlin.VirtualFileKotlinClass import org.jetbrains.kotlin.load.kotlin.VirtualFileKotlinClass
import org.jetbrains.kotlin.name.FqName import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.serialization.deserialization.descriptors.DeserializedContainerSource
// Adapted from Kotlin's interpreter/Utils.kt function 'internalName' // Adapted from Kotlin's interpreter/Utils.kt function 'internalName'
// Translates class names into their JLS section 13.1 binary name, // Translates class names into their JLS section 13.1 binary name,
@@ -177,238 +176,15 @@ fun getIrDeclarationBinaryPath(d: IrDeclaration): String? {
// This is in a file class. // This is in a file class.
val fqName = getFileClassFqName(d) val fqName = getFileClassFqName(d)
if (fqName != null) { if (fqName != null) {
if (d is IrMemberWithContainerSource) {
val containerBinaryPath = getContainerSourceBinaryPath(d.containerSource)
if (containerBinaryPath != null) {
return normalizeExternalFileClassBinaryPath(containerBinaryPath, fqName)
}
}
return getUnknownBinaryLocation(fqName.asString()) return getUnknownBinaryLocation(fqName.asString())
} }
} }
return null return null
} }
/**
* Attempts to get the binary file path from a container source (typically a
* [JvmPackagePartSource]). Returns null if the path is unavailable.
*/
fun getContainerSourceBinaryPath(containerSource: org.jetbrains.kotlin.serialization.deserialization.descriptors.DeserializedContainerSource?): String? {
if (containerSource !is JvmPackagePartSource) return null
val binaryClass = containerSource.knownJvmBinaryClass ?: return null
return when (binaryClass) {
is VirtualFileKotlinClass -> {
val vf = binaryClass.file
val path = vf.path
if (vf.fileSystem.protocol == StandardFileSystems.JRT_PROTOCOL)
"/${path.split("!/", limit = 2)[1]}"
else path
}
else -> binaryClass.location.takeIf { it.isNotEmpty() }
}
}
private fun getUnknownBinaryLocation(s: String): String { private fun getUnknownBinaryLocation(s: String): String {
return "/!unknown-binary-location/${s.replace(".", "/")}.class" return "/!unknown-binary-location/${s.replace(".", "/")}.class"
} }
fun normalizeExternalFileClassBinaryPath(path: String, fqName: FqName): String {
if (path.contains(".kotlinc_installed")) {
return getUnknownBinaryLocation(fqName.asString())
}
val normalizedPath = path.replace('\\', '/')
val classInternalPath = "${fqName.asString().replace(".", "/")}.class"
val classSuffix = "/$classInternalPath"
if (normalizedPath.endsWith(classSuffix)) {
val classpathRoot = normalizedPath.removeSuffix(classSuffix).substringAfterLast('/')
if (classpathRoot.isNotEmpty()) {
return "$classpathRoot/$classInternalPath"
}
}
return path
}
fun shouldUseConcreteExternalFileClassLocation(path: String): Boolean {
val normalizedPath = path.replace('\\', '/')
return normalizedPath.contains("/") &&
!normalizedPath.startsWith("/!unknown-binary-location/")
}
fun getJavaEquivalentClassId(c: IrClass) = fun getJavaEquivalentClassId(c: IrClass) =
c.fqNameWhenAvailable?.toUnsafe()?.let { JavaToKotlinClassMap.mapKotlinToJava(it) } c.fqNameWhenAvailable?.toUnsafe()?.let { JavaToKotlinClassMap.mapKotlinToJava(it) }
/**
* Checks whether a specific parameter of a Java binary method (identified by [methodName] and
* [paramIndex]) is a reference type (as opposed to a Java primitive). This is used to detect
* cases where K2 FIR has enhanced a reference type parameter (e.g. `@NotNull Integer`) to a
* Kotlin primitive (e.g. `kotlin.Int`), so that callable labels can use the original reference
* type and remain compatible with the Java extractor's callable IDs.
*
* Under K1, binary Java classes use [JavaSourceElement] and we can check [BinaryJavaClass.methods]
* directly. Under K2, they use [VirtualFileBasedSourceElement] and we fall back to reading the
* class bytes with ASM.
*
* Returns `null` if the information cannot be determined.
*/
fun javaBinaryMethodParamIsClassifierType(
parentClass: IrClass,
methodName: String,
nParams: Int,
isConstructor: Boolean,
paramIndex: Int
): Boolean? {
// K1 path: binary Java class has JavaSourceElement with a BinaryJavaClass.
val k1ParamKinds =
((parentClass.source as? JavaSourceElement)?.javaElement as? BinaryJavaClass)?.let {
binaryJavaClass ->
if (isConstructor)
binaryJavaClass.constructors
.asSequence()
.filter { it.valueParameters.size == nParams }
.mapNotNull { it.valueParameters.getOrNull(paramIndex)?.type }
.map { it is org.jetbrains.kotlin.load.java.structure.JavaClassifierType }
.toSet()
else
binaryJavaClass.methods
.asSequence()
.filter { it.name.asString() == methodName && it.valueParameters.size == nParams }
.mapNotNull { it.valueParameters.getOrNull(paramIndex)?.type }
.map { it is org.jetbrains.kotlin.load.java.structure.JavaClassifierType }
.toSet()
}
if (k1ParamKinds != null && k1ParamKinds.isNotEmpty()) {
return k1ParamKinds.singleOrNull()
}
// K2 path: binary Java class has VirtualFileBasedSourceElement
val k2Source = parentClass.source as? VirtualFileBasedSourceElement ?: return null
val vf = k2Source.virtualFile
if (!vf.name.endsWith(".class")) return null
return try {
val bytes = vf.contentsToByteArray()
val expectedMethodName = if (isConstructor) "<init>" else methodName
val descriptorKinds = mutableSetOf<Boolean>()
val reader = org.jetbrains.org.objectweb.asm.ClassReader(bytes)
reader.accept(
object : org.jetbrains.org.objectweb.asm.ClassVisitor(
org.jetbrains.org.objectweb.asm.Opcodes.ASM9
) {
override fun visitMethod(
access: Int,
name: String,
descriptor: String,
signature: String?,
exceptions: Array<String>?
): org.jetbrains.org.objectweb.asm.MethodVisitor? {
if (name != expectedMethodName) return null
val paramDescriptors = parseAsmMethodDescriptorParams(descriptor)
if (paramDescriptors.size != nParams) return null
val paramDesc = paramDescriptors.getOrNull(paramIndex) ?: return null
// Reference types start with 'L' or '['; Java primitives are single chars
descriptorKinds.add(paramDesc.startsWith("L") || paramDesc.startsWith("["))
return null
}
},
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_CODE or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_DEBUG or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_FRAMES
)
descriptorKinds.singleOrNull()
} catch (e: Exception) {
null
}
}
/**
* Checks whether the return type of a Java binary method (identified by [methodName] and
* [nParams]) is a reference type (as opposed to a Java primitive).
*
* Returns `null` if the information cannot be determined.
*/
fun javaBinaryMethodReturnIsClassifierType(
parentClass: IrClass,
methodName: String,
nParams: Int,
isConstructor: Boolean
): Boolean? {
if (isConstructor) return false
// K1 path: binary Java class has JavaSourceElement with a BinaryJavaClass.
val k1ReturnKinds =
((parentClass.source as? JavaSourceElement)?.javaElement as? BinaryJavaClass)?.methods
?.asSequence()
?.filter { it.name.asString() == methodName && it.valueParameters.size == nParams }
?.map { it.returnType is org.jetbrains.kotlin.load.java.structure.JavaClassifierType }
?.toSet()
if (k1ReturnKinds != null && k1ReturnKinds.isNotEmpty()) {
return k1ReturnKinds.singleOrNull()
}
// K2 path: binary Java class has VirtualFileBasedSourceElement
val k2Source = parentClass.source as? VirtualFileBasedSourceElement ?: return null
val vf = k2Source.virtualFile
if (!vf.name.endsWith(".class")) return null
return try {
val bytes = vf.contentsToByteArray()
val returnKinds = mutableSetOf<Boolean>()
val reader = org.jetbrains.org.objectweb.asm.ClassReader(bytes)
reader.accept(
object : org.jetbrains.org.objectweb.asm.ClassVisitor(
org.jetbrains.org.objectweb.asm.Opcodes.ASM9
) {
override fun visitMethod(
access: Int,
name: String,
descriptor: String,
signature: String?,
exceptions: Array<String>?
): org.jetbrains.org.objectweb.asm.MethodVisitor? {
if (name != methodName) return null
if (parseAsmMethodDescriptorParams(descriptor).size != nParams) return null
val returnDescriptor = descriptor.substring(descriptor.lastIndexOf(')') + 1)
returnKinds.add(
returnDescriptor.startsWith("L") || returnDescriptor.startsWith("[")
)
return null
}
},
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_CODE or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_DEBUG or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_FRAMES
)
returnKinds.singleOrNull()
} catch (e: Exception) {
null
}
}
private fun parseAsmMethodDescriptorParams(descriptor: String): List<String> {
val params = mutableListOf<String>()
var i = descriptor.indexOf('(') + 1
val end = descriptor.lastIndexOf(')')
while (i < end) {
when (val c = descriptor[i]) {
'L' -> {
val semi = descriptor.indexOf(';', i)
params.add(descriptor.substring(i, semi + 1))
i = semi + 1
}
'[' -> {
var j = i + 1
while (j < end && descriptor[j] == '[') j++
if (descriptor[j] == 'L') {
val semi = descriptor.indexOf(';', j)
params.add(descriptor.substring(i, semi + 1))
i = semi + 1
} else {
params.add(descriptor.substring(i, j + 1))
i = j + 1
}
}
else -> { params.add(c.toString()); i++ }
}
}
return params
}

View File

@@ -1,11 +1,11 @@
import pathlib import pathlib
def test(codeql, java_full): def test(codeql, java_full, kotlinc_2_3_20):
java_srcs = " ".join([str(s) for s in pathlib.Path().glob("*.java")]) java_srcs = " ".join([str(s) for s in pathlib.Path().glob("*.java")])
codeql.database.create( codeql.database.create(
command=[ command=[
f"javac {java_srcs} -d build", f"javac {java_srcs} -d build",
"kotlinc -language-version 2.0 user.kt -cp build", "kotlinc -language-version 1.9 user.kt -cp build",
] ]
) )

View File

@@ -1,6 +1,6 @@
import commands import commands
def test(codeql, java_full): def test(codeql, java_full, kotlinc_2_3_20):
commands.run("kotlinc -language-version 2.0 test.kt -d lib") commands.run("kotlinc -language-version 1.9 test.kt -d lib")
codeql.database.create(command="kotlinc -language-version 2.0 user.kt -cp lib") codeql.database.create(command="kotlinc -language-version 1.9 user.kt -cp lib")

View File

@@ -9,4 +9,4 @@
| Percentage of calls with call target | 100 | | Percentage of calls with call target | 100 |
| Total number of lines | 3 | | Total number of lines | 3 |
| Total number of lines with extension kt | 3 | | Total number of lines with extension kt | 3 |
| Uses Kotlin 2: true | 1 | | Uses Kotlin 2: false | 1 |

View File

@@ -1,2 +1,2 @@
def test(codeql, java_full): def test(codeql, java_full, kotlinc_2_3_20):
codeql.database.create(command="kotlinc -J-Xmx2G -language-version 2.0 SomeClass.kt") codeql.database.create(command=f"kotlinc -J-Xmx2G -language-version 1.9 SomeClass.kt")

View File

@@ -1,6 +1,6 @@
import commands import commands
def test(codeql, java_full): def test(codeql, java_full, kotlinc_2_3_20):
commands.run("kotlinc -language-version 2.0 A.kt") commands.run("kotlinc -language-version 1.9 A.kt")
codeql.database.create(command="kotlinc -cp . -language-version 2.0 B.kt C.kt") codeql.database.create(command="kotlinc -cp . -language-version 1.9 B.kt C.kt")

View File

@@ -1,6 +1,6 @@
import commands import commands
def test(codeql, java_full): def test(codeql, java_full, kotlinc_2_3_20):
commands.run(["javac", "Test.java", "-d", "bin"]) commands.run(["javac", "Test.java", "-d", "bin"])
codeql.database.create(command="kotlinc -language-version 2.0 user.kt -cp bin") codeql.database.create(command="kotlinc -language-version 1.9 user.kt -cp bin")

View File

@@ -1,13 +1,13 @@
import commands import commands
def test(codeql, java_full): def test(codeql, java_full, kotlinc_2_3_20):
# Compile the JavaDefns2 copy outside tracing, to make sure the Kotlin view of it matches the Java view seen by the traced javac compilation of JavaDefns.java below. # Compile the JavaDefns2 copy outside tracing, to make sure the Kotlin view of it matches the Java view seen by the traced javac compilation of JavaDefns.java below.
commands.run(["javac", "JavaDefns2.java"]) commands.run(["javac", "JavaDefns2.java"])
codeql.database.create( codeql.database.create(
command=[ command=[
"kotlinc kotlindefns.kt", "kotlinc kotlindefns.kt",
"javac JavaUser.java JavaDefns.java -cp .", "javac JavaUser.java JavaDefns.java -cp .",
"kotlinc -language-version 2.0 -cp . kotlinuser.kt", "kotlinc -language-version 1.9 -cp . kotlinuser.kt",
] ]
) )

View File

@@ -0,0 +1,5 @@
---
category: minorAnalysis
---
- Temporarily disabled the `instanceFieldStep` disjunct of the internal `TypeTrackingInput::levelStepCall` predicate, which was introduced in 7.2.0 and caused catastrophic query slowdowns on some OOP-heavy Python codebases (e.g. `mypy` and `dask`).

View File

@@ -1138,7 +1138,9 @@ predicate clearsContent(Node n, ContentSet cs) {
* Holds if the value that is being tracked is expected to be stored inside content `c` * Holds if the value that is being tracked is expected to be stored inside content `c`
* at node `n`. * at node `n`.
*/ */
predicate expectsContent(Node n, ContentSet c) { none() } predicate expectsContent(Node n, ContentSet c) {
FlowSummaryImpl::Private::Steps::summaryExpectsContent(n.(FlowSummaryNode).getSummaryNode(), c)
}
/** /**
* Holds if values stored inside attribute `c` are cleared at node `n`. * Holds if values stored inside attribute `c` are cleared at node `n`.

View File

@@ -91,6 +91,8 @@ module Input implements InputSig<Location, DataFlowImplSpecific::PythonDataFlow>
cs.isAnyTupleOrDictionaryElement() and result = "AnyTupleOrDictionaryElement" and arg = "" cs.isAnyTupleOrDictionaryElement() and result = "AnyTupleOrDictionaryElement" and arg = ""
} }
string encodeWithContent(ContentSet c, string arg) { result = "With" + encodeContent(c, arg) }
bindingset[token] bindingset[token]
ParameterPosition decodeUnknownParameterPosition(AccessPath::AccessPathTokenBase token) { ParameterPosition decodeUnknownParameterPosition(AccessPath::AccessPathTokenBase token) {
// needed to support `Argument[x..y]` ranges // needed to support `Argument[x..y]` ranges

View File

@@ -170,7 +170,13 @@ module TypeTrackingInput implements Shared::TypeTrackingInput<Location> {
/** Holds if there is a level step from `nodeFrom` to `nodeTo`, which may depend on the call graph. */ /** Holds if there is a level step from `nodeFrom` to `nodeTo`, which may depend on the call graph. */
predicate levelStepCall(Node nodeFrom, LocalSourceNode nodeTo) { predicate levelStepCall(Node nodeFrom, LocalSourceNode nodeTo) {
instanceFieldStep(nodeFrom, nodeTo) // HOTFIX: `instanceFieldStep` is temporarily disabled (via `and none()`).
// It uses `classInstanceTracker(cls)` -- itself a type-tracker run --
// from inside `levelStepCall`, creating a structural mutual recursion
// that causes catastrophic query slowdowns on some OOP-heavy Python
// codebases (e.g. mypy and dask). The `and none()` should be removed
// once that recursion is redesigned.
instanceFieldStep(nodeFrom, nodeTo) and none()
or or
inheritedFieldStep(nodeFrom, nodeTo) inheritedFieldStep(nodeFrom, nodeTo)
} }

View File

@@ -4199,11 +4199,9 @@ module StdlibPrivate {
// The positional argument contains a mapping. // The positional argument contains a mapping.
// TODO: these values can be overwritten by keyword arguments // TODO: these values can be overwritten by keyword arguments
// - dict mapping // - dict mapping
exists(DataFlow::DictionaryElementContent dc, string key | key = dc.getKey() | input = "Argument[0].WithAnyDictionaryElement" and
input = "Argument[0].DictionaryElement[" + key + "]" and output = "ReturnValue" and
output = "ReturnValue.DictionaryElement[" + key + "]" and preservesValue = true
preservesValue = true
)
or or
// - list-of-pairs mapping // - list-of-pairs mapping
input = "Argument[0].ListElement.TupleElement[1]" and input = "Argument[0].ListElement.TupleElement[1]" and
@@ -4240,9 +4238,7 @@ module StdlibPrivate {
or or
input = "Argument[0].SetElement" input = "Argument[0].SetElement"
or or
exists(DataFlow::TupleElementContent tc, int i | i = tc.getIndex() | input = "Argument[0].AnyTupleElement"
input = "Argument[0].TupleElement[" + i.toString() + "]"
)
// TODO: Once we have DictKeyContent, we need to transform that into ListElementContent // TODO: Once we have DictKeyContent, we need to transform that into ListElementContent
) and ) and
// Element content is mutated into list element content // Element content is mutated into list element content
@@ -4266,11 +4262,9 @@ module StdlibPrivate {
} }
override predicate propagatesFlow(string input, string output, boolean preservesValue) { override predicate propagatesFlow(string input, string output, boolean preservesValue) {
exists(DataFlow::TupleElementContent tc, int i | i = tc.getIndex() | input = "Argument[0].WithAnyTupleElement" and
input = "Argument[0].TupleElement[" + i.toString() + "]" and output = "ReturnValue" and
output = "ReturnValue.TupleElement[" + i.toString() + "]" and preservesValue = true
preservesValue = true
)
or or
input = "Argument[0].ListElement" and input = "Argument[0].ListElement" and
output = "ReturnValue" and output = "ReturnValue" and
@@ -4294,9 +4288,7 @@ module StdlibPrivate {
or or
input = "Argument[0].SetElement" input = "Argument[0].SetElement"
or or
exists(DataFlow::TupleElementContent tc, int i | i = tc.getIndex() | input = "Argument[0].AnyTupleElement"
input = "Argument[0].TupleElement[" + i.toString() + "]"
)
// TODO: Once we have DictKeyContent, we need to transform that into ListElementContent // TODO: Once we have DictKeyContent, we need to transform that into ListElementContent
) and ) and
output = "ReturnValue.SetElement" and output = "ReturnValue.SetElement" and
@@ -4342,9 +4334,7 @@ module StdlibPrivate {
or or
input = "Argument[0].SetElement" input = "Argument[0].SetElement"
or or
exists(DataFlow::TupleElementContent tc, int i | i = tc.getIndex() | input = "Argument[0].AnyTupleElement"
input = "Argument[0].TupleElement[" + i.toString() + "]"
)
// TODO: Once we have DictKeyContent, we need to transform that into ListElementContent // TODO: Once we have DictKeyContent, we need to transform that into ListElementContent
) and ) and
output = "ReturnValue.ListElement" and output = "ReturnValue.ListElement" and
@@ -4372,9 +4362,7 @@ module StdlibPrivate {
or or
content = "SetElement" content = "SetElement"
or or
exists(DataFlow::TupleElementContent tc, int i | i = tc.getIndex() | content = "AnyTupleElement"
content = "TupleElement[" + i.toString() + "]"
)
| |
// TODO: Once we have DictKeyContent, we need to transform that into ListElementContent // TODO: Once we have DictKeyContent, we need to transform that into ListElementContent
input = "Argument[0]." + content and input = "Argument[0]." + content and
@@ -4404,9 +4392,7 @@ module StdlibPrivate {
or or
input = "Argument[0].SetElement" input = "Argument[0].SetElement"
or or
exists(DataFlow::TupleElementContent tc, int i | i = tc.getIndex() | input = "Argument[0].AnyTupleElement"
input = "Argument[0].TupleElement[" + i.toString() + "]"
)
// TODO: Once we have DictKeyContent, we need to transform that into ListElementContent // TODO: Once we have DictKeyContent, we need to transform that into ListElementContent
) and ) and
output = "ReturnValue.ListElement" and output = "ReturnValue.ListElement" and
@@ -4434,9 +4420,7 @@ module StdlibPrivate {
or or
input = "Argument[0].SetElement" input = "Argument[0].SetElement"
or or
exists(DataFlow::TupleElementContent tc, int i | i = tc.getIndex() | input = "Argument[0].AnyTupleElement"
input = "Argument[0].TupleElement[" + i.toString() + "]"
)
// TODO: Once we have DictKeyContent, we need to transform that into ListElementContent // TODO: Once we have DictKeyContent, we need to transform that into ListElementContent
) and ) and
output = "ReturnValue" and output = "ReturnValue" and
@@ -4468,9 +4452,7 @@ module StdlibPrivate {
// We reduce generality slightly by not tracking tuple contents on list arguments beyond the first, for performance. // We reduce generality slightly by not tracking tuple contents on list arguments beyond the first, for performance.
// TODO: Once we have TupleElementAny, this generality can be increased. // TODO: Once we have TupleElementAny, this generality can be increased.
i = 0 and i = 0 and
exists(DataFlow::TupleElementContent tc, int j | j = tc.getIndex() | input = "Argument[1].AnyTupleElement"
input = "Argument[1].TupleElement[" + j.toString() + "]"
)
// TODO: Once we have DictKeyContent, we need to transform that into ListElementContent // TODO: Once we have DictKeyContent, we need to transform that into ListElementContent
) and ) and
output = "Argument[0].Parameter[" + i.toString() + "]" and output = "Argument[0].Parameter[" + i.toString() + "]" and
@@ -4499,9 +4481,7 @@ module StdlibPrivate {
or or
input = "Argument[1].SetElement" input = "Argument[1].SetElement"
or or
exists(DataFlow::TupleElementContent tc, int i | i = tc.getIndex() | input = "Argument[1].AnyTupleElement"
input = "Argument[1].TupleElement[" + i.toString() + "]"
)
// TODO: Once we have DictKeyContent, we need to transform that into ListElementContent // TODO: Once we have DictKeyContent, we need to transform that into ListElementContent
) and ) and
(output = "Argument[0].Parameter[0]" or output = "ReturnValue.ListElement") and (output = "Argument[0].Parameter[0]" or output = "ReturnValue.ListElement") and
@@ -4525,9 +4505,7 @@ module StdlibPrivate {
or or
input = "Argument[0].SetElement" input = "Argument[0].SetElement"
or or
exists(DataFlow::TupleElementContent tc, int i | i = tc.getIndex() | input = "Argument[0].AnyTupleElement"
input = "Argument[0].TupleElement[" + i.toString() + "]"
)
// TODO: Once we have DictKeyContent, we need to transform that into ListElementContent // TODO: Once we have DictKeyContent, we need to transform that into ListElementContent
) and ) and
output = "ReturnValue.ListElement.TupleElement[1]" and output = "ReturnValue.ListElement.TupleElement[1]" and
@@ -4552,12 +4530,7 @@ module StdlibPrivate {
or or
input = "Argument[" + i.toString() + "].SetElement" input = "Argument[" + i.toString() + "].SetElement"
or or
// We reduce generality slightly by not tracking tuple contents on arguments beyond the first two, for performance. input = "Argument[" + i.toString() + "].AnyTupleElement"
// TODO: Once we have TupleElementAny, this generality can be increased.
i in [0 .. 1] and
exists(DataFlow::TupleElementContent tc, int j | j = tc.getIndex() |
input = "Argument[" + i.toString() + "].TupleElement[" + j.toString() + "]"
)
// TODO: Once we have DictKeyContent, we need to transform that into ListElementContent // TODO: Once we have DictKeyContent, we need to transform that into ListElementContent
) and ) and
output = "ReturnValue.ListElement.TupleElement[" + i.toString() + "]" and output = "ReturnValue.ListElement.TupleElement[" + i.toString() + "]" and
@@ -4580,12 +4553,6 @@ module StdlibPrivate {
override DataFlow::ArgumentNode getACallback() { none() } override DataFlow::ArgumentNode getACallback() { none() }
override predicate propagatesFlow(string input, string output, boolean preservesValue) { override predicate propagatesFlow(string input, string output, boolean preservesValue) {
exists(DataFlow::Content c |
input = "Argument[self]." + c.getMaDRepresentation() and
output = "ReturnValue." + c.getMaDRepresentation() and
preservesValue = true
)
or
input = "Argument[self]" and input = "Argument[self]" and
output = "ReturnValue" and output = "ReturnValue" and
preservesValue = true preservesValue = true
@@ -4741,12 +4708,10 @@ module StdlibPrivate {
override DataFlow::ArgumentNode getACallback() { none() } override DataFlow::ArgumentNode getACallback() { none() }
override predicate propagatesFlow(string input, string output, boolean preservesValue) { override predicate propagatesFlow(string input, string output, boolean preservesValue) {
exists(DataFlow::DictionaryElementContent dc, string key | key = dc.getKey() | input = "Argument[self].AnyDictionaryElement" and
input = "Argument[self].DictionaryElement[" + key + "]" and output = "ReturnValue.TupleElement[1]" and
output = "ReturnValue.TupleElement[1]" and preservesValue = true
preservesValue = true // TODO: put `key` into "ReturnValue.TupleElement[0]"
// TODO: put `key` into "ReturnValue.TupleElement[0]"
)
} }
} }
@@ -4825,11 +4790,9 @@ module StdlibPrivate {
} }
override predicate propagatesFlow(string input, string output, boolean preservesValue) { override predicate propagatesFlow(string input, string output, boolean preservesValue) {
exists(DataFlow::DictionaryElementContent dc, string key | key = dc.getKey() | input = "Argument[self].AnyDictionaryElement" and
input = "Argument[self].DictionaryElement[" + key + "]" and output = "ReturnValue.ListElement" and
output = "ReturnValue.ListElement" and preservesValue = true
preservesValue = true
)
or or
input = "Argument[self]" and input = "Argument[self]" and
output = "ReturnValue" and output = "ReturnValue" and
@@ -4876,11 +4839,9 @@ module StdlibPrivate {
} }
override predicate propagatesFlow(string input, string output, boolean preservesValue) { override predicate propagatesFlow(string input, string output, boolean preservesValue) {
exists(DataFlow::DictionaryElementContent dc, string key | key = dc.getKey() | input = "Argument[self].AnyDictionaryElement" and
input = "Argument[self].DictionaryElement[" + key + "]" and output = "ReturnValue.ListElement.TupleElement[1]" and
output = "ReturnValue.ListElement.TupleElement[1]" and preservesValue = true
preservesValue = true
)
or or
// TODO: Add the keys to output list // TODO: Add the keys to output list
input = "Argument[self]" and input = "Argument[self]" and

View File

@@ -589,11 +589,11 @@ def test_zip_tuple():
SINK(z[0][0]) # $ flow="SOURCE, l:-7 -> z[0][0]" SINK(z[0][0]) # $ flow="SOURCE, l:-7 -> z[0][0]"
SINK(z[0][1]) # $ flow="SOURCE, l:-7 -> z[0][1]" SINK(z[0][1]) # $ flow="SOURCE, l:-7 -> z[0][1]"
SINK_F(z[0][2]) SINK_F(z[0][2]) # $ SPURIOUS: flow="SOURCE, l:-7 -> z[0][2]"
SINK_F(z[0][3]) SINK_F(z[0][3])
SINK(z[1][0]) # $ flow="SOURCE, l:-11 -> z[1][0]" SINK(z[1][0]) # $ flow="SOURCE, l:-11 -> z[1][0]"
SINK_F(z[1][1]) # $ SPURIOUS: flow="SOURCE, l:-11 -> z[1][1]" SINK_F(z[1][1]) # $ SPURIOUS: flow="SOURCE, l:-11 -> z[1][1]"
SINK(z[1][2]) # $ MISSING: flow="SOURCE, l:-11 -> z[1][2]" # Tuple contents are not tracked beyond the first two arguments for performance. SINK(z[1][2]) # $ flow="SOURCE, l:-11 -> z[1][2]"
SINK_F(z[1][3]) SINK_F(z[1][3])
@expects(4) @expects(4)

View File

@@ -157,7 +157,7 @@ class MyClass2(object):
print(self.foo) # $ tracked MISSING: tracked=foo print(self.foo) # $ tracked MISSING: tracked=foo
instance = MyClass2() instance = MyClass2()
print(instance.foo) # $ tracked MISSING: tracked=foo print(instance.foo) # $ MISSING: tracked=foo tracked
instance.print_foo() # $ MISSING: tracked=foo instance.print_foo() # $ MISSING: tracked=foo
@@ -195,7 +195,7 @@ class Sub1(Base1):
sub1 = Sub1() sub1 = Sub1()
sub1.read_foo() sub1.read_foo()
print(sub1.foo) # $ tracked MISSING: tracked=foo print(sub1.foo) # $ MISSING: tracked=foo tracked
# attribute written in a subclass method, read in an inherited base class method # attribute written in a subclass method, read in an inherited base class method
@@ -210,7 +210,7 @@ class Sub2(Base2):
sub2 = Sub2() sub2 = Sub2()
sub2.read_bar() sub2.read_bar()
print(sub2.bar) # $ tracked MISSING: tracked=bar print(sub2.bar) # $ MISSING: tracked=bar tracked
# attribute written in a base class method, read on an instance of the subclass # attribute written in a base class method, read on an instance of the subclass
@@ -223,4 +223,4 @@ class Sub3(Base3):
pass pass
sub3 = Sub3() sub3 = Sub3()
print(sub3.baz) # $ tracked MISSING: tracked=baz print(sub3.baz) # $ MISSING: tracked=baz tracked

View File

@@ -362,7 +362,7 @@ def test_load_in_bulk():
# see https://docs.djangoproject.com/en/4.0/ref/models/querysets/#in-bulk # see https://docs.djangoproject.com/en/4.0/ref/models/querysets/#in-bulk
d = TestLoad.objects.in_bulk([1]) d = TestLoad.objects.in_bulk([1])
for val in d.values(): for val in d.values():
SINK(val.text) # $ MISSING: flow SINK(val.text) # $ flow="SOURCE, l:-65 -> val.text"
SINK(d[1].text) # $ flow="SOURCE, l:-66 -> d[1].text" SINK(d[1].text) # $ flow="SOURCE, l:-66 -> d[1].text"

View File

@@ -1,7 +1,6 @@
#select #select
| app.py:23:20:23:24 | ControlFlowNode for query | app.py:20:18:20:21 | ControlFlowNode for name | app.py:23:20:23:24 | ControlFlowNode for query | This SQL query depends on a $@. | app.py:20:18:20:21 | ControlFlowNode for name | user-provided value | | app.py:23:20:23:24 | ControlFlowNode for query | app.py:20:18:20:21 | ControlFlowNode for name | app.py:23:20:23:24 | ControlFlowNode for query | This SQL query depends on a $@. | app.py:20:18:20:21 | ControlFlowNode for name | user-provided value |
| app.py:30:20:30:24 | ControlFlowNode for query | app.py:27:19:27:22 | ControlFlowNode for name | app.py:30:20:30:24 | ControlFlowNode for query | This SQL query depends on a $@. | app.py:27:19:27:22 | ControlFlowNode for name | user-provided value | | app.py:30:20:30:24 | ControlFlowNode for query | app.py:27:19:27:22 | ControlFlowNode for name | app.py:30:20:30:24 | ControlFlowNode for query | This SQL query depends on a $@. | app.py:27:19:27:22 | ControlFlowNode for name | user-provided value |
| app.py:37:20:37:24 | ControlFlowNode for query | app.py:34:19:34:22 | ControlFlowNode for name | app.py:37:20:37:24 | ControlFlowNode for query | This SQL query depends on a $@. | app.py:34:19:34:22 | ControlFlowNode for name | user-provided value |
| app.py:44:20:44:24 | ControlFlowNode for query | app.py:41:19:41:22 | ControlFlowNode for name | app.py:44:20:44:24 | ControlFlowNode for query | This SQL query depends on a $@. | app.py:41:19:41:22 | ControlFlowNode for name | user-provided value | | app.py:44:20:44:24 | ControlFlowNode for query | app.py:41:19:41:22 | ControlFlowNode for name | app.py:44:20:44:24 | ControlFlowNode for query | This SQL query depends on a $@. | app.py:41:19:41:22 | ControlFlowNode for name | user-provided value |
| app.py:51:20:51:24 | ControlFlowNode for query | app.py:48:19:48:22 | ControlFlowNode for name | app.py:51:20:51:24 | ControlFlowNode for query | This SQL query depends on a $@. | app.py:48:19:48:22 | ControlFlowNode for name | user-provided value | | app.py:51:20:51:24 | ControlFlowNode for query | app.py:48:19:48:22 | ControlFlowNode for name | app.py:51:20:51:24 | ControlFlowNode for query | This SQL query depends on a $@. | app.py:48:19:48:22 | ControlFlowNode for name | user-provided value |
| sql_injection.py:21:24:21:77 | ControlFlowNode for BinaryExpr | sql_injection.py:14:15:14:22 | ControlFlowNode for username | sql_injection.py:21:24:21:77 | ControlFlowNode for BinaryExpr | This SQL query depends on a $@. | sql_injection.py:14:15:14:22 | ControlFlowNode for username | user-provided value | | sql_injection.py:21:24:21:77 | ControlFlowNode for BinaryExpr | sql_injection.py:14:15:14:22 | ControlFlowNode for username | sql_injection.py:21:24:21:77 | ControlFlowNode for BinaryExpr | This SQL query depends on a $@. | sql_injection.py:14:15:14:22 | ControlFlowNode for username | user-provided value |
@@ -25,8 +24,6 @@ edges
| app.py:21:5:21:9 | ControlFlowNode for query | app.py:23:20:23:24 | ControlFlowNode for query | provenance | | | app.py:21:5:21:9 | ControlFlowNode for query | app.py:23:20:23:24 | ControlFlowNode for query | provenance | |
| app.py:27:19:27:22 | ControlFlowNode for name | app.py:28:5:28:9 | ControlFlowNode for query | provenance | | | app.py:27:19:27:22 | ControlFlowNode for name | app.py:28:5:28:9 | ControlFlowNode for query | provenance | |
| app.py:28:5:28:9 | ControlFlowNode for query | app.py:30:20:30:24 | ControlFlowNode for query | provenance | | | app.py:28:5:28:9 | ControlFlowNode for query | app.py:30:20:30:24 | ControlFlowNode for query | provenance | |
| app.py:34:19:34:22 | ControlFlowNode for name | app.py:35:5:35:9 | ControlFlowNode for query | provenance | |
| app.py:35:5:35:9 | ControlFlowNode for query | app.py:37:20:37:24 | ControlFlowNode for query | provenance | |
| app.py:41:19:41:22 | ControlFlowNode for name | app.py:42:5:42:9 | ControlFlowNode for query | provenance | | | app.py:41:19:41:22 | ControlFlowNode for name | app.py:42:5:42:9 | ControlFlowNode for query | provenance | |
| app.py:42:5:42:9 | ControlFlowNode for query | app.py:44:20:44:24 | ControlFlowNode for query | provenance | | | app.py:42:5:42:9 | ControlFlowNode for query | app.py:44:20:44:24 | ControlFlowNode for query | provenance | |
| app.py:48:19:48:22 | ControlFlowNode for name | app.py:49:5:49:9 | ControlFlowNode for query | provenance | | | app.py:48:19:48:22 | ControlFlowNode for name | app.py:49:5:49:9 | ControlFlowNode for query | provenance | |
@@ -54,9 +51,6 @@ nodes
| app.py:27:19:27:22 | ControlFlowNode for name | semmle.label | ControlFlowNode for name | | app.py:27:19:27:22 | ControlFlowNode for name | semmle.label | ControlFlowNode for name |
| app.py:28:5:28:9 | ControlFlowNode for query | semmle.label | ControlFlowNode for query | | app.py:28:5:28:9 | ControlFlowNode for query | semmle.label | ControlFlowNode for query |
| app.py:30:20:30:24 | ControlFlowNode for query | semmle.label | ControlFlowNode for query | | app.py:30:20:30:24 | ControlFlowNode for query | semmle.label | ControlFlowNode for query |
| app.py:34:19:34:22 | ControlFlowNode for name | semmle.label | ControlFlowNode for name |
| app.py:35:5:35:9 | ControlFlowNode for query | semmle.label | ControlFlowNode for query |
| app.py:37:20:37:24 | ControlFlowNode for query | semmle.label | ControlFlowNode for query |
| app.py:41:19:41:22 | ControlFlowNode for name | semmle.label | ControlFlowNode for name | | app.py:41:19:41:22 | ControlFlowNode for name | semmle.label | ControlFlowNode for name |
| app.py:42:5:42:9 | ControlFlowNode for query | semmle.label | ControlFlowNode for query | | app.py:42:5:42:9 | ControlFlowNode for query | semmle.label | ControlFlowNode for query |
| app.py:44:20:44:24 | ControlFlowNode for query | semmle.label | ControlFlowNode for query | | app.py:44:20:44:24 | ControlFlowNode for query | semmle.label | ControlFlowNode for query |

View File

@@ -31,10 +31,10 @@ async def unsafe2(name: str): # $ Source
cursor.close() cursor.close()
@app.get("/unsafe3/") @app.get("/unsafe3/")
async def unsafe3(name: str): # $ Source async def unsafe3(name: str): # $ MISSING: Source
query = "select * from users where name=" + name query = "select * from users where name=" + name
cursor = hdb_con3.cursor() cursor = hdb_con3.cursor()
cursor.execute(query) # $ Alert cursor.execute(query) # $ MISSING: Alert
cursor.close() cursor.close()
@app.get("/unsafe4/") @app.get("/unsafe4/")

View File

@@ -28,6 +28,8 @@ nodes
| string_flow.rb:227:10:227:10 | a | semmle.label | a | | string_flow.rb:227:10:227:10 | a | semmle.label | a |
subpaths subpaths
testFailures testFailures
| string_flow.rb:85:10:85:10 | a | Unexpected result: hasValueFlow=a |
| string_flow.rb:227:10:227:10 | a | Unexpected result: hasValueFlow=a |
#select #select
| string_flow.rb:3:10:3:22 | call to new | string_flow.rb:2:9:2:18 | call to source | string_flow.rb:3:10:3:22 | call to new | $@ | string_flow.rb:2:9:2:18 | call to source | call to source | | string_flow.rb:3:10:3:22 | call to new | string_flow.rb:2:9:2:18 | call to source | string_flow.rb:3:10:3:22 | call to new | $@ | string_flow.rb:2:9:2:18 | call to source | call to source |
| string_flow.rb:85:10:85:10 | a | string_flow.rb:83:9:83:18 | call to source | string_flow.rb:85:10:85:10 | a | $@ | string_flow.rb:83:9:83:18 | call to source | call to source | | string_flow.rb:85:10:85:10 | a | string_flow.rb:83:9:83:18 | call to source | string_flow.rb:85:10:85:10 | a | $@ | string_flow.rb:83:9:83:18 | call to source | call to source |

View File

@@ -82,7 +82,7 @@ end
def m_clear def m_clear
a = source "a" a = source "a"
a.clear a.clear
sink a # $ SPURIOUS: hasValueFlow=a sink a
end end
# concat and prepend omitted because they clash with the summaries for # concat and prepend omitted because they clash with the summaries for
@@ -224,7 +224,7 @@ def m_replace
b = source "b" b = source "b"
sink a.replace(b) # $ hasTaintFlow=b sink a.replace(b) # $ hasTaintFlow=b
# TODO: currently we get value flow for a, because we don't clear content # TODO: currently we get value flow for a, because we don't clear content
sink a # $ hasTaintFlow=b SPURIOUS: hasValueFlow=a sink a # $ hasTaintFlow=b
end end
def m_reverse def m_reverse

View File

@@ -18,7 +18,7 @@ class OneController < ActionController::Base
end end
def c def c
sink @foo # $ hasTaintFlow sink @foo
end end
end end
@@ -35,7 +35,7 @@ class TwoController < ActionController::Base
end end
def c def c
sink @foo # $ SPURIOUS: hasTaintFlow sink @foo
end end
end end
@@ -52,7 +52,7 @@ class ThreeController < ActionController::Base
end end
def c def c
sink @foo # $ SPURIOUS: hasTaintFlow sink @foo
end end
end end
@@ -68,7 +68,7 @@ class FourController < ActionController::Base
end end
def c def c
sink(@foo.bar) # $ hasTaintFlow sink(@foo.bar)
end end
end end
@@ -84,7 +84,7 @@ class FiveController < ActionController::Base
end end
def c def c
sink @foo # $ hasTaintFlow sink @foo
end end
def taint_foo def taint_foo

View File

@@ -270,6 +270,11 @@ nodes
| params_flow.rb:205:10:205:10 | a | semmle.label | a | | params_flow.rb:205:10:205:10 | a | semmle.label | a |
subpaths subpaths
testFailures testFailures
| filter_flow.rb:21:10:21:13 | @foo | Unexpected result: hasTaintFlow |
| filter_flow.rb:38:10:38:13 | @foo | Unexpected result: hasTaintFlow |
| filter_flow.rb:55:10:55:13 | @foo | Unexpected result: hasTaintFlow |
| filter_flow.rb:71:10:71:17 | call to bar | Unexpected result: hasTaintFlow |
| filter_flow.rb:87:11:87:14 | @foo | Unexpected result: hasTaintFlow |
#select #select
| filter_flow.rb:21:10:21:13 | @foo | filter_flow.rb:14:12:14:17 | call to params | filter_flow.rb:21:10:21:13 | @foo | $@ | filter_flow.rb:14:12:14:17 | call to params | call to params | | filter_flow.rb:21:10:21:13 | @foo | filter_flow.rb:14:12:14:17 | call to params | filter_flow.rb:21:10:21:13 | @foo | $@ | filter_flow.rb:14:12:14:17 | call to params | call to params |
| filter_flow.rb:38:10:38:13 | @foo | filter_flow.rb:30:12:30:17 | call to params | filter_flow.rb:38:10:38:13 | @foo | $@ | filter_flow.rb:30:12:30:17 | call to params | call to params | | filter_flow.rb:38:10:38:13 | @foo | filter_flow.rb:30:12:30:17 | call to params | filter_flow.rb:38:10:38:13 | @foo | $@ | filter_flow.rb:30:12:30:17 | call to params | call to params |

View File

@@ -66,7 +66,7 @@ impl<'a> AstNode for Node<'a> {
impl AstNode for yeast::Node { impl AstNode for yeast::Node {
fn kind(&self) -> &str { fn kind(&self) -> &str {
yeast::Node::kind_name(self) yeast::Node::kind(self)
} }
fn is_named(&self) -> bool { fn is_named(&self) -> bool {
yeast::Node::is_named(self) yeast::Node::is_named(self)
@@ -882,6 +882,7 @@ fn emit_extras_in(visitor: &mut Visitor, node: Node<'_>) {
} }
fn traverse_yeast(tree: &yeast::Ast, visitor: &mut Visitor) { fn traverse_yeast(tree: &yeast::Ast, visitor: &mut Visitor) {
use yeast::Cursor;
let mut cursor = tree.walk(); let mut cursor = tree.walk();
visitor.enter_node(cursor.node()); visitor.enter_node(cursor.node());
let mut recurse = true; let mut recurse = true;

View File

@@ -41,14 +41,22 @@ pub fn query(input: TokenStream) -> TokenStream {
/// (kind "literal") - leaf with static content /// (kind "literal") - leaf with static content
/// (kind #{expr}) - leaf with computed content (expr.to_string()) /// (kind #{expr}) - leaf with computed content (expr.to_string())
/// (kind $fresh) - leaf with auto-generated unique name /// (kind $fresh) - leaf with auto-generated unique name
/// {expr} - embed a Rust expression, dispatched via /// {expr} - embed a Rust expression returning Id
/// the `IntoFieldIds` trait: `Id` pushes a /// {..expr} - splice an iterable of Id (in child/field position)
/// single id; iterables (`Vec<Id>`, /// field: {..expr} - splice into a named field
/// `Option<Id>`, iterator chains) splice /// {expr}.map(p -> tpl) - apply tpl to each element; splice result
/// their elements /// {expr}.reduce_left(f -> init, acc, e -> fold)
/// field: {expr} - extend a named field with `{expr}`'s ids /// - fold with per-element init; splice 0 or 1 result
/// ``` /// ```
/// ///
/// Chain syntax after `{expr}` or `{..expr}`:
/// - `.map(param -> template)` — one output node per input element.
/// - `.reduce_left(first -> init, acc, elem -> fold)` — fold left; the first
/// element is converted by `init`, subsequent elements are folded by `fold`
/// with the accumulator bound to `acc`. An empty iterable yields nothing.
/// - Chains always splice (the result is iterable).
/// - Multiple chains can be chained, e.g. `.map(...).reduce_left(...)`.
///
/// Can be called with an explicit context or using the implicit context /// Can be called with an explicit context or using the implicit context
/// from an enclosing `rule!`: /// from an enclosing `rule!`:
/// ///
@@ -92,7 +100,7 @@ pub fn trees(input: TokenStream) -> TokenStream {
/// rule!( /// rule!(
/// (query_pattern field: (_) @name (kind)* @repeated (_)? @optional) /// (query_pattern field: (_) @name (kind)* @repeated (_)? @optional)
/// => /// =>
/// (output_template field: {name} {repeated}) /// (output_template field: {name} {..repeated})
/// ) /// )
/// ///
/// // Shorthand: captures become fields on the output node /// // Shorthand: captures become fields on the output node
@@ -113,3 +121,37 @@ pub fn rule(input: TokenStream) -> TokenStream {
Err(err) => err.to_compile_error().into(), Err(err) => err.to_compile_error().into(),
} }
} }
/// Define a desugaring rule whose transform is a hand-written Rust block.
///
/// Use `manual_rule!` when the transform needs control over capture
/// translation timing — for example, when an outer rule needs to set
/// state in `ctx` (the `BuildCtx`'s user context) before recursive
/// translation reaches inner rules that read that state.
///
/// ```text
/// manual_rule!(
/// (query_pattern field: (_) @name)
/// {
/// // `ctx` is a `&mut BuildCtx<'_, C>`; capture variables
/// // (`name: NodeRef`, etc.) are bound from the query.
/// let translated = ctx.translate(name)?;
/// Ok(translated)
/// }
/// )
/// ```
///
/// Differences from [`rule!`]:
/// - Captures are **not** auto-translated before the body runs; they
/// refer to raw input-schema nodes. Use [`BuildCtx::translate`] (or
/// [`BuildCtx::translate_opt`]) to translate them when you choose.
/// - The body is plain Rust returning `Result<Vec<Id>, String>` — no
/// tree template, no `Ok(...)` wrap.
#[proc_macro]
pub fn manual_rule(input: TokenStream) -> TokenStream {
let input2: TokenStream2 = input.into();
match parse::parse_manual_rule_top(input2) {
Ok(output) => output.into(),
Err(err) => err.to_compile_error().into(),
}
}

View File

@@ -22,9 +22,10 @@ pub fn parse_query_top(input: TokenStream) -> Result<TokenStream> {
/// Parse a single query node (possibly with a trailing `@capture`). /// Parse a single query node (possibly with a trailing `@capture`).
fn parse_query_node(tokens: &mut Tokens) -> Result<TokenStream> { fn parse_query_node(tokens: &mut Tokens) -> Result<TokenStream> {
let base = parse_query_atom(tokens)?; let base = parse_query_atom(tokens)?;
// Check for trailing @capture or @@capture // Check for trailing @capture
if peek_is_at(tokens) { if peek_is_at(tokens) {
let capture_name = consume_capture_marker(tokens)?; tokens.next(); // consume @
let capture_name = expect_ident(tokens, "expected capture name after @")?;
let name_str = capture_name.to_string(); let name_str = capture_name.to_string();
Ok(quote! { Ok(quote! {
yeast::query::QueryNode::Capture { yeast::query::QueryNode::Capture {
@@ -158,7 +159,8 @@ fn parse_query_fields(tokens: &mut Tokens) -> Result<Vec<TokenStream>> {
push_field_elem(&mut field_order, &mut field_elems, field_str, elem); push_field_elem(&mut field_order, &mut field_elems, field_str, elem);
} else { } else {
let child = if peek_is_at(tokens) { let child = if peek_is_at(tokens) {
let capture_name = consume_capture_marker(tokens)?; tokens.next();
let capture_name = expect_ident(tokens, "expected capture name after @")?;
let name_str = capture_name.to_string(); let name_str = capture_name.to_string();
quote! { quote! {
yeast::query::QueryNode::Capture { yeast::query::QueryNode::Capture {
@@ -304,8 +306,7 @@ fn parse_ctx_or_implicit(tokens: &mut Tokens) -> Ident {
&& matches!(lookahead.next(), Some(TokenTree::Punct(p)) if p.as_char() == ','); && matches!(lookahead.next(), Some(TokenTree::Punct(p)) if p.as_char() == ',');
if is_explicit { if is_explicit {
let ctx = expect_ident(tokens, "unreachable: ident was just peeked") let ctx = expect_ident(tokens, "").unwrap();
.expect("unreachable: ident was just peeked");
let _ = tokens.next(); // consume comma let _ = tokens.next(); // consume comma
ctx ctx
} else { } else {
@@ -343,7 +344,7 @@ pub fn parse_trees_top(input: TokenStream) -> Result<TokenStream> {
} }
Ok(quote! { Ok(quote! {
{ {
let mut __nodes: Vec<yeast::Id> = Vec::new(); let mut __nodes: Vec<usize> = Vec::new();
#(#items)* #(#items)*
__nodes __nodes
} }
@@ -357,7 +358,7 @@ fn parse_direct_node(tokens: &mut Tokens, ctx: &Ident) -> Result<TokenStream> {
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace => { Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace => {
let group = expect_group(tokens, Delimiter::Brace)?; let group = expect_group(tokens, Delimiter::Brace)?;
let expr = group.stream(); let expr = group.stream();
Ok(quote! { ::std::convert::Into::<yeast::Id>::into({ #expr }) }) Ok(quote! { ::std::convert::Into::<usize>::into({ #expr }) })
} }
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Parenthesis => { Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Parenthesis => {
let group = expect_group(tokens, Delimiter::Parenthesis)?; let group = expect_group(tokens, Delimiter::Parenthesis)?;
@@ -430,24 +431,49 @@ fn parse_direct_node_inner(tokens: &mut Tokens, ctx: &Ident) -> Result<TokenStre
); );
field_counter += 1; field_counter += 1;
// Plain `field: {expr}` — trait-dispatched extend. // Check for field: {..expr}.chain or field: {expr}.chain — splice a Vec<Id> into the field
if peek_is_group(tokens, Delimiter::Brace) { if peek_is_group(tokens, Delimiter::Brace) {
let group = expect_group(tokens, Delimiter::Brace)?; let group_clone = tokens.clone().next().unwrap();
let expr = group.stream(); if let TokenTree::Group(g) = &group_clone {
stmts.push(quote! { let mut inner_check = g.stream().into_iter();
let mut #temp: Vec<yeast::Id> = Vec::new(); let is_splice = matches!(inner_check.next(), Some(TokenTree::Punct(p)) if p.as_char() == '.')
yeast::IntoFieldIds::extend_into({ #expr }, &mut #temp); && matches!(inner_check.next(), Some(TokenTree::Punct(p)) if p.as_char() == '.');
}); // Determine if a chain (.map(..)) follows the `{}` group.
// An empty `{expr}` means the field is absent — skip it let mut after = tokens.clone();
// entirely rather than emitting an empty named field. after.next(); // skip the brace group
field_args.push(quote! { let has_chain =
if !#temp.is_empty() { __fields.push((#field_str, #temp)); } matches!(after.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '.');
});
continue; if is_splice || has_chain {
let group = expect_group(tokens, Delimiter::Brace)?;
let base: TokenStream = if is_splice {
let mut inner = group.stream().into_iter().peekable();
inner.next(); // consume first .
inner.next(); // consume second .
let expr: TokenStream = inner.collect();
quote! {
{ #expr }.into_iter().map(::std::convert::Into::<usize>::into)
}
} else {
let expr = group.stream();
quote! { { #expr }.into_iter() }
};
let chained = parse_chain_suffix(tokens, ctx, base)?;
stmts.push(quote! {
let #temp: Vec<usize> = #chained.collect();
});
// An empty splice means the field is absent — skip it
// entirely rather than emitting an empty named field.
field_args.push(quote! {
if !#temp.is_empty() { __fields.push((#field_str, #temp)); }
});
continue;
}
}
} }
let value = parse_direct_node(tokens, ctx)?; let value = parse_direct_node(tokens, ctx)?;
stmts.push(quote! { let #temp: yeast::Id = #value; }); stmts.push(quote! { let #temp: usize = #value; });
field_args.push(quote! { __fields.push((#field_str, vec![#temp])); }); field_args.push(quote! { __fields.push((#field_str, vec![#temp])); });
} }
@@ -464,13 +490,101 @@ fn parse_direct_node_inner(tokens: &mut Tokens, ctx: &Ident) -> Result<TokenStre
Ok(quote! { Ok(quote! {
{ {
#(#stmts)* #(#stmts)*
let mut __fields: Vec<(&str, Vec<yeast::Id>)> = Vec::new(); let mut __fields: Vec<(&str, Vec<usize>)> = Vec::new();
#(#field_args)* #(#field_args)*
#ctx.node(#kind_str, __fields) #ctx.node(#kind_str, __fields)
} }
}) })
} }
/// Parse a chain of `.method(args)` suffixes after a `{expr}` or `{..expr}`
/// placeholder in tree templates. Currently supports:
///
/// ```text
/// .map(param -> template) -- iterator map: produces Vec<usize>
/// ```
///
/// The chain may be empty (returns `base` unchanged). Multiple chained calls
/// are supported, e.g. `.map(p -> ...).map(q -> ...)`.
///
/// Each call expects the receiver to be an iterator. The `base` argument
/// should therefore already be an iterator (use `.into_iter()` on it before
/// calling this function).
fn parse_chain_suffix(tokens: &mut Tokens, ctx: &Ident, base: TokenStream) -> Result<TokenStream> {
let mut current = base;
while matches!(tokens.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '.') {
tokens.next(); // consume .
let method = expect_ident(tokens, "expected method name after `.`")?;
let method_str = method.to_string();
let args_group = expect_group(tokens, Delimiter::Parenthesis)?;
match method_str.as_str() {
"map" => {
let mut inner = args_group.stream().into_iter().peekable();
let param = expect_ident(&mut inner, "expected lambda parameter name")?;
expect_punct(&mut inner, '-', "expected `->` after lambda parameter")?;
expect_punct(&mut inner, '>', "expected `->` after lambda parameter")?;
let body = parse_direct_node(&mut inner, ctx)?;
if let Some(tok) = inner.next() {
return Err(syn::Error::new_spanned(
tok,
"unexpected token after lambda body",
));
}
current = quote! {
#current.map(|#param| #body)
};
}
"reduce_left" => {
// Syntax: reduce_left(first -> init_tpl, acc, elem -> fold_tpl)
// - first -> init_tpl : converts the first element to the initial accumulator
// - acc, elem -> fold_tpl : fold step (acc = current accumulator, elem = next element)
// Empty iterator produces an empty iterator; non-empty produces a single-element iterator.
let mut inner = args_group.stream().into_iter().peekable();
let init_param = expect_ident(&mut inner, "expected initial lambda parameter")?;
expect_punct(&mut inner, '-', "expected `->` after init parameter")?;
expect_punct(&mut inner, '>', "expected `->` after init parameter")?;
let init_body = parse_direct_node(&mut inner, ctx)?;
expect_punct(&mut inner, ',', "expected `,` after init template")?;
let acc_param = expect_ident(&mut inner, "expected accumulator parameter")?;
expect_punct(&mut inner, ',', "expected `,` after accumulator parameter")?;
let elem_param = expect_ident(&mut inner, "expected element parameter")?;
expect_punct(&mut inner, '-', "expected `->` after element parameter")?;
expect_punct(&mut inner, '>', "expected `->` after element parameter")?;
let fold_body = parse_direct_node(&mut inner, ctx)?;
if let Some(tok) = inner.next() {
return Err(syn::Error::new_spanned(
tok,
"unexpected token after fold template",
));
}
current = quote! {
{
let mut __iter = #current;
let __result: Option<usize> = if let Some(#init_param) = __iter.next() {
let mut __acc: usize = #init_body;
for #elem_param in __iter {
let #acc_param: usize = __acc;
__acc = #fold_body;
}
Some(__acc)
} else {
None
};
__result.into_iter()
}
};
}
_ => {
return Err(syn::Error::new_spanned(
method,
format!("unknown builtin method `.{method_str}()`"),
));
}
}
}
Ok(current)
}
/// Parse the top-level list of a `trees!` template. /// Parse the top-level list of a `trees!` template.
/// Each item is a node template or `{expr}` splice. /// Each item is a node template or `{expr}` splice.
fn parse_direct_list(tokens: &mut Tokens, ctx: &Ident) -> Result<Vec<TokenStream>> { fn parse_direct_list(tokens: &mut Tokens, ctx: &Ident) -> Result<Vec<TokenStream>> {
@@ -491,14 +605,35 @@ fn parse_direct_list(tokens: &mut Tokens, ctx: &Ident) -> Result<Vec<TokenStream
continue; continue;
} }
// `{expr}` — extend `__nodes` via `IntoFieldIds`, which handles // {expr} or {..expr} (with optional .chain) — single node or splice
// single ids and iterables uniformly.
if peek_is_group(tokens, Delimiter::Brace) { if peek_is_group(tokens, Delimiter::Brace) {
let group = expect_group(tokens, Delimiter::Brace)?; let group = expect_group(tokens, Delimiter::Brace)?;
let expr = group.stream(); let has_chain =
items.push(quote! { matches!(tokens.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '.');
yeast::IntoFieldIds::extend_into({ #expr }, &mut __nodes); let mut inner = group.stream().into_iter().peekable();
}); let is_splice = peek_is_dotdot(&inner);
if is_splice || has_chain {
let base: TokenStream = if is_splice {
inner.next(); // consume first .
inner.next(); // consume second .
let expr: TokenStream = inner.collect();
quote! {
{ #expr }.into_iter().map(::std::convert::Into::<usize>::into)
}
} else {
let expr = group.stream();
quote! { { #expr }.into_iter() }
};
let chained = parse_chain_suffix(tokens, ctx, base)?;
items.push(quote! {
__nodes.extend(#chained);
});
} else {
let expr = group.stream();
items.push(quote! {
__nodes.push(::std::convert::Into::<usize>::into({ #expr }));
});
}
continue; continue;
} }
@@ -515,9 +650,6 @@ fn parse_direct_list(tokens: &mut Tokens, ctx: &Ident) -> Result<Vec<TokenStream
struct CaptureInfo { struct CaptureInfo {
name: String, name: String,
multiplicity: CaptureMultiplicity, multiplicity: CaptureMultiplicity,
/// `true` for `@@name` captures: the auto-translate prefix skips them,
/// so the bound `Id` refers to the raw (input-schema) node.
raw: bool,
} }
#[derive(Clone, Copy, PartialEq)] #[derive(Clone, Copy, PartialEq)]
@@ -576,14 +708,6 @@ fn extract_captures_inner(
extract_captures_inner(&mut inner, captures, child_mult); extract_captures_inner(&mut inner, captures, child_mult);
} }
TokenTree::Punct(p) if p.as_char() == '@' => { TokenTree::Punct(p) if p.as_char() == '@' => {
// `@@name` marks the capture as raw (skip auto-translate).
let raw = matches!(
tokens.peek(),
Some(TokenTree::Punct(p)) if p.as_char() == '@'
);
if raw {
tokens.next(); // consume the second `@`
}
if let Some(TokenTree::Ident(name)) = tokens.next() { if let Some(TokenTree::Ident(name)) = tokens.next() {
let mult = if parent_mult == CaptureMultiplicity::Repeated let mult = if parent_mult == CaptureMultiplicity::Repeated
|| last_mult == CaptureMultiplicity::Repeated || last_mult == CaptureMultiplicity::Repeated
@@ -599,7 +723,6 @@ fn extract_captures_inner(
captures.push(CaptureInfo { captures.push(CaptureInfo {
name: name.to_string(), name: name.to_string(),
multiplicity: mult, multiplicity: mult,
raw,
}); });
} }
last_mult = CaptureMultiplicity::Single; last_mult = CaptureMultiplicity::Single;
@@ -653,14 +776,6 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
// Parse query // Parse query
let query_code = parse_query_top(query_stream.clone())?; let query_code = parse_query_top(query_stream.clone())?;
// Capture names marked `@@name` (raw) — passed to the auto-translate
// prefix as a skip list so those captures keep their input-schema ids.
let raw_capture_names: Vec<&str> = captures
.iter()
.filter(|c| c.raw)
.map(|c| c.name.as_str())
.collect();
// Generate capture bindings // Generate capture bindings
let ctx_ident = Ident::new(IMPLICIT_CTX, Span::call_site()); let ctx_ident = Ident::new(IMPLICIT_CTX, Span::call_site());
let bindings: Vec<TokenStream> = captures let bindings: Vec<TokenStream> = captures
@@ -671,17 +786,22 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
match cap.multiplicity { match cap.multiplicity {
CaptureMultiplicity::Repeated => { CaptureMultiplicity::Repeated => {
quote! { quote! {
let #name: Vec<yeast::Id> = __captures.get_all(#name_str); let #name: Vec<yeast::NodeRef> = __captures.get_all(#name_str)
.into_iter()
.map(yeast::NodeRef)
.collect();
} }
} }
CaptureMultiplicity::Optional => { CaptureMultiplicity::Optional => {
quote! { quote! {
let #name: Option<yeast::Id> = __captures.get_opt(#name_str); let #name: Option<yeast::NodeRef> =
__captures.get_opt(#name_str).map(yeast::NodeRef);
} }
} }
CaptureMultiplicity::Single => { CaptureMultiplicity::Single => {
quote! { quote! {
let #name: yeast::Id = __captures.get_var(#name_str).unwrap(); let #name: yeast::NodeRef =
yeast::NodeRef(__captures.get_var(#name_str).unwrap());
} }
} }
} }
@@ -712,7 +832,7 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
__fields.insert( __fields.insert(
__field_id, __field_id,
#name.into_iter() #name.into_iter()
.map(::std::convert::Into::<yeast::Id>::into) .map(::std::convert::Into::<usize>::into)
.collect(), .collect(),
); );
}, },
@@ -721,14 +841,14 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
.unwrap_or_else(|| panic!("field '{}' not found", #name_str)); .unwrap_or_else(|| panic!("field '{}' not found", #name_str));
if let Some(__id) = #name { if let Some(__id) = #name {
__fields.entry(__field_id).or_insert_with(Vec::new) __fields.entry(__field_id).or_insert_with(Vec::new)
.push(::std::convert::Into::<yeast::Id>::into(__id)); .push(::std::convert::Into::<usize>::into(__id));
} }
}, },
CaptureMultiplicity::Single => quote! { CaptureMultiplicity::Single => quote! {
let __field_id = #ctx_ident.ast.field_id_for_name(#name_str) let __field_id = #ctx_ident.ast.field_id_for_name(#name_str)
.unwrap_or_else(|| panic!("field '{}' not found", #name_str)); .unwrap_or_else(|| panic!("field '{}' not found", #name_str));
__fields.entry(__field_id).or_insert_with(Vec::new) __fields.entry(__field_id).or_insert_with(Vec::new)
.push(::std::convert::Into::<yeast::Id>::into(#name)); .push(::std::convert::Into::<usize>::into(#name));
}, },
} }
}) })
@@ -760,7 +880,7 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
} }
quote! { quote! {
let mut __nodes: Vec<yeast::Id> = Vec::new(); let mut __nodes: Vec<usize> = Vec::new();
#(#transform_items)* #(#transform_items)*
__nodes __nodes
} }
@@ -771,23 +891,120 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
let __query = #query_code; let __query = #query_code;
yeast::Rule::new(__query, Box::new(|__ast: &mut yeast::Ast, mut __captures: yeast::captures::Captures, __fresh: &yeast::tree_builder::FreshScope, __source_range: Option<tree_sitter::Range>, __user_ctx: &mut _, __translator: yeast::TranslatorHandle<'_, _>| { yeast::Rule::new(__query, Box::new(|__ast: &mut yeast::Ast, mut __captures: yeast::captures::Captures, __fresh: &yeast::tree_builder::FreshScope, __source_range: Option<tree_sitter::Range>, __user_ctx: &mut _, __translator: yeast::TranslatorHandle<'_, _>| {
// Auto-translation prefix: recursively translate every // Auto-translation prefix: recursively translate every
// captured node before invoking the user's transform body, // captured node before invoking the user's transform body.
// except for `@@name` captures listed in `__skip` which the
// body consumes raw.
// For OneShot rules this preserves the legacy behaviour // For OneShot rules this preserves the legacy behaviour
// (input-schema captures translated to output-schema // (input-schema captures translated to output-schema
// nodes); for Repeating rules it is a no-op. // nodes); for Repeating rules it is a no-op.
let __skip: &[&str] = &[#(#raw_capture_names),*]; __translator.auto_translate_captures(&mut __captures, __ast, __user_ctx)?;
__translator.auto_translate_captures(&mut __captures, __ast, __user_ctx, __skip)?;
#(#bindings)* #(#bindings)*
let mut #ctx_ident = yeast::build::BuildCtx::with_translator(__ast, &__captures, __fresh, __source_range, __user_ctx, __translator); let mut #ctx_ident = yeast::build::BuildCtx::with_translator(__ast, &__captures, __fresh, __source_range, __user_ctx, __translator);
let __result: Vec<yeast::Id> = { #transform_body }; let __result: Vec<usize> = { #transform_body };
Ok(__result) Ok(__result)
})) }))
} }
}) })
} }
/// Parse `manual_rule!( query { body } )`.
///
/// Like [`parse_rule_top`] but:
/// - Expects a Rust block `{ ... }` after the query (no `=>` arrow).
/// - Generates code that does NOT auto-translate captures before
/// running the body. Capture variables refer to raw (input-schema)
/// nodes; the body is responsible for explicit translation via
/// `ctx.translate(...)`.
/// - The body is included verbatim and must evaluate to
/// `Result<Vec<usize>, String>`.
pub fn parse_manual_rule_top(input: TokenStream) -> Result<TokenStream> {
let mut tokens = input.into_iter().peekable();
// Collect query tokens up to the body block `{ ... }`.
let mut query_tokens = Vec::new();
loop {
match tokens.peek() {
None => {
return Err(syn::Error::new(
Span::call_site(),
"expected a Rust block `{ ... }` after the query in manual_rule!",
))
}
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace => break,
_ => {
query_tokens.push(tokens.next().unwrap());
}
}
}
let query_stream: TokenStream = query_tokens.into_iter().collect();
// Extract captures from the query (same as in `rule!`).
let captures = extract_captures(&query_stream);
// Parse the query into the QueryNode-building expression.
let query_code = parse_query_top(query_stream)?;
// Generate capture bindings (same as in `rule!`).
let ctx_ident = Ident::new(IMPLICIT_CTX, Span::call_site());
let bindings: Vec<TokenStream> = captures
.iter()
.map(|cap| {
let name = Ident::new(&cap.name, Span::call_site());
let name_str = &cap.name;
match cap.multiplicity {
CaptureMultiplicity::Repeated => quote! {
let #name: Vec<yeast::NodeRef> = __captures.get_all(#name_str)
.into_iter()
.map(yeast::NodeRef)
.collect();
},
CaptureMultiplicity::Optional => quote! {
let #name: Option<yeast::NodeRef> =
__captures.get_opt(#name_str).map(yeast::NodeRef);
},
CaptureMultiplicity::Single => quote! {
let #name: yeast::NodeRef =
yeast::NodeRef(__captures.get_var(#name_str).unwrap());
},
}
})
.collect();
// Consume the body block.
let body_group = match tokens.next() {
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace => g,
other => {
return Err(syn::Error::new(
Span::call_site(),
format!(
"expected a Rust block `{{ ... }}` after the query in manual_rule!, found: {other:?}"
),
))
}
};
let body_stream = body_group.stream();
// No tokens should follow the body.
if let Some(tok) = tokens.next() {
return Err(syn::Error::new_spanned(
tok,
"unexpected token after manual_rule! body",
));
}
Ok(quote! {
{
let __query = #query_code;
yeast::Rule::new(__query, Box::new(|__ast: &mut yeast::Ast, __captures: yeast::captures::Captures, __fresh: &yeast::tree_builder::FreshScope, __source_range: Option<tree_sitter::Range>, __user_ctx: &mut _, __translator: yeast::TranslatorHandle<'_, _>| {
// No auto-translate prefix for manual rules — the body
// is responsible for translating captures explicitly.
#(#bindings)*
let mut #ctx_ident = yeast::build::BuildCtx::with_translator(__ast, &__captures, __fresh, __source_range, __user_ctx, __translator);
#body_stream
}))
}
})
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Token utilities // Token utilities
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -796,16 +1013,6 @@ fn peek_is_at(tokens: &mut Tokens) -> bool {
matches!(tokens.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '@') matches!(tokens.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '@')
} }
/// Consume an `@` or `@@` capture marker and the following name ident.
/// Caller has already verified `peek_is_at(tokens)`.
fn consume_capture_marker(tokens: &mut Tokens) -> Result<Ident> {
tokens.next(); // consume the first `@`
if peek_is_at(tokens) {
tokens.next(); // consume the second `@` of `@@`
}
expect_ident(tokens, "expected capture name after `@` or `@@`")
}
fn peek_is_literal(tokens: &mut Tokens) -> bool { fn peek_is_literal(tokens: &mut Tokens) -> bool {
matches!(tokens.peek(), Some(TokenTree::Literal(_))) matches!(tokens.peek(), Some(TokenTree::Literal(_)))
} }
@@ -818,6 +1025,13 @@ fn peek_is_hash(tokens: &mut Tokens) -> bool {
matches!(tokens.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '#') matches!(tokens.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '#')
} }
/// Check for `..` (two consecutive dot punctuation tokens).
fn peek_is_dotdot(tokens: &Tokens) -> bool {
let mut lookahead = tokens.clone();
matches!(lookahead.next(), Some(TokenTree::Punct(p)) if p.as_char() == '.')
&& matches!(lookahead.next(), Some(TokenTree::Punct(p)) if p.as_char() == '.')
}
fn peek_is_underscore(tokens: &mut Tokens) -> bool { fn peek_is_underscore(tokens: &mut Tokens) -> bool {
matches!(tokens.peek(), Some(TokenTree::Ident(id)) if *id == "_") matches!(tokens.peek(), Some(TokenTree::Ident(id)) if *id == "_")
} }
@@ -899,7 +1113,8 @@ fn expect_repetition(tokens: &mut Tokens) -> Result<TokenStream> {
fn maybe_wrap_capture(tokens: &mut Tokens, base: TokenStream) -> Result<TokenStream> { fn maybe_wrap_capture(tokens: &mut Tokens, base: TokenStream) -> Result<TokenStream> {
if peek_is_at(tokens) { if peek_is_at(tokens) {
let name = consume_capture_marker(tokens)?; tokens.next(); // consume @
let name = expect_ident(tokens, "expected capture name after @")?;
let name_str = name.to_string(); let name_str = name.to_string();
Ok(quote! { Ok(quote! {
yeast::query::QueryNode::Capture { yeast::query::QueryNode::Capture {
@@ -926,12 +1141,13 @@ fn maybe_wrap_repetition(tokens: &mut Tokens, single: TokenStream) -> Result<Tok
} }
} }
/// If `@name` (or `@@name`) follows a Repeated list element, wrap each /// If `@name` follows a Repeated list element, wrap each child SingleNode
/// child SingleNode inside the repetition with a Capture. This matches /// inside the repetition with a Capture. This matches tree-sitter semantics
/// tree-sitter semantics where `(_)* @name` captures each matched node. /// where `(_)* @name` captures each matched node.
fn maybe_wrap_list_capture(tokens: &mut Tokens, elem: TokenStream) -> Result<TokenStream> { fn maybe_wrap_list_capture(tokens: &mut Tokens, elem: TokenStream) -> Result<TokenStream> {
if peek_is_at(tokens) { if peek_is_at(tokens) {
let name = consume_capture_marker(tokens)?; tokens.next();
let name = expect_ident(tokens, "expected capture name after @")?;
let name_str = name.to_string(); let name_str = name.to_string();
// Re-parse the element isn't practical, so we generate a wrapper // Re-parse the element isn't practical, so we generate a wrapper
// that creates a new Repeated with each child wrapped in a capture. // that creates a new Repeated with each child wrapped in a capture.

View File

@@ -214,7 +214,7 @@ yeast::tree!(ctx,
```rust ```rust
yeast::trees!(ctx, yeast::trees!(ctx,
(assignment left: {tmp} right: {right}) (assignment left: {tmp} right: {right})
{body} {..body}
) )
``` ```
@@ -256,26 +256,12 @@ occurrences of the same `$name` within one `BuildCtx` share the same value:
### Embedded Rust expressions ### Embedded Rust expressions
`{expr}` embeds a Rust expression whose value is appended to the `{expr}` embeds a Rust expression that returns a single node `Id`:
enclosing field (or to the rule body's id list). Dispatch happens via
the [`IntoFieldIds`] trait, which is implemented for:
- `Id` — pushes the single id.
- Any `IntoIterator<Item: Into<Id>>` — extends with all yielded ids
(covers `Vec<Id>`, `Option<Id>`, iterator chains, etc.).
So the same `{expr}` syntax handles single ids, splices, and zero-or-many
options uniformly:
```rust ```rust
(assignment (assignment
left: {some_node_id} // a single Id left: {some_node_id} // insert a pre-built node
right: {rhs} // a captured value (inside rule!) right: {rhs} // insert a captured value (inside rule!)
)
yeast::trees!(ctx,
(assignment left: {tmp} right: {right})
{extra_nodes} // splices a Vec<Id>
) )
``` ```
@@ -291,47 +277,20 @@ expressions (with `let` bindings) work too:
}) })
``` ```
Inside `rule!`, captures are Rust variables — `{name}` works for `{..expr}` splices a `Vec<Id>` (or any iterable of `Id`); the contents
single, optional, and repeated captures alike: are likewise a Rust block, so the splice can be the result of arbitrary
computation:
```rust ```rust
rule!( yeast::trees!(ctx,
(assignment left: @lhs right: _* @parts) (assignment left: {tmp} right: {right})
=> {..extra_nodes} // splice a Vec<Id>
(assignment left: {lhs} right: (block stmt: {parts}))
) )
``` ```
### Raw captures (`@@name`) Inside `rule!`, captures are Rust variables, so `{name}` inserts a
single capture (`Id`) and `{..name}` splices a repeated capture
The default `@name` capture marker is *auto-translated*: in OneShot (`Vec<Id>`).
phases the macro recursively translates the captured node before
binding it, so `{name}` in the output template splices a node that
already conforms to the output schema.
For rules that need the raw (input-schema) capture — typically to read
its source text or to translate it explicitly with mutable context
state between calls — use `@@name` instead. The body sees the original
input-schema `Id`:
```rust
yeast::rule!(
(assignment left: (_) @@raw_lhs right: (_) @rhs)
=>
{
// raw_lhs is untranslated: read its original source text.
let text = ctx.ast.source_text(raw_lhs);
// rhs is already translated by the auto-translate prefix.
tree!((call
method: (identifier #{text.as_str()})
receiver: {rhs}))
}
);
```
Mix `@` and `@@` freely in the same rule. In a Repeating phase both
markers are equivalent (auto-translation is a no-op for repeating
rules).
## Complete example: for-loop desugaring ## Complete example: for-loop desugaring

View File

@@ -158,6 +158,15 @@ impl<'a, C> BuildCtx<'a, C> {
self.ast self.ast
.create_named_token_with_range(kind, generated, self.source_range) .create_named_token_with_range(kind, generated, self.source_range)
} }
/// Prepend a value to a field of an existing node.
pub fn prepend_field(&mut self, node_id: Id, field_name: &str, value_id: Id) {
let field_id = self
.ast
.field_id_for_name(field_name)
.unwrap_or_else(|| panic!("build: field '{field_name}' not found"));
self.ast.prepend_field_child(node_id, field_id, value_id);
}
} }
impl<C: Clone> BuildCtx<'_, C> { impl<C: Clone> BuildCtx<'_, C> {
@@ -167,6 +176,9 @@ impl<C: Clone> BuildCtx<'_, C> {
/// (translation is not meaningful when input and output share a /// (translation is not meaningful when input and output share a
/// schema). /// schema).
/// ///
/// Accepts any value convertible to [`Id`] (including [`crate::NodeRef`]),
/// so manual rules can pass capture bindings directly without unwrapping.
///
/// Errors if this `BuildCtx` was constructed by hand (without a /// Errors if this `BuildCtx` was constructed by hand (without a
/// translator handle) — for example, in unit tests that don't go /// translator handle) — for example, in unit tests that don't go
/// through the rule driver. /// through the rule driver.
@@ -177,6 +189,20 @@ impl<C: Clone> BuildCtx<'_, C> {
None => Err("translate() called on a BuildCtx without a translator handle".into()), None => Err("translate() called on a BuildCtx without a translator handle".into()),
} }
} }
/// Translate an optional capture, returning the first translated id or
/// `None`. Convenience for `?`-quantifier captures (`Option<NodeRef>`).
///
/// If the underlying translation produces multiple ids for a single
/// input, only the first is returned. For most use cases (e.g.
/// translating a single type annotation) this is what you want; if
/// you need all ids, use [`translate`] directly.
pub fn translate_opt<I: Into<Id>>(&mut self, id: Option<I>) -> Result<Option<Id>, String> {
match id {
Some(id) => Ok(self.translate(id)?.into_iter().next()),
None => Ok(None),
}
}
} }
impl<C> std::ops::Deref for BuildCtx<'_, C> { impl<C> std::ops::Deref for BuildCtx<'_, C> {

View File

@@ -54,24 +54,24 @@ impl Captures {
self.captures.entry(key).or_default().push(id); self.captures.entry(key).or_default().push(id);
} }
/// Apply a fallible function to every captured id, replacing each id pub fn map_captures(&mut self, kind: &str, f: &mut impl FnMut(Id) -> Id) {
/// with the results. A function returning an empty vector removes if let Some(ids) = self.captures.get_mut(kind) {
/// the capture; returning multiple ids splices them into the for id in ids {
/// capture's value list (suitable for `*`/`+` captures). Captures *id = f(*id);
/// whose name appears in `skip` are left untouched. Stops and }
/// returns the error on the first failure. }
/// }
/// Used by the `rule!` macro's auto-translate prefix to translate
/// every capture except those marked `@@name` (raw). /// Apply a fallible function to every captured id (across all keys),
pub fn try_map_captures_except<E>( /// replacing each id with the results. A function returning an empty
/// vector removes the capture; returning multiple ids splices them
/// into the capture's value list (suitable for `*`/`+` captures).
/// Stops and returns the error on the first failure.
pub fn try_map_all_captures<E>(
&mut self, &mut self,
skip: &[&str],
mut f: impl FnMut(Id) -> Result<Vec<Id>, E>, mut f: impl FnMut(Id) -> Result<Vec<Id>, E>,
) -> Result<(), E> { ) -> Result<(), E> {
for (name, ids) in self.captures.iter_mut() { for ids in self.captures.values_mut() {
if skip.contains(name) {
continue;
}
let mut new_ids = Vec::with_capacity(ids.len()); let mut new_ids = Vec::with_capacity(ids.len());
for &id in ids.iter() { for &id in ids.iter() {
new_ids.extend(f(id)?); new_ids.extend(f(id)?);
@@ -80,6 +80,12 @@ impl Captures {
} }
Ok(()) Ok(())
} }
pub fn map_captures_to(&mut self, from: &str, to: &'static str, f: &mut impl FnMut(Id) -> Id) {
if let Some(from_ids) = self.captures.get(from) {
let new_values = from_ids.iter().copied().map(f).collect();
self.captures.insert(to, new_values);
}
}
pub fn merge(&mut self, other: &Captures) { pub fn merge(&mut self, other: &Captures) {
for (key, ids) in &other.captures { for (key, ids) in &other.captures {

View File

@@ -0,0 +1,8 @@
pub trait Cursor<'a, T, N, F> {
fn node(&self) -> &'a N;
fn field_id(&self) -> Option<F>;
fn field_name(&self) -> Option<&'static str>;
fn goto_first_child(&mut self) -> bool;
fn goto_next_sibling(&mut self) -> bool;
fn goto_parent(&mut self) -> bool;
}

View File

@@ -1,6 +1,6 @@
use std::fmt::Write; use std::fmt::Write;
use crate::{schema::Schema, Ast, Id, Node, NodeContent, CHILD_FIELD}; use crate::{schema::Schema, Ast, Node, NodeContent, CHILD_FIELD};
/// Options for controlling AST dump output. /// Options for controlling AST dump output.
pub struct DumpOptions { pub struct DumpOptions {
@@ -34,11 +34,16 @@ impl Default for DumpOptions {
/// method: /// method:
/// identifier "foo" /// identifier "foo"
/// ``` /// ```
pub fn dump_ast(ast: &Ast, root: Id, source: &str) -> String { pub fn dump_ast(ast: &Ast, root: usize, source: &str) -> String {
dump_ast_with_options(ast, root, source, &DumpOptions::default()) dump_ast_with_options(ast, root, source, &DumpOptions::default())
} }
pub fn dump_ast_with_options(ast: &Ast, root: Id, source: &str, options: &DumpOptions) -> String { pub fn dump_ast_with_options(
ast: &Ast,
root: usize,
source: &str,
options: &DumpOptions,
) -> String {
let mut out = String::new(); let mut out = String::new();
dump_node(ast, root, source, options, 0, None, &mut out); dump_node(ast, root, source, options, 0, None, &mut out);
out out
@@ -48,7 +53,7 @@ pub fn dump_ast_with_options(ast: &Ast, root: Id, source: &str, options: &DumpOp
/// ///
/// Any node that does not match the expected type set for its parent field is /// Any node that does not match the expected type set for its parent field is
/// rendered with a trailing `" <-- ERROR: ..."` annotation on the same line. /// rendered with a trailing `" <-- ERROR: ..."` annotation on the same line.
pub fn dump_ast_with_type_errors(ast: &Ast, root: Id, source: &str, schema: &Schema) -> String { pub fn dump_ast_with_type_errors(ast: &Ast, root: usize, source: &str, schema: &Schema) -> String {
dump_ast_with_type_errors_and_options(ast, root, source, schema, &DumpOptions::default()) dump_ast_with_type_errors_and_options(ast, root, source, schema, &DumpOptions::default())
} }
@@ -58,7 +63,7 @@ pub fn dump_ast_with_type_errors(ast: &Ast, root: Id, source: &str, schema: &Sch
/// rendered with a trailing `" <-- ERROR: ..."` annotation on the same line. /// rendered with a trailing `" <-- ERROR: ..."` annotation on the same line.
pub fn dump_ast_with_type_errors_and_options( pub fn dump_ast_with_type_errors_and_options(
ast: &Ast, ast: &Ast,
root: Id, root: usize,
source: &str, source: &str,
schema: &Schema, schema: &Schema,
options: &DumpOptions, options: &DumpOptions,
@@ -171,7 +176,7 @@ fn expected_for_field<'a>(
fn dump_node( fn dump_node(
ast: &Ast, ast: &Ast,
id: Id, id: usize,
source: &str, source: &str,
options: &DumpOptions, options: &DumpOptions,
indent: usize, indent: usize,
@@ -310,7 +315,7 @@ fn dump_node(
/// Dump a leaf node inline (no newline prefix, caller provides context). /// Dump a leaf node inline (no newline prefix, caller provides context).
fn dump_node_inline( fn dump_node_inline(
ast: &Ast, ast: &Ast,
id: Id, id: usize,
source: &str, source: &str,
options: &DumpOptions, options: &DumpOptions,
type_check: Option<( type_check: Option<(

View File

@@ -7,6 +7,7 @@ use serde_json::{json, Value};
pub mod build; pub mod build;
pub mod captures; pub mod captures;
pub mod cursor;
pub mod dump; pub mod dump;
pub mod node_types_yaml; pub mod node_types_yaml;
pub mod query; pub mod query;
@@ -15,64 +16,35 @@ pub mod schema;
pub mod tree_builder; pub mod tree_builder;
mod visitor; mod visitor;
pub use yeast_macros::{query, rule, tree, trees}; pub use yeast_macros::{manual_rule, query, rule, tree, trees};
use captures::Captures; use captures::Captures;
pub use cursor::Cursor;
use query::QueryNode; use query::QueryNode;
/// Node id: an index into the [`Ast`] arena. A newtype around `usize` /// Node ids are indexes into the arena
/// rather than a bare alias so that it can carry its own pub type Id = usize;
/// [`YeastDisplay`] / [`YeastSourceRange`] / [`IntoFieldIds`] impls
/// without colliding with the impls for plain integers.
///
/// Use `id.0` (or `id.into()`) to obtain the raw arena index.
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug, Hash, Serialize)]
pub struct Id(pub usize);
impl From<usize> for Id {
fn from(value: usize) -> Self {
Id(value)
}
}
impl From<Id> for usize {
fn from(value: Id) -> Self {
value.0
}
}
/// Field and Kind ids are provided by tree-sitter /// Field and Kind ids are provided by tree-sitter
type FieldId = u16; type FieldId = u16;
type KindId = u16; type KindId = u16;
/// Trait for values that can be appended to a field's id list inside a /// A typed reference to a node in an [`Ast`] arena. Wraps an [`Id`] but
/// `tree!`/`trees!`/`rule!` template (in `{expr}` placeholders). /// deliberately does not implement [`std::fmt::Display`]: rendering a node
/// /// requires the [`Ast`] it lives in (to resolve [`NodeContent::Range`] back
/// `Id` pushes a single id; the blanket impl for /// to source text). Use [`YeastDisplay::yeast_to_string`] to format it.
/// `IntoIterator<Item: Into<Id>>` handles `Vec<Id>`, `Option<Id>`, #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
/// arbitrary iterators yielding `Id`, etc. pub struct NodeRef(pub Id);
///
/// This lets `{expr}` interpolate any of these shapes without a
/// dedicated splice syntax — the macro emits the same trait-dispatched
/// call regardless of the value's type.
pub trait IntoFieldIds {
fn extend_into(self, out: &mut Vec<Id>);
}
impl IntoFieldIds for Id { impl NodeRef {
fn extend_into(self, out: &mut Vec<Id>) { pub fn id(self) -> Id {
out.push(self); self.0
} }
} }
impl<I, T> IntoFieldIds for I impl From<NodeRef> for Id {
where fn from(value: NodeRef) -> Self {
I: IntoIterator<Item = T>, value.0
T: Into<Id>,
{
fn extend_into(self, out: &mut Vec<Id>) {
out.extend(self.into_iter().map(Into::into));
} }
} }
@@ -89,21 +61,21 @@ pub trait YeastDisplay {
/// Optional source range for values used in `#{expr}` interpolations. /// Optional source range for values used in `#{expr}` interpolations.
/// ///
/// By default this returns `None`, so synthesized leaves inherit the matched /// By default this returns `None`, so synthesized leaves inherit the matched
/// rule's source range. `Id` returns the referenced node's range, letting /// rule's source range. `NodeRef` returns the referenced node's range, letting
/// `(kind #{capture})` carry the captured node's location. /// `(kind #{capture})` carry the captured node's location.
pub trait YeastSourceRange { pub trait YeastSourceRange {
fn yeast_source_range(&self, ast: &Ast) -> Option<tree_sitter::Range>; fn yeast_source_range(&self, ast: &Ast) -> Option<tree_sitter::Range>;
} }
impl YeastDisplay for Id { impl YeastDisplay for NodeRef {
fn yeast_to_string(&self, ast: &Ast) -> String { fn yeast_to_string(&self, ast: &Ast) -> String {
ast.source_text(*self) ast.source_text(self.0)
} }
} }
impl YeastSourceRange for Id { impl YeastSourceRange for NodeRef {
fn yeast_source_range(&self, ast: &Ast) -> Option<tree_sitter::Range> { fn yeast_source_range(&self, ast: &Ast) -> Option<tree_sitter::Range> {
ast.get_node(*self).and_then(|n| match &n.content { ast.get_node(self.0).and_then(|n| match &n.content {
NodeContent::Range(r) => Some(r.clone()), NodeContent::Range(r) => Some(r.clone()),
_ => n.source_range, _ => n.source_range,
}) })
@@ -172,36 +144,6 @@ impl<'a> AstCursor<'a> {
self.node_id self.node_id
} }
pub fn node(&self) -> &'a Node {
&self.ast.nodes[self.node_id.0]
}
pub fn field_id(&self) -> Option<FieldId> {
let (_, children) = self.parents.last()?;
children.current_field()
}
pub fn field_name(&self) -> Option<&'static str> {
if self.field_id() == Some(CHILD_FIELD) {
None
} else {
self.field_id()
.and_then(|id| self.ast.field_name_for_id(id))
}
}
pub fn goto_first_child(&mut self) -> bool {
self.goto_first_child_opt().is_some()
}
pub fn goto_next_sibling(&mut self) -> bool {
self.goto_next_sibling_opt().is_some()
}
pub fn goto_parent(&mut self) -> bool {
self.goto_parent_opt().is_some()
}
fn goto_next_sibling_opt(&mut self) -> Option<()> { fn goto_next_sibling_opt(&mut self) -> Option<()> {
self.node_id = self.parents.last_mut()?.1.next()?; self.node_id = self.parents.last_mut()?.1.next()?;
Some(()) Some(())
@@ -222,6 +164,37 @@ impl<'a> AstCursor<'a> {
Some(()) Some(())
} }
} }
impl<'a> Cursor<'a, Ast, Node, FieldId> for AstCursor<'a> {
fn node(&self) -> &'a Node {
&self.ast.nodes[self.node_id]
}
fn field_id(&self) -> Option<FieldId> {
let (_, children) = self.parents.last()?;
children.current_field()
}
fn field_name(&self) -> Option<&'static str> {
if self.field_id() == Some(CHILD_FIELD) {
None
} else {
self.field_id()
.and_then(|id| self.ast.field_name_for_id(id))
}
}
fn goto_first_child(&mut self) -> bool {
self.goto_first_child_opt().is_some()
}
fn goto_next_sibling(&mut self) -> bool {
self.goto_next_sibling_opt().is_some()
}
fn goto_parent(&mut self) -> bool {
self.goto_parent_opt().is_some()
}
}
/// An iterator over the child Ids of a node. /// An iterator over the child Ids of a node.
#[derive(Debug)] #[derive(Debug)]
@@ -368,16 +341,16 @@ impl Ast {
/// ///
/// This reflects the effective AST after desugaring and excludes orphaned /// This reflects the effective AST after desugaring and excludes orphaned
/// arena nodes left behind by rewrite operations. /// arena nodes left behind by rewrite operations.
pub fn reachable_node_ids(&self) -> Vec<Id> { pub fn reachable_node_ids(&self) -> Vec<usize> {
let mut reachable = Vec::new(); let mut reachable = Vec::new();
let mut stack = vec![self.root]; let mut stack = vec![self.root];
let mut seen = vec![false; self.nodes.len()]; let mut seen = vec![false; self.nodes.len()];
while let Some(id) = stack.pop() { while let Some(id) = stack.pop() {
if id.0 >= self.nodes.len() || seen[id.0] { if id >= self.nodes.len() || seen[id] {
continue; continue;
} }
seen[id.0] = true; seen[id] = true;
reachable.push(id); reachable.push(id);
if let Some(node) = self.get_node(id) { if let Some(node) = self.get_node(id) {
@@ -401,11 +374,11 @@ impl Ast {
} }
pub fn get_node(&self, id: Id) -> Option<&Node> { pub fn get_node(&self, id: Id) -> Option<&Node> {
self.nodes.get(id.0) self.nodes.get(id)
} }
pub fn print(&self, source: &str, root_id: Id) -> Value { pub fn print(&self, source: &str, root_id: Id) -> Value {
let root = &self.nodes()[root_id.0]; let root = &self.nodes()[root_id];
self.print_node(root, source) self.print_node(root, source)
} }
@@ -448,7 +421,7 @@ impl Ast {
is_named, is_named,
source_range, source_range,
}); });
Id(id) id
} }
fn union_source_range_of_children( fn union_source_range_of_children(
@@ -515,6 +488,15 @@ impl Ast {
self.create_named_token_with_range(kind, content, None) self.create_named_token_with_range(kind, content, None)
} }
/// Prepend a child id to the given field of the given node.
pub fn prepend_field_child(&mut self, node_id: Id, field_id: FieldId, value_id: Id) {
let node = self
.nodes
.get_mut(node_id)
.expect("prepend_field_child: invalid node id");
node.fields.entry(field_id).or_default().insert(0, value_id);
}
pub fn create_named_token_with_range( pub fn create_named_token_with_range(
&mut self, &mut self,
kind: &'static str, kind: &'static str,
@@ -536,7 +518,7 @@ impl Ast {
fields: BTreeMap::new(), fields: BTreeMap::new(),
content: NodeContent::DynamicString(content), content: NodeContent::DynamicString(content),
}); });
Id(id) id
} }
pub fn field_name_for_id(&self, id: FieldId) -> Option<&'static str> { pub fn field_name_for_id(&self, id: FieldId) -> Option<&'static str> {
@@ -620,6 +602,10 @@ pub struct Node {
} }
impl Node { impl Node {
pub fn kind(&self) -> &'static str {
self.kind_name
}
pub fn kind_name(&self) -> &'static str { pub fn kind_name(&self) -> &'static str {
self.kind_name self.kind_name
} }
@@ -771,14 +757,13 @@ impl<'a, C: Clone> TranslatorHandle<'a, C> {
} }
/// Translate every captured node in `captures` in place (OneShot phase /// Translate every captured node in `captures` in place (OneShot phase
/// only), except for captures whose name appears in `skip` — those are /// only). In a Repeating phase this is a no-op — Repeating rules
/// left as raw (input-schema) ids for the rule body to consume /// receive raw captures.
/// directly. In a Repeating phase this is a no-op — Repeating rules
/// receive raw captures regardless of `skip`.
/// ///
/// Used by the `rule!` macro's generated prefix. `skip` is populated /// Used by the `rule!` macro's generated prefix to preserve the
/// from the macro's `@@name` capture markers; for plain `@name` /// pre-existing "auto-translate captures before running the transform
/// captures (and rules with no `@@` markers) it is empty. /// body" behavior. Manually-written transforms typically translate
/// captures selectively via [`translate`] instead.
/// ///
/// To avoid infinite recursion, a capture whose id matches the rule's /// To avoid infinite recursion, a capture whose id matches the rule's
/// matched root (e.g. from a `(_) @_` pattern) is left unchanged. /// matched root (e.g. from a `(_) @_` pattern) is left unchanged.
@@ -787,12 +772,11 @@ impl<'a, C: Clone> TranslatorHandle<'a, C> {
captures: &mut Captures, captures: &mut Captures,
ast: &mut Ast, ast: &mut Ast,
user_ctx: &mut C, user_ctx: &mut C,
skip: &[&str],
) -> Result<(), String> { ) -> Result<(), String> {
match &self.inner { match &self.inner {
TranslatorImpl::OneShot { matched_root, .. } => { TranslatorImpl::OneShot { matched_root, .. } => {
let root = *matched_root; let root = *matched_root;
captures.try_map_captures_except(skip, |cid| { captures.try_map_all_captures(|cid| {
if cid == root { if cid == root {
Ok(vec![cid]) Ok(vec![cid])
} else { } else {
@@ -964,7 +948,7 @@ fn apply_repeating_rules_inner<C: Clone>(
)); ));
} }
let node_kind = ast.get_node(id).map(|n| n.kind_name()).unwrap_or(""); let node_kind = ast.get_node(id).map(|n| n.kind()).unwrap_or("");
for rule in index.rules_for_kind(node_kind) { for rule in index.rules_for_kind(node_kind) {
let rule_ptr = *rule as *const Rule<C>; let rule_ptr = *rule as *const Rule<C>;
if Some(rule_ptr) == skip_rule { if Some(rule_ptr) == skip_rule {
@@ -1016,7 +1000,7 @@ fn apply_repeating_rules_inner<C: Clone>(
// //
// Child traversal does not increment rewrite depth and starts fresh // Child traversal does not increment rewrite depth and starts fresh
// (no rule is skipped on child subtrees). // (no rule is skipped on child subtrees).
let mut fields = std::mem::take(&mut ast.nodes[id.0].fields); let mut fields = std::mem::take(&mut ast.nodes[id].fields);
for children in fields.values_mut() { for children in fields.values_mut() {
let mut new_children: Option<Vec<Id>> = None; let mut new_children: Option<Vec<Id>> = None;
for (i, &child_id) in children.iter().enumerate() { for (i, &child_id) in children.iter().enumerate() {
@@ -1049,7 +1033,7 @@ fn apply_repeating_rules_inner<C: Clone>(
*children = new; *children = new;
} }
} }
ast.nodes[id.0].fields = fields; ast.nodes[id].fields = fields;
Ok(vec![id]) Ok(vec![id])
} }
@@ -1083,7 +1067,7 @@ fn apply_one_shot_rules_inner<C: Clone>(
)); ));
} }
let node_kind = ast.get_node(id).map(|n| n.kind_name()).unwrap_or(""); let node_kind = ast.get_node(id).map(|n| n.kind()).unwrap_or("");
for rule in index.rules_for_kind(node_kind) { for rule in index.rules_for_kind(node_kind) {
if let Some(captures) = rule.try_match(ast, id)? { if let Some(captures) = rule.try_match(ast, id)? {

View File

@@ -49,7 +49,7 @@ impl Visitor {
pub fn build_with_schema(self, schema: crate::schema::Schema) -> Ast { pub fn build_with_schema(self, schema: crate::schema::Schema) -> Ast {
Ast { Ast {
root: Id(0), root: 0,
schema, schema,
nodes: self.nodes.into_iter().map(|n| n.inner).collect(), nodes: self.nodes.into_iter().map(|n| n.inner).collect(),
source: Vec::new(), source: Vec::new(),
@@ -72,7 +72,7 @@ impl Visitor {
}, },
parent: self.current, parent: self.current,
}); });
Id(id) id
} }
fn enter_node(&mut self, node: tree_sitter::Node<'_>) -> bool { fn enter_node(&mut self, node: tree_sitter::Node<'_>) -> bool {
@@ -83,10 +83,10 @@ impl Visitor {
fn leave_node(&mut self, field_name: Option<&'static str>, _node: tree_sitter::Node<'_>) { fn leave_node(&mut self, field_name: Option<&'static str>, _node: tree_sitter::Node<'_>) {
let node_id = self.current.unwrap(); let node_id = self.current.unwrap();
let node_parent = self.nodes[node_id.0].parent; let node_parent = self.nodes[node_id].parent;
if let Some(parent_id) = node_parent { if let Some(parent_id) = node_parent {
let parent = self.nodes.get_mut(parent_id.0).unwrap(); let parent = self.nodes.get_mut(parent_id).unwrap();
if let Some(field) = field_name { if let Some(field) = field_name {
let field_id = self.language.field_id_for_name(field).unwrap().get(); let field_id = self.language.field_id_for_name(field).unwrap().get();
parent parent

View File

@@ -300,7 +300,7 @@ fn test_query_skips_extras_in_positional_match() {
let mut cursor = AstCursor::new(&ast); let mut cursor = AstCursor::new(&ast);
cursor.goto_first_child(); cursor.goto_first_child();
let array_id = cursor.node_id(); let array_id = cursor.node_id();
assert_eq!(ast.get_node(array_id).unwrap().kind_name(), "array"); assert_eq!(ast.get_node(array_id).unwrap().kind(), "array");
// Two positional wildcards should bind to the two integers, skipping // Two positional wildcards should bind to the two integers, skipping
// the comment that sits between them. // the comment that sits between them.
@@ -309,15 +309,11 @@ fn test_query_skips_extras_in_positional_match() {
let matched = query.do_match(&ast, array_id, &mut captures).unwrap(); let matched = query.do_match(&ast, array_id, &mut captures).unwrap();
assert!(matched); assert!(matched);
assert_eq!( assert_eq!(
ast.get_node(captures.get_var("a").unwrap()) ast.get_node(captures.get_var("a").unwrap()).unwrap().kind(),
.unwrap()
.kind_name(),
"integer" "integer"
); );
assert_eq!( assert_eq!(
ast.get_node(captures.get_var("b").unwrap()) ast.get_node(captures.get_var("b").unwrap()).unwrap().kind(),
.unwrap()
.kind_name(),
"integer" "integer"
); );
} }
@@ -395,7 +391,7 @@ fn test_capture_unnamed_node_parenthesized() {
assert!(matched); assert!(matched);
let op_id = captures.get_var("op").unwrap(); let op_id = captures.get_var("op").unwrap();
let op_node = ast.get_node(op_id).unwrap(); let op_node = ast.get_node(op_id).unwrap();
assert_eq!(op_node.kind_name(), "="); assert_eq!(op_node.kind(), "=");
assert!(!op_node.is_named()); assert!(!op_node.is_named());
} }
@@ -418,7 +414,7 @@ fn test_capture_bare_underscore_repeated() {
let all = captures.get_all("all"); let all = captures.get_all("all");
assert_eq!(all.len(), 1); assert_eq!(all.len(), 1);
assert_eq!(ast.get_node(all[0]).unwrap().kind_name(), "="); assert_eq!(ast.get_node(all[0]).unwrap().kind(), "=");
assert!(!ast.get_node(all[0]).unwrap().is_named()); assert!(!ast.get_node(all[0]).unwrap().is_named());
} }
@@ -445,7 +441,7 @@ fn test_capture_unnamed_node_bare_literal() {
assert!(matched); assert!(matched);
let op_id = captures.get_var("op").unwrap(); let op_id = captures.get_var("op").unwrap();
let op_node = ast.get_node(op_id).unwrap(); let op_node = ast.get_node(op_id).unwrap();
assert_eq!(op_node.kind_name(), "="); assert_eq!(op_node.kind(), "=");
assert!(!op_node.is_named()); assert!(!op_node.is_named());
} }
@@ -483,7 +479,7 @@ fn test_bare_underscore_matches_unnamed() {
.unwrap(); .unwrap();
assert!(matched, "_ should match the unnamed `=`"); assert!(matched, "_ should match the unnamed `=`");
let any_node = ast.get_node(captures.get_var("any").unwrap()).unwrap(); let any_node = ast.get_node(captures.get_var("any").unwrap()).unwrap();
assert_eq!(any_node.kind_name(), "="); assert_eq!(any_node.kind(), "=");
assert!(!any_node.is_named()); assert!(!any_node.is_named());
} }
@@ -510,7 +506,7 @@ fn test_bare_forms_in_field_position() {
assert_eq!( assert_eq!(
ast.get_node(captures.get_var("lhs").unwrap()) ast.get_node(captures.get_var("lhs").unwrap())
.unwrap() .unwrap()
.kind_name(), .kind(),
"identifier" "identifier"
); );
@@ -520,7 +516,7 @@ fn test_bare_forms_in_field_position() {
let matched = query.do_match(&ast, assignment_id, &mut captures).unwrap(); let matched = query.do_match(&ast, assignment_id, &mut captures).unwrap();
assert!(matched); assert!(matched);
let op = ast.get_node(captures.get_var("op").unwrap()).unwrap(); let op = ast.get_node(captures.get_var("op").unwrap()).unwrap();
assert_eq!(op.kind_name(), "="); assert_eq!(op.kind(), "=");
assert!(!op.is_named()); assert!(!op.is_named());
} }
@@ -539,7 +535,7 @@ fn test_forward_scan_finds_unnamed_token_late() {
let mut cursor = AstCursor::new(&ast); let mut cursor = AstCursor::new(&ast);
cursor.goto_first_child(); // for cursor.goto_first_child(); // for
cursor.goto_first_child(); // do (the body) cursor.goto_first_child(); // do (the body)
while cursor.node().kind_name() != "do" || !cursor.node().is_named() { while cursor.node().kind() != "do" || !cursor.node().is_named() {
assert!(cursor.goto_next_sibling(), "expected to find named `do`"); assert!(cursor.goto_next_sibling(), "expected to find named `do`");
} }
let do_id = cursor.node_id(); let do_id = cursor.node_id();
@@ -549,7 +545,7 @@ fn test_forward_scan_finds_unnamed_token_late() {
let matched = query.do_match(&ast, do_id, &mut captures).unwrap(); let matched = query.do_match(&ast, do_id, &mut captures).unwrap();
assert!(matched, "forward-scan should find the `end` keyword"); assert!(matched, "forward-scan should find the `end` keyword");
let kw = ast.get_node(captures.get_var("kw").unwrap()).unwrap(); let kw = ast.get_node(captures.get_var("kw").unwrap()).unwrap();
assert_eq!(kw.kind_name(), "end"); assert_eq!(kw.kind(), "end");
assert!(!kw.is_named()); assert!(!kw.is_named());
} }
@@ -565,7 +561,7 @@ fn test_forward_scan_preserves_order() {
let mut cursor = AstCursor::new(&ast); let mut cursor = AstCursor::new(&ast);
cursor.goto_first_child(); cursor.goto_first_child();
cursor.goto_first_child(); cursor.goto_first_child();
while cursor.node().kind_name() != "do" || !cursor.node().is_named() { while cursor.node().kind() != "do" || !cursor.node().is_named() {
assert!(cursor.goto_next_sibling(), "expected to find named `do`"); assert!(cursor.goto_next_sibling(), "expected to find named `do`");
} }
let do_id = cursor.node_id(); let do_id = cursor.node_id();
@@ -639,7 +635,7 @@ fn ruby_rules() -> Vec<Rule> {
left: (identifier $tmp) left: (identifier $tmp)
right: {right} right: {right}
) )
{left.iter().enumerate().map(|(i, &lhs)| {..left.iter().enumerate().map(|(i, &lhs)|
yeast::tree!( yeast::tree!(
(assignment (assignment
left: {lhs} left: {lhs}
@@ -671,7 +667,7 @@ fn ruby_rules() -> Vec<Rule> {
left: {pat} left: {pat}
right: (identifier $tmp) right: (identifier $tmp)
) )
stmt: {body} stmt: {..body}
) )
) )
) )
@@ -907,7 +903,7 @@ fn one_shot_xeq1_rules() -> Vec<Rule> {
yeast::rule!( yeast::rule!(
(program (_)* @stmts) (program (_)* @stmts)
=> =>
(program stmt: {stmts}) (program stmt: {..stmts})
), ),
yeast::rule!( yeast::rule!(
(assignment left: (_) @left right: (_) @right) (assignment left: (_) @left right: (_) @right)
@@ -983,7 +979,7 @@ fn test_one_shot_recurses_into_returned_capture() {
yeast::rule!( yeast::rule!(
(program (_)* @stmts) (program (_)* @stmts)
=> =>
(program stmt: {stmts}) (program stmt: {..stmts})
), ),
// Returns the captured `left` verbatim, discarding `right`. // Returns the captured `left` verbatim, discarding `right`.
yeast::rule!( yeast::rule!(
@@ -1025,7 +1021,7 @@ fn test_one_shot_does_not_recurse_into_wrapper_output() {
yeast::rule!( yeast::rule!(
(program (_)* @stmts) (program (_)* @stmts)
=> =>
(program stmt: {stmts}) (program stmt: {..stmts})
), ),
// Wraps `left` in nested `first_node`/`second_node` output kinds. // Wraps `left` in nested `first_node`/`second_node` output kinds.
// Neither wrapper kind has a matching rule, so a buggy implementation // Neither wrapper kind has a matching rule, so a buggy implementation
@@ -1062,111 +1058,6 @@ fn test_one_shot_does_not_recurse_into_wrapper_output() {
); );
} }
/// Verify that `@@name` capture markers skip the auto-translate prefix:
/// the body sees the *raw* (input-schema) `Id` and can read its
/// source text or call `ctx.translate(...)` explicitly. Compare with
/// the bare `@name` form, where the auto-translate prefix runs the
/// same translation up front and the body sees the post-translate id.
#[test]
fn test_raw_capture_marker() {
let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into();
let schema =
yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap();
let rules: Vec<Rule> = vec![
yeast::rule!(
(program (_)* @stmts)
=>
(program stmt: {stmts})
),
// `@@raw_lhs` is untranslated: the body reads its source text
// ("x") and embeds it directly as the identifier content. `@rhs`
// is auto-translated (rhs already points to (integer "INT")).
yeast::rule!(
(assignment left: (_) @@raw_lhs right: (_) @rhs)
=>
{
let text = ctx.ast.source_text(raw_lhs);
tree!((call
method: (identifier #{text.as_str()})
receiver: {rhs}))
}
),
yeast::rule!((identifier) => (identifier "ID")),
yeast::rule!((integer) => (integer "INT")),
];
let phases = vec![Phase::new("translate", PhaseKind::OneShot, rules)];
let runner: Runner = Runner::with_schema(lang, &schema, &phases);
let input = "x = 1";
let ast = runner.run(input).unwrap();
let dump = dump_ast(&ast, ast.get_root(), input);
// `method:` uses the raw source text ("x"); if `@@` were broken and
// auto-translation ran on `raw_lhs`, it would still produce the
// string "x" (source_text inherits the input range), so the dump
// wouldn't change here. The companion test
// `test_raw_capture_marker_explicit_translate` exercises the
// stronger property that `ctx.translate(raw_lhs)?` succeeds and
// produces the translated `(identifier "ID")`.
assert_dump_eq(
&dump,
r#"
program
stmt:
call
method: identifier "x"
receiver: integer "INT"
"#,
);
}
/// Companion to `test_raw_capture_marker`: confirms that calling
/// `ctx.translate(raw)` on a `@@`-captured `Id` from the rule body
/// produces the correctly-translated output-schema node. With `@`, the
/// translation has already happened, so `ctx.translate(...)` inside the
/// body would attempt to re-translate an output node (which has no
/// matching rule and would error).
#[test]
fn test_raw_capture_marker_explicit_translate() {
let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into();
let schema =
yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap();
let rules: Vec<Rule> = vec![
yeast::rule!(
(program (_)* @stmts)
=>
(program stmt: {stmts})
),
yeast::rule!(
(assignment left: (_) @@raw_lhs right: (_) @rhs)
=>
{
let translated_lhs = ctx.translate(raw_lhs)?;
tree!((call
method: {translated_lhs}
receiver: {rhs}))
}
),
yeast::rule!((identifier) => (identifier "ID")),
yeast::rule!((integer) => (integer "INT")),
];
let phases = vec![Phase::new("translate", PhaseKind::OneShot, rules)];
let runner: Runner = Runner::with_schema(lang, &schema, &phases);
let input = "x = 1";
let ast = runner.run(input).unwrap();
let dump = dump_ast(&ast, ast.get_root(), input);
assert_dump_eq(
&dump,
r#"
program
stmt:
call
method: identifier "ID"
receiver: integer "INT"
"#,
);
}
// ---- Cursor tests ---- // ---- Cursor tests ----
#[test] #[test]
@@ -1176,11 +1067,11 @@ fn test_cursor_navigation() {
let mut cursor = AstCursor::new(&ast); let mut cursor = AstCursor::new(&ast);
// Start at root // Start at root
assert_eq!(cursor.node().kind_name(), "program"); assert_eq!(cursor.node().kind(), "program");
// Go to first child (assignment) // Go to first child (assignment)
assert!(cursor.goto_first_child()); assert!(cursor.goto_first_child());
assert_eq!(cursor.node().kind_name(), "assignment"); assert_eq!(cursor.node().kind(), "assignment");
// No sibling // No sibling
assert!(!cursor.goto_next_sibling()); assert!(!cursor.goto_next_sibling());
@@ -1191,10 +1082,10 @@ fn test_cursor_navigation() {
// Go back up // Go back up
assert!(cursor.goto_parent()); assert!(cursor.goto_parent());
assert_eq!(cursor.node().kind_name(), "assignment"); assert_eq!(cursor.node().kind(), "assignment");
assert!(cursor.goto_parent()); assert!(cursor.goto_parent());
assert_eq!(cursor.node().kind_name(), "program"); assert_eq!(cursor.node().kind(), "program");
// Can't go further up // Can't go further up
assert!(!cursor.goto_parent()); assert!(!cursor.goto_parent());
@@ -1239,8 +1130,10 @@ fn test_desugar_for_with_multiple_assignment() {
} }
/// Regression test: `#{capture}` in a template must render the *source text* /// Regression test: `#{capture}` in a template must render the *source text*
/// of the captured node, not its arena `Id`. Captures are bound as `Id`, /// of the captured node, not its arena `Id`. Previously, captures were bound
/// whose `YeastDisplay` impl resolves to the captured node's source text. /// as `usize`, so `#{cap}` printed the integer id (e.g. `"3"`) via `Display`.
/// Captures are now bound as `NodeRef`, which has no `Display` impl and
/// resolves to the captured node's source text via `YeastDisplay`.
#[test] #[test]
fn test_hash_brace_renders_capture_source_text() { fn test_hash_brace_renders_capture_source_text() {
let rule: Rule = rule!( let rule: Rule = rule!(
@@ -1268,7 +1161,7 @@ fn test_hash_brace_renders_capture_source_text() {
); );
} }
/// Regression test: non-`Id` values in `#{expr}` still render via their /// Regression test: non-`NodeRef` values in `#{expr}` still render via their
/// `Display` impl (covered by `YeastDisplay`'s blanket impls for primitives). /// `Display` impl (covered by `YeastDisplay`'s blanket impls for primitives).
#[test] #[test]
fn test_hash_brace_renders_integer_expression() { fn test_hash_brace_renders_integer_expression() {
@@ -1306,12 +1199,12 @@ fn test_hash_brace_uses_capture_location_for_leaf() {
let ast = run_and_ast("foo.bar()", vec![rule]); let ast = run_and_ast("foo.bar()", vec![rule]);
let mut bar_ids: Vec<yeast::Id> = Vec::new(); let mut bar_ids: Vec<usize> = Vec::new();
for id in ast.reachable_node_ids() { for id in ast.reachable_node_ids() {
let Some(node) = ast.get_node(id) else { let Some(node) = ast.get_node(id) else {
continue; continue;
}; };
if node.kind_name() == "identifier" && ast.source_text(id) == "bar" { if node.kind() == "identifier" && ast.source_text(id) == "bar" {
bar_ids.push(id); bar_ids.push(id);
} }
} }

View File

@@ -1,5 +1,5 @@
use codeql_extractor::extractor::simple; use codeql_extractor::extractor::simple;
use yeast::{ConcreteDesugarer, DesugaringConfig, PhaseKind, Rule, rule, tree}; use yeast::{ConcreteDesugarer, DesugaringConfig, PhaseKind, Rule, manual_rule, rule, tree};
/// User context propagated from outer rules down to the inner rules that /// User context propagated from outer rules down to the inner rules that
/// emit the corresponding output declarations, so that each emitted node /// emit the corresponding output declarations, so that each emitted node
@@ -45,7 +45,7 @@ struct SwiftContext {
/// Build a freshly-created `chained_declaration` modifier node if /// Build a freshly-created `chained_declaration` modifier node if
/// `ctx.is_chained`, else `None`. Used by inner declaration rules to /// `ctx.is_chained`, else `None`. Used by inner declaration rules to
/// emit the chained tag for non-first children of a flattening outer /// emit the chained tag for non-first children of a flattening outer
/// rule. Returns `Option<Id>` so it splices via `{…}` to 0 or 1 ids. /// rule. Returns `Option<Id>` so it splices via `{..…}` to 0 or 1 ids.
fn chained_modifier(ctx: &mut yeast::build::BuildCtx<'_, SwiftContext>) -> Option<yeast::Id> { fn chained_modifier(ctx: &mut yeast::build::BuildCtx<'_, SwiftContext>) -> Option<yeast::Id> {
if ctx.is_chained { if ctx.is_chained {
Some(ctx.literal("modifier", "chained_declaration")) Some(ctx.literal("modifier", "chained_declaration"))
@@ -63,10 +63,10 @@ fn chained_modifier(ctx: &mut yeast::build::BuildCtx<'_, SwiftContext>) -> Optio
/// condition. /// condition.
fn and_chain( fn and_chain(
ctx: &mut yeast::build::BuildCtx<'_, SwiftContext>, ctx: &mut yeast::build::BuildCtx<'_, SwiftContext>,
conds: Vec<yeast::Id>, conds: Vec<yeast::NodeRef>,
) -> yeast::Id { ) -> yeast::Id {
conds conds.into_iter()
.into_iter() .map(yeast::Id::from)
.reduce(|acc, elem| { .reduce(|acc, elem| {
tree!((binary_expr operator: (infix_operator "&&") left: {acc} right: {elem})) tree!((binary_expr operator: (infix_operator "&&") left: {acc} right: {elem}))
}) })
@@ -79,7 +79,7 @@ fn and_chain(
/// guarantees at least one part. /// guarantees at least one part.
fn member_chain( fn member_chain(
ctx: &mut yeast::build::BuildCtx<'_, SwiftContext>, ctx: &mut yeast::build::BuildCtx<'_, SwiftContext>,
parts: Vec<yeast::Id>, parts: Vec<yeast::NodeRef>,
) -> yeast::Id { ) -> yeast::Id {
let mut iter = parts.into_iter(); let mut iter = parts.into_iter();
let first = iter let first = iter
@@ -100,7 +100,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(source_file statement: _* @children) (source_file statement: _* @children)
=> =>
(top_level (top_level
body: (block stmt: {children}) body: (block stmt: {..children})
) )
), ),
// Declarations may be wrapped in local/global wrapper nodes. // Declarations may be wrapped in local/global wrapper nodes.
@@ -144,12 +144,12 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
rule!( rule!(
(operator_declaration "prefix" (referenceable_operator _ @op) (simple_identifier)? @prec) (operator_declaration "prefix" (referenceable_operator _ @op) (simple_identifier)? @prec)
=> =>
(operator_syntax_declaration name: (identifier #{op}) fixity: (fixity "prefix") precedence: {prec}) (operator_syntax_declaration name: (identifier #{op}) fixity: (fixity "prefix") precedence: {..prec})
), ),
rule!( rule!(
(operator_declaration "postfix" (referenceable_operator _ @op) (simple_identifier)? @prec) (operator_declaration "postfix" (referenceable_operator _ @op) (simple_identifier)? @prec)
=> =>
(operator_syntax_declaration name: (identifier #{op}) fixity: (fixity "postfix") precedence: {prec}) (operator_syntax_declaration name: (identifier #{op}) fixity: (fixity "postfix") precedence: {..prec})
), ),
rule!( rule!(
(operator_declaration "infix" (referenceable_operator _ @op) (simple_identifier)? @prec) (operator_declaration "infix" (referenceable_operator _ @op) (simple_identifier)? @prec)
@@ -157,7 +157,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(operator_syntax_declaration (operator_syntax_declaration
name: (identifier #{op}) name: (identifier #{op})
fixity: (fixity "infix") fixity: (fixity "infix")
precedence: {prec}) precedence: {..prec})
), ),
rule!((bitwise_operation lhs: @l op: @op rhs: @r) => (binary_expr left: {l} operator: (infix_operator #{op}) right: {r})), rule!((bitwise_operation lhs: @l op: @op rhs: @r) => (binary_expr left: {l} operator: (infix_operator #{op}) right: {r})),
rule!((nil_coalescing_expression value: @l if_nil: @r) => (binary_expr left: {l} operator: (infix_operator "??") right: {r})), rule!((nil_coalescing_expression value: @l if_nil: @r) => (binary_expr left: {l} operator: (infix_operator "??") right: {r})),
@@ -170,9 +170,9 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
rule!((postfix_expression operation: @op target: @operand) => (unary_expr operator: (postfix_operator #{op}) operand: {operand})), rule!((postfix_expression operation: @op target: @operand) => (unary_expr operator: (postfix_operator #{op}) operand: {operand})),
// TODO: Parenthesised single-value tuple is a grouping expression and should pass through. // TODO: Parenthesised single-value tuple is a grouping expression and should pass through.
// Multi-value tuples become tuple_expr. // Multi-value tuples become tuple_expr.
rule!((tuple_expression value: _* @v) => (tuple_expr element: {v})), rule!((tuple_expression value: _* @v) => (tuple_expr element: {..v})),
// Blocks contain statement* directly. // Blocks contain statement* directly.
rule!((block statement: _+ @stmts) => (block stmt: {stmts})), rule!((block statement: _+ @stmts) => (block stmt: {..stmts})),
rule!((block) => (block)), rule!((block) => (block)),
// ---- Variables ---- // ---- Variables ----
// property_binding rules — these produce variable_declaration and/or accessor_declaration // property_binding rules — these produce variable_declaration and/or accessor_declaration
@@ -192,15 +192,21 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
// this whole property_binding is itself a non-first declarator // this whole property_binding is itself a non-first declarator
// of a containing property_declaration); subsequent accessors // of a containing property_declaration); subsequent accessors
// always emit `chained_declaration`. // always emit `chained_declaration`.
rule!( manual_rule!(
(property_binding (property_binding
name: @pattern name: @pattern
type: _? @ty type: _? @ty
computed_value: (computed_property accessor: _+ @@accessors)) computed_value: (computed_property accessor: _+ @accessors))
=> {
{{ // Translate `ty` first so the context holds an
ctx.property_name = Some(tree!((identifier #{pattern}))); // output-schema node id.
ctx.property_type = ty; let translated_ty = ctx.translate_opt(ty)?;
// Build the property-name identifier from the
// (untranslated) pattern leaf.
let name_id = tree!((identifier #{pattern}));
ctx.property_name = Some(name_id);
ctx.property_type = translated_ty;
let mut result = Vec::new(); let mut result = Vec::new();
for (i, acc) in accessors.into_iter().enumerate() { for (i, acc) in accessors.into_iter().enumerate() {
@@ -209,8 +215,8 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
} }
result.extend(ctx.translate(acc)?); result.extend(ctx.translate(acc)?);
} }
result Ok(result)
}} }
), ),
// Computed property: shorthand getter (no explicit get/set, just // Computed property: shorthand getter (no explicit get/set, just
// statements) → a single accessor_declaration with kind "get". // statements) → a single accessor_declaration with kind "get".
@@ -223,13 +229,13 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
computed_value: (computed_property statement: _* @body)) computed_value: (computed_property statement: _* @body))
=> =>
(accessor_declaration (accessor_declaration
modifier: {ctx.binding_modifier} modifier: {..ctx.binding_modifier}
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
name: (identifier #{name}) name: (identifier #{name})
type: {ty} type: {..ty}
accessor_kind: (accessor_kind "get") accessor_kind: (accessor_kind "get")
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// Stored property with willSet/didSet observers (initializer // Stored property with willSet/didSet observers (initializer
// optional) → a `variable_declaration` followed by one // optional) → a `variable_declaration` followed by one
@@ -242,22 +248,26 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
// The `variable_declaration` itself inherits the outer rule's // The `variable_declaration` itself inherits the outer rule's
// chained state; observers always get `chained_declaration` // chained state; observers always get `chained_declaration`
// because they're subsequent outputs of this flattening rule. // because they're subsequent outputs of this flattening rule.
rule!( manual_rule!(
(property_binding (property_binding
name: (pattern bound_identifier: @name) name: (pattern bound_identifier: @name)
type: _? @ty type: _? @ty
value: _? @val value: _? @val
observers: (willset_didset_block willset: _? @@ws didset: _? @@ds)) observers: (willset_didset_block willset: _? @ws didset: _? @ds))
=> {
{{ // Translate ty and val so the variable_declaration
// below contains output-schema nodes.
let translated_ty = ctx.translate_opt(ty)?;
let translated_val = ctx.translate_opt(val)?;
let var_decl = tree!( let var_decl = tree!(
(variable_declaration (variable_declaration
modifier: {ctx.binding_modifier} modifier: {..ctx.binding_modifier}
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
pattern: (name_pattern identifier: (identifier #{name})) pattern: (name_pattern identifier: (identifier #{name}))
type: {ty} type: {..translated_ty}
value: {val}) value: {..translated_val})
); );
// Publish the property name for the observer rules. // Publish the property name for the observer rules.
@@ -270,8 +280,8 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
for obs in ws.into_iter().chain(ds) { for obs in ws.into_iter().chain(ds) {
result.extend(ctx.translate(obs)?); result.extend(ctx.translate(obs)?);
} }
result Ok(result)
}} }
), ),
// property_binding with any pattern name (identifier or // property_binding with any pattern name (identifier or
// destructuring). Reads outer modifiers / chained tag from `ctx`. // destructuring). Reads outer modifiers / chained tag from `ctx`.
@@ -282,12 +292,12 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
value: _? @val) value: _? @val)
=> =>
(variable_declaration (variable_declaration
modifier: {ctx.binding_modifier} modifier: {..ctx.binding_modifier}
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
pattern: {pattern} pattern: {pattern}
type: {ty} type: {..ty}
value: {val}) value: {..val})
), ),
// property_declaration: flatten declarators (each may translate // property_declaration: flatten declarators (each may translate
// to multiple nodes — variable_declaration and/or // to multiple nodes — variable_declaration and/or
@@ -299,24 +309,27 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
// inner declaration rules (`property_binding` variants, // inner declaration rules (`property_binding` variants,
// accessor inner rules) read these fields and emit complete // accessor inner rules) read these fields and emit complete
// `modifier:` lists from the start. // `modifier:` lists from the start.
rule!( manual_rule!(
(property_declaration (property_declaration
binding: (value_binding_pattern mutability: @@binding_kind) binding: (value_binding_pattern mutability: @binding_kind)
declarator: _* @@decls declarator: _* @decls
(modifiers)* @mods) (modifiers)* @mods)
=> {
{{ let binding_text = ctx.ast.source_text(binding_kind.0);
let binding_text = ctx.ast.source_text(binding_kind);
ctx.binding_modifier = Some(ctx.literal("modifier", &binding_text)); ctx.binding_modifier = Some(ctx.literal("modifier", &binding_text));
ctx.outer_modifiers = mods; let mut modifiers = Vec::new();
for m in mods {
modifiers.extend(ctx.translate(m)?);
}
ctx.outer_modifiers = modifiers;
let mut result = Vec::new(); let mut result = Vec::new();
for (i, decl) in decls.into_iter().enumerate() { for (i, decl) in decls.into_iter().enumerate() {
ctx.is_chained = i > 0; ctx.is_chained = i > 0;
result.extend(ctx.translate(decl)?); result.extend(ctx.translate(decl)?);
} }
result Ok(result)
}} }
), ),
// ---- Enums ---- // ---- Enums ----
// enum_type_parameter → parameter (with optional name as pattern). // enum_type_parameter → parameter (with optional name as pattern).
@@ -342,19 +355,19 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
data_contents: (enum_type_parameters parameter: _* @params)) data_contents: (enum_type_parameters parameter: _* @params))
=> =>
(class_like_declaration (class_like_declaration
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
modifier: (modifier "enum_case") modifier: (modifier "enum_case")
name: (identifier #{name}) name: (identifier #{name})
member: (constructor_declaration parameter: {params} body: (block))) member: (constructor_declaration parameter: {..params} body: (block)))
), ),
// enum_case_entry with explicit raw value → variable_declaration with that value. // enum_case_entry with explicit raw value → variable_declaration with that value.
rule!( rule!(
(enum_case_entry name: @name raw_value: @val) (enum_case_entry name: @name raw_value: @val)
=> =>
(variable_declaration (variable_declaration
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
modifier: (modifier "enum_case") modifier: (modifier "enum_case")
pattern: (name_pattern identifier: (identifier #{name})) pattern: (name_pattern identifier: (identifier #{name}))
value: {val}) value: {val})
@@ -364,8 +377,8 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(enum_case_entry name: @name) (enum_case_entry name: @name)
=> =>
(variable_declaration (variable_declaration
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
modifier: (modifier "enum_case") modifier: (modifier "enum_case")
pattern: (name_pattern identifier: (identifier #{name}))) pattern: (name_pattern identifier: (identifier #{name})))
), ),
@@ -373,19 +386,22 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
// into `ctx` and translate each case with `ctx.is_chained` // into `ctx` and translate each case with `ctx.is_chained`
// toggled per iteration so the inner `enum_case_entry` rules // toggled per iteration so the inner `enum_case_entry` rules
// emit complete `modifier:` lists from the start. // emit complete `modifier:` lists from the start.
rule!( manual_rule!(
(enum_entry case: _+ @@cases (modifiers)* @mods) (enum_entry case: _+ @cases (modifiers)* @mods)
=> {
{{ let mut modifiers = Vec::new();
ctx.outer_modifiers = mods; for m in mods {
modifiers.extend(ctx.translate(m)?);
}
ctx.outer_modifiers = modifiers;
let mut result = Vec::new(); let mut result = Vec::new();
for (i, case) in cases.into_iter().enumerate() { for (i, case) in cases.into_iter().enumerate() {
ctx.is_chained = i > 0; ctx.is_chained = i > 0;
result.extend(ctx.translate(case)?); result.extend(ctx.translate(case)?);
} }
result Ok(result)
}} }
), ),
// Plain assignment: `x = expr` // Plain assignment: `x = expr`
rule!( rule!(
@@ -418,7 +434,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(constructor_pattern (constructor_pattern
constructor: (member_access_expr base: {typ} member: (identifier #{name})) constructor: (member_access_expr base: {typ} member: (identifier #{name}))
element: {items}) element: {..items})
), ),
// case .foo(x,y) pattern // case .foo(x,y) pattern
rule!( rule!(
@@ -426,10 +442,10 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(constructor_pattern (constructor_pattern
constructor: (member_access_expr base: (inferred_type_expr #{dot}) member: (identifier #{name})) constructor: (member_access_expr base: (inferred_type_expr #{dot}) member: (identifier #{name}))
element: {items}) element: {..items})
), ),
// Tuple pattern and its (optionally named) items // Tuple pattern and its (optionally named) items
rule!((pattern kind: (tuple_pattern item: _* @elems)) => (tuple_pattern element: {elems})), rule!((pattern kind: (tuple_pattern item: _* @elems)) => (tuple_pattern element: {..elems})),
rule!((tuple_pattern_item name: @key pattern: @pat) => (pattern_element key: (identifier #{key}) pattern: {pat})), rule!((tuple_pattern_item name: @key pattern: @pat) => (pattern_element key: (identifier #{key}) pattern: {pat})),
rule!((tuple_pattern_item pattern: @pat) => (pattern_element pattern: {pat})), rule!((tuple_pattern_item pattern: @pat) => (pattern_element pattern: {pat})),
// Type casting pattern (TODO) // Type casting pattern (TODO)
@@ -452,21 +468,20 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(function_declaration (function_declaration
name: (identifier #{name}) name: (identifier #{name})
parameter: {params} parameter: {..params}
return_type: {ret} return_type: {..ret}
body: (block stmt: {body_stmts})) body: (block stmt: {..body_stmts}))
), ),
// Parameters are wrapped in function_parameter, which also carries // Parameters are wrapped in function_parameter, which also carries
// optional default values. Publishes the default value into `ctx` // optional default values. Publishes the default value into `ctx`
// before translating the inner `parameter` so the `parameter` // before translating the inner `parameter` so the `parameter`
// rules can include it as a `default:` field directly. // rules can include it as a `default:` field directly.
rule!( manual_rule!(
(function_parameter parameter: @@p default_value: _? @def) (function_parameter parameter: @p default_value: _? @def)
=> {
{{ ctx.default_value = ctx.translate_opt(def)?;
ctx.default_value = def; ctx.translate(p)
ctx.translate(p)? }
}}
), ),
// Parameter with external name and type // Parameter with external name and type
rule!( rule!(
@@ -475,7 +490,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(parameter (parameter
external_name: (identifier #{ext}) external_name: (identifier #{ext})
pattern: (name_pattern identifier: (identifier #{name})) pattern: (name_pattern identifier: (identifier #{name}))
default: {ctx.default_value}) default: {..ctx.default_value})
), ),
rule!( rule!(
(parameter external_name: @ext name: @name type: @ty) (parameter external_name: @ext name: @name type: @ty)
@@ -484,7 +499,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
external_name: (identifier #{ext}) external_name: (identifier #{ext})
pattern: (name_pattern identifier: (identifier #{name})) pattern: (name_pattern identifier: (identifier #{name}))
type: {ty} type: {ty}
default: {ctx.default_value}) default: {..ctx.default_value})
), ),
// Parameter with just name and type (no external name) // Parameter with just name and type (no external name)
rule!( rule!(
@@ -492,7 +507,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(parameter (parameter
pattern: (name_pattern identifier: (identifier #{name})) pattern: (name_pattern identifier: (identifier #{name}))
default: {ctx.default_value}) default: {..ctx.default_value})
), ),
rule!( rule!(
(parameter name: @name type: @ty) (parameter name: @name type: @ty)
@@ -500,7 +515,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(parameter (parameter
pattern: (name_pattern identifier: (identifier #{name})) pattern: (name_pattern identifier: (identifier #{name}))
type: {ty} type: {ty}
default: {ctx.default_value}) default: {..ctx.default_value})
), ),
// Reference to a function, f(x:y:z:). This is parsed as a call with a single argument with multiple reference_specifier labels. // Reference to a function, f(x:y:z:). This is parsed as a call with a single argument with multiple reference_specifier labels.
// We don't want downstream QL to try to handle this as a call_expr with a weird argument, so explicitly mark it as unsupported for now. // We don't want downstream QL to try to handle this as a call_expr with a weird argument, so explicitly mark it as unsupported for now.
@@ -514,7 +529,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
rule!( rule!(
(call_expression function: @func suffix: (call_suffix arguments: (value_arguments argument: (value_argument)* @args))) (call_expression function: @func suffix: (call_suffix arguments: (value_arguments argument: (value_argument)* @args)))
=> =>
(call_expr callee: {func} argument: {args}) (call_expr callee: {func} argument: {..args})
), ),
// Value argument with label (value: _ matches both named nodes and anonymous tokens like nil) // Value argument with label (value: _ matches both named nodes and anonymous tokens like nil)
rule!( rule!(
@@ -537,7 +552,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
// Return / break / continue, one rule per keyword. // Return / break / continue, one rule per keyword.
// The anonymous "return"/"break"/"continue" keywords are matched as // The anonymous "return"/"break"/"continue" keywords are matched as
// string literals. // string literals.
rule!((control_transfer_statement kind: "return" result: _? @val) => (return_expr value: {val})), rule!((control_transfer_statement kind: "return" result: _? @val) => (return_expr value: {..val})),
rule!((control_transfer_statement kind: "break" result: @lbl) => (break_expr label: (identifier #{lbl}))), rule!((control_transfer_statement kind: "break" result: @lbl) => (break_expr label: (identifier #{lbl}))),
rule!((control_transfer_statement kind: "break") => (break_expr)), rule!((control_transfer_statement kind: "break") => (break_expr)),
rule!((control_transfer_statement kind: "continue" result: @lbl) => (continue_expr label: (identifier #{lbl}))), rule!((control_transfer_statement kind: "continue" result: @lbl) => (continue_expr label: (identifier #{lbl}))),
@@ -556,20 +571,20 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
statement: _* @body) statement: _* @body)
=> =>
(function_expr (function_expr
modifier: {attrs} modifier: {..attrs}
capture_declaration: {captures} capture_declaration: {..captures}
parameter: {params} parameter: {..params}
return_type: {ret} return_type: {..ret}
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// capture_list_item with ownership modifier (e.g. [weak self], [unowned x]) // capture_list_item with ownership modifier (e.g. [weak self], [unowned x])
rule!( rule!(
(capture_list_item ownership: _? @ownership name: @name value: _? @val) (capture_list_item ownership: _? @ownership name: @name value: _? @val)
=> =>
(variable_declaration (variable_declaration
modifier: {ownership} modifier: {..ownership}
pattern: (name_pattern identifier: (identifier #{name})) pattern: (name_pattern identifier: (identifier #{name}))
value: {val}) value: {..val})
), ),
// Lambda parameter with type and optional external name // Lambda parameter with type and optional external name
rule!( rule!(
@@ -615,7 +630,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(if_expr (if_expr
condition: {and_chain(&mut ctx, cond)} condition: {and_chain(&mut ctx, cond)}
then: {then_body} then: {then_body}
else: {else_stmts}) else: {..else_stmts})
), ),
// Guard statement // Guard statement
rule!( rule!(
@@ -623,7 +638,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(guard_if_stmt (guard_if_stmt
condition: {and_chain(&mut ctx, cond)} condition: {and_chain(&mut ctx, cond)}
else: (block stmt: {else_stmts})) else: (block stmt: {..else_stmts}))
), ),
// Ternary expression → if_expr // Ternary expression → if_expr
rule!( rule!(
@@ -635,7 +650,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
rule!( rule!(
(switch_statement expr: @val entry: (switch_entry)* @cases) (switch_statement expr: @val entry: (switch_entry)* @cases)
=> =>
(switch_expr value: {val} case: {cases}) (switch_expr value: {val} case: {..cases})
), ),
// Switch entry with multiple patterns and body // Switch entry with multiple patterns and body
rule!( rule!(
@@ -644,19 +659,19 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
pattern: (switch_pattern pattern: @rest)+ pattern: (switch_pattern pattern: @rest)+
statement: _* @body) statement: _* @body)
=> =>
(switch_case pattern: (or_pattern pattern: {first} pattern: {rest}) body: (block stmt: {body})) (switch_case pattern: (or_pattern pattern: {first} pattern: {..rest}) body: (block stmt: {..body}))
), ),
// Switch entry with exactly one pattern and body // Switch entry with exactly one pattern and body
rule!( rule!(
(switch_entry pattern: (switch_pattern pattern: @pat) statement: _* @body) (switch_entry pattern: (switch_pattern pattern: @pat) statement: _* @body)
=> =>
(switch_case pattern: {pat} body: (block stmt: {body})) (switch_case pattern: {pat} body: (block stmt: {..body}))
), ),
// Switch entry: default case (no patterns) // Switch entry: default case (no patterns)
rule!( rule!(
(switch_entry default: (default_keyword) statement: _* @body) (switch_entry default: (default_keyword) statement: _* @body)
=> =>
(switch_case body: (block stmt: {body})) (switch_case body: (block stmt: {..body}))
), ),
// if case PATTERN = expr — preserve the pattern directly (no Optional wrapping) // if case PATTERN = expr — preserve the pattern directly (no Optional wrapping)
rule!( rule!(
@@ -702,8 +717,8 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(for_each_stmt (for_each_stmt
pattern: {pat} pattern: {pat}
iterable: {iter} iterable: {iter}
guard: {guard} guard: {..guard}
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// While loop // While loop
rule!( rule!(
@@ -711,7 +726,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(while_stmt (while_stmt
condition: {and_chain(&mut ctx, cond)} condition: {and_chain(&mut ctx, cond)}
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// Repeat-while loop // Repeat-while loop
rule!( rule!(
@@ -719,28 +734,28 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(do_while_stmt (do_while_stmt
condition: {and_chain(&mut ctx, cond)} condition: {and_chain(&mut ctx, cond)}
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// Labeled statement (e.g. `outer: for ...`). Strip the trailing ':' from the label token. // Labeled statement (e.g. `outer: for ...`). Strip the trailing ':' from the label token.
rule!((labeled_statement label: (statement_label) @lbl statement: @stmt) => { rule!((labeled_statement label: (statement_label) @lbl statement: @stmt) => {
let text = ctx.ast.source_text(lbl); let text = ctx.ast.source_text(lbl.into());
let name = &text[..text.len() - 1]; let name = &text[..text.len() - 1];
tree!((labeled_stmt label: (identifier #{name}) stmt: {stmt})) tree!((labeled_stmt label: (identifier #{name}) stmt: {stmt}))
}), }),
// ---- Collections ---- // ---- Collections ----
// Array literal // Array literal
rule!((array_literal element: _* @elems) => (array_literal element: {elems})), rule!((array_literal element: _* @elems) => (array_literal element: {..elems})),
// Empty array literal // Empty array literal
rule!((array_literal) => (array_literal)), rule!((array_literal) => (array_literal)),
// Dictionary literal — zip keys and values into key_value_pairs // Dictionary literal — zip keys and values into key_value_pairs
rule!( rule!(
(dictionary_literal key: _* @keys value: _* @vals) (dictionary_literal key: _* @keys value: _* @vals)
=> =>
(map_literal element: {keys.into_iter().zip(vals).map(|(k, v)| (map_literal element: {..keys.into_iter().zip(vals).map(|(k, v)|
tree!((key_value_pair key: {k} value: {v})) tree!((key_value_pair key: {k} value: {v}))
)}) )})
), ),
rule!((dictionary_literal element: _* @elems) => (map_literal element: {elems})), rule!((dictionary_literal element: _* @elems) => (map_literal element: {..elems})),
rule!((dictionary_literal_item key: @k value: @v) => (key_value_pair key: {k} value: {v})), rule!((dictionary_literal_item key: @k value: @v) => (key_value_pair key: {k} value: {v})),
// ---- Optionals and errors ---- // ---- Optionals and errors ----
// Optional chaining — unwrap the marker // Optional chaining — unwrap the marker
@@ -753,8 +768,8 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(do_statement body: (block statement: _* @body) catch: (catch_block)* @catches) (do_statement body: (block statement: _* @body) catch: (catch_block)* @catches)
=> =>
(try_expr (try_expr
body: (block stmt: {body}) body: (block stmt: {..body})
catch_clause: {catches}) catch_clause: {..catches})
), ),
// Catch block with bound identifier; optional where-clause guard. // Catch block with bound identifier; optional where-clause guard.
rule!( rule!(
@@ -766,14 +781,14 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(catch_clause (catch_clause
pattern: {pattern} pattern: {pattern}
guard: {guard} guard: {..guard}
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// Catch block without error binding // Catch block without error binding
rule!( rule!(
(catch_block keyword: (catch_keyword) body: (block statement: _* @body)) (catch_block keyword: (catch_keyword) body: (block statement: _* @body))
=> =>
(catch_clause body: (block stmt: {body})) (catch_clause body: (block stmt: {..body}))
), ),
// Empty catch block: catch {} // Empty catch block: catch {}
rule!( rule!(
@@ -787,7 +802,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(catch_clause (catch_clause
pattern: {pat} pattern: {pat}
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// As expression (type cast) — as?, as! // As expression (type cast) — as?, as!
rule!((as_expression (as_operator) @op expr: @val type: @ty) => (type_cast_expr expr: {val} operator: (infix_operator #{op}) type: {ty})), rule!((as_expression (as_operator) @op expr: @val type: @ty) => (type_cast_expr expr: {val} operator: (infix_operator #{op}) type: {ty})),
@@ -812,7 +827,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
pattern: (name_pattern identifier: (identifier #{parts.last().unwrap()})) pattern: (name_pattern identifier: (identifier #{parts.last().unwrap()}))
imported_expr: {name} imported_expr: {name}
modifier: (modifier #{kind}) modifier: (modifier #{kind})
modifier: {mods}) modifier: {..mods})
), ),
// Non-scoped import declaration (for example `import Foundation`): // Non-scoped import declaration (for example `import Foundation`):
// flatten the identifier parts into a member_access_expr and use a // flatten the identifier parts into a member_access_expr and use a
@@ -823,7 +838,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(import_declaration (import_declaration
pattern: (bulk_importing_pattern) pattern: (bulk_importing_pattern)
imported_expr: {name} imported_expr: {name}
modifier: {mods}) modifier: {..mods})
), ),
// ---- Types and classes ---- // ---- Types and classes ----
// Self expression → name_expr // Self expression → name_expr
@@ -831,7 +846,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
// Super expression → super_expr // Super expression → super_expr
rule!((super_expression) => (super_expr)), rule!((super_expression) => (super_expr)),
// Modifiers — unwrap to individual modifier children // Modifiers — unwrap to individual modifier children
rule!((modifiers _* @mods) => {mods}), rule!((modifiers _* @mods) => {..mods}),
rule!((attribute) @m => (modifier #{m})), rule!((attribute) @m => (modifier #{m})),
rule!((visibility_modifier) @m => (modifier #{m})), rule!((visibility_modifier) @m => (modifier #{m})),
rule!((function_modifier) @m => (modifier #{m})), rule!((function_modifier) @m => (modifier #{m})),
@@ -848,7 +863,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
// Keep a conservative textual fallback to avoid dropping type information. // Keep a conservative textual fallback to avoid dropping type information.
rule!((user_type) @ty => (named_type_expr name: (identifier #{ty}))), rule!((user_type) @ty => (named_type_expr name: (identifier #{ty}))),
// Tuple type → tuple_type_expr // Tuple type → tuple_type_expr
rule!((tuple_type element: _* @elems) => (tuple_type_expr element: {elems})), rule!((tuple_type element: _* @elems) => (tuple_type_expr element: {..elems})),
rule!((tuple_type_item name: @name type: @ty) => (tuple_type_element name: (identifier #{name}) type: {ty})), rule!((tuple_type_item name: @name type: @ty) => (tuple_type_element name: (identifier #{name}) type: {ty})),
rule!((tuple_type_item type: @ty) => (tuple_type_element type: {ty})), rule!((tuple_type_item type: @ty) => (tuple_type_element type: {ty})),
// Array type `[T]` → generic_type_expr with Array base // Array type `[T]` → generic_type_expr with Array base
@@ -865,7 +880,7 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
base: (named_type_expr name: (identifier "Optional")) base: (named_type_expr name: (identifier "Optional"))
type_argument: {w})), type_argument: {w})),
// Function type `(Params) -> Ret` → function_type_expr. // Function type `(Params) -> Ret` → function_type_expr.
rule!((function_type parameter: _* @ps return_type: @ret) => (function_type_expr parameter: {ps} return_type: {ret})), rule!((function_type parameter: _* @ps return_type: @ret) => (function_type_expr parameter: {..ps} return_type: {ret})),
rule!((function_type_parameter name: @name type: @ty) => (parameter external_name: (identifier #{name}) type: {ty})), rule!((function_type_parameter name: @name type: @ty) => (parameter external_name: (identifier #{name}) type: {ty})),
rule!((function_type_parameter type: @ty) => (parameter type: {ty})), rule!((function_type_parameter type: @ty) => (parameter type: {ty})),
// Selector expression: `#selector(inner)` -- not yet supported // Selector expression: `#selector(inner)` -- not yet supported
@@ -889,10 +904,10 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(class_like_declaration (class_like_declaration
modifier: (modifier #{kind}) modifier: (modifier #{kind})
modifier: {mods} modifier: {..mods}
name: (identifier #{name}) name: (identifier #{name})
base_type: {bases} base_type: {..bases}
member: {members}) member: {..members})
), ),
// Enum class declaration: same as a regular class but with an enum body. // Enum class declaration: same as a regular class but with an enum body.
rule!( rule!(
@@ -905,10 +920,10 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(class_like_declaration (class_like_declaration
modifier: (modifier #{kind}) modifier: (modifier #{kind})
modifier: {mods} modifier: {..mods}
name: (identifier #{name}) name: (identifier #{name})
base_type: {bases} base_type: {..bases}
member: {members}) member: {..members})
), ),
// Class declaration with empty body // Class declaration with empty body
rule!( rule!(
@@ -921,9 +936,9 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(class_like_declaration (class_like_declaration
modifier: (modifier #{kind}) modifier: (modifier #{kind})
modifier: {mods} modifier: {..mods}
name: (identifier #{name}) name: (identifier #{name})
base_type: {bases}) base_type: {..bases})
), ),
// Protocol declaration // Protocol declaration
rule!( rule!(
@@ -935,10 +950,10 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(class_like_declaration (class_like_declaration
modifier: (modifier "protocol") modifier: (modifier "protocol")
modifier: {mods} modifier: {..mods}
name: (identifier #{name}) name: (identifier #{name})
base_type: {bases} base_type: {..bases}
member: {members}) member: {..members})
), ),
// Protocol function — return type and body statements both optional. // Protocol function — return type and body statements both optional.
rule!( rule!(
@@ -950,11 +965,11 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(modifiers)* @mods) (modifiers)* @mods)
=> =>
(function_declaration (function_declaration
modifier: {mods} modifier: {..mods}
name: (identifier #{name}) name: (identifier #{name})
parameter: {params} parameter: {..params}
return_type: {ret} return_type: {..ret}
body: (block stmt: {body_stmts})) body: (block stmt: {..body_stmts}))
), ),
// Init declaration → constructor_declaration. Body statements optional; // Init declaration → constructor_declaration. Body statements optional;
// body itself is also optional (protocol requirement). // body itself is also optional (protocol requirement).
@@ -965,9 +980,9 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(modifiers)* @mods) (modifiers)* @mods)
=> =>
(constructor_declaration (constructor_declaration
modifier: {mods} modifier: {..mods}
parameter: {params} parameter: {..params}
body: (block stmt: {body_stmts})) body: (block stmt: {..body_stmts}))
), ),
// Deinit declaration → destructor_declaration. Body statements optional. // Deinit declaration → destructor_declaration. Body statements optional.
rule!( rule!(
@@ -976,15 +991,15 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(modifiers)* @mods) (modifiers)* @mods)
=> =>
(destructor_declaration (destructor_declaration
modifier: {mods} modifier: {..mods}
body: (block stmt: {body_stmts})) body: (block stmt: {..body_stmts}))
), ),
// Typealias declaration // Typealias declaration
rule!( rule!(
(typealias_declaration name: @name value: @val (modifiers)* @mods) (typealias_declaration name: @name value: @val (modifiers)* @mods)
=> =>
(type_alias_declaration (type_alias_declaration
modifier: {mods} modifier: {..mods}
name: (identifier #{name}) name: (identifier #{name})
r#type: {val}) r#type: {val})
), ),
@@ -999,9 +1014,9 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(associatedtype_declaration name: @name inherits_from: _? @bound (modifiers)* @mods) (associatedtype_declaration name: @name inherits_from: _? @bound (modifiers)* @mods)
=> =>
(associated_type_declaration (associated_type_declaration
modifier: {mods} modifier: {..mods}
name: (identifier #{name}) name: (identifier #{name})
bound: {bound}) bound: {..bound})
), ),
// Protocol property declaration: translate each accessor // Protocol property declaration: translate each accessor
// requirement to an `accessor_declaration` carrying the property // requirement to an `accessor_declaration` carrying the property
@@ -1011,25 +1026,28 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
// inner `getter_specifier`/`setter_specifier` rules emit // inner `getter_specifier`/`setter_specifier` rules emit
// complete nodes from the start (including the // complete nodes from the start (including the
// `chained_declaration` tag for non-first accessors). // `chained_declaration` tag for non-first accessors).
rule!( manual_rule!(
(protocol_property_declaration (protocol_property_declaration
name: (pattern bound_identifier: @name) name: (pattern bound_identifier: @name)
requirements: (protocol_property_requirements accessor: _+ @@accessors) requirements: (protocol_property_requirements accessor: _+ @accessors)
type: _? @ty type: _? @ty
(modifiers)* @mods) (modifiers)* @mods)
=> {
{{
ctx.property_name = Some(tree!((identifier #{name}))); ctx.property_name = Some(tree!((identifier #{name})));
ctx.property_type = ty; ctx.property_type = ctx.translate_opt(ty)?;
ctx.outer_modifiers = mods; let mut modifiers = Vec::new();
for m in mods {
modifiers.extend(ctx.translate(m)?);
}
ctx.outer_modifiers = modifiers;
let mut result = Vec::new(); let mut result = Vec::new();
for (i, acc) in accessors.into_iter().enumerate() { for (i, acc) in accessors.into_iter().enumerate() {
ctx.is_chained = i > 0; ctx.is_chained = i > 0;
result.extend(ctx.translate(acc)?); result.extend(ctx.translate(acc)?);
} }
result Ok(result)
}} }
), ),
// getter_specifier / setter_specifier → bodyless accessor_declaration // getter_specifier / setter_specifier → bodyless accessor_declaration
// getter_specifier / setter_specifier → bodyless // getter_specifier / setter_specifier → bodyless
@@ -1040,23 +1058,23 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
=> =>
(accessor_declaration (accessor_declaration
name: {ctx.property_name.ok_or("getter_specifier outside protocol_property_declaration context")?} name: {ctx.property_name.ok_or("getter_specifier outside protocol_property_declaration context")?}
type: {ctx.property_type} type: {..ctx.property_type}
accessor_kind: (accessor_kind "get") accessor_kind: (accessor_kind "get")
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)}) modifier: {..chained_modifier(&mut ctx)})
), ),
rule!( rule!(
(setter_specifier) (setter_specifier)
=> =>
(accessor_declaration (accessor_declaration
name: {ctx.property_name.ok_or("setter_specifier outside protocol_property_declaration context")?} name: {ctx.property_name.ok_or("setter_specifier outside protocol_property_declaration context")?}
type: {ctx.property_type} type: {..ctx.property_type}
accessor_kind: (accessor_kind "set") accessor_kind: (accessor_kind "set")
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)}) modifier: {..chained_modifier(&mut ctx)})
), ),
// protocol_property_requirements wrapper — should be consumed by above; fallback // protocol_property_requirements wrapper — should be consumed by above; fallback
rule!((protocol_property_requirements accessor: _* @accs) => {accs}), rule!((protocol_property_requirements accessor: _* @accs) => {..accs}),
// Computed getter → accessor_declaration (body optional). // Computed getter → accessor_declaration (body optional).
// Reads property name/type from the outer property_binding rule // Reads property name/type from the outer property_binding rule
// and binding/outer modifiers + chained tag from the outer // and binding/outer modifiers + chained tag from the outer
@@ -1065,58 +1083,58 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(computed_getter body: (block statement: _* @body)?) (computed_getter body: (block statement: _* @body)?)
=> =>
(accessor_declaration (accessor_declaration
modifier: {ctx.binding_modifier} modifier: {..ctx.binding_modifier}
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
name: {ctx.property_name.ok_or("computed_getter outside property_binding context")?} name: {ctx.property_name.ok_or("computed_getter outside property_binding context")?}
type: {ctx.property_type} type: {..ctx.property_type}
accessor_kind: (accessor_kind "get") accessor_kind: (accessor_kind "get")
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// Computed setter with explicit parameter name. // Computed setter with explicit parameter name.
rule!( rule!(
(computed_setter parameter: @param body: (block statement: _* @body)) (computed_setter parameter: @param body: (block statement: _* @body))
=> =>
(accessor_declaration (accessor_declaration
modifier: {ctx.binding_modifier} modifier: {..ctx.binding_modifier}
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
name: {ctx.property_name.ok_or("computed_setter outside property_binding context")?} name: {ctx.property_name.ok_or("computed_setter outside property_binding context")?}
type: {ctx.property_type} type: {..ctx.property_type}
accessor_kind: (accessor_kind "set") accessor_kind: (accessor_kind "set")
parameter: (parameter pattern: (name_pattern identifier: (identifier #{param}))) parameter: (parameter pattern: (name_pattern identifier: (identifier #{param})))
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// Computed setter without explicit parameter name; body optional. // Computed setter without explicit parameter name; body optional.
rule!( rule!(
(computed_setter body: (block statement: _* @body)?) (computed_setter body: (block statement: _* @body)?)
=> =>
(accessor_declaration (accessor_declaration
modifier: {ctx.binding_modifier} modifier: {..ctx.binding_modifier}
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
name: {ctx.property_name.ok_or("computed_setter outside property_binding context")?} name: {ctx.property_name.ok_or("computed_setter outside property_binding context")?}
type: {ctx.property_type} type: {..ctx.property_type}
accessor_kind: (accessor_kind "set") accessor_kind: (accessor_kind "set")
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// Computed modify → accessor_declaration // Computed modify → accessor_declaration
rule!( rule!(
(computed_modify body: (block statement: _* @body)) (computed_modify body: (block statement: _* @body))
=> =>
(accessor_declaration (accessor_declaration
modifier: {ctx.binding_modifier} modifier: {..ctx.binding_modifier}
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
name: {ctx.property_name.ok_or("computed_modify outside property_binding context")?} name: {ctx.property_name.ok_or("computed_modify outside property_binding context")?}
type: {ctx.property_type} type: {..ctx.property_type}
accessor_kind: (accessor_kind "modify") accessor_kind: (accessor_kind "modify")
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// willset/didset block — spread to children (only reachable as a // willset/didset block — spread to children (only reachable as a
// fallback; the outer property_binding manual rule normally // fallback; the outer property_binding manual rule normally
// captures the willset/didset clauses directly). // captures the willset/didset clauses directly).
rule!((willset_didset_block _* @clauses) => {clauses}), rule!((willset_didset_block _* @clauses) => {..clauses}),
// willset clause → accessor_declaration (body optional). Reads // willset clause → accessor_declaration (body optional). Reads
// `ctx.property_name` set by the outer property_binding rule and // `ctx.property_name` set by the outer property_binding rule and
// binding/outer modifiers + chained tag from the outer // binding/outer modifiers + chained tag from the outer
@@ -1125,24 +1143,24 @@ fn translation_rules() -> Vec<Rule<SwiftContext>> {
(willset_clause body: (block statement: _* @body)?) (willset_clause body: (block statement: _* @body)?)
=> =>
(accessor_declaration (accessor_declaration
modifier: {ctx.binding_modifier} modifier: {..ctx.binding_modifier}
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
name: {ctx.property_name.ok_or("willset_clause outside property_binding context")?} name: {ctx.property_name.ok_or("willset_clause outside property_binding context")?}
accessor_kind: (accessor_kind "willSet") accessor_kind: (accessor_kind "willSet")
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// didset clause → accessor_declaration (body optional). // didset clause → accessor_declaration (body optional).
rule!( rule!(
(didset_clause body: (block statement: _* @body)?) (didset_clause body: (block statement: _* @body)?)
=> =>
(accessor_declaration (accessor_declaration
modifier: {ctx.binding_modifier} modifier: {..ctx.binding_modifier}
modifier: {ctx.outer_modifiers.clone()} modifier: {..ctx.outer_modifiers.clone()}
modifier: {chained_modifier(&mut ctx)} modifier: {..chained_modifier(&mut ctx)}
name: {ctx.property_name.ok_or("didset_clause outside property_binding context")?} name: {ctx.property_name.ok_or("didset_clause outside property_binding context")?}
accessor_kind: (accessor_kind "didSet") accessor_kind: (accessor_kind "didSet")
body: (block stmt: {body})) body: (block stmt: {..body}))
), ),
// Preprocessor conditionals — unsupported // Preprocessor conditionals — unsupported
rule!((diagnostic) => (unsupported_node)), rule!((diagnostic) => (unsupported_node)),