This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new d5405cc29 [SEDONA-371] Add Optimized join support for Raster-Vector
Joins (#979)
d5405cc29 is described below
commit d5405cc29e2ff8f27a8d2daa304a8b23c5fa4273
Author: Nilesh Gajwani <[email protected]>
AuthorDate: Mon Aug 21 22:31:07 2023 -0400
[SEDONA-371] Add Optimized join support for Raster-Vector Joins (#979)
Co-authored-by: Jia Yu <[email protected]>
---
docs/api/sql/Raster-operators.md | 2 +-
.../apache/sedona/sql/utils/RasterSerializer.scala | 45 +++++++++++++
.../expressions/raster/RasterPredicates.scala | 74 ++++++++++++++++++++--
.../strategy/join/BroadcastIndexJoinExec.scala | 21 ++++--
.../strategy/join/DistanceJoinExec.scala | 2 +-
.../strategy/join/JoinQueryDetector.scala | 45 ++++++++++---
.../strategy/join/SpatialIndexExec.scala | 6 +-
.../strategy/join/TraitJoinQueryBase.scala | 25 ++++++--
.../strategy/join/TraitJoinQueryExec.scala | 11 ++--
.../sedona/sql/BroadcastIndexJoinSuite.scala | 43 +++++++++++++
.../org/apache/sedona/sql/TestBaseScala.scala | 14 ++++
.../apache/sedona/sql/predicateJoinTestScala.scala | 49 +++++++++++++-
12 files changed, 300 insertions(+), 37 deletions(-)
diff --git a/docs/api/sql/Raster-operators.md b/docs/api/sql/Raster-operators.md
index 32b91178c..0e2acb9e8 100644
--- a/docs/api/sql/Raster-operators.md
+++ b/docs/api/sql/Raster-operators.md
@@ -98,7 +98,7 @@ POLYGON ((0 0,20 0,20 60,0 60,0 0))
### RS_ConvexHull
Introduction: Return the convex hull geometry of the raster including the
NoDataBandValue band pixels.
-For regular shaped and non-skewed rasters, this gives more or less the same
result as RS_ConvexHull and hence is only useful for irregularly shaped or
skewed rasters.
+For regular shaped and non-skewed rasters, this gives more or less the same
result as RS_Envelope and hence is only useful for irregularly shaped or skewed
rasters.
Format: `RS_ConvexHull(raster: Raster)`
diff --git
a/sql/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala
b/sql/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala
new file mode 100644
index 000000000..3ffc41c6c
--- /dev/null
+++
b/sql/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sedona.sql.utils
+
+import org.apache.sedona.common.raster.Serde
+import org.geotools.coverage.grid.GridCoverage2D
+
+object RasterSerializer {
+ /**
+ * Given a raster returns array of bytes
+ *
+ * @param GridCoverage2D raster
+ * @return Array of bites represents this geometry
+ */
+ def serialize(raster: GridCoverage2D): Array[Byte] = {
+ Serde.serialize(raster);
+ }
+
+ /**
+ * Given ArrayData returns Geometry
+ *
+ * @param value ArrayData
+ * @return GridCoverage2D
+ */
+ def deserialize(value: Array[Byte]): GridCoverage2D = {
+ Serde.deserialize(value);
+ }
+}
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
index 95d746d8f..1880088ff 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
@@ -16,27 +16,87 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.spark.sql.sedona_sql.expressions.raster
+
+package org.apache.spark.sql.sedona_sql.expressions
import org.apache.sedona.common.raster.RasterPredicates
-import org.apache.spark.sql.catalyst.expressions.Expression
-import
org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
-import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
+import org.apache.sedona.sql.utils.{GeometrySerializer, RasterSerializer}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes,
Expression, NullIntolerant}
+import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
+import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType}
+import org.geotools.coverage.grid.GridCoverage2D
+import org.locationtech.jts.geom.Geometry
+
+abstract class RS_Predicate extends Expression
+ with FoldableExpression
+ with ExpectsInputTypes
+ with NullIntolerant {
+ def inputExpressions: Seq[Expression]
+
+ override def toString: String = s" **${this.getClass.getName}** "
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, GeometryUDT)
+
+ override def dataType: DataType = BooleanType
+
+ override def children: Seq[Expression] = inputExpressions
+
+ override final def eval(inputRow: InternalRow): Any = {
+ val leftArray =
inputExpressions.head.eval(inputRow).asInstanceOf[Array[Byte]]
+ if (leftArray == null) {
+ null
+ } else {
+ val rightArray =
inputExpressions(1).eval(inputRow).asInstanceOf[Array[Byte]]
+ if (rightArray == null) {
+ null
+ } else {
+ val leftGeometry = RasterSerializer.deserialize(leftArray)
+ val rightGeometry = GeometrySerializer.deserialize(rightArray)
+ evalGeom(leftGeometry, rightGeometry)
+ }
+ }
+ }
+
+ def evalGeom(leftGeometry: GridCoverage2D, rightGeometry: Geometry): Boolean
+}
+
+case class RS_Intersects(inputExpressions: Seq[Expression])
+ extends RS_Predicate with CodegenFallback {
+
+ override def evalGeom(leftGeometry: GridCoverage2D, rightGeometry:
Geometry): Boolean = {
+ RasterPredicates.rsIntersects(leftGeometry, rightGeometry)
+ }
-case class RS_Intersects(inputExpressions: Seq[Expression]) extends
InferredExpression(RasterPredicates.rsIntersects _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
{
copy(inputExpressions = newChildren)
}
}
-case class RS_Within(inputExpressions: Seq[Expression]) extends
InferredExpression(RasterPredicates.rsWithin _) {
+case class RS_Contains(inputExpressions: Seq[Expression])
+ extends RS_Predicate with CodegenFallback {
+
+ override def evalGeom(leftGeometry: GridCoverage2D, rightGeometry:
Geometry): Boolean = {
+ RasterPredicates.rsContains(leftGeometry, rightGeometry)
+ }
+
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
{
copy(inputExpressions = newChildren)
}
}
-case class RS_Contains(inputExpressions: Seq[Expression]) extends
InferredExpression(RasterPredicates.rsContains _) {
+case class RS_Within(inputExpressions: Seq[Expression])
+ extends RS_Predicate with CodegenFallback {
+
+ override def evalGeom(leftGeometry: GridCoverage2D, rightGeometry:
Geometry): Boolean = {
+ RasterPredicates.rsWithin(leftGeometry, rightGeometry)
+ }
+
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
{
copy(inputExpressions = newChildren)
}
}
+
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
index 693f4d1c4..35b58f3b0 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
@@ -18,9 +18,10 @@
*/
package org.apache.spark.sql.sedona_sql.strategy.join
+import org.apache.sedona.common.raster.GeometryFunctions
import org.apache.sedona.core.spatialOperator.{SpatialPredicate,
SpatialPredicateEvaluators}
import
org.apache.sedona.core.spatialOperator.SpatialPredicateEvaluators.SpatialPredicateEvaluator
-import org.apache.sedona.sql.utils.GeometrySerializer
+import org.apache.sedona.sql.utils.{GeometrySerializer, RasterSerializer}
import scala.collection.JavaConverters._
import org.apache.spark.broadcast.Broadcast
@@ -32,6 +33,7 @@ import
org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.execution.SedonaBinaryExecNode
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.geom.prep.{PreparedGeometry,
PreparedGeometryFactory}
@@ -107,10 +109,13 @@ case class BroadcastIndexJoinExec(
(streamShape, broadcast.shape)
}
- private val spatialExpression = (distance, spatialPredicate) match {
- case (Some(r), SpatialPredicate.INTERSECTS) =>
s"ST_Distance($windowExpression, $objectExpression) <= $r"
- case (Some(r), _) => s"ST_Distance($windowExpression, $objectExpression) <
$r"
- case (None, _) => s"ST_$spatialPredicate($windowExpression,
$objectExpression)"
+ private val isRaster = windowExpression.dataType.isInstanceOf[RasterUDT] ||
objectExpression.dataType.isInstanceOf[RasterUDT]
+
+ private val spatialExpression = (distance, spatialPredicate, isRaster) match
{
+ case (Some(r), SpatialPredicate.INTERSECTS, false) =>
s"ST_Distance($windowExpression, $objectExpression) <= $r"
+ case (Some(r), _, false) => s"ST_Distance($windowExpression,
$objectExpression) < $r"
+ case (None, _, false) => s"ST_$spatialPredicate($windowExpression,
$objectExpression)"
+ case (None, _, true) => s"RS_$spatialPredicate($windowExpression,
$objectExpression)"
}
override def simpleString(maxFields: Int): String =
super.simpleString(maxFields) + s" $spatialExpression" // SPARK3 anchor
@@ -260,11 +265,12 @@ case class BroadcastIndexJoinExec(
distance match {
case Some(distanceExpression) =>
streamResultsRaw.map(row => {
+ val isRaster = boundStreamShape.dataType.isInstanceOf[RasterUDT]
val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
if (geom == null) {
(null, row)
} else {
- val geometry = GeometrySerializer.deserialize(geom)
+ val geometry = if (isRaster)
GeometryFunctions.convexHull(RasterSerializer.deserialize(geom)) else
GeometrySerializer.deserialize(geom)
val radius = BindReferences.bindReference(distanceExpression,
streamed.output).eval(row).asInstanceOf[Double]
val envelope = geometry.getEnvelopeInternal
envelope.expandBy(radius)
@@ -273,11 +279,12 @@ case class BroadcastIndexJoinExec(
})
case _ =>
streamResultsRaw.map(row => {
+ val isRaster = boundStreamShape.dataType.isInstanceOf[RasterUDT]
val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
if (geom == null) {
(null, row)
} else {
- (GeometrySerializer.deserialize(geom), row)
+ (if (isRaster)
GeometryFunctions.convexHull(RasterSerializer.deserialize(geom)) else
GeometrySerializer.deserialize(geom), row)
}
})
}
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index 615f88a21..91ba539cd 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -69,7 +69,7 @@ case class DistanceJoinExec(left: SparkPlan,
override def toSpatialRddPair(leftRdd: RDD[UnsafeRow],
leftShapeExpr: Expression,
rightRdd: RDD[UnsafeRow],
- rightShapeExpr: Expression):
(SpatialRDD[Geometry], SpatialRDD[Geometry]) = {
+ rightShapeExpr: Expression, isLeftRaster:
Boolean, isRightRaster: Boolean): (SpatialRDD[Geometry], SpatialRDD[Geometry])
= {
if (distanceBoundToLeft) {
(toExpandedEnvelopeRDD(leftRdd, leftShapeExpr, boundRadius,
isGeography), toSpatialRDD(rightRdd, rightShapeExpr))
} else {
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 2715e950d..1c02b27e8 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And,
EqualNullSafe, EqualTo, E
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.expressions._
import
org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates
import org.apache.spark.sql.{SparkSession, Strategy}
@@ -79,6 +80,22 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
}
}
+ private def getRasterJoinDetection(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ predicate: RS_Predicate,
+ extraCondition: Option[Expression] = None): Option[JoinQueryDetection] = {
+ predicate match {
+ case RS_Intersects(Seq(leftShape, rightShape)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, extraCondition))
+ case RS_Contains(Seq(leftShape, rightShape)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.CONTAINS, false, extraCondition))
+ case RS_Within(Seq(leftShape, rightShape)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.WITHIN, false, extraCondition))
+ case _ => None
+ }
+ }
+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint))
if optimizationEnabled(left, right, condition) => {
var broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST))
@@ -94,7 +111,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
val canAutoBroadCastLeft = canAutoBroadcastBySize(left)
val canAutoBroadCastRight = canAutoBroadcastBySize(right)
if (canAutoBroadCastLeft && canAutoBroadCastRight) {
- // Both sides can be broadcasted. Choose the smallest side.
+ // Both sides can be broadcast. Choose the smallest side.
broadcastLeft = left.stats.sizeInBytes <= right.stats.sizeInBytes
broadcastRight = !broadcastLeft
} else {
@@ -104,12 +121,20 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
}
val queryDetection: Option[JoinQueryDetection] = condition match {
+ //For vector only joins
case Some(predicate: ST_Predicate) =>
getJoinDetection(left, right, predicate)
case Some(And(predicate: ST_Predicate, extraCondition)) =>
getJoinDetection(left, right, predicate, Some(extraCondition))
case Some(And(extraCondition, predicate: ST_Predicate)) =>
getJoinDetection(left, right, predicate, Some(extraCondition))
+ //For raster-vector joins
+ case Some(predicate: RS_Predicate) =>
+ getRasterJoinDetection(left, right, predicate)
+ case Some(And(predicate: RS_Predicate, extraCondition)) =>
+ getRasterJoinDetection(left, right, predicate, Some(extraCondition))
+ case Some(And(extraCondition, predicate: RS_Predicate)) =>
+ getRasterJoinDetection(left, right, predicate, Some(extraCondition))
// For distance joins we execute the actual predicate (condition) and
not only extraConditions.
case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)),
distance)) =>
Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
@@ -276,8 +301,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
val a = children.head
val b = children.tail.head
- val relationship = s"ST_$spatialPredicate"
-
+ val isRaster = a.dataType.isInstanceOf[RasterUDT] ||
b.dataType.isInstanceOf[RasterUDT]
+ val relationship = if (isRaster) s"RS_$spatialPredicate" else
s"ST_$spatialPredicate"
matchExpressionsToPlans(a, b, left, right) match {
case Some((_, _, false)) =>
logInfo(s"Planning spatial join for $relationship relationship")
@@ -366,13 +391,15 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
val a = children.head
val b = children.tail.head
+ val isRaster = a.dataType.isInstanceOf[RasterUDT] ||
b.dataType.isInstanceOf[RasterUDT]
- val relationship = (distance, spatialPredicate, isGeography) match {
- case (Some(_), SpatialPredicate.INTERSECTS, false) => "ST_Distance <="
- case (Some(_), _, false) => "ST_Distance <"
- case (Some(_), SpatialPredicate.INTERSECTS, true) => "ST_Distance
(Geography) <="
- case (Some(_), _, true) => "ST_Distance (Geography) <"
- case (None, _, false) => s"ST_$spatialPredicate"
+ val relationship = (distance, spatialPredicate, isGeography, isRaster)
match {
+ case (Some(_), SpatialPredicate.INTERSECTS, false, false) =>
"ST_Distance <="
+ case (Some(_), _, false, false) => "ST_Distance <"
+ case (Some(_), SpatialPredicate.INTERSECTS, true, false) => "ST_Distance
(Geography) <="
+ case (Some(_), _, true, false) => "ST_Distance (Geography) <"
+ case (None, _, false, false) => s"ST_$spatialPredicate"
+ case (None, _, false, true) => s"RS_$spatialPredicate"
}
val (distanceOnIndexSide, distanceOnStreamSide) = distance.map {
distanceExpr =>
matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
index 2c24a34e7..448ae744b 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
@@ -19,7 +19,6 @@
package org.apache.spark.sql.sedona_sql.strategy.join
import scala.jdk.CollectionConverters._
-
import org.apache.sedona.core.enums.IndexType
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
@@ -27,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences,
Expression, UnsafeRow}
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.execution.SedonaUnaryExecNode
@@ -48,12 +48,12 @@ case class SpatialIndexExec(child: SparkPlan,
override protected[sql] def doExecuteBroadcast[T](): Broadcast[T] = {
val boundShape = BindReferences.bindReference(shape, child.output)
-
+ val isRaster = boundShape.dataType.isInstanceOf[RasterUDT]
val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1)
val spatialRDD = distance match {
case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw,
boundShape, BindReferences.bindReference(distanceExpression, child.output),
isGeography)
- case None => toSpatialRDD(resultRaw, boundShape)
+ case None => if (isRaster) toSpatialRDDRaster(resultRaw, boundShape)
else toSpatialRDD(resultRaw, boundShape)
}
spatialRDD.buildIndex(indexType, false)
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index 0bdeeb836..9ea665397 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -18,12 +18,14 @@
*/
package org.apache.spark.sql.sedona_sql.strategy.join
+import org.apache.sedona.common.raster.GeometryFunctions
import org.apache.sedona.core.spatialRDD.SpatialRDD
import org.apache.sedona.core.utils.SedonaConf
-import org.apache.sedona.sql.utils.GeometrySerializer
+import org.apache.sedona.sql.utils.{GeometrySerializer, RasterSerializer}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeRow}
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.locationtech.jts.geom.{Envelope, Geometry}
trait TraitJoinQueryBase {
@@ -32,8 +34,9 @@ trait TraitJoinQueryBase {
def toSpatialRddPair(leftRdd: RDD[UnsafeRow],
leftShapeExpr: Expression,
rightRdd: RDD[UnsafeRow],
- rightShapeExpr: Expression): (SpatialRDD[Geometry],
SpatialRDD[Geometry]) =
- (toSpatialRDD(leftRdd, leftShapeExpr), toSpatialRDD(rightRdd,
rightShapeExpr))
+ rightShapeExpr: Expression, isLeftRaster: Boolean,
isRightRaster: Boolean): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
+ (if (isLeftRaster) toSpatialRDDRaster(leftRdd, leftShapeExpr) else
toSpatialRDD(leftRdd, leftShapeExpr),
+ if (isRightRaster) toSpatialRDDRaster(rightRdd, rightShapeExpr) else
toSpatialRDD(rightRdd, rightShapeExpr))
def toSpatialRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression):
SpatialRDD[Geometry] = {
val spatialRdd = new SpatialRDD[Geometry]
@@ -48,12 +51,26 @@ trait TraitJoinQueryBase {
spatialRdd
}
+ def toSpatialRDDRaster(rdd: RDD[UnsafeRow], shapeExpression: Expression):
SpatialRDD[Geometry] = {
+ val spatialRdd = new SpatialRDD[Geometry]
+ spatialRdd.setRawSpatialRDD(
+ rdd
+ .map { x =>
+ val shape =
GeometryFunctions.convexHull(RasterSerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]]))
+ shape.setUserData(x.copy)
+ shape
+ }
+ .toJavaRDD())
+ spatialRdd
+ }
+
def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression,
boundRadius: Expression, isGeography: Boolean): SpatialRDD[Geometry] = {
val spatialRdd = new SpatialRDD[Geometry]
+ val isRaster = shapeExpression.dataType.isInstanceOf[RasterUDT]
spatialRdd.setRawSpatialRDD(
rdd
.map { x =>
- val shape =
GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
+ val shape = if (isRaster)
GeometryFunctions.convexHull(RasterSerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]]))
else
GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
val envelope = shape.getEnvelopeInternal.copy()
expandEnvelope(envelope, boundRadius.eval(x).asInstanceOf[Double],
6357000.0, isGeography)
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
index c3a5a3b11..c377b3547 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
@@ -19,15 +19,15 @@
package org.apache.spark.sql.sedona_sql.strategy.join
import org.apache.sedona.core.enums.JoinSparitionDominantSide
-import org.apache.sedona.core.spatialOperator.JoinQuery
import org.apache.sedona.core.spatialOperator.JoinQuery.JoinParams
-import org.apache.sedona.core.spatialOperator.SpatialPredicate
+import org.apache.sedona.core.spatialOperator.{JoinQuery, SpatialPredicate}
import org.apache.sedona.core.utils.SedonaConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences,
Expression, Predicate, UnsafeRow}
import
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences,
Expression, Predicate, UnsafeRow}
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.locationtech.jts.geom.Geometry
trait TraitJoinQueryExec extends TraitJoinQueryBase {
@@ -50,8 +50,11 @@ trait TraitJoinQueryExec extends TraitJoinQueryBase {
val rightResultsRaw = right.execute().asInstanceOf[RDD[UnsafeRow]]
val sedonaConf = SedonaConf.fromActiveSession
+ val isLeftRaster = leftShape.dataType.isInstanceOf[RasterUDT]
+ val isRightRaster = rightShape.dataType.isInstanceOf[RasterUDT]
+
val (leftShapes, rightShapes) =
- toSpatialRddPair(leftResultsRaw, boundLeftShape, rightResultsRaw,
boundRightShape)
+ toSpatialRddPair(leftResultsRaw, boundLeftShape, rightResultsRaw,
boundRightShape, isLeftRaster, isRightRaster)
// Only do SpatialRDD analyze when the user doesn't know approximate total
count of the spatial partitioning
// dominant side rdd
diff --git
a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
index c22ba601e..38ea24662 100644
---
a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
+++
b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
@@ -405,6 +405,49 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
assert(distanceJoinDF.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size == 1)
assert(distanceJoinDF.count() == expected)
}
+
+ it("Passed RS_Intersects") {
+ val rasterDf = buildRasterDf.repartition(3)
+ val buildingsDf = buildBuildingsDf.repartition(5)
+ val joinDfRightBroadcast =
rasterDf.alias("rasterDf").join(broadcast(buildingsDf).alias("buildingsDf"),
expr("RS_Intersects(rasterDf.raster, buildingsDf.building)"))
+ assert(joinDfRightBroadcast.queryExecution.sparkPlan.collect{case p:
BroadcastIndexJoinExec => p}.size == 1)
+ val resultRightBroadcast = joinDfRightBroadcast.count()
+ assert(buildingsDf.count() == resultRightBroadcast) // raster is of
entire world, all buildings should intersect
+
+ //ideally the raster should not be broadcast here, testing it out
nevertheless
+ val joinDfLeftBroadcast =
broadcast(rasterDf.alias("rasterDf")).join(buildingsDf.alias("buildingsDf"),
expr("RS_Intersects(rasterDf.raster, buildingsDf.building)"))
+ assert(joinDfLeftBroadcast.queryExecution.sparkPlan.collect {case p:
BroadcastIndexJoinExec => p}.size == 1)
+ val resultLeftBroadcast = joinDfLeftBroadcast.count()
+ assert(buildingsDf.count() == resultLeftBroadcast)
+ }
+
+ it("Passed RS_Contains") {
+ val rasterDf = buildRasterDf.repartition(3)
+ val buildingsDf = buildBuildingsDf.limit(300).repartition(5)
+ val joinDfRightBroadcast =
rasterDf.alias("rasterDf").join(broadcast(buildingsDf).alias("buildingsDf"),
expr("RS_Contains(rasterDf.raster, buildingsDf.building)"))
+ assert(joinDfRightBroadcast.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size == 1)
+ val resultRightBroadcast = joinDfRightBroadcast.count()
+ assert(buildingsDf.count() == resultRightBroadcast) // raster is of
entire world, should contain all buildings
+
+ //ideally the raster should not be broadcast here, testing it out
nevertheless
+ val joinDfLeftBroadcast =
broadcast(rasterDf.alias("rasterDf")).join(buildingsDf.alias("buildingsDf"),
expr("RS_Contains(rasterDf.raster, buildingsDf.building)"))
+ assert(joinDfLeftBroadcast.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size == 1)
+ val resultLeftBroadcast = joinDfLeftBroadcast.count()
+ assert(buildingsDf.count() == resultLeftBroadcast)
+ }
+
+ it("Passed RS_Within") {
+ val smallRasterDf1 = buildSmallRasterDf.repartition(3)
+ val smallRasterDf2 =
buildSmallRasterDf.selectExpr("RS_ConvexHull(raster) as geom").repartition(5)
+ val joinDfRightBroadcast =
smallRasterDf1.alias("rasterDf").join(broadcast(smallRasterDf2.alias("geomDf")),
expr("RS_Within(rasterDf.raster, geomDf.geom)"))
+ assert(joinDfRightBroadcast.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size == 1)
+ assert(1 == joinDfRightBroadcast.count()) // raster within its own
convexHull
+
+ val joinDfLeftBroadcast =
broadcast(smallRasterDf1.alias("rasterDf")).join(smallRasterDf2.alias("geomDf"),
expr("RS_Within(rasterDf.raster, geomDf.geom)"))
+ assert(joinDfLeftBroadcast.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size == 1)
+ assert(1 == joinDfLeftBroadcast.count()) // raster within its own
convexHull
+
+ }
}
describe("Sedona-SQL Broadcast Index Join Test for left semi joins") {
diff --git
a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 46978a212..53aa239a2 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -66,6 +66,9 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
val smallPointsLocation: String = resourceFolder + "small/points.csv"
val spatialJoinLeftInputLocation: String = resourceFolder +
"spatial-predicates-test-data.tsv"
val spatialJoinRightInputLocation: String = resourceFolder +
"spatial-join-query-window.tsv"
+ val rasterDataLocation: String = resourceFolder +
"raster/raster_with_no_data/test5.tiff"
+ val buildingDataLocation: String = resourceFolder + "813_buildings_test.csv"
+ val smallRasterDataLocation: String = resourceFolder + "raster/test1.tiff"
override def beforeAll(): Unit = {
SedonaContext.create(sparkSession)
@@ -80,8 +83,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
sparkSession.read.format("csv").option("delimiter", ",").option("header",
"false").load(path)
}
+ def loadCsvWithHeader(path: String): DataFrame = {
+ sparkSession.read.format("csv").option("delimiter", ",").option("header",
"true").load(path)
+ }
+
+ def loadGeoTiff(path: String): DataFrame = {
+ sparkSession.read.format("binaryFile").load(path)
+ }
+
lazy val buildPointDf =
loadCsv(csvPointInputLocation).selectExpr("ST_Point(cast(_c0 as
Decimal(24,20)),cast(_c1 as Decimal(24,20))) as pointshape")
lazy val buildPolygonDf =
loadCsv(csvPolygonInputLocation).selectExpr("ST_PolygonFromEnvelope(cast(_c0 as
Decimal(24,20)),cast(_c1 as Decimal(24,20)), cast(_c2 as Decimal(24,20)),
cast(_c3 as Decimal(24,20))) as polygonshape")
+ lazy val buildRasterDf =
loadGeoTiff(rasterDataLocation).selectExpr("RS_FromGeoTiff(content) as raster")
+ lazy val buildBuildingsDf =
loadCsvWithHeader(buildingDataLocation).selectExpr("ST_GeomFromWKT(geometry) as
building")
+ lazy val buildSmallRasterDf =
loadGeoTiff(smallRasterDataLocation).selectExpr("RS_FromGeoTiff(content) as
raster")
protected final val FP_TOLERANCE: Double = 1e-12
protected final val COORDINATE_SEQUENCE_COMPARATOR:
CoordinateSequenceComparator = new CoordinateSequenceComparator(2) {
diff --git
a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
index 749399add..088aaf10d 100644
---
a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
+++
b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
@@ -22,7 +22,7 @@ package org.apache.sedona.sql
import org.apache.sedona.core.utils.SedonaConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.expr
-import org.apache.spark.sql.sedona_sql.strategy.join.DistanceJoinExec
+import org.apache.spark.sql.sedona_sql.strategy.join.{DistanceJoinExec,
RangeJoinExec}
import org.apache.spark.sql.types._
import org.locationtech.jts.geom.Geometry
@@ -30,6 +30,53 @@ class predicateJoinTestScala extends TestBaseScala {
describe("Sedona-SQL Predicate Join Test") {
+ //raster-vector predicates
+
+ it("Passed RS_Intersects in a join") {
+ val sedonaConf = new SedonaConf(sparkSession.conf)
+ println(sedonaConf)
+
+ val polygonCsvDf = sparkSession.read.format("csv").option("delimiter",
",").option("header", "true").load(buildingDataLocation)
+ polygonCsvDf.createOrReplaceTempView("polygontable")
+ val polygonDf = sparkSession.sql("SELECT ST_GeomFromWKT(geometry) as
building from polygontable")
+ polygonDf.createOrReplaceTempView("polygondf")
+
+ val rasterDf =
sparkSession.read.format("binaryFile").load(rasterDataLocation).selectExpr("RS_FromGeoTiff(content)
as raster")
+ rasterDf.createOrReplaceTempView("rasterDf")
+ //
assert(distanceDefaultNoIntersectsDF.queryExecution.sparkPlan.collect { case p:
DistanceJoinExec => p }.size === 1)
+ val rangeJoinDf = sparkSession.sql("select * from polygondf, rasterDf
where RS_Intersects(rasterDf.raster, polygondf.building)")
+ assert(rangeJoinDf.queryExecution.sparkPlan.collect{case p:
RangeJoinExec => p}.size === 1)
+ assert(rangeJoinDf.count() == 999)
+ }
+
+ it("Passed RS_Contains in a join") {
+ val sedonaConf = new SedonaConf(sparkSession.conf)
+ println(sedonaConf)
+
+ val polygonCsvDf = sparkSession.read.format("csv").option("delimiter",
",").option("header", "true").load(buildingDataLocation)
+ polygonCsvDf.createOrReplaceTempView("polygontable")
+ val polygonDf = sparkSession.sql("SELECT ST_GeomFromWKT(geometry) as
building from polygontable where confidence > 0.85")
+ polygonDf.createOrReplaceTempView("polygondf")
+
+ val rasterDf =
sparkSession.read.format("binaryFile").load(rasterDataLocation).selectExpr("RS_FromGeoTiff(content)
as raster")
+ rasterDf.createOrReplaceTempView("rasterDf")
+ val rangeJoinDf = sparkSession.sql("select * from rasterDf, polygondf
where RS_Contains(rasterDf.raster, polygondf.building)")
+ assert(rangeJoinDf.queryExecution.sparkPlan.collect{case p:
RangeJoinExec => p}.size === 1)
+ assert(rangeJoinDf.count() == 210)
+ }
+
+ it("Passed RS_Within in a join") {
+ val sedonaConf = new SedonaConf(sparkSession.conf)
+ println(sedonaConf)
+
+ val smallRasterDf =
sparkSession.read.format("binaryFile").load(resourceFolder +
"raster/test1.tiff").selectExpr("RS_FromGeoTiff(content) as raster")
+ smallRasterDf.createOrReplaceTempView("smallRaster")
+
+ val rangeJoinDf = sparkSession.sql("select * from smallRaster r1,
smallRaster r2 where RS_Within(r1.raster, RS_ConvexHull(r2.raster))")
+ assert(rangeJoinDf.queryExecution.sparkPlan.collect{case p:
RangeJoinExec => p}.size === 1)
+ assert(rangeJoinDf.count() == 1)
+ }
+
it("Passed ST_Contains in a join") {
val sedonaConf = new SedonaConf(sparkSession.conf)
println(sedonaConf)