This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch bigtable-cdc-feature-branch
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/bigtable-cdc-feature-branch by 
this push:
     new eaa17a71d93 Cloud Bigtable stream changes and handle CloseStream 
responses (#25460)
eaa17a71d93 is described below

commit eaa17a71d93169d3905200f6c876a69b435e4f2f
Author: Tony Tang <[email protected]>
AuthorDate: Tue Feb 21 14:22:05 2023 -0500

    Cloud Bigtable stream changes and handle CloseStream responses (#25460)
    
    * Handle ChangeStreamMutation
    
    * Evaluate CloseStream split and merge messages from Change Stream API
    
    * Fix rebase issues
    
    ---------
    
    Co-authored-by: Pablo <[email protected]>
---
 .../changestreams/ByteStringRangeHelper.java       | 118 +++++++++++-
 .../changestreams/ChangeStreamMetrics.java         |  45 +++++
 .../changestreams/action/ChangeStreamAction.java   |  34 ++++
 .../action/ReadChangeStreamPartitionAction.java    |  51 ++++++
 .../changestreams/dao/MetadataTableDao.java        |  88 +++++++++
 .../dofn/ReadChangeStreamPartitionDoFn.java        |   1 +
 .../ReadChangeStreamPartitionProgressTracker.java  |   2 +-
 .../changestreams/restriction/StreamProgress.java  |  17 +-
 .../changestreams/ByteStringRangeHelperTest.java   | 186 +++++++++++++++++++
 .../action/ChangeStreamActionTest.java             |  26 +++
 .../ReadChangeStreamPartitionActionTest.java       | 203 +++++++++++++++++++++
 .../changestreams/dao/MetadataTableDaoTest.java    |  44 +++++
 12 files changed, 811 insertions(+), 4 deletions(-)

diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java
index 8f307f526f0..34d3affa01c 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java
@@ -18,11 +18,15 @@
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams;
 
 import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
+import com.google.protobuf.ByteString;
 import com.google.protobuf.TextFormat;
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 
 /** Helper functions to evaluate the completeness of collection of 
ByteStringRanges. */
 public class ByteStringRangeHelper {
-
   /**
    * Returns formatted string of a partition for debugging.
    *
@@ -36,4 +40,116 @@ public class ByteStringRangeHelper {
         + TextFormat.escapeBytes(partition.getEnd())
         + "')";
   }
+
+  /**
+   * Convert partitions to a string for debugging.
+   *
+   * @param partitions to print
+   * @return string representation of partitions
+   */
+  public static String partitionsToString(List<ByteStringRange> partitions) {
+    return partitions.stream()
+        .map(ByteStringRangeHelper::formatByteStringRange)
+        .collect(Collectors.joining(", ", "{", "}"));
+  }
+
+  @VisibleForTesting
+  static class PartitionComparator implements Comparator<ByteStringRange> {
+    @Override
+    // if first > second, it returns positive number
+    // if first < second, it returns negative number
+    // if first == second, it returns 0
+    // First is greater than second if either of the following are true:
+    // - Its start key comes after second's start key
+    // - The start keys are equal and its end key comes after second's end key
+    // An end key of "" represents the final end key, so it needs to be 
handled as a special case
+    public int compare(ByteStringRange first, ByteStringRange second) {
+      int compareStart =
+          ByteString.unsignedLexicographicalComparator()
+              .compare(first.getStart(), second.getStart());
+      if (compareStart != 0) {
+        return compareStart;
+      }
+      if (first.getEnd().isEmpty() && !second.getEnd().isEmpty()) {
+        return 1;
+      }
+      if (second.getEnd().isEmpty() && !first.getEnd().isEmpty()) {
+        return -1;
+      }
+      return ByteString.unsignedLexicographicalComparator()
+          .compare(first.getEnd(), second.getEnd());
+    }
+  }
+
+  private static boolean childStartsBeforeParent(
+      ByteString parentStartKey, ByteString childStartKey) {
+    // Check if the start key of the child partition comes before the start 
key of the entire
+    // parentPartitions
+    return 
ByteString.unsignedLexicographicalComparator().compare(parentStartKey, 
childStartKey)
+        > 0;
+  }
+
+  private static boolean childEndsAfterParent(ByteString parentEndKey, 
ByteString childEndKey) {
+    // A final end key is represented by "" but this evaluates to < all 
characters, so we need to
+    // handle it as a special case.
+    if (childEndKey.isEmpty() && !parentEndKey.isEmpty()) {
+      return true;
+    }
+
+    // Check if the end key of the child partition comes after the end key of 
the entire
+    // parentPartitions. "" Represents the final end key so we need to handle 
that as a
+    // special case when it is the end key of the entire parentPartitions
+    return 
ByteString.unsignedLexicographicalComparator().compare(parentEndKey, 
childEndKey) < 0
+        && !parentEndKey.isEmpty();
+  }
+
+  // This assumes parentPartitions is sorted. If parentPartitions has not 
already been sorted
+  // it will be incorrect
+  private static boolean gapsInParentPartitions(List<ByteStringRange> 
sortedParentPartitions) {
+    for (int i = 1; i < sortedParentPartitions.size(); i++) {
+      // Iterating through a sorted list, the start key should be the same or 
before the end of the
+      // previous. Handle "" end key as a special case.
+      ByteString prevEndKey = sortedParentPartitions.get(i - 1).getEnd();
+      if (ByteString.unsignedLexicographicalComparator()
+                  .compare(sortedParentPartitions.get(i).getStart(), 
prevEndKey)
+              > 0
+          && !prevEndKey.isEmpty()) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  /**
+   * Returns true if parentPartitions is a superset of childPartition.
+   *
+   * <p>If ordered parentPartitions row ranges form a contiguous range, and 
start key is before or
+   * at childPartition's start key, and end key is at or after 
childPartition's end key, then
+   * parentPartitions is a superset of childPartition.
+   *
+   * <p>Overlaps from parents are valid because arbitrary partitions can merge 
and they may overlap.
+   * They will form a valid new partition. However, if there are any missing 
parent partitions, then
+   * merge cannot happen with missing row ranges.
+   *
+   * @param parentPartitions list of partitions to determine if it forms a 
large contiguous range
+   * @param childPartition the smaller partition
+   * @return true if parentPartitions is a superset of childPartition, 
otherwise false.
+   */
+  public static boolean isSuperset(
+      List<ByteStringRange> parentPartitions, ByteStringRange childPartition) {
+    // sort parentPartitions by starting key
+    // iterate through, check open end key and close start key of each 
iteration to ensure no gaps.
+    // first start key and last end key must be equal to or wider than child 
partition start and end
+    // key.
+    if (parentPartitions.isEmpty()) {
+      return false;
+    }
+    parentPartitions.sort(new PartitionComparator());
+    ByteString parentStartKey = parentPartitions.get(0).getStart();
+    ByteString parentEndKey = parentPartitions.get(parentPartitions.size() - 
1).getEnd();
+
+    return !childStartsBeforeParent(parentStartKey, childPartition.getStart())
+        && !childEndsAfterParent(parentEndKey, childPartition.getEnd())
+        && !gapsInParentPartitions(parentPartitions);
+  }
 }
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java
index ed14eb50d1c..b7b24c4892d 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ChangeStreamMetrics.java
@@ -47,6 +47,14 @@ public class ChangeStreamMetrics implements Serializable {
           
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics.class,
           "heartbeat_count");
 
+  /**
+   * Counter for the total number of heartbeats identified during the 
execution of the Connector.
+   */
+  public static final Counter CLOSESTREAM_COUNT =
+      Metrics.counter(
+          
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics.class,
+          "closestream_count");
+
   /**
    * Counter for the total number of ChangeStreamMutations that are initiated 
by users (not garbage
    * collection) identified during the execution of the Connector.
@@ -71,6 +79,12 @@ public class ChangeStreamMetrics implements Serializable {
           
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics.class,
           "processing_delay_from_commit_timestamp");
 
+  /** Counter for the total number of active partitions being streamed. */
+  public static final Counter PARTITION_STREAM_COUNT =
+      Metrics.counter(
+          
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics.class,
+          "partition_stream_count");
+
   /**
    * Increments the {@link
    * 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#LIST_PARTITIONS_COUNT}
 by
@@ -89,6 +103,15 @@ public class ChangeStreamMetrics implements Serializable {
     inc(HEARTBEAT_COUNT);
   }
 
+  /**
+   * Increments the {@link
+   * 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#CLOSESTREAM_COUNT}
 by 1
+   * if the metric is enabled.
+   */
+  public void incClosestreamCount() {
+    inc(CLOSESTREAM_COUNT);
+  }
+
   /**
    * Increments the {@link
    * 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#CHANGE_STREAM_MUTATION_USER_COUNT}
@@ -107,6 +130,24 @@ public class ChangeStreamMetrics implements Serializable {
     inc(CHANGE_STREAM_MUTATION_GC_COUNT);
   }
 
+  /**
+   * Increments the {@link
+   * 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#PARTITION_STREAM_COUNT}
+   * by 1.
+   */
+  public void incPartitionStreamCount() {
+    inc(PARTITION_STREAM_COUNT);
+  }
+
+  /**
+   * Decrements the {@link
+   * 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#PARTITION_STREAM_COUNT}
+   * by 1.
+   */
+  public void decPartitionStreamCount() {
+    dec(PARTITION_STREAM_COUNT);
+  }
+
   /**
    * Adds measurement of an instance for the {@link
    * 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics#PROCESSING_DELAY_FROM_COMMIT_TIMESTAMP}.
@@ -119,6 +160,10 @@ public class ChangeStreamMetrics implements Serializable {
     counter.inc();
   }
 
+  private void dec(Counter counter) {
+    counter.dec();
+  }
+
   private void update(Distribution distribution, long value) {
     distribution.update(value);
   }
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java
index e64cd6fb876..5f3784721e2 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamAction.java
@@ -22,10 +22,12 @@ import static 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeH
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamRecord;
+import com.google.cloud.bigtable.data.v2.models.CloseStream;
 import com.google.cloud.bigtable.data.v2.models.Heartbeat;
 import com.google.cloud.bigtable.data.v2.models.Range;
 import com.google.protobuf.ByteString;
 import java.util.Optional;
+import java.util.stream.Collectors;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.TimestampConverter;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.model.PartitionRecord;
@@ -135,6 +137,38 @@ public class ChangeStreamAction {
         return Optional.of(DoFn.ProcessContinuation.stop());
       }
       metrics.incHeartbeatCount();
+    } else if (record instanceof CloseStream) {
+      CloseStream closeStream = (CloseStream) record;
+      StreamProgress streamProgress = new StreamProgress(closeStream);
+
+      if (shouldDebug) {
+        LOG.info(
+            "RCSP {}: CloseStream: {}",
+            formatByteStringRange(partitionRecord.getPartition()),
+            closeStream.getChangeStreamContinuationTokens().stream()
+                .map(
+                    c ->
+                        "{partition: "
+                            + formatByteStringRange(c.getPartition())
+                            + " token: "
+                            + c.getToken()
+                            + "}")
+                .collect(Collectors.joining(", ", "[", "]")));
+      }
+      // If the tracker fail to claim the streamProgress, it most likely means 
the runner initiated
+      // a checkpoint. See {@link
+      // 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.ReadChangeStreamPartitionProgressTracker}
+      // for more information regarding runner initiated checkpoints.
+      if (!tracker.tryClaim(streamProgress)) {
+        if (shouldDebug) {
+          LOG.info(
+              "RCSP {}: Failed to claim close stream tracker",
+              formatByteStringRange(partitionRecord.getPartition()));
+        }
+        return Optional.of(DoFn.ProcessContinuation.stop());
+      }
+      metrics.incClosestreamCount();
+      return Optional.of(DoFn.ProcessContinuation.resume());
     } else if (record instanceof ChangeStreamMutation) {
       ChangeStreamMutation changeStreamMutation = (ChangeStreamMutation) 
record;
       final Instant watermark =
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java
index a36b57ff42f..8a49788542d 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionAction.java
@@ -17,11 +17,21 @@
  */
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.action;
 
+import static 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.formatByteStringRange;
+import static 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.isSuperset;
+import static 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.partitionsToString;
+
 import com.google.api.gax.rpc.ServerStream;
+import com.google.cloud.bigtable.common.Status;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamRecord;
+import com.google.cloud.bigtable.data.v2.models.CloseStream;
+import com.google.cloud.bigtable.data.v2.models.Range;
 import com.google.protobuf.ByteString;
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Optional;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.ChangeStreamDao;
@@ -126,6 +136,47 @@ public class ReadChangeStreamPartitionAction {
               + tracker.currentRestriction());
     }
 
+    // Process CloseStream if it exists
+    CloseStream closeStream = tracker.currentRestriction().getCloseStream();
+    if (closeStream != null) {
+      if (closeStream.getStatus().getCode() != Status.Code.OUT_OF_RANGE) {
+        LOG.error(
+            "RCSP {}: Reached unexpected terminal state: {}",
+            formatByteStringRange(partitionRecord.getPartition()),
+            closeStream.getStatus().toString());
+        metrics.decPartitionStreamCount();
+        return ProcessContinuation.stop();
+      }
+      // The partitions in the continuation tokens should be a superset of 
this partition.
+      // If there's only 1 token, then the token's partition should be a 
superset of this partition.
+      // If there are more than 1 tokens, then the tokens should form a 
continuous row range that is
+      // a superset of this partition.
+      List<Range.ByteStringRange> partitions = new ArrayList<>();
+      for (ChangeStreamContinuationToken changeStreamContinuationToken :
+          closeStream.getChangeStreamContinuationTokens()) {
+        partitions.add(changeStreamContinuationToken.getPartition());
+        metadataTableDao.writeNewPartition(
+            changeStreamContinuationToken,
+            partitionRecord.getPartition(),
+            watermarkEstimator.getState());
+      }
+      if (shouldDebug) {
+        LOG.info(
+            "RCSP {}: Split/Merge into {}",
+            formatByteStringRange(partitionRecord.getPartition()),
+            partitionsToString(partitions));
+      }
+      if (!isSuperset(partitions, partitionRecord.getPartition())) {
+        LOG.warn(
+            "RCSP {}: CloseStream has child partition(s) {} that doesn't cover 
the keyspace",
+            formatByteStringRange(partitionRecord.getPartition()),
+            partitionsToString(partitions));
+      }
+      
metadataTableDao.deleteStreamPartitionRow(partitionRecord.getPartition());
+      metrics.decPartitionStreamCount();
+      return ProcessContinuation.stop();
+    }
+
     // Update the metadata table with the watermark
     metadataTableDao.updateWatermark(
         partitionRecord.getPartition(),
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java
index 3c3e828c5bb..fd7c3ae2bc1 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java
@@ -21,9 +21,13 @@ import static 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTabl
 import static 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableAdminDao.NEW_PARTITION_PREFIX;
 import static 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableAdminDao.STREAM_PARTITION_PREFIX;
 
+import com.google.api.gax.rpc.ServerStream;
 import com.google.cloud.bigtable.data.v2.BigtableDataClient;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.Filters;
+import com.google.cloud.bigtable.data.v2.models.Query;
 import com.google.cloud.bigtable.data.v2.models.Range;
+import com.google.cloud.bigtable.data.v2.models.Row;
 import com.google.cloud.bigtable.data.v2.models.RowMutation;
 import com.google.protobuf.ByteString;
 import javax.annotation.Nullable;
@@ -95,6 +99,78 @@ public class MetadataTableDao {
     return 
getFullStreamPartitionPrefix().concat(Range.ByteStringRange.toByteString(partition));
   }
 
+  /**
+   * Convert partition to a New Partition row key to query for partitions 
ready to be streamed as
+   * the result of splits and merges.
+   *
+   * @param partition convert to row key
+   * @return row key to insert to Cloud Bigtable.
+   */
+  public ByteString convertPartitionToNewPartitionRowKey(Range.ByteStringRange 
partition) {
+    return 
getFullNewPartitionPrefix().concat(Range.ByteStringRange.toByteString(partition));
+  }
+
+  /**
+   * @return stream of all the new partitions resulting from splits and merges 
waiting to be
+   *     streamed.
+   */
+  public ServerStream<Row> readNewPartitions() {
+    // It's important that we limit to the latest value per column because 
it's possible to write to
+    // the same column multiple times. We don't want to read and send 
duplicate tokens to the
+    // server.
+    Query query =
+        Query.create(tableId)
+            .prefix(getFullNewPartitionPrefix())
+            .filter(Filters.FILTERS.limit().cellsPerColumn(1));
+    return dataClient.readRows(query);
+  }
+
+  /**
+   * After a split or merge from a close stream, write the new partition's 
information to the
+   * metadata table.
+   *
+   * @param changeStreamContinuationToken the token that can be used to pick 
up from where the
+   *     parent left off
+   * @param parentPartition the parent that stopped and split or merged
+   * @param lowWatermark the low watermark of the parent stream
+   */
+  public void writeNewPartition(
+      ChangeStreamContinuationToken changeStreamContinuationToken,
+      Range.ByteStringRange parentPartition,
+      Instant lowWatermark) {
+    writeNewPartition(
+        changeStreamContinuationToken.getPartition(),
+        changeStreamContinuationToken.toByteString(),
+        Range.ByteStringRange.toByteString(parentPartition),
+        lowWatermark);
+  }
+
+  /**
+   * After a split or merge from a close stream, write the new partition's 
information to the
+   * metadata table.
+   *
+   * @param newPartition the new partition
+   * @param newPartitionContinuationToken continuation token for the new 
partition
+   * @param parentPartition the parent that stopped
+   * @param lowWatermark low watermark of the parent
+   */
+  private void writeNewPartition(
+      Range.ByteStringRange newPartition,
+      ByteString newPartitionContinuationToken,
+      ByteString parentPartition,
+      Instant lowWatermark) {
+    ByteString rowKey = convertPartitionToNewPartitionRowKey(newPartition);
+    RowMutation rowMutation =
+        RowMutation.create(tableId, rowKey)
+            .setCell(MetadataTableAdminDao.CF_INITIAL_TOKEN, 
newPartitionContinuationToken, 1)
+            .setCell(MetadataTableAdminDao.CF_PARENT_PARTITIONS, 
parentPartition, 1)
+            .setCell(
+                MetadataTableAdminDao.CF_PARENT_LOW_WATERMARKS,
+                parentPartition,
+                
ByteString.copyFromUtf8(Long.toString(lowWatermark.getMillis())));
+    dataClient.mutateRow(rowMutation);
+  }
+
   /**
    * Update the metadata for the rowKey. This helper adds necessary prefixes 
to the row key.
    *
@@ -134,6 +210,18 @@ public class MetadataTableDao {
         convertPartitionToStreamPartitionRowKey(partition), watermark, 
currentToken);
   }
 
+  /**
+   * Delete the row key represented by the partition. This represents that the 
partition will no
+   * longer be streamed.
+   *
+   * @param partition forms the row key of the row to delete
+   */
+  public void deleteStreamPartitionRow(Range.ByteStringRange partition) {
+    ByteString rowKey = convertPartitionToStreamPartitionRowKey(partition);
+    RowMutation rowMutation = RowMutation.create(tableId, rowKey).deleteRow();
+    dataClient.mutateRow(rowMutation);
+  }
+
   /**
    * Set the version number for DetectNewPartition. This value can be checked 
later to verify that
    * the existing metadata table is compatible with current beam connector 
code.
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java
index a7871dd5414..e3665070b0d 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java
@@ -82,6 +82,7 @@ public class ReadChangeStreamPartitionDoFn
 
   @GetInitialRestriction
   public StreamProgress initialRestriction() {
+    metrics.incPartitionStreamCount();
     return new StreamProgress();
   }
 
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/ReadChangeStreamPartitionProgressTracker.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/ReadChangeStreamPartitionProgressTracker.java
index f6158898754..5eeb096bc3f 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/ReadChangeStreamPartitionProgressTracker.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/ReadChangeStreamPartitionProgressTracker.java
@@ -79,7 +79,7 @@ public class ReadChangeStreamPartitionProgressTracker
    */
   @Override
   public void checkDone() throws java.lang.IllegalStateException {
-    boolean done = shouldStop;
+    boolean done = shouldStop || streamProgress.getCloseStream() != null;
     Preconditions.checkState(done, "There's more work to be done");
   }
 
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
index ef35a040ee8..c594af3a344 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/restriction/StreamProgress.java
@@ -18,6 +18,7 @@
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction;
 
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.CloseStream;
 import com.google.protobuf.Timestamp;
 import java.io.Serializable;
 import java.util.Objects;
@@ -36,9 +37,10 @@ import org.checkerframework.checker.nullness.qual.Nullable;
  */
 @Internal
 public class StreamProgress implements Serializable {
-  private static final long serialVersionUID = -5384329262726188695L;
+  private static final long serialVersionUID = -8597355120329526194L;
 
   private @Nullable ChangeStreamContinuationToken currentToken;
+  private @Nullable CloseStream closeStream;
   private @Nullable Timestamp lowWatermark;
 
   public @Nullable ChangeStreamContinuationToken getCurrentToken() {
@@ -49,6 +51,10 @@ public class StreamProgress implements Serializable {
     return lowWatermark;
   }
 
+  public @Nullable CloseStream getCloseStream() {
+    return closeStream;
+  }
+
   public StreamProgress() {}
 
   public StreamProgress(@Nullable ChangeStreamContinuationToken token, 
Timestamp lowWatermark) {
@@ -56,6 +62,10 @@ public class StreamProgress implements Serializable {
     this.lowWatermark = lowWatermark;
   }
 
+  public StreamProgress(@Nullable CloseStream closeStream) {
+    this.closeStream = closeStream;
+  }
+
   @Override
   public boolean equals(@Nullable Object o) {
     if (this == o) {
@@ -66,7 +76,8 @@ public class StreamProgress implements Serializable {
     }
     StreamProgress that = (StreamProgress) o;
     return Objects.equals(getCurrentToken(), that.getCurrentToken())
-        && Objects.equals(getLowWatermark(), that.getLowWatermark());
+        && Objects.equals(getLowWatermark(), that.getLowWatermark())
+        && Objects.equals(getCloseStream(), that.getCloseStream());
   }
 
   @Override
@@ -81,6 +92,8 @@ public class StreamProgress implements Serializable {
         + currentToken
         + ", lowWatermark="
         + lowWatermark
+        + ", closeStream="
+        + closeStream
         + '}';
   }
 }
diff --git 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelperTest.java
 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelperTest.java
new file mode 100644
index 00000000000..78e63a1a494
--- /dev/null
+++ 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelperTest.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigtable.changestreams;
+
+import static 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.formatByteStringRange;
+import static 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.partitionsToString;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import org.junit.Test;
+
+public class ByteStringRangeHelperTest {
+
+  @Test
+  public void testParentIsEntireKeySpaceIsSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("", "");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("A", "B");
+
+    assertTrue(ByteStringRangeHelper.isSuperset(parentPartitions, 
childPartition));
+  }
+
+  @Test
+  public void testChildIsEntireKeySpaceParentIsLeftSubSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("", "n");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("", "");
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, 
childPartition));
+  }
+
+  @Test
+  public void testChildIsEntireKeySpaceParentIsRightSubSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("n", "");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("", "");
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, 
childPartition));
+  }
+
+  @Test
+  public void testChildIsEntireKeySpaceParentIsSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition1 = ByteStringRange.create("", "n");
+    ByteStringRange partition2 = ByteStringRange.create("n", "");
+    parentPartitions.add(partition1);
+    parentPartitions.add(partition2);
+
+    ByteStringRange childPartition = ByteStringRange.create("", "");
+    assertTrue(ByteStringRangeHelper.isSuperset(parentPartitions, 
childPartition));
+  }
+
+  @Test
+  public void testParentKeySpaceStartsBeforeAndEndAfterChildIsSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("A", "B");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("AA", "AB");
+
+    assertTrue(ByteStringRangeHelper.isSuperset(parentPartitions, 
childPartition));
+  }
+
+  @Test
+  public void testParentStartKeyIsAfterChildStartKeyIsNotSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("AA", "B");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("A", "AB");
+
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, 
childPartition));
+  }
+
+  @Test
+  public void testParentEndKeyIsBeforeChildEndKeyIsNotSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("A", "B");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("AA", "BA");
+
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, 
childPartition));
+  }
+
+  @Test
+  public void testParentIsSameAsChildIsSuperSet() {
+    List<ByteStringRange> parentPartitions = new ArrayList<>();
+    ByteStringRange partition = ByteStringRange.create("A", "B");
+    parentPartitions.add(partition);
+
+    ByteStringRange childPartition = ByteStringRange.create("A", "B");
+
+    assertTrue(ByteStringRangeHelper.isSuperset(parentPartitions, 
childPartition));
+  }
+
+  @Test
+  public void testParentIsMissingPartitionIsNotSuperSet() {
+    ByteStringRange partition1 = ByteStringRange.create("A", "B");
+    ByteStringRange partition2 = ByteStringRange.create("C", "Z");
+    List<ByteStringRange> parentPartitions = Arrays.asList(partition1, 
partition2);
+
+    ByteStringRange childPartition = ByteStringRange.create("A", "Z");
+
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, 
childPartition));
+  }
+
+  @Test
+  public void testParentHasOverlapIsSuperSet() {
+    ByteStringRange partition1 = ByteStringRange.create("A", "C");
+    ByteStringRange partition2 = ByteStringRange.create("B", "Z");
+    List<ByteStringRange> parentPartitions = Arrays.asList(partition1, 
partition2);
+
+    ByteStringRange childPartition = ByteStringRange.create("A", "Z");
+
+    assertTrue(ByteStringRangeHelper.isSuperset(parentPartitions, 
childPartition));
+  }
+
+  @Test
+  public void testEmptyParentsIsNotSuperset() {
+    List<ByteStringRange> parentPartitions = Collections.emptyList();
+    ByteStringRange childPartition = ByteStringRange.create("", "");
+
+    assertFalse(ByteStringRangeHelper.isSuperset(parentPartitions, 
childPartition));
+  }
+
+  @Test
+  public void testPartitionsToString() {
+    ByteStringRange partition1 = ByteStringRange.create("", "A");
+    ByteStringRange partition2 = ByteStringRange.create("A", "B");
+    ByteStringRange partition3 = ByteStringRange.create("B", "");
+    List<ByteStringRange> partitions = Arrays.asList(partition1, partition2, 
partition3);
+    String partitionsString = partitionsToString(partitions);
+    assertEquals(
+        String.format(
+            "{%s, %s, %s}",
+            formatByteStringRange(partition1),
+            formatByteStringRange(partition2),
+            formatByteStringRange(partition3)),
+        partitionsString);
+  }
+
+  @Test
+  public void testPartitionsToStringEmptyPartition() {
+    List<ByteStringRange> partitions = new ArrayList<>();
+    String partitionsString = partitionsToString(partitions);
+    assertEquals("{}", partitionsString);
+  }
+
+  @Test
+  public void testPartitionComparator() {
+    ByteStringRange partition1 = ByteStringRange.create("", "a");
+    ByteStringRange partition2 = ByteStringRange.create("", "");
+    ByteStringRange partition3 = ByteStringRange.create("a", "z");
+    ByteStringRange partition4 = ByteStringRange.create("a", "");
+    List<ByteStringRange> unsorted = Arrays.asList(partition3, partition4, 
partition2, partition1);
+    List<ByteStringRange> sorted = Arrays.asList(partition1, partition2, 
partition3, partition4);
+    unsorted.sort(new ByteStringRangeHelper.PartitionComparator());
+    assertEquals(unsorted, sorted);
+  }
+}
diff --git 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamActionTest.java
 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamActionTest.java
