This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new 9b618c9bd [KYUUBI #6223] Fix Scala interpreter can not access 
spark.jars issue
9b618c9bd is described below

commit 9b618c9bdb34dfefbce6c3fcb98d28c1fdc437ad
Author: Wang, Fei <[email protected]>
AuthorDate: Wed Apr 3 18:36:27 2024 +0800

    [KYUUBI #6223] Fix Scala interpreter can not access spark.jars issue
    
    # :mag: Description
    ## Issue References ๐Ÿ”—
    
    This pull request fixes #6223
    
    Even the user specify `spark.jars`, but they can not access the classes in 
jars with Scala code.
    
    ## Describe Your Solution ๐Ÿ”ง
    
    Add the jars into repl classpath.
    
    ## Types of changes :bookmark:
    
    - [x] Bugfix (non-breaking change which fixes an issue)
    - [ ] New feature (non-breaking change which adds functionality)
    - [ ] Breaking change (fix or feature that would cause existing 
functionality to change)
    
    ## Test Plan ๐Ÿงช
    
    #### Behavior Without This Pull Request :coffin:
    
    #### Behavior With This Pull Request :tada:
    
    #### Related Unit Tests
    UT.
    
    ---
    
    # Checklist ๐Ÿ“
    
    - [ ] This patch was not authored or co-authored using [Generative 
Tooling](https://www.apache.org/legal/generative-tooling.html)
    
    **Be nice. Be informative.**
    
    Closes #6235 from turboFei/scala_repl_urls.
    
    Closes #6223
    
    344502660 [Wang, Fei] scala 2.13
    cc6e28989 [Wang, Fei] todo
    a8b373167 [Wang, Fei] refine
    65b438ccf [Wang, Fei] remove scala reflect check
    eb257c7a8 [Wang, Fei] using -classpath
    e1c6f0e11 [Wang, Fei] revert 2.13
    15d37662d [Wang, Fei] repl
    41ebe1011 [Wang, Fei] fix ut
    ed5d344f8 [Wang, Fei] info
    1cdd82ab4 [Wang, Fei] comment
    aa4292dac [Wang, Fei] fix
    a10cfa5e0 [Wang, Fei] ut
    63fdb8877 [Wang, Fei] Use global.classPath.asURLs instead of class loader 
urls
    
    Authored-by: Wang, Fei <[email protected]>
    Signed-off-by: Cheng Pan <[email protected]>
---
 .../engine/spark/repl/KyuubiSparkILoop.scala       | 47 +++++++++++++---------
 .../engine/spark/repl/KyuubiSparkILoop.scala       | 46 ++++++++++++---------
 .../KyuubiOperationPerConnectionSuite.scala        | 35 ++++++++++++++--
 3 files changed, 85 insertions(+), 43 deletions(-)

diff --git 
a/externals/kyuubi-spark-sql-engine/src/main/scala-2.12/org/apache/kyuubi/engine/spark/repl/KyuubiSparkILoop.scala
 
b/externals/kyuubi-spark-sql-engine/src/main/scala-2.12/org/apache/kyuubi/engine/spark/repl/KyuubiSparkILoop.scala
index fbbda89ed..bf20a8fb9 100644
--- 
a/externals/kyuubi-spark-sql-engine/src/main/scala-2.12/org/apache/kyuubi/engine/spark/repl/KyuubiSparkILoop.scala
+++ 
b/externals/kyuubi-spark-sql-engine/src/main/scala-2.12/org/apache/kyuubi/engine/spark/repl/KyuubiSparkILoop.scala
@@ -18,6 +18,7 @@
 package org.apache.kyuubi.engine.spark.repl
 
 import java.io.{ByteArrayOutputStream, File, PrintWriter}
+import java.net.URL
 import java.util.concurrent.locks.ReentrantLock
 
 import scala.tools.nsc.Settings
@@ -28,47 +29,35 @@ import org.apache.spark.repl.SparkILoop
 import org.apache.spark.sql.{DataFrame, SparkSession}
 import org.apache.spark.util.MutableURLClassLoader
 
-import org.apache.kyuubi.Utils
+import org.apache.kyuubi.{Logging, Utils}
 
 private[spark] case class KyuubiSparkILoop private (
     spark: SparkSession,
     output: ByteArrayOutputStream)
-  extends SparkILoop(None, new PrintWriter(output)) {
+  extends SparkILoop(None, new PrintWriter(output)) with Logging {
   import KyuubiSparkILoop._
 
   val result = new DataFrameHolder(spark)
 
   private def initialize(): Unit = withLockRequired {
+    val currentClassLoader = Thread.currentThread().getContextClassLoader
+    val interpreterClasspath = 
getAllJars(currentClassLoader).mkString(File.pathSeparator)
+    info(s"Adding jars to Scala interpreter's class path: 
$interpreterClasspath")
     settings = new Settings
     val interpArguments = List(
       "-Yrepl-class-based",
       "-Yrepl-outdir",
-      s"${spark.sparkContext.getConf.get("spark.repl.class.outputDir")}")
+      s"${spark.sparkContext.getConf.get("spark.repl.class.outputDir")}",
+      "-classpath",
+      interpreterClasspath)
     settings.processArguments(interpArguments, processAll = true)
     settings.usejavacp.value = true
-    val currentClassLoader = Thread.currentThread().getContextClassLoader
     settings.embeddedDefaults(currentClassLoader)
     this.createInterpreter()
     this.initializeSynchronous()
     try {
       this.compilerClasspath
       this.ensureClassLoader()
-      var classLoader: ClassLoader = 
Thread.currentThread().getContextClassLoader
-      while (classLoader != null) {
-        classLoader match {
-          case loader: MutableURLClassLoader =>
-            val allJars = loader.getURLs.filter { u =>
-              val file = new File(u.getPath)
-              u.getProtocol == "file" && file.isFile &&
-              file.getName.contains("scala-lang_scala-reflect")
-            }
-            this.addUrlsToClassPath(allJars: _*)
-            classLoader = null
-          case _ =>
-            classLoader = classLoader.getParent
-        }
-      }
-
       this.addUrlsToClassPath(
         classOf[DataFrameHolder].getProtectionDomain.getCodeSource.getLocation)
     } finally {
@@ -97,6 +86,24 @@ private[spark] case class KyuubiSparkILoop private (
     }
   }
 
+  private def getAllJars(currentClassLoader: ClassLoader): Array[URL] = {
+    var classLoader: ClassLoader = currentClassLoader
+    var allJars = Array.empty[URL]
+    while (classLoader != null) {
+      classLoader match {
+        case loader: MutableURLClassLoader =>
+          allJars = loader.getURLs.filter { u =>
+            // TODO: handle SPARK-47475 since Spark 4.0.0 in the future
+            u.getProtocol == "file" && new File(u.getPath).isFile
+          }
+          classLoader = null
+        case _ =>
+          classLoader = classLoader.getParent
+      }
+    }
+    allJars
+  }
+
   def getResult(statementId: String): DataFrame = result.get(statementId)
 
   def clearResult(statementId: String): Unit = result.unset(statementId)
diff --git 
a/externals/kyuubi-spark-sql-engine/src/main/scala-2.13/org/apache/kyuubi/engine/spark/repl/KyuubiSparkILoop.scala
 
b/externals/kyuubi-spark-sql-engine/src/main/scala-2.13/org/apache/kyuubi/engine/spark/repl/KyuubiSparkILoop.scala
index a63d71a78..c6e216fd6 100644
--- 
a/externals/kyuubi-spark-sql-engine/src/main/scala-2.13/org/apache/kyuubi/engine/spark/repl/KyuubiSparkILoop.scala
+++ 
b/externals/kyuubi-spark-sql-engine/src/main/scala-2.13/org/apache/kyuubi/engine/spark/repl/KyuubiSparkILoop.scala
@@ -18,6 +18,7 @@
 package org.apache.kyuubi.engine.spark.repl
 
 import java.io.{ByteArrayOutputStream, File, PrintWriter}
+import java.net.URL
 import java.util.concurrent.locks.ReentrantLock
 
 import scala.tools.nsc.Settings
@@ -28,25 +29,29 @@ import org.apache.spark.repl.SparkILoop
 import org.apache.spark.sql.{DataFrame, SparkSession}
 import org.apache.spark.util.MutableURLClassLoader
 
-import org.apache.kyuubi.Utils
+import org.apache.kyuubi.{Logging, Utils}
 
 private[spark] case class KyuubiSparkILoop private (
     spark: SparkSession,
     output: ByteArrayOutputStream)
-  extends SparkILoop(null, new PrintWriter(output)) {
+  extends SparkILoop(null, new PrintWriter(output)) with Logging {
   import KyuubiSparkILoop._
 
   val result = new DataFrameHolder(spark)
 
   private def initialize(): Unit = withLockRequired {
+    val currentClassLoader = Thread.currentThread().getContextClassLoader
+    val interpreterClasspath = 
getAllJars(currentClassLoader).mkString(File.pathSeparator)
+    info(s"Adding jars to Scala interpreter's class path: 
$interpreterClasspath")
     val settings = new Settings
     val interpArguments = List(
       "-Yrepl-class-based",
       "-Yrepl-outdir",
-      s"${spark.sparkContext.getConf.get("spark.repl.class.outputDir")}")
+      s"${spark.sparkContext.getConf.get("spark.repl.class.outputDir")}",
+      "-classpath",
+      interpreterClasspath)
     settings.processArguments(interpArguments, processAll = true)
     settings.usejavacp.value = true
-    val currentClassLoader = Thread.currentThread().getContextClassLoader
     settings.embeddedDefaults(currentClassLoader)
     this.createInterpreter(settings)
     val iMain = this.intp.asInstanceOf[IMain]
@@ -54,22 +59,6 @@ private[spark] case class KyuubiSparkILoop private (
     try {
       this.compilerClasspath
       iMain.ensureClassLoader()
-      var classLoader: ClassLoader = 
Thread.currentThread().getContextClassLoader
-      while (classLoader != null) {
-        classLoader match {
-          case loader: MutableURLClassLoader =>
-            val allJars = loader.getURLs.filter { u =>
-              val file = new File(u.getPath)
-              u.getProtocol == "file" && file.isFile &&
-              file.getName.contains("scala-lang_scala-reflect")
-            }
-            this.addUrlsToClassPath(allJars: _*)
-            classLoader = null
-          case _ =>
-            classLoader = classLoader.getParent
-        }
-      }
-
       this.addUrlsToClassPath(
         classOf[DataFrameHolder].getProtectionDomain.getCodeSource.getLocation)
     } finally {
@@ -98,6 +87,23 @@ private[spark] case class KyuubiSparkILoop private (
     }
   }
 
+  private def getAllJars(currentClassLoader: ClassLoader): Array[URL] = {
+    var classLoader: ClassLoader = currentClassLoader
+    var allJars = Array.empty[URL]
+    while (classLoader != null) {
+      classLoader match {
+        case loader: MutableURLClassLoader =>
+          allJars = loader.getURLs.filter { u =>
+            u.getProtocol == "file" && new File(u.getPath).isFile
+          }
+          classLoader = null
+        case _ =>
+          classLoader = classLoader.getParent
+      }
+    }
+    allJars
+  }
+
   def getResult(statementId: String): DataFrame = result.get(statementId)
 
   def clearResult(statementId: String): Unit = result.unset(statementId)
diff --git 
a/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerConnectionSuite.scala
 
b/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerConnectionSuite.scala
index 859ccac98..12f8520b6 100644
--- 
a/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerConnectionSuite.scala
+++ 
b/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerConnectionSuite.scala
@@ -19,18 +19,19 @@ package org.apache.kyuubi.operation
 
 import java.sql.SQLException
 import java.util
-import java.util.Properties
+import java.util.{Properties, UUID}
 
 import scala.collection.JavaConverters._
 
+import org.apache.hadoop.fs.Path
 import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
 
-import org.apache.kyuubi.{KYUUBI_VERSION, WithKyuubiServer}
+import org.apache.kyuubi.{KYUUBI_VERSION, Utils, WithKyuubiServer}
 import org.apache.kyuubi.config.{KyuubiConf, KyuubiReservedKeys}
 import org.apache.kyuubi.config.KyuubiConf.SESSION_CONF_ADVISOR
 import org.apache.kyuubi.engine.{ApplicationManagerInfo, ApplicationState}
 import org.apache.kyuubi.jdbc.KyuubiHiveDriver
-import org.apache.kyuubi.jdbc.hive.{KyuubiConnection, KyuubiSQLException}
+import org.apache.kyuubi.jdbc.hive.{KyuubiConnection, KyuubiSQLException, 
KyuubiStatement}
 import org.apache.kyuubi.metrics.{MetricsConstants, MetricsSystem}
 import org.apache.kyuubi.plugin.SessionConfAdvisor
 import org.apache.kyuubi.session.{KyuubiSessionImpl, KyuubiSessionManager, 
SessionHandle, SessionType}
@@ -346,6 +347,34 @@ class KyuubiOperationPerConnectionSuite extends 
WithKyuubiServer with HiveJDBCTe
       }
     }
   }
+
+  test("Scala REPL should see jars added by spark.jars") {
+    val jarDir = Utils.createTempDir().toFile
+    val udfCode =
+      """
+        |package test.utils
+        |
+        |object Math {
+        |  def add(x: Int, y: Int): Int = x + y
+        |}
+        |
+        |""".stripMargin
+    val jarFile = UserJarTestUtils.createJarFile(
+      udfCode,
+      "test",
+      s"test-function-${UUID.randomUUID}.jar",
+      jarDir.toString)
+    val localPath = new Path(jarFile.getAbsolutePath)
+    withSessionConf()(Map("spark.jars" -> localPath.toString))() {
+      withJdbcStatement() { statement =>
+        val kyuubiStatement = statement.asInstanceOf[KyuubiStatement]
+        kyuubiStatement.executeScala("import test.utils.{Math => TMath}")
+        val rs = kyuubiStatement.executeScala("println(TMath.add(1,2))")
+        rs.next()
+        assert(rs.getString(1) === "3")
+      }
+    }
+  }
 }
 
 class TestSessionConfAdvisor extends SessionConfAdvisor {

Reply via email to