This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 88696eb [SPARK-38121][PYTHON][SQL] Use SparkSession instead of
SQLContext inside PySpark
88696eb is described below
commit 88696ebcb72fd3057b1546831f653b29b7e0abb2
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Feb 15 09:42:31 2022 +0900
[SPARK-38121][PYTHON][SQL] Use SparkSession instead of SQLContext inside
PySpark
### What changes were proposed in this pull request?
This PR proposes to `SparkSession` within PySpark. This is a base work for
respecting runtime configurations, etc. Currently, we rely on old deprecated
`SQLContext` internally that doesn't respect Spark session's runtime
configurations correctly.
This PR also contains related changes (and a bit of refactoring in the code
this PR touches) as below:
- Expose `DataFrame.sparkSession` like Scala API does.
- Move `SQLContext._conf` -> `SparkSession._jconf`.
- Rename `rdd_array` to `df_array` at `DataFrame.randomSplit`.
- Issue warnings to discourage to use `DataFrame.sql_ctx` and
`DataFrame(..., sql_ctx)`.
### Why are the changes needed?
- This is a base work for PySpark to respect runtime configuration.
- To expose the same API layer as Scala API (`df.sparkSession`)
- To avoid relaying on old `SQLContext`.
### Does this PR introduce _any_ user-facing change?
Yes.
- Issue warnings to discourage to use `DataFrame.sql_ctx` and
`DataFrame(..., sql_ctx)`.
- New API `DataFrame.sparkSession`
### How was this patch tested?
Existing test cases should cover them.
Closes #35410 from HyukjinKwon/SPARK-38121.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/docs/source/reference/pyspark.sql.rst | 1 +
python/pyspark/ml/clustering.py | 2 +-
python/pyspark/ml/common.py | 2 +-
python/pyspark/ml/fpm.py | 2 +-
python/pyspark/ml/wrapper.py | 2 +-
python/pyspark/mllib/common.py | 2 +-
python/pyspark/pandas/internal.py | 2 +-
python/pyspark/shell.py | 3 +-
python/pyspark/sql/catalog.py | 4 +-
python/pyspark/sql/context.py | 42 +++--
python/pyspark/sql/dataframe.py | 185 ++++++++++++++-------
python/pyspark/sql/functions.py | 2 +-
python/pyspark/sql/group.py | 14 +-
python/pyspark/sql/observation.py | 2 +-
python/pyspark/sql/pandas/conversion.py | 33 ++--
python/pyspark/sql/pandas/group_ops.py | 5 +-
python/pyspark/sql/pandas/map_ops.py | 4 +-
python/pyspark/sql/readwriter.py | 12 +-
python/pyspark/sql/session.py | 53 +++---
python/pyspark/sql/streaming.py | 8 +-
python/pyspark/sql/tests/test_session.py | 16 +-
python/pyspark/sql/tests/test_streaming.py | 16 +-
python/pyspark/sql/tests/test_udf.py | 8 +-
python/pyspark/sql/tests/test_udf_profiler.py | 11 +-
python/pyspark/sql/utils.py | 8 +-
.../spark/sql/api/python/PythonSQLUtils.scala | 14 +-
.../org/apache/spark/sql/api/r/SQLUtils.scala | 4 +-
.../sql/execution/arrow/ArrowConverters.scala | 12 +-
28 files changed, 266 insertions(+), 203 deletions(-)
diff --git a/python/docs/source/reference/pyspark.sql.rst
b/python/docs/source/reference/pyspark.sql.rst
index 818814c..1d34961 100644
--- a/python/docs/source/reference/pyspark.sql.rst
+++ b/python/docs/source/reference/pyspark.sql.rst
@@ -201,6 +201,7 @@ DataFrame APIs
DataFrame.show
DataFrame.sort
DataFrame.sortWithinPartitions
+ DataFrame.sparkSession
DataFrame.stat
DataFrame.storageLevel
DataFrame.subtract
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index a66d6e3..c8b4c93 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -2107,7 +2107,7 @@ class PowerIterationClustering(
assert self._java_obj is not None
jdf = self._java_obj.assignClusters(dataset._jdf)
- return DataFrame(jdf, dataset.sql_ctx)
+ return DataFrame(jdf, dataset.sparkSession)
if __name__ == "__main__":
diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py
index 2329421..32829c4 100644
--- a/python/pyspark/ml/common.py
+++ b/python/pyspark/ml/common.py
@@ -108,7 +108,7 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump",
encoding: str = "byt
return RDD(jrdd, sc)
if clsName == "Dataset":
- return DataFrame(r, SparkSession(sc)._wrapped)
+ return DataFrame(r, SparkSession._getActiveSessionOrCreate())
if clsName in _picklable_classes:
r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 0795ec2..b748a7d 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -510,7 +510,7 @@ class PrefixSpan(JavaParams):
self._transfer_params_to_java()
assert self._java_obj is not None
jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf)
- return DataFrame(jdf, dataset.sql_ctx)
+ return DataFrame(jdf, dataset.sparkSession)
if __name__ == "__main__":
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 7f03f64..385a243 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -393,7 +393,7 @@ class JavaTransformer(JavaParams, Transformer,
metaclass=ABCMeta):
assert self._java_obj is not None
self._transfer_params_to_java()
- return DataFrame(self._java_obj.transform(dataset._jdf),
dataset.sql_ctx)
+ return DataFrame(self._java_obj.transform(dataset._jdf),
dataset.sparkSession)
@inherit_doc
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index 24a3f41..00653aa 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -110,7 +110,7 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump",
encoding: str = "byt
return RDD(jrdd, sc)
if clsName == "Dataset":
- return DataFrame(r, SparkSession(sc)._wrapped)
+ return DataFrame(r, SparkSession._getActiveSessionOrCreate())
if clsName in _picklable_classes:
r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
diff --git a/python/pyspark/pandas/internal.py
b/python/pyspark/pandas/internal.py
index 1c32c43..71f8f6e 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -906,7 +906,7 @@ class InternalFrame:
if len(sdf.columns) > 0:
return SparkDataFrame(
sdf._jdf.toDF().withSequenceColumn(column_name), # type:
ignore[operator]
- sdf.sql_ctx,
+ sdf.sparkSession,
)
else:
cnt = sdf.count()
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index 4164e3a..e0a8c06 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -28,6 +28,7 @@ import warnings
from pyspark.context import SparkContext
from pyspark.sql import SparkSession
+from pyspark.sql.context import SQLContext
if os.environ.get("SPARK_EXECUTOR_URI"):
SparkContext.setSystemProperty("spark.executor.uri",
os.environ["SPARK_EXECUTOR_URI"])
@@ -49,7 +50,7 @@ sql = spark.sql
atexit.register((lambda sc: lambda: sc.stop())(sc))
# for compatibility
-sqlContext = spark._wrapped
+sqlContext = SQLContext._get_or_create(sc)
sqlCtx = sqlContext
print(
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index ea8bb97..3ececfa 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -345,7 +345,7 @@ class Catalog:
if path is not None:
options["path"] = path
if source is None:
- c = self._sparkSession._wrapped._conf
+ c = self._sparkSession._jconf
source = c.defaultDataSourceName() # type: ignore[attr-defined]
if description is None:
description = ""
@@ -356,7 +356,7 @@ class Catalog:
raise TypeError("schema should be StructType")
scala_datatype = self._jsparkSession.parseDataType(schema.json())
df = self._jcatalog.createTable(tableName, source, scala_datatype,
description, options)
- return DataFrame(df, self._sparkSession._wrapped)
+ return DataFrame(df, self._sparkSession)
def dropTempView(self, viewName: str) -> None:
"""Drops the local temporary view with the given view name in the
catalog.
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 6ab70ee..6f94e9a 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -46,7 +46,6 @@ from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.sql.types import AtomicType, DataType, StructType
from pyspark.sql.streaming import StreamingQueryManager
-from pyspark.conf import SparkConf
if TYPE_CHECKING:
from pyspark.sql._typing import (
@@ -121,7 +120,7 @@ class SQLContext:
if sparkSession is None:
sparkSession = SparkSession._getActiveSessionOrCreate()
if jsqlContext is None:
- jsqlContext = sparkSession._jwrapped
+ jsqlContext = sparkSession._jsparkSession.sqlContext()
self.sparkSession = sparkSession
self._jsqlContext = jsqlContext
_monkey_patch_RDD(self.sparkSession)
@@ -141,11 +140,6 @@ class SQLContext:
"""
return self._jsqlContext
- @property
- def _conf(self) -> SparkConf:
- """Accessor for the JVM SQL-specific configurations"""
- return self.sparkSession._jsparkSession.sessionState().conf()
-
@classmethod
def getOrCreate(cls: Type["SQLContext"], sc: SparkContext) -> "SQLContext":
"""
@@ -164,17 +158,20 @@ class SQLContext:
"Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate()
instead.",
FutureWarning,
)
+ return cls._get_or_create(sc)
+
+ @classmethod
+ def _get_or_create(cls: Type["SQLContext"], sc: SparkContext) ->
"SQLContext":
if (
cls._instantiatedContext is None
or SQLContext._instantiatedContext._sc._jsc is None # type:
ignore[union-attr]
):
assert sc._jvm is not None
- jsqlContext = (
-
sc._jvm.SparkSession.builder().sparkContext(sc._jsc.sc()).getOrCreate().sqlContext()
- )
- sparkSession = SparkSession(sc, jsqlContext.sparkSession())
- cls(sc, sparkSession, jsqlContext)
+ # There can be only one running Spark context. That will
automatically
+ # be used in the Spark session internally.
+ session = SparkSession._getActiveSessionOrCreate()
+ cls(sc, session, session._jsparkSession.sqlContext())
return cast(SQLContext, cls._instantiatedContext)
def newSession(self) -> "SQLContext":
@@ -590,9 +587,9 @@ class SQLContext:
Row(namespace='', tableName='table1', isTemporary=True)
"""
if dbName is None:
- return DataFrame(self._ssql_ctx.tables(), self)
+ return DataFrame(self._ssql_ctx.tables(), self.sparkSession)
else:
- return DataFrame(self._ssql_ctx.tables(dbName), self)
+ return DataFrame(self._ssql_ctx.tables(dbName), self.sparkSession)
def tableNames(self, dbName: Optional[str] = None) -> List[str]:
"""Returns a list of names of tables in the database ``dbName``.
@@ -647,7 +644,7 @@ class SQLContext:
-------
:class:`DataFrameReader`
"""
- return DataFrameReader(self)
+ return DataFrameReader(self.sparkSession)
@property
def readStream(self) -> DataStreamReader:
@@ -669,7 +666,7 @@ class SQLContext:
>>> text_sdf.isStreaming
True
"""
- return DataStreamReader(self)
+ return DataStreamReader(self.sparkSession)
@property
def streams(self) -> StreamingQueryManager:
@@ -714,14 +711,13 @@ class HiveContext(SQLContext):
+ "SparkSession.builder.enableHiveSupport().getOrCreate()
instead.",
FutureWarning,
)
+ static_conf = {}
if jhiveContext is None:
- sparkContext._conf.set( # type: ignore[attr-defined]
- "spark.sql.catalogImplementation", "hive"
- )
- sparkSession =
SparkSession.builder._sparkContext(sparkContext).getOrCreate()
- else:
- sparkSession = SparkSession(sparkContext,
jhiveContext.sparkSession())
- SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext)
+ static_conf = {"spark.sql.catalogImplementation": "in-memory"}
+ # There can be only one running Spark context. That will automatically
+ # be used in the Spark session internally.
+ session = SparkSession._getActiveSessionOrCreate(**static_conf)
+ SQLContext.__init__(self, sparkContext, session, jhiveContext)
@classmethod
def _createForTesting(cls, sparkContext: SparkContext) -> "HiveContext":
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 0372527..610b8d6 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -16,6 +16,7 @@
#
import json
+import os
import sys
import random
import warnings
@@ -70,6 +71,7 @@ if TYPE_CHECKING:
from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
from pyspark.sql._typing import ColumnOrName, LiteralType,
OptionalPrimitiveType
from pyspark.sql.context import SQLContext
+ from pyspark.sql.session import SparkSession
from pyspark.sql.group import GroupedData
from pyspark.sql.observation import Observation
@@ -102,12 +104,34 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
.groupBy(department.name, "gender").agg({"salary": "avg", "age":
"max"})
.. versionadded:: 1.3.0
+
+ .. note: A DataFrame should only be created as described above. It should
not be directly
+ created via using the constructor.
"""
- def __init__(self, jdf: JavaObject, sql_ctx: "SQLContext"):
- self._jdf = jdf
- self.sql_ctx = sql_ctx
- self._sc: SparkContext = cast(SparkContext, sql_ctx and sql_ctx._sc)
+ def __init__(
+ self,
+ jdf: JavaObject,
+ sql_ctx: Union["SQLContext", "SparkSession"],
+ ):
+ from pyspark.sql.context import SQLContext
+
+ self._session: Optional["SparkSession"] = None
+ self._sql_ctx: Optional["SQLContext"] = None
+
+ if isinstance(sql_ctx, SQLContext):
+ assert not os.environ.get("SPARK_TESTING") # Sanity check for our
internal usage.
+ assert isinstance(sql_ctx, SQLContext)
+ # We should remove this if-else branch in the future release, and
rename
+ # sql_ctx to session in the constructor. This is an internal code
path but
+ # was kept with an warning because it's used intensively by
third-party libraries.
+ warnings.warn("DataFrame constructor is internal. Do not directly
use it.")
+ self._sql_ctx = sql_ctx
+ else:
+ self._session = sql_ctx
+
+ self._sc: SparkContext = sql_ctx._sc
+ self._jdf: JavaObject = jdf
self.is_cached = False
# initialized lazily
self._schema: Optional[StructType] = None
@@ -116,13 +140,45 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
# by __repr__ and _repr_html_ while eager evaluation opened.
self._support_repr_html = False
+ @property
+ def sql_ctx(self) -> "SQLContext":
+ from pyspark.sql.context import SQLContext
+
+ warnings.warn(
+ "DataFrame.sql_ctx is an internal property, and will be removed "
+ "in future releases. Use DataFrame.sparkSession instead."
+ )
+ if self._sql_ctx is None:
+ self._sql_ctx = SQLContext._get_or_create(self._sc)
+ return self._sql_ctx
+
+ @property # type: ignore[misc]
+ def sparkSession(self) -> "SparkSession":
+ """Returns Spark session that created this :class:`DataFrame`.
+
+ .. versionadded:: 3.3.0
+
+ Examples
+ --------
+ >>> df = spark.range(1)
+ >>> type(df.sparkSession)
+ <class 'pyspark.sql.session.SparkSession'>
+ """
+ from pyspark.sql.session import SparkSession
+
+ if self._session is None:
+ self._session = SparkSession._getActiveSessionOrCreate()
+ return self._session
+
@property # type: ignore[misc]
@since(1.3)
def rdd(self) -> "RDD[Row]":
"""Returns the content as an :class:`pyspark.RDD` of :class:`Row`."""
if self._lazy_rdd is None:
jrdd = self._jdf.javaToPython()
- self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc,
BatchedSerializer(CPickleSerializer()))
+ self._lazy_rdd = RDD(
+ jrdd, self.sparkSession._sc,
BatchedSerializer(CPickleSerializer())
+ )
return self._lazy_rdd
@property # type: ignore[misc]
@@ -456,7 +512,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
+---+---+
"""
- return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx)
+ return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession)
@since(1.3)
def isLocal(self) -> bool:
@@ -563,12 +619,12 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
def __repr__(self) -> str:
if (
not self._support_repr_html
- and self.sql_ctx._conf.isReplEagerEvalEnabled() # type:
ignore[attr-defined]
+ and self.sparkSession._jconf.isReplEagerEvalEnabled() # type:
ignore[attr-defined]
):
vertical = False
return self._jdf.showString(
- self.sql_ctx._conf.replEagerEvalMaxNumRows(), # type:
ignore[attr-defined]
- self.sql_ctx._conf.replEagerEvalTruncate(), # type:
ignore[attr-defined]
+ self.sparkSession._jconf.replEagerEvalMaxNumRows(), # type:
ignore[attr-defined]
+ self.sparkSession._jconf.replEagerEvalTruncate(), # type:
ignore[attr-defined]
vertical,
) # type: ignore[attr-defined]
else:
@@ -581,13 +637,13 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
if not self._support_repr_html:
self._support_repr_html = True
- if self.sql_ctx._conf.isReplEagerEvalEnabled(): # type:
ignore[attr-defined]
+ if self.sparkSession._jconf.isReplEagerEvalEnabled(): # type:
ignore[attr-defined]
max_num_rows = max(
- self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0 # type:
ignore[attr-defined]
+ self.sparkSession._jconf.replEagerEvalMaxNumRows(), 0 # type:
ignore[attr-defined]
)
sock_info = self._jdf.getRowsToPython(
max_num_rows,
- self.sql_ctx._conf.replEagerEvalTruncate(), # type:
ignore[attr-defined]
+ self.sparkSession._jconf.replEagerEvalTruncate(), # type:
ignore[attr-defined]
)
rows = list(_load_from_socket(sock_info,
BatchedSerializer(CPickleSerializer())))
head = rows[0]
@@ -631,7 +687,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
This API is experimental.
"""
jdf = self._jdf.checkpoint(eager)
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def localCheckpoint(self, eager: bool = True) -> "DataFrame":
"""Returns a locally checkpointed version of this :class:`DataFrame`.
Checkpointing can be
@@ -651,7 +707,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
This API is experimental.
"""
jdf = self._jdf.localCheckpoint(eager)
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def withWatermark(self, eventTime: str, delayThreshold: str) ->
"DataFrame":
"""Defines an event time watermark for this :class:`DataFrame`. A
watermark tracks a point
@@ -695,7 +751,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
if not delayThreshold or type(delayThreshold) is not str:
raise TypeError("delayThreshold should be provided as a string
interval")
jdf = self._jdf.withWatermark(eventTime, delayThreshold)
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def hint(
self, name: str, *parameters: Union["PrimitiveType",
List["PrimitiveType"]]
@@ -740,7 +796,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
)
jdf = self._jdf.hint(name, self._jseq(parameters))
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def count(self) -> int:
"""Returns the number of rows in this :class:`DataFrame`.
@@ -804,7 +860,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
[]
"""
jdf = self._jdf.limit(num)
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def take(self, num: int) -> List[Row]:
"""Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
@@ -970,7 +1026,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
>>> df.coalesce(1).rdd.getNumPartitions()
1
"""
- return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)
+ return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession)
@overload
def repartition(self, numPartitions: int, *cols: "ColumnOrName") ->
"DataFrame":
@@ -1041,14 +1097,15 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
"""
if isinstance(numPartitions, int):
if len(cols) == 0:
- return DataFrame(self._jdf.repartition(numPartitions),
self.sql_ctx)
+ return DataFrame(self._jdf.repartition(numPartitions),
self.sparkSession)
else:
return DataFrame(
- self._jdf.repartition(numPartitions, self._jcols(*cols)),
self.sql_ctx
+ self._jdf.repartition(numPartitions, self._jcols(*cols)),
+ self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
cols = (numPartitions,) + cols
- return DataFrame(self._jdf.repartition(self._jcols(*cols)),
self.sql_ctx)
+ return DataFrame(self._jdf.repartition(self._jcols(*cols)),
self.sparkSession)
else:
raise TypeError("numPartitions should be an int or Column")
@@ -1115,11 +1172,12 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
raise ValueError("At least one partition-by expression must be
specified.")
else:
return DataFrame(
- self._jdf.repartitionByRange(numPartitions,
self._jcols(*cols)), self.sql_ctx
+ self._jdf.repartitionByRange(numPartitions,
self._jcols(*cols)),
+ self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
cols = (numPartitions,) + cols
- return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)),
self.sql_ctx)
+ return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)),
self.sparkSession)
else:
raise TypeError("numPartitions should be an int, string or Column")
@@ -1133,7 +1191,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
>>> df.distinct().count()
2
"""
- return DataFrame(self._jdf.distinct(), self.sql_ctx)
+ return DataFrame(self._jdf.distinct(), self.sparkSession)
@overload
def sample(self, fraction: float, seed: Optional[int] = ...) ->
"DataFrame":
@@ -1228,7 +1286,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
seed = int(seed) if seed is not None else None
args = [arg for arg in [withReplacement, fraction, seed] if arg is not
None]
jdf = self._jdf.sample(*args)
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def sampleBy(
self, col: "ColumnOrName", fractions: Dict[Any, float], seed:
Optional[int] = None
@@ -1283,7 +1341,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
fractions[k] = float(v)
col = col._jc
seed = seed if seed is not None else random.randint(0, sys.maxsize)
- return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions),
seed), self.sql_ctx)
+ return DataFrame(
+ self._jdf.stat().sampleBy(col, self._jmap(fractions), seed),
self.sparkSession
+ )
def randomSplit(self, weights: List[float], seed: Optional[int] = None) ->
List["DataFrame"]:
"""Randomly splits this :class:`DataFrame` with the provided weights.
@@ -1311,10 +1371,10 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
if w < 0.0:
raise ValueError("Weights must be positive. Found weight
value: %s" % w)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
- rdd_array = self._jdf.randomSplit(
- _to_list(self.sql_ctx._sc, cast(List["ColumnOrName"], weights)),
int(seed)
+ df_array = self._jdf.randomSplit(
+ _to_list(self.sparkSession._sc, cast(List["ColumnOrName"],
weights)), int(seed)
)
- return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
+ return [DataFrame(df, self.sparkSession) for df in df_array]
@property
def dtypes(self) -> List[Tuple[str, str]]:
@@ -1392,7 +1452,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
[Row(name='Bob', name='Bob', age=5), Row(name='Alice', name='Alice',
age=2)]
"""
assert isinstance(alias, str), "alias should be a string"
- return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)
+ return DataFrame(getattr(self._jdf, "as")(alias), self.sparkSession)
def crossJoin(self, other: "DataFrame") -> "DataFrame":
"""Returns the cartesian product with another :class:`DataFrame`.
@@ -1416,7 +1476,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
jdf = self._jdf.crossJoin(other._jdf)
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def join(
self,
@@ -1486,7 +1546,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
on = self._jseq([])
assert isinstance(how, str), "how should be a string"
jdf = self._jdf.join(other._jdf, on, how)
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
# TODO(SPARK-22947): Fix the DataFrame API.
def _joinAsOf(
@@ -1607,7 +1667,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
allowExactMatches,
direction,
)
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def sortWithinPartitions(
self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs:
Any
@@ -1639,7 +1699,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
+---+-----+
"""
jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def sort(
self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs:
Any
@@ -1677,7 +1737,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
[Row(age=5, name='Bob'), Row(age=2, name='Alice')]
"""
jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
orderBy = sort
@@ -1687,11 +1747,11 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
converter: Optional[Callable[..., Union["PrimitiveType", JavaObject]]]
= None,
) -> JavaObject:
"""Return a JVM Seq of Columns from a list of Column or names"""
- return _to_seq(self.sql_ctx._sc, cols, converter)
+ return _to_seq(self.sparkSession._sc, cols, converter)
def _jmap(self, jm: Dict) -> JavaObject:
"""Return a JVM Scala Map from a dict"""
- return _to_scala_map(self.sql_ctx._sc, jm)
+ return _to_scala_map(self.sparkSession._sc, jm)
def _jcols(self, *cols: "ColumnOrName") -> JavaObject:
"""Return a JVM Seq of Columns from a list of Column or column names
@@ -1767,7 +1827,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0] # type: ignore[assignment]
jdf = self._jdf.describe(self._jseq(cols))
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def summary(self, *statistics: str) -> "DataFrame":
"""Computes specified statistics for numeric and string columns.
Available statistics are:
@@ -1832,7 +1892,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
if len(statistics) == 1 and isinstance(statistics[0], list):
statistics = statistics[0]
jdf = self._jdf.summary(self._jseq(statistics))
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
@overload
def head(self) -> Optional[Row]:
@@ -1970,7 +2030,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
[Row(name='Alice', age=12), Row(name='Bob', age=15)]
"""
jdf = self._jdf.select(self._jcols(*cols))
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
@overload
def selectExpr(self, *expr: str) -> "DataFrame":
@@ -1995,7 +2055,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
if len(expr) == 1 and isinstance(expr[0], list):
expr = expr[0] # type: ignore[assignment]
jdf = self._jdf.selectExpr(self._jseq(expr))
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def filter(self, condition: "ColumnOrName") -> "DataFrame":
"""Filters rows using the given condition.
@@ -2028,7 +2088,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jdf = self._jdf.filter(condition._jc)
else:
raise TypeError("condition should be string or Column")
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
@overload
def groupBy(self, *cols: "ColumnOrName") -> "GroupedData":
@@ -2203,7 +2263,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
Also as standard in SQL, this function resolves columns by position
(not by name).
"""
- return DataFrame(self._jdf.union(other._jdf), self.sql_ctx)
+ return DataFrame(self._jdf.union(other._jdf), self.sparkSession)
@since(1.3)
def unionAll(self, other: "DataFrame") -> "DataFrame":
@@ -2260,7 +2320,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
Added optional argument `allowMissingColumns` to specify whether to
allow
missing columns.
"""
- return DataFrame(self._jdf.unionByName(other._jdf,
allowMissingColumns), self.sql_ctx)
+ return DataFrame(self._jdf.unionByName(other._jdf,
allowMissingColumns), self.sparkSession)
@since(1.3)
def intersect(self, other: "DataFrame") -> "DataFrame":
@@ -2269,7 +2329,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
This is equivalent to `INTERSECT` in SQL.
"""
- return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
+ return DataFrame(self._jdf.intersect(other._jdf), self.sparkSession)
def intersectAll(self, other: "DataFrame") -> "DataFrame":
"""Return a new :class:`DataFrame` containing rows in both this
:class:`DataFrame`
@@ -2295,7 +2355,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
+---+---+
"""
- return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx)
+ return DataFrame(self._jdf.intersectAll(other._jdf), self.sparkSession)
@since(1.3)
def subtract(self, other: "DataFrame") -> "DataFrame":
@@ -2305,7 +2365,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
This is equivalent to `EXCEPT DISTINCT` in SQL.
"""
- return DataFrame(getattr(self._jdf, "except")(other._jdf),
self.sql_ctx)
+ return DataFrame(getattr(self._jdf, "except")(other._jdf),
self.sparkSession)
def dropDuplicates(self, subset: Optional[List[str]] = None) ->
"DataFrame":
"""Return a new :class:`DataFrame` with duplicate rows removed,
@@ -2350,7 +2410,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jdf = self._jdf.dropDuplicates()
else:
jdf = self._jdf.dropDuplicates(self._jseq(subset))
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def dropna(
self,
@@ -2398,7 +2458,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
if thresh is None:
thresh = len(subset) if how == "any" else 1
- return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)),
self.sql_ctx)
+ return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)),
self.sparkSession)
@overload
def fillna(
@@ -2476,16 +2536,16 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
value = float(value)
if isinstance(value, dict):
- return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
+ return DataFrame(self._jdf.na().fill(value), self.sparkSession)
elif subset is None:
- return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
+ return DataFrame(self._jdf.na().fill(value), self.sparkSession)
else:
if isinstance(subset, str):
subset = [subset]
elif not isinstance(subset, (list, tuple)):
raise TypeError("subset should be a list or tuple of column
names")
- return DataFrame(self._jdf.na().fill(value, self._jseq(subset)),
self.sql_ctx)
+ return DataFrame(self._jdf.na().fill(value, self._jseq(subset)),
self.sparkSession)
@overload
def replace(
@@ -2686,10 +2746,11 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
raise ValueError("Mixed type replacements are not supported")
if subset is None:
- return DataFrame(self._jdf.na().replace("*", rep_dict),
self.sql_ctx)
+ return DataFrame(self._jdf.na().replace("*", rep_dict),
self.sparkSession)
else:
return DataFrame(
- self._jdf.na().replace(self._jseq(subset),
self._jmap(rep_dict)), self.sql_ctx
+ self._jdf.na().replace(self._jseq(subset),
self._jmap(rep_dict)),
+ self.sparkSession,
)
@overload
@@ -2875,7 +2936,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
raise TypeError("col1 should be a string.")
if not isinstance(col2, str):
raise TypeError("col2 should be a string.")
- return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
+ return DataFrame(self._jdf.stat().crosstab(col1, col2),
self.sparkSession)
def freqItems(
self, cols: Union[List[str], Tuple[str]], support: Optional[float] =
None
@@ -2909,7 +2970,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
raise TypeError("cols must be a list or tuple of column names as
strings.")
if not support:
support = 0.01
- return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols),
support), self.sql_ctx)
+ return DataFrame(
+ self._jdf.stat().freqItems(_to_seq(self._sc, cols), support),
self.sparkSession
+ )
def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame":
"""
@@ -2978,7 +3041,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
if not isinstance(col, Column):
raise TypeError("col should be Column")
- return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx)
+ return DataFrame(self._jdf.withColumn(colName, col._jc),
self.sparkSession)
def withColumnRenamed(self, existing: str, new: str) -> "DataFrame":
"""Returns a new :class:`DataFrame` by renaming an existing column.
@@ -2998,7 +3061,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
>>> df.withColumnRenamed('age', 'age2').collect()
[Row(age2=2, name='Alice'), Row(age2=5, name='Bob')]
"""
- return DataFrame(self._jdf.withColumnRenamed(existing, new),
self.sql_ctx)
+ return DataFrame(self._jdf.withColumnRenamed(existing, new),
self.sparkSession)
def withMetadata(self, columnName: str, metadata: Dict[str, Any]) ->
"DataFrame":
"""Returns a new :class:`DataFrame` by updating an existing column
with metadata.
@@ -3023,7 +3086,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
sc = SparkContext._active_spark_context
assert sc is not None and sc._jvm is not None
jmeta =
sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(json.dumps(metadata))
- return DataFrame(self._jdf.withMetadata(columnName, jmeta),
self.sql_ctx)
+ return DataFrame(self._jdf.withMetadata(columnName, jmeta),
self.sparkSession)
@overload
def drop(self, cols: "ColumnOrName") -> "DataFrame":
@@ -3075,7 +3138,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
raise TypeError("each col in the param list should be a
string")
jdf = self._jdf.drop(self._jseq(cols))
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def toDF(self, *cols: "ColumnOrName") -> "DataFrame":
"""Returns a new :class:`DataFrame` that with new specified column
names
@@ -3091,7 +3154,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
[Row(f1=2, f2='Alice'), Row(f1=5, f2='Bob')]
"""
jdf = self._jdf.toDF(self._jseq(cols))
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def transform(self, func: Callable[..., "DataFrame"], *args: Any,
**kwargs: Any) -> "DataFrame":
"""Returns a new :class:`DataFrame`. Concise syntax for chaining
custom transformations.
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 2dfaec8..d79da47 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1105,7 +1105,7 @@ def broadcast(df: DataFrame) -> DataFrame:
sc = SparkContext._active_spark_context
assert sc is not None and sc._jvm is not None
- return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sql_ctx)
+ return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sparkSession)
def coalesce(*cols: "ColumnOrName") -> Column:
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 485e017..802d34d0 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -22,7 +22,7 @@ from typing import Callable, List, Optional, TYPE_CHECKING,
overload, Dict, Unio
from py4j.java_gateway import JavaObject # type: ignore[import]
from pyspark.sql.column import Column, _to_seq
-from pyspark.sql.context import SQLContext
+from pyspark.sql.session import SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
@@ -37,7 +37,7 @@ def dfapi(f: Callable) -> Callable:
def _api(self: "GroupedData") -> DataFrame:
name = f.__name__
jdf = getattr(self._jgd, name)()
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.session)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
@@ -47,8 +47,8 @@ def dfapi(f: Callable) -> Callable:
def df_varargs_api(f: Callable) -> Callable:
def _api(self: "GroupedData", *cols: str) -> DataFrame:
name = f.__name__
- jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, cols))
- return DataFrame(jdf, self.sql_ctx)
+ jdf = getattr(self._jgd, name)(_to_seq(self.session._sc, cols))
+ return DataFrame(jdf, self.session)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
@@ -66,7 +66,7 @@ class GroupedData(PandasGroupedOpsMixin):
def __init__(self, jgd: JavaObject, df: DataFrame):
self._jgd = jgd
self._df = df
- self.sql_ctx: SQLContext = df.sql_ctx
+ self.session: SparkSession = df.sparkSession
@overload
def agg(self, *exprs: Column) -> DataFrame:
@@ -134,8 +134,8 @@ class GroupedData(PandasGroupedOpsMixin):
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs
should be Column"
exprs = cast(Tuple[Column, ...], exprs)
- jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.sql_ctx._sc, [c._jc
for c in exprs[1:]]))
- return DataFrame(jdf, self.sql_ctx)
+ jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.session._sc, [c._jc
for c in exprs[1:]]))
+ return DataFrame(jdf, self.session)
@dfapi
def count(self) -> DataFrame:
diff --git a/python/pyspark/sql/observation.py
b/python/pyspark/sql/observation.py
index e5d426a..951b0f4 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -109,7 +109,7 @@ class Observation:
observed_df = self._jo.on(
df._jdf, exprs[0]._jc, column._to_seq(df._sc, [c._jc for c in
exprs[1:]])
)
- return DataFrame(observed_df, df.sql_ctx)
+ return DataFrame(observed_df, df.sparkSession)
@property
def get(self) -> Dict[str, Any]:
diff --git a/python/pyspark/sql/pandas/conversion.py
b/python/pyspark/sql/pandas/conversion.py
index 33a4058..fbb5183 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -43,6 +43,7 @@ from pyspark.traceback_utils import SCCallSiteSync
if TYPE_CHECKING:
import numpy as np
import pyarrow as pa
+ from py4j.java_gateway import JavaObject
from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
from pyspark.sql import DataFrame
@@ -88,9 +89,10 @@ class PandasConversionMixin:
import pandas as pd
from pandas.core.dtypes.common import is_timedelta64_dtype
- timezone = self.sql_ctx._conf.sessionLocalTimeZone() # type:
ignore[attr-defined]
+ jconf = self.sparkSession._jconf
+ timezone = jconf.sessionLocalTimeZone()
- if self.sql_ctx._conf.arrowPySparkEnabled(): # type:
ignore[attr-defined]
+ if jconf.arrowPySparkEnabled(): # type: ignore[attr-defined]
use_arrow = True
try:
from pyspark.sql.pandas.types import to_arrow_schema
@@ -100,7 +102,7 @@ class PandasConversionMixin:
to_arrow_schema(self.schema)
except Exception as e:
- if self.sql_ctx._conf.arrowPySparkFallbackEnabled(): # type:
ignore[attr-defined]
+ if jconf.arrowPySparkFallbackEnabled(): # type:
ignore[attr-defined]
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to
true; however, "
@@ -134,7 +136,7 @@ class PandasConversionMixin:
# Rename columns to avoid duplicated column names.
tmp_column_names = ["col_{}".format(i) for i in
range(len(self.columns))]
- c = self.sql_ctx._conf
+ c = self.sparkSession._jconf
self_destruct = (
c.arrowPySparkSelfDestructEnabled() # type:
ignore[attr-defined]
)
@@ -368,6 +370,8 @@ class SparkConversionMixin:
can use this class.
"""
+ _jsparkSession: "JavaObject"
+
@overload
def createDataFrame(
self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
@@ -398,20 +402,17 @@ class SparkConversionMixin:
require_minimum_pandas_version()
- timezone = self._wrapped._conf.sessionLocalTimeZone() # type:
ignore[attr-defined]
+ timezone = self._jconf.sessionLocalTimeZone() # type:
ignore[attr-defined]
# If no schema supplied by user then get the names of columns only
if schema is None:
schema = [str(x) if not isinstance(x, str) else x for x in
data.columns]
- if (
- self._wrapped._conf.arrowPySparkEnabled() # type:
ignore[attr-defined]
- and len(data) > 0
- ):
+ if self._jconf.arrowPySparkEnabled() and len(data) > 0: # type:
ignore[attr-defined]
try:
return self._create_from_pandas_with_arrow(data, schema,
timezone)
except Exception as e:
- if self._wrapped._conf.arrowPySparkFallbackEnabled(): # type:
ignore[attr-defined]
+ if self._jconf.arrowPySparkFallbackEnabled(): # type:
ignore[attr-defined]
msg = (
"createDataFrame attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to
true; however, "
@@ -603,25 +604,25 @@ class SparkConversionMixin:
for pdf_slice in pdf_slices
]
- jsqlContext = self._wrapped._jsqlContext # type: ignore[attr-defined]
+ jsparkSession = self._jsparkSession
- safecheck = self._wrapped._conf.arrowSafeTypeConversion() # type:
ignore[attr-defined]
+ safecheck = self._jconf.arrowSafeTypeConversion() # type:
ignore[attr-defined]
col_by_name = True # col by name only applies to StructType columns,
can't happen here
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
@no_type_check
def reader_func(temp_filename):
- return
self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
+ return
self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsparkSession, temp_filename)
@no_type_check
def create_RDD_server():
- return self._jvm.ArrowRDDServer(jsqlContext)
+ return self._jvm.ArrowRDDServer(jsparkSession)
# Create Spark DataFrame from Arrow stream file, using one batch per
partition
jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func,
create_RDD_server)
assert self._jvm is not None
- jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(),
jsqlContext)
- df = DataFrame(jdf, self._wrapped)
+ jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(),
jsparkSession)
+ df = DataFrame(jdf, self)
df._schema = schema
return df
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index 35f531f..e7599b1 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -214,7 +214,7 @@ class PandasGroupedOpsMixin:
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) # type:
ignore[attr-defined]
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.session)
def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
"""
@@ -246,7 +246,6 @@ class PandasCogroupedOps:
def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
self._gd1 = gd1
self._gd2 = gd2
- self.sql_ctx = gd1.sql_ctx
def applyInPandas(
self, func: "PandasCogroupedMapFunction", schema: Union[StructType,
str]
@@ -345,7 +344,7 @@ class PandasCogroupedOps:
jdf = self._gd1._jgd.flatMapCoGroupsInPandas( # type:
ignore[attr-defined]
self._gd2._jgd, udf_column._jc.expr() # type: ignore[attr-defined]
)
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self._gd1.session)
@staticmethod
def _extract_cols(gd: "GroupedData") -> List[Column]:
diff --git a/python/pyspark/sql/pandas/map_ops.py
b/python/pyspark/sql/pandas/map_ops.py
index c1c29ec..c1bf6aa 100644
--- a/python/pyspark/sql/pandas/map_ops.py
+++ b/python/pyspark/sql/pandas/map_ops.py
@@ -90,7 +90,7 @@ class PandasMapOpsMixin:
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.mapInPandas(udf_column._jc.expr()) # type:
ignore[operator]
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def mapInArrow(
self, func: "ArrowMapIterFunction", schema: Union[StructType, str]
@@ -153,7 +153,7 @@ class PandasMapOpsMixin:
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.pythonMapInArrow(udf_column._jc.expr())
- return DataFrame(jdf, self.sql_ctx)
+ return DataFrame(jdf, self.sparkSession)
def _test() -> None:
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index df4a089..8c729c6 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -27,7 +27,7 @@ from pyspark.sql.utils import to_str
if TYPE_CHECKING:
from pyspark.sql._typing import OptionalPrimitiveType, ColumnOrName
- from pyspark.sql.context import SQLContext
+ from pyspark.sql.session import SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.streaming import StreamingQuery
@@ -62,8 +62,8 @@ class DataFrameReader(OptionUtils):
.. versionadded:: 1.4
"""
- def __init__(self, spark: "SQLContext"):
- self._jreader = spark._ssql_ctx.read() # type: ignore[attr-defined]
+ def __init__(self, spark: "SparkSession"):
+ self._jreader = spark._jsparkSession.read() # type:
ignore[attr-defined]
self._spark = spark
def _df(self, jdf: JavaObject) -> "DataFrame":
@@ -560,7 +560,7 @@ class DataFrameReader(OptionUtils):
# There aren't any jvm api for creating a dataframe from rdd
storing csv.
# We can do it through creating a jvm dataset firstly and using
the jvm api
# for creating a dataframe from dataset storing csv.
- jdataset = self._spark._ssql_ctx.createDataset(
+ jdataset = self._spark._jsparkSession.createDataset(
jrdd.rdd(), self._spark._jvm.Encoders.STRING()
)
return self._df(self._jreader.csv(jdataset))
@@ -737,7 +737,7 @@ class DataFrameWriter(OptionUtils):
def __init__(self, df: "DataFrame"):
self._df = df
- self._spark = df.sql_ctx
+ self._spark = df.sparkSession
self._jwrite = df._jdf.write() # type: ignore[operator]
def _sq(self, jsq: JavaObject) -> "StreamingQuery":
@@ -1360,7 +1360,7 @@ class DataFrameWriterV2:
def __init__(self, df: "DataFrame", table: str):
self._df = df
- self._spark = df.sql_ctx
+ self._spark = df.sparkSession
self._jwriter = df._jdf.writeTo(table) # type: ignore[operator]
@since(3.1)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 233d529..a41ad15 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -230,11 +230,6 @@ class SparkSession(SparkConversionMixin):
"""
return self.config("spark.sql.catalogImplementation", "hive")
- def _sparkContext(self, sc: SparkContext) -> "SparkSession.Builder":
- with self._lock:
- self._sc = sc
- return self
-
def getOrCreate(self) -> "SparkSession":
"""Gets an existing :class:`SparkSession` or, if there is no
existing one, creates a
new one based on the options set in this builder.
@@ -267,14 +262,11 @@ class SparkSession(SparkConversionMixin):
session = SparkSession._instantiatedSession
if session is None or session._sc._jsc is None: # type:
ignore[attr-defined]
- if self._sc is not None:
- sc = self._sc
- else:
- sparkConf = SparkConf()
- for key, value in self._options.items():
- sparkConf.set(key, value)
- # This SparkContext may be an existing one.
- sc = SparkContext.getOrCreate(sparkConf)
+ sparkConf = SparkConf()
+ for key, value in self._options.items():
+ sparkConf.set(key, value)
+ # This SparkContext may be an existing one.
+ sc = SparkContext.getOrCreate(sparkConf)
# Do not update `SparkConf` for existing `SparkContext`,
as it's shared
# by all sessions.
session = SparkSession(sc, options=self._options)
@@ -296,8 +288,6 @@ class SparkSession(SparkConversionMixin):
jsparkSession: Optional[JavaObject] = None,
options: Dict[str, Any] = {},
):
- from pyspark.sql.context import SQLContext
-
self._sc = sparkContext
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
@@ -320,8 +310,6 @@ class SparkSession(SparkConversionMixin):
jsparkSession, options
)
self._jsparkSession = jsparkSession
- self._jwrapped = self._jsparkSession.sqlContext()
- self._wrapped = SQLContext(self._sc, self, self._jwrapped)
_monkey_patch_RDD(self)
install_exception_handler()
# If we had an instantiated SparkSession attached with a SparkContext
@@ -348,6 +336,11 @@ class SparkSession(SparkConversionMixin):
sc_HTML=self.sparkContext._repr_html_(), # type:
ignore[attr-defined]
)
+ @property
+ def _jconf(self) -> "JavaObject":
+ """Accessor for the JVM SQL-specific configurations"""
+ return self._jsparkSession.sessionState().conf()
+
@since(2.0)
def newSession(self) -> "SparkSession":
"""
@@ -498,7 +491,7 @@ class SparkSession(SparkConversionMixin):
else:
jdf = self._jsparkSession.range(int(start), int(end), int(step),
int(numPartitions))
- return DataFrame(jdf, self._wrapped)
+ return DataFrame(jdf, self)
def _inferSchemaFromList(
self, data: Iterable[Any], names: Optional[List[str]] = None
@@ -519,7 +512,7 @@ class SparkSession(SparkConversionMixin):
"""
if not data:
raise ValueError("can not infer schema from empty dataset")
- infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct() #
type: ignore[attr-defined]
+ infer_dict_as_struct = self._jconf.inferDictAsStruct() # type:
ignore[attr-defined]
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
schema = reduce(
_merge_type,
@@ -554,7 +547,7 @@ class SparkSession(SparkConversionMixin):
if not first:
raise ValueError("The first row in RDD is empty, " "can not infer
schema")
- infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct() #
type: ignore[attr-defined]
+ infer_dict_as_struct = self._jconf.inferDictAsStruct() # type:
ignore[attr-defined]
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
if samplingRatio is None:
schema = _infer_schema(
@@ -684,14 +677,20 @@ class SparkSession(SparkConversionMixin):
return SparkSession._getActiveSessionOrCreate()
@staticmethod
- def _getActiveSessionOrCreate() -> "SparkSession":
+ def _getActiveSessionOrCreate(**static_conf: Any) -> "SparkSession":
"""
Returns the active :class:`SparkSession` for the current thread,
returned by the builder,
or if there is no existing one, creates a new one based on the options
set in the builder.
+
+ NOTE that 'static_conf' might not be set if there's an active or
default Spark session
+ running.
"""
spark = SparkSession.getActiveSession()
if spark is None:
- spark = SparkSession.builder.getOrCreate()
+ builder = SparkSession.builder
+ for k, v in static_conf.items():
+ builder = builder.config(k, v)
+ spark = builder.getOrCreate()
return spark
@overload
@@ -940,7 +939,7 @@ class SparkSession(SparkConversionMixin):
rdd._to_java_object_rdd() # type: ignore[attr-defined]
)
jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(),
struct.json())
- df = DataFrame(jdf, self._wrapped)
+ df = DataFrame(jdf, self)
df._schema = struct
return df
@@ -1034,7 +1033,7 @@ class SparkSession(SparkConversionMixin):
if len(kwargs) > 0:
sqlQuery = formatter.format(sqlQuery, **kwargs)
try:
- return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
+ return DataFrame(self._jsparkSession.sql(sqlQuery), self)
finally:
if len(kwargs) > 0:
formatter.clear()
@@ -1055,7 +1054,7 @@ class SparkSession(SparkConversionMixin):
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- return DataFrame(self._jsparkSession.table(tableName), self._wrapped)
+ return DataFrame(self._jsparkSession.table(tableName), self)
@property
def read(self) -> DataFrameReader:
@@ -1069,7 +1068,7 @@ class SparkSession(SparkConversionMixin):
-------
:class:`DataFrameReader`
"""
- return DataFrameReader(self._wrapped)
+ return DataFrameReader(self)
@property
def readStream(self) -> DataStreamReader:
@@ -1087,7 +1086,7 @@ class SparkSession(SparkConversionMixin):
-------
:class:`DataStreamReader`
"""
- return DataStreamReader(self._wrapped)
+ return DataStreamReader(self)
@property
def streams(self) -> "StreamingQueryManager":
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index de68ccc..7cff8d0 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -29,7 +29,7 @@ from pyspark.sql.types import Row, StructType, StructField,
StringType
from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException
if TYPE_CHECKING:
- from pyspark.sql import SQLContext
+ from pyspark.sql.session import SparkSession
from pyspark.sql._typing import SupportsProcess, OptionalPrimitiveType
from pyspark.sql.dataframe import DataFrame
@@ -316,8 +316,8 @@ class DataStreamReader(OptionUtils):
This API is evolving.
"""
- def __init__(self, spark: "SQLContext") -> None:
- self._jreader = spark._ssql_ctx.readStream()
+ def __init__(self, spark: "SparkSession") -> None:
+ self._jreader = spark._jsparkSession.readStream()
self._spark = spark
def _df(self, jdf: JavaObject) -> "DataFrame":
@@ -856,7 +856,7 @@ class DataStreamWriter:
def __init__(self, df: "DataFrame") -> None:
self._df = df
- self._spark = df.sql_ctx
+ self._spark = df.sparkSession
self._jwrite = df._jdf.writeStream()
def _sq(self, jsq: JavaObject) -> StreamingQuery:
diff --git a/python/pyspark/sql/tests/test_session.py
b/python/pyspark/sql/tests/test_session.py
index 1262e52..91aa923 100644
--- a/python/pyspark/sql/tests/test_session.py
+++ b/python/pyspark/sql/tests/test_session.py
@@ -224,28 +224,26 @@ class SparkSessionTests5(unittest.TestCase):
def test_sqlcontext_with_stopped_sparksession(self):
# SPARK-30856: test that SQLContext.getOrCreate() returns a usable
instance after
# the SparkSession is restarted.
- sql_context = self.spark._wrapped
+ sql_context = SQLContext.getOrCreate(self.spark.sparkContext)
self.spark.stop()
- sc = SparkContext("local[4]", self.sc.appName)
- spark = SparkSession(sc) # Instantiate the underlying SQLContext
- new_sql_context = spark._wrapped
+ spark =
SparkSession.builder.master("local[4]").appName(self.sc.appName).getOrCreate()
+ new_sql_context = SQLContext.getOrCreate(spark.sparkContext)
self.assertIsNot(new_sql_context, sql_context)
- self.assertIs(SQLContext.getOrCreate(sc).sparkSession, spark)
+ self.assertIs(SQLContext.getOrCreate(spark.sparkContext).sparkSession,
spark)
try:
df = spark.createDataFrame([(1, 2)], ["c", "c"])
df.collect()
finally:
spark.stop()
self.assertIsNone(SQLContext._instantiatedContext)
- sc.stop()
def test_sqlcontext_with_stopped_sparkcontext(self):
# SPARK-30856: test initialization via SparkSession when only the
SparkContext is stopped
self.sc.stop()
- self.sc = SparkContext("local[4]", self.sc.appName)
- self.spark = SparkSession(self.sc)
- self.assertIs(SQLContext.getOrCreate(self.sc).sparkSession, self.spark)
+ spark =
SparkSession.builder.master("local[4]").appName(self.sc.appName).getOrCreate()
+ self.sc = spark.sparkContext
+ self.assertIs(SQLContext.getOrCreate(self.sc).sparkSession, spark)
def test_get_sqlcontext_with_stopped_sparkcontext(self):
# SPARK-30856: test initialization via SQLContext.getOrCreate() when
only the SparkContext
diff --git a/python/pyspark/sql/tests/test_streaming.py
b/python/pyspark/sql/tests/test_streaming.py
index 87e3564..4920423 100644
--- a/python/pyspark/sql/tests/test_streaming.py
+++ b/python/pyspark/sql/tests/test_streaming.py
@@ -86,7 +86,7 @@ class StreamingTests(ReusedSQLTestCase):
.load("python/test_support/sql/streaming")
.withColumn("id", lit(1))
)
- for q in self.spark._wrapped.streams.active:
+ for q in self.spark.streams.active:
q.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
@@ -117,7 +117,7 @@ class StreamingTests(ReusedSQLTestCase):
def test_stream_save_options_overwrite(self):
df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
- for q in self.spark._wrapped.streams.active:
+ for q in self.spark.streams.active:
q.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
@@ -154,7 +154,7 @@ class StreamingTests(ReusedSQLTestCase):
def test_stream_status_and_progress(self):
df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
- for q in self.spark._wrapped.streams.active:
+ for q in self.spark.streams.active:
q.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
@@ -198,7 +198,7 @@ class StreamingTests(ReusedSQLTestCase):
def test_stream_await_termination(self):
df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
- for q in self.spark._wrapped.streams.active:
+ for q in self.spark.streams.active:
q.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
@@ -267,7 +267,7 @@ class StreamingTests(ReusedSQLTestCase):
def test_query_manager_await_termination(self):
df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
- for q in self.spark._wrapped.streams.active:
+ for q in self.spark.streams.active:
q.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
@@ -280,13 +280,13 @@ class StreamingTests(ReusedSQLTestCase):
try:
self.assertTrue(q.isActive)
try:
- self.spark._wrapped.streams.awaitAnyTermination("hello")
+ self.spark.streams.awaitAnyTermination("hello")
self.fail("Expected a value exception")
except ValueError:
pass
now = time.time()
# test should take at least 2 seconds
- res = self.spark._wrapped.streams.awaitAnyTermination(2.6)
+ res = self.spark.streams.awaitAnyTermination(2.6)
duration = time.time() - now
self.assertTrue(duration >= 2)
self.assertFalse(res)
@@ -347,7 +347,7 @@ class StreamingTests(ReusedSQLTestCase):
self.stop_all()
def stop_all(self):
- for q in self.spark._wrapped.streams.active:
+ for q in self.spark.streams.active:
q.stop()
def _reset(self):
diff --git a/python/pyspark/sql/tests/test_udf.py
b/python/pyspark/sql/tests/test_udf.py
index a092d67..0e9d766 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -22,7 +22,7 @@ import tempfile
import unittest
import datetime
-from pyspark import SparkContext
+from pyspark import SparkContext, SQLContext
from pyspark.sql import SparkSession, Column, Row
from pyspark.sql.functions import udf, assert_true, lit
from pyspark.sql.udf import UserDefinedFunction
@@ -79,7 +79,7 @@ class UDFTests(ReusedSQLTestCase):
self.assertEqual(row[0], 5)
# This is to check if a deprecated 'SQLContext.registerFunction' can
call its alias.
- sqlContext = self.spark._wrapped
+ sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
[row] = sqlContext.sql("SELECT oneArg('test')").collect()
self.assertEqual(row[0], 4)
@@ -372,7 +372,7 @@ class UDFTests(ReusedSQLTestCase):
)
# This is to check if a 'SQLContext.udf' can call its alias.
- sqlContext = self.spark._wrapped
+ sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
add_four = sqlContext.udf.register("add_four", lambda x: x + 4,
IntegerType())
self.assertListEqual(
@@ -419,7 +419,7 @@ class UDFTests(ReusedSQLTestCase):
)
# This is to check if a deprecated 'SQLContext.registerJavaFunction'
can call its alias.
- sqlContext = spark._wrapped
+ sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
self.assertRaisesRegex(
AnalysisException,
"Can not load class non_existed_udf",
diff --git a/python/pyspark/sql/tests/test_udf_profiler.py
b/python/pyspark/sql/tests/test_udf_profiler.py
index 27d9458..136f423 100644
--- a/python/pyspark/sql/tests/test_udf_profiler.py
+++ b/python/pyspark/sql/tests/test_udf_profiler.py
@@ -21,7 +21,7 @@ import os
import sys
from io import StringIO
-from pyspark import SparkConf, SparkContext
+from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.profiler import UDFBasicProfiler
@@ -32,8 +32,13 @@ class UDFProfilerTests(unittest.TestCase):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
conf = SparkConf().set("spark.python.profile", "true")
- self.sc = SparkContext("local[4]", class_name, conf=conf)
- self.spark = SparkSession.builder._sparkContext(self.sc).getOrCreate()
+ self.spark = (
+ SparkSession.builder.master("local[4]")
+ .config(conf=conf)
+ .appName(class_name)
+ .getOrCreate()
+ )
+ self.sc = self.spark.sparkContext
def tearDown(self):
self.spark.stop()
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 15645d0..b5abe68 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -30,7 +30,7 @@ from pyspark import SparkContext
from pyspark.find_spark_home import _find_spark_home
if TYPE_CHECKING:
- from pyspark.sql.context import SQLContext
+ from pyspark.sql.session import SparkSession
from pyspark.sql.dataframe import DataFrame
@@ -258,15 +258,15 @@ class ForeachBatchFunction:
the query is active.
"""
- def __init__(self, sql_ctx: "SQLContext", func: Callable[["DataFrame",
int], None]):
- self.sql_ctx = sql_ctx
+ def __init__(self, session: "SparkSession", func: Callable[["DataFrame",
int], None]):
self.func = func
+ self.session = session
def call(self, jdf: JavaObject, batch_id: int) -> None:
from pyspark.sql.dataframe import DataFrame
try:
- self.func(DataFrame(jdf, self.sql_ctx), batch_id)
+ self.func(DataFrame(jdf, self.session), batch_id)
except Exception as e:
self.error = e
raise e
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 490ab9f..ab43aa4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -24,7 +24,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.PythonRDDServer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Column, DataFrame, SQLContext}
+import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.{CastTimestampNTZToLong,
ExpressionInfo}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
@@ -59,8 +59,8 @@ private[sql] object PythonSQLUtils extends Logging {
* Python callable function to read a file in Arrow stream format and create
a [[RDD]]
* using each serialized ArrowRecordBatch as a partition.
*/
- def readArrowStreamFromFile(sqlContext: SQLContext, filename: String):
JavaRDD[Array[Byte]] = {
- ArrowConverters.readArrowStreamFromFile(sqlContext, filename)
+ def readArrowStreamFromFile(session: SparkSession, filename: String):
JavaRDD[Array[Byte]] = {
+ ArrowConverters.readArrowStreamFromFile(session, filename)
}
/**
@@ -70,8 +70,8 @@ private[sql] object PythonSQLUtils extends Logging {
def toDataFrame(
arrowBatchRDD: JavaRDD[Array[Byte]],
schemaString: String,
- sqlContext: SQLContext): DataFrame = {
- ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, sqlContext)
+ session: SparkSession): DataFrame = {
+ ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, session)
}
def explainString(queryExecution: QueryExecution, mode: String): String = {
@@ -85,13 +85,13 @@ private[sql] object PythonSQLUtils extends Logging {
* Helper for making a dataframe from arrow data from data sent from python
over a socket. This is
* used when encryption is enabled, and we don't want to write data to a file.
*/
-private[sql] class ArrowRDDServer(sqlContext: SQLContext) extends
PythonRDDServer {
+private[sql] class ArrowRDDServer(session: SparkSession) extends
PythonRDDServer {
override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
// Create array to consume iterator so that we can safely close the
inputStream
val batches =
ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray
// Parallelize the record batches to create an RDD
- JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches,
batches.length))
+ JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index befaea2..7831dde 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -230,7 +230,7 @@ private[sql] object SQLUtils extends Logging {
def readArrowStreamFromFile(
sparkSession: SparkSession,
filename: String): JavaRDD[Array[Byte]] = {
- ArrowConverters.readArrowStreamFromFile(sparkSession.sqlContext, filename)
+ ArrowConverters.readArrowStreamFromFile(sparkSession, filename)
}
/**
@@ -241,6 +241,6 @@ private[sql] object SQLUtils extends Logging {
arrowBatchRDD: JavaRDD[Array[Byte]],
schema: StructType,
sparkSession: SparkSession): DataFrame = {
- ArrowConverters.toDataFrame(arrowBatchRDD, schema.json,
sparkSession.sqlContext)
+ ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, sparkSession)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 8e22c42..93ff276 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -31,7 +31,7 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch,
IpcOption, Message
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
@@ -195,27 +195,27 @@ private[sql] object ArrowConverters {
private[sql] def toDataFrame(
arrowBatchRDD: JavaRDD[Array[Byte]],
schemaString: String,
- sqlContext: SQLContext): DataFrame = {
+ session: SparkSession): DataFrame = {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
- val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
+ val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
val context = TaskContext.get()
ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
}
- sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema)
+ session.internalCreateDataFrame(rdd.setName("arrow"), schema)
}
/**
* Read a file as an Arrow stream and parallelize as an RDD of serialized
ArrowRecordBatches.
*/
private[sql] def readArrowStreamFromFile(
- sqlContext: SQLContext,
+ session: SparkSession,
filename: String): JavaRDD[Array[Byte]] = {
Utils.tryWithResource(new FileInputStream(filename)) { fileStream =>
// Create array to consume iterator so that we can safely close the file
val batches = getBatchesFromStream(fileStream.getChannel).toArray
// Parallelize the record batches to create an RDD
- JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches,
batches.length))
+ JavaRDD.fromRDD(session.sparkContext.parallelize(batches,
batches.length))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]