This is an automated email from the ASF dual-hosted git repository.

ycai pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-analytics.git


The following commit(s) were added to refs/heads/trunk by this push:
     new e168011  CASSANDRA-19806: Stream sstable eagerly when bulk writing to 
reclaim local disk space sooner (#69)
e168011 is described below

commit e168011c40de2ca48d138514640838067e61feea
Author: Yifan Cai <[email protected]>
AuthorDate: Wed Aug 7 17:02:38 2024 -0700

    CASSANDRA-19806: Stream sstable eagerly when bulk writing to reclaim local 
disk space sooner (#69)
    
    Patch by Yifan Cai; Reviewed by Francisco Guerrero for CASSANDRA-19806
---
 CHANGES.txt                                        |   1 +
 .../apache/cassandra/bridge/SSTableDescriptor.java |  61 +++++++++++
 .../org/apache/cassandra/bridge/SSTableWriter.java |  18 +++
 .../apache/cassandra/spark/common/SSTables.java    |  23 ++++
 .../org/apache/cassandra}/util/ThreadUtil.java     |   2 +-
 .../example/ExampleStorageTransportExtension.java  |   2 +-
 .../CassandraDirectDataTransportContext.java       |   7 +-
 .../spark/bulkwriter/CassandraTopologyMonitor.java |   2 +-
 .../spark/bulkwriter/CommitCoordinator.java        |   2 +-
 .../spark/bulkwriter/DirectStreamSession.java      | 100 +++++++++++++----
 .../spark/bulkwriter/HeartbeatReporter.java        |   2 +-
 .../bulkwriter/ImportCompletionCoordinator.java    |   2 +-
 .../cassandra/spark/bulkwriter/RecordWriter.java   |  50 +++------
 .../spark/bulkwriter/SortedSSTableWriter.java      |  99 +++++++++++++----
 .../cassandra/spark/bulkwriter/StreamSession.java  | 117 ++++++++++++++++----
 .../spark/bulkwriter/TransportContext.java         |   5 +-
 .../bulkwriter/blobupload/BlobStreamSession.java   |  85 +++++++++++++--
 .../CassandraCloudStorageTransportContext.java     |   7 +-
 .../bulkwriter/blobupload/SSTableCollector.java    |   6 +
 .../spark/bulkwriter/blobupload/SSTableLister.java |  38 +++++--
 .../bulkwriter/blobupload/SSTablesBundler.java     |  10 ++
 .../cassandra/spark/data/LocalDataLayer.java       |  29 ++++-
 .../spark/bulkwriter/DirectStreamSessionTest.java  |  21 ++--
 .../spark/bulkwriter/MockBulkWriterContext.java    |   7 +-
 .../spark/bulkwriter/MockTableWriter.java          |  14 ++-
 .../NonValidatingTestSortedSSTableWriter.java      |   6 +-
 .../spark/bulkwriter/RecordWriterTest.java         |   8 +-
 .../spark/bulkwriter/SortedSSTableWriterTest.java  |  19 +++-
 .../bulkwriter/StreamSessionConsistencyTest.java   |   9 +-
 .../blobupload/BlobStreamSessionTest.java          |   6 +-
 .../bulkwriter/blobupload/SSTableListerTest.java   |  89 ++++++++++++---
 .../bridge/SSTableWriterImplementation.java        | 121 +++++++++++++++++++--
 .../bridge/SSTableWriterImplementationTest.java    |  55 +++++++++-
 33 files changed, 837 insertions(+), 186 deletions(-)

diff --git a/CHANGES.txt b/CHANGES.txt
index a30d3f4..6c0e3a0 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 1.0.0
+ * Stream sstable eagerly when bulk writing to reclaim local disk space sooner 
(CASSANDRA-19806)
  * Split the Cassandra type logic out from CassandraBridge into a separate 
module (CASSANDRA-19793)
  * Remove other uses of Apache Commons lang for hashcode, equality and random 
string generation (CASSANDRA-19791)
  * Split out BufferingInputStream stats into separate interface 
(CASSANDRA-19778)
diff --git 
a/cassandra-analytics-common/src/main/java/org/apache/cassandra/bridge/SSTableDescriptor.java
 
b/cassandra-analytics-common/src/main/java/org/apache/cassandra/bridge/SSTableDescriptor.java
new file mode 100644
index 0000000..f14b86a
--- /dev/null
+++ 
b/cassandra-analytics-common/src/main/java/org/apache/cassandra/bridge/SSTableDescriptor.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.bridge;
+
+import java.util.Objects;
+
+/**
+ * Descriptor for each SSTable.
+ *
+ * (as of now, it is just a wrapper around the base filename of sstable; add 
more methods and properties when appropriate)
+ */
+public class SSTableDescriptor
+{
+    // base filename that is shared among all components of the same sstable
+    public final String baseFilename;
+
+    public SSTableDescriptor(String baseFilename)
+    {
+        this.baseFilename = baseFilename;
+    }
+
+    @Override
+    public boolean equals(Object o)
+    {
+        if (this == o)
+        {
+            return true;
+        }
+
+        if (o == null || getClass() != o.getClass())
+        {
+            return false;
+        }
+
+        SSTableDescriptor that = (SSTableDescriptor) o;
+        return Objects.equals(baseFilename, that.baseFilename);
+    }
+
+    @Override
+    public int hashCode()
+    {
+        return Objects.hashCode(baseFilename);
+    }
+}
diff --git 
a/cassandra-analytics-common/src/main/java/org/apache/cassandra/bridge/SSTableWriter.java
 
b/cassandra-analytics-common/src/main/java/org/apache/cassandra/bridge/SSTableWriter.java
index 9ec752d..1e4338a 100644
--- 
a/cassandra-analytics-common/src/main/java/org/apache/cassandra/bridge/SSTableWriter.java
+++ 
b/cassandra-analytics-common/src/main/java/org/apache/cassandra/bridge/SSTableWriter.java
@@ -22,8 +22,26 @@ package org.apache.cassandra.bridge;
 import java.io.Closeable;
 import java.io.IOException;
 import java.util.Map;
+import java.util.Set;
+import java.util.function.Consumer;
 
 public interface SSTableWriter extends Closeable
 {
+    /**
+     * Write a new row
+     *
+     * @param values values of the row
+     * @throws IOException i/o exception when writing
+     */
     void addRow(Map<String, Object> values) throws IOException;
+
+    /**
+     * Register the listener for the set of newly produced sstable, identified 
by its unique base filename.
+     * The base filename is filename of sstable without the component suffix.
+     * For example, "nb-1-big"
+     * <p>
+     * Note that once a produced sstable has been returned, the returning 
lists of the subsequent calls do not include it anymore.
+     * Therefore, it only returns the _newly_ produced sstables.
+     */
+    void setSSTablesProducedListener(Consumer<Set<SSTableDescriptor>> 
listener);
 }
diff --git 
a/cassandra-analytics-common/src/main/java/org/apache/cassandra/spark/common/SSTables.java
 
b/cassandra-analytics-common/src/main/java/org/apache/cassandra/spark/common/SSTables.java
index 9571cf6..192f1dc 100644
--- 
a/cassandra-analytics-common/src/main/java/org/apache/cassandra/spark/common/SSTables.java
+++ 
b/cassandra-analytics-common/src/main/java/org/apache/cassandra/spark/common/SSTables.java
@@ -21,6 +21,8 @@ package org.apache.cassandra.spark.common;
 
 import java.nio.file.Path;
 
+import org.apache.cassandra.bridge.SSTableDescriptor;
+
 public final class SSTables
 {
     private SSTables()
@@ -28,9 +30,30 @@ public final class SSTables
         throw new IllegalStateException(getClass() + " is static utility class 
and shall not be instantiated");
     }
 
+    /**
+     * Get the sstable base name from data file path.
+     * For example, the base name of data file 
'/path/to/table/nb-1-big-Data.db' is 'nb-1-big'
+     *
+     * @deprecated use {@code #getSSTableDescriptor(Path).baseFilename} instead
+     *
+     * @param dataFile data file path
+     * @return sstable base name
+     */
+    @Deprecated
     public static String getSSTableBaseName(Path dataFile)
     {
         String fileName = dataFile.getFileName().toString();
         return fileName.substring(0, fileName.lastIndexOf("-") + 1);
     }
+
+    /**
+     * Get the {@link SSTableDescriptor} from the data file path.
+     * @param dataFile data file path
+     * @return sstable descriptor
+     */
+    public static SSTableDescriptor getSSTableDescriptor(Path dataFile)
+    {
+        String baseFilename = getSSTableBaseName(dataFile);
+        return new SSTableDescriptor(baseFilename);
+    }
 }
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/ThreadUtil.java
 
b/cassandra-analytics-common/src/main/java/org/apache/cassandra/util/ThreadUtil.java
similarity index 96%
rename from 
cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/ThreadUtil.java
rename to 
cassandra-analytics-common/src/main/java/org/apache/cassandra/util/ThreadUtil.java
index 7aababb..a3aff6e 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/ThreadUtil.java
+++ 
b/cassandra-analytics-common/src/main/java/org/apache/cassandra/util/ThreadUtil.java
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.cassandra.spark.bulkwriter.util;
+package org.apache.cassandra.util;
 
 import java.util.concurrent.ThreadFactory;
 
diff --git 
a/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/ExampleStorageTransportExtension.java
 
b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/ExampleStorageTransportExtension.java
index a4524a7..1b71b1a 100644
--- 
a/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/ExampleStorageTransportExtension.java
+++ 
b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/ExampleStorageTransportExtension.java
@@ -27,7 +27,7 @@ import com.google.common.collect.ImmutableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.cassandra.spark.bulkwriter.util.ThreadUtil;
+import org.apache.cassandra.util.ThreadUtil;
 import org.apache.cassandra.spark.transports.storage.StorageCredentialPair;
 import org.apache.cassandra.spark.transports.storage.StorageCredentials;
 import 
org.apache.cassandra.spark.transports.storage.extensions.StorageTransportConfiguration;
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraDirectDataTransportContext.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraDirectDataTransportContext.java
index 963ebad..bed36ac 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraDirectDataTransportContext.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraDirectDataTransportContext.java
@@ -20,6 +20,7 @@
 package org.apache.cassandra.spark.bulkwriter;
 
 import java.math.BigInteger;
+import java.util.concurrent.ExecutorService;
 
 import com.google.common.collect.Range;
 
@@ -49,14 +50,16 @@ public class CassandraDirectDataTransportContext implements 
TransportContext.Dir
                                                    String sessionId,
                                                    SortedSSTableWriter 
sstableWriter,
                                                    Range<BigInteger> range,
-                                                   
ReplicaAwareFailureHandler<RingInstance> failureHandler)
+                                                   
ReplicaAwareFailureHandler<RingInstance> failureHandler,
+                                                   ExecutorService 
executorService)
     {
         return new DirectStreamSession(writerContext,
                                        sstableWriter,
                                        this,
                                        sessionId,
                                        range,
-                                       failureHandler);
+                                       failureHandler,
+                                       executorService);
     }
 
     @Override
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraTopologyMonitor.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraTopologyMonitor.java
index 52b9cca..1530677 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraTopologyMonitor.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraTopologyMonitor.java
@@ -29,7 +29,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
-import org.apache.cassandra.spark.bulkwriter.util.ThreadUtil;
+import org.apache.cassandra.util.ThreadUtil;
 
 /**
  * A monitor that check whether the cassandra topology has changed.
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitCoordinator.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitCoordinator.java
index b90b213..7ce727d 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitCoordinator.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitCoordinator.java
@@ -43,7 +43,7 @@ import com.google.common.util.concurrent.MoreExecutors;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.cassandra.spark.bulkwriter.util.ThreadUtil;
+import org.apache.cassandra.util.ThreadUtil;
 import org.jetbrains.annotations.Nullable;
 
 public final class CommitCoordinator extends 
AbstractFuture<List<CommitResult>> implements AutoCloseable
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSession.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSession.java
index fafa4d0..52432cb 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSession.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSession.java
@@ -31,6 +31,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 
@@ -40,9 +41,11 @@ import org.apache.commons.io.FileUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.cassandra.bridge.SSTableDescriptor;
 import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
 import org.apache.cassandra.spark.common.Digest;
 import org.apache.cassandra.spark.common.SSTables;
+import org.apache.cassandra.spark.data.FileType;
 
 public class DirectStreamSession extends 
StreamSession<TransportContext.DirectDataBulkWriterContext>
 {
@@ -56,16 +59,64 @@ public class DirectStreamSession extends 
StreamSession<TransportContext.DirectDa
                                TransportContext.DirectDataBulkWriterContext 
transportContext,
                                String sessionID,
                                Range<BigInteger> tokenRange,
-                               ReplicaAwareFailureHandler<RingInstance> 
failureHandler)
+                               ReplicaAwareFailureHandler<RingInstance> 
failureHandler,
+                               ExecutorService executorService)
     {
-        super(writerContext, sstableWriter, transportContext, sessionID, 
tokenRange, failureHandler);
+        super(writerContext, sstableWriter, transportContext, sessionID, 
tokenRange, failureHandler, executorService);
         this.directDataTransferApi = transportContext.dataTransferApi();
     }
 
     @Override
-    protected StreamResult doScheduleStream(SortedSSTableWriter sstableWriter)
+    protected void onSSTablesProduced(Set<SSTableDescriptor> sstables)
     {
-        sendSSTables(sstableWriter);
+        // do not submit the streaming task if it is in the last stream run, 
the rest of the sstables should be handled by doScheduleStream
+        if (sstables.isEmpty() || isStreamFinalized())
+        {
+            return;
+        }
+
+        // send sstables asynchronously
+        executorService.submit(() -> {
+            try
+            {
+                // The task does those steps
+                // 1. find the newly produced sstables
+                // 2. validate the sstables
+                // 3. send the sstables to all replicas
+                // 4. remove the sstables once sent
+                Map<Path, Digest> fileDigests = 
sstableWriter.prepareSStablesToSend(writerContext, sstables);
+                recordStreamedFiles(fileDigests.keySet());
+                fileDigests.keySet()
+                           .stream()
+                           .filter(p -> 
p.getFileName().toString().endsWith(FileType.DATA.getFileSuffix()))
+                           .forEach(this::sendSStableToReplicas);
+                LOGGER.info("[{}]: Sent SSTables. sstables={}", sessionID, 
sstableWriter.sstableCount());
+                LOGGER.info("[{}]: Removing temporary files after streaming. 
files={}", sessionID, fileDigests);
+                fileDigests.keySet().forEach(path -> {
+                    try
+                    {
+                        Files.deleteIfExists(path);
+                    }
+                    catch (IOException e)
+                    {
+                        LOGGER.warn("[{}]: Failed to delete temporary file. 
file={}", sessionID, path);
+                    }
+                });
+            }
+            catch (IOException e)
+            {
+                LOGGER.error("[{}]: Unexpected exception while streaming 
SSTables {}",
+                             sessionID, sstableWriter.getOutDir());
+                setLastStreamFailure(e);
+                cleanAllReplicas();
+            }
+        });
+    }
+
+    @Override
+    protected StreamResult doFinalizeStream()
+    {
+        sendRemainingSSTables();
         // StreamResult has errors streaming to replicas
         DirectStreamResult streamResult = new DirectStreamResult(sessionID,
                                                                  tokenRange,
@@ -94,18 +145,19 @@ public class DirectStreamSession extends 
StreamSession<TransportContext.DirectDa
     }
 
     @Override
-    protected void sendSSTables(final SortedSSTableWriter sstableWriter)
+    protected void sendRemainingSSTables()
     {
         try (DirectoryStream<Path> dataFileStream = 
Files.newDirectoryStream(sstableWriter.getOutDir(), "*Data.db"))
         {
             for (Path dataFile : dataFileStream)
             {
-                int ssTableIdx = nextSSTableIdx.getAndIncrement();
+                if (isFileStreamed(dataFile))
+                {
+                    // the file is already streamed or being streamed; 
skipping it
+                    continue;
+                }
 
-                LOGGER.info("[{}]: Pushing SSTable {} to replicas {}",
-                            sessionID, dataFile,
-                            
replicas.stream().map(RingInstance::nodeName).collect(Collectors.joining(",")));
-                replicas.removeIf(replica -> 
!trySendSSTableToReplica(sstableWriter, dataFile, ssTableIdx, replica));
+                sendSStableToReplicas(dataFile);
             }
 
             LOGGER.info("[{}]: Sent SSTables. sstables={}", sessionID, 
sstableWriter.sstableCount());
@@ -133,14 +185,24 @@ public class DirectStreamSession extends 
StreamSession<TransportContext.DirectDa
         }
     }
 
-    private boolean trySendSSTableToReplica(SortedSSTableWriter sstableWriter,
-                                            Path dataFile,
-                                            int ssTableIdx,
-                                            RingInstance replica)
+    private void sendSStableToReplicas(Path dataFile)
+    {
+        int ssTableIdx = nextSSTableIdx.getAndIncrement();
+
+        LOGGER.info("[{}]: Pushing SSTable {} to replicas {}",
+                    sessionID, dataFile,
+                    
replicas.stream().map(RingInstance::nodeName).collect(Collectors.joining(",")));
+        replicas.removeIf(replica -> !trySendSSTableToOneReplica(dataFile, 
ssTableIdx, replica, sstableWriter.fileDigestMap()));
+    }
+
+    private boolean trySendSSTableToOneReplica(Path dataFile,
+                                               int ssTableIdx,
+                                               RingInstance replica,
+                                               Map<Path, Digest> fileDigests)
     {
         try
         {
-            sendSSTableToReplica(dataFile, ssTableIdx, replica, 
sstableWriter.fileDigestMap());
+            sendSSTableToOneReplica(dataFile, ssTableIdx, replica, 
fileDigests);
             return true;
         }
         catch (Exception exception)
@@ -155,10 +217,10 @@ public class DirectStreamSession extends 
StreamSession<TransportContext.DirectDa
         }
     }
 
-    private void sendSSTableToReplica(Path dataFile,
-                                      int ssTableIdx,
-                                      RingInstance instance,
-                                      Map<Path, Digest> fileHashes) throws 
IOException
+    private void sendSSTableToOneReplica(Path dataFile,
+                                         int ssTableIdx,
+                                         RingInstance instance,
+                                         Map<Path, Digest> fileHashes) throws 
IOException
     {
         try (DirectoryStream<Path> componentFileStream = 
Files.newDirectoryStream(dataFile.getParent(),
                                                                                
   SSTables.getSSTableBaseName(dataFile) + "*"))
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/HeartbeatReporter.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/HeartbeatReporter.java
index b27983c..19a32e6 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/HeartbeatReporter.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/HeartbeatReporter.java
@@ -32,7 +32,7 @@ import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.cassandra.spark.bulkwriter.util.ThreadUtil;
+import org.apache.cassandra.util.ThreadUtil;
 
 public class HeartbeatReporter implements Closeable
 {
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/ImportCompletionCoordinator.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/ImportCompletionCoordinator.java
index 1133bc5..e77e6c6 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/ImportCompletionCoordinator.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/ImportCompletionCoordinator.java
@@ -45,7 +45,7 @@ import org.apache.cassandra.sidecar.client.SidecarInstance;
 import org.apache.cassandra.spark.bulkwriter.blobupload.BlobDataTransferApi;
 import org.apache.cassandra.spark.bulkwriter.blobupload.BlobStreamResult;
 import org.apache.cassandra.spark.bulkwriter.blobupload.CreatedRestoreSlice;
-import org.apache.cassandra.spark.bulkwriter.util.ThreadUtil;
+import org.apache.cassandra.util.ThreadUtil;
 import org.apache.cassandra.spark.data.ReplicationFactor;
 import 
org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
 
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java
index aa0104a..5368003 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java
@@ -53,7 +53,7 @@ import 
o.a.c.sidecar.client.shaded.common.response.TimeSkewResponse;
 import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
 import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
 import org.apache.cassandra.spark.bulkwriter.util.TaskContextUtils;
-import org.apache.cassandra.spark.bulkwriter.util.ThreadUtil;
+import org.apache.cassandra.util.ThreadUtil;
 import org.apache.cassandra.spark.data.BridgeUdtValue;
 import org.apache.cassandra.spark.data.CqlField;
 import org.apache.cassandra.spark.data.CqlTable;
@@ -198,7 +198,7 @@ public class RecordWriter
                     }
                     currentRange = subRanges.get(currentRangeIndex);
                 }
-                maybeCreateStreamSession(taskContext, currentRange);
+                maybeSwitchToNewStreamSession(taskContext, currentRange);
                 writeRow(rowData, valueMap, partitionId, 
streamSession.getTokenRange());
             }
 
@@ -219,6 +219,12 @@ public class RecordWriter
                          taskContext.stageAttemptNumber(),
                          taskContext.attemptNumber());
 
+            // if streamSession is not closed/nullified. Clean it up here
+            if (streamSession != null)
+            {
+                streamSession.cleanupOnFailure();
+            }
+
             if (exception instanceof InterruptedException)
             {
                 Thread.currentThread().interrupt();
@@ -274,27 +280,10 @@ public class RecordWriter
      * If we do find the need to split a range into sub-ranges, we create the 
corresponding session for the sub-range
      * if the token from the row data belongs to the range.
      */
