This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 4dd212f61f [SEDONA-689] Geostats SQL (#1736)
4dd212f61f is described below
commit 4dd212f61fc770bca435f538ba2b3b461e5046e9
Author: James Willis <[email protected]>
AuthorDate: Thu Jan 2 10:54:37 2025 -0800
[SEDONA-689] Geostats SQL (#1736)
* Geostats SQL
* add missing import statement
---------
Co-authored-by: jameswillis <[email protected]>
---
docs/api/sql/Function.md | 123 ++++++++++
python/sedona/sql/st_functions.py | 183 +++++++++++++++
python/sedona/stats/clustering/dbscan.py | 6 +
.../outlier_detection/local_outlier_factor.py | 3 +
python/sedona/stats/weighting.py | 69 +++++-
python/tests/sql/test_dataframe_api.py | 16 ++
spark/common/pom.xml | 6 +
.../org/apache/sedona/core/utils/SedonaConf.java | 11 +
.../org/apache/sedona/spark/SedonaContext.scala | 15 ++
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 8 +-
.../scala/org/apache/sedona/stats/Weighting.scala | 116 +++++++--
.../apache/sedona/stats/clustering/DBSCAN.scala | 21 +-
.../outlierDetection/LocalOutlierFactor.scala | 17 +-
.../sedona_sql/expressions/GeoStatsFunctions.scala | 259 +++++++++++++++++++++
.../sql/sedona_sql/expressions/st_functions.scala | 42 ++++
.../optimization/ExtractGeoStatsFunctions.scala | 120 ++++++++++
.../plans/logical/EvalGeoStatsFunction.scala | 39 ++++
.../geostats/EvalGeoStatsFunctionExec.scala | 41 ++++
.../geostats/EvalGeoStatsFunctionStrategy.scala | 37 +++
.../org/apache/sedona/sql/GeoStatsSuite.scala | 215 +++++++++++++++++
20 files changed, 1311 insertions(+), 36 deletions(-)
diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md
index f1708d927d..8aa8aa4404 100644
--- a/docs/api/sql/Function.md
+++ b/docs/api/sql/Function.md
@@ -708,6 +708,29 @@ Output:
32618
```
+## ST_BinaryDistanceBandColumn
+
+Introduction: Introduction: Returns a `weights` column containing every record
in a dataframe within a specified `threshold` distance.
+
+The `weights` column is an array of structs containing the `attributes` from
each neighbor and that neighbor's weight. Since this is a binary distance band
function, weights of neighbors within the threshold will always be
+`1.0`.
+
+Format: `ST_BinaryDistanceBandColumn(geometry:Geometry, threshold: Double,
includeZeroDistanceNeighbors: boolean, includeSelf: boolean, useSpheroid:
boolean, attributes: Struct)`
+
+Since: `v1.7.1`
+
+SQL Example
+
+```sql
+ST_BinaryDistanceBandColumn(geometry, 1.0, true, true, false, struct(id,
geometry))
+````
+
+Output:
+
+```
+[{{15, POINT (3 1.9)}, 1.0}, {{16, POINT (3 2)}, 1.0}, {{17, POINT (3 2.1)},
1.0}, {{18, POINT (3 2.2)}, 1.0}]
+```
+
## ST_Boundary
Introduction: Returns the closure of the combinatorial boundary of this
Geometry.
@@ -1107,6 +1130,31 @@ true
!!!Warning
For geometries that span more than 180 degrees in longitude without
actually crossing the Date Line, this function may still return true,
indicating a crossing.
+## ST_DBSCAN
+
+Introduction: Performs a DBSCAN clustering across the entire dataframe.
+
+Returns a struct containing the cluster ID and a boolean indicating if the
record is a core point in the cluster.
+
+- `epsilon` is the maximum distance between two points for them to be
considered as part of the same cluster.
+- `minPoints` is the minimum number of neighbors a single record must have to
form a cluster.
+
+Format: `ST_DBSCAN(geom: Geometry, epsilon: Double, minPoints: Integer)`
+
+Since: `v1.7.1`
+
+SQL Example
+
+```sql
+SELECT ST_DBSCAN(geom, 1.0, 2)
+```
+
+Output:
+
+```
+{true, 85899345920}
+```
+
## ST_Degrees
Introduction: Convert an angle in radian to degrees.
@@ -1874,6 +1922,31 @@ Output:
ST_LINESTRING
```
+## ST_GLocal
+
+Introduction: Runs Getis and Ord's G Local (Gi or Gi*) statistic on the
geometry given the `weights` and `level`.
+
+Getis and Ord's Gi and Gi* statistics are used to identify data points with
locally high values (hot spots) and low
+values (cold spots) in a spatial dataset.
+
+The `ST_WeightedDistanceBand` and `ST_BinaryDistanceBand` functions can be
used to generate the `weights` column.
+
+Format: `ST_GLocal(geom: Geometry, weights: Struct, level: Int)`
+
+Since: `v1.7.1`
+
+SQL Example
+
+```sql
+ST_GLocal(myVariable, ST_BinaryDistanceBandColumn(geometry, 1.0, true, true,
false, struct(myVariable, geometry)), true)
+```
+
+Output:
+
+```
+{0.5238095238095238, 0.4444444444444444, 0.001049802637104223,
2.4494897427831814, 0.00715293921771476}
+```
+
## ST_H3CellDistance
Introduction: return result of h3 function [gridDistance(cel1,
cell2)](https://h3geo.org/docs/api/traversal#griddistance).
@@ -2657,6 +2730,34 @@ Output:
LINESTRING (69.28469348539744 94.28469348539744, 100 125, 111.70035626068274
140.21046313888758)
```
+## ST_LocalOutlierFactor
+
+Introduction: Computes the Local Outlier Factor (LOF) for each point in the
input dataset.
+
+Local Outlier Factor is an algorithm for determining the degree to which a
single record is an inlier or outlier. It is
+based on how close a record is to its `k` nearest neighbors vs how close those
neighbors are to their `k` nearest
+neighbors. Values substantially less than `1` imply that the record is an
inlier, while values greater than `1` imply that
+the record is an outlier.
+
+!!!Note
+ ST_LocalOutlierFactor has a useSphere parameter rather than a useSpheroid
parameter. This function thus uses a spherical model of the earth rather than
an ellipsoidal model when calculating distance.
+
+Format: `ST_LocalOutlierFactor(geometry: Geometry, k: Int, useSphere: Boolean)`
+
+Since: `v1.7.1`
+
+SQL Example
+
+```sql
+SELECT ST_LocalOutlierFactor(geometry, 5, true)
+```
+
+Output:
+
+```
+1.0009256283408587
+```
+
## ST_LocateAlong
Introduction: This function computes Point or MultiPoint geometries
representing locations along a measured input geometry (LineString or
MultiLineString) corresponding to the provided measure value(s). Polygonal
geometry inputs are not supported. The output points lie directly on the input
line at the specified measure positions.
@@ -4416,6 +4517,28 @@ Output:
GEOMETRYCOLLECTION(POLYGON((-1 2,2 -1,-1 -1,-1 2)),POLYGON((-1 2,2 2,2 -1,-1
2)))
```
+## ST_WeightedDistanceBandColumn
+
+Introduction: Introduction: Returns a `weights` column containing every record
in a dataframe within a specified `threshold` distance.
+
+The `weights` column is an array of structs containing the `attributes` from
each neighbor and that neighbor's weight. Since this is a distance weighted
distance band, weights will be distance^alpha.
+
+Format: `ST_WeightedDistanceBandColumn(geometry:Geometry, threshold: Double,
alpha: Double, includeZeroDistanceNeighbors: boolean, includeSelf: boolean,
selfWeight: Double, useSpheroid: boolean, attributes: Struct)`
+
+Since: `v1.7.1`
+
+SQL Example
+
+```sql
+ST_WeightedDistanceBandColumn(geometry, 1.0, -1.0, true, true, 1.0, false,
struct(id, geometry))
+````
+
+Output:
+
+```
+[{{15, POINT (3 1.9)}, 1.0}, {{16, POINT (3 2)}, 9.999999999999991}, {{17,
POINT (3 2.1)}, 4.999999999999996}, {{18, POINT (3 2.2)}, 3.3333333333333304}]
+```
+
## ST_X
Introduction: Returns X Coordinate of given Point null otherwise.
diff --git a/python/sedona/sql/st_functions.py
b/python/sedona/sql/st_functions.py
index f4924c8de0..b3f67fd522 100644
--- a/python/sedona/sql/st_functions.py
+++ b/python/sedona/sql/st_functions.py
@@ -20,6 +20,7 @@ from functools import partial
from typing import Optional, Union
from pyspark.sql import Column
+from pyspark.sql.functions import lit
from sedona.sql.dataframe_api import (
ColumnOrName,
@@ -2462,6 +2463,188 @@ def ST_InterpolatePoint(geom1: ColumnOrName, geom2:
ColumnOrName) -> Column:
return _call_st_function("ST_InterpolatePoint", args)
+@validate_argument_types
+def ST_DBSCAN(
+ geometry: ColumnOrName,
+ epsilon: Union[ColumnOrName, float],
+ min_pts: Union[ColumnOrName, int],
+ use_spheroid: Optional[Union[ColumnOrName, bool]] = False,
+) -> Column:
+ """Perform DBSCAN clustering on the given geometry column.
+
+ @param geometry: Geometry column or name
+ :type geometry: ColumnOrName
+ @param epsilon: the distance between two points to be considered neighbors
+ :type epsilon: ColumnOrName
+ @param min_pts: the number of neighbors a point should have to form a
cluster
+ :type min_pts: ColumnOrName
+ @param use_spheroid: whether to use spheroid for distance calculation
+ :type use_spheroid: ColumnOrName
+ @return: A struct indicating the cluster to which the point belongs and
whether it is a core point
+ """
+
+ if isinstance(epsilon, float):
+ epsilon = lit(epsilon)
+
+ if isinstance(min_pts, int):
+ min_pts = lit(min_pts)
+
+ if isinstance(use_spheroid, bool):
+ use_spheroid = lit(use_spheroid)
+
+ return _call_st_function("ST_DBSCAN", (geometry, epsilon, min_pts,
use_spheroid))
+
+
+@validate_argument_types
+def ST_LocalOutlierFactor(
+ geometry: ColumnOrName,
+ k: Union[ColumnOrName, int],
+ use_spheroid: Optional[Union[ColumnOrName, bool]] = False,
+) -> Column:
+ """Calculate the local outlier factor on the given geometry column.
+
+ @param geometry: Geometry column or name
+ :type geometry: ColumnOrName
+ @param k: the number of neighbors to use for LOF calculation
+ :type k: ColumnOrName
+ @param use_spheroid: whether to use spheroid for distance calculation
+ :type use_spheroid: ColumnOrName
+ @return: A Double indicating the local outlier factor of the point
+ """
+
+ if isinstance(k, int):
+ k = lit(k)
+
+ if isinstance(use_spheroid, bool):
+ use_spheroid = lit(use_spheroid)
+
+ return _call_st_function("ST_LocalOutlierFactor", (geometry, k,
use_spheroid))
+
+
+@validate_argument_types
+def ST_GLocal(
+ x: ColumnOrName,
+ weights: ColumnOrName,
+ star: Optional[Union[ColumnOrName, bool]] = False,
+) -> Column:
+ """Calculate Getis Ord Gi(*) statistics on the given column.
+
+ @param x: The variable we want to compute Gi statistics for
+ :type x: ColumnOrName
+ @param weights: the weights array containing the neighbors, their weights,
and their values of x
+ :type weights: ColumnOrName
+ @param star: whether to use the focal observation in the calculations
+ :type star: ColumnOrName
+ @return: A struct containing the Gi statistics including a p value
+ """
+
+ if isinstance(star, bool):
+ star = lit(star)
+
+ return _call_st_function("ST_GLocal", (x, weights, star))
+
+
+@validate_argument_types
+def ST_BinaryDistanceBandColumn(
+ geometry: ColumnOrName,
+ threshold: ColumnOrName,
+ include_zero_distance_neighbors: Union[ColumnOrName, bool] = True,
+ include_self: Union[ColumnOrName, bool] = False,
+ use_spheroid: Union[ColumnOrName, bool] = False,
+ attributes: ColumnOrName = None,
+) -> Column:
+ """Creates a weights column containing the other records within the
threshold and their weight.
+
+ Weights will always be 1.0.
+
+
+ @param geometry: name of the geometry column
+ @param threshold: Distance threshold for considering neighbors
+ @param include_zero_distance_neighbors: whether to include neighbors that
are 0 distance.
+ @param include_self: whether to include self in the list of neighbors
+ @param use_spheroid: whether to use a cartesian or spheroidal distance
calculation. Default is false
+ @param attributes: the attributes to save in the neighbor column.
+
+ """
+ if isinstance(include_zero_distance_neighbors, bool):
+ include_zero_distance_neighbors = lit(include_zero_distance_neighbors)
+
+ if isinstance(include_self, bool):
+ include_self = lit(include_self)
+
+ if isinstance(use_spheroid, bool):
+ use_spheroid = lit(use_spheroid)
+
+ return _call_st_function(
+ "ST_BinaryDistanceBandColumn",
+ (
+ geometry,
+ threshold,
+ include_zero_distance_neighbors,
+ include_self,
+ use_spheroid,
+ attributes,
+ ),
+ )
+
+
+@validate_argument_types
+def ST_WeightedDistanceBandColumn(
+ geometry: ColumnOrName,
+ threshold: ColumnOrName,
+ alpha: Union[ColumnOrName, float],
+ include_zero_distance_neighbors: Union[ColumnOrName, bool] = True,
+ include_self: Union[ColumnOrName, bool] = False,
+ self_weight: Union[ColumnOrName, float] = 1.0,
+ use_spheroid: Union[ColumnOrName, bool] = False,
+ attributes: ColumnOrName = None,
+) -> Column:
+ """Creates a weights column containing the other records within the
threshold and their weight.
+
+ Weights will be distance^alpha.
+
+
+ @param geometry: name of the geometry column
+ @param threshold: Distance threshold for considering neighbors
+ @param alpha: alpha to use for inverse distance weights. Computation is
dist^alpha. Default is -1.0
+ @param include_zero_distance_neighbors: whether to include neighbors that
are 0 distance. If 0 distance neighbors are
+ included, values are infinity as per the floating point spec (divide
by 0)
+ @param include_self: whether to include self in the list of neighbors
+ @param self_weight: the value to use for the self weight. Default is 1.0
+ @param use_spheroid: whether to use a cartesian or spheroidal distance
calculation. Default is false
+ @param attributes: the attributes to save in the neighbor column.
+
+ """
+ if isinstance(alpha, float):
+ alpha = lit(alpha)
+
+ if isinstance(include_zero_distance_neighbors, bool):
+ include_zero_distance_neighbors = lit(include_zero_distance_neighbors)
+
+ if isinstance(include_self, bool):
+ include_self = lit(include_self)
+
+ if isinstance(self_weight, float):
+ self_weight = lit(self_weight)
+
+ if isinstance(use_spheroid, bool):
+ use_spheroid = lit(use_spheroid)
+
+ return _call_st_function(
+ "ST_WeightedDistanceBandColumn",
+ (
+ geometry,
+ threshold,
+ alpha,
+ include_zero_distance_neighbors,
+ include_self,
+ self_weight,
+ use_spheroid,
+ attributes,
+ ),
+ )
+
+
# Automatically populate __all__
__all__ = [
name
diff --git a/python/sedona/stats/clustering/dbscan.py
b/python/sedona/stats/clustering/dbscan.py
index f1501963db..28b37d8bdc 100644
--- a/python/sedona/stats/clustering/dbscan.py
+++ b/python/sedona/stats/clustering/dbscan.py
@@ -34,6 +34,8 @@ def dbscan(
geometry: Optional[str] = None,
include_outliers: bool = True,
use_spheroid=False,
+ is_core_column_name="isCore",
+ cluster_column_name="cluster",
):
"""Annotates a dataframe with a cluster label for each data record using
the DBSCAN algorithm.
@@ -49,6 +51,8 @@ def dbscan(
include_outliers: whether to return outlier points. If True, outliers
are returned with a cluster value of -1.
Default is False
use_spheroid: whether to use a cartesian or spheroidal distance
calculation. Default is false
+ is_core_column_name: what the name of the column indicating if this is
a core point should be. Default is "isCore"
+ cluster_column_name: what the name of the column indicating the
cluster id should be. Default is "cluster"
Returns:
A PySpark DataFrame containing the cluster label for each row
@@ -62,6 +66,8 @@ def dbscan(
geometry,
include_outliers,
use_spheroid,
+ is_core_column_name,
+ cluster_column_name,
)
return DataFrame(result_df, sedona)
diff --git a/python/sedona/stats/outlier_detection/local_outlier_factor.py
b/python/sedona/stats/outlier_detection/local_outlier_factor.py
index 3050d216b7..7a29f5c508 100644
--- a/python/sedona/stats/outlier_detection/local_outlier_factor.py
+++ b/python/sedona/stats/outlier_detection/local_outlier_factor.py
@@ -30,6 +30,7 @@ def local_outlier_factor(
geometry: Optional[str] = None,
handle_ties: bool = False,
use_spheroid=False,
+ result_column_name: str = "lof",
):
"""Annotates a dataframe with a column containing the local outlier factor
for each data record.
@@ -43,6 +44,7 @@ def local_outlier_factor(
geometry: name of the geometry column
handle_ties: whether to handle ties in the k-distance calculation.
Default is false
use_spheroid: whether to use a cartesian or spheroidal distance
calculation. Default is false
+ result_column_name: the name of the column containing the lof for each
row. Default is "lof"
Returns:
A PySpark DataFrame containing the lof for each row
@@ -55,6 +57,7 @@ def local_outlier_factor(
geometry,
handle_ties,
use_spheroid,
+ result_column_name,
)
return DataFrame(result_df, sedona)
diff --git a/python/sedona/stats/weighting.py b/python/sedona/stats/weighting.py
index 8a5fc7e07a..7b5eb9be9d 100644
--- a/python/sedona/stats/weighting.py
+++ b/python/sedona/stats/weighting.py
@@ -17,7 +17,7 @@
"""Weighting functions for spatial data."""
-from typing import Optional
+from typing import Optional, List
from pyspark.sql import DataFrame, SparkSession
@@ -32,6 +32,8 @@ def add_distance_band_column(
self_weight: float = 1.0,
geometry: Optional[str] = None,
use_spheroid: bool = False,
+ saved_attributes: List[str] = None,
+ result_name: str = "weights",
) -> DataFrame:
"""Annotates a dataframe with a weights column containing the other
records within the threshold and their weight.
@@ -51,7 +53,8 @@ def add_distance_band_column(
self_weight: the value to use for the self weight
geometry: name of the geometry column
use_spheroid: whether to use a cartesian or spheroidal distance
calculation. Default is false
-
+ saved_attributes: the attributes to save in the neighbor column.
Default is all columns.
+ result_name: the name of the resulting column. Default is 'weights'.
Returns:
The input DataFrame with a weight column added containing neighbors
and their weights added to each row.
@@ -67,6 +70,8 @@ def add_distance_band_column(
float(self_weight),
geometry,
use_spheroid,
+ saved_attributes,
+ result_name,
)
@@ -77,6 +82,8 @@ def add_binary_distance_band_column(
include_self: bool = False,
geometry: Optional[str] = None,
use_spheroid: bool = False,
+ saved_attributes: List[str] = None,
+ result_name: str = "weights",
) -> DataFrame:
"""Annotates a dataframe with a weights column containing the other
records within the threshold and their weight.
@@ -93,6 +100,59 @@ def add_binary_distance_band_column(
include_self: whether to include self in the list of neighbors
geometry: name of the geometry column
use_spheroid: whether to use a cartesian or spheroidal distance
calculation. Default is false
+ saved_attributes: the attributes to save in the neighbor column.
Default is all columns.
+ result_name: the name of the resulting column. Default is 'weights'.
+
+ Returns:
+ The input DataFrame with a weight column added containing neighbors
and their weights (always 1) added to each
+ row.
+
+ """
+ sedona = SparkSession.getActiveSession()
+
+ return
sedona._jvm.org.apache.sedona.stats.Weighting.addBinaryDistanceBandColumn(
+ dataframe._jdf,
+ float(threshold),
+ include_zero_distance_neighbors,
+ include_self,
+ geometry,
+ use_spheroid,
+ saved_attributes,
+ result_name,
+ )
+
+
+def add_weighted_distance_band_column(
+ dataframe: DataFrame,
+ threshold: float,
+ alpha: float,
+ include_zero_distance_neighbors: bool = True,
+ include_self: bool = False,
+ self_weight: float = 1.0,
+ geometry: Optional[str] = None,
+ use_spheroid: bool = False,
+ saved_attributes: List[str] = None,
+ result_name: str = "weights",
+) -> DataFrame:
+ """Annotates a dataframe with a weights column containing the other
records within the threshold and their weight.
+
+ Weights will be distance^alpha. The dataframe should contain at least one
GeometryType column. Rows must be unique. If
+ one geometry column is present it will be used automatically. If two are
present, the one named 'geometry' will be
+ used. If more than one are present and neither is named 'geometry', the
column name must be provided. The new column
+ will be named 'cluster'.
+
+ Args:
+ dataframe: DataFrame with geometry column
+ threshold: Distance threshold for considering neighbors
+ alpha: alpha to use for inverse distance weights. Computation is
dist^alpha. Default is -1.0
+ include_zero_distance_neighbors: whether to include neighbors that are
0 distance. If 0 distance neighbors are
+ included and binary is false, values are infinity as per the
floating point spec (divide by 0)
+ include_self: whether to include self in the list of neighbors
+ self_weight: the value to use for the self weight. Default is 1.0
+ geometry: name of the geometry column
+ use_spheroid: whether to use a cartesian or spheroidal distance
calculation. Default is false
+ saved_attributes: the attributes to save in the neighbor column.
Default is all columns.
+ result_name: the name of the resulting column. Default is 'weights'.
Returns:
The input DataFrame with a weight column added containing neighbors
and their weights (always 1) added to each
@@ -100,11 +160,16 @@ def add_binary_distance_band_column(
"""
sedona = SparkSession.getActiveSession()
+
return
sedona._jvm.org.apache.sedona.stats.Weighting.addBinaryDistanceBandColumn(
dataframe._jdf,
float(threshold),
+ float(alpha),
include_zero_distance_neighbors,
include_self,
+ float(self_weight),
geometry,
use_spheroid,
+ saved_attributes,
+ result_name,
)
diff --git a/python/tests/sql/test_dataframe_api.py
b/python/tests/sql/test_dataframe_api.py
index f839b762bb..8299f5b45f 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -1714,3 +1714,19 @@ class TestDataFrameAPI(TestBase):
match=f"Incorrect argument type: [A-Za-z_0-9]+ for {func.__name__}
should be [A-Za-z0-9\\[\\]_, ]+ but received [A-Za-z0-9_]+.",
):
func(*args)
+
+ def test_dbscan(self):
+ df = self.spark.createDataFrame([{"id": 1, "x": 2, "y":
3}]).withColumn(
+ "geometry", f.expr("ST_Point(x, y)")
+ )
+
+ df.withColumn("dbscan", ST_DBSCAN("geometry", 1.0, 2, False)).collect()
+
+ def test_lof(self):
+ df = self.spark.createDataFrame([{"id": 1, "x": 2, "y":
3}]).withColumn(
+ "geometry", f.expr("ST_Point(x, y)")
+ )
+
+ df.withColumn(
+ "localOutlierFactor", ST_LocalOutlierFactor("geometry", 2, False)
+ ).collect()
diff --git a/spark/common/pom.xml b/spark/common/pom.xml
index 7803a93275..9014c9d7cc 100644
--- a/spark/common/pom.xml
+++ b/spark/common/pom.xml
@@ -220,6 +220,12 @@
</exclusion>
</exclusions>
</dependency>
+ <dependency> <!-- Generally this will be provided by the runtime's
spark install -->
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-graphx_${scala.compat.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<sourceDirectory>src/main/java</sourceDirectory>
diff --git
a/spark/common/src/main/java/org/apache/sedona/core/utils/SedonaConf.java
b/spark/common/src/main/java/org/apache/sedona/core/utils/SedonaConf.java
index 28685c6a03..d02e96df93 100644
--- a/spark/common/src/main/java/org/apache/sedona/core/utils/SedonaConf.java
+++ b/spark/common/src/main/java/org/apache/sedona/core/utils/SedonaConf.java
@@ -59,6 +59,9 @@ public class SedonaConf implements Serializable {
// Parameters for knn joins
private boolean includeTieBreakersInKNNJoins = false;
+ // Parameters for geostats
+ private Boolean DBSCANIncludeOutliers = true;
+
public static SedonaConf fromActiveSession() {
return new SedonaConf(SparkSession.active().conf());
}
@@ -98,6 +101,10 @@ public class SedonaConf implements Serializable {
// Parameters for knn joins
this.includeTieBreakersInKNNJoins =
Boolean.parseBoolean(getConfigValue(runtimeConfig,
"join.knn.includeTieBreakers", "false"));
+
+ // Parameters for geostats
+ this.DBSCANIncludeOutliers =
+
Boolean.parseBoolean(runtimeConfig.get("spark.sedona.dbscan.includeOutliers",
"true"));
}
// Helper method to prioritize `sedona.*` over `spark.sedona.*`
@@ -182,4 +189,8 @@ public class SedonaConf implements Serializable {
public SpatialJoinOptimizationMode getSpatialJoinOptimizationMode() {
return spatialJoinOptimizationMode;
}
+
+ public Boolean getDBSCANIncludeOutliers() {
+ return DBSCANIncludeOutliers;
+ }
}
diff --git
a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
index db266c38fb..000f1beaa4 100644
--- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
@@ -25,6 +25,10 @@ import org.apache.sedona.sql.UDF.UdfRegistrator
import org.apache.sedona.sql.UDT.UdtRegistrator
import org.apache.spark.serializer.KryoSerializer
import
org.apache.spark.sql.sedona_sql.optimization.SpatialFilterPushDownForGeoParquet
+
+import org.apache.spark.sql.sedona_sql.optimization.ExtractGeoStatsFunctions
+import
org.apache.spark.sql.sedona_sql.strategy.geostats.EvalGeoStatsFunctionStrategy
+
import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
import org.apache.spark.sql.{SQLContext, SparkSession}
@@ -61,6 +65,17 @@ object SedonaContext {
sparkSession.experimental.extraOptimizations ++= Seq(
new SpatialFilterPushDownForGeoParquet(sparkSession))
}
+
+ // Support geostats functions
+ if
(!sparkSession.experimental.extraOptimizations.contains(ExtractGeoStatsFunctions))
{
+ sparkSession.experimental.extraOptimizations ++=
Seq(ExtractGeoStatsFunctions)
+ }
+ if (!sparkSession.experimental.extraStrategies.exists(
+ _.isInstanceOf[EvalGeoStatsFunctionStrategy])) {
+ sparkSession.experimental.extraStrategies ++= Seq(
+ new EvalGeoStatsFunctionStrategy(sparkSession))
+ }
+
addGeoParquetToSupportNestedFilterSources(sparkSession)
RasterRegistrator.registerAll(sparkSession)
UdtRegistrator.registerAll()
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index b616ecd790..371fe41913 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -339,7 +339,13 @@ object Catalog {
function[RS_Resample](),
function[RS_ReprojectMatch]("nearestneighbor"),
function[RS_FromNetCDF](),
- function[RS_NetCDFInfo]())
+ function[RS_NetCDFInfo](),
+ // geostats functions
+ function[ST_DBSCAN](),
+ function[ST_LocalOutlierFactor](),
+ function[ST_GLocal](),
+ function[ST_BinaryDistanceBandColumn](),
+ function[ST_WeightedDistanceBandColumn]())
// Aggregate functions with Geometry as buffer
val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] =
diff --git
a/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
index 6d5a273854..1b252cc432 100644
--- a/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
@@ -54,6 +54,10 @@ object Weighting {
* name of the geometry column
* @param useSpheroid
* whether to use a cartesian or spheroidal distance calculation. Default
is false
+ * @param savedAttributes
+ * the attributes to save in the neighbor column. Default is all columns.
+ * @param resultName
+ * the name of the resulting column. Default is 'weights'.
* @return
* The input DataFrame with a weight column added containing neighbors and
their weights added
* to each row.
@@ -67,7 +71,9 @@ object Weighting {
includeSelf: Boolean = false,
selfWeight: Double = 1.0,
geometry: String = null,
- useSpheroid: Boolean = false): DataFrame = {
+ useSpheroid: Boolean = false,
+ savedAttributes: Seq[String] = null,
+ resultName: String = "weights"): DataFrame = {
require(threshold >= 0, "Threshold must be greater than or equal to 0")
require(alpha < 0, "Alpha must be less than 0")
@@ -81,6 +87,12 @@ object Weighting {
geometry
}
+ // Always include the geometry column in the saved attributes
+ val savedAttributesWithGeom =
+ if (savedAttributes == null) null
+ else if (!savedAttributes.contains(geometryColumn)) savedAttributes :+
geometryColumn
+ else savedAttributes
+
val distanceFunction: (Column, Column) => Column =
if (useSpheroid) ST_DistanceSpheroid else ST_Distance
@@ -96,14 +108,6 @@ object Weighting {
val formattedDataFrame = dataframe.withColumn(ID_COLUMN,
sha2(to_json(struct("*")), 256))
- // Since spark 3.0 doesn't support dropFields, we need a work around
- val withoutId = (prefix: String, colFunc: String => Column) => {
- formattedDataFrame.schema.fields
- .map(_.name)
- .filter(name => name != ID_COLUMN)
- .map(x => colFunc(prefix + "." + x).alias(x))
- }
-
formattedDataFrame
.alias("l")
.join(
@@ -116,7 +120,13 @@ object Weighting {
col(s"l.$ID_COLUMN"),
struct("l.*").alias("left_contents"),
struct(
- struct(withoutId("r", col): _*).alias("neighbor"),
+ (
+ savedAttributesWithGeom match {
+ case null => struct(col("r.*")).dropFields(ID_COLUMN)
+ case _ =>
+ struct(savedAttributesWithGeom.map(c => col(s"r.$c")): _*)
+ }
+ ).alias("neighbor"),
if (!binary)
pow(distanceFunction(col(s"l.$geometryColumn"),
col(s"r.$geometryColumn")), alpha)
.alias("value")
@@ -127,14 +137,18 @@ object Weighting {
concat(
collect_list(col("weight")),
if (includeSelf)
- array(
- struct(
- struct(withoutId("left_contents", first):
_*).alias("neighbor"),
- lit(selfWeight).alias("value")))
- else array()).alias("weights"))
- .select("left_contents.*", "weights")
+ array(struct(
+ (savedAttributesWithGeom match {
+ case null => first("left_contents").dropFields(ID_COLUMN)
+ case _ =>
+ struct(
+ savedAttributesWithGeom.map(c =>
first(s"left_contents.$c").alias(c)): _*)
+ }).alias("neighbor"),
+ lit(selfWeight).alias("value")))
+ else array()).alias(resultName))
+ .select("left_contents.*", resultName)
.drop(ID_COLUMN)
- .withColumn("weights", filter(col("weights"),
_(f"neighbor")(geometryColumn).isNotNull))
+ .withColumn(resultName, filter(col(resultName),
_(f"neighbor")(geometryColumn).isNotNull))
}
/**
@@ -158,6 +172,10 @@ object Weighting {
* name of the geometry column
* @param useSpheroid
* whether to use a cartesian or spheroidal distance calculation. Default
is false
+ * @param savedAttributes
+ * the attributes to save in the neighbor column. Default is all columns.
+ * @param resultName
+ * the name of the resulting column. Default is 'weights'.
* @return
* The input DataFrame with a weight column added containing neighbors and
their weights
* (always 1) added to each row.
@@ -168,13 +186,73 @@ object Weighting {
includeZeroDistanceNeighbors: Boolean = true,
includeSelf: Boolean = false,
geometry: String = null,
- useSpheroid: Boolean = false): DataFrame = addDistanceBandColumn(
+ useSpheroid: Boolean = false,
+ savedAttributes: Seq[String] = null,
+ resultName: String = "weights"): DataFrame = addDistanceBandColumn(
dataframe,
threshold,
binary = true,
includeZeroDistanceNeighbors = includeZeroDistanceNeighbors,
includeSelf = includeSelf,
geometry = geometry,
- useSpheroid = useSpheroid)
+ useSpheroid = useSpheroid,
+ savedAttributes = savedAttributes,
+ resultName = resultName)
+
+ /**
+ * Annotates a dataframe with a weights column for each data record
containing the other members
+ * within the threshold and their weight. Weights will be dist^alpha. The
dataframe should
+ * contain at least one GeometryType column. Rows must be unique. If one
geometry column is
+ * present it will be used automatically. If two are present, the one named
'geometry' will be
+ * used. If more than one are present and neither is named 'geometry', the
column name must be
+ * provided. The new column will be named 'cluster'.
+ *
+ * @param dataframe
+ * DataFrame with geometry column
+ * @param threshold
+ * Distance threshold for considering neighbors
+ * @param alpha
+ * alpha to use for inverse distance weights. Computation is dist^alpha.
Default is -1.0
+ * @param includeZeroDistanceNeighbors
+ * whether to include neighbors that are 0 distance. If 0 distance
neighbors are included and
+ * binary is false, values are infinity as per the floating point spec
(divide by 0)
+ * @param includeSelf
+ * whether to include self in the list of neighbors
+ * @param selfWeight
+ * the weight to provide for the self as its own neighbor. Default is 1.0
+ * @param geometry
+ * name of the geometry column
+ * @param useSpheroid
+ * whether to use a cartesian or spheroidal distance calculation. Default
is false
+ * @param savedAttributes
+ * the attributes to save in the neighbor column. Default is all columns.
+ * @param resultName
+ * the name of the resulting column. Default is 'weights'.
+ * @return
+ * The input DataFrame with a weight column added containing neighbors and
their weights
+ * (dist^alpha) added to each row.
+ */
+ def addWeightedDistanceBandColumn(
+ dataframe: DataFrame,
+ threshold: Double,
+ alpha: Double = -1.0,
+ includeZeroDistanceNeighbors: Boolean = false,
+ includeSelf: Boolean = false,
+ selfWeight: Double = 1.0,
+ geometry: String = null,
+ useSpheroid: Boolean = false,
+ savedAttributes: Seq[String] = null,
+ resultName: String = "weights"): DataFrame = addDistanceBandColumn(
+ dataframe,
+ threshold,
+ alpha = alpha,
+ binary = false,
+ includeZeroDistanceNeighbors = includeZeroDistanceNeighbors,
+ includeSelf = includeSelf,
+ selfWeight = selfWeight,
+ geometry = geometry,
+ useSpheroid = useSpheroid,
+ savedAttributes = savedAttributes,
+ resultName = resultName)
}
diff --git
a/spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala
b/spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala
index e4cd1f90b4..96a4e8c317 100644
---
a/spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala
+++
b/spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala
@@ -48,6 +48,11 @@ object DBSCAN {
* whether to include outliers in the output. Default is false
* @param useSpheroid
* whether to use a cartesian or spheroidal distance calculation. Default
is false
+ * @param isCoreColumnName
+ * what the name of the column indicating if this is a core point should
be. Default is
+ * "isCore"
+ * @param clusterColumnName
+ * what the name of the column indicating the cluster id should be.
Default is "cluster"
* @return
* The input DataFrame with the cluster label added to each row. Outlier
will have a cluster
* value of -1 if included.
@@ -58,7 +63,9 @@ object DBSCAN {
minPts: Int,
geometry: String = null,
includeOutliers: Boolean = true,
- useSpheroid: Boolean = false): DataFrame = {
+ useSpheroid: Boolean = false,
+ isCoreColumnName: String = "isCore",
+ clusterColumnName: String = "cluster"): DataFrame = {
val geometryCol = geometry match {
case null => getGeometryColumnName(dataframe)
@@ -89,12 +96,12 @@ object DBSCAN {
first(struct("left.*")).alias("leftContents"),
count(col(s"right.id")).alias("neighbors_count"),
collect_list(col(s"right.id")).alias("neighbors"))
- .withColumn("isCore", col("neighbors_count") >= lit(minPts))
- .select("leftContents.*", "neighbors", "isCore")
+ .withColumn(isCoreColumnName, col("neighbors_count") >= lit(minPts))
+ .select("leftContents.*", "neighbors", isCoreColumnName)
.checkpoint()
- val corePointsDF = isCorePointsDF.filter(col("isCore"))
- val borderPointsDF = isCorePointsDF.filter(!col("isCore"))
+ val corePointsDF = isCorePointsDF.filter(col(isCoreColumnName))
+ val borderPointsDF = isCorePointsDF.filter(!col(isCoreColumnName))
val coreEdgesDf = corePointsDF
.select(col("id").alias("src"), explode(col("neighbors")).alias("dst"))
@@ -117,14 +124,14 @@ object DBSCAN {
val outliersDf = idDataframe
.join(clusteredPointsDf, Seq("id"), "left_anti")
- .withColumn("isCore", lit(false))
+ .withColumn(isCoreColumnName, lit(false))
.withColumn("component", lit(-1))
.withColumn("neighbors", array().cast("array<string>"))
val completedDf = (
if (includeOutliers) clusteredPointsDf.unionByName(outliersDf)
else clusteredPointsDf
- ).withColumnRenamed("component", "cluster")
+ ).withColumnRenamed("component", clusterColumnName)
val returnDf = if (hasIdColumn) {
completedDf.drop("neighbors", "id").withColumnRenamed(ID_COLUMN, "id")
diff --git
a/spark/common/src/main/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactor.scala
b/spark/common/src/main/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactor.scala
index b98919de25..1d2689917d 100644
---
a/spark/common/src/main/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactor.scala
+++
b/spark/common/src/main/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactor.scala
@@ -19,7 +19,7 @@
package org.apache.sedona.stats.outlierDetection
import org.apache.sedona.stats.Util.getGeometryColumnName
-import org.apache.spark.sql.sedona_sql.expressions.st_functions.{ST_Distance,
ST_DistanceSpheroid}
+import org.apache.spark.sql.sedona_sql.expressions.st_functions.{ST_Distance,
ST_DistanceSphere}
import org.apache.spark.sql.{Column, DataFrame, SparkSession, functions => f}
object LocalOutlierFactor {
@@ -42,8 +42,10 @@ object LocalOutlierFactor {
* name of the geometry column
* @param handleTies
* whether to handle ties in the k-distance calculation. Default is false
- * @param useSpheroid
+ * @param useSphere
* whether to use a cartesian or spheroidal distance calculation. Default
is false
+ * @param resultColumnName
+ * the name of the column containing the lof for each row. Default is "lof"
*
* @return
* A DataFrame containing the lof for each row
@@ -53,7 +55,8 @@ object LocalOutlierFactor {
k: Int = 20,
geometry: String = null,
handleTies: Boolean = false,
- useSpheroid: Boolean = false): DataFrame = {
+ useSphere: Boolean = false,
+ resultColumnName: String = "lof"): DataFrame = {
if (k < 1)
throw new IllegalArgumentException("k must be a positive integer")
@@ -67,8 +70,8 @@ object LocalOutlierFactor {
} else "false" // else case to make compiler happy
val distanceFunction: (Column, Column) => Column =
- if (useSpheroid) ST_DistanceSpheroid else ST_Distance
- val useSpheroidString = if (useSpheroid) "True" else "False" // for the
SQL expression
+ if (useSphere) ST_DistanceSphere else ST_Distance
+ val useSpheroidString = if (useSphere) "True" else "False" // for the SQL
expression
val geometryColumn = if (geometry == null)
getGeometryColumnName(dataframe) else geometry
@@ -136,8 +139,8 @@ object LocalOutlierFactor {
.groupBy("a_id")
.agg(
f.first(CONTENTS_COLUMN_NAME).alias(CONTENTS_COLUMN_NAME),
- (f.sum("b_lrd") / (f.count("b_lrd") * f.first("a_lrd"))).alias("lof"))
- .select(f.col(f"$CONTENTS_COLUMN_NAME.*"), f.col("lof"))
+ (f.sum("b_lrd") / (f.count("b_lrd") *
f.first("a_lrd"))).alias(resultColumnName))
+ .select(f.col(f"$CONTENTS_COLUMN_NAME.*"), f.col(resultColumnName))
if (handleTies)
SparkSession.getActiveSession.get.conf
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala
new file mode 100644
index 0000000000..8c6b645daf
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.sedona.core.utils.SedonaConf
+import org.apache.sedona.stats.Weighting.{addBinaryDistanceBandColumn,
addWeightedDistanceBandColumn}
+import org.apache.sedona.stats.clustering.DBSCAN.dbscan
+import org.apache.sedona.stats.hotspotDetection.GetisOrd.gLocal
+import
org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, ImplicitCastInputTypes, Literal,
ScalarSubquery, Unevaluable}
+import org.apache.spark.sql.execution.{LogicalRDD, SparkPlan}
+import org.apache.spark.sql.functions.{col, struct}
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+
+import scala.reflect.ClassTag
+
+// We mark ST_GeoStatsFunction as non-deterministic to avoid the filter
push-down optimization pass
+// duplicates the ST_GeoStatsFunction when pushing down aliased
ST_GeoStatsFunction through a
+// Project operator. This will make ST_GeoStatsFunction being evaluated twice.
+trait ST_GeoStatsFunction
+ extends Expression
+ with ImplicitCastInputTypes
+ with Unevaluable
+ with Serializable {
+
+ final override lazy val deterministic: Boolean = false
+
+ override def nullable: Boolean = true
+
+ private final lazy val sparkSession = SparkSession.getActiveSession.get
+
+ protected final lazy val geometryColumnName = getInputName(0, "geometry")
+
+ protected def getInputName(i: Int, fieldName: String): String = children(i)
match {
+ case ref: AttributeReference => ref.name
+ case _ =>
+ throw new IllegalArgumentException(
+ f"$fieldName argument must be a named reference to an existing column")
+ }
+
+ protected def getInputNames(i: Int, fieldName: String): Seq[String] =
children(
+ i).dataType match {
+ case StructType(fields) => fields.map(_.name)
+ case _ => throw new IllegalArgumentException(f"$fieldName argument must be
a struct")
+ }
+
+ protected def getResultName(resultAttrs: Seq[Attribute]): String =
resultAttrs match {
+ case Seq(attr) => attr.name
+ case _ => throw new IllegalArgumentException("resultAttrs must have
exactly one attribute")
+ }
+
+ protected def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame
+
+ protected def getScalarValue[T](i: Int, name: String)(implicit ct:
ClassTag[T]): T = {
+ children(i) match {
+ case Literal(l: T, _) => l
+ case _: Literal =>
+ throw new IllegalArgumentException(f"$name must be an instance of
${ct.runtimeClass}")
+ case s: ScalarSubquery =>
+ s.eval() match {
+ case t: T => t
+ case _ =>
+ throw new IllegalArgumentException(
+ f"$name must be an instance of ${ct.runtimeClass}")
+ }
+ case _ => throw new IllegalArgumentException(f"$name must be a scalar
value")
+ }
+ }
+
+ def execute(plan: SparkPlan, resultAttrs: Seq[Attribute]): RDD[InternalRow]
= {
+ val df = doExecute(
+ Dataset.ofRows(sparkSession, LogicalRDD(plan.output,
plan.execute())(sparkSession)),
+ resultAttrs)
+ df.queryExecution.toRdd
+ }
+
+}
+
+case class ST_DBSCAN(children: Seq[Expression]) extends ST_GeoStatsFunction {
+
+ override def dataType: DataType = StructType(
+ Seq(StructField("isCore", BooleanType), StructField("cluster", LongType)))
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(GeometryUDT, DoubleType, IntegerType, BooleanType)
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression =
+ copy(children = newChildren)
+
+ override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame = {
+ require(
+ !dataframe.columns.contains("__isCore"),
+ "__isCore is a reserved name by the dbscan algorithm. Please rename the
columns before calling the ST_DBSCAN function.")
+ require(
+ !dataframe.columns.contains("__cluster"),
+ "__cluster is a reserved name by the dbscan algorithm. Please rename
the columns before calling the ST_DBSCAN function.")
+
+ dbscan(
+ dataframe,
+ getScalarValue[Double](1, "epsilon"),
+ getScalarValue[Int](2, "minPts"),
+ geometryColumnName,
+ SedonaConf.fromActiveSession().getDBSCANIncludeOutliers,
+ getScalarValue[Boolean](3, "useSpheroid"),
+ "__isCore",
+ "__cluster")
+ .withColumn(getResultName(resultAttrs), struct(col("__isCore"),
col("__cluster")))
+ .drop("__isCore", "__cluster")
+ }
+}
+
+case class ST_LocalOutlierFactor(children: Seq[Expression]) extends
ST_GeoStatsFunction {
+
+ override def dataType: DataType = DoubleType
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(GeometryUDT, IntegerType, BooleanType)
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression =
+ copy(children = newChildren)
+
+ override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame = {
+ localOutlierFactor(
+ dataframe,
+ getScalarValue[Int](1, "k"),
+ geometryColumnName,
+ SedonaConf.fromActiveSession().isIncludeTieBreakersInKNNJoins,
+ getScalarValue[Boolean](2, "useSphere"),
+ getResultName(resultAttrs))
+ }
+}
+
+case class ST_GLocal(children: Seq[Expression]) extends ST_GeoStatsFunction {
+
+ override def dataType: DataType = StructType(
+ Seq(
+ StructField("G", DoubleType),
+ StructField("EG", DoubleType),
+ StructField("VG", DoubleType),
+ StructField("Z", DoubleType),
+ StructField("P", DoubleType)))
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ val xDataType = children(0).dataType
+ require(xDataType == DoubleType || xDataType == IntegerType, "x must be a
numeric value")
+ Seq(
+ xDataType,
+ children(1).dataType, // Array of the weights
+ BooleanType)
+ }
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression =
+ copy(children = newChildren)
+
+ override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame = {
+ gLocal(
+ dataframe,
+ getInputName(0, "x"),
+ getInputName(1, "weights"),
+ 0,
+ getScalarValue[Boolean](2, "star"),
+ 0.0)
+ .withColumn(
+ getResultName(resultAttrs),
+ struct(col("G"), col("EG"), col("VG"), col("Z"), col("P")))
+ .drop("G", "EG", "VG", "Z", "P")
+ }
+}
+
+case class ST_BinaryDistanceBandColumn(children: Seq[Expression]) extends
ST_GeoStatsFunction {
+ override def dataType: DataType = ArrayType(
+ StructType(
+ Seq(StructField("neighbor", children(5).dataType), StructField("value",
DoubleType))))
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(GeometryUDT, DoubleType, BooleanType, BooleanType, BooleanType,
children(5).dataType)
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression =
+ copy(children = newChildren)
+
+ override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame = {
+ val attributeNames = getInputNames(5, "attributes")
+ require(attributeNames.nonEmpty, "attributes must have at least one
column")
+ require(
+ attributeNames.contains(geometryColumnName),
+ "attributes must contain the geometry column")
+
+ addBinaryDistanceBandColumn(
+ dataframe,
+ getScalarValue[Double](1, "threshold"),
+ getScalarValue[Boolean](2, "includeZeroDistanceNeighbors"),
+ getScalarValue[Boolean](3, "includeSelf"),
+ geometryColumnName,
+ getScalarValue[Boolean](4, "useSpheroid"),
+ attributeNames,
+ getResultName(resultAttrs))
+ }
+}
+
+case class ST_WeightedDistanceBandColumn(children: Seq[Expression]) extends
ST_GeoStatsFunction {
+
+ override def dataType: DataType = ArrayType(
+ StructType(
+ Seq(StructField("neighbor", children(7).dataType), StructField("value",
DoubleType))))
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(
+ GeometryUDT,
+ DoubleType,
+ DoubleType,
+ BooleanType,
+ BooleanType,
+ DoubleType,
+ BooleanType,
+ children(7).dataType)
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression =
+ copy(children = newChildren)
+
+ override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]):
DataFrame = {
+ val attributeNames = getInputNames(7, "attributes")
+ require(attributeNames.nonEmpty, "attributes must have at least one
column")
+ require(
+ attributeNames.contains(geometryColumnName),
+ "attributes must contain the geometry column")
+
+ addWeightedDistanceBandColumn(
+ dataframe,
+ getScalarValue[Double](1, "threshold"),
+ getScalarValue[Double](2, "alpha"),
+ getScalarValue[Boolean](3, "includeZeroDistanceNeighbors"),
+ getScalarValue[Boolean](4, "includeSelf"),
+ getScalarValue[Double](5, "selfWeight"),
+ geometryColumnName,
+ getScalarValue[Boolean](6, "useSpheroid"),
+ attributeNames,
+ getResultName(resultAttrs))
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index 0defb89a2f..0c332b10b9 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -959,4 +959,46 @@ object st_functions extends DataFrameAPI {
def ST_InterpolatePoint(geom1: String, geom2: String): Column =
wrapExpression[ST_InterpolatePoint](geom1, geom2)
+ def ST_DBSCAN(geom: Column, epsilon: Column, minPoints: Column, useSpheroid:
Column): Column =
+ wrapExpression[ST_DBSCAN](geom, epsilon, minPoints, useSpheroid)
+
+ def ST_LocalOutlierFactor(geom: Column, k: Column, useSpheroid: Column):
Column =
+ wrapExpression[ST_LocalOutlierFactor](geom, k, useSpheroid)
+
+ def ST_GLocal(x: Column, weights: Column, star: Column): Column =
+ wrapExpression[ST_GLocal](x, weights, star)
+
+ def ST_BinaryDistanceBandColumn(
+ geometry: Column,
+ threshold: Column,
+ includeZeroDistanceNeighbors: Column,
+ includeSelf: Column,
+ useSpheroid: Column,
+ attributes: Column): Column =
+ wrapExpression[ST_BinaryDistanceBandColumn](
+ geometry,
+ threshold,
+ includeZeroDistanceNeighbors,
+ includeSelf,
+ useSpheroid,
+ attributes)
+
+ def ST_WeightedDistanceBandColumn(
+ geometry: Column,
+ threshold: Column,
+ alpha: Column,
+ includeZeroDistanceNeighbors: Column,
+ includeSelf: Column,
+ selfWeight: Column,
+ useSpheroid: Column,
+ attributes: Column): Column =
+ wrapExpression[ST_BinaryDistanceBandColumn](
+ geometry,
+ threshold,
+ alpha,
+ includeZeroDistanceNeighbors,
+ includeSelf,
+ selfWeight,
+ useSpheroid,
+ attributes)
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractGeoStatsFunctions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractGeoStatsFunctions.scala
new file mode 100644
index 0000000000..6b4cf9ccea
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractGeoStatsFunctions.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.optimization
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.sedona_sql.expressions.ST_GeoStatsFunction
+import org.apache.spark.sql.sedona_sql.plans.logical.EvalGeoStatsFunction
+
+import scala.collection.mutable
+
+/**
+ * Extracts GeoStats functions from operators, rewriting the query plan so
that the geo-stats
+ * functions can be evaluated alone in its own physical executors.
+ */
+object ExtractGeoStatsFunctions extends Rule[LogicalPlan] {
+ var geoStatsResultCount = 0
+
+ private def collectGeoStatsFunctionsFromExpressions(
+ expressions: Seq[Expression]): Seq[ST_GeoStatsFunction] = {
+ def collectGeoStatsFunctions(expr: Expression): Seq[ST_GeoStatsFunction] =
expr match {
+ case expr: ST_GeoStatsFunction => Seq(expr)
+ case e => e.children.flatMap(collectGeoStatsFunctions)
+ }
+ expressions.flatMap(collectGeoStatsFunctions)
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ // SPARK-26293: A subquery will be rewritten into join later, and will go
through this rule
+ // eventually. Here we skip subquery, as geo-stats functions only needs to
be extracted once.
+ case s: Subquery if s.correlated => plan
+ case _ =>
+ plan.transformUp {
+ case p: EvalGeoStatsFunction => p
+ case plan: LogicalPlan => extract(plan)
+ }
+ }
+
+ private def canonicalizeDeterministic(u: ST_GeoStatsFunction) = {
+ if (u.deterministic) {
+ u.canonicalized.asInstanceOf[ST_GeoStatsFunction]
+ } else {
+ u
+ }
+ }
+
+ /**
+ * Extract all the geo-stats functions from the current operator and
evaluate them before the
+ * operator.
+ */
+ private def extract(plan: LogicalPlan): LogicalPlan = {
+ val geoStatsFuncs = plan match {
+ case e: EvalGeoStatsFunction =>
+ collectGeoStatsFunctionsFromExpressions(e.function.children)
+ case _ =>
+
ExpressionSet(collectGeoStatsFunctionsFromExpressions(plan.expressions))
+ // ignore the ST_GeoStatsFunction that come from second/third
aggregate, which is not used
+ .filter(func => func.references.subsetOf(plan.inputSet))
+ .filter(func =>
+ plan.children.exists(child =>
func.references.subsetOf(child.outputSet)))
+ .toSeq
+ .asInstanceOf[Seq[ST_GeoStatsFunction]]
+ }
+
+ if (geoStatsFuncs.isEmpty) {
+ // If there aren't any, we are done.
+ plan
+ } else {
+ // Transform the first geo-stats function we have found. We'll call
extract recursively later
+ // to transform the rest.
+ val geoStatsFunc = geoStatsFuncs.head
+
+ val attributeMap = mutable.HashMap[ST_GeoStatsFunction, Expression]()
+ // Rewrite the child that has the input required for the UDF
+ val newChildren = plan.children.map { child =>
+ if (geoStatsFunc.references.subsetOf(child.outputSet)) {
+ geoStatsResultCount += 1
+ val resultAttr =
+ AttributeReference(f"geoStatsResult$geoStatsResultCount",
geoStatsFunc.dataType)()
+ val evaluation = EvalGeoStatsFunction(geoStatsFunc, Seq(resultAttr),
child)
+ attributeMap += (canonicalizeDeterministic(geoStatsFunc) ->
resultAttr)
+ extract(evaluation) // handle nested geo-stats functions
+ } else {
+ child
+ }
+ }
+
+ // Replace the geo stats function call with the newly created
geoStatsResult attribute
+ val rewritten = plan.withNewChildren(newChildren).transformExpressions {
+ case p: ST_GeoStatsFunction =>
attributeMap.getOrElse(canonicalizeDeterministic(p), p)
+ }
+
+ // extract remaining geo-stats functions recursively
+ val newPlan = extract(rewritten)
+ if (newPlan.output != plan.output) {
+ // Trim away the new UDF value if it was only used for filtering or
something.
+ Project(plan.output, newPlan)
+ } else {
+ newPlan
+ }
+ }
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalGeoStatsFunction.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalGeoStatsFunction.scala
new file mode 100644
index 0000000000..8daeb0c304
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalGeoStatsFunction.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.AttributeSet
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.UnaryNode
+
+case class EvalGeoStatsFunction(
+ function: Expression,
+ resultAttrs: Seq[Attribute],
+ child: LogicalPlan)
+ extends UnaryNode {
+
+ override def output: Seq[Attribute] = child.output ++ resultAttrs
+
+ override def producedAttributes: AttributeSet = AttributeSet(resultAttrs)
+
+ override protected def withNewChildInternal(newChild: LogicalPlan):
LogicalPlan =
+ copy(child = newChild)
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionExec.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionExec.scala
new file mode 100644
index 0000000000..fbecb69ec4
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionExec.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategy.geostats
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.sedona_sql.expressions.ST_GeoStatsFunction
+
+case class EvalGeoStatsFunctionExec(
+ function: ST_GeoStatsFunction,
+ child: SparkPlan,
+ resultAttrs: Seq[Attribute])
+ extends UnaryExecNode {
+
+ override protected def doExecute(): RDD[InternalRow] =
function.execute(child, resultAttrs)
+
+ override def output: Seq[Attribute] = child.output ++ resultAttrs
+
+ override def producedAttributes: AttributeSet = AttributeSet(resultAttrs)
+
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionStrategy.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionStrategy.scala
new file mode 100644
index 0000000000..4c10b747a6
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionStrategy.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategy.geostats
+
+import org.apache.spark.sql.Strategy
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.sedona_sql.plans.logical.EvalGeoStatsFunction
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.sedona_sql.expressions.ST_GeoStatsFunction
+
+class EvalGeoStatsFunctionStrategy(spark: SparkSession) extends Strategy {
+
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
+ plan match {
+ case EvalGeoStatsFunction(function: ST_GeoStatsFunction, resultAttrs,
child) =>
+ EvalGeoStatsFunctionExec(function, planLater(child), resultAttrs) ::
Nil
+ case _ => Nil
+ }
+ }
+}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/GeoStatsSuite.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/GeoStatsSuite.scala
new file mode 100644
index 0000000000..9567dcc95f
--- /dev/null
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/GeoStatsSuite.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql
+
+import org.apache.sedona.stats.Weighting.{addBinaryDistanceBandColumn,
addWeightedDistanceBandColumn}
+import org.apache.sedona.stats.clustering.DBSCAN.dbscan
+import org.apache.sedona.stats.hotspotDetection.GetisOrd.gLocal
+import
org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.{col, expr, lit}
+import org.apache.spark.sql.sedona_sql.expressions.st_functions.{ST_DBSCAN,
ST_LocalOutlierFactor}
+
+class GeoStatsSuite extends TestBaseScala {
+ private val spark = sparkSession
+
+ case class Record(id: Int, x: Double, y: Double)
+
+ def getData: DataFrame = {
+ spark
+ .createDataFrame(
+ Seq(
+ Record(10, 1.0, 1.8),
+ Record(11, 1.0, 1.9),
+ Record(12, 1.0, 2.0),
+ Record(13, 1.0, 2.1),
+ Record(14, 2.0, 2.0),
+ Record(15, 3.0, 1.9),
+ Record(16, 3.0, 2.0),
+ Record(17, 3.0, 2.1),
+ Record(18, 3.0, 2.2)))
+ .withColumn("geometry", expr("ST_Point(x, y)"))
+ }
+
+ it("test dbscan function") {
+ dbscan(getData.withColumn("sql_results", expr("ST_DBSCAN(geometry, 1.0, 4,
false)")), 1.0, 4)
+ .where("sql_results.cluster = cluster and sql_results.isCore = isCore")
+ .count() == getData.count()
+ }
+
+ it("test dbscan function df method") {
+ dbscan(
+ getData.withColumn("sql_results", ST_DBSCAN(col("geometry"), lit(1.0),
lit(4), lit(false))),
+ 1.0,
+ 4)
+ .where("sql_results.cluster = cluster and sql_results.isCore = isCore")
+ .count() == getData.count()
+ }
+
+ it("test dbscan function with distance column") {
+ dbscan(
+ getData.withColumn("sql_results", expr("ST_DBSCAN(geometry, 1.0, 4,
true)")),
+ 1.0,
+ 4,
+ useSpheroid = true)
+ .where("sql_results.cluster = cluster and sql_results.isCore = isCore")
+ .count() == getData.count()
+ }
+
+ it("test dbscan function with scalar subquery") {
+ dbscan(
+ getData.withColumn(
+ "sql_results",
+ expr("ST_DBSCAN(geometry, (SELECT ARRAY(1.0, 2.0)[0]), 4, false)")),
+ 1.0,
+ 4)
+ .where("sql_results.cluster = cluster and sql_results.isCore = isCore")
+ .count() == getData.count()
+ }
+
+ it("test dbscan with geom literal") {
+ val error = intercept[IllegalArgumentException] {
+ spark.sql("SELECT ST_DBSCAN(ST_GeomFromWKT('POINT(0.0 1.1)'), 1.0, 4,
false)").collect()
+ }
+ assert(
+ error
+ .asInstanceOf[IllegalArgumentException]
+ .getMessage == "geometry argument must be a named reference to an
existing column")
+ }
+
+ it("test dbscan with minPts variable") {
+ val error = intercept[IllegalArgumentException] {
+ getData
+ .withColumn("result", ST_DBSCAN(col("geometry"), lit(1.0), col("id"),
lit(false)))
+ .collect()
+ }
+
+ assert(
+ error
+ .asInstanceOf[IllegalArgumentException]
+ .getMessage
+ .contains("minPts must be a scalar value"))
+ }
+
+ it("test lof") {
+ localOutlierFactor(
+ getData.withColumn("sql_result", expr("ST_LocalOutlierFactor(geometry,
4, false)")),
+ 4)
+ .where("sql_result = lof")
+ .count() == getData.count()
+ }
+
+ it("test lof with dataframe method") {
+ localOutlierFactor(
+ getData.withColumn(
+ "sql_result",
+ ST_LocalOutlierFactor(col("geometry"), lit(4), lit(false))),
+ 4)
+ .where("sql_result = lof")
+ .count() == getData.count()
+ }
+
+ it("test geostats function in another function") {
+ getData
+ .withColumn("sql_result", expr("SQRT(ST_LocalOutlierFactor(geometry, 4,
false))"))
+ .collect()
+ }
+
+ it("test DBSCAN with a column named __isCore in input df") {
+ val exception = intercept[IllegalArgumentException] {
+ getData
+ .withColumn("__isCore", lit(1))
+ .withColumn("sql_result", expr("ST_DBSCAN(geometry, 0.1, 4, false)"))
+ .collect()
+ }
+ assert(
+ exception.getMessage == "requirement failed: __isCore is a reserved
name by the dbscan algorithm. Please rename the columns before calling the
ST_DBSCAN function.")
+ }
+
+ it("test ST_BinaryDistanceBandColumn") {
+ val weightedDf = getData
+ .withColumn(
+ "someWeights",
+ expr(
+ "array_sort(ST_BinaryDistanceBandColumn(geometry, 1.0, true, true,
false, struct(id, geometry)))"))
+
+ val resultsDf = addBinaryDistanceBandColumn(
+ weightedDf,
+ 1.0,
+ true,
+ true,
+ savedAttributes = Seq("id", "geometry"))
+ .withColumn("weights", expr("array_sort(weights)"))
+ .where("someWeights = weights")
+
+ assert(resultsDf.count == weightedDf.count())
+ }
+
+ it("test ST_WeightedDistanceBandColumn") {
+ val weightedDf = getData
+ .withColumn(
+ "someWeights",
+ expr(
+ "array_sort(ST_WeightedDistanceBandColumn(geometry, 1.0, -1.0, true,
true, 1.0, false, struct(id, geometry)))"))
+
+ val resultsDf = addWeightedDistanceBandColumn(
+ weightedDf,
+ 1.0,
+ -1.0,
+ true,
+ true,
+ savedAttributes = Seq("id", "geometry"),
+ selfWeight = 1.0)
+ .withColumn("weights", expr("array_sort(weights)"))
+ .where("someWeights = weights")
+
+ assert(resultsDf.count == weightedDf.count())
+ }
+
+ it("test GI with ST_BinaryDistanceBandColumn") {
+ val weightedDf = getData
+ .withColumn(
+ "someWeights",
+ expr(
+ "ST_BinaryDistanceBandColumn(geometry, 1.0, true, true, false,
struct(id, geometry))"))
+
+ val giDf = weightedDf
+ .withColumn("gi", expr("ST_GLocal(id, someWeights, true)"))
+ assert(
+ gLocal(giDf, "id", weights = "someWeights", star = true)
+ .where("G = gi.G")
+ .count() == weightedDf.count())
+ }
+
+ it("test nested ST_Geostats calls with getis ord") {
+ getData
+ .withColumn(
+ "GI",
+ expr(
+ "ST_GLocal(id, ST_BinaryDistanceBandColumn(geometry, 1.0, true,
true, false, struct(id, geometry)), true)"))
+ .collect()
+ }
+
+ it("test ST_Geostats with string column") {
+ getData
+ .withColumn("someString", lit("test"))
+ .withColumn("sql_results", expr("ST_DBSCAN(geometry, 1.0, 4, false)"))
+ .collect()
+ }
+}