This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new abc51e3d [SEDONA-280] Add ST_GeometricMedian (#831)
abc51e3d is described below

commit abc51e3db675db643418c2e03f1be3f63d2a1b37
Author: Artem <[email protected]>
AuthorDate: Wed May 17 21:57:27 2023 -0700

    [SEDONA-280] Add ST_GeometricMedian (#831)
---
 .../java/org/apache/sedona/common/Functions.java   | 160 ++++++++++++++++++++-
 .../org/apache/sedona/common/FunctionsTest.java    |  58 +++++++-
 docs/api/flink/Function.md                         |  31 +++-
 docs/api/sql/Function.md                           |  31 ++++
 .../main/java/org/apache/sedona/flink/Catalog.java |   3 +-
 .../apache/sedona/flink/expressions/Functions.java |  33 +++++
 .../java/org/apache/sedona/flink/FunctionTest.java |  32 +++++
 .../java/org/apache/sedona/flink/TestBase.java     |  17 +++
 python/sedona/sql/st_functions.py                  |  26 ++++
 python/tests/sql/test_dataframe_api.py             |   1 +
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |   1 +
 .../sql/sedona_sql/expressions/Functions.scala     |  15 ++
 .../sql/sedona_sql/expressions/st_functions.scala  |  10 ++
 .../org/apache/sedona/sql/TestBaseScala.scala      |  16 +++
 .../org/apache/sedona/sql/functionTestScala.scala  |  23 ++-
 15 files changed, 446 insertions(+), 11 deletions(-)

diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java 
b/common/src/main/java/org/apache/sedona/common/Functions.java
index c86572f8..ad7af119 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -25,6 +25,7 @@ import org.geotools.referencing.CRS;
 import org.locationtech.jts.algorithm.MinimumBoundingCircle;
 import org.locationtech.jts.algorithm.hull.ConcaveHull;
 import org.locationtech.jts.geom.*;
+import org.locationtech.jts.geom.impl.CoordinateArraySequence;
 import org.locationtech.jts.geom.util.GeometryFixer;
 import org.locationtech.jts.io.gml2.GMLWriter;
 import org.locationtech.jts.io.kml.KMLWriter;
@@ -42,18 +43,18 @@ import org.opengis.referencing.operation.MathTransform;
 import org.opengis.referencing.operation.TransformException;
 import org.wololo.jts2geojson.GeoJSONWriter;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashSet;
-import java.util.List;
+import java.util.*;
 import java.util.stream.Collectors;
-import java.util.stream.Stream;
+
+import static com.google.common.geometry.S2.DBL_EPSILON;
 
 
 public class Functions {
     private static final GeometryFactory GEOMETRY_FACTORY = new 
GeometryFactory();
     private static Geometry EMPTY_POLYGON = 
GEOMETRY_FACTORY.createPolygon(null, null);
     private static GeometryCollection EMPTY_GEOMETRY_COLLECTION = 
GEOMETRY_FACTORY.createGeometryCollection(null);
+    private static final double DEFAULT_TOLERANCE = 1e-6;
+    private static final int DEFAULT_MAX_ITER = 1000;
 
     public static double area(Geometry geometry) {
         return geometry.getArea();
@@ -730,4 +731,153 @@ public class Functions {
         return GEOMETRY_FACTORY.createGeometryCollection();
     }
 
+
+    // ported from 
https://github.com/postgis/postgis/blob/f6ed58d1fdc865d55d348212d02c11a10aeb2b30/liblwgeom/lwgeom_median.c
+    // geometry ST_GeometricMedian ( geometry g , float8 tolerance , int 
max_iter , boolean fail_if_not_converged );
+
+    private static double distance3d(Coordinate p1, Coordinate p2) {
+        double dx = p2.x - p1.x;
+        double dy = p2.y - p1.y;
+        double dz = p2.z - p1.z;
+        return Math.sqrt(dx * dx + dy * dy + dz * dz);
+    }
+
+    private static void distances(Coordinate curr, Coordinate[] points, 
double[] distances) {
+        for(int i = 0; i < points.length; i++) {
+            distances[i] = distance3d(curr, points[i]);
+        }
+    }
+
+    private static double iteratePoints(Coordinate curr, Coordinate[] points, 
double[] distances) {
+        Coordinate next = new Coordinate(0, 0, 0);
+        double delta = 0;
+        double denom = 0;
+        boolean hit = false;
+        distances(curr, points, distances);
+
+        for (int i = 0; i < points.length; i++) {
+            /* we need to use lower epsilon than in FP_IS_ZERO in the loop for 
calculation to converge */
+            double distance = distances[i];
+            if (distance > DBL_EPSILON) {
+                Coordinate coordinate = points[i];
+                next.x += coordinate.x / distance;
+                next.y += coordinate.y / distance;
+                next.z += coordinate.z / distance;
+                denom += 1.0 / distance;
+            } else {
+                hit = true;
+            }
+        }
+        /* negative weight shouldn't get here */
+        //assert(denom >= 0);
+
+        /* denom is zero in case of multipoint of single point when we've 
converged perfectly */
+        if (denom > DBL_EPSILON) {
+            next.x /= denom;
+            next.y /= denom;
+            next.z /= denom;
+
+            /* If any of the intermediate points in the calculation is found 
in the
+             * set of input points, the standard Weiszfeld method gets stuck 
with a
+             * divide-by-zero.
+             *
+             * To get ourselves out of the hole, we follow an alternate 
procedure to
+             * get the next iteration, as described in:
+             *
+             * Vardi, Y. and Zhang, C. (2011) "A modified Weiszfeld algorithm 
for the
+             * Fermat-Weber location problem."  Math. Program., Ser. A 90: 
559-566.
+             * DOI 10.1007/s101070100222
+             *
+             * Available online at the time of this writing at
+             * http://www.stat.rutgers.edu/home/cunhui/papers/43.pdf
+             */
+            if (hit) {
+                double dx = 0;
+                double dy = 0;
+                double dz = 0;
+                for (int i = 0; i < points.length; i++) {
+                    double distance = distances[i];
+                    if (distance > DBL_EPSILON) {
+                        Coordinate coordinate = points[i];
+                        dx += (coordinate.x - curr.x) / distance;
+                        dy += (coordinate.y - curr.y) / distance;
+                        dz += (coordinate.z - curr.z) / distance;
+                    }
+                }
+                double dSqr = Math.sqrt(dx*dx + dy*dy + dz*dz);
+                /* Avoid division by zero if the intermediate point is the 
median */
+                if (dSqr > DBL_EPSILON) {
+                    double rInv = Math.max(0, 1.0 / dSqr);
+                    next.x = (1.0 - rInv)*next.x + rInv*curr.x;
+                    next.y = (1.0 - rInv)*next.y + rInv*curr.y;
+                    next.z = (1.0 - rInv)*next.z + rInv*curr.z;
+                }
+            }
+            delta = distance3d(curr, next);
+            curr.x = next.x;
+            curr.y = next.y;
+            curr.z = next.z;
+        }
+        return delta;
+    }
+
+    private static Coordinate initGuess(Coordinate[] points) {
+        Coordinate guess = new Coordinate(0, 0, 0);
+        for (Coordinate point : points) {
+            guess.x += point.x / points.length;
+            guess.y += point.y / points.length;
+            guess.z += point.z / points.length;
+        }
+        return guess;
+    }
+
+    private static Coordinate[] extractCoordinates(Geometry geometry) {
+        Coordinate[] points = geometry.getCoordinates();
+        if(points.length == 0)
+            return points;
+        boolean is3d = !Double.isNaN(points[0].z);
+        Coordinate[] coordinates = new Coordinate[points.length];
+        for(int i = 0; i < points.length; i++) {
+            coordinates[i] = points[i].copy();
+            if(!is3d)
+                coordinates[i].z = 0.0;
+        }
+        return coordinates;
+    }
+
+    public static Geometry geometricMedian(Geometry geometry, double 
tolerance, int maxIter, boolean failIfNotConverged) throws Exception {
+        String geometryType = geometry.getGeometryType();
+        if(!(Geometry.TYPENAME_POINT.equals(geometryType) || 
Geometry.TYPENAME_MULTIPOINT.equals(geometryType))) {
+            throw new Exception("Unsupported geometry type: " + geometryType);
+        }
+        Coordinate[] coordinates = extractCoordinates(geometry);
+        if(coordinates.length == 0)
+            return new Point(null, GEOMETRY_FACTORY);
+        Coordinate median = initGuess(coordinates);
+        double delta = Double.MAX_VALUE;
+        double[] distances = new double[coordinates.length]; // preallocate to 
reduce gc pressure for large iterations
+        for(int i = 0; i < maxIter && delta > tolerance; i++)
+            delta = iteratePoints(median, coordinates, distances);
+        if (failIfNotConverged && delta > tolerance)
+            throw new Exception(String.format("Median failed to converge 
within %.1E after %d iterations.", tolerance, maxIter));
+        boolean is3d = !Double.isNaN(geometry.getCoordinate().z);
+        if(!is3d)
+            median.z = Double.NaN;
+        Point point = new Point(new CoordinateArraySequence(new 
Coordinate[]{median}), GEOMETRY_FACTORY);
+        point.setSRID(geometry.getSRID());
+        return point;
+    }
+
+    public static Geometry geometricMedian(Geometry geometry, double 
tolerance, int maxIter) throws Exception {
+        return geometricMedian(geometry, tolerance, maxIter, false);
+    }
+
+    public static Geometry geometricMedian(Geometry geometry, double 
tolerance) throws Exception {
+        return geometricMedian(geometry, tolerance, DEFAULT_MAX_ITER, false);
+    }
+
+    public static Geometry geometricMedian(Geometry geometry) throws Exception 
{
+        return geometricMedian(geometry, DEFAULT_TOLERANCE, DEFAULT_MAX_ITER, 
false);
+    }
+
 }
diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java 
b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
index eabaaf36..78a26e19 100644
--- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
+++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
@@ -14,22 +14,38 @@
 package org.apache.sedona.common;
 
 import com.google.common.geometry.S2CellId;
-import org.apache.sedona.common.utils.GeomUtils;
+import com.google.common.math.DoubleMath;
 import org.apache.sedona.common.utils.S2Utils;
 import org.junit.Test;
 import org.locationtech.jts.geom.*;
+import org.locationtech.jts.io.WKTReader;
 
 import java.util.Arrays;
 import java.util.HashSet;
 import java.util.Set;
 import java.util.stream.Collectors;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
+import static org.junit.Assert.*;
 
 public class FunctionsTest {
     public static final GeometryFactory GEOMETRY_FACTORY = new 
GeometryFactory();
 
+    protected static final double FP_TOLERANCE = 1e-12;
+    protected static final CoordinateSequenceComparator 
COORDINATE_SEQUENCE_COMPARATOR = new CoordinateSequenceComparator(2){
+        @Override
+        protected int compareCoordinate(CoordinateSequence s1, 
CoordinateSequence s2, int i, int dimension) {
+            for (int d = 0; d < dimension; d++) {
+                double ord1 = s1.getOrdinate(i, d);
+                double ord2 = s2.getOrdinate(i, d);
+                int comp = DoubleMath.fuzzyCompare(ord1, ord2, FP_TOLERANCE);
+                if (comp != 0) return comp;
+            }
+            return 0;
+        }
+    };
+
+    private final WKTReader wktReader = new WKTReader();
+
     private Coordinate[] coordArray(double... coordValues) {
         Coordinate[] coords = new Coordinate[(int)(coordValues.length / 2)];
         for (int i = 0; i < coordValues.length; i += 2) {
@@ -389,4 +405,40 @@ public class FunctionsTest {
         expects.add(10);
         assertEquals(expects, levels);
     }
+
+    @Test
+    public void geometricMedian() throws Exception {
+        MultiPoint multiPoint = GEOMETRY_FACTORY.createMultiPointFromCoords(
+                coordArray(1480,0, 620,0));
+        Geometry actual = Functions.geometricMedian(multiPoint);
+        Geometry expected = wktReader.read("POINT (1050 0)");
+        assertEquals(0, expected.compareTo(actual, 
COORDINATE_SEQUENCE_COMPARATOR));
+    }
+
+    @Test
+    public void geometricMedianTolerance() throws Exception {
+        MultiPoint multiPoint = GEOMETRY_FACTORY.createMultiPointFromCoords(
+                coordArray(0,0, 10,1, 5,1, 20,20));
+        Geometry actual = Functions.geometricMedian(multiPoint, 1e-15);
+        Geometry expected = wktReader.read("POINT (5 1)");
+        assertEquals(0, expected.compareTo(actual, 
COORDINATE_SEQUENCE_COMPARATOR));
+    }
+
+    @Test
+    public void geometricMedianUnsupported() {
+        LineString lineString = GEOMETRY_FACTORY.createLineString(
+                coordArray(1480,0, 620,0));
+        Exception e = assertThrows(Exception.class, () -> 
Functions.geometricMedian(lineString));
+        assertEquals("Unsupported geometry type: LineString", e.getMessage());
+    }
+
+    @Test
+    public void geometricMedianFailConverge() {
+        MultiPoint multiPoint = GEOMETRY_FACTORY.createMultiPointFromCoords(
+                coordArray(12,5, 62,7, 100,-1, 100,-5, 10,20, 105,-5));
+        Exception e = assertThrows(Exception.class,
+                () -> Functions.geometricMedian(multiPoint, 1e-6, 5, true));
+        assertEquals("Median failed to converge within 1.0E-06 after 5 
iterations.", e.getMessage());
+    }
+
 }
diff --git a/docs/api/flink/Function.md b/docs/api/flink/Function.md
index ee2a4f60..603609d8 100644
--- a/docs/api/flink/Function.md
+++ b/docs/api/flink/Function.md
@@ -355,6 +355,36 @@ Result:
 +-----------------------------+
 ```
 
+## ST_GeometricMedian
+
+Introduction: Computes the approximate geometric median of a MultiPoint 
geometry using the Weiszfeld algorithm. The geometric median provides a 
centrality measure that is less sensitive to outlier points than the centroid.
+
+The algorithm will iterate until the distance change between successive 
iterations is less than the supplied `tolerance` parameter. If this condition 
has not been met after `maxIter` iterations, the function will produce an error 
and exit, unless `failIfNotConverged` is set to `false`.
+
+If a `tolerance` value is not provided, a default `tolerance` value is `1e-6`.
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter: 
integer, failIfNotConverged: boolean)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter: 
integer)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float)`
+
+Format: `ST_GeometricMedian(geom: geometry)`
+
+Default parameters: `tolerance: 1e-6, maxIter: 1000, failIfNotConverged: false`
+
+Since: `1.4.1`
+
+Example:
+```sql
+SELECT ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT((0 0), (1 1), (2 2), (200 
200))'))
+```
+
+Output:
+```
+POINT (1.9761550281255005 1.9761550281255005)
+```
+
 ## ST_GeometryN
 
 Introduction: Return the 0-based Nth geometry if the geometry is a 
GEOMETRYCOLLECTION, (MULTI)POINT, (MULTI)LINESTRING, MULTICURVE or 
(MULTI)POLYGON. Otherwise, return null
@@ -927,4 +957,3 @@ SELECT ST_ZMin(ST_GeomFromText('LINESTRING(1 3 4, 5 6 7)'))
 ```
 
 Output: `4.0`
-
diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md
index d15bc735..f9a8d82b 100644
--- a/docs/api/sql/Function.md
+++ b/docs/api/sql/Function.md
@@ -542,6 +542,36 @@ Result:
 +-----------------------------+
 ```
 
+## ST_GeometricMedian
+
+Introduction: Computes the approximate geometric median of a MultiPoint 
geometry using the Weiszfeld algorithm. The geometric median provides a 
centrality measure that is less sensitive to outlier points than the centroid.
+
+The algorithm will iterate until the distance change between successive 
iterations is less than the supplied `tolerance` parameter. If this condition 
has not been met after `maxIter` iterations, the function will produce an error 
and exit, unless `failIfNotConverged` is set to `false`.
+
+If a `tolerance` value is not provided, a default `tolerance` value is `1e-6`.
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter: 
integer, failIfNotConverged: boolean)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter: 
integer)`
+
+Format: `ST_GeometricMedian(geom: geometry, tolerance: float)`
+
+Format: `ST_GeometricMedian(geom: geometry)`
+
+Default parameters: `tolerance: 1e-6, maxIter: 1000, failIfNotConverged: false`
+
+Since: `1.4.1`
+
+Example:
+```sql
+SELECT ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT((0 0), (1 1), (2 2), (200 
200))'))
+```
+
+Output:
+```
+POINT (1.9761550281255005 1.9761550281255005)
+```
+
 ## ST_GeometryN
 
 Introduction: Return the 0-based Nth geometry if the geometry is a 
GEOMETRYCOLLECTION, (MULTI)POINT, (MULTI)LINESTRING, MULTICURVE or 
(MULTI)POLYGON. Otherwise, return null
@@ -1560,3 +1590,4 @@ SELECT ST_ZMin(ST_GeomFromText('LINESTRING(1 3 4, 5 6 
7)'))
 ```
 
 Output: `4.0`
+
diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java 
b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
index 83a99029..1075710a 100644
--- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java
+++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
@@ -89,7 +89,8 @@ public class Catalog {
                 new Functions.ST_SetPoint(),
                 new Functions.ST_LineFromMultiPoint(),
                 new Functions.ST_Split(),
-                new Functions.ST_S2CellIDs()
+                new Functions.ST_S2CellIDs(),
+                new Functions.ST_GeometricMedian()
         };
     }
 
diff --git 
a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java 
b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
index a22b91c2..11f5e9d8 100644
--- a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
+++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
@@ -497,4 +497,37 @@ public class Functions {
             return org.apache.sedona.common.Functions.s2CellIDs(geom, level);
         }
     }
+
+    public static class ST_GeometricMedian extends ScalarFunction {
+        @DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class)
+        public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class) Object o) throws Exception {
+            Geometry geometry = (Geometry) o;
+            return 
org.apache.sedona.common.Functions.geometricMedian(geometry);
+        }
+
+        @DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class)
+        public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class) Object o,
+                             @DataTypeHint("Double") Double tolerance) throws 
Exception {
+            Geometry geometry = (Geometry) o;
+            return 
org.apache.sedona.common.Functions.geometricMedian(geometry, tolerance);
+        }
+
+        @DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class)
+        public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class) Object o,
+                             @DataTypeHint("Double") Double tolerance,
+                             int maxIter) throws Exception {
+            Geometry geometry = (Geometry) o;
+            return 
org.apache.sedona.common.Functions.geometricMedian(geometry, tolerance, 
maxIter);
+        }
+
+        @DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class)
+        public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class) Object o,
+                             @DataTypeHint("Double") Double tolerance,
+                             int maxIter, @DataTypeHint("Boolean") Boolean 
failIfNotConverged) throws Exception {
+            Geometry geometry = (Geometry) o;
+            return 
org.apache.sedona.common.Functions.geometricMedian(geometry, tolerance, 
maxIter, failIfNotConverged);
+        }
+
+    }
+
 }
diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java 
b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
index 94546edd..02486e39 100644
--- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
+++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
@@ -24,6 +24,7 @@ import org.locationtech.jts.geom.Geometry;
 import org.locationtech.jts.geom.LineString;
 import org.locationtech.jts.geom.Point;
 import org.locationtech.jts.geom.Polygon;
+import org.locationtech.jts.io.ParseException;
 import org.opengis.referencing.FactoryException;
 import org.opengis.referencing.crs.CoordinateReferenceSystem;
 
@@ -37,6 +38,7 @@ import static org.apache.flink.table.api.Expressions.call;
 import static org.junit.Assert.*;
 
 public class FunctionTest extends TestBase{
+
     @BeforeClass
     public static void onceExecutedBeforeAll() {
         initialize();
@@ -614,4 +616,34 @@ public class FunctionTest extends TestBase{
         assertEquals(1, count(joinCleanedTable));
         assertEquals(2, first(joinCleanedTable).getField(1));
     }
+
+    @Test
+    public void testGeometricMedian() throws ParseException {
+        Table pointTable = tableEnv.sqlQuery("SELECT 
ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT((0 0), (1 1), (2 2), (200 
200))'))");
+        Geometry expected = wktReader.read("POINT (1.9761550281255005 
1.9761550281255005)");
+        Geometry actual = (Geometry) first(pointTable).getField(0);
+        assertEquals(String.format("expected: %s was %s", expected.toText(), 
actual != null ? actual.toText() : "null"),
+                0, expected.compareTo(actual, COORDINATE_SEQUENCE_COMPARATOR));
+    }
+
+    @Test
+    public void testGeometricMedianParamsTolerance() throws ParseException {
+        Table pointTable = tableEnv.sqlQuery(
+                "SELECT ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT ((0 0), 
(1 1), (0 1), (2 2))'), 1e-5)");
+        Geometry expected = wktReader.read("POINT (0.996230268436779 
0.9999899629155288)");
+        Geometry actual = (Geometry) first(pointTable).getField(0);
+        assertEquals(String.format("expected: %s was %s", expected.toText(), 
actual != null ? actual.toText() : "null"),
+                0, expected.compareTo(actual, COORDINATE_SEQUENCE_COMPARATOR));
+    }
+
+    @Test
+    public void testGeometricMedianParamsFull() throws ParseException {
+        Table pointTable = tableEnv.sqlQuery(
+                "SELECT ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT ((0 0), 
(1 1), (0 1), (2 2))'), 1e-5, 10, false)");
+        Geometry expected = wktReader.read("POINT (0.8844442206215307 
0.9912184073718183)");
+        Geometry actual = (Geometry) first(pointTable).getField(0);
+        assertEquals(String.format("expected: %s was %s", expected.toText(), 
actual != null ? actual.toText() : "null"),
+                0, expected.compareTo(actual, COORDINATE_SEQUENCE_COMPARATOR));
+    }
+
 }
diff --git a/flink/src/test/java/org/apache/sedona/flink/TestBase.java 
b/flink/src/test/java/org/apache/sedona/flink/TestBase.java
index be88c5fb..1148f715 100644
--- a/flink/src/test/java/org/apache/sedona/flink/TestBase.java
+++ b/flink/src/test/java/org/apache/sedona/flink/TestBase.java
@@ -13,6 +13,7 @@
  */
 package org.apache.sedona.flink;
 
+import com.google.common.math.DoubleMath;
 import org.apache.flink.api.common.eventtime.WatermarkStrategy;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -27,6 +28,7 @@ import org.apache.flink.types.Row;
 import org.apache.flink.util.CloseableIterator;
 import org.apache.sedona.flink.expressions.Constructors;
 import org.locationtech.jts.geom.*;
+import org.locationtech.jts.io.WKTReader;
 import org.wololo.jts2geojson.GeoJSONWriter;
 
 import java.sql.Timestamp;
@@ -50,6 +52,21 @@ public class TestBase {
     static Long timestamp_base = new 
Timestamp(System.currentTimeMillis()).getTime();
     static Long time_interval = 1L; // Generate a record per this interval. 
Unit is second
 
+    static final double FP_TOLERANCE = 1e-12;
+    static final CoordinateSequenceComparator COORDINATE_SEQUENCE_COMPARATOR = 
new CoordinateSequenceComparator(2){
+        @Override
+        protected int compareCoordinate(CoordinateSequence s1, 
CoordinateSequence s2, int i, int dimension) {
+            for (int d = 0; d < dimension; d++) {
+                double ord1 = s1.getOrdinate(i, d);
+                double ord2 = s2.getOrdinate(i, d);
+                int comp = DoubleMath.fuzzyCompare(ord1, ord2, FP_TOLERANCE);
+                if (comp != 0) return comp;
+            }
+            return 0;
+        }
+    };
+    final WKTReader wktReader = new WKTReader();
+
     public void setTestDataSize(int testDataSize) {
         this.testDataSize = testDataSize;
     }
diff --git a/python/sedona/sql/st_functions.py 
b/python/sedona/sql/st_functions.py
index d2fd8edc..4bb7df5b 100644
--- a/python/sedona/sql/st_functions.py
+++ b/python/sedona/sql/st_functions.py
@@ -53,6 +53,7 @@ __all__ = [
     "ST_FlipCoordinates",
     "ST_Force_2D",
     "ST_GeoHash",
+    "ST_GeometricMedian",
     "ST_GeometryN",
     "ST_GeometryType",
     "ST_InteriorRingN",
@@ -496,6 +497,31 @@ def ST_GeoHash(geometry: ColumnOrName, precision: 
Union[ColumnOrName, int]) -> C
     """
     return _call_st_function("ST_GeoHash", (geometry, precision))
 
+@validate_argument_types
+def ST_GeometricMedian(geometry: ColumnOrName, tolerance: 
Optional[Union[ColumnOrName, float]] = 1e-6,
+                       max_iter: Optional[Union[ColumnOrName, int]] = 1000,
+                       fail_if_not_converged: Optional[Union[ColumnOrName, 
bool]] = False) -> Column:
+    """Computes the approximate geometric median of a MultiPoint geometry 
using the Weiszfeld algorithm.
+    The geometric median provides a centrality measure that is less sensitive 
to outlier points than the centroid.
+    The algorithm will iterate until the distance change between successive 
iterations is less than the
+    supplied `tolerance` parameter. If this condition has not been met after 
`maxIter` iterations, the function will
+    produce an error and exit, unless `failIfNotConverged` is set to `false`. 
If a `tolerance` value is not provided,
+    a default `tolerance` value is `1e-6`.
+
+    :param geometry: MultiPoint or Point geometry.
+    :type geometry: ColumnOrName
+    :param tolerance: Distance limit change between successive iterations, 
defaults to 1e-6.
+    :type tolerance: Optional[Union[ColumnOrName, float]], optional
+    :param max_iter: Max number of iterations, defaults to 1000.
+    :type max_iter: Optional[Union[ColumnOrName, int]], optional
+    :param fail_if_not_converged: Generate error if not converged within given 
tolerance and number of iterations, defaults to False
+    :type fail_if_not_converged: Optional[Union[ColumnOrName, boolean]], 
optional
+    :return: Point geometry column.
+    :rtype: Column
+    """
+    args = (geometry, tolerance, max_iter, fail_if_not_converged)
+    return _call_st_function("ST_GeometricMedian", args)
+
 
 @validate_argument_types
 def ST_GeometryN(multi_geometry: ColumnOrName, n: Union[ColumnOrName, int]) -> 
Column:
diff --git a/python/tests/sql/test_dataframe_api.py 
b/python/tests/sql/test_dataframe_api.py
index dd5f275a..1ad052bd 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -81,6 +81,7 @@ test_configurations = [
     (stf.ST_ExteriorRing, ("geom",), "triangle_geom", "", "LINESTRING (0 0, 1 
0, 1 1, 0 0)"),
     (stf.ST_FlipCoordinates, ("point",), "point_geom", "", "POINT (1 0)"),
     (stf.ST_Force_2D, ("point",), "point_geom", "", "POINT (0 1)"),
+    (stf.ST_GeometricMedian, ("multipoint",), "multipoint_geom", "", "POINT 
(22.500002656424286 21.250001168173426)"),
     (stf.ST_GeometryN, ("geom", 0), "multipoint", "", "POINT (0 0)"),
     (stf.ST_GeometryType, ("point",), "point_geom", "", "ST_Point"),
     (stf.ST_InteriorRingN, ("geom", 0), "geom_with_hole", "", "LINESTRING (1 
1, 2 2, 2 1, 1 1)"),
diff --git a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala 
b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 9364eae5..9b85f931 100644
--- a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -142,6 +142,7 @@ object Catalog {
     function[ST_MLineFromText](0),
     function[ST_Split](),
     function[ST_S2CellIDs](),
+    function[ST_GeometricMedian](1e-6, 1000, false),
     // Expression for rasters
     function[RS_NormalizedDifference](),
     function[RS_Mean](),
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 9d065a78..7298e1af 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -933,3 +933,18 @@ case class ST_CollectionExtract(inputExpressions: 
Seq[Expression])
 
   override def allowRightNull: Boolean = true
 }
+
+/**
+ * Returns a POINT Computes the approximate geometric median of a MultiPoint 
geometry using the Weiszfeld algorithm.
+ * The geometric median provides a centrality measure that is less sensitive 
to outlier points than the centroid.
+ *
+ * @param inputExpressions Geometry
+ */
+case class ST_GeometricMedian(inputExpressions: Seq[Expression])
+  extends InferredQuarternaryExpression(Functions.geometricMedian) with 
FoldableExpression {
+
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
+    copy(inputExpressions = newChildren)
+  }
+}
+
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index ce077d99..3e5fa706 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -275,4 +275,14 @@ object st_functions extends DataFrameAPI {
 
   def ST_ZMin(geometry: Column): Column = wrapExpression[ST_ZMin](geometry)
   def ST_ZMin(geometry: String): Column = wrapExpression[ST_ZMin](geometry)
+
+  def ST_GeometricMedian(geometry: Column): Column = 
wrapExpression[ST_GeometricMedian](geometry, 1e-6, 1000, false)
+  def ST_GeometricMedian(geometry: String): Column = 
wrapExpression[ST_GeometricMedian](geometry, 1e-6, 1000, false)
+  def ST_GeometricMedian(geometry: Column, tolerance: Column): Column = 
wrapExpression[ST_GeometricMedian](geometry, tolerance, 1000, false)
+  def ST_GeometricMedian(geometry: String, tolerance: Double): Column = 
wrapExpression[ST_GeometricMedian](geometry, tolerance, 1000, false)
+  def ST_GeometricMedian(geometry: Column, tolerance: Column, maxIter: 
Column): Column = wrapExpression[ST_GeometricMedian](geometry, tolerance, 
maxIter, false)
+  def ST_GeometricMedian(geometry: String, tolerance: Double, maxIter: Int): 
Column = wrapExpression[ST_GeometricMedian](geometry, tolerance, maxIter, false)
+  def ST_GeometricMedian(geometry: Column, tolerance: Column, maxIter: Column, 
failIfNotConverged: Column): Column = 
wrapExpression[ST_GeometricMedian](geometry, tolerance, maxIter, 
failIfNotConverged)
+  def ST_GeometricMedian(geometry: String, tolerance: Double, maxIter: Int, 
failIfNotConverged: Boolean): Column = 
wrapExpression[ST_GeometricMedian](geometry, tolerance, maxIter, 
failIfNotConverged)
+
 }
diff --git 
a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala 
b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 47c72e7c..f12b5874 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -18,11 +18,13 @@
  */
 package org.apache.sedona.sql
 
+import com.google.common.math.DoubleMath
 import org.apache.log4j.{Level, Logger}
 import org.apache.sedona.core.serde.SedonaKryoRegistrator
 import org.apache.sedona.sql.utils.SedonaSQLRegistrator
 import org.apache.spark.serializer.KryoSerializer
 import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.locationtech.jts.geom.{CoordinateSequence, 
CoordinateSequenceComparator}
 import org.scalatest.{BeforeAndAfterAll, FunSpec}
 
 trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
@@ -81,4 +83,18 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
 
   lazy val buildPointDf = 
loadCsv(csvPointInputLocation).selectExpr("ST_Point(cast(_c0 as 
Decimal(24,20)),cast(_c1 as Decimal(24,20))) as pointshape")
   lazy val buildPolygonDf = 
loadCsv(csvPolygonInputLocation).selectExpr("ST_PolygonFromEnvelope(cast(_c0 as 
Decimal(24,20)),cast(_c1 as Decimal(24,20)), cast(_c2 as Decimal(24,20)), 
cast(_c3 as Decimal(24,20))) as polygonshape")
+
+  protected final val FP_TOLERANCE: Double = 1e-12
+  protected final val COORDINATE_SEQUENCE_COMPARATOR: 
CoordinateSequenceComparator = new CoordinateSequenceComparator(2) {
+    override protected def compareCoordinate(s1: CoordinateSequence, s2: 
CoordinateSequence, i: Int, dimension: Int): Int = {
+      for (d <- 0 until dimension) {
+        val ord1: Double = s1.getOrdinate(i, d)
+        val ord2: Double = s2.getOrdinate(i, d)
+        val comp: Int = DoubleMath.fuzzyCompare(ord1, ord2, FP_TOLERANCE)
+        if (comp != 0) return comp
+      }
+      0
+    }
+  }
+
 }
diff --git 
a/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala 
b/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index 0b3d9d42..a66a9be2 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.{DataFrame, Row}
 import org.geotools.referencing.CRS
 import org.locationtech.jts.algorithm.MinimumBoundingCircle
-import org.locationtech.jts.geom.{Geometry, Polygon}
+import org.locationtech.jts.geom.{CoordinateSequenceComparator, Geometry, 
Polygon}
 import org.locationtech.jts.io.WKTWriter
 import org.locationtech.jts.linearref.LengthIndexedLine
 import org.locationtech.jts.operation.distance3d.Distance3DOp
@@ -1762,6 +1762,8 @@ class functionTestScala extends TestBaseScala with 
Matchers with GeometrySample
     assert(functionDf.first().get(0) == null)
     functionDf = sparkSession.sql("select ST_LineFromMultiPoint(null)")
     assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_GeometricMedian(null)")
+    assert(functionDf.first().get(0) == null)
   }
 
   it ("Should pass St_CollectionExtract") {
@@ -1818,4 +1820,23 @@ class functionTestScala extends TestBaseScala with 
Matchers with GeometrySample
       assert(textResult==expected)
     }
   }
+
+  it ("Should pass ST_GeometricMedian") {
+    val geomTestCases = Map(
+      ("'MULTIPOINT((10 40), (40 30), (20 20), (30 10))'", 1e-15) -> 
"'POINT(22.5 21.25)'",
+      ("'MULTIPOINT((0 0), (1 1), (2 2), (200 200))'", 1e-6) -> "'POINT 
(1.9761550281255005 1.9761550281255005)'",
+      ("'MULTIPOINT ((0 0), (10 1), (5 1), (20 20))'", 1e-15) -> "'POINT (5 
1)'",
+      ("'MULTIPOINT ((0 -1), (0 0), (0 0), (0 1))'", 1e-6) -> "'POINT (0 0)'",
+      ("'POINT (7 6)'", 1e-6) -> "'POINT (7 6)'",
+      ("'MULTIPOINT ((12 5),(62 7),(100 -1),(100 -5),(10 20),(105 -5))'", 
1e-15) -> "'POINT(84.21672412761632 0.1351485929395439)'"
+    )
+    for(((targetWkt, tolerance), expectedWkt) <- geomTestCases) {
+      val df = sparkSession.sql(s"SELECT 
ST_GeometricMedian(ST_GeomFromWKT($targetWkt), $tolerance), " +
+        s"ST_GeomFromWKT($expectedWkt)")
+      val actual = df.take(1)(0).get(0).asInstanceOf[Geometry]
+      val expected = df.take(1)(0).get(1).asInstanceOf[Geometry]
+      assert(expected.compareTo(actual, COORDINATE_SEQUENCE_COMPARATOR) == 0)
+    }
+  }
+
 }


Reply via email to