This is an automated email from the ASF dual-hosted git repository.

spacewander pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix.git


The following commit(s) were added to refs/heads/master by this push:
     new ae2e653  fix(mqtt): handle properties for MQTT 5 (#5916)
ae2e653 is described below

commit ae2e653f6b354f2fbb34650c2d3a3e5a69f5a5d9
Author: 罗泽轩 <[email protected]>
AuthorDate: Mon Dec 27 21:01:59 2021 +0800

    fix(mqtt): handle properties for MQTT 5 (#5916)
---
 apisix/stream/plugins/mqtt-proxy.lua | 83 +++++++++++++++++++++---------------
 t/stream-plugin/mqtt-proxy.t         | 73 ++++++++++++++++++++++++++++++-
 2 files changed, 121 insertions(+), 35 deletions(-)

diff --git a/apisix/stream/plugins/mqtt-proxy.lua 
b/apisix/stream/plugins/mqtt-proxy.lua
index fae0eb0..7c89050 100644
--- a/apisix/stream/plugins/mqtt-proxy.lua
+++ b/apisix/stream/plugins/mqtt-proxy.lua
@@ -67,27 +67,39 @@ function _M.check_schema(conf)
 end
 
 
-local function parse_mqtt(data)
-    local res = {}
-    res.packet_type_flags_byte = str_byte(data, 1, 1)
-    if res.packet_type_flags_byte < 16 or res.packet_type_flags_byte > 32 then
-        return nil, "Received unexpected MQTT packet type+flags: "
-                    .. res.packet_type_flags_byte
-    end
-
-    local parsed_pos = 1
-    res.remaining_len = 0
+local function decode_variable_byte_int(data, offset)
     local multiplier = 1
-    for i = 2, 5 do
-        parsed_pos = i
+    local len = 0
+    local pos
+    for i = offset, offset + 3 do
+        pos = i
         local byte = str_byte(data, i, i)
-        res.remaining_len = res.remaining_len + bit.band(byte, 127) * 
multiplier
+        len = len + bit.band(byte, 127) * multiplier
         multiplier = multiplier * 128
         if bit.band(byte, 128) == 0 then
             break
         end
     end
 
+    return len, pos
+end
+
+
+local function parse_msg_hdr(data)
+    local packet_type_flags_byte = str_byte(data, 1, 1)
+    if packet_type_flags_byte < 16 or packet_type_flags_byte > 32 then
+        return nil, nil,
+            "Received unexpected MQTT packet type+flags: " .. 
packet_type_flags_byte
+    end
+
+    local len, pos = decode_variable_byte_int(data, 2)
+    return len, pos
+end
+
+
+local function parse_mqtt(data, parsed_pos)
+    local res = {}
+
     local protocol_len = str_byte(data, parsed_pos + 1, parsed_pos + 1) * 256
                          + str_byte(data, parsed_pos + 2, parsed_pos + 2)
     parsed_pos = parsed_pos + 2
@@ -96,10 +108,15 @@ local function parse_mqtt(data)
 
     res.protocol_ver = str_byte(data, parsed_pos + 1, parsed_pos + 1)
     parsed_pos = parsed_pos + 1
-    if res.protocol_ver == 4 then
-        parsed_pos = parsed_pos + 3
-    elseif res.protocol_ver == 5 then
-        parsed_pos = parsed_pos + 9
+
+    -- skip control flags & keepalive
+    parsed_pos = parsed_pos + 3
+
+    if res.protocol_ver == 5 then
+        -- skip properties
+        local property_len
+        property_len, parsed_pos = decode_variable_byte_int(data, parsed_pos + 
1)
+        parsed_pos = parsed_pos + property_len
     end
 
     local client_id_len = str_byte(data, parsed_pos + 1, parsed_pos + 1) * 256
@@ -129,31 +146,29 @@ function _M.preread(conf, ctx)
     local sock = ngx.req.socket()
     -- the header format of MQTT CONNECT can be found in
     -- 
https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901033
-    local data, err = sock:peek(14)
+    local data, err = sock:peek(5)
     if not data then
-        core.log.error("failed to read first 16 bytes: ", err)
+        core.log.error("failed to read the msg header: ", err)
         return 503
     end
 
-    local res, err = parse_mqtt(data)
-    if not res then
-        core.log.error("failed to parse the first 16 bytes: ", err)
+    local remain_len, pos, err = parse_msg_hdr(data)
+    if not remain_len then
+        core.log.error("failed to parse the msg header: ", err)
         return 503
     end
 
-    if res.expect_len > #data then
-        data, err = sock:peek(res.expect_len)
-        if not data then
-            core.log.error("failed to read ", res.expect_len, " bytes: ", err)
-            return 503
-        end
+    local data, err = sock:peek(pos + remain_len)
+    if not data then
+        core.log.error("failed to read the Connect Command: ", err)
+        return 503
+    end
 
-        res = parse_mqtt(data)
-        if res.expect_len > #data then
-            core.log.error("failed to parse mqtt request, expect len: ",
-                           res.expect_len, " but got ", #data)
-            return 503
-        end
+    local res = parse_mqtt(data, pos)
+    if res.expect_len > #data then
+        core.log.error("failed to parse mqtt request, expect len: ",
+                        res.expect_len, " but got ", #data)
+        return 503
     end
 
     if res.protocol and res.protocol ~= conf.protocol_name then
diff --git a/t/stream-plugin/mqtt-proxy.t b/t/stream-plugin/mqtt-proxy.t
index ae46fa8..3aa5cdf 100644
--- a/t/stream-plugin/mqtt-proxy.t
+++ b/t/stream-plugin/mqtt-proxy.t
@@ -328,7 +328,7 @@ mqtt client id: foo
 === TEST 13: hit route with empty client id
 --- stream_enable
 --- stream_request eval
-"\x10\x0f\x00\x04\x4d\x51\x54\x54\x04\x02\x00\x3c\x00\x00"
+"\x10\x0c\x00\x04\x4d\x51\x54\x54\x04\x02\x00\x3c\x00\x00"
 --- stream_response
 hello world
 --- grep_error_log eval
@@ -336,3 +336,74 @@ qr/mqtt client id: \w+/
 --- grep_error_log_out
 --- no_error_log
 [error]
+
+
+
+=== TEST 14: MQTT 5
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/stream_routes/1',
+                ngx.HTTP_PUT,
+                [[{
+                    "remote_addr": "127.0.0.1",
+                    "server_port": 1985,
+                    "plugins": {
+                        "mqtt-proxy": {
+                            "protocol_name": "MQTT",
+                            "protocol_level": 5
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": [{
+                            "host": "127.0.0.1",
+                            "port": 1995,
+                            "weight": 1
+                        }]
+                    }
+                }]]
+                )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- request
+GET /t
+--- response_body
+passed
+--- no_error_log
+[error]
+
+
+
+=== TEST 15: hit route with empty property
+--- stream_enable
+--- stream_request eval
+"\x10\x0d\x00\x04\x4d\x51\x54\x54\x05\x02\x00\x3c\x00\x00\x00"
+--- stream_response
+hello world
+--- grep_error_log eval
+qr/mqtt client id: \w+/
+--- grep_error_log_out
+--- no_error_log
+[error]
+
+
+
+=== TEST 16: hit route with property
+--- stream_enable
+--- stream_request eval
+"\x10\x1b\x00\x04\x4d\x51\x54\x54\x05\x02\x00\x3c\x05\x11\x00\x00\x0e\x10\x00\x09\x63\x6c\x69\x6e\x74\x2d\x31\x31\x31"
+--- stream_response
+hello world
+--- grep_error_log eval
+qr/mqtt client id: \S+/
+--- grep_error_log_out
+mqtt client id: clint-111
+--- no_error_log
+[error]

Reply via email to