[ https://issues.apache.org/jira/browse/MADLIB-1222?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16426112#comment-16426112 ]
Frank McQuillan commented on MADLIB-1222: ----------------------------------------- For minibatch this seems to work OK. e.g., continuing modified version user docs example from above: {code:sql} DROP TABLE IF EXISTS iris_data_packed, iris_data_packed_standardization, iris_data_packed_summary; SELECT madlib.minibatch_preprocessor( 'iris_data', 'iris_data_packed', 'class_integer', 'attributes', 10 ); {code} {code:sql} DROP TABLE IF EXISTS mlp_model, mlp_model_summary, mlp_model_standardization; -- Set seed so results are reproducible SELECT setseed(0); SELECT madlib.mlp_classification( 'iris_data_packed', -- Source table 'mlp_model', -- Destination table 'independent_varname', -- Input features 'dependent_varname', -- Label ARRAY[5], -- Number of units per layer 'learning_rate_init=0.003, n_iterations=500, tolerance=0', -- Optimizer params 'tanh', -- Activation function NULL, -- Default weight (1) FALSE, -- No warm start FALSE -- Not verbose ); {code} {code:sql} DROP TABLE IF EXISTS mlp_prediction; SELECT madlib.mlp_predict( 'mlp_model', -- Model table 'iris_data', -- Test data table 'id', -- Id column in test table 'mlp_prediction', -- Output table for predictions 'response' -- Output classes, not probabilities ); SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id; {code} produces {code} id | estimated_class_integer | attributes | class_integer | class | state ----+-------------------------+-------------------+---------------+-------+----------- 1 | {1,0} | {5.0,3.2,1.2,0.2} | {1,0} | 1 | Alaska 2 | {1,0} | {5.5,3.5,1.3,0.2} | {1,0} | 1 | Alaska 3 | {1,0} | {4.9,3.1,1.5,0.1} | {1,0} | 1 | Alaska 4 | {1,0} | {4.4,3.0,1.3,0.2} | {1,0} | 1 | Alaska 5 | {1,0} | {5.1,3.4,1.5,0.2} | {1,0} | 1 | Alaska 6 | {1,0} | {5.0,3.5,1.3,0.3} | {1,0} | 1 | Alaska 7 | {1,0} | {4.5,2.3,1.3,0.3} | {1,0} | 1 | Alaska 8 | {1,0} | {4.4,3.2,1.3,0.2} | {1,0} | 1 | Alaska 9 | {1,0} | {5.0,3.5,1.6,0.6} | {1,0} | 1 | Alaska 10 | {1,0} | {5.1,3.8,1.9,0.4} | {1,0} | 1 | Alaska 11 | {1,0} | {4.8,3.0,1.4,0.3} | {1,0} | 1 | Alaska 12 | {1,0} | {5.1,3.8,1.6,0.2} | {1,0} | 1 | Alaska 13 | {0,1} | {5.7,2.8,4.5,1.3} | {0,1} | 2 | Alaska 14 | {0,1} | {6.3,3.3,4.7,1.6} | {0,1} | 2 | Alaska 15 | {0,1} | {4.9,2.4,3.3,1.0} | {0,1} | 2 | Alaska 16 | {0,1} | {6.6,2.9,4.6,1.3} | {0,1} | 2 | Alaska 17 | {0,1} | {5.2,2.7,3.9,1.4} | {0,1} | 2 | Alaska 18 | {0,1} | {5.0,2.0,3.5,1.0} | {0,1} | 2 | Alaska 19 | {0,1} | {5.9,3.0,4.2,1.5} | {0,1} | 2 | Alaska 20 | {0,1} | {6.0,2.2,4.0,1.0} | {0,1} | 2 | Alaska 21 | {0,1} | {6.1,2.9,4.7,1.4} | {0,1} | 2 | Alaska 22 | {0,1} | {5.6,2.9,3.6,1.3} | {0,1} | 2 | Alaska 23 | {0,1} | {6.7,3.1,4.4,1.4} | {0,1} | 2 | Alaska 24 | {0,1} | {5.6,3.0,4.5,1.5} | {0,1} | 2 | Alaska 25 | {0,1} | {5.8,2.7,4.1,1.0} | {0,1} | 2 | Alaska 26 | {0,1} | {6.2,2.2,4.5,1.5} | {0,1} | 2 | Alaska 27 | {0,1} | {5.6,2.5,3.9,1.1} | {0,1} | 2 | Alaska 28 | {1,0} | {5.0,3.4,1.5,0.2} | {1,0} | 1 | Tennessee 29 | {1,0} | {4.4,2.9,1.4,0.2} | {1,0} | 1 | Tennessee 30 | {1,0} | {4.9,3.1,1.5,0.1} | {1,0} | 1 | Tennessee 31 | {1,0} | {5.4,3.7,1.5,0.2} | {1,0} | 1 | Tennessee 32 | {1,0} | {4.8,3.4,1.6,0.2} | {1,0} | 1 | Tennessee 33 | {1,0} | {4.8,3.0,1.4,0.1} | {1,0} | 1 | Tennessee 34 | {1,0} | {4.3,3.0,1.1,0.1} | {1,0} | 1 | Tennessee 35 | {1,0} | {5.8,4.0,1.2,0.2} | {1,0} | 1 | Tennessee 36 | {1,0} | {5.7,4.4,1.5,0.4} | {1,0} | 1 | Tennessee 37 | {1,0} | {5.4,3.9,1.3,0.4} | {1,0} | 1 | Tennessee 38 | {0,1} | {6.0,2.9,4.5,1.5} | {0,1} | 2 | Tennessee 39 | {0,1} | {5.7,2.6,3.5,1.0} | {0,1} | 2 | Tennessee 40 | {0,1} | {5.5,2.4,3.8,1.1} | {0,1} | 2 | Tennessee 41 | {0,1} | {5.5,2.4,3.7,1.0} | {0,1} | 2 | Tennessee 42 | {0,1} | {5.8,2.7,3.9,1.2} | {0,1} | 2 | Tennessee 43 | {0,1} | {6.0,2.7,5.1,1.6} | {0,1} | 2 | Tennessee 44 | {0,1} | {5.4,3.0,4.5,1.5} | {0,1} | 2 | Tennessee 45 | {0,1} | {6.0,3.4,4.5,1.6} | {0,1} | 2 | Tennessee 46 | {0,1} | {6.7,3.1,4.7,1.5} | {0,1} | 2 | Tennessee 47 | {0,1} | {6.3,2.3,4.4,1.3} | {0,1} | 2 | Tennessee 48 | {0,1} | {5.6,3.0,4.1,1.3} | {0,1} | 2 | Tennessee 49 | {0,1} | {5.5,2.5,4.0,1.3} | {0,1} | 2 | Tennessee 50 | {0,1} | {5.5,2.6,4.4,1.2} | {0,1} | 2 | Tennessee 51 | {0,1} | {6.1,3.0,4.6,1.4} | {0,1} | 2 | Tennessee 52 | {0,1} | {5.8,2.6,4.0,1.2} | {0,1} | 2 | Tennessee (52 rows) {code} {code:sql} DROP TABLE IF EXISTS mlp_prediction; SELECT madlib.mlp_predict( 'mlp_model', -- Model table 'iris_data', -- Test data table 'id', -- Id column in test table 'mlp_prediction', -- Output table for predictions 'prob' -- Output classes, not probabilities ); SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id; {code} produces {code} id | estimated_prob | attributes | class_integer | class | state ----+----------------------------------------+-------------------+---------------+-------+----------- 1 | {0.930759252672095,0.069240747327905} | {5.0,3.2,1.2,0.2} | {1,0} | 1 | Alaska 2 | {0.929395372110727,0.0706046278892731} | {5.5,3.5,1.3,0.2} | {1,0} | 1 | Alaska 3 | {0.92275296493747,0.0772470350625298} | {4.9,3.1,1.5,0.1} | {1,0} | 1 | Alaska 4 | {0.92923853862346,0.0707614613765397} | {4.4,3.0,1.3,0.2} | {1,0} | 1 | Alaska 5 | {0.930203943536138,0.0697960564638618} | {5.1,3.4,1.5,0.2} | {1,0} | 1 | Alaska 6 | {0.937097480813401,0.062902519186599} | {5.0,3.5,1.3,0.3} | {1,0} | 1 | Alaska 7 | {0.809864020154205,0.190135979845795} | {4.5,2.3,1.3,0.3} | {1,0} | 1 | Alaska 8 | {0.938492444302248,0.0615075556977523} | {4.4,3.2,1.3,0.2} | {1,0} | 1 | Alaska 9 | {0.909421618572682,0.090578381427318} | {5.0,3.5,1.6,0.6} | {1,0} | 1 | Alaska 10 | {0.927170837453955,0.0728291625460452} | {5.1,3.8,1.9,0.4} | {1,0} | 1 | Alaska 11 | {0.907769148253907,0.0922308517460933} | {4.8,3.0,1.4,0.3} | {1,0} | 1 | Alaska 12 | {0.943518017066475,0.0564819829335253} | {5.1,3.8,1.6,0.2} | {1,0} | 1 | Alaska 13 | {0.0529094443610184,0.947090555638982} | {5.7,2.8,4.5,1.3} | {0,1} | 2 | Alaska 14 | {0.0529742392448023,0.947025760755198} | {6.3,3.3,4.7,1.6} | {0,1} | 2 | Alaska 15 | {0.154232916835593,0.845767083164407} | {4.9,2.4,3.3,1.0} | {0,1} | 2 | Alaska 16 | {0.0432082742886866,0.956791725711313} | {6.6,2.9,4.6,1.3} | {0,1} | 2 | Alaska 17 | {0.0848279782808559,0.915172021719144} | {5.2,2.7,3.9,1.4} | {0,1} | 2 | Alaska 18 | {0.0757044751883623,0.924295524811638} | {5.0,2.0,3.5,1.0} | {0,1} | 2 | Alaska 19 | {0.0611931643454561,0.938806835654544} | {5.9,3.0,4.2,1.5} | {0,1} | 2 | Alaska 20 | {0.0449649419417731,0.955035058058227} | {6.0,2.2,4.0,1.0} | {0,1} | 2 | Alaska 21 | {0.0430757587622325,0.956924241237768} | {6.1,2.9,4.7,1.4} | {0,1} | 2 | Alaska 22 | {0.111330143272174,0.888669856727826} | {5.6,2.9,3.6,1.3} | {0,1} | 2 | Alaska 23 | {0.0517875328297457,0.948212467170254} | {6.7,3.1,4.4,1.4} | {0,1} | 2 | Alaska 24 | {0.0610712779633371,0.938928722036663} | {5.6,3.0,4.5,1.5} | {0,1} | 2 | Alaska 25 | {0.0697058912971787,0.930294108702821} | {5.8,2.7,4.1,1.0} | {0,1} | 2 | Alaska 26 | {0.0300465449544714,0.969953455045529} | {6.2,2.2,4.5,1.5} | {0,1} | 2 | Alaska 27 | {0.0641965800166526,0.935803419983347} | {5.6,2.5,3.9,1.1} | {0,1} | 2 | Alaska 28 | {0.932679530162975,0.0673204698370254} | {5.0,3.4,1.5,0.2} | {1,0} | 1 | Tennessee 29 | {0.91913460018541,0.0808653998145895} | {4.4,2.9,1.4,0.2} | {1,0} | 1 | Tennessee 30 | {0.92275296493747,0.0772470350625298} | {4.9,3.1,1.5,0.1} | {1,0} | 1 | Tennessee 31 | {0.936685220371634,0.0633147796283663} | {5.4,3.7,1.5,0.2} | {1,0} | 1 | Tennessee 32 | {0.934032404740506,0.0659675952594938} | {4.8,3.4,1.6,0.2} | {1,0} | 1 | Tennessee 33 | {0.922226922202426,0.0777730777975738} | {4.8,3.0,1.4,0.1} | {1,0} | 1 | Tennessee 34 | {0.94042684548622,0.0595731545137805} | {4.3,3.0,1.1,0.1} | {1,0} | 1 | Tennessee 35 | {0.943820498537346,0.0561795014626537} | {5.8,4.0,1.2,0.2} | {1,0} | 1 | Tennessee 36 | {0.942322282886469,0.0576777171135306} | {5.7,4.4,1.5,0.4} | {1,0} | 1 | Tennessee 37 | {0.938684928938641,0.0613150710613592} | {5.4,3.9,1.3,0.4} | {1,0} | 1 | Tennessee 38 | {0.0457032748591934,0.954296725140807} | {6.0,2.9,4.5,1.5} | {0,1} | 2 | Tennessee 39 | {0.0944184541813754,0.905581545818625} | {5.7,2.6,3.5,1.0} | {0,1} | 2 | Tennessee 40 | {0.0645243589381724,0.935475641061828} | {5.5,2.4,3.8,1.1} | {0,1} | 2 | Tennessee 41 | {0.0739865590316946,0.926013440968305} | {5.5,2.4,3.7,1.0} | {0,1} | 2 | Tennessee 42 | {0.0665047837634499,0.93349521623655} | {5.8,2.7,3.9,1.2} | {0,1} | 2 | Tennessee 43 | {0.0315714539349891,0.968428546065011} | {6.0,2.7,5.1,1.6} | {0,1} | 2 | Tennessee 44 | {0.0700314082679038,0.929968591732096} | {5.4,3.0,4.5,1.5} | {0,1} | 2 | Tennessee 45 | {0.0769778072718228,0.923022192728177} | {6.0,3.4,4.5,1.6} | {0,1} | 2 | Tennessee 46 | {0.0432280654691233,0.956771934530877} | {6.7,3.1,4.7,1.5} | {0,1} | 2 | Tennessee 47 | {0.0340971244056786,0.965902875594321} | {6.3,2.3,4.4,1.3} | {0,1} | 2 | Tennessee 48 | {0.0911111408301112,0.908888859169889} | {5.6,3.0,4.1,1.3} | {0,1} | 2 | Tennessee 49 | {0.0556041814589958,0.944395818541004} | {5.5,2.5,4.0,1.3} | {0,1} | 2 | Tennessee 50 | {0.0542022706221145,0.945797729377886} | {5.5,2.6,4.4,1.2} | {0,1} | 2 | Tennessee 51 | {0.0490083738977815,0.950991626102219} | {6.1,3.0,4.6,1.4} | {0,1} | 2 | Tennessee 52 | {0.0567530672794966,0.943246932720503} | {5.8,2.6,4.0,1.2} | {0,1} | 2 | Tennessee (52 rows) {code} > Support already encoded arrays for dependent var in MLP classification > ---------------------------------------------------------------------- > > Key: MADLIB-1222 > URL: https://issues.apache.org/jira/browse/MADLIB-1222 > Project: Apache MADlib > Issue Type: New Feature > Components: Module: Neural Networks > Reporter: Nandish Jayaram > Priority: Major > Fix For: v1.14 > > > MLP currently only supports scalar dependent variables for MLP > classification. If a user has already one-hot encoded categorical variables > the dependent variable will be an array, and hence unusable with > mlp_classification. This feature request is to allow the use of one-hot > encoded array for dependent vars in MLP classification. -- This message was sent by Atlassian JIRA (v7.6.3#76005)