-    private void maybeCreateStreamSession(TaskContext taskContext,
-                                          Range<BigInteger> currentRange) 
throws IOException
+    private void maybeSwitchToNewStreamSession(TaskContext taskContext,
+                                               Range<BigInteger> currentRange) 
throws IOException
     {
-        maybeCreateSubRangeSession(taskContext, currentRange);
-
-        // If we do not have any stream session at this point, we create a 
session using the partition's token range
-        if (streamSession == null)
-        {
-            createStreamSessionWithAssignedRange(taskContext);
-        }
-    }
-
-    /**
-     * Given that the token belongs to a sub-range, creates a new stream 
session if either
-     * 1) we do not have an existing stream session, or 2) the existing stream 
session corresponds to a range that
-     * does NOT match the sub-range the token belongs to.
-     */
-    private void maybeCreateSubRangeSession(TaskContext taskContext,
-                                            Range<BigInteger> 
matchingSubRange) throws IOException
-    {
-        if (streamSession != null && 
streamSession.getTokenRange().equals(matchingSubRange))
+        if (streamSession != null && 
streamSession.getTokenRange().equals(currentRange))
         {
             return;
         }
@@ -306,12 +295,7 @@ public class RecordWriter
             flushAsync(taskContext.partitionId());
         }
 
-        streamSession = createStreamSession(taskContext, matchingSubRange);
-    }
-
-    private void createStreamSessionWithAssignedRange(TaskContext taskContext) 
throws IOException
-    {
-        createStreamSession(taskContext, getTokenRange(taskContext));
+        streamSession = createStreamSession(taskContext, currentRange);
     }
 
     private StreamSession<?> createStreamSession(TaskContext taskContext, 
Range<BigInteger> range) throws IOException
@@ -321,11 +305,11 @@ public class RecordWriter
         String sessionId = TaskContextUtils.createStreamSessionId(taskContext);
         Path perSessionDirectory = baseDir.resolve(sessionId);
         Files.createDirectories(perSessionDirectory);
