This is an automated email from the ASF dual-hosted git repository.

jiacai2050 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/horaedb.git


The following commit(s) were added to refs/heads/main by this push:
     new aa438cae feat: use multithreading to optimize WAL replay (#1572)
aa438cae is described below

commit aa438cae8c552a18f6701d6237bcbcc982bed25a
Author: Draco <[email protected]>
AuthorDate: Fri Sep 27 15:40:42 2024 +0800

    feat: use multithreading to optimize WAL replay (#1572)
    
    ## Rationale
    1. Currently, the WAL replayer uses coroutines to replay the WAL logs of
    multiple tables in parallel. However, coroutines utilize at most one
    CPU. By switching to a multithreaded approach, we can fully leverage
    multiple CPUs.
    2. We observed that during the replay phase, decoding the WAL log is a
    CPU-intensive operation, so parallelize it.
    
    ## Detailed Changes
    
    1. Modify both `TableBasedReplay` and `RegionBasedReplay` to use the
    `spawn task` approach for parallelism, with a maximum of 20 tasks
    running concurrently.
    2. Preload next segment in WAL based on local storage.
    4. In `BatchLogIteratorAdapter::simulated_async_next`, we first retrieve
    all the payloads in a batch and then decode them in parallel.
    
    ## Test Plan
    Manual testing.
---
 Cargo.lock                                        |  22 ++-
 src/analytic_engine/Cargo.toml                    |   1 +
 src/analytic_engine/src/instance/wal_replayer.rs  | 116 +++++++--------
 src/wal/Cargo.toml                                |   2 +
 src/wal/src/local_storage_impl/record_encoding.rs |  16 +--
 src/wal/src/local_storage_impl/segment.rs         | 164 ++++++++++++++++------
 src/wal/src/manager.rs                            |  39 +++--
 7 files changed, 238 insertions(+), 122 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 43d74e7e..584f2e98 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -83,6 +83,7 @@ dependencies = [
  "arc-swap 1.6.0",
  "arena",
  "arrow 49.0.0",
+ "async-scoped",
  "async-stream",
  "async-trait",
  "atomic_enum",
@@ -764,6 +765,17 @@ dependencies = [
  "syn 2.0.48",
 ]
 
+[[package]]
+name = "async-scoped"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "4042078ea593edffc452eef14e99fdb2b120caa4ad9618bcdeabc4a023b98740"
+dependencies = [
+ "futures 0.3.28",
+ "pin-project",
+ "tokio",
+]
+
 [[package]]
 name = "async-stream"
 version = "0.3.4"
@@ -5981,9 +5993,9 @@ dependencies = [
 
 [[package]]
 name = "rayon"
-version = "1.8.0"
+version = "1.10.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1"
+checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
 dependencies = [
  "either",
  "rayon-core",
@@ -5991,9 +6003,9 @@ dependencies = [
 
 [[package]]
 name = "rayon-core"
-version = "1.12.0"
+version = "1.12.1"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed"
+checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
 dependencies = [
  "crossbeam-deque",
  "crossbeam-utils",
@@ -8220,6 +8232,7 @@ name = "wal"
 version = "2.0.0"
 dependencies = [
  "anyhow",
+ "async-scoped",
  "async-trait",
  "bytes_ext",
  "chrono",
@@ -8237,6 +8250,7 @@ dependencies = [
  "prometheus 0.12.0",
  "prost 0.11.8",
  "rand 0.8.5",
+ "rayon",
  "rocksdb",
  "runtime",
  "serde",
diff --git a/src/analytic_engine/Cargo.toml b/src/analytic_engine/Cargo.toml
index 8197b4ee..09ff47af 100644
--- a/src/analytic_engine/Cargo.toml
+++ b/src/analytic_engine/Cargo.toml
@@ -43,6 +43,7 @@ anyhow = { workspace = true }
 arc-swap = "1.4.0"
 arena = { workspace = true }
 arrow = { workspace = true }
+async-scoped = { version = "0.9.0", features = ["use-tokio"] }
 async-stream = { workspace = true }
 async-trait = { workspace = true }
 atomic_enum = { workspace = true }
diff --git a/src/analytic_engine/src/instance/wal_replayer.rs 
b/src/analytic_engine/src/instance/wal_replayer.rs
index f7828951..6c674140 100644
--- a/src/analytic_engine/src/instance/wal_replayer.rs
+++ b/src/analytic_engine/src/instance/wal_replayer.rs
@@ -30,14 +30,13 @@ use common_types::{
     schema::{IndexInWriterSchema, Schema},
     table::ShardId,
 };
-use futures::StreamExt;
 use generic_error::BoxError;
 use lazy_static::lazy_static;
 use logger::{debug, error, info, trace, warn};
 use prometheus::{exponential_buckets, register_histogram, Histogram};
 use snafu::ResultExt;
 use table_engine::table::TableId;
-use tokio::sync::{Mutex, MutexGuard};
+use tokio::sync::{Mutex, MutexGuard, Semaphore};
 use wal::{
     log_batch::LogEntry,
     manager::{
@@ -74,6 +73,8 @@ lazy_static! {
     .unwrap();
 }
 
+const MAX_REPLAY_TASK_NUM: usize = 20;
+
 /// Wal replayer supporting both table based and region based
 // TODO: limit the memory usage in `RegionBased` mode.
 pub struct WalReplayer<'a> {
@@ -189,22 +190,23 @@ impl Replay for TableBasedReplay {
             ..Default::default()
         };
 
-        let mut tasks = futures::stream::iter(
-            table_datas
-                .iter()
-                .map(|table_data| {
-                    let table_id = table_data.id;
-                    let read_ctx = &read_ctx;
-                    async move {
-                        let ret = Self::recover_table_logs(context, 
table_data, read_ctx).await;
-                        (table_id, ret)
-                    }
-                })
-                .collect::<Vec<_>>(),
-        )
-        .buffer_unordered(20);
-        while let Some((table_id, ret)) = tasks.next().await {
-            if let Err(e) = ret {
+        let ((), results) = async_scoped::TokioScope::scope_and_block(|scope| {
+            // Limit the maximum number of concurrent tasks.
+            let semaphore = Arc::new(Semaphore::new(MAX_REPLAY_TASK_NUM));
+            for table_data in table_datas {
+                let table_id = table_data.id;
+                let read_ctx = &read_ctx;
+                let semaphore = semaphore.clone();
+                scope.spawn(async move {
+                    let _permit = semaphore.acquire().await.unwrap();
+                    let ret = Self::recover_table_logs(context, table_data, 
read_ctx).await;
+                    (table_id, ret)
+                });
+            }
+        });
+
+        for result in results.into_iter().flatten() {
+            if let (table_id, Err(e)) = result {
                 // If occur error, mark this table as failed and store the 
cause.
                 failed_tables.insert(table_id, e);
             }
@@ -345,7 +347,7 @@ impl RegionBasedReplay {
                 table_data: table_data.clone(),
                 serial_exec,
             };
-            serial_exec_ctxs.insert(table_data.id, serial_exec_ctx);
+            serial_exec_ctxs.insert(table_data.id, 
Mutex::new(serial_exec_ctx));
             table_datas_by_id.insert(table_data.id.as_u64(), 
table_data.clone());
         }
 
@@ -353,7 +355,7 @@ impl RegionBasedReplay {
         let schema_provider = TableSchemaProviderAdapter {
             table_datas: table_datas_by_id.clone(),
         };
-        let serial_exec_ctxs = Arc::new(Mutex::new(serial_exec_ctxs));
+        let serial_exec_ctxs = serial_exec_ctxs;
         // Split and replay logs.
         loop {
             let _timer = PULL_LOGS_DURATION_HISTOGRAM.start_timer();
@@ -381,49 +383,53 @@ impl RegionBasedReplay {
     async fn replay_single_batch(
         context: &ReplayContext,
         log_batch: &VecDeque<LogEntry<ReadPayload>>,
-        serial_exec_ctxs: &Arc<Mutex<HashMap<TableId, SerialExecContext<'_>>>>,
+        serial_exec_ctxs: &HashMap<TableId, Mutex<SerialExecContext<'_>>>,
         failed_tables: &mut FailedTables,
     ) -> Result<()> {
         let mut table_batches = Vec::new();
         // TODO: No `group_by` method in `VecDeque`, so implement it manually 
here...
         Self::split_log_batch_by_table(log_batch, &mut table_batches);
 
-        // TODO: Replay logs of different tables in parallel.
-        let mut replay_tasks = Vec::with_capacity(table_batches.len());
-        for table_batch in table_batches {
-            // Some tables may have failed in previous replay, ignore them.
-            if failed_tables.contains_key(&table_batch.table_id) {
-                continue;
-            }
-            let log_entries: Vec<_> = table_batch
-                .ranges
-                .iter()
-                .flat_map(|range| log_batch.range(range.clone()))
-                .collect();
-
-            let serial_exec_ctxs = serial_exec_ctxs.clone();
-            replay_tasks.push(async move {
-                // Some tables may have been moved to other shards or dropped, 
ignore such logs.
-                if let Some(ctx) = 
serial_exec_ctxs.lock().await.get_mut(&table_batch.table_id) {
-                    let result = replay_table_log_entries(
-                        &context.flusher,
-                        context.max_retry_flush_limit,
-                        &mut ctx.serial_exec,
-                        &ctx.table_data,
-                        log_entries.into_iter(),
-                    )
-                    .await;
-                    (table_batch.table_id, Some(result))
-                } else {
-                    (table_batch.table_id, None)
+        let ((), results) = async_scoped::TokioScope::scope_and_block(|scope| {
+            // Limit the maximum number of concurrent tasks.
+            let semaphore = Arc::new(Semaphore::new(MAX_REPLAY_TASK_NUM));
+
+            for table_batch in table_batches {
+                // Some tables may have failed in previous replay, ignore them.
+                if failed_tables.contains_key(&table_batch.table_id) {
+                    continue;
                 }
-            });
-        }
+                let log_entries: Vec<_> = table_batch
+                    .ranges
+                    .iter()
+                    .flat_map(|range| log_batch.range(range.clone()))
+                    .collect();
+                let semaphore = semaphore.clone();
+
+                scope.spawn(async move {
+                    let _permit = semaphore.acquire().await.unwrap();
+                    // Some tables may have been moved to other shards or 
dropped, ignore such logs.
+                    if let Some(ctx) = 
serial_exec_ctxs.get(&table_batch.table_id) {
+                        let mut ctx = ctx.lock().await;
+                        let table_data = ctx.table_data.clone();
+                        let result = replay_table_log_entries(
+                            &context.flusher,
+                            context.max_retry_flush_limit,
+                            &mut ctx.serial_exec,
+                            &table_data,
+                            log_entries.into_iter(),
+                        )
+                        .await;
+                        (table_batch.table_id, Some(result))
+                    } else {
+                        (table_batch.table_id, None)
+                    }
+                });
+            }
+        });
 
-        // Run at most 20 tasks in parallel
-        let mut replay_tasks = 
futures::stream::iter(replay_tasks).buffer_unordered(20);
-        while let Some((table_id, ret)) = replay_tasks.next().await {
-            if let Some(Err(e)) = ret {
+        for result in results.into_iter().flatten() {
+            if let (table_id, Some(Err(e))) = result {
                 // If occur error, mark this table as failed and store the 
cause.
                 failed_tables.insert(table_id, e);
             }
diff --git a/src/wal/Cargo.toml b/src/wal/Cargo.toml
index 30a5b004..14640164 100644
--- a/src/wal/Cargo.toml
+++ b/src/wal/Cargo.toml
@@ -48,6 +48,7 @@ required-features = ["wal-message-queue", "wal-table-kv", 
"wal-rocksdb", "wal-lo
 
 [dependencies]
 anyhow = { workspace = true }
+async-scoped = { version = "0.9.0", features = ["use-tokio"] }
 async-trait = { workspace = true }
 bytes_ext = { workspace = true }
 chrono = { workspace = true }
@@ -64,6 +65,7 @@ memmap2 = { version = "0.9.4", optional = true }
 message_queue = { workspace = true, optional = true }
 prometheus = { workspace = true }
 prost = { workspace = true }
+rayon = "1.10.0"
 runtime = { workspace = true }
 serde = { workspace = true }
 serde_json = { workspace = true }
diff --git a/src/wal/src/local_storage_impl/record_encoding.rs 
b/src/wal/src/local_storage_impl/record_encoding.rs
index 6e8011c2..e91d3f5a 100644
--- a/src/wal/src/local_storage_impl/record_encoding.rs
+++ b/src/wal/src/local_storage_impl/record_encoding.rs
@@ -63,7 +63,7 @@ define_result!(Error);
 /// 
+---------+--------+--------+------------+--------------+--------------+-------+
 /// ```
 #[derive(Debug)]
-pub struct Record<'a> {
+pub struct Record {
     /// The version number of the record.
     pub version: u8,
 
@@ -83,11 +83,11 @@ pub struct Record<'a> {
     pub value_length: u32,
 
     /// Common log value.
-    pub value: &'a [u8],
+    pub value: Vec<u8>,
 }
 
-impl<'a> Record<'a> {
-    pub fn new(table_id: u64, sequence_num: u64, value: &'a [u8]) -> 
Result<Self> {
+impl Record {
+    pub fn new(table_id: u64, sequence_num: u64, value: &[u8]) -> Result<Self> 
{
         let mut record = Record {
             version: NEWEST_RECORD_ENCODING_VERSION,
             crc: 0,
@@ -95,7 +95,7 @@ impl<'a> Record<'a> {
             table_id,
             sequence_num,
             value_length: value.len() as u32,
-            value,
+            value: value.to_vec(),
         };
 
         // Calculate CRC
@@ -128,7 +128,7 @@ impl RecordEncoding {
     }
 }
 
-impl Encoder<Record<'_>> for RecordEncoding {
+impl Encoder<Record> for RecordEncoding {
     type Error = Error;
 
     fn encode<B: BufMut>(&self, buf: &mut B, record: &Record) -> Result<()> {
@@ -147,7 +147,7 @@ impl Encoder<Record<'_>> for RecordEncoding {
         buf.try_put_u64(record.table_id).context(Encoding)?;
         buf.try_put_u64(record.sequence_num).context(Encoding)?;
         buf.try_put_u32(record.value_length).context(Encoding)?;
-        buf.try_put(record.value).context(Encoding)?;
+        buf.try_put(record.value.as_slice()).context(Encoding)?;
         Ok(())
     }
 
@@ -222,7 +222,7 @@ impl RecordEncoding {
         let value_length = buf.try_get_u32().context(Decoding)?;
 
         // Read value
-        let value = &buf[0..value_length as usize];
+        let value = buf[0..value_length as usize].to_vec();
         buf.advance(value_length as usize);
 
         Ok(Record {
diff --git a/src/wal/src/local_storage_impl/segment.rs 
b/src/wal/src/local_storage_impl/segment.rs
index 02bd1d13..b66701b2 100644
--- a/src/wal/src/local_storage_impl/segment.rs
+++ b/src/wal/src/local_storage_impl/segment.rs
@@ -33,7 +33,7 @@ use common_types::{table::TableId, SequenceNumber, 
MAX_SEQUENCE_NUMBER, MIN_SEQU
 use generic_error::{BoxError, GenericError};
 use macros::define_result;
 use memmap2::{MmapMut, MmapOptions};
-use runtime::Runtime;
+use runtime::{JoinHandle, Runtime};
 use snafu::{ensure, Backtrace, ResultExt, Snafu};
 
 use crate::{
@@ -832,6 +832,7 @@ impl Region {
             Some(req.location.table_id),
             start,
             end,
+            self.runtime.clone(),
         )?;
 
         Ok(BatchLogIteratorAdapter::new_with_sync(
@@ -849,6 +850,7 @@ impl Region {
             None,
             MIN_SEQUENCE_NUMBER,
             MAX_SEQUENCE_NUMBER,
+            self.runtime.clone(),
         )?;
         Ok(BatchLogIteratorAdapter::new_with_sync(
             Box::new(iter),
@@ -1006,19 +1008,37 @@ impl RegionManager {
     }
 }
 
+fn decode_segment_content(
+    segment_content: &[u8],
+    record_positions: &[Position],
+    record_encoding: &RecordEncoding,
+) -> Result<Vec<Record>> {
+    let mut records = Vec::with_capacity(record_positions.len());
+
+    for pos in record_positions {
+        // Extract the record data from the segment content
+        let record_data = &segment_content[pos.start..pos.end];
+
+        // Decode the record
+        let record = record_encoding
+            .decode(record_data)
+            .box_err()
+            .context(InvalidRecord)?;
+        records.push(record);
+    }
+    Ok(records)
+}
+
 #[derive(Debug)]
 struct SegmentLogIterator {
     /// Encoding method for common log.
     log_encoding: CommonLogEncoding,
 
     /// Encoding method for records.
-    record_encoding: RecordEncoding,
-
-    /// Raw content of the segment.
-    segment_content: Vec<u8>,
+    _record_encoding: RecordEncoding,
 
-    /// Positions of records within the segment content.
-    record_positions: Vec<Position>,
+    /// Decoded log records in the segment.
+    records: Vec<Record>,
 
     /// Optional identifier for the table, which is used to filter logs.
     table_id: Option<TableId>,
@@ -1040,27 +1060,19 @@ struct SegmentLogIterator {
 }
 
 impl SegmentLogIterator {
-    pub fn new(
+    pub fn new_with_records(
         log_encoding: CommonLogEncoding,
         record_encoding: RecordEncoding,
-        segment: Arc<Mutex<Segment>>,
-        segment_manager: Arc<SegmentManager>,
+        records: Vec<Record>,
+        table_ranges: HashMap<TableId, (SequenceNumber, SequenceNumber)>,
         table_id: Option<TableId>,
         start: SequenceNumber,
         end: SequenceNumber,
     ) -> Result<Self> {
-        let mut guard = segment.lock().unwrap();
-        // Open the segment if it is not open
-        segment_manager.open_segment(&mut guard, segment.clone())?;
-        let segment_content = guard.read(0, guard.current_size)?;
-        let record_positions = guard.record_position.clone();
-        let table_ranges = guard.table_ranges.clone();
-
         Ok(Self {
             log_encoding,
-            record_encoding,
-            segment_content,
-            record_positions,
+            _record_encoding: record_encoding,
+            records,
             table_id,
             table_ranges,
             start,
@@ -1076,24 +1088,14 @@ impl SegmentLogIterator {
         }
 
         loop {
-            // Get the next record position
-            let Some(pos) = self.record_positions.get(self.current_record_idx) 
else {
+            // Get the next record
+            let Some(record) = self.records.get(self.current_record_idx) else {
                 self.no_more_data = true;
                 return Ok(None);
             };
 
             self.current_record_idx += 1;
 
-            // Extract the record data from the segment content
-            let record_data = &self.segment_content[pos.start..pos.end];
-
-            // Decode the record
-            let record = self
-                .record_encoding
-                .decode(record_data)
-                .box_err()
-                .context(InvalidRecord)?;
-
             // Filter by sequence number
             if record.sequence_num < self.start {
                 continue;
@@ -1122,7 +1124,7 @@ impl SegmentLogIterator {
             // Decode the value
             let value = self
                 .log_encoding
-                .decode_value(record.value)
+                .decode_value(&record.value)
                 .box_err()
                 .context(InvalidRecord)?;
 
@@ -1150,6 +1152,9 @@ pub struct MultiSegmentLogIterator {
     /// Current segment iterator.
     current_iterator: Option<SegmentLogIterator>,
 
+    /// Future iterator for preloading the next segment.
+    next_segment_iterator: Option<JoinHandle<Result<SegmentLogIterator>>>,
+
     /// Encoding method for common log.
     log_encoding: CommonLogEncoding,
 
@@ -1167,6 +1172,9 @@ pub struct MultiSegmentLogIterator {
 
     /// The raw payload data of the current record.
     current_payload: Vec<u8>,
+
+    /// Runtime for preloading segments
+    runtime: Arc<Runtime>,
 }
 
 impl MultiSegmentLogIterator {
@@ -1177,6 +1185,7 @@ impl MultiSegmentLogIterator {
         table_id: Option<TableId>,
         start: SequenceNumber,
         end: SequenceNumber,
+        runtime: Arc<Runtime>,
     ) -> Result<Self> {
         let relevant_segments = 
segment_manager.get_relevant_segments(table_id, start, end)?;
 
@@ -1185,12 +1194,14 @@ impl MultiSegmentLogIterator {
             segments: relevant_segments,
             current_segment_idx: 0,
             current_iterator: None,
+            next_segment_iterator: None,
             log_encoding,
             record_encoding,
             table_id,
             start,
             end,
             current_payload: Vec::new(),
+            runtime,
         };
 
         // Load the first segment iterator
@@ -1199,25 +1210,88 @@ impl MultiSegmentLogIterator {
         Ok(iter)
     }
 
+    fn preload_next_segment(&mut self) {
+        assert!(self.next_segment_iterator.is_none());
+        if self.current_segment_idx >= self.segments.len() {
+            return;
+        }
+
+        let next_segment_idx = self.current_segment_idx;
+        let segment = self.segments[next_segment_idx].clone();
+        let segment_manager = self.segment_manager.clone();
+        let log_encoding = self.log_encoding.clone();
+        let record_encoding = self.record_encoding.clone();
+        let table_id = self.table_id;
+        let start = self.start;
+        let end = self.end;
+
+        // Spawn an async task to preload the next SegmentLogIterator
+        let handle = self.runtime.spawn(async move {
+            let mut guard = segment.lock().unwrap();
+            // Open the segment if it is not open
+            segment_manager.open_segment(&mut guard, segment.clone())?;
+            let segment_content = guard.read(0, guard.current_size)?;
+            let table_ranges = guard.table_ranges.clone();
+            let records =
+                decode_segment_content(&segment_content, 
&guard.record_position, &record_encoding)?;
+            let iterator = SegmentLogIterator::new_with_records(
+                log_encoding,
+                record_encoding,
+                records,
+                table_ranges,
+                table_id,
+                start,
+                end,
+            )?;
+            Ok(iterator)
+        });
+
+        self.next_segment_iterator = Some(handle);
+    }
+
     fn load_next_segment_iterator(&mut self) -> Result<bool> {
         if self.current_segment_idx >= self.segments.len() {
             self.current_iterator = None;
             return Ok(false);
         }
 
-        let segment = self.segments[self.current_segment_idx].clone();
-        let iterator = SegmentLogIterator::new(
-            self.log_encoding.clone(),
-            self.record_encoding.clone(),
-            segment,
-            self.segment_manager.clone(),
-            self.table_id,
-            self.start,
-            self.end,
-        )?;
+        if let Some(handle) = self.next_segment_iterator.take() {
+            // Wait for the future to complete
+            let iterator = self
+                .runtime
+                .block_on(handle)
+                .map_err(anyhow::Error::new)
+                .context(Internal)??;
+            self.current_iterator = Some(iterator);
+            self.current_segment_idx += 1;
+        } else {
+            // Preload was not set, load synchronously
+            let segment = self.segments[self.current_segment_idx].clone();
+            let mut guard = segment.lock().unwrap();
+            self.segment_manager
+                .open_segment(&mut guard, segment.clone())?;
+            let segment_content = guard.read(0, guard.current_size)?;
+            let table_ranges = guard.table_ranges.clone();
+            let records = decode_segment_content(
+                &segment_content,
+                &guard.record_position,
+                &self.record_encoding,
+            )?;
+            let iterator = SegmentLogIterator::new_with_records(
+                self.log_encoding.clone(),
+                self.record_encoding.clone(),
+                records,
+                table_ranges,
+                self.table_id,
+                self.start,
+                self.end,
+            )?;
+            self.current_iterator = Some(iterator);
+            self.current_segment_idx += 1;
+        }
 
-        self.current_iterator = Some(iterator);
-        self.current_segment_idx += 1;
+        // Preload the next segment
+        self.preload_next_segment();
 
         Ok(true)
     }
diff --git a/src/wal/src/manager.rs b/src/wal/src/manager.rs
index 9c4a960b..fcd017dc 100644
--- a/src/wal/src/manager.rs
+++ b/src/wal/src/manager.rs
@@ -27,6 +27,7 @@ use common_types::{
 };
 pub use error::*;
 use generic_error::BoxError;
+use rayon::{iter::ParallelIterator, prelude::IntoParallelIterator};
 use runtime::Runtime;
 use snafu::ResultExt;
 
@@ -428,13 +429,29 @@ impl BatchLogIteratorAdapter {
         let batch_size = self.batch_size;
         let (log_entries, iter_opt) = runtime
             .spawn_blocking(move || {
-                while buffer.len() < batch_size {
+                let mut raw_entries = Vec::new();
+
+                while raw_entries.len() < batch_size {
                     if let Some(raw_log_entry) = iter.next_log_entry()? {
                         if !filter(raw_log_entry.table_id) {
                             continue;
                         }
 
-                        let mut raw_payload = raw_log_entry.payload;
+                        raw_entries.push(LogEntry {
+                            table_id: raw_log_entry.table_id,
+                            sequence: raw_log_entry.sequence,
+                            payload: raw_log_entry.payload.to_vec(),
+                        });
+                    } else {
+                        break;
+                    }
+                }
+
+                // Decoding is time-consuming, so we do it in parallel.
+                let result: Result<VecDeque<_>> = raw_entries
+                    .into_par_iter()
+                    .map(|raw_log_entry| {
+                        let mut raw_payload = raw_log_entry.payload.as_slice();
                         let ctx = PayloadDecodeContext {
                             table_id: raw_log_entry.table_id,
                         };
@@ -442,18 +459,20 @@ impl BatchLogIteratorAdapter {
                             .decode(&ctx, &mut raw_payload)
                             .box_err()
                             .context(error::Decoding)?;
-                        let log_entry = LogEntry {
+                        Ok(LogEntry {
                             table_id: raw_log_entry.table_id,
                             sequence: raw_log_entry.sequence,
                             payload,
-                        };
-                        buffer.push_back(log_entry);
-                    } else {
-                        return Ok((buffer, None));
-                    }
-                }
+                        })
+                    })
+                    .collect();
 
-                Ok((buffer, Some(iter)))
+                let log_entries = result?;
+                if log_entries.len() < batch_size {
+                    Ok((log_entries, None))
+                } else {
+                    Ok((log_entries, Some(iter)))
+                }
             })
             .await
             .context(RuntimeExec)??;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to