This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a1fc6d57b27d [SPARK-47818][CONNECT] Introduce plan cache in 
SparkConnectPlanner to improve performance of Analyze requests
a1fc6d57b27d is described below

commit a1fc6d57b27d24b832b2f2580e6acd64c4488c62
Author: Xi Lyu <xi....@databricks.com>
AuthorDate: Tue Apr 16 16:27:32 2024 +0900

    [SPARK-47818][CONNECT] Introduce plan cache in SparkConnectPlanner to 
improve performance of Analyze requests
    
    ### What changes were proposed in this pull request?
    
    While building the DataFrame step by step, each time a new DataFrame is 
generated with an empty schema, which is lazily computed on access. However, if 
a user's code frequently accesses the schema of these new DataFrames using 
methods such as `df.columns`, it will result in a large number of Analyze 
requests to the server. Each time, the entire plan needs to be reanalyzed, 
leading to poor performance, especially when constructing highly complex plans.
    
    Now, by introducing plan cache in SparkConnectPlanner, we aim to reduce the 
overhead of repeated analysis during this process. This is achieved by saving 
significant computation if the resolved logical plan of a subtree of can be 
cached.
    
    A minimal example of the problem:
    
    ```
    import pyspark.sql.functions as F
    df = spark.range(10)
    for i in range(200):
      if str(i) not in df.columns: # <-- The df.columns call causes a new 
Analyze request in every iteration
        df = df.withColumn(str(i), F.col("id") + i)
    df.show()
    ```
    
    With this patch, the performance of the above code improved from ~110s to 
~5s.
    
    ### Why are the changes needed?
    
    The performance improvement is huge in the above cases.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, a static conf `spark.connect.session.planCache.maxSize` and a dynamic 
conf `spark.connect.session.planCache.enabled` are added.
    
    * `spark.connect.session.planCache.maxSize`: Sets the maximum number of 
cached resolved logical plans in Spark Connect Session. If set to a value less 
or equal than zero will disable the plan cache
    * `spark.connect.session.planCache.enabled`: When true, the cache of 
