This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 189cfd0cb1 [GH-2672] Add a new raster data source reader that can
automatically tile GeoTiffs and bypass the Spark record limit (#2673)
189cfd0cb1 is described below
commit 189cfd0cb1925ad4832d0c8ca23e0f804a39290d
Author: Jia Yu <[email protected]>
AuthorDate: Wed Feb 25 02:13:28 2026 -0700
[GH-2672] Add a new raster data source reader that can automatically tile
GeoTiffs and bypass the Spark record limit (#2673)
---
common/pom.xml | 5 +
.../sedona/common/raster/RasterConstructors.java | 156 ++---------
.../apache/sedona/common/raster/TileGenerator.java | 270 +++++++++++++++++++
.../raster/inputstream/HadoopImageInputStream.java | 124 +++++++++
.../common/raster/RasterConstructorsTest.java | 40 +--
.../inputstream/HadoopImageInputStreamTest.java | 204 +++++++++++++++
docs/api/sql/Raster-loader.md | 72 +++++-
...org.apache.spark.sql.sources.DataSourceRegister | 2 +-
.../expressions/raster/RasterConstructors.scala | 67 ++---
.../sedona_sql/io/raster/RasterDataSource.scala | 105 ++++++++
...terOptions.scala => RasterInputPartition.scala} | 20 +-
.../sql/sedona_sql/io/raster/RasterOptions.scala | 48 +++-
.../io/raster/RasterPartitionReader.scala | 221 ++++++++++++++++
.../io/raster/RasterPartitionReaderFactory.scala | 65 +++++
.../sedona_sql/io/raster/RasterScanBuilder.scala | 183 +++++++++++++
.../sql/sedona_sql/io/raster/RasterTable.scala | 94 +++++++
.../scala/org/apache/sedona/sql/rasterIOTest.scala | 288 ++++++++++++++++++++-
17 files changed, 1763 insertions(+), 201 deletions(-)
diff --git a/common/pom.xml b/common/pom.xml
index 67898b6be1..195d8cba5c 100644
--- a/common/pom.xml
+++ b/common/pom.xml
@@ -129,6 +129,11 @@
<groupId>org.datasyslab</groupId>
<artifactId>proj4sedona</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
</dependencies>
<build>
<sourceDirectory>src/main/java</sourceDirectory>
diff --git
a/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java
b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java
index e379b24648..b12241ab8a 100644
---
a/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java
+++
b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java
@@ -19,27 +19,23 @@
package org.apache.sedona.common.raster;
import java.awt.*;
-import java.awt.image.Raster;
-import java.awt.image.RenderedImage;
import java.awt.image.WritableRaster;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
+import javax.imageio.stream.ImageInputStream;
import javax.media.jai.RasterFactory;
import org.apache.sedona.common.FunctionsGeoTools;
import org.apache.sedona.common.raster.inputstream.ByteArrayImageInputStream;
import org.apache.sedona.common.raster.netcdf.NetCdfReader;
-import org.apache.sedona.common.utils.ImageUtils;
import org.apache.sedona.common.utils.RasterUtils;
import org.geotools.api.feature.simple.SimpleFeature;
import org.geotools.api.feature.simple.SimpleFeatureType;
-import org.geotools.api.metadata.spatial.PixelOrientation;
import org.geotools.api.referencing.FactoryException;
import org.geotools.api.referencing.crs.CoordinateReferenceSystem;
import org.geotools.api.referencing.datum.PixelInCell;
import org.geotools.api.referencing.operation.MathTransform;
-import org.geotools.coverage.GridSampleDimension;
import org.geotools.coverage.grid.GridCoverage2D;
import org.geotools.coverage.grid.GridEnvelope2D;
import org.geotools.coverage.grid.GridGeometry2D;
@@ -73,6 +69,21 @@ public class RasterConstructors {
return geoTiffReader.read(null);
}
+ /**
+ * Creates a GridCoverage2D from a GeoTIFF via an ImageInputStream. This
avoids materializing the
+ * entire file as a byte[], which is critical for files larger than 2 GB.
+ *
+ * @param inputStream an ImageInputStream positioned at the start of the
GeoTIFF data
+ * @return a GridCoverage2D with a lazily-decoded RenderedImage
+ * @throws IOException if the GeoTIFF cannot be read
+ */
+ public static GridCoverage2D fromGeoTiff(ImageInputStream inputStream)
throws IOException {
+ GeoTiffReader geoTiffReader =
+ new GeoTiffReader(
+ inputStream, new Hints(Hints.FORCE_LONGITUDE_FIRST_AXIS_ORDER,
Boolean.TRUE));
+ return geoTiffReader.read(null);
+ }
+
public static GridCoverage2D fromNetCDF(
byte[] bytes, String variableName, String lonDimensionName, String
latDimensionName)
throws IOException, FactoryException {
@@ -560,32 +571,9 @@ public class RasterConstructors {
return RasterUtils.create(raster, gridGeometry, null);
}
- public static class Tile {
- private final int tileX;
- private final int tileY;
- private final GridCoverage2D coverage;
-
- public Tile(int tileX, int tileY, GridCoverage2D coverage) {
- this.tileX = tileX;
- this.tileY = tileY;
- this.coverage = coverage;
- }
-
- public int getTileX() {
- return tileX;
- }
-
- public int getTileY() {
- return tileY;
- }
-
- public GridCoverage2D getCoverage() {
- return coverage;
- }
- }
-
/**
- * Generate tiles from a grid coverage
+ * Generate tiles from a grid coverage. Returns a lazy iterator that
generates tiles one at a
+ * time, reading only the necessary pixel data for each tile from the source
image.
*
* @param gridCoverage2D the grid coverage
* @param bandIndices the indices of the bands to select (1-based), can be
null or empty to
@@ -595,9 +583,9 @@ public class RasterConstructors {
* @param padWithNoData whether to pad the tiles with no data value
* @param padNoDataValue the no data value for padded tiles, only used when
padWithNoData is true.
* If the value is NaN, the no data value of the original band will be
used.
- * @return the tiles
+ * @return a lazy iterator of tiles
*/
- public static Tile[] generateTiles(
+ public static TileGenerator.TileIterator generateTiles(
GridCoverage2D gridCoverage2D,
int[] bandIndices,
int tileWidth,
@@ -620,102 +608,10 @@ public class RasterConstructors {
}
}
}
- return doGenerateTiles(
+ return TileGenerator.generateInDbTiles(
gridCoverage2D, bandIndices, tileWidth, tileHeight, padWithNoData,
padNoDataValue);
}
- /**
- * Generate tiles from an in-db grid coverage. The generated tiles are also
in-db grid coverages.
- * Pixel data will be copied into the tiles.
- *
- * @param gridCoverage2D the in-db grid coverage
- * @param bandIndices the indices of the bands to select (1-based)
- * @param tileWidth the width of the tiles
- * @param tileHeight the height of the tiles
- * @param padWithNoData whether to pad the tiles with no data value
- * @param padNoDataValue the no data value for padded tiles, only used when
padWithNoData is true.
- * If the value is NaN, the no data value of the original band will be
used.
- * @return the tiles
- */
- private static Tile[] doGenerateTiles(
- GridCoverage2D gridCoverage2D,
- int[] bandIndices,
- int tileWidth,
- int tileHeight,
- boolean padWithNoData,
- double padNoDataValue) {
- AffineTransform2D affine =
- RasterUtils.getAffineTransform(gridCoverage2D,
PixelOrientation.CENTER);
- RenderedImage image = gridCoverage2D.getRenderedImage();
- double[] noDataValues = new double[bandIndices.length];
- for (int i = 0; i < bandIndices.length; i++) {
- noDataValues[i] =
-
RasterUtils.getNoDataValue(gridCoverage2D.getSampleDimension(bandIndices[i] -
1));
- }
- int width = image.getWidth();
- int height = image.getHeight();
- int numTileX = (int) Math.ceil((double) width / tileWidth);
- int numTileY = (int) Math.ceil((double) height / tileHeight);
- Tile[] tiles = new Tile[numTileX * numTileY];
- for (int tileY = 0; tileY < numTileY; tileY++) {
- for (int tileX = 0; tileX < numTileX; tileX++) {
- int x0 = tileX * tileWidth;
- int y0 = tileY * tileHeight;
-
- // Rect to copy from the original image
- int rectWidth = Math.min(tileWidth, width - x0);
- int rectHeight = Math.min(tileHeight, height - y0);
-
- // If we don't pad with no data, the tiles on the boundary may have a
different size
- int currentTileWidth = padWithNoData ? tileWidth : rectWidth;
- int currentTileHeight = padWithNoData ? tileHeight : rectHeight;
- boolean needPadding = padWithNoData && (rectWidth < tileWidth ||
rectHeight < tileHeight);
-
- // Create a new affine transformation for this tile
- AffineTransform2D tileAffine =
RasterUtils.translateAffineTransform(affine, x0, y0);
- GridGeometry2D gridGeometry2D =
- new GridGeometry2D(
- new GridEnvelope2D(0, 0, currentTileWidth, currentTileHeight),
- PixelInCell.CELL_CENTER,
- tileAffine,
- gridCoverage2D.getCoordinateReferenceSystem(),
- null);
-
- // Prepare a new image for this tile, and copy the data from the
original image
- WritableRaster raster =
- RasterFactory.createBandedRaster(
- image.getSampleModel().getDataType(),
- currentTileWidth,
- currentTileHeight,
- bandIndices.length,
- null);
- GridSampleDimension[] sampleDimensions = new
GridSampleDimension[bandIndices.length];
- Raster sourceRaster = image.getData(new Rectangle(x0, y0, rectWidth,
rectHeight));
- for (int k = 0; k < bandIndices.length; k++) {
- int bandIndex = bandIndices[k] - 1;
-
- // Copy sample dimensions from source bands, and pad with no data
value if necessary
- GridSampleDimension sampleDimension =
gridCoverage2D.getSampleDimension(bandIndex);
- double noDataValue = noDataValues[k];
- if (needPadding && !Double.isNaN(padNoDataValue)) {
- sampleDimension =
-
RasterUtils.createSampleDimensionWithNoDataValue(sampleDimension,
padNoDataValue);
- noDataValue = padNoDataValue;
- }
- sampleDimensions[k] = sampleDimension;
-
- // Copy data from original image to tile image
- ImageUtils.copyRasterWithPadding(sourceRaster, bandIndex, raster, k,
noDataValue);
- }
-
- GridCoverage2D tile = RasterUtils.create(raster, gridGeometry2D,
sampleDimensions);
- tiles[tileY * numTileX + tileX] = new Tile(tileX, tileY, tile);
- }
- }
-
- return tiles;
- }
-
public static GridCoverage2D[] rsTile(
GridCoverage2D gridCoverage2D,
int[] bandIndices,
@@ -729,12 +625,14 @@ public class RasterConstructors {
if (padNoDataValue == null) {
padNoDataValue = Double.NaN;
}
- Tile[] tiles =
+ TileGenerator.TileIterator tileIterator =
generateTiles(
gridCoverage2D, bandIndices, tileWidth, tileHeight, padWithNoData,
padNoDataValue);
- GridCoverage2D[] result = new GridCoverage2D[tiles.length];
- for (int i = 0; i < tiles.length; i++) {
- result[i] = tiles[i].getCoverage();
+ GridCoverage2D[] result = new GridCoverage2D[tileIterator.getNumTiles()];
+ int i = 0;
+ while (tileIterator.hasNext()) {
+ TileGenerator.Tile tile = tileIterator.next();
+ result[i++] = tile.getCoverage();
}
return result;
}
diff --git
a/common/src/main/java/org/apache/sedona/common/raster/TileGenerator.java
b/common/src/main/java/org/apache/sedona/common/raster/TileGenerator.java
new file mode 100644
index 0000000000..6f7b3818af
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/raster/TileGenerator.java
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.common.raster;
+
+import java.awt.Rectangle;
+import java.awt.image.Raster;
+import java.awt.image.RenderedImage;
+import java.awt.image.WritableRaster;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+import javax.media.jai.RasterFactory;
+import org.apache.sedona.common.utils.ImageUtils;
+import org.apache.sedona.common.utils.RasterUtils;
+import org.geotools.api.metadata.spatial.PixelOrientation;
+import org.geotools.api.referencing.datum.PixelInCell;
+import org.geotools.coverage.GridSampleDimension;
+import org.geotools.coverage.grid.GridCoverage2D;
+import org.geotools.coverage.grid.GridEnvelope2D;
+import org.geotools.coverage.grid.GridGeometry2D;
+import org.geotools.referencing.operation.transform.AffineTransform2D;
+
+public class TileGenerator {
+
+ public static class Tile {
+
+ private final int tileX;
+ private final int tileY;
+ private final GridCoverage2D coverage;
+
+ public Tile(int tileX, int tileY, GridCoverage2D coverage) {
+ this.tileX = tileX;
+ this.tileY = tileY;
+ this.coverage = coverage;
+ }
+
+ public int getTileX() {
+ return tileX;
+ }
+
+ public int getTileY() {
+ return tileY;
+ }
+
+ public GridCoverage2D getCoverage() {
+ return coverage;
+ }
+ }
+
+ /**
+ * Generate tiles from an in-db grid coverage. The generated tiles are also
in-db grid coverages.
+ * Pixel data will be copied into the tiles one tile at a time via a lazy
iterator.
+ *
+ * @param gridCoverage2D the in-db grid coverage
+ * @param bandIndices the indices of the bands to select (1-based)
+ * @param tileWidth the width of the tiles
+ * @param tileHeight the height of the tiles
+ * @param padWithNoData whether to pad the tiles with no data value
+ * @param padNoDataValue the no data value for padded tiles, only used when
padWithNoData is true.
+ * If the value is NaN, the no data value of the original band will be
used.
+ * @return a lazy iterator of tiles
+ */
+ public static InDbTileIterator generateInDbTiles(
+ GridCoverage2D gridCoverage2D,
+ int[] bandIndices,
+ int tileWidth,
+ int tileHeight,
+ boolean padWithNoData,
+ double padNoDataValue) {
+ return new InDbTileIterator(
+ gridCoverage2D, bandIndices, tileWidth, tileHeight, padWithNoData,
padNoDataValue);
+ }
+
+ public abstract static class TileIterator implements Iterator<Tile> {
+ protected final GridCoverage2D gridCoverage2D;
+ protected int numTileX;
+ protected int numTileY;
+ protected int tileX;
+ protected int tileY;
+ protected boolean autoDisposeSource = false;
+ protected Runnable disposeFunction = null;
+
+ TileIterator(GridCoverage2D gridCoverage2D) {
+ this.gridCoverage2D = gridCoverage2D;
+ }
+
+ protected void initialize(int numTileX, int numTileY) {
+ this.numTileX = numTileX;
+ this.numTileY = numTileY;
+ this.tileX = 0;
+ this.tileY = 0;
+ }
+
+ /**
+ * Set whether to dispose the grid coverage when the iterator reaches the
end. Default is false.
+ *
+ * @param autoDisposeSource whether to dispose the grid coverage
+ */
+ public void setAutoDisposeSource(boolean autoDisposeSource) {
+ this.autoDisposeSource = autoDisposeSource;
+ }
+
+ public void setDisposeFunction(Runnable disposeFunction) {
+ this.disposeFunction = disposeFunction;
+ }
+
+ public int getNumTileX() {
+ return numTileX;
+ }
+
+ public int getNumTileY() {
+ return numTileY;
+ }
+
+ public int getNumTiles() {
+ return numTileX * numTileY;
+ }
+
+ protected abstract Tile generateTile();
+
+ @Override
+ public boolean hasNext() {
+ if (numTileX == 0 || numTileY == 0) {
+ return false;
+ }
+ return tileY < numTileY;
+ }
+
+ @Override
+ public Tile next() {
+ // Check if current tile coordinate is valid
+ if (tileX >= numTileX || tileY >= numTileY) {
+ throw new NoSuchElementException();
+ }
+
+ Tile tile = generateTile();
+
+ // Advance to the next tile
+ tileX += 1;
+ if (tileX >= numTileX) {
+ tileX = 0;
+ tileY += 1;
+
+ // Dispose the grid coverage if we are at the end
+ if (tileY >= numTileY) {
+ if (autoDisposeSource) {
+ gridCoverage2D.dispose(true);
+ }
+ if (disposeFunction != null) {
+ disposeFunction.run();
+ }
+ }
+ }
+
+ return tile;
+ }
+ }
+
+ public static class InDbTileIterator extends TileIterator {
+ private final int[] bandIndices;
+ private final int tileWidth;
+ private final int tileHeight;
+ private final boolean padWithNoData;
+ private final double padNoDataValue;
+ private final AffineTransform2D affine;
+ private final RenderedImage image;
+ private final double[] noDataValues;
+ private final int imageWidth;
+ private final int imageHeight;
+
+ public InDbTileIterator(
+ GridCoverage2D gridCoverage2D,
+ int[] bandIndices,
+ int tileWidth,
+ int tileHeight,
+ boolean padWithNoData,
+ double padNoDataValue) {
+ super(gridCoverage2D);
+ this.bandIndices = bandIndices;
+ this.tileWidth = tileWidth;
+ this.tileHeight = tileHeight;
+ this.padWithNoData = padWithNoData;
+ this.padNoDataValue = padNoDataValue;
+
+ affine = RasterUtils.getAffineTransform(gridCoverage2D,
PixelOrientation.CENTER);
+ image = gridCoverage2D.getRenderedImage();
+ noDataValues = new double[bandIndices.length];
+ for (int i = 0; i < bandIndices.length; i++) {
+ noDataValues[i] =
+
RasterUtils.getNoDataValue(gridCoverage2D.getSampleDimension(bandIndices[i] -
1));
+ }
+ imageWidth = image.getWidth();
+ imageHeight = image.getHeight();
+ int numTileX = (int) Math.ceil((double) imageWidth / tileWidth);
+ int numTileY = (int) Math.ceil((double) imageHeight / tileHeight);
+ initialize(numTileX, numTileY);
+ }
+
+ @Override
+ protected Tile generateTile() {
+ // Process the current tile
+ int x0 = tileX * tileWidth + image.getMinX();
+ int y0 = tileY * tileHeight + image.getMinY();
+
+ // Rect to copy from the original image
+ int rectWidth = Math.min(tileWidth, image.getMinX() + imageWidth - x0);
+ int rectHeight = Math.min(tileHeight, image.getMinY() + imageHeight -
y0);
+
+ // If we don't pad with no data, the tiles on the boundary may have a
different size
+ int currentTileWidth = padWithNoData ? tileWidth : rectWidth;
+ int currentTileHeight = padWithNoData ? tileHeight : rectHeight;
+ boolean needPadding = padWithNoData && (rectWidth < tileWidth ||
rectHeight < tileHeight);
+
+ // Create a new affine transformation for this tile
+ AffineTransform2D tileAffine =
RasterUtils.translateAffineTransform(affine, x0, y0);
+ GridGeometry2D gridGeometry2D =
+ new GridGeometry2D(
+ new GridEnvelope2D(0, 0, currentTileWidth, currentTileHeight),
+ PixelInCell.CELL_CENTER,
+ tileAffine,
+ gridCoverage2D.getCoordinateReferenceSystem(),
+ null);
+
+ // Prepare a new image for this tile, and copy the data from the
original image
+ WritableRaster raster =
+ RasterFactory.createBandedRaster(
+ image.getSampleModel().getDataType(),
+ currentTileWidth,
+ currentTileHeight,
+ bandIndices.length,
+ null);
+ GridSampleDimension[] sampleDimensions = new
GridSampleDimension[bandIndices.length];
+ Raster sourceRaster = image.getData(new Rectangle(x0, y0, rectWidth,
rectHeight));
+ for (int k = 0; k < bandIndices.length; k++) {
+ int bandIndex = bandIndices[k] - 1;
+
+ // Copy sample dimensions from source bands, and pad with no data
value if necessary
+ GridSampleDimension sampleDimension =
gridCoverage2D.getSampleDimension(bandIndex);
+ double noDataValue = noDataValues[k];
+ if (needPadding && !Double.isNaN(padNoDataValue)) {
+ sampleDimension =
+
RasterUtils.createSampleDimensionWithNoDataValue(sampleDimension,
padNoDataValue);
+ noDataValue = padNoDataValue;
+ }
+ sampleDimensions[k] = sampleDimension;
+
+ // Copy data from original image to tile image
+ ImageUtils.copyRasterWithPadding(sourceRaster, bandIndex, raster, k,
noDataValue);
+ }
+
+ GridCoverage2D tile = RasterUtils.create(raster, gridGeometry2D,
sampleDimensions);
+ return new Tile(tileX, tileY, tile);
+ }
+ }
+}
diff --git
a/common/src/main/java/org/apache/sedona/common/raster/inputstream/HadoopImageInputStream.java
b/common/src/main/java/org/apache/sedona/common/raster/inputstream/HadoopImageInputStream.java
new file mode 100644
index 0000000000..e4055cad6b
--- /dev/null
+++
b/common/src/main/java/org/apache/sedona/common/raster/inputstream/HadoopImageInputStream.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.common.raster.inputstream;
+
+import java.io.IOException;
+import javax.imageio.stream.ImageInputStreamImpl;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+
+/** An ImageInputStream that reads image data from a Hadoop FileSystem. */
+public class HadoopImageInputStream extends ImageInputStreamImpl {
+
+ private final FSDataInputStream stream;
+ private final Path path;
+ private final Configuration conf;
+
+ public HadoopImageInputStream(Path path, Configuration conf) throws
IOException {
+ FileSystem fs = path.getFileSystem(conf);
+ stream = fs.open(path);
+ this.path = path;
+ this.conf = conf;
+ }
+
+ public HadoopImageInputStream(Path path) throws IOException {
+ this(path, new Configuration());
+ }
+
+ public HadoopImageInputStream(FSDataInputStream stream) {
+ this.stream = stream;
+ this.path = null;
+ this.conf = null;
+ }
+
+ public Path getPath() {
+ return path;
+ }
+
+ public Configuration getConf() {
+ return conf;
+ }
+
+ @Override
+ public void close() throws IOException {
+ super.close();
+ stream.close();
+ }
+
+ @Override
+ public int read() throws IOException {
+ byte[] buf = new byte[1];
+ int ret_len = read(buf, 0, 1);
+ if (ret_len < 0) {
+ return ret_len;
+ }
+ return buf[0] & 0xFF;
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ checkClosed();
+ bitOffset = 0;
+
+ if (len == 0) {
+ return 0;
+ }
+
+ // stream.read may return fewer data than requested, so we need to loop
until we get all the
+ // data, or hit the end of the stream. We can not simply perform an
incomplete read and return
+ // the number of bytes actually read, since the methods in
ImageInputStreamImpl such as
+ // readInt() relies on this method and assumes that partial reads only
happens when reading
+ // EOF. This might be a bug of imageio since they should invoke
readFully() in such cases.
+ int remaining = len;
+ while (remaining > 0) {
+ int ret_len = stream.read(b, off, remaining);
+ if (ret_len == 0) {
+ // This should not happen per InputStream contract, but guard against
non-progressing
+ // streams to avoid an infinite loop.
+ throw new IOException("Stream returned 0 bytes for a non-zero read
request");
+ }
+ if (ret_len < 0) {
+ // Hit EOF, no more data to read.
+ if (len - remaining > 0) {
+ // We have read some data, but that may not be all the data we can
read from the
+ // stream. The partial read may happen when reading from S3
(S3AInputStream), and it
+ // may also happen on other remote file systems.
+ return len - remaining;
+ } else {
+ // We have not read any data, return EOF.
+ return ret_len;
+ }
+ }
+ off += ret_len;
+ remaining -= ret_len;
+ streamPos += ret_len;
+ }
+
+ return len - remaining;
+ }
+
+ @Override
+ public void seek(long pos) throws IOException {
+ checkClosed();
+ stream.seek(pos);
+ super.seek(pos);
+ }
+}
diff --git
a/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java
b/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java
index a5cc02aada..7e8b7643b3 100644
---
a/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java
+++
b/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java
@@ -498,8 +498,8 @@ public class RasterConstructorsTest extends RasterTestBase {
public void testInDbTileWithoutPadding() {
GridCoverage2D raster =
createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 1,
"EPSG:3857");
- RasterConstructors.Tile[] tiles =
- RasterConstructors.generateTiles(raster, null, 10, 10, false,
Double.NaN);
+ TileGenerator.Tile[] tiles =
+ collectTiles(RasterConstructors.generateTiles(raster, null, 10, 10,
false, Double.NaN));
assertTilesSameWithGridCoverage(tiles, raster, null, 10, 10, Double.NaN);
}
@@ -507,8 +507,8 @@ public class RasterConstructorsTest extends RasterTestBase {
public void testInDbTileWithoutPadding2() {
GridCoverage2D raster =
createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 1,
"EPSG:3857");
- RasterConstructors.Tile[] tiles =
- RasterConstructors.generateTiles(raster, null, 9, 9, false,
Double.NaN);
+ TileGenerator.Tile[] tiles =
+ collectTiles(RasterConstructors.generateTiles(raster, null, 9, 9,
false, Double.NaN));
assertTilesSameWithGridCoverage(tiles, raster, null, 9, 9, Double.NaN);
}
@@ -516,8 +516,8 @@ public class RasterConstructorsTest extends RasterTestBase {
public void testInDbTileWithPadding() {
GridCoverage2D raster =
createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 2,
"EPSG:3857");
- RasterConstructors.Tile[] tiles =
- RasterConstructors.generateTiles(raster, null, 9, 9, true, 100);
+ TileGenerator.Tile[] tiles =
+ collectTiles(RasterConstructors.generateTiles(raster, null, 9, 9,
true, 100));
assertTilesSameWithGridCoverage(tiles, raster, null, 9, 9, 100);
}
@@ -526,8 +526,8 @@ public class RasterConstructorsTest extends RasterTestBase {
GridCoverage2D raster =
createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 2,
"EPSG:3857");
int[] bandIndices = {2};
- RasterConstructors.Tile[] tiles =
- RasterConstructors.generateTiles(raster, bandIndices, 9, 9, true, 100);
+ TileGenerator.Tile[] tiles =
+ collectTiles(RasterConstructors.generateTiles(raster, bandIndices, 9,
9, true, 100));
assertTilesSameWithGridCoverage(tiles, raster, bandIndices, 9, 9, 100);
}
@@ -536,8 +536,8 @@ public class RasterConstructorsTest extends RasterTestBase {
GridCoverage2D raster =
createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 4,
"EPSG:3857");
int[] bandIndices = {3, 1};
- RasterConstructors.Tile[] tiles =
- RasterConstructors.generateTiles(raster, bandIndices, 8, 7, true, 100);
+ TileGenerator.Tile[] tiles =
+ collectTiles(RasterConstructors.generateTiles(raster, bandIndices, 8,
7, true, 100));
assertTilesSameWithGridCoverage(tiles, raster, bandIndices, 8, 7, 100);
}
@@ -546,8 +546,8 @@ public class RasterConstructorsTest extends RasterTestBase {
GridCoverage2D raster =
createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 1,
"EPSG:3857");
raster = MapAlgebra.addBandFromArray(raster,
MapAlgebra.bandAsArray(raster, 1), 1, 13.0);
- RasterConstructors.Tile[] tiles =
- RasterConstructors.generateTiles(raster, null, 9, 9, true, Double.NaN);
+ TileGenerator.Tile[] tiles =
+ collectTiles(RasterConstructors.generateTiles(raster, null, 9, 9,
true, Double.NaN));
assertTilesSameWithGridCoverage(tiles, raster, null, 9, 9, 13);
}
@@ -556,13 +556,21 @@ public class RasterConstructorsTest extends
RasterTestBase {
GridCoverage2D raster =
createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 1,
"EPSG:3857");
raster = MapAlgebra.addBandFromArray(raster,
MapAlgebra.bandAsArray(raster, 1), 1, 13.0);
- RasterConstructors.Tile[] tiles =
- RasterConstructors.generateTiles(raster, null, 9, 9, true, 42);
+ TileGenerator.Tile[] tiles =
+ collectTiles(RasterConstructors.generateTiles(raster, null, 9, 9,
true, 42));
assertTilesSameWithGridCoverage(tiles, raster, null, 9, 9, 42);
}
+ private TileGenerator.Tile[] collectTiles(TileGenerator.TileIterator iter) {
+ java.util.List<TileGenerator.Tile> list = new java.util.ArrayList<>();
+ while (iter.hasNext()) {
+ list.add(iter.next());
+ }
+ return list.toArray(new TileGenerator.Tile[0]);
+ }
+
private void assertTilesSameWithGridCoverage(
- RasterConstructors.Tile[] tiles,
+ TileGenerator.Tile[] tiles,
GridCoverage2D gridCoverage2D,
int[] bandIndices,
int tileWidth,
@@ -581,7 +589,7 @@ public class RasterConstructorsTest extends RasterTestBase {
// in the grid
// coverage
Set<Pair<Integer, Integer>> visitedTiles = new HashSet<>();
- for (RasterConstructors.Tile tile : tiles) {
+ for (TileGenerator.Tile tile : tiles) {
int tileX = tile.getTileX();
int tileY = tile.getTileY();
Pair<Integer, Integer> tilePosition = Pair.of(tileX, tileY);
diff --git
a/common/src/test/java/org/apache/sedona/common/raster/inputstream/HadoopImageInputStreamTest.java
b/common/src/test/java/org/apache/sedona/common/raster/inputstream/HadoopImageInputStreamTest.java
new file mode 100644
index 0000000000..76c7380333
--- /dev/null
+++
b/common/src/test/java/org/apache/sedona/common/raster/inputstream/HadoopImageInputStreamTest.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.common.raster.inputstream;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.RandomAccessFile;
+import java.nio.file.Files;
+import java.util.Random;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PositionedReadable;
+import org.apache.hadoop.fs.Seekable;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+public class HadoopImageInputStreamTest {
+ @Rule public TemporaryFolder temp = new TemporaryFolder();
+
+ private static final int TEST_FILE_SIZE = 1000;
+ private final Random random = new Random();
+ private File testFile;
+
+ @Before
+ public void setup() throws IOException {
+ testFile = temp.newFile();
+ prepareTestData(testFile);
+ }
+
+ @Test
+ public void testReadSequentially() throws IOException {
+ Path path = new Path(testFile.getPath());
+ try (HadoopImageInputStream stream = new HadoopImageInputStream(path);
+ InputStream in = new
BufferedInputStream(Files.newInputStream(testFile.toPath()))) {
+ byte[] bActual = new byte[8];
+ byte[] bExpected = new byte[bActual.length];
+ while (true) {
+ int len = random.nextInt(bActual.length + 1);
+ int lenActual = stream.read(bActual, 0, len);
+ int lenExpected = in.read(bExpected, 0, len);
+ Assert.assertEquals(lenExpected, lenActual);
+ if (lenActual < 0) {
+ break;
+ }
+ Assert.assertArrayEquals(bExpected, bActual);
+ }
+ }
+ }
+
+ @Test
+ public void testReadRandomly() throws IOException {
+ Path path = new Path(testFile.getPath());
+ try (HadoopImageInputStream stream = new HadoopImageInputStream(path);
+ RandomAccessFile raf = new RandomAccessFile(testFile, "r")) {
+ byte[] bActual = new byte[8];
+ byte[] bExpected = new byte[bActual.length];
+ for (int k = 0; k < 1000; k++) {
+ int offset = random.nextInt(TEST_FILE_SIZE + 1);
+ int len = random.nextInt(bActual.length + 1);
+ stream.seek(offset);
+ raf.seek(offset);
+ int lenActual = stream.read(bActual, 0, len);
+ int lenExpected = raf.read(bExpected, 0, len);
+ Assert.assertEquals(lenExpected, lenActual);
+ if (lenActual < 0) {
+ continue;
+ }
+ Assert.assertArrayEquals(bExpected, bActual);
+ }
+
+ // Test seek to EOF.
+ stream.seek(TEST_FILE_SIZE);
+ int len = stream.read(bActual, 0, bActual.length);
+ Assert.assertEquals(-1, len);
+ }
+ }
+
+ @Test
+ public void testFromUnstableStream() throws IOException {
+ Path path = new Path(testFile.getPath());
+ FileSystem fs = path.getFileSystem(new Configuration());
+ try (FSDataInputStream unstable = new
UnstableFSDataInputStream(fs.open(path));
+ HadoopImageInputStream stream = new HadoopImageInputStream(unstable);
+ InputStream in = new
BufferedInputStream(Files.newInputStream(testFile.toPath()))) {
+ byte[] bActual = new byte[8];
+ byte[] bExpected = new byte[bActual.length];
+ while (true) {
+ int len = random.nextInt(bActual.length);
+ int lenActual = stream.read(bActual, 0, len);
+ int lenExpected = in.read(bExpected, 0, len);
+ Assert.assertEquals(lenExpected, lenActual);
+ if (lenActual < 0) {
+ break;
+ }
+ Assert.assertArrayEquals(bExpected, bActual);
+ }
+ }
+ }
+
+ private void prepareTestData(File testFile) throws IOException {
+ try (OutputStream out = new
BufferedOutputStream(Files.newOutputStream(testFile.toPath()))) {
+ for (int k = 0; k < TEST_FILE_SIZE; k++) {
+ out.write(random.nextInt());
+ }
+ }
+ }
+
+ /**
+ * An FSDataInputStream that sometimes return less data than requested when
calling read(byte[],
+ * int, int).
+ */
+ private static class UnstableFSDataInputStream extends FSDataInputStream {
+ public UnstableFSDataInputStream(FSDataInputStream in) {
+ super(new UnstableInputStream(in.getWrappedStream()));
+ }
+
+ private static class UnstableInputStream extends InputStream
+ implements Seekable, PositionedReadable {
+ private final InputStream wrapped;
+ private final Random random = new Random();
+
+ UnstableInputStream(InputStream in) {
+ wrapped = in;
+ }
+
+ @Override
+ public void close() throws IOException {
+ wrapped.close();
+ }
+
+ @Override
+ public int read() throws IOException {
+ return wrapped.read();
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ // Make this read unstable, i.e. sometimes return less data than
requested,
+ // but still obey InputStream's contract by not returning 0 when len >
0.
+ if (len == 0) {
+ return 0;
+ }
+ int unstableLen = 1 + random.nextInt(len);
+ return wrapped.read(b, off, unstableLen);
+ }
+
+ @Override
+ public void seek(long pos) throws IOException {
+ ((Seekable) wrapped).seek(pos);
+ }
+
+ @Override
+ public long getPos() throws IOException {
+ return ((Seekable) wrapped).getPos();
+ }
+
+ @Override
+ public boolean seekToNewSource(long targetPos) throws IOException {
+ return ((Seekable) wrapped).seekToNewSource(targetPos);
+ }
+
+ @Override
+ public int read(long position, byte[] buffer, int offset, int length)
throws IOException {
+ return ((PositionedReadable) wrapped).read(position, buffer, offset,
length);
+ }
+
+ @Override
+ public void readFully(long position, byte[] buffer, int offset, int
length)
+ throws IOException {
+ ((PositionedReadable) wrapped).readFully(position, buffer, offset,
length);
+ }
+
+ @Override
+ public void readFully(long position, byte[] buffer) throws IOException {
+ ((PositionedReadable) wrapped).readFully(position, buffer);
+ }
+ }
+ }
+}
diff --git a/docs/api/sql/Raster-loader.md b/docs/api/sql/Raster-loader.md
index d0df8c2231..acc60f6815 100644
--- a/docs/api/sql/Raster-loader.md
+++ b/docs/api/sql/Raster-loader.md
@@ -20,6 +20,74 @@
!!!note
Sedona loader are available in Scala, Java and Python and have the same
APIs.
+## Loading raster using the raster data source
+
+The `raster` data source loads GeoTiff files and automatically splits them
into smaller tiles. Each tile is a row in the resulting DataFrame stored in
`Raster` format.
+
+=== "Scala"
+ ```scala
+ var rawDf = sedona.read.format("raster").load("/some/path/*.tif")
+ rawDf.createOrReplaceTempView("rawdf")
+ rawDf.show()
+ ```
+
+=== "Java"
+ ```java
+ Dataset<Row> rawDf =
sedona.read().format("raster").load("/some/path/*.tif");
+ rawDf.createOrReplaceTempView("rawdf");
+ rawDf.show();
+ ```
+
+=== "Python"
+ ```python
+ rawDf = sedona.read.format("raster").load("/some/path/*.tif")
+ rawDf.createOrReplaceTempView("rawdf")
+ rawDf.show()
+ ```
+
+The output will look like this:
+
+```
++--------------------+---+---+----+
+| rast| x| y|name|
++--------------------+---+---+----+
+|GridCoverage2D["g...| 0| 0| ...|
+|GridCoverage2D["g...| 1| 0| ...|
+|GridCoverage2D["g...| 2| 0| ...|
+...
+```
+
+The output contains the following columns:
+
+- `rast`: The raster data in `Raster` format.
+- `x`: The 0-based x-coordinate of the tile. This column is only present when
retile is not disabled.
+- `y`: The 0-based y-coordinate of the tile. This column is only present when
retile is not disabled.
+- `name`: The name of the raster file.
+
+The size of the tile is determined by the internal tiling scheme of the raster
data. It is recommended to use [Cloud Optimized GeoTIFF
(COG)](https://www.cogeo.org/) format for raster data since they usually
organize pixel data as square tiles. You can also disable automatic tiling
using `option("retile", "false")`, or specify the tile size manually using
options such as `option("tileWidth", "256")` and `option("tileHeight", "256")`.
+
+The options for the `raster` data source are as follows:
+
+- `retile`: Whether to enable tiling. Default is `true`.
+- `tileWidth`: The width of the tile. If not specified, the size of internal
tiles will be used.
+- `tileHeight`: The height of the tile. If not specified, will use `tileWidth`
if `tileWidth` is explicitly set, otherwise the size of internal tiles will be
used.
+- `padWithNoData`: Pad the right and bottom of the tile with NODATA values if
the tile is smaller than the specified tile size. Default is `false`.
+
+!!!note
+ If the internal tiling scheme of raster data is not friendly for tiling,
the `raster` data source will throw an error, and you can disable automatic
tiling using `option("retile", "false")`, or specify the tile size manually to
workaround this issue. A better solution is to translate the raster data into
COG format using `gdal_translate` or other tools.
+
+The `raster` data source also works with Spark generic file source options,
such as `option("pathGlobFilter", "*.tif*")` and `option("recursiveFileLookup",
"true")`. For instance, you can load all the `.tif` files recursively in a
directory using
+
+```python
+sedona.read.format("raster").option("recursiveFileLookup", "true").option(
+ "pathGlobFilter", "*.tif*"
+).load(path_to_raster_data_folder)
+```
+
+One difference from other file source loaders is that when the loaded path
ends with `/`, the `raster` data source will look up raster files in the
directory and all its subdirectories recursively. This is equivalent to
specifying a path without trailing `/` and setting
`option("recursiveFileLookup", "true")`.
+
+## Loading raster using binaryFile loader (Deprecated)
+
The raster loader of Sedona leverages Spark built-in binary data source and
works with several RS constructors to produce Raster type. Each raster is a row
in the resulting DataFrame and stored in a `Raster` format.
!!!tip
@@ -27,7 +95,7 @@ The raster loader of Sedona leverages Spark built-in binary
data source and work
By default, these functions uses lon/lat order since `v1.5.0`. Before, it used
lat/lon order.
-## Step 1: Load raster to a binary DataFrame
+### Step 1: Load raster to a binary DataFrame
You can load any type of raster data using the code below. Then use the RS
constructors below to create a Raster DataFrame.
@@ -35,7 +103,7 @@ You can load any type of raster data using the code below.
Then use the RS const
sedona.read.format("binaryFile").load("/some/path/*.asc")
```
-## Step 2: Create a raster type column
+### Step 2: Create a raster type column
### RS_FromArcInfoAsciiGrid
diff --git
a/spark/common/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
b/spark/common/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index e420915439..10c362405c 100644
---
a/spark/common/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++
b/spark/common/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -1,4 +1,4 @@
-org.apache.spark.sql.sedona_sql.io.raster.RasterFileFormat
+org.apache.spark.sql.sedona_sql.io.raster.RasterDataSource
org.apache.spark.sql.sedona_sql.io.geojson.GeoJSONFileFormat
org.apache.sedona.sql.datasources.spider.SpiderDataSource
org.apache.spark.sql.sedona_sql.io.stac.StacDataSource
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
index 2523e89343..b0409dd1f3 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
@@ -30,6 +30,8 @@ import
org.apache.spark.sql.sedona_sql.expressions.InferredExpression
import
org.apache.spark.sql.sedona_sql.expressions.raster.implicits.{RasterEnhancer,
RasterInputExpressionEnhancer}
import org.apache.spark.sql.types._
+import scala.collection.JavaConverters._
+
private[apache] case class RS_FromArcInfoAsciiGrid(inputExpressions:
Seq[Expression])
extends InferredExpression(RasterConstructors.fromArcInfoAsciiGrid _) {
override def foldable: Boolean = false
@@ -107,39 +109,38 @@ private[apache] case class RS_TileExplode(children:
Seq[Expression])
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val raster = arguments.rasterExpr.toRaster(input)
- try {
- val bandIndices =
arguments.bandIndicesExpr.eval(input).asInstanceOf[ArrayData] match {
- case null => null
- case value: Any => value.toIntArray
- }
- val tileWidth = arguments.tileWidthExpr.eval(input).asInstanceOf[Int]
- val tileHeight = arguments.tileHeightExpr.eval(input).asInstanceOf[Int]
- val padWithNoDataValue =
arguments.padWithNoDataExpr.eval(input).asInstanceOf[Boolean]
- val noDataValue = arguments.noDataValExpr.eval(input) match {
- case null => Double.NaN
- case value: Integer => value.toDouble
- case value: Decimal => value.toDouble
- case value: Float => value.toDouble
- case value: Double => value
- case value: Any =>
- throw new IllegalArgumentException(
- "Unsupported class for noDataValue: " + value.getClass)
- }
- val tiles = RasterConstructors.generateTiles(
- raster,
- bandIndices,
- tileWidth,
- tileHeight,
- padWithNoDataValue,
- noDataValue)
- tiles.map { tile =>
- val gridCoverage2D = tile.getCoverage
- val row = InternalRow(tile.getTileX, tile.getTileY,
gridCoverage2D.serialize)
- gridCoverage2D.dispose(true)
- row
- }
- } finally {
- raster.dispose(true)
+ if (raster == null) {
+ return Iterator.empty
+ }
+ val bandIndices =
arguments.bandIndicesExpr.eval(input).asInstanceOf[ArrayData] match {
+ case null => null
+ case value: Any => value.toIntArray
+ }
+ val tileWidth = arguments.tileWidthExpr.eval(input).asInstanceOf[Int]
+ val tileHeight = arguments.tileHeightExpr.eval(input).asInstanceOf[Int]
+ val padWithNoDataValue =
arguments.padWithNoDataExpr.eval(input).asInstanceOf[Boolean]
+ val noDataValue = arguments.noDataValExpr.eval(input) match {
+ case null => Double.NaN
+ case value: Integer => value.toDouble
+ case value: Decimal => value.toDouble
+ case value: Float => value.toDouble
+ case value: Double => value
+ case value: Any =>
+ throw new IllegalArgumentException("Unsupported class for noDataValue:
" + value.getClass)
+ }
+ val tileIterator = RasterConstructors.generateTiles(
+ raster,
+ bandIndices,
+ tileWidth,
+ tileHeight,
+ padWithNoDataValue,
+ noDataValue)
+ tileIterator.setAutoDisposeSource(true)
+ tileIterator.asScala.map { tile =>
+ val gridCoverage2D = tile.getCoverage
+ val row = InternalRow(tile.getTileX, tile.getTileY,
gridCoverage2D.serialize)
+ gridCoverage2D.dispose(true)
+ row
}
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterDataSource.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterDataSource.scala
new file mode 100644
index 0000000000..2141debda3
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterDataSource.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.io.raster
+
+import org.apache.spark.sql.connector.catalog.Table
+import org.apache.spark.sql.connector.catalog.TableProvider
+import org.apache.spark.sql.execution.datasources.FileFormat
+import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+import scala.collection.JavaConverters._
+
+/**
+ * A Spark SQL data source for reading and writing raster images. This data
source supports
+ * reading various raster formats as in-db rasters. Write support is
implemented by falling back
+ * to the V1 data source [[RasterFileFormat]].
+ */
+class RasterDataSource extends FileDataSourceV2 with TableProvider with
DataSourceRegister {
+
+ override def shortName(): String = "raster"
+
+ private def createRasterTable(
+ options: CaseInsensitiveStringMap,
+ userSchema: Option[StructType] = None): Table = {
+ var paths = getPaths(options)
+ var optionsWithoutPaths = getOptionsWithoutPaths(options)
+ val tableName = getTableName(options, paths)
+ val rasterOptions = new RasterOptions(optionsWithoutPaths.asScala.toMap)
+
+ if (paths.size == 1) {
+ if (paths.head.endsWith("/")) {
+ // Paths ends with / will be recursively loaded
+ val newOptions =
+ new java.util.HashMap[String,
String](optionsWithoutPaths.asCaseSensitiveMap())
+ newOptions.put("recursiveFileLookup", "true")
+ if (!newOptions.containsKey("pathGlobFilter")) {
+ newOptions.put("pathGlobFilter", "*.{tif,tiff,TIF,TIFF}")
+ }
+ optionsWithoutPaths = new CaseInsensitiveStringMap(newOptions)
+ } else {
+ // Rewrite paths such as /path/to/some*glob*.tif into /path/to with
+ // pathGlobFilter="some*glob*.tif". This is for avoiding listing .tif
+ // files as directories when discovering files to load. Globs ends with
+ // .tif or .tiff should be files in the context of raster data loading.
+ val loadTifPattern = "(.*)/([^/]*\\*[^/]*\\.(?i:tif|tiff))$".r
+ paths.head match {
+ case loadTifPattern(prefix, glob) =>
+ paths = Seq(prefix)
+ val newOptions =
+ new java.util.HashMap[String,
String](optionsWithoutPaths.asCaseSensitiveMap())
+ newOptions.put("pathGlobFilter", glob)
+ optionsWithoutPaths = new CaseInsensitiveStringMap(newOptions)
+ case _ =>
+ }
+ }
+ }
+
+ new RasterTable(
+ tableName,
+ sparkSession,
+ optionsWithoutPaths,
+ paths,
+ userSchema,
+ rasterOptions,
+ fallbackFileFormat)
+ }
+
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
+ createRasterTable(options)
+ }
+
+ override def getTable(options: CaseInsensitiveStringMap, schema:
StructType): Table = {
+ createRasterTable(options, Some(schema))
+ }
+
+ override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
+ val paths = getPaths(options)
+ if (paths.isEmpty) {
+ throw new IllegalArgumentException("No paths specified for raster data
source")
+ }
+
+ val rasterOptions = new RasterOptions(options.asScala.toMap)
+ RasterTable.inferSchema(rasterOptions)
+ }
+
+ override def fallbackFileFormat: Class[_ <: FileFormat] =
classOf[RasterFileFormat]
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterInputPartition.scala
similarity index 54%
copy from
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala
copy to
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterInputPartition.scala
index 4f1e501f04..c5a1ac2a69 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterInputPartition.scala
@@ -18,18 +18,10 @@
*/
package org.apache.spark.sql.sedona_sql.io.raster
-import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.connector.read.InputPartition
+import org.apache.spark.Partition
+import org.apache.spark.sql.execution.datasources.PartitionedFile
-private[io] class RasterOptions(@transient private val parameters:
CaseInsensitiveMap[String])
- extends Serializable {
- def this(parameters: Map[String, String]) =
this(CaseInsensitiveMap(parameters))
-
- // The file format of the raster image
- val fileExtension = parameters.getOrElse("fileExtension", ".tiff")
- // Column of the raster image name
- val rasterPathField = parameters.get("pathField")
- // Column of the raster image itself
- val rasterField = parameters.get("rasterField")
- // Use direct committer to directly write to the final destination
- val useDirectCommitter = parameters.getOrElse("useDirectCommitter",
"true").toBoolean
-}
+case class RasterInputPartition(index: Int, files: Array[PartitionedFile])
+ extends Partition
+ with InputPartition
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala
index 4f1e501f04..7c806bcd45 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala
@@ -20,16 +20,54 @@ package org.apache.spark.sql.sedona_sql.io.raster
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-private[io] class RasterOptions(@transient private val parameters:
CaseInsensitiveMap[String])
+class RasterOptions(@transient private val parameters:
CaseInsensitiveMap[String])
extends Serializable {
def this(parameters: Map[String, String]) =
this(CaseInsensitiveMap(parameters))
+ // The following options are used to read raster data
+
+ /**
+ * Whether to retile the raster data. If true, the raster data will be
retiled into smaller
+ * tiles. If false, the raster data will be read as a single tile.
+ */
+ val retile: Boolean = parameters.getOrElse("retile", "true").toBoolean
+
+ /**
+ * The width of the tile. This is only effective when retile is true. If
retile is true and
+ * tileWidth is not set, the default value is the width of the internal
tiles in the raster
+ * files. Each raster file may have different internal tile sizes.
+ */
+ val tileWidth: Option[Int] = parameters.get("tileWidth").map(_.toInt)
+
+ /**
+ * The height of the tile. This is only effective when retile is true. If
retile is true and
+ * tileHeight is not set, the default value is the same as tileWidth. If
tileHeight is set,
+ * tileWidth must be set as well.
+ */
+ val tileHeight: Option[Int] = parameters
+ .get("tileHeight")
+ .map { value =>
+ if (retile) {
+ require(tileWidth.isDefined, "tileWidth must be set when tileHeight is
set")
+ }
+ value.toInt
+ }
+ .orElse(tileWidth)
+
+ /**
+ * Whether to pad the right and bottom of the tile with NoData values if the
tile is smaller
+ * than the specified tile size. Default is `false`.
+ */
+ val padWithNoData: Boolean = parameters.getOrElse("padWithNoData",
"false").toBoolean
+
+ // The following options are used to write raster data
+
// The file format of the raster image
- val fileExtension = parameters.getOrElse("fileExtension", ".tiff")
+ val fileExtension: String = parameters.getOrElse("fileExtension", ".tiff")
// Column of the raster image name
- val rasterPathField = parameters.get("pathField")
+ val rasterPathField: Option[String] = parameters.get("pathField")
// Column of the raster image itself
- val rasterField = parameters.get("rasterField")
+ val rasterField: Option[String] = parameters.get("rasterField")
// Use direct committer to directly write to the final destination
- val useDirectCommitter = parameters.getOrElse("useDirectCommitter",
"true").toBoolean
+ val useDirectCommitter: Boolean = parameters.getOrElse("useDirectCommitter",
"true").toBoolean
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReader.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReader.scala
new file mode 100644
index 0000000000..0022f7c05f
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReader.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.io.raster
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.sedona.common.raster.RasterConstructors
+import org.apache.sedona.common.raster.inputstream.HadoopImageInputStream
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
+import org.apache.spark.sql.connector.read.PartitionReader
+import org.apache.spark.sql.execution.datasources.PartitionedFile
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
+import
org.apache.spark.sql.sedona_sql.io.raster.RasterPartitionReader.rasterToInternalRows
+import
org.apache.spark.sql.sedona_sql.io.raster.RasterTable.{MAX_AUTO_TILE_SIZE,
RASTER, RASTER_NAME, TILE_X, TILE_Y}
+import org.apache.spark.sql.types.StructType
+import org.geotools.coverage.grid.GridCoverage2D
+
+import java.net.URI
+import scala.collection.JavaConverters._
+
+class RasterPartitionReader(
+ configuration: Configuration,
+ partitionedFiles: Array[PartitionedFile],
+ dataSchema: StructType,
+ rasterOptions: RasterOptions)
+ extends PartitionReader[InternalRow] {
+
+ // Track the current file index we're processing
+ private var currentFileIndex = 0
+
+ // Current raster being processed
+ private var currentRaster: GridCoverage2D = _
+
+ // Current image input stream (must be kept open while the raster is in use)
+ private var currentImageStream: HadoopImageInputStream = _
+
+ // Current row
+ private var currentRow: InternalRow = _
+
+ // Iterator for the current file's tiles
+ private var currentIterator: Iterator[InternalRow] = Iterator.empty
+
+ override def next(): Boolean = {
+ // If current iterator has more elements, return true
+ if (currentIterator.hasNext) {
+ currentRow = currentIterator.next()
+ return true
+ }
+
+ // If current iterator is exhausted, but we have more files, load the next
file
+ if (currentFileIndex < partitionedFiles.length) {
+ loadNextFile()
+ if (currentIterator.hasNext) {
+ currentRow = currentIterator.next()
+ return true
+ }
+ }
+
+ // No more data
+ false
+ }
+
+ override def get(): InternalRow = {
+ currentRow
+ }
+
+ override def close(): Unit = {
+ if (currentRaster != null) {
+ currentRaster.dispose(true)
+ currentRaster = null
+ }
+ if (currentImageStream != null) {
+ currentImageStream.close()
+ currentImageStream = null
+ }
+ }
+
+ private def loadNextFile(): Unit = {
+ // Clean up previous raster and stream if exists
+ if (currentRaster != null) {
+ currentRaster.dispose(true)
+ currentRaster = null
+ }
+ if (currentImageStream != null) {
+ currentImageStream.close()
+ currentImageStream = null
+ }
+
+ if (currentFileIndex >= partitionedFiles.length) {
+ currentIterator = Iterator.empty
+ return
+ }
+
+ val partition = partitionedFiles(currentFileIndex)
+ val path = new Path(new URI(partition.filePath.toString()))
+
+ try {
+ // Open a stream-based reader instead of materializing the entire file
as byte[].
+ // This avoids the 2 GB byte[] limit and reduces memory pressure for
large files.
+ currentImageStream = new HadoopImageInputStream(path, configuration)
+
+ // Create in-db GridCoverage2D from GeoTiff stream. The RenderedImage is
lazy -
+ // pixel data will only be decoded when accessed via
image.getData(Rectangle).
+ currentRaster = RasterConstructors.fromGeoTiff(currentImageStream)
+ currentIterator = rasterToInternalRows(currentRaster, dataSchema,
rasterOptions, path)
+ currentFileIndex += 1
+ } catch {
+ case e: Exception =>
+ if (currentRaster != null) {
+ currentRaster.dispose(true)
+ currentRaster = null
+ }
+ if (currentImageStream != null) {
+ currentImageStream.close()
+ currentImageStream = null
+ }
+ throw e
+ }
+ }
+}
+
+object RasterPartitionReader {
+ def rasterToInternalRows(
+ currentRaster: GridCoverage2D,
+ dataSchema: StructType,
+ rasterOptions: RasterOptions,
+ path: Path): Iterator[InternalRow] = {
+ val retile = rasterOptions.retile
+ val tileWidth = rasterOptions.tileWidth
+ val tileHeight = rasterOptions.tileHeight
+ val padWithNoData = rasterOptions.padWithNoData
+
+ val writer = new UnsafeRowWriter(dataSchema.length)
+ writer.resetRowWriter()
+
+ // Extract the file name from the path
+ val fileName = path.getName
+
+ if (retile) {
+ val (tw, th) = (tileWidth, tileHeight) match {
+ case (Some(tw), Some(th)) => (tw, th)
+ case (None, None) =>
+ // Use the internal tile size of the input raster
+ val tw = currentRaster.getRenderedImage.getTileWidth
+ val th = currentRaster.getRenderedImage.getTileHeight
+ val tileSizeError = {
+ """To resolve this issue, you can try one of the following methods:
+ | 1. Disable retile by setting `.option("retile", "false")`.
+ | 2. Explicitly set `tileWidth` and `tileHeight`.
+ | 3. Convert the raster to a Cloud Optimized GeoTIFF (COG) using
tools like `gdal_translate`.
+ |""".stripMargin
+ }
+ if (tw >= MAX_AUTO_TILE_SIZE || th >= MAX_AUTO_TILE_SIZE) {
+ throw new IllegalArgumentException(
+ s"Internal tile size of $path is too large ($tw x $th). " +
tileSizeError)
+ }
+ if (tw == 0 || th == 0) {
+ throw new IllegalArgumentException(
+ s"Internal tile size of $path contains zero ($tw x $th). " +
tileSizeError)
+ }
+ if (tw.toDouble / th > 10.0 || th.toDouble / tw > 10.0) {
+ throw new IllegalArgumentException(
+ s"Internal tile shape of $path is too thin ($tw x $th). " +
tileSizeError)
+ }
+ (tw, th)
+ case _ =>
+ throw new IllegalArgumentException("Both tileWidth and tileHeight
must be set")
+ }
+
+ val iter =
+ RasterConstructors.generateTiles(currentRaster, null, tw, th,
padWithNoData, Double.NaN)
+ iter.asScala.map { tile =>
+ val tileRaster = tile.getCoverage
+ writer.reset()
+ writeRaster(writer, dataSchema, tileRaster, tile.getTileX,
tile.getTileY, fileName)
+ tileRaster.dispose(true)
+ writer.getRow
+ }
+ } else {
+ writeRaster(writer, dataSchema, currentRaster, 0, 0, fileName)
+ Iterator.single(writer.getRow)
+ }
+ }
+
+ private def writeRaster(
+ writer: UnsafeRowWriter,
+ dataSchema: StructType,
+ raster: GridCoverage2D,
+ x: Int,
+ y: Int,
+ fileName: String): Unit = {
+ dataSchema.fieldNames.zipWithIndex.foreach {
+ case (RASTER, i) => writer.write(i, RasterUDT.serialize(raster))
+ case (TILE_X, i) => writer.write(i, x)
+ case (TILE_Y, i) => writer.write(i, y)
+ case (RASTER_NAME, i) =>
+ if (fileName != null)
+ writer.write(i,
org.apache.spark.unsafe.types.UTF8String.fromString(fileName))
+ else writer.setNullAt(i)
+ case (other, _) =>
+ throw new IllegalArgumentException(s"Unsupported field name: $other")
+ }
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReaderFactory.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReaderFactory.scala
new file mode 100644
index 0000000000..1c5bc2491f
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReaderFactory.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.io.raster
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.read.InputPartition
+import org.apache.spark.sql.connector.read.PartitionReader
+import org.apache.spark.sql.connector.read.PartitionReaderFactory
+import org.apache.spark.sql.execution.datasources.PartitionedFile
+import
org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitionValues
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.SerializableConfiguration
+
+case class RasterPartitionReaderFactory(
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ dataSchema: StructType,
+ readDataSchema: StructType,
+ partitionSchema: StructType,
+ rasterOptions: RasterOptions,
+ filters: Seq[Filter])
+ extends PartitionReaderFactory {
+
+ private def buildReader(
+ partitionedFiles: Array[PartitionedFile]): PartitionReader[InternalRow]
= {
+
+ val fileReader = new RasterPartitionReader(
+ broadcastedConf.value.value,
+ partitionedFiles,
+ readDataSchema,
+ rasterOptions)
+
+ new PartitionReaderWithPartitionValues(
+ fileReader,
+ readDataSchema,
+ partitionSchema,
+ partitionedFiles.head.partitionValues)
+ }
+
+ override def createReader(partition: InputPartition):
PartitionReader[InternalRow] = {
+ partition match {
+ case filePartition: RasterInputPartition =>
buildReader(filePartition.files)
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Unexpected partition type: ${partition.getClass.getCanonicalName}")
+ }
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterScanBuilder.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterScanBuilder.scala
new file mode 100644
index 0000000000..ca83f40d2e
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterScanBuilder.scala
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.io.raster
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.connector.read.Batch
+import org.apache.spark.sql.connector.read.InputPartition
+import org.apache.spark.sql.connector.read.PartitionReaderFactory
+import org.apache.spark.sql.connector.read.Scan
+import org.apache.spark.sql.connector.read.SupportsPushDownLimit
+import org.apache.spark.sql.connector.read.SupportsPushDownTableSample
+import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
+import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
+import org.apache.spark.sql.execution.datasources.FilePartition
+import org.apache.spark.sql.execution.datasources.v2.FileScan
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.util.SerializableConfiguration
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.util.Random
+
+case class RasterScanBuilder(
+ sparkSession: SparkSession,
+ fileIndex: PartitioningAwareFileIndex,
+ schema: StructType,
+ dataSchema: StructType,
+ options: CaseInsensitiveStringMap,
+ rasterOptions: RasterOptions)
+ extends FileScanBuilder(sparkSession, fileIndex, dataSchema)
+ with SupportsPushDownTableSample
+ with SupportsPushDownLimit {
+
+ private var pushedTableSample: Option[TableSampleInfo] = None
+ private var pushedLimit: Option[Int] = None
+
+ override def build(): Scan = {
+ RasterScan(
+ sparkSession,
+ fileIndex,
+ dataSchema,
+ readDataSchema(),
+ readPartitionSchema(),
+ options,
+ rasterOptions,
+ pushedDataFilters,
+ partitionFilters,
+ dataFilters,
+ pushedTableSample,
+ pushedLimit)
+ }
+
+ override def pushTableSample(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long): Boolean = {
+ if (withReplacement || rasterOptions.retile) {
+ false
+ } else {
+ pushedTableSample = Some(TableSampleInfo(lowerBound, upperBound,
withReplacement, seed))
+ true
+ }
+ }
+
+ override def pushLimit(limit: Int): Boolean = {
+ pushedLimit = Some(limit)
+ true
+ }
+
+ override def isPartiallyPushed: Boolean = rasterOptions.retile
+}
+
+case class RasterScan(
+ sparkSession: SparkSession,
+ fileIndex: PartitioningAwareFileIndex,
+ dataSchema: StructType,
+ readDataSchema: StructType,
+ readPartitionSchema: StructType,
+ options: CaseInsensitiveStringMap,
+ rasterOptions: RasterOptions,
+ pushedFilters: Array[Filter],
+ partitionFilters: Seq[Expression] = Seq.empty,
+ dataFilters: Seq[Expression] = Seq.empty,
+ pushedTableSample: Option[TableSampleInfo] = None,
+ pushedLimit: Option[Int] = None)
+ extends FileScan
+ with Batch {
+
+ private lazy val inputPartitions = {
+ var partitions = super.planInputPartitions()
+
+ // Sample the files based on the table sample
+ pushedTableSample.foreach { tableSample =>
+ val r = new Random(tableSample.seed)
+ var partitionIndex = 0
+ partitions = partitions.flatMap {
+ case filePartition: FilePartition =>
+ val files = filePartition.files
+ val sampledFiles = files.filter { _ =>
+ val v = r.nextDouble()
+ v >= tableSample.lowerBound && v < tableSample.upperBound
+ }
+ if (sampledFiles.nonEmpty) {
+ val index = partitionIndex
+ partitionIndex += 1
+ Some(FilePartition(index, sampledFiles))
+ } else {
+ None
+ }
+ case partition =>
+ throw new IllegalArgumentException(
+ s"Unexpected partition type:
${partition.getClass.getCanonicalName}")
+ }
+ }
+
+ // Limit the number of files to read
+ pushedLimit.foreach { limit =>
+ var remaining = limit
+ partitions = partitions.iterator
+ .takeWhile(_ => remaining > 0)
+ .map { partition =>
+ val filePartition = partition.asInstanceOf[FilePartition]
+ val files = filePartition.files
+ if (files.length <= remaining) {
+ remaining -= files.length
+ filePartition
+ } else {
+ val selectedFiles = files.take(remaining)
+ remaining = 0
+ FilePartition(filePartition.index, selectedFiles)
+ }
+ }
+ .toArray
+ }
+
+ partitions
+ }
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ inputPartitions.map {
+ case filePartition: FilePartition =>
+ RasterInputPartition(filePartition.index, filePartition.files)
+ case partition =>
+ throw new IllegalArgumentException(
+ s"Unexpected partition type: ${partition.getClass.getCanonicalName}")
+ }
+ }
+
+ override def createReaderFactory(): PartitionReaderFactory = {
+ val hadoopConf =
sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)
+ val broadcastedConf =
+ sparkSession.sparkContext.broadcast(new
SerializableConfiguration(hadoopConf))
+
+ RasterPartitionReaderFactory(
+ broadcastedConf,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ rasterOptions,
+ pushedFilters)
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterTable.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterTable.scala
new file mode 100644
index 0000000000..2a1eb6cdb3
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterTable.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.io.raster
+
+import org.apache.hadoop.fs.FileStatus
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.catalog.SupportsRead
+import org.apache.spark.sql.connector.catalog.SupportsWrite
+import org.apache.spark.sql.connector.catalog.TableCapability
+import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.connector.write.LogicalWriteInfo
+import org.apache.spark.sql.connector.write.WriteBuilder
+import org.apache.spark.sql.execution.datasources.v2.FileTable
+import org.apache.spark.sql.execution.datasources.FileFormat
+import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
+import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.StructField
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+import java.util.{Set => JSet}
+
+case class RasterTable(
+ name: String,
+ sparkSession: SparkSession,
+ options: CaseInsensitiveStringMap,
+ paths: Seq[String],
+ userSpecifiedSchema: Option[StructType],
+ rasterOptions: RasterOptions,
+ fallbackFileFormat: Class[_ <: FileFormat])
+ extends FileTable(sparkSession, options, paths, userSpecifiedSchema)
+ with SupportsRead
+ with SupportsWrite {
+
+ override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
+ Some(userSpecifiedSchema.getOrElse(RasterTable.inferSchema(rasterOptions)))
+
+ override def formatName: String = "Raster"
+
+ override def capabilities(): JSet[TableCapability] =
+ java.util.EnumSet.of(TableCapability.BATCH_READ)
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder
= {
+ RasterScanBuilder(sparkSession, fileIndex, schema, dataSchema, options,
rasterOptions)
+ }
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder =
+ // Note: this code path will never be taken, since Spark will always fall
back to V1
+ // data source when writing to File source v2. See SPARK-28396: File
source v2 write
+ // path is currently broken.
+ null
+}
+
+object RasterTable {
+ // Names of the fields in the read schema
+ val RASTER = "rast"
+ val TILE_X = "x"
+ val TILE_Y = "y"
+ val RASTER_NAME = "name"
+
+ val MAX_AUTO_TILE_SIZE = 4096
+
+ def inferSchema(options: RasterOptions): StructType = {
+ val baseFields = if (options.retile) {
+ Seq(
+ StructField(RASTER, RasterUDT(), nullable = false),
+ StructField(TILE_X, IntegerType, nullable = false),
+ StructField(TILE_Y, IntegerType, nullable = false))
+ } else {
+ Seq(StructField(RASTER, RasterUDT(), nullable = false))
+ }
+
+ val nameField = Seq(
+ StructField(RASTER_NAME, org.apache.spark.sql.types.StringType, nullable
= true))
+
+ StructType(baseFields ++ nameField)
+ }
+}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
index 8b36f32272..e0128ce3f1 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
@@ -21,6 +21,15 @@ package org.apache.sedona.sql
import org.apache.commons.io.FileUtils
import org.apache.hadoop.hdfs.MiniDFSCluster
import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.LimitExec
+import org.apache.spark.sql.execution.SampleExec
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.sedona_sql.io.raster.RasterTable
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.expr
+import org.geotools.coverage.grid.GridCoverage2D
import org.junit.Assert.assertEquals
import org.scalatest.{BeforeAndAfter, GivenWhenThen}
@@ -239,5 +248,282 @@ class rasterIOTest extends TestBaseScala with
BeforeAndAfter with GivenWhenThen
}
}
- override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir))
+ describe("Raster read test") {
+ it("should read geotiff using raster source with explicit tiling") {
+ val rasterDf = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "true", "tileWidth" -> "64"))
+ .load(rasterdatalocation)
+ assert(rasterDf.count() > 100)
+ rasterDf.collect().foreach { row =>
+ val raster = row.getAs[Object](0).asInstanceOf[GridCoverage2D]
+ assert(raster.getGridGeometry.getGridRange2D.width <= 64)
+ assert(raster.getGridGeometry.getGridRange2D.height <= 64)
+ val x = row.getInt(1)
+ val y = row.getInt(2)
+ assert(x >= 0 && y >= 0)
+ raster.dispose(true)
+ }
+
+ // Test projection push-down
+ rasterDf.selectExpr("y", "rast as r").collect().foreach { row =>
+ val raster = row.getAs[Object](1).asInstanceOf[GridCoverage2D]
+ assert(raster.getGridGeometry.getGridRange2D.width <= 64)
+ assert(raster.getGridGeometry.getGridRange2D.height <= 64)
+ val y = row.getInt(0)
+ assert(y >= 0)
+ raster.dispose(true)
+ }
+ }
+
+ it("should tile geotiff using raster source with padding enabled") {
+ val rasterDf = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "true", "tileWidth" -> "64", "padWithNoData"
-> "true"))
+ .load(rasterdatalocation)
+ assert(rasterDf.count() > 100)
+ rasterDf.collect().foreach { row =>
+ val raster = row.getAs[Object](0).asInstanceOf[GridCoverage2D]
+ assert(raster.getGridGeometry.getGridRange2D.width == 64)
+ assert(raster.getGridGeometry.getGridRange2D.height == 64)
+ val x = row.getInt(1)
+ val y = row.getInt(2)
+ assert(x >= 0 && y >= 0)
+ raster.dispose(true)
+ }
+ }
+
+ it("should push down limit and sample to data source") {
+ FileUtils.cleanDirectory(new File(tempDir))
+
+ val sourceDir = new File(rasterdatalocation)
+ val files = sourceDir.listFiles().filter(_.isFile)
+ var numUniqueFiles = 0
+ var numTotalFiles = 0
+ files.foreach { file =>
+ if (file.getPath.endsWith(".tif") || file.getPath.endsWith(".tiff")) {
+ // Create 4 copies for each file
+ for (i <- 0 until 4) {
+ val destFile = new File(tempDir + "/" + file.getName + "_" + i)
+ FileUtils.copyFile(file, destFile)
+ numTotalFiles += 1
+ }
+ numUniqueFiles += 1
+ }
+ }
+
+ val df = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "false"))
+ .load(tempDir)
+ .withColumn("width", expr("RS_Width(rast)"))
+
+ val dfWithLimit = df.limit(numUniqueFiles)
+ val plan = queryPlan(dfWithLimit)
+ // Global/local limits are all pushed down to data source
+ assert(plan.collect { case e: LimitExec => e }.isEmpty)
+ assert(dfWithLimit.count() == numUniqueFiles)
+
+ val dfWithSample = df.sample(0.3, seed = 42)
+ val planSample = queryPlan(dfWithSample)
+ // Sample is pushed down to data source
+ assert(planSample.collect { case e: SampleExec => e }.isEmpty)
+ val count = dfWithSample.count()
+ assert(count >= numTotalFiles * 0.1 && count <= numTotalFiles * 0.5)
+
+ val dfWithSampleAndLimit = df.sample(0.5, seed =
42).limit(numUniqueFiles)
+ val planBoth = queryPlan(dfWithSampleAndLimit)
+ assert(planBoth.collect { case e: LimitExec => e }.isEmpty)
+ assert(planBoth.collect { case e: SampleExec => e }.isEmpty)
+ assert(dfWithSampleAndLimit.count() == numUniqueFiles)
+
+ // Limit and sample cannot be fully pushed down when retile is enabled
+ val dfReTiledWithSampleAndLimit = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "true"))
+ .load(tempDir)
+ .sample(0.5, seed = 42)
+ .limit(numUniqueFiles)
+ val planRetiled = queryPlan(dfReTiledWithSampleAndLimit)
+ assert(planRetiled.collect { case e: LimitExec => e }.nonEmpty)
+ assert(planRetiled.collect { case e: SampleExec => e }.nonEmpty)
+ }
+
+ it("should read geotiff using raster source without tiling") {
+ val rasterDf = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "false"))
+ .load(rasterdatalocation)
+ assert(rasterDf.schema.fields.length == 2)
+ rasterDf.collect().foreach { row =>
+ val raster = row.getAs[Object](0).asInstanceOf[GridCoverage2D]
+ assert(raster != null)
+ raster.dispose(true)
+ // Should load name correctly
+ val name = row.getString(1)
+ assert(name != null)
+ }
+ }
+
+ it("should read geotiff using raster source with auto-tiling") {
+ val rasterDf = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "true"))
+ .load(rasterdatalocation)
+ val rasterDfNoTiling = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "false"))
+ .load(rasterdatalocation)
+ assert(rasterDf.count() > rasterDfNoTiling.count())
+ }
+
+ it("should throw exception when only tileHeight is specified") {
+ assertThrows[IllegalArgumentException] {
+ val df = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "true", "tileHeight" -> "64"))
+ .load(rasterdatalocation)
+ df.collect()
+ }
+ }
+
+ it("should throw exception when the geotiff is badly tiled") {
+ val exception = intercept[Exception] {
+ val rasterDf = sparkSession.read
+ .format("raster")
+ .options(Map("retile" -> "true"))
+ .load(resourceFolder + "raster_geotiff_color/*")
+ rasterDf.collect()
+ }
+ assert(
+ exception.getMessage.contains(
+ "To resolve this issue, you can try one of the following methods"))
+ }
+
+ it("read partitioned directory") {
+ FileUtils.cleanDirectory(new File(tempDir))
+ Files.createDirectory(new File(tempDir + "/part=1").toPath)
+ Files.createDirectory(new File(tempDir + "/part=2").toPath)
+ FileUtils.copyFile(
+ new File(resourceFolder + "raster/test1.tiff"),
+ new File(tempDir + "/part=1/test1.tiff"))
+ FileUtils.copyFile(
+ new File(resourceFolder + "raster/test2.tiff"),
+ new File(tempDir + "/part=1/test2.tiff"))
+ FileUtils.copyFile(
+ new File(resourceFolder + "raster/test4.tiff"),
+ new File(tempDir + "/part=2/test4.tiff"))
+ FileUtils.copyFile(
+ new File(resourceFolder + "raster/test4.tiff"),
+ new File(tempDir + "/part=2/test5.tiff"))
+
+ val rasterDf = sparkSession.read
+ .format("raster")
+ .load(tempDir)
+ val rows = rasterDf.collect()
+ assert(rows.length >= 4)
+ rows.foreach { row =>
+ val name = row.getAs[String]("name")
+ if (name.startsWith("test1") || name.startsWith("test2")) {
+ assert(row.getAs[Int]("part") == 1)
+ } else {
+ assert(row.getAs[Int]("part") == 2)
+ }
+ }
+ }
+
+ it("read directory recursively from a temp directory with subdirectories")
{
+ // Create temp subdirectories in tempDir
+ FileUtils.cleanDirectory(new File(tempDir))
+ val subDir1 = tempDir + "/subdir1"
+ val subDir2 = tempDir + "/nested/subdir2"
+ new File(subDir1).mkdirs()
+ new File(subDir2).mkdirs()
+
+ // Copy raster files from resourceFolder/raster to the temp
subdirectories
+ val sourceDir = new File(resourceFolder + "raster")
+ val files = sourceDir.listFiles().filter(_.isFile)
+ files.zipWithIndex.foreach { case (file, idx) =>
+ idx % 3 match {
+ case 0 => FileUtils.copyFile(file, new File(tempDir, file.getName))
+ case 1 => FileUtils.copyFile(file, new File(subDir1, file.getName))
+ case 2 => FileUtils.copyFile(file, new File(subDir2, file.getName))
+ }
+ }
+
+ val rasterDfNonRecursive = sparkSession.read
+ .format("raster")
+ .option("retile", "false")
+ .load(sourceDir.getPath)
+
+ val rasterDfRecursive = sparkSession.read
+ .format("raster")
+ .option("retile", "false")
+ .load(tempDir + "/")
+
+ val rowsNonRecursive = rasterDfNonRecursive.collect()
+ val rowsRecursive = rasterDfRecursive.collect()
+ assert(rowsRecursive.length == rowsNonRecursive.length)
+ }
+
+ it("read directory suffixed by /*.tif") {
+ val df = sparkSession.read
+ .format("raster")
+ .option("retile", "false")
+ .load(resourceFolder + "raster")
+
+ val dfTif = sparkSession.read
+ .format("raster")
+ .option("retile", "false")
+ .load(resourceFolder + "raster/*.tif")
+
+ val dfTiff = sparkSession.read
+ .format("raster")
+ .option("retile", "false")
+ .load(resourceFolder + "raster/*.tiff")
+
+ assert(df.count() == dfTif.count() + dfTiff.count())
+ queryPlan(dfTif).collect { case scan: BatchScanExec => scan }.foreach {
scan =>
+ val table = scan.table.asInstanceOf[RasterTable]
+ assert(!table.paths.head.endsWith("*.tif"))
+ assert(table.options.get("pathGlobFilter") == "*.tif")
+ }
+ queryPlan(dfTiff).collect { case scan: BatchScanExec => scan }.foreach {
scan =>
+ val table = scan.table.asInstanceOf[RasterTable]
+ assert(!table.paths.head.endsWith("*.tiff"))
+ assert(table.options.get("pathGlobFilter") == "*.tiff")
+ }
+
+ var dfComplexGlob = sparkSession.read
+ .format("raster")
+ .option("retile", "false")
+ .load(resourceFolder + "raster/test*.tiff")
+ queryPlan(dfComplexGlob).collect { case scan: BatchScanExec => scan
}.foreach { scan =>
+ val table = scan.table.asInstanceOf[RasterTable]
+ assert(!table.paths.head.endsWith("*.tiff"))
+ assert(table.options.get("pathGlobFilter") == "test*.tiff")
+ }
+ dfComplexGlob = sparkSession.read
+ .format("raster")
+ .option("retile", "false")
+ .load(resourceFolder + "raster/*1.tiff")
+ queryPlan(dfComplexGlob).collect { case scan: BatchScanExec => scan
}.foreach { scan =>
+ val table = scan.table.asInstanceOf[RasterTable]
+ assert(!table.paths.head.endsWith("*1.tiff"))
+ assert(table.options.get("pathGlobFilter") == "*1.tiff")
+ }
+ }
+ }
+
+ override def afterAll(): Unit = {
+ FileUtils.deleteDirectory(new File(tempDir))
+ super.afterAll()
+ }
+
+ private def queryPlan(df: DataFrame): SparkPlan = {
+ df.queryExecution.executedPlan match {
+ case adaptive: AdaptiveSparkPlanExec => adaptive.initialPlan
+ case plan: SparkPlan => plan
+ }
+ }
}