This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 51c869fa [SEDONA-327] Refactored raster UDFs to extend 
InferredExpression (#909)
51c869fa is described below

commit 51c869fa688d076b3758a17b3dcee5176b528670
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Thu Jul 20 08:39:50 2023 +0800

    [SEDONA-327] Refactored raster UDFs to extend InferredExpression (#909)
---
 .../sedona/common/raster/RasterConstructors.java   |  8 +-
 .../apache/sedona/common/raster/RasterOutputs.java | 12 ++-
 python-adapter/pom.xml                             |  4 +
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  2 +-
 .../expressions/InferredExpression.scala           | 26 +++++-
 .../sedona_sql/expressions/raster/MapAlgebra.scala | 63 ++------------
 .../expressions/raster/PixelFunctions.scala        | 27 +-----
 .../expressions/raster/RasterAccessors.scala       | 89 ++-----------------
 .../expressions/raster/RasterConstructors.scala    | 99 ++++------------------
 .../expressions/raster/RasterEditors.scala         | 39 ++-------
 .../expressions/raster/RasterOutputs.scala         | 56 ++----------
 .../expressions/raster/RasterPredicates.scala      | 28 +-----
 viz/pom.xml                                        |  4 +
 13 files changed, 98 insertions(+), 359 deletions(-)

diff --git 
a/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java 
b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java
index 0a2aab00..9111cf3e 100644
--- 
a/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java
+++ 
b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java
@@ -93,15 +93,9 @@ public class RasterConstructors
         } else {
             crs = CRS.decode("EPSG:" + srid);
         }
-        // If scaleY is not defined, use scaleX
-        // MAX_VALUE is used to indicate that the scaleY is not defined
-        double actualScaleY = scaleY;
-        if (scaleY == Integer.MAX_VALUE) {
-            actualScaleY = scaleX;
-        }
         // Create a new empty raster
         WritableRaster raster = 
RasterFactory.createBandedRaster(DataBuffer.TYPE_DOUBLE, widthInPixel, 
heightInPixel, numBand, null);
-        MathTransform transform = new AffineTransform2D(scaleX, skewY, skewX, 
-actualScaleY, upperLeftX + scaleX / 2, upperLeftY - actualScaleY / 2);
+        MathTransform transform = new AffineTransform2D(scaleX, skewY, skewX, 
-scaleY, upperLeftX + scaleX / 2, upperLeftY - scaleY / 2);
         GridGeometry2D gridGeometry = new GridGeometry2D(new GridEnvelope2D(0, 
0, widthInPixel, heightInPixel), transform, crs);
         ReferencedEnvelope referencedEnvelope = new 
ReferencedEnvelope(gridGeometry.getEnvelope2D());
         // Create a new coverage
diff --git 
a/common/src/main/java/org/apache/sedona/common/raster/RasterOutputs.java 
b/common/src/main/java/org/apache/sedona/common/raster/RasterOutputs.java
index 8fbf8217..dcbdce6d 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/RasterOutputs.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/RasterOutputs.java
@@ -35,7 +35,7 @@ import java.io.IOException;
 
 public class RasterOutputs
 {
-    public static byte[] asGeoTiff(GridCoverage2D raster, String 
compressionType, float compressionQuality) {
+    public static byte[] asGeoTiff(GridCoverage2D raster, String 
compressionType, double compressionQuality) {
         ByteArrayOutputStream out = new ByteArrayOutputStream();
         GridCoverageWriter writer;
         try {
@@ -52,7 +52,7 @@ public class RasterOutputs
             params.setCompressionType(compressionType);
             // Should be a value between 0 and 1
             // 0 means max compression, 1 means no compression
-            params.setCompressionQuality(compressionQuality);
+            params.setCompressionQuality((float) compressionQuality);
             
defaultParams.parameter(AbstractGridFormat.GEOTOOLS_WRITE_PARAMS.getName().toString()).setValue(params);
         }
         GeneralParameterValue[] wps = defaultParams.values().toArray(new 
GeneralParameterValue[0]);
@@ -67,6 +67,10 @@ public class RasterOutputs
         return out.toByteArray();
     }
 
+    public static byte[] asGeoTiff(GridCoverage2D raster) {
+        return asGeoTiff(raster, null, -1);
+    }
+
     public static byte[] asArcGrid(GridCoverage2D raster, int sourceBand) {
         ByteArrayOutputStream out = new ByteArrayOutputStream();
         GridCoverageWriter writer;
@@ -93,4 +97,8 @@ public class RasterOutputs
         }
         return out.toByteArray();
     }
+
+    public static byte[] asArcGrid(GridCoverage2D raster) {
+        return asArcGrid(raster, -1);
+    }
 }
diff --git a/python-adapter/pom.xml b/python-adapter/pom.xml
index 9e494c72..b534f8a2 100644
--- a/python-adapter/pom.xml
+++ b/python-adapter/pom.xml
@@ -111,6 +111,10 @@
             <groupId>org.geotools</groupId>
             <artifactId>gt-epsg-hsql</artifactId>
         </dependency>
+        <dependency>
+            <groupId>org.geotools</groupId>
+            <artifactId>gt-coverage</artifactId>
+        </dependency>
         <dependency>
             <groupId>org.scala-lang</groupId>
             <artifactId>scala-library</artifactId>
diff --git a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala 
b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index b5f7b0df..06b7750e 100644
--- a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -192,7 +192,7 @@ object Catalog {
     function[RS_BandAsArray](),
     function[RS_FromArcInfoAsciiGrid](),
     function[RS_FromGeoTiff](),
-    function[RS_MakeEmptyRaster](java.lang.Integer.MAX_VALUE, 0.0, 0.0, 0),
+    function[RS_MakeEmptyRaster](),
     function[RS_Envelope](),
     function[RS_NumBands](),
     function[RS_Metadata](),
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
index 3d8ade3b..aed843a6 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
@@ -22,11 +22,13 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Expression, 
ImplicitCastInputTypes}
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
 import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, 
DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
 import org.apache.spark.unsafe.types.UTF8String
 import org.locationtech.jts.geom.Geometry
 import org.apache.spark.sql.sedona_sql.expressions.implicits._
+import org.apache.spark.sql.sedona_sql.expressions.raster.implicits._
+import org.geotools.coverage.grid.GridCoverage2D
 
 import scala.reflect.runtime.universe.TypeTag
 import scala.reflect.runtime.universe.Type
@@ -76,6 +78,8 @@ sealed class InferrableType[T: TypeTag]
 object InferrableType {
   implicit val geometryInstance: InferrableType[Geometry] =
     new InferrableType[Geometry] {}
+  implicit val gridCoverage2DInstance: InferrableType[GridCoverage2D] =
+    new InferrableType[GridCoverage2D] {}
   implicit val geometryArrayInstance: InferrableType[Array[Geometry]] =
     new InferrableType[Array[Geometry]] {}
   implicit val javaDoubleInstance: InferrableType[java.lang.Double] =
@@ -96,6 +100,8 @@ object InferrableType {
     new InferrableType[Array[Byte]] {}
   implicit val longArrayInstance: InferrableType[Array[java.lang.Long]] =
     new InferrableType[Array[java.lang.Long]] {}
+  implicit val doubleArrayInstance: InferrableType[Array[Double]] =
+    new InferrableType[Array[Double]] {}
 }
 
 object InferredTypes {
@@ -104,6 +110,10 @@ object InferredTypes {
       expr => input => expr.toGeometry(input)
     } else if (t =:= typeOf[Array[Geometry]]) {
       expr => input => expr.toGeometryArray(input)
+    } else if (t =:= typeOf[GridCoverage2D]) {
+      expr => input => expr.toRaster(input)
+    } else if (t =:= typeOf[Array[Double]]) {
+      expr => input => expr.eval(input).asInstanceOf[ArrayData].toDoubleArray()
     } else if (t =:= typeOf[String]) {
       expr => input => expr.asString(input)
     } else {
@@ -119,6 +129,14 @@ object InferredTypes {
         } else {
           null
         }
+    } else if (t =:= typeOf[GridCoverage2D]) {
+      output => {
+        if (output != null) {
+          output.asInstanceOf[GridCoverage2D].serialize
+        } else {
+          null
+        }
+      }
     } else if (t =:= typeOf[String]) {
       output =>
         if (output != null) {
@@ -126,7 +144,7 @@ object InferredTypes {
         } else {
           null
         }
-    } else if (t =:= typeOf[Array[java.lang.Long]]) {
+    } else if (t =:= typeOf[Array[java.lang.Long]] || t =:= 
typeOf[Array[Double]]) {
       output =>
         if (output != null) {
           ArrayData.toArrayData(output)
@@ -157,6 +175,8 @@ object InferredTypes {
       GeometryUDT
     } else if (t =:= typeOf[Array[Geometry]]) {
       DataTypes.createArrayType(GeometryUDT)
+    } else if (t =:= typeOf[GridCoverage2D]) {
+      RasterUDT
     } else if (t =:= typeOf[java.lang.Double]) {
       DoubleType
     } else if (t =:= typeOf[java.lang.Integer]) {
@@ -171,6 +191,8 @@ object InferredTypes {
       BinaryType
     } else if (t =:= typeOf[Array[java.lang.Long]]) {
       DataTypes.createArrayType(LongType)
+    } else if (t =:= typeOf[Array[Double]]) {
+      DataTypes.createArrayType(DoubleType)
     } else if (t =:= typeOf[Option[Boolean]]) {
       BooleanType
     } else {
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
index 157bb941..a56954e3 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
@@ -18,17 +18,14 @@
  */
 package org.apache.spark.sql.sedona_sql.expressions.raster
 
-import org.apache.sedona.common.raster.{MapAlgebra, Serde}
+import org.apache.sedona.common.raster.MapAlgebra
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, 
ImplicitCastInputTypes}
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression}
-import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
-import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
-import 
org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
-import org.apache.spark.sql.sedona_sql.expressions.{SerdeAware, 
UserDataGeneratator}
+import 
org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.{InferredExpression, 
UserDataGeneratator}
 import org.apache.spark.sql.types._
-import org.geotools.coverage.grid.GridCoverage2D
 
 /// Calculate Normalized Difference between two bands
 case class RS_NormalizedDifference(inputExpressions: Seq[Expression])
@@ -807,61 +804,15 @@ case class RS_Append(inputExpressions: Seq[Expression])
   }
 }
 
-case class RS_AddBandFromArray(inputExpressions: Seq[Expression]) extends 
Expression with CodegenFallback
-  with ExpectsInputTypes with SerdeAware {
-
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    Option(evalWithoutSerialization(input)).map(Serde.serialize).orNull
-  }
-
-  override def dataType: DataType = RasterUDT
-
-  override def children: Seq[Expression] = inputExpressions
-
+case class RS_AddBandFromArray(inputExpressions: Seq[Expression])
+  extends InferredExpression(inferrableFunction3(MapAlgebra.addBandFromArray)) 
{
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, ArrayType, 
IntegerType)
-
-  override def evalWithoutSerialization(input: InternalRow): GridCoverage2D = {
-    val raster = inputExpressions(0).toRaster(input)
-    if (raster == null) {
-      return null
-    }
-    val band = 
inputExpressions(1).eval(input).asInstanceOf[ArrayData].toDoubleArray()
-    val bandIndex = inputExpressions(2).eval(input).asInstanceOf[Int]
-    MapAlgebra.addBandFromArray(raster, band, bandIndex)
-  }
 }
 
-case class RS_BandAsArray(inputExpressions: Seq[Expression]) extends 
Expression with CodegenFallback
-  with ExpectsInputTypes {
-
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val raster = inputExpressions(0).toRaster(input)
-    if (raster == null) {
-      return null
-    }
-    val bandIndex = inputExpressions(1).eval(input).asInstanceOf[Int]
-    val band = MapAlgebra.bandAsArray(raster, bandIndex)
-    if (band == null) {
-      return null
-    }
-    new GenericArrayData(band)
-  }
-
-  override def dataType: DataType = ArrayType(DoubleType)
-
-  override def children: Seq[Expression] = inputExpressions
-
+case class RS_BandAsArray(inputExpressions: Seq[Expression]) extends 
InferredExpression(MapAlgebra.bandAsArray _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, IntegerType)
 }
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala
index d6001651..97b9e673 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/PixelFunctions.scala
@@ -25,34 +25,15 @@ import 
org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
-import 
org.apache.spark.sql.sedona_sql.expressions.implicits.InputExpressionEnhancer
 import 
org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
 import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, 
DoubleType, IntegerType}
+import 
org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
 
-case class RS_Value(inputExpressions: Seq[Expression]) extends Expression with 
CodegenFallback with ExpectsInputTypes {
-
-  override def nullable: Boolean = true
-
-  override def dataType: DataType = DoubleType
-
-  override def eval(input: InternalRow): Any = {
-    val raster = inputExpressions.head.toRaster(input)
-    val geom = inputExpressions(1).toGeometry(input)
-    val band = inputExpressions(2).eval(input).asInstanceOf[Int]
-    if (raster == null || geom == null) {
-      null
-    } else {
-      PixelFunctions.value(raster, geom, band)
-    }
-  }
-
-  override def children: Seq[Expression] = inputExpressions
-
+case class RS_Value(inputExpressions: Seq[Expression]) extends 
InferredExpression(PixelFunctions.value _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, GeometryUDT, 
IntegerType)
 }
 
 case class RS_Values(inputExpressions: Seq[Expression]) extends Expression 
with CodegenFallback with ExpectsInputTypes {
@@ -82,4 +63,4 @@ case class RS_Values(inputExpressions: Seq[Expression]) 
extends Expression with
   }
 
   override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, 
ArrayType(GeometryUDT), IntegerType)
-}
\ No newline at end of file
+}
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala
index 1f19979e..57ea3423 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterAccessors.scala
@@ -19,103 +19,30 @@
 package org.apache.spark.sql.sedona_sql.expressions.raster
 
 import org.apache.sedona.common.raster.RasterAccessors
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression}
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.GenericArrayData
-import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
-import org.apache.spark.sql.sedona_sql.expressions.implicits.GeometryEnhancer
-import 
org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
-import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, 
DoubleType, IntegerType}
-
-case class RS_Envelope(inputExpressions: Seq[Expression]) extends Expression 
with CodegenFallback with ExpectsInputTypes {
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val raster = inputExpressions(0).toRaster(input)
-    if (raster == null) {
-      null
-    } else {
-      RasterAccessors.envelope(raster).toGenericArrayData
-    }
-  }
-
-  override def dataType: DataType = GeometryUDT
-
-  override def children: Seq[Expression] = inputExpressions
+import org.apache.spark.sql.catalyst.expressions.Expression
+import 
org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
 
