This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git


The following commit(s) were added to refs/heads/main by this push:
     new 4660351b6 fix(compiler): reject any, message and union as map key 
types (#3804)
4660351b6 is described below

commit 4660351b625da3222f7c4665454cdec2c552f419
Author: Peiyang He <[email protected]>
AuthorDate: Tue Jun 30 18:31:14 2026 +0800

    fix(compiler): reject any, message and union as map key types (#3804)
    
    ## Why?
    
    As discussed in
    https://github.com/apache/fory/pull/3789#issuecomment-4798805385, we
    should reject `any` `message` and `union` as IDL `map` key types. But
    the validator misses them in the check.
    
    ## What does this PR do?
    
    - Reject `any`, `message` and `union` as map key types
    - Modify the test and doc accordingly
    
    ## Related issues
    
    N/A.
    
    ## AI Contribution Checklist
    
    
    
    - [X] Substantial AI assistance was used in this PR: `no`
    
    ## Does this PR introduce any user-facing change?
    
    N/A.
    
    ## Benchmark
    
    N/A.
---
 compiler/fory_compiler/ir/ast.py                   |   2 +-
 compiler/fory_compiler/ir/validator.py             |  47 +++++---
 compiler/fory_compiler/tests/test_weak_ref.py      |   2 +-
 .../fory_compiler/tests/test_xlang_type_system.py  | 129 ++++++++++++++++++++-
 docs/compiler/schema-idl.md                        |   6 +-
 5 files changed, 167 insertions(+), 19 deletions(-)

diff --git a/compiler/fory_compiler/ir/ast.py b/compiler/fory_compiler/ir/ast.py
index a03fb2b4d..e6ab74627 100644
--- a/compiler/fory_compiler/ir/ast.py
+++ b/compiler/fory_compiler/ir/ast.py
@@ -56,7 +56,7 @@ class PrimitiveType:
 
 @dataclass
 class NamedType:
-    """A reference to a user-defined type (message or enum)."""
+    """A reference to a user-defined type (message, enum or union)."""
 
     name: str
     location: Optional[SourceLocation] = None
diff --git a/compiler/fory_compiler/ir/validator.py 
b/compiler/fory_compiler/ir/validator.py
index 613bc0544..fa2455f39 100644
--- a/compiler/fory_compiler/ir/validator.py
+++ b/compiler/fory_compiler/ir/validator.py
@@ -38,6 +38,7 @@ from fory_compiler.ir.types import ARRAY_ELEMENT_KINDS, 
PrimitiveKind
 from fory_compiler.ir.type_id import compute_registered_type_id
 
 INVALID_MAP_KEY_KINDS = {
+    PrimitiveKind.ANY,
     PrimitiveKind.BYTES,
     PrimitiveKind.FLOAT16,
     PrimitiveKind.BFLOAT16,
@@ -45,9 +46,7 @@ INVALID_MAP_KEY_KINDS = {
     PrimitiveKind.FLOAT64,
     PrimitiveKind.DECIMAL,
 }
-INVALID_MAP_KEY_MESSAGE = (
-    "map keys do not support binary, float, decimal, list, map, or array types"
-)
+INVALID_MAP_KEY_MESSAGE = "map keys do not support any, binary, float, 
decimal, message, union, list, map, or array types"
 
 
 @dataclass
@@ -547,12 +546,30 @@ class SchemaValidator:
                 check_field(f, None)
 
     def _check_collection_type_rules(self) -> None:
-        def invalid_map_key(field_type: FieldType) -> bool:
+        def invalid_map_key(
+            field_type: FieldType,
+            enclosing_messages: Optional[List[Message]],
+        ) -> bool:
             if isinstance(field_type, PrimitiveType):
                 return field_type.kind in INVALID_MAP_KEY_KINDS
+            if isinstance(field_type, NamedType):
+                if enclosing_messages is not None:
+                    resolved = self._resolve_named_type(
+                        field_type.name, enclosing_messages
+                    )
+                else:
+                    resolved = self._find_top_level_type(field_type.name)
+                return isinstance(
+                    resolved, (Message, Union)
+                )  # message and union cannot be used as map key types.
             return isinstance(field_type, (ListType, ArrayType, MapType))
 
-        def check_type(field_type: FieldType, field: Field, in_map_key: bool = 
False):
+        def check_type(
+            field_type: FieldType,
+            field: Field,
+            enclosing_messages: Optional[List[Message]] = None,
+            in_map_key: bool = False,
+        ):
             if isinstance(field_type, ArrayType):
                 if in_map_key:
                     self._error(INVALID_MAP_KEY_MESSAGE, field.location)
@@ -580,26 +597,30 @@ class SchemaValidator:
                 if in_map_key:
                     self._error(INVALID_MAP_KEY_MESSAGE, field.location)
                     return
-                check_type(field_type.element_type, field)
+                check_type(field_type.element_type, field, enclosing_messages)
             elif isinstance(field_type, MapType):
                 if in_map_key:
                     self._error(INVALID_MAP_KEY_MESSAGE, field.location)
                     return
                 key_type = field_type.key_type
-                if invalid_map_key(key_type):
+                if invalid_map_key(key_type, enclosing_messages):
                     self._error(INVALID_MAP_KEY_MESSAGE, field.location)
                 else:
-                    check_type(key_type, field, in_map_key=True)
-                check_type(field_type.value_type, field)
+                    check_type(key_type, field, enclosing_messages, 
in_map_key=True)
+                check_type(field_type.value_type, field, enclosing_messages)
 
-        def check_message_fields(message: Message) -> None:
+        def check_message_fields(
+            message: Message,
+            enclosing_messages: Optional[List[Message]] = None,
+        ) -> None:
+            lineage = (enclosing_messages or []) + [message]
             for f in message.fields:
-                check_type(f.field_type, f)
+                check_type(f.field_type, f, lineage)
             for nested_msg in message.nested_messages:
-                check_message_fields(nested_msg)
+                check_message_fields(nested_msg, lineage)
             for nested_union in message.nested_unions:
                 for f in nested_union.fields:
-                    check_type(f.field_type, f)
+                    check_type(f.field_type, f, lineage)
 
         for message in self.schema.messages:
             check_message_fields(message)
diff --git a/compiler/fory_compiler/tests/test_weak_ref.py 
b/compiler/fory_compiler/tests/test_weak_ref.py
index a90cb8382..b669cae9a 100644
--- a/compiler/fory_compiler/tests/test_weak_ref.py
+++ b/compiler/fory_compiler/tests/test_weak_ref.py
@@ -135,7 +135,7 @@ def 
test_list_and_map_ref_options_preserve_explicit_opt_out():
     message Holder {
         list<ref Foo> foos = 1;
         list<ref(weak=true, thread_safe=false) Bar> bars = 2;
-        map<Foo, ref(weak=true, thread_safe=false) Bar> bar_map = 3;
+        map<string, ref(weak=true, thread_safe=false) Bar> bar_map = 3;
     }
     """
     schema = parse_schema(source)
diff --git a/compiler/fory_compiler/tests/test_xlang_type_system.py 
b/compiler/fory_compiler/tests/test_xlang_type_system.py
index b03be3699..4924dc60c 100644
--- a/compiler/fory_compiler/tests/test_xlang_type_system.py
+++ b/compiler/fory_compiler/tests/test_xlang_type_system.py
@@ -166,6 +166,7 @@ def 
test_array_rejects_optional_or_ref_elements_at_parse_time():
 @pytest.mark.parametrize(
     "key_type",
     [
+        "any",
         "bytes",
         "float16",
         "bfloat16",
@@ -188,12 +189,138 @@ def test_map_rejects_non_portable_key_types(key_type):
 
     assert not ok
     assert any(
-        "map keys do not support binary, float, decimal, list, map, or array 
types"
+        "map keys do not support any, binary, float, decimal, message, union, 
list, map, or array types"
+        in err.message
+        for err in validator.errors
+    )
+
+
[email protected](
+    "source",
+    [
+        """
+        message Key {
+            string id = 1;
+        }
+
+        message InvalidMap {
+            map<Key, string> values = 1;
+        }
+        """,
+        """
+        union Choice {
+            string text = 1;
+        }
+
+        message InvalidMap {
+            map<Choice, string> values = 1;
+        }
+        """,
+        """
+        message Key {
+            string id = 1;
+        }
+
+        union InvalidUnion {
+            map<Key, string> values = 1;
+        }
+        """,
+        """
+        message Outer {
+            message Key {
+                string id = 1;
+            }
+
+            map<Key, string> values = 1;
+        }
+        """,
+        """
+        message Outer {
+            message Key {
+                string id = 1;
+            }
+        }
+
+        message InvalidMap {
+            map<Outer.Key, string> values = 1;
+        }
+        """,
+        """
+        message Outer {
+            union Choice {
+                string text = 1;
+            }
+
+            map<Choice, string> values = 1;
+        }
+        """,
+        """
+        message Outer {
+            union Choice {
+                string text = 1;
+            }
+        }
+
+        message InvalidMap {
+            map<Outer.Choice, string> values = 1;
+        }
+        """,
+    ],
+)
+def test_map_rejects_message_and_union_key_types(source):
+    _schema, validator, ok = validate_schema(source)
+
+    assert not ok
+    assert any(
+        "map keys do not support any, binary, float, decimal, message, union, 
list, map, or array types"
         in err.message
         for err in validator.errors
     )
 
 
[email protected](
+    "source",
+    [
+        """
+        enum Status {
+            UNKNOWN = 0;
+            READY = 1;
+        }
+
+        message Holder {
+            map<Status, string> values = 1;
+        }
+        """,
+        """
+        message Outer {
+            enum Status {
+                UNKNOWN = 0;
+                READY = 1;
+            }
+
+            map<Status, string> values = 1;
+        }
+        """,
+        """
+        message Outer {
+            enum Status {
+                UNKNOWN = 0;
+                READY = 1;
+            }
+        }
+
+        message Holder {
+            map<Outer.Status, string> values = 1;
+        }
+        """,
+    ],
+)
+def test_map_accepts_enum_key_types(source):
+    _schema, validator, ok = validate_schema(source)
+
+    assert ok, validator.errors
+
+
 def test_proto_repeated_fields_remain_list_type():
     schema = ProtoFrontend().parse(
         """
diff --git a/docs/compiler/schema-idl.md b/docs/compiler/schema-idl.md
index c12135203..128f2fe2a 100644
--- a/docs/compiler/schema-idl.md
+++ b/docs/compiler/schema-idl.md
@@ -1497,9 +1497,9 @@ message Config {
 - Temporal scalar types (`date`, `timestamp`, `duration`)
 - Enums
 
-Map keys do not support binary `bytes`, floating-point types, `decimal`, 
`list<T>`, `array<T>`,
-or nested `map<K, V>` types. Put those types in map values or wrap them in a 
message with a
-portable scalar or enum key.
+Map keys do not support `any`, binary `bytes`, floating-point types, 
`decimal`, message types,
+union types, `list<T>`, `array<T>`, or nested `map<K, V>` types. Put those 
types in map values or
+wrap them in a message with a portable scalar or enum key.
 
 ### Type Compatibility Matrix
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to