This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ce1fe57cdd7 [SPARK-44653][SQL] Non-trivial DataFrame unions should not break caching ce1fe57cdd7 is described below commit ce1fe57cdd7004a891ef8b97c77ac96b3719efcd Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Fri Aug 4 11:26:36 2023 +0800 [SPARK-44653][SQL] Non-trivial DataFrame unions should not break caching ### What changes were proposed in this pull request? We have a long-standing tricky optimization in `Dataset.union`, which invokes the optimizer rule `CombineUnions` to pre-optimize the analyzed plan. This is to avoid too large analyzed plan for a specific dataframe query pattern `df1.union(df2).union(df3).union...`. This tricky optimization is designed to break dataframe caching, but we thought it was fine as people usually won't cache the intermediate dataframe in a union chain. However, `CombineUnions` gets improved from time to time (e.g. https://github.com/apache/spark/pull/35214) and now it can optimize a wide range of Union patterns. Now it's possible that people union two dataframe, do something with `select`, and cache it. Then the dataframe is unioned again with other dataframes and peop [...] This PR updates `Dataset.union` to only combine adjacent Unions to match the original purpose. ### Why are the changes needed? Fix perf regression due to breaking df caching ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new test Closes #42315 from cloud-fan/union. Lead-authored-by: Wenchen Fan <wenc...@databricks.com> Co-authored-by: Wenchen Fan <cloud0...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- .../main/scala/org/apache/spark/sql/Dataset.scala | 56 ++++++++++++++++++---- .../org/apache/spark/sql/DatasetCacheSuite.scala | 21 ++++++++ 3 files changed, 70 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9fc664bb1c2..f83cd36f0a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -157,7 +157,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // since the other rules might make two separate Unions operators adjacent. Batch("Inline CTE", Once, InlineCTE()) :: - Batch("Union", Once, + Batch("Union", fixedPoint, RemoveNoopOperators, CombineUnions, RemoveNoopUnion) :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7b2259a6d99..61c83829d20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -42,11 +42,10 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} -import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId @@ -2241,6 +2240,51 @@ class Dataset[T] private[sql]( Offset(Literal(n), logicalPlan) } + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + private def combineUnions(plan: LogicalPlan): LogicalPlan = { + plan.transformDownWithPruning(_.containsPattern(TreePattern.UNION)) { + case Distinct(u: Union) => + Distinct(flattenUnion(u, isUnionDistinct = true)) + // Only handle distinct-like 'Deduplicate', where the keys == output + case Deduplicate(keys: Seq[Attribute], u: Union) if AttributeSet(keys) == u.outputSet => + Deduplicate(keys, flattenUnion(u, true)) + case u: Union => + flattenUnion(u, isUnionDistinct = false) + } + } + + private def flattenUnion(u: Union, isUnionDistinct: Boolean): Union = { + var changed = false + // We only need to look at the direct children of Union, as the nested adjacent Unions should + // have been combined already by previous `Dataset#union` transformations. + val newChildren = u.children.flatMap { + case Distinct(Union(children, byName, allowMissingCol)) + if isUnionDistinct && byName == u.byName && allowMissingCol == u.allowMissingCol => + changed = true + children + // Only handle distinct-like 'Deduplicate', where the keys == output + case Deduplicate(keys: Seq[Attribute], child @ Union(children, byName, allowMissingCol)) + if AttributeSet(keys) == child.outputSet && isUnionDistinct && byName == u.byName && + allowMissingCol == u.allowMissingCol => + changed = true + children + case Union(children, byName, allowMissingCol) + if !isUnionDistinct && byName == u.byName && allowMissingCol == u.allowMissingCol => + changed = true + children + case other => + Seq(other) + } + if (changed) { + val newUnion = Union(newChildren) + newUnion.copyTagsFrom(u) + newUnion + } else { + u + } + } + /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. * @@ -2272,9 +2316,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def union(other: Dataset[T]): Dataset[T] = withSetOperator { - // This breaks caching, but it's usually ok because it addresses a very specific use case: - // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)) + combineUnions(Union(logicalPlan, other.logicalPlan)) } /** @@ -2366,9 +2408,7 @@ class Dataset[T] private[sql]( * @since 3.1.0 */ def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] = withSetOperator { - // This breaks caching, but it's usually ok because it addresses a very specific use case: - // using union to union many files or partitions. - CombineUnions(Union(logicalPlan :: other.logicalPlan :: Nil, true, allowMissingColumns)) + combineUnions(Union(logicalPlan :: other.logicalPlan :: Nil, true, allowMissingColumns)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 6033b9fee84..a657c6212aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -273,4 +273,25 @@ class DatasetCacheSuite extends QueryTest } } } + + test("SPARK-44653: non-trivial DataFrame unions should not break caching") { + val df1 = Seq(1 -> 1).toDF("i", "j") + val df2 = Seq(2 -> 2).toDF("i", "j") + val df3 = Seq(3 -> 3).toDF("i", "j") + + withClue("positive") { + val unionDf = df1.union(df2).select($"i") + unionDf.cache() + val finalDf = unionDf.union(df3.select($"i")) + assert(finalDf.queryExecution.executedPlan.exists(_.isInstanceOf[InMemoryTableScanExec])) + } + + withClue("negative") { + val unionDf = df1.union(df2) + unionDf.cache() + val finalDf = unionDf.union(df3) + // It's by design to break caching here. + assert(!finalDf.queryExecution.executedPlan.exists(_.isInstanceOf[InMemoryTableScanExec])) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org