This is an automated email from the ASF dual-hosted git repository.
piotr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iggy.git
The following commit(s) were added to refs/heads/master by this push:
new f7ade4b60 fix(server): improve shutdown handling for critical task
failures (#2506)
f7ade4b60 is described below
commit f7ade4b6029559322d92b96736b232803be0a995
Author: Hubert Gruszecki <[email protected]>
AuthorDate: Fri Dec 19 10:57:19 2025 +0100
fix(server): improve shutdown handling for critical task failures (#2506)
- Add broadcast shutdown to all shards when critical task fails or panics
- Mark HTTP server as critical task
- Config writer now respects shutdown token via futures::select!
- CORS parsing returns Result instead of panicking on invalid config
- Enable err_trail tracing feature for automatic error logging
- Wrap continuous tasks with catch_unwind for panic detection
---
Cargo.lock | 3 +
Cargo.toml | 2 +-
core/integration/tests/streaming/mod.rs | 7 +-
core/server/src/http/http_server.rs | 80 ++++++++++++++--------
core/server/src/shard/builder.rs | 17 ++---
core/server/src/shard/mod.rs | 5 +-
core/server/src/shard/task_registry/registry.rs | 74 ++++++++++++++++----
.../src/shard/tasks/continuous/http_server.rs | 1 +
.../src/shard/tasks/oneshot/config_writer.rs | 33 ++++++---
9 files changed, 151 insertions(+), 71 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 8c079c767..085a4802e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3008,6 +3008,9 @@ name = "err_trail"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85b9e8330eccf84d08fb8efe2f923ddacc9f02c1359edfc33cc0af4100caf764"
+dependencies = [
+ "tracing",
+]
[[package]]
name = "errno"
diff --git a/Cargo.toml b/Cargo.toml
index f2491ac6d..dd31ec539 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -123,7 +123,7 @@ dlopen2 = "0.8.2"
dotenvy = "0.15.7"
enum_dispatch = "0.3.13"
env_logger = "0.11.8"
-err_trail = "0.10.2"
+err_trail = { version = "0.10.2", features = ["tracing"] }
error_set = "0.9.1"
figlet-rs = "0.1.5"
figment = { version = "0.10.19", features = ["toml", "env"] }
diff --git a/core/integration/tests/streaming/mod.rs
b/core/integration/tests/streaming/mod.rs
index 0fde4b3c5..6ad773647 100644
--- a/core/integration/tests/streaming/mod.rs
+++ b/core/integration/tests/streaming/mod.rs
@@ -19,7 +19,7 @@
use iggy_common::{CompressionAlgorithm, Identifier, IggyError, IggyExpiry,
MaxTopicSize};
use server::{
configs::system::SystemConfig,
- shard::task_registry::TaskRegistry,
+ shard::{task_registry::TaskRegistry,
transmission::connector::ShardConnector},
slab::{streams::Streams, traits_ext::EntityMarker},
streaming::{
self,
@@ -118,8 +118,9 @@ async fn bootstrap_test_environment(
});
}
- // Create a test task registry
- let task_registry = Rc::new(TaskRegistry::new(shard_id));
+ // Create a test task registry with dummy stop sender from ShardConnector
+ let connector: ShardConnector<()> = ShardConnector::new(shard_id);
+ let task_registry = Rc::new(TaskRegistry::new(shard_id,
vec![connector.stop_sender]));
Ok(BootstrapResult {
streams,
diff --git a/core/server/src/http/http_server.rs
b/core/server/src/http/http_server.rs
index 669ff772c..3a114aecf 100644
--- a/core/server/src/http/http_server.rs
+++ b/core/server/src/http/http_server.rs
@@ -36,6 +36,7 @@ use axum::http::Method;
use axum::{Router, middleware};
use axum_server::tls_rustls::RustlsConfig;
use compio_net::TcpListener;
+use err_trail::ErrContext;
use iggy_common::IggyError;
use iggy_common::TransportProtocol;
use std::net::SocketAddr;
@@ -106,7 +107,7 @@ pub async fn start_http_server(
.layer(middleware::from_fn_with_state(app_state.clone(), jwt_auth));
if config.cors.enabled {
- app = app.layer(configure_cors(config.cors));
+ app = app.layer(configure_cors(config.cors)?);
}
if config.metrics.enabled {
@@ -250,50 +251,73 @@ async fn build_app_state(
})
}
-fn configure_cors(config: HttpCorsConfig) -> CorsLayer {
+fn configure_cors(config: HttpCorsConfig) -> Result<CorsLayer, IggyError> {
let allowed_origins = match config.allowed_origins {
- origins if origins.is_empty() => AllowOrigin::default(),
- origins if origins.first().unwrap() == "*" => AllowOrigin::any(),
- origins => AllowOrigin::list(origins.iter().map(|s|
s.parse().unwrap())),
+ ref origins if origins.is_empty() => AllowOrigin::default(),
+ ref origins if origins.first().unwrap() == "*" => AllowOrigin::any(),
+ origins => {
+ let parsed: Result<Vec<_>, _> = origins
+ .iter()
+ .filter(|s| !s.trim().is_empty())
+ .map(|s| {
+ s.parse()
+ .with_error(|e| format!("Invalid CORS origin '{s}':
{e}"))
+ .map_err(|_| IggyError::InvalidConfiguration)
+ })
+ .collect();
+ AllowOrigin::list(parsed?)
+ }
};
- let allowed_headers = config
+ let allowed_headers: Result<Vec<_>, _> = config
.allowed_headers
.iter()
- .filter(|s| !s.is_empty())
- .map(|s| s.parse().unwrap())
- .collect::<Vec<_>>();
+ .filter(|s| !s.trim().is_empty())
+ .map(|s| {
+ s.parse()
+ .with_error(|e| format!("Invalid CORS header '{s}': {e}"))
+ .map_err(|_| IggyError::InvalidConfiguration)
+ })
+ .collect();
+ let allowed_headers = allowed_headers?;
- let exposed_headers = config
+ let exposed_headers: Result<Vec<_>, _> = config
.exposed_headers
.iter()
- .filter(|s| !s.is_empty())
- .map(|s| s.parse().unwrap())
- .collect::<Vec<_>>();
+ .filter(|s| !s.trim().is_empty())
+ .map(|s| {
+ s.parse()
+ .with_error(|e| format!("Invalid CORS exposed header '{s}':
{e}"))
+ .map_err(|_| IggyError::InvalidConfiguration)
+ })
+ .collect();
+ let exposed_headers = exposed_headers?;
- let allowed_methods = config
+ let allowed_methods: Result<Vec<_>, _> = config
.allowed_methods
.iter()
- .filter(|s| !s.is_empty())
+ .filter(|s| !s.trim().is_empty())
.map(|s| match s.to_uppercase().as_str() {
- "GET" => Method::GET,
- "POST" => Method::POST,
- "PUT" => Method::PUT,
- "DELETE" => Method::DELETE,
- "HEAD" => Method::HEAD,
- "OPTIONS" => Method::OPTIONS,
- "CONNECT" => Method::CONNECT,
- "PATCH" => Method::PATCH,
- "TRACE" => Method::TRACE,
- _ => panic!("Invalid HTTP method: {s}"),
+ "GET" => Ok(Method::GET),
+ "POST" => Ok(Method::POST),
+ "PUT" => Ok(Method::PUT),
+ "DELETE" => Ok(Method::DELETE),
+ "HEAD" => Ok(Method::HEAD),
+ "OPTIONS" => Ok(Method::OPTIONS),
+ "CONNECT" => Ok(Method::CONNECT),
+ "PATCH" => Ok(Method::PATCH),
+ "TRACE" => Ok(Method::TRACE),
+ _ => Err(IggyError::InvalidConfiguration)
+ .with_error(|_| format!("Invalid HTTP method in CORS config:
'{s}'")),
})
- .collect::<Vec<_>>();
+ .collect();
+ let allowed_methods = allowed_methods?;
- CorsLayer::new()
+ Ok(CorsLayer::new()
.allow_methods(allowed_methods)
.allow_origin(allowed_origins)
.allow_headers(allowed_headers)
.expose_headers(exposed_headers)
.allow_credentials(config.allow_credentials)
- .allow_private_network(config.allow_private_network)
+ .allow_private_network(config.allow_private_network))
}
diff --git a/core/server/src/shard/builder.rs b/core/server/src/shard/builder.rs
index 98914543e..188452fe8 100644
--- a/core/server/src/shard/builder.rs
+++ b/core/server/src/shard/builder.rs
@@ -127,25 +127,22 @@ impl IggyShardBuilder {
let encryptor = self.encryptor;
let client_manager = self.client_manager.unwrap();
let version = self.version.unwrap();
- let (_, stop_receiver, frame_receiver) = connections
+ let (stop_receiver, frame_receiver) = connections
.iter()
.filter(|c| c.id == id)
- .map(|c| {
- (
- c.stop_sender.clone(),
- c.stop_receiver.clone(),
- c.receiver.clone(),
- )
- })
+ .map(|c| (c.stop_receiver.clone(), c.receiver.clone()))
.next()
.expect("Failed to find connection with the specified ID");
+
+ // Collect all stop_senders for broadcasting shutdown to all shards
+ let all_stop_senders: Vec<_> = connections.iter().map(|c|
c.stop_sender.clone()).collect();
let shards = connections;
// Initialize metrics
let metrics = self.metrics.unwrap_or_else(Metrics::init);
- // Create TaskRegistry for this shard
- let task_registry = Rc::new(TaskRegistry::new(id));
+ // Create TaskRegistry with all stop_senders for critical task failures
+ let task_registry = Rc::new(TaskRegistry::new(id, all_stop_senders));
// Create notification channel for config writer
let (config_writer_notify, config_writer_receiver) =
async_channel::bounded(1);
diff --git a/core/server/src/shard/mod.rs b/core/server/src/shard/mod.rs
index 64243f51e..32e07d76d 100644
--- a/core/server/src/shard/mod.rs
+++ b/core/server/src/shard/mod.rs
@@ -181,10 +181,7 @@ impl IggyShard {
// Spawn shutdown handler
compio::runtime::spawn(async move {
let _ = stop_receiver.recv().await;
- let shutdown_success = shard_for_shutdown.trigger_shutdown().await;
- if !shutdown_success {
- error!("shutdown timed out");
- }
+ shard_for_shutdown.trigger_shutdown().await;
let _ = shutdown_complete_tx.send(()).await;
})
.detach();
diff --git a/core/server/src/shard/task_registry/registry.rs
b/core/server/src/shard/task_registry/registry.rs
index 011924225..8e96db6eb 100644
--- a/core/server/src/shard/task_registry/registry.rs
+++ b/core/server/src/shard/task_registry/registry.rs
@@ -16,12 +16,15 @@
// under the License.
use super::shutdown::{Shutdown, ShutdownToken};
+use crate::shard::transmission::connector::StopSender;
use compio::runtime::JoinHandle;
+use futures::FutureExt;
use futures::future::join_all;
use iggy_common::IggyError;
use std::cell::RefCell;
use std::collections::HashMap;
use std::ops::{AsyncFn, AsyncFnOnce};
+use std::panic::AssertUnwindSafe;
use std::time::{Duration, Instant};
use tracing::{debug, error, trace, warn};
@@ -44,6 +47,7 @@ pub struct TaskRegistry {
shard_id: u16,
shutdown: Shutdown,
shutdown_token: ShutdownToken,
+ all_stop_senders: Vec<StopSender>,
long_running: RefCell<Vec<TaskHandle>>,
oneshots: RefCell<Vec<TaskHandle>>,
connections: RefCell<HashMap<u32, async_channel::Sender<()>>>,
@@ -51,12 +55,13 @@ pub struct TaskRegistry {
}
impl TaskRegistry {
- pub fn new(shard_id: u16) -> Self {
+ pub fn new(shard_id: u16, all_stop_senders: Vec<StopSender>) -> Self {
let (s, t) = Shutdown::new();
Self {
shard_id,
shutdown: s,
shutdown_token: t,
+ all_stop_senders,
long_running: RefCell::new(vec![]),
oneshots: RefCell::new(vec![]),
connections: RefCell::new(HashMap::new()),
@@ -88,15 +93,40 @@ impl TaskRegistry {
let shutdown = self.shutdown_token.clone();
let shard_id = self.shard_id;
+ let all_stop_senders = self.all_stop_senders.clone();
let handle = compio::runtime::spawn(async move {
trace!("continuous '{}' starting on shard {}", name, shard_id);
- let fut = f(shutdown);
- let r = fut.await;
- match &r {
- Ok(()) => debug!("continuous '{}' completed on shard {}",
name, shard_id),
- Err(e) => error!("continuous '{}' failed on shard {}: {}",
name, shard_id, e),
- }
+
+ let fut = AssertUnwindSafe(f(shutdown)).catch_unwind();
+ let result = fut.await;
+
+ let (r, should_trigger_shutdown) = match result {
+ Ok(r) => {
+ match &r {
+ Ok(()) => debug!("continuous '{}' completed on shard
{}", name, shard_id),
+ Err(e) => {
+ error!("continuous '{}' failed on shard {}: {}",
name, shard_id, e);
+ }
+ }
+ // Trigger shutdown for critical task errors
+ let trigger = critical && r.is_err();
+ (r, trigger)
+ }
+ Err(panic_payload) => {
+ let panic_msg = panic_payload
+ .downcast_ref::<&str>()
+ .map(|s| s.to_string())
+ .or_else(||
panic_payload.downcast_ref::<String>().cloned())
+ .unwrap_or_else(|| "unknown panic".to_string());
+ error!(
+ "continuous '{}' panicked on shard {}: {}",
+ name, shard_id, panic_msg
+ );
+ // Trigger shutdown for critical task panics
+ (Err(IggyError::Error), critical)
+ }
+ };
// Execute on_shutdown callback if provided
if let Some(shutdown_fn) = on_shutdown {
@@ -104,6 +134,17 @@ impl TaskRegistry {
shutdown_fn(r.clone()).await;
}
+ // Trigger shutdown for ALL shards when critical task fails
+ if should_trigger_shutdown {
+ error!(
+ "Critical task '{}' failed on shard {}, triggering
shutdown for all shards",
+ name, shard_id
+ );
+ for stop_sender in &all_stop_senders {
+ let _ = stop_sender.try_send(());
+ }
+ }
+
r
});
@@ -391,9 +432,14 @@ impl TaskRegistry {
mod tests {
use super::*;
+ fn create_test_registry(shard_id: u16) -> TaskRegistry {
+ let (stop_sender, _stop_receiver) = async_channel::bounded(1);
+ TaskRegistry::new(shard_id, vec![stop_sender])
+ }
+
#[compio::test]
async fn test_oneshot_completion_detection() {
- let registry = TaskRegistry::new(1);
+ let registry = create_test_registry(1);
// Spawn a failing non-critical task
registry
@@ -416,7 +462,7 @@ mod tests {
#[compio::test]
async fn test_oneshot_critical_failure() {
- let registry = TaskRegistry::new(1);
+ let registry = create_test_registry(1);
// Spawn a failing critical task
registry
@@ -434,7 +480,7 @@ mod tests {
#[compio::test]
async fn test_shutdown_prevents_spawning() {
- let registry = TaskRegistry::new(1);
+ let registry = create_test_registry(1);
// Trigger shutdown
*registry.shutting_down.borrow_mut() = true;
@@ -453,7 +499,7 @@ mod tests {
#[compio::test]
async fn test_timeout_error() {
- let registry = TaskRegistry::new(1);
+ let registry = create_test_registry(1);
// Create a task that will timeout
let handle = compio::runtime::spawn(async move {
@@ -479,7 +525,7 @@ mod tests {
#[compio::test]
async fn test_composite_timeout() {
- let registry = TaskRegistry::new(1);
+ let registry = create_test_registry(1);
// Create a long-running task that takes 100ms
let long_handle = compio::runtime::spawn(async move {
@@ -516,7 +562,7 @@ mod tests {
#[compio::test]
async fn test_composite_timeout_insufficient() {
- let registry = TaskRegistry::new(1);
+ let registry = create_test_registry(1);
// Create a long-running task that takes 50ms
let long_handle = compio::runtime::spawn(async move {
@@ -555,7 +601,7 @@ mod tests {
async fn test_periodic_last_tick_timeout() {
// This test verifies that periodic tasks with last_tick_on_shutdown
// don't hang shutdown if the final tick takes too long
- let registry = TaskRegistry::new(1);
+ let registry = create_test_registry(1);
// Create a handle that simulates a periodic task whose final tick
will hang
let handle = compio::runtime::spawn(async move {
diff --git a/core/server/src/shard/tasks/continuous/http_server.rs
b/core/server/src/shard/tasks/continuous/http_server.rs
index 24518426a..40a024cf3 100644
--- a/core/server/src/shard/tasks/continuous/http_server.rs
+++ b/core/server/src/shard/tasks/continuous/http_server.rs
@@ -29,6 +29,7 @@ pub fn spawn_http_server(shard: Rc<IggyShard>) {
shard
.task_registry
.continuous("http_server")
+ .critical(true)
.run(move |shutdown| http_server(shard_clone, shutdown))
.spawn();
}
diff --git a/core/server/src/shard/tasks/oneshot/config_writer.rs
b/core/server/src/shard/tasks/oneshot/config_writer.rs
index 19d04afe9..7399782c9 100644
--- a/core/server/src/shard/tasks/oneshot/config_writer.rs
+++ b/core/server/src/shard/tasks/oneshot/config_writer.rs
@@ -16,11 +16,13 @@
// under the License.
use crate::shard::IggyShard;
+use crate::shard::task_registry::ShutdownToken;
use compio::io::AsyncWriteAtExt;
use err_trail::ErrContext;
+use futures::FutureExt;
use iggy_common::IggyError;
use std::rc::Rc;
-use tracing::info;
+use tracing::{info, warn};
pub fn spawn_config_writer_task(shard: &Rc<IggyShard>) {
let shard_clone = shard.clone();
@@ -28,11 +30,14 @@ pub fn spawn_config_writer_task(shard: &Rc<IggyShard>) {
.task_registry
.oneshot("config_writer")
.critical(false)
- .run(move |_shutdown_token| async move {
write_config(shard_clone).await })
+ .run(move |shutdown_token| async move { write_config(shard_clone,
shutdown_token).await })
.spawn();
}
-async fn write_config(shard: Rc<IggyShard>) -> Result<(), IggyError> {
+async fn write_config(
+ shard: Rc<IggyShard>,
+ shutdown_token: ShutdownToken,
+) -> Result<(), IggyError> {
let shard_clone = shard.clone();
let tcp_enabled = shard.config.tcp.enabled;
let quic_enabled = shard.config.quic.enabled;
@@ -41,15 +46,21 @@ async fn write_config(shard: Rc<IggyShard>) -> Result<(),
IggyError> {
let notify_receiver = shard_clone.config_writer_receiver.clone();
- // Wait for notifications until all servers have bound
+ // Wait for notifications until all servers have bound, or shutdown is
triggered
loop {
- notify_receiver
- .recv()
- .await
- .map_err(|_| IggyError::CannotWriteToFile)
- .with_error(
- |_| "config_writer: notification channel closed before all
servers bound",
- )?;
+ futures::select! {
+ _ = shutdown_token.wait().fuse() => {
+ warn!("config_writer: shutdown triggered before all servers
bound, skipping config write");
+ return Ok(());
+ }
+ result = notify_receiver.recv().fuse() => {
+ if result.is_err() {
+ return Err(IggyError::CannotWriteToFile).with_error(
+ |_| "config_writer: notification channel closed before
all servers bound",
+ );
+ }
+ }
+ }
let tcp_ready = !tcp_enabled ||
shard_clone.tcp_bound_address.get().is_some();
let quic_ready = !quic_enabled ||
shard_clone.quic_bound_address.get().is_some();