This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6fd1c649c72 [SPARK-43906][PYTHON][CONNECT] Implement the file support in SparkSession.addArtifacts 6fd1c649c72 is described below commit 6fd1c649c72d4b53ecf83c1643d38002d80c9288 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Wed Jun 7 10:54:24 2023 +0900 [SPARK-43906][PYTHON][CONNECT] Implement the file support in SparkSession.addArtifacts ### What changes were proposed in this pull request? This PR proposes to add the support of the regular files in `SparkSession.addArtifacts`. ### Why are the changes needed? So users can add the regular files in the worker nodes. ### Does this PR introduce _any_ user-facing change? Yes, it adds the support of arbitrary regular files in `SparkSession.addArtifacts`. ### How was this patch tested? Added a couple of unittests. Also manually tested in `local-cluster`: ```bash ./sbin/start-connect-server.sh --jars `ls connector/connect/server/target/**/spark-connect*SNAPSHOT.jar` --master "local-cluster[2,2,1024]" ./bin/pyspark --remote "sc://localhost:15002" ``` ```python import os import tempfile from pyspark.sql.functions import udf from pyspark import SparkFiles with tempfile.TemporaryDirectory() as d: file_path = os.path.join(d, "my_file.txt") with open(file_path, "w") as f: f.write("Hello world!!") udf("string") def func(x): with open( os.path.join(SparkFiles.getRootDirectory(), "my_file.txt"), "r" ) as my_file: return my_file.read().strip() spark.addArtifacts(file_path, file=True) spark.range(1).select(func("id")).show() ``` Closes #41415 from HyukjinKwon/addFile. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../artifact/SparkConnectArtifactManager.scala | 6 +++-- python/pyspark/sql/connect/client/artifact.py | 21 +++++++++++---- python/pyspark/sql/connect/client/core.py | 4 +-- python/pyspark/sql/connect/session.py | 13 ++++++--- .../sql/tests/connect/client/test_artifact.py | 31 +++++++++++++++++++--- 5 files changed, 58 insertions(+), 17 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index 604108f68d2..47c48d8e083 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -97,6 +97,7 @@ class SparkConnectArtifactManager private[connect] { * @param session * @param remoteRelativePath * @param serverLocalStagingPath + * @param fragment */ private[connect] def addArtifact( sessionHolder: SessionHolder, @@ -135,8 +136,7 @@ class SparkConnectArtifactManager private[connect] { // previously added, if (Files.exists(target)) { throw new RuntimeException( - s"Duplicate Jar: $remoteRelativePath. " + - s"Jars cannot be overwritten.") + s"Duplicate file: $remoteRelativePath. Files cannot be overwritten.") } Files.move(serverLocalStagingPath, target) if (remoteRelativePath.startsWith(s"jars${File.separator}")) { @@ -154,6 +154,8 @@ class SparkConnectArtifactManager private[connect] { val canonicalUri = fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri) sessionHolder.session.sparkContext.addArchive(canonicalUri.toString) + } else if (remoteRelativePath.startsWith(s"files${File.separator}")) { + sessionHolder.session.sparkContext.addFile(target.toString) } } } diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py index 64f89119e4f..9a848bd96b8 100644 --- a/python/pyspark/sql/connect/client/artifact.py +++ b/python/pyspark/sql/connect/client/artifact.py @@ -39,6 +39,7 @@ import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib JAR_PREFIX: str = "jars" PYFILE_PREFIX: str = "pyfiles" ARCHIVE_PREFIX: str = "archives" +FILE_PREFIX: str = "files" class LocalData(metaclass=abc.ABCMeta): @@ -107,6 +108,10 @@ def new_archive_artifact(file_name: str, storage: LocalData) -> Artifact: return _new_artifact(ARCHIVE_PREFIX, "", file_name, storage) +def new_file_artifact(file_name: str, storage: LocalData) -> Artifact: + return _new_artifact(FILE_PREFIX, "", file_name, storage) + + def _new_artifact( prefix: str, required_suffix: str, file_name: str, storage: LocalData ) -> Artifact: @@ -141,7 +146,9 @@ class ArtifactManager: self._stub = grpc_lib.SparkConnectServiceStub(channel) self._session_id = session_id - def _parse_artifacts(self, path_or_uri: str, pyfile: bool, archive: bool) -> List[Artifact]: + def _parse_artifacts( + self, path_or_uri: str, pyfile: bool, archive: bool, file: bool + ) -> List[Artifact]: # Currently only local files with .jar extension is supported. parsed = urlparse(path_or_uri) # Check if it is a file from the scheme @@ -180,6 +187,8 @@ class ArtifactManager: name = f"{name}#{parsed.fragment}" artifact = new_archive_artifact(name, LocalFile(local_path)) + elif file: + artifact = new_file_artifact(name, LocalFile(local_path)) elif name.endswith(".jar"): artifact = new_jar_artifact(name, LocalFile(local_path)) else: @@ -188,11 +197,13 @@ class ArtifactManager: raise RuntimeError(f"Unsupported scheme: {parsed.scheme}") def _create_requests( - self, *path: str, pyfile: bool, archive: bool + self, *path: str, pyfile: bool, archive: bool, file: bool ) -> Iterator[proto.AddArtifactsRequest]: """Separated for the testing purpose.""" return self._add_artifacts( - chain(*(self._parse_artifacts(p, pyfile=pyfile, archive=archive) for p in path)) + chain( + *(self._parse_artifacts(p, pyfile=pyfile, archive=archive, file=file) for p in path) + ) ) def _retrieve_responses( @@ -201,13 +212,13 @@ class ArtifactManager: """Separated for the testing purpose.""" return self._stub.AddArtifacts(requests) - def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None: + def add_artifacts(self, *path: str, pyfile: bool, archive: bool, file: bool) -> None: """ Add a single artifact to the session. Currently only local files with .jar extension is supported. """ requests: Iterator[proto.AddArtifactsRequest] = self._create_requests( - *path, pyfile=pyfile, archive=archive + *path, pyfile=pyfile, archive=archive, file=file ) response: proto.AddArtifactsResponse = self._retrieve_responses(requests) summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = [] diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 8da649e7765..b2071641a26 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1251,8 +1251,8 @@ class SparkConnectClient(object): else: raise SparkConnectGrpcException(str(rpc_error)) from None - def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None: - self._artifact_manager.add_artifacts(*path, pyfile=pyfile, archive=archive) + def add_artifacts(self, *path: str, pyfile: bool, archive: bool, file: bool) -> None: + self._artifact_manager.add_artifacts(*path, pyfile=pyfile, archive=archive, file=file) class RetryState: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 2d58ce1daf0..341db448955 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -613,7 +613,9 @@ class SparkSession: """ return self._client - def addArtifacts(self, *path: str, pyfile: bool = False, archive: bool = False) -> None: + def addArtifacts( + self, *path: str, pyfile: bool = False, archive: bool = False, file: bool = False + ) -> None: """ Add artifact(s) to the client session. Currently only local files are supported. @@ -630,10 +632,13 @@ class SparkSession: archive : bool Whether to add them as archives such as .zip, .jar, .tar.gz, .tgz, or .tar files. The archives are unpacked on the executor side automatically. + file : bool + Add a file to be downloaded with this Spark job on every node. + The ``path`` passed can only be a local file for now. """ - if pyfile and archive: - raise ValueError("'pyfile' and 'archive' cannot be True together.") - self._client.add_artifacts(*path, pyfile=pyfile, archive=archive) + if sum([file, pyfile, archive]) > 1: + raise ValueError("'pyfile', 'archive' and/or 'file' cannot be True together.") + self._client.add_artifacts(*path, pyfile=pyfile, archive=archive, file=file) addArtifact = addArtifacts diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index 2bff3fd5bc4..1725e0f6e4c 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -49,7 +49,9 @@ class ArtifactTests(ReusedConnectTestCase): file_name = "smallJar" small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar") response = self.artifact_manager._retrieve_responses( - self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False) + self.artifact_manager._create_requests( + small_jar_path, pyfile=False, archive=False, file=False + ) ) self.assertTrue(response.artifacts[0].name.endswith(f"{file_name}.jar")) @@ -59,7 +61,9 @@ class ArtifactTests(ReusedConnectTestCase): small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt") requests = list( - self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False) + self.artifact_manager._create_requests( + small_jar_path, pyfile=False, archive=False, file=False + ) ) self.assertEqual(len(requests), 1) @@ -83,7 +87,9 @@ class ArtifactTests(ReusedConnectTestCase): large_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt") requests = list( - self.artifact_manager._create_requests(large_jar_path, pyfile=False, archive=False) + self.artifact_manager._create_requests( + large_jar_path, pyfile=False, archive=False, file=False + ) ) # Expected chunks = roundUp( file_size / chunk_size) = 12 # File size of `junitLargeJar.jar` is 384581 bytes. @@ -117,7 +123,7 @@ class ArtifactTests(ReusedConnectTestCase): requests = list( self.artifact_manager._create_requests( - small_jar_path, small_jar_path, pyfile=False, archive=False + small_jar_path, small_jar_path, pyfile=False, archive=False, file=False ) ) # Single request containing 2 artifacts. @@ -160,6 +166,7 @@ class ArtifactTests(ReusedConnectTestCase): small_jar_path, pyfile=False, archive=False, + file=False, ) ) # There are a total of 14 requests. @@ -271,6 +278,22 @@ class ArtifactTests(ReusedConnectTestCase): self.spark.addArtifacts(f"{archive_path}.zip#my_files", archive=True) self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "hello world!") + def test_add_file(self): + with tempfile.TemporaryDirectory() as d: + file_path = os.path.join(d, "my_file.txt") + with open(file_path, "w") as f: + f.write("Hello world!!") + + @udf("string") + def func(x): + with open( + os.path.join(SparkFiles.getRootDirectory(), "my_file.txt"), "r" + ) as my_file: + return my_file.read().strip() + + self.spark.addArtifacts(file_path, file=True) + self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "Hello world!!") + if __name__ == "__main__": from pyspark.sql.tests.connect.client.test_artifact import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org