This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new ea5af851e78 [SPARK-39758][SQL][3.2] Fix NPE from the regexp functions 
on invalid patterns
ea5af851e78 is described below

commit ea5af851e78f569365b240e3a81b8108d8b2f650
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Thu Jul 14 17:49:39 2022 +0300

    [SPARK-39758][SQL][3.2] Fix NPE from the regexp functions on invalid 
patterns
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to catch `PatternSyntaxException` while compiling the 
regexp pattern by the `regexp_extract`, `regexp_extract_all` and 
`regexp_instr`, and substitute the exception by Spark's exception w/ the error 
class `INVALID_PARAMETER_VALUE`. In this way, Spark SQL will output the error 
in the form:
    ```sql
    org.apache.spark.SparkRuntimeException
    The value of parameter(s) 'regexp' in `regexp_instr` is invalid: ') ?'
    ```
    instead of (on Spark 3.3.0):
    ```java
    java.lang.NullPointerException: null
    ```
    Also I propose to set `lastRegex` only after the compilation of the regexp 
pattern completes successfully.
    
    This is a backport of https://github.com/apache/spark/pull/37171.
    
    ### Why are the changes needed?
    The changes fix NPE portrayed by the code on Spark 3.3.0:
    ```sql
    spark-sql> SELECT regexp_extract('1a 2b 14m', '(?l)');
    22/07/12 19:07:21 ERROR SparkSQLDriver: Failed in [SELECT 
regexp_extract('1a 2b 14m', '(?l)')]
    java.lang.NullPointerException: null
            at 
org.apache.spark.sql.catalyst.expressions.RegExpExtractBase.getLastMatcher(regexpExpressions.scala:768)
 ~[spark-catalyst_2.12-3.3.0.jar:3.3.0]
    ```
    This should improve user experience with Spark SQL.
    
    ### Does this PR introduce _any_ user-facing change?
    No. In regular cases, the behavior is the same but users will observe 
different exceptions (error messages) after the changes.
    
    ### How was this patch tested?
    By running new tests:
    ```
    $ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z 
regexp-functions.sql"
    $ build/sbt "test:testOnly *.RegexpExpressionsSuite"
    $ build/sbt "sql/test:testOnly 
org.apache.spark.sql.expressions.ExpressionInfoSuite"
    ```
    
    Authored-by: Max Gekk <max.gekkgmail.com>
    Signed-off-by: Max Gekk <max.gekkgmail.com>
    (cherry picked from commit 5b96bd5cf8f44eee7a16cd027d37dec552ed5a6a)
    Signed-off-by: Max Gekk <max.gekkgmail.com>
    
    Closes #37182 from MaxGekk/pattern-syntax-exception-3.2.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 core/src/main/resources/error/error-classes.json   |  4 ++
 .../scala/org/apache/spark/SparkException.scala    | 12 +++++
 .../catalyst/expressions/regexpExpressions.scala   | 62 +++++++++++++---------
 .../spark/sql/errors/QueryExecutionErrors.scala    | 12 ++++-
 .../expressions/RegexpExpressionsSuite.scala       | 17 +++++-
 .../sql-tests/inputs/regexp-functions.sql          |  2 +
 .../sql-tests/results/regexp-functions.sql.out     | 20 ++++++-
 7 files changed, 100 insertions(+), 29 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index 9999eb5f6e4..7aa9b26b00d 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -33,6 +33,10 @@
     "message" : [ "Field name %s is invalid: %s is not a struct." ],
     "sqlState" : "42000"
   },
