fqaiser94 commented on a change in pull request #29795:
URL: https://github.com/apache/spark/pull/29795#discussion_r491753622
##########
File path:
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
##########
@@ -537,18 +662,75 @@ class ComplexTypesSuite extends PlanTest with
ExpressionEvalHelper {
query(testStructRelation),
testStructRelation
.select(
- GetStructField('struct1, 0, Some("a")) as "struct2A",
+ GetStructField('struct1, 0) as "struct2A",
Literal(2) as "struct2B",
- GetStructField('struct1, 0, Some("a")) as "struct3A",
+ GetStructField('struct1, 0) as "struct3A",
Literal(3) as "struct3B"))
checkRule(
query(testNullableStructRelation),
testNullableStructRelation
.select(
- GetStructField('struct1, 0, Some("a")) as "struct2A",
+ GetStructField('struct1, 0) as "struct2A",
If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as
"struct2B",
- GetStructField('struct1, 0, Some("a")) as "struct3A",
+ GetStructField('struct1, 0) as "struct3A",
If(IsNull('struct1), Literal(null, IntegerType), Literal(3)) as
"struct3B"))
}
+
+ test("simplify add multiple nested fields to struct") {
+ // this scenario is possible if users add multiple nested columns via the
Column.withField API
+ // ideally, users should not be doing this.
+ val nullableStructLevel2 = LocalRelation(
+ 'a1.struct(
+ 'a2.struct('a3.int)).withNullability(false))
+
+ val query = {
+ val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2",
+ UpdateFields(GetStructField('a1, 0), Seq(WithField("b3",
Literal(2)))))))
+
+ nullableStructLevel2.select(
+ UpdateFields(
+ addB3toA1A2,
+ Seq(WithField("a2", UpdateFields(
+ GetStructField(addB3toA1A2, 0), Seq(WithField("c3",
Literal(3))))))).as("a1"))
+ }
+
+ val expected = nullableStructLevel2.select(
+ UpdateFields('a1, Seq(
+ // scalastyle:off line.size.limit
+ WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3",
2) :: Nil)),
+ WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3",
2) :: WithField("c3", 3) :: Nil))
+ // scalastyle:on line.size.limit
+ )).as("a1"))
+
+ checkRule(query, expected)
+ }
+
+ test("simplify drop multiple nested fields in struct") {
+ // this scenario is possible if users drop multiple nested columns via the
Column.dropFields API
+ // ideally, users should not be doing this.
+ val df = LocalRelation(
+ 'a1.struct(
+ 'a2.struct('a3.int, 'b3.int, 'c3.int).withNullability(false)
+ ).withNullability(false))
+
+ val query = {
+ val dropA1A2B = UpdateFields('a1, Seq(WithField("a2", UpdateFields(
+ GetStructField('a1, 0), Seq(DropField("b3"))))))
+
+ df.select(
+ UpdateFields(
+ dropA1A2B,
+ Seq(WithField("a2", UpdateFields(
+ GetStructField(dropA1A2B, 0), Seq(DropField("c3")))))).as("a1"))
+ }
+
+ val expected = df.select(
+ UpdateFields('a1, Seq(
+ WithField("a2", UpdateFields(GetStructField('a1, 0),
Seq(DropField("b3")))),
+ WithField("a2", UpdateFields(GetStructField('a1, 0),
Seq(DropField("b3"), DropField("c3"))))
+ )).as("a1"))
Review comment:
This first `WithField` in here is entirely redundant as well and ideally
we would optimize this away as well.
However, in the interests of keeping this PR simple, I have opted to forgo
writing any such optimizer rule.
If necessary, we can address this in a future PR.
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
##########
@@ -1514,27 +1516,578 @@ class ColumnExpressionSuite extends QueryTest with
SharedSparkSession {
StructType(Seq(StructField("a", structType, nullable = true))))
// extract newly added field
- checkAnswerAndSchema(
+ checkAnswer(
df.withColumn("a", $"a".withField("d", lit(4)).getField("d")),
Row(4) :: Row(null) :: Nil,
StructType(Seq(StructField("a", IntegerType, nullable = true))))
// extract newly replaced field
- checkAnswerAndSchema(
+ checkAnswer(
df.withColumn("a", $"a".withField("a", lit(4)).getField("a")),
Row(4) :: Row(null):: Nil,
StructType(Seq(StructField("a", IntegerType, nullable = true))))
// add new field, extract another field from original struct
- checkAnswerAndSchema(
+ checkAnswer(
df.withColumn("a", $"a".withField("d", lit(4)).getField("c")),
Row(3) :: Row(null):: Nil,
StructType(Seq(StructField("a", IntegerType, nullable = true))))
// replace field, extract another field from original struct
- checkAnswerAndSchema(
+ checkAnswer(
df.withColumn("a", $"a".withField("a", lit(4)).getField("c")),
Row(3) :: Row(null):: Nil,
StructType(Seq(StructField("a", IntegerType, nullable = true))))
}
+
+
+ test("dropFields should throw an exception if called on a non-StructType
column") {
+ intercept[AnalysisException] {
+ testData.withColumn("key", $"key".dropFields("a"))
+ }.getMessage should include("struct argument should be struct type, got:
int")
+ }
+
+ test("dropFields should throw an exception if fieldName argument is null") {
+ intercept[IllegalArgumentException] {
+ structLevel1.withColumn("a", $"a".dropFields(null))
+ }.getMessage should include("fieldName cannot be null")
+ }
+
+ test("dropFields should throw an exception if any intermediate structs don't
exist") {
+ intercept[AnalysisException] {
+ structLevel2.withColumn("a", 'a.dropFields("x.b"))
+ }.getMessage should include("No such struct field x in a")
+
+ intercept[AnalysisException] {
+ structLevel3.withColumn("a", 'a.dropFields("a.x.b"))
+ }.getMessage should include("No such struct field x in a")
+ }
+
+ test("dropFields should throw an exception if intermediate field is not a
struct") {
+ intercept[AnalysisException] {
+ structLevel1.withColumn("a", 'a.dropFields("b.a"))
+ }.getMessage should include("struct argument should be struct type, got:
int")
+ }
+
+ test("dropFields should throw an exception if intermediate field reference
is ambiguous") {
+ intercept[AnalysisException] {
+ val structLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = false),
+ StructField("a", structType, nullable = false))),
+ nullable = false))))
+
+ structLevel2.withColumn("a", 'a.dropFields("a.b"))
+ }.getMessage should include("Ambiguous reference to fields")
+ }
+
+ test("dropFields should drop field in struct") {
+ checkAnswer(
+ structLevel1.withColumn("a", 'a.dropFields("b")),
+ Row(Row(1, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("dropFields should drop field in null struct") {
+ checkAnswer(
+ nullStructLevel1.withColumn("a", $"a".dropFields("b")),
+ Row(null) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = true))))
+ }
+
+ test("dropFields should drop multiple fields in struct") {
+ Seq(
+ structLevel1.withColumn("a", $"a".dropFields("b", "c")),
+ structLevel1.withColumn("a", 'a.dropFields("b").dropFields("c"))
+ ).foreach { df =>
+ checkAnswer(
+ df,
+ Row(Row(1)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("dropFields should throw an exception if no fields will be left in
struct") {
+ intercept[AnalysisException] {
+ structLevel1.withColumn("a", 'a.dropFields("a", "b", "c"))
+ }.getMessage should include("cannot drop all fields in struct")
+ }
+
+ test("dropFields should drop field in nested struct") {
+ checkAnswer(
+ structLevel2.withColumn("a", 'a.dropFields("a.b")),
+ Row(Row(Row(1, 3))) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+
+ test("dropFields should drop multiple fields in nested struct") {
+ checkAnswer(
+ structLevel2.withColumn("a", 'a.dropFields("a.b", "a.c")),
+ Row(Row(Row(1))) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+
+ test("dropFields should drop field in nested null struct") {
+ checkAnswer(
+ nullStructLevel2.withColumn("a", $"a".dropFields("a.b")),
+ Row(Row(null)) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = true))),
+ nullable = false))))
+ }
+
+ test("dropFields should drop multiple fields in nested null struct") {
+ checkAnswer(
+ nullStructLevel2.withColumn("a", $"a".dropFields("a.b", "a.c")),
+ Row(Row(null)) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false))),
+ nullable = true))),
+ nullable = false))))
+ }
+
+ test("dropFields should drop field in deeply nested struct") {
+ checkAnswer(
+ structLevel3.withColumn("a", 'a.dropFields("a.a.b")),
+ Row(Row(Row(Row(1, 3)))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+
+ test("dropFields should drop all fields with given name in struct") {
+ val structLevel1 = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswer(
+ structLevel1.withColumn("a", 'a.dropFields("b")),
+ Row(Row(1)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("dropFields should drop field in struct even if casing is different") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ checkAnswer(
+ mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")),
+ Row(Row(1)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("B", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswer(
+ mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")),
+ Row(Row(1)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("dropFields should not drop field in struct because casing is
different") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ checkAnswer(
+ mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")),
+ Row(Row(1, 1)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("B", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswer(
+ mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")),
+ Row(Row(1, 1)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("B", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("dropFields should drop nested field in struct even if casing is
different") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ checkAnswer(
+ mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")),
+ Row(Row(Row(1), Row(1, 1))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("A", StructType(Seq(
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false),
+ StructField("B", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ checkAnswer(
+ mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")),
+ Row(Row(Row(1, 1), Row(1))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false),
+ StructField("b", StructType(Seq(
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("dropFields should throw an exception because casing is different") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ intercept[AnalysisException] {
+ mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a"))
+ }.getMessage should include("No such struct field A in a, B")
+
+ intercept[AnalysisException] {
+ mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a"))
+ }.getMessage should include("No such struct field b in a, B")
+ }
+ }
+
+ test("dropFields should drop only fields that exist") {
+ checkAnswer(
+ structLevel1.withColumn("a", 'a.dropFields("d")),
+ Row(Row(1, null, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswer(
+ structLevel1.withColumn("a", 'a.dropFields("b", "d")),
+ Row(Row(1, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswer(
+ structLevel2.withColumn("a", $"a".dropFields("a.b", "a.d")),
+ Row(Row(Row(1, 3))) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+
+ test("dropFields should drop multiple fields at arbitrary levels of nesting
in a single call") {
+ val df: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswer(
+ df.withColumn("a", $"a".dropFields("a.b", "b")),
+ Row(Row(Row(1, 3))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))), nullable =
false))),
+ nullable = false))))
+ }
+
+ test("dropFields user-facing examples") {
+ checkAnswer(
+ sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ .select($"struct_col".dropFields("b")),
+ Row(Row(1)))
+
+ checkAnswer(
+ sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ .select($"struct_col".dropFields("c")),
+ Row(Row(1, 2)))
+
+ checkAnswer(
+ sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col")
+ .select($"struct_col".dropFields("b", "c")),
+ Row(Row(1)))
+
+ intercept[AnalysisException] {
+ sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ .select($"struct_col".dropFields("a", "b"))
+ }.getMessage should include("cannot drop all fields in struct")
+
+ checkAnswer(
+ sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
+ .select($"struct_col".dropFields("b")),
+ Row(null))
+
+ checkAnswer(
+ sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
+ .select($"struct_col".dropFields("b")),
+ Row(Row(1)))
+
+ checkAnswer(
+ sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
+ .select($"struct_col".dropFields("a.b")),
+ Row(Row(Row(1))))
+
+ intercept[AnalysisException] {
+ sql("SELECT named_struct('a', named_struct('b', 1), 'a',
named_struct('c', 2)) struct_col")
+ .select($"struct_col".dropFields("a.c"))
+ }.getMessage should include("Ambiguous reference to fields")
+
+ checkAnswer(
+ sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2, 'c', 3))
struct_col")
+ .select($"struct_col".dropFields("a.b", "a.c")),
+ Row(Row(Row(1))))
+
+ checkAnswer(
+ sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2, 'c', 3))
struct_col")
+ .select($"struct_col".withField("a", $"struct_col.a".dropFields("b",
"c"))),
+ Row(Row(Row(1))))
+ }
+
+ test("should correctly handle different dropField + withField + getField
combinations") {
+ val structType = StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false)))
+
+ val structLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, 2)) :: Nil),
+ StructType(Seq(StructField("a", structType, nullable = false))))
+
+ val nullStructLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(null) :: Nil),
+ StructType(Seq(StructField("a", structType, nullable = true))))
+
+ val nullableStructLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, 2)) :: Row(null) :: Nil),
+ StructType(Seq(StructField("a", structType, nullable = true))))
+
+ def check(
+ fieldOps: Column => Column,
+ getFieldName: String,
+ expectedValue: Option[Int]): Unit = {
+
+ def query(df: DataFrame): DataFrame =
+ df.select(fieldOps(col("a")).getField(getFieldName).as("res"))
+
+ checkAnswer(
+ query(structLevel1),
+ Row(expectedValue.orNull) :: Nil,
+ StructType(Seq(StructField("res", IntegerType, nullable =
expectedValue.isEmpty))))
+
+ checkAnswer(
+ query(nullStructLevel1),
+ Row(null) :: Nil,
+ StructType(Seq(StructField("res", IntegerType, nullable = true))))
+
+ checkAnswer(
+ query(nullableStructLevel1),
+ Row(expectedValue.orNull) :: Row(null) :: Nil,
+ StructType(Seq(StructField("res", IntegerType, nullable = true))))
+ }
+
+ // add attribute, extract an attribute from the original struct
+ check(_.withField("c", lit(3)), "a", Some(1))
+ check(_.withField("c", lit(3)), "b", Some(2))
+
+ // add attribute, extract added attribute
+ check(_.withField("c", lit(3)), "c", Some(3))
+ check(_.withField("c", col("a.a")), "c", Some(1))
+ check(_.withField("c", col("a.b")), "c", Some(2))
+ check(_.withField("c", lit(null).cast(IntegerType)), "c", None)
+
+ // replace attribute, extract an attribute from the original struct
+ check(_.withField("b", lit(3)), "a", Some(1))
+ check(_.withField("a", lit(3)), "b", Some(2))
+
+ // replace attribute, extract replaced attribute
+ check(_.withField("b", lit(3)), "b", Some(3))
+ check(_.withField("b", lit(null).cast(IntegerType)), "b", None)
+ check(_.withField("a", lit(3)), "a", Some(3))
+ check(_.withField("a", lit(null).cast(IntegerType)), "a", None)
+
+ // drop attribute, extract an attribute from the original struct
+ check(_.dropFields("b"), "a", Some(1))
+ check(_.dropFields("a"), "b", Some(2))
+
+ // drop attribute, add attribute, extract an attribute from the original
struct
+ check(_.dropFields("b").withField("c", lit(3)), "a", Some(1))
+ check(_.dropFields("a").withField("c", lit(3)), "b", Some(2))
+
+ // drop attribute, add another attribute, extract added attribute
+ check(_.dropFields("a").withField("c", lit(3)), "c", Some(3))
+ check(_.dropFields("b").withField("c", lit(3)), "c", Some(3))
+
+ // add attribute, drop attribute, extract an attribute from the original
struct
+ check(_.withField("c", lit(3)).dropFields("a"), "b", Some(2))
+ check(_.withField("c", lit(3)).dropFields("b"), "a", Some(1))
+
+ // add attribute, drop another attribute, extract added attribute
+ check(_.withField("c", lit(3)).dropFields("a"), "c", Some(3))
+ check(_.withField("c", lit(3)).dropFields("b"), "c", Some(3))
+
+ // replace attribute, drop same attribute, extract an attribute from the
original struct
+ check(_.withField("b", lit(3)).dropFields("b"), "a", Some(1))
+ check(_.withField("a", lit(3)).dropFields("a"), "b", Some(2))
+
+ // add attribute, drop same attribute, extract an attribute from the
original struct
+ check(_.withField("c", lit(3)).dropFields("c"), "a", Some(1))
+ check(_.withField("c", lit(3)).dropFields("c"), "b", Some(2))
+
+ // add attribute, drop another attribute, extract added attribute
+ check(_.withField("b", lit(3)).dropFields("a"), "b", Some(3))
+ check(_.withField("a", lit(3)).dropFields("b"), "a", Some(3))
+ check(_.withField("b", lit(null).cast(IntegerType)).dropFields("a"), "b",
None)
+ check(_.withField("a", lit(null).cast(IntegerType)).dropFields("b"), "a",
None)
+
+ // drop attribute, add same attribute, extract added attribute
+ check(_.dropFields("b").withField("b", lit(3)), "b", Some(3))
+ check(_.dropFields("a").withField("a", lit(3)), "a", Some(3))
+ check(_.dropFields("b").withField("b", lit(null).cast(IntegerType)), "b",
None)
+ check(_.dropFields("a").withField("a", lit(null).cast(IntegerType)), "a",
None)
+ check(_.dropFields("c").withField("c", lit(3)), "c", Some(3))
+
+ // add attribute, drop same attribute, add same attribute again, extract
added attribute
+ check(_.withField("c", lit(3)).dropFields("c").withField("c", lit(4)),
"c", Some(4))
+ }
+
+ test("should move field up one level of nesting") {
+ val nullableStructLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(null)) :: Row(Row(Row(1, 2, 3))) ::
Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = true))),
+ nullable = true))))
+
+ // move a field up one level
+ checkAnswer(
+ nullableStructLevel2.select(
+ col("a").withField("b", col("a.a.b")).dropFields("a.b").as("res")),
+ Row(Row(null, null)) :: Row(Row(Row(1, 3), 2)) :: Nil,
+ StructType(Seq(
+ StructField("res", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = true),
+ StructField("b", IntegerType, nullable = true))),
+ nullable = true))))
+
+ // move a field up one level and then extract it
+ checkAnswer(
+ nullableStructLevel2.select(col("a").withField("b",
col("a.a.b")).getField("b").as("res")),
+ Row(null) :: Row(2) :: Nil,
+ StructType(Seq(StructField("res", IntegerType, nullable = true))))
+ }
+
+ test("should be able to refer to newly added nested column") {
+ intercept[AnalysisException] {
+ structLevel1.select($"a".withField("d", lit(4)).withField("e", $"a.d" +
1).as("a"))
+ }.getMessage should include("No such struct field d in a, b, c")
+
+ checkAnswer(
+ structLevel1
+ .select($"a".withField("d", lit(4)).as("a"))
+ .select($"a".withField("e", $"a.d" + 1).as("a")),
+ Row(Row(1, null, 3, 4, 5)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false),
+ StructField("e", IntegerType, nullable = false))),
+ nullable = false))))
+ }
Review comment:
I don't expect anyone will be surprised or feel that this is wrong but
nevertheless, I did want to highlight this behaviour. Same goes for the two
tests below.
##########
File path:
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
##########
@@ -537,18 +662,75 @@ class ComplexTypesSuite extends PlanTest with
ExpressionEvalHelper {
query(testStructRelation),
testStructRelation
.select(
- GetStructField('struct1, 0, Some("a")) as "struct2A",
+ GetStructField('struct1, 0) as "struct2A",
Literal(2) as "struct2B",
- GetStructField('struct1, 0, Some("a")) as "struct3A",
+ GetStructField('struct1, 0) as "struct3A",
Literal(3) as "struct3B"))
checkRule(
query(testNullableStructRelation),
testNullableStructRelation
.select(
- GetStructField('struct1, 0, Some("a")) as "struct2A",
+ GetStructField('struct1, 0) as "struct2A",
If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as
"struct2B",
- GetStructField('struct1, 0, Some("a")) as "struct3A",
+ GetStructField('struct1, 0) as "struct3A",
If(IsNull('struct1), Literal(null, IntegerType), Literal(3)) as
"struct3B"))
}
+
+ test("simplify add multiple nested fields to struct") {
+ // this scenario is possible if users add multiple nested columns via the
Column.withField API
+ // ideally, users should not be doing this.
+ val nullableStructLevel2 = LocalRelation(
+ 'a1.struct(
+ 'a2.struct('a3.int)).withNullability(false))
+
+ val query = {
+ val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2",
+ UpdateFields(GetStructField('a1, 0), Seq(WithField("b3",
Literal(2)))))))
+
+ nullableStructLevel2.select(
+ UpdateFields(
+ addB3toA1A2,
+ Seq(WithField("a2", UpdateFields(
+ GetStructField(addB3toA1A2, 0), Seq(WithField("c3",
Literal(3))))))).as("a1"))
+ }
+
+ val expected = nullableStructLevel2.select(
+ UpdateFields('a1, Seq(
+ // scalastyle:off line.size.limit
+ WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3",
2) :: Nil)),
+ WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3",
2) :: WithField("c3", 3) :: Nil))
+ // scalastyle:on line.size.limit
+ )).as("a1"))
Review comment:
This first `WithField` in here is entirely redundant and ideally we
would optimize this away as well.
However, in the interests of keeping this PR simple, I have opted to forgo
writing any such optimizer rule.
If necessary, we can address this in a future PR.
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
##########
@@ -39,19 +40,14 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)
- case GetStructField(w @ WithFields(struct, names, valExprs), ordinal,
maybeName) =>
- val name = w.dataType(ordinal).name
- val matches = names.zip(valExprs).filter(_._1 == name)
- if (matches.nonEmpty) {
- // return last matching element as that is the final value for the
field being extracted.
- // For example, if a user submits a query like this:
- // `$"struct_col".withField("b", lit(1)).withField("b",
lit(2)).getField("b")`
- // we want to return `lit(2)` (and not `lit(1)`).
- val expr = matches.last._2
- If(IsNull(struct), Literal(null, expr.dataType), expr)
- } else {
- GetStructField(struct, ordinal, maybeName)
- }
+ case GetStructField(updateFields: UpdateFields, ordinal, _) =>
+ val structExpr = updateFields.structExpr
+ updateFields.newExprs(ordinal) match {
+ // if the struct itself is null, then any value extracted from it
(expr) will be null
+ // so we don't need to wrap expr in If(IsNull(struct), Literal(null,
expr.dataType), expr)
+ case expr: GetStructField if expr.child.semanticEquals(structExpr) =>
expr
Review comment:
should I use `semanticEquals` or `fastEquals` here? The difference isn't
entirely clear to me and my tests seem to pass in either scenario.
##########
File path: sql/core/src/main/scala/org/apache/spark/sql/Column.scala
##########
@@ -901,39 +901,125 @@ class Column(val expr: Expression) extends Logging {
* // result: org.apache.spark.sql.AnalysisException: Ambiguous reference
to fields
* }}}
*
+ * This method supports adding/replacing nested fields directly e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".withField("a.c", lit(3)).withField("a.d",
lit(4)))
+ * // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+ * }}}
+ *
+ * However, if you are going to add/replace multiple nested fields, it is
more optimal to extract
+ * out the nested struct before adding/replacing multiple fields e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".withField("a", $"struct_col.a".withField("c",
lit(3)).withField("d", lit(4))))
+ * // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+ * }}}
+ *
Review comment:
One of the issues in master branch with the current `Column.withField`
implementation is the size of the parsed logical plan scales non-linearly with
the number of directly-add-**nested**-column operations. This results in the
driver spending a considerable amount of time analyzing and optimizing the
logical plan (literally minutes, if it ever completes).
Users can avoid this issue entirely by writing their queries in a performant
manner.
For example:
```
lazy val nullableStructLevel2: DataFrame = spark.createDataFrame(
sparkContext.parallelize(Row(Row(Row(0))) :: Nil),
StructType(Seq(
StructField("a1", StructType(Seq(
StructField("a2", StructType(Seq(
StructField("col0", IntegerType, nullable = false))),
nullable = true))),
nullable = true))))
val numColsToAdd = 100
val expectedRows = Row(Row(Row(0 to numColsToAdd: _*))) :: Nil
val expectedSchema =
StructType(Seq(
StructField("a1", StructType(Seq(
StructField("a2", StructType((0 to numColsToAdd).map(num =>
StructField(s"col$num", IntegerType, nullable = false))),
nullable = true))),
nullable = true)))
test("good way of writing query") {
// Spark can easily analyze and optimize the parsed logical plan in
seconds
checkAnswer(
nullableStructLevel2
.select(col("a1").withField("a2", (1 to
numColsToAdd).foldLeft(col("a1.a2")) {
(column, num) => column.withField(s"col$num", lit(num))
}).as("a1")),
expectedRows,
expectedSchema)
}
test("bad way of writing the same query that will eventually fail with
timeout exception with as little as numColsToAdd = 10") {
checkAnswer(
nullableStructLevel2
.select((1 to numColsToAdd).foldLeft(col("a1")) {
(column, num) => column.withField(s"a2.col$num", lit(num))
}.as("a1")),
expectedRows,
expectedSchema)
}
```
This issue and its solution is what I've attempted to capture here as part
of the method doc.
There are other options here instead of method-doc-note:
- We could potentially write some kind of optimization in
`updateFieldsHelper` (I've bashed my head against this for a while but haven't
been able to come up with anything satisfactory).
- Remove the ability to change nested fields directly entirely. While this
has the advantage that there will be absolutely no way to run into this
"performance" issue, the user-experience definitely suffers for more advanced
users who would know how to use these methods properly.
I've gone with what made most sense to me (method-doc-note) but am open to
hearing other people's thoughts on the matter.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]