This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch geotiff-enhance
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 2beeee1071f896c1e64ce07f20349bf9e7a5761f
Author: Jia Yu <[email protected]>
AuthorDate: Thu May 11 00:36:11 2023 -0700

    Add a working solution
---
 .../spark/sql/sedona_sql/io/HadoopUtils.scala      | 107 -------------
 .../sedona_sql/io/{ => raster}/GeotiffSchema.scala |  90 ++++++-----
 .../io/{ => raster}/ImageReadOptions.scala         |   2 +-
 .../io/{ => raster}/ImageWriteOptions.scala        |   2 +-
 .../sedona_sql/io/raster/RasterFileFormat.scala    | 166 +++++++++++++++++++++
 .../RasterOptions.scala}                           |  18 +--
 ...org.apache.spark.sql.sources.DataSourceRegister |   5 +-
 .../io/{ => raster}/GeotiffFileFormat.scala        |   5 +-
 .../scala/org/apache/sedona/sql/rasterIOTest.scala |  53 ++++++-
 ...org.apache.spark.sql.sources.DataSourceRegister |   5 +-
 .../io/{ => raster}/GeotiffFileFormat.scala        |   0
 .../scala/org/apache/sedona/sql/rasterIOTest.scala |  53 ++++++-
 12 files changed, 318 insertions(+), 188 deletions(-)

diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/HadoopUtils.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/HadoopUtils.scala
deleted file mode 100644
index 54c5377f..00000000
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/HadoopUtils.scala
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.spark.sql.sedona_sql.io
-
-import org.apache.commons.io.FilenameUtils
-import org.apache.hadoop.conf.{Configuration, Configured}
-import org.apache.hadoop.fs.{Path, PathFilter}
-import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
-import org.apache.spark.sql.SparkSession
-
-import scala.language.existentials
-import scala.util.Random
-
-object RecursiveFlag {
-
-  /** Sets a value of spark recursive flag
-   *
-   * @param value value to set
-   * @param spark existing spark session
-   * @return previous value of this flag
-   */
-  def setRecursiveFlag(value: Option[String], spark: SparkSession): 
Option[String] = {
-    val flagName = FileInputFormat.INPUT_DIR_RECURSIVE
-    val hadoopConf = spark.sparkContext.hadoopConfiguration
-    val old = Option(hadoopConf.get(flagName))
-
-    value match {
-      case Some(v) => hadoopConf.set(flagName, v)
-      case None => hadoopConf.unset(flagName)
-    }
-
-    old
-  }
-}
-
-
-/** Filter that allows loading a fraction of HDFS files. */
-class SamplePathFilter extends Configured with PathFilter {
-  val random = {
-    val rd = new Random()
-    rd.setSeed(0)
-    rd
-  }
-
-  // Ratio of files to be read from disk
-  var sampleRatio: Double = 1
-
-  override def setConf(conf: Configuration): Unit = {
-    if (conf != null) {
-      sampleRatio = conf.getDouble(SamplePathFilter.ratioParam, 1)
-    }
-  }
-
-  override def accept(path: Path): Boolean = {
-    // Note: checking fileSystem.isDirectory is very slow here, so we use 
basic rules instead
-    !SamplePathFilter.isFile(path) ||
-      random.nextDouble() < sampleRatio
-  }
-}
-
-object SamplePathFilter {
-  val ratioParam = "sampleRatio"
-
-  def isFile(path: Path): Boolean = FilenameUtils.getExtension(path.toString) 
!= ""
-
-  /** Set/unset  hdfs PathFilter
-   *
-   * @param value       Filter class that is passed to HDFS
-   * @param sampleRatio Fraction of the files that the filter picks
-   * @param spark       Existing Spark session
-   * @return
-   */
-  def setPathFilter(value: Option[Class[_]], sampleRatio: Option[Double] = 
None, spark: SparkSession)
-  : Option[Class[_]] = {
-    val flagName = FileInputFormat.PATHFILTER_CLASS
-    val hadoopConf = spark.sparkContext.hadoopConfiguration
-    val old = Option(hadoopConf.getClass(flagName, null))
-    if (sampleRatio.isDefined) {
-      hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio.get)
-    } else {
-      hadoopConf.unset(SamplePathFilter.ratioParam)
-      None
-    }
-
-    value match {
-      case Some(v) => hadoopConf.setClass(flagName, v, classOf[PathFilter])
-      case None => hadoopConf.unset(flagName)
-    }
-    old
-  }
-}
\ No newline at end of file
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffSchema.scala
similarity index 85%
rename from 
sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala
rename to 
sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffSchema.scala
index 5a3a3595..90c0ec55 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffSchema.scala
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.spark.sql.sedona_sql.io
+package org.apache.spark.sql.sedona_sql.io.raster
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
@@ -38,8 +38,8 @@ object GeotiffSchema {
   val undefinedImageType = "Undefined"
 
   /**
-   * Schema for the image column: Row(String,Geometry, Int, Int, Int, 
Array[Double])
-   */
+    * Schema for the image column: Row(String,Geometry, Int, Int, Int, 
Array[Double])
+    */
   val columnSchema = StructType(
     StructField("origin", StringType, true) ::
       StructField("geometry", StringType, true) ::
@@ -51,73 +51,72 @@ object GeotiffSchema {
   val imageFields: Array[String] = columnSchema.fieldNames
 
   /**
-   * DataFrame with a single column of images named "image" (nullable)
-   */
+    * DataFrame with a single column of images named "image" (nullable)
+    */
   val imageSchema = StructType(StructField("image", columnSchema, true) :: Nil)
 
   /**
-   * Gets the origin of the image
-   *
-   * @return The origin of the image
-   */
+    * Gets the origin of the image
+    *
+    * @return The origin of the image
+    */
   def getOrigin(row: Row): String = row.getString(0)
 
   /**
-   * Gets the origin of the image
-   *
-   * @return The origin of the image
-   */
+    * Gets the origin of the image
+    *
+    * @return The origin of the image
+    */
   def getGeometry(row: Row): GeometryUDT = row.getAs[GeometryUDT](1)
 
 
   /**
-   * Gets the height of the image
-   *
-   * @return The height of the image
-   */
+    * Gets the height of the image
+    *
+    * @return The height of the image
+    */
   def getHeight(row: Row): Int = row.getInt(2)
 
   /**
-   * Gets the width of the image
-   *
-   * @return The width of the image
-   */
+    * Gets the width of the image
+    *
+    * @return The width of the image
+    */
   def getWidth(row: Row): Int = row.getInt(3)
 
   /**
-   * Gets the number of channels in the image
-   *
-   * @return The number of bands in the image
-   */
+    * Gets the number of channels in the image
+    *
+    * @return The number of bands in the image
+    */
   def getNBands(row: Row): Int = row.getInt(4)
 
 
   /**
-   * Gets the image data
-   *
-   * @return The image data
-   */
+    * Gets the image data
+    *
+    * @return The image data
+    */
   def getData(row: Row): Array[Double] = row.getAs[Array[Double]](5)
 
   /**
-   * Default values for the invalid image
-   *
-   * @param origin Origin of the invalid image
-   * @return Row with the default values
-   */
+    * Default values for the invalid image
+    *
+    * @param origin Origin of the invalid image
+    * @return Row with the default values
+    */
   private[io] def invalidImageRow(origin: String): Row =
     Row(Row(origin, -1, -1, -1, Array.ofDim[Byte](0)))
 
   /**
-   *
-   * Convert a GeoTiff image into a dataframe row
-   *
-   *
-   * @param origin Arbitrary string that identifies the image
-   * @param bytes  Image bytes (for example, jpeg)
-   * @return DataFrame Row or None (if the decompression fails)
-   *
-   */
+    *
+    * Convert a GeoTiff image into a dataframe row
+    *
+    * @param origin Arbitrary string that identifies the image
+    * @param bytes  Image bytes (for example, jpeg)
+    * @return DataFrame Row or None (if the decompression fails)
+    *
+    */
 
   private[io] def decode(origin: String, bytes: Array[Byte], 
imageSourceOptions: ImageReadOptions): Option[Row] = {
 
@@ -215,8 +214,3 @@ object GeotiffSchema {
     Some(Row(Row(origin, polygon.toText, height, width, nBands, decoded)))
   }
 }
-
-
-
-
-
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageReadOptions.scala
similarity index 97%
rename from 
sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala
rename to 
sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageReadOptions.scala
index f73fc7cf..552b8f8e 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageReadOptions.scala
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.spark.sql.sedona_sql.io
+package org.apache.spark.sql.sedona_sql.io.raster
 
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageWriteOptions.scala
similarity index 96%
copy from 
sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
copy to 
sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageWriteOptions.scala
index 8653c93a..6a730faa 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/ImageWriteOptions.scala
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.spark.sql.sedona_sql.io
+package org.apache.spark.sql.sedona_sql.io.raster
 
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterFileFormat.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterFileFormat.scala
new file mode 100644
index 00000000..54bf05d4
--- /dev/null
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterFileFormat.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+package org.apache.spark.sql.sedona_sql.io.raster
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FSDataOutputStream, FileStatus, Path}
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
+import org.apache.sedona.common.raster.Serde
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, 
OutputWriterFactory, PartitionedFile}
+import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
+import org.apache.spark.sql.types.StructType
+import org.geotools.gce.arcgrid.ArcGridWriter
+import org.geotools.gce.geotiff.GeoTiffWriter
+import org.opengis.coverage.grid.GridCoverageWriter
+
+import java.io.IOException
+import java.nio.file.Paths
+import java.util.UUID
+
+private[spark] class RasterFileFormat extends FileFormat with 
DataSourceRegister {
+
+  override def inferSchema(
+                            sparkSession: SparkSession,
+                            options: Map[String, String],
+                            files: Seq[FileStatus]): Option[StructType] = None
+
+  override def prepareWrite(
+                             sparkSession: SparkSession,
+                             job: Job,
+                             options: Map[String, String],
+                             dataSchema: StructType): OutputWriterFactory = {
+    val rasterOptions = new RasterOptions(options)
+    if (!isValidRasterSchema(dataSchema)) {
+      throw new IllegalArgumentException("Invalid GeoTiff Schema")
+    }
+
+    new OutputWriterFactory {
+      override def getFileExtension(context: TaskAttemptContext): String = ""
+
+      override def newInstance(path: String, dataSchema: StructType, context: 
TaskAttemptContext): OutputWriter = {
+        new RasterFileWriter(path, rasterOptions, dataSchema, context)
+      }
+    }
+  }
+
+  override def shortName(): String = "raster"
+
+  override protected def buildReader(
+                                      sparkSession: SparkSession,
+                                      dataSchema: StructType,
+                                      partitionSchema: StructType,
+                                      requiredSchema: StructType,
+                                      filters: Seq[Filter],
+                                      options: Map[String, String],
+                                      hadoopConf: Configuration): 
(PartitionedFile) => Iterator[InternalRow] = {
+    throw new UnsupportedOperationException("Please use Binary data source to 
reading raster files")
+  }
+
+  private def isValidRasterSchema(dataSchema: StructType): Boolean = {
+    var imageColExist: Boolean = false
+    val fields = dataSchema.fields
+    fields.foreach(field => {
+      if (field.dataType.typeName.equals("raster")) {
+        imageColExist = true
+      }
+    })
+    imageColExist
+  }
+
+}
+
+// class for writing raster images
+private class RasterFileWriter(savePath: String,
+                               rasterOptions: RasterOptions,
+                                dataSchema: StructType,
+                                context: TaskAttemptContext) extends 
OutputWriter {
+
+  private val hfs = new Path(savePath).getFileSystem(context.getConfiguration)
+
+  override def write(row: InternalRow): Unit = {
+    val rowFields: InternalRow = row
+    val schemaFields: StructType = dataSchema
+    var imageColIndex = -1
+    for (i <- schemaFields.indices) {
+      if (schemaFields.fields(i).dataType.typeName.equals("raster")) {
+        imageColIndex = i
+      }
+    }
+    // Get grid coverage 2D from the row
+    val rasterRaw = rowFields.getBinary(imageColIndex)
+    // If the raster is null, return
+    if (rasterRaw == null) return
+    // If the raster is not null, deserialize it
+    val gridCoverage2D = Serde.deserialize(rasterRaw)
+    var writer:GridCoverageWriter = null
+    var out:FSDataOutputStream = null
+    if (rasterOptions.rasterFormat.equalsIgnoreCase("geotiff")) {
+      // If the output path is not provided, generate a random UUID as the 
file name
+      val fileExtension = ".tiff"
+      val rasterFilePath = getRasterFilePath(fileExtension, rowFields, 
schemaFields, rasterOptions)
+      // create the write path
+      out = hfs.create(new Path(Paths.get(savePath, new 
Path(rasterFilePath).getName).toString))
+      writer = new GeoTiffWriter(out)
+    } else if (rasterOptions.rasterFormat.equalsIgnoreCase("arcgrid")) {
+      val fileExtension = ".asc"
+      val rasterFilePath = getRasterFilePath(fileExtension, rowFields, 
schemaFields, rasterOptions)
+      out = hfs.create(new Path(Paths.get(savePath, new 
Path(rasterFilePath).getName).toString))
+      writer = new ArcGridWriter(out)
+    } else
+      throw new IllegalArgumentException("Invalid raster format")
+
+    // write the image to file
+    try {
+      writer.write(gridCoverage2D)
+      writer.dispose()
+      out.close()
+    } catch {
+      case e@(_: IllegalArgumentException | _: IOException) =>
+        // TODO Auto-generated catch block
+        e.printStackTrace()
+    }
+  }
+
+  override def close(): Unit = {
+    hfs.close()
+  }
+
+  def path(): String = {
+    savePath
+  }
+
+  private def getRasterFilePath(fileExtension: String, row: InternalRow, 
schema: StructType, rasterOptions: RasterOptions): String = {
+    // If the output path is not provided, generate a random UUID as the file 
name
+    var rasterFilePath = UUID.randomUUID().toString
+    if (rasterOptions.rasterPathField.isDefined) {
+      val rasterFilePathRaw = 
row.getString(schema.fieldIndex(rasterOptions.rasterPathField.get))
+      // If the output path field is provided, but the value is null, generate 
a random UUID as the file name
+      if (rasterFilePathRaw != null) {
+        // remove the extension if exists
+        if (rasterFilePathRaw.contains(".")) rasterFilePath = 
rasterFilePathRaw.substring(0, rasterFilePathRaw.lastIndexOf("."))
+        else rasterFilePath = rasterFilePathRaw
+      }
+    }
+    rasterFilePath + fileExtension
+  }
+}
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala
similarity index 57%
rename from 
sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
rename to 
sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala
index 8653c93a..518dca65 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterOptions.scala
@@ -16,21 +16,15 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.spark.sql.sedona_sql.io
+package org.apache.spark.sql.sedona_sql.io.raster
 
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 
-private[io] class ImageWriteOptions(@transient private val parameters: 
CaseInsensitiveMap[String]) extends Serializable {
+private[io] class RasterOptions(@transient private val parameters: 
CaseInsensitiveMap[String]) extends Serializable {
   def this(parameters: Map[String, String]) = 
this(CaseInsensitiveMap(parameters))
 
-  // Optional parameters for writing GeoTiff
-  val writeToCRS = parameters.getOrElse("writeToCRS", "EPSG:4326")
-  val colImage = parameters.getOrElse("fieldImage", "image")
-  val colOrigin = parameters.getOrElse("fieldOrigin", "origin")
-  val colBands = parameters.getOrElse("fieldNBands", "nBands")
-  val colWidth = parameters.getOrElse("fieldWidth", "width")
-  val colHeight = parameters.getOrElse("fieldHeight", "height")
-  val colGeometry = parameters.getOrElse("fieldGeometry", "geometry")
-  val colData = parameters.getOrElse("fieldData", "data")
-
+  // Optional parameters for writing rasters to different image formats
+  val rasterFormat = parameters.getOrElse("rasterType", "geotiff")
+  // Column of the raster image name
+  val rasterPathField = parameters.get("pathField")
 }
\ No newline at end of file
diff --git 
a/sql/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
 
b/sql/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index 68ea723a..4352e818 100644
--- 
a/sql/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ 
b/sql/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -1,2 +1,3 @@
-org.apache.spark.sql.sedona_sql.io.GeotiffFileFormat
-org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
\ No newline at end of file
+org.apache.spark.sql.sedona_sql.io.raster.GeotiffFileFormat
+org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
+org.apache.spark.sql.sedona_sql.io.raster.RasterFileFormat
\ No newline at end of file
diff --git 
a/sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala
 
b/sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffFileFormat.scala
similarity index 99%
rename from 
sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala
rename to 
sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffFileFormat.scala
index 842e28f3..f3360ae3 100644
--- 
a/sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala
+++ 
b/sql/spark-3.0/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffFileFormat.scala
@@ -18,21 +18,20 @@
  */
 
 
-package org.apache.spark.sql.sedona_sql.io
+package org.apache.spark.sql.sedona_sql.io.raster
 
 import com.google.common.io.{ByteStreams, Closeables}
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileStatus, Path}
 import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
 import org.apache.sedona.sql.utils.GeometrySerializer
-import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, 
OutputWriterFactory, PartitionedFile}
 import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.SerializableConfiguration
 import org.geotools.coverage.CoverageFactoryFinder
diff --git 
a/sql/spark-3.0/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala 
b/sql/spark-3.0/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
index 7206ac38..6b52f05e 100644
--- a/sql/spark-3.0/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
+++ b/sql/spark-3.0/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
@@ -19,15 +19,19 @@
 
 package org.apache.sedona.sql
 
+import org.apache.commons.io.FileUtils
+import org.apache.spark.sql.SaveMode
 import org.locationtech.jts.geom.Geometry
 import org.scalatest.{BeforeAndAfter, GivenWhenThen}
 
 import java.io.File
+import java.nio.file.Files
 import scala.collection.mutable
 
 class rasterIOTest extends TestBaseScala with BeforeAndAfter with 
GivenWhenThen {
 
   var rasterdatalocation: String = resourceFolder + "raster/"
+  val tempDir: String = 
Files.createTempDirectory("sedona_raster_io_test_").toFile.getAbsolutePath
 
   describe("Raster IO test") {
     it("Should Pass geotiff loading without readFromCRS and readToCRS") {
@@ -158,7 +162,7 @@ class rasterIOTest extends TestBaseScala with 
BeforeAndAfter with GivenWhenThen
     it("Should Pass geotiff file writing with coalesce") {
       var df = sparkSession.read.format("geotiff").option("dropInvalid", 
true).option("readToCRS", "EPSG:4326").load(rasterdatalocation)
       df = df.selectExpr("image.origin as origin","image.geometry as 
geometry", "image.height as height", "image.width as width", "image.data as 
data", "image.nBands as nBands")
-      val savePath = resourceFolder + "raster-written/"
+      val savePath = tempDir + "/raster-written/"
       df.coalesce(1).write.mode("overwrite").format("geotiff").save(savePath)
 
       var loadPath = savePath
@@ -185,7 +189,7 @@ class rasterIOTest extends TestBaseScala with 
BeforeAndAfter with GivenWhenThen
     it("Should Pass geotiff file writing with writeToCRS") {
       var df = sparkSession.read.format("geotiff").option("dropInvalid", 
true).load(rasterdatalocation)
       df = df.selectExpr("image.origin as origin","image.geometry as 
geometry", "image.height as height", "image.width as width", "image.data as 
data", "image.nBands as nBands")
-      val savePath = resourceFolder + "raster-written/"
+      val savePath = tempDir + "/raster-written/"
       
df.coalesce(1).write.mode("overwrite").format("geotiff").option("writeToCRS", 
"EPSG:4499").save(savePath)
 
       var loadPath = savePath
@@ -212,7 +216,7 @@ class rasterIOTest extends TestBaseScala with 
BeforeAndAfter with GivenWhenThen
     it("Should Pass geotiff file writing without coalesce") {
       var df = sparkSession.read.format("geotiff").option("dropInvalid", 
true).load(rasterdatalocation)
       df = df.selectExpr("image.origin as origin","image.geometry as 
geometry", "image.height as height", "image.width as width", "image.data as 
data", "image.nBands as nBands")
-      val savePath = resourceFolder + "raster-written/"
+      val savePath = tempDir + "/raster-written/"
       df.write.mode("overwrite").format("geotiff").save(savePath)
 
       var imageCount = 0
@@ -347,11 +351,48 @@ class rasterIOTest extends TestBaseScala with 
BeforeAndAfter with GivenWhenThen
         }
       }
     }
-    
-  }
-}
 
+    it("should read geotiff using binary source and write geotiff back to disk 
using raster source") {
+      var df = sparkSession.read.format("binaryFile").load(rasterdatalocation)
+      var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "length")
+      val rasterCount = rasterDf.count()
+      rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(tempDir + 
"/geotiff-written")
+      df = sparkSession.read.format("binaryFile").load(tempDir + 
"/geotiff-written/*")
+      rasterDf = df.selectExpr("RS_FromGeoTiff(content)")
+      assert(rasterDf.count() == rasterCount)
+    }
 
+    it("should read and write geotiff using given options") {
+      var df = sparkSession.read.format("binaryFile").load(rasterdatalocation)
+      var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "path")
+      val rasterCount = rasterDf.count()
+      rasterDf.write.format("raster").option("rasterType", 
"geotiff").option("pathField", "path").mode(SaveMode.Overwrite).save(tempDir + 
"/geotiff-written")
+      df = sparkSession.read.format("binaryFile").load(tempDir + 
"/geotiff-written/*")
+      rasterDf = df.selectExpr("RS_FromGeoTiff(content)")
+      assert(rasterDf.count() == rasterCount)
+    }
 
+    it("should read geotiff and write asc") {
+      var df = sparkSession.read.format("binaryFile").load(rasterdatalocation)
+      var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "path")
+      val rasterCount = rasterDf.count()
+      rasterDf.write.format("raster").option("rasterType", 
"arcgrid").option("pathField", "path").mode(SaveMode.Overwrite).save(tempDir + 
"/asc-written")
+      df = sparkSession.read.format("binaryFile").load(tempDir + 
"/asc-written/*")
+      rasterDf = df.selectExpr("RS_FromArcInfoAsciiGrid(content)")
+      assert(rasterDf.count() == rasterCount)
+    }
 
+    it("should handle null") {
+      var df = sparkSession.read.format("binaryFile").load(rasterdatalocation)
+      var rasterDf = df.selectExpr("RS_FromGeoTiff(null)", "length")
+      val rasterCount = rasterDf.count()
+      rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(tempDir + 
"/geotiff-written")
+      df = sparkSession.read.format("binaryFile").load(tempDir + 
"/geotiff-written/*")
+      rasterDf = df.selectExpr("RS_FromGeoTiff(content)")
+      assert(rasterCount == 3)
+      assert(rasterDf.count() == 0)
+    }
+  }
 
+  override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir))
+}
\ No newline at end of file
diff --git 
a/sql/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
 
