This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch SEDONA-587
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit c11a9c76706c28cd9e2ada36602b30a8f1484c0e
Author: Furqaan Khan <[email protected]>
AuthorDate: Wed Apr 24 00:45:46 2024 -0400

    [TASK-271] Add ST_Force4D (#168)
    
    * feat: add ST_Force4D
    
    * remove print
---
 .../java/org/apache/sedona/common/Functions.java   |  8 +++
 .../common/utils/GeometryForce4DTransformer.java   | 54 +++++++++++++++++++
 .../org/apache/sedona/common/FunctionsTest.java    | 61 ++++++++++++++++++++++
 docs/api/flink/Function.md                         | 39 ++++++++++++++
 docs/api/sql/Function.md                           | 39 ++++++++++++++
 .../main/java/org/apache/sedona/flink/Catalog.java |  1 +
 .../apache/sedona/flink/expressions/Functions.java | 16 ++++++
 .../java/org/apache/sedona/flink/FunctionTest.java | 13 +++++
 python/sedona/sql/st_functions.py                  | 14 +++++
 python/tests/sql/test_dataframe_api.py             |  2 +
 python/tests/sql/test_function.py                  |  6 +++
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  1 +
 .../sql/sedona_sql/expressions/Functions.scala     |  8 +++
 .../sql/sedona_sql/expressions/st_functions.scala  |  5 ++
 .../apache/sedona/sql/dataFrameAPITestScala.scala  | 11 ++++
 .../org/apache/sedona/sql/functionTestScala.scala  | 18 +++++++
 16 files changed, 296 insertions(+)

diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java 
b/common/src/main/java/org/apache/sedona/common/Functions.java
index 95747364e..ad85bea36 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -1549,6 +1549,14 @@ public class Functions {
         return force3DM(geom, 0.0);
     }
 
+    public static Geometry force4D(Geometry geom, double zValue, double 
mValue) {
+        return GeometryForce4DTransformer.transform(geom, zValue, mValue);
+    }
+
+    public static Geometry force4D(Geometry geom) {
+        return force4D(geom, 0.0, 0.0);
+    }
+
     public static Geometry force3D(Geometry geometry, double zValue) {
         return GeomUtils.get3DGeom(geometry, zValue);
     }
diff --git 
a/common/src/main/java/org/apache/sedona/common/utils/GeometryForce4DTransformer.java
 
b/common/src/main/java/org/apache/sedona/common/utils/GeometryForce4DTransformer.java
new file mode 100644
index 000000000..f61f8b0a2
--- /dev/null
+++ 
b/common/src/main/java/org/apache/sedona/common/utils/GeometryForce4DTransformer.java
@@ -0,0 +1,54 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.utils;
+
+import org.apache.sedona.common.Functions;
+import org.locationtech.jts.geom.*;
+import org.locationtech.jts.geom.util.GeometryTransformer;
+
+public class GeometryForce4DTransformer extends GeometryTransformer {
+
+    private static boolean hasZ;
+    private static boolean hasM;
+    private final double mValue;
+    private final double zValue;
+
+    public GeometryForce4DTransformer(double zValue, double mValue) {
+        this.zValue = zValue;
+        this.mValue = mValue;
+    }
+
+    @Override
+    protected CoordinateSequence transformCoordinates(CoordinateSequence 
coords, Geometry parent) {
+        CoordinateXYZM[] newCoords = new CoordinateXYZM[coords.size()];
+        for (int i = 0; i < coords.size(); i++) {
+            Coordinate coordinate = coords.getCoordinate(i);
+            newCoords[i] = new CoordinateXYZM(coordinate.getX(), 
coordinate.getY(),
+                    hasZ ? coordinate.getZ() : zValue, hasM ? 
coordinate.getM() : mValue);
+        }
+
+        return createCoordinateSequence(newCoords);
+    }
+
+    public static Geometry transform(Geometry geometry, double zValue, double 
mValue) {
+        if (geometry.getCoordinates().length == 0) return geometry;
+        hasZ = Functions.hasZ(geometry);
+        hasM = Functions.hasM(geometry);
+        if (hasZ && hasM) return geometry;
+
+
+        GeometryForce4DTransformer transformer = new 
GeometryForce4DTransformer(zValue, mValue);
+        return transformer.transform(geometry);
+    }
+}
diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java 
b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
index 08589d09b..9ce130930 100644
--- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
+++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
@@ -1380,6 +1380,67 @@ public class FunctionsTest extends TestBase {
         assertTrue(Predicates.equals(geom, actualGeom));
     }
 
+    @Test
+    public void force4D() throws ParseException {
+        // testing all geom types
+        Geometry geom = Constructors.geomFromWKT("POINT (1 2)",0);
+        String actual = Functions.asWKT(Functions.force4D(geom,2, 5));
+        String expected = "POINT ZM(1 2 2 5)";
+        assertEquals(expected, actual);
+
+        geom = Constructors.geomFromWKT("MULTIPOINT ((1 2), (2 3))",0);
+        actual = Functions.asWKT(Functions.force4D(geom, 2, 5));
+        expected = "MULTIPOINT ZM((1 2 2 5), (2 3 2 5))";
+        assertEquals(expected, actual);
+
+        geom = Constructors.geomFromWKT("LINESTRING (1 2, 2 3, 3 4)",0);
+        actual = Functions.asWKT(Functions.force4D(geom));
+        expected = "LINESTRING ZM(1 2 0 0, 2 3 0 0, 3 4 0 0)";
+        assertEquals(expected, actual);
+
+        geom = Constructors.geomFromWKT("MULTILINESTRING ((10 10, 20 20, 30 
30), (15 15, 25 25, 35 35))",0);
+        actual = Functions.asWKT(Functions.force4D(geom, 3, 5));
+        expected = "MULTILINESTRING ZM((10 10 3 5, 20 20 3 5, 30 30 3 5), (15 
15 3 5, 25 25 3 5, 35 35 3 5))";
+        assertEquals(expected, actual);
+
+        geom = Constructors.geomFromWKT("LINEARRING (30 10, 40 40, 20 40, 10 
20, 30 10)",0);
+        actual = Functions.asWKT(Functions.force4D(geom, 5, 5));
+        expected = "LINEARRING ZM(30 10 5 5, 40 40 5 5, 20 40 5 5, 10 20 5 5, 
30 10 5 5)";
+        assertEquals(expected, actual);
+
+        geom = Constructors.geomFromWKT("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 
0), (4 4, 4 6, 6 6, 6 4, 4 4))",0);
+        actual = Functions.asWKT(Functions.force4D(geom));
+        expected = "POLYGON ZM((0 0 0 0, 10 0 0 0, 10 10 0 0, 0 10 0 0, 0 0 0 
0), (4 4 0 0, 4 6 0 0, 6 6 0 0, 6 4 0 0, 4 4 0 0))";
+        assertEquals(expected, actual);
+
+        geom = Constructors.geomFromWKT("MULTIPOLYGON (((30 10, 40 40, 20 40, 
10 20, 30 10)), ((15 5, 10 20, 20 30, 15 5)))",0);
+        actual = Functions.asWKT(Functions.force4D(geom, 2, 5));
+        expected = "MULTIPOLYGON ZM(((30 10 2 5, 40 40 2 5, 20 40 2 5, 10 20 2 
5, 30 10 2 5)), ((15 5 2 5, 10 20 2 5, 20 30 2 5, 15 5 2 5)))";
+        assertEquals(expected, actual);
+
+        geom = Constructors.geomFromWKT("GEOMETRYCOLLECTION (POINT (10 10), 
LINESTRING (15 15, 25 25, 35 35), POLYGON ((30 10, 40 40, 20 40, 10 20, 30 
10)))",0);
+        actual = Functions.asWKT(Functions.force4D(geom, 2, 5));
+        expected = "GEOMETRYCOLLECTION ZM(POINT ZM(10 10 2 5), LINESTRING 
ZM(15 15 2 5, 25 25 2 5, 35 35 2 5), POLYGON ZM((30 10 2 5, 40 40 2 5, 20 40 2 
5, 10 20 2 5, 30 10 2 5)))";
+        assertEquals(expected, actual);
+
+        // return 4D input geom as is
+        geom = Constructors.geomFromWKT("POLYGON ZM ((30 10 5 1, 40 40 10 2, 
20 40 15 3, 10 20 20 4, 30 10 5 1))", 0);
+        Geometry actualGeom = Functions.force4D(geom, 10, 10);
+        assertTrue(Predicates.equals(geom, actualGeom));
+
+        // if input geom has z value, keep it and add m
+        geom = Constructors.geomFromWKT("LINESTRING Z(0 1 3, 1 0 3, 2 0 3)", 
0);
+        actual = Functions.asWKT(Functions.force4D(geom, 10, 10));
+        expected = "LINESTRING ZM(0 1 3 10, 1 0 3 10, 2 0 3 10)";
+        assertEquals(expected, actual);
+
+        // if input geom has m value, keep it and add z
+        geom = Constructors.geomFromWKT("LINESTRING M(0 1 3, 1 0 3, 2 0 3)", 
0);
+        actual = Functions.asWKT(Functions.force4D(geom, 10, 10));
+        expected = "LINESTRING ZM(0 1 10 3, 1 0 10 3, 2 0 10 3)";
+        assertEquals(expected, actual);
+    }
+
     @Test
     public void force3DObject3DDefaultValue() {
         int expectedDims = 3;
diff --git a/docs/api/flink/Function.md b/docs/api/flink/Function.md
index 36cb4c7c0..0288436de 100644
--- a/docs/api/flink/Function.md
+++ b/docs/api/flink/Function.md
@@ -1351,6 +1351,45 @@ Output:
 LINESTRING Z(0 1 2.3, 1 0 2.3, 2 0 2.3)
 ```
 
+## ST_Force4D
+
+Introduction: Converts the input geometry to 4D XYZM representation. Retains 
original Z and M values if present. Assigning 0.0 defaults if `mValue` and 
`zValue` aren't specified. The output contains X, Y, Z, and M coordinates. For 
geometries already in 4D form, the function returns the original geometry 
unmodified.
+
+!!!Note
+    Example output is after calling ST_AsText() on returned geometry, which 
adds Z for in the WKT for 3D geometries
+
+Format:
+
+`ST_Force4D(geom: Geometry, zValue: Double, mValue: Double)`
+
+`ST_Force4D(geom: Geometry`
+
+Since: `vTBD`
+
+SQL Example
+
+```sql
+SELECT ST_AsText(ST_Force4D(ST_GeomFromText('POLYGON((0 0 2,0 5 2,5 0 2,0 0 
2),(1 1 2,3 1 2,1 3 2,1 1 2))'), 5, 10))
+```
+
+Output:
+
+```
+POLYGON ZM((0 0 2 10, 0 5 2 10, 5 0 2 10, 0 0 2 10), (1 1 2 10, 3 1 2 10, 1 3 
2 10, 1 1 2 10))
+```
+
+SQL Example
+
+```sql
+SELECT ST_AsText(ST_Force4D(ST_GeomFromText('LINESTRING(0 1,1 0,2 0)'), 3, 1))
+```
+
+Output:
+
+```
+LINESTRING ZM(0 1 3 1, 1 0 3 1, 2 0 3 1)
+```
+
 ## ST_ForceCollection
 
 Introduction: This function converts the input geometry into a 
GeometryCollection, regardless of the original geometry type. If the input is a 
multipart geometry, such as a MultiPolygon or MultiLineString, it will be 
decomposed into a GeometryCollection containing each individual Polygon or 
LineString element from the original multipart geometry.
diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md
index fae471101..447b351da 100644
--- a/docs/api/sql/Function.md
+++ b/docs/api/sql/Function.md
@@ -1356,6 +1356,45 @@ Output:
 LINESTRING Z(0 1 2.3, 1 0 2.3, 2 0 2.3)
 ```
 
+## ST_Force4D
+
+Introduction: Converts the input geometry to 4D XYZM representation. Retains 
original Z and M values if present. Assigning 0.0 defaults if `mValue` and 
`zValue` aren't specified. The output contains X, Y, Z, and M coordinates. For 
geometries already in 4D form, the function returns the original geometry 
unmodified.
+
+!!!Note
+    Example output is after calling ST_AsText() on returned geometry, which 
adds Z for in the WKT for 3D geometries
+
+Format:
+
+`ST_Force4D(geom: Geometry, zValue: Double, mValue: Double)`
+
+`ST_Force4D(geom: Geometry`
+
+Since: `vTBD`
+
+SQL Example
+
+```sql
+SELECT ST_AsText(ST_Force4D(ST_GeomFromText('POLYGON((0 0 2,0 5 2,5 0 2,0 0 
2),(1 1 2,3 1 2,1 3 2,1 1 2))'), 5, 10))
+```
+
+Output:
+
+```
+POLYGON ZM((0 0 2 10, 0 5 2 10, 5 0 2 10, 0 0 2 10), (1 1 2 10, 3 1 2 10, 1 3 
2 10, 1 1 2 10))
+```
+
+SQL Example
+
+```sql
+SELECT ST_AsText(ST_Force4D(ST_GeomFromText('LINESTRING(0 1,1 0,2 0)'), 3, 1))
+```
+
+Output:
+
+```
+LINESTRING ZM(0 1 3 1, 1 0 3 1, 2 0 3 1)
+```
+
 ## ST_ForceCollection
 
 Introduction: This function converts the input geometry into a 
GeometryCollection, regardless of the original geometry type. If the input is a 
multipart geometry, such as a MultiPolygon or MultiLineString, it will be 
decomposed into a GeometryCollection containing each individual Polygon or 
LineString element from the original multipart geometry.
diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java 
b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
index a569d4a99..102289ed7 100644
--- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java
+++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
@@ -160,6 +160,7 @@ public class Catalog {
                 new Functions.ST_Force3D(),
                 new Functions.ST_Force3DM(),
                 new Functions.ST_Force3DZ(),
+                new Functions.ST_Force4D(),
                 new Functions.ST_ForceCollection(),
                 new Functions.ST_ForcePolygonCW(),
                 new Functions.ST_ForceRHR(),
diff --git 
a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java 
b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
index 214ba767e..217d86cfc 100644
--- a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
+++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
@@ -1133,6 +1133,22 @@ public class Functions {
         }
     }
 
+    public static class ST_Force4D extends ScalarFunction {
+
+        @DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class)
+        public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class) Object o,
+                             @DataTypeHint("Double") Double zValue, 
@DataTypeHint("Double") Double mValue) {
+            Geometry geometry = (Geometry) o;
+            return org.apache.sedona.common.Functions.force4D(geometry, 
zValue, mValue);
+        }
+
+        @DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class)
+        public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class) Object o) {
+            Geometry geometry = (Geometry) o;
+            return org.apache.sedona.common.Functions.force4D(geometry);
+        }
+    }
+
     public static class ST_ForceCollection extends ScalarFunction {
         @DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class)
         public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class) Object o) {
diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java 
b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
index a7ffb37ad..ab4e0e522 100644
--- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
+++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
@@ -1337,6 +1337,19 @@ public class FunctionTest extends TestBase{
         assertEquals(Boolean.TRUE, actual);
     }
 
+    @Test
+    public void testForce4D() {
+        Table geomTable = tableEnv.sqlQuery("SELECT 
ST_Force4D(ST_GeomFromText('LINESTRING (1 2, 2 3, 3 4)')) AS geom");
+        String actual = (String) 
first(geomTable.select(call(Functions.ST_AsText.class.getSimpleName(), 
$("geom")))).getField(0);
+        String expected = "LINESTRING ZM(1 2 0 0, 2 3 0 0, 3 4 0 0)";
+        assertEquals(expected, actual);
+
+        geomTable = tableEnv.sqlQuery("SELECT 
ST_Force4D(ST_GeomFromText('POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0), (4 4, 4 6, 
6 6, 6 4, 4 4))')) AS geom");
+        actual = (String) 
first(geomTable.select(call(Functions.ST_AsText.class.getSimpleName(), 
$("geom")))).getField(0);
+        expected = "POLYGON ZM((0 0 0 0, 10 0 0 0, 10 10 0 0, 0 10 0 0, 0 0 0 
0), (4 4 0 0, 4 6 0 0, 6 6 0 0, 6 4 0 0, 4 4 0 0))";
+        assertEquals(expected, actual);
+    }
+
     @Test
     public void testForceCollection() {
         int actual = (int) first(
diff --git a/python/sedona/sql/st_functions.py 
b/python/sedona/sql/st_functions.py
index efd721977..d93408916 100644
--- a/python/sedona/sql/st_functions.py
+++ b/python/sedona/sql/st_functions.py
@@ -1596,6 +1596,20 @@ def ST_Force3DZ(geometry: ColumnOrName, zValue: 
Optional[Union[ColumnOrName, flo
     args = (geometry, zValue)
     return _call_st_function("ST_Force3DZ", args)
 
+@validate_argument_types
+def ST_Force4D(geometry: ColumnOrName, zValue: Optional[Union[ColumnOrName, 
float]] = 0.0,
+               mValue: Optional[Union[ColumnOrName, float]] = 0.0) -> Column:
+    """
+    Return a geometry with a 4D coordinate of value 'zValue' and mValue forced 
upon it. No change happens if the
+    geometry is already 4D, if geometry either has z or m, it will not change 
the existing z or m value.
+
+    :param zValue: Optional value of z coordinate to be potentially added, 
default value is 0.0
+    :param geometry: Geometry column to make 4D
+    :return: 4D geometry with either already 4D geom or z and m component 
provided by zValue and mValue respectively
+    """
+    args = (geometry, zValue, mValue)
+    return _call_st_function("ST_Force4D", args)
+
 @validate_argument_types
 def ST_ForceCollection(geometry: ColumnOrName) -> Column:
     """
diff --git a/python/tests/sql/test_dataframe_api.py 
b/python/tests/sql/test_dataframe_api.py
index de104bc33..35d92dee8 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -126,6 +126,7 @@ test_configurations = [
     (stf.ST_Force3D, ("point", 1.0), "point_geom", "", "POINT Z (0 1 1)"),
     (stf.ST_Force3DM, ("point", 1.0), "point_geom", "ST_AsText(geom)", "POINT 
M(0 1 1)"),
     (stf.ST_Force3DZ, ("point", 1.0), "point_geom", "", "POINT Z (0 1 1)"),
+    (stf.ST_Force4D, ("point", 1.0, 1.0), "point_geom", "ST_AsText(geom)", 
"POINT ZM(0 1 1 1)"),
     (stf.ST_ForceCollection, ("multipoint",), "multipoint_geom", 
"ST_NumGeometries(geom)", 4),
     (stf.ST_ForcePolygonCW, ("geom",), "geom_with_hole", "", "POLYGON ((0 0, 3 
3, 3 0, 0 0), (1 1, 2 1, 2 2, 1 1))"),
     (stf.ST_ForcePolygonCCW, ("geom",), "geom_with_hole", "", "POLYGON ((0 0, 
3 0, 3 3, 0 0), (1 1, 2 2, 2 1, 1 1))"),
@@ -307,6 +308,7 @@ wrong_type_configurations = [
     (stf.ST_Force_2D, (None,)),
     (stf.ST_Force3DM, (None,)),
     (stf.ST_Force3DZ, (None,)),
+    (stf.ST_Force4D, (None,)),
     (stf.ST_ForceCollection, (None,)),
     (stf.ST_ForcePolygonCW, (None,)),
     (stf.ST_ForcePolygonCCW, (None,)),
diff --git a/python/tests/sql/test_function.py 
b/python/tests/sql/test_function.py
index 3ea889d64..c50dda6ab 100644
--- a/python/tests/sql/test_function.py
+++ b/python/tests/sql/test_function.py
@@ -1479,6 +1479,12 @@ class TestPredicateJoin(TestBase):
         actual = actualDf.selectExpr("ST_NDims(geom)").take(1)[0][0]
         assert expected == actual
 
+    def test_force4D(self):
+        expected = 4
+        actualDf = self.spark.sql("SELECT 
ST_Force4D(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'), 1.1, 1.1) AS geom")
+        actual = actualDf.selectExpr("ST_NDims(geom)").take(1)[0][0]
+        assert expected == actual
+
     def test_st_force_collection(self):
         basedf = self.spark.sql("SELECT ST_GeomFromWKT('MULTIPOINT (30 10, 40 
40, 20 20, 10 30, 10 10, 20 50)') AS mpoint, ST_GeomFromWKT('POLYGON ((30 10, 
40 40, 20 40, 10 20, 30 10))') AS poly")
         actual = 
basedf.selectExpr("ST_NumGeometries(ST_ForceCollection(mpoint))").take(1)[0][0]
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 58a287266..7a2f79176 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -193,6 +193,7 @@ object Catalog {
     function[ST_Force3D](0.0),
     function[ST_Force3DM](0.0),
     function[ST_Force3DZ](0.0),
+    function[ST_Force4D](),
     function[ST_ForceCollection](),
     function[ST_NRings](),
     function[ST_Translate](0.0),
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index d8af1de13..0f91ecd08 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -1216,6 +1216,14 @@ case class ST_Force3DM(inputExpressions: Seq[Expression])
   }
 }
 
+case class ST_Force4D(inputExpressions: Seq[Expression])
+  extends InferredExpression(inferrableFunction3(Functions.force4D), 
inferrableFunction1(Functions.force4D)) {
+
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
+    copy(inputExpressions = newChildren)
+  }
+}
+
 case class ST_ForceCollection(inputExpressions: Seq[Expression])
   extends InferredExpression(Functions.forceCollection _) {
 
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index b8ceae4d1..02b0cb17b 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -437,6 +437,11 @@ object st_functions extends DataFrameAPI {
   def ST_Force3DZ(geometry: Column, zValue: Column): Column = 
wrapExpression[ST_Force3DZ](geometry, zValue)
   def ST_Force3DZ(geometry: String, zValue: Double): Column = 
wrapExpression[ST_Force3DZ](geometry, zValue)
 
+  def ST_Force4D(geometry: Column): Column = 
wrapExpression[ST_Force4D](geometry, 0.0, 0.0)
+  def ST_Force4D(geometry: String): Column = 
wrapExpression[ST_Force4D](geometry, 0.0, 0.0)
+  def ST_Force4D(geometry: Column, zValue: Column, mValue: Column): Column = 
wrapExpression[ST_Force4D](geometry, zValue, mValue)
+  def ST_Force4D(geometry: String, zValue: Double, mValue: Double): Column = 
wrapExpression[ST_Force4D](geometry, zValue, mValue)
+
   def ST_ForceCollection(geometry: Column): Column = 
wrapExpression[ST_ForceCollection](geometry)
 
   def ST_ForceCollection(geometry: String): Column = 
wrapExpression[ST_ForceCollection](geometry)
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index 49860651e..7c6e05b7e 100644
--- 
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++ 
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -1371,6 +1371,17 @@ class dataFrameAPITestScala extends TestBaseScala {
       assertEquals(expectedGeomDefaultValue, 
wktWriter.write(actualGeomDefaultValue))
     }
 
+    it("Passed ST_Force4D") {
+      val lineDf = sparkSession.sql("SELECT ST_GeomFromWKT('LINESTRING (0 1, 1 
0, 2 0)') AS geom")
+      val expectedGeom = "LINESTRING ZM(0 1 4 4, 1 0 4 4, 2 0 4 4)"
+      val expectedGeomDefaultValue = "LINESTRING ZM(0 1 0 0, 1 0 0 0, 2 0 0 0)"
+      val forcedGeom = lineDf.select(ST_AsText(ST_Force4D("geom", 4, 
4))).take(1)(0).get(0).asInstanceOf[String]
+      assertEquals(expectedGeom, forcedGeom)
+      val lineDfDefaultValue = sparkSession.sql("SELECT 
ST_GeomFromWKT('LINESTRING (0 1, 1 0, 2 0)') AS geom")
+      val actualGeomDefaultValue = 
lineDfDefaultValue.select(ST_AsText(ST_Force4D("geom"))).take(1)(0).get(0).asInstanceOf[String]
+      assertEquals(expectedGeomDefaultValue, actualGeomDefaultValue)
+    }
+
     it("Passed ST_ForceCollection") {
       val baseDf = sparkSession.sql("SELECT ST_GeomFromWKT('MULTIPOINT (30 10, 
40 40, 20 20, 10 30, 10 10, 20 50)') AS mpoint, ST_GeomFromWKT('POLYGON ((30 
10, 40 40, 20 40, 10 20, 30 10))') AS poly")
       var actual = 
baseDf.select(ST_NumGeometries(ST_ForceCollection("mpoint"))).first().get(0)
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index 0bff17b5f..dd76d3bdd 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -2310,6 +2310,24 @@ class functionTestScala extends TestBaseScala with 
Matchers with GeometrySample
     }
   }
 
+  it("Should pass ST_Force4D") {
+    val geomTestCases = Map(
+      ("'LINESTRING (0 1, 1 0, 2 0)'") -> ("'LINESTRING ZM(0 1 1 1, 1 0 1 1, 2 
0 1 1)'", "'LINESTRING ZM(0 1 0 0, 1 0 0 0, 2 0 0 0)'"),
+      ("'LINESTRING ZM(0 1 3 2, 1 0 3 2, 2 0 3 2)'") -> ("'LINESTRING ZM(0 1 3 
2, 1 0 3 2, 2 0 3 2)'", "'LINESTRING ZM(0 1 3 2, 1 0 3 2, 2 0 3 2)'"),
+      ("'LINESTRING EMPTY'") -> ("'LINESTRING EMPTY'", "'LINESTRING EMPTY'")
+    )
+    for (((geom), expectedResult) <- geomTestCases) {
+      val df = sparkSession.sql(s"SELECT 
ST_AsText(ST_Force4D(ST_GeomFromWKT($geom), 1, 1)) AS geom, " + 
s"$expectedResult")
+      val dfDefaultValue = sparkSession.sql(s"SELECT 
ST_AsText(ST_Force4D(ST_GeomFromWKT($geom))) AS geom, " + s"$expectedResult")
+      val actual = df.take(1)(0).get(0).asInstanceOf[String]
+      val expected = 
df.take(1)(0).get(1).asInstanceOf[GenericRowWithSchema].get(0).asInstanceOf[String]
+      val actualDefaultValue = 
dfDefaultValue.take(1)(0).get(0).asInstanceOf[String]
+      val expectedDefaultValue = 
dfDefaultValue.take(1)(0).get(1).asInstanceOf[GenericRowWithSchema].get(1).asInstanceOf[String]
+      assertEquals(expected, actual)
+      assertEquals(expectedDefaultValue, actualDefaultValue);
+    }
+  }
+
   it("Passed ST_ForceCollection") {
     var actual = sparkSession.sql("SELECT 
ST_NumGeometries(ST_ForceCollection(ST_GeomFromWKT('MULTIPOINT (30 10, 40 40, 
20 20, 10 30, 10 10, 20 50)')))").first().get(0)
     assert(actual == 6)

Reply via email to