This is an automated email from the ASF dual-hosted git repository.

szehon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new a47937c0c1 Spark 3.5: Support Aggregate push down for incremental scan 
(#10538)
a47937c0c1 is described below

commit a47937c0c1fcafe57d7dc83551d8c9a3ce0ab1b9
Author: Huaxin Gao <[email protected]>
AuthorDate: Fri Jun 21 10:34:13 2024 -0700

    Spark 3.5: Support Aggregate push down for incremental scan (#10538)
---
 .../iceberg/spark/source/SparkScanBuilder.java     | 89 +++++++++-------------
 .../spark/source/TestDataSourceOptions.java        | 32 ++++++--
 .../iceberg/spark/sql/TestAggregatePushDown.java   | 50 ++++++++++++
 3 files changed, 109 insertions(+), 62 deletions(-)

diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
index d6f34231ae..b430e6fca2 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
@@ -37,7 +37,6 @@ import org.apache.iceberg.SparkDistributedDataScan;
 import org.apache.iceberg.StructLike;
 import org.apache.iceberg.Table;
 import org.apache.iceberg.TableProperties;
-import org.apache.iceberg.TableScan;
 import org.apache.iceberg.expressions.AggregateEvaluator;
 import org.apache.iceberg.expressions.Binder;
 import org.apache.iceberg.expressions.BoundAggregate;
@@ -232,15 +231,8 @@ public class SparkScanBuilder
       return false;
     }
 
-    TableScan scan = table.newScan().includeColumnStats();
-    Snapshot snapshot = readSnapshot();
-    if (snapshot == null) {
-      LOG.info("Skipping aggregate pushdown: table snapshot is null");
-      return false;
-    }
-    scan = scan.useSnapshot(snapshot.snapshotId());
-    scan = configureSplitPlanning(scan);
-    scan = scan.filter(filterExpression());
+    org.apache.iceberg.Scan scan =
+        buildIcebergBatchScan(true /* include Column Stats */, 
schemaWithMetadataColumns());
 
     try (CloseableIterable<FileScanTask> fileScanTasks = scan.planFiles()) {
       List<FileScanTask> tasks = ImmutableList.copyOf(fileScanTasks);
@@ -282,11 +274,6 @@ public class SparkScanBuilder
       return false;
     }
 
-    if (readConf.startSnapshotId() != null) {
-      LOG.info("Skipping aggregate pushdown: incremental scan is not 
supported");
-      return false;
-    }
-
     // If group by expression is the same as the partition, the statistics 
information can still
     // be used to calculate min/max/count, will enable aggregate push down in 
next phase.
     // TODO: enable aggregate push down for partition col group by expression
@@ -298,17 +285,6 @@ public class SparkScanBuilder
     return true;
   }
 
-  private Snapshot readSnapshot() {
-    Snapshot snapshot;
-    if (readConf.snapshotId() != null) {
-      snapshot = table.snapshot(readConf.snapshotId());
-    } else {
-      snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch());
-    }
-
-    return snapshot;
-  }
-
   private boolean metricsModeSupportsAggregatePushDown(List<BoundAggregate<?, 
?>> aggregates) {
     MetricsConfig config = MetricsConfig.forTable(table);
     for (BoundAggregate aggregate : aggregates) {
@@ -387,6 +363,18 @@ public class SparkScanBuilder
   }
 
   private Scan buildBatchScan() {
+    Schema expectedSchema = schemaWithMetadataColumns();
+    return new SparkBatchQueryScan(
+        spark,
+        table,
+        buildIcebergBatchScan(false /* not include Column Stats */, 
expectedSchema),
+        readConf,
+        expectedSchema,
+        filterExpressions,
+        metricsReporter::scanReport);
+  }
+
+  private org.apache.iceberg.Scan buildIcebergBatchScan(boolean withStats, 
Schema expectedSchema) {
     Long snapshotId = readConf.snapshotId();
     Long asOfTimestamp = readConf.asOfTimestamp();
     String branch = readConf.branch();
@@ -427,15 +415,19 @@ public class SparkScanBuilder
         SparkReadOptions.END_TIMESTAMP);
 
     if (startSnapshotId != null) {
-      return buildIncrementalAppendScan(startSnapshotId, endSnapshotId);
+      return buildIncrementalAppendScan(startSnapshotId, endSnapshotId, 
withStats, expectedSchema);
     } else {
-      return buildBatchScan(snapshotId, asOfTimestamp, branch, tag);
+      return buildBatchScan(snapshotId, asOfTimestamp, branch, tag, withStats, 
expectedSchema);
     }
   }
 
