This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new 7b7bca875ca9 [SPARK-54232][GEO][CONNECT] Enable Arrow serialization 
for Geography and Geometry types
7b7bca875ca9 is described below

commit 7b7bca875ca9c53169eac65645ca8299a83d14a1
Author: Uros Bojanic <[email protected]>
AuthorDate: Fri Nov 7 13:06:28 2025 -0800

    [SPARK-54232][GEO][CONNECT] Enable Arrow serialization for Geography and 
Geometry types
    
    ### What changes were proposed in this pull request?
    Introduce Arrow serialization/deserialization for `Geography` and 
`Geometry`.
    
    ### Why are the changes needed?
    Enable geospatial result set serialization in Arrow format for Spark 
Connect and Thrift Server.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Added unit tests:
    - `GeographyConnectDataFrameSuite`
    - `GeometryConnectDataFrameSuite`
    - `ArrowEncoderSuite`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #52930 from uros-db/geo-arrow.
    
    Authored-by: Uros Bojanic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit f9191248900c5933a2934e938931f4fdc8b76f50)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/GeographyConnectDataFrameSuite.scala | 106 +++++++++++++++++++
 .../spark/sql/GeometryConnectDataFrameSuite.scala  | 112 +++++++++++++++++++++
 .../connect/client/arrow/ArrowEncoderSuite.scala   |  98 +++++++++++++++++-
 .../connect/client/arrow/ArrowDeserializer.scala   |   8 ++
 .../connect/client/arrow/ArrowEncoderUtils.scala   |  16 +++
 .../sql/connect/client/arrow/ArrowSerializer.scala |  14 ++-
 .../client/arrow/GeospatialArrowSerDe.scala        | 101 +++++++++++++++++++
 .../connect/common/DataTypeProtoConverter.scala    |  35 +++++++
 8 files changed, 487 insertions(+), 3 deletions(-)

diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala
new file mode 100644
index 000000000000..2016a84ac5a3
--- /dev/null
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.immutable.Seq
+
+import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
+import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
+import org.apache.spark.sql.types._
+
+class GeographyConnectDataFrameSuite extends QueryTest with RemoteSparkSession 
{
+
+  private val point1: Array[Byte] = 
"010100000000000000000031400000000000001C40"
+    .grouped(2)
+    .map(Integer.parseInt(_, 16).toByte)
+    .toArray
+  private val point2: Array[Byte] = 
"010100000000000000000035400000000000001E40"
+    .grouped(2)
+    .map(Integer.parseInt(_, 16).toByte)
+    .toArray
+
+  test("decode geography value: SRID schema does not match input SRID data 
schema") {
+    val geography = Geography.fromWKB(point1, 0)
+
+    val seq = Seq((geography, 1))
+    checkError(
+      exception = intercept[SparkRuntimeException] {
+        spark.createDataFrame(seq).collect()
+      },
+      condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+      parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" 
-> "4326"))
+
+    import testImplicits._
+    checkError(
+      exception = intercept[SparkRuntimeException] {
+        Seq(geography).toDF().collect()
+      },
+      condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+      parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" 
-> "4326"))
+  }
+
+  test("decode geography value: mixed SRID schema is provided") {
+    val schema = StructType(Seq(StructField("col1", GeographyType("ANY"), 
nullable = false)))
+    val expectedResult =
+      Seq(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2, 
4326)))
+
+    val javaList = java.util.Arrays
+      .asList(Row(Geography.fromWKB(point1, 4326)), 
Row(Geography.fromWKB(point2, 4326)))
+    val resultJavaListDF = spark.createDataFrame(javaList, schema)
+    checkAnswer(resultJavaListDF, expectedResult)
+
+    // Test that unsupported SRID with mixed schema will throw an error.
+    val invalidData =
+      java.util.Arrays
+        .asList(Row(Geography.fromWKB(point1, 1)), 
Row(Geography.fromWKB(point2, 4326)))
+    checkError(
+      exception = intercept[SparkIllegalArgumentException] {
+        spark.createDataFrame(invalidData, schema).collect()
+      },
+      condition = "ST_INVALID_SRID_VALUE",
+      parameters = Map("srid" -> "1"))
+  }
+
+  test("createDataFrame APIs with Geography.fromWKB") {
+    val geography1 = Geography.fromWKB(point1, 4326)
+    val geography2 = Geography.fromWKB(point2)
+
+    val seq = Seq((geography1, 1), (geography2, 2), (null, 3))
+    val dfFromSeq = spark.createDataFrame(seq)
+    checkAnswer(dfFromSeq, Seq(Row(geography1, 1), Row(geography2, 2), 
Row(null, 3)))
+
+    val schema = StructType(Seq(StructField("geography", GeographyType(4326), 
nullable = true)))
+
+    val javaList = java.util.Arrays.asList(Row(geography1), Row(geography2), 
Row(null))
+    val dfFromJavaList = spark.createDataFrame(javaList, schema)
+    checkAnswer(dfFromJavaList, Seq(Row(geography1), Row(geography2), 
Row(null)))
+
+    import testImplicits._
+    val implicitDf = Seq(geography1, geography2, null).toDF()
+    checkAnswer(implicitDf, Seq(Row(geography1), Row(geography2), Row(null)))
+  }
+
+  test("encode geography type") {
+    // POINT (17 7)
+    val wkb = "010100000000000000000031400000000000001C40"
+    val df = spark.sql(s"SELECT ST_GeogFromWKB(X'$wkb')")
+    val point = wkb.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+    val expectedGeog = Geography.fromWKB(point, 4326)
+    checkAnswer(df, Seq(Row(expectedGeog)))
+  }
+}
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala
new file mode 100644
index 000000000000..1450ac54184b
--- /dev/null
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.immutable.Seq
+
+import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
+import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
+import org.apache.spark.sql.types._
+
+class GeometryConnectDataFrameSuite extends QueryTest with RemoteSparkSession {
+
+  private val point1: Array[Byte] = 
"010100000000000000000031400000000000001C40"
+    .grouped(2)
+    .map(Integer.parseInt(_, 16).toByte)
+    .toArray
+  private val point2: Array[Byte] = 
"010100000000000000000035400000000000001E40"
+    .grouped(2)
+    .map(Integer.parseInt(_, 16).toByte)
+    .toArray
+
+  test("decode geometry value: SRID schema does not match input SRID data 
schema") {
+    val geometry = Geometry.fromWKB(point1, 4326)
+
+    val seq = Seq((geometry, 1))
+    checkError(
+      exception = intercept[SparkRuntimeException] {
+        spark.createDataFrame(seq).collect()
+      },
+      condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+      parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" 
-> "0"))
+
+    import testImplicits._
+    checkError(
+      exception = intercept[SparkRuntimeException] {
+        Seq(geometry).toDF().collect()
+      },
+      condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+      parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" 
-> "0"))
+  }
+
+  test("decode geometry value: mixed SRID schema is provided") {
+    val schema = StructType(Seq(StructField("col1", GeometryType("ANY"), 
nullable = false)))
+    val expectedResult =
+      Seq(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 
4326)))
+
+    val javaList = java.util.Arrays
+      .asList(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 
4326)))
+    val resultJavaListDF = spark.createDataFrame(javaList, schema)
+    checkAnswer(resultJavaListDF, expectedResult)
+
+    // Test that unsupported SRID with mixed schema will throw an error.
+    val invalidData =
+      java.util.Arrays
+        .asList(Row(Geometry.fromWKB(point1, 1)), Row(Geometry.fromWKB(point2, 
4326)))
+    checkError(
+      exception = intercept[SparkIllegalArgumentException] {
+        spark.createDataFrame(invalidData, schema).collect()
+      },
+      condition = "ST_INVALID_SRID_VALUE",
+      parameters = Map("srid" -> "1"))
+  }
+
+  test("createDataFrame APIs with Geometry.fromWKB") {
+    val geometry1 = Geometry.fromWKB(point1, 0)
+    val geometry2 = Geometry.fromWKB(point2, 0)
+
+    // 1. Test createDataFrame with Seq of Geometry objects
+    val seq = Seq((geometry1, 1), (geometry2, 2), (null, 3))
+    val dfFromSeq = spark.createDataFrame(seq)
+    checkAnswer(dfFromSeq, Seq(Row(geometry1, 1), Row(geometry2, 2), Row(null, 
3)))
+
+    // 2. Test createDataFrame with RDD of Rows and StructType schema
+    val geometry3 = Geometry.fromWKB(point1, 4326)
+    val geometry4 = Geometry.fromWKB(point2, 4326)
+    val schema = StructType(Seq(StructField("geometry", GeometryType(4326), 
nullable = true)))
+
+    // 3. Test createDataFrame with Java List of Rows and StructType schema
+    val javaList = java.util.Arrays.asList(Row(geometry3), Row(geometry4), 
Row(null))
+    val dfFromJavaList = spark.createDataFrame(javaList, schema)
+    checkAnswer(dfFromJavaList, Seq(Row(geometry3), Row(geometry4), Row(null)))
+
+    // 4. Implicit conversion from Seq to DF
+    import testImplicits._
+    val implicitDf = Seq(geometry1, geometry2, null).toDF()
+    checkAnswer(implicitDf, Seq(Row(geometry1), Row(geometry2), Row(null)))
+  }
+
+  test("encode geometry type") {
+    // POINT (17 7)
+    val wkb = "010100000000000000000031400000000000001C40"
+    val df = spark.sql(s"SELECT ST_GeomFromWKB(X'$wkb')")
+    val point = wkb.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+    val expectedGeom = Geometry.fromWKB(point, 0)
+    checkAnswer(df, Seq(Row(expectedGeom)))
+  }
+}
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index bc840df5c3fa..d24369ff5fc7 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._
 import org.apache.spark.sql.connect.client.CloseableIterator
 import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
 import org.apache.spark.sql.connect.test.ConnectFunSuite
