This is an automated email from the ASF dual-hosted git repository.

havret pushed a commit to branch allow_deny_list
in repository https://gitbox.apache.org/repos/asf/activemq-nms-openwire.git

commit 6bedbb487cc4f321613bb68811c1656ebe9f52d7
Author: Havret <[email protected]>
AuthorDate: Wed Feb 15 20:03:17 2023 +0100

    AMQNET-829 Add allow, deny types support
---
 src/Commands/ActiveMQObjectMessage.cs       |   4 +
 src/Commands/TrustedClassFilter.cs          |  51 +++++++++
 src/Connection.cs                           |   2 +
 src/ConnectionFactory.cs                    |   6 ++
 src/INmsDeserializationPolicy.cs            |  42 ++++++++
 src/NmsDefaultDeserializationPolicy.cs      | 129 +++++++++++++++++++++++
 test/MessageConsumerTest.cs                 |  87 ++++++++++++++++
 test/NMSConnectionFactoryTest.cs            |  24 ++++-
 test/NmsDefaultDeserializationPolicyTest.cs | 156 ++++++++++++++++++++++++++++
 9 files changed, 500 insertions(+), 1 deletion(-)

diff --git a/src/Commands/ActiveMQObjectMessage.cs 
b/src/Commands/ActiveMQObjectMessage.cs
index 3919111..aec88b5 100644
--- a/src/Commands/ActiveMQObjectMessage.cs
+++ b/src/Commands/ActiveMQObjectMessage.cs
@@ -140,6 +140,10 @@ namespace Apache.NMS.ActiveMQ.Commands
                 if (formatter == null)
                 {
                     formatter = new BinaryFormatter();
+                    if (Connection.DeserializationPolicy != null)
+                    {
+                        formatter.Binder = new 
TrustedClassFilter(Connection.DeserializationPolicy, Destination);
+                    }
                 }
                 return formatter;
             }
diff --git a/src/Commands/TrustedClassFilter.cs 
b/src/Commands/TrustedClassFilter.cs
new file mode 100644
index 0000000..eb90662
--- /dev/null
+++ b/src/Commands/TrustedClassFilter.cs
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Reflection;
+using System.Runtime.Serialization;
+
+namespace Apache.NMS.ActiveMQ.Commands
+{
+    internal class TrustedClassFilter : SerializationBinder
+    {
+        private readonly INmsDeserializationPolicy deserializationPolicy;
+        private readonly IDestination destination;
+
+        public TrustedClassFilter(INmsDeserializationPolicy 
deserializationPolicy, IDestination destination)
+        {
+            this.deserializationPolicy = deserializationPolicy;
+            this.destination = destination;
+        }
+        
+        public override Type BindToType(string assemblyName, string typeName)
+        {
+            var name = new AssemblyName(assemblyName);
+            var assembly = Assembly.Load(name);
+            var type = FormatterServices.GetTypeFromAssembly(assembly, 
typeName);
+            if (deserializationPolicy.IsTrustedType(destination, type))
+            {
+                return type;
+            }
+
+            var message = $"Forbidden {type.FullName}! " +
+                          "This type is not trusted to be deserialized under 
the current configuration. " +
+                          "Please refer to the documentation for more 
information on how to configure trusted types.";
+            throw new SerializationException(message);
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/Connection.cs b/src/Connection.cs
index df0cc57..a2d7df2 100644
--- a/src/Connection.cs
+++ b/src/Connection.cs
@@ -482,6 +482,8 @@ namespace Apache.NMS.ActiveMQ
             get { return this.compressionPolicy; }
             set { this.compressionPolicy = value; }
         }
+        
+        public INmsDeserializationPolicy DeserializationPolicy { get; set; } = 
new NmsDefaultDeserializationPolicy();
 
         internal MessageTransformation MessageTransformation
         {
diff --git a/src/ConnectionFactory.cs b/src/ConnectionFactory.cs
index 84ecf22..6c778cb 100644
--- a/src/ConnectionFactory.cs
+++ b/src/ConnectionFactory.cs
@@ -397,6 +397,11 @@ namespace Apache.NMS.ActiveMQ
                                }
                        }
                }
