Repository: spark Updated Branches: refs/heads/master 9cfc3ee62 -> 2b2c94a3e
http://git-wip-us.apache.org/repos/asf/spark/blob/2b2c94a3/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java new file mode 100644 index 0000000..cb5954d --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.SupportsBatchRead; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +abstract class JavaSimpleBatchTable implements Table, SupportsBatchRead { + + @Override + public StructType schema() { + return new StructType().add("i", "int").add("j", "int"); + } + + @Override + public String name() { + return this.getClass().toString(); + } +} + +abstract class JavaSimpleScanBuilder implements ScanBuilder, Scan, Batch { + + @Override + public Scan build() { + return this; + } + + @Override + public Batch toBatch() { + return this; + } + + @Override + public StructType readSchema() { + return new StructType().add("i", "int").add("j", "int"); + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new JavaSimpleReaderFactory(); + } +} + +class JavaSimpleReaderFactory implements PartitionReaderFactory { + + @Override + public PartitionReader<InternalRow> createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader<InternalRow>() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {current, -current}); + } + + @Override + public void close() throws IOException { + + } + }; + } +} + http://git-wip-us.apache.org/repos/asf/spark/blob/2b2c94a3/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2cdbba8..852c454 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,17 +17,17 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; -import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; -public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { +public class JavaSimpleDataSourceV2 implements TableProvider { - class ReadSupport extends JavaSimpleReadSupport { + class MyScanBuilder extends JavaSimpleScanBuilder { @Override - public InputPartition[] planInputPartitions(ScanConfig config) { + public InputPartition[] planInputPartitions() { InputPartition[] partitions = new InputPartition[2]; partitions[0] = new JavaRangeInputPartition(0, 5); partitions[1] = new JavaRangeInputPartition(5, 10); @@ -36,7 +36,12 @@ public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportPro } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public Table getTable(DataSourceOptions options) { + return new JavaSimpleBatchTable() { + @Override + public ScanBuilder newScanBuilder(DataSourceOptions options) { + return new MyScanBuilder(); + } + }; } } http://git-wip-us.apache.org/repos/asf/spark/blob/2b2c94a3/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java deleted file mode 100644 index ced51dd..0000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.StructType; - -abstract class JavaSimpleReadSupport implements BatchReadSupport { - - @Override - public StructType fullSchema() { - return new StructType().add("i", "int").add("j", "int"); - } - - @Override - public ScanConfigBuilder newScanConfigBuilder() { - return new JavaNoopScanConfigBuilder(fullSchema()); - } - - @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - return new JavaSimpleReaderFactory(); - } -} - -class JavaNoopScanConfigBuilder implements ScanConfigBuilder, ScanConfig { - - private StructType schema; - - JavaNoopScanConfigBuilder(StructType schema) { - this.schema = schema; - } - - @Override - public ScanConfig build() { - return this; - } - - @Override - public StructType readSchema() { - return schema; - } -} - -class JavaSimpleReaderFactory implements PartitionReaderFactory { - - @Override - public PartitionReader<InternalRow> createReader(InputPartition partition) { - JavaRangeInputPartition p = (JavaRangeInputPartition) partition; - return new PartitionReader<InternalRow>() { - private int current = p.start - 1; - - @Override - public boolean next() throws IOException { - current += 1; - return current < p.end; - } - - @Override - public InternalRow get() { - return new GenericInternalRow(new Object[] {current, -current}); - } - - @Override - public void close() throws IOException { - - } - }; - } -} - http://git-wip-us.apache.org/repos/asf/spark/blob/2b2c94a3/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e8f291a..d282193 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -38,18 +38,17 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ - private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = { + private def getBatch(query: DataFrame): AdvancedBatch = { query.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => - d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder] + d.batch.asInstanceOf[AdvancedBatch] }.head } - private def getJavaScanConfig( - query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = { + private def getJavaBatch(query: DataFrame): JavaAdvancedDataSourceV2.AdvancedBatch = { query.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => - d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder] + d.batch.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedBatch] }.head } @@ -73,51 +72,51 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val q1 = df.select('j) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q1) - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + val batch = getBatch(q1) + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) } else { - val config = getJavaScanConfig(q1) - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + val batch = getJavaBatch(q1) + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) } val q2 = df.filter('i > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q2) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i", "j")) + val batch = getBatch(q2) + assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i", "j")) } else { - val config = getJavaScanConfig(q2) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i", "j")) + val batch = getJavaBatch(q2) + assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i", "j")) } val q3 = df.select('i).filter('i > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q3) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i")) + val batch = getBatch(q3) + assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i")) } else { - val config = getJavaScanConfig(q3) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i")) + val batch = getJavaBatch(q3) + assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i")) } val q4 = df.select('j).filter('j < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q4) + val batch = getBatch(q4) // 'j < 10 is not supported by the testing data source. - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) } else { - val config = getJavaScanConfig(q4) + val batch = getJavaBatch(q4) // 'j < 10 is not supported by the testing data source. - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) } } } @@ -279,26 +278,26 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val q1 = df.select('i + 1) checkAnswer(q1, (1 until 11).map(i => Row(i))) - val config1 = getScanConfig(q1) - assert(config1.requiredSchema.fieldNames === Seq("i")) + val batch1 = getBatch(q1) + assert(batch1.requiredSchema.fieldNames === Seq("i")) val q2 = df.select(lit(1)) checkAnswer(q2, (0 until 10).map(i => Row(1))) - val config2 = getScanConfig(q2) - assert(config2.requiredSchema.isEmpty) + val batch2 = getBatch(q2) + assert(batch2.requiredSchema.isEmpty) // 'j === 1 can't be pushed down, but we should still be able do column pruning val q3 = df.filter('j === -1).select('j * 2) checkAnswer(q3, Row(-2)) - val config3 = getScanConfig(q3) - assert(config3.filters.isEmpty) - assert(config3.requiredSchema.fieldNames === Seq("j")) + val batch3 = getBatch(q3) + assert(batch3.filters.isEmpty) + assert(batch3.requiredSchema.fieldNames === Seq("j")) // column pruning should work with other operators. val q4 = df.sort('i).limit(1).select('i + 1) checkAnswer(q4, Row(1)) - val config4 = getScanConfig(q4) - assert(config4.requiredSchema.fieldNames === Seq("i")) + val batch4 = getBatch(q4) + assert(batch4.requiredSchema.fieldNames === Seq("i")) } test("SPARK-23315: get output from canonicalized data source v2 related plans") { @@ -374,10 +373,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { case class RangeInputPartition(start: Int, end: Int) extends InputPartition -case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig { - override def build(): ScanConfig = this -} - object SimpleReaderFactory extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val RangeInputPartition(start, end) = partition @@ -396,87 +391,68 @@ object SimpleReaderFactory extends PartitionReaderFactory { } } -abstract class SimpleReadSupport extends BatchReadSupport { - override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int") +abstract class SimpleBatchTable extends Table with SupportsBatchRead { - override def newScanConfigBuilder(): ScanConfigBuilder = { - NoopScanConfigBuilder(fullSchema()) - } + override def schema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - SimpleReaderFactory - } + override def name(): String = this.getClass.toString } +abstract class SimpleScanBuilder extends ScanBuilder + with Batch with Scan { + + override def build(): Scan = this + + override def toBatch: Batch = this -class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory +} + +class SimpleSinglePartitionSource extends TableProvider { + + class MyScanBuilder extends SimpleScanBuilder { + override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 5)) } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport + override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder() + } } } // This class is used by pyspark tests. If this class is modified/moved, make sure pyspark // tests still pass. -class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { +class SimpleDataSourceV2 extends TableProvider { - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + class MyScanBuilder extends SimpleScanBuilder { + override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport + override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder() + } } } -class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - - class ReadSupport extends SimpleReadSupport { - override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder() +class AdvancedDataSourceV2 extends TableProvider { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters - - val lowerBound = filters.collectFirst { - case GreaterThan("i", v: Int) => v - } - - val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] - - if (lowerBound.isEmpty) { - res.append(RangeInputPartition(0, 5)) - res.append(RangeInputPartition(5, 10)) - } else if (lowerBound.get < 4) { - res.append(RangeInputPartition(lowerBound.get + 1, 5)) - res.append(RangeInputPartition(5, 10)) - } else if (lowerBound.get < 9) { - res.append(RangeInputPartition(lowerBound.get + 1, 10)) - } - - res.toArray - } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema - new AdvancedReaderFactory(requiredSchema) + override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new AdvancedScanBuilder() } } - - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } } -class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig - with SupportsPushDownRequiredColumns with SupportsPushDownFilters { +class AdvancedScanBuilder extends ScanBuilder + with Scan with SupportsPushDownFilters with SupportsPushDownRequiredColumns { var requiredSchema = new StructType().add("i", "int").add("j", "int") var filters = Array.empty[Filter] @@ -498,10 +474,40 @@ class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig override def pushedFilters(): Array[Filter] = filters - override def build(): ScanConfig = this + override def build(): Scan = this + + override def toBatch: Batch = new AdvancedBatch(filters, requiredSchema) +} + +class AdvancedBatch(val filters: Array[Filter], val requiredSchema: StructType) extends Batch { + + override def planInputPartitions(): Array[InputPartition] = { + val lowerBound = filters.collectFirst { + case GreaterThan("i", v: Int) => v + } + + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] + + if (lowerBound.isEmpty) { + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get < 4) { + res.append(RangeInputPartition(lowerBound.get + 1, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get < 9) { + res.append(RangeInputPartition(lowerBound.get + 1, 10)) + } + + res.toArray + } + + override def createReaderFactory(): PartitionReaderFactory = { + new AdvancedReaderFactory(requiredSchema) + } } class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val RangeInputPartition(start, end) = partition new PartitionReader[InternalRow] { @@ -526,39 +532,47 @@ class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderF } -class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider { +class SchemaRequiredDataSource extends TableProvider { - class ReadSupport(val schema: StructType) extends SimpleReadSupport { - override def fullSchema(): StructType = schema + class MyScanBuilder(schema: StructType) extends SimpleScanBuilder { + override def planInputPartitions(): Array[InputPartition] = Array.empty - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = - Array.empty + override def readSchema(): StructType = schema } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + override def getTable(options: DataSourceOptions): Table = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def createBatchReadSupport( - schema: StructType, options: DataSourceOptions): BatchReadSupport = { - new ReadSupport(schema) + override def getTable(options: DataSourceOptions, schema: StructType): Table = { + val userGivenSchema = schema + new SimpleBatchTable { + override def schema(): StructType = userGivenSchema + + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder(userGivenSchema) + } + } } } -class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { +class ColumnarDataSourceV2 extends TableProvider { - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + class MyScanBuilder extends SimpleScanBuilder { + + override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90)) } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createReaderFactory(): PartitionReaderFactory = { ColumnarReaderFactory } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport + override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder() + } } } @@ -608,21 +622,29 @@ object ColumnarReaderFactory extends PartitionReaderFactory { } -class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider { +class PartitionAwareDataSource extends TableProvider { + + class MyScanBuilder extends SimpleScanBuilder + with SupportsReportPartitioning{ - class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def planInputPartitions(): Array[InputPartition] = { // Note that we don't have same value of column `a` across partitions. Array( SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)), SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2))) } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createReaderFactory(): PartitionReaderFactory = { SpecificReaderFactory } - override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning + override def outputPartitioning(): Partitioning = new MyPartitioning + } + + override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder() + } } class MyPartitioning extends Partitioning { @@ -633,10 +655,6 @@ class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvide case _ => false } } - - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } } case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition @@ -662,7 +680,7 @@ object SpecificReaderFactory extends PartitionReaderFactory { class SchemaReadAttemptException(m: String) extends RuntimeException(m) class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { - override def fullSchema(): StructType = { + override def writeSchema(): StructType = { // This is a bit hacky since this source implements read support but throws // during schema retrieval. Might have to rewrite but it's done // such so for minimised changes. http://git-wip-us.apache.org/repos/asf/spark/blob/2b2c94a3/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a7dfc2d..82bb4fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration /** @@ -39,19 +39,16 @@ import org.apache.spark.util.SerializableConfiguration * Each job moves files from `target/_temporary/queryId/` to `target`. */ class SimpleWritableDataSource extends DataSourceV2 - with BatchReadSupportProvider + with TableProvider with BatchWriteSupportProvider with SessionConfigSupport { - protected def fullSchema(): StructType = new StructType().add("i", "long").add("j", "long") + protected def writeSchema(): StructType = new StructType().add("i", "long").add("j", "long") override def keyPrefix: String = "simpleWritableDataSource" - class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { - - override def fullSchema(): StructType = SimpleWritableDataSource.this.fullSchema() - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + class MyScanBuilder(path: String, conf: Configuration) extends SimpleScanBuilder { + override def planInputPartitions(): Array[InputPartition] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -66,10 +63,24 @@ class SimpleWritableDataSource extends DataSourceV2 } } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createReaderFactory(): PartitionReaderFactory = { val serializableConf = new SerializableConfiguration(conf) new CSVReaderFactory(serializableConf) } + + override def readSchema(): StructType = writeSchema + } + + override def getTable(options: DataSourceOptions): Table = { + val path = new Path(options.get("path").get()) + val conf = SparkContext.getActive.get.hadoopConfiguration + new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder(path.toUri.toString, conf) + } + + override def schema(): StructType = writeSchema + } } class WritSupport(queryId: String, path: String, conf: Configuration) extends BatchWriteSupport { @@ -105,12 +116,6 @@ class SimpleWritableDataSource extends DataSourceV2 } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - val path = new Path(options.get("path").get()) - val conf = SparkContext.getActive.get.hadoopConfiguration - new ReadSupport(path.toUri.toString, conf) - } - override def createBatchWriteSupport( queryId: String, schema: StructType, http://git-wip-us.apache.org/repos/asf/spark/blob/2b2c94a3/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 93eae29..756092f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming.continuous import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2StreamingScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r + case DataSourceV2StreamingScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r }.get val deltaMs = numTriggers * 1000 + 300 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org