resolved logical plans is enabled if `spark.connect.session.planCache.maxSize` 
is greater than zero. When false, the cache is disabled even if 
`spark.connect.session.planCache.maxSize` is greater than zero. The caching is 
best-effort and not guaranteed.
    
    ### How was this patch tested?
    
    Some new tests are added in SparkConnectSessionHolderSuite.scala.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46012 from xi-db/SPARK-47818-plan-cache.
    
    Lead-authored-by: Xi Lyu <xi....@databricks.com>
    Co-authored-by: Xi Lyu <159039256+xi...@users.noreply.github.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../apache/spark/sql/connect/config/Connect.scala  |  18 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  | 201 ++++++++++++---------
 .../spark/sql/connect/service/SessionHolder.scala  |  79 +++++++-
 .../service/SparkConnectAnalyzeHandler.scala       |  26 +--
 .../service/SparkConnectSessionHolderSuite.scala   | 125 ++++++++++++-
 5 files changed, 345 insertions(+), 104 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 6ba100af1bb9..e94e86587393 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -273,4 +273,22 @@ object Connect {
       .version("4.0.0")
       .timeConf(TimeUnit.MILLISECONDS)
       .createWithDefaultString("2s")
+
+  val CONNECT_SESSION_PLAN_CACHE_SIZE =
+    buildStaticConf("spark.connect.session.planCache.maxSize")
+      .doc("Sets the maximum number of cached resolved logical plans in Spark 
Connect Session." +
+        " If set to a value less or equal than zero will disable the plan 
cache.")
+      .version("4.0.0")
+      .intConf
+      .createWithDefault(5)
+
+  val CONNECT_SESSION_PLAN_CACHE_ENABLED =
+    buildConf("spark.connect.session.planCache.enabled")
+      .doc("When true, the cache of resolved logical plans is enabled if" +
+        s" '${CONNECT_SESSION_PLAN_CACHE_SIZE.key}' is greater than zero." +
+        s" When false, the cache is disabled even if 
'${CONNECT_SESSION_PLAN_CACHE_SIZE.key}' is" +
+        " greater than zero. The caching is best-effort and not guaranteed.")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(true)
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 5e7f3b74c299..d8eb044e4f94 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -115,95 +115,118 @@ class SparkConnectPlanner(
   private lazy val pythonExec =
     sys.env.getOrElse("PYSPARK_PYTHON", 
sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))
 
-  // The root of the query plan is a relation and we apply the transformations 
to it.
-  def transformRelation(rel: proto.Relation): LogicalPlan = {
-    val plan = rel.getRelTypeCase match {
-      // DataFrame API
-      case proto.Relation.RelTypeCase.SHOW_STRING => 
transformShowString(rel.getShowString)
-      case proto.Relation.RelTypeCase.HTML_STRING => 
transformHtmlString(rel.getHtmlString)
-      case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
-      case proto.Relation.RelTypeCase.PROJECT => 
transformProject(rel.getProject)
-      case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter)
-      case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit)
-      case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset)
-      case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail)
-      case proto.Relation.RelTypeCase.JOIN => 
transformJoinOrJoinWith(rel.getJoin)
-      case proto.Relation.RelTypeCase.AS_OF_JOIN => 
transformAsOfJoin(rel.getAsOfJoin)
-      case proto.Relation.RelTypeCase.DEDUPLICATE => 
transformDeduplicate(rel.getDeduplicate)
-      case proto.Relation.RelTypeCase.SET_OP => 
transformSetOperation(rel.getSetOp)
-      case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
-      case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop)
-      case proto.Relation.RelTypeCase.AGGREGATE => 
transformAggregate(rel.getAggregate)
-      case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
-      case proto.Relation.RelTypeCase.WITH_RELATIONS
-          if isValidSQLWithRefs(rel.getWithRelations) =>
-        transformSqlWithRefs(rel.getWithRelations)
-      case proto.Relation.RelTypeCase.LOCAL_RELATION =>
-        transformLocalRelation(rel.getLocalRelation)
-      case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample)
-      case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange)
-      case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
-        transformSubqueryAlias(rel.getSubqueryAlias)
-      case proto.Relation.RelTypeCase.REPARTITION => 
transformRepartition(rel.getRepartition)
-      case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa)
-      case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa)
-      case proto.Relation.RelTypeCase.REPLACE => 
transformReplace(rel.getReplace)
-      case proto.Relation.RelTypeCase.SUMMARY => 
transformStatSummary(rel.getSummary)
-      case proto.Relation.RelTypeCase.DESCRIBE => 
transformStatDescribe(rel.getDescribe)
-      case proto.Relation.RelTypeCase.COV => transformStatCov(rel.getCov)
-      case proto.Relation.RelTypeCase.CORR => transformStatCorr(rel.getCorr)
-      case proto.Relation.RelTypeCase.APPROX_QUANTILE =>
-        transformStatApproxQuantile(rel.getApproxQuantile)
-      case proto.Relation.RelTypeCase.CROSSTAB =>
-        transformStatCrosstab(rel.getCrosstab)
-      case proto.Relation.RelTypeCase.FREQ_ITEMS => 
transformStatFreqItems(rel.getFreqItems)
-      case proto.Relation.RelTypeCase.SAMPLE_BY =>
-        transformStatSampleBy(rel.getSampleBy)
-      case proto.Relation.RelTypeCase.TO_SCHEMA => 
transformToSchema(rel.getToSchema)
-      case proto.Relation.RelTypeCase.TO_DF =>
-        transformToDF(rel.getToDf)
-      case proto.Relation.RelTypeCase.WITH_COLUMNS_RENAMED =>
-        transformWithColumnsRenamed(rel.getWithColumnsRenamed)
-      case proto.Relation.RelTypeCase.WITH_COLUMNS => 
transformWithColumns(rel.getWithColumns)
-      case proto.Relation.RelTypeCase.WITH_WATERMARK =>
-        transformWithWatermark(rel.getWithWatermark)
-      case proto.Relation.RelTypeCase.CACHED_LOCAL_RELATION =>
-        transformCachedLocalRelation(rel.getCachedLocalRelation)
-      case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint)
-      case proto.Relation.RelTypeCase.UNPIVOT => 
transformUnpivot(rel.getUnpivot)
-      case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION =>
-        transformRepartitionByExpression(rel.getRepartitionByExpression)
-      case proto.Relation.RelTypeCase.MAP_PARTITIONS =>
-        transformMapPartitions(rel.getMapPartitions)
-      case proto.Relation.RelTypeCase.GROUP_MAP =>
-        transformGroupMap(rel.getGroupMap)
-      case proto.Relation.RelTypeCase.CO_GROUP_MAP =>
-        transformCoGroupMap(rel.getCoGroupMap)
-      case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE =>
-        transformApplyInPandasWithState(rel.getApplyInPandasWithState)
-      case 
proto.Relation.RelTypeCase.COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION =>
-        
transformCommonInlineUserDefinedTableFunction(rel.getCommonInlineUserDefinedTableFunction)
-      case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION =>
-        transformCachedRemoteRelation(rel.getCachedRemoteRelation)
-      case proto.Relation.RelTypeCase.COLLECT_METRICS =>
-        transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId)
-      case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
-      case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
-        throw new IndexOutOfBoundsException("Expected Relation to be set, but 
is empty.")
-
-      // Catalog API (internal-only)
-      case proto.Relation.RelTypeCase.CATALOG => 
transformCatalog(rel.getCatalog)
-
-      // Handle plugins for Spark Connect Relation types.
-      case proto.Relation.RelTypeCase.EXTENSION =>
-        transformRelationPlugin(rel.getExtension)
-      case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
-    }
-
-    if (rel.hasCommon && rel.getCommon.hasPlanId) {
-      plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId)
-    }
-    plan
+  /**
+   * The root of the query plan is a relation and we apply the transformations 
to it. The resolved
+   * logical plan will not get cached. If the result needs to be cached, use
+   * `transformRelation(rel, cachePlan = true)` instead.
+   * @param rel
+   *   The relation to transform.
+   * @return
+   *   The resolved logical plan.
+   */
+  def transformRelation(rel: proto.Relation): LogicalPlan =
+    transformRelation(rel, cachePlan = false)
+
+  /**
+   * The root of the query plan is a relation and we apply the transformations 
to it.
+   * @param rel
+   *   The relation to transform.
+   * @param cachePlan
+   *   Set to true for a performance optimization, if the plan is likely to be 
reused, e.g. built
+   *   upon by further dataset transformation. The default is false.
+   * @return
+   *   The resolved logical plan.
+   */
+  def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan 
= {
+    sessionHolder.usePlanCache(rel, cachePlan) { rel =>
+      val plan = rel.getRelTypeCase match {
+        // DataFrame API
+        case proto.Relation.RelTypeCase.SHOW_STRING => 
transformShowString(rel.getShowString)
+        case proto.Relation.RelTypeCase.HTML_STRING => 
transformHtmlString(rel.getHtmlString)
+        case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
+        case proto.Relation.RelTypeCase.PROJECT => 
transformProject(rel.getProject)
+        case proto.Relation.RelTypeCase.FILTER => 
transformFilter(rel.getFilter)
+        case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit)
+        case proto.Relation.RelTypeCase.OFFSET => 
transformOffset(rel.getOffset)
+        case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail)
+        case proto.Relation.RelTypeCase.JOIN => 
transformJoinOrJoinWith(rel.getJoin)
+        case proto.Relation.RelTypeCase.AS_OF_JOIN => 
transformAsOfJoin(rel.getAsOfJoin)
+        case proto.Relation.RelTypeCase.DEDUPLICATE => 
transformDeduplicate(rel.getDeduplicate)
+        case proto.Relation.RelTypeCase.SET_OP => 
transformSetOperation(rel.getSetOp)
+        case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
+        case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop)
+        case proto.Relation.RelTypeCase.AGGREGATE => 
transformAggregate(rel.getAggregate)
+        case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
+        case proto.Relation.RelTypeCase.WITH_RELATIONS
+            if isValidSQLWithRefs(rel.getWithRelations) =>
+          transformSqlWithRefs(rel.getWithRelations)
+        case proto.Relation.RelTypeCase.LOCAL_RELATION =>
+          transformLocalRelation(rel.getLocalRelation)
+        case proto.Relation.RelTypeCase.SAMPLE => 
transformSample(rel.getSample)
+        case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange)
+        case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
+          transformSubqueryAlias(rel.getSubqueryAlias)
+        case proto.Relation.RelTypeCase.REPARTITION => 
transformRepartition(rel.getRepartition)
+        case proto.Relation.RelTypeCase.FILL_NA => 
transformNAFill(rel.getFillNa)
+        case proto.Relation.RelTypeCase.DROP_NA => 
transformNADrop(rel.getDropNa)
+        case proto.Relation.RelTypeCase.REPLACE => 
transformReplace(rel.getReplace)
+        case proto.Relation.RelTypeCase.SUMMARY => 
transformStatSummary(rel.getSummary)
+        case proto.Relation.RelTypeCase.DESCRIBE => 
transformStatDescribe(rel.getDescribe)
+        case proto.Relation.RelTypeCase.COV => transformStatCov(rel.getCov)
+        case proto.Relation.RelTypeCase.CORR => transformStatCorr(rel.getCorr)
+        case proto.Relation.RelTypeCase.APPROX_QUANTILE =>
+          transformStatApproxQuantile(rel.getApproxQuantile)
+        case proto.Relation.RelTypeCase.CROSSTAB =>
+          transformStatCrosstab(rel.getCrosstab)
+        case proto.Relation.RelTypeCase.FREQ_ITEMS => 
transformStatFreqItems(rel.getFreqItems)
+        case proto.Relation.RelTypeCase.SAMPLE_BY =>
+          transformStatSampleBy(rel.getSampleBy)
+        case proto.Relation.RelTypeCase.TO_SCHEMA => 
transformToSchema(rel.getToSchema)
+        case proto.Relation.RelTypeCase.TO_DF =>
+          transformToDF(rel.getToDf)
+        case proto.Relation.RelTypeCase.WITH_COLUMNS_RENAMED =>
+          transformWithColumnsRenamed(rel.getWithColumnsRenamed)
+        case proto.Relation.RelTypeCase.WITH_COLUMNS => 
transformWithColumns(rel.getWithColumns)
+        case proto.Relation.RelTypeCase.WITH_WATERMARK =>
+          transformWithWatermark(rel.getWithWatermark)
+        case proto.Relation.RelTypeCase.CACHED_LOCAL_RELATION =>
+          transformCachedLocalRelation(rel.getCachedLocalRelation)
+        case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint)
+        case proto.Relation.RelTypeCase.UNPIVOT => 
transformUnpivot(rel.getUnpivot)
+        case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION =>
+          transformRepartitionByExpression(rel.getRepartitionByExpression)
+        case proto.Relation.RelTypeCase.MAP_PARTITIONS =>
+          transformMapPartitions(rel.getMapPartitions)
+        case proto.Relation.RelTypeCase.GROUP_MAP =>
+          transformGroupMap(rel.getGroupMap)
+        case proto.Relation.RelTypeCase.CO_GROUP_MAP =>
+          transformCoGroupMap(rel.getCoGroupMap)
+        case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE =>
+          transformApplyInPandasWithState(rel.getApplyInPandasWithState)
+        case 
proto.Relation.RelTypeCase.COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION =>
+          transformCommonInlineUserDefinedTableFunction(
+            rel.getCommonInlineUserDefinedTableFunction)
+        case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION =>
+          transformCachedRemoteRelation(rel.getCachedRemoteRelation)
+        case proto.Relation.RelTypeCase.COLLECT_METRICS =>
+          transformCollectMetrics(rel.getCollectMetrics, 
rel.getCommon.getPlanId)
+        case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
+        case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
+          throw new IndexOutOfBoundsException("Expected Relation to be set, 
but is empty.")
+
+        // Catalog API (internal-only)
+        case proto.Relation.RelTypeCase.CATALOG => 
transformCatalog(rel.getCatalog)
+
+        // Handle plugins for Spark Connect Relation types.
+        case proto.Relation.RelTypeCase.EXTENSION =>
+          transformRelationPlugin(rel.getExtension)
+        case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
+      }
+      if (rel.hasCommon && rel.getCommon.hasPlanId) {
+        plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId)
+      }
+      plan
+    }
   }
 
   @DeveloperApi
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 306b89148583..3dad57209982 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -27,14 +27,17 @@ import scala.jdk.CollectionConverters._
 import scala.util.Try
 
 import com.google.common.base.Ticker
