This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new d5405cc29 [SEDONA-371] Add Optimized join support for Raster-Vector 
Joins (#979)
d5405cc29 is described below

commit d5405cc29e2ff8f27a8d2daa304a8b23c5fa4273
Author: Nilesh Gajwani <[email protected]>
AuthorDate: Mon Aug 21 22:31:07 2023 -0400

    [SEDONA-371] Add Optimized join support for Raster-Vector Joins (#979)
    
    Co-authored-by: Jia Yu <[email protected]>
---
 docs/api/sql/Raster-operators.md                   |  2 +-
 .../apache/sedona/sql/utils/RasterSerializer.scala | 45 +++++++++++++
 .../expressions/raster/RasterPredicates.scala      | 74 ++++++++++++++++++++--
 .../strategy/join/BroadcastIndexJoinExec.scala     | 21 ++++--
 .../strategy/join/DistanceJoinExec.scala           |  2 +-
 .../strategy/join/JoinQueryDetector.scala          | 45 ++++++++++---
 .../strategy/join/SpatialIndexExec.scala           |  6 +-
 .../strategy/join/TraitJoinQueryBase.scala         | 25 ++++++--
 .../strategy/join/TraitJoinQueryExec.scala         | 11 ++--
 .../sedona/sql/BroadcastIndexJoinSuite.scala       | 43 +++++++++++++
 .../org/apache/sedona/sql/TestBaseScala.scala      | 14 ++++
 .../apache/sedona/sql/predicateJoinTestScala.scala | 49 +++++++++++++-
 12 files changed, 300 insertions(+), 37 deletions(-)

diff --git a/docs/api/sql/Raster-operators.md b/docs/api/sql/Raster-operators.md
index 32b91178c..0e2acb9e8 100644
--- a/docs/api/sql/Raster-operators.md
+++ b/docs/api/sql/Raster-operators.md
@@ -98,7 +98,7 @@ POLYGON ((0 0,20 0,20 60,0 60,0 0))
 ### RS_ConvexHull
 
 Introduction: Return the convex hull geometry of the raster including the 
NoDataBandValue band pixels. 
-For regular shaped and non-skewed rasters, this gives more or less the same 
result as RS_ConvexHull and hence is only useful for irregularly shaped or 
skewed rasters.
+For regular shaped and non-skewed rasters, this gives more or less the same 
result as RS_Envelope and hence is only useful for irregularly shaped or skewed 
rasters.
 
 Format: `RS_ConvexHull(raster: Raster)`
 
diff --git 
a/sql/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala 
b/sql/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala
new file mode 100644
index 000000000..3ffc41c6c
--- /dev/null
+++ 
b/sql/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sedona.sql.utils
+
+import org.apache.sedona.common.raster.Serde
+import org.geotools.coverage.grid.GridCoverage2D
+
+object RasterSerializer {
+  /**
+   * Given a raster returns array of bytes
+   *
+   * @param GridCoverage2D raster
+   * @return Array of bites represents this geometry
+   */
+  def serialize(raster: GridCoverage2D): Array[Byte] = {
+    Serde.serialize(raster);
+  }
+
+  /**
+   * Given ArrayData returns Geometry
+   *
+   * @param value ArrayData
+   * @return GridCoverage2D
+   */
+  def deserialize(value: Array[Byte]): GridCoverage2D = {
+    Serde.deserialize(value);
+  }
+}
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
index 95d746d8f..1880088ff 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
@@ -16,27 +16,87 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.spark.sql.sedona_sql.expressions.raster
+
+package org.apache.spark.sql.sedona_sql.expressions
 
 import org.apache.sedona.common.raster.RasterPredicates
-import org.apache.spark.sql.catalyst.expressions.Expression
-import 
org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
-import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
+import org.apache.sedona.sql.utils.{GeometrySerializer, RasterSerializer}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, NullIntolerant}
+import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
+import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType}
+import org.geotools.coverage.grid.GridCoverage2D
+import org.locationtech.jts.geom.Geometry
+
+abstract class RS_Predicate extends Expression
+  with FoldableExpression
+  with ExpectsInputTypes
+  with NullIntolerant {
+  def inputExpressions: Seq[Expression]
+
+  override def toString: String = s" **${this.getClass.getName}**  "
+
+  override def nullable: Boolean = children.exists(_.nullable)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, GeometryUDT)
+
+  override def dataType: DataType = BooleanType
+
+  override def children: Seq[Expression] = inputExpressions
+
+  override final def eval(inputRow: InternalRow): Any = {
+    val leftArray = 
inputExpressions.head.eval(inputRow).asInstanceOf[Array[Byte]]
+    if (leftArray == null) {
+      null
+    } else {
+      val rightArray = 
inputExpressions(1).eval(inputRow).asInstanceOf[Array[Byte]]
+      if (rightArray == null) {
+        null
+      } else {
+        val leftGeometry = RasterSerializer.deserialize(leftArray)
+        val rightGeometry = GeometrySerializer.deserialize(rightArray)
+        evalGeom(leftGeometry, rightGeometry)
+      }
+    }
+  }
+
+  def evalGeom(leftGeometry: GridCoverage2D, rightGeometry: Geometry): Boolean
+}
+
+case class RS_Intersects(inputExpressions: Seq[Expression])
+  extends RS_Predicate with CodegenFallback {
+
+  override def evalGeom(leftGeometry: GridCoverage2D, rightGeometry: 
Geometry): Boolean = {
+    RasterPredicates.rsIntersects(leftGeometry, rightGeometry)
+  }
 
-case class RS_Intersects(inputExpressions: Seq[Expression]) extends 
InferredExpression(RasterPredicates.rsIntersects _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
 }
 
-case class RS_Within(inputExpressions: Seq[Expression]) extends 
InferredExpression(RasterPredicates.rsWithin _) {
+case class RS_Contains(inputExpressions: Seq[Expression])
+  extends RS_Predicate with CodegenFallback {
+
+  override def evalGeom(leftGeometry: GridCoverage2D, rightGeometry: 
Geometry): Boolean = {
+    RasterPredicates.rsContains(leftGeometry, rightGeometry)
+  }
+
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
 }
 
-case class RS_Contains(inputExpressions: Seq[Expression]) extends 
InferredExpression(RasterPredicates.rsContains _) {
+case class RS_Within(inputExpressions: Seq[Expression])
+  extends RS_Predicate with CodegenFallback {
+
+  override def evalGeom(leftGeometry: GridCoverage2D, rightGeometry: 
Geometry): Boolean = {
+    RasterPredicates.rsWithin(leftGeometry, rightGeometry)
+  }
+
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
 }
+
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
index 693f4d1c4..35b58f3b0 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
@@ -18,9 +18,10 @@
  */
 package org.apache.spark.sql.sedona_sql.strategy.join
 
+import org.apache.sedona.common.raster.GeometryFunctions
 import org.apache.sedona.core.spatialOperator.{SpatialPredicate, 
SpatialPredicateEvaluators}
 import 
org.apache.sedona.core.spatialOperator.SpatialPredicateEvaluators.SpatialPredicateEvaluator
-import org.apache.sedona.sql.utils.GeometrySerializer
+import org.apache.sedona.sql.utils.{GeometrySerializer, RasterSerializer}
 
 import scala.collection.JavaConverters._
 import org.apache.spark.broadcast.Broadcast
@@ -32,6 +33,7 @@ import 
org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
 import org.apache.spark.sql.sedona_sql.execution.SedonaBinaryExecNode
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.geom.prep.{PreparedGeometry, 
PreparedGeometryFactory}
@@ -107,10 +109,13 @@ case class BroadcastIndexJoinExec(
     (streamShape, broadcast.shape)
   }
 
-  private val spatialExpression = (distance, spatialPredicate) match {
-    case (Some(r), SpatialPredicate.INTERSECTS) => 
s"ST_Distance($windowExpression, $objectExpression) <= $r"
-    case (Some(r), _) => s"ST_Distance($windowExpression, $objectExpression) < 
$r"
-    case (None, _) => s"ST_$spatialPredicate($windowExpression, 
$objectExpression)"
+  private val isRaster = windowExpression.dataType.isInstanceOf[RasterUDT] || 
objectExpression.dataType.isInstanceOf[RasterUDT]
+
+  private val spatialExpression = (distance, spatialPredicate, isRaster) match 
{
+    case (Some(r), SpatialPredicate.INTERSECTS, false) => 
s"ST_Distance($windowExpression, $objectExpression) <= $r"
+    case (Some(r), _, false) => s"ST_Distance($windowExpression, 
$objectExpression) < $r"
+    case (None, _, false) => s"ST_$spatialPredicate($windowExpression, 
$objectExpression)"
+    case (None, _, true) => s"RS_$spatialPredicate($windowExpression, 
$objectExpression)"
   }
 
   override def simpleString(maxFields: Int): String = 
