This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch seqhashdef-non-recursive
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git

commit 107ad5194ccdeead9c03764b48b5982966bacec9
Author: tqchen <[email protected]>
AuthorDate: Sun May 10 00:36:58 2026 +0000

    [FEAT] Split SEqHashDef into recursive and non-recursive variants
    
    The single ``SEqHashDef`` flag treated every nested free var inside a
    def-region field as a fresh def. That conflates two different binding
    shapes:
    
    - *Recursive* (function-style): the value var and any free vars inside
      its sub-fields (e.g. shape vars in a relax::Var's struct_info) are
      co-introduced at the same site.
    
    - *Non-recursive* (let-style): only the immediate value var binds; free
      vars in the var's sub-fields are use references that must resolve
      against an outer-scope binding.
    
    This change:
    
    - Renames the C constant ``kTVMFFIFieldFlagBitMaskSEqHashDef`` ->
      ``...SEqHashDefRecursive`` (no alias) and adds a new sibling
      ``...SEqHashDefNonRecursive`` at bit ``1 << 12`` (the next free bit).
    - Renames ``AttachFieldFlag::SEqHashDef()`` ->
      ``SEqHashDefRecursive()`` and adds ``SEqHashDefNonRecursive()``.
    - Adds a ``TVMFFIFieldDefKind`` C enum (None=0, Recursive=1,
      NonRecursive=2) for the custom ``__s_equal__`` / ``__s_hash__``
      callback's def-mode parameter. The wire type stays ``int``, so legacy
      callers that pass a ``bool`` still compile and preserve their meaning
      via the standard bool->int coercion (true -> 1 = Recursive,
      false -> 0 = None).
    - Updates the structural_equal / structural_hash dispatch to clamp
      ``map_free_vars_`` to false when descending into a FreeVar's own
      sub-fields under the non-recursive flag (the binding step itself
      still runs with the caller's setting, so the immediate FreeVar still
      binds).
    - Mirrors the rename and the new enum into the Cython base.pxi /
      type_info.pxi / object.pxi layer; preserves Python ``"def"`` as a
      back-compat alias for ``"def-recursive"`` and adds
      ``"def-non-recursive"`` to the dataclass field vocabulary.
    - Adds C++ regression tests in
      tests/cpp/extra/test_structural_equal_hash.cc covering the four
      recursive / non-recursive corner cases on a new ``TDefHolder``
      test type whose ``def_recursive`` and ``def_non_recursive`` fields
      hold a FreeVar with a nested FreeVar sub-field.
---
 include/tvm/ffi/c_api.h                       | 92 ++++++++++++++++++++++++++-
 include/tvm/ffi/reflection/registry.h         | 21 +++++-
 python/tvm_ffi/cython/base.pxi                |  8 ++-
 python/tvm_ffi/cython/object.pxi              |  8 ++-
 python/tvm_ffi/cython/type_info.pxi           | 10 ++-
 python/tvm_ffi/dataclasses/field.py           | 29 +++++++--
 src/ffi/extra/structural_equal.cc             | 80 ++++++++++++++++++++---
 src/ffi/extra/structural_hash.cc              | 69 +++++++++++++++++---
 tests/cpp/extra/test_structural_equal_hash.cc | 77 ++++++++++++++++++++++
 tests/cpp/test_reflection.cc                  |  2 +
 tests/cpp/testing_object.h                    | 77 +++++++++++++++++++++-
 11 files changed, 435 insertions(+), 38 deletions(-)

diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 0d2c9df..e69fd6c 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -858,11 +858,21 @@ typedef enum {
    */
   kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3,
   /*!
-   * \brief The field enters a def region where var can be defined/matched.
+   * \brief The field enters a recursive def region.
+   *
+   * When equality / hashing first encounters a free var reachable through
+   * this field, the var binds. Sub-fields of that var (e.g. its
+   * struct_info / type_annotation / shape) remain in the def region — any
+   * free vars discovered transitively also bind as fresh defs at the same
+   * site.
+   *
+   * Use for "function-style" bindings where the value var and its shape
+   * vars are co-introduced (e.g. ``relax::FunctionNode::params``,
+   * ``tirx::AllocBufferNode::buffer``).
    *
    * This is an optional meta-data for structural eq/hash.
    */
-  kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4,
+  kTVMFFIFieldFlagBitMaskSEqHashDefRecursive = 1 << 4,
   /*!
    * \brief The default_value_or_factory is a callable factory function () -> 
Any.
    *
@@ -922,6 +932,27 @@ typedef enum {
    * ``(field_addr_as_OpaquePtr, value_as_AnyView)``.
    */
   kTVMFFIFieldFlagBitSetterIsFunctionObj = 1 << 11,
+  /*!
+   * \brief The field enters a non-recursive def region.
+   *
+   * When equality / hashing first encounters a free var reachable through
+   * this field, the var binds. Sub-fields of that var are NOT in the def
+   * region — they are use references that must resolve against an outer
+   * binding. Free vars found in those sub-fields therefore do not rebind;
+   * if they are not already bound, equality fails (or hashing falls back
+   * to pointer identity).
+   *
+   * Use for "let-style" bindings where only the value var is introduced
+   * and its shape / type sub-fields refer to outer-scope vars
+   * (e.g. ``relax::BindingNode::var``, ``tirx::LetNode::var``,
+   * ``tirx::ForNode::loop_var``).
+   *
+   * This is an optional meta-data for structural eq/hash.
+   *
+   * \note Bit 1 << 12 is used here because bits 1 << 5 .. 1 << 11 are
+   *       already taken by other field flags above.
+   */
+  kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive = 1 << 12,
 #ifdef __cplusplus
 };
 #else
@@ -993,6 +1024,63 @@ typedef enum {
 } TVMFFISEqHashKind;
 #endif
 
+/*!
+ * \brief Kind of def region a structural-equal / structural-hash callback is
+ *        currently in when visiting a field.
+ *
+ * The numeric values are stable: a legacy ``bool def_region`` argument that
+ * is implicitly coerced to ``int`` will land on ``kTVMFFIFieldDefKindNone``
+ * (false → 0) or ``kTVMFFIFieldDefKindRecursive`` (true → 1), preserving
+ * the meaning of any pre-existing call site that passes a bool.
+ */
+#ifdef __cplusplus
+enum TVMFFIFieldDefKind : int32_t {
+#else
+typedef enum {
+#endif
+  /*!
+   * \brief Not in a def region.
+   *
+   * Free vars reachable through this field are uses; they must already
+   * be bound by an enclosing def region or equality / hashing falls
+   * back to pointer identity.
+   */
+  kTVMFFIFieldDefKindNone = 0,
+  /*!
+   * \brief In a recursive def region.
+   *
+   * When we see a free var for the first time, we define the var, and
+   * the sub-fields of the var (e.g. its struct_info / type_annotation /
+   * shape) are also still in the def region — any free vars discovered
+   * inside those sub-fields are themselves treated as fresh defs at the
+   * same site.
+   *
+   * Use for "function-style" bindings where the value var and its shape
+   * vars are co-introduced (e.g. ``relax::FunctionNode::params``,
+   * ``tirx::AllocBufferNode::buffer``).
+   */
+  kTVMFFIFieldDefKindRecursive = 1,
+  /*!
+   * \brief In a non-recursive def region.
+   *
+   * When we see a free var for the first time, we define the var, but
+   * the sub-fields of the var are NOT in the def region — they are
+   * treated as use references that must resolve against an outer
+   * binding. Free vars found in those sub-fields therefore do not
+   * rebind; if they are not already bound, equality fails.
+   *
+   * Use for "let-style" bindings where only the value var is introduced
+   * and its shape / type sub-fields refer to outer-scope vars
+   * (e.g. ``relax::BindingNode::var``, ``tirx::LetNode::var``,
+   * ``tirx::ForNode::loop_var``).
+   */
+  kTVMFFIFieldDefKindNonRecursive = 2,
+#ifdef __cplusplus
+};
+#else
+} TVMFFIFieldDefKind;
+#endif
+
 /*!
  * \brief Information support for optional object reflection.
  */
diff --git a/include/tvm/ffi/reflection/registry.h 
b/include/tvm/ffi/reflection/registry.h
index 3e715fe..09ceb8b 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -213,10 +213,25 @@ class AttachFieldFlag : public InfoTrait {
   explicit AttachFieldFlag(int32_t flag) : flag_(flag) {}
 
   /*!
-   * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef
+   * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDefRecursive
+   *
+   * The field enters a recursive def region: free vars discovered both at
+   * the field's value and inside that value's sub-fields bind as fresh
+   * defs at the same site. Use for "function-style" bindings.
+   */
+  TVM_FFI_INLINE static AttachFieldFlag SEqHashDefRecursive() {
+    return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDefRecursive);
+  }
+  /*!
+   * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive
+   *
+   * The field enters a non-recursive def region: only the immediate free
+   * var at the field's value binds; free vars in its sub-fields are uses
+   * that must already be bound by an outer def region. Use for "let-style"
+   * bindings whose sub-fields reference outer-scope vars.
    */
-  TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() {
-    return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef);
+  TVM_FFI_INLINE static AttachFieldFlag SEqHashDefNonRecursive() {
+    return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive);
   }
   /*!
    * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index c5c28a1..10d96c4 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -207,7 +207,7 @@ cdef extern from "tvm/ffi/c_api.h":
         kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1
         kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2
         kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3
-        kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4
+        kTVMFFIFieldFlagBitMaskSEqHashDefRecursive = 1 << 4
         kTVMFFIFieldFlagBitMaskDefaultFromFactory = 1 << 5
         kTVMFFIFieldFlagBitMaskReprOff = 1 << 6
         kTVMFFIFieldFlagBitMaskCompareOff = 1 << 7
@@ -215,6 +215,7 @@ cdef extern from "tvm/ffi/c_api.h":
         kTVMFFIFieldFlagBitMaskInitOff = 1 << 9
         kTVMFFIFieldFlagBitMaskKwOnly = 1 << 10
         kTVMFFIFieldFlagBitSetterIsFunctionObj = 1 << 11
+        kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive = 1 << 12
 
     ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept
     ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) 
noexcept
@@ -248,6 +249,11 @@ cdef extern from "tvm/ffi/c_api.h":
         kTVMFFISEqHashKindConstTreeNode = 4
         kTVMFFISEqHashKindUniqueInstance = 5
 
+    cdef enum TVMFFIFieldDefKind:
+        kTVMFFIFieldDefKindNone = 0
+        kTVMFFIFieldDefKindRecursive = 1
+        kTVMFFIFieldDefKindNonRecursive = 2
+
     ctypedef struct TVMFFITypeMetadata:
         TVMFFIByteArray doc
         TVMFFIObjectCreator creator
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 803ead2..c564b42 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -541,11 +541,13 @@ cdef _type_info_create_from_type_key(object type_cls, str 
type_key):
                 c_default_factory = make_ret(owned_default)
             else:
                 c_default = make_ret(owned_default)
-        # Decode SEqHashIgnore / SEqHashDef into the Field.structural_eq 
vocabulary.
+        # Decode SEqHashIgnore / SEqHashDef* into the Field.structural_eq 
vocabulary.
         if (field.flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) != 0:
             c_structural_eq = "ignore"
-        elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDef) != 0:
-            c_structural_eq = "def"
+        elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDefRecursive) != 0:
+            c_structural_eq = "def-recursive"
+        elif (field.flags & kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive) != 
0:
+            c_structural_eq = "def-non-recursive"
         else:
             c_structural_eq = None
         fields.append(
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index d7f39be..960d5cd 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -908,8 +908,14 @@ cdef _register_one_field(
     cdef object field_structure = getattr(py_field, "structural_eq", None)
     if field_structure == "ignore":
         flags |= kTVMFFIFieldFlagBitMaskSEqHashIgnore
-    elif field_structure == "def":
-        flags |= kTVMFFIFieldFlagBitMaskSEqHashDef
+    elif field_structure == "def" or field_structure == "def-recursive":
+        # ``"def"`` is the legacy short form, kept as a Python-side synonym for
+        # ``"def-recursive"`` since the C-level rename of the underlying flag
+        # (``kTVMFFIFieldFlagBitMaskSEqHashDef`` -> ``...SEqHashDefRecursive``)
+        # only changed the constant name, not the recursive semantics.
+        flags |= kTVMFFIFieldFlagBitMaskSEqHashDefRecursive
+    elif field_structure == "def-non-recursive":
+        flags |= kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive
     info.flags = flags
 
     # --- native layout ---
diff --git a/python/tvm_ffi/dataclasses/field.py 
b/python/tvm_ffi/dataclasses/field.py
index 53f576c..b332cd4 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -87,9 +87,18 @@ class Field:
           structural comparison and hashing.
         - ``"ignore"``: the field is excluded from structural equality
           and hashing entirely (e.g. source spans, caches).
-        - ``"def"``: the field is a **definition region** that introduces
-          new variable bindings.  Free variables encountered inside this
-          field are mapped by position, enabling alpha-equivalence.
+        - ``"def-recursive"`` (alias: ``"def"``): the field is a
+          **recursive definition region** that introduces new variable
+          bindings.  Free variables encountered anywhere in this field's
+          subtree (including inside the var's own sub-fields) are
+          mapped by position. Use for "function-style" bindings where
+          the value var and its shape vars are co-introduced.
+        - ``"def-non-recursive"``: the field is a **non-recursive
+          definition region**.  Only the immediate free var(s) at this
+          field's value bind; free vars inside their sub-fields must
+          resolve against an outer binding (use semantics). Use for
+          "let-style" bindings whose sub-fields reference outer-scope
+          vars.
     doc : str | None
         Optional docstring for the field.
 
@@ -125,8 +134,12 @@ class Field:
     doc: str | None
 
     #: Valid values for the *structural_eq* parameter.
+    #:
+    #: ``"def"`` is kept as a Python-side alias for ``"def-recursive"`` to
+    #: preserve back-compat with code written against the old single-flag
+    #: ``SEqHashDef`` API.
     _VALID_STRUCTURAL_EQ_VALUES: ClassVar[frozenset[str | None]] = frozenset(
-        {None, "ignore", "def"}
+        {None, "ignore", "def", "def-recursive", "def-non-recursive"}
     )
 
     def __init__(  # noqa: PLR0913
@@ -226,8 +239,12 @@ def field(
     structural_eq
         Structural equality/hashing annotation. ``None`` (default) means
         the field participates normally. ``"ignore"`` excludes the field
-        from structural comparison and hashing. ``"def"`` marks the field
-        as a definition region for variable binding.
+        from structural comparison and hashing. ``"def-recursive"``
+        (alias ``"def"``) marks the field as a recursive definition
+        region: free vars in the field's whole subtree bind. 
``"def-non-recursive"``
+        marks it as a non-recursive definition region: only immediate
+        free vars bind; nested free vars must resolve against an outer
+        binding.
     doc
         Optional docstring for the field.
 
diff --git a/src/ffi/extra/structural_equal.cc 
b/src/ffi/extra/structural_equal.cc
index 5f4db3c..5c1cb0d 100644
--- a/src/ffi/extra/structural_equal.cc
+++ b/src/ffi/extra/structural_equal.cc
@@ -174,6 +174,25 @@ class StructEqualHandler {
     static reflection::TypeAttrColumn custom_s_equal =
         reflection::TypeAttrColumn(reflection::type_attr::kSEqual);
 
+    // Non-recursive def boundary. When we enter a non-recursive def region we
+    // keep ``map_free_vars_`` on so that any FreeVar reachable through
+    // containers in the field's value (e.g. each ``Var`` in an ``Array<Var>``)
+    // can still bind. But once we are about to walk a FreeVar's OWN sub-fields
+    // (e.g. ``struct_info``, ``type_annotation``), we turn ``map_free_vars_``
+    // off so that nested free vars do not rebind — they must instead resolve
+    // against a binding established by an outer def region.
+    //
+    // ``non_recursive_def_active_`` stays on for the entire field value
+    // subtree (saved/restored at the field walk site below). It is consulted
+    // here to decide whether to clamp ``map_free_vars_`` to false during the
+    // FreeVar's field walk.
+    bool save_map_free_vars = map_free_vars_;
+    bool clamp_map_free_vars =
+        (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) && 
non_recursive_def_active_;
+    if (clamp_map_free_vars) {
+      map_free_vars_ = false;
+    }
+
     bool success = true;
     if (custom_s_equal[type_info->type_index] == nullptr) {
       // We recursively compare the fields the object
@@ -184,12 +203,27 @@ class StructEqualHandler {
         reflection::FieldGetter getter(field_info);
         Any lhs_value = getter(lhs);
         Any rhs_value = getter(rhs);
-        // field is in def region, enable free var mapping
-        if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) {
-          bool allow_free_var = true;
-          std::swap(allow_free_var, map_free_vars_);
+        // Dispatch on the def-region flags.
+        //   - Recursive    : enable ``map_free_vars_`` for the whole subtree
+        //                    of this field's value, including nested FreeVars'
+        //                    sub-fields.
+        //   - NonRecursive : enable ``map_free_vars_`` only for the immediate
+        //                    FreeVar(s) reachable through this field; their
+        //                    own sub-fields are walked with ``map_free_vars_``
+        //                    clamped to false (the clamp lives in the
+        //                    ``CompareObject`` prologue above, gated by
+        //                    ``non_recursive_def_active_``).
+        constexpr int64_t kSEqHashDefAny = 
kTVMFFIFieldFlagBitMaskSEqHashDefRecursive |
+                                           
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive;
+        if (field_info->flags & kSEqHashDefAny) {
+          bool save_allow_free_var = map_free_vars_;
+          bool save_non_recursive = non_recursive_def_active_;
+          map_free_vars_ = true;
+          non_recursive_def_active_ =
+              (field_info->flags & 
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive) != 0;
           success = CompareAny(lhs_value, rhs_value);
-          std::swap(allow_free_var, map_free_vars_);
+          map_free_vars_ = save_allow_free_var;
+          non_recursive_def_active_ = save_non_recursive;
         } else {
           success = CompareAny(lhs_value, rhs_value);
         }
@@ -212,16 +246,27 @@ class StructEqualHandler {
       // run custom equal function defined via __s_equal__ type attribute
       if (s_equal_callback_ == nullptr) {
         s_equal_callback_ = ffi::Function::FromTyped(
-            [this](AnyView lhs, AnyView rhs, bool def_region, AnyView 
field_name) {
+            // The third parameter is a ``TVMFFIFieldDefKind`` (typed as plain
+            // ``int`` on the wire to keep the FFI signature stable; legacy
+            // callers passing ``bool`` continue to compile and preserve their
+            // meaning via the implicit bool->int coercion: false -> 0
+            // (kTVMFFIFieldDefKindNone), true -> 1 
(kTVMFFIFieldDefKindRecursive)).
+            [this](AnyView lhs, AnyView rhs, int def_kind, AnyView field_name) 
{
               // NOTE: we explicitly make field_name as AnyView to avoid copy 
overhead initially
               // and only cast to string if mismatch happens
               bool success = true;
-              if (def_region) {
-                bool allow_free_var = true;
-                std::swap(allow_free_var, map_free_vars_);
+              if (def_kind == kTVMFFIFieldDefKindRecursive ||
+                  def_kind == kTVMFFIFieldDefKindNonRecursive) {
+                bool save_allow_free_var = map_free_vars_;
+                bool save_non_recursive = non_recursive_def_active_;
+                map_free_vars_ = true;
+                non_recursive_def_active_ = (def_kind == 
kTVMFFIFieldDefKindNonRecursive);
                 success = CompareAny(lhs, rhs);
-                std::swap(allow_free_var, map_free_vars_);
+                map_free_vars_ = save_allow_free_var;
+                non_recursive_def_active_ = save_non_recursive;
               } else {
+                // kTVMFFIFieldDefKindNone (or any unknown value treated as 
None):
+                // not in a def region, leave map_free_vars_ as-is.
                 success = CompareAny(lhs, rhs);
               }
               if (!success) {
@@ -241,6 +286,14 @@ class StructEqualHandler {
                     .cast<bool>();
     }
 
+    // Restore the pre-clamp value of map_free_vars_ before deciding whether
+    // to bind a FreeVar pair below. The binding decision must use the value
+    // that the caller of CompareObject set — the clamp only affects this
+    // FreeVar's OWN sub-field walk.
+    if (clamp_map_free_vars) {
+      map_free_vars_ = save_map_free_vars;
+    }
+
     if (success) {
       if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
         // we are in a free var case that is not yet mapped.
@@ -415,6 +468,13 @@ class StructEqualHandler {
   }
   // whether we map free variables that are not defined
   bool map_free_vars_{false};
+  // Whether we are currently inside a non-recursive def region. Set when a
+  // field flagged ``kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive`` is being
+  // walked (or the custom-callback caller passed 
``kTVMFFIFieldDefKindNonRecursive``).
+  // Consulted in CompareObject to clamp ``map_free_vars_`` to false when
+  // descending into a FreeVar's own sub-fields, while still allowing the
+  // FreeVar itself to bind in the post-pass.
+  bool non_recursive_def_active_{false};
   // whether we compare tensor data
   bool skip_tensor_content_{false};
   // the root lhs for result printing
diff --git a/src/ffi/extra/structural_hash.cc b/src/ffi/extra/structural_hash.cc
index 8ab96f0..e4993f8 100644
--- a/src/ffi/extra/structural_hash.cc
+++ b/src/ffi/extra/structural_hash.cc
@@ -130,6 +130,21 @@ class StructuralHashHandler {
     static reflection::TypeAttrColumn custom_s_hash =
         reflection::TypeAttrColumn(reflection::type_attr::kSHash);
 
+    // Non-recursive def boundary (mirror of structural_equal.cc). When the
+    // current object is a FreeVar AND we are inside a non-recursive def
+    // region, clamp ``map_free_vars_`` to false during the FreeVar's own
+    // sub-field walk: nested free vars in those sub-fields then hash by
+    // pointer (matching use-semantics) instead of receiving fresh
+    // ``free_var_counter_`` slots. The clamp is restored before the
+    // FreeVar-level injection below so the FreeVar itself still gets its
+    // counter slot when ``map_free_vars_`` was on at the call site.
+    bool save_map_free_vars = map_free_vars_;
+    bool clamp_map_free_vars =
+        (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) && 
non_recursive_def_active_;
+    if (clamp_map_free_vars) {
+      map_free_vars_ = false;
+    }
+
     // compute the hash value
     uint64_t hash_value = obj->GetTypeKeyHash();
     if (custom_s_hash[type_info->type_index] == nullptr) {
@@ -140,12 +155,23 @@ class StructuralHashHandler {
           // get the field value from both side
           reflection::FieldGetter getter(field_info);
           Any field_value = getter(obj);
-          // field is in def region, enable free var mapping
-          if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) {
-            bool allow_free_var = true;
-            std::swap(allow_free_var, map_free_vars_);
+          // Dispatch on the def-region flags (mirror of the equality side).
+          //   - Recursive    : map_free_vars_ stays on for the whole subtree.
+          //   - NonRecursive : map_free_vars_ on for the immediate FreeVar(s);
+          //                    clamped off when descending into a FreeVar's
+          //                    own sub-fields (the clamp lives in HashObject's
+          //                    prologue above, gated by 
non_recursive_def_active_).
+          constexpr int64_t kSEqHashDefAny = 
kTVMFFIFieldFlagBitMaskSEqHashDefRecursive |
+                                             
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive;
+          if (field_info->flags & kSEqHashDefAny) {
+            bool save_allow_free_var = map_free_vars_;
+            bool save_non_recursive = non_recursive_def_active_;
+            map_free_vars_ = true;
+            non_recursive_def_active_ =
+                (field_info->flags & 
kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive) != 0;
             hash_value = details::StableHashCombine(hash_value, 
HashAny(field_value));
-            std::swap(allow_free_var, map_free_vars_);
+            map_free_vars_ = save_allow_free_var;
+            non_recursive_def_active_ = save_non_recursive;
           } else {
             hash_value = details::StableHashCombine(hash_value, 
HashAny(field_value));
           }
@@ -154,12 +180,21 @@ class StructuralHashHandler {
     } else {
       if (s_hash_callback_ == nullptr) {
         s_hash_callback_ =
-            ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash, 
bool def_region) {
-              if (def_region) {
-                bool allow_free_var = true;
-                std::swap(allow_free_var, map_free_vars_);
+            // The third parameter is a ``TVMFFIFieldDefKind`` (typed as plain
+            // ``int`` on the wire to keep the FFI signature stable; legacy
+            // callers passing ``bool`` continue to compile and preserve their
+            // meaning via the implicit bool->int coercion: false -> 0
+            // (kTVMFFIFieldDefKindNone), true -> 1 
(kTVMFFIFieldDefKindRecursive)).
+            ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash, 
int def_kind) {
+              if (def_kind == kTVMFFIFieldDefKindRecursive ||
+                  def_kind == kTVMFFIFieldDefKindNonRecursive) {
+                bool save_allow_free_var = map_free_vars_;
+                bool save_non_recursive = non_recursive_def_active_;
+                map_free_vars_ = true;
+                non_recursive_def_active_ = (def_kind == 
kTVMFFIFieldDefKindNonRecursive);
                 uint64_t hash_value = HashAny(val);
-                std::swap(allow_free_var, map_free_vars_);
+                map_free_vars_ = save_allow_free_var;
+                non_recursive_def_active_ = save_non_recursive;
                 return 
static_cast<int64_t>(details::StableHashCombine(init_hash, hash_value));
               } else {
                 // we explicitly bitcast the result from `uint64_t` to 
`int64_t`.
@@ -175,6 +210,13 @@ class StructuralHashHandler {
               .cast<uint64_t>();
     }
 
+    // Restore the pre-clamp value of map_free_vars_ before deciding the
+    // FreeVar-level hash injection: the clamp only suppresses binding inside
+    // the FreeVar's own sub-fields, not the FreeVar slot itself.
+    if (clamp_map_free_vars) {
+      map_free_vars_ = save_map_free_vars;
+    }
+
     if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
       if (map_free_vars_) {
         // use lexical order of free var and its type
@@ -318,6 +360,13 @@ class StructuralHashHandler {
   }
 
   bool map_free_vars_{false};
+  // Whether we are currently inside a non-recursive def region. Set when a
+  // field flagged ``kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive`` is being
+  // hashed (or the custom-callback caller passed 
``kTVMFFIFieldDefKindNonRecursive``).
+  // Consulted in HashObject to clamp ``map_free_vars_`` to false when
+  // descending into a FreeVar's own sub-fields, so nested free vars hash by
+  // pointer rather than receiving fresh ``free_var_counter_`` slots.
+  bool non_recursive_def_active_{false};
   bool skip_tensor_content_{false};
   // free var counter.
   uint32_t free_var_counter_{0};
diff --git a/tests/cpp/extra/test_structural_equal_hash.cc 
b/tests/cpp/extra/test_structural_equal_hash.cc
index ad081e3..4649461 100644
--- a/tests/cpp/extra/test_structural_equal_hash.cc
+++ b/tests/cpp/extra/test_structural_equal_hash.cc
@@ -229,6 +229,83 @@ TEST(StructuralEqualHash, CustomTreeNode) {
   EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
 }
 
+// Regression tests for the SEqHashDefRecursive vs SEqHashDefNonRecursive
+// distinction. ``TDefHolder`` has two sibling fields:
+//   - ``def_recursive``     tagged AttachFieldFlag::SEqHashDefRecursive()
+//   - ``def_non_recursive`` tagged AttachFieldFlag::SEqHashDefNonRecursive()
+// each holding a ``TVarWithDep`` (a FreeVar with a sub-field ``dep`` that
+// can itself reference another FreeVar). The tests below cover the four
+// observable combinations of the two flags.
+TEST(StructuralEqualHash, NonRecursiveDef_NestedFreeVarRebindsUnderRecursive) {
+  // Both fields receive a TVarWithDep whose ``dep`` contains a *fresh*
+  // FreeVar (TVar). Under the recursive flag, the nested ``dep`` rebinds
+  // transitively, so two holders that differ only in fresh names compare
+  // equal.
+  TVarWithDep a("a", TVar("m"));
+  TVarWithDep b("b", TVar("n"));
+  TDefHolder lhs(/*def_recursive=*/a, /*def_non_recursive=*/a);
+  TDefHolder rhs(/*def_recursive=*/b, /*def_non_recursive=*/b);
+  // ``def_non_recursive`` is the *same object* on both sides so it equates
+  // by pointer; the test exercises the recursive field's rebinding behavior
+  // without requiring the non-recursive side to succeed.
+  EXPECT_TRUE(StructuralEqual()(lhs, rhs));
+  EXPECT_EQ(StructuralHash::Hash(lhs, /*map_free_vars=*/true),
+            StructuralHash::Hash(rhs, /*map_free_vars=*/true));
+}
+
+TEST(StructuralEqualHash, 
NonRecursiveDef_NestedFreeVarDoesNotRebindUnderNonRecursive) {
+  // The ``def_non_recursive`` field's value (TVarWithDep "c"/"d") binds
+  // because the holder explicitly tags the field as a def region. But the
+  // nested ``dep`` (TVar "p" / "q") is in the FreeVar's sub-field, which
+  // the non-recursive flag CLAMPS out of the def region. With no enclosing
+  // def region for "p" / "q", they must hit the unmapped FreeVar path in
+  // CompareObject and equality fails.
+  //
+  // The recursive sibling is set to identical pointers on both sides so the
+  // test isolates the non-recursive field as the failure source.
+  TVarWithDep shared("shared", std::nullopt);
+  TVarWithDep c_with_dep("c", TVar("p"));
+  TVarWithDep d_with_dep("d", TVar("q"));
+  TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_with_dep);
+  TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_with_dep);
+  EXPECT_FALSE(StructuralEqual::Equal(lhs, rhs, /*map_free_vars=*/false));
+}
+
+TEST(StructuralEqualHash, 
NonRecursiveDef_NestedFreeVarResolvesViaOuterBinding) {
+  // Now wire the same nested free var on both sides, so that even under the
+  // non-recursive clamp the FreeVars at the leaf compare equal by *pointer*
+  // (the same.same_as branch in CompareObject's FreeVar handling). This
+  // mirrors the let-style use case where the nested var has been bound by
+  // an outer scope (here we cheat by using the same pointer directly).
+  TVar shared_dep("dep");
+  TVarWithDep c_with_dep("c", shared_dep);
+  TVarWithDep d_with_dep("d", shared_dep);
+  TVarWithDep shared("shared", std::nullopt);
+  TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_with_dep);
+  TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_with_dep);
+  // ``c`` and ``d`` bind via the non-recursive def region; the nested
+  // shared_dep is the same object and so passes pointer-equality without
+  // needing map_free_vars_ to be on inside its sub-field walk.
+  EXPECT_TRUE(StructuralEqual()(lhs, rhs));
+  EXPECT_EQ(StructuralHash()(lhs), StructuralHash()(rhs));
+}
+
+TEST(StructuralEqualHash, NonRecursiveDef_TopLevelFreeVarStillBinds) {
+  // Sanity: even with the non-recursive flag, the immediate FreeVar at the
+  // field's value MUST still bind. So a TVarWithDep with no nested ``dep``
+  // (Optional set to nullopt) under the non-recursive flag should still
+  // compare equal across two fresh names — the binding step itself is not
+  // suppressed, only the descent into the FreeVar's sub-fields is.
+  TVarWithDep shared("shared", std::nullopt);
+  TVarWithDep c_no_dep("c", std::nullopt);
+  TVarWithDep d_no_dep("d", std::nullopt);
+  TDefHolder lhs(/*def_recursive=*/shared, /*def_non_recursive=*/c_no_dep);
+  TDefHolder rhs(/*def_recursive=*/shared, /*def_non_recursive=*/d_no_dep);
+  EXPECT_TRUE(StructuralEqual()(lhs, rhs));
+  EXPECT_EQ(StructuralHash::Hash(lhs, /*map_free_vars=*/true),
+            StructuralHash::Hash(rhs, /*map_free_vars=*/true));
+}
+
 TEST(StructuralEqualHash, List) {
   List<int> a = {1, 2, 3};
   List<int> b = {1, 2, 3};
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index f9d567f..eef8ddb 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -66,6 +66,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   TFloatObj::RegisterReflection();
   TPrimExprObj::RegisterReflection();
   TVarObj::RegisterReflection();
+  TVarWithDepObj::RegisterReflection();
+  TDefHolderObj::RegisterReflection();
   TFuncObj::RegisterReflection();
   TCustomFuncObj::RegisterReflection();
   TAllFieldsObj::RegisterReflection();
diff --git a/tests/cpp/testing_object.h b/tests/cpp/testing_object.h
index 48b1a01..d1bf0a9 100644
--- a/tests/cpp/testing_object.h
+++ b/tests/cpp/testing_object.h
@@ -206,6 +206,81 @@ class TVar : public ObjectRef {
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TVar, ObjectRef, TVarObj);
 };
 
