This is an automated email from the ASF dual-hosted git repository.

ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 06db9ed865 Deduplicate and standardize deserialization logic for 
streams (#13412)
06db9ed865 is described below

commit 06db9ed865dc48bb1c87ce60d85331d385ee0f17
Author: Alihan Çelikcan <[email protected]>
AuthorDate: Sat Nov 16 09:34:16 2024 +0300

    Deduplicate and standardize deserialization logic for streams (#13412)
    
    * Add BatchDeserializer
    
    * Fix formatting
    
    * Remove unused enum value
    
    * Update datafusion/core/src/datasource/file_format/mod.rs
    
    ---------
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
 datafusion/core/src/datasource/file_format/csv.rs  | 235 ++++++++++++++++++++-
 datafusion/core/src/datasource/file_format/json.rs | 134 +++++++++++-
 datafusion/core/src/datasource/file_format/mod.rs  | 171 ++++++++++++++-
 .../core/src/datasource/physical_plan/csv.rs       |  41 +---
 .../core/src/datasource/physical_plan/json.rs      |  38 +---
 5 files changed, 547 insertions(+), 72 deletions(-)

diff --git a/datafusion/core/src/datasource/file_format/csv.rs 
b/datafusion/core/src/datasource/file_format/csv.rs
index d59e2bf71d..9f979ddf01 100644
--- a/datafusion/core/src/datasource/file_format/csv.rs
+++ b/datafusion/core/src/datasource/file_format/csv.rs
@@ -23,7 +23,10 @@ use std::fmt::{self, Debug};
 use std::sync::Arc;
 
 use super::write::orchestration::stateless_multipart_put;
-use super::{FileFormat, FileFormatFactory, DEFAULT_SCHEMA_INFER_MAX_RECORD};
+use super::{
+    Decoder, DecoderDeserializer, FileFormat, FileFormatFactory,
+    DEFAULT_SCHEMA_INFER_MAX_RECORD,
+};
 use crate::datasource::file_format::file_compression_type::FileCompressionType;
 use crate::datasource::file_format::write::BatchSerializer;
 use crate::datasource::physical_plan::{
@@ -38,8 +41,8 @@ use crate::physical_plan::{
 
 use arrow::array::RecordBatch;
 use arrow::csv::WriterBuilder;
-use arrow::datatypes::SchemaRef;
-use arrow::datatypes::{DataType, Field, Fields, Schema};
+use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
+use arrow_schema::ArrowError;
 use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
 use datafusion_common::file_options::csv_writer::CsvWriterOptions;
 use datafusion_common::{
@@ -293,6 +296,45 @@ impl CsvFormat {
     }
 }
 
+#[derive(Debug)]
+pub(crate) struct CsvDecoder {
+    inner: arrow::csv::reader::Decoder,
+}
+
+impl CsvDecoder {
+    pub(crate) fn new(decoder: arrow::csv::reader::Decoder) -> Self {
+        Self { inner: decoder }
+    }
+}
+
+impl Decoder for CsvDecoder {
+    fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
+        self.inner.decode(buf)
+    }
+
+    fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
+        self.inner.flush()
+    }
+
+    fn can_flush_early(&self) -> bool {
+        self.inner.capacity() == 0
+    }
+}
+
+impl Debug for CsvSerializer {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("CsvSerializer")
+            .field("header", &self.header)
+            .finish()
+    }
+}
+
+impl From<arrow::csv::reader::Decoder> for DecoderDeserializer<CsvDecoder> {
+    fn from(decoder: arrow::csv::reader::Decoder) -> Self {
+        DecoderDeserializer::new(CsvDecoder::new(decoder))
+    }
+}
+
 #[async_trait]
 impl FileFormat for CsvFormat {
     fn as_any(&self) -> &dyn Any {
@@ -692,23 +734,28 @@ impl DataSink for CsvSink {
 mod tests {
     use super::super::test_util::scan_format;
     use super::*;
-    use crate::arrow::util::pretty;
     use crate::assert_batches_eq;
     use 
crate::datasource::file_format::file_compression_type::FileCompressionType;
     use crate::datasource::file_format::test_util::VariableStream;
+    use crate::datasource::file_format::{
+        BatchDeserializer, DecoderDeserializer, DeserializerOutput,
+    };
     use crate::datasource::listing::ListingOptions;
+    use crate::execution::session_state::SessionStateBuilder;
     use crate::physical_plan::collect;
     use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext};
     use crate::test_util::arrow_test_data;
 
     use arrow::compute::concat_batches;
+    use arrow::csv::ReaderBuilder;
+    use arrow::util::pretty::pretty_format_batches;
+    use arrow_array::{BooleanArray, Float64Array, Int32Array, StringArray};
     use datafusion_common::cast::as_string_array;
     use datafusion_common::internal_err;
     use datafusion_common::stats::Precision;
     use datafusion_execution::runtime_env::RuntimeEnvBuilder;
     use datafusion_expr::{col, lit};
 
-    use crate::execution::session_state::SessionStateBuilder;
     use chrono::DateTime;
     use object_store::local::LocalFileSystem;
     use object_store::path::Path;
@@ -1097,7 +1144,7 @@ mod tests {
     ) -> Result<usize> {
         let df = ctx.sql(&format!("EXPLAIN {sql}")).await?;
         let result = df.collect().await?;
-        let plan = format!("{}", &pretty::pretty_format_batches(&result)?);
+        let plan = format!("{}", &pretty_format_batches(&result)?);
 
         let re = Regex::new(r"CsvExec: file_groups=\{(\d+) group").unwrap();
 
@@ -1464,4 +1511,180 @@ mod tests {
 
         Ok(())
     }
+
+    #[rstest]
+    fn test_csv_deserializer_with_finish(
+        #[values(1, 5, 17)] batch_size: usize,
+        #[values(0, 5, 93)] line_count: usize,
+    ) -> Result<()> {
+        let schema = csv_schema();
+        let generator = CsvBatchGenerator::new(batch_size, line_count);
+        let mut deserializer = csv_deserializer(batch_size, &schema);
+
+        for data in generator {
+            deserializer.digest(data);
+        }
+        deserializer.finish();
+
+        let batch_count = line_count.div_ceil(batch_size);
+
+        let mut all_batches = RecordBatch::new_empty(schema.clone());
+        for _ in 0..batch_count {
+            let output = deserializer.next()?;
+            let DeserializerOutput::RecordBatch(batch) = output else {
+                panic!("Expected RecordBatch, got {:?}", output);
+            };
+            all_batches = concat_batches(&schema, &[all_batches, batch])?;
+        }
+        assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted);
+
+        let expected = csv_expected_batch(schema, line_count)?;
+
+        assert_eq!(
+            expected.clone(),
+            all_batches.clone(),
+            "Expected:\n{}\nActual:\n{}",
+            pretty_format_batches(&[expected])?,
+            pretty_format_batches(&[all_batches])?,
+        );
+
+        Ok(())
+    }
+
+    #[rstest]
+    fn test_csv_deserializer_without_finish(
+        #[values(1, 5, 17)] batch_size: usize,
+        #[values(0, 5, 93)] line_count: usize,
+    ) -> Result<()> {
+        let schema = csv_schema();
+        let generator = CsvBatchGenerator::new(batch_size, line_count);
+        let mut deserializer = csv_deserializer(batch_size, &schema);
+
+        for data in generator {
+            deserializer.digest(data);
+        }
+
+        let batch_count = line_count / batch_size;
+
+        let mut all_batches = RecordBatch::new_empty(schema.clone());
+        for _ in 0..batch_count {
+            let output = deserializer.next()?;
+            let DeserializerOutput::RecordBatch(batch) = output else {
+                panic!("Expected RecordBatch, got {:?}", output);
+            };
+            all_batches = concat_batches(&schema, &[all_batches, batch])?;
+        }
+        assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData);
+
+        let expected = csv_expected_batch(schema, batch_count * batch_size)?;
+
+        assert_eq!(
+            expected.clone(),
+            all_batches.clone(),
+            "Expected:\n{}\nActual:\n{}",
+            pretty_format_batches(&[expected])?,
+            pretty_format_batches(&[all_batches])?,
+        );
+
+        Ok(())
+    }
+
+    struct CsvBatchGenerator {
+        batch_size: usize,
+        line_count: usize,
+        offset: usize,
+    }
+
+    impl CsvBatchGenerator {
+        fn new(batch_size: usize, line_count: usize) -> Self {
+            Self {
+                batch_size,
+                line_count,
+                offset: 0,
+            }
+        }
+    }
+
+    impl Iterator for CsvBatchGenerator {
+        type Item = Bytes;
+
+        fn next(&mut self) -> Option<Self::Item> {
+            // Return `batch_size` rows per batch:
+            let mut buffer = Vec::new();
+            for _ in 0..self.batch_size {
+                if self.offset >= self.line_count {
+                    break;
+                }
+                buffer.extend_from_slice(&csv_line(self.offset));
+                self.offset += 1;
+            }
+
+            (!buffer.is_empty()).then(|| buffer.into())
+        }
+    }
+
+    fn csv_expected_batch(
+        schema: SchemaRef,
+        line_count: usize,
+    ) -> Result<RecordBatch, DataFusionError> {
+        let mut c1 = Vec::with_capacity(line_count);
+        let mut c2 = Vec::with_capacity(line_count);
+        let mut c3 = Vec::with_capacity(line_count);
+        let mut c4 = Vec::with_capacity(line_count);
+
+        for i in 0..line_count {
+            let (int_value, float_value, bool_value, char_value) = 
csv_values(i);
+            c1.push(int_value);
+            c2.push(float_value);
+            c3.push(bool_value);
+            c4.push(char_value);
+        }
+
+        let expected = RecordBatch::try_new(
+            schema.clone(),
+            vec![
+                Arc::new(Int32Array::from(c1)),
+                Arc::new(Float64Array::from(c2)),
+                Arc::new(BooleanArray::from(c3)),
+                Arc::new(StringArray::from(c4)),
+            ],
+        )?;
+        Ok(expected)
+    }
+
+    fn csv_line(line_number: usize) -> Bytes {
+        let (int_value, float_value, bool_value, char_value) = 
csv_values(line_number);
+        format!(
+            "{},{},{},{}\n",
+            int_value, float_value, bool_value, char_value
+        )
+        .into()
+    }
+
+    fn csv_values(line_number: usize) -> (i32, f64, bool, String) {
+        let int_value = line_number as i32;
+        let float_value = line_number as f64;
+        let bool_value = line_number % 2 == 0;
+        let char_value = format!("{}-string", line_number);
+        (int_value, float_value, bool_value, char_value)
+    }
+
+    fn csv_schema() -> Arc<Schema> {
+        Arc::new(Schema::new(vec![
+            Field::new("c1", DataType::Int32, true),
+            Field::new("c2", DataType::Float64, true),
+            Field::new("c3", DataType::Boolean, true),
+            Field::new("c4", DataType::Utf8, true),
+        ]))
+    }
+
+    fn csv_deserializer(
+        batch_size: usize,
+        schema: &Arc<Schema>,
+    ) -> impl BatchDeserializer<Bytes> {
+        let decoder = ReaderBuilder::new(schema.clone())
+            .with_batch_size(batch_size)
+            .build_decoder();
+        DecoderDeserializer::new(CsvDecoder::new(decoder))
+    }
 }
diff --git a/datafusion/core/src/datasource/file_format/json.rs 
b/datafusion/core/src/datasource/file_format/json.rs
index 4f51dd5ae1..e97853e9e7 100644
--- a/datafusion/core/src/datasource/file_format/json.rs
+++ b/datafusion/core/src/datasource/file_format/json.rs
@@ -26,7 +26,8 @@ use std::sync::Arc;
 
 use super::write::orchestration::stateless_multipart_put;
 use super::{
-    FileFormat, FileFormatFactory, FileScanConfig, 
DEFAULT_SCHEMA_INFER_MAX_RECORD,
+    Decoder, DecoderDeserializer, FileFormat, FileFormatFactory, 
FileScanConfig,
+    DEFAULT_SCHEMA_INFER_MAX_RECORD,
 };
 use crate::datasource::file_format::file_compression_type::FileCompressionType;
 use crate::datasource::file_format::write::BatchSerializer;
@@ -44,6 +45,7 @@ use arrow::datatypes::SchemaRef;
 use arrow::json;
 use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter};
 use arrow_array::RecordBatch;
+use arrow_schema::ArrowError;
 use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions};
 use datafusion_common::file_options::json_writer::JsonWriterOptions;
 use datafusion_common::{not_impl_err, GetExt, DEFAULT_JSON_EXTENSION};
