Repository: madlib
Updated Branches:
  refs/heads/master 35818fa39 -> b76a08344


kNN: Accept expressions for point_column_name and test_column_name

JIRA: MADLIB-1060

This commit adds code to allow expressions for point and test column
names in kNN. This also adds test cases for the same in dev-check.

Closes #315


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/b76a0834
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/b76a0834
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/b76a0834

Branch: refs/heads/master
Commit: b76a0834409f705c633f5bf80914542934d04bd7
Parents: 35818fa
Author: hpandeycodeit <[email protected]>
Authored: Mon Aug 27 23:40:13 2018 -0700
Committer: Nandish Jayaram <[email protected]>
Committed: Fri Sep 7 17:46:41 2018 -0700

----------------------------------------------------------------------
 src/ports/postgres/modules/knn/knn.py_in        | 96 +++++++++++++-------
 src/ports/postgres/modules/knn/knn.sql_in       |  6 +-
 src/ports/postgres/modules/knn/test/knn.sql_in  | 40 ++++++++
 .../modules/utilities/validate_args.py_in       | 20 ++--
 4 files changed, 122 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/b76a0834/src/ports/postgres/modules/knn/knn.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in 
b/src/ports/postgres/modules/knn/knn.py_in
index c9ae918..04e74d1 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -35,6 +35,10 @@ from utilities.validate_args import get_expr_type
 from utilities.utilities import _assert
 from utilities.utilities import unique_string
 from utilities.control import MinWarning
+from utilities.validate_args import quote_ident
+from utilities.validate_args import is_var_valid
+from utilities.utilities import NUMERIC, ONLY_ARRAY
+from utilities.utilities import is_valid_psql_type
 
 MAX_WEIGHT_ZERO_DIST = 1e6
 
