This is an automated email from the ASF dual-hosted git repository.
zhanglistar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new a76c92e82 [GLUTEN-5668][CH] Support mixed conditions in shuffle hash
join (#5735)
a76c92e82 is described below
commit a76c92e82f75250cd834f51ee88f82a7664c6562
Author: lgbo <[email protected]>
AuthorDate: Mon Jun 3 16:09:41 2024 +0800
[GLUTEN-5668][CH] Support mixed conditions in shuffle hash join (#5735)
* support inequal join
* fixed bugs in CI
* fixed performance issue in broadcast join
* broadcast join build changed
---
.../clickhouse/CHSparkPlanExecApi.scala | 14 +-
.../execution/CHHashJoinExecTransformer.scala | 14 +-
.../execution/CHSortMergeJoinExecTransformer.scala | 8 +-
.../apache/gluten/utils/CHJoinValidateUtil.scala | 94 ++---
...nClickHouseTPCDSParquetGraceHashJoinSuite.scala | 153 +-------
.../GlutenClickHouseTPCDSParquetSuite.scala | 69 +---
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 18 +
cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp | 3 +-
.../Join/StorageJoinFromReadBuffer.cpp | 129 ++++--
.../local-engine/Join/StorageJoinFromReadBuffer.h | 29 +-
cpp-ch/local-engine/Parser/JoinRelParser.cpp | 434 +++++++++++++++------
cpp-ch/local-engine/Parser/JoinRelParser.h | 24 +-
.../local-engine/Parser/SerializedPlanParser.cpp | 2 +-
13 files changed, 522 insertions(+), 469 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 1403c8261..bdbdfed0d 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -27,7 +27,7 @@ import
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverri
import org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, WindowFunctionNode}
-import org.apache.gluten.utils.CHJoinValidateUtil
+import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy}
import org.apache.gluten.vectorized.CHColumnarBatchSerializer
import org.apache.spark.{ShuffleDependency, SparkException}
@@ -694,15 +694,19 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
}
/**
- * Define whether the join operator is fallback because of the join operator
is not supported by
- * backend
+ * This is only used to control whether transform smj into shj or not at
present. We always prefer
+ * shj.
*/
override def joinFallback(
- JoinType: JoinType,
+ joinType: JoinType,
leftOutputSet: AttributeSet,
rightOutputSet: AttributeSet,
condition: Option[Expression]): Boolean = {
- CHJoinValidateUtil.shouldFallback(JoinType, leftOutputSet, rightOutputSet,
condition)
+ CHJoinValidateUtil.shouldFallback(
+ UnknownJoinStrategy(joinType),
+ leftOutputSet,
+ rightOutputSet,
+ condition)
}
/** Generate window function node */
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index c3ab89df5..6004f7f86 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -18,7 +18,7 @@ package org.apache.gluten.execution
import org.apache.gluten.backendsapi.clickhouse.CHIteratorApi
import org.apache.gluten.extension.ValidationResult
-import org.apache.gluten.utils.CHJoinValidateUtil
+import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil,
ShuffleHashJoinStrategy}
import org.apache.spark.{broadcast, SparkContext}
import org.apache.spark.rdd.RDD
@@ -55,7 +55,11 @@ case class CHShuffledHashJoinExecTransformer(
override protected def doValidateInternal(): ValidationResult = {
val shouldFallback =
- CHJoinValidateUtil.shouldFallback(joinType, left.outputSet,
right.outputSet, condition)
+ CHJoinValidateUtil.shouldFallback(
+ ShuffleHashJoinStrategy(joinType),
+ left.outputSet,
+ right.outputSet,
+ condition)
if (shouldFallback) {
return ValidationResult.notOk("ch join validate fail")
}
@@ -107,7 +111,11 @@ case class CHBroadcastHashJoinExecTransformer(
override protected def doValidateInternal(): ValidationResult = {
val shouldFallback =
- CHJoinValidateUtil.shouldFallback(joinType, left.outputSet,
right.outputSet, condition)
+ CHJoinValidateUtil.shouldFallback(
+ BroadcastHashJoinStrategy(joinType),
+ left.outputSet,
+ right.outputSet,
+ condition)
if (shouldFallback) {
return ValidationResult.notOk("ch join validate fail")
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
index a5ac5f658..e2b586551 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.execution
import org.apache.gluten.extension.ValidationResult
-import org.apache.gluten.utils.CHJoinValidateUtil
+import org.apache.gluten.utils.{CHJoinValidateUtil, SortMergeJoinStrategy}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -44,7 +44,11 @@ case class CHSortMergeJoinExecTransformer(
override protected def doValidateInternal(): ValidationResult = {
val shouldFallback =
- CHJoinValidateUtil.shouldFallback(joinType, left.outputSet,
right.outputSet, condition, true)
+ CHJoinValidateUtil.shouldFallback(
+ SortMergeJoinStrategy(joinType),
+ left.outputSet,
+ right.outputSet,
+ condition)
if (shouldFallback) {
return ValidationResult.notOk("ch join validate fail")
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
index 06b2445af..dae8e6e07 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
@@ -17,9 +17,17 @@
package org.apache.gluten.utils
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.expressions.{AttributeSet, EqualTo,
Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual,
Not, Or}
+import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression}
import org.apache.spark.sql.catalyst.plans.JoinType
+trait JoinStrategy {
+ val joinType: JoinType
+}
+case class UnknownJoinStrategy(joinType: JoinType) extends JoinStrategy {}
+case class ShuffleHashJoinStrategy(joinType: JoinType) extends JoinStrategy {}
+case class BroadcastHashJoinStrategy(joinType: JoinType) extends JoinStrategy
{}
+case class SortMergeJoinStrategy(joinType: JoinType) extends JoinStrategy {}
+
/**
* The logic here is that if it is not an equi-join spark will create BNLJ,
which will fallback, if
* it is an equi-join, spark will create BroadcastHashJoin or ShuffleHashJoin,
for these join types,
@@ -34,78 +42,40 @@ object CHJoinValidateUtil extends Logging {
def hasTwoTableColumn(
leftOutputSet: AttributeSet,
rightOutputSet: AttributeSet,
- l: Expression,
- r: Expression): Boolean = {
- val allReferences = l.references ++ r.references
+ expr: Expression): Boolean = {
+ val allReferences = expr.references
!(allReferences.subsetOf(leftOutputSet) ||
allReferences.subsetOf(rightOutputSet))
}
def shouldFallback(
- joinType: JoinType,
+ joinStrategy: JoinStrategy,
leftOutputSet: AttributeSet,
rightOutputSet: AttributeSet,
- condition: Option[Expression],
- isSMJ: Boolean = false): Boolean = {
+ condition: Option[Expression]): Boolean = {
var shouldFallback = false
+ val joinType = joinStrategy.joinType
if (joinType.toString.contains("ExistenceJoin")) {
return true
}
- if (joinType.sql.equals("INNER")) {
- return shouldFallback
- }
- if (isSMJ) {
- if (
- joinType.sql.contains("SEMI")
- || joinType.sql.contains("ANTI")
- ) {
- return true
+ if (joinType.sql.contains("INNER")) {
+ shouldFallback = false;
+ } else if (
+ condition.isDefined && hasTwoTableColumn(leftOutputSet, rightOutputSet,
condition.get)
+ ) {
+ shouldFallback = joinStrategy match {
+ case BroadcastHashJoinStrategy(joinTy) =>
+ joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
+ case SortMergeJoinStrategy(_) => true
+ case ShuffleHashJoinStrategy(joinTy) =>
+ joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
+ case UnknownJoinStrategy(joinTy) =>
+ joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
}
- }
- if (condition.isDefined) {
- condition.get.transform {
- case Or(l, r) =>
- if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
- shouldFallback = true
- }
- Or(l, r)
- case Not(EqualTo(l, r)) =>
- if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
- shouldFallback = true
- }
- Not(EqualTo(l, r))
- case LessThan(l, r) =>
- if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
- shouldFallback = true
- }
- LessThan(l, r)
- case LessThanOrEqual(l, r) =>
- if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
- shouldFallback = true
- }
- LessThanOrEqual(l, r)
- case GreaterThan(l, r) =>
- if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
- shouldFallback = true
- }
- GreaterThan(l, r)
- case GreaterThanOrEqual(l, r) =>
- if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
- shouldFallback = true
- }
- GreaterThanOrEqual(l, r)
- case In(l, r) =>
- r.foreach(
- e => {
- if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, e)) {
- shouldFallback = true
- }
- })
- In(l, r)
- case EqualTo(l, r) =>
- if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
- shouldFallback = true
- }
- EqualTo(l, r)
+ } else {
+ shouldFallback = joinStrategy match {
+ case SortMergeJoinStrategy(joinTy) =>
+ joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
+ case _ => false
}
}
shouldFallback
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetGraceHashJoinSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetGraceHashJoinSuite.scala
index 0b7ad9a6d..04ccda29b 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetGraceHashJoinSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetGraceHashJoinSuite.scala
@@ -16,13 +16,9 @@
*/
package org.apache.gluten.execution
-import org.apache.gluten.test.FallbackUtil
-
import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression,
Not}
-import org.apache.spark.sql.execution._
+import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
-import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
class GlutenClickHouseTPCDSParquetGraceHashJoinSuite extends
GlutenClickHouseTPCDSAbstractSuite {
@@ -39,105 +35,11 @@ class GlutenClickHouseTPCDSParquetGraceHashJoinSuite
extends GlutenClickHouseTPC
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.memory.offHeap.size", "8g")
.set("spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm",
"grace_hash")
-
.set("spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join",
"3145728")
+
.set("spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join",
"314572800")
}
executeTPCDSTest(false);
- test(
- "test fallback operations not supported by ch backend " +
- "in CHHashJoinExecTransformer && CHBroadcastHashJoinExecTransformer") {
- val testSql =
- """
- | SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id,
i_manufact,
- | sum(ss_ext_sales_price) AS ext_price
- | FROM date_dim
- | LEFT JOIN store_sales ON d_date_sk = ss_sold_date_sk
- | LEFT JOIN item ON ss_item_sk = i_item_sk AND i_manager_id = 7
- | LEFT JOIN customer ON ss_customer_sk = c_customer_sk
- | LEFT JOIN customer_address ON c_current_addr_sk = ca_address_sk
- | LEFT JOIN store ON ss_store_sk = s_store_sk AND substr(ca_zip,1,5)
<> substr(s_zip,1,5)
- | WHERE d_moy = 11
- | AND d_year = 1999
- | GROUP BY i_brand_id, i_brand, i_manufact_id, i_manufact
- | ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id,
i_manufact
- | LIMIT 100;
- |""".stripMargin
-
- val df = spark.sql(testSql)
- val operateWithCondition = df.queryExecution.executedPlan.collect {
- case f: BroadcastHashJoinExec if f.condition.get.isInstanceOf[Not] => f
- }
- assert(
- operateWithCondition(0).left
- .asInstanceOf[InputAdapter]
- .child
- .isInstanceOf[CHColumnarToRowExec])
- }
-
- test("test fallbackutils") {
- val testSql =
- """
- | SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id,
i_manufact,
- | sum(ss_ext_sales_price) AS ext_price
- | FROM date_dim
- | LEFT JOIN store_sales ON d_date_sk = ss_sold_date_sk
- | LEFT JOIN item ON ss_item_sk = i_item_sk AND i_manager_id = 7
- | LEFT JOIN customer ON ss_customer_sk = c_customer_sk
- | LEFT JOIN customer_address ON c_current_addr_sk = ca_address_sk
- | LEFT JOIN store ON ss_store_sk = s_store_sk AND substr(ca_zip,1,5)
<> substr(s_zip,1,5)
- | WHERE d_moy = 11
- | AND d_year = 1999
- | GROUP BY i_brand_id, i_brand, i_manufact_id, i_manufact
- | ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id,
i_manufact
- | LIMIT 100;
- |""".stripMargin
-
- val df = spark.sql(testSql)
- assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
- }
-
- test("Gluten-4458: test clickhouse not support join with IN condition") {
- val testSql =
- """
- | SELECT *
- | FROM date_dim t1
- | LEFT JOIN date_dim t2 ON t1.d_date_sk = t2.d_date_sk
- | AND datediff(t1.d_day_name, t2.d_day_name) IN (1, 3)
- | LIMIT 100;
- |""".stripMargin
-
- val df = spark.sql(testSql)
- assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
- }
-
- test("Gluten-4458: test join with Equal computing two table in one side") {
- val testSql =
- """
- | SELECT *
- | FROM date_dim t1
- | LEFT JOIN date_dim t2 ON t1.d_date_sk = t2.d_date_sk AND t1.d_year -
t2.d_year = 1
- | LIMIT 100;
- |""".stripMargin
-
- val df = spark.sql(testSql)
- assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
- }
-
- test("Gluten-4458: test inner join can support join with IN condition") {
- val testSql =
- """
- | SELECT *
- | FROM date_dim t1
- | INNER JOIN date_dim t2 ON t1.d_date_sk = t2.d_date_sk
- | AND datediff(t1.d_day_name, t2.d_day_name) IN (1, 3)
- | LIMIT 100;
- |""".stripMargin
-
- val df = spark.sql(testSql)
- assert(!FallbackUtil.hasFallback(df.queryExecution.executedPlan))
- }
-
test("Gluten-1235: Fix missing reading from the broadcasted value when
executing DPP") {
val testSql =
"""
@@ -198,55 +100,4 @@ class GlutenClickHouseTPCDSParquetGraceHashJoinSuite
extends GlutenClickHouseTPC
}
}
}
-
- test("TPCDS Q21 with non-separated scan rdd") {
- withSQLConf(("spark.gluten.sql.columnar.separate.scan.rdd.for.ch",
"false")) {
- runTPCDSQuery("q21") {
- df =>
- val foundDynamicPruningExpr = df.queryExecution.executedPlan.find {
- case f: FileSourceScanExecTransformer =>
- f.partitionFilters.exists {
- case _: DynamicPruningExpression => true
- case _ => false
- }
- case _ => false
- }
- assert(foundDynamicPruningExpr.nonEmpty == true)
-
- val reuseExchange = df.queryExecution.executedPlan.find {
- case r: ReusedExchangeExec => true
- case _ => false
- }
- assert(reuseExchange.nonEmpty == true)
- }
- }
- }
-
- test("Gluten-4452: Fix get wrong hash table when multi joins in a task") {
- val testSql =
- """
- | SELECT ws_item_sk, ws_sold_date_sk, ws_ship_date_sk,
- | t3.d_date_id as sold_date_id, t2.d_date_id as ship_date_id
- | FROM (
- | SELECT ws_item_sk, ws_sold_date_sk, ws_ship_date_sk, t1.d_date_id
- | FROM web_sales
- | LEFT JOIN
- | (SELECT d_date_id, d_date_sk from date_dim GROUP BY d_date_id,
d_date_sk) t1
- | ON ws_sold_date_sk == t1.d_date_sk) t3
- | INNER JOIN
- | (SELECT d_date_id, d_date_sk from date_dim GROUP BY d_date_id,
d_date_sk) t2
- | ON ws_ship_date_sk == t2.d_date_sk
- | LIMIT 100;
- |""".stripMargin
- compareResultsAgainstVanillaSpark(
- testSql,
- true,
- df => {
- val foundBroadcastHashJoinExpr =
df.queryExecution.executedPlan.collect {
- case f: CHBroadcastHashJoinExecTransformer => f
- }
- assert(foundBroadcastHashJoinExpr.size == 2)
- }
- )
- }
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSuite.scala
index a63e47888..e9c27437b 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSuite.scala
@@ -16,13 +16,10 @@
*/
package org.apache.gluten.execution
-import org.apache.gluten.test.FallbackUtil
-
import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression,
Not}
+import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
SortMergeJoinExec}
// Some sqls' line length exceeds 100
// scalastyle:off line.size.limit
@@ -121,38 +118,7 @@ class GlutenClickHouseTPCDSParquetSuite extends
GlutenClickHouseTPCDSAbstractSui
assert(result(0).getLong(0) == 73049)
}
- test(
- "test fallback operations not supported by ch backend " +
- "in CHHashJoinExecTransformer && CHBroadcastHashJoinExecTransformer") {
- val testSql =
- """
- |SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id,
i_manufact,
- | sum(ss_ext_sales_price) AS ext_price
- | FROM date_dim
- | LEFT JOIN store_sales ON d_date_sk = ss_sold_date_sk
- | LEFT JOIN item ON ss_item_sk = i_item_sk AND i_manager_id = 7
- | LEFT JOIN customer ON ss_customer_sk = c_customer_sk
- | LEFT JOIN customer_address ON c_current_addr_sk = ca_address_sk
- | LEFT JOIN store ON ss_store_sk = s_store_sk AND substr(ca_zip,1,5)
<> substr(s_zip,1,5)
- | WHERE d_moy = 11
- | AND d_year = 1999
- | GROUP BY i_brand_id, i_brand, i_manufact_id, i_manufact
- | ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id,
i_manufact
- | LIMIT 100;
- |""".stripMargin
-
- val df = spark.sql(testSql)
- val operateWithCondition = df.queryExecution.executedPlan.collect {
- case f: BroadcastHashJoinExec if f.condition.get.isInstanceOf[Not] => f
- }
- assert(
- operateWithCondition(0).left
- .asInstanceOf[InputAdapter]
- .child
- .isInstanceOf[CHColumnarToRowExec])
- }
-
- test("test fallbackutils") {
+ test("Test join with mixed condition 1") {
val testSql =
"""
|SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id,
i_manufact,
@@ -169,36 +135,7 @@ class GlutenClickHouseTPCDSParquetSuite extends
GlutenClickHouseTPCDSAbstractSui
| ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id,
i_manufact
| LIMIT 100;
|""".stripMargin
-
- val df = spark.sql(testSql)
- assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
- }
-
- test(
- "Test avoid forceShuffledHashJoin when the join condition" +
- " does not supported by the backend") {
- val testSql =
- """
- |SELECT /*+ merge(date_dim)*/ i_brand_id AS brand_id, i_brand AS
brand, i_manufact_id, i_manufact,
- | sum(ss_ext_sales_price) AS ext_price
- | FROM date_dim
- | LEFT JOIN store_sales ON d_date_sk == ss_sold_date_sk AND (d_date_sk
= 213232 OR ss_sold_date_sk = 3232)
- | LEFT JOIN item ON ss_item_sk = i_item_sk AND i_manager_id = 7
- | LEFT JOIN customer ON ss_customer_sk = c_customer_sk
- | LEFT JOIN customer_address ON c_current_addr_sk = ca_address_sk
- | LEFT JOIN store ON ss_store_sk = s_store_sk AND substr(ca_zip,1,5)
<> substr(s_zip,1,5)
- | WHERE d_moy = 11
- | AND d_year = 1999
- | GROUP BY i_brand_id, i_brand, i_manufact_id, i_manufact
- | ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id,
i_manufact
- | LIMIT 100;
- |""".stripMargin
-
- val df = spark.sql(testSql)
- val sortMergeJoinExec = df.queryExecution.executedPlan.collect {
- case s: SortMergeJoinExec => s
- }
- assert(sortMergeJoinExec.nonEmpty)
+ compareResultsAgainstVanillaSpark(testSql, true, _ => {})
}
test("Gluten-1235: Fix missing reading from the broadcasted value when
executing DPP") {
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 748bd5a7f..038b170df 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -2563,5 +2563,23 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
spark.sql("drop table test_tbl_5896")
}
+
+ test("Inequal join support") {
+ withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1")) {
+ spark.sql("create table ineq_join_t1 (key bigint, value bigint) using
parquet");
+ spark.sql("create table ineq_join_t2 (key bigint, value bigint) using
parquet");
+ spark.sql("insert into ineq_join_t1 values(1, 1), (2, 2), (3, 3), (4,
4), (5, 5)");
+ spark.sql("insert into ineq_join_t2 values(2, 2), (2, 1), (3, 3), (4,
6), (5, 3)");
+ val sql =
+ """
+ | select t1.key, t1.value, t2.key, t2.value from ineq_join_t1 as t1
+ | left join ineq_join_t2 as t2
+ | on t1.key = t2.key and t1.value > t2.value
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql, true, { _ => })
+ spark.sql("drop table ineq_join_t1")
+ spark.sql("drop table ineq_join_t2")
+ }
+ }
}
// scalastyle:on line.size.limit
diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
index d90951241..f1b3ac2fb 100644
--- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
+++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp
@@ -103,7 +103,8 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
row_count,
key_names,
true,
- std::make_shared<DB::TableJoin>(SizeLimits(), true, kind, strictness,
key_names),
+ kind,
+ strictness,
columns_description,
ConstraintsDescription(),
key,
diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
index 6d0021adb..f0aec6af6 100644
--- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
+++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
@@ -15,7 +15,9 @@
* limitations under the License.
*/
#include "StorageJoinFromReadBuffer.h"
+#include <algorithm>
+#include <DataTypes/DataTypeNullable.h>
#include <Interpreters/Context.h>
#include <Interpreters/HashJoin.h>
#include <Interpreters/TableJoin.h>
@@ -23,6 +25,9 @@
#include <Storages/IO/NativeReader.h>
#include <Common/Exception.h>
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+
namespace DB
{
class HashJoin;
@@ -40,25 +45,23 @@ extern const int DEADLOCK_AVOIDED;
using namespace DB;
-void restore(DB::ReadBuffer & in, IJoin & join, const Block & sample_block)
-{
- local_engine::NativeReader block_stream(in);
- ProfileInfo info;
- while (Block block = block_stream.read())
- {
- auto final_block =
sample_block.cloneWithColumns(block.mutateColumns());
- info.update(final_block);
- join.addBlockToJoin(final_block, true);
- }
-}
+constexpr auto RIHGT_COLUMN_PREFIX = "broadcast_right_";
DB::Block rightSampleBlock(bool use_nulls, const StorageInMemoryMetadata &
storage_metadata_, JoinKind kind)
{
+ DB::ColumnsWithTypeAndName new_cols;
DB::Block block = storage_metadata_.getSampleBlock();
- if (use_nulls && isLeftOrFull(kind))
- for (auto & col : block)
- DB::JoinCommon::convertColumnToNullable(col);
- return block;
+ for (const auto & col : block)
+ {
+ // Add a prefix to avoid column name conflicts with left table.
+ new_cols.emplace_back(col.column, col.type, RIHGT_COLUMN_PREFIX +
col.name);
+ if (use_nulls && isLeftOrFull(kind))
+ {
+ auto & new_col = new_cols.back();
+ DB::JoinCommon::convertColumnToNullable(new_col);
+ }
+ }
+ return DB::Block(new_cols);
}
namespace local_engine
@@ -67,46 +70,88 @@ namespace local_engine
StorageJoinFromReadBuffer::StorageJoinFromReadBuffer(
DB::ReadBuffer & in,
size_t row_count_,
- const Names & key_names,
- bool use_nulls,
- std::shared_ptr<DB::TableJoin> table_join,
+ const Names & key_names_,
+ bool use_nulls_,
+ DB::JoinKind kind,
+ DB::JoinStrictness strictness,
const ColumnsDescription & columns,
const ConstraintsDescription & constraints,
const String & comment,
- const bool overwrite)
- : key_names_(key_names), use_nulls_(use_nulls)
+ const bool overwrite_)
+ : key_names({}), use_nulls(use_nulls_), row_count(row_count_),
overwrite(overwrite_)
{
- storage_metadata_.setColumns(columns);
- storage_metadata_.setConstraints(constraints);
- storage_metadata_.setComment(comment);
+ storage_metadata.setColumns(columns);
+ storage_metadata.setConstraints(constraints);
+ storage_metadata.setComment(comment);
- for (const auto & key : key_names)
- if (!storage_metadata_.getColumns().hasPhysical(key))
+ for (const auto & key : key_names_)
+ if (!storage_metadata.getColumns().hasPhysical(key))
throw Exception(ErrorCodes::NO_SUCH_COLUMN_IN_TABLE, "Key column
({}) does not exist in table declaration.", key);
- right_sample_block_ = rightSampleBlock(use_nulls, storage_metadata_,
table_join->kind());
- join_ = std::make_shared<HashJoin>(table_join, right_sample_block_,
overwrite, row_count_);
- restore(in, *join_, storage_metadata_.getSampleBlock());
+ for (const auto & name : key_names_)
+ key_names.push_back(RIHGT_COLUMN_PREFIX + name);
+ auto table_join = std::make_shared<DB::TableJoin>(SizeLimits(), true,
kind, strictness, key_names);
+ right_sample_block = rightSampleBlock(use_nulls, storage_metadata,
table_join->kind());
+ buildJoin(in, right_sample_block, table_join);
+}
+
+/// The column names may be different in two blocks.
+/// and the nullability also could be different, with TPCDS-Q1 as an example.
+static DB::ColumnWithTypeAndName convertColumnAsNecessary(const
DB::ColumnWithTypeAndName & column, const DB::ColumnWithTypeAndName &
sample_column)
+{
+ if (sample_column.type->equals(*column.type))
+ return {column.column, column.type, sample_column.name};
+ else if (
+ sample_column.type->isNullable() && !column.type->isNullable()
+ && DB::removeNullable(sample_column.type)->equals(*column.type))
+ {
+ auto nullable_column = column;
+ DB::JoinCommon::convertColumnToNullable(nullable_column);
+ return {nullable_column.column, sample_column.type,
sample_column.name};
+ }
+ else
+ throw DB::Exception(
+ DB::ErrorCodes::LOGICAL_ERROR,
+ "Columns have different types. original:{} expected:{}",
+ column.dumpStructure(),
+ sample_column.dumpStructure());
+}
+
+void StorageJoinFromReadBuffer::buildJoin(DB::ReadBuffer & in, const Block
header, std::shared_ptr<DB::TableJoin> analyzed_join)
+{
+ local_engine::NativeReader block_stream(in);
+ ProfileInfo info;
+ join = std::make_shared<HashJoin>(analyzed_join, header, overwrite,
row_count);
+ while (Block block = block_stream.read())
+ {
+ DB::ColumnsWithTypeAndName columns;
+ for (size_t i = 0; i < block.columns(); ++i)
+ {
+ const auto & column = block.getByPosition(i);
+ columns.emplace_back(convertColumnAsNecessary(column,
header.getByPosition(i)));
+ }
+ DB::Block final_block(columns);
+ info.update(final_block);
+ join->addBlockToJoin(final_block, true);
+ }
}
-DB::JoinPtr
StorageJoinFromReadBuffer::getJoinLocked(std::shared_ptr<DB::TableJoin>
analyzed_join, DB::ContextPtr /*context*/) const
+/// The column names of 'rgiht_header' could be different from the ones in
`input_blocks`, and we must
+/// use 'right_header' to build the HashJoin. Otherwise, it will cause
exceptions with name mismatches.
+///
+/// In most cases, 'getJoinLocked' is called only once, and the input_blocks
should not be too large.
+/// This is will be OK.
+DB::JoinPtr
StorageJoinFromReadBuffer::getJoinLocked(std::shared_ptr<DB::TableJoin>
analyzed_join, DB::ContextPtr /*context*/)
{
- if ((analyzed_join->forceNullableRight() && !use_nulls_)
- || (!analyzed_join->forceNullableRight() &&
isLeftOrFull(analyzed_join->kind()) && use_nulls_))
+ if ((analyzed_join->forceNullableRight() && !use_nulls)
+ || (!analyzed_join->forceNullableRight() &&
isLeftOrFull(analyzed_join->kind()) && use_nulls))
throw Exception(
ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN,
"Table {} needs the same join_use_nulls setting as present in LEFT
or FULL JOIN",
- storage_metadata_.comment);
-
- /// TODO: check key columns
-
- /// Set names qualifiers: table.column -> column
- /// It's required because storage join stores non-qualified names
- /// Qualifies will be added by join implementation (HashJoin)
- analyzed_join->setRightKeys(key_names_);
-
- HashJoinPtr join_clone = std::make_shared<HashJoin>(analyzed_join,
right_sample_block_);
- join_clone->reuseJoinedData(static_cast<const HashJoin &>(*join_));
+ storage_metadata.comment);
+ HashJoinPtr join_clone = std::make_shared<HashJoin>(analyzed_join,
right_sample_block);
+ /// reuseJoinedData will set the flag `HashJoin::from_storage_join` which
is required by `FilledStep`
+ join_clone->reuseJoinedData(static_cast<const HashJoin &>(*join));
return join_clone;
}
}
diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
index 2e949fa87..af623c0cd 100644
--- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
+++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h
@@ -23,6 +23,8 @@ namespace DB
class TableJoin;
class IJoin;
using JoinPtr = std::shared_ptr<IJoin>;
+class HashJoin;
+class ReadBuffer;
}
namespace local_engine
@@ -33,23 +35,32 @@ class StorageJoinFromReadBuffer
public:
StorageJoinFromReadBuffer(
DB::ReadBuffer & in_,
- size_t row_count_,
+ size_t row_count,
const DB::Names & key_names_,
bool use_nulls_,
- std::shared_ptr<DB::TableJoin> table_join_,
+ DB::JoinKind kind,
+ DB::JoinStrictness strictness,
const DB::ColumnsDescription & columns_,
const DB::ConstraintsDescription & constraints_,
const String & comment,
bool overwrite_);
- DB::JoinPtr getJoinLocked(std::shared_ptr<DB::TableJoin> analyzed_join,
DB::ContextPtr context) const;
- const DB::Block & getRightSampleBlock() const { return
right_sample_block_; }
+ /// The columns' names in right_header may be different from the names in
the ColumnsDescription
+ /// in the constructor.
+ /// This should be called once.
+ DB::JoinPtr getJoinLocked(std::shared_ptr<DB::TableJoin> analyzed_join,
DB::ContextPtr context);
+ const DB::Block & getRightSampleBlock() const { return right_sample_block;
}
private:
- DB::StorageInMemoryMetadata storage_metadata_;
- const DB::Names key_names_;
- bool use_nulls_;
- DB::JoinPtr join_;
- DB::Block right_sample_block_;
+ DB::StorageInMemoryMetadata storage_metadata;
+ DB::Names key_names;
+ bool use_nulls;
+ size_t row_count;
+ bool overwrite;
+ DB::Block right_sample_block;
+ std::shared_ptr<DB::HashJoin> join = nullptr;
+
+ void readAllBlocksFromInput(DB::ReadBuffer & in);
+ void buildJoin(DB::ReadBuffer & in, const DB::Block header,
std::shared_ptr<DB::TableJoin> analyzed_join);
};
}
diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp
b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
index 8f7f35d5e..937e449b0 100644
--- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
@@ -31,6 +31,10 @@
#include <Processors/QueryPlan/JoinStep.h>
#include <google/protobuf/wrappers.pb.h>
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+
+
namespace DB
{
namespace ErrorCodes
@@ -179,40 +183,167 @@ DB::QueryPlanPtr JoinRelParser::parseOp(const
substrait::Rel & rel, std::list<co
return parseJoin(join, std::move(left_plan), std::move(right_plan));
}
+std::unordered_set<DB::JoinTableSide>
JoinRelParser::extractTableSidesFromExpression(const substrait::Expression &
expr, const DB::Block & left_header, const DB::Block & right_header)
+{
+ std::unordered_set<DB::JoinTableSide> table_sides;
+ if (expr.has_scalar_function())
+ {
+ for (const auto & arg : expr.scalar_function().arguments())
+ {
+ auto table_sides_from_arg =
extractTableSidesFromExpression(arg.value(), left_header, right_header);
+ table_sides.insert(table_sides_from_arg.begin(),
table_sides_from_arg.end());
+ }
+ }
+ else if (expr.has_selection() && expr.selection().has_direct_reference()
&& expr.selection().direct_reference().has_struct_field())
+ {
+ auto pos = expr.selection().direct_reference().struct_field().field();
+ if (pos < left_header.columns())
+ {
+ table_sides.insert(DB::JoinTableSide::Left);
+ }
+ else
+ {
+ table_sides.insert(DB::JoinTableSide::Right);
+ }
+ }
+ else if (expr.has_singular_or_list())
+ {
+ auto child_table_sides =
extractTableSidesFromExpression(expr.singular_or_list().value(), left_header,
right_header);
+ table_sides.insert(child_table_sides.begin(), child_table_sides.end());
+ for (const auto & option : expr.singular_or_list().options())
+ {
+ child_table_sides = extractTableSidesFromExpression(option,
left_header, right_header);
+ table_sides.insert(child_table_sides.begin(),
child_table_sides.end());
+ }
+ }
+ else if (expr.has_cast())
+ {
+ auto child_table_sides =
extractTableSidesFromExpression(expr.cast().input(), left_header, right_header);
+ table_sides.insert(child_table_sides.begin(), child_table_sides.end());
+ }
+ else if (expr.has_if_then())
+ {
+ for (const auto & if_child : expr.if_then().ifs())
+ {
+ auto child_table_sides =
extractTableSidesFromExpression(if_child.if_(), left_header, right_header);
+ table_sides.insert(child_table_sides.begin(),
child_table_sides.end());
+ child_table_sides =
extractTableSidesFromExpression(if_child.then(), left_header, right_header);
+ table_sides.insert(child_table_sides.begin(),
child_table_sides.end());
+ }
+ auto child_table_sides =
extractTableSidesFromExpression(expr.if_then().else_(), left_header,
right_header);
+ table_sides.insert(child_table_sides.begin(), child_table_sides.end());
+ }
+ else if (expr.has_literal())
+ {
+ // nothing
+ }
+ else
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Illegal expression
'{}'", expr.DebugString());
+ }
+ return table_sides;
+}
+
+
+void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan &
right, const StorageJoinFromReadBuffer & storage_join)
+{
+ /// To support mixed join conditions, we must make sure that the column
names in the right be the same as
+ /// storage_join's right sample block.
+ ActionsDAGPtr project = ActionsDAG::makeConvertingActions(
+ right.getCurrentDataStream().header.getColumnsWithTypeAndName(),
+ storage_join.getRightSampleBlock().getColumnsWithTypeAndName(),
+ ActionsDAG::MatchColumnsMode::Position);
+
+ if (project)
+ {
+ QueryPlanStepPtr project_step =
std::make_unique<ExpressionStep>(right.getCurrentDataStream(), project);
+ project_step->setStepDescription("Rename Broadcast Table Name");
+ steps.emplace_back(project_step.get());
+ right.addStep(std::move(project_step));
+ }
+
+ /// If the columns name in right table is duplicated with left table, we
need to rename the left table's columns,
+ /// avoid the columns name in the right table be changed in
`addConvertStep`.
+ /// This could happen in tpc-ds q44.
+ DB::ColumnsWithTypeAndName new_left_cols;
+ const auto & right_header = right.getCurrentDataStream().header;
+ auto left_prefix = getUniqueName("left");
+ for (const auto & col : left.getCurrentDataStream().header)
+ {
+ if (right_header.has(col.name))
+ {
+ new_left_cols.emplace_back(col.column, col.type, left_prefix +
col.name);
+ }
+ else
+ {
+ new_left_cols.emplace_back(col.column, col.type, col.name);
+ }
+ }
+ project = ActionsDAG::makeConvertingActions(
+ left.getCurrentDataStream().header.getColumnsWithTypeAndName(),
+ new_left_cols,
+ ActionsDAG::MatchColumnsMode::Position);
+
+ if (project)
+ {
+ QueryPlanStepPtr project_step =
std::make_unique<ExpressionStep>(left.getCurrentDataStream(), project);
+ project_step->setStepDescription("Rename Left Table Name for broadcast
join");
+ steps.emplace_back(project_step.get());
+ left.addStep(std::move(project_step));
+ }
+}
+
DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join,
DB::QueryPlanPtr left, DB::QueryPlanPtr right)
{
auto join_opt_info = parseJoinOptimizationInfo(join);
auto storage_join = join_opt_info.is_broadcast ?
BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key) : nullptr;
-
if (storage_join)
{
- ActionsDAGPtr project = ActionsDAG::makeConvertingActions(
- right->getCurrentDataStream().header.getColumnsWithTypeAndName(),
- storage_join->getRightSampleBlock().getColumnsWithTypeAndName(),
- ActionsDAG::MatchColumnsMode::Position);
+ renamePlanColumns(*left, *right, *storage_join);
+ }
+
+ auto table_join = createDefaultTableJoin(join.type());
+ DB::Block right_header_before_convert_step =
right->getCurrentDataStream().header;
+ addConvertStep(*table_join, *left, *right);
- if (project)
+ // Add a check to find error easily.
+ if (storage_join)
+ {
+ if(!blocksHaveEqualStructure(right_header_before_convert_step,
right->getCurrentDataStream().header))
{
- QueryPlanStepPtr project_step =
std::make_unique<ExpressionStep>(right->getCurrentDataStream(), project);
- project_step->setStepDescription("Rename Broadcast Table Name");
- steps.emplace_back(project_step.get());
- right->addStep(std::move(project_step));
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast
join, we must not change the columns name in the right table.\nleft
header:{},\nright header: {} -> {}",
+ left->getCurrentDataStream().header.dumpNames(),
+ right_header_before_convert_step.dumpNames(),
+ right->getCurrentDataStream().header.dumpNames());
}
}
- auto table_join = createDefaultTableJoin(join.type());
- addConvertStep(*table_join, *left, *right);
Names after_join_names;
auto left_names = left->getCurrentDataStream().header.getNames();
after_join_names.insert(after_join_names.end(), left_names.begin(),
left_names.end());
auto right_name = table_join->columnsFromJoinedTable().getNames();
after_join_names.insert(after_join_names.end(), right_name.begin(),
right_name.end());
- bool add_filter_step = tryAddPushDownFilter(*table_join, join, *left,
*right, table_join->columnsFromJoinedTable(), after_join_names);
+
+ auto left_header = left->getCurrentDataStream().header;
+ auto right_header = right->getCurrentDataStream().header;
QueryPlanPtr query_plan;
+
+ /// Support only one join clause.
+ table_join->addDisjunct();
+ /// some examples to explain when the post_join_filter is not empty
+ /// - on t1.key = t2.key and t1.v1 > 1 and t2.v1 > 1, 't1.v1> 1' is in the
post filter. but 't2.v1 > 1'
+ /// will be pushed down into right table by spark and is not in the post
filter. 't1.key = t2.key ' is
+ /// in JoinRel::expression.
+ /// - on t1.key = t2. key and t1.v1 > t2.v2, 't1.v1 > t2.v2' is in the
post filter.
+ collectJoinKeys(*table_join, join, left_header, right_header);
+
if (storage_join)
{
+
+ applyJoinFilter(*table_join, join, *left, *right, true);
auto broadcast_hash_join = storage_join->getJoinLocked(table_join,
context);
+
QueryPlanStepPtr join_step =
std::make_unique<FilledJoinStep>(left->getCurrentDataStream(),
broadcast_hash_join, 8192);
join_step->setStepDescription("STORAGE_JOIN");
@@ -224,6 +355,18 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const
substrait::JoinRel & join, DB::Q
}
else if (join_opt_info.is_smj)
{
+ bool need_post_filter = !applyJoinFilter(*table_join, join, *left,
*right, false);
+
+ /// If applyJoinFilter returns false, it means there are mixed
conditions in the post_join_filter.
+ /// It should be a inner join.
+ /// TODO: make smj support mixed conditions
+ if (need_post_filter && table_join->kind() != DB::JoinKind::Inner)
+ {
+ throw DB::Exception(
+ DB::ErrorCodes::LOGICAL_ERROR,
+ "Sort merge join doesn't support mixed join conditions, except
inner join.");
+ }
+
JoinPtr smj_join = std::make_shared<FullSortingMergeJoin>(table_join,
right->getCurrentDataStream().header.cloneEmpty(), -1);
MultiEnum<DB::JoinAlgorithm> join_algorithm =
context->getSettingsRef().join_algorithm;
QueryPlanStepPtr join_step
@@ -237,12 +380,14 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const
substrait::JoinRel & join, DB::Q
query_plan = std::make_unique<QueryPlan>();
query_plan->unitePlans(std::move(join_step), {std::move(plans)});
+ if (need_post_filter)
+ addPostFilter(*query_plan, join);
}
else
{
- /// TODO: make grace hash join be the default hash join algorithm.
- ///
- /// Following is some configuration for grace hash join.
+ applyJoinFilter(*table_join, join, *left, *right, true);
+
+ /// Following is some configurations for grace hash join.
/// -
spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash.
This will
/// enable grace hash join.
/// -
spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728.
This setup
@@ -278,28 +423,15 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const
substrait::JoinRel & join, DB::Q
}
reorderJoinOutput(*query_plan, after_join_names);
- if (add_filter_step)
- {
- addPostFilter(*query_plan, join);
- }
return query_plan;
}
void JoinRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan &
left, DB::QueryPlan & right)
{
-
- /// After https://github.com/ClickHouse/ClickHouse/pull/61216, We will
failed at tryPushDownFilter() in filterPushDown.cpp
- /// Here is a workaround, refer to chooseJoinAlgorithm() in
PlannerJoins.cpp, it always call TableJoin::setRename to
- /// create aliases for columns in the right table
- /// By using right table header name sets, so
TableJoin::deduplicateAndQualifyColumnNames can do same thing as
chooseJoinAlgorithm()
- ///
- /// Affected UT fixed bh this workaround:
- /// GlutenClickHouseTPCHParquetRFSuite:TPCH Q17, Q19, Q20, Q21
+ /// If the columns name in right table is duplicated with left table, we
need to rename the right table's columns.
NameSet left_columns_set;
- for (const auto & col : right.getCurrentDataStream().header.getNames())
- {
+ for (const auto & col : left.getCurrentDataStream().header.getNames())
left_columns_set.emplace(col);
- }
table_join.setColumnsFromJoinedTable(
right.getCurrentDataStream().header.getNamesAndTypesList(),
left_columns_set, getUniqueName("right") + ".");
@@ -360,117 +492,179 @@ void JoinRelParser::addConvertStep(TableJoin &
table_join, DB::QueryPlan & left,
}
}
-void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const
substrait::JoinRel & join)
+/// Join keys are collected from substrait::JoinRel::expression() which only
contains the equal join conditions.
+void JoinRelParser::collectJoinKeys(
+ TableJoin & table_join, const substrait::JoinRel & join_rel, const
DB::Block & left_header, const DB::Block & right_header)
{
- std::string filter_name;
- auto actions_dag =
std::make_shared<ActionsDAG>(query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName());
- if (!join.post_join_filter().has_scalar_function())
+ if (!join_rel.has_expression())
+ return;
+ const auto & expr = join_rel.expression();
+ auto & join_clause = table_join.getClauses().back();
+ std::list<const const substrait::Expression *> expressions_stack;
+ expressions_stack.push_back(&expr);
+ while (!expressions_stack.empty())
{
- // It may be singular_or_list
- auto * in_node = getPlanParser()->parseExpression(actions_dag,
join.post_join_filter());
- filter_name = in_node->result_name;
- }
- else
- {
-
getPlanParser()->parseFunction(query_plan.getCurrentDataStream().header,
join.post_join_filter(), filter_name, actions_dag, true);
+ /// Must handle the expressions in DF order. It matters in sort merge
join.
+ const auto * current_expr = expressions_stack.back();
+ expressions_stack.pop_back();
+ if (!current_expr->has_scalar_function())
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Function
expression is expected");
+ auto function_name =
parseFunctionName(current_expr->scalar_function().function_reference(),
current_expr->scalar_function());
+ if (!function_name)
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid
function expression");
+ if (*function_name == "equals")
+ {
+ String left_key, right_key;
+ size_t left_pos = 0, right_pos = 0;
+ for (const auto & arg :
current_expr->scalar_function().arguments())
+ {
+ if (!arg.value().has_selection() ||
!arg.value().selection().has_direct_reference()
+ ||
!arg.value().selection().direct_reference().has_struct_field())
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A
column reference is expected");
+ }
+ auto col_pos_ref =
arg.value().selection().direct_reference().struct_field().field();
+ if (col_pos_ref < left_header.columns())
+ {
+ left_pos = col_pos_ref;
+ left_key = left_header.getByPosition(col_pos_ref).name;
+ }
+ else
+ {
+ right_pos = col_pos_ref - left_header.columns();
+ right_key = right_header.getByPosition(col_pos_ref -
left_header.columns()).name;
+ }
+ }
+ if (left_key.empty() || right_key.empty())
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid
key equal join condition");
+ join_clause.addKey(left_key, right_key, false);
+ }
+ else if (*function_name == "and")
+ {
+
expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(1).value());
+
expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(0).value());
+ }
+ else
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow
function: {}", *function_name);
+ }
}
- auto filter_step =
std::make_unique<FilterStep>(query_plan.getCurrentDataStream(), actions_dag,
filter_name, true);
- filter_step->setStepDescription("Post Join Filter");
- steps.emplace_back(filter_step.get());
- query_plan.addStep(std::move(filter_step));
}
-bool JoinRelParser::tryAddPushDownFilter(
- TableJoin & table_join,
- const substrait::JoinRel & join,
- DB::QueryPlan & left,
- DB::QueryPlan & right,
- const NamesAndTypesList & alias_right,
- const Names & names)
+bool JoinRelParser::applyJoinFilter(
+ DB::TableJoin & table_join, const substrait::JoinRel & join_rel,
DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition)
{
- try
+ if (!join_rel.has_post_join_filter())
+ return true;
+ const auto & expr = join_rel.post_join_filter();
+
+ const auto & left_header = left.getCurrentDataStream().header;
+ const auto & right_header = right.getCurrentDataStream().header;
+ ColumnsWithTypeAndName mixed_columns;
+ std::unordered_set<String> added_column_name;
+ for (const auto & col : left_header.getColumnsWithTypeAndName())
+ {
+ mixed_columns.emplace_back(col);
+ added_column_name.insert(col.name);
+ }
+ for (const auto & col : right_header.getColumnsWithTypeAndName())
{
- ASTParser astParser(context, function_mapping, getPlanParser());
- ASTs args;
+ const auto & renamed_col_name =
table_join.renamedRightColumnNameWithAlias(col.name);
+ if (added_column_name.find(col.name) != added_column_name.end())
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Right column's
name conflict with left column: {}", col.name);
+ mixed_columns.emplace_back(col);
+ added_column_name.insert(col.name);
+ }
+ DB::Block mixed_header(mixed_columns);
- if (join.has_expression())
- {
- args.emplace_back(astParser.parseToAST(names, join.expression()));
- }
+ auto table_sides = extractTableSidesFromExpression(expr, left_header,
right_header);
- if (join.has_post_join_filter())
+ auto get_input_expressions = [](const DB::Block & header)
+ {
+ std::vector<substrait::Expression> exprs;
+ for (size_t i = 0; i < header.columns(); ++i)
{
- args.emplace_back(astParser.parseToAST(names,
join.post_join_filter()));
+ substrait::Expression expr;
+
expr.mutable_selection()->mutable_direct_reference()->mutable_struct_field()->set_field(i);
+ exprs.emplace_back(expr);
}
-
- if (args.empty())
- return false;
-
- ASTPtr ast = args.size() == 1 ? args.back() : makeASTFunction("and",
args);
-
- bool is_asof = (table_join.strictness() == JoinStrictness::Asof);
-
- Aliases aliases;
- DatabaseAndTableWithAlias left_table_name;
- DatabaseAndTableWithAlias right_table_name;
- TableWithColumnNamesAndTypes left_table(left_table_name,
left.getCurrentDataStream().header.getNamesAndTypesList());
- TableWithColumnNamesAndTypes right_table(right_table_name,
alias_right);
-
- CollectJoinOnKeysVisitor::Data data{table_join, left_table,
right_table, aliases, is_asof};
- if (auto * or_func = ast->as<ASTFunction>(); or_func && or_func->name
== "or")
+ return exprs;
+ };
+
+ /// If the columns in the expression are all from one table, use
analyzer_left_filter_condition_column_name
+ /// and analyzer_left_filter_condition_column_name to filt the join result
data. It requires to build the filter
+ /// column at first.
+ /// If the columns in the expression are from both tables, use
mixed_join_expression to filt the join result data.
+ /// the filter columns will be built inner the join step.
+ if (table_sides.size() == 1)
+ {
+ auto table_side = *table_sides.begin();
+ if (table_side == DB::JoinTableSide::Left)
{
- for (auto & disjunct : or_func->arguments->children)
- {
- table_join.addDisjunct();
- CollectJoinOnKeysVisitor(data).visit(disjunct);
- }
- assert(table_join.getClauses().size() ==
or_func->arguments->children.size());
+ auto input_exprs = get_input_expressions(left_header);
+ input_exprs.push_back(expr);
+ auto actions_dag = expressionsToActionsDAG(input_exprs,
left_header);
+
table_join.getClauses().back().analyzer_left_filter_condition_column_name =
actions_dag->getOutputs().back()->result_name;
+ QueryPlanStepPtr before_join_step =
std::make_unique<ExpressionStep>(left.getCurrentDataStream(), actions_dag);
+ before_join_step->setStepDescription("Before JOIN LEFT");
+ steps.emplace_back(before_join_step.get());
+ left.addStep(std::move(before_join_step));
}
else
{
- table_join.addDisjunct();
- CollectJoinOnKeysVisitor(data).visit(ast);
- assert(table_join.oneDisjunct());
- }
-
- if (join.has_post_join_filter())
- {
- auto left_keys = table_join.leftKeysList();
- auto right_keys = table_join.rightKeysList();
- if (!left_keys->children.empty())
+ /// since the field reference in expr is the index of left_header
++ right_header, so we use
+ /// mixed_header to build the actions_dag
+ auto input_exprs = get_input_expressions(mixed_header);
+ input_exprs.push_back(expr);
+ auto actions_dag = expressionsToActionsDAG(input_exprs,
mixed_header);
+
+ /// clear unused columns in actions_dag
+ for (const auto & col : left_header.getColumnsWithTypeAndName())
{
- auto actions =
astParser.convertToActions(left.getCurrentDataStream().header.getNamesAndTypesList(),
left_keys);
- QueryPlanStepPtr before_join_step =
std::make_unique<ExpressionStep>(left.getCurrentDataStream(), actions);
- before_join_step->setStepDescription("Before JOIN LEFT");
- steps.emplace_back(before_join_step.get());
- left.addStep(std::move(before_join_step));
+ actions_dag->removeUnusedResult(col.name);
}
+ actions_dag->removeUnusedActions();
- if (!right_keys->children.empty())
- {
- auto actions =
astParser.convertToActions(right.getCurrentDataStream().header.getNamesAndTypesList(),
right_keys);
- QueryPlanStepPtr before_join_step =
std::make_unique<ExpressionStep>(right.getCurrentDataStream(), actions);
- before_join_step->setStepDescription("Before JOIN RIGHT");
- steps.emplace_back(before_join_step.get());
- right.addStep(std::move(before_join_step));
- }
+
table_join.getClauses().back().analyzer_right_filter_condition_column_name =
actions_dag->getOutputs().back()->result_name;
+ QueryPlanStepPtr before_join_step =
std::make_unique<ExpressionStep>(right.getCurrentDataStream(), actions_dag);
+ before_join_step->setStepDescription("Before JOIN RIGHT");
+ steps.emplace_back(before_join_step.get());
+ right.addStep(std::move(before_join_step));
}
}
- // if ch does not support the join type or join conditions, it will throw
an exception like 'not support'.
- catch (Poco::Exception & e)
+ else if (table_sides.size() == 2)
{
- // CH not support join condition has 'or' and has different table in
each side.
- // But in inner join, we could execute join condition after join. so
we have add filter step
- if (e.code() == ErrorCodes::INVALID_JOIN_ON_EXPRESSION &&
table_join.kind() == DB::JoinKind::Inner)
- {
- return true;
- }
- else
- {
- throw;
- }
+ if (!allow_mixed_condition)
+ return false;
+ auto mixed_join_expressions_actions = expressionsToActionsDAG({expr},
mixed_header);
+ table_join.getMixedJoinExpression()
+ =
std::make_shared<DB::ExpressionActions>(mixed_join_expressions_actions,
ExpressionActionsSettings::fromContext(context));
}
- return false;
+ else
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not any table
column is used in the join condition.\n{}", join_rel.DebugString());
+ }
+ return true;
+}
+
+void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const
substrait::JoinRel & join)
+{
+ std::string filter_name;
+ auto actions_dag =
std::make_shared<ActionsDAG>(query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName());
+ if (!join.post_join_filter().has_scalar_function())
+ {
+ // It may be singular_or_list
+ auto * in_node = getPlanParser()->parseExpression(actions_dag,
join.post_join_filter());
+ filter_name = in_node->result_name;
+ }
+ else
+ {
+
getPlanParser()->parseFunction(query_plan.getCurrentDataStream().header,
join.post_join_filter(), filter_name, actions_dag, true);
+ }
+ auto filter_step =
std::make_unique<FilterStep>(query_plan.getCurrentDataStream(), actions_dag,
filter_name, true);
+ filter_step->setStepDescription("Post Join Filter");
+ steps.emplace_back(filter_step.get());
+ query_plan.addStep(std::move(filter_step));
}
void registerJoinRelParser(RelParserFactory & factory)
diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h
b/cpp-ch/local-engine/Parser/JoinRelParser.h
index 445b7e683..c423f4390 100644
--- a/cpp-ch/local-engine/Parser/JoinRelParser.h
+++ b/cpp-ch/local-engine/Parser/JoinRelParser.h
@@ -17,6 +17,7 @@
#pragma once
#include <memory>
+#include <unordered_set>
#include <Parser/RelParser.h>
#include <substrait/algebra.pb.h>
@@ -28,6 +29,8 @@ class TableJoin;
namespace local_engine
{
+class StorageJoinFromReadBuffer;
+
std::pair<DB::JoinKind, DB::JoinStrictness>
getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type);
class JoinRelParser : public RelParser
@@ -50,15 +53,22 @@ private:
DB::QueryPlanPtr parseJoin(const substrait::JoinRel & join,
DB::QueryPlanPtr left, DB::QueryPlanPtr right);
+ void renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & right, const
StorageJoinFromReadBuffer & storage_join);
void addConvertStep(TableJoin & table_join, DB::QueryPlan & left,
DB::QueryPlan & right);
- bool tryAddPushDownFilter(
- TableJoin & table_join,
- const substrait::JoinRel & join,
- DB::QueryPlan & left,
- DB::QueryPlan & right,
- const NamesAndTypesList & alias_right,
- const Names & names);
+ void collectJoinKeys(
+ TableJoin & table_join, const substrait::JoinRel & join_rel, const
DB::Block & left_header, const DB::Block & right_header);
+
+ bool applyJoinFilter(
+ DB::TableJoin & table_join,
+ const substrait::JoinRel & join_rel,
+ DB::QueryPlan & left_plan,
+ DB::QueryPlan & right_plan,
+ bool allow_mixed_condition);
+
void addPostFilter(DB::QueryPlan & plan, const substrait::JoinRel & join);
+
+ static std::unordered_set<DB::JoinTableSide>
extractTableSidesFromExpression(
+ const substrait::Expression & expr, const DB::Block & left_header,
const DB::Block & right_header);
};
}
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index a26f78699..b0d3bbeca 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -214,7 +214,7 @@ std::shared_ptr<ActionsDAG>
SerializedPlanParser::expressionsToActionsDAG(
}
}
}
- else if (expr.has_cast() || expr.has_if_then() || expr.has_literal())
+ else if (expr.has_cast() || expr.has_if_then() || expr.has_literal()
|| expr.has_singular_or_list())
{
const auto * node = parseExpression(actions_dag, expr);
actions_dag->addOrReplaceInOutputs(*node);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]