IGNITE-10380: Drop Multi-label Classification for Logistic Regression and SVM
This closes #5559 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/098caf44 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/098caf44 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/098caf44 Branch: refs/heads/ignite-10044 Commit: 098caf4469469763224e80fd9cd4ac5629b771bf Parents: 46ff268 Author: zaleslaw <[email protected]> Authored: Wed Dec 5 13:02:41 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Wed Dec 5 13:02:41 2018 +0300 ---------------------------------------------------------------------- ...ggedLogisticRegressionSGDTrainerExample.java | 2 +- .../LogisticRegressionSGDTrainerExample.java | 4 +- ...gressionMultiClassClassificationExample.java | 164 --------- .../logistic/multiclass/package-info.java | 22 -- .../ml/svm/SVMBinaryClassificationExample.java | 112 ++++++ .../binary/SVMBinaryClassificationExample.java | 112 ------ .../examples/ml/svm/binary/package-info.java | 22 -- .../SVMMultiClassClassificationExample.java | 153 -------- .../ml/svm/multiclass/package-info.java | 22 -- .../ml/tutorial/Step_9_Go_to_LogReg.java | 4 +- .../logistic/LogisticRegressionModel.java | 205 +++++++++++ .../logistic/LogisticRegressionSGDTrainer.java | 246 +++++++++++++ .../binomial/LogisticRegressionModel.java | 205 ----------- .../binomial/LogisticRegressionSGDTrainer.java | 246 ------------- .../logistic/binomial/package-info.java | 22 -- .../LogRegressionMultiClassModel.java | 115 ------ .../LogRegressionMultiClassTrainer.java | 269 -------------- .../logistic/multiclass/package-info.java | 22 -- .../svm/SVMLinearBinaryClassificationModel.java | 194 ---------- .../SVMLinearBinaryClassificationTrainer.java | 356 ------------------- .../ml/svm/SVMLinearClassificationModel.java | 194 ++++++++++ .../ml/svm/SVMLinearClassificationTrainer.java | 356 +++++++++++++++++++ .../SVMLinearMultiClassClassificationModel.java | 114 ------ ...VMLinearMultiClassClassificationTrainer.java | 269 -------------- .../ignite/ml/common/CollectionsTest.java | 19 +- .../ignite/ml/common/LocalModelsTest.java | 54 +-- .../ml/multiclass/OneVsRestTrainerTest.java | 4 +- .../ignite/ml/pipeline/PipelineMdlTest.java | 2 +- .../apache/ignite/ml/pipeline/PipelineTest.java | 2 +- .../ml/regressions/RegressionsTestSuite.java | 4 +- .../linear/LinearRegressionModelTest.java | 23 +- .../logistic/LogRegMultiClassTrainerTest.java | 141 -------- .../logistic/LogisticRegressionModelTest.java | 1 - .../LogisticRegressionSGDTrainerTest.java | 2 - .../ignite/ml/svm/SVMBinaryTrainerTest.java | 14 +- .../org/apache/ignite/ml/svm/SVMModelTest.java | 36 +- .../ignite/ml/svm/SVMMultiClassTrainerTest.java | 100 ------ .../org/apache/ignite/ml/svm/SVMTestSuite.java | 1 - .../apache/ignite/ml/trainers/BaggingTest.java | 4 +- 39 files changed, 1151 insertions(+), 2686 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java index 44fb77e..98745a4 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java @@ -30,7 +30,7 @@ import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.apache.ignite.ml.selection.cv.CrossValidation; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.trainers.DatasetTrainer; http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java index 52ee330..8530045 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java @@ -31,8 +31,8 @@ import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; /** * Run logistic regression model based on <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent"> http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java deleted file mode 100644 index 962fdac..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.regression.logistic.multiclass; - -import java.io.FileNotFoundException; -import java.util.Arrays; -import javax.cache.Cache; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.query.QueryCursor; -import org.apache.ignite.cache.query.ScanQuery; -import org.apache.ignite.examples.ml.util.MLSandboxDatasets; -import org.apache.ignite.examples.ml.util.SandboxMLCache; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.nn.UpdatesStrategy; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer; - -/** - * Run Logistic Regression multi-class classification trainer ({@link LogRegressionMultiClassModel}) over distributed - * dataset to build two models: one with minmaxscaling and one without minmaxscaling. - * <p> - * Code in this example launches Ignite grid and fills the cache with test data points (preprocessed - * <a href="https://archive.ics.uci.edu/ml/datasets/Glass+Identification">Glass dataset</a>).</p> - * <p> - * After that it trains two logistic regression models based on the specified data - one model is with minmaxscaling - * and one without minmaxscaling.</p> - * <p> - * Finally, this example loops over the test set of data points, applies the trained models to predict the target value, - * compares prediction to expected outcome (ground truth), and builds - * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrices</a>.</p> - * <p> - * You can change the test data used in this example and re-run it to explore this algorithm further.</p> - */ -public class LogRegressionMultiClassClassificationExample { - /** Run example. */ - public static void main(String[] args) throws FileNotFoundException { - System.out.println(); - System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - - IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite) - .fillCacheWith(MLSandboxDatasets.GLASS_IDENTIFICATION); - - LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>() - .withUpdatesStgy(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - )) - .withAmountOfIterations(100000) - .withAmountOfLocIterations(10) - .withBatchSize(100) - .withSeed(123L); - - LogRegressionMultiClassModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> v.copyOfRange(1, v.size()), - (k, v) -> v.get(0) - ); - - System.out.println(">>> SVM Multi-class model"); - System.out.println(mdl.toString()); - - MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>(); - - IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit( - ignite, - dataCache, - (k, v) -> v.copyOfRange(1, v.size()) - ); - - LogRegressionMultiClassModel mdlWithNormalization = trainer.fit( - ignite, - dataCache, - preprocessor, - (k, v) -> v.get(0) - ); - - System.out.println(">>> Logistic Regression Multi-class model with normalization"); - System.out.println(mdlWithNormalization.toString()); - - System.out.println(">>> ----------------------------------------------------------------"); - System.out.println(">>> | Prediction\t| Prediction with Normalization\t| Ground Truth\t|"); - System.out.println(">>> ----------------------------------------------------------------"); - - int amountOfErrors = 0; - int amountOfErrorsWithNormalization = 0; - int totalAmount = 0; - - // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix - int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; - int[][] confusionMtxWithNormalization = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; - - try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, Vector> observation : observations) { - Vector val = observation.getValue(); - Vector inputs = val.copyOfRange(1, val.size()); - double groundTruth = val.get(0); - - double prediction = mdl.apply(inputs); - double predictionWithNormalization = mdlWithNormalization.apply(inputs); - - totalAmount++; - - // Collect data for model - if(groundTruth != prediction) - amountOfErrors++; - - int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2); - int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); - - confusionMtx[idx1][idx2]++; - - // Collect data for model with normalization - if(groundTruth != predictionWithNormalization) - amountOfErrorsWithNormalization++; - - idx1 = (int)predictionWithNormalization == 1 ? 0 : ((int)predictionWithNormalization == 3 ? 1 : 2); - idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); - - confusionMtxWithNormalization[idx1][idx2]++; - - System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithNormalization, groundTruth); - } - System.out.println(">>> ----------------------------------------------------------------"); - System.out.println("\n>>> -----------------Logistic Regression model-------------"); - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); - - System.out.println("\n>>> -----------------Logistic Regression model with Normalization-------------"); - System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithNormalization); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithNormalization / (double)totalAmount)); - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithNormalization)); - - System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example completed."); - } - } - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java deleted file mode 100644 index c7b7fe8..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * ML multi-class logistic regression examples. - */ -package org.apache.ignite.examples.ml.regression.logistic.multiclass; http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java new file mode 100644 index 0000000..d9d1805 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.svm; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.examples.ml.util.MLSandboxDatasets; +import org.apache.ignite.examples.ml.util.SandboxMLCache; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.svm.SVMLinearClassificationModel; +import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer; + +/** + * Run SVM binary-class classification model ({@link SVMLinearClassificationModel}) over distributed dataset. + * <p> + * Code in this example launches Ignite grid and fills the cache with test data points (based on the + * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p> + * <p> + * After that it trains the model based on the specified data using KMeans algorithm.</p> + * <p> + * Finally, this example loops over the test set of data points, applies the trained model to predict what cluster + * does this point belong to, compares prediction to expected outcome (ground truth), and builds + * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p> + * <p> + * You can change the test data used in this example and re-run it to explore this algorithm further.</p> + */ +public class SVMBinaryClassificationExample { + /** Run example. */ + public static void main(String[] args) throws FileNotFoundException { + System.out.println(); + System.out.println(">>> SVM Binary classification model over cached dataset usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite) + .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS); + + SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer(); + + SVMLinearClassificationModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> v.copyOfRange(1, v.size()), + (k, v) -> v.get(0) + ); + + System.out.println(">>> SVM model " + mdl); + + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Vector> observation : observations) { + Vector val = observation.getValue(); + Vector inputs = val.copyOfRange(1, val.size()); + double groundTruth = val.get(0); + + double prediction = mdl.apply(inputs); + + totalAmount++; + if(groundTruth != prediction) + amountOfErrors++; + + int idx1 = prediction == 0.0 ? 0 : 1; + int idx2 = groundTruth == 0.0 ? 0 : 1; + + confusionMtx[idx1][idx2]++; + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + } + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + + System.out.println(">>> Linear regression model over cache based dataset usage example completed."); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java deleted file mode 100644 index 679bd77..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.svm.binary; - -import java.io.FileNotFoundException; -import java.util.Arrays; -import javax.cache.Cache; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.query.QueryCursor; -import org.apache.ignite.cache.query.ScanQuery; -import org.apache.ignite.examples.ml.util.MLSandboxDatasets; -import org.apache.ignite.examples.ml.util.SandboxMLCache; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel; -import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer; - -/** - * Run SVM binary-class classification model ({@link SVMLinearBinaryClassificationModel}) over distributed dataset. - * <p> - * Code in this example launches Ignite grid and fills the cache with test data points (based on the - * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p> - * <p> - * After that it trains the model based on the specified data using KMeans algorithm.</p> - * <p> - * Finally, this example loops over the test set of data points, applies the trained model to predict what cluster - * does this point belong to, compares prediction to expected outcome (ground truth), and builds - * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p> - * <p> - * You can change the test data used in this example and re-run it to explore this algorithm further.</p> - */ -public class SVMBinaryClassificationExample { - /** Run example. */ - public static void main(String[] args) throws FileNotFoundException { - System.out.println(); - System.out.println(">>> SVM Binary classification model over cached dataset usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - - IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite) - .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS); - - SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer(); - - SVMLinearBinaryClassificationModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> v.copyOfRange(1, v.size()), - (k, v) -> v.get(0) - ); - - System.out.println(">>> SVM model " + mdl); - - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); - - int amountOfErrors = 0; - int totalAmount = 0; - - // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix - int[][] confusionMtx = {{0, 0}, {0, 0}}; - - try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, Vector> observation : observations) { - Vector val = observation.getValue(); - Vector inputs = val.copyOfRange(1, val.size()); - double groundTruth = val.get(0); - - double prediction = mdl.apply(inputs); - - totalAmount++; - if(groundTruth != prediction) - amountOfErrors++; - - int idx1 = prediction == 0.0 ? 0 : 1; - int idx2 = groundTruth == 0.0 ? 0 : 1; - - confusionMtx[idx1][idx2]++; - - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } - - System.out.println(">>> ---------------------------------"); - - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); - } - - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); - - System.out.println(">>> Linear regression model over cache based dataset usage example completed."); - } - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/package-info.java deleted file mode 100644 index 22c9ad7..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * SVM Binary Classification Examples. - */ -package org.apache.ignite.examples.ml.svm.binary; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java deleted file mode 100644 index 987ac41..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.svm.multiclass; - -import java.io.FileNotFoundException; -import java.util.Arrays; -import javax.cache.Cache; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.query.QueryCursor; -import org.apache.ignite.cache.query.ScanQuery; -import org.apache.ignite.examples.ml.util.MLSandboxDatasets; -import org.apache.ignite.examples.ml.util.SandboxMLCache; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; -import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel; -import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationTrainer; - -/** - * Run SVM multi-class classification trainer ({@link SVMLinearMultiClassClassificationModel}) over distributed dataset - * to build two models: one with minmaxscaling and one without minmaxscaling. - * <p> - * Code in this example launches Ignite grid and fills the cache with test data points (preprocessed - * <a href="https://archive.ics.uci.edu/ml/datasets/Glass+Identification">Glass dataset</a>).</p> - * <p> - * After that it trains two SVM multi-class models based on the specified data - one model is with minmaxscaling - * and one without minmaxscaling.</p> - * <p> - * Finally, this example loops over the test set of data points, applies the trained models to predict what cluster - * does this point belong to, compares prediction to expected outcome (ground truth), and builds - * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p> - * <p> - * You can change the test data used in this example and re-run it to explore this algorithm further.</p> - * NOTE: the smallest 3rd class could be classified via linear SVM here. - */ -public class SVMMultiClassClassificationExample { - /** Run example. */ - public static void main(String[] args) throws FileNotFoundException { - System.out.println(); - System.out.println(">>> SVM Multi-class classification model over cached dataset usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - - IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite) - .fillCacheWith(MLSandboxDatasets.GLASS_IDENTIFICATION); - - SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer(); - - SVMLinearMultiClassClassificationModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> v.copyOfRange(1, v.size()), - (k, v) -> v.get(0) - ); - - System.out.println(">>> SVM Multi-class model"); - System.out.println(mdl.toString()); - - MinMaxScalerTrainer<Integer, Vector> minMaxScalerTrainer = new MinMaxScalerTrainer<>(); - - IgniteBiFunction<Integer, Vector, Vector> preprocessor = minMaxScalerTrainer.fit( - ignite, - dataCache, - (k, v) -> v.copyOfRange(1, v.size()) - ); - - SVMLinearMultiClassClassificationModel mdlWithScaling = trainer.fit( - ignite, - dataCache, - preprocessor, - (k, v) -> v.get(0) - ); - - System.out.println(">>> SVM Multi-class model with MinMaxScaling"); - System.out.println(mdlWithScaling.toString()); - - System.out.println(">>> ----------------------------------------------------------------"); - System.out.println(">>> | Prediction\t| Prediction with MinMaxScaling\t| Ground Truth\t|"); - System.out.println(">>> ----------------------------------------------------------------"); - - int amountOfErrors = 0; - int amountOfErrorsWithMinMaxScaling = 0; - int totalAmount = 0; - - // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix - int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; - int[][] confusionMtxWithMinMaxScaling = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; - - try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, Vector> observation : observations) { - Vector val = observation.getValue(); - Vector inputs = val.copyOfRange(1, val.size()); - double groundTruth = val.get(0); - - double prediction = mdl.apply(inputs); - double predictionWithMinMaxScaling = mdlWithScaling.apply(inputs); - - totalAmount++; - - // Collect data for model - if(groundTruth != prediction) - amountOfErrors++; - - int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2); - int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); - - confusionMtx[idx1][idx2]++; - - // Collect data for model with minmaxscaling - if (groundTruth != predictionWithMinMaxScaling) - amountOfErrorsWithMinMaxScaling++; - - idx1 = (int)predictionWithMinMaxScaling == 1 ? 0 : ((int)predictionWithMinMaxScaling == 3 ? 1 : 2); - idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); - - confusionMtxWithMinMaxScaling[idx1][idx2]++; - - System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithMinMaxScaling, groundTruth); - } - System.out.println(">>> ----------------------------------------------------------------"); - System.out.println("\n>>> -----------------SVM model-------------"); - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); - - System.out.println("\n>>> -----------------SVM model with MinMaxScaling-------------"); - System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithMinMaxScaling); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithMinMaxScaling / (double)totalAmount)); - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithMinMaxScaling)); - - System.out.println(">>> Linear regression model over cache based dataset usage example completed."); - } - } - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/package-info.java deleted file mode 100644 index 8b685a4..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * SVM Multi-class Classification Examples. - */ -package org.apache.ignite.examples.ml.svm.multiclass; http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java index b98b0eb..2c6a820 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java @@ -32,8 +32,8 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.apache.ignite.ml.selection.cv.CrossValidation; import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModel.java new file mode 100644 index 0000000..5cc44f8 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModel.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.logistic; + +import java.io.Serializable; +import java.util.Objects; +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.Exporter; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * Logistic regression (logit model) is a generalized linear model used for binomial regression. + */ +public class LogisticRegressionModel implements Model<Vector, Double>, Exportable<LogisticRegressionModel>, Serializable { + /** */ + private static final long serialVersionUID = -133984600091550776L; + + /** Multiplier of the objects's vector required to make prediction. */ + private Vector weights; + + /** Intercept of the linear regression model. */ + private double intercept; + + /** Output label format. 0 and 1 for false value and raw sigmoid regression value otherwise. */ + private boolean isKeepingRawLabels = false; + + /** Threshold to assign '1' label to the observation if raw value more than this threshold. */ + private double threshold = 0.5; + + /** */ + public LogisticRegressionModel(Vector weights, double intercept) { + this.weights = weights; + this.intercept = intercept; + } + + /** + * Set up the output label format. + * + * @param isKeepingRawLabels The parameter value. + * @return Model with new isKeepingRawLabels parameter value. + */ + public LogisticRegressionModel withRawLabels(boolean isKeepingRawLabels) { + this.isKeepingRawLabels = isKeepingRawLabels; + return this; + } + + /** + * Set up the threshold. + * + * @param threshold The parameter value. + * @return Model with new threshold parameter value. + */ + public LogisticRegressionModel withThreshold(double threshold) { + this.threshold = threshold; + return this; + } + + /** + * Set up the weights. + * + * @param weights The parameter value. + * @return Model with new weights parameter value. + */ + public LogisticRegressionModel withWeights(Vector weights) { + this.weights = weights; + return this; + } + + /** + * Set up the intercept. + * + * @param intercept The parameter value. + * @return Model with new intercept parameter value. + */ + public LogisticRegressionModel withIntercept(double intercept) { + this.intercept = intercept; + return this; + } + + /** + * Gets the output label format mode. + * + * @return The parameter value. + */ + public boolean isKeepingRawLabels() { + return isKeepingRawLabels; + } + + /** + * Gets the threshold. + * + * @return The parameter value. + */ + public double threshold() { + return threshold; + } + + /** + * Gets the weights. + * + * @return The parameter value. + */ + public Vector weights() { + return weights; + } + + /** + * Gets the intercept. + * + * @return The parameter value. + */ + public double intercept() { + return intercept; + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector input) { + + final double res = sigmoid(input.dot(weights) + intercept); + + if (isKeepingRawLabels) + return res; + else + return res - threshold > 0 ? 1.0 : 0; + } + + /** + * Sigmoid function. + * @param z The regression value. + * @return The result. + */ + private static double sigmoid(double z) { + return 1.0 / (1.0 + Math.exp(-z)); + } + + /** {@inheritDoc} */ + @Override public <P> void saveModel(Exporter<LogisticRegressionModel, P> exporter, P path) { + exporter.save(this, path); + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + LogisticRegressionModel mdl = (LogisticRegressionModel)o; + + return Double.compare(mdl.intercept, intercept) == 0 + && Double.compare(mdl.threshold, threshold) == 0 + && Boolean.compare(mdl.isKeepingRawLabels, isKeepingRawLabels) == 0 + && Objects.equals(weights, mdl.weights); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + return Objects.hash(weights, intercept, isKeepingRawLabels, threshold); + } + + /** {@inheritDoc} */ + @Override public String toString() { + if (weights.size() < 20) { + StringBuilder builder = new StringBuilder(); + + for (int i = 0; i < weights.size(); i++) { + double nextItem = i == weights.size() - 1 ? intercept : weights.get(i + 1); + + builder.append(String.format("%.4f", Math.abs(weights.get(i)))) + .append("*x") + .append(i) + .append(nextItem > 0 ? " + " : " - "); + } + + builder.append(String.format("%.4f", Math.abs(intercept))); + return builder.toString(); + } + + return "LogisticRegressionModel [" + + "weights=" + weights + + ", intercept=" + intercept + + ']'; + } + + /** {@inheritDoc} */ + @Override public String toString(boolean pretty) { + return toString(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java new file mode 100644 index 0000000..cdbfe4c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.logistic; + +import java.io.Serializable; +import java.util.Arrays; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.apache.ignite.ml.nn.Activators; +import org.apache.ignite.ml.nn.MLPTrainer; +import org.apache.ignite.ml.nn.MultilayerPerceptron; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.nn.architecture.MLPArchitecture; +import org.apache.ignite.ml.optimization.LossFunctions; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; +import org.jetbrains.annotations.NotNull; + +/** + * Trainer of the logistic regression model based on stochastic gradient descent algorithm. + */ +public class LogisticRegressionSGDTrainer<P extends Serializable> extends SingleLabelDatasetTrainer<LogisticRegressionModel> { + /** Update strategy. */ + private UpdatesStrategy updatesStgy = new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ); + + /** Max number of iteration. */ + private int maxIterations = 100; + + /** Batch size. */ + private int batchSize = 100; + + /** Number of local iterations. */ + private int locIterations = 100; + + /** Seed for random generator. */ + private long seed = 1234L; + + /** {@inheritDoc} */ + @Override public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> { + Integer cols = dataset.compute(data -> { + if (data.getFeatures() == null) + return null; + return data.getFeatures().length / data.getRows(); + }, (a, b) -> { + // If both are null then zero will be propagated, no good. + if (a == null) + return b; + return a; + }); + + MLPArchitecture architecture = new MLPArchitecture(cols); + architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID); + + return architecture; + }; + + MLPTrainer<?> trainer = new MLPTrainer<>( + archSupplier, + LossFunctions.L2, + updatesStgy, + maxIterations, + batchSize, + locIterations, + seed + ); + + IgniteBiFunction<K, V, double[]> lbExtractorWrapper = (k, v) -> new double[] {lbExtractor.apply(k, v)}; + MultilayerPerceptron mlp; + if (mdl != null) { + mlp = restoreMLPState(mdl); + mlp = trainer.update(mlp, datasetBuilder, featureExtractor, lbExtractorWrapper); + } + else + mlp = trainer.fit(datasetBuilder, featureExtractor, lbExtractorWrapper); + + double[] params = mlp.parameters().getStorage().data(); + + return new LogisticRegressionModel(new DenseVector(Arrays.copyOf(params, params.length - 1)), + params[params.length - 1] + ); + } + + /** + * @param mdl Model. + * @return state of MLP from last learning. + */ + @NotNull private MultilayerPerceptron restoreMLPState(LogisticRegressionModel mdl) { + Vector weights = mdl.weights(); + double intercept = mdl.intercept(); + + MLPArchitecture architecture1 = new MLPArchitecture(weights.size()); + architecture1 = architecture1.withAddedLayer(1, true, Activators.SIGMOID); + + MLPArchitecture architecture = architecture1; + MultilayerPerceptron perceptron = new MultilayerPerceptron(architecture); + + Vector mlpState = weights.like(weights.size() + 1); + weights.nonZeroes().forEach(ith -> mlpState.set(ith.index(), ith.get())); + mlpState.set(mlpState.size() - 1, intercept); + perceptron.setParameters(mlpState); + + return perceptron; + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(LogisticRegressionModel mdl) { + return true; + } + + /** + * Set up the max amount of iterations before convergence. + * + * @param maxIterations The parameter value. + * @return Model with new max number of iterations before convergence parameter value. + */ + public LogisticRegressionSGDTrainer<P> withMaxIterations(int maxIterations) { + this.maxIterations = maxIterations; + return this; + } + + /** + * Set up the batchSize parameter. + * + * @param batchSize The size of learning batch. + * @return Trainer with new batch size parameter value. + */ + public LogisticRegressionSGDTrainer<P> withBatchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + /** + * Set up the amount of local iterations of SGD algorithm. + * + * @param amountOfLocIterations The parameter value. + * @return Trainer with new locIterations parameter value. + */ + public LogisticRegressionSGDTrainer<P> withLocIterations(int amountOfLocIterations) { + this.locIterations = amountOfLocIterations; + return this; + } + + /** + * Set up the random seed parameter. + * + * @param seed Seed for random generator. + * @return Trainer with new seed parameter value. + */ + public LogisticRegressionSGDTrainer<P> withSeed(long seed) { + this.seed = seed; + return this; + } + + /** + * Set up the regularization parameter. + * + * @param updatesStgy Update strategy. + * @return Trainer with new update strategy parameter value. + */ + public LogisticRegressionSGDTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) { + this.updatesStgy = updatesStgy; + return this; + } + + /** + * Get the update strategy. + * + * @return The property value. + */ + public UpdatesStrategy getUpdatesStgy() { + return updatesStgy; + } + + /** + * Get the max amount of iterations. + * + * @return The property value. + */ + public int getMaxIterations() { + return maxIterations; + } + + /** + * Get the batch size. + * + * @return The property value. + */ + public int getBatchSize() { + return batchSize; + } + + /** + * Get the amount of local iterations. + * + * @return The property value. + */ + public int getLocIterations() { + return locIterations; + } + + /** + * Get the seed for random generator. + * + * @return The property value. + */ + public long getSeed() { + return seed; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java deleted file mode 100644 index f206532..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.logistic.binomial; - -import java.io.Serializable; -import java.util.Objects; -import org.apache.ignite.ml.Exportable; -import org.apache.ignite.ml.Exporter; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.primitives.vector.Vector; - -/** - * Logistic regression (logit model) is a generalized linear model used for binomial regression. - */ -public class LogisticRegressionModel implements Model<Vector, Double>, Exportable<LogisticRegressionModel>, Serializable { - /** */ - private static final long serialVersionUID = -133984600091550776L; - - /** Multiplier of the objects's vector required to make prediction. */ - private Vector weights; - - /** Intercept of the linear regression model. */ - private double intercept; - - /** Output label format. 0 and 1 for false value and raw sigmoid regression value otherwise. */ - private boolean isKeepingRawLabels = false; - - /** Threshold to assign '1' label to the observation if raw value more than this threshold. */ - private double threshold = 0.5; - - /** */ - public LogisticRegressionModel(Vector weights, double intercept) { - this.weights = weights; - this.intercept = intercept; - } - - /** - * Set up the output label format. - * - * @param isKeepingRawLabels The parameter value. - * @return Model with new isKeepingRawLabels parameter value. - */ - public LogisticRegressionModel withRawLabels(boolean isKeepingRawLabels) { - this.isKeepingRawLabels = isKeepingRawLabels; - return this; - } - - /** - * Set up the threshold. - * - * @param threshold The parameter value. - * @return Model with new threshold parameter value. - */ - public LogisticRegressionModel withThreshold(double threshold) { - this.threshold = threshold; - return this; - } - - /** - * Set up the weights. - * - * @param weights The parameter value. - * @return Model with new weights parameter value. - */ - public LogisticRegressionModel withWeights(Vector weights) { - this.weights = weights; - return this; - } - - /** - * Set up the intercept. - * - * @param intercept The parameter value. - * @return Model with new intercept parameter value. - */ - public LogisticRegressionModel withIntercept(double intercept) { - this.intercept = intercept; - return this; - } - - /** - * Gets the output label format mode. - * - * @return The parameter value. - */ - public boolean isKeepingRawLabels() { - return isKeepingRawLabels; - } - - /** - * Gets the threshold. - * - * @return The parameter value. - */ - public double threshold() { - return threshold; - } - - /** - * Gets the weights. - * - * @return The parameter value. - */ - public Vector weights() { - return weights; - } - - /** - * Gets the intercept. - * - * @return The parameter value. - */ - public double intercept() { - return intercept; - } - - /** {@inheritDoc} */ - @Override public Double apply(Vector input) { - - final double res = sigmoid(input.dot(weights) + intercept); - - if (isKeepingRawLabels) - return res; - else - return res - threshold > 0 ? 1.0 : 0; - } - - /** - * Sigmoid function. - * @param z The regression value. - * @return The result. - */ - private static double sigmoid(double z) { - return 1.0 / (1.0 + Math.exp(-z)); - } - - /** {@inheritDoc} */ - @Override public <P> void saveModel(Exporter<LogisticRegressionModel, P> exporter, P path) { - exporter.save(this, path); - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - LogisticRegressionModel mdl = (LogisticRegressionModel)o; - - return Double.compare(mdl.intercept, intercept) == 0 - && Double.compare(mdl.threshold, threshold) == 0 - && Boolean.compare(mdl.isKeepingRawLabels, isKeepingRawLabels) == 0 - && Objects.equals(weights, mdl.weights); - } - - /** {@inheritDoc} */ - @Override public int hashCode() { - return Objects.hash(weights, intercept, isKeepingRawLabels, threshold); - } - - /** {@inheritDoc} */ - @Override public String toString() { - if (weights.size() < 20) { - StringBuilder builder = new StringBuilder(); - - for (int i = 0; i < weights.size(); i++) { - double nextItem = i == weights.size() - 1 ? intercept : weights.get(i + 1); - - builder.append(String.format("%.4f", Math.abs(weights.get(i)))) - .append("*x") - .append(i) - .append(nextItem > 0 ? " + " : " - "); - } - - builder.append(String.format("%.4f", Math.abs(intercept))); - return builder.toString(); - } - - return "LogisticRegressionModel [" + - "weights=" + weights + - ", intercept=" + intercept + - ']'; - } - - /** {@inheritDoc} */ - @Override public String toString(boolean pretty) { - return toString(); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java deleted file mode 100644 index 47fa59d..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java +++ /dev/null @@ -1,246 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.logistic.binomial; - -import java.io.Serializable; -import java.util.Arrays; -import org.apache.ignite.ml.dataset.Dataset; -import org.apache.ignite.ml.dataset.DatasetBuilder; -import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; -import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.ml.nn.Activators; -import org.apache.ignite.ml.nn.MLPTrainer; -import org.apache.ignite.ml.nn.MultilayerPerceptron; -import org.apache.ignite.ml.nn.UpdatesStrategy; -import org.apache.ignite.ml.nn.architecture.MLPArchitecture; -import org.apache.ignite.ml.optimization.LossFunctions; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; -import org.jetbrains.annotations.NotNull; - -/** - * Trainer of the logistic regression model based on stochastic gradient descent algorithm. - */ -public class LogisticRegressionSGDTrainer<P extends Serializable> extends SingleLabelDatasetTrainer<LogisticRegressionModel> { - /** Update strategy. */ - private UpdatesStrategy updatesStgy = new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ); - - /** Max number of iteration. */ - private int maxIterations = 100; - - /** Batch size. */ - private int batchSize = 100; - - /** Number of local iterations. */ - private int locIterations = 100; - - /** Seed for random generator. */ - private long seed = 1234L; - - /** {@inheritDoc} */ - @Override public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - - return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); - } - - /** {@inheritDoc} */ - @Override protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, - DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { - - IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> { - Integer cols = dataset.compute(data -> { - if (data.getFeatures() == null) - return null; - return data.getFeatures().length / data.getRows(); - }, (a, b) -> { - // If both are null then zero will be propagated, no good. - if (a == null) - return b; - return a; - }); - - MLPArchitecture architecture = new MLPArchitecture(cols); - architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID); - - return architecture; - }; - - MLPTrainer<?> trainer = new MLPTrainer<>( - archSupplier, - LossFunctions.L2, - updatesStgy, - maxIterations, - batchSize, - locIterations, - seed - ); - - IgniteBiFunction<K, V, double[]> lbExtractorWrapper = (k, v) -> new double[] {lbExtractor.apply(k, v)}; - MultilayerPerceptron mlp; - if (mdl != null) { - mlp = restoreMLPState(mdl); - mlp = trainer.update(mlp, datasetBuilder, featureExtractor, lbExtractorWrapper); - } - else - mlp = trainer.fit(datasetBuilder, featureExtractor, lbExtractorWrapper); - - double[] params = mlp.parameters().getStorage().data(); - - return new LogisticRegressionModel(new DenseVector(Arrays.copyOf(params, params.length - 1)), - params[params.length - 1] - ); - } - - /** - * @param mdl Model. - * @return state of MLP from last learning. - */ - @NotNull private MultilayerPerceptron restoreMLPState(LogisticRegressionModel mdl) { - Vector weights = mdl.weights(); - double intercept = mdl.intercept(); - - MLPArchitecture architecture1 = new MLPArchitecture(weights.size()); - architecture1 = architecture1.withAddedLayer(1, true, Activators.SIGMOID); - - MLPArchitecture architecture = architecture1; - MultilayerPerceptron perceptron = new MultilayerPerceptron(architecture); - - Vector mlpState = weights.like(weights.size() + 1); - weights.nonZeroes().forEach(ith -> mlpState.set(ith.index(), ith.get())); - mlpState.set(mlpState.size() - 1, intercept); - perceptron.setParameters(mlpState); - - return perceptron; - } - - /** {@inheritDoc} */ - @Override protected boolean checkState(LogisticRegressionModel mdl) { - return true; - } - - /** - * Set up the max amount of iterations before convergence. - * - * @param maxIterations The parameter value. - * @return Model with new max number of iterations before convergence parameter value. - */ - public LogisticRegressionSGDTrainer<P> withMaxIterations(int maxIterations) { - this.maxIterations = maxIterations; - return this; - } - - /** - * Set up the batchSize parameter. - * - * @param batchSize The size of learning batch. - * @return Trainer with new batch size parameter value. - */ - public LogisticRegressionSGDTrainer<P> withBatchSize(int batchSize) { - this.batchSize = batchSize; - return this; - } - - /** - * Set up the amount of local iterations of SGD algorithm. - * - * @param amountOfLocIterations The parameter value. - * @return Trainer with new locIterations parameter value. - */ - public LogisticRegressionSGDTrainer<P> withLocIterations(int amountOfLocIterations) { - this.locIterations = amountOfLocIterations; - return this; - } - - /** - * Set up the random seed parameter. - * - * @param seed Seed for random generator. - * @return Trainer with new seed parameter value. - */ - public LogisticRegressionSGDTrainer<P> withSeed(long seed) { - this.seed = seed; - return this; - } - - /** - * Set up the regularization parameter. - * - * @param updatesStgy Update strategy. - * @return Trainer with new update strategy parameter value. - */ - public LogisticRegressionSGDTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) { - this.updatesStgy = updatesStgy; - return this; - } - - /** - * Get the update strategy. - * - * @return The property value. - */ - public UpdatesStrategy getUpdatesStgy() { - return updatesStgy; - } - - /** - * Get the max amount of iterations. - * - * @return The property value. - */ - public int getMaxIterations() { - return maxIterations; - } - - /** - * Get the batch size. - * - * @return The property value. - */ - public int getBatchSize() { - return batchSize; - } - - /** - * Get the amount of local iterations. - * - * @return The property value. - */ - public int getLocIterations() { - return locIterations; - } - - /** - * Get the seed for random generator. - * - * @return The property value. - */ - public long getSeed() { - return seed; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java deleted file mode 100644 index d32b1ee..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains binomial logistic regression. - */ -package org.apache.ignite.ml.regressions.logistic.binomial; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java deleted file mode 100644 index a7c9118..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.logistic.multiclass; - -import java.io.Serializable; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.TreeMap; -import org.apache.ignite.ml.Exportable; -import org.apache.ignite.ml.Exporter; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; - -/** Base class for multi-classification model for set of Logistic Regression classifiers. */ -public class LogRegressionMultiClassModel implements Model<Vector, Double>, Exportable<LogRegressionMultiClassModel>, Serializable { - /** */ - private static final long serialVersionUID = -114986533350117L; - - /** List of models associated with each class. */ - private Map<Double, LogisticRegressionModel> models; - - /** */ - public LogRegressionMultiClassModel() { - this.models = new HashMap<>(); - } - - /** {@inheritDoc} */ - @Override public Double apply(Vector input) { - TreeMap<Double, Double> maxMargins = new TreeMap<>(); - - models.forEach((k, v) -> maxMargins.put(1.0 / (1.0 + Math.exp(-(input.dot(v.weights()) + v.intercept()))), k)); - - return maxMargins.lastEntry().getValue(); - } - - /** {@inheritDoc} */ - @Override public <P> void saveModel(Exporter<LogRegressionMultiClassModel, P> exporter, P path) { - exporter.save(this, path); - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - - if (o == null || getClass() != o.getClass()) - return false; - - LogRegressionMultiClassModel mdl = (LogRegressionMultiClassModel)o; - - return Objects.equals(models, mdl.models); - } - - /** {@inheritDoc} */ - @Override public int hashCode() { - return Objects.hash(models); - } - - /** {@inheritDoc} */ - @Override public String toString() { - StringBuilder wholeStr = new StringBuilder(); - - models.forEach((clsLb, mdl) -> - wholeStr - .append("The class with label ") - .append(clsLb) - .append(" has classifier: ") - .append(mdl.toString()) - .append(System.lineSeparator()) - ); - - return wholeStr.toString(); - } - - /** {@inheritDoc} */ - @Override public String toString(boolean pretty) { - return toString(); - } - - /** - * Adds a specific Log Regression binary classifier to the bunch of same classifiers. - * - * @param clsLb The class label for the added model. - * @param mdl The model. - */ - public void add(double clsLb, LogisticRegressionModel mdl) { - models.put(clsLb, mdl); - } - - /** - * @param clsLb Class label. - * @return model for class label if it exists. - */ - public Optional<LogisticRegressionModel> getModel(Double clsLb) { - return Optional.ofNullable(models.get(clsLb)); - } -}
