This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new 550a7fb1b7 [KYUUBI #7183][LINEAGE] Support collect all input tables by 
the plan
550a7fb1b7 is described below

commit 550a7fb1b7f47d7577f5c7e95a913f9ac82834cd
Author: chenliang.lu <[email protected]>
AuthorDate: Thu Sep 11 10:54:50 2025 +0800

    [KYUUBI #7183][LINEAGE] Support collect all input tables by the plan
    
    ### Why are the changes needed?
    
    The current input tables is derived from column which has lineage result, 
which will cause some input tables to be missing. We need to support one way to 
collect the complete input tables by the plan.
    For example:
    insert overwrite v2_catalog.db.tb3 select t1.col1, t1.col2 , t1.col3 from 
v2_catalog.db.tb1 t1 join v2_catalog.db.tb2 t2 on t1.col1 = t2.col1
    
    The input tables for v2_catalog.db.tb3 should be v2_catalog.db.tb1 and 
v2_catalog.db.tb2
    
    ### How was this patch tested?
    
    add new UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #7184 from yabola/linage-inputTables.
    
    Closes #7183
    
    5c4f6291f [chenliang.lu] add input tables
    3a57872c4 [Cheng Pan] Update 
extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/spark/kyuubi/lineage/LineageConf.scala
    460e73e76 [chenliang.lu] default collect input tables by plan
    ed96f1055 [chenliang.lu] ut fix
    90b02c507 [chenliang.lu] fix code style
    b3ef5e413 [chenliang.lu] add ut
    b566c0a45 [chenliang.lu] [KYUUBI #7183][LINEAGE] Support collect all the 
input tables by the plan
    00b6a3cdf [chenliang.lu] [KYUUBI #7183][LINEAGE] Support collect all the 
input tables by the plan
    
    Lead-authored-by: chenliang.lu <[email protected]>
    Co-authored-by: Cheng Pan <[email protected]>
    Signed-off-by: Cheng Pan <[email protected]>
---
 .../helper/SparkSQLLineageParseHelper.scala        | 161 +++++++++++++++------
 .../apache/spark/kyuubi/lineage/LineageConf.scala  |   9 ++
 .../helper/RowLevelCatalogLineageParserSuite.scala |   6 +-
 .../helper/SparkSQLLineageParserHelperSuite.scala  |  47 ++++++
 .../helper/TableCatalogLineageParserSuite.scala    |   6 +-
 5 files changed, 178 insertions(+), 51 deletions(-)

diff --git 
a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala
 
b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala
index bcab9b74fe..27d74aa173 100644
--- 
a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala
+++ 
b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala
@@ -18,6 +18,7 @@
 package org.apache.kyuubi.plugin.lineage.helper
 
 import scala.collection.immutable.ListMap
+import scala.collection.mutable
 import scala.util.Try
 
 import org.apache.spark.internal.Logging
@@ -53,18 +54,26 @@ trait LineageParser {
   type AttributeMap[A] = ListMap[Attribute, A]
 
   def parse(plan: LogicalPlan): Lineage = {
-    val columnsLineage =
-      extractColumnsLineage(plan, ListMap[Attribute, 
AttributeSet]()).toList.collect {
-        case (k, attrs) =>
-          k.name -> attrs.map(attr => (attr.qualifier :+ 
attr.name).mkString(".")).toSet
+    val inputTablesByPlan = mutable.HashSet[String]()
+    val columnsLineage = extractColumnsLineage(
+      plan,
+      ListMap[Attribute, AttributeSet](),
+      inputTablesByPlan).toList.collect {
+      case (k, attrs) =>
+        k.name -> attrs.map(attr => (attr.qualifier :+ 
attr.name).mkString(".")).toSet
+    }
+    val (inputTablesByColumn, outputTables) = columnsLineage
+      .foldLeft((List[String](), List[String]())) {
+        case ((inputs, outputs), (out, in)) =>
+          val x = (inputs ++ 
in.map(_.split('.').init.mkString("."))).filter(_.nonEmpty)
+          val y = outputs ++ 
List(out.split('.').init.mkString(".")).filter(_.nonEmpty)
+          (x, y)
       }
-    val (inputTables, outputTables) = columnsLineage.foldLeft((List[String](), 
List[String]())) {
-      case ((inputs, outputs), (out, in)) =>
-        val x = (inputs ++ 
in.map(_.split('.').init.mkString("."))).filter(_.nonEmpty)
-        val y = outputs ++ 
List(out.split('.').init.mkString(".")).filter(_.nonEmpty)
-        (x, y)
+    if 
(SparkContextHelper.getConf(LineageConf.LEGACY_COLLECT_INPUT_TABLES_ENABLED)) {
+      Lineage(inputTablesByColumn.distinct, outputTables.distinct, 
columnsLineage)
+    } else {
+      Lineage(inputTablesByPlan.toList, outputTables.distinct, columnsLineage)
     }
-    Lineage(inputTables.distinct, outputTables.distinct, columnsLineage)
   }
 
   private def mergeColumnsLineage(
@@ -115,14 +124,15 @@ trait LineageParser {
   }
 
   private def getSelectColumnLineage(
-      named: Seq[NamedExpression]): AttributeMap[AttributeSet] = {
+      named: Seq[NamedExpression],
+      inputTablesByPlan: mutable.HashSet[String]): AttributeMap[AttributeSet] 
= {
     val exps = named.map {
       case exp: Alias =>
         val references =
           if (exp.references.nonEmpty) exp.references
           else {
             val attrRefs = getExpressionSubqueryPlans(exp.child)
-              .map(extractColumnsLineage(_, ListMap[Attribute, 
AttributeSet]()))
+              .map(extractColumnsLineage(_, ListMap[Attribute, 
AttributeSet](), inputTablesByPlan))
               .foldLeft(ListMap[Attribute, 
AttributeSet]())(mergeColumnsLineage).values
               .foldLeft(AttributeSet.empty)(_ ++ _)
               .map(attr => attr.withQualifier(attr.qualifier :+ 
SUBQUERY_COLUMN_IDENTIFIER))
@@ -190,13 +200,14 @@ trait LineageParser {
 
   private def extractColumnsLineage(
       plan: LogicalPlan,
-      parentColumnsLineage: AttributeMap[AttributeSet]): 
AttributeMap[AttributeSet] = {
+      parentColumnsLineage: AttributeMap[AttributeSet],
+      inputTablesByPlan: mutable.HashSet[String]): AttributeMap[AttributeSet] 
= {
 
     plan match {
       // For command
       case p if p.nodeName == "CommandResult" =>
         val commandPlan = getField[LogicalPlan](plan, "commandLogicalPlan")
-        extractColumnsLineage(commandPlan, parentColumnsLineage)
+        extractColumnsLineage(commandPlan, parentColumnsLineage, 
inputTablesByPlan)
       case p if p.nodeName == "AlterViewAsCommand" =>
         val query =
           if (SPARK_RUNTIME_VERSION <= "3.1") {
@@ -205,7 +216,7 @@ trait LineageParser {
             getQuery(plan)
           }
         val view = getV1TableName(getField[TableIdentifier](plan, 
"name").unquotedString)
-        extractColumnsLineage(query, parentColumnsLineage).map { case (k, v) =>
+        extractColumnsLineage(query, parentColumnsLineage, 
inputTablesByPlan).map { case (k, v) =>
           k.withName(s"$view.${k.name}") -> v
         }
 
@@ -222,7 +233,10 @@ trait LineageParser {
             getField[LogicalPlan](plan, "plan")
           }
 
-        val lineages = extractColumnsLineage(query, 
parentColumnsLineage).zipWithIndex.map {
+        val lineages = extractColumnsLineage(
+          query,
+          parentColumnsLineage,
+          inputTablesByPlan).zipWithIndex.map {
           case ((k, v), i) if outputCols.nonEmpty => 
k.withName(s"$view.${outputCols(i)}") -> v
           case ((k, v), _) => k.withName(s"$view.${k.name}") -> v
         }.toSeq
@@ -230,7 +244,10 @@ trait LineageParser {
 
       case p if p.nodeName == "CreateDataSourceTableAsSelectCommand" =>
         val table = getV1TableName(getField[CatalogTable](plan, 
"table").qualifiedName)
-        extractColumnsLineage(getQuery(plan), parentColumnsLineage).map { case 
(k, v) =>
+        extractColumnsLineage(
+          getQuery(plan),
+          parentColumnsLineage,
+          inputTablesByPlan).map { case (k, v) =>
           k.withName(s"$table.${k.name}") -> v
         }
 
@@ -238,7 +255,10 @@ trait LineageParser {
           if p.nodeName == "CreateHiveTableAsSelectCommand" ||
             p.nodeName == "OptimizedCreateHiveTableAsSelectCommand" =>
         val table = getV1TableName(getField[CatalogTable](plan, 
"tableDesc").qualifiedName)
-        extractColumnsLineage(getQuery(plan), parentColumnsLineage).map { case 
(k, v) =>
+        extractColumnsLineage(
+          getQuery(plan),
+          parentColumnsLineage,
+          inputTablesByPlan).map { case (k, v) =>
           k.withName(s"$table.${k.name}") -> v
         }
 
@@ -259,7 +279,10 @@ trait LineageParser {
                 invokeAs[LogicalPlan](plan, "name"),
                 "catalog").name())
           }
-        extractColumnsLineage(getQuery(plan), parentColumnsLineage).map { case 
(k, v) =>
+        extractColumnsLineage(
+          getQuery(plan),
+          parentColumnsLineage,
+          inputTablesByPlan).map { case (k, v) =>
           k.withName(Seq(catalog, namespace, table, 
k.name).filter(_.nonEmpty).mkString(".")) -> v
         }
 
@@ -267,7 +290,7 @@ trait LineageParser {
         val logicalRelation = getField[LogicalRelation](plan, 
"logicalRelation")
         val table = logicalRelation
           .catalogTable.map(t => getV1TableName(t.qualifiedName)).getOrElse("")
-        extractColumnsLineage(getQuery(plan), parentColumnsLineage).map {
+        extractColumnsLineage(getQuery(plan), parentColumnsLineage, 
inputTablesByPlan).map {
           case (k, v) if table.nonEmpty =>
             k.withName(s"$table.${k.name}") -> v
         }
@@ -277,7 +300,7 @@ trait LineageParser {
           getField[Option[CatalogTable]](plan, "catalogTable")
             .map(t => getV1TableName(t.qualifiedName))
             .getOrElse("")
-        extractColumnsLineage(getQuery(plan), parentColumnsLineage).map {
+        extractColumnsLineage(getQuery(plan), parentColumnsLineage, 
inputTablesByPlan).map {
           case (k, v) if table.nonEmpty =>
             k.withName(s"$table.${k.name}") -> v
         }
@@ -288,26 +311,32 @@ trait LineageParser {
         val dir =
           getField[CatalogStorageFormat](plan, 
"storage").locationUri.map(_.toString)
             .getOrElse("")
-        extractColumnsLineage(getQuery(plan), parentColumnsLineage).map {
+        extractColumnsLineage(getQuery(plan), parentColumnsLineage, 
inputTablesByPlan).map {
           case (k, v) if dir.nonEmpty =>
             k.withName(s"`$dir`.${k.name}") -> v
         }
 
       case p if p.nodeName == "InsertIntoHiveTable" =>
         val table = getV1TableName(getField[CatalogTable](plan, 
"table").qualifiedName)
-        extractColumnsLineage(getQuery(plan), parentColumnsLineage).map { case 
(k, v) =>
+        extractColumnsLineage(
+          getQuery(plan),
+          parentColumnsLineage,
+          inputTablesByPlan).map { case (k, v) =>
           k.withName(s"$table.${k.name}") -> v
         }
 
       case p if p.nodeName == "SaveIntoDataSourceCommand" =>
-        extractColumnsLineage(getQuery(plan), parentColumnsLineage)
+        extractColumnsLineage(getQuery(plan), parentColumnsLineage, 
inputTablesByPlan)
 
       case p
           if p.nodeName == "AppendData"
             || p.nodeName == "OverwriteByExpression"
             || p.nodeName == "OverwritePartitionsDynamic" =>
         val table = getV2TableName(getField[NamedRelation](plan, "table"))
-        extractColumnsLineage(getQuery(plan), parentColumnsLineage).map { case 
(k, v) =>
+        extractColumnsLineage(
+          getQuery(plan),
+          parentColumnsLineage,
+          inputTablesByPlan).map { case (k, v) =>
           k.withName(s"$table.${k.name}") -> v
         }
       case p if p.nodeName == "MergeRows" =>
@@ -330,12 +359,15 @@ trait LineageParser {
             keyAttr -> attributeSet
         }: _*)
         p.children.map(
-          extractColumnsLineage(_, 
nextColumnsLineage)).reduce(mergeColumnsLineage)
+          extractColumnsLineage(
+            _,
+            nextColumnsLineage,
+            inputTablesByPlan)).reduce(mergeColumnsLineage)
 
       case p if p.nodeName == "WriteDelta" || p.nodeName == "ReplaceData" =>
         val table = getV2TableName(getField[NamedRelation](plan, "table"))
         val query = getQuery(plan)
-        val columnsLineage = extractColumnsLineage(query, parentColumnsLineage)
+        val columnsLineage = extractColumnsLineage(query, 
parentColumnsLineage, inputTablesByPlan)
         columnsLineage
           .filter { case (k, _) => !isMetadataAttr(k) }
           .map { case (k, v) =>
@@ -357,8 +389,12 @@ trait LineageParser {
         val sourceTable = getField[LogicalPlan](plan, "sourceTable")
         val targetColumnsLineage = extractColumnsLineage(
           targetTable,
-          nextColumnsLlineage.map { case (k, _) => (k, AttributeSet(k)) })
-        val sourceColumnsLineage = extractColumnsLineage(sourceTable, 
nextColumnsLlineage)
+          nextColumnsLlineage.map { case (k, _) => (k, AttributeSet(k)) },
+          inputTablesByPlan)
+        val sourceColumnsLineage = extractColumnsLineage(
+          sourceTable,
+          nextColumnsLlineage,
+          inputTablesByPlan)
         val targetColumnsWithTargetTable = 
targetColumnsLineage.values.flatten.map { column =>
           val unquotedQualifiedName = (column.qualifier :+ 
column.name).mkString(".")
           column.withName(unquotedQualifiedName)
@@ -367,18 +403,28 @@ trait LineageParser {
 
       case p if p.nodeName == "WithCTE" =>
         val optimized = sparkSession.sessionState.optimizer.execute(p)
-        extractColumnsLineage(optimized, parentColumnsLineage)
+        extractColumnsLineage(optimized, parentColumnsLineage, 
inputTablesByPlan)
 
       // For query
       case p: Project =>
         val nextColumnsLineage =
-          joinColumnsLineage(parentColumnsLineage, 
getSelectColumnLineage(p.projectList))
-        p.children.map(extractColumnsLineage(_, 
nextColumnsLineage)).reduce(mergeColumnsLineage)
+          joinColumnsLineage(
+            parentColumnsLineage,
+            getSelectColumnLineage(p.projectList, inputTablesByPlan))
+        p.children.map(extractColumnsLineage(
+          _,
+          nextColumnsLineage,
+          inputTablesByPlan)).reduce(mergeColumnsLineage)
 
       case p: Aggregate =>
         val nextColumnsLineage =
-          joinColumnsLineage(parentColumnsLineage, 
getSelectColumnLineage(p.aggregateExpressions))
-        p.children.map(extractColumnsLineage(_, 
nextColumnsLineage)).reduce(mergeColumnsLineage)
+          joinColumnsLineage(
+            parentColumnsLineage,
+            getSelectColumnLineage(p.aggregateExpressions, inputTablesByPlan))
+        p.children.map(extractColumnsLineage(
+          _,
+          nextColumnsLineage,
+          inputTablesByPlan)).reduce(mergeColumnsLineage)
 
       case p: Expand =>
         val references =
@@ -387,7 +433,10 @@ trait LineageParser {
         val childColumnsLineage = ListMap(p.output.zip(references): _*)
         val nextColumnsLineage =
           joinColumnsLineage(parentColumnsLineage, childColumnsLineage)
-        p.children.map(extractColumnsLineage(_, 
nextColumnsLineage)).reduce(mergeColumnsLineage)
+        p.children.map(extractColumnsLineage(
+          _,
+          nextColumnsLineage,
+          inputTablesByPlan)).reduce(mergeColumnsLineage)
 
       case p: Generate =>
         val generateColumnsLineageWithId =
@@ -400,7 +449,10 @@ trait LineageParser {
                 attr.exprId,
                 AttributeSet(attr))))
         }
-        p.children.map(extractColumnsLineage(_, 
nextColumnsLineage)).reduce(mergeColumnsLineage)
+        p.children.map(extractColumnsLineage(
+          _,
+          nextColumnsLineage,
+          inputTablesByPlan)).reduce(mergeColumnsLineage)
 
       case p: Window =>
         val windowColumnsLineage =
@@ -417,14 +469,18 @@ trait LineageParser {
                 windowColumnsLineage.getOrElse(attr, AttributeSet(attr))))
           }
         }
-        p.children.map(extractColumnsLineage(_, 
nextColumnsLineage)).reduce(mergeColumnsLineage)
+        p.children.map(extractColumnsLineage(
+          _,
+          nextColumnsLineage,
+          inputTablesByPlan)).reduce(mergeColumnsLineage)
 
       case p: Join =>
         p.joinType match {
           case LeftSemi | LeftAnti =>
-            extractColumnsLineage(p.left, parentColumnsLineage)
+            extractColumnsLineage(p.right, ListMap[Attribute, AttributeSet](), 
inputTablesByPlan)
+            extractColumnsLineage(p.left, parentColumnsLineage, 
inputTablesByPlan)
           case _ =>
-            p.children.map(extractColumnsLineage(_, parentColumnsLineage))
+            p.children.map(extractColumnsLineage(_, parentColumnsLineage, 
inputTablesByPlan))
               .reduce(mergeColumnsLineage)
         }
 
@@ -433,12 +489,15 @@ trait LineageParser {
           // support for the multi-insert statement
           if (p.output.isEmpty) {
             p.children
-              .map(extractColumnsLineage(_, ListMap[Attribute, 
AttributeSet]()))
+              .map(extractColumnsLineage(_, ListMap[Attribute, 
AttributeSet](), inputTablesByPlan))
               .reduce(mergeColumnsLineage)
           } else {
             // merge all children in to one derivedColumns
             val childrenUnion =
-              p.children.map(extractColumnsLineage(_, ListMap[Attribute, 
AttributeSet]())).map(
+              p.children.map(extractColumnsLineage(
+                _,
+                ListMap[Attribute, AttributeSet](),
+                inputTablesByPlan)).map(
                 _.values).reduce {
                 (left, right) =>
                   left.zip(right).map(attr => attr._1 ++ attr._2)
@@ -449,14 +508,17 @@ trait LineageParser {
 
       case p: LogicalRelation if p.catalogTable.nonEmpty =>
         val tableName = getV1TableName(p.catalogTable.get.qualifiedName)
+        inputTablesByPlan += tableName
         joinRelationColumnLineage(parentColumnsLineage, p.output, 
Seq(tableName))
 
       case p: HiveTableRelation =>
         val tableName = getV1TableName(p.tableMeta.qualifiedName)
+        inputTablesByPlan += tableName
         joinRelationColumnLineage(parentColumnsLineage, p.output, 
Seq(tableName))
 
       case p: DataSourceV2ScanRelation =>
         val tableName = getV2TableName(p)
+        inputTablesByPlan += tableName
         joinRelationColumnLineage(parentColumnsLineage, p.output, 
Seq(tableName))
 
       // For creating the view from v2 table, the logical plan of table will
@@ -464,9 +526,11 @@ trait LineageParser {
       // because the view from the table is not going to read it.
       case p: DataSourceV2Relation =>
         val tableName = getV2TableName(p)
+        inputTablesByPlan += tableName
         joinRelationColumnLineage(parentColumnsLineage, p.output, 
Seq(tableName))
 
       case p: LocalRelation =>
+        inputTablesByPlan += LOCAL_TABLE_IDENTIFIER
         joinRelationColumnLineage(parentColumnsLineage, p.output, 
Seq(LOCAL_TABLE_IDENTIFIER))
 
       case _: OneRowRelation =>
@@ -485,17 +549,18 @@ trait LineageParser {
       // so we just extract the columns lineage from its inner children 
(original view)
       case pvm if pvm.nodeName == "PermanentViewMarker" =>
         pvm.innerChildren.asInstanceOf[Seq[LogicalPlan]]
-          .map(extractColumnsLineage(_, parentColumnsLineage))
+          .map(extractColumnsLineage(_, parentColumnsLineage, 
inputTablesByPlan))
           .reduce(mergeColumnsLineage)
 
       case p: View =>
         if (!p.isTempView && SparkContextHelper.getConf(
             LineageConf.SKIP_PARSING_PERMANENT_VIEW_ENABLED)) {
           val viewName = getV1TableName(p.desc.qualifiedName)
+          inputTablesByPlan += viewName
           joinRelationColumnLineage(parentColumnsLineage, p.output, 
Seq(viewName))
         } else {
           val viewColumnsLineage =
-            extractColumnsLineage(p.child, ListMap[Attribute, AttributeSet]())
+            extractColumnsLineage(p.child, ListMap[Attribute, AttributeSet](), 
inputTablesByPlan)
           mergeRelationColumnLineage(parentColumnsLineage, p.output, 
viewColumnsLineage)
         }
 
@@ -505,7 +570,10 @@ trait LineageParser {
         cachedTableLogical match {
           case Some(logicPlan) =>
             val relationColumnLineage =
-              extractColumnsLineage(logicPlan, ListMap[Attribute, 
AttributeSet]())
+              extractColumnsLineage(
+                logicPlan,
+                ListMap[Attribute, AttributeSet](),
+                inputTablesByPlan)
             mergeRelationColumnLineage(parentColumnsLineage, p.output, 
relationColumnLineage)
           case _ =>
             joinRelationColumnLineage(
@@ -517,7 +585,10 @@ trait LineageParser {
       case p if p.children.isEmpty => ListMap[Attribute, AttributeSet]()
 
       case p =>
-        p.children.map(extractColumnsLineage(_, 
parentColumnsLineage)).reduce(mergeColumnsLineage)
+        p.children.map(extractColumnsLineage(
+          _,
+          parentColumnsLineage,
+          inputTablesByPlan)).reduce(mergeColumnsLineage)
     }
   }
 
diff --git 
a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/spark/kyuubi/lineage/LineageConf.scala
 
b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/spark/kyuubi/lineage/LineageConf.scala
index e264b1f359..afffb5f578 100644
--- 
a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/spark/kyuubi/lineage/LineageConf.scala
+++ 
b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/spark/kyuubi/lineage/LineageConf.scala
@@ -46,6 +46,15 @@ object LineageConf {
       "Unsupported lineage dispatchers")
     .createWithDefault(Seq(LineageDispatcherType.SPARK_EVENT.toString))
 
+  val LEGACY_COLLECT_INPUT_TABLES_ENABLED =
+    
ConfigBuilder("spark.kyuubi.plugin.lineage.legacy.collectInputTablesByColumn")
+      .internal
+      .doc("When true, collect input tables by column lineage. " +
+        "When false, collect all the input tables by the plan.")
+      .version("1.11.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val DEFAULT_CATALOG: String = SQLConf.get.getConf(SQLConf.DEFAULT_CATALOG)
 
 }
diff --git 
a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/RowLevelCatalogLineageParserSuite.scala
 
b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/RowLevelCatalogLineageParserSuite.scala
index 81b6d85e25..8af5b0f179 100644
--- 
a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/RowLevelCatalogLineageParserSuite.scala
+++ 
b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/RowLevelCatalogLineageParserSuite.scala
@@ -72,7 +72,7 @@ class RowLevelCatalogLineageParserSuite extends 
SparkSQLLineageParserHelperSuite
           "WHEN NOT MATCHED THEN " +
           "  INSERT *")
       assert(ret1 == Lineage(
-        List("v2_catalog.db.source_t"),
+        List("v2_catalog.db.source_t", "v2_catalog.db.target_t"),
         List("v2_catalog.db.target_t"),
         List(
           ("v2_catalog.db.target_t.pk", Set("v2_catalog.db.source_t.pk")),
@@ -91,7 +91,7 @@ class RowLevelCatalogLineageParserSuite extends 
SparkSQLLineageParserHelperSuite
           "  INSERT *")
 
       assert(ret2 == Lineage(
-        List("v2_catalog.db.source_t", "v2_catalog.db.pivot_t"),
+        List("v2_catalog.db.source_t", "v2_catalog.db.pivot_t", 
"v2_catalog.db.target_t"),
         List("v2_catalog.db.target_t"),
         List(
           ("v2_catalog.db.target_t.pk", Set("v2_catalog.db.source_t.pk")),
@@ -180,7 +180,7 @@ class RowLevelCatalogLineageParserSuite extends 
SparkSQLLineageParserHelperSuite
         "  INSERT *")
 
       assert(ret2 == Lineage(
-        List("v2_catalog.db.source_t", "v2_catalog.db.target_t", 
"v2_catalog.db.pivot_t"),
+        List("v2_catalog.db.source_t", "v2_catalog.db.pivot_t", 
"v2_catalog.db.target_t"),
         List("v2_catalog.db.target_t"),
         List(
           (
diff --git 
a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala
 
b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala
index e3cda6959f..380b3eee4f 100644
--- 
a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala
+++ 
b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala
@@ -1422,6 +1422,53 @@ abstract class SparkSQLLineageParserHelperSuite extends 
KyuubiFunSuite
     }
   }
 
+  test("columns lineage extract - collect input tables by plan") {
+    val ddls =
+      """
+        |create table v2_catalog.db.tb1(col1 string, col2 string, col3 string)
+        |create table v2_catalog.db.tb2(col1 string, col2 string, col3 string)
+        |create table v2_catalog.db.tb3(col1 string, col2 string, col3 string)
+        |""".stripMargin
+    ddls.split("\n").filter(_.nonEmpty).foreach(spark.sql(_).collect())
+    withTable("v2_catalog.db.tb1", "v2_catalog.db.tb2", "v2_catalog.db.tb3") { 
_ =>
+      val sql0 =
+        """
+          |insert overwrite v2_catalog.db.tb3
+          |select t1.col1, t1.col2 , t1.col3
+          |from v2_catalog.db.tb1 t1 join v2_catalog.db.tb2 t2
+          |on t1.col1 = t2.col1
+          |""".stripMargin
+
+      val ret0 = extractLineage(sql0)
+      assert(
+        ret0 == Lineage(
+          List("v2_catalog.db.tb1", "v2_catalog.db.tb2"),
+          List("v2_catalog.db.tb3"),
+          List(
+            ("v2_catalog.db.tb3.col1", Set("v2_catalog.db.tb1.col1")),
+            ("v2_catalog.db.tb3.col2", Set("v2_catalog.db.tb1.col2")),
+            ("v2_catalog.db.tb3.col3", Set("v2_catalog.db.tb1.col3")))))
+
+      val sql1 =
+        """
+          |insert overwrite v2_catalog.db.tb3
+          |select t1.col1, t1.col2 , t1.col3
+          |from v2_catalog.db.tb1 t1 left semi join v2_catalog.db.tb2 t2
+          |on t1.col1 = t2.col1
+          |""".stripMargin
+
+      val ret1 = extractLineage(sql1)
+      assert(
+        ret1 == Lineage(
+          List("v2_catalog.db.tb1", "v2_catalog.db.tb2"),
+          List("v2_catalog.db.tb3"),
+          List(
+            ("v2_catalog.db.tb3.col1", Set("v2_catalog.db.tb1.col1")),
+            ("v2_catalog.db.tb3.col2", Set("v2_catalog.db.tb1.col2")),
+            ("v2_catalog.db.tb3.col3", Set("v2_catalog.db.tb1.col3")))))
+    }
+  }
+
   protected def extractLineageWithoutExecuting(sql: String): Lineage = {
     val parsed = spark.sessionState.sqlParser.parsePlan(sql)
     val analyzed = spark.sessionState.analyzer.execute(parsed)
diff --git 
a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/TableCatalogLineageParserSuite.scala
 
b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/TableCatalogLineageParserSuite.scala
index ea607452aa..c9724c3fec 100644
--- 
a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/TableCatalogLineageParserSuite.scala
+++ 
b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/TableCatalogLineageParserSuite.scala
@@ -42,7 +42,7 @@ class TableCatalogLineageParserSuite extends 
SparkSQLLineageParserHelperSuite {
         "WHEN NOT MATCHED THEN " +
         "  INSERT (id, name, price) VALUES (cast(source.id as int), 
source.name, source.price)")
       assert(ret0 == Lineage(
-        List("v2_catalog.db.source_t"),
+        List("v2_catalog.db.source_t", "v2_catalog.db.target_t"),
         List("v2_catalog.db.target_t"),
         List(
           ("v2_catalog.db.target_t.id", Set("v2_catalog.db.source_t.id")),
@@ -57,7 +57,7 @@ class TableCatalogLineageParserSuite extends 
SparkSQLLineageParserHelperSuite {
         "WHEN NOT MATCHED THEN " +
         "  INSERT *")
       assert(ret1 == Lineage(
-        List("v2_catalog.db.source_t"),
+        List("v2_catalog.db.source_t", "v2_catalog.db.target_t"),
         List("v2_catalog.db.target_t"),
         List(
           ("v2_catalog.db.target_t.id", Set("v2_catalog.db.source_t.id")),
@@ -74,7 +74,7 @@ class TableCatalogLineageParserSuite extends 
SparkSQLLineageParserHelperSuite {
         "  INSERT *")
 
       assert(ret2 == Lineage(
-        List("v2_catalog.db.source_t", "v2_catalog.db.pivot_t"),
+        List("v2_catalog.db.source_t", "v2_catalog.db.pivot_t", 
"v2_catalog.db.target_t"),
         List("v2_catalog.db.target_t"),
         List(
           ("v2_catalog.db.target_t.id", Set("v2_catalog.db.source_t.id")),

Reply via email to