This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new b8fd43248 Don't hydrate string dictionaries when writing to parquet 
(#1764) (#2322)
b8fd43248 is described below

commit b8fd4324893bdd2acff36e31583157564f9d0f4c
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Aug 5 12:35:17 2022 +0100

    Don't hydrate string dictionaries when writing to parquet (#1764) (#2322)
---
 arrow/src/array/array_dictionary.rs          |  13 +++-
 arrow/src/util/data_gen.rs                   |  11 +++
 parquet/benches/arrow_writer.rs              |  31 ++++++++
 parquet/src/arrow/arrow_writer/byte_array.rs |  92 +++++++++++++++--------
 parquet/src/arrow/arrow_writer/mod.rs        | 107 ++++++++-------------------
 5 files changed, 147 insertions(+), 107 deletions(-)

diff --git a/arrow/src/array/array_dictionary.rs 
b/arrow/src/array/array_dictionary.rs
index 2afc7a69e..2acb51750 100644
--- a/arrow/src/array/array_dictionary.rs
+++ b/arrow/src/array/array_dictionary.rs
@@ -421,7 +421,6 @@ impl<T: ArrowPrimitiveType> fmt::Debug for 
DictionaryArray<T> {
 ///     assert_eq!(maybe_val.unwrap(), orig)
 /// }
 /// ```
-#[derive(Copy, Clone)]
 pub struct TypedDictionaryArray<'a, K: ArrowPrimitiveType, V> {
     /// The dictionary array
     dictionary: &'a DictionaryArray<K>,
@@ -429,6 +428,18 @@ pub struct TypedDictionaryArray<'a, K: ArrowPrimitiveType, 
V> {
     values: &'a V,
 }
 
+// Manually implement `Clone` to avoid `V: Clone` type constraint
+impl<'a, K: ArrowPrimitiveType, V> Clone for TypedDictionaryArray<'a, K, V> {
+    fn clone(&self) -> Self {
+        Self {
+            dictionary: self.dictionary,
+            values: self.values,
+        }
+    }
+}
+
+impl<'a, K: ArrowPrimitiveType, V> Copy for TypedDictionaryArray<'a, K, V> {}
+
 impl<'a, K: ArrowPrimitiveType, V> fmt::Debug for TypedDictionaryArray<'a, K, 
V> {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         writeln!(f, "TypedDictionaryArray({:?})", self.dictionary)
diff --git a/arrow/src/util/data_gen.rs b/arrow/src/util/data_gen.rs
index 21b8ee8c9..4d974409a 100644
--- a/arrow/src/util/data_gen.rs
+++ b/arrow/src/util/data_gen.rs
@@ -143,6 +143,17 @@ pub fn create_random_array(
                 })
                 .collect::<Result<Vec<(&str, ArrayRef)>>>()?,
         )?),
