This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 97ee25a7af96 [SPARK-50691][SQL][FOLLOWUP] Use UnsafeProjection for
LocalRelation rows instead of ComparableLocalRelation
97ee25a7af96 is described below
commit 97ee25a7af967d083dfb4b7cf58e38aeb8edcfe5
Author: Vladimir Golubev <[email protected]>
AuthorDate: Mon Dec 30 20:39:57 2024 +0800
[SPARK-50691][SQL][FOLLOWUP] Use UnsafeProjection for LocalRelation rows
instead of ComparableLocalRelation
### What changes were proposed in this pull request?
Use `UnsafeProjection` for `LocalRelation` rows instead of
`ComparableLocalRelation`.
### Why are the changes needed?
`UnsafeRow.equals` compares the whole byte sequence under it, so it's a
convenient way to compare all kind of row values, including `ArrayBasedMapData`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing test.
### Was this patch authored or co-authored using generative AI tooling?
copilot.nvim.
Closes #49322 from
vladimirg-db/vladimirg-db/update-comparable-local-relation-to-cover-unsafe-row.
Authored-by: Vladimir Golubev <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/plans/NormalizePlan.scala | 42 +++++-----------------
.../optimizer/BooleanSimplificationSuite.scala | 4 +--
.../catalyst/optimizer/LimitPushdownSuite.scala | 10 ++++--
3 files changed, 18 insertions(+), 38 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
index ee68e433fbea..38cf2730e9ac 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
@@ -21,7 +21,6 @@ import
org.apache.spark.sql.catalyst.analysis.GetViewColumnByNameAndOrdinal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.connector.read.streaming.SparkDataStream
object NormalizePlan extends PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan =
@@ -105,8 +104,15 @@ object NormalizePlan extends PredicateHelper {
case Project(projectList, child) =>
Project(normalizeProjectList(projectList), child)
case c: KeepAnalyzedQuery => c.storeAnalyzedQuery()
- case localRelation: LocalRelation =>
- ComparableLocalRelation.fromLocalRelation(localRelation)
+ case localRelation: LocalRelation if !localRelation.data.isEmpty =>
+ /**
+ * A substitute for the [[LocalRelation.data]]. [[GenericInternalRow]]
is incomparable for
+ * maps, because [[ArrayBasedMapData]] doesn't define [[equals]].
+ */
+ val unsafeProjection = UnsafeProjection.create(localRelation.schema)
+ localRelation.copy(data = localRelation.data.map { row =>
+ unsafeProjection(row)
+ })
}
}
@@ -137,33 +143,3 @@ object NormalizePlan extends PredicateHelper {
case _ => condition // Don't reorder.
}
}
-
-/**
- * A substitute for the [[LocalRelation]] that has comparable `data` field.
[[LocalRelation]]'s
- * `data` is incomparable for maps, because [[ArrayBasedMapData]] doesn't
define [[equals]].
- */
-case class ComparableLocalRelation(
- override val output: Seq[Attribute],
- data: Seq[Seq[Expression]],
- override val isStreaming: Boolean,
- stream: Option[SparkDataStream]) extends LeafNode
-
-object ComparableLocalRelation {
- def fromLocalRelation(localRelation: LocalRelation): ComparableLocalRelation
= {
- val dataTypes = localRelation.output.map(_.dataType)
- ComparableLocalRelation(
- output = localRelation.output,
- data = localRelation.data.map { row =>
- if (row != null) {
- row.toSeq(dataTypes).zip(dataTypes).map {
- case (value, dataType) => Literal(value, dataType)
- }
- } else {
- Seq.empty
- }
- },
- isStreaming = localRelation.isStreaming,
- stream = localRelation.stream
- )
- }
-}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index fc2697d55f6d..4cc2ee99284a 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -46,7 +46,7 @@ class BooleanSimplificationSuite extends PlanTest with
ExpressionEvalHelper {
$"e".boolean, $"f".boolean, $"g".boolean, $"h".boolean)
val testRelationWithData = LocalRelation.fromExternalRows(
- testRelation.output, Seq(Row(1, 2, 3, "abc"))
+ testRelation.output, Seq(Row(1, 2, 3, "abc", true, true, true, true))
)
val testNotNullableRelation = LocalRelation($"a".int.notNull,
$"b".int.notNull, $"c".int.notNull,
@@ -54,7 +54,7 @@ class BooleanSimplificationSuite extends PlanTest with
ExpressionEvalHelper {
$"h".boolean.notNull)
val testNotNullableRelationWithData = LocalRelation.fromExternalRows(
- testNotNullableRelation.output, Seq(Row(1, 2, 3, "abc"))
+ testNotNullableRelation.output, Seq(Row(1, 2, 3, "abc", true, true, true,
true))
)
private def checkCondition(input: Expression, expected: LogicalPlan): Unit =
{
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala
index 02631c4cf61c..2dcab5cfd29c 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.Add
+import org.apache.spark.sql.catalyst.expressions.{Add, GenericInternalRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -189,7 +189,9 @@ class LimitPushdownSuite extends PlanTest {
}
test("full outer join where neither side is limited and left side has larger
statistics") {
- val xBig = testRelation.copy(data = Seq.fill(10)(null)).subquery("x")
+ val nulls = new GenericInternalRow(
+
Seq.fill(testRelation.output.length)(null).toArray.asInstanceOf[Array[Any]])
+ val xBig = testRelation.copy(data = Seq.fill(10)(nulls)).subquery("x")
assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes)
Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition =>
val originalQuery = xBig.join(y, FullOuter, condition).limit(1).analyze
@@ -204,7 +206,9 @@ class LimitPushdownSuite extends PlanTest {
}
test("full outer join where neither side is limited and right side has
larger statistics") {
- val yBig = testRelation.copy(data = Seq.fill(10)(null)).subquery("y")
+ val nulls = new GenericInternalRow(
+
Seq.fill(testRelation.output.length)(null).toArray.asInstanceOf[Array[Any]])
+ val yBig = testRelation.copy(data = Seq.fill(10)(nulls)).subquery("y")
assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes)
Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition =>
val originalQuery = x.join(yBig, FullOuter, condition).limit(1).analyze
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]