This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 01d61a5fc963 [SPARK-46253][PYTHON] Plan Python data source read using 
MapInArrow
01d61a5fc963 is described below

commit 01d61a5fc963013fcf55bbfb384e06d1c5ec7e3d
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Tue Dec 12 14:05:29 2023 -0800

    [SPARK-46253][PYTHON] Plan Python data source read using MapInArrow
    
    ### What changes were proposed in this pull request?
    
    This PR changes how we plan Python data source read. Instead of using a 
regular Python UDTF, we can use an arrow UDF and plan the data source read 
using the MapInArrow operator.
    
    ### Why are the changes needed?
    
    To improve the performance
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44170 from allisonwang-db/spark-46253-arrow-read.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/errors/error_classes.py             |  10 ++
 python/pyspark/sql/tests/test_python_datasource.py | 105 +++++++++++++++-
 python/pyspark/sql/worker/plan_data_source_read.py | 132 +++++++++++++++++++--
 .../plans/logical/pythonLogicalOperators.scala     |   6 +-
 .../datasources/PlanPythonDataSourceScan.scala     |  44 ++++---
 .../python/UserDefinedPythonDataSource.scala       |  21 +++-
 .../execution/python/PythonDataSourceSuite.scala   |  23 ++--
 7 files changed, 287 insertions(+), 54 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index d2d7f3148f4c..ffe5d692001c 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -752,6 +752,16 @@ ERROR_CLASSES_JSON = """
         "Unable to create the Python data source <type> because the '<method>' 
method hasn't been implemented."
     ]
   },
+  "PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE" : {
+    "message" : [
+        "The data type of the returned value ('<type>') from the Python data 
source '<name>' is not supported. Supported types: <supported_types>."
+    ]
+  },
+  "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH" : {
+    "message" : [
+      "The number of columns in the result does not match the required schema. 
Expected column count: <expected>, Actual column count: <actual>. Please make 
sure the values returned by the 'read' method have the same number of columns 
as required by the output schema."
+    ]
+  },
   "PYTHON_DATA_SOURCE_TYPE_MISMATCH" : {
     "message" : [
       "Expected <expected>, but got <actual>."
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index 8c7074c72a64..74ef6a874589 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -16,9 +16,11 @@
 #
 import os
 import unittest
+from typing import Callable, Union
 
+from pyspark.errors import PythonException
 from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
-from pyspark.sql.types import Row
+from pyspark.sql.types import Row, StructType
 from pyspark.testing import assertDataFrameEqual
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 from pyspark.testing.utils import SPARK_HOME
@@ -78,6 +80,107 @@ class BasePythonDataSourceTestsMixin:
         df = self.spark.read.format("TestDataSource").load()
         assertDataFrameEqual(df, [Row(c=0, d=1)])
 
+    def register_data_source(
+        self,
+        read_func: Callable,
+        partition_func: Callable = None,
+        output: Union[str, StructType] = "i int, j int",
+        name: str = "test",
+    ):
+        class TestDataSourceReader(DataSourceReader):
+            def __init__(self, schema):
+                self.schema = schema
+
+            def partitions(self):
+                if partition_func is not None:
+                    return partition_func()
+                else:
+                    raise NotImplementedError
+
+            def read(self, partition):
+                return read_func(self.schema, partition)
+
+        class TestDataSource(DataSource):
+            @classmethod
+            def name(cls):
+                return name
+
+            def schema(self):
+                return output
+
+            def reader(self, schema) -> "DataSourceReader":
+                return TestDataSourceReader(schema)
+
+        self.spark.dataSource.register(TestDataSource)
+
+    def test_data_source_read_output_tuple(self):
+        self.register_data_source(read_func=lambda schema, partition: 
iter([(0, 1)]))
+        df = self.spark.read.format("test").load()
+        assertDataFrameEqual(df, [Row(0, 1)])
+
+    def test_data_source_read_output_list(self):
+        self.register_data_source(read_func=lambda schema, partition: 
iter([[0, 1]]))
+        df = self.spark.read.format("test").load()
+        assertDataFrameEqual(df, [Row(0, 1)])
+
+    def test_data_source_read_output_row(self):
+        self.register_data_source(read_func=lambda schema, partition: 
iter([Row(0, 1)]))
+        df = self.spark.read.format("test").load()
+        assertDataFrameEqual(df, [Row(0, 1)])
+
+    def test_data_source_read_output_none(self):
+        self.register_data_source(read_func=lambda schema, partition: None)
+        df = self.spark.read.format("test").load()
+        with self.assertRaisesRegex(PythonException, 
"PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE"):
+            assertDataFrameEqual(df, [])
+
+    def test_data_source_read_output_empty_iter(self):
+        self.register_data_source(read_func=lambda schema, partition: iter([]))
+        df = self.spark.read.format("test").load()
+        assertDataFrameEqual(df, [])
+
+    def test_data_source_read_cast_output_schema(self):
+        self.register_data_source(
+            read_func=lambda schema, partition: iter([(0, 1)]), output="i 
long, j string"
+        )
+        df = self.spark.read.format("test").load()
+        assertDataFrameEqual(df, [Row(i=0, j="1")])
+
+    def test_data_source_read_output_with_partition(self):
+        def partition_func():
+            return [InputPartition(0), InputPartition(1)]
+
+        def read_func(schema, partition):
+            if partition.value == 0:
+                return iter([])
+            elif partition.value == 1:
+                yield (0, 1)
+
+        self.register_data_source(read_func=read_func, 
partition_func=partition_func)
+        df = self.spark.read.format("test").load()
+        assertDataFrameEqual(df, [Row(0, 1)])
+
+    def test_data_source_read_output_with_schema_mismatch(self):
+        self.register_data_source(read_func=lambda schema, partition: 
iter([(0, 1)]))
+        df = self.spark.read.format("test").schema("i int").load()
+        with self.assertRaisesRegex(
+            PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
+        ):
+            df.collect()
+        self.register_data_source(
+            read_func=lambda schema, partition: iter([(0, 1)]), output="i int, 
j int, k int"
+        )
+        with self.assertRaisesRegex(
+            PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
+        ):
+            df.collect()
+
+    def test_read_with_invalid_return_row_type(self):
+        self.register_data_source(read_func=lambda schema, partition: 
iter([1]))
+        df = self.spark.read.format("test").load()
+        with self.assertRaisesRegex(PythonException, 
"PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE"):
+            df.collect()
+
     def test_in_memory_data_source(self):
         class InMemDataSourceReader(DataSourceReader):
             DEFAULT_NUM_PARTITIONS: int = 3
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py 
b/python/pyspark/sql/worker/plan_data_source_read.py
index c6abef4509d5..d2fcb5096ae2 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -17,7 +17,9 @@
 
 import os
 import sys
-from typing import Any, IO, Iterator
+import functools
+from itertools import islice
+from typing import IO, List, Iterator, Iterable
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
@@ -26,10 +28,15 @@ from pyspark.serializers import (
     read_int,
     write_int,
     SpecialLengths,
-    CloudPickleSerializer,
 )
+from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, 
LocalDataToArrowConversion
 from pyspark.sql.datasource import DataSource, InputPartition
-from pyspark.sql.types import _parse_datatype_json_string, StructType
+from pyspark.sql.pandas.types import to_arrow_schema
+from pyspark.sql.types import (
+    _parse_datatype_json_string,
+    BinaryType,
+    StructType,
+)
 from pyspark.util import handle_worker_exception
 from pyspark.worker_util import (
     check_python_version,
@@ -84,6 +91,22 @@ def main(infile: IO, outfile: IO) -> None:
                 },
             )
 
+        # Receive the output schema from its child plan.
+        input_schema_json = utf8_deserializer.loads(infile)
+        input_schema = _parse_datatype_json_string(input_schema_json)
+        if not isinstance(input_schema, StructType):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "an input schema of type 'StructType'",
+                    "actual": f"'{type(input_schema).__name__}'",
+                },
+            )
+        assert len(input_schema) == 1 and isinstance(input_schema[0].dataType, 
BinaryType), (
+            "The input schema of Python data source read should contain only 
one column of type "
+            f"'BinaryType', but got '{input_schema}'"
+        )
+
         # Receive the data source output schema.
         schema_json = utf8_deserializer.loads(infile)
         schema = _parse_datatype_json_string(schema_json)
@@ -91,11 +114,18 @@ def main(infile: IO, outfile: IO) -> None:
             raise PySparkAssertionError(
                 error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
                 message_parameters={
-                    "expected": "a Python data source schema of type 
'StructType'",
+                    "expected": "an output schema of type 'StructType'",
                     "actual": f"'{type(schema).__name__}'",
                 },
             )
 
+        # Receive the configuration values.
+        max_arrow_batch_size = read_int(infile)
+        assert max_arrow_batch_size > 0, (
+            "The maximum arrow batch size should be greater than 0, but got "
+            f"'{max_arrow_batch_size}'"
+        )
+
         # Instantiate data source reader.
         try:
             reader = data_source.reader(schema=schema)
@@ -146,16 +176,94 @@ def main(infile: IO, outfile: IO) -> None:
                 message_parameters={"type": "reader", "error": str(e)},
             )
 
-        # Construct a UDTF.
-        class PythonDataSourceReaderUDTF:
-            def __init__(self) -> None:
-                self.ser = CloudPickleSerializer()
+        # Wrap the data source read logic in an mapInArrow UDF.
+        import pyarrow as pa
+
+        # Create input converter.
+        converter = ArrowTableToRowsConversion._create_converter(BinaryType())
+
+        # Create output converter.
+        return_type = schema
+        pa_schema = to_arrow_schema(return_type)
+        column_names = return_type.fieldNames()
+        column_converters = [
+            LocalDataToArrowConversion._create_converter(field.dataType)
+            for field in return_type.fields
+        ]
+
+        def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> 
Iterable[pa.RecordBatch]:
+            partition_bytes = None
+
+            # Get the partition value from the input iterator.
+            for batch in iterator:
+                # There should be only one row/column in the batch.
+                assert batch.num_columns == 1 and batch.num_rows == 1, (
+                    "Expected each batch to have exactly 1 column and 1 row, "
+                    f"but found {batch.num_columns} columns and 
{batch.num_rows} rows."
+                )
+                columns = [column.to_pylist() for column in batch.columns]
+                partition_bytes = converter(columns[0][0])
+
+            assert (
+                partition_bytes is not None
+            ), "The input iterator for Python data source read function is 
empty."
+
+            # Deserialize the partition value.
+            partition = pickleSer.loads(partition_bytes)
+
+            assert partition is None or isinstance(partition, InputPartition), 
(
+                "Expected the partition value to be of type 'InputPartition', "
+                f"but found '{type(partition).__name__}'."
+            )
+
+            output_iter = reader.read(partition)  # type: ignore[arg-type]
+
+            # Validate the output iterator.
+            if not isinstance(output_iter, Iterator):
+                raise PySparkRuntimeError(
+                    error_class="PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE",
+                    message_parameters={
+                        "type": type(output_iter).__name__,
+                        "name": data_source.name(),
+                        "supported_types": "iterator",
+                    },
+                )
+
+            def batched(iterator: Iterator, n: int) -> Iterator:
+                return iter(functools.partial(lambda it: list(islice(it, n)), 
iterator), [])
+
+            # Convert the results from the `reader.read` method to an iterator 
of arrow batches.
+            num_cols = len(column_names)
+            for batch in batched(output_iter, max_arrow_batch_size):
+                pylist: List[List] = [[] for _ in range(num_cols)]
+                for result in batch:
+                    # Validate the output row schema.
+                    if hasattr(result, "__len__") and len(result) != num_cols:
+                        raise PySparkRuntimeError(
+                            
error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
+                            message_parameters={
+                                "expected": str(num_cols),
+                                "actual": str(len(result)),
+                            },
+                        )
+
+                    # Validate the output row type.
+                    if not isinstance(result, (list, tuple)):
+                        raise PySparkRuntimeError(
+                            
error_class="PYTHON_DATA_SOURCE_READ_INVALID_RETURN_TYPE",
+                            message_parameters={
+                                "type": type(result).__name__,
+                                "name": data_source.name(),
+                                "supported_types": "tuple, list, 
`pyspark.sql.types.Row`",
+                            },
+                        )
+
+                    for col in range(num_cols):
+                        pylist[col].append(column_converters[col](result[col]))
 
-            def eval(self, partition_bytes: Any) -> Iterator:
-                partition = self.ser.loads(partition_bytes)
-                yield from reader.read(partition)
+                yield pa.RecordBatch.from_arrays(pylist, schema=pa_schema)
 
-        command = PythonDataSourceReaderUDTF
+        command = (data_source_read_func, return_type)
         pickleSer._write_with_length(command, outfile)
 
         # Return the serialized partition values.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index d4ed673c3513..fb8b06eb41bc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -134,9 +134,9 @@ case class PythonDataSourcePartitions(
 }
 
 object PythonDataSourcePartitions {
-  def getOutputAttrs: Seq[Attribute] = {
-    toAttributes(new StructType().add("partition", BinaryType))
-  }
+  def schema: StructType = new StructType().add("partition", BinaryType)
+
+  def getOutputAttrs: Seq[Attribute] = toAttributes(schema)
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
index ec4c7c188fa0..7ffd61a4a266 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
@@ -18,8 +18,8 @@
 package org.apache.spark.sql.execution.datasources
 
 import org.apache.spark.api.python.{PythonEvalType, PythonFunction, 
SimplePythonFunction}
-import org.apache.spark.sql.catalyst.expressions.PythonUDTF
-import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, 
Project, PythonDataSource, PythonDataSourcePartitions}
+import org.apache.spark.sql.catalyst.expressions.PythonUDF
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, 
PythonDataSource, PythonDataSourcePartitions, PythonMapInArrow}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_DATA_SOURCE
 import 
org.apache.spark.sql.execution.python.UserDefinedPythonDataSourceReadRunner
@@ -40,18 +40,21 @@ import org.apache.spark.util.ArrayImplicits._
  * class. Post this rule, the plan is transformed into:
  *
  *  Project [output]
- *  +- Generate [python_data_source_read_udtf, ...]
+ *  +- PythonMapInArrow [read_from_data_source, ...]
  *     +- PythonDataSourcePartitions [partition_bytes]
  *
  * The PythonDataSourcePartitions contains a list of serialized partition 
values for the data
- * source. The `DataSourceReader.read` method will be planned as a UDTF that 
accepts a partition
- * value and yields the scanning output.
+ * source. The `DataSourceReader.read` method will be planned as a MapInArrow 
operator that
+ * accepts a partition value and yields the scanning output.
  */
 object PlanPythonDataSourceScan extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
     _.containsPattern(PYTHON_DATA_SOURCE)) {
     case ds @ PythonDataSource(dataSource: PythonFunction, schema, _) =>
-      val info = new UserDefinedPythonDataSourceReadRunner(dataSource, 
schema).runInPython()
+      val inputSchema = PythonDataSourcePartitions.schema
+
+      val info = new UserDefinedPythonDataSourceReadRunner(
+        dataSource, inputSchema, schema).runInPython()
 
       val readerFunc = SimplePythonFunction(
         command = info.func.toImmutableArraySeq,
@@ -65,27 +68,22 @@ object PlanPythonDataSourceScan extends Rule[LogicalPlan] {
       val partitionPlan = PythonDataSourcePartitions(
         PythonDataSourcePartitions.getOutputAttrs, info.partitions)
 
-      // Construct a Python UDTF for the reader function.
-      val pythonUDTF = PythonUDTF(
-        name = "python_data_source_read",
+      val pythonUDF = PythonUDF(
+        name = "read_from_data_source",
         func = readerFunc,
-        elementSchema = schema,
+        dataType = schema,
         children = partitionPlan.output,
-        evalType = PythonEvalType.SQL_TABLE_UDF,
-        udfDeterministic = false,
-        pickledAnalyzeResult = None)
+        evalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
+        udfDeterministic = false)
 
-      // Later the rule `ExtractPythonUDTFs` will turn this Generate
-      // into a evaluable Python UDTF node.
-      val generate = Generate(
-        generator = pythonUDTF,
-        unrequiredChildIndex = Nil,
-        outer = false,
-        qualifier = None,
-        generatorOutput = ds.output,
-        child = partitionPlan)
+      // Construct the plan.
+      val plan = PythonMapInArrow(
+        pythonUDF,
+        ds.output,
+        partitionPlan,
+        isBarrier = false)
 
       // Project out partition values.
-      Project(ds.output, generate)
+      Project(ds.output, plan)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index 7044ef65c638..2c8e1b942727 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -29,6 +29,7 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PythonDataSourc
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, StructType}
 import org.apache.spark.util.ArrayImplicits._
 
@@ -147,9 +148,17 @@ case class PythonDataSourceReadInfo(
     func: Array[Byte],
     partitions: Seq[Array[Byte]])
 
+/**
+ * Send information to a Python process to plan a Python data source read.
+ *
+ * @param func an Python data source instance
+ * @param inputSchema input schema to the data source read from its child plan
+ * @param outputSchema output schema of the Python data source
+ */
 class UserDefinedPythonDataSourceReadRunner(
     func: PythonFunction,
-    schema: StructType) extends 
PythonPlannerRunner[PythonDataSourceReadInfo](func) {
+    inputSchema: StructType,
+    outputSchema: StructType) extends 
PythonPlannerRunner[PythonDataSourceReadInfo](func) {
 
   // See the logic in `pyspark.sql.worker.plan_data_source_read.py`.
   override val workerModule = "pyspark.sql.worker.plan_data_source_read"
@@ -158,8 +167,14 @@ class UserDefinedPythonDataSourceReadRunner(
     // Send Python data source
     PythonWorkerUtils.writePythonFunction(func, dataOut)
 
-    // Send schema
-    PythonWorkerUtils.writeUTF(schema.json, dataOut)
+    // Send input schema
+    PythonWorkerUtils.writeUTF(inputSchema.json, dataOut)
+
+    // Send output schema
+    PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
+
+    // Send configurations
+    dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
   }
 
   override protected def receiveFromPython(dataIn: DataInputStream): 
PythonDataSourceReadInfo = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index ec2f8c19b02b..6bc9166117f2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.python
 
 import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, 
QueryTest, Row}
-import org.apache.spark.sql.catalyst.plans.logical.{BatchEvalPythonUDTF, 
PythonDataSourcePartitions}
+import 
org.apache.spark.sql.catalyst.plans.logical.{PythonDataSourcePartitions, 
PythonMapInArrow}
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.StructType
@@ -40,7 +40,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
       |""".stripMargin
 
   test("simple data source") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
         |from pyspark.sql.datasource import DataSource
@@ -58,15 +58,14 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     assert(df.rdd.getNumPartitions == 2)
     val plan = df.queryExecution.optimizedPlan
     plan match {
-      case BatchEvalPythonUDTF(pythonUDTF, _, _, _: PythonDataSourcePartitions)
-        if pythonUDTF.name == "python_data_source_read" =>
+      case PythonMapInArrow(_, _, _: PythonDataSourcePartitions, _) =>
       case _ => fail(s"Plan did not match the expected pattern. Actual 
plan:\n$plan")
     }
     checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), 
Row(2, 1)))
   }
 
   test("simple data source with string schema") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -85,7 +84,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("simple data source with StructType schema") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -108,7 +107,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("data source with invalid schema") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -129,7 +128,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("register data source") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -177,7 +176,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("load data source") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource, DataSourceReader, 
InputPartition
@@ -222,7 +221,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("reader not implemented") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val dataSourceScript =
        s"""
         |from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -240,7 +239,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("error creating reader") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
         |from pyspark.sql.datasource import DataSource
@@ -260,7 +259,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("data source assertion error") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
         |class $dataSourceName:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to