This is an automated email from the ASF dual-hosted git repository.
gkoszyk pushed a commit to branch io_uring_tpc
in repository https://gitbox.apache.org/repos/asf/iggy.git
The following commit(s) were added to refs/heads/io_uring_tpc by this push:
new bd0d43416 feat(io_uring): Add TLS support for iggy websocket server
and client (#2265)
bd0d43416 is described below
commit bd0d4341621ba564bf89b073f2a5c12e283fc7ae
Author: Krishna Vishal <[email protected]>
AuthorDate: Thu Oct 16 14:42:11 2025 +0530
feat(io_uring): Add TLS support for iggy websocket server and client (#2265)
---
Cargo.lock | 6 +
Cargo.toml | 2 +-
core/common/src/types/args/mod.rs | 16 ++
.../websocket_config/websocket_client_config.rs | 18 +-
.../websocket_client_config_builder.rs | 2 +-
.../websocket_connection_string_options.rs | 8 +-
core/configs/server.toml | 6 +
core/sdk/src/client_provider.rs | 4 +
core/sdk/src/websocket/mod.rs | 2 +
core/sdk/src/websocket/websocket_client.rs | 262 +++++++++++++++------
.../src/websocket/websocket_connection_stream.rs | 6 -
core/sdk/src/websocket/websocket_stream_kind.rs | 59 +++++
...tream.rs => websocket_tls_connection_stream.rs} | 55 +++--
core/server/src/binary/sender.rs | 7 +
core/server/src/configs/defaults.rs | 14 +-
core/server/src/configs/websocket.rs | 10 +
core/server/src/websocket/mod.rs | 2 +
core/server/src/websocket/websocket_server.rs | 22 +-
.../server/src/websocket/websocket_tls_listener.rs | 257 ++++++++++++++++++++
core/server/src/websocket/websocket_tls_sender.rs | 186 +++++++++++++++
examples/rust/src/shared/args.rs | 20 ++
21 files changed, 841 insertions(+), 123 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index f92a8002a..419b01320 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -8403,8 +8403,12 @@ checksum =
"489a59b6730eda1b0171fcfda8b121f4bee2b35cba8645ca35c5f7ba3eb736c1"
dependencies = [
"futures-util",
"log",
+ "rustls",
+ "rustls-pki-types",
"tokio",
+ "tokio-rustls",
"tungstenite",
+ "webpki-roots 0.26.11",
]
[[package]]
@@ -8781,6 +8785,8 @@ dependencies = [
"httparse",
"log",
"rand 0.9.2",
+ "rustls",
+ "rustls-pki-types",
"sha1",
"thiserror 2.0.17",
"utf-8",
diff --git a/Cargo.toml b/Cargo.toml
index 247c2946a..fb61475ae 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -165,7 +165,7 @@ compio-ws = { git =
"https://github.com/krishvishal/compio-ws", branch = "main",
"rustls",
] }
tungstenite = "0.27.0"
-tokio-tungstenite = "0.27.0"
+tokio-tungstenite = { version = "0.27", features = ["rustls-tls-webpki-roots"]
}
tokio-rustls = "0.26.2"
tokio-util = { version = "0.7.16", features = ["compat"] }
toml = "0.9.5"
diff --git a/core/common/src/types/args/mod.rs
b/core/common/src/types/args/mod.rs
index e4e1a56f1..11f93194e 100644
--- a/core/common/src/types/args/mod.rs
+++ b/core/common/src/types/args/mod.rs
@@ -340,6 +340,18 @@ pub struct Args {
/// The optional heartbeat interval for the WebSocket transport
pub websocket_heartbeat_interval: String,
+
+ /// The optional TLS enabled for the WebSocket transport
+ pub websocket_tls_enabled: bool,
+
+ /// The optional TLS domain for the WebSocket transport
+ pub websocket_tls_domain: String,
+
+ /// The optional TLS CA file for the WebSocket transport
+ pub websocket_tls_ca_file: Option<String>,
+
+ /// The optional TLS validate certificate for the WebSocket transport
+ pub websocket_tls_validate_certificate: bool,
}
const QUIC_TRANSPORT: &str = "quic";
@@ -409,6 +421,10 @@ impl Default for Args {
websocket_reconnection_interval: "1s".to_string(),
websocket_reconnection_reestablish_after: "5s".to_string(),
websocket_heartbeat_interval: "5s".to_string(),
+ websocket_tls_enabled: false,
+ websocket_tls_domain: "localhost".to_string(),
+ websocket_tls_ca_file: None,
+ websocket_tls_validate_certificate: false,
}
}
}
diff --git
a/core/common/src/types/configuration/websocket_config/websocket_client_config.rs
b/core/common/src/types/configuration/websocket_config/websocket_client_config.rs
index 34cf20f53..a3250272b 100644
---
a/core/common/src/types/configuration/websocket_config/websocket_client_config.rs
+++
b/core/common/src/types/configuration/websocket_config/websocket_client_config.rs
@@ -36,6 +36,14 @@ pub struct WebSocketClientConfig {
pub heartbeat_interval: IggyDuration,
/// WebSocket-specific configuration.
pub ws_config: WebSocketConfig,
+ /// Whether tls is enabled
+ pub tls_enabled: bool,
+ /// The domain to use for TLS
+ pub tls_domain: String,
+ /// The path to the CA file for TLS
+ pub tls_ca_file: Option<String>,
+ /// Whether to validate the TLS certificate
+ pub tls_validate_certificate: bool,
}
/// WebSocket-specific configuration that maps to tungstenite options.
@@ -59,11 +67,15 @@ pub struct WebSocketConfig {
impl Default for WebSocketClientConfig {
fn default() -> Self {
WebSocketClientConfig {
- server_address: "127.0.0.1:8080".to_string(),
+ server_address: "127.0.0.1:8092".to_string(),
auto_login: AutoLogin::Disabled,
reconnection: WebSocketClientReconnectionConfig::default(),
heartbeat_interval: IggyDuration::from_str("5s").unwrap(),
ws_config: WebSocketConfig::default(),
+ tls_enabled: false,
+ tls_domain: "localhost".to_string(),
+ tls_ca_file: None,
+ tls_validate_certificate: false,
}
}
}
@@ -156,6 +168,10 @@ impl
From<ConnectionString<WebSocketConnectionStringOptions>> for WebSocketClien
reconnection:
connection_string.options().reconnection().to_owned(),
heartbeat_interval:
connection_string.options().heartbeat_interval(),
ws_config,
+ tls_enabled: options.tls_enabled(),
+ tls_domain: options.tls_domain().into(),
+ tls_ca_file: options.tls_ca_file().map(|s| s.to_string()),
+ tls_validate_certificate: options.tls_validate_certificate(),
}
}
}
diff --git
a/core/common/src/types/configuration/websocket_config/websocket_client_config_builder.rs
b/core/common/src/types/configuration/websocket_config/websocket_client_config_builder.rs
index ca581829b..2f5d18d9f 100644
---
a/core/common/src/types/configuration/websocket_config/websocket_client_config_builder.rs
+++
b/core/common/src/types/configuration/websocket_config/websocket_client_config_builder.rs
@@ -21,7 +21,7 @@ use std::net::SocketAddr;
/// Builder for the WebSocket client configuration.
/// Allows configuring the WebSocket client with custom settings or using
defaults:
-/// - `server_address`: Default is "127.0.0.1:8080"
+/// - `server_address`: Default is "127.0.0.1:8092"
/// - `auto_login`: Default is AutoLogin::Disabled.
/// - `reconnection`: Default is enabled unlimited retries and 1 second
interval.
/// - `heartbeat_interval`: Default is 5 seconds.
diff --git
a/core/common/src/types/configuration/websocket_config/websocket_connection_string_options.rs
b/core/common/src/types/configuration/websocket_config/websocket_connection_string_options.rs
index 60da518ea..85c5fb205 100644
---
a/core/common/src/types/configuration/websocket_config/websocket_connection_string_options.rs
+++
b/core/common/src/types/configuration/websocket_config/websocket_connection_string_options.rs
@@ -78,8 +78,8 @@ impl WebSocketConnectionStringOptions {
&self.tls_domain
}
- pub fn tls_ca_file(&self) -> &Option<String> {
- &self.tls_ca_file
+ pub fn tls_ca_file(&self) -> Option<&str> {
+ self.tls_ca_file.as_deref()
}
pub fn tls_validate_certificate(&self) -> bool {
@@ -201,9 +201,9 @@ impl Default for WebSocketConnectionStringOptions {
max_frame_size: None,
accept_unmasked_frames: None,
tls_enabled: false,
- tls_domain: "localhost".to_string(),
+ tls_domain: "".to_string(),
tls_ca_file: None,
- tls_validate_certificate: true,
+ tls_validate_certificate: false,
}
}
}
diff --git a/core/configs/server.toml b/core/configs/server.toml
index 2de2dd11f..d721bc033 100644
--- a/core/configs/server.toml
+++ b/core/configs/server.toml
@@ -605,3 +605,9 @@ cpu_allocation = "all"
enabled = true
address = "0.0.0.0:8092"
+
+[websocket.tls]
+enabled = true
+self_signed = true
+cert_file = "core/certs/iggy_cert.pem"
+key_file = "core/certs/iggy_key.pem"
diff --git a/core/sdk/src/client_provider.rs b/core/sdk/src/client_provider.rs
index eca6dd13d..71982878a 100644
--- a/core/sdk/src/client_provider.rs
+++ b/core/sdk/src/client_provider.rs
@@ -182,6 +182,10 @@ impl ClientProviderConfig {
AutoLogin::Disabled
},
ws_config: WebSocketConfig::default(),
+ tls_enabled: args.websocket_tls_enabled,
+ tls_domain: args.websocket_tls_domain,
+ tls_ca_file: args.websocket_tls_ca_file,
+ tls_validate_certificate:
args.websocket_tls_validate_certificate,
}));
}
}
diff --git a/core/sdk/src/websocket/mod.rs b/core/sdk/src/websocket/mod.rs
index 40e3aa46b..e94195686 100644
--- a/core/sdk/src/websocket/mod.rs
+++ b/core/sdk/src/websocket/mod.rs
@@ -19,3 +19,5 @@
pub mod websocket_client;
pub(crate) mod websocket_connection_stream;
pub(crate) mod websocket_stream;
+pub(crate) mod websocket_stream_kind;
+pub(crate) mod websocket_tls_connection_stream;
diff --git a/core/sdk/src/websocket/websocket_client.rs
b/core/sdk/src/websocket/websocket_client.rs
index 9c0aaa1c3..a5d1414c2 100644
--- a/core/sdk/src/websocket/websocket_client.rs
+++ b/core/sdk/src/websocket/websocket_client.rs
@@ -17,7 +17,11 @@
*/
use crate::websocket::websocket_connection_stream::WebSocketConnectionStream;
-use crate::{prelude::Client, websocket::websocket_stream::ConnectionStream};
+use crate::websocket::websocket_stream_kind::WebSocketStreamKind;
+use
crate::websocket::websocket_tls_connection_stream::WebSocketTlsConnectionStream;
+use rustls::{ClientConfig, pki_types::pem::PemObject};
+
+use crate::prelude::Client;
use async_broadcast::{Receiver, Sender, broadcast};
use async_trait::async_trait;
use bytes::{BufMut, Bytes, BytesMut};
@@ -31,7 +35,10 @@ use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::time::sleep;
-use tokio_tungstenite::{client_async_with_config,
tungstenite::client::IntoClientRequest};
+use tokio_tungstenite::{
+ Connector, client_async_with_config, connect_async_tls_with_config,
+ tungstenite::client::IntoClientRequest,
+};
use tracing::{debug, error, info, trace, warn};
const REQUEST_INITIAL_BYTES_LENGTH: usize = 4;
@@ -40,7 +47,7 @@ const NAME: &str = "WebSocket";
#[derive(Debug)]
pub struct WebSocketClient {
- stream: Arc<Mutex<Option<WebSocketConnectionStream>>>,
+ stream: Arc<Mutex<Option<WebSocketStreamKind>>>,
pub(crate) config: Arc<WebSocketClientConfig>,
pub(crate) state: Mutex<ClientState>,
client_address: Mutex<Option<SocketAddr>>,
@@ -175,9 +182,10 @@ impl WebSocketClient {
let mut retry_count = 0;
loop {
+ let protocol = if self.config.tls_enabled { "wss" } else { "ws" };
info!(
- "{NAME} client is connecting to server: {}...",
- self.config.server_address
+ "{NAME} client is connecting to server: {}://{}...",
+ protocol, self.config.server_address
);
self.set_state(ClientState::Connecting).await;
@@ -209,84 +217,24 @@ impl WebSocketClient {
IggyError::InvalidConfiguration
})?;
- let tcp_stream = match TcpStream::connect(&server_addr).await {
- Ok(stream) => stream,
- Err(error) => {
- error!(
- "Failed to connect to server: {}. Error: {}",
- self.config.server_address, error
- );
-
- if !self.config.reconnection.enabled {
- warn!("Automatic reconnection is disabled.");
+ let connection_stream = if self.config.tls_enabled {
+ match self.connect_tls(server_addr, &mut retry_count).await {
+ Ok(stream) => stream,
+ Err(IggyError::CannotEstablishConnection) => {
return Err(IggyError::CannotEstablishConnection);
}
-
- let unlimited_retries =
self.config.reconnection.max_retries.is_none();
- let max_retries =
self.config.reconnection.max_retries.unwrap_or_default();
- let max_retries_str = self
- .config
- .reconnection
- .max_retries
- .map(|r| r.to_string())
- .unwrap_or_else(|| "unlimited".to_string());
-
- let interval_str =
self.config.reconnection.interval.as_human_time_string();
- if unlimited_retries || retry_count < max_retries {
- retry_count += 1;
- info!(
- "Retrying to connect to server
({retry_count}/{max_retries_str}): {} in: {interval_str}",
- self.config.server_address,
- );
-
sleep(self.config.reconnection.interval.get_duration()).await;
- continue;
+ Err(_) => continue, // retry
+ }
+ } else {
+ match self.connect_plain(server_addr, &mut retry_count).await {
+ Ok(stream) => stream,
+ Err(IggyError::CannotEstablishConnection) => {
+ return Err(IggyError::CannotEstablishConnection);
}
-
- self.set_state(ClientState::Disconnected).await;
- self.publish_event(DiagnosticEvent::Disconnected).await;
- return Err(IggyError::CannotEstablishConnection);
+ Err(_) => continue, // retry
}
};
- let ws_url = format!("ws://{}", server_addr);
-
- let request = ws_url.into_client_request().map_err(|e| {
- error!("Failed to create WebSocket request: {}", e);
- IggyError::InvalidConfiguration
- })?;
-
- let tungstenite_config =
self.config.ws_config.to_tungstenite_config();
-
- let (websocket_stream, response) =
- match client_async_with_config(request, tcp_stream,
tungstenite_config).await {
- Ok(result) => result,
- Err(error) => {
- error!("WebSocket handshake failed: {}", error);
-
- if !self.config.reconnection.enabled {
- return Err(IggyError::WebSocketConnectionError);
- }
-
- let unlimited_retries =
self.config.reconnection.max_retries.is_none();
- let max_retries =
self.config.reconnection.max_retries.unwrap_or_default();
-
- if unlimited_retries || retry_count < max_retries {
- retry_count += 1;
-
sleep(self.config.reconnection.interval.get_duration()).await;
- continue;
- }
-
- return Err(IggyError::WebSocketConnectionError);
- }
- };
-
- debug!(
- "WebSocket connection established. Response status: {}",
- response.status()
- );
-
- let connection_stream =
WebSocketConnectionStream::new(server_addr, websocket_stream);
-
*self.stream.lock().await = Some(connection_stream);
*self.client_address.lock().await = Some(server_addr);
self.set_state(ClientState::Connected).await;
@@ -304,6 +252,166 @@ impl WebSocketClient {
}
}
+ async fn connect_plain(
+ &self,
+ server_addr: SocketAddr,
+ retry_count: &mut u32,
+ ) -> Result<WebSocketStreamKind, IggyError> {
+ let tcp_stream = match TcpStream::connect(&server_addr).await {
+ Ok(stream) => stream,
+ Err(error) => {
+ error!(
+ "Failed to connect to server: {}. Error: {}",
+ self.config.server_address, error
+ );
+ return self.handle_connection_error(retry_count).await;
+ }
+ };
+
+ let ws_url = format!("ws://{}", server_addr);
+ let request = ws_url.into_client_request().map_err(|e| {
+ error!("Failed to create WebSocket request: {}", e);
+ IggyError::InvalidConfiguration
+ })?;
+
+ let tungstenite_config = self.config.ws_config.to_tungstenite_config();
+
+ let (websocket_stream, response) =
+ match client_async_with_config(request, tcp_stream,
tungstenite_config).await {
+ Ok(result) => result,
+ Err(error) => {
+ error!("WebSocket handshake failed: {}", error);
+ return self.handle_connection_error(retry_count).await;
+ }
+ };
+
+ debug!(
+ "WebSocket connection established. Response status: {}",
+ response.status()
+ );
+
+ let connection_stream = WebSocketConnectionStream::new(server_addr,
websocket_stream);
+ Ok(WebSocketStreamKind::Plain(connection_stream))
+ }
+
+ async fn connect_tls(
+ &self,
+ server_addr: SocketAddr,
+ retry_count: &mut u32,
+ ) -> Result<WebSocketStreamKind, IggyError> {
+ let tls_config = self.build_tls_config()?;
+ let connector = Connector::Rustls(Arc::new(tls_config));
+
+ let domain = if !self.config.tls_domain.is_empty() {
+ self.config.tls_domain.clone()
+ } else {
+ server_addr.ip().to_string()
+ };
+
+ let ws_url = format!("wss://{}:{}", domain, server_addr.port());
+ let tungstenite_config = self.config.ws_config.to_tungstenite_config();
+
+ debug!("Initiating WebSocket TLS connection to: {}", ws_url);
+ println!("Initiating WebSocket TLS connection to: {}", ws_url);
+ println!("tungstenite_config: {:?}", tungstenite_config);
+ let (websocket_stream, response) =
+ match connect_async_tls_with_config(ws_url, tungstenite_config,
false, Some(connector))
+ .await
+ {
+ Ok(result) => result,
+ Err(error) => {
+ error!("WebSocket TLS handshake failed: {}", error);
+ return self.handle_connection_error(retry_count).await;
+ }
+ };
+
+ debug!(
+ "WebSocket TLS connection established. Response status: {}",
+ response.status()
+ );
+
+ let connection_stream = WebSocketTlsConnectionStream::new(server_addr,
websocket_stream);
+ Ok(WebSocketStreamKind::Tls(connection_stream))
+ }
+
+ fn build_tls_config(&self) -> Result<ClientConfig, IggyError> {
+ if rustls::crypto::CryptoProvider::get_default().is_none() {
+ let _ =
rustls::crypto::aws_lc_rs::default_provider().install_default();
+ }
+
+ let config = if self.config.tls_validate_certificate {
+ let mut root_cert_store = rustls::RootCertStore::empty();
+
+ if let Some(certificate_path) = &self.config.tls_ca_file {
+ // load CA certificates from file
+ for cert in
rustls::pki_types::CertificateDer::pem_file_iter(certificate_path)
+ .map_err(|error| {
+ error!("Failed to read the CA file:
{certificate_path}. {error}");
+ IggyError::InvalidTlsCertificatePath
+ })?
+ {
+ let certificate = cert.map_err(|error| {
+ error!("Failed to read a certificate from the CA file:
{certificate_path}. {error}");
+ IggyError::InvalidTlsCertificate
+ })?;
+ root_cert_store.add(certificate).map_err(|error| {
+ error!(
+ "Failed to add a certificate to the root
certificate store. {error}"
+ );
+ IggyError::InvalidTlsCertificate
+ })?;
+ }
+ } else {
+
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
+ }
+
+ rustls::ClientConfig::builder()
+ .with_root_certificates(root_cert_store)
+ .with_no_client_auth()
+ } else {
+ // skip certificate validation (development/self-signed certs)
+ use crate::tcp::tcp_tls_verifier::NoServerVerification;
+ rustls::ClientConfig::builder()
+ .dangerous()
+
.with_custom_certificate_verifier(Arc::new(NoServerVerification))
+ .with_no_client_auth()
+ };
+
+ Ok(config)
+ }
+
+ async fn handle_connection_error<T>(&self, retry_count: &mut u32) ->
Result<T, IggyError> {
+ if !self.config.reconnection.enabled {
+ warn!("Automatic reconnection is disabled.");
+ return Err(IggyError::CannotEstablishConnection);
+ }
+
+ let unlimited_retries = self.config.reconnection.max_retries.is_none();
+ let max_retries =
self.config.reconnection.max_retries.unwrap_or_default();
+ let max_retries_str = self
+ .config
+ .reconnection
+ .max_retries
+ .map(|r| r.to_string())
+ .unwrap_or_else(|| "unlimited".to_string());
+
+ let interval_str =
self.config.reconnection.interval.as_human_time_string();
+
+ if unlimited_retries || *retry_count < max_retries {
+ *retry_count += 1;
+ info!(
+ "Retrying to connect to server ({}/{}): {} in: {}",
+ retry_count, max_retries_str, self.config.server_address,
interval_str
+ );
+ sleep(self.config.reconnection.interval.get_duration()).await;
+ return Err(IggyError::Disconnected); // signal to retry
+ }
+
+ self.set_state(ClientState::Disconnected).await;
+ self.publish_event(DiagnosticEvent::Disconnected).await;
+ Err(IggyError::CannotEstablishConnection)
+ }
+
async fn auto_login(&self) -> Result<(), IggyError> {
let client_address = self.get_client_address_value().await;
match &self.config.auto_login {
diff --git a/core/sdk/src/websocket/websocket_connection_stream.rs
b/core/sdk/src/websocket/websocket_connection_stream.rs
index e71a47283..b61320383 100644
--- a/core/sdk/src/websocket/websocket_connection_stream.rs
+++ b/core/sdk/src/websocket/websocket_connection_stream.rs
@@ -142,12 +142,6 @@ impl ConnectionStream for WebSocketConnectionStream {
requested_bytes, self.client_address
);
- debug!(
- "WebSocket read {} bytes: {:02x?}",
- requested_bytes,
- &buf[..requested_bytes.min(16)] // Log first 16 bytes
- );
-
Ok(requested_bytes)
}
diff --git a/core/sdk/src/websocket/websocket_stream_kind.rs
b/core/sdk/src/websocket/websocket_stream_kind.rs
new file mode 100644
index 000000000..b30c2ce54
--- /dev/null
+++ b/core/sdk/src/websocket/websocket_stream_kind.rs
@@ -0,0 +1,59 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use crate::websocket::websocket_connection_stream::WebSocketConnectionStream;
+use crate::websocket::websocket_stream::ConnectionStream;
+use
crate::websocket::websocket_tls_connection_stream::WebSocketTlsConnectionStream;
+use iggy_common::IggyError;
+
+#[derive(Debug)]
+#[allow(clippy::large_enum_variant)]
+pub(crate) enum WebSocketStreamKind {
+ Plain(WebSocketConnectionStream),
+ Tls(WebSocketTlsConnectionStream),
+}
+
+impl WebSocketStreamKind {
+ pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, IggyError> {
+ match self {
+ Self::Plain(stream) => stream.read(buf).await,
+ Self::Tls(stream) => stream.read(buf).await,
+ }
+ }
+
+ pub async fn write(&mut self, buf: &[u8]) -> Result<(), IggyError> {
+ match self {
+ Self::Plain(stream) => stream.write(buf).await,
+ Self::Tls(stream) => stream.write(buf).await,
+ }
+ }
+
+ pub async fn flush(&mut self) -> Result<(), IggyError> {
+ match self {
+ Self::Plain(stream) => stream.flush().await,
+ Self::Tls(stream) => stream.flush().await,
+ }
+ }
+
+ pub async fn shutdown(&mut self) -> Result<(), IggyError> {
+ match self {
+ Self::Plain(stream) => stream.shutdown().await,
+ Self::Tls(stream) => stream.shutdown().await,
+ }
+ }
+}
diff --git a/core/sdk/src/websocket/websocket_connection_stream.rs
b/core/sdk/src/websocket/websocket_tls_connection_stream.rs
similarity index 80%
copy from core/sdk/src/websocket/websocket_connection_stream.rs
copy to core/sdk/src/websocket/websocket_tls_connection_stream.rs
index e71a47283..8f954f191 100644
--- a/core/sdk/src/websocket/websocket_connection_stream.rs
+++ b/core/sdk/src/websocket/websocket_tls_connection_stream.rs
@@ -24,18 +24,21 @@ use iggy_common::IggyError;
use std::io::ErrorKind;
use std::net::SocketAddr;
use tokio::net::TcpStream;
-use tokio_tungstenite::{WebSocketStream, tungstenite::Message};
+use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
use tracing::{debug, error, trace};
#[derive(Debug)]
-pub struct WebSocketConnectionStream {
+pub struct WebSocketTlsConnectionStream {
client_address: SocketAddr,
- stream: WebSocketStream<TcpStream>,
+ stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
read_buffer: BytesMut,
}
-impl WebSocketConnectionStream {
- pub fn new(client_address: SocketAddr, stream: WebSocketStream<TcpStream>)
-> Self {
+impl WebSocketTlsConnectionStream {
+ pub fn new(
+ client_address: SocketAddr,
+ stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
+ ) -> Self {
Self {
client_address,
stream,
@@ -49,7 +52,7 @@ impl WebSocketConnectionStream {
match self.stream.next().await {
Some(Ok(Message::Binary(data))) => {
trace!(
- "Received WebSocket binary message from {}, size: {}
bytes",
+ "Received WebSocket TLS binary message from {}, size:
{} bytes",
self.client_address,
data.len()
);
@@ -58,7 +61,7 @@ impl WebSocketConnectionStream {
}
Some(Ok(Message::Text(text))) => {
trace!(
- "Received WebSocket text message from {}, converting
to binary",
+ "Received WebSocket TLS text message from {},
converting to binary",
self.client_address
);
self.read_buffer.extend_from_slice(text.as_bytes());
@@ -66,12 +69,12 @@ impl WebSocketConnectionStream {
}
Some(Ok(Message::Ping(data))) => {
trace!(
- "Received WebSocket ping from {}, sending pong",
+ "Received WebSocket TLS ping from {}, sending pong",
self.client_address
);
if let Err(e) =
self.stream.send(Message::Pong(data)).await {
error!(
- "Failed to send WebSocket pong to {}: {}",
+ "Failed to send WebSocket TLS pong to {}: {}",
self.client_address, e
);
return Err(IggyError::WebSocketSendError);
@@ -79,23 +82,22 @@ impl WebSocketConnectionStream {
continue;
}
Some(Ok(Message::Pong(_))) => {
- trace!("Received WebSocket pong from {}",
self.client_address);
+ trace!("Received WebSocket TLS pong from {}",
self.client_address);
continue;
}
Some(Ok(Message::Close(_))) => {
debug!(
- "WebSocket connection closed by client: {}",
+ "WebSocket TLS connection closed by client: {}",
self.client_address
);
return Err(IggyError::ConnectionClosed);
}
Some(Ok(Message::Frame(_))) => {
- // Raw frames - just continue
continue;
}
Some(Err(e)) => {
error!(
- "Failed to read WebSocket message from {}: {}",
+ "Failed to read WebSocket TLS message from {}: {}",
self.client_address, e
);
return match e {
@@ -113,7 +115,10 @@ impl WebSocketConnectionStream {
};
}
None => {
- debug!("WebSocket stream ended for client: {}",
self.client_address);
+ debug!(
+ "WebSocket TLS stream ended for client: {}",
+ self.client_address
+ );
return Err(IggyError::ConnectionClosed);
}
}
@@ -122,7 +127,7 @@ impl WebSocketConnectionStream {
}
#[async_trait]
-impl ConnectionStream for WebSocketConnectionStream {
+impl ConnectionStream for WebSocketTlsConnectionStream {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, IggyError> {
let requested_bytes = buf.len();
@@ -138,28 +143,22 @@ impl ConnectionStream for WebSocketConnectionStream {
let _consumed = self.read_buffer.split_to(requested_bytes);
trace!(
- "Read {} bytes from WebSocket stream for client: {}",
+ "Read {} bytes from WebSocket TLS stream for client: {}",
requested_bytes, self.client_address
);
- debug!(
- "WebSocket read {} bytes: {:02x?}",
- requested_bytes,
- &buf[..requested_bytes.min(16)] // Log first 16 bytes
- );
-
Ok(requested_bytes)
}
async fn write(&mut self, buf: &[u8]) -> Result<(), IggyError> {
trace!(
- "Writing {} bytes to WebSocket stream for client: {}",
+ "Writing {} bytes to WebSocket TLS stream for client: {}",
buf.len(),
self.client_address
);
debug!(
- "WebSocket write {} bytes: {:02x?}",
+ "WebSocket TLS write {} bytes: {:02x?}",
buf.len(),
&buf[..buf.len().min(16)]
);
@@ -169,7 +168,7 @@ impl ConnectionStream for WebSocketConnectionStream {
.await
.map_err(|e| {
error!(
- "Failed to write data to WebSocket connection for client:
{}: {}",
+ "Failed to write data to WebSocket TLS connection for
client: {}: {}",
self.client_address, e
);
match e {
@@ -190,7 +189,7 @@ impl ConnectionStream for WebSocketConnectionStream {
async fn flush(&mut self) -> Result<(), IggyError> {
trace!(
- "Flushing WebSocket stream for client: {}",
+ "Flushing WebSocket TLS stream for client: {}",
self.client_address
);
Ok(())
@@ -198,7 +197,7 @@ impl ConnectionStream for WebSocketConnectionStream {
async fn shutdown(&mut self) -> Result<(), IggyError> {
debug!(
- "Shutting down WebSocket connection for client: {}",
+ "Shutting down WebSocket TLS connection for client: {}",
self.client_address
);
@@ -210,7 +209,7 @@ impl ConnectionStream for WebSocketConnectionStream {
self.stream.send(close_message).await.map_err(|e| {
error!(
- "Failed to send close frame to WebSocket connection for
client: {}: {}",
+ "Failed to send close frame to WebSocket TLS connection for
client: {}: {}",
self.client_address, e
);
IggyError::WebSocketCloseError
diff --git a/core/server/src/binary/sender.rs b/core/server/src/binary/sender.rs
index eac21de20..81112d01f 100644
--- a/core/server/src/binary/sender.rs
+++ b/core/server/src/binary/sender.rs
@@ -20,6 +20,7 @@ use crate::streaming::utils::PooledBuffer;
use crate::tcp::tcp_sender::TcpSender;
use crate::tcp::tcp_tls_sender::TcpTlsSender;
use crate::websocket::websocket_sender::WebSocketSender;
+use crate::websocket::websocket_tls_sender::WebSocketTlsSender;
use crate::{quic::quic_sender::QuicSender, server_error::ServerError};
use compio::buf::IoBufMut;
use compio::net::TcpStream;
@@ -47,6 +48,7 @@ macro_rules! forward_async_methods {
Self::TcpTls(s) => s.$method_name$(::<$($generic),+>)?($(
$arg ),*).await,
Self::Quic(s) => s.$method_name$(::<$($generic),+>)?($(
$arg ),*).await,
Self::WebSocket(s) =>
s.$method_name$(::<$($generic),+>)?($( $arg ),*).await,
+ Self::WebSocketTls(s) =>
s.$method_name$(::<$($generic),+>)?($( $arg ),*).await,
}
}
)*
@@ -75,6 +77,7 @@ pub enum SenderKind {
TcpTls(TcpTlsSender),
Quic(QuicSender),
WebSocket(WebSocketSender),
+ WebSocketTls(WebSocketTlsSender),
}
impl SenderKind {
@@ -97,6 +100,10 @@ impl SenderKind {
Self::WebSocket(stream)
}
+ pub fn get_websocket_tls_sender(stream: WebSocketTlsSender) -> Self {
+ Self::WebSocketTls(stream)
+ }
+
forward_async_methods! {
async fn read<B: IoBufMut>(&mut self, buffer: B) -> (Result<(),
IggyError>, B);
async fn send_empty_ok_response(&mut self) -> Result<(), IggyError>;
diff --git a/core/server/src/configs/defaults.rs
b/core/server/src/configs/defaults.rs
index bb14600f9..7aed1e606 100644
--- a/core/server/src/configs/defaults.rs
+++ b/core/server/src/configs/defaults.rs
@@ -35,7 +35,7 @@ use crate::configs::system::{
StateConfig, StreamConfig, SystemConfig, TopicConfig,
};
use crate::configs::tcp::{TcpConfig, TcpTlsConfig};
-use crate::configs::websocket::WebSocketConfig;
+use crate::configs::websocket::{WebSocketConfig, WebSocketTlsConfig};
use iggy_common::IggyByteSize;
use iggy_common::IggyDuration;
use std::sync::Arc;
@@ -184,6 +184,18 @@ impl Default for WebSocketConfig {
max_message_size: None,
max_frame_size: None,
accept_unmasked_frames: false,
+ tls: WebSocketTlsConfig::default(),
+ }
+ }
+}
+
+impl Default for WebSocketTlsConfig {
+ fn default() -> WebSocketTlsConfig {
+ WebSocketTlsConfig {
+ enabled: SERVER_CONFIG.websocket.tls.enabled,
+ self_signed: SERVER_CONFIG.websocket.tls.self_signed,
+ cert_file: SERVER_CONFIG.websocket.tls.cert_file.parse().unwrap(),
+ key_file: SERVER_CONFIG.websocket.tls.key_file.parse().unwrap(),
}
}
}
diff --git a/core/server/src/configs/websocket.rs
b/core/server/src/configs/websocket.rs
index 7841edfbe..541e8253e 100644
--- a/core/server/src/configs/websocket.rs
+++ b/core/server/src/configs/websocket.rs
@@ -37,6 +37,16 @@ pub struct WebSocketConfig {
pub max_frame_size: Option<String>,
#[serde(default)]
pub accept_unmasked_frames: bool,
+ #[serde(default)]
+ pub tls: WebSocketTlsConfig,
+}
+
+#[derive(Debug, Deserialize, Serialize, Clone)]
+pub struct WebSocketTlsConfig {
+ pub enabled: bool,
+ pub self_signed: bool,
+ pub cert_file: String,
+ pub key_file: String,
}
impl WebSocketConfig {
diff --git a/core/server/src/websocket/mod.rs b/core/server/src/websocket/mod.rs
index 745b48cf2..b4ede3067 100644
--- a/core/server/src/websocket/mod.rs
+++ b/core/server/src/websocket/mod.rs
@@ -2,5 +2,7 @@ pub mod connection_handler;
pub mod websocket_listener;
pub mod websocket_sender;
pub mod websocket_server;
+pub mod websocket_tls_listener;
+pub mod websocket_tls_sender;
pub const COMPONENT: &str = "WEBSOCKET";
diff --git a/core/server/src/websocket/websocket_server.rs
b/core/server/src/websocket/websocket_server.rs
index 6edfe7a13..b95b3361d 100644
--- a/core/server/src/websocket/websocket_server.rs
+++ b/core/server/src/websocket/websocket_server.rs
@@ -18,7 +18,8 @@
use crate::shard::IggyShard;
use crate::shard::task_registry::ShutdownToken;
-use crate::websocket::websocket_listener::start;
+use crate::websocket::websocket_listener;
+use crate::websocket::websocket_tls_listener;
use crate::{shard_error, shard_info};
use iggy_common::IggyError;
use std::rc::Rc;
@@ -34,17 +35,30 @@ pub async fn spawn_websocket_server(
return Ok(());
}
+ let server_name = if config.tls.enabled {
+ "WebSocket TLS"
+ } else {
+ "WebSocket"
+ };
+
shard_info!(
shard.id,
- "Starting WebSocket server on: {} for shard: {}...",
+ "Starting {} server on: {} for shard: {}...",
+ server_name,
config.address,
shard.id
);
- if let Err(error) = start(config, shard.clone(), shutdown).await {
+ let result = match config.tls.enabled {
+ true => websocket_tls_listener::start(config, shard.clone(),
shutdown).await,
+ false => websocket_listener::start(config, shard.clone(),
shutdown).await,
+ };
+
+ if let Err(error) = result {
shard_error!(
shard.id,
- "WebSocket server has failed to start, error: {error}"
+ "{} server has failed to start, error: {error}",
+ server_name
);
return Err(error);
}
diff --git a/core/server/src/websocket/websocket_tls_listener.rs
b/core/server/src/websocket/websocket_tls_listener.rs
new file mode 100644
index 000000000..4341a36e4
--- /dev/null
+++ b/core/server/src/websocket/websocket_tls_listener.rs
@@ -0,0 +1,257 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use crate::binary::sender::SenderKind;
+use crate::configs::websocket::WebSocketConfig;
+use crate::shard::IggyShard;
+use crate::shard::task_registry::ShutdownToken;
+use crate::shard::transmission::event::ShardEvent;
+use crate::websocket::connection_handler::{handle_connection, handle_error};
+use crate::{shard_debug, shard_error, shard_info, shard_warn};
+use compio::net::TcpListener;
+use compio_tls::TlsAcceptor;
+use compio_ws::accept_async_with_config;
+use error_set::ErrContext;
+use futures::FutureExt;
+use iggy_common::{IggyError, TransportProtocol};
+use rustls::ServerConfig;
+use rustls::pki_types::{CertificateDer, PrivateKeyDer};
+use rustls_pemfile::{certs, private_key};
+use std::io::BufReader;
+use std::net::SocketAddr;
+use std::rc::Rc;
+use std::sync::Arc;
+use tracing::{error, info, trace};
+
+pub async fn start(
+ config: WebSocketConfig,
+ shard: Rc<IggyShard>,
+ shutdown: ShutdownToken,
+) -> Result<(), IggyError> {
+ let addr: SocketAddr = config
+ .address
+ .parse()
+ .with_error_context(|error| {
+ format!(
+ "WebSocket TLS (error: {error}) - failed to parse address: {}",
+ config.address
+ )
+ })
+ .map_err(|_| IggyError::InvalidConfiguration)?;
+
+ let listener = TcpListener::bind(addr)
+ .await
+ .with_error_context(|error| {
+ format!("WebSocket TLS (error: {error}) - failed to bind to
address: {addr}")
+ })
+ .map_err(|_| IggyError::CannotBindToSocket(addr.to_string()))?;
+
+ let local_addr = listener.local_addr().unwrap();
+
+ // Notify shard about the bound address
+ let event = ShardEvent::AddressBound {
+ protocol: TransportProtocol::WebSocket,
+ address: local_addr,
+ };
+ shard.handle_event(event).await.ok();
+
+ // Ensure rustls crypto provider is installed
+ if rustls::crypto::CryptoProvider::get_default().is_none()
+ && let Err(e) =
rustls::crypto::ring::default_provider().install_default()
+ {
+ shard_warn!(
+ shard.id,
+ "Failed to install rustls crypto provider: {:?}. This may be
normal if another thread installed it first.",
+ e
+ );
+ } else {
+ trace!("Rustls crypto provider installed or already present");
+ }
+
+ // Load or generate TLS certificates
+ let tls_config = &shard.config.tcp.tls;
+ let (certs, key) =
+ if tls_config.self_signed &&
!std::path::Path::new(&tls_config.cert_file).exists() {
+ shard_info!(
+ shard.id,
+ "Generating self-signed certificate for WebSocket TLS server"
+ );
+ generate_self_signed_cert()
+ .unwrap_or_else(|e| panic!("Failed to generate self-signed
certificate: {e}"))
+ } else {
+ shard_info!(
+ shard.id,
+ "Loading certificates from cert_file: {}, key_file: {}",
+ tls_config.cert_file,
+ tls_config.key_file
+ );
+ load_certificates(&tls_config.cert_file, &tls_config.key_file)
+ .unwrap_or_else(|e| panic!("Failed to load certificates: {e}"))
+ };
+
+ let server_config = ServerConfig::builder()
+ .with_no_client_auth()
+ .with_single_cert(certs, key)
+ .unwrap_or_else(|e| panic!("Unable to create TLS server config: {e}"));
+
+ let acceptor = TlsAcceptor::from(Arc::new(server_config));
+
+ shard_info!(
+ shard.id,
+ "{} has started on: wss://{}",
+ "WebSocket TLS Server",
+ local_addr
+ );
+
+ let ws_config = config.to_tungstenite_config();
+ shard_info!(
+ shard.id,
+ "WebSocket TLS config: max_message_size: {:?}, max_frame_size: {:?},
accept_unmasked_frames: {}",
+ config.max_message_size,
+ config.max_frame_size,
+ config.accept_unmasked_frames
+ );
+
+ let result = accept_loop(listener, acceptor, ws_config, shard.clone(),
shutdown).await;
+
+ shard_info!(
+ shard.id,
+ "WebSocket TLS listener task exiting with result: {:?}",
+ result
+ );
+
+ result
+}
+
+async fn accept_loop(
+ listener: TcpListener,
+ acceptor: TlsAcceptor,
+ ws_config: Option<compio_ws::WebSocketConfig>,
+ shard: Rc<IggyShard>,
+ shutdown: ShutdownToken,
+) -> Result<(), IggyError> {
+ shard_info!(
+ shard.id,
+ "WebSocket TLS accept loop started, waiting for connections..."
+ );
+
+ loop {
+ let shard = shard.clone();
+ let ws_config = ws_config.clone();
+ let acceptor = acceptor.clone();
+ let accept_future = listener.accept();
+
+ futures::select! {
+ _ = shutdown.wait().fuse() => {
+ shard_debug!(shard.id, "WebSocket TLS Server received shutdown
signal, no longer accepting connections");
+ break;
+ }
+ result = accept_future.fuse() => {
+ match result {
+ Ok((tcp_stream, remote_addr)) => {
+ if shard.is_shutting_down() {
+ shard_info!(shard.id, "Rejecting new WebSocket TLS
connection from {} during shutdown", remote_addr);
+ continue;
+ }
+ shard_info!(shard.id, "Accepted new TCP connection for
WebSocket TLS handshake from: {}", remote_addr);
+
+ let shard_clone = shard.clone();
+ let ws_config_clone = ws_config.clone();
+ let registry = shard.task_registry.clone();
+ let registry_clone = registry.clone();
+
+ registry.spawn_connection(async move {
+ match acceptor.accept(tcp_stream).await {
+ Ok(tls_stream) => {
+ shard_info!(shard_clone.id, "TLS handshake
successful for {}, performing WebSocket upgrade...", remote_addr);
+
+ match accept_async_with_config(tls_stream,
ws_config_clone).await {
+ Ok(websocket) => {
+ info!("WebSocket TLS handshake
successful from: {}", remote_addr);
+
+ let session =
shard_clone.add_client(&remote_addr, TransportProtocol::WebSocket);
+ let client_id = session.client_id;
+
shard_clone.add_active_session(session.clone());
+
+ let event = ShardEvent::NewSession
{
+ address: remote_addr,
+ transport:
TransportProtocol::WebSocket,
+ };
+ let _ =
shard_clone.broadcast_event_to_all_shards(event).await;
+
+ let sender =
crate::websocket::websocket_tls_sender::WebSocketTlsSender::new(websocket);
+ let mut sender_kind =
SenderKind::WebSocketTls(sender);
+ let client_stop_receiver =
registry_clone.add_connection(client_id);
+
+ if let Err(error) =
handle_connection(&session, &mut sender_kind, &shard_clone,
client_stop_receiver).await {
+ handle_error(error);
+ }
+
registry_clone.remove_connection(&client_id);
+
+ if let Err(error) =
sender_kind.shutdown().await {
+ shard_error!(shard_clone.id,
"Failed to shutdown WebSocket TLS stream for client: {}, address: {}. {}",
client_id, remote_addr, error);
+ } else {
+ shard_info!(shard_clone.id,
"Successfully closed WebSocket TLS stream for client: {}, address: {}.",
client_id, remote_addr);
+ }
+ }
+ Err(error) => {
+ error!("WebSocket handshake failed
on TLS connection from {}: {:?}", remote_addr, error);
+ }
+ }
+ }
+ Err(error) => {
+ error!("TLS handshake failed from {}:
{:?}", remote_addr, error);
+ }
+ }
+ });
+ }
+ Err(error) => {
+ shard_error!(shard.id, "Failed to accept WebSocket TLS
connection: {}", error);
+ }
+ }
+ }
+ }
+ }
+
+ shard_info!(shard.id, "WebSocket TLS Server listener has stopped");
+ Ok(())
+}
+
+fn generate_self_signed_cert()
+-> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), Box<dyn
std::error::Error>> {
+ iggy_common::generate_self_signed_certificate("localhost")
+}
+
+fn load_certificates(
+ cert_file: &str,
+ key_file: &str,
+) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), Box<dyn
std::error::Error>> {
+ let cert_file = std::fs::File::open(cert_file)?;
+ let mut cert_reader = BufReader::new(cert_file);
+ let certs: Vec<_> = certs(&mut cert_reader).collect::<Result<Vec<_>,
_>>()?;
+
+ if certs.is_empty() {
+ return Err("No certificates found in certificate file".into());
+ }
+
+ let key_file = std::fs::File::open(key_file)?;
+ let mut key_reader = BufReader::new(key_file);
+ let key = private_key(&mut key_reader)?.ok_or("No private key found in key
file")?;
+
+ Ok((certs, key))
+}
diff --git a/core/server/src/websocket/websocket_tls_sender.rs
b/core/server/src/websocket/websocket_tls_sender.rs
new file mode 100644
index 000000000..06f77161b
--- /dev/null
+++ b/core/server/src/websocket/websocket_tls_sender.rs
@@ -0,0 +1,186 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use crate::binary::sender::Sender;
+use crate::server_error::ServerError;
+use crate::streaming::utils::PooledBuffer;
+use bytes::{BufMut, BytesMut};
+use compio::buf::IoBufMut;
+use compio::net::TcpStream;
+use compio_tls::TlsStream;
+use compio_ws::TungsteniteError;
+use compio_ws::{WebSocketMessage as Message, WebSocketStream};
+use iggy_common::IggyError;
+use std::ptr;
+use tracing::debug;
+
+const READ_BUFFER_CAPACITY: usize = 8192;
+const WRITE_BUFFER_CAPACITY: usize = 8192;
+const STATUS_OK: &[u8] = &[0; 4];
+
+pub struct WebSocketTlsSender {
+ pub(crate) stream: WebSocketStream<TlsStream<TcpStream>>,
+ pub(crate) read_buffer: BytesMut,
+ pub(crate) write_buffer: BytesMut,
+}
+
+impl WebSocketTlsSender {
+ pub fn new(stream: WebSocketStream<TlsStream<TcpStream>>) -> Self {
+ Self {
+ stream,
+ read_buffer: BytesMut::with_capacity(READ_BUFFER_CAPACITY),
+ write_buffer: BytesMut::with_capacity(WRITE_BUFFER_CAPACITY),
+ }
+ }
+
+ async fn flush_write_buffer(&mut self) -> Result<(), IggyError> {
+ if self.write_buffer.is_empty() {
+ return Ok(());
+ }
+ let data = self.write_buffer.split().freeze();
+ self.stream
+ .send(Message::Binary(data))
+ .await
+ .map_err(|_| IggyError::TcpError)
+ }
+}
+
+impl std::fmt::Debug for WebSocketTlsSender {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("WebSocketTlsSender").finish()
+ }
+}
+
+impl Sender for WebSocketTlsSender {
+ async fn read<B: IoBufMut>(&mut self, mut buffer: B) -> (Result<(),
IggyError>, B) {
+ let required_len = buffer.buf_capacity();
+ if required_len == 0 {
+ return (Ok(()), buffer);
+ }
+
+ while self.read_buffer.len() < required_len {
+ match self.stream.read().await {
+ Ok(Message::Binary(data)) => {
+ self.read_buffer.extend_from_slice(&data);
+ }
+ Ok(Message::Close(_)) => {
+ return (Err(IggyError::ConnectionClosed), buffer);
+ }
+ Ok(Message::Ping(data)) => {
+ if self.stream.send(Message::Pong(data)).await.is_err() {
+ return (Err(IggyError::ConnectionClosed), buffer);
+ }
+ }
+ Ok(_) => { /* Ignore other message types */ }
+ Err(_) => {
+ return (Err(IggyError::ConnectionClosed), buffer);
+ }
+ }
+ }
+
+ let data_to_copy = self.read_buffer.split_to(required_len);
+
+ unsafe {
+ ptr::copy_nonoverlapping(data_to_copy.as_ptr(),
buffer.as_buf_mut_ptr(), required_len);
+ buffer.set_buf_init(required_len);
+ }
+
+ (Ok(()), buffer)
+ }
+
+ async fn send_empty_ok_response(&mut self) -> Result<(), IggyError> {
+ self.send_ok_response(&[]).await
+ }
+
+ async fn send_ok_response(&mut self, payload: &[u8]) -> Result<(),
IggyError> {
+ debug!(
+ "Sending WebSocket TLS response with status: OK, payload length:
{}",
+ payload.len()
+ );
+
+ let length = (payload.len() as u32).to_le_bytes();
+ let total_size = STATUS_OK.len() + length.len() + payload.len();
+
+ if self.write_buffer.len() + total_size > self.write_buffer.capacity()
{
+ self.flush_write_buffer().await?;
+ }
+
+ self.write_buffer.put_slice(STATUS_OK);
+ self.write_buffer.put_slice(&length);
+ self.write_buffer.put_slice(payload);
+
+ self.flush_write_buffer().await
+ }
+
+ async fn send_error_response(&mut self, error: IggyError) -> Result<(),
IggyError> {
+ let status = &error.as_code().to_le_bytes();
+ debug!(
+ "Sending WebSocket TLS error response with status: {:?}",
+ status
+ );
+ let length = 0u32.to_le_bytes();
+ let total_size = status.len() + length.len();
+
+ if self.write_buffer.len() + total_size > self.write_buffer.capacity()
{
+ self.flush_write_buffer().await?;
+ }
+ self.write_buffer.put_slice(status);
+ self.write_buffer.put_slice(&length);
+ self.flush_write_buffer().await
+ }
+
+ async fn shutdown(&mut self) -> Result<(), ServerError> {
+ self.flush_write_buffer().await.map_err(ServerError::from)?;
+
+ match self.stream.close(None).await {
+ Ok(_) => Ok(()),
+ Err(e) => match e {
+ TungsteniteError::ConnectionClosed |
TungsteniteError::AlreadyClosed => {
+ debug!("WebSocket TLS connection already closed: {}", e);
+ Ok(())
+ }
+ _ => Err(ServerError::from(
+ IggyError::CannotCloseWebSocketConnection(format!("{}",
e)),
+ )),
+ },
+ }
+ }
+
+ async fn send_ok_response_vectored(
+ &mut self,
+ length: &[u8],
+ slices: Vec<PooledBuffer>,
+ ) -> Result<(), IggyError> {
+ self.flush_write_buffer().await?;
+
+ let total_payload_size = slices.iter().map(|s| s.len()).sum::<usize>();
+ let total_size = STATUS_OK.len() + length.len() + total_payload_size;
+
+ let mut response_bytes = BytesMut::with_capacity(total_size);
+ response_bytes.put_slice(STATUS_OK);
+ response_bytes.put_slice(length);
+ for slice in slices {
+ response_bytes.put_slice(&slice);
+ }
+
+ self.stream
+ .send(Message::Binary(response_bytes.freeze()))
+ .await
+ .map_err(|_| IggyError::TcpError)
+ }
+}
diff --git a/examples/rust/src/shared/args.rs b/examples/rust/src/shared/args.rs
index aa5eb47f6..66efb125e 100644
--- a/examples/rust/src/shared/args.rs
+++ b/examples/rust/src/shared/args.rs
@@ -176,6 +176,18 @@ pub struct Args {
#[arg(long, default_value = "5s")]
pub websocket_heartbeat_interval: String,
+
+ #[arg(long, default_value = "false")]
+ pub websocket_tls_enabled: bool,
+
+ #[arg(long, default_value = "localhost")]
+ pub websocket_tls_domain: String,
+
+ #[arg(long, default_value = "false")]
+ pub websocket_tls_ca_file: Option<String>,
+
+ #[arg(long, default_value = "false")]
+ pub websocket_tls_validate_certificate: bool,
}
impl Args {
@@ -243,6 +255,10 @@ impl Default for Args {
websocket_reconnection_interval: "1s".to_string(),
websocket_reconnection_reestablish_after: "5s".to_string(),
websocket_heartbeat_interval: "5s".to_string(),
+ websocket_tls_enabled: false,
+ websocket_tls_domain: "localhost".to_string(),
+ websocket_tls_ca_file: None,
+ websocket_tls_validate_certificate: false,
}
}
}
@@ -344,6 +360,10 @@ impl Args {
.websocket_reconnection_reestablish_after
.clone(),
websocket_heartbeat_interval:
self.websocket_heartbeat_interval.clone(),
+ websocket_tls_enabled: self.websocket_tls_enabled,
+ websocket_tls_domain: self.websocket_tls_domain.clone(),
+ websocket_tls_ca_file: self.websocket_tls_ca_file.clone(),
+ websocket_tls_validate_certificate:
self.websocket_tls_validate_certificate,
}
}