This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 1737174f0 [SEDONA-607] Fix error message enhancements for geometry
functions (#1555)
1737174f0 is described below
commit 1737174f097e02270c828e5aff26b2e1064b11c7
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Aug 21 15:14:30 2024 +0800
[SEDONA-607] Fix error message enhancements for geometry functions (#1555)
* Fix error message enhancements for geometry functions
* Fix compilation for scala 2.13
---
.../sql/sedona_sql/expressions/Constructors.scala | 43 ++++++-----
.../sql/sedona_sql/expressions/Functions.scala | 39 ++++++----
.../expressions/InferredExpression.scala | 89 ++++++++++++----------
.../sql/sedona_sql/expressions/Predicates.scala | 9 ++-
.../expressions/collect/ST_Collect.scala | 52 ++++++++-----
.../apache/sedona/sql/constructorTestScala.scala | 10 ++-
.../apache/sedona/sql/dataFrameAPITestScala.scala | 18 ++++-
7 files changed, 155 insertions(+), 105 deletions(-)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
index 6c3212e6c..34c8d0351 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
@@ -160,21 +160,20 @@ case class ST_GeomFromWKB(inputExpressions:
Seq[Expression])
override def nullable: Boolean = true
override def eval(inputRow: InternalRow): Any = {
+ val arg = inputExpressions.head.eval(inputRow)
try {
- (inputExpressions.head.eval(inputRow)) match {
- case (geomString: UTF8String) => {
+ arg match {
+ case geomString: UTF8String =>
// Parse UTF-8 encoded wkb string
Constructors.geomFromText(geomString.toString,
FileDataSplitter.WKB).toGenericArrayData
- }
- case (wkb: Array[Byte]) => {
+ case wkb: Array[Byte] =>
// convert raw wkb byte array to geometry
Constructors.geomFromWKB(wkb).toGenericArrayData
- }
case null => null
}
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
+
InferredExpression.throwExpressionInferenceException(getClass.getSimpleName,
Seq(arg), e)
}
}
@@ -201,21 +200,20 @@ case class ST_GeomFromEWKB(inputExpressions:
Seq[Expression])
override def nullable: Boolean = true
override def eval(inputRow: InternalRow): Any = {
+ val arg = inputExpressions.head.eval(inputRow)
try {
- (inputExpressions.head.eval(inputRow)) match {
- case (geomString: UTF8String) => {
+ arg match {
+ case geomString: UTF8String =>
// Parse UTF-8 encoded wkb string
Constructors.geomFromText(geomString.toString,
FileDataSplitter.WKB).toGenericArrayData
- }
- case (wkb: Array[Byte]) => {
+ case wkb: Array[Byte] =>
// convert raw wkb byte array to geometry
Constructors.geomFromWKB(wkb).toGenericArrayData
- }
case null => null
}
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
+
InferredExpression.throwExpressionInferenceException(getClass.getSimpleName,
Seq(arg), e)
}
}
@@ -267,7 +265,10 @@ case class ST_LineFromWKB(inputExpressions:
Seq[Expression])
}
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ Seq(wkb, srid),
+ e)
}
}
@@ -321,7 +322,10 @@ case class ST_LinestringFromWKB(inputExpressions:
Seq[Expression])
}
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ Seq(wkb, srid),
+ e)
}
}
@@ -375,7 +379,10 @@ case class ST_PointFromWKB(inputExpressions:
Seq[Expression])
}
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ Seq(wkb, srid),
+ e)
}
}
@@ -413,7 +420,6 @@ case class ST_GeomFromGeoJSON(inputExpressions:
Seq[Expression])
override def eval(inputRow: InternalRow): Any = {
val geomString =
inputExpressions.head.eval(inputRow).asInstanceOf[UTF8String].toString
try {
-
val geometry = Constructors.geomFromText(geomString,
FileDataSplitter.GEOJSON)
// If the user specify a bunch of attributes to go with each geometry,
we need to store all of them in this geometry
if (inputExpressions.length > 1) {
@@ -422,7 +428,10 @@ case class ST_GeomFromGeoJSON(inputExpressions:
Seq[Expression])
GeometrySerializer.serialize(geometry)
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ Seq(geomString),
+ e)
}
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 41dd99589..0a4f5ecca 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -353,7 +353,10 @@ case class ST_IsValidDetail(children: Seq[Expression])
Seq(validDetail.valid, UTF8String.fromString(validDetail.reason),
serLocation))
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(input, children,
e)
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ Seq(geometry),
+ e)
}
}
@@ -627,20 +630,19 @@ case class ST_MinimumBoundingRadius(inputExpressions:
Seq[Expression])
override def eval(input: InternalRow): Any = {
val expr = inputExpressions(0)
+ val geometry = expr.toGeometry(input)
try {
- val geometry = expr match {
- case s: SerdeAware => s.evalWithoutSerialization(input)
- case _ => expr.toGeometry(input)
- }
-
geometry match {
case geometry: Geometry => getMinimumBoundingRadius(geometry)
case _ => null
}
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(input,
inputExpressions, e)
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ Seq(geometry),
+ e)
}
}
@@ -932,22 +934,24 @@ case class ST_SubDivideExplode(children: Seq[Expression])
extends Generator with
children.validateLength(2)
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
- val geometryRaw = children.head
- val maxVerticesRaw = children(1)
+ val geometry = children.head.toGeometry(input)
+ val maxVertices = children(1).toInt(input)
try {
- geometryRaw.toGeometry(input) match {
+ geometry match {
case geom: Geometry =>
- ArrayData.toArrayData(
- Functions.subDivide(geom,
maxVerticesRaw.toInt(input)).map(_.toGenericArrayData))
+ ArrayData.toArrayData(Functions.subDivide(geom,
maxVertices).map(_.toGenericArrayData))
Functions
- .subDivide(geom, maxVerticesRaw.toInt(input))
+ .subDivide(geom, maxVertices)
.map(_.toGenericArrayData)
.map(InternalRow(_))
case _ => new Array[InternalRow](0)
}
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(input, children,
e)
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ Seq(geometry, maxVertices),
+ e)
}
}
@@ -1008,8 +1012,8 @@ case class ST_MaximumInscribedCircle(children:
Seq[Expression])
with CodegenFallback {
override def eval(input: InternalRow): Any = {
+ val geometry = children.head.toGeometry(input)
try {
- val geometry = children.head.toGeometry(input)
var inscribedCircle: InscribedCircle = null
inscribedCircle = Functions.maximumInscribedCircle(geometry)
@@ -1018,7 +1022,10 @@ case class ST_MaximumInscribedCircle(children:
Seq[Expression])
InternalRow.fromSeq(Seq(serCenter, serNearest, inscribedCircle.radius))
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(input, children,
e)
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ Seq(geometry),
+ e)
}
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
index 44d16b08a..935c2d5e3 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
@@ -18,17 +18,19 @@
*/
package org.apache.spark.sql.sedona_sql.expressions
+import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Expression,
ImplicitCastInputTypes, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Expression,
ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType,
DataType, DataTypes, DoubleType, IntegerType, LongType, StringType,
StructField, StructType}
+import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType,
DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.Geometry
import org.apache.spark.sql.sedona_sql.expressions.implicits._
import scala.collection.convert.ImplicitConversions.`collection
AsScalaIterable`
+import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
import scala.reflect.runtime.universe.Type
import scala.reflect.runtime.universe.typeOf
@@ -77,27 +79,40 @@ abstract class InferredExpression(fSeq: InferrableFunction*)
override def inputTypes: Seq[AbstractDataType] = f.sparkInputTypes
override def dataType: DataType = f.sparkReturnType
- private lazy val argExtractors: Array[InternalRow => Any] =
f.buildExtractors(inputExpressions)
+ private lazy val argExtractors: Array[InternalRow => Any] =
buildExtractors(inputExpressions)
private lazy val evaluator: InternalRow => Any =
f.evaluatorBuilder(argExtractors)
- private def findAllLiterals(expression: Expression): Seq[Literal] = {
- expression match {
- case lit: Literal => Seq(lit)
- case _ => expression.children.flatMap(findAllLiterals)
- }
- }
+ // Remember input args to generate error messages when exceptions occur. The
input arguments are
+ // helpful for troubleshooting the cause of errors.
+ private val inputArgs: ArrayBuffer[AnyRef] = ArrayBuffer.empty[AnyRef]
- private def findAllLiteralsInExpressions(expressions: Seq[Expression]):
Seq[String] = {
- expressions.flatMap(findAllLiterals).map(_.value.toString)
+ private def buildExtractors(expressions: Seq[Expression]): Array[InternalRow
=> Any] = {
+ f.argExtractorBuilders
+ .zipAll(expressions, null, null)
+ .flatMap {
+ case (null, _) => None
+ case (builder, expr) =>
+ val extractor = builder(expr)
+ Some((input: InternalRow) => {
+ val arg = extractor(input)
+ inputArgs += arg.asInstanceOf[AnyRef]
+ arg
+ })
+ }
+ .toArray
}
override def eval(input: InternalRow): Any = {
-
try {
f.serializer(evaluator(input))
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(input,
inputExpressions, e)
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ inputArgs.toSeq,
+ e)
+ } finally {
+ inputArgs.clear()
}
}
@@ -106,32 +121,32 @@ abstract class InferredExpression(fSeq:
InferrableFunction*)
evaluator(input)
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(input,
inputExpressions, e)
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ inputArgs.toSeq,
+ e)
+ } finally {
+ inputArgs.clear()
}
}
}
object InferredExpression {
def throwExpressionInferenceException(
- input: InternalRow,
- inputExpressions: Seq[Expression],
+ name: String,
+ inputArgs: Seq[Any],
e: Exception): Nothing = {
- val literalsAsStrings = if (input == null) {
- // In case no input row is provided, we can't extract literals from the
input expressions.
- inputExpressions.flatMap(findAllLiterals).map(_.value.toString)
+ if (e.isInstanceOf[InferredExpressionException]) {
+ throw e
} else {
- Seq.empty[String]
- }
- val literalsOrInputString = literalsAsStrings.mkString(", ")
- throw new InferredExpressionException(
- s"Exception occurred while evaluating expression - source:
[$literalsOrInputString]",
- e)
- }
-
- def findAllLiterals(expression: Expression): Seq[Literal] = {
- expression match {
- case lit: Literal => Seq(lit)
- case _ => expression.children.flatMap(findAllLiterals)
+ val inputsAsStrings = inputArgs.map { arg =>
+ val argStr = if (arg != null) arg.toString else "null"
+ StringUtils.abbreviate(argStr, 5000)
+ }
+ val inputsString = inputsAsStrings.mkString(", ")
+ throw new InferredExpressionException(
+ s"Exception occurred while evaluating expression $name - inputs:
[$inputsString]",
+ e)
}
}
}
@@ -301,17 +316,7 @@ case class InferrableFunction(
sparkReturnType: DataType,
serializer: Any => Any,
argExtractorBuilders: Seq[Expression => InternalRow => Any],
- evaluatorBuilder: Array[InternalRow => Any] => InternalRow => Any) {
- def buildExtractors(expressions: Seq[Expression]): Array[InternalRow => Any]
= {
- argExtractorBuilders
- .zipAll(expressions, null, null)
- .flatMap {
- case (null, _) => None
- case (builder, expr) => Some(builder(expr))
- }
- .toArray
- }
-}
+ evaluatorBuilder: Array[InternalRow => Any] => InternalRow => Any)
object InferrableFunction {
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
index abdc066b0..bd7202cd9 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
@@ -55,13 +55,16 @@ abstract class ST_Predicate
if (rightArray == null) {
null
} else {
+ val leftGeometry = GeometrySerializer.deserialize(leftArray)
+ val rightGeometry = GeometrySerializer.deserialize(rightArray)
try {
- val leftGeometry = GeometrySerializer.deserialize(leftArray)
- val rightGeometry = GeometrySerializer.deserialize(rightArray)
evalGeom(leftGeometry, rightGeometry)
} catch {
case e: Exception =>
- InferredExpression.throwExpressionInferenceException(inputRow,
inputExpressions, e)
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ Seq(leftGeometry, rightGeometry),
+ e)
}
}
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
index 95846526b..0a52b8d10 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
@@ -44,29 +44,41 @@ case class ST_Collect(inputExpressions: Seq[Expression])
override def evalWithoutSerialization(input: InternalRow): Any = {
val firstElement = inputExpressions.head
- try {
- firstElement.dataType match {
- case ArrayType(elementType, _) =>
- elementType match {
- case _: GeometryUDT =>
- val data = firstElement.eval(input).asInstanceOf[ArrayData]
- val numElements = data.numElements()
- val geomElements = (0 until numElements)
- .map(element => data.getBinary(element))
- .filter(_ != null)
- .map(_.toGeometry)
+ firstElement.dataType match {
+ case ArrayType(elementType, _) =>
+ elementType match {
+ case _: GeometryUDT =>
+ val data = firstElement.eval(input).asInstanceOf[ArrayData]
+ val numElements = data.numElements()
+ val geomElements = (0 until numElements)
+ .map(element => data.getBinary(element))
+ .filter(_ != null)
+ .map(_.toGeometry)
+ try {
Functions.createMultiGeometry(geomElements.toArray)
- case _ => Functions.createMultiGeometry(Array())
- }
- case _ =>
- val geomElements =
- inputExpressions.map(_.toGeometry(input)).filter(_ != null)
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ Seq(geomElements),
+ e)
+ }
+
+ case _ => Functions.createMultiGeometry(Array())
+ }
+ case _ =>
+ val geomElements =
+ inputExpressions.map(_.toGeometry(input)).filter(_ != null)
+ try {
Functions.createMultiGeometry(geomElements.toArray)
- }
- } catch {
- case e: Exception =>
- InferredExpression.throwExpressionInferenceException(input,
inputExpressions, e)
+ } catch {
+ case e: Exception =>
+ InferredExpression.throwExpressionInferenceException(
+ getClass.getSimpleName,
+ Seq(geomElements),
+ e)
+ }
}
}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
index 2c9564a3c..d1eff9deb 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
@@ -216,8 +216,9 @@ class constructorTestScala extends TestBaseScala {
val thrown = intercept[Exception] {
sparkSession.sql("SELECT ST_GeomFromWKT('not wkt')").collect()
}
- assert(
- thrown.getMessage == "Exception occurred while evaluating expression -
source: [not wkt, 0], cause: Unknown geometry type: NOT (line 1)")
+ assert(thrown.getMessage.contains("ST_GeomFromWKT"))
+ assert(thrown.getMessage.contains("not wkt"))
+ assert(thrown.getMessage.contains("Unknown geometry type"))
}
it("Passed ST_GeomFromEWKT") {
@@ -246,8 +247,9 @@ class constructorTestScala extends TestBaseScala {
val thrown = intercept[Exception] {
sparkSession.sql("SELECT ST_GeomFromEWKT('not wkt')").collect()
}
- assert(
- thrown.getMessage == "Exception occurred while evaluating expression -
source: [not wkt], cause: Unknown geometry type: NOT (line 1)")
+ assert(thrown.getMessage.contains("ST_GeomFromEWKT"))
+ assert(thrown.getMessage.contains("not wkt"))
+ assert(thrown.getMessage.contains("Unknown geometry type"))
}
it("Passed ST_LineFromText") {
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index fa2e0e644..ee9db762d 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -20,7 +20,7 @@ package org.apache.sedona.sql
import org.apache.commons.codec.binary.Hex
import org.apache.spark.sql.Row
-import org.apache.spark.sql.functions.{array, col, element_at, lit}
+import org.apache.spark.sql.functions.{array, col, element_at, expr, lit}
import org.apache.spark.sql.sedona_sql.expressions.InferredExpressionException
import org.apache.spark.sql.sedona_sql.expressions.st_aggregates._
import org.apache.spark.sql.sedona_sql.expressions.st_constructors._
@@ -2275,9 +2275,21 @@ class dataFrameAPITestScala extends TestBaseScala {
}
// Check the exception message
- assert(exception.getMessage.contains(
- "[SRID=4326;POLYGON ((0 0, 2 0, 2 2, 0 2, 1 1, 0 0)), 50.0,
SRID=4326;POLYGON ((0 0, 1 0, 1 1, 0 1, 1 1, 0 0))]"))
+ assert(exception.getMessage.contains("POLYGON ((0 0, 2 0, 2 2, 0 2, 1 1,
0 0))"))
+ assert(exception.getMessage.contains("POLYGON ((0 0, 1 0, 1 1, 0 1, 1 1,
0 0))"))
+ assert(exception.getMessage.contains("ST_Rotate"))
assert(exception.getMessage.contains("The origin must be a non-empty
Point geometry."))
+
+ // Non-literal test case: ST_MakeLine will raise an exception
+ val exception2 = intercept[Exception] {
+ sparkSession
+ .range(0, 1)
+ .withColumn("geom", expr("ST_PolygonFromEnvelope(id, id, id + 1, id
+ 1)"))
+ .selectExpr("id", "ST_Envelope(ST_MakeLine(geom, geom))")
+ .collect()
+ }
+ assert(exception2.getMessage.contains("POLYGON ((0 0, 0 1, 1 1, 1 0, 0
0))"))
+ assert(exception2.getMessage.contains("ST_MakeLine"))
}
}
}