gemini-code-assist[bot] commented on code in PR #454: URL: https://github.com/apache/tvm-ffi/pull/454#discussion_r2810658028
########## src/ffi/extra/repr_print.cc: ########## @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/repr_print.cc + * + * \brief Reflection-based repr printing with BFS-based cycle/DAG handling. + */ +#include <tvm/ffi/any.h> +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/container/list.h> +#include <tvm/ffi/container/map.h> +#include <tvm/ffi/container/shape.h> +#include <tvm/ffi/container/tensor.h> +#include <tvm/ffi/dtype.h> +#include <tvm/ffi/error.h> +#include <tvm/ffi/reflection/accessor.h> +#include <tvm/ffi/reflection/registry.h> + +#include <iomanip> +#include <sstream> +#include <string> +#include <unordered_map> +#include <vector> + +namespace tvm { +namespace ffi { + +namespace { + +/*! + * \brief Convert a DLDeviceType to a short name string. + */ +const char* DeviceTypeName(int device_type) { + switch (device_type) { + case kDLCPU: + return "cpu"; + case kDLCUDA: + return "cuda"; + case kDLCUDAHost: + return "cuda_host"; + case kDLOpenCL: + return "opencl"; + case kDLVulkan: + return "vulkan"; + case kDLMetal: + return "metal"; + case kDLVPI: + return "vpi"; + case kDLROCM: + return "rocm"; + case kDLROCMHost: + return "rocm_host"; + case kDLExtDev: + return "ext_dev"; + case kDLCUDAManaged: + return "cuda_managed"; + case kDLOneAPI: + return "oneapi"; + case kDLWebGPU: + return "webgpu"; + case kDLHexagon: + return "hexagon"; + default: + return "unknown"; + } +} + +/*! + * \brief Format a DLDevice as "device_name:device_id". + */ +std::string DeviceToString(DLDevice device) { + std::ostringstream os; + os << DeviceTypeName(device.device_type) << ":" << device.device_id; + return os.str(); +} + +/*! + * \brief Format an object address as a hex string. + */ +std::string AddressStr(const Object* obj) { + std::ostringstream os; + os << "0x" << std::hex << reinterpret_cast<uintptr_t>(obj); + return os.str(); +} + +/*! + * \brief Get the type key of an object as a std::string. + */ +std::string GetTypeKeyStr(const Object* obj) { + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index()); + return std::string(type_info->type_key.data, type_info->type_key.size); +} + +/*! + * \brief Lazily initialize and return the __ffi_repr__ TypeAttrColumn. + * + * Returns nullptr if the column does not exist (e.g., before registration). + */ +const TVMFFITypeAttrColumn* GetReprColumn() { + static const TVMFFITypeAttrColumn* column = []() -> const TVMFFITypeAttrColumn* { + TVMFFIByteArray name = {reflection::type_attr::kRepr, + std::char_traits<char>::length(reflection::type_attr::kRepr)}; + return TVMFFIGetTypeAttrColumn(&name); + }(); + return column; +} + +/*! + * \brief Lazily initialize and return the __ffi_repr_fields__ TypeAttrColumn. + * + * Returns nullptr if the column does not exist. + */ +const TVMFFITypeAttrColumn* GetReprFieldsColumn() { + static const TVMFFITypeAttrColumn* column = []() -> const TVMFFITypeAttrColumn* { + TVMFFIByteArray name = {reflection::type_attr::kReprFields, + std::char_traits<char>::length(reflection::type_attr::kReprFields)}; + return TVMFFIGetTypeAttrColumn(&name); + }(); + return column; +} + +/*! + * \brief Look up a type attribute from a column by type_index. + * \return AnyView of the attribute, or a None AnyView if not found. + */ +AnyView LookupTypeAttr(const TVMFFITypeAttrColumn* column, int32_t type_index) { + if (column == nullptr) return AnyView(); + size_t tindex = static_cast<size_t>(type_index); + if (tindex >= column->size) return AnyView(); + const AnyView* data = reinterpret_cast<const AnyView*>(column->data); + return data[tindex]; +} + +/*! + * \brief BFS-based repr printer. + * + * Algorithm: + * 1. BFS collect all objects reachable from root (tracking visit count). + * 2. Process in reverse BFS order (leaves first), building repr strings. + * 3. Objects encountered more than once use short-form on second+ occurrence. + * + * Modeled after ObjectDeepCopier in deep_copy.cc. + */ +class ReprPrinter { + public: + String Run(const Any& value) { + // Phase 1: BFS collection + CollectAny(value); + // Phase 2: Reverse-BFS processing + for (int64_t i = static_cast<int64_t>(bfs_queue_.size()) - 1; i >= 0; --i) { + ProcessNode(bfs_queue_[i]); + } + // Phase 3: Return repr of root + return String(ReprOfAny(value)); + } + + private: + // ---------- Phase 1: BFS Collection ---------- + + void CollectAny(const Any& value) { + int32_t ti = value.type_index(); + if (ti < TypeIndex::kTVMFFIStaticObjectBegin) return; + const Object* obj = static_cast<const Object*>(value.as<Object>()); + if (obj == nullptr) return; + // Track visit count + auto [it, inserted] = visit_count_.emplace(obj, 1); + if (!inserted) { + it->second++; + return; // Already visited + } + bfs_queue_.push_back(obj); + CollectChildren(obj, ti); + } + + void CollectChildren(const Object* obj, int32_t ti) { + switch (ti) { + case TypeIndex::kTVMFFIStr: + case TypeIndex::kTVMFFIBytes: + case TypeIndex::kTVMFFIShape: + case TypeIndex::kTVMFFITensor: + case TypeIndex::kTVMFFIFunction: + case TypeIndex::kTVMFFIError: + case TypeIndex::kTVMFFIOpaquePyObject: + // Leaf types: no children + break; + case TypeIndex::kTVMFFIArray: { + const ArrayObj* arr = static_cast<const ArrayObj*>(obj); + for (const Any& elem : *arr) CollectAny(elem); + break; + } + case TypeIndex::kTVMFFIList: { + const ListObj* lst = static_cast<const ListObj*>(obj); + for (const Any& elem : *lst) CollectAny(elem); + break; + } + case TypeIndex::kTVMFFIMap: { + const MapObj* map = static_cast<const MapObj*>(obj); + for (const auto& [k, v] : *map) { + CollectAny(k); + CollectAny(v); + } + break; + } + default: + // User-defined object: collect via reflection fields + CollectFieldChildren(obj); + break; + } + } + + void CollectFieldChildren(const Object* obj) { + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index()); + if (type_info == nullptr) return; + + // Check if __ffi_repr_fields__ restricts which fields to show + AnyView repr_fields_attr = LookupTypeAttr(GetReprFieldsColumn(), obj->type_index()); + if (repr_fields_attr != nullptr) { + // Explicit field list + Array<Any> field_names = repr_fields_attr.cast<Array<Any>>(); + for (const Any& name_any : field_names) { + std::string name = name_any.cast<String>().c_str(); + reflection::ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* finfo) { + std::string_view field_name(finfo->name.data, finfo->name.size); + if (field_name == name) { + reflection::FieldGetter getter(finfo); + CollectAny(getter(obj)); + return true; + } + return false; + }); + } Review Comment:  The current implementation iterates through all fields for each field name specified in `repr_fields`. This results in a time complexity of O(N*M), where N is the number of fields to represent and M is the total number of fields in the object. This could be inefficient for objects with many fields. A more efficient approach would be to first collect the desired field names into a `std::unordered_set<std::string>` and then iterate through all fields once with `ForEachFieldInfo`, checking for membership in the set. This would reduce the complexity to roughly O(N + M). A similar inefficiency exists in `GenericRepr`. ########## src/ffi/extra/repr_print.cc: ########## @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/repr_print.cc + * + * \brief Reflection-based repr printing with BFS-based cycle/DAG handling. + */ +#include <tvm/ffi/any.h> +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/container/list.h> +#include <tvm/ffi/container/map.h> +#include <tvm/ffi/container/shape.h> +#include <tvm/ffi/container/tensor.h> +#include <tvm/ffi/dtype.h> +#include <tvm/ffi/error.h> +#include <tvm/ffi/reflection/accessor.h> +#include <tvm/ffi/reflection/registry.h> + +#include <iomanip> +#include <sstream> +#include <string> +#include <unordered_map> +#include <vector> + +namespace tvm { +namespace ffi { + +namespace { + +/*! + * \brief Convert a DLDeviceType to a short name string. + */ +const char* DeviceTypeName(int device_type) { + switch (device_type) { + case kDLCPU: + return "cpu"; + case kDLCUDA: + return "cuda"; + case kDLCUDAHost: + return "cuda_host"; + case kDLOpenCL: + return "opencl"; + case kDLVulkan: + return "vulkan"; + case kDLMetal: + return "metal"; + case kDLVPI: + return "vpi"; + case kDLROCM: + return "rocm"; + case kDLROCMHost: + return "rocm_host"; + case kDLExtDev: + return "ext_dev"; + case kDLCUDAManaged: + return "cuda_managed"; + case kDLOneAPI: + return "oneapi"; + case kDLWebGPU: + return "webgpu"; + case kDLHexagon: + return "hexagon"; + default: + return "unknown"; + } +} + +/*! + * \brief Format a DLDevice as "device_name:device_id". + */ +std::string DeviceToString(DLDevice device) { + std::ostringstream os; + os << DeviceTypeName(device.device_type) << ":" << device.device_id; + return os.str(); +} + +/*! + * \brief Format an object address as a hex string. + */ +std::string AddressStr(const Object* obj) { + std::ostringstream os; + os << "0x" << std::hex << reinterpret_cast<uintptr_t>(obj); + return os.str(); +} + +/*! + * \brief Get the type key of an object as a std::string. + */ +std::string GetTypeKeyStr(const Object* obj) { + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index()); + return std::string(type_info->type_key.data, type_info->type_key.size); +} + +/*! + * \brief Lazily initialize and return the __ffi_repr__ TypeAttrColumn. + * + * Returns nullptr if the column does not exist (e.g., before registration). + */ +const TVMFFITypeAttrColumn* GetReprColumn() { + static const TVMFFITypeAttrColumn* column = []() -> const TVMFFITypeAttrColumn* { + TVMFFIByteArray name = {reflection::type_attr::kRepr, + std::char_traits<char>::length(reflection::type_attr::kRepr)}; + return TVMFFIGetTypeAttrColumn(&name); + }(); + return column; +} + +/*! + * \brief Lazily initialize and return the __ffi_repr_fields__ TypeAttrColumn. + * + * Returns nullptr if the column does not exist. + */ +const TVMFFITypeAttrColumn* GetReprFieldsColumn() { + static const TVMFFITypeAttrColumn* column = []() -> const TVMFFITypeAttrColumn* { + TVMFFIByteArray name = {reflection::type_attr::kReprFields, + std::char_traits<char>::length(reflection::type_attr::kReprFields)}; + return TVMFFIGetTypeAttrColumn(&name); + }(); + return column; +} + +/*! + * \brief Look up a type attribute from a column by type_index. + * \return AnyView of the attribute, or a None AnyView if not found. + */ +AnyView LookupTypeAttr(const TVMFFITypeAttrColumn* column, int32_t type_index) { + if (column == nullptr) return AnyView(); + size_t tindex = static_cast<size_t>(type_index); + if (tindex >= column->size) return AnyView(); + const AnyView* data = reinterpret_cast<const AnyView*>(column->data); + return data[tindex]; +} + +/*! + * \brief BFS-based repr printer. + * + * Algorithm: + * 1. BFS collect all objects reachable from root (tracking visit count). + * 2. Process in reverse BFS order (leaves first), building repr strings. + * 3. Objects encountered more than once use short-form on second+ occurrence. + * + * Modeled after ObjectDeepCopier in deep_copy.cc. + */ +class ReprPrinter { + public: + String Run(const Any& value) { + // Phase 1: BFS collection + CollectAny(value); + // Phase 2: Reverse-BFS processing + for (int64_t i = static_cast<int64_t>(bfs_queue_.size()) - 1; i >= 0; --i) { + ProcessNode(bfs_queue_[i]); + } + // Phase 3: Return repr of root + return String(ReprOfAny(value)); + } + + private: + // ---------- Phase 1: BFS Collection ---------- + + void CollectAny(const Any& value) { + int32_t ti = value.type_index(); + if (ti < TypeIndex::kTVMFFIStaticObjectBegin) return; + const Object* obj = static_cast<const Object*>(value.as<Object>()); + if (obj == nullptr) return; + // Track visit count + auto [it, inserted] = visit_count_.emplace(obj, 1); + if (!inserted) { + it->second++; + return; // Already visited + } + bfs_queue_.push_back(obj); + CollectChildren(obj, ti); + } + + void CollectChildren(const Object* obj, int32_t ti) { + switch (ti) { + case TypeIndex::kTVMFFIStr: + case TypeIndex::kTVMFFIBytes: + case TypeIndex::kTVMFFIShape: + case TypeIndex::kTVMFFITensor: + case TypeIndex::kTVMFFIFunction: + case TypeIndex::kTVMFFIError: + case TypeIndex::kTVMFFIOpaquePyObject: + // Leaf types: no children + break; + case TypeIndex::kTVMFFIArray: { + const ArrayObj* arr = static_cast<const ArrayObj*>(obj); + for (const Any& elem : *arr) CollectAny(elem); + break; + } + case TypeIndex::kTVMFFIList: { + const ListObj* lst = static_cast<const ListObj*>(obj); + for (const Any& elem : *lst) CollectAny(elem); + break; + } + case TypeIndex::kTVMFFIMap: { + const MapObj* map = static_cast<const MapObj*>(obj); + for (const auto& [k, v] : *map) { + CollectAny(k); + CollectAny(v); + } + break; + } + default: + // User-defined object: collect via reflection fields + CollectFieldChildren(obj); + break; + } + } + + void CollectFieldChildren(const Object* obj) { + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index()); + if (type_info == nullptr) return; + + // Check if __ffi_repr_fields__ restricts which fields to show + AnyView repr_fields_attr = LookupTypeAttr(GetReprFieldsColumn(), obj->type_index()); + if (repr_fields_attr != nullptr) { + // Explicit field list + Array<Any> field_names = repr_fields_attr.cast<Array<Any>>(); + for (const Any& name_any : field_names) { + std::string name = name_any.cast<String>().c_str(); + reflection::ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* finfo) { + std::string_view field_name(finfo->name.data, finfo->name.size); + if (field_name == name) { + reflection::FieldGetter getter(finfo); + CollectAny(getter(obj)); + return true; + } + return false; + }); + } + } else { + // No restriction: collect all fields + reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* finfo) { + reflection::FieldGetter getter(finfo); + Any fv = getter(obj); + if (fv.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) { + CollectAny(fv); + } + }); + } + } + + // ---------- Phase 2: Processing ---------- + + void ProcessNode(const Object* obj) { + if (repr_map_.count(obj)) return; // Already processed + + int32_t ti = obj->type_index(); + AnyView custom_repr = LookupTypeAttr(GetReprColumn(), ti); + + if (custom_repr != nullptr) { + // Custom __ffi_repr__: call it with fn_repr callback + Function repr_fn = custom_repr.cast<Function>(); + Function fn_repr = CreateFnRepr(); + String result = repr_fn(obj, fn_repr).cast<String>(); + repr_map_[obj] = std::string(result.data(), result.size()); + } else { + // Generic reflection-based repr + repr_map_[obj] = GenericRepr(obj); + } + } + + Function CreateFnRepr() { + return Function::FromTyped( + [this](AnyView value) -> String { return String(ReprOfAny(Any(value))); }); + } + + // ---------- Repr Helpers ---------- + + std::string ReprOfAny(const Any& value) { + int32_t ti = value.type_index(); + switch (ti) { + case TypeIndex::kTVMFFINone: + return "None"; + case TypeIndex::kTVMFFIBool: + return value.cast<bool>() ? "True" : "False"; + case TypeIndex::kTVMFFIInt: + return std::to_string(value.cast<int64_t>()); + case TypeIndex::kTVMFFIFloat: { + std::ostringstream os; + os << value.cast<double>(); + return os.str(); + } + case TypeIndex::kTVMFFIDataType: { + String s = DLDataTypeToString(value.cast<DLDataType>()); + return std::string(s.data(), s.size()); + } + case TypeIndex::kTVMFFIDevice: { + return DeviceToString(value.cast<DLDevice>()); + } + default: + break; + } + if (ti == TypeIndex::kTVMFFISmallStr) { + String s = value.cast<String>(); + return "\"" + std::string(s.data(), s.size()) + "\""; + } + if (ti == TypeIndex::kTVMFFISmallBytes) { + Bytes b = value.cast<Bytes>(); + std::ostringstream os; + os << "b\""; + for (size_t i = 0; i < b.size(); ++i) { + unsigned char c = static_cast<unsigned char>(b.data()[i]); + if (c >= 32 && c < 127 && c != '\"' && c != '\\') { + os << static_cast<char>(c); + } else { + os << "\\x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(c); + } + } + os << "\""; + return os.str(); + } Review Comment:  The logic for formatting bytes into a string representation is duplicated here for `kTVMFFISmallBytes` and in `ReprBytes` for `kTVMFFIBytes`. To improve maintainability, this logic could be extracted into a helper function that takes a `const char*` and a `size_t` and returns the formatted `std::string`. For example, you could add a helper function: ```cpp namespace { // ... std::string FormatBytes(const char* data, size_t size) { std::ostringstream os; os << "b\""; for (size_t i = 0; i < size; ++i) { unsigned char c = static_cast<unsigned char>(data[i]); if (c >= 32 && c < 127 && c != '\"' && c != '\\') { os << static_cast<char>(c); } else { os << "\\x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(c); } } os << "\""; return os.str(); } // ... } ``` Then you can use it in both places. ```c if (ti == TypeIndex::kTVMFFISmallBytes) { Bytes b = value.cast<Bytes>(); return FormatBytes(b.data(), b.size()); } ``` ########## src/ffi/extra/repr_print.cc: ########## @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/repr_print.cc + * + * \brief Reflection-based repr printing with BFS-based cycle/DAG handling. + */ +#include <tvm/ffi/any.h> +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/container/list.h> +#include <tvm/ffi/container/map.h> +#include <tvm/ffi/container/shape.h> +#include <tvm/ffi/container/tensor.h> +#include <tvm/ffi/dtype.h> +#include <tvm/ffi/error.h> +#include <tvm/ffi/reflection/accessor.h> +#include <tvm/ffi/reflection/registry.h> + +#include <iomanip> +#include <sstream> +#include <string> +#include <unordered_map> +#include <vector> + +namespace tvm { +namespace ffi { + +namespace { + +/*! + * \brief Convert a DLDeviceType to a short name string. + */ +const char* DeviceTypeName(int device_type) { + switch (device_type) { + case kDLCPU: + return "cpu"; + case kDLCUDA: + return "cuda"; + case kDLCUDAHost: + return "cuda_host"; + case kDLOpenCL: + return "opencl"; + case kDLVulkan: + return "vulkan"; + case kDLMetal: + return "metal"; + case kDLVPI: + return "vpi"; + case kDLROCM: + return "rocm"; + case kDLROCMHost: + return "rocm_host"; + case kDLExtDev: + return "ext_dev"; + case kDLCUDAManaged: + return "cuda_managed"; + case kDLOneAPI: + return "oneapi"; + case kDLWebGPU: + return "webgpu"; + case kDLHexagon: + return "hexagon"; + default: + return "unknown"; + } +} + +/*! + * \brief Format a DLDevice as "device_name:device_id". + */ +std::string DeviceToString(DLDevice device) { + std::ostringstream os; + os << DeviceTypeName(device.device_type) << ":" << device.device_id; + return os.str(); +} + +/*! + * \brief Format an object address as a hex string. + */ +std::string AddressStr(const Object* obj) { + std::ostringstream os; + os << "0x" << std::hex << reinterpret_cast<uintptr_t>(obj); + return os.str(); +} + +/*! + * \brief Get the type key of an object as a std::string. + */ +std::string GetTypeKeyStr(const Object* obj) { + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index()); + return std::string(type_info->type_key.data, type_info->type_key.size); +} + +/*! + * \brief Lazily initialize and return the __ffi_repr__ TypeAttrColumn. + * + * Returns nullptr if the column does not exist (e.g., before registration). + */ +const TVMFFITypeAttrColumn* GetReprColumn() { + static const TVMFFITypeAttrColumn* column = []() -> const TVMFFITypeAttrColumn* { + TVMFFIByteArray name = {reflection::type_attr::kRepr, + std::char_traits<char>::length(reflection::type_attr::kRepr)}; + return TVMFFIGetTypeAttrColumn(&name); + }(); + return column; +} + +/*! + * \brief Lazily initialize and return the __ffi_repr_fields__ TypeAttrColumn. + * + * Returns nullptr if the column does not exist. + */ +const TVMFFITypeAttrColumn* GetReprFieldsColumn() { + static const TVMFFITypeAttrColumn* column = []() -> const TVMFFITypeAttrColumn* { + TVMFFIByteArray name = {reflection::type_attr::kReprFields, + std::char_traits<char>::length(reflection::type_attr::kReprFields)}; + return TVMFFIGetTypeAttrColumn(&name); + }(); + return column; +} + +/*! + * \brief Look up a type attribute from a column by type_index. + * \return AnyView of the attribute, or a None AnyView if not found. + */ +AnyView LookupTypeAttr(const TVMFFITypeAttrColumn* column, int32_t type_index) { + if (column == nullptr) return AnyView(); + size_t tindex = static_cast<size_t>(type_index); + if (tindex >= column->size) return AnyView(); + const AnyView* data = reinterpret_cast<const AnyView*>(column->data); + return data[tindex]; +} + +/*! + * \brief BFS-based repr printer. + * + * Algorithm: + * 1. BFS collect all objects reachable from root (tracking visit count). + * 2. Process in reverse BFS order (leaves first), building repr strings. + * 3. Objects encountered more than once use short-form on second+ occurrence. + * + * Modeled after ObjectDeepCopier in deep_copy.cc. + */ +class ReprPrinter { + public: + String Run(const Any& value) { + // Phase 1: BFS collection + CollectAny(value); + // Phase 2: Reverse-BFS processing + for (int64_t i = static_cast<int64_t>(bfs_queue_.size()) - 1; i >= 0; --i) { + ProcessNode(bfs_queue_[i]); + } + // Phase 3: Return repr of root + return String(ReprOfAny(value)); + } + + private: + // ---------- Phase 1: BFS Collection ---------- + + void CollectAny(const Any& value) { + int32_t ti = value.type_index(); + if (ti < TypeIndex::kTVMFFIStaticObjectBegin) return; + const Object* obj = static_cast<const Object*>(value.as<Object>()); + if (obj == nullptr) return; + // Track visit count + auto [it, inserted] = visit_count_.emplace(obj, 1); + if (!inserted) { + it->second++; + return; // Already visited + } + bfs_queue_.push_back(obj); + CollectChildren(obj, ti); + } + + void CollectChildren(const Object* obj, int32_t ti) { + switch (ti) { + case TypeIndex::kTVMFFIStr: + case TypeIndex::kTVMFFIBytes: + case TypeIndex::kTVMFFIShape: + case TypeIndex::kTVMFFITensor: + case TypeIndex::kTVMFFIFunction: + case TypeIndex::kTVMFFIError: + case TypeIndex::kTVMFFIOpaquePyObject: + // Leaf types: no children + break; + case TypeIndex::kTVMFFIArray: { + const ArrayObj* arr = static_cast<const ArrayObj*>(obj); + for (const Any& elem : *arr) CollectAny(elem); + break; + } + case TypeIndex::kTVMFFIList: { + const ListObj* lst = static_cast<const ListObj*>(obj); + for (const Any& elem : *lst) CollectAny(elem); + break; + } + case TypeIndex::kTVMFFIMap: { + const MapObj* map = static_cast<const MapObj*>(obj); + for (const auto& [k, v] : *map) { + CollectAny(k); + CollectAny(v); + } + break; + } + default: + // User-defined object: collect via reflection fields + CollectFieldChildren(obj); + break; + } + } + + void CollectFieldChildren(const Object* obj) { + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index()); + if (type_info == nullptr) return; + + // Check if __ffi_repr_fields__ restricts which fields to show + AnyView repr_fields_attr = LookupTypeAttr(GetReprFieldsColumn(), obj->type_index()); + if (repr_fields_attr != nullptr) { + // Explicit field list + Array<Any> field_names = repr_fields_attr.cast<Array<Any>>(); + for (const Any& name_any : field_names) { + std::string name = name_any.cast<String>().c_str(); + reflection::ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* finfo) { + std::string_view field_name(finfo->name.data, finfo->name.size); + if (field_name == name) { + reflection::FieldGetter getter(finfo); + CollectAny(getter(obj)); + return true; + } + return false; + }); + } + } else { + // No restriction: collect all fields + reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* finfo) { + reflection::FieldGetter getter(finfo); + Any fv = getter(obj); + if (fv.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) { + CollectAny(fv); + } + }); + } + } + + // ---------- Phase 2: Processing ---------- + + void ProcessNode(const Object* obj) { + if (repr_map_.count(obj)) return; // Already processed + + int32_t ti = obj->type_index(); + AnyView custom_repr = LookupTypeAttr(GetReprColumn(), ti); + + if (custom_repr != nullptr) { + // Custom __ffi_repr__: call it with fn_repr callback + Function repr_fn = custom_repr.cast<Function>(); + Function fn_repr = CreateFnRepr(); + String result = repr_fn(obj, fn_repr).cast<String>(); + repr_map_[obj] = std::string(result.data(), result.size()); + } else { + // Generic reflection-based repr + repr_map_[obj] = GenericRepr(obj); + } + } + + Function CreateFnRepr() { + return Function::FromTyped( + [this](AnyView value) -> String { return String(ReprOfAny(Any(value))); }); + } + + // ---------- Repr Helpers ---------- + + std::string ReprOfAny(const Any& value) { + int32_t ti = value.type_index(); + switch (ti) { + case TypeIndex::kTVMFFINone: + return "None"; + case TypeIndex::kTVMFFIBool: + return value.cast<bool>() ? "True" : "False"; + case TypeIndex::kTVMFFIInt: + return std::to_string(value.cast<int64_t>()); + case TypeIndex::kTVMFFIFloat: { + std::ostringstream os; + os << value.cast<double>(); + return os.str(); + } + case TypeIndex::kTVMFFIDataType: { + String s = DLDataTypeToString(value.cast<DLDataType>()); + return std::string(s.data(), s.size()); + } + case TypeIndex::kTVMFFIDevice: { + return DeviceToString(value.cast<DLDevice>()); + } + default: + break; + } + if (ti == TypeIndex::kTVMFFISmallStr) { + String s = value.cast<String>(); + return "\"" + std::string(s.data(), s.size()) + "\""; + } + if (ti == TypeIndex::kTVMFFISmallBytes) { + Bytes b = value.cast<Bytes>(); + std::ostringstream os; + os << "b\""; + for (size_t i = 0; i < b.size(); ++i) { + unsigned char c = static_cast<unsigned char>(b.data()[i]); + if (c >= 32 && c < 127 && c != '\"' && c != '\\') { + os << static_cast<char>(c); + } else { + os << "\\x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(c); + } + } + os << "\""; + return os.str(); + } + if (ti < TypeIndex::kTVMFFIStaticObjectBegin) { + // Other POD types + return value.GetTypeKey(); + } + // Object type + const Object* obj = static_cast<const Object*>(value.as<Object>()); + if (obj == nullptr) return "None"; + auto it = repr_map_.find(obj); + if (it != repr_map_.end()) { + return it->second; + } + // Not yet processed -- use short form + return ShortRepr(obj); + } + + std::string ShortRepr(const Object* obj) { return GetTypeKeyStr(obj) + "@" + AddressStr(obj); } + + std::string GenericRepr(const Object* obj) { + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index()); + if (type_info == nullptr) { + return ShortRepr(obj); + } + std::string type_key(type_info->type_key.data, type_info->type_key.size); + std::ostringstream os; + os << type_key << "@" << AddressStr(obj) << "("; + + bool first = true; + auto emit_field = [&](const TVMFFIFieldInfo* finfo) { + if (!first) os << ", "; + first = false; + os << std::string_view(finfo->name.data, finfo->name.size) << "="; + reflection::FieldGetter getter(finfo); + Any fv = getter(obj); + os << ReprOfAny(fv); + }; + + AnyView repr_fields_attr = LookupTypeAttr(GetReprFieldsColumn(), obj->type_index()); + if (repr_fields_attr != nullptr) { + // Explicit field subset + Array<Any> field_names = repr_fields_attr.cast<Array<Any>>(); + for (const Any& name_any : field_names) { + std::string name = name_any.cast<String>().c_str(); + reflection::ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* finfo) { + std::string_view field_name(finfo->name.data, finfo->name.size); + if (field_name == name) { + emit_field(finfo); + return true; + } + return false; + }); + } + } else { + // All fields (or no fields registered at all) + bool has_fields = false; + reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* finfo) { + has_fields = true; + emit_field(finfo); + }); + if (!has_fields) { + // No reflection fields -- just show type_key@address + return type_key + "@" + AddressStr(obj); + } + } + os << ")"; + return os.str(); + } + + // ---------- Data members ---------- + std::vector<const Object*> bfs_queue_; + std::unordered_map<const Object*, int> visit_count_; + std::unordered_map<const Object*, std::string> repr_map_; +}; + +// ---------- Built-in __ffi_repr__ functions ---------- + +String ReprString(const details::StringObj* obj, const Function& fn_repr) { + std::ostringstream os; + os << "\"" << std::string_view(obj->data, obj->size) << "\""; + return String(os.str()); +} + +String ReprBytes(const details::BytesObj* obj, const Function& fn_repr) { + std::ostringstream os; + os << "b\""; + for (size_t i = 0; i < obj->size; ++i) { + unsigned char c = static_cast<unsigned char>(obj->data[i]); + if (c >= 32 && c < 127 && c != '\"' && c != '\\') { + os << static_cast<char>(c); + } else { + os << "\\x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(c); + } + } + os << "\""; + return String(os.str()); +} + +String ReprTensor(const TensorObj* obj, const Function& fn_repr) { + std::ostringstream os; + os << "Tensor." << DLDataTypeToString(obj->dtype); + os << "["; + for (int i = 0; i < obj->ndim; ++i) { + if (i > 0) os << ", "; + os << obj->shape[i]; + } + os << "]@" << DeviceToString(obj->device); + os << "@" << AddressStr(obj); + return String(os.str()); +} + +String ReprShape(const ShapeObj* obj, const Function& fn_repr) { + std::ostringstream os; + os << "Shape(["; + for (size_t i = 0; i < obj->size; ++i) { + if (i > 0) os << ", "; + os << obj->data[i]; + } + os << "])"; + return String(os.str()); +} + +String ReprArray(const ArrayObj* obj, const Function& fn_repr) { + std::ostringstream os; + os << "["; + bool first = true; + for (const Any& elem : *obj) { + if (!first) os << ", "; + first = false; + String s = fn_repr(elem).cast<String>(); + os << std::string_view(s.data(), s.size()); + } + os << "]"; + return String(os.str()); +} + +String ReprList(const ListObj* obj, const Function& fn_repr) { + std::ostringstream os; + os << "["; + bool first = true; + for (const Any& elem : *obj) { + if (!first) os << ", "; + first = false; + String s = fn_repr(elem).cast<String>(); + os << std::string_view(s.data(), s.size()); + } + os << "]"; + return String(os.str()); +} Review Comment:  The implementations of `ReprArray` and `ReprList` are identical. This code duplication can be avoided by creating a template function that works for any sequence-like container that can be iterated over. For example: ```cpp template <typename TContainer> String ReprSequence(const TContainer* obj, const Function& fn_repr) { std::ostringstream os; os << "["; bool first = true; for (const Any& elem : *obj) { if (!first) os << ", "; first = false; String s = fn_repr(elem).cast<String>(); os << std::string_view(s.data(), s.size()); } os << "]"; return String(os.str()); } String ReprArray(const ArrayObj* obj, const Function& fn_repr) { return ReprSequence(obj, fn_repr); } String ReprList(const ListObj* obj, const Function& fn_repr) { return ReprSequence(obj, fn_repr); } ``` This would improve maintainability. ########## python/tvm_ffi/cython/object.pxi: ########## @@ -26,8 +26,21 @@ def _set_class_object(cls): _CLASS_OBJECT = cls +_REPR_PRINT = None +_REPR_PRINT_LOADED = False + + def __object_repr__(obj: "Object") -> str: - """Object repr function that can be overridden by assigning to it""" + """Object repr function using ffi.ReprPrint when available.""" + global _REPR_PRINT, _REPR_PRINT_LOADED + if not _REPR_PRINT_LOADED: + _REPR_PRINT_LOADED = True + _REPR_PRINT = _get_global_func("ffi.ReprPrint", False) + if _REPR_PRINT is not None: + try: + return str(_REPR_PRINT(obj)) + except Exception: + pass Review Comment:  The use of `except Exception: pass` is concerning as it will silently swallow any errors that occur within the `ffi.ReprPrint` C++ implementation. This can make debugging issues with object representation very difficult. Consider logging the exception here to make developers aware of underlying problems. For example: ```python import logging ... try: return str(_REPR_PRINT(obj)) except Exception as e: logging.warning("Failed to generate repr using ffi.ReprPrint: %s", e, exc_info=True) pass ``` If adding a logging dependency is not desirable, you could consider printing to `sys.stderr` under a debug flag. ########## tests/python/test_repr.py: ########## @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for __ffi_repr__ / ffi.ReprPrint.""" + +import re + +import numpy as np +import pytest +import tvm_ffi +import tvm_ffi.testing +from tvm_ffi import _ffi_api + + +def test_repr_primitives() -> None: + """Test repr of primitive types.""" + assert _ffi_api.ReprPrint(42) == "42" + assert _ffi_api.ReprPrint(0) == "0" + assert _ffi_api.ReprPrint(-1) == "-1" + assert _ffi_api.ReprPrint(True) == "True" + assert _ffi_api.ReprPrint(False) == "False" + assert _ffi_api.ReprPrint(None) == "None" + + +def test_repr_float() -> None: + """Test repr of floating point.""" + result = _ffi_api.ReprPrint(3.14) + assert "3.14" in result + + +def test_repr_string() -> None: + """Test repr of FFI String.""" + result = _ffi_api.ReprPrint("hello world") + assert result == '"hello world"' + + +def test_repr_array() -> None: + """Test repr of FFI Array.""" + arr = tvm_ffi.Array([1, 2, 3]) + result = _ffi_api.ReprPrint(arr) + assert result == "[1, 2, 3]" + + +def test_repr_array_nested() -> None: + """Test repr of nested Array.""" + arr = tvm_ffi.Array(["a", "b"]) + result = _ffi_api.ReprPrint(arr) + assert result == '["a", "b"]' + + +def test_repr_list() -> None: + """Test repr of FFI List.""" + lst = tvm_ffi.List([10, 20]) + result = _ffi_api.ReprPrint(lst) + assert result == "[10, 20]" + + +def test_repr_map() -> None: + """Test repr of FFI Map.""" + m = tvm_ffi.Map({"key": "value"}) + result = _ffi_api.ReprPrint(m) + assert '"key": "value"' in result + + +def test_repr_tensor() -> None: + """Test repr of Tensor shows dtype, shape, device, address.""" + x = tvm_ffi.from_dlpack(np.zeros((3, 4), dtype="float32")) + result = _ffi_api.ReprPrint(x) + # Should match: Tensor.float32[3, 4]@cpu:0@0x... + assert result.startswith("Tensor.float32[3, 4]@cpu:0@0x") + + +def test_repr_tensor_int8() -> None: + """Test repr of Tensor with int8 dtype.""" + x = tvm_ffi.from_dlpack(np.zeros((2,), dtype="int8")) + result = _ffi_api.ReprPrint(x) + assert result.startswith("Tensor.int8[2]@cpu:0@0x") + + +def test_repr_shape() -> None: + """Test repr of Shape.""" + shape = tvm_ffi.Shape((5, 6)) + result = _ffi_api.ReprPrint(shape) + assert result == "Shape([5, 6])" + + +def test_repr_user_object_all_fields() -> None: + """Test repr of user-defined object with all fields shown.""" + obj = tvm_ffi.testing.create_object("testing.TestIntPair", a=10, b=20) + result = _ffi_api.ReprPrint(obj) + # Format: testing.TestIntPair@0x...(a=10, b=20) + assert re.match(r"testing\.TestIntPair@0x[0-9a-f]+\(a=10, b=20\)", result) + + +def test_repr_user_object_repr_fields() -> None: + """Test repr of object with repr_fields restriction (TestCxxClassDerived shows only v_f64, v_f32).""" + obj = tvm_ffi.testing._TestCxxClassDerived( + v_i64=1, + v_i32=2, + v_f64=3.5, + v_f32=4.5, + ) + result = _ffi_api.ReprPrint(obj) + # Should show only v_f64 and v_f32, not v_i64 and v_i32 + assert "v_f64=3.5" in result + assert "v_f32=4.5" in result + assert "v_i64" not in result + assert "v_i32" not in result + + +def test_repr_duplicate_reference() -> None: + """Test that duplicate object references use short form.""" + inner = tvm_ffi.testing.create_object("testing.TestIntPair", a=1, b=2) + arr = tvm_ffi.Array([inner, inner]) + result = _ffi_api.ReprPrint(arr) + # The array should show two elements; one full, one short form + # The full form includes (a=1, b=2), the short form is just type@addr + assert "a=1" in result + assert "b=2" in result Review Comment:  The assertions in this test verify that the full representation of the object is present, but they don't explicitly check that the short form is used for the duplicate reference. A more robust test would be to assert that the full representation appears only once. ```suggestion assert result.count("(a=1, b=2)") == 1 ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