@@ -384,16 +386,53 @@ impl DataSink for JsonSink {
     }
 }
 
+#[derive(Debug)]
+pub(crate) struct JsonDecoder {
+    inner: json::reader::Decoder,
+}
+
+impl JsonDecoder {
+    pub(crate) fn new(decoder: json::reader::Decoder) -> Self {
+        Self { inner: decoder }
+    }
+}
+
+impl Decoder for JsonDecoder {
+    fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
+        self.inner.decode(buf)
+    }
+
+    fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
+        self.inner.flush()
+    }
+
+    fn can_flush_early(&self) -> bool {
+        false
+    }
+}
+
+impl From<json::reader::Decoder> for DecoderDeserializer<JsonDecoder> {
+    fn from(decoder: json::reader::Decoder) -> Self {
+        DecoderDeserializer::new(JsonDecoder::new(decoder))
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::super::test_util::scan_format;
     use super::*;
+    use crate::datasource::file_format::{
+        BatchDeserializer, DecoderDeserializer, DeserializerOutput,
+    };
     use crate::execution::options::NdJsonReadOptions;
     use crate::physical_plan::collect;
     use crate::prelude::{SessionConfig, SessionContext};
     use crate::test::object_store::local_unpartitioned_file;
 
+    use arrow::compute::concat_batches;
+    use arrow::json::ReaderBuilder;
     use arrow::util::pretty;
+    use arrow_schema::{DataType, Field};
     use datafusion_common::cast::as_int64_array;
     use datafusion_common::stats::Precision;
     use datafusion_common::{assert_batches_eq, internal_err};
@@ -612,4 +651,97 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn test_json_deserializer_finish() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("c1", DataType::Int64, true),
+            Field::new("c2", DataType::Int64, true),
+            Field::new("c3", DataType::Int64, true),
+            Field::new("c4", DataType::Int64, true),
+            Field::new("c5", DataType::Int64, true),
+        ]));
+        let mut deserializer = json_deserializer(1, &schema)?;
+
+        deserializer.digest(r#"{ "c1": 1, "c2": 2, "c3": 3, "c4": 4, "c5": 5 
}"#.into());
+        deserializer.digest(r#"{ "c1": 6, "c2": 7, "c3": 8, "c4": 9, "c5": 10 
}"#.into());
+        deserializer
+            .digest(r#"{ "c1": 11, "c2": 12, "c3": 13, "c4": 14, "c5": 15 
}"#.into());
+        deserializer.finish();
+
+        let mut all_batches = RecordBatch::new_empty(schema.clone());
+        for _ in 0..3 {
+            let output = deserializer.next()?;
+            let DeserializerOutput::RecordBatch(batch) = output else {
+                panic!("Expected RecordBatch, got {:?}", output);
+            };
+            all_batches = concat_batches(&schema, &[all_batches, batch])?
+        }
+        assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted);
+
+        let expected = [
+            "+----+----+----+----+----+",
+            "| c1 | c2 | c3 | c4 | c5 |",
+            "+----+----+----+----+----+",
+            "| 1  | 2  | 3  | 4  | 5  |",
+            "| 6  | 7  | 8  | 9  | 10 |",
+            "| 11 | 12 | 13 | 14 | 15 |",
+            "+----+----+----+----+----+",
+        ];
+
+        assert_batches_eq!(expected, &[all_batches]);
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_json_deserializer_no_finish() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("c1", DataType::Int64, true),
+            Field::new("c2", DataType::Int64, true),
+            Field::new("c3", DataType::Int64, true),
+            Field::new("c4", DataType::Int64, true),
+            Field::new("c5", DataType::Int64, true),
+        ]));
+        let mut deserializer = json_deserializer(1, &schema)?;
+
+        deserializer.digest(r#"{ "c1": 1, "c2": 2, "c3": 3, "c4": 4, "c5": 5 
}"#.into());
+        deserializer.digest(r#"{ "c1": 6, "c2": 7, "c3": 8, "c4": 9, "c5": 10 
}"#.into());
+        deserializer
+            .digest(r#"{ "c1": 11, "c2": 12, "c3": 13, "c4": 14, "c5": 15 
}"#.into());
+
+        let mut all_batches = RecordBatch::new_empty(schema.clone());
+        // We get RequiresMoreData after 2 batches because of how 
json::Decoder works
+        for _ in 0..2 {
+            let output = deserializer.next()?;
+            let DeserializerOutput::RecordBatch(batch) = output else {
+                panic!("Expected RecordBatch, got {:?}", output);
+            };
+            all_batches = concat_batches(&schema, &[all_batches, batch])?
+        }
+        assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData);
+
+        let expected = [
+            "+----+----+----+----+----+",
+            "| c1 | c2 | c3 | c4 | c5 |",
+            "+----+----+----+----+----+",
+            "| 1  | 2  | 3  | 4  | 5  |",
+            "| 6  | 7  | 8  | 9  | 10 |",
+            "+----+----+----+----+----+",
+        ];
+
+        assert_batches_eq!(expected, &[all_batches]);
+
+        Ok(())
+    }
+
+    fn json_deserializer(
+        batch_size: usize,
+        schema: &Arc<Schema>,
+    ) -> Result<impl BatchDeserializer<Bytes>> {
+        let decoder = ReaderBuilder::new(schema.clone())
+            .with_batch_size(batch_size)
+            .build_decoder()?;
+        Ok(DecoderDeserializer::new(JsonDecoder::new(decoder)))
+    }
 }
