Github user hpandeycodeit commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/225#discussion_r162684704
  
    --- Diff: src/ports/postgres/modules/knn/knn.py_in ---
    @@ -211,23 +222,43 @@ def knn(schema_madlib, point_source, 
point_column_name, point_id,
                         ) {y_temp_table}
                 WHERE {y_temp_table}.r <= {k_val}
                 """.format(**locals()))
    -
    -        plpy.execute(
    -            """
    -            CREATE TABLE {output_table} AS
    -                SELECT {test_id_temp} AS id, {test_column_name}
    -                    {pred_out}
    -                    {knn_neighbors}
    -                FROM pg_temp.{interim_table} AS knn_temp
    -                    JOIN
    -                    {test_source} AS knn_test ON
    -                    knn_temp.{test_id_temp} = knn_test.{test_id}
    -                GROUP BY {test_id_temp} , {test_column_name}
    -            """.format(**locals()))
    -
    -        plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
    +        if weighted_avg and is_classification:
    +            plpy.execute(
    +                """
    +                CREATE TABLE {output_table} AS
    +                    SELECT id, {test_column_name} ,max(prediction) as 
prediction
    +                        {k_neighbours}
    +                    FROM
    +                        ( SELECT {test_id_temp} AS id, {test_column_name}
    +                                {pred_out}
    +                                {knn_neighbors}
    +                            FROM pg_temp.{interim_table} AS knn_temp
    +                                JOIN
    +                                {test_source} AS knn_test ON
    +                                knn_temp.{test_id_temp} = 
knn_test.{test_id}
    +                            GROUP BY {test_id_temp} ,
    +                                {test_column_name}, {label_col_temp})
    +                            a {k_neighbours_unnest}
    +                    GROUP BY id, {test_column_name}
    +                """.format(**locals()))
    +        else:
    +            plpy.execute(
    +                """
    +                CREATE TABLE {output_table} AS
    +                    SELECT {test_id_temp} AS id, {test_column_name}
    --- End diff --
    
    @rahiyer , 
    
    I am changing the query like this: 
    
    ```
    WITH vw 
         AS (SELECT test_id, 
                    max(data2) data_dist 
             FROM   (SELECT test_id, 
                            Sum(1 / dist) data2 
                     FROM   madlib_temp 
                     GROUP  BY test_id, 
                               label) a 
             GROUP  BY test_id) 
    SELECT test_vw.test_id, 
           knn_test.data, 
           test_vw.data_dist   prediction, 
           Array_agg(knn_temp.train_id order by dist asc) AS 
k_nearest_neighbours 
    FROM   madlib_temp AS knn_temp 
           JOIN knn_test_data AS knn_test 
             ON knn_temp.test_id = knn_test.id 
           JOIN vw AS test_vw 
             ON knn_temp.test_id = test_vw.test_id 
    GROUP  BY knn_temp.test_id, 
              test_vw.test_id, 
              test_vw.data_dist,  
              data 
    ORDER  BY test_vw.test_id;
    ```
    
    here the max of the distance is calculated seperately in a view and then 
using it to populate it with the existing data. Let me know your thoughts on 
this. 
    
    Also, can you provide more details on "add a wrapper if it's weighted_avg" 
comment?


---

Reply via email to