jiayuasu commented on code in PR #704:
URL: https://github.com/apache/incubator-sedona/pull/704#discussion_r1007491140
##########
sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala:
##########
@@ -25,39 +25,47 @@ import
org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{BooleanType, DataType}
import org.locationtech.jts.geom.Geometry
+import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes
+import org.apache.spark.sql.types.AbstractDataType
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-abstract class ST_Predicate extends Expression
+abstract class ST_Predicate extends Expression with FoldableExpression with
ExpectsInputTypes {
-/**
- * Test if leftGeometry full contains rightGeometry
- *
- * @param inputExpressions
- */
-case class ST_Contains(inputExpressions: Seq[Expression])
- extends ST_Predicate with CodegenFallback {
-
- // This is a binary expression
- assert(inputExpressions.length == 2)
+ def inputExpressions: Seq[Expression]
override def nullable: Boolean = false
- override def toString: String = s" **${ST_Contains.getClass.getName}** "
+ override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT,
GeometryUDT)
+
+ override def dataType: DataType = BooleanType
override def children: Seq[Expression] = inputExpressions
override def eval(inputRow: InternalRow): Any = {
val leftArray = inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData]
val rightArray = inputExpressions(1).eval(inputRow).asInstanceOf[ArrayData]
-
val leftGeometry = GeometrySerializer.deserialize(leftArray)
-
val rightGeometry = GeometrySerializer.deserialize(rightArray)
+ evalGeom(leftGeometry, rightGeometry)
+ }
+
+ def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean
+}
+/**
+ * Test if leftGeometry full contains rightGeometry
+ *
+ * @param inputExpressions
+ */
+case class ST_Contains(inputExpressions: Seq[Expression])
+ extends ST_Predicate with CodegenFallback {
+
+ override def toString: String = s" **${ST_Contains.getClass.getName}** "
Review Comment:
the toString() function is only provided in this predicate. Can you put the
toString() function to ST_Predicate (e.g., toString(): getClass.getName) ?
##########
sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala:
##########
@@ -305,41 +325,29 @@ case class ST_GeomFromGeoJSON(inputExpressions:
Seq[Expression])
/**
* Return a Point from X and Y
*
- * @param inputExpressions This function takes 2 parameter which are point x
and y.
+ * @param inputExpressions This function takes 3 parameter which are point x,
y and z.
*/
case class ST_Point(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback with UserDataGeneratator {
- inputExpressions.betweenLength(2, 3)
Review Comment:
Is there a reason why you remove the inputExpression (2, 3) requirement?
I saw similar inputExpression.length assertions were kept in some places and
removed in some other places. Is there a reason?
##########
sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala:
##########
@@ -18,140 +18,149 @@
*/
package org.apache.sedona.sql.UDF
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes,
Expression, ExpressionInfo}
+import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.expressions.{Aggregator,
UserDefinedAggregateFunction}
import org.apache.spark.sql.sedona_sql.expressions.{ST_YMax, ST_YMin, _}
import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect,
ST_CollectionExtract}
import org.apache.spark.sql.sedona_sql.expressions.raster.{RS_Add, RS_Append,
RS_Array, RS_Base64, RS_BitwiseAnd, RS_BitwiseOr, RS_Count, RS_Divide,
RS_FetchRegion, RS_GetBand, RS_GreaterThan, RS_GreaterThanEqual, RS_HTML,
RS_LessThan, RS_LessThanEqual, RS_LogicalDifference, RS_LogicalOver, RS_Mean,
RS_Mode, RS_Modulo, RS_Multiply, RS_MultiplyFactor, RS_Normalize,
RS_NormalizedDifference, RS_SquareRoot, RS_Subtract}
import org.locationtech.jts.geom.Geometry
+import org.locationtech.jts.operation.buffer.BufferParameters
+
+import scala.reflect.ClassTag
object Catalog {
- val expressions: Seq[FunctionBuilder] = Seq(
+
+ type FunctionDescription = (FunctionIdentifier, ExpressionInfo,
FunctionBuilder)
+
+ val expressions: Seq[FunctionDescription] = Seq(
// Expression for vectors
- ST_PointFromText,
Review Comment:
Is there a specific reason to change the way we define functions in the
Catalog?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]