This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c09ba44c Optimized writing of byte array to parquet (#1764) (2x 
faster) (#2221)
2c09ba44c is described below

commit 2c09ba44ca7a897f9e049f3fac60e50b92c5df16
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Mon Aug 1 12:33:32 2022 +0100

    Optimized writing of byte array to parquet (#1764) (2x faster) (#2221)
    
    * Optimized writing of byte array to parquet (#1764)
    
    * Review feedback
    
    * Fix logical conflict
---
 parquet/src/arrow/arrow_writer/byte_array.rs   | 555 +++++++++++++++++++++++++
 parquet/src/arrow/arrow_writer/mod.rs          | 222 +++-------
 parquet/src/column/writer/encoder.rs           | 129 +++---
 parquet/src/column/writer/mod.rs               |  25 +-
 parquet/src/encodings/encoding/dict_encoder.rs |   7 +-
 parquet/src/util/interner.rs                   |   6 +
 6 files changed, 727 insertions(+), 217 deletions(-)

diff --git a/parquet/src/arrow/arrow_writer/byte_array.rs 
b/parquet/src/arrow/arrow_writer/byte_array.rs
new file mode 100644
index 000000000..52698a31b
--- /dev/null
+++ b/parquet/src/arrow/arrow_writer/byte_array.rs
@@ -0,0 +1,555 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::arrow::arrow_writer::levels::LevelInfo;
+use crate::arrow::arrow_writer::ArrayWriter;
+use crate::basic::Encoding;
+use crate::column::page::PageWriter;
+use crate::column::writer::encoder::{
+    ColumnValueEncoder, DataPageValues, DictionaryPage,
+};
+use crate::column::writer::GenericColumnWriter;
+use crate::data_type::{AsBytes, ByteArray, Int32Type};
+use crate::encodings::encoding::{DeltaBitPackEncoder, Encoder};
+use crate::encodings::rle::RleEncoder;
+use crate::errors::{ParquetError, Result};
+use crate::file::properties::{WriterProperties, WriterPropertiesPtr, 
WriterVersion};
+use crate::file::writer::OnCloseColumnChunk;
+use crate::schema::types::ColumnDescPtr;
+use crate::util::bit_util::num_required_bits;
+use crate::util::interner::{Interner, Storage};
+use arrow::array::{
+    Array, ArrayAccessor, ArrayRef, BinaryArray, LargeBinaryArray, 
LargeStringArray,
+    StringArray,
+};
+use arrow::datatypes::DataType;
+
+macro_rules! downcast_op {
+    ($data_type:expr, $array:ident, $op:expr $(, $arg:expr)*) => {
+        match $data_type {
+            DataType::Utf8 => 
$op($array.as_any().downcast_ref::<StringArray>().unwrap()$(, $arg)*),
+            DataType::LargeUtf8 => {
+                
$op($array.as_any().downcast_ref::<LargeStringArray>().unwrap()$(, $arg)*)
+            }
+            DataType::Binary => {
+                $op($array.as_any().downcast_ref::<BinaryArray>().unwrap()$(, 
$arg)*)
+            }
+            DataType::LargeBinary => {
+                
$op($array.as_any().downcast_ref::<LargeBinaryArray>().unwrap()$(, $arg)*)
+            }
+            d => unreachable!("cannot downcast {} to byte array", d)
+        }
+    };
+}
+
+/// Returns an [`ArrayWriter`] for byte or string arrays
+pub(super) fn make_byte_array_writer<'a>(
+    descr: ColumnDescPtr,
+    data_type: DataType,
+    props: WriterPropertiesPtr,
+    page_writer: Box<dyn PageWriter + 'a>,
+    on_close: OnCloseColumnChunk<'a>,
+) -> Box<dyn ArrayWriter + 'a> {
+    Box::new(ByteArrayWriter {
+        writer: Some(GenericColumnWriter::new(descr, props, page_writer)),
+        on_close: Some(on_close),
+        data_type,
+    })
+}
+
+/// An [`ArrayWriter`] for [`ByteArray`]
+struct ByteArrayWriter<'a> {
+    writer: Option<GenericColumnWriter<'a, ByteArrayEncoder>>,
+    on_close: Option<OnCloseColumnChunk<'a>>,
+    data_type: DataType,
+}
+
+impl<'a> ArrayWriter for ByteArrayWriter<'a> {
+    fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()> {
+        self.writer.as_mut().unwrap().write_batch_internal(
+            array,
+            Some(levels.non_null_indices()),
+            levels.def_levels(),
+            levels.rep_levels(),
+            None,
+            None,
+            None,
+        )?;
+        Ok(())
+    }
+
+    fn close(&mut self) -> Result<()> {
+        let (bytes_written, rows_written, metadata, column_index, 
offset_index) =
+            self.writer.take().unwrap().close()?;
+
+        if let Some(on_close) = self.on_close.take() {
+            on_close(
+                bytes_written,
+                rows_written,
+                metadata,
+                column_index,
+                offset_index,
+            )?;
+        }
+        Ok(())
+    }
+}
+
+/// A fallback encoder, i.e. non-dictionary, for [`ByteArray`]
+struct FallbackEncoder {
+    encoder: FallbackEncoderImpl,
+    num_values: usize,
+}
+
+/// The fallback encoder in use
+///
+/// Note: DeltaBitPackEncoder is boxed as it is rather large
+enum FallbackEncoderImpl {
+    Plain {
+        buffer: Vec<u8>,
+    },
+    DeltaLength {
+        buffer: Vec<u8>,
+        lengths: Box<DeltaBitPackEncoder<Int32Type>>,
+    },
+    Delta {
+        buffer: Vec<u8>,
+        last_value: Vec<u8>,
+        prefix_lengths: Box<DeltaBitPackEncoder<Int32Type>>,
+        suffix_lengths: Box<DeltaBitPackEncoder<Int32Type>>,
+    },
+}
+
+impl FallbackEncoder {
+    /// Create the fallback encoder for the given [`ColumnDescPtr`] and 
[`WriterProperties`]
+    fn new(descr: &ColumnDescPtr, props: &WriterProperties) -> Result<Self> {
+        // Set either main encoder or fallback encoder.
+        let encoding = props.encoding(descr.path()).unwrap_or_else(|| {
+            match props.writer_version() {
+                WriterVersion::PARQUET_1_0 => Encoding::PLAIN,
+                WriterVersion::PARQUET_2_0 => Encoding::DELTA_BYTE_ARRAY,
+            }
+        });
+
+        let encoder = match encoding {
+            Encoding::PLAIN => FallbackEncoderImpl::Plain { buffer: vec![] },
+            Encoding::DELTA_LENGTH_BYTE_ARRAY => 
FallbackEncoderImpl::DeltaLength {
+                buffer: vec![],
+                lengths: Box::new(DeltaBitPackEncoder::new()),
+            },
+            Encoding::DELTA_BYTE_ARRAY => FallbackEncoderImpl::Delta {
+                buffer: vec![],
+                last_value: vec![],
+                prefix_lengths: Box::new(DeltaBitPackEncoder::new()),
+                suffix_lengths: Box::new(DeltaBitPackEncoder::new()),
+            },
+            _ => {
+                return Err(general_err!(
+                    "unsupported encoding {} for byte array",
+                    encoding
+                ))
+            }
+        };
+
+        Ok(Self {
+            encoder,
+            num_values: 0,
+        })
+    }
+
+    /// Encode `values` to the in-progress page
+    fn encode<T>(&mut self, values: T, indices: &[usize])
+    where
+        T: ArrayAccessor + Copy,
+        T::Item: AsRef<[u8]>,
+    {
+        self.num_values += indices.len();
+        match &mut self.encoder {
+            FallbackEncoderImpl::Plain { buffer } => {
+                for idx in indices {
+                    let value = values.value(*idx);
+                    let value = value.as_ref();
+                    buffer.extend_from_slice((value.len() as u32).as_bytes());
+                    buffer.extend_from_slice(value)
+                }
+            }
+            FallbackEncoderImpl::DeltaLength { buffer, lengths } => {
+                for idx in indices {
+                    let value = values.value(*idx);
+                    let value = value.as_ref();
+                    lengths.put(&[value.len() as i32]).unwrap();
+                    buffer.extend_from_slice(value);
+                }
+            }
+            FallbackEncoderImpl::Delta {
+                buffer,
+                last_value,
+                prefix_lengths,
+                suffix_lengths,
+            } => {
+                for idx in indices {
+                    let value = values.value(*idx);
+                    let value = value.as_ref();
+                    let mut prefix_length = 0;
+
+                    while prefix_length < last_value.len()
+                        && prefix_length < value.len()
+                        && last_value[prefix_length] == value[prefix_length]
+                    {
+                        prefix_length += 1;
+                    }
+
+                    let suffix_length = value.len() - prefix_length;
+
+                    last_value.clear();
+                    last_value.extend_from_slice(value);
+
+                    buffer.extend_from_slice(&value[prefix_length..]);
+                    prefix_lengths.put(&[prefix_length as i32]).unwrap();
+                    suffix_lengths.put(&[suffix_length as i32]).unwrap();
+                }
+            }
+        }
+    }
+
+    fn estimated_data_page_size(&self) -> usize {
+        match &self.encoder {
+            FallbackEncoderImpl::Plain { buffer, .. } => buffer.len(),
+            FallbackEncoderImpl::DeltaLength { buffer, lengths } => {
+                buffer.len() + lengths.estimated_data_encoded_size()
+            }
+            FallbackEncoderImpl::Delta {
+                buffer,
+                prefix_lengths,
+                suffix_lengths,
+                ..
+            } => {
+                buffer.len()
+                    + prefix_lengths.estimated_data_encoded_size()
+                    + suffix_lengths.estimated_data_encoded_size()
+            }
+        }
+    }
+
+    fn flush_data_page(
+        &mut self,
+        min_value: Option<ByteArray>,
+        max_value: Option<ByteArray>,
+    ) -> Result<DataPageValues<ByteArray>> {
+        let (buf, encoding) = match &mut self.encoder {
+            FallbackEncoderImpl::Plain { buffer } => {
+                (std::mem::take(buffer), Encoding::PLAIN)
+            }
+            FallbackEncoderImpl::DeltaLength { buffer, lengths } => {
+                let lengths = lengths.flush_buffer()?;
+
+                let mut out = Vec::with_capacity(lengths.len() + buffer.len());
+                out.extend_from_slice(lengths.data());
+                out.extend_from_slice(buffer);
+                (out, Encoding::DELTA_LENGTH_BYTE_ARRAY)
+            }
+            FallbackEncoderImpl::Delta {
+                buffer,
+                prefix_lengths,
+                suffix_lengths,
+                ..
+            } => {
+                let prefix_lengths = prefix_lengths.flush_buffer()?;
+                let suffix_lengths = suffix_lengths.flush_buffer()?;
+
+                let mut out = Vec::with_capacity(
+                    prefix_lengths.len() + suffix_lengths.len() + buffer.len(),
+                );
+                out.extend_from_slice(prefix_lengths.data());
+                out.extend_from_slice(suffix_lengths.data());
+                out.extend_from_slice(buffer);
+                (out, Encoding::DELTA_BYTE_ARRAY)
+            }
+        };
+
+        Ok(DataPageValues {
+            buf: buf.into(),
+            num_values: std::mem::take(&mut self.num_values),
+            encoding,
+            min_value,
+            max_value,
+        })
+    }
+}
+
+/// [`Storage`] for the [`Interner`] used by [`DictEncoder`]
+#[derive(Debug, Default)]
+struct ByteArrayStorage {
+    /// Encoded dictionary data
+    page: Vec<u8>,
+
+    values: Vec<std::ops::Range<usize>>,
+}
+
+impl Storage for ByteArrayStorage {
+    type Key = u64;
+    type Value = [u8];
+
+    fn get(&self, idx: Self::Key) -> &Self::Value {
+        &self.page[self.values[idx as usize].clone()]
+    }
+
+    fn push(&mut self, value: &Self::Value) -> Self::Key {
+        let key = self.values.len();
+
+        self.page.reserve(4 + value.len());
+        self.page.extend_from_slice((value.len() as u32).as_bytes());
+
+        let start = self.page.len();
+        self.page.extend_from_slice(value);
+        self.values.push(start..self.page.len());
+
+        key as u64
+    }
+}
+
+/// A dictionary encoder for byte array data
+#[derive(Debug, Default)]
+struct DictEncoder {
+    interner: Interner<ByteArrayStorage>,
+    indices: Vec<u64>,
+}
+
+impl DictEncoder {
+    /// Encode `values` to the in-progress page
+    fn encode<T>(&mut self, values: T, indices: &[usize])
+    where
+        T: ArrayAccessor + Copy,
+        T::Item: AsRef<[u8]>,
+    {
+        self.indices.reserve(indices.len());
+
+        for idx in indices {
+            let value = values.value(*idx);
+            let interned = self.interner.intern(value.as_ref());
+            self.indices.push(interned);
+        }
+    }
+
+    fn bit_width(&self) -> u8 {
+        let length = self.interner.storage().values.len();
+        num_required_bits(length.saturating_sub(1) as u64)
+    }
+
+    fn estimated_data_page_size(&self) -> usize {
+        let bit_width = self.bit_width();
+        1 + RleEncoder::min_buffer_size(bit_width)
+            + RleEncoder::max_buffer_size(bit_width, self.indices.len())
+    }
+
+    fn estimated_dict_page_size(&self) -> usize {
+        self.interner.storage().page.len()
+    }
+
+    fn flush_dict_page(self) -> DictionaryPage {
+        let storage = self.interner.into_inner();
+
+        DictionaryPage {
+            buf: storage.page.into(),
+            num_values: storage.values.len(),
+            is_sorted: false,
+        }
+    }
+
+    fn flush_data_page(
+        &mut self,
+        min_value: Option<ByteArray>,
+        max_value: Option<ByteArray>,
+    ) -> Result<DataPageValues<ByteArray>> {
+        let num_values = self.indices.len();
+        let buffer_len = self.estimated_data_page_size();
+        let mut buffer = Vec::with_capacity(buffer_len);
+        buffer.push(self.bit_width() as u8);
+
+        let mut encoder = RleEncoder::new_from_buf(self.bit_width(), buffer);
+        for index in &self.indices {
+            if !encoder.put(*index as u64)? {
+                return Err(general_err!("Encoder doesn't have enough space"));
+            }
+        }
+
+        self.indices.clear();
+
+        Ok(DataPageValues {
+            buf: encoder.consume()?.into(),
+            num_values,
+            encoding: Encoding::RLE_DICTIONARY,
+            min_value,
+            max_value,
+        })
+    }
+}
+
+struct ByteArrayEncoder {
+    fallback: FallbackEncoder,
+    dict_encoder: Option<DictEncoder>,
+    num_values: usize,
+    min_value: Option<ByteArray>,
+    max_value: Option<ByteArray>,
+}
+
+impl ColumnValueEncoder for ByteArrayEncoder {
+    type T = ByteArray;
+    type Values = ArrayRef;
+
+    fn min_max(
+        &self,
+        values: &ArrayRef,
+        value_indices: Option<&[usize]>,
+    ) -> Option<(Self::T, Self::T)> {
+        match value_indices {
+            Some(indices) => {
+                let iter = indices.iter().cloned();
+                downcast_op!(values.data_type(), values, compute_min_max, iter)
+            }
+            None => {
+                let len = Array::len(values);
+                downcast_op!(values.data_type(), values, compute_min_max, 
0..len)
+            }
+        }
+    }
+
+    fn try_new(descr: &ColumnDescPtr, props: &WriterProperties) -> Result<Self>
+    where
+        Self: Sized,
+    {
+        let dictionary = props
+            .dictionary_enabled(descr.path())
+            .then(DictEncoder::default);
+
+        let fallback = FallbackEncoder::new(descr, props)?;
+
+        Ok(Self {
+            fallback,
+            dict_encoder: dictionary,
+            num_values: 0,
+            min_value: None,
+            max_value: None,
+        })
+    }
+
+    fn write(
+        &mut self,
+        _values: &Self::Values,
+        _offset: usize,
+        _len: usize,
+    ) -> Result<()> {
+        unreachable!("should call write_gather instead")
+    }
+
+    fn write_gather(&mut self, values: &Self::Values, indices: &[usize]) -> 
Result<()> {
+        downcast_op!(values.data_type(), values, encode, indices, self);
+        Ok(())
+    }
+
+    fn num_values(&self) -> usize {
+        self.num_values
+    }
+
+    fn has_dictionary(&self) -> bool {
+        self.dict_encoder.is_some()
+    }
+
+    fn estimated_dict_page_size(&self) -> Option<usize> {
+        Some(self.dict_encoder.as_ref()?.estimated_dict_page_size())
+    }
+
+    fn estimated_data_page_size(&self) -> usize {
+        match &self.dict_encoder {
+            Some(encoder) => encoder.estimated_data_page_size(),
+            None => self.fallback.estimated_data_page_size(),
+        }
+    }
+
+    fn flush_dict_page(&mut self) -> Result<Option<DictionaryPage>> {
+        match self.dict_encoder.take() {
+            Some(encoder) => {
+                if self.num_values != 0 {
+                    return Err(general_err!(
+                        "Must flush data pages before flushing dictionary"
+                    ));
+                }
+
+                Ok(Some(encoder.flush_dict_page()))
+            }
+            _ => Ok(None),
+        }
+    }
+
+    fn flush_data_page(&mut self) -> Result<DataPageValues<ByteArray>> {
+        let min_value = self.min_value.take();
+        let max_value = self.max_value.take();
+
+        match &mut self.dict_encoder {
+            Some(encoder) => encoder.flush_data_page(min_value, max_value),
+            _ => self.fallback.flush_data_page(min_value, max_value),
+        }
+    }
+}
+
+/// Encodes the provided `values` and `indices` to `encoder`
+///
+/// This is a free function so it can be used with `downcast_op!`
+fn encode<T>(values: T, indices: &[usize], encoder: &mut ByteArrayEncoder)
+where
+    T: ArrayAccessor + Copy,
+    T::Item: Copy + Ord + AsRef<[u8]>,
+{
+    if let Some((min, max)) = compute_min_max(values, indices.iter().cloned()) 
{
+        if encoder.min_value.as_ref().map_or(true, |m| m > &min) {
+            encoder.min_value = Some(min);
+        }
+
+        if encoder.max_value.as_ref().map_or(true, |m| m < &max) {
+            encoder.max_value = Some(max);
+        }
+    }
+
+    match &mut encoder.dict_encoder {
+        Some(dict_encoder) => dict_encoder.encode(values, indices),
+        None => encoder.fallback.encode(values, indices),
+    }
+}
+
+/// Computes the min and max for the provided array and indices
+///
+/// This is a free function so it can be used with `downcast_op!`
+fn compute_min_max<T>(
+    array: T,
+    mut valid: impl Iterator<Item = usize>,
+) -> Option<(ByteArray, ByteArray)>
+where
+    T: ArrayAccessor,
+    T::Item: Copy + Ord + AsRef<[u8]>,
+{
+    let first_idx = valid.next()?;
+
+    let first_val = array.value(first_idx);
+    let mut min = first_val;
+    let mut max = first_val;
+    for idx in valid {
+        let val = array.value(idx);
+        min = min.min(val);
+        max = max.max(val);
+    }
+    Some((min.as_ref().to_vec().into(), max.as_ref().to_vec().into()))
+}
diff --git a/parquet/src/arrow/arrow_writer/mod.rs 
b/parquet/src/arrow/arrow_writer/mod.rs
index 06189ac3a..1c95fcc27 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -33,7 +33,7 @@ use super::schema::{
     decimal_length_from_precision,
 };
 
