This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0863af4254aa [SPARK-49249][SPARK-49122][CONNECT][SQL] Add
`addArtifact` API to the Spark SQL Core
0863af4254aa is described below
commit 0863af4254aae9f9ec352ffddaf7aa004492b0a8
Author: Paddy Xu <[email protected]>
AuthorDate: Fri Aug 30 07:59:44 2024 -0400
[SPARK-49249][SPARK-49122][CONNECT][SQL] Add `addArtifact` API to the Spark
SQL Core
### What changes were proposed in this pull request?
This PR improves Spark SQL Core by adding a bunch of `addArtifact` APIs to
`SparkSession`. These APIs were first introduced to Spark Connect a while ago.
The follow-up task is for PySpark.
### Why are the changes needed?
To close the API compatibility gap between Spark Connect and Spark Classic.
### Does this PR introduce _any_ user-facing change?
Yes, users will be able to use some new APIs.
### How was this patch tested?
Added new tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47631 from xupefei/core-add-artifact.
Authored-by: Paddy Xu <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 75 ++--------
.../sql/connect/client/AmmoniteClassFinder.scala | 2 +
.../spark/sql/connect/client/ArtifactSuite.scala | 1 +
.../CheckConnectJvmClientCompatibility.scala | 10 +-
.../main/scala/org/apache/spark/sql/Artifact.scala | 160 +++++++++++++++++++++
.../org/apache/spark/sql/api/SparkSession.scala | 55 +++++++
.../org/apache/spark/sql}/util/ArtifactUtils.scala | 2 +-
.../spark/sql/connect/client/ArtifactManager.scala | 146 ++-----------------
.../spark/sql/connect/client/ClassFinder.scala | 4 +-
.../sql/connect/client/SparkConnectClient.scala | 2 +-
.../service/SparkConnectAddArtifactsHandler.scala | 2 +-
.../scala/org/apache/spark/sql/SparkSession.scala | 43 +++++-
.../spark/sql/artifact/ArtifactManager.scala | 67 +++++++--
.../artifact-tests/HelloWithPackage.class | Bin 0 -> 635 bytes
.../resources/artifact-tests/HelloWithPackage.java | 35 +++++
.../test/resources/artifact-tests/IntSumUdf.class | Bin 0 -> 1333 bytes
.../test/resources/artifact-tests/IntSumUdf.scala | 22 +++
.../spark/sql/artifact/ArtifactManagerSuite.scala | 84 ++++++++++-
18 files changed, 489 insertions(+), 221 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index e16ff2e39173..3837db00acc6 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -390,77 +390,30 @@ class SparkSession private[sql] (
execute(command)
}
- /**
- * Add a single artifact to the client session.
- *
- * Currently only local files with extensions .jar and .class are supported.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
@Experimental
- def addArtifact(path: String): Unit = client.addArtifact(path)
+ override def addArtifact(path: String): Unit = client.addArtifact(path)
- /**
- * Add a single artifact to the client session.
- *
- * Currently it supports local files with extensions .jar and .class and
Apache Ivy URIs
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
@Experimental
- def addArtifact(uri: URI): Unit = client.addArtifact(uri)
+ override def addArtifact(uri: URI): Unit = client.addArtifact(uri)
- /**
- * Add a single in-memory artifact to the session while preserving the
directory structure
- * specified by `target` under the session's working directory of that
particular file
- * extension.
- *
- * Supported target file extensions are .jar and .class.
- *
- * ==Example==
- * {{{
- * addArtifact(bytesBar, "foo/bar.class")
- * addArtifact(bytesFlat, "flat.class")
- * // Directory structure of the session's working directory for class
files would look like:
- * // ${WORKING_DIR_FOR_CLASS_FILES}/flat.class
- * // ${WORKING_DIR_FOR_CLASS_FILES}/foo/bar.class
- * }}}
- *
- * @since 4.0.0
- */
+ /** @inheritdoc */
@Experimental
- def addArtifact(bytes: Array[Byte], target: String): Unit =
client.addArtifact(bytes, target)
+ override def addArtifact(bytes: Array[Byte], target: String): Unit = {
+ client.addArtifact(bytes, target)
+ }
- /**
- * Add a single artifact to the session while preserving the directory
structure specified by
- * `target` under the session's working directory of that particular file
extension.
- *
- * Supported target file extensions are .jar and .class.
- *
- * ==Example==
- * {{{
- * addArtifact("/Users/dummyUser/files/foo/bar.class", "foo/bar.class")
- * addArtifact("/Users/dummyUser/files/flat.class", "flat.class")
- * // Directory structure of the session's working directory for class
files would look like:
- * // ${WORKING_DIR_FOR_CLASS_FILES}/flat.class
- * // ${WORKING_DIR_FOR_CLASS_FILES}/foo/bar.class
- * }}}
- *
- * @since 4.0.0
- */
+ /** @inheritdoc */
@Experimental
- def addArtifact(source: String, target: String): Unit =
client.addArtifact(source, target)
+ override def addArtifact(source: String, target: String): Unit = {
+ client.addArtifact(source, target)
+ }
- /**
- * Add one or more artifacts to the session.
- *
- * Currently it supports local files with extensions .jar and .class and
Apache Ivy URIs
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
@Experimental
@scala.annotation.varargs
- def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri)
+ override def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri)
/**
* Register a ClassFinder for dynamically generated classes.
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala
index 4ebc22202b0b..b359a871d8c2 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala
@@ -22,6 +22,8 @@ import java.nio.file.Paths
import ammonite.repl.api.Session
import ammonite.runtime.SpecialClassLoader
+import org.apache.spark.sql.Artifact
+
/**
* A special [[ClassFinder]] for the Ammonite REPL to handle in-memory class
files.
*
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
index bbc396a937c3..66a2c943af5f 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
@@ -30,6 +30,7 @@ import org.apache.commons.codec.digest.DigestUtils.sha256Hex
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.connect.proto.AddArtifactsRequest
+import org.apache.spark.sql.Artifact
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.test.ConnectFunSuite
import org.apache.spark.util.IvyTestUtils
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 94bf18027b43..af9168339dcf 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -282,15 +282,11 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.rddToDatasetHolder"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.session"),
- // Artifact Manager
+ // Artifact Manager, client has a totally different implementation.
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.artifact.ArtifactManager"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.artifact.ArtifactManager$"),
- ProblemFilters.exclude[MissingClassProblem](
- "org.apache.spark.sql.artifact.util.ArtifactUtils"),
- ProblemFilters.exclude[MissingClassProblem](
- "org.apache.spark.sql.artifact.util.ArtifactUtils$"),
// UDFRegistration
ProblemFilters.exclude[DirectMissingMethodProblem](
@@ -391,10 +387,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.execute"),
// Experimental
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.addArtifact"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.addArtifacts"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.registerClassFinder"),
// public
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Artifact.scala
b/sql/api/src/main/scala/org/apache/spark/sql/Artifact.scala
new file mode 100644
index 000000000000..c78280af6e02
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Artifact.scala
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.io.{ByteArrayInputStream, InputStream, PrintStream}
+import java.net.URI
+import java.nio.file.{Files, Path, Paths}
+
+import org.apache.commons.lang3.StringUtils
+
+import org.apache.spark.sql.Artifact.LocalData
+import org.apache.spark.sql.util.ArtifactUtils
+import org.apache.spark.util.ArrayImplicits._
+import org.apache.spark.util.MavenUtils
+
+
+private[sql] class Artifact private(val path: Path, val storage: LocalData) {
+ require(!path.isAbsolute, s"Bad path: $path")
+
+ lazy val size: Long = storage match {
+ case localData: LocalData => localData.size
+ }
+}
+
+private[sql] object Artifact {
+ val CLASS_PREFIX: Path = Paths.get("classes")
+ val JAR_PREFIX: Path = Paths.get("jars")
+ val CACHE_PREFIX: Path = Paths.get("cache")
+
+ def newArtifactFromExtension(
+ fileName: String,
+ targetFilePath: Path,
+ storage: LocalData): Artifact = {
+ fileName match {
+ case jar if jar.endsWith(".jar") =>
+ newJarArtifact(targetFilePath, storage)
+ case cf if cf.endsWith(".class") =>
+ newClassArtifact(targetFilePath, storage)
+ case other =>
+ throw new UnsupportedOperationException(s"Unsupported file format:
$other")
+ }
+ }
+
+ def parseArtifacts(uri: URI): Seq[Artifact] = {
+ // Currently only local files with extensions .jar and .class are
supported.
+ uri.getScheme match {
+ case "file" =>
+ val path = Paths.get(uri)
+ val artifact = Artifact.newArtifactFromExtension(
+ path.getFileName.toString,
+ path.getFileName,
+ new LocalFile(path))
+ Seq[Artifact](artifact)
+
+ case "ivy" =>
+ newIvyArtifacts(uri)
+
+ case other =>
+ throw new UnsupportedOperationException(s"Unsupported scheme: $other")
+ }
+ }
+
+ def newJarArtifact(targetFilePath: Path, storage: LocalData): Artifact = {
+ newArtifact(JAR_PREFIX, ".jar", targetFilePath, storage)
+ }
+
+ def newClassArtifact(targetFilePath: Path, storage: LocalData): Artifact = {
+ newArtifact(CLASS_PREFIX, ".class", targetFilePath, storage)
+ }
+
+ def newCacheArtifact(id: String, storage: LocalData): Artifact = {
+ newArtifact(CACHE_PREFIX, "", Paths.get(id), storage)
+ }
+
+ def newIvyArtifacts(uri: URI): Seq[Artifact] = {
+ implicit val printStream: PrintStream = System.err
+
+ val authority = uri.getAuthority
+ if (authority == null) {
+ throw new IllegalArgumentException(
+ s"Invalid Ivy URI authority in uri ${uri.toString}:" +
+ " Expected 'org:module:version', found null.")
+ }
+ if (authority.split(":").length != 3) {
+ throw new IllegalArgumentException(
+ s"Invalid Ivy URI authority in uri ${uri.toString}:" +
+ s" Expected 'org:module:version', found $authority.")
+ }
+
+ val (transitive, exclusions, repos) = MavenUtils.parseQueryParams(uri)
+
+ val exclusionsList: Seq[String] =
+ if (!StringUtils.isBlank(exclusions)) {
+ exclusions.split(",").toImmutableArraySeq
+ } else {
+ Nil
+ }
+
+ val ivySettings = MavenUtils.buildIvySettings(Some(repos), None)
+
+ val jars = MavenUtils.resolveMavenCoordinates(
+ authority,
+ ivySettings,
+ transitive = transitive,
+ exclusions = exclusionsList)
+ jars.map(p => Paths.get(p)).map(path => newJarArtifact(path.getFileName,
new LocalFile(path)))
+ }
+
+ private def newArtifact(
+ prefix: Path,
+ requiredSuffix: String,
+ targetFilePath: Path,
+ storage: LocalData): Artifact = {
+ require(targetFilePath.toString.endsWith(requiredSuffix))
+ new Artifact(ArtifactUtils.concatenatePaths(prefix, targetFilePath),
storage)
+ }
+
+ /**
+ * Payload stored on this machine.
+ */
+ sealed trait LocalData {
+ def stream: InputStream
+
+ def size: Long
+ }
+
+ /**
+ * Payload stored in a local file.
+ */
+ class LocalFile(val path: Path) extends LocalData {
+ override def size: Long = Files.size(path)
+
+ override def stream: InputStream = Files.newInputStream(path)
+ }
+
+ /**
+ * Payload stored in memory.
+ */
+ class InMemory(bytes: Array[Byte]) extends LocalData {
+ override def size: Long = bytes.length
+
+ override def stream: InputStream = new ByteArrayInputStream(bytes)
+ }
+
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
index 12a1a1361903..d156aba934b6 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
import _root_.java.io.Closeable
import _root_.java.lang
+import _root_.java.net.URI
import _root_.java.util
import org.apache.spark.annotation.{DeveloperApi, Experimental}
@@ -312,6 +313,60 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]]
extends Serializable with C
*/
def sql(sqlText: String): DS[Row] = sql(sqlText, Map.empty[String, Any])
+ /**
+ * Add a single artifact to the current session.
+ *
+ * Currently only local files with extensions .jar and .class are supported.
+ *
+ * @since 4.0.0
+ */
+ @Experimental
+ def addArtifact(path: String): Unit
+
+ /**
+ * Add a single artifact to the current session.
+ *
+ * Currently it supports local files with extensions .jar and .class and
Apache Ivy URIs.
+ *
+ * @since 4.0.0
+ */
+ @Experimental
+ def addArtifact(uri: URI): Unit
+
+ @Experimental
+ def addArtifact(bytes: Array[Byte], target: String): Unit
+
+ /**
+ * Add a single artifact to the session while preserving the directory
structure specified by
+ * `target` under the session's working directory of that particular file
extension.
+ *
+ * Supported target file extensions are .jar and .class.
+ *
+ * ==Example==
+ * {{{
+ * addArtifact("/Users/dummyUser/files/foo/bar.class", "foo/bar.class")
+ * addArtifact("/Users/dummyUser/files/flat.class", "flat.class")
+ * // Directory structure of the session's working directory for class
files would look like:
+ * // ${WORKING_DIR_FOR_CLASS_FILES}/flat.class
+ * // ${WORKING_DIR_FOR_CLASS_FILES}/foo/bar.class
+ * }}}
+ *
+ * @since 4.0.0
+ */
+ @Experimental
+ def addArtifact(source: String, target: String): Unit
+
+ /**
+ * Add one or more artifacts to the session.
+ *
+ * Currently it supports local files with extensions .jar and .class and
Apache Ivy URIs
+ *
+ * @since 4.0.0
+ */
+ @Experimental
+ @scala.annotation.varargs
+ def addArtifacts(uri: URI*): Unit
+
/**
* Executes some code block and prints to stdout the time taken to execute
the block. This is
* available in Scala only and is used primarily for interactive testing and
debugging.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/util/ArtifactUtils.scala
b/sql/api/src/main/scala/org/apache/spark/sql/util/ArtifactUtils.scala
similarity index 97%
rename from
sql/core/src/main/scala/org/apache/spark/sql/artifact/util/ArtifactUtils.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/util/ArtifactUtils.scala
index f16d01501d7c..8cd239b55cff 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/util/ArtifactUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArtifactUtils.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.artifact.util
+package org.apache.spark.sql.util
import java.nio.file.{Path, Paths}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
index 6eb59bd37574..e9411dc3db61 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.connect.client
-import java.io.{ByteArrayInputStream, File, InputStream, PrintStream}
+import java.io.InputStream
import java.net.URI
import java.nio.file.{Files, Path, Paths}
import java.util.Arrays
@@ -29,7 +29,6 @@ import scala.concurrent.duration.Duration
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
-import Artifact._
import com.google.protobuf.ByteString
import io.grpc.StatusRuntimeException
import io.grpc.stub.StreamObserver
@@ -40,8 +39,9 @@ import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.AddArtifactsResponse
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
-import org.apache.spark.util.{MavenUtils, SparkFileUtils, SparkThreadUtils}
-import org.apache.spark.util.ArrayImplicits._
+import org.apache.spark.sql.Artifact
+import org.apache.spark.sql.Artifact.{newCacheArtifact, newIvyArtifacts}
+import org.apache.spark.util.{SparkFileUtils, SparkThreadUtils}
/**
* The Artifact Manager is responsible for handling and transferring artifacts
from the local
@@ -89,7 +89,7 @@ class ArtifactManager(
val artifact = Artifact.newArtifactFromExtension(
path.getFileName.toString,
path.getFileName,
- new LocalFile(path))
+ new Artifact.LocalFile(path))
Seq[Artifact](artifact)
case "ivy" =>
@@ -128,7 +128,7 @@ class ArtifactManager(
val artifact = Artifact.newArtifactFromExtension(
targetPath.getFileName.toString,
targetPath,
- new InMemory(bytes))
+ new Artifact.InMemory(bytes))
addArtifacts(artifact :: Nil)
}
@@ -152,7 +152,7 @@ class ArtifactManager(
val artifact = Artifact.newArtifactFromExtension(
targetPath.getFileName.toString,
targetPath,
- new LocalFile(Paths.get(source)))
+ new Artifact.LocalFile(Paths.get(source)))
addArtifacts(artifact :: Nil)
}
@@ -164,7 +164,7 @@ class ArtifactManager(
def addArtifacts(uris: Seq[URI]): Unit =
addArtifacts(uris.flatMap(parseArtifacts))
private[client] def isCachedArtifact(hash: String): Boolean = {
- val artifactName = s"$CACHE_PREFIX/$hash"
+ val artifactName = s"${Artifact.CACHE_PREFIX}/$hash"
val request = proto.ArtifactStatusesRequest
.newBuilder()
.setUserContext(clientConfig.userContext)
@@ -191,7 +191,7 @@ class ArtifactManager(
def cacheArtifact(blob: Array[Byte]): String = {
val hash = sha256Hex(blob)
if (!isCachedArtifact(hash)) {
- addArtifacts(newCacheArtifact(hash, new InMemory(blob)) :: Nil)
+ addArtifacts(newCacheArtifact(hash, new Artifact.InMemory(blob)) :: Nil)
}
hash
}
@@ -214,7 +214,9 @@ class ArtifactManager(
try {
stream.forEach { path =>
if (Files.isRegularFile(path) && path.toString.endsWith(".class")) {
- builder += Artifact.newClassArtifact(base.relativize(path), new
LocalFile(path))
+ builder += Artifact.newClassArtifact(
+ base.relativize(path),
+ new Artifact.LocalFile(path))
}
}
} finally {
@@ -414,127 +416,3 @@ class ArtifactManager(
}
}
}
-
-class Artifact private (val path: Path, val storage: LocalData) {
- require(!path.isAbsolute, s"Bad path: $path")
-
- lazy val size: Long = storage match {
- case localData: LocalData => localData.size
- }
-}
-
-object Artifact {
- val CLASS_PREFIX: Path = Paths.get("classes")
- val JAR_PREFIX: Path = Paths.get("jars")
- val CACHE_PREFIX: Path = Paths.get("cache")
-
- def newArtifactFromExtension(
- fileName: String,
- targetFilePath: Path,
- storage: LocalData): Artifact = {
- fileName match {
- case jar if jar.endsWith(".jar") =>
- newJarArtifact(targetFilePath, storage)
- case cf if cf.endsWith(".class") =>
- newClassArtifact(targetFilePath, storage)
- case other =>
- throw new UnsupportedOperationException(s"Unsupported file format:
$other")
- }
- }
-
- def newJarArtifact(targetFilePath: Path, storage: LocalData): Artifact = {
- newArtifact(JAR_PREFIX, ".jar", targetFilePath, storage)
- }
-
- def newClassArtifact(targetFilePath: Path, storage: LocalData): Artifact = {
- newArtifact(CLASS_PREFIX, ".class", targetFilePath, storage)
- }
-
- def newCacheArtifact(id: String, storage: LocalData): Artifact = {
- newArtifact(CACHE_PREFIX, "", Paths.get(id), storage)
- }
-
- def newIvyArtifacts(uri: URI): Seq[Artifact] = {
- implicit val printStream: PrintStream = System.err
-
- val authority = uri.getAuthority
- if (authority == null) {
- throw new IllegalArgumentException(
- s"Invalid Ivy URI authority in uri ${uri.toString}:" +
- " Expected 'org:module:version', found null.")
- }
- if (authority.split(":").length != 3) {
- throw new IllegalArgumentException(
- s"Invalid Ivy URI authority in uri ${uri.toString}:" +
- s" Expected 'org:module:version', found $authority.")
- }
-
- val (transitive, exclusions, repos) = MavenUtils.parseQueryParams(uri)
-
- val exclusionsList: Seq[String] =
- if (!StringUtils.isBlank(exclusions)) {
- exclusions.split(",").toImmutableArraySeq
- } else {
- Nil
- }
-
- val ivySettings = MavenUtils.buildIvySettings(Some(repos), None)
-
- val jars = MavenUtils.resolveMavenCoordinates(
- authority,
- ivySettings,
- transitive = transitive,
- exclusions = exclusionsList)
- jars.map(p => Paths.get(p)).map(path => newJarArtifact(path.getFileName,
new LocalFile(path)))
- }
-
- private def concatenatePaths(basePath: Path, otherPath: Path): Path = {
- // We avoid using the `.resolve()` method here to ensure that we're
concatenating the two
- // paths even if `otherPath` is absolute.
- val concatenatedPath = Paths.get(basePath.toString, otherPath.toString)
- // Note: The normalized resulting path may still reference parent
directories if the
- // `otherPath` contains sufficient number of parent operators (i.e "..").
- // Example: `basePath` = "/base", `otherPath` = "subdir/../../file.txt"
- // Then, `concatenatedPath` = "/base/subdir/../../file.txt"
- // and `normalizedPath` = "/base/file.txt".
- val normalizedPath = concatenatedPath.normalize()
- // Verify that the prefix of the `normalizedPath` starts with `basePath/`.
- require(
- normalizedPath != basePath &&
normalizedPath.startsWith(s"$basePath${File.separator}"))
- normalizedPath
- }
-
- private def newArtifact(
- prefix: Path,
- requiredSuffix: String,
- targetFilePath: Path,
- storage: LocalData): Artifact = {
- require(targetFilePath.toString.endsWith(requiredSuffix))
- new Artifact(concatenatePaths(prefix, targetFilePath), storage)
- }
-
- /**
- * Payload stored on this machine.
- */
- sealed trait LocalData {
- def stream: InputStream
- def size: Long
- }
-
- /**
- * Payload stored in a local file.
- */
- class LocalFile(val path: Path) extends LocalData {
- override def size: Long = Files.size(path)
- override def stream: InputStream = Files.newInputStream(path)
- }
-
- /**
- * Payload stored in memory.
- */
- class InMemory(bytes: Array[Byte]) extends LocalData {
- override def size: Long = bytes.length
- override def stream: InputStream = new ByteArrayInputStream(bytes)
- }
-
-}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ClassFinder.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ClassFinder.scala
index 94486c31a163..261bcc01d701 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ClassFinder.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ClassFinder.scala
@@ -21,7 +21,7 @@ import java.nio.file.{Files, LinkOption, Path, Paths}
import scala.jdk.CollectionConverters._
-import org.apache.spark.sql.connect.client.Artifact.LocalFile
+import org.apache.spark.sql.Artifact
trait ClassFinder {
def findClasses(): Iterator[Artifact]
@@ -48,7 +48,7 @@ class REPLClassDirMonitor(_rootDir: String) extends
ClassFinder {
private def toArtifact(path: Path): Artifact = {
// Persist the relative path of the classfile
- Artifact.newClassArtifact(rootDir.relativize(path), new LocalFile(path))
+ Artifact.newClassArtifact(rootDir.relativize(path), new
Artifact.LocalFile(path))
}
private def isClass(path: Path): Boolean = path.toString.endsWith(".class")
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 224208b96ebf..ff734ee9c3a9 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -386,7 +386,7 @@ private[sql] class SparkConnectClient(
/**
* Cache the given local relation at the server, and return its key in the
remote cache.
*/
- def cacheLocalRelation(data: ByteString, schema: String): String = {
+ private[sql] def cacheLocalRelation(data: ByteString, schema: String):
String = {
val localRelation = proto.Relation
.newBuilder()
.getLocalRelationBuilder
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
index b0d9337c6448..72403016404c 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
@@ -30,8 +30,8 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest,
AddArtifactsResponse}
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
import org.apache.spark.sql.artifact.ArtifactManager
-import org.apache.spark.sql.artifact.util.ArtifactUtils
import org.apache.spark.sql.connect.utils.ErrorUtils
+import org.apache.spark.sql.util.ArtifactUtils
import org.apache.spark.util.Utils
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 358541c942f1..fa2d1b163322 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.net.URI
+import java.nio.file.Paths
import java.util.{ServiceLoader, UUID}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
@@ -55,7 +57,7 @@ import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ExecutionListenerManager
-import org.apache.spark.util.{CallSite, Utils}
+import org.apache.spark.util.{CallSite, SparkFileUtils, Utils}
import org.apache.spark.util.ArrayImplicits._
/**
@@ -608,6 +610,45 @@ class SparkSession private(
}
}
+ /** @inheritdoc */
+ @Experimental
+ override def addArtifact(path: String): Unit =
addArtifact(SparkFileUtils.resolveURI(path))
+
+ /** @inheritdoc */
+ @Experimental
+ override def addArtifact(uri: URI): Unit = {
+ artifactManager.addLocalArtifacts(Artifact.parseArtifacts(uri))
+ }
+
+ /** @inheritdoc */
+ @Experimental
+ override def addArtifact(bytes: Array[Byte], target: String): Unit = {
+ val targetPath = Paths.get(target)
+ val artifact = Artifact.newArtifactFromExtension(
+ targetPath.getFileName.toString,
+ targetPath,
+ new Artifact.InMemory(bytes))
+ artifactManager.addLocalArtifacts(artifact :: Nil)
+ }
+
+ /** @inheritdoc */
+ @Experimental
+ override def addArtifact(source: String, target: String): Unit = {
+ val targetPath = Paths.get(target)
+ val artifact = Artifact.newArtifactFromExtension(
+ targetPath.getFileName.toString,
+ targetPath,
+ new Artifact.LocalFile(Paths.get(source)))
+ artifactManager.addLocalArtifacts(artifact :: Nil)
+ }
+
+ /** @inheritdoc */
+ @Experimental
+ @scala.annotation.varargs
+ override def addArtifacts(uri: URI*): Unit = {
+ artifactManager.addLocalArtifacts(uri.flatMap(Artifact.parseArtifacts))
+ }
+
/**
* Returns a [[DataFrameReader]] that can be used to read non-streaming data
in as a
* `DataFrame`.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
index ea7a33f6887d..89aa7cfd5607 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.artifact
import java.io.File
import java.net.{URI, URL, URLClassLoader}
-import java.nio.file.{Files, Path, Paths, StandardCopyOption}
+import java.nio.file.{CopyOption, Files, Path, Paths, StandardCopyOption}
import java.util.concurrent.CopyOnWriteArrayList
import scala.jdk.CollectionConverters._
@@ -28,12 +28,12 @@ import scala.reflect.ClassTag
import org.apache.commons.io.{FilenameUtils, FileUtils}
import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}
-import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkEnv,
SparkUnsupportedOperationException}
+import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkEnv,
SparkException, SparkUnsupportedOperationException}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{CONNECT_SCALA_UDF_STUB_PREFIXES,
EXECUTOR_USER_CLASS_PATH_FIRST}
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.artifact.util.ArtifactUtils
+import org.apache.spark.sql.{Artifact, SparkSession}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.util.ArtifactUtils
import org.apache.spark.storage.{CacheId, StorageLevel}
import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader, Utils}
@@ -96,12 +96,19 @@ class ArtifactManager(session: SparkSession) extends
Logging {
*/
def getPythonIncludes: Seq[String] = pythonIncludeList.asScala.toSeq
- protected def moveFile(source: Path, target: Path, allowOverwrite: Boolean =
false): Unit = {
+ private def transferFile(
+ source: Path,
+ target: Path,
+ allowOverwrite: Boolean = false,
+ deleteSource: Boolean = true): Unit = {
+ def execute(s: Path, t: Path, opt: CopyOption*): Path =
+ if (deleteSource) Files.move(s, t, opt: _*) else Files.copy(s, t, opt:
_*)
+
Files.createDirectories(target.getParent)
if (allowOverwrite) {
- Files.move(source, target, StandardCopyOption.REPLACE_EXISTING)
+ execute(source, target, StandardCopyOption.REPLACE_EXISTING)
} else {
- Files.move(source, target)
+ execute(source, target)
}
}
@@ -112,12 +119,16 @@ class ArtifactManager(session: SparkSession) extends
Logging {
* @param remoteRelativePath
* @param serverLocalStagingPath
* @param fragment
+ * @param deleteStagedFile
*/
def addArtifact(
remoteRelativePath: Path,
serverLocalStagingPath: Path,
- fragment: Option[String]): Unit =
JobArtifactSet.withActiveJobArtifactState(state) {
+ fragment: Option[String],
+ deleteStagedFile: Boolean = true
+ ): Unit = JobArtifactSet.withActiveJobArtifactState(state) {
require(!remoteRelativePath.isAbsolute)
+
if (remoteRelativePath.startsWith(s"cache${File.separator}")) {
val tmpFile = serverLocalStagingPath.toFile
Utils.tryWithSafeFinallyAndFailureCallbacks {
@@ -142,7 +153,11 @@ class ArtifactManager(session: SparkSession) extends
Logging {
// Allow overwriting class files to capture updates to classes.
// This is required because the client currently sends all the class
files in each class file
// transfer.
- moveFile(serverLocalStagingPath, target, allowOverwrite = true)
+ transferFile(
+ serverLocalStagingPath,
+ target,
+ allowOverwrite = true,
+ deleteSource = deleteStagedFile)
} else {
val target = ArtifactUtils.concatenatePaths(artifactPath,
remoteRelativePath)
// Disallow overwriting with modified version
@@ -155,7 +170,7 @@ class ArtifactManager(session: SparkSession) extends
Logging {
throw new RuntimeException(s"Duplicate Artifact: $remoteRelativePath.
" +
"Artifacts cannot be overwritten.")
}
- moveFile(serverLocalStagingPath, target)
+ transferFile(serverLocalStagingPath, target, deleteSource =
deleteStagedFile)
// This URI is for Spark file server that starts with "spark://".
val uri = s"$artifactURI/${Utils.encodeRelativeUnixPathToURIRawPath(
@@ -181,6 +196,38 @@ class ArtifactManager(session: SparkSession) extends
Logging {
}
}
+ /**
+ * Add locally-stored artifacts to the session. These artifacts are from a
user-provided
+ * permanent path which are accessible by the driver directly.
+ *
+ * Different from the [[addArtifact]] method, this method will not delete
staged artifacts since
+ * they are from a permanent location.
+ */
+ private[sql] def addLocalArtifacts(artifacts: Seq[Artifact]): Unit = {
+ artifacts.foreach { artifact =>
+ artifact.storage match {
+ case d: Artifact.LocalFile =>
+ addArtifact(
+ artifact.path,
+ d.path,
+ fragment = None,
+ deleteStagedFile = false)
+ case d: Artifact.InMemory =>
+ val tempDir = Utils.createTempDir().toPath
+ val tempFile = tempDir.resolve(artifact.path.getFileName)
+ val outStream = Files.newOutputStream(tempFile)
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
+ d.stream.transferTo(outStream)
+ addArtifact(artifact.path, tempFile, fragment = None)
+ }(finallyBlock = {
+ outStream.close()
+ })
+ case _ =>
+ throw SparkException.internalError(s"Unsupported artifact storage:
${artifact.storage}")
+ }
+ }
+ }
+
/**
* Returns a [[ClassLoader]] for session-specific jar/class file resources.
*/
diff --git a/sql/core/src/test/resources/artifact-tests/HelloWithPackage.class
b/sql/core/src/test/resources/artifact-tests/HelloWithPackage.class
new file mode 100644
index 000000000000..f0ff0c4f5cf0
Binary files /dev/null and
b/sql/core/src/test/resources/artifact-tests/HelloWithPackage.class differ
diff --git a/sql/core/src/test/resources/artifact-tests/HelloWithPackage.java
b/sql/core/src/test/resources/artifact-tests/HelloWithPackage.java
new file mode 100644
index 000000000000..5c2c9fcdb178
--- /dev/null
+++ b/sql/core/src/test/resources/artifact-tests/HelloWithPackage.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Compile: javac --source 8 --target 8 HelloWithPackage.java
+
+package my.custom.pkg;
+
+public class HelloWithPackage {
+ String name = "there";
+
+ public HelloWithPackage() {
+ }
+
+ public HelloWithPackage(String name) {
+ this.name = name;
+ }
+
+ public String msg() {
+ return "Hello " + name + "! Nice to meet you!";
+ }
+}
diff --git a/sql/core/src/test/resources/artifact-tests/IntSumUdf.class
b/sql/core/src/test/resources/artifact-tests/IntSumUdf.class
new file mode 100644
index 000000000000..75a41446cfca
Binary files /dev/null and
b/sql/core/src/test/resources/artifact-tests/IntSumUdf.class differ
diff --git a/sql/core/src/test/resources/artifact-tests/IntSumUdf.scala
b/sql/core/src/test/resources/artifact-tests/IntSumUdf.scala
new file mode 100644
index 000000000000..9678caaed5db
--- /dev/null
+++ b/sql/core/src/test/resources/artifact-tests/IntSumUdf.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.spark.sql.api.java.UDF2
+
+class IntSumUdf extends UDF2[Long, Long, Long] {
+ override def call(t1: Long, t2: Long): Long = t1 + t2
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
index 5f68f235485e..66cccb497bc7 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
@@ -18,14 +18,16 @@ package org.apache.spark.sql.artifact
import java.io.File
import java.nio.charset.StandardCharsets
-import java.nio.file.{Files, Paths}
+import java.nio.file.{Files, Path, Paths}
import org.apache.commons.io.FileUtils
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.api.java.UDF2
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.DataTypes
import org.apache.spark.storage.CacheId
import org.apache.spark.util.Utils
@@ -293,4 +295,84 @@ class ArtifactManagerSuite extends SharedSparkSession {
assert(!sessionDirectory.exists())
assert(ArtifactManager.artifactRootDirectory.toFile.exists())
}
+
+ test("Add artifact to local session - by path") {
+ val (fileName, binaryName) = ("Hello.class", "Hello")
+ testAddArtifactToLocalSession(fileName, binaryName) { classPath =>
+ spark.addArtifact(classPath.toString)
+ fileName
+ }
+ }
+
+ test("Add artifact to local session - by URI") {
+ val (fileName, binaryName) = ("Hello.class", "Hello")
+ testAddArtifactToLocalSession(fileName, binaryName) { classPath =>
+ spark.addArtifact(classPath.toUri)
+ fileName
+ }
+ }
+
+ test("Add artifact to local session - custom target path") {
+ val (fileName, binaryName) = ("HelloWithPackage.class",
"my.custom.pkg.HelloWithPackage")
+ val filePath = "my/custom/pkg/HelloWithPackage.class"
+ testAddArtifactToLocalSession(fileName, binaryName) { classPath =>
+ spark.addArtifact(classPath.toString, filePath)
+ filePath
+ }
+ }
+
+ test("Add artifact to local session - in memory") {
+ val (fileName, binaryName) = ("HelloWithPackage.class",
"my.custom.pkg.HelloWithPackage")
+ val filePath = "my/custom/pkg/HelloWithPackage.class"
+ testAddArtifactToLocalSession(fileName, binaryName) { classPath =>
+ val buffer = Files.readAllBytes(classPath)
+ spark.addArtifact(buffer, filePath)
+ filePath
+ }
+ }
+
+ test("Add UDF as artifact") {
+ val buffer = Files.readAllBytes(artifactPath.resolve("IntSumUdf.class"))
+ spark.addArtifact(buffer, "IntSumUdf.class")
+
+ val instance = artifactManager.classloader
+ .loadClass("IntSumUdf")
+ .getDeclaredConstructor()
+ .newInstance()
+ .asInstanceOf[UDF2[Long, Long, Long]]
+ spark.udf.register("intSum", instance, DataTypes.LongType)
+
+ artifactManager.withResources {
+ val r = spark.range(5)
+ .withColumn("id2", col("id") + 1)
+ .selectExpr("intSum(id, id2)")
+ .collect()
+ assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
+ }
+ }
+
+ private def testAddArtifactToLocalSession(
+ classFileToUse: String, binaryName: String)(addFunc: Path => String):
Unit = {
+ val copyDir = Utils.createTempDir().toPath
+ FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
+ val classPath = copyDir.resolve(classFileToUse)
+ assert(classPath.toFile.exists())
+
+ val movedClassPath = addFunc(classPath)
+
+ val movedClassFile = ArtifactManager.artifactRootDirectory
+ .resolve(s"$sessionUUID/classes/$movedClassPath")
+ .toFile
+ assert(movedClassFile.exists())
+
+ val classLoader = artifactManager.classloader
+
+ val instance = classLoader
+ .loadClass(binaryName)
+ .getDeclaredConstructor(classOf[String])
+ .newInstance("Talon")
+
+ val msg = instance.getClass.getMethod("msg").invoke(instance)
+ assert(msg == "Hello Talon! Nice to meet you!")
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]