This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new bab1f773e [SEDONA-630] Improve ST_Union_Aggr performance (#1526)
bab1f773e is described below
commit bab1f773e16f96d43933ee14ca6538c3ff671fac
Author: Feng Zhang <[email protected]>
AuthorDate: Mon Jul 22 13:05:08 2024 -0700
[SEDONA-630] Improve ST_Union_Aggr performance (#1526)
* [SEDONA-630] Improve ST_Union_Aggr performance
Switch to JTS `OverlayNGRobust.union` function to perform geometry union
and add
geometry cache capability.
* fix pythion test
* add unit test to measure the ST_Union_aggr time
* address review comments by refactoring unit tests
* rename test table
---
python/tests/sql/test_dataframe_api.py | 2 +-
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 8 ++-
.../org/apache/sedona/sql/UDF/UdfRegistrator.scala | 4 +-
.../expressions/AggregateFunctions.scala | 48 +++++++++++++-----
.../sedona/sql/aggregateFunctionTestScala.scala | 57 +++++++++++++++++++++-
.../apache/sedona/sql/dataFrameAPITestScala.scala | 2 +-
6 files changed, 105 insertions(+), 16 deletions(-)
diff --git a/python/tests/sql/test_dataframe_api.py
b/python/tests/sql/test_dataframe_api.py
index 49a2e68f8..16c000a34 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -262,7 +262,7 @@ test_configurations = [
# aggregates
(sta.ST_Envelope_Aggr, ("geom",), "exploded_points", "", "POLYGON ((0 0, 0
1, 1 1, 1 0, 0 0))"),
(sta.ST_Intersection_Aggr, ("geom",), "exploded_polys", "", "LINESTRING (1
0, 1 1)"),
- (sta.ST_Union_Aggr, ("geom",), "exploded_polys", "", "POLYGON ((1 0, 0 0,
0 1, 1 1, 2 1, 2 0, 1 0))"),
+ (sta.ST_Union_Aggr, ("geom",), "exploded_polys", "", "POLYGON ((0 0, 0 1,
1 1, 2 1, 2 0, 1 0, 0 0))"),
]
wrong_type_configurations = [
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 4fb0aa044..1742324d0 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.sedona_sql.expressions.raster._
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters
+import scala.collection.mutable.ListBuffer
import scala.reflect.ClassTag
object Catalog {
@@ -327,8 +328,13 @@ object Catalog {
function[RS_FromNetCDF](),
function[RS_NetCDFInfo]())
+ // Aggregate functions with Geometry as buffer
val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] =
- Seq(new ST_Union_Aggr, new ST_Envelope_Aggr, new ST_Intersection_Aggr)
+ Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr)
+
+ // Aggregate functions with List as buffer
+ val aggregateExpressions2: Seq[Aggregator[Geometry, ListBuffer[Geometry],
Geometry]] =
+ Seq(new ST_Union_Aggr())
private def function[T <: Expression: ClassTag](defaultArgs: Any*):
FunctionDescription = {
val classTag = implicitly[ClassTag[T]]
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
index fae76b2c5..30c3cb2e3 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
@@ -36,7 +36,9 @@ object UdfRegistrator {
}
Catalog.aggregateExpressions.foreach(f =>
sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f)))
// SPARK3 anchor
-//Catalog.aggregateExpressions_UDAF.foreach(f =>
sparkSession.udf.register(f.getClass.getSimpleName, f)) // SPARK2 anchor
+
+ Catalog.aggregateExpressions2.foreach(f =>
+ sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f)))
// SPARK3 anchor
}
def dropAll(sparkSession: SparkSession): Unit = {
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
index 43c488a80..0579967dc 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
@@ -21,6 +21,10 @@ package org.apache.spark.sql.sedona_sql.expressions
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
+import org.locationtech.jts.operation.overlayng.OverlayNGRobust
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
/**
* traits for creating Aggregate Function
@@ -50,22 +54,44 @@ trait TraitSTAggregateExec {
def finish(out: Geometry): Geometry = out
}
-/**
- * Return the polygon union of all Polygon in the given column
- */
-class ST_Union_Aggr extends Aggregator[Geometry, Geometry, Geometry] with
TraitSTAggregateExec {
+class ST_Union_Aggr(bufferSize: Int = 1000)
+ extends Aggregator[Geometry, ListBuffer[Geometry], Geometry]
+ with Serializable {
+
+ override def reduce(buffer: ListBuffer[Geometry], input: Geometry):
ListBuffer[Geometry] = {
+ buffer += input
+ if (buffer.size >= bufferSize) {
+ // Perform the union when buffer size is reached
+ val unionGeometry = OverlayNGRobust.union(buffer.asJava)
+ buffer.clear()
+ buffer += unionGeometry
+ }
+ buffer
+ }
- def reduce(buffer: Geometry, input: Geometry): Geometry = {
- if (buffer.equalsExact(initialGeometry)) input
- else buffer.union(input)
+ override def merge(
+ buffer1: ListBuffer[Geometry],
+ buffer2: ListBuffer[Geometry]): ListBuffer[Geometry] = {
+ buffer1 ++= buffer2
+ if (buffer1.size >= bufferSize) {
+ // Perform the union when buffer size is reached
+ val unionGeometry = OverlayNGRobust.union(buffer1.asJava)
+ buffer1.clear()
+ buffer1 += unionGeometry
+ }
+ buffer1
}
- def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
- if (buffer1.equals(initialGeometry)) buffer2
- else if (buffer2.equals(initialGeometry)) buffer1
- else buffer1.union(buffer2)
+ override def finish(reduction: ListBuffer[Geometry]): Geometry = {
+ OverlayNGRobust.union(reduction.asJava)
}
+ def bufferEncoder: ExpressionEncoder[ListBuffer[Geometry]] =
+ ExpressionEncoder[ListBuffer[Geometry]]()
+
+ def outputEncoder: ExpressionEncoder[Geometry] =
ExpressionEncoder[Geometry]()
+
+ override def zero: ListBuffer[Geometry] = ListBuffer.empty
}
/**
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
index 843555e65..25de5c8ea 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
@@ -18,7 +18,13 @@
*/
package org.apache.sedona.sql
-import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.expressions.javalang.typed
+import org.apache.spark.sql.sedona_sql.expressions.ST_Union_Aggr
+import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory,
Polygon}
+import org.locationtech.jts.io.WKTReader
+
+import scala.util.Random
class aggregateFunctionTestScala extends TestBaseScala {
@@ -62,6 +68,35 @@ class aggregateFunctionTestScala extends TestBaseScala {
assert(union.take(1)(0).get(0).asInstanceOf[Geometry].getArea == 10100)
}
+ it("Measured ST_Union_aggr wall time") {
+ // number of random polygons to generate
+ val numPolygons = 1000
+ val df = createPolygonDataFrame(numPolygons)
+
+ df.createOrReplaceTempView("geometry_table_for_measuring_union_aggr")
+
+ // cache the table to eliminate the time of table scan
+ df.cache()
+ sparkSession
+ .sql("select count(*) from geometry_table_for_measuring_union_aggr")
+ .take(1)(0)
+ .get(0)
+
+ // measure time for optimized ST_Union_Aggr
+ val startTimeOptimized = System.currentTimeMillis()
+ val unionOptimized =
+ sparkSession.sql(
+ "SELECT ST_Union_Aggr(geom) AS union_geom FROM
geometry_table_for_measuring_union_aggr")
+ assert(unionOptimized.take(1)(0).get(0).asInstanceOf[Geometry].getArea >
0)
+ val endTimeOptimized = System.currentTimeMillis()
+ val durationOptimized = endTimeOptimized - startTimeOptimized
+
+ assert(durationOptimized > 0, "Duration of optimized ST_Union_Aggr
should be positive")
+
+ // clear cache
+ df.unpersist()
+ }
+
it("Passed ST_Intersection_aggr") {
val twoPolygonsAsWktDf =
@@ -97,4 +132,24 @@ class aggregateFunctionTestScala extends TestBaseScala {
assertResult(0.0)(intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry].getArea)
}
}
+
+ def generateRandomPolygon(index: Int): String = {
+ val random = new Random()
+ val x = random.nextDouble() * index
+ val y = random.nextDouble() * index
+ s"POLYGON (($x $y, ${x + 1} $y, ${x + 1} ${y + 1}, $x ${y + 1}, $x $y))"
+ }
+
+ def createPolygonDataFrame(numPolygons: Int): DataFrame = {
+ val polygons = (1 to numPolygons).map(generateRandomPolygon).toArray
+ val polygonArray = polygons.map(polygon => s"ST_GeomFromWKT('$polygon')")
+ val polygonArrayStr = polygonArray.mkString(", ")
+
+ val sqlQuery =
+ s"""
+ |SELECT explode(array($polygonArrayStr)) AS geom
+ """.stripMargin
+
+ sparkSession.sql(sqlQuery)
+ }
}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index d01cb6799..bb83932c6 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -1585,7 +1585,7 @@ class dataFrameAPITestScala extends TestBaseScala {
"SELECT explode(array(ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0
0))'), ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 1, 1 1, 1 0))'))) AS geom")
val df = baseDf.select(ST_Union_Aggr("geom"))
val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
- val expectedResult = "POLYGON ((1 0, 0 0, 0 1, 1 1, 2 1, 2 0, 1 0))"
+ val expectedResult = "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))"
assert(actualResult == expectedResult)
}