This is an automated email from the ASF dual-hosted git repository.
szehon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/main by this push:
new a47937c0c1 Spark 3.5: Support Aggregate push down for incremental scan
(#10538)
a47937c0c1 is described below
commit a47937c0c1fcafe57d7dc83551d8c9a3ce0ab1b9
Author: Huaxin Gao <[email protected]>
AuthorDate: Fri Jun 21 10:34:13 2024 -0700
Spark 3.5: Support Aggregate push down for incremental scan (#10538)
---
.../iceberg/spark/source/SparkScanBuilder.java | 89 +++++++++-------------
.../spark/source/TestDataSourceOptions.java | 32 ++++++--
.../iceberg/spark/sql/TestAggregatePushDown.java | 50 ++++++++++++
3 files changed, 109 insertions(+), 62 deletions(-)
diff --git
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
index d6f34231ae..b430e6fca2 100644
---
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
+++
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
@@ -37,7 +37,6 @@ import org.apache.iceberg.SparkDistributedDataScan;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
-import org.apache.iceberg.TableScan;
import org.apache.iceberg.expressions.AggregateEvaluator;
import org.apache.iceberg.expressions.Binder;
import org.apache.iceberg.expressions.BoundAggregate;
@@ -232,15 +231,8 @@ public class SparkScanBuilder
return false;
}
- TableScan scan = table.newScan().includeColumnStats();
- Snapshot snapshot = readSnapshot();
- if (snapshot == null) {
- LOG.info("Skipping aggregate pushdown: table snapshot is null");
- return false;
- }
- scan = scan.useSnapshot(snapshot.snapshotId());
- scan = configureSplitPlanning(scan);
- scan = scan.filter(filterExpression());
+ org.apache.iceberg.Scan scan =
+ buildIcebergBatchScan(true /* include Column Stats */,
schemaWithMetadataColumns());
try (CloseableIterable<FileScanTask> fileScanTasks = scan.planFiles()) {
List<FileScanTask> tasks = ImmutableList.copyOf(fileScanTasks);
@@ -282,11 +274,6 @@ public class SparkScanBuilder
return false;
}
- if (readConf.startSnapshotId() != null) {
- LOG.info("Skipping aggregate pushdown: incremental scan is not
supported");
- return false;
- }
-
// If group by expression is the same as the partition, the statistics
information can still
// be used to calculate min/max/count, will enable aggregate push down in
next phase.
// TODO: enable aggregate push down for partition col group by expression
@@ -298,17 +285,6 @@ public class SparkScanBuilder
return true;
}
- private Snapshot readSnapshot() {
- Snapshot snapshot;
- if (readConf.snapshotId() != null) {
- snapshot = table.snapshot(readConf.snapshotId());
- } else {
- snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch());
- }
-
- return snapshot;
- }
-
private boolean metricsModeSupportsAggregatePushDown(List<BoundAggregate<?,
?>> aggregates) {
MetricsConfig config = MetricsConfig.forTable(table);
for (BoundAggregate aggregate : aggregates) {
@@ -387,6 +363,18 @@ public class SparkScanBuilder
}
private Scan buildBatchScan() {
+ Schema expectedSchema = schemaWithMetadataColumns();
+ return new SparkBatchQueryScan(
+ spark,
+ table,
+ buildIcebergBatchScan(false /* not include Column Stats */,
expectedSchema),
+ readConf,
+ expectedSchema,
+ filterExpressions,
+ metricsReporter::scanReport);
+ }
+
+ private org.apache.iceberg.Scan buildIcebergBatchScan(boolean withStats,
Schema expectedSchema) {
Long snapshotId = readConf.snapshotId();
Long asOfTimestamp = readConf.asOfTimestamp();
String branch = readConf.branch();
@@ -427,15 +415,19 @@ public class SparkScanBuilder
SparkReadOptions.END_TIMESTAMP);
if (startSnapshotId != null) {
- return buildIncrementalAppendScan(startSnapshotId, endSnapshotId);
+ return buildIncrementalAppendScan(startSnapshotId, endSnapshotId,
withStats, expectedSchema);
} else {
- return buildBatchScan(snapshotId, asOfTimestamp, branch, tag);
+ return buildBatchScan(snapshotId, asOfTimestamp, branch, tag, withStats,
expectedSchema);
}
}
- private Scan buildBatchScan(Long snapshotId, Long asOfTimestamp, String
branch, String tag) {
- Schema expectedSchema = schemaWithMetadataColumns();
-
+ private org.apache.iceberg.Scan buildBatchScan(
+ Long snapshotId,
+ Long asOfTimestamp,
+ String branch,
+ String tag,
+ boolean withStats,
+ Schema expectedSchema) {
BatchScan scan =
newBatchScan()
.caseSensitive(caseSensitive)
@@ -443,6 +435,10 @@ public class SparkScanBuilder
.project(expectedSchema)
.metricsReporter(metricsReporter);
+ if (withStats) {
+ scan = scan.includeColumnStats();
+ }
+
if (snapshotId != null) {
scan = scan.useSnapshot(snapshotId);
}
@@ -459,21 +455,11 @@ public class SparkScanBuilder
scan = scan.useRef(tag);
}
- scan = configureSplitPlanning(scan);
-
- return new SparkBatchQueryScan(
- spark,
- table,
- scan,
- readConf,
- expectedSchema,
- filterExpressions,
- metricsReporter::scanReport);
+ return configureSplitPlanning(scan);
}
- private Scan buildIncrementalAppendScan(long startSnapshotId, Long
endSnapshotId) {
- Schema expectedSchema = schemaWithMetadataColumns();
-
+ private org.apache.iceberg.Scan buildIncrementalAppendScan(
+ long startSnapshotId, Long endSnapshotId, boolean withStats, Schema
expectedSchema) {
IncrementalAppendScan scan =
table
.newIncrementalAppendScan()
@@ -483,20 +469,15 @@ public class SparkScanBuilder
.project(expectedSchema)
.metricsReporter(metricsReporter);
+ if (withStats) {
+ scan = scan.includeColumnStats();
+ }
+
if (endSnapshotId != null) {
scan = scan.toSnapshot(endSnapshotId);
}
- scan = configureSplitPlanning(scan);
-
- return new SparkBatchQueryScan(
- spark,
- table,
- scan,
- readConf,
- expectedSchema,
- filterExpressions,
- metricsReporter::scanReport);
+ return configureSplitPlanning(scan);
}
@SuppressWarnings("CyclomaticComplexity")
diff --git
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java
index ff6ddea323..627fe15f28 100644
---
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java
+++
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java
@@ -57,6 +57,7 @@ import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.functions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestTemplate;
@@ -290,29 +291,44 @@ public class TestDataSourceOptions extends
TestBaseWithCatalog {
"Cannot set only end-snapshot-id for incremental scans. Please,
set start-snapshot-id too.");
// test (1st snapshot, current snapshot] incremental scan.
- List<SimpleRecord> result =
+ Dataset<Row> unboundedIncrementalResult =
spark
.read()
.format("iceberg")
.option("start-snapshot-id", snapshotIds.get(3).toString())
- .load(tableLocation)
+ .load(tableLocation);
+ List<SimpleRecord> result1 =
+ unboundedIncrementalResult
.orderBy("id")
.as(Encoders.bean(SimpleRecord.class))
.collectAsList();
- assertThat(result).as("Records should
match").isEqualTo(expectedRecords.subList(1, 4));
+ assertThat(result1).as("Records should
match").isEqualTo(expectedRecords.subList(1, 4));
+ assertThat(unboundedIncrementalResult.count())
+ .as("Unprocessed count should match record count")
+ .isEqualTo(3);
+
+ Row row1 = unboundedIncrementalResult.agg(functions.min("id"),
functions.max("id")).head();
+ assertThat(row1.getInt(0)).as("min value should match").isEqualTo(2);
+ assertThat(row1.getInt(1)).as("max value should match").isEqualTo(4);
// test (2nd snapshot, 3rd snapshot] incremental scan.
- Dataset<Row> resultDf =
+ Dataset<Row> incrementalResult =
spark
.read()
.format("iceberg")
.option("start-snapshot-id", snapshotIds.get(2).toString())
.option("end-snapshot-id", snapshotIds.get(1).toString())
.load(tableLocation);
- List<SimpleRecord> result1 =
-
resultDf.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
- assertThat(result1).as("Records should
match").isEqualTo(expectedRecords.subList(2, 3));
- assertThat(resultDf.count()).as("Unprocessed count should match record
count").isEqualTo(1);
+ List<SimpleRecord> result2 =
+
incrementalResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
+ assertThat(result2).as("Records should
match").isEqualTo(expectedRecords.subList(2, 3));
+ assertThat(incrementalResult.count())
+ .as("Unprocessed count should match record count")
+ .isEqualTo(1);
+
+ Row row2 = incrementalResult.agg(functions.min("id"),
functions.max("id")).head();
+ assertThat(row2.getInt(0)).as("min value should match").isEqualTo(3);
+ assertThat(row2.getInt(1)).as("max value should match").isEqualTo(3);
}
@TestTemplate
diff --git
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
index 05515946c1..7e9bdeec8a 100644
---
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
+++
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
@@ -35,8 +35,13 @@ import org.apache.iceberg.hive.TestHiveMetastore;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.CatalogTestBase;
+import org.apache.iceberg.spark.SparkReadOptions;
import org.apache.iceberg.spark.TestBase;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.execution.ExplainMode;
+import org.apache.spark.sql.functions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestTemplate;
@@ -808,4 +813,49 @@ public class TestAggregatePushDown extends CatalogTestBase
{
});
assertEquals("min/max/count push down", expected, actual);
}
+
+ @TestTemplate
+ public void testAggregatePushDownForIncrementalScan() {
+ sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName);
+ sql(
+ "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2,
4444), (3, 5555), (3, 6666) ",
+ tableName);
+ long snapshotId1 =
validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
+ sql("INSERT INTO %s VALUES (4, 7777), (5, 8888)", tableName);
+ long snapshotId2 =
validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
+ sql("INSERT INTO %s VALUES (6, -7777), (7, 8888)", tableName);
+ long snapshotId3 =
validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
+ sql("INSERT INTO %s VALUES (8, 7777), (9, 9999)", tableName);
+
+ Dataset<Row> pushdownDs =
+ spark
+ .read()
+ .format("iceberg")
+ .option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId2)
+ .option(SparkReadOptions.END_SNAPSHOT_ID, snapshotId3)
+ .load(tableName)
+ .agg(functions.min("data"), functions.max("data"),
functions.count("data"));
+ String explain1 =
pushdownDs.queryExecution().explainString(ExplainMode.fromString("simple"));
+ assertThat(explain1).contains("LocalTableScan", "min(data)", "max(data)",
"count(data)");
+
+ List<Object[]> expected1 = Lists.newArrayList();
+ expected1.add(new Object[] {-7777, 8888, 2L});
+ assertEquals("min/max/count push down", expected1,
rowsToJava(pushdownDs.collectAsList()));
+
+ Dataset<Row> unboundedPushdownDs =
+ spark
+ .read()
+ .format("iceberg")
+ .option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId1)
+ .load(tableName)
+ .agg(functions.min("data"), functions.max("data"),
functions.count("data"));
+ String explain2 =
+
unboundedPushdownDs.queryExecution().explainString(ExplainMode.fromString("simple"));
+ assertThat(explain2).contains("LocalTableScan", "min(data)", "max(data)",
"count(data)");
+
+ List<Object[]> expected2 = Lists.newArrayList();
+ expected2.add(new Object[] {-7777, 9999, 6L});
+ assertEquals(
+ "min/max/count push down", expected2,
rowsToJava(unboundedPushdownDs.collectAsList()));
+ }
}