This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 82469e9b12eea277385dffee65afb4b6abcdc3c4 Author: Siyang Tang <[email protected]> AuthorDate: Thu Feb 8 12:02:53 2024 +0800 [enhancement](mysql-channel) avoid potential buffer overflow when flushing send buffer occurs IOE (#30868) --- .../java/org/apache/doris/mysql/MysqlChannel.java | 22 +++--- .../org/apache/doris/mysql/MysqlChannelTest.java | 92 ++++++++++++++++++++++ 2 files changed, 104 insertions(+), 10 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java index 71eaf59863d..5dfa7947abe 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java @@ -401,8 +401,11 @@ public class MysqlChannel { return; } sendBuffer.flip(); - realNetSend(sendBuffer); - sendBuffer.clear(); + try { + realNetSend(sendBuffer); + } finally { + sendBuffer.clear(); + } isSend = true; } @@ -423,18 +426,17 @@ public class MysqlChannel { sendBuffer.put((byte) sequenceId); } - private void writeBuffer(ByteBuffer buffer, boolean isSsl) throws IOException { + private void writeBuffer(ByteBuffer buffer) throws IOException { if (null == sendBuffer) { return; } - long leftLength = sendBuffer.capacity() - sendBuffer.position(); // If too long for buffer, send buffered data. - if (leftLength < buffer.remaining()) { + if (sendBuffer.remaining() < buffer.remaining()) { // Flush data in buffer. flush(); } // Send this buffer if large enough - if (buffer.remaining() > sendBuffer.capacity()) { + if (buffer.remaining() > sendBuffer.remaining()) { realNetSend(buffer); return; } @@ -451,20 +453,20 @@ public class MysqlChannel { bufLen = MAX_PHYSICAL_PACKET_LENGTH; packet.limit(packet.position() + bufLen); if (isSslHandshaking) { - writeBuffer(packet, true); + writeBuffer(packet); } else { writeHeader(bufLen, isSslMode); - writeBuffer(packet, isSslMode); + writeBuffer(packet); accSequenceId(); } } if (isSslHandshaking) { packet.limit(oldLimit); - writeBuffer(packet, true); + writeBuffer(packet); } else { writeHeader(oldLimit - packet.position(), isSslMode); packet.limit(oldLimit); - writeBuffer(packet, isSslMode); + writeBuffer(packet); accSequenceId(); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/mysql/MysqlChannelTest.java b/fe/fe-core/src/test/java/org/apache/doris/mysql/MysqlChannelTest.java new file mode 100644 index 00000000000..78d18b58bad --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/mysql/MysqlChannelTest.java @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.mysql; + + +import org.apache.doris.common.jmockit.Deencapsulation; +import org.apache.doris.qe.ConnectContext; + +import mockit.Delegate; +import mockit.Expectations; +import mockit.Mocked; +import org.junit.Assert; +import org.junit.Test; +import org.xnio.StreamConnection; + +import java.io.IOException; +import java.nio.ByteBuffer; + +public class MysqlChannelTest { + + @Mocked + StreamConnection streamConnection; + + @Test + public void testSendAfterException() throws IOException { + // Mock. + new Expectations() { + { + streamConnection.getSinkChannel().write((ByteBuffer) any); + // The first call to `write()` throws IOException. + result = new IOException(); + // The second call to `write()` executes normally. + result = new Delegate() { + int fakeRead(ByteBuffer buffer) { + int writeLen = buffer.remaining(); + buffer.position(buffer.limit()); + return writeLen; + } + }; + + streamConnection.getSinkChannel().flush(); + result = true; + } + }; + + ConnectContext ctx = new ConnectContext(streamConnection); + MysqlChannel mysqlChannel = new MysqlChannel(streamConnection, ctx); + Deencapsulation.setField(mysqlChannel, "sendBuffer", ByteBuffer.allocate(5)); + // The first call to `realNetSend()` in `flush()` throws IOException. + // If `flush()` doesn't consider this exception, `sendBuffer` won't be reset to write mode, + // which will cause BufferOverflowException at the next calling `sendOnePacket()`. + ByteBuffer buf = ByteBuffer.allocate(12); + buf.putInt(1); + buf.putInt(2); + // limit=8 + buf.flip(); + try { + mysqlChannel.sendOnePacket(buf); + Assert.fail(); + } catch (IOException ignore) { + // do nothing + } + buf.clear(); + + buf.putInt(1); + // limit=4 + buf.flip(); + mysqlChannel.sendOnePacket(buf); + buf.clear(); + + buf.putInt(1); + buf.putInt(2); + // limit=8 + buf.flip(); + mysqlChannel.sendOnePacket(buf); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
