This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 71792411083 [SPARK-40027][PYTHON][SS][DOCS] Add self-contained 
examples for pyspark.sql.streaming.readwriter
71792411083 is described below

commit 71792411083a71bcfd7a0d94ddf754bf09a27054
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Aug 11 20:19:24 2022 +0900

    [SPARK-40027][PYTHON][SS][DOCS] Add self-contained examples for 
pyspark.sql.streaming.readwriter
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to improve the examples in 
`pyspark.sql.streaming.readwriter` by making each example self-contained with a 
brief explanation and a bit more realistic example.
    
    ### Why are the changes needed?
    
    To make the documentation more readable and able to copy and paste directly 
in PySpark shell.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it changes the documentation
    
    ### How was this patch tested?
    
    Manually ran each doctest.
    
    Closes #37461 from HyukjinKwon/SPARK-40027.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/streaming/readwriter.py | 441 +++++++++++++++++++++++------
 1 file changed, 357 insertions(+), 84 deletions(-)

diff --git a/python/pyspark/sql/streaming/readwriter.py 
b/python/pyspark/sql/streaming/readwriter.py
index 74b89dbe46c..ef3b7e525e3 100644
--- a/python/pyspark/sql/streaming/readwriter.py
+++ b/python/pyspark/sql/streaming/readwriter.py
@@ -24,7 +24,7 @@ from py4j.java_gateway import java_import, JavaObject
 from pyspark.sql.column import _to_seq
 from pyspark.sql.readwriter import OptionUtils, to_str
 from pyspark.sql.streaming.query import StreamingQuery
-from pyspark.sql.types import Row, StructType, StructField, StringType
+from pyspark.sql.types import Row, StructType
 from pyspark.sql.utils import ForeachBatchFunction
 
 if TYPE_CHECKING:
@@ -46,6 +46,22 @@ class DataStreamReader(OptionUtils):
     Notes
     -----
     This API is evolving.
+
+    Examples
+    --------
+    >>> spark.readStream
+    <pyspark.sql.streaming.readwriter.DataStreamReader object ...>
+
+    The example below uses Rate source that generates rows continously.
+    After that, we operate a modulo by 3, and then writes the stream out to 
the console.
+    The streaming query stops in 3 seconds.
+
+    >>> import time
+    >>> df = spark.readStream.format("rate").load()
+    >>> df = df.selectExpr("value % 3 as v")
+    >>> q = df.writeStream.format("console").start()
+    >>> time.sleep(3)
+    >>> q.stop()
     """
 
     def __init__(self, spark: "SparkSession") -> None:
@@ -73,7 +89,23 @@ class DataStreamReader(OptionUtils):
 
         Examples
         --------
-        >>> s = spark.readStream.format("text")
+        >>> spark.readStream.format("text")
+        <pyspark.sql.streaming.readwriter.DataStreamReader object ...>
+
+        This API allows to configure other sources to read. The example below 
writes a small text
+        file, and reads it back via Text source.
+
+        >>> import tempfile
+        >>> import time
+        >>> with tempfile.TemporaryDirectory() as d:
+        ...     # Write a temporary text file to read it.
+        ...     spark.createDataFrame(
+        ...         [("hello",), 
("this",)]).write.mode("overwrite").format("text").save(d)
+        ...
+        ...     # Start a streaming query to read the text file.
+        ...     q = 
spark.readStream.format("text").load(d).writeStream.format("console").start()
+        ...     time.sleep(3)
+        ...     q.stop()
         """
         self._jreader = self._jreader.format(source)
         return self
@@ -99,8 +131,22 @@ class DataStreamReader(OptionUtils):
 
         Examples
         --------
