This is an automated email from the ASF dual-hosted git repository.
zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dff52d649d1 [SPARK-37015][PYTHON] Inline type hints for
python/pyspark/streaming/dstream.py
dff52d649d1 is described below
commit dff52d649d1e27baf3b107f75636624e0cfe780f
Author: dch nguyen <[email protected]>
AuthorDate: Mon Apr 18 17:38:32 2022 +0200
[SPARK-37015][PYTHON] Inline type hints for
python/pyspark/streaming/dstream.py
### What changes were proposed in this pull request?
Inline type hints for python/pyspark/streaming/dstream.py
### Why are the changes needed?
We can take advantage of static type checking within the functions by
inlining the type hints.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests
Closes #34324 from dchvn/SPARK-37015.
Lead-authored-by: dch nguyen <[email protected]>
Co-authored-by: dch nguyen <[email protected]>
Signed-off-by: zero323 <[email protected]>
---
python/pyspark/streaming/context.py | 22 +--
python/pyspark/streaming/dstream.py | 369 +++++++++++++++++++++++++++--------
python/pyspark/streaming/dstream.pyi | 211 --------------------
3 files changed, 296 insertions(+), 306 deletions(-)
diff --git a/python/pyspark/streaming/context.py
b/python/pyspark/streaming/context.py
index 52e5efed063..0be0c7b034a 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -397,12 +397,12 @@ class StreamingContext:
the transform function parameter will be the same as the order
of corresponding DStreams in the list.
"""
- jdstreams = [d._jdstream for d in dstreams] # type:
ignore[attr-defined]
+ jdstreams = [d._jdstream for d in dstreams]
# change the final serializer to sc.serializer
func = TransformFunction(
self._sc,
lambda t, *rdds: transformFunc(rdds),
- *[d._jrdd_deserializer for d in dstreams], # type:
ignore[attr-defined]
+ *[d._jrdd_deserializer for d in dstreams],
)
assert self._jvm is not None
@@ -419,35 +419,31 @@ class StreamingContext:
raise ValueError("should have at least one DStream to union")
if len(dstreams) == 1:
return dstreams[0]
- if len(set(s._jrdd_deserializer for s in dstreams)) > 1: # type:
ignore[attr-defined]
+ if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
raise ValueError("All DStreams should have same serializer")
- if len(set(s._slideDuration for s in dstreams)) > 1: # type:
ignore[attr-defined]
+ if len(set(s._slideDuration for s in dstreams)) > 1:
raise ValueError("All DStreams should have same slide duration")
assert SparkContext._jvm is not None
jdstream_cls =
SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream
jpair_dstream_cls =
SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream
gw = SparkContext._gateway
- if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls): # type:
ignore[attr-defined]
+ if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls):
cls = jdstream_cls
- elif is_instance_of(
- gw, dstreams[0]._jdstream, jpair_dstream_cls # type:
ignore[attr-defined]
- ):
+ elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls):
cls = jpair_dstream_cls
else:
- cls_name = (
- dstreams[0]._jdstream.getClass().getCanonicalName() # type:
ignore[attr-defined]
- )
+ cls_name = dstreams[0]._jdstream.getClass().getCanonicalName()
raise TypeError("Unsupported Java DStream class %s" % cls_name)
assert gw is not None
jdstreams = gw.new_array(cls, len(dstreams))
for i in range(0, len(dstreams)):
- jdstreams[i] = dstreams[i]._jdstream # type: ignore[attr-defined]
+ jdstreams[i] = dstreams[i]._jdstream
return DStream(
self._jssc.union(jdstreams),
self,
- dstreams[0]._jrdd_deserializer, # type: ignore[attr-defined]
+ dstreams[0]._jrdd_deserializer,
)
def addStreamingListener(self, streamingListener: StreamingListener) ->
None:
diff --git a/python/pyspark/streaming/dstream.py
b/python/pyspark/streaming/dstream.py
index f445a78bd95..934b3ae5783 100644
--- a/python/pyspark/streaming/dstream.py
+++ b/python/pyspark/streaming/dstream.py
@@ -19,19 +19,45 @@ import operator
import time
from itertools import chain
from datetime import datetime
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Hashable,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+ TYPE_CHECKING,
+ cast,
+ overload,
+)
from py4j.protocol import Py4JJavaError
-from pyspark import RDD
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.util import rddToFileName, TransformFunction
-from pyspark.rdd import portable_hash
+from pyspark.rdd import portable_hash, RDD
from pyspark.resultiterable import ResultIterable
+from py4j.java_gateway import JavaObject
+
+if TYPE_CHECKING:
+ from pyspark.serializers import Serializer
+ from pyspark.streaming.context import StreamingContext
__all__ = ["DStream"]
+S = TypeVar("S")
+T = TypeVar("T")
+T_co = TypeVar("T_co", covariant=True)
+U = TypeVar("U")
+K = TypeVar("K", bound=Hashable)
+V = TypeVar("V")
+
-class DStream:
+class DStream(Generic[T_co]):
"""
A Discretized Stream (DStream), the basic abstraction in Spark Streaming,
is a continuous sequence of RDDs (of the same type) representing a
@@ -51,7 +77,12 @@ class DStream:
- A function that is used to generate an RDD after each time interval
"""
- def __init__(self, jdstream, ssc, jrdd_deserializer):
+ def __init__(
+ self,
+ jdstream: JavaObject,
+ ssc: "StreamingContext",
+ jrdd_deserializer: "Serializer",
+ ):
self._jdstream = jdstream
self._ssc = ssc
self._sc = ssc._sc
@@ -59,76 +90,94 @@ class DStream:
self.is_cached = False
self.is_checkpointed = False
- def context(self):
+ def context(self) -> "StreamingContext":
"""
Return the StreamingContext associated with this DStream
"""
return self._ssc
- def count(self):
+ def count(self) -> "DStream[int]":
"""
Return a new DStream in which each RDD has a single element
generated by counting each RDD of this DStream.
"""
return self.mapPartitions(lambda i: [sum(1 for _ in
i)]).reduce(operator.add)
- def filter(self, f):
+ def filter(self: "DStream[T]", f: Callable[[T], bool]) -> "DStream[T]":
"""
Return a new DStream containing only the elements that satisfy
predicate.
"""
- def func(iterator):
+ def func(iterator: Iterable[T]) -> Iterable[T]:
return filter(f, iterator)
return self.mapPartitions(func, True)
- def flatMap(self, f, preservesPartitioning=False):
+ def flatMap(
+ self: "DStream[T]",
+ f: Callable[[T], Iterable[U]],
+ preservesPartitioning: bool = False,
+ ) -> "DStream[U]":
"""
Return a new DStream by applying a function to all elements of
this DStream, and then flattening the results
"""
- def func(s, iterator):
+ def func(s: int, iterator: Iterable[T]) -> Iterable[U]:
return chain.from_iterable(map(f, iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)
- def map(self, f, preservesPartitioning=False):
+ def map(
+ self: "DStream[T]", f: Callable[[T], U], preservesPartitioning: bool =
False
+ ) -> "DStream[U]":
"""
Return a new DStream by applying a function to each element of DStream.
"""
- def func(iterator):
+ def func(iterator: Iterable[T]) -> Iterable[U]:
return map(f, iterator)
return self.mapPartitions(func, preservesPartitioning)
- def mapPartitions(self, f, preservesPartitioning=False):
+ def mapPartitions(
+ self: "DStream[T]",
+ f: Callable[[Iterable[T]], Iterable[U]],
+ preservesPartitioning: bool = False,
+ ) -> "DStream[U]":
"""
Return a new DStream in which each RDD is generated by applying
mapPartitions() to each RDDs of this DStream.
"""
- def func(s, iterator):
+ def func(s: int, iterator: Iterable[T]) -> Iterable[U]:
return f(iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)
- def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+ def mapPartitionsWithIndex(
+ self: "DStream[T]",
+ f: Callable[[int, Iterable[T]], Iterable[U]],
+ preservesPartitioning: bool = False,
+ ) -> "DStream[U]":
"""
Return a new DStream in which each RDD is generated by applying
mapPartitionsWithIndex() to each RDDs of this DStream.
"""
return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f,
preservesPartitioning))
- def reduce(self, func):
+ def reduce(self: "DStream[T]", func: Callable[[T, T], T]) -> "DStream[T]":
"""
Return a new DStream in which each RDD has a single element
generated by reducing each RDD of this DStream.
"""
return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda
x: x[1])
- def reduceByKey(self, func, numPartitions=None):
+ def reduceByKey(
+ self: "DStream[Tuple[K, V]]",
+ func: Callable[[V, V], V],
+ numPartitions: Optional[int] = None,
+ ) -> "DStream[Tuple[K, V]]":
"""
Return a new DStream by applying reduceByKey to each RDD.
"""
@@ -136,40 +185,62 @@ class DStream:
numPartitions = self._sc.defaultParallelism
return self.combineByKey(lambda x: x, func, func, numPartitions)
- def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
numPartitions=None):
+ def combineByKey(
+ self: "DStream[Tuple[K, V]]",
+ createCombiner: Callable[[V], U],
+ mergeValue: Callable[[U, V], U],
+ mergeCombiners: Callable[[U, U], U],
+ numPartitions: Optional[int] = None,
+ ) -> "DStream[Tuple[K, U]]":
"""
Return a new DStream by applying combineByKey to each RDD.
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
- def func(rdd):
+ def func(rdd: RDD[Tuple[K, V]]) -> RDD[Tuple[K, U]]:
return rdd.combineByKey(createCombiner, mergeValue,
mergeCombiners, numPartitions)
return self.transform(func)
- def partitionBy(self, numPartitions, partitionFunc=portable_hash):
+ def partitionBy(
+ self: "DStream[Tuple[K, V]]",
+ numPartitions: int,
+ partitionFunc: Callable[[K], int] = portable_hash,
+ ) -> "DStream[Tuple[K, V]]":
"""
Return a copy of the DStream in which each RDD are partitioned
using the specified partitioner.
"""
return self.transform(lambda rdd: rdd.partitionBy(numPartitions,
partitionFunc))
- def foreachRDD(self, func):
+ @overload
+ def foreachRDD(self: "DStream[T]", func: Callable[[RDD[T]], None]) -> None:
+ ...
+
+ @overload
+ def foreachRDD(self: "DStream[T]", func: Callable[[datetime, RDD[T]],
None]) -> None:
+ ...
+
+ def foreachRDD(
+ self: "DStream[T]",
+ func: Union[Callable[[RDD[T]], None], Callable[[datetime, RDD[T]],
None]],
+ ) -> None:
"""
Apply a function to each RDD in this DStream.
"""
if func.__code__.co_argcount == 1:
old_func = func
- def func(_, rdd):
- return old_func(rdd)
+ def func(_: datetime, rdd: "RDD[T]") -> None:
+ return old_func(rdd) # type: ignore[call-arg, arg-type]
jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
+ assert self._ssc._jvm is not None
api = self._ssc._jvm.PythonDStream
api.callForeachRDD(self._jdstream, jfunc)
- def pprint(self, num=10):
+ def pprint(self, num: int = 10) -> None:
"""
Print the first num elements of each RDD generated in this DStream.
@@ -179,7 +250,7 @@ class DStream:
the number of elements from the first will be printed.
"""
- def takeAndPrint(time, rdd):
+ def takeAndPrint(time: datetime, rdd: RDD[T]) -> None:
taken = rdd.take(num + 1)
print("-------------------------------------------")
print("Time: %s" % time)
@@ -192,40 +263,42 @@ class DStream:
self.foreachRDD(takeAndPrint)
- def mapValues(self, f):
+ def mapValues(self: "DStream[Tuple[K, V]]", f: Callable[[V], U]) ->
"DStream[Tuple[K, U]]":
"""
Return a new DStream by applying a map function to the value of
each key-value pairs in this DStream without changing the key.
"""
- def map_values_fn(kv):
+ def map_values_fn(kv: Tuple[K, V]) -> Tuple[K, U]:
return kv[0], f(kv[1])
return self.map(map_values_fn, preservesPartitioning=True)
- def flatMapValues(self, f):
+ def flatMapValues(
+ self: "DStream[Tuple[K, V]]", f: Callable[[V], Iterable[U]]
+ ) -> "DStream[Tuple[K, U]]":
"""
Return a new DStream by applying a flatmap function to the value
of each key-value pairs in this DStream without changing the key.
"""
- def flat_map_fn(kv):
+ def flat_map_fn(kv: Tuple[K, V]) -> Iterable[Tuple[K, U]]:
return ((kv[0], x) for x in f(kv[1]))
return self.flatMap(flat_map_fn, preservesPartitioning=True)
- def glom(self):
+ def glom(self: "DStream[T]") -> "DStream[List[T]]":
"""
Return a new DStream in which RDD is generated by applying glom()
to RDD of this DStream.
"""
- def func(iterator):
+ def func(iterator: Iterable[T]) -> Iterable[List[T]]:
yield list(iterator)
return self.mapPartitions(func)
- def cache(self):
+ def cache(self: "DStream[T]") -> "DStream[T]":
"""
Persist the RDDs of this DStream with the default storage level
(`MEMORY_ONLY`).
@@ -234,7 +307,7 @@ class DStream:
self.persist(StorageLevel.MEMORY_ONLY)
return self
- def persist(self, storageLevel):
+ def persist(self: "DStream[T]", storageLevel: StorageLevel) ->
"DStream[T]":
"""
Persist the RDDs of this DStream with the given storage level
"""
@@ -243,7 +316,7 @@ class DStream:
self._jdstream.persist(javaStorageLevel)
return self
- def checkpoint(self, interval):
+ def checkpoint(self: "DStream[T]", interval: int) -> "DStream[T]":
"""
Enable periodic checkpointing of RDDs of this DStream
@@ -257,7 +330,9 @@ class DStream:
self._jdstream.checkpoint(self._ssc._jduration(interval))
return self
- def groupByKey(self, numPartitions=None):
+ def groupByKey(
+ self: "DStream[Tuple[K, V]]", numPartitions: Optional[int] = None
+ ) -> "DStream[Tuple[K, Iterable[V]]]":
"""
Return a new DStream by applying groupByKey on each RDD.
"""
@@ -265,20 +340,20 @@ class DStream:
numPartitions = self._sc.defaultParallelism
return self.transform(lambda rdd: rdd.groupByKey(numPartitions))
- def countByValue(self):
+ def countByValue(self: "DStream[K]") -> "DStream[Tuple[K, int]]":
"""
Return a new DStream in which each RDD contains the counts of each
distinct value in each RDD of this DStream.
"""
return self.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)
- def saveAsTextFiles(self, prefix, suffix=None):
+ def saveAsTextFiles(self, prefix: str, suffix: Optional[str] = None) ->
None:
"""
Save each RDD in this DStream as at text file, using string
representation of elements.
"""
- def saveAsTextFile(t, rdd):
+ def saveAsTextFile(t: Optional[datetime], rdd: RDD[T]) -> None:
path = rddToFileName(prefix, suffix, t)
try:
rdd.saveAsTextFile(path)
@@ -307,7 +382,20 @@ class DStream:
# raise
# return self.foreachRDD(saveAsPickleFile)
- def transform(self, func):
+ @overload
+ def transform(self: "DStream[T]", func: Callable[[RDD[T]], RDD[U]]) ->
"TransformedDStream[U]":
+ ...
+
+ @overload
+ def transform(
+ self: "DStream[T]", func: Callable[[datetime, RDD[T]], RDD[U]]
+ ) -> "TransformedDStream[U]":
+ ...
+
+ def transform(
+ self: "DStream[T]",
+ func: Union[Callable[[RDD[T]], RDD[U]], Callable[[datetime, RDD[T]],
RDD[U]]],
+ ) -> "TransformedDStream[U]":
"""
Return a new DStream in which each RDD is generated by applying a
function
on each RDD of this DStream.
@@ -318,13 +406,39 @@ class DStream:
if func.__code__.co_argcount == 1:
oldfunc = func
- def func(_, rdd):
- return oldfunc(rdd)
+ def func(_: datetime, rdd: RDD[T]) -> RDD[U]:
+ return oldfunc(rdd) # type: ignore[arg-type, call-arg]
assert func.__code__.co_argcount == 2, "func should take one or two
arguments"
return TransformedDStream(self, func)
- def transformWith(self, func, other, keepSerializer=False):
+ @overload
+ def transformWith(
+ self: "DStream[T]",
+ func: Callable[[RDD[T], RDD[U]], RDD[V]],
+ other: "DStream[U]",
+ keepSerializer: bool = ...,
+ ) -> "DStream[V]":
+ ...
+
+ @overload
+ def transformWith(
+ self: "DStream[T]",
+ func: Callable[[datetime, RDD[T], RDD[U]], RDD[V]],
+ other: "DStream[U]",
+ keepSerializer: bool = ...,
+ ) -> "DStream[V]":
+ ...
+
+ def transformWith(
+ self: "DStream[T]",
+ func: Union[
+ Callable[[RDD[T], RDD[U]], RDD[V]],
+ Callable[[datetime, RDD[T], RDD[U]], RDD[V]],
+ ],
+ other: "DStream[U]",
+ keepSerializer: bool = False,
+ ) -> "DStream[V]":
"""
Return a new DStream in which each RDD is generated by applying a
function
on each RDD of this DStream and 'other' DStream.
@@ -335,31 +449,37 @@ class DStream:
if func.__code__.co_argcount == 2:
oldfunc = func
- def func(_, a, b):
- return oldfunc(a, b)
+ def func(_: datetime, a: RDD[T], b: RDD[U]) -> RDD[V]:
+ return oldfunc(a, b) # type: ignore[call-arg, arg-type]
assert func.__code__.co_argcount == 3, "func should take two or three
arguments"
- jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer,
other._jrdd_deserializer)
+ jfunc = TransformFunction(
+ self._sc,
+ func,
+ self._jrdd_deserializer,
+ other._jrdd_deserializer,
+ )
+ assert self._sc._jvm is not None
dstream = self._sc._jvm.PythonTransformed2DStream(
self._jdstream.dstream(), other._jdstream.dstream(), jfunc
)
jrdd_serializer = self._jrdd_deserializer if keepSerializer else
self._sc.serializer
return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer)
- def repartition(self, numPartitions):
+ def repartition(self: "DStream[T]", numPartitions: int) -> "DStream[T]":
"""
Return a new DStream with an increased or decreased level of
parallelism.
"""
return self.transform(lambda rdd: rdd.repartition(numPartitions))
@property
- def _slideDuration(self):
+ def _slideDuration(self) -> None:
"""
Return the slideDuration in seconds of this DStream
"""
return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0
- def union(self, other):
+ def union(self: "DStream[T]", other: "DStream[U]") -> "DStream[Union[T,
U]]":
"""
Return a new DStream by unifying data of another DStream with this
DStream.
@@ -373,7 +493,11 @@ class DStream:
raise ValueError("the two DStream should have same slide duration")
return self.transformWith(lambda a, b: a.union(b), other, True)
- def cogroup(self, other, numPartitions=None):
+ def cogroup(
+ self: "DStream[Tuple[K, V]]",
+ other: "DStream[Tuple[K, U]]",
+ numPartitions: Optional[int] = None,
+ ) -> "DStream[Tuple[K, Tuple[ResultIterable[V], ResultIterable[U]]]]":
"""
Return a new DStream by applying 'cogroup' between RDDs of this
DStream and `other` DStream.
@@ -382,9 +506,16 @@ class DStream:
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
- return self.transformWith(lambda a, b: a.cogroup(b, numPartitions),
other)
+ return self.transformWith(
+ lambda a, b: a.cogroup(b, numPartitions),
+ other,
+ )
- def join(self, other, numPartitions=None):
+ def join(
+ self: "DStream[Tuple[K, V]]",
+ other: "DStream[Tuple[K, U]]",
+ numPartitions: Optional[int] = None,
+ ) -> "DStream[Tuple[K, Tuple[V, U]]]":
"""
Return a new DStream by applying 'join' between RDDs of this DStream
and
`other` DStream.
@@ -396,7 +527,11 @@ class DStream:
numPartitions = self._sc.defaultParallelism
return self.transformWith(lambda a, b: a.join(b, numPartitions), other)
- def leftOuterJoin(self, other, numPartitions=None):
+ def leftOuterJoin(
+ self: "DStream[Tuple[K, V]]",
+ other: "DStream[Tuple[K, U]]",
+ numPartitions: Optional[int] = None,
+ ) -> "DStream[Tuple[K, Tuple[V, Optional[U]]]]":
"""
Return a new DStream by applying 'left outer join' between RDDs of
this DStream and
`other` DStream.
@@ -408,7 +543,11 @@ class DStream:
numPartitions = self._sc.defaultParallelism
return self.transformWith(lambda a, b: a.leftOuterJoin(b,
numPartitions), other)
- def rightOuterJoin(self, other, numPartitions=None):
+ def rightOuterJoin(
+ self: "DStream[Tuple[K, V]]",
+ other: "DStream[Tuple[K, U]]",
+ numPartitions: Optional[int] = None,
+ ) -> "DStream[Tuple[K, Tuple[Optional[V], U]]]":
"""
Return a new DStream by applying 'right outer join' between RDDs of
this DStream and
`other` DStream.
@@ -420,7 +559,11 @@ class DStream:
numPartitions = self._sc.defaultParallelism
return self.transformWith(lambda a, b: a.rightOuterJoin(b,
numPartitions), other)
- def fullOuterJoin(self, other, numPartitions=None):
+ def fullOuterJoin(
+ self: "DStream[Tuple[K, V]]",
+ other: "DStream[Tuple[K, U]]",
+ numPartitions: Optional[int] = None,
+ ) -> "DStream[Tuple[K, Tuple[Optional[V], Optional[U]]]]":
"""
Return a new DStream by applying 'full outer join' between RDDs of
this DStream and
`other` DStream.
@@ -432,13 +575,14 @@ class DStream:
numPartitions = self._sc.defaultParallelism
return self.transformWith(lambda a, b: a.fullOuterJoin(b,
numPartitions), other)
- def _jtime(self, timestamp):
+ def _jtime(self, timestamp: Union[datetime, int, float]) -> JavaObject:
"""Convert datetime or unix_timestamp into Time"""
if isinstance(timestamp, datetime):
timestamp = time.mktime(timestamp.timetuple())
+ assert self._sc._jvm is not None
return self._sc._jvm.Time(int(timestamp * 1000))
- def slice(self, begin, end):
+ def slice(self, begin: Union[datetime, int], end: Union[datetime, int]) ->
List[RDD[T]]:
"""
Return all the RDDs between 'begin' to 'end' (both included)
@@ -447,7 +591,7 @@ class DStream:
jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end))
return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds]
- def _validate_window_param(self, window, slide):
+ def _validate_window_param(self, window: int, slide: Optional[int]) ->
None:
duration = self._jdstream.dstream().slideDuration().milliseconds()
if int(window * 1000) % duration != 0:
raise ValueError(
@@ -460,7 +604,7 @@ class DStream:
"dstream's slide (batch) duration (%d ms)" % duration
)
- def window(self, windowDuration, slideDuration=None):
+ def window(self, windowDuration: int, slideDuration: Optional[int] = None)
-> "DStream[T]":
"""
Return a new DStream in which each RDD contains all the elements in
seen in a
sliding window of time over this DStream.
@@ -482,7 +626,13 @@ class DStream:
s = self._ssc._jduration(slideDuration)
return DStream(self._jdstream.window(d, s), self._ssc,
self._jrdd_deserializer)
- def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration,
slideDuration):
+ def reduceByWindow(
+ self: "DStream[T]",
+ reduceFunc: Callable[[T, T], T],
+ invReduceFunc: Optional[Callable[[T, T], T]],
+ windowDuration: int,
+ slideDuration: int,
+ ) -> "DStream[T]":
"""
Return a new DStream in which each RDD has a single element generated
by reducing all
elements in a sliding window over this DStream.
@@ -517,7 +667,9 @@ class DStream:
)
return reduced.map(lambda kv: kv[1])
- def countByWindow(self, windowDuration, slideDuration):
+ def countByWindow(
+ self: "DStream[T]", windowDuration: int, slideDuration: int
+ ) -> "DStream[int]":
"""
Return a new DStream in which each RDD has a single element generated
by counting the number of elements in a window over this DStream.
@@ -530,7 +682,12 @@ class DStream:
operator.add, operator.sub, windowDuration, slideDuration
)
- def countByValueAndWindow(self, windowDuration, slideDuration,
numPartitions=None):
+ def countByValueAndWindow(
+ self: "DStream[T]",
+ windowDuration: int,
+ slideDuration: int,
+ numPartitions: Optional[int] = None,
+ ) -> "DStream[Tuple[T, int]]":
"""
Return a new DStream in which each RDD contains the count of distinct
elements in
RDDs in a sliding window over this DStream.
@@ -553,7 +710,12 @@ class DStream:
)
return counted.filter(lambda kv: kv[1] > 0)
- def groupByKeyAndWindow(self, windowDuration, slideDuration,
numPartitions=None):
+ def groupByKeyAndWindow(
+ self: "DStream[Tuple[K, V]]",
+ windowDuration: int,
+ slideDuration: int,
+ numPartitions: Optional[int] = None,
+ ) -> "DStream[Tuple[K, Iterable[V]]]":
"""
Return a new DStream by applying `groupByKey` over a sliding window.
Similar to `DStream.groupByKey()`, but applies it over a sliding
window.
@@ -572,7 +734,7 @@ class DStream:
"""
ls = self.mapValues(lambda x: [x])
grouped = ls.reduceByKeyAndWindow(
- lambda a, b: a.extend(b) or a,
+ lambda a, b: a.extend(b) or a, # type: ignore[func-returns-value]
lambda a, b: a[len(b) :],
windowDuration,
slideDuration,
@@ -581,8 +743,14 @@ class DStream:
return grouped.mapValues(ResultIterable)
def reduceByKeyAndWindow(
- self, func, invFunc, windowDuration, slideDuration=None,
numPartitions=None, filterFunc=None
- ):
+ self: "DStream[Tuple[K, V]]",
+ func: Callable[[V, V], V],
+ invFunc: Optional[Callable[[V, V], V]],
+ windowDuration: int,
+ slideDuration: Optional[int] = None,
+ numPartitions: Optional[int] = None,
+ filterFunc: Optional[Callable[[Tuple[K, V]], bool]] = None,
+ ) -> "DStream[Tuple[K, V]]":
"""
Return a new DStream by applying incremental `reduceByKey` over a
sliding window.
@@ -621,36 +789,46 @@ class DStream:
if invFunc:
- def reduceFunc(t, a, b):
+ def reduceFunc(t: datetime, a: Any, b: Any) -> Any:
b = b.reduceByKey(func, numPartitions)
r = a.union(b).reduceByKey(func, numPartitions) if a else b
if filterFunc:
r = r.filter(filterFunc)
return r
- def invReduceFunc(t, a, b):
+ def invReduceFunc(t: datetime, a: Any, b: Any) -> Any:
b = b.reduceByKey(func, numPartitions)
joined = a.leftOuterJoin(b, numPartitions)
return joined.mapValues(
- lambda kv: invFunc(kv[0], kv[1]) if kv[1] is not None else
kv[0]
+ lambda kv: invFunc(kv[0], kv[1]) # type: ignore[misc]
+ if kv[1] is not None
+ else kv[0]
)
jreduceFunc = TransformFunction(self._sc, reduceFunc,
reduced._jrdd_deserializer)
jinvReduceFunc = TransformFunction(self._sc, invReduceFunc,
reduced._jrdd_deserializer)
if slideDuration is None:
slideDuration = self._slideDuration
+ assert self._sc._jvm is not None
dstream = self._sc._jvm.PythonReducedWindowedDStream(
reduced._jdstream.dstream(),
jreduceFunc,
jinvReduceFunc,
self._ssc._jduration(windowDuration),
- self._ssc._jduration(slideDuration),
+ self._ssc._jduration(slideDuration), # type: ignore[arg-type]
)
return DStream(dstream.asJavaDStream(), self._ssc,
self._sc.serializer)
else:
- return reduced.window(windowDuration,
slideDuration).reduceByKey(func, numPartitions)
+ return reduced.window(windowDuration, slideDuration).reduceByKey(
+ func, numPartitions # type: ignore[arg-type]
+ )
- def updateStateByKey(self, updateFunc, numPartitions=None,
initialRDD=None):
+ def updateStateByKey(
+ self: "DStream[Tuple[K, V]]",
+ updateFunc: Callable[[Iterable[V], Optional[S]], S],
+ numPartitions: Optional[int] = None,
+ initialRDD: Optional[Union[RDD[Tuple[K, S]], Iterable[Tuple[K, S]]]] =
None,
+ ) -> "DStream[Tuple[K, S]]":
"""
Return a new "state" DStream where the state for each key is updated
by applying
the given function on the previous state of the key and the new values
of the key.
@@ -667,30 +845,37 @@ class DStream:
if initialRDD and not isinstance(initialRDD, RDD):
initialRDD = self._sc.parallelize(initialRDD)
- def reduceFunc(t, a, b):
+ def reduceFunc(t: datetime, a: Any, b: Any) -> Any:
if a is None:
g = b.groupByKey(numPartitions).mapValues(lambda vs:
(list(vs), None))
else:
- g = a.cogroup(b.partitionBy(numPartitions), numPartitions)
+ g = a.cogroup(b.partitionBy(cast(int, numPartitions)),
numPartitions)
g = g.mapValues(lambda ab: (list(ab[1]), list(ab[0])[0] if
len(ab[0]) else None))
state = g.mapValues(lambda vs_s: updateFunc(vs_s[0], vs_s[1]))
return state.filter(lambda k_v: k_v[1] is not None)
jreduceFunc = TransformFunction(
- self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer
+ self._sc,
+ reduceFunc,
+ self._sc.serializer,
+ self._jrdd_deserializer,
)
if initialRDD:
- initialRDD = initialRDD._reserialize(self._jrdd_deserializer)
+ initialRDD = cast(RDD[Tuple[K, S]],
initialRDD)._reserialize(self._jrdd_deserializer)
+ assert self._sc._jvm is not None
dstream = self._sc._jvm.PythonStateDStream(
- self._jdstream.dstream(), jreduceFunc, initialRDD._jrdd
+ self._jdstream.dstream(),
+ jreduceFunc,
+ initialRDD._jrdd,
)
else:
+ assert self._sc._jvm is not None
dstream =
self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
-class TransformedDStream(DStream):
+class TransformedDStream(DStream[U]):
"""
TransformedDStream is a DStream generated by an Python function
transforming each RDD of a DStream to another RDDs.
@@ -699,7 +884,23 @@ class TransformedDStream(DStream):
one transformation.
"""
- def __init__(self, prev, func):
+ @overload
+ def __init__(self: DStream[U], prev: DStream[T], func: Callable[[RDD[T]],
RDD[U]]):
+ ...
+
+ @overload
+ def __init__(
+ self: DStream[U],
+ prev: DStream[T],
+ func: Callable[[datetime, RDD[T]], RDD[U]],
+ ):
+ ...
+
+ def __init__(
+ self,
+ prev: DStream[T],
+ func: Union[Callable[[RDD[T]], RDD[U]], Callable[[datetime, RDD[T]],
RDD[U]]],
+ ):
self._ssc = prev._ssc
self._sc = self._ssc._sc
self._jrdd_deserializer = self._sc.serializer
@@ -710,19 +911,23 @@ class TransformedDStream(DStream):
# Using type() to avoid folding the functions and compacting the
DStreams which is not
# not strictly an object of TransformedDStream.
if type(prev) is TransformedDStream and not prev.is_cached and not
prev.is_checkpointed:
- prev_func = prev.func
- self.func = lambda t, rdd: func(t, prev_func(t, rdd))
- self.prev = prev.prev
+ prev_func: Callable = prev.func
+ func = cast(Callable[[datetime, RDD[T]], RDD[U]], func)
+ self.func: Union[
+ Callable[[RDD[T]], RDD[U]], Callable[[datetime, RDD[T]],
RDD[U]]
+ ] = lambda t, rdd: func(t, prev_func(t, rdd))
+ self.prev: DStream[T] = prev.prev
else:
self.prev = prev
self.func = func
@property
- def _jdstream(self):
+ def _jdstream(self) -> JavaObject:
if self._jdstream_val is not None:
return self._jdstream_val
jfunc = TransformFunction(self._sc, self.func,
self.prev._jrdd_deserializer)
+ assert self._sc._jvm is not None
dstream =
self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
self._jdstream_val = dstream.asJavaDStream()
return self._jdstream_val
diff --git a/python/pyspark/streaming/dstream.pyi
b/python/pyspark/streaming/dstream.pyi
deleted file mode 100644
index c9f31b37f04..00000000000
--- a/python/pyspark/streaming/dstream.pyi
+++ /dev/null
@@ -1,211 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import (
- Callable,
- Generic,
- Hashable,
- Iterable,
- List,
- Optional,
- Tuple,
- TypeVar,
- Union,
-)
-import datetime
-from pyspark.rdd import RDD
-import pyspark.serializers
-from pyspark.storagelevel import StorageLevel
-import pyspark.streaming.context
-
-from py4j.java_gateway import JavaObject
-
-S = TypeVar("S")
-T = TypeVar("T")
-T_co = TypeVar("T_co", covariant=True)
-U = TypeVar("U")
-K = TypeVar("K", bound=Hashable)
-V = TypeVar("V")
-
-class DStream(Generic[T_co]):
- is_cached: bool
- is_checkpointed: bool
- def __init__(
- self,
- jdstream: JavaObject,
- ssc: pyspark.streaming.context.StreamingContext,
- jrdd_deserializer: pyspark.serializers.Serializer,
- ) -> None: ...
- def context(self) -> pyspark.streaming.context.StreamingContext: ...
- def count(self) -> DStream[int]: ...
- def filter(self, f: Callable[[T_co], bool]) -> DStream[T_co]: ...
- def flatMap(
- self: DStream[T_co],
- f: Callable[[T_co], Iterable[U]],
- preservesPartitioning: bool = ...,
- ) -> DStream[U]: ...
- def map(
- self: DStream[T_co], f: Callable[[T_co], U], preservesPartitioning:
bool = ...
- ) -> DStream[U]: ...
- def mapPartitions(
- self, f: Callable[[Iterable[T_co]], Iterable[U]],
preservesPartitioning: bool = ...
- ) -> DStream[U]: ...
- def mapPartitionsWithIndex(
- self,
- f: Callable[[int, Iterable[T_co]], Iterable[U]],
- preservesPartitioning: bool = ...,
- ) -> DStream[U]: ...
- def reduce(self, func: Callable[[T_co, T_co], T_co]) -> DStream[T_co]: ...
- def reduceByKey(
- self: DStream[Tuple[K, V]],
- func: Callable[[V, V], V],
- numPartitions: Optional[int] = ...,
- ) -> DStream[Tuple[K, V]]: ...
- def combineByKey(
- self: DStream[Tuple[K, V]],
- createCombiner: Callable[[V], U],
- mergeValue: Callable[[U, V], U],
- mergeCombiners: Callable[[U, U], U],
- numPartitions: Optional[int] = ...,
- ) -> DStream[Tuple[K, U]]: ...
- def partitionBy(
- self: DStream[Tuple[K, V]],
- numPartitions: int,
- partitionFunc: Callable[[K], int] = ...,
- ) -> DStream[Tuple[K, V]]: ...
- @overload
- def foreachRDD(self, func: Callable[[RDD[T_co]], None]) -> None: ...
- @overload
- def foreachRDD(self, func: Callable[[datetime.datetime, RDD[T_co]], None])
-> None: ...
- def pprint(self, num: int = ...) -> None: ...
- def mapValues(self: DStream[Tuple[K, V]], f: Callable[[V], U]) ->
DStream[Tuple[K, U]]: ...
- def flatMapValues(
- self: DStream[Tuple[K, V]], f: Callable[[V], Iterable[U]]
- ) -> DStream[Tuple[K, U]]: ...
- def glom(self) -> DStream[List[T_co]]: ...
- def cache(self) -> DStream[T_co]: ...
- def persist(self, storageLevel: StorageLevel) -> DStream[T_co]: ...
- def checkpoint(self, interval: int) -> DStream[T_co]: ...
- def groupByKey(
- self: DStream[Tuple[K, V]], numPartitions: Optional[int] = ...
- ) -> DStream[Tuple[K, Iterable[V]]]: ...
- def countByValue(self) -> DStream[Tuple[T_co, int]]: ...
- def saveAsTextFiles(self, prefix: str, suffix: Optional[str] = ...) ->
None: ...
- @overload
- def transform(self, func: Callable[[RDD[T_co]], RDD[U]]) ->
TransformedDStream[U]: ...
- @overload
- def transform(
- self, func: Callable[[datetime.datetime, RDD[T_co]], RDD[U]]
- ) -> TransformedDStream[U]: ...
- @overload
- def transformWith(
- self,
- func: Callable[[RDD[T_co], RDD[U]], RDD[V]],
- other: RDD[U],
- keepSerializer: bool = ...,
- ) -> DStream[V]: ...
- @overload
- def transformWith(
- self,
- func: Callable[[datetime.datetime, RDD[T_co], RDD[U]], RDD[V]],
- other: RDD[U],
- keepSerializer: bool = ...,
- ) -> DStream[V]: ...
- def repartition(self, numPartitions: int) -> DStream[T_co]: ...
- def union(self, other: DStream[U]) -> DStream[Union[T_co, U]]: ...
- def cogroup(
- self: DStream[Tuple[K, V]],
- other: DStream[Tuple[K, U]],
- numPartitions: Optional[int] = ...,
- ) -> DStream[Tuple[K, Tuple[List[V], List[U]]]]: ...
- def join(
- self: DStream[Tuple[K, V]],
- other: DStream[Tuple[K, U]],
- numPartitions: Optional[int] = ...,
- ) -> DStream[Tuple[K, Tuple[V, U]]]: ...
- def leftOuterJoin(
- self: DStream[Tuple[K, V]],
- other: DStream[Tuple[K, U]],
- numPartitions: Optional[int] = ...,
- ) -> DStream[Tuple[K, Tuple[V, Optional[U]]]]: ...
- def rightOuterJoin(
- self: DStream[Tuple[K, V]],
- other: DStream[Tuple[K, U]],
- numPartitions: Optional[int] = ...,
- ) -> DStream[Tuple[K, Tuple[Optional[V], U]]]: ...
- def fullOuterJoin(
- self: DStream[Tuple[K, V]],
- other: DStream[Tuple[K, U]],
- numPartitions: Optional[int] = ...,
- ) -> DStream[Tuple[K, Tuple[Optional[V], Optional[U]]]]: ...
- def slice(
- self, begin: Union[datetime.datetime, int], end:
Union[datetime.datetime, int]
- ) -> List[RDD[T_co]]: ...
- def window(self, windowDuration: int, slideDuration: Optional[int] = ...)
-> DStream[T_co]: ...
- def reduceByWindow(
- self,
- reduceFunc: Callable[[T_co, T_co], T_co],
- invReduceFunc: Optional[Callable[[T_co, T_co], T_co]],
- windowDuration: int,
- slideDuration: int,
- ) -> DStream[T_co]: ...
- def countByWindow(
- self, windowDuration: int, slideDuration: int
- ) -> DStream[Tuple[T_co, int]]: ...
- def countByValueAndWindow(
- self,
- windowDuration: int,
- slideDuration: int,
- numPartitions: Optional[int] = ...,
- ) -> DStream[Tuple[T_co, int]]: ...
- def groupByKeyAndWindow(
- self: DStream[Tuple[K, V]],
- windowDuration: int,
- slideDuration: int,
- numPartitions: Optional[int] = ...,
- ) -> DStream[Tuple[K, Iterable[V]]]: ...
- def reduceByKeyAndWindow(
- self: DStream[Tuple[K, V]],
- func: Callable[[V, V], V],
- invFunc: Optional[Callable[[V, V], V]],
- windowDuration: int,
- slideDuration: Optional[int] = ...,
- numPartitions: Optional[int] = ...,
- filterFunc: Optional[Callable[[Tuple[K, V]], bool]] = ...,
- ) -> DStream[Tuple[K, V]]: ...
- def updateStateByKey(
- self: DStream[Tuple[K, V]],
- updateFunc: Callable[[Iterable[V], Optional[S]], S],
- numPartitions: Optional[int] = ...,
- initialRDD: Optional[RDD[Tuple[K, S]]] = ...,
- ) -> DStream[Tuple[K, S]]: ...
-
-class TransformedDStream(DStream[U]):
- is_cached: bool
- is_checkpointed: bool
- func: Callable
- prev: DStream
- @overload
- def __init__(self: DStream[U], prev: DStream[T], func: Callable[[RDD[T]],
RDD[U]]) -> None: ...
- @overload
- def __init__(
- self: DStream[U],
- prev: DStream[T],
- func: Callable[[datetime.datetime, RDD[T]], RDD[U]],
- ) -> None: ...
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]