This is an automated email from the ASF dual-hosted git repository.

aokolnychyi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new abbfdae108 Spark 3.4, 3.3: Support metadata columns in staged scans 
(#9098)
abbfdae108 is described below

commit abbfdae108dd342191dcc685021edbc0113fe302
Author: zhen <[email protected]>
AuthorDate: Sat Nov 18 09:26:32 2023 +0800

    Spark 3.4, 3.3: Support metadata columns in staged scans (#9098)
    
    This change cherry-picks PR #8872 to Spark 3.4 and 3.3.
---
 .../TestMetaColumnProjectionWithStageScan.java     | 127 +++++++++++++++++++++
 .../iceberg/spark/source/SparkStagedScan.java      |  10 +-
 .../spark/source/SparkStagedScanBuilder.java       |  52 ++++++++-
 .../TestMetaColumnProjectionWithStageScan.java     | 127 +++++++++++++++++++++
 .../iceberg/spark/source/SparkStagedScan.java      |  10 +-
 .../spark/source/SparkStagedScanBuilder.java       |  52 ++++++++-
 6 files changed, 368 insertions(+), 10 deletions(-)

diff --git 
a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetaColumnProjectionWithStageScan.java
 
b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetaColumnProjectionWithStageScan.java
new file mode 100644
index 0000000000..e9013848cf
--- /dev/null
+++ 
b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetaColumnProjectionWithStageScan.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.extensions;
+
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import org.apache.iceberg.ScanTask;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.io.CloseableIterable;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.spark.ScanTaskSetManager;
+import org.apache.iceberg.spark.Spark3Util;
+import org.apache.iceberg.spark.SparkCatalogConfig;
+import org.apache.iceberg.spark.SparkReadOptions;
+import org.apache.iceberg.spark.source.SimpleRecord;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.Row;
+import org.assertj.core.api.Assertions;
+import org.junit.After;
+import org.junit.Test;
+import org.junit.runners.Parameterized;
+
+public class TestMetaColumnProjectionWithStageScan extends 
SparkExtensionsTestBase {
+
+  public TestMetaColumnProjectionWithStageScan(
+      String catalogName, String implementation, Map<String, String> config) {
+    super(catalogName, implementation, config);
+  }
+
+  @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, 
config = {2}")
+  public static Object[][] parameters() {
+    return new Object[][] {
+      {
+        SparkCatalogConfig.HADOOP.catalogName(),
+        SparkCatalogConfig.HADOOP.implementation(),
+        SparkCatalogConfig.HADOOP.properties()
+      }
+    };
+  }
+
+  @After
+  public void removeTables() {
+    sql("DROP TABLE IF EXISTS %s", tableName);
+  }
+
+  private <T extends ScanTask> void stageTask(
+      Table tab, String fileSetID, CloseableIterable<T> tasks) {
+    ScanTaskSetManager taskSetManager = ScanTaskSetManager.get();
+    taskSetManager.stageTasks(tab, fileSetID, Lists.newArrayList(tasks));
+  }
+
+  @Test
+  public void testReadStageTableMeta() throws Exception {
+    sql(
+        "CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES"
+            + "('format-version'='2', 'write.delete.mode'='merge-on-read')",
+        tableName);
+
+    List<SimpleRecord> records =
+        Lists.newArrayList(
+            new SimpleRecord(1, "a"),
+            new SimpleRecord(2, "b"),
+            new SimpleRecord(3, "c"),
+            new SimpleRecord(4, "d"));
+
+    spark
+        .createDataset(records, Encoders.bean(SimpleRecord.class))
+        .coalesce(1)
+        .writeTo(tableName)
+        .append();
+
+    Table table = Spark3Util.loadIcebergTable(spark, tableName);
+    table.refresh();
+    String tableLocation = table.location();
+
+    try (CloseableIterable<ScanTask> tasks = table.newBatchScan().planFiles()) 
{
+      String fileSetID = UUID.randomUUID().toString();
+      stageTask(table, fileSetID, tasks);
+      Dataset<Row> scanDF2 =
+          spark
+              .read()
+              .format("iceberg")
+              .option(SparkReadOptions.FILE_OPEN_COST, "0")
+              .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID)
+              .load(tableLocation);
+
+      Assertions.assertThat(scanDF2.columns().length).isEqualTo(2);
+    }
+
+    try (CloseableIterable<ScanTask> tasks = table.newBatchScan().planFiles()) 
{
+      String fileSetID = UUID.randomUUID().toString();
+      stageTask(table, fileSetID, tasks);
+      Dataset<Row> scanDF =
+          spark
+              .read()
+              .format("iceberg")
+              .option(SparkReadOptions.FILE_OPEN_COST, "0")
+              .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID)
+              .load(tableLocation)
+              .select("*", "_pos");
+
+      List<Row> rows = scanDF.collectAsList();
+      ImmutableList<Object[]> expectedRows =
+          ImmutableList.of(row(1L, "a", 0L), row(2L, "b", 1L), row(3L, "c", 
2L), row(4L, "d", 3L));
+      assertEquals("result should match", expectedRows, rowsToJava(rows));
+    }
+  }
+}
diff --git 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java
 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java
