This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7c35267756 [Fix] add TVM_DLL to disco functions (#16258)
7c35267756 is described below
commit 7c352677568df0f12c49a4b5b8864b11fb37701f
Author: Lesheng Jin <[email protected]>
AuthorDate: Mon Dec 18 15:32:52 2023 +0800
[Fix] add TVM_DLL to disco functions (#16258)
---
include/tvm/runtime/disco/builtin.h | 4 ++--
include/tvm/runtime/disco/disco_worker.h | 2 +-
include/tvm/runtime/relax_vm/ndarray_cache_support.h | 10 +++++-----
src/runtime/disco/builtin.cc | 4 ++--
src/runtime/disco/disco_worker.cc | 2 +-
src/runtime/relax_vm/ndarray_cache_support.cc | 11 ++++++-----
6 files changed, 17 insertions(+), 16 deletions(-)
diff --git a/include/tvm/runtime/disco/builtin.h
b/include/tvm/runtime/disco/builtin.h
index 3847aef3f2..512059b31b 100644
--- a/include/tvm/runtime/disco/builtin.h
+++ b/include/tvm/runtime/disco/builtin.h
@@ -89,14 +89,14 @@ void AllGather(NDArray send, NDArray recv);
* \param send The buffer to be broadcasted
* \param recv The buffer receives the broadcasted array
*/
-void BroadcastFromWorker0(NDArray send, NDArray recv);
+TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv);
/*!
* \brief Perform a scatter operation from worker-0, chunking the given buffer
into equal parts.
* \param send For worker-0, it must be provided, and otherwise, the buffer
must be None.
* The buffer will be divided into equal parts and sent to each worker
accordingly.
* \param recv The receiving buffer, which must not be None.
*/
-void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
+TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
/*!
* \brief Perform a gather operation to worker-0.
* \param send The sending buffer, which must not be None.
diff --git a/include/tvm/runtime/disco/disco_worker.h
b/include/tvm/runtime/disco/disco_worker.h
index 0c666150d4..14f8f23807 100644
--- a/include/tvm/runtime/disco/disco_worker.h
+++ b/include/tvm/runtime/disco/disco_worker.h
@@ -60,7 +60,7 @@ class DiscoWorker {
/*! \brief Main loop of the worker */
void MainLoop();
/*! \brief Get the worker instance on the current thread */
- static DiscoWorker* ThreadLocal();
+ TVM_DLL static DiscoWorker* ThreadLocal();
/*! \brief Set the specific register to a specific value */
void SetRegister(int reg_id, TVMArgValue value);
diff --git a/include/tvm/runtime/relax_vm/ndarray_cache_support.h
b/include/tvm/runtime/relax_vm/ndarray_cache_support.h
index 3d8b639ee4..584da8f0ca 100644
--- a/include/tvm/runtime/relax_vm/ndarray_cache_support.h
+++ b/include/tvm/runtime/relax_vm/ndarray_cache_support.h
@@ -63,10 +63,10 @@ struct NDArrayCacheMetadata {
};
/*! \brief Load a FileRecord into memory */
- Array<NDArray> Load(Device device, //
- const std::string& path_prefix, //
- std::string* raw_data_buffer, //
- Optional<NDArray>* staging_buffer = nullptr) const;
+ TVM_DLL Array<NDArray> Load(Device device, //
+ const std::string& path_prefix, //
+ std::string* raw_data_buffer, //
+ Optional<NDArray>* staging_buffer = nullptr)
const;
/*! \brief Relative path to the bin file */
std::string data_path;
@@ -83,7 +83,7 @@ struct NDArrayCacheMetadata {
std::string path;
/*! \brief Load the metadata from a specific directory */
- static NDArrayCacheMetadata Load(const std::string& path);
+ TVM_DLL static NDArrayCacheMetadata Load(const std::string& path);
/*! \brief Load the metadata from a given JSON string */
static NDArrayCacheMetadata LoadFromStr(const std::string& json_str, const
std::string& path);
};
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 51fe4c13fc..911fdaae3d 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -85,11 +85,11 @@ void AllReduce(NDArray send, ReduceKind reduce_kind,
NDArray recv) {
void AllGather(NDArray send, NDArray recv) { GetCCLFunc("allgather")(send,
recv); }
-void BroadcastFromWorker0(NDArray send, NDArray recv) {
+TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv) {
GetCCLFunc("broadcast_from_worker0")(send, recv);
}
-void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
+TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
GetCCLFunc("scatter_from_worker0")(send, recv);
}
diff --git a/src/runtime/disco/disco_worker.cc
b/src/runtime/disco/disco_worker.cc
index d3c6d6a383..e8ba351e79 100644
--- a/src/runtime/disco/disco_worker.cc
+++ b/src/runtime/disco/disco_worker.cc
@@ -37,7 +37,7 @@ struct ThreadLocalDiscoWorker {
}
};
-DiscoWorker* DiscoWorker::ThreadLocal() {
+TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() {
DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker;
CHECK(ret) << "ValueError: The current thread is not a DiscoWorker thread";
return ret;
diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc
b/src/runtime/relax_vm/ndarray_cache_support.cc
index 613c70bb44..ce028f4d7d 100644
--- a/src/runtime/relax_vm/ndarray_cache_support.cc
+++ b/src/runtime/relax_vm/ndarray_cache_support.cc
@@ -123,7 +123,7 @@ NDArrayCacheMetadata
NDArrayCacheMetadata::LoadFromStr(const std::string& json_s
return result;
}
-NDArrayCacheMetadata NDArrayCacheMetadata::Load(const std::string& path) {
+TVM_DLL NDArrayCacheMetadata NDArrayCacheMetadata::Load(const std::string&
path) {
picojson::value json_info;
{
std::string json_str;
@@ -183,10 +183,11 @@ NDArray
NDArrayCacheMetadata::FileRecord::ParamRecord::Load(
return arr;
}
-Array<NDArray> NDArrayCacheMetadata::FileRecord::Load(Device device,
- const std::string&
path_prefix, //
- std::string*
raw_data_buffer, //
- Optional<NDArray>*
staging_buffer) const {
+TVM_DLL Array<NDArray> NDArrayCacheMetadata::FileRecord::Load(
+ Device device,
+ const std::string& path_prefix, //
+ std::string* raw_data_buffer, //
+ Optional<NDArray>* staging_buffer) const {
LoadBinaryFromFile(path_prefix + "/" + this->data_path, raw_data_buffer);
CHECK_EQ(this->format, "raw-shard") << "ValueError: Only `raw-shard` format
is supported";
CHECK_EQ(this->nbytes, raw_data_buffer->length())