This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 3b26a09a doc: Properly document `field_static_type_index` and add more 
tests (#456)
3b26a09a is described below

commit 3b26a09a1e47a55e641f2267317c97114b27ab36
Author: Junru Shao <[email protected]>
AuthorDate: Tue Feb 17 07:04:43 2026 -0800

    doc: Properly document `field_static_type_index` and add more tests (#456)
    
    ## Summary
    
    Restore full usage of `field_static_type_index` in TVMFFIFieldInfo with
    thorough documentation explaining its compile-time static type
    semantics:
    the field reflects the declared type at compile time, which may differ
    from the runtime type (e.g., `Any` vs `int`, `Array<Any>` vs
    `Array<int>`).
    
    The serializer uses this field to inline POD values (None, Bool, Int,
    Float, DataType) directly, avoiding unnecessary node-graph overhead,
    while routing other types through the standard node graph.
    
    Changes:
    - Add comprehensive documentation to `field_static_type_index` in
      TVMFFIFieldInfo (C API) explaining all possible values and semantics
    - Restore `TypeToFieldStaticTypeIndex` and `TypeToRuntimeTypeIndex`
      templates in type_traits.h
    - Restore `field_static_type_index` in all TypeTraits specializations
    - Restore POD field inlining in serialization/deserialization
    - Add comprehensive C++ and Python serialization tests (40 + 58 tests)
    
    ## Test plan
    
    - [x] All 302 C++ tests pass (including 40 serialization + 21 reflection
    tests)
    - [x] All 58 Python serialization tests pass
    - [x] Verified `field_static_type_index` is populated correctly in
    reflection via test object field info
---
 include/tvm/ffi/c_api.h               |  31 +-
 include/tvm/ffi/type_traits.h         |   1 -
 rust/tvm-ffi-sys/src/c_api.rs         |   2 +-
 tests/cpp/extra/test_serialization.cc | 364 ++++++++++++++++++++++++
 tests/cpp/test_reflection.cc          |   2 +
 tests/cpp/testing_object.h            |  96 +++++++
 tests/python/test_serialization.py    | 515 ++++++++++++++++++++++++++++++++++
 7 files changed, 996 insertions(+), 15 deletions(-)

diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 67715ca1..236e342f 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -979,23 +979,28 @@ typedef struct {
    */
   TVMFFIAny default_value_or_factory;
   /*!
-   * \brief Records the static type kind of the field.
+   * \brief The compile-time static type index of the field.
    *
-   * Possible values:
+   * This reflects the type declared at compile time, which is only trustworthy
+   * for statically-inferrable types. It does NOT necessarily match the runtime
+   * type. For example, a field declared as `Any` will have
+   * `field_static_type_index == kTVMFFIAny` even if it holds an `int` at 
runtime,
+   * and `Array<Any>` will report `kTVMFFIArray` even though the elements may 
be
+   * `Array<int>` at runtime.
    *
-   * - TVMFFITypeIndex::kTVMFFIObject for general objects.
-   *   The value is nullable when kTVMFFIObject is chosen.
-   * - Static object type kinds such as Map, Dict, String
-   * - POD type index, note it does not give information about storage size of 
the field.
-   * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info
-   *   about the field.
+   * \warning Do NOT use this field to determine the actual type of a value at
+   * runtime. It is purely a compile-time hint derived from the C++ field
+   * declaration. The actual runtime type must be obtained from the value's
+   * own `type_index`.
    *
-   * When the value is a type index of Object type, the field is storaged as 
an ObjectRef.
+   * \warning When the static type is a generic container (e.g. `Array<Any>`),
+   * this field only tells you it is an Array β€” it says nothing about the
+   * element types actually stored inside. Similarly, `kTVMFFIObject` only
+   * means "some ObjectRef" without any subtype information.
    *
-   * \note This information maybe helpful in designing serializer.
-   * As it helps to narrow down the field type so we don't have to
-   * print type_key for cases like POD types.
-   * It also helps to provide opportunities to enable short-cut getter to 
ObjectRef fields.
+   * \note This is used by the serializer to inline POD field values directly
+   * (avoiding node-graph overhead for None, Bool, Int, Float, DataType),
+   * while routing other types through the node graph.
    */
   int32_t field_static_type_index;
 } TVMFFIFieldInfo;
diff --git a/include/tvm/ffi/type_traits.h b/include/tvm/ffi/type_traits.h
index dc7f982b..e832559b 100644
--- a/include/tvm/ffi/type_traits.h
+++ b/include/tvm/ffi/type_traits.h
@@ -135,7 +135,6 @@ struct TypeToRuntimeTypeIndex<T, 
std::enable_if_t<std::is_base_of_v<ObjectRef, T
 template <>
 struct TypeTraits<std::nullptr_t> : public TypeTraitsBase {
   static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFINone;
-
   TVM_FFI_INLINE static void CopyToAnyView(const std::nullptr_t&, TVMFFIAny* 
result) {
     result->type_index = TypeIndex::kTVMFFINone;
     result->zero_padding = 0;
diff --git a/rust/tvm-ffi-sys/src/c_api.rs b/rust/tvm-ffi-sys/src/c_api.rs
index d94967ba..ee555e2a 100644
--- a/rust/tvm-ffi-sys/src/c_api.rs
+++ b/rust/tvm-ffi-sys/src/c_api.rs
@@ -291,7 +291,7 @@ pub struct TVMFFIFieldInfo {
     /// When kTVMFFIFieldFlagBitMaskDefaultFromFactory is also set,
     /// this is a callable factory function () -> Any.
     pub default_value_or_factory: TVMFFIAny,
-    /// Records the static type kind of the field
+    /// Records the compile-time static type kind of the field.
     pub field_static_type_index: i32,
 }
 
diff --git a/tests/cpp/extra/test_serialization.cc 
b/tests/cpp/extra/test_serialization.cc
index 1d9d4830..ce6e42cc 100644
--- a/tests/cpp/extra/test_serialization.cc
+++ b/tests/cpp/extra/test_serialization.cc
@@ -26,6 +26,8 @@
 #include <tvm/ffi/extra/structural_equal.h>
 #include <tvm/ffi/string.h>
 
+#include <limits>
+
 #include "../testing_object.h"
 
 namespace {
@@ -410,4 +412,366 @@ TEST(Serialization, ShuffleNodeOrder) {
   EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_shuffled), 
duplicated_map));
 }
 
+// ---------------------------------------------------------------------------
+// Integer edge cases
+// ---------------------------------------------------------------------------
+TEST(Serialization, IntegerEdgeCases) {
+  // zero
+  
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(static_cast<int64_t>(0))),
+                                static_cast<int64_t>(0)));
+  // negative
+  
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(static_cast<int64_t>(-1))),
+                                static_cast<int64_t>(-1)));
+  // large positive
+  int64_t large = 1000000000000LL;
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(large)), large));
+  // large negative
+  int64_t large_neg = -999999999999LL;
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(large_neg)), 
large_neg));
+  // INT64_MIN and INT64_MAX
+  int64_t imin = std::numeric_limits<int64_t>::min();
+  int64_t imax = std::numeric_limits<int64_t>::max();
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(imin)), imin));
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(imax)), imax));
+}
+
+// ---------------------------------------------------------------------------
+// Float edge cases
+// ---------------------------------------------------------------------------
+TEST(Serialization, FloatEdgeCases) {
+  // zero
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(0.0)), 0.0));
+  // negative
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(-1.5)), -1.5));
+  // very large
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(1e300)), 1e300));
+  // very small
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(1e-300)), 1e-300));
+}
+
+// ---------------------------------------------------------------------------
+// String edge cases
+// ---------------------------------------------------------------------------
+TEST(Serialization, EmptyString) {
+  String empty("");
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(empty)), empty));
+}
+
+TEST(Serialization, UnicodeString) {
+  String unicode("hello δΈ–η•Œ 🌍");
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(unicode)), unicode));
+}
+
+TEST(Serialization, NullCharInString) {
+  // String with embedded null characters
+  std::string with_null("ab\0cd", 5);
+  String s(with_null);
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(s)), s));
+}
+
+// ---------------------------------------------------------------------------
+// Object with all POD field types (exercises node-graph path for POD fields)
+// ---------------------------------------------------------------------------
+TEST(Serialization, AllFieldsObject) {
+  DLDataType dtype;
+  dtype.code = kDLFloat;
+  dtype.bits = 32;
+  dtype.lanes = 1;
+
+  DLDevice device;
+  device.device_type = kDLCUDA;
+  device.device_id = 3;
+
+  Array<Any> arr;
+  arr.push_back(1);
+  arr.push_back(String("two"));
+
+  Map<String, Any> map{{"k", 99}};
+
+  TAllFields obj(true, -7, 2.5, dtype, device, String("hello"), String("opt"), 
arr, map);
+  json::Value serialized = ToJSONGraph(obj);
+  Any deserialized = FromJSONGraph(serialized);
+
+  // verify each field
+  TAllFields result = deserialized.cast<TAllFields>();
+  EXPECT_EQ(result->v_bool, true);
+  EXPECT_EQ(result->v_int, -7);
+  EXPECT_DOUBLE_EQ(result->v_float, 2.5);
+  EXPECT_EQ(result->v_dtype.code, kDLFloat);
+  EXPECT_EQ(result->v_dtype.bits, 32);
+  EXPECT_EQ(result->v_dtype.lanes, 1);
+  EXPECT_EQ(result->v_device.device_type, kDLCUDA);
+  EXPECT_EQ(result->v_device.device_id, 3);
+  EXPECT_EQ(std::string(result->v_str), "hello");
+  EXPECT_TRUE(result->v_opt_str.has_value());
+  EXPECT_EQ(std::string(result->v_opt_str.value()), "opt");
+  EXPECT_EQ(result->v_array.size(), 2);
+  EXPECT_EQ(result->v_map.size(), 1);
+}
+
+TEST(Serialization, AllFieldsObjectOptionalNone) {
+  DLDataType dtype;
+  dtype.code = kDLInt;
+  dtype.bits = 64;
+  dtype.lanes = 1;
+
+  DLDevice device;
+  device.device_type = kDLCPU;
+  device.device_id = 0;
+
+  TAllFields obj(false, 0, 0.0, dtype, device, String(""), std::nullopt, 
Array<Any>(),
+                 Map<String, Any>());
+  json::Value serialized = ToJSONGraph(obj);
+  Any deserialized = FromJSONGraph(serialized);
+
+  TAllFields result = deserialized.cast<TAllFields>();
+  EXPECT_EQ(result->v_bool, false);
+  EXPECT_EQ(result->v_int, 0);
+  EXPECT_DOUBLE_EQ(result->v_float, 0.0);
+  EXPECT_EQ(std::string(result->v_str), "");
+  EXPECT_FALSE(result->v_opt_str.has_value());
+  EXPECT_EQ(result->v_array.size(), 0);
+  EXPECT_EQ(result->v_map.size(), 0);
+}
+
+// ---------------------------------------------------------------------------
+// Default field values during deserialization
+// ---------------------------------------------------------------------------
+TEST(Serialization, DefaultFieldValues) {
+  // serialize a TWithDefaults, then deserialize from JSON with missing 
default fields
+  TWithDefaults original(100, 42, "default", true);
+  json::Value serialized = ToJSONGraph(original);
+  // roundtrip should work
+  Any deserialized = FromJSONGraph(serialized);
+  TWithDefaults result = deserialized.cast<TWithDefaults>();
+  EXPECT_EQ(result->required_val, 100);
+  EXPECT_EQ(result->default_int, 42);
+  EXPECT_EQ(std::string(result->default_str), "default");
+  EXPECT_EQ(result->default_bool, true);
+}
+
+TEST(Serialization, DefaultFieldValuesMissing) {
+  // manually construct JSON with only required field, defaults should kick in
+  // required_val is int64_t so it is inlined directly (POD field)
+  json::Object data;
+  data.Set("required_val", static_cast<int64_t>(999));
+
+  json::Object graph{
+      {"root_index", 0},
+      {"nodes", json::Array{json::Object{{"type", "test.WithDefaults"}, 
{"data", data}}}}};
+  Any result = FromJSONGraph(graph);
+  TWithDefaults obj = result.cast<TWithDefaults>();
+  EXPECT_EQ(obj->required_val, 999);
+  EXPECT_EQ(obj->default_int, 42);
+  EXPECT_EQ(std::string(obj->default_str), "default");
+  EXPECT_EQ(obj->default_bool, true);
+}
+
+// ---------------------------------------------------------------------------
+// Shared object references
+// ---------------------------------------------------------------------------
+TEST(Serialization, SharedObjectReferences) {
+  TVar shared_var("shared");
+  // two funcs share the same var
+  TFunc f1({shared_var}, {shared_var, shared_var}, std::nullopt);
+
+  json::Value serialized = ToJSONGraph(f1);
+  Any deserialized = FromJSONGraph(serialized);
+  TFunc result = deserialized.cast<TFunc>();
+
+  // all references to "shared" should be the same object after deserialization
+  // via the node dedup mechanism
+  EXPECT_EQ(result->params.size(), 1);
+  EXPECT_EQ(result->body.size(), 2);
+  // the params[0] and body[0] and body[1] should all refer to the same object
+  EXPECT_EQ(result->params[0].get(), result->body[0].get());
+  EXPECT_EQ(result->body[0].get(), result->body[1].get());
+}
+
+// ---------------------------------------------------------------------------
+// Nested objects
+// ---------------------------------------------------------------------------
+TEST(Serialization, NestedObjects) {
+  TVar x("x");
+  TVar y("y");
+  TFunc inner({x}, {x}, String("inner"));
+  // put the inner func as a body element of the outer func
+  TFunc outer({y}, {inner}, String("outer"));
+
+  json::Value serialized = ToJSONGraph(outer);
+  Any deserialized = FromJSONGraph(serialized);
+  TFunc result = deserialized.cast<TFunc>();
+
+  EXPECT_EQ(result->comment.value(), "outer");
+  TFunc inner_result = Any(result->body[0]).cast<TFunc>();
+  EXPECT_EQ(inner_result->comment.value(), "inner");
+  EXPECT_EQ(std::string(Any(inner_result->params[0]).cast<TVar>()->name), "x");
+}
+
+// ---------------------------------------------------------------------------
+// Map with integer keys
+// ---------------------------------------------------------------------------
+TEST(Serialization, MapWithIntKeys) {
+  Map<Any, Any> map;
+  map.Set(static_cast<int64_t>(1), String("one"));
+  map.Set(static_cast<int64_t>(2), String("two"));
+
+  json::Value serialized = ToJSONGraph(map);
+  Any deserialized = FromJSONGraph(serialized);
+  Map<Any, Any> result = deserialized.cast<Map<Any, Any>>();
+  EXPECT_EQ(result.size(), 2);
+  EXPECT_EQ(std::string(result[1].cast<String>()), "one");
+  EXPECT_EQ(std::string(result[2].cast<String>()), "two");
+}
+
+// ---------------------------------------------------------------------------
+// Nested containers
+// ---------------------------------------------------------------------------
+TEST(Serialization, NestedArrays) {
+  Array<Any> inner1;
+  inner1.push_back(1);
+  inner1.push_back(2);
+  Array<Any> inner2;
+  inner2.push_back(3);
+  Array<Any> outer;
+  outer.push_back(inner1);
+  outer.push_back(inner2);
+
+  json::Value serialized = ToJSONGraph(outer);
+  Any deserialized = FromJSONGraph(serialized);
+  Array<Any> result = deserialized.cast<Array<Any>>();
+  EXPECT_EQ(result.size(), 2);
+  Array<Any> r1 = result[0].cast<Array<Any>>();
+  Array<Any> r2 = result[1].cast<Array<Any>>();
+  EXPECT_EQ(r1.size(), 2);
+  EXPECT_EQ(r1[0].cast<int64_t>(), 1);
+  EXPECT_EQ(r1[1].cast<int64_t>(), 2);
+  EXPECT_EQ(r2.size(), 1);
+  EXPECT_EQ(r2[0].cast<int64_t>(), 3);
+}
+
+TEST(Serialization, MapWithArrayValues) {
+  Array<Any> arr;
+  arr.push_back(10);
+  arr.push_back(20);
+  Map<String, Any> map{{"nums", arr}};
+
+  json::Value serialized = ToJSONGraph(map);
+  Any deserialized = FromJSONGraph(serialized);
+  Map<String, Any> result = deserialized.cast<Map<String, Any>>();
+  Array<Any> result_arr = result["nums"].cast<Array<Any>>();
+  EXPECT_EQ(result_arr.size(), 2);
+  EXPECT_EQ(result_arr[0].cast<int64_t>(), 10);
+  EXPECT_EQ(result_arr[1].cast<int64_t>(), 20);
+}
+
+// ---------------------------------------------------------------------------
+// Array and Map with objects
+// ---------------------------------------------------------------------------
+TEST(Serialization, ArrayOfObjects) {
+  TVar x("x");
+  TVar y("y");
+  Array<Any> arr;
+  arr.push_back(x);
+  arr.push_back(y);
+
+  json::Value serialized = ToJSONGraph(arr);
+  Any deserialized = FromJSONGraph(serialized);
+  Array<Any> result = deserialized.cast<Array<Any>>();
+  EXPECT_EQ(result.size(), 2);
+  EXPECT_EQ(std::string(result[0].cast<TVar>()->name), "x");
+  EXPECT_EQ(std::string(result[1].cast<TVar>()->name), "y");
+}
+
+TEST(Serialization, MapOfObjects) {
+  TVar x("x");
+  Map<String, Any> map{{"var", x}};
+
+  json::Value serialized = ToJSONGraph(map);
+  Any deserialized = FromJSONGraph(serialized);
+  Map<String, Any> result = deserialized.cast<Map<String, Any>>();
+  EXPECT_EQ(std::string(result["var"].cast<TVar>()->name), "x");
+}
+
+// ---------------------------------------------------------------------------
+// Mixed-type array (exercises runtime type dispatch for each element)
+// ---------------------------------------------------------------------------
+TEST(Serialization, MixedTypeArrayRoundTrip) {
+  DLDataType dtype;
+  dtype.code = kDLInt;
+  dtype.bits = 32;
+  dtype.lanes = 1;
+
+  DLDevice device;
+  device.device_type = kDLCPU;
+  device.device_id = 0;
+
+  Array<Any> arr;
+  arr.push_back(nullptr);
+  arr.push_back(true);
+  arr.push_back(false);
+  arr.push_back(static_cast<int64_t>(42));
+  arr.push_back(3.14);
+  arr.push_back(String("hello"));
+  arr.push_back(dtype);
+  arr.push_back(device);
+
+  // roundtrip and verify structural equality
+  EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(arr)), arr));
+}
+
+// ---------------------------------------------------------------------------
+// Error cases
+// ---------------------------------------------------------------------------
+TEST(Serialization, ErrorMissingRequiredField) {
+  // required_val is required but not provided
+  json::Object data;
+  json::Object graph{
+      {"root_index", 0},
+      {"nodes", json::Array{json::Object{{"type", "test.WithDefaults"}, 
{"data", data}}}}};
+  EXPECT_ANY_THROW(FromJSONGraph(graph));
+}
+
+TEST(Serialization, ErrorInvalidRootStructure) {
+  // not an object
+  EXPECT_ANY_THROW(FromJSONGraph(json::Value(42)));
+}
+
+TEST(Serialization, ErrorMissingRootIndex) {
+  json::Object graph{{"nodes", json::Array{json::Object{{"type", "None"}}}}};
+  EXPECT_ANY_THROW(FromJSONGraph(graph));
+}
+
+TEST(Serialization, ErrorMissingNodes) {
+  json::Object graph{{"root_index", 0}};
+  EXPECT_ANY_THROW(FromJSONGraph(graph));
+}
+
+// ---------------------------------------------------------------------------
+// String serialization roundtrip (json::Stringify / json::Parse)
+// ---------------------------------------------------------------------------
+TEST(Serialization, StringRoundTrip) {
+  TVar x("x");
+  TFunc f({x}, {x}, String("comment"));
+  String json_str = json::Stringify(ToJSONGraph(f));
+  Any deserialized = FromJSONGraph(json::Parse(json_str));
+  EXPECT_TRUE(StructuralEqual::Equal(deserialized, f, /*map_free_vars=*/true));
+}
+
+TEST(Serialization, StringRoundTripPrimitives) {
+  auto rt = [](const Any& v) {
+    return FromJSONGraph(json::Parse(json::Stringify(ToJSONGraph(v))));
+  };
+  // int
+  EXPECT_TRUE(StructuralEqual()(rt(static_cast<int64_t>(123)), 123));
+  // bool
+  EXPECT_TRUE(StructuralEqual()(rt(true), true));
+  // float
+  EXPECT_TRUE(StructuralEqual()(rt(2.718), 2.718));
+  // string
+  EXPECT_TRUE(StructuralEqual()(rt(String("test")), String("test")));
+  // null
+  EXPECT_TRUE(StructuralEqual()(rt(nullptr), nullptr));
+}
+
 }  // namespace
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index e77368cd..6570857c 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -62,6 +62,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   TVarObj::RegisterReflection();
   TFuncObj::RegisterReflection();
   TCustomFuncObj::RegisterReflection();
