This is an automated email from the ASF dual-hosted git repository.

xinyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new 493de64bb SAMZA-2742: [Pipeline Drain] Add Drain components and 
integrate them with SamzaContainer and JC (#1605)
493de64bb is described below

commit 493de64bb8c0228bbda11822ae3c84c2085263bf
Author: ajo thomas <[email protected]>
AuthorDate: Mon Jun 13 15:06:05 2022 -0700

    SAMZA-2742: [Pipeline Drain] Add Drain components and integrate them with 
SamzaContainer and JC (#1605)
---
 .../clustermanager/ClusterBasedJobCoordinator.java |  29 ++-
 .../clustermanager/ContainerProcessManager.java    |   4 +
 .../java/org/apache/samza/config/JobConfig.java    |   9 +
 .../java/org/apache/samza/drain/DrainMonitor.java  | 244 +++++++++++++++++++++
 .../org/apache/samza/drain/DrainNotification.java  |  77 +++++++
 .../samza/drain/DrainNotificationObjectMapper.java |  86 ++++++++
 .../java/org/apache/samza/drain/DrainUtils.java    | 129 +++++++++++
 .../apache/samza/processor/StreamProcessor.java    |   3 +-
 .../apache/samza/runtime/ContainerLaunchUtil.java  |   9 +-
 .../apache/samza/container/SamzaContainer.scala    |  33 ++-
 .../apache/samza/job/local/ThreadJobFactory.scala  |  16 +-
 .../org/apache/samza/system/SystemConsumers.scala  |  13 ++
 .../org/apache/samza/drain/DrainMonitorTests.java  | 199 +++++++++++++++++
 .../drain/DrainNotificationObjectMapperTests.java  |  41 ++++
 .../org/apache/samza/drain/DrainUtilsTests.java    | 135 ++++++++++++
 .../samza/container/TestSamzaContainer.scala       |  21 +-
 .../samza/processor/StreamProcessorTestUtils.scala |  18 +-
 17 files changed, 1044 insertions(+), 22 deletions(-)

diff --git 
a/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java
 
b/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java
index 286a04087..2d1b810ce 100644
--- 
a/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java
+++ 
b/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java
@@ -35,6 +35,7 @@ import org.apache.samza.container.ExecutionContainerIdManager;
 import org.apache.samza.container.LocalityManager;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.coordinator.InputStreamsDiscoveredException;
+import org.apache.samza.drain.DrainUtils;
 import org.apache.samza.job.metadata.JobCoordinatorMetadataManager;
 import org.apache.samza.coordinator.JobModelManager;
 import org.apache.samza.coordinator.MetadataResourceUtil;
@@ -181,7 +182,12 @@ public class ClusterBasedJobCoordinator {
   /**
    * Variable to keep the callback exception
    */
-  volatile private Exception coordinatorException = null;
+  volatile private Exception coordinatorCallbackException = null;
+
+  /**
+   * Variable to keep any exception that happened during JC run
+   */
+  volatile private Throwable coordinatorRunException = null;
 
   /**
    * Creates a new ClusterBasedJobCoordinator instance.
@@ -333,6 +339,7 @@ public class ClusterBasedJobCoordinator {
       }
     } catch (Throwable e) {
       LOG.error("Exception thrown in the JobCoordinator loop", e);
+      coordinatorRunException = e;
       throw new SamzaException(e);
     } finally {
       onShutDown();
@@ -340,8 +347,8 @@ public class ClusterBasedJobCoordinator {
   }
 
   private boolean checkAndThrowException() throws Exception {
-    if (coordinatorException != null) {
-      throw coordinatorException;
+    if (coordinatorCallbackException != null) {
+      throw coordinatorCallbackException;
     }
     return false;
   }
@@ -377,6 +384,7 @@ public class ClusterBasedJobCoordinator {
    */
   private void onShutDown() {
     try {
+      cleanupDrainNotifications();
       partitionMonitor.stop();
       inputStreamRegexMonitor.ifPresent(StreamRegexMonitor::stop);
       systemAdmins.stop();
@@ -399,6 +407,17 @@ public class ClusterBasedJobCoordinator {
     }
   }
 
+  private void cleanupDrainNotifications() {
+    if (containerProcessManager.isShutdownSuccessful() && 
coordinatorRunException == null) {
+      // Garbage collect all DrainNotifications from Drain metadata-store of 
the job if the following conditions
+      // are met:
+      // 1) If the job is draining
+      // 2) All containers shutdown successfully due to drain
+      // 3) There was no exception in the coordinator
+      DrainUtils.cleanup(metadataStore, config);
+    }
+  }
+
   private void shutDowncontainerPlacementRequestAllocatorAndUtils() {
     // Shutdown container placement handler
     containerPlacementRequestAllocator.stop();
@@ -422,7 +441,7 @@ public class ClusterBasedJobCoordinator {
             streamsChanged.toString());
         state.status = SamzaApplicationState.SamzaAppStatus.FAILED;
       }
-      coordinatorException = new PartitionChangeException(
+      coordinatorCallbackException = new PartitionChangeException(
           "Input topic partition count changes detected for topics: " + 
streamsChanged.toString());
     });
   }
@@ -436,7 +455,7 @@ public class ClusterBasedJobCoordinator {
               + " Existing input streams: {}", newInputStreams, 
initialInputSet);
           state.status = SamzaApplicationState.SamzaAppStatus.FAILED;
         }
-        coordinatorException = new InputStreamsDiscoveredException("New input 
streams discovered: " + newInputStreams);
+        coordinatorCallbackException = new 
InputStreamsDiscoveredException("New input streams discovered: " + 
newInputStreams);
       });
   }
 
diff --git 
a/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java
 
b/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java
index 254e16ec0..f8719890e 100644
--- 
a/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java
+++ 
b/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java
@@ -214,6 +214,10 @@ public class ContainerProcessManager implements 
ClusterResourceManager.Callback
     return jobFailureCriteriaMet || state.completedProcessors.get() == 
state.processorCount.get() || !allocatorThread.isAlive();
   }
 