-        SortedSSTableWriter sstableWriter = 
tableWriterFactory.create(writerContext, perSessionDirectory, digestAlgorithm);
+        SortedSSTableWriter sstableWriter = 
tableWriterFactory.create(writerContext, perSessionDirectory, digestAlgorithm, 
taskContext.partitionId());
         LOGGER.info("[{}][{}] Created new SSTable writer with directory={}",
                     taskContext.partitionId(), sessionId, perSessionDirectory);
         return writerContext.transportContext()
-                            .createStreamSession(writerContext, sessionId, 
sstableWriter, range, failureHandler);
+                            .createStreamSession(writerContext, sessionId, 
sstableWriter, range, failureHandler, executorService);
     }
 
     /**
@@ -478,7 +462,7 @@ public class RecordWriter
         Preconditions.checkState(streamSession != null);
         LOGGER.info("[{}][{}] Closing writer and scheduling SStable stream 
with {} rows",
                     partitionId, streamSession.sessionID, 
streamSession.rowCount());
-        Future<StreamResult> future = 
streamSession.scheduleStreamAsync(partitionId, executorService);
+        Future<StreamResult> future = streamSession.finalizeStreamAsync();
         streamFutures.put(streamSession.sessionID, future);
         streamSession = null;
     }
@@ -495,11 +479,13 @@ public class RecordWriter
          * @param writerContext   the context for the bulk writer job
          * @param outDir          an output directory where SSTables 
components will be written to
          * @param digestAlgorithm a digest provider to calculate digests for 
every SSTable component
+         * @param partitionId     partition id
          * @return a new {@link SortedSSTableWriter}
          */
         SortedSSTableWriter create(BulkWriterContext writerContext,
                                    Path outDir,
-                                   DigestAlgorithm digestAlgorithm);
+                                   DigestAlgorithm digestAlgorithm,
+                                   int partitionId);
     }
 
     // The java version of org.apache.spark.InterruptibleIterator
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriter.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriter.java
index e5e4227..bd0e35a 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriter.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriter.java
@@ -27,8 +27,10 @@ import java.nio.file.Path;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Consumer;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Range;
@@ -38,9 +40,10 @@ import org.slf4j.LoggerFactory;
 import org.apache.cassandra.bridge.CassandraBridgeFactory;
 import org.apache.cassandra.bridge.CassandraVersion;
 import org.apache.cassandra.bridge.CassandraVersionFeatures;
+import org.apache.cassandra.bridge.SSTableDescriptor;
 import org.apache.cassandra.spark.common.Digest;
 import org.apache.cassandra.spark.common.SSTables;
-import org.apache.cassandra.spark.data.DataLayer;
+import org.apache.cassandra.spark.data.FileType;
 import org.apache.cassandra.spark.data.LocalDataLayer;
 import org.apache.cassandra.spark.data.partitioner.Partitioner;
 import org.apache.cassandra.spark.reader.RowData;
@@ -68,9 +71,10 @@ public class SortedSSTableWriter
 
     private final Path outDir;
     private final org.apache.cassandra.bridge.SSTableWriter cqlSSTableWriter;
+    private final int partitionId;
     private BigInteger minToken = null;
     private BigInteger maxToken = null;
-    private final Map<Path, Digest> fileDigestMap = new HashMap<>();
+    private final Map<Path, Digest> overallFileDigests = new HashMap<>();
     private final DigestAlgorithm digestAlgorithm;
 
     private int sstableCount = 0;
@@ -78,17 +82,20 @@ public class SortedSSTableWriter
     private long bytesWritten = 0;
 
     public SortedSSTableWriter(org.apache.cassandra.bridge.SSTableWriter 
tableWriter, Path outDir,
-                               DigestAlgorithm digestAlgorithm)
+                               DigestAlgorithm digestAlgorithm,
+                               int partitionId)
     {
-        cqlSSTableWriter = tableWriter;
+        this.cqlSSTableWriter = tableWriter;
         this.outDir = outDir;
         this.digestAlgorithm = digestAlgorithm;
+        this.partitionId = partitionId;
     }
 
-    public SortedSSTableWriter(BulkWriterContext writerContext, Path outDir, 
DigestAlgorithm digestAlgorithm)
+    public SortedSSTableWriter(BulkWriterContext writerContext, Path outDir, 
DigestAlgorithm digestAlgorithm, int partitionId)
     {
         this.outDir = outDir;
         this.digestAlgorithm = digestAlgorithm;
+        this.partitionId = partitionId;
 
         String lowestCassandraVersion = 
writerContext.cluster().getLowestCassandraVersion();
         String packageVersion = getPackageVersion(lowestCassandraVersion);
@@ -131,6 +138,11 @@ public class SortedSSTableWriter
         rowCount += 1;
     }
 
+    public void setSSTablesProducedListener(Consumer<Set<SSTableDescriptor>> 
listener)
+    {
+        cqlSSTableWriter.setSSTablesProducedListener(listener);
+    }
+
     /**
      * @return the total number of rows written
      */
@@ -155,24 +167,64 @@ public class SortedSSTableWriter
         return sstableCount;
     }
 
-    public void close(BulkWriterContext writerContext, int partitionId) throws 
IOException
+    public Map<Path, Digest> prepareSStablesToSend(@NotNull BulkWriterContext 
writerContext, Set<SSTableDescriptor> sstables) throws IOException
+    {
+        DirectoryStream.Filter<Path> sstableFilter = path -> {
+            SSTableDescriptor baseName = SSTables.getSSTableDescriptor(path);
+            return sstables.contains(baseName);
+        };
+        Set<Path> dataFilePaths = new HashSet<>();
+        Map<Path, Digest> fileDigests = new HashMap<>();
+        try (DirectoryStream<Path> stream = 
Files.newDirectoryStream(getOutDir(), sstableFilter))
+        {
+            for (Path path : stream)
+            {
+                if (path.getFileName().toString().endsWith("-" + 
FileType.DATA.getFileSuffix()))
+                {
+                    dataFilePaths.add(path);
+                    sstableCount += 1;
+                }
+
+                Digest digest = digestAlgorithm.calculateFileDigest(path);
+                fileDigests.put(path, digest);
+                LOGGER.debug("Calculated digest={} for path={}", digest, path);
+            }
+        }
+        bytesWritten += calculatedTotalSize(fileDigests.keySet());
+        overallFileDigests.putAll(fileDigests);
+        validateSSTables(writerContext, dataFilePaths);
+        return fileDigests;
+    }
+
+    public void close(BulkWriterContext writerContext) throws IOException
     {
         cqlSSTableWriter.close();
-        sstableCount = 0;
         for (Path dataFile : getDataFileStream())
         {
             // NOTE: We calculate file hashes before re-reading so that we 
know what we hashed
             //       is what we validated. Then we send these along with the 
files and the
             //       receiving end re-hashes the files to make sure they still 
match.
-            fileDigestMap.putAll(calculateFileDigestMap(dataFile));
+            overallFileDigests.putAll(calculateFileDigestMap(dataFile));
             sstableCount += 1;
         }
-        bytesWritten = calculatedTotalSize(fileDigestMap.keySet());
-        validateSSTables(writerContext, partitionId);
+        bytesWritten += calculatedTotalSize(overallFileDigests.keySet());
+        validateSSTables(writerContext);
+    }
+
+    @VisibleForTesting
+    public void validateSSTables(@NotNull BulkWriterContext writerContext)
+    {
+        validateSSTables(writerContext, null);
     }
 
+    /**
+     * Validate SSTables. If dataFilePaths is null, it finds all sstables 
under the output directory of the writer and validates them
+     * @param writerContext bulk writer context
+     * @param dataFilePaths paths of sstables (data file) to be validated. The 
argument is nullable.
+     *                      When it is null, it validates all sstables under 
the output directory.
+     */
     @VisibleForTesting
-    public void validateSSTables(@NotNull BulkWriterContext writerContext, int 
partitionId)
+    public void validateSSTables(@NotNull BulkWriterContext writerContext, 
Set<Path> dataFilePaths)
     {
         // NOTE: If this current implementation of SS-tables' validation 
proves to be a performance issue,
         //       we will need to modify LocalDataLayer to allow scanning and 
compaction of single data file,
@@ -185,15 +237,20 @@ public class SortedSSTableWriter
             Partitioner partitioner = writerContext.cluster().getPartitioner();
             Set<String> udtStatements = 
writerContext.schema().getUserDefinedTypeStatements();
             String directory = getOutDir().toString();
-            DataLayer layer = new LocalDataLayer(version,
-                                                 partitioner,
-                                                 keyspace,
-                                                 schema,
-                                                 udtStatements,
-                                                 Collections.emptyList() /* 
requestedFeatures */,
-                                                 false /* 
useSSTableInputStream */,
-                                                 null /* statsClass */,
-                                                 directory);
+            LocalDataLayer layer = new LocalDataLayer(version,
+                                                      partitioner,
+                                                      keyspace,
+                                                      schema,
+                                                      udtStatements,
+                                                      Collections.emptyList() 
/* requestedFeatures */,
+                                                      false /* 
useSSTableInputStream */,
+                                                      null /* statsClass */,
+                                                      directory);
+            if (dataFilePaths != null)
+            {
+                layer.setDataFilePaths(dataFilePaths);
+            }
+
             try (StreamScanner<RowData> scanner = 
layer.openCompactionScanner(partitionId, Collections.emptyList(), null))
             {
                 while (scanner.next())
@@ -255,6 +312,6 @@ public class SortedSSTableWriter
      */
     public Map<Path, Digest> fileDigestMap()
     {
-        return Collections.unmodifiableMap(fileDigestMap);
+        return Collections.unmodifiableMap(overallFileDigests);
     }
 }
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamSession.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamSession.java
index c437f6e..839b46d 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamSession.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamSession.java
@@ -21,6 +21,7 @@ package org.apache.cassandra.spark.bulkwriter;
 
 import java.io.IOException;
 import java.math.BigInteger;
+import java.nio.file.Path;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
@@ -29,14 +30,18 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Range;
+import org.apache.commons.io.FileUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import o.a.c.sidecar.client.shaded.io.vertx.core.impl.ConcurrentHashSet;
+import org.apache.cassandra.bridge.SSTableDescriptor;
 import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
 import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
 
@@ -53,6 +58,11 @@ public abstract class StreamSession<T extends 
TransportContext>
     protected final ReplicaAwareFailureHandler<RingInstance> failureHandler;
     protected final TokenRangeMapping<RingInstance> tokenRangeMapping;
     protected final SortedSSTableWriter sstableWriter;
+    protected final ExecutorService executorService;
+
+    private final Set<Path> streamedFiles = new ConcurrentHashSet<>();
+    private final AtomicReference<Exception> lastStreamFailure = new 
AtomicReference<>();
+    private volatile boolean isStreamFinalized = false;
 
     @VisibleForTesting
     protected StreamSession(BulkWriterContext writerContext,
@@ -60,18 +70,41 @@ public abstract class StreamSession<T extends 
TransportContext>
                             T transportContext,
                             String sessionID,
                             Range<BigInteger> tokenRange,
-                            ReplicaAwareFailureHandler<RingInstance> 
failureHandler)
+                            ReplicaAwareFailureHandler<RingInstance> 
failureHandler,
+                            ExecutorService executorService)
     {
         this.writerContext = writerContext;
         this.sstableWriter = sstableWriter;
+        
this.sstableWriter.setSSTablesProducedListener(this::onSSTablesProduced);
         this.transportContext = transportContext;
         this.tokenRangeMapping = 
writerContext.cluster().getTokenRangeMapping(true);
         this.sessionID = sessionID;
         this.tokenRange = tokenRange;
         this.failureHandler = failureHandler;
         this.replicas = getReplicas();
+        this.executorService = executorService;
     }
 
