Repository: ignite Updated Branches: refs/heads/master 7633e18f1 -> e82ab505b
IGNITE-10428: [ML] Add example for OneVsRest trainer/model usage This closes #5631 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/e82ab505 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/e82ab505 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/e82ab505 Branch: refs/heads/master Commit: e82ab505bb9b416239a6b943ec07db26162a7400 Parents: 7633e18 Author: zaleslaw <[email protected]> Authored: Tue Dec 11 14:41:38 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Dec 11 14:41:38 2018 +0300 ---------------------------------------------------------------------- .../OneVsRestClassificationExample.java | 162 +++++++++++++++++++ .../examples/ml/multiclass/package-info.java | 22 +++ .../ignite/ml/multiclass/package-info.java | 22 +++ 3 files changed, 206 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/e82ab505/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/OneVsRestClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/OneVsRestClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/OneVsRestClassificationExample.java new file mode 100644 index 0000000..d3e030b --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/OneVsRestClassificationExample.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.multiclass; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.multiclass.MultiClassModel; +import org.apache.ignite.ml.multiclass.OneVsRestTrainer; +import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; +import org.apache.ignite.ml.svm.SVMLinearClassificationModel; +import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer; +import org.apache.ignite.ml.util.MLSandboxDatasets; +import org.apache.ignite.ml.util.SandboxMLCache; + +/** + * Run One-vs-Rest multi-class classification trainer ({@link OneVsRestTrainer}) parametrized by binary SVM classifier + * ({@link SVMLinearClassificationTrainer}) over distributed dataset + * to build two models: one with min-max scaling and one without min-max scaling. + * <p> + * Code in this example launches Ignite grid and fills the cache with test data points (preprocessed + * <a href="https://archive.ics.uci.edu/ml/datasets/Glass+Identification">Glass dataset</a>).</p> + * <p> + * After that it trains two One-vs-Rest multi-class models based on the specified data - one model is with min-max scaling + * and one without min-max scaling.</p> + * <p> + * Finally, this example loops over the test set of data points, applies the trained models to predict what cluster + * does this point belong to, compares prediction to expected outcome (ground truth), and builds + * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p> + * <p> + * You can change the test data used in this example and re-run it to explore this algorithm further.</p> + * NOTE: the smallest 3rd class could not be classified via linear SVM here. + */ +public class OneVsRestClassificationExample { + /** Run example. */ + public static void main(String[] args) throws FileNotFoundException { + System.out.println(); + System.out.println(">>> One-vs-Rest SVM Multi-class classification model over cached dataset usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite) + .fillCacheWith(MLSandboxDatasets.GLASS_IDENTIFICATION); + + OneVsRestTrainer<SVMLinearClassificationModel> trainer + = new OneVsRestTrainer<>(new SVMLinearClassificationTrainer() + .withAmountOfIterations(20) + .withAmountOfLocIterations(50) + .withLambda(0.2) + .withSeed(1234L) + ); + + MultiClassModel<SVMLinearClassificationModel> mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> v.copyOfRange(1, v.size()), + (k, v) -> v.get(0) + ); + + System.out.println(">>> One-vs-Rest SVM Multi-class model"); + System.out.println(mdl.toString()); + + MinMaxScalerTrainer<Integer, Vector> minMaxScalerTrainer = new MinMaxScalerTrainer<>(); + + IgniteBiFunction<Integer, Vector, Vector> preprocessor = minMaxScalerTrainer.fit( + ignite, + dataCache, + (k, v) -> v.copyOfRange(1, v.size()) + ); + + MultiClassModel<SVMLinearClassificationModel> mdlWithScaling = trainer.fit( + ignite, + dataCache, + preprocessor, + (k, v) -> v.get(0) + ); + + System.out.println(">>> One-vs-Rest SVM Multi-class model with MinMaxScaling"); + System.out.println(mdlWithScaling.toString()); + + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println(">>> | Prediction\t| Prediction with MinMaxScaling\t| Ground Truth\t|"); + System.out.println(">>> ----------------------------------------------------------------"); + + int amountOfErrors = 0; + int amountOfErrorsWithMinMaxScaling = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; + int[][] confusionMtxWithMinMaxScaling = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Vector> observation : observations) { + Vector val = observation.getValue(); + Vector inputs = val.copyOfRange(1, val.size()); + double groundTruth = val.get(0); + + double prediction = mdl.apply(inputs); + double predictionWithMinMaxScaling = mdlWithScaling.apply(inputs); + + totalAmount++; + + // Collect data for model + if (groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2); + int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); + + confusionMtx[idx1][idx2]++; + + // Collect data for model with min-max scaling + if (groundTruth != predictionWithMinMaxScaling) + amountOfErrorsWithMinMaxScaling++; + + idx1 = (int)predictionWithMinMaxScaling == 1 ? 0 : ((int)predictionWithMinMaxScaling == 3 ? 1 : 2); + idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); + + confusionMtxWithMinMaxScaling[idx1][idx2]++; + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithMinMaxScaling, groundTruth); + } + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println("\n>>> -----------------One-vs-Rest SVM model-------------"); + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + + System.out.println("\n>>> -----------------One-vs-Rest SVM model with MinMaxScaling-------------"); + System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithMinMaxScaling); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithMinMaxScaling / (double)totalAmount)); + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithMinMaxScaling)); + + System.out.println(">>> One-vs-Rest SVM model over cache based dataset usage example completed."); + } + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e82ab505/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/package-info.java new file mode 100644 index 0000000..dc1d8ad --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Multi-class classification examples. + */ +package org.apache.ignite.examples.ml.multiclass; http://git-wip-us.apache.org/repos/asf/ignite/blob/e82ab505/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/package-info.java new file mode 100644 index 0000000..3a352f7 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains various multi-classifier models and trainers. + */ +package org.apache.ignite.ml.multiclass;
