AnishMahto commented on code in PR #51544:
URL: https://github.com/apache/spark/pull/51544#discussion_r2216618046


##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala:
##########
@@ -486,6 +493,48 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
     Option(pipelineExecutions.get(graphId))
   }
 
+  private[connect] def createDataflowGraph(
+      defaultCatalog: String,
+      defaultDatabase: String,
+      defaultSqlConf: Map[String, String]): String = {
+    dataflowGraphRegistry.createDataflowGraph(defaultCatalog, defaultDatabase, 
defaultSqlConf)
+  }
+
+  /**
+   * Retrieves the dataflow graph for the given graph ID.
+   */
+  private[connect] def getDataflowGraph(graphId: String): 
Option[GraphRegistrationContext] = {
+    dataflowGraphRegistry.getDataflowGraph(graphId)
+  }
+
+  /**
+   * Retrieves the dataflow graph for the given graph ID, throwing if not 
found.
+   */
+  private[connect] def getDataflowGraphOrThrow(graphId: String): 
GraphRegistrationContext = {
+    dataflowGraphRegistry.getDataflowGraphOrThrow(graphId)
+  }
+
+  /**
+   * Removes the dataflow graph with the given ID.
+   */
+  private[connect] def dropDataflowGraph(graphId: String): Unit = {
+    dataflowGraphRegistry.dropDataflowGraph(graphId)
+  }
+
+  /**
+   * Returns all dataflow graphs in this session.
+   */
+  private[connect] def getAllDataflowGraphs: Seq[GraphRegistrationContext] = {
+    dataflowGraphRegistry.getAllDataflowGraphs
+  }
+
+  /**
+   * Removes all dataflow graphs from this session. Called during session 
cleanup.
+   */
+  private[connect] def dropAllDataflowGraphs(): Unit = {
+    dataflowGraphRegistry.dropAllDataflowGraphs()
+  }
+

Review Comment:
   Is there any particular reason why we added these delegator methods, rather 
than just having callers call `SessionHolder.dataflowGraphRegistry.blah()`?
   
   If its for access modifier reasons, why not just do `private[connect] lazy 
val dataflowGraphRegistry`?



##########
sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala:
##########
@@ -78,8 +83,16 @@ class PythonPipelineSuite
       throw new RuntimeException(
         s"Python process failed with exit code $exitCode. Output: 
${output.mkString("\n")}")
     }
+    val activateSessions = 
SparkConnectService.sessionManager.listActiveSessions

Review Comment:
   `val activeSessions`



##########
sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala:
##########
@@ -251,4 +251,243 @@ class SparkDeclarativePipelinesServerSuite
       assert(spark.table("spark_catalog.other.tableD").count() == 5)
     }
   }
