This is an automated email from the ASF dual-hosted git repository.
jiekaichang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/main by this push:
new ae6e4219f [QDP] PyTorch CUDA stream‑aware encode for GPU tensors (#930)
ae6e4219f is described below
commit ae6e4219f1474f5dbb4b15c625dbc0c3e862570e
Author: Jie-Kai Chang <[email protected]>
AuthorDate: Fri Jan 30 17:45:05 2026 +0800
[QDP] PyTorch CUDA stream‑aware encode for GPU tensors (#930)
* GPU pointer validation stream sync
Signed-off-by: 400Ping <[email protected]>
* fix
Signed-off-by: 400Ping <[email protected]>
* fix ci error
Signed-off-by: 400Ping <[email protected]>
* fix
Signed-off-by: 400Ping <[email protected]>
* fix conflicts
Signed-off-by: 400Ping <[email protected]>
* update
Signed-off-by: 400Ping <[email protected]>
* fix conflicts
Signed-off-by: 400Ping <[email protected]>
* fix pre-commit
Signed-off-by: 400Ping <[email protected]>
---------
Signed-off-by: 400Ping <[email protected]>
Signed-off-by: 400Ping <[email protected]>
---
qdp/qdp-core/src/gpu/cuda_sync.rs | 76 +++++++++++++
qdp/qdp-core/src/gpu/encodings/amplitude.rs | 18 +++-
qdp/qdp-core/src/gpu/mod.rs | 2 +
qdp/qdp-core/src/gpu/pipeline.rs | 18 ++--
qdp/qdp-core/src/lib.rs | 100 +++++++++++++----
qdp/qdp-python/src/lib.rs | 160 ++++++++++++++++++++++++++++
6 files changed, 341 insertions(+), 33 deletions(-)
diff --git a/qdp/qdp-core/src/gpu/cuda_sync.rs
b/qdp/qdp-core/src/gpu/cuda_sync.rs
new file mode 100644
index 000000000..077da077d
--- /dev/null
+++ b/qdp/qdp-core/src/gpu/cuda_sync.rs
@@ -0,0 +1,76 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Shared CUDA stream synchronization with unified error reporting.
+
+#[cfg(target_os = "linux")]
+use std::ffi::c_void;
+
+use crate::error::{MahoutError, Result, cuda_error_to_string};
+
+/// Synchronizes a CUDA stream and returns a consistent error with context.
+///
+/// Error message format: `"{context}: {code} ({description})"` so that all
+/// call sites report stream sync failures the same way.
+///
+/// # Arguments
+/// * `stream` - CUDA stream pointer (e.g. from PyTorch or default null)
+/// * `context` - Short description for the error message (e.g. "Norm stream
synchronize failed")
+///
+/// # Safety
+/// The stream pointer must be valid for the duration of this call.
+#[cfg(target_os = "linux")]
+pub(crate) fn sync_cuda_stream(stream: *mut c_void, context: &str) ->
Result<()> {
+ let ret = unsafe { crate::gpu::cuda_ffi::cudaStreamSynchronize(stream) };
+ if ret != 0 {
+ return Err(MahoutError::Cuda(format!(
+ "{}: {} ({})",
+ context,
+ ret,
+ cuda_error_to_string(ret)
+ )));
+ }
+ Ok(())
+}
+
+#[cfg(all(test, target_os = "linux"))]
+mod tests {
+ use super::*;
+ use std::ffi::c_void;
+ use std::ptr;
+
+ #[test]
+ fn sync_null_stream_does_not_panic() {
+ // Default stream (null) sync: may succeed or fail depending on
driver/context.
+ let _ = sync_cuda_stream(ptr::null_mut::<c_void>(), "test context");
+ }
+
+ #[test]
+ fn error_message_format_includes_context() {
+ // When sync fails, error must be MahoutError::Cuda with format
"{context}: {code} ({desc})".
+ // We build the same format as sync_cuda_stream to assert consistency.
+ let context = "TestContext";
+ let code = 999i32;
+ let desc = crate::error::cuda_error_to_string(code);
+ let msg = format!("{}: {} ({})", context, code, desc);
+ assert!(
+ msg.starts_with("TestContext:"),
+ "format should start with context"
+ );
+ assert!(msg.contains("TestContext"), "format should contain context");
+ assert!(msg.contains("999"), "format should contain error code");
+ }
+}
diff --git a/qdp/qdp-core/src/gpu/encodings/amplitude.rs
b/qdp/qdp-core/src/gpu/encodings/amplitude.rs
index 34c486c57..917336a08 100644
--- a/qdp/qdp-core/src/gpu/encodings/amplitude.rs
+++ b/qdp/qdp-core/src/gpu/encodings/amplitude.rs
@@ -33,6 +33,8 @@ use cudarc::driver::CudaDevice;
#[cfg(target_os = "linux")]
use crate::gpu::cuda_ffi::cudaMemsetAsync;
#[cfg(target_os = "linux")]
+use crate::gpu::cuda_sync::sync_cuda_stream;
+#[cfg(target_os = "linux")]
use crate::gpu::memory::{ensure_device_memory_available, map_allocation_error};
#[cfg(target_os = "linux")]
use cudarc::driver::{DevicePtr, DevicePtrMut};
@@ -436,6 +438,18 @@ impl AmplitudeEncoder {
device: &Arc<CudaDevice>,
input_ptr: *const f64,
len: usize,
+ ) -> Result<f64> {
+ unsafe {
+ Self::calculate_inv_norm_gpu_with_stream(device, input_ptr, len,
std::ptr::null_mut())
+ }
+ }
+
+ #[cfg(target_os = "linux")]
+ pub(crate) unsafe fn calculate_inv_norm_gpu_with_stream(
+ device: &Arc<CudaDevice>,
+ input_ptr: *const f64,
+ len: usize,
+ stream: *mut c_void,
) -> Result<f64> {
crate::profile_scope!("GPU::NormSingle");
@@ -448,7 +462,7 @@ impl AmplitudeEncoder {
input_ptr,
len,
*norm_buffer.device_ptr_mut() as *mut f64,
- std::ptr::null_mut(), // default stream
+ stream,
)
};
@@ -460,6 +474,8 @@ impl AmplitudeEncoder {
)));
}
+ sync_cuda_stream(stream, "Norm stream synchronize failed")?;
+
let inv_norm_host = device
.dtoh_sync_copy(&norm_buffer)
.map_err(|e| MahoutError::Cuda(format!("Failed to copy norm to
host: {:?}", e)))?;
diff --git a/qdp/qdp-core/src/gpu/mod.rs b/qdp/qdp-core/src/gpu/mod.rs
index 964662af7..7e16be7be 100644
--- a/qdp/qdp-core/src/gpu/mod.rs
+++ b/qdp/qdp-core/src/gpu/mod.rs
@@ -16,6 +16,8 @@
#[cfg(target_os = "linux")]
pub mod buffer_pool;
+#[cfg(target_os = "linux")]
+pub(crate) mod cuda_sync;
pub mod encodings;
pub mod memory;
#[cfg(target_os = "linux")]
diff --git a/qdp/qdp-core/src/gpu/pipeline.rs b/qdp/qdp-core/src/gpu/pipeline.rs
index 073ab1c7d..175621603 100644
--- a/qdp/qdp-core/src/gpu/pipeline.rs
+++ b/qdp/qdp-core/src/gpu/pipeline.rs
@@ -29,9 +29,11 @@ use crate::gpu::buffer_pool::{PinnedBufferHandle,
PinnedBufferPool};
#[cfg(target_os = "linux")]
use crate::gpu::cuda_ffi::{
CUDA_EVENT_DISABLE_TIMING, CUDA_MEMCPY_HOST_TO_DEVICE,
cudaEventCreateWithFlags,
- cudaEventDestroy, cudaEventRecord, cudaMemcpyAsync, cudaStreamSynchronize,
cudaStreamWaitEvent,
+ cudaEventDestroy, cudaEventRecord, cudaMemcpyAsync, cudaStreamWaitEvent,
};
#[cfg(target_os = "linux")]
+use crate::gpu::cuda_sync::sync_cuda_stream;
+#[cfg(target_os = "linux")]
use crate::gpu::memory::{ensure_device_memory_available, map_allocation_error};
#[cfg(target_os = "linux")]
use crate::gpu::overlap_tracker::OverlapTracker;
@@ -179,16 +181,10 @@ impl PipelineContext {
/// The context and its copy stream must be valid and not destroyed while
syncing.
pub unsafe fn sync_copy_stream(&self) -> Result<()> {
crate::profile_scope!("Pipeline::SyncCopy");
- unsafe {
- let ret = cudaStreamSynchronize(self.stream_copy.stream as *mut
c_void);
- if ret != 0 {
- return Err(MahoutError::Cuda(format!(
- "cudaStreamSynchronize(copy) failed: {}",
- ret
- )));
- }
- }
- Ok(())
+ sync_cuda_stream(
+ self.stream_copy.stream as *mut c_void,
+ "cudaStreamSynchronize(copy) failed",
+ )
}
}
diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index 9a5290447..3de648fc9 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -35,9 +35,13 @@ mod profiling;
pub use error::{MahoutError, Result, cuda_error_to_string};
pub use gpu::memory::Precision;
+#[cfg(target_os = "linux")]
+use std::ffi::c_void;
use std::sync::Arc;
use crate::dlpack::DLManagedTensor;
+#[cfg(target_os = "linux")]
+use crate::gpu::cuda_sync::sync_cuda_stream;
use crate::gpu::get_encoder;
use cudarc::driver::CudaDevice;
@@ -362,6 +366,9 @@ impl QdpEngine {
/// a raw GPU pointer directly, avoiding the GPU→CPU→GPU copy that would
otherwise
/// be required.
///
+ /// Uses the default CUDA stream. For PyTorch stream interop, use
+ /// `encode_from_gpu_ptr_with_stream`.
+ ///
/// TODO: Refactor to use QuantumEncoder trait (add `encode_from_gpu_ptr`
to trait)
/// to reduce duplication with AmplitudeEncoder::encode(). This would also
make it
/// easier to add GPU pointer support for other encoders (angle, basis) in
the future.
@@ -387,6 +394,34 @@ impl QdpEngine {
input_len: usize,
num_qubits: usize,
encoding_method: &str,
+ ) -> Result<*mut DLManagedTensor> {
+ unsafe {
+ self.encode_from_gpu_ptr_with_stream(
+ input_d,
+ input_len,
+ num_qubits,
+ encoding_method,
+ std::ptr::null_mut(),
+ )
+ }
+ }
+
+ /// Encode from existing GPU pointer on a specified CUDA stream.
+ ///
+ /// The caller must ensure the stream is valid for the device, and that any
+ /// producer work on that stream has been enqueued before this call.
+ ///
+ /// # Safety
+ /// In addition to the `encode_from_gpu_ptr` requirements, the stream
pointer
+ /// must remain valid for the duration of this call.
+ #[cfg(target_os = "linux")]
+ pub unsafe fn encode_from_gpu_ptr_with_stream(
+ &self,
+ input_d: *const f64,
+ input_len: usize,
+ num_qubits: usize,
+ encoding_method: &str,
+ stream: *mut c_void,
) -> Result<*mut DLManagedTensor> {
crate::profile_scope!("Mahout::EncodeFromGpuPtr");
@@ -399,7 +434,7 @@ impl QdpEngine {
validate_cuda_input_ptr(&self.device, input_d)?;
let state_len = 1usize << num_qubits;
- let method = encoding_method.to_ascii_lowercase();
+ let method = encoding_method.to_lowercase();
match method.as_str() {
"amplitude" => {
@@ -417,12 +452,12 @@ impl QdpEngine {
let inv_norm = {
crate::profile_scope!("GPU::NormFromPtr");
- // SAFETY: input_d validity is guaranteed by the caller's
safety contract
unsafe {
- gpu::AmplitudeEncoder::calculate_inv_norm_gpu(
+
gpu::AmplitudeEncoder::calculate_inv_norm_gpu_with_stream(
&self.device,
input_d,
input_len,
+ stream,
)?
}
};
@@ -442,7 +477,7 @@ impl QdpEngine {
input_len,
state_len,
inv_norm,
- std::ptr::null_mut(), // default stream
+ stream,
)
};
@@ -457,9 +492,7 @@ impl QdpEngine {
{
crate::profile_scope!("GPU::Synchronize");
- self.device.synchronize().map_err(|e| {
- MahoutError::Cuda(format!("CUDA device synchronize
failed: {:?}", e))
- })?;
+ sync_cuda_stream(stream, "CUDA stream synchronize
failed")?;
}
let state_vector = state_vector.to_precision(&self.device,
self.precision)?;
@@ -492,7 +525,7 @@ impl QdpEngine {
state_ptr as *mut std::ffi::c_void,
state_len,
num_qubits as u32,
- std::ptr::null_mut(), // default stream
+ stream,
)
};
@@ -507,9 +540,7 @@ impl QdpEngine {
{
crate::profile_scope!("GPU::Synchronize");
- self.device.synchronize().map_err(|e| {
- MahoutError::Cuda(format!("CUDA device synchronize
failed: {:?}", e))
- })?;
+ sync_cuda_stream(stream, "CUDA stream synchronize
failed")?;
}
let state_vector = state_vector.to_precision(&self.device,
self.precision)?;
@@ -525,6 +556,8 @@ impl QdpEngine {
/// Encode batch from existing GPU pointer (zero-copy for CUDA tensors)
///
/// This method enables zero-copy batch encoding from PyTorch CUDA tensors.
+ /// Uses the default CUDA stream. For PyTorch stream interop, use
+ /// `encode_batch_from_gpu_ptr_with_stream`.
///
/// TODO: Refactor to use QuantumEncoder trait (see `encode_from_gpu_ptr`
TODO).
///
@@ -551,6 +584,33 @@ impl QdpEngine {
sample_size: usize,
num_qubits: usize,
encoding_method: &str,
+ ) -> Result<*mut DLManagedTensor> {
+ unsafe {
+ self.encode_batch_from_gpu_ptr_with_stream(
+ input_batch_d,
+ num_samples,
+ sample_size,
+ num_qubits,
+ encoding_method,
+ std::ptr::null_mut(),
+ )
+ }
+ }
+
+ /// Encode batch from existing GPU pointer on a specified CUDA stream.
+ ///
+ /// # Safety
+ /// In addition to the `encode_batch_from_gpu_ptr` requirements, the
stream pointer
+ /// must remain valid for the duration of this call.
+ #[cfg(target_os = "linux")]
+ pub unsafe fn encode_batch_from_gpu_ptr_with_stream(
+ &self,
+ input_batch_d: *const f64,
+ num_samples: usize,
+ sample_size: usize,
+ num_qubits: usize,
+ encoding_method: &str,
+ stream: *mut c_void,
) -> Result<*mut DLManagedTensor> {
crate::profile_scope!("Mahout::EncodeBatchFromGpuPtr");
@@ -602,7 +662,7 @@ impl QdpEngine {
num_samples,
sample_size,
*buffer.device_ptr_mut() as *mut f64,
- std::ptr::null_mut(), // default stream
+ stream,
)
};
@@ -619,6 +679,7 @@ impl QdpEngine {
{
crate::profile_scope!("GPU::NormValidation");
+ sync_cuda_stream(stream, "Norm stream synchronize
failed")?;
let host_inv_norms =
self.device.dtoh_sync_copy(&inv_norms_gpu).map_err(|e|
{
MahoutError::Cuda(format!("Failed to copy norms to
host: {:?}", e))
@@ -650,7 +711,7 @@ impl QdpEngine {
num_samples,
sample_size,
state_len,
- std::ptr::null_mut(), // default stream
+ stream,
)
};
@@ -665,9 +726,7 @@ impl QdpEngine {
{
crate::profile_scope!("GPU::Synchronize");
- self.device
- .synchronize()
- .map_err(|e| MahoutError::Cuda(format!("Sync failed:
{:?}", e)))?;
+ sync_cuda_stream(stream, "CUDA stream synchronize
failed")?;
}
let batch_state_vector =
@@ -702,7 +761,7 @@ impl QdpEngine {
num_samples,
sample_size,
*buffer.device_ptr_mut() as *mut f64,
- std::ptr::null_mut(), // default stream
+ stream,
)
};
@@ -719,6 +778,7 @@ impl QdpEngine {
{
crate::profile_scope!("GPU::AngleFiniteValidationHostCopy");
+ sync_cuda_stream(stream, "Angle norm stream synchronize
failed")?;
let host_norms = self
.device
.dtoh_sync_copy(&angle_validation_buffer)
@@ -758,7 +818,7 @@ impl QdpEngine {
num_samples,
state_len,
num_qubits as u32,
- std::ptr::null_mut(), // default stream
+ stream,
)
};
@@ -773,9 +833,7 @@ impl QdpEngine {
{
crate::profile_scope!("GPU::Synchronize");
- self.device
- .synchronize()
- .map_err(|e| MahoutError::Cuda(format!("Sync failed:
{:?}", e)))?;
+ sync_cuda_stream(stream, "CUDA stream synchronize
failed")?;
}
let batch_state_vector =
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index 1af58a617..1a6ebe735 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -20,6 +20,7 @@ use pyo3::ffi;
use pyo3::prelude::*;
use qdp_core::dlpack::DLManagedTensor;
use qdp_core::{Precision, QdpEngine as CoreEngine};
+use std::ffi::c_void;
/// Quantum tensor wrapper implementing DLPack protocol
///
@@ -221,6 +222,59 @@ fn get_tensor_device_id(tensor: &Bound<'_, PyAny>) ->
PyResult<i32> {
Ok(device_index)
}
+/// Get the current CUDA stream pointer for the tensor's device.
+fn get_torch_cuda_stream_ptr(tensor: &Bound<'_, PyAny>) -> PyResult<*mut
c_void> {
+ let py = tensor.py();
+ let torch = PyModule::import(py, "torch")
+ .map_err(|_| PyRuntimeError::new_err("Failed to import torch
module"))?;
+ let cuda = torch.getattr("cuda")?;
+ let device = tensor.getattr("device")?;
+ let stream = cuda.call_method1("current_stream", (device,))?;
+
+ // Defensive validation: ensure the stream is a CUDA stream on the same
device
+ let stream_device = stream.getattr("device").map_err(|_| {
+ PyRuntimeError::new_err("CUDA stream object from PyTorch is missing
'device' attribute")
+ })?;
+ let stream_device_type: String = stream_device
+ .getattr("type")
+ .and_then(|obj| obj.extract())
+ .map_err(|_| {
+ PyRuntimeError::new_err(
+ "Failed to extract CUDA stream device type from PyTorch
stream.device",
+ )
+ })?;
+ if stream_device_type != "cuda" {
+ return Err(PyRuntimeError::new_err(format!(
+ "Expected CUDA stream device type 'cuda', got '{}'",
+ stream_device_type
+ )));
+ }
+
+ let stream_device_index: i32 = stream_device
+ .getattr("index")
+ .and_then(|obj| obj.extract())
+ .map_err(|_| {
+ PyRuntimeError::new_err(
+ "Failed to extract CUDA stream device index from PyTorch
stream.device",
+ )
+ })?;
+ let tensor_device_index = get_tensor_device_id(tensor)?;
+ if stream_device_index != tensor_device_index {
+ return Err(PyRuntimeError::new_err(format!(
+ "CUDA stream device index ({}) does not match tensor device index
({})",
+ stream_device_index, tensor_device_index
+ )));
+ }
+
+ let stream_ptr: u64 = stream.getattr("cuda_stream")?.extract()?;
+ if stream_ptr == 0 {
+ return Err(PyRuntimeError::new_err(
+ "PyTorch returned a null CUDA stream pointer",
+ ));
+ }
+ Ok(stream_ptr as *mut c_void)
+}
+
/// Validate a CUDA tensor for direct GPU encoding
/// Checks: dtype=float64, contiguous, non-empty, device_id matches engine
fn validate_cuda_tensor_for_encoding(
@@ -275,6 +329,38 @@ fn validate_cuda_tensor_for_encoding(
Ok(())
}
+/// Minimal CUDA tensor metadata extracted via PyTorch APIs.
+struct CudaTensorInfo {
+ data_ptr: *const f64,
+ shape: Vec<i64>,
+}
+
+/// Extract GPU pointer and shape directly from a PyTorch CUDA tensor.
+///
+/// # Safety
+/// The returned pointer is borrowed from the source tensor. The caller must
+/// ensure the tensor remains alive and unmodified for the duration of use.
+fn extract_cuda_tensor_info(tensor: &Bound<'_, PyAny>) ->
PyResult<CudaTensorInfo> {
+ let data_ptr: u64 = tensor.call_method0("data_ptr")?.extract()?;
+ if data_ptr == 0 {
+ return Err(PyRuntimeError::new_err(
+ "PyTorch returned a null data pointer for CUDA tensor",
+ ));
+ }
+
+ let ndim: usize = tensor.call_method0("dim")?.extract()?;
+ let mut shape = Vec::with_capacity(ndim);
+ for axis in 0..ndim {
+ let dim: i64 = tensor.call_method1("size", (axis,))?.extract()?;
+ shape.push(dim);
+ }
+
+ Ok(CudaTensorInfo {
+ data_ptr: data_ptr as *const f64,
+ shape,
+ })
+}
+
/// DLPack tensor information extracted from a PyCapsule
///
/// This struct owns the DLManagedTensor pointer and ensures proper cleanup
@@ -465,6 +551,80 @@ impl QdpEngine {
// Check if it's a PyTorch tensor
if is_pytorch_tensor(data)? {
+ // Check if it's a CUDA tensor - use zero-copy GPU encoding
+ if is_cuda_tensor(data)? {
+ // Validate CUDA tensor for direct GPU encoding
+ validate_cuda_tensor_for_encoding(
+ data,
+ self.engine.device().ordinal(),
+ encoding_method,
+ )?;
+
+ // Extract GPU pointer directly from PyTorch tensor
+ let tensor_info = extract_cuda_tensor_info(data)?;
+ let stream_ptr = get_torch_cuda_stream_ptr(data)?;
+
+ let ndim: usize = data.call_method0("dim")?.extract()?;
+
+ match ndim {
+ 1 => {
+ // 1D CUDA tensor: single sample encoding
+ let input_len = tensor_info.shape[0] as usize;
+ // SAFETY: tensor_info.data_ptr was obtained via
PyTorch's data_ptr() from a
+ // valid CUDA tensor. The tensor remains alive during
this call
+ // (held by Python's GIL), and we validated
dtype/contiguity/device above.
+ let ptr = unsafe {
+ self.engine
+ .encode_from_gpu_ptr_with_stream(
+ tensor_info.data_ptr,
+ input_len,
+ num_qubits,
+ encoding_method,
+ stream_ptr,
+ )
+ .map_err(|e| {
+ PyRuntimeError::new_err(format!("Encoding
failed: {}", e))
+ })?
+ };
+ return Ok(QuantumTensor {
+ ptr,
+ consumed: false,
+ });
+ }
+ 2 => {
+ // 2D CUDA tensor: batch encoding
+ let num_samples = tensor_info.shape[0] as usize;
+ let sample_size = tensor_info.shape[1] as usize;
+ // SAFETY: Same as above - pointer from validated
PyTorch CUDA tensor
+ let ptr = unsafe {
+ self.engine
+ .encode_batch_from_gpu_ptr_with_stream(
+ tensor_info.data_ptr,
+ num_samples,
+ sample_size,
+ num_qubits,
+ encoding_method,
+ stream_ptr,
+ )
+ .map_err(|e| {
+ PyRuntimeError::new_err(format!("Encoding
failed: {}", e))
+ })?
+ };
+ return Ok(QuantumTensor {
+ ptr,
+ consumed: false,
+ });
+ }
+ _ => {
+ return Err(PyRuntimeError::new_err(format!(
+ "Unsupported CUDA tensor shape: {}D. Expected 1D
tensor for single \
+ sample encoding or 2D tensor (batch_size,
features) for batch encoding.",
+ ndim
+ )));
+ }
+ }
+ }
+ // CPU PyTorch tensor path
return self.encode_from_pytorch(data, num_qubits, encoding_method);
}