fqaiser94 commented on a change in pull request #29795:
URL: https://github.com/apache/spark/pull/29795#discussion_r491753622



##########
File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
##########
@@ -537,18 +662,75 @@ class ComplexTypesSuite extends PlanTest with 
ExpressionEvalHelper {
       query(testStructRelation),
       testStructRelation
         .select(
-          GetStructField('struct1, 0, Some("a")) as "struct2A",
+          GetStructField('struct1, 0) as "struct2A",
           Literal(2) as "struct2B",
-          GetStructField('struct1, 0, Some("a")) as "struct3A",
+          GetStructField('struct1, 0) as "struct3A",
           Literal(3) as "struct3B"))
 
     checkRule(
       query(testNullableStructRelation),
       testNullableStructRelation
         .select(
-          GetStructField('struct1, 0, Some("a")) as "struct2A",
+          GetStructField('struct1, 0) as "struct2A",
           If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as 
"struct2B",
-          GetStructField('struct1, 0, Some("a")) as "struct3A",
+          GetStructField('struct1, 0) as "struct3A",
           If(IsNull('struct1), Literal(null, IntegerType), Literal(3)) as 
"struct3B"))
   }
+
+  test("simplify add multiple nested fields to struct") {
+    // this scenario is possible if users add multiple nested columns via the 
Column.withField API
+    // ideally, users should not be doing this.
+    val nullableStructLevel2 = LocalRelation(
+      'a1.struct(
+        'a2.struct('a3.int)).withNullability(false))
+
+    val query = {
+      val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2",
+        UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", 
Literal(2)))))))
+
+      nullableStructLevel2.select(
+        UpdateFields(
+          addB3toA1A2,
+          Seq(WithField("a2", UpdateFields(
+            GetStructField(addB3toA1A2, 0), Seq(WithField("c3", 
Literal(3))))))).as("a1"))
+    }
+
+    val expected = nullableStructLevel2.select(
+      UpdateFields('a1, Seq(
+        // scalastyle:off line.size.limit
+        WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 
2) :: Nil)),
+        WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 
2) :: WithField("c3", 3) :: Nil))
+        // scalastyle:on line.size.limit
+      )).as("a1"))
+
+    checkRule(query, expected)
+  }
+
+  test("simplify drop multiple nested fields in struct") {
+    // this scenario is possible if users drop multiple nested columns via the 
Column.dropFields API
+    // ideally, users should not be doing this.
+    val df = LocalRelation(
+      'a1.struct(
+        'a2.struct('a3.int, 'b3.int, 'c3.int).withNullability(false)
+      ).withNullability(false))
+
+    val query = {
+      val dropA1A2B = UpdateFields('a1, Seq(WithField("a2", UpdateFields(
+        GetStructField('a1, 0), Seq(DropField("b3"))))))
+
+      df.select(
+        UpdateFields(
+          dropA1A2B,
+          Seq(WithField("a2", UpdateFields(
+            GetStructField(dropA1A2B, 0), Seq(DropField("c3")))))).as("a1"))
+    }
+
+    val expected = df.select(
+      UpdateFields('a1, Seq(
+        WithField("a2", UpdateFields(GetStructField('a1, 0), 
Seq(DropField("b3")))),
+        WithField("a2", UpdateFields(GetStructField('a1, 0), 
Seq(DropField("b3"), DropField("c3"))))
+      )).as("a1"))

Review comment:
       This first `WithField` in here is entirely redundant as well and ideally 
we would optimize this away as well.
   However, in the interests of keeping this PR simple, I have opted to forgo 
writing any such optimizer rule.
   If necessary, we can address this in a future PR.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
##########
@@ -1514,27 +1516,578 @@ class ColumnExpressionSuite extends QueryTest with 
SharedSparkSession {
       StructType(Seq(StructField("a", structType, nullable = true))))
 
     // extract newly added field
-    checkAnswerAndSchema(
+    checkAnswer(
       df.withColumn("a", $"a".withField("d", lit(4)).getField("d")),
       Row(4) :: Row(null) :: Nil,
       StructType(Seq(StructField("a", IntegerType, nullable = true))))
 
     // extract newly replaced field
-    checkAnswerAndSchema(
+    checkAnswer(
       df.withColumn("a", $"a".withField("a", lit(4)).getField("a")),
       Row(4) :: Row(null):: Nil,
       StructType(Seq(StructField("a", IntegerType, nullable = true))))
 
     // add new field, extract another field from original struct
-    checkAnswerAndSchema(
+    checkAnswer(
       df.withColumn("a", $"a".withField("d", lit(4)).getField("c")),
       Row(3) :: Row(null):: Nil,
       StructType(Seq(StructField("a", IntegerType, nullable = true))))
 
     // replace field, extract another field from original struct
-    checkAnswerAndSchema(
+    checkAnswer(
       df.withColumn("a", $"a".withField("a", lit(4)).getField("c")),
       Row(3) :: Row(null):: Nil,
       StructType(Seq(StructField("a", IntegerType, nullable = true))))
   }
+
+
+  test("dropFields should throw an exception if called on a non-StructType 
column") {
+    intercept[AnalysisException] {
+      testData.withColumn("key", $"key".dropFields("a"))
+    }.getMessage should include("struct argument should be struct type, got: 
int")
+  }
+
+  test("dropFields should throw an exception if fieldName argument is null") {
+    intercept[IllegalArgumentException] {
+      structLevel1.withColumn("a", $"a".dropFields(null))
+    }.getMessage should include("fieldName cannot be null")
+  }
+
+  test("dropFields should throw an exception if any intermediate structs don't 
exist") {
+    intercept[AnalysisException] {
+      structLevel2.withColumn("a", 'a.dropFields("x.b"))
+    }.getMessage should include("No such struct field x in a")
+
+    intercept[AnalysisException] {
+      structLevel3.withColumn("a", 'a.dropFields("a.x.b"))
+    }.getMessage should include("No such struct field x in a")
+  }
+
+  test("dropFields should throw an exception if intermediate field is not a 
struct") {
+    intercept[AnalysisException] {
+      structLevel1.withColumn("a", 'a.dropFields("b.a"))
+    }.getMessage should include("struct argument should be struct type, got: 
int")
+  }
+
+  test("dropFields should throw an exception if intermediate field reference 
is ambiguous") {
+    intercept[AnalysisException] {
+      val structLevel2: DataFrame = spark.createDataFrame(
+        sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil),
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", structType, nullable = false),
+            StructField("a", structType, nullable = false))),
+            nullable = false))))
+
+      structLevel2.withColumn("a", 'a.dropFields("a.b"))
+    }.getMessage should include("Ambiguous reference to fields")
+  }
+
+  test("dropFields should drop field in struct") {
+    checkAnswer(
+      structLevel1.withColumn("a", 'a.dropFields("b")),
+      Row(Row(1, 3)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("c", IntegerType, nullable = false))),
+          nullable = false))))
+  }
+
+  test("dropFields should drop field in null struct") {
+    checkAnswer(
+      nullStructLevel1.withColumn("a", $"a".dropFields("b")),
+      Row(null) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("c", IntegerType, nullable = false))),
+          nullable = true))))
+  }
+
+  test("dropFields should drop multiple fields in struct") {
+    Seq(
+      structLevel1.withColumn("a", $"a".dropFields("b", "c")),
+      structLevel1.withColumn("a", 'a.dropFields("b").dropFields("c"))
+    ).foreach { df =>
+      checkAnswer(
+        df,
+        Row(Row(1)) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false))),
+            nullable = false))))
+    }
+  }
+
+  test("dropFields should throw an exception if no fields will be left in 
struct") {
+    intercept[AnalysisException] {
+      structLevel1.withColumn("a", 'a.dropFields("a", "b", "c"))
+    }.getMessage should include("cannot drop all fields in struct")
+  }
+
+  test("dropFields should drop field in nested struct") {
+    checkAnswer(
+      structLevel2.withColumn("a", 'a.dropFields("a.b")),
+      Row(Row(Row(1, 3))) :: Nil,
+      StructType(
+        Seq(StructField("a", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("c", IntegerType, nullable = false))),
+            nullable = false))),
+          nullable = false))))
+  }
+
+  test("dropFields should drop multiple fields in nested struct") {
+    checkAnswer(
+      structLevel2.withColumn("a", 'a.dropFields("a.b", "a.c")),
+      Row(Row(Row(1))) :: Nil,
+      StructType(
+        Seq(StructField("a", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false))),
+            nullable = false))),
+          nullable = false))))
+  }
+
+  test("dropFields should drop field in nested null struct") {
+    checkAnswer(
+      nullStructLevel2.withColumn("a", $"a".dropFields("a.b")),
+      Row(Row(null)) :: Nil,
+      StructType(
+        Seq(StructField("a", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("c", IntegerType, nullable = false))),
+            nullable = true))),
+          nullable = false))))
+  }
+
+  test("dropFields should drop multiple fields in nested null struct") {
+    checkAnswer(
+      nullStructLevel2.withColumn("a", $"a".dropFields("a.b", "a.c")),
+      Row(Row(null)) :: Nil,
+      StructType(
+        Seq(StructField("a", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false))),
+            nullable = true))),
+          nullable = false))))
+  }
+
+  test("dropFields should drop field in deeply nested struct") {
+    checkAnswer(
+      structLevel3.withColumn("a", 'a.dropFields("a.a.b")),
+      Row(Row(Row(Row(1, 3)))) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", StructType(Seq(
+              StructField("a", IntegerType, nullable = false),
+              StructField("c", IntegerType, nullable = false))),
+              nullable = false))),
+            nullable = false))),
+          nullable = false))))
+  }
+
+  test("dropFields should drop all fields with given name in struct") {
+    val structLevel1 = spark.createDataFrame(
+      sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil),
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = false))),
+          nullable = false))))
+
+    checkAnswer(
+      structLevel1.withColumn("a", 'a.dropFields("b")),
+      Row(Row(1)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false))),
+          nullable = false))))
+  }
+
+  test("dropFields should drop field in struct even if casing is different") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+      checkAnswer(
+        mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")),
+        Row(Row(1)) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("B", IntegerType, nullable = false))),
+            nullable = false))))
+
+      checkAnswer(
+        mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")),
+        Row(Row(1)) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false))),
+            nullable = false))))
+    }
+  }
+
+  test("dropFields should not drop field in struct because casing is 
different") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      checkAnswer(
+        mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")),
+        Row(Row(1, 1)) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("B", IntegerType, nullable = false))),
+            nullable = false))))
+
+      checkAnswer(
+        mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")),
+        Row(Row(1, 1)) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("B", IntegerType, nullable = false))),
+            nullable = false))))
+    }
+  }
+
+  test("dropFields should drop nested field in struct even if casing is 
different") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+      checkAnswer(
+        mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")),
+        Row(Row(Row(1), Row(1, 1))) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("A", StructType(Seq(
+              StructField("b", IntegerType, nullable = false))),
+              nullable = false),
+            StructField("B", StructType(Seq(
+              StructField("a", IntegerType, nullable = false),
+              StructField("b", IntegerType, nullable = false))),
+              nullable = false))),
+            nullable = false))))
+
+      checkAnswer(
+        mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")),
+        Row(Row(Row(1, 1), Row(1))) :: Nil,
+        StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", StructType(Seq(
+              StructField("a", IntegerType, nullable = false),
+              StructField("b", IntegerType, nullable = false))),
+              nullable = false),
+            StructField("b", StructType(Seq(
+              StructField("b", IntegerType, nullable = false))),
+              nullable = false))),
+            nullable = false))))
+    }
+  }
+
+  test("dropFields should throw an exception because casing is different") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      intercept[AnalysisException] {
+        mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a"))
+      }.getMessage should include("No such struct field A in a, B")
+
+      intercept[AnalysisException] {
+        mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a"))
+      }.getMessage should include("No such struct field b in a, B")
+    }
+  }
+
+  test("dropFields should drop only fields that exist") {
+    checkAnswer(
+      structLevel1.withColumn("a", 'a.dropFields("d")),
+      Row(Row(1, null, 3)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = true),
+          StructField("c", IntegerType, nullable = false))),
+          nullable = false))))
+
+    checkAnswer(
+      structLevel1.withColumn("a", 'a.dropFields("b", "d")),
+      Row(Row(1, 3)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("c", IntegerType, nullable = false))),
+          nullable = false))))
+
+    checkAnswer(
+      structLevel2.withColumn("a", $"a".dropFields("a.b", "a.d")),
+      Row(Row(Row(1, 3))) :: Nil,
+      StructType(
+        Seq(StructField("a", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("c", IntegerType, nullable = false))),
+            nullable = false))),
+          nullable = false))))
+  }
+
+  test("dropFields should drop multiple fields at arbitrary levels of nesting 
in a single call") {
+    val df: DataFrame = spark.createDataFrame(
+      sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil),
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", structType, nullable = false),
+          StructField("b", IntegerType, nullable = false))),
+          nullable = false))))
+
+    checkAnswer(
+      df.withColumn("a", $"a".dropFields("a.b", "b")),
+      Row(Row(Row(1, 3))) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("c", IntegerType, nullable = false))), nullable = 
false))),
+          nullable = false))))
+  }
+
+  test("dropFields user-facing examples") {
+    checkAnswer(
+      sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+        .select($"struct_col".dropFields("b")),
+      Row(Row(1)))
+
+    checkAnswer(
+      sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+        .select($"struct_col".dropFields("c")),
+      Row(Row(1, 2)))
+
+    checkAnswer(
+      sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col")
+        .select($"struct_col".dropFields("b", "c")),
+      Row(Row(1)))
+
+    intercept[AnalysisException] {
+      sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+        .select($"struct_col".dropFields("a", "b"))
+    }.getMessage should include("cannot drop all fields in struct")
+
+    checkAnswer(
+      sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
+        .select($"struct_col".dropFields("b")),
+      Row(null))
+
+    checkAnswer(
+      sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
+        .select($"struct_col".dropFields("b")),
+      Row(Row(1)))
+
+    checkAnswer(
+      sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
+        .select($"struct_col".dropFields("a.b")),
+      Row(Row(Row(1))))
+
+    intercept[AnalysisException] {
+      sql("SELECT named_struct('a', named_struct('b', 1), 'a', 
named_struct('c', 2)) struct_col")
+        .select($"struct_col".dropFields("a.c"))
+    }.getMessage should include("Ambiguous reference to fields")
+
+    checkAnswer(
+      sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2, 'c', 3)) 
struct_col")
+        .select($"struct_col".dropFields("a.b", "a.c")),
+      Row(Row(Row(1))))
+
+    checkAnswer(
+      sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2, 'c', 3)) 
struct_col")
+        .select($"struct_col".withField("a", $"struct_col.a".dropFields("b", 
"c"))),
+      Row(Row(Row(1))))
+  }
+
+  test("should correctly handle different dropField + withField + getField 
combinations") {
+    val structType = StructType(Seq(
+      StructField("a", IntegerType, nullable = false),
+      StructField("b", IntegerType, nullable = false)))
+
+    val structLevel1: DataFrame = spark.createDataFrame(
+      sparkContext.parallelize(Row(Row(1, 2)) :: Nil),
+      StructType(Seq(StructField("a", structType, nullable = false))))
+
+    val nullStructLevel1: DataFrame = spark.createDataFrame(
+      sparkContext.parallelize(Row(null) :: Nil),
+      StructType(Seq(StructField("a", structType, nullable = true))))
+
+    val nullableStructLevel1: DataFrame = spark.createDataFrame(
+      sparkContext.parallelize(Row(Row(1, 2)) :: Row(null) :: Nil),
+      StructType(Seq(StructField("a", structType, nullable = true))))
+
+    def check(
+      fieldOps: Column => Column,
+      getFieldName: String,
+      expectedValue: Option[Int]): Unit = {
+
+      def query(df: DataFrame): DataFrame =
+        df.select(fieldOps(col("a")).getField(getFieldName).as("res"))
+
+      checkAnswer(
+        query(structLevel1),
+        Row(expectedValue.orNull) :: Nil,
+        StructType(Seq(StructField("res", IntegerType, nullable = 
expectedValue.isEmpty))))
+
+      checkAnswer(
+        query(nullStructLevel1),
+        Row(null) :: Nil,
+        StructType(Seq(StructField("res", IntegerType, nullable = true))))
+
+      checkAnswer(
+        query(nullableStructLevel1),
+        Row(expectedValue.orNull) :: Row(null) :: Nil,
+        StructType(Seq(StructField("res", IntegerType, nullable = true))))
+    }
+
+    // add attribute, extract an attribute from the original struct
+    check(_.withField("c", lit(3)), "a", Some(1))
+    check(_.withField("c", lit(3)), "b", Some(2))
+
+    // add attribute, extract added attribute
+    check(_.withField("c", lit(3)), "c", Some(3))
+    check(_.withField("c", col("a.a")), "c", Some(1))
+    check(_.withField("c", col("a.b")), "c", Some(2))
+    check(_.withField("c", lit(null).cast(IntegerType)), "c", None)
+
+    // replace attribute, extract an attribute from the original struct
+    check(_.withField("b", lit(3)), "a", Some(1))
+    check(_.withField("a", lit(3)), "b", Some(2))
+
+    // replace attribute, extract replaced attribute
+    check(_.withField("b", lit(3)), "b", Some(3))
+    check(_.withField("b", lit(null).cast(IntegerType)), "b", None)
+    check(_.withField("a", lit(3)), "a", Some(3))
+    check(_.withField("a", lit(null).cast(IntegerType)), "a", None)
+
+    // drop attribute, extract an attribute from the original struct
+    check(_.dropFields("b"), "a", Some(1))
+    check(_.dropFields("a"), "b", Some(2))
+
+    // drop attribute, add attribute, extract an attribute from the original 
struct
+    check(_.dropFields("b").withField("c", lit(3)), "a", Some(1))
+    check(_.dropFields("a").withField("c", lit(3)), "b", Some(2))
+
+    // drop attribute, add another attribute, extract added attribute
+    check(_.dropFields("a").withField("c", lit(3)), "c", Some(3))
+    check(_.dropFields("b").withField("c", lit(3)), "c", Some(3))
+
+    // add attribute, drop attribute, extract an attribute from the original 
struct
+    check(_.withField("c", lit(3)).dropFields("a"), "b", Some(2))
+    check(_.withField("c", lit(3)).dropFields("b"), "a", Some(1))
+
+    // add attribute, drop another attribute, extract added attribute
+    check(_.withField("c", lit(3)).dropFields("a"), "c", Some(3))
+    check(_.withField("c", lit(3)).dropFields("b"), "c", Some(3))
+
+    // replace attribute, drop same attribute, extract an attribute from the 
original struct
+    check(_.withField("b", lit(3)).dropFields("b"), "a", Some(1))
+    check(_.withField("a", lit(3)).dropFields("a"), "b", Some(2))
+
+    // add attribute, drop same attribute, extract an attribute from the 
original struct
+    check(_.withField("c", lit(3)).dropFields("c"), "a", Some(1))
+    check(_.withField("c", lit(3)).dropFields("c"), "b", Some(2))
+
+    // add attribute, drop another attribute, extract added attribute
+    check(_.withField("b", lit(3)).dropFields("a"), "b", Some(3))
+    check(_.withField("a", lit(3)).dropFields("b"), "a", Some(3))
+    check(_.withField("b", lit(null).cast(IntegerType)).dropFields("a"), "b", 
None)
+    check(_.withField("a", lit(null).cast(IntegerType)).dropFields("b"), "a", 
None)
+
+    // drop attribute, add same attribute, extract added attribute
+    check(_.dropFields("b").withField("b", lit(3)), "b", Some(3))
+    check(_.dropFields("a").withField("a", lit(3)), "a", Some(3))
+    check(_.dropFields("b").withField("b", lit(null).cast(IntegerType)), "b", 
None)
+    check(_.dropFields("a").withField("a", lit(null).cast(IntegerType)), "a", 
None)
+    check(_.dropFields("c").withField("c", lit(3)), "c", Some(3))
+
+    // add attribute, drop same attribute, add same attribute again, extract 
added attribute
+    check(_.withField("c", lit(3)).dropFields("c").withField("c", lit(4)), 
"c", Some(4))
+  }
+
+  test("should move field up one level of nesting") {
+    val nullableStructLevel2: DataFrame = spark.createDataFrame(
+      sparkContext.parallelize(Row(Row(null)) :: Row(Row(Row(1, 2, 3))) :: 
Nil),
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", structType, nullable = true))),
+          nullable = true))))
+
+    // move a field up one level
+    checkAnswer(
+      nullableStructLevel2.select(
+        col("a").withField("b", col("a.a.b")).dropFields("a.b").as("res")),
+      Row(Row(null, null)) ::  Row(Row(Row(1, 3), 2)) :: Nil,
+      StructType(Seq(
+        StructField("res", StructType(Seq(
+          StructField("a", StructType(Seq(
+            StructField("a", IntegerType, nullable = false),
+            StructField("c", IntegerType, nullable = false))),
+            nullable = true),
+          StructField("b", IntegerType, nullable = true))),
+          nullable = true))))
+
+    // move a field up one level and then extract it
+    checkAnswer(
+      nullableStructLevel2.select(col("a").withField("b", 
col("a.a.b")).getField("b").as("res")),
+      Row(null) :: Row(2) :: Nil,
+      StructType(Seq(StructField("res", IntegerType, nullable = true))))
+  }
+
+  test("should be able to refer to newly added nested column") {
+    intercept[AnalysisException] {
+      structLevel1.select($"a".withField("d", lit(4)).withField("e", $"a.d" + 
1).as("a"))
+    }.getMessage should include("No such struct field d in a, b, c")
+
+    checkAnswer(
+      structLevel1
+        .select($"a".withField("d", lit(4)).as("a"))
+        .select($"a".withField("e", $"a.d" + 1).as("a")),
+      Row(Row(1, null, 3, 4, 5)) :: Nil,
+      StructType(Seq(
+        StructField("a", StructType(Seq(
+          StructField("a", IntegerType, nullable = false),
+          StructField("b", IntegerType, nullable = true),
+          StructField("c", IntegerType, nullable = false),
+          StructField("d", IntegerType, nullable = false),
+          StructField("e", IntegerType, nullable = false))),
+          nullable = false))))
+  }

