This is an automated email from the ASF dual-hosted git repository.

rong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 7004e4d8eed Subscription: Push consumer & Async commit (#12294)
7004e4d8eed is described below

commit 7004e4d8eed63c081936b4ecf04769b433a5a2b0
Author: Zikun Ma <[email protected]>
AuthorDate: Fri Apr 19 12:54:23 2024 +0800

    Subscription: Push consumer & Async commit (#12294)
---
 .../it/local/IoTDBSubscriptionBasicIT.java         | 260 ++++++++++++++++++++-
 .../rpc/subscription/config/ConsumerConstant.java  |   7 +-
 ...scriptionPushConsumer.java => AckStrategy.java} |  26 +--
 ...nPushConsumer.java => AsyncCommitCallback.java} |  24 +-
 ...ptionPushConsumer.java => ConsumeListener.java} |  25 +-
 ...riptionPushConsumer.java => ConsumeResult.java} |  25 +-
 .../session/subscription/SubscriptionConsumer.java | 136 ++++++++++-
 .../subscription/SubscriptionPullConsumer.java     |  67 ++----
 .../subscription/SubscriptionPushConsumer.java     | 192 ++++++++++++++-
 9 files changed, 623 insertions(+), 139 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/subscription/it/local/IoTDBSubscriptionBasicIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/subscription/it/local/IoTDBSubscriptionBasicIT.java
index 6938763c8f8..b88fcf1a7cd 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/subscription/it/local/IoTDBSubscriptionBasicIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/subscription/it/local/IoTDBSubscriptionBasicIT.java
@@ -23,8 +23,12 @@ import org.apache.iotdb.isession.ISession;
 import org.apache.iotdb.it.env.EnvFactory;
 import org.apache.iotdb.it.framework.IoTDBTestRunner;
 import org.apache.iotdb.itbase.category.LocalStandaloneIT;
+import org.apache.iotdb.session.subscription.AckStrategy;
+import org.apache.iotdb.session.subscription.AsyncCommitCallback;
+import org.apache.iotdb.session.subscription.ConsumeResult;
 import org.apache.iotdb.session.subscription.SubscriptionMessage;
 import org.apache.iotdb.session.subscription.SubscriptionPullConsumer;
+import org.apache.iotdb.session.subscription.SubscriptionPushConsumer;
 import org.apache.iotdb.session.subscription.SubscriptionSession;
 import org.apache.iotdb.session.subscription.SubscriptionSessionDataSet;
 import org.apache.iotdb.session.subscription.SubscriptionSessionDataSets;
@@ -64,7 +68,7 @@ public class IoTDBSubscriptionBasicIT {
   }
 
   @Test
