sririshindra commented on a change in pull request #31477:
URL: https://github.com/apache/spark/pull/31477#discussion_r576938141



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -137,10 +144,19 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
   protected def streamSideKeyGenerator(): UnsafeProjection =
     UnsafeProjection.create(streamedBoundKeys)
 
-  @transient protected[this] lazy val boundCondition = if 
(condition.isDefined) {
-    Predicate.create(condition.get, streamedPlan.output ++ 
buildPlan.output).eval _
-  } else {
-    (r: InternalRow) => true
+  private val numMatchedRows = longMetric("numMatchedRows")
+
+  @transient protected[this] lazy val boundCondition: InternalRow => Boolean =
+    if (condition.isDefined) {
+      (r: InternalRow) => {
+        numMatchedRows += 1
+        Predicate.create(condition.get, streamedPlan.output ++ 
buildPlan.output).eval(r)
+      }
+    } else {
+      (_: InternalRow) => {
+        numMatchedRows += 1
+        true
+      }
   }

Review comment:
       Done. Fixed in the latest commit.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -433,6 +450,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
     } else {
       ""
     }
+    val checkCondition = s"$numMatched.add(1);\n$conditionDef"

Review comment:
       Done. Fixed in the latest commit.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -716,6 +741,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
     } else {
       s"$existsVar = true;"
     }
+    val checkCondition = s"$numMatched.add(1);\n$conditionDef"

Review comment:
       @dongjoon-hyun I am slightly confused about this part. I added 
`$numMatched.add(1);` within the checkCondition itself. I did this by adding it 
directly  before line 742 and line 737 directly. That way every time 
checkCondition is invoked the metric is incremented. I hope that is what you 
are looking for. If not, could you please elaborate where I am making a 
mistake. 

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
##########
@@ -277,20 +277,34 @@ class SQLMetricsSuite extends SharedSparkSession with 
SQLMetricsTestUtils
     withTempView("testDataForJoin") {
       // Assume the execution plan is
       // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
-      val query = "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a 
= testDataForJoin.a"
+      val query1 = "SELECT * FROM testData2 JOIN testDataForJoin ON 
testData2.a = testDataForJoin.a"
       Seq((0L, 2L, false), (1L, 4L, true)).foreach { case (nodeId1, nodeId2, 
enableWholeStage) =>
-        val df = spark.sql(query)
+        val df = spark.sql(query1)
         testSparkPlanMetrics(df, 1, Map(
           nodeId1 -> (("SortMergeJoin", Map(
             // It's 4 because we only read 3 rows in the first partition and 1 
row in the second one
-            "number of output rows" -> 4L))),
+            "number of output rows" -> 4L,
+            "number of matched rows" -> 4L))),
           nodeId2 -> (("Exchange", Map(
             "records read" -> 4L,
             "local blocks read" -> 2L,
             "remote blocks read" -> 0L,
             "shuffle records written" -> 2L)))),
           enableWholeStage
         )
+
+        val query2 = "SELECT * FROM testData2 JOIN testDataForJoin ON " +
+          "testData2.a = testDataForJoin.a AND testData2.b <= 
testDataForJoin.b"
+        val df2 = spark.sql(query2)
+        Seq(false, true).foreach { case  enableWholeStage =>
+          testSparkPlanMetrics(df2, 1, Map(
+            0L -> ("SortMergeJoin", Map(
+              "number of output rows" -> 3L,
+              "number of matched rows" -> 4L))),
+            enableWholeStage
+          )
+        }
+

Review comment:
       Done. Fixed in the latest commit.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
##########
@@ -277,20 +277,34 @@ class SQLMetricsSuite extends SharedSparkSession with 
SQLMetricsTestUtils
     withTempView("testDataForJoin") {
       // Assume the execution plan is
       // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
-      val query = "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a 
= testDataForJoin.a"
+      val query1 = "SELECT * FROM testData2 JOIN testDataForJoin ON 
testData2.a = testDataForJoin.a"
       Seq((0L, 2L, false), (1L, 4L, true)).foreach { case (nodeId1, nodeId2, 
enableWholeStage) =>
-        val df = spark.sql(query)
+        val df = spark.sql(query1)
         testSparkPlanMetrics(df, 1, Map(
           nodeId1 -> (("SortMergeJoin", Map(
             // It's 4 because we only read 3 rows in the first partition and 1 
row in the second one
-            "number of output rows" -> 4L))),
+            "number of output rows" -> 4L,
+            "number of matched rows" -> 4L))),
           nodeId2 -> (("Exchange", Map(
             "records read" -> 4L,
             "local blocks read" -> 2L,
             "remote blocks read" -> 0L,
             "shuffle records written" -> 2L)))),
           enableWholeStage
         )
+
+        val query2 = "SELECT * FROM testData2 JOIN testDataForJoin ON " +

Review comment:
       Yes, that was a mistake. Fixed in the latest commit.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
