This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b4a6b7354d [FFI][REFACTOR] Phase out TVM_FFI_REGISTER_GLOBAL in favor
of GlobalDef (#18148)
b4a6b7354d is described below
commit b4a6b7354d9cc8b4a54141c5472a8230eece9b79
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Jul 15 13:59:25 2025 -0400
[FFI][REFACTOR] Phase out TVM_FFI_REGISTER_GLOBAL in favor of GlobalDef
(#18148)
This PR migrates the remaining global def reg to use the new mechanism.
It also phases out the TVM_FFI_REGISTER_GLOBAL macro in favor of
the GlobalDef mechanism.
---
docs/arch/device_target_interactions.rst | 12 +-
docs/arch/pass_infra.rst | 6 +-
docs/arch/runtime.rst | 16 ++-
ffi/include/tvm/ffi/function.h | 144 -----------------------
include/tvm/runtime/profiling.h | 7 +-
python/tvm/contrib/msc/plugin/codegen/sources.py | 12 ++
python/tvm/relax/frontend/nn/op.py | 2 +-
python/tvm/runtime/_ffi_api.py | 3 +-
python/tvm/runtime/_ffi_node_api.py | 4 +-
src/contrib/msc/plugin/tvm_codegen.cc | 9 +-
src/relax/op/op_common.h | 3 +-
src/relax/op/tensor/binary.h | 3 +-
src/relax/op/tensor/search.cc | 3 +-
src/relax/op/tensor/statistical.h | 3 +-
src/runtime/device_api.cc | 16 ++-
src/runtime/disco/nccl/nccl.cc | 99 +++++++---------
src/script/ir_builder/tir/ir.cc | 110 +++++++++--------
src/target/datatype/registry.h | 2 +-
src/tir/op/op.cc | 96 +++++++--------
src/topi/broadcast.cc | 87 +++++++-------
tests/python/contrib/test_hexagon/README_RPC.md | 23 ++--
21 files changed, 271 insertions(+), 389 deletions(-)
diff --git a/docs/arch/device_target_interactions.rst
b/docs/arch/device_target_interactions.rst
index 09867f88fa..6015c43510 100644
--- a/docs/arch/device_target_interactions.rst
+++ b/docs/arch/device_target_interactions.rst
@@ -153,7 +153,10 @@ then be registered with the following steps.
#. Register the function to the tvm registry::
-
TVM_FFI_REGISTER_GLOBAL("device_api.foo").set_body_typed(FooDeviceAPI::Global);
+ TVM_FFI_STATIC_INIT_BLOCK({
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("device_api.foo", FooDeviceAPI::Global);
+ });
.. _base.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/base.h
@@ -164,7 +167,7 @@ then be registered with the following steps.
#. Add a case in ``DeviceName`` in `device_api.h`_ to convert from the
enum value to a string representation. This string representation
- should match the name given to ``TVM_FFI_REGISTER_GLOBAL``.
+ should match the name given to ``GlobalDef().def``.
#. Add entries to the ``DEVICE_TYPE_TO_NAME`` and ``DEVICE_NAME_TO_TYPE``
dictionaries of
:py:class:`tvm.runtime.Device` for the new enum value.
@@ -225,7 +228,10 @@ the same name as was used in the
``TVM_REGISTER_TARGET_KIND``
definition above. ::
tvm::runtime::Module GeneratorFooCode(IRModule mod, Target target);
- TVM_FFI_REGISTER_GLOBAL("target.build.foo").set_body_typed(GeneratorFooCode);
+ TVM_FFI_STATIC_INIT_BLOCK({
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("target.build.foo", GeneratorFooCode);
+ });
The code generator takes two arguments. The first is the ``IRModule``
to compile, and the second is the ``Target`` that describes the device
diff --git a/docs/arch/pass_infra.rst b/docs/arch/pass_infra.rst
index c54ba18b0a..4bf3abceb0 100644
--- a/docs/arch/pass_infra.rst
+++ b/docs/arch/pass_infra.rst
@@ -376,8 +376,10 @@ Python when needed.
return CreateFunctionPass(pass_func, 0, "FoldConstant", {});
}
- TVM_FFI_REGISTER_GLOBAL("relax.transform.FoldConstant")
- .set_body_typed(FoldConstant);
+ TVM_FFI_STATIC_INIT_BLOCK({
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.transform.FoldConstant", FoldConstant);
+ });
} // namespace transform
diff --git a/docs/arch/runtime.rst b/docs/arch/runtime.rst
index 613c7d86e1..8c2a9a0995 100644
--- a/docs/arch/runtime.rst
+++ b/docs/arch/runtime.rst
@@ -80,8 +80,10 @@ The following example registers PackedFunc in C++ and calls
from python.
.. code:: c
// register a global packed function in c++
- TVM_FFI_REGISTER_GLOBAL("myadd")
- .set_body_packed(MyAdd);
+ TVM_FFI_STATIC_INIT_BLOCK({
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def_packed("myadd", MyAdd);
+ });
.. code:: python
@@ -110,10 +112,12 @@ we can pass functions from python (as PackedFunc) to C++.
.. code:: c
- TVM_FFI_REGISTER_GLOBAL("callhello")
- .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
- PackedFunc f = args[0];
- f("hello world");
+ TVM_FFI_STATIC_INIT_BLOCK({
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def_packed("callhello", [](ffi::PackedArgs args,
ffi::Any* rv) {
+ ffi::Function f = args[0].cast<ffi::Function>();
+ f("hello world");
+ });
});
.. code:: python
diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h
index 34d994b64c..5a30f25a7b 100644
--- a/ffi/include/tvm/ffi/function.h
+++ b/ffi/include/tvm/ffi/function.h
@@ -756,135 +756,6 @@ struct TypeTraits<TypedFunction<FType>> : public
TypeTraitsBase {
TVM_FFI_INLINE static std::string TypeStr() { return
details::FunctionInfo<FType>::Sig(); }
};
-/*! \brief Registry for global function */
-class Function::Registry {
- public:
- /*! \brief constructor */
- explicit Registry(const char* name) : name_(name) {}
-
- /*!
- * \brief Set body to be to use the packed convention.
- *
- * \tparam FLambda The signature of the function.
- * \param f The body of the function.
- */
- template <typename FLambda>
- Registry& set_body_packed(FLambda f) {
- return Register(ffi::Function::FromPacked(f));
- }
- /*!
- * \brief set the body of the function to the given function.
- * Note that this will ignore default arg values and always require
all arguments to be
- * provided.
- *
- * \code
- *
- * int multiply(int x, int y) {
- * return x * y;
- * }
- *
- * TVM_FFI_REGISTER_GLOBAL("multiply")
- * .set_body_typed(multiply); // will have type int(int, int)
- *
- * // will have type int(int, int)
- * TVM_FFI_REGISTER_GLOBAL("sub")
- * .set_body_typed([](int a, int b) -> int { return a - b; });
- *
- * \endcode
- *
- * \param f The function to forward to.
- * \tparam FLambda The signature of the function.
- */
- template <typename FLambda>
- Registry& set_body_typed(FLambda f) {
- return Register(Function::FromTyped(f, name_));
- }
-
- /*!
- * \brief set the body of the function to be the passed method pointer.
- * Note that this will ignore default arg values and always require
all arguments to be
- * provided.
- *
- * \code
- *
- * // objectRef subclass:
- * struct Example : ObjectRef {
- * int DoThing(int x);
- * }
- * TVM_FFI_REGISTER_GLOBAL("Example_DoThing")
- * .set_body_method(&Example::DoThing); // will have type int(self, int)
- *
- * // Object subclass:
- * struct Example : Object {
- * int DoThing(int x);
- * }
- *
- * TVM_FFI_REGISTER_GLOBAL("Example_DoThing")
- * .set_body_method(&Example::DoThing); // will have type int(self, int)
- *
- * \endcode
- *
- * \param f the method pointer to forward to.
- * \tparam T the type containing the method (inferred).
- * \tparam R the return type of the function (inferred).
- * \tparam Args the argument types of the function (inferred).
- */
- template <typename T, typename R, typename... Args>
- Registry& set_body_method(R (T::*f)(Args...)) {
- static_assert(std::is_base_of_v<ObjectRef, T> || std::is_base_of_v<Object,
T>,
- "T must be derived from ObjectRef or Object");
- if constexpr (std::is_base_of_v<ObjectRef, T>) {
- auto fwrap = [f](T target, Args... params) -> R {
- // call method pointer
- return (target.*f)(std::forward<Args>(params)...);
- };
- return Register(ffi::Function::FromTyped(fwrap, name_));
- }
- if constexpr (std::is_base_of_v<Object, T>) {
- auto fwrap = [f](const T* target, Args... params) -> R {
- // call method pointer
- return (const_cast<T*>(target)->*f)(std::forward<Args>(params)...);
- };
- return Register(ffi::Function::FromTyped(fwrap, name_));
- }
- return *this;
- }
-
- template <typename T, typename R, typename... Args>
- Registry& set_body_method(R (T::*f)(Args...) const) {
- static_assert(std::is_base_of_v<ObjectRef, T> || std::is_base_of_v<Object,
T>,
- "T must be derived from ObjectRef or Object");
- if constexpr (std::is_base_of_v<ObjectRef, T>) {
- auto fwrap = [f](const T target, Args... params) -> R {
- // call method pointer
- return (target.*f)(std::forward<Args>(params)...);
- };
- return Register(ffi::Function::FromTyped(fwrap, name_));
- }
- if constexpr (std::is_base_of_v<Object, T>) {
- auto fwrap = [f](const T* target, Args... params) -> R {
- // call method pointer
- return (target->*f)(std::forward<Args>(params)...);
- };
- return Register(ffi::Function::FromTyped(fwrap, name_));
- }
- return *this;
- }
-
- protected:
- /*!
- * \brief set the body of the function to be f
- * \param f The body of the function.
- */
- Registry& Register(Function f) {
- Function::SetGlobal(name_, f);
- return *this;
- }
-
- /*! \brief name of the function */
- const char* name_;
-};
-
/*!
* \brief helper function to get type index from key
*/
@@ -895,21 +766,6 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) {
return type_index;
}
-#define TVM_FFI_FUNC_REG_VAR_DEF \
- TVM_FFI_ATTRIBUTE_UNUSED static inline ::tvm::ffi::Function::Registry&
__##TVMFFIFuncReg
-
-/*!
- * \brief Register a function globally.
- * \code
- * TVM_FFI_REGISTER_GLOBAL("MyAdd")
- * .set_body_typed([](int a, int b) {
- * return a + b;
- * });
- * \endcode
- */
-#define TVM_FFI_REGISTER_GLOBAL(OpName) \
- TVM_FFI_STR_CONCAT(TVM_FFI_FUNC_REG_VAR_DEF, __COUNTER__) =
::tvm::ffi::Function::Registry(OpName)
-
/*!
* \brief Export typed function as a SafeCallType symbol.
*
diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h
index c43543cc38..cf467870c6 100644
--- a/include/tvm/runtime/profiling.h
+++ b/include/tvm/runtime/profiling.h
@@ -134,8 +134,11 @@ class Timer : public ObjectRef {
* };
* TVM_REGISTER_OBJECT_TYPE(CPUTimerNode);
*
- * TVM_FFI_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device
dev) {
- * return Timer(make_object<CPUTimerNode>());
+ * TVM_FFI_STATIC_INIT_BLOCK({
+ * namespace refl = tvm::ffi::reflection;
+ * refl::GlobalDef().def("profiling.timer.cpu", [](Device dev) {
+ * return Timer(make_object<CPUTimerNode>());
+ * });
* });
* \endcode
*/
diff --git a/python/tvm/contrib/msc/plugin/codegen/sources.py
b/python/tvm/contrib/msc/plugin/codegen/sources.py
index b507d7b825..a4e89ad7ec 100644
--- a/python/tvm/contrib/msc/plugin/codegen/sources.py
+++ b/python/tvm/contrib/msc/plugin/codegen/sources.py
@@ -684,6 +684,17 @@ class TVMUtils {
return cuda_dev;
}
};
+
+#define TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF(FuncName, Body) \
+ TVM_FFI_STATIC_INIT_BLOCK({ \
+ tvm::ffi::reflection::GlobalDef().def(FuncName, Body); \
+ })
+
+#define TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF_PACKED(FuncName, Body) \
+ TVM_FFI_STATIC_INIT_BLOCK({ \
+ tvm::ffi::reflection::GlobalDef().def_packed(FuncName, Body); \
+ })
+
#endif // PLUGIN_SUPPORT_TVM
"""
@@ -1101,6 +1112,7 @@ def get_plugin_utils_h_code() -> str:
#ifdef PLUGIN_SUPPORT_TVM
#include <tvm/relax/expr.h>
+#include <tvm/ffi/reflection/registry.h>
#include "tvm/../../src/contrib/msc/core/transform/layout_utils.h"
#include "tvm/../../src/contrib/msc/core/utils.h"
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index ab416ef141..1e42c862fe 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -2087,7 +2087,7 @@ def extern(
out: OutType,
) -> OutType:
"""Invoke an extern function during runtime. The extern function must be
registered with the "
- TVM runtime using `TVM_FFI_REGISTER_GLOBAL` (C++), or `tvm.register_func`
(Python).
+ TVM runtime using `reflection::GlobalDef().def` (C++), or
`tvm.register_func` (Python).
Parameters
----------
diff --git a/python/tvm/runtime/_ffi_api.py b/python/tvm/runtime/_ffi_api.py
index 71f96983ee..88a49f3a63 100644
--- a/python/tvm/runtime/_ffi_api.py
+++ b/python/tvm/runtime/_ffi_api.py
@@ -17,6 +17,5 @@
"""FFI APIs for tvm.runtime"""
import tvm.ffi
-# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the "runtime"
prefix.
-# e.g. TVM_FFI_REGISTER_GLOBAL("runtime.ModuleLoadFromFile")
+# Exports functions registered in runtime namespace.
tvm.ffi._init_api("runtime", __name__)
diff --git a/python/tvm/runtime/_ffi_node_api.py
b/python/tvm/runtime/_ffi_node_api.py
index 493dfceab5..aef9ded9cc 100644
--- a/python/tvm/runtime/_ffi_node_api.py
+++ b/python/tvm/runtime/_ffi_node_api.py
@@ -24,7 +24,6 @@ import tvm.ffi.core
# The implementations below are default ones when the corresponding
# functions are not available in the runtime only mode.
# They will be overriden via _init_api to the ones registered
-# via TVM_FFI_REGISTER_GLOBAL in the compiler mode.
def AsRepr(obj):
return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")"
@@ -45,6 +44,5 @@ def LoadJSON(json_str):
raise RuntimeError("Do not support object serialization in runtime only
mode")
-# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the "node"
prefix.
-# e.g. TVM_FFI_REGISTER_GLOBAL("node.AsRepr")
+# Exports functions registered in node namespace.
tvm.ffi._init_api("node", __name__)
diff --git a/src/contrib/msc/plugin/tvm_codegen.cc
b/src/contrib/msc/plugin/tvm_codegen.cc
index 08cd3d7da6..a3861aabe7 100644
--- a/src/contrib/msc/plugin/tvm_codegen.cc
+++ b/src/contrib/msc/plugin/tvm_codegen.cc
@@ -215,14 +215,12 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin&
plugin) {
stack_.func_end("infer_output");
// register funcs
- stack_.func_call("TVM_FFI_REGISTER_GLOBAL")
+ stack_.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF")
.call_arg(DocUtils::ToStr("msc.plugin.op.InferStructInfo" +
plugin->name))
- .method_call("set_body_typed")
.call_arg("InferStructInfo" + plugin->name)
.line()
- .func_call("TVM_FFI_REGISTER_GLOBAL")
+ .func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF")
.call_arg(DocUtils::ToStr("msc.plugin.op.InferLayout" + plugin->name))
- .method_call("set_body_typed")
.call_arg("InferLayout" + plugin->name)
.line();
}
@@ -262,9 +260,8 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin&
plugin) {
CodeGenCompute(plugin, "cpu");
stack_.cond_end().func_end();
// register the compute
- stack_.func_call("TVM_FFI_REGISTER_GLOBAL")
+ stack_.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF_PACKED")
.call_arg(DocUtils::ToStr(plugin->name))
- .method_call("set_body")
.call_arg(func_name)
.line();
}
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index d7d50f8fa7..4da8b18fcb 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -181,7 +181,8 @@ std::tuple<ArgTypes...> GetArgStructInfo(const Call& call,
const BlockBuilder& c
static const Op& op = Op::Get("relax." OpRegName); \
return Call(op, {std::move(x)}, Attrs(), {}); \
} \
- TVM_FFI_REGISTER_GLOBAL("relax.op." OpRegName).set_body_typed(OpName)
+ TVM_FFI_STATIC_INIT_BLOCK( \
+ { tvm::ffi::reflection::GlobalDef().def("relax.op." OpRegName, OpName);
})
/************ Utilities ************/
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index ae36d45b36..f612ec0598 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -42,7 +42,8 @@ namespace relax {
static const Op& op = Op::Get("relax." #OpName);
\
return Call(op, {x1, x2}, Attrs(), {});
\
}
\
- TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName);
\
+ TVM_FFI_STATIC_INIT_BLOCK(
\
+ { tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName);
}); \
TVM_REGISTER_OP("relax." #OpName)
\
.set_num_inputs(2)
\
.add_argument("x1", "Tensor", "The first input tensor.")
\
diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc
index a3d17be135..6808acdedf 100644
--- a/src/relax/op/tensor/search.cc
+++ b/src/relax/op/tensor/search.cc
@@ -256,7 +256,8 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call,
const BlockBuilder& ctx
static const Op& op = Op::Get("relax." #OpName);
\
return Call(op, {std::move(x)}, Attrs(attrs));
\
}
\
- TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName);
\
+ TVM_FFI_STATIC_INIT_BLOCK(
\
+ { tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName);
}); \
TVM_REGISTER_OP("relax." #OpName)
\
.set_num_inputs(1)
\
.add_argument("x", "Tensor", "The input data tensor")
\
diff --git a/src/relax/op/tensor/statistical.h
b/src/relax/op/tensor/statistical.h
index 331562454e..e79ce1d4ae 100644
--- a/src/relax/op/tensor/statistical.h
+++ b/src/relax/op/tensor/statistical.h
@@ -50,7 +50,8 @@ namespace relax {
static const Op& op = Op::Get("relax." #OpName);
\
return Call(op, {std::move(x)}, Attrs{attrs}, {});
\
}
\
- TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName);
\
+ TVM_FFI_STATIC_INIT_BLOCK(
\
+ { tvm::ffi::reflection::GlobalDef().def("relax.op." #OpName, OpName);
}); \
TVM_REGISTER_OP("relax." #OpName)
\
.set_num_inputs(1)
\
.add_argument("x", "Tensor", "The input data tensor")
\
diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc
index 3eb817fd0f..06eb0284f9 100644
--- a/src/runtime/device_api.cc
+++ b/src/runtime/device_api.cc
@@ -193,19 +193,17 @@ TVM_FFI_STATIC_INIT_BLOCK({
});
});
-// set device api
-TVM_FFI_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
- .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) {
- DLDevice dev;
- dev.device_type = static_cast<DLDeviceType>(args[0].cast<int>());
- dev.device_id = args[1].cast<int>();
- DeviceAPIManager::Get(dev)->SetDevice(dev);
- });
-
// set device api
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
+ .def_packed(tvm::runtime::symbol::tvm_set_device,
+ [](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) {
+ DLDevice dev;
+ dev.device_type =
static_cast<DLDeviceType>(args[0].cast<int>());
+ dev.device_id = args[1].cast<int>();
+ DeviceAPIManager::Get(dev)->SetDevice(dev);
+ })
.def_packed("runtime.GetDeviceAttr",
[](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) {
DLDevice dev;
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index f3f79c9ccc..9e41bbd0de 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -331,64 +331,49 @@ TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("runtime.disco.compiled_ccl",
[]() -> String { return TVM_DISCO_CCL_NAME; });
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl", InitCCL)
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker",
InitCCLPerWorker)
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce",
+ [](NDArray send, int kind, bool in_group, NDArray recv) {
+ CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind:
" << kind;
+ nccl::AllReduce(send, static_cast<ReduceKind>(kind), in_group,
recv);
+ })
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".allgather",
+ [](NDArray send, bool in_group, NDArray recv) {
nccl::AllGather(send, in_group, recv); })
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0",
BroadcastFromWorker0)
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0",
ScatterFromWorker0)
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0",
GatherToWorker0)
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0",
RecvFromWorker0)
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group",
SendToNextGroup)
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group",
RecvFromPrevGroup)
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker", SendToWorker)
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker",
RecvFromWorker)
+ .def("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker", SyncWorker)
+ .def("runtime.disco." TVM_DISCO_CCL_NAME
".test_send_to_next_group_recv_from_prev_group",
+ [](NDArray buffer) {
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+ CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the
world size to be 4.";
+ CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the
group size to be 2.";
+ int group_size = ctx->worker->num_workers /
ctx->worker->num_groups;
+ int group_id = ctx->worker->worker_id / group_size;
+ if (group_id == 0) {
+ tvm::runtime::nccl::SendToNextGroup(buffer);
+ } else {
+ tvm::runtime::nccl::RecvFromPrevGroup(buffer);
+ }
+ })
+ .def("runtime.disco." TVM_DISCO_CCL_NAME
".test_worker2_sends_to_worker0",
+ [](NDArray buffer) {
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+ CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the
world size to be 4.";
+ CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the
group size to be 2.";
+ if (ctx->worker->worker_id == 2) {
+ tvm::runtime::nccl::SendToWorker(buffer, 0);
+ } else if (ctx->worker->worker_id == 0) {
+ tvm::runtime::nccl::RecvFromWorker(buffer, 2);
+ }
+ });
});
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".init_ccl").set_body_typed(InitCCL);
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".init_ccl_per_worker")
- .set_body_typed(InitCCLPerWorker);
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce")
- .set_body_typed([](NDArray send, int kind, bool in_group, NDArray recv) {
- CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " <<
kind;
- nccl::AllReduce(send, static_cast<ReduceKind>(kind), in_group, recv);
- });
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather")
- .set_body_typed([](NDArray send, bool in_group, NDArray recv) {
- nccl::AllGather(send, in_group, recv);
- });
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".broadcast_from_worker0")
- .set_body_typed(BroadcastFromWorker0);
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".scatter_from_worker0")
- .set_body_typed(ScatterFromWorker0);
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".gather_to_worker0")
- .set_body_typed(GatherToWorker0);
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".recv_from_worker0")
- .set_body_typed(RecvFromWorker0);
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".send_to_next_group")
- .set_body_typed(SendToNextGroup);
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".recv_from_prev_group")
- .set_body_typed(RecvFromPrevGroup);
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker")
- .set_body_typed(SendToWorker);
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".recv_from_worker")
- .set_body_typed(RecvFromWorker);
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker")
- .set_body_typed(SyncWorker);
-
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
- ".test_send_to_next_group_recv_from_prev_group")
- .set_body_typed([](NDArray buffer) {
- CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
- CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world
size to be 4.";
- CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group
size to be 2.";
- int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
- int group_id = ctx->worker->worker_id / group_size;
- if (group_id == 0) {
- tvm::runtime::nccl::SendToNextGroup(buffer);
- } else {
- tvm::runtime::nccl::RecvFromPrevGroup(buffer);
- }
- });
-
-TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".test_worker2_sends_to_worker0")
- .set_body_typed([](NDArray buffer) {
- CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
- CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world
size to be 4.";
- CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group
size to be 2.";
- if (ctx->worker->worker_id == 2) {
- tvm::runtime::nccl::SendToWorker(buffer, 0);
- } else if (ctx->worker->worker_id == 0) {
- tvm::runtime::nccl::RecvFromWorker(buffer, 2);
- }
- });
} // namespace nccl
} // namespace runtime
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 998b468f4b..e8c8d62c9b 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -760,107 +760,117 @@ TVM_FFI_STATIC_INIT_BLOCK({
#define TVM_TMP_STR(x) #x
-#define TVM_FFI_REGISTER_GLOBAL_SIZE(Prefix, DType) \
- TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(8)).set_body_typed(DType##8); \
- TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(16)).set_body_typed(DType##16); \
- TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(32)).set_body_typed(DType##32); \
- TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(64)).set_body_typed(DType##64);
-
-TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Float", Float);
-TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt);
-TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Int", Int);
-
-#define TVM_FFI_REGISTER_GLOBAL_LANES(Prefix, Func) \
- TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x4)).set_body_typed(Func##x4); \
- TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x8)).set_body_typed(Func##x8); \
- TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x16)).set_body_typed(Func##x16); \
- TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x32)).set_body_typed(Func##x32); \
- TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x64)).set_body_typed(Func##x64);
-
-#define TVM_FFI_REGISTER_GLOBAL_SIZES_LANES(Prefix, DType) \
- TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8); \
- TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16); \
- TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32); \
- TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64);
-
-TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float);
-TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt);
-TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int);
+#define TVM_FFI_REFL_DEF_GLOBAL_SIZE(Prefix, DType) \
+ def(Prefix TVM_TMP_STR(8), DType##8) \
+ .def(Prefix TVM_TMP_STR(16), DType##16) \
+ .def(Prefix TVM_TMP_STR(32), DType##32) \
+ .def(Prefix TVM_TMP_STR(64), DType##64)
+
+#define TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix, Func) \
+ def(Prefix TVM_TMP_STR(x4), Func##x4) \
+ .def(Prefix TVM_TMP_STR(x8), Func##x8) \
+ .def(Prefix TVM_TMP_STR(x16), Func##x16) \
+ .def(Prefix TVM_TMP_STR(x32), Func##x32) \
+ .def(Prefix TVM_TMP_STR(x64), Func##x64)
+
+#define TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES(Prefix, DType) \
+ TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8) \
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16) \
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32) \
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64)
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.BFloat16", BFloat16);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.BFloat16", BFloat16)
+ .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tir.Float", Float)
+ .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt)
+ .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tir.Int", Int)
+ .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float",
Float)
+ .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt)
+ .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int)
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.BFloat16",
BFloat16);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16);
// Float8 variants
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.Float8E3M4", Float8E3M4);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.Float8E3M4", Float8E3M4)
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E3M4",
Float8E3M4);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E3M4", Float8E3M4);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3", Float8E4M3);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.Float8E4M3", Float8E4M3)
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3",
Float8E4M3);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3", Float8E4M3);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3B11FNUZ",
Float8E4M3B11FNUZ);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ)
+
.TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3B11FNUZ",
Float8E4M3B11FNUZ);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3B11FNUZ",
Float8E4M3B11FNUZ);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN)
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN",
Float8E4M3FN);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN",
Float8E4M3FN);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.Float8E4M3FNUZ",
Float8E4M3FNUZ);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ)
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FNUZ",
Float8E4M3FNUZ);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FNUZ",
Float8E4M3FNUZ);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.Float8E5M2", Float8E5M2);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.Float8E5M2", Float8E5M2)
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2",
Float8E5M2);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.Float8E5M2FNUZ",
Float8E5M2FNUZ);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ)
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2FNUZ",
Float8E5M2FNUZ);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2FNUZ",
Float8E5M2FNUZ);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU)
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E8M0FNU",
Float8E8M0FNU);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E8M0FNU",
Float8E8M0FNU);
// Float6 variants
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN)
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float6E2M3FN",
Float6E2M3FN);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float6E2M3FN",
Float6E2M3FN);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN)
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float6E3M2FN",
Float6E3M2FN);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float6E3M2FN",
Float6E3M2FN);
// Float4 variant
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN);
+ refl::GlobalDef()
+ .def("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN)
+ .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN",
Float4E2M1FN);
});
-TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN",
Float4E2M1FN);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h
index b1a1a4a7f5..363494e0fd 100644
--- a/src/target/datatype/registry.h
+++ b/src/target/datatype/registry.h
@@ -37,7 +37,7 @@ namespace datatype {
* directly---see the TVM globals registered in the corresponding .cc file.
* Currently, user should manually choose a type name and a type code,
* ensuring that neither conflict with existing types.
- * 2. Use TVM_FFI_REGISTER_GLOBAL to register the lowering functions needed to
+ * 2. Register the lowering functions needed to
* lower the custom datatype. In general, these will look like:
* For Casts: tvm.datatype.lower.<target>.Cast.<type>.<src_type>
* Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 03b22de838..9ced6f556c 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -255,7 +255,10 @@ PrimExpr thread_return(Span span) {
return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, span);
}
-TVM_FFI_REGISTER_GLOBAL("tir.thread_return").set_body_typed(thread_return);
+TVM_FFI_STATIC_INIT_BLOCK({
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("tir.thread_return", thread_return);
+});
// maximum and min limits
PrimExpr max_value(const DataType& dtype, Span span) {
@@ -1158,54 +1161,22 @@ TVM_FFI_STATIC_INIT_BLOCK({
});
// operator overloading, smarter than make
-#define REGISTER_MAKE_BINARY_OP(Node, Func)
\
- TVM_FFI_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr
b, Span span) { \
- return (Func(a, b, span));
\
+#define DEF_MAKE_BINARY_OP(Node, Func) \
+ def("tir." #Node, [](PrimExpr a, PrimExpr b, Span span) { return (Func(a, b,
span)); })
+
+#define DEF_MAKE_BIT_OP(Node, Func)
\
+ def_packed("tir." #Node, [](ffi::PackedArgs args, ffi::Any* ret) {
\
+ bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt;
\
+ bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt;
\
+ if (lhs_is_int) {
\
+ *ret = (Func(args[0].cast<int>(), args[1].cast<PrimExpr>(),
args[2].cast<Span>())); \
+ } else if (rhs_is_int) {
\
+ *ret = (Func(args[0].cast<PrimExpr>(), args[1].cast<int>(),
args[2].cast<Span>())); \
+ } else {
\
+ *ret = (Func(args[0].cast<PrimExpr>(), args[1].cast<PrimExpr>(),
args[2].cast<Span>())); \
+ }
\
})
-#define REGISTER_MAKE_BIT_OP(Node, Func)
\
- TVM_FFI_REGISTER_GLOBAL("tir." #Node).set_body_packed([](ffi::PackedArgs
args, ffi::Any* ret) { \
- bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt;
\
- bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt;
\
- if (lhs_is_int) {
\
- *ret = (Func(args[0].cast<int>(), args[1].cast<PrimExpr>(),
args[2].cast<Span>())); \
- } else if (rhs_is_int) {
\
- *ret = (Func(args[0].cast<PrimExpr>(), args[1].cast<int>(),
args[2].cast<Span>())); \
- } else {
\
- *ret = (Func(args[0].cast<PrimExpr>(), args[1].cast<PrimExpr>(),
args[2].cast<Span>())); \
- }
\
- })
-
-REGISTER_MAKE_BINARY_OP(_OpAdd, add);
-REGISTER_MAKE_BINARY_OP(_OpSub, sub);
-REGISTER_MAKE_BINARY_OP(_OpMul, mul);
-REGISTER_MAKE_BINARY_OP(_OpDiv, div);
-REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
-REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
-REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
-REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
-REGISTER_MAKE_BINARY_OP(_OpLogAddExp, logaddexp);
-REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
-REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
-REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
-REGISTER_MAKE_BINARY_OP(_OpCeilDiv, ceildiv);
-REGISTER_MAKE_BINARY_OP(_OpPow, pow);
-REGISTER_MAKE_BINARY_OP(_OpMin, min);
-REGISTER_MAKE_BINARY_OP(_OpMax, max);
-REGISTER_MAKE_BINARY_OP(_OpEQ, equal);
-REGISTER_MAKE_BINARY_OP(_OpNE, not_equal);
-REGISTER_MAKE_BINARY_OP(_OpLT, less); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpLE, less_equal); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpGT, greater); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpGE, greater_equal);
-REGISTER_MAKE_BINARY_OP(_OpAnd, logical_and);
-REGISTER_MAKE_BINARY_OP(_OpOr, logical_or);
-REGISTER_MAKE_BIT_OP(bitwise_and, bitwise_and);
-REGISTER_MAKE_BIT_OP(bitwise_or, bitwise_or);
-REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor);
-REGISTER_MAKE_BIT_OP(left_shift, left_shift); // NOLINT(*)
-REGISTER_MAKE_BIT_OP(right_shift, right_shift);
-
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
@@ -1213,7 +1184,36 @@ TVM_FFI_STATIC_INIT_BLOCK({
[](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span
span) {
return if_then_else(cond, true_value, false_value, span);
})
- .def("tir.const_true", [](DataType t, Span span) { return
const_true(t.lanes(), span); });
+ .def("tir.const_true", [](DataType t, Span span) { return
const_true(t.lanes(), span); })
+ .DEF_MAKE_BINARY_OP(_OpAdd, add)
+ .DEF_MAKE_BINARY_OP(_OpSub, sub)
+ .DEF_MAKE_BINARY_OP(_OpMul, mul)
+ .DEF_MAKE_BINARY_OP(_OpDiv, div)
+ .DEF_MAKE_BINARY_OP(_OpMod, truncmod)
+ .DEF_MAKE_BINARY_OP(_OpIndexDiv, indexdiv)
+ .DEF_MAKE_BINARY_OP(_OpIndexMod, indexmod)
+ .DEF_MAKE_BINARY_OP(_OpFloorDiv, floordiv)
+ .DEF_MAKE_BINARY_OP(_OpLogAddExp, logaddexp)
+ .DEF_MAKE_BINARY_OP(_OpFloorMod, floormod)
+ .DEF_MAKE_BINARY_OP(_OpTruncDiv, truncdiv)
+ .DEF_MAKE_BINARY_OP(_OpTruncMod, truncmod)
+ .DEF_MAKE_BINARY_OP(_OpCeilDiv, ceildiv)
+ .DEF_MAKE_BINARY_OP(_OpPow, pow)
+ .DEF_MAKE_BINARY_OP(_OpMin, min)
+ .DEF_MAKE_BINARY_OP(_OpMax, max)
+ .DEF_MAKE_BINARY_OP(_OpEQ, equal)
+ .DEF_MAKE_BINARY_OP(_OpNE, not_equal)
+ .DEF_MAKE_BINARY_OP(_OpLT, less) // NOLINT(*)
+ .DEF_MAKE_BINARY_OP(_OpLE, less_equal) // NOLINT(*)
+ .DEF_MAKE_BINARY_OP(_OpGT, greater) // NOLINT(*)
+ .DEF_MAKE_BINARY_OP(_OpGE, greater_equal)
+ .DEF_MAKE_BINARY_OP(_OpAnd, logical_and)
+ .DEF_MAKE_BINARY_OP(_OpOr, logical_or)
+ .DEF_MAKE_BIT_OP(bitwise_and, bitwise_and)
+ .DEF_MAKE_BIT_OP(bitwise_or, bitwise_or)
+ .DEF_MAKE_BIT_OP(bitwise_xor, bitwise_xor)
+ .DEF_MAKE_BIT_OP(left_shift, left_shift) // NOLINT(*)
+ .DEF_MAKE_BIT_OP(right_shift, right_shift);
});
PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) {
diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc
index 99b43b82e9..1ca901c6fb 100644
--- a/src/topi/broadcast.cc
+++ b/src/topi/broadcast.cc
@@ -32,52 +32,53 @@ namespace topi {
using namespace tvm;
using namespace tvm::runtime;
-#define TOPI_REGISTER_BCAST_OP(OpName, Op)
\
- TVM_FFI_REGISTER_GLOBAL(OpName).set_body_packed([](ffi::PackedArgs args,
ffi::Any* rv) { \
- bool lhs_is_tensor = args[0].as<tvm::te::Tensor>().has_value();
\
- bool rhs_is_tensor = args[1].as<tvm::te::Tensor>().has_value();
\
- if (lhs_is_tensor && rhs_is_tensor) {
\
- *rv = Op(args[0].cast<tvm::te::Tensor>(),
args[1].cast<tvm::te::Tensor>()); \
- } else if (!lhs_is_tensor && rhs_is_tensor) {
\
- *rv = Op(args[0].cast<tvm::PrimExpr>(),
args[1].cast<tvm::te::Tensor>()); \
- } else if (lhs_is_tensor && !rhs_is_tensor) {
\
- *rv = Op(args[0].cast<tvm::te::Tensor>(),
args[1].cast<tvm::PrimExpr>()); \
- } else if (!lhs_is_tensor && !rhs_is_tensor) {
\
- *rv = Op(args[0].cast<tvm::PrimExpr>(), args[1].cast<tvm::PrimExpr>());
\
- }
\
- });
-
-TOPI_REGISTER_BCAST_OP("topi.add", topi::add);
-TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
-TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply);
-TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide);
-TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide);
-TOPI_REGISTER_BCAST_OP("topi.log_add_exp", topi::log_add_exp);
-TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod);
-TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod);
-TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum);
-TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum);
-TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
-TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
-TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
-TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
-TOPI_REGISTER_BCAST_OP("topi.logical_xor", topi::logical_xor);
-TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
-TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
-TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);
-TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
-TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
-TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
-TOPI_REGISTER_BCAST_OP("topi.equal", topi::equal);
-TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal);
-TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal);
-TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal);
+#define TOPI_DEF_BCAST_OP(OpName, Op)
\
+ def_packed(OpName, [](ffi::PackedArgs args, ffi::Any* rv) {
\
+ bool lhs_is_tensor = args[0].as<tvm::te::Tensor>().has_value();
\
+ bool rhs_is_tensor = args[1].as<tvm::te::Tensor>().has_value();
\
+ if (lhs_is_tensor && rhs_is_tensor) {
\
+ *rv = Op(args[0].cast<tvm::te::Tensor>(),
args[1].cast<tvm::te::Tensor>()); \
+ } else if (!lhs_is_tensor && rhs_is_tensor) {
\
+ *rv = Op(args[0].cast<tvm::PrimExpr>(),
args[1].cast<tvm::te::Tensor>()); \
+ } else if (lhs_is_tensor && !rhs_is_tensor) {
\
+ *rv = Op(args[0].cast<tvm::te::Tensor>(),
args[1].cast<tvm::PrimExpr>()); \
+ } else if (!lhs_is_tensor && !rhs_is_tensor) {
\
+ *rv = Op(args[0].cast<tvm::PrimExpr>(), args[1].cast<tvm::PrimExpr>());
\
+ }
\
+ })
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def_packed("topi.broadcast_to", [](ffi::PackedArgs args,
ffi::Any* rv) {
- *rv = broadcast_to(args[0].cast<te::Tensor>(),
args[1].cast<Array<PrimExpr>>());
- });
+ refl::GlobalDef()
+ .def_packed("topi.broadcast_to",
+ [](ffi::PackedArgs args, ffi::Any* rv) {
+ *rv = broadcast_to(args[0].cast<te::Tensor>(),
args[1].cast<Array<PrimExpr>>());
+ })
+ .TOPI_DEF_BCAST_OP("topi.add", topi::add)
+ .TOPI_DEF_BCAST_OP("topi.subtract", topi::subtract)
+ .TOPI_DEF_BCAST_OP("topi.multiply", topi::multiply)
+ .TOPI_DEF_BCAST_OP("topi.divide", topi::divide)
+ .TOPI_DEF_BCAST_OP("topi.floor_divide", topi::floor_divide)
+ .TOPI_DEF_BCAST_OP("topi.log_add_exp", topi::log_add_exp)
+ .TOPI_DEF_BCAST_OP("topi.mod", topi::mod)
+ .TOPI_DEF_BCAST_OP("topi.floor_mod", topi::floor_mod)
+ .TOPI_DEF_BCAST_OP("topi.maximum", topi::maximum)
+ .TOPI_DEF_BCAST_OP("topi.minimum", topi::minimum)
+ .TOPI_DEF_BCAST_OP("topi.power", topi::power)
+ .TOPI_DEF_BCAST_OP("topi.left_shift", topi::left_shift)
+ .TOPI_DEF_BCAST_OP("topi.logical_and", topi::logical_and)
+ .TOPI_DEF_BCAST_OP("topi.logical_or", topi::logical_or)
+ .TOPI_DEF_BCAST_OP("topi.logical_xor", topi::logical_xor)
+ .TOPI_DEF_BCAST_OP("topi.bitwise_and", topi::bitwise_and)
+ .TOPI_DEF_BCAST_OP("topi.bitwise_or", topi::bitwise_or)
+ .TOPI_DEF_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor)
+ .TOPI_DEF_BCAST_OP("topi.right_shift", topi::right_shift)
+ .TOPI_DEF_BCAST_OP("topi.greater", topi::greater)
+ .TOPI_DEF_BCAST_OP("topi.less", topi::less)
+ .TOPI_DEF_BCAST_OP("topi.equal", topi::equal)
+ .TOPI_DEF_BCAST_OP("topi.not_equal", topi::not_equal)
+ .TOPI_DEF_BCAST_OP("topi.greater_equal", topi::greater_equal)
+ .TOPI_DEF_BCAST_OP("topi.less_equal", topi::less_equal);
});
} // namespace topi
diff --git a/tests/python/contrib/test_hexagon/README_RPC.md
b/tests/python/contrib/test_hexagon/README_RPC.md
index 955cd58dc2..8d185fcbeb 100644
--- a/tests/python/contrib/test_hexagon/README_RPC.md
+++ b/tests/python/contrib/test_hexagon/README_RPC.md
@@ -80,12 +80,15 @@ Which eventually jumps to the following line in C++, which
creates a RPC client
[https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129](https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129)
```cpp
-TVM_FFI_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs
args, ffi::Any* rv) {
- auto url = args[0].cast<std::string>();
- int port = args[1].cast<int>();
- auto key = args[2].cast<std::string>();
- *rv = RPCClientConnect(url, port, key,
- ffi::PackedArgs(args.values + 3, args.type_codes + 3,
args.size() - 3));
+TVM_FFI_STATIC_INIT_BLOCK({
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def_packed("rpc.Connect", [](ffi::PackedArgs args,
ffi::Any* rv) {
+ auto url = args[0].cast<std::string>();
+ int port = args[1].cast<int>();
+ auto key = args[2].cast<std::string>();
+ *rv = RPCClientConnect(url, port, key,
+ ffi::PackedArgs(args.values + 3, args.type_codes +
3, args.size() - 3));
+ });
});
```
@@ -94,8 +97,11 @@
TVM_FFI_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args,
[https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106](https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106)
```cpp
-TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session")
- .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
+
+TVM_FFI_STATIC_INIT_BLOCK({
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def_packed(
+ "tvm.contrib.hexagon.create_hexagon_session", [](ffi::PackedArgs args,
ffi::Any* rv) {
auto session_name = args[0].cast<std::string>();
int remote_stack_size_bytes = args[1].cast<int>();
HexagonTransportChannel* hexagon_channel =
@@ -105,6 +111,7 @@
TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session")
auto sess = CreateClientSession(ep);
*rv = CreateRPCSessionModule(sess);
});
+});
```
`HexagonTransportChannel` is the one that actually knows how to talk to
Hexagon. It uses functions such as `hexagon_rpc_send`, `hexagon_rpc_receive`
defined in