This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 9863486299 Add IPC FileDecoder (#5249)
9863486299 is described below

commit 98634862994f88f5bebb4ef34cb49dcf6361c95a
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Sat Dec 30 06:46:17 2023 +0000

    Add IPC FileDecoder (#5249)
    
    * Add IPC FileDecoder
    
    * Clippy
    
    * Update arrow-ipc/src/reader.rs
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
    
    ---------
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
---
 arrow-ipc/src/reader.rs | 283 +++++++++++++++++++++++++++++++-----------------
 1 file changed, 186 insertions(+), 97 deletions(-)

diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs
index 39365b7829..8ac3a387d5 100644
--- a/arrow-ipc/src/reader.rs
+++ b/arrow-ipc/src/reader.rs
@@ -451,7 +451,7 @@ pub fn read_dictionary(
     batch: crate::DictionaryBatch,
     schema: &Schema,
     dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
-    metadata: &crate::MetadataVersion,
+    metadata: &MetadataVersion,
 ) -> Result<(), ArrowError> {
     if batch.isDelta() {
         return Err(ArrowError::InvalidArgumentError(
@@ -522,6 +522,174 @@ fn parse_message(buf: &[u8]) -> Result<Message, 
ArrowError> {
         .map_err(|err| ArrowError::ParseError(format!("Unable to get root as 
message: {err:?}")))
 }
 
+/// Read the footer length from the last 10 bytes of an Arrow IPC file
+///
+/// Expects a 4 byte footer length followed by `b"ARROW1"`
+pub fn read_footer_length(buf: [u8; 10]) -> Result<usize, ArrowError> {
+    if buf[4..] != super::ARROW_MAGIC {
+        return Err(ArrowError::ParseError(
+            "Arrow file does not contain correct footer".to_string(),
+        ));
+    }
+
+    // read footer length
+    let footer_len = i32::from_le_bytes(buf[..4].try_into().unwrap());
+    footer_len
+        .try_into()
+        .map_err(|_| ArrowError::ParseError(format!("Invalid footer length: 
{footer_len}")))
+}
+
+/// A low-level, push-based interface for reading an IPC file
+///
+/// For a higher-level interface see [`FileReader`]
+///
+/// ```
+/// # use std::sync::Arc;
+/// # use arrow_array::*;
+/// # use arrow_array::types::Int32Type;
+/// # use arrow_buffer::Buffer;
+/// # use arrow_ipc::convert::fb_to_schema;
+/// # use arrow_ipc::reader::{FileDecoder, read_footer_length};
+/// # use arrow_ipc::root_as_footer;
+/// # use arrow_ipc::writer::FileWriter;
+/// // Write an IPC file
+///
+/// let batch = RecordBatch::try_from_iter([
+///     ("a", Arc::new(Int32Array::from(vec![1, 2, 3])) as _),
+///     ("b", Arc::new(Int32Array::from(vec![1, 2, 3])) as _),
+///     ("c", Arc::new(DictionaryArray::<Int32Type>::from_iter(["hello", 
"hello", "world"])) as _),
+/// ]).unwrap();
+///
+/// let schema = batch.schema();
+///
+/// let mut out = Vec::with_capacity(1024);
+/// let mut writer = FileWriter::try_new(&mut out, schema.as_ref()).unwrap();
+/// writer.write(&batch).unwrap();
+/// writer.finish().unwrap();
+///
+/// drop(writer);
+///
+/// // Read IPC file
+///
+/// let buffer = Buffer::from_vec(out);
+/// let trailer_start = buffer.len() - 10;
+/// let footer_len = 
read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
+/// let footer = root_as_footer(&buffer[trailer_start - 
footer_len..trailer_start]).unwrap();
+///
+/// let back = fb_to_schema(footer.schema().unwrap());
+/// assert_eq!(&back, schema.as_ref());
+///
+/// let mut decoder = FileDecoder::new(schema, footer.version());
+///
+/// // Read dictionaries
+/// for block in footer.dictionaries().iter().flatten() {
+///     let block_len = block.bodyLength() as usize + block.metaDataLength() 
as usize;
+///     let data = buffer.slice_with_length(block.offset() as _, block_len);
+///     decoder.read_dictionary(&block, &data).unwrap();
+/// }
+///
+/// // Read record batch
+/// let batches = footer.recordBatches().unwrap();
+/// assert_eq!(batches.len(), 1); // Only wrote a single batch
+///
+/// let block = batches.get(0);
+/// let block_len = block.bodyLength() as usize + block.metaDataLength() as 
usize;
+/// let data = buffer.slice_with_length(block.offset() as _, block_len);
+/// let back = decoder.read_record_batch(block, &data).unwrap().unwrap();
+///
+/// assert_eq!(batch, back);
+/// ```
+#[derive(Debug)]
+pub struct FileDecoder {
+    schema: SchemaRef,
+    dictionaries: HashMap<i64, ArrayRef>,
+    version: MetadataVersion,
+    projection: Option<Vec<usize>>,
+}
+
+impl FileDecoder {
+    /// Create a new [`FileDecoder`] with the given schema and version
+    pub fn new(schema: SchemaRef, version: MetadataVersion) -> Self {
+        Self {
+            schema,
+            version,
+            dictionaries: Default::default(),
+            projection: None,
+        }
+    }
+
+    /// Specify a projection
+    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
+        self.projection = Some(projection);
+        self
+    }
+
+    fn read_message<'a>(&self, buf: &'a [u8]) -> Result<Message<'a>, 
ArrowError> {
+        let message = parse_message(buf)?;
+
+        // some old test data's footer metadata is not set, so we account for 
that
+        if self.version != MetadataVersion::V1 && message.version() != 
self.version {
+            return Err(ArrowError::IpcError(
+                "Could not read IPC message as metadata versions 
mismatch".to_string(),
+            ));
+        }
+        Ok(message)
+    }
+
+    /// Read the dictionary with the given block and data buffer
+    pub fn read_dictionary(&mut self, block: &Block, buf: &Buffer) -> 
Result<(), ArrowError> {
+        let message = self.read_message(buf)?;
+        match message.header_type() {
+            crate::MessageHeader::DictionaryBatch => {
+                let batch = message.header_as_dictionary_batch().unwrap();
+                read_dictionary(
+                    &buf.slice(block.metaDataLength() as _),
+                    batch,
+                    &self.schema,
+                    &mut self.dictionaries,
+                    &message.version(),
+                )
+            }
+            t => Err(ArrowError::ParseError(format!(
+                "Expecting DictionaryBatch in dictionary blocks, found {t:?}."
+            ))),
+        }
+    }
+
+    /// Read the RecordBatch with the given block and data buffer
+    pub fn read_record_batch(
+        &self,
+        block: &Block,
+        buf: &Buffer,
+    ) -> Result<Option<RecordBatch>, ArrowError> {
+        let message = self.read_message(buf)?;
+        match message.header_type() {
+            crate::MessageHeader::Schema => Err(ArrowError::IpcError(
+                "Not expecting a schema when messages are read".to_string(),
+            )),
+            crate::MessageHeader::RecordBatch => {
+                let batch = message.header_as_record_batch().ok_or_else(|| {
+                    ArrowError::IpcError("Unable to read IPC message as record 
batch".to_string())
+                })?;
+                // read the block that makes up the record batch into a buffer
+                read_record_batch(
+                    &buf.slice(block.metaDataLength() as _),
+                    batch,
+                    self.schema.clone(),
+                    &self.dictionaries,
+                    self.projection.as_deref(),
+                    &message.version(),
+                )
+                .map(Some)
+            }
+            crate::MessageHeader::NONE => Ok(None),
+            t => Err(ArrowError::InvalidArgumentError(format!(
+                "Reading types other than record batches not yet supported, 
unable to read {t:?}"
+            ))),
+        }
+    }
+}
+
 /// Build an Arrow [`FileReader`] with custom options.
 #[derive(Debug)]
 pub struct FileReaderBuilder {
@@ -599,17 +767,10 @@ impl FileReaderBuilder {
         reader.seek(SeekFrom::End(-10))?;
         reader.read_exact(&mut buffer)?;
 
-        if buffer[4..] != super::ARROW_MAGIC {
-            return Err(ArrowError::ParseError(
-                "Arrow file does not contain correct footer".to_string(),
-            ));
-        }
-
-        // read footer length
-        let footer_len = i32::from_le_bytes(buffer[..4].try_into().unwrap());
+        let footer_len = read_footer_length(buffer)?;
 
         // read footer
-        let mut footer_data = vec![0; footer_len as usize];
+        let mut footer_data = vec![0; footer_len];
         reader.seek(SeekFrom::End(-10 - footer_len as i64))?;
         reader.read_exact(&mut footer_data)?;
 
@@ -641,50 +802,26 @@ impl FileReaderBuilder {
             }
         }
 
+        let mut decoder = FileDecoder::new(Arc::new(schema), footer.version());
+        if let Some(projection) = self.projection {
+            decoder = decoder.with_projection(projection)
+        }
+
         // Create an array of optional dictionary value arrays, one per field.
-        let mut dictionaries_by_id = HashMap::new();
         if let Some(dictionaries) = footer.dictionaries() {
             for block in dictionaries {
                 let buf = read_block(&mut reader, block)?;
-                let message = parse_message(&buf)?;
-
-                match message.header_type() {
-                    crate::MessageHeader::DictionaryBatch => {
-                        let batch = 
message.header_as_dictionary_batch().unwrap();
-                        read_dictionary(
-                            &buf.slice(block.metaDataLength() as _),
-                            batch,
-                            &schema,
-                            &mut dictionaries_by_id,
-                            &message.version(),
-                        )?;
-                    }
-                    t => {
-                        return Err(ArrowError::ParseError(format!(
-                            "Expecting DictionaryBatch in dictionary blocks, 
found {t:?}."
-                        )));
-                    }
-                }
+                decoder.read_dictionary(block, &buf)?;
             }
         }
-        let projection = match self.projection {
-            Some(projection_indices) => {
-                let schema = schema.project(&projection_indices)?;
-                Some((projection_indices, schema))
-            }
-            _ => None,
-        };
 
         Ok(FileReader {
             reader,
-            schema: Arc::new(schema),
             blocks: blocks.iter().copied().collect(),
             current_block: 0,
             total_blocks,
-            dictionaries_by_id,
-            metadata_version: footer.version(),
+            decoder,
             custom_metadata,
-            projection,
         })
     }
 }
@@ -694,13 +831,13 @@ pub struct FileReader<R: Read + Seek> {
     /// Buffered file reader that supports reading and seeking
     reader: R,
 
-    /// The schema that is read from the file header
-    schema: SchemaRef,
+    /// The decoder
+    decoder: FileDecoder,
 
     /// The blocks in the file
     ///
     /// A block indicates the regions in the file to read to get data
-    blocks: Vec<crate::Block>,
+    blocks: Vec<Block>,
 
     /// A counter to keep track of the current block that should be read
     current_block: usize,
@@ -708,31 +845,17 @@ pub struct FileReader<R: Read + Seek> {
     /// The total number of blocks, which may contain record batches and other 
types
     total_blocks: usize,
 
-    /// Optional dictionaries for each schema field.
-    ///
-    /// Dictionaries may be appended to in the streaming format.
-    dictionaries_by_id: HashMap<i64, ArrayRef>,
-
-    /// Metadata version
-    metadata_version: crate::MetadataVersion,
-
     /// User defined metadata
     custom_metadata: HashMap<String, String>,
-
-    /// Optional projection and projected_schema
-    projection: Option<(Vec<usize>, Schema)>,
 }
 
 impl<R: Read + Seek> fmt::Debug for FileReader<R> {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), 
fmt::Error> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
         f.debug_struct("FileReader<R>")
-            .field("schema", &self.schema)
+            .field("decoder", &self.decoder)
             .field("blocks", &self.blocks)
             .field("current_block", &self.current_block)
             .field("total_blocks", &self.total_blocks)
-            .field("dictionaries_by_id", &self.dictionaries_by_id)
-            .field("metadata_version", &self.metadata_version)
-            .field("projection", &self.projection)
             .finish_non_exhaustive()
     }
 }
@@ -761,7 +884,7 @@ impl<R: Read + Seek> FileReader<R> {
 
     /// Return the schema of the file
     pub fn schema(&self) -> SchemaRef {
-        self.schema.clone()
+        self.decoder.schema.clone()
     }
 
     /// Read a specific record batch
@@ -785,41 +908,7 @@ impl<R: Read + Seek> FileReader<R> {
 
         // read length
         let buffer = read_block(&mut self.reader, block)?;
-        let message = parse_message(&buffer)?;
-
-        // some old test data's footer metadata is not set, so we account for 
that
-        if self.metadata_version != MetadataVersion::V1
-            && message.version() != self.metadata_version
-        {
-            return Err(ArrowError::IpcError(
-                "Could not read IPC message as metadata versions 
mismatch".to_string(),
-            ));
-        }
-
-        match message.header_type() {
-            crate::MessageHeader::Schema => Err(ArrowError::IpcError(
-                "Not expecting a schema when messages are read".to_string(),
-            )),
-            crate::MessageHeader::RecordBatch => {
-                let batch = message.header_as_record_batch().ok_or_else(|| {
-                    ArrowError::IpcError("Unable to read IPC message as record 
batch".to_string())
-                })?;
-                // read the block that makes up the record batch into a buffer
-                read_record_batch(
-                    &buffer.slice(block.metaDataLength() as _),
-                    batch,
-                    self.schema(),
-                    &self.dictionaries_by_id,
-                    self.projection.as_ref().map(|x| x.0.as_ref()),
-                    &message.version(),
-                )
-                .map(Some)
-            }
-            crate::MessageHeader::NONE => Ok(None),
-            t => Err(ArrowError::InvalidArgumentError(format!(
-                "Reading types other than record batches not yet supported, 
unable to read {t:?}"
-            ))),
-        }
+        self.decoder.read_record_batch(block, &buffer)
     }
 
     /// Gets a reference to the underlying reader.
@@ -852,7 +941,7 @@ impl<R: Read + Seek> Iterator for FileReader<R> {
 
 impl<R: Read + Seek> RecordBatchReader for FileReader<R> {
     fn schema(&self) -> SchemaRef {
-        self.schema.clone()
+        self.schema()
     }
 }
 

Reply via email to