-  private Scan buildBatchScan(Long snapshotId, Long asOfTimestamp, String 
branch, String tag) {
-    Schema expectedSchema = schemaWithMetadataColumns();
-
+  private org.apache.iceberg.Scan buildBatchScan(
+      Long snapshotId,
+      Long asOfTimestamp,
+      String branch,
+      String tag,
+      boolean withStats,
+      Schema expectedSchema) {
     BatchScan scan =
         newBatchScan()
             .caseSensitive(caseSensitive)
@@ -443,6 +435,10 @@ public class SparkScanBuilder
             .project(expectedSchema)
             .metricsReporter(metricsReporter);
 
+    if (withStats) {
+      scan = scan.includeColumnStats();
+    }
+
     if (snapshotId != null) {
       scan = scan.useSnapshot(snapshotId);
     }
@@ -459,21 +455,11 @@ public class SparkScanBuilder
       scan = scan.useRef(tag);
     }
 
-    scan = configureSplitPlanning(scan);
-
-    return new SparkBatchQueryScan(
-        spark,
-        table,
-        scan,
-        readConf,
-        expectedSchema,
-        filterExpressions,
-        metricsReporter::scanReport);
+    return configureSplitPlanning(scan);
   }
 
-  private Scan buildIncrementalAppendScan(long startSnapshotId, Long 
endSnapshotId) {
-    Schema expectedSchema = schemaWithMetadataColumns();
-
+  private org.apache.iceberg.Scan buildIncrementalAppendScan(
+      long startSnapshotId, Long endSnapshotId, boolean withStats, Schema 
expectedSchema) {
     IncrementalAppendScan scan =
         table
             .newIncrementalAppendScan()
@@ -483,20 +469,15 @@ public class SparkScanBuilder
             .project(expectedSchema)
             .metricsReporter(metricsReporter);
 
+    if (withStats) {
+      scan = scan.includeColumnStats();
+    }
+
     if (endSnapshotId != null) {
       scan = scan.toSnapshot(endSnapshotId);
     }
 
-    scan = configureSplitPlanning(scan);
-
-    return new SparkBatchQueryScan(
-        spark,
-        table,
-        scan,
-        readConf,
-        expectedSchema,
-        filterExpressions,
-        metricsReporter::scanReport);
+    return configureSplitPlanning(scan);
   }
 
   @SuppressWarnings("CyclomaticComplexity")
diff --git 
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java
 
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java
index ff6ddea323..627fe15f28 100644
--- 
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java
+++ 
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestDataSourceOptions.java
@@ -57,6 +57,7 @@ import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SaveMode;
 import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.functions;
 import org.junit.jupiter.api.AfterAll;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.TestTemplate;
@@ -290,29 +291,44 @@ public class TestDataSourceOptions extends 
TestBaseWithCatalog {
             "Cannot set only end-snapshot-id for incremental scans. Please, 
set start-snapshot-id too.");
 
     // test (1st snapshot, current snapshot] incremental scan.
-    List<SimpleRecord> result =
+    Dataset<Row> unboundedIncrementalResult =
         spark
             .read()
             .format("iceberg")
             .option("start-snapshot-id", snapshotIds.get(3).toString())
-            .load(tableLocation)
+            .load(tableLocation);
+    List<SimpleRecord> result1 =
+        unboundedIncrementalResult
             .orderBy("id")
             .as(Encoders.bean(SimpleRecord.class))
             .collectAsList();
-    assertThat(result).as("Records should 
match").isEqualTo(expectedRecords.subList(1, 4));
+    assertThat(result1).as("Records should 
match").isEqualTo(expectedRecords.subList(1, 4));
+    assertThat(unboundedIncrementalResult.count())
+        .as("Unprocessed count should match record count")
+        .isEqualTo(3);
+
+    Row row1 = unboundedIncrementalResult.agg(functions.min("id"), 
functions.max("id")).head();
+    assertThat(row1.getInt(0)).as("min value should match").isEqualTo(2);
+    assertThat(row1.getInt(1)).as("max value should match").isEqualTo(4);
 
     // test (2nd snapshot, 3rd snapshot] incremental scan.
-    Dataset<Row> resultDf =
+    Dataset<Row> incrementalResult =
         spark
             .read()
             .format("iceberg")
             .option("start-snapshot-id", snapshotIds.get(2).toString())
             .option("end-snapshot-id", snapshotIds.get(1).toString())
             .load(tableLocation);
-    List<SimpleRecord> result1 =
-        
resultDf.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
-    assertThat(result1).as("Records should 
match").isEqualTo(expectedRecords.subList(2, 3));
-    assertThat(resultDf.count()).as("Unprocessed count should match record 
count").isEqualTo(1);
+    List<SimpleRecord> result2 =
+        
incrementalResult.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
+    assertThat(result2).as("Records should 
match").isEqualTo(expectedRecords.subList(2, 3));
+    assertThat(incrementalResult.count())
+        .as("Unprocessed count should match record count")
+        .isEqualTo(1);
+
+    Row row2 = incrementalResult.agg(functions.min("id"), 
functions.max("id")).head();
+    assertThat(row2.getInt(0)).as("min value should match").isEqualTo(3);
+    assertThat(row2.getInt(1)).as("max value should match").isEqualTo(3);
   }
 
   @TestTemplate
diff --git 
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
 
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
index 05515946c1..7e9bdeec8a 100644
--- 
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
+++ 
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
@@ -35,8 +35,13 @@ import org.apache.iceberg.hive.TestHiveMetastore;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.iceberg.relocated.com.google.common.collect.Lists;
 import org.apache.iceberg.spark.CatalogTestBase;
+import org.apache.iceberg.spark.SparkReadOptions;
 import org.apache.iceberg.spark.TestBase;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.execution.ExplainMode;
+import org.apache.spark.sql.functions;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.TestTemplate;
@@ -808,4 +813,49 @@ public class TestAggregatePushDown extends CatalogTestBase 
{
         });
     assertEquals("min/max/count push down", expected, actual);
   }
+
+  @TestTemplate
+  public void testAggregatePushDownForIncrementalScan() {
+    sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName);
+    sql(
+        "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 
4444), (3, 5555), (3, 6666) ",
+        tableName);
+    long snapshotId1 = 
validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
+    sql("INSERT INTO %s VALUES (4, 7777), (5, 8888)", tableName);
+    long snapshotId2 = 
validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
+    sql("INSERT INTO %s VALUES (6, -7777), (7, 8888)", tableName);
+    long snapshotId3 = 
validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId();
+    sql("INSERT INTO %s VALUES (8, 7777), (9, 9999)", tableName);
+
+    Dataset<Row> pushdownDs =
+        spark
+            .read()
+            .format("iceberg")
+            .option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId2)
+            .option(SparkReadOptions.END_SNAPSHOT_ID, snapshotId3)
+            .load(tableName)
+            .agg(functions.min("data"), functions.max("data"), 
functions.count("data"));
+    String explain1 = 
pushdownDs.queryExecution().explainString(ExplainMode.fromString("simple"));
+    assertThat(explain1).contains("LocalTableScan", "min(data)", "max(data)", 
"count(data)");
+
+    List<Object[]> expected1 = Lists.newArrayList();
+    expected1.add(new Object[] {-7777, 8888, 2L});
+    assertEquals("min/max/count push down", expected1, 
rowsToJava(pushdownDs.collectAsList()));
+
+    Dataset<Row> unboundedPushdownDs =
+        spark
+            .read()
+            .format("iceberg")
+            .option(SparkReadOptions.START_SNAPSHOT_ID, snapshotId1)
+            .load(tableName)
+            .agg(functions.min("data"), functions.max("data"), 
functions.count("data"));
+    String explain2 =
+        
unboundedPushdownDs.queryExecution().explainString(ExplainMode.fromString("simple"));
+    assertThat(explain2).contains("LocalTableScan", "min(data)", "max(data)", 
"count(data)");
+
+    List<Object[]> expected2 = Lists.newArrayList();
+    expected2.add(new Object[] {-7777, 9999, 6L});
+    assertEquals(
+        "min/max/count push down", expected2, 
rowsToJava(unboundedPushdownDs.collectAsList()));
+  }
 }

Reply via email to