+// FreeVar test object that has a sub-field referencing another FreeVar.
+// This models the "var with nested vars" case (analogous to a relax::Var
+// whose struct_info contains tir shape vars). It is used to exercise the
+// difference between SEqHashDefRecursive and SEqHashDefNonRecursive at the
+// FFI layer: under recursive semantics the nested ``dep`` var rebinds
+// transitively; under non-recursive semantics it is treated as a use of an
+// outer-scope binding and equality fails when no such outer binding exists.
+class TVarWithDepObj : public Object {
+ public:
+  std::string name;
+  // Optional dependency var; when null, this object behaves like a plain
+  // FreeVar with no nested free vars.
+  Optional<ObjectRef> dep;
+
+  TVarWithDepObj(std::string name, Optional<ObjectRef> dep)
+      : name(std::move(name)), dep(std::move(dep)) {}
+  explicit TVarWithDepObj(UnsafeInit) {}
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<TVarWithDepObj>()
+        .def_ro("name", &TVarWithDepObj::name, 
refl::AttachFieldFlag::SEqHashIgnore())
+        // ``dep`` participates in structural equality without any def flag,
+        // so it is a USE position. Whether the FreeVar in ``dep`` may rebind
+        // is decided by the def flag on whichever outer field reaches this
+        // object.
+        .def_ro("dep", &TVarWithDepObj::dep);
+  }
+
+  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindFreeVar;
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.VarWithDep", TVarWithDepObj, Object);
+};
+
+class TVarWithDep : public ObjectRef {
+ public:
+  explicit TVarWithDep(std::string name, Optional<ObjectRef> dep = 
std::nullopt) {
+    data_ = make_object<TVarWithDepObj>(std::move(name), std::move(dep));
+  }
+
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TVarWithDep, ObjectRef, 
TVarWithDepObj);
+};
+
+// Holder with one recursive-def field and one non-recursive-def field.
+// Used by StructuralEqualHash.NonRecursiveDef tests below.
+class TDefHolderObj : public Object {
+ public:
+  TVarWithDep def_recursive;
+  TVarWithDep def_non_recursive;
+
+  TDefHolderObj(TVarWithDep def_recursive, TVarWithDep def_non_recursive)
+      : def_recursive(std::move(def_recursive)), 
def_non_recursive(std::move(def_non_recursive)) {}
+  explicit TDefHolderObj(UnsafeInit) {}
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<TDefHolderObj>()
+        .def_ro("def_recursive", &TDefHolderObj::def_recursive,
+                refl::AttachFieldFlag::SEqHashDefRecursive())
+        .def_ro("def_non_recursive", &TDefHolderObj::def_non_recursive,
+                refl::AttachFieldFlag::SEqHashDefNonRecursive());
+  }
+
+  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindTreeNode;
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.DefHolder", TDefHolderObj, Object);
+};
+
+class TDefHolder : public ObjectRef {
+ public:
+  explicit TDefHolder(TVarWithDep def_recursive, TVarWithDep 
def_non_recursive) {
+    data_ = make_object<TDefHolderObj>(std::move(def_recursive), 
std::move(def_non_recursive));
+  }
+
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TDefHolder, ObjectRef, 
TDefHolderObj);
+};
+
 class TFuncObj : public Object {
  public:
   Array<TVar> params;
@@ -220,7 +295,7 @@ class TFuncObj : public Object {
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
     refl::ObjectDef<TFuncObj>()
-        .def_ro("params", &TFuncObj::params, 
refl::AttachFieldFlag::SEqHashDef())
+        .def_ro("params", &TFuncObj::params, 
refl::AttachFieldFlag::SEqHashDefRecursive())
         .def_ro("body", &TFuncObj::body)
         .def_ro("comment", &TFuncObj::comment, 
refl::AttachFieldFlag::SEqHashIgnore());
   }


Reply via email to