This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 22a7d74 Indexed field access for List (#1006)
22a7d74 is described below
commit 22a7d74e9dece44adb790b5ea3fca27d26a2fe51
Author: Guillaume Balaine <[email protected]>
AuthorDate: Fri Oct 29 20:00:17 2021 +0200
Indexed field access for List (#1006)
* enable GetIndexedField for Array and Dictionary
* fix GetIndexedField which should index slices not values
* Compat with latest sqlparser
* Add two tests for indexed_field access, level one and level two nesting
* fix compilation issues
* try fixing dictionary lookup
* address clippy warnings
* fix test
* Revert dictionary lookup for indexed fields
* Reject negative ints when accessing list values in get indexed field
* Fix doc in get_indexed_field
* use GetIndexedFieldExpr directly
* return the data type in unavailable field indexation error message
* Add unit tests for the physical plan of get_indexed_field
* Fix missing clause for const evaluator
---
datafusion/src/field_util.rs | 50 +++++
datafusion/src/lib.rs | 1 +
datafusion/src/logical_plan/expr.rs | 29 +++
.../src/optimizer/common_subexpr_eliminate.rs | 4 +
datafusion/src/optimizer/utils.rs | 7 +
.../physical_plan/expressions/get_indexed_field.rs | 225 +++++++++++++++++++++
datafusion/src/physical_plan/expressions/mod.rs | 2 +
datafusion/src/physical_plan/planner.rs | 20 +-
datafusion/src/sql/planner.rs | 37 ++++
datafusion/src/sql/utils.rs | 4 +
datafusion/tests/sql.rs | 80 ++++++++
11 files changed, 458 insertions(+), 1 deletion(-)
diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs
new file mode 100644
index 0000000..9d5face
--- /dev/null
+++ b/datafusion/src/field_util.rs
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Utility functions for complex field access
+
+use arrow::datatypes::{DataType, Field};
+
+use crate::error::{DataFusionError, Result};
+use crate::scalar::ScalarValue;
+
+/// Returns the field access indexed by `key` from a [`DataType::List`]
+/// # Error
+/// Errors if
+/// * the `data_type` is not a Struct or,
+/// * there is no field key is not of the required index type
+pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) ->
Result<Field> {
+ match (data_type, key) {
+ (DataType::List(lt), ScalarValue::Int64(Some(i))) => {
+ if *i < 0 {
+ Err(DataFusionError::Plan(format!(
+ "List based indexed access requires a positive int, was
{0}",
+ i
+ )))
+ } else {
+ Ok(Field::new(&i.to_string(), lt.data_type().clone(), false))
+ }
+ }
+ (DataType::List(_), _) => Err(DataFusionError::Plan(
+ "Only ints are valid as an indexed field in a list".to_string(),
+ )),
+ _ => Err(DataFusionError::Plan(
+ "The expression to get an indexed field is only valid for `List`
types"
+ .to_string(),
+ )),
+ }
+}
diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs
index a4a5a88..fa140ce 100644
--- a/datafusion/src/lib.rs
+++ b/datafusion/src/lib.rs
@@ -231,6 +231,7 @@ pub mod variable;
pub use arrow;
pub use parquet;
+pub(crate) mod field_util;
#[cfg(test)]
pub mod test;
pub mod test_util;
diff --git a/datafusion/src/logical_plan/expr.rs
b/datafusion/src/logical_plan/expr.rs
index 011068d..499a8c7 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -20,6 +20,7 @@
pub use super::Operator;
use crate::error::{DataFusionError, Result};
+use crate::field_util::get_indexed_field;
use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan};
use crate::physical_plan::functions::Volatility;
use crate::physical_plan::{
@@ -245,6 +246,13 @@ pub enum Expr {
IsNull(Box<Expr>),
/// arithmetic negation of an expression, the operand must be of a signed
numeric data type
Negative(Box<Expr>),
+ /// Returns the field of a [`ListArray`] by key
+ GetIndexedField {
+ /// the expression to take the field from
+ expr: Box<Expr>,
+ /// The name of the field to take
+ key: ScalarValue,
+ },
/// Whether an expression is between a given range.
Between {
/// The value to compare
@@ -433,6 +441,11 @@ impl Expr {
Expr::Wildcard => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query
plan".to_owned(),
)),
+ Expr::GetIndexedField { ref expr, key } => {
+ let data_type = expr.get_type(schema)?;
+
+ get_indexed_field(&data_type, key).map(|x|
x.data_type().clone())
+ }
}
}
@@ -488,6 +501,10 @@ impl Expr {
Expr::Wildcard => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query
plan".to_owned(),
)),
+ Expr::GetIndexedField { ref expr, key } => {
+ let data_type = expr.get_type(input_schema)?;
+ get_indexed_field(&data_type, key).map(|x| x.is_nullable())
+ }
}
}
@@ -763,6 +780,7 @@ impl Expr {
.try_fold(visitor, |visitor, arg| arg.accept(visitor))
}
Expr::Wildcard => Ok(visitor),
+ Expr::GetIndexedField { ref expr, .. } => expr.accept(visitor),
}?;
visitor.post_visit(self)
@@ -923,6 +941,10 @@ impl Expr {
negated,
},
Expr::Wildcard => Expr::Wildcard,
+ Expr::GetIndexedField { expr, key } => Expr::GetIndexedField {
+ expr: rewrite_boxed(expr, rewriter)?,
+ key,
+ },
};
// now rewrite this expression itself
@@ -1799,6 +1821,9 @@ impl fmt::Debug for Expr {
}
}
Expr::Wildcard => write!(f, "*"),
+ Expr::GetIndexedField { ref expr, key } => {
+ write!(f, "({:?})[{}]", expr, key)
+ }
}
}
}
@@ -1879,6 +1904,10 @@ fn create_name(e: &Expr, input_schema: &DFSchema) ->
Result<String> {
let expr = create_name(expr, input_schema)?;
Ok(format!("{} IS NOT NULL", expr))
}
+ Expr::GetIndexedField { expr, key } => {
+ let expr = create_name(expr, input_schema)?;
+ Ok(format!("{}[{}]", expr, key))
+ }
Expr::ScalarFunction { fun, args, .. } => {
create_function_name(&fun.to_string(), false, args, input_schema)
}
diff --git a/datafusion/src/optimizer/common_subexpr_eliminate.rs
b/datafusion/src/optimizer/common_subexpr_eliminate.rs
index 8d87b22..ea60286 100644
--- a/datafusion/src/optimizer/common_subexpr_eliminate.rs
+++ b/datafusion/src/optimizer/common_subexpr_eliminate.rs
@@ -442,6 +442,10 @@ impl ExprIdentifierVisitor<'_> {
Expr::Wildcard => {
desc.push_str("Wildcard-");
}
+ Expr::GetIndexedField { key, .. } => {
+ desc.push_str("GetIndexedField-");
+ desc.push_str(&key.to_string());
+ }
}
desc
diff --git a/datafusion/src/optimizer/utils.rs
b/datafusion/src/optimizer/utils.rs
index 00ea31e..f36330e 100644
--- a/datafusion/src/optimizer/utils.rs
+++ b/datafusion/src/optimizer/utils.rs
@@ -85,6 +85,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
Expr::AggregateUDF { .. } => {}
Expr::InList { .. } => {}
Expr::Wildcard => {}
+ Expr::GetIndexedField { .. } => {}
}
Ok(Recursion::Continue(self))
}
@@ -337,6 +338,7 @@ pub fn expr_sub_expressions(expr: &Expr) ->
Result<Vec<Expr>> {
Expr::Wildcard { .. } => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query
plan".to_owned(),
)),
+ Expr::GetIndexedField { expr, .. } =>
Ok(vec![expr.as_ref().to_owned()]),
}
}
@@ -496,6 +498,10 @@ pub fn rewrite_expression(expr: &Expr, expressions:
&[Expr]) -> Result<Expr> {
Expr::Wildcard { .. } => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query
plan".to_owned(),
)),
+ Expr::GetIndexedField { expr: _, key } => Ok(Expr::GetIndexedField {
+ expr: Box::new(expressions[0].clone()),
+ key: key.clone(),
+ }),
}
}
@@ -650,6 +656,7 @@ impl ConstEvaluator {
Expr::Cast { .. } => true,
Expr::TryCast { .. } => true,
Expr::InList { .. } => true,
+ Expr::GetIndexedField { .. } => true,
}
}
diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs
b/datafusion/src/physical_plan/expressions/get_indexed_field.rs
new file mode 100644
index 0000000..8a9191e
--- /dev/null
+++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs
@@ -0,0 +1,225 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! get field of a `ListArray`
+
+use std::convert::TryInto;
+use std::{any::Any, sync::Arc};
+
+use arrow::{
+ datatypes::{DataType, Schema},
+ record_batch::RecordBatch,
+};
+
+use crate::arrow::array::Array;
+use crate::arrow::compute::concat;
+use crate::scalar::ScalarValue;
+use crate::{
+ error::DataFusionError,
+ error::Result,
+ field_util::get_indexed_field as get_data_type_field,
+ physical_plan::{ColumnarValue, PhysicalExpr},
+};
+use arrow::array::ListArray;
+use std::fmt::Debug;
+
+/// expression to get a field of a struct array.
+#[derive(Debug)]
+pub struct GetIndexedFieldExpr {
+ arg: Arc<dyn PhysicalExpr>,
+ key: ScalarValue,
+}
+
+impl GetIndexedFieldExpr {
+ /// Create new get field expression
+ pub fn new(arg: Arc<dyn PhysicalExpr>, key: ScalarValue) -> Self {
+ Self { arg, key }
+ }
+
+ /// Get the input expression
+ pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
+ &self.arg
+ }
+}
+
+impl std::fmt::Display for GetIndexedFieldExpr {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "({}).[{}]", self.arg, self.key)
+ }
+}
+
+impl PhysicalExpr for GetIndexedFieldExpr {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
+ let data_type = self.arg.data_type(input_schema)?;
+ get_data_type_field(&data_type, &self.key).map(|f|
f.data_type().clone())
+ }
+
+ fn nullable(&self, input_schema: &Schema) -> Result<bool> {
+ let data_type = self.arg.data_type(input_schema)?;
+ get_data_type_field(&data_type, &self.key).map(|f| f.is_nullable())
+ }
+
+ fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
+ let arg = self.arg.evaluate(batch)?;
+ match arg {
+ ColumnarValue::Array(array) => match (array.data_type(),
&self.key) {
+ (DataType::List(_), _) if self.key.is_null() => {
+ let scalar_null: ScalarValue =
array.data_type().try_into()?;
+ Ok(ColumnarValue::Scalar(scalar_null))
+ }
+ (DataType::List(_), ScalarValue::Int64(Some(i))) => {
+ let as_list_array =
+ array.as_any().downcast_ref::<ListArray>().unwrap();
+ if as_list_array.is_empty() {
+ let scalar_null: ScalarValue =
array.data_type().try_into()?;
+ return Ok(ColumnarValue::Scalar(scalar_null))
+ }
+ let sliced_array: Vec<Arc<dyn Array>> = as_list_array
+ .iter()
+ .filter_map(|o| o.map(|list| list.slice(*i as usize,
1)))
+ .collect();
+ let vec = sliced_array.iter().map(|a|
a.as_ref()).collect::<Vec<&dyn Array>>();
+ let iter = concat(vec.as_slice()).unwrap();
+ Ok(ColumnarValue::Array(iter))
+ }
+ (dt, key) => Err(DataFusionError::NotImplemented(format!("get
indexed field is only possible on lists with int64 indexes. Tried {} with {}
index", dt, key))),
+ },
+ ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented(
+ "field access is not yet implemented for scalar
values".to_string(),
+ )),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::error::Result;
+ use crate::physical_plan::expressions::{col, lit};
+ use arrow::array::{ListBuilder, StringBuilder};
+ use arrow::{array::StringArray, datatypes::Field};
+
+ fn get_indexed_field_test(
+ list_of_lists: Vec<Vec<Option<&str>>>,
+ index: i64,
+ expected: Vec<Option<&str>>,
+ ) -> Result<()> {
+ let schema = list_schema("l");
+ let builder = StringBuilder::new(3);
+ let mut lb = ListBuilder::new(builder);
+ for values in list_of_lists {
+ let builder = lb.values();
+ for value in values {
+ match value {
+ None => builder.append_null(),
+ Some(v) => builder.append_value(v),
+ }
+ .unwrap()
+ }
+ lb.append(true).unwrap();
+ }
+
+ let expr = col("l", &schema).unwrap();
+ let batch = RecordBatch::try_new(Arc::new(schema),
vec![Arc::new(lb.finish())])?;
+
+ let key = ScalarValue::Int64(Some(index));
+ let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
+ let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
+ let result = result
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .expect("failed to downcast to StringArray");
+ let expected = &StringArray::from(expected);
+ assert_eq!(expected, result);
+ Ok(())
+ }
+
+ fn list_schema(col: &str) -> Schema {
+ Schema::new(vec![Field::new(
+ col,
+ DataType::List(Box::new(Field::new("item", DataType::Utf8, true))),
+ true,
+ )])
+ }
+
+ #[test]
+ fn get_indexed_field_list() -> Result<()> {
+ let list_of_lists = vec![
+ vec![Some("a"), Some("b"), None],
+ vec![None, Some("c"), Some("d")],
+ vec![Some("e"), None, Some("f")],
+ ];
+ let expected_list = vec![
+ vec![Some("a"), None, Some("e")],
+ vec![Some("b"), Some("c"), None],
+ vec![None, Some("d"), Some("f")],
+ ];
+
+ for (i, expected) in expected_list.into_iter().enumerate() {
+ get_indexed_field_test(list_of_lists.clone(), i as i64, expected)?;
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn get_indexed_field_empty_list() -> Result<()> {
+ let schema = list_schema("l");
+ let builder = StringBuilder::new(0);
+ let mut lb = ListBuilder::new(builder);
+ let expr = col("l", &schema).unwrap();
+ let batch = RecordBatch::try_new(Arc::new(schema),
vec![Arc::new(lb.finish())])?;
+ let key = ScalarValue::Int64(Some(0));
+ let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
+ let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
+ assert!(result.is_empty());
+ Ok(())
+ }
+
+ fn get_indexed_field_test_failure(
+ schema: Schema,
+ expr: Arc<dyn PhysicalExpr>,
+ key: ScalarValue,
+ expected: &str,
+ ) -> Result<()> {
+ let builder = StringBuilder::new(3);
+ let mut lb = ListBuilder::new(builder);
+ let batch = RecordBatch::try_new(Arc::new(schema),
vec![Arc::new(lb.finish())])?;
+ let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
+ let r = expr.evaluate(&batch).map(|_| ());
+ assert!(r.is_err());
+ assert_eq!(format!("{}", r.unwrap_err()), expected);
+ Ok(())
+ }
+
+ #[test]
+ fn get_indexed_field_invalid_scalar() -> Result<()> {
+ let schema = list_schema("l");
+ let expr = lit(ScalarValue::Utf8(Some("a".to_string())));
+ get_indexed_field_test_failure(schema, expr,
ScalarValue::Int64(Some(0)), "This feature is not implemented: field access is
not yet implemented for scalar values")
+ }
+
+ #[test]
+ fn get_indexed_field_invalid_list_index() -> Result<()> {
+ let schema = list_schema("l");
+ let expr = col("l", &schema).unwrap();
+ get_indexed_field_test_failure(schema, expr,
ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field
is only possible on lists with int64 indexes. Tried List(Field { name:
\"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false,
metadata: None }) with 0 index")
+ }
+}
diff --git a/datafusion/src/physical_plan/expressions/mod.rs
b/datafusion/src/physical_plan/expressions/mod.rs
index 843bee7..dba3bde 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -35,6 +35,7 @@ mod coercion;
mod column;
mod count;
mod cume_dist;
+mod get_indexed_field;
mod in_list;
mod is_not_null;
mod is_null;
@@ -66,6 +67,7 @@ pub use cast::{
pub use column::{col, Column};
pub use count::Count;
pub use cume_dist::cume_dist;
+pub use get_indexed_field::GetIndexedFieldExpr;
pub use in_list::{in_list, InListExpr};
pub use is_not_null::{is_not_null, IsNotNullExpr};
pub use is_null::{is_null, IsNullExpr};
diff --git a/datafusion/src/physical_plan/planner.rs
b/datafusion/src/physical_plan/planner.rs
index 8cfb907..fd0421b 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -32,7 +32,9 @@ use
crate::physical_optimizer::optimizer::PhysicalOptimizerRule;
use crate::physical_plan::cross_join::CrossJoinExec;
use crate::physical_plan::explain::ExplainExec;
use crate::physical_plan::expressions;
-use crate::physical_plan::expressions::{CaseExpr, Column, Literal,
PhysicalSortExpr};
+use crate::physical_plan::expressions::{
+ CaseExpr, Column, GetIndexedFieldExpr, Literal, PhysicalSortExpr,
+};
use crate::physical_plan::filter::FilterExec;
use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec};
use crate::physical_plan::hash_join::HashJoinExec;
@@ -141,6 +143,10 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) ->
Result<String> {
let expr = create_physical_name(expr, false)?;
Ok(format!("{} IS NOT NULL", expr))
}
+ Expr::GetIndexedField { expr, key } => {
+ let expr = create_physical_name(expr, false)?;
+ Ok(format!("{}[{}]", expr, key))
+ }
Expr::ScalarFunction { fun, args, .. } => {
create_function_physical_name(&fun.to_string(), false, args)
}
@@ -989,6 +995,18 @@ impl DefaultPhysicalPlanner {
Expr::IsNotNull(expr) => expressions::is_not_null(
self.create_physical_expr(expr, input_dfschema, input_schema,
ctx_state)?,
),
+ Expr::GetIndexedField { expr, key } => {
+ Ok(Arc::new(GetIndexedFieldExpr::new(
+ self.create_physical_expr(
+ expr,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )?,
+ key.clone(),
+ )))
+ }
+
Expr::ScalarFunction { fun, args } => {
let physical_args = args
.iter()
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index 7bdb7b8..60d2da8 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -81,6 +81,32 @@ pub struct SqlToRel<'a, S: ContextProvider> {
schema_provider: &'a S,
}
+fn plan_key(key: Value) -> ScalarValue {
+ match key {
+ Value::Number(s, _) => ScalarValue::Int64(Some(s.parse().unwrap())),
+ Value::SingleQuotedString(s) => ScalarValue::Utf8(Some(s)),
+ _ => unreachable!(),
+ }
+}
+
+#[allow(clippy::branches_sharing_code)]
+fn plan_indexed(expr: Expr, mut keys: Vec<Value>) -> Expr {
+ if keys.len() == 1 {
+ let key = keys.pop().unwrap();
+ Expr::GetIndexedField {
+ expr: Box::new(expr),
+ key: plan_key(key),
+ }
+ } else {
+ let key = keys.pop().unwrap();
+ let expr = Box::new(plan_indexed(expr, keys));
+ Expr::GetIndexedField {
+ expr,
+ key: plan_key(key),
+ }
+ }
+}
+
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
/// Create a new query planner
pub fn new(schema_provider: &'a S) -> Self {
@@ -1197,6 +1223,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
+ SQLExpr::MapAccess { ref column, keys } => {
+ if let SQLExpr::Identifier(ref id) = column.as_ref() {
+ Ok(plan_indexed(col(&id.value), keys.clone()))
+ } else {
+ Err(DataFusionError::NotImplemented(format!(
+ "map access requires an identifier, found column {}
instead",
+ column
+ )))
+ }
+ }
+
SQLExpr::CompoundIdentifier(ids) => {
let mut var_names = vec![];
for id in ids {
diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs
index 41bcd20..45e78cb 100644
--- a/datafusion/src/sql/utils.rs
+++ b/datafusion/src/sql/utils.rs
@@ -368,6 +368,10 @@ where
Ok(expr.clone())
}
Expr::Wildcard => Ok(Expr::Wildcard),
+ Expr::GetIndexedField { expr, key } => Ok(Expr::GetIndexedField {
+ expr: Box::new(clone_with_replacement(expr.as_ref(),
replacement_fn)?),
+ key: key.clone(),
+ }),
},
}
}
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index f3dba3f..f1e9888 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -5305,3 +5305,83 @@ async fn case_with_bool_type_result() -> Result<()> {
assert_eq!(expected, actual);
Ok(())
}
+
+#[tokio::test]
+async fn query_get_indexed_field() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "some_list",
+ DataType::List(Box::new(Field::new("item", DataType::Int64, true))),
+ false,
+ )]));
+ let builder = PrimitiveBuilder::<Int64Type>::new(3);
+ let mut lb = ListBuilder::new(builder);
+ for int_vec in vec![vec![0, 1, 2], vec![4, 5, 6], vec![7, 8, 9]] {
+ let builder = lb.values();
+ for int in int_vec {
+ builder.append_value(int).unwrap();
+ }
+ lb.append(true).unwrap();
+ }
+
+ let data = RecordBatch::try_new(schema.clone(),
vec![Arc::new(lb.finish())])?;
+ let table = MemTable::try_new(schema, vec![vec![data]])?;
+ let table_a = Arc::new(table);
+
+ ctx.register_table("ints", table_a)?;
+
+ // Original column is micros, convert to millis and check timestamp
+ let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["0"], vec!["4"], vec!["7"]];
+ assert_eq!(expected, actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn query_nested_get_indexed_field() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ let nested_dt = DataType::List(Box::new(Field::new("item",
DataType::Int64, true)));
+ // Nested schema of { "some_list": [[i64]] }
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "some_list",
+ DataType::List(Box::new(Field::new("item", nested_dt.clone(), true))),
+ false,
+ )]));
+
+ let builder = PrimitiveBuilder::<Int64Type>::new(3);
+ let nested_lb = ListBuilder::new(builder);
+ let mut lb = ListBuilder::new(nested_lb);
+ for int_vec_vec in vec![
+ vec![vec![0, 1], vec![2, 3], vec![3, 4]],
+ vec![vec![5, 6], vec![7, 8], vec![9, 10]],
+ vec![vec![11, 12], vec![13, 14], vec![15, 16]],
+ ] {
+ let nested_builder = lb.values();
+ for int_vec in int_vec_vec {
+ let builder = nested_builder.values();
+ for int in int_vec {
+ builder.append_value(int).unwrap();
+ }
+ nested_builder.append(true).unwrap();
+ }
+ lb.append(true).unwrap();
+ }
+
+ let data = RecordBatch::try_new(schema.clone(),
vec![Arc::new(lb.finish())])?;
+ let table = MemTable::try_new(schema, vec![vec![data]])?;
+ let table_a = Arc::new(table);
+
+ ctx.register_table("ints", table_a)?;
+
+ // Original column is micros, convert to millis and check timestamp
+ let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["[0, 1]"], vec!["[5, 6]"], vec!["[11, 12]"]];
+ assert_eq!(expected, actual);
+ let sql = "SELECT some_list[0][0] as i0 FROM ints LIMIT 3";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["0"], vec!["5"], vec!["11"]];
+ assert_eq!(expected, actual);
+ Ok(())
+}