This is an automated email from the ASF dual-hosted git repository.

ahmedabu98 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 51e35e52440 [IcebergIO] Improve TableCache (#38882)
51e35e52440 is described below

commit 51e35e524400373654fff8d998a17a407877f719
Author: Ahmed Abualsaud <[email protected]>
AuthorDate: Thu Jun 11 09:23:01 2026 -0400

    [IcebergIO] Improve TableCache (#38882)
    
    * improve table cache
    
    * address comments
    
    * cleanups
---
 .../IO_Iceberg_Integration_Tests.json              |   2 +-
 .../iceberg/AssignDestinationsAndPartitions.java   |   5 +-
 .../beam/sdk/io/iceberg/CreateReadTasksDoFn.java   |   7 +-
 .../org/apache/beam/sdk/io/iceberg/IcebergIO.java  |   2 +-
 .../beam/sdk/io/iceberg/IcebergScanConfig.java     |   3 +-
 .../beam/sdk/io/iceberg/IncrementalScanSource.java |   6 +-
 .../apache/beam/sdk/io/iceberg/ReadFromTasks.java  |   7 +-
 .../beam/sdk/io/iceberg/RecordWriterManager.java   | 140 +++++----------
 .../org/apache/beam/sdk/io/iceberg/ScanSource.java |   3 +-
 .../org/apache/beam/sdk/io/iceberg/TableCache.java | 198 ++++++++++++++++-----
 .../beam/sdk/io/iceberg/WatchForSnapshots.java     |   7 +-
 .../sdk/io/iceberg/WriteDirectRowsToFiles.java     |  12 +-
 .../sdk/io/iceberg/WriteGroupedRowsToFiles.java    |  12 +-
 .../io/iceberg/WritePartitionedRowsToFiles.java    | 166 +++++++----------
 .../sdk/io/iceberg/WriteUngroupedRowsToFiles.java  |  12 +-
 .../sdk/io/iceberg/RecordWriterManagerTest.java    | 121 ++++++++-----
 .../apache/beam/sdk/io/iceberg/TableCacheTest.java | 126 +++++++++++++
 17 files changed, 479 insertions(+), 350 deletions(-)

diff --git a/.github/trigger_files/IO_Iceberg_Integration_Tests.json 
b/.github/trigger_files/IO_Iceberg_Integration_Tests.json
index 7ab7bcd9a9c..b73af5e61a4 100644
--- a/.github/trigger_files/IO_Iceberg_Integration_Tests.json
+++ b/.github/trigger_files/IO_Iceberg_Integration_Tests.json
@@ -1,4 +1,4 @@
 {
     "comment": "Modify this file in a trivial way to cause this test suite to 
run.",
-    "modification": 2
+    "modification": 1
 }
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AssignDestinationsAndPartitions.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AssignDestinationsAndPartitions.java
index e5d70d85d87..99cd07b23c8 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AssignDestinationsAndPartitions.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AssignDestinationsAndPartitions.java
@@ -147,7 +147,10 @@ class AssignDestinationsAndPartitions
 
           try {
             // see if table already exists with a spec
-            spec = 
catalogConfig.catalog().loadTable(TableIdentifier.parse(tableIdentifier)).spec();
+            spec =
+                TableCache.getAndRefreshIfStale(
+                        catalogConfig, TableIdentifier.parse(tableIdentifier))
+                    .spec();
 
           } catch (NoSuchTableException ignored) {
             // no partition to apply
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/CreateReadTasksDoFn.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/CreateReadTasksDoFn.java
index a40e0e13f78..d9781859a88 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/CreateReadTasksDoFn.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/CreateReadTasksDoFn.java
@@ -52,18 +52,13 @@ class CreateReadTasksDoFn
     this.scanConfig = scanConfig;
   }
 
-  @Setup
-  public void setup() {
-    TableCache.setup(scanConfig);
-  }
-
   @ProcessElement
   public void process(
       @Element KV<String, List<SnapshotInfo>> element,
       OutputReceiver<KV<ReadTaskDescriptor, ReadTask>> out)
       throws IOException, ExecutionException {
     // force refresh because the table must be updated before scanning 
snapshots
-    Table table = TableCache.getRefreshed(element.getKey());
+    Table table = TableCache.getRefreshed(scanConfig.getCatalogConfig(), 
element.getKey());
 
     // scan snapshots individually and assign commit timestamp to files
     for (SnapshotInfo snapshot : element.getValue()) {
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java
index a5a3beef8f5..5c5f934ea20 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java
@@ -655,7 +655,7 @@ public class IcebergIO {
       TableIdentifier tableId =
           checkStateNotNull(getTableIdentifier(), "Must set a table to read 
from.");
 
-      Table table = getCatalogConfig().catalog().loadTable(tableId);
+      Table table = TableCache.get(getCatalogConfig(), tableId);
 
       IcebergScanConfig scanConfig =
           IcebergScanConfig.builder()
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergScanConfig.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergScanConfig.java
index 45ecc7cf71c..a942c9804c9 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergScanConfig.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergScanConfig.java
@@ -68,8 +68,7 @@ public abstract class IcebergScanConfig implements 
Serializable {
   @Pure
   public Table getTable() {
     if (cachedTable == null) {
-      cachedTable =
-          
getCatalogConfig().catalog().loadTable(TableIdentifier.parse(getTableIdentifier()));
+      cachedTable = TableCache.get(getCatalogConfig(), 
TableIdentifier.parse(getTableIdentifier()));
     }
     return cachedTable;
   }
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IncrementalScanSource.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IncrementalScanSource.java
index 4df3eecb18e..58cc8f50e0b 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IncrementalScanSource.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IncrementalScanSource.java
@@ -53,10 +53,8 @@ class IncrementalScanSource extends PTransform<PBegin, 
PCollection<Row>> {
   @Override
   public PCollection<Row> expand(PBegin input) {
     Table table =
-        scanConfig
-            .getCatalogConfig()
-            .catalog()
-            .loadTable(TableIdentifier.parse(scanConfig.getTableIdentifier()));
+        TableCache.get(
+            scanConfig.getCatalogConfig(), 
TableIdentifier.parse(scanConfig.getTableIdentifier()));
 
     PCollection<KV<String, List<SnapshotInfo>>> snapshots =
         MoreObjects.firstNonNull(scanConfig.getStreaming(), false)
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ReadFromTasks.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ReadFromTasks.java
index 528b89c203b..5eeeacda48e 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ReadFromTasks.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ReadFromTasks.java
@@ -51,11 +51,6 @@ class ReadFromTasks extends DoFn<KV<ReadTaskDescriptor, 
ReadTask>, Row> {
     this.scanConfig = scanConfig;
   }
 
-  @Setup
-  public void setup() {
-    TableCache.setup(scanConfig);
-  }
-
   @ProcessElement
   public void process(
       @Element KV<ReadTaskDescriptor, ReadTask> element,
@@ -63,7 +58,7 @@ class ReadFromTasks extends DoFn<KV<ReadTaskDescriptor, 
ReadTask>, Row> {
       OutputReceiver<Row> out)
       throws IOException, ExecutionException, InterruptedException {
     ReadTask readTask = element.getValue();
-    Table table = TableCache.get(scanConfig.getTableIdentifier());
+    Table table = TableCache.get(scanConfig.getCatalogConfig(), 
scanConfig.getTableIdentifier());
 
     List<FileScanTask> fileScanTasks = readTask.getFileScanTasks();
 
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java
index 2f532a08754..1e25e6b9c23 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java
@@ -21,8 +21,6 @@ import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Pr
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
 
 import java.io.IOException;
-import java.time.Duration;
-import java.time.Instant;
 import java.time.LocalDateTime;
 import java.time.YearMonth;
 import java.time.ZoneOffset;
@@ -248,7 +246,7 @@ class RecordWriterManager implements AutoCloseable {
       DateTimeFormatter.ofPattern("yyyy-MM-dd-HH");
   private static final LocalDateTime EPOCH = LocalDateTime.ofEpochSecond(0, 0, 
ZoneOffset.UTC);
 
-  private final Catalog catalog;
+  private final IcebergCatalogConfig catalogConfig;
   private final String filePrefix;
   private final long maxFileSize;
   private final int maxNumWriters;
@@ -260,46 +258,11 @@ class RecordWriterManager implements AutoCloseable {
   private final Map<WindowedValue<IcebergDestination>, 
List<SerializableDataFile>>
       totalSerializableDataFiles = Maps.newHashMap();
 
-  static final class LastRefreshedTable {
-    final Table table;
-    volatile Instant lastRefreshTime;
-    static final Duration STALENESS_THRESHOLD = Duration.ofMinutes(2);
-
-    LastRefreshedTable(Table table, Instant lastRefreshTime) {
-      this.table = table;
-      this.lastRefreshTime = lastRefreshTime;
-    }
-
-    /**
-     * Refreshes the table metadata if it is considered stale (older than 2 
minutes).
-     *
-     * <p>This method first performs a non-synchronized check on the table's 
freshness. This
-     * provides a lock-free fast path that avoids synchronization overhead in 
the common case where
-     * the table does not need to be refreshed. If the table might be stale, 
it then enters a
-     * synchronized block to ensure that only one thread performs the refresh 
operation.
-     */
-    void refreshIfStale() {
-      // Fast path: Avoid entering the synchronized block if the table is not 
stale.
-      if (lastRefreshTime.isAfter(Instant.now().minus(STALENESS_THRESHOLD))) {
-        return;
-      }
-      synchronized (this) {
-        if 
(lastRefreshTime.isBefore(Instant.now().minus(STALENESS_THRESHOLD))) {
-          table.refresh();
-          lastRefreshTime = Instant.now();
-        }
-      }
-    }
-  }
-
-  @VisibleForTesting
-  static final Cache<TableIdentifier, LastRefreshedTable> 
LAST_REFRESHED_TABLE_CACHE =
-      CacheBuilder.newBuilder().expireAfterAccess(10, 
TimeUnit.MINUTES).build();
-
   private boolean isClosed = false;
 
-  RecordWriterManager(Catalog catalog, String filePrefix, long maxFileSize, 
int maxNumWriters) {
-    this.catalog = catalog;
+  RecordWriterManager(
+      IcebergCatalogConfig catalogConfig, String filePrefix, long maxFileSize, 
int maxNumWriters) {
+    this.catalogConfig = catalogConfig;
     this.filePrefix = filePrefix;
     this.maxFileSize = maxFileSize;
     this.maxNumWriters = maxNumWriters;
@@ -308,9 +271,9 @@ class RecordWriterManager implements AutoCloseable {
   /**
    * Returns an Iceberg {@link Table}.
    *
-   * <p>First attempts to fetch the table from the {@link 
#LAST_REFRESHED_TABLE_CACHE}. If it's not
-   * there, we attempt to load it using the Iceberg API. If the table doesn't 
exist at all, we
-   * attempt to create it, inferring the table schema from the record schema.
+   * <p>First attempts to fetch the table from the shared {@link TableCache}. 
If it's not there, we
+   * attempt to load it using the Iceberg API. If the table doesn't exist at 
all, we attempt to
+   * create it, inferring the table schema from the record schema.
    *
    * <p>Note that this is a best-effort operation that depends on the {@link 
Catalog}
    * implementation. Although it is expected, some implementations may not 
support creating a table
@@ -319,13 +282,13 @@ class RecordWriterManager implements AutoCloseable {
   @VisibleForTesting
   Table getOrCreateTable(IcebergDestination destination, Schema dataSchema) {
     TableIdentifier identifier = destination.getTableIdentifier();
-    @Nullable
-    LastRefreshedTable lastRefreshedTable = 
LAST_REFRESHED_TABLE_CACHE.getIfPresent(identifier);
-    if (lastRefreshedTable != null && lastRefreshedTable.table != null) {
-      lastRefreshedTable.refreshIfStale();
-      return lastRefreshedTable.table;
-    }
+    return TableCache.getAndRefreshIfStale(
+        catalogConfig, identifier, () -> loadOrCreateTable(destination, 
dataSchema));
+  }
 
+  private Table loadOrCreateTable(IcebergDestination destination, Schema 
dataSchema) {
+    Catalog catalog = catalogConfig.catalog();
+    TableIdentifier identifier = destination.getTableIdentifier();
     Namespace namespace = identifier.namespace();
     @Nullable IcebergTableCreateConfig createConfig = 
destination.getTableCreateConfig();
     PartitionSpec partitionSpec =
@@ -336,53 +299,48 @@ class RecordWriterManager implements AutoCloseable {
             ? createConfig.getTableProperties()
             : Maps.newHashMap();
 
-    @Nullable Table table = null;
-    synchronized (LAST_REFRESHED_TABLE_CACHE) {
-      // Create namespace if it does not exist yet
-      if (!namespace.isEmpty() && catalog instanceof SupportsNamespaces) {
-        SupportsNamespaces supportsNamespaces = (SupportsNamespaces) catalog;
-        if (!supportsNamespaces.namespaceExists(namespace)) {
-          try {
-            supportsNamespaces.createNamespace(namespace);
-            LOG.info("Created new namespace '{}'.", namespace);
-          } catch (AlreadyExistsException ignored) {
-            // race condition: another worker already created this namespace
-          }
+    // Create namespace if it does not exist yet
+    if (!namespace.isEmpty() && catalog instanceof SupportsNamespaces) {
+      SupportsNamespaces supportsNamespaces = (SupportsNamespaces) catalog;
+      if (!supportsNamespaces.namespaceExists(namespace)) {
+        try {
+          supportsNamespaces.createNamespace(namespace);
+          LOG.info("Created new namespace '{}'.", namespace);
+        } catch (AlreadyExistsException ignored) {
+          // race condition: another worker already created this namespace
         }
       }
+    }
 
-      // If table exists, just load it
-      // Note: the implementation of catalog.tableExists() will load the table 
to check its
-      // existence. We don't use it here to avoid double loadTable() calls.
+    // If table exists, just load it
+    // Note: the implementation of catalog.tableExists() will load the table 
to check its
+    // existence. We don't use it here to avoid double loadTable() calls.
+    try {
+      return catalog.loadTable(identifier);
+    } catch (NoSuchTableException e) { // Otherwise, create the table
+      org.apache.iceberg.Schema tableSchema = 
IcebergUtils.beamSchemaToIcebergSchema(dataSchema);
       try {
-        table = catalog.loadTable(identifier);
-      } catch (NoSuchTableException e) { // Otherwise, create the table
-        org.apache.iceberg.Schema tableSchema = 
IcebergUtils.beamSchemaToIcebergSchema(dataSchema);
-        try {
-          table =
-              catalog
-                  .buildTable(identifier, tableSchema)
-                  .withPartitionSpec(partitionSpec)
-                  .withSortOrder(sortOrder)
-                  .withProperties(tableProperties)
-                  .create();
-          LOG.info(
-              "Created Iceberg table '{}' with schema: {}\n"
-                  + ", partition spec: {}, sort order: {}, table properties: 
{}",
-              identifier,
-              tableSchema,
-              partitionSpec,
-              sortOrder,
-              tableProperties);
-        } catch (AlreadyExistsException ignored) {
-          // race condition: another worker already created this table
-          table = catalog.loadTable(identifier);
-        }
+        Table table =
+            catalog
+                .buildTable(identifier, tableSchema)
+                .withPartitionSpec(partitionSpec)
+                .withSortOrder(sortOrder)
+                .withProperties(tableProperties)
+                .create();
+        LOG.info(
+            "Created Iceberg table '{}' with schema: {}\n"
+                + ", partition spec: {}, sort order: {}, table properties: {}",
+            identifier,
+            tableSchema,
+            partitionSpec,
+            sortOrder,
+            tableProperties);
+        return table;
+      } catch (AlreadyExistsException ignored) {
+        // race condition: another worker already created this table
+        return catalog.loadTable(identifier);
       }
     }
-    lastRefreshedTable = new LastRefreshedTable(table, Instant.now());
-    LAST_REFRESHED_TABLE_CACHE.put(identifier, lastRefreshedTable);
-    return table;
   }
 
   /**
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanSource.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanSource.java
index 19218b85b63..c407ef8d3e2 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanSource.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanSource.java
@@ -47,7 +47,8 @@ class ScanSource extends BoundedSource<Row> {
   }
 
   private TableScan getTableScan() {
-    Table table = scanConfig.getTable();
+    Table table =
+        TableCache.getRefreshed(scanConfig.getCatalogConfig(), 
scanConfig.getTableIdentifier());
     TableScan tableScan = 
table.newScan().project(scanConfig.getProjectedSchema());
 
     if (scanConfig.getFilter() != null) {
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/TableCache.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/TableCache.java
index cb00d90f7fb..e95a6a5f66f 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/TableCache.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/TableCache.java
@@ -17,70 +17,170 @@
  */
 package org.apache.beam.sdk.io.iceberg;
 
-import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
-import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
-
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Futures;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ListenableFuture;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.UncheckedExecutionException;
 import org.apache.iceberg.Table;
 import org.apache.iceberg.catalog.TableIdentifier;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
-/** Utility to fetch and cache Iceberg {@link Table}s. */
+/**
+ * Process-wide cache for Iceberg {@link Table}s.
+ *
+ * <p>Entries are keyed by catalog configuration and table identifier, so one 
machine can share
+ * table metadata across source and sink threads without colliding when 
different catalogs contain
+ * the same identifier. The underlying catalog is only resolved from {@link 
IcebergCatalogConfig}
+ * when the table has to be loaded. Refreshes are synchronized per table 
entry: if another thread
+ * refreshed after a caller started its request, the caller reuses that 
refresh instead of making
+ * another catalog call.
+ */
 class TableCache {
-  private static final Map<String, IcebergCatalogConfig> CATALOG_CACHE = new 
ConcurrentHashMap<>();
-  private static final LoadingCache<String, Table> INTERNAL_CACHE =
-      CacheBuilder.newBuilder()
-          .expireAfterAccess(1, TimeUnit.HOURS)
-          .refreshAfterWrite(3, TimeUnit.MINUTES)
-          .build(
-              new CacheLoader<String, Table>() {
-                @Override
-                public Table load(String identifier) {
-                  return checkStateNotNull(CATALOG_CACHE.get(identifier))
-                      .catalog()
-                      .loadTable(TableIdentifier.parse(identifier));
-                }
-
-                @Override
-                public ListenableFuture<Table> reload(String unusedIdentifier, 
Table table) {
-                  table.refresh();
-                  return Futures.immediateFuture(table);
-                }
-              });;
-
-  static Table get(String identifier) {
+  static final Duration DEFAULT_REFRESH_INTERVAL = Duration.ofMinutes(2);
+
+  private static final Cache<CacheKey, CachedTable> TABLES =
+      CacheBuilder.newBuilder().expireAfterAccess(1, TimeUnit.HOURS).build();
+
+  /** Returns the cached table, loading it from the catalog on a cache miss. */
+  static Table get(IcebergCatalogConfig catalogConfig, TableIdentifier 
identifier) {
+    return get(catalogConfig, identifier, () -> 
catalogConfig.catalog().loadTable(identifier));
+  }
+
+  /** Returns the cached table for a string identifier, loading it on a cache 
miss. */
+  static Table get(IcebergCatalogConfig catalogConfig, String identifier) {
+    return get(catalogConfig, TableIdentifier.parse(identifier));
+  }
+
+  /** Returns the cached table, using the given loader only on a cache miss. */
+  static Table get(
+      IcebergCatalogConfig catalogConfig, TableIdentifier identifier, 
Callable<Table> loader) {
+    return getEntry(catalogConfig, identifier, loader).table;
+  }
+
+  /** Returns the cached table after forcing a refresh of any pre-existing 
cache entry. */
+  static Table getRefreshed(IcebergCatalogConfig catalogConfig, 
TableIdentifier identifier) {
+    Instant refreshRequestTime = Instant.now();
+    CachedTable cachedTable =
+        getEntry(catalogConfig, identifier, () -> 
catalogConfig.catalog().loadTable(identifier));
+    cachedTable.refreshIfOlderThan(refreshRequestTime);
+    return cachedTable.table;
+  }
+
+  /** Returns the cached table for a string identifier after refreshing any 
pre-existing entry. */
+  static Table getRefreshed(IcebergCatalogConfig catalogConfig, String 
identifier) {
+    return getRefreshed(catalogConfig, TableIdentifier.parse(identifier));
+  }
+
+  /**
+   * Returns the cached table, refreshing it only if it is older than {@link
+   * #DEFAULT_REFRESH_INTERVAL}.
+   */
+  static Table getAndRefreshIfStale(
+      IcebergCatalogConfig catalogConfig, TableIdentifier identifier) {
+    return getAndRefreshIfStale(
+        catalogConfig, identifier, () -> 
catalogConfig.catalog().loadTable(identifier));
+  }
+
+  /** Returns the cached table, using the loader on a miss and refreshing 
stale entries. */
+  static Table getAndRefreshIfStale(
+      IcebergCatalogConfig catalogConfig, TableIdentifier identifier, 
Callable<Table> loader) {
+    CachedTable cachedTable = getEntry(catalogConfig, identifier, loader);
+    
cachedTable.refreshIfOlderThan(Instant.now().minus(DEFAULT_REFRESH_INTERVAL));
+    return cachedTable.table;
+  }
+
+  private static CachedTable getEntry(
+      IcebergCatalogConfig catalogConfig, TableIdentifier identifier, 
Callable<Table> loader) {
+    CacheKey key = new CacheKey(catalogConfig, identifier);
     try {
-      return INTERNAL_CACHE.get(identifier);
-    } catch (ExecutionException e) {
+      return TABLES.get(key, () -> new CachedTable(loader.call(), 
Instant.now()));
+    } catch (ExecutionException | UncheckedExecutionException e) {
+      if (e.getCause() instanceof RuntimeException) {
+        throw (RuntimeException) e.getCause();
+      }
       throw new RuntimeException(
           "Encountered a problem fetching table " + identifier + " from 
cache.", e);
     }
   }
 
-  /** Forces a table refresh and returns. */
-  static Table getRefreshed(String identifier) {
-    INTERNAL_CACHE.refresh(identifier);
-    return get(identifier);
+  @VisibleForTesting
+  static long size() {
+    return TABLES.size();
   }
 
-  static void setup(IcebergScanConfig scanConfig) {
-    String tableIdentifier = scanConfig.getTableIdentifier();
-    IcebergCatalogConfig catalogConfig = scanConfig.getCatalogConfig();
-    if (CATALOG_CACHE.containsKey(tableIdentifier)) {
-      checkState(
-          catalogConfig.equals(CATALOG_CACHE.get(tableIdentifier)),
-          "TableCache is already set up with a different catalog. " + 
"Existing: %s, new: %s.",
-          CATALOG_CACHE.get(tableIdentifier),
-          catalogConfig);
-    } else {
-      CATALOG_CACHE.put(scanConfig.getTableIdentifier(), 
scanConfig.getCatalogConfig());
+  @VisibleForTesting
+  static void invalidateAll() {
+    TABLES.invalidateAll();
+  }
+
+  @VisibleForTesting
+  static void put(
+      IcebergCatalogConfig catalogConfig,
+      TableIdentifier identifier,
+      Table table,
+      Instant lastRefreshTime) {
+    TABLES.put(new CacheKey(catalogConfig, identifier), new CachedTable(table, 
lastRefreshTime));
+  }
+
+  @VisibleForTesting
+  static void markStale(IcebergCatalogConfig catalogConfig, TableIdentifier 
identifier) {
+    CachedTable cachedTable = TABLES.getIfPresent(new CacheKey(catalogConfig, 
identifier));
+    if (cachedTable != null) {
+      cachedTable.lastRefreshTime = Instant.EPOCH;
+    }
+  }
+
+  private static class CachedTable {
+    private final Table table;
+    private volatile Instant lastRefreshTime;
+
+    private CachedTable(Table table, Instant lastRefreshTime) {
+      this.table = table;
+      this.lastRefreshTime = lastRefreshTime;
+    }
+
+    private void refreshIfOlderThan(Instant refreshRequestTime) {
+      if (lastRefreshTime.isAfter(refreshRequestTime)) {
+        return;
+      }
+      synchronized (this) {
+        if (lastRefreshTime.isBefore(refreshRequestTime)) {
+          table.refresh();
+          lastRefreshTime = Instant.now();
+        }
+      }
+    }
+  }
+
+  private static class CacheKey {
+    private final IcebergCatalogConfig catalogConfig;
+    private final TableIdentifier identifier;
+
+    private CacheKey(IcebergCatalogConfig catalogConfig, TableIdentifier 
identifier) {
+      this.catalogConfig = catalogConfig;
+      this.identifier = identifier;
+    }
+
+    @Override
+    public boolean equals(@Nullable Object obj) {
+      if (this == obj) {
+        return true;
+      }
+      if (!(obj instanceof CacheKey)) {
+        return false;
+      }
+      CacheKey other = (CacheKey) obj;
+      return catalogConfig.equals(other.catalogConfig) && 
identifier.equals(other.identifier);
+    }
+
+    @Override
+    public int hashCode() {
+      return 31 * catalogConfig.hashCode() + identifier.hashCode();
     }
   }
 }
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WatchForSnapshots.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WatchForSnapshots.java
index 1af5588c251..8bd436c5570 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WatchForSnapshots.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WatchForSnapshots.java
@@ -87,7 +87,6 @@ class WatchForSnapshots extends PTransform<PBegin, 
PCollection<KV<String, List<S
   private static class SnapshotPollFn extends Watch.Growth.PollFn<String, 
List<SnapshotInfo>> {
     private final IcebergScanConfig scanConfig;
     private @Nullable Long fromSnapshotId;
-    boolean isCacheSetup = false;
 
     SnapshotPollFn(IcebergScanConfig scanConfig) {
       this.scanConfig = scanConfig;
@@ -95,11 +94,7 @@ class WatchForSnapshots extends PTransform<PBegin, 
PCollection<KV<String, List<S
 
     @Override
     public PollResult<List<SnapshotInfo>> apply(String tableIdentifier, 
Context c) {
-      if (!isCacheSetup) {
-        TableCache.setup(scanConfig);
-        isCacheSetup = true;
-      }
-      Table table = TableCache.getRefreshed(tableIdentifier);
+      Table table = TableCache.getRefreshed(scanConfig.getCatalogConfig(), 
tableIdentifier);
 
       @Nullable Long userSpecifiedToSnapshot = ReadUtils.getToSnapshot(table, 
scanConfig);
       boolean isComplete = userSpecifiedToSnapshot != null;
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteDirectRowsToFiles.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteDirectRowsToFiles.java
index fbd6c15095e..5cf095dc3c3 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteDirectRowsToFiles.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteDirectRowsToFiles.java
@@ -31,8 +31,6 @@ import org.apache.beam.sdk.values.WindowedValue;
 import org.apache.beam.sdk.values.WindowedValues;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
-import org.apache.iceberg.catalog.Catalog;
-import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 class WriteDirectRowsToFiles
@@ -66,7 +64,6 @@ class WriteDirectRowsToFiles
 
     private final DynamicDestinations dynamicDestinations;
     private final IcebergCatalogConfig catalogConfig;
-    private transient @MonotonicNonNull Catalog catalog;
     private final String filePrefix;
     private final long maxFileSize;
     private transient @Nullable RecordWriterManager recordWriterManager;
@@ -83,17 +80,10 @@ class WriteDirectRowsToFiles
       this.recordWriterManager = null;
     }
 
-    private org.apache.iceberg.catalog.Catalog getCatalog() {
-      if (catalog == null) {
-        this.catalog = catalogConfig.catalog();
-      }
-      return catalog;
-    }
-
     @StartBundle
     public void startBundle() {
       recordWriterManager =
-          new RecordWriterManager(getCatalog(), filePrefix, maxFileSize, 
Integer.MAX_VALUE);
+          new RecordWriterManager(catalogConfig, filePrefix, maxFileSize, 
Integer.MAX_VALUE);
     }
 
     @ProcessElement
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteGroupedRowsToFiles.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteGroupedRowsToFiles.java
index 12d9570d4a3..b16496240d1 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteGroupedRowsToFiles.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteGroupedRowsToFiles.java
@@ -30,8 +30,6 @@ import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.WindowedValue;
 import org.apache.beam.sdk.values.WindowedValues;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
-import org.apache.iceberg.catalog.Catalog;
-import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
 
 class WriteGroupedRowsToFiles
     extends PTransform<
@@ -67,7 +65,6 @@ class WriteGroupedRowsToFiles
 
     private final DynamicDestinations dynamicDestinations;
     private final IcebergCatalogConfig catalogConfig;
-    private transient @MonotonicNonNull Catalog catalog;
     private final String filePrefix;
     private final long maxFileSize;
 
@@ -82,13 +79,6 @@ class WriteGroupedRowsToFiles
       this.maxFileSize = maxFileSize;
     }
 
-    private org.apache.iceberg.catalog.Catalog getCatalog() {
-      if (catalog == null) {
-        this.catalog = catalogConfig.catalog();
-      }
-      return catalog;
-    }
-
     @ProcessElement
     public void processElement(
         ProcessContext c,
@@ -103,7 +93,7 @@ class WriteGroupedRowsToFiles
           WindowedValues.of(destination, window.maxTimestamp(), window, 
paneInfo);
       RecordWriterManager writer;
       try (RecordWriterManager openWriter =
-          new RecordWriterManager(getCatalog(), filePrefix, maxFileSize, 
Integer.MAX_VALUE)) {
+          new RecordWriterManager(catalogConfig, filePrefix, maxFileSize, 
Integer.MAX_VALUE)) {
         writer = openWriter;
         for (Row e : element.getValue()) {
           writer.write(windowedDestination, e);
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WritePartitionedRowsToFiles.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WritePartitionedRowsToFiles.java
index 54ad120f1ac..8b4ae0863f7 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WritePartitionedRowsToFiles.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WritePartitionedRowsToFiles.java
@@ -22,11 +22,8 @@ import static 
org.apache.beam.sdk.io.iceberg.AssignDestinationsAndPartitions.PAR
 import static 
org.apache.beam.sdk.io.iceberg.RecordWriterManager.getPartitionDataPath;
 import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
 
-import java.time.Duration;
-import java.time.Instant;
 import java.util.Map;
 import java.util.UUID;
-import java.util.concurrent.TimeUnit;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.RowCoder;
@@ -37,8 +34,6 @@ import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.Row;
-import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
 import org.apache.iceberg.DataFiles;
 import org.apache.iceberg.PartitionField;
@@ -53,6 +48,7 @@ import org.apache.iceberg.catalog.TableIdentifier;
 import org.apache.iceberg.data.Record;
 import org.apache.iceberg.exceptions.AlreadyExistsException;
 import org.apache.iceberg.exceptions.NoSuchTableException;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -91,8 +87,9 @@ class WritePartitionedRowsToFiles
     private final IcebergCatalogConfig catalogConfig;
     private final String filePrefix;
     private final Schema dataSchema;
-    static final Cache<TableIdentifier, LastRefreshedTable> 
LAST_REFRESHED_TABLE_CACHE =
-        CacheBuilder.newBuilder().expireAfterAccess(10, 
TimeUnit.MINUTES).build();
+    private transient @MonotonicNonNull Map<TableIdentifier, Integer> specIds;
+    private transient @MonotonicNonNull Map<TableIdentifier, Map<String, 
PartitionField>>
+        partitionFieldMaps;
 
     WriteDoFn(
         IcebergCatalogConfig catalogConfig,
@@ -105,6 +102,12 @@ class WritePartitionedRowsToFiles
       this.dataSchema = dataSchema;
     }
 
+    @Setup
+    public void setup() {
+      partitionFieldMaps = Maps.newHashMap();
+      specIds = Maps.newHashMap();
+    }
+
     @ProcessElement
     public void processElement(
         @Element KV<Row, Iterable<Row>> element, 
OutputReceiver<FileWriteResult> out)
@@ -113,9 +116,10 @@ class WritePartitionedRowsToFiles
       String partitionPath = 
checkStateNotNull(element.getKey().getString(PARTITION));
 
       IcebergDestination destination = 
dynamicDestinations.instantiateDestination(tableIdentifier);
-      LastRefreshedTable lastRefreshedTable = getOrCreateTable(destination, 
dataSchema);
-      Table table = lastRefreshedTable.table;
-      partitionPath = getPartitionDataPath(partitionPath, 
lastRefreshedTable.partitionFieldMap);
+      Table table = getOrCreateTable(destination, dataSchema);
+      partitionPath =
+          getPartitionDataPath(
+              partitionPath, 
getPartitionFieldMap(destination.getTableIdentifier(), table));
 
       StructLike partitionData =
           table.spec().isPartitioned()
@@ -146,60 +150,32 @@ class WritePartitionedRowsToFiles
               .build());
     }
 
-    static final class LastRefreshedTable {
-      final Table table;
-      volatile Instant lastRefreshTime;
-      static final Duration STALENESS_THRESHOLD = Duration.ofMinutes(2);
-      private int specId;
-      volatile Map<String, PartitionField> partitionFieldMap = 
Maps.newHashMap();
-
-      LastRefreshedTable(Table table, Instant lastRefreshTime) {
-        this.table = table;
-        this.specId = table.spec().specId();
-        this.lastRefreshTime = lastRefreshTime;
-        for (PartitionField partitionField : table.spec().fields()) {
-          partitionFieldMap.put(partitionField.name(), partitionField);
-        }
+    private Map<String, PartitionField> getPartitionFieldMap(
+        TableIdentifier identifier, Table table) {
+      @Nullable Integer specId = checkStateNotNull(specIds).get(identifier);
+      if (specId != null && specId == table.spec().specId()) {
+        return 
checkStateNotNull(checkStateNotNull(partitionFieldMaps).get(identifier));
       }
-
-      /**
-       * Refreshes the table metadata if it is considered stale (older than 2 
minutes).
-       *
-       * <p>This method first performs a non-synchronized check on the table's 
freshness. This
-       * provides a lock-free fast path that avoids synchronization overhead 
in the common case
-       * where the table does not need to be refreshed. If the table might be 
stale, it then enters
-       * a synchronized block to ensure that only one thread performs the 
refresh operation.
-       */
-      void refreshIfStale() {
-        // Fast path: Avoid entering the synchronized block if the table is 
not stale.
-        if (lastRefreshTime.isAfter(Instant.now().minus(STALENESS_THRESHOLD))) 
{
-          return;
-        }
-        synchronized (this) {
-          if 
(lastRefreshTime.isBefore(Instant.now().minus(STALENESS_THRESHOLD))) {
-            table.refresh();
-            lastRefreshTime = Instant.now();
-            if (table.spec().specId() != this.specId) {
-              partitionFieldMap = Maps.newHashMap();
-              for (PartitionField partitionField : table.spec().fields()) {
-                partitionFieldMap.put(partitionField.name(), partitionField);
-              }
-              this.specId = table.spec().specId();
-            }
-          }
-        }
+      Map<String, PartitionField> partitionFieldMap = Maps.newHashMap();
+      for (PartitionField partitionField : table.spec().fields()) {
+        partitionFieldMap.put(partitionField.name(), partitionField);
       }
+      checkStateNotNull(specIds).put(identifier, table.spec().specId());
+      checkStateNotNull(partitionFieldMaps).put(identifier, partitionFieldMap);
+      return partitionFieldMap;
     }
 
-    LastRefreshedTable getOrCreateTable(IcebergDestination destination, Schema 
dataSchema) {
+    Table getOrCreateTable(IcebergDestination destination, Schema dataSchema) {
       TableIdentifier identifier = destination.getTableIdentifier();
-      @Nullable
-      LastRefreshedTable lastRefreshedTable = 
LAST_REFRESHED_TABLE_CACHE.getIfPresent(identifier);
-      if (lastRefreshedTable != null) {
-        lastRefreshedTable.refreshIfStale();
-        return lastRefreshedTable;
-      }
+      return TableCache.getAndRefreshIfStale(
+          catalogConfig,
+          identifier,
+          () -> loadOrCreateTable(catalogConfig.catalog(), destination, 
dataSchema));
+    }
 
+    private Table loadOrCreateTable(
+        Catalog catalog, IcebergDestination destination, Schema dataSchema) {
+      TableIdentifier identifier = destination.getTableIdentifier();
       Namespace namespace = identifier.namespace();
       @Nullable IcebergTableCreateConfig createConfig = 
destination.getTableCreateConfig();
       PartitionSpec partitionSpec =
@@ -209,55 +185,43 @@ class WritePartitionedRowsToFiles
               ? createConfig.getTableProperties()
               : Maps.newHashMap();
 
-      @Nullable Table table = null;
-      synchronized (LAST_REFRESHED_TABLE_CACHE) {
-        lastRefreshedTable = 
LAST_REFRESHED_TABLE_CACHE.getIfPresent(identifier);
-        if (lastRefreshedTable != null) {
-          lastRefreshedTable.refreshIfStale();
-          return lastRefreshedTable;
-        }
-
-        Catalog catalog = catalogConfig.catalog();
-        // Create namespace if it does not exist yet
-        if (!namespace.isEmpty() && catalog instanceof SupportsNamespaces) {
-          SupportsNamespaces supportsNamespaces = (SupportsNamespaces) catalog;
-          if (!supportsNamespaces.namespaceExists(namespace)) {
-            try {
-              supportsNamespaces.createNamespace(namespace);
-              LOG.info("Created new namespace '{}'.", namespace);
-            } catch (AlreadyExistsException ignored) {
-              // race condition: another worker already created this namespace
-              LOG.info("Namespace `{}` already exists.", namespace);
-            }
+      // Create namespace if it does not exist yet
+      if (!namespace.isEmpty() && catalog instanceof SupportsNamespaces) {
+        SupportsNamespaces supportsNamespaces = (SupportsNamespaces) catalog;
+        if (!supportsNamespaces.namespaceExists(namespace)) {
+          try {
+            supportsNamespaces.createNamespace(namespace);
+            LOG.info("Created new namespace '{}'.", namespace);
+          } catch (AlreadyExistsException ignored) {
+            // race condition: another worker already created this namespace
+            LOG.info("Namespace `{}` already exists.", namespace);
           }
         }
+      }
 
-        // If table exists, just load it
-        // Note: the implementation of catalog.tableExists() will load the 
table to check its
-        // existence. We don't use it here to avoid double loadTable() calls.
+      // If table exists, just load it
+      // Note: the implementation of catalog.tableExists() will load the table 
to check its
+      // existence. We don't use it here to avoid double loadTable() calls.
+      try {
+        return catalog.loadTable(identifier);
+      } catch (NoSuchTableException e) { // Otherwise, create the table
+        org.apache.iceberg.Schema tableSchema = 
IcebergUtils.beamSchemaToIcebergSchema(dataSchema);
         try {
-          table = catalog.loadTable(identifier);
-        } catch (NoSuchTableException e) { // Otherwise, create the table
-          org.apache.iceberg.Schema tableSchema =
-              IcebergUtils.beamSchemaToIcebergSchema(dataSchema);
-          try {
-            table = catalog.createTable(identifier, tableSchema, 
partitionSpec, tableProperties);
-            LOG.info(
-                "Created Iceberg table '{}' with schema: {}\n"
-                    + ", partition spec: {}, table properties: {}",
-                identifier,
-                tableSchema,
-                partitionSpec,
-                tableProperties);
-          } catch (AlreadyExistsException ignored) {
-            // race condition: another worker already created this table
-            table = catalog.loadTable(identifier);
-          }
+          Table table =
+              catalog.createTable(identifier, tableSchema, partitionSpec, 
tableProperties);
+          LOG.info(
+              "Created Iceberg table '{}' with schema: {}\n"
+                  + ", partition spec: {}, table properties: {}",
+              identifier,
+              tableSchema,
+              partitionSpec,
+              tableProperties);
+          return table;
+        } catch (AlreadyExistsException ignored) {
+          // race condition: another worker already created this table
+          return catalog.loadTable(identifier);
         }
       }
-      lastRefreshedTable = new LastRefreshedTable(table, Instant.now());
-      LAST_REFRESHED_TABLE_CACHE.put(identifier, lastRefreshedTable);
-      return lastRefreshedTable;
     }
   }
 }
diff --git 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteUngroupedRowsToFiles.java
 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteUngroupedRowsToFiles.java
index 1db6ede3016..ff9a98c6200 100644
--- 
a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteUngroupedRowsToFiles.java
+++ 
b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteUngroupedRowsToFiles.java
@@ -47,8 +47,6 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Precondit
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
-import org.apache.iceberg.catalog.Catalog;
-import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 /**
@@ -193,7 +191,6 @@ class WriteUngroupedRowsToFiles
     private final long maxFileSize;
     private final DynamicDestinations dynamicDestinations;
     private final IcebergCatalogConfig catalogConfig;
-    private transient @MonotonicNonNull Catalog catalog;
     private transient @Nullable RecordWriterManager recordWriterManager;
     private int spilledShardNumber;
 
@@ -210,17 +207,10 @@ class WriteUngroupedRowsToFiles
       this.maxFileSize = maxFileSize;
     }
 
-    private org.apache.iceberg.catalog.Catalog getCatalog() {
-      if (catalog == null) {
-        this.catalog = catalogConfig.catalog();
-      }
-      return catalog;
-    }
-
     @StartBundle
     public void startBundle() {
       recordWriterManager =
-          new RecordWriterManager(getCatalog(), filename, maxFileSize, 
maxWritersPerBundle);
+          new RecordWriterManager(catalogConfig, filename, maxFileSize, 
maxWritersPerBundle);
       this.spilledShardNumber = 
ThreadLocalRandom.current().nextInt(SPILLED_RECORD_SHARDING_FACTOR);
     }
 
diff --git 
a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java
 
b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java
index f27a86cc72a..390e8d87af2 100644
--- 
a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java
+++ 
b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java
@@ -53,7 +53,6 @@ import org.apache.beam.sdk.values.WindowedValue;
 import org.apache.beam.sdk.values.WindowedValues;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import org.apache.commons.lang3.RandomStringUtils;
-import org.apache.hadoop.conf.Configuration;
 import org.apache.iceberg.AppendFiles;
 import org.apache.iceberg.DataFile;
 import org.apache.iceberg.FileFormat;
@@ -69,7 +68,6 @@ import org.apache.iceberg.catalog.Catalog;
 import org.apache.iceberg.catalog.Namespace;
 import org.apache.iceberg.catalog.TableIdentifier;
 import org.apache.iceberg.data.Record;
-import org.apache.iceberg.hadoop.HadoopCatalog;
 import org.apache.iceberg.io.FileIO;
 import org.apache.iceberg.io.InputFile;
 import org.apache.iceberg.io.OutputFile;
@@ -109,16 +107,26 @@ public class RecordWriterManagerTest {
       IcebergUtils.beamSchemaToIcebergSchema(BEAM_SCHEMA);
   private static final PartitionSpec PARTITION_SPEC =
       PartitionSpec.builderFor(ICEBERG_SCHEMA).truncate("name", 
3).identity("bool").build();
+  private IcebergCatalogConfig catalogConfig;
 
   private WindowedValue<IcebergDestination> windowedDestination;
-  private HadoopCatalog catalog;
 
   @Before
   public void setUp() {
     windowedDestination =
         getWindowedDestination("table_" + testName.getMethodName(), 
PARTITION_SPEC);
-    catalog = new HadoopCatalog(new Configuration(), warehouse.location);
-    RecordWriterManager.LAST_REFRESHED_TABLE_CACHE.invalidateAll();
+    catalogConfig =
+        IcebergCatalogConfig.builder()
+            .setCatalogProperties(
+                ImmutableMap.of("type", "hadoop", "warehouse", 
warehouse.location))
+            .build();
+    TableCache.invalidateAll();
+  }
+
+  private static IcebergCatalogConfig mockCatalogConfigFor(Catalog 
mockCatalog) {
+    IcebergCatalogConfig catalogConfig = mock(IcebergCatalogConfig.class);
+    Mockito.doReturn(mockCatalog).when(catalogConfig).catalog();
+    return catalogConfig;
   }
 
   private WindowedValue<IcebergDestination> getWindowedDestination(
@@ -145,7 +153,8 @@ public class RecordWriterManagerTest {
 
   @Test
   public void testCreateNamespaceAndTable() {
-    RecordWriterManager writerManager = new RecordWriterManager(catalog, 
"test_file_name", 1000, 3);
+    RecordWriterManager writerManager =
+        new RecordWriterManager(catalogConfig, "test_file_name", 1000, 3);
     Namespace newNamespace = Namespace.of("new_namespace");
     TableIdentifier identifier = TableIdentifier.of(newNamespace, 
testName.getMethodName());
     WindowedValue<IcebergDestination> dest =
@@ -157,15 +166,16 @@ public class RecordWriterManagerTest {
 
     Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "aaa", true).build();
 
-    assertFalse(catalog.namespaceExists(newNamespace));
+    assertFalse(catalogConfig.namespaceExists(newNamespace.toString()));
     boolean writeSuccess = writerManager.write(dest, row);
     assertTrue(writeSuccess);
-    assertTrue(catalog.namespaceExists(newNamespace));
+    assertTrue(catalogConfig.namespaceExists(newNamespace.toString()));
   }
 
   @Test
   public void testCreateTableWithSortOrder() throws IOException {
-    RecordWriterManager writerManager = new RecordWriterManager(catalog, 
"test_file_name", 1000, 3);
+    RecordWriterManager writerManager =
+        new RecordWriterManager(catalogConfig, "test_file_name", 1000, 3);
     TableIdentifier identifier = TableIdentifier.of("default", 
testName.getMethodName());
     WindowedValue<IcebergDestination> dest =
         WindowedValues.valueInGlobalWindow(
@@ -184,7 +194,7 @@ public class RecordWriterManagerTest {
     assertTrue(writerManager.write(dest, row));
     writerManager.close();
 
-    Table created = catalog.loadTable(identifier);
+    Table created = catalogConfig.catalog().loadTable(identifier);
     SortOrder order = created.sortOrder();
     assertEquals(2, order.fields().size());
     assertEquals(SortDirection.DESC, order.fields().get(0).direction());
@@ -196,7 +206,8 @@ public class RecordWriterManagerTest {
   @Test
   public void testCreateNewWriterForEachDestination() throws IOException {
     // Writer manager with a maximum limit of 3 writers
-    RecordWriterManager writerManager = new RecordWriterManager(catalog, 
"test_file_name", 1000, 3);
+    RecordWriterManager writerManager =
+        new RecordWriterManager(catalogConfig, "test_file_name", 1000, 3);
     assertEquals(0, writerManager.openWriters);
 
     boolean writeSuccess;
@@ -257,7 +268,8 @@ public class RecordWriterManagerTest {
   @Test
   public void testCreateNewWriterForEachPartition() throws IOException {
     // Writer manager with a maximum limit of 3 writers
-    RecordWriterManager writerManager = new RecordWriterManager(catalog, 
"test_file_name", 1000, 3);
+    RecordWriterManager writerManager =
+        new RecordWriterManager(catalogConfig, "test_file_name", 1000, 3);
     assertEquals(0, writerManager.openWriters);
 
     boolean writeSuccess;
@@ -318,7 +330,8 @@ public class RecordWriterManagerTest {
   @Test
   public void testRespectMaxFileSize() throws IOException {
     // Writer manager with a maximum file size of 100 bytes
-    RecordWriterManager writerManager = new RecordWriterManager(catalog, 
"test_file_name", 100, 2);
+    RecordWriterManager writerManager =
+        new RecordWriterManager(catalogConfig, "test_file_name", 100, 2);
     assertEquals(0, writerManager.openWriters);
     boolean writeSuccess;
 
@@ -364,7 +377,8 @@ public class RecordWriterManagerTest {
 
   @Test
   public void testRequireClosingBeforeFetchingDataFiles() {
-    RecordWriterManager writerManager = new RecordWriterManager(catalog, 
"test_file_name", 100, 2);
+    RecordWriterManager writerManager =
+        new RecordWriterManager(catalogConfig, "test_file_name", 100, 2);
     Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "aaa", true).build();
     writerManager.write(windowedDestination, row);
     assertEquals(1, writerManager.openWriters);
@@ -401,7 +415,11 @@ public class RecordWriterManagerTest {
     partitionKey.partition(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, 
row));
 
     RecordWriter writer =
-        new RecordWriter(catalog, windowedDestination.getValue(), 
"test_file_name", partitionKey);
+        new RecordWriter(
+            catalogConfig.catalog(),
+            windowedDestination.getValue(),
+            "test_file_name",
+            partitionKey);
     writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row));
     writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row2));
 
@@ -443,7 +461,11 @@ public class RecordWriterManagerTest {
 
     // write some rows
     RecordWriter writer =
-        new RecordWriter(catalog, windowedDestination.getValue(), 
"test_file_name", partitionKey);
+        new RecordWriter(
+            catalogConfig.catalog(),
+            windowedDestination.getValue(),
+            "test_file_name",
+            partitionKey);
     writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row));
     writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row2));
     writer.close();
@@ -462,7 +484,8 @@ public class RecordWriterManagerTest {
     assertEquals(serializableDataFile.getPartitionSpecId(), datafile.specId());
 
     // update spec
-    Table table = 
catalog.loadTable(windowedDestination.getValue().getTableIdentifier());
+    Table table =
+        
catalogConfig.catalog().loadTable(windowedDestination.getValue().getTableIdentifier());
     table.updateSpec().addField("id").removeField("bool").commit();
 
     Map<Integer, PartitionSpec> updatedSpecs = table.specs();
@@ -473,13 +496,14 @@ public class RecordWriterManagerTest {
 
   @Test
   public void testWriterKeepsUpWithUpdatingPartitionSpec() throws IOException {
-    Table table = 
catalog.loadTable(windowedDestination.getValue().getTableIdentifier());
+    Table table =
+        
catalogConfig.catalog().loadTable(windowedDestination.getValue().getTableIdentifier());
     Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build();
     Row row2 = Row.withSchema(BEAM_SCHEMA).addValues(2, "abcxyz", 
true).build();
 
     // write some rows
     RecordWriterManager writer =
-        new RecordWriterManager(catalog, "test_prefix", Long.MAX_VALUE, 
Integer.MAX_VALUE);
+        new RecordWriterManager(catalogConfig, "test_prefix", Long.MAX_VALUE, 
Integer.MAX_VALUE);
     writer.write(windowedDestination, row);
     writer.write(windowedDestination, row2);
     writer.close();
@@ -497,20 +521,17 @@ public class RecordWriterManagerTest {
     assertThat(dataFile.path().toString(), containsString("bool=true"));
 
     // table is cached
-    assertEquals(1, RecordWriterManager.LAST_REFRESHED_TABLE_CACHE.size());
+    assertEquals(1, TableCache.size());
 
     // update spec
     table.updateSpec().addField("id").removeField("bool").commit();
-    // Make the cached table stale to force reloading its metadata.
-    RecordWriterManager.LAST_REFRESHED_TABLE_CACHE.getIfPresent(
-                windowedDestination.getValue().getTableIdentifier())
-            .lastRefreshTime =
-        Instant.EPOCH;
+    // Make the cached table stale to force refreshing its metadata.
+    TableCache.markStale(catalogConfig, 
windowedDestination.getValue().getTableIdentifier());
 
     // write a second data file
     // should refresh the table and use the new partition spec
     RecordWriterManager writer2 =
-        new RecordWriterManager(catalog, "test_prefix_2", Long.MAX_VALUE, 
Integer.MAX_VALUE);
+        new RecordWriterManager(catalogConfig, "test_prefix_2", 
Long.MAX_VALUE, Integer.MAX_VALUE);
     writer2.write(windowedDestination, row);
     writer2.write(windowedDestination, row2);
     writer2.close();
@@ -578,7 +599,7 @@ public class RecordWriterManagerTest {
         getWindowedDestination("identity_partitioning", icebergSchema, spec);
 
     RecordWriterManager writer =
-        new RecordWriterManager(catalog, "test_prefix", Long.MAX_VALUE, 
Integer.MAX_VALUE);
+        new RecordWriterManager(catalogConfig, "test_prefix", Long.MAX_VALUE, 
Integer.MAX_VALUE);
     writer.write(dest, row);
     writer.close();
     List<SerializableDataFile> files = 
writer.getSerializableDataFiles().get(dest);
@@ -664,7 +685,7 @@ public class RecordWriterManagerTest {
         getWindowedDestination("bucket_partitioning", icebergSchema, spec);
 
     RecordWriterManager writer =
-        new RecordWriterManager(catalog, "test_prefix", Long.MAX_VALUE, 
Integer.MAX_VALUE);
+        new RecordWriterManager(catalogConfig, "test_prefix", Long.MAX_VALUE, 
Integer.MAX_VALUE);
     writer.write(dest, row);
     writer.close();
     List<SerializableDataFile> files = 
writer.getSerializableDataFiles().get(dest);
@@ -730,7 +751,7 @@ public class RecordWriterManagerTest {
 
     // write some rows
     RecordWriterManager writer =
-        new RecordWriterManager(catalog, "test_prefix", Long.MAX_VALUE, 
Integer.MAX_VALUE);
+        new RecordWriterManager(catalogConfig, "test_prefix", Long.MAX_VALUE, 
Integer.MAX_VALUE);
     writer.write(dest, row);
     writer.close();
     List<SerializableDataFile> files = 
writer.getSerializableDataFiles().get(dest);
@@ -763,7 +784,7 @@ public class RecordWriterManagerTest {
     String expectedPartition = String.join("/", expectedPartitions);
     DataFile dataFile =
         serializableDataFile.createDataFile(
-            catalog.loadTable(dest.getValue().getTableIdentifier()).specs());
+            
catalogConfig.catalog().loadTable(dest.getValue().getTableIdentifier()).specs());
     assertThat(dataFile.path().toString(), containsString(expectedPartition));
   }
 
@@ -771,7 +792,8 @@ public class RecordWriterManagerTest {
 
   @Test
   public void testWriterExceptionGetsCaught() throws IOException {
-    RecordWriterManager writerManager = new RecordWriterManager(catalog, 
"test_file_name", 100, 2);
+    RecordWriterManager writerManager =
+        new RecordWriterManager(catalogConfig, "test_file_name", 100, 2);
     Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build();
     PartitionKey partitionKey = new PartitionKey(PARTITION_SPEC, 
ICEBERG_SCHEMA);
     partitionKey.partition(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, 
row));
@@ -783,7 +805,10 @@ public class RecordWriterManagerTest {
     // replace with a failing record writer
     FailingRecordWriter failingWriter =
         new FailingRecordWriter(
-            catalog, windowedDestination.getValue(), "test_failing_writer", 
partitionKey);
+            catalogConfig.catalog(),
+            windowedDestination.getValue(),
+            "test_failing_writer",
+            partitionKey);
     state.writers.put(partitionKey, failingWriter);
     writerManager.write(windowedDestination, row);
 
@@ -843,7 +868,8 @@ public class RecordWriterManagerTest {
     WindowedValue<IcebergDestination> singleDestination =
         WindowedValues.valueInGlobalWindow(destination);
 
-    RecordWriterManager writerManager = new RecordWriterManager(catalog, 
"test_file_name", 1000, 3);
+    RecordWriterManager writerManager =
+        new RecordWriterManager(catalogConfig, "test_file_name", 1000, 3);
     Row row1 = Row.withSchema(BEAM_SCHEMA).addValues(1, "aaa", true).build();
     Row row2 = Row.withSchema(BEAM_SCHEMA).addValues(2, "bbb", false).build();
     Row row3 = Row.withSchema(BEAM_SCHEMA).addValues(3, "ccc", true).build();
@@ -905,7 +931,8 @@ public class RecordWriterManagerTest {
     WindowedValue<IcebergDestination> singleDestination =
         WindowedValues.valueInGlobalWindow(destination);
 
-    RecordWriterManager writerManager = new RecordWriterManager(catalog, 
"test_file_name", 1000, 3);
+    RecordWriterManager writerManager =
+        new RecordWriterManager(catalogConfig, "test_file_name", 1000, 3);
     Row row1 = Row.withSchema(BEAM_SCHEMA).addValues(1, "aaa", true).build();
     Row row2 = Row.withSchema(BEAM_SCHEMA).addValues(2, "bbb", false).build();
     Row row3 = Row.withSchema(BEAM_SCHEMA).addValues(3, "ccc", true).build();
@@ -1090,15 +1117,16 @@ public class RecordWriterManagerTest {
     Mockito.doReturn(sharedTrackingIO).when(spyTable1).io();
     Mockito.doReturn(sharedTrackingIO).when(spyTable2).io();
 
-    Catalog spyCatalog = Mockito.spy(catalog);
+    Catalog spyCatalog = Mockito.spy(catalogConfig.catalog());
     Mockito.doReturn(spyTable1).when(spyCatalog).loadTable(tableId1);
     Mockito.doReturn(spyTable2).when(spyCatalog).loadTable(tableId2);
 
     WindowedValue<IcebergDestination> dest1 = 
getWindowedDestination(tableName1, null);
     WindowedValue<IcebergDestination> dest2 = 
getWindowedDestination(tableName2, null);
 
+    IcebergCatalogConfig spyCatalogConfig = mockCatalogConfigFor(spyCatalog);
     RecordWriterManager writerManager =
-        new RecordWriterManager(spyCatalog, "test_file_name", 1000, 3);
+        new RecordWriterManager(spyCatalogConfig, "test_file_name", 1000, 3);
 
     Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "aaa", true).build();
     assertTrue(writerManager.write(dest1, row));
@@ -1205,17 +1233,15 @@ public class RecordWriterManagerTest {
     // test.
     Schema beamSchema = null;
 
-    // Instantiate a RecordWriterManager with a dummy catalog.
-    RecordWriterManager writer = new RecordWriterManager(null, "p", 1L, 1);
+    IcebergCatalogConfig mockCatalogConfig = mock(IcebergCatalogConfig.class);
+    RecordWriterManager writer = new RecordWriterManager(mockCatalogConfig, 
"p", 1L, 1);
 
     // Clean up cache before test
-    RecordWriterManager.LAST_REFRESHED_TABLE_CACHE.invalidateAll();
+    TableCache.invalidateAll();
 
     // --- 1. Test the fast path (entry is not stale) ---
     Instant freshTimestamp = Instant.now().minus(Duration.ofMinutes(1));
-    RecordWriterManager.LastRefreshedTable freshEntry =
-        new RecordWriterManager.LastRefreshedTable(mockTable, freshTimestamp);
-    RecordWriterManager.LAST_REFRESHED_TABLE_CACHE.put(identifier, freshEntry);
+    TableCache.put(mockCatalogConfig, identifier, mockTable, freshTimestamp);
 
     // Access the table
     writer.getOrCreateTable(destination, beamSchema);
@@ -1225,9 +1251,7 @@ public class RecordWriterManagerTest {
 
     // --- 2. Test the stale path (entry is stale) ---
     Instant staleTimestamp = Instant.now().minus(Duration.ofMinutes(5));
-    RecordWriterManager.LastRefreshedTable staleEntry =
-        new RecordWriterManager.LastRefreshedTable(mockTable, staleTimestamp);
-    RecordWriterManager.LAST_REFRESHED_TABLE_CACHE.put(identifier, staleEntry);
+    TableCache.put(mockCatalogConfig, identifier, mockTable, staleTimestamp);
 
     // Access the table again
     writer.getOrCreateTable(destination, beamSchema);
@@ -1253,14 +1277,15 @@ public class RecordWriterManagerTest {
     Table spyTable = Mockito.spy(realTable);
     Mockito.doReturn(sharedIO).when(spyTable).io();
 
-    Catalog spyCatalog = Mockito.spy(catalog);
+    Catalog spyCatalog = Mockito.spy(catalogConfig.catalog());
     Mockito.doReturn(spyTable).when(spyCatalog).loadTable(tableId);
 
     WindowedValue<IcebergDestination> dest = getWindowedDestination(tableName, 
null);
     Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "aaa", true).build();
+    IcebergCatalogConfig spyCatalogConfig = mockCatalogConfigFor(spyCatalog);
 
     // Bundle 1: write and close
-    RecordWriterManager bundle1 = new RecordWriterManager(spyCatalog, 
"file_b1", 1000, 3);
+    RecordWriterManager bundle1 = new RecordWriterManager(spyCatalogConfig, 
"file_b1", 1000, 3);
     assertTrue(bundle1.write(dest, row));
     bundle1.close();
     assertFalse("FileIO must survive after bundle 1 close", sharedIO.closed);
@@ -1268,7 +1293,7 @@ public class RecordWriterManagerTest {
         "Bundle 1 should produce data files", 
bundle1.getSerializableDataFiles().containsKey(dest));
 
     // Bundle 2: write and close using the same catalog (simulates DoFn reuse)
-    RecordWriterManager bundle2 = new RecordWriterManager(spyCatalog, 
"file_b2", 1000, 3);
+    RecordWriterManager bundle2 = new RecordWriterManager(spyCatalogConfig, 
"file_b2", 1000, 3);
     assertTrue(bundle2.write(dest, row));
     bundle2.close();
     assertFalse("FileIO must survive after bundle 2 close", sharedIO.closed);
diff --git 
a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TableCacheTest.java
 
b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TableCacheTest.java
new file mode 100644
index 00000000000..6b3d7614ba4
--- /dev/null
+++ 
b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TableCacheTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.iceberg;
+
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertThrows;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import java.time.Instant;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.catalog.Catalog;
+import org.apache.iceberg.catalog.TableIdentifier;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+/** Tests for {@link TableCache}. */
+public class TableCacheTest {
+  private static final TableIdentifier IDENTIFIER = TableIdentifier.of("db", 
"table");
+
+  @Before
+  public void setUp() {
+    TableCache.invalidateAll();
+  }
+
+  @After
+  public void tearDown() {
+    TableCache.invalidateAll();
+  }
+
+  @Test
+  public void getLoadsTableOnceForSameCatalogAndIdentifier() {
+    Catalog catalog = mock(Catalog.class);
+    IcebergCatalogConfig catalogConfig = mock(IcebergCatalogConfig.class);
+    Table table = mock(Table.class);
+    when(catalogConfig.catalog()).thenReturn(catalog);
+    when(catalog.loadTable(IDENTIFIER)).thenReturn(table);
+
+    assertSame(table, TableCache.get(catalogConfig, IDENTIFIER));
+    assertSame(table, TableCache.get(catalogConfig, IDENTIFIER));
+    assertSame(table, TableCache.get(catalogConfig, IDENTIFIER));
+
+    verify(catalog, times(1)).loadTable(IDENTIFIER);
+  }
+
+  @Test
+  public void getKeysByCatalogConfigAndIdentifier() throws Exception {
+    IcebergCatalogConfig catalogConfig1 =
+        IcebergCatalogConfig.builder().setCatalogName("catalog").build();
+    IcebergCatalogConfig catalogConfig2 =
+        IcebergCatalogConfig.builder().setCatalogName("catalog").build();
+    Table table = mock(Table.class);
+    AtomicInteger loadCount = new AtomicInteger();
+
+    assertSame(
+        table,
+        TableCache.get(
+            catalogConfig1,
+            IDENTIFIER,
+            () -> {
+              loadCount.incrementAndGet();
+              return table;
+            }));
+    assertSame(
+        table,
+        TableCache.get(
+            catalogConfig2,
+            IDENTIFIER,
+            () -> {
+              loadCount.incrementAndGet();
+              return null;
+            }));
+
+    org.junit.Assert.assertEquals(1, loadCount.get());
+  }
+
+  @Test
+  public void getRefreshedDoesNotRefreshNewlyLoadedTable() {
+    Catalog catalog = mock(Catalog.class);
+    IcebergCatalogConfig catalogConfig = mock(IcebergCatalogConfig.class);
+    Table table = mock(Table.class);
+    when(catalogConfig.catalog()).thenReturn(catalog);
+    when(catalog.loadTable(IDENTIFIER)).thenReturn(table);
+
+    assertSame(table, TableCache.getRefreshed(catalogConfig, IDENTIFIER));
+
+    verify(catalog, times(1)).loadTable(IDENTIFIER);
+    verify(table, never()).refresh();
+  }
+
+  @Test
+  public void getRefreshedPropagatesRefreshFailure() {
+    IcebergCatalogConfig catalogConfig = mock(IcebergCatalogConfig.class);
+    Table table = mock(Table.class);
+    RuntimeException refreshFailure = new RuntimeException("refresh failed");
+    doThrow(refreshFailure).when(table).refresh();
+    TableCache.put(catalogConfig, IDENTIFIER, table, Instant.EPOCH);
+
+    RuntimeException thrown =
+        assertThrows(
+            RuntimeException.class, () -> 
TableCache.getRefreshed(catalogConfig, IDENTIFIER));
+
+    assertSame(refreshFailure, thrown);
+  }
+}

Reply via email to