+    /**
+     * Get notified on sstables produced. When the method is invoked, the 
input parameter 'sstables' is guaranteed to be non-empty.
+     *
+     * @param sstables produces SSTables
+     */
+    protected abstract void onSSTablesProduced(Set<SSTableDescriptor> 
sstables);
+
+    /**
+     * Finalize the stream with the produced sstables and return the stream 
result.
+     *
+     * @return stream result
+     */
+    protected abstract StreamResult doFinalizeStream();
+
+    /**
+     * Send the SSTable(s) written by SSTableWriter
+     * The code runs on a separate thread
+     */
+    protected abstract void sendRemainingSSTables();
+
     public Range<BigInteger> getTokenRange()
     {
         return tokenRange;
@@ -79,6 +112,9 @@ public abstract class StreamSession<T extends 
TransportContext>
 
     public void addRow(BigInteger token, Map<String, Object> boundValues) 
throws IOException
     {
+        // exit early when sending the produced sstables has failed
+        rethrowIfLastStreamFailed();
+
         sstableWriter.addRow(token, boundValues);
     }
 
@@ -87,14 +123,69 @@ public abstract class StreamSession<T extends 
TransportContext>
         return sstableWriter.rowCount();
     }
 
-    public Future<StreamResult> scheduleStreamAsync(int partitionId, 
ExecutorService executorService) throws IOException
+    public Future<StreamResult> finalizeStreamAsync() throws IOException
     {
-        Preconditions.checkState(!sstableWriter.getTokenRange().isEmpty(), 
"Trying to stream empty SSTable");
+        isStreamFinalized = true;
+        rethrowIfLastStreamFailed();
+        Preconditions.checkState(!sstableWriter.getTokenRange().isEmpty(), 
"Cannot stream empty SSTable");
         
Preconditions.checkState(tokenRange.encloses(sstableWriter.getTokenRange()),
                                  "SSTable range %s should be enclosed in the 
partition range %s",
                                  sstableWriter.getTokenRange(), tokenRange);
-        sstableWriter.close(writerContext, partitionId);
-        return executorService.submit(() -> doScheduleStream(sstableWriter));
+        // close the writer before finalizing stream
+        sstableWriter.close(writerContext);
+        return executorService.submit(this::doFinalizeStream);
+    }
+
+    /**
+     * Clean up any remaining files on disk when streaming is failed
+     */
+    public void cleanupOnFailure()
+    {
+        try
+        {
+            sstableWriter.close(writerContext);
+        }
+        catch (IOException e)
+        {
+            LOGGER.warn("[{}]: Failed to close sstable writer on streaming 
failure", sessionID, e);
+        }
+
+        try
+        {
+            FileUtils.deleteDirectory(sstableWriter.getOutDir().toFile());
+        }
+        catch (IOException e)
+        {
+            LOGGER.warn("[{}]: Failed to clean up the produced sstables on 
streaming failure", sessionID, e);
+        }
+    }
+
+    protected boolean isStreamFinalized()
+    {
+        return isStreamFinalized;
+    }
+
+    protected boolean setLastStreamFailure(Exception streamFailure)
+    {
+        return lastStreamFailure.compareAndSet(null, streamFailure);
+    }
+
+    protected void recordStreamedFiles(Set<Path> files)
+    {
+        streamedFiles.addAll(files);
+    }
+
+    protected boolean isFileStreamed(Path file)
+    {
+        return streamedFiles.contains(file);
+    }
+
+    private void rethrowIfLastStreamFailed() throws IOException
+    {
+        if (lastStreamFailure.get() != null)
+        {
+            throw new IOException("Unexpected exception while streaming 
SSTables", lastStreamFailure.get());
+        }
     }
 
     @VisibleForTesting
@@ -138,20 +229,4 @@ public abstract class StreamSession<T extends 
TransportContext>
         return failedInstances.contains(ringInstance)
                || blockedInstanceIps.contains(ringInstance.ipAddress());
     }
-
-    /**
-     * Schedule the stream with the produced sstables and return the stream 
result.
-     *
-     * @param sstableWriter produces SSTable(s)
-     * @return stream result
-     */
-    protected abstract StreamResult doScheduleStream(SortedSSTableWriter 
sstableWriter);
-
-    /**
-     * Send the SSTable(s) written by SSTableWriter
-     * The code runs on a separate thread
-     *
-     * @param sstableWriter produces SSTable(s)
-     */
-    protected abstract void sendSSTables(SortedSSTableWriter sstableWriter);
 }
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/TransportContext.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/TransportContext.java
index 59e848a..823c1d6 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/TransportContext.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/TransportContext.java
@@ -21,6 +21,7 @@ package org.apache.cassandra.spark.bulkwriter;
 
 import java.io.Serializable;
 import java.math.BigInteger;
+import java.util.concurrent.ExecutorService;
 
 import com.google.common.collect.Range;
 
@@ -42,13 +43,15 @@ public interface TransportContext
      * @param sstableWriter sstable writer of the stream session
      * @param range token range of the stream session
      * @param failureHandler handler to track failures of the stream session
+     * @param executorService executor service
      * @return a new stream session
      */
     StreamSession<? extends TransportContext> 
createStreamSession(BulkWriterContext writerContext,
                                                                   String 
sessionId,
                                                                   
SortedSSTableWriter sstableWriter,
                                                                   
Range<BigInteger> range,
-                                                                  
ReplicaAwareFailureHandler<RingInstance> failureHandler);
+                                                                  
ReplicaAwareFailureHandler<RingInstance> failureHandler,
+                                                                  
ExecutorService executorService);
 
     default void close()
     {
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSession.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSession.java
index 9d3ff95..8ec1bcf 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSession.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSession.java
@@ -19,10 +19,14 @@
 
 package org.apache.cassandra.spark.bulkwriter.blobupload;
 
+import java.io.IOException;
 import java.math.BigInteger;
 import java.nio.file.Path;
 import java.util.HashSet;
+import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Range;
@@ -33,6 +37,7 @@ import 
o.a.c.sidecar.client.shaded.common.request.data.CreateSliceRequestPayload
 import 
o.a.c.sidecar.client.shaded.common.request.data.RestoreJobSummaryResponsePayload;
 import org.apache.cassandra.bridge.CassandraBridge;
 import org.apache.cassandra.bridge.CassandraBridgeFactory;
+import org.apache.cassandra.bridge.SSTableDescriptor;
 import org.apache.cassandra.clients.Sidecar;
 import org.apache.cassandra.sidecar.client.SidecarInstance;
 import org.apache.cassandra.spark.bulkwriter.BulkWriteValidator;
@@ -45,6 +50,8 @@ import org.apache.cassandra.spark.bulkwriter.StreamResult;
 import org.apache.cassandra.spark.bulkwriter.StreamSession;
 import org.apache.cassandra.spark.bulkwriter.TransportContext;
 import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
+import org.apache.cassandra.spark.common.Digest;
+import org.apache.cassandra.spark.common.SSTables;
 import org.apache.cassandra.spark.common.client.ClientException;
 import org.apache.cassandra.spark.data.QualifiedTableName;
 import org.apache.cassandra.spark.transports.storage.StorageCredentials;
@@ -61,25 +68,27 @@ public class BlobStreamSession extends 
StreamSession<TransportContext.CloudStora
     protected final CassandraBridge bridge;
     private final Set<CreatedRestoreSlice> createdRestoreSlices = new 
HashSet<>();
     private final SSTablesBundler sstablesBundler;
-    private int bundleCount;
+    private int bundleCount = 0;
 
     public BlobStreamSession(BulkWriterContext bulkWriterContext, 
SortedSSTableWriter sstableWriter,
                              TransportContext.CloudStorageTransportContext 
transportContext,
                              String sessionID, Range<BigInteger> tokenRange,
-                             ReplicaAwareFailureHandler<RingInstance> 
failureHandler)
+                             ReplicaAwareFailureHandler<RingInstance> 
failureHandler,
+                             ExecutorService executorService)
     {
         this(bulkWriterContext, sstableWriter, transportContext, sessionID, 
tokenRange,
              
CassandraBridgeFactory.get(bulkWriterContext.cluster().getLowestCassandraVersion()),
-             failureHandler);
+             failureHandler, executorService);
     }
 
     @VisibleForTesting
     public BlobStreamSession(BulkWriterContext bulkWriterContext, 
SortedSSTableWriter sstableWriter,
                              TransportContext.CloudStorageTransportContext 
transportContext,
                              String sessionID, Range<BigInteger> tokenRange,
-                             CassandraBridge bridge, 
ReplicaAwareFailureHandler<RingInstance> failureHandler)
+                             CassandraBridge bridge, 
ReplicaAwareFailureHandler<RingInstance> failureHandler,
+                             ExecutorService executorService)
     {
-        super(bulkWriterContext, sstableWriter, transportContext, sessionID, 
tokenRange, failureHandler);
+        super(bulkWriterContext, sstableWriter, transportContext, sessionID, 
tokenRange, failureHandler, executorService);
 
         JobInfo job = bulkWriterContext.job();
         long maxSizePerBundleInBytes = 
job.transportInfo().getMaxSizePerBundleInBytes();
@@ -94,7 +103,57 @@ public class BlobStreamSession extends 
StreamSession<TransportContext.CloudStora
     }
 
     @Override
