This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new 23aa3f98e fix(compiler): never generate C++ equality methods for
message and union containing any (#3810)
23aa3f98e is described below
commit 23aa3f98ee56687c555c681a488a69d8cfbb5832
Author: Peiyang He <[email protected]>
AuthorDate: Fri Jul 3 13:30:12 2026 +0800
fix(compiler): never generate C++ equality methods for message and union
containing any (#3810)
## Why?
Fory IDL `any` maps to `std::any` in generated C++ code.
`std::any` does not define `operator==`, so generated equality code MUST
NOT ask the C++ standard library to compare two `std::any` directly.
But the current C++ compiler still does this in some generated
`operator==` implementations, especially when `any` appeared inside
standard-comparable containers or union alternatives.
For example, generated code such as `values_ == other.values_` for
`std::vector<std::any>`,
`by_name_ == other.by_name_` for `std::unordered_map<std::string,
std::any>`, or
`value_ == other.value_` for `std::variant<std::any, ...>` causes the
standard
container or variant equality operator to instantiate `std::any ==
std::any`.
This causes compilation error like:
```text
error: no match for 'operator==' (operand types are 'const std::any' and
'const std::any')
```
The Rust compiler avoids the equivalent derived-equality problem by not
deriving comparison traits for generated types that contain `any`:
https://github.com/apache/fory/blob/9b2bcec9618702d6aa5ba4b166f86619fce51baf/compiler/fory_compiler/generators/rust.py#L1149-L1158
## What does this PR do?
- Follow the Rust handling, i.e. don't generate C++ equality methods for
message and union containing `any`.
- Add one testcase.
## Related issues
N/A.
## AI Contribution Checklist
- [X] Substantial AI assistance was used in this PR: `no`
## Does this PR introduce any user-facing change?
N/A.
## Benchmark
N/A.
---
compiler/fory_compiler/generators/cpp.py | 148 +++++++++++++++++----
.../fory_compiler/tests/test_generated_code.py | 88 ++++++++++++
integration_tests/idl_tests/cpp/main.cc | 39 +++++-
3 files changed, 246 insertions(+), 29 deletions(-)
diff --git a/compiler/fory_compiler/generators/cpp.py
b/compiler/fory_compiler/generators/cpp.py
index 91e6b49fd..f7302dd15 100644
--- a/compiler/fory_compiler/generators/cpp.py
+++ b/compiler/fory_compiler/generators/cpp.py
@@ -691,13 +691,6 @@ class CppGenerator(BaseGenerator):
) -> str:
member_name = self.get_field_member_name(field)
other_member = f"other.{member_name}"
- if isinstance(field.field_type, PrimitiveType) and (
- field.field_type.kind == PrimitiveKind.ANY
- ):
- return (
- f"((!{member_name}.has_value() && !{other_member}.has_value())
|| "
- f"({member_name}.type() == {other_member}.type()))"
- )
if self.is_message_type(
field.field_type, parent_stack
) and self.get_field_weak_ref(field):
@@ -716,6 +709,104 @@ class CppGenerator(BaseGenerator):
)
return f"{member_name} == {other_member}"
+ def message_has_any(
+ self,
+ message: Message,
+ parent_stack: Optional[List[Message]] = None,
+ visiting: Optional[Set[Tuple[str, int]]] = None,
+ ) -> bool:
+ if visiting is None:
+ visiting = set()
+ key = ("message", id(message))
+ if key in visiting:
+ return False
+ visiting.add(key)
+ try:
+ lineage = (parent_stack or []) + [message]
+ return any(
+ self.field_type_has_any(field.field_type, lineage, visiting)
+ for field in message.fields
+ )
+ finally:
+ visiting.remove(key)
+
+ def union_has_any(
+ self,
+ union: Union,
+ parent_stack: Optional[List[Message]] = None,
+ visiting: Optional[Set[Tuple[str, int]]] = None,
+ ) -> bool:
+ if visiting is None:
+ visiting = set()
+ key = ("union", id(union))
+ if key in visiting:
+ return False
+ visiting.add(key)
+ try:
+ return any(
+ self.field_type_has_any(field.field_type, parent_stack,
visiting)
+ for field in union.fields
+ )
+ finally:
+ visiting.remove(key)
+
+ def field_type_has_any(
+ self,
+ field_type: FieldType,
+ parent_stack: Optional[List[Message]] = None,
+ visiting: Optional[Set[Tuple[str, int]]] = None,
+ ) -> bool:
+ """Return True when a field type or its children contain `any`."""
+ if isinstance(field_type, PrimitiveType):
+ return field_type.kind == PrimitiveKind.ANY
+ if isinstance(field_type, ListType):
+ return self.field_type_has_any(
+ field_type.element_type, parent_stack, visiting
+ )
+ if isinstance(field_type, ArrayType):
+ return self.field_type_has_any(
+ field_type.element_type, parent_stack, visiting
+ )
+ if isinstance(field_type, MapType):
+ # `any` is not allowed as map key (rejected first by the
validator),
+ # so we only check map value here.
+ return self.field_type_has_any(
+ field_type.value_type, parent_stack, visiting
+ )
+ if isinstance(field_type, NamedType):
+ named_type = self.resolve_named_type(field_type.name, parent_stack)
+ if isinstance(named_type, Message):
+ return self.message_has_any(
+ named_type, self._parent_stack_for_type(named_type),
visiting
+ )
+ if isinstance(named_type, Union):
+ return self.union_has_any(
+ named_type, self._parent_stack_for_type(named_type),
visiting
+ )
+ return False
+
+ def _parent_stack_for_type(self, type_def: object) -> List[Message]:
+ def visit(message: Message, parents: List[Message]) ->
Optional[List[Message]]:
+ if message is type_def:
+ return parents
+ for nested_union in message.nested_unions:
+ if nested_union is type_def:
+ return parents + [message]
+ for nested_enum in message.nested_enums:
+ if nested_enum is type_def:
+ return parents + [message]
+ for nested_message in message.nested_messages:
+ found = visit(nested_message, parents + [message])
+ if found is not None:
+ return found
+ return None
+
+ for top in self.schema.messages:
+ found = visit(top, [])
+ if found is not None:
+ return found
+ return []
+
def is_numeric_field(self, field: Field) -> bool:
if not isinstance(field.field_type, PrimitiveType):
return False
@@ -914,19 +1005,23 @@ class CppGenerator(BaseGenerator):
lines.append("")
lines.append("")
- lines.append(
- f"{body_indent}bool operator==(const {class_name}& other) const {{"
- )
- if message.fields:
- conditions = [
- self.get_field_eq_expression(field, lineage) for field in
message.fields
- ]
- lines.append(f"{body_indent} return {' && '.join(conditions)};")
- else:
- lines.append(f"{body_indent} return true;")
- lines.append(f"{body_indent}}}")
+ # We don't generate equality method for message containing `any`
+ # since C++ doesn't support std::any == std::any.
+ if not self.message_has_any(message, parent_stack):
+ lines.append(
+ f"{body_indent}bool operator==(const {class_name}& other)
const {{"
+ )
+ if message.fields:
+ conditions = [
+ self.get_field_eq_expression(field, lineage)
+ for field in message.fields
+ ]
+ lines.append(f"{body_indent} return {' &&
'.join(conditions)};")
+ else:
+ lines.append(f"{body_indent} return true;")
+ lines.append(f"{body_indent}}}")
+ lines.append("")
- lines.append("")
lines.extend(self.generate_bytes_methods(class_name, body_indent))
struct_type_name = self.get_qualified_type_name(message.name,
parent_stack)
@@ -1069,12 +1164,15 @@ class CppGenerator(BaseGenerator):
)
lines.append(f"{body_indent} }}")
lines.append("")
- lines.append(
- f"{body_indent} bool operator==(const {class_name}& other) const
{{"
- )
- lines.append(f"{body_indent} return value_ == other.value_;")
- lines.append(f"{body_indent} }}")
- lines.append("")
+ # We don't generate equality method for union containing `any`
+ # since C++ doesn't support std::any == std::any.
+ if not self.union_has_any(union, parent_stack):
+ lines.append(
+ f"{body_indent} bool operator==(const {class_name}& other)
const {{"
+ )
+ lines.append(f"{body_indent} return value_ == other.value_;")
+ lines.append(f"{body_indent} }}")
+ lines.append("")
lines.extend(self.generate_bytes_methods(class_name, f"{body_indent}
"))
diff --git a/compiler/fory_compiler/tests/test_generated_code.py
b/compiler/fory_compiler/tests/test_generated_code.py
index d5f7b3960..16e8fca31 100644
--- a/compiler/fory_compiler/tests/test_generated_code.py
+++ b/compiler/fory_compiler/tests/test_generated_code.py
@@ -1126,6 +1126,94 @@ def
test_cpp_generator_supports_decimal_fields_and_unions():
assert "(amount, fory::serialization::Decimal, fory::F(1))" in cpp_output
+def test_cpp_omits_equality_for_any_types():
+ schema = parse_fdl(
+ dedent(
+ """
+ package gen;
+
+ message Inner {
+ any value = 1;
+ }
+
+ union AnyChoice {
+ Inner inner = 1;
+ string name = 2;
+ }
+
+ message DirectAny {
+ any value = 1;
+ }
+
+ message AnyList {
+ list<any> values = 1;
+ }
+
+ message AnyMap {
+ map<string, any> values = 1;
+ }
+
+ union DirectChoice {
+ any payload = 1;
+ list<any> values = 2;
+ string name = 3;
+ }
+
+ message DirectOwner {
+ Inner inner = 1;
+ }
+
+ message ListOwner {
+ list<Inner> values = 1;
+ }
+
+ message MapOwner {
+ map<string, Inner> values = 1;
+ }
+
+ message UnionOwner {
+ AnyChoice choice = 1;
+ }
+
+ message DeclaresNestedOnly {
+ message Nested {
+ any value = 1;
+ }
+
+ string name = 1;
+ }
+
+ message Plain {
+ string name = 1;
+ list<int32> values = 2;
+ map<string, int32> counts = 3;
+ }
+
+ union PlainChoice {
+ string name = 1;
+ int32 code = 2;
+ }
+ """
+ )
+ )
+
+ cpp_output = render_files(generate_files(schema, CppGenerator))
+ assert "bool operator==(const Inner& other) const" not in cpp_output
+ assert "bool operator==(const AnyChoice& other) const" not in cpp_output
+ assert "bool operator==(const DirectAny& other) const" not in cpp_output
+ assert "bool operator==(const AnyList& other) const" not in cpp_output
+ assert "bool operator==(const AnyMap& other) const" not in cpp_output
+ assert "bool operator==(const DirectChoice& other) const" not in cpp_output
+ assert "bool operator==(const DirectOwner& other) const" not in cpp_output
+ assert "bool operator==(const ListOwner& other) const" not in cpp_output
+ assert "bool operator==(const MapOwner& other) const" not in cpp_output
+ assert "bool operator==(const UnionOwner& other) const" not in cpp_output
+ assert "bool operator==(const Nested& other) const" not in cpp_output
+ assert "bool operator==(const DeclaresNestedOnly& other) const" in
cpp_output
+ assert "bool operator==(const Plain& other) const" in cpp_output
+ assert "bool operator==(const PlainChoice& other) const" in cpp_output
+
+
def test_cpp_nested_container_ref_uses_correct_pointer_type():
schema = parse_fdl(
dedent(
diff --git a/integration_tests/idl_tests/cpp/main.cc
b/integration_tests/idl_tests/cpp/main.cc
index cb41a987a..8d9ec6baf 100644
--- a/integration_tests/idl_tests/cpp/main.cc
+++ b/integration_tests/idl_tests/cpp/main.cc
@@ -1076,6 +1076,23 @@ fory::Result<void, fory::Error> RunEvolvingRoundTrip() {
using StringMap = std::unordered_map<std::string, std::string>;
+template <typename T>
+fory::Result<void, fory::Error>
+ValidateAnyField(const std::any &actual_any, const std::any &expected_any,
+ const std::string &field_name) {
+ const auto *actual = std::any_cast<T>(&actual_any);
+ const auto *expected = std::any_cast<T>(&expected_any);
+ if (actual == nullptr || expected == nullptr) {
+ return fory::Unexpected(
+ fory::Error::invalid("any holder " + field_name + " type mismatch"));
+ }
+ if (!(*actual == *expected)) {
+ return fory::Unexpected(
+ fory::Error::invalid("any holder " + field_name + " value mismatch"));
+ }
+ return fory::Result<void, fory::Error>();
+}
+
fory::Result<void, fory::Error> RunRoundTrip(bool compatible) {
auto fory = fory::serialization::Fory::builder()
.xlang(true)
@@ -1479,10 +1496,24 @@ fory::Result<void, fory::Error> RunRoundTrip(bool
compatible) {
FORY_TRY(any_roundtrip, fory.deserialize<any_example::AnyHolder>(
any_bytes.data(), any_bytes.size()));
- if (!(any_roundtrip == any_holder)) {
- return fory::Unexpected(
- fory::Error::invalid("any holder roundtrip mismatch"));
- }
+ FORY_RETURN_IF_ERROR(ValidateAnyField<bool>(
+ any_roundtrip.bool_value(), any_holder.bool_value(), "bool_value"));
+ FORY_RETURN_IF_ERROR(ValidateAnyField<std::string>(
+ any_roundtrip.string_value(), any_holder.string_value(),
"string_value"));
+ FORY_RETURN_IF_ERROR(ValidateAnyField<fory::serialization::Date>(
+ any_roundtrip.date_value(), any_holder.date_value(), "date_value"));
+ FORY_RETURN_IF_ERROR(ValidateAnyField<fory::serialization::Timestamp>(
+ any_roundtrip.timestamp_value(), any_holder.timestamp_value(),
+ "timestamp_value"));
+ FORY_RETURN_IF_ERROR(ValidateAnyField<any_example::AnyInner>(
+ any_roundtrip.message_value(), any_holder.message_value(),
+ "message_value"));
+ FORY_RETURN_IF_ERROR(ValidateAnyField<any_example::AnyUnion>(
+ any_roundtrip.union_value(), any_holder.union_value(), "union_value"));
+ FORY_RETURN_IF_ERROR(ValidateAnyField<std::vector<std::string>>(
+ any_roundtrip.list_value(), any_holder.list_value(), "list_value"));
+ FORY_RETURN_IF_ERROR(ValidateAnyField<StringMap>(
+ any_roundtrip.map_value(), any_holder.map_value(), "map_value"));
example_peer::ExampleMessage example_message = BuildExampleMessage();
FORY_TRY(example_bytes, fory.serialize(example_message));
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]