fqaiser94 commented on a change in pull request #29795:
URL: https://github.com/apache/spark/pull/29795#discussion_r494700714



##########
File path: sql/core/src/main/scala/org/apache/spark/sql/Column.scala
##########
@@ -901,39 +901,125 @@ class Column(val expr: Expression) extends Logging {
    *   // result: org.apache.spark.sql.AnalysisException: Ambiguous reference 
to fields
    * }}}
    *
+   * This method supports adding/replacing nested fields directly e.g.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".withField("a.c", lit(3)).withField("a.d", 
lit(4)))
+   *   // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+   * }}}
+   *
+   * However, if you are going to add/replace multiple nested fields, it is 
more optimal to extract
+   * out the nested struct before adding/replacing multiple fields e.g.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".withField("a", $"struct_col.a".withField("c", 
lit(3)).withField("d", lit(4))))
+   *   // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+   * }}}
+   *
    * @group expr_ops
    * @since 3.1.0
    */
   // scalastyle:on line.size.limit
   def withField(fieldName: String, col: Column): Column = withExpr {
     require(fieldName != null, "fieldName cannot be null")
     require(col != null, "col cannot be null")
+    updateFieldsHelper(expr, nameParts(fieldName), name => WithField(name, 
col.expr))
+  }
 
-    val nameParts = if (fieldName.isEmpty) {
+  // scalastyle:off line.size.limit
+  /**
+   * An expression that drops fields in `StructType` by name.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+   *   df.select($"struct_col".dropFields("b"))
+   *   // result: {"a":1}
+   *
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+   *   df.select($"struct_col".dropFields("c"))
+   *   // result: {"a":1,"b":2}
+   *
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col")
+   *   df.select($"struct_col".dropFields("b", "c"))
+   *   // result: {"a":1}
+   *
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+   *   df.select($"struct_col".dropFields("a", "b"))
+   *   // result: org.apache.spark.sql.AnalysisException: cannot resolve 
'update_fields(update_fields(`struct_col`))' due to data type mismatch: cannot 
drop all fields in struct
+   *
+   *   val df = sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
+   *   df.select($"struct_col".dropFields("b"))
+   *   // result: null of type struct<a:int>
+   *
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
+   *   df.select($"struct_col".dropFields("b"))
+   *   // result: {"a":1}
+   *
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".dropFields("a.b"))
+   *   // result: {"a":{"a":1}}
+   *
+   *   val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', 
named_struct('c', 2)) struct_col")
+   *   df.select($"struct_col".dropFields("a.c"))
+   *   // result: org.apache.spark.sql.AnalysisException: Ambiguous reference 
to fields
+   * }}}
+   *
+   * This method supports dropping multiple nested fields directly e.g.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".dropFields("a.b", "a.c"))
+   *   // result: {"a":{"a":1}}
+   * }}}
+   *
+   * However, if you are going to drop multiple nested fields, it is more 
optimal to extract
+   * out the nested struct before dropping multiple fields from it e.g.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".withField("a", $"struct_col.a".dropFields("b", 
"c")))
+   *   // result: {"a":{"a":1}}
+   * }}}
+   *
+   * @group expr_ops
+   * @since 3.1.0
+   */
+  // scalastyle:on line.size.limit
+  def dropFields(fieldNames: String*): Column = withExpr {
+    def dropField(structExpr: Expression, fieldName: String): UpdateFields =
+      updateFieldsHelper(structExpr, nameParts(fieldName), name => 
DropField(name))
+
+    fieldNames.tail.foldLeft(dropField(expr, fieldNames.head)) {
+      (resExpr, fieldName) => dropField(resExpr, fieldName)
+    }
+  }
+
+  private def nameParts(fieldName: String): Seq[String] = {
+    require(fieldName != null, "fieldName cannot be null")
+
+    if (fieldName.isEmpty) {
       fieldName :: Nil

Review comment:
       we've discussed this before 
[here](https://github.com/apache/spark/pull/27066#discussion_r448416127) :)
   Its needed for `withField` and I think we should support it in `dropFields` 
as well because `Dataset.drop` supports it: 
   ```
   scala> Seq((1, 2)).toDF("a", "").drop("").printSchema
   root
    |-- a: integer (nullable = false)
   ```
   I've added a test case to demonstrate this works on the `dropFields` side 
but otherwise left the code unchanged. 
   

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/Column.scala
##########
@@ -901,39 +901,125 @@ class Column(val expr: Expression) extends Logging {
    *   // result: org.apache.spark.sql.AnalysisException: Ambiguous reference 
to fields
    * }}}
    *
+   * This method supports adding/replacing nested fields directly e.g.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".withField("a.c", lit(3)).withField("a.d", 
lit(4)))
+   *   // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+   * }}}
+   *
+   * However, if you are going to add/replace multiple nested fields, it is 
more optimal to extract
+   * out the nested struct before adding/replacing multiple fields e.g.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".withField("a", $"struct_col.a".withField("c", 
lit(3)).withField("d", lit(4))))
+   *   // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+   * }}}
+   *
    * @group expr_ops
    * @since 3.1.0
    */
   // scalastyle:on line.size.limit
   def withField(fieldName: String, col: Column): Column = withExpr {
     require(fieldName != null, "fieldName cannot be null")
     require(col != null, "col cannot be null")
+    updateFieldsHelper(expr, nameParts(fieldName), name => WithField(name, 
col.expr))
+  }
 
