This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new b5c831a1d9 Allow Users to Provide Custom `ArrayFormatter`s when 
Pretty-Printing Record Batches (#8829)
b5c831a1d9 is described below

commit b5c831a1d90b0c5a20f974c3d1c2bda776026f96
Author: Tobias Schwarzinger <[email protected]>
AuthorDate: Wed Nov 19 15:00:49 2025 +0100

    Allow Users to Provide Custom `ArrayFormatter`s when Pretty-Printing Record 
Batches (#8829)
    
    # Which issue does this PR close?
    
    - Closes #8821.
    
    # Rationale for this change
    
    Allows users that require custom pretty-printing logic for batches to
    supply this implementation.
    
    # What changes are included in this PR?
    
    Changes to existing code:
    
    - Make fields in `FormatOptions` public. This is necessary as the custom
    `ArrayFormatter` must also have access to the formatting options. (see
    `<NULL>` in the test)
    - Deprecate `types_info()` method as the field is now public
    - Allow directly creating `ArrayFormatter` with a `DisplayIndex`
    implementation
    - Make `FormatError`, `FormatResult`, and `DisplayIndex` public. I do
    have some second thoughts about `DisplayIndex` not having any concept of
    length even though its taking an index as input. However, it may be fine
    for now.
    
    New code:
    
    - `ArrayFormatterFactory`: Allows creating `ArrayFormatters` with custom
    behavior
    - `pretty_format_batches_with_options_and_formatters` pretty printing
    with custom formatters
    - Similar thing for format column
    
    # Are these changes tested?
    
    Yes, existing tests cover the default formatting path.
    
    Three new tests:
    - Format record batch with custom type (append € sign)
    - Format column with custom formatter (append (32-Bit) for `Int32`)
    - Allow overriding the custom types with a custom schema (AFAIK this is
    not possible with the current API but might make sense).
    - Added a sanity check that the number of fields in a custom schema must
    match the number of columns in the record batch.
    
    # Are there any user-facing changes?
    
    Yes, multiple things become public, `types_info()` becomes deprecated,
    and there are new APIs for custom pretty printing of batches.
---
 arrow-array/src/array/map_array.rs |   9 +
 arrow-cast/src/display.rs          | 324 +++++++++++++++++++----
 arrow-cast/src/pretty.rs           | 510 ++++++++++++++++++++++++++++++++++++-
 3 files changed, 781 insertions(+), 62 deletions(-)

diff --git a/arrow-array/src/array/map_array.rs 
b/arrow-array/src/array/map_array.rs
index fcf3f621b2..b5e611a92b 100644
--- a/arrow-array/src/array/map_array.rs
+++ b/arrow-array/src/array/map_array.rs
@@ -173,6 +173,15 @@ impl MapArray {
         &self.entries
     }
 
+    /// Returns a reference to the fields of the [`StructArray`] that backs 
this map.
+    pub fn entries_fields(&self) -> (&Field, &Field) {
+        let fields = self.entries.fields().iter().collect::<Vec<_>>();
+        let fields = TryInto::<[&FieldRef; 2]>::try_into(fields)
+            .expect("Every map has a key and value field");
+
+        (fields[0].as_ref(), fields[1].as_ref())
+    }
+
     /// Returns the data type of the map's keys.
     pub fn key_type(&self) -> &DataType {
         self.keys().data_type()
diff --git a/arrow-cast/src/display.rs b/arrow-cast/src/display.rs
index caa9804507..6fd454d2fb 100644
--- a/arrow-cast/src/display.rs
+++ b/arrow-cast/src/display.rs
@@ -23,7 +23,8 @@
 //! record batch pretty printing.
 //!
 //! [`pretty`]: crate::pretty
-use std::fmt::{Display, Formatter, Write};
+use std::fmt::{Debug, Display, Formatter, Write};
+use std::hash::{Hash, Hasher};
 use std::ops::Range;
 
 use arrow_array::cast::*;
@@ -53,7 +54,12 @@ pub enum DurationFormat {
 /// By default nulls are formatted as `""` and temporal types formatted
 /// according to RFC3339
 ///
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+/// # Equality
+///
+/// Most fields in [`FormatOptions`] are compared by value, except 
`formatter_factory`. As the trait
+/// does not require an [`Eq`] and [`Hash`] implementation, this struct only 
compares the pointer of
+/// the factories.
+#[derive(Debug, Clone)]
 pub struct FormatOptions<'a> {
     /// If set to `true` any formatting errors will be written to the output
     /// instead of being converted into a [`std::fmt::Error`]
@@ -74,6 +80,9 @@ pub struct FormatOptions<'a> {
     duration_format: DurationFormat,
     /// Show types in visual representation batches
     types_info: bool,
+    /// Formatter factory used to instantiate custom [`ArrayFormatter`]s. This 
allows users to
+    /// provide custom formatters.
+    formatter_factory: Option<&'a dyn ArrayFormatterFactory>,
 }
 
 impl Default for FormatOptions<'_> {
@@ -82,6 +91,44 @@ impl Default for FormatOptions<'_> {
     }
 }
 
+impl PartialEq for FormatOptions<'_> {
+    fn eq(&self, other: &Self) -> bool {
+        self.safe == other.safe
+            && self.null == other.null
+            && self.date_format == other.date_format
+            && self.datetime_format == other.datetime_format
+            && self.timestamp_format == other.timestamp_format
+            && self.timestamp_tz_format == other.timestamp_tz_format
+            && self.time_format == other.time_format
+            && self.duration_format == other.duration_format
+            && self.types_info == other.types_info
+            && match (self.formatter_factory, other.formatter_factory) {
+                (Some(f1), Some(f2)) => std::ptr::eq(f1, f2),
+                (None, None) => true,
+                _ => false,
+            }
+    }
+}
+
+impl Eq for FormatOptions<'_> {}
+
+impl Hash for FormatOptions<'_> {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        self.safe.hash(state);
+        self.null.hash(state);
+        self.date_format.hash(state);
+        self.datetime_format.hash(state);
+        self.timestamp_format.hash(state);
+        self.timestamp_tz_format.hash(state);
+        self.time_format.hash(state);
+        self.duration_format.hash(state);
+        self.types_info.hash(state);
+        self.formatter_factory
+            .map(|f| f as *const dyn ArrayFormatterFactory)
+            .hash(state);
+    }
+}
+
 impl<'a> FormatOptions<'a> {
     /// Creates a new set of format options
     pub const fn new() -> Self {
@@ -95,6 +142,7 @@ impl<'a> FormatOptions<'a> {
             time_format: None,
             duration_format: DurationFormat::ISO8601,
             types_info: false,
+            formatter_factory: None,
         }
     }
 
@@ -169,10 +217,170 @@ impl<'a> FormatOptions<'a> {
         Self { types_info, ..self }
     }
 
-    /// Returns true if type info should be included in visual representation 
of batches
+    /// Overrides the [`ArrayFormatterFactory`] used to instantiate custom 
[`ArrayFormatter`]s.
+    ///
+    /// Using [`None`] causes pretty-printers to use the default 
[`ArrayFormatter`]s.
+    pub const fn with_formatter_factory(
+        self,
+        formatter_factory: Option<&'a dyn ArrayFormatterFactory>,
+    ) -> Self {
+        Self {
+            formatter_factory,
+            ..self
+        }
+    }
+
+    /// Returns whether formatting errors should be written to the output 
instead of being converted
+    /// into a [`std::fmt::Error`].
+    pub const fn safe(&self) -> bool {
+        self.safe
+    }
+
+    /// Returns the string used for displaying nulls.
+    pub const fn null(&self) -> &'a str {
+        self.null
+    }
+
+    /// Returns the format used for [`DataType::Date32`] columns.
+    pub const fn date_format(&self) -> TimeFormat<'a> {
+        self.date_format
+    }
+
+    /// Returns the format used for [`DataType::Date64`] columns.
+    pub const fn datetime_format(&self) -> TimeFormat<'a> {
+        self.datetime_format
+    }
+
+    /// Returns the format used for [`DataType::Timestamp`] columns without a 
timezone.
+    pub const fn timestamp_format(&self) -> TimeFormat<'a> {
+        self.timestamp_format
+    }
+
+    /// Returns the format used for [`DataType::Timestamp`] columns with a 
timezone.
+    pub const fn timestamp_tz_format(&self) -> TimeFormat<'a> {
+        self.timestamp_tz_format
+    }
+
+    /// Returns the format used for [`DataType::Time32`] and 
[`DataType::Time64`] columns.
+    pub const fn time_format(&self) -> TimeFormat<'a> {
+        self.time_format
+    }
+
+    /// Returns the [`DurationFormat`] used for duration columns.
+    pub const fn duration_format(&self) -> DurationFormat {
+        self.duration_format
+    }
+
+    /// Returns true if type info should be included in a visual 
representation of batches.
     pub const fn types_info(&self) -> bool {
         self.types_info
     }
