This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new bfdc0f62c fix: Correct GetArrayItem null handling for dynamic indices
and re-enable native execution (#3709)
bfdc0f62c is described below
commit bfdc0f62c662c777c848a7c22c86dbc89a577b24
Author: ChenChen Lai <[email protected]>
AuthorDate: Wed Mar 18 04:45:40 2026 +0800
fix: Correct GetArrayItem null handling for dynamic indices and re-enable
native execution (#3709)
* Correct GetArrayItem null handling for dynamic indices
* query default mode
---
native/spark-expr/src/array_funcs/list_extract.rs | 83 ++++++++++++++++------
.../main/scala/org/apache/comet/serde/arrays.scala | 7 --
.../sql-tests/expressions/array/get_array_item.sql | 6 +-
.../org/apache/comet/CometExpressionSuite.scala | 5 +-
.../apache/comet/exec/CometNativeReaderSuite.scala | 4 +-
.../apache/comet/parquet/ParquetReadSuite.scala | 5 +-
6 files changed, 69 insertions(+), 41 deletions(-)
diff --git a/native/spark-expr/src/array_funcs/list_extract.rs
b/native/spark-expr/src/array_funcs/list_extract.rs
index 2cbd62a2c..d3661f496 100644
--- a/native/spark-expr/src/array_funcs/list_extract.rs
+++ b/native/spark-expr/src/array_funcs/list_extract.rs
@@ -277,33 +277,38 @@ fn list_extract<O: OffsetSizeTrait>(
let mut mutable = MutableArrayData::new(vec![&data, &default_data], true,
index_array.len());
- for (row, (offset_window, index)) in
offsets.windows(2).zip(index_array.values()).enumerate() {
+ for (row, (offset_window, index)) in
offsets.windows(2).zip(index_array.iter()).enumerate() {
let start = offset_window[0].as_usize();
let len = offset_window[1].as_usize() - start;
- if let Some(i) = adjust_index(*index, len)? {
- mutable.extend(0, start + i, start + i + 1);
- } else if list_array.is_null(row) {
+ if list_array.is_null(row) {
mutable.extend_nulls(1);
- } else if fail_on_error {
- // Throw appropriate error based on whether this is element_at
(one_based=true)
- // or GetArrayItem (one_based=false)
- let error = if one_based {
- // element_at function
- SparkError::InvalidElementAtIndex {
- index_value: *index,
- array_size: len as i32,
- }
+ } else if let Some(index) = index {
+ if let Some(i) = adjust_index(index, len)? {
+ mutable.extend(0, start + i, start + i + 1);
+ } else if fail_on_error {
+ // Throw appropriate error based on whether this is element_at
(one_based=true)
+ // or GetArrayItem (one_based=false)
+ let error = if one_based {
+ // element_at function
+ SparkError::InvalidElementAtIndex {
+ index_value: index,
+ array_size: len as i32,
+ }
+ } else {
+ // GetArrayItem (arr[index])
+ SparkError::InvalidArrayIndex {
+ index_value: index,
+ array_size: len as i32,
+ }
+ };
+ return Err(error_wrapper(error));
} else {
- // GetArrayItem (arr[index])
- SparkError::InvalidArrayIndex {
- index_value: *index,
- array_size: len as i32,
- }
- };
- return Err(error_wrapper(error));
+ mutable.extend(1, 0, 1);
+ }
} else {
- mutable.extend(1, 0, 1);
+ // index is NULL → result is NULL
+ mutable.extend_nulls(1);
}
}
@@ -382,4 +387,40 @@ mod test {
);
Ok(())
}
+
+ #[test]
+ fn test_list_extract_null_index() -> Result<()> {
+ // GetArrayItem returns incorrect results with dynamic (column) index
containing nulls
+ let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(vec![Some(10), Some(20), Some(30)]),
+ Some(vec![Some(10), Some(20), Some(30)]),
+ Some(vec![Some(10), Some(20), Some(30)]),
+ Some(vec![Some(1)]),
+ None,
+ Some(vec![Some(10), Some(20)]),
+ ]);
+ let indices = Int32Array::from(vec![Some(0), Some(1), Some(2),
Some(0), Some(0), None]);
+
+ let null_default = ScalarValue::Int32(None);
+ let error_wrapper = |error: SparkError| DataFusionError::from(error);
+
+ let ColumnarValue::Array(result) = list_extract(
+ &list,
+ &indices,
+ &null_default,
+ false,
+ false,
+ |idx, len| zero_based_index(idx, len, &error_wrapper),
+ &error_wrapper,
+ )?
+ else {
+ unreachable!()
+ };
+
+ assert_eq!(
+ &result.to_data(),
+ &Int32Array::from(vec![Some(10), Some(20), Some(30), Some(1),
None, None]).to_data()
+ );
+ Ok(())
+ }
}
diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
index c82018fe6..298c47308 100644
--- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
@@ -486,13 +486,6 @@ object CometCreateArray extends
CometExpressionSerde[CreateArray] {
object CometGetArrayItem extends CometExpressionSerde[GetArrayItem] {
- override def getSupportLevel(expr: GetArrayItem): SupportLevel =
- Incompatible(
- Some(
- "Known correctness issues with index handling" +
- " (https://github.com/apache/datafusion-comet/issues/3330," +
- " https://github.com/apache/datafusion-comet/issues/3332)"))
-
override def convert(
expr: GetArrayItem,
inputs: Seq[Attribute],
diff --git
a/spark/src/test/resources/sql-tests/expressions/array/get_array_item.sql
b/spark/src/test/resources/sql-tests/expressions/array/get_array_item.sql
index 99e6d68f9..c7ebdb514 100644
--- a/spark/src/test/resources/sql-tests/expressions/array/get_array_item.sql
+++ b/spark/src/test/resources/sql-tests/expressions/array/get_array_item.sql
@@ -23,12 +23,12 @@ CREATE TABLE test_get_array_item(arr array<int>, idx int)
USING parquet
statement
INSERT INTO test_get_array_item VALUES (array(10, 20, 30), 0), (array(10, 20,
30), 1), (array(10, 20, 30), 2), (array(1), 0), (NULL, 0), (array(10, 20), NULL)
-query spark_answer_only
+query
SELECT arr[0], arr[1], arr[2] FROM test_get_array_item
-query ignore(https://github.com/apache/datafusion-comet/issues/3332)
+query
SELECT arr[idx] FROM test_get_array_item
-- literal arguments
-query spark_answer_only
+query
SELECT array(10, 20, 30)[0], array(10, 20, 30)[2], array()[0]
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index eeaf1ed91..68c1a82f1 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -28,7 +28,7 @@ import org.scalatest.Tag
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Ceil, Floor,
FromUnixTime, GetArrayItem, Literal, StructsToJson, Tan, TruncDate,
TruncTimestamp}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Ceil, Floor,
FromUnixTime, Literal, StructsToJson, Tan, TruncDate, TruncTimestamp}
import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
import org.apache.spark.sql.comet.CometProjectExec
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
@@ -2587,8 +2587,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
withSQLConf(
SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString(),
// Prevent the optimizer from collapsing an extract value of a
create array
- SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
SimplifyExtractValueOps.ruleName,
- CometConf.getExprAllowIncompatConfigKey(classOf[GetArrayItem]) ->
"true") {
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
SimplifyExtractValueOps.ruleName) {
val df = spark.read.parquet(path.toString)
val stringArray = df.select(array(col("_8"), col("_8"),
lit(null)).alias("arr"))
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala
index 05e82bdfb..a369e183c 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala
@@ -31,7 +31,6 @@ import org.apache.spark.sql.types.{IntegerType, StringType,
StructType}
import org.apache.comet.CometConf
class CometNativeReaderSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
- import org.apache.spark.sql.catalyst.expressions.GetArrayItem
override protected def test(testName: String, testTags: Tag*)(testFun: =>
Any)(implicit
pos: Position): Unit = {
@@ -42,8 +41,7 @@ class CometNativeReaderSuite extends CometTestBase with
AdaptiveSparkPlanHelper
SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "false",
- CometConf.COMET_NATIVE_SCAN_IMPL.key -> scan,
- CometConf.getExprAllowIncompatConfigKey(classOf[GetArrayItem]) ->
"true") {
+ CometConf.COMET_NATIVE_SCAN_IMPL.key -> scan) {
testFun
}
})
diff --git
a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
index 09a2308e3..b9caa9430 100644
--- a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
@@ -35,7 +35,6 @@ import org.apache.parquet.example.data.simple.SimpleGroup
import org.apache.parquet.schema.MessageTypeParser
import org.apache.spark.SparkException
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
-import org.apache.spark.sql.catalyst.expressions.GetArrayItem
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -1506,9 +1505,7 @@ class ParquetReadV1Suite extends ParquetReadSuite with
AdaptiveSparkPlanHelper {
withParquetTable(path.toUri.toString, "complex_types") {
Seq(CometConf.SCAN_NATIVE_DATAFUSION,
CometConf.SCAN_NATIVE_ICEBERG_COMPAT).foreach(
scanMode => {
- withSQLConf(
- CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanMode,
- CometConf.getExprAllowIncompatConfigKey(classOf[GetArrayItem])
-> "true") {
+ withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanMode) {
checkSparkAnswerAndOperator(sql("select * from complex_types"))
// First level
checkSparkAnswerAndOperator(sql(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]