This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new bde919479e01 [SPARK-54176][GEO][PYTHON] Introduce Geography and 
Geometry data types to PySpark Connect
bde919479e01 is described below

commit bde919479e0143b2f44cff4c1b7debc312e838d2
Author: Uros Bojanic <[email protected]>
AuthorDate: Wed Nov 5 10:49:50 2025 -0800

    [SPARK-54176][GEO][PYTHON] Introduce Geography and Geometry data types to 
PySpark Connect
    
    ### What changes were proposed in this pull request?
    Introduce `GeographyType` and `GeometryType` to PySpark Connect. Note that 
the geospatial data types have already been introduced in PySpark as part of: 
https://github.com/apache/spark/pull/52627.
    
    Also, introduce classes to represent a `Geography` and `Geometry` value in 
Python. Note that the corresponding classes have already been introduced on 
Scala side as part of: https://github.com/apache/spark/pull/52804.
    
    ### Why are the changes needed?
    Enabling geospatial types in Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, `GeographyType` and `GeometryType` are now available in PySpark 
Connect.
    
    ### How was this patch tested?
    Added new Python Connect tests:
    - `test_parity_geographytype`
    - `test_parity_geometrytype`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #52871 from uros-db/geo-spark-connect.
    
    Authored-by: Uros Bojanic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   2 +
 python/pyspark/errors/error-conditions.json        |  10 ++
 python/pyspark/sql/connect/types.py                |  18 +++
 python/pyspark/sql/conversion.py                   |  74 +++++++++
 .../sql/tests/connect/test_parity_geographytype.py |  38 +++++
 .../sql/tests/connect/test_parity_geometrytype.py  |  38 +++++
 python/pyspark/sql/types.py                        | 168 +++++++++++++++++++++
 7 files changed, 348 insertions(+)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 07ac4c76b91a..aa8ca58a5a75 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1114,6 +1114,8 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_connect_retry",
         "pyspark.sql.tests.connect.test_connect_session",
         "pyspark.sql.tests.connect.test_connect_stat",
+        "pyspark.sql.tests.connect.test_parity_geographytype",
+        "pyspark.sql.tests.connect.test_parity_geometrytype",
         "pyspark.sql.tests.connect.test_parity_datasources",
         "pyspark.sql.tests.connect.test_parity_errors",
         "pyspark.sql.tests.connect.test_parity_catalog",
diff --git a/python/pyspark/errors/error-conditions.json 
b/python/pyspark/errors/error-conditions.json
index d169e6293a1b..51bbdd862516 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -549,6 +549,16 @@
       "<arg1> and <arg2> should be of the same length, got <arg1_length> and 
<arg2_length>."
     ]
   },
