Copilot commented on code in PR #16935: URL: https://github.com/apache/iotdb/pull/16935#discussion_r2639977012
########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException( + TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch (FileNotFoundException e) { Review Comment: Empty catch block silently swallows the FileNotFoundException. This makes debugging difficult when a keystore or truststore file cannot be found as a file path. Consider logging the exception at debug level or adding a comment explaining why the exception is intentionally ignored (since the code will try alternative loading methods). ```suggestion } catch (FileNotFoundException e) { // File not found at the given path; will try URL and classpath loading next. LOGGER.debug("Store file not found at path '{}', trying alternative loading methods", store, e); ``` ########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException( + TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch (FileNotFoundException e) { + } + + InputStream storeStream = null; + try { + storeStream = new URL(store).openStream(); + if (storeStream != null) { + return storeStream; + } + } catch (MalformedURLException e) { + } + + storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); + + if (storeStream != null) { + return storeStream; + } else { + throw new IOException("Could not load file: " + store); + } + } + + protected TNonblockingSSLSocket(String host, int port, int timeout, SSLContext sslContext) + throws IOException, TTransportException { + super(host, port, timeout); + sslEngine_ = sslContext.createSSLEngine(host, port); + sslEngine_.setUseClientMode(true); + + int appBufferSize = sslEngine_.getSession().getApplicationBufferSize(); + int netBufferSize = sslEngine_.getSession().getPacketBufferSize(); + appUnwrap = ByteBuffer.allocate(appBufferSize); + netUnwrap = ByteBuffer.allocate(netBufferSize); + appWrap = ByteBuffer.allocate(appBufferSize); + netWrap = ByteBuffer.allocate(netBufferSize); + decodedBytes = ByteBuffer.allocate(appBufferSize); + decodedBytes.flip(); + isHandshakeCompleted = false; + } + + /** {@inheritDoc} */ + @Override + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + selectionKey = super.registerSelector(selector, interests); + return selectionKey; + } + + /** {@inheritDoc} */ + @Override + public boolean isOpen() { + // isConnected() does not return false after close(), but isOpen() does + return super.isOpen() && isHandshakeCompleted; + } + + /** {@inheritDoc} */ + @Override + public void open() throws TTransportException { + throw new RuntimeException("open() is not implemented for TNonblockingSSLSocket"); + } + + /** {@inheritDoc} */ + @Override + public synchronized int read(ByteBuffer buffer) throws TTransportException { + int numBytes = buffer.remaining(); + while (decodedBytes.remaining() < numBytes) { + try { + if (doUnwrap() == -1) { + throw new IOException("Unable to read " + numBytes + " bytes"); + } + } catch (IOException exc) { + throw new TTransportException(TTransportException.UNKNOWN, exc.getMessage()); + } + if (appUnwrap.hasRemaining() + || (decodedBytes.position() > 0 && decodedBytes.flip().hasRemaining())) { + appUnwrap.flip(); + decodedBytes.flip(); + + ByteBuffer tempBuffer = + ByteBuffer.allocate(appUnwrap.remaining() + decodedBytes.remaining()); + tempBuffer.put(decodedBytes); + tempBuffer.put(appUnwrap); + tempBuffer.flip(); + + decodedBytes = tempBuffer; + appUnwrap.clear(); + } + } + int oldLimit = decodedBytes.limit(); + decodedBytes.limit(decodedBytes.position() + numBytes); + buffer.put(decodedBytes); + decodedBytes.limit(oldLimit); + selectionKey.interestOps(SelectionKey.OP_WRITE); + return numBytes; + } + + /** {@inheritDoc} */ + @Override + public synchronized int write(ByteBuffer buffer) throws TTransportException { + int numBytes = 0; + + if (buffer.position() > 0) buffer.flip(); + + int nTransfer; + int num; + while (buffer.hasRemaining()) { + nTransfer = Math.min(appWrap.remaining(), buffer.remaining()); + if (nTransfer > 0) { + appWrap.put(buffer); + } + + try { + num = doWrap(); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + if (num < 0) { + LOGGER.error("Failed while writing. Probably server is down"); + return -1; + } + numBytes += num; + } + return numBytes; + } + + /** {@inheritDoc} */ + @Override + public void close() { + sslEngine_.closeOutbound(); + super.close(); + } + + /** {@inheritDoc} */ + @Override + public boolean startConnect() throws IOException { + if (this.isOpen()) { + return true; + } + sslEngine_.beginHandshake(); + return super.startConnect() && doHandShake(); + } + + /** {@inheritDoc} */ + @Override + public boolean finishConnect() throws IOException { + return super.finishConnect() && doHandShake(); + } + + private synchronized boolean doHandShake() throws IOException { + LOGGER.debug("Handshake is started"); + while (true) { + HandshakeStatus hs = sslEngine_.getHandshakeStatus(); + switch (hs) { + case NEED_UNWRAP: + if (doUnwrap() == -1) { + LOGGER.error("Unexpected. Handshake failed abruptly during unwrap"); + return false; + } + break; + case NEED_WRAP: + if (doWrap() == -1) { + LOGGER.error("Unexpected. Handshake failed abruptly during wrap"); + return false; + } + break; + case NEED_TASK: + if (!doTask()) { + LOGGER.error("Unexpected. Handshake failed abruptly during task"); + return false; + } + break; + case FINISHED: + case NOT_HANDSHAKING: + isHandshakeCompleted = true; + return true; + default: + LOGGER.error("Unknown handshake status. Handshake failed"); + return false; + } + } + } + + private synchronized boolean doTask() { + Runnable runnable; + while ((runnable = sslEngine_.getDelegatedTask()) != null) { + runnable.run(); + } Review Comment: The method doTask() executes SSL engine delegated tasks directly on the calling thread using runnable.run(). For production systems, long-running or CPU-intensive tasks should be executed on a separate thread pool to avoid blocking the I/O thread. Consider using an executor service for running delegated tasks asynchronously. ########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException( + TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch (FileNotFoundException e) { + } + + InputStream storeStream = null; + try { + storeStream = new URL(store).openStream(); + if (storeStream != null) { + return storeStream; + } + } catch (MalformedURLException e) { + } + + storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); + + if (storeStream != null) { + return storeStream; + } else { + throw new IOException("Could not load file: " + store); + } + } + + protected TNonblockingSSLSocket(String host, int port, int timeout, SSLContext sslContext) + throws IOException, TTransportException { + super(host, port, timeout); + sslEngine_ = sslContext.createSSLEngine(host, port); + sslEngine_.setUseClientMode(true); + + int appBufferSize = sslEngine_.getSession().getApplicationBufferSize(); + int netBufferSize = sslEngine_.getSession().getPacketBufferSize(); + appUnwrap = ByteBuffer.allocate(appBufferSize); + netUnwrap = ByteBuffer.allocate(netBufferSize); + appWrap = ByteBuffer.allocate(appBufferSize); + netWrap = ByteBuffer.allocate(netBufferSize); + decodedBytes = ByteBuffer.allocate(appBufferSize); + decodedBytes.flip(); + isHandshakeCompleted = false; + } + + /** {@inheritDoc} */ + @Override + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + selectionKey = super.registerSelector(selector, interests); + return selectionKey; + } + + /** {@inheritDoc} */ + @Override + public boolean isOpen() { + // isConnected() does not return false after close(), but isOpen() does + return super.isOpen() && isHandshakeCompleted; + } + + /** {@inheritDoc} */ + @Override + public void open() throws TTransportException { + throw new RuntimeException("open() is not implemented for TNonblockingSSLSocket"); + } + + /** {@inheritDoc} */ + @Override + public synchronized int read(ByteBuffer buffer) throws TTransportException { + int numBytes = buffer.remaining(); + while (decodedBytes.remaining() < numBytes) { + try { + if (doUnwrap() == -1) { + throw new IOException("Unable to read " + numBytes + " bytes"); + } + } catch (IOException exc) { + throw new TTransportException(TTransportException.UNKNOWN, exc.getMessage()); + } + if (appUnwrap.hasRemaining() + || (decodedBytes.position() > 0 && decodedBytes.flip().hasRemaining())) { + appUnwrap.flip(); + decodedBytes.flip(); + + ByteBuffer tempBuffer = + ByteBuffer.allocate(appUnwrap.remaining() + decodedBytes.remaining()); + tempBuffer.put(decodedBytes); + tempBuffer.put(appUnwrap); + tempBuffer.flip(); + + decodedBytes = tempBuffer; + appUnwrap.clear(); + } + } + int oldLimit = decodedBytes.limit(); + decodedBytes.limit(decodedBytes.position() + numBytes); + buffer.put(decodedBytes); + decodedBytes.limit(oldLimit); + selectionKey.interestOps(SelectionKey.OP_WRITE); + return numBytes; + } + + /** {@inheritDoc} */ + @Override + public synchronized int write(ByteBuffer buffer) throws TTransportException { + int numBytes = 0; + + if (buffer.position() > 0) buffer.flip(); + + int nTransfer; + int num; + while (buffer.hasRemaining()) { + nTransfer = Math.min(appWrap.remaining(), buffer.remaining()); + if (nTransfer > 0) { + appWrap.put(buffer); + } + + try { + num = doWrap(); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + if (num < 0) { + LOGGER.error("Failed while writing. Probably server is down"); + return -1; + } + numBytes += num; + } + return numBytes; + } + + /** {@inheritDoc} */ + @Override + public void close() { + sslEngine_.closeOutbound(); + super.close(); + } + + /** {@inheritDoc} */ + @Override + public boolean startConnect() throws IOException { + if (this.isOpen()) { + return true; + } + sslEngine_.beginHandshake(); + return super.startConnect() && doHandShake(); + } + + /** {@inheritDoc} */ + @Override + public boolean finishConnect() throws IOException { + return super.finishConnect() && doHandShake(); + } + + private synchronized boolean doHandShake() throws IOException { + LOGGER.debug("Handshake is started"); + while (true) { + HandshakeStatus hs = sslEngine_.getHandshakeStatus(); + switch (hs) { + case NEED_UNWRAP: + if (doUnwrap() == -1) { + LOGGER.error("Unexpected. Handshake failed abruptly during unwrap"); + return false; + } + break; + case NEED_WRAP: + if (doWrap() == -1) { + LOGGER.error("Unexpected. Handshake failed abruptly during wrap"); + return false; + } + break; + case NEED_TASK: + if (!doTask()) { + LOGGER.error("Unexpected. Handshake failed abruptly during task"); + return false; + } + break; + case FINISHED: + case NOT_HANDSHAKING: + isHandshakeCompleted = true; + return true; + default: + LOGGER.error("Unknown handshake status. Handshake failed"); + return false; + } + } + } + + private synchronized boolean doTask() { + Runnable runnable; + while ((runnable = sslEngine_.getDelegatedTask()) != null) { + runnable.run(); + } + HandshakeStatus hs = sslEngine_.getHandshakeStatus(); + return hs != HandshakeStatus.NEED_TASK; + } + + private synchronized int doUnwrap() throws IOException { + int num = getSocketChannel().read(netUnwrap); + if (num < 0) { + LOGGER.error("Failed during read operation. Probably server is down"); + return -1; + } + SSLEngineResult unwrapResult; + + try { + netUnwrap.flip(); + unwrapResult = sslEngine_.unwrap(netUnwrap, appUnwrap); + netUnwrap.compact(); + } catch (SSLException ex) { + LOGGER.error(ex.getMessage()); + throw ex; + } + + switch (unwrapResult.getStatus()) { + case OK: + if (appUnwrap.position() > 0) { + appUnwrap.flip(); + appUnwrap.compact(); + } + break; + case CLOSED: + return -1; + case BUFFER_OVERFLOW: + throw new IllegalStateException("Failed to unwrap"); + case BUFFER_UNDERFLOW: + break; + } + return num; + } + + private synchronized int doWrap() throws IOException { + int num = 0; + SSLEngineResult wrapResult; + try { + appWrap.flip(); + wrapResult = sslEngine_.wrap(appWrap, netWrap); + appWrap.clear(); + } catch (SSLException exc) { + LOGGER.error(exc.getMessage()); + throw exc; + } + + switch (wrapResult.getStatus()) { + case OK: + if (netWrap.position() > 0) { + netWrap.flip(); + num = getSocketChannel().write(netWrap); + netWrap.clear(); + } + break; + case BUFFER_UNDERFLOW: + // try again later + break; + case BUFFER_OVERFLOW: + throw new IllegalStateException("Failed to wrap"); Review Comment: The error message "Failed to wrap" is vague and doesn't provide enough context for debugging. BUFFER_OVERFLOW in the wrap operation typically indicates that the network buffer is too small for the encrypted data. Consider providing a more descriptive error message that explains the actual problem, such as "SSL wrap failed: network buffer overflow - buffer size may be insufficient for encrypted data". ```suggestion throw new IllegalStateException( "SSL wrap failed: network buffer overflow - buffer size may be insufficient for encrypted data"); ``` ########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException( + TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch (FileNotFoundException e) { + } + + InputStream storeStream = null; + try { + storeStream = new URL(store).openStream(); + if (storeStream != null) { + return storeStream; + } + } catch (MalformedURLException e) { + } + + storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); + + if (storeStream != null) { + return storeStream; + } else { + throw new IOException("Could not load file: " + store); + } + } + + protected TNonblockingSSLSocket(String host, int port, int timeout, SSLContext sslContext) + throws IOException, TTransportException { + super(host, port, timeout); + sslEngine_ = sslContext.createSSLEngine(host, port); + sslEngine_.setUseClientMode(true); + + int appBufferSize = sslEngine_.getSession().getApplicationBufferSize(); + int netBufferSize = sslEngine_.getSession().getPacketBufferSize(); + appUnwrap = ByteBuffer.allocate(appBufferSize); + netUnwrap = ByteBuffer.allocate(netBufferSize); + appWrap = ByteBuffer.allocate(appBufferSize); + netWrap = ByteBuffer.allocate(netBufferSize); + decodedBytes = ByteBuffer.allocate(appBufferSize); + decodedBytes.flip(); + isHandshakeCompleted = false; + } + + /** {@inheritDoc} */ + @Override + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + selectionKey = super.registerSelector(selector, interests); + return selectionKey; + } + + /** {@inheritDoc} */ + @Override + public boolean isOpen() { + // isConnected() does not return false after close(), but isOpen() does + return super.isOpen() && isHandshakeCompleted; + } + + /** {@inheritDoc} */ + @Override + public void open() throws TTransportException { + throw new RuntimeException("open() is not implemented for TNonblockingSSLSocket"); Review Comment: The open() method throws a RuntimeException instead of the declared TTransportException, which is inconsistent with the method signature and violates the contract. Consider throwing TTransportException with NOT_OPEN error code instead, or document why RuntimeException is used here. ```suggestion throw new TTransportException( TTransportException.NOT_OPEN, "open() is not implemented for TNonblockingSSLSocket"); ``` ########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); Review Comment: Variable [tmf](1) may be null at this access because of [this](2) assignment. ```suggestion if (kmf != null && tmf != null) { ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); } else if (kmf != null) { ctx.init(kmf.getKeyManagers(), null, null); } else if (tmf != null) { ctx.init(null, tmf.getTrustManagers(), null); } else { ctx.init(null, null, null); ``` ########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException( + TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch (FileNotFoundException e) { + } + + InputStream storeStream = null; + try { + storeStream = new URL(store).openStream(); + if (storeStream != null) { + return storeStream; + } + } catch (MalformedURLException e) { + } + + storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); + + if (storeStream != null) { + return storeStream; + } else { + throw new IOException("Could not load file: " + store); + } + } + + protected TNonblockingSSLSocket(String host, int port, int timeout, SSLContext sslContext) + throws IOException, TTransportException { + super(host, port, timeout); + sslEngine_ = sslContext.createSSLEngine(host, port); + sslEngine_.setUseClientMode(true); + + int appBufferSize = sslEngine_.getSession().getApplicationBufferSize(); + int netBufferSize = sslEngine_.getSession().getPacketBufferSize(); + appUnwrap = ByteBuffer.allocate(appBufferSize); + netUnwrap = ByteBuffer.allocate(netBufferSize); + appWrap = ByteBuffer.allocate(appBufferSize); + netWrap = ByteBuffer.allocate(netBufferSize); + decodedBytes = ByteBuffer.allocate(appBufferSize); + decodedBytes.flip(); + isHandshakeCompleted = false; + } + + /** {@inheritDoc} */ + @Override + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + selectionKey = super.registerSelector(selector, interests); + return selectionKey; + } + + /** {@inheritDoc} */ + @Override + public boolean isOpen() { + // isConnected() does not return false after close(), but isOpen() does + return super.isOpen() && isHandshakeCompleted; + } + + /** {@inheritDoc} */ + @Override + public void open() throws TTransportException { + throw new RuntimeException("open() is not implemented for TNonblockingSSLSocket"); + } + + /** {@inheritDoc} */ + @Override + public synchronized int read(ByteBuffer buffer) throws TTransportException { + int numBytes = buffer.remaining(); + while (decodedBytes.remaining() < numBytes) { + try { + if (doUnwrap() == -1) { + throw new IOException("Unable to read " + numBytes + " bytes"); + } + } catch (IOException exc) { + throw new TTransportException(TTransportException.UNKNOWN, exc.getMessage()); + } + if (appUnwrap.hasRemaining() + || (decodedBytes.position() > 0 && decodedBytes.flip().hasRemaining())) { + appUnwrap.flip(); + decodedBytes.flip(); + + ByteBuffer tempBuffer = + ByteBuffer.allocate(appUnwrap.remaining() + decodedBytes.remaining()); + tempBuffer.put(decodedBytes); + tempBuffer.put(appUnwrap); + tempBuffer.flip(); + + decodedBytes = tempBuffer; + appUnwrap.clear(); + } + } + int oldLimit = decodedBytes.limit(); + decodedBytes.limit(decodedBytes.position() + numBytes); + buffer.put(decodedBytes); + decodedBytes.limit(oldLimit); + selectionKey.interestOps(SelectionKey.OP_WRITE); + return numBytes; + } + + /** {@inheritDoc} */ + @Override + public synchronized int write(ByteBuffer buffer) throws TTransportException { + int numBytes = 0; + + if (buffer.position() > 0) buffer.flip(); + + int nTransfer; + int num; + while (buffer.hasRemaining()) { + nTransfer = Math.min(appWrap.remaining(), buffer.remaining()); + if (nTransfer > 0) { + appWrap.put(buffer); Review Comment: The call to appWrap.put(buffer) on line 249 transfers nTransfer bytes from buffer to appWrap. However, ByteBuffer.put(ByteBuffer) transfers all remaining bytes from the source buffer, not a specific number. This will cause a BufferOverflowException when buffer has more bytes remaining than appWrap can hold. Use a bounded transfer by slicing the buffer or using put with offset and length. ```suggestion if (nTransfer > 0) { int originalLimit = buffer.limit(); if (nTransfer < buffer.remaining()) { buffer.limit(buffer.position() + nTransfer); } appWrap.put(buffer); buffer.limit(originalLimit); ``` ########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException( + TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch (FileNotFoundException e) { + } + + InputStream storeStream = null; + try { + storeStream = new URL(store).openStream(); + if (storeStream != null) { + return storeStream; + } + } catch (MalformedURLException e) { + } + + storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); Review Comment: Potential resource leak in getStoreAsStream. If the URL.openStream() on line 149 succeeds and returns a non-null stream but an exception occurs before the stream is returned (though unlikely in this specific code path), the stream could leak. While the current code structure makes this unlikely, consider using try-with-resources or ensuring cleanup in a finally block for better resource management. ```suggestion try { return new URL(store).openStream(); } catch (MalformedURLException e) { // Ignore and fall through to classloader-based lookup } InputStream storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); ``` ########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException( + TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch (FileNotFoundException e) { + } + + InputStream storeStream = null; + try { + storeStream = new URL(store).openStream(); + if (storeStream != null) { + return storeStream; + } + } catch (MalformedURLException e) { + } + + storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); + + if (storeStream != null) { + return storeStream; + } else { + throw new IOException("Could not load file: " + store); + } + } + + protected TNonblockingSSLSocket(String host, int port, int timeout, SSLContext sslContext) + throws IOException, TTransportException { + super(host, port, timeout); + sslEngine_ = sslContext.createSSLEngine(host, port); + sslEngine_.setUseClientMode(true); + + int appBufferSize = sslEngine_.getSession().getApplicationBufferSize(); + int netBufferSize = sslEngine_.getSession().getPacketBufferSize(); + appUnwrap = ByteBuffer.allocate(appBufferSize); + netUnwrap = ByteBuffer.allocate(netBufferSize); + appWrap = ByteBuffer.allocate(appBufferSize); + netWrap = ByteBuffer.allocate(netBufferSize); + decodedBytes = ByteBuffer.allocate(appBufferSize); + decodedBytes.flip(); + isHandshakeCompleted = false; + } + + /** {@inheritDoc} */ + @Override + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + selectionKey = super.registerSelector(selector, interests); + return selectionKey; + } + + /** {@inheritDoc} */ + @Override + public boolean isOpen() { + // isConnected() does not return false after close(), but isOpen() does + return super.isOpen() && isHandshakeCompleted; + } + + /** {@inheritDoc} */ + @Override + public void open() throws TTransportException { + throw new RuntimeException("open() is not implemented for TNonblockingSSLSocket"); + } + + /** {@inheritDoc} */ + @Override + public synchronized int read(ByteBuffer buffer) throws TTransportException { + int numBytes = buffer.remaining(); + while (decodedBytes.remaining() < numBytes) { + try { + if (doUnwrap() == -1) { + throw new IOException("Unable to read " + numBytes + " bytes"); + } + } catch (IOException exc) { + throw new TTransportException(TTransportException.UNKNOWN, exc.getMessage()); + } + if (appUnwrap.hasRemaining() + || (decodedBytes.position() > 0 && decodedBytes.flip().hasRemaining())) { + appUnwrap.flip(); + decodedBytes.flip(); + + ByteBuffer tempBuffer = + ByteBuffer.allocate(appUnwrap.remaining() + decodedBytes.remaining()); + tempBuffer.put(decodedBytes); + tempBuffer.put(appUnwrap); + tempBuffer.flip(); + + decodedBytes = tempBuffer; + appUnwrap.clear(); + } + } + int oldLimit = decodedBytes.limit(); + decodedBytes.limit(decodedBytes.position() + numBytes); + buffer.put(decodedBytes); + decodedBytes.limit(oldLimit); + selectionKey.interestOps(SelectionKey.OP_WRITE); + return numBytes; + } + + /** {@inheritDoc} */ + @Override + public synchronized int write(ByteBuffer buffer) throws TTransportException { + int numBytes = 0; + + if (buffer.position() > 0) buffer.flip(); + + int nTransfer; + int num; + while (buffer.hasRemaining()) { + nTransfer = Math.min(appWrap.remaining(), buffer.remaining()); + if (nTransfer > 0) { + appWrap.put(buffer); + } + + try { + num = doWrap(); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + if (num < 0) { + LOGGER.error("Failed while writing. Probably server is down"); + return -1; + } + numBytes += num; + } + return numBytes; + } + + /** {@inheritDoc} */ + @Override + public void close() { + sslEngine_.closeOutbound(); + super.close(); + } + + /** {@inheritDoc} */ + @Override + public boolean startConnect() throws IOException { + if (this.isOpen()) { + return true; + } + sslEngine_.beginHandshake(); + return super.startConnect() && doHandShake(); + } + + /** {@inheritDoc} */ + @Override + public boolean finishConnect() throws IOException { + return super.finishConnect() && doHandShake(); + } + + private synchronized boolean doHandShake() throws IOException { + LOGGER.debug("Handshake is started"); + while (true) { + HandshakeStatus hs = sslEngine_.getHandshakeStatus(); + switch (hs) { + case NEED_UNWRAP: + if (doUnwrap() == -1) { + LOGGER.error("Unexpected. Handshake failed abruptly during unwrap"); + return false; + } + break; + case NEED_WRAP: + if (doWrap() == -1) { + LOGGER.error("Unexpected. Handshake failed abruptly during wrap"); + return false; + } + break; + case NEED_TASK: + if (!doTask()) { + LOGGER.error("Unexpected. Handshake failed abruptly during task"); + return false; + } + break; + case FINISHED: + case NOT_HANDSHAKING: + isHandshakeCompleted = true; + return true; + default: + LOGGER.error("Unknown handshake status. Handshake failed"); + return false; + } + } + } + + private synchronized boolean doTask() { + Runnable runnable; + while ((runnable = sslEngine_.getDelegatedTask()) != null) { + runnable.run(); + } + HandshakeStatus hs = sslEngine_.getHandshakeStatus(); + return hs != HandshakeStatus.NEED_TASK; + } + + private synchronized int doUnwrap() throws IOException { + int num = getSocketChannel().read(netUnwrap); + if (num < 0) { + LOGGER.error("Failed during read operation. Probably server is down"); + return -1; + } + SSLEngineResult unwrapResult; + + try { + netUnwrap.flip(); + unwrapResult = sslEngine_.unwrap(netUnwrap, appUnwrap); + netUnwrap.compact(); + } catch (SSLException ex) { + LOGGER.error(ex.getMessage()); + throw ex; + } + + switch (unwrapResult.getStatus()) { + case OK: + if (appUnwrap.position() > 0) { + appUnwrap.flip(); + appUnwrap.compact(); Review Comment: The appUnwrap buffer manipulation logic has a potential issue. On lines 351-353, when status is OK and appUnwrap.position() > 0, the code calls flip() followed immediately by compact(). The flip() sets limit to position and position to 0, then compact() copies remaining data to the beginning. This sequence doesn't make sense - flip() is typically used to prepare for reading, but compact() is used after reading to prepare for writing. This likely leaves data in the wrong state. ```suggestion ``` ########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException( + TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch (FileNotFoundException e) { + } + + InputStream storeStream = null; + try { + storeStream = new URL(store).openStream(); + if (storeStream != null) { + return storeStream; + } + } catch (MalformedURLException e) { + } + + storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); + + if (storeStream != null) { + return storeStream; + } else { + throw new IOException("Could not load file: " + store); + } + } + + protected TNonblockingSSLSocket(String host, int port, int timeout, SSLContext sslContext) + throws IOException, TTransportException { + super(host, port, timeout); + sslEngine_ = sslContext.createSSLEngine(host, port); + sslEngine_.setUseClientMode(true); + + int appBufferSize = sslEngine_.getSession().getApplicationBufferSize(); + int netBufferSize = sslEngine_.getSession().getPacketBufferSize(); + appUnwrap = ByteBuffer.allocate(appBufferSize); + netUnwrap = ByteBuffer.allocate(netBufferSize); + appWrap = ByteBuffer.allocate(appBufferSize); + netWrap = ByteBuffer.allocate(netBufferSize); + decodedBytes = ByteBuffer.allocate(appBufferSize); + decodedBytes.flip(); + isHandshakeCompleted = false; + } + + /** {@inheritDoc} */ + @Override + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + selectionKey = super.registerSelector(selector, interests); + return selectionKey; + } + + /** {@inheritDoc} */ + @Override + public boolean isOpen() { + // isConnected() does not return false after close(), but isOpen() does + return super.isOpen() && isHandshakeCompleted; + } + + /** {@inheritDoc} */ + @Override + public void open() throws TTransportException { + throw new RuntimeException("open() is not implemented for TNonblockingSSLSocket"); + } + + /** {@inheritDoc} */ + @Override + public synchronized int read(ByteBuffer buffer) throws TTransportException { + int numBytes = buffer.remaining(); + while (decodedBytes.remaining() < numBytes) { + try { + if (doUnwrap() == -1) { + throw new IOException("Unable to read " + numBytes + " bytes"); + } + } catch (IOException exc) { + throw new TTransportException(TTransportException.UNKNOWN, exc.getMessage()); + } + if (appUnwrap.hasRemaining() + || (decodedBytes.position() > 0 && decodedBytes.flip().hasRemaining())) { + appUnwrap.flip(); + decodedBytes.flip(); + + ByteBuffer tempBuffer = + ByteBuffer.allocate(appUnwrap.remaining() + decodedBytes.remaining()); + tempBuffer.put(decodedBytes); + tempBuffer.put(appUnwrap); + tempBuffer.flip(); + + decodedBytes = tempBuffer; + appUnwrap.clear(); + } + } + int oldLimit = decodedBytes.limit(); + decodedBytes.limit(decodedBytes.position() + numBytes); + buffer.put(decodedBytes); + decodedBytes.limit(oldLimit); + selectionKey.interestOps(SelectionKey.OP_WRITE); + return numBytes; + } + + /** {@inheritDoc} */ + @Override + public synchronized int write(ByteBuffer buffer) throws TTransportException { + int numBytes = 0; + + if (buffer.position() > 0) buffer.flip(); + + int nTransfer; + int num; + while (buffer.hasRemaining()) { + nTransfer = Math.min(appWrap.remaining(), buffer.remaining()); + if (nTransfer > 0) { + appWrap.put(buffer); + } + + try { + num = doWrap(); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + if (num < 0) { + LOGGER.error("Failed while writing. Probably server is down"); + return -1; + } + numBytes += num; + } + return numBytes; Review Comment: The write method increments numBytes with the return value from doWrap(), which is the number of bytes written to the underlying socket channel. However, the method should return the number of bytes consumed from the input buffer parameter, not the number of encrypted bytes written to the network. This mismatch could cause the caller to receive incorrect information about how much data was processed. ########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException( + TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch (FileNotFoundException e) { + } + + InputStream storeStream = null; + try { + storeStream = new URL(store).openStream(); + if (storeStream != null) { + return storeStream; + } + } catch (MalformedURLException e) { + } + + storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); + + if (storeStream != null) { + return storeStream; + } else { + throw new IOException("Could not load file: " + store); + } + } + + protected TNonblockingSSLSocket(String host, int port, int timeout, SSLContext sslContext) + throws IOException, TTransportException { + super(host, port, timeout); + sslEngine_ = sslContext.createSSLEngine(host, port); + sslEngine_.setUseClientMode(true); + + int appBufferSize = sslEngine_.getSession().getApplicationBufferSize(); + int netBufferSize = sslEngine_.getSession().getPacketBufferSize(); + appUnwrap = ByteBuffer.allocate(appBufferSize); + netUnwrap = ByteBuffer.allocate(netBufferSize); + appWrap = ByteBuffer.allocate(appBufferSize); + netWrap = ByteBuffer.allocate(netBufferSize); + decodedBytes = ByteBuffer.allocate(appBufferSize); + decodedBytes.flip(); + isHandshakeCompleted = false; + } + + /** {@inheritDoc} */ + @Override + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + selectionKey = super.registerSelector(selector, interests); + return selectionKey; + } + + /** {@inheritDoc} */ + @Override + public boolean isOpen() { + // isConnected() does not return false after close(), but isOpen() does + return super.isOpen() && isHandshakeCompleted; + } + + /** {@inheritDoc} */ + @Override + public void open() throws TTransportException { + throw new RuntimeException("open() is not implemented for TNonblockingSSLSocket"); + } + + /** {@inheritDoc} */ + @Override + public synchronized int read(ByteBuffer buffer) throws TTransportException { + int numBytes = buffer.remaining(); + while (decodedBytes.remaining() < numBytes) { + try { + if (doUnwrap() == -1) { + throw new IOException("Unable to read " + numBytes + " bytes"); + } + } catch (IOException exc) { + throw new TTransportException(TTransportException.UNKNOWN, exc.getMessage()); + } + if (appUnwrap.hasRemaining() + || (decodedBytes.position() > 0 && decodedBytes.flip().hasRemaining())) { Review Comment: The condition on line 215 modifies decodedBytes by calling flip() as a side effect within the condition check. This is problematic because if appUnwrap.hasRemaining() is true, the flip() on decodedBytes is never executed due to short-circuit evaluation, leading to inconsistent state. Additionally, evaluating flip() inside a boolean expression is confusing and error-prone. Extract the flip() calls outside the condition. ```suggestion // Determine whether decodedBytes would have remaining data after a flip, // without mutating the original buffer during condition evaluation. ByteBuffer decodedBytesView = decodedBytes.asReadOnlyBuffer(); decodedBytesView.flip(); boolean decodedHasRemaining = decodedBytesView.hasRemaining(); if (appUnwrap.hasRemaining() || (decodedBytes.position() > 0 && decodedHasRemaining)) { ``` ########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException( + TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch (FileNotFoundException e) { + } + + InputStream storeStream = null; + try { + storeStream = new URL(store).openStream(); + if (storeStream != null) { + return storeStream; + } + } catch (MalformedURLException e) { + } + + storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); + + if (storeStream != null) { + return storeStream; + } else { + throw new IOException("Could not load file: " + store); + } + } + + protected TNonblockingSSLSocket(String host, int port, int timeout, SSLContext sslContext) + throws IOException, TTransportException { + super(host, port, timeout); + sslEngine_ = sslContext.createSSLEngine(host, port); + sslEngine_.setUseClientMode(true); + + int appBufferSize = sslEngine_.getSession().getApplicationBufferSize(); + int netBufferSize = sslEngine_.getSession().getPacketBufferSize(); + appUnwrap = ByteBuffer.allocate(appBufferSize); + netUnwrap = ByteBuffer.allocate(netBufferSize); + appWrap = ByteBuffer.allocate(appBufferSize); + netWrap = ByteBuffer.allocate(netBufferSize); + decodedBytes = ByteBuffer.allocate(appBufferSize); + decodedBytes.flip(); + isHandshakeCompleted = false; + } + + /** {@inheritDoc} */ + @Override + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + selectionKey = super.registerSelector(selector, interests); + return selectionKey; + } + + /** {@inheritDoc} */ + @Override + public boolean isOpen() { + // isConnected() does not return false after close(), but isOpen() does + return super.isOpen() && isHandshakeCompleted; + } + + /** {@inheritDoc} */ + @Override + public void open() throws TTransportException { + throw new RuntimeException("open() is not implemented for TNonblockingSSLSocket"); + } + + /** {@inheritDoc} */ + @Override + public synchronized int read(ByteBuffer buffer) throws TTransportException { + int numBytes = buffer.remaining(); + while (decodedBytes.remaining() < numBytes) { + try { + if (doUnwrap() == -1) { + throw new IOException("Unable to read " + numBytes + " bytes"); + } + } catch (IOException exc) { + throw new TTransportException(TTransportException.UNKNOWN, exc.getMessage()); + } + if (appUnwrap.hasRemaining() + || (decodedBytes.position() > 0 && decodedBytes.flip().hasRemaining())) { + appUnwrap.flip(); + decodedBytes.flip(); + + ByteBuffer tempBuffer = + ByteBuffer.allocate(appUnwrap.remaining() + decodedBytes.remaining()); + tempBuffer.put(decodedBytes); + tempBuffer.put(appUnwrap); + tempBuffer.flip(); + + decodedBytes = tempBuffer; + appUnwrap.clear(); + } + } + int oldLimit = decodedBytes.limit(); + decodedBytes.limit(decodedBytes.position() + numBytes); + buffer.put(decodedBytes); + decodedBytes.limit(oldLimit); + selectionKey.interestOps(SelectionKey.OP_WRITE); + return numBytes; + } + + /** {@inheritDoc} */ + @Override + public synchronized int write(ByteBuffer buffer) throws TTransportException { + int numBytes = 0; + + if (buffer.position() > 0) buffer.flip(); + + int nTransfer; + int num; + while (buffer.hasRemaining()) { + nTransfer = Math.min(appWrap.remaining(), buffer.remaining()); + if (nTransfer > 0) { + appWrap.put(buffer); + } + + try { + num = doWrap(); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + if (num < 0) { + LOGGER.error("Failed while writing. Probably server is down"); + return -1; + } + numBytes += num; + } + return numBytes; + } + + /** {@inheritDoc} */ + @Override + public void close() { + sslEngine_.closeOutbound(); + super.close(); + } + + /** {@inheritDoc} */ + @Override + public boolean startConnect() throws IOException { + if (this.isOpen()) { + return true; + } + sslEngine_.beginHandshake(); + return super.startConnect() && doHandShake(); + } + + /** {@inheritDoc} */ + @Override + public boolean finishConnect() throws IOException { + return super.finishConnect() && doHandShake(); + } + + private synchronized boolean doHandShake() throws IOException { + LOGGER.debug("Handshake is started"); + while (true) { + HandshakeStatus hs = sslEngine_.getHandshakeStatus(); + switch (hs) { + case NEED_UNWRAP: + if (doUnwrap() == -1) { + LOGGER.error("Unexpected. Handshake failed abruptly during unwrap"); + return false; + } + break; + case NEED_WRAP: + if (doWrap() == -1) { + LOGGER.error("Unexpected. Handshake failed abruptly during wrap"); + return false; + } + break; + case NEED_TASK: + if (!doTask()) { + LOGGER.error("Unexpected. Handshake failed abruptly during task"); + return false; + } + break; + case FINISHED: + case NOT_HANDSHAKING: + isHandshakeCompleted = true; + return true; + default: + LOGGER.error("Unknown handshake status. Handshake failed"); + return false; + } + } + } + + private synchronized boolean doTask() { + Runnable runnable; + while ((runnable = sslEngine_.getDelegatedTask()) != null) { + runnable.run(); + } + HandshakeStatus hs = sslEngine_.getHandshakeStatus(); + return hs != HandshakeStatus.NEED_TASK; + } + + private synchronized int doUnwrap() throws IOException { + int num = getSocketChannel().read(netUnwrap); + if (num < 0) { + LOGGER.error("Failed during read operation. Probably server is down"); + return -1; + } + SSLEngineResult unwrapResult; + + try { + netUnwrap.flip(); + unwrapResult = sslEngine_.unwrap(netUnwrap, appUnwrap); + netUnwrap.compact(); + } catch (SSLException ex) { + LOGGER.error(ex.getMessage()); + throw ex; + } + + switch (unwrapResult.getStatus()) { + case OK: + if (appUnwrap.position() > 0) { + appUnwrap.flip(); + appUnwrap.compact(); + } + break; + case CLOSED: + return -1; + case BUFFER_OVERFLOW: + throw new IllegalStateException("Failed to unwrap"); Review Comment: The error message "Failed to unwrap" is vague and doesn't provide enough context for debugging. BUFFER_OVERFLOW typically indicates that the application buffer is too small for the decrypted data. Consider providing a more descriptive error message that explains the actual problem and potential resolution, such as "SSL unwrap failed: application buffer overflow - buffer size may be insufficient for decrypted data". ```suggestion throw new IllegalStateException( "SSL unwrap failed: application buffer overflow - buffer size may be insufficient for decrypted data"); ``` ########## iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TNonblockingSSLSocket.java: ########## @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.rpc; + +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.KeyStore; + +/** Transport for use with async client. */ +public class TNonblockingSSLSocket extends TNonblockingSocket { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName()); + + private final SSLEngine sslEngine_; + + private final ByteBuffer appUnwrap; + private final ByteBuffer netUnwrap; + + private final ByteBuffer appWrap; + private final ByteBuffer netWrap; + + private ByteBuffer decodedBytes; + + private boolean isHandshakeCompleted; + + private SelectionKey selectionKey; + + public TNonblockingSSLSocket( + String host, + int port, + int timeout, + String keystore, + String keystorePassword, + String truststore, + String truststorePassword) + throws TTransportException, IOException { + this( + host, + port, + timeout, + createSSLContext(keystore, keystorePassword, truststore, truststorePassword)); + } + + private static SSLContext createSSLContext( + String keystore, String keystorePassword, String truststore, String truststorePassword) + throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (truststore != null) { + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance("JKS"); + in = getStoreAsStream(truststore); + ts.load(in, (truststorePassword != null ? truststorePassword.toCharArray() : null)); + tmf.init(ts); + } + + if (keystore != null) { + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance("JKS"); + is = getStoreAsStream(keystore); + ks.load(is, keystorePassword.toCharArray()); + kmf.init(ks, keystorePassword.toCharArray()); + } + + if (keystore != null && truststore != null) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } else if (keystore != null) { + ctx.init(kmf.getKeyManagers(), null, null); + } else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException( + TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch (FileNotFoundException e) { + } + + InputStream storeStream = null; + try { + storeStream = new URL(store).openStream(); + if (storeStream != null) { + return storeStream; + } + } catch (MalformedURLException e) { + } + + storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); + + if (storeStream != null) { + return storeStream; + } else { + throw new IOException("Could not load file: " + store); + } + } + + protected TNonblockingSSLSocket(String host, int port, int timeout, SSLContext sslContext) + throws IOException, TTransportException { + super(host, port, timeout); + sslEngine_ = sslContext.createSSLEngine(host, port); + sslEngine_.setUseClientMode(true); + + int appBufferSize = sslEngine_.getSession().getApplicationBufferSize(); + int netBufferSize = sslEngine_.getSession().getPacketBufferSize(); + appUnwrap = ByteBuffer.allocate(appBufferSize); + netUnwrap = ByteBuffer.allocate(netBufferSize); + appWrap = ByteBuffer.allocate(appBufferSize); + netWrap = ByteBuffer.allocate(netBufferSize); + decodedBytes = ByteBuffer.allocate(appBufferSize); + decodedBytes.flip(); + isHandshakeCompleted = false; + } + + /** {@inheritDoc} */ + @Override + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + selectionKey = super.registerSelector(selector, interests); + return selectionKey; + } + + /** {@inheritDoc} */ + @Override + public boolean isOpen() { + // isConnected() does not return false after close(), but isOpen() does + return super.isOpen() && isHandshakeCompleted; + } + + /** {@inheritDoc} */ + @Override + public void open() throws TTransportException { + throw new RuntimeException("open() is not implemented for TNonblockingSSLSocket"); + } + + /** {@inheritDoc} */ + @Override + public synchronized int read(ByteBuffer buffer) throws TTransportException { + int numBytes = buffer.remaining(); + while (decodedBytes.remaining() < numBytes) { + try { + if (doUnwrap() == -1) { + throw new IOException("Unable to read " + numBytes + " bytes"); + } + } catch (IOException exc) { + throw new TTransportException(TTransportException.UNKNOWN, exc.getMessage()); + } + if (appUnwrap.hasRemaining() + || (decodedBytes.position() > 0 && decodedBytes.flip().hasRemaining())) { + appUnwrap.flip(); + decodedBytes.flip(); + + ByteBuffer tempBuffer = + ByteBuffer.allocate(appUnwrap.remaining() + decodedBytes.remaining()); + tempBuffer.put(decodedBytes); + tempBuffer.put(appUnwrap); + tempBuffer.flip(); + + decodedBytes = tempBuffer; + appUnwrap.clear(); + } + } + int oldLimit = decodedBytes.limit(); + decodedBytes.limit(decodedBytes.position() + numBytes); + buffer.put(decodedBytes); + decodedBytes.limit(oldLimit); + selectionKey.interestOps(SelectionKey.OP_WRITE); + return numBytes; + } + + /** {@inheritDoc} */ + @Override + public synchronized int write(ByteBuffer buffer) throws TTransportException { + int numBytes = 0; + + if (buffer.position() > 0) buffer.flip(); Review Comment: The buffer position manipulation appears incorrect. When buffer.position() > 0, the code calls buffer.flip() which sets the limit to the position and position to 0. However, this assumes the buffer was being written to. If the buffer was already prepared for reading (position already at 0), this flip operation could cause incorrect behavior. Consider checking the buffer state or documenting the expected buffer state on entry. ```suggestion // Only flip if the buffer appears to be in write-mode (just filled, not yet flipped) if (buffer.position() > 0 && buffer.limit() == buffer.capacity()) { buffer.flip(); } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
