CheyuWu commented on code in PR #881:
URL: https://github.com/apache/mahout/pull/881#discussion_r2708070502


##########
qdp/qdp-python/src/lib.rs:
##########
@@ -321,6 +460,78 @@ impl QdpEngine {
 
         // Check if it's a PyTorch tensor
         if is_pytorch_tensor(data)? {
+            // Check if it's a CUDA tensor - use zero-copy GPU encoding
+            if is_cuda_tensor(data)? {
+                // Validate CUDA tensor for direct GPU encoding
+                validate_cuda_tensor_for_encoding(
+                    data,
+                    self.engine.device().ordinal(),
+                    encoding_method,
+                )?;
+
+                // Extract GPU pointer via DLPack
+                let dlpack_info = extract_dlpack_tensor(data.py(), data)?;
+
+                let ndim: usize = data.call_method0("dim")?.extract()?;
+
+                match ndim {
+                    1 => {
+                        // 1D CUDA tensor: single sample encoding
+                        let input_len = dlpack_info.shape[0] as usize;
+                        // SAFETY: dlpack_info.data_ptr was validated via 
DLPack protocol from a
+                        // valid PyTorch CUDA tensor. The tensor remains alive 
during this call
+                        // (held by Python's GIL), and we validated 
dtype/contiguity/device above.
+                        let ptr = unsafe {
+                            self.engine
+                                .encode_from_gpu_ptr(
+                                    dlpack_info.data_ptr,
+                                    input_len,
+                                    num_qubits,
+                                    encoding_method,
+                                )
+                                .map_err(|e| {
+                                    PyRuntimeError::new_err(format!("Encoding 
failed: {}", e))
+                                })?
+                        };
+                        return Ok(QuantumTensor {
+                            ptr,
+                            consumed: false,
+                        });
+                    }
+                    2 => {
+                        // 2D CUDA tensor: batch encoding
+                        let num_samples = dlpack_info.shape[0] as usize;
+                        let sample_size = dlpack_info.shape[1] as usize;
+                        // SAFETY: Same as above - pointer from validated 
DLPack tensor
+                        let ptr = unsafe {
+                            self.engine
+                                .encode_batch_from_gpu_ptr(
+                                    dlpack_info.data_ptr,
+                                    num_samples,
+                                    sample_size,
+                                    num_qubits,
+                                    encoding_method,
+                                )
+                                .map_err(|e| {
+                                    PyRuntimeError::new_err(format!("Encoding 
failed: {}", e))
+                                })?
+                        };
+                        return Ok(QuantumTensor {
+                            ptr,
+                            consumed: false,
+                        });
+                    }
+                    _ => {
+                        return Err(PyRuntimeError::new_err(format!(
+                            "Unsupported CUDA tensor shape: {}D. Expected 1D 
tensor for single \
+                             sample encoding or 2D tensor (batch_size, 
features) for batch encoding.",
+                            ndim
+                        )));
+                    }
+                }
+            }
+
+            // CPU tensor path (existing code)
             validate_tensor(data)?;

Review Comment:
   I think we don't need this anymore



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to