viiccwen commented on code in PR #934:
URL: https://github.com/apache/mahout/pull/934#discussion_r2740643607


##########
qdp/qdp-core/src/lib.rs:
##########
@@ -424,151 +484,220 @@ impl QdpEngine {
     /// TODO: Refactor to use QuantumEncoder trait (see `encode_from_gpu_ptr` 
TODO).
     ///
     /// # Arguments
-    /// * `input_batch_d` - Device pointer to batch input data (flattened f64 
array on GPU)
+    /// * `input_batch_d` - Device pointer to batch input data (f64 for 
amplitude, usize/int64 for basis)
     /// * `num_samples` - Number of samples in the batch
-    /// * `sample_size` - Size of each sample in f64 elements
+    /// * `sample_size` - Size of each sample in elements
     /// * `num_qubits` - Number of qubits for encoding
-    /// * `encoding_method` - Strategy (currently only "amplitude" supported)
+    /// * `encoding_method` - Strategy ("amplitude" or "basis")
     ///
     /// # Returns
     /// Single DLPack pointer containing all encoded states (shape: 
[num_samples, 2^num_qubits])
     ///
     /// # Safety
     /// The input pointer must:
     /// - Point to valid GPU memory on the same device as the engine
-    /// - Contain at least `num_samples * sample_size` f64 elements
+    /// - Contain at least `num_samples * sample_size` elements of the 
expected dtype
     /// - Remain valid for the duration of this call
     #[cfg(target_os = "linux")]
     pub unsafe fn encode_batch_from_gpu_ptr(
         &self,
-        input_batch_d: *const f64,
+        input_batch_d: *const std::ffi::c_void,
         num_samples: usize,
         sample_size: usize,
         num_qubits: usize,
         encoding_method: &str,
     ) -> Result<*mut DLManagedTensor> {
         crate::profile_scope!("Mahout::EncodeBatchFromGpuPtr");
 
-        if encoding_method != "amplitude" {
-            return Err(MahoutError::NotImplemented(format!(
-                "GPU pointer batch encoding currently only supports 
'amplitude' method, got '{}'",
-                encoding_method
-            )));
-        }
-
-        if num_samples == 0 {
-            return Err(MahoutError::InvalidInput(
-                "Number of samples cannot be zero".into(),
-            ));
-        }
-
-        if sample_size == 0 {
-            return Err(MahoutError::InvalidInput(
-                "Sample size cannot be zero".into(),
-            ));
-        }
-
         let state_len = 1usize << num_qubits;
-        if sample_size > state_len {
-            return Err(MahoutError::InvalidInput(format!(
-                "Sample size {} exceeds state vector size {} (2^{} qubits)",
-                sample_size, state_len, num_qubits
-            )));
-        }
-
-        // Allocate output state vector
-        let batch_state_vector = {
-            crate::profile_scope!("GPU::AllocBatch");
-            gpu::GpuStateVector::new_batch(&self.device, num_samples, 
num_qubits)?
-        };
-
-        // Compute inverse norms on GPU using warp-reduced kernel
-        let inv_norms_gpu = {
-            crate::profile_scope!("GPU::BatchNormKernel");
-            use cudarc::driver::DevicePtrMut;
-
-            let mut buffer = 
self.device.alloc_zeros::<f64>(num_samples).map_err(|e| {
-                MahoutError::MemoryAllocation(format!("Failed to allocate norm 
buffer: {:?}", e))
-            })?;
-
-            let ret = unsafe {
-                qdp_kernels::launch_l2_norm_batch(
-                    input_batch_d,
-                    num_samples,
-                    sample_size,
-                    *buffer.device_ptr_mut() as *mut f64,
-                    std::ptr::null_mut(), // default stream
-                )
-            };
-
-            if ret != 0 {
-                return Err(MahoutError::KernelLaunch(format!(
-                    "Norm reduction kernel failed with CUDA error code: {} 
({})",
-                    ret,
-                    cuda_error_to_string(ret)
-                )));
+        match encoding_method {
+            "amplitude" => {
+                if num_samples == 0 {
+                    return Err(MahoutError::InvalidInput(
+                        "Number of samples cannot be zero".into(),
+                    ));
+                }
+
+                if sample_size == 0 {
+                    return Err(MahoutError::InvalidInput(
+                        "Sample size cannot be zero".into(),
+                    ));
+                }

Review Comment:
   ensure to ascii lowercase.



##########
qdp/qdp-python/src/lib.rs:
##########
@@ -218,23 +219,34 @@ fn validate_cuda_tensor_for_encoding(
     expected_device_id: usize,
     encoding_method: &str,
 ) -> PyResult<()> {
-    // Check encoding method support (currently only amplitude is supported 
for CUDA tensors)
-    if encoding_method != "amplitude" {
-        return Err(PyRuntimeError::new_err(format!(
-            "CUDA tensor encoding currently only supports 'amplitude' method, 
got '{}'. \
-             Use tensor.cpu() to convert to CPU tensor for other encoding 
methods.",
-            encoding_method
-        )));
-    }
-
-    // Check dtype is float64
+    // Check encoding method support and dtype.
     let dtype = tensor.getattr("dtype")?;
     let dtype_str: String = dtype.str()?.extract()?;
-    if !dtype_str.contains("float64") {
-        return Err(PyRuntimeError::new_err(format!(
-            "CUDA tensor must have dtype float64, got {}. Use 
tensor.to(torch.float64)",
-            dtype_str
-        )));
+    match encoding_method {
+        "amplitude" => {
+            if !dtype_str.contains("float64") {
+                return Err(PyRuntimeError::new_err(format!(
+                    "CUDA tensor must have dtype float64, got {}. Use 
tensor.to(torch.float64)",
+                    dtype_str
+                )));
+            }
+        }
+        "basis" => {
+            if !dtype_str.contains("int64") {
+                return Err(PyRuntimeError::new_err(format!(
+                    "CUDA tensor must have dtype int64 for basis encoding, got 
{}. \
+                     Use tensor.to(torch.int64)",
+                    dtype_str
+                )));
+            }
+        }
+        _ => {
+            return Err(PyRuntimeError::new_err(format!(
+                "CUDA tensor encoding currently only supports 'amplitude' or 
'basis' methods, got '{}'. \
+                 Use tensor.cpu() to convert to CPU tensor for other encoding 
methods.",
+                encoding_method
+            )));
+        }

Review Comment:
   ensure to ascii lowercase.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to