This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 1dd23212 [SEDONA-291] Python test cases failing due to exact equality 
comparisons between floats (#850)
1dd23212 is described below

commit 1dd232129455551bf5ed6c49604180e001c61f37
Author: Nilesh Gajwani <[email protected]>
AuthorDate: Tue Jun 6 20:54:40 2023 -0700

    [SEDONA-291] Python test cases failing due to exact equality comparisons 
between floats (#850)
---
 python/sedona/core/geom/envelope.py                | 11 +++++-
 .../streaming/spark/test_constructor_functions.py  | 41 +++++++++++++---------
 2 files changed, 34 insertions(+), 18 deletions(-)

diff --git a/python/sedona/core/geom/envelope.py 
b/python/sedona/core/geom/envelope.py
index c0a90eb3..d449f7c6 100644
--- a/python/sedona/core/geom/envelope.py
+++ b/python/sedona/core/geom/envelope.py
@@ -19,7 +19,7 @@ from shapely.geometry import Polygon, Point
 from shapely.geometry.base import BaseGeometry
 
 from sedona.utils.decorators import require
-
+import math
 
 class Envelope(Polygon):
 
@@ -35,6 +35,15 @@ class Envelope(Polygon):
             [self.maxx, self.miny]
         ])
 
+    def isClose(self, a, b) -> bool:
+        return math.isclose(a, b, rel_tol=1e-9)
+
+    def __eq__(self, other) -> bool:
+        return self.isClose(self.minx, other.minx) and\
+                self.isClose(self.miny, other.miny) and\
+                self.isClose(self.maxx, other.maxx) and\
+                self.isClose(self.maxy, other.maxy)
+
     @require(["Envelope"])
     def create_jvm_instance(self, jvm):
         return jvm.Envelope(
diff --git a/python/tests/streaming/spark/test_constructor_functions.py 
b/python/tests/streaming/spark/test_constructor_functions.py
index 6fc55b99..3ab4258d 100644
--- a/python/tests/streaming/spark/test_constructor_functions.py
+++ b/python/tests/streaming/spark/test_constructor_functions.py
@@ -27,6 +27,7 @@ from sedona.sql.types import GeometryType
 from tests import tests_resource
 from tests.streaming.spark.cases_builder import SuiteContainer
 from tests.test_base import TestBase
+import math
 
 SCHEMA = StructType(
     [
@@ -80,8 +81,10 @@ SEDONA_LISTED_SQL_FUNCTIONS = [
     (SuiteContainer.empty()
      .with_function_name("ST_Transform")
      .with_arguments(["ST_GeomFromText('POINT(21.5 52.5)')", "'epsg:4326'", 
"'epsg:2180'"])
-     .with_expected_result("POINT (-2501415.806893427 4119952.52325666)")
-     .with_transform("ST_ASText")),
+     .with_expected_result(-2501415.806893427)
+     #.with_expected_result("POINT (-2501415.806893427 4119952.52325666)")
+     .with_transform("ST_X")),
+     #.with_transform("ST_ASText")),
     (SuiteContainer.empty()
      .with_function_name("ST_Intersection")
      .with_arguments(["ST_GeomFromText('POINT(21.5 52.5)')", 
"ST_GeomFromText('POINT(21.5 52.5)')"])
@@ -319,23 +322,27 @@ class TestConstructorFunctions(TestBase):
     @pytest.mark.sparkstreaming
     def test_geospatial_function_on_stream(self, function_name: str, 
arguments: List[str],
                                            expected_result: Any, transform: 
Optional[str]):
-        # given input stream
+      # given input stream
 
-        input_stream = 
self.spark.readStream.schema(SCHEMA).parquet(os.path.join(
-            tests_resource,
-            "streaming/geometry_example")
-        ).selectExpr(f"{function_name}({', '.join(arguments)}) AS result")
+      input_stream = self.spark.readStream.schema(SCHEMA).parquet(os.path.join(
+         tests_resource,
+         "streaming/geometry_example")
+      ).selectExpr(f"{function_name}({', '.join(arguments)}) AS result")
 
-        # and target table
-        random_table_name = f"view_{uuid.uuid4().hex}"
+      # and target table
+      random_table_name = f"view_{uuid.uuid4().hex}"
 
-        # when saving stream to memory
-        streaming_query = input_stream.writeStream.format("memory") \
-            .queryName(random_table_name) \
-            .outputMode("append").start()
+      # when saving stream to memory
+      streaming_query = input_stream.writeStream.format("memory") \
+         .queryName(random_table_name) \
+         .outputMode("append").start()
 
-        streaming_query.processAllAvailable()
+      streaming_query.processAllAvailable()
 
-        # then result should be as expected
-        transform_query = "result" if not transform else f"{transform}(result)"
-        assert self.spark.sql(f"select {transform_query} from 
{random_table_name}").collect()[0][0] == expected_result
+      # then result should be as expected
+      transform_query = "result" if not transform else f"{transform}(result)"
+      queryResult = self.spark.sql(f"select {transform_query} from 
{random_table_name}").collect()[0][0]
+      if (type(queryResult) is float and type(expected_result) is float):
+         assert math.isclose(queryResult, expected_result, rel_tol=1e-9)
+      else:
+         assert queryResult == expected_result

Reply via email to