+case class RS_Envelope(inputExpressions: Seq[Expression]) extends 
InferredExpression(RasterAccessors.envelope _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT)
 }
 
-case class RS_NumBands(inputExpressions: Seq[Expression]) extends Expression 
with CodegenFallback with ExpectsInputTypes {
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val raster = inputExpressions(0).toRaster(input)
-    if (raster == null) {
-      null
-    } else {
-      RasterAccessors.numBands(raster)
-    }
-  }
-
-  override def dataType: DataType = IntegerType
-
-  override def children: Seq[Expression] = inputExpressions
-
+case class RS_NumBands(inputExpressions: Seq[Expression]) extends 
InferredExpression(RasterAccessors.numBands _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT)
 }
 
-case class RS_SRID(inputExpressions: Seq[Expression]) extends Expression with 
CodegenFallback with ExpectsInputTypes {
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val raster = inputExpressions(0).toRaster(input)
-    if (raster == null) {
-      null
-    } else {
-      RasterAccessors.srid(raster)
-    }
-  }
-
-  override def dataType: DataType = IntegerType
-
-  override def children: Seq[Expression] = inputExpressions
-
+case class RS_SRID(inputExpressions: Seq[Expression]) extends 
InferredExpression(RasterAccessors.srid _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT)
 }
 
