This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 754836e  [SEDONA-82] Fixes in ST_Difference and ST_SymDifference (#584)
754836e is described below

commit 754836e2da847ab5a187546fafe8f5e46980a7dd
Author: Magdalena <[email protected]>
AuthorDate: Wed Feb 23 07:39:10 2022 +0100

    [SEDONA-82] Fixes in ST_Difference and ST_SymDifference (#584)
---
 .../sql/sedona_sql/expressions/Functions.scala     | 26 +++++++++-------------
 .../org/apache/sedona/sql/functionTestScala.scala  | 12 ++++++++--
 2 files changed, 21 insertions(+), 17 deletions(-)

diff --git 
a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
 
b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 08916f9..4f29dbf 100644
--- 
a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ 
b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -1444,6 +1444,11 @@ case class ST_GeoHash(inputExpressions: Seq[Expression])
   }
 }
 
+/**
+ * Return the difference between geometry A and B
+ *
+ * @param inputExpressions
+ */
 case class ST_Difference(inputExpressions: Seq[Expression])
   extends BinaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 2)
@@ -1456,14 +1461,14 @@ case class ST_Difference(inputExpressions: 
Seq[Expression])
     lazy val isRightContainsLeft = rightGeometry.contains(leftGeometry)
 
     if (!isIntersects) {
-      return new GenericArrayData(GeometrySerializer.serialize(leftGeometry))
+      new GenericArrayData(GeometrySerializer.serialize(leftGeometry))
     }
 
     if (isIntersects && isRightContainsLeft) {
-      return new GenericArrayData(GeometrySerializer.serialize(emptyPolygon))
+      new GenericArrayData(GeometrySerializer.serialize(emptyPolygon))
     }
 
-    return new 
GenericArrayData(GeometrySerializer.serialize(leftGeometry.difference(rightGeometry)))
+    new 
GenericArrayData(GeometrySerializer.serialize(leftGeometry.difference(rightGeometry)))
   }
 
   override def dataType: DataType = GeometryUDT
@@ -1481,20 +1486,11 @@ case class ST_Difference(inputExpressions: 
Seq[Expression])
  * @param inputExpressions
  */
 case class ST_SymDifference(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends BinaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 2)
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val leftGeometry = inputExpressions(0).toGeometry(input)
-    val rightGeometry = inputExpressions(1).toGeometry(input)
-
-    (leftGeometry, rightGeometry) match {
-      case (leftGeometry: Geometry, rightGeometry: Geometry)
-      => new 
GenericArrayData(GeometrySerializer.serialize(leftGeometry.symDifference(rightGeometry)))
-      case _ => null
-    }
+  override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: 
Geometry): Any = {
+    new 
GenericArrayData(GeometrySerializer.serialize(leftGeometry.symDifference(rightGeometry)))
   }
 
   override def dataType: DataType = GeometryUDT
diff --git a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala 
b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index 779b022..109628c 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -357,10 +357,10 @@ class functionTestScala extends TestBaseScala with 
Matchers with GeometrySample
 
     it("Passed ST_Difference - right not overlaps left") {
 
-      val testtable = sparkSession.sql("select ST_GeomFromWKT('POLYGON ((-3 
-3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 
5 -1, 5 -3))') as b")
+      val testtable = sparkSession.sql("select ST_GeomFromWKT('POLYGON ((-3 
-3, -3 3, 3 3, 3 -3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 
5 -1, 5 -3))') as b")
       testtable.createOrReplaceTempView("testtable")
       val diff = sparkSession.sql("select ST_Difference(a,b) from testtable")
-      
assert(diff.take(1)(0).get(0).asInstanceOf[Geometry].toText.equals("POLYGON 
((-3 -3, 3 -3, 3 3, -3 3, -3 -3))"))
+      
assert(diff.take(1)(0).get(0).asInstanceOf[Geometry].toText.equals("POLYGON 
((-3 -3, -3 3, 3 3, 3 -3, -3 -3))"))
     }
 
     it("Passed ST_Difference - left contains right") {
@@ -379,6 +379,14 @@ class functionTestScala extends TestBaseScala with 
Matchers with GeometrySample
       
assert(diff.take(1)(0).get(0).asInstanceOf[Geometry].toText.equals("POLYGON 
EMPTY"))
     }
 
+    it("Passed ST_Difference - one null") {
+
+      val testtable = sparkSession.sql("select ST_GeomFromWKT('POLYGON ((-3 
-3, 3 -3, 3 3, -3 3, -3 -3))') as a")
+      testtable.createOrReplaceTempView("testtable")
+      val diff = sparkSession.sql("select ST_Difference(a,null) from 
testtable")
+      assert(diff.first().get(0) == null)
+    }
+
     it("Passed ST_SymDifference - part of right overlaps left") {
 
       val testtable = sparkSession.sql("select ST_GeomFromWKT('POLYGON ((-1 
-1, 1 -1, 1 1, -1 1, -1 -1))') as a,ST_GeomFromWKT('POLYGON ((0 -2, 2 -2, 2 0, 
0 0, 0 -2))') as b")

Reply via email to