This is an automated email from the ASF dual-hosted git repository.
yma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 4533c7251 [V] Remove complex type fallback for parquet (#6712)
4533c7251 is described below
commit 4533c725129c1c89096beadcdb540306a7c259a8
Author: Yan Ma <[email protected]>
AuthorDate: Thu Sep 5 08:57:17 2024 +0800
[V] Remove complex type fallback for parquet (#6712)
* disable complex type fallback for parquet
* disable parquet files reading as velox not supported yet
* fallback timestamp scan for parquet if necessary
---
.../gluten/backendsapi/velox/VeloxBackend.scala | 50 +---------------
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 2 +-
.../gluten/expression/ExpressionTransformer.scala | 7 ++-
.../gluten/execution/MiscOperatorSuite.scala | 2 +-
.../VeloxParquetDataTypeValidationSuite.scala | 16 -----
.../gluten/expression/ExpressionConverter.scala | 70 ++++++++++++++++++++--
6 files changed, 73 insertions(+), 74 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index 065adf338..611e9c15b 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -98,55 +98,15 @@ object VeloxBackendSettings extends BackendSettingsApi {
}
}
- val parquetTypeValidatorWithComplexTypeFallback:
PartialFunction[StructField, String] = {
- case StructField(_, arrayType: ArrayType, _, _) =>
- arrayType.simpleString + " is forced to fallback."
- case StructField(_, mapType: MapType, _, _) =>
- mapType.simpleString + " is forced to fallback."
- case StructField(_, structType: StructType, _, _) =>
- structType.simpleString + " is forced to fallback."
- case StructField(_, timestampType: TimestampType, _, _)
- if GlutenConfig.getConf.forceParquetTimestampTypeScanFallbackEnabled
=>
- timestampType.simpleString + " is forced to fallback."
- }
- val orcTypeValidatorWithComplexTypeFallback: PartialFunction[StructField,
String] = {
- case StructField(_, arrayType: ArrayType, _, _) =>
- arrayType.simpleString + " is forced to fallback."
- case StructField(_, mapType: MapType, _, _) =>
- mapType.simpleString + " is forced to fallback."
- case StructField(_, structType: StructType, _, _) =>
- structType.simpleString + " is forced to fallback."
- case StructField(_, stringType: StringType, _, metadata)
- if isCharType(stringType, metadata) =>
- CharVarcharUtils.getRawTypeString(metadata) + " not support"
- case StructField(_, TimestampType, _, _) => "TimestampType not support"
- }
format match {
case ParquetReadFormat =>
val typeValidator: PartialFunction[StructField, String] = {
- // Parquet scan of nested array with struct/array as element type is
unsupported in Velox.
- case StructField(_, arrayType: ArrayType, _, _)
- if arrayType.elementType.isInstanceOf[StructType] =>
- "StructType as element in ArrayType"
- case StructField(_, arrayType: ArrayType, _, _)
- if arrayType.elementType.isInstanceOf[ArrayType] =>
- "ArrayType as element in ArrayType"
- // Parquet scan of nested map with struct as key type,
- // or array type as value type is not supported in Velox.
- case StructField(_, mapType: MapType, _, _) if
mapType.keyType.isInstanceOf[StructType] =>
- "StructType as Key in MapType"
- case StructField(_, mapType: MapType, _, _)
- if mapType.valueType.isInstanceOf[ArrayType] =>
- "ArrayType as Value in MapType"
+ // Parquet timestamp is not fully supported yet
case StructField(_, TimestampType, _, _)
if
GlutenConfig.getConf.forceParquetTimestampTypeScanFallbackEnabled =>
"TimestampType"
}
- if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) {
- validateTypes(typeValidator)
- } else {
- validateTypes(parquetTypeValidatorWithComplexTypeFallback)
- }
+ validateTypes(typeValidator)
case DwrfReadFormat => ValidationResult.succeeded
case OrcReadFormat =>
if (!GlutenConfig.getConf.veloxOrcScanEnabled) {
@@ -170,11 +130,7 @@ object VeloxBackendSettings extends BackendSettingsApi {
CharVarcharUtils.getRawTypeString(metadata) + " not support"
case StructField(_, TimestampType, _, _) => "TimestampType not
support"
}
- if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) {
- validateTypes(typeValidator)
- } else {
- validateTypes(orcTypeValidatorWithComplexTypeFallback)
- }
+ validateTypes(typeValidator)
}
case _ => ValidationResult.failed(s"Unsupported file format for
$format.")
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 4755adc91..4a9bfef55 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -713,7 +713,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
childTransformer: ExpressionTransformer,
ordinal: Int,
original: GetStructField): ExpressionTransformer = {
- VeloxGetStructFieldTransformer(substraitExprName, childTransformer,
original)
+ VeloxGetStructFieldTransformer(substraitExprName, childTransformer,
ordinal, original)
}
/**
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
index 71e58f124..4a40d9410 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
@@ -54,19 +54,20 @@ case class VeloxNamedStructTransformer(
case class VeloxGetStructFieldTransformer(
substraitExprName: String,
child: ExpressionTransformer,
+ ordinal: Int,
original: GetStructField)
extends UnaryExpressionTransformer {
override def doTransform(args: Object): ExpressionNode = {
val childNode = child.doTransform(args)
childNode match {
case node: StructLiteralNode =>
- node.getFieldLiteral(original.ordinal)
+ node.getFieldLiteral(ordinal)
case node: SelectionNode =>
// Append the nested index to selection node.
- node.addNestedChildIdx(JInteger.valueOf(original.ordinal))
+ node.addNestedChildIdx(JInteger.valueOf(ordinal))
case node: NullLiteralNode =>
val nodeType =
-
node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(original.ordinal)
+ node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(ordinal)
ExpressionBuilder.makeNullLiteral(nodeType)
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
index 296e98ca9..fc56d049f 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
@@ -1713,7 +1713,7 @@ class MiscOperatorSuite extends
VeloxWholeStageTransformerSuite with AdaptiveSpa
sql("CREATE TABLE t2(id INT, l ARRAY<STRUCT<k: INT, v: INT>>) USING
PARQUET")
sql("INSERT INTO t2 VALUES(1, ARRAY(STRUCT(1, 100))), (2,
ARRAY(STRUCT(2, 200)))")
- runQueryAndCompare("SELECT first(l) FROM t2")(df =>
checkFallbackOperators(df, 1))
+ runQueryAndCompare("SELECT first(l) FROM t2")(df =>
checkFallbackOperators(df, 0))
}
}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
index 85b3f32a7..8b6cc63c9 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
@@ -427,22 +427,6 @@ class VeloxParquetDataTypeValidationSuite extends
VeloxWholeStageTransformerSuit
}
}
- test("Force complex type scan fallback") {
- withSQLConf(("spark.gluten.sql.complexType.scan.fallback.enabled",
"true")) {
- val df = spark.sql("select struct from type1")
- val executedPlan = getExecutedPlan(df)
- assert(!executedPlan.exists(plan =>
plan.isInstanceOf[BatchScanExecTransformer]))
- }
- }
-
- test("Force timestamp type scan fallback") {
-
withSQLConf(("spark.gluten.sql.parquet.timestampType.scan.fallback.enabled",
"true")) {
- val df = spark.sql("select timestamp from type1")
- val executedPlan = getExecutedPlan(df)
- assert(!executedPlan.exists(plan =>
plan.isInstanceOf[BatchScanExecTransformer]))
- }
- }
-
test("Decimal type") {
// Validation: BatchScan Project Aggregate Expand Sort Limit
runQueryAndCompare(
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index c5ba3a8a7..6f6e2cf12 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -33,6 +33,8 @@ import org.apache.spark.sql.hive.HiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+import scala.collection.mutable.ArrayBuffer
+
trait Transformable {
def getTransformer(childrenTransformers: Seq[ExpressionTransformer]):
ExpressionTransformer
}
@@ -345,12 +347,23 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
expr => replaceWithExpressionTransformer0(expr, attributeSeq,
expressionsMap)),
m)
case getStructField: GetStructField =>
- // Different backends may have different result.
-
BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer(
- substraitExprName,
- replaceWithExpressionTransformer0(getStructField.child,
attributeSeq, expressionsMap),
- getStructField.ordinal,
- getStructField)
+ try {
+ val bindRef =
+ bindGetStructField(getStructField, attributeSeq)
+ // Different backends may have different result.
+
BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer(
+ substraitExprName,
+ replaceWithExpressionTransformer0(getStructField.child,
attributeSeq, expressionsMap),
+ bindRef.ordinal,
+ getStructField)
+ } catch {
+ case e: IllegalStateException =>
+ // This situation may need developers to fix, although we just
throw the below
+ // exception to let the corresponding operator fall back.
+ throw new UnsupportedOperationException(
+ s"Failed to bind reference for $getStructField: ${e.getMessage}")
+ }
+
case getArrayStructFields: GetArrayStructFields =>
GenericExpressionTransformer(
substraitExprName,
@@ -693,4 +706,49 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
}
substraitExprName
}
+
+ private def bindGetStructField(
+ structField: GetStructField,
+ input: AttributeSeq): BoundReference = {
+ // get the new ordinal base input
+ var newOrdinal: Int = -1
+ val names = new ArrayBuffer[String]
+ var root: Expression = structField
+ while (root.isInstanceOf[GetStructField]) {
+ val curField = root.asInstanceOf[GetStructField]
+ val name = curField.childSchema.fields(curField.ordinal).name
+ names += name
+ root = root.asInstanceOf[GetStructField].child
+ }
+ // For map/array type, the reference is correct no matter
NESTED_SCHEMA_PRUNING_ENABLED or not
+ if (!root.isInstanceOf[AttributeReference]) {
+ return BoundReference(structField.ordinal, structField.dataType,
structField.nullable)
+ }
+ names += root.asInstanceOf[AttributeReference].name
+ input.attrs.foreach(
+ attribute => {
+ var level = names.size - 1
+ if (names(level) == attribute.name) {
+ var candidateFields: Array[StructField] = null
+ var dtType = attribute.dataType
+ while (dtType.isInstanceOf[StructType] && level >= 1) {
+ candidateFields = dtType.asInstanceOf[StructType].fields
+ level -= 1
+ val curName = names(level)
+ for (i <- 0 until candidateFields.length) {
+ if (candidateFields(i).name == curName) {
+ dtType = candidateFields(i).dataType
+ newOrdinal = i
+ }
+ }
+ }
+ }
+ })
+ if (newOrdinal == -1) {
+ throw new IllegalStateException(
+ s"Couldn't find $structField in ${input.attrs.mkString("[", ",",
"]")}")
+ } else {
+ BoundReference(newOrdinal, structField.dataType, structField.nullable)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]