Repository: spark Updated Branches: refs/heads/master b85eb946a -> 75db14864
[SPARK-22392][SQL] data source v2 columnar batch reader ## What changes were proposed in this pull request? a new Data Source V2 interface to allow the data source to return `ColumnarBatch` during the scan. ## How was this patch tested? new tests Author: Wenchen Fan <[email protected]> Closes #20153 from cloud-fan/columnar-reader. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/75db1486 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/75db1486 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/75db1486 Branch: refs/heads/master Commit: 75db14864d2bd9b8e13154226e94d466e3a7e0a0 Parents: b85eb94 Author: Wenchen Fan <[email protected]> Authored: Tue Jan 16 22:41:30 2018 +0800 Committer: Wenchen Fan <[email protected]> Committed: Tue Jan 16 22:41:30 2018 +0800 ---------------------------------------------------------------------- .../sources/v2/reader/DataSourceV2Reader.java | 5 +- .../v2/reader/SupportsScanColumnarBatch.java | 52 +++++++++ .../v2/reader/SupportsScanUnsafeRow.java | 2 +- .../spark/sql/execution/ColumnarBatchScan.scala | 37 +++++- .../sql/execution/DataSourceScanExec.scala | 39 ++----- .../columnar/InMemoryTableScanExec.scala | 101 +++++++++-------- .../datasources/v2/DataSourceRDD.scala | 20 ++-- .../datasources/v2/DataSourceV2ScanExec.scala | 72 +++++++----- .../ContinuousDataSourceRDDIter.scala | 4 +- .../sql/sources/v2/JavaBatchDataSourceV2.java | 112 +++++++++++++++++++ .../sql/execution/WholeStageCodegenSuite.scala | 28 ++--- .../sql/sources/v2/DataSourceV2Suite.scala | 72 +++++++++++- 12 files changed, 400 insertions(+), 144 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index 95ee4a8..f23c384 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -38,7 +38,10 @@ import org.apache.spark.sql.types.StructType; * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. * Names of these interfaces start with `SupportsReporting`. * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. - * Names of these interfaces start with `SupportsScan`. + * Names of these interfaces start with `SupportsScan`. Note that a reader should only + * implement at most one of the special scans, if more than one special scans are implemented, + * only one of them would be respected, according to the priority list from high to low: + * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * * If an exception was throw when applying any of these query optimizations, the action would fail * and no Spark job was submitted. http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java new file mode 100644 index 0000000..27cf3a7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to output {@link ColumnarBatch} and make the scan faster. + */ [email protected] +public interface SupportsScanColumnarBatch extends DataSourceV2Reader { + @Override + default List<ReadTask<Row>> createReadTasks() { + throw new IllegalStateException( + "createReadTasks not supported by default within SupportsScanColumnarBatch."); + } + + /** + * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns columnar data in batches. + */ + List<ReadTask<ColumnarBatch>> createBatchReadTasks(); + + /** + * Returns true if the concrete data source reader can read data in batch according to the scan + * properties like required columns, pushes filters, etc. It's possible that the implementation + * can only support some certain columns with certain types. Users can overwrite this method and + * {@link #createReadTasks()} to fallback to normal read path under some conditions. + */ + default boolean enableBatchRead() { + return true; + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index b90ec88..2d3ad0e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -35,7 +35,7 @@ public interface SupportsScanUnsafeRow extends DataSourceV2Reader { @Override default List<ReadTask<Row>> createReadTasks() { throw new IllegalStateException( - "createReadTasks should not be called with SupportsScanUnsafeRow."); + "createReadTasks not supported by default within SupportsScanUnsafeRow"); } /** http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 5617046..dd68df9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType @@ -25,13 +25,16 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** - * Helper trait for abstracting scan functionality using - * [[ColumnarBatch]]es. + * Helper trait for abstracting scan functionality using [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { def vectorTypes: Option[Seq[String]] = None + protected def supportsBatch: Boolean = true + + protected def needsUnsafeRowConversion: Boolean = true + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) @@ -71,7 +74,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // PhysicalRDD always just has one input val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") + if (supportsBatch) { + produceBatches(ctx, input) + } else { + produceRows(ctx, input) + } + } + private def produceBatches(ctx: CodegenContext, input: String): String = { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") @@ -137,4 +147,25 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { """.stripMargin } + private def produceRows(ctx: CodegenContext, input: String): String = { + val numOutputRows = metricTerm(ctx, "numOutputRows") + val row = ctx.freshName("row") + + ctx.INPUT_ROW = row + ctx.currentVars = null + // Always provide `outputVars`, so that the framework can help us build unsafe row if the input + // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. + val outputVars = output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + val inputRow = if (needsUnsafeRowConversion) null else row + s""" + |while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${consume(ctx, outputVars, inputRow).trim} + | if (shouldStop()) return; + |} + """.stripMargin + } } http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index d1ff82c..7c7d79c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -164,13 +164,15 @@ case class FileSourceScanExec( override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { - val supportsBatch: Boolean = relation.fileFormat.supportBatch( + override val supportsBatch: Boolean = relation.fileFormat.supportBatch( relation.sparkSession, StructType.fromAttributes(output)) - val needsUnsafeRowConversion: Boolean = if (relation.fileFormat.isInstanceOf[ParquetSource]) { - SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled - } else { - false + override val needsUnsafeRowConversion: Boolean = { + if (relation.fileFormat.isInstanceOf[ParquetSource]) { + SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled + } else { + false + } } override def vectorTypes: Option[Seq[String]] = @@ -346,33 +348,6 @@ case class FileSourceScanExec( override val nodeNamePrefix: String = "File" - override protected def doProduce(ctx: CodegenContext): String = { - if (supportsBatch) { - return super.doProduce(ctx) - } - val numOutputRows = metricTerm(ctx, "numOutputRows") - // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") - val row = ctx.freshName("row") - - ctx.INPUT_ROW = row - ctx.currentVars = null - // Always provide `outputVars`, so that the framework can help us build unsafe row if the input - // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. - val outputVars = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable).genCode(ctx) - } - val inputRow = if (needsUnsafeRowConversion) null else row - s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutputRows.add(1); - | ${consume(ctx, outputVars, inputRow).trim} - | if (shouldStop()) return; - |} - """.stripMargin - } - /** * Create an RDD for bucketed reads. * The non-bucketed variant of this function is [[createNonBucketedReadRDD]]. http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 933b975..3565ee3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -49,9 +49,9 @@ case class InMemoryTableScanExec( /** * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. - * If false, get data from UnsafeRow build from ColumnVector + * If false, get data from UnsafeRow build from CachedBatch */ - override val supportCodegen: Boolean = { + override val supportsBatch: Boolean = { // In the initial implementation, for ease of review // support only primitive data types and # of fields is less than wholeStageMaxNumFields relation.schema.fields.forall(f => f.dataType match { @@ -61,6 +61,8 @@ case class InMemoryTableScanExec( }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) } + override protected def needsUnsafeRowConversion: Boolean = false + private val columnIndices = attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray @@ -90,14 +92,56 @@ case class InMemoryTableScanExec( columnarBatch } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - assert(supportCodegen) + private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() - // HACK ALERT: This is actually an RDD[ColumnarBatch]. - // We're taking advantage of Scala's type erasure here to pass these batches along. - Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) + if (supportsBatch) { + // HACK ALERT: This is actually an RDD[ColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]] + } else { + val numOutputRows = longMetric("numOutputRows") + + if (enableAccumulatorsForTest) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + // Using these variables here to avoid serialization of entire objects (if referenced + // directly) within the map Partitions closure. + val relOutput: AttributeSeq = relation.output + + filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => + // Find the ordinals and data types of the requested columns. + val (requestedColumnIndices, requestedColumnDataTypes) = + attributes.map { a => + relOutput.indexOf(a.exprId) -> a.dataType + }.unzip + + // update SQL metrics + val withMetrics = cachedBatchIterator.map { batch => + if (enableAccumulatorsForTest) { + readBatches.add(1) + } + numOutputRows += batch.numRows + batch + } + + val columnTypes = requestedColumnDataTypes.map { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + }.toArray + val columnarIterator = GenerateColumnAccessor.generate(columnTypes) + columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) + if (enableAccumulatorsForTest && columnarIterator.hasNext) { + readPartitions.add(1) + } + columnarIterator + } + } } + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) + override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { @@ -185,7 +229,7 @@ case class InMemoryTableScanExec( } } - lazy val enableAccumulators: Boolean = + lazy val enableAccumulatorsForTest: Boolean = sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean // Accumulators used for testing purposes @@ -230,43 +274,10 @@ case class InMemoryTableScanExec( } protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - if (enableAccumulators) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - - // Using these variables here to avoid serialization of entire objects (if referenced directly) - // within the map Partitions closure. - val relOutput: AttributeSeq = relation.output - - filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => - // Find the ordinals and data types of the requested columns. - val (requestedColumnIndices, requestedColumnDataTypes) = - attributes.map { a => - relOutput.indexOf(a.exprId) -> a.dataType - }.unzip - - // update SQL metrics - val withMetrics = cachedBatchIterator.map { batch => - if (enableAccumulators) { - readBatches.add(1) - } - numOutputRows += batch.numRows - batch - } - - val columnTypes = requestedColumnDataTypes.map { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - }.toArray - val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) - if (enableAccumulators && columnarIterator.hasNext) { - readPartitions.add(1) - } - columnarIterator + if (supportsBatch) { + WholeStageCodegenExec(this).execute() + } else { + inputRDD } } } http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5f30be5..ac104d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -18,19 +18,19 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.sources.v2.reader.ReadTask -class DataSourceRDDPartition(val index: Int, val readTask: ReadTask[UnsafeRow]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val readTask: ReadTask[T]) extends Partition with Serializable -class DataSourceRDD( +class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { + @transient private val readTasks: java.util.List[ReadTask[T]]) + extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { readTasks.asScala.zipWithIndex.map { @@ -38,10 +38,10 @@ class DataSourceRDD( }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readTask.createDataReader() context.addTaskCompletionListener(_ => reader.close()) - val iter = new Iterator[UnsafeRow] { + val iter = new Iterator[T] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -51,7 +51,7 @@ class DataSourceRDD( valuePrepared } - override def next(): UnsafeRow = { + override def next(): T = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -63,6 +63,6 @@ class DataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].readTask.preferredLocations() } } http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 49c506b..8c64df0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -24,10 +24,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions} +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types.StructType @@ -37,40 +35,56 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( fullOutput: Seq[AttributeReference], - @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { + @transient reader: DataSourceV2Reader) + extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def references: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = AttributeSet(fullOutput) - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() + case _ => + reader.createReadTasks().asScala.map { + new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] + }.asJava + } - override protected def doExecute(): RDD[InternalRow] = { - val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() - case _ => - reader.createReadTasks().asScala.map { - new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] - }.asJava - } + private lazy val inputRDD: RDD[InternalRow] = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + assert(!reader.isInstanceOf[ContinuousReader], + "continuous stream reader does not support columnar read yet.") + new DataSourceRDD(sparkContext, r.createBatchReadTasks()).asInstanceOf[RDD[InternalRow]] + + case _: ContinuousReader => + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + .askSync[Unit](SetReaderPartitions(readTasks.size())) + new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + .asInstanceOf[RDD[InternalRow]] + + case _ => + new DataSourceRDD(sparkContext, readTasks).asInstanceOf[RDD[InternalRow]] + } - val inputRDD = reader match { - case _: ContinuousReader => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readTasks.size())) + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + override val supportsBatch: Boolean = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => true + case _ => false + } - case _ => - new DataSourceRDD(sparkContext, readTasks) - } + override protected def needsUnsafeRowConversion: Boolean = false - val numOutputRows = longMetric("numOutputRows") - inputRDD.asInstanceOf[RDD[InternalRow]].map { r => - numOutputRows += 1 - r + override protected def doExecute(): RDD[InternalRow] = { + if (supportsBatch) { + WholeStageCodegenExec(this).execute() + } else { + val numOutputRows = longMetric("numOutputRows") + inputRDD.map { r => + numOutputRows += 1 + r + } } } } http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index d79e4bd..b3f1a1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -52,7 +52,7 @@ class ContinuousDataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) @@ -132,7 +132,7 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.preferredLocations() } } http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java new file mode 100644 index 0000000..44e5146 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsScanColumnarBatch { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List<ReadTask<ColumnarBatch>> createBatchReadTasks() { + return java.util.Arrays.asList(new JavaBatchReadTask(0, 50), new JavaBatchReadTask(50, 90)); + } + } + + static class JavaBatchReadTask implements ReadTask<ColumnarBatch>, DataReader<ColumnarBatch> { + private int start; + private int end; + + private static final int BATCH_SIZE = 20; + + private OnHeapColumnVector i; + private OnHeapColumnVector j; + private ColumnarBatch batch; + + JavaBatchReadTask(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public DataReader<ColumnarBatch> createDataReader() { + this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + this.batch = new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + return this; + } + + @Override + public boolean next() { + i.reset(); + j.reset(); + int count = 0; + while (start < end && count < BATCH_SIZE) { + i.putInt(count, start); + j.putInt(count, -start); + start += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + } + + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca..22ca128 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -121,31 +121,23 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { import testImplicits._ - val dsInt = spark.range(3).cache - dsInt.count + val dsInt = spark.range(3).cache() + dsInt.count() val dsIntFilter = dsInt.filter(_ > 0) val planInt = dsIntFilter.queryExecution.executedPlan - assert(planInt.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined - ) + assert(planInt.collect { + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if i.supportsBatch => () + }.length == 1) assert(dsIntFilter.collect() === Array(1, 2)) // cache for string type is not supported for InMemoryTableScanExec - val dsString = spark.range(3).map(_.toString).cache - dsString.count + val dsString = spark.range(3).map(_.toString).cache() + dsString.count() val dsStringFilter = dsString.filter(_ == "1") val planString = dsStringFilter.queryExecution.executedPlan - assert(planString.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec]).isDefined - ) + assert(planString.collect { + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if !i.supportsBatch => () + }.length == 1) assert(dsStringFilter.collect() === Array("1")) } http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ab37e49..a89f7c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,10 +24,12 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -56,7 +58,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("unsafe row implementation") { + test("unsafe row scan implementation") { Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -67,6 +69,17 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("columnar batch scan implementation") { + Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 90).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 90).map(i => Row(-i))) + checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i))) + } + } + } + test("schema required data source") { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { @@ -275,7 +288,7 @@ class UnsafeRowReadTask(start: Int, end: Int) private var current = start - 1 - override def createDataReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) + override def createDataReader(): DataReader[UnsafeRow] = this override def next(): Boolean = { current += 1 @@ -300,3 +313,56 @@ class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = new Reader(schema) } + +class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createBatchReadTasks(): JList[ReadTask[ColumnarBatch]] = { + java.util.Arrays.asList(new BatchReadTask(0, 50), new BatchReadTask(50, 90)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class BatchReadTask(start: Int, end: Int) + extends ReadTask[ColumnarBatch] with DataReader[ColumnarBatch] { + + private final val BATCH_SIZE = 20 + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch( + new StructType().add("i", "int").add("j", "int"), Array(i, j), BATCH_SIZE) + + private var current = start + + override def createDataReader(): DataReader[ColumnarBatch] = this + + override def next(): Boolean = { + i.reset() + j.reset() + + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } + + if (count == 0) { + false + } else { + batch.setNumRows(count) + true + } + } + + override def get(): ColumnarBatch = { + batch + } + + override def close(): Unit = batch.close() +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
