This is an automated email from the ASF dual-hosted git repository.
zero323 pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new baaa3bbecd9 [SPARK-37014][PYTHON] Inline type hints for
python/pyspark/streaming/context.py
baaa3bbecd9 is described below
commit baaa3bbecd9f63aa0a71cf76de4b53d3c1dcf7a4
Author: dch nguyen <[email protected]>
AuthorDate: Thu Apr 14 02:03:24 2022 +0200
[SPARK-37014][PYTHON] Inline type hints for
python/pyspark/streaming/context.py
### What changes were proposed in this pull request?
Inline type hints for python/pyspark/streaming/context.py from Inline type
hints for python/pyspark/streaming/context.pyi.
### Why are the changes needed?
Currently, there is type hint stub files
python/pyspark/streaming/context.pyi to show the expected types for functions,
but we can also take advantage of static type checking within the functions by
inlining the type hints.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing test.
Closes #34293 from dchvn/SPARK-37014.
Authored-by: dch nguyen <[email protected]>
Signed-off-by: zero323 <[email protected]>
(cherry picked from commit c0c1f35cd9279bc1a7a50119be72a297162a9b55)
Signed-off-by: zero323 <[email protected]>
---
python/pyspark/streaming/context.py | 123 ++++++++++++++++++++++++-----------
python/pyspark/streaming/context.pyi | 71 --------------------
python/pyspark/streaming/kinesis.py | 9 +--
3 files changed, 91 insertions(+), 112 deletions(-)
diff --git a/python/pyspark/streaming/context.py
b/python/pyspark/streaming/context.py
index cc9875d6575..52e5efed063 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -14,18 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+from typing import Any, Callable, List, Optional, TypeVar
-from py4j.java_gateway import java_import, is_instance_of
+from py4j.java_gateway import java_import, is_instance_of, JavaObject
from pyspark import RDD, SparkConf
from pyspark.serializers import NoOpSerializer, UTF8Deserializer,
CloudPickleSerializer
from pyspark.context import SparkContext
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.dstream import DStream
+from pyspark.streaming.listener import StreamingListener
from pyspark.streaming.util import TransformFunction,
TransformFunctionSerializer
__all__ = ["StreamingContext"]
+T = TypeVar("T")
+
class StreamingContext:
"""
@@ -51,27 +55,35 @@ class StreamingContext:
# Reference to a currently active StreamingContext
_activeContext = None
- def __init__(self, sparkContext, batchDuration=None, jssc=None):
-
+ def __init__(
+ self,
+ sparkContext: SparkContext,
+ batchDuration: Optional[int] = None,
+ jssc: Optional[JavaObject] = None,
+ ):
self._sc = sparkContext
self._jvm = self._sc._jvm
self._jssc = jssc or self._initialize_context(self._sc, batchDuration)
- def _initialize_context(self, sc, duration):
+ def _initialize_context(self, sc: SparkContext, duration: Optional[int])
-> JavaObject:
self._ensure_initialized()
+ assert self._jvm is not None and duration is not None
return self._jvm.JavaStreamingContext(sc._jsc,
self._jduration(duration))
- def _jduration(self, seconds):
+ def _jduration(self, seconds: int) -> JavaObject:
"""
Create Duration object given number of seconds
"""
+ assert self._jvm is not None
return self._jvm.Duration(int(seconds * 1000))
@classmethod
- def _ensure_initialized(cls):
+ def _ensure_initialized(cls) -> None:
SparkContext._ensure_initialized()
gw = SparkContext._gateway
+ assert gw is not None
+
java_import(gw.jvm, "org.apache.spark.streaming.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
@@ -83,11 +95,15 @@ class StreamingContext:
# register serializer for TransformFunction
# it happens before creating SparkContext when loading from
checkpointing
cls._transformerSerializer = TransformFunctionSerializer(
- SparkContext._active_spark_context, CloudPickleSerializer(), gw
+ SparkContext._active_spark_context,
+ CloudPickleSerializer(),
+ gw,
)
@classmethod
- def getOrCreate(cls, checkpointPath, setupFunc):
+ def getOrCreate(
+ cls, checkpointPath: str, setupFunc: Callable[[], "StreamingContext"]
+ ) -> "StreamingContext":
"""
Either recreate a StreamingContext from checkpoint data or create a
new StreamingContext.
If checkpoint data exists in the provided `checkpointPath`, then
StreamingContext will be
@@ -104,6 +120,8 @@ class StreamingContext:
cls._ensure_initialized()
gw = SparkContext._gateway
+ assert gw is not None
+
# Check whether valid checkpoint information exists in the given path
ssc_option =
gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath)
if ssc_option.isEmpty():
@@ -121,12 +139,15 @@ class StreamingContext:
sc = SparkContext._active_spark_context
+ assert sc is not None
+
# update ctx in serializer
+ assert cls._transformerSerializer is not None
cls._transformerSerializer.ctx = sc
return StreamingContext(sc, None, jssc)
@classmethod
- def getActive(cls):
+ def getActive(cls) -> Optional["StreamingContext"]:
"""
Return either the currently active StreamingContext (i.e., if there is
a context started
but not stopped) or None.
@@ -149,7 +170,9 @@ class StreamingContext:
return cls._activeContext
@classmethod
- def getActiveOrCreate(cls, checkpointPath, setupFunc):
+ def getActiveOrCreate(
+ cls, checkpointPath: str, setupFunc: Callable[[], "StreamingContext"]
+ ) -> "StreamingContext":
"""
Either return the active StreamingContext (i.e. currently started but
not stopped),
or recreate a StreamingContext from checkpoint data or create a new
StreamingContext
@@ -178,20 +201,20 @@ class StreamingContext:
return setupFunc()
@property
- def sparkContext(self):
+ def sparkContext(self) -> SparkContext:
"""
Return SparkContext which is associated with this StreamingContext.
"""
return self._sc
- def start(self):
+ def start(self) -> None:
"""
Start the execution of the streams.
"""
self._jssc.start()
StreamingContext._activeContext = self
- def awaitTermination(self, timeout=None):
+ def awaitTermination(self, timeout: Optional[int] = None) -> None:
"""
Wait for the execution to stop.
@@ -205,7 +228,7 @@ class StreamingContext:
else:
self._jssc.awaitTerminationOrTimeout(int(timeout * 1000))
- def awaitTerminationOrTimeout(self, timeout):
+ def awaitTerminationOrTimeout(self, timeout: int) -> None:
"""
Wait for the execution to stop. Return `true` if it's stopped; or
throw the reported error during the execution; or `false` if the
@@ -218,7 +241,7 @@ class StreamingContext:
"""
return self._jssc.awaitTerminationOrTimeout(int(timeout * 1000))
- def stop(self, stopSparkContext=True, stopGraceFully=False):
+ def stop(self, stopSparkContext: bool = True, stopGraceFully: bool =
False) -> None:
"""
Stop the execution of the streams, with option of ensuring all
received data has been processed.
@@ -236,7 +259,7 @@ class StreamingContext:
if stopSparkContext:
self._sc.stop()
- def remember(self, duration):
+ def remember(self, duration: int) -> None:
"""
Set each DStreams in this context to remember RDDs it generated
in the last given duration. DStreams remember RDDs only for a
@@ -252,7 +275,7 @@ class StreamingContext:
"""
self._jssc.remember(self._jduration(duration))
- def checkpoint(self, directory):
+ def checkpoint(self, directory: str) -> None:
"""
Sets the context to periodically checkpoint the DStream operations for
master
fault-tolerance. The graph will be checkpointed every batch interval.
@@ -264,7 +287,9 @@ class StreamingContext:
"""
self._jssc.checkpoint(directory)
- def socketTextStream(self, hostname, port,
storageLevel=StorageLevel.MEMORY_AND_DISK_2):
+ def socketTextStream(
+ self, hostname: str, port: int, storageLevel: StorageLevel =
StorageLevel.MEMORY_AND_DISK_2
+ ) -> "DStream[str]":
"""
Create an input from TCP source hostname:port. Data is received using
a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n``
delimited
@@ -284,7 +309,7 @@ class StreamingContext:
self._jssc.socketTextStream(hostname, port, jlevel), self,
UTF8Deserializer()
)
- def textFileStream(self, directory):
+ def textFileStream(self, directory: str) -> "DStream[str]":
"""
Create an input stream that monitors a Hadoop-compatible file system
for new files and reads them as text files. Files must be written to
the
@@ -294,7 +319,7 @@ class StreamingContext:
"""
return DStream(self._jssc.textFileStream(directory), self,
UTF8Deserializer())
- def binaryRecordsStream(self, directory, recordLength):
+ def binaryRecordsStream(self, directory: str, recordLength: int) ->
"DStream[bytes]":
"""
Create an input stream that monitors a Hadoop-compatible file system
for new files and reads them as flat binary files with records of
@@ -313,14 +338,19 @@ class StreamingContext:
self._jssc.binaryRecordsStream(directory, recordLength), self,
NoOpSerializer()
)
- def _check_serializers(self, rdds):
+ def _check_serializers(self, rdds: List[RDD[T]]) -> None:
# make sure they have same serializer
if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
for i in range(len(rdds)):
# reset them to sc.serializer
rdds[i] = rdds[i]._reserialize()
- def queueStream(self, rdds, oneAtATime=True, default=None):
+ def queueStream(
+ self,
+ rdds: List[RDD[T]],
+ oneAtATime: bool = True,
+ default: Optional[RDD[T]] = None,
+ ) -> "DStream[T]":
"""
Create an input stream from a queue of RDDs or list. In each batch,
it will process either one or all of the RDDs returned by the queue.
@@ -339,42 +369,48 @@ class StreamingContext:
Changes to the queue after the stream is created will not be
recognized.
"""
if default and not isinstance(default, RDD):
- default = self._sc.parallelize(default)
+ default = self._sc.parallelize(default) # type: ignore[arg-type]
if not rdds and default:
- rdds = [rdds]
+ rdds = [rdds] # type: ignore[list-item]
if rdds and not isinstance(rdds[0], RDD):
- rdds = [self._sc.parallelize(input) for input in rdds]
+ rdds = [self._sc.parallelize(input) for input in rdds] # type:
ignore[arg-type]
self._check_serializers(rdds)
+ assert self._jvm is not None
queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds])
if default:
default = default._reserialize(rdds[0]._jrdd_deserializer)
+ assert default is not None
jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
else:
jdstream = self._jssc.queueStream(queue, oneAtATime)
return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
- def transform(self, dstreams, transformFunc):
+ def transform(
+ self, dstreams: List["DStream[Any]"], transformFunc: Callable[...,
RDD[T]]
+ ) -> "DStream[T]":
"""
Create a new DStream in which each RDD is generated by applying
a function on RDDs of the DStreams. The order of the JavaRDDs in
the transform function parameter will be the same as the order
of corresponding DStreams in the list.
"""
- jdstreams = [d._jdstream for d in dstreams]
+ jdstreams = [d._jdstream for d in dstreams] # type:
ignore[attr-defined]
# change the final serializer to sc.serializer
func = TransformFunction(
self._sc,
lambda t, *rdds: transformFunc(rdds),
- *[d._jrdd_deserializer for d in dstreams],
+ *[d._jrdd_deserializer for d in dstreams], # type:
ignore[attr-defined]
)
+
+ assert self._jvm is not None
jfunc = self._jvm.TransformFunction(func)
jdstream = self._jssc.transform(jdstreams, jfunc)
return DStream(jdstream, self, self._sc.serializer)
- def union(self, *dstreams):
+ def union(self, *dstreams: "DStream[T]") -> "DStream[T]":
"""
Create a unified DStream from multiple DStreams of the same
type and same slide duration.
@@ -383,30 +419,43 @@ class StreamingContext:
raise ValueError("should have at least one DStream to union")
if len(dstreams) == 1:
return dstreams[0]
- if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
+ if len(set(s._jrdd_deserializer for s in dstreams)) > 1: # type:
ignore[attr-defined]
raise ValueError("All DStreams should have same serializer")
- if len(set(s._slideDuration for s in dstreams)) > 1:
+ if len(set(s._slideDuration for s in dstreams)) > 1: # type:
ignore[attr-defined]
raise ValueError("All DStreams should have same slide duration")
+
+ assert SparkContext._jvm is not None
jdstream_cls =
SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream
jpair_dstream_cls =
SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream
gw = SparkContext._gateway
- if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls):
+ if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls): # type:
ignore[attr-defined]
cls = jdstream_cls
- elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls):
+ elif is_instance_of(
+ gw, dstreams[0]._jdstream, jpair_dstream_cls # type:
ignore[attr-defined]
+ ):
cls = jpair_dstream_cls
else:
- cls_name = dstreams[0]._jdstream.getClass().getCanonicalName()
+ cls_name = (
+ dstreams[0]._jdstream.getClass().getCanonicalName() # type:
ignore[attr-defined]
+ )
raise TypeError("Unsupported Java DStream class %s" % cls_name)
+
+ assert gw is not None
jdstreams = gw.new_array(cls, len(dstreams))
for i in range(0, len(dstreams)):
- jdstreams[i] = dstreams[i]._jdstream
- return DStream(self._jssc.union(jdstreams), self,
dstreams[0]._jrdd_deserializer)
+ jdstreams[i] = dstreams[i]._jdstream # type: ignore[attr-defined]
+ return DStream(
+ self._jssc.union(jdstreams),
+ self,
+ dstreams[0]._jrdd_deserializer, # type: ignore[attr-defined]
+ )
- def addStreamingListener(self, streamingListener):
+ def addStreamingListener(self, streamingListener: StreamingListener) ->
None:
"""
Add a [[org.apache.spark.streaming.scheduler.StreamingListener]]
object for
receiving system events related to streaming.
"""
+ assert self._jvm is not None
self._jssc.addStreamingListener(
self._jvm.JavaStreamingListenerWrapper(
self._jvm.PythonStreamingListenerWrapper(streamingListener)
diff --git a/python/pyspark/streaming/context.pyi
b/python/pyspark/streaming/context.pyi
deleted file mode 100644
index 0d1b2aca739..00000000000
--- a/python/pyspark/streaming/context.pyi
+++ /dev/null
@@ -1,71 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import Any, Callable, List, Optional, TypeVar
-
-from py4j.java_gateway import JavaObject
-
-from pyspark.context import SparkContext
-from pyspark.rdd import RDD
-from pyspark.storagelevel import StorageLevel
-from pyspark.streaming.dstream import DStream
-from pyspark.streaming.listener import StreamingListener
-
-T = TypeVar("T")
-
-class StreamingContext:
- def __init__(
- self,
- sparkContext: SparkContext,
- batchDuration: int = ...,
- jssc: Optional[JavaObject] = ...,
- ) -> None: ...
- @classmethod
- def getOrCreate(
- cls, checkpointPath: str, setupFunc: Callable[[], StreamingContext]
- ) -> StreamingContext: ...
- @classmethod
- def getActive(cls) -> StreamingContext: ...
- @classmethod
- def getActiveOrCreate(
- cls, checkpointPath: str, setupFunc: Callable[[], StreamingContext]
- ) -> StreamingContext: ...
- @property
- def sparkContext(self) -> SparkContext: ...
- def start(self) -> None: ...
- def awaitTermination(self, timeout: Optional[int] = ...) -> None: ...
- def awaitTerminationOrTimeout(self, timeout: int) -> None: ...
- def stop(self, stopSparkContext: bool = ..., stopGraceFully: bool = ...)
-> None: ...
- def remember(self, duration: int) -> None: ...
- def checkpoint(self, directory: str) -> None: ...
- def socketTextStream(
- self, hostname: str, port: int, storageLevel: StorageLevel = ...
- ) -> DStream[str]: ...
- def textFileStream(self, directory: str) -> DStream[str]: ...
- def binaryRecordsStream(self, directory: str, recordLength: int) ->
DStream[bytes]: ...
- def queueStream(
- self,
- rdds: List[RDD[T]],
- oneAtATime: bool = ...,
- default: Optional[RDD[T]] = ...,
- ) -> DStream[T]: ...
- def transform(
- self, dstreams: List[DStream[Any]], transformFunc: Callable[...,
RDD[T]]
- ) -> DStream[T]: ...
- def union(self, *dstreams: DStream[T]) -> DStream[T]: ...
- def addStreamingListener(self, streamingListener: StreamingListener) ->
None: ...
diff --git a/python/pyspark/streaming/kinesis.py
b/python/pyspark/streaming/kinesis.py
index 26d66c394ab..150fb79f572 100644
--- a/python/pyspark/streaming/kinesis.py
+++ b/python/pyspark/streaming/kinesis.py
@@ -153,10 +153,11 @@ class KinesisUtils:
The given AWS credentials will get saved in DStream checkpoints if
checkpointing
is enabled. Make sure that your checkpoint directory is secure.
"""
- jlevel = ssc._sc._getJavaStorageLevel(storageLevel) # type:
ignore[attr-defined]
- jduration = ssc._jduration(checkpointInterval) # type:
ignore[attr-defined]
+ jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
+ jduration = ssc._jduration(checkpointInterval)
- jvm = ssc._jvm # type: ignore[attr-defined]
+ jvm = ssc._jvm
+ assert jvm is not None
try:
helper =
jvm.org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper()
@@ -170,7 +171,7 @@ class KinesisUtils:
)
raise
jstream = helper.createStream(
- ssc._jssc, # type: ignore[attr-defined]
+ ssc._jssc,
kinesisAppName,
streamName,
endpointUrl,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]