This is an automated email from the ASF dual-hosted git repository.

aokolnychyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 95227dbbd5 Spark 3.3: Support arbitrary scans in SparkBatchQueryScan 
(#6309)
95227dbbd5 is described below

commit 95227dbbd558d057f515be969a8031ef77505085
Author: Anton Okolnychyi <[email protected]>
AuthorDate: Thu Dec 1 11:41:40 2022 -0800

    Spark 3.3: Support arbitrary scans in SparkBatchQueryScan (#6309)
---
 .../apache/iceberg/spark/source/SparkBatch.java    | 71 ++++++++++-------
 .../iceberg/spark/source/SparkBatchQueryScan.java  | 88 ++++++++++++----------
 .../iceberg/spark/source/SparkCopyOnWriteScan.java |  2 +-
 .../iceberg/spark/source/SparkFilesScan.java       |  2 +-
 .../org/apache/iceberg/spark/source/SparkScan.java |  7 +-
 .../iceberg/spark/source/SparkScanBuilder.java     | 42 ++++++++---
 6 files changed, 131 insertions(+), 81 deletions(-)

diff --git 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
index bcfa70bcf2..c5eb17a784 100644
--- 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
+++ 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
@@ -20,14 +20,15 @@ package org.apache.iceberg.spark.source;
 
 import java.util.List;
 import java.util.Objects;
-import org.apache.iceberg.CombinedScanTask;
 import org.apache.iceberg.FileFormat;
+import org.apache.iceberg.FileScanTask;
+import org.apache.iceberg.ScanTask;
+import org.apache.iceberg.ScanTaskGroup;
 import org.apache.iceberg.Schema;
 import org.apache.iceberg.SchemaParser;
 import org.apache.iceberg.Table;
 import org.apache.iceberg.spark.SparkReadConf;
 import org.apache.iceberg.spark.source.SparkScan.ReaderFactory;
-import org.apache.iceberg.util.TableScanUtil;
 import org.apache.iceberg.util.Tasks;
 import org.apache.iceberg.util.ThreadPools;
 import org.apache.spark.api.java.JavaSparkContext;
@@ -41,7 +42,7 @@ class SparkBatch implements Batch {
   private final JavaSparkContext sparkContext;
   private final Table table;
   private final SparkReadConf readConf;
-  private final List<CombinedScanTask> taskGroups;
+  private final List<? extends ScanTaskGroup<?>> taskGroups;
   private final Schema expectedSchema;
   private final boolean caseSensitive;
   private final boolean localityEnabled;
@@ -51,7 +52,7 @@ class SparkBatch implements Batch {
       JavaSparkContext sparkContext,
       Table table,
       SparkReadConf readConf,
-      List<CombinedScanTask> taskGroups,
+      List<? extends ScanTaskGroup<?>> taskGroups,
       Schema expectedSchema,
       int scanHashCode) {
     this.sparkContext = sparkContext;
@@ -95,43 +96,61 @@ class SparkBatch implements Batch {
   }
 
   private int batchSize() {
-    if (parquetOnly() && parquetBatchReadsEnabled()) {
+    if (useParquetBatchReads()) {
       return readConf.parquetBatchSize();
-    } else if (orcOnly() && orcBatchReadsEnabled()) {
+    } else if (useOrcBatchReads()) {
       return readConf.orcBatchSize();
     } else {
       return 0;
     }
   }
 
-  private boolean parquetOnly() {
-    return taskGroups.stream()
-        .allMatch(task -> !task.isDataTask() && onlyFileFormat(task, 
FileFormat.PARQUET));
-  }
-
-  private boolean parquetBatchReadsEnabled() {
+  // conditions for using Parquet batch reads:
+  // - Parquet vectorization is enabled
+  // - at least one column is projected
+  // - only primitives are projected
+  // - all tasks are of FileScanTask type and read only Parquet files
+  private boolean useParquetBatchReads() {
     return readConf.parquetVectorizationEnabled()
-        && // vectorization enabled
-        expectedSchema.columns().size() > 0
-        && // at least one column is projected
-        expectedSchema.columns().stream()
-            .allMatch(c -> c.type().isPrimitiveType()); // only primitives
+        && expectedSchema.columns().size() > 0
+        && expectedSchema.columns().stream().allMatch(c -> 
c.type().isPrimitiveType())
+        && taskGroups.stream().allMatch(this::supportsParquetBatchReads);
   }
 
-  private boolean orcOnly() {
-    return taskGroups.stream()
-        .allMatch(task -> !task.isDataTask() && onlyFileFormat(task, 
FileFormat.ORC));
+  private boolean supportsParquetBatchReads(ScanTask task) {
+    if (task instanceof ScanTaskGroup) {
+      ScanTaskGroup<?> taskGroup = (ScanTaskGroup<?>) task;
+      return 
taskGroup.tasks().stream().allMatch(this::supportsParquetBatchReads);
+
+    } else if (task.isFileScanTask() && !task.isDataTask()) {
+      FileScanTask fileScanTask = task.asFileScanTask();
+      return fileScanTask.file().format() == FileFormat.PARQUET;
+
+    } else {
+      return false;
+    }
   }
 
-  private boolean orcBatchReadsEnabled() {
+  // conditions for using ORC batch reads:
+  // - ORC vectorization is enabled
+  // - all tasks are of type FileScanTask and read only ORC files with no 
delete files
+  private boolean useOrcBatchReads() {
     return readConf.orcVectorizationEnabled()
-        && // vectorization enabled
-        taskGroups.stream().noneMatch(TableScanUtil::hasDeletes); // no delete 
files
+        && taskGroups.stream().allMatch(this::supportsOrcBatchReads);
   }
 
-  private boolean onlyFileFormat(CombinedScanTask task, FileFormat fileFormat) 
{
-    return task.files().stream()
-        .allMatch(fileScanTask -> 
fileScanTask.file().format().equals(fileFormat));
+  private boolean supportsOrcBatchReads(ScanTask task) {
+    if (task instanceof ScanTaskGroup) {
+      ScanTaskGroup<?> taskGroup = (ScanTaskGroup<?>) task;
+      return taskGroup.tasks().stream().allMatch(this::supportsOrcBatchReads);
+
+    } else if (task.isFileScanTask() && !task.isDataTask()) {
+      FileScanTask fileScanTask = task.asFileScanTask();
+      return fileScanTask.file().format() == FileFormat.ORC && 
fileScanTask.deletes().isEmpty();
+
+    } else {
+      return false;
+    }
   }
 
   @Override
diff --git 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java
 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java
index fb83636b4c..19bb551426 100644
--- 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java
+++ 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java
@@ -26,14 +26,15 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.stream.Collectors;
-import org.apache.iceberg.CombinedScanTask;
-import org.apache.iceberg.FileScanTask;
 import org.apache.iceberg.PartitionField;
+import org.apache.iceberg.PartitionScanTask;
 import org.apache.iceberg.PartitionSpec;
+import org.apache.iceberg.Scan;
+import org.apache.iceberg.ScanTask;
+import org.apache.iceberg.ScanTaskGroup;
 import org.apache.iceberg.Schema;
 import org.apache.iceberg.Snapshot;
 import org.apache.iceberg.Table;
-import org.apache.iceberg.TableScan;
 import org.apache.iceberg.exceptions.ValidationException;
 import org.apache.iceberg.expressions.Binder;
 import org.apache.iceberg.expressions.Evaluator;
@@ -63,7 +64,7 @@ class SparkBatchQueryScan extends SparkScan implements 
SupportsRuntimeFiltering
 
   private static final Logger LOG = 
LoggerFactory.getLogger(SparkBatchQueryScan.class);
 
-  private final TableScan scan;
+  private final Scan<?, ? extends ScanTask, ? extends ScanTaskGroup<?>> scan;
   private final Long snapshotId;
   private final Long startSnapshotId;
   private final Long endSnapshotId;
@@ -73,13 +74,13 @@ class SparkBatchQueryScan extends SparkScan implements 
SupportsRuntimeFiltering
   private final List<Expression> runtimeFilterExpressions;
 
   private Set<Integer> specIds = null; // lazy cache of scanned spec IDs
-  private List<FileScanTask> files = null; // lazy cache of files
-  private List<CombinedScanTask> tasks = null; // lazy cache of tasks
+  private List<PartitionScanTask> tasks = null; // lazy cache of uncombined 
tasks
+  private List<ScanTaskGroup<PartitionScanTask>> taskGroups = null; // lazy 
cache of task groups
 
   SparkBatchQueryScan(
       SparkSession spark,
       Table table,
-      TableScan scan,
+      Scan<?, ? extends ScanTask, ? extends ScanTaskGroup<?>> scan,
       SparkReadConf readConf,
       Schema expectedSchema,
       List<Expression> filters) {
@@ -97,8 +98,8 @@ class SparkBatchQueryScan extends SparkScan implements 
SupportsRuntimeFiltering
 
     if (scan == null) {
       this.specIds = Collections.emptySet();
-      this.files = Collections.emptyList();
       this.tasks = Collections.emptyList();
+      this.taskGroups = Collections.emptyList();
     }
   }
 
@@ -109,8 +110,8 @@ class SparkBatchQueryScan extends SparkScan implements 
SupportsRuntimeFiltering
   private Set<Integer> specIds() {
     if (specIds == null) {
       Set<Integer> specIdSet = Sets.newHashSet();
-      for (FileScanTask file : files()) {
-        specIdSet.add(file.spec().specId());
+      for (PartitionScanTask task : tasks()) {
+        specIdSet.add(task.spec().specId());
       }
       this.specIds = specIdSet;
     }
@@ -118,31 +119,40 @@ class SparkBatchQueryScan extends SparkScan implements 
SupportsRuntimeFiltering
     return specIds;
   }
 
-  private List<FileScanTask> files() {
-    if (files == null) {
-      try (CloseableIterable<FileScanTask> filesIterable = scan.planFiles()) {
-        this.files = Lists.newArrayList(filesIterable);
+  private List<PartitionScanTask> tasks() {
+    if (tasks == null) {
+      try (CloseableIterable<? extends ScanTask> taskIterable = 
scan.planFiles()) {
+        List<PartitionScanTask> partitionScanTasks = Lists.newArrayList();
+        for (ScanTask task : taskIterable) {
+          ValidationException.check(
+              task instanceof PartitionScanTask,
+              "Unsupported task type, expected a subtype of PartitionScanTask: 
%",
+              task.getClass().getName());
+
+          partitionScanTasks.add((PartitionScanTask) task);
+        }
+        this.tasks = partitionScanTasks;
       } catch (IOException e) {
-        throw new UncheckedIOException("Failed to close table scan: " + scan, 
e);
+        throw new UncheckedIOException("Failed to close scan: " + scan, e);
       }
     }
 
-    return files;
+    return tasks;
   }
 
   @Override
-  protected List<CombinedScanTask> tasks() {
-    if (tasks == null) {
-      CloseableIterable<FileScanTask> splitFiles =
-          TableScanUtil.splitFiles(
-              CloseableIterable.withNoopClose(files()), 
scan.targetSplitSize());
-      CloseableIterable<CombinedScanTask> scanTasks =
-          TableScanUtil.planTasks(
-              splitFiles, scan.targetSplitSize(), scan.splitLookback(), 
scan.splitOpenFileCost());
-      tasks = Lists.newArrayList(scanTasks);
+  protected List<ScanTaskGroup<PartitionScanTask>> taskGroups() {
+    if (taskGroups == null) {
+      CloseableIterable<ScanTaskGroup<PartitionScanTask>> plannedTaskGroups =
+          TableScanUtil.planTaskGroups(
+              CloseableIterable.withNoopClose(tasks()),
+              scan.targetSplitSize(),
+              scan.splitLookback(),
+              scan.splitOpenFileCost());
+      taskGroups = Lists.newArrayList(plannedTaskGroups);
     }
 
-    return tasks;
+    return taskGroups;
   }
 
   @Override
@@ -184,30 +194,30 @@ class SparkBatchQueryScan extends SparkScan implements 
SupportsRuntimeFiltering
       }
 
       LOG.info(
-          "Trying to filter {} files using runtime filter {}",
-          files().size(),
+          "Trying to filter {} tasks using runtime filter {}",
+          tasks().size(),
           ExpressionUtil.toSanitizedString(runtimeFilterExpr));
 
-      List<FileScanTask> filteredFiles =
-          files().stream()
+      List<PartitionScanTask> filteredTasks =
+          tasks().stream()
               .filter(
-                  file -> {
-                    Evaluator evaluator = 
evaluatorsBySpecId.get(file.spec().specId());
-                    return evaluator.eval(file.file().partition());
+                  task -> {
+                    Evaluator evaluator = 
evaluatorsBySpecId.get(task.spec().specId());
+                    return evaluator.eval(task.partition());
                   })
               .collect(Collectors.toList());
 
       LOG.info(
-          "{}/{} files matched runtime filter {}",
-          filteredFiles.size(),
-          files().size(),
+          "{}/{} tasks matched runtime filter {}",
+          filteredTasks.size(),
+          tasks().size(),
           ExpressionUtil.toSanitizedString(runtimeFilterExpr));
 
       // don't invalidate tasks if the runtime filter had no effect to avoid 
planning splits again
-      if (filteredFiles.size() < files().size()) {
+      if (filteredTasks.size() < tasks().size()) {
         this.specIds = null;
-        this.files = filteredFiles;
-        this.tasks = null;
+        this.tasks = filteredTasks;
+        this.taskGroups = null;
       }
 
       // save the evaluated filter for equals/hashCode
diff --git 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java
 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java
index a13a09f995..0431f904e7 100644
--- 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java
+++ 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java
@@ -151,7 +151,7 @@ class SparkCopyOnWriteScan extends SparkScan implements 
SupportsRuntimeFiltering
   }
 
   @Override
-  protected synchronized List<CombinedScanTask> tasks() {
+  protected synchronized List<CombinedScanTask> taskGroups() {
     if (tasks == null) {
       CloseableIterable<FileScanTask> splitFiles =
           TableScanUtil.splitFiles(
diff --git 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFilesScan.java
 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFilesScan.java
index d40009c9f8..e7670f438b 100644
--- 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFilesScan.java
+++ 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkFilesScan.java
@@ -50,7 +50,7 @@ class SparkFilesScan extends SparkScan {
   }
 
   @Override
-  protected List<CombinedScanTask> tasks() {
+  protected List<CombinedScanTask> taskGroups() {
     if (tasks == null) {
       FileScanTaskSetManager taskSetManager = FileScanTaskSetManager.get();
       List<FileScanTask> files = taskSetManager.fetchTasks(table(), taskSetID);
diff --git 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java
index 7d89fc23bc..41b5a197e1 100644
--- 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java
+++ 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java
@@ -21,7 +21,6 @@ package org.apache.iceberg.spark.source;
 import java.util.Collections;
 import java.util.List;
 import java.util.stream.Collectors;
-import org.apache.iceberg.CombinedScanTask;
 import org.apache.iceberg.ScanTaskGroup;
 import org.apache.iceberg.Schema;
 import org.apache.iceberg.Snapshot;
@@ -104,11 +103,11 @@ abstract class SparkScan implements Scan, 
SupportsReportStatistics {
     return filterExpressions;
   }
 
-  protected abstract List<CombinedScanTask> tasks();
+  protected abstract List<? extends ScanTaskGroup<?>> taskGroups();
 
   @Override
   public Batch toBatch() {
-    return new SparkBatch(sparkContext, table, readConf, tasks(), 
expectedSchema, hashCode());
+    return new SparkBatch(sparkContext, table, readConf, taskGroups(), 
expectedSchema, hashCode());
   }
 
   @Override
@@ -149,7 +148,7 @@ abstract class SparkScan implements Scan, 
SupportsReportStatistics {
       return new Stats(SparkSchemaUtil.estimateSize(readSchema(), 
totalRecords), totalRecords);
     }
 
-    long rowsCount = 
tasks().stream().mapToLong(ScanTaskGroup::estimatedRowsCount).sum();
+    long rowsCount = 
taskGroups().stream().mapToLong(ScanTaskGroup::estimatedRowsCount).sum();
     long sizeInBytes = SparkSchemaUtil.estimateSize(readSchema(), rowsCount);
     return new Stats(sizeInBytes, rowsCount);
   }
diff --git 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
index 150da814ba..69f8ab972c 100644
--- 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
+++ 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java
@@ -21,6 +21,8 @@ package org.apache.iceberg.spark.source;
 import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
+import org.apache.iceberg.BatchScan;
+import org.apache.iceberg.IncrementalAppendScan;
 import org.apache.iceberg.IncrementalChangelogScan;
 import org.apache.iceberg.MetadataColumns;
 import org.apache.iceberg.Schema;
@@ -211,11 +213,19 @@ public class SparkScanBuilder
         SparkReadOptions.END_SNAPSHOT_ID,
         SparkReadOptions.START_SNAPSHOT_ID);
 
+    if (startSnapshotId != null) {
+      return buildIncrementalAppendScan(startSnapshotId, endSnapshotId);
+    } else {
+      return buildBatchScan(snapshotId, asOfTimestamp, branch, tag);
+    }
+  }
+
+  private Scan buildBatchScan(Long snapshotId, Long asOfTimestamp, String 
branch, String tag) {
     Schema expectedSchema = schemaWithMetadataColumns();
 
-    TableScan scan =
+    BatchScan scan =
         table
-            .newScan()
+            .newBatchScan()
             .caseSensitive(caseSensitive)
             .filter(filterExpression())
             .project(expectedSchema);
@@ -236,12 +246,24 @@ public class SparkScanBuilder
       scan = scan.useRef(tag);
     }
 
-    if (startSnapshotId != null) {
-      if (endSnapshotId != null) {
-        scan = scan.appendsBetween(startSnapshotId, endSnapshotId);
-      } else {
-        scan = scan.appendsAfter(startSnapshotId);
-      }
+    scan = configureSplitPlanning(scan);
+
+    return new SparkBatchQueryScan(spark, table, scan, readConf, 
expectedSchema, filterExpressions);
+  }
+
+  private Scan buildIncrementalAppendScan(long startSnapshotId, Long 
endSnapshotId) {
+    Schema expectedSchema = schemaWithMetadataColumns();
+
+    IncrementalAppendScan scan =
+        table
+            .newIncrementalAppendScan()
+            .fromSnapshotExclusive(startSnapshotId)
+            .caseSensitive(caseSensitive)
+            .filter(filterExpression())
+            .project(expectedSchema);
+
+    if (endSnapshotId != null) {
+      scan = scan.toSnapshot(endSnapshotId);
     }
 
     scan = configureSplitPlanning(scan);
@@ -320,9 +342,9 @@ public class SparkScanBuilder
 
     Schema expectedSchema = schemaWithMetadataColumns();
 
-    TableScan scan =
+    BatchScan scan =
         table
-            .newScan()
+            .newBatchScan()
             .useSnapshot(snapshotId)
             .caseSensitive(caseSensitive)
             .filter(filterExpression())

Reply via email to