-    val nameParts = if (fieldName.isEmpty) {
+  // scalastyle:off line.size.limit
+  /**
+   * An expression that drops fields in `StructType` by name.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+   *   df.select($"struct_col".dropFields("b"))
+   *   // result: {"a":1}
+   *
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+   *   df.select($"struct_col".dropFields("c"))
+   *   // result: {"a":1,"b":2}
+   *
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col")
+   *   df.select($"struct_col".dropFields("b", "c"))
+   *   // result: {"a":1}
+   *
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+   *   df.select($"struct_col".dropFields("a", "b"))
+   *   // result: org.apache.spark.sql.AnalysisException: cannot resolve 
'update_fields(update_fields(`struct_col`))' due to data type mismatch: cannot 
drop all fields in struct
+   *
+   *   val df = sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
+   *   df.select($"struct_col".dropFields("b"))
+   *   // result: null of type struct<a:int>
+   *
+   *   val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
+   *   df.select($"struct_col".dropFields("b"))
+   *   // result: {"a":1}
+   *
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".dropFields("a.b"))
+   *   // result: {"a":{"a":1}}
+   *
+   *   val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', 
named_struct('c', 2)) struct_col")
+   *   df.select($"struct_col".dropFields("a.c"))
+   *   // result: org.apache.spark.sql.AnalysisException: Ambiguous reference 
to fields
+   * }}}
+   *
+   * This method supports dropping multiple nested fields directly e.g.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".dropFields("a.b", "a.c"))
+   *   // result: {"a":{"a":1}}
+   * }}}
+   *
+   * However, if you are going to drop multiple nested fields, it is more 
optimal to extract
+   * out the nested struct before dropping multiple fields from it e.g.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".withField("a", $"struct_col.a".dropFields("b", 
"c")))
+   *   // result: {"a":{"a":1}}
+   * }}}
+   *
+   * @group expr_ops
+   * @since 3.1.0
+   */
+  // scalastyle:on line.size.limit
+  def dropFields(fieldNames: String*): Column = withExpr {
+    def dropField(structExpr: Expression, fieldName: String): UpdateFields =
+      updateFieldsHelper(structExpr, nameParts(fieldName), name => 
DropField(name))
+
+    fieldNames.tail.foldLeft(dropField(expr, fieldNames.head)) {
+      (resExpr, fieldName) => dropField(resExpr, fieldName)
+    }
+  }
+
+  private def nameParts(fieldName: String): Seq[String] = {
+    require(fieldName != null, "fieldName cannot be null")
+
+    if (fieldName.isEmpty) {
       fieldName :: Nil
     } else {
       CatalystSqlParser.parseMultipartIdentifier(fieldName)
     }
-    withFieldHelper(expr, nameParts, Nil, col.expr)
   }
 
-  private def withFieldHelper(
-      struct: Expression,
-      namePartsRemaining: Seq[String],
-      namePartsDone: Seq[String],
-      value: Expression) : WithFields = {
-    val name = namePartsRemaining.head
+  private def updateFieldsHelper(
+    structExpr: Expression,

Review comment:
       done

##########
File path: sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
##########
@@ -159,6 +160,14 @@ abstract class QueryTest extends PlanTest {
     checkAnswer(df, expectedAnswer.collect())
   }
 
+  protected def checkAnswer(
+    df: => DataFrame,
+    expectedAnswer: Seq[Row],

Review comment:
       done

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession {
+
+  private def nestedColName(d: Int, colNum: Int): String = 
s"nested${d}Col$colNum"
+
+  private def nestedStructType(
+    depths: Seq[Int], colNums: Seq[Int], nullable: Boolean): StructType = {

Review comment:
       done

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession {
+
+  private def nestedColName(d: Int, colNum: Int): String = 
s"nested${d}Col$colNum"
+
+  private def nestedStructType(
+    depths: Seq[Int], colNums: Seq[Int], nullable: Boolean): StructType = {
+    if (depths.length == 1) {
+      StructType(colNums.map { colNum =>
+        StructField(nestedColName(depths.head, colNum), IntegerType, nullable 
= false)
+      })
+    } else {
+      val depth = depths.head
+      val fields = colNums.foldLeft(Seq.empty[StructField]) {
+        case (structFields, colNum) if colNum == 0 =>
+          val nested = nestedStructType(depths.tail, colNums, nullable)
+          structFields :+ StructField(nestedColName(depth, colNum), nested, 
nullable)
+        case (structFields, colNum) =>
+          structFields :+ StructField(nestedColName(depth, colNum), 
IntegerType, nullable = false)
+      }
+      StructType(fields)
+    }
+  }
+
+  private def nestedRow(depths: Seq[Int], colNums: Seq[Int]): Row = {
+    if (depths.length == 1) {
+      Row.fromSeq(colNums)
+    } else {
+      val values = colNums.foldLeft(Seq.empty[Any]) {
+        case (values, colNum) if colNum == 0 => values :+ 
nestedRow(depths.tail, colNums)
+        case (values, colNum) => values :+ colNum
+      }
+      Row.fromSeq(values)
+    }
+  }
+
+  /**
+   * Utility function for generating a DataFrame with nested columns.
+   *
+   * @param depth: The depth to which to create nested columns.
+   * @param numColsAtEachDepth: The number of columns to create at each depth. 
The value of each
+   *                          column will be the same as its index 
(IntegerType) at that depth
+   *                          unless the index = 0, in which case it may be a 
StructType which
+   *                          represents the next depth.
+   * @param nullable: This value is used to set the nullability of StructType 
columns.
+   */
+  private def nestedDf(
+    depth: Int, numColsAtEachDepth: Int, nullable: Boolean = false): DataFrame 
= {

Review comment:
       done

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession {
+
+  private def nestedColName(d: Int, colNum: Int): String = 
s"nested${d}Col$colNum"
+
+  private def nestedStructType(
+    depths: Seq[Int], colNums: Seq[Int], nullable: Boolean): StructType = {
+    if (depths.length == 1) {
+      StructType(colNums.map { colNum =>
+        StructField(nestedColName(depths.head, colNum), IntegerType, nullable 
= false)
+      })
+    } else {
+      val depth = depths.head
+      val fields = colNums.foldLeft(Seq.empty[StructField]) {
+        case (structFields, colNum) if colNum == 0 =>
+          val nested = nestedStructType(depths.tail, colNums, nullable)
+          structFields :+ StructField(nestedColName(depth, colNum), nested, 
nullable)
+        case (structFields, colNum) =>
+          structFields :+ StructField(nestedColName(depth, colNum), 
IntegerType, nullable = false)
+      }
+      StructType(fields)
+    }
+  }
+
+  private def nestedRow(depths: Seq[Int], colNums: Seq[Int]): Row = {
+    if (depths.length == 1) {
+      Row.fromSeq(colNums)
+    } else {
+      val values = colNums.foldLeft(Seq.empty[Any]) {
+        case (values, colNum) if colNum == 0 => values :+ 
nestedRow(depths.tail, colNums)
+        case (values, colNum) => values :+ colNum
+      }
+      Row.fromSeq(values)
+    }
+  }
+
+  /**
+   * Utility function for generating a DataFrame with nested columns.
+   *
+   * @param depth: The depth to which to create nested columns.
+   * @param numColsAtEachDepth: The number of columns to create at each depth. 
The value of each
+   *                          column will be the same as its index 
(IntegerType) at that depth
+   *                          unless the index = 0, in which case it may be a 
StructType which
+   *                          represents the next depth.
+   * @param nullable: This value is used to set the nullability of StructType 
columns.
+   */
+  private def nestedDf(
+    depth: Int, numColsAtEachDepth: Int, nullable: Boolean = false): DataFrame 
= {
+    require(depth > 0)
+    require(numColsAtEachDepth > 0)
+
+    val depths = 1 to depth
+    val colNums = 0 until numColsAtEachDepth
+    val nestedColumn = nestedRow(depths, colNums)
+    val nestedColumnDataType = nestedStructType(depths, colNums, nullable)
+
+    spark.createDataFrame(
+      sparkContext.parallelize(Row(nestedColumn) :: Nil),
+      StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, 
nullable))))
+  }
+
+  test("nestedDf should generate nested DataFrames") {
+    checkAnswer(
+      nestedDf(1, 1),
+      Row(Row(0)) :: Nil,
+      StructType(Seq(StructField("nested0Col0", StructType(Seq(
+        StructField("nested1Col0", IntegerType, nullable = false))),
+        nullable = false))))
+
+    checkAnswer(
+      nestedDf(1, 2),
+      Row(Row(0, 1)) :: Nil,
+      StructType(Seq(StructField("nested0Col0", StructType(Seq(
+        StructField("nested1Col0", IntegerType, nullable = false),
+        StructField("nested1Col1", IntegerType, nullable = false))),
+        nullable = false))))
+
+    checkAnswer(
+      nestedDf(2, 1),
+      Row(Row(Row(0))) :: Nil,
+      StructType(Seq(StructField("nested0Col0", StructType(Seq(
+        StructField("nested1Col0", StructType(Seq(
+          StructField("nested2Col0", IntegerType, nullable = false))),
+          nullable = false))),
+        nullable = false))))
+
+    checkAnswer(
+      nestedDf(2, 2),
+      Row(Row(Row(0, 1), 1)) :: Nil,
+      StructType(Seq(StructField("nested0Col0", StructType(Seq(
+        StructField("nested1Col0", StructType(Seq(
+          StructField("nested2Col0", IntegerType, nullable = false),
+          StructField("nested2Col1", IntegerType, nullable = false))),
+          nullable = false),
+        StructField("nested1Col1", IntegerType, nullable = false))),
+        nullable = false))))
+
+    checkAnswer(
+      nestedDf(2, 2, nullable = true),
+      Row(Row(Row(0, 1), 1)) :: Nil,
+      StructType(Seq(StructField("nested0Col0", StructType(Seq(
+        StructField("nested1Col0", StructType(Seq(
+          StructField("nested2Col0", IntegerType, nullable = false),
+          StructField("nested2Col1", IntegerType, nullable = false))),
+          nullable = true),
+        StructField("nested1Col1", IntegerType, nullable = false))),
+        nullable = true))))
+  }
+
+  // simulates how a user would add/drop nested fields in a performant manner
+  private def addDropNestedColumns(
+    column: Column,

Review comment:
       done




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to