iyerr3 commented on a change in pull request #352: Feature/kd tree knn URL: https://github.com/apache/madlib/pull/352#discussion_r255757154
########## File path: src/ports/postgres/modules/knn/knn.py_in ########## @@ -124,16 +137,312 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id, """.format(fn_dist=fn_dist, profunc=profunc))[0]['output'] if is_invalid_func or (fn_dist not in dist_functions): - plpy.error("KNN error: Distance function has invalid signature " - "or is not a simple function.") - + plpy.error("KNN error: Distance function ({0}) has invalid signature " + "or is not a simple function.".format(fn_dist)) + if depth <= 0: + plpy.error("kNN Error: depth={0} is an invalid value, must be greater " + "than 0.".format(depth)) + if leaf_nodes <= 0: + plpy.error("kNN Error: leaf_nodes={0} is an invalid value, must be greater " + "than 0.".format(leaf_nodes)) + if pow(2,depth) <= leaf_nodes: + plpy.error("kNN Error: depth={0}, leaf_nodes={1} is not valid. " + "The leaf_nodes value must be lower than 2^depth".format(depth, leaf_nodes)) return k # ------------------------------------------------------------------------------ +def kd_tree(schema_madlib, source_table, output_table, point_column_name, depth, + r_id, dim, **kwargs): + """ + KD-tree function to create a partitioning for KNN + Args: + @param schema_madlib Name of the Madlib Schema + @param source_table Training data table + @param output_table Name of the table to store kd tree + @param point_column_name Name of the column with training data + or expression that evaluates to a + numeric array + @param depth Depth of the kd tree + @param r_id Name of the region id column + @param dim Name of the dimension column + + """ + with MinWarning("error"): + + validate_kd_tree(source_table, output_table, point_column_name, depth) + n_features = num_features(source_table, point_column_name) + + clauses = [" WHERE 1=1 "] + cutoffs = [] + centers_table = add_postfix(output_table, "_centers") + clause_counter = 0 + current_feature = 1 + for curr_level in range(depth): + for curr_leaf in range(pow(2,curr_level)): + clause = clauses[clause_counter] + cutoff_sql = """ + SELECT percentile_disc(0.5) + WITHIN GROUP ( + ORDER BY ({point_column_name})[{current_feature}] + ) AS cutoff + FROM {source_table} + {clause} + """.format(**locals()) + + cutoff = plpy.execute(cutoff_sql)[0]['cutoff'] + cutoff = cutoff if cutoff is not None else "NULL" + clause_counter += 1 + + cutoffs.append(cutoff) + clauses.append(clause + + "AND ({point_column_name})[{current_feature}]" + " < {cutoff} ".format(**locals())) + clauses.append(clause + + "AND ({point_column_name})[{current_feature}]" + " >= {cutoff} ".format(**locals())) + current_feature = current_feature % n_features + 1 + + output_table_tree = add_postfix(output_table, "_tree") + plpy.execute("CREATE TABLE {0} AS " + "SELECT ('{{ {1} }}')::DOUBLE PRECISION[] AS tree". + format(output_table_tree, + " ,".join(map(str, cutoffs)))) + + n_leaves = pow(2,depth) + case_when_clause = ["WHEN {0} THEN {1}::INTEGER".format(cond[14:], i) + for i, cond in enumerate(clauses[-n_leaves:])] + output_sql = """ + CREATE TABLE {output_table} AS + SELECT *, CASE {cases} END AS {r_id} + FROM {source_table}""".format( + cases = ' '.join(case_when_clause),**locals()) + plpy.execute(output_sql) + + plpy.execute("DROP TABLE IF EXISTS {0}".format(centers_table)) + centers_sql = """ + CREATE TABLE {centers_table} AS + SELECT {r_id}, {schema_madlib}.array_scalar_mult( + {schema_madlib}.sum({point_column_name}):: DOUBLE PRECISION[], + (1.0/count(*))::DOUBLE PRECISION) AS __center__ + FROM {output_table} + GROUP BY {r_id} + """.format(**locals()) + plpy.execute(centers_sql) +# ------------------------------------------------------------------------------ + +def validate_kd_tree(source_table, output_table, point_column_name, depth): + + input_tbl_valid(source_table, 'kd_tree') + output_tbl_valid(output_table, 'kd_tree') + output_tbl_valid(output_table+"_tree", 'kd_tree') + + _assert(is_var_valid(source_table, point_column_name), + "kd_tree error: {0} is an invalid column name or expression for " + "point_column_name param".format(point_column_name)) + point_col_type = get_expr_type(point_column_name, source_table) + _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY), + "kNN Error: Feature column or expression '{0}' in train table is not" + " an array.".format(point_column_name)) + if depth <= 0: + plpy.error("kNN Error: depth={0} is an invalid value, must be greater " + "than 0.".format(depth)) +# ------------------------------------------------------------------------------ + +def knn_tree(schema_madlib, kd_out, point_source, point_column_name, point_id, + label_column_name, test_source, test_column_name, test_id, + interim_table, in_k, output_neighbors, fn_dist, weighted_avg, + leaf_nodes, r_id, dim, label_out, comma_label_out_alias, + label_name, train, train_id, dist_inverse, test_id_temp, **kwargs): + """ + KNN function to find the K Nearest neighbours + Args: + @param schema_madlib Name of the Madlib Schema + @param kd_out Name of the kd tree table + @param point_source Training data table + @param point_column_name Name of the column with training data + or expression that evaluates to a + numeric array + @param point_id Name of the column having ids of data + point in train data table + points. + @param label_column_name Name of the column with labels/values + of training data points. + @param test_source Name of the table containing the test + data points. + @param test_column_name Name of the column with testing data + points or expression that evaluates to a + numeric array + @param test_id Name of the column having ids of data + points in test data table. + @param interim_table Name of the table to store interim + results. + @param in_k default: 1. Number of nearest + neighbors to consider + @param output_neighbours Outputs the list of k-nearest neighbors + that were used in the voting/averaging. + @param fn_dist Distance metrics function. Default is + squared_dist_norm2. Following functions + are supported : + dist_norm1 , dist_norm2,squared_dist_norm2, + dist_angle , dist_tanimoto + Or user defined function with signature + DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION + @param weighted_avg Calculates the Regression or classication of k-NN using + the weighted average method. + @param leaf_nodes Number of leaf nodes to explore + @param r_id Name of the region id column + @param dim Name of the dimension column + Following parameters are passed to ensure the interim table has + identical features to non-kd-tree implementation + @param label_out + @param comma_label_out_alias + @param label_name + @param train + @param train_id + @param dist_inverse + @param test_id_temp + """ + with MinWarning("error"): + + tree_model = add_postfix(kd_out, "_tree") + centers_table = add_postfix(kd_out, "_centers") + n_features = num_features(test_source, test_column_name) + + tree = plpy.execute("SELECT * FROM {0}".format(tree_model))[0]['tree'] + # 'tree' contains only non-leaf nodes, + # hence 'n_leaves' is always 1 more than len(tree) + n_leaves = len(tree)+1 + + depth = int(log(n_leaves, 2)) + + # The borders table will have two rows for each dimension (Upper & lower) + # even if a dimension does not have a branch. + # Its borders will be -Inf, Inf + + # The first leaf_note is itself, + # we expand to n-1 nodes out of 2 * n_features borders + quant = float(leaf_nodes) / n_leaves + + clause_counter = 0 Review comment: The logic here is the same as the one in tree building. We can either abstract to a separate function or just store the complete string as the tree to avoid rebuilding here. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services