super.simpleString(maxFields) + s" $spatialExpression" // SPARK3 anchor
@@ -260,11 +265,12 @@ case class BroadcastIndexJoinExec(
     distance match {
       case Some(distanceExpression) =>
         streamResultsRaw.map(row => {
+          val isRaster = boundStreamShape.dataType.isInstanceOf[RasterUDT]
           val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
           if (geom == null) {
             (null, row)
           } else {
-            val geometry = GeometrySerializer.deserialize(geom)
+            val geometry = if (isRaster) 
GeometryFunctions.convexHull(RasterSerializer.deserialize(geom)) else 
GeometrySerializer.deserialize(geom)
             val radius = BindReferences.bindReference(distanceExpression, 
streamed.output).eval(row).asInstanceOf[Double]
             val envelope = geometry.getEnvelopeInternal
             envelope.expandBy(radius)
@@ -273,11 +279,12 @@ case class BroadcastIndexJoinExec(
         })
       case _ =>
         streamResultsRaw.map(row => {
+          val isRaster = boundStreamShape.dataType.isInstanceOf[RasterUDT]
           val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
           if (geom == null) {
             (null, row)
           } else {
-            (GeometrySerializer.deserialize(geom), row)
+            (if (isRaster) 
GeometryFunctions.convexHull(RasterSerializer.deserialize(geom)) else 
GeometrySerializer.deserialize(geom), row)
           }
         })
     }
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index 615f88a21..91ba539cd 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -69,7 +69,7 @@ case class DistanceJoinExec(left: SparkPlan,
   override def toSpatialRddPair(leftRdd: RDD[UnsafeRow],
                                 leftShapeExpr: Expression,
                                 rightRdd: RDD[UnsafeRow],
-                                rightShapeExpr: Expression): 
(SpatialRDD[Geometry], SpatialRDD[Geometry]) = {
+                                rightShapeExpr: Expression, isLeftRaster: 
Boolean, isRightRaster: Boolean): (SpatialRDD[Geometry], SpatialRDD[Geometry]) 
= {
     if (distanceBoundToLeft) {
       (toExpandedEnvelopeRDD(leftRdd, leftShapeExpr, boundRadius, 
isGeography), toSpatialRDD(rightRdd, rightShapeExpr))
     } else {
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 2715e950d..1c02b27e8 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, 
EqualNullSafe, EqualTo, E
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
 import org.apache.spark.sql.sedona_sql.expressions._
 import 
org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates
 import org.apache.spark.sql.{SparkSession, Strategy}
@@ -79,6 +80,22 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
       }
     }
 
+  private def getRasterJoinDetection(
+    left: LogicalPlan,
+    right: LogicalPlan,
+    predicate: RS_Predicate,
+    extraCondition: Option[Expression] = None): Option[JoinQueryDetection] = {
+    predicate match {
+      case RS_Intersects(Seq(leftShape, rightShape)) =>
+        Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, extraCondition))
+      case RS_Contains(Seq(leftShape, rightShape)) =>
+        Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.CONTAINS, false, extraCondition))
+      case RS_Within(Seq(leftShape, rightShape)) =>
+        Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.WITHIN, false, extraCondition))
+      case _ => None
+    }
+  }
+
   def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
     case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint)) 
