This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new a565643ed fix(c++): fix buffer read/write bound check (#3418)
a565643ed is described below
commit a565643ed0c3abf3fa786595f6f8dbf76a22cc5e
Author: Shawn Yang <[email protected]>
AuthorDate: Thu Feb 26 01:30:54 2026 +0800
fix(c++): fix buffer read/write bound check (#3418)
## Why?
This PR hardens C++ and Rust deserialization paths against
truncated/corrupt inputs and prevents inconsistent type registration
state. It also enforces that xlang/non-xlang payloads are deserialized
by matching protocol configs.
## What does this PR do?
- C++: lock type registration after first serialize/deserialize, and
route all register APIs through guarded `register_type(...)`.
- C++: reject protocol mismatch when payload `is_xlang` flag differs
from local config.
- C++: make `TypeResolver::register_type_internal` validate uniqueness
before committing entries, so failed registrations do not leak partial
type info.
- C++: harden `TypeMeta` size handling and `Buffer` varint/fixed reads
with strict bounds checks and non-advancing error behavior on truncated
data.
- Rust: add overflow-safe reader bound checks for fixed-width reads and
`read_varuint36small`.
- Rust: make row `get(...)` APIs return `Result` instead of panicking on
out-of-bounds; propagate errors in map materialization.
- Tests: add C++ serialization/buffer regression tests and Rust
buffer/row tests covering the new error paths.
## Related issues
- None.
## Does this PR introduce any user-facing change?
- [x] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
- N/A
---
.gitignore | 4 +-
cpp/fory/serialization/fory.h | 82 ++++++++++++++++++----
cpp/fory/serialization/serialization_test.cc | 101 +++++++++++++++++++++++++++
cpp/fory/serialization/serializer.h | 16 ++---
cpp/fory/serialization/type_resolver.cc | 24 +++++--
cpp/fory/serialization/type_resolver.h | 51 +++++++++-----
cpp/fory/util/buffer.h | 97 +++++++++++++++++++++----
cpp/fory/util/buffer_test.cc | 51 ++++++++++++++
rust/fory-core/src/buffer.rs | 35 ++++++----
rust/fory-core/src/row/row.rs | 22 ++++--
rust/fory/src/lib.rs | 2 +-
rust/tests/tests/test_buffer.rs | 20 ++++++
rust/tests/tests/test_row.rs | 29 ++++----
13 files changed, 438 insertions(+), 96 deletions(-)
diff --git a/.gitignore b/.gitignore
index 2c0f134e6..257ce535a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -115,4 +115,6 @@ csharp/src/Fory/obj/
csharp/tests/Fory.Tests/bin/
csharp/tests/Fory.Tests/obj/
csharp/tests/Fory.XlangPeer/bin/
-csharp/tests/Fory.XlangPeer/obj/
\ No newline at end of file
+csharp/tests/Fory.XlangPeer/obj/
+
+tasks/
\ No newline at end of file
diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h
index 8b1428c96..53735400b 100644
--- a/cpp/fory/serialization/fory.h
+++ b/cpp/fory/serialization/fory.h
@@ -230,7 +230,9 @@ public:
/// fory.register_struct<MyStruct>(1);
/// ```
template <typename T> Result<void, Error> register_struct(uint32_t type_id) {
- return type_resolver_->template register_by_id<T>(type_id);
+ return register_type([this, type_id]() {
+ return type_resolver_->template register_by_id<T>(type_id);
+ });
}
/// Register a struct type with namespace and type name.
@@ -250,7 +252,9 @@ public:
template <typename T>
Result<void, Error> register_struct(const std::string &ns,
const std::string &type_name) {
- return type_resolver_->template register_by_name<T>(ns, type_name);
+ return register_type([this, &ns, &type_name]() {
+ return type_resolver_->template register_by_name<T>(ns, type_name);
+ });
}
/// Register a struct type with type name only (no namespace).
@@ -267,7 +271,9 @@ public:
/// ```
template <typename T>
Result<void, Error> register_struct(const std::string &type_name) {
- return type_resolver_->template register_by_name<T>("", type_name);
+ return register_type([this, &type_name]() {
+ return type_resolver_->template register_by_name<T>("", type_name);
+ });
}
/// Register an enum type with a numeric type ID.
@@ -288,7 +294,9 @@ public:
/// fory.register_enum<Color>(1);
/// ```
template <typename T> Result<void, Error> register_enum(uint32_t type_id) {
- return type_resolver_->template register_by_id<T>(type_id);
+ return register_type([this, type_id]() {
+ return type_resolver_->template register_by_id<T>(type_id);
+ });
}
/// Register an enum type with namespace and type name.
@@ -308,7 +316,9 @@ public:
template <typename T>
Result<void, Error> register_enum(const std::string &ns,
const std::string &type_name) {
- return type_resolver_->template register_by_name<T>(ns, type_name);
+ return register_type([this, &ns, &type_name]() {
+ return type_resolver_->template register_by_name<T>(ns, type_name);
+ });
}
/// Register an enum type with type name only (no namespace).
@@ -325,7 +335,9 @@ public:
/// ```
template <typename T>
Result<void, Error> register_enum(const std::string &type_name) {
- return type_resolver_->template register_by_name<T>("", type_name);
+ return register_type([this, &type_name]() {
+ return type_resolver_->template register_by_name<T>("", type_name);
+ });
}
/// Register a union type with a numeric type ID.
@@ -336,7 +348,9 @@ public:
/// @param type_id Unique numeric identifier for this union type.
/// @return Success or error if registration fails.
template <typename T> Result<void, Error> register_union(uint32_t type_id) {
- return type_resolver_->template register_union_by_id<T>(type_id);
+ return register_type([this, type_id]() {
+ return type_resolver_->template register_union_by_id<T>(type_id);
+ });
}
/// Register a union type with namespace and type name.
@@ -348,7 +362,9 @@ public:
template <typename T>
Result<void, Error> register_union(const std::string &ns,
const std::string &type_name) {
- return type_resolver_->template register_union_by_name<T>(ns, type_name);
+ return register_type([this, &ns, &type_name]() {
+ return type_resolver_->template register_union_by_name<T>(ns, type_name);
+ });
}
/// Register a union type with type name only (no namespace).
@@ -358,7 +374,9 @@ public:
/// @return Success or error if registration fails.
template <typename T>
Result<void, Error> register_union(const std::string &type_name) {
- return type_resolver_->template register_union_by_name<T>("", type_name);
+ return register_type([this, &type_name]() {
+ return type_resolver_->template register_union_by_name<T>("", type_name);
+ });
}
/// Register an extension type with a numeric type ID.
@@ -371,7 +389,9 @@ public:
/// @return Success or error if registration fails.
template <typename T>
Result<void, Error> register_extension_type(uint32_t type_id) {
- return type_resolver_->template register_ext_type_by_id<T>(type_id);
+ return register_type([this, type_id]() {
+ return type_resolver_->template register_ext_type_by_id<T>(type_id);
+ });
}
/// Register an extension type with namespace and type name.
@@ -383,7 +403,10 @@ public:
template <typename T>
Result<void, Error> register_extension_type(const std::string &ns,
const std::string &type_name) {
- return type_resolver_->template register_ext_type_by_name<T>(ns,
type_name);
+ return register_type([this, &ns, &type_name]() {
+ return type_resolver_->template register_ext_type_by_name<T>(ns,
+ type_name);
+ });
}
/// Register an extension type with type name only (no namespace).
@@ -393,10 +416,29 @@ public:
/// @return Success or error if registration fails.
template <typename T>
Result<void, Error> register_extension_type(const std::string &type_name) {
- return type_resolver_->template register_ext_type_by_name<T>("",
type_name);
+ return register_type([this, &type_name]() {
+ return type_resolver_->template register_ext_type_by_name<T>("",
+ type_name);
+ });
+ }
+
+private:
+ template <typename RegisterFn>
+ Result<void, Error> register_type(RegisterFn &&fn) {
+ std::lock_guard<std::mutex> lock(registration_mutex_);
+ if (FORY_PREDICT_FALSE(registration_locked_)) {
+ return Unexpected(Error::invalid(
+ "Cannot register types after first serialize/deserialize call"));
+ }
+ return std::forward<RegisterFn>(fn)();
}
protected:
+ void lock_registration() const {
+ std::lock_guard<std::mutex> lock(registration_mutex_);
+ registration_locked_ = true;
+ }
+
/// Protected constructor - only derived classes can instantiate.
explicit BaseFory(const Config &config,
std::shared_ptr<TypeResolver> resolver)
@@ -412,6 +454,8 @@ protected:
Config config_;
std::shared_ptr<TypeResolver> type_resolver_;
+ mutable std::mutex registration_mutex_;
+ mutable bool registration_locked_{false};
};
// ============================================================================
@@ -527,6 +571,12 @@ public:
if (header.is_null) {
return Unexpected(Error::invalid_data("Cannot deserialize null object"));
}
+ if (FORY_PREDICT_FALSE(header.is_xlang != config_.xlang)) {
+ return Unexpected(Error::invalid_data(
+ "Protocol mismatch: payload xlang=" +
+ std::string(header.is_xlang ? "true" : "false") +
+ ", local xlang=" + std::string(config_.xlang ? "true" : "false")));
+ }
read_ctx_->attach(buffer);
ReadContextGuard guard(*read_ctx_);
@@ -560,6 +610,12 @@ public:
if (header.is_null) {
return Unexpected(Error::invalid_data("Cannot deserialize null object"));
}
+ if (FORY_PREDICT_FALSE(header.is_xlang != config_.xlang)) {
+ return Unexpected(Error::invalid_data(
+ "Protocol mismatch: payload xlang=" +
+ std::string(header.is_xlang ? "true" : "false") +
+ ", local xlang=" + std::string(config_.xlang ? "true" : "false")));
+ }
read_ctx_->attach(buffer);
ReadContextGuard guard(*read_ctx_);
@@ -601,6 +657,7 @@ private:
/// Finalize the type resolver on first use.
void ensure_finalized() {
if (!finalized_) {
+ lock_registration();
auto final_result = type_resolver_->build_final_type_resolver();
FORY_CHECK(final_result.ok())
<< "Failed to build finalized TypeResolver: "
@@ -746,6 +803,7 @@ private:
std::shared_ptr<TypeResolver> get_finalized_resolver() const {
std::call_once(finalized_once_flag_, [this]() {
+ lock_registration();
auto final_result = type_resolver_->build_final_type_resolver();
FORY_CHECK(final_result.ok())
<< "Failed to build finalized TypeResolver: "
diff --git a/cpp/fory/serialization/serialization_test.cc
b/cpp/fory/serialization/serialization_test.cc
index 110b75daf..2f6b8ab63 100644
--- a/cpp/fory/serialization/serialization_test.cc
+++ b/cpp/fory/serialization/serialization_test.cc
@@ -417,6 +417,90 @@ TEST(SerializationTest, DeserializeZeroSize) {
EXPECT_FALSE(result.ok());
}
+TEST(SerializationTest, DeserializeRejectsXlangProtocolMismatch) {
+ auto writer = Fory::builder().xlang(true).build();
+ auto reader = Fory::builder().xlang(false).build();
+
+ auto bytes_result = writer.serialize<int32_t>(123);
+ ASSERT_TRUE(bytes_result.ok())
+ << "Serialization failed: " << bytes_result.error().to_string();
+
+ auto result = reader.deserialize<int32_t>(bytes_result.value().data(),
+ bytes_result.value().size());
+ EXPECT_FALSE(result.ok());
+ ASSERT_FALSE(result.ok());
+ EXPECT_EQ(result.error().code(), ErrorCode::InvalidData);
+ EXPECT_NE(result.error().to_string().find("Protocol mismatch"),
+ std::string::npos);
+}
+
+TEST(SerializationTest, RegistrationByIdFailureDoesNotLeakTypeInfo) {
+ auto fory = Fory::builder().xlang(true).track_ref(false).build();
+ TypeResolver &resolver = fory.type_resolver();
+
+ ASSERT_TRUE(fory.register_struct<::SimpleStruct>(1).ok());
+
+ auto duplicate = fory.register_struct<::NestedStruct>(1);
+ EXPECT_FALSE(duplicate.ok());
+ ASSERT_FALSE(duplicate.ok());
+ EXPECT_EQ(duplicate.error().code(), ErrorCode::Invalid);
+
+ auto nested_info = resolver.get_type_info<::NestedStruct>();
+ EXPECT_FALSE(nested_info.ok());
+
+ auto simple_info = resolver.get_type_info<::SimpleStruct>();
+ ASSERT_TRUE(simple_info.ok());
+ auto by_user_id =
+ resolver.get_user_type_info_by_id(simple_info.value()->type_id, 1);
+ ASSERT_TRUE(by_user_id.ok());
+ EXPECT_EQ(by_user_id.value(), simple_info.value());
+}
+
+TEST(SerializationTest, RegistrationByNameFailureDoesNotLeakTypeInfo) {
+ auto fory = Fory::builder().xlang(true).track_ref(false).build();
+ TypeResolver &resolver = fory.type_resolver();
+
+ ASSERT_TRUE(fory.register_struct<::SimpleStruct>("demo", "SharedType").ok());
+
+ auto duplicate = fory.register_struct<::NestedStruct>("demo", "SharedType");
+ EXPECT_FALSE(duplicate.ok());
+ ASSERT_FALSE(duplicate.ok());
+ EXPECT_EQ(duplicate.error().code(), ErrorCode::Invalid);
+
+ auto nested_info = resolver.get_type_info<::NestedStruct>();
+ EXPECT_FALSE(nested_info.ok());
+
+ auto simple_info = resolver.get_type_info<::SimpleStruct>();
+ ASSERT_TRUE(simple_info.ok());
+ auto by_name = resolver.get_type_info_by_name("demo", "SharedType");
+ ASSERT_TRUE(by_name.ok());
+ EXPECT_EQ(by_name.value(), simple_info.value());
+}
+
+TEST(SerializationTest, TypeMetaRejectsOverConsumedDeclaredSize) {
+ TypeMeta meta =
+ TypeMeta::from_fields(static_cast<uint32_t>(TypeId::STRUCT), "", "S",
+ false, 1, std::vector<FieldInfo>{});
+ auto bytes_result = meta.to_bytes();
+ ASSERT_TRUE(bytes_result.ok())
+ << "TypeMeta serialization failed: " << bytes_result.error().to_string();
+
+ std::vector<uint8_t> bytes = bytes_result.value();
+ ASSERT_GE(bytes.size(), sizeof(int64_t));
+
+ int64_t header = 0;
+ std::memcpy(&header, bytes.data(), sizeof(header));
+ // Corrupt declared meta_size to be much smaller than actual payload.
+ header = (header & ~static_cast<int64_t>(0xFF)) | 0x01;
+ std::memcpy(bytes.data(), &header, sizeof(header));
+
+ Buffer buffer(bytes);
+ auto parsed = TypeMeta::from_bytes(buffer, nullptr);
+ EXPECT_FALSE(parsed.ok());
+ ASSERT_FALSE(parsed.ok());
+ EXPECT_EQ(parsed.error().code(), ErrorCode::InvalidData);
+}
+
// ============================================================================
// Configuration Tests
// ============================================================================
@@ -478,6 +562,23 @@ TEST(SerializationTest, ThreadSafeForyMultiThread) {
EXPECT_EQ(success_count.load(), k_num_threads * k_iterations_per_thread);
}
+TEST(SerializationTest, ThreadSafeForyRejectsRegistrationAfterFirstSerialize) {
+ auto fory = Fory::builder().xlang(true).track_ref(false).build_thread_safe();
+ ASSERT_TRUE(fory.register_struct<::ComplexStruct>(1).ok());
+
+ ::ComplexStruct original{"Alice", 30, {"reading", "coding"}};
+ auto bytes_result = fory.serialize(original);
+ ASSERT_TRUE(bytes_result.ok())
+ << "Serialization failed: " << bytes_result.error().to_string();
+
+ auto late_registration = fory.register_struct<::SimpleStruct>(2);
+ EXPECT_FALSE(late_registration.ok());
+ ASSERT_FALSE(late_registration.ok());
+ EXPECT_EQ(late_registration.error().code(), ErrorCode::Invalid);
+ EXPECT_NE(late_registration.error().to_string().find("Cannot register
types"),
+ std::string::npos);
+}
+
} // namespace test
} // namespace serialization
} // namespace fory
diff --git a/cpp/fory/serialization/serializer.h
b/cpp/fory/serialization/serializer.h
index 5eeae493e..36477101a 100644
--- a/cpp/fory/serialization/serializer.h
+++ b/cpp/fory/serialization/serializer.h
@@ -83,24 +83,16 @@ struct HeaderInfo {
/// @param buffer Input buffer
/// @return Header information or error
inline Result<HeaderInfo, Error> read_header(Buffer &buffer) {
- // Check minimum header size (1 byte: flags)
- if (buffer.reader_index() + 1 > buffer.size()) {
- return Unexpected(
- Error::buffer_out_of_bound(buffer.reader_index(), 1, buffer.size()));
+ Error error;
+ uint8_t flags = buffer.read_uint8(error);
+ if (FORY_PREDICT_FALSE(!error.ok())) {
+ return Unexpected(std::move(error));
}
-
HeaderInfo info;
- uint32_t start_pos = buffer.reader_index();
-
- // Read flags byte
- uint8_t flags = buffer.get_byte_as<uint8_t>(start_pos);
info.is_null = (flags & (1 << 0)) != 0;
info.is_xlang = (flags & (1 << 1)) != 0;
info.is_oob = (flags & (1 << 2)) != 0;
- // Update reader index (1 byte consumed: flags)
- buffer.increase_reader_index(1);
-
// Note: Meta start offset would be read here if present
info.meta_start_offset = 0;
diff --git a/cpp/fory/serialization/type_resolver.cc
b/cpp/fory/serialization/type_resolver.cc
index a5c993017..e8070a3e9 100644
--- a/cpp/fory/serialization/type_resolver.cc
+++ b/cpp/fory/serialization/type_resolver.cc
@@ -553,9 +553,16 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta
*local_type_info) {
// CRITICAL FIX: Ensure we consume exactly meta_size bytes
size_t current_pos = buffer.reader_index();
size_t expected_end_pos = start_pos + header_size + meta_size;
+ if (FORY_PREDICT_FALSE(current_pos > expected_end_pos)) {
+ return Unexpected(Error::invalid_data(
+ "TypeMeta parser consumed beyond declared meta size"));
+ }
if (current_pos < expected_end_pos) {
size_t remaining = expected_end_pos - current_pos;
- buffer.increase_reader_index(remaining);
+ buffer.skip(static_cast<uint32_t>(remaining), error);
+ if (FORY_PREDICT_FALSE(!error.ok())) {
+ return Unexpected(std::move(error));
+ }
}
auto meta = std::make_unique<TypeMeta>();
@@ -574,16 +581,12 @@ Result<std::unique_ptr<TypeMeta>, Error>
TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header) {
Error error;
int64_t meta_size = header & META_SIZE_MASK;
- size_t header_size = 0;
if (meta_size == META_SIZE_MASK) {
- uint32_t before = buffer.reader_index();
uint32_t extra = buffer.read_var_uint32(error);
if (FORY_PREDICT_FALSE(!error.ok())) {
return Unexpected(std::move(error));
}
meta_size += extra;
- uint32_t after = buffer.reader_index();
- header_size = (after - before);
}
int64_t meta_hash = header >> (64 - NUM_HASH_BITS);
@@ -663,10 +666,17 @@ TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t
header) {
// CRITICAL FIX: Ensure we consume exactly meta_size bytes
size_t current_pos = buffer.reader_index();
- size_t expected_end_pos = start_pos + meta_size - header_size;
+ size_t expected_end_pos = start_pos + meta_size;
+ if (FORY_PREDICT_FALSE(current_pos > expected_end_pos)) {
+ return Unexpected(Error::invalid_data(
+ "TypeMeta parser consumed beyond declared meta size"));
+ }
if (current_pos < expected_end_pos) {
size_t remaining = expected_end_pos - current_pos;
- buffer.increase_reader_index(remaining);
+ buffer.skip(static_cast<uint32_t>(remaining), error);
+ if (FORY_PREDICT_FALSE(!error.ok())) {
+ return Unexpected(std::move(error));
+ }
}
auto meta = std::make_unique<TypeMeta>();
diff --git a/cpp/fory/serialization/type_resolver.h
b/cpp/fory/serialization/type_resolver.h
index e3489b44a..d13157b5e 100644
--- a/cpp/fory/serialization/type_resolver.h
+++ b/cpp/fory/serialization/type_resolver.h
@@ -1860,44 +1860,59 @@ TypeResolver::register_type_internal(uint64_t ctid,
Error::invalid("TypeInfo or harness is invalid during registration"));
}
- // Store in primary storage and get raw pointer
+ // Validate all uniqueness constraints before mutating resolver state so
+ // failed registration leaves no partial entries behind.
TypeInfo *raw_ptr = info.get();
- type_infos_.push_back(std::move(info));
-
- type_info_by_ctid_.put(ctid, raw_ptr);
+ const bool is_internal = ::fory::is_internal_type(raw_ptr->type_id);
+ const bool has_user_type_key =
+ !raw_ptr->register_by_name && raw_ptr->user_type_id !=
kInvalidUserTypeId;
+ uint64_t user_type_key = 0;
+ std::string name_key;
- if (::fory::is_internal_type(raw_ptr->type_id)) {
+ if (is_internal) {
TypeInfo *existing =
type_info_by_id_.get_or_default(raw_ptr->type_id, nullptr);
- if (existing != nullptr && existing != raw_ptr) {
+ if (existing != nullptr) {
return Unexpected(Error::invalid("Type id already registered: " +
std::to_string(raw_ptr->type_id)));
}
- type_info_by_id_.put(raw_ptr->type_id, raw_ptr);
- } else if (!raw_ptr->register_by_name &&
- raw_ptr->user_type_id != kInvalidUserTypeId) {
- uint64_t key = make_user_type_key(raw_ptr->type_id, raw_ptr->user_type_id);
- TypeInfo *existing = user_type_info_by_id_.get_or_default(key, nullptr);
- if (existing != nullptr && existing != raw_ptr) {
+ } else if (has_user_type_key) {
+ user_type_key = make_user_type_key(raw_ptr->type_id,
raw_ptr->user_type_id);
+ TypeInfo *existing =
+ user_type_info_by_id_.get_or_default(user_type_key, nullptr);
+ if (existing != nullptr) {
return Unexpected(Error::invalid(
"Type id already registered: " + std::to_string(raw_ptr->type_id) +
"/" + std::to_string(raw_ptr->user_type_id)));
}
- user_type_info_by_id_.put(key, raw_ptr);
}
if (raw_ptr->register_by_name) {
- auto key = make_name_key(raw_ptr->namespace_name, raw_ptr->type_name);
- auto it = type_info_by_name_.find(key);
- if (it != type_info_by_name_.end() && it->second != raw_ptr) {
+ name_key = make_name_key(raw_ptr->namespace_name, raw_ptr->type_name);
+ auto it = type_info_by_name_.find(name_key);
+ if (it != type_info_by_name_.end()) {
return Unexpected(Error::invalid(
"Type already registered for namespace '" + raw_ptr->namespace_name +
"' and name '" + raw_ptr->type_name + "'"));
}
- type_info_by_name_[key] = raw_ptr;
}
- return raw_ptr;
+ // Commit all state after validation passes.
+ type_infos_.push_back(std::move(info));
+ TypeInfo *stored_ptr = type_infos_.back().get();
+ type_info_by_ctid_.put(ctid, stored_ptr);
+
+ if (is_internal) {
+ type_info_by_id_.put(stored_ptr->type_id, stored_ptr);
+ } else if (has_user_type_key) {
+ user_type_info_by_id_.put(user_type_key, stored_ptr);
+ }
+
+ if (stored_ptr->register_by_name) {
+ type_info_by_name_[name_key] = stored_ptr;
+ }
+
+ return stored_ptr;
}
inline void
diff --git a/cpp/fory/util/buffer.h b/cpp/fory/util/buffer.h
index 0631d989b..b626fb7d9 100644
--- a/cpp/fory/util/buffer.h
+++ b/cpp/fory/util/buffer.h
@@ -201,20 +201,25 @@ public:
*target = 0;
return Result<void, Error>();
}
- if (size_ - (offset + 8) > 0) {
- uint64_t mask = 0xffffffffffffffff;
+ if (FORY_PREDICT_FALSE(length > 8)) {
+ return Unexpected(Error::invalid_data(
+ "get_bytes_as_int64 length should be in range [0, 8]"));
+ }
+ if (FORY_PREDICT_FALSE(offset > size_ || length > size_ - offset)) {
+ return Unexpected(Error::buffer_out_of_bound(offset, length, size_));
+ }
+ if (size_ - offset >= 8) {
+ uint64_t mask = std::numeric_limits<uint64_t>::max();
uint64_t x = (mask >> (8 - length) * 8);
- *target = get_int64(offset) & x;
- } else {
- if (size_ - (offset + length) < 0) {
- return Unexpected(Error::out_of_bound("buffer out of bound"));
- }
- int64_t result = 0;
- for (size_t i = 0; i < length; i++) {
- result = result | ((int64_t)(data_[offset + i])) << (i * 8);
- }
- *target = result;
+ *target =
+ static_cast<int64_t>(static_cast<uint64_t>(get_int64(offset)) & x);
+ return Result<void, Error>();
+ }
+ int64_t result = 0;
+ for (size_t i = 0; i < length; i++) {
+ result = result | ((int64_t)(data_[offset + i])) << (i * 8);
}
+ *target = result;
return Result<void, Error>();
}
@@ -259,6 +264,10 @@ public:
/// Slow path: byte-by-byte for buffer edge cases.
FORY_ALWAYS_INLINE uint32_t get_var_uint32(uint32_t offset,
uint32_t *read_bytes_length) {
+ if (FORY_PREDICT_FALSE(offset >= size_)) {
+ *read_bytes_length = 0;
+ return 0;
+ }
// Fast path: need at least 5 bytes for safe bulk read (4 bytes + potential
// 5th)
if (FORY_PREDICT_TRUE(size_ - offset >= 5)) {
@@ -299,19 +308,39 @@ public:
/// Slow path for get_var_uint32 when not enough bytes for bulk read.
uint32_t get_var_uint32_slow(uint32_t offset, uint32_t *read_bytes_length) {
+ if (FORY_PREDICT_FALSE(offset >= size_)) {
+ *read_bytes_length = 0;
+ return 0;
+ }
uint32_t position = offset;
int b = data_[position++];
uint32_t result = b & 0x7F;
if ((b & 0x80) != 0) {
+ if (FORY_PREDICT_FALSE(position >= size_)) {
+ *read_bytes_length = 0;
+ return 0;
+ }
b = data_[position++];
result |= (b & 0x7F) << 7;
if ((b & 0x80) != 0) {
+ if (FORY_PREDICT_FALSE(position >= size_)) {
+ *read_bytes_length = 0;
+ return 0;
+ }
b = data_[position++];
result |= (b & 0x7F) << 14;
if ((b & 0x80) != 0) {
+ if (FORY_PREDICT_FALSE(position >= size_)) {
+ *read_bytes_length = 0;
+ return 0;
+ }
b = data_[position++];
result |= (b & 0x7F) << 21;
if ((b & 0x80) != 0) {
+ if (FORY_PREDICT_FALSE(position >= size_)) {
+ *read_bytes_length = 0;
+ return 0;
+ }
b = data_[position++];
result |= (b & 0x7F) << 28;
}
@@ -383,6 +412,10 @@ public:
/// Uses PVL (Progressive Variable-length Long) encoding per xlang spec.
FORY_ALWAYS_INLINE uint64_t get_var_uint64(uint32_t offset,
uint32_t *read_bytes_length) {
+ if (FORY_PREDICT_FALSE(offset >= size_)) {
+ *read_bytes_length = 0;
+ return 0;
+ }
// Fast path: need at least 9 bytes for safe bulk read
if (FORY_PREDICT_TRUE(size_ - offset >= 9)) {
uint64_t bulk = *reinterpret_cast<uint64_t *>(data_ + offset);
@@ -438,10 +471,18 @@ public:
/// Slow path for get_var_uint64 when not enough bytes for bulk read.
uint64_t get_var_uint64_slow(uint32_t offset, uint32_t *read_bytes_length) {
+ if (FORY_PREDICT_FALSE(offset >= size_)) {
+ *read_bytes_length = 0;
+ return 0;
+ }
uint32_t position = offset;
uint64_t result = 0;
int shift = 0;
for (int i = 0; i < 8; ++i) {
+ if (FORY_PREDICT_FALSE(position >= size_)) {
+ *read_bytes_length = 0;
+ return 0;
+ }
uint8_t b = data_[position++];
result |= static_cast<uint64_t>(b & 0x7F) << shift;
if ((b & 0x80) == 0) {
@@ -450,6 +491,10 @@ public:
}
shift += 7;
}
+ if (FORY_PREDICT_FALSE(position >= size_)) {
+ *read_bytes_length = 0;
+ return 0;
+ }
uint8_t last = data_[position++];
result |= static_cast<uint64_t>(last) << 56;
*read_bytes_length = position - offset;
@@ -830,6 +875,10 @@ public:
}
uint32_t read_bytes = 0;
uint32_t value = get_var_uint32(reader_index_, &read_bytes);
+ if (FORY_PREDICT_FALSE(read_bytes == 0)) {
+ error.set_buffer_out_of_bound(reader_index_, 1, size_);
+ return 0;
+ }
increase_reader_index(read_bytes);
return value;
}
@@ -843,6 +892,10 @@ public:
}
uint32_t read_bytes = 0;
uint32_t raw = get_var_uint32(reader_index_, &read_bytes);
+ if (FORY_PREDICT_FALSE(read_bytes == 0)) {
+ error.set_buffer_out_of_bound(reader_index_, 1, size_);
+ return 0;
+ }
increase_reader_index(read_bytes);
return static_cast<int32_t>((raw >> 1) ^ (~(raw & 1) + 1));
}
@@ -855,6 +908,10 @@ public:
}
uint32_t read_bytes = 0;
uint64_t value = get_var_uint64(reader_index_, &read_bytes);
+ if (FORY_PREDICT_FALSE(read_bytes == 0)) {
+ error.set_buffer_out_of_bound(reader_index_, 1, size_);
+ return 0;
+ }
increase_reader_index(read_bytes);
return value;
}
@@ -986,15 +1043,31 @@ public:
uint8_t b = data_[position++];
uint64_t result = b & 0x7F;
if ((b & 0x80) != 0) {
+ if (FORY_PREDICT_FALSE(position >= size_)) {
+ error.set_buffer_out_of_bound(position, 1, size_);
+ return 0;
+ }
b = data_[position++];
result |= static_cast<uint64_t>(b & 0x7F) << 7;
if ((b & 0x80) != 0) {
+ if (FORY_PREDICT_FALSE(position >= size_)) {
+ error.set_buffer_out_of_bound(position, 1, size_);
+ return 0;
+ }
b = data_[position++];
result |= static_cast<uint64_t>(b & 0x7F) << 14;
if ((b & 0x80) != 0) {
+ if (FORY_PREDICT_FALSE(position >= size_)) {
+ error.set_buffer_out_of_bound(position, 1, size_);
+ return 0;
+ }
b = data_[position++];
result |= static_cast<uint64_t>(b & 0x7F) << 21;
if ((b & 0x80) != 0) {
+ if (FORY_PREDICT_FALSE(position >= size_)) {
+ error.set_buffer_out_of_bound(position, 1, size_);
+ return 0;
+ }
b = data_[position++];
result |= static_cast<uint64_t>(b & 0xFF) << 28;
}
diff --git a/cpp/fory/util/buffer_test.cc b/cpp/fory/util/buffer_test.cc
index 7263e4905..1a49f02a5 100644
--- a/cpp/fory/util/buffer_test.cc
+++ b/cpp/fory/util/buffer_test.cc
@@ -119,6 +119,57 @@ TEST(Buffer, TestGetBytesAsInt64) {
EXPECT_TRUE(buffer->get_bytes_as_int64(0, 1, &result).ok());
EXPECT_EQ(result, 100);
}
+
+TEST(Buffer, TestGetBytesAsInt64OutOfBound) {
+ std::shared_ptr<Buffer> buffer;
+ allocate_buffer(8, &buffer);
+ int64_t result = -1;
+ auto oob = buffer->get_bytes_as_int64(7, 2, &result);
+ EXPECT_FALSE(oob.ok());
+ auto invalid = buffer->get_bytes_as_int64(0, 9, &result);
+ EXPECT_FALSE(invalid.ok());
+}
+
+TEST(Buffer, TestGetVarUint32Truncated) {
+ std::vector<uint8_t> bytes = {0x80};
+ Buffer buffer(bytes);
+ uint32_t read_bytes = 123;
+ uint32_t value = buffer.get_var_uint32(0, &read_bytes);
+ EXPECT_EQ(value, 0U);
+ EXPECT_EQ(read_bytes, 0U);
+
+ Error error;
+ uint32_t decoded = buffer.read_var_uint32(error);
+ EXPECT_EQ(decoded, 0U);
+ EXPECT_FALSE(error.ok());
+ EXPECT_EQ(buffer.reader_index(), 0U);
+}
+
+TEST(Buffer, TestGetVarUint64Truncated) {
+ std::vector<uint8_t> bytes(8, 0x80);
+ Buffer buffer(bytes);
+ uint32_t read_bytes = 123;
+ uint64_t value = buffer.get_var_uint64(0, &read_bytes);
+ EXPECT_EQ(value, 0ULL);
+ EXPECT_EQ(read_bytes, 0U);
+
+ Error error;
+ uint64_t decoded = buffer.read_var_uint64(error);
+ EXPECT_EQ(decoded, 0ULL);
+ EXPECT_FALSE(error.ok());
+ EXPECT_EQ(buffer.reader_index(), 0U);
+}
+
+TEST(Buffer, TestReadVarUint36SmallTruncated) {
+ std::vector<uint8_t> bytes = {0x80, 0x80, 0x80, 0x80};
+ Buffer buffer(bytes);
+
+ Error error;
+ uint64_t decoded = buffer.read_var_uint36_small(error);
+ EXPECT_EQ(decoded, 0ULL);
+ EXPECT_FALSE(error.ok());
+ EXPECT_EQ(buffer.reader_index(), 0U);
+}
} // namespace fory
int main(int argc, char **argv) {
diff --git a/rust/fory-core/src/buffer.rs b/rust/fory-core/src/buffer.rs
index 8e6d16838..4cab51acf 100644
--- a/rust/fory-core/src/buffer.rs
+++ b/rust/fory-core/src/buffer.rs
@@ -565,7 +565,11 @@ impl<'a> Reader<'a> {
#[inline(always)]
fn check_bound(&self, n: usize) -> Result<(), Error> {
- if self.cursor + n > self.bf.len() {
+ let end = self
+ .cursor
+ .checked_add(n)
+ .ok_or_else(|| Error::buffer_out_of_bound(self.cursor, n,
self.bf.len()))?;
+ if end > self.bf.len() {
Err(Error::buffer_out_of_bound(self.cursor, n, self.bf.len()))
} else {
Ok(())
@@ -699,8 +703,8 @@ impl<'a> Reader<'a> {
#[inline(always)]
pub fn read_u16(&mut self) -> Result<u16, Error> {
- let slice = self.slice_after_cursor();
- let result = LittleEndian::read_u16(slice);
+ self.check_bound(2)?;
+ let result = LittleEndian::read_u16(&self.bf[self.cursor..self.cursor
+ 2]);
self.cursor += 2;
Ok(result)
}
@@ -709,8 +713,8 @@ impl<'a> Reader<'a> {
#[inline(always)]
pub fn read_u32(&mut self) -> Result<u32, Error> {
- let slice = self.slice_after_cursor();
- let result = LittleEndian::read_u32(slice);
+ self.check_bound(4)?;
+ let result = LittleEndian::read_u32(&self.bf[self.cursor..self.cursor
+ 4]);
self.cursor += 4;
Ok(result)
}
@@ -756,8 +760,8 @@ impl<'a> Reader<'a> {
#[inline(always)]
pub fn read_u64(&mut self) -> Result<u64, Error> {
- let slice = self.slice_after_cursor();
- let result = LittleEndian::read_u64(slice);
+ self.check_bound(8)?;
+ let result = LittleEndian::read_u64(&self.bf[self.cursor..self.cursor
+ 8]);
self.cursor += 8;
Ok(result)
}
@@ -854,8 +858,8 @@ impl<'a> Reader<'a> {
#[inline(always)]
pub fn read_f32(&mut self) -> Result<f32, Error> {
- let slice = self.slice_after_cursor();
- let result = LittleEndian::read_f32(slice);
+ self.check_bound(4)?;
+ let result = LittleEndian::read_f32(&self.bf[self.cursor..self.cursor
+ 4]);
self.cursor += 4;
Ok(result)
}
@@ -863,14 +867,15 @@ impl<'a> Reader<'a> {
// ============ FLOAT64 (TypeId = 18) ============
#[inline(always)]
pub fn read_f16(&mut self) -> Result<float16, Error> {
- let bits = LittleEndian::read_u16(self.slice_after_cursor());
+ self.check_bound(2)?;
+ let bits = LittleEndian::read_u16(&self.bf[self.cursor..self.cursor +
2]);
self.cursor += 2;
Ok(float16::from_bits(bits))
}
pub fn read_f64(&mut self) -> Result<f64, Error> {
- let slice = self.slice_after_cursor();
- let result = LittleEndian::read_f64(slice);
+ self.check_bound(8)?;
+ let result = LittleEndian::read_f64(&self.bf[self.cursor..self.cursor
+ 8]);
self.cursor += 8;
Ok(result)
}
@@ -963,8 +968,8 @@ impl<'a> Reader<'a> {
#[inline(always)]
pub fn read_u128(&mut self) -> Result<u128, Error> {
- let slice = self.slice_after_cursor();
- let result = LittleEndian::read_u128(slice);
+ self.check_bound(16)?;
+ let result = LittleEndian::read_u128(&self.bf[self.cursor..self.cursor
+ 16]);
self.cursor += 16;
Ok(result)
}
@@ -995,6 +1000,8 @@ impl<'a> Reader<'a> {
#[inline(always)]
pub fn read_varuint36small(&mut self) -> Result<u64, Error> {
+ // Keep this API panic-free even if cursor is externally set past
buffer end.
+ self.check_bound(0)?;
let start = self.cursor;
let slice = self.slice_after_cursor();
diff --git a/rust/fory-core/src/row/row.rs b/rust/fory-core/src/row/row.rs
index 0ee55a1c7..147a2f110 100644
--- a/rust/fory-core/src/row/row.rs
+++ b/rust/fory-core/src/row/row.rs
@@ -99,12 +99,16 @@ impl<'a, T: Row<'a>, const N: usize> FixedArrayGetter<'a,
T, N> {
self.array_data.num_elements()
}
- pub fn get(&self, idx: usize) -> T::ReadResult {
+ pub fn get(&self, idx: usize) -> Result<T::ReadResult, Error> {
if idx >= self.array_data.num_elements() {
- panic!("out of bound");
+ return Err(Error::buffer_out_of_bound(
+ idx,
+ 1,
+ self.array_data.num_elements(),
+ ));
}
let bytes = self.array_data.get_field_bytes(idx);
- <T as Row>::cast(bytes)
+ Ok(<T as Row>::cast(bytes))
}
}
@@ -190,12 +194,16 @@ impl<'a, T: Row<'a>> ArrayGetter<'a, T> {
self.array_data.num_elements()
}
- pub fn get(&self, idx: usize) -> T::ReadResult {
+ pub fn get(&self, idx: usize) -> Result<T::ReadResult, Error> {
if idx >= self.array_data.num_elements() {
- panic!("out of bound");
+ return Err(Error::buffer_out_of_bound(
+ idx,
+ 1,
+ self.array_data.num_elements(),
+ ));
}
let bytes = self.array_data.get_field_bytes(idx);
- <T as Row>::cast(bytes)
+ Ok(<T as Row>::cast(bytes))
}
}
@@ -241,7 +249,7 @@ impl<'a, T1: Row<'a> + Ord, T2: Row<'a> + Ord>
MapGetter<'a, T1, T2> {
let values = self.values();
for i in 0..self.keys().size() {
- map.insert(keys.get(i), values.get(i));
+ map.insert(keys.get(i)?, values.get(i)?);
}
Ok(map)
}
diff --git a/rust/fory/src/lib.rs b/rust/fory/src/lib.rs
index 6c88b6c45..3ba72033e 100644
--- a/rust/fory/src/lib.rs
+++ b/rust/fory/src/lib.rs
@@ -946,7 +946,7 @@
//!
//! let scores = row.scores();
//! assert_eq!(scores.size(), 4);
-//! assert_eq!(scores.get(0), 95);
+//! assert_eq!(scores.get(0).unwrap(), 95);
//! # }
//! ```
//!
diff --git a/rust/tests/tests/test_buffer.rs b/rust/tests/tests/test_buffer.rs
index 4a4f5e3ca..d589c16c3 100644
--- a/rust/tests/tests/test_buffer.rs
+++ b/rust/tests/tests/test_buffer.rs
@@ -96,3 +96,23 @@ fn test_varuint36_small() {
assert_eq!(value, data, "failed for data {}", data);
}
}
+
+#[test]
+fn test_fixed_width_read_bounds_checks() {
+ let mut empty = Reader::new(&[]);
+ assert!(empty.read_u16().is_err());
+ assert!(empty.read_u32().is_err());
+ assert!(empty.read_u64().is_err());
+ assert!(empty.read_f16().is_err());
+ assert!(empty.read_f32().is_err());
+ assert!(empty.read_f64().is_err());
+ assert!(empty.read_u128().is_err());
+
+ let mut short = Reader::new(&[1, 2, 3]);
+ assert!(short.read_u32().is_err());
+
+ let mut bad_cursor = Reader::new(&[1, 2, 3, 4]);
+ bad_cursor.set_cursor(10);
+ assert!(bad_cursor.read_u16().is_err());
+ assert!(bad_cursor.read_varuint36small().is_err());
+}
diff --git a/rust/tests/tests/test_row.rs b/rust/tests/tests/test_row.rs
index dce383099..488a57cfb 100644
--- a/rust/tests/tests/test_row.rs
+++ b/rust/tests/tests/test_row.rs
@@ -40,10 +40,11 @@ fn row_with_array_field() {
assert_eq!(obj.index(), 42);
let point_getter = obj.point();
assert_eq!(point_getter.size(), 4);
- assert_eq!(point_getter.get(0), 1.0);
- assert_eq!(point_getter.get(1), 2.0);
- assert_eq!(point_getter.get(2), 3.0);
- assert_eq!(point_getter.get(3), 4.0);
+ assert_eq!(point_getter.get(0).expect("index 0"), 1.0);
+ assert_eq!(point_getter.get(1).expect("index 1"), 2.0);
+ assert_eq!(point_getter.get(2).expect("index 2"), 3.0);
+ assert_eq!(point_getter.get(3).expect("index 3"), 4.0);
+ assert!(point_getter.get(4).is_err());
}
#[test]
@@ -73,9 +74,10 @@ fn row_with_nested_struct_array() {
assert_eq!(obj.name(), "origin");
let coords = obj.origin().coords();
assert_eq!(coords.size(), 3);
- assert_eq!(coords.get(0), 0.0);
- assert_eq!(coords.get(1), 0.0);
- assert_eq!(coords.get(2), 0.0);
+ assert_eq!(coords.get(0).expect("index 0"), 0.0);
+ assert_eq!(coords.get(1).expect("index 1"), 0.0);
+ assert_eq!(coords.get(2).expect("index 2"), 0.0);
+ assert!(coords.get(3).is_err());
}
#[test]
@@ -118,17 +120,20 @@ fn row() {
assert_eq!(f3, vec![1, 2, 3]);
let f4_size: usize = obj.f3().f4().size();
assert_eq!(f4_size, 3);
- assert_eq!(obj.f3().f4().get(0), -1);
- assert_eq!(obj.f3().f4().get(1), 2);
- assert_eq!(obj.f3().f4().get(2), -3);
+ assert_eq!(obj.f3().f4().get(0).expect("index 0"), -1);
+ assert_eq!(obj.f3().f4().get(1).expect("index 1"), 2);
+ assert_eq!(obj.f3().f4().get(2).expect("index 2"), -3);
+ assert!(obj.f3().f4().get(3).is_err());
let binding = obj.f3().f5();
assert_eq!(binding.keys().size(), 2);
- assert_eq!(binding.keys().get(0), "k1");
+ assert_eq!(binding.keys().get(0).expect("key 0"), "k1");
+ assert!(binding.keys().get(2).is_err());
assert_eq!(binding.values().size(), 2);
- assert_eq!(binding.values().get(0), "v1");
+ assert_eq!(binding.values().get(0).expect("value 0"), "v1");
+ assert!(binding.values().get(2).is_err());
let f5 = binding.to_btree_map().expect("should be map");
assert_eq!(f5.get("k1").expect("should exists"), &"v1");
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]