This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3680a0d5a2 [RUNTIME][VULKAN] Support total_global_memory (#16890)
3680a0d5a2 is described below

commit 3680a0d5a23da22124c17a845a39f3ae36b70ca3
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Apr 16 15:48:41 2024 -0400

    [RUNTIME][VULKAN] Support total_global_memory (#16890)
    
    This PR supports total_global_memory query for vulkan devices.
---
 src/runtime/vulkan/vulkan_device.cc     | 7 +++++--
 src/runtime/vulkan/vulkan_device.h      | 2 ++
 src/runtime/vulkan/vulkan_device_api.cc | 1 +
 3 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/src/runtime/vulkan/vulkan_device.cc 
b/src/runtime/vulkan/vulkan_device.cc
index 7c5ac55f0b..cc39972432 100644
--- a/src/runtime/vulkan/vulkan_device.cc
+++ b/src/runtime/vulkan/vulkan_device.cc
@@ -293,7 +293,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, 
VkPhysicalDevice phy_
 
   for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
     VkMemoryType ty = prop.memoryTypes[k];
-    size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
+    int64_t heap_size = 
static_cast<int64_t>(prop.memoryHeaps[ty.heapIndex].size);
     // host visible
     if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
     // match copy requirment
@@ -312,7 +312,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, 
VkPhysicalDevice phy_
   win_rank = -1;
   for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
     VkMemoryType ty = prop.memoryTypes[k];
-    size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
+    int64_t heap_size = 
static_cast<int64_t>(prop.memoryHeaps[ty.heapIndex].size);
     // host visible
     if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
     // match copy requirment
@@ -324,8 +324,10 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, 
VkPhysicalDevice phy_
     if (rank > win_rank) {
       win_rank = rank;
       compute_mtype_index = k;
+      compute_memory_size = heap_size;
     }
   }
+
   ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";
 
   if (device_properties.supports_push_descriptor) {
@@ -383,6 +385,7 @@ void VulkanDevice::do_swap(VulkanDevice&& other) {
   std::swap(queue_insert_debug_utils_label_functions,
             other.queue_insert_debug_utils_label_functions);
   std::swap(compute_mtype_index, other.compute_mtype_index);
+  std::swap(compute_memory_size, other.compute_memory_size);
   std::swap(queue, other.queue);
   std::swap(queue_family_index, other.queue_family_index);
   std::swap(physical_device_, other.physical_device_);
diff --git a/src/runtime/vulkan/vulkan_device.h 
b/src/runtime/vulkan/vulkan_device.h
index 296483a6b1..0573a00e5c 100644
--- a/src/runtime/vulkan/vulkan_device.h
+++ b/src/runtime/vulkan/vulkan_device.h
@@ -223,6 +223,8 @@ class VulkanDevice {
       queue_insert_debug_utils_label_functions{nullptr};
   // Memory type index for compute
   uint32_t compute_mtype_index{0};
+  // maximum memory size for compute
+  int64_t compute_memory_size{0};
 
   // queue family_index;
   uint32_t queue_family_index{uint32_t(-1)};
diff --git a/src/runtime/vulkan/vulkan_device_api.cc 
b/src/runtime/vulkan/vulkan_device_api.cc
index 18a40bf54f..4b337dd524 100644
--- a/src/runtime/vulkan/vulkan_device_api.cc
+++ b/src/runtime/vulkan/vulkan_device_api.cc
@@ -165,6 +165,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind 
kind, TVMRetValue* rv)
       break;
 
     case kTotalGlobalMemory: {
+      *rv = device(index).compute_memory_size;
       return;
     }
   }

Reply via email to