-use crate::column::writer::{get_column_writer, ColumnWriter};
+use crate::column::writer::{get_column_writer, ColumnWriter, ColumnWriterImpl};
 use crate::errors::{ParquetError, Result};
 use crate::file::metadata::RowGroupMetaDataPtr;
 use crate::file::properties::WriterProperties;
@@ -41,6 +41,7 @@ use crate::file::writer::{SerializedColumnWriter, 
SerializedRowGroupWriter};
 use crate::{data_type::*, file::writer::SerializedFileWriter};
 use levels::{calculate_array_levels, LevelInfo};
 
+mod byte_array;
 mod levels;
 
 /// An object-safe API for writing an [`ArrayRef`]
@@ -66,17 +67,32 @@ impl<'a> ArrayWriter for ColumnArrayWriter<'a> {
 
 fn get_writer<'a, W: Write>(
     row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
+    data_type: &ArrowDataType,
 ) -> Result<Box<dyn ArrayWriter + 'a>> {
     let array_writer = row_group_writer
-        .next_column_with_factory(|descr, props, page_writer, on_close| {
-            // TODO: Special case array readers (#1764)
+        .next_column_with_factory(
+            |descr, props, page_writer, on_close| match data_type {
+                ArrowDataType::Utf8
+                | ArrowDataType::LargeUtf8
+                | ArrowDataType::Binary
+                | ArrowDataType::LargeBinary => 
Ok(byte_array::make_byte_array_writer(
+                    descr,
+                    data_type.clone(),
+                    props.clone(),
+                    page_writer,
+                    on_close,
+                )),
+                _ => {
+                    let column_writer =
+                        get_column_writer(descr, props.clone(), page_writer);
 
-            let column_writer = get_column_writer(descr, props.clone(), 
page_writer);
-            let serialized_writer =
-                SerializedColumnWriter::new(column_writer, Some(on_close));
+                    let serialized_writer =
+                        SerializedColumnWriter::new(column_writer, 
Some(on_close));
 
-            Ok(Box::new(ColumnArrayWriter(Some(serialized_writer))))
-        })?
+                    Ok(Box::new(ColumnArrayWriter(Some(serialized_writer))))
+                }
+            },
+        )?
         .expect("Unable to get column writer");
     Ok(array_writer)
 }
