[ 
https://issues.apache.org/jira/browse/MADLIB-1222?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16426112#comment-16426112
 ] 

Frank McQuillan commented on MADLIB-1222:
-----------------------------------------

For minibatch this seems to work OK.  e.g., continuing modified version user 
docs example from above:

{code:sql}
DROP TABLE IF EXISTS iris_data_packed, iris_data_packed_standardization, 
iris_data_packed_summary; 
SELECT madlib.minibatch_preprocessor(
        'iris_data',
        'iris_data_packed',
        'class_integer',
        'attributes',
        10
        );
 {code}

{code:sql}
DROP TABLE IF EXISTS mlp_model, mlp_model_summary, mlp_model_standardization;
-- Set seed so results are reproducible
SELECT setseed(0);
SELECT madlib.mlp_classification(
    'iris_data_packed',      -- Source table
    'mlp_model',      -- Destination table
    'independent_varname',     -- Input features
    'dependent_varname',     -- Label
    ARRAY[5],         -- Number of units per layer
    'learning_rate_init=0.003,
    n_iterations=500,
    tolerance=0',     -- Optimizer params
    'tanh',           -- Activation function
    NULL,             -- Default weight (1)
    FALSE,            -- No warm start
    FALSE             -- Not verbose
);
{code}

{code:sql}
DROP TABLE IF EXISTS mlp_prediction;
SELECT madlib.mlp_predict(
         'mlp_model',         -- Model table
         'iris_data',         -- Test data table
         'id',                -- Id column in test table
         'mlp_prediction',    -- Output table for predictions
         'response'           -- Output classes, not probabilities
     );
SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id;
{code}

produces

{code}
 id | estimated_class_integer |    attributes     | class_integer | class |   
state   
----+-------------------------+-------------------+---------------+-------+-----------
  1 | {1,0}                   | {5.0,3.2,1.2,0.2} | {1,0}         |     1 | 
Alaska
  2 | {1,0}                   | {5.5,3.5,1.3,0.2} | {1,0}         |     1 | 
Alaska
  3 | {1,0}                   | {4.9,3.1,1.5,0.1} | {1,0}         |     1 | 
Alaska
  4 | {1,0}                   | {4.4,3.0,1.3,0.2} | {1,0}         |     1 | 
Alaska
  5 | {1,0}                   | {5.1,3.4,1.5,0.2} | {1,0}         |     1 | 
Alaska
  6 | {1,0}                   | {5.0,3.5,1.3,0.3} | {1,0}         |     1 | 
Alaska
  7 | {1,0}                   | {4.5,2.3,1.3,0.3} | {1,0}         |     1 | 
Alaska
  8 | {1,0}                   | {4.4,3.2,1.3,0.2} | {1,0}         |     1 | 
Alaska
  9 | {1,0}                   | {5.0,3.5,1.6,0.6} | {1,0}         |     1 | 
Alaska
 10 | {1,0}                   | {5.1,3.8,1.9,0.4} | {1,0}         |     1 | 
Alaska
 11 | {1,0}                   | {4.8,3.0,1.4,0.3} | {1,0}         |     1 | 
Alaska
 12 | {1,0}                   | {5.1,3.8,1.6,0.2} | {1,0}         |     1 | 
Alaska
 13 | {0,1}                   | {5.7,2.8,4.5,1.3} | {0,1}         |     2 | 
Alaska
 14 | {0,1}                   | {6.3,3.3,4.7,1.6} | {0,1}         |     2 | 
Alaska
 15 | {0,1}                   | {4.9,2.4,3.3,1.0} | {0,1}         |     2 | 
Alaska
 16 | {0,1}                   | {6.6,2.9,4.6,1.3} | {0,1}         |     2 | 
Alaska
 17 | {0,1}                   | {5.2,2.7,3.9,1.4} | {0,1}         |     2 | 
Alaska
 18 | {0,1}                   | {5.0,2.0,3.5,1.0} | {0,1}         |     2 | 
Alaska
 19 | {0,1}                   | {5.9,3.0,4.2,1.5} | {0,1}         |     2 | 
Alaska
 20 | {0,1}                   | {6.0,2.2,4.0,1.0} | {0,1}         |     2 | 
Alaska
 21 | {0,1}                   | {6.1,2.9,4.7,1.4} | {0,1}         |     2 | 
