This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new bab1f773e [SEDONA-630] Improve ST_Union_Aggr performance (#1526)
bab1f773e is described below

commit bab1f773e16f96d43933ee14ca6538c3ff671fac
Author: Feng Zhang <[email protected]>
AuthorDate: Mon Jul 22 13:05:08 2024 -0700

    [SEDONA-630] Improve ST_Union_Aggr performance (#1526)
    
    * [SEDONA-630] Improve ST_Union_Aggr performance
    
    Switch to JTS `OverlayNGRobust.union` function to perform geometry union 
and add
    geometry cache capability.
    
    * fix pythion test
    
    * add unit test to measure the ST_Union_aggr time
    
    * address review comments by refactoring unit tests
    
    * rename test table
---
 python/tests/sql/test_dataframe_api.py             |  2 +-
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  8 ++-
 .../org/apache/sedona/sql/UDF/UdfRegistrator.scala |  4 +-
 .../expressions/AggregateFunctions.scala           | 48 +++++++++++++-----
 .../sedona/sql/aggregateFunctionTestScala.scala    | 57 +++++++++++++++++++++-
 .../apache/sedona/sql/dataFrameAPITestScala.scala  |  2 +-
 6 files changed, 105 insertions(+), 16 deletions(-)

diff --git a/python/tests/sql/test_dataframe_api.py 
b/python/tests/sql/test_dataframe_api.py
index 49a2e68f8..16c000a34 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -262,7 +262,7 @@ test_configurations = [
     # aggregates
     (sta.ST_Envelope_Aggr, ("geom",), "exploded_points", "", "POLYGON ((0 0, 0 
1, 1 1, 1 0, 0 0))"),
     (sta.ST_Intersection_Aggr, ("geom",), "exploded_polys", "", "LINESTRING (1 
0, 1 1)"),
-    (sta.ST_Union_Aggr, ("geom",), "exploded_polys", "", "POLYGON ((1 0, 0 0, 
0 1, 1 1, 2 1, 2 0, 1 0))"),
+    (sta.ST_Union_Aggr, ("geom",), "exploded_polys", "", "POLYGON ((0 0, 0 1, 
1 1, 2 1, 2 0, 1 0, 0 0))"),
 ]
 
 wrong_type_configurations = [
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 4fb0aa044..1742324d0 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.sedona_sql.expressions.raster._
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.operation.buffer.BufferParameters
 
+import scala.collection.mutable.ListBuffer
 import scala.reflect.ClassTag
 
 object Catalog {
@@ -327,8 +328,13 @@ object Catalog {
     function[RS_FromNetCDF](),
     function[RS_NetCDFInfo]())
 
+  // Aggregate functions with Geometry as buffer
   val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] =
-    Seq(new ST_Union_Aggr, new ST_Envelope_Aggr, new ST_Intersection_Aggr)
+    Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr)
+
+  // Aggregate functions with List as buffer
+  val aggregateExpressions2: Seq[Aggregator[Geometry, ListBuffer[Geometry], 
Geometry]] =
+    Seq(new ST_Union_Aggr())
 
   private def function[T <: Expression: ClassTag](defaultArgs: Any*): 