-    protected StreamResult doScheduleStream(SortedSSTableWriter sstableWriter)
+    protected void onSSTablesProduced(Set<SSTableDescriptor> sstables)
+    {
+        if (sstables.isEmpty() || isStreamFinalized())
+        {
+            return;
+        }
+
+        executorService.submit(() -> {
+            try
+            {
+                Map<Path, Digest> fileDigests = 
sstableWriter.prepareSStablesToSend(writerContext, sstables);
+                // sstablesBundler keeps track of the known files. No need to 
record the streamed files.
+                // group the files by sstable (unique) basename and add to 
bundler
+                fileDigests.keySet()
+                           .stream()
+                           
.collect(Collectors.groupingBy(SSTables::getSSTableBaseName))
+                           .values()
+                           .forEach(sstablesBundler::includeSSTable);
+
+                if (!sstablesBundler.hasNext())
+                {
+                    // hold on until a bundle can be produced
+                    return;
+                }
+
+                bundleCount += 1;
+                Bundle bundle = sstablesBundler.next();
+                try
+                {
+                    sendBundle(bundle, false);
+                }
+                catch (RuntimeException e)
+                {
+                    // log and rethrow
+                    LOGGER.error("[{}]: Unexpected exception while upload 
SSTable", sessionID, e);
+                    setLastStreamFailure(e);
+                }
+                finally
+                {
+                    bundle.deleteAll();
+                }
+            }
+            catch (IOException e)
+            {
+                throw new RuntimeException(e);
+            }
+        });
+    }
+
+    @Override
+    protected StreamResult doFinalizeStream()
     {
         sstablesBundler.includeDirectory(sstableWriter.getOutDir());
 
@@ -112,7 +171,7 @@ public class BlobStreamSession extends 
StreamSession<TransportContext.CloudStora
             return BlobStreamResult.empty(sessionID, tokenRange);
         }
 
-        sendSSTables(sstableWriter);
+        sendRemainingSSTables();
         LOGGER.info("[{}]: Uploaded bundles to S3. sstables={} bundles={}", 
sessionID, sstableWriter.sstableCount(), bundleCount);
 
         BlobStreamResult streamResult = new BlobStreamResult(sessionID,
@@ -130,9 +189,8 @@ public class BlobStreamSession extends 
StreamSession<TransportContext.CloudStora
     }
 
     @Override
-    protected void sendSSTables(SortedSSTableWriter sstableWriter)
+    protected void sendRemainingSSTables()
     {
-        bundleCount = 0;
         while (sstablesBundler.hasNext())
         {
             bundleCount++;
@@ -153,6 +211,15 @@ public class BlobStreamSession extends 
StreamSession<TransportContext.CloudStora
         }
     }
 
+    @Override
+    public void cleanupOnFailure()
+    {
+        super.cleanupOnFailure();
+
+        // remove any remaining bundle
+        sstablesBundler.cleanupBundle(sessionID);
+    }
+
     void sendBundle(Bundle bundle, boolean hasRefreshedCredentials)
     {
         StorageCredentials writeCredentials = 
getStorageCredentialsFromSidecar();
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/CassandraCloudStorageTransportContext.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/CassandraCloudStorageTransportContext.java
index fe07304..f6ba470 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/CassandraCloudStorageTransportContext.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/CassandraCloudStorageTransportContext.java
@@ -22,6 +22,7 @@ package org.apache.cassandra.spark.bulkwriter.blobupload;
 import java.lang.reflect.InvocationTargetException;
 import java.math.BigInteger;
 import java.util.Objects;
+import java.util.concurrent.ExecutorService;
 
 import com.google.common.collect.Range;
 import org.slf4j.Logger;
@@ -80,10 +81,12 @@ public class CassandraCloudStorageTransportContext 
implements TransportContext.C
                                                  String sessionId,
                                                  SortedSSTableWriter 
sstableWriter,
                                                  Range<BigInteger> range,
-                                                 
ReplicaAwareFailureHandler<RingInstance> failureHandler)
+                                                 
ReplicaAwareFailureHandler<RingInstance> failureHandler,
+                                                 ExecutorService 
executorService)
     {
         return new BlobStreamSession(writerContext, sstableWriter,
-                                     this, sessionId, range, failureHandler);
+                                     this, sessionId, range, failureHandler,
+                                     executorService);
     }
 
     @Override
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableCollector.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableCollector.java
index 7bb1c6d..1badfd8 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableCollector.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableCollector.java
@@ -38,6 +38,12 @@ public interface SSTableCollector
      */
     void includeDirectory(Path dir);
 
+    /**
+     * Include the sstable components of an individual SSTable
+     * @param sstableComponents sstable components
+     */
+    void includeSSTable(List<Path> sstableComponents);
+
     /**
      * @return total size of all sstables included
      */
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableLister.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableLister.java
index 2f3b546..b67d754 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableLister.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableLister.java
@@ -46,6 +46,7 @@ import org.apache.cassandra.spark.data.FileSystemSSTable;
 import org.apache.cassandra.spark.data.QualifiedTableName;
 import org.apache.cassandra.spark.data.SSTable;
 import org.apache.cassandra.spark.stats.Stats;