Alaska
 22 | {0,1}                   | {5.6,2.9,3.6,1.3} | {0,1}         |     2 | 
Alaska
 23 | {0,1}                   | {6.7,3.1,4.4,1.4} | {0,1}         |     2 | 
Alaska
 24 | {0,1}                   | {5.6,3.0,4.5,1.5} | {0,1}         |     2 | 
Alaska
 25 | {0,1}                   | {5.8,2.7,4.1,1.0} | {0,1}         |     2 | 
Alaska
 26 | {0,1}                   | {6.2,2.2,4.5,1.5} | {0,1}         |     2 | 
Alaska
 27 | {0,1}                   | {5.6,2.5,3.9,1.1} | {0,1}         |     2 | 
Alaska
 28 | {1,0}                   | {5.0,3.4,1.5,0.2} | {1,0}         |     1 | 
Tennessee
 29 | {1,0}                   | {4.4,2.9,1.4,0.2} | {1,0}         |     1 | 
Tennessee
 30 | {1,0}                   | {4.9,3.1,1.5,0.1} | {1,0}         |     1 | 
Tennessee
 31 | {1,0}                   | {5.4,3.7,1.5,0.2} | {1,0}         |     1 | 
Tennessee
 32 | {1,0}                   | {4.8,3.4,1.6,0.2} | {1,0}         |     1 | 
Tennessee
 33 | {1,0}                   | {4.8,3.0,1.4,0.1} | {1,0}         |     1 | 
Tennessee
 34 | {1,0}                   | {4.3,3.0,1.1,0.1} | {1,0}         |     1 | 
Tennessee
 35 | {1,0}                   | {5.8,4.0,1.2,0.2} | {1,0}         |     1 | 
Tennessee
 36 | {1,0}                   | {5.7,4.4,1.5,0.4} | {1,0}         |     1 | 
Tennessee
 37 | {1,0}                   | {5.4,3.9,1.3,0.4} | {1,0}         |     1 | 
Tennessee
 38 | {0,1}                   | {6.0,2.9,4.5,1.5} | {0,1}         |     2 | 
Tennessee
 39 | {0,1}                   | {5.7,2.6,3.5,1.0} | {0,1}         |     2 | 
Tennessee
 40 | {0,1}                   | {5.5,2.4,3.8,1.1} | {0,1}         |     2 | 
Tennessee
 41 | {0,1}                   | {5.5,2.4,3.7,1.0} | {0,1}         |     2 | 
Tennessee
 42 | {0,1}                   | {5.8,2.7,3.9,1.2} | {0,1}         |     2 | 
Tennessee
 43 | {0,1}                   | {6.0,2.7,5.1,1.6} | {0,1}         |     2 | 
Tennessee
 44 | {0,1}                   | {5.4,3.0,4.5,1.5} | {0,1}         |     2 | 
Tennessee
 45 | {0,1}                   | {6.0,3.4,4.5,1.6} | {0,1}         |     2 | 
Tennessee
 46 | {0,1}                   | {6.7,3.1,4.7,1.5} | {0,1}         |     2 | 
Tennessee
 47 | {0,1}                   | {6.3,2.3,4.4,1.3} | {0,1}         |     2 | 
Tennessee
 48 | {0,1}                   | {5.6,3.0,4.1,1.3} | {0,1}         |     2 | 
Tennessee
 49 | {0,1}                   | {5.5,2.5,4.0,1.3} | {0,1}         |     2 | 
Tennessee
 50 | {0,1}                   | {5.5,2.6,4.4,1.2} | {0,1}         |     2 | 
Tennessee
 51 | {0,1}                   | {6.1,3.0,4.6,1.4} | {0,1}         |     2 | 
Tennessee
 52 | {0,1}                   | {5.8,2.6,4.0,1.2} | {0,1}         |     2 | 
Tennessee
(52 rows)
{code}

{code:sql}
DROP TABLE IF EXISTS mlp_prediction;
SELECT madlib.mlp_predict(
         'mlp_model',         -- Model table
         'iris_data',         -- Test data table
         'id',                -- Id column in test table
         'mlp_prediction',    -- Output table for predictions
         'prob'           -- Output classes, not probabilities
     );
SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id;
{code}

produces

{code}
 id |             estimated_prob             |    attributes     | 