-case class RS_Metadata(inputExpressions: Seq[Expression]) extends Expression 
with CodegenFallback with ExpectsInputTypes {
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val raster = inputExpressions(0).toRaster(input)
-    if (raster == null) {
-      null
-    } else {
-      new GenericArrayData(RasterAccessors.metadata(raster))
-    }
-  }
-
-  override def dataType: DataType = ArrayType(DoubleType)
-
-  override def children: Seq[Expression] = inputExpressions
-
+case class RS_Metadata(inputExpressions: Seq[Expression]) extends 
InferredExpression(RasterAccessors.metadata _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT)
-}
\ No newline at end of file
+}
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
index 7a385b78..48fe8f6a 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
@@ -18,105 +18,38 @@
  */
 package org.apache.spark.sql.sedona_sql.expressions.raster
 
-import org.apache.sedona.common.raster.{RasterConstructors, Serde}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ImplicitCastInputTypes}
-import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
-import org.apache.spark.sql.sedona_sql.expressions.SerdeAware
-import org.apache.spark.sql.sedona_sql.expressions.raster.implicits._
-import org.apache.spark.sql.types._
-import org.geotools.coverage.grid.GridCoverage2D
+import org.apache.sedona.common.raster.RasterConstructors
+import org.apache.spark.sql.catalyst.expressions.Expression
+import 
org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
 