-        >>> s = spark.readStream.schema(sdf_schema)
-        >>> s = spark.readStream.schema("col0 INT, col1 DOUBLE")
+        >>> from pyspark.sql.types import StructField, StructType, StringType
+        >>> spark.readStream.schema(StructType([StructField("data", 
StringType(), True)]))
+        <pyspark.sql.streaming.readwriter.DataStreamReader object ...>
+        >>> spark.readStream.schema("col0 INT, col1 DOUBLE")
+        <pyspark.sql.streaming.readwriter.DataStreamReader object ...>
+
+        The example below specifies a different schema to CSV file.
+
+        >>> import tempfile
+        >>> import time
+        >>> with tempfile.TemporaryDirectory() as d:
+        ...     # Start a streaming query to read the CSV file.
+        ...     spark.readStream.schema("col0 INT, col1 
STRING").format("csv").load(d).printSchema()
+        root
+         |-- col0: integer (nullable = true)
+         |-- col1: string (nullable = true)
         """
         from pyspark.sql import SparkSession
 
@@ -125,7 +171,17 @@ class DataStreamReader(OptionUtils):
 
         Examples
         --------
-        >>> s = spark.readStream.option("x", 1)
+        >>> spark.readStream.option("x", 1)
+        <pyspark.sql.streaming.readwriter.DataStreamReader object ...>
+
+        The example below specifies 'rowsPerSecond' option to Rate source in 
order to generate
+        10 rows every second.
+
+        >>> import time
+        >>> q = spark.readStream.format(
+        ...     "rate").option("rowsPerSecond", 
10).load().writeStream.format("console").start()
+        >>> time.sleep(3)
+        >>> q.stop()
         """
         self._jreader = self._jreader.option(key, to_str(value))
         return self
@@ -141,7 +197,18 @@ class DataStreamReader(OptionUtils):
 
         Examples
         --------
-        >>> s = spark.readStream.options(x="1", y=2)
+        >>> spark.readStream.options(x="1", y=2)
+        <pyspark.sql.streaming.readwriter.DataStreamReader object ...>
+
+        The example below specifies 'rowsPerSecond' and 'numPartitions' 
options to
+        Rate source in order to generate 10 rows with 10 partitions every 
second.
+
+        >>> import time
+        >>> q = spark.readStream.format("rate").options(
+        ...    rowsPerSecond=10, numPartitions=10
+        ... ).load().writeStream.format("console").start()
+        >>> time.sleep(3)
+        >>> q.stop()
         """
         for k in options:
             self._jreader = self._jreader.option(k, to_str(options[k]))
@@ -177,13 +244,22 @@ class DataStreamReader(OptionUtils):
 
         Examples
         --------
-        >>> json_sdf = spark.readStream.format("json") \\
-        ...     .schema(sdf_schema) \\
-        ...     .load(tempfile.mkdtemp())
-        >>> json_sdf.isStreaming
-        True
-        >>> json_sdf.schema == sdf_schema
-        True
+        Load a data stream from a temporary JSON file.
+
+        >>> import tempfile
+        >>> import time
+        >>> with tempfile.TemporaryDirectory() as d:
+        ...     # Write a temporary JSON file to read it.
+        ...     spark.createDataFrame(
+        ...         [(100, "Hyukjin Kwon"),], ["age", "name"]
+        ...     ).write.mode("overwrite").format("json").save(d)
+        ...
+        ...     # Start a streaming query to read the JSON file.
+        ...     q = spark.readStream.schema(
+        ...         "age INT, name STRING"
+        ...     ).format("json").load(d).writeStream.format("console").start()
+        ...     time.sleep(3)
+        ...     q.stop()
         """
         if format is not None:
             self.format(format)
@@ -260,11 +336,22 @@ class DataStreamReader(OptionUtils):
 
         Examples
         --------
-        >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = 
sdf_schema)
-        >>> json_sdf.isStreaming
-        True
-        >>> json_sdf.schema == sdf_schema
-        True
+        Load a data stream from a temporary JSON file.
+
+        >>> import tempfile
+        >>> import time
+        >>> with tempfile.TemporaryDirectory() as d:
+        ...     # Write a temporary JSON file to read it.
+        ...     spark.createDataFrame(
+        ...         [(100, "Hyukjin Kwon"),], ["age", "name"]
+        ...     ).write.mode("overwrite").format("json").save(d)
+        ...
+        ...     # Start a streaming query to read the JSON file.
+        ...     q = spark.readStream.schema(
+        ...         "age INT, name STRING"
+        ...     ).json(d).writeStream.format("console").start()
+        ...     time.sleep(3)
+        ...     q.stop()
         """
         self._set_opts(
             schema=schema,
@@ -316,11 +403,18 @@ class DataStreamReader(OptionUtils):
 
         Examples
         --------
-        >>> orc_sdf = 
spark.readStream.schema(sdf_schema).orc(tempfile.mkdtemp())
-        >>> orc_sdf.isStreaming
-        True
-        >>> orc_sdf.schema == sdf_schema
-        True
+        Load a data stream from a temporary ORC file.
+
+        >>> import tempfile
+        >>> import time
+        >>> with tempfile.TemporaryDirectory() as d:
+        ...     # Write a temporary ORC file to read it.
+        ...     spark.range(10).write.mode("overwrite").format("orc").save(d)
+        ...
+        ...     # Start a streaming query to read the ORC file.
+        ...     q = spark.readStream.schema("id 
LONG").orc(d).writeStream.format("console").start()
+        ...     time.sleep(3)
+        ...     q.stop()
         """
         self._set_opts(
             mergeSchema=mergeSchema,
@@ -362,11 +456,19 @@ class DataStreamReader(OptionUtils):
 
         Examples
         --------
-        >>> parquet_sdf = 
spark.readStream.schema(sdf_schema).parquet(tempfile.mkdtemp())
-        >>> parquet_sdf.isStreaming
-        True
-        >>> parquet_sdf.schema == sdf_schema
-        True
+        Load a data stream from a temporary Parquet file.
+
+        >>> import tempfile
+        >>> import time
+        >>> with tempfile.TemporaryDirectory() as d:
+        ...     # Write a temporary Parquet file to read it.
+        ...     
spark.range(10).write.mode("overwrite").format("parquet").save(d)
+        ...
+        ...     # Start a streaming query to read the Parquet file.
+        ...     q = spark.readStream.schema(
+        ...         "id LONG").parquet(d).writeStream.format("console").start()
+        ...     time.sleep(3)
+        ...     q.stop()
         """
         self._set_opts(
             mergeSchema=mergeSchema,
@@ -418,11 +520,19 @@ class DataStreamReader(OptionUtils):
 
         Examples
         --------
-        >>> text_sdf = spark.readStream.text(tempfile.mkdtemp())
-        >>> text_sdf.isStreaming
-        True
-        >>> "value" in str(text_sdf.schema)
-        True
+        Load a data stream from a temporary text file.
+
+        >>> import tempfile
+        >>> import time
+        >>> with tempfile.TemporaryDirectory() as d:
+        ...     # Write a temporary text file to read it.
+        ...     spark.createDataFrame(
+        ...         [("hello",), 
("this",)]).write.mode("overwrite").format("text").save(d)
+        ...
+        ...     # Start a streaming query to read the text file.
+        ...     q = 
spark.readStream.text(d).writeStream.format("console").start()
+        ...     time.sleep(3)
+        ...     q.stop()
         """
         self._set_opts(
             wholetext=wholetext,
@@ -500,11 +610,20 @@ class DataStreamReader(OptionUtils):
 
         Examples
         --------
-        >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = 
sdf_schema)
-        >>> csv_sdf.isStreaming
-        True
-        >>> csv_sdf.schema == sdf_schema
-        True
+        Load a data stream from a temporary CSV file.
+
+        >>> import tempfile
+        >>> import time
+        >>> with tempfile.TemporaryDirectory() as d:
+        ...     # Write a temporary text file to read it.
+        ...     spark.createDataFrame([(1, 
"2"),]).write.mode("overwrite").format("csv").save(d)
+        ...
+        ...     # Start a streaming query to read the CSV file.
+        ...     q = spark.readStream.schema(
+        ...         "col0 INT, col1 STRING"
+        ...     ).format("csv").load(d).writeStream.format("console").start()
+        ...     time.sleep(3)
+        ...     q.stop()
         """
         self._set_opts(
             schema=schema,
@@ -564,7 +683,22 @@ class DataStreamReader(OptionUtils):
 
         Examples
         --------
-        >>> spark.readStream.table('input_table') # doctest: +SKIP
+        Load a data stream from a table.
+
+        >>> import tempfile
+        >>> import time
+        >>> _ = spark.sql("DROP TABLE IF EXISTS my_table")
+        >>> with tempfile.TemporaryDirectory() as d:
+        ...     # Create a table with Rate source.
+        ...     q1 = 
spark.readStream.format("rate").load().writeStream.toTable(
+        ...         "my_table", checkpointLocation=d)
+        ...
+        ...     # Read the table back and print out in the console.
+        ...     q2 = 
spark.readStream.table("my_table").writeStream.format("console").start()
+        ...     time.sleep(3)
+        ...     q1.stop()
+        ...     q2.stop()
+        ...     _ = spark.sql("DROP TABLE my_table")
         """
         if isinstance(tableName, str):
             return self._df(self._jreader.table(tableName))
@@ -584,6 +718,19 @@ class DataStreamWriter:
     Notes
     -----
     This API is evolving.
+
+    Examples
+    --------
+    The example below uses Rate source that generates rows continously.
+    After that, we operate a modulo by 3, and then writes the stream out to 
the console.
+    The streaming query stops in 3 seconds.
+
+    >>> import time
+    >>> df = spark.readStream.format("rate").load()
+    >>> df = df.selectExpr("value % 3 as v")
+    >>> q = df.writeStream.format("console").start()
+    >>> time.sleep(3)
+    >>> q.stop()
     """
 
     def __init__(self, df: "DataFrame") -> None:
@@ -615,7 +762,18 @@ class DataStreamWriter:
 
         Examples
         --------
-        >>> writer = sdf.writeStream.outputMode('append')
+        >>> df = spark.readStream.format("rate").load()
+        >>> df.writeStream.outputMode('append')
+        <pyspark.sql.streaming.readwriter.DataStreamWriter object ...>
+
+        The example below uses Complete mode that the entire aggregated counts 
are printed out.
+
+        >>> import time
+        >>> df = spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
+        >>> df = df.groupby().count()
+        >>> q = df.writeStream.outputMode("complete").format("console").start()
+        >>> time.sleep(3)
+        >>> q.stop()
         """
         if not outputMode or type(outputMode) != str or 
len(outputMode.strip()) == 0:
             raise ValueError("The output mode must be a non-empty string. Got: 
%s" % outputMode)
@@ -638,7 +796,25 @@ class DataStreamWriter:
 
         Examples
         --------
-        >>> writer = sdf.writeStream.format('json')
+        >>> df = spark.readStream.format("rate").load()
+        >>> df.writeStream.format("text")
+        <pyspark.sql.streaming.readwriter.DataStreamWriter object ...>
+
+        This API allows to configure the source to write. The example below 
writes a CSV
+        file from Rate source in a streaming manner.
+
+        >>> import tempfile
+        >>> import time
+        >>> with tempfile.TemporaryDirectory() as d, 
tempfile.TemporaryDirectory() as cp:
+        ...     df = spark.readStream.format("rate").load()
+        ...     q = df.writeStream.format("csv").option("checkpointLocation", 
cp).start(d)
+        ...     time.sleep(5)
+        ...     q.stop()
+        ...     spark.read.schema("timestamp TIMESTAMP, value 
STRING").csv(d).show()
+        +...---------+-----+
+        |...timestamp|value|
+        +...---------+-----+
+        ...
         """
         self._jwrite = self._jwrite.format(source)
         return self
@@ -651,6 +827,22 @@ class DataStreamWriter:
         Notes
         -----
         This API is evolving.
+
+        Examples
+        --------
+        >>> df = spark.readStream.format("rate").load()
+        >>> df.writeStream.option("x", 1)
+        <pyspark.sql.streaming.readwriter.DataStreamWriter object ...>
+
+        The example below specifies 'numRows' option to Console source in 
order to print
+        3 rows for every batch.
+
+        >>> import time
+        >>> q = spark.readStream.format(
+        ...     "rate").option("rowsPerSecond", 10).load().writeStream.format(
+        ...         "console").option("numRows", 3).start()
+        >>> time.sleep(3)
+        >>> q.stop()
         """
         self._jwrite = self._jwrite.option(key, to_str(value))
         return self
@@ -663,6 +855,22 @@ class DataStreamWriter:
         Notes
         -----
         This API is evolving.
+
+        Examples
+        --------
+        >>> df = spark.readStream.format("rate").load()
+        >>> df.writeStream.option("x", 1)
+        <pyspark.sql.streaming.readwriter.DataStreamWriter object ...>
+
+        The example below specifies 'numRows' and 'truncate' options to 
Console source in order
+        to print 3 rows for every batch without truncating the results.
+
+        >>> import time
+        >>> q = spark.readStream.format(
+        ...     "rate").option("rowsPerSecond", 10).load().writeStream.format(
+        ...         "console").options(numRows=3, truncate=False).start()
+        >>> time.sleep(3)
+        >>> q.stop()
         """
         for k in options:
             self._jwrite = self._jwrite.option(k, to_str(options[k]))
@@ -692,6 +900,28 @@ class DataStreamWriter:
         Notes
         -----
         This API is evolving.
+
+        Examples
+        --------
+        >>> df = spark.readStream.format("rate").load()
+        >>> df.writeStream.partitionBy("value")
+        <pyspark.sql.streaming.readwriter.DataStreamWriter object ...>
+
+        Partition-by timestamp column from Rate source.
+
+        >>> import tempfile
+        >>> import time
+        >>> with tempfile.TemporaryDirectory() as d, 
tempfile.TemporaryDirectory() as cp:
+        ...     df = spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
+        ...     q = df.writeStream.partitionBy(
+        ...         
"timestamp").format("parquet").option("checkpointLocation", cp).start(d)
+        ...     time.sleep(5)
+        ...     q.stop()
+        ...     spark.read.schema(df.schema).parquet(d).show()
+        +...---------+-----+
+        |...timestamp|value|
+        +...---------+-----+
+        ...
         """
         if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
             cols = cols[0]
@@ -716,7 +946,12 @@ class DataStreamWriter:
 
         Examples
         --------
-        >>> writer = sdf.writeStream.queryName('streaming_query')
+        >>> import time
+        >>> df = spark.readStream.format("rate").load()
+        >>> q = 
df.writeStream.queryName("streaming_query").format("console").start()
+        >>> q.stop()
+        >>> q.name
+        'streaming_query'
         """
         if not queryName or type(queryName) != str or len(queryName.strip()) 
== 0:
             raise ValueError("The queryName must be a non-empty string. Got: 
%s" % queryName)
@@ -775,14 +1010,22 @@ class DataStreamWriter:
 
         Examples
         --------
-        >>> # trigger the query for execution every 5 seconds
-        >>> writer = sdf.writeStream.trigger(processingTime='5 seconds')
-        >>> # trigger the query for just once batch of data
-        >>> writer = sdf.writeStream.trigger(once=True)
-        >>> # trigger the query for execution every 5 seconds
-        >>> writer = sdf.writeStream.trigger(continuous='5 seconds')
-        >>> # trigger the query for reading all available data with multiple 
batches
-        >>> writer = sdf.writeStream.trigger(availableNow=True)
+        >>> df = spark.readStream.format("rate").load()
+
+        Trigger the query for execution every 5 seconds
+
+        >>> df.writeStream.trigger(processingTime='5 seconds')
+        <pyspark.sql.streaming.readwriter.DataStreamWriter object ...>
+
+        Trigger the query for execution every 5 seconds
+
+        >>> df.writeStream.trigger(continuous='5 seconds')
+        <pyspark.sql.streaming.readwriter.DataStreamWriter object ...>
+
+        Trigger the query for reading all available data with multiple batches
+
+        >>> df.writeStream.trigger(availableNow=True)
+        <pyspark.sql.streaming.readwriter.DataStreamWriter object ...>
         """
         params = [processingTime, once, continuous, availableNow]
 
@@ -908,22 +1151,34 @@ class DataStreamWriter:
 
         Examples
         --------
-        >>> # Print every row using a function
+        >>> import time
+        >>> df = spark.readStream.format("rate").load()
+
+        Print every row using a function
+
         >>> def print_row(row):
         ...     print(row)
         ...
-        >>> writer = sdf.writeStream.foreach(print_row)
-        >>> # Print every row using a object with process() method
+        >>> q = df.writeStream.foreach(print_row).start()
+        >>> time.sleep(3)
+        >>> q.stop()
+
+        Print every row using a object with process() method
+
         >>> class RowPrinter:
         ...     def open(self, partition_id, epoch_id):
         ...         print("Opened %d, %d" % (partition_id, epoch_id))
         ...         return True
+        ...
         ...     def process(self, row):
         ...         print(row)
+        ...
         ...     def close(self, error):
         ...         print("Closed with error: %s" % str(error))
         ...
-        >>> writer = sdf.writeStream.foreach(RowPrinter())
+        >>> q = df.writeStream.foreach(print_row).start()
+        >>> time.sleep(3)
+        >>> q.stop()
         """
 
         from pyspark.rdd import _wrap_function
@@ -1025,10 +1280,14 @@ class DataStreamWriter:
 
         Examples
         --------
+        >>> import time
+        >>> df = spark.readStream.format("rate").load()
         >>> def func(batch_df, batch_id):
         ...     batch_df.collect()
         ...
-        >>> writer = sdf.writeStream.foreachBatch(func)
+        >>> q = df.writeStream.foreachBatch(func).start()
+        >>> time.sleep(3)
+        >>> q.stop()
         """
 
         from pyspark.java_gateway import ensure_callback_server_started
@@ -1090,21 +1349,28 @@ class DataStreamWriter:
 
         Examples
         --------
-        >>> sq = 
sdf.writeStream.format('memory').queryName('this_query').start()
-        >>> sq.isActive
+        >>> df = spark.readStream.format("rate").load()
+
+        Basic example.
+
+        >>> q = df.writeStream.format('memory').queryName('this_query').start()
+        >>> q.isActive
         True
-        >>> sq.name
+        >>> q.name
         'this_query'
-        >>> sq.stop()
-        >>> sq.isActive
+        >>> q.stop()
+        >>> q.isActive
         False
-        >>> sq = sdf.writeStream.trigger(processingTime='5 seconds').start(
+
+        Example with using other parameters with a trigger.
+
+        >>> q = df.writeStream.trigger(processingTime='5 seconds').start(
         ...     queryName='that_query', outputMode="append", format='memory')
-        >>> sq.name
+        >>> q.name
         'that_query'
-        >>> sq.isActive
+        >>> q.isActive
         True
-        >>> sq.stop()
+        >>> q.stop()
         """
         self.options(**options)
         if outputMode is not None:
@@ -1176,15 +1442,28 @@ class DataStreamWriter:
 
         Examples
         --------
-        >>> 
sdf.writeStream.format('parquet').queryName('query').toTable('output_table')
-        ... # doctest: +SKIP
-
-        >>> sdf.writeStream.trigger(processingTime='5 seconds').toTable(
-        ...     'output_table',
-        ...     queryName='that_query',
-        ...     outputMode="append",
-        ...     format='parquet',
-        ...     checkpointLocation='/tmp/checkpoint') # doctest: +SKIP
+        Save a data stream to a table.
+
+        >>> import tempfile
+        >>> import time
+        >>> _ = spark.sql("DROP TABLE IF EXISTS my_table2")
+        >>> with tempfile.TemporaryDirectory() as d:
+        ...     # Create a table with Rate source.
+        ...     q = spark.readStream.format("rate").option(
+        ...         "rowsPerSecond", 10).load().writeStream.toTable(
+        ...             "my_table2",
+        ...             queryName='that_query',
+        ...             outputMode="append",
+        ...             format='parquet',
+        ...             checkpointLocation=d)
+        ...     time.sleep(3)
+        ...     q.stop()
+        ...     spark.read.table("my_table2").show()
+        ...     _ = spark.sql("DROP TABLE my_table2")
+        +...---------+-----+
+        |...timestamp|value|
+        +...---------+-----+
+        ...
         """
         self.options(**options)
         if outputMode is not None:
@@ -1201,23 +1480,17 @@ class DataStreamWriter:
 def _test() -> None:
     import doctest
     import os
-    import tempfile
     from pyspark.sql import SparkSession
     import pyspark.sql.streaming.readwriter
-    from py4j.protocol import Py4JError
 
     os.chdir(os.environ["SPARK_HOME"])
 
     globs = pyspark.sql.streaming.readwriter.__dict__.copy()
-    try:
-        spark = SparkSession._getActiveSessionOrCreate()
-    except Py4JError:  # noqa: F821
-        spark = SparkSession(sc)  # type: ignore[name-defined] # noqa: F821
-
-    globs["tempfile"] = tempfile
-    globs["spark"] = spark
-    globs["sdf"] = 
spark.readStream.format("text").load("python/test_support/sql/streaming")
-    globs["sdf_schema"] = StructType([StructField("data", StringType(), True)])
+    globs["spark"] = (
+        SparkSession.builder.master("local[4]")
+        .appName("sql.streaming.readwriter tests")
+        .getOrCreate()
+    )
 
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.streaming.readwriter,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to