This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8853f28 [SPARK-35173][SQL][PYTHON] Add multiple columns adding support
8853f28 is described below
commit 8853f286371bcc1d44762a0a8ed5bf1a40cdbbd5
Author: Yikun Jiang <[email protected]>
AuthorDate: Tue Feb 15 09:40:27 2022 +0900
[SPARK-35173][SQL][PYTHON] Add multiple columns adding support
### What changes were proposed in this pull request?
This PR added the multiple columns adding support for Spark
scala/java/python API.
- Expose `withColumns` with Map input as public API in Scala/Java
- Add `withColumns` in PySpark
There was also some discussion about adding multiple columns in past
JIRA([SPARK-1225](https://issues.apache.org/jira/browse/SPARK-12225),
[SPARK-26224](https://issues.apache.org/jira/browse/SPARK-26224)) and
[ML](http://apache-spark-developers-list.1001551.n3.nabble.com/DISCUSS-Multiple-columns-adding-replacing-support-in-PySpark-DataFrame-API-td31164.html).
### Why are the changes needed?
There were a private method `withColumns` can add columns at one pass [1]:
https://github.com/apache/spark/blob/b5241c97b17a1139a4ff719bfce7f68aef094d95/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L2402
However, it was not exposed as public API in Scala/Java, and also PySpark
user can only use `withColumn` to add one column or replacing the existing one
column that has the same name.
For example, if the PySpark user want to add multiple columns, they should
call `withColumn` again and again like:
```Python
df.withColumn("key1", col("key1")).withColumn("key2",
col("key2")).withColumn("key3", col("key3"))
```
After this patch, the user can use the `withColumns` with map of colume
name and column :
```Python
df.withColumns({"key1": col("key1"), "key2":col("key2"), "key3":
col("key3")})
```
### Does this PR introduce _any_ user-facing change?
Yes, this PR exposes `withColumns` as public API, and also adds
`withColumns` API in PySpark .
### How was this patch tested?
- Add new multiple columns adding test, passed
- Existing test, passed
Closes #32431 from Yikun/SPARK-35173-cols.
Authored-by: Yikun Jiang <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/dataframe.py | 35 ++++++++++++++++++++++
python/pyspark/sql/tests/test_dataframe.py | 27 +++++++++++++++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 29 ++++++++++++++++++
.../org/apache/spark/sql/JavaDataFrameSuite.java | 16 ++++++++++
.../org/apache/spark/sql/DataFrameSuite.scala | 18 +++++++++--
5 files changed, 122 insertions(+), 3 deletions(-)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ee68865..0372527 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2911,6 +2911,41 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
support = 0.01
return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols),
support), self.sql_ctx)
+ def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame":
+ """
+ Returns a new :class:`DataFrame` by adding multiple columns or
replacing the
+ existing columns that has the same names.
+
+ The colsMap is a map of column name and column, the column must only
refer to attributes
+ supplied by this Dataset. It is an error to add columns that refer to
some other Dataset.
+
+ .. versionadded:: 3.3.0
+ Added support for multiple columns adding
+
+ Parameters
+ ----------
+ colsMap : dict
+ a dict of column name and :class:`Column`. Currently, only single
map is supported.
+
+ Examples
+ --------
+ >>> df.withColumns({'age2': df.age + 2, 'age3': df.age + 3}).collect()
+ [Row(age=2, name='Alice', age2=4, age3=5), Row(age=5, name='Bob',
age2=7, age3=8)]
+ """
+ # Below code is to help enable kwargs in future.
+ assert len(colsMap) == 1
+ colsMap = colsMap[0] # type: ignore[assignment]
+
+ if not isinstance(colsMap, dict):
+ raise TypeError("colsMap must be dict of column name and column.")
+
+ col_names = list(colsMap.keys())
+ cols = list(colsMap.values())
+
+ return DataFrame(
+ self._jdf.withColumns(_to_seq(self._sc, col_names),
self._jcols(*cols)), self.sql_ctx
+ )
+
def withColumn(self, colName: str, col: Column) -> "DataFrame":
"""
Returns a new :class:`DataFrame` by adding a column or replacing the
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index 1367fe7..5f5e88f 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -479,6 +479,33 @@ class DataFrameTests(ReusedSQLTestCase):
self.assertRaises(TypeError, foo)
+ def test_with_columns(self):
+ # With single column
+ keys = self.df.withColumns({"key":
self.df.key}).select("key").collect()
+ self.assertEqual([r.key for r in keys], list(range(100)))
+
+ # With key and value columns
+ kvs = (
+ self.df.withColumns({"key": self.df.key, "value": self.df.value})
+ .select("key", "value")
+ .collect()
+ )
+ self.assertEqual([(r.key, r.value) for r in kvs], [(i, str(i)) for i
in range(100)])
+
+ # Columns rename
+ kvs = (
+ self.df.withColumns({"key_alias": self.df.key, "value_alias":
self.df.value})
+ .select("key_alias", "value_alias")
+ .collect()
+ )
+ self.assertEqual(
+ [(r.key_alias, r.value_alias) for r in kvs], [(i, str(i)) for i in
range(100)]
+ )
+
+ # Type check
+ self.assertRaises(TypeError, self.df.withColumns, ["key"])
+ self.assertRaises(AssertionError, self.df.withColumns)
+
def test_generic_hints(self):
from pyspark.sql import DataFrame
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 9dd38d8..4a921b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2479,6 +2479,35 @@ class Dataset[T] private[sql](
def withColumn(colName: String, col: Column): DataFrame =
withColumns(Seq(colName), Seq(col))
/**
+ * (Scala-specific) Returns a new Dataset by adding columns or replacing the
existing columns
+ * that has the same names.
+ *
+ * `colsMap` is a map of column name and column, the column must only refer
to attributes
+ * supplied by this Dataset. It is an error to add columns that refers to
some other Dataset.
+ *
+ * @group untypedrel
+ * @since 3.3.0
+ */
+ def withColumns(colsMap: Map[String, Column]): DataFrame = {
+ val (colNames, newCols) = colsMap.toSeq.unzip
+ withColumns(colNames, newCols)
+ }
+
+ /**
+ * (Java-specific) Returns a new Dataset by adding columns or replacing the
existing columns
+ * that has the same names.
+ *
+ * `colsMap` is a map of column name and column, the column must only refer
to attribute
+ * supplied by this Dataset. It is an error to add columns that refers to
some other Dataset.
+ *
+ * @group untypedrel
+ * @since 3.3.0
+ */
+ def withColumns(colsMap: java.util.Map[String, Column]): DataFrame =
withColumns(
+ colsMap.asScala.toMap
+ )
+
+ /**
* Returns a new Dataset by adding columns or replacing the existing columns
that has
* the same names.
*/
diff --git
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index da7c622..c0b4690 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -33,6 +33,7 @@ import org.junit.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
@@ -319,6 +320,21 @@ public class JavaDataFrameSuite {
}
@Test
+ public void testwithColumns() {
+ Dataset<Row> df = spark.table("testData2");
+ Map<String, Column> colMaps = new HashMap<>();
+ colMaps.put("a1", col("a"));
+ colMaps.put("b1", col("b"));
+
+ StructType expected = df.withColumn("a1", col("a")).withColumn("b1",
col("b")).schema();
+ StructType actual = df.withColumns(colMaps).schema();
+ // Validate geting same result with withColumn loop call
+ Assert.assertEquals(expected, actual);
+ // Validate the col names
+ Assert.assertArrayEquals(actual.fieldNames(), new String[] {"a", "b",
"a1", "b1"});
+ }
+
+ @Test
public void testSampleByColumn() {
Dataset<Row> df = spark.range(0, 100, 1,
2).select(col("id").mod(3).as("key"));
Dataset<Row> sampled = df.stat().sampleBy(col("key"), ImmutableMap.of(0,
0.1, 1, 0.2), 0L);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 7482d76..cd0bd06 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -631,7 +631,19 @@ class DataFrameSuite extends QueryTest
assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
}
- test("withColumns") {
+ test("withColumns: public API, with Map input") {
+ val df = testData.toDF().withColumns(Map(
+ "newCol1" -> (col("key") + 1), "newCol2" -> (col("key") + 2)
+ ))
+ checkAnswer(
+ df,
+ testData.collect().map { case Row(key: Int, value: String) =>
+ Row(key, value, key + 1, key + 2)
+ }.toSeq)
+ assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2"))
+ }
+
+ test("withColumns: internal method") {
val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"),
Seq(col("key") + 1, col("key") + 2))
checkAnswer(
@@ -655,7 +667,7 @@ class DataFrameSuite extends QueryTest
assert(err2.getMessage.contains("Found duplicate column(s)"))
}
- test("withColumns: case sensitive") {
+ test("withColumns: internal method, case sensitive") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
Seq(col("key") + 1, col("key") + 2))
@@ -674,7 +686,7 @@ class DataFrameSuite extends QueryTest
}
}
- test("withColumns: given metadata") {
+ test("withColumns: internal method, given metadata") {
def buildMetadata(num: Int): Seq[Metadata] = {
(0 until num).map { n =>
val builder = new MetadataBuilder
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]