+               
+               /// <summary>
+               /// The deserialization policy that is applied when a 
connection is created.
+               /// </summary>
+               public INmsDeserializationPolicy DeserializationPolicy { get; 
set; } = new NmsDefaultDeserializationPolicy();
 
                public IdGenerator ClientIdGenerator
                {
@@ -546,6 +551,7 @@ namespace Apache.NMS.ActiveMQ
                        connection.RedeliveryPolicy = 
this.redeliveryPolicy.Clone() as IRedeliveryPolicy;
                        connection.PrefetchPolicy = this.prefetchPolicy.Clone() 
as PrefetchPolicy;
                        connection.CompressionPolicy = 
this.compressionPolicy.Clone() as ICompressionPolicy;
+                       connection.DeserializationPolicy = 
this.DeserializationPolicy.Clone();
                        connection.ConsumerTransformer = 
this.consumerTransformer;
                        connection.ProducerTransformer = 
this.producerTransformer;
             connection.WatchTopicAdvisories = this.watchTopicAdvisories;
diff --git a/src/INmsDeserializationPolicy.cs b/src/INmsDeserializationPolicy.cs
new file mode 100644
index 0000000..743855d
--- /dev/null
+++ b/src/INmsDeserializationPolicy.cs
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+
+namespace Apache.NMS.ActiveMQ
+{
+    /// <summary>
+    /// Defines the interface for a policy that controls the permissible 
message content 
+    /// during the deserialization of the body of an incoming <see 
cref="IObjectMessage"/>.
+    /// </summary>
+    public interface INmsDeserializationPolicy
+    {
+        /// <summary>
+        /// Determines if the given class is a trusted type that can be 
deserialized by the client.
+        /// </summary>
+        /// <param name="destination">The Destination for the message 
containing the type to be deserialized.</param>
+        /// <param name="type">The type of the object that is about to be 
read.</param>
+        /// <returns>True if the type is trusted, otherwise false.</returns>
+        bool IsTrustedType(IDestination destination, Type type);
+        
+        /// <summary>
+        /// Makes a thread-safe copy of the INmsDeserializationPolicy object.
+        /// </summary>
+        /// <returns>A copy of INmsDeserializationPolicy object.</returns>
+        INmsDeserializationPolicy Clone();
+    }
+}
\ No newline at end of file
diff --git a/src/NmsDefaultDeserializationPolicy.cs 
b/src/NmsDefaultDeserializationPolicy.cs
new file mode 100644
index 0000000..b56531a
--- /dev/null
+++ b/src/NmsDefaultDeserializationPolicy.cs
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace Apache.NMS.ActiveMQ
+{
+    /// <summary>
+    /// Default implementation of the deserialization policy that can read 
allow and deny lists of
+    /// types/namespaces from the connection URI options.
+    ///
+    /// The policy reads a default deny list string value (comma separated) 
from the connection URI options
+    /// (nms.deserializationPolicy.deny) which defaults to null which 
indicates an empty deny list.
+    ///
+    /// The policy reads a default allow list string value (comma separated) 
from the connection URI options
+    /// (nms.deserializationPolicy.allowList) which defaults to <see 
cref="CATCH_ALL_WILDCARD"/> which
+    /// indicates that all types are allowed.
+    ///
+    /// The deny list overrides the allow list, entries that could match both 
are counted as denied.
+    ///
+    /// If the policy should treat all classes as untrusted, the deny list 
should be set to <see cref="CATCH_ALL_WILDCARD"/>.
+    /// </summary>
+    public class NmsDefaultDeserializationPolicy : INmsDeserializationPolicy
+    {
+        /// <summary>
+        /// Value used to indicate that all types should be allowed or denied
+        /// </summary>
+        public const string CATCH_ALL_WILDCARD = "*";
+        
+        private IReadOnlyList<string> denyList = Array.Empty<string>();
+        private IReadOnlyList<string> allowList = new[] { CATCH_ALL_WILDCARD };
+        
+        public bool IsTrustedType(IDestination destination, Type type)
+        {
+            var typeName = type?.FullName;
+            if (typeName == null)
+            {
+                return true;
+            }
+
+            foreach (var denyListEntry in denyList)
+            {
+                if (CATCH_ALL_WILDCARD == denyListEntry)
+                {
+                    return false;
+                }
+                if (IsTypeOrNamespaceMatch(typeName, denyListEntry))
+                {
+                    return false;
+                }
+            }
+
+            foreach (var allowListEntry in allowList)
+            {
+                if (CATCH_ALL_WILDCARD == allowListEntry)
+                {
+                    return true;
+                }
+                if (IsTypeOrNamespaceMatch(typeName, allowListEntry))
+                {
+                    return true;
+                }
+            }
+            
+            // Failing outright rejection or allow from above, reject.
+            return false;
+        }
+
+        private bool IsTypeOrNamespaceMatch(string typeName, string listEntry)
+        {
+            // Check if type is an exact match of the entry
+            if (typeName == listEntry)
+            {
+                return true;
+            }
+            
+            // Check if the type is from a namespace matching the entry
+            var entryLength = listEntry.Length;
+            return typeName.Length > entryLength && 
typeName.StartsWith(listEntry) && '.' == typeName[entryLength];
+        }
+
+        public INmsDeserializationPolicy Clone()
+        {
+            return new NmsDefaultDeserializationPolicy
+            {
+                allowList = allowList.ToArray(),
+                denyList = denyList.ToArray()
+            };
+        }
+
+        /// <summary>
+        /// Gets or sets the deny list on this policy instance.
+        /// </summary>
+        public string DenyList
+        {
+            get => string.Join(",", denyList);
+            set => denyList = string.IsNullOrWhiteSpace(value) 
+                ? Array.Empty<string>() 
+                : value.Split(',');
+        }
+
+        /// <summary>
+        /// Gets or sets the allow list on this policy instance.
+        /// </summary>
+        public string AllowList
+        {
+            get => string.Join(",", allowList);
+            set => allowList = string.IsNullOrWhiteSpace(value) 
+                ? Array.Empty<string>() 
+                : value.Split(',');
+        }
+    }
+}
\ No newline at end of file
diff --git a/test/MessageConsumerTest.cs b/test/MessageConsumerTest.cs
index e635434..4d59dcd 100644
--- a/test/MessageConsumerTest.cs
+++ b/test/MessageConsumerTest.cs
@@ -20,6 +20,7 @@ using Apache.NMS.Test;
 using NUnit.Framework;
 using Apache.NMS.ActiveMQ.Commands;
 using System;
+using System.Runtime.Serialization;
 using Apache.NMS.Util;
 
 namespace Apache.NMS.ActiveMQ.Test
@@ -306,5 +307,91 @@ namespace Apache.NMS.ActiveMQ.Test
                 }
             }
         }
