This is an automated email from the ASF dual-hosted git repository.

jiekaichang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new ae6e4219f [QDP] PyTorch CUDA stream‑aware encode for GPU tensors (#930)
ae6e4219f is described below

commit ae6e4219f1474f5dbb4b15c625dbc0c3e862570e
Author: Jie-Kai Chang <[email protected]>
AuthorDate: Fri Jan 30 17:45:05 2026 +0800

    [QDP] PyTorch CUDA stream‑aware encode for GPU tensors (#930)
    
    * GPU pointer validation stream sync
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix ci error
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix conflicts
    
    Signed-off-by: 400Ping <[email protected]>
    
    * update
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix conflicts
    
    Signed-off-by: 400Ping <[email protected]>
    
    * fix pre-commit
    
    Signed-off-by: 400Ping <[email protected]>
    
    ---------
    
    Signed-off-by: 400Ping <[email protected]>
    Signed-off-by: 400Ping <[email protected]>
---
 qdp/qdp-core/src/gpu/cuda_sync.rs           |  76 +++++++++++++
 qdp/qdp-core/src/gpu/encodings/amplitude.rs |  18 +++-
 qdp/qdp-core/src/gpu/mod.rs                 |   2 +
 qdp/qdp-core/src/gpu/pipeline.rs            |  18 ++--
 qdp/qdp-core/src/lib.rs                     | 100 +++++++++++++----
 qdp/qdp-python/src/lib.rs                   | 160 ++++++++++++++++++++++++++++
 6 files changed, 341 insertions(+), 33 deletions(-)

diff --git a/qdp/qdp-core/src/gpu/cuda_sync.rs 
b/qdp/qdp-core/src/gpu/cuda_sync.rs
new file mode 100644
index 000000000..077da077d
--- /dev/null
+++ b/qdp/qdp-core/src/gpu/cuda_sync.rs
@@ -0,0 +1,76 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Shared CUDA stream synchronization with unified error reporting.
+
+#[cfg(target_os = "linux")]
+use std::ffi::c_void;
+
+use crate::error::{MahoutError, Result, cuda_error_to_string};
+
+/// Synchronizes a CUDA stream and returns a consistent error with context.
+///
+/// Error message format: `"{context}: {code} ({description})"` so that all
+/// call sites report stream sync failures the same way.
+///
+/// # Arguments
+/// * `stream` - CUDA stream pointer (e.g. from PyTorch or default null)
+/// * `context` - Short description for the error message (e.g. "Norm stream 
synchronize failed")
+///
+/// # Safety
+/// The stream pointer must be valid for the duration of this call.
+#[cfg(target_os = "linux")]
+pub(crate) fn sync_cuda_stream(stream: *mut c_void, context: &str) -> 
Result<()> {
+    let ret = unsafe { crate::gpu::cuda_ffi::cudaStreamSynchronize(stream) };
+    if ret != 0 {
+        return Err(MahoutError::Cuda(format!(
+            "{}: {} ({})",
+            context,
+            ret,
+            cuda_error_to_string(ret)
+        )));
+    }
+    Ok(())
+}
+
+#[cfg(all(test, target_os = "linux"))]
+mod tests {
+    use super::*;
+    use std::ffi::c_void;
+    use std::ptr;
+
+    #[test]
+    fn sync_null_stream_does_not_panic() {
+        // Default stream (null) sync: may succeed or fail depending on 
driver/context.
+        let _ = sync_cuda_stream(ptr::null_mut::<c_void>(), "test context");
+    }
+
+    #[test]
+    fn error_message_format_includes_context() {
+        // When sync fails, error must be MahoutError::Cuda with format 
"{context}: {code} ({desc})".
+        // We build the same format as sync_cuda_stream to assert consistency.
+        let context = "TestContext";
+        let code = 999i32;
+        let desc = crate::error::cuda_error_to_string(code);
+        let msg = format!("{}: {} ({})", context, code, desc);
+        assert!(
+            msg.starts_with("TestContext:"),
+            "format should start with context"
+        );
+        assert!(msg.contains("TestContext"), "format should contain context");
+        assert!(msg.contains("999"), "format should contain error code");
+    }
+}
diff --git a/qdp/qdp-core/src/gpu/encodings/amplitude.rs 
b/qdp/qdp-core/src/gpu/encodings/amplitude.rs
index 34c486c57..917336a08 100644
--- a/qdp/qdp-core/src/gpu/encodings/amplitude.rs
+++ b/qdp/qdp-core/src/gpu/encodings/amplitude.rs
@@ -33,6 +33,8 @@ use cudarc::driver::CudaDevice;
 #[cfg(target_os = "linux")]
 use crate::gpu::cuda_ffi::cudaMemsetAsync;
 #[cfg(target_os = "linux")]
+use crate::gpu::cuda_sync::sync_cuda_stream;
+#[cfg(target_os = "linux")]
 use crate::gpu::memory::{ensure_device_memory_available, map_allocation_error};
 #[cfg(target_os = "linux")]
 use cudarc::driver::{DevicePtr, DevicePtrMut};
@@ -436,6 +438,18 @@ impl AmplitudeEncoder {
         device: &Arc<CudaDevice>,
         input_ptr: *const f64,
         len: usize,
+    ) -> Result<f64> {
+        unsafe {
+            Self::calculate_inv_norm_gpu_with_stream(device, input_ptr, len, 
std::ptr::null_mut())
+        }
+    }
+
+    #[cfg(target_os = "linux")]
+    pub(crate) unsafe fn calculate_inv_norm_gpu_with_stream(
+        device: &Arc<CudaDevice>,
+        input_ptr: *const f64,
+        len: usize,
+        stream: *mut c_void,
     ) -> Result<f64> {
         crate::profile_scope!("GPU::NormSingle");
 
@@ -448,7 +462,7 @@ impl AmplitudeEncoder {
                 input_ptr,
                 len,
                 *norm_buffer.device_ptr_mut() as *mut f64,
-                std::ptr::null_mut(), // default stream
+                stream,
             )
         };
 
@@ -460,6 +474,8 @@ impl AmplitudeEncoder {
             )));
         }
 
