Github user liancheng commented on a diff in the pull request:
https://github.com/apache/spark/pull/4308#discussion_r24192566
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala ---
@@ -228,66 +347,397 @@ case class ParquetRelation2(path: String)(@transient
val sqlContext: SQLContext)
val cacheMetadata = useCache
@transient
- val cachedStatus = selectedPartitions.flatMap(_.files)
+ val cachedStatus = selectedFiles
// Overridden so we can inject our own cached files statuses.
override def getPartitions: Array[SparkPartition] = {
- val inputFormat =
- if (cacheMetadata) {
- new FilteringParquetRowInputFormat {
- override def listStatus(jobContext: JobContext):
JList[FileStatus] = cachedStatus
- }
- } else {
- new FilteringParquetRowInputFormat
+ val inputFormat = if (cacheMetadata) {
+ new FilteringParquetRowInputFormat {
+ override def listStatus(jobContext: JobContext):
JList[FileStatus] = cachedStatus
}
-
- inputFormat match {
- case configurable: Configurable =>
- configurable.setConf(getConf)
- case _ =>
+ } else {
+ new FilteringParquetRowInputFormat
}
+
val jobContext = newJobContext(getConf, jobId)
- val rawSplits = inputFormat.getSplits(jobContext).toArray
- val result = new Array[SparkPartition](rawSplits.size)
- for (i <- 0 until rawSplits.size) {
- result(i) =
- new NewHadoopPartition(id, i,
rawSplits(i).asInstanceOf[InputSplit with Writable])
+ val rawSplits = inputFormat.getSplits(jobContext)
+
+ Array.tabulate[SparkPartition](rawSplits.size) { i =>
+ new NewHadoopPartition(id, i,
rawSplits(i).asInstanceOf[InputSplit with Writable])
}
- result
}
}
- // The ordinal for the partition key in the result row, if requested.
- val partitionKeyLocation =
- partitionKeys
- .headOption
- .map(requiredColumns.indexOf(_))
- .getOrElse(-1)
+ // The ordinals for partition keys in the result row, if requested.
+ val partitionKeyLocations =
partitionColumns.fieldNames.zipWithIndex.map {
+ case (name, index) => index -> requiredColumns.indexOf(name)
+ }.toMap.filter {
+ case (_, index) => index >= 0
+ }
// When the data does not include the key and the key is requested
then we must fill it in
// based on information from the input split.
- if (!dataIncludesKey && partitionKeyLocation != -1) {
- baseRDD.mapPartitionsWithInputSplit { case (split, iter) =>
- val partValue = "([^=]+)=([^=]+)".r
- val partValues =
- split.asInstanceOf[parquet.hadoop.ParquetInputSplit]
- .getPath
- .toString
- .split("/")
- .flatMap {
- case partValue(key, value) => Some(key -> value)
- case _ => None
- }.toMap
-
- val currentValue = partValues.values.head.toInt
- iter.map { pair =>
- val res = pair._2.asInstanceOf[SpecificMutableRow]
- res.setInt(partitionKeyLocation, currentValue)
- res
+ if (!dataSchemaIncludesPartitionKeys &&
partitionKeyLocations.nonEmpty) {
+ baseRDD.mapPartitionsWithInputSplit { case (split:
ParquetInputSplit, iterator) =>
+ val partValues = selectedPartitions.collectFirst {
+ case p if split.getPath.getParent.toString == p.path => p.values
+ }.get
+
+ iterator.map { pair =>
+ val row = pair._2.asInstanceOf[SpecificMutableRow]
+ var i = 0
+ while (i < partValues.size) {
+ // TODO Avoids boxing cost here!
+ row.update(partitionKeyLocations(i), partValues(i))
+ i += 1
+ }
+ row
}
}
} else {
baseRDD.map(_._2)
}
}
+
+ private def prunePartitions(
+ predicates: Seq[Expression],
+ partitions: Seq[Partition]): Seq[Partition] = {
+ val partitionColumnNames = partitionColumns.map(_.name).toSet
+ val partitionPruningPredicates = predicates.filter {
+ _.references.map(_.name).toSet.subsetOf(partitionColumnNames)
+ }
+
+ val rawPredicate =
partitionPruningPredicates.reduceOption(And).getOrElse(Literal(true))
+ val boundPredicate = InterpretedPredicate(rawPredicate transform {
+ case a: AttributeReference =>
+ val index = partitionColumns.indexWhere(a.name == _.name)
+ BoundReference(index, partitionColumns(index).dataType, nullable =
true)
+ })
+
+ if (isPartitioned && partitionPruningPredicates.nonEmpty) {
+ partitions.filter(p => boundPredicate(p.values))
+ } else {
+ partitions
+ }
+ }
+
+ override def insert(data: DataFrame, overwrite: Boolean): Unit = {
+ // TODO: currently we do not check whether the "schema"s are compatible
+ // That means if one first creates a table and then INSERTs data with
+ // and incompatible schema the execution will fail. It would be nice
+ // to catch this early one, maybe having the planner validate the
schema
+ // before calling execute().
+
+ val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
+ val writeSupport = if (schema.map(_.dataType).forall(_.isPrimitive)) {
+ log.debug("Initializing MutableRowWriteSupport")
+ classOf[MutableRowWriteSupport]
+ } else {
+ classOf[RowWriteSupport]
+ }
+
+ ParquetOutputFormat.setWriteSupportClass(job, writeSupport)
+
+ val conf = ContextUtil.getConfiguration(job)
+ RowWriteSupport.setSchema(schema.toAttributes, conf)
+
+ val destinationPath = new Path(paths.head)
+
+ if (overwrite) {
+ try {
+ destinationPath.getFileSystem(conf).delete(destinationPath, true)
+ } catch {
+ case e: IOException =>
+ throw new IOException(
+ s"Unable to clear output directory ${destinationPath.toString}
prior" +
+ s" to writing to Parquet file:\n${e.toString}")
+ }
+ }
+
+ job.setOutputKeyClass(classOf[Void])
+ job.setOutputValueClass(classOf[Row])
+ FileOutputFormat.setOutputPath(job, destinationPath)
+
+ val wrappedConf = new SerializableWritable(job.getConfiguration)
+ val jobTrackerId = new SimpleDateFormat("yyyyMMddHHmm").format(new
Date())
+ val stageId = sqlContext.sparkContext.newRddId()
+
+ val taskIdOffset = if (overwrite) {
+ 1
+ } else {
+ FileSystemHelper.findMaxTaskId(
+ FileOutputFormat.getOutputPath(job).toString,
job.getConfiguration) + 1
+ }
+
+ def writeShard(context: TaskContext, iterator: Iterator[Row]): Unit = {
+ /* "reduce task" <split #> <attempt # = spark task #> */
+ val attemptId = newTaskAttemptID(
+ jobTrackerId, stageId, isMap = false, context.partitionId(),
context.attemptNumber())
+ val hadoopContext = newTaskAttemptContext(wrappedConf.value,
attemptId)
+ val format = new AppendingParquetOutputFormat(taskIdOffset)
+ val committer = format.getOutputCommitter(hadoopContext)
+ committer.setupTask(hadoopContext)
+ val writer = format.getRecordWriter(hadoopContext)
+ try {
+ while (iterator.hasNext) {
+ val row = iterator.next()
+ writer.write(null, row)
+ }
+ } finally {
+ writer.close(hadoopContext)
+ }
+ committer.commitTask(hadoopContext)
+ }
+ val jobFormat = new AppendingParquetOutputFormat(taskIdOffset)
+ /* apparently we need a TaskAttemptID to construct an OutputCommitter;
+ * however we're only going to use this local OutputCommitter for
+ * setupJob/commitJob, so we just use a dummy "map" task.
+ */
+ val jobAttemptId = newTaskAttemptID(jobTrackerId, stageId, isMap =
true, 0, 0)
+ val jobTaskContext = newTaskAttemptContext(wrappedConf.value,
jobAttemptId)
+ val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
+
+ jobCommitter.setupJob(jobTaskContext)
+
sqlContext.sparkContext.runJob(data.queryExecution.executedPlan.execute(),
writeShard _)
+ jobCommitter.commitJob(jobTaskContext)
+
+ metadataCache.refresh()
+ }
+}
+
+object ParquetRelation2 {
+ // Whether we should merge schemas collected from all Parquet part-files.
+ val MERGE_SCHEMA = "parquet.mergeSchema"
--- End diff --
Thanks. Will address these in follow-up PR(s).
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]