Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/incubator-madlib/pull/4#discussion_r45919749
--- Diff: src/ports/postgres/modules/svm/svm.py_in ---
@@ -66,211 +253,118 @@ def svm(schema_madlib, source_table, model_table,
"""
Executes the linear support vector classification algorithm.
"""
- # verbosing
+ # verbosing
verbosity_level = "info" if verbose else "error"
with MinWarning(verbosity_level):
- # validate input
- input_tbl_valid(source_table, 'SVM')
- _assert(is_var_valid(source_table, dependent_varname),
- "SVM error: invalid dependent_varname ('" +
str(dependent_varname) +
- "') for source_table (" + source_table + ")!")
- _assert(is_var_valid(source_table, independent_varname),
- "SVM error: invalid independent_varname ('" +
str(independent_varname) +
- "') for source_table (" + source_table + ")!")
-
- dep_type = get_expr_type(dependent_varname, source_table)
- if '[]' in dep_type:
- plpy.error("SVM error: dependent_varname cannot be of array
type!")
-
- # validate output tables
- output_tbl_valid(model_table, 'SVM')
- summary_table = add_postfix(model_table, "_summary")
- output_tbl_valid(summary_table, 'SVM')
-
- # arguments for iterating
- n_features = plpy.execute("SELECT array_upper({0}, 1) AS dim "
- "FROM {1} LIMIT 1".
- format(independent_varname, source_table)
- )[0]['dim']
- if grouping_col:
- grouping_list = [i + "::text"
- for i in explicit_bool_to_text(
- source_table,
- _string_to_array_with_quotes(grouping_col),
- schema_madlib)]
- grouping_str = ','.join(grouping_list)
- else:
- grouping_str = "Null"
- grouping_str1 = "" if not grouping_col else grouping_col + ","
- grouping_str2 = "1 = 1" if not grouping_col else grouping_col
-
- args = {
- 'rel_args': unique_string(desp='rel_args'),
- 'rel_state': unique_string(desp='rel_state'),
- 'col_grp_iteration':
unique_string(desp='col_grp_iteration'),
- 'col_grp_state': unique_string(desp='col_grp_state'),
- 'col_grp_key': unique_string(desp='col_grp_key'),
- 'col_n_tuples': unique_string(desp='col_n_tuples'),
- 'state_type': "double precision[]",
- 'rel_source': source_table,
- 'col_ind_var': independent_varname,
- 'col_dep_var': dependent_varname}
- args.update(locals())
- # variables defined above cannot be moved below this line
- # -------------------------------------------------------
-
- # other params
- kernel_func = 'linear' if not kernel_func else kernel_func.lower()
- # Add non-linear kernels below after implementing them.
- supported_kernels = ['linear']
- try:
- # allow user to specify a prefix substring of
- # supported kernel function names. This works because the
supported
- # kernel functions have unique prefixes.
- kernel_func = next(x for x in supported_kernels if
x.startswith(kernel_func))
- except StopIteration:
- # next() returns a StopIteration if no element found
- plpy.error("SVM Error: Invalid kernel function: {0}. "
- "Supported kernel functions are ({1})"
- .format(kernel_func,
','.join(sorted(supported_kernels))))
-
- if grouping_col:
- cols_in_tbl_valid(source_table,
_string_to_array_with_quotes(grouping_col), 'SVM')
- intersect =
frozenset(_string_to_array(grouping_col)).intersection(
- frozenset(
- ('coef', '__random_feature_data',
- '__random_feature_data', 'loss'
- 'num_rows_processed',
'num_rows_skipped',
- 'norm_of_gradient',
'num_iterations')))
- if len(intersect) > 0:
- plpy.error("SVM error: Conflicting grouping column name.\n"
- "Some predefined keyword(s) ({0}) are not
allowed!".format(
- ', '.join(intersect)))
-
- args.update(_extract_params(schema_madlib, params))
- args.update(_process_epsilon(is_svc, args))
-
- if not is_svc:
- # transform col_dep_var to binary (1 or -1) if classification
- args.update({
- 'col_dep_var_trans': dependent_varname,
- 'mapping': 'NULL',
- 'method': 'SVR'})
- else:
- # dependent variable mapping
- dep_labels=plpy.execute("""
- SELECT {dependent_varname} AS y
- FROM {source_table}
- WHERE ({dependent_varname}) IS NOT NULL
- GROUP BY ({dependent_varname})
- ORDER BY ({dependent_varname})""".format(**locals()))
- dep_var_mapping = ["'" + d['y'] + "'" if isinstance(d['y'],
basestring)
- else str(d['y']) for d in dep_labels]
- if len(dep_var_mapping) != 2:
- plpy.error("SVM error: Classification currently only
supports binary output")
-
- col_dep_var_trans = (
- """
- CASE WHEN ({col_dep_var}) IS NULL THEN NULL
- WHEN ({col_dep_var}) = {mapped_value_for_negative}
THEN -1.0
- ELSE 1.0
- END
- """
- .format(col_dep_var=dependent_varname,
- mapped_value_for_negative=dep_var_mapping[0])
- )
-
- args.update({
- 'mapped_value_for_negative': dep_var_mapping[0],
- 'col_dep_var_trans': col_dep_var_trans,
- 'mapping': dep_var_mapping[0] + "," + dep_var_mapping[1],
- 'method': 'SVC'})
-
- args['stepsize'] = args['init_stepsize']
- args['is_l2'] = True if args['norm'] == 'l2' else False
-
- # place holder for compatibility
- plpy.execute("CREATE TABLE pg_temp.{0} AS SELECT
1".format(args['rel_args']))
- # actual iterative algorithm computation
- n_iters_run = _compute_svm(args)
-
- # organizing results
- groupby_str = "GROUP BY {grouping_col},
{col_grp_key}".format(**args) if grouping_col else ""
- using_str = "USING ({col_grp_key})".format(**args) if grouping_col
else "ON TRUE"
- model_table_query = """
- CREATE TABLE {model_table} AS
- SELECT
- {grouping_str1}
- (result).coefficients AS coef,
- (result).loss AS loss,
- (result).norm_of_gradient AS norm_of_gradient,
- {n_iters_run} AS num_iterations,
- (result).num_rows_processed AS num_rows_processed,
- n_tuples_including_nulls - (result).num_rows_processed
- AS num_rows_skipped,
- NULL AS
__random_feature_data,
- ARRAY[{mapping}]::{dep_type}[] AS dep_var_mapping
- FROM
- (
- SELECT
- {schema_madlib}.internal_linear_svm_igd_result(
- {col_grp_state}
- ) AS result,
- {col_grp_key}
- FROM {rel_state}
- WHERE {col_grp_iteration} = {n_iters_run}
- ) rel_state_subq
- JOIN
- (
- SELECT
- {grouping_str1}
- count(*) AS n_tuples_including_nulls,
- array_to_string(ARRAY[{grouping_str}],
- ','
- ) AS {col_grp_key}
- FROM {source_table}
- {groupby_str}
- ) n_tuples_including_nulls_subq
- {using_str}
- """.format(n_iters_run=n_iters_run,
- groupby_str=groupby_str,
- using_str=using_str, **args)
- plpy.execute(model_table_query)
-
- if isinstance(args['lambda'], list):
- args['lambda_str'] = '{' + ','.join(str(e) for e in
args['lambda']) + '}'
- else:
- args['lambda_str'] = str(args['lambda'])
-
- plpy.execute("""
- CREATE TABLE {summary_table} AS
- SELECT
- '{method}'::text AS method,
- '__MADLIB_VERSION__'::text AS version_number,
- '{source_table}'::text AS source_table,
- '{model_table}'::text AS model_table,
- '{dependent_varname}'::text AS
dependent_varname,
- '{independent_varname}'::text AS
independent_varname,
- 'linear'::text AS kernel_func,
- NULL::text AS kernel_params,
- '{grouping_text}'::text AS grouping_col,
- 'init_stepsize={init_stepsize}, ' ||
- 'decay_factor={decay_factor}, ' ||
- 'max_iter={max_iter}, ' ||
- 'tolerance={tolerance}'::text AS optim_params,
- 'lambda={lambda_str}, ' ||
- 'norm={norm}, ' ||
- 'n_folds={n_folds}'::text AS reg_params,
- count(*)::integer AS num_all_groups,
- 0::integer AS
num_failed_groups,
- sum(num_rows_processed)::bigint AS
total_rows_processed,
- sum(num_rows_skipped)::bigint AS
total_rows_skipped,
- '{epsilon}'::double precision AS epsilon,
- '{eps_table}'::text AS eps_table
- FROM {model_table};
- """.format(grouping_text="NULL" if not grouping_col else
grouping_col,
- **args))
-#
------------------------------------------------------------------------------
+ _verify_table(source_table,
+ model_table,
+ dependent_varname,
+ independent_varname)
+ args = locals()
+ args['params_dict'] = _extract_params(schema_madlib, params)
+ _cross_validate_svm(args)
+ _svm_parsed_params(**args)
+
+
+def _cross_validate_svm(args):
+ # updating params_dict will also update
+ # also update args['params_dict']
+ params_dict = args['params_dict']
+
+ if params_dict['n_folds'] > 1 and args['grouping_col']:
+ plpy.error('SVM error: cross validation '
+ 'with grouping is not supported!')
+
+ # currently only support cross validation
+ # on lambda and epsilon
+ cv_params = {}
+ if len(params_dict['lambda']) > 1:
+ cv_params['lambda'] = params_dict['lambda']
+ else:
+ params_dict['lambda'] = params_dict['lambda'][0]
+ if len(params_dict['epsilon']) > 1 and not args['is_svc']:
+ cv_params['epsilon'] = params_dict['epsilon']
+ else:
+ params_dict['epsilon'] = params_dict['epsilon'][0]
+ if len(params_dict['init_stepsize']) > 1:
+ cv_params['init_stepsize'] = params_dict['init_stepsize']
+ else:
+ params_dict['init_stepsize'] = params_dict['init_stepsize'][0]
+ if len(params_dict['max_iter']) > 1:
+ cv_params['max_iter'] = params_dict['max_iter']
+ else:
+ params_dict['max_iter'] = params_dict['max_iter'][0]
+ if len(params_dict['decay_factor']) > 1:
+ cv_params['decay_factor'] = params_dict['decay_factor']
+ else:
+ params_dict['decay_factor'] = params_dict['decay_factor'][0]
+
+ if not cv_params and params_dict['n_folds'] <= 1:
+ return
+
+ if cv_params and params_dict['n_folds'] <= 1:
+ plpy.error("SVM Error: parameters must be a scalar "
+ "or of length 1 when n_folds is 0 or 1")
+ return
+
+ if not cv_params and params_dict['n_folds'] > 1:
+ plpy.warning('SVM Warning: no cross validate params provided! '
+ 'Ignore {}-folds cross validation request.'
+ .format(params_dict['n_folds']))
+ return
+
+ scorer = 'classification' if args['is_svc'] else 'regression'
+ sub_args = {'params_dict':cv_params}
+ cv = CrossValidator(_svm_parsed_params,svm_predict,scorer,args)
+ val_res = cv.validate(sub_args, params_dict['n_folds']).sorted()
+ val_res.output_tbl(params_dict['validation_result'])
+ params_dict.update(val_res.first('sub_args')['params_dict'])
+
+
+def _svm_parsed_params(schema_madlib, source_table, model_table,
+ dependent_varname, independent_varname, kernel_func,
+ kernel_params, grouping_col, params_dict, is_svc,
+ verbose, **kwargs):
+ """
+ Executes the linear support vector classification algorithm.
+ """
+ grouping_str = _verify_grouping(schema_madlib,
+ source_table,
+ grouping_col)
+
+ kernel_func = _verify_kernel(kernel_func)
+
+ # arguments for iterating
+ n_features = num_features(source_table,
+ independent_varname)
+
+ args = {
+ 'rel_args': unique_string(desp='rel_args'),
+ 'rel_state': unique_string(desp='rel_state'),
+ 'col_grp_iteration': unique_string(desp='col_grp_iteration'),
+ 'col_grp_state': unique_string(desp='col_grp_state'),
+ 'col_grp_key': unique_string(desp='col_grp_key'),
+ 'col_n_tuples': unique_string(desp='col_n_tuples'),
+ 'state_type': "double precision[]",
+ 'n_features': n_features,
+ 'verbose': verbose,
+ 'schema_madlib': schema_madlib,
+ 'grouping_str': grouping_str,
+ 'grouping_col': grouping_col,
+ 'rel_source': source_table,
+ 'col_ind_var': independent_varname,
+ 'col_dep_var': dependent_varname}
+
+ args.update(_verify_params_dict(params_dict))
+ args.update(_process_epsilon(is_svc, args))
+ args.update(_svc_or_svr(is_svc, source_table, dependent_varname))
+
+ # place holder for compatibility
+ plpy.execute("CREATE TABLE pg_temp.{0} AS SELECT
1".format(args['rel_args']))
+ # actual iterative algorithm computation
+ n_iters_run = _compute_svm(args)
+ _summary(n_iters_run, model_table, args)
+>>>>>>> b105d1c... SVM: Add cross validation support and generic
CrossValidator class
--- End diff --
No problem - this has happened to me before :)
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---