index b2c85481e2e..46453a4d892 100644
--- 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamActionTest.java
+++ 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ChangeStreamActionTest.java
@@ -17,7 +17,9 @@
  */
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.action;
 
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
@@ -27,10 +29,13 @@ import static org.mockito.Mockito.when;
 
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
 import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation;
+import com.google.cloud.bigtable.data.v2.models.CloseStream;
 import com.google.cloud.bigtable.data.v2.models.Heartbeat;
 import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.Timestamp;
+import com.google.rpc.Status;
+import java.util.Collections;
 import java.util.Optional;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.TimestampConverter;
@@ -88,6 +93,27 @@ public class ChangeStreamActionTest {
     verify(tracker).tryClaim(eq(streamProgress));
   }
 
+  @Test
+  public void testCloseStreamResume() {
+    ChangeStreamContinuationToken changeStreamContinuationToken =
+        new ChangeStreamContinuationToken(ByteStringRange.create("a", "b"), 
"1234");
+    CloseStream mockCloseStream = Mockito.mock(CloseStream.class);
+    Status statusProto = Status.newBuilder().setCode(11).build();
+    Mockito.when(mockCloseStream.getStatus())
+        
.thenReturn(com.google.cloud.bigtable.common.Status.fromProto(statusProto));
+    Mockito.when(mockCloseStream.getChangeStreamContinuationTokens())
+        .thenReturn(Collections.singletonList(changeStreamContinuationToken));
+
+    final Optional<DoFn.ProcessContinuation> result =
+        action.run(partitionRecord, mockCloseStream, tracker, receiver, 
watermarkEstimator, false);
+
+    assertTrue(result.isPresent());
+    assertEquals(DoFn.ProcessContinuation.resume(), result.get());
+    verify(metrics).incClosestreamCount();
+    StreamProgress streamProgress = new StreamProgress(mockCloseStream);
+    verify(tracker).tryClaim(eq(streamProgress));
+  }
+
   @Test
   public void testChangeStreamMutationUser() {
     ByteStringRange partition = ByteStringRange.create("", "");
diff --git 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionActionTest.java
 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionActionTest.java
new file mode 100644
index 00000000000..bea8728f3ca
--- /dev/null
+++ 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/ReadChangeStreamPartitionActionTest.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigtable.changestreams.action;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.api.gax.rpc.ServerStream;
+import com.google.cloud.Timestamp;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation;
+import com.google.cloud.bigtable.data.v2.models.ChangeStreamRecord;
+import com.google.cloud.bigtable.data.v2.models.CloseStream;
+import com.google.cloud.bigtable.data.v2.models.Heartbeat;
+import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
+import com.google.protobuf.ByteString;
+import com.google.rpc.Status;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.Optional;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.ChangeStreamDao;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableDao;
+import org.apache.beam.sdk.io.gcp.bigtable.changestreams.model.PartitionRecord;
+import 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.ReadChangeStreamPartitionProgressTracker;
+import 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.StreamProgress;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.values.KV;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+public class ReadChangeStreamPartitionActionTest {
+
+  private ReadChangeStreamPartitionAction action;
+
+  private MetadataTableDao metadataTableDao;
+  private ChangeStreamDao changeStreamDao;
+  private ChangeStreamMetrics metrics;
+  private ChangeStreamAction changeStreamAction;
+
+  //    private PartitionRecord partitionRecord;
+  private StreamProgress restriction;
+  private RestrictionTracker<StreamProgress, StreamProgress> tracker;
+  private DoFn.OutputReceiver<KV<ByteString, ChangeStreamMutation>> receiver;
+  private ManualWatermarkEstimator<Instant> watermarkEstimator;
+
+  private ByteStringRange partition;
+  private String uuid;
+  private PartitionRecord partitionRecord;
+
+  @Before
+  public void setUp() throws Exception {
+    metadataTableDao = mock(MetadataTableDao.class);
+    changeStreamDao = mock(ChangeStreamDao.class);
+    metrics = mock(ChangeStreamMetrics.class);
+    changeStreamAction = mock(ChangeStreamAction.class);
+    Duration heartbeatDurationSeconds = Duration.standardSeconds(1);
+
+    action =
+        new ReadChangeStreamPartitionAction(
+            metadataTableDao,
+            changeStreamDao,
+            metrics,
+            changeStreamAction,
+            heartbeatDurationSeconds);
+
+    restriction = mock(StreamProgress.class);
+    tracker = mock(ReadChangeStreamPartitionProgressTracker.class);
+    receiver = mock(DoFn.OutputReceiver.class);
+    watermarkEstimator = mock(ManualWatermarkEstimator.class);
+
+    partition = ByteStringRange.create("A", "B");
+    uuid = "123456";
+    Timestamp startTime = Timestamp.now();
+    Timestamp parentLowWatermark = Timestamp.now();
+    partitionRecord = new PartitionRecord(partition, startTime, uuid, 
parentLowWatermark);
+    when(tracker.currentRestriction()).thenReturn(restriction);
+    when(restriction.getCurrentToken()).thenReturn(null);
+    when(restriction.getCloseStream()).thenReturn(null);
+    // Setting watermark estimator to now so we don't debug.
+    when(watermarkEstimator.getState()).thenReturn(Instant.now());
+  }
+
+  @Test
+  public void testThatChangeStreamWorkerCounterIsIncrementedOnInitialRun() 
throws IOException {
+    // Return null token to indicate that this is the first ever run.
+    when(restriction.getCurrentToken()).thenReturn(null);
+    when(restriction.getCloseStream()).thenReturn(null);
+
+    final ServerStream<ChangeStreamRecord> responses = 
mock(ServerStream.class);
+    final Iterator<ChangeStreamRecord> responseIterator = mock(Iterator.class);
+    when(responses.iterator()).thenReturn(responseIterator);
+
+    Heartbeat mockHeartBeat = Mockito.mock(Heartbeat.class);
+    when(responseIterator.next()).thenReturn(mockHeartBeat);
+    when(responseIterator.hasNext()).thenReturn(true);
+    when(changeStreamDao.readChangeStreamPartition(any(), any(), any(), 
anyBoolean()))
+        .thenReturn(responses);
+
+    when(changeStreamAction.run(any(), any(), any(), any(), any(), 
anyBoolean()))
+        .thenReturn(Optional.of(DoFn.ProcessContinuation.stop()));
+
+    final DoFn.ProcessContinuation result =
+        action.run(partitionRecord, tracker, receiver, watermarkEstimator);
+    assertEquals(DoFn.ProcessContinuation.stop(), result);
+    verify(changeStreamAction).run(any(), any(), any(), any(), any(), 
anyBoolean());
+  }
+
+  @Test
+  public void testCloseStreamTerminateOKStatus() throws IOException {
+    CloseStream mockCloseStream = Mockito.mock(CloseStream.class);
+    Status statusProto = Status.newBuilder().setCode(0).build();
+    Mockito.when(mockCloseStream.getStatus())
+        
.thenReturn(com.google.cloud.bigtable.common.Status.fromProto(statusProto));
+    when(restriction.getCloseStream()).thenReturn(mockCloseStream);
+    final DoFn.ProcessContinuation result =
+        action.run(partitionRecord, tracker, receiver, watermarkEstimator);
+    assertEquals(DoFn.ProcessContinuation.stop(), result);
+    // Should terminate before reaching processing stream partition responses.
+    verify(changeStreamAction, never()).run(any(), any(), any(), any(), any(), 
anyBoolean());
+    // Should decrement the metric on termination.
+    verify(metrics).decPartitionStreamCount();
+    // Should not try to write any new partition to the metadata table.
+    verify(metadataTableDao, never()).writeNewPartition(any(), any(), any());
+    verify(metadataTableDao, never()).deleteStreamPartitionRow(any());
+  }
+
+  @Test
+  public void testCloseStreamTerminateNotOutOfRangeStatus() throws IOException 
{
+    // Out of Range code is 11.
+    CloseStream mockCloseStream = Mockito.mock(CloseStream.class);
+    Status statusProto = Status.newBuilder().setCode(10).build();
+    Mockito.when(mockCloseStream.getStatus())
+        
.thenReturn(com.google.cloud.bigtable.common.Status.fromProto(statusProto));
+    when(restriction.getCloseStream()).thenReturn(mockCloseStream);
+    final DoFn.ProcessContinuation result =
+        action.run(partitionRecord, tracker, receiver, watermarkEstimator);
+    assertEquals(DoFn.ProcessContinuation.stop(), result);
+    // Should terminate before reaching processing stream partition responses.
+    verify(changeStreamAction, never()).run(any(), any(), any(), any(), any(), 
anyBoolean());
+    // Should decrement the metric on termination.
+    verify(metrics).decPartitionStreamCount();
+    // Should not try to write any new partition to the metadata table.
+    verify(metadataTableDao, never()).writeNewPartition(any(), any(), any());
+    verify(metadataTableDao, never()).deleteStreamPartitionRow(any());
+  }
+
+  @Test
+  public void testCloseStreamWritesContinuationTokens() throws IOException {
+    ChangeStreamContinuationToken changeStreamContinuationToken1 =
+        new ChangeStreamContinuationToken(ByteStringRange.create("A", "AJ"), 
"1234");
+    ChangeStreamContinuationToken changeStreamContinuationToken2 =
+        new ChangeStreamContinuationToken(ByteStringRange.create("AJ", "B"), 
"5678");
+
+    CloseStream mockCloseStream = Mockito.mock(CloseStream.class);
+    Status statusProto = Status.newBuilder().setCode(11).build();
+    Mockito.when(mockCloseStream.getStatus())
+        
.thenReturn(com.google.cloud.bigtable.common.Status.fromProto(statusProto));
+    Mockito.when(mockCloseStream.getChangeStreamContinuationTokens())
+        .thenReturn(Arrays.asList(changeStreamContinuationToken1, 
changeStreamContinuationToken2));
+
+    when(restriction.getCloseStream()).thenReturn(mockCloseStream);
+    final DoFn.ProcessContinuation result =
+        action.run(partitionRecord, tracker, receiver, watermarkEstimator);
+    assertEquals(DoFn.ProcessContinuation.stop(), result);
+    // Should terminate before reaching processing stream partition responses.
+    verify(changeStreamAction, never()).run(any(), any(), any(), any(), any(), 
anyBoolean());
+    // Should decrement the metric on termination.
+    verify(metrics).decPartitionStreamCount();
+    // Write the new partitions.
+    
verify(metadataTableDao).writeNewPartition(eq(changeStreamContinuationToken1), 
any(), any());
+    
verify(metadataTableDao).writeNewPartition(eq(changeStreamContinuationToken2), 
any(), any());
+    verify(metadataTableDao, 
times(1)).deleteStreamPartitionRow(partitionRecord.getPartition());
+  }
+}
diff --git 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java
 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java
index 8bb238162e7..fbea020295c 100644
--- 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java
+++ 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java
@@ -18,7 +18,9 @@
 package org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 
+import com.google.api.gax.rpc.ServerStream;
 import com.google.cloud.bigtable.admin.v2.BigtableTableAdminClient;
 import com.google.cloud.bigtable.admin.v2.BigtableTableAdminSettings;
 import com.google.cloud.bigtable.data.v2.BigtableDataClient;
@@ -27,6 +29,8 @@ import 
com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken;
 import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange;
 import com.google.cloud.bigtable.data.v2.models.Row;
 import com.google.cloud.bigtable.emulator.v2.BigtableEmulatorRule;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.InvalidProtocolBufferException;
 import java.io.IOException;
 import org.apache.beam.sdk.io.gcp.bigtable.changestreams.UniqueIdGenerator;
 import 
org.apache.beam.sdk.io.gcp.bigtable.changestreams.encoder.MetadataTableEncoder;
@@ -80,6 +84,46 @@ public class MetadataTableDaoTest {
             metadataTableAdminDao.getChangeStreamNamePrefix());
   }
 
