Copilot commented on code in PR #2673:
URL: https://github.com/apache/sedona/pull/2673#discussion_r2851315423
##########
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala:
##########
@@ -20,16 +20,52 @@ package org.apache.spark.sql.sedona_sql.io.raster
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-private[io] class RasterOptions(@transient private val parameters:
CaseInsensitiveMap[String])
+class RasterOptions(@transient private val parameters:
CaseInsensitiveMap[String])
extends Serializable {
def this(parameters: Map[String, String]) =
this(CaseInsensitiveMap(parameters))
+ // The following options are used to read raster data
+
+ /**
+ * Whether to retile the raster data. If true, the raster data will be
retiled into smaller
+ * tiles. If false, the raster data will be read as a single tile.
+ */
+ val retile: Boolean = parameters.getOrElse("retile", "true").toBoolean
+
+ /**
+ * The width of the tile. This is only effective when retile is true. If
retile is true and
+ * tileWidth is not set, the default value is the width of the internal
tiles in the raster
+ * files. Each raster file may have different internal tile sizes.
+ */
+ val tileWidth: Option[Int] = parameters.get("tileWidth").map(_.toInt)
+
+ /**
+ * The height of the tile. This is only effective when retile is true. If
retile is true and
+ * tileHeight is not set, the default value is the same as tileWidth. If
tileHeight is set,
+ * tileWidth must be set as well.
+ */
+ val tileHeight: Option[Int] = parameters
+ .get("tileHeight")
+ .map { value =>
+ require(tileWidth.isDefined, "tileWidth must be set when tileHeight is
set")
+ value.toInt
+ }
+ .orElse(tileWidth)
+
Review Comment:
`tileHeight` validation currently throws when `tileHeight` is set without
`tileWidth` even if `retile=false` (where tiling options are documented as
having no effect). Consider only enforcing the `tileWidth`/`tileHeight` pairing
when `retile=true`, or ignoring tiling-specific options entirely when
`retile=false` to avoid surprising failures.
##########
common/src/main/java/org/apache/sedona/common/raster/inputstream/HadoopImageInputStream.java:
##########
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.common.raster.inputstream;
+
+import java.io.IOException;
+import javax.imageio.stream.ImageInputStreamImpl;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+
+/** An ImageInputStream that reads image data from a Hadoop FileSystem. */
+public class HadoopImageInputStream extends ImageInputStreamImpl {
+
+ private final FSDataInputStream stream;
+ private final Path path;
+ private final Configuration conf;
+
+ public HadoopImageInputStream(Path path, Configuration conf) throws
IOException {
+ FileSystem fs = path.getFileSystem(conf);
+ stream = fs.open(path);
+ this.path = path;
+ this.conf = conf;
+ }
+
+ public HadoopImageInputStream(Path path) throws IOException {
+ this(path, new Configuration());
+ }
+
+ public HadoopImageInputStream(FSDataInputStream stream) {
+ this.stream = stream;
+ this.path = null;
+ this.conf = null;
+ }
+
+ public Path getPath() {
+ return path;
+ }
+
+ public Configuration getConf() {
+ return conf;
+ }
+
+ @Override
+ public void close() throws IOException {
+ super.close();
+ stream.close();
+ }
+
+ @Override
+ public int read() throws IOException {
+ byte[] buf = new byte[1];
+ int ret_len = read(buf, 0, 1);
+ if (ret_len < 0) {
+ return ret_len;
+ }
+ return buf[0] & 0xFF;
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ checkClosed();
+ bitOffset = 0;
+
+ if (len == 0) {
+ return 0;
+ }
+
+ // stream.read may return fewer data than requested, so we need to loop
until we get all the
+ // data, or hit the end of the stream. We can not simply perform an
incomplete read and return
+ // the number of bytes actually read, since the methods in
ImageInputStreamImpl such as
+ // readInt() relies on this method and assumes that partial reads only
happens when reading
+ // EOF. This might be a bug of imageio since they should invoke
readFully() in such cases.
+ int remaining = len;
+ while (remaining > 0) {
+ int ret_len = stream.read(b, off, remaining);
+ if (ret_len < 0) {
+ // Hit EOF, no more data to read.
Review Comment:
`read(byte[], off, len)` loops until `remaining == 0`, but it never handles
the case where `stream.read(...)` returns `0` for a positive `remaining`. That
would result in an infinite loop / hang. Add a guard for `ret_len == 0` (e.g.,
retry with a bounded number of attempts or throw an `IOException` indicating a
non-progressing stream).
##########
spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala:
##########
@@ -239,5 +248,283 @@ class rasterIOTest extends TestBaseScala with
BeforeAndAfter with GivenWhenThen
}
}
- override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir))
+ describe("Raster read test") {
+ it("should read geotiff using raster source with explicit tiling") {
+ val rasterDf = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "true", "tileWidth" -> "64"))
+ .load(rasterdatalocation)
+ assert(rasterDf.count() > 100)
+ rasterDf.collect().foreach { row =>
+ val raster = row.getAs[Object](0).asInstanceOf[GridCoverage2D]
+ assert(raster.getGridGeometry.getGridRange2D.width <= 64)
+ assert(raster.getGridGeometry.getGridRange2D.height <= 64)
+ val x = row.getInt(1)
+ val y = row.getInt(2)
+ assert(x >= 0 && y >= 0)
+ raster.dispose(true)
+ }
+
+ // Test projection push-down
+ rasterDf.selectExpr("y", "rast as r").collect().foreach { row =>
+ val raster = row.getAs[Object](1).asInstanceOf[GridCoverage2D]
+ assert(raster.getGridGeometry.getGridRange2D.width <= 64)
+ assert(raster.getGridGeometry.getGridRange2D.height <= 64)
+ val y = row.getInt(0)
+ assert(y >= 0)
+ raster.dispose(true)
+ }
+ }
+
+ it("should tile geotiff using raster source with padding enabled") {
+ val rasterDf = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "true", "tileWidth" -> "64", "padWithNoData"
-> "true"))
+ .load(rasterdatalocation)
+ assert(rasterDf.count() > 100)
+ rasterDf.collect().foreach { row =>
+ val raster = row.getAs[Object](0).asInstanceOf[GridCoverage2D]
+ assert(raster.getGridGeometry.getGridRange2D.width == 64)
+ assert(raster.getGridGeometry.getGridRange2D.height == 64)
+ val x = row.getInt(1)
+ val y = row.getInt(2)
+ assert(x >= 0 && y >= 0)
+ raster.dispose(true)
+ }
+ }
+
+ it("should push down limit and sample to data source") {
+ FileUtils.cleanDirectory(new File(tempDir))
+
+ val sourceDir = new File(rasterdatalocation)
+ val files = sourceDir.listFiles().filter(_.isFile)
+ var numUniqueFiles = 0
+ var numTotalFiles = 0
+ files.foreach { file =>
+ if (file.getPath.endsWith(".tif") || file.getPath.endsWith(".tiff")) {
+ // Create 4 copies for each file
+ for (i <- 0 until 4) {
+ val destFile = new File(tempDir + "/" + file.getName + "_" + i)
+ FileUtils.copyFile(file, destFile)
+ numTotalFiles += 1
+ }
+ numUniqueFiles += 1
+ }
+ }
+
+ val df = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "false"))
+ .load(tempDir)
+ .withColumn("width", expr("RS_Width(rast)"))
+
+ val dfWithLimit = df.limit(numUniqueFiles)
+ val plan = queryPlan(dfWithLimit)
+ // Global/local limits are all pushed down to data source
+ assert(plan.collect { case e: LimitExec => e }.isEmpty)
+ assert(dfWithLimit.count() == numUniqueFiles)
+
+ val dfWithSample = df.sample(0.3, seed = 42)
+ val planSample = queryPlan(dfWithSample)
+ // Sample is pushed down to data source
+ assert(planSample.collect { case e: SampleExec => e }.isEmpty)
+ val count = dfWithSample.count()
+ assert(count >= numTotalFiles * 0.1 && count <= numTotalFiles * 0.5)
+
+ val dfWithSampleAndLimit = df.sample(0.5, seed =
42).limit(numUniqueFiles)
+ val planBoth = queryPlan(dfWithSampleAndLimit)
+ assert(planBoth.collect { case e: LimitExec => e }.isEmpty)
+ assert(planBoth.collect { case e: SampleExec => e }.isEmpty)
+ assert(dfWithSampleAndLimit.count() == numUniqueFiles)
+
+ // Limit and sample cannot be fully pushed down when retile is enabled
+ val dfReTiledWithSampleAndLimit = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "true"))
+ .load(tempDir)
+ .sample(0.5, seed = 42)
+ .limit(numUniqueFiles)
+ dfReTiledWithSampleAndLimit.explain(true)
Review Comment:
`explain(true)` in tests prints verbose plans to stdout/stderr and can
create noisy CI logs. Consider removing it (the plan is already inspected via
`queryPlan(...)`).
```suggestion
```
##########
common/src/test/java/org/apache/sedona/common/raster/inputstream/HadoopImageInputStreamTest.java:
##########
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.common.raster.inputstream;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.RandomAccessFile;
+import java.nio.file.Files;
+import java.util.Random;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PositionedReadable;
+import org.apache.hadoop.fs.Seekable;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+public class HadoopImageInputStreamTest {
+ @Rule public TemporaryFolder temp = new TemporaryFolder();
+
+ private static final int TEST_FILE_SIZE = 1000;
+ private final Random random = new Random();
+ private File testFile;
+
+ @Before
+ public void setup() throws IOException {
+ testFile = temp.newFile();
+ prepareTestData(testFile);
+ }
+
+ @Test
+ public void testReadSequentially() throws IOException {
+ Path path = new Path(testFile.getPath());
+ try (HadoopImageInputStream stream = new HadoopImageInputStream(path);
+ InputStream in = new
BufferedInputStream(Files.newInputStream(testFile.toPath()))) {
+ byte[] bActual = new byte[8];
+ byte[] bExpected = new byte[bActual.length];
+ while (true) {
+ int len = random.nextInt(bActual.length + 1);
+ int lenActual = stream.read(bActual, 0, len);
+ int lenExpected = in.read(bExpected, 0, len);
+ Assert.assertEquals(lenExpected, lenActual);
+ if (lenActual < 0) {
+ break;
+ }
+ Assert.assertArrayEquals(bExpected, bActual);
+ }
+ }
+ }
+
+ @Test
+ public void testReadRandomly() throws IOException {
+ Path path = new Path(testFile.getPath());
+ try (HadoopImageInputStream stream = new HadoopImageInputStream(path);
+ RandomAccessFile raf = new RandomAccessFile(testFile, "r")) {
+ byte[] bActual = new byte[8];
+ byte[] bExpected = new byte[bActual.length];
+ for (int k = 0; k < 1000; k++) {
+ int offset = random.nextInt(TEST_FILE_SIZE + 1);
+ int len = random.nextInt(bActual.length + 1);
+ stream.seek(offset);
+ raf.seek(offset);
+ int lenActual = stream.read(bActual, 0, len);
+ int lenExpected = raf.read(bExpected, 0, len);
+ Assert.assertEquals(lenExpected, lenActual);
+ if (lenActual < 0) {
+ continue;
+ }
+ Assert.assertArrayEquals(bExpected, bActual);
+ }
+
+ // Test seek to EOF.
+ stream.seek(TEST_FILE_SIZE);
+ int len = stream.read(bActual, 0, bActual.length);
+ Assert.assertEquals(-1, len);
+ }
+ }
+
+ @Test
+ public void testFromUnstableStream() throws IOException {
+ Path path = new Path(testFile.getPath());
+ FileSystem fs = path.getFileSystem(new Configuration());
+ try (FSDataInputStream unstable = new
UnstableFSDataInputStream(fs.open(path));
+ HadoopImageInputStream stream = new HadoopImageInputStream(unstable);
+ InputStream in = new
BufferedInputStream(Files.newInputStream(testFile.toPath()))) {
+ byte[] bActual = new byte[8];
+ byte[] bExpected = new byte[bActual.length];
+ while (true) {
+ int len = random.nextInt(bActual.length);
+ int lenActual = stream.read(bActual, 0, len);
+ int lenExpected = in.read(bExpected, 0, len);
+ Assert.assertEquals(lenExpected, lenActual);
+ if (lenActual < 0) {
+ break;
+ }
+ Assert.assertArrayEquals(bExpected, bActual);
+ }
+ }
+ }
+
+ private void prepareTestData(File testFile) throws IOException {
+ try (OutputStream out = new
BufferedOutputStream(Files.newOutputStream(testFile.toPath()))) {
+ for (int k = 0; k < TEST_FILE_SIZE; k++) {
+ out.write(random.nextInt());
+ }
+ }
+ }
+
+ /**
+ * An FSDataInputStream that sometimes return less data than requested when
calling read(byte[],
+ * int, int).
+ */
+ private static class UnstableFSDataInputStream extends FSDataInputStream {
+ public UnstableFSDataInputStream(FSDataInputStream in) {
+ super(new UnstableInputStream(in.getWrappedStream()));
+ }
+
+ private static class UnstableInputStream extends InputStream
+ implements Seekable, PositionedReadable {
+ private final InputStream wrapped;
+ private final Random random = new Random();
+
+ UnstableInputStream(InputStream in) {
+ wrapped = in;
+ }
+
+ @Override
+ public void close() throws IOException {
+ wrapped.close();
+ }
+
+ @Override
+ public int read() throws IOException {
+ return wrapped.read();
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ // Make this read unstable, i.e. sometimes return less data than
requested.
+ int unstableLen = random.nextInt(len + 1);
Review Comment:
`UnstableInputStream.read(byte[], off, len)` can choose `unstableLen = 0`
when `len > 0`, which violates `InputStream`'s contract (must block until at
least 1 byte or EOF) and will cause `HadoopImageInputStream.read(...)` to
spin/hang. Ensure `unstableLen` is at least 1 when `len > 0` so this test
reliably simulates partial (but progressing) reads.
```suggestion
// Make this read unstable, i.e. sometimes return less data than
requested,
// but still obey InputStream's contract by not returning 0 when len
> 0.
if (len == 0) {
return 0;
}
int unstableLen = 1 + random.nextInt(len);
```
##########
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala:
##########
@@ -107,39 +109,35 @@ private[apache] case class RS_TileExplode(children:
Seq[Expression])
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val raster = arguments.rasterExpr.toRaster(input)
- try {
- val bandIndices =
arguments.bandIndicesExpr.eval(input).asInstanceOf[ArrayData] match {
- case null => null
- case value: Any => value.toIntArray
- }
- val tileWidth = arguments.tileWidthExpr.eval(input).asInstanceOf[Int]
- val tileHeight = arguments.tileHeightExpr.eval(input).asInstanceOf[Int]
- val padWithNoDataValue =
arguments.padWithNoDataExpr.eval(input).asInstanceOf[Boolean]
- val noDataValue = arguments.noDataValExpr.eval(input) match {
- case null => Double.NaN
- case value: Integer => value.toDouble
- case value: Decimal => value.toDouble
- case value: Float => value.toDouble
- case value: Double => value
- case value: Any =>
- throw new IllegalArgumentException(
- "Unsupported class for noDataValue: " + value.getClass)
- }
- val tiles = RasterConstructors.generateTiles(
- raster,
- bandIndices,
- tileWidth,
- tileHeight,
- padWithNoDataValue,
- noDataValue)
- tiles.map { tile =>
- val gridCoverage2D = tile.getCoverage
- val row = InternalRow(tile.getTileX, tile.getTileY,
gridCoverage2D.serialize)
- gridCoverage2D.dispose(true)
- row
- }
- } finally {
- raster.dispose(true)
+ val bandIndices =
arguments.bandIndicesExpr.eval(input).asInstanceOf[ArrayData] match {
+ case null => null
+ case value: Any => value.toIntArray
+ }
+ val tileWidth = arguments.tileWidthExpr.eval(input).asInstanceOf[Int]
+ val tileHeight = arguments.tileHeightExpr.eval(input).asInstanceOf[Int]
+ val padWithNoDataValue =
arguments.padWithNoDataExpr.eval(input).asInstanceOf[Boolean]
+ val noDataValue = arguments.noDataValExpr.eval(input) match {
+ case null => Double.NaN
+ case value: Integer => value.toDouble
+ case value: Decimal => value.toDouble
+ case value: Float => value.toDouble
+ case value: Double => value
+ case value: Any =>
+ throw new IllegalArgumentException("Unsupported class for noDataValue:
" + value.getClass)
+ }
+ val tileIterator = RasterConstructors.generateTiles(
+ raster,
+ bandIndices,
+ tileWidth,
+ tileHeight,
+ padWithNoDataValue,
+ noDataValue)
+ tileIterator.setAutoDisposeSource(true)
+ tileIterator.asScala.map { tile =>
+ val gridCoverage2D = tile.getCoverage
+ val row = InternalRow(tile.getTileX, tile.getTileY,
gridCoverage2D.serialize)
+ gridCoverage2D.dispose(true)
+ row
}
Review Comment:
`toRaster` can return `null` (e.g., when the input column is null).
`RS_TileExplode.eval` currently passes that null into
`RasterConstructors.generateTiles`, which will NPE. Add an explicit null check
and return `Nil`/`Iterator.empty` (consistent with other raster expressions)
when `raster` is null.
```suggestion
if (raster == null) {
Iterator.empty
} else {
val bandIndices =
arguments.bandIndicesExpr.eval(input).asInstanceOf[ArrayData] match {
case null => null
case value: Any => value.toIntArray
}
val tileWidth = arguments.tileWidthExpr.eval(input).asInstanceOf[Int]
val tileHeight = arguments.tileHeightExpr.eval(input).asInstanceOf[Int]
val padWithNoDataValue =
arguments.padWithNoDataExpr.eval(input).asInstanceOf[Boolean]
val noDataValue = arguments.noDataValExpr.eval(input) match {
case null => Double.NaN
case value: Integer => value.toDouble
case value: Decimal => value.toDouble
case value: Float => value.toDouble
case value: Double => value
case value: Any =>
throw new IllegalArgumentException("Unsupported class for
noDataValue: " + value.getClass)
}
val tileIterator = RasterConstructors.generateTiles(
raster,
bandIndices,
tileWidth,
tileHeight,
padWithNoDataValue,
noDataValue)
tileIterator.setAutoDisposeSource(true)
tileIterator.asScala.map { tile =>
val gridCoverage2D = tile.getCoverage
val row = InternalRow(tile.getTileX, tile.getTileY,
gridCoverage2D.serialize)
gridCoverage2D.dispose(true)
row
}
}
```
##########
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReader.scala:
##########
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.io.raster
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.sedona.common.raster.RasterConstructors
+import org.apache.sedona.common.raster.inputstream.HadoopImageInputStream
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
+import org.apache.spark.sql.connector.read.PartitionReader
+import org.apache.spark.sql.execution.datasources.PartitionedFile
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
+import
org.apache.spark.sql.sedona_sql.io.raster.RasterPartitionReader.rasterToInternalRows
+import
org.apache.spark.sql.sedona_sql.io.raster.RasterTable.{MAX_AUTO_TILE_SIZE,
RASTER, RASTER_NAME, TILE_X, TILE_Y}
+import org.apache.spark.sql.types.StructType
+import org.geotools.coverage.grid.GridCoverage2D
+
+import java.net.URI
+import scala.collection.JavaConverters._
+
+class RasterPartitionReader(
+ configuration: Configuration,
+ partitionedFiles: Array[PartitionedFile],
+ dataSchema: StructType,
+ rasterOptions: RasterOptions)
+ extends PartitionReader[InternalRow] {
+
+ // Track the current file index we're processing
+ private var currentFileIndex = 0
+
+ // Current raster being processed
+ private var currentRaster: GridCoverage2D = _
+
+ // Current image input stream (must be kept open while the raster is in use)
+ private var currentImageStream: HadoopImageInputStream = _
+
+ // Current row
+ private var currentRow: InternalRow = _
+
+ // Iterator for the current file's tiles
+ private var currentIterator: Iterator[InternalRow] = Iterator.empty
+
+ override def next(): Boolean = {
+ // If current iterator has more elements, return true
+ if (currentIterator.hasNext) {
+ currentRow = currentIterator.next()
+ return true
+ }
+
+ // If current iterator is exhausted, but we have more files, load the next
file
+ if (currentFileIndex < partitionedFiles.length) {
+ loadNextFile()
+ if (currentIterator.hasNext) {
+ currentRow = currentIterator.next()
+ return true
+ }
+ }
+
+ // No more data
+ false
+ }
+
+ override def get(): InternalRow = {
+ currentRow
+ }
+
+ override def close(): Unit = {
+ if (currentRaster != null) {
+ currentRaster.dispose(true)
+ currentRaster = null
+ }
+ if (currentImageStream != null) {
+ currentImageStream.close()
+ currentImageStream = null
+ }
+ }
+
+ private def loadNextFile(): Unit = {
+ // Clean up previous raster and stream if exists
+ if (currentRaster != null) {
+ currentRaster.dispose(true)
+ currentRaster = null
+ }
+ if (currentImageStream != null) {
+ currentImageStream.close()
+ currentImageStream = null
+ }
+
+ if (currentFileIndex >= partitionedFiles.length) {
+ currentIterator = Iterator.empty
+ return
+ }
+
+ val partition = partitionedFiles(currentFileIndex)
+ val path = new Path(new URI(partition.filePath.toString()))
+
+ try {
+ // Open a stream-based reader instead of materializing the entire file
as byte[].
+ // This avoids the 2 GB byte[] limit and reduces memory pressure for
large files.
+ currentImageStream = new HadoopImageInputStream(path, configuration)
+
+ // Create in-db GridCoverage2D from GeoTiff stream. The RenderedImage is
lazy -
+ // pixel data will only be decoded when accessed via
image.getData(Rectangle).
+ currentRaster = RasterConstructors.fromGeoTiff(currentImageStream)
+ currentIterator = rasterToInternalRows(currentRaster, dataSchema,
rasterOptions, path)
+ currentFileIndex += 1
+ } catch {
+ case e: Exception =>
+ if (currentRaster != null) {
+ currentRaster.dispose(true)
+ currentRaster = null
+ }
+ if (currentImageStream != null) {
+ currentImageStream.close()
+ currentImageStream = null
+ }
+ throw e
+ }
+ }
+}
+
+object RasterPartitionReader {
+ def rasterToInternalRows(
+ currentRaster: GridCoverage2D,
+ dataSchema: StructType,
+ rasterOptions: RasterOptions,
+ path: Path): Iterator[InternalRow] = {
+ val retile = rasterOptions.retile
+ val tileWidth = rasterOptions.tileWidth
+ val tileHeight = rasterOptions.tileHeight
+ val padWithNoData = rasterOptions.padWithNoData
+
+ val writer = new UnsafeRowWriter(dataSchema.length)
+ writer.resetRowWriter()
+
+ // Extract the file name from the path
+ val fileName = path.getName
+
+ if (retile) {
+ val (tw, th) = (tileWidth, tileHeight) match {
+ case (Some(tw), Some(th)) => (tw, th)
+ case (None, None) =>
+ // Use the internal tile size of the input raster
+ val tw = currentRaster.getRenderedImage.getTileWidth
+ val th = currentRaster.getRenderedImage.getTileHeight
+ val tileSizeError = {
+ """To resolve this issue, you can try one of the following methods:
+ | 1. Disable retile by setting `.option("retile", "false")`.
+ | 2. Explicitly set `tileWidth` and `tileHeight`.
+ | 3. Convert the raster to a Cloud Optimized GeoTIFF (COG) using
tools like `gdal_translate`.
+ |""".stripMargin
+ }
+ if (tw >= MAX_AUTO_TILE_SIZE || th >= MAX_AUTO_TILE_SIZE) {
+ throw new IllegalArgumentException(
+ s"Internal tile size of $path is too large ($tw x $th). " +
tileSizeError)
+ }
+ if (tw == 0 || th == 0) {
+ throw new IllegalArgumentException(
+ s"Internal tile size of $path contains zero ($tw x $th). " +
tileSizeError)
+ }
+ if (tw / th > 10 || th / tw > 10) {
Review Comment:
The “too thin” check uses integer division (`tw / th`), which truncates and
can miss thin tiles (e.g., 4095/400 becomes 10 instead of 10.23). Use
floating-point division (`tw.toDouble / th` and vice versa) to correctly
enforce the aspect ratio constraint.
```suggestion
if (tw.toDouble / th > 10.0 || th.toDouble / tw > 10.0) {
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]