if optimizationEnabled(left, right, condition) => {
       var broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST))
@@ -94,7 +111,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
         val canAutoBroadCastLeft = canAutoBroadcastBySize(left)
         val canAutoBroadCastRight = canAutoBroadcastBySize(right)
         if (canAutoBroadCastLeft && canAutoBroadCastRight) {
-          // Both sides can be broadcasted. Choose the smallest side.
+          // Both sides can be broadcast. Choose the smallest side.
           broadcastLeft = left.stats.sizeInBytes <= right.stats.sizeInBytes
           broadcastRight = !broadcastLeft
         } else {
@@ -104,12 +121,20 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
       }
 
       val queryDetection: Option[JoinQueryDetection] = condition match {
+        //For vector only joins
         case Some(predicate: ST_Predicate) =>
           getJoinDetection(left, right, predicate)
         case Some(And(predicate: ST_Predicate, extraCondition)) =>
           getJoinDetection(left, right, predicate, Some(extraCondition))
         case Some(And(extraCondition, predicate: ST_Predicate)) =>
           getJoinDetection(left, right, predicate, Some(extraCondition))
+          //For raster-vector joins
+        case Some(predicate: RS_Predicate) =>
+          getRasterJoinDetection(left, right, predicate)
+        case Some(And(predicate: RS_Predicate, extraCondition)) =>
+          getRasterJoinDetection(left, right, predicate, Some(extraCondition))
+        case Some(And(extraCondition, predicate: RS_Predicate)) =>
+          getRasterJoinDetection(left, right, predicate, Some(extraCondition))
         // For distance joins we execute the actual predicate (condition) and 
not only extraConditions.
         case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), 
