This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 9cdb052c7 [SEDONA-607] Return geometry with ST functions with
exceptions (#1525)
9cdb052c7 is described below
commit 9cdb052c7455dba2adf6f6e2e956eddb5f7ac39e
Author: Feng Zhang <[email protected]>
AuthorDate: Mon Jul 22 19:55:00 2024 -0700
[SEDONA-607] Return geometry with ST functions with exceptions (#1525)
* [SEDONA-607] New function ST_SafeGeom to return geometry with exceptions
* [SEDONA-607] Include Geometry in ST Function Exceptions
* throw exception on serialization of geometry
* refactor the code
* fix unit tests
* fix additional tests
* refactor to add exception handling in InferredExpression
* add tests and comments
* Refactor InferredExpression to include the input row in the exception
message
* reformat
---
.../sql/sedona_sql/expressions/Constructors.scala | 167 ++++++++++++---------
.../sql/sedona_sql/expressions/Functions.scala | 80 ++++++----
.../expressions/InferredExpression.scala | 66 +++++++-
.../sql/sedona_sql/expressions/Predicates.scala | 11 +-
.../expressions/collect/ST_Collect.scala | 41 ++---
.../apache/sedona/sql/constructorTestScala.scala | 6 +-
.../apache/sedona/sql/dataFrameAPITestScala.scala | 16 ++
7 files changed, 263 insertions(+), 124 deletions(-)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
index b1a82ba10..6c3212e6c 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
@@ -160,16 +160,21 @@ case class ST_GeomFromWKB(inputExpressions:
Seq[Expression])
override def nullable: Boolean = true
override def eval(inputRow: InternalRow): Any = {
- (inputExpressions.head.eval(inputRow)) match {
- case (geomString: UTF8String) => {
- // Parse UTF-8 encoded wkb string
- Constructors.geomFromText(geomString.toString,
FileDataSplitter.WKB).toGenericArrayData
- }
- case (wkb: Array[Byte]) => {
- // convert raw wkb byte array to geometry
- Constructors.geomFromWKB(wkb).toGenericArrayData
+ try {
+ (inputExpressions.head.eval(inputRow)) match {
+ case (geomString: UTF8String) => {
+ // Parse UTF-8 encoded wkb string
+ Constructors.geomFromText(geomString.toString,
FileDataSplitter.WKB).toGenericArrayData
+ }
+ case (wkb: Array[Byte]) => {
+ // convert raw wkb byte array to geometry
+ Constructors.geomFromWKB(wkb).toGenericArrayData
+ }
+ case null => null
}
- case null => null
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
}
}
@@ -196,16 +201,21 @@ case class ST_GeomFromEWKB(inputExpressions:
Seq[Expression])
override def nullable: Boolean = true
override def eval(inputRow: InternalRow): Any = {
- (inputExpressions.head.eval(inputRow)) match {
- case (geomString: UTF8String) => {
- // Parse UTF-8 encoded wkb string
- Constructors.geomFromText(geomString.toString,
FileDataSplitter.WKB).toGenericArrayData
- }
- case (wkb: Array[Byte]) => {
- // convert raw wkb byte array to geometry
- Constructors.geomFromWKB(wkb).toGenericArrayData
+ try {
+ (inputExpressions.head.eval(inputRow)) match {
+ case (geomString: UTF8String) => {
+ // Parse UTF-8 encoded wkb string
+ Constructors.geomFromText(geomString.toString,
FileDataSplitter.WKB).toGenericArrayData
+ }
+ case (wkb: Array[Byte]) => {
+ // convert raw wkb byte array to geometry
+ Constructors.geomFromWKB(wkb).toGenericArrayData
+ }
+ case null => null
}
- case null => null
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
}
}
@@ -238,21 +248,26 @@ case class ST_LineFromWKB(inputExpressions:
Seq[Expression])
if (inputExpressions.length > 1)
inputExpressions(1).eval(inputRow).asInstanceOf[Int]
else -1
- wkb match {
- case geomString: UTF8String =>
- // Parse UTF-8 encoded WKB string
- val geom = Constructors.lineStringFromText(geomString.toString, "wkb")
- if (geom.getGeometryType == "LineString") {
- (if (srid != -1) Functions.setSRID(geom, srid) else
geom).toGenericArrayData
- } else {
- null
- }
-
- case wkbArray: Array[Byte] =>
- // Convert raw WKB byte array to geometry
- Constructors.lineFromWKB(wkbArray, srid).toGenericArrayData
-
- case _ => null
+ try {
+ wkb match {
+ case geomString: UTF8String =>
+ // Parse UTF-8 encoded WKB string
+ val geom = Constructors.lineStringFromText(geomString.toString,
"wkb")
+ if (geom.getGeometryType == "LineString") {
+ (if (srid != -1) Functions.setSRID(geom, srid) else
geom).toGenericArrayData
+ } else {
+ null
+ }
+
+ case wkbArray: Array[Byte] =>
+ // Convert raw WKB byte array to geometry
+ Constructors.lineFromWKB(wkbArray, srid).toGenericArrayData
+
+ case _ => null
+ }
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
}
}
@@ -287,21 +302,26 @@ case class ST_LinestringFromWKB(inputExpressions:
Seq[Expression])
if (inputExpressions.length > 1)
inputExpressions(1).eval(inputRow).asInstanceOf[Int]
else -1
- wkb match {
- case geomString: UTF8String =>
- // Parse UTF-8 encoded WKB string
- val geom = Constructors.lineStringFromText(geomString.toString, "wkb")
- if (geom.getGeometryType == "LineString") {
- (if (srid != -1) Functions.setSRID(geom, srid) else
geom).toGenericArrayData
- } else {
- null
- }
-
- case wkbArray: Array[Byte] =>
- // Convert raw WKB byte array to geometry
- Constructors.lineFromWKB(wkbArray, srid).toGenericArrayData
-
- case _ => null
+ try {
+ wkb match {
+ case geomString: UTF8String =>
+ // Parse UTF-8 encoded WKB string
+ val geom = Constructors.lineStringFromText(geomString.toString,
"wkb")
+ if (geom.getGeometryType == "LineString") {
+ (if (srid != -1) Functions.setSRID(geom, srid) else
geom).toGenericArrayData
+ } else {
+ null
+ }
+
+ case wkbArray: Array[Byte] =>
+ // Convert raw WKB byte array to geometry
+ Constructors.lineFromWKB(wkbArray, srid).toGenericArrayData
+
+ case _ => null
+ }
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
}
}
@@ -336,21 +356,26 @@ case class ST_PointFromWKB(inputExpressions:
Seq[Expression])
if (inputExpressions.length > 1)
inputExpressions(1).eval(inputRow).asInstanceOf[Int]
else -1
- wkb match {
- case geomString: UTF8String =>
- // Parse UTF-8 encoded WKB string
- val geom = Constructors.pointFromText(geomString.toString, "wkb")
- if (geom.getGeometryType == "Point") {
- (if (srid != -1) Functions.setSRID(geom, srid) else
geom).toGenericArrayData
- } else {
- null
- }
-
- case wkbArray: Array[Byte] =>
- // Convert raw WKB byte array to geometry
- Constructors.pointFromWKB(wkbArray, srid).toGenericArrayData
-
- case _ => null
+ try {
+ wkb match {
+ case geomString: UTF8String =>
+ // Parse UTF-8 encoded WKB string
+ val geom = Constructors.pointFromText(geomString.toString, "wkb")
+ if (geom.getGeometryType == "Point") {
+ (if (srid != -1) Functions.setSRID(geom, srid) else
geom).toGenericArrayData
+ } else {
+ null
+ }
+
+ case wkbArray: Array[Byte] =>
+ // Convert raw WKB byte array to geometry
+ Constructors.pointFromWKB(wkbArray, srid).toGenericArrayData
+
+ case _ => null
+ }
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
}
}
@@ -387,12 +412,18 @@ case class ST_GeomFromGeoJSON(inputExpressions:
Seq[Expression])
override def eval(inputRow: InternalRow): Any = {
val geomString =
inputExpressions.head.eval(inputRow).asInstanceOf[UTF8String].toString
- val geometry = Constructors.geomFromText(geomString,
FileDataSplitter.GEOJSON)
- // If the user specify a bunch of attributes to go with each geometry, we
need to store all of them in this geometry
- if (inputExpressions.length > 1) {
- geometry.setUserData(generateUserData(minInputLength, inputExpressions,
inputRow))
+ try {
+
+ val geometry = Constructors.geomFromText(geomString,
FileDataSplitter.GEOJSON)
+ // If the user specify a bunch of attributes to go with each geometry,
we need to store all of them in this geometry
+ if (inputExpressions.length > 1) {
+ geometry.setUserData(generateUserData(minInputLength,
inputExpressions, inputRow))
+ }
+ GeometrySerializer.serialize(geometry)
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
}
- GeometrySerializer.serialize(geometry)
}
override def dataType: DataType = GeometryUDT
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 70e3582f2..cf52d5bc7 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -332,13 +332,18 @@ case class ST_IsValidDetail(children: Seq[Expression])
throw new IllegalArgumentException(s"Invalid number of arguments:
$nArgs")
}
- if (validDetail.location == null) {
- return InternalRow.fromSeq(Seq(validDetail.valid, null, null))
- }
+ try {
+ if (validDetail.location == null) {
+ return InternalRow.fromSeq(Seq(validDetail.valid, null, null))
+ }
- val serLocation = GeometrySerializer.serialize(validDetail.location)
- InternalRow.fromSeq(
- Seq(validDetail.valid, UTF8String.fromString(validDetail.reason),
serLocation))
+ val serLocation = GeometrySerializer.serialize(validDetail.location)
+ InternalRow.fromSeq(
+ Seq(validDetail.valid, UTF8String.fromString(validDetail.reason),
serLocation))
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(input, children,
e)
+ }
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression = {
@@ -609,14 +614,20 @@ case class ST_MinimumBoundingRadius(inputExpressions:
Seq[Expression])
override def eval(input: InternalRow): Any = {
val expr = inputExpressions(0)
- val geometry = expr match {
- case s: SerdeAware => s.evalWithoutSerialization(input)
- case _ => expr.toGeometry(input)
- }
- geometry match {
- case geometry: Geometry => getMinimumBoundingRadius(geometry)
- case _ => null
+ try {
+ val geometry = expr match {
+ case s: SerdeAware => s.evalWithoutSerialization(input)
+ case _ => expr.toGeometry(input)
+ }
+
+ geometry match {
+ case geometry: Geometry => getMinimumBoundingRadius(geometry)
+ case _ => null
+ }
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(input,
inputExpressions, e)
}
}
@@ -910,17 +921,23 @@ case class ST_SubDivideExplode(children: Seq[Expression])
extends Generator with
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val geometryRaw = children.head
val maxVerticesRaw = children(1)
- geometryRaw.toGeometry(input) match {
- case geom: Geometry =>
- ArrayData.toArrayData(
- Functions.subDivide(geom,
maxVerticesRaw.toInt(input)).map(_.toGenericArrayData))
- Functions
- .subDivide(geom, maxVerticesRaw.toInt(input))
- .map(_.toGenericArrayData)
- .map(InternalRow(_))
- case _ => new Array[InternalRow](0)
+ try {
+ geometryRaw.toGeometry(input) match {
+ case geom: Geometry =>
+ ArrayData.toArrayData(
+ Functions.subDivide(geom,
maxVerticesRaw.toInt(input)).map(_.toGenericArrayData))
+ Functions
+ .subDivide(geom, maxVerticesRaw.toInt(input))
+ .map(_.toGenericArrayData)
+ .map(InternalRow(_))
+ case _ => new Array[InternalRow](0)
+ }
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(input, children,
e)
}
}
+
override def elementSchema: StructType = {
new StructType()
.add("geom", GeometryUDT, true)
@@ -978,13 +995,18 @@ case class ST_MaximumInscribedCircle(children:
Seq[Expression])
with CodegenFallback {
override def eval(input: InternalRow): Any = {
- val geometry = children.head.toGeometry(input)
- var inscribedCircle: InscribedCircle = null
- inscribedCircle = Functions.maximumInscribedCircle(geometry)
-
- val serCenter = GeometrySerializer.serialize(inscribedCircle.center)
- val serNearest = GeometrySerializer.serialize(inscribedCircle.nearest)
- InternalRow.fromSeq(Seq(serCenter, serNearest, inscribedCircle.radius))
+ try {
+ val geometry = children.head.toGeometry(input)
+ var inscribedCircle: InscribedCircle = null
+ inscribedCircle = Functions.maximumInscribedCircle(geometry)
+
+ val serCenter = GeometrySerializer.serialize(inscribedCircle.center)
+ val serNearest = GeometrySerializer.serialize(inscribedCircle.nearest)
+ InternalRow.fromSeq(Seq(serCenter, serNearest, inscribedCircle.radius))
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(input, children,
e)
+ }
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression = {
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
index c068b546a..44d16b08a 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
@@ -19,11 +19,11 @@
package org.apache.spark.sql.sedona_sql.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Expression,
ImplicitCastInputTypes}
+import org.apache.spark.sql.catalyst.expressions.{Expression,
ImplicitCastInputTypes, Literal}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType,
DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
+import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType,
DataType, DataTypes, DoubleType, IntegerType, LongType, StringType,
StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.Geometry
import org.apache.spark.sql.sedona_sql.expressions.implicits._
@@ -33,6 +33,12 @@ import scala.reflect.runtime.universe.TypeTag
import scala.reflect.runtime.universe.Type
import scala.reflect.runtime.universe.typeOf
+/**
+ * Custom exception to include the input row and the original exception
message.
+ */
+class InferredExpressionException(message: String, cause: Throwable)
+ extends Exception(s"$message, cause: " + cause.getMessage, cause)
+
/**
* This is the base class for wrapping Java/Scala functions as a catalyst
expression in Spark SQL.
* @param fSeq
@@ -74,8 +80,60 @@ abstract class InferredExpression(fSeq: InferrableFunction*)
private lazy val argExtractors: Array[InternalRow => Any] =
f.buildExtractors(inputExpressions)
private lazy val evaluator: InternalRow => Any =
f.evaluatorBuilder(argExtractors)
- override def eval(input: InternalRow): Any = f.serializer(evaluator(input))
- override def evalWithoutSerialization(input: InternalRow): Any =
evaluator(input)
+ private def findAllLiterals(expression: Expression): Seq[Literal] = {
+ expression match {
+ case lit: Literal => Seq(lit)
+ case _ => expression.children.flatMap(findAllLiterals)
+ }
+ }
+
+ private def findAllLiteralsInExpressions(expressions: Seq[Expression]):
Seq[String] = {
+ expressions.flatMap(findAllLiterals).map(_.value.toString)
+ }
+
+ override def eval(input: InternalRow): Any = {
+
+ try {
+ f.serializer(evaluator(input))
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(input,
inputExpressions, e)
+ }
+ }
+
+ override def evalWithoutSerialization(input: InternalRow): Any = {
+ try {
+ evaluator(input)
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(input,
inputExpressions, e)
+ }
+ }
+}
+
+object InferredExpression {
+ def throwExpressionInferenceException(
+ input: InternalRow,
+ inputExpressions: Seq[Expression],
+ e: Exception): Nothing = {
+ val literalsAsStrings = if (input == null) {
+ // In case no input row is provided, we can't extract literals from the
input expressions.
+ inputExpressions.flatMap(findAllLiterals).map(_.value.toString)
+ } else {
+ Seq.empty[String]
+ }
+ val literalsOrInputString = literalsAsStrings.mkString(", ")
+ throw new InferredExpressionException(
+ s"Exception occurred while evaluating expression - source:
[$literalsOrInputString]",
+ e)
+ }
+
+ def findAllLiterals(expression: Expression): Seq[Literal] = {
+ expression match {
+ case lit: Literal => Seq(lit)
+ case _ => expression.children.flatMap(findAllLiterals)
+ }
+ }
}
// This is a compile time type shield for the types we are able to infer.
Anything
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
index d2e7c2ccb..abdc066b0 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
@@ -55,9 +55,14 @@ abstract class ST_Predicate
if (rightArray == null) {
null
} else {
- val leftGeometry = GeometrySerializer.deserialize(leftArray)
- val rightGeometry = GeometrySerializer.deserialize(rightArray)
- evalGeom(leftGeometry, rightGeometry)
+ try {
+ val leftGeometry = GeometrySerializer.deserialize(leftArray)
+ val rightGeometry = GeometrySerializer.deserialize(rightArray)
+ evalGeom(leftGeometry, rightGeometry)
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
+ }
}
}
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
index bae125990..95846526b 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.expressions.implicits._
-import org.apache.spark.sql.sedona_sql.expressions.SerdeAware
+import org.apache.spark.sql.sedona_sql.expressions.{InferredExpression,
SerdeAware}
import org.apache.spark.sql.types.{ArrayType, _}
import org.locationtech.jts.geom.Geometry
@@ -44,24 +44,29 @@ case class ST_Collect(inputExpressions: Seq[Expression])
override def evalWithoutSerialization(input: InternalRow): Any = {
val firstElement = inputExpressions.head
- firstElement.dataType match {
- case ArrayType(elementType, _) =>
- elementType match {
- case _: GeometryUDT =>
- val data = firstElement.eval(input).asInstanceOf[ArrayData]
- val numElements = data.numElements()
- val geomElements = (0 until numElements)
- .map(element => data.getBinary(element))
- .filter(_ != null)
- .map(_.toGeometry)
+ try {
+ firstElement.dataType match {
+ case ArrayType(elementType, _) =>
+ elementType match {
+ case _: GeometryUDT =>
+ val data = firstElement.eval(input).asInstanceOf[ArrayData]
+ val numElements = data.numElements()
+ val geomElements = (0 until numElements)
+ .map(element => data.getBinary(element))
+ .filter(_ != null)
+ .map(_.toGeometry)
- Functions.createMultiGeometry(geomElements.toArray)
- case _ => Functions.createMultiGeometry(Array())
- }
- case _ =>
- val geomElements =
- inputExpressions.map(_.toGeometry(input)).filter(_ != null)
- Functions.createMultiGeometry(geomElements.toArray)
+ Functions.createMultiGeometry(geomElements.toArray)
+ case _ => Functions.createMultiGeometry(Array())
+ }
+ case _ =>
+ val geomElements =
+ inputExpressions.map(_.toGeometry(input)).filter(_ != null)
+ Functions.createMultiGeometry(geomElements.toArray)
+ }
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(input,
inputExpressions, e)
}
}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
index 8bf503927..2c9564a3c 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
@@ -216,7 +216,8 @@ class constructorTestScala extends TestBaseScala {
val thrown = intercept[Exception] {
sparkSession.sql("SELECT ST_GeomFromWKT('not wkt')").collect()
}
- assert(thrown.getMessage == "Unknown geometry type: NOT (line 1)")
+ assert(
+ thrown.getMessage == "Exception occurred while evaluating expression -
source: [not wkt, 0], cause: Unknown geometry type: NOT (line 1)")
}
it("Passed ST_GeomFromEWKT") {
@@ -245,7 +246,8 @@ class constructorTestScala extends TestBaseScala {
val thrown = intercept[Exception] {
sparkSession.sql("SELECT ST_GeomFromEWKT('not wkt')").collect()
}
- assert(thrown.getMessage == "Unknown geometry type: NOT (line 1)")
+ assert(
+ thrown.getMessage == "Exception occurred while evaluating expression -
source: [not wkt], cause: Unknown geometry type: NOT (line 1)")
}
it("Passed ST_LineFromText") {
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index bb83932c6..79859e427 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -21,6 +21,7 @@ package org.apache.sedona.sql
import org.apache.commons.codec.binary.Hex
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.{array, col, element_at, lit}
+import org.apache.spark.sql.sedona_sql.expressions.InferredExpressionException
import org.apache.spark.sql.sedona_sql.expressions.st_aggregates._
import org.apache.spark.sql.sedona_sql.expressions.st_constructors._
import org.apache.spark.sql.sedona_sql.expressions.st_functions._
@@ -2215,5 +2216,20 @@ class dataFrameAPITestScala extends TestBaseScala {
"SRID=4326;POLYGON ((-0.4546817643920842 0.5948176504236309,
1.4752502925921425 0.0700679430157733, 2 2, 0.0700679430157733
2.5247497074078575, 0.7726591178039579 1.2974088252118154, -0.4546817643920842
0.5948176504236309))"
assert(expected.equals(actual))
}
+
+ it("Passed returning exception with geometry when exception is thrown") {
+ val baseDf = sparkSession.sql(
+ "SELECT ST_GeomFromEWKT('SRID=4326;POLYGON ((0 0, 2 0, 2 2, 0 2, 1 1,
0 0))') AS geom1, ST_GeomFromEWKT('SRID=4326;POLYGON ((0 0, 1 0, 1 1, 0 1, 1 1,
0 0))') AS geom2")
+
+ // Use intercept to assert that an exception is thrown
+ val exception = intercept[InferredExpressionException] {
+ baseDf.select(ST_Rotate("geom1", 50, "geom2")).take(1)
+ }
+
+ // Check the exception message
+ assert(exception.getMessage.contains(
+ "[SRID=4326;POLYGON ((0 0, 2 0, 2 2, 0 2, 1 1, 0 0)), 50.0,
SRID=4326;POLYGON ((0 0, 1 0, 1 1, 0 1, 1 1, 0 0))]"))
+ assert(exception.getMessage.contains("The origin must be a non-empty
Point geometry."))
+ }
}
}