Github user rdblue commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21029#discussion_r182596392
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
 ---
    @@ -17,52 +17,85 @@
     
     package org.apache.spark.sql.execution.datasources.v2
     
    -import scala.collection.JavaConverters._
    -import scala.reflect.ClassTag
    -
     import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, 
TaskContext}
     import org.apache.spark.rdd.RDD
    -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
    +import org.apache.spark.sql.Row
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, 
RowEncoder}
    +import org.apache.spark.sql.catalyst.expressions.UnsafeRow
    +import org.apache.spark.sql.sources.v2.DataFormat
    +import org.apache.spark.sql.sources.v2.reader.{DataReader, 
DataReaderFactory}
    +import org.apache.spark.sql.types.StructType
     
    -class DataSourceRDDPartition[T : ClassTag](val index: Int, val 
readerFactory: DataReaderFactory[T])
    +class DataSourceRDDPartition(val index: Int, val factory: 
DataReaderFactory)
       extends Partition with Serializable
     
    -class DataSourceRDD[T: ClassTag](
    +class DataSourceRDD(
         sc: SparkContext,
    -    @transient private val readerFactories: Seq[DataReaderFactory[T]])
    -  extends RDD[T](sc, Nil) {
    +    @transient private val readerFactories: Seq[DataReaderFactory],
    +    schema: StructType)
    +  extends RDD[InternalRow](sc, Nil) {
     
       override protected def getPartitions: Array[Partition] = {
         readerFactories.zipWithIndex.map {
           case (readerFactory, index) => new DataSourceRDDPartition(index, 
readerFactory)
         }.toArray
       }
     
    -  override def compute(split: Partition, context: TaskContext): 
Iterator[T] = {
    -    val reader = 
split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader()
    -    context.addTaskCompletionListener(_ => reader.close())
    -    val iter = new Iterator[T] {
    -      private[this] var valuePrepared = false
    -
    -      override def hasNext: Boolean = {
    -        if (!valuePrepared) {
    -          valuePrepared = reader.next()
    -        }
    -        valuePrepared
    -      }
    -
    -      override def next(): T = {
    -        if (!hasNext) {
    -          throw new java.util.NoSuchElementException("End of stream")
    -        }
    -        valuePrepared = false
    -        reader.get()
    -      }
    +  override def compute(split: Partition, context: TaskContext): 
Iterator[InternalRow] = {
    +    val factory = split.asInstanceOf[DataSourceRDDPartition].factory
    +    val iter: DataReaderIterator[UnsafeRow] = factory.dataFormat() match {
    +      case DataFormat.ROW =>
    +        val reader = new RowToUnsafeDataReader(
    +          factory.createRowDataReader(), 
RowEncoder.apply(schema).resolveAndBind())
    +        new DataReaderIterator(reader)
    +
    +      case DataFormat.UNSAFE_ROW =>
    +        new DataReaderIterator(factory.createUnsafeRowDataReader())
    +
    +      case DataFormat.COLUMNAR_BATCH =>
    +        new DataReaderIterator(factory.createColumnarBatchDataReader())
    +          // TODO: remove this type erase hack.
    +          .asInstanceOf[DataReaderIterator[UnsafeRow]]
    --- End diff --
    
    Isn't this change intended to avoid these casts?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to