This is an automated email from the ASF dual-hosted git repository. blaginin pushed a commit to branch annarose/dict-coercion in repository https://gitbox.apache.org/repos/asf/datafusion-sandbox.git
commit 8ba3d2617b49afa7473e0b583acbb2bf93928523 Author: Kumar Ujjawal <[email protected]> AuthorDate: Thu Feb 5 03:25:51 2026 +0530 fix: regression of `dict_id` in physical plan proto (#20063) ## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #20011. ## Rationale for this change - `dict_id` is intentionally not preserved protobuf (it’s deprecated in Arrow schema metadata), but Arrow IPC still requires dict IDs for dictionary encoding/decoding. <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? - Fix protobuf serde for nested ScalarValue (list/struct/map) containing dictionary arrays by using Arrow IPC’s dictionary handling correctly. - Seed DictionaryTracker by encoding the schema before encoding the nested scalar batch. - On decode, reconstruct an IPC schema from the protobuf schema and use arrow_ipc::reader::read_dictionary to build dict_by_id before reading the record batch. <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes added a test for this <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? No <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Andrew Lamb <[email protected]> --- datafusion/proto-common/src/from_proto/mod.rs | 79 +++++++++++++++------- datafusion/proto-common/src/to_proto/mod.rs | 13 +++- .../proto/tests/cases/roundtrip_physical_plan.rs | 19 ++++++ 3 files changed, 85 insertions(+), 26 deletions(-) diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index af427ef5a..967bda627 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -28,7 +28,12 @@ use arrow::datatypes::{ DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, i256, }; -use arrow::ipc::{reader::read_record_batch, root_as_message}; +use arrow::ipc::{ + convert::fb_to_schema, + reader::{read_dictionary, read_record_batch}, + root_as_message, + writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}, +}; use datafusion_common::{ Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef, @@ -397,7 +402,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float32Value(v) => Self::Float32(Some(*v)), Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), - // ScalarValue::List is serialized using arrow IPC format + // Nested ScalarValue types are serialized using arrow IPC format Value::ListValue(v) | Value::FixedSizeListValue(v) | Value::LargeListValue(v) @@ -414,55 +419,83 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { schema_ref.try_into()? } else { return Err(Error::General( - "Invalid schema while deserializing ScalarValue::List" + "Invalid schema while deserializing nested ScalarValue" .to_string(), )); }; + // IPC dictionary batch IDs are assigned when encoding the schema, but our protobuf + // `Schema` doesn't preserve those IDs. Reconstruct them deterministically by + // round-tripping the schema through IPC. + let schema: Schema = { + let ipc_gen = IpcDataGenerator {}; + let write_options = IpcWriteOptions::default(); + let mut dict_tracker = DictionaryTracker::new(false); + let encoded_schema = ipc_gen.schema_to_bytes_with_dictionary_tracker( + &schema, + &mut dict_tracker, + &write_options, + ); + let message = + root_as_message(encoded_schema.ipc_message.as_slice()).map_err( + |e| { + Error::General(format!( + "Error IPC schema message while deserializing nested ScalarValue: {e}" + )) + }, + )?; + let ipc_schema = message.header_as_schema().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing nested ScalarValue schema" + .to_string(), + ) + })?; + fb_to_schema(ipc_schema) + }; + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( - "Error IPC message while deserializing ScalarValue::List: {e}" + "Error IPC message while deserializing nested ScalarValue: {e}" )) })?; let buffer = Buffer::from(arrow_data.as_slice()); let ipc_batch = message.header_as_record_batch().ok_or_else(|| { Error::General( - "Unexpected message type deserializing ScalarValue::List" + "Unexpected message type deserializing nested ScalarValue" .to_string(), ) })?; - let dict_by_id: HashMap<i64,ArrayRef> = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| { + let mut dict_by_id: HashMap<i64, ArrayRef> = HashMap::new(); + for protobuf::scalar_nested_value::Dictionary { + ipc_message, + arrow_data, + } in dictionaries + { let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( - "Error IPC message while deserializing ScalarValue::List dictionary message: {e}" + "Error IPC message while deserializing nested ScalarValue dictionary message: {e}" )) })?; let buffer = Buffer::from(arrow_data.as_slice()); let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| { Error::General( - "Unexpected message type deserializing ScalarValue::List dictionary message" + "Unexpected message type deserializing nested ScalarValue dictionary message" .to_string(), ) })?; - - let id = dict_batch.id(); - - let record_batch = read_record_batch( + read_dictionary( &buffer, - dict_batch.data().unwrap(), - Arc::new(schema.clone()), - &Default::default(), - None, + dict_batch, + &schema, + &mut dict_by_id, &message.version(), - )?; - - let values: ArrayRef = Arc::clone(record_batch.column(0)); - - Ok((id, values)) - }).collect::<datafusion_common::Result<HashMap<_, _>>>()?; + ) + .map_err(|e| arrow_datafusion_err!(e)) + .map_err(|e| e.context("Decoding nested ScalarValue dictionary"))?; + } let record_batch = read_record_batch( &buffer, @@ -473,7 +506,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { &message.version(), ) .map_err(|e| arrow_datafusion_err!(e)) - .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; + .map_err(|e| e.context("Decoding nested ScalarValue value"))?; let arr = record_batch.column(0); match value { Value::ListValue(_) => { diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index db405b29a..01b671e37 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -1031,7 +1031,7 @@ fn create_proto_scalar<I, T: FnOnce(&I) -> protobuf::scalar_value::Value>( Ok(protobuf::ScalarValue { value: Some(value) }) } -// ScalarValue::List / FixedSizeList / LargeList / Struct / Map are serialized using +// Nested ScalarValue types (List / FixedSizeList / LargeList / Struct / Map) are serialized using // Arrow IPC messages as a single column RecordBatch fn encode_scalar_nested_value( arr: ArrayRef, @@ -1039,13 +1039,20 @@ fn encode_scalar_nested_value( ) -> Result<protobuf::ScalarValue, Error> { let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| { Error::General(format!( - "Error creating temporary batch while encoding ScalarValue::List: {e}" + "Error creating temporary batch while encoding nested ScalarValue: {e}" )) })?; let ipc_gen = IpcDataGenerator {}; let mut dict_tracker = DictionaryTracker::new(false); let write_options = IpcWriteOptions::default(); + // The IPC writer requires pre-allocated dictionary IDs (normally assigned when + // serializing the schema). Populate `dict_tracker` by encoding the schema first. + ipc_gen.schema_to_bytes_with_dictionary_tracker( + batch.schema().as_ref(), + &mut dict_tracker, + &write_options, + ); let mut compression_context = CompressionContext::default(); let (encoded_dictionaries, encoded_message) = ipc_gen .encode( @@ -1055,7 +1062,7 @@ fn encode_scalar_nested_value( &mut compression_context, ) .map_err(|e| { - Error::General(format!("Error encoding ScalarValue::List as IPC: {e}")) + Error::General(format!("Error encoding nested ScalarValue as IPC: {e}")) })?; let schema: protobuf::Schema = batch.schema().try_into()?; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 0a26025a3..fd4de8114 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -2566,6 +2566,25 @@ fn custom_proto_converter_intercepts() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_call_null_scalar_struct_dict() -> Result<()> { + let data_type = DataType::Struct(Fields::from(vec![Field::new( + "item", + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + true, + )])); + + let schema = Arc::new(Schema::new(vec![Field::new("a", data_type.clone(), true)])); + let scan = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let scalar = lit(ScalarValue::try_from(data_type)?); + let filter = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new(scalar, Operator::Eq, col("a", &schema)?)), + scan, + )?); + + roundtrip_test(filter) +} + /// Test that expression deduplication works during deserialization. /// When the same expression Arc is serialized multiple times, it should be /// deduplicated on deserialization (sharing the same Arc). --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
