This is an automated email from the ASF dual-hosted git repository.

bryanck pushed a commit to branch 1.7.1rc0-contd
in repository https://gitbox.apache.org/repos/asf/iceberg.git

commit 9b74bbb1195af7d19c3a404f5b230c545be9ae42
Author: Manu Zhang <[email protected]>
AuthorDate: Thu Nov 21 00:39:41 2024 +0800

    Spark 3.5: Fix NotSerializableException when migrating Spark tables (#11157)
---
 .../apache/iceberg/data/TableMigrationUtil.java    |   2 +
 .../spark/extensions/TestAddFilesProcedure.java    |  20 ++++
 .../extensions/TestMigrateTableProcedure.java      |  18 ++++
 .../extensions/TestSnapshotTableProcedure.java     |  18 ++++
 .../org/apache/iceberg/spark/SparkTableUtil.java   | 117 ++++++++++++++++++++-
 .../spark/procedures/MigrateTableProcedure.java    |   3 +-
 .../spark/procedures/SnapshotTableProcedure.java   |   3 +-
 7 files changed, 177 insertions(+), 4 deletions(-)

diff --git a/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java 
b/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java
index 0602c9e494..eb1c1a341e 100644
--- a/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java
+++ b/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java
@@ -25,6 +25,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ExecutorService;
 import java.util.stream.Collectors;
+import javax.annotation.Nullable;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
@@ -263,6 +264,7 @@ public class TableMigrationUtil {
    * <p><b>Important:</b> Callers are responsible for shutting down the 
returned executor service
    * when it is no longer needed to prevent resource leaks.
    */
+  @Nullable
   public static ExecutorService migrationService(int parallelism) {
     return parallelism == 1 ? null : 
ThreadPools.newFixedThreadPool("table-migration", parallelism);
   }
diff --git 
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
 
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
index 920c2f55ea..332669470a 100644
--- 
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
+++ 
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
@@ -948,6 +948,26 @@ public class TestAddFilesProcedure extends 
ExtensionsTestBase {
         sql("SELECT * FROM %s ORDER BY id", tableName));
   }
 
+  @TestTemplate
+  public void testAddFilesPartitionedWithParallelism() {
+    createPartitionedHiveTable();
+
+    createIcebergTable(
+        "id Integer, name String, dept String, subdept String", "PARTITIONED 
BY (id)");
+
+    List<Object[]> result =
+        sql(
+            "CALL %s.system.add_files(table => '%s', source_table => '%s', 
parallelism => 2)",
+            catalogName, tableName, sourceTableName);
+
+    assertOutput(result, 8L, 4L);
+
+    assertEquals(
+        "Iceberg table contains correct data",
+        sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", 
sourceTableName),
+        sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
+  }
+
   private static final List<Object[]> EMPTY_QUERY_RESULT = 
Lists.newArrayList();
 
   private static final StructField[] STRUCT = {
diff --git 
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java
 
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java
index 23c08b2572..69e80026e6 100644
--- 
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java
+++ 
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java
@@ -273,4 +273,22 @@ public class TestMigrateTableProcedure extends 
ExtensionsTestBase {
         .isInstanceOf(IllegalArgumentException.class)
         .hasMessage("Parallelism should be larger than 0");
   }
+
+  @TestTemplate
+  public void testMigratePartitionedWithParallelism() throws IOException {
+    assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog");
+
+    String location = Files.createTempDirectory(temp, 
"junit").toFile().toString();
+    sql(
+        "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet 
PARTITIONED BY (id) LOCATION '%s'",
+        tableName, location);
+    sql("INSERT INTO TABLE %s (id, data) VALUES (1, 'a'), (2, 'b')", 
tableName);
+    List<Object[]> result =
+        sql("CALL %s.system.migrate(table => '%s', parallelism => %d)", 
catalogName, tableName, 2);
+    assertEquals("Procedure output must match", ImmutableList.of(row(2L)), 
result);
+    assertEquals(
+        "Should have expected rows",
+        ImmutableList.of(row("a", 1L), row("b", 2L)),
+        sql("SELECT * FROM %s ORDER BY id", tableName));
+  }
 }
diff --git 
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java
 
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java
index 6caff28bb1..28ae31ec6a 100644
--- 
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java
+++ 
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java
@@ -263,4 +263,22 @@ public class TestSnapshotTableProcedure extends 
ExtensionsTestBase {
         .isInstanceOf(IllegalArgumentException.class)
         .hasMessage("Parallelism should be larger than 0");
   }
+
+  @TestTemplate
+  public void testSnapshotPartitionedWithParallelism() throws IOException {
+    String location = Files.createTempDirectory(temp, 
"junit").toFile().toString();
+    sql(
+        "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet 
PARTITIONED BY (id) LOCATION '%s'",
+        SOURCE_NAME, location);
+    sql("INSERT INTO TABLE %s (id, data) VALUES (1, 'a'), (2, 'b')", 
SOURCE_NAME);
+    List<Object[]> result =
+        sql(
+            "CALL %s.system.snapshot(source_table => '%s', table => '%s', 
parallelism => %d)",
+            catalogName, SOURCE_NAME, tableName, 2);
+    assertEquals("Procedure output must match", ImmutableList.of(row(2L)), 
result);
+    assertEquals(
+        "Should have expected rows",
+        ImmutableList.of(row("a", 1L), row("b", 2L)),
+        sql("SELECT * FROM %s ORDER BY id", tableName));
+  }
 }
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java
index c44969c49e..01912c81cc 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java
@@ -23,12 +23,18 @@ import static org.apache.spark.sql.functions.col;
 import java.io.IOException;
 import java.io.Serializable;
 import java.net.URI;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.UUID;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 import java.util.stream.Collectors;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