diff --git a/datafusion/core/src/datasource/file_format/mod.rs 
b/datafusion/core/src/datasource/file_format/mod.rs
index 5c9eb7f20a..eb2a85367f 100644
--- a/datafusion/core/src/datasource/file_format/mod.rs
+++ b/datafusion/core/src/datasource/file_format/mod.rs
@@ -32,9 +32,10 @@ pub mod parquet;
 pub mod write;
 
 use std::any::Any;
-use std::collections::HashMap;
-use std::fmt::{self, Display};
+use std::collections::{HashMap, VecDeque};
+use std::fmt::{self, Debug, Display};
 use std::sync::Arc;
+use std::task::Poll;
 
 use crate::arrow::datatypes::SchemaRef;
 use crate::datasource::physical_plan::{FileScanConfig, FileSinkConfig};
@@ -42,17 +43,20 @@ use crate::error::Result;
 use crate::execution::context::SessionState;
 use crate::physical_plan::{ExecutionPlan, Statistics};
 
-use arrow_schema::{DataType, Field, FieldRef, Schema};
+use arrow_array::RecordBatch;
+use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema};
 use datafusion_common::file_options::file_type::FileType;
 use datafusion_common::{internal_err, not_impl_err, GetExt};
 use datafusion_expr::Expr;
 use datafusion_physical_expr::PhysicalExpr;
 
 use async_trait::async_trait;