+  TAllFieldsObj::RegisterReflection();
+  TWithDefaultsObj::RegisterReflection();
 
   refl::ObjectDef<TestObjA>()
       .def(refl::init<int64_t, int64_t>())
diff --git a/tests/cpp/testing_object.h b/tests/cpp/testing_object.h
index b3ce5504..07f070af 100644
--- a/tests/cpp/testing_object.h
+++ b/tests/cpp/testing_object.h
@@ -23,6 +23,7 @@
 #include <tvm/ffi/any.h>
 #include <tvm/ffi/container/array.h>
 #include <tvm/ffi/container/map.h>
+#include <tvm/ffi/dtype.h>
 #include <tvm/ffi/memory.h>
 #include <tvm/ffi/object.h>
 #include <tvm/ffi/reflection/registry.h>
@@ -288,6 +289,101 @@ class TCustomFunc : public ObjectRef {
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TCustomFunc, ObjectRef, 
TCustomFuncObj);
 };
 
+// Test object with all POD field types to exercise serialization of every 
field kind.
+class TAllFieldsObj : public Object {
+ public:
+  bool v_bool;
+  int64_t v_int;
+  double v_float;
+  DLDataType v_dtype;
+  DLDevice v_device;
+  String v_str;
+  Optional<String> v_opt_str;
+  Array<Any> v_array;
+  Map<String, Any> v_map;
+
+  TAllFieldsObj(bool v_bool, int64_t v_int, double v_float, DLDataType 
v_dtype, DLDevice v_device,
+                String v_str, Optional<String> v_opt_str, Array<Any> v_array,
+                Map<String, Any> v_map)
+      : v_bool(v_bool),
+        v_int(v_int),
+        v_float(v_float),
+        v_dtype(v_dtype),
+        v_device(v_device),
+        v_str(v_str),
+        v_opt_str(v_opt_str),
+        v_array(v_array),
+        v_map(v_map) {}
+  explicit TAllFieldsObj(UnsafeInit) {}
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<TAllFieldsObj>()
+        .def_ro("v_bool", &TAllFieldsObj::v_bool)
+        .def_ro("v_int", &TAllFieldsObj::v_int)
+        .def_ro("v_float", &TAllFieldsObj::v_float)
+        .def_ro("v_dtype", &TAllFieldsObj::v_dtype)
+        .def_ro("v_device", &TAllFieldsObj::v_device)
+        .def_ro("v_str", &TAllFieldsObj::v_str)
+        .def_ro("v_opt_str", &TAllFieldsObj::v_opt_str)
+        .def_ro("v_array", &TAllFieldsObj::v_array)
+        .def_ro("v_map", &TAllFieldsObj::v_map);
+  }
+
+  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindTreeNode;
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.AllFields", TAllFieldsObj, Object);
+};
+
+class TAllFields : public ObjectRef {
+ public:
+  explicit TAllFields(bool v_bool, int64_t v_int, double v_float, DLDataType 
v_dtype,
+                      DLDevice v_device, String v_str, Optional<String> 
v_opt_str,
+                      Array<Any> v_array, Map<String, Any> v_map) {
+    data_ = make_object<TAllFieldsObj>(v_bool, v_int, v_float, v_dtype, 
v_device, v_str, v_opt_str,
+                                       v_array, v_map);
+  }
+
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TAllFields, ObjectRef, 
TAllFieldsObj);
+};
+
+// Test object with fields that have default values to test deserialization 
with missing fields
+class TWithDefaultsObj : public Object {
+ public:
+  int64_t required_val;
+  int64_t default_int;
+  String default_str;
+  bool default_bool;
+
+  TWithDefaultsObj(int64_t required_val, int64_t default_int, String 
default_str, bool default_bool)
+      : required_val(required_val),
+        default_int(default_int),
+        default_str(default_str),
+        default_bool(default_bool) {}
+  explicit TWithDefaultsObj(UnsafeInit) {}
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<TWithDefaultsObj>()
+        .def_ro("required_val", &TWithDefaultsObj::required_val)
+        .def_ro("default_int", &TWithDefaultsObj::default_int, 
refl::DefaultValue(42))
+        .def_ro("default_str", &TWithDefaultsObj::default_str, 
refl::DefaultValue("default"))
+        .def_ro("default_bool", &TWithDefaultsObj::default_bool, 
refl::DefaultValue(true));
+  }
+
+  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindTreeNode;
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.WithDefaults", TWithDefaultsObj, 
Object);
+};
+
+class TWithDefaults : public ObjectRef {
+ public:
+  explicit TWithDefaults(int64_t required_val, int64_t default_int = 42,
+                         String default_str = "default", bool default_bool = 
true) {
+    data_ = make_object<TWithDefaultsObj>(required_val, default_int, 
default_str, default_bool);
+  }
+
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TWithDefaults, ObjectRef, 
TWithDefaultsObj);
+};
+
 }  // namespace testing
 
 template <>
