hvanhovell commented on code in PR #47785:
URL: https://github.com/apache/spark/pull/47785#discussion_r1720885706


##########
sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala:
##########
@@ -974,75 +974,40 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("SPARK-35876: arrays_zip should retain field names") {
-    withTempDir { dir =>
-      val df = spark.sparkContext.parallelize(
-        Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6)))).toDF("val1", "val2")
-      val qualifiedDF = df.as("foo")
-
-      // Fields are UnresolvedAttribute
-      val zippedDF1 =
-        qualifiedDF.select(Column(ArraysZip(Seq($"foo.val1".expr, 
$"foo.val2".expr))) as "zipped")
-      val maybeAlias1 = zippedDF1.queryExecution.logical.expressions.head
-      assert(maybeAlias1.isInstanceOf[Alias])
-      val maybeArraysZip1 = maybeAlias1.children.head
-      assert(maybeArraysZip1.isInstanceOf[ArraysZip])
-      
assert(maybeArraysZip1.children.forall(_.isInstanceOf[UnresolvedAttribute]))
-      val file1 = new File(dir, "arrays_zip1")
-      zippedDF1.write.parquet(file1.getAbsolutePath)
-      val restoredDF1 = spark.read.parquet(file1.getAbsolutePath)
-      val fieldNames1 = 
restoredDF1.schema.head.dataType.asInstanceOf[ArrayType]
-        .elementType.asInstanceOf[StructType].fieldNames
-      assert(fieldNames1.toSeq === Seq("val1", "val2"))
-
-      // Fields are resolved NamedExpression
-      val zippedDF2 =
-        df.select(Column(ArraysZip(Seq(df("val1").expr, df("val2").expr))) as 
"zipped")
-      val maybeAlias2 = zippedDF2.queryExecution.logical.expressions.head
-      assert(maybeAlias2.isInstanceOf[Alias])
-      val maybeArraysZip2 = maybeAlias2.children.head
-      assert(maybeArraysZip2.isInstanceOf[ArraysZip])
-      assert(maybeArraysZip2.children.forall(
-        e => e.isInstanceOf[AttributeReference] && e.resolved))
-      val file2 = new File(dir, "arrays_zip2")
-      zippedDF2.write.parquet(file2.getAbsolutePath)
-      val restoredDF2 = spark.read.parquet(file2.getAbsolutePath)
-      val fieldNames2 = 
restoredDF2.schema.head.dataType.asInstanceOf[ArrayType]
-        .elementType.asInstanceOf[StructType].fieldNames
-      assert(fieldNames2.toSeq === Seq("val1", "val2"))
-
-      // Fields are unresolved NamedExpression
-      val zippedDF3 = df.select(
-        Column(ArraysZip(Seq(($"val1" as "val3").expr, ($"val2" as 
"val4").expr))) as "zipped")
-      val maybeAlias3 = zippedDF3.queryExecution.logical.expressions.head
-      assert(maybeAlias3.isInstanceOf[Alias])
-      val maybeArraysZip3 = maybeAlias3.children.head
-      assert(maybeArraysZip3.isInstanceOf[ArraysZip])
-      assert(maybeArraysZip3.children.forall(e => e.isInstanceOf[Alias] && 
!e.resolved))
-      val file3 = new File(dir, "arrays_zip3")
-      zippedDF3.write.parquet(file3.getAbsolutePath)
-      val restoredDF3 = spark.read.parquet(file3.getAbsolutePath)
-      val fieldNames3 = 
restoredDF3.schema.head.dataType.asInstanceOf[ArrayType]
-        .elementType.asInstanceOf[StructType].fieldNames
-      assert(fieldNames3.toSeq === Seq("val3", "val4"))
-
-      // Fields are neither UnresolvedAttribute nor NamedExpression
-      val zippedDF4 = df.select(
-        Column(ArraysZip(Seq(array_sort($"val1").expr, 
array_sort($"val2").expr))) as "zipped")
-      val maybeAlias4 = zippedDF4.queryExecution.logical.expressions.head
-      assert(maybeAlias4.isInstanceOf[Alias])
-      val maybeArraysZip4 = maybeAlias4.children.head
-      assert(maybeArraysZip4.isInstanceOf[ArraysZip])
-      assert(maybeArraysZip4.children.forall {
-        case _: UnresolvedAttribute | _: NamedExpression => false
-        case _ => true
-      })
-      val file4 = new File(dir, "arrays_zip4")
-      zippedDF4.write.parquet(file4.getAbsolutePath)
-      val restoredDF4 = spark.read.parquet(file4.getAbsolutePath)
-      val fieldNames4 = 
restoredDF4.schema.head.dataType.asInstanceOf[ArrayType]
-        .elementType.asInstanceOf[StructType].fieldNames
-      assert(fieldNames4.toSeq === Seq("0", "1"))
-    }
+    val df = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2")

Review Comment:
   Mostly rewritten this because it was doing some crazy stuff... If you want 
to validate a schema post optimization then check that instead of writing a 
file and making all kinds of assumptions on the expression tree structure.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to