+import org.jetbrains.annotations.NotNull;
 
 /**
  * {@link SSTableLister} lists the directories containing SSTables.
@@ -63,6 +64,7 @@ public class SSTableLister implements SSTableCollector
     private final CassandraBridge bridge;
     private final Queue<SSTableFilesAndRange> sstables;
     private final Set<Path> sstableDirectories;
+    private final Set<Path> knownFiles;
     private long totalSize;
 
     public SSTableLister(QualifiedTableName qualifiedTableName, 
CassandraBridge bridge)
@@ -71,6 +73,7 @@ public class SSTableLister implements SSTableCollector
         this.bridge = bridge;
         this.sstables = new LinkedBlockingQueue<>();
         this.sstableDirectories = new HashSet<>();
+        this.knownFiles = new HashSet<>();
     }
 
     @Override
@@ -83,19 +86,19 @@ public class SSTableLister implements SSTableCollector
         }
 
         listSSTables(dir)
-        .map(components -> {
-            SSTable sstable = buildSSTable(components);
-            SSTableSummary summary = 
bridge.getSSTableSummary(qualifiedTableName.keyspace(),
-                                                              
qualifiedTableName.table(),
-                                                              sstable);
-            long size = sizeSum(components);
-            totalSize += size;
-            return new SSTableFilesAndRange(summary, components, 
sizeSum(components));
-        })
+        .map(this::createSSTableFilesAndRange)
         .sorted(SORT_BY_FIRST_TOKEN_THEN_LAST_TOKEN)
         .forEach(sstables::add);
     }
 
+    @Override
+    public void includeSSTable(List<Path> sstableComponents)
+    {
+        knownFiles.addAll(sstableComponents);
+        SSTableFilesAndRange sstableAndRange = 
createSSTableFilesAndRange(sstableComponents);
+        sstables.add(sstableAndRange);
+    }
+
     @Override
     public long totalSize()
     {
@@ -132,6 +135,12 @@ public class SSTableLister implements SSTableCollector
         try (Stream<Path> stream = Files.list(dir))
         {
             stream.forEach(path -> {
+                if (knownFiles.contains(path))
+                {
+                    // ignore the file as it has been included via 
includeSSTable
+                    return;
+                }
+
                 final String ssTablePrefix = 
getSSTablePrefix(path.getFileName().toString());
 
                 if (ssTablePrefix.isEmpty())
@@ -195,4 +204,15 @@ public class SSTableLister implements SSTableCollector
         }
         return new FileSystemSSTable(dataComponents.get(0), true, 
Stats.DoNothingStats.INSTANCE::bufferingInputStreamStats);
     }
+
+    private @NotNull SSTableFilesAndRange 
createSSTableFilesAndRange(List<Path> sstableComponents)
+    {
+        SSTable sstable = buildSSTable(sstableComponents);
+        SSTableSummary summary = 
bridge.getSSTableSummary(qualifiedTableName.keyspace(),
+                                                          
qualifiedTableName.table(),
+                                                          sstable);
+        long size = sizeSum(sstableComponents);
+        totalSize += size;
+        return new SSTableFilesAndRange(summary, sstableComponents, size);
+    }
 }
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTablesBundler.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTablesBundler.java
index c93044f..5297af8 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTablesBundler.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTablesBundler.java
@@ -107,6 +107,11 @@ public class SSTablesBundler implements Iterator<Bundle>
         collector.includeDirectory(dir);
     }
 
+    public void includeSSTable(List<Path> sstableComponents)
+    {
+        collector.includeSSTable(sstableComponents);
+    }
+
     public void finish()
     {
         reachedEnd = true;
@@ -115,6 +120,11 @@ public class SSTablesBundler implements Iterator<Bundle>
     public void cleanupBundle(String sessionID)
     {
         LOGGER.info("[{}]: Clean up bundle files after stream session 
bundle={}", sessionID, currentBundle);
+        if (currentBundle == null)
+        {
+            return;
+        }
+
         try
         {
             Bundle bundle = currentBundle;
diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/LocalDataLayer.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/LocalDataLayer.java
index b0e6161..ff42489 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/LocalDataLayer.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/LocalDataLayer.java
@@ -31,6 +31,7 @@ import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -82,6 +83,7 @@ public class LocalDataLayer extends DataLayer implements 
Serializable
     private boolean useBufferingInputStream;
     private String[] paths;
     private int minimumReplicasPerMutation = 1;
+    private Set<Path> dataFilePaths = null;
 
     @Nullable
     private static Stats loadStats(@Nullable String statsClass)
@@ -236,6 +238,7 @@ public class LocalDataLayer extends DataLayer implements 
Serializable
         this.useBufferingInputStream = useBufferingInputStream;
         this.statsClass = statsClass;
         this.paths = paths;
+        this.dataFilePaths = new HashSet<>();
     }
 
     // For serialization
@@ -314,16 +317,32 @@ public class LocalDataLayer extends DataLayer implements 
Serializable
         return stats;
     }
 
+    public void setDataFilePaths(Set<Path> dataFilePaths)
+    {
+        this.dataFilePaths = dataFilePaths;
+    }
+
     @Override
     public SSTablesSupplier sstables(int partitionId,
                                      @Nullable SparkRangeFilter 
sparkRangeFilter,
                                      @NotNull List<PartitionKeyFilter> 
partitionKeyFilters)
     {
-        return new BasicSupplier(Arrays
-                .stream(paths)
-                .map(Paths::get)
-                .flatMap(Throwing.function(Files::list))
-                .filter(path -> path.getFileName().toString().endsWith("-" + 
FileType.DATA.getFileSuffix()))
+        Stream<Path> dataFilePathsStream;
+        // if data file paths is supplied, prefer them over listing files
+        if (dataFilePaths != null && !dataFilePaths.isEmpty())
+        {
+            dataFilePathsStream = dataFilePaths.stream();
+        }
+        else
+        {
+            dataFilePathsStream = Arrays
+                                  .stream(paths)
+                                  .map(Paths::get)
+                                  .flatMap(Throwing.function(Files::list))
+                                  .filter(path -> 
path.getFileName().toString().endsWith("-" + FileType.DATA.getFileSuffix()));
+        }
+
+        return new BasicSupplier(dataFilePathsStream
                 .map(path -> new FileSystemSSTable(path, 
useBufferingInputStream, () -> this.stats.bufferingInputStreamStats()))
                 .collect(Collectors.toSet()));
     }
diff --git 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSessionTest.java
 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSessionTest.java
index aa16da3..966de28 100644
--- 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSessionTest.java
+++ 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSessionTest.java
@@ -107,7 +107,7 @@ public class DirectStreamSessionTest
         StreamSession<?> ss = 
createStreamSession(NonValidatingTestSortedSSTableWriter::new);
         ss.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
         assertThat(ss.rowCount(), is(1L));
-        StreamResult streamResult = ss.scheduleStreamAsync(1, executor).get();
+        StreamResult streamResult = ss.finalizeStreamAsync().get();
         assertThat(streamResult.rowCount, is(1L));
         executor.assertFuturesCalled();
         assertThat(executor.futures.size(), equalTo(1));  // We only scheduled 
one SSTable
@@ -122,11 +122,11 @@ public class DirectStreamSessionTest
         Exception exception = assertThrows(IllegalStateException.class,
                                            () -> new DirectStreamSession(
                                            writerContext,
-                                           new 
NonValidatingTestSortedSSTableWriter(tableWriter, folder, digestAlgorithm),
+                                           new 
NonValidatingTestSortedSSTableWriter(tableWriter, folder, digestAlgorithm, 1),
                                            transportContext,
                                            "sessionId",
                                            Range.range(BigInteger.valueOf(0L), 
BoundType.OPEN, BigInteger.valueOf(0L), BoundType.CLOSED),
-                                           replicaAwareFailureHandler())
+                                           replicaAwareFailureHandler(), null)
                                            );
         assertThat(exception.getMessage(), is("No replicas found for range 
(0‥0]"));
     }
@@ -137,7 +137,7 @@ public class DirectStreamSessionTest
         StreamSession<?> ss = 
createStreamSession(NonValidatingTestSortedSSTableWriter::new);
         ss.addRow(BigInteger.valueOf(9999L), COLUMN_BOUND_VALUES);
         IllegalStateException illegalStateException = 
assertThrows(IllegalStateException.class,
-                                                      () -> 
ss.scheduleStreamAsync(1, executor));
+                                                                   
ss::finalizeStreamAsync);
         assertThat(illegalStateException.getMessage(), matchesPattern(
         "SSTable range \\[9999(‥|..)9999] should be enclosed in the partition 
range \\[101(‥|..)199]"));
     }
@@ -184,7 +184,7 @@ public class DirectStreamSessionTest
         ExecutionException ex = assertThrows(ExecutionException.class, () -> {
             StreamSession<?> ss = 
createStreamSession(NonValidatingTestSortedSSTableWriter::new);
             ss.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
-            Future<?> fut = ss.scheduleStreamAsync(1, executor);
+            Future<?> fut = ss.finalizeStreamAsync();
             tableWriter.removeOutDir();
             fut.get();
         });
@@ -220,7 +220,7 @@ public class DirectStreamSessionTest
             }
         });
         ss.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
-        ss.scheduleStreamAsync(1, executor).get();
+        ss.finalizeStreamAsync().get();
         executor.assertFuturesCalled();
         
assertThat(writerContext.getUploads().values().stream().mapToInt(Collection::size).sum(),
 equalTo(RF * FILES_PER_SSTABLE));
         final List<String> instances = 
writerContext.getUploads().keySet().stream().map(CassandraInstance::nodeName).collect(Collectors.toList());
@@ -245,7 +245,7 @@ public class DirectStreamSessionTest
         });
         ss.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
         ExecutionException exception = assertThrows(ExecutionException.class,
-                                                    () -> 
ss.scheduleStreamAsync(1, executor).get());
+                                                    () -> 
ss.finalizeStreamAsync().get());
         assertEquals("Failed to load 1 ranges with LOCAL_QUORUM for job " + 
writerContext.job().getId()
                      + " in phase UploadAndCommit.", 
exception.getCause().getMessage());
         executor.assertFuturesCalled();
@@ -260,7 +260,7 @@ public class DirectStreamSessionTest
         StreamSession<?> ss = 
createStreamSession(NonValidatingTestSortedSSTableWriter::new);
         ss.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
         ExecutionException ex = assertThrows(ExecutionException.class,
-                                             () -> ss.scheduleStreamAsync(1, 
executor).get());
+                                             () -> 
ss.finalizeStreamAsync().get());
         assertThat(ex.getCause().getMessage(), 
startsWith(LOAD_RANGE_ERROR_PREFIX));
     }
 
@@ -279,10 +279,11 @@ public class DirectStreamSessionTest
     private DirectStreamSession createStreamSession(MockTableWriter.Creator 
writerCreator)
     {
         return new DirectStreamSession(writerContext,
-                                       writerCreator.create(tableWriter, 
folder, digestAlgorithm),
+                                       writerCreator.create(tableWriter, 
folder, digestAlgorithm, 1),
                                        transportContext,
                                        "sessionId",
                                        range,
-                                       replicaAwareFailureHandler());
+                                       replicaAwareFailureHandler(),
+                                       executor);
     }
 }
diff --git 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java
 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java
index d863ea6..2dc5eab 100644
--- 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java
+++ 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java
@@ -30,6 +30,7 @@ import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
 import java.util.function.BiFunction;
 import java.util.function.Function;
 import java.util.function.Predicate;
@@ -498,14 +499,16 @@ public class MockBulkWriterContext implements 
BulkWriterContext, ClusterInfo, Jo
                                                         String sessionId,
                                                         SortedSSTableWriter 
sstableWriter,
                                                         Range<BigInteger> 
range,
-                                                        
ReplicaAwareFailureHandler<RingInstance> failureHandler)
+                                                        
ReplicaAwareFailureHandler<RingInstance> failureHandler,
+                                                        ExecutorService 
executorService)
             {
                 return new DirectStreamSession(mockBulkWriterContext,
                                                sstableWriter,
                                                this,
                                                sessionId,
                                                range,
-                                               failureHandler);
+                                               failureHandler,
+                                               executorService);
             }
         };
     }
diff --git 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockTableWriter.java
 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockTableWriter.java
index 9551fe2..69d2a31 100644
--- 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockTableWriter.java
+++ 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockTableWriter.java
@@ -25,10 +25,13 @@ import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.ArrayList;
 import java.util.Map;
+import java.util.Set;
+import java.util.function.Consumer;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.apache.commons.io.FileUtils;
 
+import org.apache.cassandra.bridge.SSTableDescriptor;
 import org.apache.cassandra.bridge.SSTableWriter;
 import org.apache.cassandra.spark.utils.DigestAlgorithm;
 
@@ -75,12 +78,16 @@ public class MockTableWriter implements SSTableWriter
         rows.add(values.values().toArray());
     }
 
+    @Override
+    public void setSSTablesProducedListener(Consumer<Set<SSTableDescriptor>> 
listener)
+    {
+        // do nothing
+    }
+
     @Override
     public void close() throws IOException
     {
         // Create files to mimic SSTableWriter
-        // TODO: Instead, we shouldn't have SSTableWriter return the outDir - 
we should
-        //       provide a way to iterate over the data files and pass a 
callable of some kind in
         for (String component: TABLE_COMPONENTS)
         {
             Path path = Paths.get(outDir.toString(), BASE_NAME + component);
@@ -107,6 +114,7 @@ public class MockTableWriter implements SSTableWriter
         // to match with SortedSSTableWriter's constructor
         SortedSSTableWriter create(MockTableWriter tableWriter,
                                    Path outDir,
-                                   DigestAlgorithm digestAlgorithm);
+                                   DigestAlgorithm digestAlgorithm,
+                                   int partitionId);
     }
 }
diff --git 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSortedSSTableWriter.java
 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSortedSSTableWriter.java
index eae8795..0498b7e 100644
--- 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSortedSSTableWriter.java
+++ 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSortedSSTableWriter.java
@@ -26,13 +26,13 @@ import org.jetbrains.annotations.NotNull;
 
 public class NonValidatingTestSortedSSTableWriter extends SortedSSTableWriter
 {
-    public NonValidatingTestSortedSSTableWriter(MockTableWriter tableWriter, 
Path path, DigestAlgorithm digestAlgorithm)
+    public NonValidatingTestSortedSSTableWriter(MockTableWriter tableWriter, 
Path path, DigestAlgorithm digestAlgorithm, int partitionId)
     {
-        super(tableWriter, path, digestAlgorithm);
+        super(tableWriter, path, digestAlgorithm, partitionId);
     }
 
     @Override
-    public void validateSSTables(@NotNull BulkWriterContext writerContext, int 
partitionId)
+    public void validateSSTables(@NotNull BulkWriterContext writerContext)
     {
         // Skip validation for these tests
     }
diff --git 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/RecordWriterTest.java
 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/RecordWriterTest.java
index 14abca6..6fcd366 100644
--- 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/RecordWriterTest.java
+++ 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/RecordWriterTest.java
@@ -389,7 +389,7 @@ class RecordWriterTest
     void testCorruptSSTable()
     {
         rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc,
-                              (wc, path, dp) -> new 
SortedSSTableWriter(tw.setOutDir(path), path, digestAlgorithm));
+                              (wc, path, dp, pid) -> new 
SortedSSTableWriter(tw.setOutDir(path), path, digestAlgorithm, pid));
         Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData();
         // TODO: Add better error handling with human-readable exception 
messages in SSTableReader::new
         // That way we can assert on the exception thrown here
@@ -400,7 +400,7 @@ class RecordWriterTest
     void testWriteWithOutOfRangeTokenFails()
     {
         rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc,
-                              (wc, path, dp) -> new SortedSSTableWriter(tw, 
folder, digestAlgorithm));
+                              (wc, path, dp, pid) -> new 
SortedSSTableWriter(tw, folder, digestAlgorithm, pid));
         Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData(5, 
Range.all(), false, false, false);
         RuntimeException ex = assertThrows(RuntimeException.class, () -> 
rw.write(data));
         String expectedErr = "java.lang.IllegalStateException: Received Token 
" +
@@ -412,7 +412,7 @@ class RecordWriterTest
     void testAddRowThrowingFails()
     {
         rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc,
-                              (wc, path, dp) -> new SortedSSTableWriter(tw, 
folder, digestAlgorithm));
+                              (wc, path, dp, pid) -> new 
SortedSSTableWriter(tw, folder, digestAlgorithm, pid));
         tw.setAddRowThrows(true);
         Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData();
         RuntimeException ex = assertThrows(RuntimeException.class, () -> 
rw.write(data));
@@ -425,7 +425,7 @@ class RecordWriterTest
         // Mock context returns a 60-minute allowable time skew, so we use 
something just outside the limits
         long sixtyOneMinutesInMillis = TimeUnit.MINUTES.toMillis(61);
         rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc,
-                              (wc, path, dp) -> new SortedSSTableWriter(tw, 
folder, digestAlgorithm));
+                              (wc, path, dp, pid) -> new 
SortedSSTableWriter(tw, folder, digestAlgorithm, pid));
         writerContext.setTimeProvider(() -> System.currentTimeMillis() - 
sixtyOneMinutesInMillis);
         Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData();
         RuntimeException ex = assertThrows(RuntimeException.class, () -> 
rw.write(data));
diff --git 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriterTest.java
 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriterTest.java
index e7acbfb..f7d4584 100644
--- 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriterTest.java
+++ 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriterTest.java
@@ -25,6 +25,8 @@ import java.nio.file.DirectoryStream;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
 import java.util.stream.Collectors;
 import javax.validation.constraints.NotNull;
 
@@ -86,15 +88,20 @@ public class SortedSSTableWriterTest
     public void canCreateWriterForVersion(String version) throws IOException
     {
         MockBulkWriterContext writerContext = new 
MockBulkWriterContext(tokenRangeMapping, version, 
ConsistencyLevel.CL.LOCAL_QUORUM);
-        SortedSSTableWriter tw = new SortedSSTableWriter(writerContext, 
tmpDir, new XXHash32DigestAlgorithm());
+        SortedSSTableWriter tw = new SortedSSTableWriter(writerContext, 
tmpDir, new XXHash32DigestAlgorithm(), 1);
         tw.addRow(BigInteger.ONE, ImmutableMap.of("id", 1, "date", 1, 
"course", "foo", "marks", 1));
-        tw.close(writerContext, 1);
+        tw.close(writerContext);
+        Set<Path> dataFilePaths = new HashSet<>();
         try (DirectoryStream<Path> dataFileStream = 
Files.newDirectoryStream(tw.getOutDir(), "*Data.db"))
         {
-            dataFileStream.forEach(dataFilePath ->
-                                   
assertEquals(CassandraVersionFeatures.cassandraVersionFeaturesFromCassandraVersion(version).getMajorVersion(),
-                                                
SSTables.cassandraVersionFromTable(dataFilePath).getMajorVersion()));
+            dataFileStream.forEach(dataFilePath -> {
+                dataFilePaths.add(dataFilePath);
+                
assertEquals(CassandraVersionFeatures.cassandraVersionFeaturesFromCassandraVersion(version).getMajorVersion(),
+                             
SSTables.cassandraVersionFromTable(dataFilePath).getMajorVersion());
+            });
         }
-        tw.validateSSTables(writerContext, 1);
+        // no exception should be thrown from both the validate methods
+        tw.validateSSTables(writerContext);
+        tw.validateSSTables(writerContext, dataFilePaths);
     }
 }
diff --git 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/StreamSessionConsistencyTest.java
 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/StreamSessionConsistencyTest.java
index b9722d1..2e2265d 100644
--- 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/StreamSessionConsistencyTest.java
+++ 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/StreamSessionConsistencyTest.java
@@ -114,7 +114,7 @@ public class StreamSessionConsistencyTest
         });
         StreamSession<?> streamSession = 
createStreamSession(NonValidatingTestSortedSSTableWriter::new);
         streamSession.addRow(BigInteger.valueOf(102L), COLUMN_BIND_VALUES);
-        Future<?> fut = streamSession.scheduleStreamAsync(1, executor);
+        Future<?> fut = streamSession.finalizeStreamAsync();
         if (shouldFail)
         {
             ExecutionException exception = 
assertThrows(ExecutionException.class, fut::get);
@@ -152,7 +152,7 @@ public class StreamSessionConsistencyTest
         writerContext.setUploadSupplier(instance -> 
dcFailures.get(instance.datacenter()).getAndDecrement() <= 0);
         StreamSession<?> streamSession = 
createStreamSession(NonValidatingTestSortedSSTableWriter::new);
         streamSession.addRow(BigInteger.valueOf(102L), COLUMN_BIND_VALUES);
-        Future<?> fut =  streamSession.scheduleStreamAsync(1, executor);
+        Future<?> fut =  streamSession.finalizeStreamAsync();
         if (shouldFail)
         {
             ExecutionException exception = 
assertThrows(ExecutionException.class, fut::get);
@@ -209,10 +209,11 @@ public class StreamSessionConsistencyTest
     private StreamSession<?> createStreamSession(MockTableWriter.Creator 
writerCreator)
     {
         return new DirectStreamSession(writerContext,
-                                       writerCreator.create(tableWriter, 
folder, digestAlgorithm),
+                                       writerCreator.create(tableWriter, 
folder, digestAlgorithm, 1),
                                        transportContext,
                                        "sessionId",
                                        RANGE,
-                                       new 
ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner()));
+                                       new 
ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner()),
+                                       executor);
     }
 }
diff --git 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSessionTest.java
 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSessionTest.java
index 01c4960..cbba94d 100644
--- 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSessionTest.java
+++ 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSessionTest.java
@@ -28,6 +28,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.UUID;
+import java.util.concurrent.Executors;
 
 import com.google.common.collect.BoundType;
 import com.google.common.collect.ImmutableList;
@@ -95,7 +96,7 @@ class BlobStreamSessionTest
         when(job.getRestoreJobId()).thenReturn(jobId);
         when(job.qualifiedTableName()).thenReturn(new QualifiedTableName("ks", 
"table1"));
         MockTableWriter tableWriter = new MockTableWriter(folder);
-        SortedSSTableWriter sstableWriter = new 
NonValidatingTestSortedSSTableWriter(tableWriter, folder, new 
XXHash32DigestAlgorithm());
+        SortedSSTableWriter sstableWriter = new 
NonValidatingTestSortedSSTableWriter(tableWriter, folder, new 
XXHash32DigestAlgorithm(), 1);
 
         DataTransportInfo transportInfo = mock(DataTransportInfo.class);
         when(transportInfo.getTransport()).thenReturn(DataTransport.S3_COMPAT);
@@ -135,7 +136,8 @@ class BlobStreamSessionTest
 
             BlobStreamSession ss = new BlobStreamSession(spiedWriterContext, 
sstableWriter,
                                                          transportContext, 
sessionId,
-                                                         range, bridge, 
replicaAwareFailureHandler);
+                                                         range, bridge, 
replicaAwareFailureHandler,
+                                                         
Executors.newSingleThreadExecutor());
 
             // test begins
             for (Bundle bundle : bundles)
diff --git 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableListerTest.java
 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableListerTest.java
index 5abfa19..492794a 100644
--- 
a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableListerTest.java
+++ 
b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableListerTest.java
@@ -22,6 +22,8 @@ package org.apache.cassandra.spark.bulkwriter.blobupload;
 import java.io.IOException;
 import java.math.BigInteger;
 import java.net.URISyntaxException;
+import java.nio.file.DirectoryStream;
+import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.ArrayList;
@@ -39,6 +41,7 @@ import org.apache.cassandra.spark.data.QualifiedTableName;
 import org.apache.cassandra.spark.utils.TemporaryDirectory;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.mockito.Mockito.mock;
@@ -46,28 +49,20 @@ import static org.mockito.Mockito.when;
 
 class SSTableListerTest
 {
+    private Path outputDir;
+
     @Test
     void testOutput() throws URISyntaxException
     {
-        Path outputDir = 
Paths.get(getClass().getResource("/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0").toURI());
-        CassandraBridge bridge = mock(CassandraBridge.class);
-
-        SSTableSummary summary1 = new SSTableSummary(BigInteger.valueOf(1L), 
BigInteger.valueOf(3L), "na-1-big-");
-        SSTableSummary summary2 = new SSTableSummary(BigInteger.valueOf(3L), 
BigInteger.valueOf(6L), "na-2-big-");
-
-        FileSystemSSTable ssTable1 = new 
FileSystemSSTable(outputDir.resolve("na-1-big-Data.db"), false, null);
-        FileSystemSSTable ssTable2 = new 
FileSystemSSTable(outputDir.resolve("na-2-big-Data.db"), false, null);
-        when(bridge.getSSTableSummary("ks", "table1", 
ssTable1)).thenReturn(summary1);
-        when(bridge.getSSTableSummary("ks", "table1", 
ssTable2)).thenReturn(summary2);
-        SSTableLister ssTableLister = new SSTableLister(new 
QualifiedTableName("ks", "table1"), bridge);
-        ssTableLister.includeDirectory(outputDir);
+        SSTableLister sstableLister = setupSSTableLister();
+        sstableLister.includeDirectory(outputDir);
         List<SSTableLister.SSTableFilesAndRange> sstables = new ArrayList<>();
         // 10196 is the total size of files in 
/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0
         // If this line fails, maybe something has been changed in the folder.
-        assertEquals(10196, ssTableLister.totalSize());
-        while (!ssTableLister.isEmpty())
+        assertEquals(10196, sstableLister.totalSize());
+        while (!sstableLister.isEmpty())
         {
-            sstables.add(ssTableLister.consumeOne());
+            sstables.add(sstableLister.consumeOne());
         }
         assertEquals(2, sstables.size());
         Set<String> ssTablePrefixes = sstables.stream()
@@ -106,4 +101,68 @@ class SSTableListerTest
             assertTrue(ssTableLister.isEmpty());
         }
     }
+
+    @Test
+    void testIncludeSSTable() throws Exception
+    {
+        SSTableLister sstableLister = setupSSTableLister();
+        List<Path> sstableComponents = new ArrayList<>();
+        try (DirectoryStream<Path> stream = 
Files.newDirectoryStream(outputDir, "na-1-big-*"))
+        {
+            stream.forEach(sstableComponents::add);
+        }
+        sstableLister.includeSSTable(sstableComponents);
+        List<SSTableLister.SSTableFilesAndRange> sstables = new ArrayList<>();
+        assertFalse(sstableLister.isEmpty());
+        assertEquals(5098, sstableLister.totalSize());
+        while (!sstableLister.isEmpty())
+        {
+            sstables.add(sstableLister.consumeOne());
+        }
+        assertEquals(0, sstableLister.totalSize());
+        assertEquals(1, sstables.size());
+        Set<Path> range1Files = sstables.get(0).files;
+        
assertTrue(range1Files.contains(outputDir.resolve("na-1-big-Data.db")));
+        
assertTrue(range1Files.contains(outputDir.resolve("na-1-big-Index.db")));
+        
assertTrue(range1Files.contains(outputDir.resolve("na-1-big-Summary.db")));
+        
assertTrue(range1Files.contains(outputDir.resolve("na-1-big-Statistics.db")));
+        
assertTrue(range1Files.contains(outputDir.resolve("na-1-big-TOC.txt")));
+
+        // now include the entire directory
+        // note that one sstable has been included. The sstable should be 
ignored when including the directory
+        sstableLister.includeDirectory(outputDir);
+        assertFalse(sstableLister.isEmpty());
+        assertEquals(5098, sstableLister.totalSize());
+        int producedSSTables = 0;
+        while (!sstableLister.isEmpty())
+        {
+            producedSSTables += 1;
+            sstables.add(sstableLister.consumeOne());
+        }
+        assertEquals(1, producedSSTables);
+        assertEquals(0, sstableLister.totalSize());
+        assertEquals(2, sstables.size());
+
+        Set<Path> range2Files = sstables.get(1).files;
+        
assertTrue(range2Files.contains(outputDir.resolve("na-2-big-Data.db")));
+        
assertTrue(range2Files.contains(outputDir.resolve("na-2-big-Index.db")));
+        
assertTrue(range2Files.contains(outputDir.resolve("na-2-big-Summary.db")));
+        
assertTrue(range2Files.contains(outputDir.resolve("na-2-big-Statistics.db")));
+        
assertTrue(range2Files.contains(outputDir.resolve("na-2-big-TOC.txt")));
+    }
+
+    private SSTableLister setupSSTableLister() throws URISyntaxException
+    {
+        outputDir = 
Paths.get(getClass().getResource("/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0").toURI());
+        CassandraBridge bridge = mock(CassandraBridge.class);
+
+        SSTableSummary summary1 = new SSTableSummary(BigInteger.valueOf(1L), 
BigInteger.valueOf(3L), "na-1-big-");
+        SSTableSummary summary2 = new SSTableSummary(BigInteger.valueOf(3L), 
BigInteger.valueOf(6L), "na-2-big-");
+
+        FileSystemSSTable ssTable1 = new 
FileSystemSSTable(outputDir.resolve("na-1-big-Data.db"), false, null);
+        FileSystemSSTable ssTable2 = new 
FileSystemSSTable(outputDir.resolve("na-2-big-Data.db"), false, null);
+        when(bridge.getSSTableSummary("ks", "table1", 
ssTable1)).thenReturn(summary1);
+        when(bridge.getSSTableSummary("ks", "table1", 
ssTable2)).thenReturn(summary2);
+        return new SSTableLister(new QualifiedTableName("ks", "table1"), 
bridge);
+    }
 }
diff --git 
a/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/bridge/SSTableWriterImplementation.java
 
b/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/bridge/SSTableWriterImplementation.java
index fdc9cab..b97b236 100644
--- 
a/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/bridge/SSTableWriterImplementation.java
+++ 
b/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/bridge/SSTableWriterImplementation.java
@@ -19,11 +19,25 @@
 
 package org.apache.cassandra.bridge;
 
+import java.io.Closeable;
 import java.io.IOException;
+import java.nio.file.DirectoryStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.HashSet;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
 
 import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import org.apache.cassandra.config.Config;
 import org.apache.cassandra.dht.IPartitioner;
@@ -31,6 +45,7 @@ import org.apache.cassandra.dht.Murmur3Partitioner;
 import org.apache.cassandra.dht.RandomPartitioner;
 import org.apache.cassandra.exceptions.InvalidRequestException;
 import org.apache.cassandra.io.sstable.CQLSSTableWriter;
+import org.apache.cassandra.util.ThreadUtil;
 import org.jetbrains.annotations.NotNull;
 
 public class SSTableWriterImplementation implements SSTableWriter
@@ -40,7 +55,12 @@ public class SSTableWriterImplementation implements 
SSTableWriter
         Config.setClientMode(true);
     }
 
+    private static final Logger LOGGER = 
LoggerFactory.getLogger(SSTableWriterImplementation.class);
+
     private final CQLSSTableWriter writer;
+    private final Path outputDir;
+    private final SSTableWatcher sstableWatcher;
+    private Consumer<Set<SSTableDescriptor>> producedSSTablesListener;
 
     public SSTableWriterImplementation(String inDirectory,
                                        String partitioner,
@@ -48,17 +68,95 @@ public class SSTableWriterImplementation implements 
SSTableWriter
                                        String insertStatement,
                                        @NotNull Set<String> 
userDefinedTypeStatements,
                                        int bufferSizeMB)
+    {
+        this(inDirectory, partitioner, createStatement, insertStatement, 
userDefinedTypeStatements, bufferSizeMB, 10);
+    }
+
+    @VisibleForTesting
+    SSTableWriterImplementation(String inDirectory,
+                                String partitioner,
+                                String createStatement,
+                                String insertStatement,
+                                @NotNull Set<String> userDefinedTypeStatements,
+                                int bufferSizeMB,
+                                long sstableWatcherDelaySeconds)
     {
         IPartitioner cassPartitioner = 
partitioner.toLowerCase().contains("random") ? new RandomPartitioner()
                                                                                
     : new Murmur3Partitioner();
 
-        CQLSSTableWriter.Builder builder = configureBuilder(inDirectory,
-                                                            createStatement,
-                                                            insertStatement,
-                                                            bufferSizeMB,
-                                                            
userDefinedTypeStatements,
-                                                            cassPartitioner);
-        writer = builder.build();
+        this.writer = configureBuilder(inDirectory,
+                                       createStatement,
+                                       insertStatement,
+                                       bufferSizeMB,
+                                       userDefinedTypeStatements,
+                                       cassPartitioner)
+                      .build();
+        this.outputDir = Paths.get(inDirectory);
+        this.sstableWatcher = new SSTableWatcher(sstableWatcherDelaySeconds);
+    }
+
+    private class SSTableWatcher implements Closeable
+    {
+        // The TOC component is the last one flushed when finishing a SSTable.
+        // Therefore, it monitors the creation of the TOC component to 
determine the creation of SSTable
+        private static final String TOC_COMPONENT_SUFFIX = "-TOC.txt";
+        private static final String GLOB_PATTERN_FOR_TOC = "*" + 
TOC_COMPONENT_SUFFIX;
+
+        private final ScheduledExecutorService sstableWatcherScheduler;
+        private final Set<SSTableDescriptor> knownSSTables;
+
+        SSTableWatcher(long delaySeconds)
+        {
+            ThreadFactory tf = ThreadUtil.threadFactory("SSTableWatcher-" + 
outputDir.getFileName().toString());
+            this.sstableWatcherScheduler = 
Executors.newSingleThreadScheduledExecutor(tf);
+            this.knownSSTables = new HashSet<>();
+            sstableWatcherScheduler.scheduleWithFixedDelay(this::listSSTables, 
delaySeconds, delaySeconds, TimeUnit.SECONDS);
+        }
+
+        private void listSSTables()
+        {
+            try (DirectoryStream<Path> stream = 
Files.newDirectoryStream(outputDir, GLOB_PATTERN_FOR_TOC))
+            {
+                HashSet<SSTableDescriptor> newlyProducedSSTables = new 
HashSet<>();
+                stream.forEach(path -> {
+                    String baseFilename = 
path.getFileName().toString().replace(TOC_COMPONENT_SUFFIX, "");
+                    SSTableDescriptor sstable = new 
SSTableDescriptor(baseFilename);
+                    if (!knownSSTables.contains(sstable))
+                    {
+                        newlyProducedSSTables.add(sstable);
+                    }
+                });
+
+                if (!newlyProducedSSTables.isEmpty())
+                {
+                    knownSSTables.addAll(newlyProducedSSTables);
+                    producedSSTablesListener.accept(newlyProducedSSTables);
+                }
+            }
+            catch (IOException e)
+            {
+                LOGGER.warn("Fails to list SSTables", e);
+            }
+        }
+
+        @Override
+        public void close()
+        {
+            sstableWatcherScheduler.shutdown();
+            try
+            {
+                boolean terminated = 
sstableWatcherScheduler.awaitTermination(10, TimeUnit.SECONDS);
+                if (!terminated)
+                {
+                    LOGGER.debug("SSTableWatcher scheduler termination times 
out");
+                }
+            }
+            catch (InterruptedException e)
+            {
+                LOGGER.debug("Closing SSTableWatcher scheduler is 
interrupted");
+            }
+            knownSSTables.clear();
+        }
     }
 
     @Override
@@ -74,9 +172,18 @@ public class SSTableWriterImplementation implements 
SSTableWriter
         }
     }
 
+    @Override
+    public void setSSTablesProducedListener(Consumer<Set<SSTableDescriptor>> 
listener)
+    {
+        producedSSTablesListener = Objects.requireNonNull(listener);
+    }
+
     @Override
     public void close() throws IOException
     {
+        // close sstablewatcher first. There is no need to continue monitoring 
the new sstables. StreamSession should handle the last set of sstables.
+        // writer.close is guaranteed to create one more sstable
+        sstableWatcher.close();
         writer.close();
     }
 
diff --git 
a/cassandra-four-zero-bridge/src/test/java/org/apache/cassandra/bridge/SSTableWriterImplementationTest.java
 
b/cassandra-four-zero-bridge/src/test/java/org/apache/cassandra/bridge/SSTableWriterImplementationTest.java
index 589ee4e..8d396f0 100644
--- 
a/cassandra-four-zero-bridge/src/test/java/org/apache/cassandra/bridge/SSTableWriterImplementationTest.java
+++ 
b/cassandra-four-zero-bridge/src/test/java/org/apache/cassandra/bridge/SSTableWriterImplementationTest.java
@@ -20,10 +20,15 @@
 package org.apache.cassandra.bridge;
 
 import java.io.File;
+import java.io.IOException;
 import java.lang.reflect.Field;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
 
+import com.google.common.util.concurrent.Uninterruptibles;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.io.TempDir;
 
@@ -39,7 +44,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
  */
 class SSTableWriterImplementationTest
 {
-    public static final String CREATE_STATEMENT = "CREATE TABLE 
test_keyspace.test_table (a int, b text)";
+    public static final String CREATE_STATEMENT = "CREATE TABLE 
test_keyspace.test_table (a int PRIMARY KEY, b text)";
     public static final String INSERT_STATEMENT = "INSERT INTO 
test_keyspace.test_table (a, b) VALUES (?, ?)";
 
     @TempDir
@@ -52,7 +57,7 @@ class SSTableWriterImplementationTest
                                                                                
         CREATE_STATEMENT,
                                                                                
         INSERT_STATEMENT,
                                                                                
         250,
-                                                                               
         new HashSet<>(),
+                                                                               
         Collections.emptySet(),
                                                                                
         new Murmur3Partitioner());
 
 
