Github user hpandeycodeit commented on a diff in the pull request:
https://github.com/apache/madlib/pull/204#discussion_r154183869
--- Diff: src/ports/postgres/modules/knn/knn.py_in ---
@@ -89,20 +89,20 @@ def knn_validate_src(schema_madlib, point_source,
point_column_name, point_id,
" column '{1}' in table '{2}'.".
format(col_type_test, test_id, test_source))
- fn_dist = fn_dist.lower().strip()
- dist_functions = {
- schema_madlib + '.dist_norm1',
- schema_madlib + '.dist_norm2',
- schema_madlib + '.squared_dist_norm2',
- schema_madlib + '.dist_angle',
- schema_madlib + '.dist_tanimoto'}
-
- if plpy.execute("""select prorettype != 'DOUBLE PRECISION'::regtype
- OR proisagg = TRUE AS OUTPUT from pg_proc where
- oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE
PRECISION[])'::regprocedure;"""
- .format(**locals()))[0]['output'] or fn_dist not in
dist_functions:
- plpy.error(
- "KNN error: Distance function has wrong signature or is not a
simple function.")
+ if fn_dist and fn_dist is not None:
+ fn_dist = fn_dist.lower().strip()
+ dist_functions = set([schema_madlib + dist for dist in
+ ('.dist_norm1', '.dist_norm2',
'.squared_dist_norm2', '.dist_angle', '.dist_tanimoto')])
+
+ is_invalid_func = plpy.execute(
+ """select prorettype != 'DOUBLE PRECISION'::regtype
+ OR proisagg = TRUE AS OUTPUT from pg_proc where
+ oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE
PRECISION[])'::regprocedure;
+ """.format(**locals()))[0]['output']
+
+ if is_invalid_func or fn_dist not in dist_functions:
--- End diff --
@orhankislal @njayaram2 @kaknikhil @fmcquillan99 Does the user input has to
be in this form Schema_name.Fn_dist ? I referred to the k-Means where
Schema_name.Fn_dist is passed as the user input.
---