This is an automated email from the ASF dual-hosted git repository.

alsay pushed a commit to branch add_p
in repository https://gitbox.apache.org/repos/asf/datasketches-bigquery.git

commit f02a45ab89186d18717473c00bf289592bcffe55
Author: AlexanderSaydakov <[email protected]>
AuthorDate: Wed Sep 4 17:09:28 2024 -0700

    added sampling probability
---
 theta/sqlx/theta_sketch_agg_string.sqlx | 73 +++++++++++++++------------------
 theta/test/theta_sketch_test.sql        | 20 ++++-----
 theta/theta_sketch.cpp                  | 28 +------------
 3 files changed, 46 insertions(+), 75 deletions(-)

diff --git a/theta/sqlx/theta_sketch_agg_string.sqlx 
b/theta/sqlx/theta_sketch_agg_string.sqlx
index 013942b..c676f0c 100644
--- a/theta/sqlx/theta_sketch_agg_string.sqlx
+++ b/theta/sqlx/theta_sketch_agg_string.sqlx
@@ -19,7 +19,7 @@
 
 config { hasOutput: true }
 
-CREATE OR REPLACE AGGREGATE FUNCTION ${self()}(str STRING, params STRUCT<lg_k 
BYTEINT, seed INT64> NOT AGGREGATE)
+CREATE OR REPLACE AGGREGATE FUNCTION ${self()}(str STRING, params STRUCT<lg_k 
BYTEINT, seed INT64, p FLOAT64> NOT AGGREGATE)
 RETURNS BYTES 
 LANGUAGE js
 OPTIONS (
@@ -28,6 +28,9 @@ OPTIONS (
 Param str: the STRING column of identifiers.
 Param lg_k: the sketch accuracy/size parameter as an integer in the range [4, 
26].
 Param seed: the seed to be used by the underlying hash function.
+Param p: sampling probability (initial theta). The default is 1, so the sketch 
retains
+all entries until it reaches the limit, at which point it goes into the 
estimation mode
+and reduces the effective sampling probability (theta) as necessary.
 Returns: a Compact, Compressed Theta Sketch, as bytes, from which the 
cardinality can be obtained. 
 For more details: 
https://datasketches.apache.org/docs/Theta/ThetaSketchFramework.html'''
 ) AS R"""
@@ -35,6 +38,7 @@ import ModuleFactory from "gs://$GCS_BUCKET/theta_sketch.mjs";
 var Module = await ModuleFactory();
 const default_lg_k = Number(Module.DEFAULT_LG_K);
 const default_seed = BigInt(Module.DEFAULT_SEED);
+const default_p = 1.0;
 
 function destroyState(state) {
   if (state.sketch) {
@@ -53,83 +57,74 @@ export function initialState(params) {
   var state = {
     lg_k: params.lg_k == null ? default_lg_k : Number(params.lg_k),
     seed: params.seed == null ? default_seed : BigInt(params.seed),
-    sketch: null,
-    union: null,
-    serialized: null
+    p: params.p == null ? default_p : params.p
   };
-  state.sketch = new Module.update_theta_sketch(state.lg_k, state.seed);
+  state.sketch = new Module.update_theta_sketch(state.lg_k, state.seed, 
state.p);
   return state;
 }
 
 export function aggregate(state, str) {
   if (state.sketch == null) {
-    state.sketch = new Module.update_theta_sketch(state.lg_k, state.seed);
+    state.sketch = new Module.update_theta_sketch(state.lg_k, state.seed, 
state.p);
   }
   state.sketch.updateString(str);
 }
 
 export function serialize(state) {
+  if (state.sketch == null) return state; // for transition 
deserialize-serialize
   try {
+    // for prior transition deserialize-aggregate
+    // merge aggregated and serialized state
     if (state.sketch != null && state.serialized != null) {
-      // merge aggregated and serialized state
-      var u = new Module.theta_union(state.lg_k, state.seed);
+      var u = null;
       try {
+        u = new Module.theta_union(state.lg_k, state.seed);
         u.updateWithUpdateSketch(state.sketch);
         u.updateWithBytes(state.serialized, state.seed);
         state.serialized = u.getResultAsUint8ArrayCompressed();
       } finally {
-        u.delete();
+        if (u != null) u.delete();
       }
     } else if (state.sketch != null) {
       state.serialized = state.sketch.serializeAsUint8ArrayCompressed();
     } else if (state.union != null) {
       state.serialized = state.union.getResultAsUint8ArrayCompressed();
-    } else if (state.serialized == null) {
-      throw new Error("Unexpected state in serialization " + 
JSON.stringify(state));
     }
     return {
       lg_k: state.lg_k,
       seed: state.seed,
-      bytes: state.serialized
+      serialized: state.serialized
     };
+  } catch (e) {
+    throw new Error(Module.getExceptionMessage(e));
   } finally {
     destroyState(state);
   }
 }
 
-export function deserialize(serialized) {
-  return {
-    sketch: null,
-    union: null,
-    serialized: serialized.bytes,
-    lg_k: serialized.lg_k,
-    seed: serialized.seed
-  };
+export function deserialize(state) {
+  return state;
 }
 
 export function merge(state, other_state) {
-  if (!state.union) {
-    state.union = new Module.theta_union(state.lg_k, state.seed);
-  }
-  if (state.sketch || other_state.sketch) {
-    throw new Error("update_theta_sketch not expected during merge()");
-  }
-  if (other_state.union) {
-    throw new Error("other_state should not have union during merge()");
-  }
-  if (state.serialized) {
-    state.union.updateWithBytes(state.serialized, state.seed);
-    state.serialized = null;
-  }
-  if (other_state.serialized) {
-    state.union.updateWithBytes(other_state.serialized, state.seed);
-    other_state.serialized = null;
-  } else {
-    throw new Error("other_state should have serialized sketch during merge");
+  try {
+    if (!state.union) {
+      state.union = new Module.theta_union(state.lg_k, state.seed);
+    }
+    if (state.serialized) {
+      state.union.updateWithBytes(state.serialized, state.seed);
+      state.serialized = null;
+    }
+    if (other_state.serialized) {
+      state.union.updateWithBytes(other_state.serialized, state.seed);
+      other_state.serialized = null;
+    }
+  } catch (e) {
+    throw new Error(Module.getExceptionMessage(e));
   }
 }
 
 export function finalize(state) {
-  return serialize(state).bytes
+  return serialize(state).serialized
 }
 """;
diff --git a/theta/test/theta_sketch_test.sql b/theta/test/theta_sketch_test.sql
index 1003a21..6d8f336 100644
--- a/theta/test/theta_sketch_test.sql
+++ b/theta/test/theta_sketch_test.sql
@@ -20,9 +20,9 @@
 create or replace table $BQ_DATASET.theta_sketch(sketch bytes);
 
 insert into $BQ_DATASET.theta_sketch
-(select $BQ_DATASET.theta_sketch_agg_string(cast(value as string), struct<int, 
int>(14, null)) from unnest(GENERATE_ARRAY(1, 10000, 1)) as value);
+(select $BQ_DATASET.theta_sketch_agg_string(cast(value as string), struct<int, 
int, float64>(14, null, 0.8)) from unnest(GENERATE_ARRAY(1, 10000, 1)) as 
value);
 insert into $BQ_DATASET.theta_sketch
-(select $BQ_DATASET.theta_sketch_agg_string(cast(value as string), struct<int, 
int>(14, null)) from unnest(GENERATE_ARRAY(100000, 110000, 1)) as value);
+(select $BQ_DATASET.theta_sketch_agg_string(cast(value as string), struct<int, 
int, float64>(14, null, 0.8)) from unnest(GENERATE_ARRAY(100000, 110000, 1)) as 
value);
 
 # expected about 20000
 select $BQ_DATASET.theta_sketch_get_estimate_and_bounds(
@@ -42,8 +42,8 @@ drop table $BQ_DATASET.theta_sketch;
 # expected 5
 select $BQ_DATASET.theta_sketch_get_estimate(
   $BQ_DATASET.theta_sketch_scalar_union(
-    (select $BQ_DATASET.theta_sketch_agg_string(str, struct<int, int>(null, 
null)) from unnest(["a", "b", "c"]) as str),
-    (select $BQ_DATASET.theta_sketch_agg_string(str, STRUCT<int, int>(null, 
null)) from unnest(["c", "d", "e"]) as str),
+    (select $BQ_DATASET.theta_sketch_agg_string(str, struct<int, int, 
float64>(null, null, null)) from unnest(["a", "b", "c"]) as str),
+    (select $BQ_DATASET.theta_sketch_agg_string(str, STRUCT<int, int, 
float64>(null, null, null)) from unnest(["c", "d", "e"]) as str),
     null,
     null
   ),
@@ -53,8 +53,8 @@ select $BQ_DATASET.theta_sketch_get_estimate(
 # expected 1
 select $BQ_DATASET.theta_sketch_get_estimate(
   $BQ_DATASET.theta_sketch_scalar_intersection(
-    (select $BQ_DATASET.theta_sketch_agg_string(str, struct<int, int>(null, 
null)) from unnest(["a", "b", "c"]) as str),
-    (select $BQ_DATASET.theta_sketch_agg_string(str, STRUCT<int, int>(null, 
null)) from unnest(["c", "d", "e"]) as str),
+    (select $BQ_DATASET.theta_sketch_agg_string(str, struct<int, int, 
float64>(null, null, null)) from unnest(["a", "b", "c"]) as str),
+    (select $BQ_DATASET.theta_sketch_agg_string(str, STRUCT<int, int, 
float64>(null, null, null)) from unnest(["c", "d", "e"]) as str),
     null
   ),
   null
@@ -63,8 +63,8 @@ select $BQ_DATASET.theta_sketch_get_estimate(
 # expected 2
 select $BQ_DATASET.theta_sketch_get_estimate(
   $BQ_DATASET.theta_sketch_a_not_b(
-    (select $BQ_DATASET.theta_sketch_agg_string(str, struct<int, int>(null, 
null)) from unnest(["a", "b", "c"]) as str),
-    (select $BQ_DATASET.theta_sketch_agg_string(str, STRUCT<int, int>(null, 
null)) from unnest(["c", "d", "e"]) as str),
+    (select $BQ_DATASET.theta_sketch_agg_string(str, struct<int, int, 
float64>(null, null, null)) from unnest(["a", "b", "c"]) as str),
+    (select $BQ_DATASET.theta_sketch_agg_string(str, STRUCT<int, int, 
float64>(null, null, null)) from unnest(["c", "d", "e"]) as str),
     null
   ),
   null
@@ -72,7 +72,7 @@ select $BQ_DATASET.theta_sketch_get_estimate(
 
 # expected 0.2
 select $BQ_DATASET.theta_sketch_jaccard_similarity(
-  (select $BQ_DATASET.theta_sketch_agg_string(str, struct<int, int>(null, 
null)) from unnest(["a", "b", "c"]) as str),
-  (select $BQ_DATASET.theta_sketch_agg_string(str, STRUCT<int, int>(null, 
null)) from unnest(["c", "d", "e"]) as str),
+  (select $BQ_DATASET.theta_sketch_agg_string(str, struct<int, int, 
float64>(null, null, null)) from unnest(["a", "b", "c"]) as str),
+  (select $BQ_DATASET.theta_sketch_agg_string(str, STRUCT<int, int, 
float64>(null, null, null)) from unnest(["c", "d", "e"]) as str),
   null
 );
diff --git a/theta/theta_sketch.cpp b/theta/theta_sketch.cpp
index 86fb5b8..f91da30 100644
--- a/theta/theta_sketch.cpp
+++ b/theta/theta_sketch.cpp
@@ -48,16 +48,10 @@ EMSCRIPTEN_BINDINGS(theta_sketch) {
   emscripten::constant("DEFAULT_SEED", datasketches::DEFAULT_SEED);
 
   emscripten::class_<update_theta_sketch>("update_theta_sketch")
-    .constructor(emscripten::optional_override([](uint8_t lg_k) {
-      return new 
update_theta_sketch(update_theta_sketch::builder().set_lg_k(lg_k).build());
-    }))
-    .constructor(emscripten::optional_override([](uint8_t lg_k, uint64_t seed) 
{
-      return new 
update_theta_sketch(update_theta_sketch::builder().set_lg_k(lg_k).set_seed(seed).build());
+    .constructor(emscripten::optional_override([](uint8_t lg_k, uint64_t seed, 
float p) {
+      return new 
update_theta_sketch(update_theta_sketch::builder().set_lg_k(lg_k).set_seed(seed).set_p(p).build());
     }))
     .function("updateString", emscripten::select_overload<void(const 
std::string&)>(&update_theta_sketch::update))
-    .function("serialize", 
emscripten::optional_override([](update_theta_sketch& self) {
-      return self.compact().serialize_compressed();
-    }))
     .function("serializeB64", emscripten::optional_override([](const 
update_theta_sketch& self) {
       auto bytes = self.compact().serialize();
       std::vector<char> b64(b64_enc_len(bytes.size()));
@@ -104,9 +98,6 @@ EMSCRIPTEN_BINDINGS(theta_sketch) {
     .function("toString", emscripten::optional_override([](const 
compact_theta_sketch& self) {
       return std::string(self.to_string());
     }))
-    .function("serialize", emscripten::optional_override([](const 
compact_theta_sketch& self) {
-      return self.serialize_compressed();
-    }))
     .class_function("getMaxSerializedSizeBytes", 
&compact_theta_sketch::get_max_serialized_size_bytes)
     ;
 
@@ -123,12 +114,6 @@ EMSCRIPTEN_BINDINGS(theta_sketch) {
     ;
 
   emscripten::class_<theta_union>("theta_union")
-    .constructor(emscripten::optional_override([]() {
-      return new theta_union(theta_union::builder().build());
-    }))
-    .constructor(emscripten::optional_override([](uint8_t lg_k) {
-      return new theta_union(theta_union::builder().set_lg_k(lg_k).build());
-    }))
     .constructor(emscripten::optional_override([](uint8_t lg_k, uint64_t seed) 
{
       return new 
theta_union(theta_union::builder().set_lg_k(lg_k).set_seed(seed).build());
     }))
@@ -152,9 +137,6 @@ EMSCRIPTEN_BINDINGS(theta_sketch) {
     .function("updateWithBuffer", 
emscripten::optional_override([](theta_union& self, intptr_t bytes, size_t 
size, uint64_t seed) {
       
self.update(wrapped_compact_theta_sketch::wrap(reinterpret_cast<void*>(bytes), 
size, seed));
     }))
-    .function("getResultSerialized", 
emscripten::optional_override([](theta_union& self) {
-      return self.get_result().serialize_compressed();
-    }))
     .function("getResultStreamCompressed", 
emscripten::optional_override([](theta_union& self, intptr_t bytes, size_t 
size) {
       std::strstream stream(reinterpret_cast<char*>(bytes), size);
       self.get_result().serialize_compressed(stream);
@@ -183,9 +165,6 @@ EMSCRIPTEN_BINDINGS(theta_sketch) {
     ;
 
   emscripten::class_<theta_intersection>("theta_intersection")
-    .constructor(emscripten::optional_override([]() {
-      return new theta_intersection();
-    }))
     .constructor(emscripten::optional_override([](uint64_t seed) {
       return new theta_intersection(seed);
     }))
@@ -215,9 +194,6 @@ EMSCRIPTEN_BINDINGS(theta_sketch) {
     ;
 
   emscripten::class_<theta_a_not_b>("theta_a_not_b")
-    .constructor(emscripten::optional_override([]() {
-      return new theta_a_not_b();
-    }))
     .constructor(emscripten::optional_override([](uint64_t seed) {
       return new theta_a_not_b(seed);
     }))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to