This is an automated email from the ASF dual-hosted git repository.

richhuang pushed a commit to branch dev-qdp
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/dev-qdp by this push:
     new aa306793f [QDP] Fix DLPack device_id hardcoding (#745)
aa306793f is described below

commit aa306793f041340b5aaf2af7a11d03fd3d6de318
Author: KUAN-HAO HUANG <[email protected]>
AuthorDate: Tue Dec 23 10:51:14 2025 +0800

    [QDP] Fix DLPack device_id hardcoding (#745)
    
    * [QDP] Fix DLPack device_id hardcoding# Please enter the commit message 
for your changes. Lines starting
    
    * update and move comments position
    
    * improve the test
---
 qdp/qdp-core/src/dlpack.rs            |  2 +-
 qdp/qdp-core/src/gpu/memory.rs        |  4 ++++
 qdp/qdp-core/tests/api_workflow.rs    | 43 +++++++++++++++++++++++++++++++++++
 qdp/qdp-core/tests/memory_safety.rs   |  2 +-
 qdp/qdp-python/src/lib.rs             | 18 +++++++++++++--
 qdp/qdp-python/tests/test_bindings.py | 39 +++++++++++++++++++++++++++++++
 6 files changed, 104 insertions(+), 4 deletions(-)

diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index 883d19b37..dd134ca5d 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -140,7 +140,7 @@ impl GpuStateVector {
             data: self.ptr_void(),
             device: DLDevice {
                 device_type: DLDeviceType::kDLCUDA,
-                device_id: 0,
+                device_id: self.device_id as c_int,
             },
             ndim: 1,
             dtype: DLDataType {
diff --git a/qdp/qdp-core/src/gpu/memory.rs b/qdp/qdp-core/src/gpu/memory.rs
index 26e7b1383..1cfd32eca 100644
--- a/qdp/qdp-core/src/gpu/memory.rs
+++ b/qdp/qdp-core/src/gpu/memory.rs
@@ -190,6 +190,7 @@ pub struct GpuStateVector {
     pub(crate) buffer: Arc<BufferStorage>,
     pub num_qubits: usize,
     pub size_elements: usize,
+    pub device_id: usize,
 }
 
 // Safety: CudaSlice and Arc are both Send + Sync
@@ -229,6 +230,7 @@ impl GpuStateVector {
                 buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
                 num_qubits: qubits,
                 size_elements: _size_elements,
+                device_id: _device.ordinal(),
             })
         }
 
@@ -300,6 +302,7 @@ impl GpuStateVector {
                 buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
                 num_qubits: qubits,
                 size_elements: total_elements,
+                device_id: _device.ordinal(),
             })
         }
 
@@ -364,6 +367,7 @@ impl GpuStateVector {
                         buffer: Arc::new(BufferStorage::F32(GpuBufferRaw { 
slice })),
                         num_qubits: self.num_qubits,
                         size_elements: self.size_elements,
+                        device_id: device.ordinal(),
                     })
                 }
 
