This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new b9cd411d9d [SEDONA-704] Add support to load items directly form
catalog without using collection id and fix limit issues (#1839)
b9cd411d9d is described below
commit b9cd411d9dcb9fb03cc3dcdeb3a630823deb04df
Author: Feng Zhang <[email protected]>
AuthorDate: Tue Mar 4 17:48:36 2025 -0800
[SEDONA-704] Add support to load items directly form catalog without using
collection id and fix limit issues (#1839)
* [SEDONA-704] Add support to load items directly form catalog without
using collection id and fix limit issues
- Add support to load items directly form catalog without using collection
id. Also fixed some minor issues with the max items returned parameter.
- This commit tries to fix the STAC reader issue when the collection
endpoint limits the number of returns.
For some STAC endpoints, when a single items json file contains too many
items, the server will limit the number of items returned each call to the
endpoint, and add a next link for users to load the next "chunk" of items if
they want.
* update document
* fix compiling errors
* two more fixes
* revert StacPartitionReader
* revert StacBatchTest
---
docs/api/sql/Stac.md | 42 ++++++++
python/sedona/stac/client.py | 21 +++-
python/tests/stac/test_client.py | 15 ++-
python/tests/stac/test_collection_client.py | 8 +-
python/tests/test_base.py | 11 +-
.../spark/sql/sedona_sql/io/stac/StacBatch.scala | 113 ++++++++++++++++++++-
.../sql/sedona_sql/io/stac/StacDataSource.scala | 11 +-
.../spark/sql/sedona_sql/io/stac/StacUtils.scala | 29 +++++-
.../sedona_sql/io/stac/StacDataSourceTest.scala | 8 +-
.../sql/sedona_sql/io/stac/StacTableTest.scala | 2 +-
10 files changed, 230 insertions(+), 30 deletions(-)
diff --git a/docs/api/sql/Stac.md b/docs/api/sql/Stac.md
index bf91d4c965..8d56644e5e 100644
--- a/docs/api/sql/Stac.md
+++ b/docs/api/sql/Stac.md
@@ -150,6 +150,48 @@ In this example, the data source will push down the
temporal filter to the under
In this example, the data source will push down the spatial filter to the
underlying data source.
+## Sedona Configuration for STAC Reader
+
+When using the STAC reader in Sedona, several configuration options can be set
to control the behavior of the reader. These configurations are typically set
in a `Map[String, String]` and passed to the reader. Below are the key sedona
configuration options:
+
+- **spark.sedona.stac.load.maxPartitionItemFiles**: This option specifies the
maximum number of item files that can be included in a single partition. It
helps in controlling the size of partitions. The default value is set to -1,
meaning the system will automatically determine the number of item files per
partition.
+
+- **spark.sedona.stac.load.numPartitions**: This option sets the number of
partitions to be created for the STAC data. It allows for better control over
data distribution and parallel processing. The default value is set to -1,
meaning the system will automatically determine the number of item files per
partition.
+
+Below are reader options that can be set to control the behavior of the STAC
reader:
+
+- **itemsLimitMax**: This option specifies the maximum number of items to be
loaded from the STAC collection. It helps in limiting the amount of data
processed. The default value is set to -1, meaning all items will be loaded.
+
+- **itemsLoadProcessReportThreshold**: This option specifies the threshold for
reporting the progress of item loading. It helps in monitoring the progress of
the loading process. The default value is set to 1000000, meaning the progress
will be reported every 1,000,000 items loaded.
+
+- **itemsLimitPerRequest**: This option specifies the maximum number of items
to be requested in a single API call. It helps in controlling the size of each
request. The default value is set to 10.
+
+These configurations can be combined into a single `Map[String, String]` and
passed to the STAC reader as shown below:
+
+```scala
+ def defaultSparkConfig: Map[String, String] = Map(
+ "spark.sedona.stac.load.maxPartitionItemFiles" -> "100",
+ "spark.sedona.stac.load.numPartitions" -> "10",
+ "spark.sedona.stac.load.itemsLimitMax" -> "20")
+
+ val sparkSession: SparkSession = {
+ val builder = SedonaContext
+ .builder()
+ .master("local[*]")
+ defaultSparkConfig.foreach { case (key, value) => builder.config(key,
value) }
+ builder.getOrCreate()
+ }
+
+ df = sedona.read
+ .format("stac")
+ .option("itemsLimitMax", "100")
+ .option("itemsLoadProcessReportThreshold", "2000000")
+ .option("itemsLimitPerRequest", "100")
+
.load("https://earth-search.aws.element84.com/v1/collections/sentinel-2-pre-c1-l2a")
+```
+
+These options above provide fine-grained control over how the STAC data is
read and processed in Sedona.
+
# Python API
The Python API allows you to interact with a SpatioTemporal Asset Catalog
(STAC) API using the Client class. This class provides methods to open a
connection to a STAC API, retrieve collections, and search for items with
various filters.
diff --git a/python/sedona/stac/client.py b/python/sedona/stac/client.py
index 3e8eeacefa..50ce0afc6c 100644
--- a/python/sedona/stac/client.py
+++ b/python/sedona/stac/client.py
@@ -60,10 +60,22 @@ class Client:
"""
return CollectionClient(self.url, collection_id)
+ def get_collection_from_catalog(self):
+ """
+ Retrieves the catalog from the STAC API.
+
+ This method fetches the root catalog from the STAC API, providing
access to all collections and items.
+
+ Returns:
+ - dict: The root catalog of the STAC API.
+ """
+ # Implement logic to fetch and return the root catalog
+ return CollectionClient(self.url, None)
+
def search(
self,
*ids: Union[str, list],
- collection_id: str,
+ collection_id: Optional[str] = None,
bbox: Optional[list] = None,
datetime: Optional[Union[str, python_datetime.datetime, list]] = None,
max_items: Optional[int] = None,
@@ -76,7 +88,7 @@ class Client:
- ids (Union[str, list]): A variable number of item IDs to filter the
items.
Example: "item_id1" or ["item_id1", "item_id2"]
- - collection_id (str): The ID of the collection to search in.
+ - collection_id (Optional[str]): The ID of the collection to search in.
Example: "aster-l1t"
- bbox (Optional[list]): A list of bounding boxes for filtering the
items.
@@ -101,7 +113,10 @@ class Client:
Returns:
- Union[Iterator[PyStacItem], DataFrame]: An iterator of PyStacItem
objects or a Spark DataFrame that match the specified filters.
"""
- client = self.get_collection(collection_id)
+ if collection_id:
+ client = self.get_collection(collection_id)
+ else:
+ client = self.get_collection_from_catalog()
if return_dataframe:
return client.get_dataframe(
*ids, bbox=bbox, datetime=datetime, max_items=max_items
diff --git a/python/tests/stac/test_client.py b/python/tests/stac/test_client.py
index 5c6192258a..4f9919e1c5 100644
--- a/python/tests/stac/test_client.py
+++ b/python/tests/stac/test_client.py
@@ -21,7 +21,8 @@ from pyspark.sql import DataFrame
from tests.test_base import TestBase
STAC_URLS = {
- "PLANETARY-COMPUTER": "https://planetarycomputer.microsoft.com/api/stac/v1"
+ "PLANETARY-COMPUTER":
"https://planetarycomputer.microsoft.com/api/stac/v1",
+ "EARTHVIEW-CATALOG":
"https://satellogic-earthview.s3.us-west-2.amazonaws.com/stac/catalog.json",
}
@@ -95,7 +96,7 @@ class TestStacClient(TestBase):
return_dataframe=False,
)
assert items is not None
- assert len(list(items)) == 10
+ assert len(list(items)) == 20
def test_search_with_max_items(self) -> None:
client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
@@ -133,7 +134,7 @@ class TestStacClient(TestBase):
return_dataframe=False,
)
assert items is not None
- assert len(list(items)) == 10
+ assert len(list(items)) == 20
def test_search_with_return_dataframe(self) -> None:
client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
@@ -144,3 +145,11 @@ class TestStacClient(TestBase):
)
assert df is not None
assert isinstance(df, DataFrame)
+
+ def test_search_with_catalog_url(self) -> None:
+ client = Client.open(STAC_URLS["EARTHVIEW-CATALOG"])
+ df = client.search(
+ return_dataframe=True,
+ )
+ assert df is not None
+ assert isinstance(df, DataFrame)
diff --git a/python/tests/stac/test_collection_client.py
b/python/tests/stac/test_collection_client.py
index c30105a4eb..24226f86ca 100644
--- a/python/tests/stac/test_collection_client.py
+++ b/python/tests/stac/test_collection_client.py
@@ -38,7 +38,7 @@ class TestStacReader(TestBase):
collection = client.get_collection("aster-l1t")
df = collection.get_dataframe()
assert df is not None
- assert df.count() == 10
+ assert df.count() == 20
def test_get_dataframe_with_spatial_extent(self) -> None:
client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
@@ -79,7 +79,7 @@ class TestStacReader(TestBase):
datetime = [["2006-12-01T00:00:00Z", "2006-12-27T02:00:00Z"]]
items = list(collection.get_items(datetime=datetime))
assert items is not None
- assert len(items) == 6
+ assert len(items) == 16
def test_get_items_with_both_extents(self) -> None:
client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
@@ -140,7 +140,7 @@ class TestStacReader(TestBase):
]
items = list(collection.get_items(bbox=bbox, datetime=datetime))
assert items is not None
- assert len(items) == 10
+ assert len(items) == 20
def test_get_items_with_bbox_and_interval(self) -> None:
client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
@@ -186,4 +186,4 @@ class TestStacReader(TestBase):
# Optionally, you can load the file back and check its contents
df_loaded =
collection.spark.read.format("geoparquet").load(output_path)
- assert df_loaded.count() == 10, "Loaded GeoParquet file is empty"
+ assert df_loaded.count() == 20, "Loaded GeoParquet file is empty"
diff --git a/python/tests/test_base.py b/python/tests/test_base.py
index 2769a93cdd..2582b9d4bf 100644
--- a/python/tests/test_base.py
+++ b/python/tests/test_base.py
@@ -42,7 +42,7 @@ class TestBase:
if "SPARK_HOME" in os.environ and not os.environ["SPARK_HOME"]:
del os.environ["SPARK_HOME"]
- builder = SedonaContext.builder()
+ builder = SedonaContext.builder().appName("SedonaSparkTest")
if SPARK_REMOTE:
builder = (
builder.remote(SPARK_REMOTE)
@@ -54,9 +54,16 @@ class TestBase:
"spark.sql.extensions",
"org.apache.sedona.sql.SedonaSqlExtensions",
)
+ .config(
+ "spark.sedona.stac.load.itemsLimitMax",
+ "20",
+ )
)
else:
- builder = builder.master("local[*]")
+ builder = builder.master("local[*]").config(
+ "spark.sedona.stac.load.itemsLimitMax",
+ "20",
+ )
# Allows the Sedona .jar to be explicitly set by the caller (e.g,
to run
# pytest against a freshly-built development version of Sedona)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
index 98cb35ee07..1de8518de7 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
@@ -29,6 +29,7 @@ import java.time.LocalDateTime
import java.time.format.DateTimeFormatterBuilder
import java.time.temporal.ChronoField
import scala.jdk.CollectionConverters._
+import scala.util.control.Breaks.breakable
/**
* The `StacBatch` class represents a batch of partitions for reading data in
the SpatioTemporal
@@ -47,8 +48,24 @@ case class StacBatch(
temporalFilter: Option[TemporalFilter])
extends Batch {
+ private val defaultItemsLimitPerRequest =
opts.getOrElse("itemsLimitPerRequest", "10").toInt
+ private val itemsLoadProcessReportThreshold =
+ opts.getOrElse("itemsLoadProcessReportThreshold", "1000000").toInt
+ private var itemMaxLeft: Int = -1
+ private var lastReportCount: Int = 0
+
val mapper = new ObjectMapper()
+ /**
+ * Sets the maximum number of items left to process.
+ *
+ * @param value
+ * The maximum number of items left.
+ */
+ def setItemMaxLeft(value: Int): Unit = {
+ itemMaxLeft = value
+ }
+
/**
* Plans the input partitions for reading data from the STAC data source.
*
@@ -62,7 +79,10 @@ case class StacBatch(
val itemLinks = scala.collection.mutable.ArrayBuffer[String]()
// Start the recursive collection of item links
- collectItemLinks(stacCollectionBasePath, stacCollectionJson, itemLinks)
+ val itemsLimitMax = opts.getOrElse("itemsLimitMax", "-1").toInt
+ val checkItemsLimitMax = itemsLimitMax > 0
+ setItemMaxLeft(itemsLimitMax)
+ collectItemLinks(stacCollectionBasePath, stacCollectionJson, itemLinks,
checkItemsLimitMax)
// Handle when the number of items is less than 1
if (itemLinks.isEmpty) {
@@ -106,16 +126,78 @@ case class StacBatch(
* @param itemLinks
* The list of item links to populate.
*/
- private def collectItemLinks(
+ def collectItemLinks(
collectionBasePath: String,
collectionJson: String,
- itemLinks: scala.collection.mutable.ArrayBuffer[String]): Unit = {
+ itemLinks: scala.collection.mutable.ArrayBuffer[String],
+ needCountNextItems: Boolean): Unit = {
+
+ // end early if there are no more items to process
+ if (needCountNextItems && itemMaxLeft <= 0) return
+
+ if (itemLinks.size - lastReportCount >= itemsLoadProcessReportThreshold) {
+ Console.out.println(s"Searched or partitioned ${itemLinks.size} items so
far.")
+ lastReportCount = itemLinks.size
+ }
+
// Parse the JSON string into a JsonNode (tree representation of JSON)
val rootNode: JsonNode = mapper.readTree(collectionJson)
// Extract item links from the "links" array
val linksNode = rootNode.get("links")
val iterator = linksNode.elements()
+
+ def iterateItemsWithLimit(itemUrl: String, needCountNextItems: Boolean):
Boolean = {
+ // Load the item URL and process the response
+ var nextUrl: Option[String] = Some(itemUrl)
+ breakable {
+ while (nextUrl.isDefined) {
+ val itemJson = StacUtils.loadStacCollectionToJson(nextUrl.get)
+ val itemRootNode = mapper.readTree(itemJson)
+ // Check if there exists a "next" link
+ val itemLinksNode = itemRootNode.get("links")
+ if (itemLinksNode == null) {
+ return true
+ }
+ val itemIterator = itemLinksNode.elements()
+ nextUrl = None
+ while (itemIterator.hasNext) {
+ val itemLinkNode = itemIterator.next()
+ val itemRel = itemLinkNode.get("rel").asText()
+ val itemHref = itemLinkNode.get("href").asText()
+ if (itemRel == "next") {
+ // Only check the number of items returned if there are more
items to process
+ val numberReturnedNode = itemRootNode.get("numberReturned")
+ val numberReturned = if (numberReturnedNode == null) {
+ // From STAC API Spec:
+ // The optional limit parameter limits the number of
+ // items that are presented in the response document.
+ // The default value is 10.
+ defaultItemsLimitPerRequest
+ } else {
+ numberReturnedNode.asInt()
+ }
+ // count the number of items returned and left to be processed
+ itemMaxLeft = itemMaxLeft - numberReturned
+ // early exit if there are no more items to process
+ if (needCountNextItems && itemMaxLeft <= 0) {
+ return true
+ }
+ nextUrl = Some(if (itemHref.startsWith("http") ||
itemHref.startsWith("file")) {
+ itemHref
+ } else {
+ collectionBasePath + itemHref
+ })
+ }
+ }
+ if (nextUrl.isDefined) {
+ itemLinks += nextUrl.get
+ }
+ }
+ }
+ false
+ }
+
while (iterator.hasNext) {
val linkNode = iterator.next()
val rel = linkNode.get("rel").asText()
@@ -129,7 +211,24 @@ case class StacBatch(
} else {
collectionBasePath + href
}
- itemLinks += itemUrl // Add the item URL to the list
+ if (rel == "items" && href.startsWith("http")) {
+ itemLinks += (itemUrl + "?limit=" + defaultItemsLimitPerRequest)
+ } else {
+ itemLinks += itemUrl
+ }
+ if (needCountNextItems && itemMaxLeft <= 0) {
+ return
+ } else {
+ if (rel == "item" && needCountNextItems) {
+ // count the number of items returned and left to be processed
+ itemMaxLeft = itemMaxLeft - 1
+ } else if (rel == "items" && href.startsWith("http")) {
+ // iterate through the items and check if the limit is reached (if
needed)
+ if (iterateItemsWithLimit(
+ itemUrl + "?limit=" + defaultItemsLimitPerRequest,
+ needCountNextItems)) return
+ }
+ }
} else if (rel == "child") {
val childUrl = if (href.startsWith("http") || href.startsWith("file"))
{
href
@@ -143,7 +242,11 @@ case class StacBatch(
filterCollection(linkedCollectionJson, spatialFilter, temporalFilter)
if (!collectionFiltered) {
- collectItemLinks(nestedCollectionBasePath, linkedCollectionJson,
itemLinks)
+ collectItemLinks(
+ nestedCollectionBasePath,
+ linkedCollectionJson,
+ itemLinks,
+ needCountNextItems)
}
}
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala
index ac64b8393b..dc2dc3e6dd 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala
@@ -106,9 +106,16 @@ class StacDataSource() extends TableProvider with
DataSourceRegister {
"columnNameOfCorruptRecord" ->
SparkSession.active.sessionState.conf.columnNameOfCorruptRecord,
"defaultParallelism" ->
SparkSession.active.sparkContext.defaultParallelism.toString,
"maxPartitionItemFiles" -> SparkSession.active.conf
- .get("spark.wherobots.stac.load.maxPartitionItemFiles", "0"),
+ .get("spark.sedona.stac.load.maxPartitionItemFiles", "0"),
"numPartitions" -> SparkSession.active.conf
- .get("spark.wherobots.stac.load.numPartitions", "-1"))
+ .get("spark.sedona.stac.load.numPartitions", "-1"),
+ "itemsLimitMax" -> opts
+ .asCaseSensitiveMap()
+ .asScala
+ .toMap
+ .get("itemsLimitMax")
+ .filter(_.toInt > 0)
+
.getOrElse(SparkSession.active.conf.get("spark.sedona.stac.load.itemsLimitMax",
"-1")))
val stacCollectionJsonString = StacUtils.loadStacCollectionToJson(optsMap)
new StacTable(stacCollectionJson = stacCollectionJsonString, opts =
optsMap)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
index 4e148422bf..508d6986b5 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
@@ -47,12 +47,31 @@ object StacUtils {
}
// Function to load JSON from URL or service
- def loadStacCollectionToJson(url: String): String = {
- if (url.startsWith("s3://") || url.startsWith("s3a://")) {
- SparkSession.active.read.textFile(url).collect().mkString("\n")
- } else {
- Source.fromURL(url).mkString
+ def loadStacCollectionToJson(url: String, maxRetries: Int = 3): String = {
+ var retries = 0
+ var success = false
+ var result: String = ""
+
+ while (retries < maxRetries && !success) {
+ try {
+ result = if (url.startsWith("s3://") || url.startsWith("s3a://")) {
+ SparkSession.active.read.textFile(url).collect().mkString("\n")
+ } else {
+ Source.fromURL(url).mkString
+ }
+ success = true
+ } catch {
+ case e: Exception =>
+ retries += 1
+ if (retries >= maxRetries) {
+ throw new RuntimeException(
+ s"Failed to load STAC collection from $url after $maxRetries
attempts",
+ e)
+ }
+ }
}
+
+ result
}
// Function to get the base URL from the collection URL or service
diff --git
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
index a1234ffa11..2cab18b694 100644
---
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
+++
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
@@ -19,11 +19,8 @@
package org.apache.spark.sql.sedona_sql.io.stac
import org.apache.sedona.sql.TestBaseScala
-import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types.{ArrayType, DoubleType, MapType, StringType,
StructField, StructType, TimestampType}
-import org.scalatest.BeforeAndAfterAll
-
-import java.util.TimeZone
class StacDataSourceTest extends TestBaseScala {
@@ -35,7 +32,8 @@ class StacDataSourceTest extends TestBaseScala {
"https://storage.googleapis.com/cfo-public/vegetation/collection.json",
"https://storage.googleapis.com/cfo-public/wildfire/collection.json",
"https://earthdatahub.destine.eu/api/stac/v1/collections/copernicus-dem",
- "https://planetarycomputer.microsoft.com/api/stac/v1/collections/naip")
+ "https://planetarycomputer.microsoft.com/api/stac/v1/collections/naip",
+
"https://satellogic-earthview.s3.us-west-2.amazonaws.com/stac/catalog.json")
it("basic df load from local file should work") {
val dfStac = sparkSession.read.format("stac").load(STAC_COLLECTION_LOCAL)
diff --git
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTableTest.scala
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTableTest.scala
index eca7768733..9ca5f424a4 100644
---
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTableTest.scala
+++
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTableTest.scala
@@ -19,7 +19,7 @@
package org.apache.spark.sql.sedona_sql.io.stac
import org.apache.spark.sql.sedona_sql.io.stac.StacTable.{SCHEMA_GEOPARQUET,
addAssetStruct, addAssetsStruct}
-import org.apache.spark.sql.types.{ArrayType, MapType, StringType,
StructField, StructType}
+import org.apache.spark.sql.types.{ArrayType, StringType, StructField,
StructType}
import org.scalatest.funsuite.AnyFunSuite
class StacTableTest extends AnyFunSuite {