This is an automated email from the ASF dual-hosted git repository.
guan404ming pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/main by this push:
new d9155fae9 refactor(qdp-core): generic DataReader<T> with FloatElem
(#1343)
d9155fae9 is described below
commit d9155fae9877554f9213c9618125f0dcca649fc4
Author: ChenChen Lai <[email protected]>
AuthorDate: Sat Jun 6 15:11:10 2026 +0900
refactor(qdp-core): generic DataReader<T> with FloatElem (#1343)
* refactor(qdp-core): generic DataReader<T> with FloatElem
* add test
---
qdp/qdp-core/src/lib.rs | 2 +-
qdp/qdp-core/src/reader.rs | 26 +++++---
qdp/qdp-core/src/readers/tensorflow.rs | 2 +-
qdp/qdp-core/tests/reader.rs | 115 +++++++++++++++++++++++++++++++++
4 files changed, 133 insertions(+), 12 deletions(-)
diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index 72d00898e..ac5dd5fe9 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -37,7 +37,7 @@ mod profiling;
pub use error::{MahoutError, Result, cuda_error_to_string};
pub use gpu::memory::Precision;
-pub use reader::{NullHandling, handle_float64_nulls};
+pub use reader::{FloatElem, NullHandling, handle_float64_nulls};
pub use types::{Dtype, Encoding};
// Throughput/latency pipeline runner: single path using QdpEngine and
encode_batch in Rust.
diff --git a/qdp/qdp-core/src/reader.rs b/qdp/qdp-core/src/reader.rs
index fd36fa231..a51fd334a 100644
--- a/qdp/qdp-core/src/reader.rs
+++ b/qdp/qdp-core/src/reader.rs
@@ -49,6 +49,15 @@ use arrow::array::{Array, Float64Array};
use crate::error::Result;
+/// Scalar element type for [`DataReader`] output (`f32` or `f64` only).
+///
+/// Keeps f32 file data as `Vec<f32>` end-to-end once readers implement
+/// `DataReader<f32>`; today most readers use the default `T = f64`.
+pub trait FloatElem: Copy + Send + Sync + 'static {}
+
+impl FloatElem for f32 {}
+impl FloatElem for f64 {}
+
/// Policy for handling null values in Float64 arrays.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum NullHandling {
@@ -96,19 +105,16 @@ pub fn handle_float64_nulls(
///
/// This interface enables zero-copy streaming where possible and maintains
/// memory efficiency for large datasets.
-//
-// Structural debt: this trait hard-codes `Vec<f64>` in `read_batch`, so
-// `float32_pipeline=true` still forces every file-backed source to materialise
-// as f64 before casting. A future refactor should parameterise over the
element
-// type (`T: FloatElem`) so f32 pipelines stay f32 end-to-end.
-pub trait DataReader {
+///
+/// Parameterised by [`FloatElem`] (`T` defaults to `f64` for existing
readers).
+pub trait DataReader<T: FloatElem = f64> {
/// Read all data from the source.
///
/// Returns a tuple of:
- /// - `Vec<f64>`: Flattened batch data (all samples concatenated)
+ /// - `Vec<T>`: Flattened batch data (all samples concatenated)
/// - `usize`: Number of samples
/// - `usize`: Sample size (elements per sample)
- fn read_batch(&mut self) -> Result<(Vec<f64>, usize, usize)>;
+ fn read_batch(&mut self) -> Result<(Vec<T>, usize, usize)>;
/// Get the sample size if known before reading.
///
@@ -130,7 +136,7 @@ pub trait DataReader {
///
/// This trait enables chunk-by-chunk reading for datasets that don't fit
/// in memory, maintaining constant memory usage regardless of file size.
-pub trait StreamingDataReader: DataReader {
+pub trait StreamingDataReader<T: FloatElem = f64>: DataReader<T> {
/// Read a chunk of data into the provided buffer.
///
/// Returns the number of elements written to the buffer.
@@ -138,7 +144,7 @@ pub trait StreamingDataReader: DataReader {
///
/// The implementation should respect sample boundaries - only complete
/// samples should be written to avoid splitting samples across chunks.
- fn read_chunk(&mut self, buffer: &mut [f64]) -> Result<usize>;
+ fn read_chunk(&mut self, buffer: &mut [T]) -> Result<usize>;
/// Get the total number of rows/samples in the data source.
///
diff --git a/qdp/qdp-core/src/readers/tensorflow.rs
b/qdp/qdp-core/src/readers/tensorflow.rs
index 0db45245a..9f8ab2605 100644
--- a/qdp/qdp-core/src/readers/tensorflow.rs
+++ b/qdp/qdp-core/src/readers/tensorflow.rs
@@ -193,7 +193,7 @@ impl TensorFlowReader {
/// Convert `tensor_content` bytes to `Vec<f64>`.
///
- /// Note: Even though `tensor_content` can be zero-copy, `DataReader`
requires `Vec<f64>`,
+ /// Note: Even though `tensor_content` can be zero-copy, [`DataReader`]
requires `Vec<f64>` today,
/// so one copy is still needed. Uses memcpy (instead of element-wise
`from_le_bytes`) for best performance.
///
/// # Safety
diff --git a/qdp/qdp-core/tests/reader.rs b/qdp/qdp-core/tests/reader.rs
new file mode 100644
index 000000000..e16f72086
--- /dev/null
+++ b/qdp/qdp-core/tests/reader.rs
@@ -0,0 +1,115 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Tests for [`qdp_core::reader::DataReader`], [`StreamingDataReader`], and
[`FloatElem`].
+
+use qdp_core::MahoutError;
+use qdp_core::Result;
+use qdp_core::reader::{DataReader, FloatElem, StreamingDataReader};
+
+struct BatchReader<T: FloatElem> {
+ data: Vec<T>,
+ num_samples: usize,
+ sample_size: usize,
+ consumed: bool,
+}
+
+impl<T: FloatElem> DataReader<T> for BatchReader<T> {
+ fn read_batch(&mut self) -> Result<(Vec<T>, usize, usize)> {
+ if self.consumed {
+ return Err(MahoutError::InvalidInput(
+ "BatchReader already consumed".to_string(),
+ ));
+ }
+ self.consumed = true;
+ Ok((self.data.clone(), self.num_samples, self.sample_size))
+ }
+}
+
+struct ChunkReader<T: FloatElem> {
+ chunks: Vec<Vec<T>>,
+ index: usize,
+ total_rows: usize,
+}
+
+impl<T: FloatElem> DataReader<T> for ChunkReader<T> {
+ fn read_batch(&mut self) -> Result<(Vec<T>, usize, usize)> {
+ Err(MahoutError::InvalidInput(
+ "ChunkReader supports streaming only".to_string(),
+ ))
+ }
+}
+
+impl<T: FloatElem> StreamingDataReader<T> for ChunkReader<T> {
+ fn read_chunk(&mut self, buffer: &mut [T]) -> Result<usize> {
+ if self.index >= self.chunks.len() {
+ return Ok(0);
+ }
+ let chunk = &self.chunks[self.index];
+ self.index += 1;
+ let n = chunk.len().min(buffer.len());
+ buffer[..n].copy_from_slice(&chunk[..n]);
+ Ok(n)
+ }
+
+ fn total_rows(&self) -> usize {
+ self.total_rows
+ }
+}
+
+#[test]
+fn data_reader_default_elem_type_is_f64() {
+ let mut reader = BatchReader {
+ data: vec![1.0, 2.0, 3.0, 4.0],
+ num_samples: 2,
+ sample_size: 2,
+ consumed: false,
+ };
+ let (data, num_samples, sample_size) = reader.read_batch().unwrap();
+ assert_eq!(num_samples, 2);
+ assert_eq!(sample_size, 2);
+ assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
+}
+
+#[test]
+fn data_reader_f32_returns_vec_f32_without_widening() {
+ let mut reader = BatchReader {
+ data: vec![1.0f32, 2.0, 3.0, 4.0],
+ num_samples: 2,
+ sample_size: 2,
+ consumed: false,
+ };
+ let (data, num_samples, sample_size) = reader.read_batch().unwrap();
+ assert_eq!(num_samples, 2);
+ assert_eq!(sample_size, 2);
+ assert_eq!(data, vec![1.0f32, 2.0, 3.0, 4.0]);
+}
+
+#[test]
+fn streaming_data_reader_f32_read_chunk() {
+ let mut reader = ChunkReader {
+ chunks: vec![vec![1.0f32, 2.0], vec![3.0, 4.0]],
+ index: 0,
+ total_rows: 2,
+ };
+ let mut buf = [0.0f32; 4];
+ assert_eq!(reader.read_chunk(&mut buf[..2]).unwrap(), 2);
+ assert_eq!(&buf[..2], &[1.0, 2.0]);
+ assert_eq!(reader.read_chunk(&mut buf[2..]).unwrap(), 2);
+ assert_eq!(&buf[2..], &[3.0, 4.0]);
+ assert_eq!(reader.read_chunk(&mut buf).unwrap(), 0);
+ assert_eq!(reader.total_rows(), 2);
+}