This is an automated email from the ASF dual-hosted git repository.
riyer pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
The following commit(s) were added to refs/heads/master by this push:
new 5e601fb K-NN: Add kd-tree method for approximate knn
5e601fb is described below
commit 5e601fbdb4c6423c148f8bdfead0a9988f31800d
Author: Orhan Kislal <[email protected]>
AuthorDate: Wed Feb 20 16:33:46 2019 -0800
K-NN: Add kd-tree method for approximate knn
JIRA: MADLIB-1061
This commit adds a kd-tree option to the 'knn' function. A kd-tree is
used to reduce the search space to find nearest neighbors. The method
implemented here does not produce the complete kd-tree, instead it
allows the user to specify a maximum depth for the binary tree.
Additional changes:
- Add function to clean madlib views
- Move k-nn out of 'Early Stage Development'
Closes #352
Co-authored-by: Rahul Iyer <[email protected]>
Co-authored-by: Frank McQuillan <[email protected]>
---
doc/design/design.tex | 1 +
doc/design/figures/2d_kdtree.pdf | Bin 0 -> 10652 bytes
doc/design/modules/knn.tex | 146 +++++++
doc/literature.bib | 11 +
doc/mainpage.dox.in | 2 +-
src/ports/postgres/modules/knn/knn.py_in | 480 +++++++++++++++++----
src/ports/postgres/modules/knn/knn.sql_in | 249 +++++++++--
src/ports/postgres/modules/knn/test/knn.sql_in | 287 +++++++++---
src/ports/postgres/modules/utilities/admin.py_in | 22 +
.../postgres/modules/utilities/utilities.py_in | 1 -
.../postgres/modules/utilities/utilities.sql_in | 8 +
11 files changed, 1033 insertions(+), 174 deletions(-)
diff --git a/doc/design/design.tex b/doc/design/design.tex
index e9ed7b8..6772f89 100644
--- a/doc/design/design.tex
+++ b/doc/design/design.tex
@@ -231,6 +231,7 @@
\input{modules/SVM}
\input{modules/graph}
\input{modules/neural-network}
+\input{modules/knn}
\printbibliography
\end{document}
diff --git a/doc/design/figures/2d_kdtree.pdf b/doc/design/figures/2d_kdtree.pdf
new file mode 100644
index 0000000..062ae23
Binary files /dev/null and b/doc/design/figures/2d_kdtree.pdf differ
diff --git a/doc/design/modules/knn.tex b/doc/design/modules/knn.tex
new file mode 100644
index 0000000..71af411
--- /dev/null
+++ b/doc/design/modules/knn.tex
@@ -0,0 +1,146 @@
+% Licensed to the Apache Software Foundation (ASF) under one
+% or more contributor license agreements. See the NOTICE file
+% distributed with this work for additional information
+% regarding copyright ownership. The ASF licenses this file
+% to you under the Apache License, Version 2.0 (the
+% "License"); you may not use this file except in compliance
+% with the License. You may obtain a copy of the License at
+
+% http://www.apache.org/licenses/LICENSE-2.0
+
+% Unless required by applicable law or agreed to in writing,
+% software distributed under the License is distributed on an
+% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+% KIND, either express or implied. See the License for the
+% specific language governing permissions and limitations
+% under the License.
+
+!TEX root = ../design.tex
+
+
+\chapter[k Nearest Neighbors]{k Nearest Neighbors}
+
+\begin{moduleinfo}
+\item[Authors] \href{mailto:[email protected]}{Orhan Kislal}
+
+\item[History]
+ \begin{modulehistory}
+ \item[v0.1] Initial version: knn and kd-tree.
+ \end{modulehistory}
+\end{moduleinfo}
+
+
+% Abstract. What is the problem we want to solve?
+\section{Introduction} % (fold)
+\label{sec:knn_introduction}
+
+\emph{Some notes and figures in this section are borrowed from
\cite{medium_knn} and \cite{point_knn}}.
+
+K-nearest neighbors (KNN) is one of the most commonly used learning
+algorithms. The goal of knn is to find a number (k) of training data points
+closest to the test point. These neighbors can be used to predict labels via
+classification or regression.
+
+KNN does not have a training phase like the most of learning techniques. It
+does not create a model to generalize the data, instead the algorithm uses the
+whole training dataset (or a specific subset of it).
+
+KNN can be used for classification, the output is a class membership (a
+discrete value). An object is classified by a majority vote of its neighbors,
+with the object being assigned to the class most common among its k nearest
+neighbors. It can also be used for regression, output is the value for the
+object (predicts continuous values). This value is the average (or median) of
+the values of its k nearest neighbors.
+
+\section{Implementation Details}
+
+The basic KNN implementation depends on the table join between the training
dataset and the test dataset.
+
+\begin{sql}
+ (SELECT test_id,
+ train_id,
+ fn_dist(train_col_name, test_col_name) AS dist,
+ label
+ FROM train_table, test_table) AS knn_sub
+\end{sql}
+
+Once we have the distance between every train - test pair, the algorithm picks
the k smallest values.
+
+\begin{sql}
+ SELECT row_number() OVER
+ (PARTITION BY test_id ORDER BY dist) AS r,
+ test_id,
+ train_id,
+ label
+ FROM knn_sub
+ WHERE r <= k
+\end{sql}
+
+Finally, the prediction is completed based on the labels of the selected
+training points for each test point.
+
+\section{Enabling KD-tree}
+
+One of the major shortcomings of KNN is the fact that it is computationally
+expensive. In addition, there is no training phase; which means every single
+prediction will have to compute the full table join. One of the ways to
+improve the performance is to reduce the search space for test points. Kd-tree
+option is developed to enable trading the accuracy of the output with higher
+performance by reducing the neighbor search space.
+
+Kd-trees are used for organizing data in k dimensions. It is constructed like
+a binary search tree where each level of the tree is using a specific
+dimension for finding splits.
+
+
+\begin{figure}[h]
+ \centering
+ \includegraphics[width=0.9\textwidth]{figures/2d_kdtree.pdf}
+\caption{A 2D kd-tree of depth 3}
+\label{kdd:2d_kdtree}
+\end{figure}
+
+A kd-tree is constructed by finding the median value of the data in a
+particular dimension and separating the data into two sections based on this
+value. This process is repeated a number of times to construct smaller
+regions. Once the kd-tree is prepared, it can be used by any test point to
+find its assigned region and this fragmentation can be used for limiting the
+search space for nearest neighbors.
+
+Once we have the kd-tree regions and their borders, we find the associated
+regions for the test points. This gives us the first region to search for
+nearest neighbors. In addition, we allow the user to request for multiple
+regions to search. This means we have to decide which additional regions to
+include in our search. We implemented a backtracking algorithm to find these
+regions. The core idea is to find the closest border for each test point and
+select the region on the other side of the border. Note that points that
+reside in the same region might have different secondary (or tertiary, etc.)
+regions. Consider the tree at Figure~\ref{kdd:2d_kdtree}. A test point at $<5
+, 2>$ is in the same region as $<3 , 3.9>$. However, their closest borders and
+the associated secondary regions are wildly different. In addition, consider
+$<3 , 3.9>$ and $<6 , 3.9>$. They both have the same border as their closest
+one ($y=4$). However, their closest regions do differ. To make sure that we
+get the correct region, the following scheme is implemented. For a given point
+$P$, we find the closest border, $dim[i] = x$ and $P$'s relative position to
+it ($pos$ = $-1$ for lower and $+1$ for higher). We conjure a new point that
+consists of the same values as the test point in every dimension except $i$.
+For $dim[i]$, we set the value to $x-pos*\epsilon$. Finally, we use the
+existing kd-tree to find this new point's assigned region. This region is our
+expansion target for the point $P$. We repeat this process with the next
+closest border as requested by the user.
+
+The knn algorithm does not change significantly with the addition of regions.
+Assuming that the training and test datasets have their region information
+stored in the tables, the only necessary change is ensuring that the table
+join uses these region ids to limit the search space.
+
+
+\begin{sql}
+ (SELECT test_id,
+ train_id,
+ fn_dist(train_col_name, test_col_name) AS dist,
+ label
+ FROM train_table, test_table
+ WHERE train_table.region_id = test_table.region_id
+ ) AS knn_sub
+\end{sql}
diff --git a/doc/literature.bib b/doc/literature.bib
index 3c07260..5aa19ab 100644
--- a/doc/literature.bib
+++ b/doc/literature.bib
@@ -986,3 +986,14 @@ Applied Survival Analysis},
Title = {{TRAINING RECURRENT NEURAL NETWORKS}},
Author = {{Ilya Sutskever}}
}
+
+@misc{medium_knn,
+ Url =
{https://medium.com/@adi.bronshtein/a-quick-introduction-to-k-nearest-neighbors-algorithm-62214cea29c7},
+ Title = {{A quick introduction to k nearest neighbors algorithm}},
+}
+
+@misc{point_knn,
+ Url = {http://pointclouds.org/documentation/tutorials/kdtree_search.php},
+ Title = {{How to use a KdTree to search}},
+}
+
diff --git a/doc/mainpage.dox.in b/doc/mainpage.dox.in
index 5568da6..c8b308d 100644
--- a/doc/mainpage.dox.in
+++ b/doc/mainpage.dox.in
@@ -191,6 +191,7 @@ complete matrix stored as a distributed table.
@details Methods to perform a variety of supervised learning tasks.
@{
@defgroup grp_crf Conditional Random Field
+ @defgroup grp_knn k-Nearest Neighbors
@defgroup grp_nn Neural Network
@defgroup grp_regml Regression Models
@brief A collection of methods for modeling conditional expectation of a
response variable.
@@ -291,7 +292,6 @@ Interface and implementation are subject to change.
@{
@defgroup grp_minibatch_preprocessing_dl Mini-Batch Preprocessor for
Deep Learning
@}
- @defgroup grp_knn k-Nearest Neighbors
@defgroup grp_bayes Naive Bayes Classification
@defgroup grp_sample Random Sampling
@}
diff --git a/src/ports/postgres/modules/knn/knn.py_in
b/src/ports/postgres/modules/knn/knn.py_in
index 4db7ac1..bf64352 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -27,27 +27,38 @@
"""
import plpy
-from utilities.validate_args import input_tbl_valid, output_tbl_valid
-from utilities.validate_args import cols_in_tbl_valid
-from utilities.validate_args import is_col_array
-from utilities.validate_args import array_col_has_no_null
-from utilities.validate_args import get_expr_type
+import copy
+from collections import defaultdict
+from math import log
+from utilities.control import MinWarning
+from utilities.utilities import INTEGER
from utilities.utilities import _assert
+from utilities.utilities import add_postfix
+from utilities.utilities import extract_keyvalue_params
+from utilities.utilities import py_list_to_sql_string
from utilities.utilities import unique_string
-from utilities.control import MinWarning
-from utilities.validate_args import quote_ident
-from utilities.validate_args import is_var_valid
from utilities.utilities import NUMERIC, ONLY_ARRAY
from utilities.utilities import is_valid_psql_type
from utilities.utilities import is_pg_major_version_less_than
+from utilities.utilities import num_features
+from utilities.validate_args import array_col_has_no_null
+from utilities.validate_args import cols_in_tbl_valid
+from utilities.validate_args import drop_tables
+from utilities.validate_args import get_cols
+from utilities.validate_args import get_expr_type
+from utilities.validate_args import input_tbl_valid, output_tbl_valid
+from utilities.validate_args import is_col_array
+from utilities.validate_args import is_var_valid
+from utilities.validate_args import quote_ident
-MAX_WEIGHT_ZERO_DIST = 1e6
-
+WEIGHT_FOR_ZERO_DIST = 1e6
+BRUTE_FORCE = 'brute_force'
+KD_TREE = 'kd_tree'
def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
label_column_name, test_source, test_column_name,
test_id, output_table, k, output_neighbors, fn_dist,
- **kwargs):
+ is_brute_force, depth, leaf_nodes, **kwargs):
input_tbl_valid(point_source, 'kNN')
input_tbl_valid(test_source, 'kNN')
output_tbl_valid(output_table, 'kNN')
@@ -60,14 +71,16 @@ def knn_validate_src(schema_madlib, point_source,
point_column_name, point_id,
cols_in_tbl_valid(point_source, [label_column_name], 'kNN')
_assert(is_var_valid(point_source, point_column_name),
- "kNN error: {0} is an invalid column name or expression for
point_column_name param".format(point_column_name))
+ "kNN error: {0} is an invalid column name or "
+ "expression for point_column_name param".format(point_column_name))
point_col_type = get_expr_type(point_column_name, point_source)
_assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
"kNN Error: Feature column or expression '{0}' in train table is
not"
" an array.".format(point_column_name))
_assert(is_var_valid(test_source, test_column_name),
- "kNN error: {0} is an invalid column name or expression for
test_column_name param".format(test_column_name))
+ "kNN error: {0} is an invalid column name or expression for "
+ "test_column_name param".format(test_column_name))
test_col_type = get_expr_type(test_column_name, test_source)
_assert(is_valid_psql_type(test_col_type, NUMERIC | ONLY_ARRAY),
"kNN Error: Feature column or expression '{0}' in test table is
not"
@@ -101,7 +114,7 @@ def knn_validate_src(schema_madlib, point_source,
point_column_name, point_id,
format(col_type, point_source))
col_type_test = get_expr_type(test_id, test_source).lower()
- if col_type_test not in ['integer']:
+ if col_type_test not in INTEGER:
plpy.error("kNN Error: Invalid data type '{0}' for"
" test_id column in table '{1}'.".
format(col_type_test, test_source))
@@ -113,27 +126,325 @@ def knn_validate_src(schema_madlib, point_source,
point_column_name, point_id,
'squared_dist_norm2', 'dist_angle',
'dist_tanimoto')])
- profunc = ("proisagg = TRUE" if
is_pg_major_version_less_than(schema_madlib, 11)
- else "prokind = 'a'")
+ profunc = ("proisagg = TRUE"
+ if is_pg_major_version_less_than(schema_madlib, 11)
+ else "prokind = 'a'")
is_invalid_func = plpy.execute("""
- SELECT prorettype != 'DOUBLE PRECISION'::regtype OR
- {profunc} AS OUTPUT
+ SELECT prorettype != 'DOUBLE PRECISION'::regtype OR {profunc} AS
OUTPUT
FROM pg_proc
WHERE oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE
PRECISION[])'::regprocedure;
""".format(fn_dist=fn_dist, profunc=profunc))[0]['output']
if is_invalid_func or (fn_dist not in dist_functions):
- plpy.error("KNN error: Distance function has invalid signature "
- "or is not a simple function.")
-
+ plpy.error("KNN error: Distance function ({0}) has invalid
signature "
+ "or is not a simple function.".format(fn_dist))
+ if not is_brute_force:
+ if depth <= 0:
+ plpy.error("kNN Error: depth={0} is an invalid value, must be "
+ "greater than 0.".format(depth))
+ if leaf_nodes <= 0:
+ plpy.error("kNN Error: leaf_nodes={0} is an invalid value, must be
"
+ "greater than 0.".format(leaf_nodes))
+ if pow(2, depth) <= leaf_nodes:
+ plpy.error("kNN Error: depth={0}, leaf_nodes={1} is not valid. "
+ "The leaf_nodes value must be lower than 2^depth".
+ format(depth, leaf_nodes))
return k
#
------------------------------------------------------------------------------
+def build_kd_tree(schema_madlib, source_table, output_table, point_column_name,
+ depth, r_id, **kwargs):
+ """
+ KD-tree function to create a partitioning for KNN
+ Args:
+ @param schema_madlib Name of the Madlib Schema
+ @param source_table Training data table
+ @param output_table Name of the table to store kd tree
+ @param point_column_name Name of the column with training data
+ or expression that evaluates to a
+ numeric array
+ @param depth Depth of the kd tree
+ @param r_id Name of the region id column
+ """
+ with MinWarning("error"):
+
+ validate_kd_tree(source_table, output_table, point_column_name, depth)
+ n_features = num_features(source_table, point_column_name)
+
+ clauses = [' 1=1 ']
+ centers_table = add_postfix(output_table, "_centers")
+ clause_counter = 0
+ for curr_level in range(depth):
+ curr_feature = (curr_level % n_features) + 1
+ for curr_leaf in range(pow(2,curr_level)):
+ clause = clauses[clause_counter]
+ cutoff_sql = """
+ SELECT percentile_disc(0.5)
+ WITHIN GROUP (
+ ORDER BY ({point_column_name})[{curr_feature}]
+ ) AS cutoff
+ FROM {source_table}
+ WHERE {clause}
+ """.format(**locals())
+
+ cutoff = plpy.execute(cutoff_sql)[0]['cutoff']
+ cutoff = "NULL" if cutoff is None else cutoff
+ clause_counter += 1
+ clauses.append(clause +
+ "AND ({point_column_name})[{curr_feature}] <
{cutoff} ".
+ format(**locals()))
+ clauses.append(clause +
+ "AND ({point_column_name})[{curr_feature}] >=
{cutoff} ".
+ format(**locals()))
+
+ n_leaves = pow(2, depth)
+ case_when_clause = '\n'.join(["WHEN {0} THEN
{1}::INTEGER".format(cond, i)
+ for i, cond in
enumerate(clauses[-n_leaves:])])
+ output_sql = """
+ CREATE TABLE {output_table} AS
+ SELECT *,
+ CASE {case_when_clause} END AS {r_id}
+ FROM {source_table}
+ """.format(**locals())
+ plpy.execute(output_sql)
+ plpy.execute("DROP TABLE IF EXISTS {0}".format(centers_table))
+ centers_sql = """
+ CREATE TABLE {centers_table} AS
+ SELECT {r_id}, {schema_madlib}.array_scalar_mult(
+ {schema_madlib}.sum({point_column_name})::DOUBLE
PRECISION[],
+ (1.0/count(*))::DOUBLE PRECISION) AS __center__
+ FROM {output_table}
+ GROUP BY {r_id}
+ """.format(**locals())
+ plpy.execute(centers_sql)
+ return case_when_clause
+#
------------------------------------------------------------------------------
+
+
+def validate_kd_tree(source_table, output_table, point_column_name, depth):
+
+ input_tbl_valid(source_table, 'kd_tree')
+ output_tbl_valid(output_table, 'kd_tree')
+ output_tbl_valid(output_table+"_centers", 'kd_tree')
+
+ _assert(is_var_valid(source_table, point_column_name),
+ "kd_tree error: {0} is an invalid column name or expression for "
+ "point_column_name param".format(point_column_name))
+ point_col_type = get_expr_type(point_column_name, source_table)
+ _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
+ "kNN Error: Feature column or expression '{0}' in train table is
not"
+ " an array.".format(point_column_name))
+ if depth <= 0:
+ plpy.error("kNN Error: depth={0} is an invalid value, must be greater "
+ "than 0.".format(depth))
+#
------------------------------------------------------------------------------
+
+
+def knn_kd_tree(schema_madlib, kd_out, test_source, test_column_name, test_id,
+ fn_dist, max_leaves_to_explore, depth, r_id, case_when_clause,
+ t_col_name, **kwargs):
+ """
+ KNN function to find the K Nearest neighbours using kd tree
+ Args:
+ @param schema_madlib Name of the Madlib Schema
+ @param kd_out Name of the kd tree table
+ @param test_source Name of the table containing the test
+ data points.
+ @param test_column_name Name of the column with testing data
+ points or expression that evaluates to
a
+ numeric array
+ @param test_id Name of the column having ids of data
+ points in test data table.
+ @param fn_dist Distance metrics function.
+ @param max_leaves_to_explore Number of leaf nodes to explore
+ @param depth Depth of the kd tree
+ @param r_id Name of the region id column
+ @param case_when_clause SQL string for reconstructing the
+ kd-tree
+ @param t_col_name Unique test point column name
+ """
+ with MinWarning("error"):
+ centers_table = add_postfix(kd_out, "_centers")
+ test_view = add_postfix(kd_out, "_test_view")
+
+ n_leaves = pow(2,depth)
+ plpy.execute("DROP VIEW IF EXISTS {test_view}".format(**locals()))
+ test_view_sql = """
+ CREATE VIEW {test_view} AS
+ SELECT {test_id},
+ ({test_column_name})::DOUBLE PRECISION[] AS
{t_col_name},
+ CASE
+ {case_when_clause}
+ END AS {r_id}
+ FROM {test_source}""".format(**locals())
+ plpy.execute(test_view_sql)
+
+ if max_leaves_to_explore > 1:
+ ext_test_view = add_postfix(kd_out, "_ext_test_view")
+ ext_test_view_sql = """
+ CREATE VIEW {ext_test_view} AS
+ SELECT * FROM(
+ SELECT
+ row_number() OVER (PARTITION BY {test_id}
+ ORDER BY __dist_center__) AS r,
+ {test_id},
+ {t_col_name},
+ {r_id}
+ FROM (
+ SELECT
+ {test_id},
+ {t_col_name},
+ {centers_table}.{r_id} AS {r_id},
+ {fn_dist}({t_col_name}, __center__) AS
__dist_center__
+ FROM {test_view}, {centers_table}
+ ) q1
+ ) q2
+ WHERE r <= {max_leaves_to_explore}
+ """.format(**locals())
+ plpy.execute(ext_test_view_sql)
+ else:
+ ext_test_view = test_view
+
+ return ext_test_view
+#
------------------------------------------------------------------------------
+
+def _create_interim_tbl(schema_madlib, point_source, point_column_name,
point_id,
+ label_name, test_source, test_column_name, test_id, interim_table, k,
+ fn_dist, test_id_temp, train_id, dist_inverse, comma_label_out_alias,
+ label_out, r_id, kd_out, train, t_col_name, **kwargs):
+ """
+ KNN function to create the interim table
+ Args:
+ @param schema_madlib Name of the Madlib Schema
+ @param point_source Training data table
+ @param point_column_name Name of the column with training data
+ or expression that evaluates to a
+ numeric array
+ @param point_id Name of the column having ids of data
+ point in train data table
+ points.
+ @param label_name Name of the column with labels/values
+ of training data points.
+ @param test_source Name of the table containing the test
+ data points.
+ @param test_column_name Name of the column with testing data
+ points or expression that evaluates to
a
+ numeric array
+ @param test_id Name of the column having ids of data
+ points in test data table.
+ @param interim_table Name of the table to store interim
+ results.
+ @param k default: 1. Number of nearest
+ neighbors to consider
+ @param fn_dist Distance metrics function. Default is
+ squared_dist_norm2. Following functions
+ are supported :
+ dist_norm1 ,
dist_norm2,squared_dist_norm2,
+ dist_angle , dist_tanimoto
+ Or user defined function with signature
+ DOUBLE PRECISION[] x, DOUBLE
PRECISION[] y
+ -> DOUBLE PRECISION
+ Following parameters are passed to ensure the interim table has
+ identical features in both implementations
+ @param test_id_temp
+ @param train_id
+ @param dist_inverse
+ @param comma_label_out_alias
+ @param label_out
+ @param r_id
+ @param kd_out
+ @param train
+ @param t_col_name
+ """
+ with MinWarning("error"):
+ # If r_id is None, we are using the brute force algorithm.
+ is_brute_force = not bool(r_id)
+ r_id = "NULL AS {0}".format(unique_string()) if not r_id else r_id
+
+ p_col_name = unique_string(desp='p_col_name')
+ x_temp_table = unique_string(desp='x_temp_table')
+ y_temp_table = unique_string(desp='y_temp_table')
+ test = unique_string(desp='test')
+ r = unique_string(desp='r')
+ dist = unique_string(desp='dist')
+
+ if not is_brute_force:
+ point_source = kd_out
+ where_condition = "{train}.{r_id} = {test}.{r_id}
".format(**locals())
+ select_sql = """ {train}.{r_id} AS tr_{r_id},
+ {test}.{r_id} AS test_{r_id},
""".format(**locals())
+ t_col_cast = t_col_name
+ else:
+ where_condition = "1 = 1"
+ select_sql = ""
+ t_col_cast = "({test_column_name}) AS
{t_col_name}".format(**locals())
+
+ plpy.execute("""
+ CREATE TABLE {interim_table} AS
+ SELECT *
+ FROM (
+ SELECT row_number() OVER
+ (PARTITION BY {test_id_temp} ORDER BY {dist})
AS {r},
+ {test_id_temp},
+ {train_id},
+ CASE WHEN {dist} = 0.0 THEN {weight_for_zero_dist}
+ ELSE 1.0 / {dist}
+ END AS {dist_inverse}
+ {comma_label_out_alias}
+ FROM (
+ SELECT {select_sql}
+ {test}.{test_id} AS {test_id_temp},
+ {train}.{point_id} AS {train_id},
+ {fn_dist}({p_col_name}, {t_col_name}) AS {dist}
+ {label_out}
+ FROM
+ (
+ SELECT {point_id},
+ {r_id},
+ {point_column_name} AS {p_col_name}
+ {label_name}
+ FROM {point_source}
+ ) {train},
+ (
+ SELECT {test_id},
+ {t_col_cast},
+ {r_id}
+ FROM {test_source}
+ ) {test}
+ WHERE
+ {where_condition}
+ ) {x_temp_table}
+ ) {y_temp_table}
+ WHERE {y_temp_table}.{r} <= {k}
+ """.format(weight_for_zero_dist=WEIGHT_FOR_ZERO_DIST, **locals()))
+
+#
------------------------------------------------------------------------------
+
+def _get_algorithm_name(algorithm):
+ if not algorithm:
+ algorithm = BRUTE_FORCE
+ else:
+ supported_algorithms = [BRUTE_FORCE, KD_TREE]
+ try:
+ # allow user to specify a prefix substring of
+ # supported algorithms. This works because the supported
+ # algorithms have unique prefixes.
+ algorithm = next(x for x in supported_algorithms
+ if x.startswith(algorithm))
+ except StopIteration:
+ # next() returns a StopIteration if no element found
+ plpy.error("kNN Error: Invalid algorithm: "
+ "{0}. Supported algorithms are ({1})"
+ .format(algorithm,
','.join(sorted(supported_algorithms))))
+ return algorithm
+#
------------------------------------------------------------------------------
+
def knn(schema_madlib, point_source, point_column_name, point_id,
label_column_name, test_source, test_column_name, test_id,
output_table,
- k, output_neighbors, fn_dist, weighted_avg, **kwargs):
+ k, output_neighbors, fn_dist, weighted_avg, algorithm,
algorithm_params,
+ **kwargs):
"""
KNN function to find the K Nearest neighbours
Args:
@@ -158,7 +469,7 @@ def knn(schema_madlib, point_source, point_column_name,
point_id,
results.
@param k default: 1. Number of nearest
neighbors to consider
- @output_neighbours Outputs the list of k-nearest neighbors
+ @param output_neighbours Outputs the list of k-nearest neighbors
that were used in the voting/averaging.
@param fn_dist Distance metrics function. Default is
squared_dist_norm2. Following functions
@@ -166,32 +477,52 @@ def knn(schema_madlib, point_source, point_column_name,
point_id,
dist_norm1 ,
dist_norm2,squared_dist_norm2,
dist_angle , dist_tanimoto
Or user defined function with signature
- DOUBLE PRECISION[] x, DOUBLE
PRECISION[] y -> DOUBLE PRECISION
- @param weighted_avg Calculates the Regression or
classication of k-NN using
+ DOUBLE PRECISION[] x, DOUBLE
PRECISION[] y
+ -> DOUBLE PRECISION
+ @param weighted_avg Calculates the Regression or
+ classication of k-NN using
the weighted average method.
+ @param algorithm The algorithm to use for knn
+ @param algorithm_params The parameters for kd-tree algorithm
"""
with MinWarning('warning'):
output_neighbors = True if output_neighbors is None else
output_neighbors
if k is None:
k = 1
+
+ algorithm = _get_algorithm_name(algorithm)
+
+ # Default values for depth and leaf nodes
+ depth = 3
+ max_leaves_to_explore = 2
+
+ if algorithm_params:
+ params_types = {'depth': int, 'leaf_nodes': int}
+ default_args = {'depth': 3, 'leaf_nodes': 2}
+ algorithm_params_dict = extract_keyvalue_params(algorithm_params,
+ params_types,
+ default_args)
+
+ depth = algorithm_params_dict['depth']
+ max_leaves_to_explore = algorithm_params_dict['leaf_nodes']
+
knn_validate_src(schema_madlib, point_source,
point_column_name, point_id, label_column_name,
test_source, test_column_name, test_id,
- output_table, k, output_neighbors, fn_dist)
+ output_table, k, output_neighbors, fn_dist,
+ algorithm == BRUTE_FORCE, depth,
max_leaves_to_explore)
+
+ n_features = num_features(test_source, test_column_name)
# Unique Strings
- x_temp_table = unique_string(desp='x_temp_table')
- y_temp_table = unique_string(desp='y_temp_table')
label_col_temp = unique_string(desp='label_col_temp')
test_id_temp = unique_string(desp='test_id_temp')
+
train = unique_string(desp='train')
- test = unique_string(desp='test')
- p_col_name = unique_string(desp='p_col_name')
- t_col_name = unique_string(desp='t_col_name')
- dist = unique_string(desp='dist')
train_id = unique_string(desp='train_id')
dist_inverse = unique_string(desp='dist_inverse')
- r = unique_string(desp='r')
+ dim = unique_string(desp='dim')
+ t_col_name = unique_string(desp='t_col_name')
if not fn_dist:
fn_dist = '{0}.squared_dist_norm2'.format(schema_madlib)
@@ -206,6 +537,9 @@ def knn(schema_madlib, point_source, point_column_name,
point_id,
view_def = ""
view_join = ""
view_grp_by = ""
+ r_id = None
+ kd_output_table = None
+ test_data = None
if label_column_name:
label_column_type = get_expr_type(
@@ -240,11 +574,10 @@ def knn(schema_madlib, point_source, point_column_name,
point_id,
{test_id_temp},
{label_col_temp},
sum({dist_inverse}) data_sum
- FROM pg_temp.{interim_table}
+ FROM {interim_table}
GROUP BY {test_id_temp},
{label_col_temp}
) a
- -- GROUP BY {test_id_temp} , {label_col_temp}
)
""".format(**locals())
# This join is needed to get the max value of predicion
@@ -276,44 +609,34 @@ def knn(schema_madlib, point_source, point_column_name,
point_id,
comma_label_out_alias = ""
label_name = ""
- # interim_table picks the 'k' nearest neighbors for each test point
if output_neighbors:
knn_neighbors = (", array_agg(knn_temp.{train_id} ORDER BY "
"knn_temp.{dist_inverse} DESC) AS
k_nearest_neighbours ").format(**locals())
else:
knn_neighbors = ''
- plpy.execute("""
- CREATE TEMP TABLE {interim_table} AS
- SELECT * FROM (
- SELECT row_number() over
- (partition by {test_id_temp} order by {dist}) AS
{r},
- {test_id_temp},
- {train_id},
- CASE WHEN {dist} = 0.0 THEN {max_weight_zero_dist}
- ELSE 1.0 / {dist}
- END AS {dist_inverse}
- {comma_label_out_alias}
- FROM (
- SELECT {test}.{test_id} AS {test_id_temp},
- {train}.{point_id} as {train_id},
- {fn_dist}(
- {p_col_name},
- {t_col_name})
- AS {dist}
- {label_out}
- FROM
- (
- SELECT {point_id} , {point_column_name} as
{p_col_name} {label_name} from {point_source}
- ) {train},
- (
- SELECT {test_id} ,{test_column_name} as
{t_col_name} from {test_source}
- ) {test}
- ) {x_temp_table}
- ) {y_temp_table}
- WHERE {y_temp_table}.{r} <= {k}
- """.format(max_weight_zero_dist=MAX_WEIGHT_ZERO_DIST, **locals()))
- sql = """
+ if 'kd_tree' in algorithm:
+ r_id = unique_string(desp='r_id')
+ kd_output_table = unique_string(desp='kd_tree')
+ case_when_clause = build_kd_tree(schema_madlib,
+ point_source,
+ kd_output_table,
+ point_column_name,
+ depth, r_id)
+ test_data = knn_kd_tree(schema_madlib, kd_output_table,
test_source,
+ test_column_name, test_id, fn_dist,
+ max_leaves_to_explore, depth, r_id,
+ case_when_clause, t_col_name)
+ else:
+ test_data = test_source
+
+ # interim_table picks the 'k' nearest neighbors for each test point
+ _create_interim_tbl(schema_madlib, point_source, point_column_name,
+ point_id, label_name, test_data, test_column_name,
+ test_id, interim_table, k, fn_dist, test_id_temp,
+ train_id, dist_inverse, comma_label_out_alias,
+ label_out, r_id, kd_output_table, train,
t_col_name)
+ output_sql = """
CREATE TABLE {output_table} AS
{view_def}
SELECT
@@ -322,7 +645,7 @@ def knn(schema_madlib, point_source, point_column_name,
point_id,
{pred_out}
{knn_neighbors}
FROM
- pg_temp.{interim_table} AS knn_temp
+ {interim_table} AS knn_temp
JOIN
{test_source} AS knn_test
ON knn_temp.{test_id_temp} = knn_test.{test_id}
@@ -330,11 +653,19 @@ def knn(schema_madlib, point_source, point_column_name,
point_id,
GROUP BY knn_temp.{test_id_temp},
{test_column_name}
{view_grp_by}
- """
- plpy.execute(sql.format(**locals()))
- plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
+ """.format(**locals())
+ plpy.execute(output_sql)
+ drop_tables([interim_table])
+
+ if 'kd_tree' in algorithm:
+ centers_table = add_postfix(kd_output_table, "_centers")
+ test_view = add_postfix(kd_output_table, "_test_view")
+ ext_test_view = add_postfix(kd_output_table, "_ext_test_view")
+ plpy.execute("DROP VIEW IF EXISTS {0} CASCADE".format(test_view))
+ plpy.execute("DROP VIEW IF EXISTS {0}
CASCADE".format(ext_test_view))
+ drop_tables([centers_table, kd_output_table])
return
-
+#
------------------------------------------------------------------------------
def knn_help(schema_madlib, message, **kwargs):
"""
@@ -366,7 +697,9 @@ SELECT {schema_madlib}.knn(
k, -- value of k. Default will go as 1
output_neighbors -- Outputs the list of k-nearest neighbors that were
used in the voting/averaging.
fn_dist -- The name of the function to use to calculate the
distance from a data point to a centroid.
- weighted_avg Calculates the Regression or classication of k-NN
using the weighted average method.
+ weighted_avg -- Calculates the Regression or classication of k-NN
using the weighted average method.
+ algorithm -- The algorithm to use for knn.
+ algorithm_params -- The parameters for kd-tree algorithm.
);
-----------------------------------------------------------------------
@@ -397,6 +730,5 @@ of k nearest neighbors of the given testing example.
For an overview on usage, run:
SELECT {schema_madlib}.knn('usage');
"""
-
return help_string.format(schema_madlib=schema_madlib)
#
------------------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.sql_in
b/src/ports/postgres/modules/knn/knn.sql_in
index 6fe1672..0693e94 100644
--- a/src/ports/postgres/modules/knn/knn.sql_in
+++ b/src/ports/postgres/modules/knn/knn.sql_in
@@ -32,7 +32,6 @@
m4_include(`SQLCommon.m4')
-
/**
@addtogroup grp_knn
@@ -47,24 +46,22 @@ m4_include(`SQLCommon.m4')
</ul>
</div>
-@brief Finds k nearest data points to the given data point and outputs
majority vote value of output classes for classification, and average value of
target values for regression.
-
-\warning <em> This MADlib method is still in early stage development. There
may be some
-issues that will be addressed in a future version. Interface and implementation
-are subject to change. </em>
+@brief Finds \f$k\f$ nearest data points to the given data point and outputs
majority
+vote value of output classes for classification, or average value of target
+values for regression.
@anchor knn
-K-nearest neighbors is a method for finding the k closest points to a
-given data point in terms of a given metric. Its input consists of
-data points as features from testing examples, and it
-looks for k closest points in the training set for each of the data
-points in test set. The output of KNN depends on the type of task.
-For classification, the output is the majority vote of the classes of
-the k nearest data points. That is, the testing example gets assigned the
-most popular class from the nearest neighbors.
-For regression, the output is the average of the values of k nearest
-neighbors of the given test point.
+K-nearest neighbors is a method for finding the \f$k\f$ closest points to a
given data
+point in terms of a given metric. Its input consists of data points as features
+from testing examples and it looks for \f$k\f$ closest points in the training
set
+for each of the data points in test set. The output of KNN depends on the type
+of task. For classification, the output is the majority vote of the classes of
+the \f$k\f$ nearest data points. For regression, the output is the average of
the
+values of \f$k\f$ nearest neighbors of the given test point.
+
+Both exact and approximate methods are supported. The approximate methods can
be
+used in the case that run-time is too long using the exact method.
@anchor usage
@par Usage
@@ -80,7 +77,9 @@ knn( point_source,
k,
output_neighbors,
fn_dist,
- weighted_avg
+ weighted_avg,
+ algorithm,
+ algorithm_params
)
</pre>
@@ -93,7 +92,7 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
</dd>
<dt>point_column_name</dt>
-<dd>TEXT. Name of the column with training data points
+<dd>TEXT. Name of the column with training data points
or expression that evaluates to a numeric array</dd>
<dt>point_id</dt>
@@ -116,7 +115,7 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
</dd>
<dt>test_column_name</dt>
-<dd>TEXT. Name of the column with testing data points
+<dd>TEXT. Name of the column with testing data points
or expression that evaluates to a numeric array</dd>
<dt>test_id</dt>
@@ -154,13 +153,38 @@ regression values using a weighted average. The idea is
to
weigh the contribution of each of the k neighbors according
to their distance to the test point, giving greater influence
to closer neighbors. The distance function 'fn_dist' specified
-above is used.
+above is used. For classification, majority voting weighs a neighbor
+according to inverse distance. For regression, the inverse distance
+weighting approach is used from Shepard [4].
+
+<dt>algorithm (optional)</dt>
+<dd>TEXT, default: 'brute_force'. The name of the algorithm
+used to compute nearest neighbors. The following options are supported:
+<ul>
+<li><b>\ref brute_force</b>: Produces an exact result by searching
+all points in the search space. You can also use a short
+form "b" or "brute" etc. to select brute force.</li>
+<li><b>\ref kd_tree</b>: Produces an approximate result by searching
+a subset of the search space, that is, only certain leaf nodes in the
+kd-tree as specified by "algorithm_params" below.
+You can also use a short
+form "k" or "kd" etc. to select kd-tree.</li></ul></dd>
+
+<dt>algorithm_params (optional)</dt>
+<dd>TEXT, default: 'depth=3, leaf_nodes=2'. These parameters apply to the
+kd-tree algorithm only.
+<ul>
+<li><b>\ref depth</b>: Depth of the kd-tree. Increasing this value will
+decrease run-time but reduce the accuracy.</li>
+<li><b>\ref leaf_nodes</b>: Number of leaf nodes (regions) to search for each
test point.
+Inceasing this value will improve the accuracy but increase run-time.</li></ul>
-For classification, majority voting weighs a neighbor
-according to inverse distance.
+@note
+Please note that the kd-tree accuracy will be lower for datasets with a high
+number of features. It is advised to use at least two leaf nodes.
+Refer to the <a href="#background">Technical Background</a> for more
information
+on how the kd-tree is implemented.</dd>
-For regression, the inverse distance weighting approach is
-used from Shepard [4].
</dl>
@@ -234,7 +258,7 @@ INSERT INTO knn_train_data_reg VALUES
-# Prepare some testing data:
<pre class="example">
-DROP TABLE IF EXISTS knn_test_data;
+DROP TABLE IF EXISTS knn_test_data CASCADE;
CREATE TABLE knn_test_data (
id integer,
data integer[]
@@ -375,6 +399,73 @@ SELECT * FROM knn_result_classification ORDER BY id;
(6 rows)
</pre>
+-# Use kd-tree option. First we build a kd-tree to depth 4 and
+search half (8) of the 16 leaf nodes (i.e., 2^4 total leaf nodes):
+<pre class="example">
+DROP TABLE IF EXISTS knn_result_classification_kd;
+SELECT madlib.knn(
+ 'knn_train_data', -- Table of training data
+ 'data', -- Col name of training data
+ 'id', -- Col name of id in train data
+ NULL, -- Training labels
+ 'knn_test_data', -- Table of test data
+ 'data', -- Col name of test data
+ 'id', -- Col name of id in test data
+ 'knn_result_classification_kd', -- Output table
+ 3, -- Number of nearest neighbors
+ True, -- True to list nearest-neighbors by
id
+ 'madlib.squared_dist_norm2', -- Distance function
+ False, -- For weighted average
+ 'kd_tree', -- Use kd-tree
+ 'depth=4, leaf_nodes=8' -- Kd-tree options
+ );
+SELECT * FROM knn_result_classification_kd ORDER BY id;
+</pre>
+<pre class="result">
+ id | data | k_nearest_neighbours
+----+---------+----------------------
+ 1 | {2,1} | {1,2,3}
+ 2 | {2,6} | {5,4,3}
+ 3 | {15,40} | {7,6,5}
+ 4 | {12,1} | {4,5,3}
+ 5 | {2,90} | {9,6,7}
+ 6 | {50,45} | {6,7,8}
+(6 rows)
+</pre>
+The result above is the same as brute force. If we search just 1 leaf node,
+run-time will be faster but accuracy will be lower. This shows up in this
+very small data set by not being able to find 3 nearest neighbors for all test
points:
+<pre class="example">
+DROP TABLE IF EXISTS knn_result_classification_kd;
+SELECT madlib.knn(
+ 'knn_train_data', -- Table of training data
+ 'data', -- Col name of training data
+ 'id', -- Col name of id in train data
+ NULL, -- Training labels
+ 'knn_test_data', -- Table of test data
+ 'data', -- Col name of test data
+ 'id', -- Col name of id in test data
+ 'knn_result_classification_kd', -- Output table
+ 3, -- Number of nearest neighbors
+ True, -- True to list nearest-neighbors by
id
+ 'madlib.squared_dist_norm2', -- Distance function
+ False, -- For weighted average
+ 'kd_tree', -- Use kd-tree
+ 'depth=4, leaf_nodes=1' -- Kd-tree options
+ );
+SELECT * FROM knn_result_classification_kd ORDER BY id;
+</pre>
+<pre class="result">
+ id | data | k_nearest_neighbours
+----+---------+----------------------
+ 1 | {2,1} | {1}
+ 2 | {2,6} | {3,2}
+ 3 | {15,40} | {7}
+ 5 | {2,90} | {3,2}
+ 6 | {50,45} | {6,8}
+(5 rows)
+</pre>
+
@anchor background
@par Technical Background
@@ -382,11 +473,37 @@ The training data points are vectors in a
multidimensional feature space,
each with a class label. The training phase of the algorithm consists
only of storing the feature vectors and class labels of the training points.
-In the classification phase, k is a user-defined constant, and an unlabeled
-vector (a test point) is classified by assigning the label which is most
-frequent among the k training samples nearest to that test point.
-In case of regression, average of the values of these k training samples
-is assigned to the test point.
+In the prediction phase, \f$k\f$ is a user-defined constant, and an unlabeled
vector
+(a test point) is predicted by using the label from the the \f$k\f$ training
samples
+nearest to that test point.
+
+Since distances between points are used to find the nearest neighbors, the data
+should be standardized across features. This ensures that all features are
given
+equal weightage in the distance computation.
+
+An approximation method can be used to speed the prediction phase by building
+appropriate data structures in the training phase. An example of such a data
+structure is kd-trees [5]. Using the kd-tree algorithm can improve the
execution
+time of the \f$k\f$-NN operation, but at expense of sacrificing some accuracy.
The
+kd-tree implementation divides the training dataset into multiple regions that
+correspond to the leaf nodes of a tree. For example, a tree of depth \f$3\f$
will have
+a total of \f$2^3 = 8\f$ regions. The algorithm will look for the nearest
neighbors
+in a subset of all regions instead of searching the whole dataset. For a given
+test point, the first (home) region is found by traversing the tree and finding
+its associated node. If the user requests additional leaf nodes to be searched,
+we look at the distance between the point and the centroids of other regions
and
+expand the search to the specified number of closest regions.
+
+It's important to note that the nodes that each level of the kd-tree search
over
+a single feature and the features are explored in the same order as that in the
+data.
+
+The kd-tree accuracy might suffer on datasets with a high number of features
+(dimensions). For example, let's say we are using a dataset with 20 features
and
+kd-tree depth of only 3. This means the kd-tree is constructed based on the
+first 3 features. Therefore, it is possible to miss nearest neighbors that are
+closer in those 17 dimensions because they got assigned to a further region
(the
+distance computation would still uses all 20 features).
@anchor literature
@literature
@@ -404,14 +521,20 @@ is assigned to the test point.
https://ai2-s2-pdfs.s3.amazonaws.com/a7e2/814ec5db800d2f8c4313fd436e9cf8273821.pdf
@anchor knn-lit-4
-[4] Shepard, Donald (1968). "A two-dimensional interpolation function for
+[4] Shepard, Donald (1968). "A two-dimensional interpolation function for
irregularly-spaced data". Proceedings of the 1968 ACM National Conference. pp.
517–524.
+@anchor knn-lit-5
+[5] Bentley, J. L. (1975). "Multidimensional binary search trees used for
+associative searching". Communications of the ACM. 18 (9): 509.
doi:10.1145/361002.361007.
+
+
@internal
@sa namespace knn (documenting the implementation in Python)
@endinternal
*/
+
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.__knn_validate_src(
point_source VARCHAR,
point_column_name VARCHAR,
@@ -440,12 +563,61 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
k INTEGER,
output_neighbors BOOLEAN,
fn_dist TEXT,
- weighted_avg BOOLEAN
+ weighted_avg BOOLEAN,
+ algorithm VARCHAR,
+ algorithm_params VARCHAR
) RETURNS VARCHAR AS $$
PythonFunction(`knn', `knn', `knn')
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
+ point_source VARCHAR,
+ point_column_name VARCHAR,
+ point_id VARCHAR,
+ label_column_name VARCHAR,
+ test_source VARCHAR,
+ test_column_name VARCHAR,
+ test_id VARCHAR,
+ output_table VARCHAR,
+ k INTEGER,
+ output_neighbors BOOLEAN,
+ fn_dist TEXT,
+ weighted_avg BOOLEAN,
+ algorithm VARCHAR
+) RETURNS VARCHAR AS $$
+ DECLARE
+ returnstring VARCHAR;
+BEGIN
+ returnstring =
MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,
+ NULL);
+ RETURN returnstring;
+END;
+$$ LANGUAGE plpgsql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
+ point_source VARCHAR,
+ point_column_name VARCHAR,
+ point_id VARCHAR,
+ label_column_name VARCHAR,
+ test_source VARCHAR,
+ test_column_name VARCHAR,
+ test_id VARCHAR,
+ output_table VARCHAR,
+ k INTEGER,
+ output_neighbors BOOLEAN,
+ fn_dist TEXT,
+ weighted_avg BOOLEAN
+) RETURNS VARCHAR AS $$
+ DECLARE
+ returnstring VARCHAR;
+BEGIN
+ returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,
+ NULL, NULL);
+ RETURN returnstring;
+END;
+$$ LANGUAGE plpgsql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
point_source VARCHAR,
@@ -463,7 +635,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
DECLARE
returnstring VARCHAR;
BEGIN
- returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,
FALSE);
+ returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,
+ FALSE, NULL, NULL);
RETURN returnstring;
END;
$$ LANGUAGE plpgsql VOLATILE
@@ -486,7 +659,8 @@ DECLARE
returnstring VARCHAR;
BEGIN
returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,
- 'MADLIB_SCHEMA.squared_dist_norm2',
FALSE);
+ 'MADLIB_SCHEMA.squared_dist_norm2', FALSE,
+ NULL, NULL);
RETURN returnstring;
END;
$$ LANGUAGE plpgsql VOLATILE
@@ -507,7 +681,8 @@ DECLARE
returnstring VARCHAR;
BEGIN
returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,TRUE,
- 'MADLIB_SCHEMA.squared_dist_norm2',
FALSE);
+ 'MADLIB_SCHEMA.squared_dist_norm2', FALSE,
+ NULL, NULL);
RETURN returnstring;
END;
$$ LANGUAGE plpgsql VOLATILE
@@ -527,7 +702,8 @@ DECLARE
returnstring VARCHAR;
BEGIN
returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,TRUE,
- 'MADLIB_SCHEMA.squared_dist_norm2',FALSE);
+ 'MADLIB_SCHEMA.squared_dist_norm2',FALSE,
+ NULL, NULL);
RETURN returnstring;
END;
$$ LANGUAGE plpgsql VOLATILE
@@ -546,4 +722,3 @@ RETURNS VARCHAR AS $$
SELECT MADLIB_SCHEMA.knn('');
$$ LANGUAGE sql IMMUTABLE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `');
-
diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in
b/src/ports/postgres/modules/knn/test/knn.sql_in
index 6dbed36..86f0eb4 100644
--- a/src/ports/postgres/modules/knn/test/knn.sql_in
+++ b/src/ports/postgres/modules/knn/test/knn.sql_in
@@ -8,7 +8,7 @@
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
@@ -25,12 +25,12 @@ m4_include(`SQLCommon.m4')
*
* --------------------------------------------------------------------------
*/
-drop table if exists knn_train_data;
-create table knn_train_data (
+DROP TABLE if exists knn_train_data;
+create TABLE knn_train_data (
id integer,
data integer[],
label integer);
-copy knn_train_data (id, data, label) from stdin delimiter '|';
+copy knn_train_data (id, data, label) FROM stdin delimiter '|';
1|{1,1}|1
2|{2,2}|1
3|{3,3}|1
@@ -47,7 +47,7 @@ CREATE TABLE knn_train_data_reg (
data integer[],
label float
);
-COPY knn_train_data_reg (id, data, label) from stdin delimiter '|';
+COPY knn_train_data_reg (id, data, label) FROM stdin delimiter '|';
1|{1,1}|1.0
2|{2,2}|1.0
3|{3,3}|1.0
@@ -59,10 +59,10 @@ COPY knn_train_data_reg (id, data, label) from stdin
delimiter '|';
9|{1,111}|0.0
\.
DROP TABLE IF EXISTS knn_test_data;
-create table knn_test_data (
+create TABLE knn_test_data (
id integer,
data integer[]);
-copy knn_test_data (id, data) from stdin delimiter '|';
+copy knn_test_data (id, data) FROM stdin delimiter '|';
1|{2,1}
2|{2,6}
3|{15,40}
@@ -70,13 +70,13 @@ copy knn_test_data (id, data) from stdin delimiter '|';
5|{2,90}
6|{50,45}
\.
-drop table if exists knn_train_data_expr;
-create table knn_train_data_expr (
+DROP TABLE if exists knn_train_data_expr;
+create TABLE knn_train_data_expr (
id integer,
-data1 integer,
+data1 integer,
data2 integer,
label integer);
-copy knn_train_data_expr (id, data1 , data2, label) from stdin delimiter '|';
+copy knn_train_data_expr (id, data1 , data2, label) FROM stdin delimiter '|';
1| 1 | 1 |1
2| 2 | 2 |1
3| 3 | 3 |1
@@ -90,77 +90,242 @@ copy knn_train_data_expr (id, data1 , data2, label) from
stdin delimiter '|';
-drop table if exists madlib_knn_result_classification;
-select
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
-select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output
in classification with k=3') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0,1,0,0}', 'Wrong output
in classification with k=3') FROM madlib_knn_result_classification;
-drop table if exists madlib_knn_result_classification;
-select
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3);
-select assert(array_agg(x order by id)= '{1,2,3}','Wrong output in
classification with k=3') from (select unnest(k_nearest_neighbours) as x, id
from madlib_knn_result_classification where id = 1 order by x asc) y;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3);
+SELECT assert(array_agg(x ORDER BY id)= '{1,2,3}','Wrong output in
classification with k=3') FROM (SELECT unnest(k_nearest_neighbours) AS x, id
FROM madlib_knn_result_classification WHERE id = 1 ORDER BY x ASC) y;
-drop table if exists madlib_knn_result_regression;
-select
knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
-select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}',
'Wrong output in regression') from madlib_knn_result_regression;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT
knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0.5,1,0.25,0.25}',
'Wrong output in regression') FROM madlib_knn_result_regression;
-drop table if exists madlib_knn_result_regression;
-select
knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True);
-select assert(array_agg(x order by id)= '{1,2,3}' , 'Wrong output in
regression with k=3') from (select unnest(k_nearest_neighbours) as x, id from
madlib_knn_result_regression where id = 1 order by x asc) y;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT
knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True);
+SELECT assert(array_agg(x ORDER BY id)= '{1,2,3}' , 'Wrong output in
regression with k=3') FROM (SELECT unnest(k_nearest_neighbours) AS x, id FROM
madlib_knn_result_regression WHERE id = 1 ORDER BY x ASC) y;
-drop table if exists madlib_knn_result_classification;
-select
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,NULL,False);
-select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output
in classification with k=3') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,NULL,False);
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0,1,0,0}', 'Wrong output
in classification with k=3') FROM madlib_knn_result_classification;
-drop table if exists madlib_knn_result_classification;
-select
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_norm1');
-select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output
in classification with k=3') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_norm1');
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0,1,0,0}', 'Wrong output
in classification with k=3') FROM madlib_knn_result_classification;
-drop table if exists madlib_knn_result_classification;
-select
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_angle');
-select assert(array_agg(prediction order by id)='{1,0,0,1,0,1}', 'Wrong output
in classification with k=3') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_angle');
+SELECT assert(array_agg(prediction ORDER BY id)='{1,0,0,1,0,1}', 'Wrong output
in classification with k=3') FROM madlib_knn_result_classification;
-drop table if exists madlib_knn_result_classification;
-select
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_tanimoto');
-select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output
in classification with k=3') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_tanimoto');
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0,1,0,0}', 'Wrong output
in classification with k=3') FROM madlib_knn_result_classification;
-drop table if exists madlib_knn_result_regression;
-select
knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.dist_norm1');
-select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}',
'Wrong output in regression') from madlib_knn_result_regression;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT
knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.dist_norm1');
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0.5,1,0.25,0.25}',
'Wrong output in regression') FROM madlib_knn_result_regression;
-drop table if exists madlib_knn_result_regression;
-select
knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.dist_angle');
-select assert(array_agg(prediction order by
id)='{0.75,0.25,0.25,0.75,0.25,1}', 'Wrong output in regression') from
madlib_knn_result_regression;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT
knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.dist_angle');
+SELECT assert(array_agg(prediction ORDER BY
id)='{0.75,0.25,0.25,0.75,0.25,1}', 'Wrong output in regression') FROM
madlib_knn_result_regression;
-drop table if exists madlib_knn_result_classification;
-select
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
True);
-select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}',
'Wrong output in classification with k=3') from
madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT
knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
True);
+SELECT assert(array_agg(prediction::numeric ORDER BY id)='{1,1,0,1,0,0}',
'Wrong output in classification with k=3') FROM
madlib_knn_result_classification;
-drop table if exists madlib_knn_result_regression;
-select
knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
True);
-select assert(array_agg(prediction::numeric order by
id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') from
madlib_knn_result_regression;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT
knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
True);
+SELECT assert(array_agg(prediction::numeric ORDER BY
id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') FROM
madlib_knn_result_regression;
-drop table if exists madlib_knn_result_classification;
-select
knn('knn_train_data','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
True);
-select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}',
'Wrong output in classification') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT
knn('knn_train_data','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
True);
+SELECT assert(array_agg(prediction::numeric ORDER BY id)='{1,1,0,1,0,0}',
'Wrong output in classification') FROM madlib_knn_result_classification;
-drop table if exists madlib_knn_result_regression;
-select
knn('knn_train_data_reg','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
True);
-select assert(array_agg(prediction::numeric order by
id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') from
madlib_knn_result_regression;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT
knn('knn_train_data_reg','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
True);
+SELECT assert(array_agg(prediction::numeric ORDER BY
id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') FROM
madlib_knn_result_regression;
-drop table if exists madlib_knn_result_classification;
-select
knn('knn_train_data_expr','ARRAY[data1,data2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
True);
-select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}',
'Wrong output in classification') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT
knn('knn_train_data_expr','ARRAY[data1,data2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',
True);
+SELECT assert(array_agg(prediction::numeric ORDER BY id)='{1,1,0,1,0,0}',
'Wrong output in classification') FROM madlib_knn_result_classification;
-drop table if exists madlib_knn_result_classification;
-select
knn('knn_train_data','data','id',NULL,'knn_test_data','data','id','madlib_knn_result_classification',3);
-select assert(array_agg(x order by id)= '{1,2,3}','Wrong output in
classification with k=3') from (select unnest(k_nearest_neighbours) as x, id
from madlib_knn_result_classification where id = 1 order by x asc) y;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT
knn('knn_train_data','data','id',NULL,'knn_test_data','data','id','madlib_knn_result_classification',3);
+SELECT assert(array_agg(x ORDER BY id)= '{1,2,3}','Wrong output in
classification with k=3') FROM (SELECT unnest(k_nearest_neighbours) AS x, id
FROM madlib_knn_result_classification WHERE id = 1 ORDER BY x ASC) y;
-select knn();
-select knn('help');
+SELECT knn();
+SELECT knn('help');
+
+
+
+DROP TABLE if exists knn_train_data2;
+CREATE TABLE knn_train_data2 (
+ id integer,
+ data double precision[],
+ label integer
+);
+COPY knn_train_data2 (id, data, label) FROM stdin delimiter '|';
+1|{43983,164834}|0
+2|{491231,38953}|0
+3|{587484,467668}|0
+4|{882448,507209}|0
+5|{17326,595844}|0
+6|{236408,453230}|0
+7|{283929,237605}|0
+8|{392623,153808}|0
+9|{267864,179054}|0
+10|{428486,618138}|0
+11|{963752,141363}|0
+12|{980623,652584}|0
+13|{398411,894748}|0
+14|{559681,670919}|0
+15|{297984,171933}|0
+16|{254190,341966}|0
+17|{336766,745420}|0
+18|{380918,924250}|0
+19|{213087,263365}|0
+20|{431458,230413}|0
+21|{859208,667865}|0
+22|{683642,143136}|0
+23|{905470,76265}|0
+24|{296944,173333}|0
+25|{255319,725429}|0
+26|{791471,219070}|0
+27|{866791,772094}|0
+28|{871653,265202}|0
+29|{666841,431334}|0
+30|{936120,964824}|0
+31|{603267,190309}|0
+32|{306790,940033}|1
+33|{935729,687708}|1
+34|{864282,148815}|1
+35|{951072,295739}|1
+36|{379228,810280}|1
+37|{963604,62869}|1
+38|{953416,869073}|1
+39|{139133,250360}|1
+40|{42406,394452}|1
+41|{975789,833877}|1
+42|{613521,842579}|1
+43|{605970,485173}|1
+44|{107780,272810}|1
+45|{916507,43900}|1
+46|{237634,519773}|1
+47|{234208,544424}|1
+48|{459805,169937}|1
+49|{232131,324086}|1
+50|{318751,183202}|1
+51|{619825,697978}|1
+52|{993482,583428}|1
+53|{760847,946898}|1
+54|{452501,899980}|1
+55|{197257,494907}|1
+56|{294431,173045}|1
+57|{328783,907951}|1
+58|{15624,934752}|1
+59|{393124,123404}|1
+60|{207562,309630}|1
+61|{167303,445196}|1
+62|{829402,401511}|1
+63|{989619,289207}|1
+64|{571447,221749}|1
+65|{613292,890198}|1
+66|{404951,233116}|1
+67|{588176,398433}|1
+68|{816544,349023}|1
+69|{345330,269045}|1
+70|{249002,542587}|1
+71|{763951,543433}|1
+72|{715632,92734}|1
+73|{451384,731255}|1
+74|{27485,844507}|1
+75|{854659,235047}|1
+76|{154137,21962}|1
+77|{680243,983539}|1
+78|{423473,669861}|1
+79|{272745,994920}|1
+80|{891610,886037}|1
+81|{885117,296561}|1
+82|{119153,473293}|2
+83|{694994,935696}|2
+84|{822315,40323}|2
+85|{204741,71317}|2
+86|{582910,968691}|2
+87|{614749,298541}|2
+88|{61424,66132}|2
+89|{29796,88909}|2
+90|{910639,884455}|2
+91|{323956,64775}|2
+92|{906416,4198}|2
+93|{48314,329888}|2
+94|{674059,321058}|2
+95|{324807,565669}|2
+96|{207094,209924}|2
+97|{862229,326247}|2
+98|{683217,557222}|2
+99|{261943,505531}|2
+100|{597545,466683}|2
+\.
+
+
+CREATE TABLE knn_test_data2 (
+ id integer NOT NULL,
+ data integer[]
+);
+
+COPY knn_test_data2 (id, data) FROM stdin delimiter '|';
+1|{576848,180455}
+2|{435374,191597}
+3|{478996,496797}
+4|{257729,508791}
+5|{585706,168367}
+\.
+
+DROP TABLE if exists madlib_knn_result_classification_kd;
+
+SELECT knn('knn_train_data2','data','id',NULL,'knn_test_data2','data','id',
+ 'madlib_knn_result_classification_kd',1,True,
+ 'MADLIB_SCHEMA.squared_dist_norm2',False,
+ 'kd_tree', 'depth=2, leaf_nodes=2');
+
+SELECT assert(count(*) > 0, 'Wrong output with kd_tree')
+FROM madlib_knn_result_classification_kd;
+
+DROP TABLE if exists madlib_knn_result_classification_kd;
+
+SELECT knn('knn_train_data2','data','id','label','knn_test_data2','data','id',
+ 'madlib_knn_result_classification_kd',2,True,
+ 'MADLIB_SCHEMA.squared_dist_norm2',True,
+ 'kd_tree', 'depth=2, leaf_nodes=2');
+
+SELECT assert(count(*) > 0, 'Wrong output with kd_tree')
+FROM madlib_knn_result_classification_kd;
+
+DROP TABLE if exists madlib_knn_result_classification_kd;
+
+SELECT knn('knn_train_data', 'data', 'id', NULL, 'knn_test_data', 'data', 'id',
+ 'madlib_knn_result_classification_kd', 2, True,
+ 'MADLIB_SCHEMA.squared_dist_norm2', False, 'kd_tree',
+ 'depth=2, leaf_nodes=1');
+
+SELECT assert(count(*) > 0, 'Wrong output with kd_tree')
+FROM madlib_knn_result_classification_kd;
+
+DROP TABLE if exists madlib_knn_result_classification_kd;
+
+SELECT knn('knn_train_data', 'data', 'id', NULL, 'knn_test_data', 'data', 'id',
+ 'madlib_knn_result_classification_kd', 2, True,
+ 'MADLIB_SCHEMA.squared_dist_norm2', False, 'kd_tree',
+ 'depth=3, leaf_nodes=2');
+
+SELECT assert(count(*) > 0, 'Wrong output with kd_tree')
+FROM madlib_knn_result_classification_kd;
diff --git a/src/ports/postgres/modules/utilities/admin.py_in
b/src/ports/postgres/modules/utilities/admin.py_in
index 2fa5e62..6f88fd9 100644
--- a/src/ports/postgres/modules/utilities/admin.py_in
+++ b/src/ports/postgres/modules/utilities/admin.py_in
@@ -11,6 +11,15 @@ def __get_madlib_temp_tables(target_schema):
""".format(**locals())
return plpy.execute(sql_get_tables_to_drop)
+def __get_madlib_temp_views(target_schema):
+ sql_get_tables_to_drop = """
+ SELECT quote_ident(viewname) AS viewname
+ FROM pg_views
+ WHERE viewname LIKE E'%madlib\_temp%'
+ AND quote_ident(schemaname) = '{target_schema}'
+ """.format(**locals())
+ return plpy.execute(sql_get_tables_to_drop)
+
#
------------------------------------------------------------------------------
def cleanup_madlib_temp_tables(schema_madlib, target_schema, **kwargs):
""" Drop all tables matching '%madlib_temp%' in the given schema
@@ -65,3 +74,16 @@ def cleanup_madlib_temp_tables_script(schema_madlib,
target_schema, **kwargs):
sql_drop = "DROP TABLE {target_schema}.{tablename};".format(**locals())
sql_content += sql_drop + "\n"
return sql_content
+
+#
------------------------------------------------------------------------------
+def cleanup_madlib_temp_views(schema_madlib, target_schema, **kwargs):
+ to_drop_list = __get_madlib_temp_views(target_schema)
+ if len(to_drop_list) == 0:
+ plpy.info("No madlib temp views found in schema
{target_schema}.".format(**locals()))
+ return None
+ sql_drop_all = 'DROP VIEW IF EXISTS '
+ sql_drop_all += ",".join(["{target_schema}.{viewname}".format(
+ viewname=row['viewname'], **locals()) for row in to_drop_list])
+ sql_drop_all += " CASCADE;"
+ plpy.notice("Dropping {0} views ...".format(len(to_drop_list)))
+ plpy.execute(sql_drop_all)
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in
b/src/ports/postgres/modules/utilities/utilities.py_in
index 1b0069f..d2f14a5 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -432,7 +432,6 @@ def is_pg_major_version_less_than(schema_madlib,
compare_version, **kwargs):
version = plpy.execute("select version()")[0]["version"]
regex = re.compile('PostgreSQL\s*([0-9]+)([0-9.beta]+)', re.IGNORECASE)
version = regex.findall(version)
- plpy.info("{0}".format(version))
if len(version) > 0 and int(version[0][0]) < compare_version:
return True
else:
diff --git a/src/ports/postgres/modules/utilities/utilities.sql_in
b/src/ports/postgres/modules/utilities/utilities.sql_in
index e598566..7035940 100644
--- a/src/ports/postgres/modules/utilities/utilities.sql_in
+++ b/src/ports/postgres/modules/utilities/utilities.sql_in
@@ -114,6 +114,14 @@ PythonFunction(utilities, admin,
cleanup_madlib_temp_tables_script)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cleanup_madlib_temp_views(
+ target_schema text
+)
+RETURNS void AS $$
+PythonFunction(utilities, admin, cleanup_madlib_temp_views)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
/**
* @brief Return MADlib build information.
*