+        sync_cuda_stream(stream, "Norm stream synchronize failed")?;
+
         let inv_norm_host = device
             .dtoh_sync_copy(&norm_buffer)
             .map_err(|e| MahoutError::Cuda(format!("Failed to copy norm to 
host: {:?}", e)))?;
diff --git a/qdp/qdp-core/src/gpu/mod.rs b/qdp/qdp-core/src/gpu/mod.rs
index 964662af7..7e16be7be 100644
--- a/qdp/qdp-core/src/gpu/mod.rs
+++ b/qdp/qdp-core/src/gpu/mod.rs
@@ -16,6 +16,8 @@
 
 #[cfg(target_os = "linux")]
 pub mod buffer_pool;
+#[cfg(target_os = "linux")]
+pub(crate) mod cuda_sync;
 pub mod encodings;
 pub mod memory;
 #[cfg(target_os = "linux")]
diff --git a/qdp/qdp-core/src/gpu/pipeline.rs b/qdp/qdp-core/src/gpu/pipeline.rs
index 073ab1c7d..175621603 100644
--- a/qdp/qdp-core/src/gpu/pipeline.rs
+++ b/qdp/qdp-core/src/gpu/pipeline.rs
@@ -29,9 +29,11 @@ use crate::gpu::buffer_pool::{PinnedBufferHandle, 
PinnedBufferPool};
 #[cfg(target_os = "linux")]
 use crate::gpu::cuda_ffi::{
     CUDA_EVENT_DISABLE_TIMING, CUDA_MEMCPY_HOST_TO_DEVICE, 
cudaEventCreateWithFlags,
-    cudaEventDestroy, cudaEventRecord, cudaMemcpyAsync, cudaStreamSynchronize, 
cudaStreamWaitEvent,
+    cudaEventDestroy, cudaEventRecord, cudaMemcpyAsync, cudaStreamWaitEvent,
 };
 #[cfg(target_os = "linux")]
+use crate::gpu::cuda_sync::sync_cuda_stream;
+#[cfg(target_os = "linux")]
 use crate::gpu::memory::{ensure_device_memory_available, map_allocation_error};
 #[cfg(target_os = "linux")]
 use crate::gpu::overlap_tracker::OverlapTracker;
@@ -179,16 +181,10 @@ impl PipelineContext {
     /// The context and its copy stream must be valid and not destroyed while 
syncing.
     pub unsafe fn sync_copy_stream(&self) -> Result<()> {
         crate::profile_scope!("Pipeline::SyncCopy");
-        unsafe {
-            let ret = cudaStreamSynchronize(self.stream_copy.stream as *mut 
c_void);
-            if ret != 0 {
-                return Err(MahoutError::Cuda(format!(
-                    "cudaStreamSynchronize(copy) failed: {}",
-                    ret
-                )));
-            }
-        }
-        Ok(())
+        sync_cuda_stream(
+            self.stream_copy.stream as *mut c_void,
+            "cudaStreamSynchronize(copy) failed",
+        )
     }
 }
 
diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index 9a5290447..3de648fc9 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -35,9 +35,13 @@ mod profiling;
 pub use error::{MahoutError, Result, cuda_error_to_string};
 pub use gpu::memory::Precision;
 
+#[cfg(target_os = "linux")]
+use std::ffi::c_void;
 use std::sync::Arc;
 
 use crate::dlpack::DLManagedTensor;
+#[cfg(target_os = "linux")]
+use crate::gpu::cuda_sync::sync_cuda_stream;
 use crate::gpu::get_encoder;
 use cudarc::driver::CudaDevice;
 
@@ -362,6 +366,9 @@ impl QdpEngine {
     /// a raw GPU pointer directly, avoiding the GPU→CPU→GPU copy that would 
otherwise
     /// be required.
     ///
+    /// Uses the default CUDA stream. For PyTorch stream interop, use
+    /// `encode_from_gpu_ptr_with_stream`.
+    ///
     /// TODO: Refactor to use QuantumEncoder trait (add `encode_from_gpu_ptr` 
to trait)
     /// to reduce duplication with AmplitudeEncoder::encode(). This would also 
make it
     /// easier to add GPU pointer support for other encoders (angle, basis) in 
the future.
@@ -387,6 +394,34 @@ impl QdpEngine {
         input_len: usize,
         num_qubits: usize,
         encoding_method: &str,
+    ) -> Result<*mut DLManagedTensor> {
+        unsafe {
+            self.encode_from_gpu_ptr_with_stream(
+                input_d,
+                input_len,
+                num_qubits,
+                encoding_method,
+                std::ptr::null_mut(),
+            )
+        }
+    }
+
+    /// Encode from existing GPU pointer on a specified CUDA stream.
+    ///
+    /// The caller must ensure the stream is valid for the device, and that any
+    /// producer work on that stream has been enqueued before this call.
+    ///
+    /// # Safety
+    /// In addition to the `encode_from_gpu_ptr` requirements, the stream 
pointer
+    /// must remain valid for the duration of this call.
+    #[cfg(target_os = "linux")]
+    pub unsafe fn encode_from_gpu_ptr_with_stream(
+        &self,
+        input_d: *const f64,
+        input_len: usize,
+        num_qubits: usize,
+        encoding_method: &str,
+        stream: *mut c_void,
     ) -> Result<*mut DLManagedTensor> {
         crate::profile_scope!("Mahout::EncodeFromGpuPtr");
 
@@ -399,7 +434,7 @@ impl QdpEngine {
         validate_cuda_input_ptr(&self.device, input_d)?;
 
         let state_len = 1usize << num_qubits;
-        let method = encoding_method.to_ascii_lowercase();
+        let method = encoding_method.to_lowercase();
 
         match method.as_str() {
             "amplitude" => {
@@ -417,12 +452,12 @@ impl QdpEngine {
 
                 let inv_norm = {
                     crate::profile_scope!("GPU::NormFromPtr");
-                    // SAFETY: input_d validity is guaranteed by the caller's 
safety contract
                     unsafe {
-                        gpu::AmplitudeEncoder::calculate_inv_norm_gpu(
+                        
gpu::AmplitudeEncoder::calculate_inv_norm_gpu_with_stream(
                             &self.device,
                             input_d,
                             input_len,
+                            stream,
                         )?
                     }
                 };
@@ -442,7 +477,7 @@ impl QdpEngine {
                             input_len,
                             state_len,
                             inv_norm,
-                            std::ptr::null_mut(), // default stream
+                            stream,
                         )
                     };
 
@@ -457,9 +492,7 @@ impl QdpEngine {
 
                 {
                     crate::profile_scope!("GPU::Synchronize");
-                    self.device.synchronize().map_err(|e| {
-                        MahoutError::Cuda(format!("CUDA device synchronize 
failed: {:?}", e))
-                    })?;
+                    sync_cuda_stream(stream, "CUDA stream synchronize 
failed")?;
                 }
 
                 let state_vector = state_vector.to_precision(&self.device, 
self.precision)?;
@@ -492,7 +525,7 @@ impl QdpEngine {
                             state_ptr as *mut std::ffi::c_void,
                             state_len,
                             num_qubits as u32,
-                            std::ptr::null_mut(), // default stream
+                            stream,
                         )
                     };
 
@@ -507,9 +540,7 @@ impl QdpEngine {
 
                 {
                     crate::profile_scope!("GPU::Synchronize");
-                    self.device.synchronize().map_err(|e| {
-                        MahoutError::Cuda(format!("CUDA device synchronize 
failed: {:?}", e))
-                    })?;
+                    sync_cuda_stream(stream, "CUDA stream synchronize 
failed")?;
                 }
 
                 let state_vector = state_vector.to_precision(&self.device, 
self.precision)?;
@@ -525,6 +556,8 @@ impl QdpEngine {
     /// Encode batch from existing GPU pointer (zero-copy for CUDA tensors)
     ///
     /// This method enables zero-copy batch encoding from PyTorch CUDA tensors.
+    /// Uses the default CUDA stream. For PyTorch stream interop, use
+    /// `encode_batch_from_gpu_ptr_with_stream`.
     ///
     /// TODO: Refactor to use QuantumEncoder trait (see `encode_from_gpu_ptr` 
TODO).
     ///
@@ -551,6 +584,33 @@ impl QdpEngine {
         sample_size: usize,
         num_qubits: usize,
         encoding_method: &str,
+    ) -> Result<*mut DLManagedTensor> {
+        unsafe {
+            self.encode_batch_from_gpu_ptr_with_stream(
+                input_batch_d,
+                num_samples,
+                sample_size,
+                num_qubits,
+                encoding_method,
+                std::ptr::null_mut(),
+            )
+        }
+    }
+
+    /// Encode batch from existing GPU pointer on a specified CUDA stream.
+    ///
+    /// # Safety
+    /// In addition to the `encode_batch_from_gpu_ptr` requirements, the 
stream pointer
+    /// must remain valid for the duration of this call.
+    #[cfg(target_os = "linux")]
+    pub unsafe fn encode_batch_from_gpu_ptr_with_stream(
+        &self,
+        input_batch_d: *const f64,
+        num_samples: usize,
+        sample_size: usize,
+        num_qubits: usize,
+        encoding_method: &str,
+        stream: *mut c_void,
     ) -> Result<*mut DLManagedTensor> {
         crate::profile_scope!("Mahout::EncodeBatchFromGpuPtr");
 
@@ -602,7 +662,7 @@ impl QdpEngine {
                             num_samples,
                             sample_size,
                             *buffer.device_ptr_mut() as *mut f64,
-                            std::ptr::null_mut(), // default stream
+                            stream,
                         )
                     };
 
@@ -619,6 +679,7 @@ impl QdpEngine {
 
                 {
                     crate::profile_scope!("GPU::NormValidation");
+                    sync_cuda_stream(stream, "Norm stream synchronize 
failed")?;
                     let host_inv_norms =
                         self.device.dtoh_sync_copy(&inv_norms_gpu).map_err(|e| 
{
                             MahoutError::Cuda(format!("Failed to copy norms to 
host: {:?}", e))
@@ -650,7 +711,7 @@ impl QdpEngine {
                             num_samples,
                             sample_size,
                             state_len,
-                            std::ptr::null_mut(), // default stream
+                            stream,
                         )
                     };
 
@@ -665,9 +726,7 @@ impl QdpEngine {
 
                 {
                     crate::profile_scope!("GPU::Synchronize");
-                    self.device
-                        .synchronize()
-                        .map_err(|e| MahoutError::Cuda(format!("Sync failed: 
{:?}", e)))?;
+                    sync_cuda_stream(stream, "CUDA stream synchronize 
failed")?;
                 }
 
                 let batch_state_vector =
@@ -702,7 +761,7 @@ impl QdpEngine {
                             num_samples,
                             sample_size,
                             *buffer.device_ptr_mut() as *mut f64,
-                            std::ptr::null_mut(), // default stream
+                            stream,
                         )
                     };
 
@@ -719,6 +778,7 @@ impl QdpEngine {
 
                 {
                     
crate::profile_scope!("GPU::AngleFiniteValidationHostCopy");
+                    sync_cuda_stream(stream, "Angle norm stream synchronize 
failed")?;
                     let host_norms = self
                         .device
                         .dtoh_sync_copy(&angle_validation_buffer)
@@ -758,7 +818,7 @@ impl QdpEngine {
                             num_samples,
                             state_len,
                             num_qubits as u32,
-                            std::ptr::null_mut(), // default stream
+                            stream,
                         )
                     };
 
@@ -773,9 +833,7 @@ impl QdpEngine {
 
                 {
                     crate::profile_scope!("GPU::Synchronize");
-                    self.device
-                        .synchronize()
-                        .map_err(|e| MahoutError::Cuda(format!("Sync failed: 
{:?}", e)))?;
+                    sync_cuda_stream(stream, "CUDA stream synchronize 
failed")?;
                 }
 
                 let batch_state_vector =
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index 1af58a617..1a6ebe735 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -20,6 +20,7 @@ use pyo3::ffi;
 use pyo3::prelude::*;
 use qdp_core::dlpack::DLManagedTensor;
 use qdp_core::{Precision, QdpEngine as CoreEngine};
+use std::ffi::c_void;
 
 /// Quantum tensor wrapper implementing DLPack protocol
 ///
@@ -221,6 +222,59 @@ fn get_tensor_device_id(tensor: &Bound<'_, PyAny>) -> 
PyResult<i32> {
     Ok(device_index)
 }
 
+/// Get the current CUDA stream pointer for the tensor's device.
+fn get_torch_cuda_stream_ptr(tensor: &Bound<'_, PyAny>) -> PyResult<*mut 
c_void> {
+    let py = tensor.py();
+    let torch = PyModule::import(py, "torch")
+        .map_err(|_| PyRuntimeError::new_err("Failed to import torch 
module"))?;
+    let cuda = torch.getattr("cuda")?;
+    let device = tensor.getattr("device")?;
+    let stream = cuda.call_method1("current_stream", (device,))?;
+
+    // Defensive validation: ensure the stream is a CUDA stream on the same 
device
+    let stream_device = stream.getattr("device").map_err(|_| {
+        PyRuntimeError::new_err("CUDA stream object from PyTorch is missing 
'device' attribute")
+    })?;
+    let stream_device_type: String = stream_device
+        .getattr("type")
+        .and_then(|obj| obj.extract())
+        .map_err(|_| {
+            PyRuntimeError::new_err(
+                "Failed to extract CUDA stream device type from PyTorch 
stream.device",
+            )
+        })?;
+    if stream_device_type != "cuda" {
+        return Err(PyRuntimeError::new_err(format!(
+            "Expected CUDA stream device type 'cuda', got '{}'",
+            stream_device_type
+        )));
+    }
+
+    let stream_device_index: i32 = stream_device
+        .getattr("index")
+        .and_then(|obj| obj.extract())
+        .map_err(|_| {
+            PyRuntimeError::new_err(
+                "Failed to extract CUDA stream device index from PyTorch 
stream.device",
+            )
+        })?;
+    let tensor_device_index = get_tensor_device_id(tensor)?;
+    if stream_device_index != tensor_device_index {
+        return Err(PyRuntimeError::new_err(format!(
+            "CUDA stream device index ({}) does not match tensor device index 
({})",
+            stream_device_index, tensor_device_index
+        )));
+    }
+
+    let stream_ptr: u64 = stream.getattr("cuda_stream")?.extract()?;
+    if stream_ptr == 0 {
+        return Err(PyRuntimeError::new_err(
+            "PyTorch returned a null CUDA stream pointer",
+        ));
+    }
+    Ok(stream_ptr as *mut c_void)
+}
+
 /// Validate a CUDA tensor for direct GPU encoding
 /// Checks: dtype=float64, contiguous, non-empty, device_id matches engine
 fn validate_cuda_tensor_for_encoding(
@@ -275,6 +329,38 @@ fn validate_cuda_tensor_for_encoding(
     Ok(())
 }
 
+/// Minimal CUDA tensor metadata extracted via PyTorch APIs.
+struct CudaTensorInfo {
+    data_ptr: *const f64,
+    shape: Vec<i64>,
+}
+
+/// Extract GPU pointer and shape directly from a PyTorch CUDA tensor.
+///
+/// # Safety
+/// The returned pointer is borrowed from the source tensor. The caller must
+/// ensure the tensor remains alive and unmodified for the duration of use.
+fn extract_cuda_tensor_info(tensor: &Bound<'_, PyAny>) -> 
PyResult<CudaTensorInfo> {
+    let data_ptr: u64 = tensor.call_method0("data_ptr")?.extract()?;
+    if data_ptr == 0 {
+        return Err(PyRuntimeError::new_err(
+            "PyTorch returned a null data pointer for CUDA tensor",
+        ));
+    }
+
+    let ndim: usize = tensor.call_method0("dim")?.extract()?;
+    let mut shape = Vec::with_capacity(ndim);
+    for axis in 0..ndim {
+        let dim: i64 = tensor.call_method1("size", (axis,))?.extract()?;
+        shape.push(dim);
+    }
+
+    Ok(CudaTensorInfo {
+        data_ptr: data_ptr as *const f64,
+        shape,
+    })
+}
+
 /// DLPack tensor information extracted from a PyCapsule
 ///
 /// This struct owns the DLManagedTensor pointer and ensures proper cleanup
@@ -465,6 +551,80 @@ impl QdpEngine {
 
         // Check if it's a PyTorch tensor
         if is_pytorch_tensor(data)? {
+            // Check if it's a CUDA tensor - use zero-copy GPU encoding
+            if is_cuda_tensor(data)? {
+                // Validate CUDA tensor for direct GPU encoding
+                validate_cuda_tensor_for_encoding(
+                    data,
+                    self.engine.device().ordinal(),
+                    encoding_method,
+                )?;
+
+                // Extract GPU pointer directly from PyTorch tensor
+                let tensor_info = extract_cuda_tensor_info(data)?;
+                let stream_ptr = get_torch_cuda_stream_ptr(data)?;
+
+                let ndim: usize = data.call_method0("dim")?.extract()?;
+
+                match ndim {
+                    1 => {
+                        // 1D CUDA tensor: single sample encoding
+                        let input_len = tensor_info.shape[0] as usize;
+                        // SAFETY: tensor_info.data_ptr was obtained via 
PyTorch's data_ptr() from a
+                        // valid CUDA tensor. The tensor remains alive during 
this call
+                        // (held by Python's GIL), and we validated 
dtype/contiguity/device above.
+                        let ptr = unsafe {
+                            self.engine
+                                .encode_from_gpu_ptr_with_stream(
+                                    tensor_info.data_ptr,
+                                    input_len,
+                                    num_qubits,
+                                    encoding_method,
+                                    stream_ptr,
+                                )
+                                .map_err(|e| {
+                                    PyRuntimeError::new_err(format!("Encoding 
failed: {}", e))
+                                })?
+                        };
+                        return Ok(QuantumTensor {
+                            ptr,
+                            consumed: false,
+                        });
+                    }
+                    2 => {
+                        // 2D CUDA tensor: batch encoding
+                        let num_samples = tensor_info.shape[0] as usize;
+                        let sample_size = tensor_info.shape[1] as usize;
+                        // SAFETY: Same as above - pointer from validated 
PyTorch CUDA tensor
+                        let ptr = unsafe {
+                            self.engine
+                                .encode_batch_from_gpu_ptr_with_stream(
+                                    tensor_info.data_ptr,
+                                    num_samples,
+                                    sample_size,
+                                    num_qubits,
+                                    encoding_method,
+                                    stream_ptr,
+                                )
+                                .map_err(|e| {
+                                    PyRuntimeError::new_err(format!("Encoding 
failed: {}", e))
+                                })?
+                        };
+                        return Ok(QuantumTensor {
+                            ptr,
+                            consumed: false,
+                        });
+                    }
+                    _ => {
+                        return Err(PyRuntimeError::new_err(format!(
+                            "Unsupported CUDA tensor shape: {}D. Expected 1D 
tensor for single \
+                             sample encoding or 2D tensor (batch_size, 
features) for batch encoding.",
+                            ndim
+                        )));
+                    }
+                }
+            }
+            // CPU PyTorch tensor path
             return self.encode_from_pytorch(data, num_qubits, encoding_method);
         }
 

Reply via email to