This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 993e96afe [CORE][VL] RAS: Refactor memo cache to look up on
cluster-canonical node rather than on group-canonical node (#5305)
993e96afe is described below
commit 993e96afe81ff85e2928151b4a9b9d45baf4b79f
Author: Hongze Zhang <[email protected]>
AuthorDate: Mon Apr 8 11:42:51 2024 +0800
[CORE][VL] RAS: Refactor memo cache to look up on cluster-canonical node
rather than on group-canonical node (#5305)
---
.../planner/property/GlutenPropertyModel.scala | 2 +
.../org/apache/gluten/ras/MetadataModel.scala | 3 +-
.../org/apache/gluten/ras/PropertyModel.scala | 1 +
.../src/main/scala/org/apache/gluten/ras/Ras.scala | 48 ++++-
.../scala/org/apache/gluten/ras/RasCluster.scala | 16 +-
.../main/scala/org/apache/gluten/ras/RasNode.scala | 53 ++++-
.../scala/org/apache/gluten/ras/RasPlanner.scala | 23 ++-
.../org/apache/gluten/ras/best/BestFinder.scala | 21 +-
.../gluten/ras/best/GroupBasedBestFinder.scala | 23 ++-
.../org/apache/gluten/ras/dp/DpClusterAlgo.scala | 2 +-
.../org/apache/gluten/ras/dp/DpGroupAlgo.scala | 2 +-
.../scala/org/apache/gluten/ras/dp/DpPlanner.scala | 4 +-
.../org/apache/gluten/ras/dp/DpZipperAlgo.scala | 2 -
.../gluten/ras/exaustive/ExhaustivePlanner.scala | 24 ++-
.../apache/gluten/ras/memo/ForwardMemoTable.scala | 19 +-
.../scala/org/apache/gluten/ras/memo/Memo.scala | 214 +++++++++++++--------
.../org/apache/gluten/ras/memo/MemoTable.scala | 1 +
.../scala/org/apache/gluten/ras/path/RasPath.scala | 43 ++---
.../org/apache/gluten/ras/rule/RuleApplier.scala | 51 ++---
.../scala/org/apache/gluten/ras/util/NodeMap.scala | 60 ------
.../apache/gluten/ras/vis/GraphvizVisualizer.scala | 4 +-
.../org/apache/gluten/ras/OperationSuite.scala | 8 +-
.../org/apache/gluten/ras/PropertySuite.scala | 34 +++-
.../scala/org/apache/gluten/ras/RasSuite.scala | 9 +-
.../org/apache/gluten/ras/path/RasPathSuite.scala | 60 ++++--
.../gluten/ras/specific/DistributedSuite.scala | 4 +
26 files changed, 453 insertions(+), 278 deletions(-)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
index 54f4e3b84..07dd3fe02 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
@@ -53,6 +53,8 @@ object GlutenProperties {
val conv = getProperty(plan)
plan.children.map(_ => conv)
}
+
+ override def any(): Convention = Conventions.ANY
}
case class ConventionEnforcerRule(reqConv: Convention) extends
RasRule[SparkPlan] {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala
index d2056746c..a81ac31cb 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala
@@ -21,8 +21,9 @@ package org.apache.gluten.ras
*/
trait MetadataModel[T <: AnyRef] {
def metadataOf(node: T): Metadata
- def dummy(): Metadata
def verify(one: Metadata, other: Metadata): Unit
+
+ def dummy(): Metadata
}
trait Metadata {}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala
index e2ba99136..e764631e7 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala
@@ -26,6 +26,7 @@ trait Property[T <: AnyRef] {
}
trait PropertyDef[T <: AnyRef, P <: Property[T]] {
+ def any(): P
def getProperty(plan: T): P
def getChildrenConstraints(constraint: Property[T], plan: T): Seq[P]
}
diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
index 6832d07c5..9910fab6f 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
@@ -79,9 +79,10 @@ class Ras[T <: AnyRef] private (
ruleFactory)
}
- // Normal groups start with ID 0, so it's safe to use -1 to do validation.
+ private val propSetFactory: PropertySetFactory[T] =
PropertySetFactory(propertyModel, planModel)
+ // Normal groups start with ID 0, so it's safe to use Int.MinValue to do
validation.
private val dummyGroup: T =
- planModel.newGroupLeaf(-1, metadataModel.dummy(), PropertySet(Seq.empty))
+ planModel.newGroupLeaf(Int.MinValue, metadataModel.dummy(),
propSetFactory.any())
private val infCost: Cost = costModel.makeInfCost()
validateModels()
@@ -123,8 +124,6 @@ class Ras[T <: AnyRef] private (
}
}
- private val propSetFactory: PropertySetFactory[T] = PropertySetFactory(this)
-
override def newPlanner(
plan: T,
constraintSet: PropertySet[T],
@@ -171,6 +170,8 @@ class Ras[T <: AnyRef] private (
private[ras] def getInfCost(): Cost = infCost
private[ras] def isInfCost(cost: Cost) =
costModel.costComparator().equiv(cost, infCost)
+
+ private[ras] def toUnsafeKey(node: T): UnsafeKey[T] = UnsafeKey(this, node)
}
object Ras {
@@ -192,16 +193,29 @@ object Ras {
}
trait PropertySetFactory[T <: AnyRef] {
+ def any(): PropertySet[T]
def get(node: T): PropertySet[T]
def childrenConstraintSets(constraintSet: PropertySet[T], node: T):
Seq[PropertySet[T]]
}
private object PropertySetFactory {
- def apply[T <: AnyRef](ras: Ras[T]): PropertySetFactory[T] = new
PropertySetFactoryImpl[T](ras)
-
- private class PropertySetFactoryImpl[T <: AnyRef](val ras: Ras[T])
+ def apply[T <: AnyRef](
+ propertyModel: PropertyModel[T],
+ planModel: PlanModel[T]): PropertySetFactory[T] =
+ new PropertySetFactoryImpl[T](propertyModel, planModel)
+
+ private class PropertySetFactoryImpl[T <: AnyRef](
+ propertyModel: PropertyModel[T],
+ planModel: PlanModel[T])
extends PropertySetFactory[T] {
- private val propDefs: Seq[PropertyDef[T, _ <: Property[T]]] =
ras.propertyModel.propertyDefs
+ private val propDefs: Seq[PropertyDef[T, _ <: Property[T]]] =
propertyModel.propertyDefs
+ private val anyConstraint = {
+ val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
+ propDefs.map(propDef => (propDef, propDef.any())).toMap
+ PropertySet[T](m)
+ }
+
+ override def any(): PropertySet[T] = anyConstraint
override def get(node: T): PropertySet[T] = {
val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
@@ -213,7 +227,7 @@ object Ras {
constraintSet: PropertySet[T],
node: T): Seq[PropertySet[T]] = {
val builder: Seq[mutable.Map[PropertyDef[T, _ <: Property[T]],
Property[T]]] =
- ras.planModel
+ planModel
.childrenOf(node)
.map(_ => mutable.Map[PropertyDef[T, _ <: Property[T]],
Property[T]]())
@@ -236,4 +250,20 @@ object Ras {
}
}
}
+
+ trait UnsafeKey[T]
+
+ private object UnsafeKey {
+ def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeKey[T] = new
UnsafeKeyImpl(ras, self)
+ private class UnsafeKeyImpl[T <: AnyRef](ras: Ras[T], val self: T) extends
UnsafeKey[T] {
+ override def hashCode(): Int = ras.planModel.hashCode(self)
+ override def equals(other: Any): Boolean = {
+ other match {
+ case that: UnsafeKeyImpl[T] => ras.planModel.equals(self, that.self)
+ case _ => false
+ }
+ }
+ override def toString: String = ras.explain.describeNode(self)
+ }
+ }
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
index 63b8b1e68..1b30e1242 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
@@ -16,6 +16,7 @@
*/
package org.apache.gluten.ras
+import org.apache.gluten.ras.Ras.UnsafeKey
import org.apache.gluten.ras.memo.MemoTable
import org.apache.gluten.ras.property.PropertySet
@@ -54,16 +55,19 @@ object RasCluster {
override val ras: Ras[T],
metadata: Metadata)
extends MutableRasCluster[T] {
- private val buffer: mutable.Set[CanonicalNode[T]] =
- mutable.Set()
+ private val deDup: mutable.Set[UnsafeKey[T]] = mutable.Set()
+ private val buffer: mutable.ListBuffer[CanonicalNode[T]] =
+ mutable.ListBuffer()
override def contains(t: CanonicalNode[T]): Boolean = {
- buffer.contains(t)
+ deDup.contains(t.toUnsafeKey())
}
override def add(t: CanonicalNode[T]): Unit = {
+ val key = t.toUnsafeKey()
+ assert(!deDup.contains(key))
ras.metadataModel.verify(metadata,
ras.metadataModel.metadataOf(t.self()))
- assert(!buffer.contains(t))
+ deDup += key
buffer += t
}
@@ -75,12 +79,12 @@ object RasCluster {
case class ImmutableRasCluster[T <: AnyRef] private (
ras: Ras[T],
- override val nodes: Set[CanonicalNode[T]])
+ override val nodes: Seq[CanonicalNode[T]])
extends RasCluster[T]
object ImmutableRasCluster {
def apply[T <: AnyRef](ras: Ras[T], cluster: RasCluster[T]):
ImmutableRasCluster[T] = {
- ImmutableRasCluster(ras, cluster.nodes().toSet)
+ ImmutableRasCluster(ras, cluster.nodes().toVector)
}
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
index 5f18f96a7..878020391 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
@@ -16,6 +16,7 @@
*/
package org.apache.gluten.ras
+import org.apache.gluten.ras.Ras.UnsafeKey
import org.apache.gluten.ras.property.PropertySet
trait RasNode[T <: AnyRef] {
@@ -41,6 +42,8 @@ object RasNode {
def asGroup(): GroupNode[T] = {
node.asInstanceOf[GroupNode[T]]
}
+
+ def toUnsafeKey(): UnsafeKey[T] = node.ras().toUnsafeKey(node.self())
}
}
@@ -53,7 +56,7 @@ object CanonicalNode {
assert(ras.isCanonical(canonical))
val propSet = ras.propSetsOf(canonical)
val children = ras.planModel.childrenOf(canonical)
- CanonicalNodeImpl[T](ras, canonical, propSet, children.size)
+ new CanonicalNodeImpl[T](ras, canonical, propSet, children.size)
}
// We put RasNode's API methods that accept mutable input in implicit
definition.
@@ -74,12 +77,16 @@ object CanonicalNode {
}
}
- private case class CanonicalNodeImpl[T <: AnyRef](
- ras: Ras[T],
+ private class CanonicalNodeImpl[T <: AnyRef](
+ override val ras: Ras[T],
override val self: T,
override val propSet: PropertySet[T],
override val childrenCount: Int)
- extends CanonicalNode[T]
+ extends CanonicalNode[T] {
+ override def toString: String = ras.explain.describeNode(self)
+ override def hashCode(): Int = throw new UnsupportedOperationException()
+ override def equals(obj: Any): Boolean = throw new
UnsupportedOperationException()
+ }
}
trait GroupNode[T <: AnyRef] extends RasNode[T] {
@@ -88,15 +95,19 @@ trait GroupNode[T <: AnyRef] extends RasNode[T] {
object GroupNode {
def apply[T <: AnyRef](ras: Ras[T], group: RasGroup[T]): GroupNode[T] = {
- GroupNodeImpl[T](ras, group.self(), group.propSet(), group.id())
+ new GroupNodeImpl[T](ras, group.self(), group.propSet(), group.id())
}
- private case class GroupNodeImpl[T <: AnyRef](
- ras: Ras[T],
+ private class GroupNodeImpl[T <: AnyRef](
+ override val ras: Ras[T],
override val self: T,
override val propSet: PropertySet[T],
override val groupId: Int)
- extends GroupNode[T] {}
+ extends GroupNode[T] {
+ override def toString: String = ras.explain.describeNode(self)
+ override def hashCode(): Int = throw new UnsupportedOperationException()
+ override def equals(obj: Any): Boolean = throw new
UnsupportedOperationException()
+ }
// We put RasNode's API methods that accept mutable input in implicit
definition.
// Do not break this rule during further development.
@@ -116,8 +127,21 @@ object InGroupNode {
def apply[T <: AnyRef](groupId: Int, node: CanonicalNode[T]): InGroupNode[T]
= {
InGroupNodeImpl(groupId, node)
}
+
private case class InGroupNodeImpl[T <: AnyRef](groupId: Int, can:
CanonicalNode[T])
extends InGroupNode[T]
+
+ trait HashKey extends Any
+
+ implicit class InGroupNodeImplicits[T <: AnyRef](n: InGroupNode[T]) {
+ import InGroupNodeImplicits._
+ def toHashKey: HashKey =
+ InGroupNodeHashKeyImpl(n.groupId, System.identityHashCode(n.can))
+ }
+
+ private object InGroupNodeImplicits {
+ private case class InGroupNodeHashKeyImpl(gid: Int, cid: Int) extends
HashKey
+ }
}
trait InClusterNode[T <: AnyRef] {
@@ -129,8 +153,21 @@ object InClusterNode {
def apply[T <: AnyRef](clusterId: RasClusterKey, node: CanonicalNode[T]):
InClusterNode[T] = {
InClusterNodeImpl(clusterId, node)
}
+
private case class InClusterNodeImpl[T <: AnyRef](
clusterKey: RasClusterKey,
can: CanonicalNode[T])
extends InClusterNode[T]
+
+ trait HashKey extends Any
+
+ implicit class InClusterNodeImplicits[T <: AnyRef](n: InClusterNode[T]) {
+ import InClusterNodeImplicits._
+ def toHashKey: HashKey =
+ InClusterNodeHashKeyImpl(n.clusterKey, System.identityHashCode(n.can))
+ }
+
+ private object InClusterNodeImplicits {
+ private case class InClusterNodeHashKeyImpl(clusterKey: RasClusterKey,
cid: Int) extends HashKey
+ }
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
index 0665d3661..74793a3d0 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
@@ -49,8 +49,8 @@ object RasPlanner {
trait Best[T <: AnyRef] {
import Best._
def rootGroupId(): Int
- def bestNodes(): Set[InGroupNode[T]]
- def winnerNodes(): Set[InGroupNode[T]]
+ def bestNodes(): InGroupNode[T] => Boolean
+ def winnerNodes(): InGroupNode[T] => Boolean
def costs(): InGroupNode[T] => Option[Cost]
def path(): KnownCostPath[T]
}
@@ -62,11 +62,11 @@ object Best {
bestPath: KnownCostPath[T],
winnerNodes: Seq[InGroupNode[T]],
costs: InGroupNode[T] => Option[Cost]): Best[T] = {
- val bestNodes = mutable.Set[InGroupNode[T]]()
+ val bestNodes = mutable.Set[InGroupNode.HashKey]()
def dfs(groupId: Int, cursor: RasPath.PathNode[T]): Unit = {
val can = cursor.self().asCanonical()
- bestNodes += InGroupNode(groupId, can)
+ bestNodes += InGroupNode(groupId, can).toHashKey
cursor.zipChildrenWithGroupIds().foreach {
case (childPathNode, childGroupId) =>
dfs(childGroupId, childPathNode)
@@ -75,17 +75,24 @@ object Best {
dfs(rootGroupId, bestPath.rasPath.node())
- val winnerNodeSet = winnerNodes.toSet
+ val bestNodeSet = bestNodes.toSet
+ val winnerNodeSet = winnerNodes.map(_.toHashKey).toSet
- BestImpl(ras, rootGroupId, bestPath, bestNodes.toSet, winnerNodeSet, costs)
+ BestImpl(
+ ras,
+ rootGroupId,
+ bestPath,
+ n => bestNodeSet.contains(n.toHashKey),
+ n => winnerNodeSet.contains(n.toHashKey),
+ costs)
}
private case class BestImpl[T <: AnyRef](
ras: Ras[T],
override val rootGroupId: Int,
override val path: KnownCostPath[T],
- override val bestNodes: Set[InGroupNode[T]],
- override val winnerNodes: Set[InGroupNode[T]],
+ override val bestNodes: InGroupNode[T] => Boolean,
+ override val winnerNodes: InGroupNode[T] => Boolean,
override val costs: InGroupNode[T] => Option[Cost])
extends Best[T]
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
index 4ec7e09f5..0912ab536 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
@@ -40,9 +40,12 @@ object BestFinder {
}
case class KnownCostGroup[T <: AnyRef](
- nodeToCost: Map[CanonicalNode[T], KnownCostPath[T]],
+ nodes: Iterable[CanonicalNode[T]],
+ nodeToCost: CanonicalNode[T] => Option[KnownCostPath[T]],
bestNode: CanonicalNode[T]) {
- def best(): KnownCostPath[T] = nodeToCost(bestNode)
+ def best(): KnownCostPath[T] = {
+ nodeToCost(bestNode).get
+ }
}
case class KnownCostCluster[T <: AnyRef](groupToCost: Map[Int,
KnownCostGroup[T]])
@@ -52,17 +55,21 @@ object BestFinder {
allGroups: Seq[RasGroup[T]],
group: RasGroup[T],
groupToCosts: Map[Int, KnownCostGroup[T]]): Best[T] = {
+
val bestPath = groupToCosts(group.id()).best()
val bestRoot = bestPath.rasPath.node()
val winnerNodes = groupToCosts.map { case (id, g) => InGroupNode(id,
g.bestNode) }.toSeq
- val costsMap = mutable.Map[InGroupNode[T], Cost]()
+ val costsMap = mutable.Map[InGroupNode.HashKey, Cost]()
groupToCosts.foreach {
case (gid, g) =>
- g.nodeToCost.foreach {
- case (n, c) =>
- costsMap += (InGroupNode(gid, n) -> c.cost)
+ g.nodes.foreach {
+ n =>
+ val c = g.nodeToCost(n)
+ if (c.nonEmpty) {
+ costsMap += (InGroupNode(gid, n).toHashKey -> c.get.cost)
+ }
}
}
- Best(ras, group.id(), bestPath, winnerNodes, costsMap.get)
+ Best(ras, group.id(), bestPath, winnerNodes, ign =>
costsMap.get(ign.toHashKey))
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
index 7d2d807ff..6db3600de 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
@@ -23,6 +23,8 @@ import org.apache.gluten.ras.dp.{DpGroupAlgo, DpGroupAlgoDef}
import org.apache.gluten.ras.memo.MemoState
import org.apache.gluten.ras.path.{PathKeySet, RasPath}
+import java.util
+
// The best path's each sub-path is considered optimal in its own group.
private class GroupBasedBestFinder[T <: AnyRef](
ras: Ras[T],
@@ -94,21 +96,34 @@ private object GroupBasedBestFinder {
override def solveGroup(
group: RasGroup[T],
nodesOutput: InGroupNode[T] => Option[KnownCostPath[T]]):
Option[KnownCostGroup[T]] = {
+ import scala.collection.JavaConverters._
+
val nodes = group.nodes(memoState)
// Allow unsolved children nodes while solving group.
- val flatNodesOutput =
- nodes.flatMap(n => nodesOutput(InGroupNode(group.id(), n)).map(kcp =>
n -> kcp)).toMap
+ val flatNodesOutput = new util.IdentityHashMap[CanonicalNode[T],
KnownCostPath[T]]()
+
+ nodes
+ .flatMap(n => nodesOutput(InGroupNode(group.id(), n)).map(kcp => n ->
kcp))
+ .foreach {
+ case (n, kcp) =>
+ assert(!flatNodesOutput.containsKey(n))
+ flatNodesOutput.put(n, kcp)
+ }
if (flatNodesOutput.isEmpty) {
return None
}
- val bestPath = flatNodesOutput.values.reduce {
+ val bestPath = flatNodesOutput.values.asScala.reduce {
(left, right) =>
Ordering
.by((cp: KnownCostPath[T]) => cp.cost)(costComparator)
.min(left, right)
}
- Some(KnownCostGroup(flatNodesOutput,
bestPath.rasPath.node().self().asCanonical()))
+ Some(
+ KnownCostGroup(
+ nodes,
+ n => Option(flatNodesOutput.get(n)),
+ bestPath.rasPath.node().self().asCanonical()))
}
override def solveNodeOnCycle(node: InGroupNode[T]):
Option[KnownCostPath[T]] =
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
index 95f453f47..e90ba448b 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
@@ -73,7 +73,7 @@ object DpClusterAlgo {
clusterAlgoDef: DpClusterAlgoDef[T, NodeOutput, ClusterOutput])
extends DpZipperAlgoDef[InClusterNode[T], RasClusterKey, NodeOutput,
ClusterOutput] {
override def idOfX(x: InClusterNode[T]): Any = {
- x
+ x.toHashKey
}
override def idOfY(y: RasClusterKey): Any = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
index 6c1e998b6..c824fda8e 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
@@ -66,7 +66,7 @@ object DpGroupAlgo {
groupAlgoDef: DpGroupAlgoDef[T, NodeOutput, GroupOutput])
extends DpZipperAlgoDef[InGroupNode[T], RasGroup[T], NodeOutput,
GroupOutput] {
override def idOfX(x: InGroupNode[T]): Any = {
- x
+ x.toHashKey
}
override def idOfY(y: RasGroup[T]): Any = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
index 391e7f196..1be728ae6 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
@@ -21,7 +21,7 @@ import org.apache.gluten.ras.Best.KnownCostPath
import org.apache.gluten.ras.best.BestFinder
import org.apache.gluten.ras.dp.DpZipperAlgo.Adjustment.Panel
import org.apache.gluten.ras.memo.{Memo, MemoTable}
-import org.apache.gluten.ras.path.{PathFinder, RasPath}
+import org.apache.gluten.ras.path.{InClusterPath, PathFinder, RasPath}
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.rule.{EnforcerRuleSet, RuleApplier, Shape}
@@ -172,7 +172,7 @@ object DpPlanner {
rule: RuleApplier[T],
path: RasPath[T]): Unit = {
val probe = memoTable.probe()
- rule.apply(path)
+ rule.apply(InClusterPath(thisClusterKey, path))
val diff = probe.toDiff()
val changedClusters = diff.changedClusters()
if (changedClusters.isEmpty) {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala
index 821009982..f28edd0dc 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala
@@ -608,7 +608,6 @@ object DpZipperAlgo {
}
private object XKey {
- // Keep argument "ele" although it is unused. To give compiler type hint.
def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
x: X): XKey[X, Y, XOutput, YOutput] = {
@@ -631,7 +630,6 @@ object DpZipperAlgo {
}
private object YKey {
- // Keep argument "ele" although it is unused. To give compiler type hint.
def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
y: Y): YKey[X, Y, XOutput, YOutput] = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
index a9737eb02..47945fc14 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
@@ -107,8 +107,8 @@ object ExhaustivePlanner {
finder.find(canonical).foreach(path => onFound(path))
}
- private def applyRule(rule: RuleApplier[T], path: RasPath[T]): Unit = {
- rule.apply(path)
+ private def applyRule(rule: RuleApplier[T], icp: InClusterPath[T]): Unit =
{
+ rule.apply(icp)
}
private def applyRules(): Unit = {
@@ -116,10 +116,17 @@ object ExhaustivePlanner {
return
}
val shapes = rules.map(_.shape())
- allClusters
- .flatMap(c => c.nodes())
- .foreach(
- node => findPaths(node, shapes)(path => rules.foreach(rule =>
applyRule(rule, path))))
+ memoState
+ .clusterLookup()
+ .foreach {
+ case (cKey, cluster) =>
+ cluster
+ .nodes()
+ .foreach(
+ node =>
+ findPaths(node, shapes)(
+ path => rules.foreach(rule => applyRule(rule,
InClusterPath(cKey, path)))))
+ }
}
private def applyEnforcerRules(): Unit = {
@@ -129,10 +136,11 @@ object ExhaustivePlanner {
val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
if (enforcerRules.nonEmpty) {
val shapes = enforcerRules.map(_.shape())
- memoState.clusterLookup()(group.clusterKey()).nodes().foreach {
+ val cKey = group.clusterKey()
+ memoState.clusterLookup()(cKey).nodes().foreach {
node =>
findPaths(node, shapes)(
- path => enforcerRules.foreach(rule => applyRule(rule, path)))
+ path => enforcerRules.foreach(rule => applyRule(rule,
InClusterPath(cKey, path))))
}
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
index 945e653eb..e3ae03ebf 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
@@ -33,6 +33,8 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
private val clusterKeyBuffer: mutable.ArrayBuffer[IntClusterKey] =
mutable.ArrayBuffer()
private val clusterBuffer: mutable.ArrayBuffer[MutableRasCluster[T]] =
mutable.ArrayBuffer()
private val clusterDisjointSet: IndexDisjointSet = IndexDisjointSet()
+ private val clusterDummyGroupBuffer = mutable.ArrayBuffer[RasGroup[T]]()
+
private val groupLookup: mutable.ArrayBuffer[mutable.Map[PropertySet[T],
RasGroup[T]]] =
mutable.ArrayBuffer()
@@ -46,14 +48,22 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
override def newCluster(metadata: Metadata): RasClusterKey = {
checkBufferSizes()
- val key = IntClusterKey(clusterBuffer.size, metadata)
+ val clusterId = clusterBuffer.size
+ val key = IntClusterKey(clusterId, metadata)
clusterKeyBuffer += key
clusterBuffer += MutableRasCluster(ras, metadata)
clusterDisjointSet.grow()
groupLookup += mutable.Map()
+ // Normal groups start with ID 0, so it's safe to use negative IDs for
dummy groups.
+ clusterDummyGroupBuffer += RasGroup(ras, key, -clusterId,
ras.propertySetFactory().any())
key
}
+ override def dummyGroupOf(key: RasClusterKey): RasGroup[T] = {
+ val ancestor = ancestorClusterIdOf(key)
+ clusterDummyGroupBuffer(ancestor)
+ }
+
override def groupOf(key: RasClusterKey, propSet: PropertySet[T]):
RasGroup[T] = {
val ancestor = ancestorClusterIdOf(key)
val lookup = groupLookup(ancestor)
@@ -75,7 +85,11 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
}
override def addToCluster(key: RasClusterKey, node: CanonicalNode[T]): Unit
= {
- getCluster(key).add(node)
+ val cluster = getCluster(key)
+ if (cluster.contains(node)) {
+ return
+ }
+ cluster.add(node)
memoWriteCount += 1
}
@@ -142,6 +156,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
assert(clusterKeyBuffer.size == clusterBuffer.size)
assert(clusterKeyBuffer.size == clusterDisjointSet.size)
assert(clusterKeyBuffer.size == groupLookup.size)
+ assert(clusterKeyBuffer.size == clusterDummyGroupBuffer.size)
}
override def probe(): MemoTable.Probe[T] = new
ForwardMemoTable.Probe[T](this)
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
index a77293586..66626b756 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
@@ -17,17 +17,19 @@
package org.apache.gluten.ras.memo
import org.apache.gluten.ras._
+import org.apache.gluten.ras.Ras.UnsafeKey
import org.apache.gluten.ras.RasCluster.ImmutableRasCluster
import org.apache.gluten.ras.property.PropertySet
-import org.apache.gluten.ras.util.CanonicalNodeMap
import org.apache.gluten.ras.vis.GraphvizVisualizer
+import scala.collection.mutable
+
trait MemoLike[T <: AnyRef] {
def memorize(node: T, constraintSet: PropertySet[T]): RasGroup[T]
}
trait Closure[T <: AnyRef] {
- def openFor(node: CanonicalNode[T]): MemoLike[T]
+ def openFor(cKey: RasClusterKey): MemoLike[T]
}
trait Memo[T <: AnyRef] extends Closure[T] with MemoLike[T] {
@@ -51,82 +53,61 @@ object Memo {
private class RasMemo[T <: AnyRef](val ras: Ras[T]) extends UnsafeMemo[T] {
import RasMemo._
private val memoTable: MemoTable.Writable[T] = MemoTable.create(ras)
- private val cache: NodeToClusterMap[T] = new NodeToClusterMap(ras)
+ private val cache = mutable.Map[MemoCacheKey[T], RasClusterKey]()
private def newCluster(metadata: Metadata): RasClusterKey = {
memoTable.newCluster(metadata)
}
private def addToCluster(clusterKey: RasClusterKey, can:
CanonicalNode[T]): Unit = {
- assert(!cache.contains(can))
- cache.put(can, clusterKey)
memoTable.addToCluster(clusterKey, can)
}
- // Replace node's children with node groups. When a group doesn't exist,
create it.
- private def canonizeUnsafe(node: T, constraintSet: PropertySet[T], depth:
Int): T = {
- assert(depth >= 1)
- if (depth > 1) {
- return ras.withNewChildren(
- node,
- ras.planModel
- .childrenOf(node)
-
.zip(ras.propertySetFactory().childrenConstraintSets(constraintSet, node))
- .map {
- case (child, constraintSet) =>
- canonizeUnsafe(child, constraintSet, depth - 1)
- }
- )
+ private def clusterOfUnsafe(metadata: Metadata, cacheKey:
MemoCacheKey[T]): RasClusterKey = {
+ if (cache.contains(cacheKey)) {
+ cache(cacheKey)
+ } else {
+ // Node not yet added to cluster.
+ val cluster = newCluster(metadata)
+ cache += (cacheKey -> cluster)
+ cluster
}
- assert(depth == 1)
- val childrenGroups: Seq[RasGroup[T]] = ras.planModel
- .childrenOf(node)
- .zip(ras.propertySetFactory().childrenConstraintSets(constraintSet,
node))
- .map {
- case (child, childConstraintSet) =>
- memorize(child, childConstraintSet)
- }
- val newNode =
- ras.withNewChildren(node, childrenGroups.map(group => group.self()))
- newNode
}
- private def canonize(node: T, constraintSet: PropertySet[T]):
CanonicalNode[T] = {
- CanonicalNode(ras, canonizeUnsafe(node, constraintSet, 1))
+ private def dummyGroupOf(clusterKey: RasClusterKey): RasGroup[T] = {
+ memoTable.dummyGroupOf(clusterKey)
+ }
+
+ private def toCacheKeyUnsafe(n: T): MemoCacheKey[T] = {
+ MemoCacheKey(ras, n)
}
- private def insert(n: T, constraintSet: PropertySet[T]): RasClusterKey = {
- if (ras.planModel.isGroupLeaf(n)) {
- val plainGroup = memoTable.allGroups()(ras.planModel.getGroupId(n))
- return plainGroup.clusterKey()
+ private def prepareInsert(n: T): Prepare[T] = {
+ if (ras.isGroupLeaf(n)) {
+ val group = memoTable.allGroups()(ras.planModel.getGroupId(n))
+ return Prepare.cluster(this, group.clusterKey())
}
- val node = canonize(n, constraintSet)
+ val childrenPrepares = ras.planModel.childrenOf(n).map(child =>
prepareInsert(child))
- if (cache.contains(node)) {
- cache.get(node)
- } else {
- // Node not yet added to cluster.
- val meta = ras.metadataModel.metadataOf(node.self())
- val clusterKey = newCluster(meta)
- addToCluster(clusterKey, node)
- clusterKey
- }
+ val canUnsafe = ras.withNewChildren(
+ n,
+ childrenPrepares.map(childPrepare =>
dummyGroupOf(childPrepare.clusterKey()).self()))
+
+ val cacheKey = toCacheKeyUnsafe(canUnsafe)
+
+ val clusterKey = clusterOfUnsafe(ras.metadataModel.metadataOf(n),
cacheKey)
+
+ Prepare.tree(this, clusterKey, childrenPrepares)
}
override def memorize(node: T, constraintSet: PropertySet[T]): RasGroup[T]
= {
- val clusterKey = insert(node, constraintSet)
- val prevGroupCount = memoTable.allGroups().size
- val out = memoTable.groupOf(clusterKey, constraintSet)
- val newGroupCount = memoTable.allGroups().size
- assert(newGroupCount >= prevGroupCount)
- out
+ val prepare = prepareInsert(node)
+ prepare.doInsert(node, constraintSet)
}
- override def openFor(node: CanonicalNode[T]): MemoLike[T] = {
- assert(cache.contains(node))
- val targetCluster = cache.get(node)
- new InCusterMemo[T](this, targetCluster)
+ override def openFor(cKey: RasClusterKey): MemoLike[T] = {
+ new InCusterMemo[T](this, cKey)
}
override def newState(): MemoState[T] = {
@@ -141,37 +122,116 @@ object Memo {
}
private object RasMemo {
- private class InCusterMemo[T <: AnyRef](parent: RasMemo[T],
preparedCluster: RasClusterKey)
+ private class InCusterMemo[T <: AnyRef](parent: RasMemo[T], targetCluster:
RasClusterKey)
extends MemoLike[T] {
+ private val ras = parent.ras
+
+ private def prepareInsert(node: T): Prepare[T] = {
+ assert(!ras.isGroupLeaf(node))
+
+ val childrenPrepares =
+ ras.planModel.childrenOf(node).map(child =>
parent.prepareInsert(child))
+
+ val canUnsafe = ras.withNewChildren(
+ node,
+ childrenPrepares.map {
+ childPrepare =>
parent.dummyGroupOf(childPrepare.clusterKey()).self()
+ })
+
+ val cacheKey = parent.toCacheKeyUnsafe(canUnsafe)
+
+ if (!parent.cache.contains(cacheKey)) {
+ // The new node was not added to memo yet. Add it to the target
cluster.
+ parent.cache += (cacheKey -> targetCluster)
+ return Prepare.tree(parent, targetCluster, childrenPrepares)
+ }
+
+ // The new node already memorized to memo.
- private def insert(node: T, constraintSet: PropertySet[T]): Unit = {
- val can = parent.canonize(node, constraintSet)
- if (parent.cache.contains(can)) {
- val cachedCluster = parent.cache.get(can)
- if (cachedCluster == preparedCluster) {
- return
- }
- // The new node already memorized to memo, but in the different
cluster
- // with the input node. Merge the two clusters.
- //
- // TODO: Traversal up the tree to do more merges.
- parent.memoTable.mergeClusters(cachedCluster, preparedCluster)
- // Since new node already memorized, we don't have to add it to
either of the clusters
- // anymore.
- return
+ val cachedCluster = parent.cache(cacheKey)
+ if (cachedCluster == targetCluster) {
+ // The new node already memorized to memo and in the target cluster.
+ return Prepare.tree(parent, targetCluster, childrenPrepares)
}
- parent.addToCluster(preparedCluster, can)
+ // The new node already memorized to memo, but in the different
cluster.
+ // Merge the two clusters.
+ //
+ // TODO: Traverse up the tree to do more merges.
+ parent.memoTable.mergeClusters(cachedCluster, targetCluster)
+ Prepare.tree(parent, targetCluster, childrenPrepares)
}
override def memorize(node: T, constraintSet: PropertySet[T]):
RasGroup[T] = {
- insert(node, constraintSet)
- parent.memoTable.groupOf(preparedCluster, constraintSet)
+ val prepare = prepareInsert(node)
+ prepare.doInsert(node, constraintSet)
}
}
+
+ private trait Prepare[T <: AnyRef] {
+ def clusterKey(): RasClusterKey
+ def doInsert(node: T, constraintSet: PropertySet[T]): RasGroup[T]
+ }
+
+ private object Prepare {
+ def tree[T <: AnyRef](
+ memo: RasMemo[T],
+ cKey: RasClusterKey,
+ children: Seq[Prepare[T]]): Prepare[T] = {
+ new TreePrepare[T](memo, cKey, children)
+ }
+
+ def cluster[T <: AnyRef](memo: RasMemo[T], cKey: RasClusterKey):
Prepare[T] = {
+ new ClusterPrepare[T](memo, cKey)
+ }
+
+ private class TreePrepare[T <: AnyRef](
+ memo: RasMemo[T],
+ override val clusterKey: RasClusterKey,
+ children: Seq[Prepare[T]])
+ extends Prepare[T] {
+ private val ras = memo.ras
+
+ override def doInsert(node: T, constraintSet: PropertySet[T]):
RasGroup[T] = {
+ assert(!ras.isGroupLeaf(node))
+ val childrenGroups = children
+ .zip(ras.planModel.childrenOf(node))
+
.zip(ras.propertySetFactory().childrenConstraintSets(constraintSet, node))
+ .map {
+ case ((childPrepare, child), childConstraintSet) =>
+ childPrepare.doInsert(child, childConstraintSet)
+ }
+
+ val canUnsafe = ras.withNewChildren(node, childrenGroups.map(group
=> group.self()))
+ val can = CanonicalNode(ras, canUnsafe)
+
+ memo.addToCluster(clusterKey, can)
+
+ val group = memo.memoTable.groupOf(clusterKey, constraintSet)
+ group
+ }
+ }
+
+ private class ClusterPrepare[T <: AnyRef](memo: RasMemo[T], cKey:
RasClusterKey)
+ extends Prepare[T] {
+ private val ras = memo.ras
+ override def doInsert(node: T, constraintSet: PropertySet[T]):
RasGroup[T] = {
+ assert(ras.isGroupLeaf(node))
+ memo.memoTable.groupOf(cKey, constraintSet)
+ }
+
+ override def clusterKey(): RasClusterKey = cKey
+ }
+ }
+ }
+
+ private object MemoCacheKey {
+ def apply[T <: AnyRef](ras: Ras[T], self: T): MemoCacheKey[T] = {
+ assert(ras.isCanonical(self))
+ MemoCacheKey[T](ras.toUnsafeKey(self))
+ }
}
- private class NodeToClusterMap[T <: AnyRef](ras: Ras[T])
- extends CanonicalNodeMap[T, RasClusterKey](ras)
+ private case class MemoCacheKey[T <: AnyRef] private (delegate: UnsafeKey[T])
}
trait MemoStore[T <: AnyRef] {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
index b54bd8811..3baba8eae 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
@@ -44,6 +44,7 @@ object MemoTable {
trait Writable[T <: AnyRef] extends MemoTable[T] {
def newCluster(metadata: Metadata): RasClusterKey
def groupOf(key: RasClusterKey, propertySet: PropertySet[T]): RasGroup[T]
+ def dummyGroupOf(key: RasClusterKey): RasGroup[T]
def addToCluster(key: RasClusterKey, node: CanonicalNode[T]): Unit
def mergeClusters(one: RasClusterKey, other: RasClusterKey): Unit
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala
index ca712cec4..61fa22e5e 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala
@@ -37,7 +37,7 @@ object RasPath {
object PathNode {
def apply[T <: AnyRef](node: RasNode[T], children: Seq[PathNode[T]]):
PathNode[T] = {
- PathNodeImpl(node, children)
+ new PathNodeImpl(node, children)
}
}
@@ -61,7 +61,7 @@ object RasPath {
keys: PathKeySet,
height: Int,
node: RasPath.PathNode[T]): RasPath[T] = {
- RasPathImpl(ras, keys, height, node)
+ new RasPathImpl(ras, keys, height, node)
}
// Returns none if children doesn't share at least one path key.
@@ -103,25 +103,6 @@ object RasPath {
PathNode(canonical, canonical.getChildrenGroups(allGroups).map(g =>
PathNode(g, List.empty))))
}
- // Aggregates paths that have same shape but different keys together.
- // Currently not in use because of bad performance.
- def aggregate[T <: AnyRef](ras: Ras[T], paths: Iterable[RasPath[T]]):
Iterable[RasPath[T]] = {
- // Scala has specialized optimization against small set of input of
group-by.
- // So it's better only to pass small inputs to this method if possible.
- val grouped = paths.groupBy(_.node())
- grouped.map {
- case (node, paths) =>
- val heights = paths.map(_.height()).toSeq.distinct
- assert(heights.size == 1)
- val height = heights.head
- val keys = paths.map(_.keys().keys()).reduce[Set[PathKey]] {
- case (one, other) =>
- one.union(other)
- }
- RasPath(ras, PathKeySet(keys), height, node)
- }
- }
-
def cartesianProduct[T <: AnyRef](
ras: Ras[T],
canonical: CanonicalNode[T],
@@ -171,12 +152,12 @@ object RasPath {
}
}
- private case class PathNodeImpl[T <: AnyRef](
+ private class PathNodeImpl[T <: AnyRef](
override val self: RasNode[T],
override val children: Seq[PathNode[T]])
extends PathNode[T]
- private case class RasPathImpl[T <: AnyRef](
+ private class RasPathImpl[T <: AnyRef](
override val ras: Ras[T],
override val keys: PathKeySet,
override val height: Int,
@@ -193,3 +174,19 @@ object RasPath {
override def plan(): T = built
}
}
+
+trait InClusterPath[T <: AnyRef] {
+ def cluster(): RasClusterKey
+ def path(): RasPath[T]
+}
+
+object InClusterPath {
+ def apply[T <: AnyRef](cluster: RasClusterKey, path: RasPath[T]):
InClusterPath[T] = {
+ new InClusterPathImpl(cluster, path)
+ }
+
+ private class InClusterPathImpl[T <: AnyRef](
+ override val cluster: RasClusterKey,
+ override val path: RasPath[T])
+ extends InClusterPath[T]
+}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
index b99001e93..0a7bf0c76 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
@@ -17,14 +17,14 @@
package org.apache.gluten.ras.rule
import org.apache.gluten.ras._
+import org.apache.gluten.ras.Ras.UnsafeKey
import org.apache.gluten.ras.memo.Closure
-import org.apache.gluten.ras.path.RasPath
-import org.apache.gluten.ras.util.CanonicalNodeMap
+import org.apache.gluten.ras.path.InClusterPath
import scala.collection.mutable
trait RuleApplier[T <: AnyRef] {
- def apply(path: RasPath[T]): Unit
+ def apply(icp: InClusterPath[T]): Unit
def shape(): Shape[T]
}
@@ -42,25 +42,27 @@ object RuleApplier {
private class RegularRuleApplier[T <: AnyRef](ras: Ras[T], closure:
Closure[T], rule: RasRule[T])
extends RuleApplier[T] {
- private val cache = new CanonicalNodeMap[T, mutable.Set[T]](ras)
+ private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]()
- override def apply(path: RasPath[T]): Unit = {
- val can = path.node().self().asCanonical()
+ override def apply(icp: InClusterPath[T]): Unit = {
+ val cKey = icp.cluster()
+ val path = icp.path()
val plan = path.plan()
- val appliedPlans = cache.getOrElseUpdate(can, mutable.Set())
- if (appliedPlans.contains(plan)) {
+ val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set())
+ val pKey = ras.toUnsafeKey(plan)
+ if (appliedPlans.contains(pKey)) {
return
}
- apply0(can, plan)
- appliedPlans += plan
+ apply0(cKey, plan)
+ appliedPlans += pKey
}
- private def apply0(can: CanonicalNode[T], plan: T): Unit = {
+ private def apply0(cKey: RasClusterKey, plan: T): Unit = {
val equivalents = rule.shift(plan)
equivalents.foreach {
equiv =>
closure
- .openFor(can)
+ .openFor(cKey)
.memorize(equiv, ras.propertySetFactory().get(equiv))
}
}
@@ -73,32 +75,35 @@ object RuleApplier {
closure: Closure[T],
rule: EnforcerRule[T])
extends RuleApplier[T] {
- private val cache = new CanonicalNodeMap[T, mutable.Set[T]](ras)
+ private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]()
private val constraint = rule.constraint()
private val constraintDef = constraint.definition()
- override def apply(path: RasPath[T]): Unit = {
+ override def apply(icp: InClusterPath[T]): Unit = {
+ val cKey = icp.cluster()
+ val path = icp.path()
val can = path.node().self().asCanonical()
if (can.propSet().get(constraintDef).satisfies(constraint)) {
return
}
val plan = path.plan()
- val appliedPlans = cache.getOrElseUpdate(can, mutable.Set())
- if (appliedPlans.contains(plan)) {
+ val pKey = ras.toUnsafeKey(plan)
+ val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set())
+ if (appliedPlans.contains(pKey)) {
return
}
- apply0(can, plan)
- appliedPlans += plan
+ apply0(cKey, plan)
+ appliedPlans += pKey
}
- private def apply0(can: CanonicalNode[T], plan: T): Unit = {
+ private def apply0(cKey: RasClusterKey, plan: T): Unit = {
val propSet = ras.propertySetFactory().get(plan)
val constraintSet = propSet.withProp(constraint)
val equivalents = rule.shift(plan)
equivalents.foreach {
equiv =>
closure
- .openFor(can)
+ .openFor(cKey)
.memorize(equiv, constraintSet)
}
}
@@ -110,11 +115,11 @@ object RuleApplier {
extends RuleApplier[T] {
private val ruleShape = rule.shape()
- override def apply(path: RasPath[T]): Unit = {
- if (!ruleShape.identify(path)) {
+ override def apply(icp: InClusterPath[T]): Unit = {
+ if (!ruleShape.identify(icp.path())) {
return
}
- rule.apply(path)
+ rule.apply(icp)
}
override def shape(): Shape[T] = ruleShape
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/NodeMap.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/NodeMap.scala
deleted file mode 100644
index 887e00bdc..000000000
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/NodeMap.scala
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.ras.util
-
-import org.apache.gluten.ras.{CanonicalNode, Ras}
-
-import scala.collection.mutable
-
-// Arbitrary node key.
-class NodeKey[T <: AnyRef](ras: Ras[T], val node: T) {
- override def hashCode(): Int = ras.planModel.hashCode(node)
-
- override def equals(obj: Any): Boolean = {
- obj match {
- case other: NodeKey[T] => ras.planModel.equals(node, other.node)
- case _ => false
- }
- }
-
- override def toString(): String = s"NodeKey($node)"
-}
-
-// Canonical node map.
-class CanonicalNodeMap[T <: AnyRef, V](ras: Ras[T]) {
- private val map: mutable.Map[NodeKey[T], V] = mutable.Map()
-
- def contains(node: CanonicalNode[T]): Boolean = {
- map.contains(keyOf(node))
- }
-
- def put(node: CanonicalNode[T], value: V): Unit = {
- map.put(keyOf(node), value)
- }
-
- def get(node: CanonicalNode[T]): V = {
- map(keyOf(node))
- }
-
- def getOrElseUpdate(node: CanonicalNode[T], op: => V): V = {
- map.getOrElseUpdate(keyOf(node), op)
- }
-
- private def keyOf(node: CanonicalNode[T]): NodeKey[T] = {
- new NodeKey(ras, node.self())
- }
-}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
index 600a61edc..11f6051b0 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
@@ -43,13 +43,13 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T],
memoState: MemoState[T], best
object IsBestNode {
def unapply(nodeAndGroupToTest: (CanonicalNode[T], RasGroup[T])):
Boolean = {
- bestNodes.contains(InGroupNode(nodeAndGroupToTest._2.id(),
nodeAndGroupToTest._1))
+ bestNodes(InGroupNode(nodeAndGroupToTest._2.id(),
nodeAndGroupToTest._1))
}
}
object IsWinnerNode {
def unapply(nodeAndGroupToTest: (CanonicalNode[T], RasGroup[T])):
Boolean = {
- winnerNodes.contains(InGroupNode(nodeAndGroupToTest._2.id(),
nodeAndGroupToTest._1))
+ winnerNodes(InGroupNode(nodeAndGroupToTest._2.id(),
nodeAndGroupToTest._1))
}
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
index acd96442c..f1c319873 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
@@ -97,7 +97,7 @@ class OperationSuite extends AnyFunSuite {
val ras =
Ras[TestNode](
- PlanModelImpl,
+ planModel,
CostModelImpl,
MetadataModelImpl,
PropertyModelImpl,
@@ -108,7 +108,7 @@ class OperationSuite extends AnyFunSuite {
val optimized = planner.plan()
assert(optimized == Unary2(49, Leaf2(29)))
- planModel.assertPlanOpsLte((200, 50, 50, 50))
+ planModel.assertPlanOpsLte((200, 50, 100, 50))
val state = planner.newState()
val allPaths = state.memoState().collectAllPaths(RasPath.INF_DEPTH).toSeq
@@ -127,7 +127,7 @@ class OperationSuite extends AnyFunSuite {
val ras =
Ras[TestNode](
- PlanModelImpl,
+ planModel,
CostModelImpl,
MetadataModelImpl,
PropertyModelImpl,
@@ -163,7 +163,7 @@ class OperationSuite extends AnyFunSuite {
val ras =
Ras[TestNode](
- PlanModelImpl,
+ planModel,
CostModelImpl,
MetadataModelImpl,
PropertyModelImpl,
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
index 8a68bbba8..e48604116 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
@@ -19,6 +19,7 @@ package org.apache.gluten.ras
import org.apache.gluten.ras.Best.BestNotFoundException
import org.apache.gluten.ras.RasConfig.PlannerType
import org.apache.gluten.ras.RasSuiteBase._
+import org.apache.gluten.ras.memo.Memo
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
@@ -37,6 +38,30 @@ abstract class PropertySuite extends AnyFunSuite {
protected def conf: RasConfig
+ test("Group memo - cache") {
+ val ras =
+ Ras[TestNode](
+ PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
+ NodeTypePropertyModelWithOutEnforcerRules,
+ ExplainImpl,
+ RasRule.Factory.none())
+ .withNewConfig(_ => conf)
+
+ val memo = Memo(ras)
+
+ memo.memorize(ras, PassNodeType(1, PassNodeType(1, PassNodeType(1,
TypedLeaf(TypeA, 1)))))
+ val leafGroup = memo.memorize(ras, TypedLeaf(TypeA, 1))
+ memo
+ .openFor(leafGroup.clusterKey())
+ .memorize(ras, TypedLeaf(TypeB, 1))
+ memo.memorize(ras, PassNodeType(1, PassNodeType(1, PassNodeType(1,
TypedLeaf(TypeB, 1)))))
+ val state = memo.newState()
+ assert(state.allClusters().size == 4)
+ assert(state.getGroupCount() == 8)
+ }
+
test(s"Get property") {
val leaf = PLeaf(10, DummyProperty(0))
val unary = PUnary(5, DummyProperty(0), leaf)
@@ -112,7 +137,7 @@ abstract class PropertySuite extends AnyFunSuite {
TypedLeaf(TypeB, 10)))
}
- ignore(s"Memo cache hit - (A, B)") {
+ test(s"Memo cache hit - (A, B)") {
object ReplaceLeafAByLeafBRule extends RasRule[TestNode] {
override def shift(node: TestNode): Iterable[TestNode] = {
node match {
@@ -163,8 +188,8 @@ abstract class PropertySuite extends AnyFunSuite {
val out = planner.plan()
assert(out == TypedLeaf(TypeA, 1))
- // FIXME: Cluster 2 and 1 are currently able to merge but it's better to
- // have them identified as the same right after HitCacheOp is applied
+ // Cluster 2 and 1 are able to merge but we'd make sure
+ // they are identified as the same right after HitCacheOp is applied
val clusterCount = planner.newState().memoState().allClusters().size
assert(clusterCount == 2)
}
@@ -531,6 +556,7 @@ object PropertySuite {
}
object DummyPropertyDef extends PropertyDef[TestNode, DummyProperty] {
+ override def any(): DummyProperty = DummyProperty(Int.MinValue)
override def getProperty(plan: TestNode): DummyProperty = {
plan match {
case Group(_, _, _) => throw new IllegalStateException()
@@ -669,6 +695,8 @@ object PropertySuite {
}
override def toString: String = "NodeTypeDef"
+
+ override def any(): NodeType = TypeAny
}
trait NodeType extends Property[TestNode] {
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
index abb8bdecd..0ad825181 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
@@ -66,8 +66,7 @@ abstract class RasSuite extends AnyFunSuite {
val group = memo.memorize(ras, Unary(50, Unary(50, Leaf(30))))
val state = memo.newState()
assert(group.nodes(state).size == 1)
- val can = group.nodes(state).head.asCanonical()
- memo.openFor(can).memorize(ras, Unary(30, Leaf(90)))
+ memo.openFor(group.clusterKey()).memorize(ras, Unary(30, Leaf(90)))
assert(memo.newState().allGroups().size == 4)
}
@@ -87,8 +86,7 @@ abstract class RasSuite extends AnyFunSuite {
assert(group.nodes(state).size == 1)
val leaf40Group = memo.memorize(ras, Leaf(40))
assert(leaf40Group.nodes(state).size == 1)
- val can = leaf40Group.nodes(state).head.asCanonical()
- memo.openFor(can).memorize(ras, Leaf(30))
+ memo.openFor(leaf40Group.clusterKey()).memorize(ras, Leaf(30))
assert(memo.newState().allGroups().size == 3)
}
@@ -108,8 +106,7 @@ abstract class RasSuite extends AnyFunSuite {
assert(group.nodes(state).size == 1)
val leaf40Group = memo.memorize(ras, Leaf(40))
assert(leaf40Group.nodes(state).size == 1)
- val can = leaf40Group.nodes(state).head.asCanonical()
- memo.openFor(can).memorize(ras, Leaf(30))
+ memo.openFor(leaf40Group.clusterKey()).memorize(ras, Leaf(30))
assert(memo.newState().allGroups().size == 5)
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala
index e092ea4f2..8158aec16 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.ras.path
-import org.apache.gluten.ras.Ras
+import org.apache.gluten.ras.{CanonicalNode, Ras}
import org.apache.gluten.ras.RasSuiteBase._
import org.apache.gluten.ras.mock.MockRasPath
import org.apache.gluten.ras.rule.RasRule
@@ -26,7 +26,7 @@ import org.scalatest.funsuite.AnyFunSuite
class RasPathSuite extends AnyFunSuite {
import RasPathSuite._
- test("Path aggregate - empty") {
+ test("Cartesian product - empty") {
val ras =
Ras[TestNode](
PlanModelImpl,
@@ -35,10 +35,21 @@ class RasPathSuite extends AnyFunSuite {
PropertyModelImpl,
ExplainImpl,
RasRule.Factory.reuse(List.empty))
- assert(RasPath.aggregate(ras, List.empty) == List.empty)
+ assert(
+ RasPath.cartesianProduct(
+ ras,
+ CanonicalNode(ras, Binary("b", ras.dummyGroupLeaf(),
ras.dummyGroupLeaf())),
+ List(
+ List.empty,
+ List(
+ MockRasPath.mock(
+ ras,
+ Leaf("l", 1),
+ PathKeySet(Set(DummyPathKey(3)))
+ )))
+ ) == List.empty)
}
-
- test("Path aggregate") {
+ test("Cartesian product") {
val ras =
Ras[TestNode](
PlanModelImpl,
@@ -54,6 +65,7 @@ class RasPathSuite extends AnyFunSuite {
val n4 = "n4"
val n5 = "n5"
val n6 = "n6"
+
val path1 = MockRasPath.mock(
ras,
Unary(n5, Leaf(n6, 1)),
@@ -66,31 +78,37 @@ class RasPathSuite extends AnyFunSuite {
)
val path3 = MockRasPath.mock(
ras,
- Unary(n1, Unary(n2, Leaf(n3, 1))),
- PathKeySet(Set(DummyPathKey(1), DummyPathKey(2)))
+ Leaf(n6, 1),
+ PathKeySet(Set(DummyPathKey(1)))
)
val path4 = MockRasPath.mock(
ras,
- Unary(n1, Unary(n2, Leaf(n3, 1))),
- PathKeySet(Set(DummyPathKey(4)))
+ Leaf(n3, 1),
+ PathKeySet(Set(DummyPathKey(3)))
)
+
val path5 = MockRasPath.mock(
ras,
- Unary(n5, Leaf(n6, 1)),
+ Unary(n2, Leaf(n3, 1)),
PathKeySet(Set(DummyPathKey(4)))
)
- val out = RasPath
- .aggregate(ras, List(path1, path2, path3, path4, path5))
- .toSeq
- .sortBy(_.height())
- assert(out.size == 2)
- assert(out.head.height() == 2)
- assert(out.head.plan() == Unary(n5, Leaf(n6, 1)))
- assert(out.head.keys() == PathKeySet(Set(DummyPathKey(1), DummyPathKey(3),
DummyPathKey(4))))
- assert(out(1).height() == 3)
- assert(out(1).plan() == Unary(n1, Unary(n2, Leaf(n3, 1))))
- assert(out(1).keys() == PathKeySet(Set(DummyPathKey(1), DummyPathKey(2),
DummyPathKey(4))))
+ val product = RasPath.cartesianProduct(
+ ras,
+ CanonicalNode(ras, Binary(n4, ras.dummyGroupLeaf(),
ras.dummyGroupLeaf())),
+ List(
+ List(path1, path2),
+ List(path3, path4, path5)
+ ))
+
+ val out = product.toList
+ assert(out.size == 3)
+
+ assert(
+ out.map(_.plan()) == List(
+ Binary(n4, Unary(n5, Leaf(n6, 1)), Leaf(n6, 1)),
+ Binary(n4, Unary(n5, Leaf(n6, 1)), Leaf(n3, 1)),
+ Binary(n4, Unary(n1, Unary(n2, Leaf(n3, 1))), Leaf(n6, 1))))
}
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
index cab3d1818..de71cba5b 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
@@ -262,6 +262,8 @@ object DistributedSuite {
case (d: Distribution, p: DNode) => p.getDistributionConstraints(d)
case _ => throw new UnsupportedOperationException()
}
+
+ override def any(): Distribution = AnyDistribution
}
trait Ordering extends Property[TestNode]
@@ -315,6 +317,8 @@ object DistributedSuite {
case (o: Ordering, p: DNode) => p.getOrderingConstraints(o)
case _ => throw new UnsupportedOperationException()
}
+
+ override def any(): Ordering = AnyOrdering
}
private class EnforceDistribution(distribution: Distribution) extends
RasRule[TestNode] {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]