+use bytes::{Buf, Bytes};
 use datafusion_physical_expr_common::sort_expr::LexRequirement;
 use file_compression_type::FileCompressionType;
+use futures::stream::BoxStream;
+use futures::{ready, Stream, StreamExt};
 use object_store::{ObjectMeta, ObjectStore};
-use std::fmt::Debug;
 
 /// Factory for creating [`FileFormat`] instances based on session and command 
level options
 ///
@@ -168,6 +172,165 @@ pub enum FilePushdownSupport {
     Supported,
 }
 
+/// Possible outputs of a [`BatchDeserializer`].
+#[derive(Debug, PartialEq)]
+pub enum DeserializerOutput {
+    /// A successfully deserialized [`RecordBatch`].
+    RecordBatch(RecordBatch),
+    /// The deserializer requires more data to make progress.
+    RequiresMoreData,
+    /// The input data has been exhausted.
+    InputExhausted,
+}
+
+/// Trait defining a scheme for deserializing byte streams into structured 
data.
+/// Implementors of this trait are responsible for converting raw bytes into
+/// `RecordBatch` objects.
+pub trait BatchDeserializer<T>: Send + Debug {
+    /// Feeds a message for deserialization, updating the internal state of
+    /// this `BatchDeserializer`. Note that one can call this function multiple
+    /// times before calling `next`, which will queue multiple messages for
+    /// deserialization. Returns the number of bytes consumed.
+    fn digest(&mut self, message: T) -> usize;
+
+    /// Attempts to deserialize any pending messages and returns a
+    /// `DeserializerOutput` to indicate progress.
+    fn next(&mut self) -> Result<DeserializerOutput, ArrowError>;
+
+    /// Informs the deserializer that no more messages will be provided for
+    /// deserialization.
+    fn finish(&mut self);
+}
+
+/// A general interface for decoders such as [`arrow::json::reader::Decoder`] 
and
+/// [`arrow::csv::reader::Decoder`]. Defines an interface similar to
+/// [`Decoder::decode`] and [`Decoder::flush`] methods, but also includes
+/// a method to check if the decoder can flush early. Intended to be used in
+/// conjunction with [`DecoderDeserializer`].
+///
+/// [`arrow::json::reader::Decoder`]: ::arrow::json::reader::Decoder
+/// [`arrow::csv::reader::Decoder`]: ::arrow::csv::reader::Decoder
+/// [`Decoder::decode`]: ::arrow::json::reader::Decoder::decode
+/// [`Decoder::flush`]: ::arrow::json::reader::Decoder::flush
+pub(crate) trait Decoder: Send + Debug {
+    /// See [`arrow::json::reader::Decoder::decode`].
+    ///
+    /// [`arrow::json::reader::Decoder::decode`]: 
::arrow::json::reader::Decoder::decode
+    fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError>;
+
+    /// See [`arrow::json::reader::Decoder::flush`].
+    ///
+    /// [`arrow::json::reader::Decoder::flush`]: 
::arrow::json::reader::Decoder::flush
+    fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError>;
+
+    /// Whether the decoder can flush early in its current state.
+    fn can_flush_early(&self) -> bool;
+}
+
+impl<T: Decoder> Debug for DecoderDeserializer<T> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("Deserializer")
+            .field("buffered_queue", &self.buffered_queue)
+            .field("finalized", &self.finalized)
+            .finish()
+    }
+}
+
+impl<T: Decoder> BatchDeserializer<Bytes> for DecoderDeserializer<T> {
+    fn digest(&mut self, message: Bytes) -> usize {
+        if message.is_empty() {
+            return 0;
+        }
+
+        let consumed = message.len();
+        self.buffered_queue.push_back(message);
+        consumed
+    }
+
+    fn next(&mut self) -> Result<DeserializerOutput, ArrowError> {
+        while let Some(buffered) = self.buffered_queue.front_mut() {
+            let decoded = self.decoder.decode(buffered)?;
+            buffered.advance(decoded);
+
+            if buffered.is_empty() {
+                self.buffered_queue.pop_front();
+            }
+
+            // Flush when the stream ends or batch size is reached
+            // Certain implementations can flush early
+            if decoded == 0 || self.decoder.can_flush_early() {
+                return match self.decoder.flush() {
+                    Ok(Some(batch)) => 
Ok(DeserializerOutput::RecordBatch(batch)),
+                    Ok(None) => continue,
+                    Err(e) => Err(e),
+                };
+            }
+        }
+        if self.finalized {
+            Ok(DeserializerOutput::InputExhausted)
+        } else {
+            Ok(DeserializerOutput::RequiresMoreData)
+        }
+    }
+
+    fn finish(&mut self) {
+        self.finalized = true;
+        // Ensure the decoder is flushed:
+        self.buffered_queue.push_back(Bytes::new());
+    }
+}
+
+/// A generic, decoder-based deserialization scheme for processing encoded 
data.
+///
+/// This struct is responsible for converting a stream of bytes, which 
represent
+/// encoded data, into a stream of `RecordBatch` objects, following the 
specified
+/// schema and formatting options. It also handles any buffering necessary to 
satisfy
+/// the `Decoder` interface.
+pub(crate) struct DecoderDeserializer<T: Decoder> {
+    /// The underlying decoder used for deserialization
+    pub(crate) decoder: T,
+    /// The buffer used to store the remaining bytes to be decoded
+    pub(crate) buffered_queue: VecDeque<Bytes>,
+    /// Whether the input stream has been fully consumed
+    pub(crate) finalized: bool,
+}
+
+impl<T: Decoder> DecoderDeserializer<T> {
+    /// Creates a new `DecoderDeserializer` with the provided decoder.
+    pub(crate) fn new(decoder: T) -> Self {
+        DecoderDeserializer {
+            decoder,
+            buffered_queue: VecDeque::new(),
+            finalized: false,
+        }
+    }
+}
+
+/// Deserializes a stream of bytes into a stream of [`RecordBatch`] objects 
using the
+/// provided deserializer.
+///
+/// Returns a boxed stream of `Result<RecordBatch, ArrowError>`. The stream 
yields [`RecordBatch`]
+/// objects as they are produced by the deserializer, or an [`ArrowError`] if 
an error
+/// occurs while polling the input or deserializing.
+pub(crate) fn deserialize_stream<'a>(
+    mut input: impl Stream<Item = Result<Bytes>> + Unpin + Send + 'a,
+    mut deserializer: impl BatchDeserializer<Bytes> + 'a,
+) -> BoxStream<'a, Result<RecordBatch, ArrowError>> {
+    futures::stream::poll_fn(move |cx| loop {
+        match ready!(input.poll_next_unpin(cx)).transpose()? {
+            Some(b) => _ = deserializer.digest(b),
+            None => deserializer.finish(),
+        };
+
+        return match deserializer.next()? {
+            DeserializerOutput::RecordBatch(rb) => Poll::Ready(Some(Ok(rb))),
+            DeserializerOutput::InputExhausted => Poll::Ready(None),
+            DeserializerOutput::RequiresMoreData => continue,
+        };
+    })
+    .boxed()
+}
+
 /// A container of [FileFormatFactory] which also implements [FileType].
 /// This enables converting a dyn FileFormat to a dyn FileType.
 /// The former trait is a superset of the latter trait, which includes 