+        
+        [Test, Timeout(20_000)]
+        public void TestShouldNotDeserializeUntrustedType()
+        {
+            string uri = "activemq:tcp://${{activemqhost}}:61616";
+            var factory = new ConnectionFactory(ReplaceEnvVar(uri))
+            {
+                DeserializationPolicy = new NmsDefaultDeserializationPolicy
+                {
+                    DenyList = typeof(UntrustedType).FullName
+                }
+            };
+            using var connection = factory.CreateConnection("", "");
+
+            connection.Start();
+            var session = 
connection.CreateSession(AcknowledgementMode.AutoAcknowledge);
+            var queue = session.GetQueue(Guid.NewGuid().ToString());
+            var consumer = session.CreateConsumer(queue);
+            var producer = session.CreateProducer(queue);
+            
+            var message = producer.CreateObjectMessage(new UntrustedType { 
Prop1 = "foo" });
+            producer.Send(message);
+
+            var receivedMessage = consumer.Receive();
+            var objectMessage = receivedMessage as IObjectMessage;
+            Assert.NotNull(objectMessage);
+            var exception = Assert.Throws<SerializationException>(() =>
+            {
+                _ = objectMessage.Body;
+            });
+            Assert.AreEqual($"Forbidden {typeof(UntrustedType).FullName}! " +
+                            "This type is not trusted to be deserialized under 
the current configuration. " +
+                            "Please refer to the documentation for more 
information on how to configure trusted types.",
+                exception.Message);
+        }
+        
+        [Test]
+        public void TestShouldUseCustomDeserializationPolicy()
+        {
+            string uri = "activemq:tcp://${{activemqhost}}:61616";
+            var factory = new ConnectionFactory(ReplaceEnvVar(uri))
+            {
+                DeserializationPolicy = new CustomDeserializationPolicy()
+            };
+            using var connection = factory.CreateConnection("", "");
+            connection.Start();
+            var session = 
connection.CreateSession(AcknowledgementMode.AutoAcknowledge);
+            var queue = session.GetQueue(Guid.NewGuid().ToString());
+            var consumer = session.CreateConsumer(queue);
+            var producer = session.CreateProducer(queue);
+            
+            var message = producer.CreateObjectMessage(new UntrustedType { 
Prop1 = "foo" });
+            producer.Send(message);
+
+            var receivedMessage = consumer.Receive();
+            var objectMessage = receivedMessage as IObjectMessage;
+            Assert.NotNull(objectMessage);
+            _ = Assert.Throws<SerializationException>(() =>
+            {
+                _ = objectMessage.Body;
+            });
+        }
+        
+        [Serializable]
+        public class UntrustedType
+        {
+            public string Prop1 { get; set; }
+        }
+
+        private class CustomDeserializationPolicy : INmsDeserializationPolicy
+        {
+            public bool IsTrustedType(IDestination destination, Type type)
+            {
+                if (type == typeof(UntrustedType))
+                {
+                    return false;
+                }
+
+                return true;
+            }
+
+            public INmsDeserializationPolicy Clone()
+            {
+                return this;
+            }
+        }
     }
 }
