This is an automated email from the ASF dual-hosted git repository.

fmcquillan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new e4b53a7  add examples for generalize cross validation
e4b53a7 is described below

commit e4b53a75f62e9e6688d611fc3bc029af26961b0f
Author: Frank McQuillan <fmcquil...@pivotal.io>
AuthorDate: Thu May 2 15:32:28 2019 -0700

    add examples for generalize cross validation
---
 .../modules/validation/cross_validation.sql_in     | 198 ++++++++++++++++++---
 1 file changed, 173 insertions(+), 25 deletions(-)

diff --git a/src/ports/postgres/modules/validation/cross_validation.sql_in 
b/src/ports/postgres/modules/validation/cross_validation.sql_in
index a5eeeff..77b2c2b 100644
--- a/src/ports/postgres/modules/validation/cross_validation.sql_in
+++ b/src/ports/postgres/modules/validation/cross_validation.sql_in
@@ -28,7 +28,8 @@ m4_include(`SQLCommon.m4')
 </ul>
 </div>
 
-Estimates the fit of a predictive model given a data set and specifications 
for the training, prediction, and error estimation functions.
+Estimates the fit of a predictive model given a data set and specifications for
+the training, prediction, and error estimation functions.
 
 Cross validation, sometimes called rotation estimation, is a technique for
 assessing how the results of a statistical analysis will generalize to an
@@ -56,12 +57,12 @@ output table. The prediction function should take a unique 
ID column name in
 the data table as one of the inputs, so that the prediction result can be
 compared with the validation values.
 Note: Prediction function in some MADlib modules do not save results into an 
output
-table. These prediction functions are not suitable for cross-validation.
+table. These prediction functions are not suitable for this cross-validation 
module.
 
 - The error metric function compares the prediction results with the known
 values of the dependent variables in the data set that was fed into the
 prediction function. It computes the error metric using the specified error
-metric function, storing the  results in a table.
+metric function, and stores the results in a table.
 
 Other inputs include the output table name, k value for the k-fold
 cross validation, and how many folds to try. For example, you can choose to 
run a
@@ -94,40 +95,54 @@ cross_validation_general( modelling_func,
 <dl class="arglist">
 <dt>modelling_func</dt>
 <dd>VARCHAR. The name of the function that trains the model.</dd>
+
 <dt>modelling_params</dt>
 <dd>VARCHAR[]. An array of parameters to supply to the modelling function.</dd>
+
 <dt>modelling_params_type</dt>
 <dd>VARCHAR[]. An array of data type names for each of the parameters supplied 
to the modelling function.</dd>
+
 <dt>param_explored</dt>
 <dd>VARCHAR. The name of the parameter that will be checked to find the 
optimum value. The name must appear in the \e modelling_params array.</dd>
+
 <dt>explore_values</dt>
 <dd>VARCHAR. The name of the parameter whose values are to be studied.</dd>
+
 <dt>predict_func</dt>
 <dd>VARCHAR. The name of the prediction function.</dd>
+
 <dt>predict_params</dt>
 <dd>VARCHAR[]. An array of parameters to supply to the prediction 
function.</dd>
+
 <dt>predict_params_type</dt>
 <dd>VARCHAR[]. An array of data type names for each of the parameters supplied 
to the prediction function.</dd>
+
 <dt>metric_func</dt>
 <dd>VARCHAR. The name of the function for measuring errors.</dd>
+
 <dt>metric_params</dt>
 <dd>VARCHAR[]. An array of parameters to supply to the error metric 
function.</dd>
+
 <dt>metric_params_type</dt>
 <dd>VARCHAR[]. An array of data type names for each of the parameters supplied 
to the metric function.</dd>
+
 <dt>data_tbl</dt>
 <dd>VARCHAR. The name of the data table that will be split into training and 
validation parts.</dd>
+
 <dt>data_id</dt>
 <dd>VARCHAR. The name of the column containing a unique ID associated with
 each row, or NULL if the table has no such column.
 
-Ideally, the data set has a unique ID for each row, so that it is easier to
+Ideally, the data set has a unique ID for each row so that it is easier to
 partition the data set into the training part and the validation part. Set the
 \e id_is_random argument to inform the cross-validation function whether
 the ID value is randomly assigned to each row. If it is not randomly
 assigned, the cross-validation function generates a random ID for each row.
 </dd>
+
 <dt>id_is_random</dt>
 <dd>BOOLEAN. TRUE if the provided ID is randomly assigned to each row.</dd>
+
 <dt>validation_result</dt>
 <dd>VARCHAR. The name of the table to store the output of the cross-validation 
function. The output table has the following columns:
 <table class="output">
@@ -146,6 +161,7 @@ same name specified in the \e param_explored argument of 
the \e cross_validation
 </tr>
 </table>
 </dd>
+
 <dt>data_cols</dt>
 <dd>A comma-separated list of names of data columns to use in the calculation.
 When its value is NULL, the function will automatically figure out all the 
column names of the data table.
@@ -183,42 +199,174 @@ The parameter arrays for the modelling, prediction and 
metric functions can incl
 @anchor examples
 @examp
 
-This example uses cross validation with an elastic net regression to find the 
best value of the regularization parameter.
-
--# Populate the table \c cvtest with 101 dimensional independent variables \c 
val, and dependent variable \c dep.
+-# Load some sample data:
+<pre class="example">
+DROP TABLE IF EXISTS houses;
+CREATE TABLE houses ( id INT,
+                      tax INT,
+                      bedroom INT,
+                      bath FLOAT,
+                      size INT,
+                      lot INT,
+                      zipcode INT,
+                      price INT,
+                      high_priced BOOLEAN
+                      );
+INSERT INTO houses (id, tax, bedroom, bath, price, size, lot, zipcode, 
high_priced) VALUES
+(1  ,  590 ,       2 ,    1 ,  50000 ,  770 , 22100  , 94301, 'f'::boolean),
+(2  , 1050 ,       3 ,    2 ,  85000 , 1410 , 12000  , 94301, 'f'::boolean),
+(3  ,   20 ,       3 ,    1 ,  22500 , 1060 ,  3500  , 94301, 'f'::boolean),
+(4  ,  870 ,       2 ,    2 ,  90000 , 1300 , 17500  , 94301, 'f'::boolean),
+(5  , 1320 ,       3 ,    2 , 133000 , 1500 , 30000  , 94301, 't'::boolean),
+(6  , 1350 ,       2 ,    1 ,  90500 ,  820 , 25700  , 94301, 'f'::boolean),
+(7  , 2790 ,       3 ,  2.5 , 260000 , 2130 , 25000  , 94301, 't'::boolean),
+(8  ,  680 ,       2 ,    1 , 142500 , 1170 , 22000  , 94301, 't'::boolean),
+(9  , 1840 ,       3 ,    2 , 160000 , 1500 , 19000  , 94301, 't'::boolean),
+(10 , 3680 ,       4 ,    2 , 240000 , 2790 , 20000  , 94301, 't'::boolean),
+(11 , 1660 ,       3 ,    1 ,  87000 , 1030 , 17500  , 94301, 'f'::boolean),
+(12 , 1620 ,       3 ,    2 , 118600 , 1250 , 20000  , 94301, 't'::boolean),
+(13 , 3100 ,       3 ,    2 , 140000 , 1760 , 38000  , 94301, 't'::boolean),
+(14 , 2070 ,       2 ,    3 , 148000 , 1550 , 14000  , 94301, 't'::boolean),
+(15 ,  650 ,       3 ,  1.5 ,  65000 , 1450 , 12000  , 94301, 'f'::boolean),
+(16 ,  770 ,       2 ,    2 ,  91000 , 1300 , 17500  , 76010, 'f'::boolean),
+(17 , 1220 ,       3 ,    2 , 132300 , 1500 , 30000  , 76010, 't'::boolean),
+(18 , 1150 ,       2 ,    1 ,  91100 ,  820 , 25700  , 76010, 'f'::boolean),
+(19 , 2690 ,       3 ,  2.5 , 260011 , 2130 , 25000  , 76010, 't'::boolean),
+(20 ,  780 ,       2 ,    1 , 141800 , 1170 , 22000  , 76010, 't'::boolean),
+(21 , 1910 ,       3 ,    2 , 160900 , 1500 , 19000  , 76010, 't'::boolean),
+(22 , 3600 ,       4 ,    2 , 239000 , 2790 , 20000  , 76010, 't'::boolean),
+(23 , 1600 ,       3 ,    1 ,  81010 , 1030 , 17500  , 76010, 'f'::boolean),
+(24 , 1590 ,       3 ,    2 , 117910 , 1250 , 20000  , 76010, 'f'::boolean),
+(25 , 3200 ,       3 ,    2 , 141100 , 1760 , 38000  , 76010, 't'::boolean),
+(26 , 2270 ,       2 ,    3 , 148011 , 1550 , 14000  , 76010, 't'::boolean),
+(27 ,  750 ,       3 ,  1.5 ,  66000 , 1450 , 12000  , 76010, 'f'::boolean),
+(28 , 2690 ,       3 ,  2.5 , 260011 , 2130 , 25000  , 76010, 't'::boolean),
+(29 ,  780 ,       2 ,    1 , 141800 , 1170 , 22000  , 76010, 't'::boolean),
+(30 , 1910 ,       3 ,    2 , 160900 , 1500 , 19000  , 76010, 't'::boolean),
+(31 , 3600 ,       4 ,    2 , 239000 , 2790 , 20000  , 76010, 't'::boolean),
+(32 , 1600 ,       3 ,    1 ,  81010 , 1030 , 17500  , 76010, 'f'::boolean),
+(33 , 1590 ,       3 ,    2 , 117910 , 1250 , 20000  , 76010, 'f'::boolean),
+(34 , 3200 ,       3 ,    2 , 141100 , 1760 , 38000  , 76010, 't'::boolean),
+(35 , 2270 ,       2 ,    3 , 148011 , 1550 , 14000  , 76010, 't'::boolean),
+(36 ,  750 ,       3 ,  1.5 ,  66000 , 1450 , 12000  , 76010, 'f'::boolean);
+</pre>
 
--# Run the general cross-validation function.
+-# Use the general function to explore lambda values
+for elastic net.  (Note that elastic net also has a
+built in cross validation function
+for selecting elastic net control parameter alpha and
+regularization value lambda.)
 <pre class="example">
-SELECT madlib.cross_validation_general
-    ( 'madlib.elastic_net_train',
-        '{\%data%, \%model%, dep, indep, gaussian, 1, lambda, TRUE, NULL, 
fista,
+DROP TABLE IF EXISTS houses_cv_results;
+SELECT madlib.cross_validation_general(
+    -- modelling_func
+      'madlib.elastic_net_train',
+    -- modelling_params
+        '{%%data%, %%model%, price, "array[tax, bath, size]", gaussian, 0.5, 
lambda, TRUE, NULL, fista,
           "{eta = 2, max_stepsize = 2, use_active_set = t}",
-          NULL, 2000, 1e-6}'::varchar[],
+          NULL, 10000, 1e-6}'::varchar[],
+    -- modelling_params_type
         '{varchar, varchar, varchar, varchar, varchar, double precision,
           double precision, boolean, varchar, varchar, varchar, varchar,
           integer, double precision}'::varchar[],
+    -- param_explored
       'lambda',
-      '{0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.20,
-        0.22, 0.24, 0.26, 0.28, 0.30, 0.32, 0.34, 0.36}'::varchar[],
+    -- explore_values
+      '{0.1, 0.2}'::varchar[],
+    -- predict_func
       'madlib.elastic_net_predict',
-        '{\%model%, \%data%, \%id%, \%prediction%}'::varchar[],
+    -- predict_params
+        '{%%model%, %%data%, %%id%, %%prediction%}'::varchar[],
+    -- predict_params_type
         '{text, text, text, text}'::varchar[],
+    -- metric_func
       'madlib.mse_error',
-        '{\%prediction%, \%data%, \%id%, dep, \%error%}'::varchar[],
+    -- metric_params
+        '{%%prediction%, %%data%, %%id%, price, %%error%}'::varchar[],
+    -- metric_params_type
         '{varchar, varchar, varchar, varchar, varchar}'::varchar[],
-      'cvtest',
-      NULL::varchar,
+    -- data_tbl
+      'houses',
+    -- data_id
+      'id',
+    -- id_is_random
       FALSE,
-      'valid_rst_tbl',
-      '{indep, dep}'::varchar[],
-      10
+    -- validation_result
+      'houses_cv_results',
+    -- data_cols
+      NULL,
+    -- fold_num
+      3
+);
+SELECT * FROM houses_cv_results;
+</pre>
+Results from the lambda values explored:
+<pre class="result">
+ lambda | mean_squared_error_avg | mean_squared_error_stddev
+--------+------------------------+---------------------------
+    0.1 |        1194685622.1604 |          366687470.779826
+    0.2 |       1181768409.98238 |          352203200.758414
+(2 rows)
+</pre>
+
+-# Here we use the general function to explore
+maximum number of iterations for logistic regression:
+<pre class="example">
+DROP TABLE IF EXISTS houses_logregr_cv;
+SELECT madlib.cross_validation_general(
+    -- modelling_func
+        'madlib.logregr_train',
+    -- modelling_params
+        '{%%data%, %%model%, high_priced, "ARRAY[1, bedroom, bath, size]", 
NULL, max_iter}'::varchar[],
+    -- modelling_params_type
+        '{varchar, varchar, varchar, varchar, varchar, integer}'::varchar[],
+    -- param_explored
+        'max_iter',
+    -- explore_values
+        '{2, 10, 40, 100}'::varchar[],
+    -- predict_func
+        'madlib.cv_logregr_predict',
+    -- predict_params
+        '{%%model%, %%data%, "ARRAY[1, bedroom, bath, size]", id, 
%%prediction%}'::varchar[],
+    -- predict_params_type
+        '{varchar, varchar,varchar,varchar,varchar}'::varchar[],
+    -- metric_func
+        'madlib.misclassification_avg',
+    -- metric_params
+        '{%%prediction%, %%data%,  id, high_priced, %%error%}'::varchar[],
+    -- metric_params_type
+        '{varchar, varchar, varchar, varchar, varchar}'::varchar[],
+    -- data_tbl
+        'houses',
+    -- data_id
+        'id',
+    -- id_is_random
+        FALSE,
+    -- validation_result
+        'houses_logregr_cv',
+    -- data_cols
+        NULL,
+    -- fold_num
+       5
 );
+SELECT * FROM houses_logregr_cv;
+</pre>
+Results from the explored number of iterations:
+<pre class="result">
+ max_iter |     error_rate_avg     |             error_rate_stddev
+----------+------------------------+--------------------------------------------
+        2 | 0.28214285714285714286 | 0.2053183193114972855562362870460638951565
+       10 | 0.25357142857142857143 | 0.1239753925550698688724699258837060065614
+       40 | 0.25357142857142857143 | 0.1239753925550698688724699258837060065614
+      100 | 0.25357142857142857143 | 0.1239753925550698688724699258837060065614
+(4 rows)
 </pre>
 
 @anchor notes
 @par Notes
 
-<em>max_locks_per_transaction</em>, which usually is set to the default value
+The lock management parameter <em>max_locks_per_transaction</em>,
+which usually is set to the default value
 of 64, limits the number of tables that can be dropped inside a single
 transaction (the cross-validation function). Thus, the number of different
 values of <em>param_explored</em> (or the length of the
@@ -231,8 +379,8 @@ One way to overcome this limitation is to run the 
cross-validation function
 multiple times, with each run covering a different region of values of the
 parameter.
 
-In the future, MADlib may implement cross-validation functions for each
-individual applicable module, where it is possible to optimize the calculation
+Note that MADlib implements cross-validation functions within certain
+individual modules, where it is possible to optimize the calculation
 to avoid dropping tables and prevent exceeding the \e
 max_locks_per_transaction limitation. Since module-specific cross-validation
 functions depend upon the implementation details of the modules to perform the
@@ -245,7 +393,7 @@ function provided here.
 One round of cross validation involves partitioning a sample of data into
 complementary subsets, performing the analysis on one subset (called the
 training set), and validating the analysis on the other subset (called the
-validation set or testing set). To reduce variability, multiple rounds of
+validation set or test set). To reduce variability, multiple rounds of
 cross validation are performed using different partitions, and the validation
 results are averaged over the rounds.
 

Reply via email to