index 89b184c91c..ad501f7f91 100644
--- 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java
+++ 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java
@@ -22,6 +22,7 @@ import java.util.List;
 import java.util.Objects;
 import org.apache.iceberg.ScanTask;
 import org.apache.iceberg.ScanTaskGroup;
+import org.apache.iceberg.Schema;
 import org.apache.iceberg.Table;
 import org.apache.iceberg.exceptions.ValidationException;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
@@ -40,8 +41,11 @@ class SparkStagedScan extends SparkScan {
   private List<ScanTaskGroup<ScanTask>> taskGroups = null; // lazy cache of 
tasks
 
   SparkStagedScan(SparkSession spark, Table table, SparkReadConf readConf) {
-    super(spark, table, readConf, table.schema(), ImmutableList.of());
+    this(spark, table, table.schema(), readConf);
+  }
 
+  SparkStagedScan(SparkSession spark, Table table, Schema expectedSchema, 
SparkReadConf readConf) {
+    super(spark, table, readConf, expectedSchema, ImmutableList.of());
     this.taskSetId = readConf.scanTaskSetId();
     this.splitSize = readConf.splitSize();
     this.splitLookback = readConf.splitLookback();
@@ -77,6 +81,7 @@ class SparkStagedScan extends SparkScan {
     SparkStagedScan that = (SparkStagedScan) other;
     return table().name().equals(that.table().name())
         && Objects.equals(taskSetId, that.taskSetId)
+        && readSchema().equals(that.readSchema())
         && Objects.equals(splitSize, that.splitSize)
         && Objects.equals(splitLookback, that.splitLookback)
         && Objects.equals(openFileCost, that.openFileCost);
@@ -84,7 +89,8 @@ class SparkStagedScan extends SparkScan {
 
   @Override
   public int hashCode() {
-    return Objects.hash(table().name(), taskSetId, splitSize, splitSize, 
openFileCost);
+    return Objects.hash(
+        table().name(), taskSetId, readSchema(), splitSize, splitSize, 
openFileCost);
   }
 
   @Override
diff --git 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java
 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java
index 37bbea42e5..25393888f9 100644
--- 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java
+++ 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java
@@ -18,27 +18,75 @@
  */
 package org.apache.iceberg.spark.source;
 
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.iceberg.MetadataColumns;
+import org.apache.iceberg.Schema;
 import org.apache.iceberg.Table;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
 import org.apache.iceberg.spark.SparkReadConf;
+import org.apache.iceberg.spark.SparkSchemaUtil;
+import org.apache.iceberg.types.TypeUtil;
+import org.apache.iceberg.types.Types;
 import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.connector.read.Scan;
 import org.apache.spark.sql.connector.read.ScanBuilder;
+import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
 import org.apache.spark.sql.util.CaseInsensitiveStringMap;
 
-class SparkStagedScanBuilder implements ScanBuilder {
+class SparkStagedScanBuilder implements ScanBuilder, 
SupportsPushDownRequiredColumns {
 
   private final SparkSession spark;
   private final Table table;
   private final SparkReadConf readConf;
+  private final List<String> metaColumns = Lists.newArrayList();
+
+  private Schema schema = null;
 
   SparkStagedScanBuilder(SparkSession spark, Table table, 
CaseInsensitiveStringMap options) {
     this.spark = spark;
     this.table = table;
     this.readConf = new SparkReadConf(spark, table, options);
+    this.schema = table.schema();
   }
 
   @Override
   public Scan build() {
-    return new SparkStagedScan(spark, table, readConf);
+    return new SparkStagedScan(spark, table, schemaWithMetadataColumns(), 
readConf);
+  }
+
+  @Override
+  public void pruneColumns(StructType requestedSchema) {
+    StructType requestedProjection = removeMetaColumns(requestedSchema);
+    this.schema = SparkSchemaUtil.prune(schema, requestedProjection);
+
+    Stream.of(requestedSchema.fields())
+        .map(StructField::name)
+        .filter(MetadataColumns::isMetadataColumn)
+        .distinct()
+        .forEach(metaColumns::add);
+  }
+
+  private StructType removeMetaColumns(StructType structType) {
+    return new StructType(
+        Stream.of(structType.fields())
+            .filter(field -> MetadataColumns.nonMetadataColumn(field.name()))
+            .toArray(StructField[]::new));
+  }
+
+  private Schema schemaWithMetadataColumns() {
+    // metadata columns
+    List<Types.NestedField> fields =
+        metaColumns.stream()
+            .distinct()
+            .map(name -> MetadataColumns.metadataColumn(table, name))
+            .collect(Collectors.toList());
+    Schema meta = new Schema(fields);
+
+    // schema of rows returned by readers
+    return TypeUtil.join(schema, meta);
   }
 }
diff --git 
a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetaColumnProjectionWithStageScan.java
 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetaColumnProjectionWithStageScan.java
new file mode 100644
index 0000000000..e9013848cf
--- /dev/null
+++ 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMetaColumnProjectionWithStageScan.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.extensions;
+
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import org.apache.iceberg.ScanTask;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.io.CloseableIterable;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.spark.ScanTaskSetManager;
+import org.apache.iceberg.spark.Spark3Util;
+import org.apache.iceberg.spark.SparkCatalogConfig;
+import org.apache.iceberg.spark.SparkReadOptions;
+import org.apache.iceberg.spark.source.SimpleRecord;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.Row;
+import org.assertj.core.api.Assertions;
+import org.junit.After;
+import org.junit.Test;
+import org.junit.runners.Parameterized;
+
+public class TestMetaColumnProjectionWithStageScan extends 
SparkExtensionsTestBase {
+
+  public TestMetaColumnProjectionWithStageScan(
+      String catalogName, String implementation, Map<String, String> config) {
+    super(catalogName, implementation, config);
+  }
+
+  @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, 
config = {2}")
+  public static Object[][] parameters() {
+    return new Object[][] {
+      {
+        SparkCatalogConfig.HADOOP.catalogName(),
+        SparkCatalogConfig.HADOOP.implementation(),
+        SparkCatalogConfig.HADOOP.properties()
+      }
+    };
+  }
+
+  @After
+  public void removeTables() {
+    sql("DROP TABLE IF EXISTS %s", tableName);
+  }
+
+  private <T extends ScanTask> void stageTask(
+      Table tab, String fileSetID, CloseableIterable<T> tasks) {
+    ScanTaskSetManager taskSetManager = ScanTaskSetManager.get();
+    taskSetManager.stageTasks(tab, fileSetID, Lists.newArrayList(tasks));
+  }
+
+  @Test
+  public void testReadStageTableMeta() throws Exception {
+    sql(
+        "CREATE TABLE %s (id bigint, data string) USING iceberg TBLPROPERTIES"
+            + "('format-version'='2', 'write.delete.mode'='merge-on-read')",
+        tableName);
+
+    List<SimpleRecord> records =
+        Lists.newArrayList(
+            new SimpleRecord(1, "a"),
+            new SimpleRecord(2, "b"),
+            new SimpleRecord(3, "c"),
+            new SimpleRecord(4, "d"));
+
+    spark
+        .createDataset(records, Encoders.bean(SimpleRecord.class))
+        .coalesce(1)
+        .writeTo(tableName)
+        .append();
+
+    Table table = Spark3Util.loadIcebergTable(spark, tableName);
+    table.refresh();
+    String tableLocation = table.location();
+
+    try (CloseableIterable<ScanTask> tasks = table.newBatchScan().planFiles()) 
{
+      String fileSetID = UUID.randomUUID().toString();
+      stageTask(table, fileSetID, tasks);
+      Dataset<Row> scanDF2 =
+          spark
+              .read()
+              .format("iceberg")
+              .option(SparkReadOptions.FILE_OPEN_COST, "0")
+              .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID)
+              .load(tableLocation);
+
+      Assertions.assertThat(scanDF2.columns().length).isEqualTo(2);
+    }
+
+    try (CloseableIterable<ScanTask> tasks = table.newBatchScan().planFiles()) 
{
+      String fileSetID = UUID.randomUUID().toString();
+      stageTask(table, fileSetID, tasks);
+      Dataset<Row> scanDF =
+          spark
+              .read()
+              .format("iceberg")
+              .option(SparkReadOptions.FILE_OPEN_COST, "0")
+              .option(SparkReadOptions.SCAN_TASK_SET_ID, fileSetID)
+              .load(tableLocation)
+              .select("*", "_pos");
+
+      List<Row> rows = scanDF.collectAsList();
+      ImmutableList<Object[]> expectedRows =
+          ImmutableList.of(row(1L, "a", 0L), row(2L, "b", 1L), row(3L, "c", 
2L), row(4L, "d", 3L));
+      assertEquals("result should match", expectedRows, rowsToJava(rows));
+    }
+  }
+}
diff --git 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java
 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java
index 0290bf7e84..fd299ade7f 100644
--- 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java
+++ 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScan.java
@@ -22,6 +22,7 @@ import java.util.List;
 import java.util.Objects;
 import org.apache.iceberg.ScanTask;
 import org.apache.iceberg.ScanTaskGroup;
+import org.apache.iceberg.Schema;
 import org.apache.iceberg.Table;
 import org.apache.iceberg.exceptions.ValidationException;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
@@ -39,9 +40,8 @@ class SparkStagedScan extends SparkScan {
 
   private List<ScanTaskGroup<ScanTask>> taskGroups = null; // lazy cache of 
tasks
 
-  SparkStagedScan(SparkSession spark, Table table, SparkReadConf readConf) {
-    super(spark, table, readConf, table.schema(), ImmutableList.of(), null);
-
+  SparkStagedScan(SparkSession spark, Table table, Schema expectedSchema, 
SparkReadConf readConf) {
+    super(spark, table, readConf, expectedSchema, ImmutableList.of(), null);
     this.taskSetId = readConf.scanTaskSetId();
     this.splitSize = readConf.splitSize();
     this.splitLookback = readConf.splitLookback();
@@ -77,6 +77,7 @@ class SparkStagedScan extends SparkScan {
     SparkStagedScan that = (SparkStagedScan) other;
     return table().name().equals(that.table().name())
         && Objects.equals(taskSetId, that.taskSetId)
+        && readSchema().equals(that.readSchema())
         && Objects.equals(splitSize, that.splitSize)
         && Objects.equals(splitLookback, that.splitLookback)
         && Objects.equals(openFileCost, that.openFileCost);
@@ -84,7 +85,8 @@ class SparkStagedScan extends SparkScan {
 
   @Override
   public int hashCode() {
-    return Objects.hash(table().name(), taskSetId, splitSize, splitSize, 
openFileCost);
+    return Objects.hash(
+        table().name(), taskSetId, readSchema(), splitSize, splitSize, 
openFileCost);
   }
 
   @Override
diff --git 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java
 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java
index 37bbea42e5..25393888f9 100644
--- 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java
+++ 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkStagedScanBuilder.java
@@ -18,27 +18,75 @@
  */
 package org.apache.iceberg.spark.source;
 
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.iceberg.MetadataColumns;
+import org.apache.iceberg.Schema;
 import org.apache.iceberg.Table;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
 import org.apache.iceberg.spark.SparkReadConf;
+import org.apache.iceberg.spark.SparkSchemaUtil;
+import org.apache.iceberg.types.TypeUtil;
+import org.apache.iceberg.types.Types;
 import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.connector.read.Scan;
 import org.apache.spark.sql.connector.read.ScanBuilder;
+import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
 import org.apache.spark.sql.util.CaseInsensitiveStringMap;
 
-class SparkStagedScanBuilder implements ScanBuilder {
+class SparkStagedScanBuilder implements ScanBuilder, 
SupportsPushDownRequiredColumns {
 
   private final SparkSession spark;
   private final Table table;
   private final SparkReadConf readConf;
+  private final List<String> metaColumns = Lists.newArrayList();
+
+  private Schema schema = null;
 
   SparkStagedScanBuilder(SparkSession spark, Table table, 
CaseInsensitiveStringMap options) {
     this.spark = spark;
     this.table = table;
     this.readConf = new SparkReadConf(spark, table, options);
+    this.schema = table.schema();
   }
 
   @Override
   public Scan build() {
-    return new SparkStagedScan(spark, table, readConf);
+    return new SparkStagedScan(spark, table, schemaWithMetadataColumns(), 
readConf);
+  }
+
+  @Override
+  public void pruneColumns(StructType requestedSchema) {
+    StructType requestedProjection = removeMetaColumns(requestedSchema);
+    this.schema = SparkSchemaUtil.prune(schema, requestedProjection);
+
+    Stream.of(requestedSchema.fields())
+        .map(StructField::name)
+        .filter(MetadataColumns::isMetadataColumn)
+        .distinct()
+        .forEach(metaColumns::add);
+  }
+
+  private StructType removeMetaColumns(StructType structType) {
+    return new StructType(
+        Stream.of(structType.fields())
+            .filter(field -> MetadataColumns.nonMetadataColumn(field.name()))
+            .toArray(StructField[]::new));
+  }
+
+  private Schema schemaWithMetadataColumns() {
+    // metadata columns
+    List<Types.NestedField> fields =
+        metaColumns.stream()
+            .distinct()
+            .map(name -> MetadataColumns.metadataColumn(table, name))
+            .collect(Collectors.toList());
+    Schema meta = new Schema(fields);
+
+    // schema of rows returned by readers
+    return TypeUtil.join(schema, meta);
   }
 }

Reply via email to