This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4177292  [SPARK-27435][SQL] Support schema pruning in ORC V2
4177292 is described below

commit 4177292dcd65f77679e74365dc627a506812ddbb
Author: Gengliang Wang <[email protected]>
AuthorDate: Thu Apr 11 20:03:32 2019 +0800

    [SPARK-27435][SQL] Support schema pruning in ORC V2
    
    ## What changes were proposed in this pull request?
    
    Currently, the optimization rule `SchemaPruning` only works for Parquet/Orc 
V1.
    We should have the same optimization in ORC V2.
    
    ## How was this patch tested?
    
    Unit test
    
    Closes #24338 from gengliangwang/schemaPruningForV2.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/execution/datasources/SchemaPruning.scala  | 104 +++++++++++++++++----
 .../execution/datasources/SchemaPruningSuite.scala |   4 +-
 ...ngSuite.scala => OrcV1SchemaPruningSuite.scala} |   2 +-
 ...ngSuite.scala => OrcV2SchemaPruningSuite.scala} |  26 +++++-
 4 files changed, 109 insertions(+), 27 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
index 3a37ca7..15fdf65 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.sql.execution.datasources
 
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, 
LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, 
FileTable}
+import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, 
StructType}
 
@@ -48,7 +51,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
           l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _))
         if canPruneRelation(hadoopFsRelation) =>
         val (normalizedProjects, normalizedFilters) =
-          normalizeAttributeRefNames(l, projects, filters)
+          normalizeAttributeRefNames(l.output, projects, filters)
         val requestedRootFields = identifyRootFields(normalizedProjects, 
normalizedFilters)
 
         // If requestedRootFields includes a nested field, continue. Otherwise,
@@ -76,6 +79,43 @@ object SchemaPruning extends Rule[LogicalPlan] {
         } else {
           op
         }
