This is an automated email from the ASF dual-hosted git repository. hgruszecki pushed a commit to branch io_uring_tpc_task_registry in repository https://gitbox.apache.org/repos/asf/iggy.git
commit d29da9d64e5d4cfc758c2cba2dafc529ecb94f7d Author: Hubert Gruszecki <[email protected]> AuthorDate: Wed Sep 24 13:14:23 2025 +0200 feat(io_uring): implement task supervisor --- Cargo.lock | 4 - core/bench/src/actors/consumer/client/low_level.rs | 2 +- core/bench/src/benchmarks/benchmark.rs | 7 +- core/bench/src/benchmarks/common.rs | 10 +- core/common/src/error/iggy_error.rs | 2 + core/server/Cargo.toml | 1 - .../commands/clean_personal_access_tokens.rs | 75 --- core/server/src/channels/commands/print_sysinfo.rs | 70 --- core/server/src/channels/commands/save_messages.rs | 86 ---- .../src/channels/commands/verify_heartbeats.rs | 80 --- core/server/src/http/http_server.rs | 27 +- core/server/src/http/jwt/cleaner.rs | 41 -- core/server/src/http/jwt/mod.rs | 1 - core/server/src/http/mod.rs | 2 +- core/server/src/http/partitions.rs | 1 - core/server/src/lib.rs | 1 - core/server/src/log/logger.rs | 4 +- core/server/src/log/runtime.rs | 21 +- core/server/src/main.rs | 30 +- core/server/src/quic/listener.rs | 121 +++-- core/server/src/shard/builder.rs | 23 +- core/server/src/shard/mod.rs | 100 ++-- core/server/src/shard/stats.rs | 1 - core/server/src/shard/system/messages.rs | 8 +- core/server/src/shard/task_registry.rs | 108 ---- core/server/src/shard/task_registry/builders.rs | 22 + .../src/shard/task_registry/builders/continuous.rs | 116 +++++ .../src/shard/task_registry/builders/oneshot.rs | 128 +++++ .../src/shard/task_registry/builders/periodic.rs | 144 ++++++ core/server/src/shard/task_registry/mod.rs | 12 + core/server/src/shard/task_registry/registry.rs | 550 +++++++++++++++++++++ core/server/src/shard/task_registry/shutdown.rs | 233 +++++++++ core/server/src/shard/task_registry/specs.rs | 59 +++ core/server/src/shard/task_registry/tls.rs | 104 ++++ .../src/shard/tasks/continuous/http_server.rs | 68 +++ .../src/shard/tasks/continuous/message_pump.rs | 99 ++++ .../commands => shard/tasks/continuous}/mod.rs | 13 +- .../src/shard/tasks/continuous/quic_server.rs | 62 +++ .../src/shard/tasks/continuous/tcp_server.rs | 62 +++ core/server/src/shard/tasks/messages.rs | 59 --- core/server/src/shard/tasks/mod.rs | 105 +++- .../src/shard/tasks/periodic/clear_jwt_tokens.rs | 92 ++++ .../tasks/periodic/clear_personal_access_tokens.rs | 106 ++++ .../src/{channels => shard/tasks/periodic}/mod.rs | 12 +- .../src/shard/tasks/periodic/print_sysinfo.rs | 123 +++++ .../src/shard/tasks/periodic/save_messages.rs | 127 +++++ .../src/shard/tasks/periodic/verify_heartbeats.rs | 126 +++++ core/server/src/shard/transmission/message.rs | 5 +- core/server/src/slab/streams.rs | 39 +- core/server/src/streaming/storage.rs | 2 +- core/server/src/tcp/tcp_listener.rs | 11 +- core/server/src/tcp/tcp_tls_listener.rs | 7 +- 52 files changed, 2595 insertions(+), 717 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 876935f9..0b2edd0b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -614,8 +614,6 @@ dependencies = [ "futures-lite", "pin-project", "thiserror 2.0.16", - "tokio", - "tokio-util", ] [[package]] @@ -5716,8 +5714,6 @@ dependencies = [ "rand 0.9.2", "serde_json", "thiserror 2.0.16", - "tokio", - "tokio-stream", ] [[package]] diff --git a/core/bench/src/actors/consumer/client/low_level.rs b/core/bench/src/actors/consumer/client/low_level.rs index 088ad42f..4c80c4ce 100644 --- a/core/bench/src/actors/consumer/client/low_level.rs +++ b/core/bench/src/actors/consumer/client/low_level.rs @@ -109,7 +109,7 @@ impl ConsumerClient for LowLevelConsumerClient { } Ok(Some(BatchMetrics { - messages: messages_count as u32, + messages: messages_count.try_into().unwrap(), user_data_bytes: user_bytes, total_bytes, latency, diff --git a/core/bench/src/benchmarks/benchmark.rs b/core/bench/src/benchmarks/benchmark.rs index 37e65b78..31f1828c 100644 --- a/core/bench/src/benchmarks/benchmark.rs +++ b/core/bench/src/benchmarks/benchmark.rs @@ -102,7 +102,7 @@ pub trait Benchmarkable: Send { login_root(&client).await; let streams = client.get_streams().await?; for i in 1..=number_of_streams { - let stream_name = format!("bench-stream-{}", i); + let stream_name = format!("bench-stream-{i}"); let stream_id: Identifier = stream_name.as_str().try_into()?; if streams.iter().all(|s| s.name != stream_name) { info!("Creating the test stream '{}'", stream_name); @@ -141,11 +141,10 @@ pub trait Benchmarkable: Send { login_root(&client).await; let streams = client.get_streams().await?; for i in 1..=number_of_streams { - let stream_name = format!("bench-stream-{}", i); + let stream_name = format!("bench-stream-{i}"); if streams.iter().all(|s| s.name != stream_name) { return Err(IggyError::ResourceNotFound(format!( - "Streams for testing are not properly initialized. Stream '{}' is missing.", - stream_name + "Streams for testing are not properly initialized. Stream '{stream_name}' is missing." ))); } } diff --git a/core/bench/src/benchmarks/common.rs b/core/bench/src/benchmarks/common.rs index 3ee4f686..5aa2c458 100644 --- a/core/bench/src/benchmarks/common.rs +++ b/core/bench/src/benchmarks/common.rs @@ -79,7 +79,7 @@ pub async fn init_consumer_groups( login_root(&client).await; for i in 1..=cg_count { let consumer_group_id = CONSUMER_GROUP_BASE_ID + i; - let stream_name = format!("bench-stream-{}", i); + let stream_name = format!("bench-stream-{i}"); let stream_id: Identifier = stream_name.as_str().try_into()?; let topic_id: Identifier = "topic-1".try_into()?; let consumer_group_name = format!("{CONSUMER_GROUP_NAME_PREFIX}-{consumer_group_id}"); @@ -136,7 +136,7 @@ pub fn build_producer_futures( }; let stream_idx = 1 + ((producer_id - 1) % streams); - let stream_id = format!("bench-stream-{}", stream_idx); + let stream_id = format!("bench-stream-{stream_idx}"); async move { let producer = TypedBenchmarkProducer::new( @@ -203,7 +203,7 @@ pub fn build_consumer_futures( } else { consumer_id }; - let stream_id = format!("bench-stream-{}", stream_idx); + let stream_id = format!("bench-stream-{stream_idx}"); let consumer_group_id = if cg_count > 0 { Some(CONSUMER_GROUP_BASE_ID + 1 + (consumer_id % cg_count)) } else { @@ -251,7 +251,7 @@ pub fn build_producing_consumers_futures( let client_factory_clone = client_factory.clone(); let args_clone = args.clone(); let stream_idx = 1 + ((actor_id - 1) % streams); - let stream_id = format!("bench-stream-{}", stream_idx); + let stream_id = format!("bench-stream-{stream_idx}"); let send_finish_condition = BenchmarkFinishCondition::new( &args, @@ -329,7 +329,7 @@ pub fn build_producing_consumer_groups_futures( let client_factory_clone = client_factory.clone(); let args_clone = args.clone(); let stream_idx = 1 + ((actor_id - 1) % cg_count); - let stream_id = format!("bench-stream-{}", stream_idx); + let stream_id = format!("bench-stream-{stream_idx}"); let should_produce = actor_id <= producers; let should_consume = actor_id <= consumers; diff --git a/core/common/src/error/iggy_error.rs b/core/common/src/error/iggy_error.rs index 614757f1..19dd41a9 100644 --- a/core/common/src/error/iggy_error.rs +++ b/core/common/src/error/iggy_error.rs @@ -476,6 +476,8 @@ pub enum IggyError { #[error("Cannot bind to socket with addr: {0}")] CannotBindToSocket(String) = 12000, + #[error("Task execution timeout")] + TaskTimeout = 12001, } impl IggyError { diff --git a/core/server/Cargo.toml b/core/server/Cargo.toml index 64f71a31..49b6b66c 100644 --- a/core/server/Cargo.toml +++ b/core/server/Cargo.toml @@ -34,7 +34,6 @@ path = "src/main.rs" [features] default = ["mimalloc"] -tokio-console = ["dep:console-subscriber", "tokio/tracing"] disable-mimalloc = [] mimalloc = ["dep:mimalloc"] diff --git a/core/server/src/channels/commands/clean_personal_access_tokens.rs b/core/server/src/channels/commands/clean_personal_access_tokens.rs deleted file mode 100644 index d846aa27..00000000 --- a/core/server/src/channels/commands/clean_personal_access_tokens.rs +++ /dev/null @@ -1,75 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::rc::Rc; - -use crate::shard::IggyShard; -use iggy_common::{IggyError, IggyTimestamp}; -use tracing::{debug, info, trace}; - -pub async fn clear_personal_access_tokens(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let config = &shard.config.personal_access_token.cleaner; - if !config.enabled { - info!("Personal access token cleaner is disabled."); - return Ok(()); - } - - info!( - "Personal access token cleaner is enabled, expired tokens will be deleted every: {}.", - config.interval - ); - - let interval = config.interval.get_duration(); - let mut interval_timer = compio::time::interval(interval); - - loop { - interval_timer.tick().await; - trace!("Cleaning expired personal access tokens..."); - - let users = shard.users.borrow(); - let now = IggyTimestamp::now(); - let mut deleted_tokens_count = 0; - - for (_, user) in users.iter() { - let expired_tokens = user - .personal_access_tokens - .iter() - .filter(|token| token.is_expired(now)) - .map(|token| token.token.clone()) - .collect::<Vec<_>>(); - - for token in expired_tokens { - debug!( - "Personal access token: {} for user with ID: {} is expired.", - token, user.id - ); - deleted_tokens_count += 1; - user.personal_access_tokens.remove(&token); - debug!( - "Deleted personal access token: {} for user with ID: {}.", - token, user.id - ); - } - } - - info!( - "Deleted {} expired personal access tokens.", - deleted_tokens_count - ); - } -} diff --git a/core/server/src/channels/commands/print_sysinfo.rs b/core/server/src/channels/commands/print_sysinfo.rs deleted file mode 100644 index 0c3992d2..00000000 --- a/core/server/src/channels/commands/print_sysinfo.rs +++ /dev/null @@ -1,70 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::rc::Rc; - -use crate::shard::IggyShard; -use crate::streaming::utils::memory_pool; -use human_repr::HumanCount; -use iggy_common::IggyError; -use tracing::{error, info, trace}; - -pub async fn print_sys_info(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let config = &shard.config.system.logging; - let interval = config.sysinfo_print_interval; - - if interval.is_zero() { - info!("SysInfoPrinter is disabled."); - return Ok(()); - } - info!("SysInfoPrinter is enabled, system information will be printed every {interval}."); - let mut interval_timer = compio::time::interval(interval.get_duration()); - loop { - interval_timer.tick().await; - trace!("Printing system information..."); - - let stats = match shard.get_stats().await { - Ok(stats) => stats, - Err(e) => { - error!("Failed to get system information. Error: {e}"); - continue; - } - }; - - let free_memory_percent = (stats.available_memory.as_bytes_u64() as f64 - / stats.total_memory.as_bytes_u64() as f64) - * 100f64; - - info!( - "CPU: {:.2}%/{:.2}% (IggyUsage/Total), Mem: {:.2}%/{}/{}/{} (Free/IggyUsage/TotalUsed/Total), Clients: {}, Messages processed: {}, Read: {}, Written: {}, Uptime: {}", - stats.cpu_usage, - stats.total_cpu_usage, - free_memory_percent, - stats.memory_usage, - stats.total_memory - stats.available_memory, - stats.total_memory, - stats.clients_count, - stats.messages_count.human_count_bare(), - stats.read_bytes, - stats.written_bytes, - stats.run_time - ); - - memory_pool().log_stats(); - } -} diff --git a/core/server/src/channels/commands/save_messages.rs b/core/server/src/channels/commands/save_messages.rs deleted file mode 100644 index 84330563..00000000 --- a/core/server/src/channels/commands/save_messages.rs +++ /dev/null @@ -1,86 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::{shard::IggyShard, shard_info}; -use iggy_common::{Identifier, IggyError}; -use std::rc::Rc; -use tracing::{error, info, trace}; - -pub async fn save_messages(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let config = &shard.config.message_saver; - if !config.enabled { - info!("Message saver is disabled."); - return Ok(()); - } - - // TODO: Maybe we should get rid of it in order to not complicate, and use the fsync settings per partition from config. - let enforce_fsync = config.enforce_fsync; - let interval = config.interval; - info!( - "Message saver is enabled, buffered messages will be automatically saved every: {interval}, enforce fsync: {enforce_fsync}." - ); - - let mut interval_timer = compio::time::interval(interval.get_duration()); - loop { - interval_timer.tick().await; - trace!("Saving buffered messages..."); - - let namespaces = shard.get_current_shard_namespaces(); - let mut total_saved_messages = 0u32; - let reason = "background saver triggered".to_string(); - - for ns in namespaces { - let stream_id = Identifier::numeric(ns.stream_id() as u32).unwrap(); - let topic_id = Identifier::numeric(ns.topic_id() as u32).unwrap(); - let partition_id = ns.partition_id(); - - match shard - .streams2 - .persist_messages( - shard.id, - &stream_id, - &topic_id, - partition_id, - reason.clone(), - &shard.config.system, - ) - .await - { - Ok(batch_count) => { - total_saved_messages += batch_count; - } - Err(err) => { - error!( - "Failed to save messages for partition {}: {}", - partition_id, err - ); - } - } - } - - if total_saved_messages > 0 { - shard_info!( - shard.id, - "Saved {} buffered messages on disk.", - total_saved_messages - ); - } - - trace!("Finished saving buffered messages."); - } -} diff --git a/core/server/src/channels/commands/verify_heartbeats.rs b/core/server/src/channels/commands/verify_heartbeats.rs deleted file mode 100644 index 123c174f..00000000 --- a/core/server/src/channels/commands/verify_heartbeats.rs +++ /dev/null @@ -1,80 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::shard::IggyShard; -use iggy_common::{IggyDuration, IggyError, IggyTimestamp}; -use std::rc::Rc; -use tracing::{debug, info, trace, warn}; - -const MAX_THRESHOLD: f64 = 1.2; - -pub async fn verify_heartbeats(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let config = &shard.config.heartbeat; - if !config.enabled { - info!("Heartbeats verification is disabled."); - return Ok(()); - } - - let interval = config.interval; - let max_interval = IggyDuration::from((MAX_THRESHOLD * interval.as_micros() as f64) as u64); - info!("Heartbeats will be verified every: {interval}. Max allowed interval: {max_interval}."); - - let mut interval_timer = compio::time::interval(interval.get_duration()); - - loop { - interval_timer.tick().await; - trace!("Verifying heartbeats..."); - - let clients = { - let client_manager = shard.client_manager.borrow(); - client_manager.get_clients() - }; - - let now = IggyTimestamp::now(); - let heartbeat_to = IggyTimestamp::from(now.as_micros() - max_interval.as_micros()); - debug!("Verifying heartbeats at: {now}, max allowed timestamp: {heartbeat_to}"); - - let mut stale_clients = Vec::new(); - for client in clients { - if client.last_heartbeat.as_micros() < heartbeat_to.as_micros() { - warn!( - "Stale client session: {}, last heartbeat at: {}, max allowed timestamp: {heartbeat_to}", - client.session, client.last_heartbeat, - ); - client.session.set_stale(); - stale_clients.push(client.session.client_id); - } else { - debug!( - "Valid heartbeat at: {} for client session: {}, max allowed timestamp: {heartbeat_to}", - client.last_heartbeat, client.session, - ); - } - } - - if stale_clients.is_empty() { - continue; - } - - let count = stale_clients.len(); - info!("Removing {count} stale clients..."); - for client_id in stale_clients { - shard.delete_client(client_id); - } - info!("Removed {count} stale clients."); - } -} diff --git a/core/server/src/http/http_server.rs b/core/server/src/http/http_server.rs index 2ae46271..187b9e40 100644 --- a/core/server/src/http/http_server.rs +++ b/core/server/src/http/http_server.rs @@ -19,13 +19,13 @@ use crate::configs::http::{HttpConfig, HttpCorsConfig}; use crate::http::diagnostics::request_diagnostics; use crate::http::http_shard_wrapper::HttpSafeShard; -use crate::http::jwt::cleaner::start_expired_tokens_cleaner; use crate::http::jwt::jwt_manager::JwtManager; use crate::http::jwt::middleware::jwt_auth; use crate::http::metrics::metrics; use crate::http::shared::AppState; use crate::http::*; use crate::shard::IggyShard; +use crate::shard::task_registry::{TaskScope, task_registry}; use crate::streaming::persistence::persister::PersisterKind; // use crate::streaming::systems::system::SharedSystem; use axum::extract::DefaultBodyLimit; @@ -66,7 +66,11 @@ impl<'a> Connected<cyper_axum::IncomingStream<'a, TcpListener>> for CompioSocket /// Starts the HTTP API server. /// Returns the address the server is listening on. -pub async fn start(config: HttpConfig, persister: Arc<PersisterKind>, shard: Rc<IggyShard>) -> Result<(), IggyError> { +pub async fn start( + config: HttpConfig, + persister: Arc<PersisterKind>, + shard: Rc<IggyShard>, +) -> Result<(), IggyError> { if shard.id != 0 { info!( "HTTP server disabled for shard {} (only runs on shard 0)", @@ -81,7 +85,7 @@ pub async fn start(config: HttpConfig, persister: Arc<PersisterKind>, shard: Rc< "HTTP API" }; - let app_state = build_app_state(&config, persister, shard).await; + let app_state = build_app_state(&config, persister, shard.clone()).await; let mut app = Router::new() .merge(system::router(app_state.clone(), &config.metrics)) .merge(personal_access_tokens::router(app_state.clone())) @@ -105,7 +109,16 @@ pub async fn start(config: HttpConfig, persister: Arc<PersisterKind>, shard: Rc< app = app.layer(middleware::from_fn_with_state(app_state.clone(), metrics)); } - start_expired_tokens_cleaner(app_state.clone()); + // JWT token cleaner task + { + use crate::shard::tasks::periodic::ClearJwtTokens; + let period = std::time::Duration::from_secs(300); // 5 minutes + task_registry().spawn_periodic( + shard.clone(), + Box::new(ClearJwtTokens::new(app_state.clone(), period)), + ); + } + app = app.layer(middleware::from_fn(request_diagnostics)); if !config.tls.enabled { @@ -158,7 +171,11 @@ pub async fn start(config: HttpConfig, persister: Arc<PersisterKind>, shard: Rc< } } -async fn build_app_state(config: &HttpConfig, persister: Arc<PersisterKind>, shard: Rc<IggyShard>) -> Arc<AppState> { +async fn build_app_state( + config: &HttpConfig, + persister: Arc<PersisterKind>, + shard: Rc<IggyShard>, +) -> Arc<AppState> { let tokens_path; { tokens_path = shard.config.system.get_state_tokens_path(); diff --git a/core/server/src/http/jwt/cleaner.rs b/core/server/src/http/jwt/cleaner.rs deleted file mode 100644 index 6705d955..00000000 --- a/core/server/src/http/jwt/cleaner.rs +++ /dev/null @@ -1,41 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::http::shared::AppState; -use iggy_common::IggyTimestamp; -use std::{sync::Arc, time::Duration}; -use tracing::{error, trace}; - -pub fn start_expired_tokens_cleaner(app_state: Arc<AppState>) { - compio::runtime::spawn(async move { - let mut interval_timer = compio::time::interval(Duration::from_secs(300)); - loop { - interval_timer.tick().await; - trace!("Deleting expired tokens..."); - let now = IggyTimestamp::now().to_secs(); - app_state - .jwt_manager - .delete_expired_revoked_tokens(now) - .await - .unwrap_or_else(|err| { - error!("Failed to delete expired revoked access tokens. Error: {err}"); - }); - } - }) - .detach(); -} diff --git a/core/server/src/http/jwt/mod.rs b/core/server/src/http/jwt/mod.rs index a60b270b..5481f110 100644 --- a/core/server/src/http/jwt/mod.rs +++ b/core/server/src/http/jwt/mod.rs @@ -16,7 +16,6 @@ * under the License. */ -pub mod cleaner; pub mod json_web_token; pub mod jwt_manager; pub mod middleware; diff --git a/core/server/src/http/mod.rs b/core/server/src/http/mod.rs index ba369815..4cf663d3 100644 --- a/core/server/src/http/mod.rs +++ b/core/server/src/http/mod.rs @@ -23,7 +23,7 @@ mod http_shard_wrapper; pub mod jwt; mod mapper; pub mod metrics; -mod shared; +pub mod shared; pub mod consumer_groups; pub mod consumer_offsets; diff --git a/core/server/src/http/partitions.rs b/core/server/src/http/partitions.rs index 11a1c4a1..b54539ea 100644 --- a/core/server/src/http/partitions.rs +++ b/core/server/src/http/partitions.rs @@ -33,7 +33,6 @@ use iggy_common::Identifier; use iggy_common::Validatable; use iggy_common::create_partitions::CreatePartitions; use iggy_common::delete_partitions::DeletePartitions; -use iggy_common::locking::IggyRwLockFn; use send_wrapper::SendWrapper; use std::sync::Arc; use tracing::instrument; diff --git a/core/server/src/lib.rs b/core/server/src/lib.rs index 38263692..735a2e51 100644 --- a/core/server/src/lib.rs +++ b/core/server/src/lib.rs @@ -31,7 +31,6 @@ compile_error!("iggy-server doesn't support windows."); pub mod args; pub mod binary; pub mod bootstrap; -pub mod channels; pub(crate) mod compat; pub mod configs; pub mod http; diff --git a/core/server/src/log/logger.rs b/core/server/src/log/logger.rs index 59725685..b400233a 100644 --- a/core/server/src/log/logger.rs +++ b/core/server/src/log/logger.rs @@ -16,10 +16,10 @@ * under the License. */ -use crate::log::runtime::CompioRuntime; use crate::VERSION; use crate::configs::server::{TelemetryConfig, TelemetryTransport}; use crate::configs::system::LoggingConfig; +use crate::log::runtime::CompioRuntime; use crate::server_error::LogError; use opentelemetry::KeyValue; use opentelemetry::global; @@ -250,7 +250,7 @@ impl Logging { .with_span_processor( span_processor_with_async_runtime::BatchSpanProcessor::builder( trace_exporter, - CompioRuntime + CompioRuntime, ) .build(), ) diff --git a/core/server/src/log/runtime.rs b/core/server/src/log/runtime.rs index 1683b34a..4959acee 100644 --- a/core/server/src/log/runtime.rs +++ b/core/server/src/log/runtime.rs @@ -1,7 +1,6 @@ -use std::{pin::Pin, task::Poll, time::Duration}; - -use futures::{channel::mpsc, future::poll_fn, FutureExt, SinkExt, Stream, StreamExt}; +use futures::{FutureExt, SinkExt, Stream, StreamExt, channel::mpsc, future::poll_fn}; use opentelemetry_sdk::runtime::{Runtime, RuntimeChannel, TrySend}; +use std::{pin::Pin, task::Poll, time::Duration}; #[derive(Clone)] pub struct CompioRuntime; @@ -34,17 +33,20 @@ impl<T> CompioSender<T> { pub fn new(sender: mpsc::UnboundedSender<T>) -> Self { Self { sender } } -} +} -// Safety: Since we use compio runtime which is single-threaded, or rather the Future: !Send + !Sync, +// Safety: Since we use compio runtime which is single-threaded, or rather the Future: !Send + !Sync, // we can implement those traits, to satisfy the trait bounds from `Runtime` and `RuntimeChannel` traits. unsafe impl<T> Send for CompioSender<T> {} unsafe impl<T> Sync for CompioSender<T> {} -impl<T: std::fmt::Debug + Send> TrySend for CompioSender<T> { +impl<T: std::fmt::Debug + Send> TrySend for CompioSender<T> { type Message = T; - fn try_send(&self, item: Self::Message) -> Result<(), opentelemetry_sdk::runtime::TrySendError> { + fn try_send( + &self, + item: Self::Message, + ) -> Result<(), opentelemetry_sdk::runtime::TrySendError> { self.sender.unbounded_send(item).map_err(|_err| { // Unbounded channels can only fail if disconnected, never full opentelemetry_sdk::runtime::TrySendError::ChannelClosed @@ -65,7 +67,10 @@ impl<T> CompioReceiver<T> { impl<T: std::fmt::Debug + Send> Stream for CompioReceiver<T> { type Item = T; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Option<Self::Item>> { self.receiver.poll_next_unpin(cx) } } diff --git a/core/server/src/main.rs b/core/server/src/main.rs index 095ba4c2..8006b5d2 100644 --- a/core/server/src/main.rs +++ b/core/server/src/main.rs @@ -42,10 +42,7 @@ use server::configs::config_provider::{self}; use server::configs::server::ServerConfig; use server::configs::sharding::CpuAllocation; use server::io::fs_utils; -#[cfg(not(feature = "tokio-console"))] use server::log::logger::Logging; -#[cfg(feature = "tokio-console")] -use server::log::tokio_console::Logging; use server::server_error::{ConfigError, ServerError}; use server::shard::namespace::IggyNamespace; use server::shard::system::info::SystemInfo; @@ -354,22 +351,19 @@ async fn main() -> Result<(), ServerError> { } let shutdown_handles_for_signal = shutdown_handles.clone(); - /* - ::set_handler(move || { - info!("Received shutdown signal (SIGTERM/SIGINT), initiating graceful shutdown..."); - - for (shard_id, stop_sender) in &shutdown_handles_for_signal { - info!("Sending shutdown signal to shard {}", shard_id); - if let Err(e) = stop_sender.send_blocking(()) { - error!( - "Failed to send shutdown signal to shard {}: {}", - shard_id, e - ); - } + ctrlc::set_handler(move || { + info!("Received shutdown signal (SIGTERM/SIGINT), initiating graceful shutdown..."); + + for (shard_id, stop_sender) in &shutdown_handles_for_signal { + if let Err(e) = stop_sender.try_send(()) { + error!( + "Failed to send shutdown signal to shard {}: {}", + shard_id, e + ); } - }) - .expect("Error setting Ctrl-C handler"); - */ + } + }) + .expect("Error setting Ctrl-C handler"); info!("Iggy server is running. Press Ctrl+C or send SIGTERM to shutdown."); for (idx, handle) in handles.into_iter().enumerate() { diff --git a/core/server/src/quic/listener.rs b/core/server/src/quic/listener.rs index 5fce5c84..b10119cc 100644 --- a/core/server/src/quic/listener.rs +++ b/core/server/src/quic/listener.rs @@ -16,56 +16,85 @@ * under the License. */ -use std::rc::Rc; - use crate::binary::command::{ServerCommand, ServerCommandHandler}; use crate::binary::sender::SenderKind; use crate::server_error::ConnectionError; use crate::shard::IggyShard; +use crate::shard::task_registry::task_registry; use crate::shard::transmission::event::ShardEvent; use crate::streaming::session::Session; use crate::{shard_debug, shard_info}; use anyhow::anyhow; use compio_quic::{Connection, Endpoint, RecvStream, SendStream}; +use futures::FutureExt; use iggy_common::IggyError; use iggy_common::TransportProtocol; +use std::rc::Rc; use tracing::{error, info, trace}; const INITIAL_BYTES_LENGTH: usize = 4; pub async fn start(endpoint: Endpoint, shard: Rc<IggyShard>) -> Result<(), IggyError> { - info!("Starting QUIC listener for shard {}", shard.id); - // Since the QUIC Endpoint is internally Arc-wrapped and can be shared, // we only need one worker per shard rather than multiple workers per endpoint. // This avoids the N×workers multiplication when multiple shards are used. - while let Some(incoming_conn) = endpoint.wait_incoming().await { - let remote_addr = incoming_conn.remote_address(); - trace!("Incoming connection from client: {}", remote_addr); - let shard = shard.clone(); - - // Spawn each connection handler independently to maintain concurrency - compio::runtime::spawn(async move { - trace!("Accepting connection from {}", remote_addr); - match incoming_conn.await { - Ok(connection) => { - trace!("Connection established from {}", remote_addr); - if let Err(error) = handle_connection(connection, shard).await { - error!("QUIC connection from {} has failed: {error}", remote_addr); - } + loop { + let shard_clone = shard.clone(); + let shutdown_check = async { + loop { + if shard_clone.is_shutting_down() { + return; } - Err(error) => { - error!( - "Error when accepting incoming connection from {}: {:?}", - remote_addr, error - ); + compio::time::sleep(std::time::Duration::from_millis(100)).await; + } + }; + + let accept_future = endpoint.wait_incoming(); + + futures::select! { + _ = shutdown_check.fuse() => { + shard_debug!(shard.id, "QUIC listener detected shutdown flag, no longer accepting connections"); + break; + } + incoming_conn = accept_future.fuse() => { + match incoming_conn { + Some(incoming_conn) => { + let remote_addr = incoming_conn.remote_address(); + + if shard.is_shutting_down() { + shard_info!(shard.id, "Rejecting new QUIC connection from {} during shutdown", remote_addr); + continue; + } + + trace!("Incoming connection from client: {}", remote_addr); + let shard_for_conn = shard.clone(); + + task_registry().spawn_tracked(async move { + trace!("Accepting connection from {}", remote_addr); + match incoming_conn.await { + Ok(connection) => { + trace!("Connection established from {}", remote_addr); + if let Err(error) = handle_connection(connection, shard_for_conn).await { + error!("QUIC connection from {} has failed: {error}", remote_addr); + } + } + Err(error) => { + error!( + "Error when accepting incoming connection from {}: {:?}", + remote_addr, error + ); + } + } + }); + } + None => { + info!("QUIC endpoint closed for shard {}", shard.id); + break; + } } } - }) - .detach(); + } } - - info!("QUIC listener for shard {} stopped", shard.id); Ok(()) } @@ -77,7 +106,6 @@ async fn handle_connection( info!("Client has connected: {address}"); let session = shard.add_client(&address, TransportProtocol::Quic); - let session = shard.add_client(&address, TransportProtocol::Quic); let client_id = session.client_id; shard_debug!( shard.id, @@ -93,19 +121,40 @@ async fn handle_connection( address, transport: TransportProtocol::Quic, }; + + // TODO(hubcio): unused? let _responses = shard.broadcast_event_to_all_shards(event.into()).await; - while let Some(stream) = accept_stream(&connection, &shard, client_id).await? { - let shard = shard.clone(); - let session = session.clone(); + let conn_stop_receiver = task_registry().add_connection(client_id); - let handle_stream_task = async move { - if let Err(err) = handle_stream(stream, shard, session).await { - error!("Error when handling QUIC stream: {:?}", err) + loop { + futures::select! { + // Check for shutdown signal + _ = conn_stop_receiver.recv().fuse() => { + info!("QUIC connection {} shutting down gracefully", client_id); + break; } - }; - let _handle = compio::runtime::spawn(handle_stream_task).detach(); + // Accept new connection + stream_result = accept_stream(&connection, &shard, client_id).fuse() => { + match stream_result? { + Some(stream) => { + let shard_clone = shard.clone(); + let session_clone = session.clone(); + + task_registry().spawn_tracked(async move { + if let Err(err) = handle_stream(stream, shard_clone, session_clone).await { + error!("Error when handling QUIC stream: {:?}", err) + } + }); + } + None => break, // Connection closed + } + } + } } + + task_registry().remove_connection(&client_id); + info!("QUIC connection {} closed", client_id); Ok(()) } diff --git a/core/server/src/shard/builder.rs b/core/server/src/shard/builder.rs index 287f31bb..a5cf3b63 100644 --- a/core/server/src/shard/builder.rs +++ b/core/server/src/shard/builder.rs @@ -16,21 +16,10 @@ * under the License. */ -use std::{ - cell::{Cell, RefCell}, - rc::Rc, - sync::{Arc, atomic::AtomicBool}, -}; - -use ahash::HashMap; -use dashmap::DashMap; -use iggy_common::{Aes256GcmEncryptor, EncryptorKind, UserId}; -use tracing::info; - use crate::{ configs::server::ServerConfig, io::storage::Storage, - shard::{Shard, ShardInfo, namespace::IggyNamespace, task_registry::TaskRegistry}, + shard::{Shard, ShardInfo, namespace::IggyNamespace}, slab::streams::Streams, state::{StateKind, system::SystemState}, streaming::{ @@ -39,6 +28,15 @@ use crate::{ }, versioning::SemanticVersion, }; +use ahash::HashMap; +use dashmap::DashMap; +use iggy_common::{Aes256GcmEncryptor, EncryptorKind, UserId}; +use std::{ + cell::{Cell, RefCell}, + rc::Rc, + sync::{Arc, atomic::AtomicBool}, +}; +use tracing::info; use super::{IggyShard, transmission::connector::ShardConnector, transmission::frame::ShardFrame}; @@ -151,7 +149,6 @@ impl IggyShardBuilder { stop_sender: stop_sender, messages_receiver: Cell::new(Some(frame_receiver)), metrics: metrics, - task_registry: TaskRegistry::new(), is_shutting_down: AtomicBool::new(false), tcp_bound_address: Cell::new(None), quic_bound_address: Cell::new(None), diff --git a/core/server/src/shard/mod.rs b/core/server/src/shard/mod.rs index 534ddfe3..65a173bd 100644 --- a/core/server/src/shard/mod.rs +++ b/core/server/src/shard/mod.rs @@ -19,7 +19,6 @@ inner() * or more contributor license agreements. See the NOTICE file pub mod builder; pub mod logging; pub mod namespace; -pub mod stats; pub mod system; pub mod task_registry; pub mod tasks; @@ -51,18 +50,17 @@ use std::{ }, time::{Duration, Instant}, }; -use tracing::{error, info, instrument, trace, warn}; +use tracing::{debug, error, info, instrument, trace, warn}; use transmission::connector::{Receiver, ShardConnector, StopReceiver, StopSender}; use crate::{ binary::handlers::messages::poll_messages_handler::IggyPollMetadata, configs::server::ServerConfig, - http::http_server, io::fs_utils, shard::{ namespace::{IggyFullNamespace, IggyNamespace}, - task_registry::TaskRegistry, - tasks::messages::spawn_shard_message_task, + task_registry::{init_task_registry, task_registry}, + tasks::register_tasks, transmission::{ event::ShardEvent, frame::{ShardFrame, ShardResponse}, @@ -75,12 +73,21 @@ use crate::{ traits_ext::{EntityComponentSystem, EntityMarker, Insert}, }, state::{ - file::FileState, system::{StreamState, SystemState, UserState}, StateKind + StateKind, + file::FileState, + system::{StreamState, SystemState, UserState}, }, streaming::{ - clients::client_manager::ClientManager, diagnostics::metrics::Metrics, partitions, persistence::persister::PersisterKind, polling_consumer::PollingConsumer, session::Session, traits::MainOps, users::{permissioner::Permissioner, user::User}, utils::ptr::EternalPtr + clients::client_manager::ClientManager, + diagnostics::metrics::Metrics, + partitions, + persistence::persister::PersisterKind, + polling_consumer::PollingConsumer, + session::Session, + traits::MainOps, + users::{permissioner::Permissioner, user::User}, + utils::ptr::EternalPtr, }, - tcp::tcp_server::spawn_tcp_server, versioning::SemanticVersion, }; @@ -157,7 +164,6 @@ pub struct IggyShard { pub messages_receiver: Cell<Option<Receiver<ShardFrame>>>, pub(crate) stop_receiver: StopReceiver, pub(crate) stop_sender: StopSender, - pub(crate) task_registry: TaskRegistry, pub(crate) is_shutting_down: AtomicBool, pub(crate) tcp_bound_address: Cell<Option<SocketAddr>>, pub(crate) quic_bound_address: Cell<Option<SocketAddr>>, @@ -175,77 +181,45 @@ impl IggyShard { } pub async fn run(self: &Rc<Self>, persister: Arc<PersisterKind>) -> Result<(), IggyError> { + let now = Instant::now(); + + // Initialize thread-local task registry for this thread + init_task_registry(self.id); + // Workaround to ensure that the statistics are initialized before the server // loads streams and starts accepting connections. This is necessary to // have the correct statistics when the server starts. - let now = Instant::now(); self.get_stats().await?; shard_info!(self.id, "Starting..."); self.init().await?; + // TODO: Fixme //self.assert_init(); - // Create all tasks (tcp listener, http listener, command processor, in the future also the background jobs). - let mut tasks: Vec<Task> = vec![Box::pin(spawn_shard_message_task(self.clone()))]; - if self.config.tcp.enabled { - tasks.push(Box::pin(spawn_tcp_server(self.clone()))); - } - - if self.config.http.enabled && self.id == 0 { - println!("Starting HTTP server on shard: {}", self.id); - tasks.push(Box::pin(http_server::start( - self.config.http.clone(), - persister, - self.clone(), - ))); - } + // Create and spawn all tasks via the supervisor + register_tasks(&task_registry(), self.clone()); - if self.config.quic.enabled { - tasks.push(Box::pin(crate::quic::quic_server::span_quic_server( - self.clone(), - ))); - } - - tasks.push(Box::pin( - crate::channels::commands::clean_personal_access_tokens::clear_personal_access_tokens( - self.clone(), - ), - )); - // TOOD: Fixme, not always id 0 is the first shard. - if self.id == 0 { - tasks.push(Box::pin( - crate::channels::commands::print_sysinfo::print_sys_info(self.clone()), - )); - } - - tasks.push(Box::pin( - crate::channels::commands::verify_heartbeats::verify_heartbeats(self.clone()), - )); - tasks.push(Box::pin( - crate::channels::commands::save_messages::save_messages(self.clone()), - )); + // Create a oneshot channel for shutdown completion notification + let (shutdown_complete_tx, shutdown_complete_rx) = async_channel::bounded(1); let stop_receiver = self.get_stop_receiver(); let shard_for_shutdown = self.clone(); - /* + // Spawn shutdown handler - only this task consumes the stop signal compio::runtime::spawn(async move { let _ = stop_receiver.recv().await; - info!("Shard {} received shutdown signal", shard_for_shutdown.id); - let shutdown_success = shard_for_shutdown.trigger_shutdown().await; if !shutdown_success { shard_error!(shard_for_shutdown.id, "shutdown timed out"); - } else { - shard_info!(shard_for_shutdown.id, "shutdown completed successfully"); } - }); - */ + let _ = shutdown_complete_tx.send(()).await; + }) + .detach(); let elapsed = now.elapsed(); shard_info!(self.id, "Initialized in {} ms.", elapsed.as_millis()); - let result = try_join_all(tasks).await; - result?; + + shutdown_complete_rx.recv().await.ok(); Ok(()) } @@ -349,11 +323,19 @@ impl IggyShard { self.stop_receiver.clone() } + /// Get the task supervisor for the current thread + /// + /// # Panics + /// Panics if the task supervisor has not been initialized + pub fn task_registry() -> Rc<crate::shard::task_registry::TaskRegistry> { + task_registry() + } + #[instrument(skip_all, name = "trace_shutdown")] pub async fn trigger_shutdown(&self) -> bool { self.is_shutting_down.store(true, Ordering::SeqCst); - info!("Shard {} shutdown state set", self.id); - self.task_registry.shutdown_all(SHUTDOWN_TIMEOUT).await + debug!("Shard {} shutdown state set", self.id); + task_registry().graceful_shutdown(SHUTDOWN_TIMEOUT).await } pub fn get_available_shards_count(&self) -> u32 { diff --git a/core/server/src/shard/stats.rs b/core/server/src/shard/stats.rs deleted file mode 100644 index 8b137891..00000000 --- a/core/server/src/shard/stats.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/core/server/src/shard/system/messages.rs b/core/server/src/shard/system/messages.rs index a64602d0..e7f28b4d 100644 --- a/core/server/src/shard/system/messages.rs +++ b/core/server/src/shard/system/messages.rs @@ -16,8 +16,6 @@ * under the License. */ -use std::sync::atomic::Ordering; - use super::COMPONENT; use crate::binary::handlers::messages::poll_messages_handler::IggyPollMetadata; use crate::shard::IggyShard; @@ -34,10 +32,11 @@ use crate::streaming::traits::MainOps; use crate::streaming::utils::{PooledBuffer, hash}; use crate::streaming::{partitions, streams, topics}; use error_set::ErrContext; - use iggy_common::{ - BytesSerializable, Consumer, EncryptorKind, Identifier, IggyError, IggyTimestamp, Partitioning, PartitioningKind, PollingKind, PollingStrategy, IGGY_MESSAGE_HEADER_SIZE + BytesSerializable, Consumer, EncryptorKind, IGGY_MESSAGE_HEADER_SIZE, Identifier, IggyError, + IggyTimestamp, Partitioning, PartitioningKind, PollingKind, PollingStrategy, }; +use std::sync::atomic::Ordering; use tracing::{error, trace}; impl IggyShard { @@ -352,7 +351,6 @@ impl IggyShard { todo!(); } - async fn decrypt_messages( &self, batches: IggyMessagesBatchSet, diff --git a/core/server/src/shard/task_registry.rs b/core/server/src/shard/task_registry.rs deleted file mode 100644 index ff674a8e..00000000 --- a/core/server/src/shard/task_registry.rs +++ /dev/null @@ -1,108 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use async_channel::{Receiver, Sender, bounded}; -use compio::runtime::JoinHandle; -use futures::future::join_all; -use std::cell::RefCell; -use std::collections::HashMap; -use std::future::Future; -use std::time::Duration; -use tracing::{error, info, warn}; - -pub struct TaskRegistry { - tasks: RefCell<Vec<JoinHandle<()>>>, - active_connections: RefCell<HashMap<u32, Sender<()>>>, -} - -impl TaskRegistry { - pub fn new() -> Self { - Self { - tasks: RefCell::new(Vec::new()), - active_connections: RefCell::new(HashMap::new()), - } - } - - pub fn spawn_tracked<F>(&self, future: F) - where - F: Future<Output = ()> + 'static, - { - let handle = compio::runtime::spawn(future); - self.tasks.borrow_mut().push(handle); - } - - pub fn add_connection(&self, client_id: u32) -> Receiver<()> { - let (stop_sender, stop_receiver) = bounded(1); - self.active_connections - .borrow_mut() - .insert(client_id, stop_sender); - stop_receiver - } - - pub fn remove_connection(&self, client_id: &u32) { - self.active_connections.borrow_mut().remove(client_id); - } - - pub async fn shutdown_all(&self, timeout: Duration) -> bool { - info!("Initiating task registry shutdown"); - - let connections = self.active_connections.borrow(); - for (client_id, stop_sender) in connections.iter() { - info!("Sending shutdown signal to client {}", client_id); - if let Err(e) = stop_sender.send(()).await { - warn!( - "Failed to send shutdown signal to client {}: {}", - client_id, e - ); - } - } - drop(connections); - - let tasks = self.tasks.take(); - let total = tasks.len(); - - if total == 0 { - info!("No tasks to shut down"); - return true; - } - - let timeout_futures: Vec<_> = tasks - .into_iter() - .enumerate() - .map(|(idx, handle)| async move { - match compio::time::timeout(timeout, handle).await { - Ok(_) => (idx, true), - Err(_) => { - warn!("Task {} did not complete within timeout", idx); - (idx, false) - } - } - }) - .collect(); - - let results = join_all(timeout_futures).await; - let completed = results.iter().filter(|(_, success)| *success).count(); - - info!( - "Task registry shutdown complete. {} of {} tasks completed", - completed, total - ); - - completed == total - } -} diff --git a/core/server/src/shard/task_registry/builders.rs b/core/server/src/shard/task_registry/builders.rs new file mode 100644 index 00000000..062a6143 --- /dev/null +++ b/core/server/src/shard/task_registry/builders.rs @@ -0,0 +1,22 @@ +pub mod continuous; +pub mod oneshot; +pub mod periodic; + +use super::registry::TaskRegistry; + +impl TaskRegistry { + pub fn periodic(&self, name: &'static str) -> periodic::PeriodicBuilder<'_> { + periodic::PeriodicBuilder::new(self, name) + } + + pub fn continuous(&self, name: &'static str) -> continuous::ContinuousBuilder<'_> { + continuous::ContinuousBuilder::new(self, name) + } + + pub fn oneshot(&self, name: &'static str) -> oneshot::OneShotBuilder<'_> { + oneshot::OneShotBuilder::new(self, name) + } +} + +pub struct NoTask; +pub struct HasTask; diff --git a/core/server/src/shard/task_registry/builders/continuous.rs b/core/server/src/shard/task_registry/builders/continuous.rs new file mode 100644 index 00000000..73743548 --- /dev/null +++ b/core/server/src/shard/task_registry/builders/continuous.rs @@ -0,0 +1,116 @@ +use crate::shard::IggyShard; +use crate::shard::task_registry::registry::TaskRegistry; +use crate::shard::task_registry::specs::{ + ContinuousTask, TaskCtx, TaskFuture, TaskMeta, TaskScope, +}; +use futures::future::LocalBoxFuture; +use iggy_common::IggyError; +use std::{fmt::Debug, marker::PhantomData, rc::Rc}; + +use crate::shard::task_registry::builders::{HasTask, NoTask}; + +pub struct ContinuousBuilder<'a, S = NoTask> { + reg: &'a TaskRegistry, + name: &'static str, + scope: TaskScope, + critical: bool, + shard: Option<Rc<IggyShard>>, + run: Option<Box<dyn FnOnce(TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>>, + _p: PhantomData<S>, +} + +impl<'a> ContinuousBuilder<'a, NoTask> { + pub fn new(reg: &'a TaskRegistry, name: &'static str) -> Self { + Self { + reg, + name, + scope: TaskScope::AllShards, + critical: false, + shard: None, + run: None, + _p: PhantomData, + } + } + + pub fn on_shard(mut self, scope: TaskScope) -> Self { + self.scope = scope; + self + } + + pub fn critical(mut self, c: bool) -> Self { + self.critical = c; + self + } + + pub fn with_shard(mut self, shard: Rc<IggyShard>) -> Self { + self.shard = Some(shard); + self + } + + pub fn run<F, Fut>(self, f: F) -> ContinuousBuilder<'a, HasTask> + where + F: FnOnce(TaskCtx) -> Fut + 'static, + Fut: std::future::Future<Output = Result<(), IggyError>> + 'static, + { + ContinuousBuilder { + reg: self.reg, + name: self.name, + scope: self.scope, + critical: self.critical, + shard: self.shard, + run: Some(Box::new(move |ctx| Box::pin(f(ctx)))), + _p: PhantomData, + } + } +} + +impl<'a> ContinuousBuilder<'a, HasTask> { + pub fn spawn(self) { + let shard = self.shard.expect("shard required"); + if !self.scope.should_run(&shard) { + return; + } + let spec = Box::new(ClosureContinuous { + name: self.name, + scope: self.scope, + critical: self.critical, + run: self.run.expect("run required"), + }); + self.reg.spawn_continuous(shard, spec); + } +} + +struct ClosureContinuous { + name: &'static str, + scope: TaskScope, + critical: bool, + run: Box<dyn FnOnce(TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>, +} + +impl Debug for ClosureContinuous { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClosureContinuous") + .field("name", &self.name) + .field("scope", &self.scope) + .field("critical", &self.critical) + .finish() + } +} + +impl TaskMeta for ClosureContinuous { + fn name(&self) -> &'static str { + self.name + } + fn scope(&self) -> TaskScope { + self.scope.clone() + } + fn is_critical(&self) -> bool { + self.critical + } +} + +impl ContinuousTask for ClosureContinuous { + fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { + (self.run)(ctx) + } +} diff --git a/core/server/src/shard/task_registry/builders/oneshot.rs b/core/server/src/shard/task_registry/builders/oneshot.rs new file mode 100644 index 00000000..706bc243 --- /dev/null +++ b/core/server/src/shard/task_registry/builders/oneshot.rs @@ -0,0 +1,128 @@ +use crate::shard::IggyShard; +use crate::shard::task_registry::registry::TaskRegistry; +use crate::shard::task_registry::specs::{OneShotTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use futures::future::LocalBoxFuture; +use iggy_common::IggyError; +use std::{fmt::Debug, marker::PhantomData, rc::Rc, time::Duration}; + +use crate::shard::task_registry::builders::{HasTask, NoTask}; + +pub struct OneShotBuilder<'a, S = NoTask> { + reg: &'a TaskRegistry, + name: &'static str, + scope: TaskScope, + critical: bool, + shard: Option<Rc<IggyShard>>, + timeout: Option<Duration>, + run: Option<Box<dyn FnOnce(TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>>, + _p: PhantomData<S>, +} + +impl<'a> OneShotBuilder<'a, NoTask> { + pub fn new(reg: &'a TaskRegistry, name: &'static str) -> Self { + Self { + reg, + name, + scope: TaskScope::AllShards, + critical: false, + shard: None, + timeout: None, + run: None, + _p: PhantomData, + } + } + + pub fn on_shard(mut self, scope: TaskScope) -> Self { + self.scope = scope; + self + } + + pub fn critical(mut self, c: bool) -> Self { + self.critical = c; + self + } + + pub fn with_shard(mut self, shard: Rc<IggyShard>) -> Self { + self.shard = Some(shard); + self + } + + pub fn timeout(mut self, d: Duration) -> Self { + self.timeout = Some(d); + self + } + + pub fn run<F, Fut>(self, f: F) -> OneShotBuilder<'a, HasTask> + where + F: FnOnce(TaskCtx) -> Fut + 'static, + Fut: std::future::Future<Output = Result<(), IggyError>> + 'static, + { + OneShotBuilder { + reg: self.reg, + name: self.name, + scope: self.scope, + critical: self.critical, + shard: self.shard, + timeout: self.timeout, + run: Some(Box::new(move |ctx| Box::pin(f(ctx)))), + _p: PhantomData, + } + } +} + +impl<'a> OneShotBuilder<'a, HasTask> { + pub fn spawn(self) { + let shard = self.shard.expect("shard required"); + if !self.scope.should_run(&shard) { + return; + } + let spec = Box::new(ClosureOneShot { + name: self.name, + scope: self.scope, + critical: self.critical, + timeout: self.timeout, + run: self.run.expect("run required"), + }); + self.reg.spawn_oneshot(shard, spec); + } +} + +struct ClosureOneShot { + name: &'static str, + scope: TaskScope, + critical: bool, + timeout: Option<Duration>, + run: Box<dyn FnOnce(TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>, +} + +impl Debug for ClosureOneShot { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClosureOneShot") + .field("name", &self.name) + .field("scope", &self.scope) + .field("critical", &self.critical) + .field("timeout", &self.timeout) + .finish() + } +} + +impl TaskMeta for ClosureOneShot { + fn name(&self) -> &'static str { + self.name + } + fn scope(&self) -> TaskScope { + self.scope.clone() + } + fn is_critical(&self) -> bool { + self.critical + } +} + +impl OneShotTask for ClosureOneShot { + fn run_once(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { + (self.run)(ctx) + } + fn timeout(&self) -> Option<Duration> { + self.timeout + } +} diff --git a/core/server/src/shard/task_registry/builders/periodic.rs b/core/server/src/shard/task_registry/builders/periodic.rs new file mode 100644 index 00000000..598021e6 --- /dev/null +++ b/core/server/src/shard/task_registry/builders/periodic.rs @@ -0,0 +1,144 @@ +use crate::shard::IggyShard; +use crate::shard::task_registry::registry::TaskRegistry; +use crate::shard::task_registry::specs::{PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use futures::future::LocalBoxFuture; +use iggy_common::IggyError; +use std::{fmt::Debug, marker::PhantomData, rc::Rc, time::Duration}; + +use crate::shard::task_registry::builders::{HasTask, NoTask}; + +pub struct PeriodicBuilder<'a, S = NoTask> { + reg: &'a TaskRegistry, + name: &'static str, + scope: TaskScope, + critical: bool, + shard: Option<Rc<IggyShard>>, + period: Option<Duration>, + last_on_shutdown: bool, + tick: Option<Box<dyn FnMut(&TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>>, + _p: PhantomData<S>, +} + +impl<'a> PeriodicBuilder<'a, NoTask> { + pub fn new(reg: &'a TaskRegistry, name: &'static str) -> Self { + Self { + reg, + name, + scope: TaskScope::AllShards, + critical: false, + shard: None, + period: None, + last_on_shutdown: false, + tick: None, + _p: PhantomData, + } + } + + pub fn every(mut self, d: Duration) -> Self { + self.period = Some(d); + self + } + + pub fn on_shard(mut self, scope: TaskScope) -> Self { + self.scope = scope; + self + } + + pub fn critical(mut self, c: bool) -> Self { + self.critical = c; + self + } + + pub fn with_shard(mut self, shard: Rc<IggyShard>) -> Self { + self.shard = Some(shard); + self + } + + pub fn last_tick_on_shutdown(mut self, v: bool) -> Self { + self.last_on_shutdown = v; + self + } + + pub fn tick<F, Fut>(self, f: F) -> PeriodicBuilder<'a, HasTask> + where + F: FnMut(&TaskCtx) -> Fut + 'static, + Fut: std::future::Future<Output = Result<(), IggyError>> + 'static, + { + let mut g = f; + PeriodicBuilder { + reg: self.reg, + name: self.name, + scope: self.scope, + critical: self.critical, + shard: self.shard, + period: self.period, + last_on_shutdown: self.last_on_shutdown, + tick: Some(Box::new(move |ctx| Box::pin(g(ctx)))), + _p: PhantomData, + } + } +} + +impl<'a> PeriodicBuilder<'a, HasTask> { + pub fn spawn(self) { + let shard = self.shard.expect("shard required"); + let period = self.period.expect("period required"); + if !self.scope.should_run(&shard) { + return; + } + let spec = Box::new(ClosurePeriodic { + name: self.name, + scope: self.scope, + critical: self.critical, + period, + last_on_shutdown: self.last_on_shutdown, + tick: self.tick.expect("tick required"), + }); + self.reg.spawn_periodic(shard, spec); + } +} + +struct ClosurePeriodic { + name: &'static str, + scope: TaskScope, + critical: bool, + period: Duration, + last_on_shutdown: bool, + tick: Box<dyn FnMut(&TaskCtx) -> LocalBoxFuture<'static, Result<(), IggyError>>>, +} + +impl Debug for ClosurePeriodic { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClosurePeriodic") + .field("name", &self.name) + .field("scope", &self.scope) + .field("critical", &self.critical) + .field("period", &self.period) + .field("last_on_shutdown", &self.last_on_shutdown) + .finish() + } +} + +impl TaskMeta for ClosurePeriodic { + fn name(&self) -> &'static str { + self.name + } + fn scope(&self) -> TaskScope { + self.scope.clone() + } + fn is_critical(&self) -> bool { + self.critical + } +} + +impl PeriodicTask for ClosurePeriodic { + fn period(&self) -> Duration { + self.period + } + fn tick(&mut self, ctx: &TaskCtx) -> TaskFuture { + (self.tick)(ctx) + } + fn last_tick_on_shutdown(&self) -> bool { + self.last_on_shutdown + } +} diff --git a/core/server/src/shard/task_registry/mod.rs b/core/server/src/shard/task_registry/mod.rs new file mode 100644 index 00000000..7fad6ce5 --- /dev/null +++ b/core/server/src/shard/task_registry/mod.rs @@ -0,0 +1,12 @@ +pub mod builders; +pub mod registry; +pub mod shutdown; +pub mod specs; +pub mod tls; + +pub use registry::TaskRegistry; +pub use shutdown::{Shutdown, ShutdownToken}; +pub use specs::{ + ContinuousTask, OneShotTask, PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskResult, TaskScope, +}; +pub use tls::{init_task_registry, is_registry_initialized, task_registry}; diff --git a/core/server/src/shard/task_registry/registry.rs b/core/server/src/shard/task_registry/registry.rs new file mode 100644 index 00000000..9accce5b --- /dev/null +++ b/core/server/src/shard/task_registry/registry.rs @@ -0,0 +1,550 @@ +use super::shutdown::{Shutdown, ShutdownToken}; +use super::specs::{ + ContinuousTask, OneShotTask, PeriodicTask, TaskCtx, TaskMeta, TaskResult, TaskScope, +}; +use crate::shard::IggyShard; +use compio::runtime::JoinHandle; +use futures::future::join_all; +use iggy_common::IggyError; +use std::{cell::RefCell, collections::HashMap, rc::Rc, time::Duration}; +use tracing::{debug, error, trace, warn}; + +enum Kind { + Continuous, + Periodic(Duration), + OneShot, +} + +struct TaskHandle { + name: String, + kind: Kind, + handle: JoinHandle<TaskResult>, + critical: bool, +} + +pub struct TaskRegistry { + pub(crate) shard_id: u16, + shutdown: Shutdown, + shutdown_token: ShutdownToken, + long_running: RefCell<Vec<TaskHandle>>, + oneshots: RefCell<Vec<TaskHandle>>, + connections: RefCell<HashMap<u32, async_channel::Sender<()>>>, + shutting_down: RefCell<bool>, +} + +impl TaskRegistry { + pub fn new(shard_id: u16) -> Self { + let (s, t) = Shutdown::new(); + Self { + shard_id, + shutdown: s, + shutdown_token: t, + long_running: RefCell::new(vec![]), + oneshots: RefCell::new(vec![]), + connections: RefCell::new(HashMap::new()), + shutting_down: RefCell::new(false), + } + } + + pub fn shutdown_token(&self) -> ShutdownToken { + self.shutdown_token.clone() + } + + pub fn spawn_continuous(&self, shard: Rc<IggyShard>, mut task: Box<dyn ContinuousTask>) { + if *self.shutting_down.borrow() { + warn!( + "Attempted to spawn continuous task '{}' during shutdown", + task.name() + ); + return; + } + if !task.scope().should_run(&shard) { + return; + } + task.on_start(); + let name = task.name(); + let is_critical = task.is_critical(); + let ctx = TaskCtx { + shard, + shutdown: self.shutdown_token.clone(), + }; + let shard_id = self.shard_id; + + let handle = compio::runtime::spawn(async move { + trace!("continuous '{}' starting on shard {}", name, shard_id); + let r = task.run(ctx).await; + match &r { + Ok(()) => debug!("continuous '{}' completed on shard {}", name, shard_id), + Err(e) => error!("continuous '{}' failed on shard {}: {}", name, shard_id, e), + } + r + }); + + self.long_running.borrow_mut().push(TaskHandle { + name: name.into(), + kind: Kind::Continuous, + handle, + critical: is_critical, + }); + } + + pub fn spawn_periodic(&self, shard: Rc<IggyShard>, mut task: Box<dyn PeriodicTask>) { + if *self.shutting_down.borrow() { + warn!( + "Attempted to spawn periodic task '{}' during shutdown", + task.name() + ); + return; + } + if !task.scope().should_run(&shard) { + return; + } + let period = task.period(); + task.on_start(); + let name = task.name(); + let is_critical = task.is_critical(); + let ctx = TaskCtx { + shard, + shutdown: self.shutdown_token.clone(), + }; + let shutdown = self.shutdown_token.clone(); + let shard_id = self.shard_id; + + let handle = compio::runtime::spawn(async move { + trace!( + "periodic '{}' every {:?} on shard {}", + name, period, shard_id + ); + loop { + if !shutdown.sleep_or_shutdown(period).await { + break; + } + if let Err(e) = task.tick(&ctx).await { + error!( + "periodic '{}' tick failed on shard {}: {}", + name, shard_id, e + ); + } + } + if task.last_tick_on_shutdown() { + let _ = task.tick(&ctx).await; + } + Ok(()) + }); + + self.long_running.borrow_mut().push(TaskHandle { + name: name.into(), + kind: Kind::Periodic(period), + handle, + critical: is_critical, + }); + } + + pub fn spawn_oneshot(&self, shard: Rc<IggyShard>, mut task: Box<dyn OneShotTask>) { + if *self.shutting_down.borrow() { + warn!( + "Attempted to spawn oneshot task '{}' during shutdown", + task.name() + ); + return; + } + if !task.scope().should_run(&shard) { + return; + } + task.on_start(); + let name = task.name(); + let is_critical = task.is_critical(); + let timeout = task.timeout(); + let ctx = TaskCtx { + shard, + shutdown: self.shutdown_token.clone(), + }; + let shard_id = self.shard_id; + + let handle = compio::runtime::spawn(async move { + trace!("oneshot '{}' starting on shard {}", name, shard_id); + let fut = task.run_once(ctx); + let r = if let Some(d) = timeout { + match compio::time::timeout(d, fut).await { + Ok(r) => r, + Err(_) => Err(IggyError::TaskTimeout), + } + } else { + fut.await + }; + match &r { + Ok(()) => trace!("oneshot '{}' completed on shard {}", name, shard_id), + Err(e) => error!("oneshot '{}' failed on shard {}: {}", name, shard_id, e), + } + r + }); + + self.oneshots.borrow_mut().push(TaskHandle { + name: name.into(), + kind: Kind::OneShot, + handle, + critical: is_critical, + }); + } + + pub async fn graceful_shutdown(&self, timeout: Duration) -> bool { + use std::time::Instant; + + let start = Instant::now(); + *self.shutting_down.borrow_mut() = true; + self.shutdown_connections(); + self.shutdown.trigger(); + + // First shutdown long-running tasks (continuous and periodic) + let long = self.long_running.take(); + let long_ok = if !long.is_empty() { + debug!( + "Shutting down {} long-running task(s) on shard {}", + long.len(), + self.shard_id + ); + self.await_with_timeout(long, timeout).await + } else { + true + }; + + // Calculate remaining time for oneshots + let elapsed = start.elapsed(); + let remaining = timeout.saturating_sub(elapsed); + + // Then shutdown oneshot tasks with remaining time + let ones = self.oneshots.take(); + let ones_ok = if !ones.is_empty() { + if remaining.is_zero() { + warn!( + "No time remaining for {} oneshot task(s) on shard {}, they will be cancelled", + ones.len(), + self.shard_id + ); + false + } else { + debug!( + "Shutting down {} oneshot task(s) on shard {} with {:?} remaining", + ones.len(), + self.shard_id, + remaining + ); + self.await_with_timeout(ones, remaining).await + } + } else { + true + }; + + let total_elapsed = start.elapsed(); + if long_ok && ones_ok { + debug!( + "Graceful shutdown completed successfully on shard {} in {:?}", + self.shard_id, total_elapsed + ); + } else { + warn!( + "Graceful shutdown completed with failures on shard {} in {:?}", + self.shard_id, total_elapsed + ); + } + + long_ok && ones_ok + } + + async fn await_with_timeout(&self, tasks: Vec<TaskHandle>, timeout: Duration) -> bool { + if tasks.is_empty() { + return true; + } + let results = join_all(tasks.into_iter().map(|t| async move { + match compio::time::timeout(timeout, t.handle).await { + Ok(Ok(Ok(()))) => true, + Ok(Ok(Err(e))) => { + error!("task '{}' failed: {}", t.name, e); + !t.critical + } + Ok(Err(_)) => { + error!("task '{}' panicked", t.name); + !t.critical + } + Err(_) => { + error!("task '{}' timed out after {:?}", t.name, timeout); + !t.critical + } + } + })) + .await; + + results.into_iter().all(|x| x) + } + + async fn await_all(&self, tasks: Vec<TaskHandle>) -> bool { + if tasks.is_empty() { + return true; + } + let results = join_all(tasks.into_iter().map(|t| async move { + match t.handle.await { + Ok(Ok(())) => true, + Ok(Err(e)) => { + error!("task '{}' failed: {}", t.name, e); + !t.critical + } + Err(_) => { + error!("task '{}' panicked", t.name); + !t.critical + } + } + })) + .await; + results.into_iter().all(|x| x) + } + + pub fn add_connection(&self, client_id: u32) -> async_channel::Receiver<()> { + let (tx, rx) = async_channel::bounded(1); + self.connections.borrow_mut().insert(client_id, tx); + rx + } + + pub fn remove_connection(&self, client_id: &u32) { + self.connections.borrow_mut().remove(client_id); + } + + fn shutdown_connections(&self) { + for tx in self.connections.borrow().values() { + let _ = tx.send_blocking(()); + } + } + + pub fn spawn_tracked<F>(&self, future: F) + where + F: futures::Future<Output = ()> + 'static, + { + let handle = compio::runtime::spawn(async move { + future.await; + Ok(()) + }); + self.long_running.borrow_mut().push(TaskHandle { + name: "connection_handler".to_string(), + kind: Kind::Continuous, + handle, + critical: false, + }); + } + + pub fn spawn_oneshot_future<F>(&self, name: &'static str, critical: bool, f: F) + where + F: futures::Future<Output = Result<(), IggyError>> + 'static, + { + if *self.shutting_down.borrow() { + warn!( + "Attempted to spawn oneshot future '{}' during shutdown", + name + ); + return; + } + let shard_id = self.shard_id; + let handle = compio::runtime::spawn(async move { + trace!("oneshot '{}' starting on shard {}", name, shard_id); + let r = f.await; + match &r { + Ok(()) => trace!("oneshot '{}' completed on shard {}", name, shard_id), + Err(e) => error!("oneshot '{}' failed on shard {}: {}", name, shard_id, e), + } + r + }); + + self.oneshots.borrow_mut().push(TaskHandle { + name: name.into(), + kind: Kind::OneShot, + handle, + critical, + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::shard::task_registry::specs::{OneShotTask, TaskCtx, TaskFuture, TaskMeta}; + use std::fmt::Debug; + + #[derive(Debug)] + struct TestOneShotTask { + should_fail: bool, + is_critical: bool, + } + + impl TaskMeta for TestOneShotTask { + fn name(&self) -> &'static str { + "test_oneshot" + } + + fn is_critical(&self) -> bool { + self.is_critical + } + } + + impl OneShotTask for TestOneShotTask { + fn run_once(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { + Box::pin(async move { + if self.should_fail { + Err(IggyError::Error) + } else { + Ok(()) + } + }) + } + + fn timeout(&self) -> Option<Duration> { + Some(Duration::from_millis(100)) + } + } + + #[compio::test] + async fn test_oneshot_completion_detection() { + let registry = TaskRegistry::new(1); + + // Spawn a failing non-critical task + registry.spawn_oneshot_future("failing_non_critical", false, async { + Err(IggyError::Error) + }); + + // Spawn a successful task + registry.spawn_oneshot_future("successful", false, async { Ok(()) }); + + // Wait for all tasks + let all_ok = registry.await_all(registry.oneshots.take()).await; + + // Should return true because the failing task is not critical + assert!(all_ok); + } + + #[compio::test] + async fn test_oneshot_critical_failure() { + let registry = TaskRegistry::new(1); + + // Spawn a failing critical task + registry.spawn_oneshot_future("failing_critical", true, async { Err(IggyError::Error) }); + + // Wait for all tasks + let all_ok = registry.await_all(registry.oneshots.take()).await; + + // Should return false because the failing task is critical + assert!(!all_ok); + } + + #[compio::test] + async fn test_shutdown_prevents_spawning() { + let registry = TaskRegistry::new(1); + + // Trigger shutdown + *registry.shutting_down.borrow_mut() = true; + + let initial_count = registry.oneshots.borrow().len(); + + // Try to spawn after shutdown + registry.spawn_oneshot_future("should_not_spawn", false, async { Ok(()) }); + + // Task should not be added + assert_eq!(registry.oneshots.borrow().len(), initial_count); + } + + #[compio::test] + async fn test_timeout_error() { + let registry = TaskRegistry::new(1); + + // Create a task that will timeout + let handle = compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_secs(10)).await; + Ok(()) + }); + + let task_handle = TaskHandle { + name: "timeout_test".to_string(), + kind: Kind::OneShot, + handle, + critical: false, + }; + + let tasks = vec![task_handle]; + let all_ok = registry + .await_with_timeout(tasks, Duration::from_millis(50)) + .await; + + // Should return true because the task is not critical + assert!(all_ok); + } + + #[compio::test] + async fn test_composite_timeout() { + let registry = TaskRegistry::new(1); + + // Create a long-running task that takes 100ms + let long_handle = compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_millis(100)).await; + Ok(()) + }); + + registry.long_running.borrow_mut().push(TaskHandle { + name: "long_task".to_string(), + kind: Kind::Continuous, + handle: long_handle, + critical: false, + }); + + // Create a oneshot that would succeed quickly + let oneshot_handle = compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_millis(10)).await; + Ok(()) + }); + + registry.oneshots.borrow_mut().push(TaskHandle { + name: "quick_oneshot".to_string(), + kind: Kind::OneShot, + handle: oneshot_handle, + critical: false, + }); + + // Give total timeout of 150ms + // Long-running should complete in ~100ms + // Oneshot should have ~50ms remaining, which is enough + let all_ok = registry.graceful_shutdown(Duration::from_millis(150)).await; + assert!(all_ok); + } + + #[compio::test] + async fn test_composite_timeout_insufficient() { + let registry = TaskRegistry::new(1); + + // Create a long-running task that takes 50ms + let long_handle = compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_millis(50)).await; + Ok(()) + }); + + registry.long_running.borrow_mut().push(TaskHandle { + name: "long_task".to_string(), + kind: Kind::Continuous, + handle: long_handle, + critical: false, + }); + + // Create a oneshot that would take 100ms (much longer) + let oneshot_handle = compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_millis(100)).await; + Ok(()) + }); + + registry.oneshots.borrow_mut().push(TaskHandle { + name: "slow_oneshot".to_string(), + kind: Kind::OneShot, + handle: oneshot_handle, + critical: true, // Make it critical so failure is detected + }); + + // Give total timeout of 60ms + // Long-running should complete in ~50ms + // Oneshot would need 100ms but only has ~10ms, so it should definitely fail + let all_ok = registry.graceful_shutdown(Duration::from_millis(60)).await; + assert!(!all_ok); // Should fail because critical oneshot times out + } +} diff --git a/core/server/src/shard/task_registry/shutdown.rs b/core/server/src/shard/task_registry/shutdown.rs new file mode 100644 index 00000000..fb57b5a9 --- /dev/null +++ b/core/server/src/shard/task_registry/shutdown.rs @@ -0,0 +1,233 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use async_channel::{Receiver, Sender, bounded}; +use futures::FutureExt; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; +use tracing::trace; + +/// Coordinates graceful shutdown across multiple tasks +#[derive(Clone)] +pub struct Shutdown { + sender: Sender<()>, + is_triggered: Arc<AtomicBool>, +} + +impl Shutdown { + pub fn new() -> (Self, ShutdownToken) { + let (sender, receiver) = bounded(1); + let is_triggered = Arc::new(AtomicBool::new(false)); + + let shutdown = Self { + sender, + is_triggered: is_triggered.clone(), + }; + + let token = ShutdownToken { + receiver, + is_triggered, + }; + + (shutdown, token) + } + + pub fn trigger(&self) { + if self.is_triggered.swap(true, Ordering::SeqCst) { + return; + } + + trace!("Triggering shutdown signal"); + let _ = self.sender.close(); + } + + pub fn is_triggered(&self) -> bool { + self.is_triggered.load(Ordering::Relaxed) + } +} + +/// Token held by tasks to receive shutdown signals +#[derive(Clone)] +pub struct ShutdownToken { + receiver: Receiver<()>, + is_triggered: Arc<AtomicBool>, +} + +impl ShutdownToken { + /// Wait for shutdown signal + pub async fn wait(&self) { + let _ = self.receiver.recv().await; + } + + /// Check if shutdown has been triggered (non-blocking) + pub fn is_triggered(&self) -> bool { + self.is_triggered.load(Ordering::Relaxed) + } + + /// Sleep for the specified duration or until shutdown is triggered + /// Returns true if the full duration elapsed, false if shutdown was triggered + pub async fn sleep_or_shutdown(&self, duration: Duration) -> bool { + futures::select! { + _ = self.wait().fuse() => false, + _ = compio::time::sleep(duration).fuse() => !self.is_triggered(), + } + } + + /// Creates a scoped shutdown pair (child `Shutdown`, combined `ShutdownToken`). + /// + /// This is a bit complicated, but it needs to be this way to avoid deadlocks. + /// + /// The returned token fires when EITHER the parent or the child is triggered, + /// while a child trigger does NOT propagate back to the parent. + /// Internally spawns a tiny forwarder to merge both signals into one channel, + /// so callers can await a single `wait()` and use fast `is_triggered()` checks + /// without writing `select!` at every call site. + /// Use when a subtree needs cancelation that respects parent cancelation, + /// but can also be canceled locally. + pub fn child(&self) -> (Shutdown, ShutdownToken) { + let (child_shutdown, child_token) = Shutdown::new(); + let parent_receiver = self.receiver.clone(); + let child_receiver = child_token.receiver.clone(); + + let (combined_sender, combined_receiver) = bounded(1); + let combined_is_triggered = Arc::new(AtomicBool::new(false)); + + let parent_triggered = self.is_triggered.clone(); + let child_triggered = child_token.is_triggered.clone(); + let combined_flag_for_task = combined_is_triggered.clone(); + + compio::runtime::spawn(async move { + futures::select! { + _ = parent_receiver.recv().fuse() => { + trace!("Child token triggered by parent shutdown"); + }, + _ = child_receiver.recv().fuse() => { + trace!("Child token triggered by child shutdown"); + }, + } + + if parent_triggered.load(Ordering::Relaxed) || child_triggered.load(Ordering::Relaxed) { + combined_flag_for_task.store(true, Ordering::SeqCst); + } + + let _ = combined_sender.close(); + }) + .detach(); + + let combined_token = ShutdownToken { + receiver: combined_receiver, + is_triggered: combined_is_triggered, + }; + + (child_shutdown, combined_token) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[compio::test] + async fn test_shutdown_trigger() { + let (shutdown, token) = Shutdown::new(); + + assert!(!token.is_triggered()); + + shutdown.trigger(); + + assert!(token.is_triggered()); + + token.wait().await; + } + + #[compio::test] + async fn test_sleep_or_shutdown_completes() { + let (_shutdown, token) = Shutdown::new(); + + let completed = token.sleep_or_shutdown(Duration::from_millis(10)).await; + assert!(completed); + } + + #[compio::test] + async fn test_sleep_or_shutdown_interrupted() { + let (shutdown, token) = Shutdown::new(); + + // Trigger shutdown after a short delay + let shutdown_clone = shutdown.clone(); + compio::runtime::spawn(async move { + compio::time::sleep(Duration::from_millis(10)).await; + shutdown_clone.trigger(); + }) + .detach(); + + // Should be interrupted + let completed = token.sleep_or_shutdown(Duration::from_secs(10)).await; + assert!(!completed); + } + + #[compio::test] + async fn test_child_token_parent_trigger() { + let (parent_shutdown, parent_token) = Shutdown::new(); + let (_child_shutdown, combined_token) = parent_token.child(); + + assert!(!combined_token.is_triggered()); + + // Trigger parent shutdown + parent_shutdown.trigger(); + + // Combined token should be triggered + combined_token.wait().await; + assert!(combined_token.is_triggered()); + } + + #[compio::test] + async fn test_child_token_child_trigger() { + let (_parent_shutdown, parent_token) = Shutdown::new(); + let (child_shutdown, combined_token) = parent_token.child(); + + assert!(!combined_token.is_triggered()); + + // Trigger child shutdown + child_shutdown.trigger(); + + // Combined token should be triggered + combined_token.wait().await; + assert!(combined_token.is_triggered()); + } + + #[compio::test] + async fn test_child_token_no_polling_overhead() { + let (_parent_shutdown, parent_token) = Shutdown::new(); + let (_child_shutdown, combined_token) = parent_token.child(); + + // Test that we can create many child tokens without performance issues + let start = std::time::Instant::now(); + for _ in 0..100 { + let _ = combined_token.child(); + } + let elapsed = start.elapsed(); + + // Should complete very quickly since there's no polling + assert!( + elapsed.as_millis() < 100, + "Creating child tokens took too long: {:?}", + elapsed + ); + } +} diff --git a/core/server/src/shard/task_registry/specs.rs b/core/server/src/shard/task_registry/specs.rs new file mode 100644 index 00000000..15ea0e55 --- /dev/null +++ b/core/server/src/shard/task_registry/specs.rs @@ -0,0 +1,59 @@ +use crate::shard::IggyShard; +use crate::shard::task_registry::ShutdownToken; +use futures::future::LocalBoxFuture; +use iggy_common::IggyError; +use std::{fmt::Debug, rc::Rc, time::Duration}; + +pub type TaskResult = Result<(), IggyError>; +pub type TaskFuture = LocalBoxFuture<'static, TaskResult>; + +#[derive(Clone, Debug)] +pub enum TaskScope { + AllShards, + SpecificShard(u16), +} + +impl TaskScope { + pub fn should_run(&self, shard: &IggyShard) -> bool { + match self { + TaskScope::AllShards => true, + TaskScope::SpecificShard(id) => shard.id == *id, + } + } +} + +#[derive(Clone)] +pub struct TaskCtx { + pub shard: Rc<IggyShard>, + pub shutdown: ShutdownToken, +} + +pub trait TaskMeta: 'static + Debug { + fn name(&self) -> &'static str; + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + fn is_critical(&self) -> bool { + false + } + fn on_start(&self) {} +} + +pub trait ContinuousTask: TaskMeta { + fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture; +} + +pub trait PeriodicTask: TaskMeta { + fn period(&self) -> Duration; + fn tick(&mut self, ctx: &TaskCtx) -> TaskFuture; + fn last_tick_on_shutdown(&self) -> bool { + false + } +} + +pub trait OneShotTask: TaskMeta { + fn run_once(self: Box<Self>, ctx: TaskCtx) -> TaskFuture; + fn timeout(&self) -> Option<Duration> { + None + } +} diff --git a/core/server/src/shard/task_registry/tls.rs b/core/server/src/shard/task_registry/tls.rs new file mode 100644 index 00000000..798f9f6f --- /dev/null +++ b/core/server/src/shard/task_registry/tls.rs @@ -0,0 +1,104 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use super::registry::TaskRegistry; +use std::cell::RefCell; +use std::rc::Rc; + +thread_local! { + static REGISTRY: RefCell<Option<Rc<TaskRegistry>>> = RefCell::new(None); +} + +pub fn init_task_registry(shard_id: u16) { + REGISTRY.with(|s| { + *s.borrow_mut() = Some(Rc::new(TaskRegistry::new(shard_id))); + }); +} + +pub fn task_registry() -> Rc<TaskRegistry> { + REGISTRY.with(|s| { + s.borrow() + .as_ref() + .expect("Task registry not initialized for this thread. Call init_registry() first.") + .clone() + }) +} + +pub fn is_registry_initialized() -> bool { + REGISTRY.with(|s| s.borrow().is_some()) +} + +pub fn clear_registry() { + REGISTRY.with(|s| { + *s.borrow_mut() = None; + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use std::panic; + + #[test] + fn test_registry_initialization() { + clear_registry(); + assert!(!is_registry_initialized()); + + let result = panic::catch_unwind(|| { + task_registry(); + }); + assert!(result.is_err()); + + init_task_registry(42); + assert!(is_registry_initialized()); + + clear_registry(); + assert!(!is_registry_initialized()); + } + + #[test] + fn test_multiple_initializations() { + clear_registry(); + init_task_registry(1); + let _reg1 = task_registry(); + init_task_registry(2); + let _reg2 = task_registry(); + clear_registry(); + } + + #[test] + fn test_thread_locality() { + use std::thread; + + clear_registry(); + init_task_registry(100); + assert!(is_registry_initialized()); + + let handle = thread::spawn(|| { + assert!(!is_registry_initialized()); + init_task_registry(200); + assert!(is_registry_initialized()); + let _ = task_registry(); + }); + + handle.join().expect("Thread should complete successfully"); + assert!(is_registry_initialized()); + let _ = task_registry(); + clear_registry(); + } +} diff --git a/core/server/src/shard/tasks/continuous/http_server.rs b/core/server/src/shard/tasks/continuous/http_server.rs new file mode 100644 index 00000000..81dde7d3 --- /dev/null +++ b/core/server/src/shard/tasks/continuous/http_server.rs @@ -0,0 +1,68 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::bootstrap::resolve_persister; +use crate::http::http_server; +use crate::shard::IggyShard; +use crate::shard::task_registry::{ContinuousTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use std::fmt::Debug; +use std::rc::Rc; +use tracing::info; + +pub struct HttpServer { + shard: Rc<IggyShard>, +} + +impl Debug for HttpServer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HttpServer") + .field("shard_id", &self.shard.id) + .finish() + } +} + +impl HttpServer { + pub fn new(shard: Rc<IggyShard>) -> Self { + Self { shard } + } +} + +impl TaskMeta for HttpServer { + fn name(&self) -> &'static str { + "http_server" + } + + fn scope(&self) -> TaskScope { + TaskScope::SpecificShard(0) + } + + fn is_critical(&self) -> bool { + false + } +} + +impl ContinuousTask for HttpServer { + fn run(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { + let shard = self.shard.clone(); + Box::pin(async move { + info!("Starting HTTP server on shard: {}", shard.id); + let persister = resolve_persister(shard.config.system.partition.enforce_fsync); + http_server::start(shard.config.http.clone(), persister, shard).await + }) + } +} diff --git a/core/server/src/shard/tasks/continuous/message_pump.rs b/core/server/src/shard/tasks/continuous/message_pump.rs new file mode 100644 index 00000000..b0340e1b --- /dev/null +++ b/core/server/src/shard/tasks/continuous/message_pump.rs @@ -0,0 +1,99 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::task_registry::{ContinuousTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use crate::shard::transmission::frame::ShardFrame; +use crate::{shard_debug, shard_info}; +use futures::{FutureExt, StreamExt}; +use std::fmt::Debug; +use std::rc::Rc; + +pub struct MessagePump { + shard: Rc<IggyShard>, +} + +impl Debug for MessagePump { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MessagePump") + .field("shard_id", &self.shard.id) + .finish() + } +} + +impl MessagePump { + pub fn new(shard: Rc<IggyShard>) -> Self { + Self { shard } + } +} + +impl TaskMeta for MessagePump { + fn name(&self) -> &'static str { + "message_pump" + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn is_critical(&self) -> bool { + true + } +} + +impl ContinuousTask for MessagePump { + fn run(self: Box<Self>, ctx: TaskCtx) -> TaskFuture { + Box::pin(async move { + let Some(mut messages_receiver) = self.shard.messages_receiver.take() else { + shard_info!( + self.shard.id, + "Message receiver already taken; pump not started" + ); + return Ok(()); + }; + + shard_info!(self.shard.id, "Starting message passing task"); + + loop { + futures::select! { + _ = ctx.shutdown.wait().fuse() => { + shard_debug!(self.shard.id, "Message receiver shutting down"); + break; + } + frame = messages_receiver.next().fuse() => { + match frame { + Some(ShardFrame { message, response_sender }) => { + if let (Some(response), Some(tx)) = + (self.shard.handle_shard_message(message).await, response_sender) + { + let _ = tx.send(response).await; + } + } + None => { + shard_debug!(self.shard.id, "Message receiver closed; exiting pump"); + break; + } + } + } + } + } + + Ok(()) + }) + } +} diff --git a/core/server/src/channels/commands/mod.rs b/core/server/src/shard/tasks/continuous/mod.rs similarity index 78% rename from core/server/src/channels/commands/mod.rs rename to core/server/src/shard/tasks/continuous/mod.rs index 0130170b..59ed6293 100644 --- a/core/server/src/channels/commands/mod.rs +++ b/core/server/src/shard/tasks/continuous/mod.rs @@ -16,7 +16,12 @@ * under the License. */ -pub mod clean_personal_access_tokens; -pub mod print_sysinfo; -pub mod save_messages; -pub mod verify_heartbeats; +pub mod http_server; +pub mod message_pump; +pub mod quic_server; +pub mod tcp_server; + +pub use http_server::HttpServer; +pub use message_pump::MessagePump; +pub use quic_server::QuicServer; +pub use tcp_server::TcpServer; diff --git a/core/server/src/shard/tasks/continuous/quic_server.rs b/core/server/src/shard/tasks/continuous/quic_server.rs new file mode 100644 index 00000000..81cd2fcd --- /dev/null +++ b/core/server/src/shard/tasks/continuous/quic_server.rs @@ -0,0 +1,62 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::quic::quic_server; +use crate::shard::IggyShard; +use crate::shard::task_registry::{ContinuousTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use std::fmt::Debug; +use std::rc::Rc; + +pub struct QuicServer { + shard: Rc<IggyShard>, +} + +impl Debug for QuicServer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QuicServer") + .field("shard_id", &self.shard.id) + .finish() + } +} + +impl QuicServer { + pub fn new(shard: Rc<IggyShard>) -> Self { + Self { shard } + } +} + +impl TaskMeta for QuicServer { + fn name(&self) -> &'static str { + "quic_server" + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn is_critical(&self) -> bool { + true + } +} + +impl ContinuousTask for QuicServer { + fn run(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { + let shard = self.shard.clone(); + Box::pin(async move { quic_server::span_quic_server(shard).await }) + } +} diff --git a/core/server/src/shard/tasks/continuous/tcp_server.rs b/core/server/src/shard/tasks/continuous/tcp_server.rs new file mode 100644 index 00000000..f2bcda62 --- /dev/null +++ b/core/server/src/shard/tasks/continuous/tcp_server.rs @@ -0,0 +1,62 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::task_registry::{ContinuousTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use crate::tcp::tcp_server::spawn_tcp_server; +use std::fmt::Debug; +use std::rc::Rc; + +pub struct TcpServer { + shard: Rc<IggyShard>, +} + +impl Debug for TcpServer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TcpServer") + .field("shard_id", &self.shard.id) + .finish() + } +} + +impl TcpServer { + pub fn new(shard: Rc<IggyShard>) -> Self { + Self { shard } + } +} + +impl TaskMeta for TcpServer { + fn name(&self) -> &'static str { + "tcp_server" + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn is_critical(&self) -> bool { + true + } +} + +impl ContinuousTask for TcpServer { + fn run(self: Box<Self>, _ctx: TaskCtx) -> TaskFuture { + let shard = self.shard.clone(); + Box::pin(async move { spawn_tcp_server(shard).await }) + } +} diff --git a/core/server/src/shard/tasks/messages.rs b/core/server/src/shard/tasks/messages.rs deleted file mode 100644 index 2feadd1c..00000000 --- a/core/server/src/shard/tasks/messages.rs +++ /dev/null @@ -1,59 +0,0 @@ -use futures::{FutureExt, StreamExt}; -use iggy_common::IggyError; -use std::{rc::Rc, time::Duration}; - -use crate::{ - shard::{IggyShard, transmission::frame::ShardFrame}, - shard_error, shard_info, -}; - -async fn run_shard_messages_receiver(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let mut messages_receiver = shard.messages_receiver.take().unwrap(); - - shard_info!(shard.id, "Starting message passing task"); - loop { - let shutdown_check = async { - loop { - if shard.is_shutting_down() { - return; - } - compio::time::sleep(Duration::from_millis(100)).await; - } - }; - - futures::select! { - _ = shutdown_check.fuse() => { - shard_info!(shard.id, "Message receiver shutting down"); - break; - } - frame = messages_receiver.next().fuse() => { - if let Some(frame) = frame { - let ShardFrame { - message, - response_sender, - } = frame; - match (shard.handle_shard_message(message).await, response_sender) { - (Some(response), Some(response_sender)) => { - response_sender - .send(response) - .await - .expect("Failed to send response back to origin shard."); - } - _ => {} - }; - } - } - } - } - - Ok(()) -} - -pub async fn spawn_shard_message_task(shard: Rc<IggyShard>) -> Result<(), IggyError> { - let result = run_shard_messages_receiver(shard.clone()).await; - if let Err(err) = result { - shard_error!(shard.id, "Error running shard message receiver: {err}"); - return Err(err); - } - Ok(()) -} diff --git a/core/server/src/shard/tasks/mod.rs b/core/server/src/shard/tasks/mod.rs index ba63992f..6863f17f 100644 --- a/core/server/src/shard/tasks/mod.rs +++ b/core/server/src/shard/tasks/mod.rs @@ -1 +1,104 @@ -pub mod messages; +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +pub mod continuous; +pub mod periodic; + +use crate::shard::IggyShard; +use crate::shard::task_registry::TaskRegistry; +use std::rc::Rc; + +pub fn register_tasks(reg: &TaskRegistry, shard: Rc<IggyShard>) { + reg.spawn_continuous( + shard.clone(), + Box::new(continuous::MessagePump::new(shard.clone())), + ); + + if shard.config.tcp.enabled { + reg.spawn_continuous( + shard.clone(), + Box::new(continuous::TcpServer::new(shard.clone())), + ); + } + + if shard.config.http.enabled { + reg.spawn_continuous( + shard.clone(), + Box::new(continuous::HttpServer::new(shard.clone())), + ); + } + + if shard.config.quic.enabled { + reg.spawn_continuous( + shard.clone(), + Box::new(continuous::QuicServer::new(shard.clone())), + ); + } + + if shard.config.message_saver.enabled { + let period = shard.config.message_saver.interval.get_duration(); + reg.spawn_periodic( + shard.clone(), + Box::new(periodic::SaveMessages::new(shard.clone(), period)), + ); + } + + if shard.config.heartbeat.enabled { + let period = shard.config.heartbeat.interval.get_duration(); + reg.spawn_periodic( + shard.clone(), + Box::new(periodic::VerifyHeartbeats::new(shard.clone(), period)), + ); + } + + if shard.config.personal_access_token.cleaner.enabled { + let period = shard + .config + .personal_access_token + .cleaner + .interval + .get_duration(); + reg.spawn_periodic( + shard.clone(), + Box::new(periodic::ClearPersonalAccessTokens::new( + shard.clone(), + period, + )), + ); + } + + if shard + .config + .system + .logging + .sysinfo_print_interval + .as_micros() + > 0 + { + let period = shard + .config + .system + .logging + .sysinfo_print_interval + .get_duration(); + reg.spawn_periodic( + shard.clone(), + Box::new(periodic::PrintSysinfo::new(shard.clone(), period)), + ); + } +} diff --git a/core/server/src/shard/tasks/periodic/clear_jwt_tokens.rs b/core/server/src/shard/tasks/periodic/clear_jwt_tokens.rs new file mode 100644 index 00000000..cfb72ac2 --- /dev/null +++ b/core/server/src/shard/tasks/periodic/clear_jwt_tokens.rs @@ -0,0 +1,92 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::http::shared::AppState; +use crate::shard::task_registry::{PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use iggy_common::IggyTimestamp; +use std::fmt::Debug; +use std::sync::Arc; +use std::time::Duration; +use tracing::{error, info, trace}; + +pub struct ClearJwtTokens { + app_state: Arc<AppState>, + period: Duration, +} + +impl Debug for ClearJwtTokens { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClearJwtTokens") + .field("period", &self.period) + .finish() + } +} + +impl ClearJwtTokens { + pub fn new(app_state: Arc<AppState>, period: Duration) -> Self { + Self { app_state, period } + } +} + +impl TaskMeta for ClearJwtTokens { + fn name(&self) -> &'static str { + "clear_jwt_tokens" + } + + fn scope(&self) -> TaskScope { + TaskScope::SpecificShard(0) + } + + fn on_start(&self) { + info!( + "JWT token cleaner is enabled, expired revoked tokens will be deleted every: {:?}.", + self.period + ); + } +} + +impl PeriodicTask for ClearJwtTokens { + fn period(&self) -> Duration { + self.period + } + + fn tick(&mut self, _ctx: &TaskCtx) -> TaskFuture { + let app_state = self.app_state.clone(); + + Box::pin(async move { + trace!("Checking for expired revoked JWT tokens..."); + + let now = IggyTimestamp::now().to_secs(); + + match app_state + .jwt_manager + .delete_expired_revoked_tokens(now) + .await + { + Ok(()) => { + trace!("Successfully cleaned up expired revoked JWT tokens"); + } + Err(err) => { + error!("Failed to delete expired revoked JWT tokens: {}", err); + } + } + + Ok(()) + }) + } +} diff --git a/core/server/src/shard/tasks/periodic/clear_personal_access_tokens.rs b/core/server/src/shard/tasks/periodic/clear_personal_access_tokens.rs new file mode 100644 index 00000000..bf1c6076 --- /dev/null +++ b/core/server/src/shard/tasks/periodic/clear_personal_access_tokens.rs @@ -0,0 +1,106 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::task_registry::{PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use iggy_common::IggyTimestamp; +use std::fmt::Debug; +use std::rc::Rc; +use std::sync::Arc; +use std::time::Duration; +use tracing::{info, trace}; + +pub struct ClearPersonalAccessTokens { + shard: Rc<IggyShard>, + period: Duration, +} + +impl Debug for ClearPersonalAccessTokens { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClearPersonalAccessTokens") + .field("shard_id", &self.shard.id) + .field("period", &self.period) + .finish() + } +} + +impl ClearPersonalAccessTokens { + pub fn new(shard: Rc<IggyShard>, period: Duration) -> Self { + Self { shard, period } + } +} + +impl TaskMeta for ClearPersonalAccessTokens { + fn name(&self) -> &'static str { + "clear_personal_access_tokens" + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn on_start(&self) { + info!( + "Personal access token cleaner is enabled, expired tokens will be deleted every: {:?}.", + self.period + ); + } +} + +impl PeriodicTask for ClearPersonalAccessTokens { + fn period(&self) -> Duration { + self.period + } + + fn tick(&mut self, _ctx: &TaskCtx) -> TaskFuture { + let shard = self.shard.clone(); + + Box::pin(async move { + trace!("Checking for expired personal access tokens..."); + + let now = IggyTimestamp::now(); + let mut total_removed = 0; + + let users = shard.users.borrow(); + for user in users.values() { + let expired_tokens: Vec<Arc<String>> = user + .personal_access_tokens + .iter() + .filter(|entry| entry.value().is_expired(now)) + .map(|entry| entry.key().clone()) + .collect(); + + for token_hash in expired_tokens { + if let Some((_, pat)) = user.personal_access_tokens.remove(&token_hash) { + info!( + "Removed expired personal access token '{}' for user ID {}", + pat.name, user.id + ); + total_removed += 1; + } + } + } + + if total_removed > 0 { + info!("Removed {total_removed} expired personal access tokens"); + } + + Ok(()) + }) + } +} diff --git a/core/server/src/channels/mod.rs b/core/server/src/shard/tasks/periodic/mod.rs similarity index 68% rename from core/server/src/channels/mod.rs rename to core/server/src/shard/tasks/periodic/mod.rs index d147f7e0..ca808dca 100644 --- a/core/server/src/channels/mod.rs +++ b/core/server/src/shard/tasks/periodic/mod.rs @@ -16,4 +16,14 @@ * under the License. */ -pub mod commands; +pub mod clear_jwt_tokens; +pub mod clear_personal_access_tokens; +pub mod print_sysinfo; +pub mod save_messages; +pub mod verify_heartbeats; + +pub use clear_jwt_tokens::ClearJwtTokens; +pub use clear_personal_access_tokens::ClearPersonalAccessTokens; +pub use print_sysinfo::PrintSysinfo; +pub use save_messages::SaveMessages; +pub use verify_heartbeats::VerifyHeartbeats; diff --git a/core/server/src/shard/tasks/periodic/print_sysinfo.rs b/core/server/src/shard/tasks/periodic/print_sysinfo.rs new file mode 100644 index 00000000..b1c91f86 --- /dev/null +++ b/core/server/src/shard/tasks/periodic/print_sysinfo.rs @@ -0,0 +1,123 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::task_registry::{PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use human_repr::HumanCount; +use std::fmt::Debug; +use std::rc::Rc; +use std::time::Duration; +use tracing::{error, info, trace}; + +pub struct PrintSysinfo { + shard: Rc<IggyShard>, + period: Duration, +} + +impl Debug for PrintSysinfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrintSysinfo") + .field("shard_id", &self.shard.id) + .field("period", &self.period) + .finish() + } +} + +impl PrintSysinfo { + pub fn new(shard: Rc<IggyShard>, period: Duration) -> Self { + Self { shard, period } + } +} + +impl TaskMeta for PrintSysinfo { + fn name(&self) -> &'static str { + "print_sysinfo" + } + + fn scope(&self) -> TaskScope { + TaskScope::SpecificShard(0) + } + + fn on_start(&self) { + info!( + "System info logger is enabled, OS info will be printed every: {:?}", + self.period + ); + } +} + +impl PeriodicTask for PrintSysinfo { + fn period(&self) -> Duration { + self.period + } + + fn tick(&mut self, _ctx: &TaskCtx) -> TaskFuture { + let shard = self.shard.clone(); + + Box::pin(async move { + trace!("Fetching OS info..."); + + if let Err(e) = print_os_info() { + error!("Failed to print system info: {}", e); + } + + Ok(()) + }) + } +} + +fn print_os_info() -> Result<(), String> { + let mut sys = sysinfo::System::new(); + sys.refresh_memory(); + + let available_memory = sys.available_memory(); + let free_memory = sys.free_memory(); + let used_memory = sys.used_memory(); + let total_memory = sys.total_memory(); + + info!( + "Memory -> available: {}, free: {}, used: {}, total: {}.", + available_memory.human_count_bytes(), + free_memory.human_count_bytes(), + used_memory.human_count_bytes(), + total_memory.human_count_bytes() + ); + + let disks = sysinfo::Disks::new_with_refreshed_list(); + for disk in disks.list() { + let name = disk.name().to_string_lossy(); + let mount_point = disk.mount_point(); + let available_space = disk.available_space(); + let total_space = disk.total_space(); + let used_space = total_space - available_space; + let file_system = disk.file_system().to_string_lossy(); + let is_removable = disk.is_removable(); + info!( + "Disk: {}, mounted: {}, removable: {}, file system: {}, available: {}, used: {}, total: {}", + name, + mount_point.display(), + is_removable, + file_system, + available_space.human_count_bytes(), + used_space.human_count_bytes(), + total_space.human_count_bytes() + ); + } + + Ok(()) +} diff --git a/core/server/src/shard/tasks/periodic/save_messages.rs b/core/server/src/shard/tasks/periodic/save_messages.rs new file mode 100644 index 00000000..0bb8b35e --- /dev/null +++ b/core/server/src/shard/tasks/periodic/save_messages.rs @@ -0,0 +1,127 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::task_registry::{ + PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskResult, TaskScope, +}; +use crate::shard_info; +use iggy_common::Identifier; +use std::fmt::Debug; +use std::rc::Rc; +use std::time::Duration; +use tracing::{error, info, trace}; + +pub struct SaveMessages { + shard: Rc<IggyShard>, + period: Duration, +} + +impl Debug for SaveMessages { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SaveMessages") + .field("shard_id", &self.shard.id) + .field("period", &self.period) + .finish() + } +} + +impl SaveMessages { + pub fn new(shard: Rc<IggyShard>, period: Duration) -> Self { + Self { shard, period } + } +} + +impl TaskMeta for SaveMessages { + fn name(&self) -> &'static str { + "save_messages" + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn on_start(&self) { + let enforce_fsync = self.shard.config.message_saver.enforce_fsync; + info!( + "Message saver is enabled, buffered messages will be automatically saved every: {:?}, enforce fsync: {enforce_fsync}.", + self.period + ); + } +} + +impl PeriodicTask for SaveMessages { + fn period(&self) -> Duration { + self.period + } + + fn tick(&mut self, ctx: &TaskCtx) -> TaskFuture { + let shard = ctx.shard.clone(); + Box::pin(async move { + trace!("Saving buffered messages..."); + + let namespaces = shard.get_current_shard_namespaces(); + let mut total_saved_messages = 0u32; + let reason = "background saver triggered".to_string(); + + for ns in namespaces { + let stream_id = Identifier::numeric(ns.stream_id() as u32).unwrap(); + let topic_id = Identifier::numeric(ns.topic_id() as u32).unwrap(); + let partition_id = ns.partition_id(); + + match shard + .streams2 + .persist_messages( + shard.id, + &stream_id, + &topic_id, + partition_id, + reason.clone(), + &shard.config.system, + ) + .await + { + Ok(batch_count) => { + total_saved_messages += batch_count; + } + Err(err) => { + error!( + "Failed to save messages for partition {}: {}", + partition_id, err + ); + } + } + } + + if total_saved_messages > 0 { + shard_info!( + shard.id, + "Saved {} buffered messages on disk.", + total_saved_messages + ); + } + + trace!("Finished saving buffered messages."); + Ok(()) + }) + } + + fn last_tick_on_shutdown(&self) -> bool { + true + } +} diff --git a/core/server/src/shard/tasks/periodic/verify_heartbeats.rs b/core/server/src/shard/tasks/periodic/verify_heartbeats.rs new file mode 100644 index 00000000..5810d310 --- /dev/null +++ b/core/server/src/shard/tasks/periodic/verify_heartbeats.rs @@ -0,0 +1,126 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::shard::IggyShard; +use crate::shard::task_registry::{PeriodicTask, TaskCtx, TaskFuture, TaskMeta, TaskScope}; +use iggy_common::{IggyDuration, IggyTimestamp}; +use std::fmt::Debug; +use std::rc::Rc; +use std::time::Duration; +use tracing::{debug, info, trace, warn}; + +const MAX_THRESHOLD: f64 = 1.2; + +pub struct VerifyHeartbeats { + shard: Rc<IggyShard>, + period: Duration, + max_interval: IggyDuration, +} + +impl Debug for VerifyHeartbeats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VerifyHeartbeats") + .field("shard_id", &self.shard.id) + .field("period", &self.period) + .finish() + } +} + +impl VerifyHeartbeats { + pub fn new(shard: Rc<IggyShard>, period: Duration) -> Self { + let interval = IggyDuration::from(period); + let max_interval = IggyDuration::from((MAX_THRESHOLD * interval.as_micros() as f64) as u64); + Self { + shard, + period, + max_interval, + } + } +} + +impl TaskMeta for VerifyHeartbeats { + fn name(&self) -> &'static str { + "verify_heartbeats" + } + + fn scope(&self) -> TaskScope { + TaskScope::AllShards + } + + fn on_start(&self) { + info!( + "Heartbeats will be verified every: {}. Max allowed interval: {}.", + IggyDuration::from(self.period), + self.max_interval + ); + } +} + +impl PeriodicTask for VerifyHeartbeats { + fn period(&self) -> Duration { + self.period + } + + fn tick(&mut self, _ctx: &TaskCtx) -> TaskFuture { + let shard = self.shard.clone(); + let max_interval = self.max_interval; + + Box::pin(async move { + trace!("Verifying heartbeats..."); + + let clients = { + let client_manager = shard.client_manager.borrow(); + client_manager.get_clients() + }; + + let now = IggyTimestamp::now(); + let heartbeat_to = IggyTimestamp::from(now.as_micros() - max_interval.as_micros()); + debug!("Verifying heartbeats at: {now}, max allowed timestamp: {heartbeat_to}"); + + let mut stale_clients = Vec::new(); + for client in clients { + if client.last_heartbeat.as_micros() < heartbeat_to.as_micros() { + warn!( + "Stale client session: {}, last heartbeat at: {}, max allowed timestamp: {heartbeat_to}", + client.session, client.last_heartbeat, + ); + client.session.set_stale(); + stale_clients.push(client.session.client_id); + } else { + debug!( + "Valid heartbeat at: {} for client session: {}, max allowed timestamp: {heartbeat_to}", + client.last_heartbeat, client.session, + ); + } + } + + if stale_clients.is_empty() { + return Ok(()); + } + + let count = stale_clients.len(); + info!("Removing {count} stale clients..."); + for client_id in stale_clients { + shard.delete_client(client_id); + } + info!("Removed {count} stale clients."); + + Ok(()) + }) + } +} diff --git a/core/server/src/shard/transmission/message.rs b/core/server/src/shard/transmission/message.rs index a4e4bb3a..14bf075b 100644 --- a/core/server/src/shard/transmission/message.rs +++ b/core/server/src/shard/transmission/message.rs @@ -1,7 +1,3 @@ -use std::{rc::Rc, sync::Arc}; - -use iggy_common::{Identifier, PollingStrategy}; - /* Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -27,6 +23,7 @@ use crate::{ slab::partitions, streaming::{polling_consumer::PollingConsumer, segments::IggyMessagesBatchMut}, }; +use iggy_common::Identifier; pub enum ShardSendRequestResult { // TODO: In the future we can add other variants, for example backpressure from the destination shard, diff --git a/core/server/src/slab/streams.rs b/core/server/src/slab/streams.rs index d733e04e..e8ee7468 100644 --- a/core/server/src/slab/streams.rs +++ b/core/server/src/slab/streams.rs @@ -1,3 +1,5 @@ +use crate::shard::task_registry::tls::task_registry; +use crate::streaming::partitions as streaming_partitions; use crate::{ binary::handlers::messages::poll_messages_handler::IggyPollMetadata, configs::{cache_indexes::CacheIndexesConfig, system::SystemConfig}, @@ -43,10 +45,7 @@ use std::{ cell::RefCell, sync::{Arc, atomic::Ordering}, }; -use tracing::trace; - -// Import streaming partitions helpers for the persist_messages method -use crate::streaming::partitions as streaming_partitions; +use tracing::{error, trace}; const CAPACITY: usize = 1024; pub type ContainerId = usize; @@ -735,15 +734,29 @@ impl Streams { (msg.unwrap(), index.unwrap()) }); - compio::runtime::spawn(async move { - let _ = log_writer.fsync().await; - }) - .detach(); - compio::runtime::spawn(async move { - let _ = index_writer.fsync().await; - drop(index_writer) - }) - .detach(); + task_registry().spawn_oneshot_future("fsync:segment-close-log", true, async move { + match log_writer.fsync().await { + Ok(_) => Ok(()), + Err(e) => { + error!("Failed to fsync log writer on segment close: {}", e); + Err(e) + } + } + }); + + task_registry().spawn_oneshot_future("fsync:segment-close-index", true, async move { + match index_writer.fsync().await { + Ok(_) => { + drop(index_writer); + Ok(()) + } + Err(e) => { + error!("Failed to fsync index writer on segment close: {}", e); + drop(index_writer); + Err(e) + } + } + }); let (start_offset, size, end_offset) = self.with_partition_by_id(stream_id, topic_id, partition_id, |(.., log)| { diff --git a/core/server/src/streaming/storage.rs b/core/server/src/streaming/storage.rs index 03199d70..d64845a2 100644 --- a/core/server/src/streaming/storage.rs +++ b/core/server/src/streaming/storage.rs @@ -47,7 +47,7 @@ macro_rules! forward_async_methods { } } -// TODO: Tech debt, how to get rid of this ? +// TODO: Tech debt, how to get rid of this ? #[derive(Debug)] pub enum SystemInfoStorageKind { File(FileSystemInfoStorage), diff --git a/core/server/src/tcp/tcp_listener.rs b/core/server/src/tcp/tcp_listener.rs index 18565ae6..22a34feb 100644 --- a/core/server/src/tcp/tcp_listener.rs +++ b/core/server/src/tcp/tcp_listener.rs @@ -19,9 +19,10 @@ use crate::binary::sender::SenderKind; use crate::configs::tcp::TcpSocketConfig; use crate::shard::IggyShard; +use crate::shard::task_registry::task_registry; use crate::shard::transmission::event::ShardEvent; use crate::tcp::connection_handler::{handle_connection, handle_error}; -use crate::{shard_error, shard_info}; +use crate::{shard_debug, shard_error, shard_info}; use compio::net::{TcpListener, TcpOpts}; use error_set::ErrContext; use futures::FutureExt; @@ -154,7 +155,7 @@ async fn accept_loop( let accept_future = listener.accept(); futures::select! { _ = shutdown_check.fuse() => { - shard_info!(shard.id, "{} detected shutdown flag, no longer accepting connections", server_name); + shard_debug!(shard.id, "{} detected shutdown flag, no longer accepting connections", server_name); break; } result = accept_future.fuse() => { @@ -180,14 +181,14 @@ async fn accept_loop( shard_info!(shard.id, "Created new session: {}", session); let mut sender = SenderKind::get_tcp_sender(stream); - let conn_stop_receiver = shard_clone.task_registry.add_connection(client_id); + let conn_stop_receiver = task_registry().add_connection(client_id); let shard_for_conn = shard_clone.clone(); - shard_clone.task_registry.spawn_tracked(async move { + task_registry().spawn_tracked(async move { if let Err(error) = handle_connection(&session, &mut sender, &shard_for_conn, conn_stop_receiver).await { handle_error(error); } - shard_for_conn.task_registry.remove_connection(&client_id); + task_registry().remove_connection(&client_id); if let Err(error) = sender.shutdown().await { shard_error!(shard.id, "Failed to shutdown TCP stream for client: {}, address: {}. {}", client_id, address, error); diff --git a/core/server/src/tcp/tcp_tls_listener.rs b/core/server/src/tcp/tcp_tls_listener.rs index ec69dbcd..ee83f4dc 100644 --- a/core/server/src/tcp/tcp_tls_listener.rs +++ b/core/server/src/tcp/tcp_tls_listener.rs @@ -19,6 +19,7 @@ use crate::binary::sender::SenderKind; use crate::configs::tcp::TcpSocketConfig; use crate::shard::IggyShard; +use crate::shard::task_registry::task_registry; use crate::shard::transmission::event::ShardEvent; use crate::tcp::connection_handler::{handle_connection, handle_error}; use crate::{shard_error, shard_info, shard_warn}; @@ -219,7 +220,7 @@ async fn accept_loop( // Perform TLS handshake in a separate task to avoid blocking the accept loop let task_shard = shard_clone.clone(); - task_shard.task_registry.spawn_tracked(async move { + task_registry().spawn_tracked(async move { match acceptor.accept(stream).await { Ok(tls_stream) => { // TLS handshake successful, now create session @@ -237,13 +238,13 @@ async fn accept_loop( let client_id = session.client_id; shard_info!(shard_clone.id, "Created new session: {}", session); - let conn_stop_receiver = shard_clone.task_registry.add_connection(client_id); + let conn_stop_receiver = task_registry().add_connection(client_id); let shard_for_conn = shard_clone.clone(); let mut sender = SenderKind::get_tcp_tls_sender(tls_stream); if let Err(error) = handle_connection(&session, &mut sender, &shard_for_conn, conn_stop_receiver).await { handle_error(error); } - shard_for_conn.task_registry.remove_connection(&client_id); + task_registry().remove_connection(&client_id); if let Err(error) = sender.shutdown().await { shard_error!(shard.id, "Failed to shutdown TCP TLS stream for client: {}, address: {}. {}", client_id, address, error);
