iyerr3 commented on a change in pull request #352: Feature/kd tree knn
URL: https://github.com/apache/madlib/pull/352#discussion_r256158403
 
 

 ##########
 File path: src/ports/postgres/modules/knn/knn.py_in
 ##########
 @@ -124,16 +137,377 @@ def knn_validate_src(schema_madlib, point_source, 
point_column_name, point_id,
             """.format(fn_dist=fn_dist, profunc=profunc))[0]['output']
 
         if is_invalid_func or (fn_dist not in dist_functions):
-            plpy.error("KNN error: Distance function has invalid signature "
-                       "or is not a simple function.")
-
+            plpy.error("KNN error: Distance function ({0}) has invalid 
signature "
+                       "or is not a simple function.".format(fn_dist))
+    if depth <= 0:
+        plpy.error("kNN Error: depth={0} is an invalid value, must be greater "
+                   "than 0.".format(depth))
+    if leaf_nodes <= 0:
+        plpy.error("kNN Error: leaf_nodes={0} is an invalid value, must be 
greater "
+                   "than 0.".format(leaf_nodes))
+    if pow(2,depth) <= leaf_nodes:
+        plpy.error("kNN Error: depth={0}, leaf_nodes={1} is not valid. "
+                   "The leaf_nodes value must be lower than 
2^depth".format(depth, leaf_nodes))
     return k
 # 
------------------------------------------------------------------------------
 
 
+def build_kd_tree(schema_madlib, source_table, output_table, point_column_name,
+                  depth, r_id, dim, **kwargs):
+    """
+        KD-tree function to create a partitioning for KNN
+        Args:
+            @param schema_madlib        Name of the Madlib Schema
+            @param source_table         Training data table
+            @param output_table         Name of the table to store kd tree
+            @param point_column_name    Name of the column with training data
+                                        or expression that evaluates to a
+                                        numeric array
+            @param depth                Depth of the kd tree
+            @param r_id                 Name of the region id column
+            @param dim                  Name of the dimension column
+    """
+    with MinWarning("error"):
+
+        validate_kd_tree(source_table, output_table, point_column_name, depth)
+        n_features = num_features(source_table, point_column_name)
+
+        clauses = [' 1=1 ']
+        cutoffs = []
+        centers_table = add_postfix(output_table, "_centers")
+        clause_counter = 0
+        current_feature = 1
+        for curr_level in range(depth):
+            for curr_leaf in range(pow(2,curr_level)):
+                clause = clauses[clause_counter]
+                cutoff_sql = """
+                    SELECT percentile_disc(0.5)
+                           WITHIN GROUP (
+                            ORDER BY ({point_column_name})[{current_feature}]
+                           ) AS cutoff
+                    FROM {source_table}
+                    WHERE {clause}
+                    """.format(**locals())
+
+                cutoff = plpy.execute(cutoff_sql)[0]['cutoff']
+                cutoff = cutoff if cutoff is not None else "NULL"
+                clause_counter += 1
+
+                cutoffs.append(cutoff)
+                clauses.append(clause +
+                               "AND ({point_column_name})[{current_feature}]"
+                               " < {cutoff} ".format(**locals()))
+                clauses.append(clause +
+                               "AND ({point_column_name})[{current_feature}]"
+                               " >= {cutoff} ".format(**locals()))
+            current_feature = current_feature % n_features + 1
+
+        output_table_tree = add_postfix(output_table, "_tree")
 
 Review comment:
   I'm wondering if there's a use for placing this tree in a table? 
   Downstream it looks like we need the tree to compute `n_leaves` but we can 
obtain that from other sources (eg. from the depth), isn't it? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to