+
+    /// Returns the [`ArrayFormatterFactory`] used to instantiate custom 
[`ArrayFormatter`]s.
+    pub const fn formatter_factory(&self) -> Option<&'a dyn 
ArrayFormatterFactory> {
+        self.formatter_factory
+    }
+}
+
+/// Allows creating a new [`ArrayFormatter`] for a given [`Array`] and an 
optional [`Field`].
+///
+/// # Example
+///
+/// The example below shows how to create a custom formatter for a custom type 
`my_money`. Note that
+/// this example requires the `prettyprint` feature.
+///
+/// ```rust
+/// use std::fmt::Write;
+/// use arrow_array::{cast::AsArray, Array, Int32Array};
+/// use arrow_cast::display::{ArrayFormatter, ArrayFormatterFactory, 
DisplayIndex, FormatOptions, FormatResult};
+/// use arrow_cast::pretty::pretty_format_batches_with_options;
+/// use arrow_schema::{ArrowError, Field};
+///
+/// /// A custom formatter factory that can create a formatter for the special 
type `my_money`.
+/// ///
+/// /// This struct could have access to some kind of extension type registry 
that can lookup the
+/// /// correct formatter for an extension type on-demand.
+/// #[derive(Debug)]
+/// struct MyFormatters {}
+///
+/// impl ArrayFormatterFactory for MyFormatters {
+///     fn create_array_formatter<'formatter>(
+///         &self,
+///         array: &'formatter dyn Array,
+///         options: &FormatOptions<'formatter>,
+///         field: Option<&'formatter Field>,
+///     ) -> Result<Option<ArrayFormatter<'formatter>>, ArrowError> {
+///         // check if this is the money type
+///         if field
+///             .map(|f| f.extension_type_name() == Some("my_money"))
+///             .unwrap_or(false)
+///         {
+///             // We assume that my_money always is an Int32.
+///             let array = array.as_primitive();
+///             let display_index = Box::new(MyMoneyFormatter { array, 
options: options.clone() });
+///             return Ok(Some(ArrayFormatter::new(display_index, 
options.safe())));
+///         }
+///
+///         Ok(None) // None indicates that the default formatter should be 
used.
+///     }
+/// }
+///
+/// /// A formatter for the type `my_money` that wraps a specific array and 
has access to the
+/// /// formatting options.
+/// struct MyMoneyFormatter<'a> {
+///     array: &'a Int32Array,
+///     options: FormatOptions<'a>,
+/// }
+///
+/// impl<'a> DisplayIndex for MyMoneyFormatter<'a> {
+///     fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult {
+///         match self.array.is_valid(idx) {
+///             true => write!(f, "{} €", self.array.value(idx))?,
+///             false => write!(f, "{}", self.options.null())?,
+///         }
+///
+///         Ok(())
+///     }
+/// }
+///
+/// // Usually, here you would provide your record batches.
+/// let my_batches = vec![];
+///
+/// // Call the pretty printer with the custom formatter factory.
+/// pretty_format_batches_with_options(
+///        &my_batches,
+///        &FormatOptions::new().with_formatter_factory(Some(&MyFormatters {}))
+/// );
+/// ```
+pub trait ArrayFormatterFactory: Debug {
+    /// Creates a new [`ArrayFormatter`] for the given [`Array`] and an 
optional [`Field`]. If the
+    /// default implementation should be used, return [`None`].
+    ///
+    /// The field shall be used to look up metadata about the `array` while 
`options` provide
+    /// information on formatting, for example, dates and times which should 
be considered by an
+    /// implementor.
+    fn create_array_formatter<'formatter>(
+        &self,
+        array: &'formatter dyn Array,
+        options: &FormatOptions<'formatter>,
+        field: Option<&'formatter Field>,
+    ) -> Result<Option<ArrayFormatter<'formatter>>, ArrowError>;
+}
+
+/// Used to create a new [`ArrayFormatter`] from the given `array`, while also 
checking whether
+/// there is an override available in the [`ArrayFormatterFactory`].
+pub(crate) fn make_array_formatter<'a>(
+    array: &'a dyn Array,
+    options: &FormatOptions<'a>,
+    field: Option<&'a Field>,
+) -> Result<ArrayFormatter<'a>, ArrowError> {
+    match options.formatter_factory() {
+        None => ArrayFormatter::try_new(array, options),
+        Some(formatters) => formatters
+            .create_array_formatter(array, options, field)
+            .transpose()
+            .unwrap_or_else(|| ArrayFormatter::try_new(array, options)),
+    }
 }
 
 /// Implements [`Display`] for a specific array value
