This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new d1f8b0ee [SEDONA-326] Improve raster algebra functions: RS_Array and
RS_MultiplyFactor (#907)
d1f8b0ee is described below
commit d1f8b0eee4e1b146fd384d975457d182b3293940
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Jul 19 00:37:15 2023 +0800
[SEDONA-326] Improve raster algebra functions: RS_Array and
RS_MultiplyFactor (#907)
---
.../main/java/org/apache/sedona/common/Constructors.java | 2 --
docs/api/sql/Raster-loader.md | 2 +-
docs/api/sql/Raster-operators.md | 4 +++-
.../spark/sql/sedona_sql/expressions/raster/IO.scala | 7 +++++--
.../sql/sedona_sql/expressions/raster/MapAlgebra.scala | 14 ++++++++------
.../scala/org/apache/sedona/sql/rasteralgebraTest.scala | 13 +++++++++++++
6 files changed, 30 insertions(+), 12 deletions(-)
diff --git a/common/src/main/java/org/apache/sedona/common/Constructors.java
b/common/src/main/java/org/apache/sedona/common/Constructors.java
index 793738de..c391d77f 100644
--- a/common/src/main/java/org/apache/sedona/common/Constructors.java
+++ b/common/src/main/java/org/apache/sedona/common/Constructors.java
@@ -148,8 +148,6 @@ public class Constructors {
}
public static Geometry geomFromGeoHash(String geoHash, Integer precision) {
- System.out.println(geoHash);
- System.out.println(precision);
try {
return GeoHashDecoder.decode(geoHash, precision);
} catch (GeoHashDecoder.InvalidGeoHashException e) {
diff --git a/docs/api/sql/Raster-loader.md b/docs/api/sql/Raster-loader.md
index d3b33ecb..c0a15640 100644
--- a/docs/api/sql/Raster-loader.md
+++ b/docs/api/sql/Raster-loader.md
@@ -191,7 +191,7 @@ Output:
Introduction: Create an array that is filled by the given value
-Format: `RS_Array(length:Int, value: Decimal)`
+Format: `RS_Array(length:Int, value: Double)`
Since: `v1.1.0`
diff --git a/docs/api/sql/Raster-operators.md b/docs/api/sql/Raster-operators.md
index 6cbff57d..8a67125e 100644
--- a/docs/api/sql/Raster-operators.md
+++ b/docs/api/sql/Raster-operators.md
@@ -517,7 +517,7 @@ val multiplyDF = spark.sql("select RS_Multiply(band1,
band2) as multiplyBands fr
Introduction: Multiply a factor to a spectral band in a geotiff image
-Format: `RS_MultiplyFactor (Band1: Array[Double], Factor: Int)`
+Format: `RS_MultiplyFactor (Band1: Array[Double], Factor: Double)`
Since: `v1.1.0`
@@ -528,6 +528,8 @@ val multiplyFactorDF = spark.sql("select
RS_MultiplyFactor(band1, 2) as multiply
```
+This function only accepts integer as factor before `v1.5.0`.
+
### RS_Normalize
Introduction: Normalize the value in the array to [0, 255]
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/IO.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/IO.scala
index 1575ae81..1dc2d45a 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/IO.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/IO.scala
@@ -22,6 +22,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.expressions.UserDataGeneratator
import org.apache.spark.sql.types._
@@ -126,14 +127,14 @@ case class RS_GetBand(inputExpressions: Seq[Expression])
}
case class RS_Array(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback with UserDataGeneratator {
+ extends Expression with ImplicitCastInputTypes with CodegenFallback with
UserDataGeneratator {
override def nullable: Boolean = false
override def eval(inputRow: InternalRow): Any = {
// This is an expression which takes one input expressions
assert(inputExpressions.length == 2)
val len =inputExpressions(0).eval(inputRow).asInstanceOf[Int]
- val num = inputExpressions(1).eval(inputRow).asInstanceOf[Decimal].toDouble
+ val num = inputExpressions(1).eval(inputRow).asInstanceOf[Double]
val result = createarray(len, num)
new GenericArrayData(result)
}
@@ -148,6 +149,8 @@ case class RS_Array(inputExpressions: Seq[Expression])
result
}
+ override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, DoubleType)
+
override def dataType: DataType = ArrayType(DoubleType)
override def children: Seq[Expression] = inputExpressions
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
index e08e9bff..157bb941 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/MapAlgebra.scala
@@ -22,6 +22,7 @@ import org.apache.sedona.common.raster.{MapAlgebra, Serde}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes,
Expression}
+import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import
org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
@@ -352,30 +353,31 @@ case class RS_Count(inputExpressions: Seq[Expression])
// Multiply a factor to all values of a band
case class RS_MultiplyFactor(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback with UserDataGeneratator {
+ extends Expression with ImplicitCastInputTypes with CodegenFallback with
UserDataGeneratator {
assert(inputExpressions.length == 2)
override def nullable: Boolean = false
override def eval(inputRow: InternalRow): Any = {
val band =
inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData].toDoubleArray()
- val target = inputExpressions(1).eval(inputRow).asInstanceOf[Int]
- new GenericArrayData(multiply(band, target))
+ val factor = inputExpressions(1).eval(inputRow).asInstanceOf[Double]
+ new GenericArrayData(multiply(band, factor))
}
- private def multiply(band: Array[Double], target: Int):Array[Double] = {
+ private def multiply(band: Array[Double], factor: Double):Array[Double] = {
var result = new Array[Double](band.length)
for(i<-0 until band.length) {
- result(i) = band(i)*target
+ result(i) = band(i) * factor
}
result
-
}
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType(DoubleType),
DoubleType)
+
override def dataType: DataType = ArrayType(DoubleType)
override def children: Seq[Expression] = inputExpressions
diff --git
a/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
b/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
index 7728b772..5d3bba69 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
@@ -71,6 +71,12 @@ class rasteralgebraTest extends TestBaseScala with
BeforeAndAfter with GivenWhen
assert(inputDf.first().getAs[mutable.WrappedArray[Double]](0) ==
expectedDF.first().getAs[mutable.WrappedArray[Double]](0))
}
+ it("Passed RS_MultiplyFactor with double factor") {
+ val inputDf = Seq((Seq(200.0, 400.0, 600.0))).toDF("Band")
+ val expectedDF = Seq((Seq(20.0, 40.0, 60.0))).toDF("multiply")
+ val actualDF = inputDf.selectExpr("RS_MultiplyFactor(Band, 0.1) as
multiply")
+ assert(actualDF.first().getAs[mutable.WrappedArray[Double]](0) ==
expectedDF.first().getAs[mutable.WrappedArray[Double]](0))
+ }
}
describe("Should pass basic statistical tests") {
@@ -202,6 +208,13 @@ class rasteralgebraTest extends TestBaseScala with
BeforeAndAfter with GivenWhen
df = df.selectExpr("RS_Normalize(Band) as normalizedBand")
assert(df.first().getAs[mutable.WrappedArray[Double]](0)(1) == 255)
}
+
+ it("should pass RS_Array") {
+ val df = sparkSession.sql("SELECT RS_Array(6, 1e-6) as band")
+ val result = df.first().getAs[mutable.WrappedArray[Double]](0)
+ assert(result.length == 6)
+ assert(result sameElements Array.fill[Double](6)(1e-6))
+ }
}
describe("Should pass all transformation tests") {