@@ -92,6 +98,8 @@ import org.apache.spark.sql.catalyst.parser.ParseException;
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
 import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+import org.jetbrains.annotations.NotNull;
+import org.jetbrains.annotations.Nullable;
 import scala.Function2;
 import scala.Option;
 import scala.Some;
@@ -487,7 +495,7 @@ public class SparkTableUtil {
         stagingDir,
         partitionFilter,
         checkDuplicateFiles,
-        TableMigrationUtil.migrationService(parallelism));
+        migrationService(parallelism));
   }
 
   /**
@@ -711,7 +719,7 @@ public class SparkTableUtil {
         spec,
         stagingDir,
         checkDuplicateFiles,
-        TableMigrationUtil.migrationService(parallelism));
+        migrationService(parallelism));
   }
 
   /**
@@ -971,4 +979,109 @@ public class SparkTableUtil {
       return Objects.hashCode(values, uri, format);
     }
   }
+
+  @Nullable
+  public static ExecutorService migrationService(int parallelism) {
+    return parallelism == 1 ? null : new LazyExecutorService(parallelism);
+  }
+
+  private static class LazyExecutorService implements ExecutorService, 
Serializable {
+
+    private final int parallelism;
+    private volatile ExecutorService service;
+
+    LazyExecutorService(int parallelism) {
+      this.parallelism = parallelism;
+    }
+
+    @Override
+    public void shutdown() {
+      getService().shutdown();
+    }
+
+    @NotNull
+    @Override
+    public List<Runnable> shutdownNow() {
+      return getService().shutdownNow();
+    }
+
+    @Override
+    public boolean isShutdown() {
+      return getService().isShutdown();
+    }
+
+    @Override
+    public boolean isTerminated() {
+      return getService().isTerminated();
+    }
+
+    @Override
+    public boolean awaitTermination(long timeout, @NotNull TimeUnit unit)
+        throws InterruptedException {
+      return getService().awaitTermination(timeout, unit);
+    }
+
+    @NotNull
+    @Override
+    public <T> Future<T> submit(@NotNull Callable<T> task) {
+      return getService().submit(task);
+    }
+
+    @NotNull
+    @Override
+    public <T> Future<T> submit(@NotNull Runnable task, T result) {
+      return getService().submit(task, result);
+    }
+
+    @NotNull
+    @Override
+    public Future<?> submit(@NotNull Runnable task) {
+      return getService().submit(task);
+    }
+
+    @NotNull
+    @Override
+    public <T> List<Future<T>> invokeAll(@NotNull Collection<? extends 
Callable<T>> tasks)
+        throws InterruptedException {
+      return getService().invokeAll(tasks);
+    }
+
+    @NotNull
+    @Override
+    public <T> List<Future<T>> invokeAll(
+        @NotNull Collection<? extends Callable<T>> tasks, long timeout, 
@NotNull TimeUnit unit)
+        throws InterruptedException {
+      return getService().invokeAll(tasks, timeout, unit);
+    }
+
+    @NotNull
+    @Override
+    public <T> T invokeAny(@NotNull Collection<? extends Callable<T>> tasks)
+        throws InterruptedException, ExecutionException {
+      return getService().invokeAny(tasks);
+    }
+
+    @Override
+    public <T> T invokeAny(
+        @NotNull Collection<? extends Callable<T>> tasks, long timeout, 
@NotNull TimeUnit unit)
+        throws InterruptedException, ExecutionException, TimeoutException {
+      return getService().invokeAny(tasks, timeout, unit);
+    }
+
+    @Override
+    public void execute(@NotNull Runnable command) {
+      getService().execute(command);
+    }
+
+    private ExecutorService getService() {
+      if (service == null) {
+        synchronized (this) {
+          if (service == null) {
+            service = TableMigrationUtil.migrationService(parallelism);
+          }
+        }
+      }
+      return service;
+    }
+  }
 }
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java
index a0bd04dd99..7c67a1aced 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java
@@ -22,6 +22,7 @@ import java.util.Map;
 import org.apache.iceberg.actions.MigrateTable;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.spark.SparkTableUtil;
 import org.apache.iceberg.spark.actions.MigrateTableSparkAction;
 import org.apache.iceberg.spark.actions.SparkActions;
 import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder;
@@ -110,7 +111,7 @@ class MigrateTableProcedure extends BaseProcedure {
       int parallelism = args.getInt(4);
       Preconditions.checkArgument(parallelism > 0, "Parallelism should be 
larger than 0");
       migrateTableSparkAction =
-          migrateTableSparkAction.executeWith(executorService(parallelism, 
"table-migration"));
+          
migrateTableSparkAction.executeWith(SparkTableUtil.migrationService(parallelism));
     }
 
     MigrateTable.Result result = migrateTableSparkAction.execute();
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java
index f709f64ebf..37dfde76b7 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java
@@ -22,6 +22,7 @@ import java.util.Map;
 import org.apache.iceberg.actions.SnapshotTable;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.spark.SparkTableUtil;
 import org.apache.iceberg.spark.actions.SparkActions;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.connector.catalog.TableCatalog;
@@ -106,7 +107,7 @@ class SnapshotTableProcedure extends BaseProcedure {
     if (!args.isNullAt(4)) {
       int parallelism = args.getInt(4);
       Preconditions.checkArgument(parallelism > 0, "Parallelism should be 
larger than 0");
-      action = action.executeWith(executorService(parallelism, 
"table-snapshot"));
+      action = 
action.executeWith(SparkTableUtil.migrationService(parallelism));
     }
 
     SnapshotTable.Result result = action.tableProperties(properties).execute();

Reply via email to