This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 3bca65a68ef [SPARK-42482][CONNECT] Scala Client Write API V1
3bca65a68ef is described below
commit 3bca65a68ef9c0716eef3e7a965fe280ba751673
Author: Zhen Li <[email protected]>
AuthorDate: Sun Feb 19 13:04:37 2023 -0400
[SPARK-42482][CONNECT] Scala Client Write API V1
### What changes were proposed in this pull request?
Implemented the basic Dataset#write API to allow users to write the df into
tables, csv etc. files.
### Why are the changes needed?
Basic write operation.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Integration tests.
Closes #40061 from zhenlineo/write.
Authored-by: Zhen Li <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
(cherry picked from commit ede1a541182e043b1f79a9ffbfc4a7fa97604078)
Signed-off-by: Herman van Hovell <[email protected]>
---
.../main/java/org/apache/spark/sql/SaveMode.java | 58 +++
.../org/apache/spark/sql/DataFrameWriter.scala | 457 +++++++++++++++++++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 10 +
.../scala/org/apache/spark/sql/SparkSession.scala | 7 +
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 51 ++-
.../scala/org/apache/spark/sql/DatasetSuite.scala | 32 ++
.../sql/connect/client/CompatibilitySuite.scala | 4 +-
.../connect/client/SparkConnectClientSuite.scala | 16 +-
8 files changed, 626 insertions(+), 9 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java
b/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java
new file mode 100644
index 00000000000..95af157687c
--- /dev/null
+++
b/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql;
+
+import org.apache.spark.annotation.Stable;
+
+/**
+ * SaveMode is used to specify the expected behavior of saving a DataFrame to
a data source.
+ *
+ * @since 3.4.0
+ */
+@Stable
+public enum SaveMode {
+ /**
+ * Append mode means that when saving a DataFrame to a data source, if
data/table already exists,
+ * contents of the DataFrame are expected to be appended to existing data.
+ *
+ * @since 3.4.0
+ */
+ Append,
+ /**
+ * Overwrite mode means that when saving a DataFrame to a data source,
+ * if data/table already exists, existing data is expected to be overwritten
by the contents of
+ * the DataFrame.
+ *
+ * @since 3.4.0
+ */
+ Overwrite,
+ /**
+ * ErrorIfExists mode means that when saving a DataFrame to a data source,
if data already exists,
+ * an exception is expected to be thrown.
+ *
+ * @since 3.4.0
+ */
+ ErrorIfExists,
+ /**
+ * Ignore mode means that when saving a DataFrame to a data source, if data
already exists,
+ * the save operation is expected to not save the contents of the DataFrame
and to not
+ * change the existing data.
+ *
+ * @since 3.4.0
+ */
+ Ignore
+}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
new file mode 100644
index 00000000000..b7c4ed7bcab
--- /dev/null
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -0,0 +1,457 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.Stable
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+
+/**
+ * Interface used to write a [[Dataset]] to external storage systems (e.g.
file systems, key-value
+ * stores, etc). Use `Dataset.write` to access this.
+ *
+ * @since 3.4.0
+ */
+@Stable
+final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) {
+
+ /**
+ * Specifies the behavior when data or table already exists. Options
include: <ul>
+ * <li>`SaveMode.Overwrite`: overwrite the existing data.</li>
<li>`SaveMode.Append`: append the
+ * data.</li> <li>`SaveMode.Ignore`: ignore the operation (i.e. no-op).</li>
+ * <li>`SaveMode.ErrorIfExists`: throw an exception at runtime.</li> </ul>
<p> The default
+ * option is `ErrorIfExists`.
+ *
+ * @since 3.4.0
+ */
+ def mode(saveMode: SaveMode): DataFrameWriter[T] = {
+ this.mode = saveMode
+ this
+ }
+
+ /**
+ * Specifies the behavior when data or table already exists. Options
include: <ul>
+ * <li>`overwrite`: overwrite the existing data.</li> <li>`append`: append
the data.</li>
+ * <li>`ignore`: ignore the operation (i.e. no-op).</li> <li>`error` or
`errorifexists`: default
+ * option, throw an exception at runtime.</li> </ul>
+ *
+ * @since 3.4.0
+ */
+ def mode(saveMode: String): DataFrameWriter[T] = {
+ saveMode.toLowerCase(Locale.ROOT) match {
+ case "overwrite" => mode(SaveMode.Overwrite)
+ case "append" => mode(SaveMode.Append)
+ case "ignore" => mode(SaveMode.Ignore)
+ case "error" | "errorifexists" | "default" =>
mode(SaveMode.ErrorIfExists)
+ case _ =>
+ throw new IllegalArgumentException(s"Unknown save mode: $saveMode.
Accepted " +
+ "save modes are 'overwrite', 'append', 'ignore', 'error',
'errorifexists', 'default'.")
+ }
+ }
+
+ /**
+ * Specifies the underlying output data source. Built-in options include
"parquet", "json", etc.
+ *
+ * @since 3.4.0
+ */
+ def format(source: String): DataFrameWriter[T] = {
+ this.source = Some(source)
+ this
+ }
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key
names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: String): DataFrameWriter[T] = {
+ this.extraOptions = this.extraOptions + (key -> value)
+ this
+ }
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key
names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Boolean): DataFrameWriter[T] = option(key,
value.toString)
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key
names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Long): DataFrameWriter[T] = option(key,
value.toString)
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key
names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Double): DataFrameWriter[T] = option(key,
value.toString)
+
+ /**
+ * (Scala-specific) Adds output options for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key
names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def options(options: scala.collection.Map[String, String]):
DataFrameWriter[T] = {
+ this.extraOptions ++= options
+ this
+ }
+
+ /**
+ * Adds output options for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key
names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def options(options: java.util.Map[String, String]): DataFrameWriter[T] = {
+ this.options(options.asScala)
+ this
+ }
+
+ /**
+ * Partitions the output by the given columns on the file system. If
specified, the output is
+ * laid out on the file system similar to Hive's partitioning scheme. As an
example, when we
+ * partition a dataset by year and then month, the directory layout would
look like: <ul>
+ * <li>year=2016/month=01/</li> <li>year=2016/month=02/</li> </ul>
+ *
+ * Partitioning is one of the most widely used techniques to optimize
physical data layout. It
+ * provides a coarse-grained index for skipping unnecessary data reads when
queries have
+ * predicates on the partitioned columns. In order for partitioning to work
well, the number of
+ * distinct values in each column should typically be less than tens of
thousands.
+ *
+ * This is applicable for all file-based data sources (e.g. Parquet, JSON)
starting with Spark
+ * 2.1.0.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def partitionBy(colNames: String*): DataFrameWriter[T] = {
+ this.partitioningColumns = Option(colNames)
+ this
+ }
+
+ /**
+ * Buckets the output by the given columns. If specified, the output is laid
out on the file
+ * system similar to Hive's bucketing scheme, but with a different bucket
hash function and is
+ * not compatible with Hive's bucketing.
+ *
+ * This is applicable for all file-based data sources (e.g. Parquet, JSON)
starting with Spark
+ * 2.1.0.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def bucketBy(numBuckets: Int, colName: String, colNames: String*):
DataFrameWriter[T] = {
+ require(numBuckets > 0, "The numBuckets should be > 0.")
+ this.numBuckets = Option(numBuckets)
+ this.bucketColumnNames = Option(colName +: colNames)
+ this
+ }
+
+ /**
+ * Sorts the output in each bucket by the given columns.
+ *
+ * This is applicable for all file-based data sources (e.g. Parquet, JSON)
starting with Spark
+ * 2.1.0.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def sortBy(colName: String, colNames: String*): DataFrameWriter[T] = {
+ this.sortColumnNames = Option(colName +: colNames)
+ this
+ }
+
+ /**
+ * Saves the content of the `DataFrame` at the specified path.
+ *
+ * @since 3.4.0
+ */
+ def save(path: String): Unit = {
+ saveInternal(Some(path))
+ }
+
+ /**
+ * Saves the content of the `DataFrame` as the specified table.
+ *
+ * @since 3.4.0
+ */
+ def save(): Unit = saveInternal(None)
+
+ private def saveInternal(path: Option[String]): Unit = {
+ executeWriteOperation(builder => path.foreach(builder.setPath))
+ }
+
+ private def executeWriteOperation(f: proto.WriteOperation.Builder => Unit):
Unit = {
+ val builder = proto.WriteOperation.newBuilder()
+
+ builder.setInput(ds.plan.getRoot)
+
+ // Set path or table
+ f(builder)
+ require(builder.hasPath != builder.hasTable) // Only one can be set
+
+ builder.setMode(mode match {
+ case SaveMode.Append => proto.WriteOperation.SaveMode.SAVE_MODE_APPEND
+ case SaveMode.Overwrite =>
proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE
+ case SaveMode.Ignore => proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE
+ case SaveMode.ErrorIfExists =>
proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS
+ })
+
+ source.foreach(builder.setSource)
+ sortColumnNames.foreach(names =>
builder.addAllSortColumnNames(names.asJava))
+ partitioningColumns.foreach(cols =>
builder.addAllPartitioningColumns(cols.asJava))
+
+ numBuckets.foreach(n => {
+ val bucketBuilder = proto.WriteOperation.BucketBy.newBuilder()
+ bucketBuilder.setNumBuckets(n)
+ bucketColumnNames.foreach(names =>
bucketBuilder.addAllBucketColumnNames(names.asJava))
+ builder.setBucketBy(bucketBuilder)
+ })
+
+ extraOptions.foreach { case (k, v) =>
+ builder.putOptions(k, v)
+ }
+
+
ds.session.execute(proto.Command.newBuilder().setWriteOperation(builder).build())
+ }
+
+ /**
+ * Inserts the content of the `DataFrame` to the specified table. It
requires that the schema of
+ * the `DataFrame` is the same as the schema of the table.
+ *
+ * @note
+ * Unlike `saveAsTable`, `insertInto` ignores the column names and just
uses position-based
+ * resolution. For example:
+ *
+ * @note
+ * SaveMode.ErrorIfExists and SaveMode.Ignore behave as SaveMode.Append in
`insertInto` as
+ * `insertInto` is not a table creating operation.
+ *
+ * {{{
+ * scala> Seq((1, 2)).toDF("i",
"j").write.mode("overwrite").saveAsTable("t1")
+ * scala> Seq((3, 4)).toDF("j", "i").write.insertInto("t1")
+ * scala> Seq((5, 6)).toDF("a", "b").write.insertInto("t1")
+ * scala> sql("select * from t1").show
+ * +---+---+
+ * | i| j|
+ * +---+---+
+ * | 5| 6|
+ * | 3| 4|
+ * | 1| 2|
+ * +---+---+
+ * }}}
+ *
+ * Because it inserts data to an existing table, format or options will be
ignored.
+ *
+ * @since 3.4.0
+ */
+ def insertInto(tableName: String): Unit = {
+ executeWriteOperation(builder => {
+ builder.setTable(
+ proto.WriteOperation.SaveTable
+ .newBuilder()
+ .setTableName(tableName)
+ .setSaveMethod(
+
proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO))
+ })
+ }
+
+ /**
+ * Saves the content of the `DataFrame` as the specified table.
+ *
+ * In the case the table already exists, behavior of this function depends
on the save mode,
+ * specified by the `mode` function (default to throwing an exception). When
`mode` is
+ * `Overwrite`, the schema of the `DataFrame` does not need to be the same
as that of the
+ * existing table.
+ *
+ * When `mode` is `Append`, if there is an existing table, we will use the
format and options of
+ * the existing table. The column order in the schema of the `DataFrame`
doesn't need to be same
+ * as that of the existing table. Unlike `insertInto`, `saveAsTable` will
use the column names
+ * to find the correct column positions. For example:
+ *
+ * {{{
+ * scala> Seq((1, 2)).toDF("i",
"j").write.mode("overwrite").saveAsTable("t1")
+ * scala> Seq((3, 4)).toDF("j",
"i").write.mode("append").saveAsTable("t1")
+ * scala> sql("select * from t1").show
+ * +---+---+
+ * | i| j|
+ * +---+---+
+ * | 1| 2|
+ * | 4| 3|
+ * +---+---+
+ * }}}
+ *
+ * In this method, save mode is used to determine the behavior if the data
source table exists
+ * in Spark catalog. We will always overwrite the underlying data of data
source (e.g. a table
+ * in JDBC data source) if the table doesn't exist in Spark catalog, and
will always append to
+ * the underlying data of data source if the table already exists.
+ *
+ * When the DataFrame is created from a non-partitioned `HadoopFsRelation`
with a single input
+ * path, and the data source provider can be mapped to an existing Hive
builtin SerDe (i.e. ORC
+ * and Parquet), the table is persisted in a Hive compatible format, which
means other systems
+ * like Hive will be able to read this table. Otherwise, the table is
persisted in a Spark SQL
+ * specific format.
+ *
+ * @since 3.4.0
+ */
+ def saveAsTable(tableName: String): Unit = {
+ executeWriteOperation(builder => {
+ builder.setTable(
+ proto.WriteOperation.SaveTable
+ .newBuilder()
+ .setTableName(tableName)
+ .setSaveMethod(
+
proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE))
+ })
+ }
+
+ /**
+ * Saves the content of the `DataFrame` in JSON format (<a
href="http://jsonlines.org/"> JSON
+ * Lines text format or newline-delimited JSON</a>) at the specified path.
This is equivalent
+ * to:
+ * {{{
+ * format("json").save(path)
+ * }}}
+ *
+ * You can find the JSON-specific options for writing JSON files in <a
+ *
href="https://spark.apache.org/docs/latest/sql-data-sources-json.html#data-source-option">
+ * Data Source Option</a> in the version you use.
+ *
+ * @since 3.4.0
+ */
+ def json(path: String): Unit = {
+ format("json").save(path)
+ }
+
+ /**
+ * Saves the content of the `DataFrame` in Parquet format at the specified
path. This is
+ * equivalent to:
+ * {{{
+ * format("parquet").save(path)
+ * }}}
+ *
+ * Parquet-specific option(s) for writing Parquet files can be found in <a
href=
+ *
"https://spark.apache.org/docs/latest/sql-data-sources-parquet.html#data-source-option">
Data
+ * Source Option</a> in the version you use.
+ *
+ * @since 3.4.0
+ */
+ def parquet(path: String): Unit = {
+ format("parquet").save(path)
+ }
+
+ /**
+ * Saves the content of the `DataFrame` in ORC format at the specified path.
This is equivalent
+ * to:
+ * {{{
+ * format("orc").save(path)
+ * }}}
+ *
+ * ORC-specific option(s) for writing ORC files can be found in <a href=
+ *
"https://spark.apache.org/docs/latest/sql-data-sources-orc.html#data-source-option">
Data
+ * Source Option</a> in the version you use.
+ *
+ * @since 3.4.0
+ */
+ def orc(path: String): Unit = {
+ format("orc").save(path)
+ }
+
+ /**
+ * Saves the content of the `DataFrame` in a text file at the specified
path. The DataFrame must
+ * have only one column that is of string type. Each row becomes a new line
in the output file.
+ * For example:
+ * {{{
+ * // Scala:
+ * df.write.text("/path/to/output")
+ *
+ * // Java:
+ * df.write().text("/path/to/output")
+ * }}}
+ * The text files will be encoded as UTF-8.
+ *
+ * You can find the text-specific options for writing text files in <a
+ *
href="https://spark.apache.org/docs/latest/sql-data-sources-text.html#data-source-option">
+ * Data Source Option</a> in the version you use.
+ *
+ * @since 3.4.0
+ */
+ def text(path: String): Unit = {
+ format("text").save(path)
+ }
+
+ /**
+ * Saves the content of the `DataFrame` in CSV format at the specified path.
This is equivalent
+ * to:
+ * {{{
+ * format("csv").save(path)
+ * }}}
+ *
+ * You can find the CSV-specific options for writing CSV files in <a
+ *
href="https://spark.apache.org/docs/latest/sql-data-sources-csv.html#data-source-option">
+ * Data Source Option</a> in the version you use.
+ *
+ * @since 3.4.0
+ */
+ def csv(path: String): Unit = {
+ format("csv").save(path)
+ }
+
+
///////////////////////////////////////////////////////////////////////////////////////
+ // Builder pattern config options
+
///////////////////////////////////////////////////////////////////////////////////////
+
+ private var source: Option[String] = None
+
+ private var mode: SaveMode = SaveMode.ErrorIfExists
+
+ private var extraOptions = CaseInsensitiveMap[String](Map.empty)
+
+ private var partitioningColumns: Option[Seq[String]] = None
+
+ private var bucketColumnNames: Option[Seq[String]] = None
+
+ private var numBuckets: Option[Int] = None
+
+ private var sortColumnNames: Option[Seq[String]] = None
+}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 3c34b45fccb..3c876c05432 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2246,6 +2246,16 @@ class Dataset[T] private[sql] (val session:
SparkSession, private[sql] val plan:
*/
def inputFiles: Array[String] = analyze.getInputFilesList.asScala.toArray
+ /**
+ * Interface for saving the content of the non-streaming Dataset out into
external storage.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def write: DataFrameWriter[T] = {
+ new DataFrameWriter[T](this)
+ }
+
private[sql] def analyze: proto.AnalyzePlanResponse = {
session.analyze(plan, proto.Explain.ExplainMode.SIMPLE)
}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 54871c99b56..1761e8ce42d 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -18,6 +18,8 @@ package org.apache.spark.sql
import java.io.Closeable
+import scala.collection.JavaConverters._
+
import org.apache.arrow.memory.RootAllocator
import org.apache.spark.connect.proto
@@ -162,6 +164,11 @@ class SparkSession(private val client: SparkConnectClient,
private val cleaner:
result
}
+ private[sql] def execute(command: proto.Command): Unit = {
+ val plan = proto.Plan.newBuilder().setCommand(command).build()
+ client.execute(plan).asScala.foreach(_ => ())
+ }
+
override def close(): Unit = {
client.shutdown()
allocator.close()
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 058ba1a8efc..145d62feefc 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -21,6 +21,8 @@ import java.io.{ByteArrayOutputStream, PrintStream}
import scala.collection.JavaConverters._
import io.grpc.StatusRuntimeException
+import java.nio.file.Files
+import org.apache.commons.io.FileUtils
import org.apache.commons.io.output.TeeOutputStream
import org.scalactic.TolerantNumerics
@@ -67,7 +69,7 @@ class ClientE2ETestSuite extends RemoteSparkSession {
}
}
- test("read") {
+ test("read and write") {
val testDataPath = java.nio.file.Paths
.get(
IntegrationTestUtils.sparkHome,
@@ -91,11 +93,20 @@ class ClientE2ETestSuite extends RemoteSparkSession {
StructField("age", IntegerType) ::
StructField("job", StringType) :: Nil))
.load()
- val array = df.collectResult().toArray
- assert(array.length == 2)
- assert(array(0).getString(0) == "Jorge")
- assert(array(0).getInt(1) == 30)
- assert(array(0).getString(2) == "Developer")
+ val outputFolderPath = Files.createTempDirectory("output").toAbsolutePath
+
+ df.write
+ .format("csv")
+ .mode("overwrite")
+ .options(Map("header" -> "true", "delimiter" -> ";"))
+ .save(outputFolderPath.toString)
+
+ // We expect only one csv file saved.
+ val outputFile = outputFolderPath.toFile
+ .listFiles()
+ .filter(file => file.getPath.endsWith(".csv"))(0)
+
+ assert(FileUtils.contentEquals(testDataPath.toFile, outputFile))
}
test("read path collision") {
@@ -124,7 +135,33 @@ class ClientE2ETestSuite extends RemoteSparkSession {
.csv(testDataPath.toString)
// Failed because the path cannot be provided both via option and load
method (csv).
assertThrows[StatusRuntimeException] {
- df.collectResult().toArray
+ df.collect()
+ }
+ }
+
+ test("write table") {
+ try {
+ val df = spark.range(10).limit(3)
+ df.write.mode(SaveMode.Overwrite).saveAsTable("myTable")
+ spark.range(2).write.insertInto("myTable")
+ val result = spark.sql("select * from myTable").sort("id").collect()
+ assert(result.length == 5)
+ assert(result(0).getLong(0) == 0)
+ assert(result(1).getLong(0) == 0)
+ assert(result(2).getLong(0) == 1)
+ assert(result(3).getLong(0) == 1)
+ assert(result(4).getLong(0) == 2)
+ } finally {
+ spark.sql("drop table if exists myTable").collect()
+ }
+ }
+
+ test("write path collision") {
+ val df = spark.range(10)
+ val outputFolderPath = Files.createTempDirectory("output").toAbsolutePath
+ // Failed because the path cannot be provided both via option and save
method.
+ assertThrows[StatusRuntimeException] {
+ df.write.option("path",
outputFolderPath.toString).save(outputFolderPath.toString)
}
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 087dcbb360a..66e597f2457 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -115,4 +115,36 @@ class DatasetSuite
val actualPlan = service.getAndClearLatestInputPlan()
assert(actualPlan.equals(expectedPlan))
}
+
+ test("write") {
+ val df = ss.newDataset(_ => ()).limit(10)
+
+ val builder = proto.WriteOperation.newBuilder()
+ builder
+ .setInput(df.plan.getRoot)
+ .setPath("my/test/path")
+ .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS)
+ .setSource("parquet")
+ .addSortColumnNames("col1")
+ .addPartitioningColumns("col99")
+ .setBucketBy(
+ proto.WriteOperation.BucketBy
+ .newBuilder()
+ .setNumBuckets(2)
+ .addBucketColumnNames("col1")
+ .addBucketColumnNames("col2"))
+
+ val expectedPlan = proto.Plan
+ .newBuilder()
+ .setCommand(proto.Command.newBuilder().setWriteOperation(builder))
+ .build()
+
+ df.write
+ .sortBy("col1")
+ .partitionBy("col99")
+ .bucketBy(2, "col1", "col2")
+ .parquet("my/test/path")
+ val actualPlan = service.getAndClearLatestInputPlan()
+ assert(actualPlan.equals(expectedPlan))
+ }
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
index fa2cb18cda2..81d58566cd9 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
@@ -76,6 +76,7 @@ class CompatibilitySuite extends AnyFunSuite { //
scalastyle:ignore funsuite
// IncludeByName("org.apache.spark.sql.Dataset$"),
IncludeByName("org.apache.spark.sql.DataFrame"),
IncludeByName("org.apache.spark.sql.DataFrameReader"),
+ IncludeByName("org.apache.spark.sql.DataFrameWriter"),
IncludeByName("org.apache.spark.sql.SparkSession"),
IncludeByName("org.apache.spark.sql.SparkSession$")) ++
includeImplementedMethods(clientJar)
val excludeRules = Seq(
@@ -135,7 +136,8 @@ class CompatibilitySuite extends AnyFunSuite { //
scalastyle:ignore funsuite
// the Dataset methods, as too many overload methods are missing.
// "org.apache.spark.sql.Dataset",
"org.apache.spark.sql.SparkSession",
- "org.apache.spark.sql.DataFrameReader")
+ "org.apache.spark.sql.DataFrameReader",
+ "org.apache.spark.sql.DataFrameWriter")
val clientClassLoader: URLClassLoader = new
URLClassLoader(Seq(clientJar.toURI.toURL).toArray)
clsNames
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index f3caba28ffd..908eddbe7bf 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{AnalyzePlanRequest,
AnalyzePlanResponse, SparkConnectServiceGrpc}
+import org.apache.spark.connect.proto.{AnalyzePlanRequest,
AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse,
SparkConnectServiceGrpc}
import org.apache.spark.sql.connect.common.config.ConnectCommon
class SparkConnectClientSuite
@@ -160,6 +160,20 @@ class DummySparkConnectService() extends
SparkConnectServiceGrpc.SparkConnectSer
plan
}
+ override def executePlan(
+ request: ExecutePlanRequest,
+ responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
+ // Reply with a dummy response using the same client ID
+ val requestClientId = request.getClientId
+ inputPlan = request.getPlan
+ val response = ExecutePlanResponse
+ .newBuilder()
+ .setClientId(requestClientId)
+ .build()
+ responseObserver.onNext(response)
+ responseObserver.onCompleted()
+ }
+
override def analyzePlan(
request: AnalyzePlanRequest,
responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]