Review comment:
       I don't expect anyone will be surprised or feel that this is wrong but 
nevertheless, I did want to highlight this behaviour. Same goes for the two 
tests below. 

##########
File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
##########
@@ -537,18 +662,75 @@ class ComplexTypesSuite extends PlanTest with 
ExpressionEvalHelper {
       query(testStructRelation),
       testStructRelation
         .select(
-          GetStructField('struct1, 0, Some("a")) as "struct2A",
+          GetStructField('struct1, 0) as "struct2A",
           Literal(2) as "struct2B",
-          GetStructField('struct1, 0, Some("a")) as "struct3A",
+          GetStructField('struct1, 0) as "struct3A",
           Literal(3) as "struct3B"))
 
     checkRule(
       query(testNullableStructRelation),
       testNullableStructRelation
         .select(
-          GetStructField('struct1, 0, Some("a")) as "struct2A",
+          GetStructField('struct1, 0) as "struct2A",
           If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as 
"struct2B",
-          GetStructField('struct1, 0, Some("a")) as "struct3A",
+          GetStructField('struct1, 0) as "struct3A",
           If(IsNull('struct1), Literal(null, IntegerType), Literal(3)) as 
"struct3B"))
   }
+
+  test("simplify add multiple nested fields to struct") {
+    // this scenario is possible if users add multiple nested columns via the 
Column.withField API
+    // ideally, users should not be doing this.
+    val nullableStructLevel2 = LocalRelation(
+      'a1.struct(
+        'a2.struct('a3.int)).withNullability(false))
+
+    val query = {
+      val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2",
+        UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", 
Literal(2)))))))
+
+      nullableStructLevel2.select(
+        UpdateFields(
+          addB3toA1A2,
+          Seq(WithField("a2", UpdateFields(
+            GetStructField(addB3toA1A2, 0), Seq(WithField("c3", 
Literal(3))))))).as("a1"))
+    }
+
+    val expected = nullableStructLevel2.select(
+      UpdateFields('a1, Seq(
+        // scalastyle:off line.size.limit
+        WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 
2) :: Nil)),
+        WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 
2) :: WithField("c3", 3) :: Nil))
+        // scalastyle:on line.size.limit
+      )).as("a1"))

