This is an automated email from the ASF dual-hosted git repository. havret pushed a commit to branch allow_deny_list in repository https://gitbox.apache.org/repos/asf/activemq-nms-openwire.git
commit 6bedbb487cc4f321613bb68811c1656ebe9f52d7 Author: Havret <[email protected]> AuthorDate: Wed Feb 15 20:03:17 2023 +0100 AMQNET-829 Add allow, deny types support --- src/Commands/ActiveMQObjectMessage.cs | 4 + src/Commands/TrustedClassFilter.cs | 51 +++++++++ src/Connection.cs | 2 + src/ConnectionFactory.cs | 6 ++ src/INmsDeserializationPolicy.cs | 42 ++++++++ src/NmsDefaultDeserializationPolicy.cs | 129 +++++++++++++++++++++++ test/MessageConsumerTest.cs | 87 ++++++++++++++++ test/NMSConnectionFactoryTest.cs | 24 ++++- test/NmsDefaultDeserializationPolicyTest.cs | 156 ++++++++++++++++++++++++++++ 9 files changed, 500 insertions(+), 1 deletion(-) diff --git a/src/Commands/ActiveMQObjectMessage.cs b/src/Commands/ActiveMQObjectMessage.cs index 3919111..aec88b5 100644 --- a/src/Commands/ActiveMQObjectMessage.cs +++ b/src/Commands/ActiveMQObjectMessage.cs @@ -140,6 +140,10 @@ namespace Apache.NMS.ActiveMQ.Commands if (formatter == null) { formatter = new BinaryFormatter(); + if (Connection.DeserializationPolicy != null) + { + formatter.Binder = new TrustedClassFilter(Connection.DeserializationPolicy, Destination); + } } return formatter; } diff --git a/src/Commands/TrustedClassFilter.cs b/src/Commands/TrustedClassFilter.cs new file mode 100644 index 0000000..eb90662 --- /dev/null +++ b/src/Commands/TrustedClassFilter.cs @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Reflection; +using System.Runtime.Serialization; + +namespace Apache.NMS.ActiveMQ.Commands +{ + internal class TrustedClassFilter : SerializationBinder + { + private readonly INmsDeserializationPolicy deserializationPolicy; + private readonly IDestination destination; + + public TrustedClassFilter(INmsDeserializationPolicy deserializationPolicy, IDestination destination) + { + this.deserializationPolicy = deserializationPolicy; + this.destination = destination; + } + + public override Type BindToType(string assemblyName, string typeName) + { + var name = new AssemblyName(assemblyName); + var assembly = Assembly.Load(name); + var type = FormatterServices.GetTypeFromAssembly(assembly, typeName); + if (deserializationPolicy.IsTrustedType(destination, type)) + { + return type; + } + + var message = $"Forbidden {type.FullName}! " + + "This type is not trusted to be deserialized under the current configuration. " + + "Please refer to the documentation for more information on how to configure trusted types."; + throw new SerializationException(message); + } + } +} \ No newline at end of file diff --git a/src/Connection.cs b/src/Connection.cs index df0cc57..a2d7df2 100644 --- a/src/Connection.cs +++ b/src/Connection.cs @@ -482,6 +482,8 @@ namespace Apache.NMS.ActiveMQ get { return this.compressionPolicy; } set { this.compressionPolicy = value; } } + + public INmsDeserializationPolicy DeserializationPolicy { get; set; } = new NmsDefaultDeserializationPolicy(); internal MessageTransformation MessageTransformation { diff --git a/src/ConnectionFactory.cs b/src/ConnectionFactory.cs index 84ecf22..6c778cb 100644 --- a/src/ConnectionFactory.cs +++ b/src/ConnectionFactory.cs @@ -397,6 +397,11 @@ namespace Apache.NMS.ActiveMQ } } } + + /// <summary> + /// The deserialization policy that is applied when a connection is created. + /// </summary> + public INmsDeserializationPolicy DeserializationPolicy { get; set; } = new NmsDefaultDeserializationPolicy(); public IdGenerator ClientIdGenerator { @@ -546,6 +551,7 @@ namespace Apache.NMS.ActiveMQ connection.RedeliveryPolicy = this.redeliveryPolicy.Clone() as IRedeliveryPolicy; connection.PrefetchPolicy = this.prefetchPolicy.Clone() as PrefetchPolicy; connection.CompressionPolicy = this.compressionPolicy.Clone() as ICompressionPolicy; + connection.DeserializationPolicy = this.DeserializationPolicy.Clone(); connection.ConsumerTransformer = this.consumerTransformer; connection.ProducerTransformer = this.producerTransformer; connection.WatchTopicAdvisories = this.watchTopicAdvisories; diff --git a/src/INmsDeserializationPolicy.cs b/src/INmsDeserializationPolicy.cs new file mode 100644 index 0000000..743855d --- /dev/null +++ b/src/INmsDeserializationPolicy.cs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; + +namespace Apache.NMS.ActiveMQ +{ + /// <summary> + /// Defines the interface for a policy that controls the permissible message content + /// during the deserialization of the body of an incoming <see cref="IObjectMessage"/>. + /// </summary> + public interface INmsDeserializationPolicy + { + /// <summary> + /// Determines if the given class is a trusted type that can be deserialized by the client. + /// </summary> + /// <param name="destination">The Destination for the message containing the type to be deserialized.</param> + /// <param name="type">The type of the object that is about to be read.</param> + /// <returns>True if the type is trusted, otherwise false.</returns> + bool IsTrustedType(IDestination destination, Type type); + + /// <summary> + /// Makes a thread-safe copy of the INmsDeserializationPolicy object. + /// </summary> + /// <returns>A copy of INmsDeserializationPolicy object.</returns> + INmsDeserializationPolicy Clone(); + } +} \ No newline at end of file diff --git a/src/NmsDefaultDeserializationPolicy.cs b/src/NmsDefaultDeserializationPolicy.cs new file mode 100644 index 0000000..b56531a --- /dev/null +++ b/src/NmsDefaultDeserializationPolicy.cs @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Apache.NMS.ActiveMQ +{ + /// <summary> + /// Default implementation of the deserialization policy that can read allow and deny lists of + /// types/namespaces from the connection URI options. + /// + /// The policy reads a default deny list string value (comma separated) from the connection URI options + /// (nms.deserializationPolicy.deny) which defaults to null which indicates an empty deny list. + /// + /// The policy reads a default allow list string value (comma separated) from the connection URI options + /// (nms.deserializationPolicy.allowList) which defaults to <see cref="CATCH_ALL_WILDCARD"/> which + /// indicates that all types are allowed. + /// + /// The deny list overrides the allow list, entries that could match both are counted as denied. + /// + /// If the policy should treat all classes as untrusted, the deny list should be set to <see cref="CATCH_ALL_WILDCARD"/>. + /// </summary> + public class NmsDefaultDeserializationPolicy : INmsDeserializationPolicy + { + /// <summary> + /// Value used to indicate that all types should be allowed or denied + /// </summary> + public const string CATCH_ALL_WILDCARD = "*"; + + private IReadOnlyList<string> denyList = Array.Empty<string>(); + private IReadOnlyList<string> allowList = new[] { CATCH_ALL_WILDCARD }; + + public bool IsTrustedType(IDestination destination, Type type) + { + var typeName = type?.FullName; + if (typeName == null) + { + return true; + } + + foreach (var denyListEntry in denyList) + { + if (CATCH_ALL_WILDCARD == denyListEntry) + { + return false; + } + if (IsTypeOrNamespaceMatch(typeName, denyListEntry)) + { + return false; + } + } + + foreach (var allowListEntry in allowList) + { + if (CATCH_ALL_WILDCARD == allowListEntry) + { + return true; + } + if (IsTypeOrNamespaceMatch(typeName, allowListEntry)) + { + return true; + } + } + + // Failing outright rejection or allow from above, reject. + return false; + } + + private bool IsTypeOrNamespaceMatch(string typeName, string listEntry) + { + // Check if type is an exact match of the entry + if (typeName == listEntry) + { + return true; + } + + // Check if the type is from a namespace matching the entry + var entryLength = listEntry.Length; + return typeName.Length > entryLength && typeName.StartsWith(listEntry) && '.' == typeName[entryLength]; + } + + public INmsDeserializationPolicy Clone() + { + return new NmsDefaultDeserializationPolicy + { + allowList = allowList.ToArray(), + denyList = denyList.ToArray() + }; + } + + /// <summary> + /// Gets or sets the deny list on this policy instance. + /// </summary> + public string DenyList + { + get => string.Join(",", denyList); + set => denyList = string.IsNullOrWhiteSpace(value) + ? Array.Empty<string>() + : value.Split(','); + } + + /// <summary> + /// Gets or sets the allow list on this policy instance. + /// </summary> + public string AllowList + { + get => string.Join(",", allowList); + set => allowList = string.IsNullOrWhiteSpace(value) + ? Array.Empty<string>() + : value.Split(','); + } + } +} \ No newline at end of file diff --git a/test/MessageConsumerTest.cs b/test/MessageConsumerTest.cs index e635434..4d59dcd 100644 --- a/test/MessageConsumerTest.cs +++ b/test/MessageConsumerTest.cs @@ -20,6 +20,7 @@ using Apache.NMS.Test; using NUnit.Framework; using Apache.NMS.ActiveMQ.Commands; using System; +using System.Runtime.Serialization; using Apache.NMS.Util; namespace Apache.NMS.ActiveMQ.Test @@ -306,5 +307,91 @@ namespace Apache.NMS.ActiveMQ.Test } } } + + [Test, Timeout(20_000)] + public void TestShouldNotDeserializeUntrustedType() + { + string uri = "activemq:tcp://${{activemqhost}}:61616"; + var factory = new ConnectionFactory(ReplaceEnvVar(uri)) + { + DeserializationPolicy = new NmsDefaultDeserializationPolicy + { + DenyList = typeof(UntrustedType).FullName + } + }; + using var connection = factory.CreateConnection("", ""); + + connection.Start(); + var session = connection.CreateSession(AcknowledgementMode.AutoAcknowledge); + var queue = session.GetQueue(Guid.NewGuid().ToString()); + var consumer = session.CreateConsumer(queue); + var producer = session.CreateProducer(queue); + + var message = producer.CreateObjectMessage(new UntrustedType { Prop1 = "foo" }); + producer.Send(message); + + var receivedMessage = consumer.Receive(); + var objectMessage = receivedMessage as IObjectMessage; + Assert.NotNull(objectMessage); + var exception = Assert.Throws<SerializationException>(() => + { + _ = objectMessage.Body; + }); + Assert.AreEqual($"Forbidden {typeof(UntrustedType).FullName}! " + + "This type is not trusted to be deserialized under the current configuration. " + + "Please refer to the documentation for more information on how to configure trusted types.", + exception.Message); + } + + [Test] + public void TestShouldUseCustomDeserializationPolicy() + { + string uri = "activemq:tcp://${{activemqhost}}:61616"; + var factory = new ConnectionFactory(ReplaceEnvVar(uri)) + { + DeserializationPolicy = new CustomDeserializationPolicy() + }; + using var connection = factory.CreateConnection("", ""); + connection.Start(); + var session = connection.CreateSession(AcknowledgementMode.AutoAcknowledge); + var queue = session.GetQueue(Guid.NewGuid().ToString()); + var consumer = session.CreateConsumer(queue); + var producer = session.CreateProducer(queue); + + var message = producer.CreateObjectMessage(new UntrustedType { Prop1 = "foo" }); + producer.Send(message); + + var receivedMessage = consumer.Receive(); + var objectMessage = receivedMessage as IObjectMessage; + Assert.NotNull(objectMessage); + _ = Assert.Throws<SerializationException>(() => + { + _ = objectMessage.Body; + }); + } + + [Serializable] + public class UntrustedType + { + public string Prop1 { get; set; } + } + + private class CustomDeserializationPolicy : INmsDeserializationPolicy + { + public bool IsTrustedType(IDestination destination, Type type) + { + if (type == typeof(UntrustedType)) + { + return false; + } + + return true; + } + + public INmsDeserializationPolicy Clone() + { + return this; + } + } } } diff --git a/test/NMSConnectionFactoryTest.cs b/test/NMSConnectionFactoryTest.cs index 5a6ce48..5be6ff6 100644 --- a/test/NMSConnectionFactoryTest.cs +++ b/test/NMSConnectionFactoryTest.cs @@ -211,6 +211,28 @@ namespace Apache.NMS.ActiveMQ.Test connection.Close(); } - } + } + + [Test] + public void TestSetDeserializationPolicy() + { + string baseUri = "activemq:tcp://${{activemqhost}}:61616"; + string configuredUri = baseUri + + "?nms.deserializationPolicy.allowList=a,b,c" + + "&nms.deserializationPolicy.denyList=c,d,e"; + + var factory = new NMSConnectionFactory(NMSTestSupport.ReplaceEnvVar(configuredUri)); + + Assert.IsNotNull(factory); + Assert.IsNotNull(factory.ConnectionFactory); + using IConnection connection = factory.CreateConnection("", ""); + Assert.IsNotNull(connection); + var amqConnection = connection as Connection; + var deserializationPolicy = amqConnection.DeserializationPolicy as NmsDefaultDeserializationPolicy; + Assert.IsNotNull(deserializationPolicy); + Assert.AreEqual("a,b,c", deserializationPolicy.AllowList); + Assert.AreEqual("c,d,e", deserializationPolicy.DenyList); + connection.Close(); + } } } diff --git a/test/NmsDefaultDeserializationPolicyTest.cs b/test/NmsDefaultDeserializationPolicyTest.cs new file mode 100644 index 0000000..7ac772b --- /dev/null +++ b/test/NmsDefaultDeserializationPolicyTest.cs @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using Apache.NMS.Commands; +using NUnit.Framework; + +namespace Apache.NMS.ActiveMQ.Test +{ + [TestFixture] + public class NmsDefaultDeserializationPolicyTest + { + [Test] + public void TestIsTrustedType() + { + var destination = new Queue("test-queue"); + var policy = new NmsDefaultDeserializationPolicy(); + + Assert.True(policy.IsTrustedType(destination, null)); + Assert.True(policy.IsTrustedType(destination, typeof(Guid))); + Assert.True(policy.IsTrustedType(destination, typeof(string))); + Assert.True(policy.IsTrustedType(destination, typeof(bool))); + Assert.True(policy.IsTrustedType(destination, typeof(double))); + Assert.True(policy.IsTrustedType(destination, typeof(object))); + + // Only types in System + policy.AllowList = "System"; + Assert.True(policy.IsTrustedType(destination, null)); + Assert.True(policy.IsTrustedType(destination, typeof(Guid))); + Assert.True(policy.IsTrustedType(destination, typeof(string))); + Assert.True(policy.IsTrustedType(destination, typeof(bool))); + Assert.True(policy.IsTrustedType(destination, typeof(double))); + Assert.True(policy.IsTrustedType(destination, typeof(object))); + Assert.False(policy.IsTrustedType(destination, GetType())); + + // Entry must be complete namespace name prefix to match + // i.e. while "System.C" is a prefix of "System.Collections", this + // wont match the Queue class below. + policy.AllowList = "System.C"; + Assert.False(policy.IsTrustedType(destination, typeof(Guid))); + Assert.False(policy.IsTrustedType(destination, typeof(string))); + Assert.False(policy.IsTrustedType(destination, typeof(System.Collections.Queue))); + + // Add a non-core namespace + policy.AllowList = $"System,{GetType().Namespace}"; + Assert.True(policy.IsTrustedType(destination, typeof(string))); + Assert.True(policy.IsTrustedType(destination, GetType())); + + // Try with a type-specific entry + policy.AllowList = typeof(string).FullName; + Assert.True(policy.IsTrustedType(destination, typeof(string))); + Assert.False(policy.IsTrustedType(destination, typeof(bool))); + + // Verify deny list overrides allow list + policy.AllowList = "System"; + policy.DenyList = "System"; + Assert.False(policy.IsTrustedType(destination, typeof(string))); + + // Verify deny list entry prefix overrides allow list + policy.AllowList = typeof(string).FullName; + policy.DenyList = typeof(string).Namespace; + Assert.False(policy.IsTrustedType(destination, typeof(string))); + + // Verify deny list catch-all overrides allow list + policy.AllowList = typeof(string).FullName; + policy.DenyList = NmsDefaultDeserializationPolicy.CATCH_ALL_WILDCARD; + Assert.False(policy.IsTrustedType(destination, typeof(string))); + } + + [Test] + public void TestNmsDefaultDeserializationPolicy() + { + var policy = new NmsDefaultDeserializationPolicy(); + + Assert.IsNotEmpty(policy.AllowList); + Assert.IsEmpty(policy.DenyList); + } + + [Test] + public void TestNmsDefaultDeserializationPolicyClone() + { + var policy = new NmsDefaultDeserializationPolicy + { + AllowList = "a.b.c", + DenyList = "d.e.f" + }; + + var clone = (NmsDefaultDeserializationPolicy) policy.Clone(); + Assert.AreEqual(policy.AllowList, clone.AllowList); + Assert.AreEqual(policy.DenyList, clone.DenyList); + Assert.AreNotSame(clone, policy); + } + + [Test] + public void TestSetAllowList() + { + var policy = new NmsDefaultDeserializationPolicy(); + Assert.NotNull(policy.AllowList); + + policy.AllowList = null; + Assert.NotNull(policy.AllowList); + Assert.IsEmpty(policy.AllowList); + + policy.AllowList = string.Empty; + Assert.NotNull(policy.AllowList); + Assert.IsEmpty(policy.AllowList); + + policy.AllowList = "*"; + Assert.NotNull(policy.AllowList); + Assert.IsNotEmpty(policy.AllowList); + + policy.AllowList = "a,b,c"; + Assert.NotNull(policy.AllowList); + Assert.IsNotEmpty(policy.AllowList); + Assert.AreEqual("a,b,c", policy.AllowList); + } + + [Test] + public void TestSetDenyList() + { + var policy = new NmsDefaultDeserializationPolicy(); + Assert.NotNull(policy.DenyList); + + policy.DenyList = null; + Assert.NotNull(policy.DenyList); + Assert.IsEmpty(policy.DenyList); + + policy.DenyList = string.Empty; + Assert.NotNull(policy.DenyList); + Assert.IsEmpty(policy.DenyList); + + policy.DenyList = "*"; + Assert.NotNull(policy.DenyList); + Assert.IsNotEmpty(policy.DenyList); + + policy.DenyList = "a,b,c"; + Assert.NotNull(policy.DenyList); + Assert.IsNotEmpty(policy.DenyList); + Assert.AreEqual("a,b,c", policy.DenyList); + } + } +} \ No newline at end of file