distance)) =>
           Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
@@ -276,8 +301,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
     val a = children.head
     val b = children.tail.head
 
-    val relationship = s"ST_$spatialPredicate"
-
+    val isRaster = a.dataType.isInstanceOf[RasterUDT] || 
b.dataType.isInstanceOf[RasterUDT]
+    val relationship = if (isRaster) s"RS_$spatialPredicate" else 
s"ST_$spatialPredicate"
     matchExpressionsToPlans(a, b, left, right) match {
       case Some((_, _, false)) =>
         logInfo(s"Planning spatial join for $relationship relationship")
@@ -366,13 +391,15 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
 
     val a = children.head
     val b = children.tail.head
+    val isRaster = a.dataType.isInstanceOf[RasterUDT] || 
b.dataType.isInstanceOf[RasterUDT]
 
-    val relationship = (distance, spatialPredicate, isGeography) match {
-      case (Some(_), SpatialPredicate.INTERSECTS, false) => "ST_Distance <="
-      case (Some(_), _, false) => "ST_Distance <"
-      case (Some(_), SpatialPredicate.INTERSECTS, true) => "ST_Distance 
(Geography) <="
-      case (Some(_), _, true) => "ST_Distance (Geography) <"
-      case (None, _, false) => s"ST_$spatialPredicate"
+    val relationship = (distance, spatialPredicate, isGeography, isRaster) 
match {
+      case (Some(_), SpatialPredicate.INTERSECTS, false, false) => 
"ST_Distance <="
+      case (Some(_), _, false, false) => "ST_Distance <"
+      case (Some(_), SpatialPredicate.INTERSECTS, true, false) => "ST_Distance 
(Geography) <="
+      case (Some(_), _, true, false) => "ST_Distance (Geography) <"
+      case (None, _, false, false) => s"ST_$spatialPredicate"
+      case (None, _, false, true) => s"RS_$spatialPredicate"
     }
     val (distanceOnIndexSide, distanceOnStreamSide) = distance.map { 
distanceExpr =>
       matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
index 2c24a34e7..448ae744b 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
@@ -19,7 +19,6 @@
 package org.apache.spark.sql.sedona_sql.strategy.join
 
 import scala.jdk.CollectionConverters._
-
 import org.apache.sedona.core.enums.IndexType
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
@@ -27,6 +26,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, 
Expression, UnsafeRow}
 import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
 import org.apache.spark.sql.sedona_sql.execution.SedonaUnaryExecNode
 
 
@@ -48,12 +48,12 @@ case class SpatialIndexExec(child: SparkPlan,
 
   override protected[sql] def doExecuteBroadcast[T](): Broadcast[T] = {
     val boundShape = BindReferences.bindReference(shape, child.output)
-
+    val isRaster = boundShape.dataType.isInstanceOf[RasterUDT]
     val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1)
 
     val spatialRDD = distance match {
       case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw, 
boundShape, BindReferences.bindReference(distanceExpression, child.output), 
isGeography)
-      case None => toSpatialRDD(resultRaw, boundShape)
+      case None => if (isRaster) toSpatialRDDRaster(resultRaw, boundShape) 
else toSpatialRDD(resultRaw, boundShape)
     }
 
     spatialRDD.buildIndex(indexType, false)
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index 0bdeeb836..9ea665397 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -18,12 +18,14 @@
  */
 package org.apache.spark.sql.sedona_sql.strategy.join
 
+import org.apache.sedona.common.raster.GeometryFunctions
 import org.apache.sedona.core.spatialRDD.SpatialRDD
 import org.apache.sedona.core.utils.SedonaConf
-import org.apache.sedona.sql.utils.GeometrySerializer
+import org.apache.sedona.sql.utils.{GeometrySerializer, RasterSerializer}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeRow}
 import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
 import org.locationtech.jts.geom.{Envelope, Geometry}
 
 trait TraitJoinQueryBase {
@@ -32,8 +34,9 @@ trait TraitJoinQueryBase {
   def toSpatialRddPair(leftRdd: RDD[UnsafeRow],
                        leftShapeExpr: Expression,
                        rightRdd: RDD[UnsafeRow],
-                       rightShapeExpr: Expression): (SpatialRDD[Geometry], 
SpatialRDD[Geometry]) =
-    (toSpatialRDD(leftRdd, leftShapeExpr), toSpatialRDD(rightRdd, 
rightShapeExpr))
+                       rightShapeExpr: Expression, isLeftRaster: Boolean, 
isRightRaster: Boolean): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
+    (if (isLeftRaster) toSpatialRDDRaster(leftRdd, leftShapeExpr) else 
toSpatialRDD(leftRdd, leftShapeExpr),
+      if (isRightRaster) toSpatialRDDRaster(rightRdd, rightShapeExpr) else 
toSpatialRDD(rightRdd, rightShapeExpr))
 
   def toSpatialRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression): 
SpatialRDD[Geometry] = {
     val spatialRdd = new SpatialRDD[Geometry]
@@ -48,12 +51,26 @@ trait TraitJoinQueryBase {
     spatialRdd
   }
 
+  def toSpatialRDDRaster(rdd: RDD[UnsafeRow], shapeExpression: Expression): 
SpatialRDD[Geometry] = {
+    val spatialRdd = new SpatialRDD[Geometry]
+    spatialRdd.setRawSpatialRDD(
+      rdd
+        .map { x =>
+          val shape = 
GeometryFunctions.convexHull(RasterSerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]]))
+          shape.setUserData(x.copy)
+          shape
+        }
+        .toJavaRDD())
+    spatialRdd
+  }
+
   def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, 
