This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3680a0d5a2 [RUNTIME][VULKAN] Support total_global_memory (#16890)
3680a0d5a2 is described below
commit 3680a0d5a23da22124c17a845a39f3ae36b70ca3
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Apr 16 15:48:41 2024 -0400
[RUNTIME][VULKAN] Support total_global_memory (#16890)
This PR supports total_global_memory query for vulkan devices.
---
src/runtime/vulkan/vulkan_device.cc | 7 +++++--
src/runtime/vulkan/vulkan_device.h | 2 ++
src/runtime/vulkan/vulkan_device_api.cc | 1 +
3 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/src/runtime/vulkan/vulkan_device.cc
b/src/runtime/vulkan/vulkan_device.cc
index 7c5ac55f0b..cc39972432 100644
--- a/src/runtime/vulkan/vulkan_device.cc
+++ b/src/runtime/vulkan/vulkan_device.cc
@@ -293,7 +293,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance,
VkPhysicalDevice phy_
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
- size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
+ int64_t heap_size =
static_cast<int64_t>(prop.memoryHeaps[ty.heapIndex].size);
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
// match copy requirment
@@ -312,7 +312,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance,
VkPhysicalDevice phy_
win_rank = -1;
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
- size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
+ int64_t heap_size =
static_cast<int64_t>(prop.memoryHeaps[ty.heapIndex].size);
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
// match copy requirment
@@ -324,8 +324,10 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance,
VkPhysicalDevice phy_
if (rank > win_rank) {
win_rank = rank;
compute_mtype_index = k;
+ compute_memory_size = heap_size;
}
}
+
ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";
if (device_properties.supports_push_descriptor) {
@@ -383,6 +385,7 @@ void VulkanDevice::do_swap(VulkanDevice&& other) {
std::swap(queue_insert_debug_utils_label_functions,
other.queue_insert_debug_utils_label_functions);
std::swap(compute_mtype_index, other.compute_mtype_index);
+ std::swap(compute_memory_size, other.compute_memory_size);
std::swap(queue, other.queue);
std::swap(queue_family_index, other.queue_family_index);
std::swap(physical_device_, other.physical_device_);
diff --git a/src/runtime/vulkan/vulkan_device.h
b/src/runtime/vulkan/vulkan_device.h
index 296483a6b1..0573a00e5c 100644
--- a/src/runtime/vulkan/vulkan_device.h
+++ b/src/runtime/vulkan/vulkan_device.h
@@ -223,6 +223,8 @@ class VulkanDevice {
queue_insert_debug_utils_label_functions{nullptr};
// Memory type index for compute
uint32_t compute_mtype_index{0};
+ // maximum memory size for compute
+ int64_t compute_memory_size{0};
// queue family_index;
uint32_t queue_family_index{uint32_t(-1)};
diff --git a/src/runtime/vulkan/vulkan_device_api.cc
b/src/runtime/vulkan/vulkan_device_api.cc
index 18a40bf54f..4b337dd524 100644
--- a/src/runtime/vulkan/vulkan_device_api.cc
+++ b/src/runtime/vulkan/vulkan_device_api.cc
@@ -165,6 +165,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind
kind, TVMRetValue* rv)
break;
case kTotalGlobalMemory: {
+ *rv = device(index).compute_memory_size;
return;
}
}