This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9c27523503 [FFI][EXTRA] Serialization To/From JSONGraph (#18187)
9c27523503 is described below
commit 9c275235031ff0d9eca011f047d3797866c6780e
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Aug 5 10:02:58 2025 -0400
[FFI][EXTRA] Serialization To/From JSONGraph (#18187)
---
ffi/CMakeLists.txt | 1 +
ffi/include/tvm/ffi/extra/base64.h | 142 ++++++++
ffi/include/tvm/ffi/extra/json.h | 2 +-
ffi/include/tvm/ffi/extra/serialization.h | 72 ++++
ffi/src/ffi/extra/serialization.cc | 408 ++++++++++++++++++++++
ffi/tests/cpp/extra/test_serialization.cc | 354 +++++++++++++++++++
ffi/tests/cpp/extra/test_structural_equal_hash.cc | 6 +-
ffi/tests/cpp/testing_object.h | 20 +-
8 files changed, 998 insertions(+), 7 deletions(-)
diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt
index b67611a273..55fbd1c1bc 100644
--- a/ffi/CMakeLists.txt
+++ b/ffi/CMakeLists.txt
@@ -68,6 +68,7 @@ if (TVM_FFI_USE_EXTRA_CXX_API)
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_hash.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_parser.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/serialization.cc"
)
endif()
diff --git a/ffi/include/tvm/ffi/extra/base64.h
b/ffi/include/tvm/ffi/extra/base64.h
new file mode 100644
index 0000000000..136fec2e7f
--- /dev/null
+++ b/ffi/include/tvm/ffi/extra/base64.h
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ *
+ * \file tvm/ffi/extra/base64.h
+ * \brief Base64 encoding and decoding utilities
+ */
+#ifndef TVM_FFI_EXTRA_BASE64_H_
+#define TVM_FFI_EXTRA_BASE64_H_
+
+#include <tvm/ffi/string.h>
+
+#include <string>
+
+namespace tvm {
+namespace ffi {
+/*!
+ * \brief Encode a byte array into a base64 string
+ * \param bytes The byte array to encode
+ * \return The base64 encoded string
+ */
+inline String Base64Encode(TVMFFIByteArray bytes) {
+ // encoding every 3 bytes into 4 characters
+ constexpr const char kEncodeTable[] =
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+ std::string encoded;
+ encoded.reserve(4 * (bytes.size + 2) / 3);
+
+ for (size_t i = 0; i < (bytes.size / 3) * 3; i += 3) {
+ int32_t buf[3];
+ buf[0] = static_cast<int32_t>(bytes.data[i]);
+ buf[1] = static_cast<int32_t>(bytes.data[i + 1]);
+ buf[2] = static_cast<int32_t>(bytes.data[i + 2]);
+ encoded.push_back(kEncodeTable[buf[0] >> 2]);
+ encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]);
+ encoded.push_back(kEncodeTable[((buf[1] << 2) | (buf[2] >> 6)) & 0x3F]);
+ encoded.push_back(kEncodeTable[buf[2] & 0x3F]);
+ }
+ if (bytes.size % 3 == 1) {
+ int32_t buf[1] = {static_cast<int32_t>(bytes.data[bytes.size - 1])};
+ encoded.push_back(kEncodeTable[buf[0] >> 2]);
+ encoded.push_back(kEncodeTable[(buf[0] << 4) & 0x3F]);
+ encoded.push_back('=');
+ encoded.push_back('=');
+ } else if (bytes.size % 3 == 2) {
+ int32_t buf[2] = {static_cast<int32_t>(bytes.data[bytes.size - 2]),
+ static_cast<int32_t>(bytes.data[bytes.size - 1])};
+ encoded.push_back(kEncodeTable[buf[0] >> 2]);
+ encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]);
+ encoded.push_back(kEncodeTable[(buf[1] << 2) & 0x3F]);
+ encoded.push_back('=');
+ }
+ return String(encoded);
+}
+
+/*!
+ * \brief Encode a bytes object into a base64 string
+ * \param data The bytes object to encode
+ * \return The base64 encoded string
+ */
+inline String Base64Encode(const Bytes& data) {
+ return Base64Encode(TVMFFIByteArray{data.data(), data.size()});
+}
+
+/*!
+ * \brief Decode a base64 string into a byte array
+ * \param data The base64 encoded string to decode
+ * \return The decoded byte array
+ */
+inline Bytes Base64Decode(TVMFFIByteArray bytes) {
+ constexpr const char kDecodeTable[] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0,
+ 62, // '+'
+ 0, 0, 0,
+ 63, // '/'
+ 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
+ 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
38, 39, 40, 41,
+ 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
+ };
+ std::string decoded;
+ decoded.reserve(bytes.size * 3 / 4);
+ if (bytes.size == 0) return Bytes();
+ TVM_FFI_ICHECK(bytes.size % 4 == 0) << "invalid base64 encoding";
+ // leverage this property to simplify decoding
+ static_assert('=' < sizeof(kDecodeTable) &&
kDecodeTable[static_cast<size_t>('=')] == 0);
+ // base64 is always multiple of 4 bytes
+ for (size_t i = 0; i < bytes.size; i += 4) {
+ // decode every 4 characters into 24bits, each character contains 6 bits
+ // note that = is also decoded as 0, which is safe to skip
+ int32_t buf[4] = {
+ static_cast<int32_t>(bytes.data[i]),
+ static_cast<int32_t>(bytes.data[i + 1]),
+ static_cast<int32_t>(bytes.data[i + 2]),
+ static_cast<int32_t>(bytes.data[i + 3]),
+ };
+ int32_t value_i24 = (static_cast<int32_t>(kDecodeTable[buf[0]]) << 18) |
+ (static_cast<int32_t>(kDecodeTable[buf[1]]) << 12) |
+ (static_cast<int32_t>(kDecodeTable[buf[2]]) << 6) |
+ static_cast<int32_t>(kDecodeTable[buf[3]]);
+ // unpack 24bits into 3 bytes, each contains 8 bits
+ decoded.push_back(static_cast<char>((value_i24 >> 16) & 0xFF));
+ if (buf[2] != '=') {
+ decoded.push_back(static_cast<char>((value_i24 >> 8) & 0xFF));
+ }
+ if (buf[3] != '=') {
+ decoded.push_back(static_cast<char>(value_i24 & 0xFF));
+ }
+ }
+ return Bytes(decoded);
+}
+
+/*!
+ * \brief Decode a base64 string into a byte array
+ * \param data The base64 encoded string to decode
+ * \return The decoded byte array
+ */
+inline Bytes Base64Decode(const String& data) {
+ return Base64Decode(TVMFFIByteArray{data.data(), data.size()});
+}
+
+} // namespace ffi
+} // namespace tvm
+#endif // TVM_FFI_EXTRA_BASE64_H_
diff --git a/ffi/include/tvm/ffi/extra/json.h b/ffi/include/tvm/ffi/extra/json.h
index 847e60c0f6..409f7aa525 100644
--- a/ffi/include/tvm/ffi/extra/json.h
+++ b/ffi/include/tvm/ffi/extra/json.h
@@ -17,7 +17,7 @@
* under the License.
*/
/*!
- * \file tvm/ffi/json/json.h
+ * \file tvm/ffi/extra/json.h
* \brief Minimal lightweight JSON parsing and serialization utilities
*/
#ifndef TVM_FFI_EXTRA_JSON_H_
diff --git a/ffi/include/tvm/ffi/extra/serialization.h
b/ffi/include/tvm/ffi/extra/serialization.h
new file mode 100644
index 0000000000..c08ad81cc3
--- /dev/null
+++ b/ffi/include/tvm/ffi/extra/serialization.h
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/serialization.h
+ * \brief Reflection-based serialization utilities
+ */
+#ifndef TVM_FFI_EXTRA_SERIALIZATION_H_
+#define TVM_FFI_EXTRA_SERIALIZATION_H_
+
+#include <tvm/ffi/extra/base.h>
+#include <tvm/ffi/extra/json.h>
+
+namespace tvm {
+namespace ffi {
+
+/**
+ * \brief Serialize ffi::Any to a JSON that stores the object graph.
+ *
+ * The JSON graph structure is stored as follows:
+ *
+ * ```json
+ * {
+ * "root_index": <int>, // Index of root node in nodes array
+ * "nodes": [<node>, ...], // Array of serialized nodes
+ * "metadata": <object> // Optional metadata
+ * }
+ * ```
+ *
+ * Each node has the format: `{"type": "<type_key>", "data": <type_data>}`
+ * For object types and strings, the data may contain indices to other nodes.
+ * For object fields whose static type is known as a primitive type, it is
stored directly,
+ * otherwise, it is stored as a reference to the nodes array by an index.
+ *
+ * This function preserves the type and multiple references to the same object,
+ * which is useful for debugging and serialization.
+ *
+ * \param value The ffi::Any value to serialize.
+ * \param metadata Extra metadata attached to "metadata" field of the JSON
object.
+ * \return The serialized JSON value.
+ */
+TVM_FFI_EXTRA_CXX_API json::Value ToJSONGraph(const Any& value, const Any&
metadata = Any(nullptr));
+
+/**
+ * \brief Deserialize a JSON that stores the object graph to an ffi::Any value.
+ *
+ * This function can be used to implement deserialization
+ * and debugging.
+ *
+ * \param value The JSON value to deserialize.
+ * \return The deserialized object graph.
+ */
+TVM_FFI_EXTRA_CXX_API Any FromJSONGraph(const json::Value& value);
+
+} // namespace ffi
+} // namespace tvm
+#endif // TVM_FFI_EXTRA_SERIALIZATION_H_
diff --git a/ffi/src/ffi/extra/serialization.cc
b/ffi/src/ffi/extra/serialization.cc
new file mode 100644
index 0000000000..b3230f38fb
--- /dev/null
+++ b/ffi/src/ffi/extra/serialization.cc
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * \file src/ffi/extra/serialization.cc
+ *
+ * \brief Reflection-based serialization utilities.
+ */
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/dtype.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/extra/base64.h>
+#include <tvm/ffi/extra/serialization.h>
+#include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/string.h>
+
+namespace tvm {
+namespace ffi {
+
+class ObjectGraphSerializer {
+ public:
+ static json::Value Serialize(const Any& value, Any metadata) {
+ ObjectGraphSerializer serializer;
+ json::Object result;
+ result.Set("root_index", serializer.GetOrCreateNodeIndex(value));
+ result.Set("nodes", std::move(serializer.nodes_));
+ if (metadata != nullptr) {
+ result.Set("metadata", metadata);
+ }
+ return result;
+ }
+
+ private:
+ ObjectGraphSerializer() = default;
+
+ int64_t GetOrCreateNodeIndex(const Any& value) {
+ // already mapped value, return the index
+ auto it = node_index_map_.find(value);
+ if (it != node_index_map_.end()) {
+ return (*it).second;
+ }
+ json::Object node;
+ switch (value.type_index()) {
+ case TypeIndex::kTVMFFINone: {
+ node.Set("type", ffi::StaticTypeKey::kTVMFFINone);
+ break;
+ }
+ case TypeIndex::kTVMFFIBool: {
+ node.Set("type", ffi::StaticTypeKey::kTVMFFIBool);
+ node.Set("data",
details::AnyUnsafe::CopyFromAnyViewAfterCheck<bool>(value));
+ break;
+ }
+ case TypeIndex::kTVMFFIInt: {
+ node.Set("type", ffi::StaticTypeKey::kTVMFFIInt);
+ node.Set("data",
details::AnyUnsafe::CopyFromAnyViewAfterCheck<int64_t>(value));
+ break;
+ }
+ case TypeIndex::kTVMFFIFloat: {
+ node.Set("type", ffi::StaticTypeKey::kTVMFFIFloat);
+ node.Set("data",
details::AnyUnsafe::CopyFromAnyViewAfterCheck<double>(value));
+ break;
+ }
+ case TypeIndex::kTVMFFIDataType: {
+ DLDataType dtype =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<DLDataType>(value);
+ node.Set("type", ffi::StaticTypeKey::kTVMFFIDataType);
+ node.Set("data", DLDataTypeToString(dtype));
+ break;
+ }
+ case TypeIndex::kTVMFFIDevice: {
+ DLDevice device =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<DLDevice>(value);
+ node.Set("type", ffi::StaticTypeKey::kTVMFFIDevice);
+ node.Set("data", json::Array{
+ static_cast<int64_t>(device.device_type),
+ static_cast<int64_t>(device.device_id),
+ });
+ break;
+ }
+ case TypeIndex::kTVMFFISmallStr:
+ case TypeIndex::kTVMFFIStr: {
+ String str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<String>(value);
+ node.Set("type", ffi::StaticTypeKey::kTVMFFIStr);
+ node.Set("data", str);
+ break;
+ }
+ case TypeIndex::kTVMFFISmallBytes:
+ case TypeIndex::kTVMFFIBytes: {
+ Bytes bytes =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<Bytes>(value);
+ node.Set("type", ffi::StaticTypeKey::kTVMFFIBytes);
+ node.Set("data", Base64Encode(bytes));
+ break;
+ }
+ case TypeIndex::kTVMFFIArray: {
+ Array<Any> array =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<Array<Any>>(value);
+ node.Set("type", ffi::StaticTypeKey::kTVMFFIArray);
+ node.Set("data", CreateArrayData(array));
+ break;
+ }
+ case TypeIndex::kTVMFFIMap: {
+ Map<Any, Any> map =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<Map<Any, Any>>(value);
+ node.Set("type", ffi::StaticTypeKey::kTVMFFIMap);
+ node.Set("data", CreateMapData(map));
+ break;
+ }
+ default: {
+ if (value.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ // serialize type key since type index is runtime dependent
+ node.Set("type", value.GetTypeKey());
+ node.Set("data", CreateObjectData(value));
+ } else {
+ TVM_FFI_THROW(RuntimeError) << "Cannot serialize type `" <<
value.GetTypeKey() << "`";
+ TVM_FFI_UNREACHABLE();
+ }
+ }
+ }
+ int64_t node_index = nodes_.size();
+ nodes_.push_back(node);
+ node_index_map_.Set(value, node_index);
+ return node_index;
+ }
+
+ json::Array CreateArrayData(const Array<Any>& value) {
+ json::Array data;
+ data.reserve(value.size());
+ for (const Any& item : value) {
+ data.push_back(GetOrCreateNodeIndex(item));
+ }
+ return data;
+ }
+
+ json::Array CreateMapData(const Map<Any, Any>& value) {
+ json::Array data;
+ data.reserve(value.size() * 2);
+ for (const auto& [key, value] : value) {
+ data.push_back(GetOrCreateNodeIndex(key));
+ data.push_back(GetOrCreateNodeIndex(value));
+ }
+ return data;
+ }
+
+ // create the data for the object, if the type has a custom data to json
function,
+ // use it. otherwise, we go over the fields and create the data.
+ json::Object CreateObjectData(const Any& value) {
+ static reflection::TypeAttrColumn data_to_json =
reflection::TypeAttrColumn("__data_to_json__");
+ if (data_to_json[value.type_index()] != nullptr) {
+ return
data_to_json[value.type_index()].cast<Function>()(value).cast<json::Object>();
+ }
+ // NOTE: invariant: lhs and rhs are already the same type
+ const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(value.type_index());
+ if (type_info->metadata == nullptr) {
+ TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `"
+ << String(type_info->type_key)
+ << "`, so ToJSONGraph is not supported for this
type";
+ }
+ const Object* obj = value.cast<const Object*>();
+ json::Object data;
+ // go over the content and hash the fields
+ reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo*
field_info) {
+ // get the field value from both side
+ reflection::FieldGetter getter(field_info);
+ Any field_value = getter(obj);
+ int field_static_type_index = field_info->field_static_type_index;
+ String field_name(field_info->name);
+ // for static field index that are known, we can directly set the field
value.
+ switch (field_static_type_index) {
+ case TypeIndex::kTVMFFINone: {
+ data.Set(field_name, nullptr);
+ break;
+ }
+ case TypeIndex::kTVMFFIBool: {
+ data.Set(field_name,
details::AnyUnsafe::CopyFromAnyViewAfterCheck<bool>(field_value));
+ break;
+ }
+ case TypeIndex::kTVMFFIInt: {
+ data.Set(field_name,
details::AnyUnsafe::CopyFromAnyViewAfterCheck<int64_t>(field_value));
+ break;
+ }
+ case TypeIndex::kTVMFFIFloat: {
+ data.Set(field_name,
details::AnyUnsafe::CopyFromAnyViewAfterCheck<double>(field_value));
+ break;
+ }
+ case TypeIndex::kTVMFFIDataType: {
+ DLDataType dtype =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<DLDataType>(field_value);
+ data.Set(field_name, DLDataTypeToString(dtype));
+ break;
+ }
+ default: {
+ // for dynamic field index, we need need to put them onto nodes
+ int64_t node_index = GetOrCreateNodeIndex(field_value);
+ data.Set(field_name, node_index);
+ break;
+ }
+ }
+ });
+ return data;
+ }
+
+ // maps the original value to the index of the node in the nodes_ array
+ Map<Any, int64_t> node_index_map_;
+ // records nodes that are serialized
+ json::Array nodes_;
+};
+
+json::Value ToJSONGraph(const Any& value, const Any& metadata) {
+ return ObjectGraphSerializer::Serialize(value, metadata);
+}
+
+class ObjectGraphDeserializer {
+ public:
+ static Any Deserialize(const json::Value& value) {
+ ObjectGraphDeserializer deserializer(value);
+ return deserializer.GetOrDecodeNode(deserializer.root_index_);
+ }
+
+ Any GetOrDecodeNode(int64_t node_index) {
+ // already decoded null index
+ if (node_index == decoded_null_index_) {
+ return Any(nullptr);
+ }
+ // already decoded
+ if (decoded_nodes_[node_index] != nullptr) {
+ return decoded_nodes_[node_index];
+ }
+ // now decode the node
+ Any value = DecodeNode(nodes_[node_index].cast<json::Object>());
+ decoded_nodes_[node_index] = value;
+ if (value == nullptr) {
+ decoded_null_index_ = node_index;
+ }
+ return value;
+ }
+
+ private:
+ Any DecodeNode(const json::Object& node) {
+ String type_key = node["type"].cast<String>();
+ TVMFFIByteArray type_key_arr{type_key.data(), type_key.length()};
+ int32_t type_index;
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index));
+
+ switch (type_index) {
+ case TypeIndex::kTVMFFINone: {
+ return nullptr;
+ }
+ case TypeIndex::kTVMFFIBool: {
+ return node["data"].cast<bool>();
+ }
+ case TypeIndex::kTVMFFIInt: {
+ return node["data"].cast<int64_t>();
+ }
+ case TypeIndex::kTVMFFIFloat: {
+ return node["data"].cast<double>();
+ }
+ case TypeIndex::kTVMFFIDataType: {
+ return StringToDLDataType(node["data"].cast<String>());
+ }
+ case TypeIndex::kTVMFFIDevice: {
+ Array<int32_t> data = node["data"].cast<Array<int32_t>>();
+ return DLDevice{static_cast<DLDeviceType>(data[0]), data[1]};
+ }
+ case TypeIndex::kTVMFFIStr: {
+ return node["data"].cast<String>();
+ }
+ case TypeIndex::kTVMFFIBytes: {
+ return Base64Decode(node["data"].cast<String>());
+ }
+ case TypeIndex::kTVMFFIMap: {
+ return DecodeMapData(node["data"].cast<json::Array>());
+ }
+ case TypeIndex::kTVMFFIArray: {
+ return DecodeArrayData(node["data"].cast<json::Array>());
+ }
+ default: {
+ return DecodeObjectData(type_index, node["data"]);
+ }
+ }
+ }
+
+ Array<Any> DecodeArrayData(const json::Array& data) {
+ Array<Any> array;
+ array.reserve(data.size());
+ for (size_t i = 0; i < data.size(); i++) {
+ array.push_back(GetOrDecodeNode(data[i].cast<int64_t>()));
+ }
+ return array;
+ }
+
+ Map<Any, Any> DecodeMapData(const json::Array& data) {
+ Map<Any, Any> map;
+ for (size_t i = 0; i < data.size(); i += 2) {
+ int64_t key_index = data[i].cast<int64_t>();
+ int64_t value_index = data[i + 1].cast<int64_t>();
+ map.Set(GetOrDecodeNode(key_index), GetOrDecodeNode(value_index));
+ }
+ return map;
+ }
+
+ Any DecodeObjectData(int32_t type_index, const json::Value& data) {
+ static reflection::TypeAttrColumn data_from_json =
+ reflection::TypeAttrColumn("__data_from_json__");
+ if (data_from_json[type_index] != nullptr) {
+ return data_from_json[type_index].cast<Function>()(data);
+ }
+ // otherwise, we go over the fields and create the data.
+ const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
+ if (type_info->metadata == nullptr || type_info->metadata->creator ==
nullptr) {
+ TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
+ << "` does not support default constructor"
+ << ", so ToJSONGraph is not supported for
this type";
+ }
+ TVMFFIObjectHandle handle;
+ TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
+ ObjectPtr<Object> ptr =
+
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+
+ auto decode_field_value = [&](const TVMFFIFieldInfo* field_info,
json::Value data) -> Any {
+ switch (field_info->field_static_type_index) {
+ case TypeIndex::kTVMFFINone: {
+ return nullptr;
+ }
+ case TypeIndex::kTVMFFIBool: {
+ return data.cast<bool>();
+ }
+ case TypeIndex::kTVMFFIInt: {
+ return data.cast<int64_t>();
+ }
+ case TypeIndex::kTVMFFIFloat: {
+ return data.cast<double>();
+ }
+ case TypeIndex::kTVMFFIDataType: {
+ return StringToDLDataType(data.cast<String>());
+ }
+ default: {
+ return GetOrDecodeNode(data.cast<int64_t>());
+ }
+ }
+ };
+
+ json::Object data_object = data.cast<json::Object>();
+ reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo*
field_info) {
+ String field_name(field_info->name);
+ void* field_addr = reinterpret_cast<char*>(ptr.get()) +
field_info->offset;
+ if (data_object.count(field_name) != 0) {
+ Any field_value = decode_field_value(field_info,
data_object[field_name]);
+ field_info->setter(field_addr, reinterpret_cast<const
TVMFFIAny*>(&field_value));
+ } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
+ field_info->setter(field_addr, &(field_info->default_value));
+ } else {
+ TVM_FFI_THROW(TypeError) << "Required field `"
+ << String(field_info->name.data,
field_info->name.size)
+ << "` not set in type `" <<
TypeIndexToTypeKey(type_index) << "`";
+ }
+ });
+ return ObjectRef(ptr);
+ }
+
+ explicit ObjectGraphDeserializer(json::Value serialized) {
+ if (!serialized.as<json::Object>()) {
+ TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected an
object";
+ }
+ json::Object encoded_object = serialized.cast<json::Object>();
+ if (encoded_object.count("root_index") == 0 ||
!encoded_object["root_index"].as<int64_t>()) {
+ TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected
`root_index` integer field";
+ }
+ if (encoded_object.count("nodes") == 0 ||
!encoded_object["nodes"].as<json::Array>()) {
+ TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected
`nodes` array field";
+ }
+ root_index_ = encoded_object["root_index"].cast<int64_t>();
+ nodes_ = encoded_object["nodes"].cast<json::Array>();
+ decoded_nodes_.resize(nodes_.size(), Any(nullptr));
+ }
+ // nodes
+ json::Array nodes_;
+ // root index
+ int64_t root_index_;
+ // null index if already created
+ int64_t decoded_null_index_{-1};
+ // decoded nodes
+ std::vector<Any> decoded_nodes_;
+};
+
+Any FromJSONGraph(const json::Value& value) { return
ObjectGraphDeserializer::Deserialize(value); }
+
+TVM_FFI_STATIC_INIT_BLOCK({
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("ffi.ToJSONGraph",
ToJSONGraph).def("ffi.FromJSONGraph", FromJSONGraph);
+ refl::EnsureTypeAttrColumn("__data_to_json__");
+ refl::EnsureTypeAttrColumn("__data_from_json__");
+});
+
+} // namespace ffi
+} // namespace tvm
diff --git a/ffi/tests/cpp/extra/test_serialization.cc
b/ffi/tests/cpp/extra/test_serialization.cc
new file mode 100644
index 0000000000..f0aefa3709
--- /dev/null
+++ b/ffi/tests/cpp/extra/test_serialization.cc
@@ -0,0 +1,354 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/dtype.h>
+#include <tvm/ffi/extra/serialization.h>
+#include <tvm/ffi/extra/structural_equal.h>
+#include <tvm/ffi/string.h>
+
+#include "../testing_object.h"
+
+namespace {
+
+using namespace tvm::ffi;
+using namespace tvm::ffi::testing;
+
+TEST(Serialization, BoolNull) {
+ json::Object expected_null =
+ json::Object{{"root_index", 0}, {"nodes",
json::Array{json::Object{{"type", "None"}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(nullptr), expected_null));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_null), nullptr));
+
+ json::Object expected_true = json::Object{
+ {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "bool"},
{"data", true}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(true), expected_true));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_true), true));
+
+ json::Object expected_false = json::Object{
+ {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "bool"},
{"data", false}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(false), expected_false));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_false), false));
+}
+
+TEST(Serialization, IntegerTypes) {
+ // Test positive integer
+ json::Object expected_int = json::Object{
+ {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "int"},
{"data", 42}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(static_cast<int64_t>(42)),
expected_int));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_int),
static_cast<int64_t>(42)));
+}
+
+TEST(Serialization, FloatTypes) {
+ // Test positive float
+ json::Object expected_float =
+ json::Object{{"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "float"},
{"data", 3.14159}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(3.14159), expected_float));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_float), 3.14159));
+}
+
+TEST(Serialization, StringTypes) {
+ // Test short string
+ json::Object expected_short = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data",
String("hello")}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String("hello")), expected_short));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_short),
String("hello")));
+
+ // Test long string
+ std::string long_str(1000, 'x');
+ json::Object expected_long = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data",
String(long_str)}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String(long_str)), expected_long));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_long),
String(long_str)));
+
+ // Test string with special characters
+ json::Object expected_special = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "ffi.String"},
+ {"data",
String("hello\nworld\t\"quotes\"")}}}}};
+
EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String("hello\nworld\t\"quotes\"")),
expected_special));
+ EXPECT_TRUE(
+ StructuralEqual()(FromJSONGraph(expected_special),
String("hello\nworld\t\"quotes\"")));
+}
+
+TEST(Serialization, Bytes) {
+ // Test empty bytes
+ Bytes empty_bytes;
+ json::Object expected_empty = json::Object{
+ {"root_index", 0}, {"nodes", json::Array{json::Object{{"type",
"ffi.Bytes"}, {"data", ""}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_bytes), expected_empty));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_bytes));
+
+ // Test bytes with that encoded as base64
+ Bytes bytes_content = Bytes("abcd");
+ json::Object expected_encoded = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data",
"YWJjZA=="}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(bytes_content), expected_encoded));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_encoded),
bytes_content));
+
+ // Test bytes with that encoded as base64, that contains control characters
via utf-8
+ char bytes_v2_content[] = {0x01, 0x02, 0x03, 0x04, 0x01, 0x0b};
+ Bytes bytes_v2 = Bytes(bytes_v2_content, sizeof(bytes_v2_content));
+ json::Object expected_encoded_v2 = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data",
"AQIDBAEL"}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(bytes_v2), expected_encoded_v2));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_encoded_v2), bytes_v2));
+}
+
+TEST(Serialization, DataTypes) {
+ // Test int32 dtype
+ DLDataType int32_dtype;
+ int32_dtype.code = kDLInt;
+ int32_dtype.bits = 32;
+ int32_dtype.lanes = 1;
+
+ json::Object expected_int32 = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data",
String("int32")}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(int32_dtype), expected_int32));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_int32), int32_dtype));
+
+ // Test float64 dtype
+ DLDataType float64_dtype;
+ float64_dtype.code = kDLFloat;
+ float64_dtype.bits = 64;
+ float64_dtype.lanes = 1;
+
+ json::Object expected_float64 = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data",
String("float64")}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(float64_dtype), expected_float64));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_float64),
float64_dtype));
+
+ // Test vector dtype
+ DLDataType vector_dtype;
+ vector_dtype.code = kDLFloat;
+ vector_dtype.bits = 32;
+ vector_dtype.lanes = 4;
+
+ json::Object expected_vector = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data",
String("float32x4")}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(vector_dtype), expected_vector));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_vector), vector_dtype));
+}
+
+TEST(Serialization, DeviceTypes) {
+ // Test CPU device
+ DLDevice cpu_device;
+ cpu_device.device_type = kDLCPU;
+ cpu_device.device_id = 0;
+
+ json::Object expected_cpu = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "Device"},
+ {"data",
json::Array{static_cast<int64_t>(kDLCPU),
+
static_cast<int64_t>(0)}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(cpu_device), expected_cpu));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_cpu), cpu_device));
+
+ // Test GPU device
+ DLDevice gpu_device;
+ gpu_device.device_type = kDLCUDA;
+ gpu_device.device_id = 1;
+
+ json::Object expected_gpu = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{
+ {"type", "Device"}, {"data",
json::Array{static_cast<int64_t>(kDLCUDA), 1}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(gpu_device), expected_gpu));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_gpu), gpu_device));
+}
+
+TEST(Serialization, Arrays) {
+ // Test empty array
+ Array<Any> empty_array;
+ json::Object expected_empty = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "ffi.Array"}, {"data",
json::Array{}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_array), expected_empty));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_array));
+
+ // Test single element array
+ Array<Any> single_array;
+ single_array.push_back(Any(42));
+ json::Object expected_single =
+ json::Object{{"root_index", 1},
+ {"nodes", json::Array{
+ json::Object{{"type", "int"}, {"data",
static_cast<int64_t>(42)}},
+ json::Object{{"type", "ffi.Array"}, {"data",
json::Array{0}}},
+ }}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_array), expected_single));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_array));
+
+ // Test duplicated element array
+ Array<Any> duplicated_array;
+ duplicated_array.push_back(42);
+ duplicated_array.push_back(42);
+ json::Object expected_duplicated =
+ json::Object{{"root_index", 1},
+ {"nodes", json::Array{
+ json::Object{{"type", "int"}, {"data", 42}},
+ json::Object{{"type", "ffi.Array"}, {"data",
json::Array{0, 0}}},
+ }}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(duplicated_array),
expected_duplicated));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated),
duplicated_array));
+ // Test mixed element array, note that 42 and "hello" are duplicated and will
+ // be indexed as 0 and 1
+ Array<Any> mixed_array;
+ mixed_array.push_back(42);
+ mixed_array.push_back(String("hello"));
+ mixed_array.push_back(true);
+ mixed_array.push_back(nullptr);
+ mixed_array.push_back(42);
+ mixed_array.push_back(String("hello"));
+ json::Object expected_mixed = json::Object{
+ {"root_index", 4},
+ {"nodes", json::Array{
+ json::Object{{"type", "int"}, {"data", 42}},
+ json::Object{{"type", "ffi.String"}, {"data",
String("hello")}},
+ json::Object{{"type", "bool"}, {"data", true}},
+ json::Object{{"type", "None"}},
+ json::Object{{"type", "ffi.Array"}, {"data",
json::Array{0, 1, 2, 3, 0, 1}}},
+ }}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(mixed_array), expected_mixed));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_mixed), mixed_array));
+}
+
+TEST(Serialization, Maps) {
+ // Test empty map
+ Map<String, Any> empty_map;
+ json::Object expected_empty = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "ffi.Map"}, {"data",
json::Array{}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_map), expected_empty));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_map));
+
+ // Test single element map
+ Map<String, Any> single_map{{"key", 42}};
+ json::Object expected_single = json::Object{
+ {"root_index", 2},
+ {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data",
String("key")}},
+ json::Object{{"type", "int"}, {"data", 42}},
+ json::Object{{"type", "ffi.Map"}, {"data",
json::Array{0, 1}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_map), expected_single));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_map));
+
+ // Test duplicated element map
+ Map<String, Any> duplicated_map{{"b", 42}, {"a", 42}};
+ json::Object expected_duplicated = json::Object{
+ {"root_index", 3},
+ {"nodes", json::Array{
+ json::Object{{"type", "ffi.String"}, {"data", "b"}},
+ json::Object{{"type", "int"}, {"data", 42}},
+ json::Object{{"type", "ffi.String"}, {"data", "a"}},
+ json::Object{{"type", "ffi.Map"}, {"data", json::Array{0,
1, 2, 1}}},
+
+ }}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(duplicated_map),
expected_duplicated));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated),
duplicated_map));
+}
+
+TEST(Serialization, TestObjectVar) {
+ TVar x = TVar("x");
+ json::Object expected_x = json::Object{
+ {"root_index", 1},
+ {"nodes",
+ json::Array{json::Object{{"type", "ffi.String"}, {"data", "x"}},
+ json::Object{{"type", "test.Var"}, {"data",
json::Object{{"name", 0}}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(x), expected_x));
+ EXPECT_TRUE(StructuralEqual::Equal(FromJSONGraph(expected_x), x,
/*map_free_vars=*/true));
+}
+
+TEST(Serialization, TestObjectIntCustomToJSON) {
+ TInt value = TInt(42);
+ json::Object expected_i = json::Object{
+ {"root_index", 0},
+ {"nodes",
+ json::Array{json::Object{{"type", "test.Int"}, {"data",
json::Object{{"value", 42}}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(value), expected_i));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_i), value));
+}
+
+TEST(Serialization, TestObjectFunc) {
+ TVar x = TVar("x");
+ // comment fields are ignored
+ TFunc fa = TFunc({x}, {x, x}, String("comment a"));
+
+ json::Object expected_fa = json::Object{
+ {"root_index", 5},
+ {"nodes",
+ json::Array{
+ json::Object{{"type", "ffi.String"}, {"data", "x"}},
// string "x"
+ json::Object{{"type", "test.Var"}, {"data", json::Object{{"name",
0}}}}, // var x
+ json::Object{{"type", "ffi.Array"}, {"data", json::Array{1}}},
// array [x]
+ json::Object{{"type", "ffi.Array"}, {"data", json::Array{1, 1}}},
// array [x, x]
+ json::Object{{"type", "ffi.String"}, {"data", "comment a"}},
// "comment a"
+ json::Object{{"type", "test.Func"},
+ {"data", json::Object{{"params", 2}, {"body", 3},
{"comment", 4}}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(fa), expected_fa));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_fa), fa));
+
+ TFunc fb = TFunc({}, {}, std::nullopt);
+ json::Object expected_fb = json::Object{
+ {"root_index", 3},
+ {"nodes",
+ json::Array{
+ json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}},
+ json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}},
+ json::Object{{"type", "None"}},
+ json::Object{{"type", "test.Func"},
+ {"data", json::Object{{"params", 0}, {"body", 1},
{"comment", 2}}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(fb), expected_fb));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_fb), fb));
+}
+
+TEST(Serialization, AttachMetadata) {
+ bool value = true;
+ json::Object metadata{{"version", "1.0"}};
+ json::Object expected =
+ json::Object{{"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "bool"},
{"data", true}}}},
+ {"metadata", metadata}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(value, metadata), expected));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected), value));
+}
+
+TEST(Serialization, ShuffleNodeOrder) {
+ // the FromJSONGraph is agnostic to the node order
+ // so we can shuffle the node order as it reads nodes lazily
+ Map<String, Any> duplicated_map{{"b", 42}, {"a", 42}};
+ json::Object expected_shuffled = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{
+ json::Object{{"type", "ffi.Map"}, {"data", json::Array{2,
3, 1, 3}}},
+ json::Object{{"type", "ffi.String"}, {"data", "a"}},
+ json::Object{{"type", "ffi.String"}, {"data", "b"}},
+ json::Object{{"type", "int"}, {"data", 42}},
+ }}};
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_shuffled),
duplicated_map));
+}
+
+} // namespace
diff --git a/ffi/tests/cpp/extra/test_structural_equal_hash.cc
b/ffi/tests/cpp/extra/test_structural_equal_hash.cc
index 76c485d906..8a377f4837 100644
--- a/ffi/tests/cpp/extra/test_structural_equal_hash.cc
+++ b/ffi/tests/cpp/extra/test_structural_equal_hash.cc
@@ -147,10 +147,10 @@ TEST(StructuralEqualHash, FuncDefAndIgnoreField) {
TVar x = TVar("x");
TVar y = TVar("y");
// comment fields are ignored
- TFunc fa = TFunc({x}, {TInt(1), x}, "comment a");
- TFunc fb = TFunc({y}, {TInt(1), y}, "comment b");
+ TFunc fa = TFunc({x}, {TInt(1), x}, String("comment a"));
+ TFunc fb = TFunc({y}, {TInt(1), y}, String("comment b"));
- TFunc fc = TFunc({x}, {TInt(1), TInt(2)}, "comment c");
+ TFunc fc = TFunc({x}, {TInt(1), TInt(2)}, String("comment c"));
EXPECT_TRUE(StructuralEqual()(fa, fb));
EXPECT_EQ(StructuralHash()(fa), StructuralHash()(fb));
diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h
index 78ca008e10..c5725da941 100644
--- a/ffi/tests/cpp/testing_object.h
+++ b/ffi/tests/cpp/testing_object.h
@@ -21,6 +21,7 @@
#define TVM_FFI_TESTING_OBJECT_H_
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/map.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/registry.h>
@@ -87,6 +88,15 @@ inline void TIntObj::RegisterReflection() {
refl::TypeAttrDef<TIntObj>()
.def("test.GetValue", &TIntObj::GetValue)
.attr("test.size", sizeof(TIntObj));
+ // custom json serialization
+ refl::TypeAttrDef<TIntObj>()
+ .def("__data_to_json__",
+ [](const TIntObj* self) -> Map<String, Any> {
+ return Map<String, Any>{{"value", self->value}};
+ })
+ .def("__data_from_json__", [](Map<String, Any> json_obj) -> TInt {
+ return TInt(json_obj["value"].cast<int64_t>());
+ });
}
class TFloatObj : public TNumberObj {
@@ -154,6 +164,8 @@ class TVarObj : public Object {
public:
std::string name;
+ // need default constructor for json serialization
+ TVarObj() = default;
TVarObj(std::string name) : name(name) {}
static void RegisterReflection() {
@@ -178,9 +190,11 @@ class TFuncObj : public Object {
public:
Array<TVar> params;
Array<ObjectRef> body;
- String comment;
+ Optional<String> comment;
- TFuncObj(Array<TVar> params, Array<ObjectRef> body, String comment)
+ // need default constructor for json serialization
+ TFuncObj() = default;
+ TFuncObj(Array<TVar> params, Array<ObjectRef> body, Optional<String> comment)
: params(params), body(body), comment(comment) {}
static void RegisterReflection() {
@@ -198,7 +212,7 @@ class TFuncObj : public Object {
class TFunc : public ObjectRef {
public:
- explicit TFunc(Array<TVar> params, Array<ObjectRef> body, String comment) {
+ explicit TFunc(Array<TVar> params, Array<ObjectRef> body, Optional<String>
comment) {
data_ = make_object<TFuncObj>(params, body, comment);
}