tqchen commented on code in PR #15653:
URL: https://github.com/apache/tvm/pull/15653#discussion_r1312983067


##########
src/runtime/disco/builtin.cc:
##########
@@ -72,38 +70,48 @@ Module LoadVMModule(std::string path, Device device) {
   return mod;
 }
 
-TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule);
-
-TVM_REGISTER_GLOBAL("runtime.disco.empty").set_body([](TVMArgs args, 
TVMRetValue* rv) -> void {
-  runtime::DataType dtype = args[args.num_args - 2];
-  Device device = args[args.num_args - 1];
-  int ndim = args.num_args - 2;
-  std::vector<ShapeTuple::index_type> shape;
-  for (int i = 0; i < ndim; ++i) {
-    shape.push_back(args[i].operator int64_t());
-  }
-  device = UseDefaultDeviceIfNone(device);
-  *rv = NDArray::Empty(ShapeTuple(shape), dtype, device);
-});
+NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device device) {
+  return NDArray::Empty(shape, dtype, UseDefaultDeviceIfNone(device));
+}
 
-TVM_REGISTER_GLOBAL("runtime.disco.allreduce").set_body([](TVMArgs args, 
TVMRetValue* rv) -> void {
+const PackedFunc& GetCCLFunc(const char* name) {
   std::string ccl = DiscoWorker::ThreadLocal()->ccl;
-  std::string pf_name = "runtime.disco." + ccl + ".allreduce";
+  std::string pf_name = "runtime.disco." + ccl + "." + name;
   const PackedFunc* pf = tvm::runtime::Registry::Get(pf_name);
-  CHECK(pf != nullptr) << "ValueError: Cannot find the allreduce function for 
" << ccl << " via `"
-                       << pf_name << "`";
-  pf->CallPacked(args, rv);
-});
+  CHECK(pf != nullptr) << "ValueError: Cannot find the `" << name << "` 
function for `" << ccl
+                       << "` via `" << pf_name << "`";
+  return *pf;
+}
 
-TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0")
-    .set_body([](TVMArgs args, TVMRetValue* rv) -> void {
-      std::string ccl = DiscoWorker::ThreadLocal()->ccl;
-      std::string pf_name = "runtime.disco." + ccl + ".broadcast_from_worker0";
-      const PackedFunc* pf = tvm::runtime::Registry::Get(pf_name);
-      CHECK(pf != nullptr) << "ValueError: Cannot find the broadcast function 
for " << ccl
-                           << " via `" << pf_name << "`";
-      pf->CallPacked(args, rv);
+NDArray AllReduce(NDArray send, ReduceKind reduce_kind) {
+  return GetCCLFunc("allreduce")(send, static_cast<int>(reduce_kind));
+}
+
+NDArray BroadcastFromWorker0(NDArray buffer) {
+  return GetCCLFunc("broadcast_from_worker0")(buffer);
+}
+
+void ScatterFromWorker0(const Array<NDArray>& buffers) {

Review Comment:
   API choice, let us avoid using scatter. Since MPI scatter have a specific 
expectation (scatter from a single buffer).
   
   Instead, let us use SendToWorkers. This is also not that commonly used, so 
maybe start with DebugSendToWorkers is a good idea.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to