This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new d1f8b0ee [SEDONA-326] Improve raster algebra functions: RS_Array and 
RS_MultiplyFactor (#907)
d1f8b0ee is described below

commit d1f8b0eee4e1b146fd384d975457d182b3293940
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Jul 19 00:37:15 2023 +0800

    [SEDONA-326] Improve raster algebra functions: RS_Array and 
RS_MultiplyFactor (#907)
---
 .../main/java/org/apache/sedona/common/Constructors.java   |  2 --
 docs/api/sql/Raster-loader.md                              |  2 +-
 docs/api/sql/Raster-operators.md                           |  4 +++-
 .../spark/sql/sedona_sql/expressions/raster/IO.scala       |  7 +++++--
 .../sql/sedona_sql/expressions/raster/MapAlgebra.scala     | 14 ++++++++------
 .../scala/org/apache/sedona/sql/rasteralgebraTest.scala    | 13 +++++++++++++
 6 files changed, 30 insertions(+), 12 deletions(-)

diff --git a/common/src/main/java/org/apache/sedona/common/Constructors.java 
b/common/src/main/java/org/apache/sedona/common/Constructors.java
index 793738de..c391d77f 100644
--- a/common/src/main/java/org/apache/sedona/common/Constructors.java
+++ b/common/src/main/java/org/apache/sedona/common/Constructors.java
@@ -148,8 +148,6 @@ public class Constructors {
     }
 
     public static Geometry geomFromGeoHash(String geoHash, Integer precision) {
-        System.out.println(geoHash);
-        System.out.println(precision);
         try {
             return GeoHashDecoder.decode(geoHash, precision);
         } catch (GeoHashDecoder.InvalidGeoHashException e) {
diff --git a/docs/api/sql/Raster-loader.md b/docs/api/sql/Raster-loader.md
index d3b33ecb..c0a15640 100644
--- a/docs/api/sql/Raster-loader.md
+++ b/docs/api/sql/Raster-loader.md
@@ -191,7 +191,7 @@ Output:
 
 Introduction: Create an array that is filled by the given value
 
-Format: `RS_Array(length:Int, value: Decimal)`
+Format: `RS_Array(length:Int, value: Double)`
 
 Since: `v1.1.0`
 
diff --git a/docs/api/sql/Raster-operators.md b/docs/api/sql/Raster-operators.md
index 6cbff57d..8a67125e 100644
--- a/docs/api/sql/Raster-operators.md
+++ b/docs/api/sql/Raster-operators.md
@@ -517,7 +517,7 @@ val multiplyDF = spark.sql("select RS_Multiply(band1, 
band2) as multiplyBands fr
 
 Introduction: Multiply a factor to a spectral band in a geotiff image
 
-Format: `RS_MultiplyFactor (Band1: Array[Double], Factor: Int)`
+Format: `RS_MultiplyFactor (Band1: Array[Double], Factor: Double)`
 
 Since: `v1.1.0`
 
@@ -528,6 +528,8 @@ val multiplyFactorDF = spark.sql("select 
RS_MultiplyFactor(band1, 2) as multiply
 
 ```
 
+This function only accepts integer as factor before `v1.5.0`.
+
 ### RS_Normalize
 
 Introduction: Normalize the value in the array to [0, 255]
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/IO.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/IO.scala
index 1575ae81..1dc2d45a 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/IO.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/IO.scala
@@ -22,6 +22,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.sedona_sql.expressions.UserDataGeneratator
 import org.apache.spark.sql.types._
@@ -126,14 +127,14 @@ case class RS_GetBand(inputExpressions: Seq[Expression])
 }
 
 case class RS_Array(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback with UserDataGeneratator {
+  extends Expression with ImplicitCastInputTypes with CodegenFallback with 
UserDataGeneratator {
   override def nullable: Boolean = false
 
   override def eval(inputRow: InternalRow): Any = {
     // This is an expression which takes one input expressions
     assert(inputExpressions.length == 2)
     val len =inputExpressions(0).eval(inputRow).asInstanceOf[Int]
-    val num = inputExpressions(1).eval(inputRow).asInstanceOf[Decimal].toDouble
+    val num = inputExpressions(1).eval(inputRow).asInstanceOf[Double]
     val result = createarray(len, num)
     new GenericArrayData(result)
   }
@@ -148,6 +149,8 @@ case class RS_Array(inputExpressions: Seq[Expression])
     result
   }
 
+  override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, DoubleType)
+
   override def dataType: DataType = ArrayType(DoubleType)
 
   override def children: Seq[Expression] = inputExpressions
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
index e08e9bff..157bb941 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
@@ -22,6 +22,7 @@ import org.apache.sedona.common.raster.{MapAlgebra, Serde}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression}
+import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
 import 
org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
@@ -352,30 +353,31 @@ case class RS_Count(inputExpressions: Seq[Expression])
 
 // Multiply a factor to all values of a band
 case class RS_MultiplyFactor(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback with UserDataGeneratator {
+  extends Expression with ImplicitCastInputTypes with CodegenFallback with 
UserDataGeneratator {
   assert(inputExpressions.length == 2)
 
   override def nullable: Boolean = false
 
   override def eval(inputRow: InternalRow): Any = {
     val band = 
inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData].toDoubleArray()
-    val target = inputExpressions(1).eval(inputRow).asInstanceOf[Int]
-    new GenericArrayData(multiply(band, target))
+    val factor = inputExpressions(1).eval(inputRow).asInstanceOf[Double]
+    new GenericArrayData(multiply(band, factor))
 
   }
 
-  private def multiply(band: Array[Double], target: Int):Array[Double] = {
+  private def multiply(band: Array[Double], factor: Double):Array[Double] = {
 
     var result = new Array[Double](band.length)
     for(i<-0 until band.length) {
 
-      result(i) = band(i)*target
+      result(i) = band(i) * factor
 
     }
     result
-
   }
 
+  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType(DoubleType), 
DoubleType)
+
   override def dataType: DataType = ArrayType(DoubleType)
 
   override def children: Seq[Expression] = inputExpressions
diff --git 
a/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala 
b/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
index 7728b772..5d3bba69 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
@@ -71,6 +71,12 @@ class rasteralgebraTest extends TestBaseScala with 
BeforeAndAfter with GivenWhen
       assert(inputDf.first().getAs[mutable.WrappedArray[Double]](0) == 
expectedDF.first().getAs[mutable.WrappedArray[Double]](0))
     }
 
+    it("Passed RS_MultiplyFactor with double factor") {
+      val inputDf = Seq((Seq(200.0, 400.0, 600.0))).toDF("Band")
+      val expectedDF = Seq((Seq(20.0, 40.0, 60.0))).toDF("multiply")
+      val actualDF = inputDf.selectExpr("RS_MultiplyFactor(Band, 0.1) as 
multiply")
+      assert(actualDF.first().getAs[mutable.WrappedArray[Double]](0) == 
expectedDF.first().getAs[mutable.WrappedArray[Double]](0))
+    }
   }
 
   describe("Should pass basic statistical tests") {
@@ -202,6 +208,13 @@ class rasteralgebraTest extends TestBaseScala with 
BeforeAndAfter with GivenWhen
       df = df.selectExpr("RS_Normalize(Band) as normalizedBand")
       assert(df.first().getAs[mutable.WrappedArray[Double]](0)(1) == 255)
     }
+
+    it("should pass RS_Array") {
+      val df = sparkSession.sql("SELECT RS_Array(6, 1e-6) as band")
+      val result = df.first().getAs[mutable.WrappedArray[Double]](0)
+      assert(result.length == 6)
+      assert(result sameElements Array.fill[Double](6)(1e-6))
+    }
   }
 
   describe("Should pass all transformation tests") {

Reply via email to