This is an automated email from the ASF dual-hosted git repository.

havret pushed a commit to branch allow_deny_list
in repository https://gitbox.apache.org/repos/asf/activemq-nms-amqp.git

commit 0ccd4ed78895eee9a860de1e2632988c23be0278
Author: Havret <[email protected]>
AuthorDate: Sun Feb 12 15:20:28 2023 +0100

    AMQNET-828 Add allow, deny types support
---
 src/NMS.AMQP/Meta/NmsConnectionInfo.cs             |   7 +-
 src/NMS.AMQP/Meta/NmsConsumerInfo.cs               |   4 +-
 src/NMS.AMQP/NmsConnectionFactory.cs               |  11 +-
 src/NMS.AMQP/NmsMessageConsumer.cs                 |   3 +-
 src/NMS.AMQP/Policies/INmsDeserializationPolicy.cs |  44 ++++++
 .../Policies/NmsDefaultDeserializationPolicy.cs    | 129 +++++++++++++++++
 src/NMS.AMQP/Provider/Amqp/AmqpConsumer.cs         |   3 +-
 .../Amqp/Message/AmqpNmsObjectMessageFacade.cs     |  18 ++-
 .../Amqp/Message/AmqpSerializedObjectDelegate.cs   |  18 ++-
 .../Amqp/Message/AmqpTypedObjectDelegate.cs        |   2 +
 .../Amqp/Message/IAmqpObjectTypeDelegate.cs        |   1 +
 .../Provider/Amqp/Message/TrustedClassFilter.cs    |  52 +++++++
 .../AmqpTestSupport.cs                             |   3 +-
 .../NmsMessageConsumerTest.cs                      |  84 +++++++++++
 test/Apache-NMS-AMQP-Test/ConnectionFactoryTest.cs |  16 +++
 .../NmsDefaultDeserializationPolicyTest.cs         | 157 +++++++++++++++++++++
 .../Amqp/AmqpNmsObjectMessageFacadeTest.cs         |   1 +
 17 files changed, 535 insertions(+), 18 deletions(-)

diff --git a/src/NMS.AMQP/Meta/NmsConnectionInfo.cs 
b/src/NMS.AMQP/Meta/NmsConnectionInfo.cs
index 103a4cb..f9120c6 100644
--- a/src/NMS.AMQP/Meta/NmsConnectionInfo.cs
+++ b/src/NMS.AMQP/Meta/NmsConnectionInfo.cs
@@ -17,6 +17,7 @@
 
 using System;
 using Amqp;