-
-case class RS_FromArcInfoAsciiGrid(inputExpressions: Seq[Expression]) extends 
Expression with CodegenFallback
-  with ExpectsInputTypes with SerdeAware {
-
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    Option(evalWithoutSerialization(input)).map(Serde.serialize).orNull
-  }
-
-  override def dataType: DataType = RasterUDT
-
-  override def children: Seq[Expression] = inputExpressions
+case class RS_FromArcInfoAsciiGrid(inputExpressions: Seq[Expression])
+  extends InferredExpression(RasterConstructors.fromArcInfoAsciiGrid _) {
+  override def foldable: Boolean = false
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
-
-  override def evalWithoutSerialization(input: InternalRow): GridCoverage2D = {
-    val bytes = inputExpressions(0).eval(input).asInstanceOf[Array[Byte]]
-    if (bytes == null) {
-      null
-    } else {
-      RasterConstructors.fromArcInfoAsciiGrid(bytes)
-    }
-  }
 }
 
-case class RS_FromGeoTiff(inputExpressions: Seq[Expression]) extends 
Expression with CodegenFallback
-  with ExpectsInputTypes with SerdeAware {
-
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    Option(evalWithoutSerialization(input)).map(Serde.serialize).orNull
-  }
-
-  override def dataType: DataType = RasterUDT
+case class RS_FromGeoTiff(inputExpressions: Seq[Expression])
+  extends InferredExpression(RasterConstructors.fromGeoTiff _) {
 
-  override def children: Seq[Expression] = inputExpressions
+  override def foldable: Boolean = false
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
-
-  override def evalWithoutSerialization(input: InternalRow): GridCoverage2D = {
-    val bytes = inputExpressions(0).eval(input).asInstanceOf[Array[Byte]]
-    if (bytes == null) {
-      null
-    } else {
-      RasterConstructors.fromGeoTiff(bytes)
-    }
-  }
 }
 
