This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 3b09d9a0e2 [SEDONA-706] Fix Python dataframe api for multi-threaded 
environment (#1785)
3b09d9a0e2 is described below

commit 3b09d9a0e2c113fd364b6212be507ccff6bd9041
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Tue Feb 4 00:29:00 2025 +0800

    [SEDONA-706] Fix Python dataframe api for multi-threaded environment (#1785)
---
 .github/workflows/python.yml           |  8 ++------
 python/sedona/sql/dataframe_api.py     | 14 +++++++-------
 python/tests/sql/test_dataframe_api.py | 27 +++++++++++++++++++++++++++
 3 files changed, 36 insertions(+), 13 deletions(-)

diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index aaca28df05..6aad4a97b7 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -163,13 +163,9 @@ jobs:
       - name: Run Spark Connect tests
         env:
           PYTHON_VERSION: ${{ matrix.python }}
+          SPARK_VERSION: ${{ matrix.spark }}
+        if: ${{ matrix.spark >= '3.4.0' }}
         run: |
-          if [ ! -f 
"${VENV_PATH}/lib/python${PYTHON_VERSION}/site-packages/pyspark/sbin/start-connect-server.sh"
 ]
-          then
-            echo "Skipping connect tests for Spark $SPARK_VERSION"
-            exit
-          fi
-
           export 
SPARK_HOME=${VENV_PATH}/lib/python${PYTHON_VERSION}/site-packages/pyspark
           export SPARK_REMOTE=local
 
diff --git a/python/sedona/sql/dataframe_api.py 
b/python/sedona/sql/dataframe_api.py
index 2f56dfffa5..b1639a97bf 100644
--- a/python/sedona/sql/dataframe_api.py
+++ b/python/sedona/sql/dataframe_api.py
@@ -21,6 +21,7 @@ import itertools
 import typing
 from typing import Any, Callable, Iterable, List, Mapping, Tuple, Type, Union
 
+from pyspark import SparkContext
 from pyspark.sql import Column, SparkSession
 from pyspark.sql import functions as f
 
@@ -57,12 +58,6 @@ def _convert_argument_to_java_column(arg: Any) -> Column:
 def call_sedona_function(
     object_name: str, function_name: str, args: Union[Any, Tuple[Any]]
 ) -> Column:
-    spark = SparkSession.getActiveSession()
-    if spark is None:
-        raise ValueError(
-            "No active spark session was detected. Unable to call sedona 
function."
-        )
-
     # apparently a Column is an Iterable so we need to check for it explicitly
     if (not isinstance(args, Iterable)) or isinstance(
         args, (str, Column, ConnectColumn)
@@ -75,7 +70,12 @@ def call_sedona_function(
 
     args = map(_convert_argument_to_java_column, args)
 
-    jobject = getattr(spark._jvm, object_name)
+    jvm = SparkContext._jvm
+    if jvm is None:
+        raise ValueError(
+            "No active spark context was detected. Unable to call sedona 
function."
+        )
+    jobject = getattr(jvm, object_name)
     jfunc = getattr(jobject, function_name)
 
     jc = jfunc(*args)
diff --git a/python/tests/sql/test_dataframe_api.py 
b/python/tests/sql/test_dataframe_api.py
index 7f64750190..de65f6f0f4 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -15,6 +15,9 @@
 #  specific language governing permissions and limitations
 #  under the License.
 from math import radians
+import os
+import threading
+import concurrent.futures
 from typing import Callable, Tuple
 
 import pytest
@@ -1732,6 +1735,26 @@ class TestDataFrameAPI(TestBase):
         ):
             func(*args)
 
+    def test_multi_thread(self):
+        df = self.spark.range(0, 100)
+
+        def run_spatial_query():
+            result = df.select(
+                stf.ST_Buffer(stc.ST_Point("id", f.col("id") + 1), 
1.0).alias("geom")
+            ).collect()
+            assert len(result) == 100
+
+        # Create and run 4 threads
+        with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+            futures = [executor.submit(run_spatial_query) for _ in range(4)]
+            concurrent.futures.wait(futures)
+        for future in futures:
+            future.result()
+
+    @pytest.mark.skipif(
+        os.getenv("SPARK_REMOTE") is not None,
+        reason="Checkpoint dir is not available in Spark Connect",
+    )
     def test_dbscan(self):
         df = self.spark.createDataFrame([{"id": 1, "x": 2, "y": 
3}]).withColumn(
             "geometry", f.expr("ST_Point(x, y)")
@@ -1739,6 +1762,10 @@ class TestDataFrameAPI(TestBase):
 
         df.withColumn("dbscan", ST_DBSCAN("geometry", 1.0, 2, False)).collect()
 
+    @pytest.mark.skipif(
+        os.getenv("SPARK_REMOTE") is not None,
+        reason="Checkpoint dir is not available in Spark Connect",
+    )
     def test_lof(self):
         df = self.spark.createDataFrame([{"id": 1, "x": 2, "y": 
3}]).withColumn(
             "geometry", f.expr("ST_Point(x, y)")

Reply via email to