This is an automated email from the ASF dual-hosted git repository.

cshannon pushed a commit to branch activemq-5.19.x
in repository https://gitbox.apache.org/repos/asf/activemq.git


The following commit(s) were added to refs/heads/activemq-5.19.x by this push:
     new 63ca7332a4 Update Stomp transports with improved validation (#2064) 
(#2066)
63ca7332a4 is described below

commit 63ca7332a47239964cbf2e343603412254ebb358
Author: Christopher L. Shannon <[email protected]>
AuthorDate: Tue Jun 2 08:29:10 2026 -0400

    Update Stomp transports with improved validation (#2064) (#2066)
    
    This update makes the following changes to improve validation for the
    Stomp transport:
    
    * Verifies that the first frame seen by the server is either a CONNECT
      (or FRAME) frame.
    * Verifies that a duplicate CONNECT (or FRAME) frame is not received.
    * Adds validation to make sure a content-length header that is set is
      not negative.
    * Adds a new server mode (default true) to the Stomp wireformat to
      handle the validation differences between clients and servers. Client
      mode is only used for testing (currently). Also adds the option to
      configure using the StompWireFormatFactory in case there is a future use
      case.
    * Centralizes the state tracking for frame size validation and for the
      new validation checks inside StompWireFormat so that it is shared by
      NIO, non-NIO and WS transports.
    * Adds tests to verify everything for the NIO transports, non-NIO
      transprots and WS transports.
    
    If any of these new validation checks throw a protocol error then it
    is marked as a fatal exception, an error is sent to the client and
    connection closed. Both NIO and non-NIO will stop parsing the rest of
    the frame on error, but only NIO transport errors will stop reading
    the frame from the socket buffer because non NIO requires reading the
    entire frame into a buffer first to validate.
    
    (cherry picked from commit 1493db95b5918d4f4a305fd1df8155f57c38850b)
---
 .../transport/ws/StompWSTransportTest.java         | 102 ++++++++++++++++
 .../activemq/transport/stomp/StompCodec.java       |  44 ++++---
 .../activemq/transport/stomp/StompConnection.java  |   1 +
 .../activemq/transport/stomp/StompWireFormat.java  | 132 +++++++++++++++++----
 .../transport/stomp/StompWireFormatFactory.java    |  13 +-
 .../apache/activemq/transport/stomp/StompTest.java | 111 ++++++++++++++++-
 .../stomp/StompWireFormatFactoryTest.java          |  49 ++++++++
 7 files changed, 402 insertions(+), 50 deletions(-)

diff --git 
a/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java
 
b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java
index d9ea8ec417..9a20f2f494 100644
--- 
a/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java
+++ 
b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java
@@ -28,6 +28,7 @@ import java.util.concurrent.TimeUnit;
 import org.apache.activemq.transport.stomp.Stomp;
 import org.apache.activemq.transport.stomp.StompFrame;
 import org.apache.activemq.util.Wait;
+import org.apache.activemq.util.Wait.Condition;
 import org.eclipse.jetty.client.HttpClient;
 import org.eclipse.jetty.util.ssl.SslContextFactory;
 import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
@@ -115,6 +116,107 @@ public class StompWSTransportTest extends 
WSTransportTestSupport {
         }));
     }
 