+  public boolean isShutdownSuccessful() {
+    return state.status == SamzaApplicationState.SamzaAppStatus.SUCCEEDED;
+  }
+
   public void start() {
     LOG.info("Starting the container process manager");
 
diff --git a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java 
b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
index 4bff52d59..145f4cda4 100644
--- a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
@@ -164,6 +164,11 @@ public class JobConfig extends MapConfig {
 
   private static final String JOB_STARTPOINT_ENABLED = 
"job.startpoint.enabled";
 
+  // Enable DrainMonitor in Samza Containers
+  // Default is false for now. Will be turned on after testing
+  public static final String DRAIN_MONITOR_ENABLED = 
"samza.drain-monitor.enabled";
+  public static final boolean DRAIN_MONITOR_ENABLED_DEFAULT = false;
+
   // Enable ClusterBasedJobCoordinator aka ApplicationMaster High Availability 
(AM-HA).
   // High availability allows new AM to establish connection with already 
running containers
   public static final String YARN_AM_HIGH_AVAILABILITY_ENABLED = 
"yarn.am.high-availability.enabled";
@@ -470,6 +475,10 @@ public class JobConfig extends MapConfig {
     return getBoolean(YARN_AM_HIGH_AVAILABILITY_ENABLED, 
YARN_AM_HIGH_AVAILABILITY_ENABLED_DEFAULT);
   }
 
+  public boolean getDrainMonitorEnabled() {
+    return getBoolean(DRAIN_MONITOR_ENABLED, DRAIN_MONITOR_ENABLED_DEFAULT);
+  }
+
   public long getContainerHeartbeatRetryCount() {
     return getLong(YARN_CONTAINER_HEARTBEAT_RETRY_COUNT, 
YARN_CONTAINER_HEARTBEAT_RETRY_COUNT_DEFAULT);
   }
diff --git a/samza-core/src/main/java/org/apache/samza/drain/DrainMonitor.java 
b/samza-core/src/main/java/org/apache/samza/drain/DrainMonitor.java
new file mode 100644
index 000000000..6b5c98ef4
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/drain/DrainMonitor.java
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.drain;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import java.io.IOException;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.concurrent.GuardedBy;
+import org.apache.samza.SamzaException;
+import org.apache.samza.config.ApplicationConfig;
+import org.apache.samza.config.Config;
+import 
org.apache.samza.coordinator.metadatastore.NamespaceAwareCoordinatorStreamStore;
+import org.apache.samza.metadatastore.MetadataStore;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * DrainMonitor is intended to monitor the MetadataStore for {@link 
DrainNotification} and invoke
+ * the {@link DrainCallback}.
+ * */
+public class DrainMonitor {
+  private static final Logger LOG = 
LoggerFactory.getLogger(DrainMonitor.class);
+
+  /**
+   * Describes the state of the monitor.
+   * */
+  public enum State {
+    /**
+     * Initial state when DrainMonitor is not polling for DrainNotifications.
+     * */
+    INIT,
+    /**
+     * When Drain Monitor is started, it moves from INIT to RUNNING state and 
starts polling
+     * for Drain Notifications.
+     * */
+    RUNNING,
+    /**
+     * Indicates that the Drain Monitor is stopped. The DrainMonitor could 
have been explicitly stopped or stopped on
+     * its own if a DrainNotification was encountered.
+     * */
+    STOPPED
+  }
+
+  private static final int POLLING_INTERVAL_MILLIS = 60_000;
+  private static final int INITIAL_POLL_DELAY_MILLIS = 0;
+
+  private final ScheduledExecutorService schedulerService =
+      Executors.newSingleThreadScheduledExecutor(
+          new ThreadFactoryBuilder()
+              .setNameFormat("Samza DrainMonitor Thread-%d")
+              .setDaemon(true)
+              .build());
+  private final String appRunId;
+  private final long pollingIntervalMillis;
+  private final NamespaceAwareCoordinatorStreamStore drainMetadataStore;
+  // Used to guard write access to state.
+  private final Object lock = new Object();
+
+  @GuardedBy("lock")
+  private State state = State.INIT;
+  private DrainCallback callback;
+
+  public DrainMonitor(MetadataStore metadataStore, Config config) {
+    this(metadataStore, config, POLLING_INTERVAL_MILLIS);
+  }
+
+  public DrainMonitor(MetadataStore metadataStore, Config config, long 
pollingIntervalMillis) {
+    Preconditions.checkNotNull(metadataStore, "MetadataStore parameter cannot 
be null.");
+    Preconditions.checkNotNull(config, "Config parameter cannot be null.");
+    Preconditions.checkArgument(pollingIntervalMillis > 0,
+        String.format("Polling interval specified is %d ms. It should be 
greater than 0.", pollingIntervalMillis));
+    this.drainMetadataStore =
+        new NamespaceAwareCoordinatorStreamStore(metadataStore, 
DrainUtils.DRAIN_METADATA_STORE_NAMESPACE);
+    ApplicationConfig applicationConfig = new ApplicationConfig(config);
+    this.appRunId = applicationConfig.getRunId();
+    this.pollingIntervalMillis = pollingIntervalMillis;
+  }
+
+  /**
+   * Starts the DrainMonitor.
+   * */
+  public void start() {
+    Preconditions.checkState(callback != null,
+        "Drain Callback needs to be set using registerCallback(callback) prior 
to starting the DrainManager.");
+    synchronized (lock) {
+      switch (state) {
+        case INIT:
+          if (shouldDrain(drainMetadataStore, appRunId)) {
+            /*
+             * Prior to starting the periodic polling, we are doing a one-time 
check on the calling(container main) thread
+             * to see if DrainNotification is present in the metadata store 
for the current deployment.
+             * This check is to deal with the case where a container might 
have re-started during Drain.
+             * If yes, we will set the container to drain mode to prevent it 
from processing any new messages. This will
+             * in-turn guarantee that intermediate Drain control messages from 
the previous incarnation of the container are
+             * processed and there are no duplicate intermediate control 
messages for the same deployment.
+             * */
+            LOG.info("Found DrainNotification message on container start. 
Skipping poll of DrainNotifications.");
+            callback.onDrain();
+          } else {
+            state = State.RUNNING;
+            schedulerService.scheduleAtFixedRate(() -> {
+              if (shouldDrain(drainMetadataStore, appRunId)) {
+                LOG.info("Received Drain Notification for deployment: {}", 
appRunId);
+                stop();
+                callback.onDrain();
+              }
+            }, INITIAL_POLL_DELAY_MILLIS, pollingIntervalMillis, 
TimeUnit.MILLISECONDS);
+            LOG.info("Started DrainMonitor.");
+          }
+          break;
+        case RUNNING:
+        case STOPPED:
+          LOG.info("Cannot call start() on the DrainMonitor when it is in {} 
state.", state);
+          break;
+      }
+    }
+  }
+
+  /**
+   * Stops the DrainMonitor.
+   * */
+  public void stop() {
+    synchronized (lock) {
+      switch (state) {
+        case RUNNING:
+          schedulerService.shutdownNow();
+          state = State.STOPPED;
+          LOG.info("Stopped DrainMonitor.");
+          break;
+        case INIT:
+        case STOPPED:
+          LOG.info("Cannot stop DrainMonitor as it is not running. State: 
{}.", state);
+          break;
+      }
+    }
+  }
+
+  /**
+   * Register a callback to be executed when DrainNotification is encountered.
+   *
+   * @param callback the callback to register.
+   * @return Returns {@code true} if registration was successful and {@code 
false} if not.
+   * Registration can fail it the DrainMonitor is stopped or a callback is 
already registered.
+   * */
+  public boolean registerDrainCallback(DrainCallback callback) {
+    Preconditions.checkNotNull(callback);
+
+    switch (state) {
+      case RUNNING:
+      case STOPPED:
+        LOG.warn("Cannot register callback when it is in {} state. Please 
register callback before calling start "
+            + "on DrainMonitor.", state);
+        return false;
+      case INIT:
+        if (this.callback != null) {
+          LOG.warn("Cannot register callback as a callback is already 
registered.");
+          return false;
+        }
+        this.callback = callback;
+        return true;
+      default:
+        return false;
+    }
+  }
+
+  /**
+   * Get the current state of the DrainMonitor.
+   * */
+  @VisibleForTesting
+  State getState() {
+    return state;
+  }
+
+  /**
+   * Callback for any action to executed by DrainMonitor implementations once 
Drain is encountered.
+   * Registered using {@link #registerDrainCallback(DrainCallback)}.
+   * */
+  public interface DrainCallback {
+    void onDrain();
+  }
+
+  /**
+   * One time check check to see if there are any DrainNotification messages 
available in the
+   * metadata store for the current deployment.
+   * */
+  static boolean shouldDrain(NamespaceAwareCoordinatorStreamStore 
drainMetadataStore, String deploymentId) {
+    final Optional<List<DrainNotification>> drainNotifications = 
readDrainNotificationMessages(drainMetadataStore);
+    if (drainNotifications.isPresent()) {
+      final ImmutableList<DrainNotification> filteredDrainNotifications = 
drainNotifications.get()
+          .stream()
+          .filter(notification -> 
deploymentId.equals(notification.getDeploymentId()))
+          .collect(ImmutableList.toImmutableList());
+      return !filteredDrainNotifications.isEmpty();
+    }
+    return false;
+  }
+
+  /**
+   * Reads all DrainNotification messages from the metadata store.
+   * */
+  private static Optional<List<DrainNotification>> 
readDrainNotificationMessages(NamespaceAwareCoordinatorStreamStore
+      drainMetadataStore) {
+    final ObjectMapper objectMapper = 
DrainNotificationObjectMapper.getObjectMapper();
+    final ImmutableList<DrainNotification> drainNotifications = 
drainMetadataStore.all()
+        .values()
+        .stream()
+        .map(bytes -> {
+          try {
+            return objectMapper.readValue(bytes, DrainNotification.class);
+          } catch (IOException e) {
+            LOG.error("Unable to deserialize DrainNotification from the 
metadata store", e);
+            throw new SamzaException(e);
+          }
+        })
+        .collect(ImmutableList.toImmutableList());
+    return drainNotifications.size() > 0
+        ? Optional.of(drainNotifications)
+        : Optional.empty();
+  }
+}
diff --git 
a/samza-core/src/main/java/org/apache/samza/drain/DrainNotification.java 
b/samza-core/src/main/java/org/apache/samza/drain/DrainNotification.java
new file mode 100644
index 000000000..a16595e7d
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/drain/DrainNotification.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.drain;
+
+import com.google.common.base.Objects;
+import java.util.UUID;
+
+/**
+ * DrainNotification is a custom message is used by an external controller to 
trigger Drain.
+ * The message is written in the metadata store using {@link DrainUtils}.
+ * */
+public class DrainNotification {
+  /**
+   * Unique identifier of a drain notification.
+   */
+  private final UUID uuid;
+  /**
+   * Unique identifier for a deployment so drain notifications messages can be 
invalidated across a job restarts.
+   */
+  private final String deploymentId;
+
+  public DrainNotification(UUID uuid, String deploymentId) {
+    this.uuid = uuid;
+    this.deploymentId = deploymentId;
+  }
+
+  public UUID getUuid() {
+    return this.uuid;
+  }
+
+  public String getDeploymentId() {
+    return deploymentId;
+  }
+
+  @Override
+  public String toString() {
+    final StringBuilder sb = new StringBuilder("DrainMessage{");
+    sb.append(" UUID: ").append(uuid);
+    sb.append(", deploymentId: '").append(deploymentId).append('\'');
+    sb.append('}');
+    return sb.toString();
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    DrainNotification that = (DrainNotification) o;
+    return Objects.equal(uuid, that.uuid)
+        && Objects.equal(deploymentId, that.deploymentId);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hashCode(uuid, deploymentId);
+  }
+}
diff --git 
a/samza-core/src/main/java/org/apache/samza/drain/DrainNotificationObjectMapper.java
 
b/samza-core/src/main/java/org/apache/samza/drain/DrainNotificationObjectMapper.java
new file mode 100644
index 000000000..48e6db15a
--- /dev/null
+++ 
b/samza-core/src/main/java/org/apache/samza/drain/DrainNotificationObjectMapper.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.drain;
+
+import com.fasterxml.jackson.core.JsonGenerator;
+import com.fasterxml.jackson.core.JsonParser;
+import com.fasterxml.jackson.core.ObjectCodec;
+import com.fasterxml.jackson.databind.DeserializationContext;
+import com.fasterxml.jackson.databind.JsonDeserializer;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.JsonSerializer;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.SerializerProvider;
+import com.fasterxml.jackson.databind.jsontype.NamedType;
+import com.fasterxml.jackson.databind.module.SimpleModule;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+
+
+/**
+ * Wraps a ObjectMapper for serializing and deserializing Drain Notification 
Messages.
+ */
+public class DrainNotificationObjectMapper {
+  private static ObjectMapper objectMapper = null;
+
+  private DrainNotificationObjectMapper() {
+  }
+
+  public static ObjectMapper getObjectMapper() {
+    if (objectMapper == null) {
+      objectMapper = createObjectMapper();
+    }
+    return objectMapper;
+  }
+
+  private static ObjectMapper createObjectMapper() {
+    ObjectMapper objectMapper = new ObjectMapper();
+    SimpleModule module = new SimpleModule("DrainModule");
+    module.addSerializer(DrainNotification.class, new 
DrainNotificationSerializer());
+    module.addDeserializer(DrainNotification.class, new 
DrainNotificationDeserializer());
+    objectMapper.registerModule(module);
+    objectMapper.registerSubtypes(new NamedType(DrainNotification.class));
+    return objectMapper;
+  }
+
+  private static class DrainNotificationSerializer extends 
JsonSerializer<DrainNotification> {
+    @Override
+    public void serialize(DrainNotification value, JsonGenerator 
jsonGenerator, SerializerProvider provider)
+        throws IOException {
+      Map<String, Object> drainMessageMap = new HashMap<>();
+      drainMessageMap.put("uuid", value.getUuid().toString());
+      drainMessageMap.put("deploymentId", value.getDeploymentId());
+      jsonGenerator.writeObject(drainMessageMap);
+    }
+  }
+
+  private static class DrainNotificationDeserializer extends 
JsonDeserializer<DrainNotification> {
+    @Override
+    public DrainNotification deserialize(JsonParser jsonParser, 
DeserializationContext context)
+        throws IOException {
+      ObjectCodec oc = jsonParser.getCodec();
+      JsonNode node = oc.readTree(jsonParser);
+      UUID uuid = UUID.fromString(node.get("uuid").textValue());
+      String deploymentId = node.get("deploymentId").textValue();
+      return new DrainNotification(uuid, deploymentId);
+    }
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/drain/DrainUtils.java 
b/samza-core/src/main/java/org/apache/samza/drain/DrainUtils.java
new file mode 100644
index 000000000..c100a47e1
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/drain/DrainUtils.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ .* with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.drain;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.util.UUID;
+import joptsimple.internal.Strings;
+import org.apache.samza.SamzaException;
+import org.apache.samza.config.ApplicationConfig;
+import org.apache.samza.config.Config;
+import 
org.apache.samza.coordinator.metadatastore.NamespaceAwareCoordinatorStreamStore;
+import org.apache.samza.metadatastore.MetadataStore;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * DrainUtils provides utility methods for managing {@link DrainNotification} 
in the the provided {@link MetadataStore}.
+ * */
+public class DrainUtils {
+  private static final Logger LOG = LoggerFactory.getLogger(DrainUtils.class);
+  private static final Integer VERSION = 1;
+  // namespace for the underlying metadata store
+  public static final String DRAIN_METADATA_STORE_NAMESPACE = "samza-drain-v" 
+ VERSION;
+
+  private DrainUtils() {
+  }
+
+  /**
+   * Writes a {@link DrainNotification} to the underlying metastore. This 
method should be used by external controllers
+   * to issue a DrainNotification to the JobCoordinator and Samza Containers.
+   * @param metadataStore Metadata store to write drain notification to.
+   * @param deploymentId deploymentId for the DrainNotification
+   *
+   * @return generated uuid for the DrainNotification
+   */
+  public static UUID writeDrainNotification(MetadataStore metadataStore, 
String deploymentId) {
+    Preconditions.checkArgument(metadataStore != null, "MetadataStore cannot 
be null.");
+    Preconditions.checkArgument(!Strings.isNullOrEmpty(deploymentId), 
"deploymentId should be non-null.");
+    LOG.info("Attempting to write DrainNotification to metadata-store for the 
deployment ID {}", deploymentId);
+    final NamespaceAwareCoordinatorStreamStore drainMetadataStore =
+        new NamespaceAwareCoordinatorStreamStore(metadataStore, 
DRAIN_METADATA_STORE_NAMESPACE);
+    final ObjectMapper objectMapper = 
DrainNotificationObjectMapper.getObjectMapper();
+    final UUID uuid = UUID.randomUUID();
+    final DrainNotification message = new DrainNotification(uuid, 
deploymentId);
+    try {
+      drainMetadataStore.put(message.getUuid().toString(), 
objectMapper.writeValueAsBytes(message));
+      drainMetadataStore.flush();
+      LOG.info("DrainNotification with id {} written to metadata-store for the 
deployment ID {}", uuid, deploymentId);
+    } catch (Exception ex) {
+      throw new SamzaException(
+          String.format("DrainNotification might have been not written to 
metastore %s", message), ex);
+    }
+    return uuid;
+  }
+
+  /**
+   * Cleans up DrainNotifications for the current deployment from the 
underlying metadata store.
+   * The current deploymentId is extracted from the config.
+   *
+   * @param metadataStore underlying metadata store
+   * @param config Config for the job. Used to extract the deploymentId of the 
job.
+   * */
+  public static void cleanup(MetadataStore metadataStore, Config config) {
+    Preconditions.checkArgument(metadataStore != null, "MetadataStore cannot 
be null.");
+    Preconditions.checkNotNull(config, "Config parameter cannot be null.");
+
+    final ApplicationConfig applicationConfig = new ApplicationConfig(config);
+    final String deploymentId = applicationConfig.getRunId();
+    final ObjectMapper objectMapper = 
DrainNotificationObjectMapper.getObjectMapper();
+    final NamespaceAwareCoordinatorStreamStore drainMetadataStore =
+        new NamespaceAwareCoordinatorStreamStore(metadataStore, 
DRAIN_METADATA_STORE_NAMESPACE);
+
+    if (DrainMonitor.shouldDrain(drainMetadataStore, deploymentId)) {
+      LOG.info("Attempting to clean up DrainNotifications from the 
metadata-store for the current deployment {}", deploymentId);
+      drainMetadataStore.all()
+          .values()
+          .stream()
+          .map(bytes -> {
+            try {
+              return objectMapper.readValue(bytes, DrainNotification.class);
+            } catch (IOException e) {
+              LOG.error("Unable to deserialize DrainNotification from the 
metadata store", e);
+              throw new SamzaException(e);
+            }
+          })
+          .filter(notification -> 
deploymentId.equals(notification.getDeploymentId()))
+          .forEach(notification -> 
drainMetadataStore.delete(notification.getUuid().toString()));
+
+      drainMetadataStore.flush();
+      LOG.info("Successfully cleaned up DrainNotifications from the 
metadata-store for the current deployment {}", deploymentId);
+    } else {
+      LOG.info("No DrainNotification found in the metadata-store for the 
current deployment {}. No need to cleanup.",
+          deploymentId);
+    }
+  }
+
+  /**
+   * Cleans up all DrainNotifications irrespective of the deploymentId.
+   * */
+  public static void cleanupAll(MetadataStore metadataStore) {
+    Preconditions.checkArgument(metadataStore != null, "MetadataStore cannot 
be null.");
+    final NamespaceAwareCoordinatorStreamStore drainMetadataStore =
+        new NamespaceAwareCoordinatorStreamStore(metadataStore, 
DRAIN_METADATA_STORE_NAMESPACE);
+    LOG.info("Attempting to cleanup all DrainNotifications from the 
metadata-store.");
+    drainMetadataStore.all()
+        .keySet()
+        .forEach(drainMetadataStore::delete);
+    drainMetadataStore.flush();
+    LOG.info("Successfully cleaned up all DrainNotifications from the 
metadata-store.");
+  }
+}
diff --git 
a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java 
b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
index 2994606f6..654d9a0b3 100644
--- a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
+++ b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
@@ -395,7 +395,6 @@ public class StreamProcessor {
     } else {
       LOGGER.warn("StartpointManager cannot be instantiated because no 
metadata store defined for this stream processor");
     }
-
     /*
      * StreamProcessor has a metricsRegistry instance variable, but 
StreamProcessor registers its metrics on its own
      * with the reporters. Therefore, don't reuse the 
StreamProcessor.metricsRegistry, because SamzaContainer also
@@ -408,7 +407,7 @@ public class StreamProcessor {
         
Option.apply(this.applicationDefinedContainerContextFactoryOptional.orElse(null)),
         
Option.apply(this.applicationDefinedTaskContextFactoryOptional.orElse(null)),
         Option.apply(this.externalContextOptional.orElse(null)), null, 
startpointManager,
-        Option.apply(diagnosticsManager.orElse(null)));
+        Option.apply(diagnosticsManager.orElse(null)), null);
   }
 
   private static JobCoordinator createJobCoordinator(Config config, String 
processorId, MetricsRegistry metricsRegistry, MetadataStore metadataStore) {
diff --git 
a/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java 
b/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java
index 9cc812192..f499eb34b 100644
--- a/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java
+++ b/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java
@@ -42,6 +42,7 @@ import org.apache.samza.coordinator.stream.messages.SetConfig;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
 import 
org.apache.samza.coordinator.stream.messages.SetExecutionEnvContainerIdMapping;
 import org.apache.samza.diagnostics.DiagnosticsManager;
+import org.apache.samza.drain.DrainMonitor;
 import org.apache.samza.environment.EnvironmentVariables;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.logging.LoggingContextHolder;
@@ -141,6 +142,11 @@ public class ContainerLaunchUtil {
               samzaEpochId, config);
       MetricsRegistryMap metricsRegistryMap = new MetricsRegistryMap();
 
+      DrainMonitor drainMonitor = null;
+      if (new JobConfig(config).getDrainMonitorEnabled()) {
+        drainMonitor = new DrainMonitor(coordinatorStreamStore, config);
+      }
+
       SamzaContainer container = SamzaContainer$.MODULE$.apply(
           containerId, jobModel,
           ScalaJavaUtil.toScalaMap(metricsReporters),
@@ -152,7 +158,8 @@ public class ContainerLaunchUtil {
           Option.apply(externalContextOptional.orElse(null)),
           localityManager,
           startpointManager,
-          Option.apply(diagnosticsManager.orElse(null)));
+          Option.apply(diagnosticsManager.orElse(null)),
+          drainMonitor);
 
       ProcessorLifecycleListener processorLifecycleListener = 
appDesc.getProcessorLifecycleListenerFactory()
           .createInstance(new ProcessorContext() { }, config);
diff --git 
a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala 
b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index da364f1b6..e6c188db5 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -38,6 +38,8 @@ import 
org.apache.samza.container.disk.{DiskQuotaPolicyFactory, DiskSpaceMonitor
 import org.apache.samza.container.host.{StatisticsMonitorImpl, 
SystemMemoryStatistics, SystemStatisticsMonitor}
 import org.apache.samza.context._
 import org.apache.samza.diagnostics.DiagnosticsManager
+import org.apache.samza.drain.DrainMonitor.DrainCallback
+import org.apache.samza.drain.DrainMonitor
 import org.apache.samza.job.model.{ContainerModel, JobModel, TaskMode}
 import org.apache.samza.metrics.{JmxServer, JvmMetrics, MetricsRegistryMap, 
MetricsReporter}
 import org.apache.samza.serializers._
@@ -133,7 +135,8 @@ object SamzaContainer extends Logging {
     externalContextOption: Option[ExternalContext],
     localityManager: LocalityManager = null,
     startpointManager: StartpointManager = null,
-    diagnosticsManager: Option[DiagnosticsManager] = Option.empty) = {
+    diagnosticsManager: Option[DiagnosticsManager] = Option.empty,
+    drainMonitor: DrainMonitor = null) = {
     val config = if (StandbyTaskUtil.isStandbyContainer(containerId)) {
       // standby containers will need to continually poll checkpoint messages
       val newConfig = new util.HashMap[String, String](jobContext.getConfig)
@@ -544,8 +547,6 @@ object SamzaContainer extends Logging {
 
     storeWatchPaths.addAll(containerStorageManager.getStoreDirectoryPaths)
 
-
-
     // Create taskInstances
     val taskInstances: Map[TaskName, TaskInstance] = taskModels
       .filter(taskModel => 
taskModel.getTaskMode.eq(TaskMode.Active)).map(taskModel => {
@@ -671,6 +672,7 @@ object SamzaContainer extends Logging {
     } else {
       info(s"Disk quotas disabled because polling interval is not set 
($DISK_POLL_INTERVAL_KEY)")
     }
+
     info("Samza container setup complete.")
 
     new SamzaContainer(
@@ -696,6 +698,7 @@ object SamzaContainer extends Logging {
       applicationContainerContextOption = applicationContainerContextOption,
       externalContextOption = externalContextOption,
       containerStorageManager = containerStorageManager,
+      drainMonitor = drainMonitor,
       diagnosticsManager = diagnosticsManager)
   }
 }
@@ -723,6 +726,7 @@ class SamzaContainer(
   applicationContainerContextOption: Option[ApplicationContainerContext],
   externalContextOption: Option[ExternalContext],
   containerStorageManager: ContainerStorageManager,
+  drainMonitor: DrainMonitor = null,
   diagnosticsManager: Option[DiagnosticsManager] = Option.empty) extends 
Runnable with Logging {
 
   private val jobConfig = new JobConfig(config)
@@ -742,6 +746,10 @@ class SamzaContainer(
 
   def getStatus(): SamzaContainerStatus = status
 
+  def drain() {
+    consumerMultiplexer.drain
+  }
+
   def getTaskInstances() = taskInstances
 
   def setContainerListener(listener: SamzaContainerListener): Unit = {
@@ -768,6 +776,7 @@ class SamzaContainer(
       startMetrics
       startDiagnostics
       startAdmins
+      startDrainMonitor
       startOffsetManager
       storeContainerLocality
       // TODO HIGH pmaheshw SAMZA-2338: since store restore needs to trim 
changelog messages,
@@ -831,6 +840,7 @@ class SamzaContainer(
 
       shutdownConsumers
       shutdownTask
+      shutdownDrainMonitor
       shutdownTableManager
       shutdownStores
       shutdownDiskSpaceMonitor
@@ -1026,6 +1036,16 @@ class SamzaContainer(
     }
   }
 
+  def startDrainMonitor: Unit = {
+    if (drainMonitor != null) {
+      drainMonitor.registerDrainCallback(new DrainCallback {
+        override def onDrain(): Unit = drain()
+      })
+      info("Starting DrainMonitor.")
+      drainMonitor.start()
+    }
+  }
+
   def shutdownConsumers {
     info("Shutting down consumer multiplexer.")
 
@@ -1143,4 +1163,11 @@ class SamzaContainer(
       hostStatisticsMonitor.stop()
     }
   }
+
+  def shutdownDrainMonitor: Unit = {
+    if (drainMonitor != null) {
+      info("Shutting down DrainMonitor.")
+      drainMonitor.stop();
+    }
+  }
 }
diff --git 
a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala 
b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
index af0f5be14..f1c476a4b 100644
--- 
a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
+++ 
b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
@@ -29,6 +29,7 @@ import org.apache.samza.context.{ExternalContext, 
JobContextImpl}
 import org.apache.samza.coordinator.metadatastore.{CoordinatorStreamStore, 
NamespaceAwareCoordinatorStreamStore}
 import org.apache.samza.coordinator.stream.messages.SetChangelogMapping
 import org.apache.samza.coordinator.{JobModelManager, MetadataResourceUtil}
+import org.apache.samza.drain.DrainMonitor
 import org.apache.samza.execution.RemoteJobPlanner
 import org.apache.samza.job.model.JobModelUtil
 import org.apache.samza.job.{StreamJob, StreamJobFactory}
@@ -74,7 +75,9 @@ class ThreadJobFactory extends StreamJobFactory with Logging {
     val metadataResourceUtil = new MetadataResourceUtil(jobModel, 
metricsRegistry, config)
     metadataResourceUtil.createResources()
 
-    if (new JobConfig(config).getStartpointEnabled()) {
+    val jobConfig = new JobConfig(config)
+
+    if (jobConfig.getStartpointEnabled()) {
       // fan out the startpoints
       val startpointManager = new StartpointManager(coordinatorStreamStore)
       startpointManager.start()
@@ -84,10 +87,14 @@ class ThreadJobFactory extends StreamJobFactory with 
Logging {
         startpointManager.stop()
       }
     }
+    var drainMonitor: DrainMonitor = null
+    if (jobConfig.getDrainMonitorEnabled()) {
+      drainMonitor = new DrainMonitor(coordinatorStreamStore, config)
+    }
 
     val containerId = "0"
     var jmxServer: JmxServer = null
-    if (new JobConfig(config).getJMXEnabled) {
+    if (jobConfig.getJMXEnabled) {
       jmxServer = new JmxServer()
     }
 
@@ -136,8 +143,8 @@ class ThreadJobFactory extends StreamJobFactory with 
Logging {
         JobContextImpl.fromConfigWithDefaults(config, jobModel),
         Option(appDesc.getApplicationContainerContextFactory.orElse(null)),
         Option(appDesc.getApplicationTaskContextFactory.orElse(null)),
-        buildExternalContext(config)
-      )
+        buildExternalContext(config),
+        drainMonitor = drainMonitor)
       container.setContainerListener(containerListener)
 
       val threadJob = new ThreadJob(container)
@@ -147,7 +154,6 @@ class ThreadJobFactory extends StreamJobFactory with 
Logging {
       if (jmxServer != null) {
         jmxServer.stop
       }
-      coordinatorStreamStore.close()
     }
   }
 
diff --git 
a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala 
b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
index 4a24303cb..bcecebb7f 100644
--- a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
+++ b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
@@ -155,6 +155,11 @@ class SystemConsumers (
     */
   private var started = false
 
+  /**
+   * Denotes if the SystemConsumers is in drain mode.
+   * */
+  private var draining = false
+
   /**
    * Default timeout to noNewMessagesTimeout. Every time SystemConsumers
    * receives incoming messages, it sets timeout to 0. Every time
@@ -214,6 +219,10 @@ class SystemConsumers (
     refresh
   }
 
+  def drain: Unit = {
+    draining = true
+  }
+
   def stop {
     if (started) {
       debug("Stopping consumers.")
@@ -389,6 +398,10 @@ class SystemConsumers (
   }
 
   private def refresh {
+    if (draining) {
+      trace("Skipping refresh of chooser as the multiplexer is in drain mode.")
+      return
+    }
     trace("Refreshing chooser with new messages.")
 
     // Update last poll time so we don't poll too frequently.
diff --git 
a/samza-core/src/test/java/org/apache/samza/drain/DrainMonitorTests.java 
b/samza-core/src/test/java/org/apache/samza/drain/DrainMonitorTests.java
new file mode 100644
index 000000000..e7666aa49
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/drain/DrainMonitorTests.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.drain;
+
+import com.google.common.collect.ImmutableMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.samza.config.ApplicationConfig;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.coordinator.metadatastore.CoordinatorStreamStore;
+import 
org.apache.samza.coordinator.metadatastore.CoordinatorStreamStoreTestUtil;
+import 
org.apache.samza.coordinator.metadatastore.NamespaceAwareCoordinatorStreamStore;
+import org.apache.samza.metadatastore.MetadataStore;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.mockito.Mockito;
+
+/**
+ * Tests for {@link DrainMonitor}
+ * */
+public class DrainMonitorTests {
+  private static final String TEST_DEPLOYMENT_ID = "foo";
+
+  private static final Config
+      CONFIG = new MapConfig(ImmutableMap.of(
+          "job.name", "test-job",
+      "job.coordinator.system", "test-kafka",
+      ApplicationConfig.APP_RUN_ID, TEST_DEPLOYMENT_ID));
+
+  private CoordinatorStreamStore coordinatorStreamStore;
+
+  @Before
+  public void setup() {
+    CoordinatorStreamStoreTestUtil coordinatorStreamStoreTestUtil = new 
CoordinatorStreamStoreTestUtil(CONFIG);
+    coordinatorStreamStore = 
coordinatorStreamStoreTestUtil.getCoordinatorStreamStore();
+    coordinatorStreamStore.init();
+  }
+
+  @After
+  public void teardown() {
+    DrainUtils.cleanupAll(coordinatorStreamStore);
+    coordinatorStreamStore.close();
+  }
+
+  @Rule
+  public ExpectedException exceptionRule = ExpectedException.none();
+
+  @Test()
+  public void testConstructorFailureWhenDrainManagerIsNull() {
+    exceptionRule.expect(NullPointerException.class);
+    exceptionRule.expectMessage("MetadataStore parameter cannot be null.");
+    DrainMonitor unusedMonitor = new DrainMonitor(null, null, 100L);
+  }
+
+  @Test()
+  public void testConstructorFailureWhenConfigIsNull() {
+    exceptionRule.expect(NullPointerException.class);
+    exceptionRule.expectMessage("Config parameter cannot be null.");
+    DrainMonitor unusedMonitor = new 
DrainMonitor(Mockito.mock(MetadataStore.class), null, 100L);
+  }
+
+  @Test()
+  public void testConstructorFailureWithInvalidPollingInterval() {
+    exceptionRule.expect(IllegalArgumentException.class);
+    exceptionRule.expectMessage("Polling interval specified is 0 ms. It should 
be greater than 0.");
+    DrainMonitor unusedMonitor = new 
DrainMonitor(Mockito.mock(MetadataStore.class), Mockito.mock(Config.class), 0);
+  }
+
+  @Test()
+  public void testDrainMonitorStartFailureWhenCallbackIsNotSet() {
+    exceptionRule.expect(IllegalStateException.class);
+    exceptionRule.expectMessage("Drain Callback needs to be set using 
registerCallback(callback) prior to "
+        + "starting the DrainManager.");
+    DrainMonitor drainMonitor = new DrainMonitor(coordinatorStreamStore, 
CONFIG, 100L);
+    drainMonitor.start();
+  }
+
+  @Test
+  public void testSuccessfulCallbackRegistration() {
+    DrainMonitor drainMonitor = new DrainMonitor(coordinatorStreamStore, 
CONFIG, 100L);
+    DrainMonitor.DrainCallback emptyCallback = () -> { };
+    boolean callbackRegistrationResult1 = 
drainMonitor.registerDrainCallback(emptyCallback);
+    // first registration of callback should succeed
+    Assert.assertTrue(callbackRegistrationResult1);
+    boolean callbackRegistrationResult2 = 
drainMonitor.registerDrainCallback(emptyCallback);
+    // repeat registration of callback should fail
+    Assert.assertFalse(callbackRegistrationResult2);
+  }
+
+  @Test
+  public void testCallbackCalledIfMonitorEncountersDrainOnStart() throws 
InterruptedException {
+    final AtomicInteger numCallbacks = new AtomicInteger(0);
+    final CountDownLatch latch = new CountDownLatch(1);
+    // write drain before monitor start
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, 
TEST_DEPLOYMENT_ID);
+    DrainMonitor drainMonitor = new DrainMonitor(coordinatorStreamStore, 
CONFIG);
+    drainMonitor.registerDrainCallback(() -> {
+      numCallbacks.incrementAndGet();
+      latch.countDown();
+    });
+    // monitor shouldn't go into RUNNING state as DrainNotification was 
already present and it shouldn't start poll
+    drainMonitor.start();
+    if (!latch.await(2, TimeUnit.SECONDS)) {
+      Assert.fail("Timed out waiting for drain callback to complete");
+    }
+    Assert.assertEquals(1, numCallbacks.get());
+    Assert.assertEquals(DrainMonitor.State.INIT, drainMonitor.getState());
+  }
+
+  @Test
+  public void testCallbackCalledOnDrain() throws InterruptedException {
+    final AtomicInteger numCallbacks = new AtomicInteger(0);
+    final CountDownLatch latch = new CountDownLatch(1);
+
+    DrainMonitor drainMonitor = new DrainMonitor(coordinatorStreamStore, 
CONFIG, 100L);
+
+    drainMonitor.registerDrainCallback(() -> {
+      numCallbacks.incrementAndGet();
+      latch.countDown();
+    });
+    drainMonitor.start();
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, 
TEST_DEPLOYMENT_ID);
+    if (!latch.await(2, TimeUnit.SECONDS)) {
+      Assert.fail("Timed out waiting for drain callback to complete");
+    }
+    Assert.assertEquals(DrainMonitor.State.STOPPED, drainMonitor.getState());
+    Assert.assertEquals(1, numCallbacks.get());
+  }
+
+  @Test
+  public void testCallbackNotCalledDueToMismatchedDeploymentId() throws 
InterruptedException {
+    // The test fails due to timeout as the published DrainNotification's 
deploymentId doesn't match deploymentId
+    // in the config
+    exceptionRule.expect(AssertionError.class);
+    exceptionRule.expectMessage("Timed out waiting for drain callback to 
complete.");
+    final AtomicInteger numCallbacks = new AtomicInteger(0);
+    final CountDownLatch latch = new CountDownLatch(1);
+
+    DrainMonitor drainMonitor = new DrainMonitor(coordinatorStreamStore, 
CONFIG, 100L);
+
+    drainMonitor.registerDrainCallback(() -> {
+      numCallbacks.incrementAndGet();
+      latch.countDown();
+    });
+
+    drainMonitor.start();
+    final String mismatchedDeploymentId = "bar";
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, 
mismatchedDeploymentId);
+    if (!latch.await(2, TimeUnit.SECONDS)) {
+      Assert.fail("Timed out waiting for drain callback to complete.");
+    }
+  }
+
+  @Test
+  public void testDrainMonitorStop() {
+    DrainMonitor drainMonitor = new DrainMonitor(coordinatorStreamStore, 
CONFIG, 100L);
+    drainMonitor.registerDrainCallback(() -> { });
+    drainMonitor.start();
+    drainMonitor.stop();
+    Assert.assertEquals(drainMonitor.getState(), DrainMonitor.State.STOPPED);
+  }
+
+  @Test
+  public void testShouldDrain() {
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, 
TEST_DEPLOYMENT_ID);
+    NamespaceAwareCoordinatorStreamStore drainStore =
+        new NamespaceAwareCoordinatorStreamStore(coordinatorStreamStore, 
DrainUtils.DRAIN_METADATA_STORE_NAMESPACE);
+    Assert.assertTrue(DrainMonitor.shouldDrain(drainStore, 
TEST_DEPLOYMENT_ID));
+
+    // Cleanup old drain message
+    DrainUtils.cleanup(coordinatorStreamStore, CONFIG);
+
+    final String mismatchedDeploymentId = "bar";
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, 
mismatchedDeploymentId);
+    Assert.assertFalse(DrainMonitor.shouldDrain(drainStore, 
TEST_DEPLOYMENT_ID));
+  }
+}
diff --git 
a/samza-core/src/test/java/org/apache/samza/drain/DrainNotificationObjectMapperTests.java
 
b/samza-core/src/test/java/org/apache/samza/drain/DrainNotificationObjectMapperTests.java
new file mode 100644
index 000000000..05c6b7da2
--- /dev/null
+++ 
b/samza-core/src/test/java/org/apache/samza/drain/DrainNotificationObjectMapperTests.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.drain;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import java.io.IOException;
+import java.util.UUID;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link DrainNotificationObjectMapper}
+ * */
+public class DrainNotificationObjectMapperTests {
+  @Test
+  public void testDrainNotificationSerde() throws IOException {
+    UUID uuid = UUID.randomUUID();
+    DrainNotification originalMessage = new DrainNotification(uuid, "foobar");
+    ObjectMapper objectMapper = 
DrainNotificationObjectMapper.getObjectMapper();
+    byte[] bytes = objectMapper.writeValueAsBytes(originalMessage);
+    DrainNotification deserializedMessage = objectMapper.readValue(bytes, 
DrainNotification.class);
+    assertEquals(originalMessage, deserializedMessage);
+  }
+}
diff --git 
a/samza-core/src/test/java/org/apache/samza/drain/DrainUtilsTests.java 
b/samza-core/src/test/java/org/apache/samza/drain/DrainUtilsTests.java
new file mode 100644
index 000000000..17265352d
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/drain/DrainUtilsTests.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.drain;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.UUID;
+import org.apache.samza.SamzaException;
+import org.apache.samza.config.ApplicationConfig;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.coordinator.metadatastore.CoordinatorStreamStore;
+import 
org.apache.samza.coordinator.metadatastore.CoordinatorStreamStoreTestUtil;
+import 
org.apache.samza.coordinator.metadatastore.NamespaceAwareCoordinatorStreamStore;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+
+/**
+ * Tests for {@link DrainUtils}
+ * */
+public class DrainUtilsTests {
+  private static final String TEST_DEPLOYMENT_ID = "foo";
+  private static final Config CONFIG = new MapConfig(ImmutableMap.of(
+      "job.name", "test-job",
+      "job.coordinator.system", "test-kafka",
+      ApplicationConfig.APP_RUN_ID, TEST_DEPLOYMENT_ID));
+
+  private CoordinatorStreamStore coordinatorStreamStore;
+
+  @Before
+  public void setup() {
+    CoordinatorStreamStoreTestUtil coordinatorStreamStoreTestUtil = new 
CoordinatorStreamStoreTestUtil(CONFIG);
+    coordinatorStreamStore = 
coordinatorStreamStoreTestUtil.getCoordinatorStreamStore();
+    coordinatorStreamStore.init();
+  }
+
+  @After
+  public void teardown() {
+    DrainUtils.cleanupAll(coordinatorStreamStore);
+    coordinatorStreamStore.close();
+  }
+
+  @Test
+  public void testWrites() {
+    String deploymentId1 = "foo1";
+    String deploymentId2 = "foo2";
+    String deploymentId3 = "foo3";
+
+    UUID uuid1 = DrainUtils.writeDrainNotification(coordinatorStreamStore, 
deploymentId1);
+    UUID uuid2 = DrainUtils.writeDrainNotification(coordinatorStreamStore, 
deploymentId2);
+    UUID uuid3 = DrainUtils.writeDrainNotification(coordinatorStreamStore, 
deploymentId3);
+
+    DrainNotification expectedDrainNotification1 = new 
DrainNotification(uuid1, deploymentId1);
+    DrainNotification expectedDrainNotification2 = new 
DrainNotification(uuid2, deploymentId2);
+    DrainNotification expectedDrainNotification3 = new 
DrainNotification(uuid3, deploymentId3);
+    Set<DrainNotification> expectedDrainNotifications = new 
HashSet<>(Arrays.asList(expectedDrainNotification1,
+        expectedDrainNotification2, expectedDrainNotification3));
+
+    Optional<List<DrainNotification>> drainNotifications = 
readDrainNotificationMessages(coordinatorStreamStore);
+    Assert.assertTrue(drainNotifications.isPresent());
+    Assert.assertEquals(3, drainNotifications.get().size());
+    Assert.assertEquals(expectedDrainNotifications, new 
HashSet<>(drainNotifications.get()));
+  }
+
+  @Test
+  public void testCleanup() {
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, 
TEST_DEPLOYMENT_ID);
+    DrainUtils.cleanup(coordinatorStreamStore, CONFIG);
+    final Optional<List<DrainNotification>> drainNotifications1 = 
readDrainNotificationMessages(coordinatorStreamStore);
+    Assert.assertFalse(drainNotifications1.isPresent());
+
+    final String deploymentId = "bar";
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, deploymentId);
+    DrainUtils.cleanup(coordinatorStreamStore, CONFIG);
+    final Optional<List<DrainNotification>> drainNotifications2 = 
readDrainNotificationMessages(coordinatorStreamStore);
+    Assert.assertTrue(drainNotifications2.isPresent());
+    Assert.assertEquals(deploymentId, 
drainNotifications2.get().get(0).getDeploymentId());
+  }
+
+  @Test
+  public void testCleanupAll() {
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, 
TEST_DEPLOYMENT_ID);
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, 
TEST_DEPLOYMENT_ID);
+    DrainUtils.writeDrainNotification(coordinatorStreamStore, "bar");
+    DrainUtils.cleanupAll(coordinatorStreamStore);
+    final Optional<List<DrainNotification>> drainNotifications = 
readDrainNotificationMessages(coordinatorStreamStore);
+    Assert.assertFalse(drainNotifications.isPresent());
+  }
+
+  private static Optional<List<DrainNotification>> 
readDrainNotificationMessages(CoordinatorStreamStore metadataStore) {
+    final NamespaceAwareCoordinatorStreamStore drainMetadataStore =
+        new NamespaceAwareCoordinatorStreamStore(metadataStore, 
DrainUtils.DRAIN_METADATA_STORE_NAMESPACE);
+    final ObjectMapper objectMapper = 
DrainNotificationObjectMapper.getObjectMapper();
+    final ImmutableList<DrainNotification> drainNotifications = 
drainMetadataStore.all()
+        .values()
+        .stream()
+        .map(bytes -> {
+          try {
+            return objectMapper.readValue(bytes, DrainNotification.class);
+          } catch (IOException e) {
+            throw new SamzaException(e);
+          }
+        })
+        .collect(ImmutableList.toImmutableList());
+    return drainNotifications.size() > 0
+        ? Optional.of(drainNotifications)
+        : Optional.empty();
+  }
+}
diff --git 
a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala 
b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
index 5154069b3..d9bb916bf 100644
--- 
a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
+++ 
b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
@@ -19,15 +19,18 @@
 
 package org.apache.samza.container
 
+import com.google.common.collect.ImmutableMap
+
 import java.util
 import java.util.concurrent.atomic.AtomicReference
-
 import javax.servlet.http.{HttpServlet, HttpServletRequest, 
HttpServletResponse}
 import org.apache.samza.Partition
 import org.apache.samza.config.{ClusterManagerConfig, Config, MapConfig}
 import org.apache.samza.context.{ApplicationContainerContext, ContainerContext}
 import org.apache.samza.coordinator.JobModelManager
+import 
org.apache.samza.coordinator.metadatastore.CoordinatorStreamStoreTestUtil
 import org.apache.samza.coordinator.server.{HttpServer, JobServlet}
+import org.apache.samza.drain.DrainMonitor
 import org.apache.samza.job.model.{ContainerModel, JobModel, TaskModel}
 import org.apache.samza.metrics.Gauge
 import org.apache.samza.serializers.model.SamzaObjectMapper
@@ -74,6 +77,8 @@ class TestSamzaContainer extends AssertionsForJUnit with 
MockitoSugar {
   @Mock
   private var containerStorageManager: ContainerStorageManager = null
 
+  private var drainMonitor: DrainMonitor = null
+
   private var samzaContainer: SamzaContainer = null
 
   @Before
@@ -151,7 +156,8 @@ class TestSamzaContainer extends AssertionsForJUnit with 
MockitoSugar {
       containerContext = this.containerContext,
       applicationContainerContextOption = 
Some(this.applicationContainerContext),
       externalContextOption = None,
-      containerStorageManager = containerStorageManager)
+      containerStorageManager = containerStorageManager,
+      drainMonitor = drainMonitor)
     this.samzaContainer.setContainerListener(this.samzaContainerListener)
 
     new ShutDownSignal(samzaContainer).run();
@@ -320,6 +326,14 @@ class TestSamzaContainer extends AssertionsForJUnit with 
MockitoSugar {
   }
 
   private def setupSamzaContainer(applicationContainerContext: 
Option[ApplicationContainerContext]) {
+    val coordinatorStreamConfig = new MapConfig(ImmutableMap.of(
+      "job.name", "test-job",
+      "job.coordinator.system", "test-kafka"))
+    val coordinatorStreamStoreTestUtil = new 
CoordinatorStreamStoreTestUtil(coordinatorStreamConfig)
+    val coordinatorStreamStore = 
coordinatorStreamStoreTestUtil.getCoordinatorStreamStore
+    coordinatorStreamStore.init()
+    drainMonitor = new DrainMonitor(coordinatorStreamStore, config)
+
     this.samzaContainer = new SamzaContainer(
       this.config,
       Map(TASK_NAME -> this.taskInstance),
@@ -333,7 +347,8 @@ class TestSamzaContainer extends AssertionsForJUnit with 
MockitoSugar {
       containerContext = this.containerContext,
       applicationContainerContextOption = applicationContainerContext,
       externalContextOption = None,
-      containerStorageManager = containerStorageManager)
+      containerStorageManager = containerStorageManager,
+      drainMonitor = drainMonitor)
     this.samzaContainer.setContainerListener(this.samzaContainerListener)
   }
 
diff --git 
a/samza-core/src/test/scala/org/apache/samza/processor/StreamProcessorTestUtils.scala
 
b/samza-core/src/test/scala/org/apache/samza/processor/StreamProcessorTestUtils.scala
index f7d65d1d6..c8955124a 100644
--- 
a/samza-core/src/test/scala/org/apache/samza/processor/StreamProcessorTestUtils.scala
+++ 
b/samza-core/src/test/scala/org/apache/samza/processor/StreamProcessorTestUtils.scala
@@ -18,13 +18,16 @@
  */
 package org.apache.samza.processor
 
+import com.google.common.collect.ImmutableMap
+
 import java.util
 import java.util.Collections
-
 import org.apache.samza.Partition
 import org.apache.samza.config.MapConfig
 import org.apache.samza.container._
 import org.apache.samza.context.{ContainerContext, JobContext}
+import 
org.apache.samza.coordinator.metadatastore.CoordinatorStreamStoreTestUtil
+import org.apache.samza.drain.DrainMonitor
 import org.apache.samza.job.model.TaskModel
 import org.apache.samza.serializers.SerdeManager
 import org.apache.samza.storage.ContainerStorageManager
@@ -65,6 +68,14 @@ object StreamProcessorTestUtils {
       applicationTaskContextFactoryOption = None,
       externalContextOption = None)
 
+    val coordinatorStreamConfig = new MapConfig(ImmutableMap.of(
+      "job.name", "test-job",
+      "job.coordinator.system", "test-kafka"))
+    val coordinatorStreamStoreTestUtil = new 
CoordinatorStreamStoreTestUtil(coordinatorStreamConfig)
+    val coordinatorStreamStore = 
coordinatorStreamStoreTestUtil.getCoordinatorStreamStore
+    coordinatorStreamStore.init()
+    val drainMonitor = new DrainMonitor(coordinatorStreamStore, config)
+
     val container = new SamzaContainer(
       config = config,
       taskInstances = Map(taskName -> taskInstance),
@@ -77,7 +88,8 @@ object StreamProcessorTestUtils {
       containerContext = containerContext,
       applicationContainerContextOption = None,
       externalContextOption = None,
-      containerStorageManager = Mockito.mock(classOf[ContainerStorageManager]))
+      containerStorageManager = Mockito.mock(classOf[ContainerStorageManager]),
+      drainMonitor = drainMonitor)
     container
   }
-}
\ No newline at end of file
+}

Reply via email to