execution time
diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs 
b/datafusion/core/src/datasource/physical_plan/csv.rs
index 1679acf303..0c41f69c76 100644
--- a/datafusion/core/src/datasource/physical_plan/csv.rs
+++ b/datafusion/core/src/datasource/physical_plan/csv.rs
@@ -24,6 +24,7 @@ use std::task::Poll;
 
 use super::{calculate_range, FileGroupPartitioner, FileScanConfig, 
RangeCalculation};
 use crate::datasource::file_format::file_compression_type::FileCompressionType;
+use crate::datasource::file_format::{deserialize_stream, DecoderDeserializer};
 use crate::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile};
 use crate::datasource::physical_plan::file_stream::{
     FileOpenFuture, FileOpener, FileStream,
@@ -42,8 +43,7 @@ use datafusion_common::config::ConfigOptions;
 use datafusion_execution::TaskContext;
 use datafusion_physical_expr::{EquivalenceProperties, LexOrdering};
 
-use bytes::{Buf, Bytes};
-use futures::{ready, StreamExt, TryStreamExt};
+use futures::{StreamExt, TryStreamExt};
 use object_store::buffered::BufWriter;
 use object_store::{GetOptions, GetResultPayload, ObjectStore};
 use tokio::io::AsyncWriteExt;
@@ -651,36 +651,14 @@ impl FileOpener for CsvOpener {
                     Ok(futures::stream::iter(config.open(decoder)?).boxed())
                 }
                 GetResultPayload::Stream(s) => {
-                    let mut decoder = config.builder().build_decoder();
+                    let decoder = config.builder().build_decoder();
                     let s = s.map_err(DataFusionError::from);
-                    let mut input =
-                        
file_compression_type.convert_stream(s.boxed())?.fuse();
-                    let mut buffered = Bytes::new();
-
-                    let s = futures::stream::poll_fn(move |cx| {
-                        loop {
-                            if buffered.is_empty() {
-                                match ready!(input.poll_next_unpin(cx)) {
-                                    Some(Ok(b)) => buffered = b,
-                                    Some(Err(e)) => {
-                                        return Poll::Ready(Some(Err(e.into())))
-                                    }
-                                    None => {}
-                                };
-                            }
-                            let decoded = match 
decoder.decode(buffered.as_ref()) {
-                                // Note: the decoder needs to be called with 
an empty
-                                // array to delimt the final record
-                                Ok(0) => break,
-                                Ok(decoded) => decoded,
-                                Err(e) => return Poll::Ready(Some(Err(e))),
-                            };
-                            buffered.advance(decoded);
-                        }
-
-                        Poll::Ready(decoder.flush().transpose())
-                    });
-                    Ok(s.boxed())
+                    let input = 
file_compression_type.convert_stream(s.boxed())?.fuse();
+
+                    Ok(deserialize_stream(
+                        input,
+                        DecoderDeserializer::from(decoder),
+                    ))
                 }
             }
         }))