+
+  test("dataflow graphs are session-specific") {
+    withRawBlockingStub { implicit stub =>
+      // Create a dataflow graph in the default session
+      val graphId1 = createDataflowGraph
+
+      // Register a dataset in the default session
+      sendPlan(
+        buildPlanFromPipelineCommand(
+          PipelineCommand
+            .newBuilder()
+            .setDefineDataset(
+              DefineDataset
+                .newBuilder()
+                .setDataflowGraphId(graphId1)
+                .setDatasetName("session1_table")
+                .setDatasetType(DatasetType.MATERIALIZED_VIEW))
+            .build()))
+
+      // Verify the graph exists in the default session
+      assert(getDefaultSessionHolder.getAllDataflowGraphs.size == 1)
+    }
+
+    // Create a second session with different user/session ID
+    val newSessionId = UUID.randomUUID().toString
+    val newSessionUserId = "session2_user"
+
+    withRawBlockingStub { implicit stub =>
+      // Override the test context to use different session
+      val newSessionExecuteRequest = buildExecutePlanRequest(
+        buildCreateDataflowGraphPlan(
+          proto.PipelineCommand.CreateDataflowGraph
+            .newBuilder()
+            .setDefaultCatalog("spark_catalog")
+            .setDefaultDatabase("default")
+            .build())).toBuilder
+        .setUserContext(proto.UserContext
+          .newBuilder()
+          .setUserId(newSessionUserId)
+          .build())
+        .setSessionId(newSessionId)
+        .build()
+
+      val response = stub.executePlan(newSessionExecuteRequest)
+      val graphId2 =
+        
response.next().getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId
+
+      // Register a different dataset in second session
+      val session2DefineRequest = buildExecutePlanRequest(
+        buildPlanFromPipelineCommand(
+          PipelineCommand
+            .newBuilder()
+            .setDefineDataset(
+              DefineDataset
+                .newBuilder()
+                .setDataflowGraphId(graphId2)
+                .setDatasetName("session2_table")
+                .setDatasetType(DatasetType.MATERIALIZED_VIEW))
+            .build())).toBuilder
+        .setUserContext(proto.UserContext
+          .newBuilder()
+          .setUserId(newSessionUserId)
+          .build())
+        .setSessionId(newSessionId)
+        .build()
+
+      stub.executePlan(session2DefineRequest).next()
+
+      // Verify session isolation - each session should only see its own graphs
+      val newSessionHolder = SparkConnectService.sessionManager
+        .getIsolatedSessionIfPresent(SessionKey(newSessionUserId, 
newSessionId))
+        .getOrElse(throw new RuntimeException("New session not found"))
+
+      val defaultSessionGraphs = getDefaultSessionHolder.getAllDataflowGraphs
+      val newSessionGraphs = newSessionHolder.getAllDataflowGraphs
+
+      assert(defaultSessionGraphs.size == 1)
+      assert(newSessionGraphs.size == 1)
+
+      assert(
+        defaultSessionGraphs.head.toDataflowGraph.tables
+          .exists(_.identifier.table == "session1_table"),
+        "Session 1 should have its own table")
+      assert(
+        newSessionGraphs.head.toDataflowGraph.tables
+          .exists(_.identifier.table == "session2_table"),
+        "Session 2 should have its own table")
+    }
+  }
+
+  test("dataflow graphs are cleaned up when session is closed") {
+    val testUserId = "test_user"
+    val testSessionId = UUID.randomUUID().toString
+
+    // Create a session and dataflow graph
+    withRawBlockingStub { implicit stub =>
+      val createGraphRequest = buildExecutePlanRequest(
+        buildCreateDataflowGraphPlan(
+          proto.PipelineCommand.CreateDataflowGraph
+            .newBuilder()
+            .setDefaultCatalog("spark_catalog")
+            .setDefaultDatabase("default")
+            .build())).toBuilder
+        .setUserContext(proto.UserContext
+          .newBuilder()
+          .setUserId(testUserId)
+          .build())
+        .setSessionId(testSessionId)
+        .build()
+
+      val response = stub.executePlan(createGraphRequest)
+      val graphId =
+        
response.next().getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId
+
+      // Register a dataset
+      val defineRequest = buildExecutePlanRequest(
+        buildPlanFromPipelineCommand(
+          PipelineCommand
+            .newBuilder()
+            .setDefineDataset(
+              DefineDataset
+                .newBuilder()
+                .setDataflowGraphId(graphId)
+                .setDatasetName("test_table")
+                .setDatasetType(DatasetType.MATERIALIZED_VIEW))
+            .build())).toBuilder
+        .setUserContext(proto.UserContext
+          .newBuilder()
+          .setUserId(testUserId)
+          .build())
+        .setSessionId(testSessionId)
+        .build()
+
+      stub.executePlan(defineRequest).next()
+
+      // Verify the graph exists
+      val sessionHolder = SparkConnectService.sessionManager
+        .getIsolatedSessionIfPresent(SessionKey(testUserId, testSessionId))
+        .get
+
+      val graphsBefore = sessionHolder.getAllDataflowGraphs
+      assert(graphsBefore.size == 1)
+
+      // Close the session
+      SparkConnectService.sessionManager.closeSession(SessionKey(testUserId, 
testSessionId))
+
+      // Verify the session is no longer available
+      val sessionAfterClose = SparkConnectService.sessionManager
+        .getIsolatedSessionIfPresent(SessionKey(testUserId, testSessionId))
+
+      assert(sessionAfterClose.isEmpty, "Session should be cleaned up after 
close")
+      // Verify the graph is removed
+      val graphsAfter = sessionHolder.getAllDataflowGraphs
+      assert(graphsAfter.isEmpty, "Graph should be removed after session 
close")
+    }
+  }
+
+  test("multiple dataflow graphs can exist in the same session") {
+    withRawBlockingStub { implicit stub =>
+      // Create two dataflow graphs in the same session
+      val graphId1 = createDataflowGraph
+      val graphId2 = createDataflowGraph
+
+      // Register datasets in both graphs
+      sendPlan(
+        buildPlanFromPipelineCommand(
+          PipelineCommand
+            .newBuilder()
+            .setDefineDataset(
+              DefineDataset
+                .newBuilder()
+                .setDataflowGraphId(graphId1)
+                .setDatasetName("graph1_table")
+                .setDatasetType(DatasetType.MATERIALIZED_VIEW))
+            .build()))
+
+      sendPlan(
+        buildPlanFromPipelineCommand(
+          PipelineCommand
+            .newBuilder()
+            .setDefineDataset(
+              DefineDataset
+                .newBuilder()
+                .setDataflowGraphId(graphId2)
+                .setDatasetName("graph2_table")
+                .setDatasetType(DatasetType.MATERIALIZED_VIEW))
+            .build()))
+
+      // Verify both graphs exist in the session
+      val sessionHolder = getDefaultSessionHolder
+      val graph1 = sessionHolder.getDataflowGraph(graphId1).getOrElse {
+        fail(s"Graph with ID $graphId1 not found in session")
+      }
+      val graph2 = sessionHolder.getDataflowGraph(graphId2).getOrElse {
+        fail(s"Graph with ID $graphId2 not found in session")
+      }

Review Comment:
   nit: just call `getDataflowGraphOrThrow`



##########
sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala:
##########
@@ -23,20 +23,29 @@ import org.apache.spark.connect.{proto => sc}
 import org.apache.spark.connect.proto.{PipelineCommand, PipelineEvent}
 import org.apache.spark.sql.connect.{SparkConnectServerTest, 
SparkConnectTestUtils}
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
-import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService}
+import org.apache.spark.sql.connect.service.{SessionHolder, SessionKey, 
SparkConnectService}
 import org.apache.spark.sql.pipelines.utils.PipelineTest
 
 class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest {
 
   override def afterEach(): Unit = {
     SparkConnectService.sessionManager
       .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId))
-      .foreach(_.removeAllPipelineExecutions())
-    DataflowGraphRegistry.dropAllDataflowGraphs()
+      .foreach(s => {
+        s.removeAllPipelineExecutions()
+        s.dropAllDataflowGraphs()
+      })
     PipelineTest.cleanupMetastore(spark)
     super.afterEach()
   }
 
+  // Helper method to get the session holder
+  protected def getDefaultSessionHolder: SessionHolder = {
+    SparkConnectService.sessionManager
+      .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId))
+      .getOrElse(throw new RuntimeException("Session not found"))

Review Comment:
   nit: just call `getIsolatedSession`



##########
sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala:
##########
@@ -78,8 +83,16 @@ class PythonPipelineSuite
       throw new RuntimeException(
         s"Python process failed with exit code $exitCode. Output: 
${output.mkString("\n")}")
     }
+    val activateSessions = 
SparkConnectService.sessionManager.listActiveSessions
 
-    val dataflowGraphContexts = DataflowGraphRegistry.getAllDataflowGraphs
+    // get the session holder by finding the session with the custom UUID set 
in the conf
+    val sessionHolder = activateSessions
+      .map(info => 
SparkConnectService.sessionManager.getIsolatedSessionIfPresent(info.key).get)

Review Comment:
   `getIsolatedSession()` instead of `getIsolatedSessionIfPresent(...).get`



##########
sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala:
##########
@@ -42,6 +44,8 @@ class PythonPipelineSuite
 
   def buildGraph(pythonText: String): DataflowGraph = {
     val indentedPythonText = pythonText.linesIterator.map("    " + 
_).mkString("\n")
+    // create a unique identifier to allow identifying the session and 
dataflow graph
+    val identifier = UUID.randomUUID().toString

Review Comment:
   nit: rename to something more descriptive like `customSessionIdentifier`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to