This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d1975a1da3f [SPARK-39836][SQL] Simplify V2ExpressionBuilder by extract 
common method
d1975a1da3f is described below

commit d1975a1da3f3262c8df3003604cc72c2be290014
Author: Jiaan Geng <[email protected]>
AuthorDate: Mon Jul 25 15:47:33 2022 +0800

    [SPARK-39836][SQL] Simplify V2ExpressionBuilder by extract common method
    
    ### What changes were proposed in this pull request?
    Currently, `V2ExpressionBuilder` have a lot of similar code, we can extract 
them as one common method.
    
    We can simplify the implement with the common method.
    
    ### Why are the changes needed?
    Simplify `V2ExpressionBuilder` by extract common method.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    Just update inner implementation.
    
    ### How was this patch tested?
    N/A
    
    Closes #37249 from beliefer/SPARK-39836.
    
    Authored-by: Jiaan Geng <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/util/V2ExpressionBuilder.scala    | 252 +++++----------------
 1 file changed, 59 insertions(+), 193 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 59cbcf48334..70c85def45d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -90,124 +90,45 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
       }
     case Cast(child, dataType, _, true) =>
       generateExpression(child).map(v => new V2Cast(v, dataType))
-    case Abs(child, true) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v)))
-    case Coalesce(children) =>
-      val childrenExpressions = children.flatMap(generateExpression(_))
-      if (children.length == childrenExpressions.length) {
-        Some(new GeneralScalarExpression("COALESCE", 
childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
-    case Greatest(children) =>
-      val childrenExpressions = children.flatMap(generateExpression(_))
-      if (children.length == childrenExpressions.length) {
-        Some(new GeneralScalarExpression("GREATEST", 
childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
-    case Least(children) =>
-      val childrenExpressions = children.flatMap(generateExpression(_))
-      if (children.length == childrenExpressions.length) {
-        Some(new GeneralScalarExpression("LEAST", 
childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
+    case Abs(child, true) => generateExpressionWithName("ABS", Seq(child))
+    case Coalesce(children) => generateExpressionWithName("COALESCE", children)
+    case Greatest(children) => generateExpressionWithName("GREATEST", children)
+    case Least(children) => generateExpressionWithName("LEAST", children)
     case Rand(child, hideSeed) =>
       if (hideSeed) {
         Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression]))
       } else {
-        generateExpression(child)
-          .map(v => new GeneralScalarExpression("RAND", 
Array[V2Expression](v)))
-      }
-    case log: Logarithm =>
-      val l = generateExpression(log.left)
-      val r = generateExpression(log.right)
-      if (l.isDefined && r.isDefined) {
-        Some(new GeneralScalarExpression("LOG", Array[V2Expression](l.get, 
r.get)))
-      } else {
-        None
-      }
-    case Log10(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("LOG10", Array[V2Expression](v)))
-    case Log2(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("LOG2", Array[V2Expression](v)))
-    case Log(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v)))
-    case Exp(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v)))
-    case Pow(left, right) =>
-      val l = generateExpression(left)
-      val r = generateExpression(right)
-      if (l.isDefined && r.isDefined) {
-        Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, 
r.get)))
-      } else {
-        None
-      }
-    case Sqrt(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v)))
-    case Floor(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v)))
-    case Ceil(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v)))
-    case round: Round =>
-      val l = generateExpression(round.left)
-      val r = generateExpression(round.right)
-      if (l.isDefined && r.isDefined) {
-        Some(new GeneralScalarExpression("ROUND", Array[V2Expression](l.get, 
r.get)))
-      } else {
-        None
-      }
-    case Sin(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("SIN", Array[V2Expression](v)))
-    case Sinh(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("SINH", Array[V2Expression](v)))
-    case Cos(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("COS", Array[V2Expression](v)))
-    case Cosh(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("COSH", Array[V2Expression](v)))
-    case Tan(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("TAN", Array[V2Expression](v)))
-    case Tanh(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("TANH", Array[V2Expression](v)))
-    case Cot(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("COT", Array[V2Expression](v)))
-    case Asin(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("ASIN", Array[V2Expression](v)))
-    case Asinh(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("ASINH", Array[V2Expression](v)))
-    case Acos(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("ACOS", Array[V2Expression](v)))
-    case Acosh(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("ACOSH", Array[V2Expression](v)))
-    case Atan(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("ATAN", Array[V2Expression](v)))
-    case Atanh(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("ATANH", Array[V2Expression](v)))
-    case atan2: Atan2 =>
-      val l = generateExpression(atan2.left)
-      val r = generateExpression(atan2.right)
-      if (l.isDefined && r.isDefined) {
-        Some(new GeneralScalarExpression("ATAN2", Array[V2Expression](l.get, 
r.get)))
-      } else {
-        None
-      }
-    case Cbrt(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("CBRT", Array[V2Expression](v)))
-    case ToDegrees(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("DEGREES", Array[V2Expression](v)))
-    case ToRadians(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("RADIANS", Array[V2Expression](v)))
-    case Signum(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("SIGN", Array[V2Expression](v)))
-    case wb: WidthBucket =>
-      val childrenExpressions = wb.children.flatMap(generateExpression(_))
-      if (childrenExpressions.length == wb.children.length) {
-        Some(new GeneralScalarExpression("WIDTH_BUCKET",
-          childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
+        generateExpressionWithName("RAND", Seq(child))
+      }
+    case log: Logarithm => generateExpressionWithName("LOG", log.children)
+    case Log10(child) => generateExpressionWithName("LOG10", Seq(child))
+    case Log2(child) => generateExpressionWithName("LOG2", Seq(child))
+    case Log(child) => generateExpressionWithName("LN", Seq(child))
+    case Exp(child) => generateExpressionWithName("EXP", Seq(child))
+    case pow: Pow => generateExpressionWithName("POWER", pow.children)
+    case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child))
+    case Floor(child) => generateExpressionWithName("FLOOR", Seq(child))
+    case Ceil(child) => generateExpressionWithName("CEIL", Seq(child))
+    case round: Round => generateExpressionWithName("ROUND", round.children)
+    case Sin(child) => generateExpressionWithName("SIN", Seq(child))
+    case Sinh(child) => generateExpressionWithName("SINH", Seq(child))
+    case Cos(child) => generateExpressionWithName("COS", Seq(child))
+    case Cosh(child) => generateExpressionWithName("COSH", Seq(child))
+    case Tan(child) => generateExpressionWithName("TAN", Seq(child))
+    case Tanh(child) => generateExpressionWithName("TANH", Seq(child))
+    case Cot(child) => generateExpressionWithName("COT", Seq(child))
+    case Asin(child) => generateExpressionWithName("ASIN", Seq(child))
+    case Asinh(child) => generateExpressionWithName("ASINH", Seq(child))
+    case Acos(child) => generateExpressionWithName("ACOS", Seq(child))
+    case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child))
+    case Atan(child) => generateExpressionWithName("ATAN", Seq(child))
+    case Atanh(child) => generateExpressionWithName("ATANH", Seq(child))
+    case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children)
+    case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child))
+    case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child))
+    case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child))
+    case Signum(child) => generateExpressionWithName("SIGN", Seq(child))
+    case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", 
wb.children)
     case and: And =>
       // AND expects predicate
       val l = generateExpression(and.left, true)
