This is an automated email from the ASF dual-hosted git repository.
yaooqinn pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 9e6e8bf03c [VL] Gate non-binary collation StringType in cached batch
stats dispatch (#12112)
9e6e8bf03c is described below
commit 9e6e8bf03ca26b73f61c6010536bd0d6416b2a7f
Author: Kent Yao <[email protected]>
AuthorDate: Wed May 20 12:43:38 2026 +0800
[VL] Gate non-binary collation StringType in cached batch stats dispatch
(#12112)
### What changes were proposed in this pull request?
Gate non-binary-collation `StringType` columns in the Velox cache path to
`supported=0` (writer side) AND strip any AND-conjunct that references such a
column from the reader-side `buildFilter` predicate vector (before delegating
to `super.buildFilter`). Writer / wire format unchanged.
New shim API `SparkShims.isBinaryCollationString` — default `true` for
Spark 3.x shims (no collation concept), overridden on Spark 4.0 / 4.1 to check
`collationId == UTF8_BINARY_COLLATION_ID`.
### Why are the changes needed?
On Spark 4.x with a non-binary collation, Velox's `scanMinMax<StringView>`
does an unsigned byte-order compare while Spark's filter compare is
collation-aware (`PhysicalStringType.ordering =
CollationFactory.fetchCollation(id).comparator`). The two disagree, so
stats-based pruning can silently drop matching rows.
Repro:
```scala
spark.sql("CREATE TABLE t(s STRING COLLATE UTF8_LCASE) USING parquet")
spark.sql("INSERT INTO t VALUES 'abc', 'XYZ'")
spark.sql("CACHE TABLE t")
spark.sql("SELECT * FROM t WHERE s = 'ABC'").show()
// Before: 0 rows (wrong). After: 1 row.
```
Vanilla Spark's `StringColumnStats` is collation-aware, so this is
Gluten-specific.
### Reader-side approach (rev 3.2)
Earlier revisions filled a `0xFF * 256B` sentinel upper bound on the
deserialize side to keep vanilla `buildFilter` from pruning. As pointed out by
@zhli1142015, that sentinel is not a universal upper bound under non-binary
collation orderings, so it is not safe.
Rev 3.2 drops the sentinel and instead wraps
`SimpleMetricsCachedBatchSerializer.buildFilter` with a
`splitConjunctivePredicates`-based predicate-strip layer
(`stripUnsupportedConjuncts`):
- For each input predicate, split into AND-conjuncts.
- Drop every conjunct whose `references` contain any attribute that was
demoted to `supported=0` (i.e. a non-binary collation StringType in
`cachedAttributes`).
- Rebuild surviving conjuncts with `And.reduce`; bypass entirely if nothing
references a demoted column.
- `Or` sub-trees stay intact (one losing-stats disjunct already loses the
whole Or anyway, so it's conservative).
Empty filtered predicates degrade gracefully: vanilla
`SimpleMetricsCachedBatchSerializer` reduces `partitionFilters` with
`.reduceOption(And).getOrElse(Literal(true))`, so the partition filter becomes
pass-through (verified against `spark-sql_2.13-4.0.1-sources`
`CachedBatchSerializer.scala`).
A real collation-aware bound (matching vanilla
`StringColumnStats.semanticCompare`) would require teaching the cpp
`scanMinMax` path about collations, likely via ICU sort keys — tracked as a
Phase-2 follow-up.
### Does this PR introduce _any_ user-facing change?
Yes — correctness fix. No new config.
### How was this patch tested?
- New `ColumnarCachedBatchBuildFilterPruneSuite` W1–W8 (wrapper behavior +
anti-regression bypass).
- New `ColumnarCachedBatchE2ESuite` cases for UTF8_LCASE + UNICODE_CI
predicate over cached batch.
- Existing suites: `ColumnarCacheShipBlockerMarshalSuite`,
`ColumnarCachedBatchStatsBlobSuite`, `ColumnarCachedBatchIntFamilyMarshalSuite`.
- `mvn clean install` + suites verified on:
- `-Pspark-4.0 -Pscala-2.13` — 42/42 PASS
- `-Pspark-4.1 -Pscala-2.13` — 42/42 PASS
- `-Pspark-3.5 -Pscala-2.12` — 32 PASS + 10 cleanly canceled (W1–W8 + 2
E2E collation cases guarded by `assume(isCollationAware)`, since
`CollationFactory` does not exist on 3.5)
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Opus 4.7
---
.../execution/ColumnarCachedBatchSerializer.scala | 57 ++++-
.../ColumnarCacheShipBlockerMarshalSuite.scala | 6 +-
.../ColumnarCachedBatchBuildFilterPruneSuite.scala | 239 ++++++++++++++++++++-
.../execution/ColumnarCachedBatchE2ESuite.scala | 43 ++++
.../org/apache/gluten/sql/shims/SparkShims.scala | 17 +-
.../gluten/sql/shims/spark40/Spark40Shims.scala | 5 +-
.../gluten/sql/shims/spark41/Spark41Shims.scala | 5 +-
7 files changed, 360 insertions(+), 12 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
index 60e264b016..7bf7a17546 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala
@@ -23,6 +23,7 @@ import org.apache.gluten.execution.{RowToVeloxColumnarExec,
VeloxColumnarToRowEx
import org.apache.gluten.iterator.Iterators
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.runtime.Runtimes
+import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.ArrowAbiUtil
import org.apache.gluten.vectorized.ColumnarBatchSerializerJniWrapper
@@ -31,8 +32,10 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Kryo.KRYO_SERIALIZER_MAX_BUFFER_SIZE
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
GenericInternalRow}
-import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatch,
SimpleMetricsCachedBatchSerializer}
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression,
ExprId}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow,
PredicateHelper}
+import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatch}
+import org.apache.spark.sql.columnar.SimpleMetricsCachedBatchSerializer
import org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -243,7 +246,12 @@ object CachedColumnarBatchKryoSerializer {
case FloatType => true // 4B IEEE 754; NaN guard in cpp scanMinMax
case DoubleType => true // 8B IEEE 754; NaN guard in cpp scanMinMax
case BooleanType => true
- case _: StringType => true // truncated to 256B; see encodeStringBounds
(any collation)
+ case s: StringType if
SparkShimLoader.getSparkShims.isBinaryCollationString(s) => true
+ // Non-binary collation: cpp scanMinMax byte-order disagrees with Spark's
+ // collation-aware String ordering at runtime
(PhysicalStringType.ordering
+ // dispatches to CollationFactory.fetchCollation(id).comparator). Demote
to
+ // supported=0; the buildFilter wrapper strips any AND-conjunct that
+ // references such columns to guarantee pass-through.
case _ => false
}
@@ -612,7 +620,8 @@ object CachedColumnarBatchKryoSerializer {
* Velox columnar cache serializer. Supports column pruning; converts
row-based input via
* [[RowToVeloxColumnarExec]] and falls back to vanilla Spark serialization
for unsupported schemas.
*/
-class ColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer
{
+class ColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer
+ with PredicateHelper {
private lazy val rowBasedCachedBatchSerializer = new
DefaultCachedBatchSerializer
private def glutenConf: GlutenConfig = GlutenConfig.get
@@ -840,11 +849,49 @@ class ColumnarCachedBatchSerializer extends
SimpleMetricsCachedBatchSerializer {
// split, vanilla SimpleMetricsCachedBatchSerializer.buildFilter NPEs on
// partitionFilter.eval(null) for non-trivial predicates -- the codegen and
interpreted
// paths both have no fallback for null stats.
+ //
+ // Strip every AND-conjunct that references a non-binary collation
StringType attribute
+ // (writer-side gate demoted those columns to supported=0; the cpp
byte-order min/max
+ // bytes do not agree with collation-aware String ordering at runtime, so
feeding such
+ // a conjunct to super.buildFilter would let the stats-bound check wrongly
prune).
+ // Or sub-trees are left intact; one disjunct losing stats already loses the
Or anyway.
+ private def stripUnsupportedConjuncts(
+ predicates: Seq[Expression],
+ cachedAttributes: Seq[Attribute]): Seq[Expression] = {
+ val skipAttrIds: Set[ExprId] = cachedAttributes.collect {
+ case a if (a.dataType match {
+ case s: StringType =>
!SparkShimLoader.getSparkShims.isBinaryCollationString(s)
+ case _ => false
+ }) =>
+ a.exprId
+ }.toSet
+ if (skipAttrIds.isEmpty) {
+ predicates
+ } else {
+ predicates.flatMap {
+ p =>
+ val conjuncts = splitConjunctivePredicates(p)
+ val kept = conjuncts.filterNot(
+ c =>
+ c.references.exists(r => skipAttrIds.contains(r.exprId)))
+ if (kept.isEmpty) None else Some(kept.reduce(And))
+ }
+ }
+ }
+
override def buildFilter(
predicates: Seq[Expression],
cachedAttributes: Seq[Attribute])
: (Int, Iterator[CachedBatch]) => Iterator[CachedBatch] = {
- val parent = super.buildFilter(predicates, cachedAttributes)
+ // cachedAttributes carries the cached relation's output ExprIds (the
underlying scan
+ // attributes), so ExprId-based matching is stable here -- no aliased
ExprIds reach
+ // this layer. Stripping is intentionally done before super.buildFilter
sees the
+ // predicate vector; empty filteredPredicates degrade gracefully because
+ // super reduces partitionFilters with
.reduceOption(And).getOrElse(Literal(true))
+ // -- verified against spark-sql_2.13-4.0.1-sources
CachedBatchSerializer.scala.
+ val parent = super.buildFilter(
+ stripUnsupportedConjuncts(predicates, cachedAttributes),
+ cachedAttributes)
(index, cachedBatchIterator) =>
new Iterator[CachedBatch] {
private val peekable = cachedBatchIterator.buffered
diff --git
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCacheShipBlockerMarshalSuite.scala
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCacheShipBlockerMarshalSuite.scala
index c25638d91c..5ef5d8d544 100644
---
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCacheShipBlockerMarshalSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCacheShipBlockerMarshalSuite.scala
@@ -125,7 +125,9 @@ class ColumnarCacheShipBlockerMarshalSuite extends
AnyFunSuite {
val schema = StructType(Seq(StructField("s", StringType)))
val blob = CachedColumnarBatchKryoSerializer.serializeStats(stats, schema)
val read = CachedColumnarBatchKryoSerializer.deserializeStats(blob, schema)
- assert(read.isNullAt(0), "lower bound must be null when carry overflows")
- assert(read.isNullAt(1), "upper bound must be null when carry overflows")
+ // supported=0 StringType: no sentinel bound; left null. The buildFilter
wrapper
+ // strips conjuncts referencing demoted columns before super sees them.
+ assert(read.isNullAt(0))
+ assert(read.isNullAt(1))
}
}
diff --git
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCachedBatchBuildFilterPruneSuite.scala
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCachedBatchBuildFilterPruneSuite.scala
index 1956139d1c..151cd40015 100644
---
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCachedBatchBuildFilterPruneSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCachedBatchBuildFilterPruneSuite.scala
@@ -16,9 +16,15 @@
*/
package org.apache.spark.sql.execution
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo,
GenericInternalRow, Literal}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference,
EqualTo}
+import org.apache.spark.sql.catalyst.expressions.{Expression,
GenericInternalRow, GreaterThan}
+import org.apache.spark.sql.catalyst.expressions.{In, IsNotNull, IsNull,
LessThan}
+import org.apache.spark.sql.catalyst.expressions.{Literal, Or, StartsWith}
+// CollationFactory + StringType(collationId) are Spark 4.0+ only.
+// Use reflection so this suite still compiles against Spark 3.5 shim.
import org.apache.spark.sql.columnar.CachedBatch
-import org.apache.spark.sql.types.LongType
+import org.apache.spark.sql.types.{IntegerType, LongType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
import org.scalatest.funsuite.AnyFunSuite
@@ -92,4 +98,233 @@ class ColumnarCachedBatchBuildFilterPruneSuite extends
AnyFunSuite {
result.map(_.numRows) === Seq(7, 10),
"order: stats=null pass-through first, then stats-covers-literal kept")
}
+
+ //
---------------------------------------------------------------------------
+ // W1-W8 -- non-binary collation StringType wrapper behavior.
+ // The wrapper strips AND-conjuncts referencing non-binary collation
StringType
+ // attributes via splitConjunctivePredicates, leaving binary attribute
predicates
+ // intact. See ColumnarCachedBatchSerializer.stripUnsupportedConjuncts above.
+ //
---------------------------------------------------------------------------
+
+ private val binaryString: StringType = StringType
+ // Build StringType("UTF8_LCASE") reflectively -- Spark 4.0+ only.
+ // We resolve the companion's `apply(int)` method (Spark
types.StringType$.apply(int))
+ // rather than the case-class constructor (which has private/defaulted
secondary args).
+ // Fragile only if Spark renames the companion apply signature; if that
happens, this
+ // suite would fail loudly at test load on 4.0+, so the break is visible,
not silent.
+ // On Spark 3.5 there is no collation system; tests guarded by
`assume(isCollationAware)`.
+ private val isCollationAware: Boolean = {
+ try {
+ // scalastyle:off classforname
+ Class.forName("org.apache.spark.sql.catalyst.util.CollationFactory")
+ // scalastyle:on classforname
+ true
+ } catch { case _: ClassNotFoundException => false }
+ }
+ private val nbString: StringType = {
+ if (!isCollationAware) {
+ StringType
+ } else {
+ // scalastyle:off classforname
+ val cf =
Class.forName("org.apache.spark.sql.catalyst.util.CollationFactory")
+ val moduleCls = Class.forName("org.apache.spark.sql.types.StringType$")
+ // scalastyle:on classforname
+ val idField = cf.getField("UTF8_LCASE_COLLATION_ID")
+ val collationId = idField.getInt(null)
+ val module = moduleCls.getField("MODULE$").get(null)
+ val applyM = moduleCls.getMethod("apply", java.lang.Integer.TYPE)
+ applyM.invoke(module,
Integer.valueOf(collationId)).asInstanceOf[StringType]
+ }
+ }
+
+ // Build a stats row + batch for a single non-binary collation StringType
attr.
+ private def stringBatch(
+ lower: String,
+ upper: String,
+ numRows: Int = 10): CachedColumnarBatch = {
+ val stats = new GenericInternalRow(
+ Array[Any](
+ UTF8String.fromString(lower),
+ UTF8String.fromString(upper),
+ 0,
+ numRows,
+ numRows.toLong * 4L))
+ CachedColumnarBatch(
+ numRows = numRows,
+ sizeInBytes = numRows.toLong * 4L,
+ bytes = Array.fill[Byte](numRows * 4)(0),
+ stats = stats)
+ }
+
+ // Build a stats row + batch for [String nb, Int int] schema (5 slots per
col).
+ private def mixedBatch(
+ strLower: String,
+ strUpper: String,
+ intLower: Int,
+ intUpper: Int,
+ numRows: Int = 10,
+ strNullCount: Int = 0): CachedColumnarBatch = {
+ val stats = new GenericInternalRow(
+ Array[Any](
+ UTF8String.fromString(strLower),
+ UTF8String.fromString(strUpper),
+ strNullCount,
+ numRows,
+ numRows.toLong * 4L,
+ intLower,
+ intUpper,
+ 0,
+ numRows,
+ numRows.toLong * 4L
+ ))
+ CachedColumnarBatch(
+ numRows = numRows,
+ sizeInBytes = numRows.toLong * 8L,
+ bytes = Array.fill[Byte](numRows * 8)(0),
+ stats = stats)
+ }
+
+ test("W1: wrapper strips predicate on non-binary collation StringType
attribute") {
+ assume(isCollationAware)
+ // Stats range [aaa, aaz] does NOT cover literal "zzz". Without strip:
super would
+ // generate `lower <= 'zzz' && 'zzz' <= upper` -> false -> drop. With
wrapper: stripped
+ // -> no predicate -> kept.
+ val serializer = new ColumnarCachedBatchSerializer
+ val attr = AttributeReference("c", nbString, nullable = false)()
+ val predicate = EqualTo(attr, Literal.create("zzz", nbString))
+ val filter = serializer.buildFilter(Seq(predicate), Seq(attr))
+
+ val result = filter(0, Iterator(stringBatch("aaa", "aaz"))).toList
+ assert(result.length === 1, "non-binary attr predicate must be stripped ->
batch kept")
+ }
+
+ test("W2: wrapper preserves predicate on binary collation StringType
attribute") {
+ assume(isCollationAware)
+ // Binary collation: predicate untouched, super applies -> batch pruned.
+ val serializer = new ColumnarCachedBatchSerializer
+ val attr = AttributeReference("c", binaryString, nullable = false)()
+ val predicate = EqualTo(attr, Literal.create("zzz", binaryString))
+ val filter = serializer.buildFilter(Seq(predicate), Seq(attr))
+
+ val result = filter(0, Iterator(stringBatch("aaa", "aaz"))).toList
+ assert(
+ result.length === 0,
+ "binary attr predicate preserved -> super prunes batch outside [aaa,
aaz]")
+ }
+
+ test("W3: mixed-attr conjunct: nb='zzz' AND int>=300 keeps int predicate,
batch pruned") {
+ assume(isCollationAware)
+ // Stats: nb=[aaa, aaz], int=[100, 200]. Predicate And(nb='zzz', int>=300).
+ // Conjunct-level strip: nb stripped, int>=300 remains -> super applies
+ // `300 <= upperBound(200)` -> false -> drop. We assert drop to prove
conjunct level.
+ val serializer = new ColumnarCachedBatchSerializer
+ val nbAttr = AttributeReference("nb", nbString, nullable = false)()
+ val intAttr = AttributeReference("i", IntegerType, nullable = false)()
+ val pred = And(
+ EqualTo(nbAttr, Literal.create("zzz", nbString)),
+ GreaterThan(intAttr, Literal(300)))
+ val filter = serializer.buildFilter(Seq(pred), Seq(nbAttr, intAttr))
+
+ val result = filter(0, Iterator(mixedBatch("aaa", "aaz", 100, 200))).toList
+ assert(
+ result.length === 0,
+ "conjunct-level strip: int>=300 survives strip -> batch w/ int=[100,200]
pruned")
+ }
+
+ test("W4: nested And: nb='zzz' AND (int>=300 AND int<=400) splits deeply,
batch pruned") {
+ assume(isCollationAware)
+ // Stats: nb=[aaa, aaz], int=[100, 200]. Predicate And(nb='zzz',
And(int>=300, int<=400)).
+ // splitConjunctivePredicates must unpack nested And: 3 conjuncts -> strip
nb -> keep
+ // [int>=300, int<=400] -> reduce(And) -> super prunes (int=[100,200]
disjoint from [300,400]).
+ val serializer = new ColumnarCachedBatchSerializer
+ val nbAttr = AttributeReference("nb", nbString, nullable = false)()
+ val intAttr = AttributeReference("i", IntegerType, nullable = false)()
+ val pred = And(
+ EqualTo(nbAttr, Literal.create("zzz", nbString)),
+ And(GreaterThan(intAttr, Literal(300)), LessThan(intAttr, Literal(400))))
+ val filter = serializer.buildFilter(Seq(pred), Seq(nbAttr, intAttr))
+
+ val result = filter(0, Iterator(mixedBatch("aaa", "aaz", 100, 200))).toList
+ assert(
+ result.length === 0,
+ "splitConjunctivePredicates unpacks nested And -> int conjuncts survive
-> batch pruned")
+ }
+
+ test("W5: Or branch: Or(nb='zzz', int<150) stripped entirely (Or
conservative), batch kept") {
+ assume(isCollationAware)
+ // Stats: nb=[aaa, aaz], int=[300, 400]. Without wrapper:
+ // nb branch: 'aaa' <= 'zzz' && 'zzz' <= 'aaz' -> false
(collation-dependent)
+ // int branch: lowerBound(300) < 150 -> false
+ // Or = false -> drop.
+ // With wrapper: Or references nb -> entire Or stripped
(splitConjunctivePredicates
+ // does not split Or) -> kept=empty -> no predicate -> batch kept.
+ val serializer = new ColumnarCachedBatchSerializer
+ val nbAttr = AttributeReference("nb", nbString, nullable = false)()
+ val intAttr = AttributeReference("i", IntegerType, nullable = false)()
+ val pred = Or(
+ EqualTo(nbAttr, Literal.create("zzz", nbString)),
+ LessThan(intAttr, Literal(150)))
+ val filter = serializer.buildFilter(Seq(pred), Seq(nbAttr, intAttr))
+
+ val result = filter(0, Iterator(mixedBatch("aaa", "aaz", 300, 400))).toList
+ assert(
+ result.length === 1,
+ "Or containing nb attr stripped wholesale -> no predicate -> batch kept
(pass-through)")
+ }
+
+ test("W6: IsNull(nb attr) stripped, IsNotNull(int) kept, batch evaluated by
int only") {
+ assume(isCollationAware)
+ // Stats: nb nullCount=0, int has rows. Predicates [IsNull(nb),
IsNotNull(int)].
+ // Without wrapper: IsNull(nb) -> nullCount>0 -> 0>0=false -> batch
dropped.
+ // With wrapper: IsNull(nb) stripped; IsNotNull(int) survives ->
count-nullCount>0 ->
+ // 10-0=10>0=true -> kept.
+ val serializer = new ColumnarCachedBatchSerializer
+ val nbAttr = AttributeReference("nb", nbString, nullable = true)()
+ val intAttr = AttributeReference("i", IntegerType, nullable = true)()
+ val preds: Seq[Expression] = Seq(IsNull(nbAttr), IsNotNull(intAttr))
+ val filter = serializer.buildFilter(preds, Seq(nbAttr, intAttr))
+
+ val result = filter(0, Iterator(mixedBatch("aaa", "aaz", 100, 200))).toList
+ assert(
+ result.length === 1,
+ "IsNull(nb) stripped -> only IsNotNull(int) drives decision -> kept")
+ }
+
+ test("W7: In(nb attr, list) and StartsWith(nb attr, lit) both stripped,
batch kept") {
+ assume(isCollationAware)
+ // Both predicates reference nb only -> both stripped -> no predicate ->
kept.
+ val serializer = new ColumnarCachedBatchSerializer
+ val nbAttr = AttributeReference("nb", nbString, nullable = false)()
+ val preds: Seq[Expression] = Seq(
+ In(nbAttr, Seq(Literal.create("xx", nbString), Literal.create("yy",
nbString))),
+ StartsWith(nbAttr, Literal.create("zz", nbString)))
+ val filter = serializer.buildFilter(preds, Seq(nbAttr))
+
+ val result = filter(0, Iterator(stringBatch("aaa", "aaz"))).toList
+ assert(result.length === 1, "In and StartsWith on nb attr both stripped ->
batch kept")
+ }
+
+ test("W8 (anti-regression): bypassing wrapper would let UTF8_LCASE bound
prune batch") {
+ assume(isCollationAware)
+ // Anchors the non-binary collation attack scenario: if the wrapper is ever
+ // moved/removed, super's partition filter on a non-binary collation attr
+ // uses cpp-written byte-order bounds in a collation-aware comparator --
+ // behavior is implementation-defined and can drop valid rows. Use vanilla
+ // DefaultCachedBatchSerializer (extends
SimpleMetricsCachedBatchSerializer,
+ // no wrapper) to demonstrate the unsafe path.
+ val vanilla = new
org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer
+ val attr = AttributeReference("c", nbString, nullable = false)()
+ val predicate = EqualTo(attr, Literal.create("zzz", nbString))
+ val filter = vanilla.buildFilter(Seq(predicate), Seq(attr))
+
+ val batch = stringBatch("aaa", "aaz")
+ val result = filter(0, Iterator[CachedBatch](batch)).toList
+ // Vanilla (no wrapper): super applies the predicate via collation-aware
comparator.
+ // For UTF8_LCASE 'aaa' <= 'zzz' && 'zzz' <= 'aaz' -> false -> batch
dropped.
+ // Asserting drop locks in the negative ground truth -- wrapper is the
only thing
+ // standing between user data and silent loss for non-binary collation
columns.
+ assert(
+ result.length === 0,
+ "without wrapper: non-binary attr predicate applied -> batch dropped")
+ }
}
diff --git
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCachedBatchE2ESuite.scala
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCachedBatchE2ESuite.scala
index 3b4d1d3a16..34998c887d 100644
---
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCachedBatchE2ESuite.scala
+++
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/ColumnarCachedBatchE2ESuite.scala
@@ -321,6 +321,49 @@ class ColumnarCachedBatchE2ESuite
}
}
+ test("non-binary collation StringType: sentinel demotion keeps batch (no
silent prune)") {
+ assume(
+ spark.version.startsWith("4."),
+ "COLLATE syntax requires Spark 4.0+; sentinel path is also gated to
Spark 4.x shims")
+ val cached = spark
+ .range(N)
+ .selectExpr("concat('k_', lpad(cast(id as string), 4, '0')) COLLATE
UTF8_LCASE as s")
+ .repartitionByRange(P, col("s"))
+ .cache()
+ try {
+ cached.count()
+ val result = cached.filter(col("s") === "K_0500").count()
+ assert(
+ result == 1L,
+ s"non-binary collation equality must return the matching row, got
$result")
+ } finally {
+ cached.unpersist()
+ }
+ }
+
+ // ICU collation coverage: UNICODE_CI is a separate collation family
+ // (ICU collator, not just lowercase) from UTF8_LCASE. Same wrapper-strip
+ // behavior expected, proving the mechanism is not specific to one collation
kind.
+ test("non-binary collation StringType (UNICODE_CI): pass-through, no silent
prune") {
+ assume(
+ spark.version.startsWith("4."),
+ "COLLATE syntax requires Spark 4.0+")
+ val cached = spark
+ .range(N)
+ .selectExpr("concat('k_', lpad(cast(id as string), 4, '0')) COLLATE
UNICODE_CI as s")
+ .repartitionByRange(P, col("s"))
+ .cache()
+ try {
+ cached.count()
+ val result = cached.filter(col("s") === "K_0500").count()
+ assert(
+ result == 1L,
+ s"UNICODE_CI equality must return the matching row, got $result")
+ } finally {
+ cached.unpersist()
+ }
+ }
+
// Config-gate negative test: with partition stats disabled (the production
default),
// serializeWithStats must NOT be invoked -- the legacy serialize() path is
taken and stats
// are emitted as null. A bug in the gate could silently activate stats for
all users, or
diff --git
a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
index 4d1fd804a9..1f1cfdf282 100644
--- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
@@ -41,7 +41,7 @@ import
org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike,
ShuffleExchangeLike}
import org.apache.spark.sql.execution.window.WindowGroupLimitExecShim
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DecimalType, StructType}
+import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.util.SparkShimVersionUtil
import org.apache.hadoop.fs.{FileStatus, Path}
@@ -292,4 +292,19 @@ trait SparkShims {
* similar to LeftOuter. Default implementation returns false for Spark 3.x
compatibility.
*/
def isLeftSingleJoinType(joinType: JoinType): Boolean = false
+
+ /**
+ * Returns true iff the given StringType uses the UTF8_BINARY collation (id
== 0).
+ *
+ * Spark 4.0 introduced collation-aware StringType. Bound computation in
gluten cached batch
+ * partition-stats uses unsigned byte order, which only matches Spark's
predicate semantics for
+ * UTF8_BINARY. Non-binary collations must be gated out of the dispatch fast
path;
+ * deserializeStats fills a sentinel bound so vanilla
+ * SimpleMetricsCachedBatchSerializer.buildFilter pass-throughs them.
+ *
+ * Default returns true (Spark 3.x has no collation concept; all StringType
is binary). Any future
+ * Spark 4.0+ shim MUST override and consult collationId, otherwise the
binary-only invariant
+ * degrades silently to "accept any collation".
+ */
+ def isBinaryCollationString(dt: StringType): Boolean = true
}
diff --git
a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
index 5c6451a67b..fb38af3060 100644
---
a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
+++
b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning,
KeyGroupedShuffleSpec, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.catalyst.util.{InternalRowComparableWrapper,
MapData, TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.{CollationFactory,
InternalRowComparableWrapper, MapData, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.connector.catalog.Table
@@ -647,4 +647,7 @@ class Spark40Shims extends SparkShims {
override def isLeftSingleJoinType(joinType: JoinType): Boolean = {
joinType == LeftSingle
}
+
+ override def isBinaryCollationString(dt: StringType): Boolean =
+ dt.collationId == CollationFactory.UTF8_BINARY_COLLATION_ID
}
diff --git
a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala
b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala
index 3eabf6b595..238c8ded81 100644
---
a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala
+++
b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning,
KeyGroupedShuffleSpec, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.catalyst.util.{InternalRowComparableWrapper,
MapData, TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.{CollationFactory,
InternalRowComparableWrapper, MapData, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition,
Scan}
@@ -663,4 +663,7 @@ class Spark41Shims extends SparkShims {
override def isLeftSingleJoinType(joinType: JoinType): Boolean = {
joinType == LeftSingle
}
+
+ override def isBinaryCollationString(dt: StringType): Boolean =
+ dt.collationId == CollationFactory.UTF8_BINARY_COLLATION_ID
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]