This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 44984f196 [SEDONA-663] Support spark connect python api (#1639)
44984f196 is described below

commit 44984f196f7911156b1b32603276c81990087a3c
Author: Sebastian Eckweiler <[email protected]>
AuthorDate: Tue Oct 15 22:28:48 2024 +0200

    [SEDONA-663] Support spark connect python api (#1639)
    
    * initial successful test
    
    * try add docker-compose based tests
    
    * 3.5 only
    
    * comment classic tests
    
    * try fix yaml
    
    * skip other workflows
    
    * skip other workflows
    
    * try fix if check
    
    * fix path
    
    * cd to python folder
    
    * skip sparkContext with SPARK_REMOTE
    
    * fix type check
    
    * refactor somewhat
    
    * Revert "skip other workflows"
    
    This reverts commit 7eb9b6ea
    
    * back to full matrix
    
    * add license header, fix missing whitespace
    
    * Add a simple docstring to SedonaFunction
    
    * uncomment build step
    
    * need sql extensions
    
    * run pre-commit
    
    * fix lint/pre-commit
    
    * Update .github/workflows/python.yml
    
    Co-authored-by: John Bampton <[email protected]>
    
    * adjust spelling
    
    * use UnresolvedFunction instead of CallFunction
    
    * revert Pipfile to master rev
    
    ---------
    
    Co-authored-by: John Bampton <[email protected]>
---
 .github/workflows/python.yml                       | 17 +++++++++++
 python/sedona/spark/SedonaContext.py               | 14 +++++++--
 .../{tests/test_base.py => sedona/sql/connect.py}  | 35 +++++++++++-----------
 python/sedona/sql/dataframe_api.py                 | 35 +++++++++++++++++-----
 python/tests/test_base.py                          | 31 +++++++++++++++----
 5 files changed, 100 insertions(+), 32 deletions(-)

diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index e7d1002d9..04fa4f7fc 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -153,3 +153,20 @@ jobs:
           SPARK_VERSION: ${{ matrix.spark }}
           HADOOP_VERSION: ${{ matrix.hadoop }}
         run: (export 
SPARK_HOME=$PWD/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION};export 
PYTHONPATH=$SPARK_HOME/python;cd python;pipenv run pytest tests)
+      - env:
+          SPARK_VERSION: ${{ matrix.spark }}
+          HADOOP_VERSION: ${{ matrix.hadoop }}
+        run: |
+          if [ ! -f 
"spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}/sbin/start-connect-server.sh"
 ]
+          then
+            echo "Skipping connect tests for Spark $SPARK_VERSION"
+            exit
+          fi
+
+          export 
SPARK_HOME=$PWD/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}
+          export PYTHONPATH=$SPARK_HOME/python
+          export SPARK_REMOTE=local
+
+          cd python
+          pipenv install "pyspark[connect]==${SPARK_VERSION}"
+          pipenv run pytest tests/sql/test_dataframe_api.py
diff --git a/python/sedona/spark/SedonaContext.py 
b/python/sedona/spark/SedonaContext.py
index 5cba5df62..49db2e47a 100644
--- a/python/sedona/spark/SedonaContext.py
+++ b/python/sedona/spark/SedonaContext.py
@@ -21,6 +21,13 @@ from pyspark.sql import SparkSession
 from sedona.register.geo_registrator import PackageImporter
 from sedona.utils import KryoSerializer, SedonaKryoRegistrator
 
+try:
+    from pyspark.sql.utils import is_remote
+except ImportError:
+
+    def is_remote():
+        return False
+
 
 @attr.s
 class SedonaContext:
