This is an automated email from the ASF dual-hosted git repository.

piotr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iggy.git


The following commit(s) were added to refs/heads/master by this push:
     new f7ade4b60 fix(server): improve shutdown handling for critical task 
failures (#2506)
f7ade4b60 is described below

commit f7ade4b6029559322d92b96736b232803be0a995
Author: Hubert Gruszecki <[email protected]>
AuthorDate: Fri Dec 19 10:57:19 2025 +0100

    fix(server): improve shutdown handling for critical task failures (#2506)
    
    - Add broadcast shutdown to all shards when critical task fails or panics
    - Mark HTTP server as critical task
    - Config writer now respects shutdown token via futures::select!
    - CORS parsing returns Result instead of panicking on invalid config
    - Enable err_trail tracing feature for automatic error logging
    - Wrap continuous tasks with catch_unwind for panic detection
---
 Cargo.lock                                         |  3 +
 Cargo.toml                                         |  2 +-
 core/integration/tests/streaming/mod.rs            |  7 +-
 core/server/src/http/http_server.rs                | 80 ++++++++++++++--------
 core/server/src/shard/builder.rs                   | 17 ++---
 core/server/src/shard/mod.rs                       |  5 +-
 core/server/src/shard/task_registry/registry.rs    | 74 ++++++++++++++++----
 .../src/shard/tasks/continuous/http_server.rs      |  1 +
 .../src/shard/tasks/oneshot/config_writer.rs       | 33 ++++++---
 9 files changed, 151 insertions(+), 71 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 8c079c767..085a4802e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3008,6 +3008,9 @@ name = "err_trail"
 version = "0.10.2"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "85b9e8330eccf84d08fb8efe2f923ddacc9f02c1359edfc33cc0af4100caf764"
+dependencies = [
+ "tracing",
+]
 
 [[package]]
 name = "errno"
diff --git a/Cargo.toml b/Cargo.toml
index f2491ac6d..dd31ec539 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -123,7 +123,7 @@ dlopen2 = "0.8.2"
 dotenvy = "0.15.7"
 enum_dispatch = "0.3.13"
 env_logger = "0.11.8"
-err_trail = "0.10.2"
+err_trail = { version = "0.10.2", features = ["tracing"] }
 error_set = "0.9.1"
 figlet-rs = "0.1.5"
 figment = { version = "0.10.19", features = ["toml", "env"] }
diff --git a/core/integration/tests/streaming/mod.rs 
b/core/integration/tests/streaming/mod.rs
index 0fde4b3c5..6ad773647 100644
--- a/core/integration/tests/streaming/mod.rs
+++ b/core/integration/tests/streaming/mod.rs
@@ -19,7 +19,7 @@
 use iggy_common::{CompressionAlgorithm, Identifier, IggyError, IggyExpiry, 
MaxTopicSize};
 use server::{
     configs::system::SystemConfig,
-    shard::task_registry::TaskRegistry,
+    shard::{task_registry::TaskRegistry, 
transmission::connector::ShardConnector},
     slab::{streams::Streams, traits_ext::EntityMarker},
     streaming::{
         self,
@@ -118,8 +118,9 @@ async fn bootstrap_test_environment(
         });
     }
 
-    // Create a test task registry
-    let task_registry = Rc::new(TaskRegistry::new(shard_id));
+    // Create a test task registry with dummy stop sender from ShardConnector
+    let connector: ShardConnector<()> = ShardConnector::new(shard_id);
+    let task_registry = Rc::new(TaskRegistry::new(shard_id, 
vec![connector.stop_sender]));
 
     Ok(BootstrapResult {
         streams,
diff --git a/core/server/src/http/http_server.rs 
b/core/server/src/http/http_server.rs
index 669ff772c..3a114aecf 100644
--- a/core/server/src/http/http_server.rs
+++ b/core/server/src/http/http_server.rs
@@ -36,6 +36,7 @@ use axum::http::Method;
 use axum::{Router, middleware};
 use axum_server::tls_rustls::RustlsConfig;
 use compio_net::TcpListener;
+use err_trail::ErrContext;
 use iggy_common::IggyError;
 use iggy_common::TransportProtocol;
 use std::net::SocketAddr;
@@ -106,7 +107,7 @@ pub async fn start_http_server(
         .layer(middleware::from_fn_with_state(app_state.clone(), jwt_auth));
 
     if config.cors.enabled {
-        app = app.layer(configure_cors(config.cors));
+        app = app.layer(configure_cors(config.cors)?);
     }
 
     if config.metrics.enabled {
@@ -250,50 +251,73 @@ async fn build_app_state(
     })
 }
 
-fn configure_cors(config: HttpCorsConfig) -> CorsLayer {
+fn configure_cors(config: HttpCorsConfig) -> Result<CorsLayer, IggyError> {
     let allowed_origins = match config.allowed_origins {
-        origins if origins.is_empty() => AllowOrigin::default(),
-        origins if origins.first().unwrap() == "*" => AllowOrigin::any(),
-        origins => AllowOrigin::list(origins.iter().map(|s| 
s.parse().unwrap())),
+        ref origins if origins.is_empty() => AllowOrigin::default(),
+        ref origins if origins.first().unwrap() == "*" => AllowOrigin::any(),
+        origins => {
+            let parsed: Result<Vec<_>, _> = origins
+                .iter()
+                .filter(|s| !s.trim().is_empty())
+                .map(|s| {
+                    s.parse()
+                        .with_error(|e| format!("Invalid CORS origin '{s}': 
{e}"))
+                        .map_err(|_| IggyError::InvalidConfiguration)
+                })
+                .collect();
+            AllowOrigin::list(parsed?)
+        }
     };
 
-    let allowed_headers = config
+    let allowed_headers: Result<Vec<_>, _> = config
         .allowed_headers
         .iter()
-        .filter(|s| !s.is_empty())
-        .map(|s| s.parse().unwrap())
-        .collect::<Vec<_>>();
+        .filter(|s| !s.trim().is_empty())
+        .map(|s| {
+            s.parse()
+                .with_error(|e| format!("Invalid CORS header '{s}': {e}"))
+                .map_err(|_| IggyError::InvalidConfiguration)
+        })
+        .collect();
+    let allowed_headers = allowed_headers?;
 
-    let exposed_headers = config
+    let exposed_headers: Result<Vec<_>, _> = config
         .exposed_headers
         .iter()
-        .filter(|s| !s.is_empty())
-        .map(|s| s.parse().unwrap())
-        .collect::<Vec<_>>();
+        .filter(|s| !s.trim().is_empty())
+        .map(|s| {
+            s.parse()
+                .with_error(|e| format!("Invalid CORS exposed header '{s}': 
{e}"))
+                .map_err(|_| IggyError::InvalidConfiguration)
+        })
+        .collect();
+    let exposed_headers = exposed_headers?;
 
-    let allowed_methods = config
+    let allowed_methods: Result<Vec<_>, _> = config
         .allowed_methods
         .iter()
-        .filter(|s| !s.is_empty())
+        .filter(|s| !s.trim().is_empty())
         .map(|s| match s.to_uppercase().as_str() {
-            "GET" => Method::GET,
-            "POST" => Method::POST,
-            "PUT" => Method::PUT,
-            "DELETE" => Method::DELETE,
-            "HEAD" => Method::HEAD,
-            "OPTIONS" => Method::OPTIONS,
-            "CONNECT" => Method::CONNECT,
-            "PATCH" => Method::PATCH,
-            "TRACE" => Method::TRACE,
-            _ => panic!("Invalid HTTP method: {s}"),
+            "GET" => Ok(Method::GET),
+            "POST" => Ok(Method::POST),
+            "PUT" => Ok(Method::PUT),
+            "DELETE" => Ok(Method::DELETE),
+            "HEAD" => Ok(Method::HEAD),
+            "OPTIONS" => Ok(Method::OPTIONS),
+            "CONNECT" => Ok(Method::CONNECT),
+            "PATCH" => Ok(Method::PATCH),
+            "TRACE" => Ok(Method::TRACE),
+            _ => Err(IggyError::InvalidConfiguration)
+                .with_error(|_| format!("Invalid HTTP method in CORS config: 
'{s}'")),
         })
-        .collect::<Vec<_>>();
+        .collect();
+    let allowed_methods = allowed_methods?;
 
-    CorsLayer::new()
+    Ok(CorsLayer::new()
         .allow_methods(allowed_methods)
         .allow_origin(allowed_origins)
         .allow_headers(allowed_headers)
         .expose_headers(exposed_headers)
         .allow_credentials(config.allow_credentials)
-        .allow_private_network(config.allow_private_network)
+        .allow_private_network(config.allow_private_network))
 }
diff --git a/core/server/src/shard/builder.rs b/core/server/src/shard/builder.rs
index 98914543e..188452fe8 100644
--- a/core/server/src/shard/builder.rs
+++ b/core/server/src/shard/builder.rs
@@ -127,25 +127,22 @@ impl IggyShardBuilder {
         let encryptor = self.encryptor;
         let client_manager = self.client_manager.unwrap();
         let version = self.version.unwrap();
-        let (_, stop_receiver, frame_receiver) = connections
+        let (stop_receiver, frame_receiver) = connections
             .iter()
             .filter(|c| c.id == id)
-            .map(|c| {
-                (
-                    c.stop_sender.clone(),
-                    c.stop_receiver.clone(),
-                    c.receiver.clone(),
-                )
-            })
+            .map(|c| (c.stop_receiver.clone(), c.receiver.clone()))
             .next()
             .expect("Failed to find connection with the specified ID");
+
+        // Collect all stop_senders for broadcasting shutdown to all shards
+        let all_stop_senders: Vec<_> = connections.iter().map(|c| 
c.stop_sender.clone()).collect();
         let shards = connections;
 
         // Initialize metrics
         let metrics = self.metrics.unwrap_or_else(Metrics::init);
 
-        // Create TaskRegistry for this shard
-        let task_registry = Rc::new(TaskRegistry::new(id));
+        // Create TaskRegistry with all stop_senders for critical task failures
+        let task_registry = Rc::new(TaskRegistry::new(id, all_stop_senders));
 
         // Create notification channel for config writer
         let (config_writer_notify, config_writer_receiver) = 
async_channel::bounded(1);
diff --git a/core/server/src/shard/mod.rs b/core/server/src/shard/mod.rs
index 64243f51e..32e07d76d 100644
--- a/core/server/src/shard/mod.rs
+++ b/core/server/src/shard/mod.rs
@@ -181,10 +181,7 @@ impl IggyShard {
         // Spawn shutdown handler
         compio::runtime::spawn(async move {
             let _ = stop_receiver.recv().await;
-            let shutdown_success = shard_for_shutdown.trigger_shutdown().await;
-            if !shutdown_success {
-                error!("shutdown timed out");
-            }
+            shard_for_shutdown.trigger_shutdown().await;
             let _ = shutdown_complete_tx.send(()).await;
         })
         .detach();
diff --git a/core/server/src/shard/task_registry/registry.rs 
b/core/server/src/shard/task_registry/registry.rs
index 011924225..8e96db6eb 100644
--- a/core/server/src/shard/task_registry/registry.rs
+++ b/core/server/src/shard/task_registry/registry.rs
@@ -16,12 +16,15 @@
 // under the License.
 
 use super::shutdown::{Shutdown, ShutdownToken};
+use crate::shard::transmission::connector::StopSender;
 use compio::runtime::JoinHandle;
+use futures::FutureExt;
 use futures::future::join_all;
 use iggy_common::IggyError;
 use std::cell::RefCell;
 use std::collections::HashMap;
 use std::ops::{AsyncFn, AsyncFnOnce};
+use std::panic::AssertUnwindSafe;
 use std::time::{Duration, Instant};
 use tracing::{debug, error, trace, warn};
 
@@ -44,6 +47,7 @@ pub struct TaskRegistry {
     shard_id: u16,
     shutdown: Shutdown,
     shutdown_token: ShutdownToken,
+    all_stop_senders: Vec<StopSender>,
     long_running: RefCell<Vec<TaskHandle>>,
     oneshots: RefCell<Vec<TaskHandle>>,
     connections: RefCell<HashMap<u32, async_channel::Sender<()>>>,
@@ -51,12 +55,13 @@ pub struct TaskRegistry {
 }
 
 impl TaskRegistry {
-    pub fn new(shard_id: u16) -> Self {
+    pub fn new(shard_id: u16, all_stop_senders: Vec<StopSender>) -> Self {
         let (s, t) = Shutdown::new();
         Self {
             shard_id,
             shutdown: s,
             shutdown_token: t,
+            all_stop_senders,
             long_running: RefCell::new(vec![]),
             oneshots: RefCell::new(vec![]),
             connections: RefCell::new(HashMap::new()),
@@ -88,15 +93,40 @@ impl TaskRegistry {
 
         let shutdown = self.shutdown_token.clone();
         let shard_id = self.shard_id;
+        let all_stop_senders = self.all_stop_senders.clone();
 
         let handle = compio::runtime::spawn(async move {
             trace!("continuous '{}' starting on shard {}", name, shard_id);
-            let fut = f(shutdown);
-            let r = fut.await;
-            match &r {
-                Ok(()) => debug!("continuous '{}' completed on shard {}", 
name, shard_id),
-                Err(e) => error!("continuous '{}' failed on shard {}: {}", 
name, shard_id, e),
-            }
+
+            let fut = AssertUnwindSafe(f(shutdown)).catch_unwind();
+            let result = fut.await;
+
+            let (r, should_trigger_shutdown) = match result {
+                Ok(r) => {
+                    match &r {
+                        Ok(()) => debug!("continuous '{}' completed on shard 
{}", name, shard_id),
+                        Err(e) => {
+                            error!("continuous '{}' failed on shard {}: {}", 
name, shard_id, e);
+                        }
+                    }
+                    // Trigger shutdown for critical task errors
+                    let trigger = critical && r.is_err();
+                    (r, trigger)
+                }
+                Err(panic_payload) => {
+                    let panic_msg = panic_payload
+                        .downcast_ref::<&str>()
+                        .map(|s| s.to_string())
+                        .or_else(|| 
panic_payload.downcast_ref::<String>().cloned())
+                        .unwrap_or_else(|| "unknown panic".to_string());
+                    error!(
+                        "continuous '{}' panicked on shard {}: {}",
+                        name, shard_id, panic_msg
+                    );
+                    // Trigger shutdown for critical task panics
+                    (Err(IggyError::Error), critical)
+                }
+            };
 
             // Execute on_shutdown callback if provided
             if let Some(shutdown_fn) = on_shutdown {
@@ -104,6 +134,17 @@ impl TaskRegistry {
                 shutdown_fn(r.clone()).await;
             }
 
+            // Trigger shutdown for ALL shards when critical task fails
+            if should_trigger_shutdown {
+                error!(
+                    "Critical task '{}' failed on shard {}, triggering 
shutdown for all shards",
+                    name, shard_id
+                );
+                for stop_sender in &all_stop_senders {
+                    let _ = stop_sender.try_send(());
+                }
+            }
+
             r
         });
 
@@ -391,9 +432,14 @@ impl TaskRegistry {
 mod tests {
     use super::*;
 
+    fn create_test_registry(shard_id: u16) -> TaskRegistry {
+        let (stop_sender, _stop_receiver) = async_channel::bounded(1);
+        TaskRegistry::new(shard_id, vec![stop_sender])
+    }
+
     #[compio::test]
     async fn test_oneshot_completion_detection() {
-        let registry = TaskRegistry::new(1);
+        let registry = create_test_registry(1);
 
         // Spawn a failing non-critical task
         registry
@@ -416,7 +462,7 @@ mod tests {
 
     #[compio::test]
     async fn test_oneshot_critical_failure() {
-        let registry = TaskRegistry::new(1);
+        let registry = create_test_registry(1);
 
         // Spawn a failing critical task
         registry
@@ -434,7 +480,7 @@ mod tests {
 
     #[compio::test]
     async fn test_shutdown_prevents_spawning() {
-        let registry = TaskRegistry::new(1);
+        let registry = create_test_registry(1);
 
         // Trigger shutdown
         *registry.shutting_down.borrow_mut() = true;
@@ -453,7 +499,7 @@ mod tests {
 
     #[compio::test]
     async fn test_timeout_error() {
-        let registry = TaskRegistry::new(1);
+        let registry = create_test_registry(1);
 
         // Create a task that will timeout
         let handle = compio::runtime::spawn(async move {
@@ -479,7 +525,7 @@ mod tests {
 
     #[compio::test]
     async fn test_composite_timeout() {
-        let registry = TaskRegistry::new(1);
+        let registry = create_test_registry(1);
 
         // Create a long-running task that takes 100ms
         let long_handle = compio::runtime::spawn(async move {
@@ -516,7 +562,7 @@ mod tests {
 
     #[compio::test]
     async fn test_composite_timeout_insufficient() {
-        let registry = TaskRegistry::new(1);
+        let registry = create_test_registry(1);
 
         // Create a long-running task that takes 50ms
         let long_handle = compio::runtime::spawn(async move {
@@ -555,7 +601,7 @@ mod tests {
     async fn test_periodic_last_tick_timeout() {
         // This test verifies that periodic tasks with last_tick_on_shutdown
         // don't hang shutdown if the final tick takes too long
-        let registry = TaskRegistry::new(1);
+        let registry = create_test_registry(1);
 
         // Create a handle that simulates a periodic task whose final tick 
will hang
         let handle = compio::runtime::spawn(async move {
diff --git a/core/server/src/shard/tasks/continuous/http_server.rs 
b/core/server/src/shard/tasks/continuous/http_server.rs
index 24518426a..40a024cf3 100644
--- a/core/server/src/shard/tasks/continuous/http_server.rs
+++ b/core/server/src/shard/tasks/continuous/http_server.rs
@@ -29,6 +29,7 @@ pub fn spawn_http_server(shard: Rc<IggyShard>) {
     shard
         .task_registry
         .continuous("http_server")
+        .critical(true)
         .run(move |shutdown| http_server(shard_clone, shutdown))
         .spawn();
 }
diff --git a/core/server/src/shard/tasks/oneshot/config_writer.rs 
b/core/server/src/shard/tasks/oneshot/config_writer.rs
index 19d04afe9..7399782c9 100644
--- a/core/server/src/shard/tasks/oneshot/config_writer.rs
+++ b/core/server/src/shard/tasks/oneshot/config_writer.rs
@@ -16,11 +16,13 @@
 // under the License.
 
 use crate::shard::IggyShard;
+use crate::shard::task_registry::ShutdownToken;
 use compio::io::AsyncWriteAtExt;
 use err_trail::ErrContext;
+use futures::FutureExt;
 use iggy_common::IggyError;
 use std::rc::Rc;
-use tracing::info;
+use tracing::{info, warn};
 
 pub fn spawn_config_writer_task(shard: &Rc<IggyShard>) {
     let shard_clone = shard.clone();
@@ -28,11 +30,14 @@ pub fn spawn_config_writer_task(shard: &Rc<IggyShard>) {
         .task_registry
         .oneshot("config_writer")
         .critical(false)
-        .run(move |_shutdown_token| async move { 
write_config(shard_clone).await })
+        .run(move |shutdown_token| async move { write_config(shard_clone, 
shutdown_token).await })
         .spawn();
 }
 
-async fn write_config(shard: Rc<IggyShard>) -> Result<(), IggyError> {
+async fn write_config(
+    shard: Rc<IggyShard>,
+    shutdown_token: ShutdownToken,
+) -> Result<(), IggyError> {
     let shard_clone = shard.clone();
     let tcp_enabled = shard.config.tcp.enabled;
     let quic_enabled = shard.config.quic.enabled;
@@ -41,15 +46,21 @@ async fn write_config(shard: Rc<IggyShard>) -> Result<(), 
IggyError> {
 
     let notify_receiver = shard_clone.config_writer_receiver.clone();
 
-    // Wait for notifications until all servers have bound
+    // Wait for notifications until all servers have bound, or shutdown is 
triggered
     loop {
-        notify_receiver
-            .recv()
-            .await
-            .map_err(|_| IggyError::CannotWriteToFile)
-            .with_error(
-                |_| "config_writer: notification channel closed before all 
servers bound",
-            )?;
+        futures::select! {
+            _ = shutdown_token.wait().fuse() => {
+                warn!("config_writer: shutdown triggered before all servers 
bound, skipping config write");
+                return Ok(());
+            }
+            result = notify_receiver.recv().fuse() => {
+                if result.is_err() {
+                    return Err(IggyError::CannotWriteToFile).with_error(
+                        |_| "config_writer: notification channel closed before 
all servers bound",
+                    );
+                }
+            }
+        }
 
         let tcp_ready = !tcp_enabled || 
shard_clone.tcp_bound_address.get().is_some();
         let quic_ready = !quic_enabled || 
shard_clone.quic_bound_address.get().is_some();

Reply via email to