This is an automated email from the ASF dual-hosted git repository.

rui pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new df6fe9065 [VL] Support row type and fix subfield in filter push-down 
(#6618)
df6fe9065 is described below

commit df6fe9065e0204c4f3e30dcb59d6d98df8b943e6
Author: Rui Mo <[email protected]>
AuthorDate: Wed Jul 31 15:55:14 2024 +0800

    [VL] Support row type and fix subfield in filter push-down (#6618)
---
 .../org/apache/gluten/execution/TestOperator.scala | 56 ++++++++++++++++-
 cpp/velox/substrait/SubstraitParser.cc             | 14 ++++-
 cpp/velox/substrait/SubstraitParser.h              |  5 +-
 cpp/velox/substrait/SubstraitToVeloxPlan.cc        | 73 ++++++++++++++--------
 4 files changed, 116 insertions(+), 32 deletions(-)

diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala 
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
index dcae4920d..0fb5fb549 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
@@ -29,7 +29,7 @@ import 
org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.window.WindowExec
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType, 
StructField, StructType}
+import org.apache.spark.sql.types.{ArrayType, DecimalType, IntegerType, 
StringType, StructField, StructType}
 
 import java.util.concurrent.TimeUnit
 
@@ -102,6 +102,33 @@ class TestOperator extends VeloxWholeStageTransformerSuite 
with AdaptiveSparkPla
         "where l_comment is null") { _ => }
     assert(df.isEmpty)
     checkLengthAndPlan(df, 0)
+
+    // Struct of array.
+    val data =
+      Row(Row(Array("a", "b", "c"), null)) ::
+        Row(Row(Array("d", "e", "f"), Array(1, 2, 3))) ::
+        Row(Row(null, null)) :: Nil
+
+    val schema = new StructType()
+      .add(
+        "struct",
+        new StructType()
+          .add("a0", ArrayType(StringType))
+          .add("a1", ArrayType(IntegerType)))
+
+    val dataFrame = spark.createDataFrame(JavaConverters.seqAsJavaList(data), 
schema)
+
+    withTempPath {
+      path =>
+        dataFrame.write.parquet(path.getCanonicalPath)
+        
spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("view")
+        runQueryAndCompare("select * from view where struct is null") {
+          checkGlutenOperatorMatch[FileSourceScanExecTransformer]
+        }
+        runQueryAndCompare("select * from view where struct.a0 is null") {
+          checkGlutenOperatorMatch[FileSourceScanExecTransformer]
+        }
+    }
   }
 
   test("is_null_has_null") {
@@ -119,6 +146,33 @@ class TestOperator extends VeloxWholeStageTransformerSuite 
with AdaptiveSparkPla
       "select l_orderkey from lineitem where l_comment is not null " +
         "and l_orderkey = 1") { _ => }
     checkLengthAndPlan(df, 6)
