This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new a7ad5f011 [SEDONA-543] Fixes RS_Union_aggr throwing referenceRaster is
null error when run on cluster (#1364)
a7ad5f011 is described below
commit a7ad5f01158996fc047b80c1baf87353842d932c
Author: Pranav Toggi <[email protected]>
AuthorDate: Fri Apr 26 16:19:25 2024 -0400
[SEDONA-543] Fixes RS_Union_aggr throwing referenceRaster is null error
when run on cluster (#1364)
* Init: move class level members to data buffer
* move sampleDimension serde to Serde.java
* update serde for sampleDimensions
* Add checks for index
* Undo typo
* add custom GridSampleDimensionSerializer
---
.../apache/sedona/common/raster/serde/Serde.java | 20 ++++
.../expressions/raster/AggregateFunctions.scala | 126 ++++++++++++---------
2 files changed, 90 insertions(+), 56 deletions(-)
diff --git
a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
index 616ded015..848c00b3f 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
@@ -32,6 +32,8 @@ import org.opengis.referencing.operation.MathTransform;
import javax.media.jai.RenderedImageAdapter;
import java.awt.image.RenderedImage;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
@@ -176,4 +178,22 @@ public class Serde {
return state.restore();
}
}
+
+ public static byte[] serializeGridSampleDimension(GridSampleDimension
sampleDimension) {
+ Kryo kryo = kryos.get();
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ Output output = new Output(baos);
+ GridSampleDimensionSerializer serializer = new
GridSampleDimensionSerializer();
+ serializer.write(kryo, output, sampleDimension);
+ output.close();
+ return baos.toByteArray();
+ }
+
+ public static GridSampleDimension deserializeGridSampleDimension(byte[]
data) {
+ Kryo kryo = kryos.get();
+ Input input = new Input(new ByteArrayInputStream(data));
+ GridSampleDimensionSerializer serializer = new
GridSampleDimensionSerializer();
+ return serializer.read(kryo, input, GridSampleDimension.class);
+ }
+
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
index 1fa1cb6e7..b76841638 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
@@ -19,6 +19,7 @@
package org.apache.spark.sql.sedona_sql.expressions.raster
+import org.apache.sedona.common.raster.serde.Serde
import org.apache.sedona.common.raster.{RasterAccessors, RasterBandAccessors}
import org.apache.sedona.common.utils.RasterUtils
import org.apache.spark.sql.Encoder
@@ -29,93 +30,106 @@ import org.geotools.coverage.grid.GridCoverage2D
import java.awt.image.WritableRaster
import javax.media.jai.RasterFactory
-import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-case class BandData(var bandInt: Array[Int], var bandDouble: Array[Double],
var index: Int, var isIntegral: Boolean)
+case class BandData(
+ var bandInt: Array[Int],
+ var bandDouble: Array[Double],
+ var index: Int,
+ var isIntegral: Boolean,
+ var serializedRaster: Array[Byte],
+ var serializedSampleDimension: Array[Byte]
+ )
+
/**
* Return a raster containing bands at given indexes from all rasters in a
given column
*/
class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int),
ArrayBuffer[BandData], GridCoverage2D] {
- var width: Int = -1
-
- var height: Int = -1
-
- var referenceRaster: GridCoverage2D = _
-
- var gridSampleDimension: mutable.Map[Int, GridSampleDimension] = new
mutable.HashMap()
-
def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()
- /**
- * Valid raster shape to be the same in the given column
- */
- def checkRasterShape(raster: GridCoverage2D): Boolean = {
- // first iteration
- if (width == -1 && height == -1) {
- width = RasterAccessors.getWidth(raster)
- height = RasterAccessors.getHeight(raster)
- referenceRaster = raster
- true
- } else {
- val widthNewRaster = RasterAccessors.getWidth(raster)
- val heightNewRaster = RasterAccessors.getHeight(raster)
-
- width == widthNewRaster && height == heightNewRaster
- }
- }
-
def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)):
ArrayBuffer[BandData] = {
val raster = input._1
- if (!checkRasterShape(raster)) {
- throw new IllegalArgumentException("Rasters provides should be of the
same shape.")
- }
- if (gridSampleDimension.contains(input._2)) {
- throw new IllegalArgumentException("Indexes shouldn't be repeated. Index
should be in an arithmetic sequence.")
- }
-
val rasterData = RasterUtils.getRaster(raster.getRenderedImage)
val isIntegral =
RasterUtils.isDataTypeIntegral(rasterData.getDataBuffer.getDataType)
- val bandData = if (isIntegral) {
- val band = rasterData.getSamples(0, 0, width, height, 0,
null.asInstanceOf[Array[Int]])
- BandData(band, null, input._2, isIntegral)
+ // Serializing GridSampleDimension
+ val serializedBytes =
Serde.serializeGridSampleDimension(raster.getSampleDimension(0))
+
+ // Check and set dimensions based on the first raster in the buffer
+ if (buffer.isEmpty) {
+ val width = RasterAccessors.getWidth(raster)
+ val height = RasterAccessors.getHeight(raster)
+ val referenceSerializedRaster = Serde.serialize(raster)
+
+ buffer += BandData(
+ if (isIntegral) rasterData.getSamples(0, 0, width, height, 0,
null.asInstanceOf[Array[Int]]) else null,
+ if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0,
null.asInstanceOf[Array[Double]]) else null,
+ input._2,
+ isIntegral,
+ referenceSerializedRaster,
+ serializedBytes
+ )
} else {
- val band = rasterData.getSamples(0, 0, width, height, 0,
null.asInstanceOf[Array[Double]])
- BandData(null, band, input._2, isIntegral)
+ val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
+ val width = RasterAccessors.getWidth(referenceRaster)
+ val height = RasterAccessors.getHeight(referenceRaster)
+
+ if (width != RasterAccessors.getWidth(raster) || height !=
RasterAccessors.getHeight(raster)) {
+ throw new IllegalArgumentException("All rasters must have the same
dimensions")
+ }
+
+ buffer += BandData(
+ if (isIntegral) rasterData.getSamples(0, 0, width, height, 0,
null.asInstanceOf[Array[Int]]) else null,
+ if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0,
null.asInstanceOf[Array[Double]]) else null,
+ input._2,
+ isIntegral,
+ Serde.serialize(raster),
+ serializedBytes
+ )
}
- gridSampleDimension = gridSampleDimension + (input._2 ->
raster.getSampleDimension(0))
- buffer += bandData
+ buffer
}
+
def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]):
ArrayBuffer[BandData] = {
- ArrayBuffer.concat(buffer1, buffer2)
+ val combined = ArrayBuffer.concat(buffer1, buffer2)
+ if (combined.map(_.index).distinct.length != combined.length) {
+ throw new IllegalArgumentException("Indexes shouldn't be repeated.")
+ }
+ combined
}
+
def finish(merged: ArrayBuffer[BandData]): GridCoverage2D = {
val sortedMerged = merged.sortBy(_.index)
+ if (sortedMerged.zipWithIndex.exists { case (band, idx) =>
+ if (idx > 0) (band.index - sortedMerged(idx - 1).index) !=
(sortedMerged(1).index - sortedMerged(0).index)
+ else false
+ }) {
+ throw new IllegalArgumentException("Index should be in an arithmetic
sequence.")
+ }
+
val numBands = sortedMerged.length
- val rasterData = RasterUtils.getRaster(referenceRaster.getRenderedImage)
- val dataTypeCode = rasterData.getDataBuffer.getDataType
+ val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster)
+ val width = RasterAccessors.getWidth(referenceRaster)
+ val height = RasterAccessors.getHeight(referenceRaster)
+ val dataTypeCode =
RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
val resultRaster: WritableRaster =
RasterFactory.createBandedRaster(dataTypeCode, width, height, numBands, null)
val gridSampleDimensions: Array[GridSampleDimension] = new
Array[GridSampleDimension](numBands)
- var indexCheck = 1
- for (bandData: BandData <- sortedMerged) {
- if (bandData.index != indexCheck) {
- throw new IllegalArgumentException("Indexes should be in a valid
arithmetic sequence.")
- }
- indexCheck += 1
- gridSampleDimensions(bandData.index - 1) =
gridSampleDimension(bandData.index)
- if(RasterUtils.isDataTypeIntegral(dataTypeCode))
- resultRaster.setSamples(0, 0, width, height, (bandData.index - 1),
bandData.bandInt)
- else
- resultRaster.setSamples(0, 0, width, height, bandData.index - 1,
bandData.bandDouble)
+ for ((bandData, idx) <- sortedMerged.zipWithIndex) {
+ // Deserializing GridSampleDimension
+ gridSampleDimensions(idx) =
Serde.deserializeGridSampleDimension(bandData.serializedSampleDimension)
+ if(bandData.isIntegral)
+ resultRaster.setSamples(0, 0, width, height, idx, bandData.bandInt)
+ else
+ resultRaster.setSamples(0, 0, width, height, idx, bandData.bandDouble)
}
+
val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry,
gridSampleDimensions, referenceRaster, noDataValue, true)
}