Repository: thrift Updated Branches: refs/heads/master 3815e0b2d -> 1dc265301
THRIFT-3070 Add ability to set the LocalCertificateSelectionCallback Client: C# Patch: Hans-Peter Klett <[email protected]> This closes #415 Added an optional LocalCertificateSelectionCallback. Also cleans up the connection when a secure authentication fails on the server. Project: http://git-wip-us.apache.org/repos/asf/thrift/repo Commit: http://git-wip-us.apache.org/repos/asf/thrift/commit/1dc26530 Tree: http://git-wip-us.apache.org/repos/asf/thrift/tree/1dc26530 Diff: http://git-wip-us.apache.org/repos/asf/thrift/diff/1dc26530 Branch: refs/heads/master Commit: 1dc265301d7d184438c163afd5bfd93918844603 Parents: 3815e0b Author: Jens Geyer <[email protected]> Authored: Sun Apr 5 19:13:29 2015 +0200 Committer: Jens Geyer <[email protected]> Committed: Sun Apr 5 19:27:19 2015 +0200 ---------------------------------------------------------------------- lib/csharp/src/Transport/TTLSServerSocket.cs | 23 +++++- lib/csharp/src/Transport/TTLSSocket.cs | 85 +++++++++++++++++++---- 2 files changed, 93 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/thrift/blob/1dc26530/lib/csharp/src/Transport/TTLSServerSocket.cs ---------------------------------------------------------------------- diff --git a/lib/csharp/src/Transport/TTLSServerSocket.cs b/lib/csharp/src/Transport/TTLSServerSocket.cs index 2e2d299..631a593 100644 --- a/lib/csharp/src/Transport/TTLSServerSocket.cs +++ b/lib/csharp/src/Transport/TTLSServerSocket.cs @@ -60,6 +60,11 @@ namespace Thrift.Transport private RemoteCertificateValidationCallback clientCertValidator; /// <summary> + /// The function to determine which certificate to use. + /// </summary> + private LocalCertificateSelectionCallback localCertificateSelectionCallback; + + /// <summary> /// Initializes a new instance of the <see cref="TTLSServerSocket" /> class. /// </summary> /// <param name="port">The port where the server runs.</param> @@ -88,7 +93,14 @@ namespace Thrift.Transport /// <param name="useBufferedSockets">If set to <c>true</c> [use buffered sockets].</param> /// <param name="certificate">The certificate object.</param> /// <param name="clientCertValidator">The certificate validator.</param> - public TTLSServerSocket(int port, int clientTimeout, bool useBufferedSockets, X509Certificate2 certificate, RemoteCertificateValidationCallback clientCertValidator = null) + /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param> + public TTLSServerSocket( + int port, + int clientTimeout, + bool useBufferedSockets, + X509Certificate2 certificate, + RemoteCertificateValidationCallback clientCertValidator = null, + LocalCertificateSelectionCallback localCertificateSelectionCallback = null) { if (!certificate.HasPrivateKey) { @@ -99,6 +111,7 @@ namespace Thrift.Transport this.serverCertificate = certificate; this.useBufferedSockets = useBufferedSockets; this.clientCertValidator = clientCertValidator; + this.localCertificateSelectionCallback = localCertificateSelectionCallback; try { // Create server socket @@ -150,7 +163,13 @@ namespace Thrift.Transport client.SendTimeout = client.ReceiveTimeout = this.clientTimeout; //wrap the client in an SSL Socket passing in the SSL cert - TTLSSocket socket = new TTLSSocket(client, this.serverCertificate, true, this.clientCertValidator); + TTLSSocket socket = new TTLSSocket( + client, + this.serverCertificate, + true, + this.clientCertValidator, + this.localCertificateSelectionCallback + ); socket.setupTLS(); http://git-wip-us.apache.org/repos/asf/thrift/blob/1dc26530/lib/csharp/src/Transport/TTLSSocket.cs ---------------------------------------------------------------------- diff --git a/lib/csharp/src/Transport/TTLSSocket.cs b/lib/csharp/src/Transport/TTLSSocket.cs index ca8ee41..5652556 100644 --- a/lib/csharp/src/Transport/TTLSSocket.cs +++ b/lib/csharp/src/Transport/TTLSSocket.cs @@ -72,17 +72,29 @@ namespace Thrift.Transport private RemoteCertificateValidationCallback certValidator = null; /// <summary> + /// The function to determine which certificate to use. + /// </summary> + private LocalCertificateSelectionCallback localCertificateSelectionCallback; + + /// <summary> /// Initializes a new instance of the <see cref="TTLSSocket"/> class. /// </summary> /// <param name="client">An already created TCP-client</param> /// <param name="certificate">The certificate.</param> /// <param name="isServer">if set to <c>true</c> [is server].</param> /// <param name="certValidator">User defined cert validator.</param> - public TTLSSocket(TcpClient client, X509Certificate certificate, bool isServer = false, RemoteCertificateValidationCallback certValidator = null) + /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param> + public TTLSSocket( + TcpClient client, + X509Certificate certificate, + bool isServer = false, + RemoteCertificateValidationCallback certValidator = null, + LocalCertificateSelectionCallback localCertificateSelectionCallback = null) { this.client = client; this.certificate = certificate; this.certValidator = certValidator; + this.localCertificateSelectionCallback = localCertificateSelectionCallback; this.isServer = isServer; if (IsOpen) @@ -99,8 +111,14 @@ namespace Thrift.Transport /// <param name="port">The port.</param> /// <param name="certificatePath">The certificate path.</param> /// <param name="certValidator">User defined cert validator.</param> - public TTLSSocket(string host, int port, string certificatePath, RemoteCertificateValidationCallback certValidator = null) - : this(host, port, 0, X509Certificate.CreateFromCertFile(certificatePath), certValidator) + /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param> + public TTLSSocket( + string host, + int port, + string certificatePath, + RemoteCertificateValidationCallback certValidator = null, + LocalCertificateSelectionCallback localCertificateSelectionCallback = null) + : this(host, port, 0, X509Certificate.CreateFromCertFile(certificatePath), certValidator, localCertificateSelectionCallback) { } @@ -111,8 +129,14 @@ namespace Thrift.Transport /// <param name="port">The port.</param> /// <param name="certificate">The certificate.</param> /// <param name="certValidator">User defined cert validator.</param> - public TTLSSocket(string host, int port, X509Certificate certificate, RemoteCertificateValidationCallback certValidator = null) - : this(host, port, 0, certificate, certValidator) + /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param> + public TTLSSocket( + string host, + int port, + X509Certificate certificate, + RemoteCertificateValidationCallback certValidator = null, + LocalCertificateSelectionCallback localCertificateSelectionCallback = null) + : this(host, port, 0, certificate, certValidator, localCertificateSelectionCallback) { } @@ -124,13 +148,21 @@ namespace Thrift.Transport /// <param name="timeout">The timeout.</param> /// <param name="certificate">The certificate.</param> /// <param name="certValidator">User defined cert validator.</param> - public TTLSSocket(string host, int port, int timeout, X509Certificate certificate, RemoteCertificateValidationCallback certValidator = null) + /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param> + public TTLSSocket( + string host, + int port, + int timeout, + X509Certificate certificate, + RemoteCertificateValidationCallback certValidator = null, + LocalCertificateSelectionCallback localCertificateSelectionCallback = null) { this.host = host; this.port = port; this.timeout = timeout; this.certificate = certificate; this.certValidator = certValidator; + this.localCertificateSelectionCallback = localCertificateSelectionCallback; InitSocket(); } @@ -213,7 +245,7 @@ namespace Thrift.Transport /// <param name="chain">The certificate chain.</param> /// <param name="sslPolicyErrors">An enum, which lists all the errors from the .NET certificate check.</param> /// <returns></returns> - private bool CertificateValidator(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslValidationErrors) + private bool DefaultCertificateValidator(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslValidationErrors) { return (sslValidationErrors == SslPolicyErrors.None); } @@ -253,16 +285,43 @@ namespace Thrift.Transport /// </summary> public void setupTLS() { - this.secureStream = new SslStream(this.client.GetStream(), false, this.certValidator ?? CertificateValidator); - if (isServer) + RemoteCertificateValidationCallback validator = this.certValidator ?? DefaultCertificateValidator; + + if( this.localCertificateSelectionCallback != null) { - // Server authentication - this.secureStream.AuthenticateAsServer(this.certificate, this.certValidator != null, SslProtocols.Tls, true); + this.secureStream = new SslStream( + this.client.GetStream(), + false, + validator, + this.localCertificateSelectionCallback + ); } else { - // Client authentication - this.secureStream.AuthenticateAsClient(host, new X509CertificateCollection { certificate }, SslProtocols.Tls, true); + this.secureStream = new SslStream( + this.client.GetStream(), + false, + validator + ); + } + + try + { + if (isServer) + { + // Server authentication + this.secureStream.AuthenticateAsServer(this.certificate, this.certValidator != null, SslProtocols.Tls, true); + } + else + { + // Client authentication + this.secureStream.AuthenticateAsClient(host, new X509CertificateCollection { certificate }, SslProtocols.Tls, true); + } + } + catch (Exception) + { + this.Close(); + throw; } inputStream = this.secureStream;
