This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4177292 [SPARK-27435][SQL] Support schema pruning in ORC V2
4177292 is described below
commit 4177292dcd65f77679e74365dc627a506812ddbb
Author: Gengliang Wang <[email protected]>
AuthorDate: Thu Apr 11 20:03:32 2019 +0800
[SPARK-27435][SQL] Support schema pruning in ORC V2
## What changes were proposed in this pull request?
Currently, the optimization rule `SchemaPruning` only works for Parquet/Orc
V1.
We should have the same optimization in ORC V2.
## How was this patch tested?
Unit test
Closes #24338 from gengliangwang/schemaPruningForV2.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/execution/datasources/SchemaPruning.scala | 104 +++++++++++++++++----
.../execution/datasources/SchemaPruningSuite.scala | 4 +-
...ngSuite.scala => OrcV1SchemaPruningSuite.scala} | 2 +-
...ngSuite.scala => OrcV2SchemaPruningSuite.scala} | 26 +++++-
4 files changed, 109 insertions(+), 27 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
index 3a37ca7..15fdf65 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql.execution.datasources
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan,
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode,
LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation,
FileTable}
+import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField,
StructType}
@@ -48,7 +51,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _))
if canPruneRelation(hadoopFsRelation) =>
val (normalizedProjects, normalizedFilters) =
- normalizeAttributeRefNames(l, projects, filters)
+ normalizeAttributeRefNames(l.output, projects, filters)
val requestedRootFields = identifyRootFields(normalizedProjects,
normalizedFilters)
// If requestedRootFields includes a nested field, continue. Otherwise,
@@ -76,6 +79,43 @@ object SchemaPruning extends Rule[LogicalPlan] {
} else {
op
}
+
+ case op @ PhysicalOperation(projects, filters,
+ d @ DataSourceV2Relation(table: FileTable, output, _)) if
canPruneTable(table) =>
+ val (normalizedProjects, normalizedFilters) =
+ normalizeAttributeRefNames(output, projects, filters)
+ val requestedRootFields = identifyRootFields(normalizedProjects,
normalizedFilters)
+
+ // If requestedRootFields includes a nested field, continue. Otherwise,
+ // return op
+ if (requestedRootFields.exists { root: RootField =>
!root.derivedFromAtt }) {
+ val dataSchema = table.dataSchema
+ val prunedDataSchema = pruneDataSchema(dataSchema,
requestedRootFields)
+
+ // If the data schema is different from the pruned data schema,
continue. Otherwise,
+ // return op. We effect this comparison by counting the number of
"leaf" fields in
+ // each schemata, assuming the fields in prunedDataSchema are a
subset of the fields
+ // in dataSchema.
+ if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) {
+ val prunedFileTable = table match {
+ case o: OrcTable => o.copy(userSpecifiedSchema =
Some(prunedDataSchema))
+ case _ =>
+ val message = s"${table.formatName} data source doesn't
support schema pruning."
+ throw new AnalysisException(message)
+ }
+
+
+ val prunedRelationV2 = buildPrunedRelationV2(d, prunedFileTable)
+ val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
+
+ buildNewProjection(normalizedProjects, normalizedFilters,
prunedRelationV2,
+ projectionOverSchema)
+ } else {
+ op
+ }
+ } else {
+ op
+ }
}
/**
@@ -86,15 +126,21 @@ object SchemaPruning extends Rule[LogicalPlan] {
fsRelation.fileFormat.isInstanceOf[OrcFileFormat]
/**
+ * Checks to see if the given [[FileTable]] can be pruned. Currently we
support ORC v2.
+ */
+ private def canPruneTable(table: FileTable) =
+ table.isInstanceOf[OrcTable]
+
+ /**
* Normalizes the names of the attribute references in the given projects
and filters to reflect
* the names in the given logical relation. This makes it possible to
compare attributes and
* fields by name. Returns a tuple with the normalized projects and filters,
respectively.
*/
private def normalizeAttributeRefNames(
- logicalRelation: LogicalRelation,
+ output: Seq[AttributeReference],
projects: Seq[NamedExpression],
filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = {
- val normalizedAttNameMap = logicalRelation.output.map(att => (att.exprId,
att.name)).toMap
+ val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap
val normalizedProjects = projects.map(_.transform {
case att: AttributeReference if
normalizedAttNameMap.contains(att.exprId) =>
att.withName(normalizedAttNameMap(att.exprId))
@@ -107,11 +153,13 @@ object SchemaPruning extends Rule[LogicalPlan] {
}
/**
- * Builds the new output [[Project]] Spark SQL operator that has the pruned
output relation.
+ * Builds the new output [[Project]] Spark SQL operator that has the
`leafNode`.
*/
private def buildNewProjection(
- projects: Seq[NamedExpression], filters: Seq[Expression],
prunedRelation: LogicalRelation,
- projectionOverSchema: ProjectionOverSchema) = {
+ projects: Seq[NamedExpression],
+ filters: Seq[Expression],
+ leafNode: LeafNode,
+ projectionOverSchema: ProjectionOverSchema): Project = {
// Construct a new target for our projection by rewriting and
// including the original filters where available
val projectionChild =
@@ -120,9 +168,9 @@ object SchemaPruning extends Rule[LogicalPlan] {
case projectionOverSchema(expr) => expr
})
val newFilterCondition = projectedFilters.reduce(And)
- Filter(newFilterCondition, prunedRelation)
+ Filter(newFilterCondition, leafNode)
} else {
- prunedRelation
+ leafNode
}
// Construct the new projections of our Project by
@@ -145,20 +193,36 @@ object SchemaPruning extends Rule[LogicalPlan] {
private def buildPrunedRelation(
outputRelation: LogicalRelation,
prunedBaseRelation: HadoopFsRelation) = {
+ val prunedOutput = getPrunedOutput(outputRelation.output,
prunedBaseRelation.schema)
+ outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput)
+ }
+
+ /**
+ * Builds a pruned data source V2 relation from the output of the relation
and the schema
+ * of the pruned [[FileTable]].
+ */
+ private def buildPrunedRelationV2(
+ outputRelation: DataSourceV2Relation,
+ prunedFileTable: FileTable) = {
+ val prunedOutput = getPrunedOutput(outputRelation.output,
prunedFileTable.schema)
+ outputRelation.copy(table = prunedFileTable, output = prunedOutput)
+ }
+
+ // Prune the given output to make it consistent with `requiredSchema`.
+ private def getPrunedOutput(
+ output: Seq[AttributeReference],
+ requiredSchema: StructType): Seq[AttributeReference] = {
// We need to replace the expression ids of the pruned relation output
attributes
// with the expression ids of the original relation output attributes so
that
// references to the original relation's output are not broken
- val outputIdMap = outputRelation.output.map(att => (att.name,
att.exprId)).toMap
- val prunedRelationOutput =
- prunedBaseRelation
- .schema
- .toAttributes
- .map {
- case att if outputIdMap.contains(att.name) =>
- att.withExprId(outputIdMap(att.name))
- case att => att
- }
- outputRelation.copy(relation = prunedBaseRelation, output =
prunedRelationOutput)
+ val outputIdMap = output.map(att => (att.name, att.exprId)).toMap
+ requiredSchema
+ .toAttributes
+ .map {
+ case att if outputIdMap.contains(att.name) =>
+ att.withExprId(outputIdMap(att.name))
+ case att => att
+ }
}
/**
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 22317fe..09ca428 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -407,7 +407,7 @@ abstract class SchemaPruningSuite
}
}
- private val schemaEquality = new Equality[StructType] {
+ protected val schemaEquality = new Equality[StructType] {
override def areEqual(a: StructType, b: Any): Boolean =
b match {
case otherType: StructType => a.sameType(otherType)
@@ -422,7 +422,7 @@ abstract class SchemaPruningSuite
df.collect()
}
- private def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings:
String*): Unit = {
+ protected def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings:
String*): Unit = {
val fileSourceScanSchemata =
df.queryExecution.executedPlan.collect {
case scan: FileSourceScanExec => scan.requiredSchema
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala
similarity index 95%
copy from
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
copy to
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala
index 2623bf9..832da59 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.datasources.SchemaPruningSuite
import org.apache.spark.sql.internal.SQLConf
-class OrcSchemaPruningSuite extends SchemaPruningSuite {
+class OrcV1SchemaPruningSuite extends SchemaPruningSuite {
override protected val dataSourceName: String = "orc"
override protected val vectorizedReaderEnabledKey: String =
SQLConf.ORC_VECTORIZED_READER_ENABLED.key
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
similarity index 52%
rename from
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
rename to
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
index 2623bf9..b042f7f 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
@@ -14,14 +14,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.sql.execution.datasources.orc
import org.apache.spark.SparkConf
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.execution.datasources.SchemaPruningSuite
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.internal.SQLConf
-class OrcSchemaPruningSuite extends SchemaPruningSuite {
+class OrcV2SchemaPruningSuite extends SchemaPruningSuite {
override protected val dataSourceName: String = "orc"
override protected val vectorizedReaderEnabledKey: String =
SQLConf.ORC_VECTORIZED_READER_ENABLED.key
@@ -29,6 +32,21 @@ class OrcSchemaPruningSuite extends SchemaPruningSuite {
override protected def sparkConf: SparkConf =
super
.sparkConf
- .set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc")
- .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "orc")
+ .set(SQLConf.USE_V1_SOURCE_READER_LIST, "")
+
+ override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings:
String*): Unit = {
+ val fileSourceScanSchemata =
+ df.queryExecution.executedPlan.collect {
+ case BatchScanExec(_, scan: OrcScan) => scan.readDataSchema
+ }
+ assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
+ s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
+ s"but expected $expectedSchemaCatalogStrings")
+ fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach {
+ case (scanSchema, expectedScanSchemaCatalogString) =>
+ val expectedScanSchema =
CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString)
+ implicit val equality = schemaEquality
+ assert(scanSchema === expectedScanSchema)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]