Review comment:
       This first `WithField` in here is entirely redundant and ideally we 
would optimize this away as well. 
   However, in the interests of keeping this PR simple, I have opted to forgo 
writing any such optimizer rule. 
   If necessary, we can address this in a future PR. 

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
##########
@@ -39,19 +40,14 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
       // Remove redundant field extraction.
       case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
         createNamedStruct.valExprs(ordinal)
-      case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, 
maybeName) =>
-        val name = w.dataType(ordinal).name
-        val matches = names.zip(valExprs).filter(_._1 == name)
-        if (matches.nonEmpty) {
-          // return last matching element as that is the final value for the 
field being extracted.
-          // For example, if a user submits a query like this:
-          // `$"struct_col".withField("b", lit(1)).withField("b", 
lit(2)).getField("b")`
-          // we want to return `lit(2)` (and not `lit(1)`).
-          val expr = matches.last._2
-          If(IsNull(struct), Literal(null, expr.dataType), expr)
-        } else {
-          GetStructField(struct, ordinal, maybeName)
-        }
+    case GetStructField(updateFields: UpdateFields, ordinal, _) =>
+      val structExpr = updateFields.structExpr
+      updateFields.newExprs(ordinal) match {
+        // if the struct itself is null, then any value extracted from it 
(expr) will be null
+        // so we don't need to wrap expr in If(IsNull(struct), Literal(null, 
expr.dataType), expr)
+        case expr: GetStructField if expr.child.semanticEquals(structExpr) => 
expr

Review comment:
       should I use `semanticEquals` or `fastEquals` here? The difference isn't 
entirely clear to me and my tests seem to pass in either scenario. 

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/Column.scala
##########
@@ -901,39 +901,125 @@ class Column(val expr: Expression) extends Logging {
    *   // result: org.apache.spark.sql.AnalysisException: Ambiguous reference 
to fields
    * }}}
    *
+   * This method supports adding/replacing nested fields directly e.g.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".withField("a.c", lit(3)).withField("a.d", 
lit(4)))
+   *   // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+   * }}}
+   *
+   * However, if you are going to add/replace multiple nested fields, it is 
more optimal to extract
+   * out the nested struct before adding/replacing multiple fields e.g.
+   *
+   * {{{
+   *   val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) 
struct_col")
+   *   df.select($"struct_col".withField("a", $"struct_col.a".withField("c", 
lit(3)).withField("d", lit(4))))
+   *   // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+   * }}}
+   *

Review comment:
       One of the issues in master branch with the current `Column.withField` 
implementation is the size of the parsed logical plan scales non-linearly with 
the number of directly-add-**nested**-column operations. This results in the 
driver spending a considerable amount of time analyzing and optimizing the 
logical plan (literally minutes, if it ever completes). 
   Users can avoid this issue entirely by writing their queries in a performant 
manner. 
   For example: 
   
   ```
     lazy val nullableStructLevel2: DataFrame = spark.createDataFrame(
       sparkContext.parallelize(Row(Row(Row(0))) :: Nil),
       StructType(Seq(
         StructField("a1", StructType(Seq(
           StructField("a2", StructType(Seq(
             StructField("col0", IntegerType, nullable = false))),
             nullable = true))),
           nullable = true))))
   
     val numColsToAdd = 100
   
     val expectedRows = Row(Row(Row(0 to numColsToAdd: _*))) :: Nil
     val expectedSchema =
       StructType(Seq(
         StructField("a1", StructType(Seq(
           StructField("a2", StructType((0 to numColsToAdd).map(num =>
             StructField(s"col$num", IntegerType, nullable = false))),
             nullable = true))),
           nullable = true)))
   
     test("good way of writing query") {
       // Spark can easily analyze and optimize the parsed logical plan in 
seconds
       checkAnswer(
         nullableStructLevel2
           .select(col("a1").withField("a2", (1 to 
numColsToAdd).foldLeft(col("a1.a2")) {
             (column, num) => column.withField(s"col$num", lit(num))
           }).as("a1")),
         expectedRows,
         expectedSchema)
     }
   
     test("bad way of writing the same query that will eventually fail with 
timeout exception with as little as numColsToAdd = 10") {
       checkAnswer(
         nullableStructLevel2
           .select((1 to numColsToAdd).foldLeft(col("a1")) {
             (column, num) => column.withField(s"a2.col$num", lit(num))
           }.as("a1")),
         expectedRows,
         expectedSchema)
     }
   ```
   
   This issue and its solution is what I've attempted to capture here as part 
of the method doc. 
   
   There are other options here instead of method-doc-note: 
   - We could potentially write some kind of optimization in 
`updateFieldsHelper` (I've bashed my head against this for a while but haven't 
been able to come up with anything satisfactory).
   - Remove the ability to change nested fields directly entirely. While this 
has the advantage that there will be absolutely no way to run into this 
"performance" issue, the user-experience definitely suffers for more advanced 
users who would know how to use these methods properly.  
   
   I've gone with what made most sense to me (method-doc-note) but am open to 
hearing other people's thoughts on the matter. 
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to