This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a7dd32cc16 [DeviceAPI] Support querying total global memory (#16398)
a7dd32cc16 is described below
commit a7dd32cc168b434b591bc4bfe1f446e42c07e9de
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jan 15 18:43:27 2024 -0500
[DeviceAPI] Support querying total global memory (#16398)
This PR introduces a new attribute for device backends:
`total_global_memory`. This attributes returns the total available
global memory on a device in bytes.
Tested locally on CUDA/ROCm/Metal/OpenCL:
```python
>>> import tvm
>>> tvm.metal().total_global_memory
154618822656
```
---
include/tvm/runtime/device_api.h | 1 +
python/tvm/_ffi/runtime_ctypes.py | 14 ++++++++++++++
src/runtime/cuda/cuda_device_api.cc | 10 +++++++++-
src/runtime/metal/metal_device_api.mm | 4 ++++
src/runtime/opencl/opencl_device_api.cc | 10 +++++++++-
src/runtime/rocm/rocm_device_api.cc | 11 ++++++++++-
src/runtime/vulkan/vulkan_device_api.cc | 4 ++++
7 files changed, 51 insertions(+), 3 deletions(-)
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index e33539dadd..9ff469b7c8 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -50,6 +50,7 @@ enum DeviceAttrKind : int {
kApiVersion = 11,
kDriverVersion = 12,
kL2CacheSizeBytes = 13,
+ kTotalGlobalMemory = 14,
};
#ifdef TVM_KALLOC_ALIGNMENT
diff --git a/python/tvm/_ffi/runtime_ctypes.py
b/python/tvm/_ffi/runtime_ctypes.py
index 7836f42247..54e4d8f205 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -506,6 +506,20 @@ class Device(ctypes.Structure):
"""
return self._GetDeviceAttr(self.device_type, self.device_id, 13)
+ @property
+ def total_global_memory(self):
+ """Return size of the total global memory.
+
+ Supported devices include CUDA/ROCm/Metal/OpenCL.
+
+ Returns
+ -------
+ total_global_memory : int or None
+ Return the global memory available on device in bytes.
+ Return None if the device does not support this feature.
+ """
+ return self._GetDeviceAttr(self.device_type, self.device_id, 14)
+
def texture_spatial_limit(self):
"""Returns limits for textures by spatial dimensions
diff --git a/src/runtime/cuda/cuda_device_api.cc
b/src/runtime/cuda/cuda_device_api.cc
index 769f01063f..f493865e0d 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -106,12 +106,20 @@ class CUDADeviceAPI final : public DeviceAPI {
}
case kDriverVersion:
return;
- case kL2CacheSizeBytes:
+ case kL2CacheSizeBytes: {
// Get size of device l2 cache size in bytes.
int l2_size = 0;
CUDA_CALL(cudaDeviceGetAttribute(&l2_size, cudaDevAttrL2CacheSize,
dev.device_id));
*rv = l2_size;
return;
+ }
+ case kTotalGlobalMemory: {
+ cudaDeviceProp prop;
+ CUDA_CALL(cudaGetDeviceProperties(&prop, dev.device_id));
+ int64_t total_global_memory = prop.totalGlobalMem;
+ *rv = total_global_memory;
+ return;
+ }
}
*rv = value;
}
diff --git a/src/runtime/metal/metal_device_api.mm
b/src/runtime/metal/metal_device_api.mm
index f7c2976d22..c4ffc8943c 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -89,6 +89,10 @@ void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind
kind, TVMRetValue* rv) {
return;
case kL2CacheSizeBytes:
return;
+ case kTotalGlobalMemory: {
+ *rv = static_cast<int64_t>([devices[dev.device_id]
recommendedMaxWorkingSetSize]);
+ return;
+ }
}
};
}
diff --git a/src/runtime/opencl/opencl_device_api.cc
b/src/runtime/opencl/opencl_device_api.cc
index fb9adc2757..96ec8ed69f 100644
--- a/src/runtime/opencl/opencl_device_api.cc
+++ b/src/runtime/opencl/opencl_device_api.cc
@@ -199,13 +199,21 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind
kind, TVMRetValue* rv)
*rv = std::string(value);
break;
}
- case kL2CacheSizeBytes:
+ case kL2CacheSizeBytes: {
// NOTE(Zihao): this API cannot reflect the real L2 cache size in both
CUDA/AMD GPUs.
cl_ulong value;
OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_GLOBAL_MEM_CACHE_SIZE,
sizeof(value), &value,
nullptr));
*rv = static_cast<int64_t>(value);
break;
+ }
+ case kTotalGlobalMemory: {
+ cl_ulong total_global_memory;
+ OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_GLOBAL_MEM_SIZE,
sizeof(total_global_memory),
+ &total_global_memory, nullptr));
+ *rv = static_cast<int64_t>(total_global_memory);
+ return;
+ }
}
}
diff --git a/src/runtime/rocm/rocm_device_api.cc
b/src/runtime/rocm/rocm_device_api.cc
index c2fb42ee36..72f17ede52 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -122,11 +122,20 @@ class ROCMDeviceAPI final : public DeviceAPI {
}
case kDriverVersion:
return;
- case kL2CacheSizeBytes:
+ case kL2CacheSizeBytes: {
// Get size of device l2 cache size in bytes.
int l2_size;
ROCM_CALL(hipDeviceGetAttribute(&l2_size,
hipDeviceAttributeL2CacheSize, device.device_id));
*rv = l2_size;
+ return;
+ }
+ case kTotalGlobalMemory: {
+ hipDeviceProp_t prop;
+ ROCM_CALL(hipGetDeviceProperties(&prop, device.device_id));
+ int64_t total_global_memory = prop.totalGlobalMem;
+ *rv = total_global_memory;
+ return;
+ }
}
*rv = value;
}
diff --git a/src/runtime/vulkan/vulkan_device_api.cc
b/src/runtime/vulkan/vulkan_device_api.cc
index d67746856c..e02c9304e1 100644
--- a/src/runtime/vulkan/vulkan_device_api.cc
+++ b/src/runtime/vulkan/vulkan_device_api.cc
@@ -163,6 +163,10 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind
kind, TVMRetValue* rv)
case kL2CacheSizeBytes:
break;
+
+ case kTotalGlobalMemory: {
+ return;
+ }
}
}