+using Apache.NMS.AMQP.Policies;
 
 namespace Apache.NMS.AMQP.Meta
 {
@@ -79,10 +80,8 @@ namespace Apache.NMS.AMQP.Meta
         public bool SharedSubsSupported { get; set; }
 
         public PrefetchPolicyInfo PrefetchPolicy { get; set; } = 
DEFAULT_PREFETCH_POLICY;
-        
-        
-        
-        
+        public INmsDeserializationPolicy DeserializationPolicy { get; set; }
+
 
         public void SetClientId(string clientId, bool explicitClientId)
         {
diff --git a/src/NMS.AMQP/Meta/NmsConsumerInfo.cs 
b/src/NMS.AMQP/Meta/NmsConsumerInfo.cs
index b5506c3..3d808be 100644
--- a/src/NMS.AMQP/Meta/NmsConsumerInfo.cs
+++ b/src/NMS.AMQP/Meta/NmsConsumerInfo.cs
@@ -16,6 +16,7 @@
  */
 
 using System;
+using Apache.NMS.AMQP.Policies;
 
 namespace Apache.NMS.AMQP.Meta
 {
@@ -38,7 +39,8 @@ namespace Apache.NMS.AMQP.Meta
         public bool LocalMessageExpiry { get; set; }
         public bool IsBrowser { get; set; }
         public int LinkCredit { get; set; }
-        
+        public INmsDeserializationPolicy DeserializationPolicy { get; set; }
+
         public bool HasSelector() => !string.IsNullOrWhiteSpace(Selector);
 
         protected bool Equals(NmsConsumerInfo other)
diff --git a/src/NMS.AMQP/NmsConnectionFactory.cs 
b/src/NMS.AMQP/NmsConnectionFactory.cs
index 6212ddc..7684f30 100644
--- a/src/NMS.AMQP/NmsConnectionFactory.cs
+++ b/src/NMS.AMQP/NmsConnectionFactory.cs
@@ -17,9 +17,9 @@
 
 using System;
 using System.Collections.Specialized;
-using System.Threading;
 using System.Threading.Tasks;
 using Apache.NMS.AMQP.Meta;
+using Apache.NMS.AMQP.Policies;
 using Apache.NMS.AMQP.Provider;
 using Apache.NMS.AMQP.Util;
 using Apache.NMS.AMQP.Util.Synchronization;
@@ -308,6 +308,12 @@ namespace Apache.NMS.AMQP
         }
 
         public IRedeliveryPolicy RedeliveryPolicy { get; set; }
+        
+        /// <summary>
+        /// The deserialization policy that is applied when a connection is 
created.
+        /// </summary>
+        public INmsDeserializationPolicy DeserializationPolicy { get; set; } = 
new NmsDefaultDeserializationPolicy();
+
         public ConsumerTransformerDelegate ConsumerTransformer { get; set; }
         public ProducerTransformerDelegate ProducerTransformer { get; set; }
 
@@ -339,7 +345,8 @@ namespace Apache.NMS.AMQP
                 SendTimeout = SendTimeout,
                 CloseTimeout = CloseTimeout,
                 LocalMessageExpiry = LocalMessageExpiry,
-                PrefetchPolicy = PrefetchPolicy.Clone()
+                PrefetchPolicy = PrefetchPolicy.Clone(),
+                DeserializationPolicy = DeserializationPolicy.Clone()
             };
 
             bool userSpecifiedClientId = ClientId != null;
diff --git a/src/NMS.AMQP/NmsMessageConsumer.cs 
b/src/NMS.AMQP/NmsMessageConsumer.cs
index c214ee0..05261e6 100644
--- a/src/NMS.AMQP/NmsMessageConsumer.cs
+++ b/src/NMS.AMQP/NmsMessageConsumer.cs
@@ -62,7 +62,8 @@ namespace Apache.NMS.AMQP
                 IsDurable = IsDurableSubscription,
                 IsBrowser =  IsBrowser,
                 LocalMessageExpiry = 
Session.Connection.ConnectionInfo.LocalMessageExpiry,
-                LinkCredit = 
Session.Connection.ConnectionInfo.PrefetchPolicy.GetLinkCredit(destination, 
IsBrowser, IsDurableSubscription)
+                LinkCredit = 
Session.Connection.ConnectionInfo.PrefetchPolicy.GetLinkCredit(destination, 
IsBrowser, IsDurableSubscription),
+                DeserializationPolicy = 
Session.Connection.ConnectionInfo.DeserializationPolicy.Clone()
             };
             deliveryTask = new MessageDeliveryTask(this);
         }
