Github user njayaram2 commented on a diff in the pull request: https://github.com/apache/madlib/pull/225#discussion_r163652456 --- Diff: src/ports/postgres/modules/knn/knn.py_in --- @@ -178,11 +183,38 @@ def knn(schema_madlib, point_source, point_column_name, point_id, if label_column_type in ['boolean', 'integer', 'text']: is_classification = True cast_to_int = '::INTEGER' + if weighted_avg: + pred_out = ",sum( {label_col_temp} * 1/dist)/sum(1/dist)".format( + label_col_temp=label_col_temp) + else: + pred_out = ", avg({label_col_temp})".format( + label_col_temp=label_col_temp) - pred_out = ", avg({label_col_temp})".format(**locals()) if is_classification: - pred_out = (", {schema_madlib}.mode({label_col_temp})" - ).format(**locals()) + if weighted_avg: + # This view is to calculate the max value of sum of the 1/distance grouped by label and Id. + # And this max value will be the prediction for the + # classification model. + view_def = (" WITH vw " + " AS (SELECT {test_id_temp} ," + " max(data_sum) data_dist " + " FROM (SELECT {test_id_temp}, " + " sum(1 / dist) data_sum" + " FROM pg_temp.{interim_table} " + " GROUP BY {test_id_temp}, " + " {label_col_temp}) a " + " GROUP BY {test_id_temp} )").format(**locals()) + # This join is needed to get the max value of predicion + # calculated above + view_join = (" JOIN vw AS knn_vw " + "ON knn_temp.{test_id_temp} = knn_vw.{test_id_temp}").format( + test_id_temp=test_id_temp) + view_grp_by = ", knn_vw.data_dist " + pred_out = ", knn_vw.data_dist" + else: + pred_out = (", {schema_madlib}.mode({label_col_temp})" + ).format(**locals()) + --- End diff -- We can have this string inside a single pair of `"""..."""`, instead of multiple `"..."`
---