boundRadius: Expression, isGeography: Boolean): SpatialRDD[Geometry] = {
     val spatialRdd = new SpatialRDD[Geometry]
+    val isRaster = shapeExpression.dataType.isInstanceOf[RasterUDT]
     spatialRdd.setRawSpatialRDD(
       rdd
         .map { x =>
-          val shape = 
GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
+          val shape = if (isRaster) 
GeometryFunctions.convexHull(RasterSerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]]))
 else 
GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
           val envelope = shape.getEnvelopeInternal.copy()
           expandEnvelope(envelope, boundRadius.eval(x).asInstanceOf[Double], 
6357000.0, isGeography)
 
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
index c3a5a3b11..c377b3547 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
@@ -19,15 +19,15 @@
 package org.apache.spark.sql.sedona_sql.strategy.join
 
 import org.apache.sedona.core.enums.JoinSparitionDominantSide
-import org.apache.sedona.core.spatialOperator.JoinQuery
 import org.apache.sedona.core.spatialOperator.JoinQuery.JoinParams
-import org.apache.sedona.core.spatialOperator.SpatialPredicate
+import org.apache.sedona.core.spatialOperator.{JoinQuery, SpatialPredicate}
 import org.apache.sedona.core.utils.SedonaConf
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, 
Expression, Predicate, UnsafeRow}
 import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, 
