Compare commits

..

2 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
2e6bc6612c Update inline expectations and relearn affected tests 2026-06-26 08:07:15 +00:00
copilot-swe-agent[bot]
5c2614283d Initial plan 2026-06-26 07:28:39 +00:00
119 changed files with 1050 additions and 3831 deletions

View File

@@ -28,6 +28,7 @@
/swift/extractor/ @github/codeql-swift @github/code-scanning-language-coverage
/misc/codegen/ @github/codeql-swift
/java/kotlin-extractor/ @github/codeql-kotlin @github/code-scanning-language-coverage
/java/ql/test-kotlin1/ @github/codeql-kotlin
/java/ql/test-kotlin2/ @github/codeql-kotlin
# Experimental CodeQL cryptography

View File

@@ -3,13 +3,13 @@ class C
void Problems()
{
// correct expectation comment, but only for `problem-query`
var x = "Alert"; // $ Alert
var x = "Alert"; // $ Alert[problem-query]
// irrelevant expectation comment, will be ignored
x = "Not an alert"; // $ IrrelevantTag
// incorrect expectation comment
x = "Also not an alert"; // $ Alert
x = "Also not an alert"; // $ MISSING: Alert[problem-query]
// missing expectation comment, but only for `problem-query`
x = "Alert";

View File

@@ -13,8 +13,6 @@
| InlineTests.cs:88:13:88:23 | "Alert:0:1" | InlineTests.cs:88:13:88:23 | "Alert:0:1" | InlineTests.cs:87:16:87:21 | "Sink" | This is a problem |
edges
testFailures
| InlineTests.cs:6:26:6:35 | // ... | Missing result: Alert |
| InlineTests.cs:12:34:12:43 | // ... | Missing result: Alert |
| InlineTests.cs:37:28:37:38 | // ... | Missing result: Source |
| InlineTests.cs:38:24:38:32 | // ... | Missing result: Sink |
| InlineTests.cs:39:33:39:42 | // ... | Missing result: Alert |

View File

@@ -3,8 +3,6 @@
| InlineTests.cs:100:13:100:25 | "Alert:3:2:1" | InlineTests.cs:97:18:97:25 | "Source" | InlineTests.cs:98:16:98:21 | "Sink" | This is a problem with $@ | InlineTests.cs:99:19:99:27 | "Related" | a related location |
edges
testFailures
| InlineTests.cs:6:26:6:35 | // ... | Missing result: Alert |
| InlineTests.cs:12:34:12:43 | // ... | Missing result: Alert |
| InlineTests.cs:32:32:32:42 | // ... | Missing result: Source |
| InlineTests.cs:33:28:33:36 | // ... | Missing result: Sink |
| InlineTests.cs:34:30:34:39 | // ... | Missing result: Alert |

View File

@@ -3,7 +3,6 @@
| InlineTests.cs:15:13:15:19 | "Alert" | This is a problem |
| InlineTests.cs:18:13:18:19 | "Alert" | This is a problem |
testFailures
| InlineTests.cs:12:34:12:43 | // ... | Missing result: Alert |
| InlineTests.cs:15:13:15:19 | This is a problem | Unexpected result: Alert |
| InlineTests.cs:34:30:34:39 | // ... | Missing result: Alert |
| InlineTests.cs:39:33:39:42 | // ... | Missing result: Alert |

View File

@@ -2,8 +2,6 @@
| InlineTests.cs:22:13:22:21 | "Alert:1" | This is a problem with $@ | InlineTests.cs:21:23:21:31 | "Related" | a related location |
| InlineTests.cs:26:13:26:21 | "Alert:1" | This is a problem with $@ | InlineTests.cs:25:19:25:27 | "Related" | a related location |
testFailures
| InlineTests.cs:6:26:6:35 | // ... | Missing result: Alert |
| InlineTests.cs:12:34:12:43 | // ... | Missing result: Alert |
| InlineTests.cs:25:19:25:27 | "Related" | Unexpected result: RelatedLocation |
| InlineTests.cs:34:30:34:39 | // ... | Missing result: Alert |
| InlineTests.cs:39:33:39:42 | // ... | Missing result: Alert |

View File

@@ -33,11 +33,9 @@ module StoredXss {
walkFn.getACall().getArgument(1) = f.getASuccessor*()
)
or
// The return value of a call to `os.DirEntry.Name`, `os.FileInfo.Name`
// or `os.File.ReadDirNames`.
exists(DataFlow::CallNode cn, Method m | m = cn.getTarget() and this = cn.getResult(0) |
m.implements("io/fs", ["DirEntry", "FileInfo"], "Name") or
m.hasQualifiedName("os", "File", "ReadDirNames")
// A call to os.FileInfo.Name
exists(Method m | m.implements("io/fs", "FileInfo", "Name") |
m = this.(DataFlow::CallNode).getTarget()
)
}
}

View File

@@ -156,3 +156,12 @@ nodes
| websocketXss.go:54:3:54:38 | ... := ...[1] | semmle.label | ... := ...[1] |
| websocketXss.go:55:24:55:31 | gorilla3 | semmle.label | gorilla3 |
subpaths
testFailures
| websocketXss.go:30:32:30:60 | comment | Missing result: Source[go/reflected-xss] |
| websocketXss.go:31:11:31:14 | xnet [postupdate] | Unexpected result: Source |
| websocketXss.go:34:30:34:58 | comment | Missing result: Source[go/reflected-xss] |
| websocketXss.go:35:21:35:25 | xnet2 [postupdate] | Unexpected result: Source |
| websocketXss.go:46:38:46:66 | comment | Missing result: Source[go/reflected-xss] |
| websocketXss.go:47:26:47:35 | gorillaMsg [postupdate] | Unexpected result: Source |
| websocketXss.go:50:33:50:61 | comment | Missing result: Source[go/reflected-xss] |
| websocketXss.go:51:17:51:24 | gorilla2 [postupdate] | Unexpected result: Source |

View File

@@ -1,9 +1,7 @@
#select
| StoredXss.go:13:21:13:36 | ...+... | StoredXss.go:13:21:13:31 | call to Name | StoredXss.go:13:21:13:36 | ...+... | Stored cross-site scripting vulnerability due to $@. | StoredXss.go:13:21:13:31 | call to Name | stored value |
| stored.go:30:22:30:25 | name | stored.go:18:3:18:28 | ... := ...[0] | stored.go:30:22:30:25 | name | Stored cross-site scripting vulnerability due to $@. | stored.go:18:3:18:28 | ... := ...[0] | stored value |
| stored.go:61:22:61:25 | path | stored.go:59:30:59:33 | SSA def(path) | stored.go:61:22:61:25 | path | Stored cross-site scripting vulnerability due to $@. | stored.go:59:30:59:33 | SSA def(path) | stored value |
edges
| StoredXss.go:13:21:13:31 | call to Name | StoredXss.go:13:21:13:36 | ...+... | provenance | |
| stored.go:18:3:18:28 | ... := ...[0] | stored.go:25:14:25:17 | rows | provenance | Src:MaD:1 |
| stored.go:25:14:25:17 | rows | stored.go:25:29:25:33 | &... [postupdate] | provenance | FunctionModel |
| stored.go:25:29:25:33 | &... [postupdate] | stored.go:30:22:30:25 | name | provenance | |
@@ -11,8 +9,6 @@ edges
models
| 1 | Source: database/sql; DB; true; Query; ; ; ReturnValue[0]; database; manual |
nodes
| StoredXss.go:13:21:13:31 | call to Name | semmle.label | call to Name |
| StoredXss.go:13:21:13:36 | ...+... | semmle.label | ...+... |
| stored.go:18:3:18:28 | ... := ...[0] | semmle.label | ... := ...[0] |
| stored.go:25:14:25:17 | rows | semmle.label | rows |
| stored.go:25:29:25:33 | &... [postupdate] | semmle.label | &... [postupdate] |
@@ -20,3 +16,5 @@ nodes
| stored.go:59:30:59:33 | SSA def(path) | semmle.label | SSA def(path) |
| stored.go:61:22:61:25 | path | semmle.label | path |
subpaths
testFailures
| StoredXss.go:13:39:13:63 | comment | Missing result: Alert[go/stored-xss] |

View File

