This is an automated email from the ASF dual-hosted git repository.

ulyssesyou pushed a commit to branch branch-1.7
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/branch-1.7 by this push:
     new ca91870f3 [KYUUBI #4323] Improve trino session context
ca91870f3 is described below

commit ca91870f3a3b35786803d4f4b4c0d61c69fcf95a
Author: ulysses-you <[email protected]>
AuthorDate: Tue Feb 14 15:33:36 2023 +0800

    [KYUUBI #4323] Improve trino session context
    
    ### _Why are the changes needed?_
    
    This pr improves the trino session context:
    1. always reuse the kyuubi session if session id exists, so we can restore 
the session context for next query
    2. transform trino client information to kyuubi session, e.g. trino request 
source (trino-cli)
    
    ### _How was this patch tested?_
    - [x] Add some test cases that check the changes thoroughly including 
negative and positive cases if possible
    
    - [ ] Add screenshots for manual tests if appropriate
    
    - [x] [Run 
test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests)
 locally before make a pull request
    
    Closes #4323 from ulysses-you/trino-session.
    
    Closes #4323
    
    59804bad [ulysses-you] style
    fcf540a6 [ulysses-you] Improve trino session context
    
    Authored-by: ulysses-you <[email protected]>
    Signed-off-by: ulyssesyou <[email protected]>
    (cherry picked from commit 763c088d8197b168894b3e5d84aefc248fd3273b)
    Signed-off-by: ulyssesyou <[email protected]>
---
 .../org/apache/kyuubi/server/trino/api/Query.scala | 54 ++++++++++++++++------
 .../kyuubi/server/trino/api/TrinoContext.scala     | 20 ++++----
 .../server/trino/api/v1/StatementResource.scala    |  7 +--
 .../server/trino/api/TrinoClientApiSuite.scala     | 18 ++++++--
 .../trino/api/v1/StatementResourceSuite.scala      |  8 ++--
 5 files changed, 74 insertions(+), 33 deletions(-)

diff --git 
a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/Query.scala 
b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/Query.scala
index c8c9e2fd6..925875579 100644
--- 
a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/Query.scala
+++ 
b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/Query.scala
@@ -25,6 +25,8 @@ import java.util.concurrent.atomic.AtomicLong
 import javax.ws.rs.WebApplicationException
 import javax.ws.rs.core.{Response, UriInfo}
 
+import scala.collection.mutable
+
 import Slug.Context.{EXECUTING_QUERY, QUEUED_QUERY}
 import com.google.common.hash.Hashing
 import io.trino.client.QueryResults
@@ -32,6 +34,7 @@ import org.apache.hive.service.rpc.thrift.TProtocolVersion
 
 import org.apache.kyuubi.operation.{FetchOrientation, OperationHandle}
 import org.apache.kyuubi.operation.OperationState.{FINISHED, INITIALIZED, 
OperationState, PENDING}
+import org.apache.kyuubi.server.trino.api.Query.KYUUBI_SESSION_ID
 import org.apache.kyuubi.service.BackendService
 import org.apache.kyuubi.session.SessionHandle
 
@@ -90,7 +93,7 @@ case class Query(
 
   private def clear = {
     be.closeOperation(queryId.operationHandle)
-    context.session.get("sessionId").foreach { id =>
+    context.session.get(KYUUBI_SESSION_ID).foreach { id =>
       be.closeSession(SessionHandle.fromUUID(id))
     }
   }
@@ -128,23 +131,24 @@ case class Query(
 
 object Query {
 
+  val KYUUBI_SESSION_ID = "kyuubi.session.id"
+
   def apply(
       statement: String,
       context: TrinoContext,
       translator: KyuubiTrinoOperationTranslator,
       backendService: BackendService,
       queryTimeout: Long = 0): Query = {
-
-    val sessionHandle = createSession(context, backendService)
+    val sessionHandle = getOrCreateSession(context, backendService)
     val operationHandle = translator.transform(
       statement,
       sessionHandle,
       context.session,
       true,
       queryTimeout)
-    val newSessionProperties =
-      context.session + ("sessionId" -> sessionHandle.identifier.toString)
-    val updatedContext = context.copy(session = newSessionProperties)
+    val sessionWithId =
+      context.session + (KYUUBI_SESSION_ID -> 
sessionHandle.identifier.toString)
+    val updatedContext = context.copy(session = sessionWithId)
     Query(QueryId(operationHandle), updatedContext, backendService)
   }
 
@@ -152,15 +156,39 @@ object Query {
     Query(QueryId(id), context, backendService)
   }
 
-  private def createSession(
+  private def getOrCreateSession(
       context: TrinoContext,
       backendService: BackendService): SessionHandle = {
-    backendService.openSession(
-      TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
-      context.user,
-      "",
-      context.remoteUserAddress.getOrElse(""),
-      context.session)
+    
context.session.get(KYUUBI_SESSION_ID).map(SessionHandle.fromUUID).getOrElse {
+      // transform Trino information to session and engine as far as possible.
+      val trinoInfo = new mutable.HashMap[String, String]()
+      context.clientInfo.foreach { info =>
+        trinoInfo.put("trino.client.info", info)
+      }
+      context.source.foreach { source =>
+        trinoInfo.put("trino.request.source", source)
+      }
+      context.traceToken.foreach { traceToken =>
+        trinoInfo.put("trino.trace.token", traceToken)
+      }
+      context.timeZone.foreach { timeZone =>
+        trinoInfo.put("trino.time.zone", timeZone)
+      }
+      context.language.foreach { language =>
+        trinoInfo.put("trino.language", language)
+      }
+      if (context.clientTags.nonEmpty) {
+        trinoInfo.put("trino.client.info", context.clientTags.mkString(","))
+      }
+
+      val newSessionConfigs = context.session ++ trinoInfo
+      backendService.openSession(
+        TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
+        context.user,
+        "",
+        context.remoteUserAddress.getOrElse(""),
+        newSessionConfigs)
+    }
   }
 
 }
diff --git 
a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
 
b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
index 9e7713904..0c7911a46 100644
--- 
a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
+++ 
b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
@@ -30,7 +30,9 @@ import 
org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet, T
 
 import org.apache.kyuubi.operation.OperationState.FINISHED
 import org.apache.kyuubi.operation.OperationStatus
+import org.apache.kyuubi.server.trino.api.Query.KYUUBI_SESSION_ID
 
+// TODO: Support replace `preparedStatement` for Trino-jdbc
 /**
  * The description and functionality of trino request
  * and response's context
@@ -140,15 +142,17 @@ object TrinoContext {
   def buildTrinoResponse(qr: QueryResults, trinoContext: TrinoContext): 
Response = {
     val responseBuilder = Response.ok(qr)
 
-    trinoContext.catalog.foreach(
-      responseBuilder.header(TRINO_HEADERS.responseSetCatalog, _))
-    trinoContext.schema.foreach(
-      responseBuilder.header(TRINO_HEADERS.responseSetSchema, _))
+    // Note, We have injected kyuubi session id to session context so that the 
next query can find
+    // the previous session to restore the query context.
+    // It's hard to follow the Trino style that set all context to http 
headers.
+    // Because we do not know the context at server side. e.g. `set k=v`, `use 
database`.
+    // We also can not inject other session context into header before we 
supporting to map
+    // query result to session context.
+    require(trinoContext.session.contains(KYUUBI_SESSION_ID), 
s"$KYUUBI_SESSION_ID must be set.")
+    responseBuilder.header(
+      TRINO_HEADERS.responseSetSession,
+      
s"$KYUUBI_SESSION_ID=${urlEncode(trinoContext.session(KYUUBI_SESSION_ID))}")
 
-    trinoContext.session.foreach {
-      case (k, v) =>
-        responseBuilder.header(TRINO_HEADERS.responseSetSession, 
s"${k}=${urlEncode(v)}")
-    }
     trinoContext.preparedStatement.foreach {
       case (k, v) =>
         responseBuilder.header(TRINO_HEADERS.responseAddedPrepare, 
s"${k}=${urlEncode(v)}")
diff --git 
a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
 
b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
index ab783f8ac..ee23c61f3 100644
--- 
a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
+++ 
b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
@@ -215,9 +215,10 @@ private[v1] class StatementResource extends 
ApiRequestContext with Logging {
       slug: String,
       token: Long,
       slugContext: Slug.Context.Context): Try[Query] = {
-
-    
Try(be.sessionManager.operationManager.getOperation(queryId.operationHandle)).map
 { _ =>
-      Query(queryId, context, be)
+    
Try(be.sessionManager.operationManager.getOperation(queryId.operationHandle)).map
 { op =>
+      val sessionWithId = context.session ++
+        Map(Query.KYUUBI_SESSION_ID -> 
op.getSession.handle.identifier.toString)
+      Query(queryId, context.copy(session = sessionWithId), be)
     }.filter(_.getSlug.isValid(slugContext, slug, token))
   }
 
diff --git 
a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoClientApiSuite.scala
 
b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoClientApiSuite.scala
index c88b5c940..13e10a112 100644
--- 
a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoClientApiSuite.scala
+++ 
b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoClientApiSuite.scala
@@ -45,16 +45,24 @@ class TrinoClientApiSuite extends KyuubiFunSuite with 
TrinoRestFrontendTestHelpe
   test("submit query with trino client api") {
     val trino = getTrinoStatementClient("select 1")
     val result = execute(trino)
-    val sessionId = trino.getSetSessionProperties.asScala.get("sessionId")
+    val sessionId = 
trino.getSetSessionProperties.asScala.get(Query.KYUUBI_SESSION_ID)
     assert(result == List(List(1)))
 
     updateClientSession(trino)
 
-    val trino1 = getTrinoStatementClient("select 2")
+    val trino1 = getTrinoStatementClient("set k=v")
     val result1 = execute(trino1)
-    val sessionId1 = trino1.getSetSessionProperties.asScala.get("sessionId")
-    assert(result1 == List(List(2)))
-    assert(sessionId != sessionId1)
+    val sessionId1 = 
trino1.getSetSessionProperties.asScala.get(Query.KYUUBI_SESSION_ID)
+    assert(result1 == List(List("k", "v")))
+    assert(sessionId == sessionId1)
+
+    updateClientSession(trino)
+
+    val trino2 = getTrinoStatementClient("set k")
+    val result2 = execute(trino2)
+    val sessionId2 = 
trino2.getSetSessionProperties.asScala.get(Query.KYUUBI_SESSION_ID)
+    assert(result2 == List(List("k", "v")))
+    assert(sessionId == sessionId2)
 
     trino.close()
   }
diff --git 
a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala
 
b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala
index adbf389c9..5740f6d38 100644
--- 
a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala
+++ 
b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala
@@ -27,7 +27,7 @@ import io.trino.client.ProtocolHeaders.TRINO_HEADERS
 
 import org.apache.kyuubi.{KyuubiFunSuite, KyuubiSQLException, 
TrinoRestFrontendTestHelper}
 import org.apache.kyuubi.operation.{OperationHandle, OperationState}
-import org.apache.kyuubi.server.trino.api.TrinoContext
+import org.apache.kyuubi.server.trino.api.{Query, TrinoContext}
 import org.apache.kyuubi.server.trino.api.v1.dto.Ok
 import org.apache.kyuubi.session.SessionHandle
 
@@ -78,7 +78,7 @@ class StatementResourceSuite extends KyuubiFunSuite with 
TrinoRestFrontendTestHe
       response.getStringHeaders.get(TRINO_HEADERS.responseSetSession).asScala
         .map(_.split("="))
         .find {
-          case Array("sessionId", _) => true
+          case Array(Query.KYUUBI_SESSION_ID, _) => true
         }
         .map {
           case Array(_, value) => 
SessionHandle.fromUUID(TrinoContext.urlDecode(value))
@@ -90,12 +90,12 @@ class StatementResourceSuite extends KyuubiFunSuite with 
TrinoRestFrontendTestHe
     val path = qr.getNextUri.getPath
     val nextResponse = webTarget.path(path).request().header(
       TRINO_HEADERS.requestSession(),
-      
s"sessionId=${TrinoContext.urlEncode(sessionHandle.identifier.toString)}").delete()
+      
s"${Query.KYUUBI_SESSION_ID}=${TrinoContext.urlEncode(sessionHandle.identifier.toString)}")
+      .delete()
     assert(nextResponse.getStatus == 204)
     assert(operation.getStatus.state == OperationState.CLOSED)
     val exception = 
intercept[KyuubiSQLException](sessionManager.getSession(sessionHandle))
     assert(exception.getMessage === s"Invalid $sessionHandle")
-
   }
 
   private def getData(current: TrinoResponse): TrinoResponse = {

Reply via email to