Github user hpandeycodeit commented on a diff in the pull request: https://github.com/apache/madlib/pull/225#discussion_r162684704 --- Diff: src/ports/postgres/modules/knn/knn.py_in --- @@ -211,23 +222,43 @@ def knn(schema_madlib, point_source, point_column_name, point_id, ) {y_temp_table} WHERE {y_temp_table}.r <= {k_val} """.format(**locals())) - - plpy.execute( - """ - CREATE TABLE {output_table} AS - SELECT {test_id_temp} AS id, {test_column_name} - {pred_out} - {knn_neighbors} - FROM pg_temp.{interim_table} AS knn_temp - JOIN - {test_source} AS knn_test ON - knn_temp.{test_id_temp} = knn_test.{test_id} - GROUP BY {test_id_temp} , {test_column_name} - """.format(**locals())) - - plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table)) + if weighted_avg and is_classification: + plpy.execute( + """ + CREATE TABLE {output_table} AS + SELECT id, {test_column_name} ,max(prediction) as prediction + {k_neighbours} + FROM + ( SELECT {test_id_temp} AS id, {test_column_name} + {pred_out} + {knn_neighbors} + FROM pg_temp.{interim_table} AS knn_temp + JOIN + {test_source} AS knn_test ON + knn_temp.{test_id_temp} = knn_test.{test_id} + GROUP BY {test_id_temp} , + {test_column_name}, {label_col_temp}) + a {k_neighbours_unnest} + GROUP BY id, {test_column_name} + """.format(**locals())) + else: + plpy.execute( + """ + CREATE TABLE {output_table} AS + SELECT {test_id_temp} AS id, {test_column_name} --- End diff -- @rahiyer , I am changing the query like this: ``` WITH vw AS (SELECT test_id, max(data2) data_dist FROM (SELECT test_id, Sum(1 / dist) data2 FROM madlib_temp GROUP BY test_id, label) a GROUP BY test_id) SELECT test_vw.test_id, knn_test.data, test_vw.data_dist prediction, Array_agg(knn_temp.train_id order by dist asc) AS k_nearest_neighbours FROM madlib_temp AS knn_temp JOIN knn_test_data AS knn_test ON knn_temp.test_id = knn_test.id JOIN vw AS test_vw ON knn_temp.test_id = test_vw.test_id GROUP BY knn_temp.test_id, test_vw.test_id, test_vw.data_dist, data ORDER BY test_vw.test_id; ``` here the max of the distance is calculated seperately in a view and then using it to populate it with the existing data. Let me know your thoughts on this. Also, can you provide more details on "add a wrapper if it's weighted_avg" comment?
---