This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new c6639efe0 [CORE][VL] ACBO: Add GlutenMetadataModel, move Gluten schema
def from property model to metadata model (#5159)
c6639efe0 is described below
commit c6639efe0504a2193db1e1ed49a324bc1c7b91d7
Author: Hongze Zhang <[email protected]>
AuthorDate: Fri Mar 29 12:55:46 2024 +0800
[CORE][VL] ACBO: Add GlutenMetadataModel, move Gluten schema def from
property model to metadata model (#5159)
---
.../io/glutenproject/planner/VeloxCboSuite.scala | 15 +-
.../src/main/scala/io/glutenproject/cbo/Cbo.scala | 37 ++++-
.../scala/io/glutenproject/cbo/CboCluster.scala | 11 +-
.../main/scala/io/glutenproject/cbo/CboGroup.scala | 2 +-
.../cbo/{PlanModel.scala => MetadataModel.scala} | 22 +--
.../scala/io/glutenproject/cbo/PlanModel.scala | 2 +-
.../glutenproject/cbo/memo/ForwardMemoTable.scala | 34 ++--
.../scala/io/glutenproject/cbo/memo/Memo.scala | 7 +-
.../io/glutenproject/cbo/memo/MemoTable.scala | 2 +-
.../io/glutenproject/cbo/CboMetadataSuite.scala | 183 +++++++++++++++++++++
.../io/glutenproject/cbo/CboOperationSuite.scala | 28 ++--
.../io/glutenproject/cbo/CboPropertySuite.scala | 44 +++--
.../test/scala/io/glutenproject/cbo/CboSuite.scala | 45 +++--
.../scala/io/glutenproject/cbo/CboSuiteBase.scala | 27 ++-
.../io/glutenproject/cbo/mock/MockMemoState.scala | 14 +-
.../io/glutenproject/cbo/path/CboPathSuite.scala | 11 +-
.../glutenproject/cbo/path/PathFinderSuite.scala | 17 +-
.../io/glutenproject/cbo/path/PathMaskSuite.scala | 4 +-
.../io/glutenproject/cbo/path/WizardSuite.scala | 17 +-
.../io/glutenproject/cbo/rule/PatternSuite.scala | 22 ++-
.../cbo/specific/CyclicSearchSpaceSuite.scala | 15 +-
.../cbo/specific/DistributedSuite.scala | 27 ++-
.../cbo/specific/JoinReorderSuite.scala | 41 +++--
.../extension/columnar/EnumeratedTransform.scala | 4 +-
.../glutenproject/planner/GlutenOptimization.scala | 7 +-
.../planner/metadata/GlutenMetadata.scala | 27 +--
.../planner/metadata/GlutenMetadataModel.scala | 48 ++++++
.../planner/plan/GlutenPlanModel.scala | 18 +-
.../planner/property/GlutenPropertyModel.scala | 35 +---
29 files changed, 555 insertions(+), 211 deletions(-)
diff --git
a/backends-velox/src/test/scala/io/glutenproject/planner/VeloxCboSuite.scala
b/backends-velox/src/test/scala/io/glutenproject/planner/VeloxCboSuite.scala
index 83ce7c69a..2fbc1e642 100644
--- a/backends-velox/src/test/scala/io/glutenproject/planner/VeloxCboSuite.scala
+++ b/backends-velox/src/test/scala/io/glutenproject/planner/VeloxCboSuite.scala
@@ -16,11 +16,12 @@
*/
package io.glutenproject.planner
-import io.glutenproject.cbo.{Cbo, CboSuiteBase}
+import io.glutenproject.cbo.Cbo
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.path.CboPath
import io.glutenproject.cbo.property.PropertySet
import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
-import io.glutenproject.planner.property.GlutenProperties.{Conventions,
Schemas}
+import io.glutenproject.planner.property.GlutenProperties.Conventions
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -41,7 +42,7 @@ class VeloxCboSuite extends SharedSparkSession {
test("C2R, R2C - explicitly requires any properties") {
val in = RowUnary(RowLeaf())
val planner =
- newCbo().newPlanner(in, PropertySet(List(Conventions.ANY, Schemas.ANY)))
+ newCbo().newPlanner(in, PropertySet(List(Conventions.ANY)))
val out = planner.plan()
assert(out == RowUnary(RowLeaf()))
}
@@ -49,7 +50,7 @@ class VeloxCboSuite extends SharedSparkSession {
test("C2R, R2C - requires columnar output") {
val in = RowUnary(RowLeaf())
val planner =
- newCbo().newPlanner(in, PropertySet(List(Conventions.VANILLA_COLUMNAR,
Schemas.ANY)))
+ newCbo().newPlanner(in, PropertySet(List(Conventions.VANILLA_COLUMNAR)))
val out = planner.plan()
assert(out == RowToColumnarExec(RowUnary(RowLeaf())))
}
@@ -58,7 +59,7 @@ class VeloxCboSuite extends SharedSparkSession {
val in =
ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowLeaf())))))))
val planner =
- newCbo().newPlanner(in, PropertySet(List(Conventions.ROW_BASED,
Schemas.ANY)))
+ newCbo().newPlanner(in, PropertySet(List(Conventions.ROW_BASED)))
val out = planner.plan()
assert(out == ColumnarToRowExec(ColumnarUnary(
RowToColumnarExec(RowUnary(RowUnary(ColumnarToRowExec(ColumnarUnary(RowToColumnarExec(
@@ -82,7 +83,7 @@ class VeloxCboSuite extends SharedSparkSession {
ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowLeaf())))))))
val planner =
newCbo(List(ConvertRowUnaryToColumnar))
- .newPlanner(in, PropertySet(List(Conventions.ROW_BASED, Schemas.ANY)))
+ .newPlanner(in, PropertySet(List(Conventions.ROW_BASED)))
val out = planner.plan()
assert(out == ColumnarToRowExec(ColumnarUnary(ColumnarUnary(ColumnarUnary(
ColumnarUnary(ColumnarUnary(ColumnarUnary(ColumnarUnary(RowToColumnarExec(RowLeaf()))))))))))
@@ -92,7 +93,7 @@ class VeloxCboSuite extends SharedSparkSession {
}
}
-object VeloxCboSuite extends CboSuiteBase {
+object VeloxCboSuite {
def newCbo(): Cbo[SparkPlan] = {
GlutenOptimization().asInstanceOf[Cbo[SparkPlan]]
}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/Cbo.scala
b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/Cbo.scala
index fa735ec2f..0a723dbf3 100644
--- a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/Cbo.scala
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/Cbo.scala
@@ -38,12 +38,13 @@ trait Optimization[T <: AnyRef] {
object Optimization {
def apply[T <: AnyRef](
- costModel: CostModel[T],
planModel: PlanModel[T],
+ costModel: CostModel[T],
+ metadataModel: MetadataModel[T],
propertyModel: PropertyModel[T],
explain: CboExplain[T],
ruleFactory: CboRule.Factory[T]): Optimization[T] = {
- Cbo(costModel, planModel, propertyModel, explain, ruleFactory)
+ Cbo(planModel, costModel, metadataModel, propertyModel, explain,
ruleFactory)
}
implicit class OptimizationImplicits[T <: AnyRef](opt: Optimization[T]) {
@@ -58,8 +59,9 @@ object Optimization {
class Cbo[T <: AnyRef] private (
val config: CboConfig,
- val costModel: CostModel[T],
val planModel: PlanModel[T],
+ val costModel: CostModel[T],
+ val metadataModel: MetadataModel[T],
val propertyModel: PropertyModel[T],
val explain: CboExplain[T],
val ruleFactory: CboRule.Factory[T])
@@ -67,12 +69,19 @@ class Cbo[T <: AnyRef] private (
import Cbo._
override def withNewConfig(confFunc: CboConfig => CboConfig): Cbo[T] = {
- new Cbo(confFunc(config), costModel, planModel, propertyModel, explain,
ruleFactory)
+ new Cbo(
+ confFunc(config),
+ planModel,
+ costModel,
+ metadataModel,
+ propertyModel,
+ explain,
+ ruleFactory)
}
// Normal groups start with ID 0, so it's safe to use -1 to do validation.
private val dummyGroup: T =
- planModel.newGroupLeaf(-1, PropertySet(Seq.empty))
+ planModel.newGroupLeaf(-1, metadataModel.dummy(), PropertySet(Seq.empty))
private val infCost: Cost = costModel.makeInfCost()
validateModels()
@@ -97,6 +106,12 @@ class Cbo[T <: AnyRef] private (
// Node groups don't have user-defined cost, expect exception here.
costModel.costOf(dummyGroup)
}
+ assertThrows(
+ "Group is not allowed to return its metadata directly to optimizer
(optimizer already" +
+ " knew that). It's expected to throw an exception when getting its
metadata but not") {
+ // Node groups don't have user-defined cost, expect exception here.
+ metadataModel.metadataOf(dummyGroup)
+ }
propertyModel.propertyDefs.foreach {
propDef =>
// Node groups don't have user-defined property, expect exception here.
@@ -160,12 +175,20 @@ class Cbo[T <: AnyRef] private (
object Cbo {
private[cbo] def apply[T <: AnyRef](
- costModel: CostModel[T],
planModel: PlanModel[T],
+ costModel: CostModel[T],
+ metadataModel: MetadataModel[T],
propertyModel: PropertyModel[T],
explain: CboExplain[T],
ruleFactory: CboRule.Factory[T]): Cbo[T] = {
- new Cbo[T](CboConfig(), costModel, planModel, propertyModel, explain,
ruleFactory)
+ new Cbo[T](
+ CboConfig(),
+ planModel,
+ costModel,
+ metadataModel,
+ propertyModel,
+ explain,
+ ruleFactory)
}
trait PropertySetFactory[T <: AnyRef] {
diff --git
a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboCluster.scala
b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboCluster.scala
index 153050cef..d55aa3513 100644
--- a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboCluster.scala
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboCluster.scala
@@ -21,7 +21,9 @@ import io.glutenproject.cbo.property.PropertySet
import scala.collection.mutable
-trait CboClusterKey
+trait CboClusterKey {
+ def metadata: Metadata
+}
object CboClusterKey {
implicit class CboClusterKeyImplicits[T <: AnyRef](key: CboClusterKey) {
@@ -44,11 +46,11 @@ object CboCluster {
}
object MutableCboCluster {
- def apply[T <: AnyRef](cbo: Cbo[T]): MutableCboCluster[T] = {
- new RegularMutableCboCluster(cbo)
+ def apply[T <: AnyRef](cbo: Cbo[T], metadata: Metadata):
MutableCboCluster[T] = {
+ new RegularMutableCboCluster(cbo, metadata)
}
- private class RegularMutableCboCluster[T <: AnyRef](val cbo: Cbo[T])
+ private class RegularMutableCboCluster[T <: AnyRef](val cbo: Cbo[T],
metadata: Metadata)
extends MutableCboCluster[T] {
private val buffer: mutable.Set[CanonicalNode[T]] =
mutable.Set()
@@ -58,6 +60,7 @@ object CboCluster {
}
override def add(t: CanonicalNode[T]): Unit = {
+ cbo.metadataModel.verify(metadata,
cbo.metadataModel.metadataOf(t.self()))
assert(!buffer.contains(t))
buffer += t
}
diff --git
a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboGroup.scala
b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboGroup.scala
index 025e664ec..c83475da9 100644
--- a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboGroup.scala
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboGroup.scala
@@ -42,7 +42,7 @@ object CboGroup {
override val id: Int,
override val propSet: PropertySet[T])
extends CboGroup[T] {
- private val groupLeaf: T = cbo.planModel.newGroupLeaf(id, propSet)
+ private val groupLeaf: T = cbo.planModel.newGroupLeaf(id,
clusterKey.metadata, propSet)
override def clusterKey(): CboClusterKey = clusterKey
override def self(): T = groupLeaf
diff --git
a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/MetadataModel.scala
similarity index 66%
copy from gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
copy to
gluten-cbo/common/src/main/scala/io/glutenproject/cbo/MetadataModel.scala
index 366d1575f..aea5cb154 100644
--- a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/MetadataModel.scala
@@ -16,17 +16,13 @@
*/
package io.glutenproject.cbo
-import io.glutenproject.cbo.property.PropertySet
-
-trait PlanModel[T <: AnyRef] {
- // Trivial tree operations.
- def childrenOf(node: T): Seq[T]
- def withNewChildren(node: T, children: Seq[T]): T
- def hashCode(node: T): Int
- def equals(one: T, other: T): Boolean
-
- // Group operations.
- def newGroupLeaf(groupId: Int, propSet: PropertySet[T]): T
- def isGroupLeaf(node: T): Boolean
- def getGroupId(node: T): Int
+/**
+ * Metadata defines the common traits among nodes in one single cluster. E.g.
Schema, statistics.
+ */
+trait MetadataModel[T <: AnyRef] {
+ def metadataOf(node: T): Metadata
+ def dummy(): Metadata
+ def verify(one: Metadata, other: Metadata): Unit
}
+
+trait Metadata {}
diff --git
a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
index 366d1575f..e51a644df 100644
--- a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
@@ -26,7 +26,7 @@ trait PlanModel[T <: AnyRef] {
def equals(one: T, other: T): Boolean
// Group operations.
- def newGroupLeaf(groupId: Int, propSet: PropertySet[T]): T
+ def newGroupLeaf(groupId: Int, meta: Metadata, propSet: PropertySet[T]): T
def isGroupLeaf(node: T): Boolean
def getGroupId(node: T): Int
}
diff --git
a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/ForwardMemoTable.scala
b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/ForwardMemoTable.scala
index 50d27a679..a5186150a 100644
---
a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/ForwardMemoTable.scala
+++
b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/ForwardMemoTable.scala
@@ -44,11 +44,11 @@ class ForwardMemoTable[T <: AnyRef] private (override val
cbo: Cbo[T])
clusterBuffer(ancestor)
}
- override def newCluster(): CboClusterKey = {
+ override def newCluster(metadata: Metadata): CboClusterKey = {
checkBufferSizes()
- val key = IntClusterKey(clusterBuffer.size)
+ val key = IntClusterKey(clusterBuffer.size, metadata)
clusterKeyBuffer += key
- clusterBuffer += MutableCboCluster(cbo)
+ clusterBuffer += MutableCboCluster(cbo, metadata)
clusterDisjointSet.grow()
groupLookup += mutable.Map()
key
@@ -62,7 +62,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val
cbo: Cbo[T])
}
val gid = groupBuffer.size
val newGroup =
- CboGroup(cbo, IntClusterKey(ancestor), gid, propSet)
+ CboGroup(cbo, IntClusterKey(ancestor, key.metadata), gid, propSet)
lookup += propSet -> newGroup
groupBuffer += newGroup
memoWriteCount += 1
@@ -88,19 +88,21 @@ class ForwardMemoTable[T <: AnyRef] private (override val
cbo: Cbo[T])
return
}
- case class Merge(from: Int, to: Int)
+ case class Merge(from: CboClusterKey, to: CboClusterKey) {
+ cbo.metadataModel.verify(from.metadata, to.metadata)
+ }
val merge = if (oneAncestor > otherAncestor) {
- Merge(oneAncestor, otherAncestor)
+ Merge(clusterKeyBuffer(oneAncestor), clusterKeyBuffer(otherAncestor))
} else {
- Merge(otherAncestor, oneAncestor)
+ Merge(clusterKeyBuffer(otherAncestor), clusterKeyBuffer(oneAncestor))
}
- val fromKey = IntClusterKey(merge.from)
- val toKey = IntClusterKey(merge.to)
+ val fromKey = merge.from
+ val toKey = merge.to
- val fromCluster = clusterBuffer(merge.from)
- val toCluster = clusterBuffer(merge.to)
+ val fromCluster = clusterBuffer(fromKey.id())
+ val toCluster = clusterBuffer(toKey.id())
// Add absent nodes.
fromCluster.nodes().foreach {
@@ -111,8 +113,8 @@ class ForwardMemoTable[T <: AnyRef] private (override val
cbo: Cbo[T])
}
// Add absent groups.
- val fromGroups = groupLookup(merge.from)
- val toGroups = groupLookup(merge.to)
+ val fromGroups = groupLookup(fromKey.id())
+ val toGroups = groupLookup(toKey.id())
fromGroups.foreach {
case (fromPropSet, _) =>
if (!toGroups.contains(fromPropSet)) {
@@ -121,8 +123,8 @@ class ForwardMemoTable[T <: AnyRef] private (override val
cbo: Cbo[T])
}
// Forward the element in disjoint set.
- clusterDisjointSet.forward(merge.from, merge.to)
- clusterMergeLog += (merge.from -> merge.to)
+ clusterDisjointSet.forward(fromKey.id(), toKey.id())
+ clusterMergeLog += (fromKey.id() -> toKey.id())
memoWriteCount += 1
}
@@ -150,7 +152,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val
cbo: Cbo[T])
object ForwardMemoTable {
def apply[T <: AnyRef](cbo: Cbo[T]): MemoTable.Writable[T] = new
ForwardMemoTable[T](cbo)
- private case class IntClusterKey(id: Int) extends CboClusterKey
+ private case class IntClusterKey(id: Int, metadata: Metadata) extends
CboClusterKey
private class Probe[T <: AnyRef](table: ForwardMemoTable[T]) extends
MemoTable.Probe[T] {
private val probedClusterCount: Int = table.clusterKeyBuffer.size
diff --git
a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/Memo.scala
b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/Memo.scala
index 7cbee7c39..8907ab84a 100644
--- a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/Memo.scala
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/Memo.scala
@@ -53,8 +53,8 @@ object Memo {
private val memoTable: MemoTable.Writable[T] = MemoTable.create(cbo)
private val cache: NodeToClusterMap[T] = new NodeToClusterMap(cbo)
- private def newCluster(): CboClusterKey = {
- memoTable.newCluster()
+ private def newCluster(metadata: Metadata): CboClusterKey = {
+ memoTable.newCluster(metadata)
}
private def addToCluster(clusterKey: CboClusterKey, can:
CanonicalNode[T]): Unit = {
@@ -107,7 +107,8 @@ object Memo {
cache.get(node)
} else {
// Node not yet added to cluster.
- val clusterKey = newCluster()
+ val meta = cbo.metadataModel.metadataOf(node.self())
+ val clusterKey = newCluster(meta)
addToCluster(clusterKey, node)
clusterKey
}
diff --git
a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/MemoTable.scala
b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/MemoTable.scala
index 755d2c15d..894abfba2 100644
--- a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/MemoTable.scala
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/MemoTable.scala
@@ -42,7 +42,7 @@ object MemoTable {
def create[T <: AnyRef](cbo: Cbo[T]): Writable[T] = ForwardMemoTable(cbo)
trait Writable[T <: AnyRef] extends MemoTable[T] {
- def newCluster(): CboClusterKey
+ def newCluster(metadata: Metadata): CboClusterKey
def groupOf(key: CboClusterKey, propertySet: PropertySet[T]): CboGroup[T]
def addToCluster(key: CboClusterKey, node: CanonicalNode[T]): Unit
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboMetadataSuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboMetadataSuite.scala
new file mode 100644
index 000000000..8d2d49b44
--- /dev/null
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboMetadataSuite.scala
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.CboConfig.PlannerType
+import io.glutenproject.cbo.CboSuiteBase._
+import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class ExhaustiveCboMetadataSuite extends CboMetadataSuite {
+ override protected def conf: CboConfig = CboConfig(plannerType =
PlannerType.Exhaustive)
+}
+
+class DpCboMetadataSuite extends CboMetadataSuite {
+ override protected def conf: CboConfig = CboConfig(plannerType =
PlannerType.Dp)
+}
+
+abstract class CboMetadataSuite extends AnyFunSuite {
+ import CboMetadataSuite._
+ protected def conf: CboConfig
+
+ test("Dry run") {
+ val cbo =
+ Cbo[TestNode](
+ PlanModelImpl,
+ CostModelImpl,
+ RowCountMetadataModel,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+
+ val planner = cbo.newPlanner(KnownRowCountUnary(0.5,
KnownRowCountLeaf(2000)))
+ val out = planner.plan()
+ assert(out == KnownRowCountUnary(0.5, KnownRowCountLeaf(2000)))
+ }
+
+ test("Trivial planning") {
+ object CombineUnaryNodes extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case KnownRowCountUnary(0.25d, KnownRowCountUnary(2.0d, child)) =>
+ assert(child.isInstanceOf[Group])
+ assert(child.asInstanceOf[Group].meta.isInstanceOf[IntRowCount])
+
assert(child.asInstanceOf[Group].meta.asInstanceOf[IntRowCount].value == 2000)
+ List(KnownRowCountUnary(0.5d, child))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ PlanModelImpl,
+ CostModelImpl,
+ RowCountMetadataModel,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(CombineUnaryNodes)))
+ .withNewConfig(_ => conf)
+
+ val planner =
+ cbo.newPlanner(KnownRowCountUnary(0.25d, KnownRowCountUnary(2.0d,
KnownRowCountLeaf(2000))))
+ val out = planner.plan()
+ assert(out == KnownRowCountUnary(0.5d, KnownRowCountLeaf(2000)))
+ }
+
+ test("Cluster merge") {
+ object CombineUnaryNodes extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case KnownRowCountUnary(0.25d, KnownRowCountUnary(2.0d, child)) =>
+ assert(child.isInstanceOf[Group])
+ assert(child.asInstanceOf[Group].meta.isInstanceOf[IntRowCount])
+
assert(child.asInstanceOf[Group].meta.asInstanceOf[IntRowCount].value == 2000)
+ List(KnownRowCountUnary(0.5d, child))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ PlanModelImpl,
+ CostModelImpl,
+ RowCountMetadataModel,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(CombineUnaryNodes)))
+ .withNewConfig(_ => conf)
+
+ val in = KnownRowCountBinary(
+ KnownRowCountUnary(0.5d, KnownRowCountLeaf(2000)),
+ KnownRowCountUnary(0.25d, KnownRowCountUnary(2.0d,
KnownRowCountLeaf(2000))))
+ val planner = cbo.newPlanner(in)
+ val out = planner.plan()
+ assert(
+ out == KnownRowCountBinary(
+ KnownRowCountUnary(0.5, KnownRowCountLeaf(2000)),
+ KnownRowCountUnary(0.5, KnownRowCountLeaf(2000))))
+ }
+}
+
+object CboMetadataSuite {
+ private object RowCountMetadataModel extends MetadataModel[TestNode] {
+ override def metadataOf(node: TestNode): Metadata = node match {
+ case n: KnownRowCountNode =>
+ IntRowCount(n.rowCount())
+ case other =>
+ throw new UnsupportedOperationException()
+ }
+
+ override def dummy(): Metadata = IntRowCount(0)
+ override def verify(one: Metadata, other: Metadata): Unit = (one, other)
match {
+ case (IntRowCount(a), IntRowCount(b)) =>
+ assert(a == b)
+ case other =>
+ throw new UnsupportedOperationException()
+ }
+ }
+
+ trait RowCount extends Metadata
+
+ case class IntRowCount(value: Int) extends RowCount
+
+ trait KnownRowCountNode extends TestNode {
+ def rowCount(): Int
+ }
+
+ case class KnownRowCountUnary(selectivity: Double, override val child:
TestNode)
+ extends UnaryLike
+ with KnownRowCountNode {
+ private val childRowCount = child match {
+ case n: KnownRowCountNode => n.rowCount()
+ case g: Group => g.meta.asInstanceOf[IntRowCount].value
+ case other => throw new UnsupportedOperationException()
+ }
+
+ override def withNewChildren(child: TestNode): UnaryLike =
copy(selectivity, child)
+ override def rowCount(): Int = (childRowCount * selectivity).toInt
+ override def selfCost(): Long = childRowCount
+ }
+
+ case class KnownRowCountLeaf(rowCount: Int) extends LeafLike with
KnownRowCountNode {
+ override def makeCopy(): LeafLike = this
+ override def selfCost(): Long = rowCount
+ }
+
+ case class KnownRowCountBinary(override val left: TestNode, override val
right: TestNode)
+ extends BinaryLike
+ with KnownRowCountNode {
+ private val leftRowCount = left match {
+ case n: KnownRowCountNode => n.rowCount()
+ case g: Group => g.meta.asInstanceOf[IntRowCount].value
+ case other => throw new UnsupportedOperationException()
+ }
+
+ private val rightRowCount = right match {
+ case n: KnownRowCountNode => n.rowCount()
+ case g: Group => g.meta.asInstanceOf[IntRowCount].value
+ case other => throw new UnsupportedOperationException()
+ }
+
+ override def withNewChildren(left: TestNode, right: TestNode): BinaryLike
= copy(left, right)
+ override def rowCount(): Int = leftRowCount * rightRowCount
+ override def selfCost(): Long = leftRowCount + rightRowCount
+ }
+}
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboOperationSuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboOperationSuite.scala
index 2da8a1b98..497b88e3b 100644
---
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboOperationSuite.scala
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboOperationSuite.scala
@@ -16,6 +16,7 @@
*/
package io.glutenproject.cbo
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.path.CboPath
import io.glutenproject.cbo.property.PropertySet
import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
@@ -47,8 +48,9 @@ class CboOperationSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(l2l2, u2u2, Unary2Unary2ToUnary3)))
@@ -68,8 +70,9 @@ class CboOperationSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(l2l2, u2u2, u2u22u3)))
@@ -94,8 +97,9 @@ class CboOperationSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
+ PlanModelImpl,
CostModelImpl,
- planModel,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(l2l2, u2u2)))
@@ -123,8 +127,9 @@ class CboOperationSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
+ PlanModelImpl,
CostModelImpl,
- planModel,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(l2l2, u2u2, u2u22u3)))
@@ -158,8 +163,9 @@ class CboOperationSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
+ PlanModelImpl,
CostModelImpl,
- planModel,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(new UnaryToUnary2(), new LeafToLeaf2(),
rule)))
@@ -198,8 +204,9 @@ class CboOperationSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- costModel,
PlanModelImpl,
+ costModel,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(new UnaryToUnary2, new Unary2ToUnary3)))
@@ -259,8 +266,9 @@ class CboOperationSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- costModel,
PlanModelImpl,
+ costModel,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(new UnaryToUnary2, new Unary2ToUnary3)))
@@ -288,7 +296,7 @@ class CboOperationSuite extends AnyFunSuite {
}
}
-object CboOperationSuite extends CboSuiteBase {
+object CboOperationSuite {
case class Unary(override val selfCost: Long, override val child: TestNode)
extends UnaryLike {
override def withNewChildren(child: TestNode): UnaryLike = copy(child =
child)
@@ -403,9 +411,9 @@ object CboOperationSuite extends CboSuiteBase {
equalsCount += 1
delegated.equals(one, other)
}
- override def newGroupLeaf(groupId: Int, propSet: PropertySet[T]): T = {
+ override def newGroupLeaf(groupId: Int, metadata: Metadata, propSet:
PropertySet[T]): T = {
newGroupLeafCount += 1
- delegated.newGroupLeaf(groupId, propSet)
+ delegated.newGroupLeaf(groupId, metadata, propSet)
}
override def isGroupLeaf(node: T): Boolean = {
isGroupLeafCount += 1
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboPropertySuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboPropertySuite.scala
index 0225601a2..4f25ff2d2 100644
---
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboPropertySuite.scala
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboPropertySuite.scala
@@ -18,6 +18,7 @@ package io.glutenproject.cbo
import io.glutenproject.cbo.Best.BestNotFoundException
import io.glutenproject.cbo.CboConfig.PlannerType
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.property.PropertySet
import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
@@ -57,8 +58,9 @@ abstract class CboPropertySuite extends AnyFunSuite {
test(s"Cannot enforce property") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
NodeTypePropertyModelWithOutEnforcerRules,
ExplainImpl,
CboRule.Factory.none())
@@ -73,8 +75,9 @@ abstract class CboPropertySuite extends AnyFunSuite {
test(s"Property enforcement - A to B") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
NodeTypePropertyModel,
ExplainImpl,
CboRule.Factory.none())
@@ -90,8 +93,9 @@ abstract class CboPropertySuite extends AnyFunSuite {
test(s"Property convert - (A, B)") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
NodeTypePropertyModel,
ExplainImpl,
CboRule.Factory.reuse(List(ReplaceByTypeARule, ReplaceByTypeBRule)))
@@ -145,8 +149,9 @@ abstract class CboPropertySuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
NodeTypePropertyModelWithOutEnforcerRules,
ExplainImpl,
CboRule.Factory.reuse(List(ReplaceLeafAByLeafBRule, HitCacheOp,
FinalOp))
@@ -190,8 +195,9 @@ abstract class CboPropertySuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
NodeTypePropertyModelWithOutEnforcerRules,
ExplainImpl,
CboRule.Factory.reuse(List(ReplaceLeafAByLeafBRule,
ReplaceUnaryBByUnaryARule))
@@ -218,8 +224,9 @@ abstract class CboPropertySuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
NodeTypePropertyModel,
ExplainImpl,
CboRule.Factory.reuse(List(ConvertEnforcerAndTypeAToTypeB)))
@@ -255,11 +262,13 @@ abstract class CboPropertySuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
NodeTypePropertyModel,
ExplainImpl,
- CboRule.Factory.reuse(List(ReplaceByTypeARule,
ReplaceNonUnaryByTypeBRule)))
+ CboRule.Factory.reuse(List(ReplaceByTypeARule,
ReplaceNonUnaryByTypeBRule))
+ )
.withNewConfig(_ => conf)
val plan =
TypedBinary(TypeA, 5, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10)),
TypedLeaf(TypeA, 10))
@@ -305,8 +314,9 @@ abstract class CboPropertySuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
NodeTypePropertyModel,
ExplainImpl,
CboRule.Factory.reuse(List(ReduceTypeBCost, ConvertUnaryTypeBToTypeC)))
@@ -356,8 +366,9 @@ abstract class CboPropertySuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
NodeTypePropertyModel,
ExplainImpl,
CboRule.Factory.reuse(List(LeftOp, RightOp))
@@ -402,8 +413,9 @@ abstract class CboPropertySuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
NodeTypePropertyModel,
ExplainImpl,
CboRule.Factory.reuse(List(ConvertTypeBEnforcerAndLeafToTypeC,
ConvertTypeATypeCToTypeC))
@@ -454,11 +466,13 @@ abstract class CboPropertySuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
NodeTypePropertyModel,
ExplainImpl,
- CboRule.Factory.reuse(List(ReplaceNonUnaryByTypeBRule,
ReduceTypeBCost)))
+ CboRule.Factory.reuse(List(ReplaceNonUnaryByTypeBRule,
ReduceTypeBCost))
+ )
.withNewConfig(_ => conf)
val plan =
TypedBinary(TypeA, 5, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10)),
TypedLeaf(TypeA, 10))
@@ -473,7 +487,7 @@ abstract class CboPropertySuite extends AnyFunSuite {
}
}
-object CboPropertySuite extends CboSuiteBase {
+object CboPropertySuite {
case class NoopEnforcerRule[T <: AnyRef]() extends CboRule[T] {
override def shift(node: T): Iterable[T] = List.empty
@@ -519,7 +533,7 @@ object CboPropertySuite extends CboSuiteBase {
object DummyPropertyDef extends PropertyDef[TestNode, DummyProperty] {
override def getProperty(plan: TestNode): DummyProperty = {
plan match {
- case Group(_, _) => throw new IllegalStateException()
+ case Group(_, _, _) => throw new IllegalStateException()
case PUnary(_, prop, _) => prop
case PLeaf(_, prop) => prop
case PBinary(_, prop, _, _) => prop
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuite.scala
index a930ac356..b1a9fc11d 100644
--- a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuite.scala
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuite.scala
@@ -17,6 +17,7 @@
package io.glutenproject.cbo
import io.glutenproject.cbo.CboConfig.PlannerType
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.memo.Memo
import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
@@ -38,8 +39,9 @@ abstract class CboSuite extends AnyFunSuite {
test("Group memo - re-memorize") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -53,8 +55,9 @@ abstract class CboSuite extends AnyFunSuite {
test("Group memo - define equivalence") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -71,8 +74,9 @@ abstract class CboSuite extends AnyFunSuite {
test("Group memo - define equivalence: binary with similar children, 1") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -91,8 +95,9 @@ abstract class CboSuite extends AnyFunSuite {
test("Group memo - define equivalence: binary with similar children, 2") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -111,8 +116,9 @@ abstract class CboSuite extends AnyFunSuite {
test("Group memo - partial canonical") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -155,8 +161,9 @@ abstract class CboSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(DivideUnaryCost, DecreaseUnaryCost)))
@@ -181,8 +188,9 @@ abstract class CboSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(InsertUnary2)))
@@ -213,8 +221,9 @@ abstract class CboSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(DivideBinaryCost)))
@@ -240,8 +249,9 @@ abstract class CboSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(SymmetricRule)))
@@ -270,8 +280,9 @@ abstract class CboSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(BinarySwap)))
@@ -297,8 +308,9 @@ abstract class CboSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(BinarySwap)))
@@ -324,8 +336,9 @@ abstract class CboSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(Unary2Unary3)))
@@ -345,8 +358,9 @@ abstract class CboSuite extends AnyFunSuite {
val u2u2 = new UnaryToUnary2()
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(l2l2, u2u2)))
@@ -381,8 +395,9 @@ abstract class CboSuite extends AnyFunSuite {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(l2l2, u2u2, Unary2Unary2ToUnary3)))
@@ -398,7 +413,7 @@ abstract class CboSuite extends AnyFunSuite {
}
}
-object CboSuite extends CboSuiteBase {
+object CboSuite {
case class Binary(
override val selfCost: Long,
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuiteBase.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuiteBase.scala
index d718682f6..292e97da0 100644
--- a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuiteBase.scala
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuiteBase.scala
@@ -20,7 +20,7 @@ import io.glutenproject.cbo.memo.{MemoLike, MemoState}
import io.glutenproject.cbo.path.{CboPath, PathFinder}
import io.glutenproject.cbo.property.PropertySet
-trait CboSuiteBase {
+object CboSuiteBase {
trait TestNode {
def selfCost(): Long
def children(): Seq[TestNode]
@@ -49,14 +49,14 @@ trait CboSuiteBase {
def withNewChildren(children: Seq[TestNode]): TestNode = this
}
- case class Group(id: Int, propSet: PropertySet[TestNode]) extends LeafLike {
+ case class Group(id: Int, meta: Metadata, propSet: PropertySet[TestNode])
extends LeafLike {
override def selfCost(): Long = Long.MaxValue
override def makeCopy(): LeafLike = copy()
}
object Group {
def apply(id: Int): Group = {
- Group(id, PropertySet(List.empty))
+ Group(id, MetadataModelImpl.DummyMetadata, PropertySet(List.empty))
}
}
@@ -110,8 +110,11 @@ trait CboSuiteBase {
java.util.Objects.equals(one, other)
}
- override def newGroupLeaf(groupId: Int, propSet: PropertySet[TestNode]):
TestNode =
- Group(groupId, propSet)
+ override def newGroupLeaf(
+ groupId: Int,
+ meta: Metadata,
+ propSet: PropertySet[TestNode]): TestNode =
+ Group(groupId, meta, propSet)
override def getGroupId(node: TestNode): Int = node match {
case ngl: Group => ngl.id
@@ -130,6 +133,20 @@ trait CboSuiteBase {
}
}
+ object MetadataModelImpl extends MetadataModel[TestNode] {
+ case object DummyMetadata extends Metadata
+ override def metadataOf(node: TestNode): Metadata = node match {
+ case g: Group => throw new UnsupportedOperationException()
+ case n: TestNode => DummyMetadata
+ case other => throw new UnsupportedOperationException()
+ }
+ override def dummy(): Metadata = DummyMetadata
+ override def verify(one: Metadata, other: Metadata): Unit = {
+ assert(one == DummyMetadata)
+ assert(other == DummyMetadata)
+ }
+ }
+
object PropertyModelImpl extends PropertyModel[TestNode] {
override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = List.empty
override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _
<: Property[TestNode]])
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/mock/MockMemoState.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/mock/MockMemoState.scala
index 6e95dceb6..638f256db 100644
---
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/mock/MockMemoState.scala
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/mock/MockMemoState.scala
@@ -17,6 +17,7 @@
package io.glutenproject.cbo.mock
import io.glutenproject.cbo._
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.memo.{MemoState, MemoStore}
import io.glutenproject.cbo.property.PropertySet
import io.glutenproject.cbo.vis.GraphvizVisualizer
@@ -64,7 +65,7 @@ object MockMemoState {
def newCluster(): MockMutableCluster[T] = {
val id = clusterBuffer.size
val key = MockMutableCluster.DummyIntClusterKey(id)
- val cluster = MockMutableCluster[T](cbo, key, propSet, groupFactory)
+ val cluster = MockMutableCluster[T](cbo, key, groupFactory)
clusterBuffer += (key -> cluster)
cluster
}
@@ -103,12 +104,13 @@ object MockMemoState {
def apply[T <: AnyRef](
cbo: Cbo[T],
key: CboClusterKey,
- propSet: PropertySet[T],
groupFactory: MockMutableGroup.Factory[T]): MockMutableCluster[T] = {
new MockMutableCluster[T](cbo, key, groupFactory)
}
- case class DummyIntClusterKey(id: Int) extends CboClusterKey
+ case class DummyIntClusterKey(id: Int) extends CboClusterKey {
+ override def metadata: Metadata = MetadataModelImpl.DummyMetadata
+ }
}
class MockMutableGroup[T <: AnyRef] private (
@@ -137,7 +139,11 @@ object MockMemoState {
def newGroup(clusterKey: CboClusterKey): MockMutableGroup[T] = {
val id = groupBuffer.size
val group =
- new MockMutableGroup[T](id, clusterKey, propSet,
cbo.planModel.newGroupLeaf(id, propSet))
+ new MockMutableGroup[T](
+ id,
+ clusterKey,
+ propSet,
+ cbo.planModel.newGroupLeaf(id, clusterKey.metadata, propSet))
groupBuffer += group
group
}
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/CboPathSuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/CboPathSuite.scala
index fff583874..4bc1014de 100644
---
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/CboPathSuite.scala
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/CboPathSuite.scala
@@ -16,7 +16,8 @@
*/
package io.glutenproject.cbo.path
-import io.glutenproject.cbo.{Cbo, CboSuiteBase}
+import io.glutenproject.cbo.Cbo
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.mock.MockCboPath
import io.glutenproject.cbo.rule.CboRule
@@ -28,8 +29,9 @@ class CboPathSuite extends AnyFunSuite {
test("Path aggregate - empty") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List.empty))
@@ -39,8 +41,9 @@ class CboPathSuite extends AnyFunSuite {
test("Path aggregate") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List.empty))
@@ -91,7 +94,7 @@ class CboPathSuite extends AnyFunSuite {
}
}
-object CboPathSuite extends CboSuiteBase {
+object CboPathSuite {
case class Leaf(name: String, override val selfCost: Long) extends LeafLike {
override def makeCopy(): LeafLike = this
}
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathFinderSuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathFinderSuite.scala
index a25cc2dda..8ed2ef58e 100644
---
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathFinderSuite.scala
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathFinderSuite.scala
@@ -16,7 +16,8 @@
*/
package io.glutenproject.cbo.path
-import io.glutenproject.cbo.{CanonicalNode, Cbo, CboSuiteBase}
+import io.glutenproject.cbo.{CanonicalNode, Cbo}
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.mock.MockMemoState
import io.glutenproject.cbo.rule.CboRule
@@ -28,8 +29,9 @@ class PathFinderSuite extends AnyFunSuite {
test("Base") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -82,8 +84,9 @@ class PathFinderSuite extends AnyFunSuite {
test("Find - multiple depths") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -161,8 +164,9 @@ class PathFinderSuite extends AnyFunSuite {
test("Dive - basic") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -219,8 +223,9 @@ class PathFinderSuite extends AnyFunSuite {
test("Find/Dive - binary with different children heights") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -286,7 +291,7 @@ class PathFinderSuite extends AnyFunSuite {
}
}
-object PathFinderSuite extends CboSuiteBase {
+object PathFinderSuite {
case class Leaf(name: String, override val selfCost: Long) extends LeafLike {
override def makeCopy(): LeafLike = this
}
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathMaskSuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathMaskSuite.scala
index ef784e928..310435daa 100644
---
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathMaskSuite.scala
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathMaskSuite.scala
@@ -16,8 +16,6 @@
*/
package io.glutenproject.cbo.path
-import io.glutenproject.cbo.CboSuiteBase
-
import org.scalatest.funsuite.AnyFunSuite
class PathMaskSuite extends AnyFunSuite {
@@ -110,4 +108,4 @@ class PathMaskSuite extends AnyFunSuite {
}
}
-object PathMaskSuite extends CboSuiteBase {}
+object PathMaskSuite {}
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/WizardSuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/WizardSuite.scala
index 3bcb6d55e..e0f6df835 100644
---
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/WizardSuite.scala
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/WizardSuite.scala
@@ -16,7 +16,8 @@
*/
package io.glutenproject.cbo.path
-import io.glutenproject.cbo.{Cbo, CboSuiteBase}
+import io.glutenproject.cbo.Cbo
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.mock.MockMemoState
import io.glutenproject.cbo.rule.CboRule
@@ -28,8 +29,9 @@ class WizardSuite extends AnyFunSuite {
test("None") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -52,8 +54,9 @@ class WizardSuite extends AnyFunSuite {
test("Prune by maximum depth") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -128,8 +131,9 @@ class WizardSuite extends AnyFunSuite {
test("Prune by pattern") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -218,8 +222,9 @@ class WizardSuite extends AnyFunSuite {
test("Prune by mask") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -286,7 +291,7 @@ class WizardSuite extends AnyFunSuite {
}
}
-object WizardSuite extends CboSuiteBase {
+object WizardSuite {
case class Leaf(name: String, override val selfCost: Long) extends LeafLike {
override def makeCopy(): LeafLike = this
}
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/rule/PatternSuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/rule/PatternSuite.scala
index c45eea162..f787aa18a 100644
---
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/rule/PatternSuite.scala
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/rule/PatternSuite.scala
@@ -16,7 +16,8 @@
*/
package io.glutenproject.cbo.rule
-import io.glutenproject.cbo.{rule, Cbo, CboSuiteBase}
+import io.glutenproject.cbo.Cbo
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.mock.MockCboPath
import io.glutenproject.cbo.path.{CboPath, Pattern}
@@ -27,8 +28,9 @@ class PatternSuite extends AnyFunSuite {
test("Match any") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -43,8 +45,9 @@ class PatternSuite extends AnyFunSuite {
test("Match ignore") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -59,8 +62,9 @@ class PatternSuite extends AnyFunSuite {
test("Match unary") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -81,8 +85,9 @@ class PatternSuite extends AnyFunSuite {
test("Match binary") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -113,8 +118,9 @@ class PatternSuite extends AnyFunSuite {
test("Matches above a certain depth") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -169,7 +175,7 @@ class PatternSuite extends AnyFunSuite {
}
}
-object PatternSuite extends CboSuiteBase {
+object PatternSuite {
case class Leaf(name: String, override val selfCost: Long) extends LeafLike {
override def makeCopy(): LeafLike = this
}
@@ -188,7 +194,7 @@ object PatternSuite extends CboSuiteBase {
}
case class DummyGroup() extends LeafLike {
- override def makeCopy(): rule.PatternSuite.LeafLike = throw new
UnsupportedOperationException()
+ override def makeCopy(): LeafLike = throw new
UnsupportedOperationException()
override def selfCost(): Long = throw new UnsupportedOperationException()
}
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/CyclicSearchSpaceSuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/CyclicSearchSpaceSuite.scala
index 80faa09db..572b6e2d7 100644
---
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/CyclicSearchSpaceSuite.scala
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/CyclicSearchSpaceSuite.scala
@@ -17,6 +17,7 @@
package io.glutenproject.cbo.specific
import io.glutenproject.cbo._
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.best.BestFinder
import io.glutenproject.cbo.memo.MemoState
import io.glutenproject.cbo.mock.MockMemoState
@@ -39,8 +40,9 @@ abstract class CyclicSearchSpaceSuite extends AnyFunSuite {
test("Cyclic - find paths, simple self cycle") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -73,8 +75,9 @@ abstract class CyclicSearchSpaceSuite extends AnyFunSuite {
test("Cyclic - find best, simple self cycle") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -101,8 +104,9 @@ abstract class CyclicSearchSpaceSuite extends AnyFunSuite {
test("Cyclic - find best, case 1") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -164,8 +168,9 @@ abstract class CyclicSearchSpaceSuite extends AnyFunSuite {
test("Cyclic - find best, case 2") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -208,7 +213,7 @@ abstract class CyclicSearchSpaceSuite extends AnyFunSuite {
}
}
-object CyclicSearchSpaceSuite extends CboSuiteBase {
+object CyclicSearchSpaceSuite {
case class Leaf(name: String, override val selfCost: Long) extends LeafLike {
override def makeCopy(): LeafLike = this
}
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/DistributedSuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/DistributedSuite.scala
index 6910ebcc1..733c04a5d 100644
---
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/DistributedSuite.scala
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/DistributedSuite.scala
@@ -18,6 +18,7 @@ package io.glutenproject.cbo.specific
import io.glutenproject.cbo._
import io.glutenproject.cbo.CboConfig.PlannerType
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.property.PropertySet
import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
@@ -39,8 +40,9 @@ abstract class DistributedSuite extends AnyFunSuite {
test("Project - dry run") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
DistributedPropertyModel,
ExplainImpl,
CboRule.Factory.none())
@@ -55,8 +57,9 @@ abstract class DistributedSuite extends AnyFunSuite {
test("Project - required distribution") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
DistributedPropertyModel,
ExplainImpl,
CboRule.Factory.none())
@@ -72,8 +75,9 @@ abstract class DistributedSuite extends AnyFunSuite {
test("Aggregate - none-distribution constraint") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
DistributedPropertyModel,
ExplainImpl,
CboRule.Factory.none())
@@ -92,8 +96,9 @@ abstract class DistributedSuite extends AnyFunSuite {
test("Project - required ordering") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
DistributedPropertyModel,
ExplainImpl,
CboRule.Factory.none())
@@ -109,8 +114,9 @@ abstract class DistributedSuite extends AnyFunSuite {
test("Project - required distribution and ordering") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
DistributedPropertyModel,
ExplainImpl,
CboRule.Factory.none())
@@ -128,8 +134,9 @@ abstract class DistributedSuite extends AnyFunSuite {
test("Aggregate - avoid re-exchange") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
DistributedPropertyModel,
ExplainImpl,
CboRule.Factory.none())
@@ -147,8 +154,9 @@ abstract class DistributedSuite extends AnyFunSuite {
test("Aggregate - avoid re-exchange, required ordering") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
DistributedPropertyModel,
ExplainImpl,
CboRule.Factory.none())
@@ -167,8 +175,9 @@ abstract class DistributedSuite extends AnyFunSuite {
ignore("Aggregate - avoid re-exchange, partial") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
DistributedPropertyModel,
ExplainImpl,
CboRule.Factory.reuse(List(PartialAggregateRule)))
@@ -191,7 +200,7 @@ abstract class DistributedSuite extends AnyFunSuite {
}
}
-object DistributedSuite extends CboSuiteBase {
+object DistributedSuite {
object PartialAggregateRule extends CboRule[TestNode] {
diff --git
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/JoinReorderSuite.scala
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/JoinReorderSuite.scala
index e8284649b..456004f4c 100644
---
a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/JoinReorderSuite.scala
+++
b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/JoinReorderSuite.scala
@@ -16,8 +16,9 @@
*/
package io.glutenproject.cbo.specific
-import io.glutenproject.cbo.{Cbo, CboConfig, CboSuiteBase}
+import io.glutenproject.cbo.{Cbo, CboConfig}
import io.glutenproject.cbo.CboConfig.PlannerType
+import io.glutenproject.cbo.CboSuiteBase._
import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
import org.scalatest.funsuite.AnyFunSuite
@@ -38,8 +39,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
test("3 way join - dry run") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.none())
@@ -53,8 +55,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
test("3 way join - reorder") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(JoinAssociateRule, JoinCommuteRule)))
@@ -68,8 +71,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
test("5 way join - reorder") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(JoinAssociateRule, JoinCommuteRule)))
@@ -88,8 +92,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
ignore("7 way join - reorder") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(JoinAssociateRule, JoinCommuteRule)))
@@ -108,8 +113,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
ignore("9 way join - reorder") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(JoinAssociateRule, JoinCommuteRule)))
@@ -129,8 +135,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
ignore("12 way join - reorder") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(List(JoinAssociateRule, JoinCommuteRule)))
@@ -152,8 +159,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
test("2 way join - reorder, left deep only") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(leftDeepJoinRules(2)))
@@ -167,8 +175,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
test("3 way join - reorder, left deep only") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(leftDeepJoinRules(3)))
@@ -182,8 +191,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
test("5 way join - reorder, left deep only") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(leftDeepJoinRules(5)))
@@ -201,8 +211,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
test("7 way join - reorder, left deep only") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(leftDeepJoinRules(7)))
@@ -227,8 +238,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
test("9 way join - reorder, left deep only") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(leftDeepJoinRules(9)))
@@ -263,8 +275,9 @@ abstract class JoinReorderSuite extends AnyFunSuite {
ignore("12 way join - reorder, left deep only") {
val cbo =
Cbo[TestNode](
- CostModelImpl,
PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
CboRule.Factory.reuse(leftDeepJoinRules(12)))
@@ -284,7 +297,7 @@ abstract class JoinReorderSuite extends AnyFunSuite {
}
}
-object JoinReorderSuite extends CboSuiteBase {
+object JoinReorderSuite {
object JoinAssociateRule extends CboRule[TestNode] {
override def shift(node: TestNode): Iterable[TestNode] = node match {
diff --git
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/EnumeratedTransform.scala
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/EnumeratedTransform.scala
index 42ffec155..5cc091fbe 100644
---
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/EnumeratedTransform.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/EnumeratedTransform.scala
@@ -45,9 +45,9 @@ case class EnumeratedTransform(session: SparkSession,
outputsColumnar: Boolean)
Seq(GlutenProperties.Conventions.GLUTEN_COLUMNAR,
GlutenProperties.Conventions.ROW_BASED)
override def apply(plan: SparkPlan): SparkPlan = {
- val constraintSet = PropertySet(List(GlutenProperties.Schemas.ANY,
reqConvention))
+ val constraintSet = PropertySet(List(reqConvention))
val altConstraintSets =
- altConventions.map(altConv =>
PropertySet(List(GlutenProperties.Schemas.ANY, altConv)))
+ altConventions.map(altConv => PropertySet(List(altConv)))
val planner = optimization.newPlanner(plan, constraintSet,
altConstraintSets)
val out = planner.plan()
out
diff --git
a/gluten-core/src/main/scala/io/glutenproject/planner/GlutenOptimization.scala
b/gluten-core/src/main/scala/io/glutenproject/planner/GlutenOptimization.scala
index e1540e553..b0972bafa 100644
---
a/gluten-core/src/main/scala/io/glutenproject/planner/GlutenOptimization.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/planner/GlutenOptimization.scala
@@ -19,6 +19,7 @@ package io.glutenproject.planner
import io.glutenproject.cbo.{CboExplain, Optimization}
import io.glutenproject.cbo.rule.CboRule
import io.glutenproject.planner.cost.GlutenCostModel
+import io.glutenproject.planner.metadata.GlutenMetadataModel
import io.glutenproject.planner.plan.GlutenPlanModel
import io.glutenproject.planner.property.GlutenPropertyModel
import io.glutenproject.planner.rule.GlutenRules
@@ -32,8 +33,9 @@ object GlutenOptimization {
def apply(): Optimization[SparkPlan] = {
Optimization[SparkPlan](
- GlutenCostModel(),
GlutenPlanModel(),
+ GlutenCostModel(),
+ GlutenMetadataModel(),
GlutenPropertyModel(),
GlutenExplain,
CboRule.Factory.reuse(GlutenRules()))
@@ -41,8 +43,9 @@ object GlutenOptimization {
def apply(rules: Seq[CboRule[SparkPlan]]): Optimization[SparkPlan] = {
Optimization[SparkPlan](
- GlutenCostModel(),
GlutenPlanModel(),
+ GlutenCostModel(),
+ GlutenMetadataModel(),
GlutenPropertyModel(),
GlutenExplain,
CboRule.Factory.reuse(rules))
diff --git
a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
b/gluten-core/src/main/scala/io/glutenproject/planner/metadata/GlutenMetadata.scala
similarity index 64%
copy from gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
copy to
gluten-core/src/main/scala/io/glutenproject/planner/metadata/GlutenMetadata.scala
index 366d1575f..458b37309 100644
--- a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/planner/metadata/GlutenMetadata.scala
@@ -14,19 +14,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package io.glutenproject.cbo
+package io.glutenproject.planner.metadata
-import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.Metadata
-trait PlanModel[T <: AnyRef] {
- // Trivial tree operations.
- def childrenOf(node: T): Seq[T]
- def withNewChildren(node: T, children: Seq[T]): T
- def hashCode(node: T): Int
- def equals(one: T, other: T): Boolean
+import org.apache.spark.sql.catalyst.expressions.Attribute
- // Group operations.
- def newGroupLeaf(groupId: Int, propSet: PropertySet[T]): T
- def isGroupLeaf(node: T): Boolean
- def getGroupId(node: T): Int
+sealed trait GlutenMetadata extends Metadata {
+ import GlutenMetadata._
+ def schema(): Schema
+}
+
+object GlutenMetadata {
+ def apply(schema: Schema): Metadata = {
+ Impl(schema)
+ }
+
+ private case class Impl(schema: Schema) extends GlutenMetadata
+ case class Schema(output: Seq[Attribute])
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/planner/metadata/GlutenMetadataModel.scala
b/gluten-core/src/main/scala/io/glutenproject/planner/metadata/GlutenMetadataModel.scala
new file mode 100644
index 000000000..f7f06d942
--- /dev/null
+++
b/gluten-core/src/main/scala/io/glutenproject/planner/metadata/GlutenMetadataModel.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.planner.metadata
+
+import io.glutenproject.cbo.{Metadata, MetadataModel}
+import io.glutenproject.planner.metadata.GlutenMetadata.Schema
+import io.glutenproject.planner.plan.GlutenPlanModel.GroupLeafExec
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.SparkPlan
+
+object GlutenMetadataModel extends Logging {
+ def apply(): MetadataModel[SparkPlan] = {
+ MetadataModelImpl
+ }
+
+ private object MetadataModelImpl extends MetadataModel[SparkPlan] {
+ override def metadataOf(node: SparkPlan): Metadata = node match {
+ case g: GroupLeafExec => throw new UnsupportedOperationException()
+ case other => GlutenMetadata(Schema(other.output))
+ }
+
+ override def dummy(): Metadata = GlutenMetadata(Schema(List()))
+ override def verify(one: Metadata, other: Metadata): Unit = (one, other)
match {
+ case (left: GlutenMetadata, right: GlutenMetadata) if left.schema() !=
right.schema() =>
+ // We apply loose restriction on schema. Since Gluten still have some
customized
+ // logics causing schema of an operator to change after being
transformed.
+ // For example: https://github.com/apache/incubator-gluten/pull/5171
+ logWarning(s"Warning: Schema mismatch: one: ${left.schema()}, other:
${right.schema()}")
+ case (left: GlutenMetadata, right: GlutenMetadata) if left == right =>
+ case _ => throw new IllegalStateException(s"Metadata mismatch: one:
$one, other $other")
+ }
+ }
+}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/planner/plan/GlutenPlanModel.scala
b/gluten-core/src/main/scala/io/glutenproject/planner/plan/GlutenPlanModel.scala
index 3e9b4777a..619d832fc 100644
---
a/gluten-core/src/main/scala/io/glutenproject/planner/plan/GlutenPlanModel.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/planner/plan/GlutenPlanModel.scala
@@ -16,8 +16,9 @@
*/
package io.glutenproject.planner.plan
-import io.glutenproject.cbo.PlanModel
+import io.glutenproject.cbo.{Metadata, PlanModel}
import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.planner.metadata.GlutenMetadata
import io.glutenproject.planner.property.GlutenProperties
import io.glutenproject.planner.property.GlutenProperties.Conventions
@@ -33,9 +34,13 @@ object GlutenPlanModel {
PlanModelImpl
}
- case class GroupLeafExec(groupId: Int, propertySet: PropertySet[SparkPlan])
extends LeafExecNode {
+ case class GroupLeafExec(
+ groupId: Int,
+ metadata: GlutenMetadata,
+ propertySet: PropertySet[SparkPlan])
+ extends LeafExecNode {
override protected def doExecute(): RDD[InternalRow] = throw new
IllegalStateException()
- override def output: Seq[Attribute] =
propertySet.get(GlutenProperties.SCHEMA_DEF).output
+ override def output: Seq[Attribute] = metadata.schema().output
override def supportsColumnar: Boolean =
propertySet.get(GlutenProperties.CONVENTION_DEF) match {
case Conventions.ROW_BASED => false
@@ -56,8 +61,11 @@ object GlutenPlanModel {
override def equals(one: SparkPlan, other: SparkPlan): Boolean =
Objects.equals(one, other)
- override def newGroupLeaf(groupId: Int, propSet: PropertySet[SparkPlan]):
SparkPlan =
- GroupLeafExec(groupId, propSet)
+ override def newGroupLeaf(
+ groupId: Int,
+ metadata: Metadata,
+ propSet: PropertySet[SparkPlan]): SparkPlan =
+ GroupLeafExec(groupId, metadata.asInstanceOf[GlutenMetadata], propSet)
override def isGroupLeaf(node: SparkPlan): Boolean = node match {
case _: GroupLeafExec => true
diff --git
a/gluten-core/src/main/scala/io/glutenproject/planner/property/GlutenPropertyModel.scala
b/gluten-core/src/main/scala/io/glutenproject/planner/property/GlutenPropertyModel.scala
index 3b88b89bc..bea5c3a0b 100644
---
a/gluten-core/src/main/scala/io/glutenproject/planner/property/GlutenPropertyModel.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/planner/property/GlutenPropertyModel.scala
@@ -21,26 +21,13 @@ import io.glutenproject.cbo._
import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
import io.glutenproject.extension.columnar.ColumnarTransitions
import io.glutenproject.planner.plan.GlutenPlanModel.GroupLeafExec
-import io.glutenproject.planner.property.GlutenProperties.{Convention,
CONVENTION_DEF, ConventionEnforcerRule, SCHEMA_DEF}
+import io.glutenproject.planner.property.GlutenProperties.{Convention,
CONVENTION_DEF, ConventionEnforcerRule}
import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.utils.PlanUtil
-import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution._
object GlutenProperties {
- val SCHEMA_DEF: PropertyDef[SparkPlan, Schema] = new PropertyDef[SparkPlan,
Schema] {
- override def getProperty(plan: SparkPlan): Schema = plan match {
- case _: GroupLeafExec => throw new IllegalStateException()
- case _ => Schema(plan.output)
- }
- override def getChildrenConstraints(
- constraint: Property[SparkPlan],
- plan: SparkPlan): Seq[Schema] = {
- plan.children.map(c => Schema(c.output))
- }
- }
-
val CONVENTION_DEF: PropertyDef[SparkPlan, Convention] = new
PropertyDef[SparkPlan, Convention] {
// TODO: Should the convention-transparent ops (e.g., aqe shuffle read)
support
// convention-propagation. Probably need to refactor
getChildrenPropertyRequirements.
@@ -103,22 +90,6 @@ object GlutenProperties {
override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
}
- case class Schema(output: Seq[Attribute]) extends Property[SparkPlan] {
- override def satisfies(other: Property[SparkPlan]): Boolean = other match {
- case Schemas.ANY => true
- case Schema(otherOutput) => output == otherOutput
- case _ => throw new IllegalStateException()
- }
-
- override def definition(): PropertyDef[SparkPlan, _ <:
Property[SparkPlan]] = {
- SCHEMA_DEF
- }
- }
-
- object Schemas {
- val ANY: Property[SparkPlan] = Schema(List())
- }
-
sealed trait Convention extends Property[SparkPlan] {
override def definition(): PropertyDef[SparkPlan, _ <:
Property[SparkPlan]] = {
CONVENTION_DEF
@@ -148,14 +119,12 @@ object GlutenPropertyModel {
private object PropertyModelImpl extends PropertyModel[SparkPlan] {
override def propertyDefs: Seq[PropertyDef[SparkPlan, _ <:
Property[SparkPlan]]] =
- Seq(SCHEMA_DEF, CONVENTION_DEF)
+ Seq(CONVENTION_DEF)
override def newEnforcerRuleFactory(
propertyDef: PropertyDef[SparkPlan, _ <: Property[SparkPlan]])
: EnforcerRuleFactory[SparkPlan] = (reqProp: Property[SparkPlan]) => {
propertyDef match {
- case SCHEMA_DEF =>
- Seq()
case CONVENTION_DEF =>
Seq(ConventionEnforcerRule(reqProp.asInstanceOf[Convention]))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]