jiayuasu commented on code in PR #831:
URL: https://github.com/apache/sedona/pull/831#discussion_r1191721424
##########
docs/api/sql/Function.md:
##########
@@ -1560,3 +1560,33 @@ SELECT ST_ZMin(ST_GeomFromText('LINESTRING(1 3 4, 5 6
7)'))
```
Output: `4.0`
+
+## ST_GeometricMedian
+
+Introduction: Computes the approximate geometric median of a MultiPoint
geometry using the Weiszfeld algorithm. The geometric median provides a
centrality measure that is less sensitive to outlier points than the centroid.
+
+The algorithm will iterate until the distance change between successive
iterations is less than the supplied `tolerance` parameter. If this condition
has not been met after `maxIter` iterations, the function will produce an error
and exit, unless `failIfNotConverged` is set to `false`.
+
+If a `tolerance` value is not provided, a default `tolerance` value is `1e-6`.
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter:
integer, failIfNotConverged: boolean)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter:
integer)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float)`
+
+Format: `ST_GeometricMedian(geom: geometry)`
+
+Default parameters: `tolerance: 1e-6, maxIter: 1000, failIfNotConverged: false`
+
+Since: ``
Review Comment:
Should be Since `1.4.1`
##########
docs/api/sql/Function.md:
##########
@@ -1560,3 +1560,33 @@ SELECT ST_ZMin(ST_GeomFromText('LINESTRING(1 3 4, 5 6
7)'))
```
Output: `4.0`
+
+## ST_GeometricMedian
Review Comment:
Docs are sorted alphabetically. Please place this to the correct place.
##########
docs/api/flink/Function.md:
##########
@@ -928,3 +928,33 @@ SELECT ST_ZMin(ST_GeomFromText('LINESTRING(1 3 4, 5 6 7)'))
Output: `4.0`
+## ST_GeometricMedian
Review Comment:
Please place this to the correct place because functions are sorted
alphabetically.
##########
sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala:
##########
@@ -275,4 +275,14 @@ object st_functions extends DataFrameAPI {
def ST_ZMin(geometry: Column): Column = wrapExpression[ST_ZMin](geometry)
def ST_ZMin(geometry: String): Column = wrapExpression[ST_ZMin](geometry)
+
+ def ST_GeometricMedian(geometry: Column): Column =
wrapExpression[ST_GeometricMedian](geometry, 1e-6, 1000, false)
Review Comment:
Please also add DataFrame Style API to Sedona Python. See examples:
https://github.com/apache/sedona/blob/master/python/sedona/sql/st_functions.py#L882
You don't need to perform tests in Sedona Python if you don't want.
##########
common/src/main/java/org/apache/sedona/common/Functions.java:
##########
@@ -730,4 +731,149 @@ public static Geometry collectionExtract(Geometry
geometry) {
return GEOMETRY_FACTORY.createGeometryCollection();
}
+
+ // ported from
https://github.com/postgis/postgis/blob/f6ed58d1fdc865d55d348212d02c11a10aeb2b30/liblwgeom/lwgeom_median.c
+ // geometry ST_GeometricMedian ( geometry g , float8 tolerance , int
max_iter , boolean fail_if_not_converged );
+
+ private static double distance3d(Coordinate p1, Coordinate p2) {
+ double dx = p2.x - p1.x;
+ double dy = p2.y - p1.y;
+ double dz = p2.z - p1.z;
+ return Math.sqrt(dx * dx + dy * dy + dz * dz);
+ }
+
+ private static void distances(Coordinate curr, Coordinate[] points,
double[] distances) {
+ for(int i = 0; i < points.length; i++) {
+ distances[i] = distance3d(curr, points[i]);
+ }
+ }
+
+ private static double iteratePoints(Coordinate curr, Coordinate[] points,
double[] distances) {
+ Coordinate next = new Coordinate(0, 0, 0);
+ double delta = 0;
+ double denom = 0;
+ boolean hit = false;
+ distances(curr, points, distances);
+
+ for (int i = 0; i < points.length; i++) {
+ /* we need to use lower epsilon than in FP_IS_ZERO in the loop for
calculation to converge */
+ double distance = distances[i];
+ if (distance > DBL_EPSILON) {
+ Coordinate coordinate = points[i];
+ next.x += coordinate.x / distance;
+ next.y += coordinate.y / distance;
+ next.z += coordinate.z / distance;
+ denom += 1.0 / distance;
+ } else {
+ hit = true;
+ }
+ }
+ /* negative weight shouldn't get here */
+ //assert(denom >= 0);
+
+ /* denom is zero in case of multipoint of single point when we've
converged perfectly */
+ if (denom > DBL_EPSILON) {
+ next.x /= denom;
+ next.y /= denom;
+ next.z /= denom;
+
+ /* If any of the intermediate points in the calculation is found
in the
+ * set of input points, the standard Weiszfeld method gets stuck
with a
+ * divide-by-zero.
+ *
+ * To get ourselves out of the hole, we follow an alternate
procedure to
+ * get the next iteration, as described in:
+ *
+ * Vardi, Y. and Zhang, C. (2011) "A modified Weiszfeld algorithm
for the
+ * Fermat-Weber location problem." Math. Program., Ser. A 90:
559-566.
+ * DOI 10.1007/s101070100222
+ *
+ * Available online at the time of this writing at
+ * http://www.stat.rutgers.edu/home/cunhui/papers/43.pdf
+ */
+ if (hit) {
+ double dx = 0;
+ double dy = 0;
+ double dz = 0;
+ for (int i = 0; i < points.length; i++) {
+ double distance = distances[i];
+ if (distance > DBL_EPSILON) {
+ Coordinate coordinate = points[i];
+ dx += (coordinate.x - curr.x) / distance;
+ dy += (coordinate.y - curr.y) / distance;
+ dz += (coordinate.z - curr.z) / distance;
+ }
+ }
+ double dSqr = Math.sqrt(dx*dx + dy*dy + dz*dz);
+ /* Avoid division by zero if the intermediate point is the
median */
+ if (dSqr > DBL_EPSILON) {
+ double rInv = Math.max(0, 1.0 / dSqr);
+ next.x = (1.0 - rInv)*next.x + rInv*curr.x;
+ next.y = (1.0 - rInv)*next.y + rInv*curr.y;
+ next.z = (1.0 - rInv)*next.z + rInv*curr.z;
+ }
+ }
+ delta = distance3d(curr, next);
+ curr.x = next.x;
+ curr.y = next.y;
+ curr.z = next.z;
+ }
+ return delta;
+ }
+
+ private static Coordinate initGuess(Coordinate[] points) {
+ Coordinate guess = new Coordinate(0, 0, 0);
+ for (Coordinate point : points) {
+ guess.x += point.x / points.length;
+ guess.y += point.y / points.length;
+ guess.z += point.z / points.length;
+ }
+ return guess;
+ }
+
+ private static Coordinate[] extractCoordinates(Geometry geometry) {
+ Coordinate[] points = geometry.getCoordinates();
+ if(points.length == 0)
+ return points;
+ boolean is3d = !Double.isNaN(points[0].z);
+ Coordinate[] coordinates = new Coordinate[points.length];
+ for(int i = 0; i < points.length; i++) {
+ coordinates[i] = points[i].copy();
+ if(!is3d)
+ coordinates[i].z = 0.0;
+ }
+ return coordinates;
+ }
+
+ public static Point geometricMedian(Geometry geometry, double tolerance,
int maxIter, boolean failIfNotConverged) throws Exception {
Review Comment:
If you want, you can add more comprehensive tests in sedona-common to check
the correctness of this function.
##########
sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala:
##########
@@ -142,6 +142,7 @@ object Catalog {
function[ST_MLineFromText](0),
function[ST_Split](),
function[ST_S2CellIDs](),
+ function[ST_GeometricMedian](1e-6, 1000, false),
Review Comment:
Please add test cases for Spark as well.
##########
docs/api/flink/Function.md:
##########
@@ -928,3 +928,33 @@ SELECT ST_ZMin(ST_GeomFromText('LINESTRING(1 3 4, 5 6 7)'))
Output: `4.0`
+## ST_GeometricMedian
+
+Introduction: Computes the approximate geometric median of a MultiPoint
geometry using the Weiszfeld algorithm. The geometric median provides a
centrality measure that is less sensitive to outlier points than the centroid.
+
+The algorithm will iterate until the distance change between successive
iterations is less than the supplied `tolerance` parameter. If this condition
has not been met after `maxIter` iterations, the function will produce an error
and exit, unless `failIfNotConverged` is set to `false`.
+
+If a `tolerance` value is not provided, a default `tolerance` value is `1e-6`.
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter:
integer, failIfNotConverged: boolean)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter:
integer)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float)`
+
+Format: `ST_GeometricMedian(geom: geometry)`
+
+Default parameters: `tolerance: 1e-6, maxIter: 1000, failIfNotConverged: false`
+
+Since: ``
Review Comment:
Since `1.4.1`
##########
flink/src/test/java/org/apache/sedona/flink/FunctionTest.java:
##########
@@ -614,4 +614,11 @@ assert take(joinTable, 2).stream().map(
assertEquals(1, count(joinCleanedTable));
assertEquals(2, first(joinCleanedTable).getField(1));
}
+
+ @Test
+ public void testGeometricMedian() {
+ Table pointTable = tableEnv.sqlQuery("SELECT
ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT((0 0), (1 1), (2 2), (200
200))'))");
+ assertEquals("POINT (1.9761550281255005 1.9761550281255005)",
first(pointTable).getField(0).toString());
+ }
Review Comment:
Please add another test case that has all parameters.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]