This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 04f031a  [SPARK-34086][SQL] RaiseError generates too much code and may 
fails codegen in length check for char varchar
04f031a is described below

commit 04f031acb38f9473802d2890b82dd6db66e052be
Author: Kent Yao <y...@apache.org>
AuthorDate: Wed Jan 13 09:52:36 2021 +0000

    [SPARK-34086][SQL] RaiseError generates too much code and may fails codegen 
in length check for char varchar
    
    ### What changes were proposed in this pull request?
    
    
https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/133928/testReport/org.apache.spark.sql.execution/LogicalPlanTagInSparkPlanSuite/q41/
    
    We can reduce more than 8000 bytes by removing the unnecessary CONCAT 
expression.
    
    W/ this fix, for q41 in TPCDS with [Using TPCDS original definitions for 
char/varchar columns](https://github.com/apache/spark/pull/31012) applied, we 
can reduce the stage code-gen size from 22523 to 14369
    ```
    14369  - 22523 = - 8154
    ```
    
    ### Why are the changes needed?
    
    fix the perf regression(we need other improvements for q41 works), there 
will be a huge performance regression if codegen fails
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    modified uts
    
    Closes #31150 from yaooqinn/SPARK-34086.
    
    Authored-by: Kent Yao <y...@apache.org>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/expressions/misc.scala      | 10 ++++-
 .../spark/sql/catalyst/util/CharVarcharUtils.scala | 14 +++----
 .../apache/spark/sql/CharVarcharTestSuite.scala    | 45 ++++++++--------------
 3 files changed, 29 insertions(+), 40 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 4ad4c4d..6b3b949 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -66,11 +66,13 @@ case class PrintToStderr(child: Expression) extends 
UnaryExpression {
   """,
   since = "3.1.0",
   group = "misc_funcs")
-case class RaiseError(child: Expression) extends UnaryExpression with 
ImplicitCastInputTypes {
+case class RaiseError(child: Expression, dataType: DataType)
+  extends UnaryExpression with ImplicitCastInputTypes {
+
+  def this(child: Expression) = this(child, NullType)
 
   override def foldable: Boolean = false
   override def nullable: Boolean = true
-  override def dataType: DataType = NullType
   override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
 
   override def prettyName: String = "raise_error"
@@ -100,6 +102,10 @@ case class RaiseError(child: Expression) extends 
UnaryExpression with ImplicitCa
   }
 }
 
+object RaiseError {
+  def apply(child: Expression): RaiseError = new RaiseError(child)
+}
+
 /**
  * A function that throws an exception if 'condition' is not true.
  */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
index eaafe35..5fc070a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 object CharVarcharUtils extends Logging {
 
@@ -202,12 +203,9 @@ object CharVarcharUtils extends Logging {
     }.getOrElse(expr)
   }
 
-  private def raiseError(expr: Expression, typeName: String, length: Int): 
Expression = {
-    val errorMsg = Concat(Seq(
-      Literal("input string of length "),
-      Cast(Length(expr), StringType),
-      Literal(s" exceeds $typeName type length limitation: $length")))
-    Cast(RaiseError(errorMsg), StringType)
+  private def raiseError(typeName: String, length: Int): Expression = {
+    val errMsg = UTF8String.fromString(s"Exceeds $typeName type length 
limitation: $length")
+    RaiseError(Literal(errMsg, StringType), StringType)
   }
 
   private def stringLengthCheck(expr: Expression, dt: DataType): Expression = 
dt match {
@@ -217,7 +215,7 @@ object CharVarcharUtils extends Logging {
       // spaces, as we will pad char type columns/fields at read time.
       If(
         GreaterThan(Length(trimmed), Literal(length)),
-        raiseError(expr, "char", length),
+        raiseError("char", length),
         trimmed)
 
     case VarcharType(length) =>
@@ -230,7 +228,7 @@ object CharVarcharUtils extends Logging {
         expr,
         If(
           GreaterThan(Length(trimmed), Literal(length)),
-          raiseError(expr, "varchar", length),
+          raiseError("varchar", length),
           StringRPad(trimmed, Literal(length))))
 
     case StructType(fields) =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
index 7546e88..fbf3f2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
@@ -189,8 +189,7 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
       sql("INSERT INTO t VALUES (null)")
       checkAnswer(spark.table("t"), Row(null))
       val e = intercept[SparkException](sql("INSERT INTO t VALUES ('123456')"))
-      assert(e.getCause.getMessage.contains(
-        s"input string of length 6 exceeds $typeName type length limitation: 
5"))
+      assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length 
limitation: 5"))
     }
   }
 
@@ -202,8 +201,7 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
         sql("INSERT INTO t VALUES (1, null)")
         checkAnswer(spark.table("t"), Row(1, null))
         val e = intercept[SparkException](sql("INSERT INTO t VALUES (1, 
'123456')"))
-        assert(e.getCause.getMessage.contains(
-          s"input string of length 6 exceeds $typeName type length limitation: 
5"))
+        assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length 
limitation: 5"))
       }
     }
   }
@@ -214,8 +212,7 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
       sql("INSERT INTO t SELECT struct(null)")
       checkAnswer(spark.table("t"), Row(Row(null)))
       val e = intercept[SparkException](sql("INSERT INTO t SELECT 
struct('123456')"))
-      assert(e.getCause.getMessage.contains(
-        s"input string of length 6 exceeds $typeName type length limitation: 
5"))
+      assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length 
limitation: 5"))
     }
   }
 
@@ -225,8 +222,7 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
       sql("INSERT INTO t VALUES (array(null))")
       checkAnswer(spark.table("t"), Row(Seq(null)))
       val e = intercept[SparkException](sql("INSERT INTO t VALUES (array('a', 
'123456'))"))
-      assert(e.getCause.getMessage.contains(
-        s"input string of length 6 exceeds $typeName type length limitation: 
5"))
+      assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length 
limitation: 5"))
     }
   }
 
@@ -234,8 +230,7 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
     testTableWrite { typeName =>
       sql(s"CREATE TABLE t(c MAP<$typeName(5), STRING>) USING $format")
       val e = intercept[SparkException](sql("INSERT INTO t VALUES 
(map('123456', 'a'))"))
-      assert(e.getCause.getMessage.contains(
-        s"input string of length 6 exceeds $typeName type length limitation: 
5"))
+      assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length 
limitation: 5"))
     }
   }
 
@@ -245,8 +240,7 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
       sql("INSERT INTO t VALUES (map('a', null))")
       checkAnswer(spark.table("t"), Row(Map("a" -> null)))
       val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', 
'123456'))"))
-      assert(e.getCause.getMessage.contains(
-        s"input string of length 6 exceeds $typeName type length limitation: 
5"))
+      assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length 
limitation: 5"))
     }
   }
 
@@ -254,11 +248,9 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
     testTableWrite { typeName =>
       sql(s"CREATE TABLE t(c MAP<$typeName(5), $typeName(5)>) USING $format")
       val e1 = intercept[SparkException](sql("INSERT INTO t VALUES 
(map('123456', 'a'))"))
-      assert(e1.getCause.getMessage.contains(
-        s"input string of length 6 exceeds $typeName type length limitation: 
5"))
+      assert(e1.getCause.getMessage.contains(s"Exceeds $typeName type length 
limitation: 5"))
       val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', 
'123456'))"))
-      assert(e2.getCause.getMessage.contains(
-        s"input string of length 6 exceeds $typeName type length limitation: 
5"))
+      assert(e2.getCause.getMessage.contains(s"Exceeds $typeName type length 
limitation: 5"))
     }
   }
 
@@ -268,8 +260,7 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
       sql("INSERT INTO t SELECT struct(array(null))")
       checkAnswer(spark.table("t"), Row(Row(Seq(null))))
       val e = intercept[SparkException](sql("INSERT INTO t SELECT 
struct(array('123456'))"))
-      assert(e.getCause.getMessage.contains(
-        s"input string of length 6 exceeds $typeName type length limitation: 
5"))
+      assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length 
limitation: 5"))
     }
   }
 
@@ -279,8 +270,7 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
       sql("INSERT INTO t VALUES (array(struct(null)))")
       checkAnswer(spark.table("t"), Row(Seq(Row(null))))
       val e = intercept[SparkException](sql("INSERT INTO t VALUES 
(array(struct('123456')))"))
-      assert(e.getCause.getMessage.contains(
-        s"input string of length 6 exceeds $typeName type length limitation: 
5"))
+      assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length 
limitation: 5"))
     }
   }
 
@@ -290,8 +280,7 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
       sql("INSERT INTO t VALUES (array(array(null)))")
       checkAnswer(spark.table("t"), Row(Seq(Seq(null))))
       val e = intercept[SparkException](sql("INSERT INTO t VALUES 
(array(array('123456')))"))
-      assert(e.getCause.getMessage.contains(
-        s"input string of length 6 exceeds $typeName type length limitation: 
5"))
+      assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length 
limitation: 5"))
     }
   }
 
@@ -312,11 +301,9 @@ trait CharVarcharTestSuite extends QueryTest with 
SQLTestUtils {
       sql("INSERT INTO t VALUES (1234, 1234)")
       checkAnswer(spark.table("t"), Row("1234 ", "1234"))
       val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (123456, 
1)"))
-      assert(e1.getCause.getMessage.contains(
-        "input string of length 6 exceeds char type length limitation: 5"))
+      assert(e1.getCause.getMessage.contains("Exceeds char type length 
limitation: 5"))
       val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (1, 
123456)"))
-      assert(e2.getCause.getMessage.contains(
-        "input string of length 6 exceeds varchar type length limitation: 5"))
+      assert(e2.getCause.getMessage.contains("Exceeds varchar type length 
limitation: 5"))
     }
   }
 
@@ -626,8 +613,7 @@ class FileSourceCharVarcharTestSuite extends 
CharVarcharTestSuite with SharedSpa
           sql("SELECT '123456' as col").write.format(format).save(dir.toString)
           sql(s"CREATE TABLE t (col $typ(2)) using $format LOCATION '$dir'")
           val e = intercept[SparkException] { sql("select * from t").collect() 
}
-          assert(e.getCause.getMessage.contains(
-            s"input string of length 6 exceeds $typ type length limitation: 
2"))
+          assert(e.getCause.getMessage.contains(s"Exceeds $typ type length 
limitation: 2"))
         }
       }
     }
@@ -654,8 +640,7 @@ class FileSourceCharVarcharTestSuite extends 
CharVarcharTestSuite with SharedSpa
           sql(s"CREATE TABLE t (col $typ(2)) using $format")
           sql(s"ALTER TABLE t SET LOCATION '$dir'")
           val e = intercept[SparkException] { spark.table("t").collect() }
-          assert(e.getCause.getMessage.contains(
-            s"input string of length 6 exceeds $typ type length limitation: 
2"))
+          assert(e.getCause.getMessage.contains(s"Exceeds $typ type length 
limitation: 2"))
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to