class_integer | class |   state   
----+----------------------------------------+-------------------+---------------+-------+-----------
  1 | {0.930759252672095,0.069240747327905}  | {5.0,3.2,1.2,0.2} | {1,0}        
 |     1 | Alaska
  2 | {0.929395372110727,0.0706046278892731} | {5.5,3.5,1.3,0.2} | {1,0}        
 |     1 | Alaska
  3 | {0.92275296493747,0.0772470350625298}  | {4.9,3.1,1.5,0.1} | {1,0}        
 |     1 | Alaska
  4 | {0.92923853862346,0.0707614613765397}  | {4.4,3.0,1.3,0.2} | {1,0}        
 |     1 | Alaska
  5 | {0.930203943536138,0.0697960564638618} | {5.1,3.4,1.5,0.2} | {1,0}        
 |     1 | Alaska
  6 | {0.937097480813401,0.062902519186599}  | {5.0,3.5,1.3,0.3} | {1,0}        
 |     1 | Alaska
  7 | {0.809864020154205,0.190135979845795}  | {4.5,2.3,1.3,0.3} | {1,0}        
 |     1 | Alaska
  8 | {0.938492444302248,0.0615075556977523} | {4.4,3.2,1.3,0.2} | {1,0}        
 |     1 | Alaska
  9 | {0.909421618572682,0.090578381427318}  | {5.0,3.5,1.6,0.6} | {1,0}        
 |     1 | Alaska
 10 | {0.927170837453955,0.0728291625460452} | {5.1,3.8,1.9,0.4} | {1,0}        
 |     1 | Alaska
 11 | {0.907769148253907,0.0922308517460933} | {4.8,3.0,1.4,0.3} | {1,0}        
 |     1 | Alaska
 12 | {0.943518017066475,0.0564819829335253} | {5.1,3.8,1.6,0.2} | {1,0}        
 |     1 | Alaska
 13 | {0.0529094443610184,0.947090555638982} | {5.7,2.8,4.5,1.3} | {0,1}        
 |     2 | Alaska
 14 | {0.0529742392448023,0.947025760755198} | {6.3,3.3,4.7,1.6} | {0,1}        
 |     2 | Alaska
 15 | {0.154232916835593,0.845767083164407}  | {4.9,2.4,3.3,1.0} | {0,1}        
 |     2 | Alaska
 16 | {0.0432082742886866,0.956791725711313} | {6.6,2.9,4.6,1.3} | {0,1}        
 |     2 | Alaska
 17 | {0.0848279782808559,0.915172021719144} | {5.2,2.7,3.9,1.4} | {0,1}        
 |     2 | Alaska
 18 | {0.0757044751883623,0.924295524811638} | {5.0,2.0,3.5,1.0} | {0,1}        
 |     2 | Alaska
 19 | {0.0611931643454561,0.938806835654544} | {5.9,3.0,4.2,1.5} | {0,1}        
 |     2 | Alaska
 20 | {0.0449649419417731,0.955035058058227} | {6.0,2.2,4.0,1.0} | {0,1}        
 |     2 | Alaska
 21 | {0.0430757587622325,0.956924241237768} | {6.1,2.9,4.7,1.4} | {0,1}        
 |     2 | Alaska
 22 | {0.111330143272174,0.888669856727826}  | {5.6,2.9,3.6,1.3} | {0,1}        
 |     2 | Alaska
 23 | {0.0517875328297457,0.948212467170254} | {6.7,3.1,4.4,1.4} | {0,1}        
 |     2 | Alaska
 24 | {0.0610712779633371,0.938928722036663} | {5.6,3.0,4.5,1.5} | {0,1}        
 |     2 | Alaska
 25 | {0.0697058912971787,0.930294108702821} | {5.8,2.7,4.1,1.0} | {0,1}        
 |     2 | Alaska
 26 | {0.0300465449544714,0.969953455045529} | {6.2,2.2,4.5,1.5} | {0,1}        
 |     2 | Alaska
 27 | {0.0641965800166526,0.935803419983347} | {5.6,2.5,3.9,1.1} | {0,1}        
 |     2 | Alaska
 28 | {0.932679530162975,0.0673204698370254} | {5.0,3.4,1.5,0.2} | {1,0}        
 |     1 | Tennessee
 29 | {0.91913460018541,0.0808653998145895}  | {4.4,2.9,1.4,0.2} | {1,0}        
 |     1 | Tennessee
 30 | {0.92275296493747,0.0772470350625298}  | {4.9,3.1,1.5,0.1} | {1,0}        
 |     1 | Tennessee
 31 | {0.936685220371634,0.0633147796283663} | {5.4,3.7,1.5,0.2} | {1,0}        
 |     1 | Tennessee
 32 | {0.934032404740506,0.0659675952594938} | {4.8,3.4,1.6,0.2} | {1,0}        
 |     1 | Tennessee
 33 | {0.922226922202426,0.0777730777975738} | {4.8,3.0,1.4,0.1} | {1,0}        
 |     1 | Tennessee
 34 | {0.94042684548622,0.0595731545137805}  | {4.3,3.0,1.1,0.1} | {1,0}        
 |     1 | Tennessee
 35 | {0.943820498537346,0.0561795014626537} | {5.8,4.0,1.2,0.2} | {1,0}        
 |     1 | Tennessee
 36 | {0.942322282886469,0.0576777171135306} | {5.7,4.4,1.5,0.4} | {1,0}        
 |     1 | Tennessee
 37 | {0.938684928938641,0.0613150710613592} | {5.4,3.9,1.3,0.4} | {1,0}        
 |     1 | Tennessee
 38 | {0.0457032748591934,0.954296725140807} | {6.0,2.9,4.5,1.5} | {0,1}        
 |     2 | Tennessee
 39 | {0.0944184541813754,0.905581545818625} | {5.7,2.6,3.5,1.0} | {0,1}        
 |     2 | Tennessee
 40 | {0.0645243589381724,0.935475641061828} | {5.5,2.4,3.8,1.1} | {0,1}        
 |     2 | Tennessee
 41 | {0.0739865590316946,0.926013440968305} | {5.5,2.4,3.7,1.0} | {0,1}        
 |     2 | Tennessee
 42 | {0.0665047837634499,0.93349521623655}  | {5.8,2.7,3.9,1.2} | {0,1}        
 |     2 | Tennessee
 43 | {0.0315714539349891,0.968428546065011} | {6.0,2.7,5.1,1.6} | {0,1}        
 |     2 | Tennessee
 44 | {0.0700314082679038,0.929968591732096} | {5.4,3.0,4.5,1.5} | {0,1}        
 |     2 | Tennessee
 45 | {0.0769778072718228,0.923022192728177} | {6.0,3.4,4.5,1.6} | {0,1}        
 |     2 | Tennessee
 46 | {0.0432280654691233,0.956771934530877} | {6.7,3.1,4.7,1.5} | {0,1}        
 |     2 | Tennessee
 47 | {0.0340971244056786,0.965902875594321} | {6.3,2.3,4.4,1.3} | {0,1}        
 |     2 | Tennessee
 48 | {0.0911111408301112,0.908888859169889} | {5.6,3.0,4.1,1.3} | {0,1}        
 |     2 | Tennessee
 49 | {0.0556041814589958,0.944395818541004} | {5.5,2.5,4.0,1.3} | {0,1}        
 |     2 | Tennessee
 50 | {0.0542022706221145,0.945797729377886} | {5.5,2.6,4.4,1.2} | {0,1}        
 |     2 | Tennessee
 51 | {0.0490083738977815,0.950991626102219} | {6.1,3.0,4.6,1.4} | {0,1}        
 |     2 | Tennessee
 52 | {0.0567530672794966,0.943246932720503} | {5.8,2.6,4.0,1.2} | {0,1}        
 |     2 | Tennessee
(52 rows)
{code}


 

> Support already encoded arrays for dependent var in MLP classification
> ----------------------------------------------------------------------
>
>                 Key: MADLIB-1222
>                 URL: https://issues.apache.org/jira/browse/MADLIB-1222
>             Project: Apache MADlib
>          Issue Type: New Feature
>          Components: Module: Neural Networks
>            Reporter: Nandish Jayaram
>            Priority: Major
>             Fix For: v1.14
>
>
> MLP currently only supports scalar dependent variables for MLP 
> classification. If a user has already one-hot encoded categorical variables 
> the dependent variable will be an array, and hence unusable with 
> mlp_classification. This feature request is to allow the use of one-hot 
> encoded array for dependent vars in MLP classification.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to