This is an automated email from the ASF dual-hosted git repository.
richhuang pushed a commit to branch dev-qdp
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/dev-qdp by this push:
new b6f1a3bc1 [QDP] DLPack shape/strides: Support batch 2D tensor
b6f1a3bc1 is described below
commit b6f1a3bc1b4f5c933407cb26da8e46ab7ee2f01f
Author: rich7420 <[email protected]>
AuthorDate: Fri Dec 19 16:25:15 2025 +0800
[QDP] DLPack shape/strides: Support batch 2D tensor
---
qdp/benchmark/benchmark_e2e.py | 23 ++++++----
qdp/qdp-core/src/dlpack.rs | 17 +++++--
qdp/qdp-core/src/gpu/memory.rs | 5 ++
qdp/qdp-core/src/preprocessing.rs | 6 +++
qdp/qdp-core/tests/api_workflow.rs | 92 +++++++++++++++++++++++++++++++++++++
qdp/qdp-core/tests/memory_safety.rs | 2 +-
qdp/qdp-core/tests/validation.rs | 24 ++++++++++
7 files changed, 155 insertions(+), 14 deletions(-)
diff --git a/qdp/benchmark/benchmark_e2e.py b/qdp/benchmark/benchmark_e2e.py
index b72c81d1f..0d419d0bf 100644
--- a/qdp/benchmark/benchmark_e2e.py
+++ b/qdp/benchmark/benchmark_e2e.py
@@ -277,15 +277,17 @@ def run_mahout_parquet(engine, n_qubits, n_samples):
dlpack_time = time.perf_counter() - dlpack_start
print(f" DLPack conversion: {dlpack_time:.4f} s")
- # Reshape to [n_samples, state_len] (still complex)
+ # Tensor is already 2D [n_samples, state_len] from to_dlpack()
state_len = 1 << n_qubits
+ assert gpu_batched.shape == (n_samples, state_len), (
+ f"Expected shape ({n_samples}, {state_len}), got {gpu_batched.shape}"
+ )
# Convert to float for model (batch already on GPU)
reshape_start = time.perf_counter()
- gpu_reshaped = gpu_batched.view(n_samples, state_len)
- gpu_all_data = gpu_reshaped.abs().to(torch.float32)
+ gpu_all_data = gpu_batched.abs().to(torch.float32)
reshape_time = time.perf_counter() - reshape_start
- print(f" Reshape & convert: {reshape_time:.4f} s")
+ print(f" Convert to float32: {reshape_time:.4f} s")
# Forward pass (data already on GPU)
for i in range(0, n_samples, BATCH_SIZE):
@@ -299,7 +301,7 @@ def run_mahout_parquet(engine, n_qubits, n_samples):
# Clean cache after benchmark completion
clean_cache()
- return total_time, gpu_reshaped
+ return total_time, gpu_batched
# -----------------------------------------------------------
@@ -325,13 +327,16 @@ def run_mahout_arrow(engine, n_qubits, n_samples):
dlpack_time = time.perf_counter() - dlpack_start
print(f" DLPack conversion: {dlpack_time:.4f} s")
+ # Tensor is already 2D [n_samples, state_len] from to_dlpack()
state_len = 1 << n_qubits
+ assert gpu_batched.shape == (n_samples, state_len), (
+ f"Expected shape ({n_samples}, {state_len}), got {gpu_batched.shape}"
+ )
reshape_start = time.perf_counter()
- gpu_reshaped = gpu_batched.view(n_samples, state_len)
- gpu_all_data = gpu_reshaped.abs().to(torch.float32)
+ gpu_all_data = gpu_batched.abs().to(torch.float32)
reshape_time = time.perf_counter() - reshape_start
- print(f" Reshape & convert: {reshape_time:.4f} s")
+ print(f" Convert to float32: {reshape_time:.4f} s")
for i in range(0, n_samples, BATCH_SIZE):
batch = gpu_all_data[i : i + BATCH_SIZE]
@@ -344,7 +349,7 @@ def run_mahout_arrow(engine, n_qubits, n_samples):
# Clean cache after benchmark completion
clean_cache()
- return total_time, gpu_reshaped
+ return total_time, gpu_batched
def compare_states(name_a, states_a, name_b, states_b):
diff --git a/qdp/qdp-core/src/dlpack.rs b/qdp/qdp-core/src/dlpack.rs
index 883d19b37..5a2f7ecea 100644
--- a/qdp/qdp-core/src/dlpack.rs
+++ b/qdp/qdp-core/src/dlpack.rs
@@ -120,9 +120,18 @@ impl GpuStateVector {
/// Freed by DLPack deleter when PyTorch releases tensor.
/// Do not free manually.
pub fn to_dlpack(&self) -> *mut DLManagedTensor {
- // Allocate shape/strides on heap (freed by deleter)
- let shape = vec![self.size_elements as i64];
- let strides = vec![1i64];
+ let (shape, strides, ndim) = if let Some(num_samples) =
self.num_samples {
+ // Batch: 2D shape [num_samples, state_len_per_sample], row-major
strides
+ let state_len_per_sample = self.size_elements / num_samples;
+ let shape = vec![num_samples as i64, state_len_per_sample as i64];
+ let strides = vec![state_len_per_sample as i64, 1i64]; // Strides
in elements, not bytes
+ (shape, strides, 2)
+ } else {
+ // Single state: 1D shape [size_elements]
+ let shape = vec![self.size_elements as i64];
+ let strides = vec![1i64];
+ (shape, strides, 1)
+ };
// Transfer ownership to DLPack deleter
let shape_ptr = Box::into_raw(shape.into_boxed_slice()) as *mut i64;
@@ -142,7 +151,7 @@ impl GpuStateVector {
device_type: DLDeviceType::kDLCUDA,
device_id: 0,
},
- ndim: 1,
+ ndim,
dtype: DLDataType {
code: DL_COMPLEX,
bits: dtype_bits,
diff --git a/qdp/qdp-core/src/gpu/memory.rs b/qdp/qdp-core/src/gpu/memory.rs
index 26e7b1383..240ec54cf 100644
--- a/qdp/qdp-core/src/gpu/memory.rs
+++ b/qdp/qdp-core/src/gpu/memory.rs
@@ -190,6 +190,8 @@ pub struct GpuStateVector {
pub(crate) buffer: Arc<BufferStorage>,
pub num_qubits: usize,
pub size_elements: usize,
+ /// Number of samples in batch. None for single state, Some(n) for batch.
+ pub(crate) num_samples: Option<usize>,
}
// Safety: CudaSlice and Arc are both Send + Sync
@@ -229,6 +231,7 @@ impl GpuStateVector {
buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
num_qubits: qubits,
size_elements: _size_elements,
+ num_samples: None,
})
}
@@ -300,6 +303,7 @@ impl GpuStateVector {
buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
num_qubits: qubits,
size_elements: total_elements,
+ num_samples: Some(num_samples),
})
}
@@ -364,6 +368,7 @@ impl GpuStateVector {
buffer: Arc::new(BufferStorage::F32(GpuBufferRaw {
slice })),
num_qubits: self.num_qubits,
size_elements: self.size_elements,
+ num_samples: self.num_samples, // Preserve batch
information
})
}
diff --git a/qdp/qdp-core/src/preprocessing.rs
b/qdp/qdp-core/src/preprocessing.rs
index 0d8e70148..43577a8eb 100644
--- a/qdp/qdp-core/src/preprocessing.rs
+++ b/qdp/qdp-core/src/preprocessing.rs
@@ -84,6 +84,12 @@ impl Preprocessor {
sample_size: usize,
num_qubits: usize,
) -> Result<()> {
+ if num_samples == 0 {
+ return Err(MahoutError::InvalidInput(
+ "num_samples must be greater than 0".to_string()
+ ));
+ }
+
if batch_data.len() != num_samples * sample_size {
return Err(MahoutError::InvalidInput(
format!("Batch data length {} doesn't match num_samples {} *
sample_size {}",
diff --git a/qdp/qdp-core/tests/api_workflow.rs
b/qdp/qdp-core/tests/api_workflow.rs
index a1e97e31a..7e9de3081 100644
--- a/qdp/qdp-core/tests/api_workflow.rs
+++ b/qdp/qdp-core/tests/api_workflow.rs
@@ -107,3 +107,95 @@ fn test_amplitude_encoding_async_pipeline() {
println!("PASS: Memory freed successfully");
}
}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_batch_dlpack_2d_shape() {
+ println!("Testing batch DLPack 2D shape...");
+
+ let engine = match QdpEngine::new(0) {
+ Ok(e) => e,
+ Err(_) => {
+ println!("SKIP: No GPU available");
+ return;
+ }
+ };
+
+ // Create batch data: 3 samples, each with 4 elements (2 qubits)
+ let num_samples = 3;
+ let num_qubits = 2;
+ let sample_size = 4;
+ let batch_data: Vec<f64> = (0..num_samples * sample_size)
+ .map(|i| (i as f64) / 10.0)
+ .collect();
+
+ let result = engine.encode_batch(&batch_data, num_samples, sample_size,
num_qubits, "amplitude");
+ assert!(result.is_ok(), "Batch encoding should succeed");
+
+ let dlpack_ptr = result.unwrap();
+ assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
+
+ unsafe {
+ let managed = &*dlpack_ptr;
+ let tensor = &managed.dl_tensor;
+
+ // Verify 2D shape for batch tensor
+ assert_eq!(tensor.ndim, 2, "Batch tensor should be 2D");
+
+ let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim
as usize);
+ assert_eq!(shape_slice[0], num_samples as i64, "First dimension should
be num_samples");
+ assert_eq!(shape_slice[1], (1 << num_qubits) as i64, "Second dimension
should be 2^num_qubits");
+
+ let strides_slice = std::slice::from_raw_parts(tensor.strides,
tensor.ndim as usize);
+ let state_len = 1 << num_qubits;
+ assert_eq!(strides_slice[0], state_len as i64, "Stride for first
dimension should be state_len");
+ assert_eq!(strides_slice[1], 1, "Stride for second dimension should be
1");
+
+ println!("PASS: Batch DLPack tensor has correct 2D shape: [{}, {}]",
shape_slice[0], shape_slice[1]);
+ println!("PASS: Strides are correct: [{}, {}]", strides_slice[0],
strides_slice[1]);
+
+ // Free memory
+ if let Some(deleter) = managed.deleter {
+ deleter(dlpack_ptr);
+ }
+ }
+}
+
+#[test]
+#[cfg(target_os = "linux")]
+fn test_single_encode_still_1d() {
+ println!("Testing single encode still returns 1D shape...");
+
+ let engine = match QdpEngine::new(0) {
+ Ok(e) => e,
+ Err(_) => {
+ println!("SKIP: No GPU available");
+ return;
+ }
+ };
+
+ let data = common::create_test_data(16);
+ let result = engine.encode(&data, 4, "amplitude");
+ assert!(result.is_ok(), "Encoding should succeed");
+
+ let dlpack_ptr = result.unwrap();
+ assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
+
+ unsafe {
+ let managed = &*dlpack_ptr;
+ let tensor = &managed.dl_tensor;
+
+ // Verify 1D shape for single encode (backward compatibility)
+ assert_eq!(tensor.ndim, 1, "Single encode should still be 1D");
+
+ let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim
as usize);
+ assert_eq!(shape_slice[0], 16, "Single encode shape should be [2^4]");
+
+ println!("PASS: Single encode still returns 1D shape: [{}]",
shape_slice[0]);
+
+ // Free memory
+ if let Some(deleter) = managed.deleter {
+ deleter(dlpack_ptr);
+ }
+ }
+}
diff --git a/qdp/qdp-core/tests/memory_safety.rs
b/qdp/qdp-core/tests/memory_safety.rs
index 833190c48..37f45478a 100644
--- a/qdp/qdp-core/tests/memory_safety.rs
+++ b/qdp/qdp-core/tests/memory_safety.rs
@@ -114,7 +114,7 @@ fn test_dlpack_tensor_metadata() {
assert_eq!(stride, 1, "Stride for 1D contiguous array should be 1");
assert_eq!(tensor.dtype.code, 5, "Should be complex type (code=5)");
- assert_eq!(tensor.dtype.bits, 128, "Should be 128 bits (2x64-bit
floats)");
+ assert_eq!(tensor.dtype.bits, 64, "Should be 64 bits (2x32-bit floats,
Float32 default)");
println!("PASS: DLPack metadata verified");
println!(" ndim: {}", tensor.ndim);
diff --git a/qdp/qdp-core/tests/validation.rs b/qdp/qdp-core/tests/validation.rs
index cc12a995a..6fc591e53 100644
--- a/qdp/qdp-core/tests/validation.rs
+++ b/qdp/qdp-core/tests/validation.rs
@@ -119,6 +119,30 @@ fn test_input_validation_max_qubits() {
}
}
+#[test]
+#[cfg(target_os = "linux")]
+fn test_input_validation_batch_zero_samples() {
+ println!("Testing zero num_samples rejection...");
+
+ let engine = match QdpEngine::new(0) {
+ Ok(e) => e,
+ Err(_) => return,
+ };
+
+ let batch_data = vec![1.0, 2.0, 3.0, 4.0];
+ let result = engine.encode_batch(&batch_data, 0, 4, 2, "amplitude");
+ assert!(result.is_err(), "Should reject zero num_samples");
+
+ match result {
+ Err(MahoutError::InvalidInput(msg)) => {
+ assert!(msg.contains("num_samples must be greater than 0"),
+ "Error should mention num_samples requirement");
+ println!("PASS: Correctly rejected zero num_samples: {}", msg);
+ }
+ _ => panic!("Expected InvalidInput error for zero num_samples"),
+ }
+}
+
#[test]
#[cfg(target_os = "linux")]
fn test_empty_data() {