FunctionDescription = {
     val classTag = implicitly[ClassTag[T]]
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
index fae76b2c5..30c3cb2e3 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
@@ -36,7 +36,9 @@ object UdfRegistrator {
     }
     Catalog.aggregateExpressions.foreach(f =>
       sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) 
// SPARK3 anchor
-//Catalog.aggregateExpressions_UDAF.foreach(f => 
sparkSession.udf.register(f.getClass.getSimpleName, f)) // SPARK2 anchor
+
+    Catalog.aggregateExpressions2.foreach(f =>
+      sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) 
// SPARK3 anchor
   }
 
   def dropAll(sparkSession: SparkSession): Unit = {
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
index 43c488a80..0579967dc 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala
@@ -21,6 +21,10 @@ package org.apache.spark.sql.sedona_sql.expressions
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
 import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
+import org.locationtech.jts.operation.overlayng.OverlayNGRobust
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
 
 /**
  * traits for creating Aggregate Function
@@ -50,22 +54,44 @@ trait TraitSTAggregateExec {
   def finish(out: Geometry): Geometry = out
 }
 
-/**
- * Return the polygon union of all Polygon in the given column
- */
-class ST_Union_Aggr extends Aggregator[Geometry, Geometry, Geometry] with 
TraitSTAggregateExec {
+class ST_Union_Aggr(bufferSize: Int = 1000)
+    extends Aggregator[Geometry, ListBuffer[Geometry], Geometry]
+    with Serializable {
+
+  override def reduce(buffer: ListBuffer[Geometry], input: Geometry): 
ListBuffer[Geometry] = {
+    buffer += input
+    if (buffer.size >= bufferSize) {
+      // Perform the union when buffer size is reached
+      val unionGeometry = OverlayNGRobust.union(buffer.asJava)
+      buffer.clear()
+      buffer += unionGeometry
+    }
+    buffer
+  }
 
-  def reduce(buffer: Geometry, input: Geometry): Geometry = {
-    if (buffer.equalsExact(initialGeometry)) input
-    else buffer.union(input)
+  override def merge(
+      buffer1: ListBuffer[Geometry],
+      buffer2: ListBuffer[Geometry]): ListBuffer[Geometry] = {
+    buffer1 ++= buffer2
+    if (buffer1.size >= bufferSize) {
+      // Perform the union when buffer size is reached
+      val unionGeometry = OverlayNGRobust.union(buffer1.asJava)
+      buffer1.clear()
+      buffer1 += unionGeometry
+    }
+    buffer1
   }
 
-  def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
-    if (buffer1.equals(initialGeometry)) buffer2
-    else if (buffer2.equals(initialGeometry)) buffer1
-    else buffer1.union(buffer2)
+  override def finish(reduction: ListBuffer[Geometry]): Geometry = {
+    OverlayNGRobust.union(reduction.asJava)
   }
 
+  def bufferEncoder: ExpressionEncoder[ListBuffer[Geometry]] =
+    ExpressionEncoder[ListBuffer[Geometry]]()
+
+  def outputEncoder: ExpressionEncoder[Geometry] = 
ExpressionEncoder[Geometry]()
+
+  override def zero: ListBuffer[Geometry] = ListBuffer.empty
 }
 
 /**
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
 
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
index 843555e65..25de5c8ea 100644
--- 
a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
+++ 
b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala
@@ -18,7 +18,13 @@
  */
 package org.apache.sedona.sql
 
-import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.expressions.javalang.typed
+import org.apache.spark.sql.sedona_sql.expressions.ST_Union_Aggr
+import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory, 
Polygon}
+import org.locationtech.jts.io.WKTReader
+
+import scala.util.Random
 
 class aggregateFunctionTestScala extends TestBaseScala {
 
@@ -62,6 +68,35 @@ class aggregateFunctionTestScala extends TestBaseScala {
       assert(union.take(1)(0).get(0).asInstanceOf[Geometry].getArea == 10100)
     }
 
+    it("Measured ST_Union_aggr wall time") {
+      // number of random polygons to generate
+      val numPolygons = 1000
+      val df = createPolygonDataFrame(numPolygons)
+
+      df.createOrReplaceTempView("geometry_table_for_measuring_union_aggr")
+
+      // cache the table to eliminate the time of table scan
+      df.cache()
+      sparkSession
+        .sql("select count(*) from geometry_table_for_measuring_union_aggr")
+        .take(1)(0)
+        .get(0)
+
+      // measure time for optimized ST_Union_Aggr
+      val startTimeOptimized = System.currentTimeMillis()
+      val unionOptimized =
+        sparkSession.sql(
+          "SELECT ST_Union_Aggr(geom) AS union_geom FROM 
geometry_table_for_measuring_union_aggr")
+      assert(unionOptimized.take(1)(0).get(0).asInstanceOf[Geometry].getArea > 
0)
+      val endTimeOptimized = System.currentTimeMillis()
+      val durationOptimized = endTimeOptimized - startTimeOptimized
+
+      assert(durationOptimized > 0, "Duration of optimized ST_Union_Aggr 
should be positive")
+
+      // clear cache
+      df.unpersist()
+    }
+
     it("Passed ST_Intersection_aggr") {
 
       val twoPolygonsAsWktDf =
@@ -97,4 +132,24 @@ class aggregateFunctionTestScala extends TestBaseScala {
       
assertResult(0.0)(intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry].getArea)
     }
   }
+
+  def generateRandomPolygon(index: Int): String = {
+    val random = new Random()
+    val x = random.nextDouble() * index
+    val y = random.nextDouble() * index
+    s"POLYGON (($x $y, ${x + 1} $y, ${x + 1} ${y + 1}, $x ${y + 1}, $x $y))"
+  }
+
+  def createPolygonDataFrame(numPolygons: Int): DataFrame = {
+    val polygons = (1 to numPolygons).map(generateRandomPolygon).toArray
+    val polygonArray = polygons.map(polygon => s"ST_GeomFromWKT('$polygon')")
+    val polygonArrayStr = polygonArray.mkString(", ")
+
+    val sqlQuery =
+      s"""
+         |SELECT explode(array($polygonArrayStr)) AS geom
+     """.stripMargin
+
+    sparkSession.sql(sqlQuery)
+  }
 }
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index d01cb6799..bb83932c6 100644
--- 
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++ 
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -1585,7 +1585,7 @@ class dataFrameAPITestScala extends TestBaseScala {
         "SELECT explode(array(ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0 
0))'), ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 1, 1 1, 1 0))'))) AS geom")
       val df = baseDf.select(ST_Union_Aggr("geom"))
       val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
-      val expectedResult = "POLYGON ((1 0, 0 0, 0 1, 1 1, 2 1, 2 0, 1 0))"
+      val expectedResult = "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))"
       assert(actualResult == expectedResult)
     }
 

Reply via email to