This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a691131f5c6 [SPARK-50173][PS][SQL] Make pandas expressions accept 
more datatypes
1a691131f5c6 is described below

commit 1a691131f5c6ecf84acd01b63f1e3ce53da36da9
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Oct 31 15:46:16 2024 +0900

    [SPARK-50173][PS][SQL] Make pandas expressions accept more datatypes
    
    ### What changes were proposed in this pull request?
    Make `ddof` in pandas expressions accept all IntegralType
    
    ### Why are the changes needed?
    E.g. the `ddof`, existing implementation requires it to be exactly a int, 
this is too strict for directly using those expressions.
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #48706 from zhengruifeng/ps_ddof_int.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../catalyst/expressions/aggregate/PandasAggregate.scala    | 13 ++++++++-----
 .../spark/sql/catalyst/expressions/aggregate/collect.scala  | 11 +++++++----
 .../spark/sql/catalyst/expressions/windowExpressions.scala  |  5 +++--
 3 files changed, 18 insertions(+), 11 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PandasAggregate.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PandasAggregate.scala
index 10ea6069ed3d..07f00ad03504 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PandasAggregate.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PandasAggregate.scala
@@ -16,17 +16,20 @@
  */
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import org.apache.spark.sql.catalyst.expressions.{BooleanLiteral, Expression, 
IntegerLiteral}
+import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.errors.QueryCompilationErrors
 
 private[expressions] object PandasAggregate {
-  def expressionToIgnoreNA(e: Expression, source: String): Boolean = e match {
-    case BooleanLiteral(ignoreNA) => ignoreNA
+  def expressionToIgnoreNA(e: Expression, source: String): Boolean = e.eval() 
match {
+    case b: Boolean => b
     case _ => throw QueryCompilationErrors.invalidIgnoreNAParameter(source, e)
   }
 
-  def expressionToDDOF(e: Expression, source: String): Int = e match {
-    case IntegerLiteral(ddof) => ddof
+  def expressionToDDOF(e: Expression, source: String): Int = e.eval() match {
+    case l: Long => l.toInt
+    case i: Int => i
+    case s: Short => s.toInt
+    case b: Byte => b.toInt
     case _ => throw QueryCompilationErrors.invalidDdofParameter(source, e)
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index adad04e7d749..3aaf353043a9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -252,13 +252,16 @@ case class CollectTopK(
 }
 
 private[aggregate] object CollectTopK {
-  def expressionToReverse(e: Expression): Boolean = e match {
-    case BooleanLiteral(reverse) => reverse
+  def expressionToReverse(e: Expression): Boolean = e.eval() match {
+    case b: Boolean => b
     case _ => throw QueryCompilationErrors.invalidReverseParameter(e)
   }
 
-  def expressionToNum(e: Expression): Int = e match {
-    case IntegerLiteral(num) => num
+  def expressionToNum(e: Expression): Int = e.eval() match {
+    case l: Long => l.toInt
+    case i: Int => i
+    case s: Short => s.toInt
+    case b: Byte => b.toInt
     case _ => throw QueryCompilationErrors.invalidNumParameter(e)
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index f31f3d26ee59..ecc32bc8d0ef 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -1171,8 +1171,9 @@ case class EWM(input: Expression, alpha: Double, 
ignoreNA: Boolean)
 }
 
 private[expressions] object EWM {
-  def expressionToAlpha(e: Expression): Double = e match {
-    case DoubleLiteral(alpha) => alpha
+  def expressionToAlpha(e: Expression): Double = e.eval() match {
+    case d: Double => d
+    case f: Float => f.toDouble
     case _ => throw QueryCompilationErrors.invalidAlphaParameter(e)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to