This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 800a3f54 [SEDONA-196] Add ST_Force3D (#856)
800a3f54 is described below
commit 800a3f5415577b6deec58d2c5e8d11c6658c9bce
Author: Nilesh Gajwani <[email protected]>
AuthorDate: Mon Jun 12 04:12:02 2023 -0700
[SEDONA-196] Add ST_Force3D (#856)
---
.../java/org/apache/sedona/common/Functions.java | 8 +++
.../org/apache/sedona/common/utils/GeomUtils.java | 14 +++++
.../org/apache/sedona/common/FunctionsTest.java | 67 ++++++++++++++++++++++
docs/api/flink/Function.md | 45 +++++++++++++++
docs/api/sql/Function.md | 43 ++++++++++++++
.../main/java/org/apache/sedona/flink/Catalog.java | 3 +-
.../apache/sedona/flink/expressions/Functions.java | 17 ++++++
.../java/org/apache/sedona/flink/FunctionTest.java | 20 +++++++
python/sedona/sql/st_functions.py | 14 ++++-
python/tests/sql/test_dataframe_api.py | 2 +
python/tests/sql/test_function.py | 7 +++
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 1 +
.../sql/sedona_sql/expressions/Functions.scala | 8 +++
.../sql/sedona_sql/expressions/st_functions.scala | 8 +++
.../apache/sedona/sql/dataFrameAPITestScala.scala | 13 +++++
.../org/apache/sedona/sql/functionTestScala.scala | 19 ++++++
16 files changed, 287 insertions(+), 2 deletions(-)
diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java
b/common/src/main/java/org/apache/sedona/common/Functions.java
index 7fa3d802..7a0e469f 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -853,6 +853,14 @@ public class Functions {
return geometry.getNumPoints();
}
+ public static Geometry force3D(Geometry geometry, double zValue) {
+ return GeomUtils.get3DGeom(geometry, zValue);
+ }
+
+ public static Geometry force3D(Geometry geometry) {
+ return GeomUtils.get3DGeom(geometry, 0.0);
+ }
+
public static Geometry geometricMedian(Geometry geometry, double
tolerance, int maxIter, boolean failIfNotConverged) throws Exception {
String geometryType = geometry.getGeometryType();
if(!(Geometry.TYPENAME_POINT.equals(geometryType) ||
Geometry.TYPENAME_MULTIPOINT.equals(geometryType))) {
diff --git a/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
b/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
index 635c8cd4..29ddf572 100644
--- a/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
+++ b/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
@@ -419,4 +419,18 @@ public class GeomUtils {
}
return geometries;
}
+
+
+ public static Geometry get3DGeom(Geometry geometry, double zValue) {
+ Coordinate[] coordinates = geometry.getCoordinates();
+ if (coordinates.length == 0) return geometry;
+ boolean is3d = !Double.isNaN(coordinates[0].z);
+ for(int i = 0; i < coordinates.length; i++) {
+ if(!is3d) {
+ coordinates[i].setZ(zValue);
+ }
+ }
+ geometry.geometryChanged();
+ return geometry;
+ }
}
diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
index 93d20b07..f54f2711 100644
--- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
+++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
@@ -17,10 +17,12 @@ import com.google.common.geometry.S2CellId;
import com.google.common.math.DoubleMath;
import org.apache.sedona.common.sphere.Haversine;
import org.apache.sedona.common.sphere.Spheroid;
+import org.apache.sedona.common.utils.GeomUtils;
import org.apache.sedona.common.utils.S2Utils;
import org.junit.Test;
import org.locationtech.jts.geom.*;
import org.locationtech.jts.io.WKTReader;
+import org.locationtech.jts.io.WKTWriter;
import java.util.Arrays;
import java.util.HashSet;
@@ -57,6 +59,14 @@ public class FunctionsTest {
return coords;
}
+ private Coordinate[] coordArray3d(double... coordValues) {
+ Coordinate[] coords = new Coordinate[(int)(coordValues.length / 3)];
+ for (int i = 0; i < coordValues.length; i += 3) {
+ coords[(int)(i / 3)] = new Coordinate(coordValues[i],
coordValues[i+1], coordValues[i+2]);
+ }
+ return coords;
+ }
+
@Test
public void splitLineStringByMultipoint() {
LineString lineString =
GEOMETRY_FACTORY.createLineString(coordArray(0.0, 0.0, 1.5, 1.5, 2.0, 2.0));
@@ -581,4 +591,61 @@ public class FunctionsTest {
Exception e = assertThrows(IllegalArgumentException.class, () ->
Functions.numPoints(polygon));
assertEquals(expected, e.getMessage());
}
+
+ @Test
+ public void force3DObject2D() {
+ int expectedDims = 3;
+ LineString line = GEOMETRY_FACTORY.createLineString(coordArray(0, 1,
1, 0, 2, 0));
+ LineString expectedLine =
GEOMETRY_FACTORY.createLineString(coordArray3d(0, 1, 1.1, 1, 0, 1.1, 2, 0,
1.1));
+ Geometry forcedLine = Functions.force3D(line, 1.1);
+ WKTWriter wktWriter = new
WKTWriter(GeomUtils.getDimension(expectedLine));
+ assertEquals(wktWriter.write(expectedLine),
wktWriter.write(forcedLine));
+ assertEquals(expectedDims, Functions.nDims(forcedLine));
+ }
+
+ @Test
+ public void force3DObject2DDefaultValue() {
+ int expectedDims = 3;
+ Polygon polygon = GEOMETRY_FACTORY.createPolygon(coordArray(0, 0, 0,
90, 0, 0));
+ Polygon expectedPolygon =
GEOMETRY_FACTORY.createPolygon(coordArray3d(0, 0, 0, 0, 90, 0, 0, 0, 0));
+ Geometry forcedPolygon = Functions.force3D(polygon);
+ WKTWriter wktWriter = new
WKTWriter(GeomUtils.getDimension(expectedPolygon));
+ assertEquals(wktWriter.write(expectedPolygon),
wktWriter.write(forcedPolygon));
+ assertEquals(expectedDims, Functions.nDims(forcedPolygon));
+ }
+
+ @Test
+ public void force3DObject3D() {
+ int expectedDims = 3;
+ LineString line3D = GEOMETRY_FACTORY.createLineString(coordArray3d(0,
1, 1, 1, 2, 1, 1, 2, 2));
+ Geometry forcedLine3D = Functions.force3D(line3D, 2.0);
+ WKTWriter wktWriter = new WKTWriter(GeomUtils.getDimension(line3D));
+ assertEquals(wktWriter.write(line3D), wktWriter.write(forcedLine3D));
+ assertEquals(expectedDims, Functions.nDims(forcedLine3D));
+ }
+
+ @Test
+ public void force3DObject3DDefaultValue() {
+ int expectedDims = 3;
+ Polygon polygon = GEOMETRY_FACTORY.createPolygon(coordArray3d(0, 0, 0,
90, 0, 0, 0, 0, 0));
+ Geometry forcedPolygon = Functions.force3D(polygon);
+ WKTWriter wktWriter = new WKTWriter(GeomUtils.getDimension(polygon));
+ assertEquals(wktWriter.write(polygon), wktWriter.write(forcedPolygon));
+ assertEquals(expectedDims, Functions.nDims(forcedPolygon));
+ }
+
+ @Test
+ public void force3DEmptyObject() {
+ LineString emptyLine = GEOMETRY_FACTORY.createLineString();
+ Geometry forcedEmptyLine = Functions.force3D(emptyLine, 1.2);
+ assertEquals(emptyLine.isEmpty(), forcedEmptyLine.isEmpty());
+ }
+
+ @Test
+ public void force3DEmptyObjectDefaultValue() {
+ LineString emptyLine = GEOMETRY_FACTORY.createLineString();
+ Geometry forcedEmptyLine = Functions.force3D(emptyLine);
+ assertEquals(emptyLine.isEmpty(), forcedEmptyLine.isEmpty());
+ }
+
}
diff --git a/docs/api/flink/Function.md b/docs/api/flink/Function.md
index 0b79c9b9..11d86ff6 100644
--- a/docs/api/flink/Function.md
+++ b/docs/api/flink/Function.md
@@ -390,6 +390,51 @@ Input: `POLYGON((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3
2,1 1 2))`
Output: `POLYGON((0 0,0 5,5 0,0 0),(1 1,3 1,1 3,1 1))`
+## ST_Force3D
+Introduction: Forces the geometry into a 3-dimensional model so that all
output representations will have X, Y and Z coordinates.
+An optionally given zValue is tacked onto the geometry if the geometry is
2-dimensional. Default value of zValue is 0.0
+If the given geometry is 3-dimensional, no change is performed on it.
+If the given geometry is empty, no change is performed on it.
+
+!!!Note
+ Example output is after calling ST_AsText() on returned geometry, which
adds Z for in the WKT for 3D geometries
+
+Format: `ST_Force3D(geometry, zValue)`
+
+Since: `1.4.1`
+
+Example:
+
+```sql
+SELECT ST_Force3D(df.geometry) AS geom
+from df
+```
+
+Input: `LINESTRING(0 1, 1 2, 2 1)`
+
+Output: `LINESTRING Z(0 1 0, 1 2 0, 2 1 0)`
+
+Input: `POLYGON((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Output: `POLYGON Z((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+```sql
+SELECT ST_Force3D(df.geometry, 2.3) AS geom
+from df
+```
+
+Input: `LINESTRING(0 1, 1 2, 2 1)`
+
+Output: `LINESTRING Z(0 1 2.3, 1 2 2.3, 2 1 2.3)`
+
+Input: `POLYGON((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Output: `POLYGON Z((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Input: `LINESTRING EMPTY`
+
+Output: `LINESTRING EMPTY`
+
## ST_GeoHash
Introduction: Returns GeoHash of the geometry with given precision
diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md
index 9935f0f0..05265b53 100644
--- a/docs/api/sql/Function.md
+++ b/docs/api/sql/Function.md
@@ -576,6 +576,49 @@ Result:
+---------------------------------------------------------------+
```
+## ST_Force3D
+Introduction: Forces the geometry into a 3-dimensional model so that all
output representations will have X, Y and Z coordinates.
+An optionally given zValue is tacked onto the geometry if the geometry is
2-dimensional. Default value of zValue is 0.0
+If the given geometry is 3-dimensional, no change is performed on it.
+If the given geometry is empty, no change is performed on it.
+
+!!!Note
+ Example output is after calling ST_AsText() on returned geometry, which
adds Z for in the WKT for 3D geometries
+
+Format: `ST_Force3D(geometry, zValue)`
+
+Since: `1.4.1`
+
+Spark SQL Example:
+
+```sql
+SELECT ST_Force3D(geometry) AS geom
+```
+
+Input: `LINESTRING(0 1, 1 2, 2 1)`
+
+Output: `LINESTRING Z(0 1 0, 1 2 0, 2 1 0)`
+
+Input: `POLYGON((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Output: `POLYGON Z((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+```sql
+SELECT ST_Force3D(geometry, 2.3) AS geom
+```
+
+Input: `LINESTRING(0 1, 1 2, 2 1)`
+
+Output: `LINESTRING Z(0 1 2.3, 1 2 2.3, 2 1 2.3)`
+
+Input: `POLYGON((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Output: `POLYGON Z((0 0 2,0 5 2,5 0 2,0 0 2),(1 1 2,3 1 2,1 3 2,1 1 2))`
+
+Input: `LINESTRING EMPTY`
+
+Output: `LINESTRING EMPTY`
+
## ST_GeoHash
Introduction: Returns GeoHash of the geometry with given precision
diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java
b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
index 884126c0..f72f0793 100644
--- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java
+++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
@@ -95,7 +95,8 @@ public class Catalog {
new Functions.ST_Split(),
new Functions.ST_S2CellIDs(),
new Functions.ST_GeometricMedian(),
- new Functions.ST_NumPoints()
+ new Functions.ST_NumPoints(),
+ new Functions.ST_Force3D()
};
}
diff --git
a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
index 7001345a..33f1b6c0 100644
--- a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
+++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
@@ -13,6 +13,7 @@
*/
package org.apache.sedona.flink.expressions;
+import org.apache.calcite.runtime.Geometries;
import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.functions.ScalarFunction;
import org.locationtech.jts.geom.Geometry;
@@ -582,4 +583,20 @@ public class Functions {
}
}
+ public static class ST_Force3D extends ScalarFunction {
+
+ @DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class)
+ public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class) Object o,
+ @DataTypeHint("Double") Double zValue) {
+ Geometry geometry = (Geometry) o;
+ return org.apache.sedona.common.Functions.force3D(geometry,
zValue);
+ }
+
+ @DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class)
+ public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo =
org.locationtech.jts.geom.Geometry.class) Object o) {
+ Geometry geometry = (Geometry) o;
+ return org.apache.sedona.common.Functions.force3D(geometry);
+ }
+ }
+
}
diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
index eac04d2f..219970c5 100644
--- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
+++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
@@ -699,4 +699,24 @@ public class FunctionTest extends TestBase{
assertEquals(expected, actual);
}
+ @Test
+ public void testForce3D() {
+ Integer expectedDims = 3;
+ Table pointTable = tableEnv.sqlQuery("SELECT
ST_Force3D(ST_GeomFromWKT('LINESTRING(0 1, 1 0, 2 0)'), 1.2) " +
+ "AS " + polygonColNames[0]);
+ pointTable =
pointTable.select(call(Functions.ST_NDims.class.getSimpleName(),
$(polygonColNames[0])));
+ Integer actual = (Integer) first(pointTable).getField(0);
+ assertEquals(expectedDims, actual);
+ }
+
+ @Test
+ public void testForce3DDefaultValue() {
+ Integer expectedDims = 3;
+ Table pointTable = tableEnv.sqlQuery("SELECT
ST_Force3D(ST_GeomFromWKT('LINESTRING(0 1, 1 0, 2 0)')) " +
+ "AS " + polygonColNames[0]);
+ pointTable =
pointTable.select(call(Functions.ST_NDims.class.getSimpleName(),
$(polygonColNames[0])));
+ Integer actual = (Integer) first(pointTable).getField(0);
+ assertEquals(expectedDims, actual);
+ }
+
}
diff --git a/python/sedona/sql/st_functions.py
b/python/sedona/sql/st_functions.py
index d5c7602b..66cf149a 100644
--- a/python/sedona/sql/st_functions.py
+++ b/python/sedona/sql/st_functions.py
@@ -108,7 +108,8 @@ __all__ = [
"ST_Z",
"ST_ZMax",
"ST_ZMin",
- "ST_NumPoints"
+ "ST_NumPoints",
+ "ST_Force3D"
]
@@ -1241,3 +1242,14 @@ def ST_NumPoints(geometry: ColumnOrName) -> Column:
:rtype: Column
"""
return _call_st_function("ST_NumPoints", geometry)
+
+
+def ST_Force3D(geometry: ColumnOrName, zValue: Optional[Union[ColumnOrName,
float]] = 0.0) -> Column:
+ """
+ Return a geometry with a 3D coordinate of value 'zValue' forced upon it.
No change happens if the geometry is already 3D
+ :param zValue: Optional value of z coordinate to be potentially added,
default value is 0.0
+ :param geometry: Geometry column to make 3D
+ :return: 3D geometry with either already present z coordinate if any, or
zcoordinate with given zValue
+ """
+ args = (geometry, zValue)
+ return _call_st_function("ST_Force3D", args)
diff --git a/python/tests/sql/test_dataframe_api.py
b/python/tests/sql/test_dataframe_api.py
index 1ece9f69..f9fbfc82 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -85,6 +85,7 @@ test_configurations = [
(stf.ST_ExteriorRing, ("geom",), "triangle_geom", "", "LINESTRING (0 0, 1
0, 1 1, 0 0)"),
(stf.ST_FlipCoordinates, ("point",), "point_geom", "", "POINT (1 0)"),
(stf.ST_Force_2D, ("point",), "point_geom", "", "POINT (0 1)"),
+ (stf.ST_Force3D, ("point", 1), "point_geom", "", "POINT Z (0 1 1)"),
(stf.ST_GeometricMedian, ("multipoint",), "multipoint_geom", "", "POINT
(22.500002656424286 21.250001168173426)"),
(stf.ST_GeometryN, ("geom", 0), "multipoint", "", "POINT (0 0)"),
(stf.ST_GeometryType, ("point",), "point_geom", "", "ST_Point"),
@@ -111,6 +112,7 @@ test_configurations = [
(stf.ST_NPoints, ("line",), "linestring_geom", "", 6),
(stf.ST_NumGeometries, ("geom",), "multipoint", "", 2),
(stf.ST_NumInteriorRings, ("geom",), "geom_with_hole", "", 1),
+ (stf.ST_NumPoints, ("line",), "linestring_geom", "", 6),
(stf.ST_PointN, ("line", 2), "linestring_geom", "", "POINT (1 0)"),
(stf.ST_PointOnSurface, ("line",), "linestring_geom", "", "POINT (2 0)"),
(stf.ST_PrecisionReduce, ("geom", 1), "precision_reduce_point", "", "POINT
(0.1 0.2)"),
diff --git a/python/tests/sql/test_function.py
b/python/tests/sql/test_function.py
index ba665774..18b38c11 100644
--- a/python/tests/sql/test_function.py
+++ b/python/tests/sql/test_function.py
@@ -1079,3 +1079,10 @@ class TestPredicateJoin(TestBase):
actual = self.spark.sql("SELECT
ST_NumPoints(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'))").take(1)[0][0]
expected = 3
assert expected == actual
+
+ def test_force3D(self):
+ expected = 3
+ actualDf = self.spark.sql("SELECT
ST_Force3D(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'), 1.1) AS geom")
+ actual = actualDf.selectExpr("ST_NDims(geom)").take(1)[0][0]
+ assert expected == actual
+
diff --git a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index eede0a13..9edc563f 100644
--- a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -148,6 +148,7 @@ object Catalog {
function[ST_AreaSpheroid](),
function[ST_LengthSpheroid](),
function[ST_NumPoints](),
+ function[ST_Force3D](0.0),
// Expression for rasters
function[RS_NormalizedDifference](),
function[RS_Mean](),
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 2a6dfde3..e70f1473 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -988,3 +988,11 @@ case class ST_NumPoints(inputExpressions: Seq[Expression])
}
}
+case class ST_Force3D(inputExpressions: Seq[Expression])
+ extends InferredBinaryExpression(Functions.force3D) with FoldableExpression {
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
{
+ copy(inputExpressions = newChildren)
+ }
+}
+
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index 94e0874d..6101222b 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -305,4 +305,12 @@ object st_functions extends DataFrameAPI {
def ST_NumPoints(geometry: Column): Column =
wrapExpression[ST_NumPoints](geometry)
def ST_NumPoints(geometry: String): Column =
wrapExpression[ST_NumPoints](geometry)
+
+ def ST_Force3D(geometry: Column): Column =
wrapExpression[ST_Force3D](geometry, 0.0)
+
+ def ST_Force3D(geometry: String): Column =
wrapExpression[ST_Force3D](geometry, 0.0)
+
+ def ST_Force3D(geometry: Column, zValue: Column): Column =
wrapExpression[ST_Force3D](geometry, zValue)
+
+ def ST_Force3D(geometry: String, zValue: Double): Column =
wrapExpression[ST_Force3D](geometry, zValue)
}
diff --git
a/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
b/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index e3eaf8ff..70b5aa69 100644
---
a/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++
b/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -27,6 +27,7 @@ import
org.apache.spark.sql.sedona_sql.expressions.st_functions._
import org.apache.spark.sql.sedona_sql.expressions.st_predicates._
import org.apache.spark.sql.sedona_sql.expressions.st_aggregates._
import org.junit.Assert.assertEquals
+import org.locationtech.jts.io.WKTWriter
import scala.collection.mutable
@@ -957,5 +958,17 @@ class dataFrameAPITestScala extends TestBaseScala {
val expectedResult = 3
assert(actualResult == expectedResult)
}
+
+ it("Passed ST_Force3D") {
+ val lineDf = sparkSession.sql("SELECT ST_GeomFromWKT('LINESTRING (0 1, 1
0, 2 0)') AS geom")
+ val expectedGeom = "LINESTRING Z(0 1 2.3, 1 0 2.3, 2 0 2.3)"
+ val expectedGeomDefaultValue = "LINESTRING Z(0 1 0, 1 0 0, 2 0 0)"
+ val wktWriter = new WKTWriter(3)
+ val forcedGeom = lineDf.select(ST_Force3D("geom",
2.3)).take(1)(0).get(0).asInstanceOf[Geometry]
+ assertEquals(expectedGeom, wktWriter.write(forcedGeom))
+ val lineDfDefaultValue = sparkSession.sql("SELECT
ST_GeomFromWKT('LINESTRING (0 1, 1 0, 2 0)') AS geom")
+ val actualGeomDefaultValue =
lineDfDefaultValue.select(ST_Force3D("geom")).take(1)(0).get(0).asInstanceOf[Geometry]
+ assertEquals(expectedGeomDefaultValue,
wktWriter.write(actualGeomDefaultValue))
+ }
}
}
diff --git
a/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
b/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index 21e897a5..b2ad34e0 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -21,6 +21,7 @@ package org.apache.sedona.sql
import org.apache.commons.codec.binary.Hex
import org.apache.sedona.sql.implicits._
+import org.apache.spark.sql.catalyst.expressions.{GenericRow,
GenericRowWithSchema}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row}
import org.geotools.referencing.CRS
@@ -1921,4 +1922,22 @@ class functionTestScala extends TestBaseScala with
Matchers with GeometrySample
assertEquals(expected, actual)
}
}
+
+ it("should pass ST_Force3D") {
+ val geomTestCases = Map(
+ ("'LINESTRING (0 1, 1 0, 2 0)'") -> ("'LINESTRING Z(0 1 1, 1 0 1, 2 0
1)'", "'LINESTRING Z(0 1 0, 1 0 0, 2 0 0)'"),
+ ("'LINESTRING Z(0 1 3, 1 0 3, 2 0 3)'") -> ("'LINESTRING Z(0 1 3, 1 0 3,
2 0 3)'", "'LINESTRING Z(0 1 3, 1 0 3, 2 0 3)'"),
+ ("'LINESTRING EMPTY'") -> ("'LINESTRING EMPTY'", "'LINESTRING EMPTY'")
+ )
+ for (((geom), expectedResult) <- geomTestCases) {
+ val df = sparkSession.sql(s"SELECT
ST_AsText(ST_Force3D(ST_GeomFromWKT($geom), 1)) AS geom, " + s"$expectedResult")
+ val dfDefaultValue = sparkSession.sql(s"SELECT
ST_AsText(ST_Force3D(ST_GeomFromWKT($geom))) AS geom, " + s"$expectedResult")
+ val actual = df.take(1)(0).get(0).asInstanceOf[String]
+ val expected =
df.take(1)(0).get(1).asInstanceOf[GenericRowWithSchema].get(0).asInstanceOf[String]
+ val actualDefaultValue =
dfDefaultValue.take(1)(0).get(0).asInstanceOf[String]
+ val expectedDefaultValue =
dfDefaultValue.take(1)(0).get(1).asInstanceOf[GenericRowWithSchema].get(1).asInstanceOf[String]
+ assertEquals(expected, actual)
+ assertEquals(expectedDefaultValue, actualDefaultValue);
+ }
+ }
}