This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 681871e1d2 [CORE] Exclude dependent components of incompatible
components (#11009)
681871e1d2 is described below
commit 681871e1d28eaeef4830a54aa144b498b5df69a0
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed Nov 5 14:33:26 2025 +0000
[CORE] Exclude dependent components of incompatible components (#11009)
---
.../org/apache/gluten/component/Component.scala | 60 ++++++++++-------
.../org/apache/gluten/component/package.scala | 10 ++-
.../apache/gluten/component/ComponentSuite.scala | 76 +++++++++++++++++++++-
3 files changed, 113 insertions(+), 33 deletions(-)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
b/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
index d90a4af58f..29139436a1 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
@@ -23,6 +23,7 @@ import org.apache.gluten.extension.injector.Injector
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.plugin.PluginContext
+import org.apache.spark.internal.Logging
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
@@ -91,7 +92,7 @@ trait Component {
def injectRules(injector: Injector): Unit
}
-object Component {
+object Component extends Logging {
private val nextUid = new AtomicInteger()
private val graph: Graph = new Graph()
@@ -134,8 +135,8 @@ object Component {
require(
!lookupByClass.contains(clazz),
s"Component class $clazz already registered: ${comp.name()}")
- lookupByUid += uid -> comp
- lookupByClass += clazz -> comp
+ lookupByUid(uid) = comp
+ lookupByClass(clazz) = comp
}
def isUidRegistered(uid: Int): Boolean = synchronized {
@@ -164,7 +165,8 @@ object Component {
private class Graph {
import Graph._
private val registry: Registry = new Registry()
- private val dependencies: mutable.Buffer[(Int, Class[_ <: Component])] =
mutable.Buffer()
+ private val uidAndDependencyPairs: mutable.Buffer[(Int, Class[_ <:
Component])] =
+ mutable.Buffer()
private var sortedComponents: Option[Seq[Component]] = None
@@ -183,38 +185,38 @@ object Component {
synchronized {
require(registry.isUidRegistered(comp.uid))
require(registry.isClassRegistered(comp.getClass))
- dependencies += comp.uid -> dependencyCompClass
+ uidAndDependencyPairs += comp.uid -> dependencyCompClass
sortedComponents = None
}
- private def newLookup(): mutable.Map[Int, Node] = {
- val lookup: mutable.Map[Int, Node] = mutable.Map()
+ private def newLookup(): Map[Int, Node] = {
+ val uidToNodeLookup: mutable.Map[Int, Node] = mutable.Map()
registry.allUids().foreach {
uid =>
- require(!lookup.contains(uid))
+ require(!uidToNodeLookup.contains(uid))
val n = new Node(uid)
- lookup += uid -> n
+ uidToNodeLookup(uid) = n
}
- dependencies.foreach {
+ uidAndDependencyPairs.foreach {
case (uid, dependencyCompClass) =>
require(
registry.isClassRegistered(dependencyCompClass),
s"Dependency class not registered yet:
${dependencyCompClass.getName}")
val dependencyUid = registry.findByClass(dependencyCompClass).uid
require(uid != dependencyUid)
- require(lookup.contains(uid))
- require(lookup.contains(dependencyUid))
- val n = lookup(uid)
- val r = lookup(dependencyUid)
+ require(uidToNodeLookup.contains(uid))
+ require(uidToNodeLookup.contains(dependencyUid))
+ val n = uidToNodeLookup(uid)
+ val r = uidToNodeLookup(dependencyUid)
require(!n.parents.contains(r.uid))
require(!r.children.contains(n.uid))
- n.parents += r.uid -> r
- r.children += n.uid -> n
+ n.parents(r.uid) = r
+ r.children(n.uid) = n
}
- lookup
+ uidToNodeLookup.toMap
}
def sorted(): Seq[Component] = synchronized {
@@ -222,10 +224,11 @@ object Component {
return sortedComponents.get
}
- val lookup: mutable.Map[Int, Node] = newLookup()
+ val lookup: Map[Int, Node] = newLookup()
- val out = mutable.Buffer[Component]()
- val uidToNumParents = lookup.map { case (uid, node) => uid ->
node.parents.size }
+ val sortedComponentsBuffer = mutable.Buffer[Component]()
+ val uidToNumParents = mutable.Map[Int, Int]()
+ uidToNumParents ++= lookup.map { case (uid, node) => uid ->
node.parents.size }
val removalQueue = mutable.Queue[Int]()
// 1. Find out all nodes with zero parents then enqueue them.
@@ -235,10 +238,10 @@ object Component {
while (removalQueue.nonEmpty) {
val parentUid = removalQueue.dequeue()
val node = lookup(parentUid)
- out += registry.findByUid(parentUid)
+ sortedComponentsBuffer += registry.findByUid(parentUid)
node.children.keys.foreach {
childUid =>
- uidToNumParents += childUid -> (uidToNumParents(childUid) - 1)
+ uidToNumParents(childUid) = uidToNumParents(childUid) - 1
val updatedNumParents = uidToNumParents(childUid)
assert(updatedNumParents >= 0)
if (updatedNumParents == 0) {
@@ -256,8 +259,17 @@ object Component {
s"Cycle detected in the component graph: $cycleNodeNames")
}
- // 4. Return the ordered nodes.
- sortedComponents = Some(out.toSeq)
+ // 4. Return the ordered components, with the incompatible ones excluded.
+ def isRuntimeCompatible(component: Component): Boolean = {
+ val parents =
lookup(component.uid).parents.keys.map(registry.findByUid)
+ component.isRuntimeCompatible && parents.forall(isRuntimeCompatible)
+ }
+ val (compatibleComponents, incompatibleComponents) =
+ sortedComponentsBuffer.partition(isRuntimeCompatible)
+ incompatibleComponents.foreach {
+ component => logWarning(s"Excluding runtime-incompatible component:
${component.name()}.")
+ }
+ sortedComponents = Some(compatibleComponents.toSeq)
sortedComponents.get
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/component/package.scala
b/gluten-core/src/main/scala/org/apache/gluten/component/package.scala
index cb74a0b3c8..cf0181c39c 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/component/package.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/component/package.scala
@@ -28,13 +28,11 @@ package object component extends Logging {
return
}
- // Discover all components available in the classpath.
+ // Load all components in classpath.
val all = Discovery.discoverAll()
- val (compatibleComponents, incompatibleComponents) =
all.partition(_.isRuntimeCompatible)
- incompatibleComponents.foreach(
- c => logWarning(s"Excluding runtime-incompatible component: ${c.name}"))
- // Register all runtime-compatible components.
- compatibleComponents.foreach(_.ensureRegistered())
+
+ // Register all components.
+ all.foreach(_.ensureRegistered())
// Output log so user could view the component loading order.
// Call #sortedUnsafe than on #sorted to avoid unnecessary recursion.
diff --git
a/gluten-core/src/test/scala/org/apache/gluten/component/ComponentSuite.scala
b/gluten-core/src/test/scala/org/apache/gluten/component/ComponentSuite.scala
index 9abdc6b093..79bbb37bba 100644
---
a/gluten-core/src/test/scala/org/apache/gluten/component/ComponentSuite.scala
+++
b/gluten-core/src/test/scala/org/apache/gluten/component/ComponentSuite.scala
@@ -27,7 +27,7 @@ import scala.collection.mutable
class ComponentSuite extends AnyFunSuite with BeforeAndAfterAll {
import ComponentSuite._
- test("Load order - sanity") {
+ test("Load order") {
val a = new DummyBackend("A") {}
val b = new DummyBackend("B") {}
val c = new DummyComponent("C") {}
@@ -63,6 +63,61 @@ class ComponentSuite extends AnyFunSuite with
BeforeAndAfterAll {
}
}
+ test("Incompatible component") {
+ val a = new DummyBackend("A") {}
+ val b = new DummyBackend("B") {}
+ val c = new DummyComponent("C") {}
+ val d = new DummyComponent("D") {}
+ val e = new DummyComponent("E") {}
+
+ c.dependsOn(a)
+ d.dependsOn(a, b)
+ e.dependsOn(a, d)
+
+ d.setIncompatible()
+
+ a.ensureRegistered()
+ b.ensureRegistered()
+ c.ensureRegistered()
+ d.ensureRegistered()
+ e.ensureRegistered()
+
+ val possibleOrders: Set[Seq[Component]] =
+ Set(
+ Seq(a, b, c),
+ Seq(b, a, c)
+ )
+
+ assert(possibleOrders.contains(Component.sorted().filter(Seq(a, b, c, d,
e).contains(_))))
+ }
+
+ test("Incompatible backend") {
+ val a = new DummyBackend("A") {}
+ val b = new DummyBackend("B") {}
+ val c = new DummyComponent("C") {}
+ val d = new DummyComponent("D") {}
+ val e = new DummyComponent("E") {}
+
+ c.dependsOn(a)
+ d.dependsOn(a, b)
+ e.dependsOn(a, d)
+
+ b.setIncompatible()
+
+ a.ensureRegistered()
+ b.ensureRegistered()
+ c.ensureRegistered()
+ d.ensureRegistered()
+ e.ensureRegistered()
+
+ val possibleOrders: Set[Seq[Component]] =
+ Set(
+ Seq(a, c)
+ )
+
+ assert(possibleOrders.contains(Component.sorted().filter(Seq(a, b, c, d,
e).contains(_))))
+ }
+
test("Dependencies not registered") {
val a = new DummyBackend("A") {}
val c = new DummyComponent("C") {}
@@ -115,9 +170,21 @@ object ComponentSuite {
}
}
+ private trait CompatibilityHelper extends Component {
+ private var _isRuntimeCompatible: Boolean = true;
+
+ override def isRuntimeCompatible: Boolean = _isRuntimeCompatible
+
+ def setIncompatible(): Unit = {
+ _isRuntimeCompatible = false
+ }
+ }
+
abstract private class DummyComponent(override val name: String)
extends Component
- with DependencyBuilder {
+ with DependencyBuilder
+ with CompatibilityHelper {
+
override def buildInfo(): Component.BuildInfo =
Component.BuildInfo(name, "N/A", "N/A", "N/A")
@@ -125,7 +192,10 @@ object ComponentSuite {
override def injectRules(injector: Injector): Unit = {}
}
- abstract private class DummyBackend(override val name: String) extends
Backend {
+ abstract private class DummyBackend(override val name: String)
+ extends Backend
+ with CompatibilityHelper {
+
override def buildInfo(): Component.BuildInfo =
Component.BuildInfo(name, "N/A", "N/A", "N/A")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]