This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 2c5e237ab4 feat: roundtrip FixedSizeList Scalar to protobuf (#8239)
2c5e237ab4 is described below
commit 2c5e237ab43cb6ba48c4f892120a2a7558466e76
Author: Will Jones <[email protected]>
AuthorDate: Fri Nov 17 06:46:21 2023 -0800
feat: roundtrip FixedSizeList Scalar to protobuf (#8239)
---
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/generated/pbjson.rs | 14 +++++++++++
datafusion/proto/src/generated/prost.rs | 4 +++-
datafusion/proto/src/logical_plan/from_proto.rs | 8 +++++--
datafusion/proto/src/logical_plan/to_proto.rs | 28 ++++++++++++----------
.../proto/tests/cases/roundtrip_logical_plan.rs | 14 ++++++++---
6 files changed, 51 insertions(+), 18 deletions(-)
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index a5c3d3b603..8cab62acde 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -983,6 +983,7 @@ message ScalarValue{
int32 date_32_value = 14;
ScalarTime32Value time32_value = 15;
ScalarListValue list_value = 17;
+ ScalarListValue fixed_size_list_value = 18;
Decimal128 decimal128_value = 20;
Decimal256 decimal256_value = 39;
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 3faacca18c..c50571dca0 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -22042,6 +22042,9 @@ impl serde::Serialize for ScalarValue {
scalar_value::Value::ListValue(v) => {
struct_ser.serialize_field("listValue", v)?;
}
+ scalar_value::Value::FixedSizeListValue(v) => {
+ struct_ser.serialize_field("fixedSizeListValue", v)?;
+ }
scalar_value::Value::Decimal128Value(v) => {
struct_ser.serialize_field("decimal128Value", v)?;
}
@@ -22147,6 +22150,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
"time32Value",
"list_value",
"listValue",
+ "fixed_size_list_value",
+ "fixedSizeListValue",
"decimal128_value",
"decimal128Value",
"decimal256_value",
@@ -22202,6 +22207,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
Date32Value,
Time32Value,
ListValue,
+ FixedSizeListValue,
Decimal128Value,
Decimal256Value,
Date64Value,
@@ -22257,6 +22263,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
"date32Value" | "date_32_value" =>
Ok(GeneratedField::Date32Value),
"time32Value" | "time32_value" =>
Ok(GeneratedField::Time32Value),
"listValue" | "list_value" =>
Ok(GeneratedField::ListValue),
+ "fixedSizeListValue" | "fixed_size_list_value" =>
Ok(GeneratedField::FixedSizeListValue),
"decimal128Value" | "decimal128_value" =>
Ok(GeneratedField::Decimal128Value),
"decimal256Value" | "decimal256_value" =>
Ok(GeneratedField::Decimal256Value),
"date64Value" | "date_64_value" =>
Ok(GeneratedField::Date64Value),
@@ -22399,6 +22406,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
return
Err(serde::de::Error::duplicate_field("listValue"));
}
value__ =
map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::ListValue)
+;
+ }
+ GeneratedField::FixedSizeListValue => {
+ if value__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("fixedSizeListValue"));
+ }
+ value__ =
map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeListValue)
;
}
GeneratedField::Decimal128Value => {
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 2555a31f6f..213be1c395 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1200,7 +1200,7 @@ pub struct ScalarFixedSizeBinary {
pub struct ScalarValue {
#[prost(
oneof = "scalar_value::Value",
- tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20,
39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34"
+ tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18,
20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34"
)]
pub value: ::core::option::Option<scalar_value::Value>,
}
@@ -1246,6 +1246,8 @@ pub mod scalar_value {
Time32Value(super::ScalarTime32Value),
#[prost(message, tag = "17")]
ListValue(super::ScalarListValue),
+ #[prost(message, tag = "18")]
+ FixedSizeListValue(super::ScalarListValue),
#[prost(message, tag = "20")]
Decimal128Value(super::Decimal128),
#[prost(message, tag = "39")]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index f14da70485..a34b1b7beb 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -658,7 +658,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
Value::Float64Value(v) => Self::Float64(Some(*v)),
Value::Date32Value(v) => Self::Date32(Some(*v)),
// ScalarValue::List is serialized using arrow IPC format
- Value::ListValue(scalar_list) => {
+ Value::ListValue(scalar_list) |
Value::FixedSizeListValue(scalar_list) => {
let protobuf::ScalarListValue {
ipc_message,
arrow_data,
@@ -699,7 +699,11 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
.map_err(DataFusionError::ArrowError)
.map_err(|e| e.context("Decoding ScalarValue::List Value"))?;
let arr = record_batch.column(0);
- Self::List(arr.to_owned())
+ match value {
+ Value::ListValue(_) => Self::List(arr.to_owned()),
+ Value::FixedSizeListValue(_) =>
Self::FixedSizeList(arr.to_owned()),
+ _ => unreachable!(),
+ }
}
Value::NullValue(v) => {
let null_type: DataType = v.try_into()?;
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index de81a1f4ca..433c99403e 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1134,13 +1134,9 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
Value::LargeUtf8Value(s.to_owned())
})
}
- ScalarValue::FixedSizeList(..) => Err(Error::General(
- "Proto serialization error: ScalarValue::Fixedsizelist not
supported"
- .to_string(),
- )),
- // ScalarValue::List is serialized using Arrow IPC messages.
- // as a single column RecordBatch
- ScalarValue::List(arr) => {
+ // ScalarValue::List and ScalarValue::FixedSizeList are serialized
using
+ // Arrow IPC messages as a single column RecordBatch
+ ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => {
// Wrap in a "field_name" column
let batch = RecordBatch::try_from_iter(vec![(
"field_name",
@@ -1168,11 +1164,19 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
schema: Some(schema),
};
- Ok(protobuf::ScalarValue {
- value: Some(protobuf::scalar_value::Value::ListValue(
- scalar_list_value,
- )),
- })
+ match val {
+ ScalarValue::List(_) => Ok(protobuf::ScalarValue {
+ value: Some(protobuf::scalar_value::Value::ListValue(
+ scalar_list_value,
+ )),
+ }),
+ ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue {
+ value:
Some(protobuf::scalar_value::Value::FixedSizeListValue(
+ scalar_list_value,
+ )),
+ }),
+ _ => unreachable!(),
+ }
}
ScalarValue::Date32(val) => {
create_proto_scalar(val.as_ref(), &data_type, |s|
Value::Date32Value(*s))
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 75af9d2e0a..2d56967ecf 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -19,10 +19,10 @@ use std::collections::HashMap;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
-use arrow::array::ArrayRef;
+use arrow::array::{ArrayRef, FixedSizeListArray};
use arrow::datatypes::{
- DataType, Field, Fields, IntervalDayTimeType, IntervalMonthDayNanoType,
IntervalUnit,
- Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
+ DataType, Field, Fields, Int32Type, IntervalDayTimeType,
IntervalMonthDayNanoType,
+ IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
};
use prost::Message;
@@ -690,6 +690,14 @@ fn round_trip_scalar_values() {
],
&DataType::List(new_arc_field("item", DataType::Float32, true)),
)),
+
ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::from_iter_primitive::<
+ Int32Type,
+ _,
+ _,
+ >(
+ vec![Some(vec![Some(1), Some(2), Some(3)])],
+ 3,
+ ))),
ScalarValue::Dictionary(
Box::new(DataType::Int32),
Box::new(ScalarValue::Utf8(Some("foo".into()))),