@@ -27,12 +27,12 @@ func xss(w http.ResponseWriter, r *http.Request) {
origin := "test"
{
ws, _ := websocket.Dial(uri, "", origin)
var xnet = make([]byte, 512)
ws.Read(xnet) // $ Source[go/reflected-xss]
var xnet = make([]byte, 512) // $ Source[go/reflected-xss]
ws.Read(xnet)
fmt.Fprintf(w, "%v", xnet) // $ Alert[go/reflected-xss]
codec := &websocket.Codec{Marshal: marshal, Unmarshal: unmarshal}
xnet2 := make([]byte, 512)
codec.Receive(ws, xnet2) // $ Source[go/reflected-xss]
xnet2 := make([]byte, 512) // $ Source[go/reflected-xss]
codec.Receive(ws, xnet2)
fmt.Fprintf(w, "%v", xnet2) // $ Alert[go/reflected-xss]
}
{
@@ -43,12 +43,12 @@ func xss(w http.ResponseWriter, r *http.Request) {
{
dialer := gorilla.Dialer{}
conn, _, _ := dialer.Dial(uri, nil)
var gorillaMsg = make([]byte, 512)
gorilla.ReadJSON(conn, gorillaMsg) // $ Source[go/reflected-xss]
fmt.Fprintf(w, "%v", gorillaMsg) // $ Alert[go/reflected-xss]
var gorillaMsg = make([]byte, 512) // $ Source[go/reflected-xss]
gorilla.ReadJSON(conn, gorillaMsg)
fmt.Fprintf(w, "%v", gorillaMsg) // $ Alert[go/reflected-xss]
gorilla2 := make([]byte, 512)
conn.ReadJSON(gorilla2) // $ Source[go/reflected-xss]
gorilla2 := make([]byte, 512) // $ Source[go/reflected-xss]
conn.ReadJSON(gorilla2)
fmt.Fprintf(w, "%v", gorilla2) // $ Alert[go/reflected-xss]
_, gorilla3, _ := conn.ReadMessage() // $ Source[go/reflected-xss]

View File

@@ -75,9 +75,6 @@ def get_version():
def install(version: str, quiet: bool):
if install_dir.exists():
return
if quiet:
info_out = subprocess.DEVNULL
info = lambda *args: None
@@ -86,6 +83,8 @@ def install(version: str, quiet: bool):
info = lambda *args: print(*args, file=sys.stderr)
file = file_template.format(version=version)
url = url_template.format(version=version)
if install_dir.exists():
shutil.rmtree(install_dir)
install_dir.mkdir()
zips_dir.mkdir(exist_ok=True)
zip = zips_dir / file
@@ -157,11 +156,8 @@ def main(opts, forwarded_opts):
selected_version = current_version or DEFAULT_VERSION
if selected_version != current_version:
# don't print information about install procedure unless explicitly using --select
if install_dir.exists():
shutil.rmtree(install_dir)
install(selected_version, quiet=opts.select is None)
version_file.write_text(selected_version)
# don't print information about install procedure unless explicitly using --select
install(selected_version, quiet=opts.select is None)
if opts.select and not forwarded_opts and not opts.version:
print(f"selected {selected_version}")
return

View File

@@ -6,8 +6,6 @@ import com.github.codeql.utils.*
import com.github.codeql.utils.versions.*
import com.semmle.extractor.java.OdasaOutput
import java.io.Closeable
import java.nio.file.Files
import java.nio.file.Path
import java.util.*
import kotlin.collections.ArrayList
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
@@ -52,7 +50,6 @@ import org.jetbrains.kotlin.load.java.structure.JavaMethod
import org.jetbrains.kotlin.load.java.structure.JavaTypeParameter
import org.jetbrains.kotlin.load.java.structure.JavaTypeParameterListOwner
import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaClass
import org.jetbrains.kotlin.fir.java.VirtualFileBasedSourceElement
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.types.Variance
import org.jetbrains.kotlin.util.OperatorNameConventions
@@ -164,100 +161,23 @@ open class KotlinFileExtractor(
}
}
private fun javaBinaryDeclaresMethod(c: IrClass, name: String): Boolean? {
// K1 path: source is JavaSourceElement wrapping a BinaryJavaClass - inspect class metadata
val binaryJavaClass = (c.source as? JavaSourceElement)?.javaElement as? BinaryJavaClass
if (binaryJavaClass != null) {
return binaryJavaClass.methods.any { it.name.asString() == name }
private fun javaBinaryDeclaresMethod(c: IrClass, name: String) =
((c.source as? JavaSourceElement)?.javaElement as? BinaryJavaClass)?.methods?.any {
it.name.asString() == name
}
// K2 path: binary Java classes use VirtualFileBasedSourceElement instead of
// JavaSourceElement. The BinaryJavaClass is not stored in the source element, so we parse
// the class bytes directly using ASM to check if the method is explicitly declared.
if (c.source is VirtualFileBasedSourceElement) {
val virtualFile = (c.source as VirtualFileBasedSourceElement).virtualFile
if (!virtualFile.name.endsWith(".class")) return null
return try {
val bytes = virtualFile.contentsToByteArray()
var found = false
var hasKotlinMetadata = false
val reader = org.jetbrains.org.objectweb.asm.ClassReader(bytes)
reader.accept(
object : org.jetbrains.org.objectweb.asm.ClassVisitor(
org.jetbrains.org.objectweb.asm.Opcodes.ASM9
) {
override fun visitAnnotation(
descriptor: String,
visible: Boolean
): org.jetbrains.org.objectweb.asm.AnnotationVisitor? {
if (descriptor == "Lkotlin/Metadata;") hasKotlinMetadata = true
return null
}
override fun visitMethod(
access: Int,
methodName: String,
descriptor: String,
signature: String?,
exceptions: Array<String>?
): org.jetbrains.org.objectweb.asm.MethodVisitor? {
if (methodName == name) found = true
return null
}
},
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_CODE or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_DEBUG or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_FRAMES
)
if (hasKotlinMetadata) false else found
} catch (e: Exception) {
logger.warn("Failed to check binary class methods for ${c.fqNameWhenAvailable}: $e")
null
}
}
return null
}
private fun isJavaBinaryDeclaration(f: IrFunction) =
f.parentClassOrNull?.let { javaBinaryDeclaresMethod(it, f.name.asString()) } ?: false
private fun hasConcreteSiblingObjectMethod(f: IrFunction): Boolean {
val parentClass = f.parentClassOrNull ?: return false
return parentClass.declarations
.asSequence()
.filterIsInstance<IrFunction>()
.filter { sibling ->
sibling !== f &&
sibling.name == f.name &&
sibling.codeQlValueParameters.size == f.codeQlValueParameters.size
}
.any { sibling ->
val hasInvisibleFakeVisibility =
sibling.visibility.let {
it is DelegatedDescriptorVisibility && it.delegate == Visibilities.InvisibleFake
}
!sibling.isFakeOverride && !hasInvisibleFakeVisibility
}
}
private fun isJavaBinaryObjectMethodRedeclaration(d: IrDeclaration) =
when (d) {
is IrFunction ->
d.parentClassOrNull?.typeParameters?.isEmpty() == true &&
when (d.name.asString()) {
"toString" -> d.codeQlValueParameters.isEmpty()
"hashCode" -> d.codeQlValueParameters.isEmpty()
// Under K2 (language version 2.0+), the Object.equals(Object) parameter is
// typed as Any (non-nullable) rather than Any? (nullable). Accept both.
"equals" ->
d.codeQlValueParameters
.singleOrNull()
?.type
?.let { it.isNullableAny() || it.isAny() } ?: false
"equals" -> d.codeQlValueParameters.singleOrNull()?.type?.isNullableAny() ?: false
else -> false
} &&
!hasConcreteSiblingObjectMethod(d) &&
isJavaBinaryDeclaration(d)
} && isJavaBinaryDeclaration(d)
else -> false
}
@@ -1392,28 +1312,27 @@ open class KotlinFileExtractor(
): TypeResults {
with("value parameter", vp) {
val location = locOverride ?: getLocation(vp, classTypeArgsIncludingOuterClasses)
val parentFunction = vp.parent as? IrFunction
val javaCallable = parentFunction?.let { getJavaCallable(it) }
val maybeAlteredType =
parentFunction?.let {
(vp.parent as? IrFunction)?.let {
if (overridesCollectionsMethodWithAlteredParameterTypes(it))
eraseCollectionsMethodParameterType(vp.type, it.name.asString(), idx)
else if (
(parentFunction as? IrConstructor)?.parentClassOrNull?.kind ==
(vp.parent as? IrConstructor)?.parentClassOrNull?.kind ==
ClassKind.ANNOTATION_CLASS
)
kClassToJavaClass(vp.type)
else null
} ?: vp.type
val javaType = javaCallable?.let { jCallable -> getJavaValueParameterType(jCallable, idx) }
val addParameterWildcardsByDefault =
!getInnermostWildcardSupppressionAnnotation(vp) &&
!(javaCallable == null &&
parentFunction?.origin == IrDeclarationOrigin.IR_EXTERNAL_JAVA_DECLARATION_STUB)
val javaType =
(vp.parent as? IrFunction)?.let {
getJavaCallable(it)?.let { jCallable ->
getJavaValueParameterType(jCallable, idx)
}
}
val typeWithWildcards =
addJavaLoweringWildcards(
maybeAlteredType,
addParameterWildcardsByDefault,
!getInnermostWildcardSupppressionAnnotation(vp),
javaType
)
val substitutedType =
@@ -1427,9 +1346,9 @@ open class KotlinFileExtractor(
vp.origin == IrDeclarationOrigin.UNDERSCORE_PARAMETER ||
((vp.parent as? IrFunction)?.let { hasSynthesizedParameterNames(it) } ?: true)
val javaParameter =
when (javaCallable) {
is JavaConstructor -> javaCallable.valueParameters.getOrNull(idx)
is JavaMethod -> javaCallable.valueParameters.getOrNull(idx)
when (val callable = (vp.parent as? IrFunction)?.let { getJavaCallable(it) }) {
is JavaConstructor -> callable.valueParameters.getOrNull(idx)
is JavaMethod -> callable.valueParameters.getOrNull(idx)
else -> null
}
val extraAnnotations =
@@ -2955,45 +2874,6 @@ open class KotlinFileExtractor(
return v
}
private val sourceTextCache = mutableMapOf<String, String?>()
private fun getCurrentFileSourceText() =
sourceTextCache.getOrPut(filePath) {
runCatching { Files.readString(Path.of(filePath)) }.getOrNull()
}
private fun getVariableNameLocation(v: IrVariable): Label<DbLocation>? {
if (v.startOffset < 0 || v.endOffset < v.startOffset) return null
val source = getCurrentFileSourceText() ?: return null
if (v.startOffset >= source.length) return null
val name = v.name.asString()
if (name.isEmpty()) return null
val endExclusive = minOf(v.endOffset + 1, source.length)
val declarationText = source.substring(v.startOffset, endExclusive)
val nameOffsetInDeclaration = declarationText.indexOf(name)
if (nameOffsetInDeclaration < 0) return null
val nameStartOffset = v.startOffset + nameOffsetInDeclaration
val nameEndOffset = nameStartOffset + name.length - 1
return tw.getLocation(nameStartOffset, nameEndOffset)
}
private fun shouldUseVariableNameLocation(v: IrVariable): Boolean {
val initializer = v.initializer
return initializer is IrTypeOperatorCall && initializer.operator == IrTypeOperator.IMPLICIT_NOTNULL
}
private fun getVariableLocation(v: IrVariable): Label<DbLocation> {
if (shouldUseVariableNameLocation(v)) {
val nameLocation = getVariableNameLocation(v)
if (nameLocation != null) return nameLocation
}
return tw.getLocation(getVariableLocationProvider(v))
}
private fun extractVariable(
v: IrVariable,
callable: Label<out DbCallable>,
@@ -3002,7 +2882,7 @@ open class KotlinFileExtractor(
) {
with("variable", v) {
val stmtId = tw.getFreshIdLabel<DbLocalvariabledeclstmt>()
val locId = getVariableLocation(v)
val locId = tw.getLocation(getVariableLocationProvider(v))
tw.writeStmts_localvariabledeclstmt(stmtId, parent, idx, callable)
tw.writeHasLocation(stmtId, locId)
extractVariableExpr(v, callable, stmtId, 1, stmtId)
@@ -3020,7 +2900,7 @@ open class KotlinFileExtractor(
with("variable expr", v) {
val varId = useVariable(v)
val exprId = tw.getFreshIdLabel<DbLocalvariabledeclexpr>()
val locId = getVariableLocation(v)
val locId = tw.getLocation(getVariableLocationProvider(v))
val type = useType(v.type)
tw.writeLocalvars(varId, v.name.asString(), type.javaResult.id, exprId)
tw.writeLocalvarsKotlinType(varId, type.kotlinResult.id)
@@ -4186,28 +4066,6 @@ open class KotlinFileExtractor(
else -> false
}
private fun getCallResultType(c: IrCall, syntacticCallTarget: IrFunction): IrType {
if (syntacticCallTarget.origin != IrDeclarationOrigin.IR_EXTERNAL_JAVA_DECLARATION_STUB) {
return c.type
}
val primitiveInfo =
(c.type as? IrSimpleType)?.let { primitiveTypeMapping.getPrimitiveInfo(it) } ?: return c.type
val parentClass = syntacticCallTarget.parentClassOrNull ?: return c.type
val returnIsClassifier =
javaBinaryMethodReturnIsClassifierType(
parentClass,
getFunctionShortName(syntacticCallTarget).nameInDB,
syntacticCallTarget.codeQlValueParameters.size,
syntacticCallTarget is IrConstructor
)
return if (returnIsClassifier == true) {
primitiveInfo.javaClass.symbol.typeWith()
} else {
c.type
}
}
private fun isGenericArrayType(typeName: String) =
when (typeName) {
"Array" -> true
@@ -4253,7 +4111,7 @@ open class KotlinFileExtractor(
extractRawMethodAccess(
syntacticCallTarget,
c,
getCallResultType(c, syntacticCallTarget),
c.type,
callable,
parent,
idx,

View File

@@ -36,7 +36,6 @@ import org.jetbrains.kotlin.load.java.BuiltinMethodsWithSpecialGenericSignature
import org.jetbrains.kotlin.load.java.JvmAbi
import org.jetbrains.kotlin.load.java.sources.JavaSourceElement
import org.jetbrains.kotlin.load.java.structure.*
import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaClass
import org.jetbrains.kotlin.load.java.typeEnhancement.hasEnhancedNullability
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.NameUtils
@@ -997,20 +996,7 @@ open class KotlinUsesExtractor(
)
return null
}
val fileClassId = extractFileClass(fqName)
// Under K2, external file class members sit directly under IrExternalPackageFragment
// rather than under their IrClass parent. In that case the file class entity won't
// get a location set through the normal extractClassSource path.
if (d is IrMemberWithContainerSource && tw.lm.externalFileClassLocationsExtracted.add(fqName)) {
val binaryPath =
getContainerSourceBinaryPath(d.containerSource)
?.let { normalizeExternalFileClassBinaryPath(it, fqName) }
if (binaryPath != null && shouldUseConcreteExternalFileClassLocation(binaryPath)) {
val fileId = tw.mkFileId(binaryPath, true)
tw.writeHasLocation(fileClassId, tw.getWholeFileLocation(fileId))
}
}
return fileClassId
return extractFileClass(fqName)
}
return useDeclarationParent(parent, canBeTopLevel, classTypeArguments, inReceiverContext)
}
@@ -1385,13 +1371,8 @@ open class KotlinUsesExtractor(
parentId: Label<out DbElement>,
classTypeArgsIncludingOuterClasses: List<IrTypeArgument>?,
maybeParameterList: List<IrValueParameter>? = null
): String {
val javaCallable = getJavaCallable(f)
val addParameterWildcardsByDefault =
!getInnermostWildcardSupppressionAnnotation(f) &&
!(javaCallable == null && f.origin == IrDeclarationOrigin.IR_EXTERNAL_JAVA_DECLARATION_STUB)
return getFunctionLabel(
): String =
getFunctionLabel(
f.parent,
parentId,
getFunctionShortName(f).nameInDB,
@@ -1401,10 +1382,9 @@ open class KotlinUsesExtractor(
getFunctionTypeParameters(f),
classTypeArgsIncludingOuterClasses,
overridesCollectionsMethodWithAlteredParameterTypes(f),
javaCallable,
addParameterWildcardsByDefault
getJavaCallable(f),
!getInnermostWildcardSupppressionAnnotation(f)
)
}
/*
* This function actually generates the label for a function.
@@ -1491,41 +1471,15 @@ open class KotlinUsesExtractor(
// Finally, mimic the Java extractor's behaviour by naming functions with type
// parameters for their erased types;
// those without type parameters are named for the generic type.
var maybeErased =
val maybeErased =
if (functionTypeParameters.isEmpty()) maybeSubbed else erase(maybeSubbed)
// K2 compatibility: under K2, Java @NotNull reference types such as @NotNull Integer
// are enhanced to Kotlin primitives (e.g. kotlin.Int). But the Java extractor uses
// the original reference type (java.lang.Integer) in callable labels. When we detect
// that the original Java parameter type is a reference (classifier) type but the
// Kotlin IR type is a primitive, revert to the boxed Java class so both extractors
// produce matching callable IDs.
if (functionTypeParameters.isEmpty()) {
val primitiveInfo = (maybeErased as? IrSimpleType)?.let {
primitiveTypeMapping.getPrimitiveInfo(it)
}
if (primitiveInfo != null) {
val parentClass = parent as? IrClass
if (parentClass != null) {
val isClassifierType = javaBinaryMethodParamIsClassifierType(
parentClass,
name,
allParamTypes.size,
name == "<init>",
it.index
)
if (isClassifierType == true) {
maybeErased = primitiveInfo.javaClass.symbol.typeWith()
}
}
}
}
"{${useType(maybeErased).javaResult.id}}"
}
val paramTypeIds =
allParamTypes
.withIndex()
.joinToString(separator = ",", transform = getIdForFunctionLabel)
var labelReturnType =
val labelReturnType =
if (name == "<init>") pluginContext.irBuiltIns.unitType
else
erase(
@@ -1535,28 +1489,6 @@ open class KotlinUsesExtractor(
pluginContext
)
)
// K2 compatibility: same as for parameters, if the Java binary method return type is a
// reference type but K2 enhanced it to a Kotlin primitive, use the boxed Java class.
if (functionTypeParameters.isEmpty() && name != "<init>") {
val primitiveInfo = (labelReturnType as? IrSimpleType)?.let {
primitiveTypeMapping.getPrimitiveInfo(it)
}
if (primitiveInfo != null) {
val parentClass = parent as? IrClass
if (parentClass != null) {
val returnIsClassifier =
javaBinaryMethodReturnIsClassifierType(
parentClass,
name,
allParamTypes.size,
false
)
if (returnIsClassifier == true) {
labelReturnType = primitiveInfo.javaClass.symbol.typeWith()
}
}
}
}
// Note that `addJavaLoweringWildcards` is not required here because the return type used to
// form the function
// label is always erased.
@@ -1662,23 +1594,9 @@ open class KotlinUsesExtractor(
}
@OptIn(ObsoleteDescriptorBasedAPI::class)
fun getJavaCallable(f: IrFunction): JavaMember? {
val fromDescriptor = (f.descriptor.source as? JavaSourceElement)?.javaElement as? JavaMember
if (fromDescriptor != null) return fromDescriptor
fun getJavaCallable(f: IrFunction) =
(f.descriptor.source as? JavaSourceElement)?.javaElement as? JavaMember
// K2 fallback: under K2, descriptor.source may not carry JavaSourceElement for binary Java
// methods. Try to get the JavaMember from the parent class's binary class directly.
val parentClass = f.parentClassOrNull ?: return null
val binaryJavaClass = (parentClass.source as? JavaSourceElement)?.javaElement as? BinaryJavaClass
?: return null
val name = getFunctionShortName(f).nameInDB
val nParams = f.codeQlValueParameters.size
return if (f is IrConstructor) {
binaryJavaClass.constructors.find { it.valueParameters.size == nParams }
} else {
binaryJavaClass.methods.find { it.name.asString() == name && it.valueParameters.size == nParams }
}
}
fun getJavaValueParameterType(m: JavaMember, idx: Int) =
when (m) {
is JavaMethod -> m.valueParameters[idx].type

View File

@@ -51,13 +51,6 @@ class TrapLabelManager {
* to avoid duplication.
*/
val fileClassLocationsExtracted = HashSet<IrFile>()
/**
* Tracks external file classes (by FqName) whose location has been set from a binary path.
* Used to avoid writing duplicate hasLocation facts for external file class entities extracted
* through the K2 code path where declarations sit directly under IrExternalPackageFragment.
*/
val externalFileClassLocationsExtracted = HashSet<org.jetbrains.kotlin.name.FqName>()
}
/**

View File

@@ -17,7 +17,6 @@ import org.jetbrains.kotlin.load.kotlin.JvmPackagePartSource
import org.jetbrains.kotlin.load.kotlin.KotlinJvmBinarySourceElement
import org.jetbrains.kotlin.load.kotlin.VirtualFileKotlinClass
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.serialization.deserialization.descriptors.DeserializedContainerSource
// Adapted from Kotlin's interpreter/Utils.kt function 'internalName'
// Translates class names into their JLS section 13.1 binary name,
@@ -177,238 +176,15 @@ fun getIrDeclarationBinaryPath(d: IrDeclaration): String? {
// This is in a file class.
val fqName = getFileClassFqName(d)
if (fqName != null) {
if (d is IrMemberWithContainerSource) {
val containerBinaryPath = getContainerSourceBinaryPath(d.containerSource)
if (containerBinaryPath != null) {
return normalizeExternalFileClassBinaryPath(containerBinaryPath, fqName)
}
}
return getUnknownBinaryLocation(fqName.asString())
}
}
return null
}
/**
* Attempts to get the binary file path from a container source (typically a
* [JvmPackagePartSource]). Returns null if the path is unavailable.
*/
fun getContainerSourceBinaryPath(containerSource: org.jetbrains.kotlin.serialization.deserialization.descriptors.DeserializedContainerSource?): String? {
if (containerSource !is JvmPackagePartSource) return null
val binaryClass = containerSource.knownJvmBinaryClass ?: return null
return when (binaryClass) {
is VirtualFileKotlinClass -> {
val vf = binaryClass.file
val path = vf.path
if (vf.fileSystem.protocol == StandardFileSystems.JRT_PROTOCOL)
"/${path.split("!/", limit = 2)[1]}"
else path
}
else -> binaryClass.location.takeIf { it.isNotEmpty() }
}
}
private fun getUnknownBinaryLocation(s: String): String {
return "/!unknown-binary-location/${s.replace(".", "/")}.class"
}
fun normalizeExternalFileClassBinaryPath(path: String, fqName: FqName): String {
if (path.contains(".kotlinc_installed")) {
return getUnknownBinaryLocation(fqName.asString())
}
val normalizedPath = path.replace('\\', '/')
val classInternalPath = "${fqName.asString().replace(".", "/")}.class"
val classSuffix = "/$classInternalPath"
if (normalizedPath.endsWith(classSuffix)) {
val classpathRoot = normalizedPath.removeSuffix(classSuffix).substringAfterLast('/')
if (classpathRoot.isNotEmpty()) {
return "$classpathRoot/$classInternalPath"
}
}
return path
}
fun shouldUseConcreteExternalFileClassLocation(path: String): Boolean {
val normalizedPath = path.replace('\\', '/')
return normalizedPath.contains("/") &&
!normalizedPath.startsWith("/!unknown-binary-location/")
}
fun getJavaEquivalentClassId(c: IrClass) =
c.fqNameWhenAvailable?.toUnsafe()?.let { JavaToKotlinClassMap.mapKotlinToJava(it) }
/**
* Checks whether a specific parameter of a Java binary method (identified by [methodName] and
* [paramIndex]) is a reference type (as opposed to a Java primitive). This is used to detect
* cases where K2 FIR has enhanced a reference type parameter (e.g. `@NotNull Integer`) to a
* Kotlin primitive (e.g. `kotlin.Int`), so that callable labels can use the original reference
* type and remain compatible with the Java extractor's callable IDs.
*
* Under K1, binary Java classes use [JavaSourceElement] and we can check [BinaryJavaClass.methods]
* directly. Under K2, they use [VirtualFileBasedSourceElement] and we fall back to reading the
* class bytes with ASM.
*
* Returns `null` if the information cannot be determined.
*/
fun javaBinaryMethodParamIsClassifierType(
parentClass: IrClass,
methodName: String,
nParams: Int,
isConstructor: Boolean,
paramIndex: Int
): Boolean? {
// K1 path: binary Java class has JavaSourceElement with a BinaryJavaClass.
val k1ParamKinds =
((parentClass.source as? JavaSourceElement)?.javaElement as? BinaryJavaClass)?.let {
binaryJavaClass ->
if (isConstructor)
binaryJavaClass.constructors
.asSequence()
.filter { it.valueParameters.size == nParams }
.mapNotNull { it.valueParameters.getOrNull(paramIndex)?.type }
.map { it is org.jetbrains.kotlin.load.java.structure.JavaClassifierType }
.toSet()
else
binaryJavaClass.methods
.asSequence()
.filter { it.name.asString() == methodName && it.valueParameters.size == nParams }
.mapNotNull { it.valueParameters.getOrNull(paramIndex)?.type }
.map { it is org.jetbrains.kotlin.load.java.structure.JavaClassifierType }
.toSet()
}
if (k1ParamKinds != null && k1ParamKinds.isNotEmpty()) {
return k1ParamKinds.singleOrNull()
}
// K2 path: binary Java class has VirtualFileBasedSourceElement
if (parentClass.source !is VirtualFileBasedSourceElement) return null
val vf = (parentClass.source as VirtualFileBasedSourceElement).virtualFile
if (!vf.name.endsWith(".class")) return null
return try {
val bytes = vf.contentsToByteArray()
val expectedMethodName = if (isConstructor) "<init>" else methodName
val descriptorKinds = mutableSetOf<Boolean>()
val reader = org.jetbrains.org.objectweb.asm.ClassReader(bytes)
reader.accept(
object : org.jetbrains.org.objectweb.asm.ClassVisitor(
org.jetbrains.org.objectweb.asm.Opcodes.ASM9
) {
override fun visitMethod(
access: Int,
name: String,
descriptor: String,
signature: String?,
exceptions: Array<String>?
): org.jetbrains.org.objectweb.asm.MethodVisitor? {
if (name != expectedMethodName) return null
val paramDescriptors = parseAsmMethodDescriptorParams(descriptor)
if (paramDescriptors.size != nParams) return null
val paramDesc = paramDescriptors.getOrNull(paramIndex) ?: return null
// Reference types start with 'L' or '['; Java primitives are single chars
descriptorKinds.add(paramDesc.startsWith("L") || paramDesc.startsWith("["))
return null
}
},
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_CODE or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_DEBUG or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_FRAMES
)
descriptorKinds.singleOrNull()
} catch (e: Exception) {
null
}
}
/**
* Checks whether the return type of a Java binary method (identified by [methodName] and
* [nParams]) is a reference type (as opposed to a Java primitive).
*
* Returns `null` if the information cannot be determined.
*/
fun javaBinaryMethodReturnIsClassifierType(
parentClass: IrClass,
methodName: String,
nParams: Int,
isConstructor: Boolean
): Boolean? {
if (isConstructor) return false
// K1 path: binary Java class has JavaSourceElement with a BinaryJavaClass.
val k1ReturnKinds =
((parentClass.source as? JavaSourceElement)?.javaElement as? BinaryJavaClass)?.methods
?.asSequence()
?.filter { it.name.asString() == methodName && it.valueParameters.size == nParams }
?.map { it.returnType is org.jetbrains.kotlin.load.java.structure.JavaClassifierType }
?.toSet()
if (k1ReturnKinds != null && k1ReturnKinds.isNotEmpty()) {
return k1ReturnKinds.singleOrNull()
}
// K2 path: binary Java class has VirtualFileBasedSourceElement
if (parentClass.source !is VirtualFileBasedSourceElement) return null
val vf = (parentClass.source as VirtualFileBasedSourceElement).virtualFile
if (!vf.name.endsWith(".class")) return null
return try {
val bytes = vf.contentsToByteArray()
val returnKinds = mutableSetOf<Boolean>()
val reader = org.jetbrains.org.objectweb.asm.ClassReader(bytes)
reader.accept(
object : org.jetbrains.org.objectweb.asm.ClassVisitor(
org.jetbrains.org.objectweb.asm.Opcodes.ASM9
) {
override fun visitMethod(
access: Int,
name: String,
descriptor: String,
signature: String?,
exceptions: Array<String>?
): org.jetbrains.org.objectweb.asm.MethodVisitor? {
if (name != methodName) return null
if (parseAsmMethodDescriptorParams(descriptor).size != nParams) return null
val returnDescriptor = descriptor.substring(descriptor.lastIndexOf(')') + 1)
returnKinds.add(
returnDescriptor.startsWith("L") || returnDescriptor.startsWith("[")
)
return null
}
},
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_CODE or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_DEBUG or
org.jetbrains.org.objectweb.asm.ClassReader.SKIP_FRAMES
)
returnKinds.singleOrNull()
} catch (e: Exception) {
null
}
}
fun parseAsmMethodDescriptorParams(descriptor: String): List<String> {
val params = mutableListOf<String>()
var i = descriptor.indexOf('(') + 1
val end = descriptor.lastIndexOf(')')
while (i < end) {
when (val c = descriptor[i]) {
'L' -> {
val semi = descriptor.indexOf(';', i)
params.add(descriptor.substring(i, semi + 1))
i = semi + 1
}
'[' -> {
var j = i + 1
while (j < end && descriptor[j] == '[') j++
if (descriptor[j] == 'L') {
val semi = descriptor.indexOf(';', j)
params.add(descriptor.substring(i, semi + 1))
i = semi + 1
} else {
params.add(descriptor.substring(i, j + 1))
i = j + 1
}
}
else -> { params.add(c.toString()); i++ }
}
}
return params
}

View File

@@ -14,9 +14,7 @@ pluginManagement {
repositories {
gradlePluginPortal()
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
}
dependencyResolutionManagement {
@@ -35,9 +33,7 @@ dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
}
rootProject.name = "Android Sample"

View File

@@ -14,9 +14,7 @@ pluginManagement {
repositories {
gradlePluginPortal()
google()
maven {
url = uri("https://maven-central.storage-download.googleapis.com/maven2/")
}
mavenCentral()
}
}
dependencyResolutionManagement {
@@ -35,9 +33,7 @@ dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
maven {
url = uri("https://maven-central.storage-download.googleapis.com/maven2/")
}
mavenCentral()
}
}
rootProject.name = "Android Sample"

View File

@@ -14,9 +14,7 @@ pluginManagement {
repositories {
gradlePluginPortal()
google()
maven {
url = uri("https://maven-central.storage-download.googleapis.com/maven2/")
}
mavenCentral()
}
}
dependencyResolutionManagement {
@@ -35,9 +33,7 @@ dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
maven {
url = uri("https://maven-central.storage-download.googleapis.com/maven2/")
}
mavenCentral()
}
}
rootProject.name = "Android Sample"

View File

@@ -14,9 +14,7 @@ pluginManagement {
repositories {
gradlePluginPortal()
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
}
dependencyResolutionManagement {
@@ -35,9 +33,7 @@ dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
}
rootProject.name = "Android Sample"

View File

@@ -13,9 +13,7 @@ buildscript {
repositories {
google()
maven {
url = uri("https://maven-central.storage-download.googleapis.com/maven2/")
}
jcenter()
}
/**
@@ -41,8 +39,6 @@ buildscript {
allprojects {
repositories {
google()
maven {
url = uri("https://maven-central.storage-download.googleapis.com/maven2/")
}
jcenter()
}
}

View File

@@ -13,9 +13,7 @@ buildscript {
repositories {
google()
maven {
url = uri("https://maven-central.storage-download.googleapis.com/maven2/")
}
jcenter()
}
/**
@@ -41,8 +39,6 @@ buildscript {
allprojects {
repositories {
google()
maven {
url = uri("https://maven-central.storage-download.googleapis.com/maven2/")
}
jcenter()
}
}

View File

@@ -13,9 +13,7 @@ buildscript {
repositories {
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
jcenter()
}
/**
@@ -41,8 +39,6 @@ buildscript {
allprojects {
repositories {
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
jcenter()
}
}

View File

@@ -13,9 +13,7 @@ buildscript {
repositories {
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
jcenter()
}
/**
@@ -34,15 +32,13 @@ buildscript {
* dependencies used by all modules in your project, such as third-party plugins
* or libraries. However, you should configure module-specific dependencies in
* each module-level build.gradle file. For new projects, Android Studio
* includes Maven Central and Google's Maven repository by default, but it does not
* includes JCenter and Google's Maven repository by default, but it does not
* configure any dependencies (unless you select a template that requires some).
*/
allprojects {
repositories {
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
jcenter()
}
}

View File

@@ -14,9 +14,7 @@ pluginManagement {
repositories {
gradlePluginPortal()
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
}
dependencyResolutionManagement {
@@ -35,9 +33,7 @@ dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
}
rootProject.name = "Android Sample"

View File

@@ -8,9 +8,7 @@
apply plugin: 'java-library'
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
dependencies {

View File

@@ -1,5 +1,5 @@
https://maven-central.storage-download.googleapis.com/maven2/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar
https://maven-central.storage-download.googleapis.com/maven2/org/apiguardian/apiguardian-api/1.1.2/apiguardian-api-1.1.2.jar
https://maven-central.storage-download.googleapis.com/maven2/org/junit/jupiter/junit-jupiter-api/5.12.1/junit-jupiter-api-5.12.1.jar
https://maven-central.storage-download.googleapis.com/maven2/org/junit/platform/junit-platform-commons/1.12.1/junit-platform-commons-1.12.1.jar
https://maven-central.storage-download.googleapis.com/maven2/org/opentest4j/opentest4j/1.3.0/opentest4j-1.3.0.jar
https://repo.maven.apache.org/maven2/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar
https://repo.maven.apache.org/maven2/org/apiguardian/apiguardian-api/1.1.2/apiguardian-api-1.1.2.jar
https://repo.maven.apache.org/maven2/org/junit/jupiter/junit-jupiter-api/5.12.1/junit-jupiter-api-5.12.1.jar
https://repo.maven.apache.org/maven2/org/junit/platform/junit-platform-commons/1.12.1/junit-platform-commons-1.12.1.jar
https://repo.maven.apache.org/maven2/org/opentest4j/opentest4j/1.3.0/opentest4j-1.3.0.jar

View File

@@ -8,9 +8,7 @@
apply plugin: 'java-library'
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
dependencies {

View File

@@ -1,2 +1,2 @@
https://maven-central.storage-download.googleapis.com/maven2/joda-time/joda-time/2.12.7/joda-time-2.12.7-no-tzdb.jar
https://maven-central.storage-download.googleapis.com/maven2/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar
https://repo.maven.apache.org/maven2/joda-time/joda-time/2.12.7/joda-time-2.12.7-no-tzdb.jar
https://repo.maven.apache.org/maven2/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar

View File

@@ -12,9 +12,9 @@ apply plugin: 'java'
// In this section you declare where to find the dependencies of your project
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use 'jcenter' for resolving your dependencies.
// You can declare any Maven/Ivy/file repository here.
jcenter()
}
// In this section you declare the dependencies for your production and test code

View File

@@ -8,9 +8,7 @@
apply plugin: 'java-library'
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
dependencies {

View File

@@ -1 +1 @@
https://maven-central.storage-download.googleapis.com/maven2/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar
https://repo.maven.apache.org/maven2/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar

View File

@@ -8,9 +8,7 @@
apply plugin: 'java-library'
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
dependencies {

View File

@@ -1 +1 @@
https://maven-central.storage-download.googleapis.com/maven2/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar
https://repo.maven.apache.org/maven2/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar

View File

@@ -1,7 +1,6 @@
https://maven-central.storage-download.googleapis.com/maven2/junit/junit/4.11/junit-4.11.jar
https://maven-central.storage-download.googleapis.com/maven2/junit/junit/4.12/junit-4.12.jar
https://maven-central.storage-download.googleapis.com/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar
https://maven-central.storage-download.googleapis.com/maven2/org/slf4j/slf4j-api/1.7.21/slf4j-api-1.7.21.jar
https://jcenter.bintray.com/junit/junit/4.12/junit-4.12.jar
https://jcenter.bintray.com/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar
https://jcenter.bintray.com/org/slf4j/slf4j-api/1.7.21/slf4j-api-1.7.21.jar
https://repo.maven.apache.org/maven2/com/feiniaojin/naaf/naaf-graceful-response-example/1.0/naaf-graceful-response-example-1.0.jar
https://repo.maven.apache.org/maven2/com/github/MoebiusSolutions/avro-registry-in-source/avro-registry-in-source-tests/1.8/avro-registry-in-source-tests-1.8.jar
https://repo.maven.apache.org/maven2/com/github/MoebiusSolutions/avro-registry-in-source/example-project/1.5/example-project-1.5.jar
@@ -13,6 +12,7 @@ https://repo.maven.apache.org/maven2/de/knutwalker/rx-redis-example_2.11/0.1.2/r
https://repo.maven.apache.org/maven2/de/knutwalker/rx-redis-java-example_2.11/0.1.2/rx-redis-java-example_2.11-0.1.2.jar
https://repo.maven.apache.org/maven2/io/github/scrollsyou/example-spring-boot-starter/1.0.0/example-spring-boot-starter-1.0.0.jar
https://repo.maven.apache.org/maven2/io/streamnative/com/example/maven-central-template/server/3.0.0/server-3.0.0.jar
https://repo.maven.apache.org/maven2/junit/junit/4.11/junit-4.11.jar
https://repo.maven.apache.org/maven2/no/nav/security/token-validation-ktor-demo/3.1.0/token-validation-ktor-demo-3.1.0.jar
https://repo.maven.apache.org/maven2/org/minijax/minijax-example-fileupload/0.5.10/minijax-example-fileupload-0.5.10.jar
https://repo.maven.apache.org/maven2/org/minijax/minijax-example-inject/0.5.10/minijax-example-inject-0.5.10.jar

View File

@@ -12,9 +12,9 @@ apply plugin: 'java'
// In this section you declare where to find the dependencies of your project
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use 'jcenter' for resolving your dependencies.
// You can declare any Maven/Ivy/file repository here.
jcenter()
}
// In this section you declare the dependencies for your production and test code

View File

@@ -12,9 +12,9 @@ apply plugin: 'java'
// In this section you declare where to find the dependencies of your project
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use 'jcenter' for resolving your dependencies.
// You can declare any Maven/Ivy/file repository here.
jcenter()
}
// In this section you declare the dependencies for your production and test code

View File

@@ -1,10 +0,0 @@
<settings>
<mirrors>
<mirror>
<id>google-maven-central</id>
<name>GCS Maven Central mirror</name>
<url>https://maven-central.storage-download.googleapis.com/maven2/</url>
<mirrorOf>central</mirrorOf>
</mirror>
</mirrors>
</settings>

View File

@@ -26,5 +26,4 @@ maven-project-2/src/main/resources/my-app.properties
maven-project-2/src/main/resources/page.xml
maven-project-2/src/main/resources/struts.xml
maven-project-2/src/test/java/com/example/AppTest4.java
settings.xml
test-db/working/settings.xml

View File

@@ -1,5 +1,3 @@
import os
def test(codeql, use_java_11, java, actions_toolchains_file, check_diagnostics_java):
# The version of gradle used doesn't work on java 17
codeql.database.create(
@@ -7,6 +5,5 @@ def test(codeql, use_java_11, java, actions_toolchains_file, check_diagnostics_j
"CODEQL_EXTRACTOR_JAVA_OPTION_BUILDLESS": "true",
"CODEQL_EXTRACTOR_JAVA_OPTION_BUILDLESS_CLASSPATH_FROM_BUILD_FILES": "true",
"LGTM_INDEX_MAVEN_TOOLCHAINS_FILE": str(actions_toolchains_file),
"LGTM_INDEX_MAVEN_SETTINGS_FILE": os.path.join(os.path.dirname(os.path.realpath(__file__)), "settings.xml"),
}
)

View File

@@ -14,9 +14,7 @@ pluginManagement {
repositories {
gradlePluginPortal()
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
}
dependencyResolutionManagement {
@@ -35,9 +33,7 @@ dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
}
rootProject.name = "Android Sample"

View File

@@ -12,9 +12,9 @@ apply plugin: 'java'
// In this section you declare where to find the dependencies of your project
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use 'jcenter' for resolving your dependencies.
// You can declare any Maven/Ivy/file repository here.
jcenter()
}
// In this section you declare the dependencies for your production and test code

View File

@@ -12,9 +12,9 @@ apply plugin: 'java'
// In this section you declare where to find the dependencies of your project
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use 'jcenter' for resolving your dependencies.
// You can declare any Maven/Ivy/file repository here.
jcenter()
}
// In this section you declare the dependencies for your production and test code

View File

@@ -12,9 +12,8 @@ plugins {
}
repositories {
maven {
url = uri("https://maven-central.storage-download.googleapis.com/maven2/")
}
// Use Maven Central for resolving dependencies.
mavenCentral()
}
dependencies {

View File

@@ -12,9 +12,9 @@ apply plugin: 'java'
// In this section you declare where to find the dependencies of your project
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use 'jcenter' for resolving your dependencies.
// You can declare any Maven/Ivy/file repository here.
jcenter()
}
// In this section you declare the dependencies for your production and test code

View File

@@ -12,9 +12,9 @@ apply plugin: 'java'
// In this section you declare where to find the dependencies of your project
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use 'jcenter' for resolving your dependencies.
// You can declare any Maven/Ivy/file repository here.
jcenter()
}
// In this section you declare the dependencies for your production and test code

View File

@@ -12,9 +12,9 @@ apply plugin: 'java'
// In this section you declare where to find the dependencies of your project
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use 'jcenter' for resolving your dependencies.
// You can declare any Maven/Ivy/file repository here.
jcenter()
}
// In this section you declare the dependencies for your production and test code

View File

@@ -12,9 +12,9 @@ apply plugin: 'java'
// In this section you declare where to find the dependencies of your project
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use 'jcenter' for resolving your dependencies.
// You can declare any Maven/Ivy/file repository here.
jcenter()
}
// In this section you declare the dependencies for your production and test code

View File

@@ -11,9 +11,7 @@ version = '0.0.1-SNAPSHOT'
// but I omit it to test we recognise the Spring Boot plugin version.
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
dependencies {

View File

@@ -15,9 +15,8 @@ plugins {
}
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use Maven Central for resolving dependencies.
mavenCentral()
}
application {

View File

@@ -1,11 +1,11 @@
import pathlib
def test(codeql, java_full):
def test(codeql, java_full, kotlinc_2_3_20):
java_srcs = " ".join([str(s) for s in pathlib.Path().glob("*.java")])
codeql.database.create(
command=[
f"javac {java_srcs} -d build",
"kotlinc -language-version 2.0 user.kt -cp build",
"kotlinc -language-version 1.9 user.kt -cp build",
]
)

View File

@@ -1,6 +1,6 @@
import commands
def test(codeql, java_full):
commands.run("kotlinc -language-version 2.0 test.kt -d lib")
codeql.database.create(command="kotlinc -language-version 2.0 user.kt -cp lib")
def test(codeql, java_full, kotlinc_2_3_20):
commands.run("kotlinc -language-version 1.9 test.kt -d lib")
codeql.database.create(command="kotlinc -language-version 1.9 user.kt -cp lib")

View File

@@ -9,4 +9,4 @@
| Percentage of calls with call target | 100 |
| Total number of lines | 3 |
| Total number of lines with extension kt | 3 |
| Uses Kotlin 2: true | 1 |
| Uses Kotlin 2: false | 1 |

View File

@@ -1,2 +1,2 @@
def test(codeql, java_full):
codeql.database.create(command="kotlinc -J-Xmx2G -language-version 2.0 SomeClass.kt")
def test(codeql, java_full, kotlinc_2_3_20):
codeql.database.create(command=f"kotlinc -J-Xmx2G -language-version 1.9 SomeClass.kt")

View File

@@ -1,6 +1,6 @@
import commands
def test(codeql, java_full):
commands.run("kotlinc -language-version 2.0 A.kt")
codeql.database.create(command="kotlinc -cp . -language-version 2.0 B.kt C.kt")
def test(codeql, java_full, kotlinc_2_3_20):
commands.run("kotlinc -language-version 1.9 A.kt")
codeql.database.create(command="kotlinc -cp . -language-version 1.9 B.kt C.kt")

View File

@@ -15,9 +15,8 @@ plugins {
}
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use Maven Central for resolving dependencies.
mavenCentral()
}
application {

View File

@@ -4,9 +4,7 @@ plugins {
}
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
mavenCentral()
}
dependencies {

View File

@@ -1,6 +1,6 @@
import commands
def test(codeql, java_full):
def test(codeql, java_full, kotlinc_2_3_20):
commands.run(["javac", "Test.java", "-d", "bin"])
codeql.database.create(command="kotlinc -language-version 2.0 user.kt -cp bin")
codeql.database.create(command="kotlinc -language-version 1.9 user.kt -cp bin")

View File

@@ -1,13 +1,13 @@
import commands
def test(codeql, java_full):
def test(codeql, java_full, kotlinc_2_3_20):
# Compile the JavaDefns2 copy outside tracing, to make sure the Kotlin view of it matches the Java view seen by the traced javac compilation of JavaDefns.java below.
commands.run(["javac", "JavaDefns2.java"])
codeql.database.create(
command=[
"kotlinc kotlindefns.kt",
"javac JavaUser.java JavaDefns.java -cp .",
"kotlinc -language-version 2.0 -cp . kotlinuser.kt",
"kotlinc -language-version 1.9 -cp . kotlinuser.kt",
]
)

View File

@@ -15,9 +15,8 @@ plugins {
}
repositories {
maven {
url = 'https://maven-central.storage-download.googleapis.com/maven2/'
}
// Use Maven Central for resolving dependencies.
mavenCentral()
}
application {

View File

@@ -30,7 +30,5 @@ nodes
| BadMacUse.java:152:42:152:51 | ciphertext | semmle.label | ciphertext |
subpaths
testFailures
| BadMacUse.java:50:56:50:66 | // $ Source | Missing result: Source |
| BadMacUse.java:63:118:63:128 | // $ Source | Missing result: Source |
| BadMacUse.java:92:31:92:35 | bytes : byte[] | Unexpected result: Source |
| BadMacUse.java:146:95:146:105 | // $ Source | Missing result: Source |

View File

@@ -31,7 +31,7 @@ nodes
| BadMacUse.java:124:42:124:51 | ciphertext | semmle.label | ciphertext |
subpaths
testFailures
| BadMacUse.java:63:118:63:128 | // $ Source | Missing result: Source |
| BadMacUse.java:50:28:50:53 | doFinal(...) : byte[] | Fixed missing result: Source |
| BadMacUse.java:92:16:92:36 | doFinal(...) : byte[] | Unexpected result: Source |
| BadMacUse.java:124:42:124:51 | ciphertext | Unexpected result: Alert |
| BadMacUse.java:146:95:146:105 | // $ Source | Missing result: Source |

View File

@@ -45,7 +45,7 @@ nodes
| BadMacUse.java:152:42:152:51 | ciphertext | semmle.label | ciphertext |
subpaths
testFailures
| BadMacUse.java:50:56:50:66 | // $ Source | Missing result: Source |
| BadMacUse.java:63:82:63:97 | plaintext : byte[] | Fixed missing result: Source |
| BadMacUse.java:139:79:139:90 | input : byte[] | Unexpected result: Source |
| BadMacUse.java:146:95:146:105 | // $ Source | Missing result: Source |
| BadMacUse.java:152:42:152:51 | ciphertext | Unexpected result: Alert |

View File

@@ -47,7 +47,7 @@ class BadMacUse {
SecretKey encryptionKey = new SecretKeySpec(encryptionKeyBytes, "AES");
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
cipher.init(Cipher.DECRYPT_MODE, encryptionKey, new SecureRandom());
byte[] plaintext = cipher.doFinal(ciphertext); // $ Source
byte[] plaintext = cipher.doFinal(ciphertext); // $ MISSING: Source
// Now verify MAC (too late)
SecretKey macKey = new SecretKeySpec(macKeyBytes, "HmacSHA256");
@@ -60,7 +60,7 @@ class BadMacUse {
}
}
public void BadMacOnPlaintext(byte[] encryptionKeyBytes, byte[] macKeyBytes, byte[] plaintext) throws Exception {// $ Source
public void BadMacOnPlaintext(byte[] encryptionKeyBytes, byte[] macKeyBytes, byte[] plaintext) throws Exception {// $ MISSING: Source
// Create keys directly from provided byte arrays
SecretKey encryptionKey = new SecretKeySpec(encryptionKeyBytes, "AES");
SecretKey macKey = new SecretKeySpec(macKeyBytes, "HmacSHA256");

View File

@@ -126,5 +126,3 @@ nodes
| InsecureIVorNonceSource.java:202:54:202:55 | iv : byte[] | semmle.label | iv : byte[] |
| InsecureIVorNonceSource.java:206:51:206:56 | ivSpec | semmle.label | ivSpec |
subpaths
testFailures
| InsecureIVorNonceSource.java:42:21:42:21 | 1 : Number | Unexpected result: Source |

View File

@@ -39,7 +39,7 @@ public class InsecureIVorNonceSource {
public byte[] encryptWithStaticIvByteArray(byte[] key, byte[] plaintext) throws Exception {
byte[] iv = new byte[16];
for (byte i = 0; i < iv.length; i++) {
iv[i] = 1;
iv[i] = 1; // $ Source
}
IvParameterSpec ivSpec = new IvParameterSpec(iv);

View File

@@ -40,11 +40,11 @@ public class Test {
* SAST/CBOM: - Parent: PBKDF2. - Iteration count is only 10, which is far
* below acceptable security standards. - Flagged as insecure.
*/
public void pbkdf2LowIteration(String password, int iterationCount) throws Exception { // $ Source
public void pbkdf2LowIteration(String password, int iterationCount) throws Exception { // $ MISSING: Source
byte[] salt = generateSalt(16);
PBEKeySpec spec = new PBEKeySpec(password.toCharArray(), salt, iterationCount, 256); // $ Alert[java/quantum/examples/unknown-kdf-iteration-count]
PBEKeySpec spec = new PBEKeySpec(password.toCharArray(), salt, iterationCount, 256);
SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256");
byte[] key = factory.generateSecret(spec).getEncoded();
byte[] key = factory.generateSecret(spec).getEncoded(); // $ Alert[java/quantum/examples/unknown-kdf-iteration-count]
}
/**

View File

@@ -1,5 +1 @@
#select
| Test.java:47:22:47:49 | KeyDerivation | Key derivation operation with unknown iteration: $@ | Test.java:43:53:43:70 | iterationCount | iterationCount |
testFailures
| Test.java:45:94:45:154 | // $ Alert[java/quantum/examples/unknown-kdf-iteration-count] | Missing result: Alert[java/quantum/examples/unknown-kdf-iteration-count] |
| Test.java:47:22:47:49 | Key derivation operation with unknown iteration: $@ | Unexpected result: Alert |

View File

@@ -12,5 +12,3 @@ nodes
| Test.java:58:30:58:38 | 1_000_000 : Number | semmle.label | 1_000_000 : Number |
| Test.java:59:72:59:85 | iterationCount | semmle.label | iterationCount |
subpaths
testFailures
| Test.java:43:92:43:102 | // $ Source | Missing result: Source |

View File

@@ -1,4 +0,0 @@
---
category: minorAnalysis
---
* Added support for Angular's `@HostListener('window:message', ...)` and `@HostListener('document:message', ...)` decorators as `postMessage` event handlers. The decorated method's event parameter is now recognized as a client-side remote flow source, and is considered by the `js/missing-origin-check` query.

View File

@@ -195,18 +195,6 @@ class PostMessageEventHandler extends Function {
rhs = DataFlow::globalObjectRef().getAPropertyWrite("onmessage").getRhs() and
rhs.getABoundFunctionValue(paramIndex).getFunction() = this
)
or
// Angular's `@HostListener('window:message', ['$event'])` decorator registers
// a method as a `message` event handler on the global `window` or `document`
// target. The decorated method receives the `MessageEvent` as its first
// parameter, so it is equivalent to `window.addEventListener('message', ...)`.
exists(MethodDefinition method, DataFlow::CallNode decorator |
decorator = DataFlow::moduleMember("@angular/core", "HostListener").getACall() and
decorator = method.getADecorator().getExpression().flow() and
decorator.getArgument(0).mayHaveStringValue(["window:message", "document:message"]) and
method.getBody() = this and
paramIndex = 0
)
}
/**

View File

@@ -1,29 +0,0 @@
import { Component, HostListener } from '@angular/core';
@Component({ selector: 'app-root' })
class AngularComponent {
// Angular registers this as a `window` message handler via the decorator,
// equivalent to `window.addEventListener('message', ...)`.
@HostListener('window:message', ['$event'])
onWindowMessage(event: MessageEvent): void { // $ Alert - no origin check
eval(event.data);
}
@HostListener('document:message', ['$event'])
onDocumentMessage(event: MessageEvent): void { // $ Alert - no origin check
eval(event.data);
}
@HostListener('window:message', ['$event'])
onCheckedMessage(event: MessageEvent): void { // OK - has an origin check
if (event.origin === 'https://www.example.com') {
eval(event.data);
}
}
// Not a message event, so it is not a postMessage handler.
@HostListener('window:resize', ['$event'])
onResize(event: MessageEvent): void { // OK - not a message handler
eval(event.data);
}
}

View File

@@ -1,5 +1,3 @@
| Angular.ts:8:19:8:23 | event | Postmessage handler has no origin check. |
| Angular.ts:13:21:13:25 | event | Postmessage handler has no origin check. |
| tst.js:11:20:11:24 | event | Postmessage handler has no origin check. |
| tst.js:24:27:24:27 | e | Postmessage handler has no origin check. |
| tst.js:40:27:40:27 | e | Postmessage handler has no origin check. |

View File

@@ -3,5 +3,5 @@ argumentToEnsureNotTaintedNotMarkedAsSpurious
untaintedArgumentToEnsureTaintedNotMarkedAsMissing
| taint_test.py:32:9:32:25 | taint_test.py:32 | ERROR, you should add `# $ MISSING: tainted` annotation | should_be_tainted |
| taint_test.py:37:24:37:40 | taint_test.py:37 | ERROR, you should add `# $ MISSING: tainted` annotation | should_be_tainted |
| taint_test.py:41:24:41:40 | taint_test.py:41 | ERROR, you should add `# $ MISSING: tainted` annotation | should_be_tainted |
testFailures
| taint_test.py:41:20:41:21 | ts | Fixed missing result: tainted |

View File

@@ -38,7 +38,7 @@ def bad_usage():
# if you try to get around it by adding BOTH annotations, that results in a problem
# from the default set of inline-test-expectation rules
ensure_tainted(ts, should_be_tainted) # $ tainted MISSING: tainted
ensure_tainted(ts, should_be_tainted) # $ tainted
# simulating handling something we _want_ to treat at untainted, but we currently treat as tainted
should_not_be_tainted = "pretend this is now safe" + ts

View File

@@ -1312,244 +1312,6 @@ module QL {
/** Gets a field or child node of this node. */
final override AstNode getAFieldOrChild() { ql_variable_def(this, result) }
}
/** Provides predicates for mapping AST nodes to their named children. */
module PrintAst {
/** Gets a child of `node` returned by the member predicate with the given `name`. If the predicate takes an index argument, `i` is bound to that index, otherwise `i` is `-1` (which is never a valid index). */
AstNode getChild(AstNode node, string name, int i) {
result = node.(AddExpr).getLeft() and i = -1 and name = "getLeft"
or
result = node.(AddExpr).getRight() and i = -1 and name = "getRight"
or
result = node.(AddExpr).getChild() and i = -1 and name = "getChild"
or
result = node.(Aggregate).getChild(i) and name = "getChild"
or
result = node.(AnnotArg).getChild() and i = -1 and name = "getChild"
or
result = node.(Annotation).getArgs(i) and name = "getArgs"
or
result = node.(Annotation).getName() and i = -1 and name = "getName"
or
result = node.(AritylessPredicateExpr).getName() and i = -1 and name = "getName"
or
result = node.(AritylessPredicateExpr).getQualifier() and i = -1 and name = "getQualifier"
or
result = node.(AsExpr).getChild(i) and name = "getChild"
or
result = node.(AsExprs).getChild(i) and name = "getChild"
or
result = node.(Body).getChild() and i = -1 and name = "getChild"
or
result = node.(Bool).getChild() and i = -1 and name = "getChild"
or
result = node.(CallBody).getChild(i) and name = "getChild"
or
result = node.(CallOrUnqualAggExpr).getChild(i) and name = "getChild"
or
result = node.(Charpred).getBody() and i = -1 and name = "getBody"
or
result = node.(Charpred).getChild() and i = -1 and name = "getChild"
or
result = node.(ClassMember).getChild(i) and name = "getChild"
or
result = node.(ClasslessPredicate).getName() and i = -1 and name = "getName"
or
result = node.(ClasslessPredicate).getReturnType() and i = -1 and name = "getReturnType"
or
result = node.(ClasslessPredicate).getChild(i) and name = "getChild"
or
result = node.(CompTerm).getLeft() and i = -1 and name = "getLeft"
or
result = node.(CompTerm).getRight() and i = -1 and name = "getRight"
or
result = node.(CompTerm).getChild() and i = -1 and name = "getChild"
or
result = node.(Conjunction).getLeft() and i = -1 and name = "getLeft"
or
result = node.(Conjunction).getRight() and i = -1 and name = "getRight"
or
result = node.(Dataclass).getExtends(i) and name = "getExtends"
or
result = node.(Dataclass).getInstanceof(i) and name = "getInstanceof"
or
result = node.(Dataclass).getName() and i = -1 and name = "getName"
or
result = node.(Dataclass).getChild(i) and name = "getChild"
or
result = node.(Datatype).getName() and i = -1 and name = "getName"
or
result = node.(Datatype).getChild() and i = -1 and name = "getChild"
or
result = node.(DatatypeBranch).getName() and i = -1 and name = "getName"
or
result = node.(DatatypeBranch).getChild(i) and name = "getChild"
or
result = node.(DatatypeBranches).getChild(i) and name = "getChild"
or
result = node.(Disjunction).getLeft() and i = -1 and name = "getLeft"
or
result = node.(Disjunction).getRight() and i = -1 and name = "getRight"
or
result = node.(ExprAggregateBody).getAsExprs() and i = -1 and name = "getAsExprs"
or
result = node.(ExprAggregateBody).getOrderBys() and i = -1 and name = "getOrderBys"
or
result = node.(ExprAnnotation).getAnnotArg() and i = -1 and name = "getAnnotArg"
or
result = node.(ExprAnnotation).getName() and i = -1 and name = "getName"
or
result = node.(ExprAnnotation).getChild() and i = -1 and name = "getChild"
or
result = node.(Field).getChild() and i = -1 and name = "getChild"
or
result = node.(FullAggregateBody).getAsExprs() and i = -1 and name = "getAsExprs"
or
result = node.(FullAggregateBody).getGuard() and i = -1 and name = "getGuard"
or
result = node.(FullAggregateBody).getOrderBys() and i = -1 and name = "getOrderBys"
or
result = node.(FullAggregateBody).getChild(i) and name = "getChild"
or
result = node.(HigherOrderTerm).getName() and i = -1 and name = "getName"
or
result = node.(HigherOrderTerm).getChild(i) and name = "getChild"
or
result = node.(IfTerm).getCond() and i = -1 and name = "getCond"
or
result = node.(IfTerm).getFirst() and i = -1 and name = "getFirst"
or
result = node.(IfTerm).getSecond() and i = -1 and name = "getSecond"
or
result = node.(Implication).getLeft() and i = -1 and name = "getLeft"
or
result = node.(Implication).getRight() and i = -1 and name = "getRight"
or
result = node.(ImportDirective).getChild(i) and name = "getChild"
or
result = node.(ImportModuleExpr).getQualName(i) and name = "getQualName"
or
result = node.(ImportModuleExpr).getChild() and i = -1 and name = "getChild"
or
result = node.(InExpr).getLeft() and i = -1 and name = "getLeft"
or
result = node.(InExpr).getRight() and i = -1 and name = "getRight"
or
result = node.(InstanceOf).getChild(i) and name = "getChild"
or
result = node.(Literal).getChild() and i = -1 and name = "getChild"
or
result = node.(MemberPredicate).getName() and i = -1 and name = "getName"
or
result = node.(MemberPredicate).getReturnType() and i = -1 and name = "getReturnType"
or
result = node.(MemberPredicate).getChild(i) and name = "getChild"
or
result = node.(Module).getImplements(i) and name = "getImplements"
or
result = node.(Module).getName() and i = -1 and name = "getName"
or
result = node.(Module).getParameter(i) and name = "getParameter"
or
result = node.(Module).getChild(i) and name = "getChild"
or
result = node.(ModuleAliasBody).getChild() and i = -1 and name = "getChild"
or
result = node.(ModuleExpr).getName() and i = -1 and name = "getName"
or
result = node.(ModuleExpr).getChild() and i = -1 and name = "getChild"
or
result = node.(ModuleInstantiation).getName() and i = -1 and name = "getName"
or
result = node.(ModuleInstantiation).getChild(i) and name = "getChild"
or
result = node.(ModuleMember).getChild(i) and name = "getChild"
or
result = node.(ModuleName).getChild() and i = -1 and name = "getChild"
or
result = node.(ModuleParam).getParameter() and i = -1 and name = "getParameter"
or
result = node.(ModuleParam).getSignature() and i = -1 and name = "getSignature"
or
result = node.(MulExpr).getLeft() and i = -1 and name = "getLeft"
or
result = node.(MulExpr).getRight() and i = -1 and name = "getRight"
or
result = node.(MulExpr).getChild() and i = -1 and name = "getChild"
or
result = node.(Negation).getChild() and i = -1 and name = "getChild"
or
result = node.(OrderBy).getChild(i) and name = "getChild"
or
result = node.(OrderBys).getChild(i) and name = "getChild"
or
result = node.(ParExpr).getChild() and i = -1 and name = "getChild"
or
result = node.(PredicateAliasBody).getChild() and i = -1 and name = "getChild"
or
result = node.(PredicateExpr).getChild(i) and name = "getChild"
or
result = node.(PrefixCast).getChild(i) and name = "getChild"
or
result = node.(Ql).getChild(i) and name = "getChild"
or
result = node.(QualifiedRhs).getName() and i = -1 and name = "getName"
or
result = node.(QualifiedRhs).getChild(i) and name = "getChild"
or
result = node.(QualifiedExpr).getChild(i) and name = "getChild"
or
result = node.(Quantified).getExpr() and i = -1 and name = "getExpr"
or
result = node.(Quantified).getFormula() and i = -1 and name = "getFormula"
or
result = node.(Quantified).getRange() and i = -1 and name = "getRange"
or
result = node.(Quantified).getChild(i) and name = "getChild"
or
result = node.(Range).getLower() and i = -1 and name = "getLower"
or
result = node.(Range).getUpper() and i = -1 and name = "getUpper"
or
result = node.(Select).getChild(i) and name = "getChild"
or
result = node.(SetLiteral).getChild(i) and name = "getChild"
or
result = node.(SignatureExpr).getModExpr() and i = -1 and name = "getModExpr"
or
result = node.(SignatureExpr).getPredicate() and i = -1 and name = "getPredicate"
or
result = node.(SignatureExpr).getTypeExpr() and i = -1 and name = "getTypeExpr"
or
result = node.(SpecialCall).getChild() and i = -1 and name = "getChild"
or
result = node.(SuperRef).getChild(i) and name = "getChild"
or
result = node.(TypeAliasBody).getChild() and i = -1 and name = "getChild"
or
result = node.(TypeExpr).getName() and i = -1 and name = "getName"
or
result = node.(TypeExpr).getQualifier() and i = -1 and name = "getQualifier"
or
result = node.(TypeExpr).getChild() and i = -1 and name = "getChild"
or
result = node.(TypeUnionBody).getChild(i) and name = "getChild"
or
result = node.(UnaryExpr).getChild(i) and name = "getChild"
or
result = node.(UnqualAggBody).getAsExprs(i) and name = "getAsExprs"
or
result = node.(UnqualAggBody).getGuard() and i = -1 and name = "getGuard"
or
result = node.(UnqualAggBody).getChild(i) and name = "getChild"
or
result = node.(VarDecl).getChild(i) and name = "getChild"
or
result = node.(VarName).getChild() and i = -1 and name = "getChild"
or
result = node.(Variable).getChild() and i = -1 and name = "getChild"
}
}
}
overlay[local]
@@ -1907,60 +1669,6 @@ module Dbscheme {
/** Gets the name of the primary QL class for this element. */
final override string getAPrimaryQlClass() { result = "Varchar" }
}
/** Provides predicates for mapping AST nodes to their named children. */
module PrintAst {
/** Gets a child of `node` returned by the member predicate with the given `name`. If the predicate takes an index argument, `i` is bound to that index, otherwise `i` is `-1` (which is never a valid index). */
AstNode getChild(AstNode node, string name, int i) {
result = node.(Annotation).getArgsAnnotation() and i = -1 and name = "getArgsAnnotation"
or
result = node.(Annotation).getSimpleAnnotation() and i = -1 and name = "getSimpleAnnotation"
or
result = node.(ArgsAnnotation).getName() and i = -1 and name = "getName"
or
result = node.(ArgsAnnotation).getChild(i) and name = "getChild"
or
result = node.(Branch).getQldoc() and i = -1 and name = "getQldoc"
or
result = node.(Branch).getChild(i) and name = "getChild"
or
result = node.(CaseDecl).getBase() and i = -1 and name = "getBase"
or
result = node.(CaseDecl).getDiscriminator() and i = -1 and name = "getDiscriminator"
or
result = node.(CaseDecl).getChild(i) and name = "getChild"
or
result = node.(ColType).getChild() and i = -1 and name = "getChild"
or
result = node.(Column).getColName() and i = -1 and name = "getColName"
or
result = node.(Column).getColType() and i = -1 and name = "getColType"
or
result = node.(Column).getIsRef() and i = -1 and name = "getIsRef"
or
result = node.(Column).getIsUnique() and i = -1 and name = "getIsUnique"
or
result = node.(Column).getQldoc() and i = -1 and name = "getQldoc"
or
result = node.(Column).getReprType() and i = -1 and name = "getReprType"
or
result = node.(Dbscheme).getChild(i) and name = "getChild"
or
result = node.(Entry).getChild() and i = -1 and name = "getChild"
or
result = node.(ReprType).getChild(i) and name = "getChild"
or
result = node.(Table).getTableName() and i = -1 and name = "getTableName"
or
result = node.(Table).getChild(i) and name = "getChild"
or
result = node.(TableName).getChild() and i = -1 and name = "getChild"
or
result = node.(UnionDecl).getBase() and i = -1 and name = "getBase"
or
result = node.(UnionDecl).getChild(i) and name = "getChild"
}
}
}
overlay[local]
@@ -2095,24 +1803,6 @@ module Blame {
/** Gets the name of the primary QL class for this element. */
final override string getAPrimaryQlClass() { result = "Number" }
}
/** Provides predicates for mapping AST nodes to their named children. */
module PrintAst {
/** Gets a child of `node` returned by the member predicate with the given `name`. If the predicate takes an index argument, `i` is bound to that index, otherwise `i` is `-1` (which is never a valid index). */
AstNode getChild(AstNode node, string name, int i) {
result = node.(BlameEntry).getDate() and i = -1 and name = "getDate"
or
result = node.(BlameEntry).getLine(i) and name = "getLine"
or
result = node.(BlameInfo).getFileEntry(i) and name = "getFileEntry"
or
result = node.(BlameInfo).getToday() and i = -1 and name = "getToday"
or
result = node.(FileEntry).getBlameEntry(i) and name = "getBlameEntry"
or
result = node.(FileEntry).getFileName() and i = -1 and name = "getFileName"
}
}
}
overlay[local]
@@ -2287,22 +1977,4 @@ module JSON {
/** Gets the name of the primary QL class for this element. */
final override string getAPrimaryQlClass() { result = "True" }
}
/** Provides predicates for mapping AST nodes to their named children. */
module PrintAst {
/** Gets a child of `node` returned by the member predicate with the given `name`. If the predicate takes an index argument, `i` is bound to that index, otherwise `i` is `-1` (which is never a valid index). */
AstNode getChild(AstNode node, string name, int i) {
result = node.(Array).getChild(i) and name = "getChild"
or
result = node.(Document).getChild(i) and name = "getChild"
or
result = node.(Object).getChild(i) and name = "getChild"
or
result = node.(Pair).getKey() and i = -1 and name = "getKey"
or
result = node.(Pair).getValue() and i = -1 and name = "getValue"
or
result = node.(String).getChild(i) and name = "getChild"
}
}
}

View File

@@ -1964,340 +1964,6 @@ module Ruby {
/** Gets a field or child node of this node. */
final override AstNode getAFieldOrChild() { ruby_yield_child(this, result) }
}
/** Provides predicates for mapping AST nodes to their named children. */
module PrintAst {
/** Gets a child of `node` returned by the member predicate with the given `name`. If the predicate takes an index argument, `i` is bound to that index, otherwise `i` is `-1` (which is never a valid index). */
AstNode getChild(AstNode node, string name, int i) {
result = node.(Alias).getAlias() and i = -1 and name = "getAlias"
or
result = node.(Alias).getName() and i = -1 and name = "getName"
or
result = node.(AlternativePattern).getAlternatives(i) and name = "getAlternatives"
or
result = node.(ArgumentList).getChild(i) and name = "getChild"
or
result = node.(Array).getChild(i) and name = "getChild"
or
result = node.(ArrayPattern).getClass() and i = -1 and name = "getClass"
or
result = node.(ArrayPattern).getChild(i) and name = "getChild"
or
result = node.(AsPattern).getName() and i = -1 and name = "getName"
or
result = node.(AsPattern).getValue() and i = -1 and name = "getValue"
or
result = node.(Assignment).getLeft() and i = -1 and name = "getLeft"
or
result = node.(Assignment).getRight() and i = -1 and name = "getRight"
or
result = node.(BareString).getChild(i) and name = "getChild"
or
result = node.(BareSymbol).getChild(i) and name = "getChild"
or
result = node.(Begin).getChild(i) and name = "getChild"
or
result = node.(BeginBlock).getChild(i) and name = "getChild"
or
result = node.(Binary).getLeft() and i = -1 and name = "getLeft"
or
result = node.(Binary).getRight() and i = -1 and name = "getRight"
or
result = node.(Block).getBody() and i = -1 and name = "getBody"
or
result = node.(Block).getParameters() and i = -1 and name = "getParameters"
or
result = node.(BlockArgument).getChild() and i = -1 and name = "getChild"
or
result = node.(BlockBody).getChild(i) and name = "getChild"
or
result = node.(BlockParameter).getName() and i = -1 and name = "getName"
or
result = node.(BlockParameters).getLocals(i) and name = "getLocals"
or
result = node.(BlockParameters).getChild(i) and name = "getChild"
or
result = node.(BodyStatement).getChild(i) and name = "getChild"
or
result = node.(Break).getChild() and i = -1 and name = "getChild"
or
result = node.(Call).getArguments() and i = -1 and name = "getArguments"
or
result = node.(Call).getBlock() and i = -1 and name = "getBlock"
or
result = node.(Call).getMethod() and i = -1 and name = "getMethod"
or
result = node.(Call).getOperator() and i = -1 and name = "getOperator"
or
result = node.(Call).getReceiver() and i = -1 and name = "getReceiver"
or
result = node.(Case).getValue() and i = -1 and name = "getValue"
or
result = node.(Case).getChild(i) and name = "getChild"
or
result = node.(CaseMatch).getClauses(i) and name = "getClauses"
or
result = node.(CaseMatch).getElse() and i = -1 and name = "getElse"
or
result = node.(CaseMatch).getValue() and i = -1 and name = "getValue"
or
result = node.(ChainedString).getChild(i) and name = "getChild"
or
result = node.(Class).getBody() and i = -1 and name = "getBody"
or
result = node.(Class).getName() and i = -1 and name = "getName"
or
result = node.(Class).getSuperclass() and i = -1 and name = "getSuperclass"
or
result = node.(Complex).getChild() and i = -1 and name = "getChild"
or
result = node.(Conditional).getAlternative() and i = -1 and name = "getAlternative"
or
result = node.(Conditional).getCondition() and i = -1 and name = "getCondition"
or
result = node.(Conditional).getConsequence() and i = -1 and name = "getConsequence"
or
result = node.(DelimitedSymbol).getChild(i) and name = "getChild"
or
result = node.(DestructuredLeftAssignment).getChild(i) and name = "getChild"
or
result = node.(DestructuredParameter).getChild(i) and name = "getChild"
or
result = node.(Do).getChild(i) and name = "getChild"
or
result = node.(DoBlock).getBody() and i = -1 and name = "getBody"
or
result = node.(DoBlock).getParameters() and i = -1 and name = "getParameters"
or
result = node.(ElementReference).getBlock() and i = -1 and name = "getBlock"
or
result = node.(ElementReference).getObject() and i = -1 and name = "getObject"
or
result = node.(ElementReference).getChild(i) and name = "getChild"
or
result = node.(Else).getChild(i) and name = "getChild"
or
result = node.(Elsif).getAlternative() and i = -1 and name = "getAlternative"
or
result = node.(Elsif).getCondition() and i = -1 and name = "getCondition"
or
result = node.(Elsif).getConsequence() and i = -1 and name = "getConsequence"
or
result = node.(EndBlock).getChild(i) and name = "getChild"
or
result = node.(Ensure).getChild(i) and name = "getChild"
or
result = node.(ExceptionVariable).getChild() and i = -1 and name = "getChild"
or
result = node.(Exceptions).getChild(i) and name = "getChild"
or
result = node.(ExpressionReferencePattern).getValue() and i = -1 and name = "getValue"
or
result = node.(FindPattern).getClass() and i = -1 and name = "getClass"
or
result = node.(FindPattern).getChild(i) and name = "getChild"
or
result = node.(For).getBody() and i = -1 and name = "getBody"
or
result = node.(For).getPattern() and i = -1 and name = "getPattern"
or
result = node.(For).getValue() and i = -1 and name = "getValue"
or
result = node.(Hash).getChild(i) and name = "getChild"
or
result = node.(HashPattern).getClass() and i = -1 and name = "getClass"
or
result = node.(HashPattern).getChild(i) and name = "getChild"
or
result = node.(HashSplatArgument).getChild() and i = -1 and name = "getChild"
or
result = node.(HashSplatParameter).getName() and i = -1 and name = "getName"
or
result = node.(HeredocBody).getChild(i) and name = "getChild"
or
result = node.(If).getAlternative() and i = -1 and name = "getAlternative"
or
result = node.(If).getCondition() and i = -1 and name = "getCondition"
or
result = node.(If).getConsequence() and i = -1 and name = "getConsequence"
or
result = node.(IfGuard).getCondition() and i = -1 and name = "getCondition"
or
result = node.(IfModifier).getBody() and i = -1 and name = "getBody"
or
result = node.(IfModifier).getCondition() and i = -1 and name = "getCondition"
or
result = node.(In).getChild() and i = -1 and name = "getChild"
or
result = node.(InClause).getBody() and i = -1 and name = "getBody"
or
result = node.(InClause).getGuard() and i = -1 and name = "getGuard"
or
result = node.(InClause).getPattern() and i = -1 and name = "getPattern"
or
result = node.(Interpolation).getChild(i) and name = "getChild"
or
result = node.(KeywordParameter).getName() and i = -1 and name = "getName"
or
result = node.(KeywordParameter).getValue() and i = -1 and name = "getValue"
or
result = node.(KeywordPattern).getKey() and i = -1 and name = "getKey"
or
result = node.(KeywordPattern).getValue() and i = -1 and name = "getValue"
or
result = node.(Lambda).getBody() and i = -1 and name = "getBody"
or
result = node.(Lambda).getParameters() and i = -1 and name = "getParameters"
or
result = node.(LambdaParameters).getChild(i) and name = "getChild"
or
result = node.(LeftAssignmentList).getChild(i) and name = "getChild"
or
result = node.(MatchPattern).getPattern() and i = -1 and name = "getPattern"
or
result = node.(MatchPattern).getValue() and i = -1 and name = "getValue"
or
result = node.(Method).getBody() and i = -1 and name = "getBody"
or
result = node.(Method).getName() and i = -1 and name = "getName"
or
result = node.(Method).getParameters() and i = -1 and name = "getParameters"
or
result = node.(MethodParameters).getChild(i) and name = "getChild"
or
result = node.(Module).getBody() and i = -1 and name = "getBody"
or
result = node.(Module).getName() and i = -1 and name = "getName"
or
result = node.(Next).getChild() and i = -1 and name = "getChild"
or
result = node.(OperatorAssignment).getLeft() and i = -1 and name = "getLeft"
or
result = node.(OperatorAssignment).getRight() and i = -1 and name = "getRight"
or
result = node.(OptionalParameter).getName() and i = -1 and name = "getName"
or
result = node.(OptionalParameter).getValue() and i = -1 and name = "getValue"
or
result = node.(Pair).getKey() and i = -1 and name = "getKey"
or
result = node.(Pair).getValue() and i = -1 and name = "getValue"
or
result = node.(ParenthesizedPattern).getChild() and i = -1 and name = "getChild"
or
result = node.(ParenthesizedStatements).getChild(i) and name = "getChild"
or
result = node.(Pattern).getChild() and i = -1 and name = "getChild"
or
result = node.(Program).getChild(i) and name = "getChild"
or
result = node.(Range).getBegin() and i = -1 and name = "getBegin"
or
result = node.(Range).getEnd() and i = -1 and name = "getEnd"
or
result = node.(Rational).getChild() and i = -1 and name = "getChild"
or
result = node.(Redo).getChild() and i = -1 and name = "getChild"
or
result = node.(Regex).getChild(i) and name = "getChild"
or
result = node.(Rescue).getBody() and i = -1 and name = "getBody"
or
result = node.(Rescue).getExceptions() and i = -1 and name = "getExceptions"
or
result = node.(Rescue).getVariable() and i = -1 and name = "getVariable"
or
result = node.(RescueModifier).getBody() and i = -1 and name = "getBody"
or
result = node.(RescueModifier).getHandler() and i = -1 and name = "getHandler"
or
result = node.(RestAssignment).getChild() and i = -1 and name = "getChild"
or
result = node.(Retry).getChild() and i = -1 and name = "getChild"
or
result = node.(Return).getChild() and i = -1 and name = "getChild"
or
result = node.(RightAssignmentList).getChild(i) and name = "getChild"
or
result = node.(ScopeResolution).getName() and i = -1 and name = "getName"
or
result = node.(ScopeResolution).getScope() and i = -1 and name = "getScope"
or
result = node.(Setter).getName() and i = -1 and name = "getName"
or
result = node.(SingletonClass).getBody() and i = -1 and name = "getBody"
or
result = node.(SingletonClass).getValue() and i = -1 and name = "getValue"
or
result = node.(SingletonMethod).getBody() and i = -1 and name = "getBody"
or
result = node.(SingletonMethod).getName() and i = -1 and name = "getName"
or
result = node.(SingletonMethod).getObject() and i = -1 and name = "getObject"
or
result = node.(SingletonMethod).getParameters() and i = -1 and name = "getParameters"
or
result = node.(SplatArgument).getChild() and i = -1 and name = "getChild"
or
result = node.(SplatParameter).getName() and i = -1 and name = "getName"
or
result = node.(String).getChild(i) and name = "getChild"
or
result = node.(StringArray).getChild(i) and name = "getChild"
or
result = node.(Subshell).getChild(i) and name = "getChild"
or
result = node.(Superclass).getChild() and i = -1 and name = "getChild"
or
result = node.(SymbolArray).getChild(i) and name = "getChild"
or
result = node.(TestPattern).getPattern() and i = -1 and name = "getPattern"
or
result = node.(TestPattern).getValue() and i = -1 and name = "getValue"
or
result = node.(Then).getChild(i) and name = "getChild"
or
result = node.(Unary).getOperand() and i = -1 and name = "getOperand"
or
result = node.(Undef).getChild(i) and name = "getChild"
or
result = node.(Unless).getAlternative() and i = -1 and name = "getAlternative"
or
result = node.(Unless).getCondition() and i = -1 and name = "getCondition"
or
result = node.(Unless).getConsequence() and i = -1 and name = "getConsequence"
or
result = node.(UnlessGuard).getCondition() and i = -1 and name = "getCondition"
or
result = node.(UnlessModifier).getBody() and i = -1 and name = "getBody"
or
result = node.(UnlessModifier).getCondition() and i = -1 and name = "getCondition"
or
result = node.(Until).getBody() and i = -1 and name = "getBody"
or
result = node.(Until).getCondition() and i = -1 and name = "getCondition"
or
result = node.(UntilModifier).getBody() and i = -1 and name = "getBody"
or
result = node.(UntilModifier).getCondition() and i = -1 and name = "getCondition"
or
result = node.(VariableReferencePattern).getName() and i = -1 and name = "getName"
or
result = node.(When).getBody() and i = -1 and name = "getBody"
or
result = node.(When).getPattern(i) and name = "getPattern"
or
result = node.(While).getBody() and i = -1 and name = "getBody"
or
result = node.(While).getCondition() and i = -1 and name = "getCondition"
or
result = node.(WhileModifier).getBody() and i = -1 and name = "getBody"
or
result = node.(WhileModifier).getCondition() and i = -1 and name = "getCondition"
or
result = node.(Yield).getChild() and i = -1 and name = "getChild"
}
}
}
overlay[local]
@@ -2441,20 +2107,4 @@ module Erb {
/** Gets a field or child node of this node. */
final override AstNode getAFieldOrChild() { erb_template_child(this, _, result) }
}
/** Provides predicates for mapping AST nodes to their named children. */
module PrintAst {
/** Gets a child of `node` returned by the member predicate with the given `name`. If the predicate takes an index argument, `i` is bound to that index, otherwise `i` is `-1` (which is never a valid index). */
AstNode getChild(AstNode node, string name, int i) {
result = node.(CommentDirective).getChild() and i = -1 and name = "getChild"
or
result = node.(Directive).getChild() and i = -1 and name = "getChild"
or
result = node.(GraphqlDirective).getChild() and i = -1 and name = "getChild"
or
result = node.(OutputDirective).getChild() and i = -1 and name = "getChild"
or
result = node.(Template).getChild(i) and name = "getChild"
}
}
}

View File

@@ -28,6 +28,7 @@ nodes
| string_flow.rb:227:10:227:10 | a | semmle.label | a |
subpaths
testFailures
| string_flow.rb:227:10:227:10 | a | Unexpected result: hasValueFlow=a |
#select
| string_flow.rb:3:10:3:22 | call to new | string_flow.rb:2:9:2:18 | call to source | string_flow.rb:3:10:3:22 | call to new | $@ | string_flow.rb:2:9:2:18 | call to source | call to source |
| string_flow.rb:85:10:85:10 | a | string_flow.rb:83:9:83:18 | call to source | string_flow.rb:85:10:85:10 | a | $@ | string_flow.rb:83:9:83:18 | call to source | call to source |

View File

@@ -82,7 +82,7 @@ end
def m_clear
a = source "a"
a.clear
sink a # $ SPURIOUS: hasValueFlow=a
sink a # $ hasValueFlow=a
end
# concat and prepend omitted because they clash with the summaries for
@@ -224,7 +224,7 @@ def m_replace
b = source "b"
sink a.replace(b) # $ hasTaintFlow=b
# TODO: currently we get value flow for a, because we don't clear content
sink a # $ hasTaintFlow=b SPURIOUS: hasValueFlow=a
sink a # $ hasTaintFlow=b
end
def m_reverse
@@ -316,4 +316,4 @@ def m_upto(i)
a.upto("b", true) { |x| sink x } # $ hasTaintFlow=a
"b".upto(a) { |x| sink x } # $ hasTaintFlow=a
"b".upto(a, true) { |x| sink x }
end
end

View File

@@ -9,7 +9,7 @@ end
class OneController < ActionController::Base
before_action :a
after_action :c
def a
@foo = params[:foo]
end
@@ -25,7 +25,7 @@ end
class TwoController < ActionController::Base
before_action :a
after_action :c
def a
@foo = params[:foo]
end
@@ -35,14 +35,14 @@ class TwoController < ActionController::Base
end
def c
sink @foo # $ SPURIOUS: hasTaintFlow
sink @foo
end
end
class ThreeController < ActionController::Base
before_action :a
after_action :c
def a
@foo = params[:foo]
@foo = "safe"
@@ -52,14 +52,14 @@ class ThreeController < ActionController::Base
end
def c
sink @foo # $ SPURIOUS: hasTaintFlow
sink @foo
end
end
class FourController < ActionController::Base
before_action :a
after_action :c
def a
@foo.bar = params[:foo]
end
@@ -68,14 +68,14 @@ class FourController < ActionController::Base
end
def c
sink(@foo.bar) # $ hasTaintFlow
sink(@foo.bar)
end
end
class FiveController < ActionController::Base
before_action :a
after_action :c
def a
self.taint_foo
end
@@ -84,10 +84,10 @@ class FiveController < ActionController::Base
end
def c
sink @foo # $ hasTaintFlow
sink @foo
end
def taint_foo
@foo = params[:foo]
end
end
end

View File

@@ -270,6 +270,10 @@ nodes
| params_flow.rb:205:10:205:10 | a | semmle.label | a |
subpaths
testFailures
| filter_flow.rb:38:10:38:13 | @foo | Unexpected result: hasTaintFlow |
| filter_flow.rb:55:10:55:13 | @foo | Unexpected result: hasTaintFlow |
| filter_flow.rb:71:10:71:17 | call to bar | Unexpected result: hasTaintFlow |
| filter_flow.rb:87:11:87:14 | @foo | Unexpected result: hasTaintFlow |
#select
| filter_flow.rb:21:10:21:13 | @foo | filter_flow.rb:14:12:14:17 | call to params | filter_flow.rb:21:10:21:13 | @foo | $@ | filter_flow.rb:14:12:14:17 | call to params | call to params |
| filter_flow.rb:38:10:38:13 | @foo | filter_flow.rb:30:12:30:17 | call to params | filter_flow.rb:38:10:38:13 | @foo | $@ | filter_flow.rb:30:12:30:17 | call to params | call to params |

View File

@@ -497,7 +497,6 @@ nodes
| hash_extensions.rb:126:10:126:19 | call to sole | semmle.label | call to sole |
subpaths
testFailures
| hash_extensions.rb:126:10:126:19 | call to sole | Unexpected result: hasValueFlow=b |
#select
| active_support.rb:182:10:182:13 | ...[...] | active_support.rb:180:10:180:17 | call to source | active_support.rb:182:10:182:13 | ...[...] | $@ | active_support.rb:180:10:180:17 | call to source | call to source |
| active_support.rb:188:10:188:13 | ...[...] | active_support.rb:186:10:186:18 | call to source | active_support.rb:188:10:188:13 | ...[...] | $@ | active_support.rb:186:10:186:18 | call to source | call to source |

View File

@@ -123,7 +123,7 @@ def m_sole
multi = [source("b"), source("c")]
sink(empty.sole)
sink(single.sole) # $ hasValueFlow=a
sink(multi.sole) # TODO: model that 'sole' does not return if the receiver has multiple elements
sink(multi.sole) # $ hasValueFlow=b # TODO: model that 'sole' does not return if the receiver has multiple elements
end
m_sole()

View File

@@ -23,7 +23,6 @@ nodes
| views/index.erb:2:10:2:12 | call to foo | semmle.label | call to foo |
subpaths
testFailures
| views/index.erb:2:10:2:12 | call to foo | Unexpected result: hasTaintFlow |
#select
| app.rb:95:10:95:14 | @user | app.rb:103:13:103:22 | call to source | app.rb:95:10:95:14 | @user | $@ | app.rb:103:13:103:22 | call to source | call to source |
| views/index.erb:2:10:2:12 | call to foo | app.rb:75:12:75:17 | call to params | views/index.erb:2:10:2:12 | call to foo | $@ | app.rb:75:12:75:17 | call to params | call to params |

View File

@@ -1,2 +1,2 @@
<%= @foo %>
<%= sink foo %>
<%= sink foo # $ hasTaintFlow %>

View File

@@ -1,5 +1,4 @@
testFailures
| improper_memoization.rb:100:1:104:3 | m14 | Unexpected result: result=BAD |
#select
| improper_memoization.rb:50:1:55:3 | m7 | improper_memoization.rb:50:8:50:10 | arg | improper_memoization.rb:51:3:53:5 | ... \|\|= ... |
| improper_memoization.rb:58:1:63:3 | m8 | improper_memoization.rb:58:8:58:10 | arg | improper_memoization.rb:59:3:61:5 | ... \|\|= ... |

View File

@@ -101,4 +101,4 @@ def m14(arg)
@m14 ||= {}
key = "foo/#{arg}"
@m14[key] ||= long_running_method(arg)
end
end # $ SPURIOUS: result=BAD

View File

@@ -66,7 +66,7 @@ impl<'a> AstNode for Node<'a> {
impl AstNode for yeast::Node {
fn kind(&self) -> &str {
yeast::Node::kind_name(self)
yeast::Node::kind(self)
}
fn is_named(&self) -> bool {
yeast::Node::is_named(self)
@@ -280,11 +280,10 @@ pub fn location_label(writer: &mut trap::Writer, location: trap::Location) -> tr
}
/// Extracts the source file at `path`, which is assumed to be canonicalized.
/// When `desugarer` is `Some`, the parsed tree is first transformed
/// through the supplied yeast desugarer before TRAP extraction. Building
/// the desugarer (which parses YAML and constructs the schema) is the
/// caller's responsibility, allowing it to be done once and shared across
/// files.
/// When `yeast_runner` is `Some`, the parsed tree is first transformed
/// through the supplied yeast `Runner` before TRAP extraction. Building the
/// `Runner` (which parses YAML and constructs the schema) is the caller's
/// responsibility, allowing it to be done once and shared across files.
#[allow(clippy::too_many_arguments)]
pub fn extract(
language: &Language,
@@ -296,7 +295,7 @@ pub fn extract(
path: &Path,
source: &[u8],
ranges: &[Range],
desugarer: Option<&dyn yeast::Desugarer>,
yeast_runner: Option<&yeast::Runner<'_>>,
) {
let path_str = file_paths::normalize_and_transform_path(path, transformer);
let source_root = std::env::current_dir()
@@ -329,8 +328,8 @@ pub fn extract(
schema,
);
if let Some(desugarer) = desugarer {
let ast = desugarer
if let Some(yeast_runner) = yeast_runner {
let ast = yeast_runner
.run_from_tree(&tree, source)
.unwrap_or_else(|e| panic!("Desugaring failed for {path_str}: {e}"));
traverse_yeast(&ast, &mut visitor);
@@ -882,6 +881,7 @@ fn emit_extras_in(visitor: &mut Visitor, node: Node<'_>) {
}
fn traverse_yeast(tree: &yeast::Ast, visitor: &mut Visitor) {
use yeast::Cursor;
let mut cursor = tree.walk();
visitor.enter_node(cursor.node());
let mut recurse = true;

View File

@@ -13,14 +13,11 @@ pub struct LanguageSpec {
pub prefix: &'static str,
pub ts_language: tree_sitter::Language,
pub node_types: &'static str,
/// Optional desugarer. When set, the parsed tree is rewritten through
/// the desugarer before TRAP extraction. The desugarer's
/// `output_node_types_yaml()` (if set) provides the schema used both
/// at runtime (for the rewriter) and for TRAP validation.
///
/// `Box<dyn yeast::Desugarer>` so the shared extractor is agnostic to
/// the user-defined context type the desugarer uses internally.
pub desugar: Option<Box<dyn yeast::Desugarer>>,
/// Optional yeast desugaring configuration. When set, the parsed
/// tree is rewritten through yeast before TRAP extraction. The
/// config's `output_node_types_yaml` (if set) provides the schema
/// used both at runtime (for the rewriter) and for TRAP validation.
pub desugar: Option<yeast::DesugaringConfig>,
pub file_globs: Vec<String>,
}
@@ -94,22 +91,35 @@ impl Extractor {
.collect();
let mut schemas = vec![];
let mut yeast_runners = Vec::new();
for lang in &self.languages {
let effective_node_types: String = match lang
.desugar
.as_ref()
.and_then(|d| d.output_node_types_yaml())
{
Some(yaml) => yeast::node_types_yaml::convert(yaml).map_err(|e| {
std::io::Error::other(format!(
"Failed to convert YAML node-types to JSON for {}: {e}",
lang.prefix
))
})?,
None => lang.node_types.to_string(),
};
let effective_node_types: String =
match lang.desugar.as_ref().and_then(|c| c.output_node_types_yaml) {
Some(yaml) => yeast::node_types_yaml::convert(yaml).map_err(|e| {
std::io::Error::other(format!(
"Failed to convert YAML node-types to JSON for {}: {e}",
lang.prefix
))
})?,
None => lang.node_types.to_string(),
};
let schema = node_types::read_node_types_str(lang.prefix, &effective_node_types)?;
schemas.push(schema);
// Build the yeast runner once per language so the YAML schema
// isn't re-parsed for every file.
let yeast_runner = lang
.desugar
.as_ref()
.map(|config| yeast::Runner::from_config(lang.ts_language.clone(), config))
.transpose()
.map_err(|e| {
std::io::Error::other(format!(
"Failed to build desugaring runner for {}: {e}",
lang.prefix
))
})?;
yeast_runners.push(yeast_runner);
}
// Construct a single globset containing all language globs,
@@ -184,7 +194,7 @@ impl Extractor {
&path,
&source,
&[],
lang.desugar.as_deref(),
yeast_runners[i].as_ref(),
);
std::fs::create_dir_all(src_archive_file.parent().unwrap())?;
std::fs::copy(&path, &src_archive_file)?;

View File

@@ -159,7 +159,6 @@ pub fn generate(
));
body.append(&mut ql_gen::convert_nodes(&nodes));
body.push(ql_gen::create_print_ast_module(&nodes));
ql::write(
&mut ql_writer,
&[ql::TopLevel::Module(ql::Module {

View File

@@ -150,14 +150,12 @@ impl fmt::Display for Type<'_> {
pub enum Expression<'a> {
Var(&'a str),
String(&'a str),
Integer(i64),
Integer(usize),
Pred(&'a str, Vec<Expression<'a>>),
And(Vec<Expression<'a>>),
Or(Vec<Expression<'a>>),
Equals(Box<Expression<'a>>, Box<Expression<'a>>),
Dot(Box<Expression<'a>>, &'a str, Vec<Expression<'a>>),
/// A type cast, rendered as `x.(Type)`.
Cast(Box<Expression<'a>>, &'a str),
Aggregate {
name: &'a str,
vars: Vec<FormalParameter<'a>>,
@@ -221,7 +219,6 @@ impl fmt::Display for Expression<'_> {
}
write!(f, ")")
}
Expression::Cast(x, type_name) => write!(f, "{x}.({type_name})"),
Expression::Aggregate {
name,
vars,

View File

@@ -705,7 +705,7 @@ fn create_field_getters<'a>(
),
ql::Expression::Equals(
Box::new(ql::Expression::Var("value")),
Box::new(ql::Expression::Integer(*value as i64)),
Box::new(ql::Expression::Integer(*value)),
),
])
})
@@ -874,99 +874,3 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec<ql::TopLevel<'_>> {
classes
}
/// Creates a `PrintAst` module containing a `getChild` predicate that maps each
/// AST node to its children together with the name of the member predicate that
/// produced them (and, for indexed fields, the index). This mirrors the
/// information exposed by `getAFieldOrChild`, but keeps the member predicate
/// name and index so that an AST printer can render labelled edges.
pub fn create_print_ast_module(nodes: &node_types::NodeTypeMap) -> ql::TopLevel<'_> {
let mut disjuncts: Vec<ql::Expression> = Vec::new();
for node in nodes.values() {
if let node_types::EntryKind::Table { name: _, fields } = &node.kind {
for field in fields {
// `ReservedWordInt` fields have string-valued getters, so they
// are not children and are excluded (just as they are from
// `getAFieldOrChild`).
if matches!(
field.type_info,
node_types::FieldTypeInfo::ReservedWordInt(_)
) {
continue;
}
let has_index = matches!(
field.storage,
node_types::Storage::Table {
has_index: true,
..
}
);
let getter_call = ql::Expression::Dot(
Box::new(ql::Expression::Cast(
Box::new(ql::Expression::Var("node")),
&node.ql_class_name,
)),
&field.getter_name,
if has_index {
vec![ql::Expression::Var("i")]
} else {
vec![]
},
);
let mut conjuncts = vec![ql::Expression::Equals(
Box::new(ql::Expression::Var("result")),
Box::new(getter_call),
)];
if !has_index {
conjuncts.push(ql::Expression::Equals(
Box::new(ql::Expression::Var("i")),
Box::new(ql::Expression::Integer(-1)),
));
}
conjuncts.push(ql::Expression::Equals(
Box::new(ql::Expression::Var("name")),
Box::new(ql::Expression::String(&field.getter_name)),
));
disjuncts.push(ql::Expression::And(conjuncts));
}
}
}
let get_child = ql::Predicate {
qldoc: Some(String::from(
"Gets a child of `node` returned by the member predicate with the given `name`. \
If the predicate takes an index argument, `i` is bound to that index, otherwise \
`i` is `-1` (which is never a valid index).",
)),
name: "getChild",
overridden: false,
is_private: false,
is_final: false,
return_type: Some(ql::Type::Normal("AstNode")),
formal_parameters: vec![
ql::FormalParameter {
name: "node",
param_type: ql::Type::Normal("AstNode"),
},
ql::FormalParameter {
name: "name",
param_type: ql::Type::String,
},
ql::FormalParameter {
name: "i",
param_type: ql::Type::Int,
},
],
body: ql::Expression::Or(disjuncts),
overlay: None,
};
ql::TopLevel::Module(ql::Module {
qldoc: Some(String::from(
"Provides predicates for mapping AST nodes to their named children.",
)),
name: "PrintAst",
body: vec![ql::TopLevel::Predicate(get_child)],
overlay: None,
})
}

View File

@@ -41,14 +41,22 @@ pub fn query(input: TokenStream) -> TokenStream {
/// (kind "literal") - leaf with static content
/// (kind #{expr}) - leaf with computed content (expr.to_string())
/// (kind $fresh) - leaf with auto-generated unique name
/// {expr} - embed a Rust expression, dispatched via
/// the `IntoFieldIds` trait: `Id` pushes a
/// single id; iterables (`Vec<Id>`,
/// `Option<Id>`, iterator chains) splice
/// their elements
/// field: {expr} - extend a named field with `{expr}`'s ids
/// {expr} - embed a Rust expression returning Id
/// {..expr} - splice an iterable of Id (in child/field position)
/// field: {..expr} - splice into a named field
/// {expr}.map(p -> tpl) - apply tpl to each element; splice result
/// {expr}.reduce_left(f -> init, acc, e -> fold)
/// - fold with per-element init; splice 0 or 1 result
/// ```
///
/// Chain syntax after `{expr}` or `{..expr}`:
/// - `.map(param -> template)` — one output node per input element.
/// - `.reduce_left(first -> init, acc, elem -> fold)` — fold left; the first
/// element is converted by `init`, subsequent elements are folded by `fold`
/// with the accumulator bound to `acc`. An empty iterable yields nothing.
/// - Chains always splice (the result is iterable).
/// - Multiple chains can be chained, e.g. `.map(...).reduce_left(...)`.
///
/// Can be called with an explicit context or using the implicit context
/// from an enclosing `rule!`:
///
@@ -92,7 +100,7 @@ pub fn trees(input: TokenStream) -> TokenStream {
/// rule!(
/// (query_pattern field: (_) @name (kind)* @repeated (_)? @optional)
/// =>
/// (output_template field: {name} {repeated})
/// (output_template field: {name} {..repeated})
/// )
///
/// // Shorthand: captures become fields on the output node

View File

@@ -22,9 +22,10 @@ pub fn parse_query_top(input: TokenStream) -> Result<TokenStream> {
/// Parse a single query node (possibly with a trailing `@capture`).
fn parse_query_node(tokens: &mut Tokens) -> Result<TokenStream> {
let base = parse_query_atom(tokens)?;
// Check for trailing @capture or @@capture
// Check for trailing @capture
if peek_is_at(tokens) {
let capture_name = consume_capture_marker(tokens)?;
tokens.next(); // consume @
let capture_name = expect_ident(tokens, "expected capture name after @")?;
let name_str = capture_name.to_string();
Ok(quote! {
yeast::query::QueryNode::Capture {
@@ -120,9 +121,9 @@ fn parse_query_fields(tokens: &mut Tokens) -> Result<Vec<TokenStream>> {
std::collections::HashMap::new();
let mut bare_children: Vec<TokenStream> = Vec::new();
let push_field_elem = |order: &mut Vec<String>,
map: &mut std::collections::HashMap<String, Vec<TokenStream>>,
name: String,
elem: TokenStream| {
map: &mut std::collections::HashMap<String, Vec<TokenStream>>,
name: String,
elem: TokenStream| {
if !map.contains_key(&name) {
order.push(name.clone());
map.insert(name, vec![elem]);
@@ -158,7 +159,9 @@ fn parse_query_fields(tokens: &mut Tokens) -> Result<Vec<TokenStream>> {
push_field_elem(&mut field_order, &mut field_elems, field_str, elem);
} else {
let child = if peek_is_at(tokens) {
let capture_name = consume_capture_marker(tokens)?;
tokens.next();
let capture_name =
expect_ident(tokens, "expected capture name after @")?;
let name_str = capture_name.to_string();
quote! {
yeast::query::QueryNode::Capture {
@@ -293,10 +296,10 @@ fn parse_query_list(tokens: &mut Tokens) -> Result<Vec<TokenStream>> {
// tree! / trees! parsing — direct code generation against BuildCtx
// ---------------------------------------------------------------------------
const IMPLICIT_CTX: &str = "ctx";
const IMPLICIT_CTX: &str = "__yeast_ctx";
/// Determine the context identifier: either explicit `ctx,` or the implicit
/// `ctx` from an enclosing `rule!`.
/// `__yeast_ctx` from an enclosing `rule!`.
fn parse_ctx_or_implicit(tokens: &mut Tokens) -> Ident {
// Check if first token is an ident followed by a comma
let mut lookahead = tokens.clone();
@@ -304,8 +307,7 @@ fn parse_ctx_or_implicit(tokens: &mut Tokens) -> Ident {
&& matches!(lookahead.next(), Some(TokenTree::Punct(p)) if p.as_char() == ',');
if is_explicit {
let ctx = expect_ident(tokens, "unreachable: ident was just peeked")
.expect("unreachable: ident was just peeked");
let ctx = expect_ident(tokens, "").unwrap();
let _ = tokens.next(); // consume comma
ctx
} else {
@@ -343,7 +345,7 @@ pub fn parse_trees_top(input: TokenStream) -> Result<TokenStream> {
}
Ok(quote! {
{
let mut __nodes: Vec<yeast::Id> = Vec::new();
let mut __nodes: Vec<usize> = Vec::new();
#(#items)*
__nodes
}
@@ -357,7 +359,7 @@ fn parse_direct_node(tokens: &mut Tokens, ctx: &Ident) -> Result<TokenStream> {
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace => {
let group = expect_group(tokens, Delimiter::Brace)?;
let expr = group.stream();
Ok(quote! { ::std::convert::Into::<yeast::Id>::into({ #expr }) })
Ok(quote! { ::std::convert::Into::<usize>::into(#expr) })
}
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Parenthesis => {
let group = expect_group(tokens, Delimiter::Parenthesis)?;
@@ -394,7 +396,7 @@ fn parse_direct_node_inner(tokens: &mut Tokens, ctx: &Ident) -> Result<TokenStre
let expr = group.stream();
return Ok(quote! {
{
let __expr = { #expr };
let __expr = (#expr);
let __value = yeast::YeastDisplay::yeast_to_string(&__expr, &*#ctx.ast);
let __source_range = yeast::YeastSourceRange::yeast_source_range(&__expr, &*#ctx.ast);
#ctx.literal_with_source_range(#kind_str, &__value, __source_range)
@@ -418,11 +420,7 @@ fn parse_direct_node_inner(tokens: &mut Tokens, ctx: &Ident) -> Result<TokenStre
// Named fields — compute each value into a temp, then reference it
while peek_is_field(tokens) {
let field_name = expect_ident(tokens, "expected field name")?;
let field_str = field_name
.to_string()
.strip_prefix("r#")
.unwrap_or(&field_name.to_string())
.to_string();
let field_str = field_name.to_string().strip_prefix("r#").unwrap_or(&field_name.to_string()).to_string();
expect_punct(tokens, ':', "expected `:` after field name")?;
let temp = Ident::new(
&format!("__field_{field_str}_{field_counter}"),
@@ -430,24 +428,48 @@ fn parse_direct_node_inner(tokens: &mut Tokens, ctx: &Ident) -> Result<TokenStre
);
field_counter += 1;
// Plain `field: {expr}` — trait-dispatched extend.
// Check for field: {..expr}.chain or field: {expr}.chain — splice a Vec<Id> into the field
if peek_is_group(tokens, Delimiter::Brace) {
let group = expect_group(tokens, Delimiter::Brace)?;
let expr = group.stream();
stmts.push(quote! {
let mut #temp: Vec<yeast::Id> = Vec::new();
yeast::IntoFieldIds::extend_into({ #expr }, &mut #temp);
});
// An empty `{expr}` means the field is absent — skip it
// entirely rather than emitting an empty named field.
field_args.push(quote! {
if !#temp.is_empty() { __fields.push((#field_str, #temp)); }
});
continue;
let group_clone = tokens.clone().next().unwrap();
if let TokenTree::Group(g) = &group_clone {
let mut inner_check = g.stream().into_iter();
let is_splice = matches!(inner_check.next(), Some(TokenTree::Punct(p)) if p.as_char() == '.')
&& matches!(inner_check.next(), Some(TokenTree::Punct(p)) if p.as_char() == '.');
// Determine if a chain (.map(..)) follows the `{}` group.
let mut after = tokens.clone();
after.next(); // skip the brace group
let has_chain = matches!(after.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '.');
if is_splice || has_chain {
let group = expect_group(tokens, Delimiter::Brace)?;
let base: TokenStream = if is_splice {
let mut inner = group.stream().into_iter().peekable();
inner.next(); // consume first .
inner.next(); // consume second .
let expr: TokenStream = inner.collect();
quote! {
(#expr).into_iter().map(::std::convert::Into::<usize>::into)
}
} else {
let expr = group.stream();
quote! { (#expr).into_iter() }
};
let chained = parse_chain_suffix(tokens, ctx, base)?;
stmts.push(quote! {
let #temp: Vec<usize> = #chained.collect();
});
// An empty splice means the field is absent — skip it
// entirely rather than emitting an empty named field.
field_args.push(quote! {
if !#temp.is_empty() { __fields.push((#field_str, #temp)); }
});
continue;
}
}
}
let value = parse_direct_node(tokens, ctx)?;
stmts.push(quote! { let #temp: yeast::Id = #value; });
stmts.push(quote! { let #temp: usize = #value; });
field_args.push(quote! { __fields.push((#field_str, vec![#temp])); });
}
@@ -464,13 +486,105 @@ fn parse_direct_node_inner(tokens: &mut Tokens, ctx: &Ident) -> Result<TokenStre
Ok(quote! {
{
#(#stmts)*
let mut __fields: Vec<(&str, Vec<yeast::Id>)> = Vec::new();
let mut __fields: Vec<(&str, Vec<usize>)> = Vec::new();
#(#field_args)*
#ctx.node(#kind_str, __fields)
}
})
}
/// Parse a chain of `.method(args)` suffixes after a `{expr}` or `{..expr}`
/// placeholder in tree templates. Currently supports:
///
/// ```text
/// .map(param -> template) -- iterator map: produces Vec<usize>
/// ```
///
/// The chain may be empty (returns `base` unchanged). Multiple chained calls
/// are supported, e.g. `.map(p -> ...).map(q -> ...)`.
///
/// Each call expects the receiver to be an iterator. The `base` argument
/// should therefore already be an iterator (use `.into_iter()` on it before
/// calling this function).
fn parse_chain_suffix(
tokens: &mut Tokens,
ctx: &Ident,
base: TokenStream,
) -> Result<TokenStream> {
let mut current = base;
while matches!(tokens.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '.') {
tokens.next(); // consume .
let method = expect_ident(tokens, "expected method name after `.`")?;
let method_str = method.to_string();
let args_group = expect_group(tokens, Delimiter::Parenthesis)?;
match method_str.as_str() {
"map" => {
let mut inner = args_group.stream().into_iter().peekable();
let param = expect_ident(&mut inner, "expected lambda parameter name")?;
expect_punct(&mut inner, '-', "expected `->` after lambda parameter")?;
expect_punct(&mut inner, '>', "expected `->` after lambda parameter")?;
let body = parse_direct_node(&mut inner, ctx)?;
if let Some(tok) = inner.next() {
return Err(syn::Error::new_spanned(
tok,
"unexpected token after lambda body",
));
}
current = quote! {
#current.map(|#param| #body)
};
}
"reduce_left" => {
// Syntax: reduce_left(first -> init_tpl, acc, elem -> fold_tpl)
// - first -> init_tpl : converts the first element to the initial accumulator
// - acc, elem -> fold_tpl : fold step (acc = current accumulator, elem = next element)
// Empty iterator produces an empty iterator; non-empty produces a single-element iterator.
let mut inner = args_group.stream().into_iter().peekable();
let init_param = expect_ident(&mut inner, "expected initial lambda parameter")?;
expect_punct(&mut inner, '-', "expected `->` after init parameter")?;
expect_punct(&mut inner, '>', "expected `->` after init parameter")?;
let init_body = parse_direct_node(&mut inner, ctx)?;
expect_punct(&mut inner, ',', "expected `,` after init template")?;
let acc_param = expect_ident(&mut inner, "expected accumulator parameter")?;
expect_punct(&mut inner, ',', "expected `,` after accumulator parameter")?;
let elem_param = expect_ident(&mut inner, "expected element parameter")?;
expect_punct(&mut inner, '-', "expected `->` after element parameter")?;
expect_punct(&mut inner, '>', "expected `->` after element parameter")?;
let fold_body = parse_direct_node(&mut inner, ctx)?;
if let Some(tok) = inner.next() {
return Err(syn::Error::new_spanned(
tok,
"unexpected token after fold template",
));
}
current = quote! {
{
let mut __iter = #current;
let __result: Option<usize> = if let Some(#init_param) = __iter.next() {
let mut __acc: usize = #init_body;
for #elem_param in __iter {
let #acc_param: usize = __acc;
__acc = #fold_body;
}
Some(__acc)
} else {
None
};
__result.into_iter()
}
};
}
_ => {
return Err(syn::Error::new_spanned(
method,
format!("unknown builtin method `.{method_str}()`"),
));
}
}
}
Ok(current)
}
/// Parse the top-level list of a `trees!` template.
/// Each item is a node template or `{expr}` splice.
fn parse_direct_list(tokens: &mut Tokens, ctx: &Ident) -> Result<Vec<TokenStream>> {
@@ -491,14 +605,34 @@ fn parse_direct_list(tokens: &mut Tokens, ctx: &Ident) -> Result<Vec<TokenStream
continue;
}
// `{expr}` — extend `__nodes` via `IntoFieldIds`, which handles
// single ids and iterables uniformly.
// {expr} or {..expr} (with optional .chain) — single node or splice
if peek_is_group(tokens, Delimiter::Brace) {
let group = expect_group(tokens, Delimiter::Brace)?;
let expr = group.stream();
items.push(quote! {
yeast::IntoFieldIds::extend_into({ #expr }, &mut __nodes);
});
let has_chain = matches!(tokens.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '.');
let mut inner = group.stream().into_iter().peekable();
let is_splice = peek_is_dotdot(&inner);
if is_splice || has_chain {
let base: TokenStream = if is_splice {
inner.next(); // consume first .
inner.next(); // consume second .
let expr: TokenStream = inner.collect();
quote! {
(#expr).into_iter().map(::std::convert::Into::<usize>::into)
}
} else {
let expr = group.stream();
quote! { (#expr).into_iter() }
};
let chained = parse_chain_suffix(tokens, ctx, base)?;
items.push(quote! {
__nodes.extend(#chained);
});
} else {
let expr = group.stream();
items.push(quote! {
__nodes.push(::std::convert::Into::<usize>::into(#expr));
});
}
continue;
}
@@ -515,9 +649,6 @@ fn parse_direct_list(tokens: &mut Tokens, ctx: &Ident) -> Result<Vec<TokenStream
struct CaptureInfo {
name: String,
multiplicity: CaptureMultiplicity,
/// `true` for `@@name` captures: the auto-translate prefix skips them,
/// so the bound `Id` refers to the raw (input-schema) node.
raw: bool,
}
#[derive(Clone, Copy, PartialEq)]
@@ -576,14 +707,6 @@ fn extract_captures_inner(
extract_captures_inner(&mut inner, captures, child_mult);
}
TokenTree::Punct(p) if p.as_char() == '@' => {
// `@@name` marks the capture as raw (skip auto-translate).
let raw = matches!(
tokens.peek(),
Some(TokenTree::Punct(p)) if p.as_char() == '@'
);
if raw {
tokens.next(); // consume the second `@`
}
if let Some(TokenTree::Ident(name)) = tokens.next() {
let mult = if parent_mult == CaptureMultiplicity::Repeated
|| last_mult == CaptureMultiplicity::Repeated
@@ -599,7 +722,6 @@ fn extract_captures_inner(
captures.push(CaptureInfo {
name: name.to_string(),
multiplicity: mult,
raw,
});
}
last_mult = CaptureMultiplicity::Single;
@@ -653,14 +775,6 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
// Parse query
let query_code = parse_query_top(query_stream.clone())?;
// Capture names marked `@@name` (raw) — passed to the auto-translate
// prefix as a skip list so those captures keep their input-schema ids.
let raw_capture_names: Vec<&str> = captures
.iter()
.filter(|c| c.raw)
.map(|c| c.name.as_str())
.collect();
// Generate capture bindings
let ctx_ident = Ident::new(IMPLICIT_CTX, Span::call_site());
let bindings: Vec<TokenStream> = captures
@@ -671,17 +785,22 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
match cap.multiplicity {
CaptureMultiplicity::Repeated => {
quote! {
let #name: Vec<yeast::Id> = __captures.get_all(#name_str);
let #name: Vec<yeast::NodeRef> = __captures.get_all(#name_str)
.into_iter()
.map(yeast::NodeRef)
.collect();
}
}
CaptureMultiplicity::Optional => {
quote! {
let #name: Option<yeast::Id> = __captures.get_opt(#name_str);
let #name: Option<yeast::NodeRef> =
__captures.get_opt(#name_str).map(yeast::NodeRef);
}
}
CaptureMultiplicity::Single => {
quote! {
let #name: yeast::Id = __captures.get_var(#name_str).unwrap();
let #name: yeast::NodeRef =
yeast::NodeRef(__captures.get_var(#name_str).unwrap());
}
}
}
@@ -712,7 +831,7 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
__fields.insert(
__field_id,
#name.into_iter()
.map(::std::convert::Into::<yeast::Id>::into)
.map(::std::convert::Into::<usize>::into)
.collect(),
);
},
@@ -721,14 +840,14 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
.unwrap_or_else(|| panic!("field '{}' not found", #name_str));
if let Some(__id) = #name {
__fields.entry(__field_id).or_insert_with(Vec::new)
.push(::std::convert::Into::<yeast::Id>::into(__id));
.push(::std::convert::Into::<usize>::into(__id));
}
},
CaptureMultiplicity::Single => quote! {
let __field_id = #ctx_ident.ast.field_id_for_name(#name_str)
.unwrap_or_else(|| panic!("field '{}' not found", #name_str));
__fields.entry(__field_id).or_insert_with(Vec::new)
.push(::std::convert::Into::<yeast::Id>::into(#name));
.push(::std::convert::Into::<usize>::into(#name));
},
}
})
@@ -760,7 +879,7 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
}
quote! {
let mut __nodes: Vec<yeast::Id> = Vec::new();
let mut __nodes: Vec<usize> = Vec::new();
#(#transform_items)*
__nodes
}
@@ -769,20 +888,10 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
Ok(quote! {
{
let __query = #query_code;
yeast::Rule::new(__query, Box::new(|__ast: &mut yeast::Ast, mut __captures: yeast::captures::Captures, __fresh: &yeast::tree_builder::FreshScope, __source_range: Option<tree_sitter::Range>, __user_ctx: &mut _, __translator: yeast::TranslatorHandle<'_, _>| {
// Auto-translation prefix: recursively translate every
// captured node before invoking the user's transform body,
// except for `@@name` captures listed in `__skip` which the
// body consumes raw.
// For OneShot rules this preserves the legacy behaviour
// (input-schema captures translated to output-schema
// nodes); for Repeating rules it is a no-op.
let __skip: &[&str] = &[#(#raw_capture_names),*];
__translator.auto_translate_captures(&mut __captures, __ast, __user_ctx, __skip)?;
yeast::Rule::new(__query, Box::new(|__ast: &mut yeast::Ast, __captures: yeast::captures::Captures, __fresh: &yeast::tree_builder::FreshScope, __source_range: Option<tree_sitter::Range>| {
#(#bindings)*
let mut #ctx_ident = yeast::build::BuildCtx::with_translator(__ast, &__captures, __fresh, __source_range, __user_ctx, __translator);
let __result: Vec<yeast::Id> = { #transform_body };
Ok(__result)
let mut #ctx_ident = yeast::build::BuildCtx::with_source_range(__ast, &__captures, __fresh, __source_range);
#transform_body
}))
}
})
@@ -796,16 +905,6 @@ fn peek_is_at(tokens: &mut Tokens) -> bool {
matches!(tokens.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '@')
}
/// Consume an `@` or `@@` capture marker and the following name ident.
/// Caller has already verified `peek_is_at(tokens)`.
fn consume_capture_marker(tokens: &mut Tokens) -> Result<Ident> {
tokens.next(); // consume the first `@`
if peek_is_at(tokens) {
tokens.next(); // consume the second `@` of `@@`
}
expect_ident(tokens, "expected capture name after `@` or `@@`")
}
fn peek_is_literal(tokens: &mut Tokens) -> bool {
matches!(tokens.peek(), Some(TokenTree::Literal(_)))
}
@@ -818,6 +917,13 @@ fn peek_is_hash(tokens: &mut Tokens) -> bool {
matches!(tokens.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '#')
}
/// Check for `..` (two consecutive dot punctuation tokens).
fn peek_is_dotdot(tokens: &Tokens) -> bool {
let mut lookahead = tokens.clone();
matches!(lookahead.next(), Some(TokenTree::Punct(p)) if p.as_char() == '.')
&& matches!(lookahead.next(), Some(TokenTree::Punct(p)) if p.as_char() == '.')
}
fn peek_is_underscore(tokens: &mut Tokens) -> bool {
matches!(tokens.peek(), Some(TokenTree::Ident(id)) if *id == "_")
}
@@ -899,7 +1005,8 @@ fn expect_repetition(tokens: &mut Tokens) -> Result<TokenStream> {
fn maybe_wrap_capture(tokens: &mut Tokens, base: TokenStream) -> Result<TokenStream> {
if peek_is_at(tokens) {
let name = consume_capture_marker(tokens)?;
tokens.next(); // consume @
let name = expect_ident(tokens, "expected capture name after @")?;
let name_str = name.to_string();
Ok(quote! {
yeast::query::QueryNode::Capture {
@@ -926,12 +1033,13 @@ fn maybe_wrap_repetition(tokens: &mut Tokens, single: TokenStream) -> Result<Tok
}
}
/// If `@name` (or `@@name`) follows a Repeated list element, wrap each
/// child SingleNode inside the repetition with a Capture. This matches
/// tree-sitter semantics where `(_)* @name` captures each matched node.
/// If `@name` follows a Repeated list element, wrap each child SingleNode
/// inside the repetition with a Capture. This matches tree-sitter semantics
/// where `(_)* @name` captures each matched node.
fn maybe_wrap_list_capture(tokens: &mut Tokens, elem: TokenStream) -> Result<TokenStream> {
if peek_is_at(tokens) {
let name = consume_capture_marker(tokens)?;
tokens.next();
let name = expect_ident(tokens, "expected capture name after @")?;
let name_str = name.to_string();
// Re-parse the element isn't practical, so we generate a wrapper
// that creates a new Repeated with each child wrapped in a capture.

View File

@@ -214,7 +214,7 @@ yeast::tree!(ctx,
```rust
yeast::trees!(ctx,
(assignment left: {tmp} right: {right})
{body}
{..body}
)
```
@@ -256,82 +256,27 @@ occurrences of the same `$name` within one `BuildCtx` share the same value:
### Embedded Rust expressions
`{expr}` embeds a Rust expression whose value is appended to the
enclosing field (or to the rule body's id list). Dispatch happens via
the [`IntoFieldIds`] trait, which is implemented for:
- `Id` — pushes the single id.
- Any `IntoIterator<Item: Into<Id>>` — extends with all yielded ids
(covers `Vec<Id>`, `Option<Id>`, iterator chains, etc.).
So the same `{expr}` syntax handles single ids, splices, and zero-or-many
options uniformly:
`{expr}` embeds a Rust expression that returns a single node `Id`:
```rust
(assignment
left: {some_node_id} // a single Id
right: {rhs} // a captured value (inside rule!)
left: {some_node_id} // insert a pre-built node
right: {rhs} // insert a captured value (inside rule!)
)
```
`{..expr}` splices a `Vec<Id>` (or any iterable of `Id`):
```rust
yeast::trees!(ctx,
(assignment left: {tmp} right: {right})
{extra_nodes} // splices a Vec<Id>
{..extra_nodes} // splice a Vec<Id>
)
```
The contents of `{…}` are treated as a Rust block, so multi-statement
expressions (with `let` bindings) work too:
```rust
(assignment
left: {tmp}
right: {
let lit = ctx.literal("integer", "0");
tree!((binary_expr op: (operator "+") left: {tmp} right: {lit}))
})
```
Inside `rule!`, captures are Rust variables — `{name}` works for
single, optional, and repeated captures alike:
```rust
rule!(
(assignment left: @lhs right: _* @parts)
=>
(assignment left: {lhs} right: (block stmt: {parts}))
)
```
### Raw captures (`@@name`)
The default `@name` capture marker is *auto-translated*: in OneShot
phases the macro recursively translates the captured node before
binding it, so `{name}` in the output template splices a node that
already conforms to the output schema.
For rules that need the raw (input-schema) capture — typically to read
its source text or to translate it explicitly with mutable context
state between calls — use `@@name` instead. The body sees the original
input-schema `Id`:
```rust
yeast::rule!(
(assignment left: (_) @@raw_lhs right: (_) @rhs)
=>
{
// raw_lhs is untranslated: read its original source text.
let text = ctx.ast.source_text(raw_lhs);
// rhs is already translated by the auto-translate prefix.
tree!((call
method: (identifier #{text.as_str()})
receiver: {rhs}))
}
);
```
Mix `@` and `@@` freely in the same rule. In a Repeating phase both
markers are equivalent (auto-translation is a no-op for repeating
rules).
Inside `rule!`, captures are Rust variables, so `{name}` inserts a
single capture (`Id`) and `{..name}` splices a repeated capture
(`Vec<Id>`).
## Complete example: for-loop desugaring

View File

@@ -20,7 +20,7 @@ fn main() {
let args = Cli::parse();
let language = get_language(&args.language);
let source = std::fs::read_to_string(&args.file).unwrap();
let runner: yeast::Runner = yeast::Runner::new(language, &[]);
let runner = yeast::Runner::new(language, &[]);
let ast = runner.run(&source).unwrap();
println!("{}", ast.print(&source, ast.get_root()));
}

View File

@@ -2,60 +2,28 @@ use std::collections::BTreeMap;
use crate::captures::Captures;
use crate::tree_builder::FreshScope;
use crate::{Ast, FieldId, Id, NodeContent, TranslatorHandle};
use crate::{Ast, FieldId, Id, NodeContent};
/// Context for building new AST nodes during a transformation.
///
/// Used by the `tree!` and `trees!` macros. Holds a mutable reference to the
/// AST, a reference to the captures from a query match, a `FreshScope` for
/// generating unique identifiers, and a mutable reference to a user-defined
/// context of type `C`.
///
/// The user context `C` is shared across rules via the framework's driver:
/// outer rules can write to it before recursive translation, and inner rules
/// can read (or further mutate) it during their transforms. The framework
/// snapshots and restores the user context around each rule application, so
/// mutations made by a rule are visible to its descendants (via recursive
/// translation) but not to its parent's siblings.
///
/// `BuildCtx` implements [`Deref`] and [`DerefMut`] targeting `C`, so user
/// context fields are accessible as `ctx.my_field` directly (provided they
/// don't collide with `BuildCtx`'s own fields like `ast`, `captures`, etc.).
///
/// The default `C = ()` means rules that don't need any user context don't
/// pay any cost.
///
/// When constructed by the framework (via the rule! macro), `BuildCtx` also
/// carries a [`TranslatorHandle`] that the [`translate`] method delegates
/// to. When constructed by hand (e.g. in tests), the translator is `None`
/// and [`translate`] returns an error.
pub struct BuildCtx<'a, C: 'a = ()> {
/// AST, a reference to the captures from a query match, and a `FreshScope` for
/// generating unique identifiers.
pub struct BuildCtx<'a> {
pub ast: &'a mut Ast,
pub captures: &'a Captures,
pub fresh: &'a FreshScope,
/// Source range of the matched node, inherited by synthetic nodes.
pub source_range: Option<tree_sitter::Range>,
/// User-supplied context, accessible directly via `ctx.field` (via Deref).
pub user_ctx: &'a mut C,
/// Optional translator handle, populated when the context is built by
/// the framework's rule driver. None when the context is built by hand.
pub(crate) translator: Option<TranslatorHandle<'a, C>>,
}
impl<'a, C> BuildCtx<'a, C> {
pub fn new(
ast: &'a mut Ast,
captures: &'a Captures,
fresh: &'a FreshScope,
user_ctx: &'a mut C,
) -> Self {
impl<'a> BuildCtx<'a> {
pub fn new(ast: &'a mut Ast, captures: &'a Captures, fresh: &'a FreshScope) -> Self {
Self {
ast,
captures,
fresh,
source_range: None,
user_ctx,
translator: None,
}
}
@@ -64,35 +32,12 @@ impl<'a, C> BuildCtx<'a, C> {
captures: &'a Captures,
fresh: &'a FreshScope,
source_range: Option<tree_sitter::Range>,
user_ctx: &'a mut C,
) -> Self {
Self {
ast,
captures,
fresh,
source_range,
user_ctx,
translator: None,
}
}
/// Construct a `BuildCtx` carrying a translator handle. Used by the
/// `rule!` macro to enable [`translate`] inside rule transforms.
pub fn with_translator(
ast: &'a mut Ast,
captures: &'a Captures,
fresh: &'a FreshScope,
source_range: Option<tree_sitter::Range>,
user_ctx: &'a mut C,
translator: TranslatorHandle<'a, C>,
) -> Self {
Self {
ast,
captures,
fresh,
source_range,
user_ctx,
translator: Some(translator),
}
}
@@ -158,36 +103,13 @@ impl<'a, C> BuildCtx<'a, C> {
self.ast
.create_named_token_with_range(kind, generated, self.source_range)
}
}
impl<C: Clone> BuildCtx<'_, C> {
/// Recursively translate a node via the framework's rule machinery.
/// In a OneShot phase, applies OneShot rules to the given node and
/// returns the resulting node ids. In a Repeating phase, errors
/// (translation is not meaningful when input and output share a
/// schema).
///
/// Errors if this `BuildCtx` was constructed by hand (without a
/// translator handle) — for example, in unit tests that don't go
/// through the rule driver.
pub fn translate<I: Into<Id>>(&mut self, id: I) -> Result<Vec<Id>, String> {
let id = id.into();
match &self.translator {
Some(t) => t.translate(self.ast, self.user_ctx, id),
None => Err("translate() called on a BuildCtx without a translator handle".into()),
}
}
}
impl<C> std::ops::Deref for BuildCtx<'_, C> {
type Target = C;
fn deref(&self) -> &C {
&*self.user_ctx
}
}
impl<C> std::ops::DerefMut for BuildCtx<'_, C> {
fn deref_mut(&mut self) -> &mut C {
&mut *self.user_ctx
/// Prepend a value to a field of an existing node.
pub fn prepend_field(&mut self, node_id: Id, field_name: &str, value_id: Id) {
let field_id = self
.ast
.field_id_for_name(field_name)
.unwrap_or_else(|| panic!("build: field '{field_name}' not found"));
self.ast.prepend_field_child(node_id, field_id, value_id);
}
}

View File

@@ -54,24 +54,24 @@ impl Captures {
self.captures.entry(key).or_default().push(id);
}
/// Apply a fallible function to every captured id, replacing each id
/// with the results. A function returning an empty vector removes
/// the capture; returning multiple ids splices them into the
/// capture's value list (suitable for `*`/`+` captures). Captures
/// whose name appears in `skip` are left untouched. Stops and
/// returns the error on the first failure.
///
/// Used by the `rule!` macro's auto-translate prefix to translate
/// every capture except those marked `@@name` (raw).
pub fn try_map_captures_except<E>(
pub fn map_captures(&mut self, kind: &str, f: &mut impl FnMut(Id) -> Id) {
if let Some(ids) = self.captures.get_mut(kind) {
for id in ids {
*id = f(*id);
}
}
}
/// Apply a fallible function to every captured id (across all keys),
/// replacing each id with the results. A function returning an empty
/// vector removes the capture; returning multiple ids splices them
/// into the capture's value list (suitable for `*`/`+` captures).
/// Stops and returns the error on the first failure.
pub fn try_map_all_captures<E>(
&mut self,
skip: &[&str],
mut f: impl FnMut(Id) -> Result<Vec<Id>, E>,
) -> Result<(), E> {
for (name, ids) in self.captures.iter_mut() {
if skip.contains(name) {
continue;
}
for ids in self.captures.values_mut() {
let mut new_ids = Vec::with_capacity(ids.len());
for &id in ids.iter() {
new_ids.extend(f(id)?);
@@ -80,6 +80,12 @@ impl Captures {
}
Ok(())
}
pub fn map_captures_to(&mut self, from: &str, to: &'static str, f: &mut impl FnMut(Id) -> Id) {
if let Some(from_ids) = self.captures.get(from) {
let new_values = from_ids.iter().copied().map(f).collect();
self.captures.insert(to, new_values);
}
}
pub fn merge(&mut self, other: &Captures) {
for (key, ids) in &other.captures {

View File

@@ -0,0 +1,8 @@
pub trait Cursor<'a, T, N, F> {
fn node(&self) -> &'a N;
fn field_id(&self) -> Option<F>;
fn field_name(&self) -> Option<&'static str>;
fn goto_first_child(&mut self) -> bool;
fn goto_next_sibling(&mut self) -> bool;
fn goto_parent(&mut self) -> bool;
}

View File

@@ -1,6 +1,6 @@
use std::fmt::Write;
use crate::{schema::Schema, Ast, Id, Node, NodeContent, CHILD_FIELD};
use crate::{schema::Schema, Ast, Node, NodeContent, CHILD_FIELD};
/// Options for controlling AST dump output.
pub struct DumpOptions {
@@ -34,11 +34,16 @@ impl Default for DumpOptions {
/// method:
/// identifier "foo"
/// ```
pub fn dump_ast(ast: &Ast, root: Id, source: &str) -> String {
pub fn dump_ast(ast: &Ast, root: usize, source: &str) -> String {
dump_ast_with_options(ast, root, source, &DumpOptions::default())
}
pub fn dump_ast_with_options(ast: &Ast, root: Id, source: &str, options: &DumpOptions) -> String {
pub fn dump_ast_with_options(
ast: &Ast,
root: usize,
source: &str,
options: &DumpOptions,
) -> String {
let mut out = String::new();
dump_node(ast, root, source, options, 0, None, &mut out);
out
@@ -48,7 +53,12 @@ pub fn dump_ast_with_options(ast: &Ast, root: Id, source: &str, options: &DumpOp
///
/// Any node that does not match the expected type set for its parent field is
/// rendered with a trailing `" <-- ERROR: ..."` annotation on the same line.
pub fn dump_ast_with_type_errors(ast: &Ast, root: Id, source: &str, schema: &Schema) -> String {
pub fn dump_ast_with_type_errors(
ast: &Ast,
root: usize,
source: &str,
schema: &Schema,
) -> String {
dump_ast_with_type_errors_and_options(ast, root, source, schema, &DumpOptions::default())
}
@@ -58,21 +68,13 @@ pub fn dump_ast_with_type_errors(ast: &Ast, root: Id, source: &str, schema: &Sch
/// rendered with a trailing `" <-- ERROR: ..."` annotation on the same line.
pub fn dump_ast_with_type_errors_and_options(
ast: &Ast,
root: Id,
root: usize,
source: &str,
schema: &Schema,
options: &DumpOptions,
) -> String {
let mut out = String::new();
dump_node(
ast,
root,
source,
options,
0,
Some((schema, None, None)),
&mut out,
);
dump_node(ast, root, source, options, 0, Some((schema, None, None)), &mut out);
out
}
@@ -171,7 +173,7 @@ fn expected_for_field<'a>(
fn dump_node(
ast: &Ast,
id: Id,
id: usize,
source: &str,
options: &DumpOptions,
indent: usize,
@@ -230,8 +232,8 @@ fn dump_node(
}
let field_name = ast.field_name_for_id(field_id).unwrap_or("?");
let child_type_check = type_check.map(|(schema, _, _)| {
let expected =
expected_for_field(schema, node.kind_name(), field_id).or(Some(EMPTY_NODE_TYPES));
let expected = expected_for_field(schema, node.kind_name(), field_id)
.or(Some(EMPTY_NODE_TYPES));
let parent_field = Some((node.kind_name(), field_name));
(schema, expected, parent_field)
});
@@ -310,7 +312,7 @@ fn dump_node(
/// Dump a leaf node inline (no newline prefix, caller provides context).
fn dump_node_inline(
ast: &Ast,
id: Id,
id: usize,
source: &str,
options: &DumpOptions,
type_check: Option<(

View File

@@ -7,6 +7,7 @@ use serde_json::{json, Value};
pub mod build;
pub mod captures;
pub mod cursor;
pub mod dump;
pub mod node_types_yaml;
pub mod query;
@@ -18,61 +19,32 @@ mod visitor;
pub use yeast_macros::{query, rule, tree, trees};
use captures::Captures;
pub use cursor::Cursor;
use query::QueryNode;
/// Node id: an index into the [`Ast`] arena. A newtype around `usize`
/// rather than a bare alias so that it can carry its own
/// [`YeastDisplay`] / [`YeastSourceRange`] / [`IntoFieldIds`] impls
/// without colliding with the impls for plain integers.
///
/// Use `id.0` (or `id.into()`) to obtain the raw arena index.
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug, Hash, Serialize)]
pub struct Id(pub usize);
impl From<usize> for Id {
fn from(value: usize) -> Self {
Id(value)
}
}
impl From<Id> for usize {
fn from(value: Id) -> Self {
value.0
}
}
/// Node ids are indexes into the arena
pub type Id = usize;
/// Field and Kind ids are provided by tree-sitter
type FieldId = u16;
type KindId = u16;
/// Trait for values that can be appended to a field's id list inside a
/// `tree!`/`trees!`/`rule!` template (in `{expr}` placeholders).
///
/// `Id` pushes a single id; the blanket impl for
/// `IntoIterator<Item: Into<Id>>` handles `Vec<Id>`, `Option<Id>`,
/// arbitrary iterators yielding `Id`, etc.
///
/// This lets `{expr}` interpolate any of these shapes without a
/// dedicated splice syntax — the macro emits the same trait-dispatched
/// call regardless of the value's type.
pub trait IntoFieldIds {
fn extend_into(self, out: &mut Vec<Id>);
}
/// A typed reference to a node in an [`Ast`] arena. Wraps an [`Id`] but
/// deliberately does not implement [`std::fmt::Display`]: rendering a node
/// requires the [`Ast`] it lives in (to resolve [`NodeContent::Range`] back
/// to source text). Use [`YeastDisplay::yeast_to_string`] to format it.
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
pub struct NodeRef(pub Id);
impl IntoFieldIds for Id {
fn extend_into(self, out: &mut Vec<Id>) {
out.push(self);
impl NodeRef {
pub fn id(self) -> Id {
self.0
}
}
impl<I, T> IntoFieldIds for I
where
I: IntoIterator<Item = T>,
T: Into<Id>,
{
fn extend_into(self, out: &mut Vec<Id>) {
out.extend(self.into_iter().map(Into::into));
impl From<NodeRef> for Id {
fn from(value: NodeRef) -> Self {
value.0
}
}
@@ -89,21 +61,21 @@ pub trait YeastDisplay {
/// Optional source range for values used in `#{expr}` interpolations.
///
/// By default this returns `None`, so synthesized leaves inherit the matched
/// rule's source range. `Id` returns the referenced node's range, letting
/// rule's source range. `NodeRef` returns the referenced node's range, letting
/// `(kind #{capture})` carry the captured node's location.
pub trait YeastSourceRange {
fn yeast_source_range(&self, ast: &Ast) -> Option<tree_sitter::Range>;
}
impl YeastDisplay for Id {
impl YeastDisplay for NodeRef {
fn yeast_to_string(&self, ast: &Ast) -> String {
ast.source_text(*self)
ast.source_text(self.0)
}
}
impl YeastSourceRange for Id {
impl YeastSourceRange for NodeRef {
fn yeast_source_range(&self, ast: &Ast) -> Option<tree_sitter::Range> {
ast.get_node(*self).and_then(|n| match &n.content {
ast.get_node(self.0).and_then(|n| match &n.content {
NodeContent::Range(r) => Some(r.clone()),
_ => n.source_range,
})
@@ -172,36 +144,6 @@ impl<'a> AstCursor<'a> {
self.node_id
}
pub fn node(&self) -> &'a Node {
&self.ast.nodes[self.node_id.0]
}
pub fn field_id(&self) -> Option<FieldId> {
let (_, children) = self.parents.last()?;
children.current_field()
}
pub fn field_name(&self) -> Option<&'static str> {
if self.field_id() == Some(CHILD_FIELD) {
None
} else {
self.field_id()
.and_then(|id| self.ast.field_name_for_id(id))
}
}
pub fn goto_first_child(&mut self) -> bool {
self.goto_first_child_opt().is_some()
}
pub fn goto_next_sibling(&mut self) -> bool {
self.goto_next_sibling_opt().is_some()
}
pub fn goto_parent(&mut self) -> bool {
self.goto_parent_opt().is_some()
}
fn goto_next_sibling_opt(&mut self) -> Option<()> {
self.node_id = self.parents.last_mut()?.1.next()?;
Some(())
@@ -222,6 +164,37 @@ impl<'a> AstCursor<'a> {
Some(())
}
}
impl<'a> Cursor<'a, Ast, Node, FieldId> for AstCursor<'a> {
fn node(&self) -> &'a Node {
&self.ast.nodes[self.node_id]
}
fn field_id(&self) -> Option<FieldId> {
let (_, children) = self.parents.last()?;
children.current_field()
}
fn field_name(&self) -> Option<&'static str> {
if self.field_id() == Some(CHILD_FIELD) {
None
} else {
self.field_id()
.and_then(|id| self.ast.field_name_for_id(id))
}
}
fn goto_first_child(&mut self) -> bool {
self.goto_first_child_opt().is_some()
}
fn goto_next_sibling(&mut self) -> bool {
self.goto_next_sibling_opt().is_some()
}
fn goto_parent(&mut self) -> bool {
self.goto_parent_opt().is_some()
}
}
/// An iterator over the child Ids of a node.
#[derive(Debug)]
@@ -324,9 +297,7 @@ impl Ast {
/// Returns the source text for `id`, resolving `NodeContent::Range`
/// against the stored source bytes when available.
pub fn source_text(&self, id: Id) -> String {
let Some(node) = self.get_node(id) else {
return String::new();
};
let Some(node) = self.get_node(id) else { return String::new(); };
let read_range = |range: &tree_sitter::Range| {
let start = range.start_byte;
let end = range.end_byte;
@@ -368,16 +339,16 @@ impl Ast {
///
/// This reflects the effective AST after desugaring and excludes orphaned
/// arena nodes left behind by rewrite operations.
pub fn reachable_node_ids(&self) -> Vec<Id> {
pub fn reachable_node_ids(&self) -> Vec<usize> {
let mut reachable = Vec::new();
let mut stack = vec![self.root];
let mut seen = vec![false; self.nodes.len()];
while let Some(id) = stack.pop() {
if id.0 >= self.nodes.len() || seen[id.0] {
if id >= self.nodes.len() || seen[id] {
continue;
}
seen[id.0] = true;
seen[id] = true;
reachable.push(id);
if let Some(node) = self.get_node(id) {
@@ -401,11 +372,11 @@ impl Ast {
}
pub fn get_node(&self, id: Id) -> Option<&Node> {
self.nodes.get(id.0)
self.nodes.get(id)
}
pub fn print(&self, source: &str, root_id: Id) -> Value {
let root = &self.nodes()[root_id.0];
let root = &self.nodes()[root_id];
self.print_node(root, source)
}
@@ -448,7 +419,7 @@ impl Ast {
is_named,
source_range,
});
Id(id)
id
}
fn union_source_range_of_children(
@@ -515,6 +486,12 @@ impl Ast {
self.create_named_token_with_range(kind, content, None)
}
/// Prepend a child id to the given field of the given node.
pub fn prepend_field_child(&mut self, node_id: Id, field_id: FieldId, value_id: Id) {
let node = self.nodes.get_mut(node_id).expect("prepend_field_child: invalid node id");
node.fields.entry(field_id).or_default().insert(0, value_id);
}
pub fn create_named_token_with_range(
&mut self,
kind: &'static str,
@@ -536,7 +513,7 @@ impl Ast {
fields: BTreeMap::new(),
content: NodeContent::DynamicString(content),
});
Id(id)
id
}
pub fn field_name_for_id(&self, id: FieldId) -> Option<&'static str> {
@@ -620,6 +597,10 @@ pub struct Node {
}
impl Node {
pub fn kind(&self) -> &'static str {
self.kind_name
}
pub fn kind_name(&self) -> &'static str {
self.kind_name
}
@@ -719,120 +700,18 @@ impl From<tree_sitter::Range> for NodeContent {
}
}
/// A handle that lets a rule transform recursively translate AST nodes via
/// the framework's rule machinery. Constructed by the driver and passed as
/// the last argument of every [`Transform`] invocation.
///
/// The `rule!` macro uses [`TranslatorHandle::auto_translate_captures`] in
/// its generated prefix to translate captures before running the user's
/// transform body. Manually-written transforms (using [`Rule::new`]
/// directly) can call [`TranslatorHandle::translate`] selectively on
/// specific node ids to control when translation happens.
pub struct TranslatorHandle<'a, C> {
inner: TranslatorImpl<'a, C>,
}
/// Internal phase-specific translation state. Kept private — callers
/// interact with [`TranslatorHandle`] only.
enum TranslatorImpl<'a, C> {
/// OneShot phase translator: recursively applies OneShot rules.
OneShot {
index: &'a RuleIndex<'a, C>,
fresh: &'a tree_builder::FreshScope,
rewrite_depth: usize,
/// The id of the node the current rule is matching. Used by
/// [`auto_translate_captures`] to avoid infinite recursion when a
/// rule captures its own match root (e.g. via `(_) @_`).
matched_root: Id,
},
/// Repeating phase translator: translation is not meaningful here
/// (input and output schemas are the same). [`translate`] errors;
/// [`auto_translate_captures`] is a no-op so the macro's auto-prefix
/// works unchanged for Repeating rules.
Repeating,
}
impl<'a, C: Clone> TranslatorHandle<'a, C> {
/// Recursively apply OneShot rules to `id` and return the resulting
/// node ids. Errors in a Repeating phase (where translation is not
/// meaningful).
pub fn translate(&self, ast: &mut Ast, user_ctx: &mut C, id: Id) -> Result<Vec<Id>, String> {
match &self.inner {
TranslatorImpl::OneShot {
index,
fresh,
rewrite_depth,
..
} => apply_one_shot_rules_inner(index, ast, user_ctx, id, fresh, rewrite_depth + 1),
TranslatorImpl::Repeating => {
Err("translate() is not available in a Repeating phase".into())
}
}
}
/// Translate every captured node in `captures` in place (OneShot phase
/// only), except for captures whose name appears in `skip` — those are
/// left as raw (input-schema) ids for the rule body to consume
/// directly. In a Repeating phase this is a no-op — Repeating rules
/// receive raw captures regardless of `skip`.
///
/// Used by the `rule!` macro's generated prefix. `skip` is populated
/// from the macro's `@@name` capture markers; for plain `@name`
/// captures (and rules with no `@@` markers) it is empty.
///
/// To avoid infinite recursion, a capture whose id matches the rule's
/// matched root (e.g. from a `(_) @_` pattern) is left unchanged.
pub fn auto_translate_captures(
&self,
captures: &mut Captures,
ast: &mut Ast,
user_ctx: &mut C,
skip: &[&str],
) -> Result<(), String> {
match &self.inner {
TranslatorImpl::OneShot { matched_root, .. } => {
let root = *matched_root;
captures.try_map_captures_except(skip, |cid| {
if cid == root {
Ok(vec![cid])
} else {
self.translate(ast, user_ctx, cid)
}
})
}
TranslatorImpl::Repeating => Ok(()),
}
}
}
/// The transform function for a rule.
///
/// Takes the AST, the (raw, untranslated) captured variables, a fresh-name
/// scope, the source range of the matched node, a mutable reference to the
/// user context of type `C`, and a [`TranslatorHandle`] for recursively
/// translating nodes. Returns the IDs of the replacement nodes, or an
/// error message if the transform could not be completed.
///
/// Transforms produced by [`Rule::new`] receive **raw** captures and must
/// translate them themselves (via the handle). Transforms produced by the
/// `rule!` macro have an auto-translation prefix injected for backward
/// compatibility.
pub type Transform<C = ()> = Box<
dyn Fn(
&mut Ast,
Captures,
&tree_builder::FreshScope,
Option<tree_sitter::Range>,
&mut C,
TranslatorHandle<'_, C>,
) -> Result<Vec<Id>, String>
/// The transform function for a rule: takes the AST, captured variables, a
/// fresh-name scope, and the source range of the matched node, and returns
/// the IDs of the replacement nodes.
pub type Transform = Box<
dyn Fn(&mut Ast, Captures, &tree_builder::FreshScope, Option<tree_sitter::Range>) -> Vec<Id>
+ Send
+ Sync,
>;
pub struct Rule<C = ()> {
pub struct Rule {
query: QueryNode,
transform: Transform<C>,
transform: Transform,
/// If true, after this rule fires on a node the engine will try to
/// re-apply this same rule on the result root. Defaults to false:
/// each rule fires at most once on a given node, which prevents
@@ -840,8 +719,8 @@ pub struct Rule<C = ()> {
repeated: bool,
}
impl<C> Rule<C> {
pub fn new(query: QueryNode, transform: Transform<C>) -> Self {
impl Rule {
pub fn new(query: QueryNode, transform: Transform) -> Self {
Self {
query,
transform,
@@ -863,13 +742,9 @@ impl<C> Rule<C> {
ast: &mut Ast,
node: Id,
fresh: &tree_builder::FreshScope,
user_ctx: &mut C,
translator: TranslatorHandle<'_, C>,
) -> Result<Option<Vec<Id>>, String> {
match self.try_match(ast, node)? {
Some(captures) => Ok(Some(
self.run_transform(ast, captures, node, fresh, user_ctx, translator)?,
)),
Some(captures) => Ok(Some(self.run_transform(ast, captures, node, fresh))),
None => Ok(None),
}
}
@@ -893,31 +768,29 @@ impl<C> Rule<C> {
captures: Captures,
node: Id,
fresh: &tree_builder::FreshScope,
user_ctx: &mut C,
translator: TranslatorHandle<'_, C>,
) -> Result<Vec<Id>, String> {
) -> Vec<Id> {
fresh.next_scope();
let source_range = ast.get_node(node).and_then(|n| match n.content {
NodeContent::Range(r) => Some(r),
_ => n.source_range,
});
(self.transform)(ast, captures, fresh, source_range, user_ctx, translator)
(self.transform)(ast, captures, fresh, source_range)
}
}
const MAX_REWRITE_DEPTH: usize = 100;
/// Index of rules by their root query kind for fast lookup.
struct RuleIndex<'a, C> {
struct RuleIndex<'a> {
/// Rules indexed by root node kind name.
by_kind: BTreeMap<&'static str, Vec<&'a Rule<C>>>,
by_kind: BTreeMap<&'static str, Vec<&'a Rule>>,
/// Rules with wildcard queries (Any) that apply to all nodes.
wildcard: Vec<&'a Rule<C>>,
wildcard: Vec<&'a Rule>,
}
impl<'a, C> RuleIndex<'a, C> {
fn new(rules: &'a [Rule<C>]) -> Self {
let mut by_kind: BTreeMap<&'static str, Vec<&'a Rule<C>>> = BTreeMap::new();
impl<'a> RuleIndex<'a> {
fn new(rules: &'a [Rule]) -> Self {
let mut by_kind: BTreeMap<&'static str, Vec<&'a Rule>> = BTreeMap::new();
let mut wildcard = Vec::new();
for rule in rules {
match rule.query.root_kind() {
@@ -928,7 +801,7 @@ impl<'a, C> RuleIndex<'a, C> {
Self { by_kind, wildcard }
}
fn rules_for_kind(&self, kind: &str) -> impl Iterator<Item = &&'a Rule<C>> {
fn rules_for_kind(&self, kind: &str) -> impl Iterator<Item = &&'a Rule> {
self.by_kind
.get(kind)
.into_iter()
@@ -937,25 +810,23 @@ impl<'a, C> RuleIndex<'a, C> {
}
}
fn apply_repeating_rules<C: Clone>(
rules: &[Rule<C>],
fn apply_repeating_rules(
rules: &[Rule],
ast: &mut Ast,
user_ctx: &mut C,
id: Id,
fresh: &tree_builder::FreshScope,
) -> Result<Vec<Id>, String> {
let index = RuleIndex::new(rules);
apply_repeating_rules_inner(&index, ast, user_ctx, id, fresh, 0, None)
apply_repeating_rules_inner(&index, ast, id, fresh, 0, None)
}
fn apply_repeating_rules_inner<C: Clone>(
index: &RuleIndex<C>,
fn apply_repeating_rules_inner(
index: &RuleIndex,
ast: &mut Ast,
user_ctx: &mut C,
id: Id,
fresh: &tree_builder::FreshScope,
rewrite_depth: usize,
skip_rule: Option<*const Rule<C>>,
skip_rule: Option<*const Rule>,
) -> Result<Vec<Id>, String> {
if rewrite_depth > MAX_REWRITE_DEPTH {
return Err(format!(
@@ -964,25 +835,13 @@ fn apply_repeating_rules_inner<C: Clone>(
));
}
let node_kind = ast.get_node(id).map(|n| n.kind_name()).unwrap_or("");
let node_kind = ast.get_node(id).map(|n| n.kind()).unwrap_or("");
for rule in index.rules_for_kind(node_kind) {
let rule_ptr = *rule as *const Rule<C>;
let rule_ptr = *rule as *const Rule;
if Some(rule_ptr) == skip_rule {
continue;
}
// Snapshot the user context before invoking the rule so that any
// mutations the rule makes are visible during recursive translation
// of its result, but not leaked to the parent's siblings.
let snapshot = user_ctx.clone();
// Repeating rules don't need a real translator: their captures
// aren't auto-translated (Repeating preserves the input schema),
// and `ctx.translate(id)` errors if invoked from a Repeating
// transform.
let translator = TranslatorHandle {
inner: TranslatorImpl::Repeating,
};
let try_result = rule.try_rule(ast, id, fresh, user_ctx, translator)?;
if let Some(result_node) = try_result {
if let Some(result_node) = rule.try_rule(ast, id, fresh)? {
// For non-repeated rules, suppress further application of *this*
// rule on the result root, so a rule whose output matches its own
// query doesn't loop. Other rules and child traversal are
@@ -993,19 +852,14 @@ fn apply_repeating_rules_inner<C: Clone>(
results.extend(apply_repeating_rules_inner(
index,
ast,
user_ctx,
node,
fresh,
rewrite_depth + 1,
next_skip,
)?);
}
*user_ctx = snapshot;
return Ok(results);
}
// Rule didn't match; restore any speculative changes (none expected
// since try_rule only mutates on match, but be defensive).
*user_ctx = snapshot;
}
// Take the parent's fields by ownership: the recursion will rewrite
@@ -1016,19 +870,11 @@ fn apply_repeating_rules_inner<C: Clone>(
//
// Child traversal does not increment rewrite depth and starts fresh
// (no rule is skipped on child subtrees).
let mut fields = std::mem::take(&mut ast.nodes[id.0].fields);
let mut fields = std::mem::take(&mut ast.nodes[id].fields);
for children in fields.values_mut() {
let mut new_children: Option<Vec<Id>> = None;
for (i, &child_id) in children.iter().enumerate() {
let result = apply_repeating_rules_inner(
index,
ast,
user_ctx,
child_id,
fresh,
rewrite_depth,
None,
)?;
let result = apply_repeating_rules_inner(index, ast, child_id, fresh, rewrite_depth, None)?;
let unchanged = result.len() == 1 && result[0] == child_id;
match (&mut new_children, unchanged) {
(None, true) => {} // unchanged so far, no allocation needed
@@ -1049,7 +895,7 @@ fn apply_repeating_rules_inner<C: Clone>(
*children = new;
}
}
ast.nodes[id.0].fields = fields;
ast.nodes[id].fields = fields;
Ok(vec![id])
}
@@ -1057,25 +903,24 @@ fn apply_repeating_rules_inner<C: Clone>(
/// each visited node, recursion proceeds only through captured nodes (not
/// through the input node's children directly), and an error is returned if
/// no rule matches a visited node.
fn apply_one_shot_rules<C: Clone>(
rules: &[Rule<C>],
fn apply_one_shot_rules(
rules: &[Rule],
ast: &mut Ast,
user_ctx: &mut C,
id: Id,
fresh: &tree_builder::FreshScope,
) -> Result<Vec<Id>, String> {
let index = RuleIndex::new(rules);
apply_one_shot_rules_inner(&index, ast, user_ctx, id, fresh, 0)
apply_one_shot_rules_inner(&index, ast, id, fresh, 0)
}
fn apply_one_shot_rules_inner<C: Clone>(
index: &RuleIndex<C>,
fn apply_one_shot_rules_inner(
index: &RuleIndex,
ast: &mut Ast,
user_ctx: &mut C,
id: Id,
fresh: &tree_builder::FreshScope,
rewrite_depth: usize,
) -> Result<Vec<Id>, String> {
if rewrite_depth > MAX_REWRITE_DEPTH {
return Err(format!(
"Desugaring exceeded maximum rewrite depth ({MAX_REWRITE_DEPTH}). \
@@ -1083,30 +928,25 @@ fn apply_one_shot_rules_inner<C: Clone>(
));
}
let node_kind = ast.get_node(id).map(|n| n.kind_name()).unwrap_or("");
let node_kind = ast.get_node(id).map(|n| n.kind()).unwrap_or("");
for rule in index.rules_for_kind(node_kind) {
if let Some(captures) = rule.try_match(ast, id)? {
// Snapshot the user context before invoking the rule so that any
// mutations the rule (or its transitively-translated captures)
// make are visible during this rule's transform, but not leaked
// to the parent's siblings.
let snapshot = user_ctx.clone();
// Build the translator handle the transform will use to
// recursively translate captures (or, for macro-generated
// rules, the auto-translate prefix uses it to translate every
// capture up front, preserving the legacy behavior).
let translator = TranslatorHandle {
inner: TranslatorImpl::OneShot {
index,
fresh,
rewrite_depth,
matched_root: id,
},
};
let result = rule.run_transform(ast, captures, id, fresh, user_ctx, translator)?;
*user_ctx = snapshot;
return Ok(result);
if let Some(mut captures) = rule.try_match(ast, id)? {
// Recursively translate every captured node before invoking the
// transform. The transform's output uses output-schema kinds, so
// we must translate captured input-schema nodes to their
// output-schema equivalents first.
captures.try_map_all_captures(|captured_id| {
// Avoid infinite recursion when a capture refers to the root
// node of the matched tree (e.g. an `@_` capture on the
// pattern root): re-analyzing it would match the same rule
// again indefinitely.
if captured_id == id {
return Ok(vec![captured_id]);
}
apply_one_shot_rules_inner(index, ast, captured_id, fresh, rewrite_depth + 1)
})?;
return Ok(rule.run_transform(ast, captures, id, fresh));
}
}
@@ -1134,15 +974,15 @@ pub enum PhaseKind {
/// starts. Rules within a phase compete for matches as usual; rules in
/// different phases never compete because each traversal only considers the
/// current phase's rules.
pub struct Phase<C = ()> {
pub struct Phase {
/// Name used in error messages.
pub name: String,
pub rules: Vec<Rule<C>>,
pub rules: Vec<Rule>,
pub kind: PhaseKind,
}
impl<C> Phase<C> {
pub fn new(name: impl Into<String>, kind: PhaseKind, rules: Vec<Rule<C>>) -> Self {
impl Phase {
pub fn new(name: impl Into<String>, kind: PhaseKind, rules: Vec<Rule>) -> Self {
Self {
name: name.into(),
rules,
@@ -1168,30 +1008,17 @@ impl<C> Phase<C> {
/// .add_phase("desugar", PhaseKind::Repeating, desugar_rules)
/// .with_output_node_types_yaml(yaml);
/// ```
///
/// The optional type parameter `C` is the user context type threaded through
/// rule transforms. Defaults to `()` (no user context).
pub struct DesugaringConfig<C = ()> {
#[derive(Default)]
pub struct DesugaringConfig {
/// Phases of rule application, applied in order.
pub phases: Vec<Phase<C>>,
pub phases: Vec<Phase>,
/// Output node-types in YAML format. If `None`, the input grammar's
/// node types are used (i.e. the desugared AST has the same node types
/// as the tree-sitter grammar).
pub output_node_types_yaml: Option<&'static str>,
}
// Manual `Default` impl so users with a custom `C` that doesn't implement
// `Default` can still construct an empty config.
impl<C> Default for DesugaringConfig<C> {
fn default() -> Self {
Self {
phases: Vec::new(),
output_node_types_yaml: None,
}
}
}
impl<C> DesugaringConfig<C> {
impl DesugaringConfig {
/// Create an empty configuration. Add phases via [`add_phase`] and an
/// optional output schema via [`with_output_node_types_yaml`].
pub fn new() -> Self {
@@ -1203,7 +1030,7 @@ impl<C> DesugaringConfig<C> {
mut self,
name: impl Into<String>,
kind: PhaseKind,
rules: Vec<Rule<C>>,
rules: Vec<Rule>,
) -> Self {
self.phases.push(Phase::new(name, kind, rules));
self
@@ -1225,15 +1052,15 @@ impl<C> DesugaringConfig<C> {
}
}
pub struct Runner<'a, C = ()> {
pub struct Runner<'a> {
language: tree_sitter::Language,
schema: schema::Schema,
phases: &'a [Phase<C>],
phases: &'a [Phase],
}
impl<'a, C> Runner<'a, C> {
impl<'a> Runner<'a> {
/// Create a runner using the input grammar's schema for output.
pub fn new(language: tree_sitter::Language, phases: &'a [Phase<C>]) -> Self {
pub fn new(language: tree_sitter::Language, phases: &'a [Phase]) -> Self {
let schema = schema::Schema::from_language(&language);
Self {
language,
@@ -1246,7 +1073,7 @@ impl<'a, C> Runner<'a, C> {
pub fn with_schema(
language: tree_sitter::Language,
schema: &schema::Schema,
phases: &'a [Phase<C>],
phases: &'a [Phase],
) -> Self {
Self {
language,
@@ -1258,7 +1085,7 @@ impl<'a, C> Runner<'a, C> {
/// Create a runner from a [`DesugaringConfig`].
pub fn from_config(
language: tree_sitter::Language,
config: &'a DesugaringConfig<C>,
config: &'a DesugaringConfig,
) -> Result<Self, String> {
let schema = config.build_schema(&language)?;
Ok(Self {
@@ -1267,17 +1094,11 @@ impl<'a, C> Runner<'a, C> {
phases: &config.phases,
})
}
}
impl<'a, C: Clone> Runner<'a, C> {
/// Parse `tree` against `source` and run all phases, threading
/// `user_ctx` through every rule transform. The caller owns the
/// initial context state.
pub fn run_from_tree_with_ctx(
pub fn run_from_tree(
&self,
tree: &tree_sitter::Tree,
source: &[u8],
user_ctx: &mut C,
) -> Result<Ast, String> {
let mut ast = Ast::from_tree_with_schema_and_source(
self.schema.clone(),
@@ -1285,13 +1106,11 @@ impl<'a, C: Clone> Runner<'a, C> {
&self.language,
source.to_vec(),
);
self.run_phases(&mut ast, user_ctx)?;
self.run_phases(&mut ast)?;
Ok(ast)
}
/// Parse `input` and run all phases, threading `user_ctx` through
/// every rule transform. The caller owns the initial context state.
pub fn run_with_ctx(&self, input: &str, user_ctx: &mut C) -> Result<Ast, String> {
pub fn run(&self, input: &str) -> Result<Ast, String> {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&self.language)
@@ -1305,24 +1124,20 @@ impl<'a, C: Clone> Runner<'a, C> {
&self.language,
input.as_bytes().to_vec(),
);
self.run_phases(&mut ast, user_ctx)?;
self.run_phases(&mut ast)?;
Ok(ast)
}
/// Apply each phase in turn to the AST, threading the root through.
/// A single `FreshScope` is shared across phases so that fresh
/// identifiers generated in different phases don't collide.
fn run_phases(&self, ast: &mut Ast, user_ctx: &mut C) -> Result<(), String> {
fn run_phases(&self, ast: &mut Ast) -> Result<(), String> {
let fresh = tree_builder::FreshScope::new();
let mut root = ast.get_root();
for phase in self.phases {
let res = match phase.kind {
PhaseKind::Repeating => {
apply_repeating_rules(&phase.rules, ast, user_ctx, root, &fresh)
}
PhaseKind::OneShot => {
apply_one_shot_rules(&phase.rules, ast, user_ctx, root, &fresh)
}
PhaseKind::Repeating => apply_repeating_rules(&phase.rules, ast, root, &fresh),
PhaseKind::OneShot => apply_one_shot_rules(&phase.rules, ast, root, &fresh),
}
.map_err(|e| format!("Phase `{}`: {e}", phase.name))?;
if res.len() != 1 {
@@ -1338,78 +1153,3 @@ impl<'a, C: Clone> Runner<'a, C> {
Ok(())
}
}
impl<'a, C: Clone + Default> Runner<'a, C> {
/// Parse `tree` against `source` and run all phases, using the
/// default context (`C::default()`) as the initial context state.
pub fn run_from_tree(&self, tree: &tree_sitter::Tree, source: &[u8]) -> Result<Ast, String> {
let mut user_ctx = C::default();
self.run_from_tree_with_ctx(tree, source, &mut user_ctx)
}
/// Parse `input` and run all phases, using the default context
/// (`C::default()`) as the initial context state.
pub fn run(&self, input: &str) -> Result<Ast, String> {
let mut user_ctx = C::default();
self.run_with_ctx(input, &mut user_ctx)
}
}
// ---------------------------------------------------------------------------
// Desugarer: type-erased view of a DesugaringConfig + Runner
// ---------------------------------------------------------------------------
/// Type-erased interface to a desugaring pipeline for a single language.
///
/// Consumers (e.g. a generic tree-sitter extractor) hold
/// `Box<dyn Desugarer>` so they can dispatch through the trait without
/// knowing the user context type `C` that's internal to yeast.
///
/// Construct one via [`ConcreteDesugarer::new`] from a
/// [`DesugaringConfig<C>`] and a [`tree_sitter::Language`].
pub trait Desugarer: Send + Sync {
/// The output AST schema (in YAML format), or `None` if the input
/// grammar's schema should be used.
fn output_node_types_yaml(&self) -> Option<&'static str>;
/// Parse `tree` against `source` and run the desugaring pipeline.
/// Each call constructs a fresh default user context internally.
fn run_from_tree(&self, tree: &tree_sitter::Tree, source: &[u8]) -> Result<Ast, String>;
}
/// A concrete [`Desugarer`] backed by a [`DesugaringConfig<C>`] for a
/// specific user context type `C`. Stores the language and a pre-built
/// schema so that per-call cost is bounded to constructing a transient
/// [`Runner`] and cloning the schema (no YAML re-parsing).
pub struct ConcreteDesugarer<C: Default + Clone + Send + Sync + 'static> {
language: tree_sitter::Language,
schema: schema::Schema,
config: DesugaringConfig<C>,
}
impl<C: Default + Clone + Send + Sync + 'static> ConcreteDesugarer<C> {
/// Build a desugarer for `language` from `config`. Parses the output
/// schema YAML once (if set) and stores it for reuse across files.
pub fn new(
language: tree_sitter::Language,
config: DesugaringConfig<C>,
) -> Result<Self, String> {
let schema = config.build_schema(&language)?;
Ok(Self {
language,
schema,
config,
})
}
}
impl<C: Default + Clone + Send + Sync + 'static> Desugarer for ConcreteDesugarer<C> {
fn output_node_types_yaml(&self) -> Option<&'static str> {
self.config.output_node_types_yaml
}
fn run_from_tree(&self, tree: &tree_sitter::Tree, source: &[u8]) -> Result<Ast, String> {
let runner = Runner::with_schema(self.language.clone(), &self.schema, &self.config.phases);
runner.run_from_tree(tree, source)
}
}

Some files were not shown because too many files have changed in this diff Show More