Repository: madlib Updated Branches: refs/heads/master 35818fa39 -> b76a08344
kNN: Accept expressions for point_column_name and test_column_name JIRA: MADLIB-1060 This commit adds code to allow expressions for point and test column names in kNN. This also adds test cases for the same in dev-check. Closes #315 Project: http://git-wip-us.apache.org/repos/asf/madlib/repo Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/b76a0834 Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/b76a0834 Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/b76a0834 Branch: refs/heads/master Commit: b76a0834409f705c633f5bf80914542934d04bd7 Parents: 35818fa Author: hpandeycodeit <[email protected]> Authored: Mon Aug 27 23:40:13 2018 -0700 Committer: Nandish Jayaram <[email protected]> Committed: Fri Sep 7 17:46:41 2018 -0700 ---------------------------------------------------------------------- src/ports/postgres/modules/knn/knn.py_in | 96 +++++++++++++------- src/ports/postgres/modules/knn/knn.sql_in | 6 +- src/ports/postgres/modules/knn/test/knn.sql_in | 40 ++++++++ .../modules/utilities/validate_args.py_in | 20 ++-- 4 files changed, 122 insertions(+), 40 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/madlib/blob/b76a0834/src/ports/postgres/modules/knn/knn.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in index c9ae918..04e74d1 100644 --- a/src/ports/postgres/modules/knn/knn.py_in +++ b/src/ports/postgres/modules/knn/knn.py_in @@ -35,6 +35,10 @@ from utilities.validate_args import get_expr_type from utilities.utilities import _assert from utilities.utilities import unique_string from utilities.control import MinWarning +from utilities.validate_args import quote_ident +from utilities.validate_args import is_var_valid +from utilities.utilities import NUMERIC, ONLY_ARRAY +from utilities.utilities import is_valid_psql_type MAX_WEIGHT_ZERO_DIST = 1e6 @@ -53,15 +57,23 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id, if label_column_name and label_column_name.strip(): cols_in_tbl_valid(point_source, [label_column_name], 'kNN') - cols_in_tbl_valid(point_source, (point_column_name, point_id), 'kNN') - cols_in_tbl_valid(test_source, (test_column_name, test_id), 'kNN') - if not is_col_array(point_source, point_column_name): - plpy.error("kNN Error: Feature column '{0}' in train table is not" - " an array.".format(point_column_name)) - if not is_col_array(test_source, test_column_name): - plpy.error("kNN Error: Feature column '{0}' in test table is not" - " an array.".format(test_column_name)) + _assert(is_var_valid(point_source, point_column_name), + "kNN error: {0} is an invalid column name or expression for point_column_name param".format(point_column_name)) + point_col_type = get_expr_type(point_column_name, point_source) + _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY), + "kNN Error: Feature column or expression '{0}' in train table is not" + " an array.".format(point_column_name)) + + _assert(is_var_valid(test_source, test_column_name), + "kNN error: {0} is an invalid column name or expression for test_column_name param".format(test_column_name)) + test_col_type = get_expr_type(test_column_name, test_source) + _assert(is_valid_psql_type(test_col_type, NUMERIC | ONLY_ARRAY), + "kNN Error: Feature column or expression '{0}' in test table is not" + " an array.".format(test_column_name)) + + cols_in_tbl_valid(point_source, [point_id], 'kNN') + cols_in_tbl_valid(test_source, [test_id], 'kNN') if not array_col_has_no_null(point_source, point_column_name): plpy.error("kNN Error: Feature column '{0}' in train table has some" @@ -124,6 +136,8 @@ def knn(schema_madlib, point_source, point_column_name, point_id, @param schema_madlib Name of the Madlib Schema @param point_source Training data table @param point_column_name Name of the column with training data + or expression that evaluates to a + numeric array @param point_id Name of the column having ids of data point in train data table points. @@ -132,7 +146,8 @@ def knn(schema_madlib, point_source, point_column_name, point_id, @param test_source Name of the table containing the test data points. @param test_column_name Name of the column with testing data - points. + points or expression that evaluates to a + numeric array @param test_id Name of the column having ids of data points in test data table. @param output_table Name of the table to store final @@ -160,10 +175,19 @@ def knn(schema_madlib, point_source, point_column_name, point_id, test_source, test_column_name, test_id, output_table, k, output_neighbors, fn_dist) + # Unique Strings x_temp_table = unique_string(desp='x_temp_table') y_temp_table = unique_string(desp='y_temp_table') label_col_temp = unique_string(desp='label_col_temp') test_id_temp = unique_string(desp='test_id_temp') + train = unique_string(desp='train') + test = unique_string(desp='test') + p_col_name = unique_string(desp='p_col_name') + t_col_name = unique_string(desp='t_col_name') + dist = unique_string(desp='dist') + train_id = unique_string(desp='train_id') + dist_inverse = unique_string(desp='dist_inverse') + r = unique_string(desp='r') if not fn_dist: fn_dist = '{0}.squared_dist_norm2'.format(schema_madlib) @@ -211,7 +235,7 @@ def knn(schema_madlib, point_source, point_column_name, point_id, SELECT {test_id_temp}, {label_col_temp}, - sum(dist_inverse) data_sum + sum({dist_inverse}) data_sum FROM pg_temp.{interim_table} GROUP BY {test_id_temp}, {label_col_temp} @@ -226,53 +250,63 @@ def knn(schema_madlib, point_source, point_column_name, point_id, view_grp_by = ", vw.{0}".format(label_col_temp) pred_out = ", vw.{0}".format(label_col_temp) else: - pred_out = ", {0}.mode({1})".format(schema_madlib, label_col_temp) + pred_out = ", {0}.mode({1})".format( + schema_madlib, label_col_temp) else: if weighted_avg: - pred_out = (", sum({0} * dist_inverse) / sum(dist_inverse)". - format(label_col_temp)) + pred_out = (", sum({0} * {dist_inverse}) / sum({dist_inverse})". + format(label_col_temp, dist_inverse=dist_inverse)) else: pred_out = ", avg({0})".format(label_col_temp) pred_out += " AS prediction" - label_out = (", train.{label_column_name}{cast_to_int}" + label_out = (", {train}.{label_column_name}{cast_to_int}" " AS {label_col_temp}").format(**locals()) comma_label_out_alias = ', ' + label_col_temp + label_name = ", {label_column_name}".format( + label_column_name=label_column_name) + else: pred_out = "" label_out = "" comma_label_out_alias = "" + label_name = "" # interim_table picks the 'k' nearest neighbors for each test point if output_neighbors: - knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY " - "knn_temp.dist_inverse DESC) AS k_nearest_neighbours ") + knn_neighbors = (", array_agg(knn_temp.{train_id} ORDER BY " + "knn_temp.{dist_inverse} DESC) AS k_nearest_neighbours ").format(**locals()) else: knn_neighbors = '' plpy.execute(""" CREATE TEMP TABLE {interim_table} AS SELECT * FROM ( SELECT row_number() over - (partition by {test_id_temp} order by dist) AS r, + (partition by {test_id_temp} order by {dist}) AS {r}, {test_id_temp}, - train_id, - CASE WHEN dist = 0.0 THEN {max_weight_zero_dist} - ELSE 1.0 / dist - END AS dist_inverse + {train_id}, + CASE WHEN {dist} = 0.0 THEN {max_weight_zero_dist} + ELSE 1.0 / {dist} + END AS {dist_inverse} {comma_label_out_alias} FROM ( - SELECT test.{test_id} AS {test_id_temp}, - train.{point_id} as train_id, + SELECT {test}.{test_id} AS {test_id_temp}, + {train}.{point_id} as {train_id}, {fn_dist}( - train.{point_column_name}, - test.{test_column_name}) - AS dist + {p_col_name}, + {t_col_name}) + AS {dist} {label_out} - FROM {point_source} AS train, - {test_source} AS test + FROM + ( + SELECT {point_id} , {point_column_name} as {p_col_name} {label_name} from {point_source} + ) {train}, + ( + SELECT {test_id} ,{test_column_name} as {t_col_name} from {test_source} + ) {test} ) {x_temp_table} ) {y_temp_table} - WHERE {y_temp_table}.r <= {k} + WHERE {y_temp_table}.{r} <= {k} """.format(max_weight_zero_dist=MAX_WEIGHT_ZERO_DIST, **locals())) sql = """ @@ -280,7 +314,7 @@ def knn(schema_madlib, point_source, point_column_name, point_id, {view_def} SELECT knn_temp.{test_id_temp} AS id, - knn_test.{test_column_name} + {test_column_name} as "{test_column_name}" {pred_out} {knn_neighbors} FROM @@ -290,7 +324,7 @@ def knn(schema_madlib, point_source, point_column_name, point_id, ON knn_temp.{test_id_temp} = knn_test.{test_id} {view_join} GROUP BY knn_temp.{test_id_temp}, - knn_test.{test_column_name} + {test_column_name} {view_grp_by} """ plpy.execute(sql.format(**locals())) http://git-wip-us.apache.org/repos/asf/madlib/blob/b76a0834/src/ports/postgres/modules/knn/knn.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/knn/knn.sql_in b/src/ports/postgres/modules/knn/knn.sql_in index 24f19c7..49e0c22 100644 --- a/src/ports/postgres/modules/knn/knn.sql_in +++ b/src/ports/postgres/modules/knn/knn.sql_in @@ -93,7 +93,8 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>. </dd> <dt>point_column_name</dt> -<dd>TEXT. Name of the column with training data points.</dd> +<dd>TEXT. Name of the column with training data points +or expression that evaluates to a numeric array</dd> <dt>point_id</dt> <dd>TEXT. Name of the column in 'point_sourceâ containing source data ids. @@ -115,7 +116,8 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>. </dd> <dt>test_column_name</dt> -<dd>TEXT. Name of the column with testing data points.</dd> +<dd>TEXT. Name of the column with testing data points +or expression that evaluates to a numeric array</dd> <dt>test_id</dt> <dd>TEXT. Name of the column having ids of data points in test data table.</dd> http://git-wip-us.apache.org/repos/asf/madlib/blob/b76a0834/src/ports/postgres/modules/knn/test/knn.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in b/src/ports/postgres/modules/knn/test/knn.sql_in index 20348af..6dbed36 100644 --- a/src/ports/postgres/modules/knn/test/knn.sql_in +++ b/src/ports/postgres/modules/knn/test/knn.sql_in @@ -70,6 +70,25 @@ copy knn_test_data (id, data) from stdin delimiter '|'; 5|{2,90} 6|{50,45} \. +drop table if exists knn_train_data_expr; +create table knn_train_data_expr ( +id integer, +data1 integer, +data2 integer, +label integer); +copy knn_train_data_expr (id, data1 , data2, label) from stdin delimiter '|'; +1| 1 | 1 |1 +2| 2 | 2 |1 +3| 3 | 3 |1 +4| 4 | 4 |1 +5| 4 | 5 |1 +6| 20 | 50 |0 +7| 10 | 31 |0 +8| 81 | 13 |0 +9| 1 | 111 |0 +\. + + drop table if exists madlib_knn_result_classification; select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',False); @@ -122,5 +141,26 @@ select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id', select assert(array_agg(prediction::numeric order by id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') from madlib_knn_result_regression; +drop table if exists madlib_knn_result_classification; +select knn('knn_train_data','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True); +select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}', 'Wrong output in classification') from madlib_knn_result_classification; + +drop table if exists madlib_knn_result_regression; +select knn('knn_train_data_reg','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True); +select assert(array_agg(prediction::numeric order by id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') from madlib_knn_result_regression; + + + + +drop table if exists madlib_knn_result_classification; +select knn('knn_train_data_expr','ARRAY[data1,data2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True); +select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}', 'Wrong output in classification') from madlib_knn_result_classification; + + + +drop table if exists madlib_knn_result_classification; +select knn('knn_train_data','data','id',NULL,'knn_test_data','data','id','madlib_knn_result_classification',3); +select assert(array_agg(x order by id)= '{1,2,3}','Wrong output in classification with k=3') from (select unnest(k_nearest_neighbours) as x, id from madlib_knn_result_classification where id = 1 order by x asc) y; + select knn(); select knn('help'); http://git-wip-us.apache.org/repos/asf/madlib/blob/b76a0834/src/ports/postgres/modules/utilities/validate_args.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in index 28e6aa4..b400e1a 100644 --- a/src/ports/postgres/modules/utilities/validate_args.py_in +++ b/src/ports/postgres/modules/utilities/validate_args.py_in @@ -267,6 +267,7 @@ def get_first_schema(table_name): return None # ------------------------------------------------------------------------- + def drop_tables(table_list): """ Drop tables specified in table_list. @@ -275,6 +276,7 @@ def drop_tables(table_list): if drop_str: plpy.execute("DROP TABLE IF EXISTS {0}".format(drop_str)) + def table_is_empty(tbl, filter_str=None): """ Returns True if the input table has no rows @@ -376,7 +378,8 @@ def get_expr_type(expressions, tbl): str """ # FIXME: Below transformation exist to ensure backwards compatibility - # Remove this when all callers have been modified to pass an Iterable 'expressions' + # Remove this when all callers have been modified to pass an Iterable + # 'expressions' if (isinstance(expressions, StringTypes) or not isinstance(expressions, Iterable)): expressions = [expressions] @@ -537,7 +540,7 @@ def explicit_bool_to_text(tbl, cols, schema_madlib): """ Patch madlib.bool_to_text for columns that are of type boolean. """ - m4_ifdef(<!__HAS_BOOL_TO_TEXT_CAST__!>, <!return cols!>, <!!>) + m4_ifdef(<!__HAS_BOOL_TO_TEXT_CAST__!> , <!return cols!> , <!!>) patched = [] col_types = get_expr_type(cols, tbl) for col, col_type in zip(cols, col_types): @@ -565,13 +568,14 @@ def array_col_has_no_null(tbl, col): FROM {tbl} """.format(col=col, tbl=tbl))[0]["dim"] for i in range(1, dim + 1): - n_non_nulls = plpy.execute("SELECT count({col}[{i}]) FROM {tbl}". + n_non_nulls = plpy.execute("SELECT count(({col})[{i}]) FROM {tbl}". format(col=col, tbl=tbl, i=i))[0]["count"] if row_len != n_non_nulls: return False return True # ------------------------------------------------------------------------- + def get_col_dimension(tbl, col_name, dim=1): """ Returns upper bound of the requested array dimension @@ -588,6 +592,7 @@ def get_col_dimension(tbl, col_name, dim=1): """.format(col_name=col_name, dim=dim, tbl=tbl))[0]["dimension"] return col_dim + def _tbl_dimension_rownum(schema_madlib, tbl, col_name, skip_row_count=False): """ Measure the dimension and row number of source data table @@ -694,6 +699,7 @@ def regproc_valid(qualified_name, args_str, module): """{module} error: Required function "{qualified_name}({args_str})" not found!""".format(**locals())) # ------------------------------------------------------------------------- + def does_exclude_reserved(targets, reserved): """ Function to check if any target column name is part of reserved column @@ -701,10 +707,10 @@ def does_exclude_reserved(targets, reserved): """ intersect = frozenset(targets).intersection(frozenset(reserved)) if len(intersect) != 0: - plpy.error( "Error: Conflicting column names.\n" - "Some predefined keyword(s) ({0}) are not allowed " - "for column names in module input params.".format( - ', '.join(intersect))) + plpy.error("Error: Conflicting column names.\n" + "Some predefined keyword(s) ({0}) are not allowed " + "for column names in module input params.".format( + ', '.join(intersect))) # ------------------------------------------------------------------------- import unittest
