This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new b720ab2cbd [VL] RAS: Add internal property `MemoRole` to reduce
duplications in plan enumeration for rule applications (#9749)
b720ab2cbd is described below
commit b720ab2cbdf3b28865b66105be47d4ddc844ed25
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue May 27 14:02:10 2025 +0100
[VL] RAS: Add internal property `MemoRole` to reduce duplications in plan
enumeration for rule applications (#9749)
---
.../enumerated/planner/VeloxRasSuite.scala | 22 ++-
.../extension/columnar/cost/LongCostModel.scala | 2 +-
.../planner/metadata/GlutenMetadataModel.scala | 10 +-
.../enumerated/planner/plan/GlutenPlanModel.scala | 93 +--------
.../enumerated/planner/plan/GroupLeafExec.scala | 127 ++++++++++++
.../enumerated/planner/property/Conv.scala | 89 +++++----
.../planner/property/GlutenPropertyModel.scala | 20 +-
.../org/apache/gluten/ras/MetadataModel.scala | 4 +-
.../scala/org/apache/gluten/ras/PlanModel.scala | 17 +-
.../org/apache/gluten/ras/PropertyModel.scala | 21 +-
.../src/main/scala/org/apache/gluten/ras/Ras.scala | 122 ++++--------
.../scala/org/apache/gluten/ras/RasCluster.scala | 4 +-
.../scala/org/apache/gluten/ras/RasGroup.scala | 2 +-
.../main/scala/org/apache/gluten/ras/RasNode.scala | 6 +-
.../scala/org/apache/gluten/ras/dp/DpPlanner.scala | 48 +++--
.../org/apache/gluten/ras/dp/DpZipperAlgo.scala | 100 ++++++----
.../gluten/ras/exaustive/ExhaustivePlanner.scala | 50 +++--
.../apache/gluten/ras/memo/ForwardMemoTable.scala | 54 +++--
.../scala/org/apache/gluten/ras/memo/Memo.scala | 23 ++-
.../org/apache/gluten/ras/memo/MemoTable.scala | 26 ++-
.../org/apache/gluten/ras/path/OutputFilter.scala | 30 ++-
.../org/apache/gluten/ras/path/PathFinder.scala | 2 +-
.../scala/org/apache/gluten/ras/path/RasPath.scala | 2 +
.../org/apache/gluten/ras/property/MemoRole.scala | 218 +++++++++++++++++++++
.../apache/gluten/ras/property/PropertySet.scala | 22 +--
.../gluten/ras/property/PropertySetFactory.scala | 102 ++++++++++
.../org/apache/gluten/ras/rule/EnforcerRule.scala | 109 -----------
.../gluten/ras/rule/EnforcerRuleFactory.scala | 68 +++++++
.../apache/gluten/ras/rule/EnforcerRuleSet.scala | 99 ++++++++++
.../scala/org/apache/gluten/ras/rule/RasRule.scala | 1 -
.../org/apache/gluten/ras/rule/RuleApplier.scala | 85 +++++---
.../apache/gluten/ras/vis/GraphvizVisualizer.scala | 2 +-
.../org/apache/gluten/ras/MetadataSuite.scala | 7 +
.../org/apache/gluten/ras/OperationSuite.scala | 14 +-
.../org/apache/gluten/ras/PropertySuite.scala | 157 ++++++++-------
.../scala/org/apache/gluten/ras/RasSuite.scala | 25 +--
.../scala/org/apache/gluten/ras/RasSuiteBase.scala | 78 +++++---
.../org/apache/gluten/ras/mock/MockMemoState.scala | 11 +-
.../org/apache/gluten/ras/mock/MockRasPath.scala | 2 +-
.../gluten/ras/specific/DistributedSuite.scala | 212 +++++++++++---------
40 files changed, 1357 insertions(+), 729 deletions(-)
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
index 0e8b056fcf..987e8ce70b 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
@@ -22,8 +22,6 @@ import
org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform
import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionReq}
import org.apache.gluten.ras.Ras
-import org.apache.gluten.ras.RasSuiteBase._
-import org.apache.gluten.ras.path.RasPath
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
@@ -78,9 +76,13 @@ class VeloxRasSuite extends SharedSparkSession {
ColumnarUnary(RowToColumnarExec(
RowUnary(RowUnary(ColumnarToRowExec(ColumnarUnary(RowToColumnarExec(RowUnary(RowUnary(
ColumnarToRowExec(ColumnarUnary(RowToColumnarExec(RowLeaf(TRIVIAL_SCHEMA)))))))))))))))
- val paths =
planner.newState().memoState().collectAllPaths(RasPath.INF_DEPTH).toList
- val pathCount = paths.size
- assert(pathCount == 165)
+ val memoState = planner.newState().memoState()
+ val numClusters = memoState.allClusters().size
+ val numGroups = memoState.allGroups().size
+ val numNodes = memoState.allClusters().flatMap(_.nodes()).size
+ assert(numClusters == 8)
+ assert(numGroups == 30)
+ assert(numNodes == 39)
}
test("C2R, R2C - Row unary convertible to Columnar") {
@@ -103,9 +105,13 @@ class VeloxRasSuite extends SharedSparkSession {
val out = planner.plan()
assert(out ==
ColumnarToRowExec(ColumnarUnary(ColumnarUnary(ColumnarUnary(ColumnarUnary(
ColumnarUnary(ColumnarUnary(ColumnarUnary(RowToColumnarExec(RowLeaf(TRIVIAL_SCHEMA)))))))))))
- val paths =
planner.newState().memoState().collectAllPaths(RasPath.INF_DEPTH).toList
- val pathCount = paths.size
- assert(pathCount == 1094)
+ val memoState = planner.newState().memoState()
+ val numClusters = memoState.allClusters().size
+ val numGroups = memoState.allGroups().size
+ val numNodes = memoState.allClusters().flatMap(_.nodes()).size
+ assert(numClusters == 8)
+ assert(numGroups == 32)
+ assert(numNodes == 55)
}
test("C2R, R2C - empty schema") {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala
index 2cdf86e6af..954d711d08 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.extension.columnar.cost
import org.apache.gluten.exception.GlutenException
-import
org.apache.gluten.extension.columnar.enumerated.planner.plan.GlutenPlanModel.GroupLeafExec
+import
org.apache.gluten.extension.columnar.enumerated.planner.plan.GroupLeafExec
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/metadata/GlutenMetadataModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/metadata/GlutenMetadataModel.scala
index 690964daa6..61f1081b87 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/metadata/GlutenMetadataModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/metadata/GlutenMetadataModel.scala
@@ -16,8 +16,8 @@
*/
package org.apache.gluten.extension.columnar.enumerated.planner.metadata
-import
org.apache.gluten.extension.columnar.enumerated.planner.plan.GlutenPlanModel.GroupLeafExec
-import org.apache.gluten.ras.{Metadata, MetadataModel}
+import
org.apache.gluten.extension.columnar.enumerated.planner.plan.GroupLeafExec
+import org.apache.gluten.ras.{GroupLeafBuilder, Metadata, MetadataModel}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
@@ -43,6 +43,12 @@ object GlutenMetadataModel extends Logging {
implicitly[Verifier[LogicalLink]].verify(left.logicalLink(),
right.logicalLink())
case _ => throw new IllegalStateException(s"Metadata mismatch: one:
$one, other $other")
}
+
+ override def assignToGroup(group: GroupLeafBuilder[SparkPlan], meta:
Metadata): Unit =
+ (group, meta) match {
+ case (builder: GroupLeafExec.Builder, metadata: GlutenMetadata) =>
+ builder.withMetadata(metadata)
+ }
}
trait Verifier[T <: Any] {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/plan/GlutenPlanModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/plan/GlutenPlanModel.scala
index a4058e5c7b..fed7bca248 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/plan/GlutenPlanModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/plan/GlutenPlanModel.scala
@@ -16,103 +16,21 @@
*/
package org.apache.gluten.extension.columnar.enumerated.planner.plan
-import org.apache.gluten.execution.GlutenPlan
-import
org.apache.gluten.extension.columnar.enumerated.planner.metadata.{GlutenMetadata,
LogicalLink}
-import org.apache.gluten.extension.columnar.enumerated.planner.property.{Conv,
ConvDef}
-import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionReq}
-import org.apache.gluten.ras.{Metadata, PlanModel}
-import org.apache.gluten.ras.property.PropertySet
+import org.apache.gluten.ras.PlanModel
import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.execution.{ColumnarToRowExec, LeafExecNode,
SparkPlan}
+import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase
import org.apache.spark.task.{SparkTaskUtil, TaskResources}
import java.util.{Objects, Properties}
-import java.util.concurrent.atomic.AtomicBoolean
object GlutenPlanModel {
def apply(): PlanModel[SparkPlan] = {
PlanModelImpl
}
- // TODO: Make this inherit from GlutenPlan.
- case class GroupLeafExec(
- groupId: Int,
- metadata: GlutenMetadata,
- constraintSet: PropertySet[SparkPlan])
- extends LeafExecNode
- with Convention.KnownBatchType
- with Convention.KnownRowTypeForSpark33OrLater
- with GlutenPlan.SupportsRowBasedCompatible {
-
- private val frozen = new AtomicBoolean(false)
- private val req: Conv.Req =
constraintSet.get(ConvDef).asInstanceOf[Conv.Req]
-
- // Set the logical link then make the plan node immutable. All future
- // mutable operations related to tagging will be aborted.
- if (metadata.logicalLink() != LogicalLink.notFound) {
- setLogicalLink(metadata.logicalLink().plan)
- }
- frozen.set(true)
-
- override protected def doExecute(): RDD[InternalRow] = throw new
IllegalStateException()
- override def output: Seq[Attribute] = metadata.schema().output
-
- override val batchType: Convention.BatchType = {
- val out = req.req.requiredBatchType match {
- case ConventionReq.BatchType.Any => Convention.BatchType.None
- case ConventionReq.BatchType.Is(b) => b
- }
- out
- }
-
- final override val supportsColumnar: Boolean = {
- batchType != Convention.BatchType.None
- }
-
- override val rowType0: Convention.RowType = {
- val out = req.req.requiredRowType match {
- case ConventionReq.RowType.Any => Convention.RowType.VanillaRowType
- case ConventionReq.RowType.Is(r) => r
- }
- out
- }
-
- final override val supportsRowBased: Boolean = {
- rowType() != Convention.RowType.None
- }
-
- private def ensureNotFrozen(): Unit = {
- if (frozen.get()) {
- throw new UnsupportedOperationException()
- }
- }
-
- // Enclose mutable APIs.
- override def setLogicalLink(logicalPlan: LogicalPlan): Unit = {
- ensureNotFrozen()
- super.setLogicalLink(logicalPlan)
- }
- override def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = {
- ensureNotFrozen()
- super.setTagValue(tag, value)
- }
- override def unsetTagValue[T](tag: TreeNodeTag[T]): Unit = {
- ensureNotFrozen()
- super.unsetTagValue(tag)
- }
- override def copyTagsFrom(other: SparkPlan): Unit = {
- ensureNotFrozen()
- super.copyTagsFrom(other)
- }
- }
-
private object PlanModelImpl extends PlanModel[SparkPlan] {
private val fakeTc =
SparkShimLoader.getSparkShims.createTestTaskContext(new Properties())
private def fakeTc[T](body: => T): T = {
@@ -144,11 +62,8 @@ object GlutenPlanModel {
override def equals(one: SparkPlan, other: SparkPlan): Boolean =
Objects.equals(withEqualityWrapper(one), withEqualityWrapper(other))
- override def newGroupLeaf(
- groupId: Int,
- metadata: Metadata,
- constraintSet: PropertySet[SparkPlan]): SparkPlan =
- GroupLeafExec(groupId, metadata.asInstanceOf[GlutenMetadata],
constraintSet)
+ override def newGroupLeaf(groupId: Int): GroupLeafExec.Builder =
+ GroupLeafExec.newBuilder(groupId)
override def isGroupLeaf(node: SparkPlan): Boolean = node match {
case _: GroupLeafExec => true
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/plan/GroupLeafExec.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/plan/GroupLeafExec.scala
new file mode 100644
index 0000000000..02746c4534
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/plan/GroupLeafExec.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.enumerated.planner.plan
+
+import org.apache.gluten.execution.GlutenPlan
+import
org.apache.gluten.extension.columnar.enumerated.planner.metadata.{GlutenMetadata,
LogicalLink}
+import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
+import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionReq}
+import org.apache.gluten.ras.GroupLeafBuilder
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
+
+import java.util.concurrent.atomic.AtomicBoolean
+
+// TODO: Make this inherit from GlutenPlan.
+case class GroupLeafExec(groupId: Int, metadata: GlutenMetadata, convReq:
Conv.Req)
+ extends LeafExecNode
+ with Convention.KnownBatchType
+ with Convention.KnownRowTypeForSpark33OrLater
+ with GlutenPlan.SupportsRowBasedCompatible {
+
+ private val frozen = new AtomicBoolean(false)
+
+ // Set the logical link then make the plan node immutable. All future
+ // mutable operations related to tagging will be aborted.
+ if (metadata.logicalLink() != LogicalLink.notFound) {
+ setLogicalLink(metadata.logicalLink().plan)
+ }
+ frozen.set(true)
+
+ override protected def doExecute(): RDD[InternalRow] = throw new
IllegalStateException()
+ override def output: Seq[Attribute] = metadata.schema().output
+
+ override val batchType: Convention.BatchType = {
+ val out = convReq.req.requiredBatchType match {
+ case ConventionReq.BatchType.Any => Convention.BatchType.None
+ case ConventionReq.BatchType.Is(b) => b
+ }
+ out
+ }
+
+ final override val supportsColumnar: Boolean = {
+ batchType != Convention.BatchType.None
+ }
+
+ override val rowType0: Convention.RowType = {
+ val out = convReq.req.requiredRowType match {
+ case ConventionReq.RowType.Any => Convention.RowType.VanillaRowType
+ case ConventionReq.RowType.Is(r) => r
+ }
+ out
+ }
+
+ final override val supportsRowBased: Boolean = {
+ rowType() != Convention.RowType.None
+ }
+
+ private def ensureNotFrozen(): Unit = {
+ if (frozen.get()) {
+ throw new UnsupportedOperationException()
+ }
+ }
+
+ // Enclose mutable APIs.
+ override def setLogicalLink(logicalPlan: LogicalPlan): Unit = {
+ ensureNotFrozen()
+ super.setLogicalLink(logicalPlan)
+ }
+ override def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = {
+ ensureNotFrozen()
+ super.setTagValue(tag, value)
+ }
+ override def unsetTagValue[T](tag: TreeNodeTag[T]): Unit = {
+ ensureNotFrozen()
+ super.unsetTagValue(tag)
+ }
+ override def copyTagsFrom(other: SparkPlan): Unit = {
+ ensureNotFrozen()
+ super.copyTagsFrom(other)
+ }
+}
+
+object GroupLeafExec {
+ class Builder private[GroupLeafExec] (override val id: Int) extends
GroupLeafBuilder[SparkPlan] {
+ private var convReq: Conv.Req = _
+ private var metadata: GlutenMetadata = _
+
+ def withMetadata(metadata: GlutenMetadata): Builder = {
+ this.metadata = metadata
+ this
+ }
+
+ def withConvReq(convReq: Conv.Req): Builder = {
+ this.convReq = convReq
+ this
+ }
+
+ override def build(): SparkPlan = {
+ require(metadata != null)
+ require(convReq != null)
+ GroupLeafExec(id, metadata, convReq)
+ }
+ }
+
+ def newBuilder(groupId: Int): Builder = {
+ new Builder(groupId)
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
index ff530d49bc..868791ac33 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
@@ -16,55 +16,32 @@
*/
package org.apache.gluten.extension.columnar.enumerated.planner.property
+import
org.apache.gluten.extension.columnar.enumerated.planner.plan.GroupLeafExec
+import
org.apache.gluten.extension.columnar.enumerated.planner.property.Conv.{Prop,
Req}
import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionReq, Transition}
-import org.apache.gluten.ras.{Property, PropertyDef}
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
+import org.apache.gluten.ras.{GroupLeafBuilder, Property, PropertyDef}
+import org.apache.gluten.ras.rule.EnforcerRuleFactory
import org.apache.spark.sql.execution._
sealed trait Conv extends Property[SparkPlan] {
- import Conv._
override def definition(): PropertyDef[SparkPlan, _ <: Property[SparkPlan]]
= {
ConvDef
}
-
- override def satisfies(other: Property[SparkPlan]): Boolean = {
- // The following enforces strict type checking against `this` and `other`
- // to make sure:
- //
- // 1. `this`, which came from user implementation of
PropertyDef.getProperty, must be a `Prop`
- // 2. `other` which came from user implementation of
PropertyDef.getChildrenConstraints,
- // must be a `Req`
- //
- // If the user implementation doesn't follow the criteria, cast error will
be thrown.
- //
- // This can be a common practice to implement a safe Property for RAS.
- //
- // TODO: Add a similar case to RAS UTs.
- val req = other.asInstanceOf[Req]
- if (req.isAny) {
- return true
- }
- val prop = this.asInstanceOf[Prop]
- val out = Transition.factory.satisfies(prop.prop, req.req)
- out
- }
}
object Conv {
val any: Conv = Req(ConventionReq.any)
- def of(conv: Convention): Conv = Prop(conv)
- def req(req: ConventionReq): Conv = Req(req)
+ def of(conv: Convention): Prop = Prop(conv)
+ def req(req: ConventionReq): Req = Req(req)
- def get(plan: SparkPlan): Conv = {
+ def get(plan: SparkPlan): Prop = {
Conv.of(Convention.get(plan))
}
- def findTransition(from: Conv, to: Conv): Transition = {
- val prop = from.asInstanceOf[Prop]
- val req = to.asInstanceOf[Req]
- val out = Transition.factory.findTransition(prop.prop, req.req, new
IllegalStateException())
+ def findTransition(from: Prop, to: Req): Transition = {
+ val out = Transition.factory.findTransition(from.prop, to.req, new
IllegalStateException())
out
}
@@ -79,7 +56,7 @@ object Conv {
object ConvDef extends PropertyDef[SparkPlan, Conv] {
// TODO: Should the convention-transparent ops (e.g., aqe shuffle read)
support
- // convention-propagation. Probably need to refactor
getChildrenPropertyRequirements.
+ // convention-propagation. Probably need to refactor getChildrenConstraints.
override def getProperty(plan: SparkPlan): Conv = {
conventionOf(plan)
}
@@ -90,25 +67,57 @@ object ConvDef extends PropertyDef[SparkPlan, Conv] {
}
override def getChildrenConstraints(
- constraint: Property[SparkPlan],
- plan: SparkPlan): Seq[Conv] = {
+ plan: SparkPlan,
+ constraint: Property[SparkPlan]): Seq[Conv] = {
val out = ConventionReq.get(plan).map(Conv.req)
out
}
override def any(): Conv = Conv.any
+
+ override def satisfies(
+ property: Property[SparkPlan],
+ constraint: Property[SparkPlan]): Boolean = {
+ // The following enforces strict type checking against `property` and
`constraint`
+ // to make sure:
+ //
+ // 1. `property`, which came from user implementation of
PropertyDef.getProperty, must be a
+ // `Prop`
+ // 2. `constraint` which came from user implementation of
PropertyDef.getChildrenConstraints,
+ // must be a `Req`
+ //
+ // If the user implementation doesn't follow the criteria, cast error will
be thrown.
+ //
+ // This can be a common practice to implement a safe Property for RAS.
+ //
+ // TODO: Add a similar case to RAS UTs.
+ (property, constraint) match {
+ case (prop: Prop, req: Req) =>
+ if (req.isAny) {
+ return true
+ }
+ val out = Transition.factory.satisfies(prop.prop, req.req)
+ out
+ }
+ }
+
+ override def assignToGroup(
+ group: GroupLeafBuilder[SparkPlan],
+ constraint: Property[SparkPlan]): GroupLeafBuilder[SparkPlan] = (group,
constraint) match {
+ case (builder: GroupLeafExec.Builder, req: Req) =>
+ builder.withConvReq(req)
+ }
}
-case class ConvEnforcerRule(reqConv: Conv) extends RasRule[SparkPlan] {
- override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+case class ConvEnforcerRule() extends EnforcerRuleFactory.SubRule[SparkPlan] {
+ override def enforce(node: SparkPlan, constraint: Property[SparkPlan]):
Iterable[SparkPlan] = {
+ val reqConv = constraint.asInstanceOf[Req]
val conv = Conv.get(node)
- if (conv.satisfies(reqConv)) {
+ if (ConvDef.satisfies(conv, reqConv)) {
return List.empty
}
val transition = Conv.findTransition(conv, reqConv)
val after = transition.apply(node)
List(after)
}
-
- override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/GlutenPropertyModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/GlutenPropertyModel.scala
index bc7014f0fa..24b5c85122 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/GlutenPropertyModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/GlutenPropertyModel.scala
@@ -17,6 +17,7 @@
package org.apache.gluten.extension.columnar.enumerated.planner.property
import org.apache.gluten.ras._
+import org.apache.gluten.ras.rule.{EnforcerRuleFactory, Shape, Shapes}
import org.apache.spark.sql.execution._
@@ -30,13 +31,16 @@ object GlutenPropertyModel {
override def propertyDefs: Seq[PropertyDef[SparkPlan, _ <:
Property[SparkPlan]]] =
Seq(ConvDef)
- override def newEnforcerRuleFactory(
- propertyDef: PropertyDef[SparkPlan, _ <: Property[SparkPlan]])
- : EnforcerRuleFactory[SparkPlan] = (reqProp: Property[SparkPlan]) => {
- propertyDef match {
- case ConvDef =>
- Seq(ConvEnforcerRule(reqProp.asInstanceOf[Conv]))
- }
- }
+ override def newEnforcerRuleFactory(): EnforcerRuleFactory[SparkPlan] =
+ EnforcerRuleFactory.fromSubRules(Seq(new
EnforcerRuleFactory.SubRuleFactory[SparkPlan] {
+ override def newSubRule(constraintDef: PropertyDef[SparkPlan, _ <:
Property[SparkPlan]])
+ : EnforcerRuleFactory.SubRule[SparkPlan] = {
+ constraintDef match {
+ case ConvDef => ConvEnforcerRule()
+ }
+ }
+
+ override def ruleShape: Shape[SparkPlan] = Shapes.fixedHeight(1)
+ }))
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala
index a81ac31cba..7103ef11b3 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala
@@ -20,10 +20,10 @@ package org.apache.gluten.ras
* Metadata defines the common traits among nodes in one single cluster. E.g.
Schema, statistics.
*/
trait MetadataModel[T <: AnyRef] {
+ def dummy(): Metadata
def metadataOf(node: T): Metadata
def verify(one: Metadata, other: Metadata): Unit
-
- def dummy(): Metadata
+ def assignToGroup(group: GroupLeafBuilder[T], meta: Metadata): Unit
}
trait Metadata {}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
index bac9d0b646..946d14a42b 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
@@ -16,8 +16,6 @@
*/
package org.apache.gluten.ras
-import org.apache.gluten.ras.property.PropertySet
-
trait PlanModel[T <: AnyRef] {
// Trivial tree operations.
def childrenOf(node: T): Seq[T]
@@ -26,7 +24,20 @@ trait PlanModel[T <: AnyRef] {
def equals(one: T, other: T): Boolean
// Group operations.
- def newGroupLeaf(groupId: Int, meta: Metadata, constraintSet:
PropertySet[T]): T
+ def newGroupLeaf(groupId: Int): GroupLeafBuilder[T]
def isGroupLeaf(node: T): Boolean
def getGroupId(node: T): Int
}
+
+object PlanModel {
+ implicit class PlanModelImplicits[T <: AnyRef](model: PlanModel[T]) {
+ def isLeaf(node: T): Boolean = {
+ model.childrenOf(node).isEmpty
+ }
+ }
+}
+
+trait GroupLeafBuilder[T <: AnyRef] {
+ def id(): Int
+ def build(): T
+}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala
index e764631e77..765c5a3809 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala
@@ -16,28 +16,33 @@
*/
package org.apache.gluten.ras
-import org.apache.gluten.ras.rule.RasRule
+import org.apache.gluten.ras.rule.EnforcerRuleFactory
// TODO Use class tags to restrict runtime user-defined class types.
trait Property[T <: AnyRef] {
- def satisfies(other: Property[T]): Boolean
def definition(): PropertyDef[T, _ <: Property[T]]
}
+object Property {
+ implicit class PropertyImplicits[T <: AnyRef](property: Property[T]) {
+ def satisfies(constraint: Property[T]): Boolean = {
+ property.definition().satisfies(property, constraint)
+ }
+ }
+}
+
trait PropertyDef[T <: AnyRef, P <: Property[T]] {
def any(): P
def getProperty(plan: T): P
- def getChildrenConstraints(constraint: Property[T], plan: T): Seq[P]
-}
-
-trait EnforcerRuleFactory[T <: AnyRef] {
- def newEnforcerRules(constraint: Property[T]): Seq[RasRule[T]]
+ def getChildrenConstraints(plan: T, constraint: Property[T]): Seq[P]
+ def satisfies(property: Property[T], constraint: Property[T]): Boolean
+ def assignToGroup(group: GroupLeafBuilder[T], constraint: Property[T]):
GroupLeafBuilder[T]
}
trait PropertyModel[T <: AnyRef] {
def propertyDefs: Seq[PropertyDef[T, _ <: Property[T]]]
- def newEnforcerRuleFactory(propertyDef: PropertyDef[T, _ <: Property[T]]):
EnforcerRuleFactory[T]
+ def newEnforcerRuleFactory(): EnforcerRuleFactory[T]
}
object PropertyModel {}
diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
index 6cf15b0e8d..785afe5ebc 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
@@ -16,19 +16,15 @@
*/
package org.apache.gluten.ras
-import org.apache.gluten.ras.property.PropertySet
+import org.apache.gluten.ras.property.{MemoRole, PropertySet,
PropertySetFactory}
import org.apache.gluten.ras.rule.RasRule
-import scala.collection.mutable
-
/**
* Entrypoint of RAS (relational algebra selector)'s search engine. See basic
introduction of RAS:
* https://github.com/apache/incubator-gluten/issues/5057.
*/
trait Optimization[T <: AnyRef] {
def newPlanner(plan: T, constraintSet: PropertySet[T]): RasPlanner[T]
- def anyPropSet(): PropertySet[T]
- def withNewConfig(confFunc: RasConfig => RasConfig): Optimization[T]
}
object Optimization {
@@ -41,12 +37,6 @@ object Optimization {
ruleFactory: RasRule.Factory[T]): Optimization[T] = {
Ras(planModel, costModel, metadataModel, propertyModel, explain,
ruleFactory)
}
-
- implicit class OptimizationImplicits[T <: AnyRef](opt: Optimization[T]) {
- def newPlanner(plan: T): RasPlanner[T] = {
- opt.newPlanner(plan, opt.anyPropSet())
- }
- }
}
class Ras[T <: AnyRef] private (
@@ -54,27 +44,20 @@ class Ras[T <: AnyRef] private (
val planModel: PlanModel[T],
val costModel: CostModel[T],
val metadataModel: MetadataModel[T],
- val propertyModel: PropertyModel[T],
+ private val propertyModel: PropertyModel[T],
val explain: RasExplain[T],
val ruleFactory: RasRule.Factory[T])
extends Optimization[T] {
import Ras._
- override def withNewConfig(confFunc: RasConfig => RasConfig): Ras[T] = {
- new Ras(
- confFunc(config),
- planModel,
- costModel,
- metadataModel,
- propertyModel,
- explain,
- ruleFactory)
- }
-
- private val propSetFactory: PropertySetFactory[T] =
PropertySetFactory(propertyModel, planModel)
+ private[ras] val memoRoleDef: MemoRole.Def[T] = MemoRole.newDef(planModel)
+ private val userPropertySetFactory: PropertySetFactory[T] =
+ PropertySetFactory(propertyModel, planModel)
+ private val propSetFactory: PropertySetFactory[T] =
+ MemoRole.wrapPropertySetFactory(userPropertySetFactory, memoRoleDef)
// Normal groups start with ID 0, so it's safe to use Int.MinValue to do
validation.
private val dummyGroup: T =
- planModel.newGroupLeaf(Int.MinValue, metadataModel.dummy(),
propSetFactory.any())
+ newGroupLeaf(Int.MinValue, metadataModel.dummy(), propSetFactory.any())
private val infCost: Cost = costModel.makeInfCost()
validateModels()
@@ -111,7 +94,26 @@ class Ras[T <: AnyRef] private (
RasPlanner(this, constraintSet, plan)
}
- override def anyPropSet(): PropertySet[T] = propertySetFactory().any()
+ def newPlanner(plan: T): RasPlanner[T] = {
+ RasPlanner(this, userPropertySetFactory.any(), plan)
+ }
+
+ def withNewConfig(confFunc: RasConfig => RasConfig): Ras[T] = {
+ new Ras(
+ confFunc(config),
+ planModel,
+ costModel,
+ metadataModel,
+ propertyModel,
+ explain,
+ ruleFactory)
+ }
+
+ private[ras] def userConstraintSet(): PropertySet[T] =
+ userPropertySetFactory.any() +: memoRoleDef.reqUser
+
+ private[ras] def hubConstraintSet(): PropertySet[T] =
+ userPropertySetFactory.any() +: memoRoleDef.reqHub
private[ras] def propSetOf(plan: T): PropertySet[T] = {
val out = propertySetFactory().get(plan)
@@ -131,7 +133,7 @@ class Ras[T <: AnyRef] private (
}
private[ras] def isLeaf(node: T): Boolean = {
- planModel.childrenOf(node).isEmpty
+ planModel.isLeaf(node)
}
private[ras] def isCanonical(node: T): Boolean = {
@@ -157,6 +159,13 @@ class Ras[T <: AnyRef] private (
private[ras] def isInfCost(cost: Cost) =
costModel.costComparator().equiv(cost, infCost)
private[ras] def toHashKey(node: T): UnsafeHashKey[T] = UnsafeHashKey(this,
node)
+
+ private[ras] def newGroupLeaf(groupId: Int, meta: Metadata, constraintSet:
PropertySet[T]): T = {
+ val builder = planModel.newGroupLeaf(groupId)
+ metadataModel.assignToGroup(builder, meta)
+ propSetFactory.assignToGroup(builder, constraintSet)
+ builder.build()
+ }
}
object Ras {
@@ -177,65 +186,6 @@ object Ras {
ruleFactory)
}
- trait PropertySetFactory[T <: AnyRef] {
- def any(): PropertySet[T]
- def get(node: T): PropertySet[T]
- def childrenConstraintSets(constraintSet: PropertySet[T], node: T):
Seq[PropertySet[T]]
- }
-
- private object PropertySetFactory {
- def apply[T <: AnyRef](
- propertyModel: PropertyModel[T],
- planModel: PlanModel[T]): PropertySetFactory[T] =
- new PropertySetFactoryImpl[T](propertyModel, planModel)
-
- private class PropertySetFactoryImpl[T <: AnyRef](
- propertyModel: PropertyModel[T],
- planModel: PlanModel[T])
- extends PropertySetFactory[T] {
- private val propDefs: Seq[PropertyDef[T, _ <: Property[T]]] =
propertyModel.propertyDefs
- private val anyConstraint = {
- val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
- propDefs.map(propDef => (propDef, propDef.any())).toMap
- PropertySet[T](m)
- }
-
- override def any(): PropertySet[T] = anyConstraint
-
- override def get(node: T): PropertySet[T] = {
- val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
- propDefs.map(propDef => (propDef, propDef.getProperty(node))).toMap
- PropertySet[T](m)
- }
-
- override def childrenConstraintSets(
- constraintSet: PropertySet[T],
- node: T): Seq[PropertySet[T]] = {
- val builder: Seq[mutable.Map[PropertyDef[T, _ <: Property[T]],
Property[T]]] =
- planModel
- .childrenOf(node)
- .map(_ => mutable.Map[PropertyDef[T, _ <: Property[T]],
Property[T]]())
-
- propDefs
- .foldLeft(builder) {
- (
- builder: Seq[mutable.Map[PropertyDef[T, _ <: Property[T]],
Property[T]]],
- propDef: PropertyDef[T, _ <: Property[T]]) =>
- val constraint = constraintSet.get(propDef)
- val childrenConstraints =
propDef.getChildrenConstraints(constraint, node)
- builder.zip(childrenConstraints).map {
- case (childBuilder, childConstraint) =>
- childBuilder += (propDef -> childConstraint)
- }
- }
- .map {
- builder: mutable.Map[PropertyDef[T, _ <: Property[T]],
Property[T]] =>
- PropertySet[T](builder.toMap)
- }
- }
- }
- }
-
trait UnsafeHashKey[T]
private object UnsafeHashKey {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
index e01ee053ef..98f03eb961 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
@@ -60,11 +60,11 @@ object RasCluster {
mutable.ListBuffer()
override def contains(t: CanonicalNode[T]): Boolean = {
- deDup.contains(t.toHashKey())
+ deDup.contains(t.toHashKey)
}
override def add(t: CanonicalNode[T]): Unit = {
- val key = t.toHashKey()
+ val key = t.toHashKey
assert(!deDup.contains(key))
ras.metadataModel.verify(metadata,
ras.metadataModel.metadataOf(t.self()))
deDup += key
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala
index 9591fbb225..10087e7e1d 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala
@@ -42,7 +42,7 @@ object RasGroup {
override val id: Int,
override val constraintSet: PropertySet[T])
extends RasGroup[T] {
- private val groupLeaf: T = ras.planModel.newGroupLeaf(id,
clusterKey.metadata, constraintSet)
+ private val groupLeaf: T = ras.newGroupLeaf(id, clusterKey.metadata,
constraintSet)
override def clusterKey(): RasClusterKey = clusterKey
override def self(): T = groupLeaf
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
index 8c9b526056..beb61475fc 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
@@ -43,7 +43,7 @@ object RasNode {
node.asInstanceOf[GroupNode[T]]
}
- def toHashKey(): UnsafeHashKey[T] = node.ras().toHashKey(node.self())
+ def toHashKey: UnsafeHashKey[T] = node.ras().toHashKey(node.self())
}
}
@@ -52,6 +52,8 @@ trait CanonicalNode[T <: AnyRef] extends RasNode[T] {
}
object CanonicalNode {
+ trait UniqueKey extends Any
+
def apply[T <: AnyRef](ras: Ras[T], canonical: T): CanonicalNode[T] = {
assert(ras.isCanonical(canonical))
val propSet = ras.propSetOf(canonical)
@@ -96,7 +98,7 @@ trait GroupNode[T <: AnyRef] extends RasNode[T] {
object GroupNode {
def apply[T <: AnyRef](ras: Ras[T], group: RasGroup[T]): GroupNode[T] = {
val self = group.self()
- // Re-derive property set of group leaf. User should define an appropriate
conversion
+ // Re-derive a property set of group leaf. User should define an
appropriate conversion
// from group constraints to its output properties in property model or
plan model.
val propSet = ras.propSetOf(self)
new GroupNodeImpl[T](ras, self, propSet, group.id())
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
index c681cfbc47..8f9456df34 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
@@ -31,11 +31,12 @@ private class DpPlanner[T <: AnyRef] private (ras: Ras[T],
constraintSet: Proper
import DpPlanner._
private val memo = Memo.unsafe(ras)
- private val rules = ras.ruleFactory.create().map(rule => RuleApplier(ras,
memo, rule))
- private val enforcerRuleSet = EnforcerRuleSet[T](ras, memo)
+ private val rules = ras.ruleFactory.create().map(rule =>
RuleApplier.regular(ras, memo, rule))
+ private val enforcerRuleSetFactory = EnforcerRuleSet.Factory.regular(ras,
memo)
+ private val deriverRuleSetFactory = EnforcerRuleSet.Factory.derive(ras, memo)
private lazy val rootGroupId: Int = {
- memo.memorize(plan, constraintSet).id()
+ memo.memorize(plan, constraintSet +: ras.memoRoleDef.reqUser).id()
}
private lazy val best: (Best[T], KnownCostPath[T]) = {
@@ -57,7 +58,8 @@ private class DpPlanner[T <: AnyRef] private (ras: Ras[T],
constraintSet: Proper
private def findBest(memoTable: MemoTable[T], groupId: Int): Best[T] = {
val cKey = memoTable.asGroupSupplier()(groupId).clusterKey()
val algoDef = new DpExploreAlgoDef[T]
- val adjustment = new ExploreAdjustment(ras, memoTable, rules,
enforcerRuleSet)
+ val adjustment =
+ new ExploreAdjustment(ras, memoTable, rules, enforcerRuleSetFactory,
deriverRuleSetFactory)
DpClusterAlgo.resolve(memoTable, algoDef, adjustment, cKey)
val finder = BestFinder(ras, memoTable.newState())
finder.bestOf(groupId)
@@ -88,7 +90,8 @@ object DpPlanner {
ras: Ras[T],
memoTable: MemoTable[T],
rules: Seq[RuleApplier[T]],
- enforcerRuleSet: EnforcerRuleSet[T])
+ enforcerRuleSetFactory: EnforcerRuleSet.Factory[T],
+ deriverRuleSetFactory: EnforcerRuleSet.Factory[T])
extends DpClusterAlgo.Adjustment[T] {
import ExploreAdjustment._
@@ -97,12 +100,14 @@ object DpPlanner {
override def exploreChildX(
panel: Panel[InClusterNode[T], RasClusterKey],
x: InClusterNode[T]): Unit = {
+ applyHubRulesOnNode(panel, x.clusterKey, x.can)
applyRulesOnNode(panel, x.clusterKey, x.can)
}
override def exploreChildY(
panel: Panel[InClusterNode[T], RasClusterKey],
y: RasClusterKey): Unit = {}
+
override def exploreParentX(
panel: Panel[InClusterNode[T], RasClusterKey],
x: InClusterNode[T]): Unit = {}
@@ -122,8 +127,8 @@ object DpPlanner {
if (rules.isEmpty) {
return
}
- val dummyGroup = memoTable.getDummyGroup(cKey)
- findPaths(GroupNode(ras, dummyGroup), ruleShapes, List(new
FromSingleNode[T](can))) {
+ val hubGroup = memoTable.getHubGroup(cKey)
+ findPaths(GroupNode(ras, hubGroup), ruleShapes, List(new
FromSingleNode[T](can))) {
path =>
val rootNode = path.node().self()
if (rootNode.isCanonical) {
@@ -133,16 +138,37 @@ object DpPlanner {
}
}
+ private def applyHubRulesOnNode(
+ panel: Panel[InClusterNode[T], RasClusterKey],
+ cKey: RasClusterKey,
+ can: CanonicalNode[T]): Unit = {
+ val hubConstraint = ras.hubConstraintSet()
+ val hubDeriverRuleSet = deriverRuleSetFactory.ruleSetOf(hubConstraint)
+ val hubDeriverRules = hubDeriverRuleSet.rules()
+ if (hubDeriverRules.nonEmpty) {
+ val hubDeriverRuleShapes = hubDeriverRuleSet.shapes()
+ val userGroup = memoTable.getUserGroup(cKey)
+ findPaths(
+ GroupNode(ras, userGroup),
+ hubDeriverRuleShapes,
+ List(new FromSingleNode[T](can))) {
+ path => hubDeriverRules.foreach(rule => applyRule(panel, cKey, rule,
path))
+ }
+ }
+ }
+
private def applyEnforcerRules(
panel: Panel[InClusterNode[T], RasClusterKey],
cKey: RasClusterKey): Unit = {
- val dummyGroup = memoTable.getDummyGroup(cKey)
+ val hubGroup = memoTable.getHubGroup(cKey)
cKey.propSets(memoTable).foreach {
constraintSet: PropertySet[T] =>
- val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
+ val enforcerRuleSet = deriverRuleSetFactory.ruleSetOf(
+ constraintSet) ++ enforcerRuleSetFactory.ruleSetOf(constraintSet)
+ val enforcerRules = enforcerRuleSet.rules()
if (enforcerRules.nonEmpty) {
- val shapes = enforcerRuleSet.ruleShapesOf(constraintSet)
- findPaths(GroupNode(ras, dummyGroup), shapes, List.empty) {
+ val enforcerRuleShapes = enforcerRuleSet.shapes()
+ findPaths(GroupNode(ras, hubGroup), enforcerRuleShapes,
List.empty) {
path => enforcerRules.foreach(rule => applyRule(panel, cKey,
rule, path))
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala
index 746cce8983..d8c7388669 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala
@@ -188,32 +188,48 @@ object DpZipperAlgo {
val xSolutions: mutable.Map[XKey[X, Y, XOutput, YOutput], XOutput] =
mutable.Map()
def loop(): Unit = {
- while (true) {
- val xKeys: Set[XKey[X, Y, XOutput, YOutput]] =
- algoDef.browseY(thisY).map(algoDef.keyOfX(_)).toSet
+ def browseXKeys(): Set[XKey[X, Y, XOutput, YOutput]] = {
+ algoDef.browseY(thisY).map(algoDef.keyOfX(_)).toSet
+ }
+
+ def innerLoop(): Unit = {
+ while (true) {
+ val xKeys = browseXKeys()
+ val xCount = xKeys.size
+ if (xCount == xSolutions.size) {
+ // We got enough children solutions.
+ return
+ }
+
+ xKeys.filterNot(xKey => xSolutions.contains(xKey)).foreach {
+ childXKey =>
+ val xOutputs = solveXRec(childXKey.x, xCycleDetector,
newYCycleDetector)
+ val cm = xOutputs.cycleMemory()
+ cyclicXs ++= cm.cyclicXs
+ cyclicYs ++= cm.cyclicYs
+ sBuilder.addYAsBackDependencyOfX(thisY, childXKey.x)
+ xSolutions += childXKey -> xOutputs.output()
+ // Try applying adjustment
+ // to see if algo caller likes to add some Xs or to invalidate
+ // some of the registered solutions.
+ adjustment.exploreChildX(adjustmentPanel, childXKey.x)
+ }
+ }
+ }
+ while (true) {
+ val xKeys = browseXKeys()
val xCount = xKeys.size
if (xCount == xSolutions.size) {
// We got enough children solutions.
return
}
- xKeys.filterNot(xKey => xSolutions.contains(xKey)).foreach {
- childXKey =>
- val xOutputs = solveXRec(childXKey.x, xCycleDetector,
newYCycleDetector)
- val cm = xOutputs.cycleMemory()
- cyclicXs ++= cm.cyclicXs
- cyclicYs ++= cm.cyclicYs
- sBuilder.addYAsBackDependencyOfX(thisY, childXKey.x)
- xSolutions += childXKey -> xOutputs.output()
- // Try applying adjustment
- // to see if algo caller likes to add some Xs or to invalidate
- // some of the registered solutions.
- adjustment.exploreChildX(adjustmentPanel, childXKey.x)
- }
+ innerLoop()
+
adjustment.exploreParentY(adjustmentPanel, thisY)
// If an adjustment (this adjustment or children's) just invalidated
one or more
- // children of this element's solutions, the children's keys would
be removed from
+ // children of this element's solutions, the children's keys would
be removed from the
// back-dependency list. We do a test here to trigger re-computation
if some children
// do get invalidated.
xSolutions.keySet.foreach {
@@ -263,32 +279,48 @@ object DpZipperAlgo {
val ySolutions: mutable.Map[YKey[X, Y, XOutput, YOutput], YOutput] =
mutable.Map()
def loop(): Unit = {
- while (true) {
- val yKeys: Set[YKey[X, Y, XOutput, YOutput]] =
- algoDef.browseX(thisX).map(algoDef.keyOfY(_)).toSet
+ def browseYKeys(): Set[YKey[X, Y, XOutput, YOutput]] = {
+ algoDef.browseX(thisX).map(algoDef.keyOfY(_)).toSet
+ }
+
+ def innerLoop(): Unit = {
+ while (true) {
+ val yKeys = browseYKeys()
+ val yCount = yKeys.size
+ if (yCount == ySolutions.size) {
+ // We got enough children solutions.
+ return
+ }
+
+ yKeys.filterNot(yKey => ySolutions.contains(yKey)).foreach {
+ childYKey =>
+ val yOutputs = solveYRec(childYKey.y, newXCycleDetector,
yCycleDetector)
+ val cm = yOutputs.cycleMemory()
+ cyclicXs ++= cm.cyclicXs
+ cyclicYs ++= cm.cyclicYs
+ sBuilder.addXAsBackDependencyOfY(thisX, childYKey.y)
+ ySolutions += childYKey -> yOutputs.output()
+ // Try applying adjustment
+ // to see if algo caller likes to add some Ys or to invalidate
+ // some of the registered solutions.
+ adjustment.exploreChildY(adjustmentPanel, childYKey.y)
+ }
+ }
+ }
+ while (true) {
+ val yKeys = browseYKeys()
val yCount = yKeys.size
if (yCount == ySolutions.size) {
// We got enough children solutions.
return
}
- yKeys.filterNot(yKey => ySolutions.contains(yKey)).foreach {
- childYKey =>
- val yOutputs = solveYRec(childYKey.y, newXCycleDetector,
yCycleDetector)
- val cm = yOutputs.cycleMemory()
- cyclicXs ++= cm.cyclicXs
- cyclicYs ++= cm.cyclicYs
- sBuilder.addXAsBackDependencyOfY(thisX, childYKey.y)
- ySolutions += childYKey -> yOutputs.output()
- // Try applying adjustment
- // to see if algo caller likes to add some Ys or to invalidate
- // some of the registered solutions.
- adjustment.exploreChildY(adjustmentPanel, childYKey.y)
- }
+ innerLoop()
+
adjustment.exploreParentX(adjustmentPanel, thisX)
// If an adjustment (this adjustment or children's) just invalidated
one or more
- // children of this element's solutions, the children's keys would
be removed from
+ // children of this element's solutions, the children's keys would
be removed from the
// back-dependency list. We do a test here to trigger re-computation
if some children
// do get invalidated.
ySolutions.keySet.foreach {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
index f0cb42cf66..58a37afa47 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
@@ -31,11 +31,12 @@ private class ExhaustivePlanner[T <: AnyRef] private (
plan: T)
extends RasPlanner[T] {
private val memo = Memo(ras)
- private val rules = ras.ruleFactory.create().map(rule => RuleApplier(ras,
memo, rule))
- private val enforcerRuleSet = EnforcerRuleSet[T](ras, memo)
+ private val rules = ras.ruleFactory.create().map(rule =>
RuleApplier.regular(ras, memo, rule))
+ private val enforcerRuleSetFactory = EnforcerRuleSet.Factory.regular(ras,
memo)
+ private val deriverRuleSetFactory = EnforcerRuleSet.Factory.derive(ras, memo)
private lazy val rootGroupId: Int = {
- memo.memorize(plan, constraintSet).id()
+ memo.memorize(plan, constraintSet +: ras.memoRoleDef.reqUser).id()
}
private lazy val best: (Best[T], KnownCostPath[T]) = {
@@ -59,7 +60,12 @@ private class ExhaustivePlanner[T <: AnyRef] private (
// TODO1: Prune paths within cost threshold
// ~~ TODO2: Use partial-canonical paths to reduce search space ~~
memo.doExhaustively {
- val explorer = new ExhaustiveExplorer(ras, memo.newState(), rules,
enforcerRuleSet)
+ val explorer = new ExhaustiveExplorer(
+ ras,
+ memo.newState(),
+ rules,
+ enforcerRuleSetFactory,
+ deriverRuleSetFactory)
explorer.explore()
}
}
@@ -78,12 +84,14 @@ object ExhaustivePlanner {
ras: Ras[T],
memoState: MemoState[T],
rules: Seq[RuleApplier[T]],
- enforcerRuleSet: EnforcerRuleSet[T]) {
+ enforcerRuleSetFactory: EnforcerRuleSet.Factory[T],
+ deriverRuleSetFactory: EnforcerRuleSet.Factory[T]) {
private val allClusters = memoState.allClusters()
private val allGroups = memoState.allGroups()
def explore(): Unit = {
// TODO: ONLY APPLY RULES ON ALTERED GROUPS (and close parents)
+ applyHubRules()
applyEnforcerRules()
applyRules()
}
@@ -114,23 +122,43 @@ object ExhaustivePlanner {
.clusterLookup()
.foreach {
case (cKey, cluster) =>
- val dummyGroup = memoState.getDummyGroup(cKey)
- findPaths(GroupNode(ras, dummyGroup), shapes) {
+ val hubGroup = memoState.getHubGroup(cKey)
+ findPaths(GroupNode(ras, hubGroup), shapes) {
path => rules.foreach(rule => applyRule(rule,
InClusterPath(cKey, path)))
}
}
}
+ private def applyHubRules(): Unit = {
+ val hubConstraint = ras.hubConstraintSet()
+ val hubDeriverRuleSet = deriverRuleSetFactory.ruleSetOf(hubConstraint)
+ val hubDeriverRules = hubDeriverRuleSet.rules()
+ if (hubDeriverRules.nonEmpty) {
+ memoState
+ .clusterLookup()
+ .foreach {
+ case (cKey, cluster) =>
+ val hubDeriverRuleShapes = hubDeriverRuleSet.shapes()
+ val userGroup = memoState.getUserGroup(cKey)
+ findPaths(GroupNode(ras, userGroup), hubDeriverRuleShapes) {
+ path => hubDeriverRules.foreach(rule => applyRule(rule,
InClusterPath(cKey, path)))
+ }
+ }
+ }
+ }
+
private def applyEnforcerRules(): Unit = {
allGroups.foreach {
group =>
val constraintSet = group.constraintSet()
- val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
+ val enforcerRuleSet = deriverRuleSetFactory.ruleSetOf(
+ constraintSet) ++ enforcerRuleSetFactory.ruleSetOf(constraintSet)
+ val enforcerRules = enforcerRuleSet.rules()
if (enforcerRules.nonEmpty) {
- val shapes = enforcerRules.map(_.shape())
+ val enforcerRuleShapes = enforcerRuleSet.shapes()
val cKey = group.clusterKey()
- val dummyGroup = memoState.getDummyGroup(cKey)
- findPaths(GroupNode(ras, dummyGroup), shapes) {
+ val hubGroup = memoState.getHubGroup(cKey)
+ findPaths(GroupNode(ras, hubGroup), enforcerRuleShapes) {
path => enforcerRules.foreach(rule => applyRule(rule,
InClusterPath(cKey, path)))
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
index c2ebccd405..cbd43026c6 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
@@ -29,8 +29,6 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
import ForwardMemoTable._
private val groupBuffer: mutable.ArrayBuffer[RasGroup[T]] =
mutable.ArrayBuffer()
- private val dummyGroupBuffer: mutable.ArrayBuffer[RasGroup[T]] =
- mutable.ArrayBuffer[RasGroup[T]]()
private val clusterKeyBuffer: mutable.ArrayBuffer[IntClusterKey] =
mutable.ArrayBuffer()
private val clusterBuffer: mutable.ArrayBuffer[MutableRasCluster[T]] =
mutable.ArrayBuffer()
@@ -55,29 +53,22 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
clusterBuffer += MutableRasCluster(ras, metadata)
clusterDisjointSet.grow()
groupLookup += mutable.Map()
- // Normal groups start with ID 0, so it's safe to use negative IDs for
dummy groups.
- // Dummy group ID starts from -1.
- dummyGroupBuffer += RasGroup(ras, key, -(clusterId + 1),
ras.propertySetFactory().any())
+ groupOf(key, ras.hubConstraintSet())
+ groupOf(key, ras.userConstraintSet())
+ memoWriteCount += 1
key
}
- override def getDummyGroup(key: RasClusterKey): RasGroup[T] = {
- val ancestor = ancestorClusterIdOf(key)
- val out = dummyGroupBuffer(ancestor)
- assert(out.id() == -(ancestor + 1))
- out
- }
-
- override def groupOf(key: RasClusterKey, propSet: PropertySet[T]):
RasGroup[T] = {
+ override def groupOf(key: RasClusterKey, constraintSet: PropertySet[T]):
RasGroup[T] = {
val ancestor = ancestorClusterIdOf(key)
val lookup = groupLookup(ancestor)
- if (lookup.contains(propSet)) {
- return lookup(propSet)
+ if (lookup.contains(constraintSet)) {
+ return lookup(constraintSet)
}
val gid = groupBuffer.size
val newGroup =
- RasGroup(ras, IntClusterKey(ancestor, key.metadata), gid, propSet)
- lookup += propSet -> newGroup
+ RasGroup(ras, IntClusterKey(ancestor, key.metadata), gid, constraintSet)
+ lookup += constraintSet -> newGroup
groupBuffer += newGroup
memoWriteCount += 1
newGroup
@@ -134,9 +125,9 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
val fromGroups = groupLookup(fromKey.id())
val toGroups = groupLookup(toKey.id())
fromGroups.foreach {
- case (fromPropSet, _) =>
- if (!toGroups.contains(fromPropSet)) {
- groupOf(toKey, fromPropSet)
+ case (fromConstraintSet, _) =>
+ if (!toGroups.contains(fromConstraintSet)) {
+ groupOf(toKey, fromConstraintSet)
}
}
@@ -147,18 +138,14 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
}
override def getGroup(id: Int): RasGroup[T] = {
- if (id < 0) {
- val out = dummyGroupBuffer(-id - 1)
- assert(out.id() == id)
- return out
- }
+ assert(id >= 0)
groupBuffer(id)
}
override def allClusterKeys(): Seq[RasClusterKey] = clusterKeyBuffer.toSeq
override def allGroupIds(): Seq[Int] = {
- val from = -dummyGroupBuffer.size
+ val from = 0
val to = groupBuffer.size
(from until to).toVector
}
@@ -171,12 +158,23 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
assert(clusterKeyBuffer.size == clusterBuffer.size)
assert(clusterKeyBuffer.size == clusterDisjointSet.size)
assert(clusterKeyBuffer.size == groupLookup.size)
- assert(clusterKeyBuffer.size == dummyGroupBuffer.size)
}
override def probe(): MemoTable.Probe[T] = new
ForwardMemoTable.Probe[T](this)
override def writeCount(): Int = memoWriteCount
+
+ override def getHubGroup(key: RasClusterKey): RasGroup[T] = {
+ val ancestor = ancestorClusterIdOf(key)
+ val lookup = groupLookup(ancestor)
+ lookup(ras.hubConstraintSet())
+ }
+
+ override def getUserGroup(key: RasClusterKey): RasGroup[T] = {
+ val ancestor = ancestorClusterIdOf(key)
+ val lookup = groupLookup(ancestor)
+ lookup(ras.userConstraintSet())
+ }
}
object ForwardMemoTable {
@@ -216,7 +214,7 @@ object ForwardMemoTable {
val changedClusters =
(clustersOfNewGroups.toSet ++ affectedClustersDuringMerging) --
newClusters
- // We consider a existing cluster with new groups changed.
+ // We consider an existing cluster with new groups changed.
Probe.Diff(changedClusters)
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
index 68a8ca7457..db3e90d7ab 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
@@ -73,8 +73,8 @@ object Memo {
}
}
- private def dummyGroupOf(clusterKey: RasClusterKey): RasGroup[T] = {
- memoTable.getDummyGroup(clusterKey)
+ private def hubGroupOf(clusterKey: RasClusterKey): RasGroup[T] = {
+ memoTable.getHubGroup(clusterKey)
}
private def toCacheKey(n: T): MemoCacheKey[T] = {
@@ -91,7 +91,7 @@ object Memo {
val keyUnsafe = ras.withNewChildren(
n,
- childrenPrepares.map(childPrepare =>
dummyGroupOf(childPrepare.clusterKey()).self()))
+ childrenPrepares.map(childPrepare =>
hubGroupOf(childPrepare.clusterKey()).self()))
val cacheKey = toCacheKey(keyUnsafe)
@@ -128,6 +128,7 @@ object Memo {
// TODO: Traverse up the tree to do more merges.
private def prepareInsert(node: T): Prepare[T] = {
if (ras.isGroupLeaf(node)) {
+ // This mainly serves the group reduction case.
val group = parent.memoTable.getGroup(ras.planModel.getGroupId(node))
val residentCluster = group.clusterKey()
@@ -146,7 +147,7 @@ object Memo {
val keyUnsafe = ras.withNewChildren(
node,
childrenPrepares.map {
- childPrepare =>
parent.dummyGroupOf(childPrepare.clusterKey()).self()
+ childPrepare => parent.hubGroupOf(childPrepare.clusterKey()).self()
})
val cacheKey = parent.toCacheKey(keyUnsafe)
@@ -203,7 +204,7 @@ object Memo {
assert(!ras.isGroupLeaf(node))
val childrenGroups = children
.zip(ras.planModel.childrenOf(node))
-
.zip(ras.propertySetFactory().childrenConstraintSets(constraintSet, node))
+ .zip(ras.propertySetFactory().childrenConstraintSets(node,
constraintSet))
.map {
case ((childPrepare, child), childConstraintSet) =>
childPrepare.doInsert(child, childConstraintSet)
@@ -233,9 +234,13 @@ object Memo {
}
private object MemoCacheKey {
+ private def apply[T <: AnyRef](delegate: UnsafeHashKey[T]):
MemoCacheKey[T] = {
+ throw new UnsupportedOperationException()
+ }
+
def apply[T <: AnyRef](ras: Ras[T], self: T): MemoCacheKey[T] = {
assert(ras.isCanonical(self))
- MemoCacheKey[T](ras.toHashKey(self))
+ new MemoCacheKey[T](ras.toHashKey(self))
}
}
@@ -244,7 +249,8 @@ object Memo {
trait MemoStore[T <: AnyRef] {
def getCluster(key: RasClusterKey): RasCluster[T]
- def getDummyGroup(key: RasClusterKey): RasGroup[T]
+ def getHubGroup(key: RasClusterKey): RasGroup[T]
+ def getUserGroup(key: RasClusterKey): RasGroup[T]
def getGroup(id: Int): RasGroup[T]
}
@@ -259,7 +265,8 @@ object MemoStore {
trait MemoState[T <: AnyRef] extends MemoStore[T] {
def ras(): Ras[T]
def clusterLookup(): Map[RasClusterKey, RasCluster[T]]
- def clusterDummyGroupLookup(): Map[RasClusterKey, RasGroup[T]]
+ def clusterHubGroupLookup(): Map[RasClusterKey, RasGroup[T]]
+ def clusterUserGroupLookup(): Map[RasClusterKey, RasGroup[T]]
def allClusters(): Iterable[RasCluster[T]]
def allGroups(): Seq[RasGroup[T]]
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
index 2e2323a1e5..3bdf7b794e 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
@@ -40,7 +40,7 @@ object MemoTable {
trait Writable[T <: AnyRef] extends MemoTable[T] {
def newCluster(metadata: Metadata): RasClusterKey
- def groupOf(key: RasClusterKey, propertySet: PropertySet[T]): RasGroup[T]
+ def groupOf(key: RasClusterKey, constraintSet: PropertySet[T]): RasGroup[T]
def addToCluster(key: RasClusterKey, node: CanonicalNode[T]): Unit
def mergeClusters(one: RasClusterKey, other: RasClusterKey): Unit
@@ -67,14 +67,16 @@ object MemoTable {
private case class MemoStateImpl[T <: AnyRef](
override val ras: Ras[T],
override val clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
- override val clusterDummyGroupLookup: Map[RasClusterKey, RasGroup[T]],
+ override val clusterHubGroupLookup: Map[RasClusterKey, RasGroup[T]],
+ override val clusterUserGroupLookup: Map[RasClusterKey, RasGroup[T]],
override val allGroups: Seq[RasGroup[T]],
idToGroup: Map[Int, RasGroup[T]])
extends MemoState[T] {
private val allClustersCopy = clusterLookup.values
override def getCluster(key: RasClusterKey): RasCluster[T] =
clusterLookup(key)
- override def getDummyGroup(key: RasClusterKey): RasGroup[T] =
clusterDummyGroupLookup(key)
+ override def getHubGroup(key: RasClusterKey): RasGroup[T] =
clusterHubGroupLookup(key)
+ override def getUserGroup(key: RasClusterKey): RasGroup[T] =
clusterUserGroupLookup(key)
override def getGroup(id: Int): RasGroup[T] = idToGroup(id)
override def allClusters(): Iterable[RasCluster[T]] = allClustersCopy
}
@@ -85,9 +87,15 @@ object MemoTable {
.allClusterKeys()
.map(key => key -> ImmutableRasCluster(table.ras,
table.getCluster(key)))
.toMap
- val immutableDummyGroups = table
+
+ val immutableHubGroups = table
+ .allClusterKeys()
+ .map(key => key -> table.getHubGroup(key))
+ .toMap
+
+ val immutableUserGroups = table
.allClusterKeys()
- .map(key => key -> table.getDummyGroup(key))
+ .map(key => key -> table.getUserGroup(key))
.toMap
var maxGroupId = Int.MinValue
@@ -107,7 +115,13 @@ object MemoTable {
val allGroups = (0 to maxGroupId).map(table.getGroup).toVector
- MemoStateImpl(table.ras, immutableClusters, immutableDummyGroups,
allGroups, groupMap)
+ MemoStateImpl(
+ table.ras,
+ immutableClusters,
+ immutableHubGroups,
+ immutableUserGroups,
+ allGroups,
+ groupMap)
}
def doExhaustively(func: => Unit): Unit = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
index 253e9ec84d..a098c1d113 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
@@ -49,9 +49,14 @@ object FilterWizard {
}
object FilterWizards {
- def omitCycles[T <: AnyRef](): FilterWizard[T] = {
+ def omitNodeCycles[T <: AnyRef](): FilterWizard[T] = {
// Compares against group ID to identify cycles.
- OmitCycles[T](CycleDetector[GroupNode[T]]((one, other) => one.groupId() ==
other.groupId()))
+ OmitCycles.onNodes(CycleDetector((one, other) => one.toHashKey ==
other.toHashKey))
+ }
+
+ def omitGroupCycles[T <: AnyRef](): FilterWizard[T] = {
+ // Compares against group ID to identify cycles.
+ OmitCycles.onGroups(CycleDetector((one, other) => one.groupId() ==
other.groupId()))
}
def none[T <: AnyRef](): FilterWizard[T] = {
@@ -70,17 +75,22 @@ object FilterWizards {
}
// Cycle detection starts from the first visited group in the input path.
- private class OmitCycles[T <: AnyRef] private (detector:
CycleDetector[GroupNode[T]])
+ private class OmitCycles[T <: AnyRef] private (
+ detectorOnNodes: CycleDetector[CanonicalNode[T]],
+ detectorOnGroups: CycleDetector[GroupNode[T]])
extends FilterWizard[T] {
override def omit(can: CanonicalNode[T]): FilterAction[T] = {
- FilterAction.Continue(this)
+ if (detectorOnNodes.contains(can)) {
+ return FilterAction.omit
+ }
+ FilterAction.Continue(new OmitCycles(detectorOnNodes.append(can),
detectorOnGroups))
}
override def omit(group: GroupNode[T]): FilterAction[T] = {
- if (detector.contains(group)) {
+ if (detectorOnGroups.contains(group)) {
return FilterAction.omit
}
- FilterAction.Continue(new OmitCycles(detector.append(group)))
+ FilterAction.Continue(new OmitCycles(detectorOnNodes,
detectorOnGroups.append(group)))
}
override def advance(offset: Int, count: Int): FilterAdvanceAction[T] =
@@ -88,8 +98,12 @@ object FilterWizards {
}
private object OmitCycles {
- def apply[T <: AnyRef](detector: CycleDetector[GroupNode[T]]):
OmitCycles[T] = {
- new OmitCycles(detector)
+ def onNodes[T <: AnyRef](detector: CycleDetector[CanonicalNode[T]]):
OmitCycles[T] = {
+ new OmitCycles(detector, CycleDetector.noop())
+ }
+
+ def onGroups[T <: AnyRef](detector: CycleDetector[GroupNode[T]]):
OmitCycles[T] = {
+ new OmitCycles(CycleDetector.noop(), detector)
}
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala
index 78aed142a7..609433dc1b 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala
@@ -38,7 +38,7 @@ object PathFinder {
}
class Builder[T <: AnyRef] private (ras: Ras[T], memoStore: MemoStore[T]) {
- private val filterWizards =
mutable.ListBuffer[FilterWizard[T]](FilterWizards.omitCycles())
+ private val filterWizards =
mutable.ListBuffer[FilterWizard[T]](FilterWizards.omitNodeCycles())
private val outputWizards = mutable.ListBuffer[OutputWizard[T]]()
def depth(depth: Int): Builder[T] = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala
index 61fa22e5ea..3e1dba0c7d 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala
@@ -172,6 +172,8 @@ object RasPath {
}
override def plan(): T = built
+
+ override def toString: String =
s"RasPathImpl(${ras.explain.describeNode(built)})"
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/MemoRole.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/MemoRole.scala
new file mode 100644
index 0000000000..523597eddb
--- /dev/null
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/MemoRole.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.ras.property
+
+import org.apache.gluten.ras._
+import org.apache.gluten.ras.rule.{EnforcerRuleFactory, RasRule}
+
+import scala.collection.mutable
+
+sealed trait MemoRole[T <: AnyRef] extends Property[T] {
+ override def toString: String = this.getClass.getSimpleName
+}
+
+object MemoRole {
+ implicit class MemoRoleImplicits[T <: AnyRef](role: MemoRole[T]) {
+ def asReq(): Req[T] = role.asInstanceOf[Req[T]]
+ def asProp(): Prop[T] = role.asInstanceOf[Prop[T]]
+
+ def +:(base: PropertySet[T]): PropertySet[T] = {
+ require(!base.asMap.contains(role.definition()))
+ val map: Map[PropertyDef[T, _ <: Property[T]], Property[T]] = {
+ base.asMap + (role.definition() -> role)
+ }
+ PropertySet(map)
+ }
+ }
+
+ trait Req[T <: AnyRef] extends MemoRole[T]
+ trait Prop[T <: AnyRef] extends MemoRole[T]
+
+ // Constraints.
+ class ReqHub[T <: AnyRef] private[MemoRole] (
+ override val definition: PropertyDef[T, _ <: Property[T]])
+ extends Req[T]
+ class ReqUser[T <: AnyRef] private[MemoRole] (
+ override val definition: PropertyDef[T, _ <: Property[T]])
+ extends Req[T]
+ private class ReqAny[T <: AnyRef] private[MemoRole] (
+ override val definition: PropertyDef[T, _ <: Property[T]])
+ extends Req[T]
+
+ // Props.
+ class Leaf[T <: AnyRef] private[MemoRole] (
+ override val definition: PropertyDef[T, _ <: Property[T]])
+ extends Prop[T]
+ class Hub[T <: AnyRef] private[MemoRole] (
+ override val definition: PropertyDef[T, _ <: Property[T]])
+ extends Prop[T]
+ class User[T <: AnyRef] private[MemoRole] (
+ override val definition: PropertyDef[T, _ <: Property[T]])
+ extends Prop[T]
+
+ class Def[T <: AnyRef] private[MemoRole] (val planModel: PlanModel[T])
+ extends PropertyDef[T, MemoRole[T]] {
+ private val groupRoleLookup = mutable.Map[Int, Prop[T]]()
+
+ private val reqAny = new ReqAny[T](this)
+ val reqHub = new ReqHub[T](this)
+ val reqUser = new ReqUser[T](this)
+
+ val leaf = new Leaf[T](this)
+ val hub = new Hub[T](this)
+ val user = new User[T](this)
+
+ override def any(): MemoRole[T] = reqAny
+
+ override def getProperty(plan: T): MemoRole[T] = {
+ getProperty0(plan)
+ }
+
+ private def getProperty0(plan: T): MemoRole[T] = {
+ if (planModel.isGroupLeaf(plan)) {
+ val groupId = planModel.getGroupId(plan)
+ return groupRoleLookup(groupId)
+ }
+ val children = planModel.childrenOf(plan)
+ if (children.isEmpty) {
+ return leaf
+ }
+ val childrenRoles = children.map(getProperty0).distinct
+ assert(childrenRoles.size == 1, s"Unidentical children memo roles:
$childrenRoles")
+ childrenRoles.head
+ }
+
+ override def getChildrenConstraints(plan: T, constraint: Property[T]):
Seq[MemoRole[T]] = {
+ throw new UnsupportedOperationException("Not implemented for MemoRole")
+ }
+
+ override def satisfies(property: Property[T], constraint: Property[T]):
Boolean =
+ (property, constraint) match {
+ case (_: Prop[T], _: ReqAny[T]) => true
+ case (_: Leaf[T], _: Req[T]) => true
+ case (_: User[T], _: ReqUser[T]) => true
+ case (_: Hub[T], _: ReqHub[T]) => true
+ case _ => false
+ }
+
+ override def assignToGroup(
+ group: GroupLeafBuilder[T],
+ constraint: Property[T]): GroupLeafBuilder[T] = {
+ val role: Prop[T] = constraint.asInstanceOf[Req[T]] match {
+ case _: ReqAny[T] =>
+ hub
+ case _: ReqHub[T] =>
+ hub
+ case _: ReqUser[T] =>
+ user
+ case _ =>
+ throw new IllegalStateException(s"Unexpected req: $constraint")
+ }
+ groupRoleLookup(group.id()) = role
+ group
+ }
+ }
+
+ implicit class DefImplicits[T <: AnyRef](roleDef: Def[T]) {
+ def -:(base: PropertySet[T]): PropertySet[T] = {
+ require(base.asMap.contains(roleDef))
+ val map: Map[PropertyDef[T, _ <: Property[T]], Property[T]] = {
+ base.asMap - roleDef
+ }
+ PropertySet(map)
+ }
+ }
+
+ def newDef[T <: AnyRef](planModel: PlanModel[T]): Def[T] = {
+ new Def[T](planModel)
+ }
+
+ def wrapPropertySetFactory[T <: AnyRef](
+ factory: PropertySetFactory[T],
+ roleDef: Def[T]): PropertySetFactory[T] = {
+ new PropertySetFactoryWithMemoRole[T](factory, roleDef)
+ }
+
+ private class PropertySetFactoryWithMemoRole[T <: AnyRef](
+ delegate: PropertySetFactory[T],
+ roleDef: Def[T])
+ extends PropertySetFactory[T] {
+
+ override val any: PropertySet[T] = compose(roleDef.any(), delegate.any())
+
+ override def get(node: T): PropertySet[T] =
+ compose(roleDef.getProperty(node), delegate.get(node))
+
+ override def childrenConstraintSets(
+ node: T,
+ constraintSet: PropertySet[T]): Seq[PropertySet[T]] = {
+ assert(!roleDef.planModel.isGroupLeaf(node))
+
+ if (roleDef.planModel.isLeaf(node)) {
+ return Nil
+ }
+
+ val numChildren = roleDef.planModel.childrenOf(node).size
+
+ def delegateChildrenConstraintSets(): Seq[PropertySet[T]] = {
+ val roleRemoved = PropertySet(constraintSet.asMap - roleDef)
+ val out = delegate.childrenConstraintSets(node, roleRemoved)
+ out
+ }
+
+ def delegateConstraintSetAny(): PropertySet[T] = {
+ val properties: Seq[Property[T]] = constraintSet.asMap.keys.flatMap {
+ case _: Def[T] => Nil
+ case other => Seq(other.any())
+ }.toSeq
+ PropertySet(properties)
+ }
+
+ val constraintSets = constraintSet.get(roleDef).asReq() match {
+ case _: ReqAny[T] =>
+ delegateChildrenConstraintSets().map(
+ delegateConstraint => compose(roleDef.any(), delegateConstraint))
+ case _: ReqHub[T] =>
+ Seq.tabulate(numChildren)(_ => compose(roleDef.reqHub,
delegateConstraintSetAny()))
+ case _: ReqUser[T] =>
+ delegateChildrenConstraintSets().map(
+ delegateConstraint => compose(roleDef.reqUser, delegateConstraint))
+ }
+
+ constraintSets
+ }
+
+ override def assignToGroup(group: GroupLeafBuilder[T], constraintSet:
PropertySet[T]): Unit = {
+ roleDef.assignToGroup(group, constraintSet.asMap(roleDef))
+ delegate.assignToGroup(group, PropertySet(constraintSet.asMap - roleDef))
+ }
+
+ override def newEnforcerRuleFactory(): EnforcerRuleFactory[T] = {
+ new EnforcerRuleFactory[T] {
+ private val delegateFactory: EnforcerRuleFactory[T] =
delegate.newEnforcerRuleFactory()
+
+ override def newEnforcerRules(constraintSet: PropertySet[T]):
Seq[RasRule[T]] = {
+ delegateFactory.newEnforcerRules(constraintSet -: roleDef)
+ }
+ }
+ }
+
+ private def compose(memoRole: MemoRole[T], base: PropertySet[T]):
PropertySet[T] = {
+ base +: memoRole
+ }
+ }
+}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/PropertySet.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/PropertySet.scala
index e28718993c..3191df77ec 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/PropertySet.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/PropertySet.scala
@@ -19,8 +19,8 @@ package org.apache.gluten.ras.property
import org.apache.gluten.ras.{Property, PropertyDef}
trait PropertySet[T <: AnyRef] {
- def get[P <: Property[T]](property: PropertyDef[T, P]): P
- def getMap: Map[PropertyDef[T, _ <: Property[T]], Property[T]]
+ def get[P <: Property[T]](propertyDef: PropertyDef[T, P]): P
+ def asMap: Map[PropertyDef[T, _ <: Property[T]], Property[T]]
def satisfies(other: PropertySet[T]): Boolean
}
@@ -37,29 +37,17 @@ object PropertySet {
ImmutablePropertySet[T](map)
}
- implicit class PropertySetImplicits[T <: AnyRef](propSet: PropertySet[T]) {
- def withProp(property: Property[T]): PropertySet[T] = {
- val before = propSet.getMap
- val after = before + (property.definition() -> property)
- assert(after.size == before.size)
- ImmutablePropertySet[T](after)
- }
- }
-
private case class ImmutablePropertySet[T <: AnyRef](
map: Map[PropertyDef[T, _ <: Property[T]], Property[T]])
extends PropertySet[T] {
- assert(
- map.values.forall(p => p.satisfies(p.definition().any())),
- s"Property set $this doesn't satisfy its ${'"'}any${'"'} variant")
+ override def asMap: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
map
- override def getMap: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
map
override def satisfies(other: PropertySet[T]): Boolean = {
- assert(map.size == other.getMap.size)
+ assert(map.size == other.asMap.size)
map.forall {
case (propDef, prop) =>
- prop.satisfies(other.getMap(propDef))
+ prop.satisfies(other.asMap(propDef))
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/PropertySetFactory.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/PropertySetFactory.scala
new file mode 100644
index 0000000000..cf41c8ace1
--- /dev/null
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/PropertySetFactory.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.ras.property
+
+import org.apache.gluten.ras.{GroupLeafBuilder, PlanModel, Property,
PropertyDef, PropertyModel}
+import org.apache.gluten.ras.rule.EnforcerRuleFactory
+
+import scala.collection.mutable
+
+trait PropertySetFactory[T <: AnyRef] {
+ def any(): PropertySet[T]
+ def get(node: T): PropertySet[T]
+ def childrenConstraintSets(node: T, constraintSet: PropertySet[T]):
Seq[PropertySet[T]]
+ def assignToGroup(group: GroupLeafBuilder[T], constraintSet:
PropertySet[T]): Unit
+ def newEnforcerRuleFactory(): EnforcerRuleFactory[T]
+}
+
+object PropertySetFactory {
+ def apply[T <: AnyRef](
+ propertyModel: PropertyModel[T],
+ planModel: PlanModel[T]): PropertySetFactory[T] =
+ new Impl[T](propertyModel, planModel)
+
+ private class Impl[T <: AnyRef](propertyModel: PropertyModel[T], planModel:
PlanModel[T])
+ extends PropertySetFactory[T] {
+ private val propDefs: Seq[PropertyDef[T, _ <: Property[T]]] =
propertyModel.propertyDefs
+ private val anyConstraint = {
+ val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
+ propDefs.map(propDef => (propDef, propDef.any())).toMap
+ PropertySet[T](m)
+ }
+
+ override def any(): PropertySet[T] = anyConstraint
+
+ override def get(node: T): PropertySet[T] = {
+ val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
+ propDefs
+ .map(
+ propDef => {
+ val prop = propDef.getProperty(node)
+ (propDef, prop)
+ })
+ .toMap
+ val propSet = PropertySet[T](m)
+ assert(
+ propSet.satisfies(anyConstraint),
+ s"Property set $propSet doesn't satisfy its ${'\"'}any${'\"'} variant")
+ propSet
+ }
+
+ override def childrenConstraintSets(
+ node: T,
+ constraintSet: PropertySet[T]): Seq[PropertySet[T]] = {
+ val builder: Seq[mutable.Map[PropertyDef[T, _ <: Property[T]],
Property[T]]] =
+ planModel
+ .childrenOf(node)
+ .map(_ => mutable.Map[PropertyDef[T, _ <: Property[T]],
Property[T]]())
+
+ propDefs
+ .foldLeft(builder) {
+ (
+ builder: Seq[mutable.Map[PropertyDef[T, _ <: Property[T]],
Property[T]]],
+ propDef: PropertyDef[T, _ <: Property[T]]) =>
+ val constraint = constraintSet.get(propDef)
+ val childrenConstraints = propDef.getChildrenConstraints(node,
constraint)
+ builder.zip(childrenConstraints).map {
+ case (childBuilder, childConstraint) =>
+ childBuilder += (propDef -> childConstraint)
+ }
+ }
+ .map {
+ builder: mutable.Map[PropertyDef[T, _ <: Property[T]], Property[T]]
=>
+ PropertySet[T](builder.toMap)
+ }
+ }
+
+ override def assignToGroup(group: GroupLeafBuilder[T], constraintSet:
PropertySet[T]): Unit = {
+ constraintSet.asMap.foreach {
+ case (constraintDef, constraint) =>
+ constraintDef.assignToGroup(group, constraint)
+ }
+ }
+
+ override def newEnforcerRuleFactory(): EnforcerRuleFactory[T] = {
+ propertyModel.newEnforcerRuleFactory()
+ }
+ }
+}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala
deleted file mode 100644
index c82deb7138..0000000000
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.ras.rule
-
-import org.apache.gluten.ras.{EnforcerRuleFactory, Property, PropertyDef, Ras}
-import org.apache.gluten.ras.memo.Closure
-import org.apache.gluten.ras.property.PropertySet
-
-import scala.collection.mutable
-
-trait EnforcerRule[T <: AnyRef] {
- def shift(node: T): Iterable[T]
- def shape(): Shape[T]
- def constraint(): Property[T]
-}
-
-object EnforcerRule {
- def apply[T <: AnyRef](rule: RasRule[T], constraint: Property[T]):
EnforcerRule[T] = {
- new EnforcerRuleImpl(rule, constraint)
- }
-
- def builtin[T <: AnyRef](constraint: Property[T]): EnforcerRule[T] = {
- new BuiltinEnforcerRule(constraint)
- }
-
- private class EnforcerRuleImpl[T <: AnyRef](
- rule: RasRule[T],
- override val constraint: Property[T])
- extends EnforcerRule[T] {
- override def shift(node: T): Iterable[T] = rule.shift(node)
- override def shape(): Shape[T] = rule.shape()
- }
-
- // A built-in enforcer rule that does constraint propagation. The rule
directly outputs
- // whatever passed in, and memo will copy the output node in with the
desired constraint.
- // During witch children constraints will be derived through
PropertyDef#getChildrenConstraints.
- // When the children constraints do changed, the new node with changed
children constraints will
- // be persisted into memo.
- private class BuiltinEnforcerRule[T <: AnyRef](override val constraint:
Property[T])
- extends EnforcerRule[T] {
- override def shift(node: T): Iterable[T] = List(node)
- override def shape(): Shape[T] = Shapes.fixedHeight(1)
- }
-}
-
-trait EnforcerRuleSet[T <: AnyRef] {
- def rulesOf(constraintSet: PropertySet[T]): Seq[RuleApplier[T]]
- def ruleShapesOf(constraintSet: PropertySet[T]): Seq[Shape[T]]
-}
-
-object EnforcerRuleSet {
- def apply[T <: AnyRef](ras: Ras[T], closure: Closure[T]): EnforcerRuleSet[T]
= {
- new EnforcerRuleSetImpl(ras, closure)
- }
-
- private def newEnforcerRuleFactory[T <: AnyRef](
- ras: Ras[T],
- propertyDef: PropertyDef[T, _ <: Property[T]]): EnforcerRuleFactory[T] =
{
- ras.propertyModel.newEnforcerRuleFactory(propertyDef)
- }
-
- private class EnforcerRuleSetImpl[T <: AnyRef](ras: Ras[T], closure:
Closure[T])
- extends EnforcerRuleSet[T] {
- private val factoryBuffer =
- mutable.Map[PropertyDef[T, _ <: Property[T]], EnforcerRuleFactory[T]]()
- private val buffer = mutable.Map[Property[T], Seq[RuleApplier[T]]]()
-
- private val rulesBuffer = mutable.Map[PropertySet[T],
Seq[RuleApplier[T]]]()
- private val shapesBuffer = mutable.Map[PropertySet[T], Seq[Shape[T]]]()
-
- override def rulesOf(constraintSet: PropertySet[T]): Seq[RuleApplier[T]] =
{
- rulesBuffer.getOrElseUpdate(
- constraintSet,
- constraintSet.getMap.flatMap {
- case (constraintDef, constraint) =>
- buffer.getOrElseUpdate(
- constraint, {
- val factory =
- factoryBuffer.getOrElseUpdate(
- constraintDef,
- newEnforcerRuleFactory(ras, constraintDef))
- RuleApplier(ras, closure, EnforcerRule.builtin(constraint)) +:
factory
- .newEnforcerRules(constraint)
- .map(rule => RuleApplier(ras, closure, EnforcerRule(rule,
constraint)))
- }
- )
- }.toSeq
- )
- }
-
- override def ruleShapesOf(constraintSet: PropertySet[T]): Seq[Shape[T]] = {
- shapesBuffer.getOrElseUpdate(constraintSet,
rulesBuffer(constraintSet).map(_.shape()))
- }
- }
-}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleFactory.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleFactory.scala
new file mode 100644
index 0000000000..9239620c57
--- /dev/null
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleFactory.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.ras.rule
+
+import org.apache.gluten.ras.{Property, PropertyDef}
+import org.apache.gluten.ras.property.PropertySet
+
+trait EnforcerRuleFactory[T <: AnyRef] {
+ def newEnforcerRules(constraintSet: PropertySet[T]): Seq[RasRule[T]]
+}
+
+object EnforcerRuleFactory {
+ def fromSubRules[T <: AnyRef](
+ subRuleFactories: Seq[SubRuleFactory[T]]): EnforcerRuleFactory[T] = {
+ new FromSubRules[T](subRuleFactories)
+ }
+
+ trait SubRule[T <: AnyRef] {
+ def enforce(node: T, constraint: Property[T]): Iterable[T]
+ }
+
+ trait SubRuleFactory[T <: AnyRef] {
+ def newSubRule(constraintDef: PropertyDef[T, _ <: Property[T]]): SubRule[T]
+ def ruleShape: Shape[T]
+ }
+
+ private class FromSubRules[T <: AnyRef](subRuleFactories:
Seq[SubRuleFactory[T]])
+ extends EnforcerRuleFactory[T] {
+ override def newEnforcerRules(constraintSet: PropertySet[T]):
Seq[RasRule[T]] = {
+ subRuleFactories.map {
+ subRuleFactory =>
+ new RasRule[T] {
+ override def shift(node: T): Iterable[T] = {
+ val out = constraintSet.asMap
+ .scanLeft(Seq(node)) {
+ case (nodes, (constraintDef, constraint)) =>
+ val subRule = subRuleFactory.newSubRule(constraintDef)
+ val intermediate = nodes.flatMap(
+ n => {
+ val after = subRule.enforce(n, constraint)
+ after
+ })
+ intermediate
+ }
+ .flatten
+ out
+ }
+
+ override def shape(): Shape[T] = subRuleFactory.ruleShape
+ }
+ }
+ }
+ }
+}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleSet.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleSet.scala
new file mode 100644
index 0000000000..a254851f43
--- /dev/null
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleSet.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.ras.rule
+
+import org.apache.gluten.ras.Ras
+import org.apache.gluten.ras.memo.Closure
+import org.apache.gluten.ras.property.PropertySet
+
+import scala.collection.mutable
+
+trait EnforcerRuleSet[T <: AnyRef] {
+ def rules(): Seq[RuleApplier[T]]
+ def shapes(): Seq[Shape[T]]
+}
+
+object EnforcerRuleSet {
+ implicit class EnforcerRuleSetImplicits[T <: AnyRef](ruleSet:
EnforcerRuleSet[T]) {
+ def ++(other: EnforcerRuleSet[T]): EnforcerRuleSet[T] = {
+ EnforcerRuleSet(ruleSet.rules() ++ other.rules())
+ }
+ }
+
+ private def apply[T <: AnyRef](rules: Seq[RuleApplier[T]]):
EnforcerRuleSet[T] = {
+ new Impl(rules)
+ }
+
+ private class Impl[T <: AnyRef](rules: Seq[RuleApplier[T]]) extends
EnforcerRuleSet[T] {
+ private val ruleShapes: Seq[Shape[T]] = rules.map(_.shape())
+
+ override def rules(): Seq[RuleApplier[T]] = rules
+ override def shapes(): Seq[Shape[T]] = ruleShapes
+ }
+
+ trait Factory[T <: AnyRef] {
+ def ruleSetOf(constraintSet: PropertySet[T]): EnforcerRuleSet[T]
+ }
+
+ object Factory {
+ def regular[T <: AnyRef](ras: Ras[T], closure: Closure[T]): Factory[T] = {
+ new Regular(ras, closure)
+ }
+
+ def derive[T <: AnyRef](ras: Ras[T], closure: Closure[T]): Factory[T] = {
+ new Derive(ras, closure)
+ }
+
+ private class Regular[T <: AnyRef](ras: Ras[T], closure: Closure[T])
extends Factory[T] {
+ private val factory = ras.propertySetFactory().newEnforcerRuleFactory()
+ private val ruleSetBuffer = mutable.Map[PropertySet[T],
EnforcerRuleSet[T]]()
+
+ override def ruleSetOf(constraintSet: PropertySet[T]):
EnforcerRuleSet[T] = {
+ ruleSetBuffer.getOrElseUpdate(
+ constraintSet, {
+ val rules =
+ factory.newEnforcerRules(constraintSet).map {
+ rule: RasRule[T] => RuleApplier.enforcer(ras, closure,
constraintSet, rule)
+ }
+ EnforcerRuleSet(rules)
+ }
+ )
+ }
+ }
+
+ private class Derive[T <: AnyRef](ras: Ras[T], closure: Closure[T])
extends Factory[T] {
+ import Derive._
+ private val ruleSetBuffer = mutable.Map[PropertySet[T],
EnforcerRuleSet[T]]()
+ override def ruleSetOf(constraintSet: PropertySet[T]):
EnforcerRuleSet[T] = {
+ val rule = RuleApplier.enforcer(ras, closure, constraintSet, new
DeriveEnforcerRule[T]())
+ ruleSetBuffer.getOrElseUpdate(constraintSet,
EnforcerRuleSet(Seq(rule)))
+ }
+ }
+
+ private object Derive {
+ // A built-in enforcer rule set that does constraint propagation. The
rule directly outputs
+ // whatever passed in, and memo will copy the output node in with the
desired constraint.
+ // During witch children constraints will be derived through
+ // PropertyDef#getChildrenConstraints. When the children constraints are
changed, the
+ // new node with changed children constraints will be persisted into the
memo.
+ private class DeriveEnforcerRule[T <: AnyRef]() extends RasRule[T] {
+ override def shift(node: T): Iterable[T] = Seq(node)
+ override def shape(): Shape[T] = Shapes.fixedHeight(1)
+ }
+ }
+ }
+}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RasRule.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RasRule.scala
index 0c3e7558ef..8a997375cd 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RasRule.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RasRule.scala
@@ -35,5 +35,4 @@ object RasRule {
override def create(): Seq[RasRule[T]] = rules
}
}
-
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
index ed686c6ef7..3a090d6dd6 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
@@ -22,6 +22,8 @@ import org.apache.gluten.ras.memo.Closure
import org.apache.gluten.ras.path.InClusterPath
import org.apache.gluten.ras.property.PropertySet
+import java.util
+
import scala.collection.mutable
trait RuleApplier[T <: AnyRef] {
@@ -30,20 +32,21 @@ trait RuleApplier[T <: AnyRef] {
}
object RuleApplier {
- def apply[T <: AnyRef](ras: Ras[T], closure: Closure[T], rule: RasRule[T]):
RuleApplier[T] = {
+ def regular[T <: AnyRef](ras: Ras[T], closure: Closure[T], rule:
RasRule[T]): RuleApplier[T] = {
new RegularRuleApplier(ras, closure, rule)
}
- def apply[T <: AnyRef](
+ def enforcer[T <: AnyRef](
ras: Ras[T],
closure: Closure[T],
- rule: EnforcerRule[T]): RuleApplier[T] = {
- new EnforcerRuleApplier[T](ras, closure, rule)
+ constraintSet: PropertySet[T],
+ rule: RasRule[T]): RuleApplier[T] = {
+ new EnforcerRuleApplier[T](ras, closure, constraintSet, rule)
}
private class RegularRuleApplier[T <: AnyRef](ras: Ras[T], closure:
Closure[T], rule: RasRule[T])
extends RuleApplier[T] {
- private val deDup = mutable.Map[RasClusterKey,
mutable.Set[UnsafeHashKey[T]]]()
+ private val deDup = DeDup(ras)
override def apply(icp: InClusterPath[T]): Unit = {
if (!shape.identify(icp.path())) {
@@ -52,13 +55,9 @@ object RuleApplier {
val cKey = icp.cluster()
val path = icp.path()
val plan = path.plan()
- val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set())
- val pKey = ras.toHashKey(plan)
- if (appliedPlans.contains(pKey)) {
- return
+ deDup.run(cKey, plan) {
+ apply0(cKey, plan)
}
- apply0(cKey, plan)
- appliedPlans += pKey
}
private def apply0(cKey: RasClusterKey, plan: T): Unit = {
@@ -67,7 +66,7 @@ object RuleApplier {
equiv =>
closure
.openFor(cKey)
- .memorize(equiv, ras.anyPropSet())
+ .memorize(equiv, ras.userConstraintSet())
}
}
@@ -77,11 +76,10 @@ object RuleApplier {
private class EnforcerRuleApplier[T <: AnyRef](
ras: Ras[T],
closure: Closure[T],
- rule: EnforcerRule[T])
+ constraintSet: PropertySet[T],
+ rule: RasRule[T])
extends RuleApplier[T] {
- private val deDup = mutable.Map[RasClusterKey,
mutable.Set[UnsafeHashKey[T]]]()
- private val constraint = rule.constraint()
- private val constraintDef = constraint.definition()
+ private val deDup = DeDup(ras)
override def apply(icp: InClusterPath[T]): Unit = {
if (!shape.identify(icp.path())) {
@@ -90,18 +88,13 @@ object RuleApplier {
val cKey = icp.cluster()
val path = icp.path()
val propSet = path.node().self().propSet()
- if (propSet.get(constraintDef).satisfies(constraint)) {
+ if (propSet.satisfies(constraintSet)) {
return
}
val plan = path.plan()
- val pKey = ras.toHashKey(plan)
- val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set())
- if (appliedPlans.contains(pKey)) {
- return
+ deDup.run(cKey, plan) {
+ apply0(cKey, constraintSet, plan)
}
- val constraintSet = propSet.withProp(constraint)
- apply0(cKey, constraintSet, plan)
- appliedPlans += pKey
}
private def apply0(cKey: RasClusterKey, constraintSet: PropertySet[T],
plan: T): Unit = {
@@ -116,4 +109,48 @@ object RuleApplier {
override val shape: Shape[T] = rule.shape()
}
+
+ private trait DeDup[T <: AnyRef] {
+ def run(cKey: RasClusterKey, plan: T)(computation: => Unit): Unit
+ }
+
+ private object DeDup {
+ def apply[T <: AnyRef](ras: Ras[T]): DeDup[T] = {
+ new Impl[T](ras)
+ }
+
+ private class Impl[T <: AnyRef](ras: Ras[T]) extends DeDup[T] {
+ private val layerOne = mutable.Map[RasClusterKey,
java.util.IdentityHashMap[T, Object]]()
+ private val layerTwo = mutable.Map[RasClusterKey,
mutable.Set[UnsafeHashKey[T]]]()
+
+ override def run(cKey: RasClusterKey, plan: T)(computation: => Unit):
Unit = {
+ // L1 cache is built on the identity hash codes of the input query
plans. If
+ // the cache is hit, which means the same plan object in this JVM was
+ // once applied for the computation. Return fast in that case.
+ val l1Plans = layerOne.getOrElseUpdate(cKey, new
util.IdentityHashMap())
+ if (l1Plans.containsKey(plan)) {
+ // The L1 cache is hit.
+ return
+ }
+ // Add the plan object into L1 cache.
+ l1Plans.put(plan, new Object)
+
+ // L2 cache is built on the equalities of the input query plans. It
internally
+ // compares plans through RAS API PlanMode#equals. If the cache is
hit, which
+ // means an identical plan (but not necessarily the same one) was once
applied
+ // for the computation. Return fast in that case.
+ val l2Plans = layerTwo.getOrElseUpdate(cKey, mutable.Set())
+ val pKey = ras.toHashKey(plan)
+ if (l2Plans.contains(pKey)) {
+ // The L2 cache is hit.
+ return
+ }
+ // Add the plan object into L2 cache.
+ l2Plans += pKey
+
+ // All cache missed, apply the computation on the plan.
+ computation
+ }
+ }
+ }
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
index d7d14cf3a7..4b4e0d45da 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
@@ -148,7 +148,7 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T],
memoState: MemoState[T], best
}
private def describeGroupVerbose(group: RasGroup[T]): String = {
- s"[Group ${group.id()}:
${group.constraintSet().getMap.values.toIndexedSeq}]"
+ s"[Group ${group.id()}:
${group.constraintSet().asMap.values.toIndexedSeq}]"
}
private def describeNode(
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/MetadataSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/MetadataSuite.scala
index 50c37ca9d8..bd77d05a01 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/MetadataSuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/MetadataSuite.scala
@@ -132,6 +132,13 @@ object MetadataSuite {
case other =>
throw new UnsupportedOperationException()
}
+
+ override def assignToGroup(group: GroupLeafBuilder[TestNode], meta:
Metadata): Unit = {
+ (group, meta) match {
+ case (builder: Group.Builder, m: Metadata) =>
+ builder.withMetadata(m)
+ }
+ }
}
trait RowCount extends Metadata
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
index e1ccfa1f44..db5f73299a 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
@@ -18,7 +18,6 @@ package org.apache.gluten.ras
import org.apache.gluten.ras.RasSuiteBase._
import org.apache.gluten.ras.path.RasPath
-import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
import org.scalatest.funsuite.AnyFunSuite
@@ -108,7 +107,7 @@ class OperationSuite extends AnyFunSuite {
val optimized = planner.plan()
assert(optimized == Unary2(49, Leaf2(29)))
- planModel.assertPlanOpsLte((200, 50, 100, 50))
+ planModel.assertPlanOpsLte((400, 100, 100, 50))
val state = planner.newState()
val allPaths = state.memoState().collectAllPaths(RasPath.INF_DEPTH).toSeq
@@ -138,7 +137,7 @@ class OperationSuite extends AnyFunSuite {
val optimized = planner.plan()
assert(optimized == Unary3(98, Unary3(99, Leaf2(29))))
- planModel.assertPlanOpsLte((800, 300, 300, 200))
+ planModel.assertPlanOpsLte((1300, 300, 300, 200))
val state = planner.newState()
val allPaths = state.memoState().collectAllPaths(RasPath.INF_DEPTH).toSeq
@@ -189,7 +188,7 @@ class OperationSuite extends AnyFunSuite {
50,
Unary2(50, Unary2(50, Unary2(50, Unary2(44, Unary2(50,
Unary2(50, Leaf2(29))))))))))))
- planModel.assertPlanOpsLte((20000, 10000, 3000, 3000))
+ planModel.assertPlanOpsLte((20000, 10000, 3000, 4000))
val state = planner.newState()
val allPaths = state.memoState().collectAllPaths(RasPath.INF_DEPTH).toSeq
@@ -411,12 +410,9 @@ object OperationSuite {
equalsCount += 1
delegated.equals(one, other)
}
- override def newGroupLeaf(
- groupId: Int,
- metadata: Metadata,
- constraintSet: PropertySet[T]): T = {
+ override def newGroupLeaf(groupId: Int): GroupLeafBuilder[T] = {
newGroupLeafCount += 1
- delegated.newGroupLeaf(groupId, metadata, constraintSet)
+ delegated.newGroupLeaf(groupId)
}
override def isGroupLeaf(node: T): Boolean = {
isGroupLeafCount += 1
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
index 06bb806f7d..0f3daabc67 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
@@ -21,7 +21,7 @@ import org.apache.gluten.ras.RasConfig.PlannerType
import org.apache.gluten.ras.RasSuiteBase._
import org.apache.gluten.ras.memo.Memo
import org.apache.gluten.ras.property.PropertySet
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
+import org.apache.gluten.ras.rule.{EnforcerRuleFactory, RasRule, Shape, Shapes}
import org.scalatest.funsuite.AnyFunSuite
@@ -72,7 +72,7 @@ abstract class PropertySuite extends AnyFunSuite {
memo.memorize(ras, PassNodeType(1, PassNodeType(1, PassNodeType(1,
TypedLeaf(TypeB, 1)))))
val state = memo.newState()
assert(state.allClusters().size == 4)
- assert(state.getGroupCount() == 4)
+ assert(state.getGroupCount() == 8)
}
test(s"Get property") {
@@ -87,10 +87,10 @@ abstract class PropertySuite extends AnyFunSuite {
assert(propDefs.head.getProperty(leaf) === DummyProperty(0))
assert(propDefs.head.getProperty(unary) === DummyProperty(0))
assert(propDefs.head.getProperty(binary) === DummyProperty(0))
- assert(propDefs.head.getChildrenConstraints(DummyProperty(0), leaf) ===
Seq.empty)
- assert(propDefs.head.getChildrenConstraints(DummyProperty(0), unary) ===
Seq(DummyProperty(0)))
+ assert(propDefs.head.getChildrenConstraints(leaf, DummyProperty(0)) ===
Seq.empty)
+ assert(propDefs.head.getChildrenConstraints(unary, DummyProperty(0)) ===
Seq(DummyProperty(0)))
assert(propDefs.head
- .getChildrenConstraints(DummyProperty(0), binary) ===
Seq(DummyProperty(0), DummyProperty(0)))
+ .getChildrenConstraints(binary, DummyProperty(0)) ===
Seq(DummyProperty(0), DummyProperty(0)))
}
test(s"Cannot enforce property") {
@@ -201,7 +201,7 @@ abstract class PropertySuite extends AnyFunSuite {
val out = planner.plan()
assert(out == TypedLeaf(TypeA, 1))
- // Cluster 2 and 1 are able to merge but we'd make sure
+ // Cluster 2 and 1 are able to merge, but we'd make sure
// they are identified as the same right after HitCacheOp is applied
val clusterCount = planner.newState().memoState().allClusters().size
assert(clusterCount == 2)
@@ -501,15 +501,6 @@ object PropertySuite {
// Dummy property model
case class DummyProperty(id: Int) extends Property[TestNode] {
- override def satisfies(other: Property[TestNode]): Boolean = {
- other match {
- case DummyProperty(otherId) =>
- // Higher ID satisfies lower IDs.
- id >= otherId
- case _ => throw new IllegalStateException()
- }
- }
-
override def definition(): PropertyDef[TestNode, DummyProperty] = {
DummyPropertyDef
}
@@ -547,8 +538,8 @@ object PropertySuite {
}
override def getChildrenConstraints(
- constraint: Property[TestNode],
- plan: TestNode): Seq[DummyProperty] = {
+ plan: TestNode,
+ constraint: Property[TestNode]): Seq[DummyProperty] = {
plan match {
case PUnary(_, _, _) => Seq(DummyProperty(0))
case PLeaf(_, _) => Seq.empty
@@ -556,14 +547,31 @@ object PropertySuite {
case _ => throw new IllegalStateException()
}
}
+
+ override def satisfies(
+ property: Property[TestNode],
+ constraint: Property[TestNode]): Boolean = {
+ (property, constraint) match {
+ case (DummyProperty(id), DummyProperty(otherId)) =>
+ id >= otherId
+ }
+ }
+
+ override def assignToGroup(
+ group: GroupLeafBuilder[TestNode],
+ constraint: Property[TestNode]): GroupLeafBuilder[TestNode] =
+ (group, constraint) match {
+ case (builder: Group.Builder, c: DummyProperty) =>
+ builder.withConstraint(c)
+ }
}
object DummyPropertyModel extends PropertyModel[TestNode] {
override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = Seq(
DummyPropertyDef)
- override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _
<: Property[TestNode]])
- : EnforcerRuleFactory[TestNode] = (constraint: Property[TestNode]) =>
List.empty
+ override def newEnforcerRuleFactory(): EnforcerRuleFactory[TestNode] =
+ (_: PropertySet[TestNode]) => Seq.empty
}
// Node type property model
@@ -618,20 +626,20 @@ object PropertySuite {
override def withNewChildren(children: Seq[TestNode]): TestNode =
copy(selfCost, children.head)
}
- case class NodeTypeEnforcerRule(reqType: NodeType) extends RasRule[TestNode]
{
- override def shift(node: TestNode): Iterable[TestNode] = {
+ case class NodeTypeEnforcerRule() extends
EnforcerRuleFactory.SubRule[TestNode] {
+ override def enforce(node: TestNode, constraint: Property[TestNode]):
Iterable[TestNode] = {
+ val reqType = constraint.asInstanceOf[NodeType]
node match {
case typed: TypedNode if typed.nodeType.satisfies(reqType) =>
List(typed)
case typed: TypedNode => List(TypeEnforcer(reqType, 1, typed))
case _ => throw new IllegalStateException()
}
}
-
- override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
}
- case class ZeroDepthNodeTypeEnforcerRule(reqType: NodeType) extends
RasRule[TestNode] {
- override def shift(node: TestNode): Iterable[TestNode] = {
+ case class ZeroDepthNodeTypeEnforcerRule() extends
EnforcerRuleFactory.SubRule[TestNode] {
+ override def enforce(node: TestNode, constraint: Property[TestNode]):
Iterable[TestNode] = {
+ val reqType = constraint.asInstanceOf[NodeType]
node match {
case group: Group =>
val groupType = group.constraintSet.get(NodeTypeDef)
@@ -643,8 +651,6 @@ object PropertySuite {
case _ => throw new IllegalStateException()
}
}
-
- override def shape(): Shape[TestNode] = Shapes.fixedHeight(0)
}
object ReplaceByTypeARule extends RasRule[TestNode] {
@@ -681,8 +687,8 @@ object PropertySuite {
}
override def getChildrenConstraints(
- constraint: Property[TestNode],
- plan: TestNode): Seq[NodeType] = plan match {
+ plan: TestNode,
+ constraint: Property[TestNode]): Seq[NodeType] = plan match {
case TypedLeaf(_, _) => Seq.empty
case TypedUnary(t, _, _) => Seq(t)
case TypedBinary(t, _, _, _) => Seq(t, t)
@@ -694,6 +700,25 @@ object PropertySuite {
override def toString: String = "NodeTypeDef"
override def any(): NodeType = TypeAny
+
+ override def satisfies(
+ property: Property[TestNode],
+ constraint: Property[TestNode]): Boolean = {
+ (property, constraint) match {
+ case (_, TypeAny) => true
+ case (one, other) if one == other => true
+ case _ => false
+ }
+ }
+
+ override def assignToGroup(
+ group: GroupLeafBuilder[TestNode],
+ constraint: Property[TestNode]): GroupLeafBuilder[TestNode] = {
+ (group, constraint) match {
+ case (builder: Group.Builder, c: NodeType) =>
+ builder.withConstraint(c)
+ }
+ }
}
trait NodeType extends Property[TestNode] {
@@ -701,40 +726,13 @@ object PropertySuite {
override def toString: String = getClass.getSimpleName
}
- object TypeAny extends NodeType {
- override def satisfies(other: Property[TestNode]): Boolean = other match {
- case TypeAny => true
- case _: NodeType => false
- case _ => throw new IllegalStateException()
- }
- }
+ object TypeAny extends NodeType
- object TypeA extends NodeType {
- override def satisfies(other: Property[TestNode]): Boolean = other match {
- case TypeA => true
- case TypeAny => true
- case _: NodeType => false
- case _ => throw new IllegalStateException()
- }
- }
+ object TypeA extends NodeType
- object TypeB extends NodeType {
- override def satisfies(other: Property[TestNode]): Boolean = other match {
- case TypeB => true
- case TypeAny => true
- case _: NodeType => false
- case _ => throw new IllegalStateException()
- }
- }
+ object TypeB extends NodeType
- object TypeC extends NodeType {
- override def satisfies(other: Property[TestNode]): Boolean = other match {
- case TypeC => true
- case TypeAny => true
- case _: NodeType => false
- case _ => throw new IllegalStateException()
- }
- }
+ object TypeC extends NodeType
private def propertyModel(zeroDepth: Boolean): PropertyModel[TestNode] = {
if (zeroDepth) {
@@ -752,13 +750,18 @@ object PropertySuite {
override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = Seq(
NodeTypeDef)
- override def newEnforcerRuleFactory(
- propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
- : EnforcerRuleFactory[TestNode] = {
- (constraint: Property[TestNode]) =>
- {
- List(NodeTypeEnforcerRule(constraint.asInstanceOf[NodeType]))
+ override def newEnforcerRuleFactory(): EnforcerRuleFactory[TestNode] = {
+ EnforcerRuleFactory.fromSubRules(Seq(new
EnforcerRuleFactory.SubRuleFactory[TestNode] {
+ override def newSubRule(constraintDef: PropertyDef[TestNode, _ <:
Property[TestNode]])
+ : EnforcerRuleFactory.SubRule[TestNode] = {
+ constraintDef match {
+ case NodeTypeDef =>
+ NodeTypeEnforcerRule()
+ }
}
+
+ override def ruleShape: Shape[TestNode] = Shapes.fixedHeight(1)
+ }))
}
}
@@ -766,13 +769,18 @@ object PropertySuite {
override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = Seq(
NodeTypeDef)
- override def newEnforcerRuleFactory(
- propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
- : EnforcerRuleFactory[TestNode] = {
- (constraint: Property[TestNode]) =>
- {
-
List(ZeroDepthNodeTypeEnforcerRule(constraint.asInstanceOf[NodeType]))
+ override def newEnforcerRuleFactory(): EnforcerRuleFactory[TestNode] = {
+ EnforcerRuleFactory.fromSubRules(Seq(new
EnforcerRuleFactory.SubRuleFactory[TestNode] {
+ override def newSubRule(constraintDef: PropertyDef[TestNode, _ <:
Property[TestNode]])
+ : EnforcerRuleFactory.SubRule[TestNode] = {
+ constraintDef match {
+ case NodeTypeDef =>
+ ZeroDepthNodeTypeEnforcerRule()
+ }
}
+
+ override def ruleShape: Shape[TestNode] = Shapes.fixedHeight(0)
+ }))
}
}
@@ -780,9 +788,8 @@ object PropertySuite {
override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = Seq(
NodeTypeDef)
- override def newEnforcerRuleFactory(
- propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
- : EnforcerRuleFactory[TestNode] = (_: Property[TestNode]) =>
List.empty
+ override def newEnforcerRuleFactory(): EnforcerRuleFactory[TestNode] =
+ (_: PropertySet[TestNode]) => Seq.empty
}
}
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
index 2f3ef348cb..064be47617 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
@@ -69,7 +69,7 @@ abstract class RasSuite extends AnyFunSuite {
val state = memo.newState()
assert(group.nodes(state).size == 1)
memo.openFor(group.clusterKey()).memorize(ras, Unary(30, Leaf(90)))
- assert(memo.newState().allGroups().size == 4)
+ assert(memo.newState().allGroups().size == 8)
}
test("Group memo - define equivalence: binary with similar children, 1") {
@@ -89,7 +89,7 @@ abstract class RasSuite extends AnyFunSuite {
val leaf40Group = memo.memorize(ras, Leaf(40))
assert(leaf40Group.nodes(state).size == 1)
memo.openFor(leaf40Group.clusterKey()).memorize(ras, Leaf(30))
- assert(memo.newState().allGroups().size == 3)
+ assert(memo.newState().allGroups().size == 6)
}
test("Group memo - define equivalence: binary with similar children, 2") {
@@ -109,7 +109,7 @@ abstract class RasSuite extends AnyFunSuite {
val leaf40Group = memo.memorize(ras, Leaf(40))
assert(leaf40Group.nodes(state).size == 1)
memo.openFor(leaf40Group.clusterKey()).memorize(ras, Leaf(30))
- assert(memo.newState().allGroups().size == 5)
+ assert(memo.newState().allGroups().size == 10)
}
test("Group memo - partial canonical") {
@@ -124,8 +124,9 @@ abstract class RasSuite extends AnyFunSuite {
.withNewConfig(_ => conf)
val memo = Memo(ras)
val group1 = memo.memorize(ras, Unary(50, Unary(50, Leaf(30))))
- val group2 = memo.memorize(ras, Unary(50, Group(1)))
- assert(group2 eq group1)
+ val group2 = memo.memorize(ras, Unary(50, Leaf(30)))
+ val group3 = memo.memorize(ras, Unary(50, Group(group2.id())))
+ assert(group3 eq group1)
}
test(s"Unary node") {
@@ -229,8 +230,8 @@ abstract class RasSuite extends AnyFunSuite {
val allPaths = state.collectAllPaths(Int.MaxValue)
assert(state.allClusters().size == 3)
- assert(state.allGroups().size == 3)
- assert(allPaths.size == 15)
+ assert(state.allGroups().size == 6)
+ assert(allPaths.size == 33)
}
test(s"Group expansion - pattern") {
@@ -262,8 +263,8 @@ abstract class RasSuite extends AnyFunSuite {
val allPaths = state.collectAllPaths(Int.MaxValue)
assert(state.allClusters().size == 3)
- assert(state.allGroups().size == 3)
- assert(allPaths.size == 15)
+ assert(state.allGroups().size == 6)
+ assert(allPaths.size == 33)
}
test(s"Rule dependency") {
@@ -405,9 +406,9 @@ abstract class RasSuite extends AnyFunSuite {
val optimized = planner.plan()
val state = planner.newState()
- // The 2 plans have same cost
+ // The 2 plans have the same cost.
assert(optimized == Unary(90, Leaf(70)) || optimized == Unary2(90,
Leaf(70)))
- assert(state.memoState().getGroupCount() == 2)
+ assert(state.memoState().getGroupCount() == 4)
}
test(s"Binary swap") {
@@ -492,7 +493,7 @@ abstract class RasSuite extends AnyFunSuite {
val optimized = planner.plan()
val state = planner.newState()
- assert(state.memoState().getGroupCount() == 3)
+ assert(state.memoState().getGroupCount() == 6)
assert(optimized == Unary(50, Unary3(49, Leaf(30))))
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
index 65c4d5a073..5c13444c97 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
@@ -19,6 +19,9 @@ package org.apache.gluten.ras
import org.apache.gluten.ras.memo.{MemoLike, MemoState}
import org.apache.gluten.ras.path.{PathFinder, RasPath}
import org.apache.gluten.ras.property.PropertySet
+import org.apache.gluten.ras.rule.EnforcerRuleFactory
+
+import scala.collection.mutable.ArrayBuffer
object RasSuiteBase {
trait TestNode {
@@ -49,7 +52,8 @@ object RasSuiteBase {
def withNewChildren(children: Seq[TestNode]): TestNode = this
}
- case class Group(id: Int, meta: Metadata, constraintSet:
PropertySet[TestNode]) extends LeafLike {
+ case class Group private (id: Int, meta: Metadata, constraintSet:
PropertySet[TestNode])
+ extends LeafLike {
override def selfCost(): Long = Long.MaxValue
override def makeCopy(): LeafLike = copy()
}
@@ -58,6 +62,29 @@ object RasSuiteBase {
def apply(id: Int): Group = {
Group(id, MetadataModelImpl.DummyMetadata, PropertySet(List.empty))
}
+
+ def newBuilder(id: Int): Builder = {
+ new Builder(id)
+ }
+
+ class Builder private[Group] (override val id: Int) extends
GroupLeafBuilder[TestNode] {
+ private var meta: Metadata = _
+ private val constraints: ArrayBuffer[Property[TestNode]] = ArrayBuffer()
+
+ def withMetadata(meta: Metadata): Builder = {
+ this.meta = meta
+ this
+ }
+
+ def withConstraint(constraint: Property[TestNode]): Builder = {
+ this.constraints += constraint
+ this
+ }
+
+ override def build(): TestNode = {
+ Group(id, meta, PropertySet(constraints.toSeq))
+ }
+ }
}
case class LongCost(value: Long) extends Cost
@@ -110,11 +137,7 @@ object RasSuiteBase {
java.util.Objects.equals(one, other)
}
- override def newGroupLeaf(
- groupId: Int,
- meta: Metadata,
- constraintSet: PropertySet[TestNode]): TestNode =
- Group(groupId, meta, constraintSet)
+ override def newGroupLeaf(groupId: Int): Group.Builder =
Group.newBuilder(groupId)
override def getGroupId(node: TestNode): Int = node match {
case ngl: Group => ngl.id
@@ -145,12 +168,19 @@ object RasSuiteBase {
assert(one == DummyMetadata)
assert(other == DummyMetadata)
}
+
+ override def assignToGroup(group: GroupLeafBuilder[TestNode], meta:
Metadata): Unit = {
+ (group, meta) match {
+ case (builder: Group.Builder, m: Metadata) =>
+ builder.withMetadata(m)
+ }
+ }
}
object PropertyModelImpl extends PropertyModel[TestNode] {
- override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = List.empty
- override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _
<: Property[TestNode]])
- : EnforcerRuleFactory[TestNode] = (_: Property[TestNode]) => List.empty
+ override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = Seq.empty
+ override def newEnforcerRuleFactory(): EnforcerRuleFactory[TestNode] =
+ (_: PropertySet[TestNode]) => Seq.empty
}
implicit class GraphvizPrinter[T <: AnyRef](val planner: RasPlanner[T]) {
@@ -163,7 +193,7 @@ object RasSuiteBase {
implicit class MemoLikeImplicits[T <: AnyRef](val memo: MemoLike[T]) {
def memorize(ras: Ras[T], node: T): RasGroup[T] = {
- memo.memorize(node, ras.anyPropSet())
+ memo.memorize(node, ras.userConstraintSet())
}
}
@@ -175,26 +205,26 @@ object RasSuiteBase {
def collectAllPaths(depth: Int): Iterable[RasPath[T]] = {
val allClusters = state.allClusters()
val allGroups = state.allGroups()
+ val hubGroupLookup = state.clusterHubGroupLookup()
val highestFinder = PathFinder
.builder(state.ras(), state)
.depth(depth)
.build()
- allClusters
- .flatMap(c => c.nodes())
- .flatMap(
- node => {
- val highest = highestFinder.find(node).maxBy(c => c.height())
- val finder = (1 to highest.height())
- .foldLeft(PathFinder
- .builder(state.ras(), state)) {
- case (builder, d) =>
- builder.depth(d)
- }
- .build()
- finder.find(node)
- })
+ hubGroupLookup.flatMap {
+ case (cKey, hubGroup) =>
+ val hubGroupNode = GroupNode(state.ras(), hubGroup)
+ val highest = highestFinder.find(hubGroupNode).maxBy(c => c.height())
+ val finder = (1 to highest.height())
+ .foldLeft(PathFinder
+ .builder(state.ras(), state)) {
+ case (builder, d) =>
+ builder.depth(d)
+ }
+ .build()
+ finder.find(hubGroupNode)
+ }
}
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
index 1c8458af3c..b33073cd3c 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
@@ -49,9 +49,14 @@ case class MockMemoState[T <: AnyRef] private (
override def getGroup(id: Int): RasGroup[T] = allGroups(id)
- override def clusterDummyGroupLookup(): Map[RasClusterKey, RasGroup[T]] =
Map.empty
+ override def clusterHubGroupLookup(): Map[RasClusterKey, RasGroup[T]] =
Map.empty
- override def getDummyGroup(key: RasClusterKey): RasGroup[T] =
+ override def getHubGroup(key: RasClusterKey): RasGroup[T] =
+ throw new UnsupportedOperationException()
+
+ override def clusterUserGroupLookup(): Map[RasClusterKey, RasGroup[T]] =
Map.empty
+
+ override def getUserGroup(key: RasClusterKey): RasGroup[T] =
throw new UnsupportedOperationException()
}
@@ -148,7 +153,7 @@ object MockMemoState {
id,
clusterKey,
propSet,
- ras.planModel.newGroupLeaf(id, clusterKey.metadata, propSet))
+ ras.newGroupLeaf(id, clusterKey.metadata, propSet +:
ras.memoRoleDef.reqUser))
groupBuffer += group
group
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
index bf267a4b68..b8336e64ad 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
@@ -27,7 +27,7 @@ object MockRasPath {
def mock[T <: AnyRef](ras: Ras[T], node: T, keys: PathKeySet): RasPath[T] = {
val memo = Memo(ras)
- val g = memo.memorize(node, ras.anyPropSet())
+ val g = memo.memorize(node, ras.userConstraintSet())
val state = memo.newState()
val groupSupplier = state.asGroupSupplier()
assert(g.nodes(state).size == 1)
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
index e930e4da22..e4ab0687a6 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
@@ -20,7 +20,7 @@ import org.apache.gluten.ras._
import org.apache.gluten.ras.RasConfig.PlannerType
import org.apache.gluten.ras.RasSuiteBase._
import org.apache.gluten.ras.property.PropertySet
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
+import org.apache.gluten.ras.rule.{EnforcerRuleFactory, RasRule, Shape, Shapes}
import org.scalatest.funsuite.AnyFunSuite
@@ -69,7 +69,11 @@ abstract class DistributedSuite extends AnyFunSuite {
val planner =
ras.newPlanner(plan, PropertySet(List(HashDistribution(List("a", "b")),
AnyOrdering)))
val out = planner.plan()
- assert(out == DProject(DExchange(List("a", "b"), DLeaf())))
+ val alternatives = Set[TestNode](
+ DProject(DExchange(List("a", "b"), DLeaf())),
+ DExchange(List("a", "b"), DProject(DLeaf()))
+ )
+ assert(alternatives.contains(out))
}
test("Aggregate - none-distribution constraint") {
@@ -108,7 +112,11 @@ abstract class DistributedSuite extends AnyFunSuite {
val planner =
ras.newPlanner(plan, PropertySet(List(AnyDistribution,
SimpleOrdering(List("a", "b")))))
val out = planner.plan()
- assert(out == DProject(DSort(List("a", "b"), DLeaf())))
+ val alternatives = Set[TestNode](
+ DProject(DSort(List("a", "b"), DLeaf())),
+ DSort(List("a", "b"), DProject(DLeaf()))
+ )
+ assert(alternatives.contains(out))
}
test("Project - required distribution and ordering") {
@@ -128,7 +136,17 @@ abstract class DistributedSuite extends AnyFunSuite {
plan,
PropertySet(List(HashDistribution(List("a", "b")),
SimpleOrdering(List("b", "c")))))
val out = planner.plan()
- assert(out == DProject(DSort(List("b", "c"), DExchange(List("a", "b"),
DLeaf()))))
+
+ val alternatives = Set[TestNode](
+ DProject(DSort(List("b", "c"), DExchange(List("a", "b"), DLeaf()))),
+ DProject(DExchange(List("a", "b"), DSort(List("b", "c"), DLeaf()))),
+ DSort(List("b", "c"), DProject(DExchange(List("a", "b"), DLeaf()))),
+ DSort(List("b", "c"), DExchange(List("a", "b"), DProject(DLeaf()))),
+ DExchange(List("a", "b"), DSort(List("b", "c"), DProject(DLeaf()))),
+ DExchange(List("a", "b"), DProject(DSort(List("b", "c"), DLeaf())))
+ )
+
+ assert(alternatives.contains(out))
}
test("Aggregate - avoid re-exchange") {
@@ -212,42 +230,15 @@ object DistributedSuite {
override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
}
- trait Distribution extends Property[TestNode]
-
- case class HashDistribution(keys: Seq[String]) extends Distribution {
- override def satisfies(other: Property[TestNode]): Boolean = other match {
- case HashDistribution(otherKeys) if keys.size > otherKeys.size => false
- case HashDistribution(otherKeys) =>
- // (a) satisfies (a, b)
- keys.zipWithIndex.forall {
- case (key, index) =>
- key == otherKeys(index)
- }
- case AnyDistribution => true
- case NoneDistribution => false
- case _ => throw new UnsupportedOperationException()
- }
+ trait Distribution extends Property[TestNode] {
override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]]
= DistributionDef
}
- case object AnyDistribution extends Distribution {
- override def satisfies(other: Property[TestNode]): Boolean = other match {
- case HashDistribution(_) => false
- case AnyDistribution => true
- case NoneDistribution => false
- case _ => throw new UnsupportedOperationException()
- }
- override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]]
= DistributionDef
- }
+ case class HashDistribution(keys: Seq[String]) extends Distribution
- case object NoneDistribution extends Distribution {
- override def satisfies(other: Property[TestNode]): Boolean = other match {
- case AnyDistribution => true
- case _: Distribution => false
- case _ => throw new UnsupportedOperationException()
- }
- override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]]
= DistributionDef
- }
+ case object AnyDistribution extends Distribution
+
+ case object NoneDistribution extends Distribution
private object DistributionDef extends PropertyDef[TestNode, Distribution] {
override def getProperty(plan: TestNode): Distribution = plan match {
@@ -258,53 +249,52 @@ object DistributedSuite {
}
override def getChildrenConstraints(
- constraint: Property[TestNode],
- plan: TestNode): Seq[Distribution] = (constraint, plan) match {
+ plan: TestNode,
+ constraint: Property[TestNode]): Seq[Distribution] = (constraint,
plan) match {
case (NoneDistribution, p: DNode) => p.children().map(_ =>
NoneDistribution)
case (d: Distribution, p: DNode) => p.getDistributionConstraints(d)
case _ => throw new UnsupportedOperationException()
}
override def any(): Distribution = AnyDistribution
- }
- trait Ordering extends Property[TestNode]
+ override def satisfies(
+ property: Property[TestNode],
+ constraint: Property[TestNode]): Boolean = {
+ (property, constraint) match {
+ case (_, NoneDistribution) => false
+ case (_, AnyDistribution) => true
+ case (HashDistribution(keys), HashDistribution(otherKeys)) if
keys.size > otherKeys.size =>
+ false
+ case (HashDistribution(keys), HashDistribution(otherKeys)) =>
+ // (a) satisfies (a, b)
+ keys.zipWithIndex.forall {
+ case (key, index) =>
+ key == otherKeys(index)
+ }
+ case _ => false
+ }
+ }
- case class SimpleOrdering(keys: Seq[String]) extends Ordering {
- override def satisfies(other: Property[TestNode]): Boolean = other match {
- case SimpleOrdering(otherKeys) if keys.size < otherKeys.size => false
- case SimpleOrdering(otherKeys) =>
- // (a, b) satisfies (a)
- otherKeys.zipWithIndex.forall {
- case (otherKey, index) =>
- otherKey == keys(index)
- }
- case AnyOrdering => true
- case NoneOrdering => false
- case _ => throw new UnsupportedOperationException()
+ override def assignToGroup(
+ group: GroupLeafBuilder[TestNode],
+ constraint: Property[TestNode]): GroupLeafBuilder[TestNode] = {
+ (group, constraint) match {
+ case (builder: Group.Builder, c: Distribution) =>
+ builder.withConstraint(c)
+ }
}
- override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]]
= OrderingDef
}
- case object AnyOrdering extends Ordering {
- override def satisfies(other: Property[TestNode]): Boolean = other match {
- case SimpleOrdering(_) => false
- case AnyOrdering => true
- case NoneOrdering => false
- case _ => throw new UnsupportedOperationException()
- }
+ trait Ordering extends Property[TestNode] {
override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]]
= OrderingDef
}
- case object NoneOrdering extends Ordering {
- override def satisfies(other: Property[TestNode]): Boolean = other match {
- case AnyOrdering => true
- case _: Ordering => false
- case _ => throw new UnsupportedOperationException()
- }
- override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]]
= OrderingDef
+ case class SimpleOrdering(keys: Seq[String]) extends Ordering
- }
+ case object AnyOrdering extends Ordering
+
+ case object NoneOrdering extends Ordering
// FIXME: Handle non-ordering as well as non-distribution
private object OrderingDef extends PropertyDef[TestNode, Ordering] {
@@ -314,8 +304,8 @@ object DistributedSuite {
case _ => throw new UnsupportedOperationException()
}
override def getChildrenConstraints(
- constraint: Property[TestNode],
- plan: TestNode): Seq[Ordering] =
+ plan: TestNode,
+ constraint: Property[TestNode]): Seq[Ordering] =
(constraint, plan) match {
case (NoneOrdering, p: DNode) => p.children().map(_ => NoneOrdering)
case (o: Ordering, p: DNode) => p.getOrderingConstraints(o)
@@ -323,43 +313,77 @@ object DistributedSuite {
}
override def any(): Ordering = AnyOrdering
+
+ override def satisfies(
+ property: Property[TestNode],
+ constraint: Property[TestNode]): Boolean = {
+ (property, constraint) match {
+ case (_, NoneOrdering) => false
+ case (_, AnyOrdering) => true
+ case (SimpleOrdering(keys), SimpleOrdering(otherKeys)) if keys.size >
otherKeys.size =>
+ false
+ case (SimpleOrdering(keys), SimpleOrdering(otherKeys)) =>
+ // (a, b) satisfies (a)
+ otherKeys.zipWithIndex.forall {
+ case (otherKey, index) =>
+ otherKey == keys(index)
+ }
+ case _ => false
+ }
+ }
+
+ override def assignToGroup(
+ group: GroupLeafBuilder[TestNode],
+ constraint: Property[TestNode]): GroupLeafBuilder[TestNode] = {
+ (group, constraint) match {
+ case (builder: Group.Builder, c: Ordering) =>
+ builder.withConstraint(c)
+ }
+ }
}
- private class EnforceDistribution(distribution: Distribution) extends
RasRule[TestNode] {
- override def shift(node: TestNode): Iterable[TestNode] = (node,
distribution) match {
- case (d: DNode, HashDistribution(keys)) => List(DExchange(keys, d))
- case (d: DNode, AnyDistribution) => List(d)
- case (d: DNode, NoneDistribution) => List.empty
- case _ =>
- throw new UnsupportedOperationException()
+ private class EnforceDistribution() extends
EnforcerRuleFactory.SubRule[TestNode] {
+ override def enforce(node: TestNode, constraint: Property[TestNode]):
Iterable[TestNode] = {
+ val distribution = constraint.asInstanceOf[Distribution]
+ (node, distribution) match {
+ case (d: DNode, HashDistribution(keys)) => List(DExchange(keys, d))
+ case (d: DNode, AnyDistribution) => List(d)
+ case (d: DNode, NoneDistribution) => List.empty
+ case _ =>
+ throw new UnsupportedOperationException()
+ }
}
- override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
}
- private class EnforceOrdering(ordering: Ordering) extends RasRule[TestNode] {
- override def shift(node: TestNode): Iterable[TestNode] = (node, ordering)
match {
- case (d: DNode, SimpleOrdering(keys)) => List(DSort(keys, d))
- case (d: DNode, AnyOrdering) => List(d)
- case (d: DNode, NoneOrdering) => List.empty
- case _ => throw new UnsupportedOperationException()
+ private class EnforceOrdering() extends
EnforcerRuleFactory.SubRule[TestNode] {
+ override def enforce(node: TestNode, constraint: Property[TestNode]):
Iterable[TestNode] = {
+ val ordering = constraint.asInstanceOf[Ordering]
+ (node, ordering) match {
+ case (d: DNode, SimpleOrdering(keys)) => List(DSort(keys, d))
+ case (d: DNode, AnyOrdering) => List(d)
+ case (d: DNode, NoneOrdering) => List.empty
+ case _ => throw new UnsupportedOperationException()
+ }
}
- override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
}
private object DistributedPropertyModel extends PropertyModel[TestNode] {
override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] =
List(DistributionDef, OrderingDef)
- override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _
<: Property[TestNode]])
- : EnforcerRuleFactory[TestNode] = new EnforcerRuleFactory[TestNode] {
- override def newEnforcerRules(constraint: Property[TestNode]):
Seq[RasRule[TestNode]] = {
- constraint match {
- case distribution: Distribution => List(new
EnforceDistribution(distribution))
- case ordering: Ordering => List(new EnforceOrdering(ordering))
- case _ => throw new UnsupportedOperationException()
+ override def newEnforcerRuleFactory(): EnforcerRuleFactory[TestNode] =
+ EnforcerRuleFactory.fromSubRules(Seq(new
EnforcerRuleFactory.SubRuleFactory[TestNode] {
+ override def newSubRule(constraintDef: PropertyDef[TestNode, _ <:
Property[TestNode]])
+ : EnforcerRuleFactory.SubRule[TestNode] = {
+ constraintDef match {
+ case DistributionDef => new EnforceDistribution()
+ case OrderingDef => new EnforceOrdering()
+ case _ => throw new UnsupportedOperationException()
+ }
}
- }
- }
+
+ override def ruleShape: Shape[TestNode] = Shapes.fixedHeight(1)
+ }))
}
trait DNode extends TestNode {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]