-import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, 
Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, 
StructType, UserDefinedType, YearMonthIntervalType}
+import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, 
Decimal, DecimalType, Geography, Geometry, IntegerType, Metadata, 
SQLUserDefinedType, StringType, StructType, UserDefinedType, 
YearMonthIntervalType}
 import org.apache.spark.unsafe.types.VariantVal
 import org.apache.spark.util.{MaybeNull, SparkStringUtils}
 
@@ -263,6 +263,102 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     assert(inspector.numBatches == 1)
   }
 
+  test("geography round trip") {
+    val point1 = "010100000000000000000031400000000000001C40"
+      .grouped(2)
+      .map(Integer.parseInt(_, 16).toByte)
+      .toArray
+    val point2 = "010100000000000000000035400000000000001E40"
+      .grouped(2)
+      .map(Integer.parseInt(_, 16).toByte)
+      .toArray
+
+    val geographyEncoder = toRowEncoder(new StructType().add("g", 
"geography(4326)"))
+    roundTripAndCheckIdentical(geographyEncoder) { () =>
+      val maybeNull = MaybeNull(7)
+      Iterator.tabulate(101)(i => Row(maybeNull(Geography.fromWKB(point1, 
4326))))
+    }
+
+    val nestedGeographyEncoder = toRowEncoder(
+      new StructType()
+        .add(
+          "s",
+          new StructType()
+            .add("i1", "int")
+            .add("g0", "geography(4326)")
+            .add("i2", "int")
+            .add("g4326", "geography(4326)"))
+        .add("a", "array<geography(4326)>")
+        .add("m", "map<string, geography(ANY)>"))
+
+    roundTripAndCheckIdentical(nestedGeographyEncoder) { () =>
+      val maybeNull5 = MaybeNull(5)
+      val maybeNull7 = MaybeNull(7)
+      val maybeNull11 = MaybeNull(11)
+      val maybeNull13 = MaybeNull(13)
+      val maybeNull17 = MaybeNull(17)
+      Iterator
+        .tabulate(100)(i =>
+          Row(
+            maybeNull5(
+              Row(
+                i,
+                maybeNull7(Geography.fromWKB(point1)),
+                i + 1,
+                maybeNull11(Geography.fromWKB(point2, 4326)))),
+            maybeNull7((0 until 10).map(j => Geography.fromWKB(point2, 0))),
+            maybeNull13(Map((i.toString, maybeNull17(Geography.fromWKB(point1, 
4326)))))))
+    }
+  }
+
+  test("geometry round trip") {
+    val point1 = "010100000000000000000031400000000000001C40"
+      .grouped(2)
+      .map(Integer.parseInt(_, 16).toByte)
+      .toArray
+    val point2 = "010100000000000000000035400000000000001E40"
+      .grouped(2)
+      .map(Integer.parseInt(_, 16).toByte)
+      .toArray
+
+    val geometryEncoder = toRowEncoder(new StructType().add("g", 
"geometry(0)"))
+    roundTripAndCheckIdentical(geometryEncoder) { () =>
+      val maybeNull = MaybeNull(7)
+      Iterator.tabulate(101)(i => Row(maybeNull(Geometry.fromWKB(point1, 0))))
+    }
+
+    val nestedGeometryEncoder = toRowEncoder(
+      new StructType()
+        .add(
+          "s",
+          new StructType()
+            .add("i1", "int")
+            .add("g0", "geometry(0)")
+            .add("i2", "int")
+            .add("g4326", "geometry(4326)"))
+        .add("a", "array<geometry(0)>")
+        .add("m", "map<string, geometry(ANY)>"))
+
+    roundTripAndCheckIdentical(nestedGeometryEncoder) { () =>
+      val maybeNull5 = MaybeNull(5)
+      val maybeNull7 = MaybeNull(7)
+      val maybeNull11 = MaybeNull(11)
+      val maybeNull13 = MaybeNull(13)
+      val maybeNull17 = MaybeNull(17)
+      Iterator
+        .tabulate(100)(i =>
+          Row(
+            maybeNull5(
+              Row(
+                i,
+                maybeNull7(Geometry.fromWKB(point1, 0)),
+                i + 1,
+                maybeNull11(Geometry.fromWKB(point2, 4326)))),
+            maybeNull7((0 until 10).map(j => Geometry.fromWKB(point2, 0))),
+            maybeNull13(Map((i.toString, maybeNull17(Geometry.fromWKB(point1, 
4326)))))))
+    }
+  }
+
   test("variant round trip") {
     val variantEncoder = toRowEncoder(new StructType().add("v", "variant"))
     roundTripAndCheckIdentical(variantEncoder) { () =>
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 7597a0ceeb8c..8d5811dda8f3 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -341,6 +341,14 @@ object ArrowDeserializers {
           }
         }
 
+      case (_: GeometryEncoder, StructVectors(struct, vectors)) =>
+        val gdser = new GeometryArrowSerDe
+        gdser.createDeserializer(struct, vectors, timeZoneId)
+
+      case (_: GeographyEncoder, StructVectors(struct, vectors)) =>
+        val gdser = new GeographyArrowSerDe
+        gdser.createDeserializer(struct, vectors, timeZoneId)
+
       case (VariantEncoder, StructVectors(struct, vectors)) =>
         assert(vectors.exists(_.getName == "value"))
         assert(
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
index 5b1539e39f4f..2430c2bbc86f 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
@@ -41,6 +41,22 @@ private[arrow] object ArrowEncoderUtils {
   def unsupportedCollectionType(cls: Class[_]): Nothing = {
     throw new RuntimeException(s"Unsupported collection type: $cls")
   }
+
+  def assertMetadataPresent(
+      vectors: Seq[FieldVector],
+      expectedVectors: Seq[String],
+      expectedMetadata: Seq[(String, String)]): Unit = {
+    expectedVectors.foreach { vectorName =>
+      assert(vectors.exists(_.getName == vectorName))
+    }
+
+    expectedVectors.zip(expectedMetadata).foreach { case (vectorName, (key, 
value)) =>
+      assert(
+        vectors.exists(field =>
+          field.getName == vectorName && field.getField.getMetadata
+            .containsKey(key) && field.getField.getMetadata.get(key) == value))
+    }
+  }
 }
 
 private[arrow] object StructVectors {
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
index 4acb11f014d1..73c9a991ab6a 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
@@ -487,6 +487,14 @@ object ArrowSerializer {
               extractor = (v: Any) => v.asInstanceOf[VariantVal].getMetadata,
               serializerFor(BinaryEncoder, struct.getChild("metadata")))))
 
+      case (_: GeographyEncoder, StructVectors(struct, vectors)) =>
+        val gser = new GeographyArrowSerDe
+        gser.createSerializer(struct, vectors)
+
+      case (_: GeometryEncoder, StructVectors(struct, vectors)) =>
+        val gser = new GeometryArrowSerDe
+        gser.createSerializer(struct, vectors)
+
       case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) =>
         structSerializerFor(fields, struct, vectors) { (field, _) =>
           val getter = methodLookup.findVirtual(
@@ -585,12 +593,14 @@ object ArrowSerializer {
     }
   }
 
-  private class StructFieldSerializer(val extractor: Any => Any, val 
serializer: Serializer) {
+  private[arrow] class StructFieldSerializer(
+      val extractor: Any => Any,
+      val serializer: Serializer) {
     def write(index: Int, value: Any): Unit = serializer.write(index, 
extractor(value))
     def writeNull(index: Int): Unit = serializer.write(index, null)
   }
 
-  private class StructSerializer(
+  private[arrow] class StructSerializer(
       struct: StructVector,
       fieldSerializers: Seq[StructFieldSerializer])
       extends Serializer {
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala
new file mode 100644
index 000000000000..443523ef02cd
--- /dev/null
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.client.arrow
+
+import org.apache.arrow.vector.FieldVector
+import org.apache.arrow.vector.complex.StructVector
+
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, 
PrimitiveIntEncoder}
+import org.apache.spark.sql.errors.CompilationErrors
+import org.apache.spark.sql.types.{Geography, Geometry}
+
+abstract class GeospatialArrowSerDe[T](typeName: String) {
+
+  def createDeserializer(
+      struct: StructVector,
+      vectors: Seq[FieldVector],
+      timeZoneId: String): ArrowDeserializers.StructFieldSerializer[T] = {
+    assertMetadataPresent(vectors)
+    val wkbDecoder = ArrowDeserializers.deserializerFor(
+      BinaryEncoder,
+      vectors
+        .find(_.getName == "wkb")
+        .getOrElse(throw CompilationErrors.columnNotFoundError("wkb")),
+      timeZoneId)
+    val sridDecoder = ArrowDeserializers.deserializerFor(
+      PrimitiveIntEncoder,
+      vectors
+        .find(_.getName == "srid")
+        .getOrElse(throw CompilationErrors.columnNotFoundError("srid")),
+      timeZoneId)
+    new ArrowDeserializers.StructFieldSerializer[T](struct) {
+      override def value(i: Int): T = createInstance(wkbDecoder.get(i), 
sridDecoder.get(i))
+    }
+  }
+
+  def createSerializer(
+      struct: StructVector,
+      vectors: Seq[FieldVector]): ArrowSerializer.StructSerializer = {
+    assertMetadataPresent(vectors)
+    new ArrowSerializer.StructSerializer(
+      struct,
+      Seq(
+        new ArrowSerializer.StructFieldSerializer(
+          extractor = (v: Any) => extractSrid(v),
+          ArrowSerializer.serializerFor(PrimitiveIntEncoder, 
struct.getChild("srid"))),
+        new ArrowSerializer.StructFieldSerializer(
+          extractor = (v: Any) => extractBytes(v),
+          ArrowSerializer.serializerFor(BinaryEncoder, 
struct.getChild("wkb")))))
+  }
+
+  private def assertMetadataPresent(vectors: Seq[FieldVector]): Unit = {
+    assert(vectors.exists(_.getName == "srid"))
+    assert(
+      vectors.exists(field =>
+        field.getName == "wkb" && field.getField.getMetadata
+          .containsKey(typeName) && field.getField.getMetadata.get(typeName) 
== "true"))
+  }
+
+  protected def createInstance(wkb: Any, srid: Any): T
+  protected def extractSrid(value: Any): Int
+  protected def extractBytes(value: Any): Array[Byte]
+}
+
+// Geography-specific implementation
+class GeographyArrowSerDe extends GeospatialArrowSerDe[Geography]("geography") 
{
+  override protected def createInstance(wkb: Any, srid: Any): Geography =
+    Geography.fromWKB(wkb.asInstanceOf[Array[Byte]], srid.asInstanceOf[Int])
+
+  override protected def extractSrid(value: Any): Int =
+    value.asInstanceOf[Geography].getSrid
+
+  override protected def extractBytes(value: Any): Array[Byte] =
+    value.asInstanceOf[Geography].getBytes
+}
+
+// Geometry-specific implementation
+class GeometryArrowSerDe extends GeospatialArrowSerDe[Geometry]("geometry") {
+  override protected def createInstance(wkb: Any, srid: Any): Geometry =
+    Geometry.fromWKB(wkb.asInstanceOf[Array[Byte]], srid.asInstanceOf[Int])
+
+  override protected def extractSrid(value: Any): Int =
+    value.asInstanceOf[Geometry].getSrid
+
+  override protected def extractBytes(value: Any): Array[Byte] =
+    value.asInstanceOf[Geometry].getBytes
+}
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
index 419cc8e082af..ac69f084c307 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
@@ -71,6 +71,21 @@ object DataTypeProtoConverter {
       case proto.DataType.KindCase.MAP => toCatalystMapType(t.getMap)
       case proto.DataType.KindCase.VARIANT => VariantType
 
+      case proto.DataType.KindCase.GEOMETRY =>
+        val srid = t.getGeometry.getSrid
+        if (srid == GeometryType.MIXED_SRID) {
+          GeometryType("ANY")
+        } else {
+          GeometryType(srid)
+        }
+      case proto.DataType.KindCase.GEOGRAPHY =>
+        val srid = t.getGeography.getSrid
+        if (srid == GeographyType.MIXED_SRID) {
+          GeographyType("ANY")
+        } else {
+          GeographyType(srid)
+        }
+
       case proto.DataType.KindCase.UDT => toCatalystUDT(t.getUdt)
 
       case _ =>
@@ -307,6 +322,26 @@ object DataTypeProtoConverter {
               .build())
           .build()
 
+      case g: GeographyType =>
+        proto.DataType
+          .newBuilder()
+          .setGeography(
+            proto.DataType.Geography
+              .newBuilder()
+              .setSrid(g.srid)
+              .build())
+          .build()
+
+      case g: GeometryType =>
+        proto.DataType
+          .newBuilder()
+          .setGeometry(
+            proto.DataType.Geometry
+              .newBuilder()
+              .setSrid(g.srid)
+              .build())
+          .build()
+
       case VariantType => ProtoDataTypes.VariantType
 
       case pyudt: PythonUserDefinedType =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to