-  public void testBasicSubscription() throws Exception {
+  public void testBasicPullConsumer() throws Exception {
     // Insert some historical data
     try (final ISession session = EnvFactory.getEnv().getSessionConnection()) {
       for (int i = 0; i < 100; ++i) {
@@ -154,4 +158,258 @@ public class IoTDBSubscriptionBasicIT {
       thread.join();
     }
   }
+
+  @Test
+  public void testBasicPullConsumerWithCommitAsync() throws Exception {
+    // Insert some historical data
+    try (final ISession session = EnvFactory.getEnv().getSessionConnection()) {
+      for (int i = 0; i < 100; ++i) {
+        session.executeNonQueryStatement(
+            String.format("insert into root.db.d1(time, s1) values (%s, 1)", 
i));
+      }
+      session.executeNonQueryStatement("flush");
+    } catch (final Exception e) {
+      e.printStackTrace();
+      fail(e.getMessage());
+    }
+
+    // Create topic
+    final String host = EnvFactory.getEnv().getIP();
+    final int port = Integer.parseInt(EnvFactory.getEnv().getPort());
+    try (final SubscriptionSession session = new SubscriptionSession(host, 
port)) {
+      session.open();
+      session.createTopic("topic1");
+    } catch (final Exception e) {
+      e.printStackTrace();
+      fail(e.getMessage());
+    }
+
+    // Subscription
+    final AtomicInteger rowCount = new AtomicInteger();
+    final AtomicInteger commitSuccessCount = new AtomicInteger();
+    final AtomicInteger lastCommitSuccessCount = new AtomicInteger();
+    final AtomicInteger commitFailureCount = new AtomicInteger();
+    final AtomicBoolean isClosed = new AtomicBoolean(false);
+    final Thread thread =
+        new Thread(
+            () -> {
+              try (final SubscriptionPullConsumer consumer =
+                  new SubscriptionPullConsumer.Builder()
+                      .host(host)
+                      .port(port)
+                      .consumerId("c1")
+                      .consumerGroupId("cg1")
+                      .autoCommit(false)
+                      .buildPullConsumer()) {
+                consumer.open();
+                consumer.subscribe("topic1");
+                while (!isClosed.get()) {
+                  try {
+                    Thread.sleep(1000); // wait some time
+                  } catch (final InterruptedException e) {
+                    break;
+                  }
+                  final List<SubscriptionMessage> messages =
+                      consumer.poll(Duration.ofMillis(10000));
+                  if (messages.isEmpty()) {
+                    continue;
+                  }
+                  for (final SubscriptionMessage message : messages) {
+                    final SubscriptionSessionDataSets payload =
+                        (SubscriptionSessionDataSets) message.getPayload();
+                    int rowCountInOneMessage = 0;
+                    for (final SubscriptionSessionDataSet dataSet : payload) {
+                      while (dataSet.hasNext()) {
+                        dataSet.next();
+                        rowCount.addAndGet(1);
+                        rowCountInOneMessage++;
+                      }
+                    }
+                    LOGGER.info(rowCountInOneMessage + " rows in message");
+                  }
+                  consumer.commitAsync(
+                      messages,
+                      new AsyncCommitCallback() {
+                        @Override
+                        public void onComplete() {
+                          commitSuccessCount.incrementAndGet();
+                          LOGGER.info("commit success, messages size: {}", 
messages.size());
+                        }
+
+                        @Override
+                        public void onFailure(Throwable e) {
+                          commitFailureCount.incrementAndGet();
+                        }
+                      });
+                }
+                consumer.unsubscribe("topic1");
+              } catch (final Exception e) {
+                e.printStackTrace();
+                // avoid fail
+              } finally {
+                LOGGER.info("consumer exiting...");
+              }
+            });
+    thread.start();
+
+    // Check row count
+    try {
+      // Keep retrying if there are execution failures
+      Awaitility.await()
+          .pollDelay(1, TimeUnit.SECONDS)
+          .pollInterval(1, TimeUnit.SECONDS)
+          .atMost(120, TimeUnit.SECONDS)
+          .untilAsserted(() -> Assert.assertEquals(100, rowCount.get()));
+      Assert.assertTrue(commitSuccessCount.get() > 
lastCommitSuccessCount.get());
+      Assert.assertEquals(0, commitFailureCount.get());
+    } catch (final Exception e) {
+      e.printStackTrace();
+      fail(e.getMessage());
+    }
+
+    lastCommitSuccessCount.set(commitSuccessCount.get());
+
+    // Insert more data, the pull consumer is also running, so the data may be 
pulled more than
+    // once.
+    try (final ISession session = EnvFactory.getEnv().getSessionConnection()) {
+      for (int i = 100; i < 200; ++i) {
+        session.executeNonQueryStatement(
+            String.format("insert into root.db.d1(time, s1) values (%s, 1)", 
i));
+      }
+      session.executeNonQueryStatement("flush");
+    } catch (final Exception e) {
+      e.printStackTrace();
+      fail(e.getMessage());
+    }
+
+    // Check row count
+    try {
+      // Keep retrying if there are execution failures
+      Awaitility.await()
+          .pollDelay(1, TimeUnit.SECONDS)
+          .pollInterval(1, TimeUnit.SECONDS)
+          .atMost(120, TimeUnit.SECONDS)
+          .untilAsserted(() -> Assert.assertEquals(200, rowCount.get()));
+      Assert.assertTrue(commitSuccessCount.get() > 
lastCommitSuccessCount.get());
+      Assert.assertEquals(0, commitFailureCount.get());
+    } catch (final Exception e) {
+      e.printStackTrace();
+      fail(e.getMessage());
+    } finally {
+      isClosed.set(true);
+      thread.join();
+    }
+  }
+
+  @Test
+  public void testBasicPushConsumer() {
+    final AtomicInteger onReceiveCount = new AtomicInteger(0);
+    final AtomicInteger lastOnReceiveCount = new AtomicInteger(0);
+    final AtomicInteger rowCount = new AtomicInteger(0);
+
+    // Insert some historical data
+    try (final ISession session = EnvFactory.getEnv().getSessionConnection()) {
+      for (int i = 0; i < 10; ++i) {
+        session.executeNonQueryStatement(
+            String.format("insert into root.db.d1(time, s1) values (%s, 1)", 
i));
+      }
+      session.executeNonQueryStatement("flush");
+    } catch (final Exception e) {
+      e.printStackTrace();
+      fail(e.getMessage());
+    }
+
+    // Create topic
+    final String host = EnvFactory.getEnv().getIP();
+    final int port = Integer.parseInt(EnvFactory.getEnv().getPort());
+    try (final SubscriptionSession session = new SubscriptionSession(host, 
port)) {
+      session.open();
+      session.createTopic("topic1");
+    } catch (final Exception e) {
+      e.printStackTrace();
+      fail(e.getMessage());
+    }
+
+    // Subscription
+    try (final SubscriptionPushConsumer consumer =
+        new SubscriptionPushConsumer.Builder()
+            .host(host)
+            .port(port)
+            .consumerId("c1")
+            .consumerGroupId("cg1")
+            .ackStrategy(AckStrategy.BEFORE_CONSUME)
+            .consumeListener(
+                message -> {
+                  onReceiveCount.getAndIncrement();
+                  SubscriptionSessionDataSets dataSets =
+                      (SubscriptionSessionDataSets) message.getPayload();
+                  dataSets
+                      .tabletIterator()
+                      .forEachRemaining(tablet -> 
rowCount.addAndGet(tablet.rowSize));
+                  return ConsumeResult.SUCCESS;
+                })
+            .buildPushConsumer()) {
+
+      consumer.open();
+      consumer.subscribe("topic1");
+
+      // The push consumer should automatically poll 10 rows of data by 1 
onReceive()
+      Awaitility.await()
+          .pollDelay(1, TimeUnit.SECONDS)
+          .pollInterval(1, TimeUnit.SECONDS)
+          .atMost(10, TimeUnit.SECONDS)
+          .untilAsserted(
+              () -> {
+                Assert.assertEquals(10, rowCount.get());
+                Assert.assertTrue(onReceiveCount.get() > 
lastOnReceiveCount.get());
+              });
+
+      lastOnReceiveCount.set(onReceiveCount.get());
+
+      // Insert more rows and check if the push consumer can automatically 
poll the new data
+      try (final ISession session = 
EnvFactory.getEnv().getSessionConnection()) {
+        for (int i = 10; i < 20; ++i) {
+          session.executeNonQueryStatement(
+              String.format("insert into root.db.d1(time, s1) values (%s, 1)", 
i));
+        }
+        session.executeNonQueryStatement("flush");
+      }
+
+      Awaitility.await()
+          .pollDelay(1, TimeUnit.SECONDS)
+          .pollInterval(1, TimeUnit.SECONDS)
+          .atMost(10, TimeUnit.SECONDS)
+          .untilAsserted(
+              () -> {
+                Assert.assertEquals(20, rowCount.get());
+                Assert.assertTrue(onReceiveCount.get() > 
lastOnReceiveCount.get());
+              });
+
+      lastOnReceiveCount.set(onReceiveCount.get());
+
+      try (final ISession session = 
EnvFactory.getEnv().getSessionConnection()) {
+        for (int i = 20; i < 30; ++i) {
+          session.executeNonQueryStatement(
+              String.format("insert into root.db.d1(time, s1) values (%s, 1)", 
i));
+        }
+        session.executeNonQueryStatement("flush");
+      }
+
+      Awaitility.await()
+          .pollDelay(1, TimeUnit.SECONDS)
+          .pollInterval(1, TimeUnit.SECONDS)
+          .atMost(10, TimeUnit.SECONDS)
+          .untilAsserted(
+              () -> {
+                Assert.assertEquals(30, rowCount.get());
+                Assert.assertTrue(onReceiveCount.get() > 
lastOnReceiveCount.get());
+              });
+
+      consumer.unsubscribe("topic1");
+
+    } catch (final Exception e) {
+      e.printStackTrace();
+      fail(e.getMessage());
+    }
+  }
 }
diff --git 
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/subscription/config/ConsumerConstant.java
 
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/subscription/config/ConsumerConstant.java
index 1db77c7e03a..7f2b01ea480 100644
--- 
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/subscription/config/ConsumerConstant.java
+++ 
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/subscription/config/ConsumerConstant.java
@@ -53,7 +53,12 @@ public class ConsumerConstant {
 
   /////////////////////////////// push consumer ///////////////////////////////
 
-  // TODO
+  public static final String ACK_STRATEGY_KEY = "ack-strategy";
+  public static final String CONSUME_LISTENER_KEY = "consume-listener";
+
+  // TODO: configure this parameter
+  public static final int PUSH_CONSUMER_AUTO_POLL_INTERVAL_MS = 1000;
+  public static final int PUSH_CONSUMER_AUTO_POLL_TIME_OUT_MS = 2000;
 
   private ConsumerConstant() {
     throw new IllegalStateException("Utility class");
diff --git 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/AckStrategy.java
similarity index 57%
copy from 
iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
copy to 
iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/AckStrategy.java
index 950a396d58d..217d50ffd02 100644
--- 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
+++ 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/AckStrategy.java
@@ -19,26 +19,12 @@
 
 package org.apache.iotdb.session.subscription;
 
-import org.apache.iotdb.rpc.subscription.exception.SubscriptionException;
+public enum AckStrategy {
+  BEFORE_CONSUME,
+  AFTER_CONSUME;
 
-// TODO
-public class SubscriptionPushConsumer extends SubscriptionConsumer {
-
-  protected SubscriptionPushConsumer(Builder builder) {
-    super(builder);
-  }
-
-  public static class Builder extends SubscriptionConsumer.Builder {
-
-    @Override
-    public SubscriptionPullConsumer buildPullConsumer() {
-      throw new SubscriptionException(
-          "SubscriptionPushConsumer.Builder do not support build pull 
consumer.");
-    }
-
-    @Override
-    public SubscriptionPushConsumer buildPushConsumer() {
-      return new SubscriptionPushConsumer(this);
-    }
+  public static AckStrategy defaultValue() {
+    // Use AFTER_CONSUME by default
+    return AFTER_CONSUME;
   }
 }
diff --git 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/AsyncCommitCallback.java
similarity index 58%
copy from 
iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
copy to 
iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/AsyncCommitCallback.java
index 950a396d58d..0d538dfd671 100644
--- 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
+++ 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/AsyncCommitCallback.java
@@ -19,26 +19,12 @@
 
 package org.apache.iotdb.session.subscription;
 
-import org.apache.iotdb.rpc.subscription.exception.SubscriptionException;
-
-// TODO
-public class SubscriptionPushConsumer extends SubscriptionConsumer {
-
-  protected SubscriptionPushConsumer(Builder builder) {
-    super(builder);
+public interface AsyncCommitCallback {
+  default void onComplete() {
+    // Do nothing
   }
 
-  public static class Builder extends SubscriptionConsumer.Builder {
-
-    @Override
-    public SubscriptionPullConsumer buildPullConsumer() {
-      throw new SubscriptionException(
-          "SubscriptionPushConsumer.Builder do not support build pull 
consumer.");
-    }
-
-    @Override
-    public SubscriptionPushConsumer buildPushConsumer() {
-      return new SubscriptionPushConsumer(this);
-    }
+  default void onFailure(Throwable e) {
+    // Do nothing
   }
 }
diff --git 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/ConsumeListener.java
similarity index 57%
copy from 
iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
copy to 
iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/ConsumeListener.java
index 950a396d58d..f266c60b926 100644
--- 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
+++ 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/ConsumeListener.java
@@ -19,26 +19,7 @@
 
 package org.apache.iotdb.session.subscription;
 
-import org.apache.iotdb.rpc.subscription.exception.SubscriptionException;
-
-// TODO
-public class SubscriptionPushConsumer extends SubscriptionConsumer {
-
-  protected SubscriptionPushConsumer(Builder builder) {
-    super(builder);
-  }
-
-  public static class Builder extends SubscriptionConsumer.Builder {
-
-    @Override
-    public SubscriptionPullConsumer buildPullConsumer() {
-      throw new SubscriptionException(
-          "SubscriptionPushConsumer.Builder do not support build pull 
consumer.");
-    }
-
-    @Override
-    public SubscriptionPushConsumer buildPushConsumer() {
-      return new SubscriptionPushConsumer(this);
-    }
-  }
+@FunctionalInterface
+public interface ConsumeListener {
+  ConsumeResult onReceive(SubscriptionMessage message);
 }
diff --git 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/ConsumeResult.java
similarity index 57%
copy from 
iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
copy to 
iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/ConsumeResult.java
index 950a396d58d..63bf701a02d 100644
--- 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
+++ 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/ConsumeResult.java
@@ -19,26 +19,7 @@
 
 package org.apache.iotdb.session.subscription;
 
-import org.apache.iotdb.rpc.subscription.exception.SubscriptionException;
-
-// TODO
-public class SubscriptionPushConsumer extends SubscriptionConsumer {
-
-  protected SubscriptionPushConsumer(Builder builder) {
-    super(builder);
-  }
-
-  public static class Builder extends SubscriptionConsumer.Builder {
-
-    @Override
-    public SubscriptionPullConsumer buildPullConsumer() {
-      throw new SubscriptionException(
-          "SubscriptionPushConsumer.Builder do not support build pull 
consumer.");
-    }
-
-    @Override
-    public SubscriptionPushConsumer buildPushConsumer() {
-      return new SubscriptionPushConsumer(this);
-    }
-  }
+public enum ConsumeResult {
+  SUCCESS,
+  FAILURE
 }
diff --git 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionConsumer.java
 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionConsumer.java
index 4f93c97e795..ccf73c35983 100644
--- 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionConsumer.java
+++ 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionConsumer.java
@@ -24,6 +24,7 @@ import org.apache.iotdb.isession.SessionConfig;
 import org.apache.iotdb.rpc.IoTDBConnectionException;
 import org.apache.iotdb.rpc.StatementExecutionException;
 import org.apache.iotdb.rpc.subscription.config.ConsumerConstant;
+import org.apache.iotdb.rpc.subscription.payload.EnrichedTablets;
 import org.apache.iotdb.session.util.SessionUtils;
 
 import org.apache.thrift.TException;
@@ -34,6 +35,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -42,6 +44,7 @@ import java.util.Properties;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.concurrent.ConcurrentSkipListMap;
+import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
@@ -74,6 +77,8 @@ public abstract class SubscriptionConsumer implements 
AutoCloseable {
   private ScheduledExecutorService heartbeatWorkerExecutor;
   private ScheduledExecutorService endpointsSyncerExecutor;
 
+  private ExecutorService asyncCommitExecutor;
+
   private final AtomicBoolean isClosed = new AtomicBoolean(true);
 
   public String getConsumerId() {
@@ -173,8 +178,8 @@ public abstract class SubscriptionConsumer implements 
AutoCloseable {
       // shutdown endpoints syncer
       shutdownEndpointsSyncer();
 
-      // shutdown heartbeat worker
-      shutdownHeartbeatWorker();
+      // shutdown workers
+      shutdownWorkers();
 
       // close subscription providers
       acquireWriteLock();
@@ -274,9 +279,18 @@ public abstract class SubscriptionConsumer implements 
AutoCloseable {
         new ConsumerHeartbeatWorker(this), 0, heartbeatIntervalMs, 
TimeUnit.MILLISECONDS);
   }
 
-  private void shutdownHeartbeatWorker() {
+  /**
+   * Shut down workers upon close. There are currently two workers: heartbeat 
worker and async
+   * commit executor.
+   */
+  private void shutdownWorkers() {
     heartbeatWorkerExecutor.shutdown();
     heartbeatWorkerExecutor = null;
+
+    if (asyncCommitExecutor != null) {
+      asyncCommitExecutor.shutdown();
+      asyncCommitExecutor = null;
+    }
   }
 
   /////////////////////////////// endpoints syncer 
///////////////////////////////
@@ -369,6 +383,98 @@ public abstract class SubscriptionConsumer implements 
AutoCloseable {
     }
   }
 
+  /////////////////////////////// poll & commit ///////////////////////////////
+
+  protected List<SubscriptionMessage> poll(Set<String> topicNames, long 
timeoutMs)
+      throws TException, IOException, StatementExecutionException {
+    List<EnrichedTablets> enrichedTabletsList = new ArrayList<>();
+
+    acquireReadLock();
+    try {
+      for (final SubscriptionProvider provider : getAllAvailableProviders()) {
+        // TODO: network timeout
+        
enrichedTabletsList.addAll(provider.getSessionConnection().poll(topicNames, 
timeoutMs));
+      }
+    } finally {
+      releaseReadLock();
+    }
+
+    return 
enrichedTabletsList.stream().map(SubscriptionMessage::new).collect(Collectors.toList());
+  }
+
+  protected void commitSync(Iterable<SubscriptionMessage> messages)
+      throws TException, IOException, StatementExecutionException, 
IoTDBConnectionException {
+    Map<Integer, Map<String, List<String>>> 
dataNodeIdToTopicNameToSubscriptionCommitIds =
+        new HashMap<>();
+    for (SubscriptionMessage message : messages) {
+      dataNodeIdToTopicNameToSubscriptionCommitIds
+          .computeIfAbsent(
+              message.parseDataNodeIdFromSubscriptionCommitId(), (id) -> new 
HashMap<>())
+          .computeIfAbsent(message.getTopicName(), (topicName) -> new 
ArrayList<>())
+          .add(message.getSubscriptionCommitId());
+    }
+    for (Map.Entry<Integer, Map<String, List<String>>> entry :
+        dataNodeIdToTopicNameToSubscriptionCommitIds.entrySet()) {
+      commitSyncInternal(entry.getKey(), entry.getValue());
+    }
+  }
+
+  protected void commitAsync(Iterable<SubscriptionMessage> messages) {
+    commitAsync(messages, new AsyncCommitCallback() {});
+  }
+
+  protected void commitAsync(Iterable<SubscriptionMessage> messages, 
AsyncCommitCallback callback) {
+
+    // Initiate executor if needed
+    if (asyncCommitExecutor == null) {
+      synchronized (this) {
+        if (asyncCommitExecutor != null) {
+          return;
+        }
+
+        asyncCommitExecutor =
+            Executors.newSingleThreadExecutor(
+                r -> {
+                  Thread t =
+                      new Thread(
+                          Thread.currentThread().getThreadGroup(),
+                          r,
+                          "SubscriptionConsumerAsyncCommitWorker",
+                          0);
+                  if (!t.isDaemon()) {
+                    t.setDaemon(true);
+                  }
+                  if (t.getPriority() != Thread.NORM_PRIORITY) {
+                    t.setPriority(Thread.NORM_PRIORITY);
+                  }
+                  return t;
+                });
+      }
+    }
+
+    asyncCommitExecutor.submit(new AsyncCommitWorker(messages, callback));
+  }
+
+  /////////////////////////////// utility ///////////////////////////////
+
+  private void commitSyncInternal(
+      int dataNodeId, Map<String, List<String>> 
topicNameToSubscriptionCommitIds)
+      throws TException, IOException, StatementExecutionException, 
IoTDBConnectionException {
+    acquireReadLock();
+    try {
+      final SubscriptionProvider provider = getProvider(dataNodeId);
+      if (Objects.isNull(provider) || !provider.isAvailable()) {
+        throw new IoTDBConnectionException(
+            String.format(
+                "something unexpected happened when commit messages to 
subscription provider with data node id %s, the subscription provider may be 
unavailable or not existed",
+                dataNodeId));
+      }
+      
provider.getSessionConnection().commitSync(topicNameToSubscriptionCommitIds);
+    } finally {
+      releaseReadLock();
+    }
+  }
+
   /** Caller should ensure that the method is called in the lock {@link 
#acquireWriteLock()}. */
   private void closeProviders() throws IoTDBConnectionException {
     for (final SubscriptionProvider provider : getAllProviders()) {
@@ -553,4 +659,28 @@ public abstract class SubscriptionConsumer implements 
AutoCloseable {
 
     public abstract SubscriptionPushConsumer buildPushConsumer();
   }
+
+  class AsyncCommitWorker implements Runnable {
+    private final Iterable<SubscriptionMessage> messages;
+    private final AsyncCommitCallback callback;
+
+    public AsyncCommitWorker(Iterable<SubscriptionMessage> messages, 
AsyncCommitCallback callback) {
+      this.messages = messages;
+      this.callback = callback;
+    }
+
+    @Override
+    public void run() {
+      if (isClosed()) {
+        return;
+      }
+
+      try {
+        commitSync(messages);
+        callback.onComplete();
+      } catch (Exception e) {
+        callback.onFailure(e);
+      }
+    }
+  }
 }
diff --git 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPullConsumer.java
 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPullConsumer.java
index 892ae5a44bf..3fba0d730fe 100644
--- 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPullConsumer.java
+++ 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPullConsumer.java
@@ -23,7 +23,6 @@ import org.apache.iotdb.rpc.IoTDBConnectionException;
 import org.apache.iotdb.rpc.StatementExecutionException;
 import org.apache.iotdb.rpc.subscription.config.ConsumerConstant;
 import org.apache.iotdb.rpc.subscription.exception.SubscriptionException;
-import org.apache.iotdb.rpc.subscription.payload.EnrichedTablets;
 
 import org.apache.thrift.TException;
 import org.slf4j.Logger;
@@ -31,12 +30,9 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.time.Duration;
-import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
 import java.util.Properties;
 import java.util.Set;
 import java.util.SortedMap;
@@ -46,7 +42,6 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.stream.Collectors;
 
 public class SubscriptionPullConsumer extends SubscriptionConsumer {
 
@@ -147,20 +142,7 @@ public class SubscriptionPullConsumer extends 
SubscriptionConsumer {
 
   public List<SubscriptionMessage> poll(Set<String> topicNames, long timeoutMs)
       throws TException, IOException, StatementExecutionException {
-    List<EnrichedTablets> enrichedTabletsList = new ArrayList<>();
-
-    acquireReadLock();
-    try {
-      for (final SubscriptionProvider provider : getAllAvailableProviders()) {
-        // TODO: network timeout
-        
enrichedTabletsList.addAll(provider.getSessionConnection().poll(topicNames, 
timeoutMs));
-      }
-    } finally {
-      releaseReadLock();
-    }
-
-    List<SubscriptionMessage> messages =
-        
enrichedTabletsList.stream().map(SubscriptionMessage::new).collect(Collectors.toList());
+    List<SubscriptionMessage> messages = super.poll(topicNames, timeoutMs);
 
     if (autoCommit) {
       long currentTimestamp = System.currentTimeMillis();
@@ -178,44 +160,28 @@ public class SubscriptionPullConsumer extends 
SubscriptionConsumer {
 
   public void commitSync(SubscriptionMessage message)
       throws TException, IOException, StatementExecutionException, 
IoTDBConnectionException {
-    commitSync(Collections.singletonList(message));
+    super.commitSync(Collections.singletonList(message));
   }
 
   public void commitSync(Iterable<SubscriptionMessage> messages)
       throws TException, IOException, StatementExecutionException, 
IoTDBConnectionException {
-    Map<Integer, Map<String, List<String>>> 
dataNodeIdToTopicNameToSubscriptionCommitIds =
-        new HashMap<>();
-    for (SubscriptionMessage message : messages) {
-      dataNodeIdToTopicNameToSubscriptionCommitIds
-          .computeIfAbsent(
-              message.parseDataNodeIdFromSubscriptionCommitId(), (id) -> new 
HashMap<>())
-          .computeIfAbsent(message.getTopicName(), (topicName) -> new 
ArrayList<>())
-          .add(message.getSubscriptionCommitId());
-    }
-    for (Map.Entry<Integer, Map<String, List<String>>> entry :
-        dataNodeIdToTopicNameToSubscriptionCommitIds.entrySet()) {
-      commitSyncInternal(entry.getKey(), entry.getValue());
-    }
+    super.commitSync(messages);
   }
 
-  /////////////////////////////// utility ///////////////////////////////
+  public void commitAsync(SubscriptionMessage message) {
+    super.commitAsync(Collections.singletonList(message));
+  }
 
-  private void commitSyncInternal(
-      int dataNodeId, Map<String, List<String>> 
topicNameToSubscriptionCommitIds)
-      throws TException, IOException, StatementExecutionException, 
IoTDBConnectionException {
-    acquireReadLock();
-    try {
-      final SubscriptionProvider provider = getProvider(dataNodeId);
-      if (Objects.isNull(provider) || !provider.isAvailable()) {
-        throw new IoTDBConnectionException(
-            String.format(
-                "something unexpected happened when commit messages to 
subscription provider with data node id %s, the subscription provider may be 
unavailable or not existed",
-                dataNodeId));
-      }
-      
provider.getSessionConnection().commitSync(topicNameToSubscriptionCommitIds);
-    } finally {
-      releaseReadLock();
-    }
+  public void commitAsync(Iterable<SubscriptionMessage> messages) {
+    super.commitAsync(messages);
+  }
+
+  public void commitAsync(SubscriptionMessage message, AsyncCommitCallback 
callback) {
+    super.commitAsync(Collections.singletonList(message), callback);
+  }
+
+  public void commitAsync(Iterable<SubscriptionMessage> messages, 
AsyncCommitCallback callback) {
+    super.commitAsync(messages, callback);
   }
 
   /////////////////////////////// auto commit ///////////////////////////////
@@ -260,6 +226,7 @@ public class SubscriptionPullConsumer extends 
SubscriptionConsumer {
     }
   }
 
+  @Override
   boolean isClosed() {
     return isClosed.get();
   }
diff --git 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
index 950a396d58d..8355d90a6c1 100644
--- 
a/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
+++ 
b/iotdb-client/session/src/main/java/org/apache/iotdb/session/subscription/SubscriptionPushConsumer.java
@@ -19,17 +19,171 @@
 
 package org.apache.iotdb.session.subscription;
 
+import org.apache.iotdb.rpc.IoTDBConnectionException;
+import org.apache.iotdb.rpc.StatementExecutionException;
+import org.apache.iotdb.rpc.subscription.config.ConsumerConstant;
 import org.apache.iotdb.rpc.subscription.exception.SubscriptionException;
 
-// TODO
+import org.apache.thrift.TException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+
 public class SubscriptionPushConsumer extends SubscriptionConsumer {
 
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(SubscriptionPushConsumer.class);
+
+  private final AckStrategy ackStrategy;
+  private final ConsumeListener consumeListener;
+
+  private ScheduledExecutorService workerExecutor;
+
+  private final AtomicBoolean isClosed = new AtomicBoolean(true);
+
   protected SubscriptionPushConsumer(Builder builder) {
     super(builder);
+
+    this.ackStrategy = builder.ackStrategy;
+    this.consumeListener = builder.consumeListener;
+  }
+
+  public SubscriptionPushConsumer(Properties config) {
+    this(
+        config,
+        (AckStrategy)
+            config.getOrDefault(ConsumerConstant.ACK_STRATEGY_KEY, 
AckStrategy.defaultValue()),
+        (ConsumeListener)
+            config.getOrDefault(
+                ConsumerConstant.CONSUME_LISTENER_KEY,
+                (ConsumeListener) message -> ConsumeResult.SUCCESS));
+  }
+
+  private SubscriptionPushConsumer(
+      Properties config, AckStrategy ackStrategy, ConsumeListener 
consumeListener) {
+    super(new Builder().ackStrategy(ackStrategy), config);
+
+    this.ackStrategy = ackStrategy;
+    this.consumeListener = consumeListener;
+  }
+
+  /////////////////////////////// open & close ///////////////////////////////
+
+  public synchronized void open()
+      throws TException, IoTDBConnectionException, IOException, 
StatementExecutionException {
+    if (!isClosed.get()) {
+      return;
+    }
+
+    super.open();
+
+    launchAutoPollWorker();
+
+    isClosed.set(false);
+  }
+
+  @Override
+  public synchronized void close() throws IoTDBConnectionException {
+    if (isClosed.get()) {
+      return;
+    }
+
+    try {
+      shutdownWorker();
+      super.close();
+    } finally {
+      isClosed.set(true);
+    }
+  }
+
+  @Override
+  boolean isClosed() {
+    return isClosed.get();
+  }
+
+  /////////////////////////////// auto poll worker 
///////////////////////////////
+
+  @SuppressWarnings("unsafeThreadSchedule")
+  private void launchAutoPollWorker() {
+    workerExecutor =
+        Executors.newSingleThreadScheduledExecutor(
+            r -> {
+              Thread t =
+                  new Thread(Thread.currentThread().getThreadGroup(), r, 
"PushConsumerWorker", 0);
+              if (!t.isDaemon()) {
+                t.setDaemon(true);
+              }
+              if (t.getPriority() != Thread.NORM_PRIORITY) {
+                t.setPriority(Thread.NORM_PRIORITY);
+              }
+              return t;
+            });
+    workerExecutor.scheduleAtFixedRate(
+        new PushConsumerWorker(),
+        0,
+        ConsumerConstant.PUSH_CONSUMER_AUTO_POLL_INTERVAL_MS,
+        TimeUnit.MILLISECONDS);
   }
 
+  private void shutdownWorker() {
+    workerExecutor.shutdown();
+    workerExecutor = null;
+  }
+
+  /////////////////////////////// builder ///////////////////////////////
+
   public static class Builder extends SubscriptionConsumer.Builder {
 
+    private AckStrategy ackStrategy = AckStrategy.defaultValue();
+    private ConsumeListener consumeListener = message -> ConsumeResult.SUCCESS;
+
+    public SubscriptionPushConsumer.Builder host(String host) {
+      super.host(host);
+      return this;
+    }
+
+    public SubscriptionPushConsumer.Builder port(int port) {
+      super.port(port);
+      return this;
+    }
+
+    public SubscriptionPushConsumer.Builder username(String username) {
+      super.username(username);
+      return this;
+    }
+
+    public SubscriptionPushConsumer.Builder password(String password) {
+      super.password(password);
+      return this;
+    }
+
+    public SubscriptionPushConsumer.Builder consumerId(String consumerId) {
+      super.consumerId(consumerId);
+      return this;
+    }
+
+    public SubscriptionPushConsumer.Builder consumerGroupId(String 
consumerGroupId) {
+      super.consumerGroupId(consumerGroupId);
+      return this;
+    }
+
+    public SubscriptionPushConsumer.Builder ackStrategy(AckStrategy 
ackStrategy) {
+      this.ackStrategy = ackStrategy;
+      return this;
+    }
+
+    public SubscriptionPushConsumer.Builder consumeListener(ConsumeListener 
consumeListener) {
+      this.consumeListener = consumeListener;
+      return this;
+    }
+
     @Override
     public SubscriptionPullConsumer buildPullConsumer() {
       throw new SubscriptionException(
@@ -41,4 +195,40 @@ public class SubscriptionPushConsumer extends 
SubscriptionConsumer {
       return new SubscriptionPushConsumer(this);
     }
   }
+
+  class PushConsumerWorker implements Runnable {
+    @Override
+    public void run() {
+      if (isClosed()) {
+        return;
+      }
+
+      try {
+        // Poll all subscribed topics by passing an empty set
+        List<SubscriptionMessage> pollResults =
+            poll(Collections.emptySet(), 
ConsumerConstant.PUSH_CONSUMER_AUTO_POLL_TIME_OUT_MS);
+
+        if (ackStrategy.equals(AckStrategy.BEFORE_CONSUME)) {
+          commitSync(pollResults);
+        }
+
+        for (SubscriptionMessage pollResult : pollResults) {
+          ConsumeResult consumeResult = consumeListener.onReceive(pollResult);
+          if (consumeResult.equals(ConsumeResult.FAILURE)) {
+            LOGGER.warn("consumeListener failed when processing message: {}", 
pollResult);
+          }
+        }
+
+        if (ackStrategy.equals(AckStrategy.AFTER_CONSUME)) {
+          commitSync(pollResults);
+        }
+
+      } catch (TException
+          | IOException
+          | StatementExecutionException
+          | IoTDBConnectionException e) {
+        LOGGER.warn("Exception occurred when auto polling: ", e);
+      }
+    }
+  }
 }


Reply via email to