This is an automated email from the ASF dual-hosted git repository. jiayu pushed a commit to branch geotiff-enhance in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 2beeee1071f896c1e64ce07f20349bf9e7a5761f Author: Jia Yu <[email protected]> AuthorDate: Thu May 11 00:36:11 2023 -0700 Add a working solution --- .../spark/sql/sedona_sql/io/HadoopUtils.scala | 107 ------------- .../sedona_sql/io/{ => raster}/GeotiffSchema.scala | 90 ++++++----- .../io/{ => raster}/ImageReadOptions.scala | 2 +- .../io/{ => raster}/ImageWriteOptions.scala | 2 +- .../sedona_sql/io/raster/RasterFileFormat.scala | 166 +++++++++++++++++++++ .../RasterOptions.scala} | 18 +-- ...org.apache.spark.sql.sources.DataSourceRegister | 5 +- .../io/{ => raster}/GeotiffFileFormat.scala | 5 +- .../scala/org/apache/sedona/sql/rasterIOTest.scala | 53 ++++++- ...org.apache.spark.sql.sources.DataSourceRegister | 5 +- .../io/{ => raster}/GeotiffFileFormat.scala | 0 .../scala/org/apache/sedona/sql/rasterIOTest.scala | 53 ++++++- 12 files changed, 318 insertions(+), 188 deletions(-) diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/HadoopUtils.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/HadoopUtils.scala deleted file mode 100644 index 54c5377f..00000000 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/HadoopUtils.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.sedona_sql.io - -import org.apache.commons.io.FilenameUtils -import org.apache.hadoop.conf.{Configuration, Configured} -import org.apache.hadoop.fs.{Path, PathFilter} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.spark.sql.SparkSession - -import scala.language.existentials -import scala.util.Random - -object RecursiveFlag { - - /** Sets a value of spark recursive flag - * - * @param value value to set - * @param spark existing spark session - * @return previous value of this flag - */ - def setRecursiveFlag(value: Option[String], spark: SparkSession): Option[String] = { - val flagName = FileInputFormat.INPUT_DIR_RECURSIVE - val hadoopConf = spark.sparkContext.hadoopConfiguration - val old = Option(hadoopConf.get(flagName)) - - value match { - case Some(v) => hadoopConf.set(flagName, v) - case None => hadoopConf.unset(flagName) - } - - old - } -} - - -/** Filter that allows loading a fraction of HDFS files. */ -class SamplePathFilter extends Configured with PathFilter { - val random = { - val rd = new Random() - rd.setSeed(0) - rd - } - - // Ratio of files to be read from disk - var sampleRatio: Double = 1 - - override def setConf(conf: Configuration): Unit = { - if (conf != null) { - sampleRatio = conf.getDouble(SamplePathFilter.ratioParam, 1) - } - } - - override def accept(path: Path): Boolean = { - // Note: checking fileSystem.isDirectory is very slow here, so we use basic rules instead - !SamplePathFilter.isFile(path) || - random.nextDouble() < sampleRatio - } -} - -object SamplePathFilter { - val ratioParam = "sampleRatio" - - def isFile(path: Path): Boolean = FilenameUtils.getExtension(path.toString) != "" - - /** Set/unset hdfs PathFilter - * - * @param value Filter class that is passed to HDFS - * @param sampleRatio Fraction of the files that the filter picks - * @param spark Existing Spark session - * @return - */ - def setPathFilter(value: Option[Class[_]], sampleRatio: Option[Double] = None, spark: SparkSession) - : Option[Class[_]] = { - val flagName = FileInputFormat.PATHFILTER_CLASS - val hadoopConf = spark.sparkContext.hadoopConfiguration - val old = Option(hadoopConf.getClass(flagName, null)) - if (sampleRatio.isDefined) { - hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio.get) - } else { - hadoopConf.unset(SamplePathFilter.ratioParam) - None - } - - value match { - case Some(v) => hadoopConf.setClass(flagName, v, classOf[PathFilter]) - case None => hadoopConf.unset(flagName) - } - old - } -} \ No newline at end of file diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffSchema.scala similarity index 85% rename from sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala rename to sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffSchema.scala index 5a3a3595..90c0ec55 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffSchema.scala @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.spark.sql.sedona_sql.io +package org.apache.spark.sql.sedona_sql.io.raster import org.apache.spark.sql.Row import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT @@ -38,8 +38,8 @@ object GeotiffSchema { val undefinedImageType = "Undefined" /** - * Schema for the image column: Row(String,Geometry, Int, Int, Int, Array[Double]) - */ + * Schema for the image column: Row(String,Geometry, Int, Int, Int, Array[Double]) + */ val columnSchema = StructType( StructField("origin", StringType, true) :: StructField("geometry", StringType, true) :: @@ -51,73 +51,72 @@ object GeotiffSchema { val imageFields: Array[String] = columnSchema.fieldNames /** - * DataFrame with a single column of images named "image" (nullable) - */ + * DataFrame with a single column of images named "image" (nullable) + */ val imageSchema = StructType(StructField("image", columnSchema, true) :: Nil) /** - * Gets the origin of the image - * - * @return The origin of the image - */ + * Gets the origin of the image + * + * @return The origin of the image + */ def getOrigin(row: Row): String = row.getString(0) /** - * Gets the origin of the image - * - * @return The origin of the image - */ + * Gets the origin of the image + * + * @return The origin of the image + */ def getGeometry(row: Row): GeometryUDT = row.getAs[GeometryUDT](1) /** - * Gets the height of the image - * - * @return The height of the image - */ + * Gets the height of the image + * + * @return The height of the image + */ def getHeight(row: Row): Int = row.getInt(2) /** - * Gets the width of the image - * - * @return The width of the image - */ + * Gets the width of the image + * + * @return The width of the image + */ def getWidth(row: Row): Int = row.getInt(3) /** - * Gets the number of channels in the image - * - * @return The number of bands in the image - */ + * Gets the number of channels in the image + * + * @return The number of bands in the image + */ def getNBands(row: Row): Int = row.getInt(4) /** - * Gets the image data - * - * @return The image data - */ + * Gets the image data + * + * @return The image data + */ def getData(row: Row): Array[Double] = row.getAs[Array[Double]](5) /** - * Default values for the invalid image - * - * @param origin Origin of the invalid image - * @return Row with the default values - */ + * Default values for the invalid image + * + * @param origin Origin of the invalid image + * @return Row with the default values + */ private[io] def invalidImageRow(origin: String): Row = Row(Row(origin, -1, -1, -1, Array.ofDim[Byte](0))) /** - * - * Convert a GeoTiff image into a dataframe row - * - * - * @param origin Arbitrary string that identifies the image - * @param bytes Image bytes (for example, jpeg) - * @return DataFrame Row or None (if the decompression fails) - * - */ + * + * Convert a GeoTiff image into a dataframe row + * + * @param origin Arbitrary string that identifies the image + * @param bytes Image bytes (for example, jpeg) + * @return DataFrame Row or None (if the decompression fails) + * + */ private[io] def decode(origin: String, bytes: Array[Byte], imageSourceOptions: ImageReadOptions): Option[Row] = { @@ -215,8 +214,3 @@ object GeotiffSchema { Some(Row(Row(origin, polygon.toText, height, width, nBands, decoded))) } } - - - - - diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageReadOptions.scala similarity index 97% rename from sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala rename to sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageReadOptions.scala index f73fc7cf..552b8f8e 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageReadOptions.scala @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.spark.sql.sedona_sql.io +package org.apache.spark.sql.sedona_sql.io.raster import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageWriteOptions.scala similarity index 96% copy from sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala copy to sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageWriteOptions.scala index 8653c93a..6a730faa 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageWriteOptions.scala @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.spark.sql.sedona_sql.io +package org.apache.spark.sql.sedona_sql.io.raster import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterFileFormat.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterFileFormat.scala new file mode 100644 index 00000000..54bf05d4 --- /dev/null +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterFileFormat.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.spark.sql.sedona_sql.io.raster + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FSDataOutputStream, FileStatus, Path} +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.sedona.common.raster.Serde +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.types.StructType +import org.geotools.gce.arcgrid.ArcGridWriter +import org.geotools.gce.geotiff.GeoTiffWriter +import org.opengis.coverage.grid.GridCoverageWriter + +import java.io.IOException +import java.nio.file.Paths +import java.util.UUID + +private[spark] class RasterFileFormat extends FileFormat with DataSourceRegister { + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = None + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val rasterOptions = new RasterOptions(options) + if (!isValidRasterSchema(dataSchema)) { + throw new IllegalArgumentException("Invalid GeoTiff Schema") + } + + new OutputWriterFactory { + override def getFileExtension(context: TaskAttemptContext): String = "" + + override def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { + new RasterFileWriter(path, rasterOptions, dataSchema, context) + } + } + } + + override def shortName(): String = "raster" + + override protected def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + throw new UnsupportedOperationException("Please use Binary data source to reading raster files") + } + + private def isValidRasterSchema(dataSchema: StructType): Boolean = { + var imageColExist: Boolean = false + val fields = dataSchema.fields + fields.foreach(field => { + if (field.dataType.typeName.equals("raster")) { + imageColExist = true + } + }) + imageColExist + } + +} + +// class for writing raster images +private class RasterFileWriter(savePath: String, + rasterOptions: RasterOptions, + dataSchema: StructType, + context: TaskAttemptContext) extends OutputWriter { + + private val hfs = new Path(savePath).getFileSystem(context.getConfiguration) + + override def write(row: InternalRow): Unit = { + val rowFields: InternalRow = row + val schemaFields: StructType = dataSchema + var imageColIndex = -1 + for (i <- schemaFields.indices) { + if (schemaFields.fields(i).dataType.typeName.equals("raster")) { + imageColIndex = i + } + } + // Get grid coverage 2D from the row + val rasterRaw = rowFields.getBinary(imageColIndex) + // If the raster is null, return + if (rasterRaw == null) return + // If the raster is not null, deserialize it + val gridCoverage2D = Serde.deserialize(rasterRaw) + var writer:GridCoverageWriter = null + var out:FSDataOutputStream = null + if (rasterOptions.rasterFormat.equalsIgnoreCase("geotiff")) { + // If the output path is not provided, generate a random UUID as the file name + val fileExtension = ".tiff" + val rasterFilePath = getRasterFilePath(fileExtension, rowFields, schemaFields, rasterOptions) + // create the write path + out = hfs.create(new Path(Paths.get(savePath, new Path(rasterFilePath).getName).toString)) + writer = new GeoTiffWriter(out) + } else if (rasterOptions.rasterFormat.equalsIgnoreCase("arcgrid")) { + val fileExtension = ".asc" + val rasterFilePath = getRasterFilePath(fileExtension, rowFields, schemaFields, rasterOptions) + out = hfs.create(new Path(Paths.get(savePath, new Path(rasterFilePath).getName).toString)) + writer = new ArcGridWriter(out) + } else + throw new IllegalArgumentException("Invalid raster format") + + // write the image to file + try { + writer.write(gridCoverage2D) + writer.dispose() + out.close() + } catch { + case e@(_: IllegalArgumentException | _: IOException) => + // TODO Auto-generated catch block + e.printStackTrace() + } + } + + override def close(): Unit = { + hfs.close() + } + + def path(): String = { + savePath + } + + private def getRasterFilePath(fileExtension: String, row: InternalRow, schema: StructType, rasterOptions: RasterOptions): String = { + // If the output path is not provided, generate a random UUID as the file name + var rasterFilePath = UUID.randomUUID().toString + if (rasterOptions.rasterPathField.isDefined) { + val rasterFilePathRaw = row.getString(schema.fieldIndex(rasterOptions.rasterPathField.get)) + // If the output path field is provided, but the value is null, generate a random UUID as the file name + if (rasterFilePathRaw != null) { + // remove the extension if exists + if (rasterFilePathRaw.contains(".")) rasterFilePath = rasterFilePathRaw.substring(0, rasterFilePathRaw.lastIndexOf(".")) + else rasterFilePath = rasterFilePathRaw + } + } + rasterFilePath + fileExtension + } +} diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala similarity index 57% rename from sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala rename to sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala index 8653c93a..518dca65 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala @@ -16,21 +16,15 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.spark.sql.sedona_sql.io +package org.apache.spark.sql.sedona_sql.io.raster import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -private[io] class ImageWriteOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable { +private[io] class RasterOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable { def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) - // Optional parameters for writing GeoTiff - val writeToCRS = parameters.getOrElse("writeToCRS", "EPSG:4326") - val colImage = parameters.getOrElse("fieldImage", "image") - val colOrigin = parameters.getOrElse("fieldOrigin", "origin") - val colBands = parameters.getOrElse("fieldNBands", "nBands") - val colWidth = parameters.getOrElse("fieldWidth", "width") - val colHeight = parameters.getOrElse("fieldHeight", "height") - val colGeometry = parameters.getOrElse("fieldGeometry", "geometry") - val colData = parameters.getOrElse("fieldData", "data") - + // Optional parameters for writing rasters to different image formats + val rasterFormat = parameters.getOrElse("rasterType", "geotiff") + // Column of the raster image name + val rasterPathField = parameters.get("pathField") } \ No newline at end of file diff --git a/sql/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 68ea723a..4352e818 100644 --- a/sql/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,2 +1,3 @@ -org.apache.spark.sql.sedona_sql.io.GeotiffFileFormat -org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat \ No newline at end of file +org.apache.spark.sql.sedona_sql.io.raster.GeotiffFileFormat +org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat +org.apache.spark.sql.sedona_sql.io.raster.RasterFileFormat \ No newline at end of file diff --git a/sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala b/sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffFileFormat.scala similarity index 99% rename from sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala rename to sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffFileFormat.scala index 842e28f3..f3360ae3 100644 --- a/sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala +++ b/sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffFileFormat.scala @@ -18,21 +18,20 @@ */ -package org.apache.spark.sql.sedona_sql.io +package org.apache.spark.sql.sedona_sql.io.raster import com.google.common.io.{ByteStreams, Closeables} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.sedona.sql.utils.GeometrySerializer -import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.sources.{DataSourceRegister, Filter} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration import org.geotools.coverage.CoverageFactoryFinder diff --git a/sql/spark-3.0/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala b/sql/spark-3.0/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala index 7206ac38..6b52f05e 100644 --- a/sql/spark-3.0/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala +++ b/sql/spark-3.0/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala @@ -19,15 +19,19 @@ package org.apache.sedona.sql +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.SaveMode import org.locationtech.jts.geom.Geometry import org.scalatest.{BeforeAndAfter, GivenWhenThen} import java.io.File +import java.nio.file.Files import scala.collection.mutable class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen { var rasterdatalocation: String = resourceFolder + "raster/" + val tempDir: String = Files.createTempDirectory("sedona_raster_io_test_").toFile.getAbsolutePath describe("Raster IO test") { it("Should Pass geotiff loading without readFromCRS and readToCRS") { @@ -158,7 +162,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen it("Should Pass geotiff file writing with coalesce") { var df = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readToCRS", "EPSG:4326").load(rasterdatalocation) df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands") - val savePath = resourceFolder + "raster-written/" + val savePath = tempDir + "/raster-written/" df.coalesce(1).write.mode("overwrite").format("geotiff").save(savePath) var loadPath = savePath @@ -185,7 +189,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen it("Should Pass geotiff file writing with writeToCRS") { var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation) df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands") - val savePath = resourceFolder + "raster-written/" + val savePath = tempDir + "/raster-written/" df.coalesce(1).write.mode("overwrite").format("geotiff").option("writeToCRS", "EPSG:4499").save(savePath) var loadPath = savePath @@ -212,7 +216,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen it("Should Pass geotiff file writing without coalesce") { var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation) df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands") - val savePath = resourceFolder + "raster-written/" + val savePath = tempDir + "/raster-written/" df.write.mode("overwrite").format("geotiff").save(savePath) var imageCount = 0 @@ -347,11 +351,48 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen } } } - - } -} + it("should read geotiff using binary source and write geotiff back to disk using raster source") { + var df = sparkSession.read.format("binaryFile").load(rasterdatalocation) + var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "length") + val rasterCount = rasterDf.count() + rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(tempDir + "/geotiff-written") + df = sparkSession.read.format("binaryFile").load(tempDir + "/geotiff-written/*") + rasterDf = df.selectExpr("RS_FromGeoTiff(content)") + assert(rasterDf.count() == rasterCount) + } + it("should read and write geotiff using given options") { + var df = sparkSession.read.format("binaryFile").load(rasterdatalocation) + var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "path") + val rasterCount = rasterDf.count() + rasterDf.write.format("raster").option("rasterType", "geotiff").option("pathField", "path").mode(SaveMode.Overwrite).save(tempDir + "/geotiff-written") + df = sparkSession.read.format("binaryFile").load(tempDir + "/geotiff-written/*") + rasterDf = df.selectExpr("RS_FromGeoTiff(content)") + assert(rasterDf.count() == rasterCount) + } + it("should read geotiff and write asc") { + var df = sparkSession.read.format("binaryFile").load(rasterdatalocation) + var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "path") + val rasterCount = rasterDf.count() + rasterDf.write.format("raster").option("rasterType", "arcgrid").option("pathField", "path").mode(SaveMode.Overwrite).save(tempDir + "/asc-written") + df = sparkSession.read.format("binaryFile").load(tempDir + "/asc-written/*") + rasterDf = df.selectExpr("RS_FromArcInfoAsciiGrid(content)") + assert(rasterDf.count() == rasterCount) + } + it("should handle null") { + var df = sparkSession.read.format("binaryFile").load(rasterdatalocation) + var rasterDf = df.selectExpr("RS_FromGeoTiff(null)", "length") + val rasterCount = rasterDf.count() + rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(tempDir + "/geotiff-written") + df = sparkSession.read.format("binaryFile").load(tempDir + "/geotiff-written/*") + rasterDf = df.selectExpr("RS_FromGeoTiff(content)") + assert(rasterCount == 3) + assert(rasterDf.count() == 0) + } + } + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir)) +} \ No newline at end of file diff --git a/sql/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 68ea723a..4352e818 100644 --- a/sql/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,2 +1,3 @@ -org.apache.spark.sql.sedona_sql.io.GeotiffFileFormat -org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat \ No newline at end of file +org.apache.spark.sql.sedona_sql.io.raster.GeotiffFileFormat +org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat +org.apache.spark.sql.sedona_sql.io.raster.RasterFileFormat \ No newline at end of file diff --git a/sql/spark-3.4/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala b/sql/spark-3.4/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffFileFormat.scala similarity index 100% rename from sql/spark-3.4/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala rename to sql/spark-3.4/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffFileFormat.scala diff --git a/sql/spark-3.4/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala b/sql/spark-3.4/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala index 7206ac38..6b52f05e 100644 --- a/sql/spark-3.4/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala +++ b/sql/spark-3.4/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala @@ -19,15 +19,19 @@ package org.apache.sedona.sql +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.SaveMode import org.locationtech.jts.geom.Geometry import org.scalatest.{BeforeAndAfter, GivenWhenThen} import java.io.File +import java.nio.file.Files import scala.collection.mutable class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen { var rasterdatalocation: String = resourceFolder + "raster/" + val tempDir: String = Files.createTempDirectory("sedona_raster_io_test_").toFile.getAbsolutePath describe("Raster IO test") { it("Should Pass geotiff loading without readFromCRS and readToCRS") { @@ -158,7 +162,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen it("Should Pass geotiff file writing with coalesce") { var df = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readToCRS", "EPSG:4326").load(rasterdatalocation) df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands") - val savePath = resourceFolder + "raster-written/" + val savePath = tempDir + "/raster-written/" df.coalesce(1).write.mode("overwrite").format("geotiff").save(savePath) var loadPath = savePath @@ -185,7 +189,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen it("Should Pass geotiff file writing with writeToCRS") { var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation) df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands") - val savePath = resourceFolder + "raster-written/" + val savePath = tempDir + "/raster-written/" df.coalesce(1).write.mode("overwrite").format("geotiff").option("writeToCRS", "EPSG:4499").save(savePath) var loadPath = savePath @@ -212,7 +216,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen it("Should Pass geotiff file writing without coalesce") { var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation) df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands") - val savePath = resourceFolder + "raster-written/" + val savePath = tempDir + "/raster-written/" df.write.mode("overwrite").format("geotiff").save(savePath) var imageCount = 0 @@ -347,11 +351,48 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen } } } - - } -} + it("should read geotiff using binary source and write geotiff back to disk using raster source") { + var df = sparkSession.read.format("binaryFile").load(rasterdatalocation) + var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "length") + val rasterCount = rasterDf.count() + rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(tempDir + "/geotiff-written") + df = sparkSession.read.format("binaryFile").load(tempDir + "/geotiff-written/*") + rasterDf = df.selectExpr("RS_FromGeoTiff(content)") + assert(rasterDf.count() == rasterCount) + } + it("should read and write geotiff using given options") { + var df = sparkSession.read.format("binaryFile").load(rasterdatalocation) + var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "path") + val rasterCount = rasterDf.count() + rasterDf.write.format("raster").option("rasterType", "geotiff").option("pathField", "path").mode(SaveMode.Overwrite).save(tempDir + "/geotiff-written") + df = sparkSession.read.format("binaryFile").load(tempDir + "/geotiff-written/*") + rasterDf = df.selectExpr("RS_FromGeoTiff(content)") + assert(rasterDf.count() == rasterCount) + } + it("should read geotiff and write asc") { + var df = sparkSession.read.format("binaryFile").load(rasterdatalocation) + var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "path") + val rasterCount = rasterDf.count() + rasterDf.write.format("raster").option("rasterType", "arcgrid").option("pathField", "path").mode(SaveMode.Overwrite).save(tempDir + "/asc-written") + df = sparkSession.read.format("binaryFile").load(tempDir + "/asc-written/*") + rasterDf = df.selectExpr("RS_FromArcInfoAsciiGrid(content)") + assert(rasterDf.count() == rasterCount) + } + it("should handle null") { + var df = sparkSession.read.format("binaryFile").load(rasterdatalocation) + var rasterDf = df.selectExpr("RS_FromGeoTiff(null)", "length") + val rasterCount = rasterDf.count() + rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(tempDir + "/geotiff-written") + df = sparkSession.read.format("binaryFile").load(tempDir + "/geotiff-written/*") + rasterDf = df.selectExpr("RS_FromGeoTiff(content)") + assert(rasterCount == 3) + assert(rasterDf.count() == 0) + } + } + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir)) +} \ No newline at end of file
