This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 1dd23212 [SEDONA-291] Python test cases failing due to exact equality
comparisons between floats (#850)
1dd23212 is described below
commit 1dd232129455551bf5ed6c49604180e001c61f37
Author: Nilesh Gajwani <[email protected]>
AuthorDate: Tue Jun 6 20:54:40 2023 -0700
[SEDONA-291] Python test cases failing due to exact equality comparisons
between floats (#850)
---
python/sedona/core/geom/envelope.py | 11 +++++-
.../streaming/spark/test_constructor_functions.py | 41 +++++++++++++---------
2 files changed, 34 insertions(+), 18 deletions(-)
diff --git a/python/sedona/core/geom/envelope.py
b/python/sedona/core/geom/envelope.py
index c0a90eb3..d449f7c6 100644
--- a/python/sedona/core/geom/envelope.py
+++ b/python/sedona/core/geom/envelope.py
@@ -19,7 +19,7 @@ from shapely.geometry import Polygon, Point
from shapely.geometry.base import BaseGeometry
from sedona.utils.decorators import require
-
+import math
class Envelope(Polygon):
@@ -35,6 +35,15 @@ class Envelope(Polygon):
[self.maxx, self.miny]
])
+ def isClose(self, a, b) -> bool:
+ return math.isclose(a, b, rel_tol=1e-9)
+
+ def __eq__(self, other) -> bool:
+ return self.isClose(self.minx, other.minx) and\
+ self.isClose(self.miny, other.miny) and\
+ self.isClose(self.maxx, other.maxx) and\
+ self.isClose(self.maxy, other.maxy)
+
@require(["Envelope"])
def create_jvm_instance(self, jvm):
return jvm.Envelope(
diff --git a/python/tests/streaming/spark/test_constructor_functions.py
b/python/tests/streaming/spark/test_constructor_functions.py
index 6fc55b99..3ab4258d 100644
--- a/python/tests/streaming/spark/test_constructor_functions.py
+++ b/python/tests/streaming/spark/test_constructor_functions.py
@@ -27,6 +27,7 @@ from sedona.sql.types import GeometryType
from tests import tests_resource
from tests.streaming.spark.cases_builder import SuiteContainer
from tests.test_base import TestBase
+import math
SCHEMA = StructType(
[
@@ -80,8 +81,10 @@ SEDONA_LISTED_SQL_FUNCTIONS = [
(SuiteContainer.empty()
.with_function_name("ST_Transform")
.with_arguments(["ST_GeomFromText('POINT(21.5 52.5)')", "'epsg:4326'",
"'epsg:2180'"])
- .with_expected_result("POINT (-2501415.806893427 4119952.52325666)")
- .with_transform("ST_ASText")),
+ .with_expected_result(-2501415.806893427)
+ #.with_expected_result("POINT (-2501415.806893427 4119952.52325666)")
+ .with_transform("ST_X")),
+ #.with_transform("ST_ASText")),
(SuiteContainer.empty()
.with_function_name("ST_Intersection")
.with_arguments(["ST_GeomFromText('POINT(21.5 52.5)')",
"ST_GeomFromText('POINT(21.5 52.5)')"])
@@ -319,23 +322,27 @@ class TestConstructorFunctions(TestBase):
@pytest.mark.sparkstreaming
def test_geospatial_function_on_stream(self, function_name: str,
arguments: List[str],
expected_result: Any, transform:
Optional[str]):
- # given input stream
+ # given input stream
- input_stream =
self.spark.readStream.schema(SCHEMA).parquet(os.path.join(
- tests_resource,
- "streaming/geometry_example")
- ).selectExpr(f"{function_name}({', '.join(arguments)}) AS result")
+ input_stream = self.spark.readStream.schema(SCHEMA).parquet(os.path.join(
+ tests_resource,
+ "streaming/geometry_example")
+ ).selectExpr(f"{function_name}({', '.join(arguments)}) AS result")
- # and target table
- random_table_name = f"view_{uuid.uuid4().hex}"
+ # and target table
+ random_table_name = f"view_{uuid.uuid4().hex}"
- # when saving stream to memory
- streaming_query = input_stream.writeStream.format("memory") \
- .queryName(random_table_name) \
- .outputMode("append").start()
+ # when saving stream to memory
+ streaming_query = input_stream.writeStream.format("memory") \
+ .queryName(random_table_name) \
+ .outputMode("append").start()
- streaming_query.processAllAvailable()
+ streaming_query.processAllAvailable()
- # then result should be as expected
- transform_query = "result" if not transform else f"{transform}(result)"
- assert self.spark.sql(f"select {transform_query} from
{random_table_name}").collect()[0][0] == expected_result
+ # then result should be as expected
+ transform_query = "result" if not transform else f"{transform}(result)"
+ queryResult = self.spark.sql(f"select {transform_query} from
{random_table_name}").collect()[0][0]
+ if (type(queryResult) is float and type(expected_result) is float):
+ assert math.isclose(queryResult, expected_result, rel_tol=1e-9)
+ else:
+ assert queryResult == expected_result