+        d @ Dictionary(_, value_type)
+            if crate::compute::can_cast_types(value_type, d) =>
+        {
+            let f = Field::new(
+                field.name(),
+                value_type.as_ref().clone(),
+                field.is_nullable(),
+            );
+            let v = create_random_array(&f, size, null_density, true_density)?;
+            crate::compute::cast(&v, d)?
+        }
         other => {
             return Err(ArrowError::NotYetImplemented(format!(
                 "Generating random arrays not yet implemented for {:?}",
diff --git a/parquet/benches/arrow_writer.rs b/parquet/benches/arrow_writer.rs
index 25ff1ca90..ddca1e53c 100644
--- a/parquet/benches/arrow_writer.rs
+++ b/parquet/benches/arrow_writer.rs
@@ -92,6 +92,25 @@ fn create_string_bench_batch(
     )?)
 }
 
+fn create_string_dictionary_bench_batch(
+    size: usize,
+    null_density: f32,
+    true_density: f32,
+) -> Result<RecordBatch> {
+    let fields = vec![Field::new(
+        "_1",
+        DataType::Dictionary(Box::new(DataType::Int32), 
Box::new(DataType::Utf8)),
+        true,
+    )];
+    let schema = Schema::new(fields);
+    Ok(create_random_batch(
+        Arc::new(schema),
+        size,
+        null_density,
+        true_density,
+    )?)
+}
+
 fn create_string_bench_batch_non_null(
     size: usize,
     null_density: f32,
@@ -346,6 +365,18 @@ fn bench_primitive_writer(c: &mut Criterion) {
         b.iter(|| write_batch(&batch).unwrap())
     });
 
+    let batch = create_string_dictionary_bench_batch(4096, 0.25, 
0.75).unwrap();
+    group.throughput(Throughput::Bytes(
+        batch
+            .columns()
+            .iter()
+            .map(|f| f.get_array_memory_size() as u64)
+            .sum(),
+    ));
+    group.bench_function("4096 values string dictionary", |b| {
+        b.iter(|| write_batch(&batch).unwrap())
+    });
+
     let batch = create_string_bench_batch_non_null(4096, 0.25, 0.75).unwrap();
     group.throughput(Throughput::Bytes(
         batch
diff --git a/parquet/src/arrow/arrow_writer/byte_array.rs 
b/parquet/src/arrow/arrow_writer/byte_array.rs
index d1a0da5b3..a7b6ccc3f 100644
--- a/parquet/src/arrow/arrow_writer/byte_array.rs
+++ b/parquet/src/arrow/arrow_writer/byte_array.rs
@@ -16,7 +16,6 @@
 // under the License.
 
 use crate::arrow::arrow_writer::levels::LevelInfo;
-use crate::arrow::arrow_writer::ArrayWriter;
 use crate::basic::Encoding;
 use crate::column::page::PageWriter;
 use crate::column::writer::encoder::{
@@ -33,11 +32,38 @@ use crate::schema::types::ColumnDescPtr;
 use crate::util::bit_util::num_required_bits;
 use crate::util::interner::{Interner, Storage};
 use arrow::array::{
-    Array, ArrayAccessor, ArrayRef, BinaryArray, LargeBinaryArray, 
LargeStringArray,
-    StringArray,
+    Array, ArrayAccessor, ArrayRef, BinaryArray, DictionaryArray, 
LargeBinaryArray,
+    LargeStringArray, StringArray,
 };
 use arrow::datatypes::DataType;
 
+macro_rules! downcast_dict_impl {
+    ($array:ident, $key:ident, $val:ident, $op:expr $(, $arg:expr)*) => {{
+        $op($array
+            .as_any()
+            .downcast_ref::<DictionaryArray<arrow::datatypes::$key>>()
+            .unwrap()
+            .downcast_dict::<$val>()
+            .unwrap()$(, $arg)*)
+    }};
+}
+
+macro_rules! downcast_dict_op {
+    ($key_type:expr, $val:ident, $array:ident, $op:expr $(, $arg:expr)*) => {
+        match $key_type.as_ref() {
+            DataType::UInt8 => downcast_dict_impl!($array, UInt8Type, $val, 
$op$(, $arg)*),
+            DataType::UInt16 => downcast_dict_impl!($array, UInt16Type, $val, 
$op$(, $arg)*),
+            DataType::UInt32 => downcast_dict_impl!($array, UInt32Type, $val, 
$op$(, $arg)*),
+            DataType::UInt64 => downcast_dict_impl!($array, UInt64Type, $val, 
$op$(, $arg)*),
+            DataType::Int8 => downcast_dict_impl!($array, Int8Type, $val, 
$op$(, $arg)*),
+            DataType::Int16 => downcast_dict_impl!($array, Int16Type, $val, 
$op$(, $arg)*),
+            DataType::Int32 => downcast_dict_impl!($array, Int32Type, $val, 
$op$(, $arg)*),
+            DataType::Int64 => downcast_dict_impl!($array, Int64Type, $val, 
$op$(, $arg)*),
+            _ => unreachable!(),
+        }
+    };
+}
+
 macro_rules! downcast_op {
     ($data_type:expr, $array:ident, $op:expr $(, $arg:expr)*) => {
         match $data_type {
@@ -51,36 +77,44 @@ macro_rules! downcast_op {
             DataType::LargeBinary => {
                 
$op($array.as_any().downcast_ref::<LargeBinaryArray>().unwrap()$(, $arg)*)
             }
-            d => unreachable!("cannot downcast {} to byte array", d)
+            DataType::Dictionary(key, value) => match value.as_ref() {
+                DataType::Utf8 => downcast_dict_op!(key, StringArray, $array, 
$op$(, $arg)*),
+                DataType::LargeUtf8 => {
+                    downcast_dict_op!(key, LargeStringArray, $array, $op$(, 
$arg)*)
+                }
+                DataType::Binary => downcast_dict_op!(key, BinaryArray, 
$array, $op$(, $arg)*),
+                DataType::LargeBinary => {
+                    downcast_dict_op!(key, LargeBinaryArray, $array, $op$(, 
$arg)*)
+                }
+                d => unreachable!("cannot downcast {} dictionary value to byte 
array", d),
+            },
+            d => unreachable!("cannot downcast {} to byte array", d),
         }
     };
 }
 
-/// Returns an [`ArrayWriter`] for byte or string arrays
-pub(super) fn make_byte_array_writer<'a>(
-    descr: ColumnDescPtr,
-    data_type: DataType,
-    props: WriterPropertiesPtr,
-    page_writer: Box<dyn PageWriter + 'a>,
-    on_close: OnCloseColumnChunk<'a>,
-) -> Box<dyn ArrayWriter + 'a> {
-    Box::new(ByteArrayWriter {
-        writer: Some(GenericColumnWriter::new(descr, props, page_writer)),
-        on_close: Some(on_close),
-        data_type,
-    })
-}
-
-/// An [`ArrayWriter`] for [`ByteArray`]
-struct ByteArrayWriter<'a> {
-    writer: Option<GenericColumnWriter<'a, ByteArrayEncoder>>,
+/// A writer for byte array types
+pub(super) struct ByteArrayWriter<'a> {
+    writer: GenericColumnWriter<'a, ByteArrayEncoder>,
     on_close: Option<OnCloseColumnChunk<'a>>,
-    data_type: DataType,
 }
 
-impl<'a> ArrayWriter for ByteArrayWriter<'a> {
-    fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()> {
-        self.writer.as_mut().unwrap().write_batch_internal(
+impl<'a> ByteArrayWriter<'a> {
+    /// Returns a new [`ByteArrayWriter`]
+    pub fn new(
+        descr: ColumnDescPtr,
+        props: &'a WriterPropertiesPtr,
+        page_writer: Box<dyn PageWriter + 'a>,
+        on_close: OnCloseColumnChunk<'a>,
+    ) -> Result<Self> {
+        Ok(Self {
+            writer: GenericColumnWriter::new(descr, props.clone(), 
page_writer),
+            on_close: Some(on_close),
+        })
+    }
+
+    pub fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()> 
{
+        self.writer.write_batch_internal(
             array,
             Some(levels.non_null_indices()),
             levels.def_levels(),
@@ -92,11 +126,11 @@ impl<'a> ArrayWriter for ByteArrayWriter<'a> {
         Ok(())
     }
 
-    fn close(&mut self) -> Result<()> {
+    pub fn close(self) -> Result<()> {
         let (bytes_written, rows_written, metadata, column_index, 
offset_index) =
-            self.writer.take().unwrap().close()?;
+            self.writer.close()?;
 
-        if let Some(on_close) = self.on_close.take() {
+        if let Some(on_close) = self.on_close {
             on_close(
                 bytes_written,
                 rows_written,
diff --git a/parquet/src/arrow/arrow_writer/mod.rs 
b/parquet/src/arrow/arrow_writer/mod.rs
index 49531d972..800aff98a 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -33,70 +33,18 @@ use super::schema::{
     decimal_length_from_precision,
 };
 
-use crate::column::writer::{get_column_writer, ColumnWriter, ColumnWriterImpl};
+use crate::arrow::arrow_writer::byte_array::ByteArrayWriter;
+use crate::column::writer::{ColumnWriter, ColumnWriterImpl};
 use crate::errors::{ParquetError, Result};
 use crate::file::metadata::RowGroupMetaDataPtr;
 use crate::file::properties::WriterProperties;
-use crate::file::writer::{SerializedColumnWriter, SerializedRowGroupWriter};
+use crate::file::writer::SerializedRowGroupWriter;
 use crate::{data_type::*, file::writer::SerializedFileWriter};
 use levels::{calculate_array_levels, LevelInfo};
 
 mod byte_array;
 mod levels;
 
-/// An object-safe API for writing an [`ArrayRef`]
-trait ArrayWriter {
-    fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()>;
-
-    fn close(&mut self) -> Result<()>;
-}
-
-/// Fallback implementation for writing an [`ArrayRef`] that uses 
[`SerializedColumnWriter`]
-struct ColumnArrayWriter<'a>(Option<SerializedColumnWriter<'a>>);
-
-impl<'a> ArrayWriter for ColumnArrayWriter<'a> {
-    fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()> {
-        write_leaf(self.0.as_mut().unwrap().untyped(), array, levels)?;
-        Ok(())
-    }
-
-    fn close(&mut self) -> Result<()> {
-        self.0.take().unwrap().close()
-    }
-}
-
-fn get_writer<'a, W: Write>(
-    row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
-    data_type: &ArrowDataType,
-) -> Result<Box<dyn ArrayWriter + 'a>> {
-    let array_writer = row_group_writer
-        .next_column_with_factory(
-            |descr, props, page_writer, on_close| match data_type {
-                ArrowDataType::Utf8
-                | ArrowDataType::LargeUtf8
-                | ArrowDataType::Binary
-                | ArrowDataType::LargeBinary => 
Ok(byte_array::make_byte_array_writer(
-                    descr,
-                    data_type.clone(),
-                    props.clone(),
-                    page_writer,
-                    on_close,
-                )),
-                _ => {
-                    let column_writer =
-                        get_column_writer(descr, props.clone(), page_writer);
-
-                    let serialized_writer =
-                        SerializedColumnWriter::new(column_writer, 
Some(on_close));
-
-                    Ok(Box::new(ColumnArrayWriter(Some(serialized_writer))))
-                }
-            },
-        )?
-        .expect("Unable to get column writer");
-    Ok(array_writer)
-}
-
 /// Arrow writer
 ///
 /// Writes Arrow `RecordBatch`es to a Parquet writer, buffering up 
`RecordBatch` in order
@@ -314,22 +262,24 @@ fn write_leaves<W: Write>(
         | ArrowDataType::Time64(_)
         | ArrowDataType::Duration(_)
         | ArrowDataType::Interval(_)
-        | ArrowDataType::LargeBinary
-        | ArrowDataType::Binary
-        | ArrowDataType::Utf8
-        | ArrowDataType::LargeUtf8
         | ArrowDataType::Decimal128(_, _)
         | ArrowDataType::Decimal256(_, _)
         | ArrowDataType::FixedSizeBinary(_) => {
-            let mut writer = get_writer(row_group_writer, &data_type)?;
+            let mut col_writer = row_group_writer.next_column()?.unwrap();
             for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
-                writer.write(
-                    array,
-                    levels.pop().expect("Levels exhausted"),
-                )?;
+                write_leaf(col_writer.untyped(), array, 
levels.pop().expect("Levels exhausted"))?;
             }
-            writer.close()?;
-            Ok(())
+            col_writer.close()
+        }
+        ArrowDataType::LargeBinary
+        | ArrowDataType::Binary
+        | ArrowDataType::Utf8
+        | ArrowDataType::LargeUtf8 => {
+            let mut col_writer = 
row_group_writer.next_column_with_factory(ByteArrayWriter::new)?.unwrap();
+            for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
+                col_writer.write(array, levels.pop().expect("Levels 
exhausted"))?;
+            }
+            col_writer.close()
         }
         ArrowDataType::List(_) | ArrowDataType::LargeList(_) => {
             let arrays: Vec<_> = arrays.iter().map(|array|{
@@ -380,18 +330,21 @@ fn write_leaves<W: Write>(
             write_leaves(row_group_writer, &values, levels)?;
             Ok(())
         }
-        ArrowDataType::Dictionary(_, value_type) => {
-            let mut writer = get_writer(row_group_writer, value_type)?;
-            for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
-                // cast dictionary to a primitive
-                let array = arrow::compute::cast(array, value_type)?;
-                writer.write(
-                    &array,
-                    levels.pop().expect("Levels exhausted"),
-                )?;
+        ArrowDataType::Dictionary(_, value_type) => match value_type.as_ref() {
+            ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | 
ArrowDataType::Binary | ArrowDataType::LargeBinary => {
+                let mut col_writer = 
row_group_writer.next_column_with_factory(ByteArrayWriter::new)?.unwrap();
+                for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
+                    col_writer.write(array, levels.pop().expect("Levels 
exhausted"))?;
+                }
+                col_writer.close()
+            }
+            _ => {
+                let mut col_writer = row_group_writer.next_column()?.unwrap();
+                for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
+                    write_leaf(col_writer.untyped(), array, 
levels.pop().expect("Levels exhausted"))?;
+                }
+                col_writer.close()
             }
-            writer.close()?;
-            Ok(())
         }
         ArrowDataType::Float16 => Err(ParquetError::ArrowError(
             "Float16 arrays not supported".to_string(),

Reply via email to