Copilot commented on code in PR #2673: URL: https://github.com/apache/sedona/pull/2673#discussion_r2845919143
########## spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReader.scala: ########## @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.io.raster + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.sedona.common.raster.RasterConstructors +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.sedona_sql.UDT.RasterUDT +import org.apache.spark.sql.sedona_sql.io.raster.RasterPartitionReader.rasterToInternalRows +import org.apache.spark.sql.sedona_sql.io.raster.RasterTable.{MAX_AUTO_TILE_SIZE, RASTER, RASTER_NAME, TILE_X, TILE_Y} +import org.apache.spark.sql.types.StructType +import org.geotools.coverage.grid.GridCoverage2D + +import java.net.URI +import scala.collection.JavaConverters._ + +class RasterPartitionReader( + configuration: Configuration, + partitionedFiles: Array[PartitionedFile], + dataSchema: StructType, + rasterOptions: RasterOptions) + extends PartitionReader[InternalRow] { + + // Track the current file index we're processing + private var currentFileIndex = 0 + + // Current raster being processed + private var currentRaster: GridCoverage2D = _ + + // Current row + private var currentRow: InternalRow = _ + + // Iterator for the current file's tiles + private var currentIterator: Iterator[InternalRow] = Iterator.empty + + override def next(): Boolean = { + // If current iterator has more elements, return true + if (currentIterator.hasNext) { + currentRow = currentIterator.next() + return true + } + + // If current iterator is exhausted, but we have more files, load the next file + if (currentFileIndex < partitionedFiles.length) { + loadNextFile() + if (currentIterator.hasNext) { + currentRow = currentIterator.next() + return true + } + } + + // No more data + false + } + + override def get(): InternalRow = { + currentRow + } + + override def close(): Unit = { + if (currentRaster != null) { + currentRaster.dispose(true) + currentRaster = null + } + } + + private def loadNextFile(): Unit = { + // Clean up previous raster if exists + if (currentRaster != null) { + currentRaster.dispose(true) + currentRaster = null + } + + if (currentFileIndex >= partitionedFiles.length) { + currentIterator = Iterator.empty + return + } + + val partition = partitionedFiles(currentFileIndex) + val path = new Path(new URI(partition.filePath.toString())) + + try { + // Read file bytes from Hadoop FS + val fs = path.getFileSystem(configuration) + val fileStatus = fs.getFileStatus(path) + val fileLength = fileStatus.getLen.toInt + val bytes = new Array[Byte](fileLength) + val inputStream = fs.open(path) + try { + org.apache.hadoop.io.IOUtils.readFully(inputStream, bytes, 0, fileLength) + } finally { + inputStream.close() + } + + // Create in-db GridCoverage2D from GeoTiff bytes. The RenderedImage is lazy - + // pixel data will only be decoded when accessed via image.getData(Rectangle). + currentRaster = RasterConstructors.fromGeoTiff(bytes) + currentIterator = rasterToInternalRows(currentRaster, dataSchema, rasterOptions, path) Review Comment: This reader materializes the entire file into a single `Array[Byte]` and casts `getLen` to `Int`. For GeoTIFFs >2GB this will overflow / fail (and still hits Java array size limits), which undermines the stated goal of bypassing Spark’s 2GB record limit. Consider streaming the GeoTIFF via an `InputStream`/`ImageInputStream`-backed reader (or at least detect `getLen > Int.MaxValue` and fail with a clear error). ########## spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterScanBuilder.scala: ########## @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.io.raster + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.Batch +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.connector.read.SupportsPushDownLimit +import org.apache.spark.sql.connector.read.SupportsPushDownTableSample +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Random + +case class RasterScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap, + rasterOptions: RasterOptions) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownTableSample + with SupportsPushDownLimit { + + private var pushedTableSample: Option[TableSampleInfo] = None + private var pushedLimit: Option[Int] = None + + override def build(): Scan = { + RasterScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + rasterOptions, + pushedDataFilters, + partitionFilters, + dataFilters, + pushedTableSample, + pushedLimit) + } + + override def pushTableSample( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long): Boolean = { + if (withReplacement || rasterOptions.retile) { + false + } else { + pushedTableSample = Some(TableSampleInfo(lowerBound, upperBound, withReplacement, seed)) + true + } + } + + override def pushLimit(limit: Int): Boolean = { + pushedLimit = Some(limit) + true + } + + override def isPartiallyPushed: Boolean = rasterOptions.retile +} + +case class RasterScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + rasterOptions: RasterOptions, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty, + pushedTableSample: Option[TableSampleInfo] = None, + pushedLimit: Option[Int] = None) + extends FileScan + with Batch { + + private lazy val inputPartitions = { + var partitions = super.planInputPartitions() + + // Sample the files based on the table sample + pushedTableSample.foreach { tableSample => + val r = new Random(tableSample.seed) + var partitionIndex = 0 + partitions = partitions.flatMap { + case filePartition: FilePartition => + val files = filePartition.files + val sampledFiles = files.filter(_ => r.nextDouble() < tableSample.upperBound) + if (sampledFiles.nonEmpty) { + val index = partitionIndex + partitionIndex += 1 + Some(FilePartition(index, sampledFiles)) + } else { + None + } + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") Review Comment: `pushTableSample` receives both `lowerBound` and `upperBound`, but the sampling logic only uses `upperBound`. This changes Spark’s table-sample semantics for non-zero lower bounds. The filter should accept a file iff the random value is in `[lowerBound, upperBound)` (and still respect `withReplacement`). ```suggestion // Only push down sampling for without-replacement semantics. if (!tableSample.withReplacement) { val r = new Random(tableSample.seed) var partitionIndex = 0 partitions = partitions.flatMap { case filePartition: FilePartition => val files = filePartition.files val sampledFiles = files.filter { _ => val v = r.nextDouble() v >= tableSample.lowerBound && v < tableSample.upperBound } if (sampledFiles.nonEmpty) { val index = partitionIndex partitionIndex += 1 Some(FilePartition(index, sampledFiles)) } else { None } case partition => throw new IllegalArgumentException( s"Unexpected partition type: ${partition.getClass.getCanonicalName}") } ``` ########## spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterTable.scala: ########## @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.io.raster + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.SupportsRead +import org.apache.spark.sql.connector.catalog.SupportsWrite +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.connector.write.WriteBuilder +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.sedona_sql.UDT.RasterUDT +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.{Set => JSet} + +case class RasterTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + rasterOptions: RasterOptions, + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) + with SupportsRead + with SupportsWrite { + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = + Some(userSpecifiedSchema.getOrElse(RasterTable.inferSchema(rasterOptions))) + + override def formatName: String = "Raster" + + override def capabilities(): JSet[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + RasterScanBuilder(sparkSession, fileIndex, schema, dataSchema, options, rasterOptions) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = + // Note: this code path will never be taken, since Spark will always fall back to V1 + // data source when writing to File source v2. See SPARK-28396: File source v2 write + // path is currently broken. + null Review Comment: `newWriteBuilder` returns `null` while the table mixes in `SupportsWrite`. If Spark ever calls this code path, this will cause an NPE. Prefer either (a) removing `SupportsWrite` entirely since `capabilities()` only advertises `BATCH_READ`, or (b) returning a valid `WriteBuilder` / throwing a clear `UnsupportedOperationException`. ########## spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala: ########## @@ -107,39 +109,35 @@ private[apache] case class RS_TileExplode(children: Seq[Expression]) override def eval(input: InternalRow): TraversableOnce[InternalRow] = { val raster = arguments.rasterExpr.toRaster(input) - try { - val bandIndices = arguments.bandIndicesExpr.eval(input).asInstanceOf[ArrayData] match { - case null => null - case value: Any => value.toIntArray - } - val tileWidth = arguments.tileWidthExpr.eval(input).asInstanceOf[Int] - val tileHeight = arguments.tileHeightExpr.eval(input).asInstanceOf[Int] - val padWithNoDataValue = arguments.padWithNoDataExpr.eval(input).asInstanceOf[Boolean] - val noDataValue = arguments.noDataValExpr.eval(input) match { - case null => Double.NaN - case value: Integer => value.toDouble - case value: Decimal => value.toDouble - case value: Float => value.toDouble - case value: Double => value - case value: Any => - throw new IllegalArgumentException( - "Unsupported class for noDataValue: " + value.getClass) - } - val tiles = RasterConstructors.generateTiles( - raster, - bandIndices, - tileWidth, - tileHeight, - padWithNoDataValue, - noDataValue) - tiles.map { tile => - val gridCoverage2D = tile.getCoverage - val row = InternalRow(tile.getTileX, tile.getTileY, gridCoverage2D.serialize) - gridCoverage2D.dispose(true) - row - } - } finally { - raster.dispose(true) + val bandIndices = arguments.bandIndicesExpr.eval(input).asInstanceOf[ArrayData] match { + case null => null + case value: Any => value.toIntArray + } + val tileWidth = arguments.tileWidthExpr.eval(input).asInstanceOf[Int] + val tileHeight = arguments.tileHeightExpr.eval(input).asInstanceOf[Int] + val padWithNoDataValue = arguments.padWithNoDataExpr.eval(input).asInstanceOf[Boolean] + val noDataValue = arguments.noDataValExpr.eval(input) match { + case null => Double.NaN + case value: Integer => value.toDouble + case value: Decimal => value.toDouble + case value: Float => value.toDouble + case value: Double => value + case value: Any => + throw new IllegalArgumentException("Unsupported class for noDataValue: " + value.getClass) + } + val tileIterator = RasterConstructors.generateTiles( + raster, + bandIndices, + tileWidth, + tileHeight, + padWithNoDataValue, + noDataValue) + tileIterator.setAutoDisposeSource(true) + tileIterator.asScala.map { tile => + val gridCoverage2D = tile.getCoverage + val row = InternalRow(tile.getTileX, tile.getTileY, gridCoverage2D.serialize) + gridCoverage2D.dispose(true) + row Review Comment: `RS_TileExplode` no longer disposes the source raster in a `finally`. With the new lazy iterator, the source is only disposed when the iterator reaches the end; if iteration is short-circuited (e.g., by `limit`/cancellation) the raster can leak. Consider wiring a disposal hook that runs when Spark stops consuming the iterator (e.g., using a `CompletionIterator`/task completion listener), or otherwise ensure disposal even on early termination. ########## common/src/main/java/org/apache/sedona/common/raster/TileGenerator.java: ########## @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.common.raster; + +import java.awt.Rectangle; +import java.awt.image.Raster; +import java.awt.image.RenderedImage; +import java.awt.image.WritableRaster; +import java.util.Iterator; +import java.util.NoSuchElementException; +import javax.media.jai.RasterFactory; +import org.apache.sedona.common.utils.ImageUtils; +import org.apache.sedona.common.utils.RasterUtils; +import org.geotools.api.metadata.spatial.PixelOrientation; +import org.geotools.api.referencing.datum.PixelInCell; +import org.geotools.coverage.GridSampleDimension; +import org.geotools.coverage.grid.GridCoverage2D; +import org.geotools.coverage.grid.GridEnvelope2D; +import org.geotools.coverage.grid.GridGeometry2D; +import org.geotools.referencing.operation.transform.AffineTransform2D; + +public class TileGenerator { + + public static class Tile { + + private final int tileX; + private final int tileY; + private final GridCoverage2D coverage; + + public Tile(int tileX, int tileY, GridCoverage2D coverage) { + this.tileX = tileX; + this.tileY = tileY; + this.coverage = coverage; + } + + public int getTileX() { + return tileX; + } + + public int getTileY() { + return tileY; + } + + public GridCoverage2D getCoverage() { + return coverage; + } + } + + /** + * Generate tiles from an in-db grid coverage. The generated tiles are also in-db grid coverages. + * Pixel data will be copied into the tiles one tile at a time via a lazy iterator. + * + * @param gridCoverage2D the in-db grid coverage + * @param bandIndices the indices of the bands to select (1-based) + * @param tileWidth the width of the tiles + * @param tileHeight the height of the tiles + * @param padWithNoData whether to pad the tiles with no data value + * @param padNoDataValue the no data value for padded tiles, only used when padWithNoData is true. + * If the value is NaN, the no data value of the original band will be used. + * @return a lazy iterator of tiles + */ + public static InDbTileIterator generateInDbTiles( + GridCoverage2D gridCoverage2D, + int[] bandIndices, + int tileWidth, + int tileHeight, + boolean padWithNoData, + double padNoDataValue) { + return new InDbTileIterator( + gridCoverage2D, bandIndices, tileWidth, tileHeight, padWithNoData, padNoDataValue); + } + + public abstract static class TileIterator implements Iterator<Tile> { + protected final GridCoverage2D gridCoverage2D; + protected int numTileX; + protected int numTileY; + protected int tileX; + protected int tileY; + protected boolean autoDisposeSource = false; + protected Runnable disposeFunction = null; + + TileIterator(GridCoverage2D gridCoverage2D) { + this.gridCoverage2D = gridCoverage2D; + } + + protected void initialize(int numTileX, int numTileY) { + this.numTileX = numTileX; + this.numTileY = numTileY; + this.tileX = 0; + this.tileY = 0; + } + + /** + * Set whether to dispose the grid coverage when the iterator reaches the end. Default is false. + * + * @param autoDisposeSource whether to dispose the grid coverage + */ + public void setAutoDisposeSource(boolean autoDisposeSource) { + this.autoDisposeSource = autoDisposeSource; + } + + public void setDisposeFunction(Runnable disposeFunction) { + this.disposeFunction = disposeFunction; + } + + public int getNumTileX() { + return numTileX; + } + + public int getNumTileY() { + return numTileY; + } + + public int getNumTiles() { + return numTileX * numTileY; + } + + protected abstract Tile generateTile(); + + @Override + public boolean hasNext() { + if (numTileX == 0 || numTileY == 0) { + return false; + } + return tileY < numTileY; + } + + @Override + public Tile next() { + // Check if current tile coordinate is valid + if (tileX >= numTileX || tileY >= numTileY) { + throw new NoSuchElementException(); + } + + Tile tile = generateTile(); + + // Advance to the next tile + tileX += 1; + if (tileX >= numTileX) { + tileX = 0; + tileY += 1; + + // Dispose the grid coverage if we are at the end + if (tileY >= numTileY) { + if (autoDisposeSource) { + gridCoverage2D.dispose(true); + } + if (disposeFunction != null) { + disposeFunction.run(); + } + } + } + + return tile; + } + } + + public static class InDbTileIterator extends TileIterator { + private final int[] bandIndices; + private final int tileWidth; + private final int tileHeight; + private final boolean padWithNoData; + private final double padNoDataValue; + private final AffineTransform2D affine; + private final RenderedImage image; + private final double[] noDataValues; + private final int imageWidth; + private final int imageHeight; + + public InDbTileIterator( + GridCoverage2D gridCoverage2D, + int[] bandIndices, + int tileWidth, + int tileHeight, + boolean padWithNoData, + double padNoDataValue) { + super(gridCoverage2D); + this.bandIndices = bandIndices; + this.tileWidth = tileWidth; + this.tileHeight = tileHeight; + this.padWithNoData = padWithNoData; + this.padNoDataValue = padNoDataValue; + + affine = RasterUtils.getAffineTransform(gridCoverage2D, PixelOrientation.CENTER); + image = gridCoverage2D.getRenderedImage(); + noDataValues = new double[bandIndices.length]; + for (int i = 0; i < bandIndices.length; i++) { + noDataValues[i] = + RasterUtils.getNoDataValue(gridCoverage2D.getSampleDimension(bandIndices[i] - 1)); + } + imageWidth = image.getWidth(); + imageHeight = image.getHeight(); + int numTileX = (int) Math.ceil((double) imageWidth / tileWidth); + int numTileY = (int) Math.ceil((double) imageHeight / tileHeight); + initialize(numTileX, numTileY); + } + + @Override + protected Tile generateTile() { + // Process the current tile + int x0 = tileX * tileWidth + image.getMinX(); + int y0 = tileY * tileHeight + image.getMinY(); + + // Rect to copy from the original image + int rectWidth = Math.min(tileWidth, imageWidth - x0); + int rectHeight = Math.min(tileHeight, imageHeight - y0); Review Comment: `x0`/`y0` are offset by `image.getMinX/getMinY`, but `rectWidth`/`rectHeight` are computed using `imageWidth - x0` / `imageHeight - y0`, which is incorrect when `minX/minY` are non-zero and can truncate tiles. Compute remaining width/height using `(image.getMinX + imageWidth) - x0` (and similarly for Y), or compute `x0/y0` in the same coordinate system as `imageWidth/imageHeight`. ```suggestion int rectWidth = Math.min(tileWidth, (image.getMinX() + imageWidth) - x0); int rectHeight = Math.min(tileHeight, (image.getMinY() + imageHeight) - y0); ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