-case class RS_MakeEmptyRaster(inputExpressions: Seq[Expression]) extends 
Expression with CodegenFallback
-  with ImplicitCastInputTypes with SerdeAware {
-
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    Option(evalWithoutSerialization(input)).map(Serde.serialize).orNull
-  }
-
-  override def dataType: DataType = RasterUDT
+case class RS_MakeEmptyRaster(inputExpressions: Seq[Expression])
+  extends InferredExpression(
+    inferrableFunction6(RasterConstructors.makeEmptyRaster),
+    inferrableFunction10(RasterConstructors.makeEmptyRaster)) {
 
-  override def children: Seq[Expression] = inputExpressions
+  override def foldable: Boolean = false
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, 
IntegerType, IntegerType, DecimalType, DecimalType, DecimalType, DecimalType, 
DecimalType, DecimalType, IntegerType)
-
-  override def evalWithoutSerialization(input: InternalRow): GridCoverage2D = {
-    val numBands = inputExpressions(0).eval(input).asInstanceOf[Int]
-    val widthInPixels = inputExpressions(1).eval(input).asInstanceOf[Int]
-    val heightInPixel = inputExpressions(2).eval(input).asInstanceOf[Int]
-    val upperLeftX = 
inputExpressions(3).eval(input).asInstanceOf[Decimal].toDouble
-    val upperLeftY = 
inputExpressions(4).eval(input).asInstanceOf[Decimal].toDouble
-    val pixelSizeX = 
inputExpressions(5).eval(input).asInstanceOf[Decimal].toDouble
-    val pixelSizeY = 
inputExpressions(6).eval(input).asInstanceOf[Decimal].toDouble
-    val skewX = inputExpressions(7).eval(input).asInstanceOf[Decimal].toDouble
-    val skewY = inputExpressions(8).eval(input).asInstanceOf[Decimal].toDouble
-    val srid = inputExpressions(9).eval(input).asInstanceOf[Int]
-    RasterConstructors.makeEmptyRaster(numBands, widthInPixels, heightInPixel, 
upperLeftX, upperLeftY, pixelSizeX, pixelSizeY, skewX, skewY, srid)
-  }
-}
\ No newline at end of file
+}
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala
index db63fb6c..3b8083b6 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala
@@ -18,40 +18,13 @@
  */
 package org.apache.spark.sql.sedona_sql.expressions.raster
 