b/sql/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index 68ea723a..4352e818 100644
--- 
a/sql/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ 
b/sql/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -1,2 +1,3 @@
-org.apache.spark.sql.sedona_sql.io.GeotiffFileFormat
-org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
\ No newline at end of file
+org.apache.spark.sql.sedona_sql.io.raster.GeotiffFileFormat
+org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
+org.apache.spark.sql.sedona_sql.io.raster.RasterFileFormat
\ No newline at end of file
diff --git 
a/sql/spark-3.4/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala
 
b/sql/spark-3.4/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffFileFormat.scala
similarity index 100%
rename from 
sql/spark-3.4/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala
rename to 
sql/spark-3.4/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/GeotiffFileFormat.scala
diff --git 
a/sql/spark-3.4/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala 
b/sql/spark-3.4/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
index 7206ac38..6b52f05e 100644
--- a/sql/spark-3.4/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
+++ b/sql/spark-3.4/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
@@ -19,15 +19,19 @@
 
 package org.apache.sedona.sql
 
+import org.apache.commons.io.FileUtils
+import org.apache.spark.sql.SaveMode
 import org.locationtech.jts.geom.Geometry
 import org.scalatest.{BeforeAndAfter, GivenWhenThen}
 
 import java.io.File
+import java.nio.file.Files
 import scala.collection.mutable
 
 class rasterIOTest extends TestBaseScala with BeforeAndAfter with 