@@ -272,14 +480,19 @@ pub struct ArrayFormatter<'a> {
 }
 
 impl<'a> ArrayFormatter<'a> {
+    /// Returns an [`ArrayFormatter`] using the provided formatter.
+    pub fn new(format: Box<dyn DisplayIndex + 'a>, safe: bool) -> Self {
+        Self { format, safe }
+    }
+
     /// Returns an [`ArrayFormatter`] that can be used to format `array`
     ///
     /// This returns an error if an array of the given data type cannot be 
formatted
     pub fn try_new(array: &'a dyn Array, options: &FormatOptions<'a>) -> 
Result<Self, ArrowError> {
-        Ok(Self {
-            format: make_formatter(array, options)?,
-            safe: options.safe,
-        })
+        Ok(Self::new(
+            make_default_display_index(array, options)?,
+            options.safe,
+        ))
     }
 
     /// Returns a [`ValueFormatter`] that implements [`Display`] for
@@ -292,7 +505,7 @@ impl<'a> ArrayFormatter<'a> {
     }
 }
 
-fn make_formatter<'a>(
+fn make_default_display_index<'a>(
     array: &'a dyn Array,
     options: &FormatOptions<'a>,
 ) -> Result<Box<dyn DisplayIndex + 'a>, ArrowError> {
@@ -332,12 +545,15 @@ fn make_formatter<'a>(
 }
 
 /// Either an [`ArrowError`] or [`std::fmt::Error`]
