This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0d1375fe0e90 [SPARK-47404][SQL] Add configurable size limits for ANTLR 
DFA cache
0d1375fe0e90 is described below

commit 0d1375fe0e90433e98e8034ce37454c24b3f5e4f
Author: Tynan Sigg <[email protected]>
AuthorDate: Tue Jul 15 21:21:56 2025 +0800

    [SPARK-47404][SQL] Add configurable size limits for ANTLR DFA cache
    
    ### What changes were proposed in this pull request?
    
    Add hooks to release the ANTLR DFA cache after parsing SQL
    
    ### Why are the changes needed?
    
    ANTLR builds a DFA cache while parsing to speed up parsing of similar 
future inputs. However, this cache is never cleared and can only grow. 
Extremely large SQL inputs can lead to very large DFA caches (>20GiB in one 
extreme case I've seen).
    
    Spark’s ANTLR SQL parser is derived from the Presto ANTLR SQL Parser (see 
[SPARK-13713](https://issues.apache.org/jira/browse/SPARK-13713) and 
https://github.com/apache/spark/pull/11557), and Presto has added hooks to be 
able to clear this DFA cache (https://github.com/trinodb/trino/pull/3186). I 
think Spark should have similar hooks.
    
    ### Does this PR introduce _any_ user-facing change?
    
    New `SQLConf`s to control the behavior:
    - `spark.sql.parser.manageParserCachesKillSwitch`: turns feature and 
logging on/off
    - `spark.sql.parser.parserDfaCacheFlushThreshold`: sets a limit on the 
absolute size of the DFA cache
    - `spark.sql.parser.parserDfaCacheFlushRatio`: sets a limit on the ratio of 
the (estimated) DFA cache size to the total driver memory
    
    ### How was this patch tested?
    
    New unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51069 from trsigg/antlr-cache.
    
    Authored-by: Tynan Sigg <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../scala/org/apache/spark/internal/LogKey.scala   |   3 +
 .../apache/spark/sql/catalyst/parser/parsers.scala | 126 +++++++++++++++++-
 .../org/apache/spark/sql/internal/SqlApiConf.scala |  11 ++
 .../spark/sql/internal/SqlApiConfHelper.scala      |   4 +
 .../org/apache/spark/sql/internal/SQLConf.scala    |  63 +++++++++
 .../sql/catalyst/analysis/AnalysisSuite.scala      |  90 ++++++-------
 .../spark/sql/execution/SparkSqlParserSuite.scala  | 142 ++++++++++++++++++++-
 7 files changed, 391 insertions(+), 48 deletions(-)

diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala 
b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
index 877ca7f4a9cb..61913670d4b4 100644
--- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
+++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
@@ -71,6 +71,8 @@ private[spark] object LogKeys {
   case object ALIGNED_TO_TIME extends LogKey
   case object ALPHA extends LogKey
   case object ANALYSIS_ERROR extends LogKey
+  case object ANTLR_DFA_CACHE_DELTA extends LogKey
+  case object ANTLR_DFA_CACHE_SIZE extends LogKey
   case object APP_ATTEMPT_ID extends LogKey
   case object APP_ATTEMPT_SHUFFLE_MERGE_ID extends LogKey
   case object APP_DESC extends LogKey
@@ -209,6 +211,7 @@ private[spark] object LogKeys {
   case object DIFF_DELTA extends LogKey
   case object DIVISIBLE_CLUSTER_INDICES_SIZE extends LogKey
   case object DRIVER_ID extends LogKey
+  case object DRIVER_JVM_MEMORY extends LogKey
   case object DRIVER_MEMORY_SIZE extends LogKey
   case object DRIVER_STATE extends LogKey
   case object DROPPED_PARTITIONS extends LogKey
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
index 28fccd2092b3..ab27d15d8540 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
@@ -16,15 +16,18 @@
  */
 package org.apache.spark.sql.catalyst.parser
 
+import java.util.concurrent.atomic.AtomicReference
+
 import scala.jdk.CollectionConverters._
 
 import org.antlr.v4.runtime._
-import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.atn.{ATN, ParserATNSimulator, 
PredictionContextCache, PredictionMode}
+import org.antlr.v4.runtime.dfa.DFA
 import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
 import org.antlr.v4.runtime.tree.TerminalNodeImpl
 
 import org.apache.spark.{QueryContext, SparkException, SparkThrowable, 
SparkThrowableHelper}
-import org.apache.spark.internal.Logging
+import org.apache.spark.internal.{Logging, LogKeys, MDC}
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, 
SQLQueryContext, WithOrigin}
 import org.apache.spark.sql.catalyst.util.SparkParserUtils
@@ -62,6 +65,7 @@ abstract class AbstractParser extends DataTypeParserInterface 
with Logging {
 
     val tokenStream = new CommonTokenStream(lexer)
     val parser = new SqlBaseParser(tokenStream)
+    if (conf.manageParserCaches) AbstractParser.installCaches(parser)
     parser.addParseListener(PostProcessor)
     parser.addParseListener(UnclosedCommentProcessor(command, tokenStream))
     parser.removeErrorListeners()
@@ -102,6 +106,18 @@ abstract class AbstractParser extends 
DataTypeParserInterface with Logging {
           errorClass = e.getCondition,
           messageParameters = e.getMessageParameters.asScala.toMap,
           queryContext = e.getQueryContext)
+    } finally {
+      // Antlr4 uses caches to make parsing faster but its caches are 
unbounded and never purged,
+      // which can cause OOMs when parsing a huge number of SQL queries. 
Clearing these caches too
+      // often will slow down parsing and cause performance regressions, but 
will prevent OOMs
+      // caused by the parser cache. We use a heuristic and clear the cache if 
the number of states
+      // in the DFA cache has exceeded the threshold
+      // configured by `spark.sql.parser.parserDfaCacheFlushThreshold`. These 
states generally
+      // represent the bulk of the memory consumed by the parser, and the size 
of a single state
+      // is approximately `BYTES_PER_DFA_STATE` bytes.
+      //
+      // Negative values mean we should never clear the cache
+      AbstractParser.maybeClearParserCaches(parser, conf)
     }
   }
 
@@ -439,3 +455,109 @@ case class UnclosedCommentProcessor(command: String, 
tokenStream: CommonTokenStr
 object DataTypeParser extends AbstractParser {
   override protected def astBuilder: DataTypeAstBuilder = new 
DataTypeAstBuilder
 }
+
+object AbstractParser extends Logging {
+  // Approximation based on experiments. Used to estimate the size of the DFA 
cache for the
+  // `parserDfaCacheFlushRatio` threshold.
+  final val BYTES_PER_DFA_STATE = 9700
+
+  private val DRIVER_MEMORY = Runtime.getRuntime.maxMemory()
+
+  private case class AntlrCaches(atn: ATN) {
+    private[parser] val predictionContextCache: PredictionContextCache =
+      new PredictionContextCache
+    private[parser] val decisionToDFACache: Array[DFA] = 
AntlrCaches.makeDecisionToDFACache(atn)
+
+    def installManagedParserCaches(parser: SqlBaseParser): Unit = {
+      parser.setInterpreter(
+        new ParserATNSimulator(parser, atn, decisionToDFACache, 
predictionContextCache))
+    }
+  }
+
+  private object AntlrCaches {
+    private def makeDecisionToDFACache(atn: ATN): Array[DFA] = {
+      val decisionToDFA = new Array[DFA](atn.getNumberOfDecisions)
+      for (i <- 0 until atn.getNumberOfDecisions) {
+        decisionToDFA(i) = new DFA(atn.getDecisionState(i), i)
+      }
+      decisionToDFA
+    }
+  }
+
+  private val parserCaches = new 
AtomicReference[AntlrCaches](AntlrCaches(SqlBaseParser._ATN))
+
+  private var numDFACacheStates: Long = 0
+  def getDFACacheNumStates: Long = numDFACacheStates
+
+  /**
+   * Returns the number of DFA states in the DFA cache.
+   *
+   * DFA states empirically consume about `BYTES_PER_DFA_STATE` bytes of 
memory each.
+   */
+  private def computeDFACacheNumStates: Long = {
+    parserCaches.get().decisionToDFACache.map(_.states.size).sum
+  }
+
+  /**
+   * Install the managed parser caches into the given parser. Configuring the 
parser to use the
+   * managed `AntlrCaches` enables us to manage the size of the cache and 
clear it when required
+   * as the parser caches are unbounded by default.
+   *
+   * This method should be called before parsing any input.
+   */
+  private[parser] def installCaches(parser: SqlBaseParser): Unit = {
+    parserCaches.get().installManagedParserCaches(parser)
+  }
+
+  /**
+   * Drop the existing parser caches and create a new one.
+   *
+   * ANTLR retains caches in its parser that are never released. This speeds 
up parsing of future
+   * input, but it can consume a lot of memory depending on the input seen so 
far.
+   *
+   * This method provides a mechanism to free the retained caches, which can 
be useful after
+   * parsing very large SQL inputs, especially if those large inputs are 
unlikely to be similar to
+   * future inputs seen by the driver.
+   */
+  private[parser] def clearParserCaches(parser: SqlBaseParser): Unit = {
+    parserCaches.set(AntlrCaches(SqlBaseParser._ATN))
+    logInfo(log"ANTLR parser caches cleared")
+    numDFACacheStates = 0
+    installCaches(parser)
+  }
+
+  /**
+   * Check cache size and config values to determine if we should clear the 
parser caches. Also
+   * logs the current cache size and the delta since the last check. This 
method should be called
+   * after parsing each input.
+   */
+  private[parser] def maybeClearParserCaches(parser: SqlBaseParser, conf: 
SqlApiConf): Unit = {
+    if (!conf.manageParserCaches) {
+      return
+    }
+
+    val numDFACacheStatesCurrent: Long = computeDFACacheNumStates
+    val numDFACacheStatesDelta = numDFACacheStatesCurrent - numDFACacheStates
+    numDFACacheStates = numDFACacheStatesCurrent
+    logInfo(
+      log"EXPERIMENTAL: Query cached " +
+        log"${MDC(LogKeys.ANTLR_DFA_CACHE_DELTA, numDFACacheStatesDelta)} " +
+        log"DFA states in the parser. Total cached DFA states: " +
+        log"${MDC(LogKeys.ANTLR_DFA_CACHE_SIZE, numDFACacheStatesCurrent)}." +
+        log"Driver memory: ${MDC(LogKeys.DRIVER_JVM_MEMORY, DRIVER_MEMORY)}.")
+
+    val staticThresholdExceeded = 0 <= conf.parserDfaCacheFlushThreshold &&
+      conf.parserDfaCacheFlushThreshold <= numDFACacheStatesCurrent
+
+    val estCacheBytes: Long = numDFACacheStatesCurrent * BYTES_PER_DFA_STATE
+    if (estCacheBytes < 0) {
+      logWarning(log"Estimated cache size is negative, likely due to an 
integer overflow.")
+    }
+    val dynamicThresholdExceeded = 0 <= conf.parserDfaCacheFlushRatio &&
+      conf.parserDfaCacheFlushRatio * DRIVER_MEMORY / 100 <= estCacheBytes
+
+    if (staticThresholdExceeded || dynamicThresholdExceeded) {
+      AbstractParser.clearParserCaches(parser)
+    }
+  }
+}
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
index 76449f1704d2..3ab9b312feea 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
@@ -47,6 +47,9 @@ private[sql] trait SqlApiConf {
   def stackTracesInDataFrameContext: Int
   def dataFrameQueryContextEnabled: Boolean
   def legacyAllowUntypedScalaUDFs: Boolean
+  def manageParserCaches: Boolean
+  def parserDfaCacheFlushThreshold: Int
+  def parserDfaCacheFlushRatio: Double
 }
 
 private[sql] object SqlApiConf {
@@ -60,6 +63,11 @@ private[sql] object SqlApiConf {
   val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = {
     SqlApiConfHelper.LOCAL_RELATION_CACHE_THRESHOLD_KEY
   }
+  val PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY: String =
+    SqlApiConfHelper.PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY
+  val PARSER_DFA_CACHE_FLUSH_RATIO_KEY: String =
+    SqlApiConfHelper.PARSER_DFA_CACHE_FLUSH_RATIO_KEY
+  val MANAGE_PARSER_CACHES_KEY: String = 
SqlApiConfHelper.MANAGE_PARSER_CACHES_KEY
 
   def get: SqlApiConf = SqlApiConfHelper.getConfGetter.get()()
 
@@ -88,4 +96,7 @@ private[sql] object DefaultSqlApiConf extends SqlApiConf {
   override def stackTracesInDataFrameContext: Int = 1
   override def dataFrameQueryContextEnabled: Boolean = true
   override def legacyAllowUntypedScalaUDFs: Boolean = false
+  override def manageParserCaches: Boolean = false
+  override def parserDfaCacheFlushThreshold: Int = -1
+  override def parserDfaCacheFlushRatio: Double = -1.0
 }
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
index dace1dbaecfa..727620bd5bd0 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
@@ -33,6 +33,10 @@ private[sql] object SqlApiConfHelper {
   val SESSION_LOCAL_TIMEZONE_KEY: String = "spark.sql.session.timeZone"
   val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = 
"spark.sql.session.localRelationCacheThreshold"
   val ARROW_EXECUTION_USE_LARGE_VAR_TYPES = 
"spark.sql.execution.arrow.useLargeVarTypes"
+  val PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY: String =
+    "spark.sql.parser.parserDfaCacheFlushThreshold"
+  val PARSER_DFA_CACHE_FLUSH_RATIO_KEY: String = 
"spark.sql.parser.parserDfaCacheFlushRatio"
+  val MANAGE_PARSER_CACHES_KEY: String = "spark.sql.parser.manageParserCaches"
 
   val confGetter: AtomicReference[() => SqlApiConf] = {
     new AtomicReference[() => SqlApiConf](() => DefaultSqlApiConf)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 4d2982e91f76..69a90f87cb0e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1197,6 +1197,63 @@ object SQLConf {
     .booleanConf
     .createWithDefault(false)
 
+  val PARSER_DFA_CACHE_FLUSH_RATIO =
+    buildConf("spark.sql.parser.parserDfaCacheFlushRatio")
+      .internal()
+      .doc(
+        """Like `spark.sql.parser.parserDfaCacheFlushThreshold`, but uses a 
threshold that is a
+          |linear function of the memory allocated to the driver process. 
Represents the percentage
+          |of the driver memory that the DFA cache can consume before it is 
flushed.
+          |
+          |Estimates the memory used by the DFA cache, assuming each state 
consumes
+          |`AbstractParser.BYTES_PER_DFA_STATE` bytes. If this value exceeds 
the product of the
+          |driver memory with the config value (interpreted as a percentage), 
the cache is flushed.
+          |
+          |Active values should be in the range 0-100, and a negative value 
disables the feature.
+          |If both this config and 
`spark.sql.parser.parserDfaCacheFlushThreshold` are set, the
+          |cache is flushed if either condition is met.
+          |Requires `spark.sql.parser.manageParserCaches` to be true to take 
effect.
+          |""".stripMargin)
+      .version("4.1.0")
+      .doubleConf
+      .checkValue(_ <= 100.0, "The ratio must be less than 100%")
+      .createWithDefault(-1.0)
+
+  val PARSER_DFA_CACHE_FLUSH_THRESHOLD =
+    buildConf("spark.sql.parser.parserDfaCacheFlushThreshold")
+      .internal()
+      .doc(
+        """When positive, release ANTLR caches after parsing a SQL query when 
the number of states
+          |in the DFA cache exceeds the value of the config. DFA states 
empirically consume about
+          |`AbstractParser.BYTES_PER_DFA_STATE` bytes of memory each.
+          |
+          |ANTLR parsers retain a DFA cache designed to speed up parsing 
future input. However,
+          |there is no limit to how large this cache can become. Parsing large 
SQL statements can
+          |lead to an accumulation of objects in the cache that are unlikely 
to be reused, causing
+          |high GC overhead and eventually OOMs.
+          |
+          |If this config is set to a negative value, it is ignored.
+          |If both this config and `spark.sql.parser.parserDfaCacheFlushRatio` 
are set, the
+          |cache is flushed if either condition is met.
+          |Requires `spark.sql.parser.manageParserCaches` to be true to take 
effect.
+          |
+          |Can significantly slow down parsing in exchange for better memory 
stability.
+          |""".stripMargin)
+      .version("4.1.0")
+      .intConf
+      .createWithDefault(-1)
+
+  val MANAGE_PARSER_CACHES =
+    buildConf("spark.sql.parser.manageParserCaches")
+      .internal()
+      .doc(
+        """When true, we install our own ANTLR caches to manage memory usage. 
When false, we use the
+          |default ANTLR caches. Dependency for
+          |`spark.sql.parser.parserDfaCacheFlushThreshold`.""".stripMargin)
+      .version("4.1.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val FILE_COMPRESSION_FACTOR = 
buildConf("spark.sql.sources.fileCompressionFactor")
     .internal()
     .doc("When estimating the output data size of a table scan, multiply the 
file size with this " +
@@ -6498,6 +6555,12 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
 
+  def parserDfaCacheFlushRatio: Double = getConf(PARSER_DFA_CACHE_FLUSH_RATIO)
+
+  def parserDfaCacheFlushThreshold: Int = 
getConf(PARSER_DFA_CACHE_FLUSH_THRESHOLD)
+
+  def manageParserCaches: Boolean = getConf(MANAGE_PARSER_CACHES)
+
   def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR)
 
   def stringRedactionPattern: Option[Regex] = 
getConf(SQL_STRING_REDACTION_PATTERN)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 23aab31c103b..0c8d2bae418a 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1034,51 +1034,53 @@ class AnalysisSuite extends AnalysisTest with Matchers {
 
   test("SPARK-30886 Deprecate two-parameter TRIM/LTRIM/RTRIM") {
     Seq("trim", "ltrim", "rtrim").foreach { f =>
-      val logAppender = new LogAppender("deprecated two-parameter 
TRIM/LTRIM/RTRIM functions")
-      def check(count: Int): Unit = {
-        val message = "Two-parameter TRIM/LTRIM/RTRIM function signatures are 
deprecated."
-        assert(logAppender.loggingEvents.size == count)
-        assert(logAppender.loggingEvents.exists(
-          e => e.getLevel == Level.WARN &&
-            e.getMessage.getFormattedMessage.contains(message)))
-      }
+      withSQLConf(SQLConf.MANAGE_PARSER_CACHES.key -> "false") { // Avoid 
additional logging
+        val logAppender = new LogAppender("deprecated two-parameter 
TRIM/LTRIM/RTRIM functions")
+        def check(count: Int): Unit = {
+          val message = "Two-parameter TRIM/LTRIM/RTRIM function signatures 
are deprecated."
+          assert(logAppender.loggingEvents.size == count)
+          assert(logAppender.loggingEvents.exists(
+            e => e.getLevel == Level.WARN &&
+              e.getMessage.getFormattedMessage.contains(message)))
+        }
 
-      withLogAppender(logAppender) {
-        val testAnalyzer1 = new Analyzer(
-          new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin))
-
-        val plan1 = testRelation2.select(
-          UnresolvedFunction(f, $"a" :: Nil, isDistinct = false))
-        testAnalyzer1.execute(plan1)
-        // One-parameter is not deprecated.
-        assert(logAppender.loggingEvents.isEmpty)
-
-        val plan2 = testRelation2.select(
-          UnresolvedFunction(f, $"a" :: $"b" :: Nil, isDistinct = false))
-        testAnalyzer1.execute(plan2)
-        // Deprecation warning is printed out once.
-        check(1)
-
-        val plan3 = testRelation2.select(
-          UnresolvedFunction(f, $"b" :: $"a" :: Nil, isDistinct = false))
-        testAnalyzer1.execute(plan3)
-        // There is no change in the log.
-        check(1)
-
-        // New analyzer from new SessionState
-        val testAnalyzer2 = new Analyzer(
-          new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin))
-        val plan4 = testRelation2.select(
-          UnresolvedFunction(f, $"c" :: $"d" :: Nil, isDistinct = false))
-        testAnalyzer2.execute(plan4)
-        // Additional deprecation warning from new analyzer
-        check(2)
-
-        val plan5 = testRelation2.select(
-          UnresolvedFunction(f, $"c" :: $"d" :: Nil, isDistinct = false))
-        testAnalyzer2.execute(plan5)
-        // There is no change in the log.
-        check(2)
+        withLogAppender(logAppender) {
+          val testAnalyzer1 = new Analyzer(
+            new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin))
+
+          val plan1 = testRelation2.select(
+            UnresolvedFunction(f, $"a" :: Nil, isDistinct = false))
+          testAnalyzer1.execute(plan1)
+          // One-parameter is not deprecated.
+          assert(logAppender.loggingEvents.isEmpty)
+
+          val plan2 = testRelation2.select(
+            UnresolvedFunction(f, $"a" :: $"b" :: Nil, isDistinct = false))
+          testAnalyzer1.execute(plan2)
+          // Deprecation warning is printed out once.
+          check(1)
+
+          val plan3 = testRelation2.select(
+            UnresolvedFunction(f, $"b" :: $"a" :: Nil, isDistinct = false))
+          testAnalyzer1.execute(plan3)
+          // There is no change in the log.
+          check(1)
+
+          // New analyzer from new SessionState
+          val testAnalyzer2 = new Analyzer(
+            new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin))
+          val plan4 = testRelation2.select(
+            UnresolvedFunction(f, $"c" :: $"d" :: Nil, isDistinct = false))
+          testAnalyzer2.execute(plan4)
+          // Additional deprecation warning from new analyzer
+          check(2)
+
+          val plan5 = testRelation2.select(
+            UnresolvedFunction(f, $"c" :: $"d" :: Nil, isDistinct = false))
+          testAnalyzer2.execute(plan5)
+          // There is no change in the log.
+          check(2)
+        }
       }
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
index fbcc8a582bfb..94e60db67ac7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
@@ -19,12 +19,12 @@ package org.apache.spark.sql.execution
 
 import scala.jdk.CollectionConverters._
 
-import org.apache.spark.SparkThrowable
+import org.apache.spark.{SparkConf, SparkThrowable}
 import org.apache.spark.internal.config.ConfigEntry
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
 import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, 
UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedHaving, 
UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.expressions.{Ascending, 
AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, 
UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, 
WindowSpecReference}
-import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.parser.{AbstractParser, ParseException}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.connector.catalog.TableCatalog
@@ -44,6 +44,10 @@ import org.apache.spark.util.ArrayImplicits._
 class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession {
   import org.apache.spark.sql.catalyst.dsl.expressions._
 
+  override protected def sparkConf: SparkConf =
+    super.sparkConf
+      .set(SQLConf.MANAGE_PARSER_CACHES.key, true.toString)
+
   private lazy val parser = new SparkSqlParser()
 
   private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = {
@@ -1026,4 +1030,138 @@ class SparkSqlParserSuite extends AnalysisTest with 
SharedSparkSession {
           stop = sql.length - 1))
     }
   }
+
+  private def awfulQuery(depth: Int): String = {
+    if (depth == 0) {
+      s"rand()"
+    } else {
+      s"case when ${awfulQuery(depth - 1)} > 0.5 " +
+      s"then ${awfulQuery(depth - 1)} " +
+      s"else ${awfulQuery(depth - 1)} " +
+      "end"
+    }
+  }
+
+  test("SPARK-47404: Managed parsers killswitch works") {
+    val initialSize = AbstractParser.getDFACacheNumStates
+    val mediumQuery = s"select ${awfulQuery(2)} from range(10)"
+
+    withSQLConf(SQLConf.PARSER_DFA_CACHE_FLUSH_THRESHOLD.key -> 
(10000).toString,
+        SQLConf.PARSER_DFA_CACHE_FLUSH_RATIO.key -> 100.toString) {
+      withSQLConf(SQLConf.MANAGE_PARSER_CACHES.key -> false.toString) {
+        parser.parsePlan(mediumQuery)
+      }
+      val disabledSize = AbstractParser.getDFACacheNumStates
+      // There should be no change to the state of the managed caches when not 
enabled
+      assert(disabledSize == initialSize)
+
+      withSQLConf(SQLConf.MANAGE_PARSER_CACHES.key -> true.toString) {
+        parser.parsePlan(mediumQuery)
+      }
+      val enabledSize = AbstractParser.getDFACacheNumStates
+      // Now the cache should be populated
+      assert(enabledSize > initialSize)
+    }
+  }
+
+  test("SPARK-47404: Always release Antlr cache when cache limit is 0") {
+    withSQLConf(SQLConf.PARSER_DFA_CACHE_FLUSH_THRESHOLD.key -> (-1).toString) 
{
+      parser.parsePlan("select id from range(10)")
+    }
+    val initialCacheSize = AbstractParser.getDFACacheNumStates
+    assert(initialCacheSize > 0)
+
+    withSQLConf(SQLConf.PARSER_DFA_CACHE_FLUSH_THRESHOLD.key -> 0.toString) {
+      parser.parsePlan("select id from range(10)")
+    }
+    val clearedCacheSize = AbstractParser.getDFACacheNumStates
+    assert(clearedCacheSize == 0)
+  }
+
+  test("SPARK-47404: Release ANTLR cache based on threshold") {
+    val smallQuery = "select id from range(10)"
+    val bigQuery = s"select ${awfulQuery(8)} from range(10)"
+
+    // Chose this value based on the observed size of the parser cache being 
~27k states after
+    // parsing `bigQuery` on my machine.
+    val threshold = 10000
+
+    // Fill the cache a little
+    withSQLConf(SQLConf.PARSER_DFA_CACHE_FLUSH_THRESHOLD.key -> 
threshold.toString) {
+      parser.parsePlan(smallQuery)
+    }
+    val smallQueryCacheSize = AbstractParser.getDFACacheNumStates
+    assert(smallQueryCacheSize > 0)
+    assert(smallQueryCacheSize < threshold)
+
+    // Parse a big query to fill the cache
+    withSQLConf(SQLConf.PARSER_DFA_CACHE_FLUSH_THRESHOLD.key -> (-1).toString) 
{
+      parser.parsePlan(bigQuery)
+    }
+    val bigQueryCacheSize = AbstractParser.getDFACacheNumStates
+    assert(bigQueryCacheSize > threshold)
+
+    // Parse a small query to release the cache
+    withSQLConf(SQLConf.PARSER_DFA_CACHE_FLUSH_THRESHOLD.key -> 
threshold.toString) {
+      parser.parsePlan(smallQuery)
+    }
+    val clearedCacheSize = AbstractParser.getDFACacheNumStates
+    assert(clearedCacheSize == 0)
+  }
+
+  test("SPARK-47404: Release Antlr cache based on memory ratio") {
+    val smallQuery = "select id from range(10)"
+    val bigQuery = s"select ${awfulQuery(8)} from range(10)"
+
+    val driverMemory = Runtime.getRuntime.maxMemory()
+    // `bigQuery` fills the cache to about 27k states
+    val stateThreshold = 15000
+    // Calculate what ratio will give us this threshold based on driver memory
+    val ratio = stateThreshold * AbstractParser.BYTES_PER_DFA_STATE * 100.0 / 
driverMemory
+
+    // Fill the cache a little
+    withSQLConf(SQLConf.PARSER_DFA_CACHE_FLUSH_RATIO.key -> ratio.toString) {
+      parser.parsePlan(smallQuery)
+    }
+    val smallQueryCacheSize = AbstractParser.getDFACacheNumStates
+    assert(smallQueryCacheSize > 0)
+    assert(smallQueryCacheSize < stateThreshold)
+
+    // Parse a big query to fill the cache
+    withSQLConf(SQLConf.PARSER_DFA_CACHE_FLUSH_RATIO.key -> 100.toString) {
+      parser.parsePlan(bigQuery)
+    }
+    val bigQueryCacheSize = AbstractParser.getDFACacheNumStates
+    assert(bigQueryCacheSize > smallQueryCacheSize)
+
+    // Parse a small query to release the cache
+    withSQLConf(SQLConf.PARSER_DFA_CACHE_FLUSH_RATIO.key -> ratio.toString) {
+      parser.parsePlan(smallQuery)
+    }
+    val clearedCacheSize = AbstractParser.getDFACacheNumStates
+    assert(clearedCacheSize == 0)
+  }
+
+  Seq(
+    (-1, -1, false),
+    (10000, -1, true),
+    (-1, 1, true),
+    (10000, 1, true)
+  ).foreach { case (threshold, ratio, shouldFlush) =>
+    test(s"SPARK-47404: Antlr cache combined thresholds. States: $threshold, 
Ratio: $ratio") {
+      // The cache should be flushed if either of the thresholds are exceeded.
+      val bigQuery = s"select ${awfulQuery(8)} from range(10)"
+      withSQLConf(
+          SQLConf.PARSER_DFA_CACHE_FLUSH_THRESHOLD.key -> threshold.toString,
+          SQLConf.PARSER_DFA_CACHE_FLUSH_RATIO.key -> ratio.toString) {
+        parser.parsePlan(bigQuery)
+        val bigQueryCacheSize = AbstractParser.getDFACacheNumStates
+        if (shouldFlush) {
+          assert(bigQueryCacheSize == 0)
+        } else {
+          assert(bigQueryCacheSize > 0)
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to