This is an automated email from the ASF dual-hosted git repository.
gkoszyk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iggy.git
The following commit(s) were added to refs/heads/master by this push:
new fc3976019 fix(consensus): prevent UB from untrusted bytes interpreted
as enum discriminants in consensus headers
(#2887)
fc3976019 is described below
commit fc397601977e131de077c7327cc5d09a6ed68781
Author: Krishna Vishal <[email protected]>
AuthorDate: Tue Mar 10 19:51:39 2026 +0530
fix(consensus): prevent UB from untrusted bytes interpreted as enum
discriminants in consensus headers
(#2887)
Closes #2878
## Rationale
Consensus headers previously stored `Command2` and `Operation` enum
values directly in repr(C) structs. This is unsound for consensus header
structs: arbitrary bytes from the network are reinterpreted as enum
discriminants without validation, which is undefined behavior in Rust.
## What changed?
- Change `command` and `operation` fields in all consensus headers from
enum types (`Command2`, `Operation`) to raw `u8` to prevent UB when
interpreting untrusted bytes from the network as Rust enum
discriminants.
- Add `TryFrom<u8>` impls for `Command2` and `Operation` with proper
error variants (`InvalidCommandByte`, `InvalidOperationByte`) in
`ConsensusError`.
- Fix `dispatch_request` to return `Result<Receiver<R>, ConsensusError>`
instead of silently dropping unrecognized messages.
---
Cargo.lock | 14 +++++
Cargo.toml | 2 +-
DEPENDENCIES.md | 1 +
core/common/src/types/consensus/header.rs | 96 +++++++++++++-----------------
core/common/src/types/consensus/message.rs | 89 ++++++++++++++++-----------
core/consensus/src/plane_helpers.rs | 3 +-
core/metadata/src/stm/mod.rs | 1 -
core/shard/src/lib.rs | 11 ++--
core/shard/src/router.rs | 24 ++++++--
core/simulator/src/packet.rs | 3 +-
10 files changed, 140 insertions(+), 104 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 140692a6f..6f4d7aa02 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1595,6 +1595,20 @@ name = "bytemuck"
version = "1.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec"
+dependencies = [
+ "bytemuck_derive",
+]
+
+[[package]]
+name = "bytemuck_derive"
+version = "1.10.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.117",
+]
[[package]]
name = "byteorder"
diff --git a/Cargo.toml b/Cargo.toml
index 2640a12a3..d91476771 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -88,7 +88,7 @@ bit-set = "0.8.0"
blake3 = "1.8.3"
bon = "3.9.0"
byte-unit = { version = "5.2.0", default-features = false, features =
["serde", "byte", "std"] }
-bytemuck = { version = "1.25" }
+bytemuck = { version = "1.25", features = ["derive", "min_const_generics"] }
bytes = "1.11.1"
charming = "0.6.0"
chrono = { version = "0.4.44", features = ["serde"] }
diff --git a/DEPENDENCIES.md b/DEPENDENCIES.md
index 7cb477ac4..a4069999c 100644
--- a/DEPENDENCIES.md
+++ b/DEPENDENCIES.md
@@ -131,6 +131,7 @@ bytecheck: 0.6.12, "MIT",
bytecheck_derive: 0.6.12, "MIT",
bytecount: 0.6.9, "Apache-2.0 OR MIT",
bytemuck: 1.25.0, "Apache-2.0 OR MIT OR Zlib",
+bytemuck_derive: 1.10.2, "Apache-2.0 OR MIT OR Zlib",
byteorder: 1.5.0, "MIT OR Unlicense",
byteorder-lite: 0.1.0, "MIT OR Unlicense",
bytes: 1.11.1, "MIT",
diff --git a/core/common/src/types/consensus/header.rs
b/core/common/src/types/consensus/header.rs
index c096fa44f..e7d62e758 100644
--- a/core/common/src/types/consensus/header.rs
+++ b/core/common/src/types/consensus/header.rs
@@ -15,12 +15,12 @@
// specific language governing permissions and limitations
// under the License.
-use bytemuck::{Pod, Zeroable};
+use bytemuck::{CheckedBitPattern, NoUninit};
use enumset::EnumSetType;
use thiserror::Error;
const HEADER_SIZE: usize = 256;
-pub trait ConsensusHeader: Sized + Pod + Zeroable {
+pub trait ConsensusHeader: Sized + CheckedBitPattern + NoUninit {
const COMMAND: Command2;
fn validate(&self) -> Result<(), ConsensusError>;
@@ -51,6 +51,18 @@ pub enum Command2 {
StartView = 12,
}
+// SAFETY: Command2 is #[repr(u8)] with no padding bytes.
+unsafe impl NoUninit for Command2 {}
+
+// SAFETY: Command2 is #[repr(u8)]; is_valid_bit_pattern matches all defined
discriminants.
+unsafe impl CheckedBitPattern for Command2 {
+ type Bits = u8;
+
+ fn is_valid_bit_pattern(bits: &u8) -> bool {
+ *bits <= 12
+ }
+}
+
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum ConsensusError {
#[error("invalid command: expected {expected:?}, found {found:?}")]
@@ -89,13 +101,16 @@ pub enum ConsensusError {
#[error("context_padding must be 0")]
ReplyContextPaddingNonZero,
+
+ #[error("invalid bit pattern in header (enum discriminant out of range)")]
+ InvalidBitPattern,
}
-#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
+#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, NoUninit,
CheckedBitPattern)]
#[repr(u8)]
pub enum Operation {
#[default]
- Default = 0,
+ Reserved = 0,
CreateStream = 128,
UpdateStream = 129,
DeleteStream = 130,
@@ -120,8 +135,6 @@ pub enum Operation {
// Partition operations (replicated via consensus)
SendMessages = 160,
StoreConsumerOffset = 161,
-
- Reserved = 200,
}
impl Operation {
@@ -165,7 +178,7 @@ impl Operation {
}
#[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(Debug, Clone, Copy, CheckedBitPattern, NoUninit)]
pub struct GenericHeader {
pub checksum: u128,
pub checksum_body: u128,
@@ -194,14 +207,11 @@ const _: () = {
);
};
-unsafe impl Pod for GenericHeader {}
-unsafe impl Zeroable for GenericHeader {}
-
impl ConsensusHeader for GenericHeader {
const COMMAND: Command2 = Command2::Reserved;
fn operation(&self) -> Operation {
- Operation::Default
+ Operation::Reserved
}
fn command(&self) -> Command2 {
@@ -218,7 +228,7 @@ impl ConsensusHeader for GenericHeader {
}
#[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(Debug, Clone, Copy, CheckedBitPattern, NoUninit)]
pub struct RequestHeader {
pub checksum: u128,
pub checksum_body: u128,
@@ -263,14 +273,14 @@ impl Default for RequestHeader {
size: 0,
view: 0,
release: 0,
- command: Default::default(),
+ command: Command2::Reserved,
replica: 0,
reserved_frame: [0; 66],
client: 0,
request_checksum: 0,
timestamp: 0,
request: 0,
- operation: Default::default(),
+ operation: Operation::Reserved,
operation_padding: [0; 7],
namespace: 0,
reserved: [0; 64],
@@ -278,9 +288,6 @@ impl Default for RequestHeader {
}
}
-unsafe impl Pod for RequestHeader {}
-unsafe impl Zeroable for RequestHeader {}
-
impl ConsensusHeader for RequestHeader {
const COMMAND: Command2 = Command2::Request;
@@ -308,7 +315,7 @@ impl ConsensusHeader for RequestHeader {
// TODO: Manually impl default (and use a const for the `release`)
#[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(Debug, Clone, Copy, CheckedBitPattern, NoUninit)]
pub struct PrepareHeader {
pub checksum: u128,
pub checksum_body: u128,
@@ -347,9 +354,6 @@ const _: () = {
);
};
-unsafe impl Pod for PrepareHeader {}
-unsafe impl Zeroable for PrepareHeader {}
-
impl ConsensusHeader for PrepareHeader {
const COMMAND: Command2 = Command2::Prepare;
@@ -384,7 +388,7 @@ impl Default for PrepareHeader {
size: 0,
view: 0,
release: 0,
- command: Default::default(),
+ command: Command2::Reserved,
replica: 0,
reserved_frame: [0; 66],
client: 0,
@@ -394,7 +398,7 @@ impl Default for PrepareHeader {
commit: 0,
timestamp: 0,
request: 0,
- operation: Default::default(),
+ operation: Operation::Reserved,
operation_padding: [0; 7],
namespace: 0,
reserved: [0; 32],
@@ -404,7 +408,7 @@ impl Default for PrepareHeader {
// TODO: Manually impl default (and use a const for the `release`)
#[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(Debug, Clone, Copy, CheckedBitPattern, NoUninit)]
pub struct PrepareOkHeader {
pub checksum: u128,
pub checksum_body: u128,
@@ -442,9 +446,6 @@ const _: () = {
);
};
-unsafe impl Pod for PrepareOkHeader {}
-unsafe impl Zeroable for PrepareOkHeader {}
-
impl ConsensusHeader for PrepareOkHeader {
const COMMAND: Command2 = Command2::PrepareOk;
@@ -479,7 +480,7 @@ impl Default for PrepareOkHeader {
size: 0,
view: 0,
release: 0,
- command: Default::default(),
+ command: Command2::Reserved,
replica: 0,
reserved_frame: [0; 66],
parent: 0,
@@ -488,7 +489,7 @@ impl Default for PrepareOkHeader {
commit: 0,
timestamp: 0,
request: 0,
- operation: Default::default(),
+ operation: Operation::Reserved,
operation_padding: [0; 7],
namespace: 0,
reserved: [0; 48],
@@ -497,7 +498,7 @@ impl Default for PrepareOkHeader {
}
#[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(Debug, Clone, Copy, CheckedBitPattern, NoUninit)]
pub struct CommitHeader {
pub checksum: u128,
pub checksum_body: u128,
@@ -531,14 +532,11 @@ const _: () = {
);
};
-unsafe impl Pod for CommitHeader {}
-unsafe impl Zeroable for CommitHeader {}
-
impl ConsensusHeader for CommitHeader {
const COMMAND: Command2 = Command2::Commit;
fn operation(&self) -> Operation {
- Operation::Default
+ Operation::Reserved
}
fn command(&self) -> Command2 {
self.command
@@ -560,7 +558,7 @@ impl ConsensusHeader for CommitHeader {
}
#[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(Debug, Clone, Copy, CheckedBitPattern, NoUninit)]
pub struct ReplyHeader {
pub checksum: u128,
pub checksum_body: u128,
@@ -598,9 +596,6 @@ const _: () = {
);
};
-unsafe impl Pod for ReplyHeader {}
-unsafe impl Zeroable for ReplyHeader {}
-
impl ConsensusHeader for ReplyHeader {
const COMMAND: Command2 = Command2::Reply;
@@ -632,7 +627,7 @@ impl Default for ReplyHeader {
size: 0,
view: 0,
release: 0,
- command: Default::default(),
+ command: Command2::Reserved,
replica: 0,
reserved_frame: [0; 66],
request_checksum: 0,
@@ -641,7 +636,7 @@ impl Default for ReplyHeader {
commit: 0,
timestamp: 0,
request: 0,
- operation: Default::default(),
+ operation: Operation::Reserved,
operation_padding: [0; 7],
namespace: 0,
reserved: [0; 48],
@@ -653,7 +648,7 @@ impl Default for ReplyHeader {
///
/// Sent by a replica when it suspects the primary has failed.
/// This is a header-only message with no body.
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, CheckedBitPattern, NoUninit)]
#[repr(C)]
pub struct StartViewChangeHeader {
pub checksum: u128,
@@ -684,14 +679,11 @@ const _: () = {
);
};
-unsafe impl Pod for StartViewChangeHeader {}
-unsafe impl Zeroable for StartViewChangeHeader {}
-
impl ConsensusHeader for StartViewChangeHeader {
const COMMAND: Command2 = Command2::StartViewChange;
fn operation(&self) -> Operation {
- Operation::Default
+ Operation::Reserved
}
fn command(&self) -> Command2 {
self.command
@@ -720,7 +712,7 @@ impl ConsensusHeader for StartViewChangeHeader {
///
/// Sent by replicas to the primary candidate after collecting a quorum of
/// StartViewChange messages.
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, CheckedBitPattern, NoUninit)]
#[repr(C)]
pub struct DoViewChangeHeader {
pub checksum: u128,
@@ -761,14 +753,11 @@ const _: () = {
);
};
-unsafe impl Pod for DoViewChangeHeader {}
-unsafe impl Zeroable for DoViewChangeHeader {}
-
impl ConsensusHeader for DoViewChangeHeader {
const COMMAND: Command2 = Command2::DoViewChange;
fn operation(&self) -> Operation {
- Operation::Default
+ Operation::Reserved
}
fn command(&self) -> Command2 {
self.command
@@ -813,7 +802,7 @@ impl ConsensusHeader for DoViewChangeHeader {
///
/// Sent by the new primary to all replicas after collecting a quorum of
/// DoViewChange messages.
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, CheckedBitPattern, NoUninit)]
#[repr(C)]
pub struct StartViewHeader {
pub checksum: u128,
@@ -851,14 +840,11 @@ const _: () = {
);
};
-unsafe impl Pod for StartViewHeader {}
-unsafe impl Zeroable for StartViewHeader {}
-
impl ConsensusHeader for StartViewHeader {
const COMMAND: Command2 = Command2::StartView;
fn operation(&self) -> Operation {
- Operation::Default
+ Operation::Reserved
}
fn command(&self) -> Command2 {
self.command
diff --git a/core/common/src/types/consensus/message.rs
b/core/common/src/types/consensus/message.rs
index 7ef39cdca..5a49671db 100644
--- a/core/common/src/types/consensus/message.rs
+++ b/core/common/src/types/consensus/message.rs
@@ -37,7 +37,8 @@ where
{
fn header(&self) -> &H {
let header_bytes = &self.buffer[..size_of::<H>()];
- bytemuck::from_bytes(header_bytes)
+ bytemuck::checked::try_from_bytes(header_bytes)
+ .expect("header validated at construction time")
}
}
@@ -55,22 +56,17 @@ where
#[allow(unused)]
pub fn header(&self) -> &H {
let header_bytes = &self.buffer[..size_of::<H>()];
- bytemuck::from_bytes(header_bytes)
+ bytemuck::checked::try_from_bytes(header_bytes)
+ .expect("header validated at construction time")
}
/// Create a new message from a buffer.
///
- /// # Safety
- ///
- /// The buffer must:
- /// - be at least `size_of::<H>()` bytes long
- /// - contain a valid header at the start
- /// - be properly aligned for type H
- ///
/// # Errors
///
/// Returns an error if:
/// - buffer is too small for the header
+ /// - buffer contains invalid bit patterns (enum discriminants)
/// - buffer is too small for the size specified in the header
/// - header validation fails
#[allow(unused)]
@@ -83,24 +79,27 @@ where
});
}
- let message = Self {
- buffer,
- _marker: PhantomData,
- };
+ // Validate bit patterns (enum discriminants) via try_from_bytes
+ let header_bytes = &buffer[..size_of::<H>()];
+ let header = bytemuck::checked::try_from_bytes::<H>(header_bytes)
+ .map_err(|_| header::ConsensusError::InvalidBitPattern)?;
// validate the header
- message.header().validate()?;
+ header.validate()?;
// verify buffer size matches header.size
- let header_size = message.header().size() as usize;
- if message.buffer.len() < header_size {
+ let header_size = header.size() as usize;
+ if buffer.len() < header_size {
return Err(header::ConsensusError::InvalidCommand {
expected: H::COMMAND,
found: header::Command2::Reserved,
});
}
- Ok(message)
+ Ok(Self {
+ buffer,
+ _marker: PhantomData,
+ })
}
/// Create a new message with a specific size, initializing the buffer
with zeros.
@@ -131,7 +130,10 @@ where
unsafe {
let ptr = buffer.as_ptr() as *mut u8;
let slice = std::slice::from_raw_parts_mut(ptr, size_of::<T>());
- let new_header = bytemuck::from_bytes_mut(slice);
+ // Zero out to ensure valid bit patterns before creating a typed
reference.
+ slice.fill(0);
+ let new_header =
+ bytemuck::checked::try_from_bytes_mut(slice).expect("zeroed
bytes are valid");
f(old_header, new_header);
}
@@ -250,11 +252,14 @@ where
});
}
- let new_message = unsafe {
Message::<T>::from_buffer_unchecked(self.buffer) };
+ // Validate bit patterns for the target type
+ let header_bytes = &self.buffer[..size_of::<T>()];
+ let header = bytemuck::checked::try_from_bytes::<T>(header_bytes)
+ .map_err(|_| header::ConsensusError::InvalidBitPattern)?;
- new_message.header().validate()?;
+ header.validate()?;
- Ok(new_message)
+ Ok(unsafe { Message::<T>::from_buffer_unchecked(self.buffer) })
}
/// Try to get a reference to this message as a different header type.
@@ -282,10 +287,13 @@ where
});
}
- let typed_message = unsafe { &*(self as *const Self as *const
Message<T>) };
+ // Validate bit patterns for the target type
+ let header_bytes = &self.buffer[..size_of::<T>()];
+ bytemuck::checked::try_from_bytes::<T>(header_bytes)
+ .map_err(|_| header::ConsensusError::InvalidBitPattern)?
+ .validate()?;
- // validate the header
- typed_message.header().validate()?;
+ let typed_message = unsafe { &*(self as *const Self as *const
Message<T>) };
Ok(typed_message)
}
@@ -328,13 +336,14 @@ impl MessageBag {
}
}
-impl<T> From<Message<T>> for MessageBag
+impl<T> TryFrom<Message<T>> for MessageBag
where
T: ConsensusHeader,
{
- fn from(value: Message<T>) -> Self {
- let command = value.as_generic().header().command;
+ type Error = header::ConsensusError;
+ fn try_from(value: Message<T>) -> Result<Self, Self::Error> {
+ let command = value.as_generic().header().command;
let buffer = value.into_inner();
// SAFETY: All Message<H> types have identical memory layout (only
PhantomData differs).
@@ -343,19 +352,22 @@ where
header::Command2::Prepare => {
let msg =
unsafe {
Message::<header::PrepareHeader>::from_buffer_unchecked(buffer) };
- MessageBag::Prepare(msg)
+ Ok(MessageBag::Prepare(msg))
}
header::Command2::Request => {
let msg =
unsafe {
Message::<header::RequestHeader>::from_buffer_unchecked(buffer) };
- MessageBag::Request(msg)
+ Ok(MessageBag::Request(msg))
}
header::Command2::PrepareOk => {
let msg =
unsafe {
Message::<header::PrepareOkHeader>::from_buffer_unchecked(buffer) };
- MessageBag::PrepareOk(msg)
+ Ok(MessageBag::PrepareOk(msg))
}
- _ => unreachable!(),
+ other => Err(header::ConsensusError::InvalidCommand {
+ expected: header::Command2::Reserved,
+ found: other,
+ }),
}
}
}
@@ -378,7 +390,8 @@ mod tests {
let mut buffer = BytesMut::zeroed(total_size);
- let header = bytemuck::from_bytes_mut::<Self>(&mut
buffer[..header_size]);
+ let header = bytemuck::checked::try_from_bytes_mut::<Self>(&mut
buffer[..header_size])
+ .expect("zeroed bytes are valid");
header.checksum = 123456;
header.cluster = 12345;
@@ -406,7 +419,9 @@ mod tests {
let mut buffer = BytesMut::zeroed(total_size);
- let header = bytemuck::from_bytes_mut::<Self>(&mut
buffer[..header_size]);
+ // Zeroed bytes are valid (Command2::Reserved=0,
Operation::Reserved=0).
+ let header = bytemuck::checked::try_from_bytes_mut::<Self>(&mut
buffer[..header_size])
+ .expect("zeroed bytes are valid");
header.checksum = 123456;
header.checksum_body = 789012;
@@ -431,7 +446,8 @@ mod tests {
let mut buffer = BytesMut::zeroed(total_size);
- let header = bytemuck::from_bytes_mut::<Self>(&mut
buffer[..header_size]);
+ let header = bytemuck::checked::try_from_bytes_mut::<Self>(&mut
buffer[..header_size])
+ .expect("zeroed bytes are valid");
header.checksum = 123456;
header.cluster = 12345;
@@ -453,7 +469,8 @@ mod tests {
let mut buffer = BytesMut::zeroed(total_size);
- let header = bytemuck::from_bytes_mut::<Self>(&mut
buffer[..header_size]);
+ let header = bytemuck::checked::try_from_bytes_mut::<Self>(&mut
buffer[..header_size])
+ .expect("zeroed bytes are valid");
header.checksum = 123456;
header.cluster = 12345;
@@ -515,7 +532,7 @@ mod tests {
#[test]
fn test_message_bag_from_prepare() {
let prepare = header::PrepareHeader::create_test();
- let bag = MessageBag::from(prepare);
+ let bag = MessageBag::try_from(prepare).expect("valid prepare
message");
assert_eq!(bag.command(), header::Command2::Prepare);
assert!(matches!(bag, MessageBag::Prepare(_)));
diff --git a/core/consensus/src/plane_helpers.rs
b/core/consensus/src/plane_helpers.rs
index 55d21534a..31356a448 100644
--- a/core/consensus/src/plane_helpers.rs
+++ b/core/consensus/src/plane_helpers.rs
@@ -240,7 +240,8 @@ where
let total_size = header_size + body.len();
let mut buffer = bytes::BytesMut::zeroed(total_size);
- let header = bytemuck::from_bytes_mut::<ReplyHeader>(&mut
buffer[..header_size]);
+ let header = bytemuck::checked::try_from_bytes_mut::<ReplyHeader>(&mut
buffer[..header_size])
+ .expect("zeroed bytes are valid");
*header = ReplyHeader {
checksum: 0,
checksum_body: 0,
diff --git a/core/metadata/src/stm/mod.rs b/core/metadata/src/stm/mod.rs
index 5171207ae..e25aaaf65 100644
--- a/core/metadata/src/stm/mod.rs
+++ b/core/metadata/src/stm/mod.rs
@@ -248,7 +248,6 @@ macro_rules! collect_handlers {
use ::iggy_common::BytesSerializable;
use ::iggy_common::Either;
use ::iggy_common::header::Operation;
-
match input.header().operation {
$(
Operation::$operation => {
diff --git a/core/shard/src/lib.rs b/core/shard/src/lib.rs
index e41e210a6..4e03d0a4f 100644
--- a/core/shard/src/lib.rs
+++ b/core/shard/src/lib.rs
@@ -171,10 +171,13 @@ where
Error = iggy_common::IggyError,
>,
{
- match MessageBag::from(message) {
- MessageBag::Request(request) => self.on_request(request).await,
- MessageBag::Prepare(prepare) => self.on_replicate(prepare).await,
- MessageBag::PrepareOk(prepare_ok) => self.on_ack(prepare_ok).await,
+ match MessageBag::try_from(message) {
+ Ok(MessageBag::Request(request)) => self.on_request(request).await,
+ Ok(MessageBag::Prepare(prepare)) =>
self.on_replicate(prepare).await,
+ Ok(MessageBag::PrepareOk(prepare_ok)) =>
self.on_ack(prepare_ok).await,
+ Err(e) => {
+ tracing::warn!(shard = self.id, error = %e, "dropping message
with invalid command");
+ }
}
}
diff --git a/core/shard/src/router.rs b/core/shard/src/router.rs
index 0a68b98bf..5e75cb9dc 100644
--- a/core/shard/src/router.rs
+++ b/core/shard/src/router.rs
@@ -18,7 +18,7 @@
use crate::shards_table::ShardsTable;
use crate::{IggyShard, Receiver, ShardFrame};
use futures::FutureExt;
-use iggy_common::header::{GenericHeader, PrepareHeader};
+use iggy_common::header::{ConsensusError, GenericHeader, PrepareHeader};
use iggy_common::message::{Message, MessageBag};
use iggy_common::sharding::IggyNamespace;
use journal::{Journal, JournalHandle};
@@ -44,7 +44,14 @@ where
/// or `PrepareOk`) to access the operation and namespace, then resolves
/// the target shard and enqueues the message via its channel sender.
pub fn dispatch(&self, message: Message<GenericHeader>) {
- let (operation, namespace, generic) = match MessageBag::from(message) {
+ let bag = match MessageBag::try_from(message) {
+ Ok(bag) => bag,
+ Err(e) => {
+ tracing::warn!(shard = self.id, error = %e, "dropping message
with invalid command");
+ return;
+ }
+ };
+ let (operation, namespace, generic) = match bag {
MessageBag::Request(ref r) => {
let h = r.header();
(h.operation, h.namespace, r.as_generic().clone())
@@ -80,8 +87,15 @@ where
/// Dispatch a message and return a receiver that resolves when the target
/// shard has finished processing it.
- pub fn dispatch_request(&self, message: Message<GenericHeader>) ->
Receiver<R> {
- let (operation, namespace, generic) = match MessageBag::from(message) {
+ ///
+ /// # Errors
+ /// Returns `ConsensusError` if the message cannot be routed.
+ pub fn dispatch_request(
+ &self,
+ message: Message<GenericHeader>,
+ ) -> Result<Receiver<R>, ConsensusError> {
+ let bag = MessageBag::try_from(message)?;
+ let (operation, namespace, generic) = match bag {
MessageBag::Request(ref r) => {
let h = r.header();
(h.operation, h.namespace, r.as_generic().clone())
@@ -123,7 +137,7 @@ where
// Create a frame and send it to the target shard.
let (frame, rx) = ShardFrame::<R>::with_response(generic);
let _ = self.senders[target as usize].send(frame);
- rx
+ Ok(rx)
}
/// Drain this shard's inbox and process each frame locally.
diff --git a/core/simulator/src/packet.rs b/core/simulator/src/packet.rs
index 59317adcc..9aba6e17a 100644
--- a/core/simulator/src/packet.rs
+++ b/core/simulator/src/packet.rs
@@ -741,7 +741,8 @@ mod tests {
fn create_test_message_with_command(command: Command2) ->
Message<GenericHeader> {
let size = std::mem::size_of::<GenericHeader>();
let mut buf = vec![0u8; size];
- let header: &mut GenericHeader = bytemuck::from_bytes_mut(&mut buf);
+ let header: &mut GenericHeader =
+ bytemuck::checked::try_from_bytes_mut(&mut buf).expect("zeroed
bytes are valid");
header.command = command;
Message::<GenericHeader>::from_bytes(bytes::Bytes::from(buf)).unwrap()
}