This is an automated email from the ASF dual-hosted git repository. okislal pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 7a130ec548f5e4f9e362e42de47cf04739cce64c Author: Orhan Kislal <[email protected]> AuthorDate: Fri Sep 13 12:38:27 2019 -0400 Kmeans: Add simple silhouette score for every point JIRA: MADLIB-1382 This commit adds a function to calculate the simple silhouette score for every input data point. Closes #441 Co-authored-by: Domino Valdano <[email protected]> --- src/modules/linalg/metric.cpp | 55 +++++++----- src/ports/postgres/modules/kmeans/kmeans.py_in | 90 +++++++++++++++++-- src/ports/postgres/modules/kmeans/kmeans.sql_in | 53 +++++++++++ .../postgres/modules/kmeans/test/kmeans.sql_in | 100 ++++++++++++++++++--- src/ports/postgres/modules/linalg/linalg.sql_in | 2 +- .../postgres/modules/utilities/utilities.py_in | 2 +- 6 files changed, 259 insertions(+), 43 deletions(-) diff --git a/src/modules/linalg/metric.cpp b/src/modules/linalg/metric.cpp index e3835ee..4809762 100644 --- a/src/modules/linalg/metric.cpp +++ b/src/modules/linalg/metric.cpp @@ -365,31 +365,40 @@ closest_column::run(AnyType& args) { * This function calls a user-supplied function, for which it does not do * garbage collection. It is therefore meant to be called only constantly many * times before control is returned to the backend. - */ + */ AnyType closest_columns::run(AnyType& args) { - MappedMatrix M = args[0].getAs<MappedMatrix>(); - MappedColumnVector x = args[1].getAs<MappedColumnVector>(); - uint32_t num = args[2].getAs<uint32_t>(); - FunctionHandle dist = args[3].getAs<FunctionHandle>() - .unsetFunctionCallOptions(FunctionHandle::GarbageCollectionAfterCall); - string dist_fname = args[4].getAs<char *>(); - - std::string fname = dist_fn_name(dist_fname); - - std::vector<std::tuple<Index, double> > result(num); - closestColumnsAndDistancesShortcut(M, x, dist, fname, result.begin(), - result.end()); - - MutableArrayHandle<int32_t> indices = allocateArray<int32_t, - dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num); - MutableArrayHandle<double> distances = allocateArray<double, - dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num); - for (uint32_t i = 0; i < num; ++i) - std::tie(indices[i], distances[i]) = result[i]; - - AnyType tuple; - return tuple << indices << distances; + + /* If the input has a null value, we want to return nothing for that + * particular data point (because we cannot calculate the distance) + * instead of failing. + */ + try{ + MappedMatrix M = args[0].getAs<MappedMatrix>(); + MappedColumnVector x = args[1].getAs<MappedColumnVector>(); + uint32_t num = args[2].getAs<uint32_t>(); + FunctionHandle dist = args[3].getAs<FunctionHandle>() + .unsetFunctionCallOptions(FunctionHandle::GarbageCollectionAfterCall); + string dist_fname = args[4].getAs<char *>(); + + std::string fname = dist_fn_name(dist_fname); + + std::vector<std::tuple<Index, double> > result(num); + closestColumnsAndDistancesShortcut(M, x, dist, fname, result.begin(), + result.end()); + + MutableArrayHandle<int32_t> indices = allocateArray<int32_t, + dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num); + MutableArrayHandle<double> distances = allocateArray<double, + dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num); + for (uint32_t i = 0; i < num; ++i) + std::tie(indices[i], distances[i]) = result[i]; + + AnyType tuple; + return tuple << indices << distances; + }catch (const ArrayWithNullException &e) { + return Null(); + } } AnyType diff --git a/src/ports/postgres/modules/kmeans/kmeans.py_in b/src/ports/postgres/modules/kmeans/kmeans.py_in index 628b690..30e4005 100644 --- a/src/ports/postgres/modules/kmeans/kmeans.py_in +++ b/src/ports/postgres/modules/kmeans/kmeans.py_in @@ -15,11 +15,15 @@ import plpy import re from utilities.control import IterationController2D +from utilities.control import MinWarning from utilities.control_composite import IterationControllerComposite from utilities.validate_args import table_exists from utilities.validate_args import columns_exist_in_table from utilities.validate_args import table_is_empty from utilities.validate_args import get_expr_type +from utilities.validate_args import input_tbl_valid +from utilities.validate_args import output_tbl_valid +from utilities.utilities import _assert from utilities.utilities import unique_string HAS_FUNCTION_PROPERTIES = m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>) @@ -224,7 +228,21 @@ def compute_kmeans_random_seeding(schema_madlib, rel_args, rel_state, m = it.evaluate("_args.k - coalesce(array_upper({0}, 1), 0)".format(state_str)) return iterationCtrl.iteration # ------------------------------------------------------------------------------ +def _create_temp_view_for_expr(schema_madlib, rel_source, expr_point): + """ + Create a temporary view to evaluate the expr_point. + """ + + if kmeans_validate_expr(schema_madlib, rel_source, expr_point): + view_name = unique_string('km_view') + + plpy.execute(""" CREATE TEMP VIEW {view_name} AS + SELECT {expr_point} AS expr FROM {rel_source} + """.format(**locals())) + rel_source = view_name + expr_point = 'expr' + return rel_source,expr_point def compute_kmeans(schema_madlib, rel_args, rel_state, rel_source, expr_point, agg_centroid, **kwargs): @@ -246,14 +264,9 @@ def compute_kmeans(schema_madlib, rel_args, rel_state, rel_source, result in \c rel_state """ - if kmeans_validate_expr(schema_madlib, rel_source, expr_point): - view_name = unique_string('km_view') - - plpy.execute(""" CREATE TEMP VIEW {view_name} AS - SELECT {expr_point} AS expr FROM {rel_source} - """.format(**locals())) - rel_source = view_name - expr_point = 'expr' + rel_source, expr_point = _create_temp_view_for_expr(schema_madlib, + rel_source, + expr_point) fn_dist_name = plpy.execute("SELECT fn_dist_name FROM " + rel_args)[0]['fn_dist_name'] @@ -387,5 +400,66 @@ def compute_kmeans(schema_madlib, rel_args, rel_state, rel_source, 'old_centroid': old_centroid_str})) return iterationCtrl.iteration +def simple_silhouette_points(schema_madlib, rel_source, output_table, pid, + expr_point, centroids, fn_dist, **kwargs): + + """ + Calculate the simple silhouette score for every data point. + """ + + with MinWarning("error"): + kmeans_validate_src(schema_madlib, rel_source) + output_tbl_valid(output_table, 'kmeans') + + _assert(type(centroids) == list and + type(centroids[0]) == list and + len(centroids) > 1, + 'kmeans: Invalid centroids shape. Centroids have to be a 2D numeric array.') + + rel_source, expr_point = _create_temp_view_for_expr(schema_madlib, + rel_source, + expr_point) + + plpy.execute(""" + CREATE TABLE {output_table} AS + SELECT {pid}, centroids[1] AS centroid_id, + centroids[2] AS neighbor_centroid_id, + (CASE + WHEN distances[2] = 0 THEN 0 + ELSE (distances[2] - distances[1]) / distances[2] + END) AS silh + FROM + (SELECT {pid}, + (cc_out).column_ids::integer[] AS centroids, + (cc_out).distances::double precision[] AS distances + FROM ( + SELECT {pid}, + {schema_madlib}._closest_columns( + array{centroids}, + {expr_point}, + 2, + '{fn_dist}'::REGPROC, '{fn_dist}') AS cc_out + FROM {rel_source})q1 + )q2 + """.format(**locals())) + +def simple_silhouette_points_dbl_wrapper(schema_madlib, rel_source, output_table, pid, + expr_point, centroids, fn_dist, **kwargs): + + simple_silhouette_points(schema_madlib, rel_source, output_table, pid, + expr_point, centroids, fn_dist) + + +def simple_silhouette_points_str_wrapper(schema_madlib, rel_source, output_table, pid, + expr_point, centroids_table, centroids_col, fn_dist, **kwargs): + + input_tbl_valid(centroids_table, 'kmeans') + columns_exist_in_table(centroids_table, centroids_col) + centroids = plpy.execute(""" + SELECT {centroids_col} AS centroids FROM {centroids_table} + """.format(**locals()))[0]['centroids'] + + simple_silhouette_points(schema_madlib, rel_source, output_table, pid, + expr_point, centroids, fn_dist, **kwargs) m4_changequote(<!`!>, <!'!>) diff --git a/src/ports/postgres/modules/kmeans/kmeans.sql_in b/src/ports/postgres/modules/kmeans/kmeans.sql_in index 1eae525..e354e05 100644 --- a/src/ports/postgres/modules/kmeans/kmeans.sql_in +++ b/src/ports/postgres/modules/kmeans/kmeans.sql_in @@ -1906,3 +1906,56 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto( SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, NULL, NULL, NULL, NULL, NULL) $$ LANGUAGE sql VOLATILE m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points( + rel_source VARCHAR, + output_table VARCHAR, + pid VARCHAR, + expr_point VARCHAR, + centroids_table VARCHAR, + centroids_col VARCHAR, + fn_dist VARCHAR /*+ DEFAULT 'dist_norm2' */ +) RETURNS VOID AS $$ + PythonFunction(kmeans, kmeans, simple_silhouette_points_str_wrapper) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points( + rel_source VARCHAR, + output_table VARCHAR, + pid VARCHAR, + expr_point VARCHAR, + centroids_table VARCHAR, + centroids_col VARCHAR +) RETURNS VOID +AS $$ + SELECT MADLIB_SCHEMA.simple_silhouette_points($1, $2, $3, $4, $5, $6, + 'MADLIB_SCHEMA.dist_norm2') +$$ LANGUAGE sql VOLATILE +m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `'); + + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points( + rel_source VARCHAR, + output_table VARCHAR, + pid VARCHAR, + expr_point VARCHAR, + centroids DOUBLE PRECISION[], + fn_dist VARCHAR /*+ DEFAULT 'dist_norm2' */ +) RETURNS VOID AS $$ + PythonFunction(kmeans, kmeans, simple_silhouette_points_dbl_wrapper) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points( + rel_source VARCHAR, + output_table VARCHAR, + pid VARCHAR, + expr_point VARCHAR, + centroids DOUBLE PRECISION[] +) RETURNS VOID +AS $$ + SELECT MADLIB_SCHEMA.simple_silhouette_points($1, $2, $3, $4, $5, + 'MADLIB_SCHEMA.dist_norm2') +$$ LANGUAGE sql VOLATILE +m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `'); diff --git a/src/ports/postgres/modules/kmeans/test/kmeans.sql_in b/src/ports/postgres/modules/kmeans/test/kmeans.sql_in index 4553b6c..b0e5024 100644 --- a/src/ports/postgres/modules/kmeans/test/kmeans.sql_in +++ b/src/ports/postgres/modules/kmeans/test/kmeans.sql_in @@ -64,6 +64,10 @@ SELECT * FROM kmeans('kmeans_2d', 'position', ARRAY[ SELECT * FROM kmeans('kmeans_2d', 'position', 'centroids', 'position', 'MADLIB_SCHEMA.dist_norm1'); SELECT * FROM kmeans('kmeans_2d', 'position', 'centroids', 'position', 'MADLIB_SCHEMA.dist_norm2'); +SELECT * FROM kmeans('kmeans_2d', 'array[x,y]', 'centroids', 'array[x,y]'); +SELECT * FROM kmeanspp('kmeans_2d', 'array[x,y]', 10); +SELECT * FROM kmeans_random('kmeans_2d', 'arRAy [ x,y]', 10); + DROP TABLE IF EXISTS km_sample; CREATE TABLE km_sample(pid int, points double precision[]); @@ -81,16 +85,6 @@ COPY km_sample (pid, points) FROM stdin DELIMITER '|'; 10 | {13.86, 1.35, 2.27, 16, 98, 2.98, 3.15, 0.22, 1.8500, 7.2199, 1.01, NULL, 1045} \. - -SELECT * FROM kmeanspp('km_sample', 'points', 2, - 'MADLIB_SCHEMA.squared_dist_norm2', - 'MADLIB_SCHEMA.avg', 20, 0.001); - - -SELECT * FROM kmeans('kmeans_2d', 'array[x,y]', 'centroids', 'array[x,y]'); -SELECT * FROM kmeanspp('kmeans_2d', 'array[x,y]', 10); -SELECT * FROM kmeans_random('kmeans_2d', 'arRAy [ x,y]', 10); - -- Test kmeanspp_auto DROP TABLE IF EXISTS autokm_out,autokm_out_summary; SELECT * FROM kmeanspp_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2', @@ -163,6 +157,43 @@ DROP TABLE IF EXISTS autokm_out,autokm_out_summary; SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[12,3,5,6,8], 'MADLIB_SCHEMA.squared_dist_norm2', 'MADLIB_SCHEMA.avg', 20, 0.001, 'silhouette'); +-- Silhouette Tests +DROP TABLE IF EXISTS km_sample_out, silh_out; + +CREATE TABLE km_sample_out AS +SELECT * FROM kmeanspp('km_sample', 'points', 2, + 'MADLIB_SCHEMA.squared_dist_norm2', + 'MADLIB_SCHEMA.avg', 20, 0.001); + +-- Test simple_silhouette_points full interface +SELECT * FROM simple_silhouette_points('km_sample', 'silh_out', 'pid', 'points', + 'km_sample_out', 'centroids', + 'MADLIB_SCHEMA.squared_dist_norm2'); + +SELECT assert(silh > 0, 'Incorrect silhouette value') +FROM silh_out +WHERE silh IS NOT NULL; + +DROP TABLE IF EXISTS silh_out; +-- Test simple_silhouette_points default distance func +SELECT * FROM simple_silhouette_points( + 'km_sample', 'silh_out', 'pid', 'points', + 'km_sample_out', 'centroids'); + +SELECT assert(count(*) = 9, 'Incorrect silhouette count') +FROM silh_out +WHERE silh IS NOT NULL; + +DROP TABLE IF EXISTS silh_out; +-- Test simple_silhouette_points double precision array centroids +SELECT * FROM simple_silhouette_points( + 'km_sample', 'silh_out', 'pid', 'points', + (SELECT centroids FROM km_sample_out)); + +SELECT assert(silh > 0, 'Incorrect silhouette value') +FROM silh_out +WHERE silh IS NOT NULL; + SELECT assert( silhouette > 0 AND objective_fn > 0, 'Kmeans: Auto Kmeans_random failed for silhouette on unordered k vals') @@ -206,6 +237,55 @@ SELECT assert( 'Kmeans: Auto Kmeans_random failed for both.') FROM autokm_out_summary; +DROP TABLE IF EXISTS silh_out; +-- Test simple_silhouette_points actual values +SELECT * FROM simple_silhouette_points( + 'km_sample', 'silh_out', 'pid', 'points', + ARRAY[[1,1,1,1,1,1,1,1,1,1,1,1,1], + [14.23, 1.71, 2.43, 15.6, 127, 2.8, 3.0600, 0.2800, 2.29, 5.64, 1.04, 3.92, 1065]]::DOUBLE PRECISION[][]); + +SELECT assert(relative_error(1, silh) < 1e-3, + 'Incorrect silhouette value') +FROM silh_out +WHERE pid = 1; +SELECT assert(relative_error(0.8789, silh) < 1e-3, + 'Incorrect silhouette value') +FROM silh_out +WHERE pid = 2; +SELECT assert(relative_error(0.8966, silh) < 1e-3, + 'Incorrect silhouette value') +FROM silh_out +WHERE pid = 3; +SELECT assert(relative_error(0.7200, silh) < 1e-3, + 'Incorrect silhouette value') +FROM silh_out +WHERE pid = 4; +SELECT assert(relative_error(0.5560, silh) < 1e-3, + 'Incorrect silhouette value') +FROM silh_out +WHERE pid = 5; +SELECT assert(relative_error(0.7348, silh) < 1e-3, + 'Incorrect silhouette value') +FROM silh_out +WHERE pid = 6; +SELECT assert(relative_error(0.8242, silh) < 1e-3, + 'Incorrect silhouette value') +FROM silh_out +WHERE pid = 7; +SELECT assert(relative_error(0.8229, silh) < 1e-3, + 'Incorrect silhouette value') +FROM silh_out +WHERE pid = 8; +SELECT assert(relative_error(0.9655, silh) < 1e-3, + 'Incorrect silhouette value') +FROM silh_out +WHERE pid = 9; + +SELECT assert(centroid_id = 1 AND neighbor_centroid_id = 0, + 'Incorrect centroid ids') +FROM silh_out; + +DROP TABLE IF EXISTS km_sample_out, silh_out; DROP TABLE IF EXISTS km_sample CASCADE; DROP TABLE IF EXISTS centroids CASCADE; DROP TABLE IF EXISTS kmeans_2d CASCADE; diff --git a/src/ports/postgres/modules/linalg/linalg.sql_in b/src/ports/postgres/modules/linalg/linalg.sql_in index b10cef7..3c74451 100644 --- a/src/ports/postgres/modules/linalg/linalg.sql_in +++ b/src/ports/postgres/modules/linalg/linalg.sql_in @@ -428,7 +428,7 @@ LANGUAGE C IMMUTABLE STRICT m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!NO SQL!>, <!!>); --- Because of Jiara MPP-23166, ORCA makes the following +-- Because of Jira MPP-23166, ORCA makes the following -- function extremely slow because "NO SQL" now becomes -- "CONTAINS SQL". This is why we disabled the optimizer -- in kmeans. diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index 4e142aa..8f5b2ff 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -209,7 +209,7 @@ def add_postfix(quoted_string, postfix): NUMERIC = set(['smallint', 'integer', 'bigint', 'decimal', 'numeric', - 'real', 'double precision', 'serial', 'bigserial']) + 'real', 'double precision', 'float', 'serial', 'bigserial']) INTEGER = set(['smallint', 'integer', 'bigint']) TEXT = set(['text', 'varchar', 'character varying', 'char', 'character']) BOOLEAN = set(['boolean'])
