This is an automated email from the ASF dual-hosted git repository. blaginin pushed a commit to branch annarose/dict-coercion in repository https://gitbox.apache.org/repos/asf/datafusion-sandbox.git
commit 39da29f5ee6be088961d01df4f0b66e5079b54c5 Author: Jeffrey Vo <[email protected]> AuthorDate: Tue Feb 3 08:49:13 2026 +0900 Add `ScalarValue::RunEndEncoded` variant (#19895) ## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #18563 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Support RunEndEncoded scalar values, similar to how we support for Dictionary. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> - Add new `ScalarValue::RunEndEncoded` enum variant - Fix `ScalarValue::new_default` to support `Decimal32` and `Decimal64` - Support RunEndEncoded type in proto for both `ScalarValue` message and `ArrowType` message ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Added tests. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> New variant for `ScalarValue` Protobuf changes to support RunEndEncoded type <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Andrew Lamb <[email protected]> --- datafusion/common/src/cast.rs | 8 +- datafusion/common/src/scalar/mod.rs | 458 ++++++++++++++++++++- .../proto-common/proto/datafusion_common.proto | 14 + datafusion/proto-common/src/from_proto/mod.rs | 39 ++ datafusion/proto-common/src/generated/pbjson.rs | 265 ++++++++++++ datafusion/proto-common/src/generated/prost.rs | 24 +- datafusion/proto-common/src/to_proto/mod.rs | 81 ++-- .../proto/src/generated/datafusion_proto_common.rs | 24 +- .../proto/tests/cases/roundtrip_logical_plan.rs | 10 + datafusion/sql/src/unparser/expr.rs | 34 +- 10 files changed, 906 insertions(+), 51 deletions(-) diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 29082cc30..bc4313ed9 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -25,8 +25,9 @@ use arrow::array::{ BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, Int8Array, Int16Array, LargeBinaryArray, LargeListViewArray, LargeStringArray, - ListViewArray, StringViewArray, UInt16Array, + ListViewArray, RunArray, StringViewArray, UInt16Array, }; +use arrow::datatypes::RunEndIndexType; use arrow::{ array::{ Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, @@ -334,3 +335,8 @@ pub fn as_list_view_array(array: &dyn Array) -> Result<&ListViewArray> { pub fn as_large_list_view_array(array: &dyn Array) -> Result<&LargeListViewArray> { Ok(downcast_value!(array, LargeListViewArray)) } + +// Downcast Array to RunArray +pub fn as_run_array<T: RunEndIndexType>(array: &dyn Array) -> Result<&RunArray<T>> { + Ok(downcast_value!(array, RunArray, T)) +} diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 064091971..644916d78 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -43,7 +43,7 @@ use crate::cast::{ as_float16_array, as_float32_array, as_float64_array, as_int8_array, as_int16_array, as_int32_array, as_int64_array, as_interval_dt_array, as_interval_mdn_array, as_interval_ym_array, as_large_binary_array, as_large_list_array, - as_large_string_array, as_string_array, as_string_view_array, + as_large_string_array, as_run_array, as_string_array, as_string_view_array, as_time32_millisecond_array, as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, @@ -56,8 +56,8 @@ use crate::hash_utils::create_hashes; use crate::utils::SingleRowListArrayBuilder; use crate::{_internal_datafusion_err, arrow_datafusion_err}; use arrow::array::{ - Array, ArrayData, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray, - BinaryArray, BinaryViewArray, BinaryViewBuilder, BooleanArray, Date32Array, + Array, ArrayData, ArrayDataBuilder, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + AsArray, BinaryArray, BinaryViewArray, BinaryViewBuilder, BooleanArray, Date32Array, Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, DictionaryArray, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, @@ -65,11 +65,11 @@ use arrow::array::{ Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray, LargeStringArray, ListArray, MapArray, MutableArrayData, OffsetSizeTrait, - PrimitiveArray, Scalar, StringArray, StringViewArray, StringViewBuilder, StructArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + PrimitiveArray, RunArray, Scalar, StringArray, StringViewArray, StringViewBuilder, + StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, - UInt64Array, UnionArray, new_empty_array, new_null_array, + UInt64Array, UnionArray, downcast_run_array, new_empty_array, new_null_array, }; use arrow::buffer::{BooleanBuffer, ScalarBuffer}; use arrow::compute::kernels::cast::{CastOptions, cast_with_options}; @@ -79,11 +79,12 @@ use arrow::compute::kernels::numeric::{ use arrow::datatypes::{ ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, Date32Type, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, Field, - Float32Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTime, + FieldRef, Float32Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit, - IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, - UInt64Type, UnionFields, UnionMode, i256, validate_decimal_precision_and_scale, + IntervalYearMonthType, RunEndIndexType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, UnionFields, UnionMode, i256, + validate_decimal_precision_and_scale, }; use arrow::util::display::{ArrayFormatter, FormatOptions, array_value_to_string}; use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array}; @@ -428,6 +429,8 @@ pub enum ScalarValue { Union(Option<(i8, Box<ScalarValue>)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box<DataType>, Box<ScalarValue>), + /// (run-ends field, value field, value) + RunEndEncoded(FieldRef, FieldRef, Box<ScalarValue>), } impl Hash for Fl<f16> { @@ -557,6 +560,10 @@ impl PartialEq for ScalarValue { (Union(_, _, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, + (RunEndEncoded(rf1, vf1, v1), RunEndEncoded(rf2, vf2, v2)) => { + rf1.eq(rf2) && vf1.eq(vf2) && v1.eq(v2) + } + (RunEndEncoded(_, _, _), _) => false, (Null, Null) => true, (Null, _) => false, } @@ -722,6 +729,15 @@ impl PartialOrd for ScalarValue { if k1 == k2 { v1.partial_cmp(v2) } else { None } } (Dictionary(_, _), _) => None, + (RunEndEncoded(rf1, vf1, v1), RunEndEncoded(rf2, vf2, v2)) => { + // Don't compare if the run ends fields don't match (it is effectively a different datatype) + if rf1 == rf2 && vf1 == vf2 { + v1.partial_cmp(v2) + } else { + None + } + } + (RunEndEncoded(_, _, _), _) => None, (Null, Null) => Some(Ordering::Equal), (Null, _) => None, } @@ -965,6 +981,11 @@ impl Hash for ScalarValue { k.hash(state); v.hash(state); } + RunEndEncoded(rf, vf, v) => { + rf.hash(state); + vf.hash(state); + v.hash(state); + } // stable hash for Null value Null => 1.hash(state), } @@ -1243,6 +1264,13 @@ impl ScalarValue { index_type.clone(), Box::new(value_type.as_ref().try_into()?), ), + DataType::RunEndEncoded(run_ends_field, value_field) => { + ScalarValue::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(value_field.data_type().try_into()?), + ) + } // `ScalarValue::List` contains single element `ListArray`. DataType::List(field_ref) => ScalarValue::List(Arc::new( GenericListArray::new_null(Arc::clone(field_ref), 1), @@ -1573,6 +1601,8 @@ impl ScalarValue { | DataType::Float16 | DataType::Float32 | DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) | DataType::Timestamp(_, _) @@ -1641,6 +1671,14 @@ impl ScalarValue { Box::new(ScalarValue::new_default(value_type)?), )), + DataType::RunEndEncoded(run_ends_field, value_field) => { + Ok(ScalarValue::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(ScalarValue::new_default(value_field.data_type())?), + )) + } + // Map types DataType::Map(field, _) => Ok(ScalarValue::Map(Arc::new(MapArray::from( ArrayData::new_empty(field.data_type()), @@ -1660,8 +1698,7 @@ impl ScalarValue { } } - // Unsupported types for now - _ => { + DataType::ListView(_) | DataType::LargeListView(_) => { _not_impl_err!( "Default value for data_type \"{datatype}\" is not implemented yet" ) @@ -1952,6 +1989,12 @@ impl ScalarValue { ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } + ScalarValue::RunEndEncoded(run_ends_field, value_field, _) => { + DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + ) + } ScalarValue::Null => DataType::Null, } } @@ -2230,6 +2273,7 @@ impl ScalarValue { None => true, }, ScalarValue::Dictionary(_, v) => v.is_null(), + ScalarValue::RunEndEncoded(_, _, v) => v.is_null(), } } @@ -2597,6 +2641,94 @@ impl ScalarValue { _ => unreachable!("Invalid dictionary keys type: {}", key_type), } } + DataType::RunEndEncoded(run_ends_field, value_field) => { + fn make_run_array<R: RunEndIndexType>( + scalars: impl IntoIterator<Item = ScalarValue>, + run_ends_field: &FieldRef, + values_field: &FieldRef, + ) -> Result<ArrayRef> { + let mut scalars = scalars.into_iter(); + + let mut run_ends = vec![]; + let mut value_scalars = vec![]; + + let mut len = R::Native::ONE; + let mut current = + if let Some(ScalarValue::RunEndEncoded(_, _, scalar)) = + scalars.next() + { + *scalar + } else { + // We are guaranteed to have one element of correct + // type because we peeked above + unreachable!() + }; + for scalar in scalars { + let scalar = match scalar { + ScalarValue::RunEndEncoded( + inner_run_ends_field, + inner_value_field, + scalar, + ) if &inner_run_ends_field == run_ends_field + && &inner_value_field == values_field => + { + *scalar + } + _ => { + return _exec_err!( + "Expected RunEndEncoded scalar with run-ends field {run_ends_field} but got: {scalar:?}" + ); + } + }; + + // new run + if scalar != current { + run_ends.push(len); + value_scalars.push(current); + current = scalar; + } + + len = len.add_checked(R::Native::ONE).map_err(|_| { + DataFusionError::Execution(format!( + "Cannot construct RunArray: Overflows run-ends type {}", + run_ends_field.data_type() + )) + })?; + } + + run_ends.push(len); + value_scalars.push(current); + + let run_ends = PrimitiveArray::<R>::from_iter_values(run_ends); + let values = ScalarValue::iter_to_array(value_scalars)?; + + // Using ArrayDataBuilder so we can maintain the fields + let dt = DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(values_field), + ); + let builder = ArrayDataBuilder::new(dt) + .len(RunArray::logical_len(&run_ends)) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + let run_array = RunArray::<R>::from(builder.build()?); + + Ok(Arc::new(run_array)) + } + + match run_ends_field.data_type() { + DataType::Int16 => { + make_run_array::<Int16Type>(scalars, run_ends_field, value_field)? + } + DataType::Int32 => { + make_run_array::<Int32Type>(scalars, run_ends_field, value_field)? + } + DataType::Int64 => { + make_run_array::<Int64Type>(scalars, run_ends_field, value_field)? + } + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } DataType::FixedSizeBinary(size) => { let array = scalars .map(|sv| { @@ -2625,7 +2757,6 @@ impl ScalarValue { | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) - | DataType::RunEndEncoded(_, _) | DataType::ListView(_) | DataType::LargeListView(_) => { return _not_impl_err!( @@ -3202,6 +3333,54 @@ impl ScalarValue { _ => unreachable!("Invalid dictionary keys type: {}", key_type), } } + ScalarValue::RunEndEncoded(run_ends_field, values_field, value) => { + fn make_run_array<R: RunEndIndexType>( + run_ends_field: &Arc<Field>, + values_field: &Arc<Field>, + value: &ScalarValue, + size: usize, + ) -> Result<ArrayRef> { + let size_native = R::Native::from_usize(size) + .ok_or_else(|| DataFusionError::Execution(format!("Cannot construct RunArray of size {size}: Overflows run-ends type {}", R::DATA_TYPE)))?; + let values = value.to_array_of_size(1)?; + let run_ends = + PrimitiveArray::<R>::new(vec![size_native].into(), None); + + // Using ArrayDataBuilder so we can maintain the fields + let dt = DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(values_field), + ); + let builder = ArrayDataBuilder::new(dt) + .len(size) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + let run_array = RunArray::<R>::from(builder.build()?); + + Ok(Arc::new(run_array)) + } + match run_ends_field.data_type() { + DataType::Int16 => make_run_array::<Int16Type>( + run_ends_field, + values_field, + value, + size, + )?, + DataType::Int32 => make_run_array::<Int32Type>( + run_ends_field, + values_field, + value, + size, + )?, + DataType::Int64 => make_run_array::<Int64Type>( + run_ends_field, + values_field, + value, + size, + )?, + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } ScalarValue::Null => get_or_create_cached_null_array(size), }) } @@ -3552,6 +3731,28 @@ impl ScalarValue { Self::Dictionary(key_type.clone(), Box::new(value)) } + DataType::RunEndEncoded(run_ends_field, value_field) => { + // Explicitly check length here since get_physical_index() doesn't + // bound check for us + if index > array.len() { + return _exec_err!( + "Index {index} out of bounds for array of length {}", + array.len() + ); + } + let scalar = downcast_run_array!( + array => { + let index = array.get_physical_index(index); + ScalarValue::try_from_array(array.values(), index)? + }, + dt => unreachable!("Invalid run-ends type: {dt}") + ); + Self::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(scalar), + ) + } DataType::Struct(_) => { let a = array.slice(index, 1); Self::Struct(Arc::new(a.as_struct().to_owned())) @@ -3664,6 +3865,7 @@ impl ScalarValue { ScalarValue::LargeUtf8(v) => v, ScalarValue::Utf8View(v) => v, ScalarValue::Dictionary(_, v) => return v.try_as_str(), + ScalarValue::RunEndEncoded(_, _, v) => return v.try_as_str(), _ => return None, }; Some(v.as_ref().map(|v| v.as_str())) @@ -4008,6 +4210,34 @@ impl ScalarValue { None => v.is_null(), } } + ScalarValue::RunEndEncoded(run_ends_field, _, value) => { + // Explicitly check length here since get_physical_index() doesn't + // bound check for us + if index > array.len() { + return _exec_err!( + "Index {index} out of bounds for array of length {}", + array.len() + ); + } + match run_ends_field.data_type() { + DataType::Int16 => { + let array = as_run_array::<Int16Type>(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + DataType::Int32 => { + let array = as_run_array::<Int32Type>(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + DataType::Int64 => { + let array = as_run_array::<Int64Type>(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } ScalarValue::Null => array.is_null(index), }) } @@ -4097,6 +4327,7 @@ impl ScalarValue { // `dt` and `sv` are boxed, so they are NOT already included in `self` dt.size() + sv.size() } + ScalarValue::RunEndEncoded(rf, vf, v) => rf.size() + vf.size() + v.size(), } } @@ -4212,6 +4443,9 @@ impl ScalarValue { ScalarValue::Dictionary(_, value) => { value.compact(); } + ScalarValue::RunEndEncoded(_, _, value) => { + value.compact(); + } } } @@ -4843,6 +5077,7 @@ impl fmt::Display for ScalarValue { None => write!(f, "NULL")?, }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, + ScalarValue::RunEndEncoded(_, _, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) @@ -5021,6 +5256,9 @@ impl fmt::Debug for ScalarValue { None => write!(f, "Union(NULL)"), }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), + ScalarValue::RunEndEncoded(rf, vf, v) => { + write!(f, "RunEndEncoded({rf:?}, {vf:?}, {v:?})") + } ScalarValue::Null => write!(f, "NULL"), } } @@ -7256,6 +7494,31 @@ mod tests { } } + #[test] + fn roundtrip_run_array() { + // Comparison logic in round_trip_through_scalar doesn't work for RunArrays + // so we have a custom test for them + // TODO: https://github.com/apache/arrow-rs/pull/9213 might fix this ^ + let run_ends = Int16Array::from(vec![2, 3]); + let values = Int64Array::from(vec![Some(1), None]); + let run_array = RunArray::try_new(&run_ends, &values).unwrap(); + let run_array = run_array.downcast::<Int64Array>().unwrap(); + + let expected_values = run_array.into_iter().collect::<Vec<_>>(); + + for i in 0..run_array.len() { + let scalar = ScalarValue::try_from_array(&run_array, i).unwrap(); + let array = scalar.to_array_of_size(1).unwrap(); + assert_eq!(array.data_type(), run_array.data_type()); + let array = array.as_run::<Int16Type>(); + let array = array.downcast::<Int64Array>().unwrap(); + assert_eq!( + array.into_iter().collect::<Vec<_>>(), + expected_values[i..i + 1] + ); + } + } + #[test] fn test_scalar_union_sparse() { let field_a = Arc::new(Field::new("A", DataType::Int32, true)); @@ -9228,6 +9491,175 @@ mod tests { assert_eq!(value.len(), buffers[0].len()); } + #[test] + fn test_to_array_of_size_run_end_encoded() { + fn run_test<R: RunEndIndexType>() { + let value = Box::new(ScalarValue::Float32(Some(1.0))); + let size = 5; + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", R::DATA_TYPE, false).into(), + Field::new("values", DataType::Float32, true).into(), + value.clone(), + ); + let array = scalar.to_array_of_size(size).unwrap(); + let array = array.as_run::<R>(); + let array = array.downcast::<Float32Array>().unwrap(); + assert_eq!(vec![Some(1.0); size], array.into_iter().collect::<Vec<_>>()); + assert_eq!(1, array.values().len()); + } + + run_test::<Int16Type>(); + run_test::<Int32Type>(); + run_test::<Int64Type>(); + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + let err = scalar.to_array_of_size(i16::MAX as usize + 10).unwrap_err(); + assert_eq!( + "Execution error: Cannot construct RunArray of size 32777: Overflows run-ends type Int16", + err.to_string() + ) + } + + #[test] + fn test_eq_array_run_end_encoded() { + let run_ends = Int16Array::from(vec![1, 3]); + let values = Float32Array::from(vec![None, Some(1.0)]); + let run_array = + Arc::new(RunArray::try_new(&run_ends, &values).unwrap()) as ArrayRef; + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(None)), + ); + assert!(scalar.eq_array(&run_array, 0).unwrap()); + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + assert!(scalar.eq_array(&run_array, 1).unwrap()); + assert!(scalar.eq_array(&run_array, 2).unwrap()); + + // value types must match + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float64, true).into(), + Box::new(ScalarValue::Float64(Some(1.0))), + ); + let err = scalar.eq_array(&run_array, 1).unwrap_err(); + let expected = "Internal error: could not cast array of type Float32 to arrow_array::array::primitive_array::PrimitiveArray<arrow_array::types::Float64Type>"; + assert!(err.to_string().starts_with(expected)); + + // run ends type must match + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(None)), + ); + let err = scalar.eq_array(&run_array, 0).unwrap_err(); + let expected = "Internal error: could not cast array of type RunEndEncoded(\"run_ends\": non-null Int16, \"values\": Float32) to arrow_array::array::run_array::RunArray<arrow_array::types::Int32Type>"; + assert!(err.to_string().starts_with(expected)); + } + + #[test] + fn test_iter_to_array_run_end_encoded() { + let run_ends_field = Arc::new(Field::new("run_ends", DataType::Int16, false)); + let values_field = Arc::new(Field::new("values", DataType::Int64, true)); + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(None)), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ]; + + let run_array = ScalarValue::iter_to_array(scalars).unwrap(); + let expected = RunArray::try_new( + &Int16Array::from(vec![2, 3, 6]), + &Int64Array::from(vec![Some(1), None, Some(2)]), + ) + .unwrap(); + assert_eq!(&expected as &dyn Array, run_array.as_ref()); + + // inconsistent run-ends type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: RunEndEncoded(Field { name: \"run_ends\", data_type: Int32 }, Field { name: \"values\", data_type: Int64, nullable: true }, Int64(1))"; + assert!(err.to_string().starts_with(expected)); + + // inconsistent value type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Field::new("values", DataType::Int32, true).into(), + Box::new(ScalarValue::Int32(Some(1))), + ), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: RunEndEncoded(Field { name: \"run_ends\", data_type: Int16 }, Field { name: \"values\", data_type: Int32, nullable: true }, Int32(1))"; + assert!(err.to_string().starts_with(expected)); + + // inconsistent scalars type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::Int64(Some(1)), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: Int64(1)"; + assert!(err.to_string().starts_with(expected)); + } + #[test] fn test_convert_array_to_scalar_vec() { // 1: Regular ListArray diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 08bb25bd7..8a9185ca7 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -183,6 +183,11 @@ message Map { bool keys_sorted = 2; } +message RunEndEncoded { + Field run_ends_field = 1; + Field values_field = 2; +} + enum UnionMode{ sparse = 0; dense = 1; @@ -236,6 +241,12 @@ message ScalarDictionaryValue { ScalarValue value = 2; } +message ScalarRunEndEncodedValue { + Field run_ends_field = 1; + Field values_field = 2; + ScalarValue value = 3; +} + message IntervalDayTimeValue { int32 days = 1; int32 milliseconds = 2; @@ -321,6 +332,8 @@ message ScalarValue{ IntervalMonthDayNanoValue interval_month_day_nano = 31; ScalarFixedSizeBinary fixed_size_binary_value = 34; UnionValue union_value = 42; + + ScalarRunEndEncodedValue run_end_encoded_value = 45; } } @@ -389,6 +402,7 @@ message ArrowType{ Union UNION = 29; Dictionary DICTIONARY = 30; Map MAP = 33; + RunEndEncoded RUN_END_ENCODED = 42; } } diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index 3c41b8cad..af427ef5a 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -326,6 +326,19 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { let keys_sorted = map.keys_sorted; DataType::Map(Arc::new(field), keys_sorted) } + arrow_type::ArrowTypeEnum::RunEndEncoded(run_end_encoded) => { + let run_ends_field: Field = run_end_encoded + .as_ref() + .run_ends_field + .as_deref() + .required("run_ends_field")?; + let value_field: Field = run_end_encoded + .as_ref() + .values_field + .as_deref() + .required("values_field")?; + DataType::RunEndEncoded(run_ends_field.into(), value_field.into()) + } }) } } @@ -578,6 +591,32 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Self::Dictionary(Box::new(index_type), Box::new(value)) } + Value::RunEndEncodedValue(v) => { + let run_ends_field: Field = v + .run_ends_field + .as_ref() + .ok_or_else(|| Error::required("run_ends_field"))? + .try_into()?; + + let values_field: Field = v + .values_field + .as_ref() + .ok_or_else(|| Error::required("values_field"))? + .try_into()?; + + let value: Self = v + .value + .as_ref() + .ok_or_else(|| Error::required("value"))? + .as_ref() + .try_into()?; + + Self::RunEndEncoded( + run_ends_field.into(), + values_field.into(), + Box::new(value), + ) + } Value::BinaryValue(v) => Self::Binary(Some(v.clone())), Value::BinaryViewValue(v) => Self::BinaryView(Some(v.clone())), Value::LargeBinaryValue(v) => Self::LargeBinary(Some(v.clone())), diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index ef0eae198..80dff4410 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -276,6 +276,9 @@ impl serde::Serialize for ArrowType { arrow_type::ArrowTypeEnum::Map(v) => { struct_ser.serialize_field("MAP", v)?; } + arrow_type::ArrowTypeEnum::RunEndEncoded(v) => { + struct_ser.serialize_field("RUNENDENCODED", v)?; + } } } struct_ser.end() @@ -333,6 +336,8 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "UNION", "DICTIONARY", "MAP", + "RUN_END_ENCODED", + "RUNENDENCODED", ]; #[allow(clippy::enum_variant_names)] @@ -375,6 +380,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { Union, Dictionary, Map, + RunEndEncoded, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error> @@ -434,6 +440,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "UNION" => Ok(GeneratedField::Union), "DICTIONARY" => Ok(GeneratedField::Dictionary), "MAP" => Ok(GeneratedField::Map), + "RUNENDENCODED" | "RUN_END_ENCODED" => Ok(GeneratedField::RunEndEncoded), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -715,6 +722,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType { return Err(serde::de::Error::duplicate_field("MAP")); } arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Map) +; + } + GeneratedField::RunEndEncoded => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("RUNENDENCODED")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::RunEndEncoded) ; } } @@ -6600,6 +6614,116 @@ impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { deserializer.deserialize_struct("datafusion_common.PrimaryKeyConstraint", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for RunEndEncoded { + #[allow(deprecated)] + fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.run_ends_field.is_some() { + len += 1; + } + if self.values_field.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.RunEndEncoded", len)?; + if let Some(v) = self.run_ends_field.as_ref() { + struct_ser.serialize_field("runEndsField", v)?; + } + if let Some(v) = self.values_field.as_ref() { + struct_ser.serialize_field("valuesField", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RunEndEncoded { + #[allow(deprecated)] + fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "run_ends_field", + "runEndsField", + "values_field", + "valuesField", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + RunEndsField, + ValuesField, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error> + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str<E>(self, value: &str) -> std::result::Result<GeneratedField, E> + where + E: serde::de::Error, + { + match value { + "runEndsField" | "run_ends_field" => Ok(GeneratedField::RunEndsField), + "valuesField" | "values_field" => Ok(GeneratedField::ValuesField), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RunEndEncoded; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.RunEndEncoded") + } + + fn visit_map<V>(self, mut map_: V) -> std::result::Result<RunEndEncoded, V::Error> + where + V: serde::de::MapAccess<'de>, + { + let mut run_ends_field__ = None; + let mut values_field__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::RunEndsField => { + if run_ends_field__.is_some() { + return Err(serde::de::Error::duplicate_field("runEndsField")); + } + run_ends_field__ = map_.next_value()?; + } + GeneratedField::ValuesField => { + if values_field__.is_some() { + return Err(serde::de::Error::duplicate_field("valuesField")); + } + values_field__ = map_.next_value()?; + } + } + } + Ok(RunEndEncoded { + run_ends_field: run_ends_field__, + values_field: values_field__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.RunEndEncoded", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarDictionaryValue { #[allow(deprecated)] fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> @@ -7093,6 +7217,133 @@ impl<'de> serde::Deserialize<'de> for scalar_nested_value::Dictionary { deserializer.deserialize_struct("datafusion_common.ScalarNestedValue.Dictionary", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ScalarRunEndEncodedValue { + #[allow(deprecated)] + fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.run_ends_field.is_some() { + len += 1; + } + if self.values_field.is_some() { + len += 1; + } + if self.value.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarRunEndEncodedValue", len)?; + if let Some(v) = self.run_ends_field.as_ref() { + struct_ser.serialize_field("runEndsField", v)?; + } + if let Some(v) = self.values_field.as_ref() { + struct_ser.serialize_field("valuesField", v)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarRunEndEncodedValue { + #[allow(deprecated)] + fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "run_ends_field", + "runEndsField", + "values_field", + "valuesField", + "value", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + RunEndsField, + ValuesField, + Value, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error> + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str<E>(self, value: &str) -> std::result::Result<GeneratedField, E> + where + E: serde::de::Error, + { + match value { + "runEndsField" | "run_ends_field" => Ok(GeneratedField::RunEndsField), + "valuesField" | "values_field" => Ok(GeneratedField::ValuesField), + "value" => Ok(GeneratedField::Value), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarRunEndEncodedValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ScalarRunEndEncodedValue") + } + + fn visit_map<V>(self, mut map_: V) -> std::result::Result<ScalarRunEndEncodedValue, V::Error> + where + V: serde::de::MapAccess<'de>, + { + let mut run_ends_field__ = None; + let mut values_field__ = None; + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::RunEndsField => { + if run_ends_field__.is_some() { + return Err(serde::de::Error::duplicate_field("runEndsField")); + } + run_ends_field__ = map_.next_value()?; + } + GeneratedField::ValuesField => { + if values_field__.is_some() { + return Err(serde::de::Error::duplicate_field("valuesField")); + } + values_field__ = map_.next_value()?; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; + } + } + } + Ok(ScalarRunEndEncodedValue { + run_ends_field: run_ends_field__, + values_field: values_field__, + value: value__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ScalarRunEndEncodedValue", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarTime32Value { #[allow(deprecated)] fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> @@ -7635,6 +7886,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::UnionValue(v) => { struct_ser.serialize_field("unionValue", v)?; } + scalar_value::Value::RunEndEncodedValue(v) => { + struct_ser.serialize_field("runEndEncodedValue", v)?; + } } } struct_ser.end() @@ -7731,6 +7985,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "fixedSizeBinaryValue", "union_value", "unionValue", + "run_end_encoded_value", + "runEndEncodedValue", ]; #[allow(clippy::enum_variant_names)] @@ -7777,6 +8033,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { IntervalMonthDayNano, FixedSizeBinaryValue, UnionValue, + RunEndEncodedValue, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error> @@ -7840,6 +8097,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), "fixedSizeBinaryValue" | "fixed_size_binary_value" => Ok(GeneratedField::FixedSizeBinaryValue), "unionValue" | "union_value" => Ok(GeneratedField::UnionValue), + "runEndEncodedValue" | "run_end_encoded_value" => Ok(GeneratedField::RunEndEncodedValue), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8130,6 +8388,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("unionValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue) +; + } + GeneratedField::RunEndEncodedValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("runEndEncodedValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::RunEndEncodedValue) ; } } diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index 16601dcf4..30ce5a773 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -176,6 +176,13 @@ pub struct Map { pub keys_sorted: bool, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct RunEndEncoded { + #[prost(message, optional, boxed, tag = "1")] + pub run_ends_field: ::core::option::Option<::prost::alloc::boxed::Box<Field>>, + #[prost(message, optional, boxed, tag = "2")] + pub values_field: ::core::option::Option<::prost::alloc::boxed::Box<Field>>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Union { #[prost(message, repeated, tag = "1")] pub union_types: ::prost::alloc::vec::Vec<Field>, @@ -264,6 +271,15 @@ pub struct ScalarDictionaryValue { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box<ScalarValue>>, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarRunEndEncodedValue { + #[prost(message, optional, tag = "1")] + pub run_ends_field: ::core::option::Option<Field>, + #[prost(message, optional, tag = "2")] + pub values_field: ::core::option::Option<Field>, + #[prost(message, optional, boxed, tag = "3")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box<ScalarValue>>, +} #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct IntervalDayTimeValue { #[prost(int32, tag = "1")] @@ -311,7 +327,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42, 45" )] pub value: ::core::option::Option<scalar_value::Value>, } @@ -406,6 +422,8 @@ pub mod scalar_value { FixedSizeBinaryValue(super::ScalarFixedSizeBinary), #[prost(message, tag = "42")] UnionValue(::prost::alloc::boxed::Box<super::UnionValue>), + #[prost(message, tag = "45")] + RunEndEncodedValue(::prost::alloc::boxed::Box<super::ScalarRunEndEncodedValue>), } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -449,7 +467,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33, 42" )] pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>, } @@ -538,6 +556,8 @@ pub mod arrow_type { Dictionary(::prost::alloc::boxed::Box<super::Dictionary>), #[prost(message, tag = "33")] Map(::prost::alloc::boxed::Box<super::Map>), + #[prost(message, tag = "42")] + RunEndEncoded(::prost::alloc::boxed::Box<super::RunEndEncoded>), } } /// Useful for representing an empty enum variant in rust diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index fee365648..db405b29a 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -180,7 +180,9 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { UnionMode::Dense => protobuf::UnionMode::Dense, }; Self::Union(protobuf::Union { - union_types: convert_arc_fields_to_proto_fields(fields.iter().map(|(_, item)|item))?, + union_types: convert_arc_fields_to_proto_fields( + fields.iter().map(|(_, item)| item), + )?, union_mode: union_mode.into(), type_ids: fields.iter().map(|(x, _)| x as i32).collect(), }) @@ -191,37 +193,44 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { value: Some(Box::new(value_type.as_ref().try_into()?)), })) } - DataType::Decimal32(precision, scale) => Self::Decimal32(protobuf::Decimal32Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Decimal64(precision, scale) => Self::Decimal64(protobuf::Decimal64Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Decimal128(precision, scale) => Self::Decimal128(protobuf::Decimal128Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Decimal256(precision, scale) => Self::Decimal256(protobuf::Decimal256Type { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Map(field, sorted) => { - Self::Map(Box::new( - protobuf::Map { - field_type: Some(Box::new(field.as_ref().try_into()?)), - keys_sorted: *sorted, - } - )) - } - DataType::RunEndEncoded(_, _) => { - return Err(Error::General( - "Proto serialization error: The RunEndEncoded data type is not yet supported".to_owned() - )) + DataType::Decimal32(precision, scale) => { + Self::Decimal32(protobuf::Decimal32Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Decimal64(precision, scale) => { + Self::Decimal64(protobuf::Decimal64Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Decimal128(precision, scale) => { + Self::Decimal128(protobuf::Decimal128Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Decimal256(precision, scale) => { + Self::Decimal256(protobuf::Decimal256Type { + precision: *precision as u32, + scale: *scale as i32, + }) + } + DataType::Map(field, sorted) => Self::Map(Box::new(protobuf::Map { + field_type: Some(Box::new(field.as_ref().try_into()?)), + keys_sorted: *sorted, + })), + DataType::RunEndEncoded(run_ends_field, values_field) => { + Self::RunEndEncoded(Box::new(protobuf::RunEndEncoded { + run_ends_field: Some(Box::new(run_ends_field.as_ref().try_into()?)), + values_field: Some(Box::new(values_field.as_ref().try_into()?)), + })) } DataType::ListView(_) | DataType::LargeListView(_) => { - return Err(Error::General(format!("Proto serialization error: {val} not yet supported"))) + return Err(Error::General(format!( + "Proto serialization error: {val} not yet supported" + ))); } }; @@ -680,6 +689,18 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { ))), }) } + + ScalarValue::RunEndEncoded(run_ends_field, values_field, val) => { + Ok(protobuf::ScalarValue { + value: Some(Value::RunEndEncodedValue(Box::new( + protobuf::ScalarRunEndEncodedValue { + run_ends_field: Some(run_ends_field.as_ref().try_into()?), + values_field: Some(values_field.as_ref().try_into()?), + value: Some(Box::new(val.as_ref().try_into()?)), + }, + ))), + }) + } } } } diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 16601dcf4..30ce5a773 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -176,6 +176,13 @@ pub struct Map { pub keys_sorted: bool, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct RunEndEncoded { + #[prost(message, optional, boxed, tag = "1")] + pub run_ends_field: ::core::option::Option<::prost::alloc::boxed::Box<Field>>, + #[prost(message, optional, boxed, tag = "2")] + pub values_field: ::core::option::Option<::prost::alloc::boxed::Box<Field>>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Union { #[prost(message, repeated, tag = "1")] pub union_types: ::prost::alloc::vec::Vec<Field>, @@ -264,6 +271,15 @@ pub struct ScalarDictionaryValue { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box<ScalarValue>>, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarRunEndEncodedValue { + #[prost(message, optional, tag = "1")] + pub run_ends_field: ::core::option::Option<Field>, + #[prost(message, optional, tag = "2")] + pub values_field: ::core::option::Option<Field>, + #[prost(message, optional, boxed, tag = "3")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box<ScalarValue>>, +} #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct IntervalDayTimeValue { #[prost(int32, tag = "1")] @@ -311,7 +327,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42, 45" )] pub value: ::core::option::Option<scalar_value::Value>, } @@ -406,6 +422,8 @@ pub mod scalar_value { FixedSizeBinaryValue(super::ScalarFixedSizeBinary), #[prost(message, tag = "42")] UnionValue(::prost::alloc::boxed::Box<super::UnionValue>), + #[prost(message, tag = "45")] + RunEndEncodedValue(::prost::alloc::boxed::Box<super::ScalarRunEndEncodedValue>), } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -449,7 +467,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33, 42" )] pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>, } @@ -538,6 +556,8 @@ pub mod arrow_type { Dictionary(::prost::alloc::boxed::Box<super::Dictionary>), #[prost(message, tag = "33")] Map(::prost::alloc::boxed::Box<super::Map>), + #[prost(message, tag = "42")] + RunEndEncoded(::prost::alloc::boxed::Box<super::RunEndEncoded>), } } /// Useful for representing an empty enum variant in rust diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e5c218e5e..f622cb52a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1534,6 +1534,16 @@ fn round_trip_scalar_values_and_data_types() { Box::new(DataType::Int32), Box::new(ScalarValue::Utf8(None)), ), + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Utf8, true).into(), + Box::new(ScalarValue::from("foo")), + ), + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Utf8, true).into(), + Box::new(ScalarValue::Utf8(None)), + ), ScalarValue::Binary(Some(b"bar".to_vec())), ScalarValue::Binary(None), ScalarValue::LargeBinary(Some(b"bar".to_vec())), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index ac7b46792..5f6612830 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1441,6 +1441,7 @@ impl Unparser<'_> { ScalarValue::Map(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(_k, v) => self.scalar_to_sql(v), + ScalarValue::RunEndEncoded(_, _, v) => self.scalar_to_sql(v), } } @@ -1790,6 +1791,9 @@ impl Unparser<'_> { not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype(val), + DataType::RunEndEncoded(_, val) => { + self.arrow_dtype_to_ast_dtype(val.data_type()) + } DataType::Decimal32(precision, scale) | DataType::Decimal64(precision, scale) | DataType::Decimal128(precision, scale) @@ -1811,9 +1815,6 @@ impl Unparser<'_> { DataType::Map(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type}") } - DataType::RunEndEncoded(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type}") - } } } } @@ -2316,6 +2317,17 @@ mod tests { ), "'foo'", ), + ( + Expr::Literal( + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Utf8, true).into(), + Box::new(ScalarValue::Utf8(Some("foo".into()))), + ), + None, + ), + "'foo'", + ), ( Expr::Literal( ScalarValue::List(Arc::new(ListArray::from_iter_primitive::< @@ -3185,6 +3197,22 @@ mod tests { Ok(()) } + #[test] + fn test_run_end_encoded_to_sql() -> Result<()> { + let dialect = CustomDialectBuilder::new().build(); + + let unparser = Unparser::new(&dialect); + + let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&DataType::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Utf8, true).into(), + ))?; + + assert_eq!(ast_dtype, ast::DataType::Varchar(None)); + + Ok(()) + } + #[test] fn test_utf8_view_to_sql() -> Result<()> { let dialect = CustomDialectBuilder::new() --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
