This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch branch-3.2 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push: new ea5af851e78 [SPARK-39758][SQL][3.2] Fix NPE from the regexp functions on invalid patterns ea5af851e78 is described below commit ea5af851e78f569365b240e3a81b8108d8b2f650 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Thu Jul 14 17:49:39 2022 +0300 [SPARK-39758][SQL][3.2] Fix NPE from the regexp functions on invalid patterns ### What changes were proposed in this pull request? In the PR, I propose to catch `PatternSyntaxException` while compiling the regexp pattern by the `regexp_extract`, `regexp_extract_all` and `regexp_instr`, and substitute the exception by Spark's exception w/ the error class `INVALID_PARAMETER_VALUE`. In this way, Spark SQL will output the error in the form: ```sql org.apache.spark.SparkRuntimeException The value of parameter(s) 'regexp' in `regexp_instr` is invalid: ') ?' ``` instead of (on Spark 3.3.0): ```java java.lang.NullPointerException: null ``` Also I propose to set `lastRegex` only after the compilation of the regexp pattern completes successfully. This is a backport of https://github.com/apache/spark/pull/37171. ### Why are the changes needed? The changes fix NPE portrayed by the code on Spark 3.3.0: ```sql spark-sql> SELECT regexp_extract('1a 2b 14m', '(?l)'); 22/07/12 19:07:21 ERROR SparkSQLDriver: Failed in [SELECT regexp_extract('1a 2b 14m', '(?l)')] java.lang.NullPointerException: null at org.apache.spark.sql.catalyst.expressions.RegExpExtractBase.getLastMatcher(regexpExpressions.scala:768) ~[spark-catalyst_2.12-3.3.0.jar:3.3.0] ``` This should improve user experience with Spark SQL. ### Does this PR introduce _any_ user-facing change? No. In regular cases, the behavior is the same but users will observe different exceptions (error messages) after the changes. ### How was this patch tested? By running new tests: ``` $ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z regexp-functions.sql" $ build/sbt "test:testOnly *.RegexpExpressionsSuite" $ build/sbt "sql/test:testOnly org.apache.spark.sql.expressions.ExpressionInfoSuite" ``` Authored-by: Max Gekk <max.gekkgmail.com> Signed-off-by: Max Gekk <max.gekkgmail.com> (cherry picked from commit 5b96bd5cf8f44eee7a16cd027d37dec552ed5a6a) Signed-off-by: Max Gekk <max.gekkgmail.com> Closes #37182 from MaxGekk/pattern-syntax-exception-3.2. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- core/src/main/resources/error/error-classes.json | 4 ++ .../scala/org/apache/spark/SparkException.scala | 12 +++++ .../catalyst/expressions/regexpExpressions.scala | 62 +++++++++++++--------- .../spark/sql/errors/QueryExecutionErrors.scala | 12 ++++- .../expressions/RegexpExpressionsSuite.scala | 17 +++++- .../sql-tests/inputs/regexp-functions.sql | 2 + .../sql-tests/results/regexp-functions.sql.out | 20 ++++++- 7 files changed, 100 insertions(+), 29 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 9999eb5f6e4..7aa9b26b00d 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -33,6 +33,10 @@ "message" : [ "Field name %s is invalid: %s is not a struct." ], "sqlState" : "42000" }, + "INVALID_PARAMETER_VALUE" : { + "message" : [ "The value of parameter(s) '%s' in `%s` is invalid: '%s'" ], + "sqlState" : "22023" + }, "MISSING_COLUMN" : { "message" : [ "cannot resolve '%s' given input columns: [%s]" ], "sqlState" : "42000" diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index 6ba425fe909..f74ecf3e26f 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -79,3 +79,15 @@ private[spark] class SparkArithmeticException(errorClass: String, messageParamet override def getErrorClass: String = errorClass override def getSqlState: String = SparkThrowableHelper.getSqlState(errorClass) } + +private[spark] class SparkRuntimeException( + errorClass: String, + messageParameters: Array[String], + cause: Throwable = null) + extends RuntimeException( + SparkThrowableHelper.getMessage(errorClass, messageParameters), cause) + with SparkThrowable { + + override def getErrorClass: String = errorClass + override def getSqlState: String = SparkThrowableHelper.getSqlState(errorClass) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 543481e9f4d..9b5a228ea5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import java.util.regex.{Matcher, MatchResult, Pattern} +import java.util.regex.{Matcher, MatchResult, Pattern, PatternSyntaxException} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -678,11 +678,42 @@ abstract class RegExpExtractBase protected def getLastMatcher(s: Any, p: Any): Matcher = { if (p != lastRegex) { // regex value changed - lastRegex = p.asInstanceOf[UTF8String].clone() - pattern = Pattern.compile(lastRegex.toString) + try { + val r = p.asInstanceOf[UTF8String].clone() + pattern = Pattern.compile(r.toString) + lastRegex = r + } catch { + case e: PatternSyntaxException => + throw QueryExecutionErrors.invalidPatternError(prettyName, e.getPattern) + + } } pattern.matcher(s.toString) } + + protected def initLastMatcherCode( + ctx: CodegenContext, + subject: String, + regexp: String, + matcher: String): String = { + val classNamePattern = classOf[Pattern].getCanonicalName + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") + + s""" + |if (!$regexp.equals($termLastRegex)) { + | // regex value changed + | try { + | UTF8String r = $regexp.clone(); + | $termPattern = $classNamePattern.compile(r.toString()); + | $termLastRegex = r; + | } catch (java.util.regex.PatternSyntaxException e) { + | throw QueryExecutionErrors.invalidPatternError("$prettyName", e.getPattern()); + | } + |} + |java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString()); + |""".stripMargin + } } /** @@ -744,14 +775,9 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val classNamePattern = classOf[Pattern].getCanonicalName val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName val matcher = ctx.freshName("matcher") val matchResult = ctx.freshName("matchResult") - - val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") - val termPattern = ctx.addMutableState(classNamePattern, "pattern") - val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" } else { @@ -760,13 +786,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" - if (!$regexp.equals($termLastRegex)) { - // regex value changed - $termLastRegex = $regexp.clone(); - $termPattern = $classNamePattern.compile($termLastRegex.toString()); - } - java.util.regex.Matcher $matcher = - $termPattern.matcher($subject.toString()); + ${initLastMatcherCode(ctx, subject, regexp, matcher)} if ($matcher.find()) { java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); $classNameRegExpExtractBase.checkGroupIndex($matchResult.groupCount(), $idx); @@ -848,16 +868,11 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres override def prettyName: String = "regexp_extract_all" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val classNamePattern = classOf[Pattern].getCanonicalName val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName val arrayClass = classOf[GenericArrayData].getName val matcher = ctx.freshName("matcher") val matchResult = ctx.freshName("matchResult") val matchResults = ctx.freshName("matchResults") - - val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") - val termPattern = ctx.addMutableState(classNamePattern, "pattern") - val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" } else { @@ -865,12 +880,7 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres } nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" - | if (!$regexp.equals($termLastRegex)) { - | // regex value changed - | $termLastRegex = $regexp.clone(); - | $termPattern = $classNamePattern.compile($termLastRegex.toString()); - | } - | java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString()); + | ${initLastMatcherCode(ctx, subject, regexp, matcher)} | java.util.ArrayList $matchResults = new java.util.ArrayList<UTF8String>(); | while ($matcher.find()) { | java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 3c922dec29d..be31f0baaed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.fs.permission.FsPermission import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.InternalCompilerException -import org.apache.spark.{Partition, SparkArithmeticException, SparkException, SparkUpgradeException} +import org.apache.spark.{Partition, SparkArithmeticException, SparkException, SparkRuntimeException, SparkUpgradeException} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.launcher.SparkLauncher import org.apache.spark.memory.SparkOutOfMemoryError @@ -1821,4 +1821,14 @@ private[sql] object QueryExecutionErrors { new SparkException(errorClass = "NULL_COMPARISON_RESULT", messageParameters = Array(), cause = null) } + + def invalidPatternError(funcName: String, pattern: String): RuntimeException = { + new SparkRuntimeException( + errorClass = "INVALID_PARAMETER_VALUE", + messageParameters = Array( + "regexp", + funcName, + pattern), + cause = null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 2ca9ede7742..a4211384226 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext @@ -483,4 +483,19 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { .likeAll("%foo%", Literal.create(null, StringType)), null) } } + + test("SPARK-39758: invalid regexp pattern") { + val s = $"s".string.at(0) + val p = $"p".string.at(1) + val r = $"r".int.at(2) + val prefix = "The value of parameter(s) 'regexp' in" + checkExceptionInExpression[SparkRuntimeException]( + RegExpExtract(s, p, r), + create_row("1a 2b 14m", "(?l)", 0), + s"$prefix `regexp_extract` is invalid: '(?l)'") + checkExceptionInExpression[SparkRuntimeException]( + RegExpExtractAll(s, p, r), + create_row("abc", "] [", 0), + s"$prefix `regexp_extract_all` is invalid: '] ['") + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql index efe5c278730..b11d2c7ce0d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql @@ -14,6 +14,7 @@ SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 3); SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', -1); SELECT regexp_extract('1a 2b 14m', '(\\d+)?([a-z]+)', 1); SELECT regexp_extract('a b m', '(\\d+)?([a-z]+)', 1); +SELECT regexp_extract('1a 2b 14m', '(?l)'); -- regexp_extract_all SELECT regexp_extract_all('1a 2b 14m', '\\d+'); @@ -31,6 +32,7 @@ SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 3); SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', -1); SELECT regexp_extract_all('1a 2b 14m', '(\\d+)?([a-z]+)', 1); SELECT regexp_extract_all('a 2b 14m', '(\\d+)?([a-z]+)', 1); +SELECT regexp_extract_all('abc', col0, 1) FROM VALUES('], [') AS t(col0); -- regexp_replace SELECT regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something'); diff --git a/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out index f0a6fa064d0..20d1273f348 100644 --- a/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 46 +-- Number of queries: 48 -- !query @@ -128,6 +128,15 @@ struct<regexp_extract(a b m, (\d+)?([a-z]+), 1):string> +-- !query +SELECT regexp_extract('1a 2b 14m', '(?l)') +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkRuntimeException +The value of parameter(s) 'regexp' in `regexp_extract` is invalid: '(?l)' + + -- !query SELECT regexp_extract_all('1a 2b 14m', '\\d+') -- !query schema @@ -254,6 +263,15 @@ struct<regexp_extract_all(a 2b 14m, (\d+)?([a-z]+), 1):array<string>> ["","2","14"] +-- !query +SELECT regexp_extract_all('abc', col0, 1) FROM VALUES('], [') AS t(col0) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkRuntimeException +The value of parameter(s) 'regexp' in `regexp_extract_all` is invalid: '], [' + + -- !query SELECT regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something') -- !query schema --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org