This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a1fc6d57b27d [SPARK-47818][CONNECT] Introduce plan cache in SparkConnectPlanner to improve performance of Analyze requests a1fc6d57b27d is described below commit a1fc6d57b27d24b832b2f2580e6acd64c4488c62 Author: Xi Lyu <xi....@databricks.com> AuthorDate: Tue Apr 16 16:27:32 2024 +0900 [SPARK-47818][CONNECT] Introduce plan cache in SparkConnectPlanner to improve performance of Analyze requests ### What changes were proposed in this pull request? While building the DataFrame step by step, each time a new DataFrame is generated with an empty schema, which is lazily computed on access. However, if a user's code frequently accesses the schema of these new DataFrames using methods such as `df.columns`, it will result in a large number of Analyze requests to the server. Each time, the entire plan needs to be reanalyzed, leading to poor performance, especially when constructing highly complex plans. Now, by introducing plan cache in SparkConnectPlanner, we aim to reduce the overhead of repeated analysis during this process. This is achieved by saving significant computation if the resolved logical plan of a subtree of can be cached. A minimal example of the problem: ``` import pyspark.sql.functions as F df = spark.range(10) for i in range(200): if str(i) not in df.columns: # <-- The df.columns call causes a new Analyze request in every iteration df = df.withColumn(str(i), F.col("id") + i) df.show() ``` With this patch, the performance of the above code improved from ~110s to ~5s. ### Why are the changes needed? The performance improvement is huge in the above cases. ### Does this PR introduce _any_ user-facing change? Yes, a static conf `spark.connect.session.planCache.maxSize` and a dynamic conf `spark.connect.session.planCache.enabled` are added. * `spark.connect.session.planCache.maxSize`: Sets the maximum number of cached resolved logical plans in Spark Connect Session. If set to a value less or equal than zero will disable the plan cache * `spark.connect.session.planCache.enabled`: When true, the cache of resolved logical plans is enabled if `spark.connect.session.planCache.maxSize` is greater than zero. When false, the cache is disabled even if `spark.connect.session.planCache.maxSize` is greater than zero. The caching is best-effort and not guaranteed. ### How was this patch tested? Some new tests are added in SparkConnectSessionHolderSuite.scala. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46012 from xi-db/SPARK-47818-plan-cache. Lead-authored-by: Xi Lyu <xi....@databricks.com> Co-authored-by: Xi Lyu <159039256+xi...@users.noreply.github.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../apache/spark/sql/connect/config/Connect.scala | 18 ++ .../sql/connect/planner/SparkConnectPlanner.scala | 201 ++++++++++++--------- .../spark/sql/connect/service/SessionHolder.scala | 79 +++++++- .../service/SparkConnectAnalyzeHandler.scala | 26 +-- .../service/SparkConnectSessionHolderSuite.scala | 125 ++++++++++++- 5 files changed, 345 insertions(+), 104 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 6ba100af1bb9..e94e86587393 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -273,4 +273,22 @@ object Connect { .version("4.0.0") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("2s") + + val CONNECT_SESSION_PLAN_CACHE_SIZE = + buildStaticConf("spark.connect.session.planCache.maxSize") + .doc("Sets the maximum number of cached resolved logical plans in Spark Connect Session." + + " If set to a value less or equal than zero will disable the plan cache.") + .version("4.0.0") + .intConf + .createWithDefault(5) + + val CONNECT_SESSION_PLAN_CACHE_ENABLED = + buildConf("spark.connect.session.planCache.enabled") + .doc("When true, the cache of resolved logical plans is enabled if" + + s" '${CONNECT_SESSION_PLAN_CACHE_SIZE.key}' is greater than zero." + + s" When false, the cache is disabled even if '${CONNECT_SESSION_PLAN_CACHE_SIZE.key}' is" + + " greater than zero. The caching is best-effort and not guaranteed.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5e7f3b74c299..d8eb044e4f94 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -115,95 +115,118 @@ class SparkConnectPlanner( private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) - // The root of the query plan is a relation and we apply the transformations to it. - def transformRelation(rel: proto.Relation): LogicalPlan = { - val plan = rel.getRelTypeCase match { - // DataFrame API - case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString) - case proto.Relation.RelTypeCase.HTML_STRING => transformHtmlString(rel.getHtmlString) - case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead) - case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject) - case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter) - case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit) - case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset) - case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail) - case proto.Relation.RelTypeCase.JOIN => transformJoinOrJoinWith(rel.getJoin) - case proto.Relation.RelTypeCase.AS_OF_JOIN => transformAsOfJoin(rel.getAsOfJoin) - case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate) - case proto.Relation.RelTypeCase.SET_OP => transformSetOperation(rel.getSetOp) - case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) - case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop) - case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) - case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) - case proto.Relation.RelTypeCase.WITH_RELATIONS - if isValidSQLWithRefs(rel.getWithRelations) => - transformSqlWithRefs(rel.getWithRelations) - case proto.Relation.RelTypeCase.LOCAL_RELATION => - transformLocalRelation(rel.getLocalRelation) - case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample) - case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange) - case proto.Relation.RelTypeCase.SUBQUERY_ALIAS => - transformSubqueryAlias(rel.getSubqueryAlias) - case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition) - case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa) - case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa) - case proto.Relation.RelTypeCase.REPLACE => transformReplace(rel.getReplace) - case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary) - case proto.Relation.RelTypeCase.DESCRIBE => transformStatDescribe(rel.getDescribe) - case proto.Relation.RelTypeCase.COV => transformStatCov(rel.getCov) - case proto.Relation.RelTypeCase.CORR => transformStatCorr(rel.getCorr) - case proto.Relation.RelTypeCase.APPROX_QUANTILE => - transformStatApproxQuantile(rel.getApproxQuantile) - case proto.Relation.RelTypeCase.CROSSTAB => - transformStatCrosstab(rel.getCrosstab) - case proto.Relation.RelTypeCase.FREQ_ITEMS => transformStatFreqItems(rel.getFreqItems) - case proto.Relation.RelTypeCase.SAMPLE_BY => - transformStatSampleBy(rel.getSampleBy) - case proto.Relation.RelTypeCase.TO_SCHEMA => transformToSchema(rel.getToSchema) - case proto.Relation.RelTypeCase.TO_DF => - transformToDF(rel.getToDf) - case proto.Relation.RelTypeCase.WITH_COLUMNS_RENAMED => - transformWithColumnsRenamed(rel.getWithColumnsRenamed) - case proto.Relation.RelTypeCase.WITH_COLUMNS => transformWithColumns(rel.getWithColumns) - case proto.Relation.RelTypeCase.WITH_WATERMARK => - transformWithWatermark(rel.getWithWatermark) - case proto.Relation.RelTypeCase.CACHED_LOCAL_RELATION => - transformCachedLocalRelation(rel.getCachedLocalRelation) - case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint) - case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot) - case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION => - transformRepartitionByExpression(rel.getRepartitionByExpression) - case proto.Relation.RelTypeCase.MAP_PARTITIONS => - transformMapPartitions(rel.getMapPartitions) - case proto.Relation.RelTypeCase.GROUP_MAP => - transformGroupMap(rel.getGroupMap) - case proto.Relation.RelTypeCase.CO_GROUP_MAP => - transformCoGroupMap(rel.getCoGroupMap) - case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE => - transformApplyInPandasWithState(rel.getApplyInPandasWithState) - case proto.Relation.RelTypeCase.COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION => - transformCommonInlineUserDefinedTableFunction(rel.getCommonInlineUserDefinedTableFunction) - case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION => - transformCachedRemoteRelation(rel.getCachedRemoteRelation) - case proto.Relation.RelTypeCase.COLLECT_METRICS => - transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId) - case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse) - case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => - throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") - - // Catalog API (internal-only) - case proto.Relation.RelTypeCase.CATALOG => transformCatalog(rel.getCatalog) - - // Handle plugins for Spark Connect Relation types. - case proto.Relation.RelTypeCase.EXTENSION => - transformRelationPlugin(rel.getExtension) - case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") - } - - if (rel.hasCommon && rel.getCommon.hasPlanId) { - plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId) - } - plan + /** + * The root of the query plan is a relation and we apply the transformations to it. The resolved + * logical plan will not get cached. If the result needs to be cached, use + * `transformRelation(rel, cachePlan = true)` instead. + * @param rel + * The relation to transform. + * @return + * The resolved logical plan. + */ + def transformRelation(rel: proto.Relation): LogicalPlan = + transformRelation(rel, cachePlan = false) + + /** + * The root of the query plan is a relation and we apply the transformations to it. + * @param rel + * The relation to transform. + * @param cachePlan + * Set to true for a performance optimization, if the plan is likely to be reused, e.g. built + * upon by further dataset transformation. The default is false. + * @return + * The resolved logical plan. + */ + def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan = { + sessionHolder.usePlanCache(rel, cachePlan) { rel => + val plan = rel.getRelTypeCase match { + // DataFrame API + case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString) + case proto.Relation.RelTypeCase.HTML_STRING => transformHtmlString(rel.getHtmlString) + case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead) + case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject) + case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter) + case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit) + case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset) + case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail) + case proto.Relation.RelTypeCase.JOIN => transformJoinOrJoinWith(rel.getJoin) + case proto.Relation.RelTypeCase.AS_OF_JOIN => transformAsOfJoin(rel.getAsOfJoin) + case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate) + case proto.Relation.RelTypeCase.SET_OP => transformSetOperation(rel.getSetOp) + case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) + case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop) + case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) + case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) + case proto.Relation.RelTypeCase.WITH_RELATIONS + if isValidSQLWithRefs(rel.getWithRelations) => + transformSqlWithRefs(rel.getWithRelations) + case proto.Relation.RelTypeCase.LOCAL_RELATION => + transformLocalRelation(rel.getLocalRelation) + case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample) + case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange) + case proto.Relation.RelTypeCase.SUBQUERY_ALIAS => + transformSubqueryAlias(rel.getSubqueryAlias) + case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition) + case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa) + case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa) + case proto.Relation.RelTypeCase.REPLACE => transformReplace(rel.getReplace) + case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary) + case proto.Relation.RelTypeCase.DESCRIBE => transformStatDescribe(rel.getDescribe) + case proto.Relation.RelTypeCase.COV => transformStatCov(rel.getCov) + case proto.Relation.RelTypeCase.CORR => transformStatCorr(rel.getCorr) + case proto.Relation.RelTypeCase.APPROX_QUANTILE => + transformStatApproxQuantile(rel.getApproxQuantile) + case proto.Relation.RelTypeCase.CROSSTAB => + transformStatCrosstab(rel.getCrosstab) + case proto.Relation.RelTypeCase.FREQ_ITEMS => transformStatFreqItems(rel.getFreqItems) + case proto.Relation.RelTypeCase.SAMPLE_BY => + transformStatSampleBy(rel.getSampleBy) + case proto.Relation.RelTypeCase.TO_SCHEMA => transformToSchema(rel.getToSchema) + case proto.Relation.RelTypeCase.TO_DF => + transformToDF(rel.getToDf) + case proto.Relation.RelTypeCase.WITH_COLUMNS_RENAMED => + transformWithColumnsRenamed(rel.getWithColumnsRenamed) + case proto.Relation.RelTypeCase.WITH_COLUMNS => transformWithColumns(rel.getWithColumns) + case proto.Relation.RelTypeCase.WITH_WATERMARK => + transformWithWatermark(rel.getWithWatermark) + case proto.Relation.RelTypeCase.CACHED_LOCAL_RELATION => + transformCachedLocalRelation(rel.getCachedLocalRelation) + case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint) + case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot) + case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION => + transformRepartitionByExpression(rel.getRepartitionByExpression) + case proto.Relation.RelTypeCase.MAP_PARTITIONS => + transformMapPartitions(rel.getMapPartitions) + case proto.Relation.RelTypeCase.GROUP_MAP => + transformGroupMap(rel.getGroupMap) + case proto.Relation.RelTypeCase.CO_GROUP_MAP => + transformCoGroupMap(rel.getCoGroupMap) + case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE => + transformApplyInPandasWithState(rel.getApplyInPandasWithState) + case proto.Relation.RelTypeCase.COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION => + transformCommonInlineUserDefinedTableFunction( + rel.getCommonInlineUserDefinedTableFunction) + case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION => + transformCachedRemoteRelation(rel.getCachedRemoteRelation) + case proto.Relation.RelTypeCase.COLLECT_METRICS => + transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId) + case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse) + case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => + throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") + + // Catalog API (internal-only) + case proto.Relation.RelTypeCase.CATALOG => transformCatalog(rel.getCatalog) + + // Handle plugins for Spark Connect Relation types. + case proto.Relation.RelTypeCase.EXTENSION => + transformRelationPlugin(rel.getExtension) + case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") + } + if (rel.hasCommon && rel.getCommon.hasPlanId) { + plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId) + } + plan + } } @DeveloperApi diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 306b89148583..3dad57209982 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -27,14 +27,17 @@ import scala.jdk.CollectionConverters._ import scala.util.Try import com.google.common.base.Ticker -import com.google.common.cache.CacheBuilder +import com.google.common.cache.{Cache, CacheBuilder} -import org.apache.spark.{SparkException, SparkSQLException} +import org.apache.spark.{SparkEnv, SparkException, SparkSQLException} import org.apache.spark.api.python.PythonFunction.PythonAccumulator +import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.common.InvalidPlanInput +import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC} @@ -50,6 +53,27 @@ case class SessionKey(userId: String, sessionId: String) case class SessionHolder(userId: String, sessionId: String, session: SparkSession) extends Logging { + // Cache which stores recently resolved logical plans to improve the performance of plan analysis. + // Only plans that explicitly specify "cachePlan = true" in transformRelation will be cached. + // Analyzing a large plan may be expensive, and it is not uncommon to build the plan step-by-step + // with several analysis during the process. This cache aids the recursive analysis process by + // memorizing `LogicalPlan`s which may be a sub-tree in a subsequent plan. + private lazy val planCache: Option[Cache[proto.Relation, LogicalPlan]] = { + if (SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE) <= 0) { + logWarning( + s"Session plan cache is disabled due to non-positive cache size." + + s" Current value of '${Connect.CONNECT_SESSION_PLAN_CACHE_SIZE.key}' is" + + s" ${SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE)}.") + None + } else { + Some( + CacheBuilder + .newBuilder() + .maximumSize(SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE)) + .build[proto.Relation, LogicalPlan]()) + } + } + // Time when the session was started. private val startTimeMs: Long = System.currentTimeMillis() @@ -388,6 +412,57 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ private[connect] val pythonAccumulator: Option[PythonAccumulator] = Try(session.sparkContext.collectionAccumulator[Array[Byte]]).toOption + + /** + * Transform a relation into a logical plan, using the plan cache if enabled. The plan cache is + * enable only if `spark.connect.session.planCache.maxSize` is greater than zero AND + * `spark.connect.session.planCache.enabled` is true. + * @param rel + * The relation to transform. + * @param cachePlan + * Whether to cache the result logical plan. + * @param transform + * Function to transform the relation into a logical plan. + * @return + * The logical plan. + */ + private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)( + transform: proto.Relation => LogicalPlan): LogicalPlan = { + val planCacheEnabled = + Option(session).forall(_.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) + // We only cache plans that have a plan ID. + val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId + + def getPlanCache(rel: proto.Relation): Option[LogicalPlan] = + planCache match { + case Some(cache) if planCacheEnabled && hasPlanId => + Option(cache.getIfPresent(rel)) match { + case Some(plan) => + logDebug(s"Using cached plan for relation '$rel': $plan") + Some(plan) + case None => None + } + case _ => None + } + def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit = + planCache match { + case Some(cache) if planCacheEnabled && hasPlanId => + cache.put(rel, plan) + case _ => + } + + getPlanCache(rel) + .getOrElse({ + val plan = transform(rel) + if (cachePlan) { + putPlanCache(rel, plan) + } + plan + }) + } + + // For testing. Expose the plan cache for testing purposes. + private[service] def getPlanCache: Option[Cache[proto.Relation, LogicalPlan]] = planCache } object SessionHolder { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index 3dfd29d6a8c6..6c5d95ac67d3 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -58,10 +58,12 @@ private[connect] class SparkConnectAnalyzeHandler( val session = sessionHolder.session val builder = proto.AnalyzePlanResponse.newBuilder() + def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true) + request.getAnalyzeCase match { case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA => val schema = Dataset - .ofRows(session, planner.transformRelation(request.getSchema.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getSchema.getPlan.getRoot)) .schema builder.setSchema( proto.AnalyzePlanResponse.Schema @@ -71,7 +73,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN => val queryExecution = Dataset - .ofRows(session, planner.transformRelation(request.getExplain.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getExplain.getPlan.getRoot)) .queryExecution val explainString = request.getExplain.getExplainMode match { case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE => @@ -94,7 +96,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING => val schema = Dataset - .ofRows(session, planner.transformRelation(request.getTreeString.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot)) .schema val treeString = if (request.getTreeString.hasLevel) { schema.treeString(request.getTreeString.getLevel) @@ -109,7 +111,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL => val isLocal = Dataset - .ofRows(session, planner.transformRelation(request.getIsLocal.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot)) .isLocal builder.setIsLocal( proto.AnalyzePlanResponse.IsLocal @@ -119,7 +121,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING => val isStreaming = Dataset - .ofRows(session, planner.transformRelation(request.getIsStreaming.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot)) .isStreaming builder.setIsStreaming( proto.AnalyzePlanResponse.IsStreaming @@ -129,7 +131,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES => val inputFiles = Dataset - .ofRows(session, planner.transformRelation(request.getInputFiles.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot)) .inputFiles builder.setInputFiles( proto.AnalyzePlanResponse.InputFiles @@ -155,10 +157,10 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS => val target = Dataset.ofRows( session, - planner.transformRelation(request.getSameSemantics.getTargetPlan.getRoot)) + transformRelation(request.getSameSemantics.getTargetPlan.getRoot)) val other = Dataset.ofRows( session, - planner.transformRelation(request.getSameSemantics.getOtherPlan.getRoot)) + transformRelation(request.getSameSemantics.getOtherPlan.getRoot)) builder.setSameSemantics( proto.AnalyzePlanResponse.SameSemantics .newBuilder() @@ -166,7 +168,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH => val semanticHash = Dataset - .ofRows(session, planner.transformRelation(request.getSemanticHash.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot)) .semanticHash() builder.setSemanticHash( proto.AnalyzePlanResponse.SemanticHash @@ -175,7 +177,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST => val target = Dataset - .ofRows(session, planner.transformRelation(request.getPersist.getRelation)) + .ofRows(session, transformRelation(request.getPersist.getRelation)) if (request.getPersist.hasStorageLevel) { target.persist( StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel)) @@ -186,7 +188,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST => val target = Dataset - .ofRows(session, planner.transformRelation(request.getUnpersist.getRelation)) + .ofRows(session, transformRelation(request.getUnpersist.getRelation)) if (request.getUnpersist.hasBlocking) { target.unpersist(request.getUnpersist.getBlocking) } else { @@ -196,7 +198,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL => val target = Dataset - .ofRows(session, planner.transformRelation(request.getGetStorageLevel.getRelation)) + .ofRows(session, transformRelation(request.getGetStorageLevel.getRelation)) val storageLevel = target.storageLevel builder.setGetStorageLevel( proto.AnalyzePlanResponse.GetStorageLevel diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index bb51b0a79820..62b4151aad8a 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -23,14 +23,18 @@ import java.nio.file.Files import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.sys.process.Process +import scala.util.Random import com.google.common.collect.Lists import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkEnv import org.apache.spark.api.python.SimplePythonFunction +import org.apache.spark.connect.proto import org.apache.spark.sql.IntegratedUDFTestUtils import org.apache.spark.sql.connect.common.InvalidPlanInput -import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, StreamingForeachBatchHelper} +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, SparkConnectPlanner, StreamingForeachBatchHelper} import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper.RunnerCleaner import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.ArrayImplicits._ @@ -289,4 +293,123 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { spark.streams.listListeners().foreach(spark.streams.removeListener) } } + + private def buildRelation(query: String) = { + proto.Relation + .newBuilder() + .setSql( + proto.SQL + .newBuilder() + .setQuery(query) + .build()) + .setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build()) + .build() + } + + private def assertPlanCache( + sessionHolder: SessionHolder, + optionExpectedCachedRelations: Option[Set[proto.Relation]]) = { + optionExpectedCachedRelations match { + case Some(expectedCachedRelations) => + val cachedRelations = sessionHolder.getPlanCache.get.asMap().keySet().asScala + assert(cachedRelations.size == expectedCachedRelations.size) + expectedCachedRelations.foreach(relation => assert(cachedRelations.contains(relation))) + case None => assert(sessionHolder.getPlanCache.isEmpty) + } + } + + test("Test session plan cache") { + val sessionHolder = SessionHolder.forTesting(spark) + try { + // Set cache size to 2 + SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 2) + val planner = new SparkConnectPlanner(sessionHolder) + + val random1 = buildRelation("select 1") + val random2 = buildRelation("select 2") + val random3 = buildRelation("select 3") + val query1 = proto.Relation.newBuilder + .setLimit( + proto.Limit.newBuilder + .setLimit(10) + .setInput( + proto.Relation + .newBuilder() + .setRange(proto.Range.newBuilder().setStart(0).setStep(1).setEnd(20)) + .build())) + .setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build()) + .build() + val query2 = proto.Relation.newBuilder + .setLimit(proto.Limit.newBuilder.setLimit(5).setInput(query1)) + .setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build()) + .build() + + // If cachePlan is false, the cache is still empty. + planner.transformRelation(random1, cachePlan = false) + assertPlanCache(sessionHolder, Some(Set())) + + // Put a random entry in cache. + planner.transformRelation(random1, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set(random1))) + + // Put another random entry in cache. + planner.transformRelation(random2, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set(random1, random2))) + + // Analyze query1. We only cache the root relation, and the random1 is evicted. + planner.transformRelation(query1, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set(random2, query1))) + + // Put another random entry in cache. + planner.transformRelation(random3, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set(query1, random3))) + + // Analyze query2. As query1 is accessed during the process, it should be in the cache. + planner.transformRelation(query2, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set(query1, query2))) + } finally { + // Set back to default value. + SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 5) + } + } + + test("Test session plan cache - cache size zero or negative") { + val sessionHolder = SessionHolder.forTesting(spark) + try { + // Set cache size to -1 + SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, -1) + val planner = new SparkConnectPlanner(sessionHolder) + + val query = buildRelation("select 1") + + // If cachePlan is false, the cache is still None. + planner.transformRelation(query, cachePlan = false) + assertPlanCache(sessionHolder, None) + + // Even if we specify "cachePlan = true", the cache is still None. + planner.transformRelation(query, cachePlan = true) + assertPlanCache(sessionHolder, None) + } finally { + // Set back to default value. + SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 5) + } + } + + test("Test session plan cache - disabled") { + val sessionHolder = SessionHolder.forTesting(spark) + // Disable plan cache of the session + sessionHolder.session.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, false) + val planner = new SparkConnectPlanner(sessionHolder) + + val query = buildRelation("select 1") + + // If cachePlan is false, the cache is still empty. + // Although the cache is created as cache size is greater than zero, it won't be used. + planner.transformRelation(query, cachePlan = false) + assertPlanCache(sessionHolder, Some(Set())) + + // Even if we specify "cachePlan = true", the cache is still empty. + planner.transformRelation(query, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set())) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org