gemini-code-assist[bot] commented on code in PR #594:
URL: https://github.com/apache/tvm-ffi/pull/594#discussion_r3268493797
##########
tests/cpp/extra/test_structural_map.cc:
##########
@@ -0,0 +1,8 @@
+// testing structural map
+
+#include <gtest/gtest.h>
+#include <tvm/ffi/extra/structural_map.h>
+
+TEST(StructuralMap, Basic) {
+ EXPECT_EQ(StructuralMap::Map(1), 1);
Review Comment:

This test will not compile as written. `StructuralMap` is not a class; the
intended class is likely `StructuralMapper`. Furthermore, `Map` is a virtual
member function, not a static one, and it expects an `AnyView` (passing a
literal `1` might work via implicit conversion, but the class usage is
incorrect).
##########
src/ffi/extra/structural_visit.cc:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/ffi/extra/structural_visit.cc
+ * \brief Structural visit implementation.
+ */
+#include <tvm/ffi/extra/structural_visit.h>
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/accessor.h>
+
+namespace tvm {
+namespace ffi {
+namespace {
+
+// Walk reflected fields of `value` and recurse into each non-ignored field.
+Optional<VisitInterrupt> VisitReflectedFields(StructuralVisitor* visitor,
const ObjectRef& value) {
+ const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(value.type_index());
+
+ Optional<VisitInterrupt> result = std::nullopt;
+ reflection::ForEachFieldInfoWithEarlyStop(type_info, [&](const
TVMFFIFieldInfo* field_info) {
+ if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) {
+ return false;
+ }
+
+ reflection::FieldGetter getter(field_info);
+ Any field_value = getter(value);
+
+ TVMFFIDefRegionKind kind = kTVMFFIDefRegionKindNone;
+ if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive) {
+ kind = kTVMFFIDefRegionKindNonRecursive;
+ } else if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDefRecursive)
{
+ kind = kTVMFFIDefRegionKindRecursive;
+ }
+
+ result = visitor->WithDefRegionKind(
+ kind, [&]() { return visitor->Visit(field_value.cast<ObjectRef>()); });
Review Comment:

The call to `field_value.cast<ObjectRef>()` will fail if the field contains
a primitive type (e.g., `int`, `float`). The visitor should check if the field
value is an object before attempting to recurse into it via `Visit`.
```c
if (field_value.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) {
result = visitor->WithDefRegionKind(
kind, [&]() { return
visitor->Visit(field_value.cast<ObjectRef>()); });
}
```
##########
include/tvm/ffi/extra/structural_visit.h:
##########
@@ -0,0 +1,396 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/structural_visit.h
+ * \brief Structural visit implementation
+ */
+#ifndef TVM_FFI_EXTRA_STRUCTURAL_VISIT_H_
+#define TVM_FFI_EXTRA_STRUCTURAL_VISIT_H_
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/extra/base.h>
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/optional.h>
+#include <tvm/ffi/reflection/access_path.h>
+
+#include <type_traits>
+#include <utility>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief Object node carrying the optional payload for an interrupted
structural visit.
+ */
+class VisitInterruptObj : public Object {
+ public:
+ /*! \brief Payload returned with the interrupt, or FFI None for no payload.
*/
+ Any value;
+
+ VisitInterruptObj() = default;
+ /*!
+ * \brief Construct a VisitInterruptObj with a payload.
+ * \param value The payload carried by the interrupt.
+ */
+ explicit VisitInterruptObj(Any value) : value(std::move(value)) {}
+
+ /// \cond Doxygen_Suppress
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.VisitInterrupt", VisitInterruptObj,
Object);
+ /// \endcond
+};
+
+/*!
+ * \brief ObjectRef wrapper for VisitInterruptObj.
+ */
+class VisitInterrupt : public ObjectRef {
+ public:
+ /*! \brief Construct an interrupt with no payload. */
+ VisitInterrupt() : VisitInterrupt(Any(nullptr)) {}
+ /*!
+ * \brief Construct an interrupt with a user-defined payload.
+ * \param value The payload carried by the interrupt.
+ */
+ explicit VisitInterrupt(Any value)
+ : ObjectRef(make_object<VisitInterruptObj>(std::move(value))) {}
+
+ /// \cond Doxygen_Suppress
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(VisitInterrupt, ObjectRef,
VisitInterruptObj);
+ /// \endcond
+};
+
+/*!
+ * \brief C-ABI safe-call style visit function pointer.
+ *
+ * Used as the primary entry point of \ref StructuralVisitor so that non-C++
+ * bindings (e.g. Rust) can implement and invoke visitors without crossing a
+ * C++ exception boundary.
+ *
+ * \param self Opaque visitor self pointer (the value stored in
+ * \ref StructuralVisitor::self).
+ * \param value The object being visited.
+ * \param out Out parameter: set to ``std::nullopt`` on no-interrupt, or to a
+ * \ref VisitInterrupt to halt traversal. Must be pointer to a
+ * default-initialized \c Optional<VisitInterrupt>.
+ * \return 0 on success, non-zero on error. On error, the error is set via
+ * \c TVMFFIErrorSetRaised and may be retrieved with
+ * \c TVMFFIErrorMoveFromRaised.
+ *
+ * \sa TVMFFISafeCallType
+ */
+using FStructuralVisitSafe = int (*)(void* self, const ObjectRef& value,
+ Optional<VisitInterrupt>* out);
+
+/*!
+ * \brief C++ fast-path visit function pointer (throws on error).
+ *
+ * Optional companion to \ref FStructuralVisitSafe. When non-null, callers in
+ * the same C++ ABI may invoke it directly and let exceptions propagate, saving
+ * the catch/rethrow round-trip of the safe-call path.
+ *
+ * Always set to \c nullptr for visitors authored outside C++.
+ */
+using FStructuralVisitCpp = Optional<VisitInterrupt> (*)(void* self, const
ObjectRef& value);
+
+/*!
+ * \brief Structural visitor driving recursive traversal of an Object tree.
+ *
+ * The visitor is a layout-stable POD-shaped struct exposing a small
+ * function-pointer table (``safe_visit`` / ``cpp_visit``) and an opaque
+ * ``self`` pointer. This mirrors the design of \ref TVMFFIFunctionCell and
+ * makes the visitor authorable / invokable from non-C++ bindings.
+ *
+ * Construction modes:
+ * - Default-constructed visitors dispatch through the per-type structural
+ * visit attribute registry (\c reflection::type_attr::kStructuralVisit) and
+ * fall back to a reflection-driven field walk when no override is
registered.
+ * - Derived visitors should fill in ``self`` and the function-pointer table in
+ * their constructor. \ref StructuralVisitorImpl is a convenience template
+ * that wires this up automatically for a C++ callable.
+ *
+ * The class deliberately avoids virtual functions so the layout is stable
+ * across the FFI boundary. Custom dispatch is expressed through the
+ * function-pointer table rather than virtual overrides.
+ */
+class StructuralVisitor {
+ public:
+ // -------- C-ABI layout: keep these fields first and in this order ---------
+ /*! \brief Required C-ABI safe-call entry. Never null on a constructed
visitor. */
+ FStructuralVisitSafe safe_visit = nullptr;
+ /*!
+ * \brief Optional C++ fast-path entry. ``nullptr`` for non-C++ visitors.
+ *
+ * Stored as ``void*`` (rather than \ref FStructuralVisitCpp) to keep this
+ * struct free of C++-specific signatures so that language bindings can
+ * mirror its layout with a plain pointer field.
+ */
+ void* cpp_visit = nullptr;
+ /*! \brief Opaque self pointer forwarded to ``safe_visit`` / ``cpp_visit``.
*/
+ void* self = nullptr;
+ /*! \brief Current def-region context for structural eq/hash semantics. */
+ TVMFFIDefRegionKind def_region_mode = kTVMFFIDefRegionKindNone;
+
+ // --------------------------- C++-only API ---------------------------------
+
+ /*!
+ * \brief Construct the default structural visitor.
+ *
+ * Wires up ``safe_visit`` / ``cpp_visit`` to the default dispatcher
+ * implemented in \c structural_visit.cc, which consults the structural-visit
+ * type attribute registry and falls back to a reflection-driven field walk.
+ */
+ StructuralVisitor();
+
+ ~StructuralVisitor() = default;
+ StructuralVisitor(const StructuralVisitor&) = default;
+ StructuralVisitor(StructuralVisitor&&) = default;
+ StructuralVisitor& operator=(const StructuralVisitor&) = default;
+ StructuralVisitor& operator=(StructuralVisitor&&) = default;
+
+ /*!
+ * \brief Visit a value, dispatching through this visitor's function table.
+ *
+ * Prefers ``cpp_visit`` when available (no catch/rethrow), otherwise routes
+ * through ``safe_visit`` and rethrows any raised error as a C++ exception.
+ *
+ * \param value The object to visit.
+ * \return ``std::nullopt`` to continue traversal, or a \ref VisitInterrupt
+ * to halt the entire visit.
+ */
+ TVM_FFI_INLINE Optional<VisitInterrupt> Visit(const ObjectRef& value) {
+ // Use cpp_visit fast path when present, mirroring FunctionObj::CallPacked.
Review Comment:

The `Visit` method should check if the `value` is defined before proceeding
with the dispatch. Visiting a null `ObjectRef` will lead to a crash in
`DefaultVisit` when it attempts to access the type index of the object.
```c
TVM_FFI_INLINE Optional<VisitInterrupt> Visit(const ObjectRef& value) {
if (!value.defined()) return std::nullopt;
// Use cpp_visit fast path when present, mirroring
FunctionObj::CallPacked.
```
##########
python/tvm_ffi/structural.py:
##########
@@ -231,3 +231,65 @@ def __hash__(self) -> int:
def __eq__(self, other: Any) -> bool:
"""Compare by structural equality."""
return isinstance(other, StructuralKey) and
_ffi_api.StructuralKeyEqual(self, other)
+
+
+class StructuralVisitor(Object):
+ """Structural visitor object."""
+ def __init__(self) -> None:
+ self.__init_handle_by_constructor__(_ffi_api.StructuralVisitor)
+
+ def visit(self, value: Any) -> Any:
+ return _ffi_api.StructuralVisit(self, value)
+
+ def with_def_region_kind(self, kind: TVMFFIDefRegionKind, callback:
Callable[[], Any]) -> Any:
+ return _ffi_api.StructuralVisitorWithDefRegionKind(self, kind,
callback)
+
+
+
+#### example usage
+
+```
+visitor = StructuralVisitor()
+visitor.visit(root)
+visitor.with_def_region_kind(TVMFFIDefRegionKind.kDefRecursive, lambda:
visitor.visit(root))
+```
+
+
+def visit_for_(visitor: StructuralVisitor, op: For) ->
Optional[VisitInterrupt]:
+ """Visitor implementation for For."""
+ visitor.visit(op.min)
+ visitor.visit(op.extent)
+ if op.step is not None:
+ visitor.visit(op.step)
+ visitor.visit(op.body)
+
+
+def visit_function(visitor: StructuralVisitor, op: Function) -> VisitInterrupt
| None:
+ # params introduce definitions
+ ret = visitor.with_def_region_kind(
+ TVMFFIDefRegionKind.kTVMFFIDefRegionKindRecursive,
+ lambda: visitor.visit(op.params),
+ )
+ if ret is not None:
+ return ret
+
+ # body is normal use context
+ return visitor.visit(op.body)
+
+
+def visit_let(visitor: StructuralVisitor, op: Let) -> VisitInterrupt | None:
+ # RHS is use context
+ ret = visitor.visit(op.value)
+ if ret is not None:
+ return ret
+
+ # LHS var is a definition, but vars inside its annotation are uses
+ ret = visitor.with_def_region_kind(
+ TVMFFIDefRegionKind.kTVMFFIDefRegionKindNonRecursive,
+ lambda: visitor.visit(op.var),
+ )
+ if ret is not None:
+ return ret
+
+ return visitor.visit(op.body)
+
Review Comment:

This section contains example usage and helper functions that appear to be
scratchpad code or documentation snippets. These should be moved to a dedicated
test or example file to keep the library module clean and focused on the API
implementation.
##########
tests/cpp/extra/test_structural_visit.cc:
##########
@@ -0,0 +1,114 @@
+Optional<VisitInterrupt> StructuralVisitFor(StructuralVisitor* visitor,
AnyView self) {
+ For op = self.cast<For>();
+
+ if (auto ret = visitor->Visit(op->min)) return ret;
+ if (auto ret = visitor->Visit(op->extent)) return ret;
+ if (auto ret = visitor->Visit(op->body)) return ret;
+
+ return std::nullopt;
+ }
+
+ TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+
+ refl::TypeAttrDef<ForNode>().attr(
+ refl::type_attr::kStructuralVisit,
+ reinterpret_cast<void*>(&StructuralVisitFor));
+ }
+
+auto result = structuralWalk<Add>(
+root,
+[](Add op, DefRegionKind kind) -> Variant<WalkResult, VisitInterrupt> {
+ // called for every Add
+ return WalkResult::kAdvance;
+},
+WalkOrder::kPreOrder);
+
+
+std::vector<Var> defs;
+std::vector<Var> uses;
+
+structuralWalk<Var>(
+ root,
+ [&](Var var, DefRegionKind kind) -> Variant<WalkResult, VisitInterrupt> {
+ if (kind == DefRegionKind::kNone) {
+ uses.push_back(var);
+ } else {
+ defs.push_back(var);
+ }
+ return WalkResult::kAdvance;
+ }, WalkOrder::kPreOrder);
+
+
+structuralWalk<Add, Mul, Div>(
+ root,
+ [](Variant<Add, Mul, Div> op, DefRegionKind kind)
+ -> Variant<WalkResult, VisitInterrupt> {
+ if (op.as<Add>()) {
+ // handle Add
+ } else if (op.as<Mul>()) {
+ // handle Mul
+ }
+ return WalkResult::kAdvance;
+ }, WalkOrder::kPreOrder);
+
+
+auto interrupt = structuralWalk<Add, Mul, Div, Function>(
+ root,
+ [&](Variant<Add, Mul, Div, Function> op,
+ DefRegionKind kind) -> Variant<WalkResult, VisitInterrupt> {
+ if (auto add = op.as<Add>()) {
+ // handle Add
+ }
+ if (auto mul = op.as<Mul>()) {
+ // handle Mul
+ }
+ return WalkResult::kAdvance;
+ },
+ WalkOrder::kPreOrder
+);
+
+
+Optional<VisitInterrupt> StructuralVisitFunction(StructuralVisitor* visitor,
AnyView value) {
+ Function func = value.cast<Function>();
+
+ if (auto ret = visitor->WithDefRegionKind(DefRegionKind::kDefRecursive,
[&]() {
+ return visitor->Visit(func->params);
+ })) {
+ return ret;
+ }
+
+ return visitor->Visit(func->body);
+ }
+
+ TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+
+ refl::TypeAttrDef<FunctionObj>().attr(
+ refl::type_attr::kStructuralVisit,
+ reinterpret_cast<void*>(&StructuralVisitFunction));
+ }
+
+
+ StructuralVisitor visitor;
+ Optional<VisitInterrupt> interrupt = visitor.Visit(root);
+ if (interrupt) {
+ Any payload = (*interrupt)->value;
+ }
+
+
+ auto interrupt = structuralWalk<Add, Mul, Div, Function>(
+ root,
+ [&](Variant<Add, Mul, Div, Function> op,
+ DefRegionKind kind) -> Variant<WalkResult, VisitInterrupt> {
+ if (auto add = op.as<Add>()) {
+ // handle Add
+ }
+ if (auto mul = op.as<Mul>()) {
+ // handle Mul
+ }
+ return WalkResult::kAdvance;
+ },
+ WalkOrder::kPreOrder
+);
Review Comment:

This file is not a valid C++ source file. It contains loose code snippets,
missing necessary includes and namespaces (e.g., `tvm::ffi`), and lacks GTest
`TEST` macros for most of the logic. It appears to be a collection of design
examples rather than a functional test suite.
##########
src/ffi/extra/structural_map.cc:
##########
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/ffi/extra/structural_map.cc
+ * \brief Structural map implementation.
+ */
+#include <tvm/ffi/extra/structural_map.h>
+#include <tvm/ffi/reflection/accessor.h>
+
+namespace tvm {
+namespace ffi {
+
+Any StructuralMapper::Map(AnyView value) {
+ if (value.type_index() == TypeIndex::kTVMFFINone ||
+ value.type_index() < TypeIndex::kTVMFFIStaticObjectBegin) {
+ return Any(value);
+ }
+ return CallStructuralMap(value);
+}
+
+Optional<Any> StructuralMapper::LookupVarRemap(AnyView old_var) const {
+ auto it = var_remap_.find(Any(old_var));
+ if (it == var_remap_.end()) {
+ return std::nullopt;
+ }
+ return (*it).second;
+}
+
+void StructuralMapper::SetVarRemap(Any old_var, Any new_var) {
+ var_remap_.Set(std::move(old_var), std::move(new_var));
+}
+
+// objectref*, or object* (might be risky)
+// return objectref
+// 玩一玩 a version
+Any StructuralMapper::MapOrInplaceMutator(ObjectRef&& obj) {
+ TVM_FFI_ICHECK(obj.defined());
+
+ if (obj.unique() && HasInplaceMutator(obj.type_index())) {
+ return InplaceMutator(std::move(obj));
+ }
+
+ return Map(obj);
+}
+
+Any StructuralMapper::InplaceMutator(ObjectRef&& obj) {
+ TVM_FFI_ICHECK(obj.defined());
+ TVM_FFI_ICHECK(obj.unique());
+
+ static reflection::TypeAttrColumn
column(reflection::type_attr::kStructuralInplaceMutator);
+ AnyView attr = column[obj.type_index()];
+ if (attr.type_index() == TypeIndex::kTVMFFIOpaquePtr) {
+ auto* fn = reinterpret_cast<FStructuralInplaceMutator>(attr.cast<void*>());
+ TVM_FFI_ICHECK_NOTNULL(fn);
+ return (*fn)(this, std::move(obj));
+ }
+ if (attr.type_index() != TypeIndex::kTVMFFINone) {
+ TVM_FFI_THROW(TypeError) <<
reflection::type_attr::kStructuralInplaceMutator
+ << " must be an opaque function pointer";
+ }
+ return InplaceMutateReflectedFields(std::move(obj));
+}
+
+Any StructuralMapper::CallStructuralMap(AnyView value) {
+ static reflection::TypeAttrColumn
column(reflection::type_attr::kStructuralMap);
+ AnyView attr = column[value.type_index()];
+ if (attr.type_index() == TypeIndex::kTVMFFIOpaquePtr) {
+ auto* fn = reinterpret_cast<FStructuralMap>(attr.cast<void*>());
+ TVM_FFI_ICHECK_NOTNULL(fn);
+ return (*fn)(this, value);
+ }
+ if (attr.type_index() != TypeIndex::kTVMFFINone) {
+ TVM_FFI_THROW(TypeError) << reflection::type_attr::kStructuralMap
+ << " must be an opaque function pointer";
+ }
+ return MapReflectedFields(value);
+}
+
+Any StructuralMapper::MapReflectedFields(AnyView value) {
+ return Any(value);
+}
+
+Any StructuralMapper::InplaceMutateReflectedFields(ObjectRef&& obj) {
+ return Any(std::move(obj));
+}
Review Comment:

The current implementations of `MapReflectedFields` and
`InplaceMutateReflectedFields` are stubs that return the input value unchanged.
To function as a proper structural mapper or mutator, these methods should
provide a default implementation that recursively processes the reflected
fields of the object, similar to the recursive traversal implemented in
`StructuralVisitor`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]