This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new dfdda7cb04 fix: Ignore nullability of list elements when consuming
Substrait (#10874)
dfdda7cb04 is described below
commit dfdda7cb04f7f9b640da4f297ce1a16b08f3bf7b
Author: Arttu <[email protected]>
AuthorDate: Wed Jun 12 18:40:40 2024 +0200
fix: Ignore nullability of list elements when consuming Substrait (#10874)
* Ignore nullability of list elements when consuming Substrait
DataFusion (= Arrow) is quite strict about nullability, specifically,
when using e.g. LogicalPlan::Values, the given schema must match the
given literals exactly - including nullability.
This is non-trivial to do when converting schema and literals separately.
The existing implementation for from_substrait_literal already creates
lists that are always nullable
(see ScalarValue::new_list => array_into_list_array).
This reverts part of https://github.com/apache/datafusion/pull/10640 to
align from_substrait_type with that behavior.
This is the error I was hitting:
```
ArrowError(InvalidArgumentError("column types must match schema types,
expected
List(Field { name: \"item\", data_type: Int32, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }) but found
List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }) at column index 0"), None)
```
* use `Field::new_list_field` in `array_into_(large_)list_array`
just for consistency, to reduce the places where "item" is written out
* add a test for non-nullable lists
---
datafusion/common/src/utils/mod.rs | 14 ++---
datafusion/substrait/src/logical_plan/consumer.rs | 4 +-
datafusion/substrait/src/logical_plan/producer.rs | 14 ++---
datafusion/substrait/tests/cases/logical_plans.rs | 32 ++++++++--
.../testdata/non_nullable_lists.substrait.json | 71 ++++++++++++++++++++++
5 files changed, 114 insertions(+), 21 deletions(-)
diff --git a/datafusion/common/src/utils/mod.rs
b/datafusion/common/src/utils/mod.rs
index ae444c2cb2..a0e4d1a76c 100644
--- a/datafusion/common/src/utils/mod.rs
+++ b/datafusion/common/src/utils/mod.rs
@@ -354,7 +354,7 @@ pub fn longest_consecutive_prefix<T: Borrow<usize>>(
pub fn array_into_list_array(arr: ArrayRef) -> ListArray {
let offsets = OffsetBuffer::from_lengths([arr.len()]);
ListArray::new(
- Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
+ Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)),
offsets,
arr,
None,
@@ -366,7 +366,7 @@ pub fn array_into_list_array(arr: ArrayRef) -> ListArray {
pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray {
let offsets = OffsetBuffer::from_lengths([arr.len()]);
LargeListArray::new(
- Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
+ Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)),
offsets,
arr,
None,
@@ -379,7 +379,7 @@ pub fn array_into_fixed_size_list_array(
) -> FixedSizeListArray {
let list_size = list_size as i32;
FixedSizeListArray::new(
- Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
+ Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)),
list_size,
arr,
None,
@@ -420,7 +420,7 @@ pub fn arrays_into_list_array(
let data_type = arr[0].data_type().to_owned();
let values = arr.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
Ok(ListArray::new(
- Arc::new(Field::new("item", data_type, true)),
+ Arc::new(Field::new_list_field(data_type, true)),
OffsetBuffer::from_lengths(lens),
arrow::compute::concat(values.as_slice())?,
None,
@@ -435,7 +435,7 @@ pub fn arrays_into_list_array(
/// use datafusion_common::utils::base_type;
/// use std::sync::Arc;
///
-/// let data_type = DataType::List(Arc::new(Field::new("item",
DataType::Int32, true)));
+/// let data_type =
DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true)));
/// assert_eq!(base_type(&data_type), DataType::Int32);
///
/// let data_type = DataType::Int32;
@@ -458,10 +458,10 @@ pub fn base_type(data_type: &DataType) -> DataType {
/// use datafusion_common::utils::coerced_type_with_base_type_only;
/// use std::sync::Arc;
///
-/// let data_type = DataType::List(Arc::new(Field::new("item",
DataType::Int32, true)));
+/// let data_type =
DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true)));
/// let base_type = DataType::Float64;
/// let coerced_type = coerced_type_with_base_type_only(&data_type,
&base_type);
-/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item",
DataType::Float64, true))));
+/// assert_eq!(coerced_type,
DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true))));
pub fn coerced_type_with_base_type_only(
data_type: &DataType,
base_type: &DataType,
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 648a281832..3f9a895d95 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -1395,7 +1395,9 @@ fn from_substrait_type(
})?;
let field = Arc::new(Field::new_list_field(
from_substrait_type(inner_type, dfs_names, name_idx)?,
- is_substrait_type_nullable(inner_type)?,
+ // We ignore Substrait's nullability here to match
to_substrait_literal
+ // which always creates nullable lists
+ true,
));
match list.type_variation_reference {
DEFAULT_CONTAINER_TYPE_VARIATION_REF =>
Ok(DataType::List(field)),
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 88dc894ecc..c0469d3331 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -2309,14 +2309,12 @@ mod test {
round_trip_type(DataType::Decimal128(10, 2))?;
round_trip_type(DataType::Decimal256(30, 2))?;
- for nullable in [true, false] {
- round_trip_type(DataType::List(
- Field::new_list_field(DataType::Int32, nullable).into(),
- ))?;
- round_trip_type(DataType::LargeList(
- Field::new_list_field(DataType::Int32, nullable).into(),
- ))?;
- }
+ round_trip_type(DataType::List(
+ Field::new_list_field(DataType::Int32, true).into(),
+ ))?;
+ round_trip_type(DataType::LargeList(
+ Field::new_list_field(DataType::Int32, true).into(),
+ ))?;
round_trip_type(DataType::Struct(
vec![
diff --git a/datafusion/substrait/tests/cases/logical_plans.rs
b/datafusion/substrait/tests/cases/logical_plans.rs
index 994a932c30..94572e098b 100644
--- a/datafusion/substrait/tests/cases/logical_plans.rs
+++ b/datafusion/substrait/tests/cases/logical_plans.rs
@@ -20,6 +20,7 @@
#[cfg(test)]
mod tests {
use datafusion::common::Result;
+ use datafusion::dataframe::DataFrame;
use datafusion::prelude::{CsvReadOptions, SessionContext};
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
@@ -38,11 +39,7 @@ mod tests {
// File generated with substrait-java's Isthmus:
// ./isthmus-cli/build/graal/isthmus "select not d from data" -c
"create table data (d boolean)"
- let path = "tests/testdata/select_not_bool.substrait.json";
- let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
- File::open(path).expect("file not found"),
- ))
- .expect("failed to parse json");
+ let proto = read_json("tests/testdata/select_not_bool.substrait.json");
let plan = from_substrait_plan(&ctx, &proto).await?;
@@ -54,6 +51,31 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn non_nullable_lists() -> Result<()> {
+ // DataFusion's Substrait consumer treats all lists as nullable, even
if the Substrait plan specifies them as non-nullable.
+ // That's because implementing the non-nullability consistently is
non-trivial.
+ // This test confirms that reading a plan with non-nullable lists
works as expected.
+ let ctx = create_context().await?;
+ let proto =
read_json("tests/testdata/non_nullable_lists.substrait.json");
+
+ let plan = from_substrait_plan(&ctx, &proto).await?;
+
+ assert_eq!(format!("{:?}", &plan), "Values: (List([1, 2]))");
+
+ // Need to trigger execution to ensure that Arrow has validated the
plan
+ DataFrame::new(ctx.state(), plan).show().await?;
+
+ Ok(())
+ }
+
+ fn read_json(path: &str) -> Plan {
+ serde_json::from_reader::<_, Plan>(BufReader::new(
+ File::open(path).expect("file not found"),
+ ))
+ .expect("failed to parse json")
+ }
+
async fn create_context() -> datafusion::common::Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("DATA", "tests/testdata/data.csv",
CsvReadOptions::new())
diff --git
a/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json
b/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json
new file mode 100644
index 0000000000..e1c5574f8b
--- /dev/null
+++ b/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json
@@ -0,0 +1,71 @@
+{
+ "extensionUris": [],
+ "extensions": [],
+ "relations": [
+ {
+ "root": {
+ "input": {
+ "read": {
+ "common": {
+ "direct": {
+ }
+ },
+ "baseSchema": {
+ "names": [
+ "col"
+ ],
+ "struct": {
+ "types": [
+ {
+ "list": {
+ "type": {
+ "i32": {
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ }
+ ],
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ "virtualTable": {
+ "values": [
+ {
+ "fields": [
+ {
+ "list": {
+ "values": [
+ {
+ "i32": 1,
+ "nullable": false,
+ "typeVariationReference": 0
+ },
+ {
+ "i32": 2,
+ "nullable": false,
+ "typeVariationReference": 0
+ }
+ ]
+ },
+ "nullable": false,
+ "typeVariationReference": 0
+ }
+ ]
+ }
+ ]
+ }
+ }
+ },
+ "names": [
+ "col"
+ ]
+ }
+ }
+ ],
+ "expectedTypeUrls": []
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]