This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new abc51e3d [SEDONA-280] Add ST_GeometricMedian (#831)
abc51e3d is described below
commit abc51e3db675db643418c2e03f1be3f63d2a1b37
Author: Artem <[email protected]>
AuthorDate: Wed May 17 21:57:27 2023 -0700
[SEDONA-280] Add ST_GeometricMedian (#831)
---
.../java/org/apache/sedona/common/Functions.java | 160 ++++++++++++++++++++-
.../org/apache/sedona/common/FunctionsTest.java | 58 +++++++-
docs/api/flink/Function.md | 31 +++-
docs/api/sql/Function.md | 31 ++++
.../main/java/org/apache/sedona/flink/Catalog.java | 3 +-
.../apache/sedona/flink/expressions/Functions.java | 33 +++++
.../java/org/apache/sedona/flink/FunctionTest.java | 32 +++++
.../java/org/apache/sedona/flink/TestBase.java | 17 +++
python/sedona/sql/st_functions.py | 26 ++++
python/tests/sql/test_dataframe_api.py | 1 +
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 1 +
.../sql/sedona_sql/expressions/Functions.scala | 15 ++
.../sql/sedona_sql/expressions/st_functions.scala | 10 ++
.../org/apache/sedona/sql/TestBaseScala.scala | 16 +++
.../org/apache/sedona/sql/functionTestScala.scala | 23 ++-
15 files changed, 446 insertions(+), 11 deletions(-)
diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java
b/common/src/main/java/org/apache/sedona/common/Functions.java
index c86572f8..ad7af119 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -25,6 +25,7 @@ import org.geotools.referencing.CRS;
import org.locationtech.jts.algorithm.MinimumBoundingCircle;
import org.locationtech.jts.algorithm.hull.ConcaveHull;
import org.locationtech.jts.geom.*;
+import org.locationtech.jts.geom.impl.CoordinateArraySequence;
import org.locationtech.jts.geom.util.GeometryFixer;
import org.locationtech.jts.io.gml2.GMLWriter;
import org.locationtech.jts.io.kml.KMLWriter;
@@ -42,18 +43,18 @@ import org.opengis.referencing.operation.MathTransform;
import org.opengis.referencing.operation.TransformException;
import org.wololo.jts2geojson.GeoJSONWriter;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashSet;
-import java.util.List;
+import java.util.*;
import java.util.stream.Collectors;
-import java.util.stream.Stream;
+
+import static com.google.common.geometry.S2.DBL_EPSILON;
public class Functions {
private static final GeometryFactory GEOMETRY_FACTORY = new
GeometryFactory();
private static Geometry EMPTY_POLYGON =
GEOMETRY_FACTORY.createPolygon(null, null);
private static GeometryCollection EMPTY_GEOMETRY_COLLECTION =
GEOMETRY_FACTORY.createGeometryCollection(null);
+ private static final double DEFAULT_TOLERANCE = 1e-6;
+ private static final int DEFAULT_MAX_ITER = 1000;
public static double area(Geometry geometry) {
return geometry.getArea();
@@ -730,4 +731,153 @@ public class Functions {
return GEOMETRY_FACTORY.createGeometryCollection();
}
+
+ // ported from
https://github.com/postgis/postgis/blob/f6ed58d1fdc865d55d348212d02c11a10aeb2b30/liblwgeom/lwgeom_median.c
+ // geometry ST_GeometricMedian ( geometry g , float8 tolerance , int
max_iter , boolean fail_if_not_converged );
+
+ private static double distance3d(Coordinate p1, Coordinate p2) {
+ double dx = p2.x - p1.x;
+ double dy = p2.y - p1.y;
+ double dz = p2.z - p1.z;
+ return Math.sqrt(dx * dx + dy * dy + dz * dz);
+ }
+
+ private static void distances(Coordinate curr, Coordinate[] points,
double[] distances) {
+ for(int i = 0; i < points.length; i++) {
+ distances[i] = distance3d(curr, points[i]);
+ }
+ }
+
+ private static double iteratePoints(Coordinate curr, Coordinate[] points,
double[] distances) {
+ Coordinate next = new Coordinate(0, 0, 0);
+ double delta = 0;
+ double denom = 0;
+ boolean hit = false;
+ distances(curr, points, distances);
+
+ for (int i = 0; i < points.length; i++) {
+ /* we need to use lower epsilon than in FP_IS_ZERO in the loop for
calculation to converge */
+ double distance = distances[i];
+ if (distance > DBL_EPSILON) {
+ Coordinate coordinate = points[i];
+ next.x += coordinate.x / distance;
+ next.y += coordinate.y / distance;
+ next.z += coordinate.z / distance;
+ denom += 1.0 / distance;
+ } else {
+ hit = true;
+ }
+ }
+ /* negative weight shouldn't get here */
+ //assert(denom >= 0);
+
+ /* denom is zero in case of multipoint of single point when we've
converged perfectly */
+ if (denom > DBL_EPSILON) {
+ next.x /= denom;
+ next.y /= denom;
+ next.z /= denom;
+
+ /* If any of the intermediate points in the calculation is found
in the
+ * set of input points, the standard Weiszfeld method gets stuck
with a
+ * divide-by-zero.
+ *
+ * To get ourselves out of the hole, we follow an alternate
procedure to
+ * get the next iteration, as described in:
+ *
+ * Vardi, Y. and Zhang, C. (2011) "A modified Weiszfeld algorithm
for the
+ * Fermat-Weber location problem." Math. Program., Ser. A 90:
559-566.
+ * DOI 10.1007/s101070100222
+ *
+ * Available online at the time of this writing at
+ * http://www.stat.rutgers.edu/home/cunhui/papers/43.pdf
+ */
+ if (hit) {
+ double dx = 0;
+ double dy = 0;
+ double dz = 0;
+ for (int i = 0; i < points.length; i++) {
+ double distance = distances[i];
+ if (distance > DBL_EPSILON) {
+ Coordinate coordinate = points[i];
+ dx += (coordinate.x - curr.x) / distance;
+ dy += (coordinate.y - curr.y) / distance;
+ dz += (coordinate.z - curr.z) / distance;
+ }
+ }
+ double dSqr = Math.sqrt(dx*dx + dy*dy + dz*dz);
+ /* Avoid division by zero if the intermediate point is the
median */
+ if (dSqr > DBL_EPSILON) {
+ double rInv = Math.max(0, 1.0 / dSqr);
+ next.x = (1.0 - rInv)*next.x + rInv*curr.x;
+ next.y = (1.0 - rInv)*next.y + rInv*curr.y;
+ next.z = (1.0 - rInv)*next.z + rInv*curr.z;
+ }
+ }
+ delta = distance3d(curr, next);
+ curr.x = next.x;
+ curr.y = next.y;
+ curr.z = next.z;
+ }
+ return delta;
+ }
+
+ private static Coordinate initGuess(Coordinate[] points) {
+ Coordinate guess = new Coordinate(0, 0, 0);
+ for (Coordinate point : points) {
+ guess.x += point.x / points.length;
+ guess.y += point.y / points.length;
+ guess.z += point.z / points.length;
+ }
+ return guess;
+ }
+
+ private static Coordinate[] extractCoordinates(Geometry geometry) {
+ Coordinate[] points = geometry.getCoordinates();
+ if(points.length == 0)
+ return points;
+ boolean is3d = !Double.isNaN(points[0].z);
+ Coordinate[] coordinates = new Coordinate[points.length];
+ for(int i = 0; i < points.length; i++) {
+ coordinates[i] = points[i].copy();
+ if(!is3d)
+ coordinates[i].z = 0.0;
+ }
+ return coordinates;
+ }
+
+ public static Geometry geometricMedian(Geometry geometry, double
tolerance, int maxIter, boolean failIfNotConverged) throws Exception {
+ String geometryType = geometry.getGeometryType();
+ if(!(Geometry.TYPENAME_POINT.equals(geometryType) ||
Geometry.TYPENAME_MULTIPOINT.equals(geometryType))) {
+ throw new Exception("Unsupported geometry type: " + geometryType);
+ }
+ Coordinate[] coordinates = extractCoordinates(geometry);
+ if(coordinates.length == 0)
+ return new Point(null, GEOMETRY_FACTORY);
+ Coordinate median = initGuess(coordinates);
+ double delta = Double.MAX_VALUE;
+ double[] distances = new double[coordinates.length]; // preallocate to
reduce gc pressure for large iterations
+ for(int i = 0; i < maxIter && delta > tolerance; i++)
+ delta = iteratePoints(median, coordinates, distances);
+ if (failIfNotConverged && delta > tolerance)
+ throw new Exception(String.format("Median failed to converge
within %.1E after %d iterations.", tolerance, maxIter));
+ boolean is3d = !Double.isNaN(geometry.getCoordinate().z);
+ if(!is3d)
+ median.z = Double.NaN;
+ Point point = new Point(new CoordinateArraySequence(new
Coordinate[]{median}), GEOMETRY_FACTORY);
+ point.setSRID(geometry.getSRID());
+ return point;
+ }
+
+ public static Geometry geometricMedian(Geometry geometry, double
tolerance, int maxIter) throws Exception {
+ return geometricMedian(geometry, tolerance, maxIter, false);
+ }
+
+ public static Geometry geometricMedian(Geometry geometry, double
tolerance) throws Exception {
+ return geometricMedian(geometry, tolerance, DEFAULT_MAX_ITER, false);
+ }
+
+ public static Geometry geometricMedian(Geometry geometry) throws Exception
{
+ return geometricMedian(geometry, DEFAULT_TOLERANCE, DEFAULT_MAX_ITER,
false);
+ }
+
}
diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
index eabaaf36..78a26e19 100644
--- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
+++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
@@ -14,22 +14,38 @@
package org.apache.sedona.common;
import com.google.common.geometry.S2CellId;
-import org.apache.sedona.common.utils.GeomUtils;
+import com.google.common.math.DoubleMath;
import org.apache.sedona.common.utils.S2Utils;
import org.junit.Test;
import org.locationtech.jts.geom.*;
+import org.locationtech.jts.io.WKTReader;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
+import static org.junit.Assert.*;
public class FunctionsTest {
public static final GeometryFactory GEOMETRY_FACTORY = new
GeometryFactory();
+ protected static final double FP_TOLERANCE = 1e-12;
+ protected static final CoordinateSequenceComparator
COORDINATE_SEQUENCE_COMPARATOR = new CoordinateSequenceComparator(2){
+ @Override
+ protected int compareCoordinate(CoordinateSequence s1,
CoordinateSequence s2, int i, int dimension) {
+ for (int d = 0; d < dimension; d++) {
+ double ord1 = s1.getOrdinate(i, d);
+ double ord2 = s2.getOrdinate(i, d);
+ int comp = DoubleMath.fuzzyCompare(ord1, ord2, FP_TOLERANCE);
+ if (comp != 0) return comp;
+ }
+ return 0;
+ }
+ };
+
+ private final WKTReader wktReader = new WKTReader();
+
private Coordinate[] coordArray(double... coordValues) {
Coordinate[] coords = new Coordinate[(int)(coordValues.length / 2)];
for (int i = 0; i < coordValues.length; i += 2) {
@@ -389,4 +405,40 @@ public class FunctionsTest {
expects.add(10);
assertEquals(expects, levels);
}
+
+ @Test
+ public void geometricMedian() throws Exception {
+ MultiPoint multiPoint = GEOMETRY_FACTORY.createMultiPointFromCoords(
+ coordArray(1480,0, 620,0));
+ Geometry actual = Functions.geometricMedian(multiPoint);
+ Geometry expected = wktReader.read("POINT (1050 0)");
+ assertEquals(0, expected.compareTo(actual,
COORDINATE_SEQUENCE_COMPARATOR));
+ }
+
+ @Test
+ public void geometricMedianTolerance() throws Exception {
+ MultiPoint multiPoint = GEOMETRY_FACTORY.createMultiPointFromCoords(
+ coordArray(0,0, 10,1, 5,1, 20,20));
+ Geometry actual = Functions.geometricMedian(multiPoint, 1e-15);
+ Geometry expected = wktReader.read("POINT (5 1)");
+ assertEquals(0, expected.compareTo(actual,
COORDINATE_SEQUENCE_COMPARATOR));
+ }
+
+ @Test
+ public void geometricMedianUnsupported() {
+ LineString lineString = GEOMETRY_FACTORY.createLineString(
+ coordArray(1480,0, 620,0));
+ Exception e = assertThrows(Exception.class, () ->
Functions.geometricMedian(lineString));
+ assertEquals("Unsupported geometry type: LineString", e.getMessage());
+ }
+
+ @Test
+ public void geometricMedianFailConverge() {
+ MultiPoint multiPoint = GEOMETRY_FACTORY.createMultiPointFromCoords(
+ coordArray(12,5, 62,7, 100,-1, 100,-5, 10,20, 105,-5));
+ Exception e = assertThrows(Exception.class,
+ () -> Functions.geometricMedian(multiPoint, 1e-6, 5, true));
+ assertEquals("Median failed to converge within 1.0E-06 after 5
iterations.", e.getMessage());
+ }
+
}
diff --git a/docs/api/flink/Function.md b/docs/api/flink/Function.md
index ee2a4f60..603609d8 100644
--- a/docs/api/flink/Function.md
+++ b/docs/api/flink/Function.md
@@ -355,6 +355,36 @@ Result:
+-----------------------------+
```
+## ST_GeometricMedian
+
+Introduction: Computes the approximate geometric median of a MultiPoint
geometry using the Weiszfeld algorithm. The geometric median provides a
centrality measure that is less sensitive to outlier points than the centroid.
+
+The algorithm will iterate until the distance change between successive
iterations is less than the supplied `tolerance` parameter. If this condition
has not been met after `maxIter` iterations, the function will produce an error
and exit, unless `failIfNotConverged` is set to `false`.
+
+If a `tolerance` value is not provided, a default `tolerance` value is `1e-6`.
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter:
integer, failIfNotConverged: boolean)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter:
integer)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float)`
+
+Format: `ST_GeometricMedian(geom: geometry)`
+
+Default parameters: `tolerance: 1e-6, maxIter: 1000, failIfNotConverged: false`
+
+Since: `1.4.1`
+
+Example:
+```sql
+SELECT ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT((0 0), (1 1), (2 2), (200
200))'))
+```
+
+Output:
+```
+POINT (1.9761550281255005 1.9761550281255005)
+```
+
## ST_GeometryN
Introduction: Return the 0-based Nth geometry if the geometry is a
GEOMETRYCOLLECTION, (MULTI)POINT, (MULTI)LINESTRING, MULTICURVE or
(MULTI)POLYGON. Otherwise, return null
@@ -927,4 +957,3 @@ SELECT ST_ZMin(ST_GeomFromText('LINESTRING(1 3 4, 5 6 7)'))
```
Output: `4.0`
-
diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md
index d15bc735..f9a8d82b 100644
--- a/docs/api/sql/Function.md
+++ b/docs/api/sql/Function.md
@@ -542,6 +542,36 @@ Result:
+-----------------------------+
```
+## ST_GeometricMedian
+
+Introduction: Computes the approximate geometric median of a MultiPoint
geometry using the Weiszfeld algorithm. The geometric median provides a
centrality measure that is less sensitive to outlier points than the centroid.
+
+The algorithm will iterate until the distance change between successive
iterations is less than the supplied `tolerance` parameter. If this condition
has not been met after `maxIter` iterations, the function will produce an error
and exit, unless `failIfNotConverged` is set to `false`.
+
+If a `tolerance` value is not provided, a default `tolerance` value is `1e-6`.
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter:
integer, failIfNotConverged: boolean)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter:
integer)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float)`
+
+Format: `ST_GeometricMedian(geom: geometry)`
+
+Default parameters: `tolerance: 1e-6, maxIter: 1000, failIfNotConverged: false`
+
+Since: `1.4.1`
+
+Example:
+```sql
+SELECT ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT((0 0), (1 1), (2 2), (200
200))'))
+```
+
+Output:
+```
+POINT (1.9761550281255005 1.9761550281255005)
+```
+
## ST_GeometryN
Introduction: Return the 0-based Nth geometry if the geometry is a
GEOMETRYCOLLECTION, (MULTI)POINT, (MULTI)LINESTRING, MULTICURVE or
(MULTI)POLYGON. Otherwise, return null
@@ -1560,3 +1590,4 @@ SELECT ST_ZMin(ST_GeomFromText('LINESTRING(1 3 4, 5 6
7)'))
```
Output: `4.0`
+
diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java
b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
index 83a99029..1075710a 100644
--- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java
+++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
@@ -89,7 +89,8 @@ public class Catalog {
new Functions.ST_SetPoint(),
new Functions.ST_LineFromMultiPoint(),
new Functions.ST_Split(),
- new Functions.ST_S2CellIDs()
+ new Functions.ST_S2CellIDs(),
+ new Functions.ST_GeometricMedian()
};
}
diff --git
a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
index a22b91c2..11f5e9d8 100644
--- a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
+++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
@@ -497,4 +497,37 @@ public class Functions {
return org.apache.sedona.common.Functions.s2CellIDs(geom, level);
}
}
+
+ public static class ST_GeometricMedian extends ScalarFunction {
+ @DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class)
+ public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class) Object o) throws Exception {
+ Geometry geometry = (Geometry) o;
+ return
org.apache.sedona.common.Functions.geometricMedian(geometry);
+ }
+
+ @DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class)
+ public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class) Object o,
+ @DataTypeHint("Double") Double tolerance) throws
Exception {
+ Geometry geometry = (Geometry) o;
+ return
org.apache.sedona.common.Functions.geometricMedian(geometry, tolerance);
+ }
+
+ @DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class)
+ public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class) Object o,
+ @DataTypeHint("Double") Double tolerance,
+ int maxIter) throws Exception {
+ Geometry geometry = (Geometry) o;
+ return
org.apache.sedona.common.Functions.geometricMedian(geometry, tolerance,
maxIter);
+ }
+
+ @DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class)
+ public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class) Object o,
+ @DataTypeHint("Double") Double tolerance,
+ int maxIter, @DataTypeHint("Boolean") Boolean
failIfNotConverged) throws Exception {
+ Geometry geometry = (Geometry) o;
+ return
org.apache.sedona.common.Functions.geometricMedian(geometry, tolerance,
maxIter, failIfNotConverged);
+ }
+
+ }
+
}
diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
index 94546edd..02486e39 100644
--- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
+++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
@@ -24,6 +24,7 @@ import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.LineString;
import org.locationtech.jts.geom.Point;
import org.locationtech.jts.geom.Polygon;
+import org.locationtech.jts.io.ParseException;
import org.opengis.referencing.FactoryException;
import org.opengis.referencing.crs.CoordinateReferenceSystem;
@@ -37,6 +38,7 @@ import static org.apache.flink.table.api.Expressions.call;
import static org.junit.Assert.*;
public class FunctionTest extends TestBase{
+
@BeforeClass
public static void onceExecutedBeforeAll() {
initialize();
@@ -614,4 +616,34 @@ public class FunctionTest extends TestBase{
assertEquals(1, count(joinCleanedTable));
assertEquals(2, first(joinCleanedTable).getField(1));
}
+
+ @Test
+ public void testGeometricMedian() throws ParseException {
+ Table pointTable = tableEnv.sqlQuery("SELECT
ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT((0 0), (1 1), (2 2), (200
200))'))");
+ Geometry expected = wktReader.read("POINT (1.9761550281255005
1.9761550281255005)");
+ Geometry actual = (Geometry) first(pointTable).getField(0);
+ assertEquals(String.format("expected: %s was %s", expected.toText(),
actual != null ? actual.toText() : "null"),
+ 0, expected.compareTo(actual, COORDINATE_SEQUENCE_COMPARATOR));
+ }
+
+ @Test
+ public void testGeometricMedianParamsTolerance() throws ParseException {
+ Table pointTable = tableEnv.sqlQuery(
+ "SELECT ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT ((0 0),
(1 1), (0 1), (2 2))'), 1e-5)");
+ Geometry expected = wktReader.read("POINT (0.996230268436779
0.9999899629155288)");
+ Geometry actual = (Geometry) first(pointTable).getField(0);
+ assertEquals(String.format("expected: %s was %s", expected.toText(),
actual != null ? actual.toText() : "null"),
+ 0, expected.compareTo(actual, COORDINATE_SEQUENCE_COMPARATOR));
+ }
+
+ @Test
+ public void testGeometricMedianParamsFull() throws ParseException {
+ Table pointTable = tableEnv.sqlQuery(
+ "SELECT ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT ((0 0),
(1 1), (0 1), (2 2))'), 1e-5, 10, false)");
+ Geometry expected = wktReader.read("POINT (0.8844442206215307
0.9912184073718183)");
+ Geometry actual = (Geometry) first(pointTable).getField(0);
+ assertEquals(String.format("expected: %s was %s", expected.toText(),
actual != null ? actual.toText() : "null"),
+ 0, expected.compareTo(actual, COORDINATE_SEQUENCE_COMPARATOR));
+ }
+
}
diff --git a/flink/src/test/java/org/apache/sedona/flink/TestBase.java
b/flink/src/test/java/org/apache/sedona/flink/TestBase.java
index be88c5fb..1148f715 100644
--- a/flink/src/test/java/org/apache/sedona/flink/TestBase.java
+++ b/flink/src/test/java/org/apache/sedona/flink/TestBase.java
@@ -13,6 +13,7 @@
*/
package org.apache.sedona.flink;
+import com.google.common.math.DoubleMath;
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -27,6 +28,7 @@ import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
import org.apache.sedona.flink.expressions.Constructors;
import org.locationtech.jts.geom.*;
+import org.locationtech.jts.io.WKTReader;
import org.wololo.jts2geojson.GeoJSONWriter;
import java.sql.Timestamp;
@@ -50,6 +52,21 @@ public class TestBase {
static Long timestamp_base = new
Timestamp(System.currentTimeMillis()).getTime();
static Long time_interval = 1L; // Generate a record per this interval.
Unit is second
+ static final double FP_TOLERANCE = 1e-12;
+ static final CoordinateSequenceComparator COORDINATE_SEQUENCE_COMPARATOR =
new CoordinateSequenceComparator(2){
+ @Override
+ protected int compareCoordinate(CoordinateSequence s1,
CoordinateSequence s2, int i, int dimension) {
+ for (int d = 0; d < dimension; d++) {
+ double ord1 = s1.getOrdinate(i, d);
+ double ord2 = s2.getOrdinate(i, d);
+ int comp = DoubleMath.fuzzyCompare(ord1, ord2, FP_TOLERANCE);
+ if (comp != 0) return comp;
+ }
+ return 0;
+ }
+ };
+ final WKTReader wktReader = new WKTReader();
+
public void setTestDataSize(int testDataSize) {
this.testDataSize = testDataSize;
}
diff --git a/python/sedona/sql/st_functions.py
b/python/sedona/sql/st_functions.py
index d2fd8edc..4bb7df5b 100644
--- a/python/sedona/sql/st_functions.py
+++ b/python/sedona/sql/st_functions.py
@@ -53,6 +53,7 @@ __all__ = [
"ST_FlipCoordinates",
"ST_Force_2D",
"ST_GeoHash",
+ "ST_GeometricMedian",
"ST_GeometryN",
"ST_GeometryType",
"ST_InteriorRingN",
@@ -496,6 +497,31 @@ def ST_GeoHash(geometry: ColumnOrName, precision:
Union[ColumnOrName, int]) -> C
"""
return _call_st_function("ST_GeoHash", (geometry, precision))
+@validate_argument_types
+def ST_GeometricMedian(geometry: ColumnOrName, tolerance:
Optional[Union[ColumnOrName, float]] = 1e-6,
+ max_iter: Optional[Union[ColumnOrName, int]] = 1000,
+ fail_if_not_converged: Optional[Union[ColumnOrName,
bool]] = False) -> Column:
+ """Computes the approximate geometric median of a MultiPoint geometry
using the Weiszfeld algorithm.
+ The geometric median provides a centrality measure that is less sensitive
to outlier points than the centroid.
+ The algorithm will iterate until the distance change between successive
iterations is less than the
+ supplied `tolerance` parameter. If this condition has not been met after
`maxIter` iterations, the function will
+ produce an error and exit, unless `failIfNotConverged` is set to `false`.
If a `tolerance` value is not provided,
+ a default `tolerance` value is `1e-6`.
+
+ :param geometry: MultiPoint or Point geometry.
+ :type geometry: ColumnOrName
+ :param tolerance: Distance limit change between successive iterations,
defaults to 1e-6.
+ :type tolerance: Optional[Union[ColumnOrName, float]], optional
+ :param max_iter: Max number of iterations, defaults to 1000.
+ :type max_iter: Optional[Union[ColumnOrName, int]], optional
+ :param fail_if_not_converged: Generate error if not converged within given
tolerance and number of iterations, defaults to False
+ :type fail_if_not_converged: Optional[Union[ColumnOrName, boolean]],
optional
+ :return: Point geometry column.
+ :rtype: Column
+ """
+ args = (geometry, tolerance, max_iter, fail_if_not_converged)
+ return _call_st_function("ST_GeometricMedian", args)
+
@validate_argument_types
def ST_GeometryN(multi_geometry: ColumnOrName, n: Union[ColumnOrName, int]) ->
Column:
diff --git a/python/tests/sql/test_dataframe_api.py
b/python/tests/sql/test_dataframe_api.py
index dd5f275a..1ad052bd 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -81,6 +81,7 @@ test_configurations = [
(stf.ST_ExteriorRing, ("geom",), "triangle_geom", "", "LINESTRING (0 0, 1
0, 1 1, 0 0)"),
(stf.ST_FlipCoordinates, ("point",), "point_geom", "", "POINT (1 0)"),
(stf.ST_Force_2D, ("point",), "point_geom", "", "POINT (0 1)"),
+ (stf.ST_GeometricMedian, ("multipoint",), "multipoint_geom", "", "POINT
(22.500002656424286 21.250001168173426)"),
(stf.ST_GeometryN, ("geom", 0), "multipoint", "", "POINT (0 0)"),
(stf.ST_GeometryType, ("point",), "point_geom", "", "ST_Point"),
(stf.ST_InteriorRingN, ("geom", 0), "geom_with_hole", "", "LINESTRING (1
1, 2 2, 2 1, 1 1)"),
diff --git a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 9364eae5..9b85f931 100644
--- a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -142,6 +142,7 @@ object Catalog {
function[ST_MLineFromText](0),
function[ST_Split](),
function[ST_S2CellIDs](),
+ function[ST_GeometricMedian](1e-6, 1000, false),
// Expression for rasters
function[RS_NormalizedDifference](),
function[RS_Mean](),
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 9d065a78..7298e1af 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -933,3 +933,18 @@ case class ST_CollectionExtract(inputExpressions:
Seq[Expression])
override def allowRightNull: Boolean = true
}
+
+/**
+ * Returns a POINT Computes the approximate geometric median of a MultiPoint
geometry using the Weiszfeld algorithm.
+ * The geometric median provides a centrality measure that is less sensitive
to outlier points than the centroid.
+ *
+ * @param inputExpressions Geometry
+ */
+case class ST_GeometricMedian(inputExpressions: Seq[Expression])
+ extends InferredQuarternaryExpression(Functions.geometricMedian) with
FoldableExpression {
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
{
+ copy(inputExpressions = newChildren)
+ }
+}
+
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index ce077d99..3e5fa706 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -275,4 +275,14 @@ object st_functions extends DataFrameAPI {
def ST_ZMin(geometry: Column): Column = wrapExpression[ST_ZMin](geometry)
def ST_ZMin(geometry: String): Column = wrapExpression[ST_ZMin](geometry)
+
+ def ST_GeometricMedian(geometry: Column): Column =
wrapExpression[ST_GeometricMedian](geometry, 1e-6, 1000, false)
+ def ST_GeometricMedian(geometry: String): Column =
wrapExpression[ST_GeometricMedian](geometry, 1e-6, 1000, false)
+ def ST_GeometricMedian(geometry: Column, tolerance: Column): Column =
wrapExpression[ST_GeometricMedian](geometry, tolerance, 1000, false)
+ def ST_GeometricMedian(geometry: String, tolerance: Double): Column =
wrapExpression[ST_GeometricMedian](geometry, tolerance, 1000, false)
+ def ST_GeometricMedian(geometry: Column, tolerance: Column, maxIter:
Column): Column = wrapExpression[ST_GeometricMedian](geometry, tolerance,
maxIter, false)
+ def ST_GeometricMedian(geometry: String, tolerance: Double, maxIter: Int):
Column = wrapExpression[ST_GeometricMedian](geometry, tolerance, maxIter, false)
+ def ST_GeometricMedian(geometry: Column, tolerance: Column, maxIter: Column,
failIfNotConverged: Column): Column =
wrapExpression[ST_GeometricMedian](geometry, tolerance, maxIter,
failIfNotConverged)
+ def ST_GeometricMedian(geometry: String, tolerance: Double, maxIter: Int,
failIfNotConverged: Boolean): Column =
wrapExpression[ST_GeometricMedian](geometry, tolerance, maxIter,
failIfNotConverged)
+
}
diff --git
a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 47c72e7c..f12b5874 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -18,11 +18,13 @@
*/
package org.apache.sedona.sql
+import com.google.common.math.DoubleMath
import org.apache.log4j.{Level, Logger}
import org.apache.sedona.core.serde.SedonaKryoRegistrator
import org.apache.sedona.sql.utils.SedonaSQLRegistrator
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.locationtech.jts.geom.{CoordinateSequence,
CoordinateSequenceComparator}
import org.scalatest.{BeforeAndAfterAll, FunSpec}
trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
@@ -81,4 +83,18 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
lazy val buildPointDf =
loadCsv(csvPointInputLocation).selectExpr("ST_Point(cast(_c0 as
Decimal(24,20)),cast(_c1 as Decimal(24,20))) as pointshape")
lazy val buildPolygonDf =
loadCsv(csvPolygonInputLocation).selectExpr("ST_PolygonFromEnvelope(cast(_c0 as
Decimal(24,20)),cast(_c1 as Decimal(24,20)), cast(_c2 as Decimal(24,20)),
cast(_c3 as Decimal(24,20))) as polygonshape")
+
+ protected final val FP_TOLERANCE: Double = 1e-12
+ protected final val COORDINATE_SEQUENCE_COMPARATOR:
CoordinateSequenceComparator = new CoordinateSequenceComparator(2) {
+ override protected def compareCoordinate(s1: CoordinateSequence, s2:
CoordinateSequence, i: Int, dimension: Int): Int = {
+ for (d <- 0 until dimension) {
+ val ord1: Double = s1.getOrdinate(i, d)
+ val ord2: Double = s2.getOrdinate(i, d)
+ val comp: Int = DoubleMath.fuzzyCompare(ord1, ord2, FP_TOLERANCE)
+ if (comp != 0) return comp
+ }
+ 0
+ }
+ }
+
}
diff --git
a/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
b/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index 0b3d9d42..a66a9be2 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row}
import org.geotools.referencing.CRS
import org.locationtech.jts.algorithm.MinimumBoundingCircle
-import org.locationtech.jts.geom.{Geometry, Polygon}
+import org.locationtech.jts.geom.{CoordinateSequenceComparator, Geometry,
Polygon}
import org.locationtech.jts.io.WKTWriter
import org.locationtech.jts.linearref.LengthIndexedLine
import org.locationtech.jts.operation.distance3d.Distance3DOp
@@ -1762,6 +1762,8 @@ class functionTestScala extends TestBaseScala with
Matchers with GeometrySample
assert(functionDf.first().get(0) == null)
functionDf = sparkSession.sql("select ST_LineFromMultiPoint(null)")
assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_GeometricMedian(null)")
+ assert(functionDf.first().get(0) == null)
}
it ("Should pass St_CollectionExtract") {
@@ -1818,4 +1820,23 @@ class functionTestScala extends TestBaseScala with
Matchers with GeometrySample
assert(textResult==expected)
}
}
+
+ it ("Should pass ST_GeometricMedian") {
+ val geomTestCases = Map(
+ ("'MULTIPOINT((10 40), (40 30), (20 20), (30 10))'", 1e-15) ->
"'POINT(22.5 21.25)'",
+ ("'MULTIPOINT((0 0), (1 1), (2 2), (200 200))'", 1e-6) -> "'POINT
(1.9761550281255005 1.9761550281255005)'",
+ ("'MULTIPOINT ((0 0), (10 1), (5 1), (20 20))'", 1e-15) -> "'POINT (5
1)'",
+ ("'MULTIPOINT ((0 -1), (0 0), (0 0), (0 1))'", 1e-6) -> "'POINT (0 0)'",
+ ("'POINT (7 6)'", 1e-6) -> "'POINT (7 6)'",
+ ("'MULTIPOINT ((12 5),(62 7),(100 -1),(100 -5),(10 20),(105 -5))'",
1e-15) -> "'POINT(84.21672412761632 0.1351485929395439)'"
+ )
+ for(((targetWkt, tolerance), expectedWkt) <- geomTestCases) {
+ val df = sparkSession.sql(s"SELECT
ST_GeometricMedian(ST_GeomFromWKT($targetWkt), $tolerance), " +
+ s"ST_GeomFromWKT($expectedWkt)")
+ val actual = df.take(1)(0).get(0).asInstanceOf[Geometry]
+ val expected = df.take(1)(0).get(1).asInstanceOf[Geometry]
+ assert(expected.compareTo(actual, COORDINATE_SEQUENCE_COMPARATOR) == 0)
+ }
+ }
+
}