##########
@@ -303,18 +317,35 @@ class SQLMetricsSuite extends SharedSparkSession with 
SQLMetricsTestUtils
     withTempView("testDataForJoin") {
       // Assume the execution plan is
       // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
-      val leftJoinQuery = "SELECT * FROM testData2 left JOIN testDataForJoin 
ON " +
+      val query1 = "SELECT * FROM testData2 LEFT JOIN testDataForJoin ON " +
         "testData2.a = testDataForJoin.a"
-      val rightJoinQuery = "SELECT * FROM testDataForJoin right JOIN testData2 
ON " +
+      val query2 = "SELECT * FROM testDataForJoin RIGHT JOIN testData2 ON " +
         "testData2.a = testDataForJoin.a"
-
-      Seq((leftJoinQuery, false), (leftJoinQuery, true), (rightJoinQuery, 
false),
-        (rightJoinQuery, true)).foreach { case (query, enableWholeStage) =>
+      val query3 = "SELECT * FROM testData2 RIGHT JOIN testDataForJoin ON " +
+        "testData2.a = testDataForJoin.a"
+      val query4 = "SELECT * FROM testData2 FULL OUTER JOIN testDataForJoin ON 
" +
+        "testData2.a = testDataForJoin.a"
+      val boundCondition1 = " AND testData2.b >= testDataForJoin.b"
+      val boundCondition2 = " AND testData2.a >= testDataForJoin.b"
+
+      Seq((query1, 8L, false),
+        (query1 + boundCondition1, 7L, false),
+        (query1 + boundCondition1, 7L, true),
+        (query3 + boundCondition2, 3L, false),
+        (query3 + boundCondition2, 3L, true),
+        (query4, 8L, false),
+        (query4, 8L, true),
+        (query4 + boundCondition1, 7L, false),
+        (query4 + boundCondition1, 7L, true),
+        (query1, 8L, true),
+        (query2, 8L, false),
+        (query2, 8L, true)).foreach { case (query, rows, enableWholeStage) =>
         val df = spark.sql(query)
         testSparkPlanMetrics(df, 1, Map(
           0L -> (("SortMergeJoin", Map(
             // It's 8 because we read 6 rows in the left and 2 row in the 
right one

Review comment:
       Yes, I removed the comment entirely as there are multiple queries and 
the comment is no longer valid for all the queries.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
##########
@@ -324,13 +355,22 @@ class SQLMetricsSuite extends SharedSparkSession with 
SQLMetricsTestUtils
   test("BroadcastHashJoin metrics") {
     val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
     val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value")
-    // Assume the execution plan is
-    // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
-    Seq((1L, false), (2L, true)).foreach { case (nodeId, enableWholeStage) =>
-      val df = df1.join(broadcast(df2), "key")
+    val df3 = Seq((1, 1), (2, 4)).toDF("key", "value1")
+    val df4 = Seq((1, 1), (2, 2), (3, 3), (4, 4)).toDF("key", "value2")
+
+    Seq((false, df1, df2, 1L, 2L, false),
+      (false, df1, df2, 2L, 2L, true),
+      (true, df3, df4, 2L, 1L, true),
+      (true, df3, df4, 1L, 1L, false)
+    ).foreach { case (boundCondition, dfLeft, dfRight, nodeId, rows, 
enableWholeStage) =>
+      var df = dfLeft.join(broadcast(dfRight), "key")
+      if (boundCondition) {
+        df = df.filter("value1 > value2")
+      }
       testSparkPlanMetrics(df, 2, Map(
         nodeId -> (("BroadcastHashJoin", Map(
-          "number of output rows" -> 2L)))),
+          "number of output rows" -> rows,
+          "number of matched rows" -> 2L)))),

Review comment:
       Done. Fixed in the latest commit.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
##########
@@ -369,23 +409,38 @@ class SQLMetricsSuite extends SharedSparkSession with 
SQLMetricsTestUtils
     }
   }
 
-  test("ShuffledHashJoin(left, outer) metrics") {
-    val leftDf = Seq((1, "1"), (2, "2")).toDF("key", "value")
-    val rightDf = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key2", 
"value")
-    Seq((0L, "right_outer", leftDf, rightDf, 10L, false),
-      (0L, "left_outer", rightDf, leftDf, 10L, false),
-      (1L, "right_outer", leftDf, rightDf, 10L, true),
-      (1L, "left_outer", rightDf, leftDf, 10L, true),
-      (2L, "left_anti", rightDf, leftDf, 8L, true),
-      (2L, "left_semi", rightDf, leftDf, 2L, true),
-      (1L, "left_anti", rightDf, leftDf, 8L, false),
-      (1L, "left_semi", rightDf, leftDf, 2L, false))
-      .foreach { case (nodeId, joinType, leftDf, rightDf, rows, 
enableWholeStage) =>
+  test("ShuffledHashJoin(left/right/outer/semi/anti, outer) metrics") {

Review comment:
       Done! fixed in latest commit.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to