+
+      case op @ PhysicalOperation(projects, filters,
+          d @ DataSourceV2Relation(table: FileTable, output, _)) if 
canPruneTable(table) =>
+        val (normalizedProjects, normalizedFilters) =
+          normalizeAttributeRefNames(output, projects, filters)
+        val requestedRootFields = identifyRootFields(normalizedProjects, 
normalizedFilters)
+
+        // If requestedRootFields includes a nested field, continue. Otherwise,
+        // return op
+        if (requestedRootFields.exists { root: RootField => 
!root.derivedFromAtt }) {
+          val dataSchema = table.dataSchema
+          val prunedDataSchema = pruneDataSchema(dataSchema, 
requestedRootFields)
+
+          // If the data schema is different from the pruned data schema, 
continue. Otherwise,
+          // return op. We effect this comparison by counting the number of 
"leaf" fields in
+          // each schemata, assuming the fields in prunedDataSchema are a 
subset of the fields
+          // in dataSchema.
+          if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) {
+            val prunedFileTable = table match {
+              case o: OrcTable => o.copy(userSpecifiedSchema = 
Some(prunedDataSchema))
+              case _ =>
+                val message = s"${table.formatName} data source doesn't 
support schema pruning."
+                throw new AnalysisException(message)
+            }
+
+
+            val prunedRelationV2 = buildPrunedRelationV2(d, prunedFileTable)
+            val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
+
+            buildNewProjection(normalizedProjects, normalizedFilters, 
prunedRelationV2,
+              projectionOverSchema)
+          } else {
+            op
+          }
+        } else {
+          op
+        }
     }
 
   /**
@@ -86,15 +126,21 @@ object SchemaPruning extends Rule[LogicalPlan] {
       fsRelation.fileFormat.isInstanceOf[OrcFileFormat]
 
   /**
+   * Checks to see if the given [[FileTable]] can be pruned. Currently we 
support ORC v2.
+   */
+  private def canPruneTable(table: FileTable) =
+    table.isInstanceOf[OrcTable]
+
+  /**
    * Normalizes the names of the attribute references in the given projects 
and filters to reflect
    * the names in the given logical relation. This makes it possible to 
compare attributes and
    * fields by name. Returns a tuple with the normalized projects and filters, 
respectively.
    */
   private def normalizeAttributeRefNames(
-      logicalRelation: LogicalRelation,
+      output: Seq[AttributeReference],
       projects: Seq[NamedExpression],
       filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = {
-    val normalizedAttNameMap = logicalRelation.output.map(att => (att.exprId, 
att.name)).toMap
+    val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap
     val normalizedProjects = projects.map(_.transform {
       case att: AttributeReference if 
normalizedAttNameMap.contains(att.exprId) =>
         att.withName(normalizedAttNameMap(att.exprId))
@@ -107,11 +153,13 @@ object SchemaPruning extends Rule[LogicalPlan] {
   }
 
   /**
-   * Builds the new output [[Project]] Spark SQL operator that has the pruned 
output relation.
+   * Builds the new output [[Project]] Spark SQL operator that has the 
`leafNode`.
    */
   private def buildNewProjection(
-      projects: Seq[NamedExpression], filters: Seq[Expression], 
prunedRelation: LogicalRelation,
-      projectionOverSchema: ProjectionOverSchema) = {
+      projects: Seq[NamedExpression],
+      filters: Seq[Expression],
+      leafNode: LeafNode,
+      projectionOverSchema: ProjectionOverSchema): Project = {
     // Construct a new target for our projection by rewriting and
     // including the original filters where available
     val projectionChild =
@@ -120,9 +168,9 @@ object SchemaPruning extends Rule[LogicalPlan] {
           case projectionOverSchema(expr) => expr
         })
         val newFilterCondition = projectedFilters.reduce(And)
-        Filter(newFilterCondition, prunedRelation)
+        Filter(newFilterCondition, leafNode)
       } else {
-        prunedRelation
+        leafNode
       }
 
     // Construct the new projections of our Project by
@@ -145,20 +193,36 @@ object SchemaPruning extends Rule[LogicalPlan] {
   private def buildPrunedRelation(
       outputRelation: LogicalRelation,
       prunedBaseRelation: HadoopFsRelation) = {
+    val prunedOutput = getPrunedOutput(outputRelation.output, 
prunedBaseRelation.schema)
+    outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput)
+  }
+
+  /**
+   * Builds a pruned data source V2 relation from the output of the relation 
and the schema
+   * of the pruned [[FileTable]].
+   */
+  private def buildPrunedRelationV2(
+      outputRelation: DataSourceV2Relation,
+      prunedFileTable: FileTable) = {
+    val prunedOutput = getPrunedOutput(outputRelation.output, 
prunedFileTable.schema)
+    outputRelation.copy(table = prunedFileTable, output = prunedOutput)
+  }
+
+  // Prune the given output to make it consistent with `requiredSchema`.
+  private def getPrunedOutput(
+      output: Seq[AttributeReference],
+      requiredSchema: StructType): Seq[AttributeReference] = {
     // We need to replace the expression ids of the pruned relation output 
attributes
     // with the expression ids of the original relation output attributes so 
that
     // references to the original relation's output are not broken
-    val outputIdMap = outputRelation.output.map(att => (att.name, 
att.exprId)).toMap
-    val prunedRelationOutput =
-      prunedBaseRelation
-        .schema
-        .toAttributes
-        .map {
-          case att if outputIdMap.contains(att.name) =>
-            att.withExprId(outputIdMap(att.name))
-          case att => att
-        }
-    outputRelation.copy(relation = prunedBaseRelation, output = 
prunedRelationOutput)
+    val outputIdMap = output.map(att => (att.name, att.exprId)).toMap
+    requiredSchema
+      .toAttributes
+      .map {
+        case att if outputIdMap.contains(att.name) =>
+          att.withExprId(outputIdMap(att.name))
+        case att => att
+      }
   }
 
   /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 22317fe..09ca428 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -407,7 +407,7 @@ abstract class SchemaPruningSuite
     }
   }
 
-  private val schemaEquality = new Equality[StructType] {
+  protected val schemaEquality = new Equality[StructType] {
     override def areEqual(a: StructType, b: Any): Boolean =
       b match {
         case otherType: StructType => a.sameType(otherType)
@@ -422,7 +422,7 @@ abstract class SchemaPruningSuite
     df.collect()
   }
 
-  private def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: 
String*): Unit = {
+  protected def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: 
String*): Unit = {
     val fileSourceScanSchemata =
       df.queryExecution.executedPlan.collect {
         case scan: FileSourceScanExec => scan.requiredSchema
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala
similarity index 95%
copy from 
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
copy to 
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala
index 2623bf9..832da59 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkConf
 import org.apache.spark.sql.execution.datasources.SchemaPruningSuite
 import org.apache.spark.sql.internal.SQLConf
 
-class OrcSchemaPruningSuite extends SchemaPruningSuite {
+class OrcV1SchemaPruningSuite extends SchemaPruningSuite {
   override protected val dataSourceName: String = "orc"
   override protected val vectorizedReaderEnabledKey: String =
     SQLConf.ORC_VECTORIZED_READER_ENABLED.key
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
similarity index 52%
rename from 
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
rename to 
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
index 2623bf9..b042f7f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
@@ -14,14 +14,17 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.spark.sql.execution.datasources.orc
 
 import org.apache.spark.SparkConf
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.execution.datasources.SchemaPruningSuite
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
 import org.apache.spark.sql.internal.SQLConf
 
-class OrcSchemaPruningSuite extends SchemaPruningSuite {
+class OrcV2SchemaPruningSuite extends SchemaPruningSuite {
   override protected val dataSourceName: String = "orc"
   override protected val vectorizedReaderEnabledKey: String =
     SQLConf.ORC_VECTORIZED_READER_ENABLED.key
@@ -29,6 +32,21 @@ class OrcSchemaPruningSuite extends SchemaPruningSuite {
   override protected def sparkConf: SparkConf =
     super
       .sparkConf
-      .set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc")
-      .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "orc")
+      .set(SQLConf.USE_V1_SOURCE_READER_LIST, "")
+
+  override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: 
String*): Unit = {
+    val fileSourceScanSchemata =
+      df.queryExecution.executedPlan.collect {
+        case BatchScanExec(_, scan: OrcScan) => scan.readDataSchema
+      }
+    assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
+      s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
+        s"but expected $expectedSchemaCatalogStrings")
+    fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach {
+      case (scanSchema, expectedScanSchemaCatalogString) =>
+        val expectedScanSchema = 
CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString)
+        implicit val equality = schemaEquality
+        assert(scanSchema === expectedScanSchema)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to