This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 5687b31 [SPARK-30532] DataFrameStatFunctions to work with
TABLE.COLUMN syntax
5687b31 is described below
commit 5687b31be3fd1ff991a2e3d918a2dcd79ce1449c
Author: Oleksii Kachaiev <[email protected]>
AuthorDate: Mon Mar 30 13:20:57 2020 +0800
[SPARK-30532] DataFrameStatFunctions to work with TABLE.COLUMN syntax
### What changes were proposed in this pull request?
`DataFrameStatFunctions` now works correctly with fully qualified column
name (Table.Column syntax) by properly resolving the name instead of relying on
field names from schema, notably:
* `approxQuantile`
* `freqItems`
* `cov`
* `corr`
(other functions from `DataFrameStatFunctions` already work correctly).
See code examples below.
### Why are the changes needed?
With current implementation some stat functions are impossible to use when
joining datasets with similar column names.
### Does this PR introduce any user-facing change?
Yes. Before the change, the following code would fail with
`AnalysisException`.
```scala
scala> val df1 = sc.parallelize(0 to 10).toDF("num").as("table1")
df1: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [num: int]
scala> val df2 = sc.parallelize(0 to 10).toDF("num").as("table2")
df2: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [num: int]
scala> val dfx = df2.crossJoin(df1)
dfx: org.apache.spark.sql.DataFrame = [num: int, num: int]
scala> dfx.stat.approxQuantile("table1.num", Array(0.1), 0.0)
res0: Array[Double] = Array(1.0)
scala> dfx.stat.corr("table1.num", "table2.num")
res1: Double = 1.0
scala> dfx.stat.cov("table1.num", "table2.num")
res2: Double = 11.0
scala> dfx.stat.freqItems(Array("table1.num", "table2.num"))
res3: org.apache.spark.sql.DataFrame = [table1.num_freqItems: array<int>,
table2.num_freqItems: array<int>]
```
### How was this patch tested?
Corresponding unit tests are added to `DataFrameStatSuite.scala` (marked as
"SPARK-30532").
Closes #27916 from kachayev/fix-spark-30532.
Authored-by: Oleksii Kachaiev <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 22bb6b0fddb3ecd3ac0ad2b41a5024c86b8a6fc7)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/execution/stat/FrequentItems.scala | 4 +---
.../spark/sql/execution/stat/StatFunctions.scala | 9 ++++----
.../org/apache/spark/sql/DataFrameStatSuite.scala | 26 ++++++++++++++++++++++
3 files changed, 31 insertions(+), 8 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 6f1b678..bcd226f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -113,10 +113,8 @@ object FrequentItems extends Logging {
val justItems = freqItems.map(m => m.baseMap.keys.toArray)
val resultRow = Row(justItems : _*)
- val originalSchema = df.schema
val outputCols = cols.map { name =>
- val index = originalSchema.fieldIndex(name)
- val originalField = originalSchema.fields(index)
+ val originalField = df.resolve(name)
// append frequent Items to the column name for easy debugging
StructField(name + "_freqItems", ArrayType(originalField.dataType,
originalField.nullable))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index fffd880..5094e5e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -70,7 +70,7 @@ object StatFunctions extends Logging {
require(relativeError >= 0,
s"Relative Error must be non-negative but got $relativeError")
val columns: Seq[Column] = cols.map { colName =>
- val field = df.schema(colName)
+ val field = df.resolve(colName)
require(field.dataType.isInstanceOf[NumericType],
s"Quantile calculation for column $colName with data type
${field.dataType}" +
" is not supported.")
@@ -154,10 +154,9 @@ object StatFunctions extends Logging {
functionName: String): CovarianceCounter = {
require(cols.length == 2, s"Currently $functionName calculation is
supported " +
"between two columns.")
- cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach {
case (name, data) =>
- require(data.nonEmpty, s"Couldn't find column with name $name")
- require(data.get.dataType.isInstanceOf[NumericType], s"Currently
$functionName calculation " +
- s"for columns with dataType ${data.get.dataType.catalogString} not
supported.")
+ cols.map(name => (name, df.resolve(name))).foreach { case (name, data) =>
+ require(data.dataType.isInstanceOf[NumericType], s"Currently
$functionName calculation " +
+ s"for columns with dataType ${data.dataType.catalogString} not
supported.")
}
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
df.select(columns: _*).queryExecution.toRdd.treeAggregate(new
CovarianceCounter)(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 394bad7..1960172 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -126,6 +126,32 @@ class DataFrameStatSuite extends QueryTest with
SharedSparkSession {
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
}
+ test("SPARK-30532 stat functions to understand fully-qualified column name")
{
+ val df1 = spark.sparkContext.parallelize(0 to 10).toDF("num").as("table1")
+ val df2 = spark.sparkContext.parallelize(0 to 10).toDF("num").as("table2")
+ val dfx = df2.crossJoin(df1)
+
+ assert(dfx.stat.corr("table1.num", "table2.num") != 0.0)
+ assert(dfx.stat.cov("table1.num", "table2.num") != 0.0)
+ assert(dfx.stat.approxQuantile("table1.num", Array(0.1), 0.0).length == 1)
+ assert(dfx.stat.approxQuantile("table2.num", Array(0.1), 0.0).length == 1)
+ assert(dfx.stat.freqItems(Array("table1.num",
"table2.num")).collect()(0).length == 2)
+
+ // this should throw "Reference 'num' is ambiguous"
+ intercept[AnalysisException] {
+ dfx.stat.freqItems(Array("num"))
+ }
+ intercept[AnalysisException] {
+ dfx.stat.approxQuantile("num", Array(0.1), 0.0)
+ }
+ intercept[AnalysisException] {
+ dfx.stat.cov("num", "num")
+ }
+ intercept[AnalysisException] {
+ dfx.stat.corr("num", "num")
+ }
+ }
+
test("covariance") {
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles",
"doubles", "letters")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]