This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 7b7bca875ca9 [SPARK-54232][GEO][CONNECT] Enable Arrow serialization
for Geography and Geometry types
7b7bca875ca9 is described below
commit 7b7bca875ca9c53169eac65645ca8299a83d14a1
Author: Uros Bojanic <[email protected]>
AuthorDate: Fri Nov 7 13:06:28 2025 -0800
[SPARK-54232][GEO][CONNECT] Enable Arrow serialization for Geography and
Geometry types
### What changes were proposed in this pull request?
Introduce Arrow serialization/deserialization for `Geography` and
`Geometry`.
### Why are the changes needed?
Enable geospatial result set serialization in Arrow format for Spark
Connect and Thrift Server.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added unit tests:
- `GeographyConnectDataFrameSuite`
- `GeometryConnectDataFrameSuite`
- `ArrowEncoderSuite`
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52930 from uros-db/geo-arrow.
Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit f9191248900c5933a2934e938931f4fdc8b76f50)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/GeographyConnectDataFrameSuite.scala | 106 +++++++++++++++++++
.../spark/sql/GeometryConnectDataFrameSuite.scala | 112 +++++++++++++++++++++
.../connect/client/arrow/ArrowEncoderSuite.scala | 98 +++++++++++++++++-
.../connect/client/arrow/ArrowDeserializer.scala | 8 ++
.../connect/client/arrow/ArrowEncoderUtils.scala | 16 +++
.../sql/connect/client/arrow/ArrowSerializer.scala | 14 ++-
.../client/arrow/GeospatialArrowSerDe.scala | 101 +++++++++++++++++++
.../connect/common/DataTypeProtoConverter.scala | 35 +++++++
8 files changed, 487 insertions(+), 3 deletions(-)
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala
new file mode 100644
index 000000000000..2016a84ac5a3
--- /dev/null
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeographyConnectDataFrameSuite.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.immutable.Seq
+
+import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
+import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
+import org.apache.spark.sql.types._
+
+class GeographyConnectDataFrameSuite extends QueryTest with RemoteSparkSession
{
+
+ private val point1: Array[Byte] =
"010100000000000000000031400000000000001C40"
+ .grouped(2)
+ .map(Integer.parseInt(_, 16).toByte)
+ .toArray
+ private val point2: Array[Byte] =
"010100000000000000000035400000000000001E40"
+ .grouped(2)
+ .map(Integer.parseInt(_, 16).toByte)
+ .toArray
+
+ test("decode geography value: SRID schema does not match input SRID data
schema") {
+ val geography = Geography.fromWKB(point1, 0)
+
+ val seq = Seq((geography, 1))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(seq).collect()
+ },
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid"
-> "4326"))
+
+ import testImplicits._
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ Seq(geography).toDF().collect()
+ },
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid"
-> "4326"))
+ }
+
+ test("decode geography value: mixed SRID schema is provided") {
+ val schema = StructType(Seq(StructField("col1", GeographyType("ANY"),
nullable = false)))
+ val expectedResult =
+ Seq(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2,
4326)))
+
+ val javaList = java.util.Arrays
+ .asList(Row(Geography.fromWKB(point1, 4326)),
Row(Geography.fromWKB(point2, 4326)))
+ val resultJavaListDF = spark.createDataFrame(javaList, schema)
+ checkAnswer(resultJavaListDF, expectedResult)
+
+ // Test that unsupported SRID with mixed schema will throw an error.
+ val invalidData =
+ java.util.Arrays
+ .asList(Row(Geography.fromWKB(point1, 1)),
Row(Geography.fromWKB(point2, 4326)))
+ checkError(
+ exception = intercept[SparkIllegalArgumentException] {
+ spark.createDataFrame(invalidData, schema).collect()
+ },
+ condition = "ST_INVALID_SRID_VALUE",
+ parameters = Map("srid" -> "1"))
+ }
+
+ test("createDataFrame APIs with Geography.fromWKB") {
+ val geography1 = Geography.fromWKB(point1, 4326)
+ val geography2 = Geography.fromWKB(point2)
+
+ val seq = Seq((geography1, 1), (geography2, 2), (null, 3))
+ val dfFromSeq = spark.createDataFrame(seq)
+ checkAnswer(dfFromSeq, Seq(Row(geography1, 1), Row(geography2, 2),
Row(null, 3)))
+
+ val schema = StructType(Seq(StructField("geography", GeographyType(4326),
nullable = true)))
+
+ val javaList = java.util.Arrays.asList(Row(geography1), Row(geography2),
Row(null))
+ val dfFromJavaList = spark.createDataFrame(javaList, schema)
+ checkAnswer(dfFromJavaList, Seq(Row(geography1), Row(geography2),
Row(null)))
+
+ import testImplicits._
+ val implicitDf = Seq(geography1, geography2, null).toDF()
+ checkAnswer(implicitDf, Seq(Row(geography1), Row(geography2), Row(null)))
+ }
+
+ test("encode geography type") {
+ // POINT (17 7)
+ val wkb = "010100000000000000000031400000000000001C40"
+ val df = spark.sql(s"SELECT ST_GeogFromWKB(X'$wkb')")
+ val point = wkb.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+ val expectedGeog = Geography.fromWKB(point, 4326)
+ checkAnswer(df, Seq(Row(expectedGeog)))
+ }
+}
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala
new file mode 100644
index 000000000000..1450ac54184b
--- /dev/null
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/GeometryConnectDataFrameSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.immutable.Seq
+
+import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
+import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
+import org.apache.spark.sql.types._
+
+class GeometryConnectDataFrameSuite extends QueryTest with RemoteSparkSession {
+
+ private val point1: Array[Byte] =
"010100000000000000000031400000000000001C40"
+ .grouped(2)
+ .map(Integer.parseInt(_, 16).toByte)
+ .toArray
+ private val point2: Array[Byte] =
"010100000000000000000035400000000000001E40"
+ .grouped(2)
+ .map(Integer.parseInt(_, 16).toByte)
+ .toArray
+
+ test("decode geometry value: SRID schema does not match input SRID data
schema") {
+ val geometry = Geometry.fromWKB(point1, 4326)
+
+ val seq = Seq((geometry, 1))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ spark.createDataFrame(seq).collect()
+ },
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid"
-> "0"))
+
+ import testImplicits._
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ Seq(geometry).toDF().collect()
+ },
+ condition = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid"
-> "0"))
+ }
+
+ test("decode geometry value: mixed SRID schema is provided") {
+ val schema = StructType(Seq(StructField("col1", GeometryType("ANY"),
nullable = false)))
+ val expectedResult =
+ Seq(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2,
4326)))
+
+ val javaList = java.util.Arrays
+ .asList(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2,
4326)))
+ val resultJavaListDF = spark.createDataFrame(javaList, schema)
+ checkAnswer(resultJavaListDF, expectedResult)
+
+ // Test that unsupported SRID with mixed schema will throw an error.
+ val invalidData =
+ java.util.Arrays
+ .asList(Row(Geometry.fromWKB(point1, 1)), Row(Geometry.fromWKB(point2,
4326)))
+ checkError(
+ exception = intercept[SparkIllegalArgumentException] {
+ spark.createDataFrame(invalidData, schema).collect()
+ },
+ condition = "ST_INVALID_SRID_VALUE",
+ parameters = Map("srid" -> "1"))
+ }
+
+ test("createDataFrame APIs with Geometry.fromWKB") {
+ val geometry1 = Geometry.fromWKB(point1, 0)
+ val geometry2 = Geometry.fromWKB(point2, 0)
+
+ // 1. Test createDataFrame with Seq of Geometry objects
+ val seq = Seq((geometry1, 1), (geometry2, 2), (null, 3))
+ val dfFromSeq = spark.createDataFrame(seq)
+ checkAnswer(dfFromSeq, Seq(Row(geometry1, 1), Row(geometry2, 2), Row(null,
3)))
+
+ // 2. Test createDataFrame with RDD of Rows and StructType schema
+ val geometry3 = Geometry.fromWKB(point1, 4326)
+ val geometry4 = Geometry.fromWKB(point2, 4326)
+ val schema = StructType(Seq(StructField("geometry", GeometryType(4326),
nullable = true)))
+
+ // 3. Test createDataFrame with Java List of Rows and StructType schema
+ val javaList = java.util.Arrays.asList(Row(geometry3), Row(geometry4),
Row(null))
+ val dfFromJavaList = spark.createDataFrame(javaList, schema)
+ checkAnswer(dfFromJavaList, Seq(Row(geometry3), Row(geometry4), Row(null)))
+
+ // 4. Implicit conversion from Seq to DF
+ import testImplicits._
+ val implicitDf = Seq(geometry1, geometry2, null).toDF()
+ checkAnswer(implicitDf, Seq(Row(geometry1), Row(geometry2), Row(null)))
+ }
+
+ test("encode geometry type") {
+ // POINT (17 7)
+ val wkb = "010100000000000000000031400000000000001C40"
+ val df = spark.sql(s"SELECT ST_GeomFromWKB(X'$wkb')")
+ val point = wkb.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
+ val expectedGeom = Geometry.fromWKB(point, 0)
+ checkAnswer(df, Seq(Row(expectedGeom)))
+ }
+}
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index bc840df5c3fa..d24369ff5fc7 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._
import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
import org.apache.spark.sql.connect.test.ConnectFunSuite
-import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType,
Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType,
StructType, UserDefinedType, YearMonthIntervalType}
+import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType,
Decimal, DecimalType, Geography, Geometry, IntegerType, Metadata,
SQLUserDefinedType, StringType, StructType, UserDefinedType,
YearMonthIntervalType}
import org.apache.spark.unsafe.types.VariantVal
import org.apache.spark.util.{MaybeNull, SparkStringUtils}
@@ -263,6 +263,102 @@ class ArrowEncoderSuite extends ConnectFunSuite with
BeforeAndAfterAll {
assert(inspector.numBatches == 1)
}
+ test("geography round trip") {
+ val point1 = "010100000000000000000031400000000000001C40"
+ .grouped(2)
+ .map(Integer.parseInt(_, 16).toByte)
+ .toArray
+ val point2 = "010100000000000000000035400000000000001E40"
+ .grouped(2)
+ .map(Integer.parseInt(_, 16).toByte)
+ .toArray
+
+ val geographyEncoder = toRowEncoder(new StructType().add("g",
"geography(4326)"))
+ roundTripAndCheckIdentical(geographyEncoder) { () =>
+ val maybeNull = MaybeNull(7)
+ Iterator.tabulate(101)(i => Row(maybeNull(Geography.fromWKB(point1,
4326))))
+ }
+
+ val nestedGeographyEncoder = toRowEncoder(
+ new StructType()
+ .add(
+ "s",
+ new StructType()
+ .add("i1", "int")
+ .add("g0", "geography(4326)")
+ .add("i2", "int")
+ .add("g4326", "geography(4326)"))
+ .add("a", "array<geography(4326)>")
+ .add("m", "map<string, geography(ANY)>"))
+
+ roundTripAndCheckIdentical(nestedGeographyEncoder) { () =>
+ val maybeNull5 = MaybeNull(5)
+ val maybeNull7 = MaybeNull(7)
+ val maybeNull11 = MaybeNull(11)
+ val maybeNull13 = MaybeNull(13)
+ val maybeNull17 = MaybeNull(17)
+ Iterator
+ .tabulate(100)(i =>
+ Row(
+ maybeNull5(
+ Row(
+ i,
+ maybeNull7(Geography.fromWKB(point1)),
+ i + 1,
+ maybeNull11(Geography.fromWKB(point2, 4326)))),
+ maybeNull7((0 until 10).map(j => Geography.fromWKB(point2, 0))),
+ maybeNull13(Map((i.toString, maybeNull17(Geography.fromWKB(point1,
4326)))))))
+ }
+ }
+
+ test("geometry round trip") {
+ val point1 = "010100000000000000000031400000000000001C40"
+ .grouped(2)
+ .map(Integer.parseInt(_, 16).toByte)
+ .toArray
+ val point2 = "010100000000000000000035400000000000001E40"
+ .grouped(2)
+ .map(Integer.parseInt(_, 16).toByte)
+ .toArray
+
+ val geometryEncoder = toRowEncoder(new StructType().add("g",
"geometry(0)"))
+ roundTripAndCheckIdentical(geometryEncoder) { () =>
+ val maybeNull = MaybeNull(7)
+ Iterator.tabulate(101)(i => Row(maybeNull(Geometry.fromWKB(point1, 0))))
+ }
+
+ val nestedGeometryEncoder = toRowEncoder(
+ new StructType()
+ .add(
+ "s",
+ new StructType()
+ .add("i1", "int")
+ .add("g0", "geometry(0)")
+ .add("i2", "int")
+ .add("g4326", "geometry(4326)"))
+ .add("a", "array<geometry(0)>")
+ .add("m", "map<string, geometry(ANY)>"))
+
+ roundTripAndCheckIdentical(nestedGeometryEncoder) { () =>
+ val maybeNull5 = MaybeNull(5)
+ val maybeNull7 = MaybeNull(7)
+ val maybeNull11 = MaybeNull(11)
+ val maybeNull13 = MaybeNull(13)
+ val maybeNull17 = MaybeNull(17)
+ Iterator
+ .tabulate(100)(i =>
+ Row(
+ maybeNull5(
+ Row(
+ i,
+ maybeNull7(Geometry.fromWKB(point1, 0)),
+ i + 1,
+ maybeNull11(Geometry.fromWKB(point2, 4326)))),
+ maybeNull7((0 until 10).map(j => Geometry.fromWKB(point2, 0))),
+ maybeNull13(Map((i.toString, maybeNull17(Geometry.fromWKB(point1,
4326)))))))
+ }
+ }
+
test("variant round trip") {
val variantEncoder = toRowEncoder(new StructType().add("v", "variant"))
roundTripAndCheckIdentical(variantEncoder) { () =>
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 7597a0ceeb8c..8d5811dda8f3 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -341,6 +341,14 @@ object ArrowDeserializers {
}
}
+ case (_: GeometryEncoder, StructVectors(struct, vectors)) =>
+ val gdser = new GeometryArrowSerDe
+ gdser.createDeserializer(struct, vectors, timeZoneId)
+
+ case (_: GeographyEncoder, StructVectors(struct, vectors)) =>
+ val gdser = new GeographyArrowSerDe
+ gdser.createDeserializer(struct, vectors, timeZoneId)
+
case (VariantEncoder, StructVectors(struct, vectors)) =>
assert(vectors.exists(_.getName == "value"))
assert(
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
index 5b1539e39f4f..2430c2bbc86f 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
@@ -41,6 +41,22 @@ private[arrow] object ArrowEncoderUtils {
def unsupportedCollectionType(cls: Class[_]): Nothing = {
throw new RuntimeException(s"Unsupported collection type: $cls")
}
+
+ def assertMetadataPresent(
+ vectors: Seq[FieldVector],
+ expectedVectors: Seq[String],
+ expectedMetadata: Seq[(String, String)]): Unit = {
+ expectedVectors.foreach { vectorName =>
+ assert(vectors.exists(_.getName == vectorName))
+ }
+
+ expectedVectors.zip(expectedMetadata).foreach { case (vectorName, (key,
value)) =>
+ assert(
+ vectors.exists(field =>
+ field.getName == vectorName && field.getField.getMetadata
+ .containsKey(key) && field.getField.getMetadata.get(key) == value))
+ }
+ }
}
private[arrow] object StructVectors {
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
index 4acb11f014d1..73c9a991ab6a 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
@@ -487,6 +487,14 @@ object ArrowSerializer {
extractor = (v: Any) => v.asInstanceOf[VariantVal].getMetadata,
serializerFor(BinaryEncoder, struct.getChild("metadata")))))
+ case (_: GeographyEncoder, StructVectors(struct, vectors)) =>
+ val gser = new GeographyArrowSerDe
+ gser.createSerializer(struct, vectors)
+
+ case (_: GeometryEncoder, StructVectors(struct, vectors)) =>
+ val gser = new GeometryArrowSerDe
+ gser.createSerializer(struct, vectors)
+
case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) =>
structSerializerFor(fields, struct, vectors) { (field, _) =>
val getter = methodLookup.findVirtual(
@@ -585,12 +593,14 @@ object ArrowSerializer {
}
}
- private class StructFieldSerializer(val extractor: Any => Any, val
serializer: Serializer) {
+ private[arrow] class StructFieldSerializer(
+ val extractor: Any => Any,
+ val serializer: Serializer) {
def write(index: Int, value: Any): Unit = serializer.write(index,
extractor(value))
def writeNull(index: Int): Unit = serializer.write(index, null)
}
- private class StructSerializer(
+ private[arrow] class StructSerializer(
struct: StructVector,
fieldSerializers: Seq[StructFieldSerializer])
extends Serializer {
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala
new file mode 100644
index 000000000000..443523ef02cd
--- /dev/null
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.client.arrow
+
+import org.apache.arrow.vector.FieldVector
+import org.apache.arrow.vector.complex.StructVector
+
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder,
PrimitiveIntEncoder}
+import org.apache.spark.sql.errors.CompilationErrors
+import org.apache.spark.sql.types.{Geography, Geometry}
+
+abstract class GeospatialArrowSerDe[T](typeName: String) {
+
+ def createDeserializer(
+ struct: StructVector,
+ vectors: Seq[FieldVector],
+ timeZoneId: String): ArrowDeserializers.StructFieldSerializer[T] = {
+ assertMetadataPresent(vectors)
+ val wkbDecoder = ArrowDeserializers.deserializerFor(
+ BinaryEncoder,
+ vectors
+ .find(_.getName == "wkb")
+ .getOrElse(throw CompilationErrors.columnNotFoundError("wkb")),
+ timeZoneId)
+ val sridDecoder = ArrowDeserializers.deserializerFor(
+ PrimitiveIntEncoder,
+ vectors
+ .find(_.getName == "srid")
+ .getOrElse(throw CompilationErrors.columnNotFoundError("srid")),
+ timeZoneId)
+ new ArrowDeserializers.StructFieldSerializer[T](struct) {
+ override def value(i: Int): T = createInstance(wkbDecoder.get(i),
sridDecoder.get(i))
+ }
+ }
+
+ def createSerializer(
+ struct: StructVector,
+ vectors: Seq[FieldVector]): ArrowSerializer.StructSerializer = {
+ assertMetadataPresent(vectors)
+ new ArrowSerializer.StructSerializer(
+ struct,
+ Seq(
+ new ArrowSerializer.StructFieldSerializer(
+ extractor = (v: Any) => extractSrid(v),
+ ArrowSerializer.serializerFor(PrimitiveIntEncoder,
struct.getChild("srid"))),
+ new ArrowSerializer.StructFieldSerializer(
+ extractor = (v: Any) => extractBytes(v),
+ ArrowSerializer.serializerFor(BinaryEncoder,
struct.getChild("wkb")))))
+ }
+
+ private def assertMetadataPresent(vectors: Seq[FieldVector]): Unit = {
+ assert(vectors.exists(_.getName == "srid"))
+ assert(
+ vectors.exists(field =>
+ field.getName == "wkb" && field.getField.getMetadata
+ .containsKey(typeName) && field.getField.getMetadata.get(typeName)
== "true"))
+ }
+
+ protected def createInstance(wkb: Any, srid: Any): T
+ protected def extractSrid(value: Any): Int
+ protected def extractBytes(value: Any): Array[Byte]
+}
+
+// Geography-specific implementation
+class GeographyArrowSerDe extends GeospatialArrowSerDe[Geography]("geography")
{
+ override protected def createInstance(wkb: Any, srid: Any): Geography =
+ Geography.fromWKB(wkb.asInstanceOf[Array[Byte]], srid.asInstanceOf[Int])
+
+ override protected def extractSrid(value: Any): Int =
+ value.asInstanceOf[Geography].getSrid
+
+ override protected def extractBytes(value: Any): Array[Byte] =
+ value.asInstanceOf[Geography].getBytes
+}
+
+// Geometry-specific implementation
+class GeometryArrowSerDe extends GeospatialArrowSerDe[Geometry]("geometry") {
+ override protected def createInstance(wkb: Any, srid: Any): Geometry =
+ Geometry.fromWKB(wkb.asInstanceOf[Array[Byte]], srid.asInstanceOf[Int])
+
+ override protected def extractSrid(value: Any): Int =
+ value.asInstanceOf[Geometry].getSrid
+
+ override protected def extractBytes(value: Any): Array[Byte] =
+ value.asInstanceOf[Geometry].getBytes
+}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
index 419cc8e082af..ac69f084c307 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
@@ -71,6 +71,21 @@ object DataTypeProtoConverter {
case proto.DataType.KindCase.MAP => toCatalystMapType(t.getMap)
case proto.DataType.KindCase.VARIANT => VariantType
+ case proto.DataType.KindCase.GEOMETRY =>
+ val srid = t.getGeometry.getSrid
+ if (srid == GeometryType.MIXED_SRID) {
+ GeometryType("ANY")
+ } else {
+ GeometryType(srid)
+ }
+ case proto.DataType.KindCase.GEOGRAPHY =>
+ val srid = t.getGeography.getSrid
+ if (srid == GeographyType.MIXED_SRID) {
+ GeographyType("ANY")
+ } else {
+ GeographyType(srid)
+ }
+
case proto.DataType.KindCase.UDT => toCatalystUDT(t.getUdt)
case _ =>
@@ -307,6 +322,26 @@ object DataTypeProtoConverter {
.build())
.build()
+ case g: GeographyType =>
+ proto.DataType
+ .newBuilder()
+ .setGeography(
+ proto.DataType.Geography
+ .newBuilder()
+ .setSrid(g.srid)
+ .build())
+ .build()
+
+ case g: GeometryType =>
+ proto.DataType
+ .newBuilder()
+ .setGeometry(
+ proto.DataType.Geometry
+ .newBuilder()
+ .setSrid(g.srid)
+ .build())
+ .build()
+
case VariantType => ProtoDataTypes.VariantType
case pyudt: PythonUserDefinedType =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]