@@ -60,6 +65,41 @@ class SSTableWriterImplementationTest
         assertEquals(250, peekBufferSizeInMB(builder));
     }
 
+    @Test
+    void testGetProducedSSTables() throws IOException
+    {
+        Set<SSTableDescriptor> produced = new HashSet<>();
+        try (SSTableWriterImplementation writer = new 
SSTableWriterImplementation(writeDirectory.getAbsolutePath(),
+                                                                               
   "murmur3",
+                                                                               
   CREATE_STATEMENT,
+                                                                               
   INSERT_STATEMENT,
+                                                                               
   Collections.emptySet(),
+                                                                               
   1,
+                                                                               
   1))
+        {
+            writer.setSSTablesProducedListener(produced::addAll);
+            assertTrue(produced.isEmpty());
+
+            File tocFile1 = new File(writeDirectory, "foo-big-TOC.txt");
+            File tocFile2 = new File(writeDirectory, "bar-big-TOC.txt");
+            assertTrue(tocFile1.createNewFile());
+            assertTrue(tocFile2.createNewFile());
+            waitForProduced(produced);
+            assertEquals(2, produced.size());
+            Set<SSTableDescriptor> expected = new HashSet<>(Arrays.asList(new 
SSTableDescriptor("foo-big"),
+                                                                          new 
SSTableDescriptor("bar-big")));
+            assertEquals(expected, produced);
+            produced.clear();
+
+            assertTrue(produced.isEmpty());
+            File tocFile3 = new File(writeDirectory, "baz-big-TOC.txt");
+            assertTrue(tocFile3.createNewFile());
+            waitForProduced(produced);
+            assertEquals(1, produced.size());
+            assertEquals(Collections.singleton(new 
SSTableDescriptor("baz-big")), produced);
+        }
+    }
+
     static boolean peekSorted(CQLSSTableWriter.Builder builder) throws 
NoSuchFieldException, IllegalAccessException
     {
         Field sortedField = ReflectionUtils.getField(builder.getClass(), 
"sorted");
@@ -77,7 +117,7 @@ class SSTableWriterImplementationTest
         return (long) sizeField.get(builder);
     }
 
-    static Field findFirstField(Class<?> clazz, String... fieldNames) throws 
NoSuchFieldException, IllegalAccessException
+    static Field findFirstField(Class<?> clazz, String... fieldNames) throws 
NoSuchFieldException
     {
         Field field = null;
         for (String fieldName : fieldNames)
@@ -99,4 +139,13 @@ class SSTableWriterImplementationTest
 
         return field;
     }
+
+    private void waitForProduced(Set<SSTableDescriptor> produced)
+    {
+        int i = 15; // the test runs roughly within 2 seconds; 3_000 
milliseconds timeout should suffice.
+        while (produced.isEmpty() && i-- > 0)
+        {
+            Uninterruptibles.sleepUninterruptibly(200, TimeUnit.MILLISECONDS);
+        }
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to