-import org.apache.sedona.common.raster.{RasterEditors, Serde}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression}
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
-import org.apache.spark.sql.sedona_sql.expressions.SerdeAware
-import 
org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
-import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType}
-import org.geotools.coverage.grid.GridCoverage2D
-
-case class RS_SetSRID(inputExpressions: Seq[Expression]) extends Expression 
with CodegenFallback with ExpectsInputTypes with SerdeAware {
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    Option(evalWithoutSerialization(input)).map(Serde.serialize).orNull
-  }
-
-  override def evalWithoutSerialization(input: InternalRow): GridCoverage2D = {
-    val raster = inputExpressions(0).toRaster(input)
-    val srid = inputExpressions(1).eval(input).asInstanceOf[Int]
-    if (raster == null) {
-      null
-    } else {
-      RasterEditors.setSrid(raster, srid)
-    }
-  }
-
-  override def dataType: DataType = RasterUDT
-
-  override def children: Seq[Expression] = inputExpressions
+import org.apache.sedona.common.raster.RasterEditors
+import org.apache.spark.sql.catalyst.expressions.Expression
+import 
org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
 
+case class RS_SetSRID(inputExpressions: Seq[Expression]) extends 
InferredExpression(RasterEditors.setSrid _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, IntegerType)
-}
\ No newline at end of file
+}
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala
index 195ed0c0..4026f8e0 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterOutputs.scala
@@ -19,60 +19,22 @@
 package org.apache.spark.sql.sedona_sql.expressions.raster
 
 import org.apache.sedona.common.raster.RasterOutputs
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import 
org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-// Expected Types (RasterUDT, StringType, IntegerType) or (RasterUDT, 
StringType, DecimalType)
-case class RS_AsGeoTiff(inputExpressions: Seq[Expression]) extends Expression 
with CodegenFallback {
-
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val raster = inputExpressions(0).toRaster(input)
-    if (raster == null) return null
-    // If there are more than one input expressions, the additional ones are 
used as parameters
-    if (inputExpressions.length > 1) {
-      RasterOutputs.asGeoTiff(raster, 
inputExpressions(1).eval(input).asInstanceOf[UTF8String].toString, 
inputExpressions(2).eval(input).toString.toFloat)
-    }
-    else {
-      RasterOutputs.asGeoTiff(raster, null, -1)
-    }
-  }
-
-  override def dataType: DataType = BinaryType
-
-  override def children: Seq[Expression] = inputExpressions
+import 
org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
 
