This is an automated email from the ASF dual-hosted git repository.

etudenhoefner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new 28c246e2c0 Spark 3.4: Fix NotSerializableException when migrating 
Spark tables (#12705)
28c246e2c0 is described below

commit 28c246e2c020a1f849effcfc7d5d1852f1f59b02
Author: Manu Zhang <[email protected]>
AuthorDate: Wed Apr 2 22:12:18 2025 +0800

    Spark 3.4: Fix NotSerializableException when migrating Spark tables (#12705)
    
    backports #11157 to Spark 3.4
---
 .../spark/extensions/TestAddFilesProcedure.java    |  20 ++++
 .../extensions/TestMigrateTableProcedure.java      |  20 ++++
 .../extensions/TestSnapshotTableProcedure.java     |  20 ++++
 .../org/apache/iceberg/spark/SparkTableUtil.java   | 117 ++++++++++++++++++++-
 .../spark/procedures/MigrateTableProcedure.java    |   3 +-
 .../spark/procedures/SnapshotTableProcedure.java   |   3 +-
 6 files changed, 179 insertions(+), 4 deletions(-)

diff --git 
a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
index 67081ecb95..a34caabe01 100644
--- 
a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
+++ 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java
@@ -1104,6 +1104,26 @@ public class TestAddFilesProcedure extends 
SparkExtensionsTestBase {
         sql("SELECT * FROM %s ORDER BY id", tableName));
   }
 
+  @Test
+  public void testAddFilesPartitionedWithParallelism() {
+    createPartitionedHiveTable();
+
+    createIcebergTable(
+        "id Integer, name String, dept String, subdept String", "PARTITIONED 
BY (id)");
+
+    List<Object[]> result =
+        sql(
+            "CALL %s.system.add_files(table => '%s', source_table => '%s', 
parallelism => 2)",
+            catalogName, tableName, sourceTableName);
+
+    assertOutput(result, 8L, 4L);
+
+    assertEquals(
+        "Iceberg table contains correct data",
+        sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", 
sourceTableName),
+        sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName));
+  }
+
   private static final List<Object[]> EMPTY_QUERY_RESULT = 
Lists.newArrayList();
 
   private static final StructField[] STRUCT = {
diff --git 
a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java
 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java
index 7365f7a6df..95a0290ddd 100644
--- 
a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java
+++ 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMigrateTableProcedure.java
@@ -238,4 +238,24 @@ public class TestMigrateTableProcedure extends 
SparkExtensionsTestBase {
     Object result = scalarSql("CALL %s.system.migrate('%s')", catalogName, 
tableName);
     Assert.assertEquals(0L, result);
   }
+
+  @Test
+  public void testMigratePartitionedWithParallelism() throws IOException {
+    Assume.assumeTrue(catalogName.equals("spark_catalog"));
+
+    String location = temp.newFolder().toString();
+    sql(
+        "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet 
PARTITIONED BY (id) LOCATION '%s'",
+        tableName, location);
+    sql("INSERT INTO TABLE %s (id, data) VALUES (1, 'a'), (2, 'b')", 
tableName);
+
+    assertEquals(
+        "Procedure output must match",
+        ImmutableList.of(row(2L)),
+        sql("CALL %s.system.migrate(table => '%s', parallelism => %d)", 
catalogName, tableName, 2));
+    assertEquals(
+        "Should have expected rows",
+        ImmutableList.of(row("a", 1L), row("b", 2L)),
+        sql("SELECT * FROM %s ORDER BY id", tableName));
+  }
 }
diff --git 
a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java
 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java
index 421d6efc93..1cd0545207 100644
--- 
a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java
+++ 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSnapshotTableProcedure.java
@@ -223,4 +223,24 @@ public class TestSnapshotTableProcedure extends 
SparkExtensionsTestBase {
         .isInstanceOf(IllegalArgumentException.class)
         .hasMessage("Cannot handle an empty identifier for argument table");
   }
+
+  @Test
+  public void testSnapshotPartitionedWithParallelism() throws IOException {
+    String location = temp.newFolder().toString();
+    sql(
+        "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet 
PARTITIONED BY (id) LOCATION '%s'",
+        SOURCE_NAME, location);
+    sql("INSERT INTO TABLE %s (id, data) VALUES (1, 'a'), (2, 'b')", 
SOURCE_NAME);
+
+    assertEquals(
+        "Procedure output must match",
+        ImmutableList.of(row(2L)),
+        sql(
+            "CALL %s.system.snapshot(source_table => '%s', table => '%s', 
parallelism => %d)",
+            catalogName, SOURCE_NAME, tableName, 2));
+    assertEquals(
+        "Should have expected rows",
+        ImmutableList.of(row("a", 1L), row("b", 2L)),
+        sql("SELECT * FROM %s ORDER BY id", tableName));
+  }
 }
diff --git 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java
index 12a38b68f0..1998f1c61a 100644
--- 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java
+++ 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java
@@ -23,6 +23,7 @@ import static org.apache.spark.sql.functions.col;
 import java.io.IOException;
 import java.io.Serializable;
 import java.net.URI;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
@@ -30,7 +31,12 @@ import java.util.Locale;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 import java.util.stream.Collectors;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
@@ -97,6 +103,8 @@ import org.apache.spark.sql.catalyst.parser.ParseException;
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
 import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+import org.jetbrains.annotations.NotNull;
+import org.jetbrains.annotations.Nullable;
 import scala.Function2;
 import scala.Option;
 import scala.Some;