diff --git a/test/NMSConnectionFactoryTest.cs b/test/NMSConnectionFactoryTest.cs
index 5a6ce48..5be6ff6 100644
--- a/test/NMSConnectionFactoryTest.cs
+++ b/test/NMSConnectionFactoryTest.cs
@@ -211,6 +211,28 @@ namespace Apache.NMS.ActiveMQ.Test
 
                                connection.Close();
                        }
-        }              
+        }
+
+        [Test]
+        public void TestSetDeserializationPolicy()
+        {
+               string baseUri = "activemq:tcp://${{activemqhost}}:61616";
+               string configuredUri = baseUri +
+                                      
"?nms.deserializationPolicy.allowList=a,b,c" +
+                                      
"&nms.deserializationPolicy.denyList=c,d,e";
+
+               var factory = new 
NMSConnectionFactory(NMSTestSupport.ReplaceEnvVar(configuredUri));
+
+               Assert.IsNotNull(factory);
+               Assert.IsNotNull(factory.ConnectionFactory);
+               using IConnection connection = factory.CreateConnection("", "");
+               Assert.IsNotNull(connection);
+               var amqConnection = connection as Connection;
+               var deserializationPolicy = amqConnection.DeserializationPolicy 
as NmsDefaultDeserializationPolicy;
+               Assert.IsNotNull(deserializationPolicy);
+               Assert.AreEqual("a,b,c", deserializationPolicy.AllowList);
+               Assert.AreEqual("c,d,e", deserializationPolicy.DenyList);
+               connection.Close();
+        }
     }
 }
