Repository: madlib Updated Branches: refs/heads/master daf67f81b -> 5a291aa81
KNN: Add additional distance metrics JIRA: MADLIB-1059 Project: http://git-wip-us.apache.org/repos/asf/madlib/repo Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/5a291aa8 Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/5a291aa8 Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/5a291aa8 Branch: refs/heads/master Commit: 5a291aa81fa7b735b43fc0198b0eb2bf2955364b Parents: daf67f8 Author: Himanshu Pandey <[email protected]> Authored: Thu Nov 30 13:35:54 2017 -0800 Committer: Rahul Iyer <[email protected]> Committed: Thu Nov 30 13:35:54 2017 -0800 ---------------------------------------------------------------------- src/ports/postgres/modules/knn/knn.py_in | 42 ++++++++++--- src/ports/postgres/modules/knn/knn.sql_in | 37 +++++++++--- src/ports/postgres/modules/knn/test/knn.sql_in | 67 +++++++++++++++++++-- 3 files changed, 126 insertions(+), 20 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/madlib/blob/5a291aa8/src/ports/postgres/modules/knn/knn.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in index 0e21cdd..caa89e0 100644 --- a/src/ports/postgres/modules/knn/knn.py_in +++ b/src/ports/postgres/modules/knn/knn.py_in @@ -38,7 +38,7 @@ from utilities.control import MinWarning def knn_validate_src(schema_madlib, point_source, point_column_name, point_id, label_column_name, test_source, test_column_name, - test_id, output_table, k, output_neighbors, **kwargs): + test_id, output_table, k, output_neighbors, fn_dist, **kwargs): input_tbl_valid(point_source, 'kNN') input_tbl_valid(test_source, 'kNN') output_tbl_valid(output_table, 'kNN') @@ -88,12 +88,28 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id, plpy.error("kNN Error: Data type '{0}' is not a valid type for" " column '{1}' in table '{2}'.". format(col_type_test, test_id, test_source)) + + if fn_dist: + fn_dist = fn_dist.lower().strip() + dist_functions = set([schema_madlib + dist for dist in + ('.dist_norm1', '.dist_norm2', '.squared_dist_norm2', '.dist_angle', '.dist_tanimoto')]) + + is_invalid_func = plpy.execute( + """select prorettype != 'DOUBLE PRECISION'::regtype + OR proisagg = TRUE AS OUTPUT from pg_proc where + oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE PRECISION[])'::regprocedure; + """.format(**locals()))[0]['output'] + + if is_invalid_func or fn_dist not in dist_functions: + plpy.error( + "KNN error: Distance function has wrong signature or is not a simple function.") + return k # ------------------------------------------------------------------------------ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_name, - test_source, test_column_name, test_id, output_table, k, output_neighbors): + test_source, test_column_name, test_id, output_table, k, output_neighbors, fn_dist): """ KNN function to find the K Nearest neighbours Args: @@ -117,12 +133,19 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n neighbors to consider @output_neighbours Outputs the list of k-nearest neighbors that were used in the voting/averaging. + @param fn_dist Distance metrics function. Default is + squared_dist_norm2. Following functions + are supported : + dist_norm1 , dist_norm2,squared_dist_norm2, + dist_angle , dist_tanimoto + Or user defined function with signature + DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION """ with MinWarning('warning'): k_val = knn_validate_src(schema_madlib, point_source, point_column_name, point_id, label_column_name, test_source, test_column_name, test_id, - output_table, k, output_neighbors) + output_table, k, output_neighbors, fn_dist) x_temp_table = unique_string(desp='x_temp_table') y_temp_table = unique_string(desp='y_temp_table') @@ -132,6 +155,10 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n if output_neighbors is None: output_neighbors = True + if not fn_dist: + fn_dist = schema_madlib + '.squared_dist_norm2' + + fn_dist = fn_dist.lower().strip() interim_table = unique_string(desp='interim_table') pred_out = "" @@ -141,7 +168,7 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n if output_neighbors: knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY " - "knn_temp.dist ASC) AS k_nearest_neighbours ") + "knn_temp.dist ASC) AS k_nearest_neighbours ") if label_column_name: is_classification = False label_column_type = get_expr_type( @@ -156,12 +183,12 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n ).format(**locals()) pred_out += " AS prediction" label_out = (", train.{label_column_name}{cast_to_int}" - " AS {label_col_temp}").format(**locals()) + " AS {label_col_temp}").format(**locals()) if not label_column_name and not output_neighbors: plpy.error("kNN error: Either label_column_name or " - "output_neighbors has to be non-NULL.") + "output_neighbors has to be non-NULL.") plpy.execute(""" CREATE TEMP TABLE {interim_table} AS @@ -172,7 +199,7 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n FROM ( SELECT test.{test_id} AS {test_id_temp} , train.{point_id} as train_id , - {schema_madlib}.squared_dist_norm2( + {fn_dist}( train.{point_column_name}, test.{test_column_name}) AS dist {label_out} @@ -196,7 +223,6 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n GROUP BY {test_id_temp} , {test_column_name} """.format(**locals())) - plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table)) return # ------------------------------------------------------------------------------ http://git-wip-us.apache.org/repos/asf/madlib/blob/5a291aa8/src/ports/postgres/modules/knn/knn.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/knn/knn.sql_in b/src/ports/postgres/modules/knn/knn.sql_in index befb768..17d81ad 100644 --- a/src/ports/postgres/modules/knn/knn.sql_in +++ b/src/ports/postgres/modules/knn/knn.sql_in @@ -78,7 +78,8 @@ knn( point_source, test_id, output_table, k, - output_neighbors + output_neighbors, + fn_dist ) </pre> @@ -131,6 +132,22 @@ otherwise the result may depend on ordering of the input data.</dd> neighbors that were used in the voting/averaging, sorted from closest to furthest.</dd> +<dt>fn_dist (optional)</dt> +<dd>TEXT, default: squared_dist_norm2'. The name of the function to use to calculate the distance from a data point to a centroid. + +The following distance functions can be used (computation of barycenter/mean in parentheses): +<ul> +<li><b>\ref dist_norm1</b>: 1-norm/Manhattan (element-wise median +[Note that MADlib does not provide a median aggregate function for support and +performance reasons.])</li> +<li><b>\ref dist_norm2</b>: 2-norm/Euclidean (element-wise mean)</li> +<li><b>\ref squared_dist_norm2</b>: squared Euclidean distance (element-wise mean)</li> +<li><b>\ref dist_angle</b>: angle (element-wise mean of normalized points)</li> +<li><b>\ref dist_tanimoto</b>: tanimoto (element-wise mean of normalized points <a href="#kmeans-lit-5">[5]</a>)</li> +<li><b>user defined function</b> with signature <tt>DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION</tt></li></ul></dd> + + + </dl> @@ -227,6 +244,7 @@ SELECT * FROM madlib.knn( 'knn_result_classification', -- Output table 3, -- Number of nearest neighbors True -- True if you want to show Nearest-Neighbors by id, False otherwise + 'madlib.squared_dist_norm2' -- Distance function ); SELECT * from knn_result_classification ORDER BY id; </pre> @@ -258,6 +276,7 @@ SELECT * FROM madlib.knn( 'knn_result_regression', -- Output table 3, -- Number of nearest neighbors True -- True if you want to show Nearest-Neighbors, False otherwise + 'madlib.squared_dist_norm2' -- Distance function ); SELECT * FROM knn_result_regression ORDER BY id; </pre> @@ -388,6 +407,7 @@ SELECT {schema_madlib}.knn( output_table, -- Name of output table k, -- value of k. Default will go as 1 output_neighbors -- Outputs the list of k-nearest neighbors that were used in the voting/averaging. + fn_dist -- The name of the function to use to calculate the distance from a data point to a centroid. ); ----------------------------------------------------------------------- @@ -435,7 +455,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn( test_id VARCHAR, output_table VARCHAR, k INTEGER, - output_neighbors Boolean + output_neighbors Boolean, + fn_dist TEXT ) RETURNS VARCHAR AS $$ PythonFunctionBodyOnly(`knn', `knn') return knn.knn( @@ -449,7 +470,9 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn( test_id, output_table, k, - output_neighbors + output_neighbors, + fn_dist + ) $$ LANGUAGE plpythonu VOLATILE m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); @@ -469,7 +492,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn( DECLARE returnstring VARCHAR; BEGIN - returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,$9); + returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,$9, 'MADLIB_SCHEMA.squared_dist_norm2'); RETURN returnstring; END; $$ LANGUAGE plpgsql VOLATILE @@ -489,7 +512,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn( DECLARE returnstring VARCHAR; BEGIN - returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,TRUE); + returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,TRUE,'MADLIB_SCHEMA.squared_dist_norm2'); RETURN returnstring; END; $$ LANGUAGE plpgsql VOLATILE @@ -508,8 +531,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn( DECLARE returnstring VARCHAR; BEGIN - returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,TRUE); + returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,TRUE,'MADLIB_SCHEMA.squared_dist_norm2'); RETURN returnstring; END; $$ LANGUAGE plpgsql VOLATILE -m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); http://git-wip-us.apache.org/repos/asf/madlib/blob/5a291aa8/src/ports/postgres/modules/knn/test/knn.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in b/src/ports/postgres/modules/knn/test/knn.sql_in index 1e71a0e..8bb8f20 100644 --- a/src/ports/postgres/modules/knn/test/knn.sql_in +++ b/src/ports/postgres/modules/knn/test/knn.sql_in @@ -73,24 +73,81 @@ copy knn_test_data (id, data) from stdin delimiter '|'; \. drop table if exists madlib_knn_result_classification; -select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False); +select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.squared_dist_norm2'); select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification; drop table if exists madlib_knn_result_classification; -select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,True); +select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,True,'madlib.squared_dist_norm2'); select assert(array_agg(x)= '{1,2,3}','Wrong output in classification with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_classification where id = 1 order by x asc) y; drop table if exists madlib_knn_result_regression; -select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False); +select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.squared_dist_norm2'); select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression; drop table if exists madlib_knn_result_regression; -select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True); +select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True,'madlib.squared_dist_norm2'); select assert(array_agg(x)= '{1,2,3}' , 'Wrong output in regression with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_regression where id = 1 order by x asc) y; drop table if exists madlib_knn_result_classification; -select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',False); +select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.squared_dist_norm2'); select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=1') from madlib_knn_result_classification; +drop table if exists madlib_knn_result_classification; +select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,NULL); +select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification; + + +drop table if exists madlib_knn_result_regression; +select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True, NULL ); +select assert(array_agg(x)= '{1,2,3}' , 'Wrong output in regression with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_regression where id = 1 order by x asc) y; + + +drop table if exists madlib_knn_result_regression; +select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,NULL); +select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression; + + +drop table if exists madlib_knn_result_classification; +select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_norm1'); +select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification; + + +drop table if exists madlib_knn_result_classification; +select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_norm2'); +select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification; + + +drop table if exists madlib_knn_result_classification; +select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_angle'); +select assert(array_agg(prediction order by id)='{1,0,0,1,0,1}', 'Wrong output in classification with k=3') from madlib_knn_result_classification; + + +drop table if exists madlib_knn_result_classification; +select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_tanimoto'); +select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification; + + +drop table if exists madlib_knn_result_regression; +select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_norm1'); +select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression; + + + +drop table if exists madlib_knn_result_regression; +select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_norm2'); +select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression; + + +drop table if exists madlib_knn_result_regression; +select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_angle'); +select assert(array_agg(prediction order by id)='{0.75,0.25,0.25,0.75,0.25,1}', 'Wrong output in regression') from madlib_knn_result_regression; + + +drop table if exists madlib_knn_result_regression; +select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_tanimoto'); +select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in regression') from madlib_knn_result_regression; + + + select knn(); select knn('help');
