This is an automated email from the ASF dual-hosted git repository.

amoghj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new 22d4e7836a Spark 3.5: Parallelize reading files in add_files procedure 
(#9274)
22d4e7836a is described below

commit 22d4e7836ac0bb47bed627e7c773d44105b8f0f8
Author: Manu Zhang <[email protected]>
AuthorDate: Thu Dec 28 23:52:21 2023 +0800

    Spark 3.5: Parallelize reading files in add_files procedure (#9274)
---
 .../apache/iceberg/data/TableMigrationUtil.java    |  4 +-
 docs/spark-procedures.md                           |  1 +
 .../spark/extensions/TestAddFilesProcedure.java    | 22 ++++++
 .../org/apache/iceberg/spark/SparkTableUtil.java   | 89 +++++++++++++++++++---
 .../spark/procedures/AddFilesProcedure.java        | 54 ++++++++++---
 .../iceberg/spark/procedures/ProcedureInput.java   | 12 +++
 6 files changed, 156 insertions(+), 26 deletions(-)

diff --git a/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java 
b/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java
index 0fb290f947..5834a074a1 100644
--- a/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java
+++ b/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java
@@ -215,11 +215,11 @@ public class TableMigrationUtil {
         .build();
   }
 
-  private static ExecutorService migrationService(int concurrentDeletes) {
+  private static ExecutorService migrationService(int parallelism) {
     return MoreExecutors.getExitingExecutorService(
         (ThreadPoolExecutor)
             Executors.newFixedThreadPool(
-                concurrentDeletes,
+                parallelism,
                 new 
ThreadFactoryBuilder().setNameFormat("table-migration-%d").build()));
   }
 }
diff --git a/docs/spark-procedures.md b/docs/spark-procedures.md
index cdc1779a88..45a9f80ea6 100644
--- a/docs/spark-procedures.md
+++ b/docs/spark-procedures.md
@@ -640,6 +640,7 @@ Keep in mind the `add_files` procedure will fetch the 
Parquet metadata from each
 | `source_table`          | ✔️        | string              | Table where 
files should come from, paths are also possible in the form of 
\`file_format\`.\`path\` |
 | `partition_filter`      | ️         | map<string, string> | A map of 
partitions in the source table to import from                                   
           |
 | `check_duplicate_files` | ️         | boolean             | Whether to 
prevent files existing in the table from being added (defaults to true)         
         |
+| `parallelism`           |           | int                 | number of 
threads to use for file reading (defaults to 1)                                 
        |
 
 Warning : Schema is not validated, adding files with different schema to the 
Iceberg table will cause issues.
 
diff --git 
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
 
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
index 3ed99da249..eaa0a5894c 100644
--- 
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
+++ 
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
@@ -935,6 +935,28 @@ public class TestAddFilesProcedure extends 
SparkExtensionsTestBase {
         sql("SELECT * FROM %s ORDER BY id", tableName));
   }
 
+  @Test
+  public void testAddFilesWithParallelism() {
+    createUnpartitionedHiveTable();
+
+    String createIceberg =
+        "CREATE TABLE %s (id Integer, name String, dept String, subdept 
String) USING iceberg";
+
+    sql(createIceberg, tableName);
+
+    List<Object[]> result =
+        sql(
+            "CALL %s.system.add_files(table => '%s', source_table => '%s', 
parallelism => 2)",
+            catalogName, tableName, sourceTableName);
+
+    assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), 
result);
+
+    assertEquals(
+        "Iceberg table contains correct data",
+        sql("SELECT * FROM %s ORDER BY id", sourceTableName),
+        sql("SELECT * FROM %s ORDER BY id", tableName));
+  }
+
   private static final List<Object[]> emptyQueryResult = Lists.newArrayList();
 
   private static final StructField[] struct = {
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java
index 51df02d569..3a2324d891 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java
@@ -277,7 +277,8 @@ public class SparkTableUtil {
       PartitionSpec spec,
       SerializableConfiguration conf,
       MetricsConfig metricsConfig,
-      NameMapping mapping) {
+      NameMapping mapping,
+      int parallelism) {
     return TableMigrationUtil.listPartition(
         partition.values,
         partition.uri,
@@ -285,7 +286,8 @@ public class SparkTableUtil {
         spec,
         conf.get(),
         metricsConfig,
-        mapping);
+        mapping,
+        parallelism);
   }
 
   private static SparkPartition toSparkPartition(
@@ -382,6 +384,33 @@ public class SparkTableUtil {
       String stagingDir,
       Map<String, String> partitionFilter,
       boolean checkDuplicateFiles) {
+    importSparkTable(
+        spark, sourceTableIdent, targetTable, stagingDir, partitionFilter, 
checkDuplicateFiles, 1);
+  }
+
+  /**
+   * Import files from an existing Spark table to an Iceberg table.
+   *
+   * <p>The import uses the Spark session to get table metadata. It assumes no 
operation is going on
+   * the original and target table and thus is not thread-safe.
+   *
+   * @param spark a Spark session
+   * @param sourceTableIdent an identifier of the source Spark table
+   * @param targetTable an Iceberg table where to import the data
+   * @param stagingDir a staging directory to store temporary manifest files
+   * @param partitionFilter only import partitions whose values match those in 
the map, can be
+   *     partially defined
+   * @param checkDuplicateFiles if true, throw exception if import results in 
a duplicate data file
+   * @param parallelism number of threads to use for file reading
+   */
+  public static void importSparkTable(
+      SparkSession spark,
+      TableIdentifier sourceTableIdent,
+      Table targetTable,
+      String stagingDir,
+      Map<String, String> partitionFilter,
+      boolean checkDuplicateFiles,
+      int parallelism) {
     SessionCatalog catalog = spark.sessionState().catalog();
 
     String db =
@@ -402,7 +431,7 @@ public class SparkTableUtil {
 
       if (Objects.equal(spec, PartitionSpec.unpartitioned())) {
         importUnpartitionedSparkTable(
-            spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles);
+            spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles, 
parallelism);
       } else {
         List<SparkPartition> sourceTablePartitions =
             getPartitions(spark, sourceTableIdent, partitionFilter);
@@ -410,7 +439,13 @@ public class SparkTableUtil {
           targetTable.newAppend().commit();
         } else {
           importSparkPartitions(
-              spark, sourceTablePartitions, targetTable, spec, stagingDir, 
checkDuplicateFiles);
+              spark,
+              sourceTablePartitions,
+              targetTable,
+              spec,
+              stagingDir,
+              checkDuplicateFiles,
+              parallelism);
         }
       }
     } catch (AnalysisException e) {
@@ -443,7 +478,8 @@ public class SparkTableUtil {
         targetTable,
         stagingDir,
         Collections.emptyMap(),
-        checkDuplicateFiles);
+        checkDuplicateFiles,
+        1);
   }
 
   /**
@@ -460,14 +496,15 @@ public class SparkTableUtil {
   public static void importSparkTable(
       SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, 
String stagingDir) {
     importSparkTable(
-        spark, sourceTableIdent, targetTable, stagingDir, 
Collections.emptyMap(), false);
+        spark, sourceTableIdent, targetTable, stagingDir, 
Collections.emptyMap(), false, 1);
   }
 
   private static void importUnpartitionedSparkTable(
       SparkSession spark,
       TableIdentifier sourceTableIdent,
       Table targetTable,
-      boolean checkDuplicateFiles) {
+      boolean checkDuplicateFiles,
+      int parallelism) {
     try {
       CatalogTable sourceTable = 
spark.sessionState().catalog().getTableMetadata(sourceTableIdent);
       Option<String> format =
@@ -492,7 +529,8 @@ public class SparkTableUtil {
               spec,
               conf,
               metricsConfig,
-              nameMapping);
+              nameMapping,
+              parallelism);
 
       if (checkDuplicateFiles) {
         Dataset<Row> importedFiles =
@@ -540,9 +578,31 @@ public class SparkTableUtil {
       PartitionSpec spec,
       String stagingDir,
       boolean checkDuplicateFiles) {
+    importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, 
checkDuplicateFiles, 1);
+  }
+
+  /**
+   * Import files from given partitions to an Iceberg table.
+   *
+   * @param spark a Spark session
+   * @param partitions partitions to import
+   * @param targetTable an Iceberg table where to import the data
+   * @param spec a partition spec
+   * @param stagingDir a staging directory to store temporary manifest files
+   * @param checkDuplicateFiles if true, throw exception if import results in 
a duplicate data file
+   * @param parallelism number of threads to use for file reading
+   */
+  public static void importSparkPartitions(
+      SparkSession spark,
+      List<SparkPartition> partitions,
+      Table targetTable,
+      PartitionSpec spec,
+      String stagingDir,
+      boolean checkDuplicateFiles,
+      int parallelism) {
     Configuration conf = spark.sessionState().newHadoopConf();
     SerializableConfiguration serializableConf = new 
SerializableConfiguration(conf);
-    int parallelism =
+    int listingParallelism =
         Math.min(
             partitions.size(), 
spark.sessionState().conf().parallelPartitionDiscoveryParallelism());
     int numShufflePartitions = 
spark.sessionState().conf().numShufflePartitions();
@@ -552,7 +612,7 @@ public class SparkTableUtil {
         nameMappingString != null ? 
NameMappingParser.fromJson(nameMappingString) : null;
 
     JavaSparkContext sparkContext = 
JavaSparkContext.fromSparkContext(spark.sparkContext());
-    JavaRDD<SparkPartition> partitionRDD = 
sparkContext.parallelize(partitions, parallelism);
+    JavaRDD<SparkPartition> partitionRDD = 
sparkContext.parallelize(partitions, listingParallelism);
 
     Dataset<SparkPartition> partitionDS =
         spark.createDataset(partitionRDD.rdd(), 
Encoders.javaSerialization(SparkPartition.class));
@@ -562,7 +622,12 @@ public class SparkTableUtil {
             (FlatMapFunction<SparkPartition, DataFile>)
                 sparkPartition ->
                     listPartition(
-                            sparkPartition, spec, serializableConf, 
metricsConfig, nameMapping)
+                            sparkPartition,
+                            spec,
+                            serializableConf,
+                            metricsConfig,
+                            nameMapping,
+                            parallelism)
                         .iterator(),
             Encoders.javaSerialization(DataFile.class));
 
@@ -635,7 +700,7 @@ public class SparkTableUtil {
       Table targetTable,
       PartitionSpec spec,
       String stagingDir) {
-    importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, 
false);
+    importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, 
false, 1);
   }
 
   public static List<SparkPartition> filterPartitions(
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java
index 6a05706776..40a343b55b 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java
@@ -63,9 +63,16 @@ class AddFilesProcedure extends BaseProcedure {
   private static final ProcedureParameter CHECK_DUPLICATE_FILES_PARAM =
       ProcedureParameter.optional("check_duplicate_files", 
DataTypes.BooleanType);
 
+  private static final ProcedureParameter PARALLELISM =
+      ProcedureParameter.optional("parallelism", DataTypes.IntegerType);
+
   private static final ProcedureParameter[] PARAMETERS =
       new ProcedureParameter[] {
-        TABLE_PARAM, SOURCE_TABLE_PARAM, PARTITION_FILTER_PARAM, 
CHECK_DUPLICATE_FILES_PARAM
+        TABLE_PARAM,
+        SOURCE_TABLE_PARAM,
+        PARTITION_FILTER_PARAM,
+        CHECK_DUPLICATE_FILES_PARAM,
+        PARALLELISM
       };
 
   private static final StructType OUTPUT_TYPE =
@@ -112,7 +119,10 @@ class AddFilesProcedure extends BaseProcedure {
 
     boolean checkDuplicateFiles = input.asBoolean(CHECK_DUPLICATE_FILES_PARAM, 
true);
 
-    return importToIceberg(tableIdent, sourceIdent, partitionFilter, 
checkDuplicateFiles);
+    int parallelism = input.asInt(PARALLELISM, 1);
+
+    return importToIceberg(
+        tableIdent, sourceIdent, partitionFilter, checkDuplicateFiles, 
parallelism);
   }
 
   private InternalRow[] toOutputRows(Snapshot snapshot) {
@@ -142,7 +152,8 @@ class AddFilesProcedure extends BaseProcedure {
       Identifier destIdent,
       Identifier sourceIdent,
       Map<String, String> partitionFilter,
-      boolean checkDuplicateFiles) {
+      boolean checkDuplicateFiles,
+      int parallelism) {
     return modifyIcebergTable(
         destIdent,
         table -> {
@@ -153,9 +164,16 @@ class AddFilesProcedure extends BaseProcedure {
             Path sourcePath = new Path(sourceIdent.name());
             String format = sourceIdent.namespace()[0];
             importFileTable(
-                table, sourcePath, format, partitionFilter, 
checkDuplicateFiles, table.spec());
+                table,
+                sourcePath,
+                format,
+                partitionFilter,
+                checkDuplicateFiles,
+                table.spec(),
+                parallelism);
           } else {
-            importCatalogTable(table, sourceIdent, partitionFilter, 
checkDuplicateFiles);
+            importCatalogTable(
+                table, sourceIdent, partitionFilter, checkDuplicateFiles, 
parallelism);
           }
 
           Snapshot snapshot = table.currentSnapshot();
@@ -178,7 +196,8 @@ class AddFilesProcedure extends BaseProcedure {
       String format,
       Map<String, String> partitionFilter,
       boolean checkDuplicateFiles,
-      PartitionSpec spec) {
+      PartitionSpec spec,
+      int parallelism) {
     // List Partitions via Spark InMemory file search interface
     List<SparkPartition> partitions =
         Spark3Util.getPartitions(spark(), tableLocation, format, 
partitionFilter, spec);
@@ -193,11 +212,11 @@ class AddFilesProcedure extends BaseProcedure {
       // Build a Global Partition for the source
       SparkPartition partition =
           new SparkPartition(Collections.emptyMap(), tableLocation.toString(), 
format);
-      importPartitions(table, ImmutableList.of(partition), 
checkDuplicateFiles);
+      importPartitions(table, ImmutableList.of(partition), 
checkDuplicateFiles, parallelism);
     } else {
       Preconditions.checkArgument(
           !partitions.isEmpty(), "Cannot find any matching partitions in table 
%s", table.name());
-      importPartitions(table, partitions, checkDuplicateFiles);
+      importPartitions(table, partitions, checkDuplicateFiles, parallelism);
     }
   }
 
@@ -205,7 +224,8 @@ class AddFilesProcedure extends BaseProcedure {
       Table table,
       Identifier sourceIdent,
       Map<String, String> partitionFilter,
-      boolean checkDuplicateFiles) {
+      boolean checkDuplicateFiles,
+      int parallelism) {
     String stagingLocation = getMetadataLocation(table);
     TableIdentifier sourceTableIdentifier = 
Spark3Util.toV1TableIdentifier(sourceIdent);
     SparkTableUtil.importSparkTable(
@@ -214,14 +234,24 @@ class AddFilesProcedure extends BaseProcedure {
         table,
         stagingLocation,
         partitionFilter,
-        checkDuplicateFiles);
+        checkDuplicateFiles,
+        parallelism);
   }
 
   private void importPartitions(
-      Table table, List<SparkTableUtil.SparkPartition> partitions, boolean 
checkDuplicateFiles) {
+      Table table,
+      List<SparkTableUtil.SparkPartition> partitions,
+      boolean checkDuplicateFiles,
+      int parallelism) {
     String stagingLocation = getMetadataLocation(table);
     SparkTableUtil.importSparkPartitions(
-        spark(), partitions, table, table.spec(), stagingLocation, 
checkDuplicateFiles);
+        spark(),
+        partitions,
+        table,
+        table.spec(),
+        stagingLocation,
+        checkDuplicateFiles,
+        parallelism);
   }
 
   private String getMetadataLocation(Table table) {
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java
index 42e4d8ba06..0be4b38de7 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java
@@ -68,6 +68,18 @@ class ProcedureInput {
     return args.isNullAt(ordinal) ? defaultValue : (Boolean) 
args.getBoolean(ordinal);
   }
 
+  public Integer asInt(ProcedureParameter param) {
+    Integer value = asInt(param, null);
+    Preconditions.checkArgument(value != null, "Parameter '%s' is not set", 
param.name());
+    return value;
+  }
+
+  public Integer asInt(ProcedureParameter param, Integer defaultValue) {
+    validateParamType(param, DataTypes.IntegerType);
+    int ordinal = ordinal(param);
+    return args.isNullAt(ordinal) ? defaultValue : (Integer) 
args.getInt(ordinal);
+  }
+
   public long asLong(ProcedureParameter param) {
     Long value = asLong(param, null);
     Preconditions.checkArgument(value != null, "Parameter '%s' is not set", 
param.name());

Reply via email to