This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 01d61a5fc963 [SPARK-46253][PYTHON] Plan Python data source read using
MapInArrow
01d61a5fc963 is described below
commit 01d61a5fc963013fcf55bbfb384e06d1c5ec7e3d
Author: allisonwang-db <[email protected]>
AuthorDate: Tue Dec 12 14:05:29 2023 -0800
[SPARK-46253][PYTHON] Plan Python data source read using MapInArrow
### What changes were proposed in this pull request?
This PR changes how we plan Python data source read. Instead of using a
regular Python UDTF, we can use an arrow UDF and plan the data source read
using the MapInArrow operator.
### Why are the changes needed?
To improve the performance
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #44170 from allisonwang-db/spark-46253-arrow-read.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/errors/error_classes.py | 10 ++
python/pyspark/sql/tests/test_python_datasource.py | 105 +++++++++++++++-
python/pyspark/sql/worker/plan_data_source_read.py | 132 +++++++++++++++++++--
.../plans/logical/pythonLogicalOperators.scala | 6 +-
.../datasources/PlanPythonDataSourceScan.scala | 44 ++++---
.../python/UserDefinedPythonDataSource.scala | 21 +++-
.../execution/python/PythonDataSourceSuite.scala | 23 ++--
7 files changed, 287 insertions(+), 54 deletions(-)
diff --git a/python/pyspark/errors/error_classes.py
b/python/pyspark/errors/error_classes.py
index d2d7f3148f4c..ffe5d692001c 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -752,6 +752,16 @@ ERROR_CLASSES_JSON = """
"Unable to create the Python data source <type> because the '<method>'
method hasn't been implemented."
]
},
+ "PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE" : {
+ "message" : [
+ "The data type of the returned value ('<type>') from the Python data
source '<name>' is not supported. Supported types: <supported_types>."
+ ]
+ },
+ "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH" : {
+ "message" : [
+ "The number of columns in the result does not match the required schema.
Expected column count: <expected>, Actual column count: <actual>. Please make
sure the values returned by the 'read' method have the same number of columns
as required by the output schema."
+ ]
+ },
"PYTHON_DATA_SOURCE_TYPE_MISMATCH" : {
"message" : [
"Expected <expected>, but got <actual>."
diff --git a/python/pyspark/sql/tests/test_python_datasource.py
b/python/pyspark/sql/tests/test_python_datasource.py
index 8c7074c72a64..74ef6a874589 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -16,9 +16,11 @@
#
import os
import unittest
+from typing import Callable, Union
+from pyspark.errors import PythonException
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
-from pyspark.sql.types import Row
+from pyspark.sql.types import Row, StructType
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import SPARK_HOME
@@ -78,6 +80,107 @@ class BasePythonDataSourceTestsMixin:
df = self.spark.read.format("TestDataSource").load()
assertDataFrameEqual(df, [Row(c=0, d=1)])
+ def register_data_source(
+ self,
+ read_func: Callable,
+ partition_func: Callable = None,
+ output: Union[str, StructType] = "i int, j int",
+ name: str = "test",
+ ):
+ class TestDataSourceReader(DataSourceReader):
+ def __init__(self, schema):
+ self.schema = schema
+
+ def partitions(self):
+ if partition_func is not None:
+ return partition_func()
+ else:
+ raise NotImplementedError
+
+ def read(self, partition):
+ return read_func(self.schema, partition)
+
+ class TestDataSource(DataSource):
+ @classmethod
+ def name(cls):
+ return name
+
+ def schema(self):
+ return output
+
+ def reader(self, schema) -> "DataSourceReader":
+ return TestDataSourceReader(schema)
+
+ self.spark.dataSource.register(TestDataSource)
+
+ def test_data_source_read_output_tuple(self):
+ self.register_data_source(read_func=lambda schema, partition:
iter([(0, 1)]))
+ df = self.spark.read.format("test").load()
+ assertDataFrameEqual(df, [Row(0, 1)])
+
+ def test_data_source_read_output_list(self):
+ self.register_data_source(read_func=lambda schema, partition:
iter([[0, 1]]))
+ df = self.spark.read.format("test").load()
+ assertDataFrameEqual(df, [Row(0, 1)])
+
+ def test_data_source_read_output_row(self):
+ self.register_data_source(read_func=lambda schema, partition:
iter([Row(0, 1)]))
+ df = self.spark.read.format("test").load()
+ assertDataFrameEqual(df, [Row(0, 1)])
+
+ def test_data_source_read_output_none(self):
+ self.register_data_source(read_func=lambda schema, partition: None)
+ df = self.spark.read.format("test").load()
+ with self.assertRaisesRegex(PythonException,
"PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE"):
+ assertDataFrameEqual(df, [])
+
+ def test_data_source_read_output_empty_iter(self):
+ self.register_data_source(read_func=lambda schema, partition: iter([]))
+ df = self.spark.read.format("test").load()
+ assertDataFrameEqual(df, [])
+
+ def test_data_source_read_cast_output_schema(self):
+ self.register_data_source(
+ read_func=lambda schema, partition: iter([(0, 1)]), output="i
long, j string"
+ )
+ df = self.spark.read.format("test").load()
+ assertDataFrameEqual(df, [Row(i=0, j="1")])
+
+ def test_data_source_read_output_with_partition(self):
+ def partition_func():
+ return [InputPartition(0), InputPartition(1)]
+
+ def read_func(schema, partition):
+ if partition.value == 0:
+ return iter([])
+ elif partition.value == 1:
+ yield (0, 1)
+
+ self.register_data_source(read_func=read_func,
partition_func=partition_func)
+ df = self.spark.read.format("test").load()
+ assertDataFrameEqual(df, [Row(0, 1)])
+
+ def test_data_source_read_output_with_schema_mismatch(self):
+ self.register_data_source(read_func=lambda schema, partition:
iter([(0, 1)]))
+ df = self.spark.read.format("test").schema("i int").load()
+ with self.assertRaisesRegex(
+ PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
+ ):
+ df.collect()
+ self.register_data_source(
+ read_func=lambda schema, partition: iter([(0, 1)]), output="i int,
j int, k int"
+ )
+ with self.assertRaisesRegex(
+ PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
+ ):
+ df.collect()
+
+ def test_read_with_invalid_return_row_type(self):
+ self.register_data_source(read_func=lambda schema, partition:
iter([1]))
+ df = self.spark.read.format("test").load()
+ with self.assertRaisesRegex(PythonException,
"PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE"):
+ df.collect()
+
def test_in_memory_data_source(self):
class InMemDataSourceReader(DataSourceReader):
DEFAULT_NUM_PARTITIONS: int = 3
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py
b/python/pyspark/sql/worker/plan_data_source_read.py
index c6abef4509d5..d2fcb5096ae2 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -17,7 +17,9 @@
import os
import sys
-from typing import Any, IO, Iterator
+import functools
+from itertools import islice
+from typing import IO, List, Iterator, Iterable
from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
@@ -26,10 +28,15 @@ from pyspark.serializers import (
read_int,
write_int,
SpecialLengths,
- CloudPickleSerializer,
)
+from pyspark.sql.connect.conversion import ArrowTableToRowsConversion,
LocalDataToArrowConversion
from pyspark.sql.datasource import DataSource, InputPartition
-from pyspark.sql.types import _parse_datatype_json_string, StructType
+from pyspark.sql.pandas.types import to_arrow_schema
+from pyspark.sql.types import (
+ _parse_datatype_json_string,
+ BinaryType,
+ StructType,
+)
from pyspark.util import handle_worker_exception
from pyspark.worker_util import (
check_python_version,
@@ -84,6 +91,22 @@ def main(infile: IO, outfile: IO) -> None:
},
)
+ # Receive the output schema from its child plan.
+ input_schema_json = utf8_deserializer.loads(infile)
+ input_schema = _parse_datatype_json_string(input_schema_json)
+ if not isinstance(input_schema, StructType):
+ raise PySparkAssertionError(
+ error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "an input schema of type 'StructType'",
+ "actual": f"'{type(input_schema).__name__}'",
+ },
+ )
+ assert len(input_schema) == 1 and isinstance(input_schema[0].dataType,
BinaryType), (
+ "The input schema of Python data source read should contain only
one column of type "
+ f"'BinaryType', but got '{input_schema}'"
+ )
+
# Receive the data source output schema.
schema_json = utf8_deserializer.loads(infile)
schema = _parse_datatype_json_string(schema_json)
@@ -91,11 +114,18 @@ def main(infile: IO, outfile: IO) -> None:
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
- "expected": "a Python data source schema of type
'StructType'",
+ "expected": "an output schema of type 'StructType'",
"actual": f"'{type(schema).__name__}'",
},
)
+ # Receive the configuration values.
+ max_arrow_batch_size = read_int(infile)
+ assert max_arrow_batch_size > 0, (
+ "The maximum arrow batch size should be greater than 0, but got "
+ f"'{max_arrow_batch_size}'"
+ )
+
# Instantiate data source reader.
try:
reader = data_source.reader(schema=schema)
@@ -146,16 +176,94 @@ def main(infile: IO, outfile: IO) -> None:
message_parameters={"type": "reader", "error": str(e)},
)
- # Construct a UDTF.
- class PythonDataSourceReaderUDTF:
- def __init__(self) -> None:
- self.ser = CloudPickleSerializer()
+ # Wrap the data source read logic in an mapInArrow UDF.
+ import pyarrow as pa
+
+ # Create input converter.
+ converter = ArrowTableToRowsConversion._create_converter(BinaryType())
+
+ # Create output converter.
+ return_type = schema
+ pa_schema = to_arrow_schema(return_type)
+ column_names = return_type.fieldNames()
+ column_converters = [
+ LocalDataToArrowConversion._create_converter(field.dataType)
+ for field in return_type.fields
+ ]
+
+ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) ->
Iterable[pa.RecordBatch]:
+ partition_bytes = None
+
+ # Get the partition value from the input iterator.
+ for batch in iterator:
+ # There should be only one row/column in the batch.
+ assert batch.num_columns == 1 and batch.num_rows == 1, (
+ "Expected each batch to have exactly 1 column and 1 row, "
+ f"but found {batch.num_columns} columns and
{batch.num_rows} rows."
+ )
+ columns = [column.to_pylist() for column in batch.columns]
+ partition_bytes = converter(columns[0][0])
+
+ assert (
+ partition_bytes is not None
+ ), "The input iterator for Python data source read function is
empty."
+
+ # Deserialize the partition value.
+ partition = pickleSer.loads(partition_bytes)
+
+ assert partition is None or isinstance(partition, InputPartition),
(
+ "Expected the partition value to be of type 'InputPartition', "
+ f"but found '{type(partition).__name__}'."
+ )
+
+ output_iter = reader.read(partition) # type: ignore[arg-type]
+
+ # Validate the output iterator.
+ if not isinstance(output_iter, Iterator):
+ raise PySparkRuntimeError(
+ error_class="PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE",
+ message_parameters={
+ "type": type(output_iter).__name__,
+ "name": data_source.name(),
+ "supported_types": "iterator",
+ },
+ )
+
+ def batched(iterator: Iterator, n: int) -> Iterator:
+ return iter(functools.partial(lambda it: list(islice(it, n)),
iterator), [])
+
+ # Convert the results from the `reader.read` method to an iterator
of arrow batches.
+ num_cols = len(column_names)
+ for batch in batched(output_iter, max_arrow_batch_size):
+ pylist: List[List] = [[] for _ in range(num_cols)]
+ for result in batch:
+ # Validate the output row schema.
+ if hasattr(result, "__len__") and len(result) != num_cols:
+ raise PySparkRuntimeError(
+
error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
+ message_parameters={
+ "expected": str(num_cols),
+ "actual": str(len(result)),
+ },
+ )
+
+ # Validate the output row type.
+ if not isinstance(result, (list, tuple)):
+ raise PySparkRuntimeError(
+
error_class="PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE",
+ message_parameters={
+ "type": type(result).__name__,
+ "name": data_source.name(),
+ "supported_types": "tuple, list,
`pyspark.sql.types.Row`",
+ },
+ )
+
+ for col in range(num_cols):
+ pylist[col].append(column_converters[col](result[col]))
- def eval(self, partition_bytes: Any) -> Iterator:
- partition = self.ser.loads(partition_bytes)
- yield from reader.read(partition)
+ yield pa.RecordBatch.from_arrays(pylist, schema=pa_schema)
- command = PythonDataSourceReaderUDTF
+ command = (data_source_read_func, return_type)
pickleSer._write_with_length(command, outfile)
# Return the serialized partition values.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index d4ed673c3513..fb8b06eb41bc 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -134,9 +134,9 @@ case class PythonDataSourcePartitions(
}
object PythonDataSourcePartitions {
- def getOutputAttrs: Seq[Attribute] = {
- toAttributes(new StructType().add("partition", BinaryType))
- }
+ def schema: StructType = new StructType().add("partition", BinaryType)
+
+ def getOutputAttrs: Seq[Attribute] = toAttributes(schema)
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
index ec4c7c188fa0..7ffd61a4a266 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.execution.datasources
import org.apache.spark.api.python.{PythonEvalType, PythonFunction,
SimplePythonFunction}
-import org.apache.spark.sql.catalyst.expressions.PythonUDTF
-import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan,
Project, PythonDataSource, PythonDataSourcePartitions}
+import org.apache.spark.sql.catalyst.expressions.PythonUDF
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project,
PythonDataSource, PythonDataSourcePartitions, PythonMapInArrow}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_DATA_SOURCE
import
org.apache.spark.sql.execution.python.UserDefinedPythonDataSourceReadRunner
@@ -40,18 +40,21 @@ import org.apache.spark.util.ArrayImplicits._
* class. Post this rule, the plan is transformed into:
*
* Project [output]
- * +- Generate [python_data_source_read_udtf, ...]
+ * +- PythonMapInArrow [read_from_data_source, ...]
* +- PythonDataSourcePartitions [partition_bytes]
*
* The PythonDataSourcePartitions contains a list of serialized partition
values for the data
- * source. The `DataSourceReader.read` method will be planned as a UDTF that
accepts a partition
- * value and yields the scanning output.
+ * source. The `DataSourceReader.read` method will be planned as a MapInArrow
operator that
+ * accepts a partition value and yields the scanning output.
*/
object PlanPythonDataSourceScan extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
_.containsPattern(PYTHON_DATA_SOURCE)) {
case ds @ PythonDataSource(dataSource: PythonFunction, schema, _) =>
- val info = new UserDefinedPythonDataSourceReadRunner(dataSource,
schema).runInPython()
+ val inputSchema = PythonDataSourcePartitions.schema
+
+ val info = new UserDefinedPythonDataSourceReadRunner(
+ dataSource, inputSchema, schema).runInPython()
val readerFunc = SimplePythonFunction(
command = info.func.toImmutableArraySeq,
@@ -65,27 +68,22 @@ object PlanPythonDataSourceScan extends Rule[LogicalPlan] {
val partitionPlan = PythonDataSourcePartitions(
PythonDataSourcePartitions.getOutputAttrs, info.partitions)
- // Construct a Python UDTF for the reader function.
- val pythonUDTF = PythonUDTF(
- name = "python_data_source_read",
+ val pythonUDF = PythonUDF(
+ name = "read_from_data_source",
func = readerFunc,
- elementSchema = schema,
+ dataType = schema,
children = partitionPlan.output,
- evalType = PythonEvalType.SQL_TABLE_UDF,
- udfDeterministic = false,
- pickledAnalyzeResult = None)
+ evalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
+ udfDeterministic = false)
- // Later the rule `ExtractPythonUDTFs` will turn this Generate
- // into a evaluable Python UDTF node.
- val generate = Generate(
- generator = pythonUDTF,
- unrequiredChildIndex = Nil,
- outer = false,
- qualifier = None,
- generatorOutput = ds.output,
- child = partitionPlan)
+ // Construct the plan.
+ val plan = PythonMapInArrow(
+ pythonUDF,
+ ds.output,
+ partitionPlan,
+ isBarrier = false)
// Project out partition values.
- Project(ds.output, generate)
+ Project(ds.output, plan)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index 7044ef65c638..2c8e1b942727 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -29,6 +29,7 @@ import
org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PythonDataSourc
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ArrayImplicits._
@@ -147,9 +148,17 @@ case class PythonDataSourceReadInfo(
func: Array[Byte],
partitions: Seq[Array[Byte]])
+/**
+ * Send information to a Python process to plan a Python data source read.
+ *
+ * @param func an Python data source instance
+ * @param inputSchema input schema to the data source read from its child plan
+ * @param outputSchema output schema of the Python data source
+ */
class UserDefinedPythonDataSourceReadRunner(
func: PythonFunction,
- schema: StructType) extends
PythonPlannerRunner[PythonDataSourceReadInfo](func) {
+ inputSchema: StructType,
+ outputSchema: StructType) extends
PythonPlannerRunner[PythonDataSourceReadInfo](func) {
// See the logic in `pyspark.sql.worker.plan_data_source_read.py`.
override val workerModule = "pyspark.sql.worker.plan_data_source_read"
@@ -158,8 +167,14 @@ class UserDefinedPythonDataSourceReadRunner(
// Send Python data source
PythonWorkerUtils.writePythonFunction(func, dataOut)
- // Send schema
- PythonWorkerUtils.writeUTF(schema.json, dataOut)
+ // Send input schema
+ PythonWorkerUtils.writeUTF(inputSchema.json, dataOut)
+
+ // Send output schema
+ PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
+
+ // Send configurations
+ dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
}
override protected def receiveFromPython(dataIn: DataInputStream):
PythonDataSourceReadInfo = {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index ec2f8c19b02b..6bc9166117f2 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils,
QueryTest, Row}
-import org.apache.spark.sql.catalyst.plans.logical.{BatchEvalPythonUDTF,
PythonDataSourcePartitions}
+import
org.apache.spark.sql.catalyst.plans.logical.{PythonDataSourcePartitions,
PythonMapInArrow}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
@@ -40,7 +40,7 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
|""".stripMargin
test("simple data source") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource
@@ -58,15 +58,14 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
assert(df.rdd.getNumPartitions == 2)
val plan = df.queryExecution.optimizedPlan
plan match {
- case BatchEvalPythonUDTF(pythonUDTF, _, _, _: PythonDataSourcePartitions)
- if pythonUDTF.name == "python_data_source_read" =>
+ case PythonMapInArrow(_, _, _: PythonDataSourcePartitions, _) =>
case _ => fail(s"Plan did not match the expected pattern. Actual
plan:\n$plan")
}
checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0),
Row(2, 1)))
}
test("simple data source with string schema") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -85,7 +84,7 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
}
test("simple data source with StructType schema") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -108,7 +107,7 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
}
test("data source with invalid schema") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -129,7 +128,7 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
}
test("register data source") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -177,7 +176,7 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
}
test("load data source") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader,
InputPartition
@@ -222,7 +221,7 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
}
test("reader not implemented") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -240,7 +239,7 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
}
test("error creating reader") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource
@@ -260,7 +259,7 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
}
test("data source assertion error") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|class $dataSourceName:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]