This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new fffc8bef30 feat: support the ergonomics of getting list slice with
stride (#8946)
fffc8bef30 is described below
commit fffc8bef30c9cb84fe89f7c17b0803e0b665aa40
Author: Alex Huang <[email protected]>
AuthorDate: Mon Jan 29 20:17:24 2024 +0800
feat: support the ergonomics of getting list slice with stride (#8946)
* support list stride
* add test
* fix fmt
* rename and extend ListRange to ListStride
* fix ci
* fix doctest
* fix conflict and keep ListRange
* clean up thde code
* chore
* fix ci
---
datafusion/core/src/physical_planner.rs | 9 ++-
datafusion/expr/src/expr.rs | 28 +++++--
datafusion/expr/src/expr_schema.rs | 7 +-
datafusion/expr/src/field_util.rs | 13 ++--
datafusion/expr/src/tree_node/expr.rs | 4 +-
.../src/expressions/get_indexed_field.rs | 85 ++++++++++++++++------
datafusion/physical-expr/src/planner.rs | 27 ++++---
datafusion/proto/proto/datafusion.proto | 2 +
datafusion/proto/src/generated/pbjson.rs | 34 +++++++++
datafusion/proto/src/generated/prost.rs | 4 +
datafusion/proto/src/logical_plan/from_proto.rs | 5 ++
datafusion/proto/src/logical_plan/to_proto.rs | 19 +++--
datafusion/proto/src/physical_plan/from_proto.rs | 9 ++-
datafusion/proto/src/physical_plan/to_proto.rs | 15 ++--
.../proto/tests/cases/roundtrip_physical_plan.rs | 3 +
datafusion/sql/src/expr/mod.rs | 53 +++++++++++---
datafusion/sqllogictest/test_files/array.slt | 45 ++++++++++++
17 files changed, 282 insertions(+), 80 deletions(-)
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index ac3b7ebaea..d383ddce92 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -209,10 +209,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) ->
Result<String> {
let key = create_physical_name(key, false)?;
format!("{expr}[{key}]")
}
- GetFieldAccess::ListRange { start, stop } => {
+ GetFieldAccess::ListRange {
+ start,
+ stop,
+ stride,
+ } => {
let start = create_physical_name(start, false)?;
let stop = create_physical_name(stop, false)?;
- format!("{expr}[{start}:{stop}]")
+ let stride = create_physical_name(stride, false)?;
+ format!("{expr}[{start}:{stop}:{stride}]")
}
};
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index c5d158d876..9da1f4bb4d 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -421,8 +421,12 @@ pub enum GetFieldAccess {
NamedStructField { name: ScalarValue },
/// Single list index, for example: `list[i]`
ListIndex { key: Box<Expr> },
- /// List range, for example `list[i:j]`
- ListRange { start: Box<Expr>, stop: Box<Expr> },
+ /// List stride, for example `list[i:j:k]`
+ ListRange {
+ start: Box<Expr>,
+ stop: Box<Expr>,
+ stride: Box<Expr>,
+ },
}
/// Returns the field of a [`arrow::array::ListArray`] or
@@ -1209,7 +1213,7 @@ impl Expr {
/// # use datafusion_expr::{lit, col};
/// let expr = col("c1")
/// .range(lit(2), lit(4));
- /// assert_eq!(expr.display_name().unwrap(), "c1[Int32(2):Int32(4)]");
+ /// assert_eq!(expr.display_name().unwrap(),
"c1[Int32(2):Int32(4):Int64(1)]");
/// ```
pub fn range(self, start: Expr, stop: Expr) -> Self {
Expr::GetIndexedField(GetIndexedField {
@@ -1217,6 +1221,7 @@ impl Expr {
field: GetFieldAccess::ListRange {
start: Box::new(start),
stop: Box::new(stop),
+ stride: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
},
})
}
@@ -1530,8 +1535,12 @@ impl fmt::Display for Expr {
write!(f, "({expr})[{name}]")
}
GetFieldAccess::ListIndex { key } => write!(f,
"({expr})[{key}]"),
- GetFieldAccess::ListRange { start, stop } => {
- write!(f, "({expr})[{start}:{stop}]")
+ GetFieldAccess::ListRange {
+ start,
+ stop,
+ stride,
+ } => {
+ write!(f, "({expr})[{start}:{stop}:{stride}]")
}
},
Expr::GroupingSet(grouping_sets) => match grouping_sets {
@@ -1732,10 +1741,15 @@ fn create_name(e: &Expr) -> Result<String> {
let key = create_name(key)?;
Ok(format!("{expr}[{key}]"))
}
- GetFieldAccess::ListRange { start, stop } => {
+ GetFieldAccess::ListRange {
+ start,
+ stop,
+ stride,
+ } => {
let start = create_name(start)?;
let stop = create_name(stop)?;
- Ok(format!("{expr}[{start}:{stop}]"))
+ let stride = create_name(stride)?;
+ Ok(format!("{expr}[{start}:{stop}:{stride}]"))
}
}
}
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index ba21d09f06..4967e66fed 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -374,9 +374,14 @@ fn field_for_index<S: ExprSchema>(
GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex {
key_dt: key.get_type(schema)?,
},
- GetFieldAccess::ListRange { start, stop } =>
GetFieldAccessSchema::ListRange {
+ GetFieldAccess::ListRange {
+ start,
+ stop,
+ stride,
+ } => GetFieldAccessSchema::ListRange {
start_dt: start.get_type(schema)?,
stop_dt: stop.get_type(schema)?,
+ stride_dt: stride.get_type(schema)?,
},
}
.get_accessed_field(&expr_dt)
diff --git a/datafusion/expr/src/field_util.rs
b/datafusion/expr/src/field_util.rs
index 3829a2086b..c46ec50234 100644
--- a/datafusion/expr/src/field_util.rs
+++ b/datafusion/expr/src/field_util.rs
@@ -28,10 +28,11 @@ pub enum GetFieldAccessSchema {
NamedStructField { name: ScalarValue },
/// Single list index, for example: `list[i]`
ListIndex { key_dt: DataType },
- /// List range, for example `list[i:j]`
+ /// List stride, for example `list[i:j:k]`
ListRange {
start_dt: DataType,
stop_dt: DataType,
+ stride_dt: DataType,
},
}
@@ -85,13 +86,13 @@ impl GetFieldAccessSchema {
(other, _) => plan_err!("The expression to get an indexed
field is only valid for `List` or `Struct` types, got {other}"),
}
}
- Self::ListRange{ start_dt, stop_dt } => {
- match (data_type, start_dt, stop_dt) {
- (DataType::List(_), DataType::Int64, DataType::Int64) =>
Ok(Field::new("list", data_type.clone(), true)),
- (DataType::List(_), _, _) => plan_err!(
+ Self::ListRange { start_dt, stop_dt, stride_dt } => {
+ match (data_type, start_dt, stop_dt, stride_dt) {
+ (DataType::List(_), DataType::Int64, DataType::Int64,
DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)),
+ (DataType::List(_), _, _, _) => plan_err!(
"Only ints are valid as an indexed field in a list"
),
- (other, _, _) => plan_err!("The expression to get an
indexed field is only valid for `List` or `Struct` types, got {other}"),
+ (other, _, _, _) => plan_err!("The expression to get an
indexed field is only valid for `List` or `Struct` types, got {other}"),
}
}
}
diff --git a/datafusion/expr/src/tree_node/expr.rs
b/datafusion/expr/src/tree_node/expr.rs
index 05464c96d0..8b38d1cf01 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -52,8 +52,8 @@ impl TreeNode for Expr {
let expr = expr.as_ref();
match field {
GetFieldAccess::ListIndex {key} => vec![key.as_ref(),
expr],
- GetFieldAccess::ListRange {start, stop} => {
- vec![start.as_ref(), stop.as_ref(), expr]
+ GetFieldAccess::ListRange {start, stop, stride} => {
+ vec![start.as_ref(), stop.as_ref(),stride.as_ref(),
expr]
}
GetFieldAccess::NamedStructField { .. } => vec![expr],
}
diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
index 43fd5a812a..58fe472854 100644
--- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
+++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
@@ -21,6 +21,7 @@ use crate::PhysicalExpr;
use datafusion_common::exec_err;
use crate::array_expressions::{array_element, array_slice};
+use crate::expressions::Literal;
use crate::physical_expr::down_cast_any_ref;
use arrow::{
array::{Array, Scalar, StringArray},
@@ -43,10 +44,11 @@ pub enum GetFieldAccessExpr {
NamedStructField { name: ScalarValue },
/// Single list index, for example: `list[i]`
ListIndex { key: Arc<dyn PhysicalExpr> },
- /// List range, for example `list[i:j]`
+ /// List stride, for example `list[i:j:k]`
ListRange {
start: Arc<dyn PhysicalExpr>,
stop: Arc<dyn PhysicalExpr>,
+ stride: Arc<dyn PhysicalExpr>,
},
}
@@ -55,8 +57,12 @@ impl std::fmt::Display for GetFieldAccessExpr {
match self {
GetFieldAccessExpr::NamedStructField { name } => write!(f, "[{}]",
name),
GetFieldAccessExpr::ListIndex { key } => write!(f, "[{}]", key),
- GetFieldAccessExpr::ListRange { start, stop } => {
- write!(f, "[{}:{}]", start, stop)
+ GetFieldAccessExpr::ListRange {
+ start,
+ stop,
+ stride,
+ } => {
+ write!(f, "[{}:{}:{}]", start, stop, stride)
}
}
}
@@ -76,12 +82,18 @@ impl PartialEq<dyn Any> for GetFieldAccessExpr {
ListRange {
start: start_lhs,
stop: stop_lhs,
+ stride: stride_lhs,
},
ListRange {
start: start_rhs,
stop: stop_rhs,
+ stride: stride_rhs,
},
- ) => start_lhs.eq(start_rhs) && stop_lhs.eq(stop_rhs),
+ ) => {
+ start_lhs.eq(start_rhs)
+ && stop_lhs.eq(stop_rhs)
+ && stride_lhs.eq(stride_rhs)
+ }
(NamedStructField { .. }, ListIndex { .. } | ListRange { .. })
=> false,
(ListIndex { .. }, NamedStructField { .. } | ListRange { .. })
=> false,
(ListRange { .. }, NamedStructField { .. } | ListIndex { .. })
=> false,
@@ -126,7 +138,32 @@ impl GetIndexedFieldExpr {
start: Arc<dyn PhysicalExpr>,
stop: Arc<dyn PhysicalExpr>,
) -> Self {
- Self::new(arg, GetFieldAccessExpr::ListRange { start, stop })
+ Self::new(
+ arg,
+ GetFieldAccessExpr::ListRange {
+ start,
+ stop,
+ stride: Arc::new(Literal::new(ScalarValue::Int64(Some(1))))
+ as Arc<dyn PhysicalExpr>,
+ },
+ )
+ }
+
+ /// Create a new [`GetIndexedFieldExpr`] for accessing the stride
+ pub fn new_stride(
+ arg: Arc<dyn PhysicalExpr>,
+ start: Arc<dyn PhysicalExpr>,
+ stop: Arc<dyn PhysicalExpr>,
+ stride: Arc<dyn PhysicalExpr>,
+ ) -> Self {
+ Self::new(
+ arg,
+ GetFieldAccessExpr::ListRange {
+ start,
+ stop,
+ stride,
+ },
+ )
}
/// Get the description of what field should be accessed
@@ -147,12 +184,15 @@ impl GetIndexedFieldExpr {
GetFieldAccessExpr::ListIndex { key } =>
GetFieldAccessSchema::ListIndex {
key_dt: key.data_type(input_schema)?,
},
- GetFieldAccessExpr::ListRange { start, stop } => {
- GetFieldAccessSchema::ListRange {
- start_dt: start.data_type(input_schema)?,
- stop_dt: stop.data_type(input_schema)?,
- }
- }
+ GetFieldAccessExpr::ListRange {
+ start,
+ stop,
+ stride,
+ } => GetFieldAccessSchema::ListRange {
+ start_dt: start.data_type(input_schema)?,
+ stop_dt: stop.data_type(input_schema)?,
+ stride_dt: stride.data_type(input_schema)?,
+ },
})
}
}
@@ -223,21 +263,24 @@ impl PhysicalExpr for GetIndexedFieldExpr {
with utf8 indexes. Tried
{dt:?} with {key:?} index"),
}
},
- GetFieldAccessExpr::ListRange{start, stop} => {
+ GetFieldAccessExpr::ListRange { start, stop, stride } => {
let start =
start.evaluate(batch)?.into_array(batch.num_rows())?;
let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?;
- match (array.data_type(), start.data_type(), stop.data_type())
{
- (DataType::List(_), DataType::Int64, DataType::Int64) =>
Ok(ColumnarValue::Array(array_slice(&[
- array, start, stop
- ])?)),
- (DataType::List(_), start, stop) => exec_err!(
+ let stride =
stride.evaluate(batch)?.into_array(batch.num_rows())?;
+ match (array.data_type(), start.data_type(), stop.data_type(),
stride.data_type()) {
+ (DataType::List(_), DataType::Int64, DataType::Int64,
DataType::Int64) => {
+ Ok(ColumnarValue::Array((array_slice(&[
+ array, start, stop, stride
+ ]))?))
+ },
+ (DataType::List(_), start, stop, stride) => exec_err!(
"get indexed field is only possible on lists with
int64 indexes. \
- Tried with {start:?} and {stop:?} indices"),
- (dt, start, stop) => exec_err!(
+ Tried with {start:?}, {stop:?} and {stride:?}
indices"),
+ (dt, start, stop, stride) => exec_err!(
"get indexed field is only possible on lists with
int64 indexes or struct \
- with utf8 indexes. Tried {dt:?} with
{start:?} and {stop:?} indices"),
+ with utf8 indexes. Tried {dt:?} with
{start:?}, {stop:?} and {stride:?} indices"),
}
- },
+ }
}
}
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index 09b8da836c..ee5da05d11 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -238,20 +238,19 @@ pub fn create_physical_expr(
GetFieldAccess::ListIndex { key } =>
GetFieldAccessExpr::ListIndex {
key: create_physical_expr(key, input_dfschema,
execution_props)?,
},
- GetFieldAccess::ListRange { start, stop } => {
- GetFieldAccessExpr::ListRange {
- start: create_physical_expr(
- start,
- input_dfschema,
- execution_props,
- )?,
- stop: create_physical_expr(
- stop,
- input_dfschema,
- execution_props,
- )?,
- }
- }
+ GetFieldAccess::ListRange {
+ start,
+ stop,
+ stride,
+ } => GetFieldAccessExpr::ListRange {
+ start: create_physical_expr(start, input_dfschema,
execution_props)?,
+ stop: create_physical_expr(stop, input_dfschema,
execution_props)?,
+ stride: create_physical_expr(
+ stride,
+ input_dfschema,
+ execution_props,
+ )?,
+ },
};
Ok(Arc::new(GetIndexedFieldExpr::new(
create_physical_expr(expr, input_dfschema, execution_props)?,
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 66c1271e65..c8468e1709 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -466,6 +466,7 @@ message ListIndex {
message ListRange {
LogicalExprNode start = 1;
LogicalExprNode stop = 2;
+ LogicalExprNode stride = 3;
}
message GetIndexedField {
@@ -1773,6 +1774,7 @@ message ListIndexExpr {
message ListRangeExpr {
PhysicalExprNode start = 1;
PhysicalExprNode stop = 2;
+ PhysicalExprNode stride = 3;
}
message PhysicalGetIndexedFieldExprNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 39a8678ef2..47667fb68c 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -12542,6 +12542,9 @@ impl serde::Serialize for ListRange {
if self.stop.is_some() {
len += 1;
}
+ if self.stride.is_some() {
+ len += 1;
+ }
let mut struct_ser =
serializer.serialize_struct("datafusion.ListRange", len)?;
if let Some(v) = self.start.as_ref() {
struct_ser.serialize_field("start", v)?;
@@ -12549,6 +12552,9 @@ impl serde::Serialize for ListRange {
if let Some(v) = self.stop.as_ref() {
struct_ser.serialize_field("stop", v)?;
}
+ if let Some(v) = self.stride.as_ref() {
+ struct_ser.serialize_field("stride", v)?;
+ }
struct_ser.end()
}
}
@@ -12561,12 +12567,14 @@ impl<'de> serde::Deserialize<'de> for ListRange {
const FIELDS: &[&str] = &[
"start",
"stop",
+ "stride",
];
#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Start,
Stop,
+ Stride,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -12590,6 +12598,7 @@ impl<'de> serde::Deserialize<'de> for ListRange {
match value {
"start" => Ok(GeneratedField::Start),
"stop" => Ok(GeneratedField::Stop),
+ "stride" => Ok(GeneratedField::Stride),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -12611,6 +12620,7 @@ impl<'de> serde::Deserialize<'de> for ListRange {
{
let mut start__ = None;
let mut stop__ = None;
+ let mut stride__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Start => {
@@ -12625,11 +12635,18 @@ impl<'de> serde::Deserialize<'de> for ListRange {
}
stop__ = map_.next_value()?;
}
+ GeneratedField::Stride => {
+ if stride__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("stride"));
+ }
+ stride__ = map_.next_value()?;
+ }
}
}
Ok(ListRange {
start: start__,
stop: stop__,
+ stride: stride__,
})
}
}
@@ -12650,6 +12667,9 @@ impl serde::Serialize for ListRangeExpr {
if self.stop.is_some() {
len += 1;
}
+ if self.stride.is_some() {
+ len += 1;
+ }
let mut struct_ser =
serializer.serialize_struct("datafusion.ListRangeExpr", len)?;
if let Some(v) = self.start.as_ref() {
struct_ser.serialize_field("start", v)?;
@@ -12657,6 +12677,9 @@ impl serde::Serialize for ListRangeExpr {
if let Some(v) = self.stop.as_ref() {
struct_ser.serialize_field("stop", v)?;
}
+ if let Some(v) = self.stride.as_ref() {
+ struct_ser.serialize_field("stride", v)?;
+ }
struct_ser.end()
}
}
@@ -12669,12 +12692,14 @@ impl<'de> serde::Deserialize<'de> for ListRangeExpr {
const FIELDS: &[&str] = &[
"start",
"stop",
+ "stride",
];
#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Start,
Stop,
+ Stride,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -12698,6 +12723,7 @@ impl<'de> serde::Deserialize<'de> for ListRangeExpr {
match value {
"start" => Ok(GeneratedField::Start),
"stop" => Ok(GeneratedField::Stop),
+ "stride" => Ok(GeneratedField::Stride),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -12719,6 +12745,7 @@ impl<'de> serde::Deserialize<'de> for ListRangeExpr {
{
let mut start__ = None;
let mut stop__ = None;
+ let mut stride__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Start => {
@@ -12733,11 +12760,18 @@ impl<'de> serde::Deserialize<'de> for ListRangeExpr {
}
stop__ = map_.next_value()?;
}
+ GeneratedField::Stride => {
+ if stride__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("stride"));
+ }
+ stride__ = map_.next_value()?;
+ }
}
}
Ok(ListRangeExpr {
start: start__,
stop: stop__,
+ stride: stride__,
})
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 7bf1d8ed04..a5582cc2dc 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -731,6 +731,8 @@ pub struct ListRange {
pub start:
::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
#[prost(message, optional, boxed, tag = "2")]
pub stop:
::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
+ #[prost(message, optional, boxed, tag = "3")]
+ pub stride:
::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
@@ -2538,6 +2540,8 @@ pub struct ListRangeExpr {
pub start:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalExprNode>>,
#[prost(message, optional, boxed, tag = "2")]
pub stop:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalExprNode>>,
+ #[prost(message, optional, boxed, tag = "3")]
+ pub stride:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalExprNode>>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 42d39b5c51..eb72d1f9c3 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -1067,6 +1067,11 @@ pub fn parse_expr(
registry,
"stop",
)?),
+ stride: Box::new(parse_required_expr(
+ list_range.stride.as_deref(),
+ registry,
+ "stride",
+ )?),
}
}
None => return Err(proto_error("Field must not be None")),
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index dbb52eced3..e1fc3f0c85 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1033,14 +1033,17 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
},
))
}
- GetFieldAccess::ListRange { start, stop } => {
- protobuf::get_indexed_field::Field::ListRange(Box::new(
- protobuf::ListRange {
- start:
Some(Box::new(start.as_ref().try_into()?)),
- stop:
Some(Box::new(stop.as_ref().try_into()?)),
- },
- ))
- }
+ GetFieldAccess::ListRange {
+ start,
+ stop,
+ stride,
+ } =>
protobuf::get_indexed_field::Field::ListRange(Box::new(
+ protobuf::ListRange {
+ start: Some(Box::new(start.as_ref().try_into()?)),
+ stop: Some(Box::new(stop.as_ref().try_into()?)),
+ stride:
Some(Box::new(stride.as_ref().try_into()?)),
+ },
+ )),
};
Self {
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index dc827d02bf..454f74dfd1 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -411,8 +411,15 @@ pub fn parse_physical_expr(
"stop",
input_schema
)?,
+ stride: parse_required_physical_expr(
+ list_range_expr.stride.as_deref(),
+ registry,
+ "stride",
+ input_schema
+ )?,
},
- None => return Err(proto_error(
+ None =>
+ return Err(proto_error(
"Field must not be None",
)),
};
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index cff32ca2f8..a67410da57 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -562,12 +562,15 @@ impl TryFrom<Arc<dyn PhysicalExpr>> for
protobuf::PhysicalExprNode {
key: Some(Box::new(key.to_owned().try_into()?))
}))
),
- GetFieldAccessExpr::ListRange{start, stop} => Some(
-
protobuf::physical_get_indexed_field_expr_node::Field::ListRangeExpr(Box::new(protobuf::ListRangeExpr
{
- start: Some(Box::new(start.to_owned().try_into()?)),
- stop: Some(Box::new(stop.to_owned().try_into()?)),
- }))
- ),
+ GetFieldAccessExpr::ListRange { start, stop, stride } => {
+ Some(
+
protobuf::physical_get_indexed_field_expr_node::Field::ListRangeExpr(Box::new(protobuf::ListRangeExpr
{
+ start:
Some(Box::new(start.to_owned().try_into()?)),
+ stop: Some(Box::new(stop.to_owned().try_into()?)),
+ stride:
Some(Box::new(stride.to_owned().try_into()?)),
+ }))
+ )
+ }
};
Ok(protobuf::PhysicalExprNode {
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 38eb390003..eba3db298f 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -35,6 +35,7 @@ use datafusion::logical_expr::{
create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility,
};
use datafusion::parquet::file::properties::WriterProperties;
+use datafusion::physical_expr::expressions::Literal;
use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr};
use datafusion::physical_plan::aggregates::{
@@ -750,6 +751,8 @@ fn roundtrip_get_indexed_field_list_range() -> Result<()> {
GetFieldAccessExpr::ListRange {
start: col_start,
stop: col_stop,
+ stride: Arc::new(Literal::new(ScalarValue::Int64(Some(1))))
+ as Arc<dyn PhysicalExpr>,
},
));
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index 9fded63af3..b22c458b6d 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -753,18 +753,47 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
operator: JsonOperator::Colon,
right,
} => {
- let start = Box::new(self.sql_expr_to_logical_expr(
- *left,
- schema,
- planner_context,
- )?);
- let stop = Box::new(self.sql_expr_to_logical_expr(
- *right,
- schema,
- planner_context,
- )?);
-
- GetFieldAccess::ListRange { start, stop }
+ let (start, stop, stride) = if let SQLExpr::JsonAccess {
+ left: l,
+ operator: JsonOperator::Colon,
+ right: r,
+ } = *left
+ {
+ let start = Box::new(self.sql_expr_to_logical_expr(
+ *l,
+ schema,
+ planner_context,
+ )?);
+ let stop = Box::new(self.sql_expr_to_logical_expr(
+ *r,
+ schema,
+ planner_context,
+ )?);
+ let stride = Box::new(self.sql_expr_to_logical_expr(
+ *right,
+ schema,
+ planner_context,
+ )?);
+ (start, stop, stride)
+ } else {
+ let start = Box::new(self.sql_expr_to_logical_expr(
+ *left,
+ schema,
+ planner_context,
+ )?);
+ let stop = Box::new(self.sql_expr_to_logical_expr(
+ *right,
+ schema,
+ planner_context,
+ )?);
+ let stride =
Box::new(Expr::Literal(ScalarValue::Int64(Some(1))));
+ (start, stop, stride)
+ };
+ GetFieldAccess::ListRange {
+ start,
+ stop,
+ stride,
+ }
}
_ => GetFieldAccess::ListIndex {
key: Box::new(self.sql_expr_to_logical_expr(
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index b7d92aec88..e072e4146f 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -759,6 +759,51 @@ select column1[0:5], column2[0:3], column3[0:9] from
arrays;
# select column1[column2:column3] from arrays_with_repeating_elements;
# ----
+# array[i:j:k]
+
+# multiple index with columns #1 (positive index)
+query ???
+select make_array(1, 2, 3)[1:2:2], make_array(1.0, 2.0, 3.0)[2:3:2],
make_array('h', 'e', 'l', 'l', 'o')[2:4:2];
+----
+[1] [2.0] [e, l]
+
+# multiple index with columns #2 (zero index)
+query ???
+select make_array(1, 2, 3)[0:0:2], make_array(1.0, 2.0, 3.0)[0:2:2],
make_array('h', 'e', 'l', 'l', 'o')[0:6:2];
+----
+[] [1.0] [h, l, o]
+
+#TODO: sqlparser does not support negative index
+## multiple index with columns #3 (negative index)
+#query ???
+#select make_array(1, 2, 3)[-1:-2:-2], make_array(1.0, 2.0, 3.0)[-2:-3:-2],
make_array('h', 'e', 'l', 'l', 'o')[-2:-4:-2];
+#----
+#[1] [2.0] [e, l]
+
+# multiple index with columns #1 (positive index)
+query ???
+select column1[2:4:2], column2[1:4:2], column3[3:4:2] from arrays;
+----
+[[3, ]] [1.1, 3.3] [r]
+[[5, 6]] [, 6.6] []
+[[7, 8]] [7.7, 9.9] [l]
+[[9, 10]] [10.1, 12.2] [t]
+[] [13.3, 15.5] [e]
+[[13, 14]] [] []
+[[, 18]] [16.6, 18.8] []
+
+# multiple index with columns #2 (zero index)
+query ???
+select column1[0:5:2], column2[0:3:2], column3[0:9:2] from arrays;
+----
+[[, 2]] [1.1, 3.3] [L, r, m]
+[[3, 4]] [, 6.6] [i, , m]
+[[5, 6]] [7.7, 9.9] [d, l, r]
+[[7, ]] [10.1, 12.2] [s, t]
+[] [13.3, 15.5] [a, e]
+[[11, 12]] [] [,]
+[[15, 16]] [16.6, 18.8] []
+
### Array function tests