This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8853f28  [SPARK-35173][SQL][PYTHON] Add multiple columns adding support
8853f28 is described below

commit 8853f286371bcc1d44762a0a8ed5bf1a40cdbbd5
Author: Yikun Jiang <[email protected]>
AuthorDate: Tue Feb 15 09:40:27 2022 +0900

    [SPARK-35173][SQL][PYTHON] Add multiple columns adding support
    
    ### What changes were proposed in this pull request?
    This PR added the multiple columns adding support for Spark 
scala/java/python API.
    - Expose `withColumns` with Map input as public API in Scala/Java
    - Add `withColumns` in PySpark
    
    There was also some discussion about adding multiple columns in past 
JIRA([SPARK-1225](https://issues.apache.org/jira/browse/SPARK-12225), 
[SPARK-26224](https://issues.apache.org/jira/browse/SPARK-26224)) and 
[ML](http://apache-spark-developers-list.1001551.n3.nabble.com/DISCUSS-Multiple-columns-adding-replacing-support-in-PySpark-DataFrame-API-td31164.html).
    
    ### Why are the changes needed?
    There were a private method `withColumns` can add columns at one pass [1]:
    
https://github.com/apache/spark/blob/b5241c97b17a1139a4ff719bfce7f68aef094d95/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L2402
    
    However, it was not exposed as public API in Scala/Java, and also PySpark 
user can only use `withColumn` to add one column or replacing the existing one 
column that has the same name.
    
    For example, if the PySpark user want to add multiple columns, they should 
call `withColumn` again and again like:
    ```Python
    df.withColumn("key1", col("key1")).withColumn("key2", 
col("key2")).withColumn("key3", col("key3"))
    ```
    After this patch, the user can use the `withColumns` with map of colume 
name and column :
    ```Python
    df.withColumns({"key1":  col("key1"), "key2":col("key2"), "key3": 
col("key3")})
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, this PR exposes `withColumns` as public API, and also adds 
`withColumns` API in PySpark .
    
    ### How was this patch tested?
    - Add new multiple columns adding test, passed
    - Existing test, passed
    
    Closes #32431 from Yikun/SPARK-35173-cols.
    
    Authored-by: Yikun Jiang <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/dataframe.py                    | 35 ++++++++++++++++++++++
 python/pyspark/sql/tests/test_dataframe.py         | 27 +++++++++++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 29 ++++++++++++++++++
 .../org/apache/spark/sql/JavaDataFrameSuite.java   | 16 ++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala      | 18 +++++++++--
 5 files changed, 122 insertions(+), 3 deletions(-)

diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ee68865..0372527 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2911,6 +2911,41 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
             support = 0.01
         return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), 
support), self.sql_ctx)
 
+    def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame":
+        """
+        Returns a new :class:`DataFrame` by adding multiple columns or 
replacing the
+        existing columns that has the same names.
+
+        The colsMap is a map of column name and column, the column must only 
refer to attributes
+        supplied by this Dataset. It is an error to add columns that refer to 
some other Dataset.
+
+        .. versionadded:: 3.3.0
+           Added support for multiple columns adding
+
+        Parameters
+        ----------
+        colsMap : dict
+            a dict of column name and :class:`Column`. Currently, only single 
map is supported.
+
+        Examples
+        --------
+        >>> df.withColumns({'age2': df.age + 2, 'age3': df.age + 3}).collect()
+        [Row(age=2, name='Alice', age2=4, age3=5), Row(age=5, name='Bob', 
age2=7, age3=8)]
+        """
+        # Below code is to help enable kwargs in future.
+        assert len(colsMap) == 1
+        colsMap = colsMap[0]  # type: ignore[assignment]
+
+        if not isinstance(colsMap, dict):
+            raise TypeError("colsMap must be dict of column name and column.")
+
+        col_names = list(colsMap.keys())
+        cols = list(colsMap.values())
+
+        return DataFrame(
+            self._jdf.withColumns(_to_seq(self._sc, col_names), 
self._jcols(*cols)), self.sql_ctx
+        )
+
     def withColumn(self, colName: str, col: Column) -> "DataFrame":
         """
         Returns a new :class:`DataFrame` by adding a column or replacing the
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 1367fe7..5f5e88f 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -479,6 +479,33 @@ class DataFrameTests(ReusedSQLTestCase):
 
         self.assertRaises(TypeError, foo)
 
+    def test_with_columns(self):
+        # With single column
+        keys = self.df.withColumns({"key": 
self.df.key}).select("key").collect()
+        self.assertEqual([r.key for r in keys], list(range(100)))
+
+        # With key and value columns
+        kvs = (
+            self.df.withColumns({"key": self.df.key, "value": self.df.value})
+            .select("key", "value")
+            .collect()
+        )
+        self.assertEqual([(r.key, r.value) for r in kvs], [(i, str(i)) for i 
in range(100)])
+
+        # Columns rename
+        kvs = (
+            self.df.withColumns({"key_alias": self.df.key, "value_alias": 
self.df.value})
+            .select("key_alias", "value_alias")
+            .collect()
+        )
+        self.assertEqual(
+            [(r.key_alias, r.value_alias) for r in kvs], [(i, str(i)) for i in 
range(100)]
+        )
+
+        # Type check
+        self.assertRaises(TypeError, self.df.withColumns, ["key"])
+        self.assertRaises(AssertionError, self.df.withColumns)
+
     def test_generic_hints(self):
         from pyspark.sql import DataFrame
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 9dd38d8..4a921b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2479,6 +2479,35 @@ class Dataset[T] private[sql](
   def withColumn(colName: String, col: Column): DataFrame = 
withColumns(Seq(colName), Seq(col))
 
   /**
+   * (Scala-specific) Returns a new Dataset by adding columns or replacing the 
existing columns
+   * that has the same names.
+   *
+   * `colsMap` is a map of column name and column, the column must only refer 
to attributes
+   * supplied by this Dataset. It is an error to add columns that refers to 
some other Dataset.
+   *
+   * @group untypedrel
+   * @since 3.3.0
+   */
+  def withColumns(colsMap: Map[String, Column]): DataFrame = {
+    val (colNames, newCols) = colsMap.toSeq.unzip
+    withColumns(colNames, newCols)
+  }
+
+  /**
+   * (Java-specific) Returns a new Dataset by adding columns or replacing the 
existing columns
+   * that has the same names.
+   *
+   * `colsMap` is a map of column name and column, the column must only refer 
to attribute
+   * supplied by this Dataset. It is an error to add columns that refers to 
some other Dataset.
+   *
+   * @group untypedrel
+   * @since 3.3.0
+   */
+  def withColumns(colsMap: java.util.Map[String, Column]): DataFrame = 
withColumns(
+    colsMap.asScala.toMap
+  )
+
+  /**
    * Returns a new Dataset by adding columns or replacing the existing columns 
that has
    * the same names.
    */
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index da7c622..c0b4690 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -33,6 +33,7 @@ import org.junit.*;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.Column;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
@@ -319,6 +320,21 @@ public class JavaDataFrameSuite {
   }
 
   @Test
+  public void testwithColumns() {
+    Dataset<Row> df = spark.table("testData2");
+    Map<String, Column> colMaps = new HashMap<>();
+    colMaps.put("a1", col("a"));
+    colMaps.put("b1", col("b"));
+
+    StructType expected = df.withColumn("a1", col("a")).withColumn("b1", 
col("b")).schema();
+    StructType actual = df.withColumns(colMaps).schema();
+    // Validate geting same result with withColumn loop call
+    Assert.assertEquals(expected, actual);
+    // Validate the col names
+    Assert.assertArrayEquals(actual.fieldNames(), new String[] {"a", "b", 
"a1", "b1"});
+  }
+
+  @Test
   public void testSampleByColumn() {
     Dataset<Row> df = spark.range(0, 100, 1, 
2).select(col("id").mod(3).as("key"));
     Dataset<Row> sampled = df.stat().sampleBy(col("key"), ImmutableMap.of(0, 
0.1, 1, 0.2), 0L);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 7482d76..cd0bd06 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -631,7 +631,19 @@ class DataFrameSuite extends QueryTest
     assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
   }
 
-  test("withColumns") {
+  test("withColumns: public API, with Map input") {
+    val df = testData.toDF().withColumns(Map(
+      "newCol1" -> (col("key") + 1), "newCol2" -> (col("key")  + 2)
+    ))
+    checkAnswer(
+      df,
+      testData.collect().map { case Row(key: Int, value: String) =>
+        Row(key, value, key + 1, key + 2)
+      }.toSeq)
+    assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2"))
+  }
+
+  test("withColumns: internal method") {
     val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"),
       Seq(col("key") + 1, col("key") + 2))
     checkAnswer(
@@ -655,7 +667,7 @@ class DataFrameSuite extends QueryTest
     assert(err2.getMessage.contains("Found duplicate column(s)"))
   }
 
-  test("withColumns: case sensitive") {
+  test("withColumns: internal method, case sensitive") {
     withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
       val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
         Seq(col("key") + 1, col("key") + 2))
@@ -674,7 +686,7 @@ class DataFrameSuite extends QueryTest
     }
   }
 
-  test("withColumns: given metadata") {
+  test("withColumns: internal method, given metadata") {
     def buildMetadata(num: Int): Seq[Metadata] = {
       (0 until num).map { n =>
         val builder = new MetadataBuilder

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to