+  "MALFORMED_GEOGRAPHY": {
+    "message": [
+      "Geography binary is malformed. Please check the data source is valid."
+    ]
+  },
+  "MALFORMED_GEOMETRY": {
+    "message": [
+      "Geometry binary is malformed. Please check the data source is valid."
+    ]
+  },
   "MALFORMED_VARIANT": {
     "message": [
       "Variant binary is malformed. Please check the data source is valid."
diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index 7e8f76861079..d3352b618d7c 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -50,6 +50,8 @@ from pyspark.sql.types import (
     NullType,
     NumericType,
     VariantType,
+    GeographyType,
+    GeometryType,
     UserDefinedType,
 )
 from pyspark.errors import PySparkAssertionError, PySparkValueError
@@ -191,6 +193,10 @@ def pyspark_types_to_proto_types(data_type: DataType) -> 
pb2.DataType:
         ret.array.contains_null = data_type.containsNull
     elif isinstance(data_type, VariantType):
         ret.variant.CopyFrom(pb2.DataType.Variant())
+    elif isinstance(data_type, GeometryType):
+        ret.geometry.srid = data_type.srid
+    elif isinstance(data_type, GeographyType):
+        ret.geography.srid = data_type.srid
     elif isinstance(data_type, UserDefinedType):
         json_value = data_type.jsonValue()
         ret.udt.type = "udt"
@@ -303,6 +309,18 @@ def proto_schema_to_pyspark_data_type(schema: 
pb2.DataType) -> DataType:
         )
     elif schema.HasField("variant"):
         return VariantType()
+    elif schema.HasField("geometry"):
+        srid = schema.geometry.srid
+        if srid == GeometryType.MIXED_SRID:
+            return GeometryType("ANY")
+        else:
+            return GeometryType(srid)
+    elif schema.HasField("geography"):
+        srid = schema.geography.srid
+        if srid == GeographyType.MIXED_SRID:
+            return GeographyType("ANY")
+        else:
+            return GeographyType(srid)
     elif schema.HasField("udt"):
         assert schema.udt.type == "udt"
         json_value = {}
diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index a8f621277a0a..f73727d1d534 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -28,6 +28,10 @@ from pyspark.sql.types import (
     BinaryType,
     DataType,
     DecimalType,
+    GeographyType,
+    Geography,
+    GeometryType,
+    Geometry,
     MapType,
     NullType,
     Row,
@@ -89,6 +93,10 @@ class LocalDataToArrowConversion:
             return True
         elif isinstance(dataType, VariantType):
             return True
+        elif isinstance(dataType, GeometryType):
+            return True
+        elif isinstance(dataType, GeographyType):
+            return True
         else:
             return False
 
@@ -392,6 +400,34 @@ class LocalDataToArrowConversion:
 
             return convert_variant
 
+        elif isinstance(dataType, GeographyType):
+
+            def convert_geography(value: Any) -> Any:
+                if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
+                    return None
+                elif isinstance(value, Geography):
+                    return dataType.toInternal(value)
+                else:
+                    raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")
+
+            return convert_geography
+
+        elif isinstance(dataType, GeometryType):
+
+            def convert_geometry(value: Any) -> Any:
+                if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
+                    return None
+                elif isinstance(value, Geometry):
+                    return dataType.toInternal(value)
+                else:
+                    raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")
+
+            return convert_geometry
+
         elif not nullable:
 
             def convert_other(value: Any) -> Any:
@@ -511,6 +547,10 @@ class ArrowTableToRowsConversion:
             return True
         elif isinstance(dataType, VariantType):
             return True
+        elif isinstance(dataType, GeographyType):
+            return True
+        elif isinstance(dataType, GeometryType):
+            return True
         else:
             return False
 
@@ -719,6 +759,40 @@ class ArrowTableToRowsConversion:
 
             return convert_variant
 
+        elif isinstance(dataType, GeographyType):
+
+            def convert_geography(value: Any) -> Any:
+                if value is None:
+                    return None
+                elif (
+                    isinstance(value, dict)
+                    and all(key in value for key in ["wkb", "srid"])
+                    and isinstance(value["wkb"], bytes)
+                    and isinstance(value["srid"], int)
+                ):
+                    return Geography.fromWKB(value["wkb"], value["srid"])
+                else:
+                    raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")
+
+            return convert_geography
+
+        elif isinstance(dataType, GeometryType):
+
+            def convert_geometry(value: Any) -> Any:
+                if value is None:
+                    return None
+                elif (
+                    isinstance(value, dict)
+                    and all(key in value for key in ["wkb", "srid"])
+                    and isinstance(value["wkb"], bytes)
+                    and isinstance(value["srid"], int)
+                ):
+                    return Geometry.fromWKB(value["wkb"], value["srid"])
+                else:
+                    raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")
+
+            return convert_geometry
+
         else:
             if none_on_identity:
                 return None
diff --git a/python/pyspark/sql/tests/connect/test_parity_geographytype.py 
b/python/pyspark/sql/tests/connect/test_parity_geographytype.py
new file mode 100644
index 000000000000..501bbed20ff1
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_geographytype.py
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.test_geographytype import GeographyTypeTestMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class GeographyTypeParityTest(GeographyTypeTestMixin, ReusedConnectTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.connect.test_parity_geographytype import *  # noqa: 
F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_geometrytype.py 
b/python/pyspark/sql/tests/connect/test_parity_geometrytype.py
new file mode 100644
index 000000000000..b95321b3c61b
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_geometrytype.py
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.test_geometrytype import GeometryTypeTestMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class GeometryTypeParityTest(GeometryTypeTestMixin, ReusedConnectTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.connect.test_parity_geometrytype import *  # noqa: 
F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 440100dba931..8aae39880072 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -90,6 +90,8 @@ __all__ = [
     "TimestampNTZType",
     "DecimalType",
     "DoubleType",
+    "Geography",
+    "Geometry",
     "FloatType",
     "ByteType",
     "IntegerType",
@@ -616,6 +618,20 @@ class GeographyType(SpatialType):
         # The JSON representation always uses the CRS and algorithm value.
         return f"geography({self._crs}, {self._alg})"
 
+    def needConversion(self) -> bool:
+        return True
+
+    def fromInternal(self, obj: Dict) -> Optional["Geography"]:
+        if obj is None or not all(key in obj for key in ["srid", "bytes"]):
+            return None
+        return Geography(obj["bytes"], obj["srid"])
+
+    def toInternal(self, geography: Any) -> Any:
+        if geography is None:
+            return None
+        assert isinstance(geography, Geography)
+        return {"srid": geography.srid, "wkb": geography.wkb}
+
 
 class GeometryType(SpatialType):
     """
@@ -700,6 +716,20 @@ class GeometryType(SpatialType):
         # The JSON representation always uses the CRS value.
         return f"geometry({self._crs})"
 
+    def needConversion(self) -> bool:
+        return True
+
+    def fromInternal(self, obj: Dict) -> Optional["Geometry"]:
+        if obj is None or not all(key in obj for key in ["srid", "bytes"]):
+            return None
+        return Geometry(obj["bytes"], obj["srid"])
+
+    def toInternal(self, geometry: Any) -> Any:
+        if geometry is None:
+            return None
+        assert isinstance(geometry, Geometry)
+        return {"srid": geometry.srid, "wkb": geometry.wkb}
+
 
 class ByteType(IntegralType):
     """Byte data type, representing signed 8-bit integers."""
@@ -2039,6 +2069,144 @@ class VariantVal:
         return VariantVal(value, metadata)
 
 
+class Geography:
+    """
+    A class to represent a Geography value in Python.
+
+    .. versionadded:: 4.1.0
+
+    Parameters
+    ----------
+    wkb : bytes
+        The bytes representing the WKB of Geography.
+
+    srid : integer
+        The integer value representing SRID of Geography.
+
+    Methods
+    -------
+    getBytes()
+        Returns the WKB of Geography.
+
+    getSrid()
+        Returns the SRID of Geography.
+
+    Examples
+    --------
+    >>> g = 
Geography.fromWKB(bytes.fromhex('010100000000000000000031400000000000001c40'), 
4326)
+    >>> g.getBytes().hex()
+    '010100000000000000000031400000000000001c40'
+    >>> g.getSrid()
+    4326
+    """
+
+    def __init__(self, wkb: bytes, srid: int):
+        self.wkb = wkb
+        self.srid = srid
+
+    def __str__(self) -> str:
+        return "Geography(%r, %d)" % (self.wkb, self.srid)
+
+    def __repr__(self) -> str:
+        return "Geography(%r, %d)" % (self.wkb, self.srid)
+
+    def getSrid(self) -> int:
+        """
+        Returns the SRID of Geography.
+        """
+        return self.srid
+
+    def getBytes(self) -> bytes:
+        """
+        Returns the WKB of Geography.
+        """
+        return self.wkb
+
+    def __eq__(self, other: Any) -> bool:
+        if not isinstance(other, Geography):
+            # Don't attempt to compare against unrelated types.
+            return NotImplemented
+
+        return self.wkb == other.wkb and self.srid == other.srid
+
+    @classmethod
+    def fromWKB(cls, wkb: bytes, srid: int) -> "Geography":
+        """
+        Construct Python Geography object from WKB.
+        :return: Python representation of the Geography type value.
+        """
+        return Geography(wkb, srid)
+
+
+class Geometry:
+    """
+    A class to represent a Geometry value in Python.
+
+    .. versionadded:: 4.1.0
+
+    Parameters
+    ----------
+    wkb : bytes
+        The bytes representing the WKB of Geometry.
+
+    srid : integer
+        The integer value representing SRID of Geometry.
+
+    Methods
+    -------
+    getBytes()
+        Returns the WKB of Geometry.
+
+    getSrid()
+        Returns the SRID of Geometry.
+
+    Examples
+    --------
+    >>> g = 
Geometry.fromWKB(bytes.fromhex('010100000000000000000031400000000000001c40'), 0)
+    >>> g.getBytes().hex()
+    '010100000000000000000031400000000000001c40'
+    >>> g.getSrid()
+    0
+    """
+
+    def __init__(self, wkb: bytes, srid: int):
+        self.wkb = wkb
+        self.srid = srid
+
+    def __str__(self) -> str:
+        return "Geometry(%r, %d)" % (self.wkb, self.srid)
+
+    def __repr__(self) -> str:
+        return "Geometry(%r, %d)" % (self.wkb, self.srid)
+
+    def getSrid(self) -> int:
+        """
+        Returns the SRID of Geometry.
+        """
+        return self.srid
+
+    def getBytes(self) -> bytes:
+        """
+        Returns the WKB of Geometry.
+        """
+        return self.wkb
+
+    def __eq__(self, other: Any) -> bool:
+        if not isinstance(other, Geometry):
+            # Don't attempt to compare against unrelated types.
+            return NotImplemented
+
+        return self.wkb == other.wkb and self.srid == other.srid
+
+    @classmethod
+    def fromWKB(cls, wkb: bytes, srid: int) -> "Geometry":
+        """
+        Construct Python Geometry object from WKB.
+        :return: Python representation of the Geometry type value.
+        """
+        return Geometry(wkb, srid)
+
+
 _atomic_types: List[Type[DataType]] = [
     StringType,
     CharType,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to