iyerr3 commented on a change in pull request #352: Feature/kd tree knn
URL: https://github.com/apache/madlib/pull/352#discussion_r256175598
##########
File path: src/ports/postgres/modules/knn/knn.py_in
##########
@@ -124,16 +137,377 @@ def knn_validate_src(schema_madlib, point_source,
point_column_name, point_id,
""".format(fn_dist=fn_dist, profunc=profunc))[0]['output']
if is_invalid_func or (fn_dist not in dist_functions):
- plpy.error("KNN error: Distance function has invalid signature "
- "or is not a simple function.")
-
+ plpy.error("KNN error: Distance function ({0}) has invalid
signature "
+ "or is not a simple function.".format(fn_dist))
+ if depth <= 0:
+ plpy.error("kNN Error: depth={0} is an invalid value, must be greater "
+ "than 0.".format(depth))
+ if leaf_nodes <= 0:
+ plpy.error("kNN Error: leaf_nodes={0} is an invalid value, must be
greater "
+ "than 0.".format(leaf_nodes))
+ if pow(2,depth) <= leaf_nodes:
+ plpy.error("kNN Error: depth={0}, leaf_nodes={1} is not valid. "
+ "The leaf_nodes value must be lower than
2^depth".format(depth, leaf_nodes))
return k
#
------------------------------------------------------------------------------
+def build_kd_tree(schema_madlib, source_table, output_table, point_column_name,
+ depth, r_id, dim, **kwargs):
+ """
+ KD-tree function to create a partitioning for KNN
+ Args:
+ @param schema_madlib Name of the Madlib Schema
+ @param source_table Training data table
+ @param output_table Name of the table to store kd tree
+ @param point_column_name Name of the column with training data
+ or expression that evaluates to a
+ numeric array
+ @param depth Depth of the kd tree
+ @param r_id Name of the region id column
+ @param dim Name of the dimension column
+ """
+ with MinWarning("error"):
+
+ validate_kd_tree(source_table, output_table, point_column_name, depth)
+ n_features = num_features(source_table, point_column_name)
+
+ clauses = [' 1=1 ']
+ cutoffs = []
+ centers_table = add_postfix(output_table, "_centers")
+ clause_counter = 0
+ current_feature = 1
+ for curr_level in range(depth):
+ for curr_leaf in range(pow(2,curr_level)):
+ clause = clauses[clause_counter]
+ cutoff_sql = """
+ SELECT percentile_disc(0.5)
+ WITHIN GROUP (
+ ORDER BY ({point_column_name})[{current_feature}]
+ ) AS cutoff
+ FROM {source_table}
+ WHERE {clause}
+ """.format(**locals())
+
+ cutoff = plpy.execute(cutoff_sql)[0]['cutoff']
+ cutoff = cutoff if cutoff is not None else "NULL"
+ clause_counter += 1
+
+ cutoffs.append(cutoff)
+ clauses.append(clause +
+ "AND ({point_column_name})[{current_feature}]"
+ " < {cutoff} ".format(**locals()))
+ clauses.append(clause +
+ "AND ({point_column_name})[{current_feature}]"
+ " >= {cutoff} ".format(**locals()))
+ current_feature = current_feature % n_features + 1
+
+ output_table_tree = add_postfix(output_table, "_tree")
+ plpy.execute("CREATE TABLE {0} AS "
+ "SELECT ('{{ {1} }}')::DOUBLE PRECISION[] AS tree".
+ format(output_table_tree,
+ " ,".join(map(str, cutoffs))))
+
+ n_leaves = pow(2,depth)
+ case_when_clause = ' '.join(["WHEN {0} THEN {1}::INTEGER".format(cond,
i)
+ for i, cond in
enumerate(clauses[-n_leaves:])])
+ output_sql = """
+ CREATE TABLE {output_table} AS
+ SELECT *, CASE {case_when_clause} END AS {r_id}
+ FROM {source_table}""".format(**locals())
+ plpy.execute(output_sql)
+
+ plpy.execute("DROP TABLE IF EXISTS {0}".format(centers_table))
+ centers_sql = """
+ CREATE TABLE {centers_table} AS
+ SELECT {r_id}, {schema_madlib}.array_scalar_mult(
+ {schema_madlib}.sum({point_column_name}):: DOUBLE
PRECISION[],
+ (1.0/count(*))::DOUBLE PRECISION) AS __center__
+ FROM {output_table}
+ GROUP BY {r_id}
+ """.format(**locals())
+ plpy.execute(centers_sql)
+ return case_when_clause
+#
------------------------------------------------------------------------------
+
+def validate_kd_tree(source_table, output_table, point_column_name, depth):
+
+ input_tbl_valid(source_table, 'kd_tree')
+ output_tbl_valid(output_table, 'kd_tree')
+ output_tbl_valid(output_table+"_tree", 'kd_tree')
+
+ _assert(is_var_valid(source_table, point_column_name),
+ "kd_tree error: {0} is an invalid column name or expression for "
+ "point_column_name param".format(point_column_name))
+ point_col_type = get_expr_type(point_column_name, source_table)
+ _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
+ "kNN Error: Feature column or expression '{0}' in train table is
not"
+ " an array.".format(point_column_name))
+ if depth <= 0:
+ plpy.error("kNN Error: depth={0} is an invalid value, must be greater "
+ "than 0.".format(depth))
+#
------------------------------------------------------------------------------
+
+def knn_kd_tree(schema_madlib, kd_out, point_source, point_column_name,
point_id,
+ label_column_name, test_source, test_column_name, test_id,
+ interim_table, in_k, output_neighbors, fn_dist, weighted_avg,
+ leaf_nodes, r_id, dim, label_out, comma_label_out_alias,
+ label_name, train, train_id, dist_inverse, test_id_temp,
+ case_when_clause, **kwargs):
+ """
+ KNN function to find the K Nearest neighbours using kd tree
+ Args:
+ @param schema_madlib Name of the Madlib Schema
+ @param kd_out Name of the kd tree table
+ @param point_source Training data table
+ @param point_column_name Name of the column with training data
+ or expression that evaluates to a
+ numeric array
+ @param point_id Name of the column having ids of data
+ point in train data table
+ points.
+ @param label_column_name Name of the column with labels/values
+ of training data points.
+ @param test_source Name of the table containing the test
+ data points.
+ @param test_column_name Name of the column with testing data
+ points or expression that evaluates to
a
+ numeric array
+ @param test_id Name of the column having ids of data
+ points in test data table.
+ @param interim_table Name of the table to store interim
+ results.
+ @param in_k default: 1. Number of nearest
+ neighbors to consider
+ @param output_neighbours Outputs the list of k-nearest neighbors
+ that were used in the voting/averaging.
+ @param fn_dist Distance metrics function. Default is
+ squared_dist_norm2. Following functions
+ are supported :
+ dist_norm1 ,
dist_norm2,squared_dist_norm2,
+ dist_angle , dist_tanimoto
+ Or user defined function with signature
+ DOUBLE PRECISION[] x, DOUBLE
PRECISION[] y -> DOUBLE PRECISION
+ @param weighted_avg Calculates the Regression or
classication of k-NN using
+ the weighted average method.
+ @param leaf_nodes Number of leaf nodes to explore
+ @param r_id Name of the region id column
+ @param dim Name of the dimension column
+ Following parameters are passed to ensure the interim table has
+ identical features to non-kd-tree implementation
+ @param label_out
+ @param comma_label_out_alias
+ @param label_name
+ @param train
+ @param train_id
+ @param dist_inverse
+ @param test_id_temp
+ @param case_when_clause
+ """
+ with MinWarning("error"):
+
+ tree_model = add_postfix(kd_out, "_tree")
+ centers_table = add_postfix(kd_out, "_centers")
+ n_features = num_features(test_source, test_column_name)
+
+ tree = plpy.execute("SELECT * FROM {0}".format(tree_model))[0]['tree']
+ # 'tree' contains only non-leaf nodes,
+ # hence 'n_leaves' is always 1 more than len(tree)
+ n_leaves = len(tree)+1
+
+ depth = int(log(n_leaves, 2))
+
+ # The borders table will have two rows for each dimension (Upper &
lower)
+ # even if a dimension does not have a branch.
+ # Its borders will be -Inf, Inf
+
+ # The first leaf_note is itself,
+ # we expand to n-1 nodes out of 2 * n_features borders
+ quant = float(leaf_nodes) / n_leaves
+
+ test_view = unique_string("test_view")
+ t_col_name = unique_string("t_col_name")
+ plpy.execute("DROP VIEW IF EXISTS {test_view}".format(**locals()))
+
+ test_view_sql = """
+ CREATE VIEW {test_view} AS
+ SELECT {test_id},
+ {test_column_name}::DOUBLE PRECISION[] AS {t_col_name},
+ CASE
+ {case_when_clause}
+ END AS {r_id}
+ FROM {test_source}""".format(**locals())
+ plpy.execute(test_view_sql)
+
+ if leaf_nodes > 1:
+ ext_test_view = unique_string("ext_test_view")
+ ext_test_view_sql = """
+ CREATE VIEW {ext_test_view} AS
+ SELECT {test_id},
+ {t_col_name},
+ {centers_table}.{r_id}
+ FROM {test_view} INNER JOIN
+ (SELECT {test_id}, percentile_disc({quant}) WITHIN
GROUP
+ (ORDER BY {fn_dist}({t_col_name}, __center__)
+ ) AS __dist_center__
+ FROM {test_view}, {centers_table}
+ GROUP BY {test_id}
+ )__q1__ USING ({test_id}), {centers_table}
+ WHERE {fn_dist}({t_col_name}, __center__) <=
__dist_center__
+ """. format(**locals())
+ plpy.execute(ext_test_view_sql)
+ else:
+ ext_test_view = test_view
+
+ sql = """ CREATE TABLE {interim_table} AS
+ SELECT * FROM (
Review comment:
This query is same as the brute force query except for the definition of
train/test and that brute force uses cross join instead of inner join. Do you
think it would help if we combined the two into a single query/function and
treat `knn_kd_tree` as a preprocessing for train/test? That way future versions
(like ball tree) will build upon the same structure of preprocessing the
tables. This will also allow us to keep the temp names local to function.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services