diff --git a/src/NMS.AMQP/Policies/INmsDeserializationPolicy.cs 
b/src/NMS.AMQP/Policies/INmsDeserializationPolicy.cs
new file mode 100644
index 0000000..cd686dd
--- /dev/null
+++ b/src/NMS.AMQP/Policies/INmsDeserializationPolicy.cs
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+
+namespace Apache.NMS.AMQP.Policies
+{
+    /// <summary>
+    /// Defines the interface for a policy object that controls what types of 
message
+    /// content are permissible when the body of an incoming ObjectMessage is 
being
+    /// deserialized.
+    /// </summary>
+    public interface INmsDeserializationPolicy
+    {
+        /// <summary>
+        /// Returns whether the given class is a trusted type and can be 
deserialized
+        /// by the client when calls to <see cref="IObjectMessage.Body"/> are 
made.
+        /// </summary>
+        /// <param name="destination">The Destination for the message 
containing the type to be deserialized.</param>
+        /// <param name="type">The type of the object that is about to be 
read.</param>
+        /// <returns>True if the type is trusted or false if not.</returns>
+        bool IsTrustedType(IDestination destination, Type type);
+        
+        /// <summary>
+        /// Clones the instance of INmsDeserializationPolicy object so it can 
be used in a thread safe way.
+        /// </summary>
+        /// <returns>A copy of INmsDeserializationPolicy object.</returns>
+        INmsDeserializationPolicy Clone();
+    }
+}
\ No newline at end of file
diff --git a/src/NMS.AMQP/Policies/NmsDefaultDeserializationPolicy.cs 
b/src/NMS.AMQP/Policies/NmsDefaultDeserializationPolicy.cs
new file mode 100644
index 0000000..609c1f7
--- /dev/null
+++ b/src/NMS.AMQP/Policies/NmsDefaultDeserializationPolicy.cs
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace Apache.NMS.AMQP.Policies
+{
+    /// <summary>
+    /// Default implementation of the deserialization policy that can read 
allow and deny lists of
+    /// types/namespaces from the connection uri options.
+    ///
+    /// The policy reads a default deny list string value (comma separated) 
from the connection uri options
+    /// (nms.deserializationPolicy.deny) which defaults to null which 
indicates an empty deny list.
+    ///
+    /// The policy reads a default allow list string value (comma separated) 
from the connection uri options
+    /// (nms.deserializationPolicy.allowList) which defaults to <see 
cref="CATCH_ALL_WILDCARD"/> which
+    /// indicates that all types are allowed.
+    ///
+    /// The deny list overrides the allow list, entries that could match both 
are counted as denied.
+    ///
+    /// If the policy should treat all classes as untrusted, the deny list 
should be set to <see cref="CATCH_ALL_WILDCARD"/>.
+    /// </summary>
+    public class NmsDefaultDeserializationPolicy : INmsDeserializationPolicy
+    {
+        /// <summary>
+        /// Value used to indicate that all classes should be allowed or denied
+        /// </summary>
+        public const string CATCH_ALL_WILDCARD = "*";
+        
+        private IReadOnlyList<string> denyList = Array.Empty<string>();
+        private IReadOnlyList<string> allowList = new[] { CATCH_ALL_WILDCARD };
+        
+        public bool IsTrustedType(IDestination destination, Type type)
+        {
+            var typeName = type?.FullName;
+            if (typeName == null)
+            {
+                return true;
+            }
+
+            foreach (var denyListEntry in denyList)
+            {
+                if (CATCH_ALL_WILDCARD == denyListEntry)
+                {
+                    return false;
+                }
+                if (IsTypeOrNamespaceMatch(typeName, denyListEntry))
+                {
+                    return false;
+                }
+            }
+
+            foreach (var allowListEntry in allowList)
+            {
+                if (CATCH_ALL_WILDCARD == allowListEntry)
+                {
+                    return true;
+                }
+                if (IsTypeOrNamespaceMatch(typeName, allowListEntry))
+                {
+                    return true;
+                }
+            }
+            
+            // Failing outright rejection or allow from above, reject.
+            return false;
+        }
+
+        private bool IsTypeOrNamespaceMatch(string typeName, string listEntry)
+        {
+            // Check if type is an exact match of the entry
+            if (typeName == listEntry)
+            {
+                return true;
+            }
+            
+            // Check if the type is from a namespace matching the entry
+            var entryLength = listEntry.Length;
+            return typeName.Length > entryLength && 
typeName.StartsWith(listEntry) && '.' == typeName[entryLength];
+        }
+
+        public INmsDeserializationPolicy Clone()
+        {
+            return new NmsDefaultDeserializationPolicy
+            {
+                allowList = allowList.ToArray(),
+                denyList = denyList.ToArray()
+            };
+        }
+
+        /// <summary>
+        /// Gets or sets the deny list on this policy instance.
+        /// </summary>
+        public string DenyList
+        {
+            get => string.Join(",", denyList);
+            set => denyList = string.IsNullOrWhiteSpace(value) 
+                ? Array.Empty<string>() 
+                : value.Split(',');
+        }
+
+        /// <summary>
+        /// Gets or sets the allow list on this policy instance.
+        /// </summary>
+        public string AllowList
+        {
+            get => string.Join(",", allowList);
+            set => allowList = string.IsNullOrWhiteSpace(value) 
+                ? Array.Empty<string>() 
+                : value.Split(',');
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/NMS.AMQP/Provider/Amqp/AmqpConsumer.cs 
b/src/NMS.AMQP/Provider/Amqp/AmqpConsumer.cs
index fc3a5cf..99666c8 100644
--- a/src/NMS.AMQP/Provider/Amqp/AmqpConsumer.cs
+++ b/src/NMS.AMQP/Provider/Amqp/AmqpConsumer.cs
@@ -33,6 +33,7 @@ namespace Apache.NMS.AMQP.Provider.Amqp
 {
     public interface IAmqpConsumer
     {
+        NmsConsumerInfo ResourceInfo { get; }
         IDestination Destination { get; }
         IAmqpConnection Connection { get; }
     }
@@ -60,7 +61,7 @@ namespace Apache.NMS.AMQP.Provider.Amqp
         }
 
         public NmsConsumerId ConsumerId => this.info.Id;
-        
+        public NmsConsumerInfo ResourceInfo => this.info;
 
         public Task Attach()
         {
diff --git a/src/NMS.AMQP/Provider/Amqp/Message/AmqpNmsObjectMessageFacade.cs 
b/src/NMS.AMQP/Provider/Amqp/Message/AmqpNmsObjectMessageFacade.cs
index d8454fa..94c9074 100644
--- a/src/NMS.AMQP/Provider/Amqp/Message/AmqpNmsObjectMessageFacade.cs
+++ b/src/NMS.AMQP/Provider/Amqp/Message/AmqpNmsObjectMessageFacade.cs
@@ -19,16 +19,18 @@ using System;
 using System.IO;
 using Apache.NMS.AMQP.Message;
 using Apache.NMS.AMQP.Message.Facade;
+using Apache.NMS.AMQP.Policies;
 using Apache.NMS.AMQP.Util;
 
 namespace Apache.NMS.AMQP.Provider.Amqp.Message
 {
     public class AmqpNmsObjectMessageFacade : AmqpNmsMessageFacade, 
INmsObjectMessageFacade
     {
+        private INmsDeserializationPolicy deserializationPolicy;
         private IAmqpObjectTypeDelegate typeDelegate;
 
         public IAmqpObjectTypeDelegate Delegate => typeDelegate;
-
+        
         public object Object
         {
             get => Delegate.Object;
@@ -68,15 +70,18 @@ namespace Apache.NMS.AMQP.Provider.Amqp.Message
         public override void Initialize(IAmqpConsumer consumer, 
global::Amqp.Message message)
         {
             base.Initialize(consumer, message);
-            bool dotnetSerialized = 
MessageSupport.SERIALIZED_DOTNET_OBJECT_CONTENT_TYPE.Equals(ContentType);
-            InitSerializer(!dotnetSerialized);
+            deserializationPolicy = 
consumer.ResourceInfo.DeserializationPolicy;
+            bool hasDotNetSerializedType = 
MessageSupport.SERIALIZED_DOTNET_OBJECT_CONTENT_TYPE.Equals(ContentType);
+            InitSerializer(!hasDotNetSerializedType);
         }
 
+        
+
         private void InitSerializer(bool useAmqpTypes)
         {
             if (!useAmqpTypes)
             {
-                typeDelegate = new AmqpSerializedObjectDelegate(this);
+                typeDelegate = new AmqpSerializedObjectDelegate(this, 
deserializationPolicy);
             }
             else
             {
@@ -86,9 +91,10 @@ namespace Apache.NMS.AMQP.Provider.Amqp.Message
 
         public override INmsMessageFacade Copy()
         {
-            AmqpNmsObjectMessageFacade copy = new AmqpNmsObjectMessageFacade();
+            var copy = new AmqpNmsObjectMessageFacade();
+            copy.deserializationPolicy = deserializationPolicy;
             CopyInto(copy);
-            copy.typeDelegate = typeDelegate;
+            copy.InitSerializer(typeDelegate.IsAmqpTypeEncoded());
             return copy;
         }
 
diff --git a/src/NMS.AMQP/Provider/Amqp/Message/AmqpSerializedObjectDelegate.cs 
b/src/NMS.AMQP/Provider/Amqp/Message/AmqpSerializedObjectDelegate.cs
index 9c1a4a4..b6494c1 100644
--- a/src/NMS.AMQP/Provider/Amqp/Message/AmqpSerializedObjectDelegate.cs
+++ b/src/NMS.AMQP/Provider/Amqp/Message/AmqpSerializedObjectDelegate.cs
@@ -15,10 +15,12 @@
  * limitations under the License.
  */
 
+using System;
 using System.IO;
 using System.Runtime.Serialization;
 using System.Runtime.Serialization.Formatters.Binary;
 using Amqp.Framing;
+using Apache.NMS.AMQP.Policies;
 using Apache.NMS.AMQP.Util;
 
 namespace Apache.NMS.AMQP.Provider.Amqp.Message
@@ -28,10 +30,13 @@ namespace Apache.NMS.AMQP.Provider.Amqp.Message
         public static readonly Data NULL_OBJECT_BODY = new Data() {Binary = 
new byte[] {0xac, 0xed, 0x00, 0x05, 0x70}};
 
         private readonly AmqpNmsObjectMessageFacade facade;
+        private readonly INmsDeserializationPolicy deserializationPolicy;
+        private bool localContent;
 
-        public AmqpSerializedObjectDelegate(AmqpNmsObjectMessageFacade facade)
+        public AmqpSerializedObjectDelegate(AmqpNmsObjectMessageFacade facade, 
INmsDeserializationPolicy deserializationPolicy)
         {
             this.facade = facade;
+            this.deserializationPolicy = deserializationPolicy;
             facade.ContentType = 
MessageSupport.SERIALIZED_DOTNET_OBJECT_CONTENT_TYPE;
         }
 
@@ -64,6 +69,8 @@ namespace Apache.NMS.AMQP.Provider.Amqp.Message
                 {
                     facade.Message.BodySection = NULL_OBJECT_BODY;
                 }
+
+                localContent = true;
             }
         }
 
@@ -74,11 +81,18 @@ namespace Apache.NMS.AMQP.Provider.Amqp.Message
                 facade.Message.BodySection = NULL_OBJECT_BODY;
         }
 
+        public bool IsAmqpTypeEncoded() => false;
+
         private object Deserialize(byte[] binary)
         {
-            using (MemoryStream stream = new MemoryStream(binary))
+            using (var stream = new MemoryStream(binary))
             {
                 IFormatter formatter = new BinaryFormatter();
+                if (localContent == false && deserializationPolicy != null)
+                {
+                    formatter.Binder = new 
TrustedClassFilter(deserializationPolicy, facade.NMSDestination);
+                }
+                
                 return formatter.Deserialize(stream);
             }
         }
diff --git a/src/NMS.AMQP/Provider/Amqp/Message/AmqpTypedObjectDelegate.cs 
b/src/NMS.AMQP/Provider/Amqp/Message/AmqpTypedObjectDelegate.cs
index e1dd071..cfaafef 100644
--- a/src/NMS.AMQP/Provider/Amqp/Message/AmqpTypedObjectDelegate.cs
+++ b/src/NMS.AMQP/Provider/Amqp/Message/AmqpTypedObjectDelegate.cs
@@ -103,5 +103,7 @@ namespace Apache.NMS.AMQP.Provider.Amqp.Message
             if (facade.Message.BodySection == null)
                 facade.Message.BodySection = NULL_OBJECT_BODY;
         }
+
+        public bool IsAmqpTypeEncoded() => true;
     }
 }
\ No newline at end of file
diff --git a/src/NMS.AMQP/Provider/Amqp/Message/IAmqpObjectTypeDelegate.cs 
b/src/NMS.AMQP/Provider/Amqp/Message/IAmqpObjectTypeDelegate.cs
index d0a0a28..72e7f01 100644
--- a/src/NMS.AMQP/Provider/Amqp/Message/IAmqpObjectTypeDelegate.cs
+++ b/src/NMS.AMQP/Provider/Amqp/Message/IAmqpObjectTypeDelegate.cs
@@ -25,5 +25,6 @@ namespace Apache.NMS.AMQP.Provider.Amqp.Message
     {
         object Object { get; set; }
         void OnSend();
+        bool IsAmqpTypeEncoded();
     }
 }
\ No newline at end of file
diff --git a/src/NMS.AMQP/Provider/Amqp/Message/TrustedClassFilter.cs 
b/src/NMS.AMQP/Provider/Amqp/Message/TrustedClassFilter.cs
new file mode 100644
index 0000000..c761487
--- /dev/null
+++ b/src/NMS.AMQP/Provider/Amqp/Message/TrustedClassFilter.cs
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Reflection;
+using System.Runtime.Serialization;
+using Apache.NMS.AMQP.Policies;
+
+namespace Apache.NMS.AMQP.Provider.Amqp.Message
+{
+    internal class TrustedClassFilter : SerializationBinder
+    {
+        private readonly INmsDeserializationPolicy deserializationPolicy;
+        private readonly IDestination destination;
+
+        public TrustedClassFilter(INmsDeserializationPolicy 
deserializationPolicy, IDestination destination)
+        {
+            this.deserializationPolicy = deserializationPolicy;
+            this.destination = destination;
+        }
+        
+        public override Type BindToType(string assemblyName, string typeName)
+        {
+            var name = new AssemblyName(assemblyName);
+            var assembly = Assembly.Load(name);
+            var type = FormatterServices.GetTypeFromAssembly(assembly, 
typeName);
+            if (deserializationPolicy.IsTrustedType(destination, type))
+            {
+                return type;
+            }
+
+            var message = $"Forbidden {type.FullName}! " +
+                          "This type is not trusted to be deserialized under 
the current configuration. " +
+                          "Please refer to the documentation for more 
information on how to configure trusted types.";
+            throw new SerializationException(message);
+        }
+    }
+}
\ No newline at end of file
diff --git a/test/Apache-NMS-AMQP-Interop-Test/AmqpTestSupport.cs 
b/test/Apache-NMS-AMQP-Interop-Test/AmqpTestSupport.cs
index c742457..2564139 100644
--- a/test/Apache-NMS-AMQP-Interop-Test/AmqpTestSupport.cs
+++ b/test/Apache-NMS-AMQP-Interop-Test/AmqpTestSupport.cs
@@ -49,7 +49,7 @@ namespace NMS.AMQP.Test
             return connection;
         }
 
-        protected IConnection CreateAmqpConnection(string clientId = null, 
string options = null)
+        protected IConnection CreateAmqpConnection(string clientId = null, 
string options = null, Action<NmsConnectionFactory> configureConnectionFactory 
= null)
         {
             string brokerUri = 
Environment.GetEnvironmentVariable("NMS_AMQP_TEST_URI") ?? 
"amqp://127.0.0.1:5672";
             if (options != null)
@@ -61,6 +61,7 @@ namespace NMS.AMQP.Test
 
             NmsConnectionFactory factory = new NmsConnectionFactory(brokerUri);
             factory.ClientId = clientId;
+            configureConnectionFactory?.Invoke(factory);
             return factory.CreateConnection(userName, password);
         }
 
diff --git a/test/Apache-NMS-AMQP-Interop-Test/NmsMessageConsumerTest.cs 
b/test/Apache-NMS-AMQP-Interop-Test/NmsMessageConsumerTest.cs
index 7b2de68..ff60deb 100644
--- a/test/Apache-NMS-AMQP-Interop-Test/NmsMessageConsumerTest.cs
+++ b/test/Apache-NMS-AMQP-Interop-Test/NmsMessageConsumerTest.cs
@@ -19,9 +19,11 @@ using System;
 using System.Collections.Concurrent;
 using System.Collections.Generic;
 using System.Linq;
+using System.Runtime.Serialization;
 using System.Threading;
 using System.Threading.Tasks;
 using Apache.NMS;
+using Apache.NMS.AMQP.Policies;
 using NUnit.Framework;
 
 namespace NMS.AMQP.Test
@@ -375,5 +377,87 @@ namespace NMS.AMQP.Test
             IMessageConsumer messageConsumer = session.CreateConsumer(topic, 
null, noLocal: true);
             
Assert.IsNull(messageConsumer.Receive(TimeSpan.FromMilliseconds(500)));
         }
