This is an automated email from the ASF dual-hosted git repository.

chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new cf3a98e17 [VL] [Core] Spark Input_file_name Support (#6021)
cf3a98e17 is described below

commit cf3a98e17ba48bd4fb40343990f5221ea399c82b
Author: 高阳阳 <[email protected]>
AuthorDate: Wed Jun 19 09:33:41 2024 +0800

    [VL] [Core] Spark Input_file_name Support (#6021)
---
 .../backendsapi/velox/VeloxIteratorApi.scala       |  10 +-
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  |   3 +-
 .../extension/InputFileNameReplaceRule.scala       | 155 +++++++++++++++++++++
 .../execution/ScalarFunctionsValidateSuite.scala   |   7 +
 .../columnar/heuristic/HeuristicApplier.scala      |   2 +-
 5 files changed, 174 insertions(+), 3 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala
index 459a7886e..880e1e56b 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala
@@ -19,6 +19,7 @@ package org.apache.gluten.backendsapi.velox
 import org.apache.gluten.GlutenNumaBindingInfo
 import org.apache.gluten.backendsapi.IteratorApi
 import org.apache.gluten.execution._
+import org.apache.gluten.extension.InputFileNameReplaceRule
 import org.apache.gluten.metrics.IMetrics
 import org.apache.gluten.sql.shims.SparkShimLoader
 import org.apache.gluten.substrait.plan.PlanNode
@@ -112,7 +113,7 @@ class VeloxIteratorApi extends IteratorApi with Logging {
     val fileSizes = new JArrayList[JLong]()
     val modificationTimes = new JArrayList[JLong]()
     val partitionColumns = new JArrayList[JMap[String, String]]
-    var metadataColumns = new JArrayList[JMap[String, String]]
+    val metadataColumns = new JArrayList[JMap[String, String]]
     files.foreach {
       file =>
         // The "file.filePath" in PartitionedFile is not the original encoded 
path, so the decoded
@@ -132,6 +133,13 @@ class VeloxIteratorApi extends IteratorApi with Logging {
         }
         val metadataColumn =
           SparkShimLoader.getSparkShims.generateMetadataColumns(file, 
metadataColumnNames)
+        metadataColumn.put(InputFileNameReplaceRule.replacedInputFileName, 
file.filePath.toString)
+        metadataColumn.put(
+          InputFileNameReplaceRule.replacedInputFileBlockStart,
+          file.start.toString)
+        metadataColumn.put(
+          InputFileNameReplaceRule.replacedInputFileBlockLength,
+          file.length.toString)
         metadataColumns.add(metadataColumn)
         val partitionColumn = new JHashMap[String, String]()
         for (i <- 0 until file.partitionValues.numFields) {
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index ebf82ea76..71930d7e0 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -807,7 +807,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
    */
   override def genExtendedColumnarValidationRules(): List[SparkSession => 
Rule[SparkPlan]] = List(
     BloomFilterMightContainJointRewriteRule.apply,
-    ArrowScanReplaceRule.apply
+    ArrowScanReplaceRule.apply,
+    InputFileNameReplaceRule.apply
   )
 
   /**
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/extension/InputFileNameReplaceRule.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/extension/InputFileNameReplaceRule.scala
new file mode 100644
index 000000000..cd3f50d8e
--- /dev/null
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/extension/InputFileNameReplaceRule.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, 
NamedExpression}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{FileSourceScanExec, ProjectExec, 
SparkPlan}
+import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
+import org.apache.spark.sql.types.{LongType, StringType}
+
+object InputFileNameReplaceRule {
+  val replacedInputFileName = "$input_file_name$"
+  val replacedInputFileBlockStart = "$input_file_block_start$"
+  val replacedInputFileBlockLength = "$input_file_block_length$"
+}
+
+case class InputFileNameReplaceRule(spark: SparkSession) extends 
Rule[SparkPlan] {
+  import InputFileNameReplaceRule._
+
+  private def isInputFileName(expr: Expression): Boolean = {
+    expr match {
+      case _: InputFileName => true
+      case _ => false
+    }
+  }
+
+  private def isInputFileBlockStart(expr: Expression): Boolean = {
+    expr match {
+      case _: InputFileBlockStart => true
+      case _ => false
+    }
+  }
+
+  private def isInputFileBlockLength(expr: Expression): Boolean = {
+    expr match {
+      case _: InputFileBlockLength => true
+      case _ => false
+    }
+  }
+
+  override def apply(plan: SparkPlan): SparkPlan = {
+    val replacedExprs = scala.collection.mutable.Map[String, 
AttributeReference]()
+
+    def hasParquetScan(plan: SparkPlan): Boolean = {
+      plan match {
+        case fileScan: FileSourceScanExec
+            if fileScan.relation.fileFormat.isInstanceOf[ParquetFileFormat] =>
+          true
+        case batchScan: BatchScanExec =>
+          batchScan.scan match {
+            case _: ParquetScan => true
+            case _ => false
+          }
+        case _ => plan.children.exists(hasParquetScan)
+      }
+    }
+
+    def mayNeedConvert(expr: Expression): Boolean = {
+      expr match {
+        case e if isInputFileName(e) => true
+        case s if isInputFileBlockStart(s) => true
+        case l if isInputFileBlockLength(l) => true
+        case other => other.children.exists(mayNeedConvert)
+      }
+    }
+
+    def doConvert(expr: Expression): Expression = {
+      expr match {
+        case e if isInputFileName(e) =>
+          replacedExprs.getOrElseUpdate(
+            replacedInputFileName,
+            AttributeReference(replacedInputFileName, StringType, true)())
+        case s if isInputFileBlockStart(s) =>
+          replacedExprs.getOrElseUpdate(
+            replacedInputFileBlockStart,
+            AttributeReference(replacedInputFileBlockStart, LongType, true)()
+          )
+        case l if isInputFileBlockLength(l) =>
+          replacedExprs.getOrElseUpdate(
+            replacedInputFileBlockLength,
+            AttributeReference(replacedInputFileBlockLength, LongType, true)()
+          )
+        case other =>
+          other.withNewChildren(other.children.map(child => doConvert(child)))
+      }
+    }
+
+    def ensureChildOutputHasNewAttrs(plan: SparkPlan): SparkPlan = {
+      plan match {
+        case _ @ProjectExec(projectList, child) =>
+          var newProjectList = projectList
+          for ((_, newAttr) <- replacedExprs) {
+            if (!newProjectList.exists(attr => attr.exprId == newAttr.exprId)) 
{
+              newProjectList = newProjectList :+ newAttr.toAttribute
+            }
+          }
+          val newChild = ensureChildOutputHasNewAttrs(child)
+          ProjectExec(newProjectList, newChild)
+        case f: FileSourceScanExec =>
+          var newOutput = f.output
+          for ((_, newAttr) <- replacedExprs) {
+            if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) {
+              newOutput = newOutput :+ newAttr.toAttribute
+            }
+          }
+          f.copy(output = newOutput)
+
+        case b: BatchScanExec =>
+          var newOutput = b.output
+          for ((_, newAttr) <- replacedExprs) {
+            if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) {
+              newOutput = newOutput :+ newAttr
+            }
+          }
+          b.copy(output = newOutput)
+        case other =>
+          val newChildren = other.children.map(ensureChildOutputHasNewAttrs)
+          other.withNewChildren(newChildren)
+      }
+    }
+
+    def replaceInputFileNameInProject(plan: SparkPlan): SparkPlan = {
+      plan match {
+        case _ @ProjectExec(projectList, child)
+            if projectList.exists(mayNeedConvert) && hasParquetScan(plan) =>
+          val newProjectList = projectList.map {
+            expr => doConvert(expr).asInstanceOf[NamedExpression]
+          }
+          val newChild = 
replaceInputFileNameInProject(ensureChildOutputHasNewAttrs(child))
+          ProjectExec(newProjectList, newChild)
+        case other =>
+          val newChildren = other.children.map(replaceInputFileNameInProject)
+          other.withNewChildren(newChildren)
+      }
+    }
+    replaceInputFileNameInProject(plan)
+  }
+}
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
index 9718b8e73..d08ba11ee 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
@@ -623,6 +623,13 @@ class ScalarFunctionsValidateSuite extends 
FunctionsValidateTest {
     }
   }
 
+  test("Test input_file_name function") {
+    runQueryAndCompare("""SELECT input_file_name(), l_orderkey
+                         | from lineitem limit 100""".stripMargin) {
+      checkGlutenOperatorMatch[ProjectExecTransformer]
+    }
+  }
+
   test("Test spark_partition_id function") {
     runQueryAndCompare("""SELECT spark_partition_id(), l_orderkey
                          | from lineitem limit 100""".stripMargin) {
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
index ad68786e6..d925bc231 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
@@ -96,11 +96,11 @@ class HeuristicApplier(session: SparkSession)
       (spark: SparkSession) => FallbackOnANSIMode(spark),
       (spark: SparkSession) => FallbackMultiCodegens(spark),
       (spark: SparkSession) => PlanOneRowRelation(spark),
-      (_: SparkSession) => FallbackEmptySchemaRelation(),
       (_: SparkSession) => RewriteSubqueryBroadcast()
     ) :::
       
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules()
 :::
       List(
+        (_: SparkSession) => FallbackEmptySchemaRelation(),
         (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark),
         (_: SparkSession) => RewriteSparkPlanRulesManager(),
         (_: SparkSession) => AddTransformHintRule()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to