@@ -490,7 +498,7 @@ public class SparkTableUtil {
         stagingDir,
         partitionFilter,
         checkDuplicateFiles,
-        TableMigrationUtil.migrationService(parallelism));
+        migrationService(parallelism));
   }
 
   /**
@@ -714,7 +722,7 @@ public class SparkTableUtil {
         spec,
         stagingDir,
         checkDuplicateFiles,
-        TableMigrationUtil.migrationService(parallelism));
+        migrationService(parallelism));
   }
 
   /**
@@ -1062,4 +1070,109 @@ public class SparkTableUtil {
           tableName);
     }
   }
+
+  @Nullable
+  public static ExecutorService migrationService(int parallelism) {
+    return parallelism == 1 ? null : new LazyExecutorService(parallelism);
+  }
+
+  private static class LazyExecutorService implements ExecutorService, 
Serializable {
+
+    private final int parallelism;
+    private volatile ExecutorService service;
+
+    LazyExecutorService(int parallelism) {
+      this.parallelism = parallelism;
+    }
+
+    @Override
+    public void shutdown() {
+      getService().shutdown();
+    }
+
+    @NotNull
+    @Override
+    public List<Runnable> shutdownNow() {
+      return getService().shutdownNow();
+    }
+
+    @Override
+    public boolean isShutdown() {
+      return getService().isShutdown();
+    }
+
+    @Override
+    public boolean isTerminated() {
+      return getService().isTerminated();
+    }
+
+    @Override
+    public boolean awaitTermination(long timeout, @NotNull TimeUnit unit)
+        throws InterruptedException {
+      return getService().awaitTermination(timeout, unit);
+    }
+
+    @NotNull
+    @Override
+    public <T> Future<T> submit(@NotNull Callable<T> task) {
+      return getService().submit(task);
+    }
+
+    @NotNull
+    @Override
+    public <T> Future<T> submit(@NotNull Runnable task, T result) {
+      return getService().submit(task, result);
+    }
+
+    @NotNull
+    @Override
+    public Future<?> submit(@NotNull Runnable task) {
+      return getService().submit(task);
+    }
+
+    @NotNull
+    @Override
+    public <T> List<Future<T>> invokeAll(@NotNull Collection<? extends 
Callable<T>> tasks)
+        throws InterruptedException {
+      return getService().invokeAll(tasks);
+    }
+
+    @NotNull
+    @Override
+    public <T> List<Future<T>> invokeAll(
+        @NotNull Collection<? extends Callable<T>> tasks, long timeout, 
@NotNull TimeUnit unit)
+        throws InterruptedException {
+      return getService().invokeAll(tasks, timeout, unit);
+    }
+
+    @NotNull
+    @Override
+    public <T> T invokeAny(@NotNull Collection<? extends Callable<T>> tasks)
+        throws InterruptedException, ExecutionException {
+      return getService().invokeAny(tasks);
+    }
+
+    @Override
+    public <T> T invokeAny(
+        @NotNull Collection<? extends Callable<T>> tasks, long timeout, 
@NotNull TimeUnit unit)
+        throws InterruptedException, ExecutionException, TimeoutException {
+      return getService().invokeAny(tasks, timeout, unit);
+    }
+
+    @Override
+    public void execute(@NotNull Runnable command) {
+      getService().execute(command);
+    }
+
+    private ExecutorService getService() {
+      if (service == null) {
+        synchronized (this) {
+          if (service == null) {
+            service = TableMigrationUtil.migrationService(parallelism);
+          }
+        }
+      }
+      return service;
+    }
+  }
 }
diff --git 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java
 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java
index a0bd04dd99..7c67a1aced 100644
--- 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java
+++ 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java
@@ -22,6 +22,7 @@ import java.util.Map;
 import org.apache.iceberg.actions.MigrateTable;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.spark.SparkTableUtil;
 import org.apache.iceberg.spark.actions.MigrateTableSparkAction;
 import org.apache.iceberg.spark.actions.SparkActions;
 import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder;
@@ -110,7 +111,7 @@ class MigrateTableProcedure extends BaseProcedure {
       int parallelism = args.getInt(4);
       Preconditions.checkArgument(parallelism > 0, "Parallelism should be 
larger than 0");
       migrateTableSparkAction =
-          migrateTableSparkAction.executeWith(executorService(parallelism, 
"table-migration"));
+          
migrateTableSparkAction.executeWith(SparkTableUtil.migrationService(parallelism));
     }
 
     MigrateTable.Result result = migrateTableSparkAction.execute();
diff --git 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java
 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java
index f709f64ebf..37dfde76b7 100644
--- 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java
+++ 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java
@@ -22,6 +22,7 @@ import java.util.Map;
 import org.apache.iceberg.actions.SnapshotTable;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.spark.SparkTableUtil;
 import org.apache.iceberg.spark.actions.SparkActions;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.connector.catalog.TableCatalog;
@@ -106,7 +107,7 @@ class SnapshotTableProcedure extends BaseProcedure {
     if (!args.isNullAt(4)) {
       int parallelism = args.getInt(4);
       Preconditions.checkArgument(parallelism > 0, "Parallelism should be 
larger than 0");
-      action = action.executeWith(executorService(parallelism, 
"table-snapshot"));
+      action = 
action.executeWith(SparkTableUtil.migrationService(parallelism));
     }
 
     SnapshotTable.Result result = action.tableProperties(properties).execute();

Reply via email to