This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b50ee5fce2 [TIR][Schedule] Fix type checker to support subscripted 
generics in Python 3.14+ (#18639)
b50ee5fce2 is described below

commit b50ee5fce20aa98e1f1bdecd97b14490c24e6809
Author: Haejoon Kim <[email protected]>
AuthorDate: Tue Jan 6 18:42:54 2026 +0900

    [TIR][Schedule] Fix type checker to support subscripted generics in Python 
3.14+ (#18639)
    
    This PR fixes the type annotation checker in
    `tvm.tir.schedule._type_checker` to correctly handle subscripted
    generics (e.g., `Union[str, int]`, `List[str]`, `Tuple[str, int]`) in
    Python 3.14+.
    
    ## Background
    In Python 3.14, the internal representation of generic types has
    changed:
    - `Union[str, int]` is now of type `typing.Union` instead of
    `typing._GenericAlias` or `typing._SpecialGenericAlias`
    - These types now have `__origin__` attribute directly on the type
    object
    - The existing type checker failed to recognize these new
    representations, causing the dispatcher to fall through to "atomic"
    instead of correctly identifying them as "union", "list", etc.
    
    ## Changes
    Added a check for `__origin__` attribute at the beginning of the method
    to handle Python 3.14's new generic type representations. This is fully
    backward compatible since the new `__origin__` check is only applied
    when the attribute exists.
    
    ## Tests
    Added parametrized tests to verify the dispatcher correctly handles
    subscripted generics:
    - `Union[str, int]` → identified as "union"
    - `List[str]` → identified as "list"
    - `Dict[str, int]` → identified as "dict"
    - `Tuple[str, int]` → identified as "tuple"
    - `Union[List[str], Dict[str, int]]` → identified as "union" with nested
    generics
---
 python/tvm/tir/schedule/_type_checker.py           |  4 +++
 .../python/testing/test_type_annotation_checker.py | 34 ++++++++++++++++++++++
 2 files changed, 38 insertions(+)

diff --git a/python/tvm/tir/schedule/_type_checker.py 
b/python/tvm/tir/schedule/_type_checker.py
index 5c51b1b09f..148016fb2d 100644
--- a/python/tvm/tir/schedule/_type_checker.py
+++ b/python/tvm/tir/schedule/_type_checker.py
@@ -47,6 +47,10 @@ if hasattr(typing, "_GenericAlias"):
     class _Subtype:
         @staticmethod
         def _origin(type_: Any) -> Any:
+            # In Python 3.14+, check if the type has __origin__ attribute 
directly
+            if hasattr(type_, "__origin__"):
+                return type_.__origin__
+
             if hasattr(typing, "_SpecialGenericAlias"):
                 if isinstance(type_, typing._SpecialGenericAlias):  # type: 
ignore # pylint: disable=protected-access
                     return type_.__origin__
diff --git a/tests/python/testing/test_type_annotation_checker.py 
b/tests/python/testing/test_type_annotation_checker.py
index 42ce1e1039..71bc9ba98b 100644
--- a/tests/python/testing/test_type_annotation_checker.py
+++ b/tests/python/testing/test_type_annotation_checker.py
@@ -187,5 +187,39 @@ def test_not_matches(type_annotation, case):
         func(case)
 
 
[email protected](
+    ["type_annotation", "expected_key", "expected_subtypes"],
+    [
+        pytest.param(Union[str, int], "union", [str, int], id="Union[str, 
int]"),
+        pytest.param(List[str], "list", [str], id="List[str]"),
+        pytest.param(Dict[str, int], "dict", [str, int], id="Dict[str, int]"),
+        pytest.param(Tuple[str, int], "tuple", (str, int), id="Tuple[str, 
int]"),
+        pytest.param(
+            Union[List[str], Dict[str, int]],
+            "union",
+            [List[str], Dict[str, int]],
+            id="Union[List[str], Dict[str, int]]",
+        ),
+    ],
+)
+def test_subscripted_generics(type_annotation, expected_key, 
expected_subtypes):
+    """Test that _dispatcher correctly handles subscripted generics in Python 
3.14+.
+
+    In Python 3.14, Union and other generic types have a different internal 
representation.
+    This test ensures that the dispatcher correctly identifies these types.
+    """
+    from tvm.tir.schedule._type_checker import _dispatcher
+
+    key, subtypes = _dispatcher(type_annotation)
+    assert key == expected_key, f"Expected '{expected_key}' but got '{key}'"
+
+    if isinstance(expected_subtypes, tuple):
+        assert (
+            tuple(subtypes) == expected_subtypes
+        ), f"Expected {expected_subtypes} but got {subtypes}"
+    else:
+        assert subtypes == expected_subtypes, f"Expected {expected_subtypes} 
but got {subtypes}"
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to