rdblue commented on a change in pull request #2022:
URL: https://github.com/apache/iceberg/pull/2022#discussion_r561408143
##########
File path:
spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala
##########
@@ -215,5 +227,41 @@ case class RewriteMergeInto(spark: SparkSession) extends
Rule[LogicalPlan] with
}
!(actions.size == 1 && hasUnconditionalDelete(actions.headOption))
}
+
+ def buildWritePlan(
+ childPlan: LogicalPlan,
+ table: Table): LogicalPlan = {
+ table match {
+ case iceTable: SparkTable =>
+ val numShufflePartitions = spark.sessionState.conf.numShufflePartitions
+ val table = iceTable.table()
+ val distributionMode: String = table.properties
+ .getOrDefault("write.distribution-mode",
TableProperties.WRITE_DISTRIBUTION_MODE_RANGE)
+ val order = toCatalyst(toOrderedDistribution(table.spec(),
table.sortOrder(), true), childPlan)
+ distributionMode.toLowerCase(Locale.ROOT) match {
+ case TableProperties.WRITE_DISTRIBUTION_MODE_DEFAULT =>
Review comment:
We should rename this to `WRITE_DISTRIBUTION_MODE_NONE` in a follow-up,
since the default depends on the engine. We can also add
`WRITE_DISTRIBUTION_MODE_FLINK_DEFAULT` and
`WRITE_DISTRIBUTION_MODE_SPARK_DEFAULT`.
##########
File path:
spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala
##########
@@ -215,5 +227,41 @@ case class RewriteMergeInto(spark: SparkSession) extends
Rule[LogicalPlan] with
}
!(actions.size == 1 && hasUnconditionalDelete(actions.headOption))
}
+
+ def buildWritePlan(
+ childPlan: LogicalPlan,
+ table: Table): LogicalPlan = {
+ table match {
+ case iceTable: SparkTable =>
+ val numShufflePartitions = spark.sessionState.conf.numShufflePartitions
+ val table = iceTable.table()
+ val distributionMode: String = table.properties
+ .getOrDefault("write.distribution-mode",
TableProperties.WRITE_DISTRIBUTION_MODE_RANGE)
+ val order = toCatalyst(toOrderedDistribution(table.spec(),
table.sortOrder(), true), childPlan)
+ distributionMode.toLowerCase(Locale.ROOT) match {
Review comment:
The Flink commit also added a Java enum for this. We could use that
instead of string matching here. It handles case insensitive mapping, too:
`DistributionMode.fromName(modeName)`
##########
File path:
spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala
##########
@@ -21,13 +21,12 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
case class MergeInto(
mergeIntoProcessor: MergeIntoParams,
- targetRelation: DataSourceV2Relation,
+ targetOutput: Seq[Attribute],
Review comment:
I think this could just be `output` and you wouldn't need to override
`def output` below.
##########
File path:
spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala
##########
@@ -32,10 +32,10 @@ import org.apache.spark.sql.execution.UnaryExecNode
case class MergeIntoExec(
mergeIntoParams: MergeIntoParams,
- @transient targetRelation: DataSourceV2Relation,
+ targetOutput: Seq[Attribute],
Review comment:
Same here, using `output` would make the method definition unnecessary.
##########
File path:
spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
##########
@@ -81,8 +81,8 @@ case class ExtendedDataSourceV2Strategy(spark: SparkSession)
extends Strategy {
case ReplaceData(_, batchWrite, query) =>
ReplaceDataExec(batchWrite, planLater(query)) :: Nil
- case MergeInto(mergeIntoProcessor, targetRelation, child) =>
- MergeIntoExec(mergeIntoProcessor, targetRelation, planLater(child)) ::
Nil
+ case MergeInto(mergeIntoParms, targetAttributes, child) =>
+ MergeIntoExec(mergeIntoParms, targetAttributes, planLater(child)) :: Nil
Review comment:
I think it would be more clear to use `output` instead of
`targetAttributes` here since that's what this is setting, but this is minor.
##########
File path:
spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala
##########
@@ -194,4 +220,84 @@ trait RewriteRowLevelOperationHelper extends
PredicateHelper with Logging {
}
}
}
+
+ private object BucketTransform {
+ def unapply(transform: Transform): Option[(Int, FieldReference)] =
transform match {
+ case bt: BucketTransform => bt.columns match {
+ case Seq(nf: NamedReference) =>
+ Some(bt.numBuckets.value(), FieldReference(nf.fieldNames()))
+ case _ =>
+ None
+ }
+ case _ => None
+ }
+ }
+
+ protected def toCatalyst(
+ distribution: Distribution,
+ plan: LogicalPlan): Seq[catalyst.expressions.Expression] = {
+
+ distribution match {
+ case d: OrderedDistribution =>
+ d.ordering.map(e => toCatalyst(e, plan, resolver))
+ case d: ClusteredDistribution =>
+ d.clustering.map(e => toCatalyst(e, plan, resolver))
+ case _: UnspecifiedDistribution =>
+ Array.empty[catalyst.expressions.Expression]
+ }
+ }
+
+ private def toCatalyst(
+ expr: Expression,
+ query: LogicalPlan,
+ resolver: Resolver): catalyst.expressions.Expression = {
+
+ def resolve(parts: Seq[String]): NamedExpression = {
+ // this part is controversial as we perform resolution in the optimizer
+ // we cannot perform this step in the analyzer since we need to optimize
expressions
+ // in nodes like OverwriteByExpression before constructing a logical
write
+ query.resolve(parts, resolver) match {
+ case Some(attr) => attr
+ case None => throw new AnalysisException(s"Cannot resolve
'${parts.map(quoteIfNeeded).mkString(".")}'" +
+ s" using ${query.output}")
+ }
+ }
+
+ expr match {
+ case s: SortOrder =>
+ val catalystChild = toCatalyst(s.expression(), query, resolver)
+ catalyst.expressions.SortOrder(catalystChild, toCatalyst(s.direction),
toCatalyst(s.nullOrdering), Set.empty)
+ case it: IdentityTransform =>
+ resolve(it.ref.fieldNames())
+ case BucketTransform(numBuckets, ref) =>
+ IcebergBucketTransform(numBuckets, resolve(ref.fieldNames))
+ case yt: YearsTransform =>
+ IcebergYearTransform(resolve(yt.ref.fieldNames))
+ case mt: MonthsTransform =>
+ IcebergMonthTransform(resolve(mt.ref.fieldNames))
+ case dt: DaysTransform =>
+ IcebergDayTransform(resolve(dt.ref.fieldNames))
+ case ht: HoursTransform =>
+ IcebergHourTransform(resolve(ht.ref.fieldNames))
+ case ref: FieldReference =>
+ resolve(ref.fieldNames)
+ case _ =>
+ throw new RuntimeException(s"$expr is not currently supported")
+ }
+ }
+
Review comment:
Nit: extra newline.
##########
File path:
spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala
##########
@@ -215,5 +227,41 @@ case class RewriteMergeInto(spark: SparkSession) extends
Rule[LogicalPlan] with
}
!(actions.size == 1 && hasUnconditionalDelete(actions.headOption))
}
+
+ def buildWritePlan(
+ childPlan: LogicalPlan,
+ table: Table): LogicalPlan = {
+ table match {
+ case iceTable: SparkTable =>
+ val numShufflePartitions = spark.sessionState.conf.numShufflePartitions
+ val table = iceTable.table()
+ val distributionMode: String = table.properties
+ .getOrDefault("write.distribution-mode",
TableProperties.WRITE_DISTRIBUTION_MODE_RANGE)
Review comment:
Isn't `write.distribution-mode` listed in `TableProperties`? We should
use the constant.
##########
File path:
spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java
##########
@@ -303,25 +305,107 @@ public void
testSingleUnconditionalDeleteDisbleCountCheck() throws NoSuchTableEx
}
@Test
- public void testSingleConditionalDeleteCountCheck() throws
NoSuchTableException {
Review comment:
Looks like this test case was accidentally deleted?
##########
File path:
spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java
##########
@@ -303,25 +305,107 @@ public void
testSingleUnconditionalDeleteDisbleCountCheck() throws NoSuchTableEx
}
@Test
- public void testSingleConditionalDeleteCountCheck() throws
NoSuchTableException {
- append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
- append(sourceName, new Employee(1, "emp-id-1"), new Employee(1,
"emp-id-1"),
- new Employee(2, "emp-id-2"), new Employee(6, "emp-id-6"));
+ public void testIdentityPartition() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY
(identity(dep))", targetName);
+ initTable(targetName);
+ setWriteMode(targetName, mode);
+ createAndInitSourceTable(sourceName);
+ append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
+ append(sourceName, new Employee(2, "emp-id-2"), new Employee(1,
"emp-id-1"), new Employee(6, "emp-id-6"));
+
+ String sqlText = "MERGE INTO " + targetName + " AS target \n" +
+ "USING " + sourceName + " AS source \n" +
+ "ON target.id = source.id \n" +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE \n" +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, "");
Review comment:
Minor: passing an extra empty string and passing table names embedded in
the SQL text.
##########
File path:
spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java
##########
@@ -355,9 +439,17 @@ private void initTable(String tabName) {
});
}
- protected void append(String tabName, Employee... employees) throws
NoSuchTableException {
- List<Employee> input = Arrays.asList(employees);
- Dataset<Row> inputDF = spark.createDataFrame(input, Employee.class);
- inputDF.coalesce(1).writeTo(tabName).append();
+ private void setWriteMode(String tabName, String mode) {
Review comment:
I think this should be `setDistributionMode` instead because
`setWriteMode` sounds more general, like "copy-on-write" or "merge-on-read".
##########
File path:
spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java
##########
@@ -303,25 +305,107 @@ public void
testSingleUnconditionalDeleteDisbleCountCheck() throws NoSuchTableEx
}
@Test
- public void testSingleConditionalDeleteCountCheck() throws
NoSuchTableException {
- append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
- append(sourceName, new Employee(1, "emp-id-1"), new Employee(1,
"emp-id-1"),
- new Employee(2, "emp-id-2"), new Employee(6, "emp-id-6"));
+ public void testIdentityPartition() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY
(identity(dep))", targetName);
+ initTable(targetName);
+ setWriteMode(targetName, mode);
+ createAndInitSourceTable(sourceName);
+ append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
+ append(sourceName, new Employee(2, "emp-id-2"), new Employee(1,
"emp-id-1"), new Employee(6, "emp-id-6"));
+
+ String sqlText = "MERGE INTO " + targetName + " AS target \n" +
+ "USING " + sourceName + " AS source \n" +
+ "ON target.id = source.id \n" +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE \n" +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, "");
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName));
+ });
+ }
- String sqlText = "MERGE INTO %s AS target " +
- "USING %s AS source " +
- "ON target.id = source.id " +
- "WHEN MATCHED AND target.id = 1 THEN DELETE " +
- "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+ @Test
+ public void testDaysTransform() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, ts timestamp) USING iceberg PARTITIONED BY
(days(ts))", targetName);
+ initTable(targetName);
+ setWriteMode(targetName, mode);
+ sql("CREATE TABLE %s (id INT, ts timestamp) USING iceberg", sourceName);
+ initTable(sourceName);
+ sql("INSERT INTO " + targetName + " VALUES (1, timestamp('2001-01-01
00:00:00'))," +
+ "(6, timestamp('2001-01-06 00:00:00'))");
+ sql("INSERT INto " + sourceName + " VALUES (2, timestamp('2001-01-02
00:00:00'))," +
+ "(1, timestamp('2001-01-01 00:00:00'))," +
+ "(6, timestamp('2001-01-06 00:00:00'))");
+
+ String sqlText = "MERGE INTO " + targetName + " AS target \n" +
+ "USING " + sourceName + " AS source \n" +
+ "ON target.id = source.id \n" +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE \n" +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, "");
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "2001-01-01 00:00:00"), row(2,
"2001-01-02 00:00:00")),
+ sql("SELECT id, CAST(ts AS STRING) FROM %s ORDER BY id ASC NULLS
LAST", targetName));
+ });
+ }
- String tabName = catalogName + "." + "default.target";
- String errorMsg = "The same row of target table `" + tabName + "` was
identified more than\n" +
- " once for an update, delete or insert operation of the MERGE
statement.";
- AssertHelpers.assertThrows("Should complain ambiguous row in target",
- SparkException.class, errorMsg, () -> sql(sqlText, targetName,
sourceName));
- assertEquals("Target should be unchanged",
- ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")),
- sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName));
+ @Test
+ public void testBucketExpression() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg" +
+ " CLUSTERED BY (dep) INTO 2 BUCKETS", targetName);
+ initTable(targetName);
+ setWriteMode(targetName, mode);
+ createAndInitSourceTable(sourceName);
+ append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
+ append(sourceName, new Employee(2, "emp-id-2"), new Employee(1,
"emp-id-1"), new Employee(6, "emp-id-6"));
+ String sqlText = "MERGE INTO " + targetName + " AS target \n" +
+ "USING " + sourceName + " AS source \n" +
+ "ON target.id = source.id \n" +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE \n" +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, "");
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName));
+ });
+ }
+
+ @Test
+ public void testPartitionedAndOrderedTable() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg" +
+ " PARTITIONED BY (id) CLUSTERED BY (dep) INTO 2 BUCKETS",
targetName);
+ initTable(targetName);
+ setWriteMode(targetName, mode);
+ createAndInitSourceTable(sourceName);
+ append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
+ append(sourceName, new Employee(2, "emp-id-2"), new Employee(1,
"emp-id-1"), new Employee(6, "emp-id-6"));
+ String sqlText = "MERGE INTO " + targetName + " AS target \n" +
+ "USING " + sourceName + " AS source \n" +
+ "ON target.id = source.id \n" +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE \n" +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, "");
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName));
Review comment:
Nit: Indentation is off in the new methods. Should be 2 indents or 4
spaces for continuations.
##########
File path:
spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java
##########
@@ -303,25 +305,107 @@ public void
testSingleUnconditionalDeleteDisbleCountCheck() throws NoSuchTableEx
}
@Test
- public void testSingleConditionalDeleteCountCheck() throws
NoSuchTableException {
- append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
- append(sourceName, new Employee(1, "emp-id-1"), new Employee(1,
"emp-id-1"),
- new Employee(2, "emp-id-2"), new Employee(6, "emp-id-6"));
+ public void testIdentityPartition() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY
(identity(dep))", targetName);
+ initTable(targetName);
+ setWriteMode(targetName, mode);
+ createAndInitSourceTable(sourceName);
+ append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
+ append(sourceName, new Employee(2, "emp-id-2"), new Employee(1,
"emp-id-1"), new Employee(6, "emp-id-6"));
+
+ String sqlText = "MERGE INTO " + targetName + " AS target \n" +
+ "USING " + sourceName + " AS source \n" +
+ "ON target.id = source.id \n" +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE \n" +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, "");
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName));
+ });
+ }
- String sqlText = "MERGE INTO %s AS target " +
- "USING %s AS source " +
- "ON target.id = source.id " +
- "WHEN MATCHED AND target.id = 1 THEN DELETE " +
- "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+ @Test
+ public void testDaysTransform() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, ts timestamp) USING iceberg PARTITIONED BY
(days(ts))", targetName);
+ initTable(targetName);
+ setWriteMode(targetName, mode);
+ sql("CREATE TABLE %s (id INT, ts timestamp) USING iceberg", sourceName);
+ initTable(sourceName);
+ sql("INSERT INTO " + targetName + " VALUES (1, timestamp('2001-01-01
00:00:00'))," +
+ "(6, timestamp('2001-01-06 00:00:00'))");
+ sql("INSERT INto " + sourceName + " VALUES (2, timestamp('2001-01-02
00:00:00'))," +
+ "(1, timestamp('2001-01-01 00:00:00'))," +
+ "(6, timestamp('2001-01-06 00:00:00'))");
+
+ String sqlText = "MERGE INTO " + targetName + " AS target \n" +
+ "USING " + sourceName + " AS source \n" +
+ "ON target.id = source.id \n" +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE \n" +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, "");
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "2001-01-01 00:00:00"), row(2,
"2001-01-02 00:00:00")),
+ sql("SELECT id, CAST(ts AS STRING) FROM %s ORDER BY id ASC NULLS
LAST", targetName));
+ });
+ }
- String tabName = catalogName + "." + "default.target";
- String errorMsg = "The same row of target table `" + tabName + "` was
identified more than\n" +
- " once for an update, delete or insert operation of the MERGE
statement.";
- AssertHelpers.assertThrows("Should complain ambiguous row in target",
- SparkException.class, errorMsg, () -> sql(sqlText, targetName,
sourceName));
- assertEquals("Target should be unchanged",
- ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")),
- sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName));
+ @Test
+ public void testBucketExpression() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg" +
+ " CLUSTERED BY (dep) INTO 2 BUCKETS", targetName);
+ initTable(targetName);
+ setWriteMode(targetName, mode);
+ createAndInitSourceTable(sourceName);
+ append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
+ append(sourceName, new Employee(2, "emp-id-2"), new Employee(1,
"emp-id-1"), new Employee(6, "emp-id-6"));
+ String sqlText = "MERGE INTO " + targetName + " AS target \n" +
+ "USING " + sourceName + " AS source \n" +
+ "ON target.id = source.id \n" +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE \n" +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, "");
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName));
+ });
+ }
+
+ @Test
+ public void testPartitionedAndOrderedTable() {
Review comment:
Where does this set the table ordering? I would expect it to run `ALTER
TABLE %s WRITE ORDERED BY ...`
##########
File path:
spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala
##########
@@ -85,7 +96,7 @@ case class RewriteMergeInto(spark: SparkSession) extends
Rule[LogicalPlan] with
joinedAttributes = joinPlan.output
)
- val mergePlan = MergeInto(mergeParams, target, joinPlan)
+ val mergePlan = MergeInto(mergeParams, target.output, joinPlan)
Review comment:
Why not sort this case as well? If the user has requested a sort order
on the table, it makes sense to enforce it even if we aren't also rewriting
files.
##########
File path:
spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala
##########
@@ -142,10 +154,10 @@ case class RewriteMergeInto(spark: SparkSession) extends
Rule[LogicalPlan] with
targetOutput = target.output,
joinedAttributes = joinPlan.output
)
- val mergePlan = MergeInto(mergeParams, target, joinPlan)
+ val mergePlan = MergeInto(mergeParams, target.output, joinPlan)
+ val writePlan = buildWritePlan(mergePlan, target.table)
val batchWrite = mergeBuilder.asWriteBuilder.buildForBatch()
-
- ReplaceData(target, batchWrite, mergePlan)
+ ReplaceData(target, batchWrite, writePlan)
Review comment:
Nit: newline was removed, which is a whitespace change.
##########
File path:
spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala
##########
@@ -215,5 +227,42 @@ case class RewriteMergeInto(spark: SparkSession) extends
Rule[LogicalPlan] with
}
!(actions.size == 1 && hasUnconditionalDelete(actions.headOption))
}
+
+ def buildWritePlan(
+ childPlan: LogicalPlan,
+ table: Table): LogicalPlan = {
+ table match {
+ case iceTable: SparkTable =>
+ val numShufflePartitions = spark.sessionState.conf.numShufflePartitions
+ val table = iceTable.table()
+ val distributionMode: String = table.properties
+ .getOrDefault(TableProperties.WRITE_DISTRIBUTION_MODE,
TableProperties.WRITE_DISTRIBUTION_MODE_RANGE)
+ val mode = DistributionMode.fromName(distributionMode)
+ val order = toCatalyst(toOrderedDistribution(table.spec(),
table.sortOrder(), true), childPlan)
+ mode match {
Review comment:
Minor: The `mode` variable isn't really needed. It could be
`DistributionMode.fromName(distributionMode) match { ... }`
##########
File path:
spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java
##########
@@ -324,6 +326,111 @@ public void testSingleConditionalDeleteCountCheck()
throws NoSuchTableException
sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName));
}
+ @Test
+ public void testIdentityPartition() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY
(identity(dep))", targetName);
+ initTable(targetName);
+ setDistributionMode(targetName, mode);
+ createAndInitSourceTable(sourceName);
+ append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
+ append(sourceName, new Employee(2, "emp-id-2"), new Employee(1,
"emp-id-1"), new Employee(6, "emp-id-6"));
+
+ String sqlText = "MERGE INTO %s AS target " +
+ "USING %s AS source " +
+ "ON target.id = source.id " +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * " +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE " +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, targetName, sourceName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName));
+ });
+ }
+
+ @Test
+ public void testDaysTransform() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, ts timestamp) USING iceberg PARTITIONED BY
(days(ts))", targetName);
+ initTable(targetName);
+ setDistributionMode(targetName, mode);
+ sql("CREATE TABLE %s (id INT, ts timestamp) USING iceberg", sourceName);
+ initTable(sourceName);
+ sql("INSERT INTO " + targetName + " VALUES (1, timestamp('2001-01-01
00:00:00'))," +
+ "(6, timestamp('2001-01-06 00:00:00'))");
+ sql("INSERT INto " + sourceName + " VALUES (2, timestamp('2001-01-02
00:00:00'))," +
+ "(1, timestamp('2001-01-01 00:00:00'))," +
+ "(6, timestamp('2001-01-06 00:00:00'))");
+
+ String sqlText = "MERGE INTO %s AS target \n" +
+ "USING %s AS source \n" +
+ "ON target.id = source.id \n" +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE \n" +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, targetName, sourceName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "2001-01-01 00:00:00"), row(2,
"2001-01-02 00:00:00")),
+ sql("SELECT id, CAST(ts AS STRING) FROM %s ORDER BY id ASC NULLS
LAST", targetName));
+ });
+ }
+
+ @Test
+ public void testBucketExpression() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg" +
+ " CLUSTERED BY (dep) INTO 2 BUCKETS", targetName);
+ initTable(targetName);
+ setDistributionMode(targetName, mode);
+ createAndInitSourceTable(sourceName);
+ append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
+ append(sourceName, new Employee(2, "emp-id-2"), new Employee(1,
"emp-id-1"), new Employee(6, "emp-id-6"));
+ String sqlText = "MERGE INTO %s AS target \n" +
+ "USING %s AS source \n" +
+ "ON target.id = source.id \n" +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE \n" +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, targetName, sourceName);
+ assertEquals("Should have expected rows",
+ ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName));
+ });
+ }
+
+ @Test
+ public void testPartitionedAndOrderedTable() {
+ writeModes.forEach(mode -> {
+ removeTables();
+ sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg" +
+ " PARTITIONED BY (id)", targetName);
+ sql("ALTER TABLE %s WRITE ORDERED BY (dep)", targetName);
+ initTable(targetName);
+ setDistributionMode(targetName, mode);
+ createAndInitSourceTable(sourceName);
+ append(targetName, new Employee(1, "emp-id-one"), new Employee(6,
"emp-id-6"));
+ append(sourceName, new Employee(2, "emp-id-2"), new Employee(1,
"emp-id-1"), new Employee(6, "emp-id-6"));
+ String sqlText = "MERGE INTO " + targetName + " AS target \n" +
+ "USING " + sourceName + " AS source \n" +
+ "ON target.id = source.id \n" +
+ "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" +
+ "WHEN MATCHED AND target.id = 6 THEN DELETE \n" +
+ "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * ";
+
+ sql(sqlText, "");
Review comment:
Nit: use of "" instead of filling in table names in this test as well.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]