Expression, Predicate, UnsafeRow}
 import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
 import org.locationtech.jts.geom.Geometry
 
 trait TraitJoinQueryExec extends TraitJoinQueryBase {
@@ -50,8 +50,11 @@ trait TraitJoinQueryExec extends TraitJoinQueryBase {
     val rightResultsRaw = right.execute().asInstanceOf[RDD[UnsafeRow]]
 
     val sedonaConf = SedonaConf.fromActiveSession
+    val isLeftRaster = leftShape.dataType.isInstanceOf[RasterUDT]
+    val isRightRaster = rightShape.dataType.isInstanceOf[RasterUDT]
+
     val (leftShapes, rightShapes) =
-      toSpatialRddPair(leftResultsRaw, boundLeftShape, rightResultsRaw, 
boundRightShape)
+      toSpatialRddPair(leftResultsRaw, boundLeftShape, rightResultsRaw, 
boundRightShape, isLeftRaster, isRightRaster)
 
     // Only do SpatialRDD analyze when the user doesn't know approximate total 
count of the spatial partitioning
     // dominant side rdd
diff --git 
a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala 
b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
index c22ba601e..38ea24662 100644
--- 
a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
+++ 
b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
@@ -405,6 +405,49 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
       assert(distanceJoinDF.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size == 1)
       assert(distanceJoinDF.count() == expected)
     }
+
+    it("Passed RS_Intersects") {
+      val rasterDf = buildRasterDf.repartition(3)
+      val buildingsDf = buildBuildingsDf.repartition(5)
+      val joinDfRightBroadcast = 
rasterDf.alias("rasterDf").join(broadcast(buildingsDf).alias("buildingsDf"), 
expr("RS_Intersects(rasterDf.raster, buildingsDf.building)"))
+      assert(joinDfRightBroadcast.queryExecution.sparkPlan.collect{case p: 
BroadcastIndexJoinExec => p}.size == 1)
+      val resultRightBroadcast = joinDfRightBroadcast.count()
+      assert(buildingsDf.count() == resultRightBroadcast) // raster is of 
entire world, all buildings should intersect
+
+      //ideally the raster should not be broadcast here, testing it out 
nevertheless
+      val joinDfLeftBroadcast = 
broadcast(rasterDf.alias("rasterDf")).join(buildingsDf.alias("buildingsDf"), 
expr("RS_Intersects(rasterDf.raster, buildingsDf.building)"))
+      assert(joinDfLeftBroadcast.queryExecution.sparkPlan.collect {case p: 
BroadcastIndexJoinExec => p}.size == 1)
+      val resultLeftBroadcast = joinDfLeftBroadcast.count()
+      assert(buildingsDf.count() == resultLeftBroadcast)
+    }
+
+    it("Passed RS_Contains") {
+      val rasterDf = buildRasterDf.repartition(3)
+      val buildingsDf = buildBuildingsDf.limit(300).repartition(5)
+      val joinDfRightBroadcast = 
rasterDf.alias("rasterDf").join(broadcast(buildingsDf).alias("buildingsDf"), 
expr("RS_Contains(rasterDf.raster, buildingsDf.building)"))
+      assert(joinDfRightBroadcast.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size == 1)
+      val resultRightBroadcast = joinDfRightBroadcast.count()
+      assert(buildingsDf.count() == resultRightBroadcast) // raster is of 
entire world, should contain all buildings
+
+      //ideally the raster should not be broadcast here, testing it out 
nevertheless
+      val joinDfLeftBroadcast = 
broadcast(rasterDf.alias("rasterDf")).join(buildingsDf.alias("buildingsDf"), 
expr("RS_Contains(rasterDf.raster, buildingsDf.building)"))
+      assert(joinDfLeftBroadcast.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size == 1)
+      val resultLeftBroadcast = joinDfLeftBroadcast.count()
+      assert(buildingsDf.count() == resultLeftBroadcast)
+    }
+
+    it("Passed RS_Within") {
+      val smallRasterDf1 = buildSmallRasterDf.repartition(3)
+      val smallRasterDf2 = 
buildSmallRasterDf.selectExpr("RS_ConvexHull(raster) as geom").repartition(5)
+      val joinDfRightBroadcast = 
smallRasterDf1.alias("rasterDf").join(broadcast(smallRasterDf2.alias("geomDf")),
 expr("RS_Within(rasterDf.raster, geomDf.geom)"))
+      assert(joinDfRightBroadcast.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size == 1)
+      assert(1 == joinDfRightBroadcast.count()) // raster within its own 
convexHull
+
+      val joinDfLeftBroadcast = 
broadcast(smallRasterDf1.alias("rasterDf")).join(smallRasterDf2.alias("geomDf"),
 expr("RS_Within(rasterDf.raster, geomDf.geom)"))
+      assert(joinDfLeftBroadcast.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size == 1)
+      assert(1 == joinDfLeftBroadcast.count()) // raster within its own 
convexHull
+
+    }
   }
 
   describe("Sedona-SQL Broadcast Index Join Test for left semi joins") {
diff --git 
a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala 
b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 46978a212..53aa239a2 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -66,6 +66,9 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
   val smallPointsLocation: String = resourceFolder + "small/points.csv"
   val spatialJoinLeftInputLocation: String = resourceFolder + 
"spatial-predicates-test-data.tsv"
   val spatialJoinRightInputLocation: String = resourceFolder + 
"spatial-join-query-window.tsv"
+  val rasterDataLocation: String = resourceFolder + 
"raster/raster_with_no_data/test5.tiff"
+  val buildingDataLocation: String = resourceFolder + "813_buildings_test.csv"
+  val smallRasterDataLocation: String = resourceFolder + "raster/test1.tiff"
 
   override def beforeAll(): Unit = {
     SedonaContext.create(sparkSession)
@@ -80,8 +83,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
     sparkSession.read.format("csv").option("delimiter", ",").option("header", 
"false").load(path)
   }
 
+  def loadCsvWithHeader(path: String): DataFrame = {
+    sparkSession.read.format("csv").option("delimiter", ",").option("header", 
"true").load(path)
+  }
+
+  def loadGeoTiff(path: String): DataFrame = {
+    sparkSession.read.format("binaryFile").load(path)
+  }
+
   lazy val buildPointDf = 
loadCsv(csvPointInputLocation).selectExpr("ST_Point(cast(_c0 as 
Decimal(24,20)),cast(_c1 as Decimal(24,20))) as pointshape")
   lazy val buildPolygonDf = 