+
+        [Test, Timeout(20_000)]
+        public void TestShouldNotDeserializeUntrustedType()
+        {
+            Connection = CreateAmqpConnection(configureConnectionFactory: 
factory =>
+            {
+                var deserializationPolicy = new NmsDefaultDeserializationPolicy
+                {
+                    DenyList = typeof(UntrustedType).FullName
+                };
+                factory.DeserializationPolicy = deserializationPolicy;
+            });
+            Connection.Start();
+            var session = 
Connection.CreateSession(AcknowledgementMode.AutoAcknowledge);
+            var queue = session.GetQueue(TestName);
+            var consumer = session.CreateConsumer(queue);
+            var producer = session.CreateProducer(queue);
+            
+            var message = producer.CreateObjectMessage(new UntrustedType { 
Prop1 = "foo" });
+            producer.Send(message);
+
+            var receivedMessage = consumer.Receive();
+            var objectMessage = receivedMessage as IObjectMessage;
+            Assert.NotNull(objectMessage);
+            var exception = Assert.Throws<SerializationException>(() =>
+            {
+                _ = objectMessage.Body;
+            });
+            Assert.AreEqual($"Forbidden {typeof(UntrustedType).FullName}! " +
+                            "This type is not trusted to be deserialized under 
the current configuration. " +
+                            "Please refer to the documentation for more 
information on how to configure trusted types.",
+                exception.Message);
+        }
+
+        [Test]
+        public void TestShouldUseCustomDeserializationPolicy()
+        {
+            Connection = CreateAmqpConnection(configureConnectionFactory: 
factory =>
+            {
+                factory.DeserializationPolicy = new 
CustomDeserializationPolicy();
+            });
+            Connection.Start();
+            var session = 
Connection.CreateSession(AcknowledgementMode.AutoAcknowledge);
+            var queue = session.GetQueue(TestName);
+            var consumer = session.CreateConsumer(queue);
+            var producer = session.CreateProducer(queue);
+            
+            var message = producer.CreateObjectMessage(new UntrustedType { 
Prop1 = "foo" });
+            producer.Send(message);
+
+            var receivedMessage = consumer.Receive();
+            var objectMessage = receivedMessage as IObjectMessage;
+            Assert.NotNull(objectMessage);
+            _ = Assert.Throws<SerializationException>(() =>
+            {
+                _ = objectMessage.Body;
+            });
+        }
+    }
+
+    [Serializable]
+    public class UntrustedType
+    {
+        public string Prop1 { get; set; }
+    }
+    
+    public class CustomDeserializationPolicy : INmsDeserializationPolicy
+    {
+        public bool IsTrustedType(IDestination destination, Type type)
+        {
+            if (type == typeof(UntrustedType))
+            {
+                return false;
+            }
+
+            return true;
+        }
+
+        public INmsDeserializationPolicy Clone()
+        {
+            return this;
+        }
     }
 }
