This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 01d61a5fc963 [SPARK-46253][PYTHON] Plan Python data source read using MapInArrow 01d61a5fc963 is described below commit 01d61a5fc963013fcf55bbfb384e06d1c5ec7e3d Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Tue Dec 12 14:05:29 2023 -0800 [SPARK-46253][PYTHON] Plan Python data source read using MapInArrow ### What changes were proposed in this pull request? This PR changes how we plan Python data source read. Instead of using a regular Python UDTF, we can use an arrow UDF and plan the data source read using the MapInArrow operator. ### Why are the changes needed? To improve the performance ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #44170 from allisonwang-db/spark-46253-arrow-read. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/errors/error_classes.py | 10 ++ python/pyspark/sql/tests/test_python_datasource.py | 105 +++++++++++++++- python/pyspark/sql/worker/plan_data_source_read.py | 132 +++++++++++++++++++-- .../plans/logical/pythonLogicalOperators.scala | 6 +- .../datasources/PlanPythonDataSourceScan.scala | 44 ++++--- .../python/UserDefinedPythonDataSource.scala | 21 +++- .../execution/python/PythonDataSourceSuite.scala | 23 ++-- 7 files changed, 287 insertions(+), 54 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index d2d7f3148f4c..ffe5d692001c 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -752,6 +752,16 @@ ERROR_CLASSES_JSON = """ "Unable to create the Python data source <type> because the '<method>' method hasn't been implemented." ] }, + "PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE" : { + "message" : [ + "The data type of the returned value ('<type>') from the Python data source '<name>' is not supported. Supported types: <supported_types>." + ] + }, + "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH" : { + "message" : [ + "The number of columns in the result does not match the required schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the 'read' method have the same number of columns as required by the output schema." + ] + }, "PYTHON_DATA_SOURCE_TYPE_MISMATCH" : { "message" : [ "Expected <expected>, but got <actual>." diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 8c7074c72a64..74ef6a874589 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -16,9 +16,11 @@ # import os import unittest +from typing import Callable, Union +from pyspark.errors import PythonException from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition -from pyspark.sql.types import Row +from pyspark.sql.types import Row, StructType from pyspark.testing import assertDataFrameEqual from pyspark.testing.sqlutils import ReusedSQLTestCase from pyspark.testing.utils import SPARK_HOME @@ -78,6 +80,107 @@ class BasePythonDataSourceTestsMixin: df = self.spark.read.format("TestDataSource").load() assertDataFrameEqual(df, [Row(c=0, d=1)]) + def register_data_source( + self, + read_func: Callable, + partition_func: Callable = None, + output: Union[str, StructType] = "i int, j int", + name: str = "test", + ): + class TestDataSourceReader(DataSourceReader): + def __init__(self, schema): + self.schema = schema + + def partitions(self): + if partition_func is not None: + return partition_func() + else: + raise NotImplementedError + + def read(self, partition): + return read_func(self.schema, partition) + + class TestDataSource(DataSource): + @classmethod + def name(cls): + return name + + def schema(self): + return output + + def reader(self, schema) -> "DataSourceReader": + return TestDataSourceReader(schema) + + self.spark.dataSource.register(TestDataSource) + + def test_data_source_read_output_tuple(self): + self.register_data_source(read_func=lambda schema, partition: iter([(0, 1)])) + df = self.spark.read.format("test").load() + assertDataFrameEqual(df, [Row(0, 1)]) + + def test_data_source_read_output_list(self): + self.register_data_source(read_func=lambda schema, partition: iter([[0, 1]])) + df = self.spark.read.format("test").load() + assertDataFrameEqual(df, [Row(0, 1)]) + + def test_data_source_read_output_row(self): + self.register_data_source(read_func=lambda schema, partition: iter([Row(0, 1)])) + df = self.spark.read.format("test").load() + assertDataFrameEqual(df, [Row(0, 1)]) + + def test_data_source_read_output_none(self): + self.register_data_source(read_func=lambda schema, partition: None) + df = self.spark.read.format("test").load() + with self.assertRaisesRegex(PythonException, "PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE"): + assertDataFrameEqual(df, []) + + def test_data_source_read_output_empty_iter(self): + self.register_data_source(read_func=lambda schema, partition: iter([])) + df = self.spark.read.format("test").load() + assertDataFrameEqual(df, []) + + def test_data_source_read_cast_output_schema(self): + self.register_data_source( + read_func=lambda schema, partition: iter([(0, 1)]), output="i long, j string" + ) + df = self.spark.read.format("test").load() + assertDataFrameEqual(df, [Row(i=0, j="1")]) + + def test_data_source_read_output_with_partition(self): + def partition_func(): + return [InputPartition(0), InputPartition(1)] + + def read_func(schema, partition): + if partition.value == 0: + return iter([]) + elif partition.value == 1: + yield (0, 1) + + self.register_data_source(read_func=read_func, partition_func=partition_func) + df = self.spark.read.format("test").load() + assertDataFrameEqual(df, [Row(0, 1)]) + + def test_data_source_read_output_with_schema_mismatch(self): + self.register_data_source(read_func=lambda schema, partition: iter([(0, 1)])) + df = self.spark.read.format("test").schema("i int").load() + with self.assertRaisesRegex( + PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH" + ): + df.collect() + self.register_data_source( + read_func=lambda schema, partition: iter([(0, 1)]), output="i int, j int, k int" + ) + with self.assertRaisesRegex( + PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH" + ): + df.collect() + + def test_read_with_invalid_return_row_type(self): + self.register_data_source(read_func=lambda schema, partition: iter([1])) + df = self.spark.read.format("test").load() + with self.assertRaisesRegex(PythonException, "PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE"): + df.collect() + def test_in_memory_data_source(self): class InMemDataSourceReader(DataSourceReader): DEFAULT_NUM_PARTITIONS: int = 3 diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py index c6abef4509d5..d2fcb5096ae2 100644 --- a/python/pyspark/sql/worker/plan_data_source_read.py +++ b/python/pyspark/sql/worker/plan_data_source_read.py @@ -17,7 +17,9 @@ import os import sys -from typing import Any, IO, Iterator +import functools +from itertools import islice +from typing import IO, List, Iterator, Iterable from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkAssertionError, PySparkRuntimeError @@ -26,10 +28,15 @@ from pyspark.serializers import ( read_int, write_int, SpecialLengths, - CloudPickleSerializer, ) +from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion from pyspark.sql.datasource import DataSource, InputPartition -from pyspark.sql.types import _parse_datatype_json_string, StructType +from pyspark.sql.pandas.types import to_arrow_schema +from pyspark.sql.types import ( + _parse_datatype_json_string, + BinaryType, + StructType, +) from pyspark.util import handle_worker_exception from pyspark.worker_util import ( check_python_version, @@ -84,6 +91,22 @@ def main(infile: IO, outfile: IO) -> None: }, ) + # Receive the output schema from its child plan. + input_schema_json = utf8_deserializer.loads(infile) + input_schema = _parse_datatype_json_string(input_schema_json) + if not isinstance(input_schema, StructType): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "an input schema of type 'StructType'", + "actual": f"'{type(input_schema).__name__}'", + }, + ) + assert len(input_schema) == 1 and isinstance(input_schema[0].dataType, BinaryType), ( + "The input schema of Python data source read should contain only one column of type " + f"'BinaryType', but got '{input_schema}'" + ) + # Receive the data source output schema. schema_json = utf8_deserializer.loads(infile) schema = _parse_datatype_json_string(schema_json) @@ -91,11 +114,18 @@ def main(infile: IO, outfile: IO) -> None: raise PySparkAssertionError( error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", message_parameters={ - "expected": "a Python data source schema of type 'StructType'", + "expected": "an output schema of type 'StructType'", "actual": f"'{type(schema).__name__}'", }, ) + # Receive the configuration values. + max_arrow_batch_size = read_int(infile) + assert max_arrow_batch_size > 0, ( + "The maximum arrow batch size should be greater than 0, but got " + f"'{max_arrow_batch_size}'" + ) + # Instantiate data source reader. try: reader = data_source.reader(schema=schema) @@ -146,16 +176,94 @@ def main(infile: IO, outfile: IO) -> None: message_parameters={"type": "reader", "error": str(e)}, ) - # Construct a UDTF. - class PythonDataSourceReaderUDTF: - def __init__(self) -> None: - self.ser = CloudPickleSerializer() + # Wrap the data source read logic in an mapInArrow UDF. + import pyarrow as pa + + # Create input converter. + converter = ArrowTableToRowsConversion._create_converter(BinaryType()) + + # Create output converter. + return_type = schema + pa_schema = to_arrow_schema(return_type) + column_names = return_type.fieldNames() + column_converters = [ + LocalDataToArrowConversion._create_converter(field.dataType) + for field in return_type.fields + ] + + def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]: + partition_bytes = None + + # Get the partition value from the input iterator. + for batch in iterator: + # There should be only one row/column in the batch. + assert batch.num_columns == 1 and batch.num_rows == 1, ( + "Expected each batch to have exactly 1 column and 1 row, " + f"but found {batch.num_columns} columns and {batch.num_rows} rows." + ) + columns = [column.to_pylist() for column in batch.columns] + partition_bytes = converter(columns[0][0]) + + assert ( + partition_bytes is not None + ), "The input iterator for Python data source read function is empty." + + # Deserialize the partition value. + partition = pickleSer.loads(partition_bytes) + + assert partition is None or isinstance(partition, InputPartition), ( + "Expected the partition value to be of type 'InputPartition', " + f"but found '{type(partition).__name__}'." + ) + + output_iter = reader.read(partition) # type: ignore[arg-type] + + # Validate the output iterator. + if not isinstance(output_iter, Iterator): + raise PySparkRuntimeError( + error_class="PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE", + message_parameters={ + "type": type(output_iter).__name__, + "name": data_source.name(), + "supported_types": "iterator", + }, + ) + + def batched(iterator: Iterator, n: int) -> Iterator: + return iter(functools.partial(lambda it: list(islice(it, n)), iterator), []) + + # Convert the results from the `reader.read` method to an iterator of arrow batches. + num_cols = len(column_names) + for batch in batched(output_iter, max_arrow_batch_size): + pylist: List[List] = [[] for _ in range(num_cols)] + for result in batch: + # Validate the output row schema. + if hasattr(result, "__len__") and len(result) != num_cols: + raise PySparkRuntimeError( + error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH", + message_parameters={ + "expected": str(num_cols), + "actual": str(len(result)), + }, + ) + + # Validate the output row type. + if not isinstance(result, (list, tuple)): + raise PySparkRuntimeError( + error_class="PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE", + message_parameters={ + "type": type(result).__name__, + "name": data_source.name(), + "supported_types": "tuple, list, `pyspark.sql.types.Row`", + }, + ) + + for col in range(num_cols): + pylist[col].append(column_converters[col](result[col])) - def eval(self, partition_bytes: Any) -> Iterator: - partition = self.ser.loads(partition_bytes) - yield from reader.read(partition) + yield pa.RecordBatch.from_arrays(pylist, schema=pa_schema) - command = PythonDataSourceReaderUDTF + command = (data_source_read_func, return_type) pickleSer._write_with_length(command, outfile) # Return the serialized partition values. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index d4ed673c3513..fb8b06eb41bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -134,9 +134,9 @@ case class PythonDataSourcePartitions( } object PythonDataSourcePartitions { - def getOutputAttrs: Seq[Attribute] = { - toAttributes(new StructType().add("partition", BinaryType)) - } + def schema: StructType = new StructType().add("partition", BinaryType) + + def getOutputAttrs: Seq[Attribute] = toAttributes(schema) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala index ec4c7c188fa0..7ffd61a4a266 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.api.python.{PythonEvalType, PythonFunction, SimplePythonFunction} -import org.apache.spark.sql.catalyst.expressions.PythonUDTF -import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, Project, PythonDataSource, PythonDataSourcePartitions} +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, PythonDataSource, PythonDataSourcePartitions, PythonMapInArrow} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_DATA_SOURCE import org.apache.spark.sql.execution.python.UserDefinedPythonDataSourceReadRunner @@ -40,18 +40,21 @@ import org.apache.spark.util.ArrayImplicits._ * class. Post this rule, the plan is transformed into: * * Project [output] - * +- Generate [python_data_source_read_udtf, ...] + * +- PythonMapInArrow [read_from_data_source, ...] * +- PythonDataSourcePartitions [partition_bytes] * * The PythonDataSourcePartitions contains a list of serialized partition values for the data - * source. The `DataSourceReader.read` method will be planned as a UDTF that accepts a partition - * value and yields the scanning output. + * source. The `DataSourceReader.read` method will be planned as a MapInArrow operator that + * accepts a partition value and yields the scanning output. */ object PlanPythonDataSourceScan extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning( _.containsPattern(PYTHON_DATA_SOURCE)) { case ds @ PythonDataSource(dataSource: PythonFunction, schema, _) => - val info = new UserDefinedPythonDataSourceReadRunner(dataSource, schema).runInPython() + val inputSchema = PythonDataSourcePartitions.schema + + val info = new UserDefinedPythonDataSourceReadRunner( + dataSource, inputSchema, schema).runInPython() val readerFunc = SimplePythonFunction( command = info.func.toImmutableArraySeq, @@ -65,27 +68,22 @@ object PlanPythonDataSourceScan extends Rule[LogicalPlan] { val partitionPlan = PythonDataSourcePartitions( PythonDataSourcePartitions.getOutputAttrs, info.partitions) - // Construct a Python UDTF for the reader function. - val pythonUDTF = PythonUDTF( - name = "python_data_source_read", + val pythonUDF = PythonUDF( + name = "read_from_data_source", func = readerFunc, - elementSchema = schema, + dataType = schema, children = partitionPlan.output, - evalType = PythonEvalType.SQL_TABLE_UDF, - udfDeterministic = false, - pickledAnalyzeResult = None) + evalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF, + udfDeterministic = false) - // Later the rule `ExtractPythonUDTFs` will turn this Generate - // into a evaluable Python UDTF node. - val generate = Generate( - generator = pythonUDTF, - unrequiredChildIndex = Nil, - outer = false, - qualifier = None, - generatorOutput = ds.output, - child = partitionPlan) + // Construct the plan. + val plan = PythonMapInArrow( + pythonUDF, + ds.output, + partitionPlan, + isBarrier = false) // Project out partition values. - Project(ds.output, generate) + Project(ds.output, plan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 7044ef65c638..2c8e1b942727 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PythonDataSourc import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -147,9 +148,17 @@ case class PythonDataSourceReadInfo( func: Array[Byte], partitions: Seq[Array[Byte]]) +/** + * Send information to a Python process to plan a Python data source read. + * + * @param func an Python data source instance + * @param inputSchema input schema to the data source read from its child plan + * @param outputSchema output schema of the Python data source + */ class UserDefinedPythonDataSourceReadRunner( func: PythonFunction, - schema: StructType) extends PythonPlannerRunner[PythonDataSourceReadInfo](func) { + inputSchema: StructType, + outputSchema: StructType) extends PythonPlannerRunner[PythonDataSourceReadInfo](func) { // See the logic in `pyspark.sql.worker.plan_data_source_read.py`. override val workerModule = "pyspark.sql.worker.plan_data_source_read" @@ -158,8 +167,14 @@ class UserDefinedPythonDataSourceReadRunner( // Send Python data source PythonWorkerUtils.writePythonFunction(func, dataOut) - // Send schema - PythonWorkerUtils.writeUTF(schema.json, dataOut) + // Send input schema + PythonWorkerUtils.writeUTF(inputSchema.json, dataOut) + + // Send output schema + PythonWorkerUtils.writeUTF(outputSchema.json, dataOut) + + // Send configurations + dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch) } override protected def receiveFromPython(dataIn: DataInputStream): PythonDataSourceReadInfo = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index ec2f8c19b02b..6bc9166117f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} -import org.apache.spark.sql.catalyst.plans.logical.{BatchEvalPythonUDTF, PythonDataSourcePartitions} +import org.apache.spark.sql.catalyst.plans.logical.{PythonDataSourcePartitions, PythonMapInArrow} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -40,7 +40,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { |""".stripMargin test("simple data source") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource @@ -58,15 +58,14 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { assert(df.rdd.getNumPartitions == 2) val plan = df.queryExecution.optimizedPlan plan match { - case BatchEvalPythonUDTF(pythonUDTF, _, _, _: PythonDataSourcePartitions) - if pythonUDTF.name == "python_data_source_read" => + case PythonMapInArrow(_, _, _: PythonDataSourcePartitions, _) => case _ => fail(s"Plan did not match the expected pattern. Actual plan:\n$plan") } checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) } test("simple data source with string schema") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource, DataSourceReader @@ -85,7 +84,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } test("simple data source with StructType schema") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource, DataSourceReader @@ -108,7 +107,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } test("data source with invalid schema") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource, DataSourceReader @@ -129,7 +128,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } test("register data source") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource, DataSourceReader @@ -177,7 +176,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } test("load data source") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition @@ -222,7 +221,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } test("reader not implemented") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource, DataSourceReader @@ -240,7 +239,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } test("error creating reader") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource @@ -260,7 +259,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } test("data source assertion error") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val dataSourceScript = s""" |class $dataSourceName: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org