This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 576069a31 Add ArrayWriter indirection (#1764) (#2091)
576069a31 is described below
commit 576069a3111094b63d39c3c3973b7cebc90b5a94
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Jul 21 08:06:08 2022 -0400
Add ArrayWriter indirection (#1764) (#2091)
---
parquet/src/arrow/arrow_writer/mod.rs | 65 ++++++++++++++++++++++++-----------
parquet/src/file/writer.rs | 51 +++++++++++++++++++--------
2 files changed, 81 insertions(+), 35 deletions(-)
diff --git a/parquet/src/arrow/arrow_writer/mod.rs
b/parquet/src/arrow/arrow_writer/mod.rs
index 75bd6f6aa..53b094a9e 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -33,7 +33,7 @@ use super::schema::{
decimal_length_from_precision,
};
-use crate::column::writer::ColumnWriter;
+use crate::column::writer::{get_column_writer, ColumnWriter};
use crate::errors::{ParquetError, Result};
use crate::file::metadata::RowGroupMetaDataPtr;
use crate::file::properties::WriterProperties;
@@ -43,6 +43,44 @@ use levels::{calculate_array_levels, LevelInfo};
mod levels;
+/// An object-safe API for writing an [`ArrayRef`]
+trait ArrayWriter {
+ fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()>;
+
+ fn close(&mut self) -> Result<()>;
+}
+
+/// Fallback implementation for writing an [`ArrayRef`] that uses
[`SerializedColumnWriter`]
+struct ColumnArrayWriter<'a>(Option<SerializedColumnWriter<'a>>);
+
+impl<'a> ArrayWriter for ColumnArrayWriter<'a> {
+ fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()> {
+ write_leaf(self.0.as_mut().unwrap().untyped(), array, levels)?;
+ Ok(())
+ }
+
+ fn close(&mut self) -> Result<()> {
+ self.0.take().unwrap().close()
+ }
+}
+
+fn get_writer<'a, W: Write>(
+ row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
+) -> Result<Box<dyn ArrayWriter + 'a>> {
+ let array_writer = row_group_writer
+ .next_column_with_factory(|descr, props, page_writer, on_close| {
+ // TODO: Special case array readers (#1764)
+
+ let column_writer = get_column_writer(descr, props.clone(),
page_writer);
+ let serialized_writer =
+ SerializedColumnWriter::new(column_writer, Some(on_close));
+
+ Ok(Box::new(ColumnArrayWriter(Some(serialized_writer))))
+ })?
+ .expect("Unable to get column writer");
+ Ok(array_writer)
+}
+
/// Arrow writer
///
/// Writes Arrow `RecordBatch`es to a Parquet writer, buffering up
`RecordBatch` in order
@@ -229,17 +267,6 @@ impl<W: Write> ArrowWriter<W> {
}
}
-/// Convenience method to get the next ColumnWriter from the RowGroupWriter
-#[inline]
-fn get_col_writer<'a, W: Write>(
- row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
-) -> Result<SerializedColumnWriter<'a>> {
- let col_writer = row_group_writer
- .next_column()?
- .expect("Unable to get column writer");
- Ok(col_writer)
-}
-
fn write_leaves<W: Write>(
row_group_writer: &mut SerializedRowGroupWriter<'_, W>,
arrays: &[ArrayRef],
@@ -277,15 +304,14 @@ fn write_leaves<W: Write>(
| ArrowDataType::LargeUtf8
| ArrowDataType::Decimal(_, _)
| ArrowDataType::FixedSizeBinary(_) => {
- let mut col_writer = get_col_writer(row_group_writer)?;
+ let mut writer = get_writer(row_group_writer)?;
for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
- write_leaf(
- col_writer.untyped(),
+ writer.write(
array,
levels.pop().expect("Levels exhausted"),
)?;
}
- col_writer.close()?;
+ writer.close()?;
Ok(())
}
ArrowDataType::List(_) | ArrowDataType::LargeList(_) => {
@@ -338,17 +364,16 @@ fn write_leaves<W: Write>(
Ok(())
}
ArrowDataType::Dictionary(_, value_type) => {
- let mut col_writer = get_col_writer(row_group_writer)?;
+ let mut writer = get_writer(row_group_writer)?;
for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
// cast dictionary to a primitive
let array = arrow::compute::cast(array, value_type)?;
- write_leaf(
- col_writer.untyped(),
+ writer.write(
&array,
levels.pop().expect("Levels exhausted"),
)?;
}
- col_writer.close()?;
+ writer.close()?;
Ok(())
}
ArrowDataType::Float16 => Err(ParquetError::ArrowError(
diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs
index 10983c741..467273aaa 100644
--- a/parquet/src/file/writer.rs
+++ b/parquet/src/file/writer.rs
@@ -37,7 +37,9 @@ use crate::file::{
metadata::*, properties::WriterPropertiesPtr,
statistics::to_thrift as statistics_to_thrift, FOOTER_SIZE, PARQUET_MAGIC,
};
-use crate::schema::types::{self, SchemaDescPtr, SchemaDescriptor, TypePtr};
+use crate::schema::types::{
+ self, ColumnDescPtr, SchemaDescPtr, SchemaDescriptor, TypePtr,
+};
use crate::util::io::TryClone;
/// A wrapper around a [`Write`] that keeps track of the number
@@ -367,22 +369,26 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> {
}
}
- /// Returns the next column writer, if available; otherwise returns `None`.
- /// In case of any IO error or Thrift error, or if row group writer has
already been
- /// closed returns `Err`.
- pub fn next_column(&mut self) ->
Result<Option<SerializedColumnWriter<'_>>> {
+ /// Returns the next column writer, if available, using the factory
function;
+ /// otherwise returns `None`.
+ pub(crate) fn next_column_with_factory<'b, F, C>(
+ &'b mut self,
+ factory: F,
+ ) -> Result<Option<C>>
+ where
+ F: FnOnce(
+ ColumnDescPtr,
+ &'b WriterPropertiesPtr,
+ Box<dyn PageWriter + 'b>,
+ OnCloseColumnChunk<'b>,
+ ) -> Result<C>,
+ {
self.assert_previous_writer_closed()?;
if self.column_index >= self.descr.num_columns() {
return Ok(None);
}
let page_writer = Box::new(SerializedPageWriter::new(self.buf));
- let column_writer = get_column_writer(
- self.descr.column(self.column_index),
- self.props.clone(),
- page_writer,
- );
- self.column_index += 1;
let total_bytes_written = &mut self.total_bytes_written;
let total_rows_written = &mut self.total_rows_written;
@@ -413,10 +419,25 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> {
Ok(())
};
- Ok(Some(SerializedColumnWriter::new(
- column_writer,
- Some(Box::new(on_close)),
- )))
+ let column = self.descr.column(self.column_index);
+ self.column_index += 1;
+
+ Ok(Some(factory(
+ column,
+ &self.props,
+ page_writer,
+ Box::new(on_close),
+ )?))
+ }
+
+ /// Returns the next column writer, if available; otherwise returns `None`.
+ /// In case of any IO error or Thrift error, or if row group writer has
already been
+ /// closed returns `Err`.
+ pub fn next_column(&mut self) ->
Result<Option<SerializedColumnWriter<'_>>> {
+ self.next_column_with_factory(|descr, props, page_writer, on_close| {
+ let column_writer = get_column_writer(descr, props.clone(),
page_writer);
+ Ok(SerializedColumnWriter::new(column_writer, Some(on_close)))
+ })
}
/// Closes this row group writer and returns row group metadata.