+  @Test
+  public void testNewPartitionsWriteRead() throws 
InvalidProtocolBufferException {
+    // This test a split of ["", "") to ["", "a") and ["a", "")
+    ByteStringRange parentPartition = ByteStringRange.create("", "");
+    ByteStringRange partition1 = ByteStringRange.create("", "a");
+    ChangeStreamContinuationToken changeStreamContinuationToken1 =
+        new ChangeStreamContinuationToken(partition1, "tk1");
+    ByteStringRange partition2 = ByteStringRange.create("a", "");
+    ChangeStreamContinuationToken changeStreamContinuationToken2 =
+        new ChangeStreamContinuationToken(partition2, "tk2");
+
+    Instant lowWatermark = Instant.now();
+    metadataTableDao.writeNewPartition(
+        changeStreamContinuationToken1, parentPartition, lowWatermark);
+    metadataTableDao.writeNewPartition(
+        changeStreamContinuationToken2, parentPartition, lowWatermark);
+
+    ServerStream<Row> rows = metadataTableDao.readNewPartitions();
+    int rowsCount = 0;
+    boolean matchedPartition1 = false;
+    boolean matchedPartition2 = false;
+    for (Row row : rows) {
+      rowsCount++;
+      ByteString newPartitionPrefix =
+          metadataTableDao
+              .getChangeStreamNamePrefix()
+              .concat(MetadataTableAdminDao.NEW_PARTITION_PREFIX);
+      ByteStringRange partition =
+          
ByteStringRange.toByteStringRange(row.getKey().substring(newPartitionPrefix.size()));
+      if (partition.equals(partition1)) {
+        matchedPartition1 = true;
+      } else if (partition.equals(partition2)) {
+        matchedPartition2 = true;
+      }
+    }
+    assertTrue(matchedPartition1);
+    assertTrue(matchedPartition2);
+    assertEquals(2, rowsCount);
+  }
+
   @Test
   public void testUpdateWatermark() {
     ByteStringRange partition = ByteStringRange.create("a", "b");

Reply via email to