+
+    // Struct of array.
+    val data =
+      Row(Row(Array("a", "b", "c"), null)) ::
+        Row(Row(Array("d", "e", "f"), Array(1, 2, 3))) ::
+        Row(Row(null, null)) :: Nil
+
+    val schema = new StructType()
+      .add(
+        "struct",
+        new StructType()
+          .add("a0", ArrayType(StringType))
+          .add("a1", ArrayType(IntegerType)))
+
+    val dataFrame = spark.createDataFrame(JavaConverters.seqAsJavaList(data), 
schema)
+
+    withTempPath {
+      path =>
+        dataFrame.write.parquet(path.getCanonicalPath)
+        
spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("view")
+        runQueryAndCompare("select * from view where struct is not null") {
+          checkGlutenOperatorMatch[FileSourceScanExecTransformer]
+        }
+        runQueryAndCompare("select * from view where struct.a0 is not null") {
+          checkGlutenOperatorMatch[FileSourceScanExecTransformer]
+        }
+    }
   }
 
   test("is_null and is_not_null coexist") {
diff --git a/cpp/velox/substrait/SubstraitParser.cc 
b/cpp/velox/substrait/SubstraitParser.cc
index 6eb62f854..006a20c23 100644
--- a/cpp/velox/substrait/SubstraitParser.cc
+++ b/cpp/velox/substrait/SubstraitParser.cc
@@ -141,11 +141,21 @@ void SubstraitParser::parseColumnTypes(
   return;
 }
 
-int32_t SubstraitParser::parseReferenceSegment(const 
::substrait::Expression::ReferenceSegment& refSegment) {
+bool SubstraitParser::parseReferenceSegment(
+    const ::substrait::Expression::ReferenceSegment& refSegment,
+    uint32_t& fieldIndex) {
   auto typeCase = refSegment.reference_type_case();
   switch (typeCase) {
     case 
::substrait::Expression::ReferenceSegment::ReferenceTypeCase::kStructField: {
-      return refSegment.struct_field().field();
+      if (refSegment.struct_field().has_child()) {
+        // To parse subfield index is not supported.
+        return false;
+      }
+      fieldIndex = refSegment.struct_field().field();
+      if (fieldIndex < 0) {
+        return false;
+      }
+      return true;
     }
     default:
       VELOX_NYI("Substrait conversion not supported for ReferenceSegment 
'{}'", std::to_string(typeCase));
diff --git a/cpp/velox/substrait/SubstraitParser.h 
b/cpp/velox/substrait/SubstraitParser.h
index 1f766b91c..f42d05b4a 100644
--- a/cpp/velox/substrait/SubstraitParser.h
+++ b/cpp/velox/substrait/SubstraitParser.h
@@ -50,8 +50,9 @@ class SubstraitParser {
   /// Parse Substrait Type to Velox type.
   static facebook::velox::TypePtr parseType(const ::substrait::Type& 
substraitType, bool asLowerCase = false);
 
-  /// Parse Substrait ReferenceSegment.
-  static int32_t parseReferenceSegment(const 
::substrait::Expression::ReferenceSegment& refSegment);
+  /// Parse Substrait ReferenceSegment and extract the field index. Return 
false if the segment is not a valid unnested
+  /// field.
+  static bool parseReferenceSegment(const 
::substrait::Expression::ReferenceSegment& refSegment, uint32_t& fieldIndex);
 
   /// Make names in the format of {prefix}_{index}.
   static std::vector<std::string> makeNames(const std::string& prefix, int 
size);
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc 
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 7b41f7071..d7de84119 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -1530,8 +1530,7 @@ bool SubstraitToVeloxPlanConverter::fieldOrWithLiteral(
   if (arguments.size() == 1) {
     if (arguments[0].value().has_selection()) {
       // Only field exists.
-      fieldIndex = 
SubstraitParser::parseReferenceSegment(arguments[0].value().selection().direct_reference());
-      return true;
+      return 
SubstraitParser::parseReferenceSegment(arguments[0].value().selection().direct_reference(),
 fieldIndex);
     } else {
       return false;
     }
@@ -1546,13 +1545,17 @@ bool SubstraitToVeloxPlanConverter::fieldOrWithLiteral(
   for (const auto& param : arguments) {
     auto typeCase = param.value().rex_type_case();
     switch (typeCase) {
-      case ::substrait::Expression::RexTypeCase::kSelection:
-        fieldIndex = 
SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference());
+      case ::substrait::Expression::RexTypeCase::kSelection: {
+        if 
(!SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference(),
 fieldIndex)) {
+          return false;
+        }
         fieldExists = true;
         break;
-      case ::substrait::Expression::RexTypeCase::kLiteral:
+      }
+      case ::substrait::Expression::RexTypeCase::kLiteral: {
         literalExists = true;
         break;
+      }
       default:
         break;
     }
@@ -1564,7 +1567,7 @@ bool SubstraitToVeloxPlanConverter::fieldOrWithLiteral(
 bool SubstraitToVeloxPlanConverter::childrenFunctionsOnSameField(
     const ::substrait::Expression_ScalarFunction& function) {
   // Get the column indices of the children functions.
-  std::vector<int32_t> colIndices;
+  std::vector<uint32_t> colIndices;
   for (const auto& arg : function.arguments()) {
     if (arg.value().has_scalar_function()) {
       const auto& scalarFunction = arg.value().scalar_function();
@@ -1572,14 +1575,16 @@ bool 
SubstraitToVeloxPlanConverter::childrenFunctionsOnSameField(
         if (param.value().has_selection()) {
           const auto& field = param.value().selection();
           VELOX_CHECK(field.has_direct_reference());
-          int32_t colIdx = 
SubstraitParser::parseReferenceSegment(field.direct_reference());
+          uint32_t colIdx;
+          if 
(!SubstraitParser::parseReferenceSegment(field.direct_reference(), colIdx)) {
+            return false;
+          }
           colIndices.emplace_back(colIdx);
         }
       }
     } else if (arg.value().has_singular_or_list()) {
       const auto& singularOrList = arg.value().singular_or_list();
-      int32_t colIdx = getColumnIndexFromSingularOrList(singularOrList);
-      colIndices.emplace_back(colIdx);
+      
colIndices.emplace_back(getColumnIndexFromSingularOrList(singularOrList));
     } else {
       return false;
     }
@@ -1711,8 +1716,9 @@ void SubstraitToVeloxPlanConverter::separateFilters(
     if (format == dwio::common::FileFormat::ORC && 
scalarFunction.arguments().size() > 0) {
       auto value = scalarFunction.arguments().at(0).value();
       if (value.has_selection()) {
-        uint32_t fieldIndex = 
SubstraitParser::parseReferenceSegment(value.selection().direct_reference());
-        if (!veloxTypeList.empty() && 
veloxTypeList.at(fieldIndex)->isDecimal()) {
+        uint32_t fieldIndex;
+        bool parsed = 
SubstraitParser::parseReferenceSegment(value.selection().direct_reference(), 
fieldIndex);
+        if (!parsed || (!veloxTypeList.empty() && 
veloxTypeList.at(fieldIndex)->isDecimal())) {
           remainingFunctions.emplace_back(scalarFunction);
           continue;
         }
@@ -1870,14 +1876,20 @@ void SubstraitToVeloxPlanConverter::setFilterInfo(
   for (const auto& param : scalarFunction.arguments()) {
     auto typeCase = param.value().rex_type_case();
     switch (typeCase) {
-      case ::substrait::Expression::RexTypeCase::kSelection:
+      case ::substrait::Expression::RexTypeCase::kSelection: {
         typeCases.emplace_back("kSelection");
-        colIdx = 
SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference());
+        uint32_t index;
+        VELOX_CHECK(
+            
SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference(),
 index),
+            "Failed to parse the column index from the selection.");
+        colIdx = index;
         break;
-      case ::substrait::Expression::RexTypeCase::kLiteral:
+      }
+      case ::substrait::Expression::RexTypeCase::kLiteral: {
         typeCases.emplace_back("kLiteral");
         substraitLit = param.value().literal();
         break;
+      }
       default:
         VELOX_NYI("Substrait conversion not supported for arg type '{}'", 
std::to_string(typeCase));
     }
@@ -2177,18 +2189,17 @@ void 
SubstraitToVeloxPlanConverter::constructSubfieldFilters(
       VELOX_CHECK(value == filterInfo.upperBounds_[0].value().value<bool>(), 
"invalid state of bool equal");
       filters[common::Subfield(inputName)] = 
std::make_unique<common::BoolValue>(value, nullAllowed);
     }
-  } else if constexpr (KIND == facebook::velox::TypeKind::ARRAY || KIND == 
facebook::velox::TypeKind::MAP) {
-    // Only IsNotNull and IsNull are supported for array and map types.
-    if (rangeSize == 0) {
-      if (!nullAllowed) {
-        filters[common::Subfield(inputName)] = 
std::make_unique<common::IsNotNull>();
-      } else if (isNull) {
-        filters[common::Subfield(inputName)] = 
std::make_unique<common::IsNull>();
-      } else {
-        VELOX_NYI(
-            "Only IsNotNull and IsNull are supported in 
constructSubfieldFilters for input type '{}'.",
-            inputType->toString());
-      }
+  } else if constexpr (
+      KIND == facebook::velox::TypeKind::ARRAY || KIND == 
facebook::velox::TypeKind::MAP ||
+      KIND == facebook::velox::TypeKind::ROW) {
+    // Only IsNotNull and IsNull are supported for complex types.
+    VELOX_CHECK_EQ(rangeSize, 0, "Only IsNotNull and IsNull are supported for 
complex type.");
+    if (!nullAllowed) {
+      filters[common::Subfield(inputName)] = 
std::make_unique<common::IsNotNull>();
+    } else if (isNull) {
+      filters[common::Subfield(inputName)] = 
std::make_unique<common::IsNull>();
+    } else {
+      VELOX_NYI("Only IsNotNull and IsNull are supported for input type 
'{}'.", inputType->toString());
     }
   } else {
     using NativeType = typename RangeTraits<KIND>::NativeType;
@@ -2393,6 +2404,10 @@ connector::hive::SubfieldFilters 
SubstraitToVeloxPlanConverter::mapToFilters(
           constructSubfieldFilters<TypeKind::MAP, common::Filter>(
               colIdx, inputNameList[colIdx], inputType, 
columnToFilterInfo[colIdx], filters);
           break;
+        case TypeKind::ROW:
+          constructSubfieldFilters<TypeKind::ROW, common::Filter>(
+              colIdx, inputNameList[colIdx], inputType, 
columnToFilterInfo[colIdx], filters);
+          break;
         default:
           VELOX_NYI(
               "Subfield filters creation not supported for input type '{}' in 
mapToFilters", inputType->toString());
@@ -2494,7 +2509,11 @@ uint32_t 
SubstraitToVeloxPlanConverter::getColumnIndexFromSingularOrList(
   } else {
     VELOX_FAIL("Unsupported type in IN pushdown.");
   }
-  return SubstraitParser::parseReferenceSegment(selection.direct_reference());
+  uint32_t index;
+  VELOX_CHECK(
+      SubstraitParser::parseReferenceSegment(selection.direct_reference(), 
index),
+      "Failed to parse column index from SingularOrList.");
+  return index;
 }
 
 void SubstraitToVeloxPlanConverter::setFilterInfo(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to