This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch json-upgrade in repository https://gitbox.apache.org/repos/asf/tvm.git
commit df738d1df63207d6b923ea60b878e756f6a20ec3 Author: tqchen <[email protected]> AuthorDate: Tue Aug 5 10:06:36 2025 -0400 [FFI] Support object creator --- ffi/include/tvm/ffi/reflection/creator.h | 112 +++++++++++++++++++++ ...t_reflection_accessor.cc => test_reflection.cc} | 6 ++ ffi/tests/cpp/testing_object.h | 1 + 3 files changed, 119 insertions(+) diff --git a/ffi/include/tvm/ffi/reflection/creator.h b/ffi/include/tvm/ffi/reflection/creator.h new file mode 100644 index 0000000000..983b8034a3 --- /dev/null +++ b/ffi/include/tvm/ffi/reflection/creator.h @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/reflection/creator.h + * \brief Reflection-based creator to create objects from type key and fields. + */ +#ifndef TVM_FFI_REFLECTION_CREATOR_H_ +#define TVM_FFI_REFLECTION_CREATOR_H_ + +#include <tvm/ffi/any.h> +#include <tvm/ffi/container/map.h> +#include <tvm/ffi/reflection/accessor.h> +#include <tvm/ffi/string.h> + +namespace tvm { +namespace ffi { +namespace reflection { +/*! + * \brief helper wrapper class of TVMFFITypeInfo to create object based on reflection. + */ +class ObjectCreator { + public: + explicit ObjectCreator(std::string_view type_key) + : ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {} + + explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) { + int32_t type_index = type_info->type_index; + if (type_info->metadata == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) + << "` does not have reflection registered"; + } + if (type_info->metadata->creator == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) + << "` does not support default constructor, " + << "as a result cannot be created via reflection"; + } + } + + /** + * \brief Create an object from a map of fields. + * \param fields The fields of the object. + * \return The created object. + */ + Any operator()(const Map<String, Any>& fields) const { + TVMFFIObjectHandle handle; + TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle)); + ObjectPtr<Object> ptr = + details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle)); + size_t match_field_count = 0; + ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) { + String field_name(field_info->name); + void* field_addr = reinterpret_cast<char*>(ptr.get()) + field_info->offset; + if (fields.count(field_name) != 0) { + Any field_value = fields[field_name]; + field_info->setter(field_addr, reinterpret_cast<const TVMFFIAny*>(&field_value)); + ++match_field_count; + } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { + field_info->setter(field_addr, &(field_info->default_value)); + } else { + TVM_FFI_THROW(TypeError) << "Required field `" + << String(field_info->name.data, field_info->name.size) + << "` not set in type `" + << String(type_info_->type_key.data, type_info_->type_key.size) + << "`"; + } + }); + if (match_field_count == fields.size()) return ObjectRef(ptr); + // report error that checks if contains extra fields that are not in the type + auto check_field_name = [&](const String& field_name) { + bool found = false; + ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo* field_info) { + if (field_name.compare(field_info->name) == 0) { + found = true; + return true; + } + return false; + }); + return found; + }; + for (const auto& [field_name, _] : fields) { + if (!check_field_name(field_name)) { + TVM_FFI_THROW(TypeError) << "Type `" + << String(type_info_->type_key.data, type_info_->type_key.size) + << "` does not have field `" << field_name << "`"; + } + } + TVM_FFI_UNREACHABLE(); + } + + private: + const TVMFFITypeInfo* type_info_; +}; +} // namespace reflection +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_CREATOR_H_ diff --git a/ffi/tests/cpp/test_reflection_accessor.cc b/ffi/tests/cpp/test_reflection.cc similarity index 96% rename from ffi/tests/cpp/test_reflection_accessor.cc rename to ffi/tests/cpp/test_reflection.cc index cb5145db07..98915c54e1 100644 --- a/ffi/tests/cpp/test_reflection_accessor.cc +++ b/ffi/tests/cpp/test_reflection.cc @@ -21,6 +21,7 @@ #include <tvm/ffi/container/map.h> #include <tvm/ffi/object.h> #include <tvm/ffi/reflection/accessor.h> +#include <tvm/ffi/reflection/creator.h> #include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/string.h> @@ -159,4 +160,9 @@ TEST(Reflection, FuncRegister) { EXPECT_EQ(fget_value(a).cast<int>(), 12); } +TEST(Reflection, ObjectCreator) { + namespace refl = tvm::ffi::reflection; + refl::ObjectCreator creator("test.Int"); + EXPECT_EQ(creator(Map<String, Any>({{"value", 1}})).cast<TInt>()->value, 1); +} } // namespace diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index c5725da941..fe3ba1b013 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -59,6 +59,7 @@ class TIntObj : public TNumberObj { public: int64_t value; + TIntObj() = default; TIntObj(int64_t value) : value(value) {} int64_t GetValue() const { return value; }
