guan404ming commented on code in PR #708:
URL: https://github.com/apache/mahout/pull/708#discussion_r2613126973


##########
qdp/qdp-core/src/lib.rs:
##########
@@ -145,14 +166,160 @@ impl QdpEngine {
     ) -> Result<*mut DLManagedTensor> {
         crate::profile_scope!("Mahout::EncodeFromParquet");
 
-        // Read Parquet directly using Arrow (faster than pandas)
-        let (batch_data, num_samples, sample_size) = {
-            crate::profile_scope!("IO::ReadParquetBatch");
-            crate::io::read_parquet_batch(path)?
-        };
+        #[cfg(target_os = "linux")]
+        {
+            if encoding_method != "amplitude" {
+                return Err(MahoutError::NotImplemented("Only amplitude 
encoding supported for streaming".into()));
+            }
 
-        // Encode using fused batch kernel
-        self.encode_batch(&batch_data, num_samples, sample_size, num_qubits, 
encoding_method)
+            // Initialize reader
+            let mut reader_core = crate::io::ParquetBlockReader::new(path)?;
+            let num_samples = reader_core.total_samples;
+
+            // Allocate GPU memory once
+            let total_state_vector = GpuStateVector::new_batch(&self.device, 
num_samples, num_qubits)?;
+
+            // Initialize dual-stream pipeline context
+            let ctx = PipelineContext::new(&self.device)?;
+
+            // Double-buffered device input (ping-pong)
+            let dev_in_a = unsafe { 
self.device.alloc::<f64>(STAGE_SIZE_ELEMENTS) }
+                .map_err(|e| MahoutError::MemoryAllocation(format!("{:?}", 
e)))?;
+            let dev_in_b = unsafe { 
self.device.alloc::<f64>(STAGE_SIZE_ELEMENTS) }
+                .map_err(|e| MahoutError::MemoryAllocation(format!("{:?}", 
e)))?;
+
+            // Setup Producer-Consumer channels
+            let (full_buf_tx, full_buf_rx): (SyncSender<(PinnedBuffer, 
usize)>, Receiver<(PinnedBuffer, usize)>) = sync_channel(2);
+            let (empty_buf_tx, empty_buf_rx): (SyncSender<PinnedBuffer>, 
Receiver<PinnedBuffer>) = sync_channel(2);
+
+            // CRITICAL FIX: Pre-read first chunk to determine sample_size
+            // This data must be processed, not discarded!
+            let mut host_buf_first = PinnedBuffer::new(STAGE_SIZE_ELEMENTS)?;
+            let first_len = 
reader_core.read_chunk(host_buf_first.as_slice_mut())?;
+
+            let sample_size = reader_core.get_sample_size()
+                .ok_or_else(|| MahoutError::InvalidInput("Could not determine 
sample size".into()))?;
+
+            // Send first chunk directly to GPU loop (must be processed first)
+            full_buf_tx.send((host_buf_first, first_len))
+                .map_err(|_| MahoutError::Io("Failed to send first 
buffer".into()))?;
+
+            // Send one empty buffer to IO thread for subsequent reads
+            empty_buf_tx.send(PinnedBuffer::new(STAGE_SIZE_ELEMENTS)?)
+                .map_err(|_| MahoutError::Io("Failed to send second 
buffer".into()))?;
+
+            // Spawn IO thread (Producer): continues reading from second chunk 
onwards
+            let mut reader = reader_core;
+            let io_handle = thread::spawn(move || {
+                loop {
+                    let mut buffer = match empty_buf_rx.recv() {
+                        Ok(b) => b,
+                        Err(_) => break,
+                    };
+
+                    let len = match reader.read_chunk(buffer.as_slice_mut()) {
+                        Ok(l) => l,
+                        Err(e) => { eprintln!("IO Error: {:?}", e); 0 }
+                    };
+
+                    if full_buf_tx.send((buffer, len)).is_err() { break; }
+                    if len == 0 { break; }
+                }
+            });
+
+            // GPU processing loop: receives pre-read chunk, then IO thread 
chunks
+            let mut global_sample_offset = 0;
+            let mut use_dev_a = true;
+            let state_len_per_sample = 1 << num_qubits;
+
+            loop {
+                let (host_buffer, current_len) = full_buf_rx.recv()
+                    .map_err(|_| MahoutError::Io("IO thread 
disconnected".into()))?;
+
+                // len == 0 means IO thread finished (don't recycle buffer)
+                if current_len == 0 { break; }
+
+                let samples_in_chunk = current_len / sample_size;
+                if samples_in_chunk > 0 {
+                    let dev_ptr = if use_dev_a { *dev_in_a.device_ptr() } else 
{ *dev_in_b.device_ptr() };
+
+                    unsafe {
+                        crate::profile_scope!("GPU::Dispatch");
+
+                        // Async H2D copy → record event → wait for copy → 
launch kernel
+                        ctx.async_copy_to_device(&host_buffer, dev_ptr as *mut 
c_void, current_len);
+                        ctx.record_copy_done();
+                        ctx.wait_for_copy();
+
+                        // Compute norms and encode batch
+                        {
+                            crate::profile_scope!("GPU::BatchEncode");
+                            let offset_elements = global_sample_offset * 
state_len_per_sample;
+                            let state_ptr_offset = 
total_state_vector.ptr().cast::<u8>()
+                                .add(offset_elements * 
std::mem::size_of::<qdp_kernels::CuDoubleComplex>())
+                                .cast::<std::ffi::c_void>();
+
+                            // Allocate norm buffer for this chunk
+                            let mut norm_buffer = 
self.device.alloc_zeros::<f64>(samples_in_chunk)
+                                .map_err(|e| 
MahoutError::MemoryAllocation(format!("Failed to allocate norm buffer: {:?}", 
e)))?;
+
+                            // Step 1: Compute L2 norms for this chunk
+                            {
+                                crate::profile_scope!("GPU::NormBatch");
+                                let ret = launch_l2_norm_batch(
+                                    dev_ptr as *const f64,
+                                    samples_in_chunk,
+                                    sample_size,
+                                    *norm_buffer.device_ptr_mut() as *mut f64,
+                                    ctx.stream_compute.stream as *mut c_void
+                                );
+                                if ret != 0 {
+                                    return 
Err(MahoutError::KernelLaunch(format!("Norm kernel error: {}", ret)));
+                                }
+                            }
+
+                            // Step 2: Encode batch using computed norms
+                            {
+                                crate::profile_scope!("GPU::EncodeBatch");
+                                let ret = launch_amplitude_encode_batch(
+                                    dev_ptr as *const f64,
+                                    state_ptr_offset,
+                                    *norm_buffer.device_ptr() as *const f64,
+                                    samples_in_chunk,
+                                    sample_size,
+                                    state_len_per_sample,
+                                    ctx.stream_compute.stream as *mut c_void
+                                );
+                                if ret != 0 {
+                                    return 
Err(MahoutError::KernelLaunch(format!("Encode kernel error: {}", ret)));
+                                }
+                            }
+                        }
+
+                        // Sync copy stream before buffer reuse
+                        ctx.sync_copy_stream();
+                    }
+                    global_sample_offset += samples_in_chunk;
+                    use_dev_a = !use_dev_a;
+                }
+
+                // Return buffer to IO thread (ignore errors if thread exited)
+                let _ = empty_buf_tx.send(host_buffer);
+            }
+
+            self.device.synchronize().map_err(|e| 
MahoutError::Cuda(format!("{:?}", e)))?;
+            let _ = io_handle.join();
+
+            // Transfer ownership to DLPack (Arc handles ref counting)
+            let dlpack_ptr = total_state_vector.to_dlpack();
+            Ok(dlpack_ptr)
+        }
+
+        #[cfg(not(target_os = "linux"))]
+        {
+            let (batch_data, num_samples, sample_size) = 
crate::io::read_parquet_batch(path)?;
+            self.encode_batch(&batch_data, num_samples, sample_size, 
num_qubits, encoding_method)
+        }
     }

Review Comment:
   It's fine to me to handle the tests in follow up or here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to