@@ -34,8 +41,11 @@ class SedonaContext:
         :return: SedonaContext which is an instance of SparkSession
         """
         spark.sql("SELECT 1 as geom").count()
-        PackageImporter.import_jvm_lib(spark._jvm)
-        spark._jvm.SedonaContext.create(spark._jsparkSession, "python")
+
+        # with Spark Connect there is no local JVM
+        if not is_remote():
+            PackageImporter.import_jvm_lib(spark._jvm)
+            spark._jvm.SedonaContext.create(spark._jsparkSession, "python")
         return spark
 
     @classmethod
diff --git a/python/tests/test_base.py b/python/sedona/sql/connect.py
similarity index 50%
copy from python/tests/test_base.py
copy to python/sedona/sql/connect.py
index e45a6e9f6..347099630 100644
--- a/python/tests/test_base.py
+++ b/python/sedona/sql/connect.py
@@ -15,27 +15,26 @@
 #  specific language governing permissions and limitations
 #  under the License.
 
-from tempfile import mkdtemp
+from typing import Any, Iterable, List
 
-from sedona.spark import *
-from sedona.utils.decorators import classproperty
+import pyspark.sql.connect.functions as f
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.expressions import UnresolvedFunction
 
 
-class TestBase:
+# mimic semantics of _convert_argument_to_java_column
+def _convert_argument_to_connect_column(arg: Any) -> Column:
+    if isinstance(arg, Column):
+        return arg
+    elif isinstance(arg, str):
+        return f.col(arg)
+    elif isinstance(arg, Iterable):
+        return f.array(*[_convert_argument_to_connect_column(x) for x in arg])
+    else:
+        return f.lit(arg)
 
-    @classproperty
-    def spark(self):
-        if not hasattr(self, "__spark"):
-            spark = SedonaContext.create(
-                SedonaContext.builder().master("local[*]").getOrCreate()
-            )
-            spark.sparkContext.setCheckpointDir(mkdtemp())
 
-            setattr(self, "__spark", spark)
-        return getattr(self, "__spark")
+def call_sedona_function_connect(function_name: str, args: List[Any]) -> 
Column:
 
-    @classproperty
-    def sc(self):
-        if not hasattr(self, "__spark"):
-            setattr(self, "__sc", self.spark._sc)
-        return getattr(self, "__sc")
+    expressions = [_convert_argument_to_connect_column(arg)._expr for arg in 
args]
+    return Column(UnresolvedFunction(function_name, expressions))
diff --git a/python/sedona/sql/dataframe_api.py 
b/python/sedona/sql/dataframe_api.py
index 4f79878ba..2f56dfffa 100644
--- a/python/sedona/sql/dataframe_api.py
+++ b/python/sedona/sql/dataframe_api.py
@@ -24,8 +24,23 @@ from typing import Any, Callable, Iterable, List, Mapping, 
Tuple, Type, Union
 from pyspark.sql import Column, SparkSession
 from pyspark.sql import functions as f
 
-ColumnOrName = Union[Column, str]
-ColumnOrNameOrNumber = Union[Column, str, float, int]
+try:
+    from pyspark.sql.connect.column import Column as ConnectColumn
+    from pyspark.sql.utils import is_remote
+except ImportError:
+    # be backwards compatible with Spark < 3.4
+    def is_remote():
+        return False
+
+    class ConnectColumn:
+        pass
+
+else:
+    from sedona.sql.connect import call_sedona_function_connect
+
+
+ColumnOrName = Union[Column, ConnectColumn, str]
+ColumnOrNameOrNumber = Union[Column, ConnectColumn, str, float, int]
 
 
 def _convert_argument_to_java_column(arg: Any) -> Column:
@@ -49,13 +64,15 @@ def call_sedona_function(
         )
 
     # apparently a Column is an Iterable so we need to check for it explicitly
-    if (
-        (not isinstance(args, Iterable))
-        or isinstance(args, str)
-        or isinstance(args, Column)
+    if (not isinstance(args, Iterable)) or isinstance(
+        args, (str, Column, ConnectColumn)
     ):
         args = [args]
 
+    # in spark-connect environments use connect API
+    if is_remote():
+        return call_sedona_function_connect(function_name, args)
+
     args = map(_convert_argument_to_java_column, args)
 
     jobject = getattr(spark._jvm, object_name)
@@ -86,6 +103,10 @@ def _get_type_list(annotated_type: Type) -> Tuple[Type, 
...]:
     else:
         valid_types = (annotated_type,)
 
+    # functions accepting a Column should also accept the Spark Connect sort 
of Column
+    if Column in valid_types:
+        valid_types = valid_types + (ConnectColumn,)
+
     return valid_types
 
 
@@ -159,7 +180,7 @@ def validate_argument_types(f: Callable) -> Callable:
         # all arguments are Columns or strings are always legal, so only check 
types when one of the arguments is not a column
         if not all(
             [
-                isinstance(x, Column) or isinstance(x, str)
+                isinstance(x, (Column, ConnectColumn)) or isinstance(x, str)
                 for x in itertools.chain(args, kwargs.values())
             ]
         ):
diff --git a/python/tests/test_base.py b/python/tests/test_base.py
index e45a6e9f6..4bfbb86b0 100644
--- a/python/tests/test_base.py
+++ b/python/tests/test_base.py
@@ -14,22 +14,43 @@
 #  KIND, either express or implied.  See the License for the
 #  specific language governing permissions and limitations
 #  under the License.
-
+import os
 from tempfile import mkdtemp
 
+import pyspark
+
 from sedona.spark import *
 from sedona.utils.decorators import classproperty
 
+SPARK_REMOTE = os.getenv("SPARK_REMOTE")
+
 
 class TestBase:
 
     @classproperty
     def spark(self):
         if not hasattr(self, "__spark"):
-            spark = SedonaContext.create(
-                SedonaContext.builder().master("local[*]").getOrCreate()
-            )
-            spark.sparkContext.setCheckpointDir(mkdtemp())
+
+            builder = SedonaContext.builder()
+            if SPARK_REMOTE:
+                builder = (
+                    builder.remote(SPARK_REMOTE)
+                    .config(
+                        "spark.jars.packages",
+                        
f"org.apache.spark:spark-connect_2.12:{pyspark.__version__}",
+                    )
+                    .config(
+                        "spark.sql.extensions",
+                        "org.apache.sedona.sql.SedonaSqlExtensions",
+                    )
+                )
+            else:
+                builder = builder.master("local[*]")
+
+            spark = SedonaContext.create(builder.getOrCreate())
+
+            if not SPARK_REMOTE:
+                spark.sparkContext.setCheckpointDir(mkdtemp())
 
             setattr(self, "__spark", spark)
         return getattr(self, "__spark")

Reply via email to