GivenWhenThen {
 
   var rasterdatalocation: String = resourceFolder + "raster/"
+  val tempDir: String = 
Files.createTempDirectory("sedona_raster_io_test_").toFile.getAbsolutePath
 
   describe("Raster IO test") {
     it("Should Pass geotiff loading without readFromCRS and readToCRS") {
@@ -158,7 +162,7 @@ class rasterIOTest extends TestBaseScala with 
BeforeAndAfter with GivenWhenThen
     it("Should Pass geotiff file writing with coalesce") {
       var df = sparkSession.read.format("geotiff").option("dropInvalid", 
true).option("readToCRS", "EPSG:4326").load(rasterdatalocation)
       df = df.selectExpr("image.origin as origin","image.geometry as 
geometry", "image.height as height", "image.width as width", "image.data as 
data", "image.nBands as nBands")
-      val savePath = resourceFolder + "raster-written/"
+      val savePath = tempDir + "/raster-written/"
       df.coalesce(1).write.mode("overwrite").format("geotiff").save(savePath)
 
       var loadPath = savePath
@@ -185,7 +189,7 @@ class rasterIOTest extends TestBaseScala with 
BeforeAndAfter with GivenWhenThen
     it("Should Pass geotiff file writing with writeToCRS") {
       var df = sparkSession.read.format("geotiff").option("dropInvalid", 
true).load(rasterdatalocation)
       df = df.selectExpr("image.origin as origin","image.geometry as 
geometry", "image.height as height", "image.width as width", "image.data as 
data", "image.nBands as nBands")
-      val savePath = resourceFolder + "raster-written/"
+      val savePath = tempDir + "/raster-written/"
       
df.coalesce(1).write.mode("overwrite").format("geotiff").option("writeToCRS", 
"EPSG:4499").save(savePath)
 
       var loadPath = savePath
@@ -212,7 +216,7 @@ class rasterIOTest extends TestBaseScala with 
BeforeAndAfter with GivenWhenThen
     it("Should Pass geotiff file writing without coalesce") {
       var df = sparkSession.read.format("geotiff").option("dropInvalid", 
true).load(rasterdatalocation)
       df = df.selectExpr("image.origin as origin","image.geometry as 
geometry", "image.height as height", "image.width as width", "image.data as 
data", "image.nBands as nBands")
-      val savePath = resourceFolder + "raster-written/"
+      val savePath = tempDir + "/raster-written/"
       df.write.mode("overwrite").format("geotiff").save(savePath)
 
       var imageCount = 0
@@ -347,11 +351,48 @@ class rasterIOTest extends TestBaseScala with 
BeforeAndAfter with GivenWhenThen
         }
       }
     }
-    
-  }
-}
 
+    it("should read geotiff using binary source and write geotiff back to disk 
using raster source") {
+      var df = sparkSession.read.format("binaryFile").load(rasterdatalocation)
+      var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "length")
+      val rasterCount = rasterDf.count()
+      rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(tempDir + 
"/geotiff-written")
+      df = sparkSession.read.format("binaryFile").load(tempDir + 
"/geotiff-written/*")
+      rasterDf = df.selectExpr("RS_FromGeoTiff(content)")
+      assert(rasterDf.count() == rasterCount)
+    }
 
+    it("should read and write geotiff using given options") {
+      var df = sparkSession.read.format("binaryFile").load(rasterdatalocation)
+      var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "path")
+      val rasterCount = rasterDf.count()
+      rasterDf.write.format("raster").option("rasterType", 
"geotiff").option("pathField", "path").mode(SaveMode.Overwrite).save(tempDir + 
"/geotiff-written")
+      df = sparkSession.read.format("binaryFile").load(tempDir + 
"/geotiff-written/*")
+      rasterDf = df.selectExpr("RS_FromGeoTiff(content)")
+      assert(rasterDf.count() == rasterCount)
+    }
 
+    it("should read geotiff and write asc") {
+      var df = sparkSession.read.format("binaryFile").load(rasterdatalocation)
+      var rasterDf = df.selectExpr("RS_FromGeoTiff(content)", "path")
+      val rasterCount = rasterDf.count()
+      rasterDf.write.format("raster").option("rasterType", 
"arcgrid").option("pathField", "path").mode(SaveMode.Overwrite).save(tempDir + 
"/asc-written")
+      df = sparkSession.read.format("binaryFile").load(tempDir + 
"/asc-written/*")
+      rasterDf = df.selectExpr("RS_FromArcInfoAsciiGrid(content)")
+      assert(rasterDf.count() == rasterCount)
+    }
 
+    it("should handle null") {
+      var df = sparkSession.read.format("binaryFile").load(rasterdatalocation)
+      var rasterDf = df.selectExpr("RS_FromGeoTiff(null)", "length")
+      val rasterCount = rasterDf.count()
+      rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(tempDir + 
"/geotiff-written")
+      df = sparkSession.read.format("binaryFile").load(tempDir + 
"/geotiff-written/*")
+      rasterDf = df.selectExpr("RS_FromGeoTiff(content)")
+      assert(rasterCount == 3)
+      assert(rasterDf.count() == 0)
+    }
+  }
 
+  override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir))
+}
\ No newline at end of file


Reply via email to