This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 5687b31  [SPARK-30532] DataFrameStatFunctions to work with 
TABLE.COLUMN syntax
5687b31 is described below

commit 5687b31be3fd1ff991a2e3d918a2dcd79ce1449c
Author: Oleksii Kachaiev <[email protected]>
AuthorDate: Mon Mar 30 13:20:57 2020 +0800

    [SPARK-30532] DataFrameStatFunctions to work with TABLE.COLUMN syntax
    
    ### What changes were proposed in this pull request?
    `DataFrameStatFunctions` now works correctly with fully qualified column 
name (Table.Column syntax) by properly resolving the name instead of relying on 
field names from schema, notably:
    * `approxQuantile`
    * `freqItems`
    * `cov`
    * `corr`
    
    (other functions from `DataFrameStatFunctions` already work correctly).
    
    See code examples below.
    
    ### Why are the changes needed?
    With current implementation some stat functions are impossible to use when 
joining datasets with similar column names.
    
    ### Does this PR introduce any user-facing change?
    Yes. Before the change, the following code would fail with 
`AnalysisException`.
    
    ```scala
    scala> val df1 = sc.parallelize(0 to 10).toDF("num").as("table1")
    df1: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [num: int]
    
    scala> val df2 = sc.parallelize(0 to 10).toDF("num").as("table2")
    df2: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [num: int]
    
    scala> val dfx = df2.crossJoin(df1)
    dfx: org.apache.spark.sql.DataFrame = [num: int, num: int]
    
    scala> dfx.stat.approxQuantile("table1.num", Array(0.1), 0.0)
    res0: Array[Double] = Array(1.0)
    
    scala> dfx.stat.corr("table1.num", "table2.num")
    res1: Double = 1.0
    
    scala> dfx.stat.cov("table1.num", "table2.num")
    res2: Double = 11.0
    
    scala> dfx.stat.freqItems(Array("table1.num", "table2.num"))
    res3: org.apache.spark.sql.DataFrame = [table1.num_freqItems: array<int>, 
table2.num_freqItems: array<int>]
    ```
    
    ### How was this patch tested?
    Corresponding unit tests are added to `DataFrameStatSuite.scala` (marked as 
"SPARK-30532").
    
    Closes #27916 from kachayev/fix-spark-30532.
    
    Authored-by: Oleksii Kachaiev <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit 22bb6b0fddb3ecd3ac0ad2b41a5024c86b8a6fc7)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/execution/stat/FrequentItems.scala   |  4 +---
 .../spark/sql/execution/stat/StatFunctions.scala   |  9 ++++----
 .../org/apache/spark/sql/DataFrameStatSuite.scala  | 26 ++++++++++++++++++++++
 3 files changed, 31 insertions(+), 8 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 6f1b678..bcd226f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -113,10 +113,8 @@ object FrequentItems extends Logging {
     val justItems = freqItems.map(m => m.baseMap.keys.toArray)
     val resultRow = Row(justItems : _*)
 
-    val originalSchema = df.schema
     val outputCols = cols.map { name =>
-      val index = originalSchema.fieldIndex(name)
-      val originalField = originalSchema.fields(index)
+      val originalField = df.resolve(name)
 
       // append frequent Items to the column name for easy debugging
       StructField(name + "_freqItems", ArrayType(originalField.dataType, 
originalField.nullable))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index fffd880..5094e5e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -70,7 +70,7 @@ object StatFunctions extends Logging {
     require(relativeError >= 0,
       s"Relative Error must be non-negative but got $relativeError")
     val columns: Seq[Column] = cols.map { colName =>
-      val field = df.schema(colName)
+      val field = df.resolve(colName)
       require(field.dataType.isInstanceOf[NumericType],
         s"Quantile calculation for column $colName with data type 
${field.dataType}" +
         " is not supported.")
@@ -154,10 +154,9 @@ object StatFunctions extends Logging {
               functionName: String): CovarianceCounter = {
     require(cols.length == 2, s"Currently $functionName calculation is 
supported " +
       "between two columns.")
-    cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { 
case (name, data) =>
-      require(data.nonEmpty, s"Couldn't find column with name $name")
-      require(data.get.dataType.isInstanceOf[NumericType], s"Currently 
$functionName calculation " +
-        s"for columns with dataType ${data.get.dataType.catalogString} not 
supported.")
+    cols.map(name => (name, df.resolve(name))).foreach { case (name, data) =>
+      require(data.dataType.isInstanceOf[NumericType], s"Currently 
$functionName calculation " +
+        s"for columns with dataType ${data.dataType.catalogString} not 
supported.")
     }
     val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
     df.select(columns: _*).queryExecution.toRdd.treeAggregate(new 
CovarianceCounter)(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 394bad7..1960172 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -126,6 +126,32 @@ class DataFrameStatSuite extends QueryTest with 
SharedSparkSession {
     assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
   }
 
+  test("SPARK-30532 stat functions to understand fully-qualified column name") 
{
+    val df1 = spark.sparkContext.parallelize(0 to 10).toDF("num").as("table1")
+    val df2 = spark.sparkContext.parallelize(0 to 10).toDF("num").as("table2")
+    val dfx = df2.crossJoin(df1)
+
+    assert(dfx.stat.corr("table1.num", "table2.num") != 0.0)
+    assert(dfx.stat.cov("table1.num", "table2.num") != 0.0)
+    assert(dfx.stat.approxQuantile("table1.num", Array(0.1), 0.0).length == 1)
+    assert(dfx.stat.approxQuantile("table2.num", Array(0.1), 0.0).length == 1)
+    assert(dfx.stat.freqItems(Array("table1.num", 
"table2.num")).collect()(0).length == 2)
+
+    // this should throw "Reference 'num' is ambiguous"
+    intercept[AnalysisException] {
+      dfx.stat.freqItems(Array("num"))
+    }
+    intercept[AnalysisException] {
+      dfx.stat.approxQuantile("num", Array(0.1), 0.0)
+    }
+    intercept[AnalysisException] {
+      dfx.stat.cov("num", "num")
+    }
+    intercept[AnalysisException] {
+      dfx.stat.corr("num", "num")
+    }
+  }
+
   test("covariance") {
     val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", 
"doubles", "letters")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to