This is an automated email from the ASF dual-hosted git repository.

guan404ming pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new d9155fae9 refactor(qdp-core): generic DataReader<T> with FloatElem 
(#1343)
d9155fae9 is described below

commit d9155fae9877554f9213c9618125f0dcca649fc4
Author: ChenChen Lai <[email protected]>
AuthorDate: Sat Jun 6 15:11:10 2026 +0900

    refactor(qdp-core): generic DataReader<T> with FloatElem (#1343)
    
    * refactor(qdp-core): generic DataReader<T> with FloatElem
    
    * add test
---
 qdp/qdp-core/src/lib.rs                |   2 +-
 qdp/qdp-core/src/reader.rs             |  26 +++++---
 qdp/qdp-core/src/readers/tensorflow.rs |   2 +-
 qdp/qdp-core/tests/reader.rs           | 115 +++++++++++++++++++++++++++++++++
 4 files changed, 133 insertions(+), 12 deletions(-)

diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index 72d00898e..ac5dd5fe9 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -37,7 +37,7 @@ mod profiling;
 
 pub use error::{MahoutError, Result, cuda_error_to_string};
 pub use gpu::memory::Precision;
-pub use reader::{NullHandling, handle_float64_nulls};
+pub use reader::{FloatElem, NullHandling, handle_float64_nulls};
 pub use types::{Dtype, Encoding};
 
 // Throughput/latency pipeline runner: single path using QdpEngine and 
encode_batch in Rust.
diff --git a/qdp/qdp-core/src/reader.rs b/qdp/qdp-core/src/reader.rs
index fd36fa231..a51fd334a 100644
--- a/qdp/qdp-core/src/reader.rs
+++ b/qdp/qdp-core/src/reader.rs
@@ -49,6 +49,15 @@ use arrow::array::{Array, Float64Array};
 
 use crate::error::Result;
 
+/// Scalar element type for [`DataReader`] output (`f32` or `f64` only).
+///
+/// Keeps f32 file data as `Vec<f32>` end-to-end once readers implement
+/// `DataReader<f32>`; today most readers use the default `T = f64`.
+pub trait FloatElem: Copy + Send + Sync + 'static {}
+
+impl FloatElem for f32 {}
+impl FloatElem for f64 {}
+
 /// Policy for handling null values in Float64 arrays.
 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
 pub enum NullHandling {
@@ -96,19 +105,16 @@ pub fn handle_float64_nulls(
 ///
 /// This interface enables zero-copy streaming where possible and maintains
 /// memory efficiency for large datasets.
-//
-// Structural debt: this trait hard-codes `Vec<f64>` in `read_batch`, so
-// `float32_pipeline=true` still forces every file-backed source to materialise
-// as f64 before casting. A future refactor should parameterise over the 
element
-// type (`T: FloatElem`) so f32 pipelines stay f32 end-to-end.
-pub trait DataReader {
+///
+/// Parameterised by [`FloatElem`] (`T` defaults to `f64` for existing 
readers).
+pub trait DataReader<T: FloatElem = f64> {
     /// Read all data from the source.
     ///
     /// Returns a tuple of:
-    /// - `Vec<f64>`: Flattened batch data (all samples concatenated)
+    /// - `Vec<T>`: Flattened batch data (all samples concatenated)
     /// - `usize`: Number of samples
     /// - `usize`: Sample size (elements per sample)
-    fn read_batch(&mut self) -> Result<(Vec<f64>, usize, usize)>;
+    fn read_batch(&mut self) -> Result<(Vec<T>, usize, usize)>;
 
     /// Get the sample size if known before reading.
     ///
@@ -130,7 +136,7 @@ pub trait DataReader {
 ///
 /// This trait enables chunk-by-chunk reading for datasets that don't fit
 /// in memory, maintaining constant memory usage regardless of file size.
-pub trait StreamingDataReader: DataReader {
+pub trait StreamingDataReader<T: FloatElem = f64>: DataReader<T> {
     /// Read a chunk of data into the provided buffer.
     ///
     /// Returns the number of elements written to the buffer.
@@ -138,7 +144,7 @@ pub trait StreamingDataReader: DataReader {
     ///
     /// The implementation should respect sample boundaries - only complete
     /// samples should be written to avoid splitting samples across chunks.
-    fn read_chunk(&mut self, buffer: &mut [f64]) -> Result<usize>;
+    fn read_chunk(&mut self, buffer: &mut [T]) -> Result<usize>;
 
     /// Get the total number of rows/samples in the data source.
     ///
diff --git a/qdp/qdp-core/src/readers/tensorflow.rs 
b/qdp/qdp-core/src/readers/tensorflow.rs
index 0db45245a..9f8ab2605 100644
--- a/qdp/qdp-core/src/readers/tensorflow.rs
+++ b/qdp/qdp-core/src/readers/tensorflow.rs
@@ -193,7 +193,7 @@ impl TensorFlowReader {
 
     /// Convert `tensor_content` bytes to `Vec<f64>`.
     ///
-    /// Note: Even though `tensor_content` can be zero-copy, `DataReader` 
requires `Vec<f64>`,
+    /// Note: Even though `tensor_content` can be zero-copy, [`DataReader`] 
requires `Vec<f64>` today,
     /// so one copy is still needed. Uses memcpy (instead of element-wise 
`from_le_bytes`) for best performance.
     ///
     /// # Safety
diff --git a/qdp/qdp-core/tests/reader.rs b/qdp/qdp-core/tests/reader.rs
new file mode 100644
index 000000000..e16f72086
--- /dev/null
+++ b/qdp/qdp-core/tests/reader.rs
@@ -0,0 +1,115 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Tests for [`qdp_core::reader::DataReader`], [`StreamingDataReader`], and 
[`FloatElem`].
+
+use qdp_core::MahoutError;
+use qdp_core::Result;
+use qdp_core::reader::{DataReader, FloatElem, StreamingDataReader};
+
+struct BatchReader<T: FloatElem> {
+    data: Vec<T>,
+    num_samples: usize,
+    sample_size: usize,
+    consumed: bool,
+}
+
+impl<T: FloatElem> DataReader<T> for BatchReader<T> {
+    fn read_batch(&mut self) -> Result<(Vec<T>, usize, usize)> {
+        if self.consumed {
+            return Err(MahoutError::InvalidInput(
+                "BatchReader already consumed".to_string(),
+            ));
+        }
+        self.consumed = true;
+        Ok((self.data.clone(), self.num_samples, self.sample_size))
+    }
+}
+
+struct ChunkReader<T: FloatElem> {
+    chunks: Vec<Vec<T>>,
+    index: usize,
+    total_rows: usize,
+}
+
+impl<T: FloatElem> DataReader<T> for ChunkReader<T> {
+    fn read_batch(&mut self) -> Result<(Vec<T>, usize, usize)> {
+        Err(MahoutError::InvalidInput(
+            "ChunkReader supports streaming only".to_string(),
+        ))
+    }
+}
+
+impl<T: FloatElem> StreamingDataReader<T> for ChunkReader<T> {
+    fn read_chunk(&mut self, buffer: &mut [T]) -> Result<usize> {
+        if self.index >= self.chunks.len() {
+            return Ok(0);
+        }
+        let chunk = &self.chunks[self.index];
+        self.index += 1;
+        let n = chunk.len().min(buffer.len());
+        buffer[..n].copy_from_slice(&chunk[..n]);
+        Ok(n)
+    }
+
+    fn total_rows(&self) -> usize {
+        self.total_rows
+    }
+}
+
+#[test]
+fn data_reader_default_elem_type_is_f64() {
+    let mut reader = BatchReader {
+        data: vec![1.0, 2.0, 3.0, 4.0],
+        num_samples: 2,
+        sample_size: 2,
+        consumed: false,
+    };
+    let (data, num_samples, sample_size) = reader.read_batch().unwrap();
+    assert_eq!(num_samples, 2);
+    assert_eq!(sample_size, 2);
+    assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
+}
+
+#[test]
+fn data_reader_f32_returns_vec_f32_without_widening() {
+    let mut reader = BatchReader {
+        data: vec![1.0f32, 2.0, 3.0, 4.0],
+        num_samples: 2,
+        sample_size: 2,
+        consumed: false,
+    };
+    let (data, num_samples, sample_size) = reader.read_batch().unwrap();
+    assert_eq!(num_samples, 2);
+    assert_eq!(sample_size, 2);
+    assert_eq!(data, vec![1.0f32, 2.0, 3.0, 4.0]);
+}
+
+#[test]
+fn streaming_data_reader_f32_read_chunk() {
+    let mut reader = ChunkReader {
+        chunks: vec![vec![1.0f32, 2.0], vec![3.0, 4.0]],
+        index: 0,
+        total_rows: 2,
+    };
+    let mut buf = [0.0f32; 4];
+    assert_eq!(reader.read_chunk(&mut buf[..2]).unwrap(), 2);
+    assert_eq!(&buf[..2], &[1.0, 2.0]);
+    assert_eq!(reader.read_chunk(&mut buf[2..]).unwrap(), 2);
+    assert_eq!(&buf[2..], &[3.0, 4.0]);
+    assert_eq!(reader.read_chunk(&mut buf).unwrap(), 0);
+    assert_eq!(reader.total_rows(), 2);
+}

Reply via email to