This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 44984f196 [SEDONA-663] Support spark connect python api (#1639)
44984f196 is described below
commit 44984f196f7911156b1b32603276c81990087a3c
Author: Sebastian Eckweiler <[email protected]>
AuthorDate: Tue Oct 15 22:28:48 2024 +0200
[SEDONA-663] Support spark connect python api (#1639)
* initial successful test
* try add docker-compose based tests
* 3.5 only
* comment classic tests
* try fix yaml
* skip other workflows
* skip other workflows
* try fix if check
* fix path
* cd to python folder
* skip sparkContext with SPARK_REMOTE
* fix type check
* refactor somewhat
* Revert "skip other workflows"
This reverts commit 7eb9b6ea
* back to full matrix
* add license header, fix missing whitespace
* Add a simple docstring to SedonaFunction
* uncomment build step
* need sql extensions
* run pre-commit
* fix lint/pre-commit
* Update .github/workflows/python.yml
Co-authored-by: John Bampton <[email protected]>
* adjust spelling
* use UnresolvedFunction instead of CallFunction
* revert Pipfile to master rev
---------
Co-authored-by: John Bampton <[email protected]>
---
.github/workflows/python.yml | 17 +++++++++++
python/sedona/spark/SedonaContext.py | 14 +++++++--
.../{tests/test_base.py => sedona/sql/connect.py} | 35 +++++++++++-----------
python/sedona/sql/dataframe_api.py | 35 +++++++++++++++++-----
python/tests/test_base.py | 31 +++++++++++++++----
5 files changed, 100 insertions(+), 32 deletions(-)
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index e7d1002d9..04fa4f7fc 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -153,3 +153,20 @@ jobs:
SPARK_VERSION: ${{ matrix.spark }}
HADOOP_VERSION: ${{ matrix.hadoop }}
run: (export
SPARK_HOME=$PWD/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION};export
PYTHONPATH=$SPARK_HOME/python;cd python;pipenv run pytest tests)
+ - env:
+ SPARK_VERSION: ${{ matrix.spark }}
+ HADOOP_VERSION: ${{ matrix.hadoop }}
+ run: |
+ if [ ! -f
"spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}/sbin/start-connect-server.sh"
]
+ then
+ echo "Skipping connect tests for Spark $SPARK_VERSION"
+ exit
+ fi
+
+ export
SPARK_HOME=$PWD/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}
+ export PYTHONPATH=$SPARK_HOME/python
+ export SPARK_REMOTE=local
+
+ cd python
+ pipenv install "pyspark[connect]==${SPARK_VERSION}"
+ pipenv run pytest tests/sql/test_dataframe_api.py
diff --git a/python/sedona/spark/SedonaContext.py
b/python/sedona/spark/SedonaContext.py
index 5cba5df62..49db2e47a 100644
--- a/python/sedona/spark/SedonaContext.py
+++ b/python/sedona/spark/SedonaContext.py
@@ -21,6 +21,13 @@ from pyspark.sql import SparkSession
from sedona.register.geo_registrator import PackageImporter
from sedona.utils import KryoSerializer, SedonaKryoRegistrator
+try:
+ from pyspark.sql.utils import is_remote
+except ImportError:
+
+ def is_remote():
+ return False
+
@attr.s
class SedonaContext:
@@ -34,8 +41,11 @@ class SedonaContext:
:return: SedonaContext which is an instance of SparkSession
"""
spark.sql("SELECT 1 as geom").count()
- PackageImporter.import_jvm_lib(spark._jvm)
- spark._jvm.SedonaContext.create(spark._jsparkSession, "python")
+
+ # with Spark Connect there is no local JVM
+ if not is_remote():
+ PackageImporter.import_jvm_lib(spark._jvm)
+ spark._jvm.SedonaContext.create(spark._jsparkSession, "python")
return spark
@classmethod
diff --git a/python/tests/test_base.py b/python/sedona/sql/connect.py
similarity index 50%
copy from python/tests/test_base.py
copy to python/sedona/sql/connect.py
index e45a6e9f6..347099630 100644
--- a/python/tests/test_base.py
+++ b/python/sedona/sql/connect.py
@@ -15,27 +15,26 @@
# specific language governing permissions and limitations
# under the License.
-from tempfile import mkdtemp
+from typing import Any, Iterable, List
-from sedona.spark import *
-from sedona.utils.decorators import classproperty
+import pyspark.sql.connect.functions as f
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.expressions import UnresolvedFunction
-class TestBase:
+# mimic semantics of _convert_argument_to_java_column
+def _convert_argument_to_connect_column(arg: Any) -> Column:
+ if isinstance(arg, Column):
+ return arg
+ elif isinstance(arg, str):
+ return f.col(arg)
+ elif isinstance(arg, Iterable):
+ return f.array(*[_convert_argument_to_connect_column(x) for x in arg])
+ else:
+ return f.lit(arg)
- @classproperty
- def spark(self):
- if not hasattr(self, "__spark"):
- spark = SedonaContext.create(
- SedonaContext.builder().master("local[*]").getOrCreate()
- )
- spark.sparkContext.setCheckpointDir(mkdtemp())
- setattr(self, "__spark", spark)
- return getattr(self, "__spark")
+def call_sedona_function_connect(function_name: str, args: List[Any]) ->
Column:
- @classproperty
- def sc(self):
- if not hasattr(self, "__spark"):
- setattr(self, "__sc", self.spark._sc)
- return getattr(self, "__sc")
+ expressions = [_convert_argument_to_connect_column(arg)._expr for arg in
args]
+ return Column(UnresolvedFunction(function_name, expressions))
diff --git a/python/sedona/sql/dataframe_api.py
b/python/sedona/sql/dataframe_api.py
index 4f79878ba..2f56dfffa 100644
--- a/python/sedona/sql/dataframe_api.py
+++ b/python/sedona/sql/dataframe_api.py
@@ -24,8 +24,23 @@ from typing import Any, Callable, Iterable, List, Mapping,
Tuple, Type, Union
from pyspark.sql import Column, SparkSession
from pyspark.sql import functions as f
-ColumnOrName = Union[Column, str]
-ColumnOrNameOrNumber = Union[Column, str, float, int]
+try:
+ from pyspark.sql.connect.column import Column as ConnectColumn
+ from pyspark.sql.utils import is_remote
+except ImportError:
+ # be backwards compatible with Spark < 3.4
+ def is_remote():
+ return False
+
+ class ConnectColumn:
+ pass
+
+else:
+ from sedona.sql.connect import call_sedona_function_connect
+
+
+ColumnOrName = Union[Column, ConnectColumn, str]
+ColumnOrNameOrNumber = Union[Column, ConnectColumn, str, float, int]
def _convert_argument_to_java_column(arg: Any) -> Column:
@@ -49,13 +64,15 @@ def call_sedona_function(
)
# apparently a Column is an Iterable so we need to check for it explicitly
- if (
- (not isinstance(args, Iterable))
- or isinstance(args, str)
- or isinstance(args, Column)
+ if (not isinstance(args, Iterable)) or isinstance(
+ args, (str, Column, ConnectColumn)
):
args = [args]
+ # in spark-connect environments use connect API
+ if is_remote():
+ return call_sedona_function_connect(function_name, args)
+
args = map(_convert_argument_to_java_column, args)
jobject = getattr(spark._jvm, object_name)
@@ -86,6 +103,10 @@ def _get_type_list(annotated_type: Type) -> Tuple[Type,
...]:
else:
valid_types = (annotated_type,)
+ # functions accepting a Column should also accept the Spark Connect sort
of Column
+ if Column in valid_types:
+ valid_types = valid_types + (ConnectColumn,)
+
return valid_types
@@ -159,7 +180,7 @@ def validate_argument_types(f: Callable) -> Callable:
# all arguments are Columns or strings are always legal, so only check
types when one of the arguments is not a column
if not all(
[
- isinstance(x, Column) or isinstance(x, str)
+ isinstance(x, (Column, ConnectColumn)) or isinstance(x, str)
for x in itertools.chain(args, kwargs.values())
]
):
diff --git a/python/tests/test_base.py b/python/tests/test_base.py
index e45a6e9f6..4bfbb86b0 100644
--- a/python/tests/test_base.py
+++ b/python/tests/test_base.py
@@ -14,22 +14,43 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+import os
from tempfile import mkdtemp
+import pyspark
+
from sedona.spark import *
from sedona.utils.decorators import classproperty
+SPARK_REMOTE = os.getenv("SPARK_REMOTE")
+
class TestBase:
@classproperty
def spark(self):
if not hasattr(self, "__spark"):
- spark = SedonaContext.create(
- SedonaContext.builder().master("local[*]").getOrCreate()
- )
- spark.sparkContext.setCheckpointDir(mkdtemp())
+
+ builder = SedonaContext.builder()
+ if SPARK_REMOTE:
+ builder = (
+ builder.remote(SPARK_REMOTE)
+ .config(
+ "spark.jars.packages",
+
f"org.apache.spark:spark-connect_2.12:{pyspark.__version__}",
+ )
+ .config(
+ "spark.sql.extensions",
+ "org.apache.sedona.sql.SedonaSqlExtensions",
+ )
+ )
+ else:
+ builder = builder.master("local[*]")
+
+ spark = SedonaContext.create(builder.getOrCreate())
+
+ if not SPARK_REMOTE:
+ spark.sparkContext.setCheckpointDir(mkdtemp())
setattr(self, "__spark", spark)
return getattr(self, "__spark")