This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch json-upgrade
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit df738d1df63207d6b923ea60b878e756f6a20ec3
Author: tqchen <[email protected]>
AuthorDate: Tue Aug 5 10:06:36 2025 -0400

    [FFI] Support object creator
---
 ffi/include/tvm/ffi/reflection/creator.h           | 112 +++++++++++++++++++++
 ...t_reflection_accessor.cc => test_reflection.cc} |   6 ++
 ffi/tests/cpp/testing_object.h                     |   1 +
 3 files changed, 119 insertions(+)

diff --git a/ffi/include/tvm/ffi/reflection/creator.h 
b/ffi/include/tvm/ffi/reflection/creator.h
new file mode 100644
index 0000000000..983b8034a3
--- /dev/null
+++ b/ffi/include/tvm/ffi/reflection/creator.h
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/reflection/creator.h
+ * \brief Reflection-based creator to create objects from type key and fields.
+ */
+#ifndef TVM_FFI_REFLECTION_CREATOR_H_
+#define TVM_FFI_REFLECTION_CREATOR_H_
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/string.h>
+
+namespace tvm {
+namespace ffi {
+namespace reflection {
+/*!
+ * \brief helper wrapper class of TVMFFITypeInfo to create object based on 
reflection.
+ */
+class ObjectCreator {
+ public:
+  explicit ObjectCreator(std::string_view type_key)
+      : ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {}
+
+  explicit ObjectCreator(const TVMFFITypeInfo* type_info) : 
type_info_(type_info) {
+    int32_t type_index = type_info->type_index;
+    if (type_info->metadata == nullptr) {
+      TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
+                                  << "` does not have reflection registered";
+    }
+    if (type_info->metadata->creator == nullptr) {
+      TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
+                                  << "` does not support default constructor, "
+                                  << "as a result cannot be created via 
reflection";
+    }
+  }
+
+  /**
+   * \brief Create an object from a map of fields.
+   * \param fields The fields of the object.
+   * \return The created object.
+   */
+  Any operator()(const Map<String, Any>& fields) const {
+    TVMFFIObjectHandle handle;
+    TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle));
+    ObjectPtr<Object> ptr =
+        
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
+    size_t match_field_count = 0;
+    ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) {
+      String field_name(field_info->name);
+      void* field_addr = reinterpret_cast<char*>(ptr.get()) + 
field_info->offset;
+      if (fields.count(field_name) != 0) {
+        Any field_value = fields[field_name];
+        field_info->setter(field_addr, reinterpret_cast<const 
TVMFFIAny*>(&field_value));
+        ++match_field_count;
+      } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
+        field_info->setter(field_addr, &(field_info->default_value));
+      } else {
+        TVM_FFI_THROW(TypeError) << "Required field `"
+                                 << String(field_info->name.data, 
field_info->name.size)
+                                 << "` not set in type `"
+                                 << String(type_info_->type_key.data, 
type_info_->type_key.size)
+                                 << "`";
+      }
+    });
+    if (match_field_count == fields.size()) return ObjectRef(ptr);
+    // report error that checks if contains extra fields that are not in the 
type
+    auto check_field_name = [&](const String& field_name) {
+      bool found = false;
+      ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo* 
field_info) {
+        if (field_name.compare(field_info->name) == 0) {
+          found = true;
+          return true;
+        }
+        return false;
+      });
+      return found;
+    };
+    for (const auto& [field_name, _] : fields) {
+      if (!check_field_name(field_name)) {
+        TVM_FFI_THROW(TypeError) << "Type `"
+                                 << String(type_info_->type_key.data, 
type_info_->type_key.size)
+                                 << "` does not have field `" << field_name << 
"`";
+      }
+    }
+    TVM_FFI_UNREACHABLE();
+  }
+
+ private:
+  const TVMFFITypeInfo* type_info_;
+};
+}  // namespace reflection
+}  // namespace ffi
+}  // namespace tvm
+#endif  // TVM_FFI_REFLECTION_CREATOR_H_
diff --git a/ffi/tests/cpp/test_reflection_accessor.cc 
b/ffi/tests/cpp/test_reflection.cc
similarity index 96%
rename from ffi/tests/cpp/test_reflection_accessor.cc
rename to ffi/tests/cpp/test_reflection.cc
index cb5145db07..98915c54e1 100644
--- a/ffi/tests/cpp/test_reflection_accessor.cc
+++ b/ffi/tests/cpp/test_reflection.cc
@@ -21,6 +21,7 @@
 #include <tvm/ffi/container/map.h>
 #include <tvm/ffi/object.h>
 #include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/reflection/creator.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/ffi/string.h>
 
@@ -159,4 +160,9 @@ TEST(Reflection, FuncRegister) {
   EXPECT_EQ(fget_value(a).cast<int>(), 12);
 }
 
+TEST(Reflection, ObjectCreator) {
+  namespace refl = tvm::ffi::reflection;
+  refl::ObjectCreator creator("test.Int");
+  EXPECT_EQ(creator(Map<String, Any>({{"value", 1}})).cast<TInt>()->value, 1);
+}
 }  // namespace
diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h
index c5725da941..fe3ba1b013 100644
--- a/ffi/tests/cpp/testing_object.h
+++ b/ffi/tests/cpp/testing_object.h
@@ -59,6 +59,7 @@ class TIntObj : public TNumberObj {
  public:
   int64_t value;
 
+  TIntObj() = default;
   TIntObj(int64_t value) : value(value) {}
 
   int64_t GetValue() const { return value; }

Reply via email to