\ No newline at end of file
diff --git a/test/Apache-NMS-AMQP-Test/ConnectionFactoryTest.cs 
b/test/Apache-NMS-AMQP-Test/ConnectionFactoryTest.cs
index 096e53f..6c62465 100644
--- a/test/Apache-NMS-AMQP-Test/ConnectionFactoryTest.cs
+++ b/test/Apache-NMS-AMQP-Test/ConnectionFactoryTest.cs
@@ -20,6 +20,7 @@ using System.Diagnostics;
 using System.Threading.Tasks;
 using Apache.NMS;
 using Apache.NMS.AMQP;
+using Apache.NMS.AMQP.Policies;
 using Apache.NMS.AMQP.Provider;
 using NMS.AMQP.Test.Provider.Mock;
 using NUnit.Framework;
@@ -176,6 +177,21 @@ namespace NMS.AMQP.Test
             Assert.IsFalse(factory.LocalMessageExpiry);
         }
 
+        [Test]
+        public void TestSetDeserializationPolicy()
+        {
+            string baseUri = "amqp://localhost:1234";
+            string configuredUri = baseUri +
+                                   
"?nms.deserializationPolicy.allowList=a,b,c" +
+                                   "&nms.deserializationPolicy.denyList=c,d,e";
+
+            var factory = new NmsConnectionFactory(new Uri(configuredUri));
+            var deserializationPolicy = factory.DeserializationPolicy as 
NmsDefaultDeserializationPolicy;
+            Assert.IsNotNull(deserializationPolicy);
+            Assert.AreEqual("a,b,c", deserializationPolicy.AllowList);
+            Assert.AreEqual("c,d,e", deserializationPolicy.DenyList);
+        }
+
         [Test]
         public void TestCreateConnectionBadBrokerUri()
         {
diff --git 
a/test/Apache-NMS-AMQP-Test/Policies/NmsDefaultDeserializationPolicyTest.cs 
b/test/Apache-NMS-AMQP-Test/Policies/NmsDefaultDeserializationPolicyTest.cs
new file mode 100644
index 0000000..33645b1
--- /dev/null
+++ b/test/Apache-NMS-AMQP-Test/Policies/NmsDefaultDeserializationPolicyTest.cs
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using Apache.NMS.AMQP;
+using Apache.NMS.AMQP.Policies;
+using NUnit.Framework;
+
+namespace NMS.AMQP.Test.Policies
+{
+    [TestFixture]
+    public class NmsDefaultDeserializationPolicyTest
+    {
+        [Test]
+        public void TestIsTrustedType()
+        {
+            var destination = new NmsQueue("test-queue");
+            var policy = new NmsDefaultDeserializationPolicy();
+            
+            Assert.True(policy.IsTrustedType(destination, null));
+            Assert.True(policy.IsTrustedType(destination, typeof(Guid)));
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.True(policy.IsTrustedType(destination, typeof(bool)));
+            Assert.True(policy.IsTrustedType(destination, typeof(double)));
+            Assert.True(policy.IsTrustedType(destination, typeof(object)));
+            
+            // Only types in System
+            policy.AllowList = "System";
+            Assert.True(policy.IsTrustedType(destination, null));
+            Assert.True(policy.IsTrustedType(destination, typeof(Guid)));
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.True(policy.IsTrustedType(destination, typeof(bool)));
+            Assert.True(policy.IsTrustedType(destination, typeof(double)));
+            Assert.True(policy.IsTrustedType(destination, typeof(object)));
+            Assert.False(policy.IsTrustedType(destination, GetType()));
+            
+            // Entry must be complete namespace name prefix to match
+            // i.e. while "System.C" is a prefix of "System.Collections", this
+            // wont match the Queue class below.
+            policy.AllowList = "System.C";
+            Assert.False(policy.IsTrustedType(destination, typeof(Guid)));
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+            Assert.False(policy.IsTrustedType(destination, 
typeof(System.Collections.Queue)));
+
+            // Add a non-core namespace
+            policy.AllowList = $"System,{GetType().Namespace}";
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.True(policy.IsTrustedType(destination, GetType()));
+            
+            // Try with a type-specific entry
+            policy.AllowList = typeof(string).FullName;
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.False(policy.IsTrustedType(destination, typeof(bool)));
+            
+            // Verify deny list overrides allow list
+            policy.AllowList = "System";
+            policy.DenyList = "System";
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+            
+            // Verify deny list entry prefix overrides allow list
+            policy.AllowList = typeof(string).FullName;
+            policy.DenyList = typeof(string).Namespace;
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+            
+            // Verify deny list catch-all overrides allow list
+            policy.AllowList = typeof(string).FullName;
+            policy.DenyList = 
NmsDefaultDeserializationPolicy.CATCH_ALL_WILDCARD;
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+        }
+
+        [Test]
+        public void TestNmsDefaultDeserializationPolicy()
+        {
+            var policy = new NmsDefaultDeserializationPolicy();
+            
+            Assert.IsNotEmpty(policy.AllowList);
+            Assert.IsEmpty(policy.DenyList);
+        }
+
+        [Test]
+        public void TestNmsDefaultDeserializationPolicyClone()
+        {
+            var policy = new NmsDefaultDeserializationPolicy
+            {
+                AllowList = "a.b.c",
+                DenyList = "d.e.f"
+            };
+
+            var clone = (NmsDefaultDeserializationPolicy) policy.Clone();
+            Assert.AreEqual(policy.AllowList, clone.AllowList);
+            Assert.AreEqual(policy.DenyList, clone.DenyList);
+            Assert.AreNotSame(clone, policy);
+        }
+
+        [Test]
+        public void TestSetAllowList()
+        {
+            var policy = new NmsDefaultDeserializationPolicy();
+            Assert.NotNull(policy.AllowList);
+
+            policy.AllowList = null;
+            Assert.NotNull(policy.AllowList);
+            Assert.IsEmpty(policy.AllowList);
+
+            policy.AllowList = string.Empty;
+            Assert.NotNull(policy.AllowList);
+            Assert.IsEmpty(policy.AllowList);
+            
+            policy.AllowList = "*";
+            Assert.NotNull(policy.AllowList);
+            Assert.IsNotEmpty(policy.AllowList);
+            
+            policy.AllowList = "a,b,c";
+            Assert.NotNull(policy.AllowList);
+            Assert.IsNotEmpty(policy.AllowList);
+            Assert.AreEqual("a,b,c", policy.AllowList);
+        }
+        
+        [Test]
+        public void TestSetDenyList()
+        {
+            var policy = new NmsDefaultDeserializationPolicy();
+            Assert.NotNull(policy.DenyList);
+
+            policy.DenyList = null;
+            Assert.NotNull(policy.DenyList);
+            Assert.IsEmpty(policy.DenyList);
+
+            policy.DenyList = string.Empty;
+            Assert.NotNull(policy.DenyList);
+            Assert.IsEmpty(policy.DenyList);
+            
+            policy.DenyList = "*";
+            Assert.NotNull(policy.DenyList);
+            Assert.IsNotEmpty(policy.DenyList);
+            
+            policy.DenyList = "a,b,c";
+            Assert.NotNull(policy.DenyList);
+            Assert.IsNotEmpty(policy.DenyList);
+            Assert.AreEqual("a,b,c", policy.DenyList);
+        }
+    }
+}
\ No newline at end of file
diff --git 
a/test/Apache-NMS-AMQP-Test/Provider/Amqp/AmqpNmsObjectMessageFacadeTest.cs 
b/test/Apache-NMS-AMQP-Test/Provider/Amqp/AmqpNmsObjectMessageFacadeTest.cs
index 2a9b224..17a7d80 100644
--- a/test/Apache-NMS-AMQP-Test/Provider/Amqp/AmqpNmsObjectMessageFacadeTest.cs
+++ b/test/Apache-NMS-AMQP-Test/Provider/Amqp/AmqpNmsObjectMessageFacadeTest.cs
@@ -358,6 +358,7 @@ namespace NMS.AMQP.Test.Provider.Amqp
             Assert.AreEqual(amqpObjectMessageFacade.Object, copy.Object);
         }
 
+        [Obsolete("Obsolete")]
         private static byte[] GetSerializedBytes(object content)
         {
             using (MemoryStream stream = new MemoryStream())


Reply via email to