This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a309b6b857 [Thrust] Use pointer to tls pool to prevent creating new
pool (#16856)
a309b6b857 is described below
commit a309b6b857e9abc6849193cc7fa80c015fee7969
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Apr 8 17:29:35 2024 -0700
[Thrust] Use pointer to tls pool to prevent creating new pool (#16856)
---
src/runtime/contrib/thrust/thrust.cu | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/src/runtime/contrib/thrust/thrust.cu
b/src/runtime/contrib/thrust/thrust.cu
index 7a95b4b0a3..9e35290fab 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -54,7 +54,7 @@ class WorkspaceMemoryResource : public
thrust::mr::memory_resource<void*> {
this->workspace_size = workspace->shape[0];
} else {
// Fallback to thrust TLS caching allocator if workspace is not provided.
- thrust_pool_ = thrust::mr::tls_disjoint_pool(
+ thrust_pool_ = &thrust::mr::tls_disjoint_pool(
thrust::mr::get_global_resource<thrust::device_memory_resource>(),
thrust::mr::get_global_resource<thrust::mr::new_delete_resource>());
}
@@ -67,20 +67,20 @@ class WorkspaceMemoryResource : public
thrust::mr::memory_resource<void*> {
<< " bytes.";
return result;
}
- return thrust_pool_.do_allocate(bytes, alignment).get();
+ return thrust_pool_->do_allocate(bytes, alignment).get();
}
void do_deallocate(void* p, size_t bytes, size_t alignment) override {
if (workspace != nullptr) {
// No-op
} else {
- thrust_pool_.do_deallocate(thrust::device_memory_resource::pointer(p),
bytes, alignment);
+ thrust_pool_->do_deallocate(thrust::device_memory_resource::pointer(p),
bytes, alignment);
}
}
thrust::mr::disjoint_unsynchronized_pool_resource<thrust::device_memory_resource,
-
thrust::mr::new_delete_resource>
- thrust_pool_;
+
thrust::mr::new_delete_resource>* thrust_pool_ =
+ nullptr;
void* workspace = nullptr;
size_t workspace_size = 0;