+    @Test(timeout = 60000)
+    public void testMissingStompConnect() throws Exception {
+        // Send a frame without first sending a CONNECT frame, which is a 
protocol violation
+        String message = "SEND\n" + "destination:/queue/" + getTestName() + 
"\n\n" + "Hello World" + Stomp.NULL;
+        wsStompConnection.sendRawFrame(message);
+
+        String incoming = wsStompConnection.receive(5, TimeUnit.SECONDS);
+        assertNotNull(incoming);
+        assertTrue(incoming.startsWith("ERROR"));
+        assertTrue(incoming.contains("StompWireFormat is configured for 
'server' mode and received an"
+                + " unexpected frame before CONNECT or STOMP frame: SEND"));
+
+        assertTrue("Connection should close", Wait.waitFor(
+                (Condition) () -> wsStompConnection.isNotConnected()));
+    }
+
+    @Test(timeout = 60000)
+    public void testNegativeContentLength() throws Exception {
+        String connectFrame = "STOMP\n" +
+                "login:system\n" +
+                "passcode:manager\n" +
+                "accept-version:1.2\n" +
+                "host:localhost\n" +
+                "\n" + Stomp.NULL;
+
+        wsStompConnection.sendRawFrame(connectFrame);
+
+        String incoming = wsStompConnection.receive(30, TimeUnit.SECONDS);
+        assertNotNull(incoming);
+        assertTrue(incoming.startsWith("CONNECTED"));
+
+        String message = "SEND\n" + "destination:/queue/" + getTestName() + 
"\ncontent-length:-1" + " \n\n" + "body" + Stomp.NULL;
+        wsStompConnection.sendRawFrame(message);
+
+        // Negative content length is a protocol error and should return
+        // an error and close the connection
+        incoming = wsStompConnection.receive(5, TimeUnit.SECONDS);
+        assertNotNull(incoming);
+        assertTrue(incoming.startsWith("ERROR"));
+        assertTrue(incoming.contains("Specified content-length may not be 
negative"));
+
+        assertTrue("Connection should close", Wait.waitFor(
+                (Condition) () -> wsStompConnection.isNotConnected()));
+    }
+
+    @Test(timeout = 60000)
+    public void testDuplicateConnect() throws Exception {
+        String connectFrame = "STOMP\n" +
+                "login:system\n" +
+                "passcode:manager\n" +
+                "accept-version:1.2\n" +
+                "host:localhost\n" +
+                "\n" + Stomp.NULL;
+
+        wsStompConnection.sendRawFrame(connectFrame);
+
+        String incoming = wsStompConnection.receive(30, TimeUnit.SECONDS);
+        assertNotNull(incoming);
+        assertTrue(incoming.startsWith("CONNECTED"));
+
+        // Sending a second CONNECT frame is not allowed and should error
+        wsStompConnection.sendRawFrame(connectFrame);
+
+        incoming = wsStompConnection.receive(5, TimeUnit.SECONDS);
+        assertNotNull(incoming);
+        assertTrue(incoming.startsWith("ERROR"));
+        assertTrue(incoming.contains("duplicate CONNECT or STOMP frame"));
+
+        assertTrue("Connection should close", Wait.waitFor(
+                (Condition) () -> wsStompConnection.isNotConnected()));
+    }
+
+    @Test(timeout = 60000)
+    public void testInvalidServerResponseReceived() throws Exception {
+        String connectFrame = "STOMP\n" +
+                "login:system\n" +
+                "passcode:manager\n" +
+                "accept-version:1.2\n" +
+                "host:localhost\n" +
+                "\n" + Stomp.NULL;
+
+        wsStompConnection.sendRawFrame(connectFrame);
+
+        String incoming = wsStompConnection.receive(30, TimeUnit.SECONDS);
+        assertNotNull(incoming);
+        assertTrue(incoming.startsWith("CONNECTED"));
+
+        // Sending a server response to the server, which is invalid
+        String invalidFrame = "RECEIPT\n" + "receipt-id:message-12345\n\n" + 
Stomp.NULL;
+        wsStompConnection.sendRawFrame(invalidFrame);
+        incoming = wsStompConnection.receive(5, TimeUnit.SECONDS);
+        assertNotNull(incoming);
+        assertTrue(incoming.startsWith("ERROR"));
+        assertTrue(incoming.contains("StompWireFormat is configured for 
'server' mode and received a"
+                + " frame that is only expected when configured for 'client' 
mode: RECEIPT"));
+
+        // make sure the connection was closed by the server
+        assertTrue("Connection should close", Wait.waitFor(
+                (Condition) () -> wsStompConnection.isNotConnected()));
+    }
+
     @Test(timeout = 60000)
     public void testConnectWithVersionOptions() throws Exception {
         String connectFrame = "STOMP\n" +
diff --git 
a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java
 
b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java
index 54ec21a82c..dd0f48f5aa 100644
--- 
a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java
+++ 
b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompCodec.java
@@ -19,10 +19,9 @@ package org.apache.activemq.transport.stomp;
 import java.io.ByteArrayInputStream;
 import java.util.Arrays;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
-import java.util.concurrent.atomic.AtomicLong;
+import java.util.Objects;
 
 import org.apache.activemq.transport.tcp.TcpTransport;
 import org.apache.activemq.util.ByteArrayOutputStream;
@@ -30,24 +29,22 @@ import org.apache.activemq.util.DataByteArrayInputStream;
 
 public class StompCodec {
 
-    final static byte[] crlfcrlf = new byte[]{'\r','\n','\r','\n'};
-    TcpTransport transport;
-    StompWireFormat wireFormat;
-
-    AtomicLong frameSize = new AtomicLong(); 
-    ByteArrayOutputStream currentCommand = new ByteArrayOutputStream();
-    boolean processedHeaders = false;
-    String action;
-    HashMap<String, String> headers;
-    int contentLength = -1;
-    int readLength = 0;
-    int previousByte = -1;
-    boolean awaitingCommandStart = true;
-    String version = Stomp.DEFAULT_VERSION;
+    private final static byte[] crlfcrlf = new byte[]{'\r','\n','\r','\n'};
+    private final TcpTransport transport;
+    private final StompWireFormat wireFormat;
+
+    private final ByteArrayOutputStream currentCommand = new 
ByteArrayOutputStream();
+    private boolean processedHeaders = false;
+    private String action;
+    private Map<String, String> headers;
+    private int contentLength = -1;
+    private int readLength = 0;
+    private int previousByte = -1;
+    private boolean awaitingCommandStart = true;
 
     public StompCodec(TcpTransport transport) {
-        this.transport = transport;
-        this.wireFormat = (StompWireFormat) transport.getWireFormat();
+        this.transport = Objects.requireNonNull(transport);
+        this.wireFormat = (StompWireFormat) 
Objects.requireNonNull(transport.getWireFormat());
     }
 
     public void parse(ByteArrayInputStream input, int readSize) throws 
Exception {
@@ -75,12 +72,13 @@ public class StompCodec {
                    DataByteArrayInputStream data = new 
DataByteArrayInputStream(currentCommand.toByteArray());
 
                    try {
-                       action = wireFormat.parseAction(data, frameSize);
-                       headers = wireFormat.parseHeaders(data, frameSize);
+                       action = wireFormat.parseAction(data);
+                       wireFormat.validateAction(action);
+                       headers = wireFormat.parseHeaders(data);
                        
                        String contentLengthHeader = 
headers.get(Stomp.Headers.CONTENT_LENGTH);
                        if ((action.equals(Stomp.Commands.SEND) || 
action.equals(Stomp.Responses.MESSAGE)) && contentLengthHeader != null) {
-                           contentLength = 
wireFormat.parseContentLength(contentLengthHeader, frameSize);
+                           contentLength = 
wireFormat.parseContentLength(contentLengthHeader);
                        } else {
                            contentLength = -1;
                        }
@@ -106,7 +104,7 @@ public class StompCodec {
                            transport.doConsume(errorFrame);
                            return;
                        }
-                       if (frameSize.incrementAndGet() > 
wireFormat.getMaxFrameSize()) {
+                       if (wireFormat.incrementAndGetFrameSize() > 
wireFormat.getMaxFrameSize()) {
                            StompFrameError errorFrame = new 
StompFrameError(new ProtocolException("The maximum frame size was exceeded", 
true));
                            errorFrame.setAction(this.action);
                            transport.doConsume(errorFrame);
@@ -135,7 +133,7 @@ public class StompCodec {
         awaitingCommandStart = true;
         currentCommand.reset();
         contentLength = -1;
-        frameSize.set(0);
+        wireFormat.resetFrame();
     }
 
     public static String detectVersion(Map<String, String> headers) throws 
ProtocolException {
diff --git 
a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompConnection.java
 
b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompConnection.java
index 7a4ba5e701..fbf46e38de 100644
--- 
a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompConnection.java
+++ 
b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompConnection.java
@@ -73,6 +73,7 @@ public class StompConnection {
         InputStream is = stompSocket.getInputStream();
         StompWireFormat wf = new StompWireFormat();
         wf.setStompVersion(version);
+        wf.setServerMode(false);
         DataInputStream dis = new DataInputStream(is);
         return (StompFrame)wf.unmarshal(dis);
     }
diff --git 
a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java
 
b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java
index cfc4c2d8ab..621023e92b 100644
--- 
a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java
+++ 
b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormat.java
@@ -23,10 +23,12 @@ import java.io.DataOutputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.PushbackInputStream;
+import java.nio.charset.StandardCharsets;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.concurrent.atomic.AtomicLong;
 
+import org.apache.activemq.transport.stomp.Stomp.Commands;
+import org.apache.activemq.transport.stomp.Stomp.Responses;
 import org.apache.activemq.util.ByteArrayInputStream;
 import org.apache.activemq.util.ByteArrayOutputStream;
 import org.apache.activemq.util.ByteSequence;
@@ -48,15 +50,24 @@ public class StompWireFormat implements WireFormat {
     public static final int MAX_DATA_LENGTH = 1024 * 1024 * 100;
     public static final long DEFAULT_MAX_FRAME_SIZE = Long.MAX_VALUE;
     public static final long DEFAULT_CONNECTION_TIMEOUT = 30000;
+    public static final boolean DEFAULT_SERVER_MODE = true;
 
     private int version = 1;
     private int maxDataLength = MAX_DATA_LENGTH;
     private long maxFrameSize = DEFAULT_MAX_FRAME_SIZE;
     private String stompVersion = Stomp.DEFAULT_VERSION;
     private long connectionAttemptTimeout = DEFAULT_CONNECTION_TIMEOUT;
+    // Track if this wireformat is used on the server or a client
+    // This will generally be set to true, client mode is normally
+    // used for testing
+    private boolean serverMode = DEFAULT_SERVER_MODE;
 
     //The current frame size as it is unmarshalled from the stream
-    private final AtomicLong frameSize = new AtomicLong();
+    private long frameSize = 0;
+    // A new StompWireFormat is instantiated for each connection
+    // This tracks if the server has received the CONNECT (or FRAME) frame 
first.
+    private boolean processedConnect = false;
+    private boolean fatalError = false;
 
     @Override
     public ByteSequence marshal(Object command) throws IOException {
@@ -103,7 +114,7 @@ public class StompWireFormat implements WireFormat {
 
         StringBuilder builder = new StringBuilder();
 
-        os.write(marshalHeaders(stomp, builder).toString().getBytes("UTF-8"));
+        os.write(marshalHeaders(stomp, 
builder).toString().getBytes(StandardCharsets.UTF_8));
         os.write(stomp.getContent());
         os.write(END_OF_FRAME);
     }
@@ -117,7 +128,7 @@ public class StompWireFormat implements WireFormat {
         marshalHeaders(stomp, buffer);
 
         if (stomp.getContent() != null) {
-            String contentString = new String(stomp.getContent(), "UTF-8");
+            String contentString = new String(stomp.getContent(), 
StandardCharsets.UTF_8);
             buffer.append(contentString);
         }
 
@@ -128,11 +139,19 @@ public class StompWireFormat implements WireFormat {
     @Override
     public Object unmarshal(DataInput in) throws IOException {
         try {
+            // Prevent processing any future data, this is necessary because 
there is a delay
+            // in the connection close so it's possible this gets called again.
+            if (fatalError) {
+                throw new IOException("Can't process anymore data, fatal 
ProtocolError previously received.");
+            }
+
             // parse action
-            String action = parseAction(in, frameSize);
+            String action = parseAction(in);
+            // Ensure the action is valid
+            validateAction(action);
 
             // Parse the headers
-            HashMap<String, String> headers = parseHeaders(in, frameSize);
+            HashMap<String, String> headers = parseHeaders(in);
 
             // Read in the data part.
             byte[] data = NO_DATA;
@@ -140,7 +159,7 @@ public class StompWireFormat implements WireFormat {
             if ((action.equals(Stomp.Commands.SEND) || 
action.equals(Stomp.Responses.MESSAGE)) && contentLength != null) {
 
                 // Bless the client, he's telling us how much data to read in.
-                int length = parseContentLength(contentLength, frameSize);
+                int length = parseContentLength(contentLength);
 
                 data = new byte[length];
                 in.readFully(data);
@@ -160,7 +179,7 @@ public class StompWireFormat implements WireFormat {
                     } else if (baos.size() > getMaxDataLength()) {
                         throw new ProtocolException("The maximum data length 
was exceeded", true);
                     } else {
-                        if (frameSize.incrementAndGet() > getMaxFrameSize()) {
+                        if (++frameSize > getMaxFrameSize()) {
                             throw new ProtocolException("The maximum frame 
size was exceeded", true);
                         }
                     }
@@ -177,15 +196,19 @@ public class StompWireFormat implements WireFormat {
             return new StompFrame(action, headers, data);
 
         } catch (ProtocolException e) {
+            if (e.isFatal()) {
+                fatalError = true;
+            }
             return new StompFrameError(e);
         } finally {
-            frameSize.set(0);
+            resetFrame();
         }
     }
 
     private String readLine(DataInput in, int maxLength, String errorMessage) 
throws IOException {
         ByteSequence sequence = readHeaderLine(in, maxLength, errorMessage);
-        return new String(sequence.getData(), sequence.getOffset(), 
sequence.getLength(), "UTF-8").trim();
+        return new String(sequence.getData(), sequence.getOffset(), 
sequence.getLength(),
+                StandardCharsets.UTF_8).trim();
     }
 
     private ByteSequence readHeaderLine(DataInput in, int maxLength, String 
errorMessage) throws IOException {
@@ -212,7 +235,7 @@ public class StompWireFormat implements WireFormat {
         return line;
     }
 
-    protected String parseAction(DataInput in, AtomicLong frameSize) throws 
IOException {
+    String parseAction(DataInput in) throws IOException {
         String action = null;
 
         // skip white space to next real action line
@@ -227,11 +250,59 @@ public class StompWireFormat implements WireFormat {
                 }
             }
         }
-        frameSize.addAndGet(action.length());
+        frameSize += action.length();
         return action;
     }
 
-    protected HashMap<String, String> parseHeaders(DataInput in, AtomicLong 
frameSize) throws IOException {
+    // Validate that the server/client receive packets that are expected
+    void validateAction(String action) throws ProtocolException {
+        // Validate for the server
+        if (serverMode) {
+            switch(action) {
+                // Mark that we received the expected first frame it isn't a 
duplicate
+                case Commands.CONNECT:
+                case Commands.STOMP:
+                    if (processedConnect) {
+                        throw new ProtocolException("StompWireFormat is 
configured for 'server' mode"
+                                + " and received a duplicate CONNECT or STOMP 
frame",
+                                true);
+                    }
+                    processedConnect = true;
+                    return;
+                // These are response packets, the server should not receive 
them
+                case Responses.CONNECTED:
+                case Responses.MESSAGE:
+                case Responses.ERROR:
+                case Responses.RECEIPT:
+                    throw new ProtocolException(
+                            "StompWireFormat is configured for 'server' mode 
and received a"
+                                    + " frame that is only expected when 
configured for 'client' mode: " + action, true);
+                default:
+                    // Any other frame received before CONNECT/STOMP is an 
error
+                    if (!processedConnect) {
+                        throw new ProtocolException("StompWireFormat is 
configured for 'server' mode and received an" +
+                                " unexpected frame before CONNECT or STOMP 
frame: " + action, true);
+                    }
+            }
+        } else {
+            switch(action) {
+                // The client should only receive response frames
+                case Responses.CONNECTED:
+                case Responses.MESSAGE:
+                case Responses.ERROR:
+                case Responses.RECEIPT:
+                    return;
+                default:
+                    // Any other frame received that is not a Response is an 
error in client mode
+                    throw new ProtocolException(
+                            "StompWireFormat is configured for 'client' mode 
and received a"
+                                    + " frame that is not expected: " + 
action, true);
+            }
+        }
+
+    }
+
+    HashMap<String, String> parseHeaders(DataInput in) throws IOException {
         HashMap<String, String> headers = new HashMap<>(25);
         while (true) {
             ByteSequence line = readHeaderLine(in, MAX_HEADER_LENGTH, "The 
maximum header length was exceeded");
@@ -240,7 +311,7 @@ public class StompWireFormat implements WireFormat {
                 if (headers.size() > MAX_HEADERS) {
                     throw new ProtocolException("The maximum number of headers 
was exceeded", true);
                 }
-                frameSize.addAndGet(line.length);
+                frameSize += line.length;
 
                 try {
 
@@ -259,7 +330,8 @@ public class StompWireFormat implements WireFormat {
 
                     ByteSequence nameSeq = stream.toByteSequence();
 
-                    String name = new String(nameSeq.getData(), 
nameSeq.getOffset(), nameSeq.getLength(), "UTF-8");
+                    String name = new String(nameSeq.getData(), 
nameSeq.getOffset(), nameSeq.getLength(),
+                            StandardCharsets.UTF_8);
                     String value = decodeHeader(headerLine);
                     if (stompVersion.equals(Stomp.V1_0)) {
                         value = value.trim();
@@ -281,7 +353,7 @@ public class StompWireFormat implements WireFormat {
         return headers;
     }
 
-    protected int parseContentLength(String contentLength, AtomicLong 
frameSize) throws ProtocolException {
+    int parseContentLength(String contentLength) throws ProtocolException {
         int length;
         try {
             length = Integer.parseInt(contentLength.trim());
@@ -289,11 +361,15 @@ public class StompWireFormat implements WireFormat {
             throw new ProtocolException("Specified content-length is not a 
valid integer", true);
         }
 
+        if (length < 0) {
+            throw new ProtocolException("Specified content-length may not be 
negative", true);
+        }
+
         if (length > getMaxDataLength()) {
             throw new ProtocolException("The maximum data length was 
exceeded", true);
         }
 
-        if (frameSize.addAndGet(length) > getMaxFrameSize()) {
+        if ((frameSize += length) > getMaxFrameSize()) {
             throw new ProtocolException("The maximum frame size was exceeded", 
true);
         }
 
@@ -303,7 +379,7 @@ public class StompWireFormat implements WireFormat {
     private String encodeHeader(String header) throws IOException {
         String result = header;
         if (!stompVersion.equals(Stomp.V1_0)) {
-            byte[] utf8buf = header.getBytes("UTF-8");
+            byte[] utf8buf = header.getBytes(StandardCharsets.UTF_8);
             ByteArrayOutputStream stream = new 
ByteArrayOutputStream(utf8buf.length);
             for(byte val : utf8buf) {
                 switch(val) {
@@ -325,7 +401,7 @@ public class StompWireFormat implements WireFormat {
                     stream.write(val);
                 }
             }
-            result =  new String(stream.toByteArray(), "UTF-8");
+            result =  new String(stream.toByteArray(), StandardCharsets.UTF_8);
             stream.close();
         }
 
@@ -372,7 +448,7 @@ public class StompWireFormat implements WireFormat {
 
         decoded.close();
 
-        return new String(decoded.toByteArray(), "UTF-8");
+        return new String(decoded.toByteArray(), StandardCharsets.UTF_8);
     }
 
     @Override
@@ -416,4 +492,20 @@ public class StompWireFormat implements WireFormat {
     public void setConnectionAttemptTimeout(long connectionAttemptTimeout) {
         this.connectionAttemptTimeout = connectionAttemptTimeout;
     }
+
+    public boolean isServerMode() {
+        return serverMode;
+    }
+
+    public void setServerMode(boolean serverMode) {
+        this.serverMode = serverMode;
+    }
+
+    long incrementAndGetFrameSize() {
+        return ++frameSize;
+    }
+
+    void resetFrame() {
+        frameSize = 0;
+    }
 }
diff --git 
a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormatFactory.java
 
b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormatFactory.java
index 54effa6d96..549eab0441 100644
--- 
a/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormatFactory.java
+++ 
b/activemq-stomp/src/main/java/org/apache/activemq/transport/stomp/StompWireFormatFactory.java
@@ -16,7 +16,6 @@
  */
 package org.apache.activemq.transport.stomp;
 
-import org.apache.activemq.wireformat.WireFormat;
 import org.apache.activemq.wireformat.WireFormatFactory;
 
 /**
@@ -26,13 +25,15 @@ public class StompWireFormatFactory implements 
WireFormatFactory {
 
     private int maxDataLength = StompWireFormat.MAX_DATA_LENGTH;
     private long maxFrameSize = StompWireFormat.DEFAULT_MAX_FRAME_SIZE;
+    private boolean serverMode = StompWireFormat.DEFAULT_SERVER_MODE;
 
     @Override
-    public WireFormat createWireFormat() {
+    public StompWireFormat createWireFormat() {
         StompWireFormat wireFormat = new StompWireFormat();
 
         wireFormat.setMaxDataLength(getMaxDataLength());
         wireFormat.setMaxFrameSize(getMaxFrameSize());
+        wireFormat.setServerMode(isServerMode());
 
         return wireFormat;
     }
@@ -52,4 +53,12 @@ public class StompWireFormatFactory implements 
WireFormatFactory {
     public void setMaxFrameSize(long maxFrameSize) {
         this.maxFrameSize = maxFrameSize;
     }
+
+    public boolean isServerMode() {
+        return serverMode;
+    }
+
+    public void setServerMode(boolean serverMode) {
+        this.serverMode = serverMode;
+    }
 }
diff --git 
a/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompTest.java
 
b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompTest.java
index 4142e1fd66..000743ed5f 100644
--- 
a/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompTest.java
+++ 
b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompTest.java
@@ -25,7 +25,6 @@ import static org.junit.Assert.fail;
 
 import java.io.IOException;
 import java.io.StringReader;
-import java.lang.reflect.Field;
 import java.net.SocketTimeoutException;
 import java.util.Arrays;
 import java.util.HashMap;
@@ -50,11 +49,8 @@ import javax.jms.Session;
 import javax.jms.TextMessage;
 import javax.management.ObjectName;
 
-import javax.net.ssl.SSLEngineResult;
-import javax.net.ssl.SSLEngineResult.HandshakeStatus;
 import javax.net.ssl.SSLSocket;
 import org.apache.activemq.broker.BrokerService;
-import org.apache.activemq.broker.TransportConnection;
 import org.apache.activemq.broker.TransportConnector;
 import org.apache.activemq.broker.jmx.BrokerViewMBean;
 import org.apache.activemq.broker.jmx.QueueViewMBean;
@@ -67,7 +63,8 @@ import org.apache.activemq.command.ActiveMQDestination;
 import org.apache.activemq.command.ActiveMQMessage;
 import org.apache.activemq.command.ActiveMQQueue;
 import org.apache.activemq.command.ActiveMQTextMessage;
-import org.apache.activemq.transport.nio.NIOSSLTransport;
+import org.apache.activemq.transport.stomp.Stomp.Commands;
+import org.apache.activemq.transport.stomp.Stomp.Responses;
 import org.apache.activemq.util.NioSslTestUtil;
 import org.apache.activemq.util.Wait;
 import org.junit.Assume;
@@ -2644,6 +2641,110 @@ public class StompTest extends StompTestSupport {
         receiveForSslHandshakeTest();
     }
 
+    @Test(timeout = 60000)
+    public void testMissingConnectFrame() throws Exception {
+        final boolean isAutoTransport = 
transportConnectorName.contains("auto");
+
+        // Send a frame without first sending a CONNECT frame, which is a 
protocol violation
+        String frame = "SEND\n" + "destination:/queue/" + getQueueName() + " 
\n\n" + "body" + Stomp.NULL;
+        stompConnection.sendFrame(frame);
+
+        // The auto transport will just disconnect because it can't detect the 
protocol
+        // without the CONNECT frame
+        if (isAutoTransport) {
+            try {
+                stompConnection.receive();
+            } catch (IOException e) {
+               // Expected, the connection should be closed because the 
transport
+               // can't detect the wire protocol without the initial packet
+            }
+        } else {
+            // For the other stomp transports we should get an error back first
+            StompFrame message = stompConnection.receive();
+            assertEquals(Responses.ERROR, message.getAction());
+            assertTrue(message.getBody().contains(
+                    "StompWireFormat is configured for 'server' mode and 
received an"
+                            + " unexpected frame before CONNECT or STOMP 
frame: SEND"));
+            // make sure the connection was closed by the server
+            assertConnectionClosed(5000);
+        }
+    }
+
+    @Test(timeout = 60000)
+    public void testNegativeContentLength() throws Exception {
+        String frame = "CONNECT\n" + "login:system\n" + "passcode:manager\n\n" 
+ Stomp.NULL;
+        stompConnection.sendFrame(frame);
+
+        frame = stompConnection.receiveFrame();
+        assertTrue(frame.startsWith("CONNECTED"));
+
+        frame = "SEND\n" + "destination:/queue/" + getQueueName() + 
"\ncontent-length:-1" + " \n\n" + "body" + Stomp.NULL;
+        stompConnection.sendFrame(frame);
+
+        // Negative content length is a protocol error and should return
+        // an error and close the connection
+        StompFrame message = stompConnection.receive();
+        assertEquals(Responses.ERROR, message.getAction());
+        assertTrue(message.getBody().contains("Specified content-length may 
not be negative"));
+
+        // make sure the connection was closed by the server
+        assertConnectionClosed(5000);
+    }
+
+    @Test(timeout = 60000)
+    public void testDuplicateConnect() throws Exception {
+        testDuplicateConnect(Commands.CONNECT);
+    }
+
+    @Test(timeout = 60000)
+    public void testDuplicateStomp() throws Exception {
+        testDuplicateConnect(Commands.STOMP);
+    }
+
+    private void testDuplicateConnect(String connectPacket) throws Exception {
+        String frame = connectPacket + "\n" + "login:system\n" + 
"passcode:manager\n\n" + Stomp.NULL;
+        stompConnection.sendFrame(frame);
+
+        String received = stompConnection.receiveFrame();
+        assertTrue(received.startsWith("CONNECTED"));
+
+        // Sending a second CONNECT frame is not allowed and should error
+        stompConnection.sendFrame(frame);
+        StompFrame message = stompConnection.receive();
+        assertEquals(Responses.ERROR, message.getAction());
+        assertTrue(message.getBody().contains("duplicate CONNECT or STOMP 
frame"));
+
+        // make sure the connection was closed by the server
+        assertConnectionClosed(5000);
+    }
+
+    @Test(timeout = 60000)
+    public void testInvalidServerResponseReceived() throws Exception {
+        String frame = "CONNECT\n" + "login:system\n" + "passcode:manager\n\n" 
+ Stomp.NULL;
+        stompConnection.sendFrame(frame);
+
+        String received = stompConnection.receiveFrame();
+        assertTrue(received.startsWith("CONNECTED"));
+
+        // Sending a server response to the server, which is invalid
+        frame = "RECEIPT\n" + "receipt-id:message-12345\n\n" + Stomp.NULL;
+        stompConnection.sendFrame(frame);
+        StompFrame message = stompConnection.receive();
+        assertEquals(Responses.ERROR, message.getAction());
+        assertTrue(message.getBody().contains("StompWireFormat is configured 
for 'server' mode and received a"
+                + " frame that is only expected when configured for 'client' 
mode: RECEIPT"));
+
+        // make sure the connection was closed by the server
+        assertConnectionClosed(5000);
+    }
+
+    protected void assertConnectionClosed(int timeout) throws Exception {
+        stompConnection.getStompSocket().setSoTimeout(timeout);
+        // -1 read means the socket was closed by the server
+        assertTrue("Should drop connection", Wait.waitFor(
+                () -> stompConnection.getStompSocket().getInputStream().read() 
== -1, timeout, 10));
+    }
+
     private void checkHandshakeStatusAdvances(SSLSocket socket) throws 
Exception {
         TransportConnector connector = 
brokerService.getTransportConnectorByName(transportConnectorName);
         NioSslTestUtil.checkHandshakeStatusAdvances(connector, socket);
diff --git 
a/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompWireFormatFactoryTest.java
 
b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompWireFormatFactoryTest.java
new file mode 100644
index 0000000000..acff2d0cf7
--- /dev/null
+++ 
b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompWireFormatFactoryTest.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.activemq.transport.stomp;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+import org.junit.Test;
+
+public class StompWireFormatFactoryTest {
+
+    @Test
+    public void testDefaults() {
+        StompWireFormatFactory factory = new StompWireFormatFactory();
+
+        StompWireFormat wireFormat = factory.createWireFormat();
+        assertEquals(StompWireFormat.DEFAULT_MAX_FRAME_SIZE, 
wireFormat.getMaxFrameSize());
+        assertEquals(StompWireFormat.MAX_DATA_LENGTH, 
wireFormat.getMaxDataLength());
+        assertEquals(StompWireFormat.DEFAULT_SERVER_MODE, 
wireFormat.isServerMode());
+    }
+
+    @Test
+    public void testSetters() {
+        StompWireFormatFactory factory = new StompWireFormatFactory();
+        factory.setMaxFrameSize(1020L);
+        factory.setMaxDataLength(2040);
+        factory.setServerMode(false);
+
+        StompWireFormat wireFormat = factory.createWireFormat();
+        assertEquals(1020L, wireFormat.getMaxFrameSize());
+        assertEquals(2040, wireFormat.getMaxDataLength());
+        assertFalse(wireFormat.isServerMode());
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
For further information, visit: https://activemq.apache.org/contact



Reply via email to