This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new c4ef5ed9d chore: extract conversion_funcs, conditional_funcs,
bitwise_funcs and array_funcs expressions to folders based on spark grouping
(#1223)
c4ef5ed9d is described below
commit c4ef5ed9df630e6f345853246d13cf58fe6bd9b9
Author: Raz Luvaton <[email protected]>
AuthorDate: Sat Jan 18 22:56:09 2025 +0200
chore: extract conversion_funcs, conditional_funcs, bitwise_funcs and
array_funcs expressions to folders based on spark grouping (#1223)
---
.../src/{list.rs => array_funcs/array_insert.rs} | 418 +--------------------
.../src/array_funcs/get_array_struct_fields.rs | 166 ++++++++
native/spark-expr/src/array_funcs/list_extract.rs | 310 +++++++++++++++
native/spark-expr/src/array_funcs/mod.rs | 24 ++
.../src/{ => bitwise_funcs}/bitwise_not.rs | 0
native/spark-expr/src/bitwise_funcs/mod.rs | 20 +
.../src/{ => conditional_funcs}/if_expr.rs | 0
native/spark-expr/src/conditional_funcs/mod.rs | 20 +
.../spark-expr/src/{ => conversion_funcs}/cast.rs | 0
native/spark-expr/src/conversion_funcs/mod.rs | 18 +
native/spark-expr/src/lib.rs | 17 +-
11 files changed, 574 insertions(+), 419 deletions(-)
diff --git a/native/spark-expr/src/list.rs
b/native/spark-expr/src/array_funcs/array_insert.rs
similarity index 54%
rename from native/spark-expr/src/list.rs
rename to native/spark-expr/src/array_funcs/array_insert.rs
index fc31b11a0..08fb78905 100644
--- a/native/spark-expr/src/list.rs
+++ b/native/spark-expr/src/array_funcs/array_insert.rs
@@ -21,14 +21,12 @@ use arrow::{
datatypes::ArrowNativeType,
record_batch::RecordBatch,
};
-use arrow_array::{
- make_array, Array, ArrayRef, GenericListArray, Int32Array,
OffsetSizeTrait, StructArray,
-};
-use arrow_schema::{DataType, Field, FieldRef, Schema};
+use arrow_array::{make_array, Array, ArrayRef, GenericListArray, Int32Array,
OffsetSizeTrait};
+use arrow_schema::{DataType, Field, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{
- cast::{as_int32_array, as_large_list_array, as_list_array},
- internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
+ cast::{as_large_list_array, as_list_array},
+ internal_err, DataFusionError, Result as DataFusionResult,
};
use datafusion_physical_expr::PhysicalExpr;
use std::hash::Hash;
@@ -43,372 +41,6 @@ use std::{
//
https://github.com/apache/spark/blob/master/common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java
const MAX_ROUNDED_ARRAY_LENGTH: usize = 2147483632;
-#[derive(Debug, Eq)]
-pub struct ListExtract {
- child: Arc<dyn PhysicalExpr>,
- ordinal: Arc<dyn PhysicalExpr>,
- default_value: Option<Arc<dyn PhysicalExpr>>,
- one_based: bool,
- fail_on_error: bool,
-}
-
-impl Hash for ListExtract {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- self.child.hash(state);
- self.ordinal.hash(state);
- self.default_value.hash(state);
- self.one_based.hash(state);
- self.fail_on_error.hash(state);
- }
-}
-impl PartialEq for ListExtract {
- fn eq(&self, other: &Self) -> bool {
- self.child.eq(&other.child)
- && self.ordinal.eq(&other.ordinal)
- && self.default_value.eq(&other.default_value)
- && self.one_based.eq(&other.one_based)
- && self.fail_on_error.eq(&other.fail_on_error)
- }
-}
-
-impl ListExtract {
- pub fn new(
- child: Arc<dyn PhysicalExpr>,
- ordinal: Arc<dyn PhysicalExpr>,
- default_value: Option<Arc<dyn PhysicalExpr>>,
- one_based: bool,
- fail_on_error: bool,
- ) -> Self {
- Self {
- child,
- ordinal,
- default_value,
- one_based,
- fail_on_error,
- }
- }
-
- fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef>
{
- match self.child.data_type(input_schema)? {
- DataType::List(field) | DataType::LargeList(field) => Ok(field),
- data_type => Err(DataFusionError::Internal(format!(
- "Unexpected data type in ListExtract: {:?}",
- data_type
- ))),
- }
- }
-}
-
-impl PhysicalExpr for ListExtract {
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
- Ok(self.child_field(input_schema)?.data_type().clone())
- }
-
- fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
- // Only non-nullable if fail_on_error is enabled and the element is
non-nullable
- Ok(!self.fail_on_error ||
self.child_field(input_schema)?.is_nullable())
- }
-
- fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue>
{
- let child_value =
self.child.evaluate(batch)?.into_array(batch.num_rows())?;
- let ordinal_value =
self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?;
-
- let default_value = self
- .default_value
- .as_ref()
- .map(|d| {
- d.evaluate(batch).map(|value| match value {
- ColumnarValue::Scalar(scalar)
- if
!scalar.data_type().equals_datatype(child_value.data_type()) =>
- {
- scalar.cast_to(child_value.data_type())
- }
- ColumnarValue::Scalar(scalar) => Ok(scalar),
- v => Err(DataFusionError::Execution(format!(
- "Expected scalar default value for ListExtract, got
{:?}",
- v
- ))),
- })
- })
- .transpose()?
- .unwrap_or(self.data_type(&batch.schema())?.try_into())?;
-
- let adjust_index = if self.one_based {
- one_based_index
- } else {
- zero_based_index
- };
-
- match child_value.data_type() {
- DataType::List(_) => {
- let list_array = as_list_array(&child_value)?;
- let index_array = as_int32_array(&ordinal_value)?;
-
- list_extract(
- list_array,
- index_array,
- &default_value,
- self.fail_on_error,
- adjust_index,
- )
- }
- DataType::LargeList(_) => {
- let list_array = as_large_list_array(&child_value)?;
- let index_array = as_int32_array(&ordinal_value)?;
-
- list_extract(
- list_array,
- index_array,
- &default_value,
- self.fail_on_error,
- adjust_index,
- )
- }
- data_type => Err(DataFusionError::Internal(format!(
- "Unexpected child type for ListExtract: {:?}",
- data_type
- ))),
- }
- }
-
- fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
- vec![&self.child, &self.ordinal]
- }
-
- fn with_new_children(
- self: Arc<Self>,
- children: Vec<Arc<dyn PhysicalExpr>>,
- ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
- match children.len() {
- 2 => Ok(Arc::new(ListExtract::new(
- Arc::clone(&children[0]),
- Arc::clone(&children[1]),
- self.default_value.clone(),
- self.one_based,
- self.fail_on_error,
- ))),
- _ => internal_err!("ListExtract should have exactly two children"),
- }
- }
-}
-
-fn one_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>> {
- if index == 0 {
- return Err(DataFusionError::Execution(
- "Invalid index of 0 for one-based ListExtract".to_string(),
- ));
- }
-
- let abs_index = index.abs().as_usize();
- if abs_index <= len {
- if index > 0 {
- Ok(Some(abs_index - 1))
- } else {
- Ok(Some(len - abs_index))
- }
- } else {
- Ok(None)
- }
-}
-
-fn zero_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>>
{
- if index < 0 {
- Ok(None)
- } else {
- let positive_index = index.as_usize();
- if positive_index < len {
- Ok(Some(positive_index))
- } else {
- Ok(None)
- }
- }
-}
-
-fn list_extract<O: OffsetSizeTrait>(
- list_array: &GenericListArray<O>,
- index_array: &Int32Array,
- default_value: &ScalarValue,
- fail_on_error: bool,
- adjust_index: impl Fn(i32, usize) -> DataFusionResult<Option<usize>>,
-) -> DataFusionResult<ColumnarValue> {
- let values = list_array.values();
- let offsets = list_array.offsets();
-
- let data = values.to_data();
-
- let default_data = default_value.to_array()?.to_data();
-
- let mut mutable = MutableArrayData::new(vec![&data, &default_data], true,
index_array.len());
-
- for (row, (offset_window, index)) in
offsets.windows(2).zip(index_array.values()).enumerate() {
- let start = offset_window[0].as_usize();
- let len = offset_window[1].as_usize() - start;
-
- if let Some(i) = adjust_index(*index, len)? {
- mutable.extend(0, start + i, start + i + 1);
- } else if list_array.is_null(row) {
- mutable.extend_nulls(1);
- } else if fail_on_error {
- return Err(DataFusionError::Execution(
- "Index out of bounds for array".to_string(),
- ));
- } else {
- mutable.extend(1, 0, 1);
- }
- }
-
- let data = mutable.freeze();
- Ok(ColumnarValue::Array(arrow::array::make_array(data)))
-}
-
-impl Display for ListExtract {
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- write!(
- f,
- "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?},
one_based: {:?}, fail_on_error: {:?}]",
- self.child, self.ordinal, self.default_value, self.one_based,
self.fail_on_error
- )
- }
-}
-
-#[derive(Debug, Eq)]
-pub struct GetArrayStructFields {
- child: Arc<dyn PhysicalExpr>,
- ordinal: usize,
-}
-
-impl Hash for GetArrayStructFields {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- self.child.hash(state);
- self.ordinal.hash(state);
- }
-}
-impl PartialEq for GetArrayStructFields {
- fn eq(&self, other: &Self) -> bool {
- self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal)
- }
-}
-
-impl GetArrayStructFields {
- pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
- Self { child, ordinal }
- }
-
- fn list_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
- match self.child.data_type(input_schema)? {
- DataType::List(field) | DataType::LargeList(field) => Ok(field),
- data_type => Err(DataFusionError::Internal(format!(
- "Unexpected data type in GetArrayStructFields: {:?}",
- data_type
- ))),
- }
- }
-
- fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef>
{
- match self.list_field(input_schema)?.data_type() {
- DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])),
- data_type => Err(DataFusionError::Internal(format!(
- "Unexpected data type in GetArrayStructFields: {:?}",
- data_type
- ))),
- }
- }
-}
-
-impl PhysicalExpr for GetArrayStructFields {
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
- let struct_field = self.child_field(input_schema)?;
- match self.child.data_type(input_schema)? {
- DataType::List(_) => Ok(DataType::List(struct_field)),
- DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)),
- data_type => Err(DataFusionError::Internal(format!(
- "Unexpected data type in GetArrayStructFields: {:?}",
- data_type
- ))),
- }
- }
-
- fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
- Ok(self.list_field(input_schema)?.is_nullable()
- || self.child_field(input_schema)?.is_nullable())
- }
-
- fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue>
{
- let child_value =
self.child.evaluate(batch)?.into_array(batch.num_rows())?;
-
- match child_value.data_type() {
- DataType::List(_) => {
- let list_array = as_list_array(&child_value)?;
-
- get_array_struct_fields(list_array, self.ordinal)
- }
- DataType::LargeList(_) => {
- let list_array = as_large_list_array(&child_value)?;
-
- get_array_struct_fields(list_array, self.ordinal)
- }
- data_type => Err(DataFusionError::Internal(format!(
- "Unexpected child type for ListExtract: {:?}",
- data_type
- ))),
- }
- }
-
- fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
- vec![&self.child]
- }
-
- fn with_new_children(
- self: Arc<Self>,
- children: Vec<Arc<dyn PhysicalExpr>>,
- ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
- match children.len() {
- 1 => Ok(Arc::new(GetArrayStructFields::new(
- Arc::clone(&children[0]),
- self.ordinal,
- ))),
- _ => internal_err!("GetArrayStructFields should have exactly one
child"),
- }
- }
-}
-
-fn get_array_struct_fields<O: OffsetSizeTrait>(
- list_array: &GenericListArray<O>,
- ordinal: usize,
-) -> DataFusionResult<ColumnarValue> {
- let values = list_array
- .values()
- .as_any()
- .downcast_ref::<StructArray>()
- .expect("A struct is expected");
-
- let column = Arc::clone(values.column(ordinal));
- let field = Arc::clone(&values.fields()[ordinal]);
-
- let offsets = list_array.offsets();
- let array = GenericListArray::new(field, offsets.clone(), column,
list_array.nulls().cloned());
-
- Ok(ColumnarValue::Array(Arc::new(array)))
-}
-
-impl Display for GetArrayStructFields {
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- write!(
- f,
- "GetArrayStructFields [child: {:?}, ordinal: {:?}]",
- self.child, self.ordinal
- )
- }
-}
-
#[derive(Debug, Eq)]
pub struct ArrayInsert {
src_array_expr: Arc<dyn PhysicalExpr>,
@@ -687,51 +319,13 @@ impl Display for ArrayInsert {
#[cfg(test)]
mod test {
- use crate::list::{array_insert, list_extract, zero_based_index};
-
+ use super::*;
use arrow::datatypes::Int32Type;
use arrow_array::{Array, ArrayRef, Int32Array, ListArray};
- use datafusion_common::{Result, ScalarValue};
+ use datafusion_common::Result;
use datafusion_expr::ColumnarValue;
use std::sync::Arc;
- #[test]
- fn test_list_extract_default_value() -> Result<()> {
- let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
- Some(vec![Some(1)]),
- None,
- Some(vec![]),
- ]);
- let indices = Int32Array::from(vec![0, 0, 0]);
-
- let null_default = ScalarValue::Int32(None);
-
- let ColumnarValue::Array(result) =
- list_extract(&list, &indices, &null_default, false,
zero_based_index)?
- else {
- unreachable!()
- };
-
- assert_eq!(
- &result.to_data(),
- &Int32Array::from(vec![Some(1), None, None]).to_data()
- );
-
- let zero_default = ScalarValue::Int32(Some(0));
-
- let ColumnarValue::Array(result) =
- list_extract(&list, &indices, &zero_default, false,
zero_based_index)?
- else {
- unreachable!()
- };
-
- assert_eq!(
- &result.to_data(),
- &Int32Array::from(vec![Some(1), None, Some(0)]).to_data()
- );
- Ok(())
- }
-
#[test]
fn test_array_insert() -> Result<()> {
// Test inserting an item into a list array
diff --git a/native/spark-expr/src/array_funcs/get_array_struct_fields.rs
b/native/spark-expr/src/array_funcs/get_array_struct_fields.rs
new file mode 100644
index 000000000..8b1633649
--- /dev/null
+++ b/native/spark-expr/src/array_funcs/get_array_struct_fields.rs
@@ -0,0 +1,166 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::record_batch::RecordBatch;
+use arrow_array::{Array, GenericListArray, OffsetSizeTrait, StructArray};
+use arrow_schema::{DataType, FieldRef, Schema};
+use datafusion::logical_expr::ColumnarValue;
+use datafusion_common::{
+ cast::{as_large_list_array, as_list_array},
+ internal_err, DataFusionError, Result as DataFusionResult,
+};
+use datafusion_physical_expr::PhysicalExpr;
+use std::hash::Hash;
+use std::{
+ any::Any,
+ fmt::{Debug, Display, Formatter},
+ sync::Arc,
+};
+
+#[derive(Debug, Eq)]
+pub struct GetArrayStructFields {
+ child: Arc<dyn PhysicalExpr>,
+ ordinal: usize,
+}
+
+impl Hash for GetArrayStructFields {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.child.hash(state);
+ self.ordinal.hash(state);
+ }
+}
+impl PartialEq for GetArrayStructFields {
+ fn eq(&self, other: &Self) -> bool {
+ self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal)
+ }
+}
+
+impl GetArrayStructFields {
+ pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
+ Self { child, ordinal }
+ }
+
+ fn list_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
+ match self.child.data_type(input_schema)? {
+ DataType::List(field) | DataType::LargeList(field) => Ok(field),
+ data_type => Err(DataFusionError::Internal(format!(
+ "Unexpected data type in GetArrayStructFields: {:?}",
+ data_type
+ ))),
+ }
+ }
+
+ fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef>
{
+ match self.list_field(input_schema)?.data_type() {
+ DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])),
+ data_type => Err(DataFusionError::Internal(format!(
+ "Unexpected data type in GetArrayStructFields: {:?}",
+ data_type
+ ))),
+ }
+ }
+}
+
+impl PhysicalExpr for GetArrayStructFields {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
+ let struct_field = self.child_field(input_schema)?;
+ match self.child.data_type(input_schema)? {
+ DataType::List(_) => Ok(DataType::List(struct_field)),
+ DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)),
+ data_type => Err(DataFusionError::Internal(format!(
+ "Unexpected data type in GetArrayStructFields: {:?}",
+ data_type
+ ))),
+ }
+ }
+
+ fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
+ Ok(self.list_field(input_schema)?.is_nullable()
+ || self.child_field(input_schema)?.is_nullable())
+ }
+
+ fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue>
{
+ let child_value =
self.child.evaluate(batch)?.into_array(batch.num_rows())?;
+
+ match child_value.data_type() {
+ DataType::List(_) => {
+ let list_array = as_list_array(&child_value)?;
+
+ get_array_struct_fields(list_array, self.ordinal)
+ }
+ DataType::LargeList(_) => {
+ let list_array = as_large_list_array(&child_value)?;
+
+ get_array_struct_fields(list_array, self.ordinal)
+ }
+ data_type => Err(DataFusionError::Internal(format!(
+ "Unexpected child type for ListExtract: {:?}",
+ data_type
+ ))),
+ }
+ }
+
+ fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+ vec![&self.child]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
+ match children.len() {
+ 1 => Ok(Arc::new(GetArrayStructFields::new(
+ Arc::clone(&children[0]),
+ self.ordinal,
+ ))),
+ _ => internal_err!("GetArrayStructFields should have exactly one
child"),
+ }
+ }
+}
+
+fn get_array_struct_fields<O: OffsetSizeTrait>(
+ list_array: &GenericListArray<O>,
+ ordinal: usize,
+) -> DataFusionResult<ColumnarValue> {
+ let values = list_array
+ .values()
+ .as_any()
+ .downcast_ref::<StructArray>()
+ .expect("A struct is expected");
+
+ let column = Arc::clone(values.column(ordinal));
+ let field = Arc::clone(&values.fields()[ordinal]);
+
+ let offsets = list_array.offsets();
+ let array = GenericListArray::new(field, offsets.clone(), column,
list_array.nulls().cloned());
+
+ Ok(ColumnarValue::Array(Arc::new(array)))
+}
+
+impl Display for GetArrayStructFields {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(
+ f,
+ "GetArrayStructFields [child: {:?}, ordinal: {:?}]",
+ self.child, self.ordinal
+ )
+ }
+}
diff --git a/native/spark-expr/src/array_funcs/list_extract.rs
b/native/spark-expr/src/array_funcs/list_extract.rs
new file mode 100644
index 000000000..c0f2291d9
--- /dev/null
+++ b/native/spark-expr/src/array_funcs/list_extract.rs
@@ -0,0 +1,310 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::{array::MutableArrayData, datatypes::ArrowNativeType,
record_batch::RecordBatch};
+use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait};
+use arrow_schema::{DataType, FieldRef, Schema};
+use datafusion::logical_expr::ColumnarValue;
+use datafusion_common::{
+ cast::{as_int32_array, as_large_list_array, as_list_array},
+ internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
+};
+use datafusion_physical_expr::PhysicalExpr;
+use std::hash::Hash;
+use std::{
+ any::Any,
+ fmt::{Debug, Display, Formatter},
+ sync::Arc,
+};
+
+#[derive(Debug, Eq)]
+pub struct ListExtract {
+ child: Arc<dyn PhysicalExpr>,
+ ordinal: Arc<dyn PhysicalExpr>,
+ default_value: Option<Arc<dyn PhysicalExpr>>,
+ one_based: bool,
+ fail_on_error: bool,
+}
+
+impl Hash for ListExtract {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.child.hash(state);
+ self.ordinal.hash(state);
+ self.default_value.hash(state);
+ self.one_based.hash(state);
+ self.fail_on_error.hash(state);
+ }
+}
+impl PartialEq for ListExtract {
+ fn eq(&self, other: &Self) -> bool {
+ self.child.eq(&other.child)
+ && self.ordinal.eq(&other.ordinal)
+ && self.default_value.eq(&other.default_value)
+ && self.one_based.eq(&other.one_based)
+ && self.fail_on_error.eq(&other.fail_on_error)
+ }
+}
+
+impl ListExtract {
+ pub fn new(
+ child: Arc<dyn PhysicalExpr>,
+ ordinal: Arc<dyn PhysicalExpr>,
+ default_value: Option<Arc<dyn PhysicalExpr>>,
+ one_based: bool,
+ fail_on_error: bool,
+ ) -> Self {
+ Self {
+ child,
+ ordinal,
+ default_value,
+ one_based,
+ fail_on_error,
+ }
+ }
+
+ fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef>
{
+ match self.child.data_type(input_schema)? {
+ DataType::List(field) | DataType::LargeList(field) => Ok(field),
+ data_type => Err(DataFusionError::Internal(format!(
+ "Unexpected data type in ListExtract: {:?}",
+ data_type
+ ))),
+ }
+ }
+}
+
+impl PhysicalExpr for ListExtract {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
+ Ok(self.child_field(input_schema)?.data_type().clone())
+ }
+
+ fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
+ // Only non-nullable if fail_on_error is enabled and the element is
non-nullable
+ Ok(!self.fail_on_error ||
self.child_field(input_schema)?.is_nullable())
+ }
+
+ fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue>
{
+ let child_value =
self.child.evaluate(batch)?.into_array(batch.num_rows())?;
+ let ordinal_value =
self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?;
+
+ let default_value = self
+ .default_value
+ .as_ref()
+ .map(|d| {
+ d.evaluate(batch).map(|value| match value {
+ ColumnarValue::Scalar(scalar)
+ if
!scalar.data_type().equals_datatype(child_value.data_type()) =>
+ {
+ scalar.cast_to(child_value.data_type())
+ }
+ ColumnarValue::Scalar(scalar) => Ok(scalar),
+ v => Err(DataFusionError::Execution(format!(
+ "Expected scalar default value for ListExtract, got
{:?}",
+ v
+ ))),
+ })
+ })
+ .transpose()?
+ .unwrap_or(self.data_type(&batch.schema())?.try_into())?;
+
+ let adjust_index = if self.one_based {
+ one_based_index
+ } else {
+ zero_based_index
+ };
+
+ match child_value.data_type() {
+ DataType::List(_) => {
+ let list_array = as_list_array(&child_value)?;
+ let index_array = as_int32_array(&ordinal_value)?;
+
+ list_extract(
+ list_array,
+ index_array,
+ &default_value,
+ self.fail_on_error,
+ adjust_index,
+ )
+ }
+ DataType::LargeList(_) => {
+ let list_array = as_large_list_array(&child_value)?;
+ let index_array = as_int32_array(&ordinal_value)?;
+
+ list_extract(
+ list_array,
+ index_array,
+ &default_value,
+ self.fail_on_error,
+ adjust_index,
+ )
+ }
+ data_type => Err(DataFusionError::Internal(format!(
+ "Unexpected child type for ListExtract: {:?}",
+ data_type
+ ))),
+ }
+ }
+
+ fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+ vec![&self.child, &self.ordinal]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
+ match children.len() {
+ 2 => Ok(Arc::new(ListExtract::new(
+ Arc::clone(&children[0]),
+ Arc::clone(&children[1]),
+ self.default_value.clone(),
+ self.one_based,
+ self.fail_on_error,
+ ))),
+ _ => internal_err!("ListExtract should have exactly two children"),
+ }
+ }
+}
+
+fn one_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>> {
+ if index == 0 {
+ return Err(DataFusionError::Execution(
+ "Invalid index of 0 for one-based ListExtract".to_string(),
+ ));
+ }
+
+ let abs_index = index.abs().as_usize();
+ if abs_index <= len {
+ if index > 0 {
+ Ok(Some(abs_index - 1))
+ } else {
+ Ok(Some(len - abs_index))
+ }
+ } else {
+ Ok(None)
+ }
+}
+
+fn zero_based_index(index: i32, len: usize) -> DataFusionResult<Option<usize>>
{
+ if index < 0 {
+ Ok(None)
+ } else {
+ let positive_index = index.as_usize();
+ if positive_index < len {
+ Ok(Some(positive_index))
+ } else {
+ Ok(None)
+ }
+ }
+}
+
+fn list_extract<O: OffsetSizeTrait>(
+ list_array: &GenericListArray<O>,
+ index_array: &Int32Array,
+ default_value: &ScalarValue,
+ fail_on_error: bool,
+ adjust_index: impl Fn(i32, usize) -> DataFusionResult<Option<usize>>,
+) -> DataFusionResult<ColumnarValue> {
+ let values = list_array.values();
+ let offsets = list_array.offsets();
+
+ let data = values.to_data();
+
+ let default_data = default_value.to_array()?.to_data();
+
+ let mut mutable = MutableArrayData::new(vec![&data, &default_data], true,
index_array.len());
+
+ for (row, (offset_window, index)) in
offsets.windows(2).zip(index_array.values()).enumerate() {
+ let start = offset_window[0].as_usize();
+ let len = offset_window[1].as_usize() - start;
+
+ if let Some(i) = adjust_index(*index, len)? {
+ mutable.extend(0, start + i, start + i + 1);
+ } else if list_array.is_null(row) {
+ mutable.extend_nulls(1);
+ } else if fail_on_error {
+ return Err(DataFusionError::Execution(
+ "Index out of bounds for array".to_string(),
+ ));
+ } else {
+ mutable.extend(1, 0, 1);
+ }
+ }
+
+ let data = mutable.freeze();
+ Ok(ColumnarValue::Array(arrow::array::make_array(data)))
+}
+
+impl Display for ListExtract {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(
+ f,
+ "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?},
one_based: {:?}, fail_on_error: {:?}]",
+ self.child, self.ordinal, self.default_value, self.one_based,
self.fail_on_error
+ )
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use arrow::datatypes::Int32Type;
+ use arrow_array::{Array, Int32Array, ListArray};
+ use datafusion_common::{Result, ScalarValue};
+ use datafusion_expr::ColumnarValue;
+
+ #[test]
+ fn test_list_extract_default_value() -> Result<()> {
+ let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(vec![Some(1)]),
+ None,
+ Some(vec![]),
+ ]);
+ let indices = Int32Array::from(vec![0, 0, 0]);
+
+ let null_default = ScalarValue::Int32(None);
+
+ let ColumnarValue::Array(result) =
+ list_extract(&list, &indices, &null_default, false,
zero_based_index)?
+ else {
+ unreachable!()
+ };
+
+ assert_eq!(
+ &result.to_data(),
+ &Int32Array::from(vec![Some(1), None, None]).to_data()
+ );
+
+ let zero_default = ScalarValue::Int32(Some(0));
+
+ let ColumnarValue::Array(result) =
+ list_extract(&list, &indices, &zero_default, false,
zero_based_index)?
+ else {
+ unreachable!()
+ };
+
+ assert_eq!(
+ &result.to_data(),
+ &Int32Array::from(vec![Some(1), None, Some(0)]).to_data()
+ );
+ Ok(())
+ }
+}
diff --git a/native/spark-expr/src/array_funcs/mod.rs
b/native/spark-expr/src/array_funcs/mod.rs
new file mode 100644
index 000000000..0a215f96c
--- /dev/null
+++ b/native/spark-expr/src/array_funcs/mod.rs
@@ -0,0 +1,24 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod array_insert;
+mod get_array_struct_fields;
+mod list_extract;
+
+pub use array_insert::ArrayInsert;
+pub use get_array_struct_fields::GetArrayStructFields;
+pub use list_extract::ListExtract;
diff --git a/native/spark-expr/src/bitwise_not.rs
b/native/spark-expr/src/bitwise_funcs/bitwise_not.rs
similarity index 100%
rename from native/spark-expr/src/bitwise_not.rs
rename to native/spark-expr/src/bitwise_funcs/bitwise_not.rs
diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs
b/native/spark-expr/src/bitwise_funcs/mod.rs
new file mode 100644
index 000000000..9c2636331
--- /dev/null
+++ b/native/spark-expr/src/bitwise_funcs/mod.rs
@@ -0,0 +1,20 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod bitwise_not;
+
+pub use bitwise_not::{bitwise_not, BitwiseNotExpr};
diff --git a/native/spark-expr/src/if_expr.rs
b/native/spark-expr/src/conditional_funcs/if_expr.rs
similarity index 100%
rename from native/spark-expr/src/if_expr.rs
rename to native/spark-expr/src/conditional_funcs/if_expr.rs
diff --git a/native/spark-expr/src/conditional_funcs/mod.rs
b/native/spark-expr/src/conditional_funcs/mod.rs
new file mode 100644
index 000000000..70c459ef7
--- /dev/null
+++ b/native/spark-expr/src/conditional_funcs/mod.rs
@@ -0,0 +1,20 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod if_expr;
+
+pub use if_expr::IfExpr;
diff --git a/native/spark-expr/src/cast.rs
b/native/spark-expr/src/conversion_funcs/cast.rs
similarity index 100%
rename from native/spark-expr/src/cast.rs
rename to native/spark-expr/src/conversion_funcs/cast.rs
diff --git a/native/spark-expr/src/conversion_funcs/mod.rs
b/native/spark-expr/src/conversion_funcs/mod.rs
new file mode 100644
index 000000000..f2c6f7ca3
--- /dev/null
+++ b/native/spark-expr/src/conversion_funcs/mod.rs
@@ -0,0 +1,18 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+pub mod cast;
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index 22bec87ee..14982264d 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -19,17 +19,12 @@
// The lint makes easier for code reader/reviewer separate references clones
from more heavyweight ones
#![deny(clippy::clone_on_ref_ptr)]
-mod cast;
mod error;
-mod if_expr;
-mod bitwise_not;
-pub use bitwise_not::{bitwise_not, BitwiseNotExpr};
mod checkoverflow;
pub use checkoverflow::CheckOverflow;
mod kernels;
-mod list;
pub mod scalar_funcs;
mod schema_adapter;
mod static_invoke;
@@ -52,6 +47,8 @@ mod predicate_funcs;
pub use predicate_funcs::{spark_isnan, RLike};
mod agg_funcs;
+mod array_funcs;
+mod bitwise_funcs;
mod comet_scalar_funcs;
pub mod hash_funcs;
@@ -63,13 +60,19 @@ pub use agg_funcs::*;
pub use crate::{CreateNamedStruct, GetStructField};
pub use crate::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr,
TimestampTruncExpr};
pub use cast::{spark_cast, Cast, SparkCastOptions};
+mod conditional_funcs;
+mod conversion_funcs;
+
+pub use array_funcs::*;
+pub use bitwise_funcs::*;
+pub use conditional_funcs::*;
+pub use conversion_funcs::*;
+
pub use comet_scalar_funcs::create_comet_physical_fun;
pub use datetime_funcs::*;
pub use error::{SparkError, SparkResult};
pub use hash_funcs::*;
-pub use if_expr::IfExpr;
pub use json_funcs::ToJson;
-pub use list::{ArrayInsert, GetArrayStructFields, ListExtract};
pub use string_funcs::*;
pub use struct_funcs::*;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]