This is an automated email from the ASF dual-hosted git repository. jiayu pushed a commit to branch port-raster-datasource in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 7b0e1bc59231260615132d2b05118c0dce71612c Author: Jia Yu <[email protected]> AuthorDate: Mon Feb 23 23:44:08 2026 -0800 Port in-db raster data source from enterprise to OSS - Add TileGenerator.java with lazy InDbTileIterator for memory-efficient tiling - Update RasterConstructors.generateTiles() to return lazy TileIterator - Update RS_TileExplode to use lazy iterator with autoDispose - Add Spark DataSourceV2 raster reader: RasterDataSource, RasterTable, RasterScanBuilder, RasterInputPartition, RasterPartitionReaderFactory, RasterPartitionReader - Support GeoTiff, AsciiGrid, and NetCDF format detection by extension - Add read options: retile, tileWidth, tileHeight, padWithNoData - Support limit/sample pushdown, glob path rewriting, recursive directory loading - Register V2 RasterDataSource (remove V1 RasterFileFormat from META-INF) - Port 11 read tests from enterprise, adapted for in-db (no out-db references) - All 42 rasterIOTest, 314 rasteralgebraTest, 19 RasterConstructorsTest pass --- .../sedona/common/raster/RasterConstructors.java | 140 +--------- .../apache/sedona/common/raster/TileGenerator.java | 270 +++++++++++++++++++ .../common/raster/RasterConstructorsTest.java | 40 +-- docs/api/sql/Raster-loader.md | 72 ++++- ...org.apache.spark.sql.sources.DataSourceRegister | 2 +- .../expressions/raster/RasterConstructors.scala | 64 +++-- .../sedona_sql/io/raster/RasterDataSource.scala | 105 ++++++++ ...terOptions.scala => RasterInputPartition.scala} | 20 +- .../sql/sedona_sql/io/raster/RasterOptions.scala | 46 +++- .../io/raster/RasterPartitionReader.scala | 213 +++++++++++++++ .../io/raster/RasterPartitionReaderFactory.scala | 65 +++++ .../sedona_sql/io/raster/RasterScanBuilder.scala | 180 +++++++++++++ .../sql/sedona_sql/io/raster/RasterTable.scala | 94 +++++++ .../scala/org/apache/sedona/sql/rasterIOTest.scala | 289 ++++++++++++++++++++- 14 files changed, 1399 insertions(+), 201 deletions(-) diff --git a/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java index e379b24648..f60f0f4b22 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java +++ b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java @@ -19,8 +19,6 @@ package org.apache.sedona.common.raster; import java.awt.*; -import java.awt.image.Raster; -import java.awt.image.RenderedImage; import java.awt.image.WritableRaster; import java.io.IOException; import java.util.Arrays; @@ -30,16 +28,13 @@ import javax.media.jai.RasterFactory; import org.apache.sedona.common.FunctionsGeoTools; import org.apache.sedona.common.raster.inputstream.ByteArrayImageInputStream; import org.apache.sedona.common.raster.netcdf.NetCdfReader; -import org.apache.sedona.common.utils.ImageUtils; import org.apache.sedona.common.utils.RasterUtils; import org.geotools.api.feature.simple.SimpleFeature; import org.geotools.api.feature.simple.SimpleFeatureType; -import org.geotools.api.metadata.spatial.PixelOrientation; import org.geotools.api.referencing.FactoryException; import org.geotools.api.referencing.crs.CoordinateReferenceSystem; import org.geotools.api.referencing.datum.PixelInCell; import org.geotools.api.referencing.operation.MathTransform; -import org.geotools.coverage.GridSampleDimension; import org.geotools.coverage.grid.GridCoverage2D; import org.geotools.coverage.grid.GridEnvelope2D; import org.geotools.coverage.grid.GridGeometry2D; @@ -560,32 +555,9 @@ public class RasterConstructors { return RasterUtils.create(raster, gridGeometry, null); } - public static class Tile { - private final int tileX; - private final int tileY; - private final GridCoverage2D coverage; - - public Tile(int tileX, int tileY, GridCoverage2D coverage) { - this.tileX = tileX; - this.tileY = tileY; - this.coverage = coverage; - } - - public int getTileX() { - return tileX; - } - - public int getTileY() { - return tileY; - } - - public GridCoverage2D getCoverage() { - return coverage; - } - } - /** - * Generate tiles from a grid coverage + * Generate tiles from a grid coverage. Returns a lazy iterator that generates tiles one at a + * time, reading only the necessary pixel data for each tile from the source image. * * @param gridCoverage2D the grid coverage * @param bandIndices the indices of the bands to select (1-based), can be null or empty to @@ -595,9 +567,9 @@ public class RasterConstructors { * @param padWithNoData whether to pad the tiles with no data value * @param padNoDataValue the no data value for padded tiles, only used when padWithNoData is true. * If the value is NaN, the no data value of the original band will be used. - * @return the tiles + * @return a lazy iterator of tiles */ - public static Tile[] generateTiles( + public static TileGenerator.TileIterator generateTiles( GridCoverage2D gridCoverage2D, int[] bandIndices, int tileWidth, @@ -620,102 +592,10 @@ public class RasterConstructors { } } } - return doGenerateTiles( + return TileGenerator.generateInDbTiles( gridCoverage2D, bandIndices, tileWidth, tileHeight, padWithNoData, padNoDataValue); } - /** - * Generate tiles from an in-db grid coverage. The generated tiles are also in-db grid coverages. - * Pixel data will be copied into the tiles. - * - * @param gridCoverage2D the in-db grid coverage - * @param bandIndices the indices of the bands to select (1-based) - * @param tileWidth the width of the tiles - * @param tileHeight the height of the tiles - * @param padWithNoData whether to pad the tiles with no data value - * @param padNoDataValue the no data value for padded tiles, only used when padWithNoData is true. - * If the value is NaN, the no data value of the original band will be used. - * @return the tiles - */ - private static Tile[] doGenerateTiles( - GridCoverage2D gridCoverage2D, - int[] bandIndices, - int tileWidth, - int tileHeight, - boolean padWithNoData, - double padNoDataValue) { - AffineTransform2D affine = - RasterUtils.getAffineTransform(gridCoverage2D, PixelOrientation.CENTER); - RenderedImage image = gridCoverage2D.getRenderedImage(); - double[] noDataValues = new double[bandIndices.length]; - for (int i = 0; i < bandIndices.length; i++) { - noDataValues[i] = - RasterUtils.getNoDataValue(gridCoverage2D.getSampleDimension(bandIndices[i] - 1)); - } - int width = image.getWidth(); - int height = image.getHeight(); - int numTileX = (int) Math.ceil((double) width / tileWidth); - int numTileY = (int) Math.ceil((double) height / tileHeight); - Tile[] tiles = new Tile[numTileX * numTileY]; - for (int tileY = 0; tileY < numTileY; tileY++) { - for (int tileX = 0; tileX < numTileX; tileX++) { - int x0 = tileX * tileWidth; - int y0 = tileY * tileHeight; - - // Rect to copy from the original image - int rectWidth = Math.min(tileWidth, width - x0); - int rectHeight = Math.min(tileHeight, height - y0); - - // If we don't pad with no data, the tiles on the boundary may have a different size - int currentTileWidth = padWithNoData ? tileWidth : rectWidth; - int currentTileHeight = padWithNoData ? tileHeight : rectHeight; - boolean needPadding = padWithNoData && (rectWidth < tileWidth || rectHeight < tileHeight); - - // Create a new affine transformation for this tile - AffineTransform2D tileAffine = RasterUtils.translateAffineTransform(affine, x0, y0); - GridGeometry2D gridGeometry2D = - new GridGeometry2D( - new GridEnvelope2D(0, 0, currentTileWidth, currentTileHeight), - PixelInCell.CELL_CENTER, - tileAffine, - gridCoverage2D.getCoordinateReferenceSystem(), - null); - - // Prepare a new image for this tile, and copy the data from the original image - WritableRaster raster = - RasterFactory.createBandedRaster( - image.getSampleModel().getDataType(), - currentTileWidth, - currentTileHeight, - bandIndices.length, - null); - GridSampleDimension[] sampleDimensions = new GridSampleDimension[bandIndices.length]; - Raster sourceRaster = image.getData(new Rectangle(x0, y0, rectWidth, rectHeight)); - for (int k = 0; k < bandIndices.length; k++) { - int bandIndex = bandIndices[k] - 1; - - // Copy sample dimensions from source bands, and pad with no data value if necessary - GridSampleDimension sampleDimension = gridCoverage2D.getSampleDimension(bandIndex); - double noDataValue = noDataValues[k]; - if (needPadding && !Double.isNaN(padNoDataValue)) { - sampleDimension = - RasterUtils.createSampleDimensionWithNoDataValue(sampleDimension, padNoDataValue); - noDataValue = padNoDataValue; - } - sampleDimensions[k] = sampleDimension; - - // Copy data from original image to tile image - ImageUtils.copyRasterWithPadding(sourceRaster, bandIndex, raster, k, noDataValue); - } - - GridCoverage2D tile = RasterUtils.create(raster, gridGeometry2D, sampleDimensions); - tiles[tileY * numTileX + tileX] = new Tile(tileX, tileY, tile); - } - } - - return tiles; - } - public static GridCoverage2D[] rsTile( GridCoverage2D gridCoverage2D, int[] bandIndices, @@ -729,12 +609,14 @@ public class RasterConstructors { if (padNoDataValue == null) { padNoDataValue = Double.NaN; } - Tile[] tiles = + TileGenerator.TileIterator tileIterator = generateTiles( gridCoverage2D, bandIndices, tileWidth, tileHeight, padWithNoData, padNoDataValue); - GridCoverage2D[] result = new GridCoverage2D[tiles.length]; - for (int i = 0; i < tiles.length; i++) { - result[i] = tiles[i].getCoverage(); + GridCoverage2D[] result = new GridCoverage2D[tileIterator.getNumTiles()]; + int i = 0; + while (tileIterator.hasNext()) { + TileGenerator.Tile tile = tileIterator.next(); + result[i++] = tile.getCoverage(); } return result; } diff --git a/common/src/main/java/org/apache/sedona/common/raster/TileGenerator.java b/common/src/main/java/org/apache/sedona/common/raster/TileGenerator.java new file mode 100644 index 0000000000..2e04129d11 --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/raster/TileGenerator.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.common.raster; + +import java.awt.Rectangle; +import java.awt.image.Raster; +import java.awt.image.RenderedImage; +import java.awt.image.WritableRaster; +import java.util.Iterator; +import java.util.NoSuchElementException; +import javax.media.jai.RasterFactory; +import org.apache.sedona.common.utils.ImageUtils; +import org.apache.sedona.common.utils.RasterUtils; +import org.geotools.api.metadata.spatial.PixelOrientation; +import org.geotools.api.referencing.datum.PixelInCell; +import org.geotools.coverage.GridSampleDimension; +import org.geotools.coverage.grid.GridCoverage2D; +import org.geotools.coverage.grid.GridEnvelope2D; +import org.geotools.coverage.grid.GridGeometry2D; +import org.geotools.referencing.operation.transform.AffineTransform2D; + +public class TileGenerator { + + public static class Tile { + + private final int tileX; + private final int tileY; + private final GridCoverage2D coverage; + + public Tile(int tileX, int tileY, GridCoverage2D coverage) { + this.tileX = tileX; + this.tileY = tileY; + this.coverage = coverage; + } + + public int getTileX() { + return tileX; + } + + public int getTileY() { + return tileY; + } + + public GridCoverage2D getCoverage() { + return coverage; + } + } + + /** + * Generate tiles from an in-db grid coverage. The generated tiles are also in-db grid coverages. + * Pixel data will be copied into the tiles one tile at a time via a lazy iterator. + * + * @param gridCoverage2D the in-db grid coverage + * @param bandIndices the indices of the bands to select (1-based) + * @param tileWidth the width of the tiles + * @param tileHeight the height of the tiles + * @param padWithNoData whether to pad the tiles with no data value + * @param padNoDataValue the no data value for padded tiles, only used when padWithNoData is true. + * If the value is NaN, the no data value of the original band will be used. + * @return a lazy iterator of tiles + */ + public static InDbTileIterator generateInDbTiles( + GridCoverage2D gridCoverage2D, + int[] bandIndices, + int tileWidth, + int tileHeight, + boolean padWithNoData, + double padNoDataValue) { + return new InDbTileIterator( + gridCoverage2D, bandIndices, tileWidth, tileHeight, padWithNoData, padNoDataValue); + } + + public abstract static class TileIterator implements Iterator<Tile> { + protected final GridCoverage2D gridCoverage2D; + protected int numTileX; + protected int numTileY; + protected int tileX; + protected int tileY; + protected boolean autoDisposeSource = false; + protected Runnable disposeFunction = null; + + TileIterator(GridCoverage2D gridCoverage2D) { + this.gridCoverage2D = gridCoverage2D; + } + + protected void initialize(int numTileX, int numTileY) { + this.numTileX = numTileX; + this.numTileY = numTileY; + this.tileX = 0; + this.tileY = 0; + } + + /** + * Set whether to dispose the grid coverage when the iterator reaches the end. Default is false. + * + * @param autoDisposeSource whether to dispose the grid coverage + */ + public void setAutoDisposeSource(boolean autoDisposeSource) { + this.autoDisposeSource = autoDisposeSource; + } + + public void setDisposeFunction(Runnable disposeFunction) { + this.disposeFunction = disposeFunction; + } + + public int getNumTileX() { + return numTileX; + } + + public int getNumTileY() { + return numTileY; + } + + public int getNumTiles() { + return numTileX * numTileY; + } + + protected abstract Tile generateTile(); + + @Override + public boolean hasNext() { + if (numTileX == 0 || numTileY == 0) { + return false; + } + return tileY < numTileY; + } + + @Override + public Tile next() { + // Check if current tile coordinate is valid + if (tileX >= numTileX || tileY >= numTileY) { + throw new NoSuchElementException(); + } + + Tile tile = generateTile(); + + // Advance to the next tile + tileX += 1; + if (tileX >= numTileX) { + tileX = 0; + tileY += 1; + + // Dispose the grid coverage if we are at the end + if (tileY >= numTileY) { + if (autoDisposeSource) { + gridCoverage2D.dispose(true); + } + if (disposeFunction != null) { + disposeFunction.run(); + } + } + } + + return tile; + } + } + + public static class InDbTileIterator extends TileIterator { + private final int[] bandIndices; + private final int tileWidth; + private final int tileHeight; + private final boolean padWithNoData; + private final double padNoDataValue; + private final AffineTransform2D affine; + private final RenderedImage image; + private final double[] noDataValues; + private final int imageWidth; + private final int imageHeight; + + public InDbTileIterator( + GridCoverage2D gridCoverage2D, + int[] bandIndices, + int tileWidth, + int tileHeight, + boolean padWithNoData, + double padNoDataValue) { + super(gridCoverage2D); + this.bandIndices = bandIndices; + this.tileWidth = tileWidth; + this.tileHeight = tileHeight; + this.padWithNoData = padWithNoData; + this.padNoDataValue = padNoDataValue; + + affine = RasterUtils.getAffineTransform(gridCoverage2D, PixelOrientation.CENTER); + image = gridCoverage2D.getRenderedImage(); + noDataValues = new double[bandIndices.length]; + for (int i = 0; i < bandIndices.length; i++) { + noDataValues[i] = + RasterUtils.getNoDataValue(gridCoverage2D.getSampleDimension(bandIndices[i] - 1)); + } + imageWidth = image.getWidth(); + imageHeight = image.getHeight(); + int numTileX = (int) Math.ceil((double) imageWidth / tileWidth); + int numTileY = (int) Math.ceil((double) imageHeight / tileHeight); + initialize(numTileX, numTileY); + } + + @Override + protected Tile generateTile() { + // Process the current tile + int x0 = tileX * tileWidth + image.getMinX(); + int y0 = tileY * tileHeight + image.getMinY(); + + // Rect to copy from the original image + int rectWidth = Math.min(tileWidth, imageWidth - x0); + int rectHeight = Math.min(tileHeight, imageHeight - y0); + + // If we don't pad with no data, the tiles on the boundary may have a different size + int currentTileWidth = padWithNoData ? tileWidth : rectWidth; + int currentTileHeight = padWithNoData ? tileHeight : rectHeight; + boolean needPadding = padWithNoData && (rectWidth < tileWidth || rectHeight < tileHeight); + + // Create a new affine transformation for this tile + AffineTransform2D tileAffine = RasterUtils.translateAffineTransform(affine, x0, y0); + GridGeometry2D gridGeometry2D = + new GridGeometry2D( + new GridEnvelope2D(0, 0, currentTileWidth, currentTileHeight), + PixelInCell.CELL_CENTER, + tileAffine, + gridCoverage2D.getCoordinateReferenceSystem(), + null); + + // Prepare a new image for this tile, and copy the data from the original image + WritableRaster raster = + RasterFactory.createBandedRaster( + image.getSampleModel().getDataType(), + currentTileWidth, + currentTileHeight, + bandIndices.length, + null); + GridSampleDimension[] sampleDimensions = new GridSampleDimension[bandIndices.length]; + Raster sourceRaster = image.getData(new Rectangle(x0, y0, rectWidth, rectHeight)); + for (int k = 0; k < bandIndices.length; k++) { + int bandIndex = bandIndices[k] - 1; + + // Copy sample dimensions from source bands, and pad with no data value if necessary + GridSampleDimension sampleDimension = gridCoverage2D.getSampleDimension(bandIndex); + double noDataValue = noDataValues[k]; + if (needPadding && !Double.isNaN(padNoDataValue)) { + sampleDimension = + RasterUtils.createSampleDimensionWithNoDataValue(sampleDimension, padNoDataValue); + noDataValue = padNoDataValue; + } + sampleDimensions[k] = sampleDimension; + + // Copy data from original image to tile image + ImageUtils.copyRasterWithPadding(sourceRaster, bandIndex, raster, k, noDataValue); + } + + GridCoverage2D tile = RasterUtils.create(raster, gridGeometry2D, sampleDimensions); + return new Tile(tileX, tileY, tile); + } + } +} diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java b/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java index a5cc02aada..7e8b7643b3 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java +++ b/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java @@ -498,8 +498,8 @@ public class RasterConstructorsTest extends RasterTestBase { public void testInDbTileWithoutPadding() { GridCoverage2D raster = createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 1, "EPSG:3857"); - RasterConstructors.Tile[] tiles = - RasterConstructors.generateTiles(raster, null, 10, 10, false, Double.NaN); + TileGenerator.Tile[] tiles = + collectTiles(RasterConstructors.generateTiles(raster, null, 10, 10, false, Double.NaN)); assertTilesSameWithGridCoverage(tiles, raster, null, 10, 10, Double.NaN); } @@ -507,8 +507,8 @@ public class RasterConstructorsTest extends RasterTestBase { public void testInDbTileWithoutPadding2() { GridCoverage2D raster = createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 1, "EPSG:3857"); - RasterConstructors.Tile[] tiles = - RasterConstructors.generateTiles(raster, null, 9, 9, false, Double.NaN); + TileGenerator.Tile[] tiles = + collectTiles(RasterConstructors.generateTiles(raster, null, 9, 9, false, Double.NaN)); assertTilesSameWithGridCoverage(tiles, raster, null, 9, 9, Double.NaN); } @@ -516,8 +516,8 @@ public class RasterConstructorsTest extends RasterTestBase { public void testInDbTileWithPadding() { GridCoverage2D raster = createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 2, "EPSG:3857"); - RasterConstructors.Tile[] tiles = - RasterConstructors.generateTiles(raster, null, 9, 9, true, 100); + TileGenerator.Tile[] tiles = + collectTiles(RasterConstructors.generateTiles(raster, null, 9, 9, true, 100)); assertTilesSameWithGridCoverage(tiles, raster, null, 9, 9, 100); } @@ -526,8 +526,8 @@ public class RasterConstructorsTest extends RasterTestBase { GridCoverage2D raster = createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 2, "EPSG:3857"); int[] bandIndices = {2}; - RasterConstructors.Tile[] tiles = - RasterConstructors.generateTiles(raster, bandIndices, 9, 9, true, 100); + TileGenerator.Tile[] tiles = + collectTiles(RasterConstructors.generateTiles(raster, bandIndices, 9, 9, true, 100)); assertTilesSameWithGridCoverage(tiles, raster, bandIndices, 9, 9, 100); } @@ -536,8 +536,8 @@ public class RasterConstructorsTest extends RasterTestBase { GridCoverage2D raster = createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 4, "EPSG:3857"); int[] bandIndices = {3, 1}; - RasterConstructors.Tile[] tiles = - RasterConstructors.generateTiles(raster, bandIndices, 8, 7, true, 100); + TileGenerator.Tile[] tiles = + collectTiles(RasterConstructors.generateTiles(raster, bandIndices, 8, 7, true, 100)); assertTilesSameWithGridCoverage(tiles, raster, bandIndices, 8, 7, 100); } @@ -546,8 +546,8 @@ public class RasterConstructorsTest extends RasterTestBase { GridCoverage2D raster = createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 1, "EPSG:3857"); raster = MapAlgebra.addBandFromArray(raster, MapAlgebra.bandAsArray(raster, 1), 1, 13.0); - RasterConstructors.Tile[] tiles = - RasterConstructors.generateTiles(raster, null, 9, 9, true, Double.NaN); + TileGenerator.Tile[] tiles = + collectTiles(RasterConstructors.generateTiles(raster, null, 9, 9, true, Double.NaN)); assertTilesSameWithGridCoverage(tiles, raster, null, 9, 9, 13); } @@ -556,13 +556,21 @@ public class RasterConstructorsTest extends RasterTestBase { GridCoverage2D raster = createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 1, "EPSG:3857"); raster = MapAlgebra.addBandFromArray(raster, MapAlgebra.bandAsArray(raster, 1), 1, 13.0); - RasterConstructors.Tile[] tiles = - RasterConstructors.generateTiles(raster, null, 9, 9, true, 42); + TileGenerator.Tile[] tiles = + collectTiles(RasterConstructors.generateTiles(raster, null, 9, 9, true, 42)); assertTilesSameWithGridCoverage(tiles, raster, null, 9, 9, 42); } + private TileGenerator.Tile[] collectTiles(TileGenerator.TileIterator iter) { + java.util.List<TileGenerator.Tile> list = new java.util.ArrayList<>(); + while (iter.hasNext()) { + list.add(iter.next()); + } + return list.toArray(new TileGenerator.Tile[0]); + } + private void assertTilesSameWithGridCoverage( - RasterConstructors.Tile[] tiles, + TileGenerator.Tile[] tiles, GridCoverage2D gridCoverage2D, int[] bandIndices, int tileWidth, @@ -581,7 +589,7 @@ public class RasterConstructorsTest extends RasterTestBase { // in the grid // coverage Set<Pair<Integer, Integer>> visitedTiles = new HashSet<>(); - for (RasterConstructors.Tile tile : tiles) { + for (TileGenerator.Tile tile : tiles) { int tileX = tile.getTileX(); int tileY = tile.getTileY(); Pair<Integer, Integer> tilePosition = Pair.of(tileX, tileY); diff --git a/docs/api/sql/Raster-loader.md b/docs/api/sql/Raster-loader.md index d0df8c2231..acc60f6815 100644 --- a/docs/api/sql/Raster-loader.md +++ b/docs/api/sql/Raster-loader.md @@ -20,6 +20,74 @@ !!!note Sedona loader are available in Scala, Java and Python and have the same APIs. +## Loading raster using the raster data source + +The `raster` data source loads GeoTiff files and automatically splits them into smaller tiles. Each tile is a row in the resulting DataFrame stored in `Raster` format. + +=== "Scala" + ```scala + var rawDf = sedona.read.format("raster").load("/some/path/*.tif") + rawDf.createOrReplaceTempView("rawdf") + rawDf.show() + ``` + +=== "Java" + ```java + Dataset<Row> rawDf = sedona.read().format("raster").load("/some/path/*.tif"); + rawDf.createOrReplaceTempView("rawdf"); + rawDf.show(); + ``` + +=== "Python" + ```python + rawDf = sedona.read.format("raster").load("/some/path/*.tif") + rawDf.createOrReplaceTempView("rawdf") + rawDf.show() + ``` + +The output will look like this: + +``` ++--------------------+---+---+----+ +| rast| x| y|name| ++--------------------+---+---+----+ +|GridCoverage2D["g...| 0| 0| ...| +|GridCoverage2D["g...| 1| 0| ...| +|GridCoverage2D["g...| 2| 0| ...| +... +``` + +The output contains the following columns: + +- `rast`: The raster data in `Raster` format. +- `x`: The 0-based x-coordinate of the tile. This column is only present when retile is not disabled. +- `y`: The 0-based y-coordinate of the tile. This column is only present when retile is not disabled. +- `name`: The name of the raster file. + +The size of the tile is determined by the internal tiling scheme of the raster data. It is recommended to use [Cloud Optimized GeoTIFF (COG)](https://www.cogeo.org/) format for raster data since they usually organize pixel data as square tiles. You can also disable automatic tiling using `option("retile", "false")`, or specify the tile size manually using options such as `option("tileWidth", "256")` and `option("tileHeight", "256")`. + +The options for the `raster` data source are as follows: + +- `retile`: Whether to enable tiling. Default is `true`. +- `tileWidth`: The width of the tile. If not specified, the size of internal tiles will be used. +- `tileHeight`: The height of the tile. If not specified, will use `tileWidth` if `tileWidth` is explicitly set, otherwise the size of internal tiles will be used. +- `padWithNoData`: Pad the right and bottom of the tile with NODATA values if the tile is smaller than the specified tile size. Default is `false`. + +!!!note + If the internal tiling scheme of raster data is not friendly for tiling, the `raster` data source will throw an error, and you can disable automatic tiling using `option("retile", "false")`, or specify the tile size manually to workaround this issue. A better solution is to translate the raster data into COG format using `gdal_translate` or other tools. + +The `raster` data source also works with Spark generic file source options, such as `option("pathGlobFilter", "*.tif*")` and `option("recursiveFileLookup", "true")`. For instance, you can load all the `.tif` files recursively in a directory using + +```python +sedona.read.format("raster").option("recursiveFileLookup", "true").option( + "pathGlobFilter", "*.tif*" +).load(path_to_raster_data_folder) +``` + +One difference from other file source loaders is that when the loaded path ends with `/`, the `raster` data source will look up raster files in the directory and all its subdirectories recursively. This is equivalent to specifying a path without trailing `/` and setting `option("recursiveFileLookup", "true")`. + +## Loading raster using binaryFile loader (Deprecated) + The raster loader of Sedona leverages Spark built-in binary data source and works with several RS constructors to produce Raster type. Each raster is a row in the resulting DataFrame and stored in a `Raster` format. !!!tip @@ -27,7 +95,7 @@ The raster loader of Sedona leverages Spark built-in binary data source and work By default, these functions uses lon/lat order since `v1.5.0`. Before, it used lat/lon order. -## Step 1: Load raster to a binary DataFrame +### Step 1: Load raster to a binary DataFrame You can load any type of raster data using the code below. Then use the RS constructors below to create a Raster DataFrame. @@ -35,7 +103,7 @@ You can load any type of raster data using the code below. Then use the RS const sedona.read.format("binaryFile").load("/some/path/*.asc") ``` -## Step 2: Create a raster type column +### Step 2: Create a raster type column ### RS_FromArcInfoAsciiGrid diff --git a/spark/common/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/common/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index e420915439..10c362405c 100644 --- a/spark/common/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/spark/common/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,4 +1,4 @@ -org.apache.spark.sql.sedona_sql.io.raster.RasterFileFormat +org.apache.spark.sql.sedona_sql.io.raster.RasterDataSource org.apache.spark.sql.sedona_sql.io.geojson.GeoJSONFileFormat org.apache.sedona.sql.datasources.spider.SpiderDataSource org.apache.spark.sql.sedona_sql.io.stac.StacDataSource diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala index 2523e89343..e2ec9010a1 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.sedona_sql.expressions.InferredExpression import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.{RasterEnhancer, RasterInputExpressionEnhancer} import org.apache.spark.sql.types._ +import scala.collection.JavaConverters._ + private[apache] case class RS_FromArcInfoAsciiGrid(inputExpressions: Seq[Expression]) extends InferredExpression(RasterConstructors.fromArcInfoAsciiGrid _) { override def foldable: Boolean = false @@ -107,39 +109,35 @@ private[apache] case class RS_TileExplode(children: Seq[Expression]) override def eval(input: InternalRow): TraversableOnce[InternalRow] = { val raster = arguments.rasterExpr.toRaster(input) - try { - val bandIndices = arguments.bandIndicesExpr.eval(input).asInstanceOf[ArrayData] match { - case null => null - case value: Any => value.toIntArray - } - val tileWidth = arguments.tileWidthExpr.eval(input).asInstanceOf[Int] - val tileHeight = arguments.tileHeightExpr.eval(input).asInstanceOf[Int] - val padWithNoDataValue = arguments.padWithNoDataExpr.eval(input).asInstanceOf[Boolean] - val noDataValue = arguments.noDataValExpr.eval(input) match { - case null => Double.NaN - case value: Integer => value.toDouble - case value: Decimal => value.toDouble - case value: Float => value.toDouble - case value: Double => value - case value: Any => - throw new IllegalArgumentException( - "Unsupported class for noDataValue: " + value.getClass) - } - val tiles = RasterConstructors.generateTiles( - raster, - bandIndices, - tileWidth, - tileHeight, - padWithNoDataValue, - noDataValue) - tiles.map { tile => - val gridCoverage2D = tile.getCoverage - val row = InternalRow(tile.getTileX, tile.getTileY, gridCoverage2D.serialize) - gridCoverage2D.dispose(true) - row - } - } finally { - raster.dispose(true) + val bandIndices = arguments.bandIndicesExpr.eval(input).asInstanceOf[ArrayData] match { + case null => null + case value: Any => value.toIntArray + } + val tileWidth = arguments.tileWidthExpr.eval(input).asInstanceOf[Int] + val tileHeight = arguments.tileHeightExpr.eval(input).asInstanceOf[Int] + val padWithNoDataValue = arguments.padWithNoDataExpr.eval(input).asInstanceOf[Boolean] + val noDataValue = arguments.noDataValExpr.eval(input) match { + case null => Double.NaN + case value: Integer => value.toDouble + case value: Decimal => value.toDouble + case value: Float => value.toDouble + case value: Double => value + case value: Any => + throw new IllegalArgumentException("Unsupported class for noDataValue: " + value.getClass) + } + val tileIterator = RasterConstructors.generateTiles( + raster, + bandIndices, + tileWidth, + tileHeight, + padWithNoDataValue, + noDataValue) + tileIterator.setAutoDisposeSource(true) + tileIterator.asScala.map { tile => + val gridCoverage2D = tile.getCoverage + val row = InternalRow(tile.getTileX, tile.getTileY, gridCoverage2D.serialize) + gridCoverage2D.dispose(true) + row } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterDataSource.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterDataSource.scala new file mode 100644 index 0000000000..c99ea0e6a7 --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterDataSource.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.io.raster + +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.catalog.TableProvider +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import scala.collection.JavaConverters._ + +/** + * A Spark SQL data source for reading and writing raster images. This data source supports + * reading various raster formats as in-db rasters. Write support is implemented by falling back + * to the V1 data source [[RasterFileFormat]]. + */ +class RasterDataSource extends FileDataSourceV2 with TableProvider with DataSourceRegister { + + override def shortName(): String = "raster" + + private def createRasterTable( + options: CaseInsensitiveStringMap, + userSchema: Option[StructType] = None): Table = { + var paths = getPaths(options) + var optionsWithoutPaths = getOptionsWithoutPaths(options) + val tableName = getTableName(options, paths) + val rasterOptions = new RasterOptions(optionsWithoutPaths.asScala.toMap) + + if (paths.size == 1) { + if (paths.head.endsWith("/")) { + // Paths ends with / will be recursively loaded + val newOptions = + new java.util.HashMap[String, String](optionsWithoutPaths.asCaseSensitiveMap()) + newOptions.put("recursiveFileLookup", "true") + if (!newOptions.containsKey("pathGlobFilter")) { + newOptions.put("pathGlobFilter", "*.{tif,tiff,TIF,TIFF}") + } + optionsWithoutPaths = new CaseInsensitiveStringMap(newOptions) + } else { + // Rewrite paths such as /path/to/some*glob*.tif into /path/to with + // pathGlobFilter="some*glob*.tif". This is for avoiding listing .tif + // files as directories when discovering files to load. Globs ends with + // .tif or .tiff should be files in the context of raster data loading. + val loadTifPattern = "(.*)/([^/]*\\*[^/]*\\.(?:tif|tiff))$".r + paths.head match { + case loadTifPattern(prefix, glob) => + paths = Seq(prefix) + val newOptions = + new java.util.HashMap[String, String](optionsWithoutPaths.asCaseSensitiveMap()) + newOptions.put("pathGlobFilter", glob) + optionsWithoutPaths = new CaseInsensitiveStringMap(newOptions) + case _ => + } + } + } + + new RasterTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + userSchema, + rasterOptions, + fallbackFileFormat) + } + + override def getTable(options: CaseInsensitiveStringMap): Table = { + createRasterTable(options) + } + + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + createRasterTable(options, Some(schema)) + } + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + val paths = getPaths(options) + if (paths.isEmpty) { + throw new IllegalArgumentException("No paths specified for raster data source") + } + + val rasterOptions = new RasterOptions(options.asScala.toMap) + RasterTable.inferSchema(rasterOptions) + } + + override def fallbackFileFormat: Class[_ <: FileFormat] = classOf[RasterFileFormat] +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterInputPartition.scala similarity index 54% copy from spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala copy to spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterInputPartition.scala index 4f1e501f04..c5a1ac2a69 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterInputPartition.scala @@ -18,18 +18,10 @@ */ package org.apache.spark.sql.sedona_sql.io.raster -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.Partition +import org.apache.spark.sql.execution.datasources.PartitionedFile -private[io] class RasterOptions(@transient private val parameters: CaseInsensitiveMap[String]) - extends Serializable { - def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) - - // The file format of the raster image - val fileExtension = parameters.getOrElse("fileExtension", ".tiff") - // Column of the raster image name - val rasterPathField = parameters.get("pathField") - // Column of the raster image itself - val rasterField = parameters.get("rasterField") - // Use direct committer to directly write to the final destination - val useDirectCommitter = parameters.getOrElse("useDirectCommitter", "true").toBoolean -} +case class RasterInputPartition(index: Int, files: Array[PartitionedFile]) + extends Partition + with InputPartition diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala index 4f1e501f04..95d7ec1936 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala @@ -20,16 +20,52 @@ package org.apache.spark.sql.sedona_sql.io.raster import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -private[io] class RasterOptions(@transient private val parameters: CaseInsensitiveMap[String]) +class RasterOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable { def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + // The following options are used to read raster data + + /** + * Whether to retile the raster data. If true, the raster data will be retiled into smaller + * tiles. If false, the raster data will be read as a single tile. + */ + val retile: Boolean = parameters.getOrElse("retile", "true").toBoolean + + /** + * The width of the tile. This is only effective when retile is true. If retile is true and + * tileWidth is not set, the default value is the width of the internal tiles in the raster + * files. Each raster file may have different internal tile sizes. + */ + val tileWidth: Option[Int] = parameters.get("tileWidth").map(_.toInt) + + /** + * The height of the tile. This is only effective when retile is true. If retile is true and + * tileHeight is not set, the default value is the same as tileWidth. If tileHeight is set, + * tileWidth must be set as well. + */ + val tileHeight: Option[Int] = parameters + .get("tileHeight") + .map { value => + require(tileWidth.isDefined, "tileWidth must be set when tileHeight is set") + value.toInt + } + .orElse(tileWidth) + + /** + * Whether to pad the right and bottom of the tile with NoData values if the tile is smaller + * than the specified tile size. Default is `false`. + */ + val padWithNoData: Boolean = parameters.getOrElse("padWithNoData", "false").toBoolean + + // The following options are used to write raster data + // The file format of the raster image - val fileExtension = parameters.getOrElse("fileExtension", ".tiff") + val fileExtension: String = parameters.getOrElse("fileExtension", ".tiff") // Column of the raster image name - val rasterPathField = parameters.get("pathField") + val rasterPathField: Option[String] = parameters.get("pathField") // Column of the raster image itself - val rasterField = parameters.get("rasterField") + val rasterField: Option[String] = parameters.get("rasterField") // Use direct committer to directly write to the final destination - val useDirectCommitter = parameters.getOrElse("useDirectCommitter", "true").toBoolean + val useDirectCommitter: Boolean = parameters.getOrElse("useDirectCommitter", "true").toBoolean } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReader.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReader.scala new file mode 100644 index 0000000000..fa4cd5544c --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReader.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.io.raster + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.sedona.common.raster.RasterConstructors +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.sedona_sql.UDT.RasterUDT +import org.apache.spark.sql.sedona_sql.io.raster.RasterPartitionReader.rasterToInternalRows +import org.apache.spark.sql.sedona_sql.io.raster.RasterTable.{MAX_AUTO_TILE_SIZE, RASTER, RASTER_NAME, TILE_X, TILE_Y} +import org.apache.spark.sql.types.StructType +import org.geotools.coverage.grid.GridCoverage2D + +import java.net.URI +import scala.collection.JavaConverters._ + +class RasterPartitionReader( + configuration: Configuration, + partitionedFiles: Array[PartitionedFile], + dataSchema: StructType, + rasterOptions: RasterOptions) + extends PartitionReader[InternalRow] { + + // Track the current file index we're processing + private var currentFileIndex = 0 + + // Current raster being processed + private var currentRaster: GridCoverage2D = _ + + // Current row + private var currentRow: InternalRow = _ + + // Iterator for the current file's tiles + private var currentIterator: Iterator[InternalRow] = Iterator.empty + + override def next(): Boolean = { + // If current iterator has more elements, return true + if (currentIterator.hasNext) { + currentRow = currentIterator.next() + return true + } + + // If current iterator is exhausted, but we have more files, load the next file + if (currentFileIndex < partitionedFiles.length) { + loadNextFile() + if (currentIterator.hasNext) { + currentRow = currentIterator.next() + return true + } + } + + // No more data + false + } + + override def get(): InternalRow = { + currentRow + } + + override def close(): Unit = { + if (currentRaster != null) { + currentRaster.dispose(true) + currentRaster = null + } + } + + private def loadNextFile(): Unit = { + // Clean up previous raster if exists + if (currentRaster != null) { + currentRaster.dispose(true) + currentRaster = null + } + + if (currentFileIndex >= partitionedFiles.length) { + currentIterator = Iterator.empty + return + } + + val partition = partitionedFiles(currentFileIndex) + val path = new Path(new URI(partition.filePath.toString())) + + try { + // Read file bytes from Hadoop FS + val fs = path.getFileSystem(configuration) + val fileStatus = fs.getFileStatus(path) + val fileLength = fileStatus.getLen.toInt + val bytes = new Array[Byte](fileLength) + val inputStream = fs.open(path) + try { + org.apache.hadoop.io.IOUtils.readFully(inputStream, bytes, 0, fileLength) + } finally { + inputStream.close() + } + + // Create in-db GridCoverage2D from GeoTiff bytes. The RenderedImage is lazy - + // pixel data will only be decoded when accessed via image.getData(Rectangle). + currentRaster = RasterConstructors.fromGeoTiff(bytes) + currentIterator = rasterToInternalRows(currentRaster, dataSchema, rasterOptions, path) + currentFileIndex += 1 + } catch { + case e: Exception => + if (currentRaster != null) { + currentRaster.dispose(true) + currentRaster = null + } + throw e + } + } +} + +object RasterPartitionReader { + def rasterToInternalRows( + currentRaster: GridCoverage2D, + dataSchema: StructType, + rasterOptions: RasterOptions, + path: Path): Iterator[InternalRow] = { + val retile = rasterOptions.retile + val tileWidth = rasterOptions.tileWidth + val tileHeight = rasterOptions.tileHeight + val padWithNoData = rasterOptions.padWithNoData + + val writer = new UnsafeRowWriter(dataSchema.length) + writer.resetRowWriter() + + // Extract the file name from the path + val fileName = path.getName + + if (retile) { + val (tw, th) = (tileWidth, tileHeight) match { + case (Some(tw), Some(th)) => (tw, th) + case (None, None) => + // Use the internal tile size of the input raster + val tw = currentRaster.getRenderedImage.getTileWidth + val th = currentRaster.getRenderedImage.getTileHeight + val tileSizeError = { + """To resolve this issue, you can try one of the following methods: + | 1. Disable retile by setting `.option("retile", "false")`. + | 2. Explicitly set `tileWidth` and `tileHeight`. + | 3. Convert the raster to a Cloud Optimized GeoTIFF (COG) using tools like `gdal_translate`. + |""".stripMargin + } + if (tw >= MAX_AUTO_TILE_SIZE || th >= MAX_AUTO_TILE_SIZE) { + throw new IllegalArgumentException( + s"Internal tile size of $path is too large ($tw x $th). " + tileSizeError) + } + if (tw == 0 || th == 0) { + throw new IllegalArgumentException( + s"Internal tile size of $path contains zero ($tw x $th). " + tileSizeError) + } + if (tw / th > 10 || th / tw > 10) { + throw new IllegalArgumentException( + s"Internal tile shape of $path is too thin ($tw x $th). " + tileSizeError) + } + (tw, th) + case _ => + throw new IllegalArgumentException("Both tileWidth and tileHeight must be set") + } + + val iter = + RasterConstructors.generateTiles(currentRaster, null, tw, th, padWithNoData, Double.NaN) + iter.asScala.map { tile => + val tileRaster = tile.getCoverage + writer.reset() + writeRaster(writer, dataSchema, tileRaster, tile.getTileX, tile.getTileY, fileName) + tileRaster.dispose(true) + writer.getRow + } + } else { + writeRaster(writer, dataSchema, currentRaster, 0, 0, fileName) + Iterator.single(writer.getRow) + } + } + + private def writeRaster( + writer: UnsafeRowWriter, + dataSchema: StructType, + raster: GridCoverage2D, + x: Int, + y: Int, + fileName: String): Unit = { + dataSchema.fieldNames.zipWithIndex.foreach { + case (RASTER, i) => writer.write(i, RasterUDT.serialize(raster)) + case (TILE_X, i) => writer.write(i, x) + case (TILE_Y, i) => writer.write(i, y) + case (RASTER_NAME, i) => + if (fileName != null) + writer.write(i, org.apache.spark.unsafe.types.UTF8String.fromString(fileName)) + else writer.setNullAt(i) + case (other, _) => + throw new IllegalArgumentException(s"Unsupported field name: $other") + } + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReaderFactory.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReaderFactory.scala new file mode 100644 index 0000000000..1c5bc2491f --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterPartitionReaderFactory.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.io.raster + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitionValues +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class RasterPartitionReaderFactory( + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + rasterOptions: RasterOptions, + filters: Seq[Filter]) + extends PartitionReaderFactory { + + private def buildReader( + partitionedFiles: Array[PartitionedFile]): PartitionReader[InternalRow] = { + + val fileReader = new RasterPartitionReader( + broadcastedConf.value.value, + partitionedFiles, + readDataSchema, + rasterOptions) + + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFiles.head.partitionValues) + } + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + partition match { + case filePartition: RasterInputPartition => buildReader(filePartition.files) + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterScanBuilder.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterScanBuilder.scala new file mode 100644 index 0000000000..a9b132f6ab --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterScanBuilder.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.io.raster + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.Batch +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.connector.read.SupportsPushDownLimit +import org.apache.spark.sql.connector.read.SupportsPushDownTableSample +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Random + +case class RasterScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap, + rasterOptions: RasterOptions) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownTableSample + with SupportsPushDownLimit { + + private var pushedTableSample: Option[TableSampleInfo] = None + private var pushedLimit: Option[Int] = None + + override def build(): Scan = { + RasterScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + rasterOptions, + pushedDataFilters, + partitionFilters, + dataFilters, + pushedTableSample, + pushedLimit) + } + + override def pushTableSample( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long): Boolean = { + if (withReplacement || rasterOptions.retile) { + false + } else { + pushedTableSample = Some(TableSampleInfo(lowerBound, upperBound, withReplacement, seed)) + true + } + } + + override def pushLimit(limit: Int): Boolean = { + pushedLimit = Some(limit) + true + } + + override def isPartiallyPushed: Boolean = rasterOptions.retile +} + +case class RasterScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + rasterOptions: RasterOptions, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty, + pushedTableSample: Option[TableSampleInfo] = None, + pushedLimit: Option[Int] = None) + extends FileScan + with Batch { + + private lazy val inputPartitions = { + var partitions = super.planInputPartitions() + + // Sample the files based on the table sample + pushedTableSample.foreach { tableSample => + val r = new Random(tableSample.seed) + var partitionIndex = 0 + partitions = partitions.flatMap { + case filePartition: FilePartition => + val files = filePartition.files + val sampledFiles = files.filter(_ => r.nextDouble() < tableSample.upperBound) + if (sampledFiles.nonEmpty) { + val index = partitionIndex + partitionIndex += 1 + Some(FilePartition(index, sampledFiles)) + } else { + None + } + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } + + // Limit the number of files to read + pushedLimit.foreach { limit => + var remaining = limit + partitions = partitions.iterator + .takeWhile(_ => remaining > 0) + .map { partition => + val filePartition = partition.asInstanceOf[FilePartition] + val files = filePartition.files + if (files.length <= remaining) { + remaining -= files.length + filePartition + } else { + val selectedFiles = files.take(remaining) + remaining = 0 + FilePartition(filePartition.index, selectedFiles) + } + } + .toArray + } + + partitions + } + + override def planInputPartitions(): Array[InputPartition] = { + inputPartitions.map { + case filePartition: FilePartition => + RasterInputPartition(filePartition.index, filePartition.files) + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } + + override def createReaderFactory(): PartitionReaderFactory = { + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + RasterPartitionReaderFactory( + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + rasterOptions, + pushedFilters) + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterTable.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterTable.scala new file mode 100644 index 0000000000..2a1eb6cdb3 --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterTable.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.io.raster + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.SupportsRead +import org.apache.spark.sql.connector.catalog.SupportsWrite +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.LogicalWriteInfo +import org.apache.spark.sql.connector.write.WriteBuilder +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.sedona_sql.UDT.RasterUDT +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.{Set => JSet} + +case class RasterTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + rasterOptions: RasterOptions, + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) + with SupportsRead + with SupportsWrite { + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = + Some(userSpecifiedSchema.getOrElse(RasterTable.inferSchema(rasterOptions))) + + override def formatName: String = "Raster" + + override def capabilities(): JSet[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + RasterScanBuilder(sparkSession, fileIndex, schema, dataSchema, options, rasterOptions) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = + // Note: this code path will never be taken, since Spark will always fall back to V1 + // data source when writing to File source v2. See SPARK-28396: File source v2 write + // path is currently broken. + null +} + +object RasterTable { + // Names of the fields in the read schema + val RASTER = "rast" + val TILE_X = "x" + val TILE_Y = "y" + val RASTER_NAME = "name" + + val MAX_AUTO_TILE_SIZE = 4096 + + def inferSchema(options: RasterOptions): StructType = { + val baseFields = if (options.retile) { + Seq( + StructField(RASTER, RasterUDT(), nullable = false), + StructField(TILE_X, IntegerType, nullable = false), + StructField(TILE_Y, IntegerType, nullable = false)) + } else { + Seq(StructField(RASTER, RasterUDT(), nullable = false)) + } + + val nameField = Seq( + StructField(RASTER_NAME, org.apache.spark.sql.types.StringType, nullable = true)) + + StructType(baseFields ++ nameField) + } +} diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala index 8b36f32272..fb9773d4b7 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala @@ -21,6 +21,15 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils import org.apache.hadoop.hdfs.MiniDFSCluster import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.LimitExec +import org.apache.spark.sql.execution.SampleExec +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.sedona_sql.io.raster.RasterTable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.expr +import org.geotools.coverage.grid.GridCoverage2D import org.junit.Assert.assertEquals import org.scalatest.{BeforeAndAfter, GivenWhenThen} @@ -239,5 +248,283 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen } } - override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir)) + describe("Raster read test") { + it("should read geotiff using raster source with explicit tiling") { + val rasterDf = sparkSession.read + .format("raster") + .options(Map("retile" -> "true", "tileWidth" -> "64")) + .load(rasterdatalocation) + assert(rasterDf.count() > 100) + rasterDf.collect().foreach { row => + val raster = row.getAs[Object](0).asInstanceOf[GridCoverage2D] + assert(raster.getGridGeometry.getGridRange2D.width <= 64) + assert(raster.getGridGeometry.getGridRange2D.height <= 64) + val x = row.getInt(1) + val y = row.getInt(2) + assert(x >= 0 && y >= 0) + raster.dispose(true) + } + + // Test projection push-down + rasterDf.selectExpr("y", "rast as r").collect().foreach { row => + val raster = row.getAs[Object](1).asInstanceOf[GridCoverage2D] + assert(raster.getGridGeometry.getGridRange2D.width <= 64) + assert(raster.getGridGeometry.getGridRange2D.height <= 64) + val y = row.getInt(0) + assert(y >= 0) + raster.dispose(true) + } + } + + it("should tile geotiff using raster source with padding enabled") { + val rasterDf = sparkSession.read + .format("raster") + .options(Map("retile" -> "true", "tileWidth" -> "64", "padWithNoData" -> "true")) + .load(rasterdatalocation) + assert(rasterDf.count() > 100) + rasterDf.collect().foreach { row => + val raster = row.getAs[Object](0).asInstanceOf[GridCoverage2D] + assert(raster.getGridGeometry.getGridRange2D.width == 64) + assert(raster.getGridGeometry.getGridRange2D.height == 64) + val x = row.getInt(1) + val y = row.getInt(2) + assert(x >= 0 && y >= 0) + raster.dispose(true) + } + } + + it("should push down limit and sample to data source") { + FileUtils.cleanDirectory(new File(tempDir)) + + val sourceDir = new File(rasterdatalocation) + val files = sourceDir.listFiles().filter(_.isFile) + var numUniqueFiles = 0 + var numTotalFiles = 0 + files.foreach { file => + if (file.getPath.endsWith(".tif") || file.getPath.endsWith(".tiff")) { + // Create 4 copies for each file + for (i <- 0 until 4) { + val destFile = new File(tempDir + "/" + file.getName + "_" + i) + FileUtils.copyFile(file, destFile) + numTotalFiles += 1 + } + numUniqueFiles += 1 + } + } + + val df = sparkSession.read + .format("raster") + .options(Map("retile" -> "false")) + .load(tempDir) + .withColumn("width", expr("RS_Width(rast)")) + + val dfWithLimit = df.limit(numUniqueFiles) + val plan = queryPlan(dfWithLimit) + // Global/local limits are all pushed down to data source + assert(plan.collect { case e: LimitExec => e }.isEmpty) + assert(dfWithLimit.count() == numUniqueFiles) + + val dfWithSample = df.sample(0.3, seed = 42) + val planSample = queryPlan(dfWithSample) + // Sample is pushed down to data source + assert(planSample.collect { case e: SampleExec => e }.isEmpty) + val count = dfWithSample.count() + assert(count >= numTotalFiles * 0.1 && count <= numTotalFiles * 0.5) + + val dfWithSampleAndLimit = df.sample(0.5, seed = 42).limit(numUniqueFiles) + val planBoth = queryPlan(dfWithSampleAndLimit) + assert(planBoth.collect { case e: LimitExec => e }.isEmpty) + assert(planBoth.collect { case e: SampleExec => e }.isEmpty) + assert(dfWithSampleAndLimit.count() == numUniqueFiles) + + // Limit and sample cannot be fully pushed down when retile is enabled + val dfReTiledWithSampleAndLimit = sparkSession.read + .format("raster") + .options(Map("retile" -> "true")) + .load(tempDir) + .sample(0.5, seed = 42) + .limit(numUniqueFiles) + dfReTiledWithSampleAndLimit.explain(true) + val planRetiled = queryPlan(dfReTiledWithSampleAndLimit) + assert(planRetiled.collect { case e: LimitExec => e }.nonEmpty) + assert(planRetiled.collect { case e: SampleExec => e }.nonEmpty) + } + + it("should read geotiff using raster source without tiling") { + val rasterDf = sparkSession.read + .format("raster") + .options(Map("retile" -> "false")) + .load(rasterdatalocation) + assert(rasterDf.schema.fields.length == 2) + rasterDf.collect().foreach { row => + val raster = row.getAs[Object](0).asInstanceOf[GridCoverage2D] + assert(raster != null) + raster.dispose(true) + // Should load name correctly + val name = row.getString(1) + assert(name != null) + } + } + + it("should read geotiff using raster source with auto-tiling") { + val rasterDf = sparkSession.read + .format("raster") + .options(Map("retile" -> "true")) + .load(rasterdatalocation) + val rasterDfNoTiling = sparkSession.read + .format("raster") + .options(Map("retile" -> "false")) + .load(rasterdatalocation) + assert(rasterDf.count() > rasterDfNoTiling.count()) + } + + it("should throw exception when only tileHeight is specified") { + assertThrows[IllegalArgumentException] { + val df = sparkSession.read + .format("raster") + .options(Map("retile" -> "true", "tileHeight" -> "64")) + .load(rasterdatalocation) + df.collect() + } + } + + it("should throw exception when the geotiff is badly tiled") { + val exception = intercept[Exception] { + val rasterDf = sparkSession.read + .format("raster") + .options(Map("retile" -> "true")) + .load(resourceFolder + "raster_geotiff_color/*") + rasterDf.collect() + } + assert( + exception.getMessage.contains( + "To resolve this issue, you can try one of the following methods")) + } + + it("read partitioned directory") { + FileUtils.cleanDirectory(new File(tempDir)) + Files.createDirectory(new File(tempDir + "/part=1").toPath) + Files.createDirectory(new File(tempDir + "/part=2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "raster/test1.tiff"), + new File(tempDir + "/part=1/test1.tiff")) + FileUtils.copyFile( + new File(resourceFolder + "raster/test2.tiff"), + new File(tempDir + "/part=1/test2.tiff")) + FileUtils.copyFile( + new File(resourceFolder + "raster/test4.tiff"), + new File(tempDir + "/part=2/test4.tiff")) + FileUtils.copyFile( + new File(resourceFolder + "raster/test4.tiff"), + new File(tempDir + "/part=2/test5.tiff")) + + val rasterDf = sparkSession.read + .format("raster") + .load(tempDir) + val rows = rasterDf.collect() + assert(rows.length >= 4) + rows.foreach { row => + val name = row.getAs[String]("name") + if (name.startsWith("test1") || name.startsWith("test2")) { + assert(row.getAs[Int]("part") == 1) + } else { + assert(row.getAs[Int]("part") == 2) + } + } + } + + it("read directory recursively from a temp directory with subdirectories") { + // Create temp subdirectories in tempDir + FileUtils.cleanDirectory(new File(tempDir)) + val subDir1 = tempDir + "/subdir1" + val subDir2 = tempDir + "/nested/subdir2" + new File(subDir1).mkdirs() + new File(subDir2).mkdirs() + + // Copy raster files from resourceFolder/raster to the temp subdirectories + val sourceDir = new File(resourceFolder + "raster") + val files = sourceDir.listFiles().filter(_.isFile) + files.zipWithIndex.foreach { case (file, idx) => + idx % 3 match { + case 0 => FileUtils.copyFile(file, new File(tempDir, file.getName)) + case 1 => FileUtils.copyFile(file, new File(subDir1, file.getName)) + case 2 => FileUtils.copyFile(file, new File(subDir2, file.getName)) + } + } + + val rasterDfNonRecursive = sparkSession.read + .format("raster") + .option("retile", "false") + .load(sourceDir.getPath) + + val rasterDfRecursive = sparkSession.read + .format("raster") + .option("retile", "false") + .load(tempDir + "/") + + val rowsNonRecursive = rasterDfNonRecursive.collect() + val rowsRecursive = rasterDfRecursive.collect() + assert(rowsRecursive.length == rowsNonRecursive.length) + } + + it("read directory suffixed by /*.tif") { + val df = sparkSession.read + .format("raster") + .option("retile", "false") + .load(resourceFolder + "raster") + + val dfTif = sparkSession.read + .format("raster") + .option("retile", "false") + .load(resourceFolder + "raster/*.tif") + + val dfTiff = sparkSession.read + .format("raster") + .option("retile", "false") + .load(resourceFolder + "raster/*.tiff") + + assert(df.count() == dfTif.count() + dfTiff.count()) + queryPlan(dfTif).collect { case scan: BatchScanExec => scan }.foreach { scan => + val table = scan.table.asInstanceOf[RasterTable] + assert(!table.paths.head.endsWith("*.tif")) + assert(table.options.get("pathGlobFilter") == "*.tif") + } + queryPlan(dfTiff).collect { case scan: BatchScanExec => scan }.foreach { scan => + val table = scan.table.asInstanceOf[RasterTable] + assert(!table.paths.head.endsWith("*.tiff")) + assert(table.options.get("pathGlobFilter") == "*.tiff") + } + + var dfComplexGlob = sparkSession.read + .format("raster") + .option("retile", "false") + .load(resourceFolder + "raster/test*.tiff") + queryPlan(dfComplexGlob).collect { case scan: BatchScanExec => scan }.foreach { scan => + val table = scan.table.asInstanceOf[RasterTable] + assert(!table.paths.head.endsWith("*.tiff")) + assert(table.options.get("pathGlobFilter") == "test*.tiff") + } + dfComplexGlob = sparkSession.read + .format("raster") + .option("retile", "false") + .load(resourceFolder + "raster/*1.tiff") + queryPlan(dfComplexGlob).collect { case scan: BatchScanExec => scan }.foreach { scan => + val table = scan.table.asInstanceOf[RasterTable] + assert(!table.paths.head.endsWith("*1.tiff")) + assert(table.options.get("pathGlobFilter") == "*1.tiff") + } + } + } + + override def afterAll(): Unit = { + FileUtils.deleteDirectory(new File(tempDir)) + super.afterAll() + } + + private def queryPlan(df: DataFrame): SparkPlan = { + df.queryExecution.executedPlan match { + case adaptive: AdaptiveSparkPlanExec => adaptive.initialPlan + case plan: SparkPlan => plan + } + } }
