Github user liancheng commented on a diff in the pull request:
https://github.com/apache/spark/pull/9104#discussion_r41938703
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
---
@@ -178,52 +179,26 @@ private[sql] object DataSourceStrategy extends
Strategy with Logging {
sparkPlan
}
- // TODO: refactor this thing. It is very complicated because it does
projection internally.
- // We should just put a project on top of this.
private def mergeWithPartitionValues(
- schema: StructType,
- requiredColumns: Array[String],
- partitionColumns: Array[String],
+ requiredColumns: Seq[Attribute],
+ dataColumns: Seq[Attribute],
+ partitionColumnSchema: StructType,
partitionValues: InternalRow,
dataRows: RDD[InternalRow]): RDD[InternalRow] = {
- val nonPartitionColumns =
requiredColumns.filterNot(partitionColumns.contains)
-
// If output columns contain any partition column(s), we need to merge
scanned data
// columns and requested partition columns to form the final result.
- if (!requiredColumns.sameElements(nonPartitionColumns)) {
- val mergers = requiredColumns.zipWithIndex.map { case (name, index)
=>
- // To see whether the `index`-th column is a partition column...
- val i = partitionColumns.indexOf(name)
- if (i != -1) {
- val dt = schema(partitionColumns(i)).dataType
- // If yes, gets column value from partition values.
- (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => {
- mutableRow(ordinal) = partitionValues.get(i, dt)
- }
- } else {
- // Otherwise, inherits the value from scanned data.
- val i = nonPartitionColumns.indexOf(name)
- val dt = schema(nonPartitionColumns(i)).dataType
- (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => {
- mutableRow(ordinal) = dataRow.get(i, dt)
- }
- }
+ if (requiredColumns != dataColumns) {
+ // Builds `AttributeReference`s for all partition columns so that we
can use them to project
+ // required partition columns. Note that if a partition column
appears in `requiredColumns`,
+ // we should use the `AttributeReference` in `requiredColumns`.
+ val requiredColumnMap = requiredColumns.map(a => a.name -> a).toMap
+ val partitionColumns = partitionColumnSchema.toAttributes.map { a =>
+ requiredColumnMap.getOrElse(a.name, a)
}
- // Since we know for sure that this closure is serializable, we can
avoid the overhead
- // of cleaning a closure for each RDD by creating our own
MapPartitionsRDD. Functionally
- // this is equivalent to calling
`dataRows.mapPartitions(mapPartitionsFunc)` (SPARK-7718).
val mapPartitionsFunc = (_: TaskContext, _: Int, iterator:
Iterator[InternalRow]) => {
- val dataTypes = requiredColumns.map(schema(_).dataType)
- val mutableRow = new SpecificMutableRow(dataTypes)
- iterator.map { dataRow =>
- var i = 0
- while (i < mutableRow.numFields) {
- mergers(i)(mutableRow, dataRow, i)
- i += 1
- }
- mutableRow.asInstanceOf[InternalRow]
- }
+ val projection = UnsafeProjection.create(requiredColumns,
dataColumns ++ partitionColumns)
+ iterator.map(dataRow => projection(new JoinedRow(dataRow,
partitionValues)))
--- End diff --
Updated, although reusing `JoinedRow` doesn't bring noticeable speedup in
my micro-benchmark.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]