This is an automated email from the ASF dual-hosted git repository.
jmalkin pushed a commit to branch python
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git
The following commit(s) were added to refs/heads/python by this push:
new b1ebc1f Tidy up kll python tests a bit
b1ebc1f is described below
commit b1ebc1fc6a5146167ecbe9f78277a6a49f5b7f2d
Author: Jon Malkin <[email protected]>
AuthorDate: Fri Feb 14 09:34:43 2025 -0800
Tidy up kll python tests a bit
---
python/tests/kll_test.py | 22 +++++++++++-----------
1 file changed, 11 insertions(+), 11 deletions(-)
diff --git a/python/tests/kll_test.py b/python/tests/kll_test.py
index 0ca850d..1c3cf38 100644
--- a/python/tests/kll_test.py
+++ b/python/tests/kll_test.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from pyspark.sql.types import StructType, StructField, DoubleType
+from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType
from datasketches import kll_doubles_sketch
from datasketches_spark.kll import *
@@ -40,24 +40,24 @@ def test_kll_build(spark):
assert(sk.n == n)
assert(sk.k == k)
- assert(result["min"] == sk.get_min_value())
- assert(result["max"] == sk.get_max_value())
+ assert(sk.get_min_value() == result["min"])
+ assert(sk.get_max_value() == result["max"])
assert(sk.get_pmf([25000, 30000, 75000]) == result["pmf"])
assert(sk.get_cdf([20000, 50000, 95000], False) == result["cdf"])
def test_kll_merge(spark):
n = 75 # stay in exact mode
k = 200
- data1 = [(float(i),) for i in range(1, n + 1)]
- data2 = [(float(i),) for i in range(n + 1, 2 * n + 1)]
- schema = StructType([StructField("value", DoubleType(), True)])
- df1 = spark.createDataFrame(data1, schema)
- df2 = spark.createDataFrame(data2, schema)
+ data1 = [(1, float(i)) for i in range(1, n + 1)]
+ data2 = [(2, float(i)) for i in range(n + 1, 2 * n + 1)]
+ schema = StructType([StructField("id", IntegerType(), True),
+ StructField("value", DoubleType(), True)])
+ df = spark.createDataFrame(data1 + data2, schema)
- df_agg1 = df1.agg(kll_sketch_double_agg_build("value", k).alias("sketch"))
- df_agg2 = df2.agg(kll_sketch_double_agg_build("value", k).alias("sketch"))
+ df_agg = df.groupBy("id").agg(kll_sketch_double_agg_build("value",
k).alias("sketch"))
+ assert(df_agg.count() == 2)
- result = df_agg1.union(df_agg2).select(
+ result = df_agg.select(
kll_sketch_double_agg_merge("sketch").alias("sketch")
).first()
sk = result["sketch"]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]