diff --git a/tests/python/test_serialization.py 
b/tests/python/test_serialization.py
new file mode 100644
index 00000000..d2357fa4
--- /dev/null
+++ b/tests/python/test_serialization.py
@@ -0,0 +1,515 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for JSON graph serialization/deserialization roundtrips."""
+
+from __future__ import annotations
+
+import json
+from typing import Any, Callable
+
+import pytest
+import tvm_ffi
+import tvm_ffi.testing
+from tvm_ffi.serialization import from_json_graph_str, to_json_graph_str
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+def _roundtrip(obj: Any) -> Any:
+    """Serialize then deserialize and return the result."""
+    return from_json_graph_str(to_json_graph_str(obj))
+
+
+def _assert_roundtrip_eq(obj: Any, cmp: Callable[..., Any] | None = None) -> 
None:
+    """Assert that roundtrip preserves the value."""
+    result = _roundtrip(obj)
+    if cmp is not None:
+        cmp(result)
+    elif isinstance(obj, float):
+        assert isinstance(result, float)
+        assert result == pytest.approx(obj)
+    elif obj is None:
+        assert result is None
+    else:
+        _assert_any_equal(result, obj)
+
+
+def _assert_any_equal(a: Any, b: Any) -> None:
+    """Recursively compare two tvm_ffi values for equality."""
+    if isinstance(b, tvm_ffi.Array):
+        assert len(a) == len(b)
+        for x, y in zip(a, b):
+            _assert_any_equal(x, y)
+    elif isinstance(b, tvm_ffi.Map):
+        assert len(a) == len(b)
+        for k in b:
+            _assert_any_equal(a[k], b[k])
+    elif isinstance(b, tvm_ffi.Shape):
+        assert list(a) == list(b)
+    elif isinstance(b, str):
+        # tvm_ffi String inherits from str
+        assert str(a) == str(b)
+    else:
+        assert a == b
+
+
+# ---------------------------------------------------------------------------
+# Primitive types
+# ---------------------------------------------------------------------------
+class TestNone:
+    """Roundtrip tests for None."""
+
+    def test_none(self) -> None:
+        """None roundtrips to None."""
+        assert _roundtrip(None) is None
+
+
+class TestBool:
+    """Roundtrip tests for bool values."""
+
+    def test_true(self) -> None:
+        """True roundtrips to True."""
+        assert _roundtrip(True) is True
+
+    def test_false(self) -> None:
+        """False roundtrips to False."""
+        assert _roundtrip(False) is False
+
+
+class TestInt:
+    """Roundtrip tests for integer values."""
+
+    def test_zero(self) -> None:
+        """Zero roundtrips correctly."""
+        result = _roundtrip(0)
+        assert result == 0
+
+    def test_positive(self) -> None:
+        """Positive int roundtrips correctly."""
+        result = _roundtrip(42)
+        assert result == 42
+
+    def test_negative(self) -> None:
+        """Negative int roundtrips correctly."""
+        result = _roundtrip(-1)
+        assert result == -1
+
+    def test_large_positive(self) -> None:
+        """Large positive int roundtrips correctly."""
+        result = _roundtrip(10**15)
+        assert result == 10**15
+
+    def test_large_negative(self) -> None:
+        """Large negative int roundtrips correctly."""
+        result = _roundtrip(-(10**15))
+        assert result == -(10**15)
+
+
+class TestFloat:
+    """Roundtrip tests for float values."""
+
+    def test_zero(self) -> None:
+        """Float zero roundtrips correctly."""
+        result = _roundtrip(0.0)
+        assert result == 0.0
+
+    def test_positive(self) -> None:
+        """Positive float roundtrips correctly."""
+        result = _roundtrip(3.14159)
+        assert result == pytest.approx(3.14159)
+
+    def test_negative(self) -> None:
+        """Negative float roundtrips correctly."""
+        result = _roundtrip(-2.718)
+        assert result == pytest.approx(-2.718)
+
+    def test_very_small(self) -> None:
+        """Very small float roundtrips correctly."""
+        result = _roundtrip(1e-300)
+        assert result == pytest.approx(1e-300)
+
+    def test_very_large(self) -> None:
+        """Very large float roundtrips correctly."""
+        result = _roundtrip(1e300)
+        assert result == pytest.approx(1e300)
+
+
+# ---------------------------------------------------------------------------
+# String types
+# ---------------------------------------------------------------------------
+class TestString:
+    """Roundtrip tests for ffi.String values."""
+
+    def test_empty(self) -> None:
+        """Empty string roundtrips correctly."""
+        _assert_roundtrip_eq(tvm_ffi.convert(""))
+
+    def test_short(self) -> None:
+        """Short string roundtrips correctly."""
+        _assert_roundtrip_eq(tvm_ffi.convert("hello"))
+
+    def test_long(self) -> None:
+        """Long string roundtrips correctly."""
+        _assert_roundtrip_eq(tvm_ffi.convert("x" * 1000))
+
+    def test_special_chars(self) -> None:
+        """String with special characters roundtrips correctly."""
+        _assert_roundtrip_eq(tvm_ffi.convert('hello\nworld\t"quotes"'))
+
+    def test_unicode(self) -> None:
+        """Unicode string roundtrips correctly."""
+        _assert_roundtrip_eq(tvm_ffi.convert("hello δΈ–η•Œ"))
+
+
+# ---------------------------------------------------------------------------
+# DataType
+# ---------------------------------------------------------------------------
+class TestDataType:
+    """Roundtrip tests for DLDataType values."""
+
+    def test_int32(self) -> None:
+        """int32 dtype roundtrips correctly."""
+        s = to_json_graph_str(tvm_ffi.dtype("int32"))
+        result = from_json_graph_str(s)
+        assert str(result) == "int32"
+
+    def test_float64(self) -> None:
+        """float64 dtype roundtrips correctly."""
+        s = to_json_graph_str(tvm_ffi.dtype("float64"))
+        result = from_json_graph_str(s)
+        assert str(result) == "float64"
+
+    def test_bool(self) -> None:
+        """Bool dtype roundtrips correctly."""
+        s = to_json_graph_str(tvm_ffi.dtype("bool"))
+        result = from_json_graph_str(s)
+        assert str(result) == "bool"
+
+    def test_vector(self) -> None:
+        """Vector dtype roundtrips correctly."""
+        s = to_json_graph_str(tvm_ffi.dtype("float32x4"))
+        result = from_json_graph_str(s)
+        assert str(result) == "float32x4"
+
+
+# ---------------------------------------------------------------------------
+# Device
+# ---------------------------------------------------------------------------
+class TestDevice:
+    """Roundtrip tests for DLDevice values."""
+
+    def test_cpu(self) -> None:
+        """CPU device roundtrips correctly."""
+        s = to_json_graph_str(tvm_ffi.Device("cpu", 0))
+        result = from_json_graph_str(s)
+        assert result.dlpack_device_type() == tvm_ffi.Device("cpu", 
0).dlpack_device_type()
+        assert result.index == 0
+
+    def test_cuda(self) -> None:
+        """CUDA device roundtrips correctly."""
+        s = to_json_graph_str(tvm_ffi.Device("cuda", 1))
+        result = from_json_graph_str(s)
+        assert result.dlpack_device_type() == tvm_ffi.Device("cuda", 
1).dlpack_device_type()
+        assert result.index == 1
+
+
+# ---------------------------------------------------------------------------
+# Containers
+# ---------------------------------------------------------------------------
+class TestArray:
+    """Roundtrip tests for ffi.Array containers."""
+
+    def test_empty(self) -> None:
+        """Empty array roundtrips correctly."""
+        arr = tvm_ffi.convert([])
+        _assert_roundtrip_eq(arr)
+
+    def test_single_element(self) -> None:
+        """Single-element array roundtrips correctly."""
+        arr = tvm_ffi.convert([42])
+        result = _roundtrip(arr)
+        assert len(result) == 1
+        assert result[0] == 42
+
+    def test_multiple_elements(self) -> None:
+        """Multi-element array roundtrips correctly."""
+        arr = tvm_ffi.convert([1, 2, 3])
+        result = _roundtrip(arr)
+        assert len(result) == 3
+        assert list(result) == [1, 2, 3]
+
+    def test_mixed_types(self) -> None:
+        """Array with mixed types roundtrips correctly."""
+        arr = tvm_ffi.convert([42, "hello", True, None])
+        result = _roundtrip(arr)
+        assert len(result) == 4
+        assert result[0] == 42
+        assert result[1] == "hello"
+        assert result[2] is True
+        assert result[3] is None
+
+    def test_nested_arrays(self) -> None:
+        """Nested arrays roundtrip correctly."""
+        inner1 = tvm_ffi.convert([1, 2])
+        inner2 = tvm_ffi.convert([3])
+        outer = tvm_ffi.convert([inner1, inner2])
+        result = _roundtrip(outer)
+        assert len(result) == 2
+        assert list(result[0]) == [1, 2]
+        assert list(result[1]) == [3]
+
+    def test_duplicated_elements(self) -> None:
+        """Array with duplicated elements roundtrips correctly."""
+        arr = tvm_ffi.convert([42, 42, 42])
+        result = _roundtrip(arr)
+        assert len(result) == 3
+        assert all(x == 42 for x in result)
+
+
+class TestMap:
+    """Roundtrip tests for ffi.Map containers."""
+
+    def test_empty(self) -> None:
+        """Empty map roundtrips correctly."""
+        m = tvm_ffi.convert({})
+        _assert_roundtrip_eq(m)
+
+    def test_single_entry(self) -> None:
+        """Single-entry map roundtrips correctly."""
+        m = tvm_ffi.convert({"key": 42})
+        result = _roundtrip(m)
+        assert len(result) == 1
+        assert result["key"] == 42
+
+    def test_multiple_entries(self) -> None:
+        """Multi-entry map roundtrips correctly."""
+        m = tvm_ffi.convert({"a": 1, "b": 2, "c": 3})
+        result = _roundtrip(m)
+        assert len(result) == 3
+        assert result["a"] == 1
+        assert result["b"] == 2
+        assert result["c"] == 3
+
+    def test_mixed_value_types(self) -> None:
+        """Map with mixed value types roundtrips correctly."""
+        m = tvm_ffi.convert({"int": 42, "str": "hello", "bool": True, "none": 
None})
+        result = _roundtrip(m)
+        assert result["int"] == 42
+        assert result["str"] == "hello"
+        assert result["bool"] is True
+        assert result["none"] is None
+
+    def test_nested_map(self) -> None:
+        """Nested maps roundtrip correctly."""
+        inner = tvm_ffi.convert({"x": 1})
+        outer = tvm_ffi.convert({"inner": inner})
+        result = _roundtrip(outer)
+        assert result["inner"]["x"] == 1
+
+    def test_map_with_array_value(self) -> None:
+        """Map with array values roundtrips correctly."""
+        arr = tvm_ffi.convert([10, 20])
+        m = tvm_ffi.convert({"nums": arr})
+        result = _roundtrip(m)
+        assert list(result["nums"]) == [10, 20]
+
+
+class TestShape:
+    """Roundtrip tests for ffi.Shape containers."""
+
+    def test_empty(self) -> None:
+        """Empty shape roundtrips correctly."""
+        shape = tvm_ffi.Shape(())
+        _assert_roundtrip_eq(shape)
+
+    def test_1d(self) -> None:
+        """1D shape roundtrips correctly."""
+        shape = tvm_ffi.Shape((10,))
+        result = _roundtrip(shape)
+        assert list(result) == [10]
+
+    def test_nd(self) -> None:
+        """N-D shape roundtrips correctly."""
+        shape = tvm_ffi.Shape((1, 2, 3, 4))
+        result = _roundtrip(shape)
+        assert list(result) == [1, 2, 3, 4]
+
+
+# ---------------------------------------------------------------------------
+# Objects with reflection
+# ---------------------------------------------------------------------------
+class TestObjectSerialization:
+    """Roundtrip tests for objects with reflection metadata."""
+
+    def test_int_pair_roundtrip(self) -> None:
+        """TestIntPair has refl::init and POD int64 fields."""
+        pair = tvm_ffi.testing.TestIntPair(3, 7)  # ty: 
ignore[too-many-positional-arguments]
+        s = to_json_graph_str(pair)
+        result = from_json_graph_str(s)
+        assert result.a == 3
+        assert result.b == 7
+
+    def test_int_pair_zero_values(self) -> None:
+        """TestIntPair with zero values roundtrips correctly."""
+        pair = tvm_ffi.testing.TestIntPair(0, 0)  # ty: 
ignore[too-many-positional-arguments]
+        result = _roundtrip(pair)
+        assert result.a == 0
+        assert result.b == 0
+
+    def test_int_pair_negative_values(self) -> None:
+        """TestIntPair with negative values roundtrips correctly."""
+        pair = tvm_ffi.testing.TestIntPair(-100, -200)  # ty: 
ignore[too-many-positional-arguments]
+        result = _roundtrip(pair)
+        assert result.a == -100
+        assert result.b == -200
+
+    def test_int_pair_large_values(self) -> None:
+        """TestIntPair with large values roundtrips correctly."""
+        pair = tvm_ffi.testing.TestIntPair(10**15, -(10**15))  # ty: 
ignore[too-many-positional-arguments]
+        result = _roundtrip(pair)
+        assert result.a == 10**15
+        assert result.b == -(10**15)
+
+
+# ---------------------------------------------------------------------------
+# JSON structure verification
+# ---------------------------------------------------------------------------
+class TestJSONStructure:
+    """Tests verifying the internal JSON graph structure."""
+
+    def test_null_json_structure(self) -> None:
+        """None produces a single node with type 'None'."""
+        s = to_json_graph_str(None)
+        parsed = json.loads(s)
+        assert parsed["root_index"] == 0
+        assert len(parsed["nodes"]) == 1
+        assert parsed["nodes"][0]["type"] == "None"
+
+    def test_bool_json_structure(self) -> None:
+        """Bool produces a node with type 'bool'."""
+        s = to_json_graph_str(True)
+        parsed = json.loads(s)
+        assert parsed["nodes"][0]["type"] == "bool"
+        assert parsed["nodes"][0]["data"] is True
+
+    def test_int_json_structure(self) -> None:
+        """Int produces a node with type 'int'."""
+        s = to_json_graph_str(42)
+        parsed = json.loads(s)
+        assert parsed["nodes"][0]["type"] == "int"
+        assert parsed["nodes"][0]["data"] == 42
+
+    def test_float_json_structure(self) -> None:
+        """Float produces a node with type 'float'."""
+        s = to_json_graph_str(3.14)
+        parsed = json.loads(s)
+        assert parsed["nodes"][0]["type"] == "float"
+        assert parsed["nodes"][0]["data"] == pytest.approx(3.14)
+
+    def test_string_json_structure(self) -> None:
+        """String produces a node with type 'ffi.String'."""
+        s = to_json_graph_str(tvm_ffi.convert("hello"))
+        parsed = json.loads(s)
+        assert parsed["nodes"][parsed["root_index"]]["type"] == "ffi.String"
+        assert parsed["nodes"][parsed["root_index"]]["data"] == "hello"
+
+    def test_array_json_structure(self) -> None:
+        """Array data contains node index references."""
+        s = to_json_graph_str(tvm_ffi.convert([1, 2]))
+        parsed = json.loads(s)
+        root = parsed["nodes"][parsed["root_index"]]
+        assert root["type"] == "ffi.Array"
+        # data should be list of node indices
+        assert isinstance(root["data"], list)
+        assert len(root["data"]) == 2
+
+    def test_map_json_structure(self) -> None:
+        """Map data contains flattened key-value node index pairs."""
+        s = to_json_graph_str(tvm_ffi.convert({"a": 1}))
+        parsed = json.loads(s)
+        root = parsed["nodes"][parsed["root_index"]]
+        assert root["type"] == "ffi.Map"
+        assert isinstance(root["data"], list)
+        # key-value pairs flattened: [key_idx, val_idx]
+        assert len(root["data"]) == 2
+
+    def test_node_dedup(self) -> None:
+        """Duplicate values should share the same node index."""
+        s = to_json_graph_str(tvm_ffi.convert([42, 42, 42]))
+        parsed = json.loads(s)
+        root = parsed["nodes"][parsed["root_index"]]
+        # all three elements should reference the same node
+        assert root["data"][0] == root["data"][1] == root["data"][2]
+
+    def test_object_pod_fields_are_inlined(self) -> None:
+        """POD fields (int, bool, float) are inlined directly via 
field_static_type_index."""
+        pair = tvm_ffi.testing.TestIntPair(3, 7)  # ty: 
ignore[too-many-positional-arguments]
+        s = to_json_graph_str(pair)
+        parsed = json.loads(s)
+        root = parsed["nodes"][parsed["root_index"]]
+        assert root["type"] == "testing.TestIntPair"
+        # fields 'a' and 'b' should be inlined as direct int values (not node 
indices)
+        assert root["data"]["a"] == 3
+        assert root["data"]["b"] == 7
+
+
+# ---------------------------------------------------------------------------
+# Metadata
+# ---------------------------------------------------------------------------
+class TestMetadata:
+    """Tests for optional metadata in serialized output."""
+
+    def test_with_metadata(self) -> None:
+        """Metadata dict appears in serialized JSON when provided."""
+        s = to_json_graph_str(42, {"version": "1.0"})
+        parsed = json.loads(s)
+        assert "metadata" in parsed
+        assert parsed["metadata"]["version"] == "1.0"
+
+    def test_without_metadata(self) -> None:
+        """Metadata key is absent when not provided."""
+        s = to_json_graph_str(42)
+        parsed = json.loads(s)
+        assert "metadata" not in parsed
+
+
+# ---------------------------------------------------------------------------
+# Error cases
+# ---------------------------------------------------------------------------
+class TestErrors:
+    """Tests for error handling in deserialization."""
+
+    def test_invalid_json_string(self) -> None:
+        """Invalid JSON string raises an error."""
+        with pytest.raises(Exception):
+            from_json_graph_str("not valid json")
+
+    def test_empty_json_string(self) -> None:
+        """Empty string raises an error."""
+        with pytest.raises(Exception):
+            from_json_graph_str("")
+
+    def test_missing_root_index(self) -> None:
+        """JSON without root_index raises an error."""
+        with pytest.raises(Exception):
+            from_json_graph_str('{"nodes": []}')
+
+    def test_missing_nodes(self) -> None:
+        """JSON without nodes raises an error."""
+        with pytest.raises(Exception):
+            from_json_graph_str('{"root_index": 0}')


Reply via email to