This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 3b26a09a doc: Properly document `field_static_type_index` and add more
tests (#456)
3b26a09a is described below
commit 3b26a09a1e47a55e641f2267317c97114b27ab36
Author: Junru Shao <[email protected]>
AuthorDate: Tue Feb 17 07:04:43 2026 -0800
doc: Properly document `field_static_type_index` and add more tests (#456)
## Summary
Restore full usage of `field_static_type_index` in TVMFFIFieldInfo with
thorough documentation explaining its compile-time static type
semantics:
the field reflects the declared type at compile time, which may differ
from the runtime type (e.g., `Any` vs `int`, `Array<Any>` vs
`Array<int>`).
The serializer uses this field to inline POD values (None, Bool, Int,
Float, DataType) directly, avoiding unnecessary node-graph overhead,
while routing other types through the standard node graph.
Changes:
- Add comprehensive documentation to `field_static_type_index` in
TVMFFIFieldInfo (C API) explaining all possible values and semantics
- Restore `TypeToFieldStaticTypeIndex` and `TypeToRuntimeTypeIndex`
templates in type_traits.h
- Restore `field_static_type_index` in all TypeTraits specializations
- Restore POD field inlining in serialization/deserialization
- Add comprehensive C++ and Python serialization tests (40 + 58 tests)
## Test plan
- [x] All 302 C++ tests pass (including 40 serialization + 21 reflection
tests)
- [x] All 58 Python serialization tests pass
- [x] Verified `field_static_type_index` is populated correctly in
reflection via test object field info
---
include/tvm/ffi/c_api.h | 31 +-
include/tvm/ffi/type_traits.h | 1 -
rust/tvm-ffi-sys/src/c_api.rs | 2 +-
tests/cpp/extra/test_serialization.cc | 364 ++++++++++++++++++++++++
tests/cpp/test_reflection.cc | 2 +
tests/cpp/testing_object.h | 96 +++++++
tests/python/test_serialization.py | 515 ++++++++++++++++++++++++++++++++++
7 files changed, 996 insertions(+), 15 deletions(-)
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 67715ca1..236e342f 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -979,23 +979,28 @@ typedef struct {
*/
TVMFFIAny default_value_or_factory;
/*!
- * \brief Records the static type kind of the field.
+ * \brief The compile-time static type index of the field.
*
- * Possible values:
+ * This reflects the type declared at compile time, which is only trustworthy
+ * for statically-inferrable types. It does NOT necessarily match the runtime
+ * type. For example, a field declared as `Any` will have
+ * `field_static_type_index == kTVMFFIAny` even if it holds an `int` at
runtime,
+ * and `Array<Any>` will report `kTVMFFIArray` even though the elements may
be
+ * `Array<int>` at runtime.
*
- * - TVMFFITypeIndex::kTVMFFIObject for general objects.
- * The value is nullable when kTVMFFIObject is chosen.
- * - Static object type kinds such as Map, Dict, String
- * - POD type index, note it does not give information about storage size of
the field.
- * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info
- * about the field.
+ * \warning Do NOT use this field to determine the actual type of a value at
+ * runtime. It is purely a compile-time hint derived from the C++ field
+ * declaration. The actual runtime type must be obtained from the value's
+ * own `type_index`.
*
- * When the value is a type index of Object type, the field is storaged as
an ObjectRef.
+ * \warning When the static type is a generic container (e.g. `Array<Any>`),
+ * this field only tells you it is an Array β it says nothing about the
+ * element types actually stored inside. Similarly, `kTVMFFIObject` only
+ * means "some ObjectRef" without any subtype information.
*
- * \note This information maybe helpful in designing serializer.
- * As it helps to narrow down the field type so we don't have to
- * print type_key for cases like POD types.
- * It also helps to provide opportunities to enable short-cut getter to
ObjectRef fields.
+ * \note This is used by the serializer to inline POD field values directly
+ * (avoiding node-graph overhead for None, Bool, Int, Float, DataType),
+ * while routing other types through the node graph.
*/
int32_t field_static_type_index;
} TVMFFIFieldInfo;
diff --git a/include/tvm/ffi/type_traits.h b/include/tvm/ffi/type_traits.h
index dc7f982b..e832559b 100644
--- a/include/tvm/ffi/type_traits.h
+++ b/include/tvm/ffi/type_traits.h
@@ -135,7 +135,6 @@ struct TypeToRuntimeTypeIndex<T,
std::enable_if_t<std::is_base_of_v<ObjectRef, T
template <>
struct TypeTraits<std::nullptr_t> : public TypeTraitsBase {
static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFINone;
-
TVM_FFI_INLINE static void CopyToAnyView(const std::nullptr_t&, TVMFFIAny*
result) {
result->type_index = TypeIndex::kTVMFFINone;
result->zero_padding = 0;
diff --git a/rust/tvm-ffi-sys/src/c_api.rs b/rust/tvm-ffi-sys/src/c_api.rs
index d94967ba..ee555e2a 100644
--- a/rust/tvm-ffi-sys/src/c_api.rs
+++ b/rust/tvm-ffi-sys/src/c_api.rs
@@ -291,7 +291,7 @@ pub struct TVMFFIFieldInfo {
/// When kTVMFFIFieldFlagBitMaskDefaultFromFactory is also set,
/// this is a callable factory function () -> Any.
pub default_value_or_factory: TVMFFIAny,
- /// Records the static type kind of the field
+ /// Records the compile-time static type kind of the field.
pub field_static_type_index: i32,
}
diff --git a/tests/cpp/extra/test_serialization.cc
b/tests/cpp/extra/test_serialization.cc
index 1d9d4830..ce6e42cc 100644
--- a/tests/cpp/extra/test_serialization.cc
+++ b/tests/cpp/extra/test_serialization.cc
@@ -26,6 +26,8 @@
#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/string.h>
+#include <limits>
+
#include "../testing_object.h"
namespace {
@@ -410,4 +412,366 @@ TEST(Serialization, ShuffleNodeOrder) {
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_shuffled),
duplicated_map));
}
+// ---------------------------------------------------------------------------
+// Integer edge cases
+// ---------------------------------------------------------------------------
+TEST(Serialization, IntegerEdgeCases) {
+ // zero
+
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(static_cast<int64_t>(0))),
+ static_cast<int64_t>(0)));
+ // negative
+
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(static_cast<int64_t>(-1))),
+ static_cast<int64_t>(-1)));
+ // large positive
+ int64_t large = 1000000000000LL;
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(large)), large));
+ // large negative
+ int64_t large_neg = -999999999999LL;
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(large_neg)),
large_neg));
+ // INT64_MIN and INT64_MAX
+ int64_t imin = std::numeric_limits<int64_t>::min();
+ int64_t imax = std::numeric_limits<int64_t>::max();
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(imin)), imin));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(imax)), imax));
+}
+
+// ---------------------------------------------------------------------------
+// Float edge cases
+// ---------------------------------------------------------------------------
+TEST(Serialization, FloatEdgeCases) {
+ // zero
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(0.0)), 0.0));
+ // negative
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(-1.5)), -1.5));
+ // very large
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(1e300)), 1e300));
+ // very small
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(1e-300)), 1e-300));
+}
+
+// ---------------------------------------------------------------------------
+// String edge cases
+// ---------------------------------------------------------------------------
+TEST(Serialization, EmptyString) {
+ String empty("");
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(empty)), empty));
+}
+
+TEST(Serialization, UnicodeString) {
+ String unicode("hello δΈη π");
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(unicode)), unicode));
+}
+
+TEST(Serialization, NullCharInString) {
+ // String with embedded null characters
+ std::string with_null("ab\0cd", 5);
+ String s(with_null);
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(s)), s));
+}
+
+// ---------------------------------------------------------------------------
+// Object with all POD field types (exercises node-graph path for POD fields)
+// ---------------------------------------------------------------------------
+TEST(Serialization, AllFieldsObject) {
+ DLDataType dtype;
+ dtype.code = kDLFloat;
+ dtype.bits = 32;
+ dtype.lanes = 1;
+
+ DLDevice device;
+ device.device_type = kDLCUDA;
+ device.device_id = 3;
+
+ Array<Any> arr;
+ arr.push_back(1);
+ arr.push_back(String("two"));
+
+ Map<String, Any> map{{"k", 99}};
+
+ TAllFields obj(true, -7, 2.5, dtype, device, String("hello"), String("opt"),
arr, map);
+ json::Value serialized = ToJSONGraph(obj);
+ Any deserialized = FromJSONGraph(serialized);
+
+ // verify each field
+ TAllFields result = deserialized.cast<TAllFields>();
+ EXPECT_EQ(result->v_bool, true);
+ EXPECT_EQ(result->v_int, -7);
+ EXPECT_DOUBLE_EQ(result->v_float, 2.5);
+ EXPECT_EQ(result->v_dtype.code, kDLFloat);
+ EXPECT_EQ(result->v_dtype.bits, 32);
+ EXPECT_EQ(result->v_dtype.lanes, 1);
+ EXPECT_EQ(result->v_device.device_type, kDLCUDA);
+ EXPECT_EQ(result->v_device.device_id, 3);
+ EXPECT_EQ(std::string(result->v_str), "hello");
+ EXPECT_TRUE(result->v_opt_str.has_value());
+ EXPECT_EQ(std::string(result->v_opt_str.value()), "opt");
+ EXPECT_EQ(result->v_array.size(), 2);
+ EXPECT_EQ(result->v_map.size(), 1);
+}
+
+TEST(Serialization, AllFieldsObjectOptionalNone) {
+ DLDataType dtype;
+ dtype.code = kDLInt;
+ dtype.bits = 64;
+ dtype.lanes = 1;
+
+ DLDevice device;
+ device.device_type = kDLCPU;
+ device.device_id = 0;
+
+ TAllFields obj(false, 0, 0.0, dtype, device, String(""), std::nullopt,
Array<Any>(),
+ Map<String, Any>());
+ json::Value serialized = ToJSONGraph(obj);
+ Any deserialized = FromJSONGraph(serialized);
+
+ TAllFields result = deserialized.cast<TAllFields>();
+ EXPECT_EQ(result->v_bool, false);
+ EXPECT_EQ(result->v_int, 0);
+ EXPECT_DOUBLE_EQ(result->v_float, 0.0);
+ EXPECT_EQ(std::string(result->v_str), "");
+ EXPECT_FALSE(result->v_opt_str.has_value());
+ EXPECT_EQ(result->v_array.size(), 0);
+ EXPECT_EQ(result->v_map.size(), 0);
+}
+
+// ---------------------------------------------------------------------------
+// Default field values during deserialization
+// ---------------------------------------------------------------------------
+TEST(Serialization, DefaultFieldValues) {
+ // serialize a TWithDefaults, then deserialize from JSON with missing
default fields
+ TWithDefaults original(100, 42, "default", true);
+ json::Value serialized = ToJSONGraph(original);
+ // roundtrip should work
+ Any deserialized = FromJSONGraph(serialized);
+ TWithDefaults result = deserialized.cast<TWithDefaults>();
+ EXPECT_EQ(result->required_val, 100);
+ EXPECT_EQ(result->default_int, 42);
+ EXPECT_EQ(std::string(result->default_str), "default");
+ EXPECT_EQ(result->default_bool, true);
+}
+
+TEST(Serialization, DefaultFieldValuesMissing) {
+ // manually construct JSON with only required field, defaults should kick in
+ // required_val is int64_t so it is inlined directly (POD field)
+ json::Object data;
+ data.Set("required_val", static_cast<int64_t>(999));
+
+ json::Object graph{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "test.WithDefaults"},
{"data", data}}}}};
+ Any result = FromJSONGraph(graph);
+ TWithDefaults obj = result.cast<TWithDefaults>();
+ EXPECT_EQ(obj->required_val, 999);
+ EXPECT_EQ(obj->default_int, 42);
+ EXPECT_EQ(std::string(obj->default_str), "default");
+ EXPECT_EQ(obj->default_bool, true);
+}
+
+// ---------------------------------------------------------------------------
+// Shared object references
+// ---------------------------------------------------------------------------
+TEST(Serialization, SharedObjectReferences) {
+ TVar shared_var("shared");
+ // two funcs share the same var
+ TFunc f1({shared_var}, {shared_var, shared_var}, std::nullopt);
+
+ json::Value serialized = ToJSONGraph(f1);
+ Any deserialized = FromJSONGraph(serialized);
+ TFunc result = deserialized.cast<TFunc>();
+
+ // all references to "shared" should be the same object after deserialization
+ // via the node dedup mechanism
+ EXPECT_EQ(result->params.size(), 1);
+ EXPECT_EQ(result->body.size(), 2);
+ // the params[0] and body[0] and body[1] should all refer to the same object
+ EXPECT_EQ(result->params[0].get(), result->body[0].get());
+ EXPECT_EQ(result->body[0].get(), result->body[1].get());
+}
+
+// ---------------------------------------------------------------------------
+// Nested objects
+// ---------------------------------------------------------------------------
+TEST(Serialization, NestedObjects) {
+ TVar x("x");
+ TVar y("y");
+ TFunc inner({x}, {x}, String("inner"));
+ // put the inner func as a body element of the outer func
+ TFunc outer({y}, {inner}, String("outer"));
+
+ json::Value serialized = ToJSONGraph(outer);
+ Any deserialized = FromJSONGraph(serialized);
+ TFunc result = deserialized.cast<TFunc>();
+
+ EXPECT_EQ(result->comment.value(), "outer");
+ TFunc inner_result = Any(result->body[0]).cast<TFunc>();
+ EXPECT_EQ(inner_result->comment.value(), "inner");
+ EXPECT_EQ(std::string(Any(inner_result->params[0]).cast<TVar>()->name), "x");
+}
+
+// ---------------------------------------------------------------------------
+// Map with integer keys
+// ---------------------------------------------------------------------------
+TEST(Serialization, MapWithIntKeys) {
+ Map<Any, Any> map;
+ map.Set(static_cast<int64_t>(1), String("one"));
+ map.Set(static_cast<int64_t>(2), String("two"));
+
+ json::Value serialized = ToJSONGraph(map);
+ Any deserialized = FromJSONGraph(serialized);
+ Map<Any, Any> result = deserialized.cast<Map<Any, Any>>();
+ EXPECT_EQ(result.size(), 2);
+ EXPECT_EQ(std::string(result[1].cast<String>()), "one");
+ EXPECT_EQ(std::string(result[2].cast<String>()), "two");
+}
+
+// ---------------------------------------------------------------------------
+// Nested containers
+// ---------------------------------------------------------------------------
+TEST(Serialization, NestedArrays) {
+ Array<Any> inner1;
+ inner1.push_back(1);
+ inner1.push_back(2);
+ Array<Any> inner2;
+ inner2.push_back(3);
+ Array<Any> outer;
+ outer.push_back(inner1);
+ outer.push_back(inner2);
+
+ json::Value serialized = ToJSONGraph(outer);
+ Any deserialized = FromJSONGraph(serialized);
+ Array<Any> result = deserialized.cast<Array<Any>>();
+ EXPECT_EQ(result.size(), 2);
+ Array<Any> r1 = result[0].cast<Array<Any>>();
+ Array<Any> r2 = result[1].cast<Array<Any>>();
+ EXPECT_EQ(r1.size(), 2);
+ EXPECT_EQ(r1[0].cast<int64_t>(), 1);
+ EXPECT_EQ(r1[1].cast<int64_t>(), 2);
+ EXPECT_EQ(r2.size(), 1);
+ EXPECT_EQ(r2[0].cast<int64_t>(), 3);
+}
+
+TEST(Serialization, MapWithArrayValues) {
+ Array<Any> arr;
+ arr.push_back(10);
+ arr.push_back(20);
+ Map<String, Any> map{{"nums", arr}};
+
+ json::Value serialized = ToJSONGraph(map);
+ Any deserialized = FromJSONGraph(serialized);
+ Map<String, Any> result = deserialized.cast<Map<String, Any>>();
+ Array<Any> result_arr = result["nums"].cast<Array<Any>>();
+ EXPECT_EQ(result_arr.size(), 2);
+ EXPECT_EQ(result_arr[0].cast<int64_t>(), 10);
+ EXPECT_EQ(result_arr[1].cast<int64_t>(), 20);
+}
+
+// ---------------------------------------------------------------------------
+// Array and Map with objects
+// ---------------------------------------------------------------------------
+TEST(Serialization, ArrayOfObjects) {
+ TVar x("x");
+ TVar y("y");
+ Array<Any> arr;
+ arr.push_back(x);
+ arr.push_back(y);
+
+ json::Value serialized = ToJSONGraph(arr);
+ Any deserialized = FromJSONGraph(serialized);
+ Array<Any> result = deserialized.cast<Array<Any>>();
+ EXPECT_EQ(result.size(), 2);
+ EXPECT_EQ(std::string(result[0].cast<TVar>()->name), "x");
+ EXPECT_EQ(std::string(result[1].cast<TVar>()->name), "y");
+}
+
+TEST(Serialization, MapOfObjects) {
+ TVar x("x");
+ Map<String, Any> map{{"var", x}};
+
+ json::Value serialized = ToJSONGraph(map);
+ Any deserialized = FromJSONGraph(serialized);
+ Map<String, Any> result = deserialized.cast<Map<String, Any>>();
+ EXPECT_EQ(std::string(result["var"].cast<TVar>()->name), "x");
+}
+
+// ---------------------------------------------------------------------------
+// Mixed-type array (exercises runtime type dispatch for each element)
+// ---------------------------------------------------------------------------
+TEST(Serialization, MixedTypeArrayRoundTrip) {
+ DLDataType dtype;
+ dtype.code = kDLInt;
+ dtype.bits = 32;
+ dtype.lanes = 1;
+
+ DLDevice device;
+ device.device_type = kDLCPU;
+ device.device_id = 0;
+
+ Array<Any> arr;
+ arr.push_back(nullptr);
+ arr.push_back(true);
+ arr.push_back(false);
+ arr.push_back(static_cast<int64_t>(42));
+ arr.push_back(3.14);
+ arr.push_back(String("hello"));
+ arr.push_back(dtype);
+ arr.push_back(device);
+
+ // roundtrip and verify structural equality
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(ToJSONGraph(arr)), arr));
+}
+
+// ---------------------------------------------------------------------------
+// Error cases
+// ---------------------------------------------------------------------------
+TEST(Serialization, ErrorMissingRequiredField) {
+ // required_val is required but not provided
+ json::Object data;
+ json::Object graph{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "test.WithDefaults"},
{"data", data}}}}};
+ EXPECT_ANY_THROW(FromJSONGraph(graph));
+}
+
+TEST(Serialization, ErrorInvalidRootStructure) {
+ // not an object
+ EXPECT_ANY_THROW(FromJSONGraph(json::Value(42)));
+}
+
+TEST(Serialization, ErrorMissingRootIndex) {
+ json::Object graph{{"nodes", json::Array{json::Object{{"type", "None"}}}}};
+ EXPECT_ANY_THROW(FromJSONGraph(graph));
+}
+
+TEST(Serialization, ErrorMissingNodes) {
+ json::Object graph{{"root_index", 0}};
+ EXPECT_ANY_THROW(FromJSONGraph(graph));
+}
+
+// ---------------------------------------------------------------------------
+// String serialization roundtrip (json::Stringify / json::Parse)
+// ---------------------------------------------------------------------------
+TEST(Serialization, StringRoundTrip) {
+ TVar x("x");
+ TFunc f({x}, {x}, String("comment"));
+ String json_str = json::Stringify(ToJSONGraph(f));
+ Any deserialized = FromJSONGraph(json::Parse(json_str));
+ EXPECT_TRUE(StructuralEqual::Equal(deserialized, f, /*map_free_vars=*/true));
+}
+
+TEST(Serialization, StringRoundTripPrimitives) {
+ auto rt = [](const Any& v) {
+ return FromJSONGraph(json::Parse(json::Stringify(ToJSONGraph(v))));
+ };
+ // int
+ EXPECT_TRUE(StructuralEqual()(rt(static_cast<int64_t>(123)), 123));
+ // bool
+ EXPECT_TRUE(StructuralEqual()(rt(true), true));
+ // float
+ EXPECT_TRUE(StructuralEqual()(rt(2.718), 2.718));
+ // string
+ EXPECT_TRUE(StructuralEqual()(rt(String("test")), String("test")));
+ // null
+ EXPECT_TRUE(StructuralEqual()(rt(nullptr), nullptr));
+}
+
} // namespace
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index e77368cd..6570857c 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -62,6 +62,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
TVarObj::RegisterReflection();
TFuncObj::RegisterReflection();
TCustomFuncObj::RegisterReflection();
+ TAllFieldsObj::RegisterReflection();
+ TWithDefaultsObj::RegisterReflection();
refl::ObjectDef<TestObjA>()
.def(refl::init<int64_t, int64_t>())
diff --git a/tests/cpp/testing_object.h b/tests/cpp/testing_object.h
index b3ce5504..07f070af 100644
--- a/tests/cpp/testing_object.h
+++ b/tests/cpp/testing_object.h
@@ -23,6 +23,7 @@
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/dtype.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/registry.h>
@@ -288,6 +289,101 @@ class TCustomFunc : public ObjectRef {
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TCustomFunc, ObjectRef,
TCustomFuncObj);
};
+// Test object with all POD field types to exercise serialization of every
field kind.
+class TAllFieldsObj : public Object {
+ public:
+ bool v_bool;
+ int64_t v_int;
+ double v_float;
+ DLDataType v_dtype;
+ DLDevice v_device;
+ String v_str;
+ Optional<String> v_opt_str;
+ Array<Any> v_array;
+ Map<String, Any> v_map;
+
+ TAllFieldsObj(bool v_bool, int64_t v_int, double v_float, DLDataType
v_dtype, DLDevice v_device,
+ String v_str, Optional<String> v_opt_str, Array<Any> v_array,
+ Map<String, Any> v_map)
+ : v_bool(v_bool),
+ v_int(v_int),
+ v_float(v_float),
+ v_dtype(v_dtype),
+ v_device(v_device),
+ v_str(v_str),
+ v_opt_str(v_opt_str),
+ v_array(v_array),
+ v_map(v_map) {}
+ explicit TAllFieldsObj(UnsafeInit) {}
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<TAllFieldsObj>()
+ .def_ro("v_bool", &TAllFieldsObj::v_bool)
+ .def_ro("v_int", &TAllFieldsObj::v_int)
+ .def_ro("v_float", &TAllFieldsObj::v_float)
+ .def_ro("v_dtype", &TAllFieldsObj::v_dtype)
+ .def_ro("v_device", &TAllFieldsObj::v_device)
+ .def_ro("v_str", &TAllFieldsObj::v_str)
+ .def_ro("v_opt_str", &TAllFieldsObj::v_opt_str)
+ .def_ro("v_array", &TAllFieldsObj::v_array)
+ .def_ro("v_map", &TAllFieldsObj::v_map);
+ }
+
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.AllFields", TAllFieldsObj, Object);
+};
+
+class TAllFields : public ObjectRef {
+ public:
+ explicit TAllFields(bool v_bool, int64_t v_int, double v_float, DLDataType
v_dtype,
+ DLDevice v_device, String v_str, Optional<String>
v_opt_str,
+ Array<Any> v_array, Map<String, Any> v_map) {
+ data_ = make_object<TAllFieldsObj>(v_bool, v_int, v_float, v_dtype,
v_device, v_str, v_opt_str,
+ v_array, v_map);
+ }
+
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TAllFields, ObjectRef,
TAllFieldsObj);
+};
+
+// Test object with fields that have default values to test deserialization
with missing fields
+class TWithDefaultsObj : public Object {
+ public:
+ int64_t required_val;
+ int64_t default_int;
+ String default_str;
+ bool default_bool;
+
+ TWithDefaultsObj(int64_t required_val, int64_t default_int, String
default_str, bool default_bool)
+ : required_val(required_val),
+ default_int(default_int),
+ default_str(default_str),
+ default_bool(default_bool) {}
+ explicit TWithDefaultsObj(UnsafeInit) {}
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<TWithDefaultsObj>()
+ .def_ro("required_val", &TWithDefaultsObj::required_val)
+ .def_ro("default_int", &TWithDefaultsObj::default_int,
refl::DefaultValue(42))
+ .def_ro("default_str", &TWithDefaultsObj::default_str,
refl::DefaultValue("default"))
+ .def_ro("default_bool", &TWithDefaultsObj::default_bool,
refl::DefaultValue(true));
+ }
+
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.WithDefaults", TWithDefaultsObj,
Object);
+};
+
+class TWithDefaults : public ObjectRef {
+ public:
+ explicit TWithDefaults(int64_t required_val, int64_t default_int = 42,
+ String default_str = "default", bool default_bool =
true) {
+ data_ = make_object<TWithDefaultsObj>(required_val, default_int,
default_str, default_bool);
+ }
+
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TWithDefaults, ObjectRef,
TWithDefaultsObj);
+};
+
} // namespace testing
template <>
diff --git a/tests/python/test_serialization.py
b/tests/python/test_serialization.py
new file mode 100644
index 00000000..d2357fa4
--- /dev/null
+++ b/tests/python/test_serialization.py
@@ -0,0 +1,515 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for JSON graph serialization/deserialization roundtrips."""
+
+from __future__ import annotations
+
+import json
+from typing import Any, Callable
+
+import pytest
+import tvm_ffi
+import tvm_ffi.testing
+from tvm_ffi.serialization import from_json_graph_str, to_json_graph_str
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+def _roundtrip(obj: Any) -> Any:
+ """Serialize then deserialize and return the result."""
+ return from_json_graph_str(to_json_graph_str(obj))
+
+
+def _assert_roundtrip_eq(obj: Any, cmp: Callable[..., Any] | None = None) ->
None:
+ """Assert that roundtrip preserves the value."""
+ result = _roundtrip(obj)
+ if cmp is not None:
+ cmp(result)
+ elif isinstance(obj, float):
+ assert isinstance(result, float)
+ assert result == pytest.approx(obj)
+ elif obj is None:
+ assert result is None
+ else:
+ _assert_any_equal(result, obj)
+
+
+def _assert_any_equal(a: Any, b: Any) -> None:
+ """Recursively compare two tvm_ffi values for equality."""
+ if isinstance(b, tvm_ffi.Array):
+ assert len(a) == len(b)
+ for x, y in zip(a, b):
+ _assert_any_equal(x, y)
+ elif isinstance(b, tvm_ffi.Map):
+ assert len(a) == len(b)
+ for k in b:
+ _assert_any_equal(a[k], b[k])
+ elif isinstance(b, tvm_ffi.Shape):
+ assert list(a) == list(b)
+ elif isinstance(b, str):
+ # tvm_ffi String inherits from str
+ assert str(a) == str(b)
+ else:
+ assert a == b
+
+
+# ---------------------------------------------------------------------------
+# Primitive types
+# ---------------------------------------------------------------------------
+class TestNone:
+ """Roundtrip tests for None."""
+
+ def test_none(self) -> None:
+ """None roundtrips to None."""
+ assert _roundtrip(None) is None
+
+
+class TestBool:
+ """Roundtrip tests for bool values."""
+
+ def test_true(self) -> None:
+ """True roundtrips to True."""
+ assert _roundtrip(True) is True
+
+ def test_false(self) -> None:
+ """False roundtrips to False."""
+ assert _roundtrip(False) is False
+
+
+class TestInt:
+ """Roundtrip tests for integer values."""
+
+ def test_zero(self) -> None:
+ """Zero roundtrips correctly."""
+ result = _roundtrip(0)
+ assert result == 0
+
+ def test_positive(self) -> None:
+ """Positive int roundtrips correctly."""
+ result = _roundtrip(42)
+ assert result == 42
+
+ def test_negative(self) -> None:
+ """Negative int roundtrips correctly."""
+ result = _roundtrip(-1)
+ assert result == -1
+
+ def test_large_positive(self) -> None:
+ """Large positive int roundtrips correctly."""
+ result = _roundtrip(10**15)
+ assert result == 10**15
+
+ def test_large_negative(self) -> None:
+ """Large negative int roundtrips correctly."""
+ result = _roundtrip(-(10**15))
+ assert result == -(10**15)
+
+
+class TestFloat:
+ """Roundtrip tests for float values."""
+
+ def test_zero(self) -> None:
+ """Float zero roundtrips correctly."""
+ result = _roundtrip(0.0)
+ assert result == 0.0
+
+ def test_positive(self) -> None:
+ """Positive float roundtrips correctly."""
+ result = _roundtrip(3.14159)
+ assert result == pytest.approx(3.14159)
+
+ def test_negative(self) -> None:
+ """Negative float roundtrips correctly."""
+ result = _roundtrip(-2.718)
+ assert result == pytest.approx(-2.718)
+
+ def test_very_small(self) -> None:
+ """Very small float roundtrips correctly."""
+ result = _roundtrip(1e-300)
+ assert result == pytest.approx(1e-300)
+
+ def test_very_large(self) -> None:
+ """Very large float roundtrips correctly."""
+ result = _roundtrip(1e300)
+ assert result == pytest.approx(1e300)
+
+
+# ---------------------------------------------------------------------------
+# String types
+# ---------------------------------------------------------------------------
+class TestString:
+ """Roundtrip tests for ffi.String values."""
+
+ def test_empty(self) -> None:
+ """Empty string roundtrips correctly."""
+ _assert_roundtrip_eq(tvm_ffi.convert(""))
+
+ def test_short(self) -> None:
+ """Short string roundtrips correctly."""
+ _assert_roundtrip_eq(tvm_ffi.convert("hello"))
+
+ def test_long(self) -> None:
+ """Long string roundtrips correctly."""
+ _assert_roundtrip_eq(tvm_ffi.convert("x" * 1000))
+
+ def test_special_chars(self) -> None:
+ """String with special characters roundtrips correctly."""
+ _assert_roundtrip_eq(tvm_ffi.convert('hello\nworld\t"quotes"'))
+
+ def test_unicode(self) -> None:
+ """Unicode string roundtrips correctly."""
+ _assert_roundtrip_eq(tvm_ffi.convert("hello δΈη"))
+
+
+# ---------------------------------------------------------------------------
+# DataType
+# ---------------------------------------------------------------------------
+class TestDataType:
+ """Roundtrip tests for DLDataType values."""
+
+ def test_int32(self) -> None:
+ """int32 dtype roundtrips correctly."""
+ s = to_json_graph_str(tvm_ffi.dtype("int32"))
+ result = from_json_graph_str(s)
+ assert str(result) == "int32"
+
+ def test_float64(self) -> None:
+ """float64 dtype roundtrips correctly."""
+ s = to_json_graph_str(tvm_ffi.dtype("float64"))
+ result = from_json_graph_str(s)
+ assert str(result) == "float64"
+
+ def test_bool(self) -> None:
+ """Bool dtype roundtrips correctly."""
+ s = to_json_graph_str(tvm_ffi.dtype("bool"))
+ result = from_json_graph_str(s)
+ assert str(result) == "bool"
+
+ def test_vector(self) -> None:
+ """Vector dtype roundtrips correctly."""
+ s = to_json_graph_str(tvm_ffi.dtype("float32x4"))
+ result = from_json_graph_str(s)
+ assert str(result) == "float32x4"
+
+
+# ---------------------------------------------------------------------------
+# Device
+# ---------------------------------------------------------------------------
+class TestDevice:
+ """Roundtrip tests for DLDevice values."""
+
+ def test_cpu(self) -> None:
+ """CPU device roundtrips correctly."""
+ s = to_json_graph_str(tvm_ffi.Device("cpu", 0))
+ result = from_json_graph_str(s)
+ assert result.dlpack_device_type() == tvm_ffi.Device("cpu",
0).dlpack_device_type()
+ assert result.index == 0
+
+ def test_cuda(self) -> None:
+ """CUDA device roundtrips correctly."""
+ s = to_json_graph_str(tvm_ffi.Device("cuda", 1))
+ result = from_json_graph_str(s)
+ assert result.dlpack_device_type() == tvm_ffi.Device("cuda",
1).dlpack_device_type()
+ assert result.index == 1
+
+
+# ---------------------------------------------------------------------------
+# Containers
+# ---------------------------------------------------------------------------
+class TestArray:
+ """Roundtrip tests for ffi.Array containers."""
+
+ def test_empty(self) -> None:
+ """Empty array roundtrips correctly."""
+ arr = tvm_ffi.convert([])
+ _assert_roundtrip_eq(arr)
+
+ def test_single_element(self) -> None:
+ """Single-element array roundtrips correctly."""
+ arr = tvm_ffi.convert([42])
+ result = _roundtrip(arr)
+ assert len(result) == 1
+ assert result[0] == 42
+
+ def test_multiple_elements(self) -> None:
+ """Multi-element array roundtrips correctly."""
+ arr = tvm_ffi.convert([1, 2, 3])
+ result = _roundtrip(arr)
+ assert len(result) == 3
+ assert list(result) == [1, 2, 3]
+
+ def test_mixed_types(self) -> None:
+ """Array with mixed types roundtrips correctly."""
+ arr = tvm_ffi.convert([42, "hello", True, None])
+ result = _roundtrip(arr)
+ assert len(result) == 4
+ assert result[0] == 42
+ assert result[1] == "hello"
+ assert result[2] is True
+ assert result[3] is None
+
+ def test_nested_arrays(self) -> None:
+ """Nested arrays roundtrip correctly."""
+ inner1 = tvm_ffi.convert([1, 2])
+ inner2 = tvm_ffi.convert([3])
+ outer = tvm_ffi.convert([inner1, inner2])
+ result = _roundtrip(outer)
+ assert len(result) == 2
+ assert list(result[0]) == [1, 2]
+ assert list(result[1]) == [3]
+
+ def test_duplicated_elements(self) -> None:
+ """Array with duplicated elements roundtrips correctly."""
+ arr = tvm_ffi.convert([42, 42, 42])
+ result = _roundtrip(arr)
+ assert len(result) == 3
+ assert all(x == 42 for x in result)
+
+
+class TestMap:
+ """Roundtrip tests for ffi.Map containers."""
+
+ def test_empty(self) -> None:
+ """Empty map roundtrips correctly."""
+ m = tvm_ffi.convert({})
+ _assert_roundtrip_eq(m)
+
+ def test_single_entry(self) -> None:
+ """Single-entry map roundtrips correctly."""
+ m = tvm_ffi.convert({"key": 42})
+ result = _roundtrip(m)
+ assert len(result) == 1
+ assert result["key"] == 42
+
+ def test_multiple_entries(self) -> None:
+ """Multi-entry map roundtrips correctly."""
+ m = tvm_ffi.convert({"a": 1, "b": 2, "c": 3})
+ result = _roundtrip(m)
+ assert len(result) == 3
+ assert result["a"] == 1
+ assert result["b"] == 2
+ assert result["c"] == 3
+
+ def test_mixed_value_types(self) -> None:
+ """Map with mixed value types roundtrips correctly."""
+ m = tvm_ffi.convert({"int": 42, "str": "hello", "bool": True, "none":
None})
+ result = _roundtrip(m)
+ assert result["int"] == 42
+ assert result["str"] == "hello"
+ assert result["bool"] is True
+ assert result["none"] is None
+
+ def test_nested_map(self) -> None:
+ """Nested maps roundtrip correctly."""
+ inner = tvm_ffi.convert({"x": 1})
+ outer = tvm_ffi.convert({"inner": inner})
+ result = _roundtrip(outer)
+ assert result["inner"]["x"] == 1
+
+ def test_map_with_array_value(self) -> None:
+ """Map with array values roundtrips correctly."""
+ arr = tvm_ffi.convert([10, 20])
+ m = tvm_ffi.convert({"nums": arr})
+ result = _roundtrip(m)
+ assert list(result["nums"]) == [10, 20]
+
+
+class TestShape:
+ """Roundtrip tests for ffi.Shape containers."""
+
+ def test_empty(self) -> None:
+ """Empty shape roundtrips correctly."""
+ shape = tvm_ffi.Shape(())
+ _assert_roundtrip_eq(shape)
+
+ def test_1d(self) -> None:
+ """1D shape roundtrips correctly."""
+ shape = tvm_ffi.Shape((10,))
+ result = _roundtrip(shape)
+ assert list(result) == [10]
+
+ def test_nd(self) -> None:
+ """N-D shape roundtrips correctly."""
+ shape = tvm_ffi.Shape((1, 2, 3, 4))
+ result = _roundtrip(shape)
+ assert list(result) == [1, 2, 3, 4]
+
+
+# ---------------------------------------------------------------------------
+# Objects with reflection
+# ---------------------------------------------------------------------------
+class TestObjectSerialization:
+ """Roundtrip tests for objects with reflection metadata."""
+
+ def test_int_pair_roundtrip(self) -> None:
+ """TestIntPair has refl::init and POD int64 fields."""
+ pair = tvm_ffi.testing.TestIntPair(3, 7) # ty:
ignore[too-many-positional-arguments]
+ s = to_json_graph_str(pair)
+ result = from_json_graph_str(s)
+ assert result.a == 3
+ assert result.b == 7
+
+ def test_int_pair_zero_values(self) -> None:
+ """TestIntPair with zero values roundtrips correctly."""
+ pair = tvm_ffi.testing.TestIntPair(0, 0) # ty:
ignore[too-many-positional-arguments]
+ result = _roundtrip(pair)
+ assert result.a == 0
+ assert result.b == 0
+
+ def test_int_pair_negative_values(self) -> None:
+ """TestIntPair with negative values roundtrips correctly."""
+ pair = tvm_ffi.testing.TestIntPair(-100, -200) # ty:
ignore[too-many-positional-arguments]
+ result = _roundtrip(pair)
+ assert result.a == -100
+ assert result.b == -200
+
+ def test_int_pair_large_values(self) -> None:
+ """TestIntPair with large values roundtrips correctly."""
+ pair = tvm_ffi.testing.TestIntPair(10**15, -(10**15)) # ty:
ignore[too-many-positional-arguments]
+ result = _roundtrip(pair)
+ assert result.a == 10**15
+ assert result.b == -(10**15)
+
+
+# ---------------------------------------------------------------------------
+# JSON structure verification
+# ---------------------------------------------------------------------------
+class TestJSONStructure:
+ """Tests verifying the internal JSON graph structure."""
+
+ def test_null_json_structure(self) -> None:
+ """None produces a single node with type 'None'."""
+ s = to_json_graph_str(None)
+ parsed = json.loads(s)
+ assert parsed["root_index"] == 0
+ assert len(parsed["nodes"]) == 1
+ assert parsed["nodes"][0]["type"] == "None"
+
+ def test_bool_json_structure(self) -> None:
+ """Bool produces a node with type 'bool'."""
+ s = to_json_graph_str(True)
+ parsed = json.loads(s)
+ assert parsed["nodes"][0]["type"] == "bool"
+ assert parsed["nodes"][0]["data"] is True
+
+ def test_int_json_structure(self) -> None:
+ """Int produces a node with type 'int'."""
+ s = to_json_graph_str(42)
+ parsed = json.loads(s)
+ assert parsed["nodes"][0]["type"] == "int"
+ assert parsed["nodes"][0]["data"] == 42
+
+ def test_float_json_structure(self) -> None:
+ """Float produces a node with type 'float'."""
+ s = to_json_graph_str(3.14)
+ parsed = json.loads(s)
+ assert parsed["nodes"][0]["type"] == "float"
+ assert parsed["nodes"][0]["data"] == pytest.approx(3.14)
+
+ def test_string_json_structure(self) -> None:
+ """String produces a node with type 'ffi.String'."""
+ s = to_json_graph_str(tvm_ffi.convert("hello"))
+ parsed = json.loads(s)
+ assert parsed["nodes"][parsed["root_index"]]["type"] == "ffi.String"
+ assert parsed["nodes"][parsed["root_index"]]["data"] == "hello"
+
+ def test_array_json_structure(self) -> None:
+ """Array data contains node index references."""
+ s = to_json_graph_str(tvm_ffi.convert([1, 2]))
+ parsed = json.loads(s)
+ root = parsed["nodes"][parsed["root_index"]]
+ assert root["type"] == "ffi.Array"
+ # data should be list of node indices
+ assert isinstance(root["data"], list)
+ assert len(root["data"]) == 2
+
+ def test_map_json_structure(self) -> None:
+ """Map data contains flattened key-value node index pairs."""
+ s = to_json_graph_str(tvm_ffi.convert({"a": 1}))
+ parsed = json.loads(s)
+ root = parsed["nodes"][parsed["root_index"]]
+ assert root["type"] == "ffi.Map"
+ assert isinstance(root["data"], list)
+ # key-value pairs flattened: [key_idx, val_idx]
+ assert len(root["data"]) == 2
+
+ def test_node_dedup(self) -> None:
+ """Duplicate values should share the same node index."""
+ s = to_json_graph_str(tvm_ffi.convert([42, 42, 42]))
+ parsed = json.loads(s)
+ root = parsed["nodes"][parsed["root_index"]]
+ # all three elements should reference the same node
+ assert root["data"][0] == root["data"][1] == root["data"][2]
+
+ def test_object_pod_fields_are_inlined(self) -> None:
+ """POD fields (int, bool, float) are inlined directly via
field_static_type_index."""
+ pair = tvm_ffi.testing.TestIntPair(3, 7) # ty:
ignore[too-many-positional-arguments]
+ s = to_json_graph_str(pair)
+ parsed = json.loads(s)
+ root = parsed["nodes"][parsed["root_index"]]
+ assert root["type"] == "testing.TestIntPair"
+ # fields 'a' and 'b' should be inlined as direct int values (not node
indices)
+ assert root["data"]["a"] == 3
+ assert root["data"]["b"] == 7
+
+
+# ---------------------------------------------------------------------------
+# Metadata
+# ---------------------------------------------------------------------------
+class TestMetadata:
+ """Tests for optional metadata in serialized output."""
+
+ def test_with_metadata(self) -> None:
+ """Metadata dict appears in serialized JSON when provided."""
+ s = to_json_graph_str(42, {"version": "1.0"})
+ parsed = json.loads(s)
+ assert "metadata" in parsed
+ assert parsed["metadata"]["version"] == "1.0"
+
+ def test_without_metadata(self) -> None:
+ """Metadata key is absent when not provided."""
+ s = to_json_graph_str(42)
+ parsed = json.loads(s)
+ assert "metadata" not in parsed
+
+
+# ---------------------------------------------------------------------------
+# Error cases
+# ---------------------------------------------------------------------------
+class TestErrors:
+ """Tests for error handling in deserialization."""
+
+ def test_invalid_json_string(self) -> None:
+ """Invalid JSON string raises an error."""
+ with pytest.raises(Exception):
+ from_json_graph_str("not valid json")
+
+ def test_empty_json_string(self) -> None:
+ """Empty string raises an error."""
+ with pytest.raises(Exception):
+ from_json_graph_str("")
+
+ def test_missing_root_index(self) -> None:
+ """JSON without root_index raises an error."""
+ with pytest.raises(Exception):
+ from_json_graph_str('{"nodes": []}')
+
+ def test_missing_nodes(self) -> None:
+ """JSON without nodes raises an error."""
+ with pytest.raises(Exception):
+ from_json_graph_str('{"root_index": 0}')