This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 681871e1d2 [CORE] Exclude dependent components of incompatible 
components (#11009)
681871e1d2 is described below

commit 681871e1d28eaeef4830a54aa144b498b5df69a0
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed Nov 5 14:33:26 2025 +0000

    [CORE] Exclude dependent components of incompatible components (#11009)
---
 .../org/apache/gluten/component/Component.scala    | 60 ++++++++++-------
 .../org/apache/gluten/component/package.scala      | 10 ++-
 .../apache/gluten/component/ComponentSuite.scala   | 76 +++++++++++++++++++++-
 3 files changed, 113 insertions(+), 33 deletions(-)

diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala 
b/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
index d90a4af58f..29139436a1 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
@@ -23,6 +23,7 @@ import org.apache.gluten.extension.injector.Injector
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.plugin.PluginContext
+import org.apache.spark.internal.Logging
 
 import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
 
@@ -91,7 +92,7 @@ trait Component {
   def injectRules(injector: Injector): Unit
 }
 
-object Component {
+object Component extends Logging {
   private val nextUid = new AtomicInteger()
   private val graph: Graph = new Graph()
 
@@ -134,8 +135,8 @@ object Component {
       require(
         !lookupByClass.contains(clazz),
         s"Component class $clazz already registered: ${comp.name()}")
-      lookupByUid += uid -> comp
-      lookupByClass += clazz -> comp
+      lookupByUid(uid) = comp
+      lookupByClass(clazz) = comp
     }
 
     def isUidRegistered(uid: Int): Boolean = synchronized {
@@ -164,7 +165,8 @@ object Component {
   private class Graph {
     import Graph._
     private val registry: Registry = new Registry()
-    private val dependencies: mutable.Buffer[(Int, Class[_ <: Component])] = 
mutable.Buffer()
+    private val uidAndDependencyPairs: mutable.Buffer[(Int, Class[_ <: 
Component])] =
+      mutable.Buffer()
 
     private var sortedComponents: Option[Seq[Component]] = None
 
@@ -183,38 +185,38 @@ object Component {
       synchronized {
         require(registry.isUidRegistered(comp.uid))
         require(registry.isClassRegistered(comp.getClass))
-        dependencies += comp.uid -> dependencyCompClass
+        uidAndDependencyPairs += comp.uid -> dependencyCompClass
         sortedComponents = None
       }
 
-    private def newLookup(): mutable.Map[Int, Node] = {
-      val lookup: mutable.Map[Int, Node] = mutable.Map()
+    private def newLookup(): Map[Int, Node] = {
+      val uidToNodeLookup: mutable.Map[Int, Node] = mutable.Map()
 
       registry.allUids().foreach {
         uid =>
-          require(!lookup.contains(uid))
+          require(!uidToNodeLookup.contains(uid))
           val n = new Node(uid)
-          lookup += uid -> n
+          uidToNodeLookup(uid) = n
       }
 
-      dependencies.foreach {
+      uidAndDependencyPairs.foreach {
         case (uid, dependencyCompClass) =>
           require(
             registry.isClassRegistered(dependencyCompClass),
             s"Dependency class not registered yet: 
${dependencyCompClass.getName}")
           val dependencyUid = registry.findByClass(dependencyCompClass).uid
           require(uid != dependencyUid)
-          require(lookup.contains(uid))
-          require(lookup.contains(dependencyUid))
-          val n = lookup(uid)
-          val r = lookup(dependencyUid)
+          require(uidToNodeLookup.contains(uid))
+          require(uidToNodeLookup.contains(dependencyUid))
+          val n = uidToNodeLookup(uid)
+          val r = uidToNodeLookup(dependencyUid)
           require(!n.parents.contains(r.uid))
           require(!r.children.contains(n.uid))
-          n.parents += r.uid -> r
-          r.children += n.uid -> n
+          n.parents(r.uid) = r
+          r.children(n.uid) = n
       }
 
-      lookup
+      uidToNodeLookup.toMap
     }
 
     def sorted(): Seq[Component] = synchronized {
@@ -222,10 +224,11 @@ object Component {
         return sortedComponents.get
       }
 
-      val lookup: mutable.Map[Int, Node] = newLookup()
+      val lookup: Map[Int, Node] = newLookup()
 
-      val out = mutable.Buffer[Component]()
-      val uidToNumParents = lookup.map { case (uid, node) => uid -> 
node.parents.size }
+      val sortedComponentsBuffer = mutable.Buffer[Component]()
+      val uidToNumParents = mutable.Map[Int, Int]()
+      uidToNumParents ++= lookup.map { case (uid, node) => uid -> 
node.parents.size }
       val removalQueue = mutable.Queue[Int]()
 
       // 1. Find out all nodes with zero parents then enqueue them.
@@ -235,10 +238,10 @@ object Component {
       while (removalQueue.nonEmpty) {
         val parentUid = removalQueue.dequeue()
         val node = lookup(parentUid)
-        out += registry.findByUid(parentUid)
+        sortedComponentsBuffer += registry.findByUid(parentUid)
         node.children.keys.foreach {
           childUid =>
-            uidToNumParents += childUid -> (uidToNumParents(childUid) - 1)
+            uidToNumParents(childUid) = uidToNumParents(childUid) - 1
             val updatedNumParents = uidToNumParents(childUid)
             assert(updatedNumParents >= 0)
             if (updatedNumParents == 0) {
@@ -256,8 +259,17 @@ object Component {
           s"Cycle detected in the component graph: $cycleNodeNames")
       }
 
-      // 4. Return the ordered nodes.
-      sortedComponents = Some(out.toSeq)
+      // 4. Return the ordered components, with the incompatible ones excluded.
+      def isRuntimeCompatible(component: Component): Boolean = {
+        val parents = 
lookup(component.uid).parents.keys.map(registry.findByUid)
+        component.isRuntimeCompatible && parents.forall(isRuntimeCompatible)
+      }
+      val (compatibleComponents, incompatibleComponents) =
+        sortedComponentsBuffer.partition(isRuntimeCompatible)
+      incompatibleComponents.foreach {
+        component => logWarning(s"Excluding runtime-incompatible component: 
${component.name()}.")
+      }
+      sortedComponents = Some(compatibleComponents.toSeq)
       sortedComponents.get
     }
   }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/component/package.scala 
b/gluten-core/src/main/scala/org/apache/gluten/component/package.scala
index cb74a0b3c8..cf0181c39c 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/component/package.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/component/package.scala
@@ -28,13 +28,11 @@ package object component extends Logging {
       return
     }
 
-    // Discover all components available in the classpath.
+    // Load all components in classpath.
     val all = Discovery.discoverAll()
-    val (compatibleComponents, incompatibleComponents) = 
all.partition(_.isRuntimeCompatible)
-    incompatibleComponents.foreach(
-      c => logWarning(s"Excluding runtime-incompatible component: ${c.name}"))
-    // Register all runtime-compatible components.
-    compatibleComponents.foreach(_.ensureRegistered())
+
+    // Register all components.
+    all.foreach(_.ensureRegistered())
 
     // Output log so user could view the component loading order.
     // Call #sortedUnsafe than on #sorted to avoid unnecessary recursion.
diff --git 
a/gluten-core/src/test/scala/org/apache/gluten/component/ComponentSuite.scala 
b/gluten-core/src/test/scala/org/apache/gluten/component/ComponentSuite.scala
index 9abdc6b093..79bbb37bba 100644
--- 
a/gluten-core/src/test/scala/org/apache/gluten/component/ComponentSuite.scala
+++ 
b/gluten-core/src/test/scala/org/apache/gluten/component/ComponentSuite.scala
@@ -27,7 +27,7 @@ import scala.collection.mutable
 class ComponentSuite extends AnyFunSuite with BeforeAndAfterAll {
   import ComponentSuite._
 
-  test("Load order - sanity") {
+  test("Load order") {
     val a = new DummyBackend("A") {}
     val b = new DummyBackend("B") {}
     val c = new DummyComponent("C") {}
@@ -63,6 +63,61 @@ class ComponentSuite extends AnyFunSuite with 
BeforeAndAfterAll {
     }
   }
 
+  test("Incompatible component") {
+    val a = new DummyBackend("A") {}
+    val b = new DummyBackend("B") {}
+    val c = new DummyComponent("C") {}
+    val d = new DummyComponent("D") {}
+    val e = new DummyComponent("E") {}
+
+    c.dependsOn(a)
+    d.dependsOn(a, b)
+    e.dependsOn(a, d)
+
+    d.setIncompatible()
+
+    a.ensureRegistered()
+    b.ensureRegistered()
+    c.ensureRegistered()
+    d.ensureRegistered()
+    e.ensureRegistered()
+
+    val possibleOrders: Set[Seq[Component]] =
+      Set(
+        Seq(a, b, c),
+        Seq(b, a, c)
+      )
+
+    assert(possibleOrders.contains(Component.sorted().filter(Seq(a, b, c, d, 
e).contains(_))))
+  }
+
+  test("Incompatible backend") {
+    val a = new DummyBackend("A") {}
+    val b = new DummyBackend("B") {}
+    val c = new DummyComponent("C") {}
+    val d = new DummyComponent("D") {}
+    val e = new DummyComponent("E") {}
+
+    c.dependsOn(a)
+    d.dependsOn(a, b)
+    e.dependsOn(a, d)
+
+    b.setIncompatible()
+
+    a.ensureRegistered()
+    b.ensureRegistered()
+    c.ensureRegistered()
+    d.ensureRegistered()
+    e.ensureRegistered()
+
+    val possibleOrders: Set[Seq[Component]] =
+      Set(
+        Seq(a, c)
+      )
+
+    assert(possibleOrders.contains(Component.sorted().filter(Seq(a, b, c, d, 
e).contains(_))))
+  }
+
   test("Dependencies not registered") {
     val a = new DummyBackend("A") {}
     val c = new DummyComponent("C") {}
@@ -115,9 +170,21 @@ object ComponentSuite {
     }
   }
 
+  private trait CompatibilityHelper extends Component {
+    private var _isRuntimeCompatible: Boolean = true;
+
+    override def isRuntimeCompatible: Boolean = _isRuntimeCompatible
+
+    def setIncompatible(): Unit = {
+      _isRuntimeCompatible = false
+    }
+  }
+
   abstract private class DummyComponent(override val name: String)
     extends Component
-    with DependencyBuilder {
+    with DependencyBuilder
+    with CompatibilityHelper {
+
     override def buildInfo(): Component.BuildInfo =
       Component.BuildInfo(name, "N/A", "N/A", "N/A")
 
@@ -125,7 +192,10 @@ object ComponentSuite {
     override def injectRules(injector: Injector): Unit = {}
   }
 
-  abstract private class DummyBackend(override val name: String) extends 
Backend {
+  abstract private class DummyBackend(override val name: String)
+    extends Backend
+    with CompatibilityHelper {
+
     override def buildInfo(): Component.BuildInfo =
       Component.BuildInfo(name, "N/A", "N/A", "N/A")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to