+case class RS_AsGeoTiff(inputExpressions: Seq[Expression])
+  extends InferredExpression(inferrableFunction3(RasterOutputs.asGeoTiff),
+    inferrableFunction1(RasterOutputs.asGeoTiff)) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
 }
 
-case class RS_AsArcGrid(inputExpressions: Seq[Expression]) extends Expression 
with CodegenFallback {
-
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val raster = inputExpressions(0).toRaster(input)
-    if (raster == null) return null
-    // If there are more than one input expressions, the additional ones are 
used as parameters
-    if (inputExpressions.length > 1) {
-      RasterOutputs.asArcGrid(raster, 
inputExpressions(1).eval(input).asInstanceOf[Int])
-    }
-    else {
-      RasterOutputs.asArcGrid(raster, -1)
-    }
-  }
-
-  override def dataType: DataType = BinaryType
-
-  override def children: Seq[Expression] = inputExpressions
-
+case class RS_AsArcGrid(inputExpressions: Seq[Expression])
+  extends InferredExpression(inferrableFunction2(RasterOutputs.asArcGrid),
+    inferrableFunction1(RasterOutputs.asArcGrid)) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
-}
\ No newline at end of file
+}
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
index d35a0256..54cfd6d7 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterPredicates.scala
@@ -19,31 +19,11 @@
 package org.apache.spark.sql.sedona_sql.expressions.raster
 
 import org.apache.sedona.common.raster.RasterPredicates
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression}
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
-import 
org.apache.spark.sql.sedona_sql.expressions.implicits.InputExpressionEnhancer
-import 
org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
-import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType}
-
-case class RS_Intersects(inputExpressions: Seq[Expression]) extends Expression 
with CodegenFallback with ExpectsInputTypes {
-
-  override def eval(input: InternalRow): Any = {
-    val raster = inputExpressions.head.toRaster(input)
-    val geom = inputExpressions(1).toGeometry(input)
-    if (raster == null || geom == null) {
-      null
-    } else {
-      RasterPredicates.rsIntersects(raster, geom)
-    }
-  }
-
-  override def nullable: Boolean = true
-  override def dataType: DataType = BooleanType
-  override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, GeometryUDT)
-  override def children: Seq[Expression] = inputExpressions
+import org.apache.spark.sql.catalyst.expressions.Expression
+import 
org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
+import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
 
+case class RS_Intersects(inputExpressions: Seq[Expression]) extends 
InferredExpression(RasterPredicates.rsIntersects _) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
diff --git a/viz/pom.xml b/viz/pom.xml
index fbeac182..2ea8c3c7 100644
--- a/viz/pom.xml
+++ b/viz/pom.xml
@@ -115,6 +115,10 @@
             <groupId>org.geotools</groupId>
             <artifactId>gt-epsg-hsql</artifactId>
         </dependency>
+        <dependency>
+            <groupId>org.geotools</groupId>
+            <artifactId>gt-coverage</artifactId>
+        </dependency>
         <dependency>
             <groupId>org.scala-lang</groupId>
             <artifactId>scala-library</artifactId>


Reply via email to