This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new fa32f44ab6 [GH-2565] Fix NULL handling for various aggregation
functions in SedonaSpark (#2563)
fa32f44ab6 is described below
commit fa32f44ab6cd3580a871798ff0aab7a12374d0aa
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Thu Dec 18 02:28:21 2025 +0800
[GH-2565] Fix NULL handling for various aggregation functions in
SedonaSpark (#2563)
---
.../expressions/AggregateFunctions.scala | 170 ++++++++++-----------
.../sedona/sql/aggregateFunctionTestScala.scala | 149 ++++++++++++++++++
2 files changed, 234 insertions(+), 85 deletions(-)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
index fc0cab6260..ca169a2598 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
@@ -19,9 +19,10 @@
package org.apache.spark.sql.sedona_sql.expressions
import org.apache.sedona.common.Functions
+import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
-import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
+import org.locationtech.jts.geom.{Coordinate, Envelope, Geometry,
GeometryFactory}
import org.locationtech.jts.operation.overlayng.OverlayNGRobust
import scala.collection.JavaConverters._
@@ -32,18 +33,7 @@ import scala.collection.mutable.ListBuffer
*/
trait TraitSTAggregateExec {
- val initialGeometry: Geometry = {
- // dummy value for initial value(polygon but )
- // any other value is ok.
- val coordinates: Array[Coordinate] = new Array[Coordinate](5)
- coordinates(0) = new Coordinate(-999999999, -999999999)
- coordinates(1) = new Coordinate(-999999999, -999999999)
- coordinates(2) = new Coordinate(-999999999, -999999999)
- coordinates(3) = new Coordinate(-999999999, -999999999)
- coordinates(4) = coordinates(0)
- val geometryFactory = new GeometryFactory()
- geometryFactory.createPolygon(coordinates)
- }
+ val initialGeometry: Geometry = null
val serde = ExpressionEncoder[Geometry]()
def zero: Geometry = initialGeometry
@@ -62,7 +52,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
val bufferSerde = ExpressionEncoder[ListBuffer[Geometry]]()
override def reduce(buffer: ListBuffer[Geometry], input: Geometry):
ListBuffer[Geometry] = {
- buffer += input
+ if (input != null) {
+ buffer += input
+ }
if (buffer.size >= bufferSize) {
// Perform the union when buffer size is reached
val unionGeometry = OverlayNGRobust.union(buffer.asJava)
@@ -86,6 +78,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
}
override def finish(reduction: ListBuffer[Geometry]): Geometry = {
+ if (reduction.isEmpty) {
+ return null
+ }
OverlayNGRobust.union(reduction.asJava)
}
@@ -97,81 +92,76 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
}
/**
- * Return the envelope boundary of the entire column
+ * A helper class to store envelope boundary during aggregation. We use this
custom case class
+ * instead of JTS Envelope to work with the Spark Encoder.
*/
-private[apache] class ST_Envelope_Aggr
- extends Aggregator[Geometry, Geometry, Geometry]
- with TraitSTAggregateExec {
+case class EnvelopeBuffer(minX: Double, maxX: Double, minY: Double, maxY:
Double) {
+ def isNull: Boolean = minX > maxX
- def reduce(buffer: Geometry, input: Geometry): Geometry = {
- val accumulateEnvelope = buffer.getEnvelopeInternal
- val newEnvelope = input.getEnvelopeInternal
- val coordinates: Array[Coordinate] = new Array[Coordinate](5)
- var minX = 0.0
- var minY = 0.0
- var maxX = 0.0
- var maxY = 0.0
- if (accumulateEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
- // Found the accumulateEnvelope is the initial value
- minX = newEnvelope.getMinX
- minY = newEnvelope.getMinY
- maxX = newEnvelope.getMaxX
- maxY = newEnvelope.getMaxY
- } else if (newEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
- minX = accumulateEnvelope.getMinX
- minY = accumulateEnvelope.getMinY
- maxX = accumulateEnvelope.getMaxX
- maxY = accumulateEnvelope.getMaxY
+ def toEnvelope: Envelope = {
+ if (isNull) {
+ new Envelope()
} else {
- minX = Math.min(accumulateEnvelope.getMinX, newEnvelope.getMinX)
- minY = Math.min(accumulateEnvelope.getMinY, newEnvelope.getMinY)
- maxX = Math.max(accumulateEnvelope.getMaxX, newEnvelope.getMaxX)
- maxY = Math.max(accumulateEnvelope.getMaxY, newEnvelope.getMaxY)
+ new Envelope(minX, maxX, minY, maxY)
}
- coordinates(0) = new Coordinate(minX, minY)
- coordinates(1) = new Coordinate(minX, maxY)
- coordinates(2) = new Coordinate(maxX, maxY)
- coordinates(3) = new Coordinate(maxX, minY)
- coordinates(4) = coordinates(0)
- val geometryFactory = new GeometryFactory()
- geometryFactory.createPolygon(coordinates)
-
}
- def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
- val leftEnvelope = buffer1.getEnvelopeInternal
- val rightEnvelope = buffer2.getEnvelopeInternal
- val coordinates: Array[Coordinate] = new Array[Coordinate](5)
- var minX = 0.0
- var minY = 0.0
- var maxX = 0.0
- var maxY = 0.0
- if (leftEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
- minX = rightEnvelope.getMinX
- minY = rightEnvelope.getMinY
- maxX = rightEnvelope.getMaxX
- maxY = rightEnvelope.getMaxY
- } else if (rightEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
- minX = leftEnvelope.getMinX
- minY = leftEnvelope.getMinY
- maxX = leftEnvelope.getMaxX
- maxY = leftEnvelope.getMaxY
+ def merge(other: EnvelopeBuffer): EnvelopeBuffer = {
+ if (this.isNull) {
+ other
+ } else if (other.isNull) {
+ this
} else {
- minX = Math.min(leftEnvelope.getMinX, rightEnvelope.getMinX)
- minY = Math.min(leftEnvelope.getMinY, rightEnvelope.getMinY)
- maxX = Math.max(leftEnvelope.getMaxX, rightEnvelope.getMaxX)
- maxY = Math.max(leftEnvelope.getMaxY, rightEnvelope.getMaxY)
+ EnvelopeBuffer(
+ math.min(this.minX, other.minX),
+ math.max(this.maxX, other.maxX),
+ math.min(this.minY, other.minY),
+ math.max(this.maxY, other.maxY))
}
+ }
+}
- coordinates(0) = new Coordinate(minX, minY)
- coordinates(1) = new Coordinate(minX, maxY)
- coordinates(2) = new Coordinate(maxX, maxY)
- coordinates(3) = new Coordinate(maxX, minY)
- coordinates(4) = coordinates(0)
- val geometryFactory = new GeometryFactory()
- geometryFactory.createPolygon(coordinates)
+/**
+ * Return the envelope boundary of the entire column
+ */
+private[apache] class ST_Envelope_Aggr
+ extends Aggregator[Geometry, Option[EnvelopeBuffer], Geometry] {
+
+ val serde = ExpressionEncoder[Geometry]()
+
+ def reduce(buffer: Option[EnvelopeBuffer], input: Geometry):
Option[EnvelopeBuffer] = {
+ if (input == null) return buffer
+ val env = input.getEnvelopeInternal
+ val envBuffer = EnvelopeBuffer(env.getMinX, env.getMaxX, env.getMinY,
env.getMaxY)
+ buffer match {
+ case Some(b) => Some(b.merge(envBuffer))
+ case None => Some(envBuffer)
+ }
+ }
+
+ def merge(
+ buffer1: Option[EnvelopeBuffer],
+ buffer2: Option[EnvelopeBuffer]): Option[EnvelopeBuffer] = {
+ (buffer1, buffer2) match {
+ case (Some(b1), Some(b2)) => Some(b1.merge(b2))
+ case (Some(_), None) => buffer1
+ case (None, Some(_)) => buffer2
+ case (None, None) => None
+ }
+ }
+
+ def finish(reduction: Option[EnvelopeBuffer]): Geometry = {
+ reduction match {
+ case Some(b) => new GeometryFactory().toGeometry(b.toEnvelope)
+ case None => null
+ }
}
+ def bufferEncoder: Encoder[Option[EnvelopeBuffer]] =
Encoders.product[Option[EnvelopeBuffer]]
+
+ def outputEncoder: ExpressionEncoder[Geometry] = serde
+
+ def zero: Option[EnvelopeBuffer] = None
}
/**
@@ -181,16 +171,26 @@ private[apache] class ST_Intersection_Aggr
extends Aggregator[Geometry, Geometry, Geometry]
with TraitSTAggregateExec {
def reduce(buffer: Geometry, input: Geometry): Geometry = {
- if (buffer.isEmpty) input
- else if (buffer.equalsExact(initialGeometry)) input
- else buffer.intersection(input)
+ if (input == null) {
+ return buffer
+ }
+ if (buffer == null) {
+ return input
+ }
+ buffer.intersection(input)
}
def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
- if (buffer1.equalsExact(initialGeometry)) buffer2
- else if (buffer2.equalsExact(initialGeometry)) buffer1
- else buffer1.intersection(buffer2)
+ if (buffer1 == null) {
+ return buffer2
+ }
+ if (buffer2 == null) {
+ return buffer1
+ }
+ buffer1.intersection(buffer2)
}
+
+ override def finish(out: Geometry): Geometry = out
}
/**
@@ -219,7 +219,7 @@ private[apache] class ST_Collect_Agg
override def finish(reduction: ListBuffer[Geometry]): Geometry = {
if (reduction.isEmpty) {
- new GeometryFactory().createGeometryCollection()
+ null
} else {
Functions.createMultiGeometry(reduction.toArray)
}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
index 911769e2ac..4485f9fcfe 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
@@ -245,6 +245,155 @@ class aggregateFunctionTestScala extends TestBaseScala {
// Should only have 2 points (nulls are skipped)
assert(result.getNumGeometries == 2)
}
+
+ it("ST_Union_Aggr should handle null values") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
+ | ST_GeomFromWKT(NULL),
+ | ST_GeomFromWKT('POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("polygons_with_null_for_union")
+
+ val unionDF =
+ sparkSession.sql("SELECT ST_Union_Aggr(geom) FROM
polygons_with_null_for_union")
+ val result = unionDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+ // Should union the 2 non-null polygons (total area = 2.0)
+ assert(result.getArea == 2.0)
+ }
+
+ it("ST_Envelope_Aggr should handle null values") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT('POINT(1 2)'),
+ | ST_GeomFromWKT(NULL),
+ | ST_GeomFromWKT('POINT(3 4)')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("points_with_null_for_envelope")
+
+ val envelopeDF =
+ sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM
points_with_null_for_envelope")
+ val result = envelopeDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+ // Should create envelope from the 2 non-null points
+ assert(result.getGeometryType == "Polygon")
+ val envelope = result.getEnvelopeInternal
+ assert(envelope.getMinX == 1.0)
+ assert(envelope.getMinY == 2.0)
+ assert(envelope.getMaxX == 3.0)
+ assert(envelope.getMaxY == 4.0)
+ }
+
+ it("ST_Intersection_Aggr should handle null values") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT('POLYGON((0 0, 4 0, 4 4, 0 4, 0 0))'),
+ | ST_GeomFromWKT(NULL),
+ | ST_GeomFromWKT('POLYGON((2 2, 6 2, 6 6, 2 6, 2 2))')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("polygons_with_null_for_intersection")
+
+ val intersectionDF = sparkSession.sql(
+ "SELECT ST_Intersection_Aggr(geom) FROM
polygons_with_null_for_intersection")
+ val result = intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry]
+
+ // Should intersect the 2 non-null polygons (intersection area = 4.0)
+ assert(result.getArea == 4.0)
+ }
+
+ it("ST_Union_Aggr should return null if all inputs are null") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT(NULL),
+ | ST_GeomFromWKT(NULL)
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("all_null_union")
+
+ val unionDF = sparkSession.sql("SELECT ST_Union_Aggr(geom) FROM
all_null_union")
+ val result = unionDF.take(1)(0).get(0)
+
+ assert(result == null)
+ }
+
+ it("ST_Envelope_Aggr should return null if all inputs are null") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT(NULL),
+ | ST_GeomFromWKT(NULL)
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("all_null_envelope")
+
+ val envelopeDF = sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM
all_null_envelope")
+ val result = envelopeDF.take(1)(0).get(0)
+
+ assert(result == null)
+ }
+
+ it("ST_Intersection_Aggr should return null if all inputs are null") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT(NULL),
+ | ST_GeomFromWKT(NULL)
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("all_null_intersection")
+
+ val intersectionDF =
+ sparkSession.sql("SELECT ST_Intersection_Aggr(geom) FROM
all_null_intersection")
+ val result = intersectionDF.take(1)(0).get(0)
+
+ assert(result == null)
+ }
+
+ it("ST_Collect_Agg should return null if all inputs are null") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | ST_GeomFromWKT(NULL),
+ | ST_GeomFromWKT(NULL)
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("all_null_collect")
+
+ val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM
all_null_collect")
+ val result = collectDF.take(1)(0).get(0)
+
+ assert(result == null)
+ }
+
+ it(
+ "ST_Envelope_Aggr should return empty geometry if inputs are mixed with
null and empty geometries") {
+ sparkSession
+ .sql("""
+ |SELECT explode(array(
+ | NULL,
+ | NULL,
+ | ST_GeomFromWKT('POINT EMPTY'),
+ | NULL,
+ | ST_GeomFromWKT('POLYGON EMPTY')
+ |)) AS geom
+ """.stripMargin)
+ .createOrReplaceTempView("mixed_null_empty_envelope")
+
+ val envelopeDF =
+ sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM
mixed_null_empty_envelope")
+ val result = envelopeDF.take(1)(0).get(0)
+
+ assert(result != null)
+ assert(result.asInstanceOf[Geometry].isEmpty)
+ }
}
def generateRandomPolygon(index: Int): String = {