This is an automated email from the ASF dual-hosted git repository.

yma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 4533c7251 [V] Remove complex type fallback for parquet (#6712)
4533c7251 is described below

commit 4533c725129c1c89096beadcdb540306a7c259a8
Author: Yan Ma <[email protected]>
AuthorDate: Thu Sep 5 08:57:17 2024 +0800

    [V] Remove complex type fallback for parquet (#6712)
    
    * disable complex type fallback for parquet
    
    * disable parquet files reading as velox not supported yet
    
    * fallback timestamp scan for parquet if necessary
---
 .../gluten/backendsapi/velox/VeloxBackend.scala    | 50 +---------------
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  |  2 +-
 .../gluten/expression/ExpressionTransformer.scala  |  7 ++-
 .../gluten/execution/MiscOperatorSuite.scala       |  2 +-
 .../VeloxParquetDataTypeValidationSuite.scala      | 16 -----
 .../gluten/expression/ExpressionConverter.scala    | 70 ++++++++++++++++++++--
 6 files changed, 73 insertions(+), 74 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index 065adf338..611e9c15b 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -98,55 +98,15 @@ object VeloxBackendSettings extends BackendSettingsApi {
       }
     }
 
-    val parquetTypeValidatorWithComplexTypeFallback: 
PartialFunction[StructField, String] = {
-      case StructField(_, arrayType: ArrayType, _, _) =>
-        arrayType.simpleString + " is forced to fallback."
-      case StructField(_, mapType: MapType, _, _) =>
-        mapType.simpleString + " is forced to fallback."
-      case StructField(_, structType: StructType, _, _) =>
-        structType.simpleString + " is forced to fallback."
-      case StructField(_, timestampType: TimestampType, _, _)
-          if GlutenConfig.getConf.forceParquetTimestampTypeScanFallbackEnabled 
=>
-        timestampType.simpleString + " is forced to fallback."
-    }
-    val orcTypeValidatorWithComplexTypeFallback: PartialFunction[StructField, 
String] = {
-      case StructField(_, arrayType: ArrayType, _, _) =>
-        arrayType.simpleString + " is forced to fallback."
-      case StructField(_, mapType: MapType, _, _) =>
-        mapType.simpleString + " is forced to fallback."
-      case StructField(_, structType: StructType, _, _) =>
-        structType.simpleString + " is forced to fallback."
-      case StructField(_, stringType: StringType, _, metadata)
-          if isCharType(stringType, metadata) =>
-        CharVarcharUtils.getRawTypeString(metadata) + " not support"
-      case StructField(_, TimestampType, _, _) => "TimestampType not support"
-    }
     format match {
       case ParquetReadFormat =>
         val typeValidator: PartialFunction[StructField, String] = {
-          // Parquet scan of nested array with struct/array as element type is 
unsupported in Velox.
-          case StructField(_, arrayType: ArrayType, _, _)
-              if arrayType.elementType.isInstanceOf[StructType] =>
-            "StructType as element in ArrayType"
-          case StructField(_, arrayType: ArrayType, _, _)
-              if arrayType.elementType.isInstanceOf[ArrayType] =>
-            "ArrayType as element in ArrayType"
-          // Parquet scan of nested map with struct as key type,
-          // or array type as value type is not supported in Velox.
-          case StructField(_, mapType: MapType, _, _) if 
mapType.keyType.isInstanceOf[StructType] =>
-            "StructType as Key in MapType"
-          case StructField(_, mapType: MapType, _, _)
-              if mapType.valueType.isInstanceOf[ArrayType] =>
-            "ArrayType as Value in MapType"
+          // Parquet timestamp is not fully supported yet
           case StructField(_, TimestampType, _, _)
               if 
GlutenConfig.getConf.forceParquetTimestampTypeScanFallbackEnabled =>
             "TimestampType"
         }
-        if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) {
-          validateTypes(typeValidator)
-        } else {
-          validateTypes(parquetTypeValidatorWithComplexTypeFallback)
-        }
+        validateTypes(typeValidator)
       case DwrfReadFormat => ValidationResult.succeeded
       case OrcReadFormat =>
         if (!GlutenConfig.getConf.veloxOrcScanEnabled) {
@@ -170,11 +130,7 @@ object VeloxBackendSettings extends BackendSettingsApi {
               CharVarcharUtils.getRawTypeString(metadata) + " not support"
             case StructField(_, TimestampType, _, _) => "TimestampType not 
support"
           }
-          if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) {
-            validateTypes(typeValidator)
-          } else {
-            validateTypes(orcTypeValidatorWithComplexTypeFallback)
-          }
+          validateTypes(typeValidator)
         }
       case _ => ValidationResult.failed(s"Unsupported file format for 
$format.")
     }
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 4755adc91..4a9bfef55 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -713,7 +713,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
       childTransformer: ExpressionTransformer,
       ordinal: Int,
       original: GetStructField): ExpressionTransformer = {
-    VeloxGetStructFieldTransformer(substraitExprName, childTransformer, 
original)
+    VeloxGetStructFieldTransformer(substraitExprName, childTransformer, 
ordinal, original)
   }
 
   /**
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
index 71e58f124..4a40d9410 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
@@ -54,19 +54,20 @@ case class VeloxNamedStructTransformer(
 case class VeloxGetStructFieldTransformer(
     substraitExprName: String,
     child: ExpressionTransformer,
+    ordinal: Int,
     original: GetStructField)
   extends UnaryExpressionTransformer {
   override def doTransform(args: Object): ExpressionNode = {
     val childNode = child.doTransform(args)
     childNode match {
       case node: StructLiteralNode =>
-        node.getFieldLiteral(original.ordinal)
+        node.getFieldLiteral(ordinal)
       case node: SelectionNode =>
         // Append the nested index to selection node.
-        node.addNestedChildIdx(JInteger.valueOf(original.ordinal))
+        node.addNestedChildIdx(JInteger.valueOf(ordinal))
       case node: NullLiteralNode =>
         val nodeType =
-          
node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(original.ordinal)
+          node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(ordinal)
         ExpressionBuilder.makeNullLiteral(nodeType)
       case other =>
         throw new GlutenNotSupportException(s"$other is not supported.")
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
index 296e98ca9..fc56d049f 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
@@ -1713,7 +1713,7 @@ class MiscOperatorSuite extends 
VeloxWholeStageTransformerSuite with AdaptiveSpa
 
       sql("CREATE TABLE t2(id INT, l ARRAY<STRUCT<k: INT, v: INT>>) USING 
PARQUET")
       sql("INSERT INTO t2 VALUES(1, ARRAY(STRUCT(1, 100))), (2, 
ARRAY(STRUCT(2, 200)))")
-      runQueryAndCompare("SELECT first(l) FROM t2")(df => 
checkFallbackOperators(df, 1))
+      runQueryAndCompare("SELECT first(l) FROM t2")(df => 
checkFallbackOperators(df, 0))
     }
   }
 
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
index 85b3f32a7..8b6cc63c9 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
@@ -427,22 +427,6 @@ class VeloxParquetDataTypeValidationSuite extends 
VeloxWholeStageTransformerSuit
     }
   }
 
-  test("Force complex type scan fallback") {
-    withSQLConf(("spark.gluten.sql.complexType.scan.fallback.enabled", 
"true")) {
-      val df = spark.sql("select struct from type1")
-      val executedPlan = getExecutedPlan(df)
-      assert(!executedPlan.exists(plan => 
plan.isInstanceOf[BatchScanExecTransformer]))
-    }
-  }
-
-  test("Force timestamp type scan fallback") {
-    
withSQLConf(("spark.gluten.sql.parquet.timestampType.scan.fallback.enabled", 
"true")) {
-      val df = spark.sql("select timestamp from type1")
-      val executedPlan = getExecutedPlan(df)
-      assert(!executedPlan.exists(plan => 
plan.isInstanceOf[BatchScanExecTransformer]))
-    }
-  }
-
   test("Decimal type") {
     // Validation: BatchScan Project Aggregate Expand Sort Limit
     runQueryAndCompare(
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index c5ba3a8a7..6f6e2cf12 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -33,6 +33,8 @@ import org.apache.spark.sql.hive.HiveUDFTransformer
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
+import scala.collection.mutable.ArrayBuffer
+
 trait Transformable {
   def getTransformer(childrenTransformers: Seq[ExpressionTransformer]): 
ExpressionTransformer
 }
@@ -345,12 +347,23 @@ object ExpressionConverter extends SQLConfHelper with 
Logging {
             expr => replaceWithExpressionTransformer0(expr, attributeSeq, 
expressionsMap)),
           m)
       case getStructField: GetStructField =>
-        // Different backends may have different result.
-        
BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer(
-          substraitExprName,
-          replaceWithExpressionTransformer0(getStructField.child, 
attributeSeq, expressionsMap),
-          getStructField.ordinal,
-          getStructField)
+        try {
+          val bindRef =
+            bindGetStructField(getStructField, attributeSeq)
+          // Different backends may have different result.
+          
BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer(
+            substraitExprName,
+            replaceWithExpressionTransformer0(getStructField.child, 
attributeSeq, expressionsMap),
+            bindRef.ordinal,
+            getStructField)
+        } catch {
+          case e: IllegalStateException =>
+            // This situation may need developers to fix, although we just 
throw the below
+            // exception to let the corresponding operator fall back.
+            throw new UnsupportedOperationException(
+              s"Failed to bind reference for $getStructField: ${e.getMessage}")
+        }
+
       case getArrayStructFields: GetArrayStructFields =>
         GenericExpressionTransformer(
           substraitExprName,
@@ -693,4 +706,49 @@ object ExpressionConverter extends SQLConfHelper with 
Logging {
     }
     substraitExprName
   }
+
+  private def bindGetStructField(
+      structField: GetStructField,
+      input: AttributeSeq): BoundReference = {
+    // get the new ordinal base input
+    var newOrdinal: Int = -1
+    val names = new ArrayBuffer[String]
+    var root: Expression = structField
+    while (root.isInstanceOf[GetStructField]) {
+      val curField = root.asInstanceOf[GetStructField]
+      val name = curField.childSchema.fields(curField.ordinal).name
+      names += name
+      root = root.asInstanceOf[GetStructField].child
+    }
+    // For map/array type, the reference is correct no matter 
NESTED_SCHEMA_PRUNING_ENABLED or not
+    if (!root.isInstanceOf[AttributeReference]) {
+      return BoundReference(structField.ordinal, structField.dataType, 
structField.nullable)
+    }
+    names += root.asInstanceOf[AttributeReference].name
+    input.attrs.foreach(
+      attribute => {
+        var level = names.size - 1
+        if (names(level) == attribute.name) {
+          var candidateFields: Array[StructField] = null
+          var dtType = attribute.dataType
+          while (dtType.isInstanceOf[StructType] && level >= 1) {
+            candidateFields = dtType.asInstanceOf[StructType].fields
+            level -= 1
+            val curName = names(level)
+            for (i <- 0 until candidateFields.length) {
+              if (candidateFields(i).name == curName) {
+                dtType = candidateFields(i).dataType
+                newOrdinal = i
+              }
+            }
+          }
+        }
+      })
+    if (newOrdinal == -1) {
+      throw new IllegalStateException(
+        s"Couldn't find $structField in ${input.attrs.mkString("[", ",", 
"]")}")
+    } else {
+      BoundReference(newOrdinal, structField.dataType, structField.nullable)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to