@@ -53,15 +57,23 @@ def knn_validate_src(schema_madlib, point_source, 
point_column_name, point_id,
 
     if label_column_name and label_column_name.strip():
         cols_in_tbl_valid(point_source, [label_column_name], 'kNN')
-    cols_in_tbl_valid(point_source, (point_column_name, point_id), 'kNN')
-    cols_in_tbl_valid(test_source, (test_column_name, test_id), 'kNN')
 
-    if not is_col_array(point_source, point_column_name):
-        plpy.error("kNN Error: Feature column '{0}' in train table is not"
-                   " an array.".format(point_column_name))
-    if not is_col_array(test_source, test_column_name):
-        plpy.error("kNN Error: Feature column '{0}' in test table is not"
-                   " an array.".format(test_column_name))
+    _assert(is_var_valid(point_source, point_column_name),
+            "kNN error: {0} is an invalid column name or expression for 
point_column_name param".format(point_column_name))
+    point_col_type = get_expr_type(point_column_name, point_source)
+    _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
+            "kNN Error: Feature column or expression '{0}' in train table is 
not"
+            " an array.".format(point_column_name))
+
+    _assert(is_var_valid(test_source, test_column_name),
+            "kNN error: {0} is an invalid column name or expression for 
test_column_name param".format(test_column_name))
+    test_col_type = get_expr_type(test_column_name, test_source)
+    _assert(is_valid_psql_type(test_col_type, NUMERIC | ONLY_ARRAY),
+            "kNN Error: Feature column or expression '{0}' in test table is 
not"
+            " an array.".format(test_column_name))
+
+    cols_in_tbl_valid(point_source, [point_id], 'kNN')
+    cols_in_tbl_valid(test_source, [test_id], 'kNN')
 
     if not array_col_has_no_null(point_source, point_column_name):
         plpy.error("kNN Error: Feature column '{0}' in train table has some"
@@ -124,6 +136,8 @@ def knn(schema_madlib, point_source, point_column_name, 
point_id,
             @param schema_madlib        Name of the Madlib Schema
             @param point_source         Training data table
             @param point_column_name    Name of the column with training data
+                                        or expression that evaluates to a
+                                        numeric array
             @param point_id             Name of the column having ids of data
                                         point in train data table
                                         points.
@@ -132,7 +146,8 @@ def knn(schema_madlib, point_source, point_column_name, 
point_id,
             @param test_source          Name of the table containing the test
                                         data points.
             @param test_column_name     Name of the column with testing data
-                                        points.
+                                        points or expression that evaluates to 
a
+                                        numeric array
             @param test_id              Name of the column having ids of data
                                         points in test data table.
             @param output_table         Name of the table to store final
@@ -160,10 +175,19 @@ def knn(schema_madlib, point_source, point_column_name, 
point_id,
                          test_source, test_column_name, test_id,
                          output_table, k, output_neighbors, fn_dist)
 
+        # Unique Strings
         x_temp_table = unique_string(desp='x_temp_table')
         y_temp_table = unique_string(desp='y_temp_table')
         label_col_temp = unique_string(desp='label_col_temp')
         test_id_temp = unique_string(desp='test_id_temp')
+        train = unique_string(desp='train')
+        test = unique_string(desp='test')
+        p_col_name = unique_string(desp='p_col_name')
+        t_col_name = unique_string(desp='t_col_name')
+        dist = unique_string(desp='dist')
+        train_id = unique_string(desp='train_id')
+        dist_inverse = unique_string(desp='dist_inverse')
+        r = unique_string(desp='r')
 
         if not fn_dist:
             fn_dist = '{0}.squared_dist_norm2'.format(schema_madlib)
@@ -211,7 +235,7 @@ def knn(schema_madlib, point_source, point_column_name, 
point_id,
                                 SELECT
                                     {test_id_temp},
                                     {label_col_temp},
-                                    sum(dist_inverse) data_sum
+                                    sum({dist_inverse}) data_sum
                                 FROM pg_temp.{interim_table}
                                 GROUP BY {test_id_temp},
                                          {label_col_temp}
@@ -226,53 +250,63 @@ def knn(schema_madlib, point_source, point_column_name, 
point_id,
                     view_grp_by = ", vw.{0}".format(label_col_temp)
                     pred_out = ", vw.{0}".format(label_col_temp)
                 else:
-                    pred_out = ", {0}.mode({1})".format(schema_madlib, 
label_col_temp)
+                    pred_out = ", {0}.mode({1})".format(
+                        schema_madlib, label_col_temp)
             else:
                 if weighted_avg:
-                    pred_out = (", sum({0} * dist_inverse) / 
sum(dist_inverse)".
-                                format(label_col_temp))
+                    pred_out = (", sum({0} * {dist_inverse}) / 
sum({dist_inverse})".
+                                format(label_col_temp, 
dist_inverse=dist_inverse))
                 else:
                     pred_out = ", avg({0})".format(label_col_temp)
 
             pred_out += " AS prediction"
-            label_out = (", train.{label_column_name}{cast_to_int}"
+            label_out = (", {train}.{label_column_name}{cast_to_int}"
                          " AS {label_col_temp}").format(**locals())
             comma_label_out_alias = ', ' + label_col_temp
+            label_name = ", {label_column_name}".format(
+                label_column_name=label_column_name)
+
         else:
             pred_out = ""
             label_out = ""
             comma_label_out_alias = ""
+            label_name = ""
 
         # interim_table picks the 'k' nearest neighbors for each test point
         if output_neighbors:
-            knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
-                             "knn_temp.dist_inverse DESC) AS 
k_nearest_neighbours ")
+            knn_neighbors = (", array_agg(knn_temp.{train_id} ORDER BY "
+                             "knn_temp.{dist_inverse} DESC) AS 
k_nearest_neighbours ").format(**locals())
         else:
             knn_neighbors = ''
         plpy.execute("""
             CREATE TEMP TABLE {interim_table} AS
                 SELECT * FROM (
                     SELECT row_number() over
-                            (partition by {test_id_temp} order by dist) AS r,
+                            (partition by {test_id_temp} order by {dist}) AS 
{r},
                             {test_id_temp},
-                            train_id,
-                            CASE WHEN dist = 0.0 THEN {max_weight_zero_dist}
-                                 ELSE 1.0 / dist
-                            END AS dist_inverse
+                            {train_id},
+                            CASE WHEN {dist} = 0.0 THEN {max_weight_zero_dist}
+                                 ELSE 1.0 / {dist}
+                            END AS {dist_inverse}
                             {comma_label_out_alias}
                     FROM (
-                        SELECT test.{test_id} AS {test_id_temp},
-                            train.{point_id} as train_id,
+                        SELECT {test}.{test_id} AS {test_id_temp},
+                            {train}.{point_id} as {train_id},
                             {fn_dist}(
-                                train.{point_column_name},
-                                test.{test_column_name})
-                            AS dist
+                                {p_col_name},
+                                {t_col_name})
+                            AS {dist}
                             {label_out}
-                            FROM {point_source} AS train,
-                                 {test_source} AS test
+                            FROM
+                            (
+                            SELECT {point_id} , {point_column_name} as 
{p_col_name} {label_name} from {point_source}
+                            ) {train},
+                            (
+                            SELECT {test_id} ,{test_column_name} as 
{t_col_name} from {test_source}
+                            ) {test}
                         ) {x_temp_table}
                     ) {y_temp_table}
-            WHERE {y_temp_table}.r <= {k}
+            WHERE {y_temp_table}.{r} <= {k}
             """.format(max_weight_zero_dist=MAX_WEIGHT_ZERO_DIST, **locals()))
 
         sql = """
@@ -280,7 +314,7 @@ def knn(schema_madlib, point_source, point_column_name, 
point_id,
                 {view_def}
                 SELECT
                     knn_temp.{test_id_temp} AS id,
-                    knn_test.{test_column_name}
+                    {test_column_name} as "{test_column_name}"
                     {pred_out}
                     {knn_neighbors}
                 FROM
@@ -290,7 +324,7 @@ def knn(schema_madlib, point_source, point_column_name, 
point_id,
                 ON knn_temp.{test_id_temp} = knn_test.{test_id}
                     {view_join}
                 GROUP BY knn_temp.{test_id_temp},
-                         knn_test.{test_column_name}
+                    {test_column_name}
                          {view_grp_by}
             """
         plpy.execute(sql.format(**locals()))

http://git-wip-us.apache.org/repos/asf/madlib/blob/b76a0834/src/ports/postgres/modules/knn/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.sql_in 
b/src/ports/postgres/modules/knn/knn.sql_in
index 24f19c7..49e0c22 100644
--- a/src/ports/postgres/modules/knn/knn.sql_in
+++ b/src/ports/postgres/modules/knn/knn.sql_in
@@ -93,7 +93,8 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
 </dd>
 
 <dt>point_column_name</dt>
-<dd>TEXT. Name of the column with training data points.</dd>
+<dd>TEXT. Name of the column with training data points 
+or expression that evaluates to a numeric array</dd>
 
 <dt>point_id</dt>
 <dd>TEXT. Name of the column in 'point_source’ containing source data ids.
@@ -115,7 +116,8 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
 </dd>
 
 <dt>test_column_name</dt>
-<dd>TEXT. Name of the column with testing data points.</dd>
+<dd>TEXT. Name of the column with testing data points 
+or expression that evaluates to a numeric array</dd>
 
 <dt>test_id</dt>
 <dd>TEXT. Name of the column having ids of data points in test data table.</dd>

http://git-wip-us.apache.org/repos/asf/madlib/blob/b76a0834/src/ports/postgres/modules/knn/test/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in 
b/src/ports/postgres/modules/knn/test/knn.sql_in
index 20348af..6dbed36 100644
--- a/src/ports/postgres/modules/knn/test/knn.sql_in
+++ b/src/ports/postgres/modules/knn/test/knn.sql_in
@@ -70,6 +70,25 @@ copy knn_test_data (id, data) from stdin delimiter '|';
 5|{2,90}
 6|{50,45}
 \.
+drop table if exists knn_train_data_expr;
+create table knn_train_data_expr (
+id  integer,
+data1    integer,
+data2    integer,
+label   integer);
+copy knn_train_data_expr (id, data1 , data2, label) from stdin delimiter '|';
+1| 1  |  1  |1
+2| 2  |  2  |1
+3| 3  |  3  |1
+4| 4  |  4   |1
+5| 4  |  5   |1
+6| 20 |  50  |0
+7| 10 |  31  |0
+8| 81 |  13  |0
+9| 1  |  111 |0
+\.
+
+
 
 drop table if exists madlib_knn_result_classification;
 select 
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
@@ -122,5 +141,26 @@ select 
knn('knn_train_data_reg','data','id','label','knn_test_data','data','id',
 select assert(array_agg(prediction::numeric order by 
id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') from 
madlib_knn_result_regression;
 
 
+drop table if exists madlib_knn_result_classification;
+select 
knn('knn_train_data','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
 True);
+select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}', 
'Wrong output in classification') from madlib_knn_result_classification;
+
+drop table if exists madlib_knn_result_regression;
+select 
knn('knn_train_data_reg','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
 True);
+select assert(array_agg(prediction::numeric order by 
id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') from 
madlib_knn_result_regression;
+
+
+
+
+drop table if exists madlib_knn_result_classification;
+select 
knn('knn_train_data_expr','ARRAY[data1,data2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
 True);
+select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}', 
'Wrong output in classification') from madlib_knn_result_classification;
+
+
+
+drop table if exists madlib_knn_result_classification;
+select 
knn('knn_train_data','data','id',NULL,'knn_test_data','data','id','madlib_knn_result_classification',3);
+select assert(array_agg(x order by id)= '{1,2,3}','Wrong output in 
classification with k=3') from (select unnest(k_nearest_neighbours) as x, id 
from madlib_knn_result_classification where id = 1 order by x asc) y;
+
 select knn();
 select knn('help');

http://git-wip-us.apache.org/repos/asf/madlib/blob/b76a0834/src/ports/postgres/modules/utilities/validate_args.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in 
b/src/ports/postgres/modules/utilities/validate_args.py_in
index 28e6aa4..b400e1a 100644
--- a/src/ports/postgres/modules/utilities/validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/validate_args.py_in
@@ -267,6 +267,7 @@ def get_first_schema(table_name):
     return None
 # -------------------------------------------------------------------------
 
+
 def drop_tables(table_list):
     """
         Drop tables specified in table_list.
@@ -275,6 +276,7 @@ def drop_tables(table_list):
     if drop_str:
         plpy.execute("DROP TABLE IF EXISTS {0}".format(drop_str))
 
+
 def table_is_empty(tbl, filter_str=None):
     """
     Returns True if the input table has no rows
@@ -376,7 +378,8 @@ def get_expr_type(expressions, tbl):
         str
     """
     # FIXME: Below transformation exist to ensure backwards compatibility
-    # Remove this when all callers have been modified to pass an Iterable 
'expressions'
+    # Remove this when all callers have been modified to pass an Iterable
+    # 'expressions'
     if (isinstance(expressions, StringTypes) or
             not isinstance(expressions, Iterable)):
         expressions = [expressions]
@@ -537,7 +540,7 @@ def explicit_bool_to_text(tbl, cols, schema_madlib):
     """
     Patch madlib.bool_to_text for columns that are of type boolean.
     """
-    m4_ifdef(<!__HAS_BOOL_TO_TEXT_CAST__!>, <!return cols!>, <!!>)
+    m4_ifdef(<!__HAS_BOOL_TO_TEXT_CAST__!> , <!return cols!> , <!!>)
     patched = []
     col_types = get_expr_type(cols, tbl)
     for col, col_type in zip(cols, col_types):
@@ -565,13 +568,14 @@ def array_col_has_no_null(tbl, col):
                        FROM {tbl}
                        """.format(col=col, tbl=tbl))[0]["dim"]
     for i in range(1, dim + 1):
-        n_non_nulls = plpy.execute("SELECT count({col}[{i}]) FROM {tbl}".
+        n_non_nulls = plpy.execute("SELECT count(({col})[{i}]) FROM {tbl}".
                                    format(col=col, tbl=tbl, i=i))[0]["count"]
         if row_len != n_non_nulls:
             return False
     return True
 # -------------------------------------------------------------------------
 
+
 def get_col_dimension(tbl, col_name, dim=1):
     """
     Returns upper bound of the requested array dimension
@@ -588,6 +592,7 @@ def get_col_dimension(tbl, col_name, dim=1):
         """.format(col_name=col_name, dim=dim, tbl=tbl))[0]["dimension"]
     return col_dim
 
+
 def _tbl_dimension_rownum(schema_madlib, tbl, col_name, skip_row_count=False):
     """
     Measure the dimension and row number of source data table
@@ -694,6 +699,7 @@ def regproc_valid(qualified_name, args_str, module):
             """{module} error: Required function 
"{qualified_name}({args_str})" not found!""".format(**locals()))
 # -------------------------------------------------------------------------
 
+
 def does_exclude_reserved(targets, reserved):
     """
         Function to check if any target column name is part of reserved column
@@ -701,10 +707,10 @@ def does_exclude_reserved(targets, reserved):
     """
     intersect = frozenset(targets).intersection(frozenset(reserved))
     if len(intersect) != 0:
-        plpy.error( "Error: Conflicting column names.\n"
-                    "Some predefined keyword(s) ({0}) are not allowed "
-                    "for column names in module input params.".format(
-                    ', '.join(intersect)))
+        plpy.error("Error: Conflicting column names.\n"
+                   "Some predefined keyword(s) ({0}) are not allowed "
+                   "for column names in module input params.".format(
+                       ', '.join(intersect)))
 # -------------------------------------------------------------------------
 
 import unittest

Reply via email to