diff --git a/qdp/qdp-core/tests/api_workflow.rs 
b/qdp/qdp-core/tests/api_workflow.rs
index a1e97e31a..13c2126ec 100644
--- a/qdp/qdp-core/tests/api_workflow.rs
+++ b/qdp/qdp-core/tests/api_workflow.rs
@@ -107,3 +107,46 @@ fn test_amplitude_encoding_async_pipeline() {
         println!("PASS: Memory freed successfully");
     }
 }
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_dlpack_device_id() {
+    println!("Testing DLPack device_id propagation...");
+
+    let engine = match QdpEngine::new(0) {
+        Ok(e) => e,
+        Err(_) => {
+            println!("SKIP: No GPU available");
+            return;
+        }
+    };
+
+    let data = common::create_test_data(16);
+    let result = engine.encode(&data, 4, "amplitude");
+    assert!(result.is_ok(), "Encoding should succeed");
+
+    let dlpack_ptr = result.unwrap();
+    assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
+
+    unsafe {
+        let managed = &*dlpack_ptr;
+        let tensor = &managed.dl_tensor;
+
+        // Verify device_id is correctly set (0 for device 0)
+        assert_eq!(tensor.device.device_id, 0, "device_id should be 0 for 
device 0");
+
+        // Verify device_type is CUDA (kDLCUDA = 2)
+        use qdp_core::dlpack::DLDeviceType;
+        match tensor.device.device_type {
+            DLDeviceType::kDLCUDA => println!("PASS: Device type is CUDA"),
+            _ => panic!("Expected CUDA device type"),
+        }
+
+        println!("PASS: DLPack device_id correctly set to {}", 
tensor.device.device_id);
+
+        // Free memory
+        if let Some(deleter) = managed.deleter {
+            deleter(dlpack_ptr);
+        }
+    }
+}
diff --git a/qdp/qdp-core/tests/memory_safety.rs 
b/qdp/qdp-core/tests/memory_safety.rs
index 2b5fdd6e8..6aa2d355a 100644
--- a/qdp/qdp-core/tests/memory_safety.rs
+++ b/qdp/qdp-core/tests/memory_safety.rs
@@ -94,7 +94,7 @@ fn test_multiple_concurrent_states() {
 fn test_dlpack_tensor_metadata_default() {
     println!("Testing DLPack tensor metadata...");
 
-    let engine = match QdpEngine::new(0) {
+    let engine = match QdpEngine::new_with_precision(0, 
qdp_core::Precision::Float64) {
         Ok(e) => e,
         Err(_) => return,
     };
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index 04f3c5367..d94aceeb2 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -94,8 +94,22 @@ impl QuantumTensor {
     /// Returns:
     ///     Tuple of (device_type, device_id) where device_type=2 for CUDA
     fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
-        // DLDeviceType::kDLCUDA = 2, device_id = 0
-        Ok((2, 0))
+        if self.ptr.is_null() {
+            return Err(PyRuntimeError::new_err("Invalid DLPack tensor 
pointer"));
+        }
+
+        unsafe {
+            let tensor = &(*self.ptr).dl_tensor;
+            // device_type is an enum, convert to integer
+            // kDLCUDA = 2, kDLCPU = 1
+            // Ref: 
https://github.com/dmlc/dlpack/blob/6ea9b3eb64c881f614cd4537f95f0e125a35555c/include/dlpack/dlpack.h#L76-L80
+            let device_type = match tensor.device.device_type {
+                qdp_core::dlpack::DLDeviceType::kDLCUDA => 2,
+                qdp_core::dlpack::DLDeviceType::kDLCPU => 1,
+            };
+            // Read device_id from DLPack tensor metadata
+            Ok((device_type, tensor.device.device_id))
+        }
     }
 }
 
diff --git a/qdp/qdp-python/tests/test_bindings.py 
b/qdp/qdp-python/tests/test_bindings.py
index 1fc586f78..d3cda3e22 100644
--- a/qdp/qdp-python/tests/test_bindings.py
+++ b/qdp/qdp-python/tests/test_bindings.py
@@ -20,6 +20,16 @@ import pytest
 import mahout_qdp
 
 
+def _has_multi_gpu():
+    """Check if multiple GPUs are available via PyTorch."""
+    try:
+        import torch
+
+        return torch.cuda.is_available() and torch.cuda.device_count() >= 2
+    except ImportError:
+        return False
+
+
 def test_import():
     """Test that PyO3 bindings are properly imported."""
     assert hasattr(mahout_qdp, "QdpEngine")
@@ -50,6 +60,35 @@ def test_dlpack_device():
     assert device_info == (2, 0), "Expected (2, 0) for CUDA device 0"
 
 
[email protected]
[email protected](
+    not _has_multi_gpu(), reason="Multi-GPU setup required for this test"
+)
+def test_dlpack_device_id_non_zero():
+    """Test device_id propagation for non-zero devices (requires multi-GPU)."""
+    pytest.importorskip("torch")
+    import torch
+    from mahout_qdp import QdpEngine
+
+    # Test with device_id=1 (second GPU)
+    device_id = 1
+    engine = QdpEngine(device_id)
+    data = [1.0, 2.0, 3.0, 4.0]
+    qtensor = engine.encode(data, 2, "amplitude")
+
+    device_info = qtensor.__dlpack_device__()
+    assert device_info == (2, device_id), (
+        f"Expected (2, {device_id}) for CUDA device {device_id}"
+    )
+
+    # Verify PyTorch integration works with non-zero device_id
+    torch_tensor = torch.from_dlpack(qtensor)
+    assert torch_tensor.is_cuda
+    assert torch_tensor.device.index == device_id, (
+        f"PyTorch tensor should be on device {device_id}"
+    )
+
+
 @pytest.mark.gpu
 def test_dlpack_single_use():
     """Test that __dlpack__ can only be called once (requires GPU)."""

Reply via email to