-import com.google.common.cache.CacheBuilder
+import com.google.common.cache.{Cache, CacheBuilder}
 
-import org.apache.spark.{SparkException, SparkSQLException}
+import org.apache.spark.{SparkEnv, SparkException, SparkSQLException}
 import org.apache.spark.api.python.PythonFunction.PythonAccumulator
+import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.connect.common.InvalidPlanInput
+import org.apache.spark.sql.connect.config.Connect
 import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
 import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
 import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, 
ERROR_CACHE_TIMEOUT_SEC}
@@ -50,6 +53,27 @@ case class SessionKey(userId: String, sessionId: String)
 case class SessionHolder(userId: String, sessionId: String, session: 
SparkSession)
     extends Logging {
 
+  // Cache which stores recently resolved logical plans to improve the 
performance of plan analysis.
+  // Only plans that explicitly specify "cachePlan = true" in 
transformRelation will be cached.
+  // Analyzing a large plan may be expensive, and it is not uncommon to build 
the plan step-by-step
+  // with several analysis during the process. This cache aids the recursive 
analysis process by
+  // memorizing `LogicalPlan`s which may be a sub-tree in a subsequent plan.
+  private lazy val planCache: Option[Cache[proto.Relation, LogicalPlan]] = {
+    if (SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE) <= 0) {
+      logWarning(
+        s"Session plan cache is disabled due to non-positive cache size." +
+          s" Current value of '${Connect.CONNECT_SESSION_PLAN_CACHE_SIZE.key}' 
is" +
+          s" 
${SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE)}.")
+      None
+    } else {
+      Some(
+        CacheBuilder
+          .newBuilder()
+          
.maximumSize(SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE))
+          .build[proto.Relation, LogicalPlan]())
+    }
+  }
+
   // Time when the session was started.
   private val startTimeMs: Long = System.currentTimeMillis()
 
