This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fe2174d8caaf [SPARK-47304][SQL][TESTS] Distribute tests from
`DataFrameSuite` to more specific suites
fe2174d8caaf is described below
commit fe2174d8caaf6f9474319a8d729a2f537038b4c1
Author: Max Gekk <[email protected]>
AuthorDate: Wed Mar 6 18:45:00 2024 +0300
[SPARK-47304][SQL][TESTS] Distribute tests from `DataFrameSuite` to more
specific suites
### What changes were proposed in this pull request?
In the PR, I propose to move some tests from `DataFrameSuite` out, and
distribute them across more specific tests suites and new one
`DataFrameShowSuite`.
### Why are the changes needed?
1. Improve maintainability of `DataFrameSuite`
2. Speed up execution of the test suite. Execution time dropped to < 1min.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By running the modified test suites.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #45392 from MaxGekk/split-DataFrameSuite.
Authored-by: Max Gekk <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../apache/spark/sql/DataFrameAggregateSuite.scala | 109 ++
.../spark/sql/DataFrameComplexTypeSuite.scala | 232 +++-
.../org/apache/spark/sql/DataFrameJoinSuite.scala | 19 +
.../apache/spark/sql/DataFrameSelfJoinSuite.scala | 10 +
.../org/apache/spark/sql/DataFrameShowSuite.scala | 487 ++++++++
.../org/apache/spark/sql/DataFrameStatSuite.scala | 65 ++
.../org/apache/spark/sql/DataFrameSuite.scala | 1202 +-------------------
.../spark/sql/StatisticsCollectionSuite.scala | 125 +-
.../test/scala/org/apache/spark/sql/UDFSuite.scala | 93 +-
.../execution/datasources/DataSourceSuite.scala | 49 +
.../sql/execution/datasources/json/JsonSuite.scala | 55 +
11 files changed, 1246 insertions(+), 1200 deletions(-)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index ec589fa77241..21d7156a62b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -2200,6 +2200,115 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(df, Row(1, 2, 2) :: Row(3, 1, 1) :: Nil)
}
}
+
+ private def assertDecimalSumOverflow(
+ df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
+ if (!ansiEnabled) {
+ checkAnswer(df, expectedAnswer)
+ } else {
+ val e = intercept[ArithmeticException] {
+ df.collect()
+ }
+ assert(e.getMessage.contains("cannot be represented as Decimal") ||
+ e.getMessage.contains("Overflow in sum of decimals"))
+ }
+ }
+
+ def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = {
+ Seq("true", "false").foreach { wholeStageEnabled =>
+ withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled))
{
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
+ val df0 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+ val df1 = Seq(
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+ val df = df0.union(df1)
+ val df2 = df.withColumnRenamed("decNum", "decNum2").
+ join(df, "intNum").agg(aggFn($"decNum"))
+
+ val expectedAnswer = Row(null)
+ assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)
+
+ val decStr = "1" + "0" * 19
+ val d1 = spark.range(0, 12, 1, 1)
+ val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as
d")).agg(aggFn($"d"))
+ assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)
+
+ val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
+ val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as
d")).agg(aggFn($"d"))
+ assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)
+
+ val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as
d"),
+
lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd")
+ assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)
+
+ val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as
decimal(38,18)) as d"))
+
+ val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"),
BigDecimal("9"* 20 + ".123")).
+ toDF("d")
+ assertDecimalSumOverflow(
+ nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled,
expectedAnswer)
+
+ val df3 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("50000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+
+ val df4 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+
+ val df5 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum")
+
+ val df6 = df3.union(df4).union(df5)
+ val df7 = df6.groupBy("intNum").agg(sum("decNum"),
countDistinct("decNum")).
+ filter("intNum == 1")
+ assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2))
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-28067: Aggregate sum should not return wrong results for decimal
overflow") {
+ checkAggResultsForDecimalOverflow(c => sum(c))
+ }
+
+ test("SPARK-35955: Aggregate avg should not return wrong results for decimal
overflow") {
+ checkAggResultsForDecimalOverflow(c => avg(c))
+ }
+
+ test("SPARK-28224: Aggregate sum big decimal overflow") {
+ val largeDecimals = spark.sparkContext.parallelize(
+ DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123"))
::
+ DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 +
".123")) :: Nil).toDF()
+
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
+ val structDf = largeDecimals.select("a").agg(sum("a"))
+ assertDecimalSumOverflow(structDf, ansiEnabled, Row(null))
+ }
+ }
+ }
+
+ test("SPARK-32761: aggregating multiple distinct CONSTANT columns") {
+ checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1,
1))
+ }
}
case class B(c: Option[Double])
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
index 81a797cc71bc..d982a000ad37 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
@@ -17,13 +17,20 @@
package org.apache.spark.sql
+import java.sql.{Date, Timestamp}
+
+import scala.reflect.runtime.universe.TypeTag
+
import org.apache.spark.sql.catalyst.DefinedByConstructorParams
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.objects.MapObjects
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.ArrayType
+import org.apache.spark.sql.types.{ArrayType, BooleanType, Decimal,
DoubleType, IntegerType, MapType, StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.CalendarInterval
/**
* A test suite to test DataFrame/SQL functionalities with complex types (i.e.
array, struct, map).
@@ -86,6 +93,229 @@ class DataFrameComplexTypeSuite extends QueryTest with
SharedSparkSession {
assert(result === Row(Seq(Seq(Row(1)), Seq(Row(2)), Seq(Row(3)))) :: Nil)
}
}
+
+ test("access complex data") {
+ assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1)
+ if (!conf.ansiEnabled) {
+ assert(complexData.filter(complexData("m").getItem("1") === 1).count()
== 1)
+ }
+ assert(complexData.filter(complexData("s").getField("key") === 1).count()
== 1)
+ }
+
+ test("SPARK-7133: Implement struct, array, and map field accessor") {
+ assert(complexData.filter(complexData("a")(0) === 2).count() == 1)
+ if (!conf.ansiEnabled) {
+ assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
+ }
+ assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
+ assert(complexData.filter(complexData("m")(complexData("s")("value")) ===
1).count() == 1)
+ assert(complexData.filter(complexData("a")(complexData("s")("key")) ===
1).count() == 1)
+ }
+
+ test("SPARK-24313: access map with binary keys") {
+ val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
+
checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))),
Row(1))
+ }
+
+ test("SPARK-37855: IllegalStateException when transforming an array inside a
nested struct") {
+ def makeInput(): DataFrame = {
+ val innerElement1 = Row(3, 3.12)
+ val innerElement2 = Row(4, 2.1)
+ val innerElement3 = Row(1, 985.2)
+ val innerElement4 = Row(10, 757548.0)
+ val innerElement5 = Row(1223, 0.665)
+
+ val outerElement1 = Row(1, Row(List(innerElement1, innerElement2)))
+ val outerElement2 = Row(2, Row(List(innerElement3)))
+ val outerElement3 = Row(3, Row(List(innerElement4, innerElement5)))
+
+ val data = Seq(
+ Row("row1", List(outerElement1)),
+ Row("row2", List(outerElement2, outerElement3))
+ )
+
+ val schema = new StructType()
+ .add("name", StringType)
+ .add("outer_array", ArrayType(new StructType()
+ .add("id", IntegerType)
+ .add("inner_array_struct", new StructType()
+ .add("inner_array", ArrayType(new StructType()
+ .add("id", IntegerType)
+ .add("value", DoubleType)
+ ))
+ )
+ ))
+
+ spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+ }
+
+ val df = makeInput().limit(2)
+
+ val res = df.withColumn("extracted", transform(
+ col("outer_array"),
+ c1 => {
+ struct(
+ c1.getField("id").alias("outer_id"),
+ transform(
+ c1.getField("inner_array_struct").getField("inner_array"),
+ c2 => {
+ struct(
+ c2.getField("value").alias("inner_value")
+ )
+ }
+ )
+ )
+ }
+ ))
+
+ assert(res.collect().length == 2)
+ }
+
+ test("SPARK-39293: The accumulator of ArrayAggregate to handle complex types
properly") {
+ val reverse = udf((s: String) => s.reverse)
+
+ val df = Seq(Array("abc", "def")).toDF("array")
+ val testArray = df.select(
+ aggregate(
+ col("array"),
+ array().cast("array<string>"),
+ (acc, s) => concat(acc, array(reverse(s)))))
+ checkAnswer(testArray, Row(Array("cba", "fed")) :: Nil)
+
+ val testMap = df.select(
+ aggregate(
+ col("array"),
+ map().cast("map<string, string>"),
+ (acc, s) => map_concat(acc, map(s, reverse(s)))))
+ checkAnswer(testMap, Row(Map("abc" -> "cba", "def" -> "fed")) :: Nil)
+ }
+
+ test("SPARK-31552: array encoder with different types") {
+ // primitives
+ val booleans = Array(true, false)
+ checkAnswer(Seq(booleans).toDF(), Row(booleans))
+
+ val bytes = Array(1.toByte, 2.toByte)
+ checkAnswer(Seq(bytes).toDF(), Row(bytes))
+ val shorts = Array(1.toShort, 2.toShort)
+ checkAnswer(Seq(shorts).toDF(), Row(shorts))
+ val ints = Array(1, 2)
+ checkAnswer(Seq(ints).toDF(), Row(ints))
+ val longs = Array(1L, 2L)
+ checkAnswer(Seq(longs).toDF(), Row(longs))
+
+ val floats = Array(1.0F, 2.0F)
+ checkAnswer(Seq(floats).toDF(), Row(floats))
+ val doubles = Array(1.0D, 2.0D)
+ checkAnswer(Seq(doubles).toDF(), Row(doubles))
+
+ val strings = Array("2020-04-24", "2020-04-25")
+ checkAnswer(Seq(strings).toDF(), Row(strings))
+
+ // tuples
+ val decOne = Decimal(1, 38, 18)
+ val decTwo = Decimal(2, 38, 18)
+ val tuple1 = (1, 2.2, "3.33", decOne, Date.valueOf("2012-11-22"))
+ val tuple2 = (2, 3.3, "4.44", decTwo, Date.valueOf("2022-11-22"))
+ checkAnswer(Seq(Array(tuple1, tuple2)).toDF(), Seq(Seq(tuple1,
tuple2)).toDF())
+
+ // case classes
+ val gbks = Array(GroupByKey(1, 2), GroupByKey(4, 5))
+ checkAnswer(Seq(gbks).toDF(), Row(Array(Row(1, 2), Row(4, 5))))
+
+ // We can move this implicit def to [[SQLImplicits]] when we eventually
make fully
+ // support for array encoder like Seq and Set
+ // For now cases below, decimal/datetime/interval/binary/nested types, etc,
+ // are not supported by array
+ implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] =
ExpressionEncoder()
+
+ // decimals
+ val decSpark = Array(decOne, decTwo)
+ val decScala = decSpark.map(_.toBigDecimal)
+ val decJava = decSpark.map(_.toJavaBigDecimal)
+ checkAnswer(Seq(decSpark).toDF(), Row(decJava))
+ checkAnswer(Seq(decScala).toDF(), Row(decJava))
+ checkAnswer(Seq(decJava).toDF(), Row(decJava))
+
+ // datetimes and intervals
+ val dates = strings.map(Date.valueOf)
+ checkAnswer(Seq(dates).toDF(), Row(dates))
+ val localDates = dates.map(d =>
DateTimeUtils.daysToLocalDate(DateTimeUtils.fromJavaDate(d)))
+ checkAnswer(Seq(localDates).toDF(), Row(dates))
+
+ val timestamps =
+ Array(Timestamp.valueOf("2020-04-24 12:34:56"),
Timestamp.valueOf("2020-04-24 11:22:33"))
+ checkAnswer(Seq(timestamps).toDF(), Row(timestamps))
+ val instants =
+ timestamps.map(t =>
DateTimeUtils.microsToInstant(DateTimeUtils.fromJavaTimestamp(t)))
+ checkAnswer(Seq(instants).toDF(), Row(timestamps))
+
+ val intervals = Array(new CalendarInterval(1, 2, 3), new
CalendarInterval(4, 5, 6))
+ checkAnswer(Seq(intervals).toDF(), Row(intervals))
+
+ // binary
+ val bins = Array(Array(1.toByte), Array(2.toByte), Array(3.toByte),
Array(4.toByte))
+ checkAnswer(Seq(bins).toDF(), Row(bins))
+
+ // nested
+ val nestedIntArray = Array(Array(1), Array(2))
+ checkAnswer(Seq(nestedIntArray).toDF(),
Row(nestedIntArray.map(wrapIntArray)))
+ val nestedDecArray = Array(decSpark)
+ checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava))))
+ }
+
+ test("SPARK-24165: CaseWhen/If - nullability of nested types") {
+ val rows = new java.util.ArrayList[Row]()
+ rows.add(Row(true, ("x", 1), Seq("x", "y"), Map(0 -> "x")))
+ rows.add(Row(false, (null, 2), Seq(null, "z"), Map(0 -> null)))
+ val schema = StructType(Seq(
+ StructField("cond", BooleanType, true),
+ StructField("s", StructType(Seq(
+ StructField("val1", StringType, true),
+ StructField("val2", IntegerType, false)
+ )), false),
+ StructField("a", ArrayType(StringType, true)),
+ StructField("m", MapType(IntegerType, StringType, true))
+ ))
+
+ val sourceDF = spark.createDataFrame(rows, schema)
+
+ def structWhenDF: DataFrame = sourceDF
+ .select(when($"cond",
+ struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise($"s") as
"res")
+ .select($"res".getField("val1"))
+ def arrayWhenDF: DataFrame = sourceDF
+ .select(when($"cond", array(lit("a"), lit("b"))).otherwise($"a") as
"res")
+ .select($"res".getItem(0))
+ def mapWhenDF: DataFrame = sourceDF
+ .select(when($"cond", map(lit(0), lit("a"))).otherwise($"m") as "res")
+ .select($"res".getItem(0))
+
+ def structIfDF: DataFrame = sourceDF
+ .select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res")
+ .select($"res".getField("val1"))
+ def arrayIfDF: DataFrame = sourceDF
+ .select(expr("if(cond, array('a', 'b'), a)") as "res")
+ .select($"res".getItem(0))
+ def mapIfDF: DataFrame = sourceDF
+ .select(expr("if(cond, map(0, 'a'), m)") as "res")
+ .select($"res".getItem(0))
+
+ def checkResult(): Unit = {
+ checkAnswer(structWhenDF, Seq(Row("a"), Row(null)))
+ checkAnswer(arrayWhenDF, Seq(Row("a"), Row(null)))
+ checkAnswer(mapWhenDF, Seq(Row("a"), Row(null)))
+ checkAnswer(structIfDF, Seq(Row("a"), Row(null)))
+ checkAnswer(arrayIfDF, Seq(Row("a"), Row(null)))
+ checkAnswer(mapIfDF, Seq(Row("a"), Row(null)))
+ }
+
+ // Test with local relation, the Project will be evaluated without codegen
+ checkResult()
+ // Test with cached relation, the Project will be evaluated with codegen
+ sourceDF.cache()
+ checkResult()
+ }
}
class S100(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 4c6f34d3d2dd..01905e2c05fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -602,4 +602,23 @@ class DataFrameJoinSuite extends QueryTest
)
}
}
+
+ test("SPARK-20359: catalyst outer join optimization should not throw npe") {
+ val df1 = Seq("a", "b", "c").toDF("x")
+ .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!"
}.apply($"x"))
+ val df2 = Seq("a", "b").toDF("x1")
+ df1
+ .join(df2, df1("x") === df2("x1"), "left_outer")
+ .filter($"x1".isNotNull || !$"y".isin("a!"))
+ .count()
+ }
+
+ test("SPARK-16181: outer join with isNull filter") {
+ val left = Seq("x").toDF("col")
+ val right = Seq("y").toDF("col").withColumn("new", lit(true))
+ val joined = left.join(right, left("col") === right("col"), "left_outer")
+
+ checkAnswer(joined, Row("x", null, null))
+ checkAnswer(joined.filter($"new".isNull), Row("x", null, null))
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index c777d2207584..7dc40549a17b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -488,4 +488,14 @@ class DataFrameSelfJoinSuite extends QueryTest with
SharedSparkSession {
context = ExpectedContext(fragment = "$",
getCurrentClassCallSitePattern))
}
}
+
+ test("SPARK-20897: cached self-join should not fail") {
+ // force to plan sort merge join
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+ val df = Seq(1 -> "a").toDF("i", "j")
+ val df1 = df.as("t1")
+ val df2 = df.as("t2")
+ assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1)
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameShowSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameShowSuite.scala
new file mode 100644
index 000000000000..e889fe2545af
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameShowSuite.scala
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.io.ByteArrayOutputStream
+import java.nio.charset.StandardCharsets
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.sql.functions.rand
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+class DataFrameShowSuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ ignore("show") {
+ // This test case is intended ignored, but to make sure it compiles
correctly
+ testData.select($"*").show()
+ testData.select($"*").show(1000)
+ }
+
+
+ test("showString: truncate = [0, 20]") {
+ val longString = Array.fill(21)("1").mkString
+ val df = sparkContext.parallelize(Seq("1", longString)).toDF()
+ val expectedAnswerForFalse = """+---------------------+
+ ||value |
+ |+---------------------+
+ ||1 |
+ ||111111111111111111111|
+ |+---------------------+
+ |""".stripMargin
+ assert(df.showString(10, truncate = 0) === expectedAnswerForFalse)
+ val expectedAnswerForTrue = """+--------------------+
+ || value|
+ |+--------------------+
+ || 1|
+ ||11111111111111111...|
+ |+--------------------+
+ |""".stripMargin
+ assert(df.showString(10, truncate = 20) === expectedAnswerForTrue)
+ }
+
+ test("showString: truncate = [0, 20], vertical = true") {
+ val longString = Array.fill(21)("1").mkString
+ val df = sparkContext.parallelize(Seq("1", longString)).toDF()
+ val expectedAnswerForFalse = "-RECORD 0----------------------\n" +
+ " value | 1 \n" +
+ "-RECORD 1----------------------\n" +
+ " value | 111111111111111111111 \n"
+ assert(df.showString(10, truncate = 0, vertical = true) ===
expectedAnswerForFalse)
+ val expectedAnswerForTrue = "-RECORD 0---------------------\n" +
+ " value | 1 \n" +
+ "-RECORD 1---------------------\n" +
+ " value | 11111111111111111... \n"
+ assert(df.showString(10, truncate = 20, vertical = true) ===
expectedAnswerForTrue)
+ }
+
+ test("showString: truncate = [3, 17]") {
+ val longString = Array.fill(21)("1").mkString
+ val df = sparkContext.parallelize(Seq("1", longString)).toDF()
+ val expectedAnswerForFalse = """+-----+
+ ||value|
+ |+-----+
+ || 1|
+ || 111|
+ |+-----+
+ |""".stripMargin
+ assert(df.showString(10, truncate = 3) === expectedAnswerForFalse)
+ val expectedAnswerForTrue = """+-----------------+
+ || value|
+ |+-----------------+
+ || 1|
+ ||11111111111111...|
+ |+-----------------+
+ |""".stripMargin
+ assert(df.showString(10, truncate = 17) === expectedAnswerForTrue)
+ }
+
+ test("showString: truncate = [3, 17], vertical = true") {
+ val longString = Array.fill(21)("1").mkString
+ val df = sparkContext.parallelize(Seq("1", longString)).toDF()
+ val expectedAnswerForFalse = "-RECORD 0----\n" +
+ " value | 1 \n" +
+ "-RECORD 1----\n" +
+ " value | 111 \n"
+ assert(df.showString(10, truncate = 3, vertical = true) ===
expectedAnswerForFalse)
+ val expectedAnswerForTrue = "-RECORD 0------------------\n" +
+ " value | 1 \n" +
+ "-RECORD 1------------------\n" +
+ " value | 11111111111111... \n"
+ assert(df.showString(10, truncate = 17, vertical = true) ===
expectedAnswerForTrue)
+ }
+
+ test("showString(negative)") {
+ val expectedAnswer = """+---+-----+
+ ||key|value|
+ |+---+-----+
+ |+---+-----+
+ |only showing top 0 rows
+ |""".stripMargin
+ assert(testData.select($"*").showString(-1) === expectedAnswer)
+ }
+
+ test("showString(negative), vertical = true") {
+ val expectedAnswer = "(0 rows)\n"
+ assert(testData.select($"*").showString(-1, vertical = true) ===
expectedAnswer)
+ }
+
+ test("showString(0)") {
+ val expectedAnswer = """+---+-----+
+ ||key|value|
+ |+---+-----+
+ |+---+-----+
+ |only showing top 0 rows
+ |""".stripMargin
+ assert(testData.select($"*").showString(0) === expectedAnswer)
+ }
+
+ test("showString(Int.MaxValue)") {
+ val df = Seq((1, 2), (3, 4)).toDF("a", "b")
+ val expectedAnswer = """+---+---+
+ || a| b|
+ |+---+---+
+ || 1| 2|
+ || 3| 4|
+ |+---+---+
+ |""".stripMargin
+ assert(df.showString(Int.MaxValue) === expectedAnswer)
+ }
+
+ test("showString(0), vertical = true") {
+ val expectedAnswer = "(0 rows)\n"
+ assert(testData.select($"*").showString(0, vertical = true) ===
expectedAnswer)
+ }
+
+ test("showString: array") {
+ val df = Seq(
+ (Array(1, 2, 3), Array(1, 2, 3)),
+ (Array(2, 3, 4), Array(2, 3, 4))
+ ).toDF()
+ val expectedAnswer = """+---------+---------+
+ || _1| _2|
+ |+---------+---------+
+ ||[1, 2, 3]|[1, 2, 3]|
+ ||[2, 3, 4]|[2, 3, 4]|
+ |+---------+---------+
+ |""".stripMargin
+ assert(df.showString(10) === expectedAnswer)
+ }
+
+ test("showString: array, vertical = true") {
+ val df = Seq(
+ (Array(1, 2, 3), Array(1, 2, 3)),
+ (Array(2, 3, 4), Array(2, 3, 4))
+ ).toDF()
+ val expectedAnswer = "-RECORD 0--------\n" +
+ " _1 | [1, 2, 3] \n" +
+ " _2 | [1, 2, 3] \n" +
+ "-RECORD 1--------\n" +
+ " _1 | [2, 3, 4] \n" +
+ " _2 | [2, 3, 4] \n"
+ assert(df.showString(10, vertical = true) === expectedAnswer)
+ }
+
+ test("showString: binary") {
+ val df = Seq(
+ ("12".getBytes(StandardCharsets.UTF_8),
"ABC.".getBytes(StandardCharsets.UTF_8)),
+ ("34".getBytes(StandardCharsets.UTF_8),
"12346".getBytes(StandardCharsets.UTF_8))
+ ).toDF()
+ val expectedAnswer = """+-------+----------------+
+ || _1| _2|
+ |+-------+----------------+
+ ||[31 32]| [41 42 43 2E]|
+ ||[33 34]|[31 32 33 34 36]|
+ |+-------+----------------+
+ |""".stripMargin
+ assert(df.showString(10) === expectedAnswer)
+ }
+
+ test("showString: binary, vertical = true") {
+ val df = Seq(
+ ("12".getBytes(StandardCharsets.UTF_8),
"ABC.".getBytes(StandardCharsets.UTF_8)),
+ ("34".getBytes(StandardCharsets.UTF_8),
"12346".getBytes(StandardCharsets.UTF_8))
+ ).toDF()
+ val expectedAnswer = "-RECORD 0---------------\n" +
+ " _1 | [31 32] \n" +
+ " _2 | [41 42 43 2E] \n" +
+ "-RECORD 1---------------\n" +
+ " _1 | [33 34] \n" +
+ " _2 | [31 32 33 34 36] \n"
+ assert(df.showString(10, vertical = true) === expectedAnswer)
+ }
+
+ test("showString: minimum column width") {
+ val df = Seq(
+ (1, 1),
+ (2, 2)
+ ).toDF()
+ val expectedAnswer = """+---+---+
+ || _1| _2|
+ |+---+---+
+ || 1| 1|
+ || 2| 2|
+ |+---+---+
+ |""".stripMargin
+ assert(df.showString(10) === expectedAnswer)
+ }
+
+ test("showString: minimum column width, vertical = true") {
+ val df = Seq(
+ (1, 1),
+ (2, 2)
+ ).toDF()
+ val expectedAnswer = "-RECORD 0--\n" +
+ " _1 | 1 \n" +
+ " _2 | 1 \n" +
+ "-RECORD 1--\n" +
+ " _1 | 2 \n" +
+ " _2 | 2 \n"
+ assert(df.showString(10, vertical = true) === expectedAnswer)
+ }
+
+ test("SPARK-33690: showString: escape meta-characters") {
+ val df1 = spark.sql("SELECT
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh'")
+ assert(df1.showString(1, truncate = 0) ===
+ """+--------------------------------------+
+ ||aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh|
+ |+--------------------------------------+
+ ||aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh|
+ |+--------------------------------------+
+ |""".stripMargin)
+
+ val df2 = spark.sql("SELECT
array('aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
+ assert(df2.showString(1, truncate = 0) ===
+ """+---------------------------------------------+
+ ||array(aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh)|
+ |+---------------------------------------------+
+ ||[aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh] |
+ |+---------------------------------------------+
+ |""".stripMargin)
+
+ val df3 =
+ spark.sql("SELECT map('aaa\nbbb\tccc',
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
+ assert(df3.showString(1, truncate = 0) ===
+ """+----------------------------------------------------------+
+ ||map(aaa\nbbb\tccc, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh)|
+ |+----------------------------------------------------------+
+ ||{aaa\nbbb\tccc -> aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh} |
+ |+----------------------------------------------------------+
+ |""".stripMargin)
+
+ val df4 =
+ spark.sql("SELECT named_struct('v',
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
+ assert(df4.showString(1, truncate = 0) ===
+ """+-------------------------------------------------------+
+ ||named_struct(v, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh)|
+ |+-------------------------------------------------------+
+ ||{aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh} |
+ |+-------------------------------------------------------+
+ |""".stripMargin)
+ }
+
+ test("SPARK-7319 showString") {
+ val expectedAnswer = """+---+-----+
+ ||key|value|
+ |+---+-----+
+ || 1| 1|
+ |+---+-----+
+ |only showing top 1 row
+ |""".stripMargin
+ assert(testData.select($"*").showString(1) === expectedAnswer)
+ }
+
+ test("SPARK-7319 showString, vertical = true") {
+ val expectedAnswer = "-RECORD 0----\n" +
+ " key | 1 \n" +
+ " value | 1 \n" +
+ "only showing top 1 row\n"
+ assert(testData.select($"*").showString(1, vertical = true) ===
expectedAnswer)
+ }
+
+ test("SPARK-23023 Cast rows to strings in showString") {
+ val df1 = Seq(Seq(1, 2, 3, 4)).toDF("a")
+ assert(df1.showString(10) ===
+ s"""+------------+
+ || a|
+ |+------------+
+ ||[1, 2, 3, 4]|
+ |+------------+
+ |""".stripMargin)
+ val df2 = Seq(Map(1 -> "a", 2 -> "b")).toDF("a")
+ assert(df2.showString(10) ===
+ s"""+----------------+
+ || a|
+ |+----------------+
+ ||{1 -> a, 2 -> b}|
+ |+----------------+
+ |""".stripMargin)
+ val df3 = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b")
+ assert(df3.showString(10) ===
+ s"""+------+---+
+ || a| b|
+ |+------+---+
+ ||{1, a}| 0|
+ ||{2, b}| 0|
+ |+------+---+
+ |""".stripMargin)
+ }
+
+ test("SPARK-7327 show with empty dataFrame") {
+ val expectedAnswer = """+---+-----+
+ ||key|value|
+ |+---+-----+
+ |+---+-----+
+ |""".stripMargin
+ assert(testData.select($"*").filter($"key" < 0).showString(1) ===
expectedAnswer)
+ }
+
+ test("SPARK-7327 show with empty dataFrame, vertical = true") {
+ assert(testData.select($"*").filter($"key" < 0).showString(1, vertical =
true) === "(0 rows)\n")
+ }
+
+ test("SPARK-18350 show with session local timezone") {
+ val d = Date.valueOf("2016-12-01")
+ val ts = Timestamp.valueOf("2016-12-01 00:00:00")
+ val df = Seq((d, ts)).toDF("d", "ts")
+ val expectedAnswer = """+----------+-------------------+
+ ||d |ts |
+ |+----------+-------------------+
+ ||2016-12-01|2016-12-01 00:00:00|
+ |+----------+-------------------+
+ |""".stripMargin
+ assert(df.showString(1, truncate = 0) === expectedAnswer)
+
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+
+ val expectedAnswer = """+----------+-------------------+
+ ||d |ts |
+ |+----------+-------------------+
+ ||2016-12-01|2016-12-01 08:00:00|
+ |+----------+-------------------+
+ |""".stripMargin
+ assert(df.showString(1, truncate = 0) === expectedAnswer)
+ }
+ }
+
+ test("SPARK-18350 show with session local timezone, vertical = true") {
+ val d = Date.valueOf("2016-12-01")
+ val ts = Timestamp.valueOf("2016-12-01 00:00:00")
+ val df = Seq((d, ts)).toDF("d", "ts")
+ val expectedAnswer = "-RECORD 0------------------\n" +
+ " d | 2016-12-01 \n" +
+ " ts | 2016-12-01 00:00:00 \n"
+ assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer)
+
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+
+ val expectedAnswer = "-RECORD 0------------------\n" +
+ " d | 2016-12-01 \n" +
+ " ts | 2016-12-01 08:00:00 \n"
+ assert(df.showString(1, truncate = 0, vertical = true) ===
expectedAnswer)
+ }
+ }
+
+ test("SPARK-8608: call `show` on local DataFrame with random columns should
return same value") {
+ val df = testData.select(rand(33))
+ assert(df.showString(5) == df.showString(5))
+
+ // We will reuse the same Expression object for LocalRelation.
+ val df1 = (1 to 10).map(Tuple1.apply).toDF().select(rand(33))
+ assert(df1.showString(5) == df1.showString(5))
+ }
+
+ test("dataframe toString") {
+ assert(testData.toString === "[key: int, value: string]")
+ assert(testData("key").toString === "key")
+ assert($"test".toString === "test")
+ }
+
+ test("SPARK-12398 truncated toString") {
+ val df1 = Seq((1L, "row1")).toDF("id", "name")
+ assert(df1.toString() === "[id: bigint, name: string]")
+
+ val df2 = Seq((1L, "c2", false)).toDF("c1", "c2", "c3")
+ assert(df2.toString === "[c1: bigint, c2: string ... 1 more field]")
+
+ val df3 = Seq((1L, "c2", false, 10)).toDF("c1", "c2", "c3", "c4")
+ assert(df3.toString === "[c1: bigint, c2: string ... 2 more fields]")
+
+ val df4 = Seq((1L, Tuple2(1L, "val"))).toDF("c1", "c2")
+ assert(df4.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string>]")
+
+ val df5 = Seq((1L, Tuple2(1L, "val"), 20.0)).toDF("c1", "c2", "c3")
+ assert(df5.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string>
... 1 more field]")
+
+ val df6 = Seq((1L, Tuple2(1L, "val"), 20.0, 1)).toDF("c1", "c2", "c3",
"c4")
+ assert(df6.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string>
... 2 more fields]")
+
+ val df7 = Seq((1L, Tuple3(1L, "val", 2), 20.0, 1)).toDF("c1", "c2", "c3",
"c4")
+ assert(
+ df7.toString ===
+ "[c1: bigint, c2: struct<_1: bigint, _2: string ... 1 more field> ...
2 more fields]")
+
+ val df8 = Seq((1L, Tuple7(1L, "val", 2, 3, 4, 5, 6), 20.0, 1)).toDF("c1",
"c2", "c3", "c4")
+ assert(
+ df8.toString ===
+ "[c1: bigint, c2: struct<_1: bigint, _2: string ... 5 more fields> ...
2 more fields]")
+
+ val df9 =
+ Seq((1L, Tuple4(1L, Tuple4(1L, 2L, 3L, 4L), 2L, 3L), 20.0,
1)).toDF("c1", "c2", "c3", "c4")
+ assert(
+ df9.toString ===
+ "[c1: bigint, c2: struct<_1: bigint," +
+ " _2: struct<_1: bigint," +
+ " _2: bigint ... 2 more fields> ... 2 more fields> ... 2 more
fields]")
+
+ }
+
+ test("SPARK-34308: printSchema: escape meta-characters") {
+ val captured = new ByteArrayOutputStream()
+
+ val df1 = spark.sql("SELECT
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh'")
+ Console.withOut(captured) {
+ df1.printSchema()
+ }
+ assert(captured.toString ===
+ """root
+ | |-- aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh: string (nullable = false)
+ |
+ |""".stripMargin)
+ captured.reset()
+
+ val df2 = spark.sql("SELECT
array('aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
+ Console.withOut(captured) {
+ df2.printSchema()
+ }
+ assert(captured.toString ===
+ """root
+ | |-- array(aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh): array (nullable =
false)
+ | | |-- element: string (containsNull = false)
+ |
+ |""".stripMargin)
+ captured.reset()
+
+ val df3 =
+ spark.sql("SELECT map('aaa\nbbb\tccc',
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
+ Console.withOut(captured) {
+ df3.printSchema()
+ }
+ assert(captured.toString ===
+ """root
+ | |-- map(aaa\nbbb\tccc, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh): map
(nullable = false)
+ | | |-- key: string
+ | | |-- value: string (valueContainsNull = false)
+ |
+ |""".stripMargin)
+ captured.reset()
+
+ val df4 =
+ spark.sql("SELECT named_struct('v',
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
+ Console.withOut(captured) {
+ df4.printSchema()
+ }
+ assert(captured.toString ===
+ """root
+ | |-- named_struct(v, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh): struct
(nullable = false)
+ | | |-- v: string (nullable = false)
+ |
+ |""".stripMargin)
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 8b53edb0d3c1..8eee8fc37661 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -21,6 +21,7 @@ import java.util.Random
import org.scalatest.matchers.must.Matchers._
+import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.functions.{col, lit, struct, when}
@@ -540,6 +541,70 @@ class DataFrameStatSuite extends QueryTest with
SharedSparkSession {
assert(filter4.bitSize() == 64 * 5)
assert(0.until(1000).forall(i => filter4.mightContain(i * 3)))
}
+
+ test("SPARK-34165: Add count_distinct to summary") {
+ val person3: DataFrame = Seq(
+ ("Luis", 1, 99),
+ ("Luis", 16, 99),
+ ("Luis", 16, 176),
+ ("Fernando", 32, 99),
+ ("Fernando", 32, 164),
+ ("David", 60, 99),
+ ("Amy", 24, 99)).toDF("name", "age", "height")
+ val summaryDF = person3.summary("count", "count_distinct")
+
+ val summaryResult = Seq(
+ Row("count", "7", "7", "7"),
+ Row("count_distinct", "4", "5", "3"))
+
+ def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
+ assert(getSchemaAsSeq(summaryDF) === Seq("summary", "name", "age",
"height"))
+ checkAnswer(summaryDF, summaryResult)
+
+ val approxSummaryDF = person3.summary("count", "approx_count_distinct")
+ val approxSummaryResult = Seq(
+ Row("count", "7", "7", "7"),
+ Row("approx_count_distinct", "4", "5", "3"))
+ assert(getSchemaAsSeq(summaryDF) === Seq("summary", "name", "age",
"height"))
+ checkAnswer(approxSummaryDF, approxSummaryResult)
+ }
+
+ test("summary advanced") {
+ import org.apache.spark.util.ArrayImplicits._
+ val person2: DataFrame = Seq(
+ ("Bob", 16, 176),
+ ("Alice", 32, 164),
+ ("David", 60, 192),
+ ("Amy", 24, 180)).toDF("name", "age", "height")
+
+ val stats = Array("count", "50.01%", "max", "mean", "min", "25%")
+ val orderMatters = person2.summary(stats.toImmutableArraySeq: _*)
+ assert(orderMatters.collect().map(_.getString(0)) === stats)
+
+ val onlyPercentiles = person2.summary("0.1%", "99.9%")
+ assert(onlyPercentiles.count() === 2)
+
+ checkError(
+ exception = intercept[SparkIllegalArgumentException] {
+ person2.summary("foo")
+ },
+ errorClass = "_LEGACY_ERROR_TEMP_2114",
+ parameters = Map("stats" -> "foo")
+ )
+
+ checkError(
+ exception = intercept[SparkIllegalArgumentException] {
+ person2.summary("foo%")
+ },
+ errorClass = "_LEGACY_ERROR_TEMP_2113",
+ parameters = Map("stats" -> "foo%")
+ )
+ }
+
+ test("SPARK-19691 Calculating percentile of decimal column fails with
ClassCastException") {
+ val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as
x").selectExpr("percentile(x, 0.5)")
+ checkAnswer(df, Row(BigDecimal(0)) :: Nil)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 8ef95e6fd129..24a548fa7f81 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -20,28 +20,22 @@ package org.apache.spark.sql
import java.io.{ByteArrayOutputStream, File}
import java.lang.{Long => JLong}
import java.nio.charset.StandardCharsets
-import java.sql.{Date, Timestamp}
-import java.util.{Locale, UUID}
-import java.util.concurrent.atomic.AtomicLong
+import java.util.Locale
import scala.collection.immutable.ListMap
-import scala.reflect.runtime.universe.TypeTag
import scala.util.Random
import org.scalatest.matchers.should.Matchers._
-import org.apache.spark.{SparkException, SparkIllegalArgumentException}
+import org.apache.spark.SparkException
import org.apache.spark.api.python.PythonEvalType
-import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap,
AttributeReference, Cast, EqualTo, ExpressionSet, GreaterThan, Literal,
PythonUDF, ScalarSubquery, Uuid}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Cast, EqualTo, ExpressionSet, GreaterThan, Literal,
PythonUDF, ScalarSubquery, Uuid}
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter,
LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.catalyst.util.HadoopCompressionCodec.GZIP
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode,
LocalRelation, LogicalPlan, OneRowRelation}
import org.apache.spark.sql.connector.FakeV2Provider
import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution,
SortExec, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -51,12 +45,11 @@ import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT,
SharedSparkSession}
-import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper,
ContainerStringWrapper, DecimalData, StringWrapper, TestData2}
+import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper,
ContainerStringWrapper, StringWrapper, TestData2}
import org.apache.spark.sql.types._
import org.apache.spark.tags.SlowSQLTest
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.ArrayImplicits._
-import org.apache.spark.util.Utils
@SlowSQLTest
class DataFrameSuite extends QueryTest
@@ -77,12 +70,6 @@ class DataFrameSuite extends QueryTest
}
}
- test("dataframe toString") {
- assert(testData.toString === "[key: int, value: string]")
- assert(testData("key").toString === "key")
- assert($"test".toString === "test")
- }
-
test("rename nested groupby") {
val df = Seq((1, (1, 1))).toDF()
@@ -91,14 +78,6 @@ class DataFrameSuite extends QueryTest
Row(1, 1) :: Nil)
}
- test("access complex data") {
- assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1)
- if (!conf.ansiEnabled) {
- assert(complexData.filter(complexData("m").getItem("1") === 1).count()
== 1)
- }
- assert(complexData.filter(complexData("s").getField("key") === 1).count()
== 1)
- }
-
test("table scan") {
checkAnswer(
testData,
@@ -204,32 +183,6 @@ class DataFrameSuite extends QueryTest
structDf.select(xxhash64($"a", $"record.*")))
}
- private def assertDecimalSumOverflow(
- df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
- if (!ansiEnabled) {
- checkAnswer(df, expectedAnswer)
- } else {
- val e = intercept[ArithmeticException] {
- df.collect()
- }
- assert(e.getMessage.contains("cannot be represented as Decimal") ||
- e.getMessage.contains("Overflow in sum of decimals"))
- }
- }
-
- test("SPARK-28224: Aggregate sum big decimal overflow") {
- val largeDecimals = spark.sparkContext.parallelize(
- DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123"))
::
- DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 +
".123")) :: Nil).toDF()
-
- Seq(true, false).foreach { ansiEnabled =>
- withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
- val structDf = largeDecimals.select("a").agg(sum("a"))
- assertDecimalSumOverflow(structDf, ansiEnabled, Row(null))
- }
- }
- }
-
test("SPARK-28067: sum of null decimal values") {
Seq("true", "false").foreach { wholeStageEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled))
{
@@ -243,85 +196,6 @@ class DataFrameSuite extends QueryTest
}
}
- def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = {
- Seq("true", "false").foreach { wholeStageEnabled =>
- withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled))
{
- Seq(true, false).foreach { ansiEnabled =>
- withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
- val df0 = Seq(
- (BigDecimal("10000000000000000000"), 1),
- (BigDecimal("10000000000000000000"), 1),
- (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
- val df1 = Seq(
- (BigDecimal("10000000000000000000"), 2),
- (BigDecimal("10000000000000000000"), 2),
- (BigDecimal("10000000000000000000"), 2),
- (BigDecimal("10000000000000000000"), 2),
- (BigDecimal("10000000000000000000"), 2),
- (BigDecimal("10000000000000000000"), 2),
- (BigDecimal("10000000000000000000"), 2),
- (BigDecimal("10000000000000000000"), 2),
- (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
- val df = df0.union(df1)
- val df2 = df.withColumnRenamed("decNum", "decNum2").
- join(df, "intNum").agg(aggFn($"decNum"))
-
- val expectedAnswer = Row(null)
- assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)
-
- val decStr = "1" + "0" * 19
- val d1 = spark.range(0, 12, 1, 1)
- val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as
d")).agg(aggFn($"d"))
- assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)
-
- val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
- val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as
d")).agg(aggFn($"d"))
- assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)
-
- val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as
d"),
-
lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd")
- assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)
-
- val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as
decimal(38,18)) as d"))
-
- val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"),
BigDecimal("9"* 20 + ".123")).
- toDF("d")
- assertDecimalSumOverflow(
- nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled,
expectedAnswer)
-
- val df3 = Seq(
- (BigDecimal("10000000000000000000"), 1),
- (BigDecimal("50000000000000000000"), 1),
- (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
-
- val df4 = Seq(
- (BigDecimal("10000000000000000000"), 1),
- (BigDecimal("10000000000000000000"), 1),
- (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
-
- val df5 = Seq(
- (BigDecimal("10000000000000000000"), 1),
- (BigDecimal("10000000000000000000"), 1),
- (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum")
-
- val df6 = df3.union(df4).union(df5)
- val df7 = df6.groupBy("intNum").agg(sum("decNum"),
countDistinct("decNum")).
- filter("intNum == 1")
- assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2))
- }
- }
- }
- }
- }
-
- test("SPARK-28067: Aggregate sum should not return wrong results for decimal
overflow") {
- checkAggResultsForDecimalOverflow(c => sum(c))
- }
-
- test("SPARK-35955: Aggregate avg should not return wrong results for decimal
overflow") {
- checkAggResultsForDecimalOverflow(c => avg(c))
- }
-
test("Star Expansion - ds.explode should fail with a meaningful message if
it takes a star") {
val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix",
"csv")
val e = intercept[AnalysisException] {
@@ -663,24 +537,6 @@ class DataFrameSuite extends QueryTest
testData.take(15).drop(5).toSeq)
}
- test("udf") {
- val foo = udf((a: Int, b: String) => a.toString + b)
-
- checkAnswer(
- // SELECT *, foo(key, value) FROM testData
- testData.select($"*", foo($"key", $"value")).limit(3),
- Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil
- )
- }
-
- test("callUDF without Hive Support") {
- val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
- df.sparkSession.udf.register("simpleUDF", (v: Int) => v * v)
- checkAnswer(
- df.select($"id", callUDF("simpleUDF", $"value")), // test deprecated one
- Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil)
- }
-
test("withColumn") {
val df = testData.toDF().withColumn("newCol", col("key") + 1)
checkAnswer(
@@ -1048,15 +904,6 @@ class DataFrameSuite extends QueryTest
("David", 60, 192),
("Amy", 24, 180)).toDF("name", "age", "height")
- private lazy val person3: DataFrame = Seq(
- ("Luis", 1, 99),
- ("Luis", 16, 99),
- ("Luis", 16, 176),
- ("Fernando", 32, 99),
- ("Fernando", 32, 164),
- ("David", 60, 99),
- ("Amy", 24, 99)).toDF("name", "age", "height")
-
test("describe") {
val describeResult = Seq(
Row("count", "4", "4", "4"),
@@ -1156,25 +1003,6 @@ class DataFrameSuite extends QueryTest
}
}
- test("SPARK-34165: Add count_distinct to summary") {
- val summaryDF = person3.summary("count", "count_distinct")
-
- val summaryResult = Seq(
- Row("count", "7", "7", "7"),
- Row("count_distinct", "4", "5", "3"))
-
- def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
- assert(getSchemaAsSeq(summaryDF) === Seq("summary", "name", "age",
"height"))
- checkAnswer(summaryDF, summaryResult)
-
- val approxSummaryDF = person3.summary("count", "approx_count_distinct")
- val approxSummaryResult = Seq(
- Row("count", "7", "7", "7"),
- Row("approx_count_distinct", "4", "5", "3"))
- assert(getSchemaAsSeq(summaryDF) === Seq("summary", "name", "age",
"height"))
- checkAnswer(approxSummaryDF, approxSummaryResult)
- }
-
test("SPARK-41391: Correct the output column name of
groupBy.agg(count_distinct)") {
withTempView("person") {
person.createOrReplaceTempView("person")
@@ -1187,32 +1015,6 @@ class DataFrameSuite extends QueryTest
}
}
- test("summary advanced") {
- import org.apache.spark.util.ArrayImplicits._
- val stats = Array("count", "50.01%", "max", "mean", "min", "25%")
- val orderMatters = person2.summary(stats.toImmutableArraySeq: _*)
- assert(orderMatters.collect().map(_.getString(0)) === stats)
-
- val onlyPercentiles = person2.summary("0.1%", "99.9%")
- assert(onlyPercentiles.count() === 2)
-
- checkError(
- exception = intercept[SparkIllegalArgumentException] {
- person2.summary("foo")
- },
- errorClass = "_LEGACY_ERROR_TEMP_2114",
- parameters = Map("stats" -> "foo")
- )
-
- checkError(
- exception = intercept[SparkIllegalArgumentException] {
- person2.summary("foo%")
- },
- errorClass = "_LEGACY_ERROR_TEMP_2113",
- parameters = Map("stats" -> "foo%")
- )
- }
-
test("apply on query results (SPARK-5462)") {
val df = testData.sparkSession.sql("select key from testData")
checkAnswer(df.select(df("key")), testData.select("key").collect().toSeq)
@@ -1242,12 +1044,6 @@ class DataFrameSuite extends QueryTest
}
}
- ignore("show") {
- // This test case is intended ignored, but to make sure it compiles
correctly
- testData.select($"*").show()
- testData.select($"*").show(1000)
- }
-
test("getRows: truncate = [0, 20]") {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
@@ -1307,402 +1103,6 @@ class DataFrameSuite extends QueryTest
assert(df.getRows(10, 20) === expectedAnswer)
}
- test("showString: truncate = [0, 20]") {
- val longString = Array.fill(21)("1").mkString
- val df = sparkContext.parallelize(Seq("1", longString)).toDF()
- val expectedAnswerForFalse = """+---------------------+
- ||value |
- |+---------------------+
- ||1 |
- ||111111111111111111111|
- |+---------------------+
- |""".stripMargin
- assert(df.showString(10, truncate = 0) === expectedAnswerForFalse)
- val expectedAnswerForTrue = """+--------------------+
- || value|
- |+--------------------+
- || 1|
- ||11111111111111111...|
- |+--------------------+
- |""".stripMargin
- assert(df.showString(10, truncate = 20) === expectedAnswerForTrue)
- }
-
- test("showString: truncate = [0, 20], vertical = true") {
- val longString = Array.fill(21)("1").mkString
- val df = sparkContext.parallelize(Seq("1", longString)).toDF()
- val expectedAnswerForFalse = "-RECORD 0----------------------\n" +
- " value | 1 \n" +
- "-RECORD 1----------------------\n" +
- " value | 111111111111111111111 \n"
- assert(df.showString(10, truncate = 0, vertical = true) ===
expectedAnswerForFalse)
- val expectedAnswerForTrue = "-RECORD 0---------------------\n" +
- " value | 1 \n" +
- "-RECORD 1---------------------\n" +
- " value | 11111111111111111... \n"
- assert(df.showString(10, truncate = 20, vertical = true) ===
expectedAnswerForTrue)
- }
-
- test("showString: truncate = [3, 17]") {
- val longString = Array.fill(21)("1").mkString
- val df = sparkContext.parallelize(Seq("1", longString)).toDF()
- val expectedAnswerForFalse = """+-----+
- ||value|
- |+-----+
- || 1|
- || 111|
- |+-----+
- |""".stripMargin
- assert(df.showString(10, truncate = 3) === expectedAnswerForFalse)
- val expectedAnswerForTrue = """+-----------------+
- || value|
- |+-----------------+
- || 1|
- ||11111111111111...|
- |+-----------------+
- |""".stripMargin
- assert(df.showString(10, truncate = 17) === expectedAnswerForTrue)
- }
-
- test("showString: truncate = [3, 17], vertical = true") {
- val longString = Array.fill(21)("1").mkString
- val df = sparkContext.parallelize(Seq("1", longString)).toDF()
- val expectedAnswerForFalse = "-RECORD 0----\n" +
- " value | 1 \n" +
- "-RECORD 1----\n" +
- " value | 111 \n"
- assert(df.showString(10, truncate = 3, vertical = true) ===
expectedAnswerForFalse)
- val expectedAnswerForTrue = "-RECORD 0------------------\n" +
- " value | 1 \n" +
- "-RECORD 1------------------\n" +
- " value | 11111111111111... \n"
- assert(df.showString(10, truncate = 17, vertical = true) ===
expectedAnswerForTrue)
- }
-
- test("showString(negative)") {
- val expectedAnswer = """+---+-----+
- ||key|value|
- |+---+-----+
- |+---+-----+
- |only showing top 0 rows
- |""".stripMargin
- assert(testData.select($"*").showString(-1) === expectedAnswer)
- }
-
- test("showString(negative), vertical = true") {
- val expectedAnswer = "(0 rows)\n"
- assert(testData.select($"*").showString(-1, vertical = true) ===
expectedAnswer)
- }
-
- test("showString(0)") {
- val expectedAnswer = """+---+-----+
- ||key|value|
- |+---+-----+
- |+---+-----+
- |only showing top 0 rows
- |""".stripMargin
- assert(testData.select($"*").showString(0) === expectedAnswer)
- }
-
- test("showString(Int.MaxValue)") {
- val df = Seq((1, 2), (3, 4)).toDF("a", "b")
- val expectedAnswer = """+---+---+
- || a| b|
- |+---+---+
- || 1| 2|
- || 3| 4|
- |+---+---+
- |""".stripMargin
- assert(df.showString(Int.MaxValue) === expectedAnswer)
- }
-
- test("showString(0), vertical = true") {
- val expectedAnswer = "(0 rows)\n"
- assert(testData.select($"*").showString(0, vertical = true) ===
expectedAnswer)
- }
-
- test("showString: array") {
- val df = Seq(
- (Array(1, 2, 3), Array(1, 2, 3)),
- (Array(2, 3, 4), Array(2, 3, 4))
- ).toDF()
- val expectedAnswer = """+---------+---------+
- || _1| _2|
- |+---------+---------+
- ||[1, 2, 3]|[1, 2, 3]|
- ||[2, 3, 4]|[2, 3, 4]|
- |+---------+---------+
- |""".stripMargin
- assert(df.showString(10) === expectedAnswer)
- }
-
- test("showString: array, vertical = true") {
- val df = Seq(
- (Array(1, 2, 3), Array(1, 2, 3)),
- (Array(2, 3, 4), Array(2, 3, 4))
- ).toDF()
- val expectedAnswer = "-RECORD 0--------\n" +
- " _1 | [1, 2, 3] \n" +
- " _2 | [1, 2, 3] \n" +
- "-RECORD 1--------\n" +
- " _1 | [2, 3, 4] \n" +
- " _2 | [2, 3, 4] \n"
- assert(df.showString(10, vertical = true) === expectedAnswer)
- }
-
- test("showString: binary") {
- val df = Seq(
- ("12".getBytes(StandardCharsets.UTF_8),
"ABC.".getBytes(StandardCharsets.UTF_8)),
- ("34".getBytes(StandardCharsets.UTF_8),
"12346".getBytes(StandardCharsets.UTF_8))
- ).toDF()
- val expectedAnswer = """+-------+----------------+
- || _1| _2|
- |+-------+----------------+
- ||[31 32]| [41 42 43 2E]|
- ||[33 34]|[31 32 33 34 36]|
- |+-------+----------------+
- |""".stripMargin
- assert(df.showString(10) === expectedAnswer)
- }
-
- test("showString: binary, vertical = true") {
- val df = Seq(
- ("12".getBytes(StandardCharsets.UTF_8),
"ABC.".getBytes(StandardCharsets.UTF_8)),
- ("34".getBytes(StandardCharsets.UTF_8),
"12346".getBytes(StandardCharsets.UTF_8))
- ).toDF()
- val expectedAnswer = "-RECORD 0---------------\n" +
- " _1 | [31 32] \n" +
- " _2 | [41 42 43 2E] \n" +
- "-RECORD 1---------------\n" +
- " _1 | [33 34] \n" +
- " _2 | [31 32 33 34 36] \n"
- assert(df.showString(10, vertical = true) === expectedAnswer)
- }
-
- test("showString: minimum column width") {
- val df = Seq(
- (1, 1),
- (2, 2)
- ).toDF()
- val expectedAnswer = """+---+---+
- || _1| _2|
- |+---+---+
- || 1| 1|
- || 2| 2|
- |+---+---+
- |""".stripMargin
- assert(df.showString(10) === expectedAnswer)
- }
-
- test("showString: minimum column width, vertical = true") {
- val df = Seq(
- (1, 1),
- (2, 2)
- ).toDF()
- val expectedAnswer = "-RECORD 0--\n" +
- " _1 | 1 \n" +
- " _2 | 1 \n" +
- "-RECORD 1--\n" +
- " _1 | 2 \n" +
- " _2 | 2 \n"
- assert(df.showString(10, vertical = true) === expectedAnswer)
- }
-
- test("SPARK-33690: showString: escape meta-characters") {
- val df1 = spark.sql("SELECT
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh'")
- assert(df1.showString(1, truncate = 0) ===
- """+--------------------------------------+
- ||aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh|
- |+--------------------------------------+
- ||aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh|
- |+--------------------------------------+
- |""".stripMargin)
-
- val df2 = spark.sql("SELECT
array('aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
- assert(df2.showString(1, truncate = 0) ===
- """+---------------------------------------------+
- ||array(aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh)|
- |+---------------------------------------------+
- ||[aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh] |
- |+---------------------------------------------+
- |""".stripMargin)
-
- val df3 =
- spark.sql("SELECT map('aaa\nbbb\tccc',
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
- assert(df3.showString(1, truncate = 0) ===
- """+----------------------------------------------------------+
- ||map(aaa\nbbb\tccc, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh)|
- |+----------------------------------------------------------+
- ||{aaa\nbbb\tccc -> aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh} |
- |+----------------------------------------------------------+
- |""".stripMargin)
-
- val df4 =
- spark.sql("SELECT named_struct('v',
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
- assert(df4.showString(1, truncate = 0) ===
- """+-------------------------------------------------------+
- ||named_struct(v, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh)|
- |+-------------------------------------------------------+
- ||{aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh} |
- |+-------------------------------------------------------+
- |""".stripMargin)
- }
-
- test("SPARK-34308: printSchema: escape meta-characters") {
- val captured = new ByteArrayOutputStream()
-
- val df1 = spark.sql("SELECT
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh'")
- Console.withOut(captured) {
- df1.printSchema()
- }
- assert(captured.toString ===
- """root
- | |-- aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh: string (nullable = false)
- |
- |""".stripMargin)
- captured.reset()
-
- val df2 = spark.sql("SELECT
array('aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
- Console.withOut(captured) {
- df2.printSchema()
- }
- assert(captured.toString ===
- """root
- | |-- array(aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh): array (nullable =
false)
- | | |-- element: string (containsNull = false)
- |
- |""".stripMargin)
- captured.reset()
-
- val df3 =
- spark.sql("SELECT map('aaa\nbbb\tccc',
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
- Console.withOut(captured) {
- df3.printSchema()
- }
- assert(captured.toString ===
- """root
- | |-- map(aaa\nbbb\tccc, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh): map
(nullable = false)
- | | |-- key: string
- | | |-- value: string (valueContainsNull = false)
- |
- |""".stripMargin)
- captured.reset()
-
- val df4 =
- spark.sql("SELECT named_struct('v',
'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
- Console.withOut(captured) {
- df4.printSchema()
- }
- assert(captured.toString ===
- """root
- | |-- named_struct(v, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh): struct
(nullable = false)
- | | |-- v: string (nullable = false)
- |
- |""".stripMargin)
- }
-
- test("SPARK-7319 showString") {
- val expectedAnswer = """+---+-----+
- ||key|value|
- |+---+-----+
- || 1| 1|
- |+---+-----+
- |only showing top 1 row
- |""".stripMargin
- assert(testData.select($"*").showString(1) === expectedAnswer)
- }
-
- test("SPARK-7319 showString, vertical = true") {
- val expectedAnswer = "-RECORD 0----\n" +
- " key | 1 \n" +
- " value | 1 \n" +
- "only showing top 1 row\n"
- assert(testData.select($"*").showString(1, vertical = true) ===
expectedAnswer)
- }
-
- test("SPARK-23023 Cast rows to strings in showString") {
- val df1 = Seq(Seq(1, 2, 3, 4)).toDF("a")
- assert(df1.showString(10) ===
- s"""+------------+
- || a|
- |+------------+
- ||[1, 2, 3, 4]|
- |+------------+
- |""".stripMargin)
- val df2 = Seq(Map(1 -> "a", 2 -> "b")).toDF("a")
- assert(df2.showString(10) ===
- s"""+----------------+
- || a|
- |+----------------+
- ||{1 -> a, 2 -> b}|
- |+----------------+
- |""".stripMargin)
- val df3 = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b")
- assert(df3.showString(10) ===
- s"""+------+---+
- || a| b|
- |+------+---+
- ||{1, a}| 0|
- ||{2, b}| 0|
- |+------+---+
- |""".stripMargin)
- }
-
- test("SPARK-7327 show with empty dataFrame") {
- val expectedAnswer = """+---+-----+
- ||key|value|
- |+---+-----+
- |+---+-----+
- |""".stripMargin
- assert(testData.select($"*").filter($"key" < 0).showString(1) ===
expectedAnswer)
- }
-
- test("SPARK-7327 show with empty dataFrame, vertical = true") {
- assert(testData.select($"*").filter($"key" < 0).showString(1, vertical =
true) === "(0 rows)\n")
- }
-
- test("SPARK-18350 show with session local timezone") {
- val d = Date.valueOf("2016-12-01")
- val ts = Timestamp.valueOf("2016-12-01 00:00:00")
- val df = Seq((d, ts)).toDF("d", "ts")
- val expectedAnswer = """+----------+-------------------+
- ||d |ts |
- |+----------+-------------------+
- ||2016-12-01|2016-12-01 00:00:00|
- |+----------+-------------------+
- |""".stripMargin
- assert(df.showString(1, truncate = 0) === expectedAnswer)
-
- withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
-
- val expectedAnswer = """+----------+-------------------+
- ||d |ts |
- |+----------+-------------------+
- ||2016-12-01|2016-12-01 08:00:00|
- |+----------+-------------------+
- |""".stripMargin
- assert(df.showString(1, truncate = 0) === expectedAnswer)
- }
- }
-
- test("SPARK-18350 show with session local timezone, vertical = true") {
- val d = Date.valueOf("2016-12-01")
- val ts = Timestamp.valueOf("2016-12-01 00:00:00")
- val df = Seq((d, ts)).toDF("d", "ts")
- val expectedAnswer = "-RECORD 0------------------\n" +
- " d | 2016-12-01 \n" +
- " ts | 2016-12-01 00:00:00 \n"
- assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer)
-
- withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
-
- val expectedAnswer = "-RECORD 0------------------\n" +
- " d | 2016-12-01 \n" +
- " ts | 2016-12-01 08:00:00 \n"
- assert(df.showString(1, truncate = 0, vertical = true) ===
expectedAnswer)
- }
- }
-
test("createDataFrame(RDD[Row], StructType) should convert UDTs
(SPARK-6672)") {
val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
val schema = StructType(Array(StructField("point", new ExamplePointUDT(),
false)))
@@ -1714,16 +1114,6 @@ class DataFrameSuite extends QueryTest
checkAnswer(decimalData.agg(avg("a")), Row(new java.math.BigDecimal(2)))
}
- test("SPARK-7133: Implement struct, array, and map field accessor") {
- assert(complexData.filter(complexData("a")(0) === 2).count() == 1)
- if (!conf.ansiEnabled) {
- assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
- }
- assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
- assert(complexData.filter(complexData("m")(complexData("s")("value")) ===
1).count() == 1)
- assert(complexData.filter(complexData("a")(complexData("s")("key")) ===
1).count() == 1)
- }
-
test("SPARK-7551: support backticks for DataFrame attribute resolution") {
withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") {
val df = spark.read.json(Seq("""{"a.b": {"c": {"d..e": {"f":
1}}}}""").toDS())
@@ -1900,15 +1290,6 @@ class DataFrameSuite extends QueryTest
}
}
- test("SPARK-8608: call `show` on local DataFrame with random columns should
return same value") {
- val df = testData.select(rand(33))
- assert(df.showString(5) == df.showString(5))
-
- // We will reuse the same Expression object for LocalRelation.
- val df1 = (1 to 10).map(Tuple1.apply).toDF().select(rand(33))
- assert(df1.showString(5) == df1.showString(5))
- }
-
test("SPARK-8609: local DataFrame with random columns should return same
value after sort") {
checkAnswer(testData.sort(rand(33)), testData.sort(rand(33)))
@@ -2148,82 +1529,6 @@ class DataFrameSuite extends QueryTest
}
}
- test("SPARK-39834: build the stats for LogicalRDD based on origin stats") {
- def buildExpectedColumnStats(attrs: Seq[Attribute]):
AttributeMap[ColumnStat] = {
- AttributeMap(
- attrs.map {
- case attr if attr.dataType == BooleanType =>
- attr -> ColumnStat(
- distinctCount = Some(2),
- min = Some(false),
- max = Some(true),
- nullCount = Some(0),
- avgLen = Some(1),
- maxLen = Some(1))
-
- case attr if attr.dataType == ByteType =>
- attr -> ColumnStat(
- distinctCount = Some(2),
- min = Some(1),
- max = Some(2),
- nullCount = Some(0),
- avgLen = Some(1),
- maxLen = Some(1))
-
- case attr => attr -> ColumnStat()
- }
- )
- }
-
- val outputList = Seq(
- AttributeReference("cbool", BooleanType)(),
- AttributeReference("cbyte", ByteType)(),
- AttributeReference("cint", IntegerType)()
- )
-
- val expectedSize = 16
- val statsPlan = OutputListAwareStatsTestPlan(
- outputList = outputList,
- rowCount = 2,
- size = Some(expectedSize))
-
- withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
- val df = Dataset.ofRows(spark, statsPlan)
- // add some map-like operations which optimizer will optimize away,
and make a divergence
- // for output between logical plan and optimized plan
- // logical plan
- // Project [cb#6 AS cbool#12, cby#7 AS cbyte#13, ci#8 AS cint#14]
- // +- Project [cbool#0 AS cb#6, cbyte#1 AS cby#7, cint#2 AS ci#8]
- // +- OutputListAwareStatsTestPlan [cbool#0, cbyte#1, cint#2], 2, 16
- // optimized plan
- // OutputListAwareStatsTestPlan [cbool#0, cbyte#1, cint#2], 2, 16
- .selectExpr("cbool AS cb", "cbyte AS cby", "cint AS ci")
- .selectExpr("cb AS cbool", "cby AS cbyte", "ci AS cint")
-
- // We can't leverage LogicalRDD.fromDataset here, since it triggers
physical planning and
- // there is no matching physical node for OutputListAwareStatsTestPlan.
- val optimizedPlan = df.queryExecution.optimizedPlan
- val rewrite = LogicalRDD.buildOutputAssocForRewrite(optimizedPlan.output,
- df.logicalPlan.output)
- val logicalRDD = LogicalRDD(
- df.logicalPlan.output, spark.sparkContext.emptyRDD[InternalRow],
isStreaming = true)(
- spark, Some(LogicalRDD.rewriteStatistics(optimizedPlan.stats,
rewrite.get)), None)
-
- val stats = logicalRDD.computeStats()
- val expectedStats = Statistics(sizeInBytes = expectedSize, rowCount =
Some(2),
- attributeStats = buildExpectedColumnStats(logicalRDD.output))
- assert(stats === expectedStats)
-
- // This method re-issues expression IDs for all outputs. We expect
column stats to be
- // reflected as well.
- val newLogicalRDD = logicalRDD.newInstance()
- val newStats = newLogicalRDD.computeStats()
- val newExpectedStats = Statistics(sizeInBytes = expectedSize, rowCount =
Some(2),
- attributeStats = buildExpectedColumnStats(newLogicalRDD.output))
- assert(newStats === newExpectedStats)
- }
- }
-
test("SPARK-39834: build the constraints for LogicalRDD based on origin
constraints") {
def buildExpectedConstraints(attrs: Seq[Attribute]): ExpressionSet = {
val exprs = attrs.flatMap { attr =>
@@ -2296,64 +1601,6 @@ class DataFrameSuite extends QueryTest
checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a"))
}
- test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
- val df = sparkContext.parallelize(Seq(
- java.lang.Integer.valueOf(22) -> "John",
- null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name")
-
- // passing null into the UDF that could handle it
- val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
- (i: java.lang.Integer) => if (i == null) -10 else null
- }
- checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil)
-
- spark.udf.register("boxedUDF",
- (i: java.lang.Integer) => (if (i == null) -10 else null):
java.lang.Integer)
- checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) ::
Nil)
-
- val primitiveUDF = udf((i: Int) => i * 2)
- checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)
- }
-
- test("SPARK-12398 truncated toString") {
- val df1 = Seq((1L, "row1")).toDF("id", "name")
- assert(df1.toString() === "[id: bigint, name: string]")
-
- val df2 = Seq((1L, "c2", false)).toDF("c1", "c2", "c3")
- assert(df2.toString === "[c1: bigint, c2: string ... 1 more field]")
-
- val df3 = Seq((1L, "c2", false, 10)).toDF("c1", "c2", "c3", "c4")
- assert(df3.toString === "[c1: bigint, c2: string ... 2 more fields]")
-
- val df4 = Seq((1L, Tuple2(1L, "val"))).toDF("c1", "c2")
- assert(df4.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string>]")
-
- val df5 = Seq((1L, Tuple2(1L, "val"), 20.0)).toDF("c1", "c2", "c3")
- assert(df5.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string>
... 1 more field]")
-
- val df6 = Seq((1L, Tuple2(1L, "val"), 20.0, 1)).toDF("c1", "c2", "c3",
"c4")
- assert(df6.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string>
... 2 more fields]")
-
- val df7 = Seq((1L, Tuple3(1L, "val", 2), 20.0, 1)).toDF("c1", "c2", "c3",
"c4")
- assert(
- df7.toString ===
- "[c1: bigint, c2: struct<_1: bigint, _2: string ... 1 more field> ...
2 more fields]")
-
- val df8 = Seq((1L, Tuple7(1L, "val", 2, 3, 4, 5, 6), 20.0, 1)).toDF("c1",
"c2", "c3", "c4")
- assert(
- df8.toString ===
- "[c1: bigint, c2: struct<_1: bigint, _2: string ... 5 more fields> ...
2 more fields]")
-
- val df9 =
- Seq((1L, Tuple4(1L, Tuple4(1L, 2L, 3L, 4L), 2L, 3L), 20.0,
1)).toDF("c1", "c2", "c3", "c4")
- assert(
- df9.toString ===
- "[c1: bigint, c2: struct<_1: bigint," +
- " _2: struct<_1: bigint," +
- " _2: bigint ... 2 more fields> ... 2 more fields> ... 2 more
fields]")
-
- }
-
test("reuse exchange") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") {
val df = spark.range(100).toDF()
@@ -2431,66 +1678,11 @@ class DataFrameSuite extends QueryTest
assert(e.getStackTrace.head.getClassName !=
classOf[QueryExecution].getName)
}
- test("SPARK-13774: Check error message for non existent path without globbed
paths") {
- val uuid = UUID.randomUUID().toString
- val baseDir = Utils.createTempDir()
- checkError(
- exception = intercept[AnalysisException] {
- spark.read.format("csv").load(
- new File(baseDir, "file").getAbsolutePath,
- new File(baseDir, "file2").getAbsolutePath,
- new File(uuid, "file3").getAbsolutePath,
- uuid).rdd
- },
- errorClass = "PATH_NOT_FOUND",
- parameters = Map("path" -> "file:.*"),
- matchPVals = true
- )
- }
-
- test("SPARK-13774: Check error message for not existent globbed paths") {
- // Non-existent initial path component:
- val nonExistentBasePath = "/" + UUID.randomUUID().toString
- assert(!new File(nonExistentBasePath).exists())
- checkError(
- exception = intercept[AnalysisException] {
- spark.read.format("text").load(s"$nonExistentBasePath/*")
- },
- errorClass = "PATH_NOT_FOUND",
- parameters = Map("path" -> s"file:$nonExistentBasePath/*")
- )
-
- // Existent initial path component, but no matching files:
- val baseDir = Utils.createTempDir()
- val childDir = Utils.createTempDir(baseDir.getAbsolutePath)
- assert(childDir.exists())
- try {
- checkError(
- exception = intercept[AnalysisException] {
- spark.read.json(s"${baseDir.getAbsolutePath}/*/*-xyz.json").rdd
- },
- errorClass = "PATH_NOT_FOUND",
- parameters = Map("path" ->
s"file:${baseDir.getAbsolutePath}/*/*-xyz.json")
- )
- } finally {
- Utils.deleteRecursively(baseDir)
- }
- }
-
test("SPARK-15230: distinct() does not handle column name with dot
properly") {
val df = Seq(1, 1, 2).toDF("column.with.dot")
checkAnswer(df.distinct(), Row(1) :: Row(2) :: Nil)
}
- test("SPARK-16181: outer join with isNull filter") {
- val left = Seq("x").toDF("col")
- val right = Seq("y").toDF("col").withColumn("new", lit(true))
- val joined = left.join(right, left("col") === right("col"), "left_outer")
-
- checkAnswer(joined, Row("x", null, null))
- checkAnswer(joined.filter($"new".isNull), Row("x", null, null))
- }
-
test("SPARK-16664: persist with more than 200 columns") {
val size = 201L
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(Seq.range(0, size))))
@@ -2611,21 +1803,6 @@ class DataFrameSuite extends QueryTest
checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil)
}
- test("SPARK-19691 Calculating percentile of decimal column fails with
ClassCastException") {
- val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as
x").selectExpr("percentile(x, 0.5)")
- checkAnswer(df, Row(BigDecimal(0)) :: Nil)
- }
-
- test("SPARK-20359: catalyst outer join optimization should not throw npe") {
- val df1 = Seq("a", "b", "c").toDF("x")
- .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!"
}.apply($"x"))
- val df2 = Seq("a", "b").toDF("x1")
- df1
- .join(df2, df1("x") === df2("x1"), "left_outer")
- .filter($"x1".isNotNull || !$"y".isin("a!"))
- .count()
- }
-
// The fix of SPARK-21720 avoid an exception regarding JVM code size limit
// TODO: When we make a threshold of splitting statements (1024)
configurable,
// we will re-enable this with max threshold to cause an exception
@@ -2651,16 +1828,6 @@ class DataFrameSuite extends QueryTest
}
}
- test("SPARK-20897: cached self-join should not fail") {
- // force to plan sort merge join
- withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
- val df = Seq(1 -> "a").toDF("i", "j")
- val df1 = df.as("t1")
- val df2 = df.as("t2")
- assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1)
- }
- }
-
test("order-by ordinal.") {
checkAnswer(
testData2.select(lit(7), $"a", $"b").orderBy(lit(1), lit(2), lit(3)),
@@ -2689,69 +1856,11 @@ class DataFrameSuite extends QueryTest
assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
}
- test("SPARK-24165: CaseWhen/If - nullability of nested types") {
- val rows = new java.util.ArrayList[Row]()
- rows.add(Row(true, ("x", 1), Seq("x", "y"), Map(0 -> "x")))
- rows.add(Row(false, (null, 2), Seq(null, "z"), Map(0 -> null)))
- val schema = StructType(Seq(
- StructField("cond", BooleanType, true),
- StructField("s", StructType(Seq(
- StructField("val1", StringType, true),
- StructField("val2", IntegerType, false)
- )), false),
- StructField("a", ArrayType(StringType, true)),
- StructField("m", MapType(IntegerType, StringType, true))
- ))
-
- val sourceDF = spark.createDataFrame(rows, schema)
-
- def structWhenDF: DataFrame = sourceDF
- .select(when($"cond",
- struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise($"s") as
"res")
- .select($"res".getField("val1"))
- def arrayWhenDF: DataFrame = sourceDF
- .select(when($"cond", array(lit("a"), lit("b"))).otherwise($"a") as
"res")
- .select($"res".getItem(0))
- def mapWhenDF: DataFrame = sourceDF
- .select(when($"cond", map(lit(0), lit("a"))).otherwise($"m") as "res")
- .select($"res".getItem(0))
-
- def structIfDF: DataFrame = sourceDF
- .select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res")
- .select($"res".getField("val1"))
- def arrayIfDF: DataFrame = sourceDF
- .select(expr("if(cond, array('a', 'b'), a)") as "res")
- .select($"res".getItem(0))
- def mapIfDF: DataFrame = sourceDF
- .select(expr("if(cond, map(0, 'a'), m)") as "res")
- .select($"res".getItem(0))
-
- def checkResult(): Unit = {
- checkAnswer(structWhenDF, Seq(Row("a"), Row(null)))
- checkAnswer(arrayWhenDF, Seq(Row("a"), Row(null)))
- checkAnswer(mapWhenDF, Seq(Row("a"), Row(null)))
- checkAnswer(structIfDF, Seq(Row("a"), Row(null)))
- checkAnswer(arrayIfDF, Seq(Row("a"), Row(null)))
- checkAnswer(mapIfDF, Seq(Row("a"), Row(null)))
- }
-
- // Test with local relation, the Project will be evaluated without codegen
- checkResult()
- // Test with cached relation, the Project will be evaluated with codegen
- sourceDF.cache()
- checkResult()
- }
-
test("Uuid expressions should produce same results at retries in the same
DataFrame") {
val df = spark.range(1).select($"id", new Column(Uuid()))
checkAnswer(df, df.collect())
}
- test("SPARK-24313: access map with binary keys") {
- val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
-
checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))),
Row(1))
- }
-
test("SPARK-24781: Using a reference from Dataset in Filter/Sort") {
val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
val filter1 = df.select(df("name")).filter(df("id") === 0)
@@ -2779,31 +1888,6 @@ class DataFrameSuite extends QueryTest
}
}
- test("SPARK-25159: json schema inference should only trigger one job") {
- withTempPath { path =>
- // This test is to prove that the `JsonInferSchema` does not use
`RDD#toLocalIterator` which
- // triggers one Spark job per RDD partition.
- Seq(1 -> "a", 2 -> "b").toDF("i", "p")
- // The data set has 2 partitions, so Spark will write at least 2 json
files.
- // Use a non-splittable compression (gzip), to make sure the json scan
RDD has at least 2
- // partitions.
- .write.partitionBy("p")
- .option("compression",
GZIP.lowerCaseName()).json(path.getCanonicalPath)
-
- val numJobs = new AtomicLong(0)
- sparkContext.addSparkListener(new SparkListener {
- override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
- numJobs.incrementAndGet()
- }
- })
-
- val df = spark.read.json(path.getCanonicalPath)
- assert(df.columns === Array("i", "p"))
- spark.sparkContext.listenerBus.waitUntilEmpty()
- assert(numJobs.get() == 1L)
- }
- }
-
test("SPARK-25402 Null handling in BooleanSimplification") {
val schema = StructType.fromDDL("a boolean, b int")
val rows = Seq(Row(null, 1))
@@ -3038,80 +2122,6 @@ class DataFrameSuite extends QueryTest
checkAnswer(df.selectExpr("b"), Row(new CalendarInterval(1, 2, 3)))
}
- test("SPARK-31552: array encoder with different types") {
- // primitives
- val booleans = Array(true, false)
- checkAnswer(Seq(booleans).toDF(), Row(booleans))
-
- val bytes = Array(1.toByte, 2.toByte)
- checkAnswer(Seq(bytes).toDF(), Row(bytes))
- val shorts = Array(1.toShort, 2.toShort)
- checkAnswer(Seq(shorts).toDF(), Row(shorts))
- val ints = Array(1, 2)
- checkAnswer(Seq(ints).toDF(), Row(ints))
- val longs = Array(1L, 2L)
- checkAnswer(Seq(longs).toDF(), Row(longs))
-
- val floats = Array(1.0F, 2.0F)
- checkAnswer(Seq(floats).toDF(), Row(floats))
- val doubles = Array(1.0D, 2.0D)
- checkAnswer(Seq(doubles).toDF(), Row(doubles))
-
- val strings = Array("2020-04-24", "2020-04-25")
- checkAnswer(Seq(strings).toDF(), Row(strings))
-
- // tuples
- val decOne = Decimal(1, 38, 18)
- val decTwo = Decimal(2, 38, 18)
- val tuple1 = (1, 2.2, "3.33", decOne, Date.valueOf("2012-11-22"))
- val tuple2 = (2, 3.3, "4.44", decTwo, Date.valueOf("2022-11-22"))
- checkAnswer(Seq(Array(tuple1, tuple2)).toDF(), Seq(Seq(tuple1,
tuple2)).toDF())
-
- // case classes
- val gbks = Array(GroupByKey(1, 2), GroupByKey(4, 5))
- checkAnswer(Seq(gbks).toDF(), Row(Array(Row(1, 2), Row(4, 5))))
-
- // We can move this implicit def to [[SQLImplicits]] when we eventually
make fully
- // support for array encoder like Seq and Set
- // For now cases below, decimal/datetime/interval/binary/nested types, etc,
- // are not supported by array
- implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] =
ExpressionEncoder()
-
- // decimals
- val decSpark = Array(decOne, decTwo)
- val decScala = decSpark.map(_.toBigDecimal)
- val decJava = decSpark.map(_.toJavaBigDecimal)
- checkAnswer(Seq(decSpark).toDF(), Row(decJava))
- checkAnswer(Seq(decScala).toDF(), Row(decJava))
- checkAnswer(Seq(decJava).toDF(), Row(decJava))
-
- // datetimes and intervals
- val dates = strings.map(Date.valueOf)
- checkAnswer(Seq(dates).toDF(), Row(dates))
- val localDates = dates.map(d =>
DateTimeUtils.daysToLocalDate(DateTimeUtils.fromJavaDate(d)))
- checkAnswer(Seq(localDates).toDF(), Row(dates))
-
- val timestamps =
- Array(Timestamp.valueOf("2020-04-24 12:34:56"),
Timestamp.valueOf("2020-04-24 11:22:33"))
- checkAnswer(Seq(timestamps).toDF(), Row(timestamps))
- val instants =
- timestamps.map(t =>
DateTimeUtils.microsToInstant(DateTimeUtils.fromJavaTimestamp(t)))
- checkAnswer(Seq(instants).toDF(), Row(timestamps))
-
- val intervals = Array(new CalendarInterval(1, 2, 3), new
CalendarInterval(4, 5, 6))
- checkAnswer(Seq(intervals).toDF(), Row(intervals))
-
- // binary
- val bins = Array(Array(1.toByte), Array(2.toByte), Array(3.toByte),
Array(4.toByte))
- checkAnswer(Seq(bins).toDF(), Row(bins))
-
- // nested
- val nestedIntArray = Array(Array(1), Array(2))
- checkAnswer(Seq(nestedIntArray).toDF(),
Row(nestedIntArray.map(wrapIntArray)))
- val nestedDecArray = Array(decSpark)
- checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava))))
- }
-
test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") {
withTempPath { f =>
sql("select cast(1 as decimal(38, 0)) as d")
@@ -3128,10 +2138,6 @@ class DataFrameSuite extends QueryTest
checkAnswer(df.selectExpr("ln(d)"), Row(Double.NaN))
}
- test("SPARK-32761: aggregating multiple distinct CONSTANT columns") {
- checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1,
1))
- }
-
test("SPARK-32764: -0.0 and 0.0 should be equal") {
val df = Seq(0.0 -> -0.0).toDF("pos", "neg")
checkAnswer(df.select($"pos" > $"neg"), Row(false))
@@ -3232,79 +2238,6 @@ class DataFrameSuite extends QueryTest
checkAnswer(test.select($"best_name.name"), Row("bob") :: Row("bob") ::
Row("sam") :: Nil)
}
- test("SPARK-34829: Multiple applications of typed ScalaUDFs in higher order
functions work") {
- val reverse = udf((s: String) => s.reverse)
- val reverse2 = udf((b: Bar2) => Bar2(b.s.reverse))
-
- val df = Seq(Array("abc", "def")).toDF("array")
- val test = df.select(transform(col("array"), s => reverse(s)))
- checkAnswer(test, Row(Array("cba", "fed")) :: Nil)
-
- val df2 = Seq(Array(Bar2("abc"), Bar2("def"))).toDF("array")
- val test2 = df2.select(transform(col("array"), b => reverse2(b)))
- checkAnswer(test2, Row(Array(Row("cba"), Row("fed"))) :: Nil)
-
- val df3 = Seq(Map("abc" -> 1, "def" -> 2)).toDF("map")
- val test3 = df3.select(transform_keys(col("map"), (s, _) => reverse(s)))
- checkAnswer(test3, Row(Map("cba" -> 1, "fed" -> 2)) :: Nil)
-
- val df4 = Seq(Map(Bar2("abc") -> 1, Bar2("def") -> 2)).toDF("map")
- val test4 = df4.select(transform_keys(col("map"), (b, _) => reverse2(b)))
- checkAnswer(test4, Row(Map(Row("cba") -> 1, Row("fed") -> 2)) :: Nil)
-
- val df5 = Seq(Map(1 -> "abc", 2 -> "def")).toDF("map")
- val test5 = df5.select(transform_values(col("map"), (_, s) => reverse(s)))
- checkAnswer(test5, Row(Map(1 -> "cba", 2 -> "fed")) :: Nil)
-
- val df6 = Seq(Map(1 -> Bar2("abc"), 2 -> Bar2("def"))).toDF("map")
- val test6 = df6.select(transform_values(col("map"), (_, b) => reverse2(b)))
- checkAnswer(test6, Row(Map(1 -> Row("cba"), 2 -> Row("fed"))) :: Nil)
-
- val reverseThenConcat = udf((s1: String, s2: String) => s1.reverse ++
s2.reverse)
- val reverseThenConcat2 = udf((b1: Bar2, b2: Bar2) => Bar2(b1.s.reverse ++
b2.s.reverse))
-
- val df7 = Seq((Map(1 -> "abc", 2 -> "def"), Map(1 -> "ghi", 2 ->
"jkl"))).toDF("map1", "map2")
- val test7 =
- df7.select(map_zip_with(col("map1"), col("map2"), (_, s1, s2) =>
reverseThenConcat(s1, s2)))
- checkAnswer(test7, Row(Map(1 -> "cbaihg", 2 -> "fedlkj")) :: Nil)
-
- val df8 = Seq((Map(1 -> Bar2("abc"), 2 -> Bar2("def")),
- Map(1 -> Bar2("ghi"), 2 -> Bar2("jkl")))).toDF("map1", "map2")
- val test8 =
- df8.select(map_zip_with(col("map1"), col("map2"), (_, b1, b2) =>
reverseThenConcat2(b1, b2)))
- checkAnswer(test8, Row(Map(1 -> Row("cbaihg"), 2 -> Row("fedlkj"))) :: Nil)
-
- val df9 = Seq((Array("abc", "def"), Array("ghi", "jkl"))).toDF("array1",
"array2")
- val test9 =
- df9.select(zip_with(col("array1"), col("array2"), (s1, s2) =>
reverseThenConcat(s1, s2)))
- checkAnswer(test9, Row(Array("cbaihg", "fedlkj")) :: Nil)
-
- val df10 = Seq((Array(Bar2("abc"), Bar2("def")), Array(Bar2("ghi"),
Bar2("jkl"))))
- .toDF("array1", "array2")
- val test10 =
- df10.select(zip_with(col("array1"), col("array2"), (b1, b2) =>
reverseThenConcat2(b1, b2)))
- checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil)
- }
-
- test("SPARK-39293: The accumulator of ArrayAggregate to handle complex types
properly") {
- val reverse = udf((s: String) => s.reverse)
-
- val df = Seq(Array("abc", "def")).toDF("array")
- val testArray = df.select(
- aggregate(
- col("array"),
- array().cast("array<string>"),
- (acc, s) => concat(acc, array(reverse(s)))))
- checkAnswer(testArray, Row(Array("cba", "fed")) :: Nil)
-
- val testMap = df.select(
- aggregate(
- col("array"),
- map().cast("map<string, string>"),
- (acc, s) => map_concat(acc, map(s, reverse(s)))))
- checkAnswer(testMap, Row(Map("abc" -> "cba", "def" -> "fed")) :: Nil)
- }
-
test("SPARK-34882: Aggregate with multiple distinct null sensitive
aggregators") {
withUserDefinedFunction(("countNulls", true)) {
spark.udf.register("countNulls", udaf(new Aggregator[JLong, JLong,
JLong] {
@@ -3405,87 +2338,6 @@ class DataFrameSuite extends QueryTest
assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet)
}
- test("SPARK-35320: Reading JSON with key type different to String in a map
should fail") {
- Seq(
- (MapType(IntegerType, StringType), """{"1": "test"}"""),
- (StructType(Seq(StructField("test", MapType(IntegerType, StringType)))),
- """"test": {"1": "test"}"""),
- (ArrayType(MapType(IntegerType, StringType)), """[{"1": "test"}]"""),
- (MapType(StringType, MapType(IntegerType, StringType)), """{"key": {"1"
: "test"}}""")
- ).foreach { case (schema, jsonData) =>
- withTempDir { dir =>
- val colName = "col"
- val msg = "can only contain STRING as a key type for a MAP"
-
- val thrown1 = intercept[AnalysisException](
- spark.read.schema(StructType(Seq(StructField(colName, schema))))
- .json(Seq(jsonData).toDS()).collect())
- assert(thrown1.getMessage.contains(msg))
-
- val jsonDir = new File(dir, "json").getCanonicalPath
- Seq(jsonData).toDF(colName).write.json(jsonDir)
- val thrown2 = intercept[AnalysisException](
- spark.read.schema(StructType(Seq(StructField(colName, schema))))
- .json(jsonDir).collect())
- assert(thrown2.getMessage.contains(msg))
- }
- }
- }
-
- test("SPARK-37855: IllegalStateException when transforming an array inside a
nested struct") {
- def makeInput(): DataFrame = {
- val innerElement1 = Row(3, 3.12)
- val innerElement2 = Row(4, 2.1)
- val innerElement3 = Row(1, 985.2)
- val innerElement4 = Row(10, 757548.0)
- val innerElement5 = Row(1223, 0.665)
-
- val outerElement1 = Row(1, Row(List(innerElement1, innerElement2)))
- val outerElement2 = Row(2, Row(List(innerElement3)))
- val outerElement3 = Row(3, Row(List(innerElement4, innerElement5)))
-
- val data = Seq(
- Row("row1", List(outerElement1)),
- Row("row2", List(outerElement2, outerElement3))
- )
-
- val schema = new StructType()
- .add("name", StringType)
- .add("outer_array", ArrayType(new StructType()
- .add("id", IntegerType)
- .add("inner_array_struct", new StructType()
- .add("inner_array", ArrayType(new StructType()
- .add("id", IntegerType)
- .add("value", DoubleType)
- ))
- )
- ))
-
- spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
- }
-
- val df = makeInput().limit(2)
-
- val res = df.withColumn("extracted", transform(
- col("outer_array"),
- c1 => {
- struct(
- c1.getField("id").alias("outer_id"),
- transform(
- c1.getField("inner_array_struct").getField("inner_array"),
- c2 => {
- struct(
- c2.getField("value").alias("inner_value")
- )
- }
- )
- )
- }
- ))
-
- assert(res.collect().length == 2)
- }
-
test("SPARK-38285: Fix ClassCastException: GenericArrayData cannot be cast
to InternalRow") {
withTempView("v1") {
val sqlText =
@@ -3782,50 +2634,6 @@ case class GroupByKey(a: Int, b: Int)
case class Bar2(s: String)
-/**
- * This class is used for unit-testing. It's a logical plan whose output and
stats are passed in.
- */
-case class OutputListAwareStatsTestPlan(
- outputList: Seq[Attribute],
- rowCount: BigInt,
- size: Option[BigInt] = None) extends LeafNode with MultiInstanceRelation {
- override def output: Seq[Attribute] = outputList
- override def computeStats(): Statistics = {
- val columnInfo = outputList.map { attr =>
- attr.dataType match {
- case BooleanType =>
- attr -> ColumnStat(
- distinctCount = Some(2),
- min = Some(false),
- max = Some(true),
- nullCount = Some(0),
- avgLen = Some(1),
- maxLen = Some(1))
-
- case ByteType =>
- attr -> ColumnStat(
- distinctCount = Some(2),
- min = Some(1),
- max = Some(2),
- nullCount = Some(0),
- avgLen = Some(1),
- maxLen = Some(1))
-
- case _ =>
- attr -> ColumnStat()
- }
- }
- val attrStats = AttributeMap(columnInfo)
-
- Statistics(
- // If sizeInBytes is useless in testing, we just use a fake value
- sizeInBytes = size.getOrElse(Int.MaxValue),
- rowCount = Some(rowCount),
- attributeStats = attrStats)
- }
- override def newInstance(): LogicalPlan = copy(outputList =
outputList.map(_.newInstance()))
-}
-
/**
* This class is used for unit-testing. It's a logical plan whose output is
passed in.
*/
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index e827396009d0..393ecc95b66b 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -24,12 +24,15 @@ import java.util.concurrent.TimeUnit
import scala.collection.mutable
-import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap,
AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils}
import
org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, PST,
UTC}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId,
TimeZoneUTC}
+import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.functions.timestamp_seconds
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -860,4 +863,124 @@ class StatisticsCollectionSuite extends
StatisticsCollectionTestBase with Shared
assert(stats.rowCount.isDefined && stats.rowCount.get == 6)
}
}
+
+ test("SPARK-39834: build the stats for LogicalRDD based on origin stats") {
+ def buildExpectedColumnStats(attrs: Seq[Attribute]):
AttributeMap[ColumnStat] = {
+ AttributeMap(
+ attrs.map {
+ case attr if attr.dataType == BooleanType =>
+ attr -> ColumnStat(
+ distinctCount = Some(2),
+ min = Some(false),
+ max = Some(true),
+ nullCount = Some(0),
+ avgLen = Some(1),
+ maxLen = Some(1))
+
+ case attr if attr.dataType == ByteType =>
+ attr -> ColumnStat(
+ distinctCount = Some(2),
+ min = Some(1),
+ max = Some(2),
+ nullCount = Some(0),
+ avgLen = Some(1),
+ maxLen = Some(1))
+
+ case attr => attr -> ColumnStat()
+ }
+ )
+ }
+
+ val outputList = Seq(
+ AttributeReference("cbool", BooleanType)(),
+ AttributeReference("cbyte", ByteType)(),
+ AttributeReference("cint", IntegerType)()
+ )
+
+ val expectedSize = 16
+ val statsPlan = OutputListAwareStatsTestPlan(
+ outputList = outputList,
+ rowCount = 2,
+ size = Some(expectedSize))
+
+ withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
+ val df = Dataset.ofRows(spark, statsPlan)
+ // add some map-like operations which optimizer will optimize away,
and make a divergence
+ // for output between logical plan and optimized plan
+ // logical plan
+ // Project [cb#6 AS cbool#12, cby#7 AS cbyte#13, ci#8 AS cint#14]
+ // +- Project [cbool#0 AS cb#6, cbyte#1 AS cby#7, cint#2 AS ci#8]
+ // +- OutputListAwareStatsTestPlan [cbool#0, cbyte#1, cint#2], 2, 16
+ // optimized plan
+ // OutputListAwareStatsTestPlan [cbool#0, cbyte#1, cint#2], 2, 16
+ .selectExpr("cbool AS cb", "cbyte AS cby", "cint AS ci")
+ .selectExpr("cb AS cbool", "cby AS cbyte", "ci AS cint")
+
+ // We can't leverage LogicalRDD.fromDataset here, since it triggers
physical planning and
+ // there is no matching physical node for OutputListAwareStatsTestPlan.
+ val optimizedPlan = df.queryExecution.optimizedPlan
+ val rewrite = LogicalRDD.buildOutputAssocForRewrite(optimizedPlan.output,
+ df.logicalPlan.output)
+ val logicalRDD = LogicalRDD(
+ df.logicalPlan.output, spark.sparkContext.emptyRDD[InternalRow],
isStreaming = true)(
+ spark, Some(LogicalRDD.rewriteStatistics(optimizedPlan.stats,
rewrite.get)), None)
+
+ val stats = logicalRDD.computeStats()
+ val expectedStats = Statistics(sizeInBytes = expectedSize, rowCount =
Some(2),
+ attributeStats = buildExpectedColumnStats(logicalRDD.output))
+ assert(stats === expectedStats)
+
+ // This method re-issues expression IDs for all outputs. We expect
column stats to be
+ // reflected as well.
+ val newLogicalRDD = logicalRDD.newInstance()
+ val newStats = newLogicalRDD.computeStats()
+ val newExpectedStats = Statistics(sizeInBytes = expectedSize, rowCount =
Some(2),
+ attributeStats = buildExpectedColumnStats(newLogicalRDD.output))
+ assert(newStats === newExpectedStats)
+ }
+ }
+}
+
+/**
+ * This class is used for unit-testing. It's a logical plan whose output and
stats are passed in.
+ */
+case class OutputListAwareStatsTestPlan(
+ outputList: Seq[Attribute],
+ rowCount: BigInt,
+ size: Option[BigInt] = None) extends LeafNode with MultiInstanceRelation {
+ override def output: Seq[Attribute] = outputList
+ override def computeStats(): Statistics = {
+ val columnInfo = outputList.map { attr =>
+ attr.dataType match {
+ case BooleanType =>
+ attr -> ColumnStat(
+ distinctCount = Some(2),
+ min = Some(false),
+ max = Some(true),
+ nullCount = Some(0),
+ avgLen = Some(1),
+ maxLen = Some(1))
+
+ case ByteType =>
+ attr -> ColumnStat(
+ distinctCount = Some(2),
+ min = Some(1),
+ max = Some(2),
+ nullCount = Some(0),
+ avgLen = Some(1),
+ maxLen = Some(1))
+
+ case _ =>
+ attr -> ColumnStat()
+ }
+ }
+ val attrStats = AttributeMap(columnInfo)
+
+ Statistics(
+ // If sizeInBytes is useless in testing, we just use a fake value
+ sizeInBytes = size.getOrElse(Int.MaxValue),
+ rowCount = Some(rowCount),
+ attributeStats = attrStats)
+ }
+ override def newInstance(): LogicalPlan = copy(outputList =
outputList.map(_.newInstance()))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 734a4db89627..87ca3a07c4d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -40,7 +40,7 @@ import
org.apache.spark.sql.execution.columnar.InMemoryRelation
import
org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand,
ExplainCommand}
import
org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer,
SparkUserDefinedFunction, UserDefinedAggregateFunction}
-import org.apache.spark.sql.functions.{lit, struct, udaf, udf}
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData._
@@ -56,6 +56,24 @@ private case class TimestampInstantType(t: Timestamp,
instant: Instant)
class UDFSuite extends QueryTest with SharedSparkSession {
import testImplicits._
+ test("udf") {
+ val foo = udf((a: Int, b: String) => a.toString + b)
+
+ checkAnswer(
+ // SELECT *, foo(key, value) FROM testData
+ testData.select($"*", foo($"key", $"value")).limit(3),
+ Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil
+ )
+ }
+
+ test("callUDF without Hive Support") {
+ val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
+ df.sparkSession.udf.register("simpleUDF", (v: Int) => v * v)
+ checkAnswer(
+ df.select($"id", callUDF("simpleUDF", $"value")), // test deprecated one
+ Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil)
+ }
+
test("built-in fixed arity expressions") {
val df = spark.emptyDataFrame
df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)")
@@ -1092,4 +1110,77 @@ class UDFSuite extends QueryTest with SharedSparkSession
{
.lookupFunctionInfo(FunctionIdentifier("dummyUDF"))
assert(expressionInfo.getClassName.contains("org.apache.spark.sql.UDFRegistration$$Lambda"))
}
+
+ test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
+ val df = sparkContext.parallelize(Seq(
+ java.lang.Integer.valueOf(22) -> "John",
+ null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name")
+
+ // passing null into the UDF that could handle it
+ val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
+ (i: java.lang.Integer) => if (i == null) -10 else null
+ }
+ checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil)
+
+ spark.udf.register("boxedUDF",
+ (i: java.lang.Integer) => (if (i == null) -10 else null):
java.lang.Integer)
+ checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) ::
Nil)
+
+ val primitiveUDF = udf((i: Int) => i * 2)
+ checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)
+ }
+
+ test("SPARK-34829: Multiple applications of typed ScalaUDFs in higher order
functions work") {
+ val reverse = udf((s: String) => s.reverse)
+ val reverse2 = udf((b: Bar2) => Bar2(b.s.reverse))
+
+ val df = Seq(Array("abc", "def")).toDF("array")
+ val test = df.select(transform(col("array"), s => reverse(s)))
+ checkAnswer(test, Row(Array("cba", "fed")) :: Nil)
+
+ val df2 = Seq(Array(Bar2("abc"), Bar2("def"))).toDF("array")
+ val test2 = df2.select(transform(col("array"), b => reverse2(b)))
+ checkAnswer(test2, Row(Array(Row("cba"), Row("fed"))) :: Nil)
+
+ val df3 = Seq(Map("abc" -> 1, "def" -> 2)).toDF("map")
+ val test3 = df3.select(transform_keys(col("map"), (s, _) => reverse(s)))
+ checkAnswer(test3, Row(Map("cba" -> 1, "fed" -> 2)) :: Nil)
+
+ val df4 = Seq(Map(Bar2("abc") -> 1, Bar2("def") -> 2)).toDF("map")
+ val test4 = df4.select(transform_keys(col("map"), (b, _) => reverse2(b)))
+ checkAnswer(test4, Row(Map(Row("cba") -> 1, Row("fed") -> 2)) :: Nil)
+
+ val df5 = Seq(Map(1 -> "abc", 2 -> "def")).toDF("map")
+ val test5 = df5.select(transform_values(col("map"), (_, s) => reverse(s)))
+ checkAnswer(test5, Row(Map(1 -> "cba", 2 -> "fed")) :: Nil)
+
+ val df6 = Seq(Map(1 -> Bar2("abc"), 2 -> Bar2("def"))).toDF("map")
+ val test6 = df6.select(transform_values(col("map"), (_, b) => reverse2(b)))
+ checkAnswer(test6, Row(Map(1 -> Row("cba"), 2 -> Row("fed"))) :: Nil)
+
+ val reverseThenConcat = udf((s1: String, s2: String) => s1.reverse ++
s2.reverse)
+ val reverseThenConcat2 = udf((b1: Bar2, b2: Bar2) => Bar2(b1.s.reverse ++
b2.s.reverse))
+
+ val df7 = Seq((Map(1 -> "abc", 2 -> "def"), Map(1 -> "ghi", 2 ->
"jkl"))).toDF("map1", "map2")
+ val test7 =
+ df7.select(map_zip_with(col("map1"), col("map2"), (_, s1, s2) =>
reverseThenConcat(s1, s2)))
+ checkAnswer(test7, Row(Map(1 -> "cbaihg", 2 -> "fedlkj")) :: Nil)
+
+ val df8 = Seq((Map(1 -> Bar2("abc"), 2 -> Bar2("def")),
+ Map(1 -> Bar2("ghi"), 2 -> Bar2("jkl")))).toDF("map1", "map2")
+ val test8 =
+ df8.select(map_zip_with(col("map1"), col("map2"), (_, b1, b2) =>
reverseThenConcat2(b1, b2)))
+ checkAnswer(test8, Row(Map(1 -> Row("cbaihg"), 2 -> Row("fedlkj"))) :: Nil)
+
+ val df9 = Seq((Array("abc", "def"), Array("ghi", "jkl"))).toDF("array1",
"array2")
+ val test9 =
+ df9.select(zip_with(col("array1"), col("array2"), (s1, s2) =>
reverseThenConcat(s1, s2)))
+ checkAnswer(test9, Row(Array("cbaihg", "fedlkj")) :: Nil)
+
+ val df10 = Seq((Array(Bar2("abc"), Bar2("def")), Array(Bar2("ghi"),
Bar2("jkl"))))
+ .toDF("array1", "array2")
+ val test10 =
+ df10.select(zip_with(col("array1"), col("array2"), (b1, b2) =>
reverseThenConcat2(b1, b2)))
+ checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil)
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala
index 90b341ae1f2c..7f886940473d 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.execution.datasources
+import java.io.File
import java.net.URI
+import java.util.UUID
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem}
@@ -25,6 +27,7 @@ import org.scalatest.PrivateMethodTester
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.Utils
class DataSourceSuite extends SharedSparkSession with PrivateMethodTester {
import TestPaths._
@@ -158,6 +161,52 @@ class DataSourceSuite extends SharedSparkSession with
PrivateMethodTester {
val expectMessage = "No FileSystem for scheme nonexistentFs"
assert(message.filterNot(Set(':', '"').contains) == expectMessage)
}
+
+ test("SPARK-13774: Check error message for non existent path without globbed
paths") {
+ val uuid = UUID.randomUUID().toString
+ val baseDir = Utils.createTempDir()
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.read.format("csv").load(
+ new File(baseDir, "file").getAbsolutePath,
+ new File(baseDir, "file2").getAbsolutePath,
+ new File(uuid, "file3").getAbsolutePath,
+ uuid).rdd
+ },
+ errorClass = "PATH_NOT_FOUND",
+ parameters = Map("path" -> "file:.*"),
+ matchPVals = true
+ )
+ }
+
+ test("SPARK-13774: Check error message for not existent globbed paths") {
+ // Non-existent initial path component:
+ val nonExistentBasePath = "/" + UUID.randomUUID().toString
+ assert(!new File(nonExistentBasePath).exists())
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.read.format("text").load(s"$nonExistentBasePath/*")
+ },
+ errorClass = "PATH_NOT_FOUND",
+ parameters = Map("path" -> s"file:$nonExistentBasePath/*")
+ )
+
+ // Existent initial path component, but no matching files:
+ val baseDir = Utils.createTempDir()
+ val childDir = Utils.createTempDir(baseDir.getAbsolutePath)
+ assert(childDir.exists())
+ try {
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.read.json(s"${baseDir.getAbsolutePath}/*/*-xyz.json").rdd
+ },
+ errorClass = "PATH_NOT_FOUND",
+ parameters = Map("path" ->
s"file:${baseDir.getAbsolutePath}/*/*-xyz.json")
+ )
+ } finally {
+ Utils.deleteRecursively(baseDir)
+ }
+ }
}
object TestPaths {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 5db12c124f91..ad113f1bb896 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -23,6 +23,7 @@ import java.nio.file.Files
import java.sql.{Date, Timestamp}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneId}
import java.util.Locale
+import java.util.concurrent.atomic.AtomicLong
import com.fasterxml.jackson.core.JsonFactory
import org.apache.commons.lang3.exception.ExceptionUtils
@@ -32,9 +33,11 @@ import org.apache.hadoop.io.compress.GzipCodec
import org.apache.spark.{SparkConf, SparkException,
SparkFileNotFoundException, SparkRuntimeException, SparkUpgradeException,
TestUtils}
import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.{functions => F, _}
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils,
HadoopCompressionCodec}
+import org.apache.spark.sql.catalyst.util.HadoopCompressionCodec.GZIP
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLType
import org.apache.spark.sql.execution.ExternalRDD
import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite,
DataSource, InMemoryFileIndex, NoopCache}
@@ -3703,6 +3706,58 @@ abstract class JsonSuite
assert(JSONOptions.getAlternativeOption("charset").contains("encoding"))
assert(JSONOptions.getAlternativeOption("dateFormat").isEmpty)
}
+
+ test("SPARK-25159: json schema inference should only trigger one job") {
+ withTempPath { path =>
+ // This test is to prove that the `JsonInferSchema` does not use
`RDD#toLocalIterator` which
+ // triggers one Spark job per RDD partition.
+ Seq(1 -> "a", 2 -> "b").toDF("i", "p")
+ // The data set has 2 partitions, so Spark will write at least 2 json
files.
+ // Use a non-splittable compression (gzip), to make sure the json scan
RDD has at least 2
+ // partitions.
+ .write.partitionBy("p")
+ .option("compression",
GZIP.lowerCaseName()).json(path.getCanonicalPath)
+
+ val numJobs = new AtomicLong(0)
+ sparkContext.addSparkListener(new SparkListener {
+ override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+ numJobs.incrementAndGet()
+ }
+ })
+
+ val df = spark.read.json(path.getCanonicalPath)
+ assert(df.columns === Array("i", "p"))
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+ assert(numJobs.get() == 1L)
+ }
+ }
+
+ test("SPARK-35320: Reading JSON with key type different to String in a map
should fail") {
+ Seq(
+ (MapType(IntegerType, StringType), """{"1": "test"}"""),
+ (StructType(Seq(StructField("test", MapType(IntegerType, StringType)))),
+ """"test": {"1": "test"}"""),
+ (ArrayType(MapType(IntegerType, StringType)), """[{"1": "test"}]"""),
+ (MapType(StringType, MapType(IntegerType, StringType)), """{"key": {"1"
: "test"}}""")
+ ).foreach { case (schema, jsonData) =>
+ withTempDir { dir =>
+ val colName = "col"
+ val msg = "can only contain STRING as a key type for a MAP"
+
+ val thrown1 = intercept[AnalysisException](
+ spark.read.schema(StructType(Seq(StructField(colName, schema))))
+ .json(Seq(jsonData).toDS()).collect())
+ assert(thrown1.getMessage.contains(msg))
+
+ val jsonDir = new File(dir, "json").getCanonicalPath
+ Seq(jsonData).toDF(colName).write.json(jsonDir)
+ val thrown2 = intercept[AnalysisException](
+ spark.read.schema(StructType(Seq(StructField(colName, schema))))
+ .json(jsonDir).collect())
+ assert(thrown2.getMessage.contains(msg))
+ }
+ }
+ }
}
class JsonV1Suite extends JsonSuite {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]