@@ -305,7 +321,7 @@ fn write_leaves<W: Write>(
         | ArrowDataType::Decimal128(_, _)
         | ArrowDataType::Decimal256(_, _)
         | ArrowDataType::FixedSizeBinary(_) => {
-            let mut writer = get_writer(row_group_writer)?;
+            let mut writer = get_writer(row_group_writer, &data_type)?;
             for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
                 writer.write(
                     array,
@@ -365,7 +381,7 @@ fn write_leaves<W: Write>(
             Ok(())
         }
         ArrowDataType::Dictionary(_, value_type) => {
-            let mut writer = get_writer(row_group_writer)?;
+            let mut writer = get_writer(row_group_writer, value_type)?;
             for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
                 // cast dictionary to a primitive
                 let array = arrow::compute::cast(array, value_type)?;
@@ -399,33 +415,25 @@ fn write_leaf(
     let indices = levels.non_null_indices();
     let written = match writer {
         ColumnWriter::Int32ColumnWriter(ref mut typed) => {
-            let values = match column.data_type() {
+            match column.data_type() {
                 ArrowDataType::Date64 => {
                     // If the column is a Date64, we cast it to a Date32, and 
then interpret that as Int32
-                    let array = if let ArrowDataType::Date64 = 
column.data_type() {
-                        let array = arrow::compute::cast(column, 
&ArrowDataType::Date32)?;
-                        arrow::compute::cast(&array, &ArrowDataType::Int32)?
-                    } else {
-                        arrow::compute::cast(column, &ArrowDataType::Int32)?
-                    };
+                    let array = arrow::compute::cast(column, 
&ArrowDataType::Date32)?;
+                    let array = arrow::compute::cast(&array, 
&ArrowDataType::Int32)?;
+
                     let array = array
                         .as_any()
                         .downcast_ref::<arrow_array::Int32Array>()
                         .expect("Unable to get int32 array");
-                    get_numeric_array_slice::<Int32Type, _>(array, indices)
+                    write_primitive(typed, array.values(), levels)?
                 }
                 ArrowDataType::UInt32 => {
+                    let data = column.data();
+                    let offset = data.offset();
                     // follow C++ implementation and use overflow/reinterpret 
cast from  u32 to i32 which will map
                     // `(i32::MAX as u32)..u32::MAX` to `i32::MIN..0`
-                    let array = column
-                        .as_any()
-                        .downcast_ref::<arrow_array::UInt32Array>()
-                        .expect("Unable to get u32 array");
-                    let array = arrow::compute::unary::<_, _, 
arrow::datatypes::Int32Type>(
-                        array,
-                        |x| x as i32,
-                    );
-                    get_numeric_array_slice::<Int32Type, _>(&array, indices)
+                    let array: &[i32] = data.buffers()[0].typed_data();
+                    write_primitive(typed, &array[offset..offset + 
data.len()], levels)?
                 }
                 _ => {
                     let array = arrow::compute::cast(column, 
&ArrowDataType::Int32)?;
@@ -433,14 +441,9 @@ fn write_leaf(
                         .as_any()
                         .downcast_ref::<arrow_array::Int32Array>()
                         .expect("Unable to get i32 array");
-                    get_numeric_array_slice::<Int32Type, _>(array, indices)
+                    write_primitive(typed, array.values(), levels)?
                 }
-            };
-            typed.write_batch(
-                values.as_slice(),
-                levels.def_levels(),
-                levels.rep_levels(),
-            )?
+            }
         }
         ColumnWriter::BoolColumnWriter(ref mut typed) => {
             let array = column
@@ -454,26 +457,21 @@ fn write_leaf(
             )?
         }
         ColumnWriter::Int64ColumnWriter(ref mut typed) => {
-            let values = match column.data_type() {
+            match column.data_type() {
                 ArrowDataType::Int64 => {
                     let array = column
                         .as_any()
                         .downcast_ref::<arrow_array::Int64Array>()
                         .expect("Unable to get i64 array");
-                    get_numeric_array_slice::<Int64Type, _>(array, indices)
+                    write_primitive(typed, array.values(), levels)?
                 }
                 ArrowDataType::UInt64 => {
                     // follow C++ implementation and use overflow/reinterpret 
cast from  u64 to i64 which will map
                     // `(i64::MAX as u64)..u64::MAX` to `i64::MIN..0`
-                    let array = column
-                        .as_any()
-                        .downcast_ref::<arrow_array::UInt64Array>()
-                        .expect("Unable to get u64 array");
-                    let array = arrow::compute::unary::<_, _, 
arrow::datatypes::Int64Type>(
-                        array,
-                        |x| x as i64,
-                    );
-                    get_numeric_array_slice::<Int64Type, _>(&array, indices)
+                    let data = column.data();
+                    let offset = data.offset();
+                    let array: &[i64] = data.buffers()[0].typed_data();
+                    write_primitive(typed, &array[offset..offset + 
data.len()], levels)?
                 }
                 _ => {
                     let array = arrow::compute::cast(column, 
&ArrowDataType::Int64)?;
@@ -481,14 +479,9 @@ fn write_leaf(
                         .as_any()
                         .downcast_ref::<arrow_array::Int64Array>()
                         .expect("Unable to get i64 array");
-                    get_numeric_array_slice::<Int64Type, _>(array, indices)
+                    write_primitive(typed, array.values(), levels)?
                 }
-            };
-            typed.write_batch(
-                values.as_slice(),
-                levels.def_levels(),
-                levels.rep_levels(),
-            )?
+            }
         }
         ColumnWriter::Int96ColumnWriter(ref mut _typed) => {
             unreachable!("Currently unreachable because data type not 
supported")
@@ -498,70 +491,18 @@ fn write_leaf(
                 .as_any()
                 .downcast_ref::<arrow_array::Float32Array>()
                 .expect("Unable to get Float32 array");
-            typed.write_batch(
-                get_numeric_array_slice::<FloatType, _>(array, 
indices).as_slice(),
-                levels.def_levels(),
-                levels.rep_levels(),
-            )?
+            write_primitive(typed, array.values(), levels)?
         }
         ColumnWriter::DoubleColumnWriter(ref mut typed) => {
             let array = column
                 .as_any()
                 .downcast_ref::<arrow_array::Float64Array>()
                 .expect("Unable to get Float64 array");
-            typed.write_batch(
-                get_numeric_array_slice::<DoubleType, _>(array, 
indices).as_slice(),
-                levels.def_levels(),
-                levels.rep_levels(),
-            )?
+            write_primitive(typed, array.values(), levels)?
+        }
+        ColumnWriter::ByteArrayColumnWriter(_) => {
+            unreachable!("should use ByteArrayWriter")
         }
-        ColumnWriter::ByteArrayColumnWriter(ref mut typed) => match 
column.data_type() {
-            ArrowDataType::Binary => {
-                let array = column
-                    .as_any()
-                    .downcast_ref::<arrow_array::BinaryArray>()
-                    .expect("Unable to get BinaryArray array");
-                typed.write_batch(
-                    get_binary_array(array).as_slice(),
-                    levels.def_levels(),
-                    levels.rep_levels(),
-                )?
-            }
-            ArrowDataType::Utf8 => {
-                let array = column
-                    .as_any()
-                    .downcast_ref::<arrow_array::StringArray>()
-                    .expect("Unable to get LargeBinaryArray array");
-                typed.write_batch(
-                    get_string_array(array).as_slice(),
-                    levels.def_levels(),
-                    levels.rep_levels(),
-                )?
-            }
-            ArrowDataType::LargeBinary => {
-                let array = column
-                    .as_any()
-                    .downcast_ref::<arrow_array::LargeBinaryArray>()
-                    .expect("Unable to get LargeBinaryArray array");
-                typed.write_batch(
-                    get_large_binary_array(array).as_slice(),
-                    levels.def_levels(),
-                    levels.rep_levels(),
-                )?
-            }
-            ArrowDataType::LargeUtf8 => {
-                let array = column
-                    .as_any()
-                    .downcast_ref::<arrow_array::LargeStringArray>()
-                    .expect("Unable to get LargeUtf8 array");
-                typed.write_batch(
-                    get_large_string_array(array).as_slice(),
-                    levels.def_levels(),
-                    levels.rep_levels(),
-                )?
-            }
-            _ => unreachable!("Currently unreachable because data type not 
supported"),
-        },
         ColumnWriter::FixedLenByteArrayColumnWriter(ref mut typed) => {
             let bytes = match column.data_type() {
                 ArrowDataType::Interval(interval_unit) => match interval_unit {
@@ -619,55 +560,20 @@ fn write_leaf(
     Ok(written as i64)
 }
 
-macro_rules! def_get_binary_array_fn {
-    ($name:ident, $ty:ty) => {
-        fn $name(array: &$ty) -> Vec<ByteArray> {
-            let mut byte_array = ByteArray::new();
-            let ptr = crate::util::memory::ByteBufferPtr::new(
-                array.value_data().as_slice().to_vec(),
-            );
-            byte_array.set_data(ptr);
-            array
-                .value_offsets()
-                .windows(2)
-                .enumerate()
-                .filter_map(|(i, offsets)| {
-                    if array.is_valid(i) {
-                        let start = offsets[0] as usize;
-                        let len = offsets[1] as usize - start;
-                        Some(byte_array.slice(start, len))
-                    } else {
-                        None
-                    }
-                })
-                .collect()
-        }
-    };
-}
-
-// TODO: These methods don't handle non null indices correctly (#1753)
-def_get_binary_array_fn!(get_binary_array, arrow_array::BinaryArray);
-def_get_binary_array_fn!(get_string_array, arrow_array::StringArray);
-def_get_binary_array_fn!(get_large_binary_array, 
arrow_array::LargeBinaryArray);
-def_get_binary_array_fn!(get_large_string_array, 
arrow_array::LargeStringArray);
-
-/// Get the underlying numeric array slice, skipping any null values.
-/// If there are no null values, it might be quicker to get the slice directly 
instead of
-/// calling this function.
-fn get_numeric_array_slice<T, A>(
-    array: &arrow_array::PrimitiveArray<A>,
-    indices: &[usize],
-) -> Vec<T::T>
-where
-    T: DataType,
-    A: arrow::datatypes::ArrowNumericType,
-    T::T: From<A::Native>,
-{
-    let mut values = Vec::with_capacity(indices.len());
-    for i in indices {
-        values.push(array.value(*i).into())
-    }
-    values
+fn write_primitive<'a, T: DataType>(
+    writer: &mut ColumnWriterImpl<'a, T>,
+    values: &[T::T],
+    levels: LevelInfo,
+) -> Result<usize> {
+    writer.write_batch_internal(
+        values,
+        Some(levels.non_null_indices()),
+        levels.def_levels(),
+        levels.rep_levels(),
+        None,
+        None,
+        None,
+    )
 }
 
 fn get_bool_array_slice(
diff --git a/parquet/src/column/writer/encoder.rs 
b/parquet/src/column/writer/encoder.rs
index 54003732a..d7363129f 100644
--- a/parquet/src/column/writer/encoder.rs
+++ b/parquet/src/column/writer/encoder.rs
@@ -30,16 +30,21 @@ use crate::util::memory::ByteBufferPtr;
 
 /// A collection of [`ParquetValueType`] encoded by a [`ColumnValueEncoder`]
 pub trait ColumnValues {
-    /// The underlying value type
-    type T: ParquetValueType;
-
     /// The number of values in this collection
     fn len(&self) -> usize;
+}
 
-    /// Returns the min and max values in this collection, skipping any NaN 
values
-    ///
-    /// Returns `None` if no values found
-    fn min_max(&self, descr: &ColumnDescriptor) -> Option<(&Self::T, 
&Self::T)>;
+#[cfg(any(feature = "arrow", test))]
+impl<T: arrow::array::Array> ColumnValues for T {
+    fn len(&self) -> usize {
+        arrow::array::Array::len(self)
+    }
+}
+
+impl<T: ParquetValueType> ColumnValues for [T] {
+    fn len(&self) -> usize {
+        self.len()
+    }
 }
 
 /// The encoded data for a dictionary page
@@ -67,7 +72,16 @@ pub trait ColumnValueEncoder {
     type T: ParquetValueType;
 
     /// The values encoded by this encoder
-    type Values: ColumnValues<T = Self::T> + ?Sized;
+    type Values: ColumnValues + ?Sized;
+
+    /// Returns the min and max values in this collection, skipping any NaN 
values
+    ///
+    /// Returns `None` if no values found
+    fn min_max(
+        &self,
+        values: &Self::Values,
+        value_indices: Option<&[usize]>,
+    ) -> Option<(Self::T, Self::T)>;
 
     /// Create a new [`ColumnValueEncoder`]
     fn try_new(descr: &ColumnDescPtr, props: &WriterProperties) -> Result<Self>
@@ -77,6 +91,9 @@ pub trait ColumnValueEncoder {
     /// Write the corresponding values to this [`ColumnValueEncoder`]
     fn write(&mut self, values: &Self::Values, offset: usize, len: usize) -> 
Result<()>;
 
+    /// Write the values at the indexes in `indices` to this 
[`ColumnValueEncoder`]
+    fn write_gather(&mut self, values: &Self::Values, indices: &[usize]) -> 
Result<()>;
+
     /// Returns the number of buffered values
     fn num_values(&self) -> usize;
 
@@ -110,11 +127,40 @@ pub struct ColumnValueEncoderImpl<T: DataType> {
     max_value: Option<T::T>,
 }
 
+impl<T: DataType> ColumnValueEncoderImpl<T> {
+    fn write_slice(&mut self, slice: &[T::T]) -> Result<()> {
+        if self.statistics_enabled == EnabledStatistics::Page {
+            if let Some((min, max)) = self.min_max(slice, None) {
+                update_min(&self.descr, &min, &mut self.min_value);
+                update_max(&self.descr, &max, &mut self.max_value);
+            }
+        }
+
+        match &mut self.dict_encoder {
+            Some(encoder) => encoder.put(slice),
+            _ => self.encoder.put(slice),
+        }
+    }
+}
+
 impl<T: DataType> ColumnValueEncoder for ColumnValueEncoderImpl<T> {
     type T = T::T;
 
     type Values = [T::T];
 
+    fn min_max(
+        &self,
+        values: &Self::Values,
+        value_indices: Option<&[usize]>,
+    ) -> Option<(Self::T, Self::T)> {
+        match value_indices {
+            Some(indices) => {
+                get_min_max(&self.descr, indices.iter().map(|x| &values[*x]))
+            }
+            None => get_min_max(&self.descr, values.iter()),
+        }
+    }
+
     fn try_new(descr: &ColumnDescPtr, props: &WriterProperties) -> 
Result<Self> {
         let dict_supported = props.dictionary_enabled(descr.path())
             && has_dictionary_support(T::get_physical_type(), props);
@@ -152,17 +198,12 @@ impl<T: DataType> ColumnValueEncoder for 
ColumnValueEncoderImpl<T> {
             )
         })?;
 
-        if self.statistics_enabled == EnabledStatistics::Page {
-            if let Some((min, max)) = slice.min_max(&self.descr) {
-                update_min(&self.descr, min, &mut self.min_value);
-                update_max(&self.descr, max, &mut self.max_value);
-            }
-        }
+        self.write_slice(slice)
+    }
 
-        match &mut self.dict_encoder {
-            Some(encoder) => encoder.put(slice),
-            _ => self.encoder.put(slice),
-        }
+    fn write_gather(&mut self, values: &Self::Values, indices: &[usize]) -> 
Result<()> {
+        let slice: Vec<_> = indices.iter().map(|idx| 
values[*idx].clone()).collect();
+        self.write_slice(&slice)
     }
 
     fn num_values(&self) -> usize {
@@ -221,36 +262,30 @@ impl<T: DataType> ColumnValueEncoder for 
ColumnValueEncoderImpl<T> {
     }
 }
 
-impl<T: ParquetValueType> ColumnValues for [T] {
-    type T = T;
-
-    fn len(&self) -> usize {
-        self.len()
-    }
-
-    fn min_max(&self, descr: &ColumnDescriptor) -> Option<(&T, &T)> {
-        let mut iter = self.iter();
-
-        let first = loop {
-            let next = iter.next()?;
-            if !is_nan(next) {
-                break next;
-            }
-        };
+fn get_min_max<'a, T, I>(descr: &ColumnDescriptor, mut iter: I) -> Option<(T, 
T)>
+where
+    T: ParquetValueType + 'a,
+    I: Iterator<Item = &'a T>,
+{
+    let first = loop {
+        let next = iter.next()?;
+        if !is_nan(next) {
+            break next;
+        }
+    };
 
-        let mut min = first;
-        let mut max = first;
-        for val in iter {
-            if is_nan(val) {
-                continue;
-            }
-            if compare_greater(descr, min, val) {
-                min = val;
-            }
-            if compare_greater(descr, val, max) {
-                max = val;
-            }
+    let mut min = first;
+    let mut max = first;
+    for val in iter {
+        if is_nan(val) {
+            continue;
+        }
+        if compare_greater(descr, min, val) {
+            min = val;
+        }
+        if compare_greater(descr, val, max) {
+            max = val;
         }
-        Some((min, max))
     }
+    Some((min.clone(), max.clone()))
 }
diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs
index 2d5556111..6c467b5e4 100644
--- a/parquet/src/column/writer/mod.rs
+++ b/parquet/src/column/writer/mod.rs
@@ -249,9 +249,10 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> 
{
         }
     }
 
-    fn write_batch_internal(
+    pub(crate) fn write_batch_internal(
         &mut self,
         values: &E::Values,
+        value_indices: Option<&[usize]>,
         def_levels: Option<&[i16]>,
         rep_levels: Option<&[i16]>,
         min: Option<&E::T>,
@@ -290,9 +291,10 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> 
{
                     panic!("min/max should be both set or both None")
                 }
                 (None, None) => {
-                    if let Some((min, max)) = values.min_max(&self.descr) {
-                        update_min(&self.descr, min, &mut 
self.min_column_value);
-                        update_max(&self.descr, max, &mut 
self.max_column_value);
+                    if let Some((min, max)) = self.encoder.min_max(values, 
value_indices)
+                    {
+                        update_min(&self.descr, &min, &mut 
self.min_column_value);
+                        update_max(&self.descr, &max, &mut 
self.max_column_value);
                     }
                 }
             };
@@ -311,6 +313,7 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
             values_offset += self.write_mini_batch(
                 values,
                 values_offset,
+                value_indices,
                 write_batch_size,
                 def_levels.map(|lv| &lv[levels_offset..levels_offset + 
write_batch_size]),
                 rep_levels.map(|lv| &lv[levels_offset..levels_offset + 
write_batch_size]),
@@ -321,6 +324,7 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
         values_offset += self.write_mini_batch(
             values,
             values_offset,
+            value_indices,
             num_levels - levels_offset,
             def_levels.map(|lv| &lv[levels_offset..]),
             rep_levels.map(|lv| &lv[levels_offset..]),
@@ -348,7 +352,7 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
         def_levels: Option<&[i16]>,
         rep_levels: Option<&[i16]>,
     ) -> Result<usize> {
-        self.write_batch_internal(values, def_levels, rep_levels, None, None, 
None)
+        self.write_batch_internal(values, None, def_levels, rep_levels, None, 
None, None)
     }
 
     /// Writer may optionally provide pre-calculated statistics for use when 
computing
@@ -369,6 +373,7 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
     ) -> Result<usize> {
         self.write_batch_internal(
             values,
+            None,
             def_levels,
             rep_levels,
             min,
@@ -427,6 +432,7 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
         &mut self,
         values: &E::Values,
         values_offset: usize,
+        value_indices: Option<&[usize]>,
         num_levels: usize,
         def_levels: Option<&[i16]>,
         rep_levels: Option<&[i16]>,
@@ -490,7 +496,14 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> 
{
             self.num_buffered_rows += num_levels as u32;
         }
 
-        self.encoder.write(values, values_offset, values_to_write)?;
+        match value_indices {
+            Some(indices) => {
+                let indices = &indices[values_offset..values_offset + 
values_to_write];
+                self.encoder.write_gather(values, indices)?;
+            }
+            None => self.encoder.write(values, values_offset, 
values_to_write)?,
+        }
+
         self.num_buffered_values += num_levels as u32;
 
         if self.should_add_data_page() {
diff --git a/parquet/src/encodings/encoding/dict_encoder.rs 
b/parquet/src/encodings/encoding/dict_encoder.rs
index 1b386448b..8f3c98aca 100644
--- a/parquet/src/encodings/encoding/dict_encoder.rs
+++ b/parquet/src/encodings/encoding/dict_encoder.rs
@@ -146,12 +146,7 @@ impl<T: DataType> DictEncoder<T> {
 
     #[inline]
     fn bit_width(&self) -> u8 {
-        let num_entries = self.num_entries();
-        if num_entries <= 1 {
-            num_entries as u8
-        } else {
-            num_required_bits(num_entries as u64 - 1)
-        }
+        num_required_bits(self.num_entries().saturating_sub(1) as u64)
     }
 }
 
diff --git a/parquet/src/util/interner.rs b/parquet/src/util/interner.rs
index e64ae0179..c0afad8e5 100644
--- a/parquet/src/util/interner.rs
+++ b/parquet/src/util/interner.rs
@@ -34,6 +34,7 @@ pub trait Storage {
 }
 
 /// A generic value interner supporting various different [`Storage`]
+#[derive(Debug, Default)]
 pub struct Interner<S: Storage> {
     state: ahash::RandomState,
 
@@ -84,6 +85,11 @@ impl<S: Storage> Interner<S> {
     pub fn storage(&self) -> &S {
         &self.storage
     }
+
+    /// Unwraps the inner storage
+    pub fn into_inner(self) -> S {
+        self.storage
+    }
 }
 
 fn compute_hash<T: AsBytes + ?Sized>(state: &ahash::RandomState, value: &T) -> 
u64 {

Reply via email to