@@ -388,6 +412,57 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
    */
   private[connect] val pythonAccumulator: Option[PythonAccumulator] =
     Try(session.sparkContext.collectionAccumulator[Array[Byte]]).toOption
+
+  /**
+   * Transform a relation into a logical plan, using the plan cache if 
enabled. The plan cache is
+   * enable only if `spark.connect.session.planCache.maxSize` is greater than 
zero AND
+   * `spark.connect.session.planCache.enabled` is true.
+   * @param rel
+   *   The relation to transform.
+   * @param cachePlan
+   *   Whether to cache the result logical plan.
+   * @param transform
+   *   Function to transform the relation into a logical plan.
+   * @return
+   *   The logical plan.
+   */
+  private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)(
+      transform: proto.Relation => LogicalPlan): LogicalPlan = {
+    val planCacheEnabled =
+      
Option(session).forall(_.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, 
true))
+    // We only cache plans that have a plan ID.
+    val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId
+
+    def getPlanCache(rel: proto.Relation): Option[LogicalPlan] =
+      planCache match {
+        case Some(cache) if planCacheEnabled && hasPlanId =>
+          Option(cache.getIfPresent(rel)) match {
+            case Some(plan) =>
+              logDebug(s"Using cached plan for relation '$rel': $plan")
+              Some(plan)
+            case None => None
+          }
+        case _ => None
+      }
+    def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit =
+      planCache match {
+        case Some(cache) if planCacheEnabled && hasPlanId =>
+          cache.put(rel, plan)
+        case _ =>
+      }
+
+    getPlanCache(rel)
+      .getOrElse({
+        val plan = transform(rel)
+        if (cachePlan) {
+          putPlanCache(rel, plan)
+        }
+        plan
+      })
+  }
+
+  // For testing. Expose the plan cache for testing purposes.
+  private[service] def getPlanCache: Option[Cache[proto.Relation, 
LogicalPlan]] = planCache
 }
 
 object SessionHolder {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
index 3dfd29d6a8c6..6c5d95ac67d3 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
@@ -58,10 +58,12 @@ private[connect] class SparkConnectAnalyzeHandler(
     val session = sessionHolder.session
     val builder = proto.AnalyzePlanResponse.newBuilder()
 
+    def transformRelation(rel: proto.Relation) = 
planner.transformRelation(rel, cachePlan = true)
+
     request.getAnalyzeCase match {
       case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
         val schema = Dataset
-          .ofRows(session, 
planner.transformRelation(request.getSchema.getPlan.getRoot))
+          .ofRows(session, 
transformRelation(request.getSchema.getPlan.getRoot))
           .schema
         builder.setSchema(
           proto.AnalyzePlanResponse.Schema
@@ -71,7 +73,7 @@ private[connect] class SparkConnectAnalyzeHandler(
 
       case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
         val queryExecution = Dataset
-          .ofRows(session, 
planner.transformRelation(request.getExplain.getPlan.getRoot))
+          .ofRows(session, 
transformRelation(request.getExplain.getPlan.getRoot))
           .queryExecution
         val explainString = request.getExplain.getExplainMode match {
           case 
proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE =>
@@ -94,7 +96,7 @@ private[connect] class SparkConnectAnalyzeHandler(
 
       case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
         val schema = Dataset
-          .ofRows(session, 
planner.transformRelation(request.getTreeString.getPlan.getRoot))
+          .ofRows(session, 
transformRelation(request.getTreeString.getPlan.getRoot))
           .schema
         val treeString = if (request.getTreeString.hasLevel) {
           schema.treeString(request.getTreeString.getLevel)
@@ -109,7 +111,7 @@ private[connect] class SparkConnectAnalyzeHandler(
 
       case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
         val isLocal = Dataset
-          .ofRows(session, 
planner.transformRelation(request.getIsLocal.getPlan.getRoot))
+          .ofRows(session, 
transformRelation(request.getIsLocal.getPlan.getRoot))
           .isLocal
         builder.setIsLocal(
           proto.AnalyzePlanResponse.IsLocal
@@ -119,7 +121,7 @@ private[connect] class SparkConnectAnalyzeHandler(
 
       case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
         val isStreaming = Dataset
-          .ofRows(session, 
planner.transformRelation(request.getIsStreaming.getPlan.getRoot))
+          .ofRows(session, 
transformRelation(request.getIsStreaming.getPlan.getRoot))
           .isStreaming
         builder.setIsStreaming(
           proto.AnalyzePlanResponse.IsStreaming
@@ -129,7 +131,7 @@ private[connect] class SparkConnectAnalyzeHandler(
 
       case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
         val inputFiles = Dataset
-          .ofRows(session, 
planner.transformRelation(request.getInputFiles.getPlan.getRoot))
+          .ofRows(session, 
transformRelation(request.getInputFiles.getPlan.getRoot))
           .inputFiles
         builder.setInputFiles(
           proto.AnalyzePlanResponse.InputFiles
@@ -155,10 +157,10 @@ private[connect] class SparkConnectAnalyzeHandler(
       case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
         val target = Dataset.ofRows(
           session,
-          
planner.transformRelation(request.getSameSemantics.getTargetPlan.getRoot))
+          transformRelation(request.getSameSemantics.getTargetPlan.getRoot))
         val other = Dataset.ofRows(
           session,
-          
planner.transformRelation(request.getSameSemantics.getOtherPlan.getRoot))
+          transformRelation(request.getSameSemantics.getOtherPlan.getRoot))
         builder.setSameSemantics(
           proto.AnalyzePlanResponse.SameSemantics
             .newBuilder()
@@ -166,7 +168,7 @@ private[connect] class SparkConnectAnalyzeHandler(
 
       case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
         val semanticHash = Dataset
-          .ofRows(session, 
planner.transformRelation(request.getSemanticHash.getPlan.getRoot))
+          .ofRows(session, 
transformRelation(request.getSemanticHash.getPlan.getRoot))
           .semanticHash()
         builder.setSemanticHash(
           proto.AnalyzePlanResponse.SemanticHash
@@ -175,7 +177,7 @@ private[connect] class SparkConnectAnalyzeHandler(
 
       case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST =>
         val target = Dataset
-          .ofRows(session, 
planner.transformRelation(request.getPersist.getRelation))
+          .ofRows(session, transformRelation(request.getPersist.getRelation))
         if (request.getPersist.hasStorageLevel) {
           target.persist(
             
StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel))
@@ -186,7 +188,7 @@ private[connect] class SparkConnectAnalyzeHandler(
 
       case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST =>
         val target = Dataset
-          .ofRows(session, 
planner.transformRelation(request.getUnpersist.getRelation))
+          .ofRows(session, transformRelation(request.getUnpersist.getRelation))
         if (request.getUnpersist.hasBlocking) {
           target.unpersist(request.getUnpersist.getBlocking)
         } else {
@@ -196,7 +198,7 @@ private[connect] class SparkConnectAnalyzeHandler(
 
       case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL =>
         val target = Dataset
-          .ofRows(session, 
planner.transformRelation(request.getGetStorageLevel.getRelation))
+          .ofRows(session, 
transformRelation(request.getGetStorageLevel.getRelation))
         val storageLevel = target.storageLevel
         builder.setGetStorageLevel(
           proto.AnalyzePlanResponse.GetStorageLevel
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
index bb51b0a79820..62b4151aad8a 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
@@ -23,14 +23,18 @@ import java.nio.file.Files
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 import scala.sys.process.Process
+import scala.util.Random
 
 import com.google.common.collect.Lists
 import org.scalatest.time.SpanSugar._
 
+import org.apache.spark.SparkEnv
 import org.apache.spark.api.python.SimplePythonFunction
+import org.apache.spark.connect.proto
 import org.apache.spark.sql.IntegratedUDFTestUtils
 import org.apache.spark.sql.connect.common.InvalidPlanInput
-import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, 
StreamingForeachBatchHelper}
+import org.apache.spark.sql.connect.config.Connect
+import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, 
SparkConnectPlanner, StreamingForeachBatchHelper}
 import 
org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper.RunnerCleaner
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.util.ArrayImplicits._
@@ -289,4 +293,123 @@ class SparkConnectSessionHolderSuite extends 
SharedSparkSession {
       spark.streams.listListeners().foreach(spark.streams.removeListener)
     }
   }
+
+  private def buildRelation(query: String) = {
+    proto.Relation
+      .newBuilder()
+      .setSql(
+        proto.SQL
+          .newBuilder()
+          .setQuery(query)
+          .build())
+      
.setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build())
+      .build()
+  }
+
+  private def assertPlanCache(
+      sessionHolder: SessionHolder,
+      optionExpectedCachedRelations: Option[Set[proto.Relation]]) = {
+    optionExpectedCachedRelations match {
+      case Some(expectedCachedRelations) =>
+        val cachedRelations = 
sessionHolder.getPlanCache.get.asMap().keySet().asScala
+        assert(cachedRelations.size == expectedCachedRelations.size)
+        expectedCachedRelations.foreach(relation => 
assert(cachedRelations.contains(relation)))
+      case None => assert(sessionHolder.getPlanCache.isEmpty)
+    }
+  }
+
+  test("Test session plan cache") {
+    val sessionHolder = SessionHolder.forTesting(spark)
+    try {
+      // Set cache size to 2
+      SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 2)
+      val planner = new SparkConnectPlanner(sessionHolder)
+
+      val random1 = buildRelation("select 1")
+      val random2 = buildRelation("select 2")
+      val random3 = buildRelation("select 3")
+      val query1 = proto.Relation.newBuilder
+        .setLimit(
+          proto.Limit.newBuilder
+            .setLimit(10)
+            .setInput(
+              proto.Relation
+                .newBuilder()
+                
.setRange(proto.Range.newBuilder().setStart(0).setStep(1).setEnd(20))
+                .build()))
+        
.setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build())
+        .build()
+      val query2 = proto.Relation.newBuilder
+        .setLimit(proto.Limit.newBuilder.setLimit(5).setInput(query1))
+        
.setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build())
+        .build()
+
+      // If cachePlan is false, the cache is still empty.
+      planner.transformRelation(random1, cachePlan = false)
+      assertPlanCache(sessionHolder, Some(Set()))
+
+      // Put a random entry in cache.
+      planner.transformRelation(random1, cachePlan = true)
+      assertPlanCache(sessionHolder, Some(Set(random1)))
+
+      // Put another random entry in cache.
+      planner.transformRelation(random2, cachePlan = true)
+      assertPlanCache(sessionHolder, Some(Set(random1, random2)))
+
+      // Analyze query1. We only cache the root relation, and the random1 is 
evicted.
+      planner.transformRelation(query1, cachePlan = true)
+      assertPlanCache(sessionHolder, Some(Set(random2, query1)))
+
+      // Put another random entry in cache.
+      planner.transformRelation(random3, cachePlan = true)
+      assertPlanCache(sessionHolder, Some(Set(query1, random3)))
+
+      // Analyze query2. As query1 is accessed during the process, it should 
be in the cache.
+      planner.transformRelation(query2, cachePlan = true)
+      assertPlanCache(sessionHolder, Some(Set(query1, query2)))
+    } finally {
+      // Set back to default value.
+      SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 5)
+    }
+  }
+
+  test("Test session plan cache - cache size zero or negative") {
+    val sessionHolder = SessionHolder.forTesting(spark)
+    try {
+      // Set cache size to -1
+      SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, -1)
+      val planner = new SparkConnectPlanner(sessionHolder)
+
+      val query = buildRelation("select 1")
+
+      // If cachePlan is false, the cache is still None.
+      planner.transformRelation(query, cachePlan = false)
+      assertPlanCache(sessionHolder, None)
+
+      // Even if we specify "cachePlan = true", the cache is still None.
+      planner.transformRelation(query, cachePlan = true)
+      assertPlanCache(sessionHolder, None)
+    } finally {
+      // Set back to default value.
+      SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 5)
+    }
+  }
+
+  test("Test session plan cache - disabled") {
+    val sessionHolder = SessionHolder.forTesting(spark)
+    // Disable plan cache of the session
+    sessionHolder.session.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, 
false)
+    val planner = new SparkConnectPlanner(sessionHolder)
+
+    val query = buildRelation("select 1")
+
+    // If cachePlan is false, the cache is still empty.
+    // Although the cache is created as cache size is greater than zero, it 
won't be used.
+    planner.transformRelation(query, cachePlan = false)
+    assertPlanCache(sessionHolder, Some(Set()))
+
+    // Even if we specify "cachePlan = true", the cache is still empty.
+    planner.transformRelation(query, cachePlan = true)
+    assertPlanCache(sessionHolder, Some(Set()))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to