diff --git a/test/NmsDefaultDeserializationPolicyTest.cs 
b/test/NmsDefaultDeserializationPolicyTest.cs
new file mode 100644
index 0000000..7ac772b
--- /dev/null
+++ b/test/NmsDefaultDeserializationPolicyTest.cs
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using Apache.NMS.Commands;
+using NUnit.Framework;
+
+namespace Apache.NMS.ActiveMQ.Test
+{
+    [TestFixture]
+    public class NmsDefaultDeserializationPolicyTest
+    {
+        [Test]
+        public void TestIsTrustedType()
+        {
+            var destination = new Queue("test-queue");
+            var policy = new NmsDefaultDeserializationPolicy();
+            
+            Assert.True(policy.IsTrustedType(destination, null));
+            Assert.True(policy.IsTrustedType(destination, typeof(Guid)));
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.True(policy.IsTrustedType(destination, typeof(bool)));
+            Assert.True(policy.IsTrustedType(destination, typeof(double)));
+            Assert.True(policy.IsTrustedType(destination, typeof(object)));
+            
+            // Only types in System
+            policy.AllowList = "System";
+            Assert.True(policy.IsTrustedType(destination, null));
+            Assert.True(policy.IsTrustedType(destination, typeof(Guid)));
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.True(policy.IsTrustedType(destination, typeof(bool)));
+            Assert.True(policy.IsTrustedType(destination, typeof(double)));
+            Assert.True(policy.IsTrustedType(destination, typeof(object)));
+            Assert.False(policy.IsTrustedType(destination, GetType()));
+            
+            // Entry must be complete namespace name prefix to match
+            // i.e. while "System.C" is a prefix of "System.Collections", this
+            // wont match the Queue class below.
+            policy.AllowList = "System.C";
+            Assert.False(policy.IsTrustedType(destination, typeof(Guid)));
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+            Assert.False(policy.IsTrustedType(destination, 
typeof(System.Collections.Queue)));
+
+            // Add a non-core namespace
+            policy.AllowList = $"System,{GetType().Namespace}";
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.True(policy.IsTrustedType(destination, GetType()));
+            
+            // Try with a type-specific entry
+            policy.AllowList = typeof(string).FullName;
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.False(policy.IsTrustedType(destination, typeof(bool)));
+            
+            // Verify deny list overrides allow list
+            policy.AllowList = "System";
+            policy.DenyList = "System";
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+            
+            // Verify deny list entry prefix overrides allow list
+            policy.AllowList = typeof(string).FullName;
+            policy.DenyList = typeof(string).Namespace;
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+            
+            // Verify deny list catch-all overrides allow list
+            policy.AllowList = typeof(string).FullName;
+            policy.DenyList = 
NmsDefaultDeserializationPolicy.CATCH_ALL_WILDCARD;
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+        }
+
+        [Test]
+        public void TestNmsDefaultDeserializationPolicy()
+        {
+            var policy = new NmsDefaultDeserializationPolicy();
+            
+            Assert.IsNotEmpty(policy.AllowList);
+            Assert.IsEmpty(policy.DenyList);
+        }
+
+        [Test]
+        public void TestNmsDefaultDeserializationPolicyClone()
+        {
+            var policy = new NmsDefaultDeserializationPolicy
+            {
+                AllowList = "a.b.c",
+                DenyList = "d.e.f"
+            };
+
+            var clone = (NmsDefaultDeserializationPolicy) policy.Clone();
+            Assert.AreEqual(policy.AllowList, clone.AllowList);
+            Assert.AreEqual(policy.DenyList, clone.DenyList);
+            Assert.AreNotSame(clone, policy);
+        }
+
+        [Test]
+        public void TestSetAllowList()
+        {
+            var policy = new NmsDefaultDeserializationPolicy();
+            Assert.NotNull(policy.AllowList);
+
+            policy.AllowList = null;
+            Assert.NotNull(policy.AllowList);
+            Assert.IsEmpty(policy.AllowList);
+
+            policy.AllowList = string.Empty;
+            Assert.NotNull(policy.AllowList);
+            Assert.IsEmpty(policy.AllowList);
+            
+            policy.AllowList = "*";
+            Assert.NotNull(policy.AllowList);
+            Assert.IsNotEmpty(policy.AllowList);
+            
+            policy.AllowList = "a,b,c";
+            Assert.NotNull(policy.AllowList);
+            Assert.IsNotEmpty(policy.AllowList);
+            Assert.AreEqual("a,b,c", policy.AllowList);
+        }
+        
+        [Test]
+        public void TestSetDenyList()
+        {
+            var policy = new NmsDefaultDeserializationPolicy();
+            Assert.NotNull(policy.DenyList);
+
+            policy.DenyList = null;
+            Assert.NotNull(policy.DenyList);
+            Assert.IsEmpty(policy.DenyList);
+
+            policy.DenyList = string.Empty;
+            Assert.NotNull(policy.DenyList);
+            Assert.IsEmpty(policy.DenyList);
+            
+            policy.DenyList = "*";
+            Assert.NotNull(policy.DenyList);
+            Assert.IsNotEmpty(policy.DenyList);
+            
+            policy.DenyList = "a,b,c";
+            Assert.NotNull(policy.DenyList);
+            Assert.IsNotEmpty(policy.DenyList);
+            Assert.AreEqual("a,b,c", policy.DenyList);
+        }
+    }
+}
\ No newline at end of file

Reply via email to