Kontinuation commented on code in PR #704:
URL: https://github.com/apache/incubator-sedona/pull/704#discussion_r1007565524
##########
sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala:
##########
@@ -25,39 +25,47 @@ import
org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{BooleanType, DataType}
import org.locationtech.jts.geom.Geometry
+import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes
+import org.apache.spark.sql.types.AbstractDataType
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-abstract class ST_Predicate extends Expression
+abstract class ST_Predicate extends Expression with FoldableExpression with
ExpectsInputTypes {
-/**
- * Test if leftGeometry full contains rightGeometry
- *
- * @param inputExpressions
- */
-case class ST_Contains(inputExpressions: Seq[Expression])
- extends ST_Predicate with CodegenFallback {
-
- // This is a binary expression
- assert(inputExpressions.length == 2)
+ def inputExpressions: Seq[Expression]
override def nullable: Boolean = false
- override def toString: String = s" **${ST_Contains.getClass.getName}** "
+ override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT,
GeometryUDT)
+
+ override def dataType: DataType = BooleanType
override def children: Seq[Expression] = inputExpressions
override def eval(inputRow: InternalRow): Any = {
val leftArray = inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData]
val rightArray = inputExpressions(1).eval(inputRow).asInstanceOf[ArrayData]
-
val leftGeometry = GeometrySerializer.deserialize(leftArray)
-
val rightGeometry = GeometrySerializer.deserialize(rightArray)
+ evalGeom(leftGeometry, rightGeometry)
+ }
+
+ def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean
+}
+/**
+ * Test if leftGeometry full contains rightGeometry
+ *
+ * @param inputExpressions
+ */
+case class ST_Contains(inputExpressions: Seq[Expression])
+ extends ST_Predicate with CodegenFallback {
+
+ override def toString: String = s" **${ST_Contains.getClass.getName}** "
Review Comment:
Moved `toString` to `ST_Predicate`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]