This is an automated email from the ASF dual-hosted git repository.

ggal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-livy.git


The following commit(s) were added to refs/heads/master by this push:
     new 728fcf0a [LIVY-998][THRIFT] Support connecting to existing sessions 
using session name via Thrift Server
728fcf0a is described below

commit 728fcf0ae9b94b7690c8a3f1eaa4d05507bfbea6
Author: Asif Khatri <[email protected]>
AuthorDate: Mon Jun 10 13:46:06 2024 +0530

    [LIVY-998][THRIFT] Support connecting to existing sessions using session 
name via Thrift Server
    
    Co-authored-by: Asif Khatri <[email protected]>
---
 .../thriftserver/LivyThriftSessionManager.scala    | 32 +++++--
 .../TestLivyThriftSessionManager.scala             | 98 ++++++++++++++++++----
 2 files changed, 111 insertions(+), 19 deletions(-)

diff --git 
a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
 
b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
index cd8d2f50..03e10c35 100644
--- 
a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
+++ 
b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
@@ -155,12 +155,13 @@ class LivyThriftSessionManager(val server: 
LivyThriftServer, val livyConf: LivyC
   }
 
   /**
-   * If the user specified an existing sessionId to use, the corresponding 
session is returned,
-   * otherwise a new session is created and returned.
+   * If the user specified an existing sessionId or session name to use, the 
corresponding session
+   * is returned, otherwise a new session is created and returned.
    */
