IGNITE-9717: [ML] Add setters methods to Logistic Regression and fix examples/tests
this closes #4865 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/4da48e6f Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/4da48e6f Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/4da48e6f Branch: refs/heads/ignite-gg-14206 Commit: 4da48e6f90ceb7ee585b66af4f384cc868f6ca8e Parents: a373486 Author: zaleslaw <zaleslaw....@gmail.com> Authored: Fri Sep 28 16:05:39 2018 +0300 Committer: Yury Babak <yba...@gridgain.com> Committed: Fri Sep 28 16:05:39 2018 +0300 ---------------------------------------------------------------------- .../LogisticRegressionSGDTrainerExample.java | 16 ++++--- .../ml/tutorial/Step_1_Read_and_Learn.java | 2 +- .../examples/ml/tutorial/Step_2_Imputing.java | 2 +- .../examples/ml/tutorial/Step_3_Categorial.java | 2 +- .../Step_3_Categorial_with_One_Hot_Encoder.java | 2 +- .../ml/tutorial/Step_4_Add_age_fare.java | 2 +- .../examples/ml/tutorial/Step_5_Scaling.java | 2 +- .../tutorial/Step_5_Scaling_with_Pipeline.java | 2 +- .../ignite/examples/ml/tutorial/Step_6_KNN.java | 2 +- .../ml/tutorial/Step_7_Split_train_test.java | 2 +- .../ignite/examples/ml/tutorial/Step_8_CV.java | 2 +- .../ml/tutorial/Step_8_CV_with_Param_Grid.java | 2 +- .../ml/tutorial/Step_9_Go_to_LogReg.java | 27 ++++++----- .../ml/tutorial/TutorialStepByStepExample.java | 2 +- .../binomial/LogisticRegressionSGDTrainer.java | 47 ++++++++++---------- .../LogRegressionMultiClassTrainer.java | 29 +++++++----- .../SVMLinearBinaryClassificationTrainer.java | 2 +- ...VMLinearMultiClassClassificationTrainer.java | 2 +- .../apache/ignite/ml/pipeline/PipelineTest.java | 18 +++----- .../logistic/LogRegMultiClassTrainerTest.java | 1 - .../logistic/LogisticRegressionModelTest.java | 17 +++---- .../LogisticRegressionSGDTrainerTest.java | 24 +++++----- 22 files changed, 111 insertions(+), 96 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java index 8d4218d..15330d0 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java @@ -60,11 +60,16 @@ public class LogisticRegressionSGDTrainerExample { IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); System.out.println(">>> Create new logistic regression trainer object."); - LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ), 100000, 10, 100, 123L); + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>() + .withUpdatesStgy(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + )) + .withMaxIterations(100000) + .withLocIterations(100) + .withBatchSize(10) + .withSeed(123L); System.out.println(">>> Perform the training to get the model."); LogisticRegressionModel mdl = trainer.fit( @@ -218,5 +223,4 @@ public class LogisticRegressionSGDTrainerExample { {1, 5.1, 2.5, 3, 1.1}, {1, 5.7, 2.8, 4.1, 1.3}, }; - } http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java index 264dbf4..481fa1d 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java @@ -42,7 +42,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; */ public class Step_1_Read_and_Learn { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 1 (read and learn) example started."); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java index df73235..d60dc4b 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java @@ -44,7 +44,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; */ public class Step_2_Imputing { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 2 (imputing) example started."); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java index 463a6ba..ac2fe08 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java @@ -47,7 +47,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; */ public class Step_3_Categorial { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 3 (categorial) example started."); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java index 93e7e79..f0b6efe 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java @@ -48,7 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; */ public class Step_3_Categorial_with_One_Hot_Encoder { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 3 (categorial with One-hot encoder) example started."); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java index bbeedb6..71e9efd 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java @@ -45,7 +45,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; */ public class Step_4_Add_age_fare { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 4 (add age and fare) example started."); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java index 7d934d7..fe7bf91 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java @@ -48,7 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; */ public class Step_5_Scaling { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 5 (scaling) example started."); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java index cc0a278..bd7cc21 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java @@ -48,7 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; */ public class Step_5_Scaling_with_Pipeline { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 5 (scaling) via Pipeline example started."); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java index 0c8b562..a35b841 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java @@ -49,7 +49,7 @@ import org.apache.ignite.ml.selection.scoring.metric.Accuracy; */ public class Step_6_KNN { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 6 (kNN) example started."); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java index c6d033c..53d4d0a 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java @@ -51,7 +51,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; */ public class Step_7_Split_train_test { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 7 (split to train and test) example started."); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java index d83e14a..feedccf 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java @@ -63,7 +63,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; */ public class Step_8_CV { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 8 (cross-validation) example started."); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java index 594c0eb..670f025 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java @@ -65,7 +65,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; */ public class Step_8_CV_with_Param_Grid { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 8 (cross-validation with param grid) example started."); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java index 4e1e005..b98b0eb 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java @@ -56,7 +56,7 @@ import org.apache.ignite.ml.selection.split.TrainTestSplit; */ public class Step_9_Go_to_LogReg { /** Run example. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { System.out.println(); System.out.println(">>> Tutorial step 9 (logistic regression) example started."); @@ -124,12 +124,13 @@ public class Step_9_Go_to_LogReg { minMaxScalerPreprocessor ); - LogisticRegressionSGDTrainer<?> trainer - = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(learningRate), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ), maxIterations, batchSize, locIterations, 123L); + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>() + .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(learningRate), + SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) + .withMaxIterations(maxIterations) + .withLocIterations(locIterations) + .withBatchSize(batchSize) + .withSeed(123L); CrossValidation<LogisticRegressionModel, Double, Integer, Object[]> scoreCalculator = new CrossValidation<>(); @@ -187,11 +188,13 @@ public class Step_9_Go_to_LogReg { minMaxScalerPreprocessor ); - LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(bestLearningRate), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ), bestMaxIterations, bestBatchSize, bestLocIterations, 123L); + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>() + .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(bestLearningRate), + SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) + .withMaxIterations(bestMaxIterations) + .withLocIterations(bestLocIterations) + .withBatchSize(bestBatchSize) + .withSeed(123L); System.out.println(">>> Perform the training to get the model."); LogisticRegressionModel bestMdl = trainer.fit( http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java index 67f4bf5..a376ae6 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java @@ -23,7 +23,7 @@ package org.apache.ignite.examples.ml.tutorial; */ public class TutorialStepByStepExample { /** Run examples with default settings. */ - public static void main(String[] args) throws InterruptedException { + public static void main(String[] args) { Step_1_Read_and_Learn.main(args); Step_2_Imputing.main(args); Step_3_Categorial.main(args); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java index fb5d5a0..74a296d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java @@ -33,6 +33,8 @@ import org.apache.ignite.ml.nn.MultilayerPerceptron; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; import org.jetbrains.annotations.NotNull; @@ -41,37 +43,23 @@ import org.jetbrains.annotations.NotNull; */ public class LogisticRegressionSGDTrainer<P extends Serializable> extends SingleLabelDatasetTrainer<LogisticRegressionModel> { /** Update strategy. */ - private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy; + private UpdatesStrategy updatesStgy = new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ); /** Max number of iteration. */ - private int maxIterations; + private int maxIterations = 100; /** Batch size. */ - private int batchSize; + private int batchSize = 100; /** Number of local iterations. */ - private int locIterations; + private int locIterations = 100; /** Seed for random generator. */ - private long seed; - - /** - * Constructs a new instance of linear regression SGD trainer. - * - * @param updatesStgy Update strategy. - * @param maxIterations Max number of iteration. - * @param batchSize Batch size. - * @param locIterations Number of local iterations. - * @param seed Seed for random generator. - */ - public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, - int batchSize, int locIterations, long seed) { - this.updatesStgy = updatesStgy; - this.maxIterations = maxIterations; - this.batchSize = batchSize; - this.locIterations = locIterations; - this.seed = seed; - } + private long seed = 1234L; /** {@inheritDoc} */ @Override public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, @@ -202,11 +190,22 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single } /** + * Set up the regularization parameter. + * + * @param updatesStgy Update strategy. + * @return Trainer with new update strategy parameter value. + */ + public LogisticRegressionSGDTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) { + this.updatesStgy = updatesStgy; + return this; + } + + /** * Get the update strategy. * * @return The property value. */ - public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() { + public UpdatesStrategy getUpdatesStgy() { return updatesStgy; } http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java index b9cdcc7..71d54fa 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java @@ -32,8 +32,9 @@ import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.nn.MultilayerPerceptron; import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap; @@ -46,19 +47,23 @@ import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; public class LogRegressionMultiClassTrainer<P extends Serializable> extends SingleLabelDatasetTrainer<LogRegressionMultiClassModel> { /** Update strategy. */ - private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy; + private UpdatesStrategy updatesStgy = new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ); /** Max number of iteration. */ - private int amountOfIterations; + private int amountOfIterations = 100; /** Batch size. */ - private int batchSize; + private int batchSize = 100; /** Number of local iterations. */ - private int amountOfLocIterations; + private int amountOfLocIterations = 100; /** Seed for random generator. */ - private long seed; + private long seed = 1234L; /** * Trains model based on the specified data. @@ -90,7 +95,11 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> classes.forEach(clsLb -> { LogisticRegressionSGDTrainer<?> trainer = - new LogisticRegressionSGDTrainer<>(updatesStgy, amountOfIterations, batchSize, amountOfLocIterations, seed); + new LogisticRegressionSGDTrainer<>() + .withBatchSize(batchSize) + .withLocIterations(amountOfLocIterations) + .withMaxIterations(amountOfIterations) + .withSeed(seed); IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> { Double lb = lbExtractor.apply(k, v); @@ -238,7 +247,7 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> } /** - * Set up the regularization parameter. + * Set up the updates strategy. * * @param updatesStgy Update strategy. * @return Trainer with new update strategy parameter value. @@ -249,11 +258,11 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> } /** - * Get the update strategy.. + * Get the update strategy. * * @return The parameter value. */ - public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() { + public UpdatesStrategy getUpdatesStgy() { return updatesStgy; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java index 2c621c8..47666f4 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java @@ -50,7 +50,7 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai private double lambda = 0.4; /** The seed number. */ - private long seed; + private long seed = 1234L; /** * Trains model based on the specified data. http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java index ec60034..b161914 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java @@ -52,7 +52,7 @@ public class SVMLinearMultiClassClassificationTrainer private double lambda = 0.2; /** The seed number. */ - private long seed; + private long seed = 1234L; /** * Trains model based on the specified data. http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java index 91bbcd4..d517ce6 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java @@ -51,11 +51,13 @@ public class PipelineTest extends TrainerTest { cacheMock.put(i, convertedRow); } - LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator().withLearningRate(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ), 100000, 10, 100, 123L); + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>() + .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) + .withMaxIterations(100000) + .withLocIterations(100) + .withBatchSize(10) + .withSeed(123L); PipelineMdl<Integer, Double[]> mdl = new Pipeline<Integer, Double[], Vector>() .addFeatureExtractor((k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length))) @@ -88,12 +90,6 @@ public class PipelineTest extends TrainerTest { cacheMock.put(i, convertedRow); } - LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator().withLearningRate(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ), 100000, 10, 100, 123L); - PipelineMdl<Integer, Double[]> mdl = new Pipeline<Integer, Double[], Vector>() .addFeatureExtractor((k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length))) .addLabelExtractor((k, v) -> v[0]) http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java index 78cd08d..c99bf02 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java @@ -133,7 +133,6 @@ public class LogRegMultiClassTrainerTest extends TrainerTest { VectorUtils.of(10, -10) ); - for (Vector vec : vectors) { TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION); TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION); http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java index 89c9cca..e8aaacd 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java @@ -38,7 +38,7 @@ public class LogisticRegressionModelTest { /** */ @Test public void testPredict() { - Vector weights = new DenseVector(new double[]{2.0, 3.0}); + Vector weights = new DenseVector(new double[] {2.0, 3.0}); assertFalse(new LogisticRegressionModel(weights, 1.0).isKeepingRawLabels()); @@ -57,35 +57,36 @@ public class LogisticRegressionModelTest { /** */ @Test(expected = CardinalityException.class) public void testPredictOnAnObservationWithWrongCardinality() { - Vector weights = new DenseVector(new double[]{2.0, 3.0}); + Vector weights = new DenseVector(new double[] {2.0, 3.0}); LogisticRegressionModel mdl = new LogisticRegressionModel(weights, 1.0); - Vector observation = new DenseVector(new double[]{1.0}); + Vector observation = new DenseVector(new double[] {1.0}); mdl.apply(observation); } /** */ private void verifyPredict(LogisticRegressionModel mdl) { - Vector observation = new DenseVector(new double[]{1.0, 1.0}); + Vector observation = new DenseVector(new double[] {1.0, 1.0}); TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); - observation = new DenseVector(new double[]{2.0, 1.0}); + observation = new DenseVector(new double[] {2.0, 1.0}); TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); - observation = new DenseVector(new double[]{1.0, 2.0}); + observation = new DenseVector(new double[] {1.0, 2.0}); TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 2.0), mdl.apply(observation), PRECISION); - observation = new DenseVector(new double[]{-2.0, 1.0}); + observation = new DenseVector(new double[] {-2.0, 1.0}); TestUtils.assertEquals(sigmoid(1.0 - 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); - observation = new DenseVector(new double[]{1.0, -2.0}); + observation = new DenseVector(new double[] {1.0, -2.0}); TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 - 3.0 * 2.0), mdl.apply(observation), PRECISION); } /** * Sigmoid function. + * * @param z The regression value. * @return The result. */ http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java index 723677c..d9b6f7a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java @@ -45,11 +45,13 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest { for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]); - LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator().withLearningRate(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ), 100000, 10, 100, 123L); + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>() + .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) + .withMaxIterations(100000) + .withLocIterations(100) + .withBatchSize(10) + .withSeed(123L); LogisticRegressionModel mdl = trainer.fit( cacheMock, @@ -70,11 +72,13 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest { for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]); - LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator().withLearningRate(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ), 100000, 10, 100, 123L); + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>() + .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) + .withMaxIterations(100000) + .withLocIterations(100) + .withBatchSize(10) + .withSeed(123L); LogisticRegressionModel originalMdl = trainer.fit( cacheMock,