fqaiser94 commented on a change in pull request #27066:
URL: https://github.com/apache/spark/pull/27066#discussion_r448589210
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
##########
@@ -923,4 +923,501 @@ class ColumnExpressionSuite extends QueryTest with
SharedSparkSession {
val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString))
assert(inSet.sql === "('a' IN ('a', 'b'))")
}
+
+ def checkAnswerAndSchema(
+ df: => DataFrame,
+ expectedAnswer: Seq[Row],
+ expectedSchema: StructType): Unit = {
+
+ checkAnswer(df, expectedAnswer)
+ assert(df.schema == expectedSchema)
+ }
+
+ private lazy val structType = StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false)))
+
+ private lazy val structLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, null, 3)) :: Nil),
+ StructType(Seq(StructField("a", structType, nullable = false))))
+
+ private lazy val nullStructLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(null) :: Nil),
+ StructType(Seq(StructField("a", structType, nullable = true))))
+
+ private lazy val structLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = false))),
+ nullable = false))))
+
+ private lazy val nullStructLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(null)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = true))),
+ nullable = false))))
+
+ private lazy val structLevel3: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(Row(1, null, 3)))) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ test("withField should throw an exception if called on a non-StructType
column") {
+ intercept[AnalysisException] {
+ testData.withColumn("key", $"key".withField("a", lit(2)))
+ }.getMessage should include("struct argument should be struct type, got:
int")
+ }
+
+ test("withField should throw an exception if either fieldName or col
argument are null") {
+ an[java.lang.NullPointerException] should be thrownBy {
+ structLevel1.withColumn("a", $"a".withField(null, lit(2)))
+ }
+
+ an[java.lang.NullPointerException] should be thrownBy {
+ structLevel1.withColumn("a", $"a".withField("b", null))
+ }
+
+ an[java.lang.NullPointerException] should be thrownBy {
+ structLevel1.withColumn("a", $"a".withField(null, null))
+ }
+ }
+
+ test("withField should throw an exception if any intermediate structs don't
exist") {
+ intercept[AnalysisException] {
+ structLevel2.withColumn("a", 'a.withField("x.b", lit(2)))
+ }.getMessage should include("No such struct field x in a")
+
+ intercept[AnalysisException] {
+ structLevel3.withColumn("a", 'a.withField("a.x.b", lit(2)))
+ }.getMessage should include("No such struct field x in a")
+ }
+
+ test("withField should throw an exception if any intermediate field is not a
struct") {
+ intercept[AnalysisException] {
+ structLevel1.withColumn("a", 'a.withField("b.a", lit(2)))
+ }.getMessage should include("struct argument should be struct type, got:
int")
+
+ intercept[AnalysisException] {
+ val structLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = false),
+ StructField("a", IntegerType, nullable = false))),
+ nullable = false))))
+
+ structLevel2.withColumn("a", 'a.withField("a.b", lit(2)))
+ }.getMessage should include("Ambiguous reference to fields")
+ }
+
+ test("withField should add field with no name") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", $"a".withField("", lit(4))),
+ Row(Row(1, null, 3, 4)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field to struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d", lit(4))),
+ Row(Row(1, null, 3, 4)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field to null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel1.withColumn("a", $"a".withField("d", lit(4))),
+ Row(null) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = true))))
+ }
+
+ test("withField should add field to nested null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))),
+ Row(Row(null)) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should add null field to struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d",
lit(null).cast(IntegerType))),
+ Row(Row(1, null, 3, null)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should add multiple fields to struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e",
lit(5))),
+ Row(Row(1, null, 3, 4, 5)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false),
+ StructField("e", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field to nested struct") {
+ Seq(
+ structLevel2.withColumn("a", 'a.withField("a.d", lit(4))),
+ structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d",
lit(4))))
+ ).foreach { df =>
+ checkAnswerAndSchema(
+ df,
+ Row(Row(Row(1, null, 3, 4))) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("withField should add field to deeply nested struct") {
+ checkAnswerAndSchema(
+ structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))),
+ Row(Row(Row(Row(1, null, 3, 4)))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace field in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("b", lit(2))),
+ Row(Row(1, 2, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace field in null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel1.withColumn("a", 'a.withField("b", lit(2))),
+ Row(null) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = true))))
+ }
+
+ test("withField should replace field in nested null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel2.withColumn("a", $"a".withField("a.b", lit(2))),
+ Row(Row(null)) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should replace field with null value in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("c",
lit(null).cast(IntegerType))),
+ Row(Row(1, null, null)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should replace multiple fields in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b",
lit(20))),
+ Row(Row(10, 20, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace field in nested struct") {
+ Seq(
+ structLevel2.withColumn("a", $"a".withField("a.b", lit(2))),
+ structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b",
lit(2))))
+ ).foreach { df =>
+ checkAnswerAndSchema(
+ df,
+ Row(Row(Row(1, 2, 3))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("withField should replace field in deeply nested struct") {
+ checkAnswerAndSchema(
+ structLevel3.withColumn("a", $"a".withField("a.a.b", lit(2))),
+ Row(Row(Row(Row(1, 2, 3)))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace all fields with given name in struct") {
+ val structLevel1 = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("b", lit(100))),
+ Row(Row(1, 100, 100)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace fields in struct in given order") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b",
lit(20))),
+ Row(Row(1, 20, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field and then replace same field in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d",
lit(5))),
+ Row(Row(1, null, 3, 5)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should handle fields with dots in their name if correctly
quoted") {
+ val df: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a.b", StructType(Seq(
+ StructField("c.d", IntegerType, nullable = false),
+ StructField("e.f", IntegerType, nullable = true),
+ StructField("g.h", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ checkAnswerAndSchema(
+ df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))),
+ Row(Row(Row(1, 2, 3))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a.b", StructType(Seq(
+ StructField("c.d", IntegerType, nullable = false),
+ StructField("e.f", IntegerType, nullable = false),
+ StructField("g.h", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ intercept[AnalysisException] {
+ df.withColumn("a", 'a.withField("a.b.e.f", lit(2)))
+ }.getMessage should include("No such struct field a in a.b")
+ }
+
+ private lazy val mixedCaseStructLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, 1)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("B", IntegerType, nullable = false))),
+ nullable = false))))
+
+ test("withField should replace field in struct even if casing is different")
{
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> false.toString) {
Review comment:
done.
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
##########
@@ -923,4 +923,501 @@ class ColumnExpressionSuite extends QueryTest with
SharedSparkSession {
val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString))
assert(inSet.sql === "('a' IN ('a', 'b'))")
}
+
+ def checkAnswerAndSchema(
+ df: => DataFrame,
+ expectedAnswer: Seq[Row],
+ expectedSchema: StructType): Unit = {
+
+ checkAnswer(df, expectedAnswer)
+ assert(df.schema == expectedSchema)
+ }
+
+ private lazy val structType = StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false)))
+
+ private lazy val structLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, null, 3)) :: Nil),
+ StructType(Seq(StructField("a", structType, nullable = false))))
+
+ private lazy val nullStructLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(null) :: Nil),
+ StructType(Seq(StructField("a", structType, nullable = true))))
+
+ private lazy val structLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = false))),
+ nullable = false))))
+
+ private lazy val nullStructLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(null)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = true))),
+ nullable = false))))
+
+ private lazy val structLevel3: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(Row(1, null, 3)))) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ test("withField should throw an exception if called on a non-StructType
column") {
+ intercept[AnalysisException] {
+ testData.withColumn("key", $"key".withField("a", lit(2)))
+ }.getMessage should include("struct argument should be struct type, got:
int")
+ }
+
+ test("withField should throw an exception if either fieldName or col
argument are null") {
+ an[java.lang.NullPointerException] should be thrownBy {
+ structLevel1.withColumn("a", $"a".withField(null, lit(2)))
+ }
+
+ an[java.lang.NullPointerException] should be thrownBy {
+ structLevel1.withColumn("a", $"a".withField("b", null))
+ }
+
+ an[java.lang.NullPointerException] should be thrownBy {
+ structLevel1.withColumn("a", $"a".withField(null, null))
+ }
+ }
+
+ test("withField should throw an exception if any intermediate structs don't
exist") {
+ intercept[AnalysisException] {
+ structLevel2.withColumn("a", 'a.withField("x.b", lit(2)))
+ }.getMessage should include("No such struct field x in a")
+
+ intercept[AnalysisException] {
+ structLevel3.withColumn("a", 'a.withField("a.x.b", lit(2)))
+ }.getMessage should include("No such struct field x in a")
+ }
+
+ test("withField should throw an exception if any intermediate field is not a
struct") {
+ intercept[AnalysisException] {
+ structLevel1.withColumn("a", 'a.withField("b.a", lit(2)))
+ }.getMessage should include("struct argument should be struct type, got:
int")
+
+ intercept[AnalysisException] {
+ val structLevel2: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", structType, nullable = false),
+ StructField("a", IntegerType, nullable = false))),
+ nullable = false))))
+
+ structLevel2.withColumn("a", 'a.withField("a.b", lit(2)))
+ }.getMessage should include("Ambiguous reference to fields")
+ }
+
+ test("withField should add field with no name") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", $"a".withField("", lit(4))),
+ Row(Row(1, null, 3, 4)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field to struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d", lit(4))),
+ Row(Row(1, null, 3, 4)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field to null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel1.withColumn("a", $"a".withField("d", lit(4))),
+ Row(null) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = true))))
+ }
+
+ test("withField should add field to nested null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))),
+ Row(Row(null)) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should add null field to struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d",
lit(null).cast(IntegerType))),
+ Row(Row(1, null, 3, null)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should add multiple fields to struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e",
lit(5))),
+ Row(Row(1, null, 3, 4, 5)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false),
+ StructField("e", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field to nested struct") {
+ Seq(
+ structLevel2.withColumn("a", 'a.withField("a.d", lit(4))),
+ structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d",
lit(4))))
+ ).foreach { df =>
+ checkAnswerAndSchema(
+ df,
+ Row(Row(Row(1, null, 3, 4))) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("withField should add field to deeply nested struct") {
+ checkAnswerAndSchema(
+ structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))),
+ Row(Row(Row(Row(1, null, 3, 4)))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace field in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("b", lit(2))),
+ Row(Row(1, 2, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace field in null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel1.withColumn("a", 'a.withField("b", lit(2))),
+ Row(null) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = true))))
+ }
+
+ test("withField should replace field in nested null struct") {
+ checkAnswerAndSchema(
+ nullStructLevel2.withColumn("a", $"a".withField("a.b", lit(2))),
+ Row(Row(null)) :: Nil,
+ StructType(
+ Seq(StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should replace field with null value in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("c",
lit(null).cast(IntegerType))),
+ Row(Row(1, null, null)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = true))),
+ nullable = false))))
+ }
+
+ test("withField should replace multiple fields in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b",
lit(20))),
+ Row(Row(10, 20, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace field in nested struct") {
+ Seq(
+ structLevel2.withColumn("a", $"a".withField("a.b", lit(2))),
+ structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b",
lit(2))))
+ ).foreach { df =>
+ checkAnswerAndSchema(
+ df,
+ Row(Row(Row(1, 2, 3))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("withField should replace field in deeply nested struct") {
+ checkAnswerAndSchema(
+ structLevel3.withColumn("a", $"a".withField("a.a.b", lit(2))),
+ Row(Row(Row(Row(1, 2, 3)))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace all fields with given name in struct") {
+ val structLevel1 = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("b", lit(100))),
+ Row(Row(1, 100, 100)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should replace fields in struct in given order") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b",
lit(20))),
+ Row(Row(1, 20, 3)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false),
+ StructField("c", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should add field and then replace same field in struct") {
+ checkAnswerAndSchema(
+ structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d",
lit(5))),
+ Row(Row(1, null, 3, 5)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = false),
+ StructField("d", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+
+ test("withField should handle fields with dots in their name if correctly
quoted") {
+ val df: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a.b", StructType(Seq(
+ StructField("c.d", IntegerType, nullable = false),
+ StructField("e.f", IntegerType, nullable = true),
+ StructField("g.h", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ checkAnswerAndSchema(
+ df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))),
+ Row(Row(Row(1, 2, 3))) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a.b", StructType(Seq(
+ StructField("c.d", IntegerType, nullable = false),
+ StructField("e.f", IntegerType, nullable = false),
+ StructField("g.h", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ intercept[AnalysisException] {
+ df.withColumn("a", 'a.withField("a.b.e.f", lit(2)))
+ }.getMessage should include("No such struct field a in a.b")
+ }
+
+ private lazy val mixedCaseStructLevel1: DataFrame = spark.createDataFrame(
+ sparkContext.parallelize(Row(Row(1, 1)) :: Nil),
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("B", IntegerType, nullable = false))),
+ nullable = false))))
+
+ test("withField should replace field in struct even if casing is different")
{
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> false.toString) {
+ checkAnswerAndSchema(
+ mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))),
+ Row(Row(2, 1)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("A", IntegerType, nullable = false),
+ StructField("B", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswerAndSchema(
+ mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))),
+ Row(Row(1, 2)) :: Nil,
+ StructType(Seq(
+ StructField("a", StructType(Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", IntegerType, nullable = false))),
+ nullable = false))))
+ }
+ }
+
+ test("withField should add field to struct because casing is different") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> true.toString) {
Review comment:
done.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]