+  "INVALID_PARAMETER_VALUE" : {
+    "message" : [ "The value of parameter(s) '%s' in `%s` is invalid: '%s'" ],
+    "sqlState" : "22023"
+  },
   "MISSING_COLUMN" : {
     "message" : [ "cannot resolve '%s' given input columns: [%s]" ],
     "sqlState" : "42000"
diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala 
b/core/src/main/scala/org/apache/spark/SparkException.scala
index 6ba425fe909..f74ecf3e26f 100644
--- a/core/src/main/scala/org/apache/spark/SparkException.scala
+++ b/core/src/main/scala/org/apache/spark/SparkException.scala
@@ -79,3 +79,15 @@ private[spark] class SparkArithmeticException(errorClass: 
String, messageParamet
   override def getErrorClass: String = errorClass
   override def getSqlState: String = 
SparkThrowableHelper.getSqlState(errorClass)
 }
+
+private[spark] class SparkRuntimeException(
+    errorClass: String,
+    messageParameters: Array[String],
+    cause: Throwable = null)
+  extends RuntimeException(
+    SparkThrowableHelper.getMessage(errorClass, messageParameters), cause)
+    with SparkThrowable {
+
+  override def getErrorClass: String = errorClass
+  override def getSqlState: String = 
SparkThrowableHelper.getSqlState(errorClass)
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 543481e9f4d..9b5a228ea5a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import java.util.Locale
-import java.util.regex.{Matcher, MatchResult, Pattern}
+import java.util.regex.{Matcher, MatchResult, Pattern, PatternSyntaxException}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
@@ -678,11 +678,42 @@ abstract class RegExpExtractBase
   protected def getLastMatcher(s: Any, p: Any): Matcher = {
     if (p != lastRegex) {
       // regex value changed
-      lastRegex = p.asInstanceOf[UTF8String].clone()
-      pattern = Pattern.compile(lastRegex.toString)
+      try {
+        val r = p.asInstanceOf[UTF8String].clone()
+        pattern = Pattern.compile(r.toString)
+        lastRegex = r
+      } catch {
+        case e: PatternSyntaxException =>
+          throw QueryExecutionErrors.invalidPatternError(prettyName, 
e.getPattern)
+
+      }
     }
     pattern.matcher(s.toString)
   }
+
+  protected def initLastMatcherCode(
+      ctx: CodegenContext,
+      subject: String,
+      regexp: String,
+      matcher: String): String = {
+    val classNamePattern = classOf[Pattern].getCanonicalName
+    val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex")
+    val termPattern = ctx.addMutableState(classNamePattern, "pattern")
+
+    s"""
+      |if (!$regexp.equals($termLastRegex)) {
+      |  // regex value changed
+      |  try {
+      |    UTF8String r = $regexp.clone();
+      |    $termPattern = $classNamePattern.compile(r.toString());
+      |    $termLastRegex = r;
+      |  } catch (java.util.regex.PatternSyntaxException e) {
+      |    throw QueryExecutionErrors.invalidPatternError("$prettyName", 
e.getPattern());
+      |  }
+      |}
+      |java.util.regex.Matcher $matcher = 
$termPattern.matcher($subject.toString());
+      |""".stripMargin
+  }
 }
 
 /**
@@ -744,14 +775,9 @@ case class RegExpExtract(subject: Expression, regexp: 
Expression, idx: Expressio
   override def prettyName: String = "regexp_extract"
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
-    val classNamePattern = classOf[Pattern].getCanonicalName
     val classNameRegExpExtractBase = 
classOf[RegExpExtractBase].getCanonicalName
     val matcher = ctx.freshName("matcher")
     val matchResult = ctx.freshName("matchResult")
-
-    val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex")
-    val termPattern = ctx.addMutableState(classNamePattern, "pattern")
-
     val setEvNotNull = if (nullable) {
       s"${ev.isNull} = false;"
     } else {
@@ -760,13 +786,7 @@ case class RegExpExtract(subject: Expression, regexp: 
Expression, idx: Expressio
 
     nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
       s"""
-      if (!$regexp.equals($termLastRegex)) {
-        // regex value changed
-        $termLastRegex = $regexp.clone();
-        $termPattern = $classNamePattern.compile($termLastRegex.toString());
-      }
-      java.util.regex.Matcher $matcher =
-        $termPattern.matcher($subject.toString());
+      ${initLastMatcherCode(ctx, subject, regexp, matcher)}
       if ($matcher.find()) {
         java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
         $classNameRegExpExtractBase.checkGroupIndex($matchResult.groupCount(), 
$idx);
@@ -848,16 +868,11 @@ case class RegExpExtractAll(subject: Expression, regexp: 
Expression, idx: Expres
   override def prettyName: String = "regexp_extract_all"
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
-    val classNamePattern = classOf[Pattern].getCanonicalName
     val classNameRegExpExtractBase = 
classOf[RegExpExtractBase].getCanonicalName
     val arrayClass = classOf[GenericArrayData].getName
     val matcher = ctx.freshName("matcher")
     val matchResult = ctx.freshName("matchResult")
     val matchResults = ctx.freshName("matchResults")
-
-    val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex")
-    val termPattern = ctx.addMutableState(classNamePattern, "pattern")
-
     val setEvNotNull = if (nullable) {
       s"${ev.isNull} = false;"
     } else {
@@ -865,12 +880,7 @@ case class RegExpExtractAll(subject: Expression, regexp: 
Expression, idx: Expres
     }
     nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
       s"""
-         | if (!$regexp.equals($termLastRegex)) {
-         |   // regex value changed
-         |   $termLastRegex = $regexp.clone();
-         |   $termPattern = 
$classNamePattern.compile($termLastRegex.toString());
-         | }
-         | java.util.regex.Matcher $matcher = 
$termPattern.matcher($subject.toString());
+         | ${initLastMatcherCode(ctx, subject, regexp, matcher)}
          | java.util.ArrayList $matchResults = new 
java.util.ArrayList<UTF8String>();
          | while ($matcher.find()) {
          |   java.util.regex.MatchResult $matchResult = 
$matcher.toMatchResult();
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 3c922dec29d..be31f0baaed 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -32,7 +32,7 @@ import org.apache.hadoop.fs.permission.FsPermission
 import org.codehaus.commons.compiler.CompileException
 import org.codehaus.janino.InternalCompilerException
 
-import org.apache.spark.{Partition, SparkArithmeticException, SparkException, 
SparkUpgradeException}
+import org.apache.spark.{Partition, SparkArithmeticException, SparkException, 
SparkRuntimeException, SparkUpgradeException}
 import org.apache.spark.executor.CommitDeniedException
 import org.apache.spark.launcher.SparkLauncher
 import org.apache.spark.memory.SparkOutOfMemoryError
@@ -1821,4 +1821,14 @@ private[sql] object QueryExecutionErrors {
     new SparkException(errorClass = "NULL_COMPARISON_RESULT",
       messageParameters = Array(), cause = null)
   }
+
+  def invalidPatternError(funcName: String, pattern: String): RuntimeException 
= {
+    new SparkRuntimeException(
+      errorClass = "INVALID_PARAMETER_VALUE",
+      messageParameters = Array(
+        "regexp",
+        funcName,
+        pattern),
+      cause = null)
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
index 2ca9ede7742..a4211384226 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
@@ -483,4 +483,19 @@ class RegexpExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
         .likeAll("%foo%", Literal.create(null, StringType)), null)
     }
   }
+
+  test("SPARK-39758: invalid regexp pattern") {
+    val s = $"s".string.at(0)
+    val p = $"p".string.at(1)
+    val r = $"r".int.at(2)
+    val prefix = "The value of parameter(s) 'regexp' in"
+    checkExceptionInExpression[SparkRuntimeException](
+      RegExpExtract(s, p, r),
+      create_row("1a 2b 14m", "(?l)", 0),
+      s"$prefix `regexp_extract` is invalid: '(?l)'")
+    checkExceptionInExpression[SparkRuntimeException](
+      RegExpExtractAll(s, p, r),
+      create_row("abc", "] [", 0),
+      s"$prefix `regexp_extract_all` is invalid: '] ['")
+  }
 }
diff --git a/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql
index efe5c278730..b11d2c7ce0d 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql
@@ -14,6 +14,7 @@ SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 3);
 SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', -1);
 SELECT regexp_extract('1a 2b 14m', '(\\d+)?([a-z]+)', 1);
 SELECT regexp_extract('a b m', '(\\d+)?([a-z]+)', 1);
+SELECT regexp_extract('1a 2b 14m', '(?l)');
 
 -- regexp_extract_all
 SELECT regexp_extract_all('1a 2b 14m', '\\d+');
@@ -31,6 +32,7 @@ SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 3);
 SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', -1);
 SELECT regexp_extract_all('1a 2b 14m', '(\\d+)?([a-z]+)', 1);
 SELECT regexp_extract_all('a 2b 14m', '(\\d+)?([a-z]+)', 1);
+SELECT regexp_extract_all('abc', col0, 1) FROM VALUES('], [') AS t(col0);
 
 -- regexp_replace
 SELECT regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something');
diff --git 
a/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out
index f0a6fa064d0..20d1273f348 100644
--- a/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 46
+-- Number of queries: 48
 
 
 -- !query
@@ -128,6 +128,15 @@ struct<regexp_extract(a b m, (\d+)?([a-z]+), 1):string>
 
 
 
+-- !query
+SELECT regexp_extract('1a 2b 14m', '(?l)')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+The value of parameter(s) 'regexp' in `regexp_extract` is invalid: '(?l)'
+
+
 -- !query
 SELECT regexp_extract_all('1a 2b 14m', '\\d+')
 -- !query schema
@@ -254,6 +263,15 @@ struct<regexp_extract_all(a 2b 14m, (\d+)?([a-z]+), 
1):array<string>>
 ["","2","14"]
 
 
+-- !query
+SELECT regexp_extract_all('abc', col0, 1) FROM VALUES('], [') AS t(col0)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+The value of parameter(s) 'regexp' in `regexp_extract_all` is invalid: '], ['
+
+
 -- !query
 SELECT regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something')
 -- !query schema


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to