This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a7dd32cc16 [DeviceAPI] Support querying total global memory (#16398)
a7dd32cc16 is described below

commit a7dd32cc168b434b591bc4bfe1f446e42c07e9de
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jan 15 18:43:27 2024 -0500

    [DeviceAPI] Support querying total global memory (#16398)
    
    This PR introduces a new attribute for device backends:
    `total_global_memory`. This attributes returns the total available
    global memory on a device in bytes.
    
    Tested locally on CUDA/ROCm/Metal/OpenCL:
    ```python
    >>> import tvm
    >>> tvm.metal().total_global_memory
    154618822656
    ```
---
 include/tvm/runtime/device_api.h        |  1 +
 python/tvm/_ffi/runtime_ctypes.py       | 14 ++++++++++++++
 src/runtime/cuda/cuda_device_api.cc     | 10 +++++++++-
 src/runtime/metal/metal_device_api.mm   |  4 ++++
 src/runtime/opencl/opencl_device_api.cc | 10 +++++++++-
 src/runtime/rocm/rocm_device_api.cc     | 11 ++++++++++-
 src/runtime/vulkan/vulkan_device_api.cc |  4 ++++
 7 files changed, 51 insertions(+), 3 deletions(-)

diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index e33539dadd..9ff469b7c8 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -50,6 +50,7 @@ enum DeviceAttrKind : int {
   kApiVersion = 11,
   kDriverVersion = 12,
   kL2CacheSizeBytes = 13,
+  kTotalGlobalMemory = 14,
 };
 
 #ifdef TVM_KALLOC_ALIGNMENT
diff --git a/python/tvm/_ffi/runtime_ctypes.py 
b/python/tvm/_ffi/runtime_ctypes.py
index 7836f42247..54e4d8f205 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -506,6 +506,20 @@ class Device(ctypes.Structure):
         """
         return self._GetDeviceAttr(self.device_type, self.device_id, 13)
 
+    @property
+    def total_global_memory(self):
+        """Return size of the total global memory.
+
+        Supported devices include CUDA/ROCm/Metal/OpenCL.
+
+        Returns
+        -------
+        total_global_memory : int or None
+            Return the global memory available on device in bytes.
+            Return None if the device does not support this feature.
+        """
+        return self._GetDeviceAttr(self.device_type, self.device_id, 14)
+
     def texture_spatial_limit(self):
         """Returns limits for textures by spatial dimensions
 
diff --git a/src/runtime/cuda/cuda_device_api.cc 
b/src/runtime/cuda/cuda_device_api.cc
index 769f01063f..f493865e0d 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -106,12 +106,20 @@ class CUDADeviceAPI final : public DeviceAPI {
       }
       case kDriverVersion:
         return;
-      case kL2CacheSizeBytes:
+      case kL2CacheSizeBytes: {
         // Get size of device l2 cache size in bytes.
         int l2_size = 0;
         CUDA_CALL(cudaDeviceGetAttribute(&l2_size, cudaDevAttrL2CacheSize, 
dev.device_id));
         *rv = l2_size;
         return;
+      }
+      case kTotalGlobalMemory: {
+        cudaDeviceProp prop;
+        CUDA_CALL(cudaGetDeviceProperties(&prop, dev.device_id));
+        int64_t total_global_memory = prop.totalGlobalMem;
+        *rv = total_global_memory;
+        return;
+      }
     }
     *rv = value;
   }
diff --git a/src/runtime/metal/metal_device_api.mm 
b/src/runtime/metal/metal_device_api.mm
index f7c2976d22..c4ffc8943c 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -89,6 +89,10 @@ void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind 
kind, TVMRetValue* rv) {
         return;
       case kL2CacheSizeBytes:
         return;
+      case kTotalGlobalMemory: {
+        *rv = static_cast<int64_t>([devices[dev.device_id] 
recommendedMaxWorkingSetSize]);
+        return;
+      }
     }
   };
 }
diff --git a/src/runtime/opencl/opencl_device_api.cc 
b/src/runtime/opencl/opencl_device_api.cc
index fb9adc2757..96ec8ed69f 100644
--- a/src/runtime/opencl/opencl_device_api.cc
+++ b/src/runtime/opencl/opencl_device_api.cc
@@ -199,13 +199,21 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind 
kind, TVMRetValue* rv)
       *rv = std::string(value);
       break;
     }
-    case kL2CacheSizeBytes:
+    case kL2CacheSizeBytes: {
       // NOTE(Zihao): this API cannot reflect the real L2 cache size in both 
CUDA/AMD GPUs.
       cl_ulong value;
       OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, 
sizeof(value), &value,
                                   nullptr));
       *rv = static_cast<int64_t>(value);
       break;
+    }
+    case kTotalGlobalMemory: {
+      cl_ulong total_global_memory;
+      OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_GLOBAL_MEM_SIZE, 
sizeof(total_global_memory),
+                                  &total_global_memory, nullptr));
+      *rv = static_cast<int64_t>(total_global_memory);
+      return;
+    }
   }
 }
 
diff --git a/src/runtime/rocm/rocm_device_api.cc 
b/src/runtime/rocm/rocm_device_api.cc
index c2fb42ee36..72f17ede52 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -122,11 +122,20 @@ class ROCMDeviceAPI final : public DeviceAPI {
       }
       case kDriverVersion:
         return;
-      case kL2CacheSizeBytes:
+      case kL2CacheSizeBytes: {
         // Get size of device l2 cache size in bytes.
         int l2_size;
         ROCM_CALL(hipDeviceGetAttribute(&l2_size, 
hipDeviceAttributeL2CacheSize, device.device_id));
         *rv = l2_size;
+        return;
+      }
+      case kTotalGlobalMemory: {
+        hipDeviceProp_t prop;
+        ROCM_CALL(hipGetDeviceProperties(&prop, device.device_id));
+        int64_t total_global_memory = prop.totalGlobalMem;
+        *rv = total_global_memory;
+        return;
+      }
     }
     *rv = value;
   }
diff --git a/src/runtime/vulkan/vulkan_device_api.cc 
b/src/runtime/vulkan/vulkan_device_api.cc
index d67746856c..e02c9304e1 100644
--- a/src/runtime/vulkan/vulkan_device_api.cc
+++ b/src/runtime/vulkan/vulkan_device_api.cc
@@ -163,6 +163,10 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind 
kind, TVMRetValue* rv)
 
     case kL2CacheSizeBytes:
       break;
+
+    case kTotalGlobalMemory: {
+      return;
+    }
   }
 }
 

Reply via email to