This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1a691131f5c6 [SPARK-50173][PS][SQL] Make pandas expressions accept
more datatypes
1a691131f5c6 is described below
commit 1a691131f5c6ecf84acd01b63f1e3ce53da36da9
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Oct 31 15:46:16 2024 +0900
[SPARK-50173][PS][SQL] Make pandas expressions accept more datatypes
### What changes were proposed in this pull request?
Make `ddof` in pandas expressions accept all IntegralType
### Why are the changes needed?
E.g. the `ddof`, existing implementation requires it to be exactly a int,
this is too strict for directly using those expressions.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48706 from zhengruifeng/ps_ddof_int.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../catalyst/expressions/aggregate/PandasAggregate.scala | 13 ++++++++-----
.../spark/sql/catalyst/expressions/aggregate/collect.scala | 11 +++++++----
.../spark/sql/catalyst/expressions/windowExpressions.scala | 5 +++--
3 files changed, 18 insertions(+), 11 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PandasAggregate.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PandasAggregate.scala
index 10ea6069ed3d..07f00ad03504 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PandasAggregate.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PandasAggregate.scala
@@ -16,17 +16,20 @@
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.expressions.{BooleanLiteral, Expression,
IntegerLiteral}
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.errors.QueryCompilationErrors
private[expressions] object PandasAggregate {
- def expressionToIgnoreNA(e: Expression, source: String): Boolean = e match {
- case BooleanLiteral(ignoreNA) => ignoreNA
+ def expressionToIgnoreNA(e: Expression, source: String): Boolean = e.eval()
match {
+ case b: Boolean => b
case _ => throw QueryCompilationErrors.invalidIgnoreNAParameter(source, e)
}
- def expressionToDDOF(e: Expression, source: String): Int = e match {
- case IntegerLiteral(ddof) => ddof
+ def expressionToDDOF(e: Expression, source: String): Int = e.eval() match {
+ case l: Long => l.toInt
+ case i: Int => i
+ case s: Short => s.toInt
+ case b: Byte => b.toInt
case _ => throw QueryCompilationErrors.invalidDdofParameter(source, e)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index adad04e7d749..3aaf353043a9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -252,13 +252,16 @@ case class CollectTopK(
}
private[aggregate] object CollectTopK {
- def expressionToReverse(e: Expression): Boolean = e match {
- case BooleanLiteral(reverse) => reverse
+ def expressionToReverse(e: Expression): Boolean = e.eval() match {
+ case b: Boolean => b
case _ => throw QueryCompilationErrors.invalidReverseParameter(e)
}
- def expressionToNum(e: Expression): Int = e match {
- case IntegerLiteral(num) => num
+ def expressionToNum(e: Expression): Int = e.eval() match {
+ case l: Long => l.toInt
+ case i: Int => i
+ case s: Short => s.toInt
+ case b: Byte => b.toInt
case _ => throw QueryCompilationErrors.invalidNumParameter(e)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index f31f3d26ee59..ecc32bc8d0ef 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -1171,8 +1171,9 @@ case class EWM(input: Expression, alpha: Double,
ignoreNA: Boolean)
}
private[expressions] object EWM {
- def expressionToAlpha(e: Expression): Double = e match {
- case DoubleLiteral(alpha) => alpha
+ def expressionToAlpha(e: Expression): Double = e.eval() match {
+ case d: Double => d
+ case f: Float => f.toDouble
case _ => throw QueryCompilationErrors.invalidAlphaParameter(e)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]