This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 800a3f54 [SEDONA-196] Add ST_Force3D (#856)
800a3f54 is described below

commit 800a3f5415577b6deec58d2c5e8d11c6658c9bce
Author: Nilesh Gajwani <[email protected]>
AuthorDate: Mon Jun 12 04:12:02 2023 -0700

    [SEDONA-196] Add ST_Force3D (#856)
---
 .../java/org/apache/sedona/common/Functions.java   |  8 +++
 .../org/apache/sedona/common/utils/GeomUtils.java  | 14 +++++
 .../org/apache/sedona/common/FunctionsTest.java    | 67 ++++++++++++++++++++++
 docs/api/flink/Function.md                         | 45 +++++++++++++++
 docs/api/sql/Function.md                           | 43 ++++++++++++++
 .../main/java/org/apache/sedona/flink/Catalog.java |  3 +-
 .../apache/sedona/flink/expressions/Functions.java | 17 ++++++
 .../java/org/apache/sedona/flink/FunctionTest.java | 20 +++++++
 python/sedona/sql/st_functions.py                  | 14 ++++-
 python/tests/sql/test_dataframe_api.py             |  2 +
 python/tests/sql/test_function.py                  |  7 +++
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  1 +
 .../sql/sedona_sql/expressions/Functions.scala     |  8 +++
 .../sql/sedona_sql/expressions/st_functions.scala  |  8 +++
 .../apache/sedona/sql/dataFrameAPITestScala.scala  | 13 +++++
 .../org/apache/sedona/sql/functionTestScala.scala  | 19 ++++++
 16 files changed, 287 insertions(+), 2 deletions(-)

diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java 
b/common/src/main/java/org/apache/sedona/common/Functions.java
index 7fa3d802..7a0e469f 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -853,6 +853,14 @@ public class Functions {
         return geometry.getNumPoints();
     }
 
+    public static Geometry force3D(Geometry geometry, double zValue) {
+        return GeomUtils.get3DGeom(geometry, zValue);
+    }
+
+    public static Geometry force3D(Geometry geometry) {
+       return GeomUtils.get3DGeom(geometry, 0.0);
+    }
+
     public static Geometry geometricMedian(Geometry geometry, double 
tolerance, int maxIter, boolean failIfNotConverged) throws Exception {
         String geometryType = geometry.getGeometryType();
         if(!(Geometry.TYPENAME_POINT.equals(geometryType) || 
Geometry.TYPENAME_MULTIPOINT.equals(geometryType))) {
diff --git a/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java 
b/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
index 635c8cd4..29ddf572 100644
--- a/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
+++ b/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
@@ -419,4 +419,18 @@ public class GeomUtils {
         }
         return geometries;
     }
+
+
+    public static Geometry get3DGeom(Geometry geometry, double zValue) {
+        Coordinate[] coordinates = geometry.getCoordinates();
+        if (coordinates.length == 0) return geometry;
+        boolean is3d = !Double.isNaN(coordinates[0].z);
+        for(int i = 0; i < coordinates.length; i++) {
+            if(!is3d) {
+                coordinates[i].setZ(zValue);
+            }
+        }
+        geometry.geometryChanged();
+        return geometry;
+    }
 }
diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java 
b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
index 93d20b07..f54f2711 100644
--- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
+++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
@@ -17,10 +17,12 @@ import com.google.common.geometry.S2CellId;
 import com.google.common.math.DoubleMath;
 import org.apache.sedona.common.sphere.Haversine;
 import org.apache.sedona.common.sphere.Spheroid;
+import org.apache.sedona.common.utils.GeomUtils;
 import org.apache.sedona.common.utils.S2Utils;
 import org.junit.Test;
 import org.locationtech.jts.geom.*;
 import org.locationtech.jts.io.WKTReader;
+import org.locationtech.jts.io.WKTWriter;
 
 import java.util.Arrays;
 import java.util.HashSet;
@@ -57,6 +59,14 @@ public class FunctionsTest {
         return coords;
     }
 
+    private Coordinate[] coordArray3d(double... coordValues) {
+        Coordinate[] coords = new Coordinate[(int)(coordValues.length / 3)];
+        for (int i = 0; i < coordValues.length; i += 3) {
+            coords[(int)(i / 3)] = new Coordinate(coordValues[i], 
coordValues[i+1], coordValues[i+2]);
+        }
+        return coords;
+    }
+
     @Test
     public void splitLineStringByMultipoint() {
         LineString lineString = 
GEOMETRY_FACTORY.createLineString(coordArray(0.0, 0.0, 1.5, 1.5, 2.0, 2.0));
@@ -581,4 +591,61 @@ public class FunctionsTest {
         Exception e = assertThrows(IllegalArgumentException.class, () -> 
Functions.numPoints(polygon));
         assertEquals(expected, e.getMessage());
     }
+
+    @Test
+    public void force3DObject2D() {
+        int expectedDims = 3;
+        LineString line = GEOMETRY_FACTORY.createLineString(coordArray(0, 1, 
1, 0, 2, 0));
+        LineString expectedLine = 
GEOMETRY_FACTORY.createLineString(coordArray3d(0, 1, 1.1, 1, 0, 1.1, 2, 0, 
1.1));
+        Geometry forcedLine = Functions.force3D(line, 1.1);
+        WKTWriter wktWriter = new 
WKTWriter(GeomUtils.getDimension(expectedLine));
+        assertEquals(wktWriter.write(expectedLine), 
wktWriter.write(forcedLine));
+        assertEquals(expectedDims, Functions.nDims(forcedLine));
+    }
+
+    @Test
+    public void force3DObject2DDefaultValue() {
+        int expectedDims = 3;
+        Polygon polygon = GEOMETRY_FACTORY.createPolygon(coordArray(0, 0, 0, 
90, 0, 0));
+        Polygon expectedPolygon = 
GEOMETRY_FACTORY.createPolygon(coordArray3d(0, 0, 0, 0, 90, 0, 0, 0, 0));
+        Geometry forcedPolygon = Functions.force3D(polygon);
+        WKTWriter wktWriter = new 
WKTWriter(GeomUtils.getDimension(expectedPolygon));
+        assertEquals(wktWriter.write(expectedPolygon), 
wktWriter.write(forcedPolygon));
+        assertEquals(expectedDims, Functions.nDims(forcedPolygon));
+    }
+
+    @Test
+    public void force3DObject3D() {
+        int expectedDims = 3;
+        LineString line3D = GEOMETRY_FACTORY.createLineString(coordArray3d(0, 
1, 1, 1, 2, 1, 1, 2, 2));
+        Geometry forcedLine3D = Functions.force3D(line3D, 2.0);
+        WKTWriter wktWriter = new WKTWriter(GeomUtils.getDimension(line3D));
+        assertEquals(wktWriter.write(line3D), wktWriter.write(forcedLine3D));
+        assertEquals(expectedDims, Functions.nDims(forcedLine3D));
+    }
+
+    @Test
+    public void force3DObject3DDefaultValue() {
+        int expectedDims = 3;
+        Polygon polygon = GEOMETRY_FACTORY.createPolygon(coordArray3d(0, 0, 0, 
90, 0, 0, 0, 0, 0));
+        Geometry forcedPolygon = Functions.force3D(polygon);
+        WKTWriter wktWriter = new WKTWriter(GeomUtils.getDimension(polygon));
+        assertEquals(wktWriter.write(polygon), wktWriter.write(forcedPolygon));
+        assertEquals(expectedDims, Functions.nDims(forcedPolygon));
+    }
+
+    @Test
+    public void force3DEmptyObject() {
+        LineString emptyLine = GEOMETRY_FACTORY.createLineString();
+        Geometry forcedEmptyLine = Functions.force3D(emptyLine, 1.2);
+        assertEquals(emptyLine.isEmpty(), forcedEmptyLine.isEmpty());
+    }
+
+    @Test
+    public void force3DEmptyObjectDefaultValue() {
+        LineString emptyLine = GEOMETRY_FACTORY.createLineString();
+        Geometry forcedEmptyLine = Functions.force3D(emptyLine);
+        assertEquals(emptyLine.isEmpty(), forcedEmptyLine.isEmpty());
+    }
+
 }
diff --git a/docs/api/flink/Function.md b/docs/api/flink/Function.md
index 0b79c9b9..11d86ff6 100644
--- a/docs/api/flink/Function.md
+++ b/docs/api/flink/Function.md
@@ -390,6 +390,51 @@ Input: `POLYGON((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 
2,1 1 2))`
 
 Output: `POLYGON((0 0,0 5,5 0,0 0),(1 1,3 1,1 3,1 1))`
 
+## ST_Force3D
+Introduction: Forces the geometry into a 3-dimensional model so that all 
output representations will have X, Y and Z coordinates.
+An optionally given zValue is tacked onto the geometry if the geometry is 
2-dimensional. Default value of zValue is 0.0
+If the given geometry is 3-dimensional, no change is performed on it.
+If the given geometry is empty, no change is performed on it.
+
+!!!Note
+    Example output is after calling ST_AsText() on returned geometry, which 
adds Z for in the WKT for 3D geometries
+
+Format: `ST_Force3D(geometry, zValue)`
+
+Since: `1.4.1`
+
+Example: 
+
+```sql
+SELECT ST_Force3D(df.geometry) AS geom
+from df
+```
+
+Input: `LINESTRING(0 1, 1 2, 2 1)`
+
+Output: `LINESTRING Z(0 1 0, 1 2 0, 2 1 0)`
+
+Input: `POLYGON((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Output: `POLYGON Z((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+```sql
+SELECT ST_Force3D(df.geometry, 2.3) AS geom
+from df
+```
+
+Input: `LINESTRING(0 1, 1 2, 2 1)`
+
+Output: `LINESTRING Z(0 1 2.3, 1 2 2.3, 2 1 2.3)`
+
+Input: `POLYGON((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Output: `POLYGON Z((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Input: `LINESTRING EMPTY`
+
+Output: `LINESTRING EMPTY`
+
 ## ST_GeoHash
 
 Introduction: Returns GeoHash of the geometry with given precision
diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md
index 9935f0f0..05265b53 100644
--- a/docs/api/sql/Function.md
+++ b/docs/api/sql/Function.md
@@ -576,6 +576,49 @@ Result:
 +---------------------------------------------------------------+
 ```
 
+## ST_Force3D
+Introduction: Forces the geometry into a 3-dimensional model so that all 
output representations will have X, Y and Z coordinates.
+An optionally given zValue is tacked onto the geometry if the geometry is 
2-dimensional. Default value of zValue is 0.0
+If the given geometry is 3-dimensional, no change is performed on it.
+If the given geometry is empty, no change is performed on it.
+
+!!!Note
+    Example output is after calling ST_AsText() on returned geometry, which 
adds Z for in the WKT for 3D geometries
+
+Format: `ST_Force3D(geometry, zValue)`
+
+Since: `1.4.1`
+
+Spark SQL Example:
+
+```sql
+SELECT ST_Force3D(geometry) AS geom
+```
+
+Input: `LINESTRING(0 1, 1 2, 2 1)`
+
+Output: `LINESTRING Z(0 1 0, 1 2 0, 2 1 0)`
+
+Input: `POLYGON((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Output: `POLYGON Z((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+```sql
+SELECT ST_Force3D(geometry, 2.3) AS geom
+```
+
+Input: `LINESTRING(0 1, 1 2, 2 1)`
+
+Output: `LINESTRING Z(0 1 2.3, 1 2 2.3, 2 1 2.3)`
+
+Input: `POLYGON((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Output: `POLYGON Z((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Input: `LINESTRING EMPTY`
+
+Output: `LINESTRING EMPTY`
+
 ## ST_GeoHash
 
 Introduction: Returns GeoHash of the geometry with given precision
diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java 
b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
index 884126c0..f72f0793 100644
--- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java
+++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
@@ -95,7 +95,8 @@ public class Catalog {
                 new Functions.ST_Split(),
                 new Functions.ST_S2CellIDs(),
                 new Functions.ST_GeometricMedian(),
-                new Functions.ST_NumPoints()
+                new Functions.ST_NumPoints(),
+                new Functions.ST_Force3D()
         };
     }
 
diff --git 
a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java 
b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
index 7001345a..33f1b6c0 100644
--- a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
+++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
@@ -13,6 +13,7 @@
  */
 package org.apache.sedona.flink.expressions;
 
+import org.apache.calcite.runtime.Geometries;
 import org.apache.flink.table.annotation.DataTypeHint;
 import org.apache.flink.table.functions.ScalarFunction;
 import org.locationtech.jts.geom.Geometry;
@@ -582,4 +583,20 @@ public class Functions {
         }
     }
 
+    public static class ST_Force3D extends ScalarFunction {
+
+        @DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class)
+        public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class) Object o,
+                             @DataTypeHint("Double") Double zValue) {
+            Geometry geometry = (Geometry) o;
+            return org.apache.sedona.common.Functions.force3D(geometry, 
zValue);
+        }
+
+        @DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class)
+        public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = 
org.locationtech.jts.geom.Geometry.class) Object o) {
+            Geometry geometry = (Geometry) o;
+            return org.apache.sedona.common.Functions.force3D(geometry);
+        }
+    }
+
 }
diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java 
b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
index eac04d2f..219970c5 100644
--- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
+++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
@@ -699,4 +699,24 @@ public class FunctionTest extends TestBase{
         assertEquals(expected, actual);
     }
 
+    @Test
+    public void testForce3D() {
+        Integer expectedDims = 3;
+        Table pointTable = tableEnv.sqlQuery("SELECT 
ST_Force3D(ST_GeomFromWKT('LINESTRING(0 1, 1 0, 2 0)'), 1.2) " +
+                                                "AS " + polygonColNames[0]);
+        pointTable = 
pointTable.select(call(Functions.ST_NDims.class.getSimpleName(), 
$(polygonColNames[0])));
+        Integer actual = (Integer) first(pointTable).getField(0);
+        assertEquals(expectedDims, actual);
+    }
+
+    @Test
+    public void testForce3DDefaultValue() {
+        Integer expectedDims = 3;
+        Table pointTable = tableEnv.sqlQuery("SELECT 
ST_Force3D(ST_GeomFromWKT('LINESTRING(0 1, 1 0, 2 0)')) " +
+                "AS " + polygonColNames[0]);
+        pointTable = 
pointTable.select(call(Functions.ST_NDims.class.getSimpleName(), 
$(polygonColNames[0])));
+        Integer actual = (Integer) first(pointTable).getField(0);
+        assertEquals(expectedDims, actual);
+    }
+
 }
diff --git a/python/sedona/sql/st_functions.py 
b/python/sedona/sql/st_functions.py
index d5c7602b..66cf149a 100644
--- a/python/sedona/sql/st_functions.py
+++ b/python/sedona/sql/st_functions.py
@@ -108,7 +108,8 @@ __all__ = [
     "ST_Z",
     "ST_ZMax",
     "ST_ZMin",
-    "ST_NumPoints"
+    "ST_NumPoints",
+    "ST_Force3D"
 ]
 
 
@@ -1241,3 +1242,14 @@ def ST_NumPoints(geometry: ColumnOrName) -> Column:
     :rtype: Column
     """
     return _call_st_function("ST_NumPoints", geometry)
+
+
+def ST_Force3D(geometry: ColumnOrName, zValue: Optional[Union[ColumnOrName, 
float]] = 0.0) -> Column:
+    """
+    Return a geometry with a 3D coordinate of value 'zValue' forced upon it. 
No change happens if the geometry is already 3D
+    :param zValue: Optional value of z coordinate to be potentially added, 
default value is 0.0
+    :param geometry: Geometry column to make 3D
+    :return: 3D geometry with either already present z coordinate if any, or 
zcoordinate with given zValue
+    """
+    args = (geometry, zValue)
+    return _call_st_function("ST_Force3D", args)
diff --git a/python/tests/sql/test_dataframe_api.py 
b/python/tests/sql/test_dataframe_api.py
index 1ece9f69..f9fbfc82 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -85,6 +85,7 @@ test_configurations = [
     (stf.ST_ExteriorRing, ("geom",), "triangle_geom", "", "LINESTRING (0 0, 1 
0, 1 1, 0 0)"),
     (stf.ST_FlipCoordinates, ("point",), "point_geom", "", "POINT (1 0)"),
     (stf.ST_Force_2D, ("point",), "point_geom", "", "POINT (0 1)"),
+    (stf.ST_Force3D, ("point", 1), "point_geom", "", "POINT Z (0 1 1)"),
     (stf.ST_GeometricMedian, ("multipoint",), "multipoint_geom", "", "POINT 
(22.500002656424286 21.250001168173426)"),
     (stf.ST_GeometryN, ("geom", 0), "multipoint", "", "POINT (0 0)"),
     (stf.ST_GeometryType, ("point",), "point_geom", "", "ST_Point"),
@@ -111,6 +112,7 @@ test_configurations = [
     (stf.ST_NPoints, ("line",), "linestring_geom", "", 6),
     (stf.ST_NumGeometries, ("geom",), "multipoint", "", 2),
     (stf.ST_NumInteriorRings, ("geom",), "geom_with_hole", "", 1),
+    (stf.ST_NumPoints, ("line",), "linestring_geom", "", 6),
     (stf.ST_PointN, ("line", 2), "linestring_geom", "", "POINT (1 0)"),
     (stf.ST_PointOnSurface, ("line",), "linestring_geom", "", "POINT (2 0)"),
     (stf.ST_PrecisionReduce, ("geom", 1), "precision_reduce_point", "", "POINT 
(0.1 0.2)"),
diff --git a/python/tests/sql/test_function.py 
b/python/tests/sql/test_function.py
index ba665774..18b38c11 100644
--- a/python/tests/sql/test_function.py
+++ b/python/tests/sql/test_function.py
@@ -1079,3 +1079,10 @@ class TestPredicateJoin(TestBase):
         actual = self.spark.sql("SELECT 
ST_NumPoints(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'))").take(1)[0][0]
         expected = 3
         assert expected == actual
+
+    def test_force3D(self):
+        expected = 3
+        actualDf = self.spark.sql("SELECT 
ST_Force3D(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'), 1.1) AS geom")
+        actual = actualDf.selectExpr("ST_NDims(geom)").take(1)[0][0]
+        assert expected == actual
+
diff --git a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala 
b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index eede0a13..9edc563f 100644
--- a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -148,6 +148,7 @@ object Catalog {
     function[ST_AreaSpheroid](),
     function[ST_LengthSpheroid](),
     function[ST_NumPoints](),
+    function[ST_Force3D](0.0),
     // Expression for rasters
     function[RS_NormalizedDifference](),
     function[RS_Mean](),
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 2a6dfde3..e70f1473 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -988,3 +988,11 @@ case class ST_NumPoints(inputExpressions: Seq[Expression])
   }
 }
 
+case class ST_Force3D(inputExpressions: Seq[Expression])
+  extends InferredBinaryExpression(Functions.force3D) with FoldableExpression {
+
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
+    copy(inputExpressions = newChildren)
+  }
+}
+
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index 94e0874d..6101222b 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -305,4 +305,12 @@ object st_functions extends DataFrameAPI {
   def ST_NumPoints(geometry: Column): Column = 
wrapExpression[ST_NumPoints](geometry)
 
   def ST_NumPoints(geometry: String): Column = 
wrapExpression[ST_NumPoints](geometry)
+
+  def ST_Force3D(geometry: Column): Column = 
wrapExpression[ST_Force3D](geometry, 0.0)
+
+  def ST_Force3D(geometry: String): Column = 
wrapExpression[ST_Force3D](geometry, 0.0)
+
+  def ST_Force3D(geometry: Column, zValue: Column): Column = 
wrapExpression[ST_Force3D](geometry, zValue)
+
+  def ST_Force3D(geometry: String, zValue: Double): Column = 
wrapExpression[ST_Force3D](geometry, zValue)
 }
diff --git 
a/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala 
b/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index e3eaf8ff..70b5aa69 100644
--- 
a/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++ 
b/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -27,6 +27,7 @@ import 
org.apache.spark.sql.sedona_sql.expressions.st_functions._
 import org.apache.spark.sql.sedona_sql.expressions.st_predicates._
 import org.apache.spark.sql.sedona_sql.expressions.st_aggregates._
 import org.junit.Assert.assertEquals
+import org.locationtech.jts.io.WKTWriter
 
 import scala.collection.mutable
 
@@ -957,5 +958,17 @@ class dataFrameAPITestScala extends TestBaseScala {
       val expectedResult = 3
       assert(actualResult == expectedResult)
     }
+
+    it("Passed ST_Force3D") {
+      val lineDf = sparkSession.sql("SELECT ST_GeomFromWKT('LINESTRING (0 1, 1 
0, 2 0)') AS geom")
+      val expectedGeom = "LINESTRING Z(0 1 2.3, 1 0 2.3, 2 0 2.3)"
+      val expectedGeomDefaultValue = "LINESTRING Z(0 1 0, 1 0 0, 2 0 0)"
+      val wktWriter = new WKTWriter(3)
+      val forcedGeom = lineDf.select(ST_Force3D("geom", 
2.3)).take(1)(0).get(0).asInstanceOf[Geometry]
+      assertEquals(expectedGeom, wktWriter.write(forcedGeom))
+      val lineDfDefaultValue = sparkSession.sql("SELECT 
ST_GeomFromWKT('LINESTRING (0 1, 1 0, 2 0)') AS geom")
+      val actualGeomDefaultValue = 
lineDfDefaultValue.select(ST_Force3D("geom")).take(1)(0).get(0).asInstanceOf[Geometry]
+      assertEquals(expectedGeomDefaultValue, 
wktWriter.write(actualGeomDefaultValue))
+    }
   }
 }
diff --git 
a/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala 
b/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index 21e897a5..b2ad34e0 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -21,6 +21,7 @@ package org.apache.sedona.sql
 
 import org.apache.commons.codec.binary.Hex
 import org.apache.sedona.sql.implicits._
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, 
GenericRowWithSchema}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.{DataFrame, Row}
 import org.geotools.referencing.CRS
@@ -1921,4 +1922,22 @@ class functionTestScala extends TestBaseScala with 
Matchers with GeometrySample
       assertEquals(expected, actual)
     }
   }
+
+  it("should pass ST_Force3D") {
+    val geomTestCases = Map(
+      ("'LINESTRING (0 1, 1 0, 2 0)'") -> ("'LINESTRING Z(0 1 1, 1 0 1, 2 0 
1)'", "'LINESTRING Z(0 1 0, 1 0 0, 2 0 0)'"),
+      ("'LINESTRING Z(0 1 3, 1 0 3, 2 0 3)'") -> ("'LINESTRING Z(0 1 3, 1 0 3, 
2 0 3)'", "'LINESTRING Z(0 1 3, 1 0 3, 2 0 3)'"),
+      ("'LINESTRING EMPTY'") -> ("'LINESTRING EMPTY'", "'LINESTRING EMPTY'")
+    )
+    for (((geom), expectedResult) <- geomTestCases) {
+      val df = sparkSession.sql(s"SELECT 
ST_AsText(ST_Force3D(ST_GeomFromWKT($geom), 1)) AS geom, " + s"$expectedResult")
+      val dfDefaultValue = sparkSession.sql(s"SELECT 
ST_AsText(ST_Force3D(ST_GeomFromWKT($geom))) AS geom, " + s"$expectedResult")
+      val actual = df.take(1)(0).get(0).asInstanceOf[String]
+      val expected = 
df.take(1)(0).get(1).asInstanceOf[GenericRowWithSchema].get(0).asInstanceOf[String]
+      val actualDefaultValue = 
dfDefaultValue.take(1)(0).get(0).asInstanceOf[String]
+      val expectedDefaultValue = 
dfDefaultValue.take(1)(0).get(1).asInstanceOf[GenericRowWithSchema].get(1).asInstanceOf[String]
+      assertEquals(expected, actual)
+      assertEquals(expectedDefaultValue, actualDefaultValue);
+    }
+  }
 }

Reply via email to