guan404ming commented on code in PR #851: URL: https://github.com/apache/mahout/pull/851#discussion_r2704377197
########## qdp/qdp-core/src/encoding/mod.rs: ########## @@ -0,0 +1,359 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Streaming encoding implementations for different quantum encoding methods. + +mod amplitude; +mod basis; + +use std::ffi::c_void; +use std::sync::Arc; +use std::sync::mpsc::{Receiver, SyncSender, sync_channel}; +use std::thread::{self, JoinHandle}; + +use cudarc::driver::{CudaDevice, DevicePtr}; + +/// Guard that ensures GPU synchronization and IO thread cleanup on drop. +/// Used to handle early returns in `stream_encode`. +struct CleanupGuard<'a> { + device: &'a Arc<CudaDevice>, + io_handle: Option<JoinHandle<()>>, +} + +impl<'a> CleanupGuard<'a> { + fn new(device: &'a Arc<CudaDevice>, io_handle: JoinHandle<()>) -> Self { + Self { + device, + io_handle: Some(io_handle), + } + } + + /// Defuse the guard and return the IO handle for explicit cleanup. + /// After calling this, drop() will not perform cleanup. + fn defuse(mut self) -> JoinHandle<()> { + self.io_handle.take().expect("IO handle already taken") + } +} + +impl Drop for CleanupGuard<'_> { + fn drop(&mut self) { + // Best-effort cleanup on early return + let _ = self.device.synchronize(); + if let Some(handle) = self.io_handle.take() { + let _ = handle.join(); + } + } +} + +use crate::dlpack::DLManagedTensor; +use crate::gpu::PipelineContext; +use crate::gpu::memory::{GpuStateVector, PinnedHostBuffer}; +use crate::reader::StreamingDataReader; +use crate::{MahoutError, QdpEngine, Result}; + +/// 512MB staging buffer for large Parquet row groups (reduces fragmentation) +pub(crate) const STAGE_SIZE_BYTES: usize = 512 * 1024 * 1024; +pub(crate) const STAGE_SIZE_ELEMENTS: usize = STAGE_SIZE_BYTES / std::mem::size_of::<f64>(); + +pub(crate) type FullBufferResult = std::result::Result<(PinnedHostBuffer, usize), MahoutError>; +pub(crate) type FullBufferChannel = (SyncSender<FullBufferResult>, Receiver<FullBufferResult>); + +/// Trait for chunk-based quantum state encoding. +/// +/// Implementations provide the encoding-specific logic while the shared +/// streaming pipeline handles IO, buffering, and GPU memory management. +pub(crate) trait ChunkEncoder { + /// Encoder-specific state (e.g., norm buffer for amplitude encoding). + type State; + + /// Validate that the sample size is appropriate for this encoding method. + fn validate_sample_size(&self, sample_size: usize) -> Result<()>; + + /// Whether this encoder needs the staging buffer H2D copy. + /// + /// If false, the streaming pipeline will skip the async copy to device + /// staging buffer, avoiding unnecessary memory bandwidth overhead. + /// Encoders that process data on CPU before uploading should return false. + fn needs_staging_copy(&self) -> bool { + true + } + + /// Initialize encoder-specific state. + fn init_state( + &self, + engine: &QdpEngine, + sample_size: usize, + num_qubits: usize, + ) -> Result<Self::State>; + + /// Encode a chunk of samples to quantum states. + /// + /// # Arguments + /// * `state` - Encoder-specific state + /// * `engine` - QDP engine for GPU operations + /// * `ctx` - Pipeline context for async operations + /// * `host_buffer` - Pinned host buffer containing input data + /// * `dev_ptr` - Device pointer to staging buffer with copied data + /// * `samples_in_chunk` - Number of samples in this chunk + /// * `sample_size` - Size of each sample in f64 elements + /// * `state_ptr_offset` - Pointer to output location in state vector + /// * `state_len` - Length of each quantum state (2^num_qubits) + /// * `num_qubits` - Number of qubits + #[allow(clippy::too_many_arguments)] + fn encode_chunk( + &self, + state: &mut Self::State, + engine: &QdpEngine, + ctx: &PipelineContext, + host_buffer: &PinnedHostBuffer, + dev_ptr: u64, + samples_in_chunk: usize, + sample_size: usize, + state_ptr_offset: *mut c_void, + state_len: usize, + num_qubits: usize, + global_sample_offset: usize, + ) -> Result<()>; +} + +/// Shared streaming pipeline for encoding data from Parquet files. +/// +/// This function handles all the common IO, buffering, and GPU memory +/// management logic. The actual encoding is delegated to the `ChunkEncoder`. +pub(crate) fn stream_encode<E: ChunkEncoder>( + engine: &QdpEngine, + path: &str, + num_qubits: usize, + encoder: E, +) -> Result<*mut DLManagedTensor> { + // Initialize reader + let mut reader_core = crate::io::ParquetBlockReader::new(path, None)?; + let num_samples = reader_core.total_rows; + + // Allocate output state vector + let total_state_vector = GpuStateVector::new_batch(&engine.device, num_samples, num_qubits)?; + const PIPELINE_EVENT_SLOTS: usize = 2; + let ctx = PipelineContext::new(&engine.device, PIPELINE_EVENT_SLOTS)?; + + // Double-buffered device staging + let dev_in_a = unsafe { engine.device.alloc::<f64>(STAGE_SIZE_ELEMENTS) } + .map_err(|e| MahoutError::MemoryAllocation(format!("{:?}", e)))?; + let dev_in_b = unsafe { engine.device.alloc::<f64>(STAGE_SIZE_ELEMENTS) } + .map_err(|e| MahoutError::MemoryAllocation(format!("{:?}", e)))?; Review Comment: Nice catch, I've updated -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
