This is an automated email from the ASF dual-hosted git repository. hgruszecki pushed a commit to branch io_uring_tpc_task_registry in repository https://gitbox.apache.org/repos/asf/iggy.git
commit daf781b851b52ca39460e5e3db0fa0258b6b0a68 Author: Hubert Gruszecki <[email protected]> AuthorDate: Sun Sep 28 13:32:48 2025 +0200 refactor --- core/common/src/error/iggy_error.rs | 2 + core/server/src/http/http_server.rs | 28 +- core/server/src/http/mod.rs | 2 +- core/server/src/quic/listener.rs | 10 +- core/server/src/shard/builder.rs | 20 +- core/server/src/shard/mod.rs | 14 +- core/server/src/shard/task_registry/builders.rs | 22 + .../src/shard/task_registry/builders/continuous.rs | 116 +++++ .../src/shard/task_registry/builders/oneshot.rs | 128 +++++ .../src/shard/task_registry/builders/periodic.rs | 144 ++++++ core/server/src/shard/task_registry/mod.rs | 12 + core/server/src/shard/task_registry/registry.rs | 550 ++++++++++++++++++++ .../src/shard/{tasks => task_registry}/shutdown.rs | 0 core/server/src/shard/task_registry/specs.rs | 59 +++ core/server/src/shard/task_registry/tls.rs | 104 ++++ core/server/src/shard/tasks/builder.rs | 366 -------------- .../src/shard/tasks/continuous/http_server.rs | 27 +- .../src/shard/tasks/continuous/message_pump.rs | 47 +- core/server/src/shard/tasks/continuous/mod.rs | 2 - .../src/shard/tasks/continuous/quic_server.rs | 20 +- .../src/shard/tasks/continuous/tcp_server.rs | 24 +- core/server/src/shard/tasks/mod.rs | 91 ++-- .../src/shard/tasks/periodic/clear_jwt_tokens.rs | 92 ++++ .../tasks/periodic/clear_personal_access_tokens.rs | 100 ++-- core/server/src/shard/tasks/periodic/mod.rs | 2 + .../src/shard/tasks/periodic/print_sysinfo.rs | 121 +++-- .../src/shard/tasks/periodic/save_messages.rs | 127 +++-- .../src/shard/tasks/periodic/verify_heartbeats.rs | 129 +++-- core/server/src/shard/tasks/specs.rs | 203 -------- core/server/src/shard/tasks/supervisor.rs | 558 --------------------- core/server/src/shard/tasks/tls.rs | 141 ------ core/server/src/slab/streams.rs | 6 +- core/server/src/tcp/tcp_listener.rs | 8 +- core/server/src/tcp/tcp_tls_listener.rs | 8 +- 34 files changed, 1579 insertions(+), 1704 deletions(-) diff --git a/core/common/src/error/iggy_error.rs b/core/common/src/error/iggy_error.rs index 614757f1..19dd41a9 100644 --- a/core/common/src/error/iggy_error.rs +++ b/core/common/src/error/iggy_error.rs @@ -476,6 +476,8 @@ pub enum IggyError { #[error("Cannot bind to socket with addr: {0}")] CannotBindToSocket(String) = 12000, + #[error("Task execution timeout")] + TaskTimeout = 12001, } impl IggyError { diff --git a/core/server/src/http/http_server.rs b/core/server/src/http/http_server.rs index 1d5ebf9d..187b9e40 100644 --- a/core/server/src/http/http_server.rs +++ b/core/server/src/http/http_server.rs @@ -25,7 +25,7 @@ use crate::http::metrics::metrics; use crate::http::shared::AppState; use crate::http::*; use crate::shard::IggyShard; -use crate::shard::tasks::{TaskScope, task_supervisor}; +use crate::shard::task_registry::{TaskScope, task_registry}; use crate::streaming::persistence::persister::PersisterKind; // use crate::streaming::systems::system::SharedSystem; use axum::extract::DefaultBodyLimit; @@ -111,27 +111,11 @@ pub async fn start( // JWT token cleaner task { - let app_state_for_cleaner = Arc::clone(&app_state); - let shard_clone = shard.clone(); - task_supervisor().spawn_periodic_tick( - "jwt_token_cleaner", - std::time::Duration::from_secs(300), - TaskScope::SpecificShard(0), - shard_clone, - move |_ctx| { - let app = app_state_for_cleaner.clone(); - async move { - use iggy_common::IggyTimestamp; - let now = IggyTimestamp::now().to_secs(); - app.jwt_manager - .delete_expired_revoked_tokens(now) - .await - .unwrap_or_else(|err| { - error!("Failed to delete expired revoked tokens: {}", err); - }); - Ok(()) - } - }, + use crate::shard::tasks::periodic::ClearJwtTokens; + let period = std::time::Duration::from_secs(300); // 5 minutes + task_registry().spawn_periodic( + shard.clone(), + Box::new(ClearJwtTokens::new(app_state.clone(), period)), ); } diff --git a/core/server/src/http/mod.rs b/core/server/src/http/mod.rs index ba369815..4cf663d3 100644 --- a/core/server/src/http/mod.rs +++ b/core/server/src/http/mod.rs @@ -23,7 +23,7 @@ mod http_shard_wrapper; pub mod jwt; mod mapper; pub mod metrics; -mod shared; +pub mod shared; pub mod consumer_groups; pub mod consumer_offsets; diff --git a/core/server/src/quic/listener.rs b/core/server/src/quic/listener.rs index 623c00c6..a6bb8502 100644 --- a/core/server/src/quic/listener.rs +++ b/core/server/src/quic/listener.rs @@ -20,7 +20,7 @@ use crate::binary::command::{ServerCommand, ServerCommandHandler}; use crate::binary::sender::SenderKind; use crate::server_error::ConnectionError; use crate::shard::IggyShard; -use crate::shard::tasks::task_supervisor; +use crate::shard::task_registry::task_registry; use crate::shard::transmission::event::ShardEvent; use crate::streaming::session::Session; use crate::{shard_debug, shard_info}; @@ -70,7 +70,7 @@ pub async fn start(endpoint: Endpoint, shard: Rc<IggyShard>) -> Result<(), IggyE let shard_for_conn = shard.clone(); // Use TaskSupervisor to track connection handlers for graceful shutdown - task_supervisor().spawn_tracked(async move { + task_registry().spawn_tracked(async move { trace!("Accepting connection from {}", remote_addr); match incoming_conn.await { Ok(connection) => { @@ -127,7 +127,7 @@ async fn handle_connection( // TODO(hubcio): unused? let _responses = shard.broadcast_event_to_all_shards(event.into()).await; - let conn_stop_receiver = task_supervisor().add_connection(client_id); + let conn_stop_receiver = task_registry().add_connection(client_id); loop { futures::select! { @@ -143,7 +143,7 @@ async fn handle_connection( let shard_clone = shard.clone(); let session_clone = session.clone(); - task_supervisor().spawn_tracked(async move { + task_registry().spawn_tracked(async move { if let Err(err) = handle_stream(stream, shard_clone, session_clone).await { error!("Error when handling QUIC stream: {:?}", err) } @@ -155,7 +155,7 @@ async fn handle_connection( } } - task_supervisor().remove_connection(&client_id); + task_registry().remove_connection(&client_id); info!("QUIC connection {} closed", client_id); Ok(()) } diff --git a/core/server/src/shard/builder.rs b/core/server/src/shard/builder.rs index 4d98dd96..a5cf3b63 100644 --- a/core/server/src/shard/builder.rs +++ b/core/server/src/shard/builder.rs @@ -16,17 +16,6 @@ * under the License. */ -use std::{ - cell::{Cell, RefCell}, - rc::Rc, - sync::{Arc, atomic::AtomicBool}, -}; - -use ahash::HashMap; -use dashmap::DashMap; -use iggy_common::{Aes256GcmEncryptor, EncryptorKind, UserId}; -use tracing::info; - use crate::{ configs::server::ServerConfig, io::storage::Storage, @@ -39,6 +28,15 @@ use crate::{ }, versioning::SemanticVersion, }; +use ahash::HashMap; +use dashmap::DashMap; +use iggy_common::{Aes256GcmEncryptor, EncryptorKind, UserId}; +use std::{ + cell::{Cell, RefCell}, + rc::Rc, + sync::{Arc, atomic::AtomicBool}, +}; +use tracing::info; use super::{IggyShard, transmission::connector::ShardConnector, transmission::frame::ShardFrame}; diff --git a/core/server/src/shard/mod.rs b/core/server/src/shard/mod.rs index 9f8b4cfc..2be4be23 100644 --- a/core/server/src/shard/mod.rs +++ b/core/server/src/shard/mod.rs @@ -20,6 +20,7 @@ pub mod builder; pub mod logging; pub mod namespace; pub mod system; +pub mod task_registry; pub mod tasks; pub mod transmission; @@ -58,7 +59,7 @@ use crate::{ io::fs_utils, shard::{ namespace::{IggyFullNamespace, IggyNamespace}, - tasks::{init_supervisor, shard_task_specs, task_supervisor}, + task_registry::{init_registry, task_registry}, transmission::{ event::ShardEvent, frame::{ShardFrame, ShardResponse}, @@ -229,7 +230,7 @@ impl IggyShard { pub async fn run(self: &Rc<Self>, persister: Arc<PersisterKind>) -> Result<(), IggyError> { // Initialize TLS task supervisor for this thread - init_supervisor(self.id); + init_registry(self.id); // Workaround to ensure that the statistics are initialized before the server // loads streams and starts accepting connections. This is necessary to // have the correct statistics when the server starts. @@ -241,8 +242,7 @@ impl IggyShard { //self.assert_init(); // Create and spawn all tasks via the supervisor - let task_specs = shard_task_specs(self.clone()); - task_supervisor().spawn(self.clone(), task_specs); + crate::shard::tasks::register_all(&task_registry(), self.clone()); // Create a oneshot channel for shutdown completion notification let (shutdown_complete_tx, shutdown_complete_rx) = async_channel::bounded(1); @@ -372,15 +372,15 @@ impl IggyShard { /// /// # Panics /// Panics if the task supervisor has not been initialized - pub fn task_supervisor() -> Rc<crate::shard::tasks::TaskSupervisor> { - task_supervisor() + pub fn task_registry() -> Rc<crate::shard::task_registry::TaskRegistry> { + task_registry() } #[instrument(skip_all, name = "trace_shutdown")] pub async fn trigger_shutdown(&self) -> bool { self.is_shutting_down.store(true, Ordering::SeqCst); debug!("Shard {} shutdown state set", self.id); - task_supervisor().graceful_shutdown(SHUTDOWN_TIMEOUT).await + task_registry().graceful_shutdown(SHUTDOWN_TIMEOUT).await } pub fn get_available_shards_count(&self) -> u32 { diff --git a/core/server/src/shard/task_registry/builders.rs b/core/server/src/shard/task_registry/builders.rs new file mode 100644 index 00000000..062a6143 --- /dev/null +++ b/core/server/src/shard/task_registry/builders.rs @@ -0,0 +1,22 @@ +pub mod continuous; +pub mod oneshot; +pub mod periodic; + +use super::registry::TaskRegistry; + +impl TaskRegistry { + pub fn periodic(&self, name: &'static str) -> periodic::PeriodicBuilder<'_> { + periodic::PeriodicBuilder::new(self, name) + } + + pub fn continuous(&self, name: &'static str) -> continuous::ContinuousBuilder<'_> { + continuous::ContinuousBuilder::new(self, name) + } + + pub fn oneshot(&self, name: &'static str) -> oneshot::OneShotBuilder<'_> { + oneshot::OneShotBuilder::new(self, name) + } +} + +pub struct NoTask; +pub struct HasTask; diff --git a/core/server/src/shard/task_registry/builders/continuous.rs b/core/server/src/shard/task_registry/builders/continuous.rs new file mode 100644 index 00000000..73743548 --- /dev/null +++ b/core/server/src/shard/task_registry/builders/continuous.rs @@ -0,0 +1,116 @@ +use crate::shard::IggyShard; +use crate::shard::task_registry::registry::TaskRegistry; +use crate::shard::task_registry::specs::{ + ContinuousTask, TaskCtx, TaskFuture, TaskMeta, TaskScope, +}; +use futures::future::LocalBoxFuture; +use iggy_common::IggyError; +use std::{fmt::Debug, marker::PhantomData, rc::Rc}; + +use crate::shard::task_registry::builders::{HasTask, NoTask}; + +pub struct ContinuousBuilder<'a, S = NoTask> { + reg: &'a TaskRegistry, + name: &'static str, + scope: TaskScope, + critical: bool, + shard: Option<Rc<IggyShard>>, + run: Option<Box<dyn FnOnce(TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>>, + _p: PhantomData<S>, +} + +impl<'a> ContinuousBuilder<'a, NoTask> { + pub fn new(reg: &'a TaskRegistry, name: &'static str) -> Self { + Self { + reg, + name, + scope: TaskScope::AllShards, + critical: false, + shard: None, + run: None, + _p: PhantomData, + } + } + + pub fn on_shard(mut self, scope: TaskScope) -> Self { + self.scope = scope; + self + } + + pub fn critical(mut self, c: bool) -> Self { + self.critical = c; + self + } + + pub fn with_shard(mut self, shard: Rc<IggyShard>) -> Self { + self.shard = Some(shard); + self + } + + pub fn run<F, Fut>(self, f: F) -> ContinuousBuilder<'a, HasTask> + where + F: FnOnce(TaskCtx) -> Fut + 'static, + Fut: std::future::Future<Output = Result<(), IggyError>> + 'static, + { + ContinuousBuilder { + reg: self.reg, + name: self.name, + scope: self.scope, + critical: self.critical, + shard: self.shard, + run: Some(Box::new(move |ctx| Box::pin(f(ctx)))), + _p: PhantomData, + } + } +} + +impl<'a> ContinuousBuilder<'a, HasTask> { + pub fn spawn(self) { + let shard = self.shard.expect("shard required"); + if !self.scope.should_run(&shard) { + return; + } + let spec = Box::new(ClosureContinuous { + name: self.name, + scope: self.scope, + critical: self.critical, + run: self.run.expect("run required"), + }); + self.reg.spawn_continuous(shard, spec); + } +} + +struct ClosureContinuous { + name: &'static str, + scope: TaskScope, + critical: bool, + run: Box<dyn FnOnce(TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>, +} + +impl Debug for ClosureContinuous { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClosureContinuous") + .field("name", &self.name) + .field("scope", &self.scope) + .field("critical", &self.critical) + .finish() + } +} + +impl TaskMeta for ClosureContinuous { + fn name(&self) -> &'static str { + self.name + } + fn scope(&self) -> TaskScope { + self.scope.clone() + } + fn is_critical(&self) -> bool { + self.critical + } +} + +impl ContinuousTask for ClosureContinuous { + fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { + (self.run)(ctx) + } +} diff --git a/core/server/src/shard/task_registry/builders/oneshot.rs b/core/server/src/shard/task_registry/builders/oneshot.rs new file mode 100644 index 00000000..706bc243 --- /dev/null +++ b/core/server/src/shard/task_registry/builders/oneshot.rs @@ -0,0 +1,128 @@ +use crate::shard::IggyShard; +use crate::shard::task_registry::registry::TaskRegistry; +use crate::shard::task_registry::specs::{OneShotTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use futures::future::LocalBoxFuture; +use iggy_common::IggyError; +use std::{fmt::Debug, marker::PhantomData, rc::Rc, time::Duration}; + +use crate::shard::task_registry::builders::{HasTask, NoTask}; + +pub struct OneShotBuilder<'a, S = NoTask> { + reg: &'a TaskRegistry, + name: &'static str, + scope: TaskScope, + critical: bool, + shard: Option<Rc<IggyShard>>, + timeout: Option<Duration>, + run: Option<Box<dyn FnOnce(TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>>, + _p: PhantomData<S>, +} + +impl<'a> OneShotBuilder<'a, NoTask> { + pub fn new(reg: &'a TaskRegistry, name: &'static str) -> Self { + Self { + reg, + name, + scope: TaskScope::AllShards, + critical: false, + shard: None, + timeout: None, + run: None, + _p: PhantomData, + } + } + + pub fn on_shard(mut self, scope: TaskScope) -> Self { + self.scope = scope; + self + } + + pub fn critical(mut self, c: bool) -> Self { + self.critical = c; + self + } + + pub fn with_shard(mut self, shard: Rc<IggyShard>) -> Self { + self.shard = Some(shard); + self + } + + pub fn timeout(mut self, d: Duration) -> Self { + self.timeout = Some(d); + self + } + + pub fn run<F, Fut>(self, f: F) -> OneShotBuilder<'a, HasTask> + where + F: FnOnce(TaskCtx) -> Fut + 'static, + Fut: std::future::Future<Output = Result<(), IggyError>> + 'static, + { + OneShotBuilder { + reg: self.reg, + name: self.name, + scope: self.scope, + critical: self.critical, + shard: self.shard, + timeout: self.timeout, + run: Some(Box::new(move |ctx| Box::pin(f(ctx)))), + _p: PhantomData, + } + } +} + +impl<'a> OneShotBuilder<'a, HasTask> { + pub fn spawn(self) { + let shard = self.shard.expect("shard required"); + if !self.scope.should_run(&shard) { + return; + } + let spec = Box::new(ClosureOneShot { + name: self.name, + scope: self.scope, + critical: self.critical, + timeout: self.timeout, + run: self.run.expect("run required"), + }); + self.reg.spawn_oneshot(shard, spec); + } +} + +struct ClosureOneShot { + name: &'static str, + scope: TaskScope, + critical: bool, + timeout: Option<Duration>, + run: Box<dyn FnOnce(TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>, +} + +impl Debug for ClosureOneShot { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClosureOneShot") + .field("name", &self.name) + .field("scope", &self.scope) + .field("critical", &self.critical) + .field("timeout", &self.timeout) + .finish() + } +} + +impl TaskMeta for ClosureOneShot { + fn name(&self) -> &'static str { + self.name + } + fn scope(&self) -> TaskScope { + self.scope.clone() + } + fn is_critical(&self) -> bool { + self.critical + } +} + +impl OneShotTask for ClosureOneShot { + fn run_once(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { + (self.run)(ctx) + } + fn timeout(&self) -> Option<Duration> { + self.timeout + } +} diff --git a/core/server/src/shard/task_registry/builders/periodic.rs b/core/server/src/shard/task_registry/builders/periodic.rs new file mode 100644 index 00000000..598021e6 --- /dev/null +++ b/core/server/src/shard/task_registry/builders/periodic.rs @@ -0,0 +1,144 @@ +use crate::shard::IggyShard; +use crate::shard::task_registry::registry::TaskRegistry; +use crate::shard::task_registry::specs::{PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use futures::future::LocalBoxFuture; +use iggy_common::IggyError; +use std::{fmt::Debug, marker::PhantomData, rc::Rc, time::Duration}; + +use crate::shard::task_registry::builders::{HasTask, NoTask}; + +pub struct PeriodicBuilder<'a, S = NoTask> { + reg: &'a TaskRegistry, + name: &'static str, + scope: TaskScope, + critical: bool, + shard: Option<Rc<IggyShard>>, + period: Option<Duration>, + last_on_shutdown: bool, + tick: Option<Box<dyn FnMut(&TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>>, + _p: PhantomData<S>, +} + +impl<'a> PeriodicBuilder<'a, NoTask> { + pub fn new(reg: &'a TaskRegistry, name: &'static str) -> Self { + Self { + reg, + name, + scope: TaskScope::AllShards, + critical: false, + shard: None, + period: None, + last_on_shutdown: false, + tick: None, + _p: PhantomData, + } + } + + pub fn every(mut self, d: Duration) -> Self { + self.period = Some(d); + self + } + + pub fn on_shard(mut self, scope: TaskScope) -> Self { + self.scope = scope; + self + } + + pub fn critical(mut self, c: bool) -> Self { + self.critical = c; + self + } + + pub fn with_shard(mut self, shard: Rc<IggyShard>) -> Self { + self.shard = Some(shard); + self + } + + pub fn last_tick_on_shutdown(mut self, v: bool) -> Self { + self.last_on_shutdown = v; + self + } + + pub fn tick<F, Fut>(self, f: F) -> PeriodicBuilder<'a, HasTask> + where + F: FnMut(&TaskCtx) -> Fut + 'static, + Fut: std::future::Future<Output = Result<(), IggyError>> + 'static, + { + let mut g = f; + PeriodicBuilder { + reg: self.reg, + name: self.name, + scope: self.scope, + critical: self.critical, + shard: self.shard, + period: self.period, + last_on_shutdown: self.last_on_shutdown, + tick: Some(Box::new(move |ctx| Box::pin(g(ctx)))), + _p: PhantomData, + } + } +} + +impl<'a> PeriodicBuilder<'a, HasTask> { + pub fn spawn(self) { + let shard = self.shard.expect("shard required"); + let period = self.period.expect("period required"); + if !self.scope.should_run(&shard) { + return; + } + let spec = Box::new(ClosurePeriodic { + name: self.name, + scope: self.scope, + critical: self.critical, + period, + last_on_shutdown: self.last_on_shutdown, + tick: self.tick.expect("tick required"), + }); + self.reg.spawn_periodic(shard, spec); + } +} + +struct ClosurePeriodic { + name: &'static str, + scope: TaskScope, + critical: bool, + period: Duration, + last_on_shutdown: bool, + tick: Box<dyn FnMut(&TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>, +} + +impl Debug for ClosurePeriodic { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClosurePeriodic") + .field("name", &self.name) + .field("scope", &self.scope) + .field("critical", &self.critical) + .field("period", &self.period) + .field("last_on_shutdown", &self.last_on_shutdown) + .finish() + } +} + +impl TaskMeta for ClosurePeriodic { + fn name(&self) -> &'static str { + self.name + } + fn scope(&self) -> TaskScope { + self.scope.clone() + } + fn is_critical(&self) -> bool { + self.critical + } +} + +impl PeriodicTask for ClosurePeriodic { + fn period(&self) -> Duration { + self.period + } + fn tick(&mut self, ctx: &TaskCtx) -> TaskFuture { + (self.tick)(ctx) + } + fn last_tick_on_shutdown(&self) -> bool { + self.last_on_shutdown + } +} diff --git a/core/server/src/shard/task_registry/mod.rs b/core/server/src/shard/task_registry/mod.rs new file mode 100644 index 00000000..f11e76ba --- /dev/null +++ b/core/server/src/shard/task_registry/mod.rs @@ -0,0 +1,12 @@ +pub mod builders; +pub mod registry; +pub mod shutdown; +pub mod specs; +pub mod tls; + +pub use registry::TaskRegistry; +pub use shutdown::{Shutdown, ShutdownToken}; +pub use specs::{ + ContinuousTask, OneShotTask, PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskResult, TaskScope, +}; +pub use tls::{init_registry, is_registry_initialized, task_registry}; diff --git a/core/server/src/shard/task_registry/registry.rs b/core/server/src/shard/task_registry/registry.rs new file mode 100644 index 00000000..9accce5b --- /dev/null +++ b/core/server/src/shard/task_registry/registry.rs @@ -0,0 +1,550 @@ +use super::shutdown::{Shutdown, ShutdownToken}; +use super::specs::{ + ContinuousTask, OneShotTask, PeriodicTask, TaskCtx, TaskMeta, TaskResult, TaskScope, +}; +use crate::shard::IggyShard; +use compio::runtime::JoinHandle; +use futures::future::join_all; +use iggy_common::IggyError; +use std::{cell::RefCell, collections::HashMap, rc::Rc, time::Duration}; +use tracing::{debug, error, trace, warn}; + +enum Kind { + Continuous, + Periodic(Duration), + OneShot, +} + +struct TaskHandle { + name: String, + kind: Kind, + handle: JoinHandle<TaskResult>, + critical: bool, +} + +pub struct TaskRegistry { + pub(crate) shard_id: u16, + shutdown: Shutdown, + shutdown_token: ShutdownToken, + long_running: RefCell<Vec<TaskHandle>>, + oneshots: RefCell<Vec<TaskHandle>>, + connections: RefCell<HashMap<u32, async_channel::Sender<()>>>, + shutting_down: RefCell<bool>, +} + +impl TaskRegistry { + pub fn new(shard_id: u16) -> Self { + let (s, t) = Shutdown::new(); + Self { + shard_id, + shutdown: s, + shutdown_token: t, + long_running: RefCell::new(vec![]), + oneshots: RefCell::new(vec![]), + connections: RefCell::new(HashMap::new()), + shutting_down: RefCell::new(false), + } + } + + pub fn shutdown_token(&self) -> ShutdownToken { + self.shutdown_token.clone() + } + + pub fn spawn_continuous(&self, shard: Rc<IggyShard>, mut task: Box<dyn ContinuousTask>) { + if *self.shutting_down.borrow() { + warn!( + "Attempted to spawn continuous task '{}' during shutdown", + task.name() + ); + return; + } + if !task.scope().should_run(&shard) { + return; + } + task.on_start(); + let name = task.name(); + let is_critical = task.is_critical(); + let ctx = TaskCtx { + shard, + shutdown: self.shutdown_token.clone(), + }; + let shard_id = self.shard_id; + + let handle = compio::runtime::spawn(async move { + trace!("continuous '{}' starting on shard {}", name, shard_id); + let r = task.run(ctx).await; + match &r { + Ok(()) => debug!("continuous '{}' completed on shard {}", name, shard_id), + Err(e) => error!("continuous '{}' failed on shard {}: {}", name, shard_id, e), + } + r + }); + + self.long_running.borrow_mut().push(TaskHandle { + name: name.into(), + kind: Kind::Continuous, + handle, + critical: is_critical, + }); + } + + pub fn spawn_periodic(&self, shard: Rc<IggyShard>, mut task: Box<dyn PeriodicTask>) { + if *self.shutting_down.borrow() { + warn!( + "Attempted to spawn periodic task '{}' during shutdown", + task.name() + ); + return; + } + if !task.scope().should_run(&shard) { + return; + } + let period = task.period(); + task.on_start(); + let name = task.name(); + let is_critical = task.is_critical(); + let ctx = TaskCtx { + shard, + shutdown: self.shutdown_token.clone(), + }; + let shutdown = self.shutdown_token.clone(); + let shard_id = self.shard_id; + + let handle = compio::runtime::spawn(async move { + trace!( + "periodic '{}' every {:?} on shard {}", + name, period, shard_id + ); + loop { + if !shutdown.sleep_or_shutdown(period).await { + break; + } + if let Err(e) = task.tick(&ctx).await { + error!( + "periodic '{}' tick failed on shard {}: {}", + name, shard_id, e + ); + } + } + if task.last_tick_on_shutdown() { + let _ = task.tick(&ctx).await; + } + Ok(()) + }); + + self.long_running.borrow_mut().push(TaskHandle { + name: name.into(), + kind: Kind::Periodic(period), + handle, + critical: is_critical, + }); + } + + pub fn spawn_oneshot(&self, shard: Rc<IggyShard>, mut task: Box<dyn OneShotTask>) { + if *self.shutting_down.borrow() { + warn!( + "Attempted to spawn oneshot task '{}' during shutdown", + task.name() + ); + return; + } + if !task.scope().should_run(&shard) { + return; + } + task.on_start(); + let name = task.name(); + let is_critical = task.is_critical(); + let timeout = task.timeout(); + let ctx = TaskCtx { + shard, + shutdown: self.shutdown_token.clone(), + }; + let shard_id = self.shard_id; + + let handle = compio::runtime::spawn(async move { + trace!("oneshot '{}' starting on shard {}", name, shard_id); + let fut = task.run_once(ctx); + let r = if let Some(d) = timeout { + match compio::time::timeout(d, fut).await { + Ok(r) => r, + Err(_) => Err(IggyError::TaskTimeout), + } + } else { + fut.await + }; + match &r { + Ok(()) => trace!("oneshot '{}' completed on shard {}", name, shard_id), + Err(e) => error!("oneshot '{}' failed on shard {}: {}", name, shard_id, e), + } + r + }); + + self.oneshots.borrow_mut().push(TaskHandle { + name: name.into(), + kind: Kind::OneShot, + handle, + critical: is_critical, + }); + } + + pub async fn graceful_shutdown(&self, timeout: Duration) -> bool { + use std::time::Instant; + + let start = Instant::now(); + *self.shutting_down.borrow_mut() = true; + self.shutdown_connections(); + self.shutdown.trigger(); + + // First shutdown long-running tasks (continuous and periodic) + let long = self.long_running.take(); + let long_ok = if !long.is_empty() { + debug!( + "Shutting down {} long-running task(s) on shard {}", + long.len(), + self.shard_id + ); + self.await_with_timeout(long, timeout).await + } else { + true + }; + + // Calculate remaining time for oneshots + let elapsed = start.elapsed(); + let remaining = timeout.saturating_sub(elapsed); + + // Then shutdown oneshot tasks with remaining time + let ones = self.oneshots.take(); + let ones_ok = if !ones.is_empty() { + if remaining.is_zero() { + warn!( + "No time remaining for {} oneshot task(s) on shard {}, they will be cancelled", + ones.len(), + self.shard_id + ); + false + } else { + debug!( + "Shutting down {} oneshot task(s) on shard {} with {:?} remaining", + ones.len(), + self.shard_id, + remaining + ); + self.await_with_timeout(ones, remaining).await + } + } else { + true + }; + + let total_elapsed = start.elapsed(); + if long_ok && ones_ok { + debug!( + "Graceful shutdown completed successfully on shard {} in {:?}", + self.shard_id, total_elapsed + ); + } else { + warn!( + "Graceful shutdown completed with failures on shard {} in {:?}", + self.shard_id, total_elapsed + ); + } + + long_ok && ones_ok + } + + async fn await_with_timeout(&self, tasks: Vec<TaskHandle>, timeout: Duration) -> bool { + if tasks.is_empty() { + return true; + } + let results = join_all(tasks.into_iter().map(|t| async move { + match compio::time::timeout(timeout, t.handle).await { + Ok(Ok(Ok(()))) => true, + Ok(Ok(Err(e))) => { + error!("task '{}' failed: {}", t.name, e); + !t.critical + } + Ok(Err(_)) => { + error!("task '{}' panicked", t.name); + !t.critical + } + Err(_) => { + error!("task '{}' timed out after {:?}", t.name, timeout); + !t.critical + } + } + })) + .await; + + results.into_iter().all(|x| x) + } + + async fn await_all(&self, tasks: Vec<TaskHandle>) -> bool { + if tasks.is_empty() { + return true; + } + let results = join_all(tasks.into_iter().map(|t| async move { + match t.handle.await { + Ok(Ok(())) => true, + Ok(Err(e)) => { + error!("task '{}' failed: {}", t.name, e); + !t.critical + } + Err(_) => { + error!("task '{}' panicked", t.name); + !t.critical + } + } + })) + .await; + results.into_iter().all(|x| x) + } + + pub fn add_connection(&self, client_id: u32) -> async_channel::Receiver<()> { + let (tx, rx) = async_channel::bounded(1); + self.connections.borrow_mut().insert(client_id, tx); + rx + } + + pub fn remove_connection(&self, client_id: &u32) { + self.connections.borrow_mut().remove(client_id); + } + + fn shutdown_connections(&self) { + for tx in self.connections.borrow().values() { + let _ = tx.send_blocking(()); + } + } + + pub fn spawn_tracked<F>(&self, future: F) + where + F: futures::Future<Output = ()> + 'static, + { + let handle = compio::runtime::spawn(async move { + future.await; + Ok(()) + }); + self.long_running.borrow_mut().push(TaskHandle { + name: "connection_handler".to_string(), + kind: Kind::Continuous, + handle, + critical: false, + }); + } + + pub fn spawn_oneshot_future<F>(&self, name: &'static str, critical: bool, f: F) + where + F: futures::Future<Output = Result<(), IggyError>> + 'static, + { + if *self.shutting_down.borrow() { + warn!( + "Attempted to spawn oneshot future '{}' during shutdown", + name + ); + return; + } + let shard_id = self.shard_id; + let handle = compio::runtime::spawn(async move { + trace!("oneshot '{}' starting on shard {}", name, shard_id); + let r = f.await; + match &r { + Ok(()) => trace!("oneshot '{}' completed on shard {}", name, shard_id), + Err(e) => error!("oneshot '{}' failed on shard {}: {}", name, shard_id, e), + } + r + }); + + self.oneshots.borrow_mut().push(TaskHandle { + name: name.into(), + kind: Kind::OneShot, + handle, + critical, + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::shard::task_registry::specs::{OneShotTask, TaskCtx, TaskFuture, TaskMeta}; + use std::fmt::Debug; + + #[derive(Debug)] + struct TestOneShotTask { + should_fail: bool, + is_critical: bool, + } + + impl TaskMeta for TestOneShotTask { + fn name(&self) -> &'static str { + "test_oneshot" + } + + fn is_critical(&self) -> bool { + self.is_critical + } + } + + impl OneShotTask for TestOneShotTask { + fn run_once(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { + Box::pin(async move { + if self.should_fail { + Err(IggyError::Error) + } else { + Ok(()) + } + }) + } + + fn timeout(&self) -> Option<Duration> { + Some(Duration::from_millis(100)) + } + } + + #[compio::test] + async fn test_oneshot_completion_detection() { + let registry = TaskRegistry::new(1); + + // Spawn a failing non-critical task + registry.spawn_oneshot_future("failing_non_critical", false, async { + Err(IggyError::Error) + }); + + // Spawn a successful task + registry.spawn_oneshot_future("successful", false, async { Ok(()) }); + + // Wait for all tasks + let all_ok = registry.await_all(registry.oneshots.take()).await; + + // Should return true because the failing task is not critical + assert!(all_ok); + } + + #[compio::test] + async fn test_oneshot_critical_failure() { + let registry = TaskRegistry::new(1); + + // Spawn a failing critical task + registry.spawn_oneshot_future("failing_critical", true, async { Err(IggyError::Error) }); + + // Wait for all tasks + let all_ok = registry.await_all(registry.oneshots.take()).await; + + // Should return false because the failing task is critical + assert!(!all_ok); + } + + #[compio::test] + async fn test_shutdown_prevents_spawning() { + let registry = TaskRegistry::new(1); + + // Trigger shutdown + *registry.shutting_down.borrow_mut() = true; + + let initial_count = registry.oneshots.borrow().len(); + + // Try to spawn after shutdown + registry.spawn_oneshot_future("should_not_spawn", false, async { Ok(()) }); + + // Task should not be added + assert_eq!(registry.oneshots.borrow().len(), initial_count); + } + + #[compio::test] + async fn test_timeout_error() { + let registry = TaskRegistry::new(1); + + // Create a task that will timeout + let handle = compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_secs(10)).await; + Ok(()) + }); + + let task_handle = TaskHandle { + name: "timeout_test".to_string(), + kind: Kind::OneShot, + handle, + critical: false, + }; + + let tasks = vec![task_handle]; + let all_ok = registry + .await_with_timeout(tasks, Duration::from_millis(50)) + .await; + + // Should return true because the task is not critical + assert!(all_ok); + } + + #[compio::test] + async fn test_composite_timeout() { + let registry = TaskRegistry::new(1); + + // Create a long-running task that takes 100ms + let long_handle = compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_millis(100)).await; + Ok(()) + }); + + registry.long_running.borrow_mut().push(TaskHandle { + name: "long_task".to_string(), + kind: Kind::Continuous, + handle: long_handle, + critical: false, + }); + + // Create a oneshot that would succeed quickly + let oneshot_handle = compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_millis(10)).await; + Ok(()) + }); + + registry.oneshots.borrow_mut().push(TaskHandle { + name: "quick_oneshot".to_string(), + kind: Kind::OneShot, + handle: oneshot_handle, + critical: false, + }); + + // Give total timeout of 150ms + // Long-running should complete in ~100ms + // Oneshot should have ~50ms remaining, which is enough + let all_ok = registry.graceful_shutdown(Duration::from_millis(150)).await; + assert!(all_ok); + } + + #[compio::test] + async fn test_composite_timeout_insufficient() { + let registry = TaskRegistry::new(1); + + // Create a long-running task that takes 50ms + let long_handle = compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_millis(50)).await; + Ok(()) + }); + + registry.long_running.borrow_mut().push(TaskHandle { + name: "long_task".to_string(), + kind: Kind::Continuous, + handle: long_handle, + critical: false, + }); + + // Create a oneshot that would take 100ms (much longer) + let oneshot_handle = compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_millis(100)).await; + Ok(()) + }); + + registry.oneshots.borrow_mut().push(TaskHandle { + name: "slow_oneshot".to_string(), + kind: Kind::OneShot, + handle: oneshot_handle, + critical: true, // Make it critical so failure is detected + }); + + // Give total timeout of 60ms + // Long-running should complete in ~50ms + // Oneshot would need 100ms but only has ~10ms, so it should definitely fail + let all_ok = registry.graceful_shutdown(Duration::from_millis(60)).await; + assert!(!all_ok); // Should fail because critical oneshot times out + } +} diff --git a/core/server/src/shard/tasks/shutdown.rs b/core/server/src/shard/task_registry/shutdown.rs similarity index 100% rename from core/server/src/shard/tasks/shutdown.rs rename to core/server/src/shard/task_registry/shutdown.rs diff --git a/core/server/src/shard/task_registry/specs.rs b/core/server/src/shard/task_registry/specs.rs new file mode 100644 index 00000000..15ea0e55 --- /dev/null +++ b/core/server/src/shard/task_registry/specs.rs @@ -0,0 +1,59 @@ +use crate::shard::IggyShard; +use crate::shard::task_registry::ShutdownToken; +use futures::future::LocalBoxFuture; +use iggy_common::IggyError; +use std::{fmt::Debug, rc::Rc, time::Duration}; + +pub type TaskResult = Result<(), IggyError>; +pub type TaskFuture = LocalBoxFuture<'static, TaskResult>; + +#[derive(Clone, Debug)] +pub enum TaskScope { + AllShards, + SpecificShard(u16), +} + +impl TaskScope { + pub fn should_run(&self, shard: &IggyShard) -> bool { + match self { + TaskScope::AllShards => true, + TaskScope::SpecificShard(id) => shard.id == *id, + } + } +} + +#[derive(Clone)] +pub struct TaskCtx { + pub shard: Rc<IggyShard>, + pub shutdown: ShutdownToken, +} + +pub trait TaskMeta: 'static + Debug { + fn name(&self) -> &'static str; + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + fn is_critical(&self) -> bool { + false + } + fn on_start(&self) {} +} + +pub trait ContinuousTask: TaskMeta { + fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture; +} + +pub trait PeriodicTask: TaskMeta { + fn period(&self) -> Duration; + fn tick(&mut self, ctx: &TaskCtx) -> TaskFuture; + fn last_tick_on_shutdown(&self) -> bool { + false + } +} + +pub trait OneShotTask: TaskMeta { + fn run_once(self: Box<Self>, ctx: TaskCtx) -> TaskFuture; + fn timeout(&self) -> Option<Duration> { + None + } +} diff --git a/core/server/src/shard/task_registry/tls.rs b/core/server/src/shard/task_registry/tls.rs new file mode 100644 index 00000000..00a344b9 --- /dev/null +++ b/core/server/src/shard/task_registry/tls.rs @@ -0,0 +1,104 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use super::registry::TaskRegistry; +use std::cell::RefCell; +use std::rc::Rc; + +thread_local! { + static REGISTRY: RefCell<Option<Rc<TaskRegistry>>> = RefCell::new(None); +} + +pub fn init_registry(shard_id: u16) { + REGISTRY.with(|s| { + *s.borrow_mut() = Some(Rc::new(TaskRegistry::new(shard_id))); + }); +} + +pub fn task_registry() -> Rc<TaskRegistry> { + REGISTRY.with(|s| { + s.borrow() + .as_ref() + .expect("Task registry not initialized for this thread. Call init_registry() first.") + .clone() + }) +} + +pub fn is_registry_initialized() -> bool { + REGISTRY.with(|s| s.borrow().is_some()) +} + +pub fn clear_registry() { + REGISTRY.with(|s| { + *s.borrow_mut() = None; + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use std::panic; + + #[test] + fn test_registry_initialization() { + clear_registry(); + assert!(!is_registry_initialized()); + + let result = panic::catch_unwind(|| { + task_registry(); + }); + assert!(result.is_err()); + + init_registry(42); + assert!(is_registry_initialized()); + + clear_registry(); + assert!(!is_registry_initialized()); + } + + #[test] + fn test_multiple_initializations() { + clear_registry(); + init_registry(1); + let _reg1 = task_registry(); + init_registry(2); + let _reg2 = task_registry(); + clear_registry(); + } + + #[test] + fn test_thread_locality() { + use std::thread; + + clear_registry(); + init_registry(100); + assert!(is_registry_initialized()); + + let handle = thread::spawn(|| { + assert!(!is_registry_initialized()); + init_registry(200); + assert!(is_registry_initialized()); + let _ = task_registry(); + }); + + handle.join().expect("Thread should complete successfully"); + assert!(is_registry_initialized()); + let _ = task_registry(); + clear_registry(); + } +} diff --git a/core/server/src/shard/tasks/builder.rs b/core/server/src/shard/tasks/builder.rs deleted file mode 100644 index e1a06326..00000000 --- a/core/server/src/shard/tasks/builder.rs +++ /dev/null @@ -1,366 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use super::specs::{TaskCtx, TaskKind, TaskScope}; -use super::supervisor::TaskSupervisor; -use crate::shard::IggyShard; -use futures::Future; -use futures::future::LocalBoxFuture; -use iggy_common::IggyError; -use std::marker::PhantomData; -use std::rc::Rc; -use std::time::Duration; -use tracing::{debug, info}; - -/// Type state for builder that hasn't had its task function set yet -pub struct NoTask; - -/// Type state for builder that has its task function set -pub struct HasTask; - -/// Builder for periodic tasks with fluent API -pub struct PeriodicTaskBuilder<'a, State = NoTask> { - supervisor: &'a TaskSupervisor, - name: &'static str, - period: Option<Duration>, - scope: TaskScope, - critical: bool, - shard: Option<Rc<IggyShard>>, - tick: Option<Box<dyn FnMut(&TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>>, - _phantom: PhantomData<State>, -} - -impl<'a> PeriodicTaskBuilder<'a, NoTask> { - pub(crate) fn new(supervisor: &'a TaskSupervisor, name: &'static str) -> Self { - Self { - supervisor, - name, - period: None, - scope: TaskScope::AllShards, - critical: false, - shard: None, - tick: None, - _phantom: PhantomData, - } - } - - /// Set the period for this periodic task - pub fn every(mut self, period: Duration) -> Self { - self.period = Some(period); - self - } - - /// Set the shard scope for this task - pub fn on_shard(mut self, scope: TaskScope) -> Self { - self.scope = scope; - self - } - - /// Mark this task as critical - pub fn critical(mut self, critical: bool) -> Self { - self.critical = critical; - self - } - - /// Set the shard for this task - pub fn with_shard(mut self, shard: Rc<IggyShard>) -> Self { - self.shard = Some(shard); - self - } - - /// Set the tick function for this periodic task - pub fn tick<F, Fut>(self, tick: F) -> PeriodicTaskBuilder<'a, HasTask> - where - F: FnMut(&TaskCtx) -> Fut + 'static, - Fut: Future<Output = Result<(), IggyError>> + 'static, - { - let mut tick = tick; - let tick_boxed: Box<dyn FnMut(&TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>> = - Box::new(move |ctx| Box::pin(tick(ctx))); - - PeriodicTaskBuilder { - supervisor: self.supervisor, - name: self.name, - period: self.period, - scope: self.scope, - critical: self.critical, - shard: self.shard, - tick: Some(tick_boxed), - _phantom: PhantomData, - } - } -} - -impl<'a> PeriodicTaskBuilder<'a, HasTask> { - /// Spawn the periodic task - pub fn spawn(self) { - let period = self.period.expect("Period must be set for periodic tasks"); - let shard = self.shard.expect("Shard must be set for periodic tasks"); - let mut tick = self.tick.expect("Tick function must be set"); - - // Check if task should run on this shard - if !self.scope.should_run(&shard) { - return; - } - - let ctx = TaskCtx { - shard: shard.clone(), - shutdown: self.supervisor.shutdown_token(), - }; - - let name = self.name; - let shard_id = self.supervisor.shard_id; - let shutdown = self.supervisor.shutdown_token(); - let is_critical = self.critical; - - debug!( - "Spawning periodic task '{}' on shard {} with period {:?}", - name, shard_id, period - ); - - let handle = compio::runtime::spawn(async move { - loop { - // Use shutdown-aware sleep - if !shutdown.sleep_or_shutdown(period).await { - tracing::trace!("Periodic task '{}' shutting down", name); - break; - } - - // Execute tick - match tick(&ctx).await { - Ok(()) => tracing::trace!("Periodic task '{}' tick completed", name), - Err(e) => tracing::error!("Periodic task '{}' tick failed: {}", name, e), - } - } - - debug!("Periodic task '{}' completed on shard {}", name, shard_id); - Ok(()) - }); - - self.supervisor - .register_task(name, TaskKind::Periodic { period }, handle, is_critical); - } -} - -/// Builder for continuous tasks with fluent API -pub struct ContinuousTaskBuilder<'a, State = NoTask> { - supervisor: &'a TaskSupervisor, - name: &'static str, - scope: TaskScope, - critical: bool, - shard: Option<Rc<IggyShard>>, - run: Option<Box<dyn FnOnce(&TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>>, - _phantom: PhantomData<State>, -} - -impl<'a> ContinuousTaskBuilder<'a, NoTask> { - pub(crate) fn new(supervisor: &'a TaskSupervisor, name: &'static str) -> Self { - Self { - supervisor, - name, - scope: TaskScope::AllShards, - critical: false, - shard: None, - run: None, - _phantom: PhantomData, - } - } - - /// Set the shard scope for this task - pub fn on_shard(mut self, scope: TaskScope) -> Self { - self.scope = scope; - self - } - - /// Mark this task as critical - pub fn critical(mut self, critical: bool) -> Self { - self.critical = critical; - self - } - - /// Set the shard for this task - pub fn with_shard(mut self, shard: Rc<IggyShard>) -> Self { - self.shard = Some(shard); - self - } - - /// Set the run function for this continuous task - pub fn run<F, Fut>(self, run: F) -> ContinuousTaskBuilder<'a, HasTask> - where - F: FnOnce(&TaskCtx) -> Fut + 'static, - Fut: Future<Output = Result<(), IggyError>> + 'static, - { - let run_boxed: Box<dyn FnOnce(&TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>> = - Box::new(move |ctx| Box::pin(run(ctx))); - - ContinuousTaskBuilder { - supervisor: self.supervisor, - name: self.name, - scope: self.scope, - critical: self.critical, - shard: self.shard, - run: Some(run_boxed), - _phantom: PhantomData, - } - } -} - -impl<'a> ContinuousTaskBuilder<'a, HasTask> { - /// Spawn the continuous task - pub fn spawn(self) { - let shard = self.shard.expect("Shard must be set for continuous tasks"); - let run = self.run.expect("Run function must be set"); - - if !self.scope.should_run(&shard) { - return; - } - - let ctx = TaskCtx { - shard: shard.clone(), - shutdown: self.supervisor.shutdown_token(), - }; - - let name = self.name; - let shard_id = self.supervisor.shard_id; - let is_critical = self.critical; - - debug!("Spawning continuous task '{}' on shard {}", name, shard_id); - - let handle = compio::runtime::spawn(async move { - let result = run(&ctx).await; - match &result { - Ok(()) => debug!("Continuous task '{}' completed on shard {}", name, shard_id), - Err(e) => tracing::error!( - "Continuous task '{}' failed on shard {}: {}", - name, - shard_id, - e - ), - } - result - }); - - self.supervisor - .register_task(name, TaskKind::Continuous, handle, is_critical); - } -} - -/// Builder for oneshot tasks with fluent API -pub struct OneshotTaskBuilder<'a, State = NoTask> { - supervisor: &'a TaskSupervisor, - name: &'static str, - scope: TaskScope, - critical: bool, - timeout: Option<Duration>, - shard: Option<Rc<IggyShard>>, - run: Option<Box<dyn FnOnce(&TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>>, - _phantom: PhantomData<State>, -} - -impl<'a> OneshotTaskBuilder<'a, NoTask> { - pub(crate) fn new(supervisor: &'a TaskSupervisor, name: &'static str) -> Self { - Self { - supervisor, - name, - scope: TaskScope::AllShards, - critical: false, - timeout: None, - shard: None, - run: None, - _phantom: PhantomData, - } - } - - /// Set the shard scope for this task - pub fn on_shard(mut self, scope: TaskScope) -> Self { - self.scope = scope; - self - } - - /// Mark this task as critical - pub fn critical(mut self, critical: bool) -> Self { - self.critical = critical; - self - } - - /// Set a timeout for this oneshot task - pub fn timeout(mut self, timeout: Duration) -> Self { - self.timeout = Some(timeout); - self - } - - /// Set the shard for this task - pub fn with_shard(mut self, shard: Rc<IggyShard>) -> Self { - self.shard = Some(shard); - self - } - - /// Set the run function for this oneshot task - pub fn run<F, Fut>(self, run: F) -> OneshotTaskBuilder<'a, HasTask> - where - F: FnOnce(&TaskCtx) -> Fut + 'static, - Fut: Future<Output = Result<(), IggyError>> + 'static, - { - let run_boxed: Box<dyn FnOnce(&TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>> = - Box::new(move |ctx| Box::pin(run(ctx))); - - OneshotTaskBuilder { - supervisor: self.supervisor, - name: self.name, - scope: self.scope, - critical: self.critical, - timeout: self.timeout, - shard: self.shard, - run: Some(run_boxed), - _phantom: PhantomData, - } - } -} - -impl<'a> OneshotTaskBuilder<'a, HasTask> { - /// Spawn the oneshot task - pub fn spawn(self) { - let shard = self.shard.expect("Shard must be set for oneshot tasks"); - let run = self.run.expect("Run function must be set"); - - // Check if task should run on this shard - if !self.scope.should_run(&shard) { - return; - } - - let ctx = TaskCtx { - shard: shard.clone(), - shutdown: self.supervisor.shutdown_token(), - }; - - // If timeout is specified, wrap the future with timeout - if let Some(timeout) = self.timeout { - let name = self.name; - self.supervisor - .spawn_oneshot(name, self.critical, async move { - match compio::time::timeout(timeout, run(&ctx)).await { - Ok(result) => result, - Err(_) => Err(IggyError::InvalidCommand), - } - }); - } else { - self.supervisor - .spawn_oneshot(self.name, self.critical, async move { run(&ctx).await }); - } - } -} diff --git a/core/server/src/shard/tasks/continuous/http_server.rs b/core/server/src/shard/tasks/continuous/http_server.rs index efd624b7..81dde7d3 100644 --- a/core/server/src/shard/tasks/continuous/http_server.rs +++ b/core/server/src/shard/tasks/continuous/http_server.rs @@ -19,14 +19,11 @@ use crate::bootstrap::resolve_persister; use crate::http::http_server; use crate::shard::IggyShard; -use crate::shard::tasks::specs::{ - ContinuousSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, -}; +use crate::shard::task_registry::{ContinuousTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; use std::fmt::Debug; use std::rc::Rc; use tracing::info; -/// Continuous task for running the HTTP REST API server pub struct HttpServer { shard: Rc<IggyShard>, } @@ -45,15 +42,11 @@ impl HttpServer { } } -impl TaskSpec for HttpServer { +impl TaskMeta for HttpServer { fn name(&self) -> &'static str { "http_server" } - fn kind(&self) -> TaskKind { - TaskKind::Continuous - } - fn scope(&self) -> TaskScope { TaskScope::SpecificShard(0) } @@ -61,19 +54,15 @@ impl TaskSpec for HttpServer { fn is_critical(&self) -> bool { false } +} +impl ContinuousTask for HttpServer { fn run(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { + let shard = self.shard.clone(); Box::pin(async move { - info!("Starting HTTP server on shard: {}", self.shard.id); - let persister = resolve_persister(self.shard.config.system.partition.enforce_fsync); - http_server::start( - self.shard.config.http.clone(), - persister, - self.shard.clone(), - ) - .await + info!("Starting HTTP server on shard: {}", shard.id); + let persister = resolve_persister(shard.config.system.partition.enforce_fsync); + http_server::start(shard.config.http.clone(), persister, shard).await }) } } - -impl ContinuousSpec for HttpServer {} diff --git a/core/server/src/shard/tasks/continuous/message_pump.rs b/core/server/src/shard/tasks/continuous/message_pump.rs index b50979f1..b0340e1b 100644 --- a/core/server/src/shard/tasks/continuous/message_pump.rs +++ b/core/server/src/shard/tasks/continuous/message_pump.rs @@ -17,16 +17,13 @@ */ use crate::shard::IggyShard; -use crate::shard::tasks::specs::{ - ContinuousSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, -}; +use crate::shard::task_registry::{ContinuousTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; use crate::shard::transmission::frame::ShardFrame; use crate::{shard_debug, shard_info}; use futures::{FutureExt, StreamExt}; use std::fmt::Debug; use std::rc::Rc; -/// Continuous task for processing inter-shard messages pub struct MessagePump { shard: Rc<IggyShard>, } @@ -45,15 +42,11 @@ impl MessagePump { } } -impl TaskSpec for MessagePump { +impl TaskMeta for MessagePump { fn name(&self) -> &'static str { "message_pump" } - fn kind(&self) -> TaskKind { - TaskKind::Continuous - } - fn scope(&self) -> TaskScope { TaskScope::AllShards } @@ -61,10 +54,18 @@ impl TaskSpec for MessagePump { fn is_critical(&self) -> bool { true } +} +impl ContinuousTask for MessagePump { fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { Box::pin(async move { - let mut messages_receiver = self.shard.messages_receiver.take().unwrap(); + let Some(mut messages_receiver) = self.shard.messages_receiver.take() else { + shard_info!( + self.shard.id, + "Message receiver already taken; pump not started" + ); + return Ok(()); + }; shard_info!(self.shard.id, "Starting message passing task"); @@ -75,18 +76,18 @@ impl TaskSpec for MessagePump { break; } frame = messages_receiver.next().fuse() => { - if let Some(frame) = frame { - let ShardFrame { - message, - response_sender, - } = frame; - - if let (Some(response), Some(response_sender)) = (self.shard.handle_shard_message(message).await, response_sender) { - response_sender - .send(response) - .await - .expect("Failed to send response back to origin shard."); - }; + match frame { + Some(ShardFrame { message, response_sender }) => { + if let (Some(response), Some(tx)) = + (self.shard.handle_shard_message(message).await, response_sender) + { + let _ = tx.send(response).await; + } + } + None => { + shard_debug!(self.shard.id, "Message receiver closed; exiting pump"); + break; + } } } } @@ -96,5 +97,3 @@ impl TaskSpec for MessagePump { }) } } - -impl ContinuousSpec for MessagePump {} diff --git a/core/server/src/shard/tasks/continuous/mod.rs b/core/server/src/shard/tasks/continuous/mod.rs index 97a3ff75..59ed6293 100644 --- a/core/server/src/shard/tasks/continuous/mod.rs +++ b/core/server/src/shard/tasks/continuous/mod.rs @@ -16,8 +16,6 @@ * under the License. */ -//! Continuous task specifications for long-running services - pub mod http_server; pub mod message_pump; pub mod quic_server; diff --git a/core/server/src/shard/tasks/continuous/quic_server.rs b/core/server/src/shard/tasks/continuous/quic_server.rs index c0d451b4..81cd2fcd 100644 --- a/core/server/src/shard/tasks/continuous/quic_server.rs +++ b/core/server/src/shard/tasks/continuous/quic_server.rs @@ -18,13 +18,10 @@ use crate::quic::quic_server; use crate::shard::IggyShard; -use crate::shard::tasks::specs::{ - ContinuousSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, -}; +use crate::shard::task_registry::{ContinuousTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; use std::fmt::Debug; use std::rc::Rc; -/// Continuous task for running the QUIC server pub struct QuicServer { shard: Rc<IggyShard>, } @@ -43,26 +40,23 @@ impl QuicServer { } } -impl TaskSpec for QuicServer { +impl TaskMeta for QuicServer { fn name(&self) -> &'static str { "quic_server" } - fn kind(&self) -> TaskKind { - TaskKind::Continuous - } - fn scope(&self) -> TaskScope { TaskScope::AllShards } fn is_critical(&self) -> bool { - false + true } +} +impl ContinuousTask for QuicServer { fn run(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { - Box::pin(async move { quic_server::span_quic_server(self.shard.clone()).await }) + let shard = self.shard.clone(); + Box::pin(async move { quic_server::span_quic_server(shard).await }) } } - -impl ContinuousSpec for QuicServer {} diff --git a/core/server/src/shard/tasks/continuous/tcp_server.rs b/core/server/src/shard/tasks/continuous/tcp_server.rs index b269db58..f2bcda62 100644 --- a/core/server/src/shard/tasks/continuous/tcp_server.rs +++ b/core/server/src/shard/tasks/continuous/tcp_server.rs @@ -17,14 +17,11 @@ */ use crate::shard::IggyShard; -use crate::shard::tasks::specs::{ - ContinuousSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, -}; +use crate::shard::task_registry::{ContinuousTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; use crate::tcp::tcp_server::spawn_tcp_server; use std::fmt::Debug; use std::rc::Rc; -/// Continuous task for running the TCP server pub struct TcpServer { shard: Rc<IggyShard>, } @@ -43,30 +40,23 @@ impl TcpServer { } } -impl TaskSpec for TcpServer { +impl TaskMeta for TcpServer { fn name(&self) -> &'static str { "tcp_server" } - fn kind(&self) -> TaskKind { - TaskKind::Continuous - } - fn scope(&self) -> TaskScope { TaskScope::AllShards } fn is_critical(&self) -> bool { - false + true } +} +impl ContinuousTask for TcpServer { fn run(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { - Box::pin(async move { - // The existing spawn_tcp_server already handles shutdown internally - // via is_shutting_down checks. This will be refactored later. - spawn_tcp_server(self.shard.clone()).await - }) + let shard = self.shard.clone(); + Box::pin(async move { spawn_tcp_server(shard).await }) } } - -impl ContinuousSpec for TcpServer {} diff --git a/core/server/src/shard/tasks/mod.rs b/core/server/src/shard/tasks/mod.rs index b156d68b..367c93c3 100644 --- a/core/server/src/shard/tasks/mod.rs +++ b/core/server/src/shard/tasks/mod.rs @@ -16,63 +16,54 @@ * under the License. */ -//! Task management system for shard operations -//! -//! This module provides a unified framework for managing all asynchronous tasks -//! within a shard, including servers, periodic maintenance, and one-shot operations. - -pub mod builder; pub mod continuous; pub mod periodic; -pub mod shutdown; -pub mod specs; -pub mod supervisor; -pub mod tls; - -pub use shutdown::{Shutdown, ShutdownToken}; -pub use specs::{TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec}; -pub use supervisor::TaskSupervisor; -pub use tls::{init_supervisor, is_supervisor_initialized, task_supervisor}; use crate::shard::IggyShard; -use specs::IsNoOp; +use crate::shard::task_registry::TaskRegistry; use std::rc::Rc; -use std::time::Duration; - -/// Create all task specifications for a shard -/// -/// This is the central place where all tasks are declared and configured. -/// The supervisor will filter tasks based on their scope and the shard's properties. -pub fn shard_task_specs(shard: Rc<IggyShard>) -> Vec<Box<dyn TaskSpec>> { - let mut specs: Vec<Box<dyn TaskSpec>> = vec![]; - // Continuous tasks - servers and message processing - specs.push(Box::new(continuous::MessagePump::new(shard.clone()))); +pub fn register_all(reg: &TaskRegistry, shard: Rc<IggyShard>) { + reg.spawn_continuous( + shard.clone(), + Box::new(continuous::MessagePump::new(shard.clone())), + ); if shard.config.tcp.enabled { - specs.push(Box::new(continuous::TcpServer::new(shard.clone()))); + reg.spawn_continuous( + shard.clone(), + Box::new(continuous::TcpServer::new(shard.clone())), + ); } if shard.config.http.enabled { - specs.push(Box::new(continuous::HttpServer::new(shard.clone()))); + reg.spawn_continuous( + shard.clone(), + Box::new(continuous::HttpServer::new(shard.clone())), + ); } if shard.config.quic.enabled { - specs.push(Box::new(continuous::QuicServer::new(shard.clone()))); + reg.spawn_continuous( + shard.clone(), + Box::new(continuous::QuicServer::new(shard.clone())), + ); } - // Periodic tasks - maintenance and monitoring if shard.config.message_saver.enabled { let period = shard.config.message_saver.interval.get_duration(); - specs.push(Box::new(periodic::SaveMessages::new(shard.clone(), period))); + reg.spawn_periodic( + shard.clone(), + Box::new(periodic::SaveMessages::new(shard.clone(), period)), + ); } if shard.config.heartbeat.enabled { let period = shard.config.heartbeat.interval.get_duration(); - specs.push(Box::new(periodic::VerifyHeartbeats::new( + reg.spawn_periodic( shard.clone(), - period, - ))); + Box::new(periodic::VerifyHeartbeats::new(shard.clone(), period)), + ); } if shard.config.personal_access_token.cleaner.enabled { @@ -82,26 +73,32 @@ pub fn shard_task_specs(shard: Rc<IggyShard>) -> Vec<Box<dyn TaskSpec>> { .cleaner .interval .get_duration(); - specs.push(Box::new(periodic::ClearPersonalAccessTokens::new( + reg.spawn_periodic( shard.clone(), - period, - ))); + Box::new(periodic::ClearPersonalAccessTokens::new( + shard.clone(), + period, + )), + ); } - // System info printing (leader only) - let sysinfo_period = shard + if shard .config .system .logging .sysinfo_print_interval - .get_duration(); - if !sysinfo_period.is_zero() { - specs.push(Box::new(periodic::PrintSysinfo::new( + .as_micros() + > 0 + { + let period = shard + .config + .system + .logging + .sysinfo_print_interval + .get_duration(); + reg.spawn_periodic( shard.clone(), - sysinfo_period, - ))); + Box::new(periodic::PrintSysinfo::new(shard.clone(), period)), + ); } - - // Filter out no-op tasks - specs.into_iter().filter(|spec| !spec.is_noop()).collect() } diff --git a/core/server/src/shard/tasks/periodic/clear_jwt_tokens.rs b/core/server/src/shard/tasks/periodic/clear_jwt_tokens.rs new file mode 100644 index 00000000..cfb72ac2 --- /dev/null +++ b/core/server/src/shard/tasks/periodic/clear_jwt_tokens.rs @@ -0,0 +1,92 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::http::shared::AppState; +use crate::shard::task_registry::{PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use iggy_common::IggyTimestamp; +use std::fmt::Debug; +use std::sync::Arc; +use std::time::Duration; +use tracing::{error, info, trace}; + +pub struct ClearJwtTokens { + app_state: Arc<AppState>, + period: Duration, +} + +impl Debug for ClearJwtTokens { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClearJwtTokens") + .field("period", &self.period) + .finish() + } +} + +impl ClearJwtTokens { + pub fn new(app_state: Arc<AppState>, period: Duration) -> Self { + Self { app_state, period } + } +} + +impl TaskMeta for ClearJwtTokens { + fn name(&self) -> &'static str { + "clear_jwt_tokens" + } + + fn scope(&self) -> TaskScope { + TaskScope::SpecificShard(0) + } + + fn on_start(&self) { + info!( + "JWT token cleaner is enabled, expired revoked tokens will be deleted every: {:?}.", + self.period + ); + } +} + +impl PeriodicTask for ClearJwtTokens { + fn period(&self) -> Duration { + self.period + } + + fn tick(&mut self, _ctx: &TaskCtx) -> TaskFuture { + let app_state = self.app_state.clone(); + + Box::pin(async move { + trace!("Checking for expired revoked JWT tokens..."); + + let now = IggyTimestamp::now().to_secs(); + + match app_state + .jwt_manager + .delete_expired_revoked_tokens(now) + .await + { + Ok(()) => { + trace!("Successfully cleaned up expired revoked JWT tokens"); + } + Err(err) => { + error!("Failed to delete expired revoked JWT tokens: {}", err); + } + } + + Ok(()) + }) + } +} diff --git a/core/server/src/shard/tasks/periodic/clear_personal_access_tokens.rs b/core/server/src/shard/tasks/periodic/clear_personal_access_tokens.rs index 1df3268d..bf1c6076 100644 --- a/core/server/src/shard/tasks/periodic/clear_personal_access_tokens.rs +++ b/core/server/src/shard/tasks/periodic/clear_personal_access_tokens.rs @@ -17,16 +17,14 @@ */ use crate::shard::IggyShard; -use crate::shard::tasks::specs::{ - PeriodicSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, -}; +use crate::shard::task_registry::{PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; use iggy_common::IggyTimestamp; use std::fmt::Debug; use std::rc::Rc; +use std::sync::Arc; use std::time::Duration; -use tracing::{debug, info, trace}; +use tracing::{info, trace}; -/// Periodic task for cleaning expired personal access tokens pub struct ClearPersonalAccessTokens { shard: Rc<IggyShard>, period: Duration, @@ -47,76 +45,62 @@ impl ClearPersonalAccessTokens { } } -impl TaskSpec for ClearPersonalAccessTokens { +impl TaskMeta for ClearPersonalAccessTokens { fn name(&self) -> &'static str { "clear_personal_access_tokens" } - fn kind(&self) -> TaskKind { - TaskKind::Periodic { - period: self.period, - } - } - fn scope(&self) -> TaskScope { - // Only clean tokens on shard 0 to avoid conflicts - TaskScope::SpecificShard(0) + TaskScope::AllShards } - fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { - Box::pin(async move { - info!( - "Personal access token cleaner is enabled, expired tokens will be deleted every: {:?}.", - self.period - ); - - loop { - if !ctx.shutdown.sleep_or_shutdown(self.period).await { - trace!("Personal access token cleaner shutting down"); - break; - } - - trace!("Cleaning expired personal access tokens..."); + fn on_start(&self) { + info!( + "Personal access token cleaner is enabled, expired tokens will be deleted every: {:?}.", + self.period + ); + } +} - let users = self.shard.users.borrow(); - let now = IggyTimestamp::now(); - let mut deleted_tokens_count = 0; +impl PeriodicTask for ClearPersonalAccessTokens { + fn period(&self) -> Duration { + self.period + } - for (_, user) in users.iter() { - let expired_tokens = user - .personal_access_tokens - .iter() - .filter(|token| token.is_expired(now)) - .map(|token| token.token.clone()) - .collect::<Vec<_>>(); + fn tick(&mut self, _ctx: &TaskCtx) -> TaskFuture { + let shard = self.shard.clone(); - for token in expired_tokens { - debug!( - "Personal access token: {} for user with ID: {} is expired.", - token, user.id - ); - deleted_tokens_count += 1; - user.personal_access_tokens.remove(&token); - debug!( - "Deleted personal access token: {} for user with ID: {}.", - token, user.id + Box::pin(async move { + trace!("Checking for expired personal access tokens..."); + + let now = IggyTimestamp::now(); + let mut total_removed = 0; + + let users = shard.users.borrow(); + for user in users.values() { + let expired_tokens: Vec<Arc<String>> = user + .personal_access_tokens + .iter() + .filter(|entry| entry.value().is_expired(now)) + .map(|entry| entry.key().clone()) + .collect(); + + for token_hash in expired_tokens { + if let Some((_, pat)) = user.personal_access_tokens.remove(&token_hash) { + info!( + "Removed expired personal access token '{}' for user ID {}", + pat.name, user.id ); + total_removed += 1; } } + } - info!( - "Deleted {} expired personal access tokens.", - deleted_tokens_count - ); + if total_removed > 0 { + info!("Removed {total_removed} expired personal access tokens"); } Ok(()) }) } } - -impl PeriodicSpec for ClearPersonalAccessTokens { - fn period(&self) -> Duration { - self.period - } -} diff --git a/core/server/src/shard/tasks/periodic/mod.rs b/core/server/src/shard/tasks/periodic/mod.rs index 493230a9..ca808dca 100644 --- a/core/server/src/shard/tasks/periodic/mod.rs +++ b/core/server/src/shard/tasks/periodic/mod.rs @@ -16,11 +16,13 @@ * under the License. */ +pub mod clear_jwt_tokens; pub mod clear_personal_access_tokens; pub mod print_sysinfo; pub mod save_messages; pub mod verify_heartbeats; +pub use clear_jwt_tokens::ClearJwtTokens; pub use clear_personal_access_tokens::ClearPersonalAccessTokens; pub use print_sysinfo::PrintSysinfo; pub use save_messages::SaveMessages; diff --git a/core/server/src/shard/tasks/periodic/print_sysinfo.rs b/core/server/src/shard/tasks/periodic/print_sysinfo.rs index 1866f867..b1c91f86 100644 --- a/core/server/src/shard/tasks/periodic/print_sysinfo.rs +++ b/core/server/src/shard/tasks/periodic/print_sysinfo.rs @@ -17,17 +17,13 @@ */ use crate::shard::IggyShard; -use crate::shard::tasks::specs::{ - PeriodicSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, -}; -use crate::streaming::utils::memory_pool; +use crate::shard::task_registry::{PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; use human_repr::HumanCount; use std::fmt::Debug; use std::rc::Rc; use std::time::Duration; use tracing::{error, info, trace}; -/// Periodic task for printing system information pub struct PrintSysinfo { shard: Rc<IggyShard>, period: Duration, @@ -48,70 +44,36 @@ impl PrintSysinfo { } } -impl TaskSpec for PrintSysinfo { +impl TaskMeta for PrintSysinfo { fn name(&self) -> &'static str { "print_sysinfo" } - fn kind(&self) -> TaskKind { - TaskKind::Periodic { - period: self.period, - } - } - fn scope(&self) -> TaskScope { - // Only print sysinfo from shard 0 to avoid duplicate logs TaskScope::SpecificShard(0) } - fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { + fn on_start(&self) { + info!( + "System info logger is enabled, OS info will be printed every: {:?}", + self.period + ); + } +} + +impl PeriodicTask for PrintSysinfo { + fn period(&self) -> Duration { + self.period + } + + fn tick(&mut self, _ctx: &TaskCtx) -> TaskFuture { + let shard = self.shard.clone(); + Box::pin(async move { - if self.period == Duration::ZERO { - info!("SysInfoPrinter is disabled."); - return Ok(()); - } + trace!("Fetching OS info..."); - info!( - "SysInfoPrinter is enabled, system information will be printed every {:?}.", - self.period - ); - - loop { - if !ctx.shutdown.sleep_or_shutdown(self.period).await { - trace!("SysInfoPrinter shutting down"); - break; - } - - trace!("Printing system information..."); - - let stats = match self.shard.get_stats().await { - Ok(stats) => stats, - Err(e) => { - error!("Failed to get system information. Error: {e}"); - continue; - } - }; - - let free_memory_percent = (stats.available_memory.as_bytes_u64() as f64 - / stats.total_memory.as_bytes_u64() as f64) - * 100f64; - - info!( - "CPU: {:.2}%/{:.2}% (IggyUsage/Total), Mem: {:.2}%/{}/{}/{} (Free/IggyUsage/TotalUsed/Total), Clients: {}, Messages processed: {}, Read: {}, Written: {}, Uptime: {}", - stats.cpu_usage, - stats.total_cpu_usage, - free_memory_percent, - stats.memory_usage, - stats.total_memory - stats.available_memory, - stats.total_memory, - stats.clients_count, - stats.messages_count.human_count_bare(), - stats.read_bytes, - stats.written_bytes, - stats.run_time - ); - - memory_pool().log_stats(); + if let Err(e) = print_os_info() { + error!("Failed to print system info: {}", e); } Ok(()) @@ -119,8 +81,43 @@ impl TaskSpec for PrintSysinfo { } } -impl PeriodicSpec for PrintSysinfo { - fn period(&self) -> Duration { - self.period +fn print_os_info() -> Result<(), String> { + let mut sys = sysinfo::System::new(); + sys.refresh_memory(); + + let available_memory = sys.available_memory(); + let free_memory = sys.free_memory(); + let used_memory = sys.used_memory(); + let total_memory = sys.total_memory(); + + info!( + "Memory -> available: {}, free: {}, used: {}, total: {}.", + available_memory.human_count_bytes(), + free_memory.human_count_bytes(), + used_memory.human_count_bytes(), + total_memory.human_count_bytes() + ); + + let disks = sysinfo::Disks::new_with_refreshed_list(); + for disk in disks.list() { + let name = disk.name().to_string_lossy(); + let mount_point = disk.mount_point(); + let available_space = disk.available_space(); + let total_space = disk.total_space(); + let used_space = total_space - available_space; + let file_system = disk.file_system().to_string_lossy(); + let is_removable = disk.is_removable(); + info!( + "Disk: {}, mounted: {}, removable: {}, file system: {}, available: {}, used: {}, total: {}", + name, + mount_point.display(), + is_removable, + file_system, + available_space.human_count_bytes(), + used_space.human_count_bytes(), + total_space.human_count_bytes() + ); } + + Ok(()) } diff --git a/core/server/src/shard/tasks/periodic/save_messages.rs b/core/server/src/shard/tasks/periodic/save_messages.rs index 5599cedd..0ae705a5 100644 --- a/core/server/src/shard/tasks/periodic/save_messages.rs +++ b/core/server/src/shard/tasks/periodic/save_messages.rs @@ -17,8 +17,8 @@ */ use crate::shard::IggyShard; -use crate::shard::tasks::specs::{ - PeriodicSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, +use crate::shard::task_registry::{ + PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskResult, TaskScope, }; use crate::shard_info; use iggy_common::Identifier; @@ -27,7 +27,6 @@ use std::rc::Rc; use std::time::Duration; use tracing::{error, info, trace}; -/// Periodic task for saving buffered messages to disk pub struct SaveMessages { shard: Rc<IggyShard>, period: Duration, @@ -48,89 +47,77 @@ impl SaveMessages { } } -impl TaskSpec for SaveMessages { +impl TaskMeta for SaveMessages { fn name(&self) -> &'static str { "save_messages" } - fn kind(&self) -> TaskKind { - TaskKind::Periodic { - period: self.period, - } - } - fn scope(&self) -> TaskScope { TaskScope::AllShards } - fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { - Box::pin(async move { - let enforce_fsync = self.shard.config.message_saver.enforce_fsync; - info!( - "Message saver is enabled, buffered messages will be automatically saved every: {:?}, enforce fsync: {enforce_fsync}.", - self.period - ); - - loop { - if !ctx.shutdown.sleep_or_shutdown(self.period).await { - trace!("Message saver shutting down"); - break; - } - - trace!("Saving buffered messages..."); - - let namespaces = self.shard.get_current_shard_namespaces(); - let mut total_saved_messages = 0u32; - let reason = "background saver triggered".to_string(); + fn on_start(&self) { + let enforce_fsync = self.shard.config.message_saver.enforce_fsync; + info!( + "Message saver is enabled, buffered messages will be automatically saved every: {:?}, enforce fsync: {enforce_fsync}.", + self.period + ); + } +} - for ns in namespaces { - let stream_id = Identifier::numeric(ns.stream_id() as u32).unwrap(); - let topic_id = Identifier::numeric(ns.topic_id() as u32).unwrap(); - let partition_id = ns.partition_id(); +impl PeriodicTask for SaveMessages { + fn period(&self) -> Duration { + self.period + } - match self - .shard - .streams2 - .persist_messages( - self.shard.id, - &stream_id, - &topic_id, - partition_id, - reason.clone(), - &self.shard.config.system, - ) - .await - { - Ok(batch_count) => { - total_saved_messages += batch_count; - } - Err(err) => { - error!( - "Failed to save messages for partition {}: {}", - partition_id, err - ); - } + fn tick(&mut self, ctx: &TaskCtx) -> TaskFuture { + let shard = ctx.shard.clone(); + Box::pin(async move { + trace!("Saving buffered messages..."); + + let namespaces = shard.get_current_shard_namespaces(); + let mut total_saved_messages = 0u32; + let reason = "background saver triggered".to_string(); + + for ns in namespaces { + let stream_id = Identifier::numeric(ns.stream_id() as u32).unwrap(); + let topic_id = Identifier::numeric(ns.topic_id() as u32).unwrap(); + let partition_id = ns.partition_id(); + + match shard + .streams2 + .persist_messages( + shard.id, + &stream_id, + &topic_id, + partition_id, + reason.clone(), + &shard.config.system, + ) + .await + { + Ok(batch_count) => { + total_saved_messages += batch_count; + } + Err(err) => { + error!( + "Failed to save messages for partition {}: {}", + partition_id, err + ); } } + } - if total_saved_messages > 0 { - shard_info!( - self.shard.id, - "Saved {} buffered messages on disk.", - total_saved_messages - ); - } - - trace!("Finished saving buffered messages."); + if total_saved_messages > 0 { + shard_info!( + shard.id, + "Saved {} buffered messages on disk.", + total_saved_messages + ); } + trace!("Finished saving buffered messages."); Ok(()) }) } } - -impl PeriodicSpec for SaveMessages { - fn period(&self) -> Duration { - self.period - } -} diff --git a/core/server/src/shard/tasks/periodic/verify_heartbeats.rs b/core/server/src/shard/tasks/periodic/verify_heartbeats.rs index a2ce9685..5810d310 100644 --- a/core/server/src/shard/tasks/periodic/verify_heartbeats.rs +++ b/core/server/src/shard/tasks/periodic/verify_heartbeats.rs @@ -17,9 +17,7 @@ */ use crate::shard::IggyShard; -use crate::shard::tasks::specs::{ - PeriodicSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, -}; +use crate::shard::task_registry::{PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; use iggy_common::{IggyDuration, IggyTimestamp}; use std::fmt::Debug; use std::rc::Rc; @@ -28,10 +26,10 @@ use tracing::{debug, info, trace, warn}; const MAX_THRESHOLD: f64 = 1.2; -/// Periodic task for verifying client heartbeats and removing stale clients pub struct VerifyHeartbeats { shard: Rc<IggyShard>, period: Duration, + max_interval: IggyDuration, } impl Debug for VerifyHeartbeats { @@ -45,87 +43,84 @@ impl Debug for VerifyHeartbeats { impl VerifyHeartbeats { pub fn new(shard: Rc<IggyShard>, period: Duration) -> Self { - Self { shard, period } + let interval = IggyDuration::from(period); + let max_interval = IggyDuration::from((MAX_THRESHOLD * interval.as_micros() as f64) as u64); + Self { + shard, + period, + max_interval, + } } } -impl TaskSpec for VerifyHeartbeats { +impl TaskMeta for VerifyHeartbeats { fn name(&self) -> &'static str { "verify_heartbeats" } - fn kind(&self) -> TaskKind { - TaskKind::Periodic { - period: self.period, - } - } - fn scope(&self) -> TaskScope { TaskScope::AllShards } - fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { - Box::pin(async move { - let interval = IggyDuration::from(self.period); - let max_interval = - IggyDuration::from((MAX_THRESHOLD * interval.as_micros() as f64) as u64); - info!( - "Heartbeats will be verified every: {interval}. Max allowed interval: {max_interval}." - ); - - loop { - if !ctx.shutdown.sleep_or_shutdown(self.period).await { - trace!("Heartbeat verifier shutting down"); - break; - } + fn on_start(&self) { + info!( + "Heartbeats will be verified every: {}. Max allowed interval: {}.", + IggyDuration::from(self.period), + self.max_interval + ); + } +} - trace!("Verifying heartbeats..."); - - let clients = { - let client_manager = self.shard.client_manager.borrow(); - client_manager.get_clients() - }; - - let now = IggyTimestamp::now(); - let heartbeat_to = IggyTimestamp::from(now.as_micros() - max_interval.as_micros()); - debug!("Verifying heartbeats at: {now}, max allowed timestamp: {heartbeat_to}"); - - let mut stale_clients = Vec::new(); - for client in clients { - if client.last_heartbeat.as_micros() < heartbeat_to.as_micros() { - warn!( - "Stale client session: {}, last heartbeat at: {}, max allowed timestamp: {heartbeat_to}", - client.session, client.last_heartbeat, - ); - client.session.set_stale(); - stale_clients.push(client.session.client_id); - } else { - debug!( - "Valid heartbeat at: {} for client session: {}, max allowed timestamp: {heartbeat_to}", - client.last_heartbeat, client.session, - ); - } - } +impl PeriodicTask for VerifyHeartbeats { + fn period(&self) -> Duration { + self.period + } - if stale_clients.is_empty() { - continue; - } + fn tick(&mut self, _ctx: &TaskCtx) -> TaskFuture { + let shard = self.shard.clone(); + let max_interval = self.max_interval; - let count = stale_clients.len(); - info!("Removing {count} stale clients..."); - for client_id in stale_clients { - self.shard.delete_client(client_id); + Box::pin(async move { + trace!("Verifying heartbeats..."); + + let clients = { + let client_manager = shard.client_manager.borrow(); + client_manager.get_clients() + }; + + let now = IggyTimestamp::now(); + let heartbeat_to = IggyTimestamp::from(now.as_micros() - max_interval.as_micros()); + debug!("Verifying heartbeats at: {now}, max allowed timestamp: {heartbeat_to}"); + + let mut stale_clients = Vec::new(); + for client in clients { + if client.last_heartbeat.as_micros() < heartbeat_to.as_micros() { + warn!( + "Stale client session: {}, last heartbeat at: {}, max allowed timestamp: {heartbeat_to}", + client.session, client.last_heartbeat, + ); + client.session.set_stale(); + stale_clients.push(client.session.client_id); + } else { + debug!( + "Valid heartbeat at: {} for client session: {}, max allowed timestamp: {heartbeat_to}", + client.last_heartbeat, client.session, + ); } - info!("Removed {count} stale clients."); } + if stale_clients.is_empty() { + return Ok(()); + } + + let count = stale_clients.len(); + info!("Removing {count} stale clients..."); + for client_id in stale_clients { + shard.delete_client(client_id); + } + info!("Removed {count} stale clients."); + Ok(()) }) } } - -impl PeriodicSpec for VerifyHeartbeats { - fn period(&self) -> Duration { - self.period - } -} diff --git a/core/server/src/shard/tasks/specs.rs b/core/server/src/shard/tasks/specs.rs deleted file mode 100644 index 2b3340ea..00000000 --- a/core/server/src/shard/tasks/specs.rs +++ /dev/null @@ -1,203 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::shard::IggyShard; -use iggy_common::IggyError; -use std::fmt::Debug; -use std::future::Future; -use std::pin::Pin; -use std::rc::Rc; -use std::time::Duration; -use strum::Display; - -use super::shutdown::ShutdownToken; - -/// Context provided to all tasks when they run -#[derive(Clone)] -pub struct TaskCtx { - pub shard: Rc<IggyShard>, - pub shutdown: ShutdownToken, -} - -/// Future type returned by task run methods -pub type TaskFuture = Pin<Box<dyn Future<Output = Result<(), IggyError>>>>; - -/// Describes the kind of task and its scheduling behavior -#[derive(Debug, Clone, Display)] -pub enum TaskKind { - /// Runs continuously until shutdown - Continuous, - /// Runs periodically at specified intervals - Periodic { period: Duration }, - /// Runs once and completes - OneShot, -} - -/// Determines which shards should run this task -#[derive(Clone, Debug)] -pub enum TaskScope { - /// Run on all shards - AllShards, - /// Run on a specific shard by ID - SpecificShard(u16), -} - -impl TaskScope { - /// Check if this task should run on the given shard - pub fn should_run(&self, shard: &IggyShard) -> bool { - match self { - TaskScope::AllShards => true, - TaskScope::SpecificShard(id) => shard.id == *id, - } - } -} - -/// Core trait that all tasks must implement -pub trait TaskSpec: Debug { - /// Unique name for this task - fn name(&self) -> &'static str; - - /// The kind of task (continuous, periodic, or oneshot) - fn kind(&self) -> TaskKind; - - /// Scope determining which shards run this task - fn scope(&self) -> TaskScope; - - /// Run the task with the provided context - fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture; - - /// Optional: called before the task starts - fn on_start(&self) { - tracing::info!("Starting task: {}", self.name()); - } - - /// Optional: called after the task completes - fn on_complete(&self, result: &Result<(), IggyError>) { - match result { - Ok(()) => tracing::info!("Task {} completed successfully", self.name()), - Err(e) => tracing::error!("Task {} failed: {}", self.name(), e), - } - } - - /// Optional: whether this task is critical (failure should stop the shard) - fn is_critical(&self) -> bool { - false - } -} - -/// Marker trait for continuous tasks -pub trait ContinuousSpec: TaskSpec { - fn as_task_spec(self: Box<Self>) -> Box<dyn TaskSpec> - where - Self: 'static + Sized, - { - self as Box<dyn TaskSpec> - } -} - -/// Marker trait for periodic tasks -pub trait PeriodicSpec: TaskSpec { - /// Get the period for this periodic task - fn period(&self) -> Duration; - - /// Optional: whether to run one final tick on shutdown - fn last_tick_on_shutdown(&self) -> bool { - false - } - - fn as_task_spec(self: Box<Self>) -> Box<dyn TaskSpec> - where - Self: 'static + Sized, - { - self as Box<dyn TaskSpec> - } -} - -/// Marker trait for oneshot tasks -pub trait OneShotSpec: TaskSpec { - /// Optional: timeout for this oneshot task - fn timeout(&self) -> Option<Duration> { - None - } - - fn as_task_spec(self: Box<Self>) -> Box<dyn TaskSpec> - where - Self: 'static + Sized, - { - self as Box<dyn TaskSpec> - } -} - -/// A no-op task that does nothing (useful for conditional task creation) -#[derive(Debug)] -pub struct NoOpTask; - -impl TaskSpec for NoOpTask { - fn name(&self) -> &'static str { - "noop" - } - - fn kind(&self) -> TaskKind { - TaskKind::OneShot - } - - fn scope(&self) -> TaskScope { - TaskScope::AllShards - } - - fn run(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { - Box::pin(async { Ok(()) }) - } -} - -/// Helper to create a no-op task -pub fn noop() -> Box<dyn TaskSpec> { - Box::new(NoOpTask) -} - -/// Helper trait to check if a task is a no-op -pub trait IsNoOp { - fn is_noop(&self) -> bool; -} - -impl IsNoOp for Box<dyn TaskSpec> { - fn is_noop(&self) -> bool { - self.name() == "noop" - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_task_scope() { - let shard = IggyShard::default_from_config(Default::default()); - - assert!(TaskScope::AllShards.should_run(&shard)); - assert!(TaskScope::SpecificShard(0).should_run(&shard)); // shard.id == 0 by default - assert!(!TaskScope::SpecificShard(1).should_run(&shard)); - } - - #[test] - fn test_noop_task() { - let task = noop(); - assert!(task.is_noop()); - assert_eq!(task.name(), "noop"); - } -} diff --git a/core/server/src/shard/tasks/supervisor.rs b/core/server/src/shard/tasks/supervisor.rs deleted file mode 100644 index e1efa04a..00000000 --- a/core/server/src/shard/tasks/supervisor.rs +++ /dev/null @@ -1,558 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use super::builder::{ContinuousTaskBuilder, OneshotTaskBuilder, PeriodicTaskBuilder}; -use super::shutdown::{Shutdown, ShutdownToken}; -use super::specs::{TaskCtx, TaskKind, TaskScope, TaskSpec}; -use crate::shard::IggyShard; -use compio::runtime::JoinHandle; -use futures::Future; -use futures::future::join_all; -use iggy_common::IggyError; -use std::cell::RefCell; -use std::collections::HashMap; -use std::rc::Rc; -use std::time::Duration; -use tracing::{debug, error, info, trace, warn}; - -/// Handle to a spawned task -struct TaskHandle { - name: String, - kind: TaskKind, - handle: JoinHandle<Result<(), IggyError>>, - is_critical: bool, -} - -/// Supervises the lifecycle of all tasks in a shard -pub struct TaskSupervisor { - pub(crate) shard_id: u16, - shutdown: Shutdown, - shutdown_token: ShutdownToken, - tasks: RefCell<Vec<TaskHandle>>, - oneshot_handles: RefCell<Vec<TaskHandle>>, - active_connections: RefCell<HashMap<u32, async_channel::Sender<()>>>, -} - -impl TaskSupervisor { - /// Create a new task supervisor for a shard - pub fn new(shard_id: u16) -> Self { - let (shutdown, shutdown_token) = Shutdown::new(); - - Self { - shard_id, - shutdown, - shutdown_token, - tasks: RefCell::new(Vec::new()), - oneshot_handles: RefCell::new(Vec::new()), - active_connections: RefCell::new(HashMap::new()), - } - } - - /// Get a shutdown token for tasks - pub fn shutdown_token(&self) -> ShutdownToken { - self.shutdown_token.clone() - } - - /// Create a builder for a periodic task - pub fn periodic(&self, name: &'static str) -> PeriodicTaskBuilder { - PeriodicTaskBuilder::new(self, name) - } - - /// Create a builder for a continuous task - pub fn continuous(&self, name: &'static str) -> ContinuousTaskBuilder { - ContinuousTaskBuilder::new(self, name) - } - - /// Create a builder for a oneshot task - pub fn oneshot(&self, name: &'static str) -> OneshotTaskBuilder { - OneshotTaskBuilder::new(self, name) - } - - /// Spawn all tasks according to their specifications - pub fn spawn(&self, shard: Rc<IggyShard>, specs: Vec<Box<dyn TaskSpec>>) { - for spec in specs { - // Check if task should run on this shard - if !spec.scope().should_run(&shard) { - trace!( - "Skipping task {} ({}) on shard {} due to scope", - spec.name(), - spec.kind(), - self.shard_id - ); - continue; - } - - // Skip no-op tasks - if spec.name() == "noop" { - continue; - } - - self.spawn_single(shard.clone(), spec); - } - - info!( - "Shard {} spawned {} tasks ({} oneshot)", - self.shard_id, - self.tasks.borrow().len(), - self.oneshot_handles.borrow().len() - ); - } - - /// Spawn a single task based on its kind - fn spawn_single(&self, shard: Rc<IggyShard>, spec: Box<dyn TaskSpec>) { - let name = spec.name(); - let kind = spec.kind(); - let is_critical = spec.is_critical(); - - info!( - "Spawning {} task '{}' on shard {}", - match &kind { - TaskKind::Continuous => "continuous", - TaskKind::Periodic { .. } => "periodic", - TaskKind::OneShot => "oneshot", - }, - name, - self.shard_id - ); - - spec.on_start(); - - let ctx = TaskCtx { - shard, - shutdown: self.shutdown_token.clone(), - }; - - let handle = match kind { - TaskKind::Continuous => self.spawn_continuous(spec, ctx), - TaskKind::Periodic { period } => self.spawn_periodic(spec, ctx, period), - TaskKind::OneShot => self.spawn_oneshot_spec(spec, ctx), - }; - - let task_handle = TaskHandle { - name: name.to_string(), - kind: kind.clone(), - handle, - is_critical, - }; - - // Store handle based on kind - match kind { - TaskKind::OneShot => self.oneshot_handles.borrow_mut().push(task_handle), - _ => self.tasks.borrow_mut().push(task_handle), - } - } - - /// Spawn a continuous task - fn spawn_continuous( - &self, - spec: Box<dyn TaskSpec>, - ctx: TaskCtx, - ) -> JoinHandle<Result<(), IggyError>> { - let name = spec.name(); - let shard_id = self.shard_id; - - compio::runtime::spawn(async move { - trace!("Continuous task '{}' starting on shard {}", name, shard_id); - let result = spec.run(ctx).await; - - match &result { - Ok(()) => debug!("Continuous task '{}' completed on shard {}", name, shard_id), - Err(e) => error!( - "Continuous task '{}' failed on shard {}: {}", - name, shard_id, e - ), - } - - result - }) - } - - /// Spawn a periodic task - fn spawn_periodic( - &self, - spec: Box<dyn TaskSpec>, - ctx: TaskCtx, - period: Duration, - ) -> JoinHandle<Result<(), IggyError>> { - let name = spec.name(); - let shard_id = self.shard_id; - - compio::runtime::spawn(async move { - trace!( - "Periodic task '{}' starting on shard {} with period {:?}", - name, shard_id, period - ); - - // Periodic tasks handle their own loop internally - let result = spec.run(ctx).await; - - match &result { - Ok(()) => debug!("Periodic task '{}' completed on shard {}", name, shard_id), - Err(e) => error!( - "Periodic task '{}' failed on shard {}: {}", - name, shard_id, e - ), - } - - result - }) - } - - /// Spawn a oneshot task from a TaskSpec - fn spawn_oneshot_spec( - &self, - spec: Box<dyn TaskSpec>, - ctx: TaskCtx, - ) -> JoinHandle<Result<(), IggyError>> { - let name = spec.name(); - let shard_id = self.shard_id; - - compio::runtime::spawn(async move { - trace!("OneShot task '{}' starting on shard {}", name, shard_id); - let result = spec.run(ctx).await; - - match &result { - Ok(()) => debug!("OneShot task '{}' completed on shard {}", name, shard_id), - Err(e) => error!( - "OneShot task '{}' failed on shard {}: {}", - name, shard_id, e - ), - } - - result - }) - } - - /// Trigger graceful shutdown for all tasks - pub async fn graceful_shutdown(&self, timeout: Duration) -> bool { - debug!( - "Initiating graceful shutdown for {} tasks on shard {}", - self.tasks.borrow().len() + self.oneshot_handles.borrow().len(), - self.shard_id - ); - - // First shutdown connections - self.shutdown_connections(); - - // Trigger shutdown signal - self.shutdown.trigger(); - - // Wait for continuous and periodic tasks with timeout - let continuous_periodic_tasks = self.tasks.take(); - let continuous_periodic_complete = self - .await_tasks_with_timeout(continuous_periodic_tasks, timeout) - .await; - - // Always wait for oneshot tasks (no timeout for durability) - let oneshot_tasks = self.oneshot_handles.take(); - let oneshot_complete = self.await_oneshot_tasks(oneshot_tasks).await; - - let all_complete = continuous_periodic_complete && oneshot_complete; - - if !all_complete { - warn!( - "Some tasks did not shutdown cleanly on shard {}", - self.shard_id - ); - } - - all_complete - } - - /// Wait for tasks with a timeout - async fn await_tasks_with_timeout(&self, tasks: Vec<TaskHandle>, timeout: Duration) -> bool { - if tasks.is_empty() { - return true; - } - - let task_count = tasks.len(); - let timeout_futures: Vec<_> = tasks - .into_iter() - .map(|task| async move { - let name = task.name.clone(); - let kind = task.kind; - let is_critical = task.is_critical; - - match compio::time::timeout(timeout, task.handle).await { - Ok(Ok(Ok(()))) => { - trace!("Task '{}' ({}) shutdown gracefully", name, kind); - (name, true, false) - } - Ok(Ok(Err(e))) => { - error!("Task '{}' ({}) failed during shutdown: {}", name, kind, e); - (name, false, is_critical) - } - Ok(Err(e)) => { - error!( - "Task '{}' ({}) panicked during shutdown: {:?}", - name, kind, e - ); - (name, false, is_critical) - } - Err(_) => { - warn!("Task '{}' ({}) did not complete within timeout", name, kind); - (name, false, is_critical) - } - } - }) - .collect(); - - let results = join_all(timeout_futures).await; - - let completed = results.iter().filter(|(_, success, _)| *success).count(); - let critical_failures = results.iter().any(|(_, _, critical)| *critical); - - info!( - "Shard {} shutdown: {}/{} tasks completed", - self.shard_id, completed, task_count - ); - - if critical_failures { - error!("Critical task(s) failed on shard {}", self.shard_id); - } - - completed == task_count && !critical_failures - } - - /// Wait for oneshot tasks (no timeout for durability) - async fn await_oneshot_tasks(&self, tasks: Vec<TaskHandle>) -> bool { - if tasks.is_empty() { - return true; - } - - info!( - "Waiting for {} oneshot tasks to complete on shard {}", - tasks.len(), - self.shard_id - ); - - let futures: Vec<_> = tasks - .into_iter() - .map(|task| async move { - let name = task.name.clone(); - match task.handle.await { - Ok(Ok(())) => { - trace!("OneShot task '{}' completed", name); - true - } - Ok(Err(e)) => { - error!("OneShot task '{}' failed: {}", name, e); - false - } - Err(e) => { - error!("OneShot task '{}' panicked: {:?}", name, e); - false - } - } - }) - .collect(); - - let results = join_all(futures).await; - let all_complete = results.iter().all(|&r| r); - - if all_complete { - debug!("All oneshot tasks completed on shard {}", self.shard_id); - } else { - error!("Some oneshot tasks failed on shard {}", self.shard_id); - } - - all_complete - } - - /// Add a connection for tracking - pub fn add_connection(&self, client_id: u32) -> async_channel::Receiver<()> { - let (stop_sender, stop_receiver) = async_channel::bounded(1); - self.active_connections - .borrow_mut() - .insert(client_id, stop_sender); - stop_receiver - } - - /// Remove a connection from tracking - pub fn remove_connection(&self, client_id: &u32) { - self.active_connections.borrow_mut().remove(client_id); - } - - /// Spawn a tracked task (for connection handlers) - pub fn spawn_tracked<F>(&self, future: F) - where - F: Future<Output = ()> + 'static, - { - let handle = compio::runtime::spawn(async move { - future.await; - Ok(()) - }); - self.tasks.borrow_mut().push(TaskHandle { - name: "connection_handler".to_string(), - kind: TaskKind::Continuous, - handle, - is_critical: false, - }); - } - - /// Spawn a oneshot task directly without going through TaskSpec - pub fn spawn_oneshot<F>(&self, name: impl Into<String>, critical: bool, f: F) - where - F: Future<Output = Result<(), IggyError>> + 'static, - { - let name = name.into(); - let shard_id = self.shard_id; - - trace!("Spawning oneshot task '{}' on shard {}", name, shard_id); - - let task_name = name.clone(); - let handle = compio::runtime::spawn(async move { - trace!( - "OneShot task '{}' starting on shard {}", - task_name, shard_id - ); - let result = f.await; - - match &result { - Ok(()) => trace!( - "OneShot task '{}' completed on shard {}", - task_name, shard_id - ), - Err(e) => error!( - "OneShot task '{}' failed on shard {}: {}", - task_name, shard_id, e - ), - } - - result - }); - - self.oneshot_handles.borrow_mut().push(TaskHandle { - name, - kind: TaskKind::OneShot, - handle, - is_critical: critical, - }); - } - - /// Spawn a periodic task using a closure-based tick function - /// The supervisor handles the timing loop, the closure provides the tick logic - pub fn spawn_periodic_tick<F, Fut>( - &self, - name: &'static str, - period: Duration, - scope: TaskScope, - shard: Rc<IggyShard>, - mut tick: F, - ) where - F: FnMut(&TaskCtx) -> Fut + 'static, - Fut: Future<Output = Result<(), IggyError>>, - { - // Check if task should run on this shard - if !scope.should_run(&shard) { - trace!( - "Skipping periodic tick task '{}' on shard {} due to scope", - name, self.shard_id - ); - return; - } - - let ctx = TaskCtx { - shard, - shutdown: self.shutdown_token.clone(), - }; - - let shard_id = self.shard_id; - let shutdown = self.shutdown_token.clone(); - - info!( - "Spawning periodic tick task '{}' on shard {} with period {:?}", - name, shard_id, period - ); - - let handle = compio::runtime::spawn(async move { - trace!( - "Periodic tick task '{}' starting on shard {} with period {:?}", - name, shard_id, period - ); - - loop { - // Use shutdown-aware sleep - if !shutdown.sleep_or_shutdown(period).await { - trace!("Periodic tick task '{}' shutting down", name); - break; - } - - // Execute tick - match tick(&ctx).await { - Ok(()) => trace!("Periodic tick task '{}' tick completed", name), - Err(e) => error!("Periodic tick task '{}' tick failed: {}", name, e), - } - } - - trace!( - "Periodic tick task '{}' completed on shard {}", - name, shard_id - ); - Ok(()) - }); - - self.tasks.borrow_mut().push(TaskHandle { - name: name.to_string(), - kind: TaskKind::Periodic { period }, - handle, - is_critical: false, - }); - } - - /// Shutdown all connections gracefully - fn shutdown_connections(&self) { - let connections = self.active_connections.borrow(); - if connections.is_empty() { - return; - } - - info!("Shutting down {} active connections", connections.len()); - - for (client_id, stop_sender) in connections.iter() { - trace!("Sending shutdown signal to connection {}", client_id); - if let Err(e) = stop_sender.send_blocking(()) { - warn!( - "Failed to send shutdown signal to connection {}: {}", - client_id, e - ); - } - } - } - - /// Register a task handle (used by builders) - pub(crate) fn register_task( - &self, - name: &str, - kind: TaskKind, - handle: JoinHandle<Result<(), IggyError>>, - is_critical: bool, - ) { - let task_handle = TaskHandle { - name: name.to_string(), - kind: kind.clone(), - handle, - is_critical, - }; - - match kind { - TaskKind::OneShot => self.oneshot_handles.borrow_mut().push(task_handle), - _ => self.tasks.borrow_mut().push(task_handle), - } - } -} diff --git a/core/server/src/shard/tasks/tls.rs b/core/server/src/shard/tasks/tls.rs deleted file mode 100644 index 1370c5e6..00000000 --- a/core/server/src/shard/tasks/tls.rs +++ /dev/null @@ -1,141 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use super::supervisor::TaskSupervisor; -use std::cell::RefCell; -use std::rc::Rc; - -thread_local! { - static SUPERVISOR: RefCell<Option<Rc<TaskSupervisor>>> = RefCell::new(None); -} - -/// Initialize the task supervisor for this thread -pub fn init_supervisor(shard_id: u16) { - SUPERVISOR.with(|s| { - *s.borrow_mut() = Some(Rc::new(TaskSupervisor::new(shard_id))); - }); -} - -/// Get the task supervisor for the current thread -/// -/// # Panics -/// Panics if the supervisor has not been initialized for this thread -pub fn task_supervisor() -> Rc<TaskSupervisor> { - SUPERVISOR.with(|s| { - s.borrow() - .as_ref() - .expect( - "Task supervisor not initialized for this thread. Call init_supervisor() first.", - ) - .clone() - }) -} - -/// Check if the task supervisor has been initialized for this thread -pub fn is_supervisor_initialized() -> bool { - SUPERVISOR.with(|s| s.borrow().is_some()) -} - -/// Clear the task supervisor for this thread (for cleanup/testing) -pub fn clear_supervisor() { - SUPERVISOR.with(|s| { - *s.borrow_mut() = None; - }); -} - -#[cfg(test)] -mod tests { - use super::*; - use std::panic; - - #[test] - fn test_supervisor_initialization() { - // Clean any existing supervisor first - clear_supervisor(); - - // Initially, supervisor should not be initialized - assert!(!is_supervisor_initialized()); - - // Trying to get supervisor without initialization should panic - let result = panic::catch_unwind(|| { - task_supervisor(); - }); - assert!(result.is_err()); - - // Initialize supervisor - init_supervisor(42); - - // Now it should be initialized - assert!(is_supervisor_initialized()); - - // Clean up - clear_supervisor(); - assert!(!is_supervisor_initialized()); - } - - #[test] - fn test_multiple_initializations() { - clear_supervisor(); - - // First initialization - init_supervisor(1); - let _supervisor1 = task_supervisor(); - - // Second initialization (should replace the first) - init_supervisor(2); - let _supervisor2 = task_supervisor(); - - // They should be different instances (different Rc references) - // Note: We can't use ptr::eq since Rc doesn't expose the inner pointer easily - // But we can verify that initialization works multiple times - - clear_supervisor(); - } - - #[test] - fn test_thread_locality() { - use std::thread; - - clear_supervisor(); - - // Initialize in main thread - init_supervisor(100); - assert!(is_supervisor_initialized()); - - // Spawn a new thread and verify it doesn't have the supervisor - let handle = thread::spawn(|| { - // This thread should not have supervisor initialized - assert!(!is_supervisor_initialized()); - - // Initialize different supervisor in this thread - init_supervisor(200); - assert!(is_supervisor_initialized()); - - // Get supervisor to verify it works - let _ = task_supervisor(); - }); - - handle.join().expect("Thread should complete successfully"); - - // Main thread should still have its supervisor - assert!(is_supervisor_initialized()); - let _ = task_supervisor(); - - clear_supervisor(); - } -} diff --git a/core/server/src/slab/streams.rs b/core/server/src/slab/streams.rs index 8ac12f6f..5c1218c8 100644 --- a/core/server/src/slab/streams.rs +++ b/core/server/src/slab/streams.rs @@ -736,10 +736,10 @@ impl Streams { }); // Use task supervisor for proper tracking and graceful shutdown - use crate::shard::tasks::tls::task_supervisor; + use crate::shard::task_registry::tls::task_registry; use tracing::error; - task_supervisor().spawn_oneshot("fsync:segment-close-messages", true, async move { + task_registry().spawn_oneshot_future("fsync:segment-close-messages", true, async move { match log_writer.fsync().await { Ok(_) => Ok(()), Err(e) => { @@ -749,7 +749,7 @@ impl Streams { } }); - task_supervisor().spawn_oneshot("fsync:segment-close-index", true, async move { + task_registry().spawn_oneshot_future("fsync:segment-close-index", true, async move { match index_writer.fsync().await { Ok(_) => { drop(index_writer); diff --git a/core/server/src/tcp/tcp_listener.rs b/core/server/src/tcp/tcp_listener.rs index 8cda9310..22a34feb 100644 --- a/core/server/src/tcp/tcp_listener.rs +++ b/core/server/src/tcp/tcp_listener.rs @@ -19,7 +19,7 @@ use crate::binary::sender::SenderKind; use crate::configs::tcp::TcpSocketConfig; use crate::shard::IggyShard; -use crate::shard::tasks::task_supervisor; +use crate::shard::task_registry::task_registry; use crate::shard::transmission::event::ShardEvent; use crate::tcp::connection_handler::{handle_connection, handle_error}; use crate::{shard_debug, shard_error, shard_info}; @@ -181,14 +181,14 @@ async fn accept_loop( shard_info!(shard.id, "Created new session: {}", session); let mut sender = SenderKind::get_tcp_sender(stream); - let conn_stop_receiver = task_supervisor().add_connection(client_id); + let conn_stop_receiver = task_registry().add_connection(client_id); let shard_for_conn = shard_clone.clone(); - task_supervisor().spawn_tracked(async move { + task_registry().spawn_tracked(async move { if let Err(error) = handle_connection(&session, &mut sender, &shard_for_conn, conn_stop_receiver).await { handle_error(error); } - task_supervisor().remove_connection(&client_id); + task_registry().remove_connection(&client_id); if let Err(error) = sender.shutdown().await { shard_error!(shard.id, "Failed to shutdown TCP stream for client: {}, address: {}. {}", client_id, address, error); diff --git a/core/server/src/tcp/tcp_tls_listener.rs b/core/server/src/tcp/tcp_tls_listener.rs index 276eb52c..ee83f4dc 100644 --- a/core/server/src/tcp/tcp_tls_listener.rs +++ b/core/server/src/tcp/tcp_tls_listener.rs @@ -19,7 +19,7 @@ use crate::binary::sender::SenderKind; use crate::configs::tcp::TcpSocketConfig; use crate::shard::IggyShard; -use crate::shard::tasks::task_supervisor; +use crate::shard::task_registry::task_registry; use crate::shard::transmission::event::ShardEvent; use crate::tcp::connection_handler::{handle_connection, handle_error}; use crate::{shard_error, shard_info, shard_warn}; @@ -220,7 +220,7 @@ async fn accept_loop( // Perform TLS handshake in a separate task to avoid blocking the accept loop let task_shard = shard_clone.clone(); - task_supervisor().spawn_tracked(async move { + task_registry().spawn_tracked(async move { match acceptor.accept(stream).await { Ok(tls_stream) => { // TLS handshake successful, now create session @@ -238,13 +238,13 @@ async fn accept_loop( let client_id = session.client_id; shard_info!(shard_clone.id, "Created new session: {}", session); - let conn_stop_receiver = task_supervisor().add_connection(client_id); + let conn_stop_receiver = task_registry().add_connection(client_id); let shard_for_conn = shard_clone.clone(); let mut sender = SenderKind::get_tcp_tls_sender(tls_stream); if let Err(error) = handle_connection(&session, &mut sender, &shard_for_conn, conn_stop_receiver).await { handle_error(error); } - task_supervisor().remove_connection(&client_id); + task_registry().remove_connection(&client_id); if let Err(error) = sender.shutdown().await { shard_error!(shard.id, "Failed to shutdown TCP TLS stream for client: {}, address: {}. {}", client_id, address, error);