loadCsv(csvPolygonInputLocation).selectExpr("ST_PolygonFromEnvelope(cast(_c0 as 
Decimal(24,20)),cast(_c1 as Decimal(24,20)), cast(_c2 as Decimal(24,20)), 
cast(_c3 as Decimal(24,20))) as polygonshape")
+  lazy val buildRasterDf = 
loadGeoTiff(rasterDataLocation).selectExpr("RS_FromGeoTiff(content) as raster")
+  lazy val buildBuildingsDf = 
loadCsvWithHeader(buildingDataLocation).selectExpr("ST_GeomFromWKT(geometry) as 
building")
+  lazy val buildSmallRasterDf = 
loadGeoTiff(smallRasterDataLocation).selectExpr("RS_FromGeoTiff(content) as 
raster")
 
   protected final val FP_TOLERANCE: Double = 1e-12
   protected final val COORDINATE_SEQUENCE_COMPARATOR: 
CoordinateSequenceComparator = new CoordinateSequenceComparator(2) {
diff --git 
a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala 
b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
index 749399add..088aaf10d 100644
--- 
a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
+++ 
b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
@@ -22,7 +22,7 @@ package org.apache.sedona.sql
 import org.apache.sedona.core.utils.SedonaConf
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.functions.expr
-import org.apache.spark.sql.sedona_sql.strategy.join.DistanceJoinExec
+import org.apache.spark.sql.sedona_sql.strategy.join.{DistanceJoinExec, 
RangeJoinExec}
 import org.apache.spark.sql.types._
 import org.locationtech.jts.geom.Geometry
 
@@ -30,6 +30,53 @@ class predicateJoinTestScala extends TestBaseScala {
 
   describe("Sedona-SQL Predicate Join Test") {
 
+    //raster-vector predicates
+
+    it("Passed RS_Intersects in a join") {
+      val sedonaConf = new SedonaConf(sparkSession.conf)
+      println(sedonaConf)
+
+      val polygonCsvDf = sparkSession.read.format("csv").option("delimiter", 
",").option("header", "true").load(buildingDataLocation)
+      polygonCsvDf.createOrReplaceTempView("polygontable")
+      val polygonDf = sparkSession.sql("SELECT ST_GeomFromWKT(geometry) as 
building from polygontable")
+      polygonDf.createOrReplaceTempView("polygondf")
+
+      val rasterDf = 
sparkSession.read.format("binaryFile").load(rasterDataLocation).selectExpr("RS_FromGeoTiff(content)
 as raster")
+      rasterDf.createOrReplaceTempView("rasterDf")
+      //        
assert(distanceDefaultNoIntersectsDF.queryExecution.sparkPlan.collect { case p: 
DistanceJoinExec => p }.size === 1)
+      val rangeJoinDf = sparkSession.sql("select * from polygondf, rasterDf 
where RS_Intersects(rasterDf.raster, polygondf.building)")
+      assert(rangeJoinDf.queryExecution.sparkPlan.collect{case p: 
RangeJoinExec => p}.size === 1)
+      assert(rangeJoinDf.count() == 999)
+    }
+
+    it("Passed RS_Contains in a join") {
+      val sedonaConf = new SedonaConf(sparkSession.conf)
+      println(sedonaConf)
+
+      val polygonCsvDf = sparkSession.read.format("csv").option("delimiter", 
",").option("header", "true").load(buildingDataLocation)
+      polygonCsvDf.createOrReplaceTempView("polygontable")
+      val polygonDf = sparkSession.sql("SELECT ST_GeomFromWKT(geometry) as 
building from polygontable where confidence > 0.85")
+      polygonDf.createOrReplaceTempView("polygondf")
+
+      val rasterDf = 
sparkSession.read.format("binaryFile").load(rasterDataLocation).selectExpr("RS_FromGeoTiff(content)
 as raster")
+      rasterDf.createOrReplaceTempView("rasterDf")
+      val rangeJoinDf = sparkSession.sql("select * from rasterDf, polygondf 
where RS_Contains(rasterDf.raster, polygondf.building)")
+      assert(rangeJoinDf.queryExecution.sparkPlan.collect{case p: 
RangeJoinExec => p}.size === 1)
+      assert(rangeJoinDf.count() == 210)
+    }
+
+    it("Passed RS_Within in a join") {
+      val sedonaConf = new SedonaConf(sparkSession.conf)
+      println(sedonaConf)
+
+      val smallRasterDf = 
sparkSession.read.format("binaryFile").load(resourceFolder + 
"raster/test1.tiff").selectExpr("RS_FromGeoTiff(content) as raster")
+      smallRasterDf.createOrReplaceTempView("smallRaster")
+
+      val rangeJoinDf = sparkSession.sql("select * from smallRaster r1, 
smallRaster r2 where RS_Within(r1.raster, RS_ConvexHull(r2.raster))")
+      assert(rangeJoinDf.queryExecution.sparkPlan.collect{case p: 
RangeJoinExec => p}.size === 1)
+      assert(rangeJoinDf.count() == 1)
+    }
+
     it("Passed ST_Contains in a join") {
       val sedonaConf = new SedonaConf(sparkSession.conf)
       println(sedonaConf)


Reply via email to