iyerr3 commented on a change in pull request #352: Feature/kd tree knn
URL: https://github.com/apache/madlib/pull/352#discussion_r255757154
 
 

 ##########
 File path: src/ports/postgres/modules/knn/knn.py_in
 ##########
 @@ -124,16 +137,312 @@ def knn_validate_src(schema_madlib, point_source, 
point_column_name, point_id,
             """.format(fn_dist=fn_dist, profunc=profunc))[0]['output']
 
         if is_invalid_func or (fn_dist not in dist_functions):
-            plpy.error("KNN error: Distance function has invalid signature "
-                       "or is not a simple function.")
-
+            plpy.error("KNN error: Distance function ({0}) has invalid 
signature "
+                       "or is not a simple function.".format(fn_dist))
+    if depth <= 0:
+        plpy.error("kNN Error: depth={0} is an invalid value, must be greater "
+                   "than 0.".format(depth))
+    if leaf_nodes <= 0:
+        plpy.error("kNN Error: leaf_nodes={0} is an invalid value, must be 
greater "
+                   "than 0.".format(leaf_nodes))
+    if pow(2,depth) <= leaf_nodes:
+        plpy.error("kNN Error: depth={0}, leaf_nodes={1} is not valid. "
+                   "The leaf_nodes value must be lower than 
2^depth".format(depth, leaf_nodes))
     return k
 # 
------------------------------------------------------------------------------
 
 
+def kd_tree(schema_madlib, source_table, output_table, point_column_name, 
depth,
+            r_id, dim, **kwargs):
+    """
+        KD-tree function to create a partitioning for KNN
+        Args:
+            @param schema_madlib        Name of the Madlib Schema
+            @param source_table         Training data table
+            @param output_table         Name of the table to store kd tree
+            @param point_column_name    Name of the column with training data
+                                        or expression that evaluates to a
+                                        numeric array
+            @param depth                Depth of the kd tree
+            @param r_id                 Name of the region id column
+            @param dim                  Name of the dimension column
+
+    """
+    with MinWarning("error"):
+
+        validate_kd_tree(source_table, output_table, point_column_name, depth)
+        n_features = num_features(source_table, point_column_name)
+
+        clauses = [" WHERE 1=1 "]
+        cutoffs = []
+        centers_table = add_postfix(output_table, "_centers")
+        clause_counter = 0
+        current_feature = 1
+        for curr_level in range(depth):
+            for curr_leaf in range(pow(2,curr_level)):
+                clause = clauses[clause_counter]
+                cutoff_sql = """
+                    SELECT percentile_disc(0.5)
+                           WITHIN GROUP (
+                            ORDER BY ({point_column_name})[{current_feature}]
+                           ) AS cutoff
+                    FROM {source_table}
+                    {clause}
+                    """.format(**locals())
+
+                cutoff = plpy.execute(cutoff_sql)[0]['cutoff']
+                cutoff = cutoff if cutoff is not None else "NULL"
+                clause_counter += 1
+
+                cutoffs.append(cutoff)
+                clauses.append(clause +
+                               "AND ({point_column_name})[{current_feature}]"
+                               " < {cutoff} ".format(**locals()))
+                clauses.append(clause +
+                               "AND ({point_column_name})[{current_feature}]"
+                               " >= {cutoff} ".format(**locals()))
+            current_feature = current_feature % n_features + 1
+
+        output_table_tree = add_postfix(output_table, "_tree")
+        plpy.execute("CREATE TABLE {0} AS "
+                     "SELECT ('{{ {1} }}')::DOUBLE PRECISION[] AS tree".
+                     format(output_table_tree,
+                            " ,".join(map(str, cutoffs))))
+
+        n_leaves = pow(2,depth)
+        case_when_clause = ["WHEN {0} THEN {1}::INTEGER".format(cond[14:], i)
+                                     for i, cond in 
enumerate(clauses[-n_leaves:])]
+        output_sql = """
+            CREATE TABLE {output_table} AS
+                SELECT *, CASE {cases} END AS {r_id}
+                FROM {source_table}""".format(
+                    cases = ' '.join(case_when_clause),**locals())
+        plpy.execute(output_sql)
+
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(centers_table))
+        centers_sql = """
+            CREATE TABLE {centers_table} AS
+                SELECT {r_id}, {schema_madlib}.array_scalar_mult(
+                        {schema_madlib}.sum({point_column_name}):: DOUBLE 
PRECISION[],
+                        (1.0/count(*))::DOUBLE PRECISION) AS __center__
+                FROM {output_table}
+                GROUP BY {r_id}
+            """.format(**locals())
+        plpy.execute(centers_sql)
+# 
------------------------------------------------------------------------------
+
+def validate_kd_tree(source_table, output_table, point_column_name, depth):
+
+    input_tbl_valid(source_table, 'kd_tree')
+    output_tbl_valid(output_table, 'kd_tree')
+    output_tbl_valid(output_table+"_tree", 'kd_tree')
+
+    _assert(is_var_valid(source_table, point_column_name),
+            "kd_tree error: {0} is an invalid column name or expression for "
+            "point_column_name param".format(point_column_name))
+    point_col_type = get_expr_type(point_column_name, source_table)
+    _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
+            "kNN Error: Feature column or expression '{0}' in train table is 
not"
+            " an array.".format(point_column_name))
+    if depth <= 0:
+        plpy.error("kNN Error: depth={0} is an invalid value, must be greater "
+                   "than 0.".format(depth))
+# 
------------------------------------------------------------------------------
+
+def knn_tree(schema_madlib, kd_out, point_source, point_column_name, point_id,
+             label_column_name, test_source, test_column_name, test_id,
+             interim_table, in_k, output_neighbors, fn_dist, weighted_avg,
+             leaf_nodes, r_id, dim, label_out, comma_label_out_alias,
+             label_name, train, train_id, dist_inverse, test_id_temp, 
**kwargs):
+    """
+        KNN function to find the K Nearest neighbours
+        Args:
+            @param schema_madlib        Name of the Madlib Schema
+            @param kd_out               Name of the kd tree table
+            @param point_source         Training data table
+            @param point_column_name    Name of the column with training data
+                                        or expression that evaluates to a
+                                        numeric array
+            @param point_id             Name of the column having ids of data
+                                        point in train data table
+                                        points.
+            @param label_column_name    Name of the column with labels/values
+                                        of training data points.
+            @param test_source          Name of the table containing the test
+                                        data points.
+            @param test_column_name     Name of the column with testing data
+                                        points or expression that evaluates to 
a
+                                        numeric array
+            @param test_id              Name of the column having ids of data
+                                        points in test data table.
+            @param interim_table        Name of the table to store interim
+                                        results.
+            @param in_k                 default: 1. Number of nearest
+                                        neighbors to consider
+            @param output_neighbours    Outputs the list of k-nearest neighbors
+                                        that were used in the voting/averaging.
+            @param fn_dist              Distance metrics function. Default is
+                                        squared_dist_norm2. Following functions
+                                        are supported :
+                                        dist_norm1 , 
dist_norm2,squared_dist_norm2,
+                                        dist_angle , dist_tanimoto
+                                        Or user defined function with signature
+                                        DOUBLE PRECISION[] x, DOUBLE 
PRECISION[] y -> DOUBLE PRECISION
+            @param weighted_avg         Calculates the Regression or 
classication of k-NN using
+                                        the weighted average method.
+            @param leaf_nodes           Number of leaf nodes to explore
+            @param r_id                 Name of the region id column
+            @param dim                  Name of the dimension column
+            Following parameters are passed to ensure the interim table has
+            identical features to non-kd-tree implementation
+            @param label_out
+            @param comma_label_out_alias
+            @param label_name
+            @param train
+            @param train_id
+            @param dist_inverse
+            @param test_id_temp
+    """
+    with MinWarning("error"):
+
+        tree_model = add_postfix(kd_out, "_tree")
+        centers_table = add_postfix(kd_out, "_centers")
+        n_features = num_features(test_source, test_column_name)
+
+        tree = plpy.execute("SELECT * FROM {0}".format(tree_model))[0]['tree']
+        # 'tree' contains only non-leaf nodes,
+        # hence 'n_leaves' is always 1 more than len(tree)
+        n_leaves = len(tree)+1
+
+        depth = int(log(n_leaves, 2))
+
+        # The borders table will have two rows for each dimension (Upper & 
lower)
+        # even if a dimension does not have a branch.
+        # Its borders will be -Inf, Inf
+
+        # The first leaf_note is itself,
+        # we expand to n-1 nodes out of 2 * n_features borders
+        quant = float(leaf_nodes) / n_leaves
+
+        clause_counter = 0
 
 Review comment:
   The logic here is the same as the one in tree building. We can either 
abstract to a separate function or just store the complete string as the tree 
to avoid rebuilding here. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to