This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new 2fc20ef2847 [feature](proxy-protocol) Support proxy protocol v1 
(#32338) (#32368)
2fc20ef2847 is described below

commit 2fc20ef2847f07b95c33b6f0c4882daadfdf122e
Author: Mingyu Chen <[email protected]>
AuthorDate: Mon Mar 18 23:09:55 2024 +0800

    [feature](proxy-protocol) Support proxy protocol v1 (#32338) (#32368)
    
    bp #32338
---
 .../main/java/org/apache/doris/common/Config.java  |   7 +
 .../org/apache/doris/mysql/AcceptListener.java     |  14 ++
 .../java/org/apache/doris/mysql/BytesChannel.java  |  29 +++
 .../java/org/apache/doris/mysql/MysqlChannel.java  |  32 +++-
 .../java/org/apache/doris/mysql/MysqlProto.java    |   1 -
 .../apache/doris/mysql/ProxyProtocolHandler.java   | 212 +++++++++++++++++++++
 .../apache/doris/qe/ProxyProtocolHandlerTest.java  | 135 +++++++++++++
 7 files changed, 428 insertions(+), 2 deletions(-)

diff --git a/fe/fe-common/src/main/java/org/apache/doris/common/Config.java 
b/fe/fe-common/src/main/java/org/apache/doris/common/Config.java
index ec25f8d350d..d39b0576a15 100644
--- a/fe/fe-common/src/main/java/org/apache/doris/common/Config.java
+++ b/fe/fe-common/src/main/java/org/apache/doris/common/Config.java
@@ -2305,4 +2305,11 @@ public class Config extends ConfigBase {
                     + "and the deleted labels can be reused."
     })
     public static int label_num_threshold = 2000;
+
+    @ConfField(description = {
+            "是否开启 Proxy Protocol 支持",
+            "Whether to enable proxy protocol"
+    })
+    public static boolean enable_proxy_protocol = false;
+
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/mysql/AcceptListener.java 
b/fe/fe-core/src/main/java/org/apache/doris/mysql/AcceptListener.java
index 30d0693a6b5..2682426530e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/mysql/AcceptListener.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/AcceptListener.java
@@ -18,11 +18,14 @@
 package org.apache.doris.mysql;
 
 import org.apache.doris.catalog.Env;
+import org.apache.doris.common.Config;
 import org.apache.doris.common.ErrorCode;
+import org.apache.doris.mysql.ProxyProtocolHandler.ProxyProtocolResult;
 import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.qe.ConnectProcessor;
 import org.apache.doris.qe.ConnectScheduler;
 
+import com.google.common.base.Preconditions;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.xnio.ChannelListener;
@@ -65,6 +68,17 @@ public class AcceptListener implements 
ChannelListener<AcceptingChannel<StreamCo
                         // Set thread local info
                         context.setThreadLocalInfo();
                         context.setConnectScheduler(connectScheduler);
+
+                        if (Config.enable_proxy_protocol) {
+                            ProxyProtocolResult result = 
ProxyProtocolHandler.handle(context.getMysqlChannel());
+                            Preconditions.checkNotNull(result);
+                            if (!result.isUnknown) {
+                                
context.getMysqlChannel().setRemoteAddr(result.sourceIP, result.sourcePort);
+                            }
+                            // ignore the UNKNOWN, and just use IP from MySQL 
protocol.
+                            // which is already set when creating MysqlChannel.
+                        }
+
                         // authenticate check failed.
                         if (!MysqlProto.negotiate(context)) {
                             throw new AfterConnectedException("mysql negotiate 
failed");
diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/BytesChannel.java 
b/fe/fe-core/src/main/java/org/apache/doris/mysql/BytesChannel.java
new file mode 100644
index 00000000000..bf97ae8068d
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/BytesChannel.java
@@ -0,0 +1,29 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.mysql;
+
+import java.nio.ByteBuffer;
+
+public interface BytesChannel {
+    /**
+     * Read N bytes from channel to buffer, N = dstBuf.remaining()
+     * @param buffer
+     * @return number of bytes read
+     */
+    public int read(ByteBuffer buffer);
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java 
b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java
index 8e7c5f79ffd..4b10dc00656 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java
@@ -41,7 +41,7 @@ import javax.net.ssl.SSLException;
  * MySQL protocol will split one logical packet more than 16MB to many packets.
  * http://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
  */
-public class MysqlChannel {
+public class MysqlChannel implements BytesChannel {
     // logger for this class
     private static final Logger LOG = LogManager.getLogger(MysqlChannel.class);
     // max length which one MySQL physical can hold, if one logical packet is 
bigger than this,
@@ -101,6 +101,9 @@ public class MysqlChannel {
         this.remoteHostPortString = "";
         this.remoteIp = "";
         this.conn = connection;
+
+        // if proxy protocal is enabled, the remote address will be got from 
proxy protocal header
+        // and overwrite the original remote address.
         if (connection.getPeerAddress() instanceof InetSocketAddress) {
             InetSocketAddress address = (InetSocketAddress) 
connection.getPeerAddress();
             remoteHostPortString = NetUtils
@@ -216,6 +219,28 @@ public class MysqlChannel {
         return readLen;
     }
 
+    @Override
+    public int read(ByteBuffer dstBuf) {
+        int readLen = 0;
+        try {
+            while (dstBuf.remaining() != 0) {
+                int ret = Channels.readBlocking(conn.getSourceChannel(), 
dstBuf, context.getNetReadTimeout(),
+                        TimeUnit.SECONDS);
+                // return -1 when remote peer close the channel
+                if (ret == -1) {
+                    return 0;
+                }
+                readLen += ret;
+            }
+        } catch (IOException e) {
+            if (LOG.isDebugEnabled()) {
+                LOG.debug("Read channel exception, ignore.", e);
+            }
+            return 0;
+        }
+        return readLen;
+    }
+
     protected void decryptData(ByteBuffer dstBuf, boolean isHeader) throws 
SSLException {
         // after decrypt, we get a mysql packet with mysql header.
         if (!isSslMode || isHeader) {
@@ -558,4 +583,9 @@ public class MysqlChannel {
         }
     }
 
+    // for proxy protocal only
+    public void setRemoteAddr(String ip, int port) {
+        this.remoteIp = ip;
+        this.remoteHostPortString = NetUtils.getHostPortInAccessibleFormat(ip, 
port);
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlProto.java 
b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlProto.java
index ad2fb515d63..c463e8f4264 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlProto.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlProto.java
@@ -443,5 +443,4 @@ public class MysqlProto {
         buffer.get();
         return buf;
     }
-
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/mysql/ProxyProtocolHandler.java 
b/fe/fe-core/src/main/java/org/apache/doris/mysql/ProxyProtocolHandler.java
new file mode 100644
index 00000000000..0f52a05286e
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/ProxyProtocolHandler.java
@@ -0,0 +1,212 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.mysql;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+
+/**
+ * Proxy protocol handler.
+ * The proxy protocol is a simple protocol to pass client connection 
information to the server.
+ * It is used in some load balancers and proxies to pass the client's IP 
address and port to the server.
+ * The protocol is defined in 
https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
+ * The protocol has two versions: V1 and V2.
+ * V1 is a text-based protocol, and V2 is a binary protocol.
+ * This class only supports V1.
+ * The V1 protocol is a text-based protocol, and the header is "PROXY ".
+ * The protocol is defined as:
+ * PROXY TCP4[TCP6] <srcip> <dstip> <srcport> <dstport>\r\n
+ * or
+ * PROXY UNKNOWN xxxx\r\n
+ */
+public class ProxyProtocolHandler {
+    private static final Logger LOG = 
LogManager.getLogger(ProxyProtocolHandler.class);
+
+    private static final byte[] V1_HEADER = "PROXY 
".getBytes(StandardCharsets.US_ASCII);
+    private static final byte[] V2_HEADER
+            = new byte[] {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 
0x55, 0x49, 0x54, 0x0A};
+
+    private static final String UNKNOWN = "UNKNOWN";
+    private static final String TCP4 = "TCP4";
+    private static final String TCP6 = "TCP6";
+
+    public static class ProxyProtocolResult {
+        public String sourceIP = null;
+        public int sourcePort = -1;
+        public String destIp = null;
+        public int destPort = -1;
+        public boolean isUnknown = false;
+
+        @Override
+        public String toString() {
+            return "ProxyProtocolResult{"
+                    + "sourceIP='" + sourceIP + '\''
+                    + ", sourcePort=" + sourcePort
+                    + ", destIp='" + destIp + '\''
+                    + ", destPort=" + destPort
+                    + ", isUnknown=" + isUnknown
+                    + '}';
+        }
+    }
+
+    public static ProxyProtocolResult handle(BytesChannel channel) throws 
IOException {
+        // First read 1 byte to see if it is V1 or V2
+        ByteBuffer buffer = ByteBuffer.allocate(1);
+        int readLen = channel.read(buffer);
+        if (readLen != 1) {
+            throw new IOException("Invalid proxy protocol, expect incoming 
bytes first");
+        }
+        buffer.flip();
+        byte firstByte = buffer.get();
+        if ((char) firstByte == V1_HEADER[0]) {
+            return handleV1(channel);
+        } else if (firstByte == V2_HEADER[0]) {
+            return handleV2(channel);
+        } else {
+            throw new IOException("Invalid proxy protocol header in first 
bytes: " + firstByte + ".");
+        }
+    }
+
+    private static ProxyProtocolResult handleV1(BytesChannel channel) throws 
IOException {
+        ProxyProtocolResult result = new ProxyProtocolResult();
+
+        int byteCount = 1; // already read the first byte, so start with 1
+        boolean parsingUnknown = false; // true if "UNKNOWN" is found
+        boolean carriageFound = false;  // true if \r is found
+        String protocol = null;
+        StringBuilder stringBuilder = new StringBuilder();
+
+        // read last 5 bytes of "PROXY "
+        ByteBuffer buffer = ByteBuffer.allocate(5);
+        int readLen = channel.read(buffer);
+        if (readLen != 5) {
+            throw new IOException("Invalid proxy protocol v1, expected \"PROXY 
\"");
+        }
+        byteCount += readLen;
+        StringBuilder debugInfo = new StringBuilder("PROXY ");
+        // start reading
+        buffer = ByteBuffer.allocate(1);
+        channel.read(buffer);
+        buffer.flip();
+        while (buffer.hasRemaining()) {
+            char c = (char) buffer.get();
+            debugInfo.append(c);
+            if (parsingUnknown) {
+                // Found "PROXY UNKNOWN"
+                // ignore any other bytes until "\r\n"
+                if (c == '\r') {
+                    carriageFound = true;
+                } else if (c == '\n') {
+                    if (!carriageFound) {
+                        throw new ProtocolException("Invalid proxy protocol 
v1. '\\r' is not found before '\\n'",
+                                debugInfo.toString());
+                    }
+                    result.isUnknown = true;
+                    return result;
+                } else if (carriageFound) {
+                    throw new ProtocolException("Invalid proxy protocol v1. "
+                            + "'\\r' should follow with '\\n', but see: " + c 
+ ".", debugInfo.toString());
+                }
+            } else if (carriageFound) {
+                if (c == '\n') {
+                    // eof, set remote ip
+                    if (LOG.isDebugEnabled()) {
+                        LOG.debug("Finish parsing proxy protocol v1. result: 
{}", result);
+                    }
+                    return result;
+                } else {
+                    throw new ProtocolException("Invalid proxy protocol v1. "
+                            + "'\\r' should follow with '\\n', but see: " + c 
+ ".", debugInfo.toString());
+                }
+            } else {
+                switch (c) {
+                    case ' ':
+                        if (result.sourcePort != -1 || stringBuilder.length() 
== 0) {
+                            throw new ProtocolException("Invalid proxy 
protocol v1. expecting a '\\r' or a '\\n'",
+                                    debugInfo.toString());
+                        } else if (protocol == null) {
+                            protocol = stringBuilder.toString();
+                            stringBuilder.setLength(0);
+                            if (protocol.equals(UNKNOWN)) {
+                                parsingUnknown = true;
+                            } else if (!protocol.equals(TCP4) && 
!protocol.equals(TCP6)) {
+                                throw new ProtocolException("Invalid proxy 
protocol v1. expecting TCP4/TCP6/UNKNOWN."
+                                        + " See: " + protocol + ".", 
debugInfo.toString());
+                            }
+                        } else if (result.sourceIP == null) {
+                            result.sourceIP = stringBuilder.toString();
+                            stringBuilder.setLength(0);
+                        } else if (result.destIp == null) {
+                            result.destIp = stringBuilder.toString();
+                            stringBuilder.setLength(0);
+                        } else {
+                            result.sourcePort = 
Integer.parseInt(stringBuilder.toString());
+                            stringBuilder.setLength(0);
+                        }
+                        break;
+                    case '\r':
+                        if (result.destPort == -1 && result.sourcePort != -1
+                                && !carriageFound && stringBuilder.length() > 
0) {
+                            result.destPort = 
Integer.parseInt(stringBuilder.toString());
+                            stringBuilder.setLength(0);
+                            carriageFound = true;
+                        } else if (protocol == null) {
+                            if (UNKNOWN.equals(stringBuilder.toString())) {
+                                parsingUnknown = true;
+                                carriageFound = true;
+                            }
+                        } else {
+                            throw new ProtocolException(
+                                    "Invalid proxy protocol v1. Already see 
'\\r' but no valid info",
+                                    debugInfo.toString());
+                        }
+                        break;
+                    case '\n':
+                        throw new ProtocolException("Invalid proxy protocol 
v1. '\\r' is not found before '\\n'",
+                                debugInfo.toString());
+                    default:
+                        stringBuilder.append(c);
+                }
+            }
+            byteCount++;
+            if (byteCount == 107) {
+                throw new ProtocolException("Invalid proxy protocol v1, max 
length(107) exceeds",
+                        debugInfo.toString());
+            } else {
+                buffer.clear();
+                channel.read(buffer);
+                buffer.flip();
+            }
+        }
+        throw new ProtocolException("Invalid proxy protocol v1, unexpected end 
of stream", debugInfo.toString());
+    }
+
+    private static ProxyProtocolResult handleV2(BytesChannel channel) throws 
IOException {
+        throw new IOException("proxy protocol v2 is not supported yet");
+    }
+
+    public static class ProtocolException extends IOException {
+        public ProtocolException(String message, String protocolStr) {
+            super(message + ": " + protocolStr);
+        }
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/qe/ProxyProtocolHandlerTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/qe/ProxyProtocolHandlerTest.java
new file mode 100644
index 00000000000..13ad0b67d92
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/qe/ProxyProtocolHandlerTest.java
@@ -0,0 +1,135 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.qe;
+
+import org.apache.doris.mysql.BytesChannel;
+import org.apache.doris.mysql.ProxyProtocolHandler;
+
+import org.junit.Test;
+import org.junit.jupiter.api.Assertions;
+
+import java.io.IOException;
+
+public class ProxyProtocolHandlerTest {
+
+    public static class TestChannel implements BytesChannel {
+        private byte[] data;
+        private int pos;
+
+        public TestChannel(byte[] data) {
+            this.data = data;
+            this.pos = 0;
+        }
+
+        @Override
+        public int read(java.nio.ByteBuffer buffer) {
+            int len = Math.min(buffer.remaining(), data.length - pos);
+            if (len > 0) {
+                buffer.put(data, pos, len);
+                pos += len;
+            }
+            return len;
+        }
+    }
+
+    private TestChannel testChannel;
+
+    @Test
+    public void handleV1ProtocolWithValidData() throws IOException {
+        byte[] data = "PROXY TCP4 192.168.0.1 192.168.0.2 12345 
54321\r\n".getBytes();
+        testChannel = new TestChannel(data);
+        ProxyProtocolHandler.ProxyProtocolResult result = 
ProxyProtocolHandler.handle(testChannel);
+        Assertions.assertNotNull(result);
+        Assertions.assertFalse(result.isUnknown);
+        Assertions.assertEquals("192.168.0.1", result.sourceIP);
+        Assertions.assertEquals(12345, result.sourcePort);
+        Assertions.assertEquals("192.168.0.2", result.destIp);
+        Assertions.assertEquals(54321, result.destPort);
+    }
+
+    @Test
+    public void handleV1ProtocolWithUnknown() throws IOException {
+        byte[] data = "PROXY UNKNOWN xxxxxxxxxxxxxxxxxx\r\n".getBytes();
+        testChannel = new TestChannel(data);
+        ProxyProtocolHandler.ProxyProtocolResult result = 
ProxyProtocolHandler.handle(testChannel);
+        Assertions.assertNotNull(result);
+        Assertions.assertTrue(result.isUnknown);
+    }
+
+    @Test(expected = IOException.class)
+    public void handleV1ProtocolWithInvalidProtocol() throws IOException {
+        byte[] data = "PROXY TCP7 xxx\r\n".getBytes();
+        testChannel = new TestChannel(data);
+        ProxyProtocolHandler.handle(testChannel);
+    }
+
+    @Test(expected = IOException.class)
+    public void handleV1ProtocolWithInvalidData() throws IOException {
+        byte[] data = "INVALID DATA".getBytes();
+        testChannel = new TestChannel(data);
+        ProxyProtocolHandler.handle(testChannel);
+    }
+
+    @Test(expected = IOException.class)
+    public void handleV1ProtocolWithIncompleteData() throws IOException {
+        byte[] data = "PROXY TCP4 192.168.0.1 192.168.0.2 12345".getBytes();
+        testChannel = new TestChannel(data);
+        ProxyProtocolHandler.handle(testChannel);
+    }
+
+    @Test(expected = IOException.class)
+    public void handleV1ProtocolWithExtraData() throws IOException {
+        byte[] data = "PROXY TCP4 192.168.0.1 192.168.0.2 12345 54321 EXTRA 
DATA\r\n".getBytes();
+        testChannel = new TestChannel(data);
+        ProxyProtocolHandler.handle(testChannel);
+    }
+
+    @Test
+    public void handleV1ProtocolWithValidIPv6Data() throws IOException {
+        byte[] data = "PROXY TCP6 2001:db8:0:1:1:1:1:1 2001:db8:0:1:1:1:1:2 
12345 54321\r\n".getBytes();
+        testChannel = new TestChannel(data);
+        ProxyProtocolHandler.ProxyProtocolResult result = 
ProxyProtocolHandler.handle(testChannel);
+        Assertions.assertNotNull(result);
+        Assertions.assertFalse(result.isUnknown);
+        Assertions.assertEquals("2001:db8:0:1:1:1:1:1", result.sourceIP);
+        Assertions.assertEquals(12345, result.sourcePort);
+        Assertions.assertEquals("2001:db8:0:1:1:1:1:2", result.destIp);
+        Assertions.assertEquals(54321, result.destPort);
+    }
+
+    @Test(expected = IOException.class)
+    public void handleV1ProtocolWithInvalidIPv6Data() throws IOException {
+        byte[] data = "PROXY TCP6 2001:db8:0:1:1:1:1:1 2001:db8:0:1:1:1:1:2 
12345 EXTRA DATA\r\n".getBytes();
+        testChannel = new TestChannel(data);
+        ProxyProtocolHandler.handle(testChannel);
+    }
+
+    @Test(expected = IOException.class)
+    public void handleV1ProtocolWithIncompleteIPv6Data() throws IOException {
+        byte[] data = "PROXY TCP6 2001:db8:0:1:1:1:1:1 2001:db8:0:1:1:1:1:2 
12345".getBytes();
+        testChannel = new TestChannel(data);
+        ProxyProtocolHandler.handle(testChannel);
+    }
+
+    @Test(expected = IOException.class)
+    public void handleV2Protocol() throws IOException {
+        byte[] data = new byte[] {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 
0x51, 0x55, 0x49, 0x54, 0x0A};
+        testChannel = new TestChannel(data);
+        ProxyProtocolHandler.handle(testChannel);
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to