This is an automated email from the ASF dual-hosted git repository. jmalkin pushed a commit to branch theta_compressed in repository https://gitbox.apache.org/repos/asf/datasketches-python.git
commit 9dc7e7c8ba153027475544f92661ef9d78f94511 Author: Jon Malkin <[email protected]> AuthorDate: Mon Feb 12 10:07:51 2024 -0800 Allow compression for theta sketches --- src/theta_wrapper.cpp | 8 ++++---- tests/theta_test.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/theta_wrapper.cpp b/src/theta_wrapper.cpp index 1e4eee5..2915e17 100644 --- a/src/theta_wrapper.cpp +++ b/src/theta_wrapper.cpp @@ -104,11 +104,11 @@ void init_theta(nb::module_ &m) { .def("__copy__", [](const compact_theta_sketch& sk){ return compact_theta_sketch(sk); }) .def( "serialize", - [](const compact_theta_sketch& sk) { - auto bytes = sk.serialize(); + [](const compact_theta_sketch& sk, bool compress) { + auto bytes = compress ? sk.serialize_compressed() : sk.serialize(); return nb::bytes(reinterpret_cast<const char*>(bytes.data()), bytes.size()); - }, - "Serializes the sketch into a bytes object" + }, nb::arg("compress")=false, + "Serializes the sketch into a bytes object, optionally compressing the data" ) .def_static( "deserialize", diff --git a/tests/theta_test.py b/tests/theta_test.py index 8cdb2a7..1d2da0f 100644 --- a/tests/theta_test.py +++ b/tests/theta_test.py @@ -48,6 +48,16 @@ class ThetaTest(unittest.TestCase): self.assertFalse(sk.is_empty()) self.assertEqual(sk.get_estimate(), new_sk.get_estimate()) + # can also serialze in a compressed format + sk_compresed_bytes = sk.compact().serialize(compress=True) + self.assertLess(len(sk_compresed_bytes), len(sk_bytes)) + sk_from_compressed = compact_theta_sketch.deserialize(sk_compresed_bytes) + + # compressed and non-compressed sketches should match + self.assertEqual(sk_from_compressed.get_estimate(), new_sk.get_estimate()) + self.assertEqual(sk_from_compressed.get_upper_bound(1), new_sk.get_upper_bound(1)) + self.assertEqual(sk_from_compressed.get_lower_bound(1), new_sk.get_lower_bound(1)) + # check that printing works as expected self.assertGreater(len(sk.to_string(True)), 0) self.assertEqual(len(sk.__str__()), len(sk.to_string())) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