-enum FormatError {
+pub enum FormatError {
+    /// An error occurred while formatting the array
     Format(std::fmt::Error),
+    /// An Arrow error occurred while formatting the array.
     Arrow(ArrowError),
 }
 
-type FormatResult = Result<(), FormatError>;
+/// The result of formatting an array element via [`DisplayIndex::write`].
+pub type FormatResult = Result<(), FormatError>;
 
 impl From<std::fmt::Error> for FormatError {
     fn from(value: std::fmt::Error) -> Self {
@@ -352,7 +568,8 @@ impl From<ArrowError> for FormatError {
 }
 
 /// [`Display`] but accepting an index
-trait DisplayIndex {
+pub trait DisplayIndex {
+    /// Write the value of the underlying array at `idx` to `f`.
     fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult;
 }
 
@@ -896,7 +1113,7 @@ impl<'a, K: ArrowDictionaryKeyType> DisplayIndexState<'a> 
for &'a DictionaryArra
     type State = Box<dyn DisplayIndex + 'a>;
 
     fn prepare(&self, options: &FormatOptions<'a>) -> Result<Self::State, 
ArrowError> {
-        make_formatter(self.values().as_ref(), options)
+        make_default_display_index(self.values().as_ref(), options)
     }
 
     fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> 
FormatResult {
@@ -906,68 +1123,82 @@ impl<'a, K: ArrowDictionaryKeyType> 
DisplayIndexState<'a> for &'a DictionaryArra
 }
 
 impl<'a, K: RunEndIndexType> DisplayIndexState<'a> for &'a RunArray<K> {
-    type State = Box<dyn DisplayIndex + 'a>;
+    type State = ArrayFormatter<'a>;
 
     fn prepare(&self, options: &FormatOptions<'a>) -> Result<Self::State, 
ArrowError> {
-        make_formatter(self.values().as_ref(), options)
+        let field = match (*self).data_type() {
+            DataType::RunEndEncoded(_, values_field) => values_field,
+            _ => unreachable!(),
+        };
+        make_array_formatter(self.values().as_ref(), options, Some(field))
     }
 
     fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> 
FormatResult {
         let value_idx = self.get_physical_index(idx);
-        s.as_ref().write(value_idx, f)
+        write!(f, "{}", s.value(value_idx))?;
+        Ok(())
     }
 }
 
 fn write_list(
     f: &mut dyn Write,
     mut range: Range<usize>,
-    values: &dyn DisplayIndex,
+    values: &ArrayFormatter<'_>,
 ) -> FormatResult {
     f.write_char('[')?;
     if let Some(idx) = range.next() {
-        values.write(idx, f)?;
+        write!(f, "{}", values.value(idx))?;
     }
     for idx in range {
-        write!(f, ", ")?;
-        values.write(idx, f)?;
+        write!(f, ", {}", values.value(idx))?;
     }
     f.write_char(']')?;
     Ok(())
 }
 
 impl<'a, O: OffsetSizeTrait> DisplayIndexState<'a> for &'a GenericListArray<O> 
{
-    type State = Box<dyn DisplayIndex + 'a>;
+    type State = ArrayFormatter<'a>;
 
     fn prepare(&self, options: &FormatOptions<'a>) -> Result<Self::State, 
ArrowError> {
-        make_formatter(self.values().as_ref(), options)
+        let field = match (*self).data_type() {
+            DataType::List(f) => f,
+            DataType::LargeList(f) => f,
+            _ => unreachable!(),
+        };
+        make_array_formatter(self.values().as_ref(), options, 
Some(field.as_ref()))
     }
 
     fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> 
FormatResult {
         let offsets = self.value_offsets();
         let end = offsets[idx + 1].as_usize();
         let start = offsets[idx].as_usize();
-        write_list(f, start..end, s.as_ref())
+        write_list(f, start..end, s)
     }
 }
 
 impl<'a> DisplayIndexState<'a> for &'a FixedSizeListArray {
-    type State = (usize, Box<dyn DisplayIndex + 'a>);
+    type State = (usize, ArrayFormatter<'a>);
 
     fn prepare(&self, options: &FormatOptions<'a>) -> Result<Self::State, 
ArrowError> {
-        let values = make_formatter(self.values().as_ref(), options)?;
+        let field = match (*self).data_type() {
+            DataType::FixedSizeList(f, _) => f,
+            _ => unreachable!(),
+        };
+        let formatter =
+            make_array_formatter(self.values().as_ref(), options, 
Some(field.as_ref()))?;
         let length = self.value_length();
-        Ok((length as usize, values))
+        Ok((length as usize, formatter))
     }
 
     fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> 
FormatResult {
         let start = idx * s.0;
         let end = start + s.0;
-        write_list(f, start..end, s.1.as_ref())
+        write_list(f, start..end, &s.1)
     }
 }
 
-/// Pairs a boxed [`DisplayIndex`] with its field name
-type FieldDisplay<'a> = (&'a str, Box<dyn DisplayIndex + 'a>);
+/// Pairs an [`ArrayFormatter`] with its field name
+type FieldDisplay<'a> = (&'a str, ArrayFormatter<'a>);
 
 impl<'a> DisplayIndexState<'a> for &'a StructArray {
     type State = Vec<FieldDisplay<'a>>;
@@ -982,7 +1213,7 @@ impl<'a> DisplayIndexState<'a> for &'a StructArray {
             .iter()
             .zip(fields)
             .map(|(a, f)| {
-                let format = make_formatter(a.as_ref(), options)?;
+                let format = make_array_formatter(a.as_ref(), options, 
Some(f))?;
                 Ok((f.name().as_str(), format))
             })
             .collect()
@@ -992,12 +1223,10 @@ impl<'a> DisplayIndexState<'a> for &'a StructArray {
         let mut iter = s.iter();
         f.write_char('{')?;
         if let Some((name, display)) = iter.next() {
-            write!(f, "{name}: ")?;
-            display.as_ref().write(idx, f)?;
+            write!(f, "{name}: {}", display.value(idx))?;
         }
         for (name, display) in iter {
-            write!(f, ", {name}: ")?;
-            display.as_ref().write(idx, f)?;
+            write!(f, ", {name}: {}", display.value(idx))?;
         }
         f.write_char('}')?;
         Ok(())
@@ -1005,11 +1234,13 @@ impl<'a> DisplayIndexState<'a> for &'a StructArray {
 }
 
 impl<'a> DisplayIndexState<'a> for &'a MapArray {
-    type State = (Box<dyn DisplayIndex + 'a>, Box<dyn DisplayIndex + 'a>);
+    type State = (ArrayFormatter<'a>, ArrayFormatter<'a>);
 
     fn prepare(&self, options: &FormatOptions<'a>) -> Result<Self::State, 
ArrowError> {
-        let keys = make_formatter(self.keys().as_ref(), options)?;
-        let values = make_formatter(self.values().as_ref(), options)?;
+        let (key_field, value_field) = (*self).entries_fields();
+
+        let keys = make_array_formatter(self.keys().as_ref(), options, 
Some(key_field))?;
+        let values = make_array_formatter(self.values().as_ref(), options, 
Some(value_field))?;
         Ok((keys, values))
     }
 
@@ -1021,16 +1252,12 @@ impl<'a> DisplayIndexState<'a> for &'a MapArray {
 
         f.write_char('{')?;
         if let Some(idx) = iter.next() {
-            s.0.write(idx, f)?;
-            write!(f, ": ")?;
-            s.1.write(idx, f)?;
+            write!(f, "{}: {}", s.0.value(idx), s.1.value(idx))?;
         }
 
         for idx in iter {
-            write!(f, ", ")?;
-            s.0.write(idx, f)?;
-            write!(f, ": ")?;
-            s.1.write(idx, f)?;
+            write!(f, ", {}", s.0.value(idx))?;
+            write!(f, ": {}", s.1.value(idx))?;
         }
 
         f.write_char('}')?;
@@ -1039,10 +1266,7 @@ impl<'a> DisplayIndexState<'a> for &'a MapArray {
 }
 
 impl<'a> DisplayIndexState<'a> for &'a UnionArray {
-    type State = (
-        Vec<Option<(&'a str, Box<dyn DisplayIndex + 'a>)>>,
-        UnionMode,
-    );
+    type State = (Vec<Option<FieldDisplay<'a>>>, UnionMode);
 
     fn prepare(&self, options: &FormatOptions<'a>) -> Result<Self::State, 
ArrowError> {
         let (fields, mode) = match (*self).data_type() {
@@ -1053,7 +1277,7 @@ impl<'a> DisplayIndexState<'a> for &'a UnionArray {
         let max_id = fields.iter().map(|(id, _)| id).max().unwrap_or_default() 
as usize;
         let mut out: Vec<Option<FieldDisplay>> = (0..max_id + 1).map(|_| 
None).collect();
         for (i, field) in fields.iter() {
-            let formatter = make_formatter(self.child(i).as_ref(), options)?;
+            let formatter = make_array_formatter(self.child(i).as_ref(), 
options, Some(field))?;
             out[i as usize] = Some((field.name().as_str(), formatter))
         }
         Ok((out, *mode))
@@ -1067,9 +1291,7 @@ impl<'a> DisplayIndexState<'a> for &'a UnionArray {
         };
         let (name, field) = s.0[id as usize].as_ref().unwrap();
 
-        write!(f, "{{{name}=")?;
-        field.write(idx, f)?;
-        f.write_char('}')?;
+        write!(f, "{{{name}={}}}", field.value(idx))?;
         Ok(())
     }
 }
diff --git a/arrow-cast/src/pretty.rs b/arrow-cast/src/pretty.rs
index 49fb972684..1e6535bb12 100644
--- a/arrow-cast/src/pretty.rs
+++ b/arrow-cast/src/pretty.rs
@@ -22,14 +22,12 @@
 //! [`RecordBatch`]: arrow_array::RecordBatch
 //! [`Array`]: arrow_array::Array
 
-use std::fmt::Display;
-
-use comfy_table::{Cell, Table};
-
 use arrow_array::{Array, ArrayRef, RecordBatch};
 use arrow_schema::{ArrowError, SchemaRef};
+use comfy_table::{Cell, Table};
+use std::fmt::Display;
 
-use crate::display::{ArrayFormatter, FormatOptions};
+use crate::display::{ArrayFormatter, FormatOptions, make_array_formatter};
 
 /// Create a visual representation of [`RecordBatch`]es
 ///
@@ -187,7 +185,7 @@ fn create_table(
         }
     });
 
-    if let Some(schema) = schema_opt {
+    if let Some(schema) = &schema_opt {
         let mut header = Vec::new();
         for field in schema.fields() {
             if options.types_info() {
@@ -208,10 +206,22 @@ fn create_table(
     }
 
     for batch in results {
+        let schema = schema_opt.as_ref().unwrap_or(batch.schema_ref());
+
+        // Could be a custom schema that was provided.
+        if batch.columns().len() != schema.fields().len() {
+            return Err(ArrowError::InvalidArgumentError(format!(
+                "Expected the same number of columns in a record batch ({}) as 
the number of fields ({}) in the schema",
+                batch.columns().len(),
+                schema.fields.len()
+            )));
+        }
+
         let formatters = batch
             .columns()
             .iter()
-            .map(|c| ArrayFormatter::try_new(c.as_ref(), options))
+            .zip(schema.fields().iter())
+            .map(|(c, field)| make_array_formatter(c, options, Some(field)))
             .collect::<Result<Vec<_>, ArrowError>>()?;
 
         for row in 0..batch.num_rows() {
@@ -242,7 +252,13 @@ fn create_column(
     table.set_header(header);
 
     for col in columns {
-        let formatter = ArrayFormatter::try_new(col.as_ref(), options)?;
+        let formatter = match options.formatter_factory() {
+            None => ArrayFormatter::try_new(col.as_ref(), options)?,
+            Some(formatters) => formatters
+                .create_array_formatter(col.as_ref(), options, None)
+                .transpose()
+                .unwrap_or_else(|| ArrayFormatter::try_new(col.as_ref(), 
options))?,
+        };
         for row in 0..col.len() {
             let cells = vec![Cell::new(formatter.value(row))];
             table.add_row(cells);
@@ -254,18 +270,21 @@ fn create_column(
 
 #[cfg(test)]
 mod tests {
+    use std::collections::HashMap;
     use std::fmt::Write;
     use std::sync::Arc;
 
-    use half::f16;
-
     use arrow_array::builder::*;
+    use arrow_array::cast::AsArray;
     use arrow_array::types::*;
     use arrow_array::*;
     use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer};
     use arrow_schema::*;
+    use half::f16;
 
-    use crate::display::{DurationFormat, array_value_to_string};
+    use crate::display::{
+        ArrayFormatterFactory, DisplayIndex, DurationFormat, 
array_value_to_string,
+    };
 
     use super::*;
 
@@ -1283,4 +1302,473 @@ mod tests {
         let actual: Vec<&str> = iso.lines().collect();
         assert_eq!(expected_iso, actual, "Actual result:\n{iso}");
     }
+
+    //
+    // Custom Formatting
+    //
+
+    /// The factory that will create the [`ArrayFormatter`]s.
+    #[derive(Debug)]
+    struct TestFormatters {}
+
+    impl ArrayFormatterFactory for TestFormatters {
+        fn create_array_formatter<'formatter>(
+            &self,
+            array: &'formatter dyn Array,
+            options: &FormatOptions<'formatter>,
+            field: Option<&'formatter Field>,
+        ) -> Result<Option<ArrayFormatter<'formatter>>, ArrowError> {
+            if field
+                .map(|f| f.extension_type_name() == Some("my_money"))
+                .unwrap_or(false)
+            {
+                // We assume that my_money always is an Int32.
+                let array = array.as_primitive();
+                let display_index = Box::new(MyMoneyFormatter {
+                    array,
+                    options: options.clone(),
+                });
+                return Ok(Some(ArrayFormatter::new(display_index, 
options.safe())));
+            }
+
+            if array.data_type() == &DataType::Int32 {
+                let array = array.as_primitive();
+                let display_index = Box::new(MyInt32Formatter {
+                    array,
+                    options: options.clone(),
+                });
+                return Ok(Some(ArrayFormatter::new(display_index, 
options.safe())));
+            }
+
+            Ok(None)
+        }
+    }
+
+    /// A format that will append a "€" sign to the end of the Int32 values.
+    struct MyMoneyFormatter<'a> {
+        array: &'a Int32Array,
+        options: FormatOptions<'a>,
+    }
+
+    impl<'a> DisplayIndex for MyMoneyFormatter<'a> {
+        fn write(&self, idx: usize, f: &mut dyn Write) -> 
crate::display::FormatResult {
+            match self.array.is_valid(idx) {
+                true => write!(f, "{} €", self.array.value(idx))?,
+                false => write!(f, "{}", self.options.null())?,
+            }
+
+            Ok(())
+        }
+    }
+
+    /// The actual formatter
+    struct MyInt32Formatter<'a> {
+        array: &'a Int32Array,
+        options: FormatOptions<'a>,
+    }
+
+    impl<'a> DisplayIndex for MyInt32Formatter<'a> {
+        fn write(&self, idx: usize, f: &mut dyn Write) -> 
crate::display::FormatResult {
+            match self.array.is_valid(idx) {
+                true => write!(f, "{} (32-Bit)", self.array.value(idx))?,
+                false => write!(f, "{}", self.options.null())?,
+            }
+
+            Ok(())
+        }
+    }
+
+    #[test]
+    fn test_format_batches_with_custom_formatters() {
+        // define a schema.
+        let options = FormatOptions::new()
+            .with_null("<NULL>")
+            .with_formatter_factory(Some(&TestFormatters {}));
+        let money_metadata = HashMap::from([(
+            extension::EXTENSION_TYPE_NAME_KEY.to_owned(),
+            "my_money".to_owned(),
+        )]);
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("income", DataType::Int32, 
true).with_metadata(money_metadata.clone()),
+        ]));
+
+        // define data.
+        let batch = RecordBatch::try_new(
+            schema,
+            vec![Arc::new(array::Int32Array::from(vec![
+                Some(1),
+                None,
+                Some(10),
+                Some(100),
+            ]))],
+        )
+        .unwrap();
+
+        let mut buf = String::new();
+        write!(
+            &mut buf,
+            "{}",
+            pretty_format_batches_with_options(&[batch], &options).unwrap()
+        )
+        .unwrap();
+
+        let s = [
+            "+--------+",
+            "| income |",
+            "+--------+",
+            "| 1 €    |",
+            "| <NULL> |",
+            "| 10 €   |",
+            "| 100 €  |",
+            "+--------+",
+        ];
+        let expected = s.join("\n");
+        assert_eq!(expected, buf);
+    }
+
+    #[test]
+    fn test_format_batches_with_custom_formatters_multi_nested_list() {
+        // define a schema.
+        let options = FormatOptions::new()
+            .with_null("<NULL>")
+            .with_formatter_factory(Some(&TestFormatters {}));
+        let money_metadata = HashMap::from([(
+            extension::EXTENSION_TYPE_NAME_KEY.to_owned(),
+            "my_money".to_owned(),
+        )]);
+        let nested_field = Arc::new(
+            Field::new_list_field(DataType::Int32, 
true).with_metadata(money_metadata.clone()),
+        );
+
+        // Create nested data
+        let inner_list = 
ListBuilder::new(Int32Builder::new()).with_field(nested_field);
+        let mut outer_list = FixedSizeListBuilder::new(inner_list, 2);
+        outer_list.values().append_value([Some(1)]);
+        outer_list.values().append_null();
+        outer_list.append(true);
+        outer_list.values().append_value([Some(2), Some(8)]);
+        outer_list
+            .values()
+            .append_value([Some(50), Some(25), Some(25)]);
+        outer_list.append(true);
+        let outer_list = outer_list.finish();
+
+        let schema = Arc::new(Schema::new(vec![Field::new(
+            "income",
+            outer_list.data_type().clone(),
+            true,
+        )]));
+
+        // define data.
+        let batch = RecordBatch::try_new(schema, 
vec![Arc::new(outer_list)]).unwrap();
+
+        let mut buf = String::new();
+        write!(
+            &mut buf,
+            "{}",
+            pretty_format_batches_with_options(&[batch], &options).unwrap()
+        )
+        .unwrap();
+
+        let s = [
+            "+----------------------------------+",
+            "| income                           |",
+            "+----------------------------------+",
+            "| [[1 €], <NULL>]                  |",
+            "| [[2 €, 8 €], [50 €, 25 €, 25 €]] |",
+            "+----------------------------------+",
+        ];
+        let expected = s.join("\n");
+        assert_eq!(expected, buf);
+    }
+
+    #[test]
+    fn test_format_batches_with_custom_formatters_nested_struct() {
+        // define a schema.
+        let options = FormatOptions::new()
+            .with_null("<NULL>")
+            .with_formatter_factory(Some(&TestFormatters {}));
+        let money_metadata = HashMap::from([(
+            extension::EXTENSION_TYPE_NAME_KEY.to_owned(),
+            "my_money".to_owned(),
+        )]);
+        let fields = Fields::from(vec![
+            Field::new("name", DataType::Utf8, true),
+            Field::new("income", DataType::Int32, 
true).with_metadata(money_metadata.clone()),
+        ]);
+
+        let schema = Arc::new(Schema::new(vec![Field::new(
+            "income",
+            DataType::Struct(fields.clone()),
+            true,
+        )]));
+
+        // Create nested data
+        let mut nested_data = StructBuilder::new(
+            fields,
+            vec![
+                Box::new(StringBuilder::new()),
+                Box::new(Int32Builder::new()),
+            ],
+        );
+        nested_data
+            .field_builder::<StringBuilder>(0)
+            .unwrap()
+            .extend([Some("Gimli"), Some("Legolas"), Some("Aragorn")]);
+        nested_data
+            .field_builder::<Int32Builder>(1)
+            .unwrap()
+            .extend([Some(10), None, Some(30)]);
+        nested_data.append(true);
+        nested_data.append(true);
+        nested_data.append(true);
+
+        // define data.
+        let batch = RecordBatch::try_new(schema, 
vec![Arc::new(nested_data.finish())]).unwrap();
+
+        let mut buf = String::new();
+        write!(
+            &mut buf,
+            "{}",
+            pretty_format_batches_with_options(&[batch], &options).unwrap()
+        )
+        .unwrap();
+
+        let s = [
+            "+---------------------------------+",
+            "| income                          |",
+            "+---------------------------------+",
+            "| {name: Gimli, income: 10 €}     |",
+            "| {name: Legolas, income: <NULL>} |",
+            "| {name: Aragorn, income: 30 €}   |",
+            "+---------------------------------+",
+        ];
+        let expected = s.join("\n");
+        assert_eq!(expected, buf);
+    }
+
+    #[test]
+    fn test_format_batches_with_custom_formatters_nested_map() {
+        // define a schema.
+        let options = FormatOptions::new()
+            .with_null("<NULL>")
+            .with_formatter_factory(Some(&TestFormatters {}));
+        let money_metadata = HashMap::from([(
+            extension::EXTENSION_TYPE_NAME_KEY.to_owned(),
+            "my_money".to_owned(),
+        )]);
+
+        let mut array = MapBuilder::<StringBuilder, Int32Builder>::new(
+            None,
+            StringBuilder::new(),
+            Int32Builder::new(),
+        )
+        .with_values_field(
+            Field::new("values", DataType::Int32, 
true).with_metadata(money_metadata.clone()),
+        );
+        array
+            .keys()
+            .extend([Some("Gimli"), Some("Legolas"), Some("Aragorn")]);
+        array.values().extend([Some(10), None, Some(30)]);
+        array.append(true).unwrap();
+        let array = array.finish();
+
+        // define data.
+        let schema = Arc::new(Schema::new(vec![Field::new(
+            "income",
+            array.data_type().clone(),
+            true,
+        )]));
+        let batch = RecordBatch::try_new(schema, 
vec![Arc::new(array)]).unwrap();
+
+        let mut buf = String::new();
+        write!(
+            &mut buf,
+            "{}",
+            pretty_format_batches_with_options(&[batch], &options).unwrap()
+        )
+        .unwrap();
+
+        let s = [
+            "+-----------------------------------------------+",
+            "| income                                        |",
+            "+-----------------------------------------------+",
+            "| {Gimli: 10 €, Legolas: <NULL>, Aragorn: 30 €} |",
+            "+-----------------------------------------------+",
+        ];
+        let expected = s.join("\n");
+        assert_eq!(expected, buf);
+    }
+
+    #[test]
+    fn test_format_batches_with_custom_formatters_nested_union() {
+        // define a schema.
+        let options = FormatOptions::new()
+            .with_null("<NULL>")
+            .with_formatter_factory(Some(&TestFormatters {}));
+        let money_metadata = HashMap::from([(
+            extension::EXTENSION_TYPE_NAME_KEY.to_owned(),
+            "my_money".to_owned(),
+        )]);
+        let fields = UnionFields::new(
+            vec![0],
+            vec![Field::new("income", DataType::Int32, 
true).with_metadata(money_metadata.clone())],
+        );
+
+        // Create nested data and construct it with the correct metadata
+        let mut array_builder = UnionBuilder::new_dense();
+        array_builder.append::<Int32Type>("income", 1).unwrap();
+        let (_, type_ids, offsets, children) = 
array_builder.build().unwrap().into_parts();
+        let array = UnionArray::try_new(fields, type_ids, offsets, 
children).unwrap();
+
+        let schema = Arc::new(Schema::new(vec![Field::new(
+            "income",
+            array.data_type().clone(),
+            true,
+        )]));
+
+        // define data.
+        let batch = RecordBatch::try_new(schema, 
vec![Arc::new(array)]).unwrap();
+
+        let mut buf = String::new();
+        write!(
+            &mut buf,
+            "{}",
+            pretty_format_batches_with_options(&[batch], &options).unwrap()
+        )
+        .unwrap();
+
+        let s = [
+            "+--------------+",
+            "| income       |",
+            "+--------------+",
+            "| {income=1 €} |",
+            "+--------------+",
+        ];
+        let expected = s.join("\n");
+        assert_eq!(expected, buf);
+    }
+
+    #[test]
+    fn 
test_format_batches_with_custom_formatters_custom_schema_overrules_batch_schema()
 {
+        // define a schema.
+        let options = 
FormatOptions::new().with_formatter_factory(Some(&TestFormatters {}));
+        let money_metadata = HashMap::from([(
+            extension::EXTENSION_TYPE_NAME_KEY.to_owned(),
+            "my_money".to_owned(),
+        )]);
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("income", DataType::Int32, 
true).with_metadata(money_metadata.clone()),
+        ]));
+
+        // define data.
+        let batch = RecordBatch::try_new(
+            schema,
+            vec![Arc::new(array::Int32Array::from(vec![
+                Some(1),
+                None,
+                Some(10),
+                Some(100),
+            ]))],
+        )
+        .unwrap();
+
+        let mut buf = String::new();
+        write!(
+            &mut buf,
+            "{}",
+            create_table(
+                // No metadata compared to 
test_format_batches_with_custom_formatters
+                Some(Arc::new(Schema::new(vec![Field::new(
+                    "income",
+                    DataType::Int32,
+                    true
+                ),]))),
+                &[batch],
+                &options,
+            )
+            .unwrap()
+        )
+        .unwrap();
+
+        // No € formatting as in test_format_batches_with_custom_formatters
+        let s = [
+            "+--------------+",
+            "| income       |",
+            "+--------------+",
+            "| 1 (32-Bit)   |",
+            "|              |",
+            "| 10 (32-Bit)  |",
+            "| 100 (32-Bit) |",
+            "+--------------+",
+        ];
+        let expected = s.join("\n");
+        assert_eq!(expected, buf);
+    }
+
+    #[test]
+    fn test_format_column_with_custom_formatters() {
+        // define data.
+        let array = Arc::new(array::Int32Array::from(vec![
+            Some(1),
+            None,
+            Some(10),
+            Some(100),
+        ]));
+
+        let mut buf = String::new();
+        write!(
+            &mut buf,
+            "{}",
+            pretty_format_columns_with_options(
+                "income",
+                &[array],
+                
&FormatOptions::default().with_formatter_factory(Some(&TestFormatters {}))
+            )
+            .unwrap()
+        )
+        .unwrap();
+
+        let s = [
+            "+--------------+",
+            "| income       |",
+            "+--------------+",
+            "| 1 (32-Bit)   |",
+            "|              |",
+            "| 10 (32-Bit)  |",
+            "| 100 (32-Bit) |",
+            "+--------------+",
+        ];
+        let expected = s.join("\n");
+        assert_eq!(expected, buf);
+    }
+
+    #[test]
+    fn test_pretty_format_batches_with_schema_with_wrong_number_of_fields() {
+        let schema_a = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int32, true),
+            Field::new("b", DataType::Utf8, true),
+        ]));
+        let schema_b = Arc::new(Schema::new(vec![Field::new("a", 
DataType::Int32, true)]));
+
+        // define data.
+        let batch = RecordBatch::try_new(
+            schema_b,
+            vec![Arc::new(array::Int32Array::from(vec![
+                Some(1),
+                None,
+                Some(10),
+                Some(100),
+            ]))],
+        )
+        .unwrap();
+
+        let error = pretty_format_batches_with_schema(schema_a, &[batch])
+            .err()
+            .unwrap();
+        assert_eq!(
+            &error.to_string(),
+            "Invalid argument error: Expected the same number of columns in a 
record batch (1) as the number of fields (2) in the schema"
+        );
+    }
 }


Reply via email to