This is an automated email from the ASF dual-hosted git repository.

haonan pushed a commit to branch TNonblockingSSLSocket
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit f36f33aed1ed55033ff1640ed49d1f722b53ddf8
Author: HTHou <[email protected]>
AuthorDate: Fri Dec 19 09:52:32 2025 +0800

    Add a new TNonblockingSSLSocket
---
 .../apache/iotdb/rpc/TNonblockingSSLSocket.java    | 412 +++++++++++++++++++++
 .../iotdb/rpc/TNonblockingTransportWrapper.java    |   4 +-
 2 files changed, 414 insertions(+), 2 deletions(-)

diff --git 
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java
 
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java
new file mode 100644
index 00000000000..e11b22bcc3a
--- /dev/null
+++ 
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java
@@ -0,0 +1,412 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.rpc;
+
+import org.apache.thrift.transport.TNonblockingSocket;
+import org.apache.thrift.transport.TTransportException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLEngineResult;
+import javax.net.ssl.SSLEngineResult.HandshakeStatus;
+import javax.net.ssl.SSLException;
+import javax.net.ssl.TrustManagerFactory;
+
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.MalformedURLException;
+import java.net.URL;
+import java.nio.ByteBuffer;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.Selector;
+import java.security.KeyStore;
+
+/** Transport for use with async client. */
+public class TNonblockingSSLSocket extends TNonblockingSocket {
+
+  private static final Logger LOGGER =
+      LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName());
+
+  private final SSLEngine sslEngine_;
+
+  private final ByteBuffer appUnwrap;
+  private final ByteBuffer netUnwrap;
+
+  private final ByteBuffer appWrap;
+  private final ByteBuffer netWrap;
+
+  private ByteBuffer decodedBytes;
+
+  private boolean isHandshakeCompleted;
+
+  private SelectionKey selectionKey;
+
+  public TNonblockingSSLSocket(
+      String host,
+      int port,
+      int timeout,
+      String keystore,
+      String keystorePassword,
+      String truststore,
+      String truststorePassword)
+      throws TTransportException, IOException {
+    this(
+        host,
+        port,
+        timeout,
+        createSSLContext(keystore, keystorePassword, truststore, 
truststorePassword));
+  }
+
+  private static SSLContext createSSLContext(
+      String keystore, String keystorePassword, String truststore, String 
truststorePassword)
+      throws TTransportException {
+    SSLContext ctx;
+    InputStream in = null;
+    InputStream is = null;
+
+    try {
+      ctx = SSLContext.getInstance("TLS");
+      TrustManagerFactory tmf = null;
+      KeyManagerFactory kmf = null;
+
+      if (truststore != null) {
+        tmf = 
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+        KeyStore ts = KeyStore.getInstance("JKS");
+        in = getStoreAsStream(truststore);
+        ts.load(in, (truststorePassword != null ? 
truststorePassword.toCharArray() : null));
+        tmf.init(ts);
+      }
+
+      if (keystore != null) {
+        kmf = 
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+        KeyStore ks = KeyStore.getInstance("JKS");
+        is = getStoreAsStream(keystore);
+        ks.load(is, keystorePassword.toCharArray());
+        kmf.init(ks, keystorePassword.toCharArray());
+      }
+
+      if (keystore != null && truststore != null) {
+        ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
+      } else if (keystore != null) {
+        ctx.init(kmf.getKeyManagers(), null, null);
+      } else {
+        ctx.init(null, tmf.getTrustManagers(), null);
+      }
+
+    } catch (Exception e) {
+      throw new TTransportException(
+          TTransportException.NOT_OPEN, "Error creating the transport", e);
+    } finally {
+      if (in != null) {
+        try {
+          in.close();
+        } catch (IOException e) {
+          LOGGER.warn("Unable to close stream", e);
+        }
+      }
+      if (is != null) {
+        try {
+          is.close();
+        } catch (IOException e) {
+          LOGGER.warn("Unable to close stream", e);
+        }
+      }
+    }
+
+    return ctx;
+  }
+
+  private static InputStream getStoreAsStream(String store) throws IOException 
{
+    try {
+      return new FileInputStream(store);
+    } catch (FileNotFoundException e) {
+    }
+
+    InputStream storeStream = null;
+    try {
+      storeStream = new URL(store).openStream();
+      if (storeStream != null) {
+        return storeStream;
+      }
+    } catch (MalformedURLException e) {
+    }
+
+    storeStream = 
Thread.currentThread().getContextClassLoader().getResourceAsStream(store);
+
+    if (storeStream != null) {
+      return storeStream;
+    } else {
+      throw new IOException("Could not load file: " + store);
+    }
+  }
+
+  protected TNonblockingSSLSocket(String host, int port, int timeout, 
SSLContext sslContext)
+      throws IOException, TTransportException {
+    super(host, port, timeout);
+    sslEngine_ = sslContext.createSSLEngine(host, port);
+    sslEngine_.setUseClientMode(true);
+
+    int appBufferSize = sslEngine_.getSession().getApplicationBufferSize();
+    int netBufferSize = sslEngine_.getSession().getPacketBufferSize();
+    appUnwrap = ByteBuffer.allocate(appBufferSize);
+    netUnwrap = ByteBuffer.allocate(netBufferSize);
+    appWrap = ByteBuffer.allocate(appBufferSize);
+    netWrap = ByteBuffer.allocate(netBufferSize);
+    decodedBytes = ByteBuffer.allocate(appBufferSize);
+    decodedBytes.flip();
+    isHandshakeCompleted = false;
+  }
+
+  /**
+   * Register the new SocketChannel with our Selector, indicating we'd like to 
be notified when it's
+   * ready for I/O.
+   *
+   * @param selector
+   * @return the selection key for this socket.
+   */
+  public SelectionKey registerSelector(Selector selector, int interests) 
throws IOException {
+    selectionKey = super.registerSelector(selector, interests);
+    return selectionKey;
+  }
+
+  /** Checks whether the socket is connected. */
+  public boolean isOpen() {
+    // isConnected() does not return false after close(), but isOpen() does
+    return super.isOpen() && isHandshakeCompleted;
+  }
+
+  /** Do not call, the implementation provides its own lazy non-blocking 
connect. */
+  public void open() throws TTransportException {
+    throw new RuntimeException("open() is not implemented for 
TNonblockingSSLSocket");
+  }
+
+  /** Perform a nonblocking read into buffer. */
+  public int read(ByteBuffer buffer) throws TTransportException {
+    int numBytes = buffer.limit();
+    while (decodedBytes.remaining() < numBytes) {
+      HandshakeStatus hs = sslEngine_.getHandshakeStatus();
+      if (hs == HandshakeStatus.FINISHED)
+        throw new TTransportException(
+            TTransportException.UNKNOWN, "Read operation is terminated. 
Handshake is completed");
+      try {
+        if (doUnwrap() == -1) {
+          throw new IOException("Unable to read " + numBytes + " bytes");
+        }
+      } catch (IOException exc) {
+        throw new TTransportException(TTransportException.UNKNOWN, 
exc.getMessage());
+      }
+      if (appUnwrap.position() > 0) {
+        int t;
+        appUnwrap.flip();
+        if (decodedBytes.position() > 0) decodedBytes.flip();
+        t = appUnwrap.limit() + decodedBytes.limit();
+        byte[] tmpBuffer = new byte[t];
+        decodedBytes.get(tmpBuffer, 0, decodedBytes.remaining());
+        appUnwrap.get(tmpBuffer, 0, appUnwrap.remaining());
+        if (appUnwrap.position() > 0) {
+          appUnwrap.clear();
+          appUnwrap.flip();
+          appUnwrap.compact();
+        }
+        decodedBytes = ByteBuffer.wrap(tmpBuffer);
+      }
+    }
+    byte[] b = new byte[numBytes];
+    decodedBytes.get(b, 0, numBytes);
+    if (decodedBytes.position() > 0) {
+      decodedBytes.compact();
+      decodedBytes.flip();
+    }
+    buffer.put(b);
+    selectionKey.interestOps(SelectionKey.OP_WRITE);
+    return numBytes;
+  }
+
+  /** Perform a nonblocking write of the data in buffer; */
+  public int write(ByteBuffer buffer) throws TTransportException {
+    int numBytes = 0;
+
+    if (buffer.position() > 0) buffer.flip();
+
+    int nTransfer;
+    int num;
+    while (buffer.remaining() != 0) {
+      nTransfer = Math.min(appWrap.remaining(), buffer.remaining());
+      if (nTransfer > 0) {
+        appWrap.put(buffer.array(), buffer.arrayOffset() + buffer.position(), 
nTransfer);
+        buffer.position(buffer.position() + nTransfer);
+      }
+
+      try {
+        num = doWrap();
+      } catch (IOException iox) {
+        throw new TTransportException(TTransportException.UNKNOWN, iox);
+      }
+      if (num < 0) {
+        LOGGER.error("Failed while writing. Probably server is down");
+        return -1;
+      }
+      numBytes += num;
+    }
+    return numBytes;
+  }
+
+  /** Closes the socket. */
+  public void close() {
+    sslEngine_.closeOutbound();
+    super.close();
+  }
+
+  /** {@inheritDoc} */
+  public boolean startConnect() throws IOException {
+    if (this.isOpen()) {
+      return true;
+    }
+    sslEngine_.beginHandshake();
+    return super.startConnect() && doHandShake();
+  }
+
+  /** {@inheritDoc} */
+  public boolean finishConnect() throws IOException {
+    return super.finishConnect() && doHandShake();
+  }
+
+  private synchronized boolean doHandShake() throws IOException {
+    LOGGER.debug("Handshake is started");
+    while (true) {
+      HandshakeStatus hs = sslEngine_.getHandshakeStatus();
+      switch (hs) {
+        case NEED_UNWRAP:
+          if (doUnwrap() == -1) {
+            LOGGER.error("Unexpected. Handshake failed abruptly during 
unwrap");
+            return false;
+          }
+          break;
+        case NEED_WRAP:
+          if (doWrap() == -1) {
+            LOGGER.error("Unexpected. Handshake failed abruptly during wrap");
+            return false;
+          }
+          break;
+        case NEED_TASK:
+          if (!doTask()) {
+            LOGGER.error("Unexpected. Handshake failed abruptly during task");
+            return false;
+          }
+          break;
+        case FINISHED:
+        case NOT_HANDSHAKING:
+          isHandshakeCompleted = true;
+          return true;
+        default:
+          LOGGER.error("Unknown handshake status. Handshake failed");
+          return false;
+      }
+    }
+  }
+
+  private synchronized boolean doTask() {
+    Runnable runnable;
+    while ((runnable = sslEngine_.getDelegatedTask()) != null) {
+      runnable.run();
+    }
+    HandshakeStatus hs = sslEngine_.getHandshakeStatus();
+    if (hs == HandshakeStatus.NEED_TASK) {
+      try {
+        throw new TTransportException(
+            TTransportException.UNKNOWN, "handshake shouldn't need additional 
tasks");
+      } catch (TTransportException e) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  private synchronized int doUnwrap() throws IOException {
+    int num = getSocketChannel().read(netUnwrap);
+    if (num < 0) {
+      LOGGER.error("Failed during read operation. Probably server is down");
+      return -1;
+    }
+    SSLEngineResult unwrapResult;
+
+    try {
+      netUnwrap.flip();
+      unwrapResult = sslEngine_.unwrap(netUnwrap, appUnwrap);
+      netUnwrap.compact();
+    } catch (SSLException ex) {
+      LOGGER.error(ex.getMessage());
+      throw ex;
+    }
+
+    switch (unwrapResult.getStatus()) {
+      case OK:
+        if (appUnwrap.position() > 0) {
+          appUnwrap.flip();
+          appUnwrap.compact();
+        }
+        break;
+      case CLOSED:
+        return -1;
+      case BUFFER_OVERFLOW:
+        throw new IllegalStateException("Failed to unwrap");
+      case BUFFER_UNDERFLOW:
+        break;
+    }
+    return num;
+  }
+
+  private synchronized int doWrap() throws IOException {
+    int num = 0;
+    SSLEngineResult wrapResult;
+    try {
+      appWrap.flip();
+      wrapResult = sslEngine_.wrap(appWrap, netWrap);
+      appWrap.compact();
+    } catch (SSLException exc) {
+      LOGGER.error(exc.getMessage());
+      throw exc;
+    }
+
+    switch (wrapResult.getStatus()) {
+      case OK:
+        if (netWrap.position() > 0) {
+          netWrap.flip();
+          num = getSocketChannel().write(netWrap);
+          netWrap.compact();
+        }
+        break;
+      case BUFFER_UNDERFLOW:
+        // try again later
+        break;
+      case BUFFER_OVERFLOW:
+        throw new IllegalStateException("Failed to wrap");
+      case CLOSED:
+        LOGGER.error("SSL session is closed");
+        return -1;
+    }
+    return num;
+  }
+}
diff --git 
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingTransportWrapper.java
 
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingTransportWrapper.java
index 40cf543fb0a..d985542b17b 100644
--- 
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingTransportWrapper.java
+++ 
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingTransportWrapper.java
@@ -68,9 +68,9 @@ public class TNonblockingTransportWrapper {
       String trustStorePath,
       String trustStorePwd) {
     try {
-      return new NettyTNonblockingTransport(
+      return new TNonblockingSSLSocket(
           host, port, timeout, keyStorePath, keyStorePwd, trustStorePath, 
trustStorePwd);
-    } catch (TTransportException e) {
+    } catch (Exception e) {
       // never happen
       return null;
     }

Reply via email to