Github user njayaram2 commented on a diff in the pull request:

    https://github.com/apache/incubator-madlib/pull/80#discussion_r92721716
  
    --- Diff: src/ports/postgres/modules/knn/knn.sql_in ---
    @@ -0,0 +1,126 @@
    +/* ----------------------------------------------------------------------- 
*//**
    + *
    + * @file knn.sql_in
    + *
    + * @brief Set of functions for k-nearest neighbors.
    + *
    + *
    + *//* 
----------------------------------------------------------------------- */
    +
    +m4_include(`SQLCommon.m4')
    +
    +DROP TYPE IF EXISTS MADLIB_SCHEMA.knn_result CASCADE;
    +CREATE TYPE MADLIB_SCHEMA.knn_result AS (
    +    prediction float
    +);
    +DROP TYPE IF EXISTS MADLIB_SCHEMA.test_table_spec CASCADE;
    +CREATE TYPE MADLIB_SCHEMA.test_table_spec AS (
    +    id integer,
    +    vector DOUBLE PRECISION[]
    +);
    +
    +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.__knn_validate_src(
    +rel_source VARCHAR
    +) RETURNS VOID AS $$
    +    PythonFunction(knn, knn, knn_validate_src)
    +$$ LANGUAGE plpythonu
    +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `READS SQL DATA', `');
    +
    +
    +
    +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
    +    point_source VARCHAR,
    +    point_column_name VARCHAR,
    +    label_column_name VARCHAR,
    +    test_source VARCHAR,
    +    test_column_name VARCHAR,
    +    id_column_name VARCHAR,
    +    operation VARCHAR,
    +    k INTEGER
    +) RETURNS VARCHAR AS $$
    +DECLARE
    +    class_test_source REGCLASS;
    +    class_point_source REGCLASS;
    +    l FLOAT;
    +    id INTEGER;
    +    vector DOUBLE PRECISION[];
    +    cur_pid integer;
    +    theResult MADLIB_SCHEMA.knn_result;
    +    r MADLIB_SCHEMA.test_table_spec;
    +    oldClientMinMessages VARCHAR;
    +    returnstring VARCHAR;
    +BEGIN
    +    oldClientMinMessages :=
    +        (SELECT setting FROM pg_settings WHERE name = 
'client_min_messages');
    +    EXECUTE 'SET client_min_messages TO warning';
    +    PERFORM MADLIB_SCHEMA.__knn_validate_src(test_source);
    +    PERFORM MADLIB_SCHEMA.__knn_validate_src(point_source);
    +    class_test_source := test_source;
    +    class_point_source := point_source;
    +    --checks
    +    IF (k <= 0) THEN
    +        RAISE EXCEPTION 'KNN error: Number of neighbors k must be a 
positive integer.';
    +    END IF;
    +    IF (operation != 'c' AND operation != 'r') THEN
    +        RAISE EXCEPTION 'KNN error: put r for regression OR c for 
classification.';
    +    END IF;
    +    PERFORM MADLIB_SCHEMA.create_schema_pg_temp();
    +    EXECUTE
    +        $sql$
    +           DROP TABLE IF EXISTS pg_temp.knn_label;
    +           CREATE TABLE pg_temp.knn_label(pid integer, predlabel float);
    +   $sql$;
    +    
    +    --FOR r IN EXECUTE format('SELECT * FROM %I', test_source)
    +    FOR r IN EXECUTE format('SELECT %I,%I FROM %I', id_column_name, 
test_column_name, test_source)
    +    LOOP
    +        
    +   --RAISE NOTICE 'Original: %',r.pid;
    +   --RAISE NOTICE 'Original: %',r.p;
    +   cur_pid := r.id;
    +   vector := r.vector;
    +   EXECUTE
    +        $sql$
    +   DROP TABLE IF EXISTS pg_temp.knn_vector;
    +        CREATE TABLE pg_temp.knn_vector(vec DOUBLE PRECISION[]);
    +   $sql$;
    +   EXECUTE 'INSERT INTO pg_temp.knn_vector values($1)'
    +   USING vector;
    +   EXECUTE
    +        $sql$
    +   DROP TABLE IF EXISTS pg_temp.knn_interm;
    +        --CREATE TABLE pg_temp.knn_interm(dist DOUBLE PRECISION, lable 
integer );
    +   CREATE TABLE pg_temp.knn_interm AS
    +   
    +        SELECT madlib.squared_dist_norm2($sql$ || point_column_name || 
$sql$, vec) as dist, $sql$ || label_column_name || $sql$ FROM $sql$ || 
textin(regclassout(point_source)) || $sql$, knn_vector order by dist limit 
$sql$ || k;
    +   IF (operation = 'c') THEN
    +           EXECUTE
    +        $sql$
    +        SELECT mode() within group (order by $sql$ || label_column_name || 
$sql$) FROM  pg_temp.knn_interm $sql$
    +           INTO l;
    +        ELSE
    +        EXECUTE
    +        $sql$
    +        SELECT avg( $sql$ || label_column_name || $sql$ ) FROM  
pg_temp.knn_interm $sql$
    +        INTO l;
    +        END IF;
    +   EXECUTE 'INSERT INTO pg_temp.knn_label values($1,$2)'
    +   USING cur_pid, l;
    +    END LOOP;
    +    
    +    EXECUTE
    +        $sql$
    +   DROP TABLE IF EXISTS public.knn_final;
    +        CREATE TABLE public.knn_final AS
    --- End diff --
    
    Why do we need both `pg_temp.knn_label` and `public.knn_final`?
    Can't we just use the latter directly instead of knn_label?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

Reply via email to