@@ -258,10 +179,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
         assert(v.isInstanceOf[V2Predicate])
         new V2Not(v.asInstanceOf[V2Predicate])
       }
-    case UnaryMinus(child, true) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("-", Array[V2Expression](v)))
-    case BitwiseNot(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("~", Array[V2Expression](v)))
+    case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child))
+    case BitwiseNot(child) => generateExpressionWithName("~", Seq(child))
     case CaseWhen(branches, elseValue) =>
       val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
       val values = branches.map(_._2).flatMap(generateExpression(_, true))
@@ -282,93 +201,30 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
       } else {
         None
       }
-    case iff: If =>
-      val childrenExpressions = iff.children.flatMap(generateExpression(_))
-      if (iff.children.length == childrenExpressions.length) {
-        Some(new GeneralScalarExpression("CASE_WHEN", 
childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
+    case iff: If => generateExpressionWithName("CASE_WHEN", iff.children)
     case substring: Substring =>
       val children = if (substring.len == Literal(Integer.MAX_VALUE)) {
         Seq(substring.str, substring.pos)
       } else {
         substring.children
       }
-      val childrenExpressions = children.flatMap(generateExpression(_))
-      if (childrenExpressions.length == children.length) {
-        Some(new GeneralScalarExpression("SUBSTRING",
-          childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
-    case Upper(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("UPPER", Array[V2Expression](v)))
-    case Lower(child) => generateExpression(child)
-      .map(v => new GeneralScalarExpression("LOWER", Array[V2Expression](v)))
-    case translate: StringTranslate =>
-      val childrenExpressions = 
translate.children.flatMap(generateExpression(_))
-      if (childrenExpressions.length == translate.children.length) {
-        Some(new GeneralScalarExpression("TRANSLATE",
-          childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
-    case trim: StringTrim =>
-      val childrenExpressions = trim.children.flatMap(generateExpression(_))
-      if (childrenExpressions.length == trim.children.length) {
-        Some(new GeneralScalarExpression("TRIM", 
childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
-    case trim: StringTrimLeft =>
-      val childrenExpressions = trim.children.flatMap(generateExpression(_))
-      if (childrenExpressions.length == trim.children.length) {
-        Some(new GeneralScalarExpression("LTRIM", 
childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
-    case trim: StringTrimRight =>
-      val childrenExpressions = trim.children.flatMap(generateExpression(_))
-      if (childrenExpressions.length == trim.children.length) {
-        Some(new GeneralScalarExpression("RTRIM", 
childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
+      generateExpressionWithName("SUBSTRING", children)
+    case Upper(child) => generateExpressionWithName("UPPER", Seq(child))
+    case Lower(child) => generateExpressionWithName("LOWER", Seq(child))
+    case translate: StringTranslate => generateExpressionWithName("TRANSLATE", 
translate.children)
+    case trim: StringTrim => generateExpressionWithName("TRIM", trim.children)
+    case trim: StringTrimLeft => generateExpressionWithName("LTRIM", 
trim.children)
+    case trim: StringTrimRight => generateExpressionWithName("RTRIM", 
trim.children)
     case overlay: Overlay =>
       val children = if (overlay.len == Literal(-1)) {
         Seq(overlay.input, overlay.replace, overlay.pos)
       } else {
         overlay.children
       }
-      val childrenExpressions = children.flatMap(generateExpression(_))
-      if (childrenExpressions.length == children.length) {
-        Some(new GeneralScalarExpression("OVERLAY",
-          childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
-    case date: DateAdd =>
-      val childrenExpressions = date.children.flatMap(generateExpression(_))
-      if (childrenExpressions.length == date.children.length) {
-        Some(new GeneralScalarExpression("DATE_ADD", 
childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
-    case date: DateDiff =>
-      val childrenExpressions = date.children.flatMap(generateExpression(_))
-      if (childrenExpressions.length == date.children.length) {
-        Some(new GeneralScalarExpression("DATE_DIFF", 
childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
-    case date: TruncDate =>
-      val childrenExpressions = date.children.flatMap(generateExpression(_))
-      if (childrenExpressions.length == date.children.length) {
-        Some(new GeneralScalarExpression("TRUNC", 
childrenExpressions.toArray[V2Expression]))
-      } else {
-        None
-      }
+      generateExpressionWithName("OVERLAY", children)
+    case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children)
+    case date: DateDiff => generateExpressionWithName("DATE_DIFF", 
date.children)
+    case date: TruncDate => generateExpressionWithName("TRUNC", date.children)
     case Second(child, _) =>
       generateExpression(child).map(v => new V2Extract("SECOND", v))
     case Minute(child, _) =>
@@ -429,6 +285,16 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) {
       case _ => operatorName
     }
   }
+
+  private def generateExpressionWithName(
+      v2ExpressionName: String, children: Seq[Expression]): 
Option[V2Expression] = {
+    val childrenExpressions = children.flatMap(generateExpression(_))
+    if (childrenExpressions.length == children.length) {
+      Some(new GeneralScalarExpression(v2ExpressionName, 
childrenExpressions.toArray[V2Expression]))
+    } else {
+      None
+    }
+  }
 }
 
 object ColumnOrField {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to