This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 9cdb052c7 [SEDONA-607] Return geometry with ST functions with 
exceptions (#1525)
9cdb052c7 is described below

commit 9cdb052c7455dba2adf6f6e2e956eddb5f7ac39e
Author: Feng Zhang <[email protected]>
AuthorDate: Mon Jul 22 19:55:00 2024 -0700

    [SEDONA-607] Return geometry with ST functions with exceptions (#1525)
    
    * [SEDONA-607] New function ST_SafeGeom to return geometry with exceptions
    
    * [SEDONA-607] Include Geometry in ST Function Exceptions
    
    * throw exception on serialization of geometry
    
    * refactor the code
    
    * fix unit tests
    
    * fix additional tests
    
    * refactor to add exception handling in InferredExpression
    
    * add tests and comments
    
    * Refactor InferredExpression to include the input row in the exception 
message
    
    * reformat
---
 .../sql/sedona_sql/expressions/Constructors.scala  | 167 ++++++++++++---------
 .../sql/sedona_sql/expressions/Functions.scala     |  80 ++++++----
 .../expressions/InferredExpression.scala           |  66 +++++++-
 .../sql/sedona_sql/expressions/Predicates.scala    |  11 +-
 .../expressions/collect/ST_Collect.scala           |  41 ++---
 .../apache/sedona/sql/constructorTestScala.scala   |   6 +-
 .../apache/sedona/sql/dataFrameAPITestScala.scala  |  16 ++
 7 files changed, 263 insertions(+), 124 deletions(-)

diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
index b1a82ba10..6c3212e6c 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
@@ -160,16 +160,21 @@ case class ST_GeomFromWKB(inputExpressions: 
Seq[Expression])
   override def nullable: Boolean = true
 
   override def eval(inputRow: InternalRow): Any = {
-    (inputExpressions.head.eval(inputRow)) match {
-      case (geomString: UTF8String) => {
-        // Parse UTF-8 encoded wkb string
-        Constructors.geomFromText(geomString.toString, 
FileDataSplitter.WKB).toGenericArrayData
-      }
-      case (wkb: Array[Byte]) => {
-        // convert raw wkb byte array to geometry
-        Constructors.geomFromWKB(wkb).toGenericArrayData
+    try {
+      (inputExpressions.head.eval(inputRow)) match {
+        case (geomString: UTF8String) => {
+          // Parse UTF-8 encoded wkb string
+          Constructors.geomFromText(geomString.toString, 
FileDataSplitter.WKB).toGenericArrayData
+        }
+        case (wkb: Array[Byte]) => {
+          // convert raw wkb byte array to geometry
+          Constructors.geomFromWKB(wkb).toGenericArrayData
+        }
+        case null => null
       }
-      case null => null
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(inputRow, 
inputExpressions, e)
     }
   }
 
@@ -196,16 +201,21 @@ case class ST_GeomFromEWKB(inputExpressions: 
Seq[Expression])
   override def nullable: Boolean = true
 
   override def eval(inputRow: InternalRow): Any = {
-    (inputExpressions.head.eval(inputRow)) match {
-      case (geomString: UTF8String) => {
-        // Parse UTF-8 encoded wkb string
-        Constructors.geomFromText(geomString.toString, 
FileDataSplitter.WKB).toGenericArrayData
-      }
-      case (wkb: Array[Byte]) => {
-        // convert raw wkb byte array to geometry
-        Constructors.geomFromWKB(wkb).toGenericArrayData
+    try {
+      (inputExpressions.head.eval(inputRow)) match {
+        case (geomString: UTF8String) => {
+          // Parse UTF-8 encoded wkb string
+          Constructors.geomFromText(geomString.toString, 
FileDataSplitter.WKB).toGenericArrayData
+        }
+        case (wkb: Array[Byte]) => {
+          // convert raw wkb byte array to geometry
+          Constructors.geomFromWKB(wkb).toGenericArrayData
+        }
+        case null => null
       }
-      case null => null
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(inputRow, 
inputExpressions, e)
     }
   }
 
@@ -238,21 +248,26 @@ case class ST_LineFromWKB(inputExpressions: 
Seq[Expression])
       if (inputExpressions.length > 1) 
inputExpressions(1).eval(inputRow).asInstanceOf[Int]
       else -1
 
-    wkb match {
-      case geomString: UTF8String =>
-        // Parse UTF-8 encoded WKB string
-        val geom = Constructors.lineStringFromText(geomString.toString, "wkb")
-        if (geom.getGeometryType == "LineString") {
-          (if (srid != -1) Functions.setSRID(geom, srid) else 
geom).toGenericArrayData
-        } else {
-          null
-        }
-
-      case wkbArray: Array[Byte] =>
-        // Convert raw WKB byte array to geometry
-        Constructors.lineFromWKB(wkbArray, srid).toGenericArrayData
-
-      case _ => null
+    try {
+      wkb match {
+        case geomString: UTF8String =>
+          // Parse UTF-8 encoded WKB string
+          val geom = Constructors.lineStringFromText(geomString.toString, 
"wkb")
+          if (geom.getGeometryType == "LineString") {
+            (if (srid != -1) Functions.setSRID(geom, srid) else 
geom).toGenericArrayData
+          } else {
+            null
+          }
+
+        case wkbArray: Array[Byte] =>
+          // Convert raw WKB byte array to geometry
+          Constructors.lineFromWKB(wkbArray, srid).toGenericArrayData
+
+        case _ => null
+      }
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(inputRow, 
inputExpressions, e)
     }
   }
 
@@ -287,21 +302,26 @@ case class ST_LinestringFromWKB(inputExpressions: 
Seq[Expression])
       if (inputExpressions.length > 1) 
inputExpressions(1).eval(inputRow).asInstanceOf[Int]
       else -1
 
-    wkb match {
-      case geomString: UTF8String =>
-        // Parse UTF-8 encoded WKB string
-        val geom = Constructors.lineStringFromText(geomString.toString, "wkb")
-        if (geom.getGeometryType == "LineString") {
-          (if (srid != -1) Functions.setSRID(geom, srid) else 
geom).toGenericArrayData
-        } else {
-          null
-        }
-
-      case wkbArray: Array[Byte] =>
-        // Convert raw WKB byte array to geometry
-        Constructors.lineFromWKB(wkbArray, srid).toGenericArrayData
-
-      case _ => null
+    try {
+      wkb match {
+        case geomString: UTF8String =>
+          // Parse UTF-8 encoded WKB string
+          val geom = Constructors.lineStringFromText(geomString.toString, 
"wkb")
+          if (geom.getGeometryType == "LineString") {
+            (if (srid != -1) Functions.setSRID(geom, srid) else 
geom).toGenericArrayData
+          } else {
+            null
+          }
+
+        case wkbArray: Array[Byte] =>
+          // Convert raw WKB byte array to geometry
+          Constructors.lineFromWKB(wkbArray, srid).toGenericArrayData
+
+        case _ => null
+      }
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(inputRow, 
inputExpressions, e)
     }
   }
 
@@ -336,21 +356,26 @@ case class ST_PointFromWKB(inputExpressions: 
Seq[Expression])
       if (inputExpressions.length > 1) 
inputExpressions(1).eval(inputRow).asInstanceOf[Int]
       else -1
 
-    wkb match {
-      case geomString: UTF8String =>
-        // Parse UTF-8 encoded WKB string
-        val geom = Constructors.pointFromText(geomString.toString, "wkb")
-        if (geom.getGeometryType == "Point") {
-          (if (srid != -1) Functions.setSRID(geom, srid) else 
geom).toGenericArrayData
-        } else {
-          null
-        }
-
-      case wkbArray: Array[Byte] =>
-        // Convert raw WKB byte array to geometry
-        Constructors.pointFromWKB(wkbArray, srid).toGenericArrayData
-
-      case _ => null
+    try {
+      wkb match {
+        case geomString: UTF8String =>
+          // Parse UTF-8 encoded WKB string
+          val geom = Constructors.pointFromText(geomString.toString, "wkb")
+          if (geom.getGeometryType == "Point") {
+            (if (srid != -1) Functions.setSRID(geom, srid) else 
geom).toGenericArrayData
+          } else {
+            null
+          }
+
+        case wkbArray: Array[Byte] =>
+          // Convert raw WKB byte array to geometry
+          Constructors.pointFromWKB(wkbArray, srid).toGenericArrayData
+
+        case _ => null
+      }
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(inputRow, 
inputExpressions, e)
     }
   }
 
@@ -387,12 +412,18 @@ case class ST_GeomFromGeoJSON(inputExpressions: 
Seq[Expression])
 
   override def eval(inputRow: InternalRow): Any = {
     val geomString = 
inputExpressions.head.eval(inputRow).asInstanceOf[UTF8String].toString
-    val geometry = Constructors.geomFromText(geomString, 
FileDataSplitter.GEOJSON)
-    // If the user specify a bunch of attributes to go with each geometry, we 
need to store all of them in this geometry
-    if (inputExpressions.length > 1) {
-      geometry.setUserData(generateUserData(minInputLength, inputExpressions, 
inputRow))
+    try {
+
+      val geometry = Constructors.geomFromText(geomString, 
FileDataSplitter.GEOJSON)
+      // If the user specify a bunch of attributes to go with each geometry, 
we need to store all of them in this geometry
+      if (inputExpressions.length > 1) {
+        geometry.setUserData(generateUserData(minInputLength, 
inputExpressions, inputRow))
+      }
+      GeometrySerializer.serialize(geometry)
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(inputRow, 
inputExpressions, e)
     }
-    GeometrySerializer.serialize(geometry)
   }
 
   override def dataType: DataType = GeometryUDT
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 70e3582f2..cf52d5bc7 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -332,13 +332,18 @@ case class ST_IsValidDetail(children: Seq[Expression])
       throw new IllegalArgumentException(s"Invalid number of arguments: 
$nArgs")
     }
 
-    if (validDetail.location == null) {
-      return InternalRow.fromSeq(Seq(validDetail.valid, null, null))
-    }
+    try {
+      if (validDetail.location == null) {
+        return InternalRow.fromSeq(Seq(validDetail.valid, null, null))
+      }
 
-    val serLocation = GeometrySerializer.serialize(validDetail.location)
-    InternalRow.fromSeq(
-      Seq(validDetail.valid, UTF8String.fromString(validDetail.reason), 
serLocation))
+      val serLocation = GeometrySerializer.serialize(validDetail.location)
+      InternalRow.fromSeq(
+        Seq(validDetail.valid, UTF8String.fromString(validDetail.reason), 
serLocation))
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(input, children, 
e)
+    }
   }
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression = {
@@ -609,14 +614,20 @@ case class ST_MinimumBoundingRadius(inputExpressions: 
Seq[Expression])
 
   override def eval(input: InternalRow): Any = {
     val expr = inputExpressions(0)
-    val geometry = expr match {
-      case s: SerdeAware => s.evalWithoutSerialization(input)
-      case _ => expr.toGeometry(input)
-    }
 
-    geometry match {
-      case geometry: Geometry => getMinimumBoundingRadius(geometry)
-      case _ => null
+    try {
+      val geometry = expr match {
+        case s: SerdeAware => s.evalWithoutSerialization(input)
+        case _ => expr.toGeometry(input)
+      }
+
+      geometry match {
+        case geometry: Geometry => getMinimumBoundingRadius(geometry)
+        case _ => null
+      }
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(input, 
inputExpressions, e)
     }
   }
 
@@ -910,17 +921,23 @@ case class ST_SubDivideExplode(children: Seq[Expression]) 
extends Generator with
   override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
     val geometryRaw = children.head
     val maxVerticesRaw = children(1)
-    geometryRaw.toGeometry(input) match {
-      case geom: Geometry =>
-        ArrayData.toArrayData(
-          Functions.subDivide(geom, 
maxVerticesRaw.toInt(input)).map(_.toGenericArrayData))
-        Functions
-          .subDivide(geom, maxVerticesRaw.toInt(input))
-          .map(_.toGenericArrayData)
-          .map(InternalRow(_))
-      case _ => new Array[InternalRow](0)
+    try {
+      geometryRaw.toGeometry(input) match {
+        case geom: Geometry =>
+          ArrayData.toArrayData(
+            Functions.subDivide(geom, 
maxVerticesRaw.toInt(input)).map(_.toGenericArrayData))
+          Functions
+            .subDivide(geom, maxVerticesRaw.toInt(input))
+            .map(_.toGenericArrayData)
+            .map(InternalRow(_))
+        case _ => new Array[InternalRow](0)
+      }
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(input, children, 
e)
     }
   }
+
   override def elementSchema: StructType = {
     new StructType()
       .add("geom", GeometryUDT, true)
@@ -978,13 +995,18 @@ case class ST_MaximumInscribedCircle(children: 
Seq[Expression])
     with CodegenFallback {
 
   override def eval(input: InternalRow): Any = {
-    val geometry = children.head.toGeometry(input)
-    var inscribedCircle: InscribedCircle = null
-    inscribedCircle = Functions.maximumInscribedCircle(geometry)
-
-    val serCenter = GeometrySerializer.serialize(inscribedCircle.center)
-    val serNearest = GeometrySerializer.serialize(inscribedCircle.nearest)
-    InternalRow.fromSeq(Seq(serCenter, serNearest, inscribedCircle.radius))
+    try {
+      val geometry = children.head.toGeometry(input)
+      var inscribedCircle: InscribedCircle = null
+      inscribedCircle = Functions.maximumInscribedCircle(geometry)
+
+      val serCenter = GeometrySerializer.serialize(inscribedCircle.center)
+      val serNearest = GeometrySerializer.serialize(inscribedCircle.nearest)
+      InternalRow.fromSeq(Seq(serCenter, serNearest, inscribedCircle.radius))
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(input, children, 
e)
+    }
   }
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression = {
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
index c068b546a..44d16b08a 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
@@ -19,11 +19,11 @@
 package org.apache.spark.sql.sedona_sql.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Expression, 
ImplicitCastInputTypes}
+import org.apache.spark.sql.catalyst.expressions.{Expression, 
ImplicitCastInputTypes, Literal}
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, 
DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
+import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, 
DataType, DataTypes, DoubleType, IntegerType, LongType, StringType, 
StructField, StructType}
 import org.apache.spark.unsafe.types.UTF8String
 import org.locationtech.jts.geom.Geometry
 import org.apache.spark.sql.sedona_sql.expressions.implicits._
@@ -33,6 +33,12 @@ import scala.reflect.runtime.universe.TypeTag
 import scala.reflect.runtime.universe.Type
 import scala.reflect.runtime.universe.typeOf
 
+/**
+ * Custom exception to include the input row and the original exception 
message.
+ */
+class InferredExpressionException(message: String, cause: Throwable)
+    extends Exception(s"$message, cause: " + cause.getMessage, cause)
+
 /**
  * This is the base class for wrapping Java/Scala functions as a catalyst 
expression in Spark SQL.
  * @param fSeq
@@ -74,8 +80,60 @@ abstract class InferredExpression(fSeq: InferrableFunction*)
   private lazy val argExtractors: Array[InternalRow => Any] = 
f.buildExtractors(inputExpressions)
   private lazy val evaluator: InternalRow => Any = 
f.evaluatorBuilder(argExtractors)
 
-  override def eval(input: InternalRow): Any = f.serializer(evaluator(input))
-  override def evalWithoutSerialization(input: InternalRow): Any = 
evaluator(input)
+  private def findAllLiterals(expression: Expression): Seq[Literal] = {
+    expression match {
+      case lit: Literal => Seq(lit)
+      case _ => expression.children.flatMap(findAllLiterals)
+    }
+  }
+
+  private def findAllLiteralsInExpressions(expressions: Seq[Expression]): 
Seq[String] = {
+    expressions.flatMap(findAllLiterals).map(_.value.toString)
+  }
+
+  override def eval(input: InternalRow): Any = {
+
+    try {
+      f.serializer(evaluator(input))
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(input, 
inputExpressions, e)
+    }
+  }
+
+  override def evalWithoutSerialization(input: InternalRow): Any = {
+    try {
+      evaluator(input)
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(input, 
inputExpressions, e)
+    }
+  }
+}
+
+object InferredExpression {
+  def throwExpressionInferenceException(
+      input: InternalRow,
+      inputExpressions: Seq[Expression],
+      e: Exception): Nothing = {
+    val literalsAsStrings = if (input == null) {
+      // In case no input row is provided, we can't extract literals from the 
input expressions.
+      inputExpressions.flatMap(findAllLiterals).map(_.value.toString)
+    } else {
+      Seq.empty[String]
+    }
+    val literalsOrInputString = literalsAsStrings.mkString(", ")
+    throw new InferredExpressionException(
+      s"Exception occurred while evaluating expression - source: 
[$literalsOrInputString]",
+      e)
+  }
+
+  def findAllLiterals(expression: Expression): Seq[Literal] = {
+    expression match {
+      case lit: Literal => Seq(lit)
+      case _ => expression.children.flatMap(findAllLiterals)
+    }
+  }
 }
 
 // This is a compile time type shield for the types we are able to infer. 
Anything
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
index d2e7c2ccb..abdc066b0 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
@@ -55,9 +55,14 @@ abstract class ST_Predicate
       if (rightArray == null) {
         null
       } else {
-        val leftGeometry = GeometrySerializer.deserialize(leftArray)
-        val rightGeometry = GeometrySerializer.deserialize(rightArray)
-        evalGeom(leftGeometry, rightGeometry)
+        try {
+          val leftGeometry = GeometrySerializer.deserialize(leftArray)
+          val rightGeometry = GeometrySerializer.deserialize(rightArray)
+          evalGeom(leftGeometry, rightGeometry)
+        } catch {
+          case e: Exception =>
+            InferredExpression.throwExpressionInferenceException(inputRow, 
inputExpressions, e)
+        }
       }
     }
   }
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
index bae125990..95846526b 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.sedona_sql.expressions.implicits._
-import org.apache.spark.sql.sedona_sql.expressions.SerdeAware
+import org.apache.spark.sql.sedona_sql.expressions.{InferredExpression, 
SerdeAware}
 import org.apache.spark.sql.types.{ArrayType, _}
 import org.locationtech.jts.geom.Geometry
 
@@ -44,24 +44,29 @@ case class ST_Collect(inputExpressions: Seq[Expression])
   override def evalWithoutSerialization(input: InternalRow): Any = {
     val firstElement = inputExpressions.head
 
-    firstElement.dataType match {
-      case ArrayType(elementType, _) =>
-        elementType match {
-          case _: GeometryUDT =>
-            val data = firstElement.eval(input).asInstanceOf[ArrayData]
-            val numElements = data.numElements()
-            val geomElements = (0 until numElements)
-              .map(element => data.getBinary(element))
-              .filter(_ != null)
-              .map(_.toGeometry)
+    try {
+      firstElement.dataType match {
+        case ArrayType(elementType, _) =>
+          elementType match {
+            case _: GeometryUDT =>
+              val data = firstElement.eval(input).asInstanceOf[ArrayData]
+              val numElements = data.numElements()
+              val geomElements = (0 until numElements)
+                .map(element => data.getBinary(element))
+                .filter(_ != null)
+                .map(_.toGeometry)
 
-            Functions.createMultiGeometry(geomElements.toArray)
-          case _ => Functions.createMultiGeometry(Array())
-        }
-      case _ =>
-        val geomElements =
-          inputExpressions.map(_.toGeometry(input)).filter(_ != null)
-        Functions.createMultiGeometry(geomElements.toArray)
+              Functions.createMultiGeometry(geomElements.toArray)
+            case _ => Functions.createMultiGeometry(Array())
+          }
+        case _ =>
+          val geomElements =
+            inputExpressions.map(_.toGeometry(input)).filter(_ != null)
+          Functions.createMultiGeometry(geomElements.toArray)
+      }
+    } catch {
+      case e: Exception =>
+        InferredExpression.throwExpressionInferenceException(input, 
inputExpressions, e)
     }
   }
 
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
index 8bf503927..2c9564a3c 100644
--- 
a/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
+++ 
b/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
@@ -216,7 +216,8 @@ class constructorTestScala extends TestBaseScala {
       val thrown = intercept[Exception] {
         sparkSession.sql("SELECT ST_GeomFromWKT('not wkt')").collect()
       }
-      assert(thrown.getMessage == "Unknown geometry type: NOT (line 1)")
+      assert(
+        thrown.getMessage == "Exception occurred while evaluating expression - 
source: [not wkt, 0], cause: Unknown geometry type: NOT (line 1)")
     }
 
     it("Passed ST_GeomFromEWKT") {
@@ -245,7 +246,8 @@ class constructorTestScala extends TestBaseScala {
       val thrown = intercept[Exception] {
         sparkSession.sql("SELECT ST_GeomFromEWKT('not wkt')").collect()
       }
-      assert(thrown.getMessage == "Unknown geometry type: NOT (line 1)")
+      assert(
+        thrown.getMessage == "Exception occurred while evaluating expression - 
source: [not wkt], cause: Unknown geometry type: NOT (line 1)")
     }
 
     it("Passed ST_LineFromText") {
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index bb83932c6..79859e427 100644
--- 
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++ 
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -21,6 +21,7 @@ package org.apache.sedona.sql
 import org.apache.commons.codec.binary.Hex
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.functions.{array, col, element_at, lit}
+import org.apache.spark.sql.sedona_sql.expressions.InferredExpressionException
 import org.apache.spark.sql.sedona_sql.expressions.st_aggregates._
 import org.apache.spark.sql.sedona_sql.expressions.st_constructors._
 import org.apache.spark.sql.sedona_sql.expressions.st_functions._
@@ -2215,5 +2216,20 @@ class dataFrameAPITestScala extends TestBaseScala {
         "SRID=4326;POLYGON ((-0.4546817643920842 0.5948176504236309, 
1.4752502925921425 0.0700679430157733, 2 2, 0.0700679430157733 
2.5247497074078575, 0.7726591178039579 1.2974088252118154, -0.4546817643920842 
0.5948176504236309))"
       assert(expected.equals(actual))
     }
+
+    it("Passed returning exception with geometry when exception is thrown") {
+      val baseDf = sparkSession.sql(
+        "SELECT ST_GeomFromEWKT('SRID=4326;POLYGON ((0 0, 2 0, 2 2, 0 2, 1 1, 
0 0))') AS geom1, ST_GeomFromEWKT('SRID=4326;POLYGON ((0 0, 1 0, 1 1, 0 1, 1 1, 
0 0))') AS geom2")
+
+      // Use intercept to assert that an exception is thrown
+      val exception = intercept[InferredExpressionException] {
+        baseDf.select(ST_Rotate("geom1", 50, "geom2")).take(1)
+      }
+
+      // Check the exception message
+      assert(exception.getMessage.contains(
+        "[SRID=4326;POLYGON ((0 0, 2 0, 2 2, 0 2, 1 1, 0 0)), 50.0, 
SRID=4326;POLYGON ((0 0, 1 0, 1 1, 0 1, 1 1, 0 0))]"))
+      assert(exception.getMessage.contains("The origin must be a non-empty 
Point geometry."))
+    }
   }
 }

Reply via email to