This is an automated email from the ASF dual-hosted git repository.
aokolnychyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 8026442e48 Spark 3.3: Use regular planning for applicable row-level
operations (#6534)
8026442e48 is described below
commit 8026442e486c5fde70c855fe88839bc2411a5ba1
Author: Anton Okolnychyi <[email protected]>
AuthorDate: Fri Jan 6 13:52:29 2023 -0800
Spark 3.3: Use regular planning for applicable row-level operations (#6534)
---
.../catalyst/analysis/RewriteMergeIntoTable.scala | 29 +++++++--
.../v2/RowLevelCommandScanRelationPushDown.scala | 12 +++-
.../SparkRowLevelOperationsTestBase.java | 6 +-
.../apache/iceberg/spark/extensions/TestMerge.java | 75 ++++++++++++++++++++++
.../spark/source/SparkCopyOnWriteOperation.java | 2 -
.../apache/iceberg/spark/source/SparkWrite.java | 25 ++++++--
.../iceberg/spark/source/SparkWriteBuilder.java | 2 -
7 files changed, 134 insertions(+), 17 deletions(-)
diff --git
a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
index 2e720bdd44..ca37f99955 100644
---
a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
+++
b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
@@ -22,6 +22,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.ProjectingInternalRow
import org.apache.spark.sql.catalyst.expressions.Alias
+import org.apache.spark.sql.catalyst.expressions.And
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.AttributeSet
@@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID
+import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.LeftAnti
@@ -74,7 +76,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* This rule assumes the commands have been fully resolved and all assignments
have been aligned.
* That's why it must be run after AlignRowLevelCommandAssignments.
*/
-object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
+object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with
PredicateHelper {
private final val ROW_FROM_SOURCE = "__row_from_source"
private final val ROW_FROM_TARGET = "__row_from_target"
@@ -185,12 +187,14 @@ object RewriteMergeIntoTable extends
RewriteRowLevelIcebergCommand {
val readRelation = buildRelationWithAttrs(relation, operationTable,
metadataAttrs)
val readAttrs = readRelation.output
+ val (targetCond, joinCond) = splitMergeCond(cond, readRelation)
+
// project an extra column to check if a target row exists after the join
// project a synthetic row ID to perform the cardinality check
val rowFromTarget = Alias(TrueLiteral, ROW_FROM_TARGET)()
val rowId = Alias(MonotonicallyIncreasingID(), ROW_ID)()
val targetTableProjExprs = readAttrs ++ Seq(rowFromTarget, rowId)
- val targetTableProj = Project(targetTableProjExprs, readRelation)
+ val targetTableProj = Project(targetTableProjExprs, Filter(targetCond,
readRelation))
// project an extra column to check if a source row exists after the join
val rowFromSource = Alias(TrueLiteral, ROW_FROM_SOURCE)()
@@ -202,7 +206,7 @@ object RewriteMergeIntoTable extends
RewriteRowLevelIcebergCommand {
// disable broadcasts for the target table to perform the cardinality check
val joinType = if (notMatchedActions.isEmpty) LeftOuter else FullOuter
val joinHint = JoinHint(leftHint =
Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None)
- val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj,
joinType, Some(cond), joinHint)
+ val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj,
joinType, Some(joinCond), joinHint)
// add an extra matched action to output the original row if none of the
actual actions matched
// this is needed to keep target rows that should be copied over
@@ -253,9 +257,11 @@ object RewriteMergeIntoTable extends
RewriteRowLevelIcebergCommand {
val readRelation = buildRelationWithAttrs(relation, operationTable,
rowIdAttrs ++ metadataAttrs)
val readAttrs = readRelation.output
+ val (targetCond, joinCond) = splitMergeCond(cond, readRelation)
+
// project an extra column to check if a target row exists after the join
val targetTableProjExprs = readAttrs :+ Alias(TrueLiteral,
ROW_FROM_TARGET)()
- val targetTableProj = Project(targetTableProjExprs, readRelation)
+ val targetTableProj = Project(targetTableProjExprs, Filter(targetCond,
readRelation))
// project an extra column to check if a source row exists after the join
val sourceTableProjExprs = source.output :+ Alias(TrueLiteral,
ROW_FROM_SOURCE)()
@@ -266,7 +272,7 @@ object RewriteMergeIntoTable extends
RewriteRowLevelIcebergCommand {
// also disable broadcasts for the target table to perform the cardinality
check
val joinType = if (notMatchedActions.isEmpty) Inner else RightOuter
val joinHint = JoinHint(leftHint =
Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None)
- val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj,
joinType, Some(cond), joinHint)
+ val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj,
joinType, Some(joinCond), joinHint)
val deleteRowValues = buildDeltaDeleteRowValues(rowAttrs, rowIdAttrs)
val metadataReadAttrs = readAttrs.filterNot(relation.outputSet.contains)
@@ -439,4 +445,17 @@ object RewriteMergeIntoTable extends
RewriteRowLevelIcebergCommand {
ProjectingInternalRow(schema, projectedOrdinals)
}
+
+ // splits the MERGE condition into a predicate that references columns only
from the target table,
+ // which can be pushed down, and a predicate used as a join condition to
find matches
+ private def splitMergeCond(
+ cond: Expression,
+ targetTable: LogicalPlan): (Expression, Expression) = {
+
+ val (targetPredicates, joinPredicates) = splitConjunctivePredicates(cond)
+ .partition(_.references.subsetOf(targetTable.outputSet))
+ val targetCond = targetPredicates.reduceOption(And).getOrElse(TrueLiteral)
+ val joinCond = joinPredicates.reduceOption(And).getOrElse(TrueLiteral)
+ (targetCond, joinCond)
+ }
}
diff --git
a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
index 2be73cb6ee..9ee3035c26 100644
---
a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
+++
b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
@@ -26,6 +26,8 @@ import
org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable
+import org.apache.spark.sql.catalyst.plans.logical.WriteDelta
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.filter.Predicate
@@ -38,8 +40,16 @@ object RowLevelCommandScanRelationPushDown extends
Rule[LogicalPlan] with Predic
import ExtendedDataSourceV2Implicits._
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+ // use native Spark planning for delta-based plans and copy-on-write MERGE
operations
+ // unlike other commands, these plans have filters that can be pushed down
directly
+ case RewrittenRowLevelCommand(command, _: DataSourceV2Relation,
rewritePlan)
+ if rewritePlan.isInstanceOf[WriteDelta] ||
command.isInstanceOf[MergeIntoIcebergTable] =>
+
+ val newRewritePlan = V2ScanRelationPushDown.apply(rewritePlan)
+ command.withNewRewritePlan(newRewritePlan)
+
// push down the filter from the command condition instead of the filter
in the rewrite plan,
- // which may be negated for copy-on-write operations
+ // which may be negated for copy-on-write DELETE and UPDATE operations
case RewrittenRowLevelCommand(command, relation: DataSourceV2Relation,
rewritePlan) =>
val table = relation.table.asRowLevelOperationTable
val scanBuilder = table.newScanBuilder(relation.options)
diff --git
a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
index 633b2ee431..3039958d2c 100644
---
a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
+++
b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
@@ -170,7 +170,11 @@ public abstract class SparkRowLevelOperationsTestBase
extends SparkExtensionsTes
}
protected void createAndInitTable(String schema, String jsonData) {
- sql("CREATE TABLE %s (%s) USING iceberg", tableName, schema);
+ createAndInitTable(schema, "", jsonData);
+ }
+
+ protected void createAndInitTable(String schema, String partitioning, String
jsonData) {
+ sql("CREATE TABLE %s (%s) USING iceberg %s", tableName, schema,
partitioning);
initTable();
if (jsonData != null) {
diff --git
a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
index 58fbb6241e..c598cb720c 100644
---
a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
+++
b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
@@ -18,7 +18,10 @@
*/
package org.apache.iceberg.spark.extensions;
+import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE;
import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL;
+import static org.apache.iceberg.TableProperties.MERGE_MODE;
+import static org.apache.iceberg.TableProperties.MERGE_MODE_DEFAULT;
import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES;
import static org.apache.iceberg.TableProperties.SPLIT_SIZE;
import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE;
@@ -39,6 +42,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import org.apache.iceberg.AssertHelpers;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.DistributionMode;
+import org.apache.iceberg.RowLevelOperationMode;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.Table;
@@ -54,6 +58,7 @@ import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.internal.SQLConf;
import org.assertj.core.api.Assertions;
import org.junit.After;
@@ -85,6 +90,55 @@ public abstract class TestMerge extends
SparkRowLevelOperationsTestBase {
sql("DROP TABLE IF EXISTS source");
}
+ @Test
+ public void testMergeConditionSplitIntoTargetPredicateAndJoinCondition() {
+ createAndInitTable(
+ "id INT, salary INT, dep STRING, sub_dep STRING",
+ "PARTITIONED BY (dep, sub_dep)",
+ "{ \"id\": 1, \"salary\": 100, \"dep\": \"d1\", \"sub_dep\": \"sd1\"
}\n"
+ + "{ \"id\": 6, \"salary\": 600, \"dep\": \"d6\", \"sub_dep\":
\"sd6\" }");
+
+ createOrReplaceView(
+ "source",
+ "id INT, salary INT, dep STRING, sub_dep STRING",
+ "{ \"id\": 1, \"salary\": 101, \"dep\": \"d1\", \"sub_dep\": \"sd1\"
}\n"
+ + "{ \"id\": 2, \"salary\": 200, \"dep\": \"d2\", \"sub_dep\":
\"sd2\" }\n"
+ + "{ \"id\": 3, \"salary\": 300, \"dep\": \"d3\", \"sub_dep\":
\"sd3\" }");
+
+ String query =
+ String.format(
+ "MERGE INTO %s AS t USING source AS s "
+ + "ON t.id == s.id AND ((t.dep = 'd1' AND t.sub_dep IN ('sd1',
'sd3')) OR (t.dep = 'd6' AND t.sub_dep IN ('sd2', 'sd6'))) "
+ + "WHEN MATCHED THEN "
+ + " UPDATE SET salary = s.salary "
+ + "WHEN NOT MATCHED THEN "
+ + " INSERT *",
+ tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ if (mode(table) == COPY_ON_WRITE) {
+ checkJoinAndFilterConditions(
+ query,
+ "Join [id], [id], FullOuter",
+ "((dep = 'd1' AND sub_dep IN ('sd1', 'sd3')) OR (dep = 'd6' AND
sub_dep IN ('sd2', 'sd6')))");
+ } else {
+ checkJoinAndFilterConditions(
+ query,
+ "Join [id], [id], RightOuter",
+ "((dep = 'd1' AND sub_dep IN ('sd1', 'sd3')) OR (dep = 'd6' AND
sub_dep IN ('sd2', 'sd6')))");
+ }
+
+ assertEquals(
+ "Should have expected rows",
+ ImmutableList.of(
+ row(1, 101, "d1", "sd1"), // updated
+ row(2, 200, "d2", "sd2"), // new
+ row(3, 300, "d3", "sd3"), // new
+ row(6, 600, "d6", "sd6")), // existing
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
@Test
public void testMergeWithStaticPredicatePushDown() {
createAndInitTable("id BIGINT, dep STRING");
@@ -2274,4 +2328,25 @@ public abstract class TestMerge extends
SparkRowLevelOperationsTestBase {
List<Object[]> result = sql("SELECT * FROM %s ORDER BY id", tableName);
assertEquals("Should correctly add the non-matching rows", expectedRows,
result);
}
+
+ private void checkJoinAndFilterConditions(String query, String join, String
icebergFilters) {
+ // disable runtime filtering for easier validation
+ withSQLConf(
+ ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(),
"false"),
+ () -> {
+ SparkPlan sparkPlan = executeAndKeepPlan(() -> sql(query));
+ String planAsString = sparkPlan.toString().replaceAll("#(\\d+L?)",
"");
+
+ Assertions.assertThat(planAsString).as("Join should
match").contains(join + "\n");
+
+ Assertions.assertThat(planAsString)
+ .as("Pushed filters must match")
+ .contains("[filters=" + icebergFilters + ",");
+ });
+ }
+
+ private RowLevelOperationMode mode(Table table) {
+ String modeName = table.properties().getOrDefault(MERGE_MODE,
MERGE_MODE_DEFAULT);
+ return RowLevelOperationMode.fromName(modeName);
+ }
}
diff --git
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java
index 72c243fcbc..68c9944044 100644
---
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java
+++
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java
@@ -24,7 +24,6 @@ import static
org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPD
import org.apache.iceberg.IsolationLevel;
import org.apache.iceberg.MetadataColumns;
import org.apache.iceberg.Table;
-import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.expressions.Expressions;
import org.apache.spark.sql.connector.expressions.NamedReference;
@@ -81,7 +80,6 @@ class SparkCopyOnWriteOperation implements RowLevelOperation {
@Override
public WriteBuilder newWriteBuilder(LogicalWriteInfo info) {
if (lazyWriteBuilder == null) {
- Preconditions.checkState(configuredScan != null, "Write must be
configured after scan");
SparkWriteBuilder writeBuilder = new SparkWriteBuilder(spark, table,
info);
lazyWriteBuilder = writeBuilder.overwriteFiles(configuredScan, command,
isolationLevel);
}
diff --git
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java
index 0b19fec9fd..f77d96da7f 100644
---
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java
+++
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java
@@ -383,7 +383,11 @@ abstract class SparkWrite implements Write,
RequiresDistributionAndOrdering {
}
private List<DataFile> overwrittenFiles() {
- return
scan.tasks().stream().map(FileScanTask::file).collect(Collectors.toList());
+ if (scan == null) {
+ return ImmutableList.of();
+ } else {
+ return
scan.tasks().stream().map(FileScanTask::file).collect(Collectors.toList());
+ }
}
private Expression conflictDetectionFilter() {
@@ -415,12 +419,21 @@ abstract class SparkWrite implements Write,
RequiresDistributionAndOrdering {
overwriteFiles.addFile(file);
}
- if (isolationLevel == SERIALIZABLE) {
- commitWithSerializableIsolation(overwriteFiles, numOverwrittenFiles,
numAddedFiles);
- } else if (isolationLevel == SNAPSHOT) {
- commitWithSnapshotIsolation(overwriteFiles, numOverwrittenFiles,
numAddedFiles);
+ // the scan may be null if the optimizer replaces it with an empty
relation (e.g. false cond)
+ // no validation is needed in this case as the command does not depend
on the table state
+ if (scan != null) {
+ if (isolationLevel == SERIALIZABLE) {
+ commitWithSerializableIsolation(overwriteFiles, numOverwrittenFiles,
numAddedFiles);
+ } else if (isolationLevel == SNAPSHOT) {
+ commitWithSnapshotIsolation(overwriteFiles, numOverwrittenFiles,
numAddedFiles);
+ } else {
+ throw new IllegalArgumentException("Unsupported isolation level: " +
isolationLevel);
+ }
+
} else {
- throw new IllegalArgumentException("Unsupported isolation level: " +
isolationLevel);
+ commitOperation(
+ overwriteFiles,
+ String.format("overwrite with %d new data files (no validation)",
numAddedFiles));
}
}
diff --git
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java
index 6483f13048..55cf7961e9 100644
---
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java
+++
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java
@@ -86,8 +86,6 @@ class SparkWriteBuilder implements WriteBuilder,
SupportsDynamicOverwrite, Suppo
}
public WriteBuilder overwriteFiles(Scan scan, Command command,
IsolationLevel isolationLevel) {
- Preconditions.checkArgument(
- scan instanceof SparkCopyOnWriteScan, "%s is not
SparkCopyOnWriteScan", scan);
Preconditions.checkState(!overwriteByFilter, "Cannot overwrite individual
files and by filter");
Preconditions.checkState(
!overwriteDynamic, "Cannot overwrite individual files and
dynamically");