This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch python
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git


The following commit(s) were added to refs/heads/python by this push:
     new b1ebc1f  Tidy up kll python tests a bit
b1ebc1f is described below

commit b1ebc1fc6a5146167ecbe9f78277a6a49f5b7f2d
Author: Jon Malkin <[email protected]>
AuthorDate: Fri Feb 14 09:34:43 2025 -0800

    Tidy up kll python tests a bit
---
 python/tests/kll_test.py | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/python/tests/kll_test.py b/python/tests/kll_test.py
index 0ca850d..1c3cf38 100644
--- a/python/tests/kll_test.py
+++ b/python/tests/kll_test.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from pyspark.sql.types import StructType, StructField, DoubleType
+from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType
 
 from datasketches import kll_doubles_sketch
 from datasketches_spark.kll import *
@@ -40,24 +40,24 @@ def test_kll_build(spark):
 
   assert(sk.n == n)
   assert(sk.k == k)
-  assert(result["min"] == sk.get_min_value())
-  assert(result["max"] == sk.get_max_value())
+  assert(sk.get_min_value() == result["min"])
+  assert(sk.get_max_value() == result["max"])
   assert(sk.get_pmf([25000, 30000, 75000]) == result["pmf"])
   assert(sk.get_cdf([20000, 50000, 95000], False) == result["cdf"])
 
 def test_kll_merge(spark):
   n = 75 # stay in exact mode
   k = 200
-  data1 = [(float(i),) for i in range(1, n + 1)]
-  data2 = [(float(i),) for i in range(n + 1, 2 * n + 1)]
-  schema = StructType([StructField("value", DoubleType(), True)])
-  df1 = spark.createDataFrame(data1, schema)
-  df2 = spark.createDataFrame(data2, schema)
+  data1 = [(1, float(i)) for i in range(1, n + 1)]
+  data2 = [(2, float(i)) for i in range(n + 1, 2 * n + 1)]
+  schema = StructType([StructField("id", IntegerType(), True),
+                       StructField("value", DoubleType(), True)])
+  df = spark.createDataFrame(data1 + data2, schema)
 
-  df_agg1 = df1.agg(kll_sketch_double_agg_build("value", k).alias("sketch"))
-  df_agg2 = df2.agg(kll_sketch_double_agg_build("value", k).alias("sketch"))
+  df_agg = df.groupBy("id").agg(kll_sketch_double_agg_build("value", 
k).alias("sketch"))
+  assert(df_agg.count() == 2)
 
-  result = df_agg1.union(df_agg2).select(
+  result = df_agg.select(
     kll_sketch_double_agg_merge("sketch").alias("sketch")
   ).first()
   sk = result["sketch"]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to