This is an automated email from the ASF dual-hosted git repository. hgruszecki pushed a commit to branch io_uring_tpc_task_registry in repository https://gitbox.apache.org/repos/asf/iggy.git
commit 759ead2716e716d05a59ebc3b26add6cca9b8773 Author: Hubert Gruszecki <[email protected]> AuthorDate: Wed Sep 24 13:14:23 2025 +0200 feat(io_uring): implement task supervisor --- Cargo.lock | 4 - core/bench/src/actors/consumer/client/low_level.rs | 2 +- core/bench/src/benchmarks/benchmark.rs | 7 +- core/bench/src/benchmarks/common.rs | 10 +- .../commands/clean_personal_access_tokens.rs | 75 --- core/server/src/channels/commands/print_sysinfo.rs | 70 --- core/server/src/channels/commands/save_messages.rs | 86 ---- .../src/channels/commands/verify_heartbeats.rs | 80 ---- core/server/src/http/http_server.rs | 43 +- core/server/src/http/jwt/cleaner.rs | 41 -- core/server/src/http/jwt/mod.rs | 1 - core/server/src/http/partitions.rs | 1 - core/server/src/lib.rs | 1 - core/server/src/log/logger.rs | 4 +- core/server/src/log/runtime.rs | 18 +- core/server/src/main.rs | 28 +- core/server/src/quic/listener.rs | 52 +- core/server/src/shard/builder.rs | 3 +- core/server/src/shard/mod.rs | 138 +++--- core/server/src/shard/stats.rs | 1 - core/server/src/shard/system/messages.rs | 4 +- core/server/src/shard/task_registry.rs | 108 ----- .../src/shard/tasks/continuous/http_server.rs | 79 ++++ .../src/shard/tasks/continuous/message_pump.rs | 100 ++++ .../{channels => shard/tasks/continuous}/mod.rs | 12 +- .../src/shard/tasks/continuous/quic_server.rs | 68 +++ .../src/shard/tasks/continuous/tcp_server.rs | 72 +++ core/server/src/shard/tasks/messages.rs | 59 --- core/server/src/shard/tasks/mod.rs | 107 ++++- .../tasks/periodic/clear_personal_access_tokens.rs | 122 +++++ .../commands => shard/tasks/periodic}/mod.rs | 7 +- .../src/shard/tasks/periodic/print_sysinfo.rs | 126 +++++ .../src/shard/tasks/periodic/save_messages.rs | 136 ++++++ .../src/shard/tasks/periodic/verify_heartbeats.rs | 131 ++++++ core/server/src/shard/tasks/shutdown.rs | 232 +++++++++ core/server/src/shard/tasks/specs.rs | 203 ++++++++ core/server/src/shard/tasks/supervisor.rs | 522 +++++++++++++++++++++ core/server/src/shard/tasks/tls.rs | 141 ++++++ core/server/src/shard/transmission/message.rs | 5 +- core/server/src/slab/streams.rs | 36 +- core/server/src/streaming/storage.rs | 2 +- core/server/src/tcp/tcp_listener.rs | 7 +- core/server/src/tcp/tcp_tls_listener.rs | 7 +- 43 files changed, 2291 insertions(+), 660 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 876935f9..0b2edd0b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -614,8 +614,6 @@ dependencies = [ "futures-lite", "pin-project", "thiserror 2.0.16", - "tokio", - "tokio-util", ] [[package]] @@ -5716,8 +5714,6 @@ dependencies = [ "rand 0.9.2", "serde_json", "thiserror 2.0.16", - "tokio", - "tokio-stream", ] [[package]] diff --git a/core/bench/src/actors/consumer/client/low_level.rs b/core/bench/src/actors/consumer/client/low_level.rs index 088ad42f..4c80c4ce 100644 --- a/core/bench/src/actors/consumer/client/low_level.rs +++ b/core/bench/src/actors/consumer/client/low_level.rs @@ -109,7 +109,7 @@ impl ConsumerClient for LowLevelConsumerClient { } Ok(Some(BatchMetrics { - messages: messages_count as u32, + messages: messages_count.try_into().unwrap(), user_data_bytes: user_bytes, total_bytes, latency, diff --git a/core/bench/src/benchmarks/benchmark.rs b/core/bench/src/benchmarks/benchmark.rs index 37e65b78..31f1828c 100644 --- a/core/bench/src/benchmarks/benchmark.rs +++ b/core/bench/src/benchmarks/benchmark.rs @@ -102,7 +102,7 @@ pub trait Benchmarkable: Send { login_root(&client).await; let streams = client.get_streams().await?; for i in 1..=number_of_streams { - let stream_name = format!("bench-stream-{}", i); + let stream_name = format!("bench-stream-{i}"); let stream_id: Identifier = stream_name.as_str().try_into()?; if streams.iter().all(|s| s.name != stream_name) { info!("Creating the test stream '{}'", stream_name); @@ -141,11 +141,10 @@ pub trait Benchmarkable: Send { login_root(&client).await; let streams = client.get_streams().await?; for i in 1..=number_of_streams { - let stream_name = format!("bench-stream-{}", i); + let stream_name = format!("bench-stream-{i}"); if streams.iter().all(|s| s.name != stream_name) { return Err(IggyError::ResourceNotFound(format!( - "Streams for testing are not properly initialized. Stream '{}' is missing.", - stream_name + "Streams for testing are not properly initialized. Stream '{stream_name}' is missing." ))); } } diff --git a/core/bench/src/benchmarks/common.rs b/core/bench/src/benchmarks/common.rs index 3ee4f686..5aa2c458 100644 --- a/core/bench/src/benchmarks/common.rs +++ b/core/bench/src/benchmarks/common.rs @@ -79,7 +79,7 @@ pub async fn init_consumer_groups( login_root(&client).await; for i in 1..=cg_count { let consumer_group_id = CONSUMER_GROUP_BASE_ID + i; - let stream_name = format!("bench-stream-{}", i); + let stream_name = format!("bench-stream-{i}"); let stream_id: Identifier = stream_name.as_str().try_into()?; let topic_id: Identifier = "topic-1".try_into()?; let consumer_group_name = format!("{CONSUMER_GROUP_NAME_PREFIX}-{consumer_group_id}"); @@ -136,7 +136,7 @@ pub fn build_producer_futures( }; let stream_idx = 1 + ((producer_id - 1) % streams); - let stream_id = format!("bench-stream-{}", stream_idx); + let stream_id = format!("bench-stream-{stream_idx}"); async move { let producer = TypedBenchmarkProducer::new( @@ -203,7 +203,7 @@ pub fn build_consumer_futures( } else { consumer_id }; - let stream_id = format!("bench-stream-{}", stream_idx); + let stream_id = format!("bench-stream-{stream_idx}"); let consumer_group_id = if cg_count > 0 { Some(CONSUMER_GROUP_BASE_ID + 1 + (consumer_id % cg_count)) } else { @@ -251,7 +251,7 @@ pub fn build_producing_consumers_futures( let client_factory_clone = client_factory.clone(); let args_clone = args.clone(); let stream_idx = 1 + ((actor_id - 1) % streams); - let stream_id = format!("bench-stream-{}", stream_idx); + let stream_id = format!("bench-stream-{stream_idx}"); let send_finish_condition = BenchmarkFinishCondition::new( &args, @@ -329,7 +329,7 @@ pub fn build_producing_consumer_groups_futures( let client_factory_clone = client_factory.clone(); let args_clone = args.clone(); let stream_idx = 1 + ((actor_id - 1) % cg_count); - let stream_id = format!("bench-stream-{}", stream_idx); + let stream_id = format!("bench-stream-{stream_idx}"); let should_produce = actor_id <= producers; let should_consume = actor_id <= consumers; diff --git a/core/server/src/channels/commands/clean_personal_access_tokens.rs b/core/server/src/channels/commands/clean_personal_access_tokens.rs deleted file mode 100644 index d846aa27..00000000 --- a/core/server/src/channels/commands/clean_personal_access_tokens.rs +++ /dev/null @@ -1,75 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::rc::Rc; - -use crate::shard::IggyShard; -use iggy_common::{IggyError, IggyTimestamp}; -use tracing::{debug, info, trace}; - -pub async fn clear_personal_access_tokens(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let config = &shard.config.personal_access_token.cleaner; - if !config.enabled { - info!("Personal access token cleaner is disabled."); - return Ok(()); - } - - info!( - "Personal access token cleaner is enabled, expired tokens will be deleted every: {}.", - config.interval - ); - - let interval = config.interval.get_duration(); - let mut interval_timer = compio::time::interval(interval); - - loop { - interval_timer.tick().await; - trace!("Cleaning expired personal access tokens..."); - - let users = shard.users.borrow(); - let now = IggyTimestamp::now(); - let mut deleted_tokens_count = 0; - - for (_, user) in users.iter() { - let expired_tokens = user - .personal_access_tokens - .iter() - .filter(|token| token.is_expired(now)) - .map(|token| token.token.clone()) - .collect::<Vec<_>>(); - - for token in expired_tokens { - debug!( - "Personal access token: {} for user with ID: {} is expired.", - token, user.id - ); - deleted_tokens_count += 1; - user.personal_access_tokens.remove(&token); - debug!( - "Deleted personal access token: {} for user with ID: {}.", - token, user.id - ); - } - } - - info!( - "Deleted {} expired personal access tokens.", - deleted_tokens_count - ); - } -} diff --git a/core/server/src/channels/commands/print_sysinfo.rs b/core/server/src/channels/commands/print_sysinfo.rs deleted file mode 100644 index 0c3992d2..00000000 --- a/core/server/src/channels/commands/print_sysinfo.rs +++ /dev/null @@ -1,70 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::rc::Rc; - -use crate::shard::IggyShard; -use crate::streaming::utils::memory_pool; -use human_repr::HumanCount; -use iggy_common::IggyError; -use tracing::{error, info, trace}; - -pub async fn print_sys_info(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let config = &shard.config.system.logging; - let interval = config.sysinfo_print_interval; - - if interval.is_zero() { - info!("SysInfoPrinter is disabled."); - return Ok(()); - } - info!("SysInfoPrinter is enabled, system information will be printed every {interval}."); - let mut interval_timer = compio::time::interval(interval.get_duration()); - loop { - interval_timer.tick().await; - trace!("Printing system information..."); - - let stats = match shard.get_stats().await { - Ok(stats) => stats, - Err(e) => { - error!("Failed to get system information. Error: {e}"); - continue; - } - }; - - let free_memory_percent = (stats.available_memory.as_bytes_u64() as f64 - / stats.total_memory.as_bytes_u64() as f64) - * 100f64; - - info!( - "CPU: {:.2}%/{:.2}% (IggyUsage/Total), Mem: {:.2}%/{}/{}/{} (Free/IggyUsage/TotalUsed/Total), Clients: {}, Messages processed: {}, Read: {}, Written: {}, Uptime: {}", - stats.cpu_usage, - stats.total_cpu_usage, - free_memory_percent, - stats.memory_usage, - stats.total_memory - stats.available_memory, - stats.total_memory, - stats.clients_count, - stats.messages_count.human_count_bare(), - stats.read_bytes, - stats.written_bytes, - stats.run_time - ); - - memory_pool().log_stats(); - } -} diff --git a/core/server/src/channels/commands/save_messages.rs b/core/server/src/channels/commands/save_messages.rs deleted file mode 100644 index 84330563..00000000 --- a/core/server/src/channels/commands/save_messages.rs +++ /dev/null @@ -1,86 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::{shard::IggyShard, shard_info}; -use iggy_common::{Identifier, IggyError}; -use std::rc::Rc; -use tracing::{error, info, trace}; - -pub async fn save_messages(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let config = &shard.config.message_saver; - if !config.enabled { - info!("Message saver is disabled."); - return Ok(()); - } - - // TODO: Maybe we should get rid of it in order to not complicate, and use the fsync settings per partition from config. - let enforce_fsync = config.enforce_fsync; - let interval = config.interval; - info!( - "Message saver is enabled, buffered messages will be automatically saved every: {interval}, enforce fsync: {enforce_fsync}." - ); - - let mut interval_timer = compio::time::interval(interval.get_duration()); - loop { - interval_timer.tick().await; - trace!("Saving buffered messages..."); - - let namespaces = shard.get_current_shard_namespaces(); - let mut total_saved_messages = 0u32; - let reason = "background saver triggered".to_string(); - - for ns in namespaces { - let stream_id = Identifier::numeric(ns.stream_id() as u32).unwrap(); - let topic_id = Identifier::numeric(ns.topic_id() as u32).unwrap(); - let partition_id = ns.partition_id(); - - match shard - .streams2 - .persist_messages( - shard.id, - &stream_id, - &topic_id, - partition_id, - reason.clone(), - &shard.config.system, - ) - .await - { - Ok(batch_count) => { - total_saved_messages += batch_count; - } - Err(err) => { - error!( - "Failed to save messages for partition {}: {}", - partition_id, err - ); - } - } - } - - if total_saved_messages > 0 { - shard_info!( - shard.id, - "Saved {} buffered messages on disk.", - total_saved_messages - ); - } - - trace!("Finished saving buffered messages."); - } -} diff --git a/core/server/src/channels/commands/verify_heartbeats.rs b/core/server/src/channels/commands/verify_heartbeats.rs deleted file mode 100644 index 123c174f..00000000 --- a/core/server/src/channels/commands/verify_heartbeats.rs +++ /dev/null @@ -1,80 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::shard::IggyShard; -use iggy_common::{IggyDuration, IggyError, IggyTimestamp}; -use std::rc::Rc; -use tracing::{debug, info, trace, warn}; - -const MAX_THRESHOLD: f64 = 1.2; - -pub async fn verify_heartbeats(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let config = &shard.config.heartbeat; - if !config.enabled { - info!("Heartbeats verification is disabled."); - return Ok(()); - } - - let interval = config.interval; - let max_interval = IggyDuration::from((MAX_THRESHOLD * interval.as_micros() as f64) as u64); - info!("Heartbeats will be verified every: {interval}. Max allowed interval: {max_interval}."); - - let mut interval_timer = compio::time::interval(interval.get_duration()); - - loop { - interval_timer.tick().await; - trace!("Verifying heartbeats..."); - - let clients = { - let client_manager = shard.client_manager.borrow(); - client_manager.get_clients() - }; - - let now = IggyTimestamp::now(); - let heartbeat_to = IggyTimestamp::from(now.as_micros() - max_interval.as_micros()); - debug!("Verifying heartbeats at: {now}, max allowed timestamp: {heartbeat_to}"); - - let mut stale_clients = Vec::new(); - for client in clients { - if client.last_heartbeat.as_micros() < heartbeat_to.as_micros() { - warn!( - "Stale client session: {}, last heartbeat at: {}, max allowed timestamp: {heartbeat_to}", - client.session, client.last_heartbeat, - ); - client.session.set_stale(); - stale_clients.push(client.session.client_id); - } else { - debug!( - "Valid heartbeat at: {} for client session: {}, max allowed timestamp: {heartbeat_to}", - client.last_heartbeat, client.session, - ); - } - } - - if stale_clients.is_empty() { - continue; - } - - let count = stale_clients.len(); - info!("Removing {count} stale clients..."); - for client_id in stale_clients { - shard.delete_client(client_id); - } - info!("Removed {count} stale clients."); - } -} diff --git a/core/server/src/http/http_server.rs b/core/server/src/http/http_server.rs index 2ae46271..1d5ebf9d 100644 --- a/core/server/src/http/http_server.rs +++ b/core/server/src/http/http_server.rs @@ -19,13 +19,13 @@ use crate::configs::http::{HttpConfig, HttpCorsConfig}; use crate::http::diagnostics::request_diagnostics; use crate::http::http_shard_wrapper::HttpSafeShard; -use crate::http::jwt::cleaner::start_expired_tokens_cleaner; use crate::http::jwt::jwt_manager::JwtManager; use crate::http::jwt::middleware::jwt_auth; use crate::http::metrics::metrics; use crate::http::shared::AppState; use crate::http::*; use crate::shard::IggyShard; +use crate::shard::tasks::{TaskScope, task_supervisor}; use crate::streaming::persistence::persister::PersisterKind; // use crate::streaming::systems::system::SharedSystem; use axum::extract::DefaultBodyLimit; @@ -66,7 +66,11 @@ impl<'a> Connected<cyper_axum::IncomingStream<'a, TcpListener>> for CompioSocket /// Starts the HTTP API server. /// Returns the address the server is listening on. -pub async fn start(config: HttpConfig, persister: Arc<PersisterKind>, shard: Rc<IggyShard>) -> Result<(), IggyError> { +pub async fn start( + config: HttpConfig, + persister: Arc<PersisterKind>, + shard: Rc<IggyShard>, +) -> Result<(), IggyError> { if shard.id != 0 { info!( "HTTP server disabled for shard {} (only runs on shard 0)", @@ -81,7 +85,7 @@ pub async fn start(config: HttpConfig, persister: Arc<PersisterKind>, shard: Rc< "HTTP API" }; - let app_state = build_app_state(&config, persister, shard).await; + let app_state = build_app_state(&config, persister, shard.clone()).await; let mut app = Router::new() .merge(system::router(app_state.clone(), &config.metrics)) .merge(personal_access_tokens::router(app_state.clone())) @@ -105,7 +109,32 @@ pub async fn start(config: HttpConfig, persister: Arc<PersisterKind>, shard: Rc< app = app.layer(middleware::from_fn_with_state(app_state.clone(), metrics)); } - start_expired_tokens_cleaner(app_state.clone()); + // JWT token cleaner task + { + let app_state_for_cleaner = Arc::clone(&app_state); + let shard_clone = shard.clone(); + task_supervisor().spawn_periodic_tick( + "jwt_token_cleaner", + std::time::Duration::from_secs(300), + TaskScope::SpecificShard(0), + shard_clone, + move |_ctx| { + let app = app_state_for_cleaner.clone(); + async move { + use iggy_common::IggyTimestamp; + let now = IggyTimestamp::now().to_secs(); + app.jwt_manager + .delete_expired_revoked_tokens(now) + .await + .unwrap_or_else(|err| { + error!("Failed to delete expired revoked tokens: {}", err); + }); + Ok(()) + } + }, + ); + } + app = app.layer(middleware::from_fn(request_diagnostics)); if !config.tls.enabled { @@ -158,7 +187,11 @@ pub async fn start(config: HttpConfig, persister: Arc<PersisterKind>, shard: Rc< } } -async fn build_app_state(config: &HttpConfig, persister: Arc<PersisterKind>, shard: Rc<IggyShard>) -> Arc<AppState> { +async fn build_app_state( + config: &HttpConfig, + persister: Arc<PersisterKind>, + shard: Rc<IggyShard>, +) -> Arc<AppState> { let tokens_path; { tokens_path = shard.config.system.get_state_tokens_path(); diff --git a/core/server/src/http/jwt/cleaner.rs b/core/server/src/http/jwt/cleaner.rs deleted file mode 100644 index 6705d955..00000000 --- a/core/server/src/http/jwt/cleaner.rs +++ /dev/null @@ -1,41 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::http::shared::AppState; -use iggy_common::IggyTimestamp; -use std::{sync::Arc, time::Duration}; -use tracing::{error, trace}; - -pub fn start_expired_tokens_cleaner(app_state: Arc<AppState>) { - compio::runtime::spawn(async move { - let mut interval_timer = compio::time::interval(Duration::from_secs(300)); - loop { - interval_timer.tick().await; - trace!("Deleting expired tokens..."); - let now = IggyTimestamp::now().to_secs(); - app_state - .jwt_manager - .delete_expired_revoked_tokens(now) - .await - .unwrap_or_else(|err| { - error!("Failed to delete expired revoked access tokens. Error: {err}"); - }); - } - }) - .detach(); -} diff --git a/core/server/src/http/jwt/mod.rs b/core/server/src/http/jwt/mod.rs index a60b270b..5481f110 100644 --- a/core/server/src/http/jwt/mod.rs +++ b/core/server/src/http/jwt/mod.rs @@ -16,7 +16,6 @@ * under the License. */ -pub mod cleaner; pub mod json_web_token; pub mod jwt_manager; pub mod middleware; diff --git a/core/server/src/http/partitions.rs b/core/server/src/http/partitions.rs index 11a1c4a1..b54539ea 100644 --- a/core/server/src/http/partitions.rs +++ b/core/server/src/http/partitions.rs @@ -33,7 +33,6 @@ use iggy_common::Identifier; use iggy_common::Validatable; use iggy_common::create_partitions::CreatePartitions; use iggy_common::delete_partitions::DeletePartitions; -use iggy_common::locking::IggyRwLockFn; use send_wrapper::SendWrapper; use std::sync::Arc; use tracing::instrument; diff --git a/core/server/src/lib.rs b/core/server/src/lib.rs index 38263692..735a2e51 100644 --- a/core/server/src/lib.rs +++ b/core/server/src/lib.rs @@ -31,7 +31,6 @@ compile_error!("iggy-server doesn't support windows."); pub mod args; pub mod binary; pub mod bootstrap; -pub mod channels; pub(crate) mod compat; pub mod configs; pub mod http; diff --git a/core/server/src/log/logger.rs b/core/server/src/log/logger.rs index 59725685..b400233a 100644 --- a/core/server/src/log/logger.rs +++ b/core/server/src/log/logger.rs @@ -16,10 +16,10 @@ * under the License. */ -use crate::log::runtime::CompioRuntime; use crate::VERSION; use crate::configs::server::{TelemetryConfig, TelemetryTransport}; use crate::configs::system::LoggingConfig; +use crate::log::runtime::CompioRuntime; use crate::server_error::LogError; use opentelemetry::KeyValue; use opentelemetry::global; @@ -250,7 +250,7 @@ impl Logging { .with_span_processor( span_processor_with_async_runtime::BatchSpanProcessor::builder( trace_exporter, - CompioRuntime + CompioRuntime, ) .build(), ) diff --git a/core/server/src/log/runtime.rs b/core/server/src/log/runtime.rs index 1683b34a..95f77585 100644 --- a/core/server/src/log/runtime.rs +++ b/core/server/src/log/runtime.rs @@ -1,6 +1,6 @@ use std::{pin::Pin, task::Poll, time::Duration}; -use futures::{channel::mpsc, future::poll_fn, FutureExt, SinkExt, Stream, StreamExt}; +use futures::{FutureExt, SinkExt, Stream, StreamExt, channel::mpsc, future::poll_fn}; use opentelemetry_sdk::runtime::{Runtime, RuntimeChannel, TrySend}; #[derive(Clone)] @@ -34,17 +34,20 @@ impl<T> CompioSender<T> { pub fn new(sender: mpsc::UnboundedSender<T>) -> Self { Self { sender } } -} +} -// Safety: Since we use compio runtime which is single-threaded, or rather the Future: !Send + !Sync, +// Safety: Since we use compio runtime which is single-threaded, or rather the Future: !Send + !Sync, // we can implement those traits, to satisfy the trait bounds from `Runtime` and `RuntimeChannel` traits. unsafe impl<T> Send for CompioSender<T> {} unsafe impl<T> Sync for CompioSender<T> {} -impl<T: std::fmt::Debug + Send> TrySend for CompioSender<T> { +impl<T: std::fmt::Debug + Send> TrySend for CompioSender<T> { type Message = T; - fn try_send(&self, item: Self::Message) -> Result<(), opentelemetry_sdk::runtime::TrySendError> { + fn try_send( + &self, + item: Self::Message, + ) -> Result<(), opentelemetry_sdk::runtime::TrySendError> { self.sender.unbounded_send(item).map_err(|_err| { // Unbounded channels can only fail if disconnected, never full opentelemetry_sdk::runtime::TrySendError::ChannelClosed @@ -65,7 +68,10 @@ impl<T> CompioReceiver<T> { impl<T: std::fmt::Debug + Send> Stream for CompioReceiver<T> { type Item = T; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Option<Self::Item>> { self.receiver.poll_next_unpin(cx) } } diff --git a/core/server/src/main.rs b/core/server/src/main.rs index 095ba4c2..cee7b9d0 100644 --- a/core/server/src/main.rs +++ b/core/server/src/main.rs @@ -354,22 +354,20 @@ async fn main() -> Result<(), ServerError> { } let shutdown_handles_for_signal = shutdown_handles.clone(); - /* - ::set_handler(move || { - info!("Received shutdown signal (SIGTERM/SIGINT), initiating graceful shutdown..."); - - for (shard_id, stop_sender) in &shutdown_handles_for_signal { - info!("Sending shutdown signal to shard {}", shard_id); - if let Err(e) = stop_sender.send_blocking(()) { - error!( - "Failed to send shutdown signal to shard {}: {}", - shard_id, e - ); - } + ctrlc::set_handler(move || { + info!("Received shutdown signal (SIGTERM/SIGINT), initiating graceful shutdown..."); + + for (shard_id, stop_sender) in &shutdown_handles_for_signal { + info!("Sending shutdown signal to shard {}", shard_id); + if let Err(e) = stop_sender.try_send(()) { + error!( + "Failed to send shutdown signal to shard {}: {}", + shard_id, e + ); } - }) - .expect("Error setting Ctrl-C handler"); - */ + } + }) + .expect("Error setting Ctrl-C handler"); info!("Iggy server is running. Press Ctrl+C or send SIGTERM to shutdown."); for (idx, handle) in handles.into_iter().enumerate() { diff --git a/core/server/src/quic/listener.rs b/core/server/src/quic/listener.rs index 5fce5c84..60833e2e 100644 --- a/core/server/src/quic/listener.rs +++ b/core/server/src/quic/listener.rs @@ -22,11 +22,13 @@ use crate::binary::command::{ServerCommand, ServerCommandHandler}; use crate::binary::sender::SenderKind; use crate::server_error::ConnectionError; use crate::shard::IggyShard; +use crate::shard::tasks::task_supervisor; use crate::shard::transmission::event::ShardEvent; use crate::streaming::session::Session; use crate::{shard_debug, shard_info}; use anyhow::anyhow; use compio_quic::{Connection, Endpoint, RecvStream, SendStream}; +use futures::FutureExt; use iggy_common::IggyError; use iggy_common::TransportProtocol; use tracing::{error, info, trace}; @@ -42,15 +44,16 @@ pub async fn start(endpoint: Endpoint, shard: Rc<IggyShard>) -> Result<(), IggyE while let Some(incoming_conn) = endpoint.wait_incoming().await { let remote_addr = incoming_conn.remote_address(); trace!("Incoming connection from client: {}", remote_addr); - let shard = shard.clone(); + let shard_clone = shard.clone(); + let shard_for_conn = shard_clone.clone(); - // Spawn each connection handler independently to maintain concurrency - compio::runtime::spawn(async move { + // Use TaskSupervisor to track connection handlers for graceful shutdown + task_supervisor().spawn_tracked(async move { trace!("Accepting connection from {}", remote_addr); match incoming_conn.await { Ok(connection) => { trace!("Connection established from {}", remote_addr); - if let Err(error) = handle_connection(connection, shard).await { + if let Err(error) = handle_connection(connection, shard_for_conn).await { error!("QUIC connection from {} has failed: {error}", remote_addr); } } @@ -61,8 +64,7 @@ pub async fn start(endpoint: Endpoint, shard: Rc<IggyShard>) -> Result<(), IggyE ); } } - }) - .detach(); + }); } info!("QUIC listener for shard {} stopped", shard.id); @@ -77,7 +79,6 @@ async fn handle_connection( info!("Client has connected: {address}"); let session = shard.add_client(&address, TransportProtocol::Quic); - let session = shard.add_client(&address, TransportProtocol::Quic); let client_id = session.client_id; shard_debug!( shard.id, @@ -93,19 +94,40 @@ async fn handle_connection( address, transport: TransportProtocol::Quic, }; + + // TODO(hubcio): unused? let _responses = shard.broadcast_event_to_all_shards(event.into()).await; - while let Some(stream) = accept_stream(&connection, &shard, client_id).await? { - let shard = shard.clone(); - let session = session.clone(); + let conn_stop_receiver = task_supervisor().add_connection(client_id); - let handle_stream_task = async move { - if let Err(err) = handle_stream(stream, shard, session).await { - error!("Error when handling QUIC stream: {:?}", err) + loop { + futures::select! { + // Check for shutdown signal + _ = conn_stop_receiver.recv().fuse() => { + info!("QUIC connection {} shutting down gracefully", client_id); + break; } - }; - let _handle = compio::runtime::spawn(handle_stream_task).detach(); + // Accept new connection + stream_result = accept_stream(&connection, &shard, client_id).fuse() => { + match stream_result? { + Some(stream) => { + let shard_clone = shard.clone(); + let session_clone = session.clone(); + + task_supervisor().spawn_tracked(async move { + if let Err(err) = handle_stream(stream, shard_clone, session_clone).await { + error!("Error when handling QUIC stream: {:?}", err) + } + }); + } + None => break, // Connection closed + } + } + } } + + task_supervisor().remove_connection(&client_id); + info!("QUIC connection {} closed", client_id); Ok(()) } diff --git a/core/server/src/shard/builder.rs b/core/server/src/shard/builder.rs index 287f31bb..4d98dd96 100644 --- a/core/server/src/shard/builder.rs +++ b/core/server/src/shard/builder.rs @@ -30,7 +30,7 @@ use tracing::info; use crate::{ configs::server::ServerConfig, io::storage::Storage, - shard::{Shard, ShardInfo, namespace::IggyNamespace, task_registry::TaskRegistry}, + shard::{Shard, ShardInfo, namespace::IggyNamespace}, slab::streams::Streams, state::{StateKind, system::SystemState}, streaming::{ @@ -151,7 +151,6 @@ impl IggyShardBuilder { stop_sender: stop_sender, messages_receiver: Cell::new(Some(frame_receiver)), metrics: metrics, - task_registry: TaskRegistry::new(), is_shutting_down: AtomicBool::new(false), tcp_bound_address: Cell::new(None), quic_bound_address: Cell::new(None), diff --git a/core/server/src/shard/mod.rs b/core/server/src/shard/mod.rs index 534ddfe3..7cb6c6f7 100644 --- a/core/server/src/shard/mod.rs +++ b/core/server/src/shard/mod.rs @@ -19,9 +19,7 @@ inner() * or more contributor license agreements. See the NOTICE file pub mod builder; pub mod logging; pub mod namespace; -pub mod stats; pub mod system; -pub mod task_registry; pub mod tasks; pub mod transmission; @@ -57,12 +55,10 @@ use transmission::connector::{Receiver, ShardConnector, StopReceiver, StopSender use crate::{ binary::handlers::messages::poll_messages_handler::IggyPollMetadata, configs::server::ServerConfig, - http::http_server, io::fs_utils, shard::{ namespace::{IggyFullNamespace, IggyNamespace}, - task_registry::TaskRegistry, - tasks::messages::spawn_shard_message_task, + tasks::{init_supervisor, shard_task_specs, task_supervisor}, transmission::{ event::ShardEvent, frame::{ShardFrame, ShardResponse}, @@ -75,12 +71,21 @@ use crate::{ traits_ext::{EntityComponentSystem, EntityMarker, Insert}, }, state::{ - file::FileState, system::{StreamState, SystemState, UserState}, StateKind + StateKind, + file::FileState, + system::{StreamState, SystemState, UserState}, }, streaming::{ - clients::client_manager::ClientManager, diagnostics::metrics::Metrics, partitions, persistence::persister::PersisterKind, polling_consumer::PollingConsumer, session::Session, traits::MainOps, users::{permissioner::Permissioner, user::User}, utils::ptr::EternalPtr + clients::client_manager::ClientManager, + diagnostics::metrics::Metrics, + partitions, + persistence::persister::PersisterKind, + polling_consumer::PollingConsumer, + session::Session, + traits::MainOps, + users::{permissioner::Permissioner, user::User}, + utils::ptr::EternalPtr, }, - tcp::tcp_server::spawn_tcp_server, versioning::SemanticVersion, }; @@ -157,7 +162,6 @@ pub struct IggyShard { pub messages_receiver: Cell<Option<Receiver<ShardFrame>>>, pub(crate) stop_receiver: StopReceiver, pub(crate) stop_sender: StopSender, - pub(crate) task_registry: TaskRegistry, pub(crate) is_shutting_down: AtomicBool, pub(crate) tcp_bound_address: Cell<Option<SocketAddr>>, pub(crate) quic_bound_address: Cell<Option<SocketAddr>>, @@ -168,6 +172,55 @@ impl IggyShard { Default::default() } + pub fn default_from_config(server_config: ServerConfig) -> Self { + use crate::bootstrap::resolve_persister; + use crate::versioning::SemanticVersion; + + let version = SemanticVersion::current().expect("Invalid version"); + let persister = resolve_persister(server_config.system.partition.enforce_fsync); + + let (stop_sender, stop_receiver) = async_channel::unbounded(); + + let state_path = server_config.system.get_state_messages_file_path(); + let file_state = FileState::new(&state_path, &version, persister, None); + let state = crate::state::StateKind::File(file_state); + let shards_table = Box::new(DashMap::new()); + let shards_table = Box::leak(shards_table); + + let shard = Self { + id: 0, + shards: Vec::new(), + shards_table: shards_table.into(), + version, + streams2: Default::default(), + state, + //TODO: Fix + encryptor: None, + config: server_config, + client_manager: Default::default(), + active_sessions: Default::default(), + permissioner: Default::default(), + users: Default::default(), + metrics: Metrics::init(), + messages_receiver: Cell::new(None), + stop_receiver, + stop_sender, + is_shutting_down: AtomicBool::new(false), + tcp_bound_address: Cell::new(None), + quic_bound_address: Cell::new(None), + }; + let user = User::root(DEFAULT_ROOT_USERNAME, DEFAULT_ROOT_PASSWORD); + shard + .create_user_bypass_auth( + user.id, + &user.username, + &user.password, + UserStatus::Active, + Some(Permissions::root()), + ) + .unwrap(); + shard + } pub async fn init(&self) -> Result<(), IggyError> { self.load_segments().await?; let _ = self.load_users().await; @@ -175,6 +228,8 @@ impl IggyShard { } pub async fn run(self: &Rc<Self>, persister: Arc<PersisterKind>) -> Result<(), IggyError> { + // Initialize TLS task supervisor for this thread + init_supervisor(self.id); // Workaround to ensure that the statistics are initialized before the server // loads streams and starts accepting connections. This is necessary to // have the correct statistics when the server starts. @@ -185,50 +240,17 @@ impl IggyShard { // TODO: Fixme //self.assert_init(); - // Create all tasks (tcp listener, http listener, command processor, in the future also the background jobs). - let mut tasks: Vec<Task> = vec![Box::pin(spawn_shard_message_task(self.clone()))]; - if self.config.tcp.enabled { - tasks.push(Box::pin(spawn_tcp_server(self.clone()))); - } - - if self.config.http.enabled && self.id == 0 { - println!("Starting HTTP server on shard: {}", self.id); - tasks.push(Box::pin(http_server::start( - self.config.http.clone(), - persister, - self.clone(), - ))); - } + // Create and spawn all tasks via the supervisor + let task_specs = shard_task_specs(self.clone()); + task_supervisor().spawn(self.clone(), task_specs); - if self.config.quic.enabled { - tasks.push(Box::pin(crate::quic::quic_server::span_quic_server( - self.clone(), - ))); - } - - tasks.push(Box::pin( - crate::channels::commands::clean_personal_access_tokens::clear_personal_access_tokens( - self.clone(), - ), - )); - // TOOD: Fixme, not always id 0 is the first shard. - if self.id == 0 { - tasks.push(Box::pin( - crate::channels::commands::print_sysinfo::print_sys_info(self.clone()), - )); - } - - tasks.push(Box::pin( - crate::channels::commands::verify_heartbeats::verify_heartbeats(self.clone()), - )); - tasks.push(Box::pin( - crate::channels::commands::save_messages::save_messages(self.clone()), - )); + // Create a oneshot channel for shutdown completion notification + let (shutdown_complete_tx, shutdown_complete_rx) = async_channel::bounded(1); let stop_receiver = self.get_stop_receiver(); let shard_for_shutdown = self.clone(); - /* + // Spawn shutdown handler - only this task consumes the stop signal compio::runtime::spawn(async move { let _ = stop_receiver.recv().await; info!("Shard {} received shutdown signal", shard_for_shutdown.id); @@ -239,13 +261,15 @@ impl IggyShard { } else { shard_info!(shard_for_shutdown.id, "shutdown completed successfully"); } - }); - */ + + let _ = shutdown_complete_tx.send(()).await; + }) + .detach(); let elapsed = now.elapsed(); shard_info!(self.id, "Initialized in {} ms.", elapsed.as_millis()); - let result = try_join_all(tasks).await; - result?; + + shutdown_complete_rx.recv().await.ok(); Ok(()) } @@ -349,11 +373,19 @@ impl IggyShard { self.stop_receiver.clone() } + /// Get the task supervisor for the current thread + /// + /// # Panics + /// Panics if the task supervisor has not been initialized + pub fn task_supervisor() -> Rc<crate::shard::tasks::TaskSupervisor> { + task_supervisor() + } + #[instrument(skip_all, name = "trace_shutdown")] pub async fn trigger_shutdown(&self) -> bool { self.is_shutting_down.store(true, Ordering::SeqCst); info!("Shard {} shutdown state set", self.id); - self.task_registry.shutdown_all(SHUTDOWN_TIMEOUT).await + task_supervisor().graceful_shutdown(SHUTDOWN_TIMEOUT).await } pub fn get_available_shards_count(&self) -> u32 { diff --git a/core/server/src/shard/stats.rs b/core/server/src/shard/stats.rs deleted file mode 100644 index 8b137891..00000000 --- a/core/server/src/shard/stats.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/core/server/src/shard/system/messages.rs b/core/server/src/shard/system/messages.rs index a64602d0..08237d91 100644 --- a/core/server/src/shard/system/messages.rs +++ b/core/server/src/shard/system/messages.rs @@ -36,7 +36,8 @@ use crate::streaming::{partitions, streams, topics}; use error_set::ErrContext; use iggy_common::{ - BytesSerializable, Consumer, EncryptorKind, Identifier, IggyError, IggyTimestamp, Partitioning, PartitioningKind, PollingKind, PollingStrategy, IGGY_MESSAGE_HEADER_SIZE + BytesSerializable, Consumer, EncryptorKind, IGGY_MESSAGE_HEADER_SIZE, Identifier, IggyError, + IggyTimestamp, Partitioning, PartitioningKind, PollingKind, PollingStrategy, }; use tracing::{error, trace}; @@ -352,7 +353,6 @@ impl IggyShard { todo!(); } - async fn decrypt_messages( &self, batches: IggyMessagesBatchSet, diff --git a/core/server/src/shard/task_registry.rs b/core/server/src/shard/task_registry.rs deleted file mode 100644 index ff674a8e..00000000 --- a/core/server/src/shard/task_registry.rs +++ /dev/null @@ -1,108 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use async_channel::{Receiver, Sender, bounded}; -use compio::runtime::JoinHandle; -use futures::future::join_all; -use std::cell::RefCell; -use std::collections::HashMap; -use std::future::Future; -use std::time::Duration; -use tracing::{error, info, warn}; - -pub struct TaskRegistry { - tasks: RefCell<Vec<JoinHandle<()>>>, - active_connections: RefCell<HashMap<u32, Sender<()>>>, -} - -impl TaskRegistry { - pub fn new() -> Self { - Self { - tasks: RefCell::new(Vec::new()), - active_connections: RefCell::new(HashMap::new()), - } - } - - pub fn spawn_tracked<F>(&self, future: F) - where - F: Future<Output = ()> + 'static, - { - let handle = compio::runtime::spawn(future); - self.tasks.borrow_mut().push(handle); - } - - pub fn add_connection(&self, client_id: u32) -> Receiver<()> { - let (stop_sender, stop_receiver) = bounded(1); - self.active_connections - .borrow_mut() - .insert(client_id, stop_sender); - stop_receiver - } - - pub fn remove_connection(&self, client_id: &u32) { - self.active_connections.borrow_mut().remove(client_id); - } - - pub async fn shutdown_all(&self, timeout: Duration) -> bool { - info!("Initiating task registry shutdown"); - - let connections = self.active_connections.borrow(); - for (client_id, stop_sender) in connections.iter() { - info!("Sending shutdown signal to client {}", client_id); - if let Err(e) = stop_sender.send(()).await { - warn!( - "Failed to send shutdown signal to client {}: {}", - client_id, e - ); - } - } - drop(connections); - - let tasks = self.tasks.take(); - let total = tasks.len(); - - if total == 0 { - info!("No tasks to shut down"); - return true; - } - - let timeout_futures: Vec<_> = tasks - .into_iter() - .enumerate() - .map(|(idx, handle)| async move { - match compio::time::timeout(timeout, handle).await { - Ok(_) => (idx, true), - Err(_) => { - warn!("Task {} did not complete within timeout", idx); - (idx, false) - } - } - }) - .collect(); - - let results = join_all(timeout_futures).await; - let completed = results.iter().filter(|(_, success)| *success).count(); - - info!( - "Task registry shutdown complete. {} of {} tasks completed", - completed, total - ); - - completed == total - } -} diff --git a/core/server/src/shard/tasks/continuous/http_server.rs b/core/server/src/shard/tasks/continuous/http_server.rs new file mode 100644 index 00000000..efd624b7 --- /dev/null +++ b/core/server/src/shard/tasks/continuous/http_server.rs @@ -0,0 +1,79 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::bootstrap::resolve_persister; +use crate::http::http_server; +use crate::shard::IggyShard; +use crate::shard::tasks::specs::{ + ContinuousSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, +}; +use std::fmt::Debug; +use std::rc::Rc; +use tracing::info; + +/// Continuous task for running the HTTP REST API server +pub struct HttpServer { + shard: Rc<IggyShard>, +} + +impl Debug for HttpServer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HttpServer") + .field("shard_id", &self.shard.id) + .finish() + } +} + +impl HttpServer { + pub fn new(shard: Rc<IggyShard>) -> Self { + Self { shard } + } +} + +impl TaskSpec for HttpServer { + fn name(&self) -> &'static str { + "http_server" + } + + fn kind(&self) -> TaskKind { + TaskKind::Continuous + } + + fn scope(&self) -> TaskScope { + TaskScope::SpecificShard(0) + } + + fn is_critical(&self) -> bool { + false + } + + fn run(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { + Box::pin(async move { + info!("Starting HTTP server on shard: {}", self.shard.id); + let persister = resolve_persister(self.shard.config.system.partition.enforce_fsync); + http_server::start( + self.shard.config.http.clone(), + persister, + self.shard.clone(), + ) + .await + }) + } +} + +impl ContinuousSpec for HttpServer {} diff --git a/core/server/src/shard/tasks/continuous/message_pump.rs b/core/server/src/shard/tasks/continuous/message_pump.rs new file mode 100644 index 00000000..cf45a3ff --- /dev/null +++ b/core/server/src/shard/tasks/continuous/message_pump.rs @@ -0,0 +1,100 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::tasks::specs::{ + ContinuousSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, +}; +use crate::shard::transmission::frame::ShardFrame; +use crate::shard_info; +use futures::{FutureExt, StreamExt}; +use std::fmt::Debug; +use std::rc::Rc; + +/// Continuous task for processing inter-shard messages +pub struct MessagePump { + shard: Rc<IggyShard>, +} + +impl Debug for MessagePump { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MessagePump") + .field("shard_id", &self.shard.id) + .finish() + } +} + +impl MessagePump { + pub fn new(shard: Rc<IggyShard>) -> Self { + Self { shard } + } +} + +impl TaskSpec for MessagePump { + fn name(&self) -> &'static str { + "message_pump" + } + + fn kind(&self) -> TaskKind { + TaskKind::Continuous + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn is_critical(&self) -> bool { + true + } + + fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { + Box::pin(async move { + let mut messages_receiver = self.shard.messages_receiver.take().unwrap(); + + shard_info!(self.shard.id, "Starting message passing task"); + + loop { + futures::select! { + _ = ctx.shutdown.wait().fuse() => { + shard_info!(self.shard.id, "Message receiver shutting down"); + break; + } + frame = messages_receiver.next().fuse() => { + if let Some(frame) = frame { + let ShardFrame { + message, + response_sender, + } = frame; + + if let (Some(response), Some(response_sender)) = (self.shard.handle_shard_message(message).await, response_sender) { + response_sender + .send(response) + .await + .expect("Failed to send response back to origin shard."); + }; + } + } + } + } + + Ok(()) + }) + } +} + +impl ContinuousSpec for MessagePump {} diff --git a/core/server/src/channels/mod.rs b/core/server/src/shard/tasks/continuous/mod.rs similarity index 74% rename from core/server/src/channels/mod.rs rename to core/server/src/shard/tasks/continuous/mod.rs index d147f7e0..97a3ff75 100644 --- a/core/server/src/channels/mod.rs +++ b/core/server/src/shard/tasks/continuous/mod.rs @@ -16,4 +16,14 @@ * under the License. */ -pub mod commands; +//! Continuous task specifications for long-running services + +pub mod http_server; +pub mod message_pump; +pub mod quic_server; +pub mod tcp_server; + +pub use http_server::HttpServer; +pub use message_pump::MessagePump; +pub use quic_server::QuicServer; +pub use tcp_server::TcpServer; diff --git a/core/server/src/shard/tasks/continuous/quic_server.rs b/core/server/src/shard/tasks/continuous/quic_server.rs new file mode 100644 index 00000000..c0d451b4 --- /dev/null +++ b/core/server/src/shard/tasks/continuous/quic_server.rs @@ -0,0 +1,68 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::quic::quic_server; +use crate::shard::IggyShard; +use crate::shard::tasks::specs::{ + ContinuousSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, +}; +use std::fmt::Debug; +use std::rc::Rc; + +/// Continuous task for running the QUIC server +pub struct QuicServer { + shard: Rc<IggyShard>, +} + +impl Debug for QuicServer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QuicServer") + .field("shard_id", &self.shard.id) + .finish() + } +} + +impl QuicServer { + pub fn new(shard: Rc<IggyShard>) -> Self { + Self { shard } + } +} + +impl TaskSpec for QuicServer { + fn name(&self) -> &'static str { + "quic_server" + } + + fn kind(&self) -> TaskKind { + TaskKind::Continuous + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn is_critical(&self) -> bool { + false + } + + fn run(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { + Box::pin(async move { quic_server::span_quic_server(self.shard.clone()).await }) + } +} + +impl ContinuousSpec for QuicServer {} diff --git a/core/server/src/shard/tasks/continuous/tcp_server.rs b/core/server/src/shard/tasks/continuous/tcp_server.rs new file mode 100644 index 00000000..b269db58 --- /dev/null +++ b/core/server/src/shard/tasks/continuous/tcp_server.rs @@ -0,0 +1,72 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::tasks::specs::{ + ContinuousSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, +}; +use crate::tcp::tcp_server::spawn_tcp_server; +use std::fmt::Debug; +use std::rc::Rc; + +/// Continuous task for running the TCP server +pub struct TcpServer { + shard: Rc<IggyShard>, +} + +impl Debug for TcpServer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TcpServer") + .field("shard_id", &self.shard.id) + .finish() + } +} + +impl TcpServer { + pub fn new(shard: Rc<IggyShard>) -> Self { + Self { shard } + } +} + +impl TaskSpec for TcpServer { + fn name(&self) -> &'static str { + "tcp_server" + } + + fn kind(&self) -> TaskKind { + TaskKind::Continuous + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn is_critical(&self) -> bool { + false + } + + fn run(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { + Box::pin(async move { + // The existing spawn_tcp_server already handles shutdown internally + // via is_shutting_down checks. This will be refactored later. + spawn_tcp_server(self.shard.clone()).await + }) + } +} + +impl ContinuousSpec for TcpServer {} diff --git a/core/server/src/shard/tasks/messages.rs b/core/server/src/shard/tasks/messages.rs deleted file mode 100644 index 2feadd1c..00000000 --- a/core/server/src/shard/tasks/messages.rs +++ /dev/null @@ -1,59 +0,0 @@ -use futures::{FutureExt, StreamExt}; -use iggy_common::IggyError; -use std::{rc::Rc, time::Duration}; - -use crate::{ - shard::{IggyShard, transmission::frame::ShardFrame}, - shard_error, shard_info, -}; - -async fn run_shard_messages_receiver(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let mut messages_receiver = shard.messages_receiver.take().unwrap(); - - shard_info!(shard.id, "Starting message passing task"); - loop { - let shutdown_check = async { - loop { - if shard.is_shutting_down() { - return; - } - compio::time::sleep(Duration::from_millis(100)).await; - } - }; - - futures::select! { - _ = shutdown_check.fuse() => { - shard_info!(shard.id, "Message receiver shutting down"); - break; - } - frame = messages_receiver.next().fuse() => { - if let Some(frame) = frame { - let ShardFrame { - message, - response_sender, - } = frame; - match (shard.handle_shard_message(message).await, response_sender) { - (Some(response), Some(response_sender)) => { - response_sender - .send(response) - .await - .expect("Failed to send response back to origin shard."); - } - _ => {} - }; - } - } - } - } - - Ok(()) -} - -pub async fn spawn_shard_message_task(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let result = run_shard_messages_receiver(shard.clone()).await; - if let Err(err) = result { - shard_error!(shard.id, "Error running shard message receiver: {err}"); - return Err(err); - } - Ok(()) -} diff --git a/core/server/src/shard/tasks/mod.rs b/core/server/src/shard/tasks/mod.rs index ba63992f..448f2f9e 100644 --- a/core/server/src/shard/tasks/mod.rs +++ b/core/server/src/shard/tasks/mod.rs @@ -1 +1,106 @@ -pub mod messages; +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Task management system for shard operations +//! +//! This module provides a unified framework for managing all asynchronous tasks +//! within a shard, including servers, periodic maintenance, and one-shot operations. + +pub mod continuous; +pub mod periodic; +pub mod shutdown; +pub mod specs; +pub mod supervisor; +pub mod tls; + +pub use shutdown::{Shutdown, ShutdownToken}; +pub use specs::{TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec}; +pub use supervisor::TaskSupervisor; +pub use tls::{init_supervisor, is_supervisor_initialized, task_supervisor}; + +use crate::shard::IggyShard; +use specs::IsNoOp; +use std::rc::Rc; +use std::time::Duration; + +/// Create all task specifications for a shard +/// +/// This is the central place where all tasks are declared and configured. +/// The supervisor will filter tasks based on their scope and the shard's properties. +pub fn shard_task_specs(shard: Rc<IggyShard>) -> Vec<Box<dyn TaskSpec>> { + let mut specs: Vec<Box<dyn TaskSpec>> = vec![]; + + // Continuous tasks - servers and message processing + specs.push(Box::new(continuous::MessagePump::new(shard.clone()))); + + if shard.config.tcp.enabled { + specs.push(Box::new(continuous::TcpServer::new(shard.clone()))); + } + + if shard.config.http.enabled { + specs.push(Box::new(continuous::HttpServer::new(shard.clone()))); + } + + if shard.config.quic.enabled { + specs.push(Box::new(continuous::QuicServer::new(shard.clone()))); + } + + // Periodic tasks - maintenance and monitoring + if shard.config.message_saver.enabled { + let period = shard.config.message_saver.interval.get_duration(); + specs.push(Box::new(periodic::SaveMessages::new(shard.clone(), period))); + } + + if shard.config.heartbeat.enabled { + let period = shard.config.heartbeat.interval.get_duration(); + specs.push(Box::new(periodic::VerifyHeartbeats::new( + shard.clone(), + period, + ))); + } + + if shard.config.personal_access_token.cleaner.enabled { + let period = shard + .config + .personal_access_token + .cleaner + .interval + .get_duration(); + specs.push(Box::new(periodic::ClearPersonalAccessTokens::new( + shard.clone(), + period, + ))); + } + + // System info printing (leader only) + let sysinfo_period = shard + .config + .system + .logging + .sysinfo_print_interval + .get_duration(); + if !sysinfo_period.is_zero() { + specs.push(Box::new(periodic::PrintSysinfo::new( + shard.clone(), + sysinfo_period, + ))); + } + + // Filter out no-op tasks + specs.into_iter().filter(|spec| !spec.is_noop()).collect() +} diff --git a/core/server/src/shard/tasks/periodic/clear_personal_access_tokens.rs b/core/server/src/shard/tasks/periodic/clear_personal_access_tokens.rs new file mode 100644 index 00000000..1df3268d --- /dev/null +++ b/core/server/src/shard/tasks/periodic/clear_personal_access_tokens.rs @@ -0,0 +1,122 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::tasks::specs::{ + PeriodicSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, +}; +use iggy_common::IggyTimestamp; +use std::fmt::Debug; +use std::rc::Rc; +use std::time::Duration; +use tracing::{debug, info, trace}; + +/// Periodic task for cleaning expired personal access tokens +pub struct ClearPersonalAccessTokens { + shard: Rc<IggyShard>, + period: Duration, +} + +impl Debug for ClearPersonalAccessTokens { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClearPersonalAccessTokens") + .field("shard_id", &self.shard.id) + .field("period", &self.period) + .finish() + } +} + +impl ClearPersonalAccessTokens { + pub fn new(shard: Rc<IggyShard>, period: Duration) -> Self { + Self { shard, period } + } +} + +impl TaskSpec for ClearPersonalAccessTokens { + fn name(&self) -> &'static str { + "clear_personal_access_tokens" + } + + fn kind(&self) -> TaskKind { + TaskKind::Periodic { + period: self.period, + } + } + + fn scope(&self) -> TaskScope { + // Only clean tokens on shard 0 to avoid conflicts + TaskScope::SpecificShard(0) + } + + fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { + Box::pin(async move { + info!( + "Personal access token cleaner is enabled, expired tokens will be deleted every: {:?}.", + self.period + ); + + loop { + if !ctx.shutdown.sleep_or_shutdown(self.period).await { + trace!("Personal access token cleaner shutting down"); + break; + } + + trace!("Cleaning expired personal access tokens..."); + + let users = self.shard.users.borrow(); + let now = IggyTimestamp::now(); + let mut deleted_tokens_count = 0; + + for (_, user) in users.iter() { + let expired_tokens = user + .personal_access_tokens + .iter() + .filter(|token| token.is_expired(now)) + .map(|token| token.token.clone()) + .collect::<Vec<_>>(); + + for token in expired_tokens { + debug!( + "Personal access token: {} for user with ID: {} is expired.", + token, user.id + ); + deleted_tokens_count += 1; + user.personal_access_tokens.remove(&token); + debug!( + "Deleted personal access token: {} for user with ID: {}.", + token, user.id + ); + } + } + + info!( + "Deleted {} expired personal access tokens.", + deleted_tokens_count + ); + } + + Ok(()) + }) + } +} + +impl PeriodicSpec for ClearPersonalAccessTokens { + fn period(&self) -> Duration { + self.period + } +} diff --git a/core/server/src/channels/commands/mod.rs b/core/server/src/shard/tasks/periodic/mod.rs similarity index 79% rename from core/server/src/channels/commands/mod.rs rename to core/server/src/shard/tasks/periodic/mod.rs index 0130170b..493230a9 100644 --- a/core/server/src/channels/commands/mod.rs +++ b/core/server/src/shard/tasks/periodic/mod.rs @@ -16,7 +16,12 @@ * under the License. */ -pub mod clean_personal_access_tokens; +pub mod clear_personal_access_tokens; pub mod print_sysinfo; pub mod save_messages; pub mod verify_heartbeats; + +pub use clear_personal_access_tokens::ClearPersonalAccessTokens; +pub use print_sysinfo::PrintSysinfo; +pub use save_messages::SaveMessages; +pub use verify_heartbeats::VerifyHeartbeats; diff --git a/core/server/src/shard/tasks/periodic/print_sysinfo.rs b/core/server/src/shard/tasks/periodic/print_sysinfo.rs new file mode 100644 index 00000000..1866f867 --- /dev/null +++ b/core/server/src/shard/tasks/periodic/print_sysinfo.rs @@ -0,0 +1,126 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::tasks::specs::{ + PeriodicSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, +}; +use crate::streaming::utils::memory_pool; +use human_repr::HumanCount; +use std::fmt::Debug; +use std::rc::Rc; +use std::time::Duration; +use tracing::{error, info, trace}; + +/// Periodic task for printing system information +pub struct PrintSysinfo { + shard: Rc<IggyShard>, + period: Duration, +} + +impl Debug for PrintSysinfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrintSysinfo") + .field("shard_id", &self.shard.id) + .field("period", &self.period) + .finish() + } +} + +impl PrintSysinfo { + pub fn new(shard: Rc<IggyShard>, period: Duration) -> Self { + Self { shard, period } + } +} + +impl TaskSpec for PrintSysinfo { + fn name(&self) -> &'static str { + "print_sysinfo" + } + + fn kind(&self) -> TaskKind { + TaskKind::Periodic { + period: self.period, + } + } + + fn scope(&self) -> TaskScope { + // Only print sysinfo from shard 0 to avoid duplicate logs + TaskScope::SpecificShard(0) + } + + fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { + Box::pin(async move { + if self.period == Duration::ZERO { + info!("SysInfoPrinter is disabled."); + return Ok(()); + } + + info!( + "SysInfoPrinter is enabled, system information will be printed every {:?}.", + self.period + ); + + loop { + if !ctx.shutdown.sleep_or_shutdown(self.period).await { + trace!("SysInfoPrinter shutting down"); + break; + } + + trace!("Printing system information..."); + + let stats = match self.shard.get_stats().await { + Ok(stats) => stats, + Err(e) => { + error!("Failed to get system information. Error: {e}"); + continue; + } + }; + + let free_memory_percent = (stats.available_memory.as_bytes_u64() as f64 + / stats.total_memory.as_bytes_u64() as f64) + * 100f64; + + info!( + "CPU: {:.2}%/{:.2}% (IggyUsage/Total), Mem: {:.2}%/{}/{}/{} (Free/IggyUsage/TotalUsed/Total), Clients: {}, Messages processed: {}, Read: {}, Written: {}, Uptime: {}", + stats.cpu_usage, + stats.total_cpu_usage, + free_memory_percent, + stats.memory_usage, + stats.total_memory - stats.available_memory, + stats.total_memory, + stats.clients_count, + stats.messages_count.human_count_bare(), + stats.read_bytes, + stats.written_bytes, + stats.run_time + ); + + memory_pool().log_stats(); + } + + Ok(()) + }) + } +} + +impl PeriodicSpec for PrintSysinfo { + fn period(&self) -> Duration { + self.period + } +} diff --git a/core/server/src/shard/tasks/periodic/save_messages.rs b/core/server/src/shard/tasks/periodic/save_messages.rs new file mode 100644 index 00000000..5599cedd --- /dev/null +++ b/core/server/src/shard/tasks/periodic/save_messages.rs @@ -0,0 +1,136 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::tasks::specs::{ + PeriodicSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, +}; +use crate::shard_info; +use iggy_common::Identifier; +use std::fmt::Debug; +use std::rc::Rc; +use std::time::Duration; +use tracing::{error, info, trace}; + +/// Periodic task for saving buffered messages to disk +pub struct SaveMessages { + shard: Rc<IggyShard>, + period: Duration, +} + +impl Debug for SaveMessages { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SaveMessages") + .field("shard_id", &self.shard.id) + .field("period", &self.period) + .finish() + } +} + +impl SaveMessages { + pub fn new(shard: Rc<IggyShard>, period: Duration) -> Self { + Self { shard, period } + } +} + +impl TaskSpec for SaveMessages { + fn name(&self) -> &'static str { + "save_messages" + } + + fn kind(&self) -> TaskKind { + TaskKind::Periodic { + period: self.period, + } + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { + Box::pin(async move { + let enforce_fsync = self.shard.config.message_saver.enforce_fsync; + info!( + "Message saver is enabled, buffered messages will be automatically saved every: {:?}, enforce fsync: {enforce_fsync}.", + self.period + ); + + loop { + if !ctx.shutdown.sleep_or_shutdown(self.period).await { + trace!("Message saver shutting down"); + break; + } + + trace!("Saving buffered messages..."); + + let namespaces = self.shard.get_current_shard_namespaces(); + let mut total_saved_messages = 0u32; + let reason = "background saver triggered".to_string(); + + for ns in namespaces { + let stream_id = Identifier::numeric(ns.stream_id() as u32).unwrap(); + let topic_id = Identifier::numeric(ns.topic_id() as u32).unwrap(); + let partition_id = ns.partition_id(); + + match self + .shard + .streams2 + .persist_messages( + self.shard.id, + &stream_id, + &topic_id, + partition_id, + reason.clone(), + &self.shard.config.system, + ) + .await + { + Ok(batch_count) => { + total_saved_messages += batch_count; + } + Err(err) => { + error!( + "Failed to save messages for partition {}: {}", + partition_id, err + ); + } + } + } + + if total_saved_messages > 0 { + shard_info!( + self.shard.id, + "Saved {} buffered messages on disk.", + total_saved_messages + ); + } + + trace!("Finished saving buffered messages."); + } + + Ok(()) + }) + } +} + +impl PeriodicSpec for SaveMessages { + fn period(&self) -> Duration { + self.period + } +} diff --git a/core/server/src/shard/tasks/periodic/verify_heartbeats.rs b/core/server/src/shard/tasks/periodic/verify_heartbeats.rs new file mode 100644 index 00000000..a2ce9685 --- /dev/null +++ b/core/server/src/shard/tasks/periodic/verify_heartbeats.rs @@ -0,0 +1,131 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::tasks::specs::{ + PeriodicSpec, TaskCtx, TaskFuture, TaskKind, TaskScope, TaskSpec, +}; +use iggy_common::{IggyDuration, IggyTimestamp}; +use std::fmt::Debug; +use std::rc::Rc; +use std::time::Duration; +use tracing::{debug, info, trace, warn}; + +const MAX_THRESHOLD: f64 = 1.2; + +/// Periodic task for verifying client heartbeats and removing stale clients +pub struct VerifyHeartbeats { + shard: Rc<IggyShard>, + period: Duration, +} + +impl Debug for VerifyHeartbeats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VerifyHeartbeats") + .field("shard_id", &self.shard.id) + .field("period", &self.period) + .finish() + } +} + +impl VerifyHeartbeats { + pub fn new(shard: Rc<IggyShard>, period: Duration) -> Self { + Self { shard, period } + } +} + +impl TaskSpec for VerifyHeartbeats { + fn name(&self) -> &'static str { + "verify_heartbeats" + } + + fn kind(&self) -> TaskKind { + TaskKind::Periodic { + period: self.period, + } + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { + Box::pin(async move { + let interval = IggyDuration::from(self.period); + let max_interval = + IggyDuration::from((MAX_THRESHOLD * interval.as_micros() as f64) as u64); + info!( + "Heartbeats will be verified every: {interval}. Max allowed interval: {max_interval}." + ); + + loop { + if !ctx.shutdown.sleep_or_shutdown(self.period).await { + trace!("Heartbeat verifier shutting down"); + break; + } + + trace!("Verifying heartbeats..."); + + let clients = { + let client_manager = self.shard.client_manager.borrow(); + client_manager.get_clients() + }; + + let now = IggyTimestamp::now(); + let heartbeat_to = IggyTimestamp::from(now.as_micros() - max_interval.as_micros()); + debug!("Verifying heartbeats at: {now}, max allowed timestamp: {heartbeat_to}"); + + let mut stale_clients = Vec::new(); + for client in clients { + if client.last_heartbeat.as_micros() < heartbeat_to.as_micros() { + warn!( + "Stale client session: {}, last heartbeat at: {}, max allowed timestamp: {heartbeat_to}", + client.session, client.last_heartbeat, + ); + client.session.set_stale(); + stale_clients.push(client.session.client_id); + } else { + debug!( + "Valid heartbeat at: {} for client session: {}, max allowed timestamp: {heartbeat_to}", + client.last_heartbeat, client.session, + ); + } + } + + if stale_clients.is_empty() { + continue; + } + + let count = stale_clients.len(); + info!("Removing {count} stale clients..."); + for client_id in stale_clients { + self.shard.delete_client(client_id); + } + info!("Removed {count} stale clients."); + } + + Ok(()) + }) + } +} + +impl PeriodicSpec for VerifyHeartbeats { + fn period(&self) -> Duration { + self.period + } +} diff --git a/core/server/src/shard/tasks/shutdown.rs b/core/server/src/shard/tasks/shutdown.rs new file mode 100644 index 00000000..ab9e40a1 --- /dev/null +++ b/core/server/src/shard/tasks/shutdown.rs @@ -0,0 +1,232 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use async_channel::{Receiver, Sender, bounded}; +use futures::FutureExt; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; +use tracing::trace; + +/// Coordinates graceful shutdown across multiple tasks +#[derive(Clone)] +pub struct Shutdown { + sender: Sender<()>, + is_triggered: Arc<AtomicBool>, +} + +impl Shutdown { + pub fn new() -> (Self, ShutdownToken) { + let (sender, receiver) = bounded(1); + let is_triggered = Arc::new(AtomicBool::new(false)); + + let shutdown = Self { + sender, + is_triggered: is_triggered.clone(), + }; + + let token = ShutdownToken { + receiver, + is_triggered, + }; + + (shutdown, token) + } + + pub fn trigger(&self) { + if self.is_triggered.swap(true, Ordering::SeqCst) { + return; // Already triggered + } + + trace!("Triggering shutdown signal"); + let _ = self.sender.close(); + } + + pub fn is_triggered(&self) -> bool { + self.is_triggered.load(Ordering::Relaxed) + } +} + +/// Token held by tasks to receive shutdown signals +#[derive(Clone)] +pub struct ShutdownToken { + receiver: Receiver<()>, + is_triggered: Arc<AtomicBool>, +} + +impl ShutdownToken { + /// Wait for shutdown signal + pub async fn wait(&self) { + let _ = self.receiver.recv().await; + } + + /// Check if shutdown has been triggered (non-blocking) + pub fn is_triggered(&self) -> bool { + self.is_triggered.load(Ordering::Relaxed) + } + + /// Sleep for the specified duration or until shutdown is triggered + /// Returns true if the full duration elapsed, false if shutdown was triggered + pub async fn sleep_or_shutdown(&self, duration: Duration) -> bool { + futures::select! { + _ = self.wait().fuse() => false, + _ = compio::time::sleep(duration).fuse() => !self.is_triggered(), + } + } + + /// Create a child token that triggers when either this token or the child's own signal triggers + pub fn child(&self) -> (Shutdown, ShutdownToken) { + let (child_shutdown, child_token) = Shutdown::new(); + let parent_receiver = self.receiver.clone(); + let child_receiver = child_token.receiver.clone(); + + // Create a new channel for the combined token + let (combined_sender, combined_receiver) = bounded(1); + + // Create a combined is_triggered flag + let combined_is_triggered = Arc::new(AtomicBool::new(false)); + + // Clone references for the async task + let parent_triggered = self.is_triggered.clone(); + let child_triggered = child_token.is_triggered.clone(); + let combined_flag_for_task = combined_is_triggered.clone(); + + // Spawn a single task that waits for either parent or child shutdown + compio::runtime::spawn(async move { + // Wait for either parent or child to trigger + futures::select! { + _ = parent_receiver.recv().fuse() => { + trace!("Child token triggered by parent shutdown"); + }, + _ = child_receiver.recv().fuse() => { + trace!("Child token triggered by child shutdown"); + }, + } + + // Set the combined flag based on which one triggered + if parent_triggered.load(Ordering::Relaxed) || child_triggered.load(Ordering::Relaxed) { + combined_flag_for_task.store(true, Ordering::SeqCst); + } + + // Close the combined channel to signal shutdown + let _ = combined_sender.close(); + }) + .detach(); + + let combined_token = ShutdownToken { + receiver: combined_receiver, + is_triggered: combined_is_triggered, + }; + + (child_shutdown, combined_token) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[compio::test] + async fn test_shutdown_trigger() { + let (shutdown, token) = Shutdown::new(); + + assert!(!token.is_triggered()); + + shutdown.trigger(); + + assert!(token.is_triggered()); + + // Should complete immediately after trigger + token.wait().await; + } + + #[compio::test] + async fn test_sleep_or_shutdown_completes() { + let (_shutdown, token) = Shutdown::new(); + + let completed = token.sleep_or_shutdown(Duration::from_millis(10)).await; + assert!(completed); + } + + #[compio::test] + async fn test_sleep_or_shutdown_interrupted() { + let (shutdown, token) = Shutdown::new(); + + // Trigger shutdown after a short delay + let shutdown_clone = shutdown.clone(); + compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_millis(10)).await; + shutdown_clone.trigger(); + }) + .detach(); + + // Should be interrupted + let completed = token.sleep_or_shutdown(Duration::from_secs(10)).await; + assert!(!completed); + } + + #[compio::test] + async fn test_child_token_parent_trigger() { + let (parent_shutdown, parent_token) = Shutdown::new(); + let (_child_shutdown, combined_token) = parent_token.child(); + + assert!(!combined_token.is_triggered()); + + // Trigger parent shutdown + parent_shutdown.trigger(); + + // Combined token should be triggered + combined_token.wait().await; + assert!(combined_token.is_triggered()); + } + + #[compio::test] + async fn test_child_token_child_trigger() { + let (_parent_shutdown, parent_token) = Shutdown::new(); + let (child_shutdown, combined_token) = parent_token.child(); + + assert!(!combined_token.is_triggered()); + + // Trigger child shutdown + child_shutdown.trigger(); + + // Combined token should be triggered + combined_token.wait().await; + assert!(combined_token.is_triggered()); + } + + #[compio::test] + async fn test_child_token_no_polling_overhead() { + let (_parent_shutdown, parent_token) = Shutdown::new(); + let (_child_shutdown, combined_token) = parent_token.child(); + + // Test that we can create many child tokens without performance issues + let start = std::time::Instant::now(); + for _ in 0..100 { + let _ = combined_token.child(); + } + let elapsed = start.elapsed(); + + // Should complete very quickly since there's no polling + assert!( + elapsed.as_millis() < 100, + "Creating child tokens took too long: {:?}", + elapsed + ); + } +} diff --git a/core/server/src/shard/tasks/specs.rs b/core/server/src/shard/tasks/specs.rs new file mode 100644 index 00000000..2b3340ea --- /dev/null +++ b/core/server/src/shard/tasks/specs.rs @@ -0,0 +1,203 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use iggy_common::IggyError; +use std::fmt::Debug; +use std::future::Future; +use std::pin::Pin; +use std::rc::Rc; +use std::time::Duration; +use strum::Display; + +use super::shutdown::ShutdownToken; + +/// Context provided to all tasks when they run +#[derive(Clone)] +pub struct TaskCtx { + pub shard: Rc<IggyShard>, + pub shutdown: ShutdownToken, +} + +/// Future type returned by task run methods +pub type TaskFuture = Pin<Box<dyn Future<Output = Result<(), IggyError>>>>; + +/// Describes the kind of task and its scheduling behavior +#[derive(Debug, Clone, Display)] +pub enum TaskKind { + /// Runs continuously until shutdown + Continuous, + /// Runs periodically at specified intervals + Periodic { period: Duration }, + /// Runs once and completes + OneShot, +} + +/// Determines which shards should run this task +#[derive(Clone, Debug)] +pub enum TaskScope { + /// Run on all shards + AllShards, + /// Run on a specific shard by ID + SpecificShard(u16), +} + +impl TaskScope { + /// Check if this task should run on the given shard + pub fn should_run(&self, shard: &IggyShard) -> bool { + match self { + TaskScope::AllShards => true, + TaskScope::SpecificShard(id) => shard.id == *id, + } + } +} + +/// Core trait that all tasks must implement +pub trait TaskSpec: Debug { + /// Unique name for this task + fn name(&self) -> &'static str; + + /// The kind of task (continuous, periodic, or oneshot) + fn kind(&self) -> TaskKind; + + /// Scope determining which shards run this task + fn scope(&self) -> TaskScope; + + /// Run the task with the provided context + fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture; + + /// Optional: called before the task starts + fn on_start(&self) { + tracing::info!("Starting task: {}", self.name()); + } + + /// Optional: called after the task completes + fn on_complete(&self, result: &Result<(), IggyError>) { + match result { + Ok(()) => tracing::info!("Task {} completed successfully", self.name()), + Err(e) => tracing::error!("Task {} failed: {}", self.name(), e), + } + } + + /// Optional: whether this task is critical (failure should stop the shard) + fn is_critical(&self) -> bool { + false + } +} + +/// Marker trait for continuous tasks +pub trait ContinuousSpec: TaskSpec { + fn as_task_spec(self: Box<Self>) -> Box<dyn TaskSpec> + where + Self: 'static + Sized, + { + self as Box<dyn TaskSpec> + } +} + +/// Marker trait for periodic tasks +pub trait PeriodicSpec: TaskSpec { + /// Get the period for this periodic task + fn period(&self) -> Duration; + + /// Optional: whether to run one final tick on shutdown + fn last_tick_on_shutdown(&self) -> bool { + false + } + + fn as_task_spec(self: Box<Self>) -> Box<dyn TaskSpec> + where + Self: 'static + Sized, + { + self as Box<dyn TaskSpec> + } +} + +/// Marker trait for oneshot tasks +pub trait OneShotSpec: TaskSpec { + /// Optional: timeout for this oneshot task + fn timeout(&self) -> Option<Duration> { + None + } + + fn as_task_spec(self: Box<Self>) -> Box<dyn TaskSpec> + where + Self: 'static + Sized, + { + self as Box<dyn TaskSpec> + } +} + +/// A no-op task that does nothing (useful for conditional task creation) +#[derive(Debug)] +pub struct NoOpTask; + +impl TaskSpec for NoOpTask { + fn name(&self) -> &'static str { + "noop" + } + + fn kind(&self) -> TaskKind { + TaskKind::OneShot + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn run(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { + Box::pin(async { Ok(()) }) + } +} + +/// Helper to create a no-op task +pub fn noop() -> Box<dyn TaskSpec> { + Box::new(NoOpTask) +} + +/// Helper trait to check if a task is a no-op +pub trait IsNoOp { + fn is_noop(&self) -> bool; +} + +impl IsNoOp for Box<dyn TaskSpec> { + fn is_noop(&self) -> bool { + self.name() == "noop" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_task_scope() { + let shard = IggyShard::default_from_config(Default::default()); + + assert!(TaskScope::AllShards.should_run(&shard)); + assert!(TaskScope::SpecificShard(0).should_run(&shard)); // shard.id == 0 by default + assert!(!TaskScope::SpecificShard(1).should_run(&shard)); + } + + #[test] + fn test_noop_task() { + let task = noop(); + assert!(task.is_noop()); + assert_eq!(task.name(), "noop"); + } +} diff --git a/core/server/src/shard/tasks/supervisor.rs b/core/server/src/shard/tasks/supervisor.rs new file mode 100644 index 00000000..9bd62f07 --- /dev/null +++ b/core/server/src/shard/tasks/supervisor.rs @@ -0,0 +1,522 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use super::shutdown::{Shutdown, ShutdownToken}; +use super::specs::{TaskCtx, TaskKind, TaskScope, TaskSpec}; +use crate::shard::IggyShard; +use compio::runtime::JoinHandle; +use futures::Future; +use futures::future::join_all; +use iggy_common::IggyError; +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; +use std::time::Duration; +use tracing::{error, info, trace, warn}; + +/// Handle to a spawned task +struct TaskHandle { + name: String, + kind: TaskKind, + handle: JoinHandle<Result<(), IggyError>>, + is_critical: bool, +} + +/// Supervises the lifecycle of all tasks in a shard +pub struct TaskSupervisor { + shard_id: u16, + shutdown: Shutdown, + shutdown_token: ShutdownToken, + tasks: RefCell<Vec<TaskHandle>>, + oneshot_handles: RefCell<Vec<TaskHandle>>, + active_connections: RefCell<HashMap<u32, async_channel::Sender<()>>>, +} + +impl TaskSupervisor { + /// Create a new task supervisor for a shard + pub fn new(shard_id: u16) -> Self { + let (shutdown, shutdown_token) = Shutdown::new(); + + Self { + shard_id, + shutdown, + shutdown_token, + tasks: RefCell::new(Vec::new()), + oneshot_handles: RefCell::new(Vec::new()), + active_connections: RefCell::new(HashMap::new()), + } + } + + /// Get a shutdown token for tasks + pub fn shutdown_token(&self) -> ShutdownToken { + self.shutdown_token.clone() + } + + /// Spawn all tasks according to their specifications + pub fn spawn(&self, shard: Rc<IggyShard>, specs: Vec<Box<dyn TaskSpec>>) { + for spec in specs { + // Check if task should run on this shard + if !spec.scope().should_run(&shard) { + trace!( + "Skipping task {} ({}) on shard {} due to scope", + spec.name(), + spec.kind(), + self.shard_id + ); + continue; + } + + // Skip no-op tasks + if spec.name() == "noop" { + continue; + } + + self.spawn_single(shard.clone(), spec); + } + + info!( + "Shard {} spawned {} tasks ({} oneshot)", + self.shard_id, + self.tasks.borrow().len(), + self.oneshot_handles.borrow().len() + ); + } + + /// Spawn a single task based on its kind + fn spawn_single(&self, shard: Rc<IggyShard>, spec: Box<dyn TaskSpec>) { + let name = spec.name(); + let kind = spec.kind(); + let is_critical = spec.is_critical(); + + info!( + "Spawning {} task '{}' on shard {}", + match &kind { + TaskKind::Continuous => "continuous", + TaskKind::Periodic { .. } => "periodic", + TaskKind::OneShot => "oneshot", + }, + name, + self.shard_id + ); + + spec.on_start(); + + let ctx = TaskCtx { + shard, + shutdown: self.shutdown_token.clone(), + }; + + let handle = match kind { + TaskKind::Continuous => self.spawn_continuous(spec, ctx), + TaskKind::Periodic { period } => self.spawn_periodic(spec, ctx, period), + TaskKind::OneShot => self.spawn_oneshot_spec(spec, ctx), + }; + + let task_handle = TaskHandle { + name: name.to_string(), + kind: kind.clone(), + handle, + is_critical, + }; + + // Store handle based on kind + match kind { + TaskKind::OneShot => self.oneshot_handles.borrow_mut().push(task_handle), + _ => self.tasks.borrow_mut().push(task_handle), + } + } + + /// Spawn a continuous task + fn spawn_continuous( + &self, + spec: Box<dyn TaskSpec>, + ctx: TaskCtx, + ) -> JoinHandle<Result<(), IggyError>> { + let name = spec.name(); + let shard_id = self.shard_id; + + compio::runtime::spawn(async move { + trace!("Continuous task '{}' starting on shard {}", name, shard_id); + let result = spec.run(ctx).await; + + match &result { + Ok(()) => info!("Continuous task '{}' completed on shard {}", name, shard_id), + Err(e) => error!( + "Continuous task '{}' failed on shard {}: {}", + name, shard_id, e + ), + } + + result + }) + } + + /// Spawn a periodic task + fn spawn_periodic( + &self, + spec: Box<dyn TaskSpec>, + ctx: TaskCtx, + period: Duration, + ) -> JoinHandle<Result<(), IggyError>> { + let name = spec.name(); + let shard_id = self.shard_id; + + compio::runtime::spawn(async move { + trace!( + "Periodic task '{}' starting on shard {} with period {:?}", + name, shard_id, period + ); + + // Periodic tasks handle their own loop internally + let result = spec.run(ctx).await; + + match &result { + Ok(()) => info!("Periodic task '{}' completed on shard {}", name, shard_id), + Err(e) => error!( + "Periodic task '{}' failed on shard {}: {}", + name, shard_id, e + ), + } + + result + }) + } + + /// Spawn a oneshot task from a TaskSpec + fn spawn_oneshot_spec( + &self, + spec: Box<dyn TaskSpec>, + ctx: TaskCtx, + ) -> JoinHandle<Result<(), IggyError>> { + let name = spec.name(); + let shard_id = self.shard_id; + + compio::runtime::spawn(async move { + trace!("OneShot task '{}' starting on shard {}", name, shard_id); + let result = spec.run(ctx).await; + + match &result { + Ok(()) => info!("OneShot task '{}' completed on shard {}", name, shard_id), + Err(e) => error!( + "OneShot task '{}' failed on shard {}: {}", + name, shard_id, e + ), + } + + result + }) + } + + /// Trigger graceful shutdown for all tasks + pub async fn graceful_shutdown(&self, timeout: Duration) -> bool { + info!( + "Initiating graceful shutdown for {} tasks on shard {}", + self.tasks.borrow().len() + self.oneshot_handles.borrow().len(), + self.shard_id + ); + + // First shutdown connections + self.shutdown_connections(); + + // Trigger shutdown signal + self.shutdown.trigger(); + + // Wait for continuous and periodic tasks with timeout + let continuous_periodic_tasks = self.tasks.take(); + let continuous_periodic_complete = self + .await_tasks_with_timeout(continuous_periodic_tasks, timeout) + .await; + + // Always wait for oneshot tasks (no timeout for durability) + let oneshot_tasks = self.oneshot_handles.take(); + let oneshot_complete = self.await_oneshot_tasks(oneshot_tasks).await; + + let all_complete = continuous_periodic_complete && oneshot_complete; + + if all_complete { + info!("All tasks shutdown gracefully on shard {}", self.shard_id); + } else { + warn!( + "Some tasks did not shutdown cleanly on shard {}", + self.shard_id + ); + } + + all_complete + } + + /// Wait for tasks with a timeout + async fn await_tasks_with_timeout(&self, tasks: Vec<TaskHandle>, timeout: Duration) -> bool { + if tasks.is_empty() { + return true; + } + + let task_count = tasks.len(); + let timeout_futures: Vec<_> = tasks + .into_iter() + .map(|task| async move { + let name = task.name.clone(); + let kind = task.kind; + let is_critical = task.is_critical; + + match compio::time::timeout(timeout, task.handle).await { + Ok(Ok(Ok(()))) => { + trace!("Task '{}' ({}) shutdown gracefully", name, kind); + (name, true, false) + } + Ok(Ok(Err(e))) => { + error!("Task '{}' ({}) failed during shutdown: {}", name, kind, e); + (name, false, is_critical) + } + Ok(Err(e)) => { + error!( + "Task '{}' ({}) panicked during shutdown: {:?}", + name, kind, e + ); + (name, false, is_critical) + } + Err(_) => { + warn!("Task '{}' ({}) did not complete within timeout", name, kind); + (name, false, is_critical) + } + } + }) + .collect(); + + let results = join_all(timeout_futures).await; + + let completed = results.iter().filter(|(_, success, _)| *success).count(); + let critical_failures = results.iter().any(|(_, _, critical)| *critical); + + info!( + "Shard {} shutdown: {}/{} tasks completed", + self.shard_id, completed, task_count + ); + + if critical_failures { + error!("Critical task(s) failed on shard {}", self.shard_id); + } + + completed == task_count && !critical_failures + } + + /// Wait for oneshot tasks (no timeout for durability) + async fn await_oneshot_tasks(&self, tasks: Vec<TaskHandle>) -> bool { + if tasks.is_empty() { + return true; + } + + info!( + "Waiting for {} oneshot tasks to complete on shard {}", + tasks.len(), + self.shard_id + ); + + let futures: Vec<_> = tasks + .into_iter() + .map(|task| async move { + let name = task.name.clone(); + match task.handle.await { + Ok(Ok(())) => { + trace!("OneShot task '{}' completed", name); + true + } + Ok(Err(e)) => { + error!("OneShot task '{}' failed: {}", name, e); + false + } + Err(e) => { + error!("OneShot task '{}' panicked: {:?}", name, e); + false + } + } + }) + .collect(); + + let results = join_all(futures).await; + let all_complete = results.iter().all(|&r| r); + + if all_complete { + info!("All oneshot tasks completed on shard {}", self.shard_id); + } else { + error!("Some oneshot tasks failed on shard {}", self.shard_id); + } + + all_complete + } + + /// Add a connection for tracking + pub fn add_connection(&self, client_id: u32) -> async_channel::Receiver<()> { + let (stop_sender, stop_receiver) = async_channel::bounded(1); + self.active_connections + .borrow_mut() + .insert(client_id, stop_sender); + stop_receiver + } + + /// Remove a connection from tracking + pub fn remove_connection(&self, client_id: &u32) { + self.active_connections.borrow_mut().remove(client_id); + } + + /// Spawn a tracked task (for connection handlers) + pub fn spawn_tracked<F>(&self, future: F) + where + F: Future<Output = ()> + 'static, + { + let handle = compio::runtime::spawn(async move { + future.await; + Ok(()) + }); + self.tasks.borrow_mut().push(TaskHandle { + name: "connection_handler".to_string(), + kind: TaskKind::Continuous, + handle, + is_critical: false, + }); + } + + /// Spawn a oneshot task directly without going through TaskSpec + pub fn spawn_oneshot<F>(&self, name: impl Into<String>, critical: bool, f: F) + where + F: Future<Output = Result<(), IggyError>> + 'static, + { + let name = name.into(); + let shard_id = self.shard_id; + + trace!("Spawning oneshot task '{}' on shard {}", name, shard_id); + + let task_name = name.clone(); + let handle = compio::runtime::spawn(async move { + trace!( + "OneShot task '{}' starting on shard {}", + task_name, shard_id + ); + let result = f.await; + + match &result { + Ok(()) => trace!( + "OneShot task '{}' completed on shard {}", + task_name, shard_id + ), + Err(e) => error!( + "OneShot task '{}' failed on shard {}: {}", + task_name, shard_id, e + ), + } + + result + }); + + self.oneshot_handles.borrow_mut().push(TaskHandle { + name, + kind: TaskKind::OneShot, + handle, + is_critical: critical, + }); + } + + /// Spawn a periodic task using a closure-based tick function + /// The supervisor handles the timing loop, the closure provides the tick logic + pub fn spawn_periodic_tick<F, Fut>( + &self, + name: &'static str, + period: Duration, + scope: TaskScope, + shard: Rc<IggyShard>, + mut tick: F, + ) where + F: FnMut(&TaskCtx) -> Fut + 'static, + Fut: Future<Output = Result<(), IggyError>>, + { + // Check if task should run on this shard + if !scope.should_run(&shard) { + trace!( + "Skipping periodic tick task '{}' on shard {} due to scope", + name, self.shard_id + ); + return; + } + + let ctx = TaskCtx { + shard, + shutdown: self.shutdown_token.clone(), + }; + + let shard_id = self.shard_id; + let shutdown = self.shutdown_token.clone(); + + info!( + "Spawning periodic tick task '{}' on shard {} with period {:?}", + name, shard_id, period + ); + + let handle = compio::runtime::spawn(async move { + trace!( + "Periodic tick task '{}' starting on shard {} with period {:?}", + name, shard_id, period + ); + + loop { + // Use shutdown-aware sleep + if !shutdown.sleep_or_shutdown(period).await { + trace!("Periodic tick task '{}' shutting down", name); + break; + } + + // Execute tick + match tick(&ctx).await { + Ok(()) => trace!("Periodic tick task '{}' tick completed", name), + Err(e) => error!("Periodic tick task '{}' tick failed: {}", name, e), + } + } + + info!( + "Periodic tick task '{}' completed on shard {}", + name, shard_id + ); + Ok(()) + }); + + self.tasks.borrow_mut().push(TaskHandle { + name: name.to_string(), + kind: TaskKind::Periodic { period }, + handle, + is_critical: false, + }); + } + + /// Shutdown all connections gracefully + fn shutdown_connections(&self) { + info!( + "Shutting down {} active connections", + self.active_connections.borrow().len() + ); + + let connections = self.active_connections.borrow(); + for (client_id, stop_sender) in connections.iter() { + trace!("Sending shutdown signal to connection {}", client_id); + if let Err(e) = stop_sender.send_blocking(()) { + warn!( + "Failed to send shutdown signal to connection {}: {}", + client_id, e + ); + } + } + } +} diff --git a/core/server/src/shard/tasks/tls.rs b/core/server/src/shard/tasks/tls.rs new file mode 100644 index 00000000..1370c5e6 --- /dev/null +++ b/core/server/src/shard/tasks/tls.rs @@ -0,0 +1,141 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use super::supervisor::TaskSupervisor; +use std::cell::RefCell; +use std::rc::Rc; + +thread_local! { + static SUPERVISOR: RefCell<Option<Rc<TaskSupervisor>>> = RefCell::new(None); +} + +/// Initialize the task supervisor for this thread +pub fn init_supervisor(shard_id: u16) { + SUPERVISOR.with(|s| { + *s.borrow_mut() = Some(Rc::new(TaskSupervisor::new(shard_id))); + }); +} + +/// Get the task supervisor for the current thread +/// +/// # Panics +/// Panics if the supervisor has not been initialized for this thread +pub fn task_supervisor() -> Rc<TaskSupervisor> { + SUPERVISOR.with(|s| { + s.borrow() + .as_ref() + .expect( + "Task supervisor not initialized for this thread. Call init_supervisor() first.", + ) + .clone() + }) +} + +/// Check if the task supervisor has been initialized for this thread +pub fn is_supervisor_initialized() -> bool { + SUPERVISOR.with(|s| s.borrow().is_some()) +} + +/// Clear the task supervisor for this thread (for cleanup/testing) +pub fn clear_supervisor() { + SUPERVISOR.with(|s| { + *s.borrow_mut() = None; + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use std::panic; + + #[test] + fn test_supervisor_initialization() { + // Clean any existing supervisor first + clear_supervisor(); + + // Initially, supervisor should not be initialized + assert!(!is_supervisor_initialized()); + + // Trying to get supervisor without initialization should panic + let result = panic::catch_unwind(|| { + task_supervisor(); + }); + assert!(result.is_err()); + + // Initialize supervisor + init_supervisor(42); + + // Now it should be initialized + assert!(is_supervisor_initialized()); + + // Clean up + clear_supervisor(); + assert!(!is_supervisor_initialized()); + } + + #[test] + fn test_multiple_initializations() { + clear_supervisor(); + + // First initialization + init_supervisor(1); + let _supervisor1 = task_supervisor(); + + // Second initialization (should replace the first) + init_supervisor(2); + let _supervisor2 = task_supervisor(); + + // They should be different instances (different Rc references) + // Note: We can't use ptr::eq since Rc doesn't expose the inner pointer easily + // But we can verify that initialization works multiple times + + clear_supervisor(); + } + + #[test] + fn test_thread_locality() { + use std::thread; + + clear_supervisor(); + + // Initialize in main thread + init_supervisor(100); + assert!(is_supervisor_initialized()); + + // Spawn a new thread and verify it doesn't have the supervisor + let handle = thread::spawn(|| { + // This thread should not have supervisor initialized + assert!(!is_supervisor_initialized()); + + // Initialize different supervisor in this thread + init_supervisor(200); + assert!(is_supervisor_initialized()); + + // Get supervisor to verify it works + let _ = task_supervisor(); + }); + + handle.join().expect("Thread should complete successfully"); + + // Main thread should still have its supervisor + assert!(is_supervisor_initialized()); + let _ = task_supervisor(); + + clear_supervisor(); + } +} diff --git a/core/server/src/shard/transmission/message.rs b/core/server/src/shard/transmission/message.rs index a4e4bb3a..14bf075b 100644 --- a/core/server/src/shard/transmission/message.rs +++ b/core/server/src/shard/transmission/message.rs @@ -1,7 +1,3 @@ -use std::{rc::Rc, sync::Arc}; - -use iggy_common::{Identifier, PollingStrategy}; - /* Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -27,6 +23,7 @@ use crate::{ slab::partitions, streaming::{polling_consumer::PollingConsumer, segments::IggyMessagesBatchMut}, }; +use iggy_common::Identifier; pub enum ShardSendRequestResult { // TODO: In the future we can add other variants, for example backpressure from the destination shard, diff --git a/core/server/src/slab/streams.rs b/core/server/src/slab/streams.rs index d733e04e..8ac12f6f 100644 --- a/core/server/src/slab/streams.rs +++ b/core/server/src/slab/streams.rs @@ -735,15 +735,33 @@ impl Streams { (msg.unwrap(), index.unwrap()) }); - compio::runtime::spawn(async move { - let _ = log_writer.fsync().await; - }) - .detach(); - compio::runtime::spawn(async move { - let _ = index_writer.fsync().await; - drop(index_writer) - }) - .detach(); + // Use task supervisor for proper tracking and graceful shutdown + use crate::shard::tasks::tls::task_supervisor; + use tracing::error; + + task_supervisor().spawn_oneshot("fsync:segment-close-messages", true, async move { + match log_writer.fsync().await { + Ok(_) => Ok(()), + Err(e) => { + error!("Failed to fsync log writer on segment close: {}", e); + Err(e) + } + } + }); + + task_supervisor().spawn_oneshot("fsync:segment-close-index", true, async move { + match index_writer.fsync().await { + Ok(_) => { + drop(index_writer); + Ok(()) + } + Err(e) => { + error!("Failed to fsync index writer on segment close: {}", e); + drop(index_writer); + Err(e) + } + } + }); let (start_offset, size, end_offset) = self.with_partition_by_id(stream_id, topic_id, partition_id, |(.., log)| { diff --git a/core/server/src/streaming/storage.rs b/core/server/src/streaming/storage.rs index 03199d70..d64845a2 100644 --- a/core/server/src/streaming/storage.rs +++ b/core/server/src/streaming/storage.rs @@ -47,7 +47,7 @@ macro_rules! forward_async_methods { } } -// TODO: Tech debt, how to get rid of this ? +// TODO: Tech debt, how to get rid of this ? #[derive(Debug)] pub enum SystemInfoStorageKind { File(FileSystemInfoStorage), diff --git a/core/server/src/tcp/tcp_listener.rs b/core/server/src/tcp/tcp_listener.rs index 18565ae6..f39dcc94 100644 --- a/core/server/src/tcp/tcp_listener.rs +++ b/core/server/src/tcp/tcp_listener.rs @@ -19,6 +19,7 @@ use crate::binary::sender::SenderKind; use crate::configs::tcp::TcpSocketConfig; use crate::shard::IggyShard; +use crate::shard::tasks::task_supervisor; use crate::shard::transmission::event::ShardEvent; use crate::tcp::connection_handler::{handle_connection, handle_error}; use crate::{shard_error, shard_info}; @@ -180,14 +181,14 @@ async fn accept_loop( shard_info!(shard.id, "Created new session: {}", session); let mut sender = SenderKind::get_tcp_sender(stream); - let conn_stop_receiver = shard_clone.task_registry.add_connection(client_id); + let conn_stop_receiver = task_supervisor().add_connection(client_id); let shard_for_conn = shard_clone.clone(); - shard_clone.task_registry.spawn_tracked(async move { + task_supervisor().spawn_tracked(async move { if let Err(error) = handle_connection(&session, &mut sender, &shard_for_conn, conn_stop_receiver).await { handle_error(error); } - shard_for_conn.task_registry.remove_connection(&client_id); + task_supervisor().remove_connection(&client_id); if let Err(error) = sender.shutdown().await { shard_error!(shard.id, "Failed to shutdown TCP stream for client: {}, address: {}. {}", client_id, address, error); diff --git a/core/server/src/tcp/tcp_tls_listener.rs b/core/server/src/tcp/tcp_tls_listener.rs index ec69dbcd..276eb52c 100644 --- a/core/server/src/tcp/tcp_tls_listener.rs +++ b/core/server/src/tcp/tcp_tls_listener.rs @@ -19,6 +19,7 @@ use crate::binary::sender::SenderKind; use crate::configs::tcp::TcpSocketConfig; use crate::shard::IggyShard; +use crate::shard::tasks::task_supervisor; use crate::shard::transmission::event::ShardEvent; use crate::tcp::connection_handler::{handle_connection, handle_error}; use crate::{shard_error, shard_info, shard_warn}; @@ -219,7 +220,7 @@ async fn accept_loop( // Perform TLS handshake in a separate task to avoid blocking the accept loop let task_shard = shard_clone.clone(); - task_shard.task_registry.spawn_tracked(async move { + task_supervisor().spawn_tracked(async move { match acceptor.accept(stream).await { Ok(tls_stream) => { // TLS handshake successful, now create session @@ -237,13 +238,13 @@ async fn accept_loop( let client_id = session.client_id; shard_info!(shard_clone.id, "Created new session: {}", session); - let conn_stop_receiver = shard_clone.task_registry.add_connection(client_id); + let conn_stop_receiver = task_supervisor().add_connection(client_id); let shard_for_conn = shard_clone.clone(); let mut sender = SenderKind::get_tcp_tls_sender(tls_stream); if let Err(error) = handle_connection(&session, &mut sender, &shard_for_conn, conn_stop_receiver).await { handle_error(error); } - shard_for_conn.task_registry.remove_connection(&client_id); + task_supervisor().remove_connection(&client_id); if let Err(error) = sender.shutdown().await { shard_error!(shard.id, "Failed to shutdown TCP TLS stream for client: {}, address: {}. {}", client_id, address, error);
