Github user rdblue commented on a diff in the pull request: https://github.com/apache/spark/pull/21029#discussion_r182596392 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala --- @@ -17,52 +17,85 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ -import scala.reflect.ClassTag - import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.v2.DataFormat +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.types.StructType -class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T]) +class DataSourceRDDPartition(val index: Int, val factory: DataReaderFactory) extends Partition with Serializable -class DataSourceRDD[T: ClassTag]( +class DataSourceRDD( sc: SparkContext, - @transient private val readerFactories: Seq[DataReaderFactory[T]]) - extends RDD[T](sc, Nil) { + @transient private val readerFactories: Seq[DataReaderFactory], + schema: StructType) + extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { readerFactories.zipWithIndex.map { case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader() - context.addTaskCompletionListener(_ => reader.close()) - val iter = new Iterator[T] { - private[this] var valuePrepared = false - - override def hasNext: Boolean = { - if (!valuePrepared) { - valuePrepared = reader.next() - } - valuePrepared - } - - override def next(): T = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") - } - valuePrepared = false - reader.get() - } + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val factory = split.asInstanceOf[DataSourceRDDPartition].factory + val iter: DataReaderIterator[UnsafeRow] = factory.dataFormat() match { + case DataFormat.ROW => + val reader = new RowToUnsafeDataReader( + factory.createRowDataReader(), RowEncoder.apply(schema).resolveAndBind()) + new DataReaderIterator(reader) + + case DataFormat.UNSAFE_ROW => + new DataReaderIterator(factory.createUnsafeRowDataReader()) + + case DataFormat.COLUMNAR_BATCH => + new DataReaderIterator(factory.createColumnarBatchDataReader()) + // TODO: remove this type erase hack. + .asInstanceOf[DataReaderIterator[UnsafeRow]] --- End diff -- Isn't this change intended to avoid these casts?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org