-  private def getOrCreateLivySession(
+  def getOrCreateLivySession(
       sessionHandle: SessionHandle,
       sessionId: Option[Int],
+      sessionName: Option[String],
       username: String,
       createLivySession: () => InteractiveSession): InteractiveSession = {
     sessionId match {
@@ -183,7 +184,27 @@ class LivyThriftSessionManager(val server: 
LivyThriftServer, val livyConf: LivyC
             }
         }
       case None =>
-        createLivySession()
+        sessionName match {
+          case Some(name) =>
+            server.livySessionManager.get(name) match {
+              case None =>
+                createLivySession()
+              case Some(session) if !server.isAllowedToUse(username, session) 
=>
+                warn(s"$username has no modify permissions to 
InteractiveSession $name.")
+                throw new IllegalAccessException(
+                  s"$username is not allowed to use InteractiveSession $name.")
+              case Some(session) =>
+                if (session.state.isActive) {
+                  info(s"Reusing Session $name for $sessionHandle.")
+                  session
+                } else {
+                  warn(s"InteractiveSession $name is not active anymore.")
+                  throw new IllegalArgumentException(s"Session $name is not 
active anymore.")
+                }
+            }
+          case None =>
+            createLivySession()
+        }
     }
   }
 
@@ -248,7 +269,8 @@ class LivyThriftSessionManager(val server: 
LivyThriftServer, val livyConf: LivyC
         livyServiceUGI.doAs(new PrivilegedExceptionAction[InteractiveSession] {
           override def run(): InteractiveSession = {
             livySession =
-              getOrCreateLivySession(sessionHandle, sessionId, username, 
createLivySession)
+              getOrCreateLivySession(sessionHandle, sessionId, 
createInteractiveRequest.name,
+                username, createLivySession)
             synchronized {
               managedLivySessionActiveUsers.get(livySession.id).foreach { 
numUsers =>
                 managedLivySessionActiveUsers(livySession.id) = numUsers + 1
diff --git 
a/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala
 
b/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala
index 11eea31f..cbfc006c 100644
--- 
a/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala
+++ 
b/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala
@@ -27,13 +27,13 @@ import scala.concurrent.duration.Duration
 import org.apache.hive.service.cli.{HiveSQLException, SessionHandle}
 import org.junit.Assert._
 import org.junit.Test
-import org.mockito.Mockito.mock
+import org.mockito.Mockito.{mock, when}
 
 import org.apache.livy.LivyConf
-import org.apache.livy.server.AccessManager
 import org.apache.livy.server.interactive.InteractiveSession
-import org.apache.livy.server.recovery.{SessionStore, StateStore}
-import org.apache.livy.sessions.InteractiveSessionManager
+import org.apache.livy.server.recovery.SessionStore
+import org.apache.livy.server.AccessManager
+import org.apache.livy.sessions.{InteractiveSessionManager, SessionState}
 import org.apache.livy.utils.Clock.sleep
 
 object ConnectionLimitType extends Enumeration {
@@ -46,7 +46,7 @@ class TestLivyThriftSessionManager {
   import ConnectionLimitType._
 
   private def createThriftSessionManager(
-      limitTypes: ConnectionLimitType*): LivyThriftSessionManager = {
+      limitTypes: ConnectionLimitType*): (LivyThriftSessionManager, 
LivyThriftServer) = {
     val conf = new LivyConf()
     conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION"))
     val limit = 3
@@ -62,21 +62,23 @@ class TestLivyThriftSessionManager {
   }
 
   private def createThriftSessionManager(
-      maxSessionWait: Option[String]): LivyThriftSessionManager = {
+      maxSessionWait: Option[String]): (LivyThriftSessionManager, 
LivyThriftServer) = {
     val conf = new LivyConf()
     conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION"))
     maxSessionWait.foreach(conf.set(LivyConf.THRIFT_SESSION_CREATION_TIMEOUT, 
_))
     this.createThriftSessionManager(conf)
   }
 
-  private def createThriftSessionManager(conf: LivyConf): 
LivyThriftSessionManager = {
+  private def createThriftSessionManager(conf: LivyConf): 
(LivyThriftSessionManager,
+    LivyThriftServer) = {
     val server = new LivyThriftServer(
       conf,
       mock(classOf[InteractiveSessionManager]),
       mock(classOf[SessionStore]),
       mock(classOf[AccessManager])
     )
-    new LivyThriftSessionManager(server, conf)
+    val sessionManager = new LivyThriftSessionManager(server, conf)
+    (sessionManager, server)
   }
 
   private def testLimit(
@@ -99,7 +101,7 @@ class TestLivyThriftSessionManager {
 
   @Test
   def testLimitConnectionsByUser(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(User)
+    val (thriftSessionMgr, _) = createThriftSessionManager(User)
     val user = "alice"
     val forwardedAddresses = new java.util.ArrayList[String]()
     thriftSessionMgr.incrementConnections(user, "10.20.30.40", 
forwardedAddresses)
@@ -111,7 +113,7 @@ class TestLivyThriftSessionManager {
 
   @Test
   def testLimitConnectionsByIpAddress(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(IpAddress)
+    val (thriftSessionMgr, _) = createThriftSessionManager(IpAddress)
     val ipAddress = "10.20.30.40"
     val forwardedAddresses = new java.util.ArrayList[String]()
     thriftSessionMgr.incrementConnections("alice", ipAddress, 
forwardedAddresses)
@@ -123,7 +125,7 @@ class TestLivyThriftSessionManager {
 
   @Test
   def testLimitConnectionsByUserAndIpAddress(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(UserIpAddress)
+    val (thriftSessionMgr, _) = createThriftSessionManager(UserIpAddress)
     val user = "alice"
     val ipAddress = "10.20.30.40"
     val userAndAddress = user + ":" + ipAddress
@@ -149,7 +151,7 @@ class TestLivyThriftSessionManager {
 
   @Test
   def testMultipleConnectionLimits(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(User, IpAddress)
+    val (thriftSessionMgr, _) = createThriftSessionManager(User, IpAddress)
     val user = "alice"
     val ipAddress = "10.20.30.40"
     val forwardedAddresses = new java.util.ArrayList[String]()
@@ -166,7 +168,7 @@ class TestLivyThriftSessionManager {
 
   @Test(expected = classOf[TimeoutException])
   def testGetLivySessionWaitForTimeout(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(Some("10ms"))
+    val (thriftSessionMgr, _) = createThriftSessionManager(Some("10ms"))
     val sessionHandle = mock(classOf[SessionHandle])
     val future = Future[InteractiveSession] {
       sleep(100)
@@ -178,7 +180,7 @@ class TestLivyThriftSessionManager {
 
   @Test(expected = classOf[TimeoutException])
   def testGetLivySessionWithTimeoutException(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(None)
+    val (thriftSessionMgr, _) = createThriftSessionManager(None)
     val sessionHandle = mock(classOf[SessionHandle])
     val future = Future[InteractiveSession] {
       throw new TimeoutException("Actively throw TimeoutException in Future.")
@@ -187,4 +189,72 @@ class TestLivyThriftSessionManager {
     Await.ready(future, Duration(30, TimeUnit.SECONDS))
     thriftSessionMgr.getLivySession(sessionHandle)
   }
+
+
+  @Test
+  def testGetOrCreateLivySessionDifferentSessions(): Unit = {
+    val (thriftSessionMgr, server) = createThriftSessionManager(User, 
IpAddress)
+    val sessionHandle = mock(classOf[SessionHandle])
+    val sessionUser = "testUser"
+    val sessionId1 = Some(1)
+    val session1 = mock(classOf[InteractiveSession])
+    when(session1.state).thenReturn(SessionState.Running)
+    when(session1.owner).thenReturn(sessionUser)
+    when(server.livySessionManager.get(1)).thenReturn(Some(session1))
+    val sessionId2 = Some(2)
+    val session2 = mock(classOf[InteractiveSession])
+    when(session2.state).thenReturn(SessionState.Running)
+    when(session2.owner).thenReturn(sessionUser)
+    when(server.livySessionManager.get(2)).thenReturn(Some(session2))
+    val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, 
sessionId1, None,
+      sessionUser, () => null)
+    val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, 
sessionId2, None,
+      sessionUser, () => null)
+
+    assertNotNull(result1)
+    assertNotNull(result2)
+    assertNotEquals(result1, result2)
+  }
+
+  @Test
+  def testGetOrCreateLivySessionExistingSessionByID(): Unit = {
+    val (thriftSessionMgr, server) = createThriftSessionManager(User, 
IpAddress)
+    val sessionHandle = mock(classOf[SessionHandle])
+    val sessionUser = "testUser"
+    val sessionId = Some(1)
+    val session1 = mock(classOf[InteractiveSession])
+    when(session1.state).thenReturn(SessionState.Running)
+    when(session1.owner).thenReturn(sessionUser)
+    when(server.livySessionManager.get(1)).thenReturn(Some(session1))
+    val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, 
sessionId, None,
+      sessionUser, () => null)
+    val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, 
sessionId, None,
+      sessionUser, () => null)
+
+    assertNotNull(result1)
+    assertNotNull(result2)
+    assertEquals(result1, result2)
+  }
+
+
+  @Test
+  def testGetOrCreateLivySessionExistingSessionByName(): Unit = {
+    val (thriftSessionMgr, server) = createThriftSessionManager(User, 
IpAddress)
+    val sessionHandle = mock(classOf[SessionHandle])
+    val sessionUser = "testUser"
+    val sessionName = Some("sessionName")
+    val session1 = mock(classOf[InteractiveSession])
+    when(session1.state).thenReturn(SessionState.Running)
+    when(session1.owner).thenReturn(sessionUser)
+    
when(server.livySessionManager.get("sessionName")).thenReturn(Some(session1))
+    val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, 
sessionName,
+      sessionUser, () => null)
+    val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, 
sessionName,
+      sessionUser, () => null)
+
+    assertNotNull(result1)
+    assertNotNull(result2)
+    assertEquals(result1, result2)
+  }
+
 }

Reply via email to