This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 62e8c2a79 feat: add support for `width_bucket` expression (#3273)
62e8c2a79 is described below
commit 62e8c2a79e4c7b6e443ed948c06fa35ce5cc8ac1
Author: David López <[email protected]>
AuthorDate: Wed Jan 28 17:50:47 2026 +0100
feat: add support for `width_bucket` expression (#3273)
---
docs/spark_expressions_support.md | 2 +-
native/core/src/execution/jni_api.rs | 2 +
.../org/apache/comet/shims/CometExprShim.scala | 7 ++-
.../org/apache/comet/shims/CometExprShim.scala | 7 ++-
.../apache/comet/CometMathExpressionSuite.scala | 55 ++++++++++++++++++++++
5 files changed, 70 insertions(+), 3 deletions(-)
diff --git a/docs/spark_expressions_support.md
b/docs/spark_expressions_support.md
index fa6b3a43f..27b6ad3b5 100644
--- a/docs/spark_expressions_support.md
+++ b/docs/spark_expressions_support.md
@@ -349,7 +349,7 @@
- [x] try_multiply
- [x] try_subtract
- [x] unhex
-- [ ] width_bucket
+- [x] width_bucket
### misc_funcs
diff --git a/native/core/src/execution/jni_api.rs
b/native/core/src/execution/jni_api.rs
index e9f2d6523..2022aef75 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -49,6 +49,7 @@ use datafusion_spark::function::hash::sha1::SparkSha1;
use datafusion_spark::function::hash::sha2::SparkSha2;
use datafusion_spark::function::math::expm1::SparkExpm1;
use datafusion_spark::function::math::hex::SparkHex;
+use datafusion_spark::function::math::width_bucket::SparkWidthBucket;
use datafusion_spark::function::string::char::CharFunc;
use datafusion_spark::function::string::concat::SparkConcat;
use futures::poll;
@@ -351,6 +352,7 @@ fn register_datafusion_spark_function(session_ctx:
&SessionContext) {
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkHex::default()));
+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkWidthBucket::default()));
}
/// Prepares arrow arrays for output.
diff --git
a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala
b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala
index 171216023..d9b80ab48 100644
--- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala
@@ -26,7 +26,7 @@ import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass,
Incompatible}
import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr}
-import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal
+import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal,
optExprWithInfo, scalarFunctionExprToProto}
/**
* `CometExprShim` acts as a shim for parsing expressions from different Spark
versions.
@@ -81,6 +81,11 @@ trait CometExprShim extends CommonStringExprs {
None
}
+ case wb: WidthBucket =>
+ val childExprs = wb.children.map(exprToProtoInternal(_, inputs,
binding))
+ val optExpr = scalarFunctionExprToProto("width_bucket", childExprs: _*)
+ optExprWithInfo(optExpr, wb, wb.children: _*)
+
case _ => None
}
}
diff --git
a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
index fc3db183b..1d4427d15 100644
--- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
@@ -29,7 +29,7 @@ import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass,
Incompatible}
import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr}
-import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal
+import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal,
optExprWithInfo, scalarFunctionExprToProto}
/**
* `CometExprShim` acts as a shim for parsing expressions from different Spark
versions.
@@ -103,6 +103,11 @@ trait CometExprShim extends CommonStringExprs {
None
}
+ case wb: WidthBucket =>
+ val childExprs = wb.children.map(exprToProtoInternal(_, inputs,
binding))
+ val optExpr = scalarFunctionExprToProto("width_bucket", childExprs: _*)
+ optExprWithInfo(optExpr, wb, wb.children: _*)
+
case _ => None
}
}
diff --git
a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
index 8ea4a9c88..9d27f2d25 100644
--- a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
@@ -26,6 +26,7 @@ import
org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+import org.apache.comet.CometSparkSessionExtensions.isSpark35Plus
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
class CometMathExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
@@ -90,4 +91,58 @@ class CometMathExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelpe
1000,
DataGenOptions(generateNegativeZero = generateNegativeZero))
}
+
+ test("width_bucket") {
+ assume(isSpark35Plus, "width_bucket was added in Spark 3.5")
+ withSQLConf("spark.comet.exec.localTableScan.enabled" -> "true") {
+ spark
+ .createDataFrame(
+ Seq((5.3, 0.2, 10.6, 5), (8.1, 0.0, 5.7, 4), (-0.9, 5.2, 0.5, 2),
(-2.1, 1.3, 3.4, 3)))
+ .toDF("c1", "c2", "c3", "c4")
+ .createOrReplaceTempView("width_bucket_test")
+ checkSparkAnswerAndOperator(
+ "SELECT c1, width_bucket(c1, c2, c3, c4) FROM width_bucket_test")
+ }
+ }
+
+ test("width_bucket - edge cases") {
+ assume(isSpark35Plus, "width_bucket was added in Spark 3.5")
+ withSQLConf("spark.comet.exec.localTableScan.enabled" -> "true") {
+ spark
+ .createDataFrame(Seq(
+ (0.0, 10.0, 0.0, 5), // Value equals max (reversed bounds)
+ (10.0, 0.0, 10.0, 5), // Value equals max (normal bounds)
+ (10.0, 0.0, 0.0, 5), // Min equals max - returns NULL
+ (5.0, 0.0, 10.0, 0) // Zero buckets - returns NULL
+ ))
+ .toDF("c1", "c2", "c3", "c4")
+ .createOrReplaceTempView("width_bucket_edge")
+ checkSparkAnswerAndOperator(
+ "SELECT c1, width_bucket(c1, c2, c3, c4) FROM width_bucket_edge")
+ }
+ }
+
+ test("width_bucket - NaN values") {
+ assume(isSpark35Plus, "width_bucket was added in Spark 3.5")
+ withSQLConf("spark.comet.exec.localTableScan.enabled" -> "true") {
+ spark
+ .createDataFrame(
+ Seq((Double.NaN, 5.0, 0.0), (5.0, Double.NaN, 0.0), (5.0, 0.0,
Double.NaN)))
+ .toDF("c1", "c2", "c3")
+ .createOrReplaceTempView("width_bucket_nan")
+ checkSparkAnswerAndOperator("SELECT c1, width_bucket(c1, c2, c3, 5) FROM
width_bucket_nan")
+ }
+ }
+
+ test("width_bucket - with range data") {
+ assume(isSpark35Plus, "width_bucket was added in Spark 3.5")
+ withSQLConf("spark.comet.exec.localTableScan.enabled" -> "true") {
+ spark
+ .range(10)
+ .selectExpr("id", "CAST(id AS DOUBLE) as value")
+ .createOrReplaceTempView("width_bucket_range")
+ checkSparkAnswerAndOperator(
+ "SELECT id, width_bucket(value, 0.0, 10.0, 5) FROM width_bucket_range
ORDER BY id")
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]