This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 51bf8a40f7 [Variant] extend shredded null handling for arrays (#9599)
51bf8a40f7 is described below
commit 51bf8a40f72e37528cf36419f8f453ccd0e45868
Author: Konstantin Tarasov <[email protected]>
AuthorDate: Tue Mar 31 15:44:32 2026 -0400
[Variant] extend shredded null handling for arrays (#9599)
# Which issue does this PR close?
<!--
We generally require a GitHub issue to be filed for all bug fixes and
enhancements and this helps us generate change logs for our releases.
You can link an issue to this PR using the GitHub syntax.
-->
- Closes #8400.
# Rationale for this change
Check issue
<!--
Why are you proposing this change? If this is already explained clearly
in the issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->
# What changes are included in this PR?
- Added `AppendNullMode` enum supporting all semantics.
- Replaced the bool logic to the new enum
- Fix test outputs for List Array cases
<!--
There is no need to duplicate the description in the issue here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->
# Are these changes tested?
- Added unit tests
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code
If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->
# Are there any user-facing changes?
<!--
If there are user-facing changes then we may require documentation to be
updated before approving the PR.
If there are any breaking changes to public APIs, please call them out.
-->
---
parquet-variant-compute/src/shred_variant.rs | 242 ++++++++++++++++++------
parquet-variant-compute/src/type_conversion.rs | 24 ++-
parquet-variant-compute/src/variant_get.rs | 53 ++++++
parquet-variant-compute/src/variant_to_arrow.rs | 238 ++++++++++++++++-------
4 files changed, 436 insertions(+), 121 deletions(-)
diff --git a/parquet-variant-compute/src/shred_variant.rs
b/parquet-variant-compute/src/shred_variant.rs
index 6520ea700b..d80d2f9863 100644
--- a/parquet-variant-compute/src/shred_variant.rs
+++ b/parquet-variant-compute/src/shred_variant.rs
@@ -84,7 +84,7 @@ pub fn shred_variant(array: &VariantArray, as_type:
&DataType) -> Result<Variant
as_type,
&cast_options,
array.len(),
- true,
+ NullValue::TopLevelVariant,
)?;
for i in 0..array.len() {
if array.is_null(i) {
@@ -102,11 +102,42 @@ pub fn shred_variant(array: &VariantArray, as_type:
&DataType) -> Result<Variant
))
}
+/// Controls how `append_null` is encoded for a shredded `(value,
typed_value)` pair.
+///
+/// | Mode | Struct validity bit | `value` | `typed_value` | Meaning |
+/// | --- | --- | --- | --- | --- |
+/// | `TopLevelVariant` | null | NULL | NULL | SQL NULL at the top-level
variant row |
+/// | `ObjectField` | non-null | NULL | NULL | Missing object field |
+/// | `ArrayElement` | non-null | `Variant::Null` | NULL | Explicit null array
element |
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub(crate) enum NullValue {
+ TopLevelVariant,
+ ObjectField,
+ ArrayElement,
+}
+
+impl NullValue {
+ fn append_to(
+ self,
+ nulls: &mut NullBufferBuilder,
+ value_builder: &mut VariantValueArrayBuilder,
+ ) {
+ match self {
+ Self::TopLevelVariant => nulls.append_null(),
+ Self::ObjectField | Self::ArrayElement => nulls.append_non_null(),
+ }
+ match self {
+ Self::TopLevelVariant | Self::ObjectField =>
value_builder.append_null(),
+ Self::ArrayElement => value_builder.append_value(Variant::Null),
+ }
+ }
+}
+
pub(crate) fn make_variant_to_shredded_variant_arrow_row_builder<'a>(
data_type: &'a DataType,
cast_options: &'a CastOptions,
capacity: usize,
- top_level: bool,
+ null_value: NullValue,
) -> Result<VariantToShreddedVariantRowBuilder<'a>> {
let builder = match data_type {
DataType::Struct(fields) => {
@@ -114,7 +145,7 @@ pub(crate) fn
make_variant_to_shredded_variant_arrow_row_builder<'a>(
fields,
cast_options,
capacity,
- top_level,
+ null_value,
)?;
VariantToShreddedVariantRowBuilder::Object(typed_value_builder)
}
@@ -127,6 +158,7 @@ pub(crate) fn
make_variant_to_shredded_variant_arrow_row_builder<'a>(
data_type,
cast_options,
capacity,
+ null_value,
)?;
VariantToShreddedVariantRowBuilder::Array(typed_value_builder)
}
@@ -156,7 +188,7 @@ pub(crate) fn
make_variant_to_shredded_variant_arrow_row_builder<'a>(
let builder =
make_primitive_variant_to_arrow_row_builder(data_type,
cast_options, capacity)?;
let typed_value_builder =
- VariantToShreddedPrimitiveVariantRowBuilder::new(builder,
capacity, top_level);
+ VariantToShreddedPrimitiveVariantRowBuilder::new(builder,
capacity, null_value);
VariantToShreddedVariantRowBuilder::Primitive(typed_value_builder)
}
DataType::FixedSizeBinary(_) => {
@@ -204,33 +236,31 @@ impl<'a> VariantToShreddedVariantRowBuilder<'a> {
}
}
-/// A top-level variant shredder -- appending NULL produces typed_value=NULL
and value=Variant::Null
+/// A shredded primitive field builder.
pub(crate) struct VariantToShreddedPrimitiveVariantRowBuilder<'a> {
value_builder: VariantValueArrayBuilder,
typed_value_builder: PrimitiveVariantToArrowRowBuilder<'a>,
nulls: NullBufferBuilder,
- top_level: bool,
+ null_value: NullValue,
}
impl<'a> VariantToShreddedPrimitiveVariantRowBuilder<'a> {
pub(crate) fn new(
typed_value_builder: PrimitiveVariantToArrowRowBuilder<'a>,
capacity: usize,
- top_level: bool,
+ null_value: NullValue,
) -> Self {
Self {
value_builder: VariantValueArrayBuilder::new(capacity),
typed_value_builder,
nulls: NullBufferBuilder::new(capacity),
- top_level,
+ null_value,
}
}
fn append_null(&mut self) -> Result<()> {
- // Only the top-level struct that represents the variant can be
nullable; object fields and
- // array elements are non-nullable.
- self.nulls.append(!self.top_level);
- self.value_builder.append_null();
+ self.null_value
+ .append_to(&mut self.nulls, &mut self.value_builder);
self.typed_value_builder.append_null()
}
@@ -256,6 +286,8 @@ impl<'a> VariantToShreddedPrimitiveVariantRowBuilder<'a> {
pub(crate) struct VariantToShreddedArrayVariantRowBuilder<'a> {
value_builder: VariantValueArrayBuilder,
typed_value_builder: ArrayVariantToArrowRowBuilder<'a>,
+ nulls: NullBufferBuilder,
+ null_value: NullValue,
}
impl<'a> VariantToShreddedArrayVariantRowBuilder<'a> {
@@ -263,6 +295,7 @@ impl<'a> VariantToShreddedArrayVariantRowBuilder<'a> {
data_type: &'a DataType,
cast_options: &'a CastOptions,
capacity: usize,
+ null_value: NullValue,
) -> Result<Self> {
Ok(Self {
value_builder: VariantValueArrayBuilder::new(capacity),
@@ -271,11 +304,14 @@ impl<'a> VariantToShreddedArrayVariantRowBuilder<'a> {
cast_options,
capacity,
)?,
+ nulls: NullBufferBuilder::new(capacity),
+ null_value,
})
}
fn append_null(&mut self) -> Result<()> {
- self.value_builder.append_value(Variant::Null);
+ self.null_value
+ .append_to(&mut self.nulls, &mut self.value_builder);
self.typed_value_builder.append_null()?;
Ok(())
}
@@ -285,12 +321,14 @@ impl<'a> VariantToShreddedArrayVariantRowBuilder<'a> {
// If the variant is an array, value must be null.
match variant {
Variant::List(list) => {
+ self.nulls.append_non_null();
self.value_builder.append_null();
self.typed_value_builder
.append_value(&Variant::List(list))?;
Ok(true)
}
other => {
+ self.nulls.append_non_null();
self.value_builder.append_value(other);
self.typed_value_builder.append_null()?;
Ok(false)
@@ -298,13 +336,11 @@ impl<'a> VariantToShreddedArrayVariantRowBuilder<'a> {
}
}
- fn finish(self) -> Result<(BinaryViewArray, ArrayRef, Option<NullBuffer>)>
{
+ fn finish(mut self) -> Result<(BinaryViewArray, ArrayRef,
Option<NullBuffer>)> {
Ok((
self.value_builder.build()?,
self.typed_value_builder.finish()?,
- // All elements of an array must be present (not missing) because
- // the array Variant encoding does not allow missing elements
- None,
+ self.nulls.finish(),
))
}
}
@@ -314,7 +350,7 @@ pub(crate) struct
VariantToShreddedObjectVariantRowBuilder<'a> {
typed_value_builders: IndexMap<&'a str,
VariantToShreddedVariantRowBuilder<'a>>,
typed_value_nulls: NullBufferBuilder,
nulls: NullBufferBuilder,
- top_level: bool,
+ null_value: NullValue,
}
impl<'a> VariantToShreddedObjectVariantRowBuilder<'a> {
@@ -322,14 +358,14 @@ impl<'a> VariantToShreddedObjectVariantRowBuilder<'a> {
fields: &'a Fields,
cast_options: &'a CastOptions,
capacity: usize,
- top_level: bool,
+ null_value: NullValue,
) -> Result<Self> {
let typed_value_builders = fields.iter().map(|field| {
let builder = make_variant_to_shredded_variant_arrow_row_builder(
field.data_type(),
cast_options,
capacity,
- false,
+ NullValue::ObjectField,
)?;
Ok((field.name().as_str(), builder))
});
@@ -338,15 +374,13 @@ impl<'a> VariantToShreddedObjectVariantRowBuilder<'a> {
typed_value_builders: typed_value_builders.collect::<Result<_>>()?,
typed_value_nulls: NullBufferBuilder::new(capacity),
nulls: NullBufferBuilder::new(capacity),
- top_level,
+ null_value,
})
}
fn append_null(&mut self) -> Result<()> {
- // Only the top-level struct that represents the variant can be
nullable; object fields and
- // array elements are non-nullable.
- self.nulls.append(!self.top_level);
- self.value_builder.append_null();
+ self.null_value
+ .append_to(&mut self.nulls, &mut self.value_builder);
self.typed_value_nulls.append_null();
for (_, typed_value_builder) in &mut self.typed_value_builders {
typed_value_builder.append_null()?;
@@ -669,6 +703,12 @@ mod tests {
use std::sync::Arc;
use uuid::Uuid;
+ const NULL_VALUES: [NullValue; 3] = [
+ NullValue::TopLevelVariant,
+ NullValue::ObjectField,
+ NullValue::ArrayElement,
+ ];
+
#[derive(Clone)]
enum VariantValue<'a> {
Value(Variant<'a, 'a>),
@@ -881,7 +921,9 @@ mod tests {
expected_variant.clone()
);
}
- None => unreachable!(),
+ None => {
+ assert!(fallbacks.0.is_null(idx));
+ }
}
}
}
@@ -949,6 +991,121 @@ mod tests {
}
}
+ fn assert_append_null_mode_value_and_struct_nulls(
+ mode: NullValue,
+ value: &BinaryViewArray,
+ nulls: Option<&arrow::buffer::NullBuffer>,
+ ) {
+ if mode == NullValue::TopLevelVariant {
+ assert!(nulls.is_some_and(|n| n.is_null(0)));
+ } else {
+ assert!(nulls.is_none());
+ }
+
+ if mode == NullValue::ArrayElement {
+ assert!(value.is_valid(0));
+ assert_eq!(
+ Variant::new(EMPTY_VARIANT_METADATA_BYTES, value.value(0)),
+ Variant::Null
+ );
+ } else {
+ assert!(value.is_null(0));
+ }
+ }
+
+ #[test]
+ fn test_append_null_mode_semantics_primitive_builder() {
+ let cast_options = arrow::compute::CastOptions::default();
+
+ for mode in NULL_VALUES {
+ let mut primitive_builder =
make_variant_to_shredded_variant_arrow_row_builder(
+ &DataType::Int64,
+ &cast_options,
+ 1,
+ mode,
+ )
+ .unwrap();
+ primitive_builder.append_null().unwrap();
+ let (primitive_value, primitive_typed_value, primitive_nulls) =
+ primitive_builder.finish().unwrap();
+ let primitive_typed_value = primitive_typed_value
+ .as_any()
+ .downcast_ref::<Int64Array>()
+ .unwrap();
+
+ assert!(primitive_typed_value.is_null(0));
+ assert_append_null_mode_value_and_struct_nulls(
+ mode,
+ &primitive_value,
+ primitive_nulls.as_ref(),
+ );
+ }
+ }
+
+ #[test]
+ fn test_append_null_mode_semantics_array_builder() {
+ let cast_options = arrow::compute::CastOptions::default();
+ let list_type = DataType::List(Arc::new(Field::new("item",
DataType::Int64, true)));
+
+ for mode in NULL_VALUES {
+ let mut array_builder =
make_variant_to_shredded_variant_arrow_row_builder(
+ &list_type,
+ &cast_options,
+ 1,
+ mode,
+ )
+ .unwrap();
+ array_builder.append_null().unwrap();
+ let (value, typed_value, nulls) = array_builder.finish().unwrap();
+
+ assert_append_null_mode_value_and_struct_nulls(mode, &value,
nulls.as_ref());
+
+ let typed_value =
typed_value.as_any().downcast_ref::<ListArray>().unwrap();
+ assert_eq!(typed_value.len(), 1);
+ assert!(typed_value.is_null(0));
+ assert_eq!(typed_value.values().len(), 0);
+ }
+ }
+
+ #[test]
+ fn test_append_null_mode_semantics_object_builder() {
+ let cast_options = arrow::compute::CastOptions::default();
+ let object_type = DataType::Struct(Fields::from(vec![
+ Field::new("id", DataType::Int64, true),
+ Field::new("name", DataType::Utf8, true),
+ ]));
+
+ for mode in NULL_VALUES {
+ let mut object_builder =
make_variant_to_shredded_variant_arrow_row_builder(
+ &object_type,
+ &cast_options,
+ 1,
+ mode,
+ )
+ .unwrap();
+ object_builder.append_null().unwrap();
+ let (value, typed_value, nulls) = object_builder.finish().unwrap();
+
+ assert_append_null_mode_value_and_struct_nulls(mode, &value,
nulls.as_ref());
+
+ let typed_struct = typed_value
+ .as_any()
+ .downcast_ref::<arrow::array::StructArray>()
+ .unwrap();
+ assert_eq!(typed_struct.len(), 1);
+ assert!(typed_struct.is_null(0));
+
+ for field_name in ["id", "name"] {
+ let field = ShreddedVariantFieldArray::try_new(
+ typed_struct.column_by_name(field_name).unwrap(),
+ )
+ .unwrap();
+ assert!(field.value_field().unwrap().is_null(0));
+ assert!(field.typed_value_field().unwrap().is_null(0));
+ }
+ }
+ }
+
#[test]
fn test_already_shredded_input_error() {
// Create a VariantArray that already has typed_value_field
@@ -1338,13 +1495,7 @@ mod tests {
5,
&[0, 3, 6, 6, 6, 6],
&[Some(3), Some(3), None, None, Some(0)],
- &[
- None,
- None,
- Some(Variant::from("not a list")),
- Some(Variant::Null),
- None,
- ],
+ &[None, None, Some(Variant::from("not a list")), None, None],
(
&[Some(1), Some(2), Some(3), Some(1), None, None],
&[
@@ -1414,13 +1565,7 @@ mod tests {
5,
&[0, 3, 6, 6, 6],
&[Some(3), Some(3), None, None, Some(0)],
- &[
- None,
- None,
- Some(Variant::from("not a list")),
- Some(Variant::Null),
- None,
- ],
+ &[None, None, Some(Variant::from("not a list")), None, None],
(
&[Some(1), Some(2), Some(3), Some(1), None, None],
&[
@@ -1522,12 +1667,7 @@ mod tests {
4,
&[0, 3, 6, 6, 6],
&[Some(3), Some(3), None, None],
- &[
- None,
- None,
- Some(Variant::from("not a list")),
- Some(Variant::Null),
- ],
+ &[None, None, Some(Variant::from("not a list")), None],
);
let outer_elements =
@@ -1615,7 +1755,7 @@ mod tests {
3,
&[0, 2, 2, 2],
&[Some(2), None, None],
- &[None, Some(Variant::from("not a list")), Some(Variant::Null)],
+ &[None, Some(Variant::from("not a list")), None],
);
// Validate nested struct fields for each element
@@ -2101,13 +2241,7 @@ mod tests {
scores_field.len(),
&[0i32, 2, 4, 4, 4, 4],
&[Some(2), Some(2), None, None, None],
- &[
- None,
- None,
- Some(Variant::Null),
- Some(Variant::Null),
- Some(Variant::Null),
- ],
+ &[None, None, None, None, None],
(
&[Some(10), Some(20), None, None],
&[None, None, Some(Variant::from("oops")),
Some(Variant::Null)],
diff --git a/parquet-variant-compute/src/type_conversion.rs
b/parquet-variant-compute/src/type_conversion.rs
index 4086a24107..7b9eb67d1a 100644
--- a/parquet-variant-compute/src/type_conversion.rs
+++ b/parquet-variant-compute/src/type_conversion.rs
@@ -17,11 +17,12 @@
//! Module for transforming a typed arrow `Array` to `VariantArray`.
-use arrow::compute::{DecimalCast, rescale_decimal};
+use arrow::compute::{CastOptions, DecimalCast, rescale_decimal};
use arrow::datatypes::{
self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type,
Decimal64Type, Decimal128Type,
DecimalType,
};
+use arrow::error::{ArrowError, Result};
use chrono::Timelike;
use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8,
VariantDecimal16};
@@ -37,6 +38,27 @@ pub(crate) trait TimestampFromVariant<const NTZ: bool>:
ArrowTimestampType {
fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native>;
}
+/// Cast a single `Variant` value with safe/strict semantics.
+///
+/// Returns `Ok(Some(_))` on successful conversion.
+/// Returns `Ok(None)` when conversion fails in safe mode or the source value
is `Variant::Null`.
+/// Returns `Err(_)` when conversion fails in strict mode.
+pub(crate) fn variant_cast_with_options<'a, 'm, 'v, T>(
+ variant: &'a Variant<'m, 'v>,
+ cast_options: &CastOptions<'_>,
+ cast: impl FnOnce(&'a Variant<'m, 'v>) -> Option<T>,
+) -> Result<Option<T>> {
+ if let Some(value) = cast(variant) {
+ Ok(Some(value))
+ } else if matches!(variant, Variant::Null) || cast_options.safe {
+ Ok(None)
+ } else {
+ Err(ArrowError::CastError(format!(
+ "Failed to cast variant value {variant:?}"
+ )))
+ }
+}
+
/// Macro to generate PrimitiveFromVariant implementations for Arrow primitive
types
macro_rules! impl_primitive_from_variant {
($arrow_type:ty, $variant_method:ident $(, $cast_fn:expr)?) => {
diff --git a/parquet-variant-compute/src/variant_get.rs
b/parquet-variant-compute/src/variant_get.rs
index 3e9892cacf..73906f70eb 100644
--- a/parquet-variant-compute/src/variant_get.rs
+++ b/parquet-variant-compute/src/variant_get.rs
@@ -4270,6 +4270,59 @@ mod test {
}
}
+ #[test]
+ fn test_variant_get_list_like_unsafe_cast_preserves_null_elements() {
+ let string_array: ArrayRef = Arc::new(StringArray::from(vec![r#"[1,
null, 3]"#]));
+ let variant_array =
ArrayRef::from(json_to_variant(&string_array).unwrap());
+ let cast_options = CastOptions {
+ safe: false,
+ ..Default::default()
+ };
+ let options = GetOptions::new()
+ .with_as_type(Some(FieldRef::from(Field::new(
+ "result",
+ DataType::List(Arc::new(Field::new("item", DataType::Int64,
true))),
+ true,
+ ))))
+ .with_cast_options(cast_options);
+
+ let result = variant_get(&variant_array, options).unwrap();
+ let element_struct = result
+ .as_any()
+ .downcast_ref::<ListArray>()
+ .unwrap()
+ .values()
+ .as_any()
+ .downcast_ref::<StructArray>()
+ .unwrap();
+
+ let value = element_struct
+ .column_by_name("value")
+ .unwrap()
+ .as_any()
+ .downcast_ref::<BinaryViewArray>()
+ .unwrap();
+ let typed_value = element_struct
+ .column_by_name("typed_value")
+ .unwrap()
+ .as_any()
+ .downcast_ref::<Int64Array>()
+ .unwrap();
+
+ assert_eq!(typed_value.len(), 3);
+ assert_eq!(typed_value.value(0), 1);
+ assert!(typed_value.is_null(1));
+ assert_eq!(typed_value.value(2), 3);
+
+ assert!(value.is_null(0));
+ assert!(value.is_valid(1));
+ assert_eq!(
+ Variant::new(EMPTY_VARIANT_METADATA_BYTES, value.value(1)),
+ Variant::Null
+ );
+ assert!(value.is_null(2));
+ }
+
#[test]
fn test_variant_get_list_like_unsafe_cast_errors_on_non_list() {
let string_array: ArrayRef = Arc::new(StringArray::from(vec!["[1, 2]",
"\"not a list\""]));
diff --git a/parquet-variant-compute/src/variant_to_arrow.rs
b/parquet-variant-compute/src/variant_to_arrow.rs
index dc8fbcd223..dd396117d2 100644
--- a/parquet-variant-compute/src/variant_to_arrow.rs
+++ b/parquet-variant-compute/src/variant_to_arrow.rs
@@ -16,10 +16,12 @@
// under the License.
use crate::shred_variant::{
- VariantToShreddedVariantRowBuilder,
make_variant_to_shredded_variant_arrow_row_builder,
+ NullValue, VariantToShreddedVariantRowBuilder,
+ make_variant_to_shredded_variant_arrow_row_builder,
};
use crate::type_conversion::{
- PrimitiveFromVariant, TimestampFromVariant, variant_to_unscaled_decimal,
+ PrimitiveFromVariant, TimestampFromVariant, variant_cast_with_options,
+ variant_to_unscaled_decimal,
};
use crate::variant_array::ShreddedVariantFieldArray;
use crate::{VariantArray, VariantValueArrayBuilder};
@@ -545,30 +547,30 @@ impl<'a> StructVariantToArrowRowBuilder<'a> {
}
fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
- let Variant::Object(obj) = value else {
- if self.cast_options.safe {
- self.append_null()?;
- return Ok(false);
- }
- return Err(ArrowError::CastError(format!(
- "Failed to extract struct from variant {:?}",
- value
- )));
- };
-
- for (index, field) in self.fields.iter().enumerate() {
- match obj.get(field.name()) {
- Some(field_value) => {
- self.field_builders[index].append_value(field_value)?;
- }
- None => {
- self.field_builders[index].append_null()?;
+ match variant_cast_with_options(value, self.cast_options,
Variant::as_object) {
+ Ok(Some(obj)) => {
+ for (index, field) in self.fields.iter().enumerate() {
+ match obj.get(field.name()) {
+ Some(field_value) => {
+
self.field_builders[index].append_value(field_value)?;
+ }
+ None => {
+ self.field_builders[index].append_null()?;
+ }
+ }
}
+
+ self.nulls.append_non_null();
+ Ok(true)
+ }
+ Ok(None) => {
+ self.append_null()?;
+ Ok(false)
}
+ Err(_) => Err(ArrowError::CastError(format!(
+ "Failed to extract struct from variant {value:?}"
+ ))),
}
-
- self.nulls.append_non_null();
- Ok(true)
}
fn finish(mut self) -> Result<ArrayRef> {
@@ -707,21 +709,24 @@ macro_rules! define_variant_to_primitive_builder {
}
fn append_value(&mut self, $value: &Variant<'_, '_>) ->
Result<bool> {
- if let Some(v) = $value_transform {
- self.builder.append_value(v);
- Ok(true)
- } else {
- if !self.cast_options.safe {
- // Unsafe casting: return error on conversion failure
- return Err(ArrowError::CastError(format!(
- "Failed to extract primitive of type {} from
variant {:?} at path VariantPath([])",
- $type_name,
- $value
- )));
+ match variant_cast_with_options(
+ $value,
+ self.cast_options,
+ |$value| $value_transform,
+ ) {
+ Ok(Some(v)) => {
+ self.builder.append_value(v);
+ Ok(true)
+ }
+ Ok(None) => {
+ self.builder.append_null();
+ Ok(false)
}
- // Safe casting: append null on conversion failure
- self.builder.append_null();
- Ok(false)
+ Err(_) => Err(ArrowError::CastError(format!(
+ "Failed to extract primitive of type {type_name} from
variant {value:?} at path VariantPath([])",
+ type_name = $type_name,
+ value = $value
+ ))),
}
}
@@ -748,7 +753,7 @@ define_variant_to_primitive_builder!(
define_variant_to_primitive_builder!(
struct VariantToBooleanArrowRowBuilder<'a>
|capacity| -> BooleanBuilder { BooleanBuilder::with_capacity(capacity) },
- |value| value.as_boolean(),
+ |value| value.as_boolean(),
type_name: datatypes::BooleanType::DATA_TYPE
);
@@ -821,20 +826,23 @@ where
}
fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
- if let Some(scaled) = variant_to_unscaled_decimal::<T>(value,
self.precision, self.scale) {
- self.builder.append_value(scaled);
- Ok(true)
- } else if self.cast_options.safe {
- self.builder.append_null();
- Ok(false)
- } else {
- Err(ArrowError::CastError(format!(
- "Failed to cast to {}(precision={}, scale={}) from variant
{:?}",
- T::PREFIX,
- self.precision,
- self.scale,
- value
- )))
+ match variant_cast_with_options(value, self.cast_options, |value| {
+ variant_to_unscaled_decimal::<T>(value, self.precision, self.scale)
+ }) {
+ Ok(Some(scaled)) => {
+ self.builder.append_value(scaled);
+ Ok(true)
+ }
+ Ok(None) => {
+ self.builder.append_null();
+ Ok(false)
+ }
+ Err(_) => Err(ArrowError::CastError(format!(
+ "Failed to cast to {prefix}(precision={precision},
scale={scale}) from variant {value:?}",
+ prefix = T::PREFIX,
+ precision = self.precision,
+ scale = self.scale
+ ))),
}
}
@@ -863,20 +871,19 @@ impl<'a> VariantToUuidArrowRowBuilder<'a> {
}
fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
- match value.as_uuid() {
- Some(uuid) => {
+ match variant_cast_with_options(value, self.cast_options,
Variant::as_uuid) {
+ Ok(Some(uuid)) => {
self.builder
.append_value(uuid.as_bytes())
.map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
-
Ok(true)
}
- None if self.cast_options.safe => {
+ Ok(None) => {
self.builder.append_null();
Ok(false)
}
- None => Err(ArrowError::CastError(format!(
- "Failed to extract UUID from variant {value:?}",
+ Err(_) => Err(ArrowError::CastError(format!(
+ "Failed to extract UUID from variant {value:?}"
))),
}
}
@@ -919,7 +926,7 @@ where
element_data_type,
cast_options,
capacity,
- false,
+ NullValue::ArrayElement,
)?;
Ok(Self {
field,
@@ -938,8 +945,8 @@ where
}
fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
- match value {
- Variant::List(list) => {
+ match variant_cast_with_options(value, self.cast_options,
Variant::as_list) {
+ Ok(Some(list)) => {
for element in list.iter() {
self.element_builder.append_value(element)?;
self.current_offset =
self.current_offset.add_checked(O::ONE)?;
@@ -948,13 +955,12 @@ where
self.nulls.append_non_null();
Ok(true)
}
- _ if self.cast_options.safe => {
+ Ok(None) => {
self.append_null()?;
Ok(false)
}
- _ => Err(ArrowError::CastError(format!(
- "Failed to extract list from variant {:?}",
- value
+ Err(_) => Err(ArrowError::CastError(format!(
+ "Failed to extract list from variant {value:?}"
))),
}
}
@@ -1067,11 +1073,18 @@ define_variant_to_primitive_builder!(
#[cfg(test)]
mod tests {
- use super::make_primitive_variant_to_arrow_row_builder;
+ use super::{
+ make_primitive_variant_to_arrow_row_builder,
make_typed_variant_to_arrow_row_builder,
+ };
+ use arrow::array::{
+ Array, Decimal32Array, FixedSizeBinaryArray, Int32Array, ListArray,
StructArray,
+ };
use arrow::compute::CastOptions;
use arrow::datatypes::{DataType, Field, Fields, UnionFields, UnionMode};
use arrow::error::ArrowError;
+ use parquet_variant::{Variant, VariantDecimal4};
use std::sync::Arc;
+ use uuid::Uuid;
#[test]
fn make_primitive_builder_rejects_non_primitive_types() {
@@ -1120,4 +1133,97 @@ mod tests {
}
}
}
+
+ #[test]
+ fn strict_cast_allows_variant_null_for_primitive_builder() {
+ let cast_options = CastOptions {
+ safe: false,
+ ..Default::default()
+ };
+ let mut builder =
+ make_primitive_variant_to_arrow_row_builder(&DataType::Int32,
&cast_options, 2)
+ .unwrap();
+
+ assert!(!builder.append_value(&Variant::Null).unwrap());
+ assert!(builder.append_value(&Variant::Int32(42)).unwrap());
+
+ let array = builder.finish().unwrap();
+ let int_array = array.as_any().downcast_ref::<Int32Array>().unwrap();
+ assert!(int_array.is_null(0));
+ assert_eq!(int_array.value(1), 42);
+ }
+
+ #[test]
+ fn strict_cast_allows_variant_null_for_decimal_builder() {
+ let cast_options = CastOptions {
+ safe: false,
+ ..Default::default()
+ };
+ let mut builder = make_primitive_variant_to_arrow_row_builder(
+ &DataType::Decimal32(9, 2),
+ &cast_options,
+ 2,
+ )
+ .unwrap();
+ let decimal_variant: Variant<'_, '_> = VariantDecimal4::try_new(1234,
2).unwrap().into();
+
+ assert!(!builder.append_value(&Variant::Null).unwrap());
+ assert!(builder.append_value(&decimal_variant).unwrap());
+
+ let array = builder.finish().unwrap();
+ let decimal_array =
array.as_any().downcast_ref::<Decimal32Array>().unwrap();
+ assert!(decimal_array.is_null(0));
+ assert_eq!(decimal_array.value(1), 1234);
+ }
+
+ #[test]
+ fn strict_cast_allows_variant_null_for_uuid_builder() {
+ let cast_options = CastOptions {
+ safe: false,
+ ..Default::default()
+ };
+ let mut builder = make_primitive_variant_to_arrow_row_builder(
+ &DataType::FixedSizeBinary(16),
+ &cast_options,
+ 2,
+ )
+ .unwrap();
+ let uuid = Uuid::nil();
+
+ assert!(!builder.append_value(&Variant::Null).unwrap());
+ assert!(builder.append_value(&Variant::Uuid(uuid)).unwrap());
+
+ let array = builder.finish().unwrap();
+ let uuid_array = array
+ .as_any()
+ .downcast_ref::<FixedSizeBinaryArray>()
+ .unwrap();
+ assert!(uuid_array.is_null(0));
+ assert_eq!(uuid_array.value(1), uuid.as_bytes());
+ }
+
+ #[test]
+ fn strict_cast_allows_variant_null_for_list_and_struct_builders() {
+ let cast_options = CastOptions {
+ safe: false,
+ ..Default::default()
+ };
+
+ let list_type = DataType::List(Arc::new(Field::new("item",
DataType::Int64, true)));
+ let mut list_builder =
+ make_typed_variant_to_arrow_row_builder(&list_type, &cast_options,
1).unwrap();
+ assert!(!list_builder.append_value(Variant::Null).unwrap());
+ let list_array = list_builder.finish().unwrap();
+ let list_array =
list_array.as_any().downcast_ref::<ListArray>().unwrap();
+ assert!(list_array.is_null(0));
+
+ let struct_type =
+ DataType::Struct(Fields::from(vec![Field::new("a",
DataType::Int32, true)]));
+ let mut struct_builder =
+ make_typed_variant_to_arrow_row_builder(&struct_type,
&cast_options, 1).unwrap();
+ assert!(!struct_builder.append_value(Variant::Null).unwrap());
+ let struct_array = struct_builder.finish().unwrap();
+ let struct_array =
struct_array.as_any().downcast_ref::<StructArray>().unwrap();
+ assert!(struct_array.is_null(0));
+ }
}