@@ -753,6 +731,7 @@ mod tests {
     use crate::{scalar::ScalarValue, test_util::aggr_test_schema};
 
     use arrow::datatypes::*;
+    use bytes::Bytes;
     use datafusion_common::test_util::arrow_test_data;
 
     use datafusion_common::config::CsvOptions;
diff --git a/datafusion/core/src/datasource/physical_plan/json.rs 
b/datafusion/core/src/datasource/physical_plan/json.rs
index 7b0a605aed..c86f8fbd26 100644
--- a/datafusion/core/src/datasource/physical_plan/json.rs
+++ b/datafusion/core/src/datasource/physical_plan/json.rs
@@ -24,6 +24,7 @@ use std::task::Poll;
 
 use super::{calculate_range, FileGroupPartitioner, FileScanConfig, 
RangeCalculation};
 use crate::datasource::file_format::file_compression_type::FileCompressionType;
+use crate::datasource::file_format::{deserialize_stream, DecoderDeserializer};
 use crate::datasource::listing::{ListingTableUrl, PartitionedFile};
 use crate::datasource::physical_plan::file_stream::{
     FileOpenFuture, FileOpener, FileStream,
@@ -41,8 +42,7 @@ use arrow::{datatypes::SchemaRef, json};
 use datafusion_execution::TaskContext;
 use datafusion_physical_expr::{EquivalenceProperties, LexOrdering};
 
-use bytes::{Buf, Bytes};
-use futures::{ready, StreamExt, TryStreamExt};
+use futures::{StreamExt, TryStreamExt};
 use object_store::buffered::BufWriter;
 use object_store::{GetOptions, GetResultPayload, ObjectStore};
 use tokio::io::AsyncWriteExt;
@@ -312,37 +312,15 @@ impl FileOpener for JsonOpener {
                 GetResultPayload::Stream(s) => {
                     let s = s.map_err(DataFusionError::from);
 
-                    let mut decoder = ReaderBuilder::new(schema)
+                    let decoder = ReaderBuilder::new(schema)
                         .with_batch_size(batch_size)
                         .build_decoder()?;
-                    let mut input =
-                        
file_compression_type.convert_stream(s.boxed())?.fuse();
-                    let mut buffer = Bytes::new();
-
-                    let s = futures::stream::poll_fn(move |cx| {
-                        loop {
-                            if buffer.is_empty() {
-                                match ready!(input.poll_next_unpin(cx)) {
-                                    Some(Ok(b)) => buffer = b,
-                                    Some(Err(e)) => {
-                                        return Poll::Ready(Some(Err(e.into())))
-                                    }
-                                    None => {}
-                                };
-                            }
-
-                            let decoded = match 
decoder.decode(buffer.as_ref()) {
-                                Ok(0) => break,
-                                Ok(decoded) => decoded,
-                                Err(e) => return Poll::Ready(Some(Err(e))),
-                            };
-
-                            buffer.advance(decoded);
-                        }
+                    let input = 
file_compression_type.convert_stream(s.boxed())?.fuse();
 
-                        Poll::Ready(decoder.flush().transpose())
-                    });
-                    Ok(s.boxed())
+                    Ok(deserialize_stream(
+                        input,
+                        DecoderDeserializer::from(decoder),
+                    ))
                 }
             }
         }))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to