Github user njayaram2 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/225#discussion_r163652456
  
    --- Diff: src/ports/postgres/modules/knn/knn.py_in ---
    @@ -178,11 +183,38 @@ def knn(schema_madlib, point_source, 
point_column_name, point_id,
                 if label_column_type in ['boolean', 'integer', 'text']:
                     is_classification = True
                     cast_to_int = '::INTEGER'
    +            if weighted_avg:
    +                pred_out = ",sum( {label_col_temp} * 
1/dist)/sum(1/dist)".format(
    +                    label_col_temp=label_col_temp)
    +            else:
    +                pred_out = ", avg({label_col_temp})".format(
    +                    label_col_temp=label_col_temp)
     
    -            pred_out = ", avg({label_col_temp})".format(**locals())
                 if is_classification:
    -                pred_out = (", {schema_madlib}.mode({label_col_temp})"
    -                            ).format(**locals())
    +                if weighted_avg:
    +                    # This view is to calculate the max value of sum of 
the 1/distance grouped by label and Id.
    +                    # And this max value will be the prediction for the
    +                    # classification model.
    +                    view_def = ("   WITH vw "
    +                                "   AS (SELECT {test_id_temp}  ,"
    +                                "   max(data_sum) data_dist "
    +                                "   FROM   (SELECT {test_id_temp}, "
    +                                "   sum(1 / dist) data_sum"
    +                                "   FROM   pg_temp.{interim_table} "
    +                                "   GROUP  BY {test_id_temp}, "
    +                                "   {label_col_temp}) a "
    +                                "   GROUP  BY {test_id_temp} 
)").format(**locals())
    +                    # This join is needed to get the max value of predicion
    +                    # calculated above
    +                    view_join = (" JOIN vw AS knn_vw "
    +                                 "ON knn_temp.{test_id_temp} = 
knn_vw.{test_id_temp}").format(
    +                        test_id_temp=test_id_temp)
    +                    view_grp_by = ", knn_vw.data_dist "
    +                    pred_out = ", knn_vw.data_dist"
    +                else:
    +                    pred_out = (", {schema_madlib}.mode({label_col_temp})"
    +                                ).format(**locals())
    +
    --- End diff --
    
    We can have this string inside a single pair of `"""..."""`, instead of 
multiple `"..."`


---

Reply via email to