IGNITE-8059: Integrate decision tree with partition based dataset. this closes #3760
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/139c2af6 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/139c2af6 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/139c2af6 Branch: refs/heads/ignite-8201 Commit: 139c2af66a9f745f89429842810f5d5fe1addf28 Parents: a64b941 Author: dmitrievanthony <dmitrievanth...@gmail.com> Authored: Tue Apr 10 12:46:43 2018 +0300 Committer: YuriBabak <y.ch...@gmail.com> Committed: Tue Apr 10 12:46:44 2018 +0300 ---------------------------------------------------------------------- ...ecisionTreeClassificationTrainerExample.java | 147 +++++ .../DecisionTreeRegressionTrainerExample.java | 124 ++++ .../ignite/examples/ml/tree/package-info.java | 22 + .../examples/ml/trees/DecisionTreesExample.java | 354 ------------ .../ignite/examples/ml/trees/package-info.java | 22 - .../main/java/org/apache/ignite/ml/Trainer.java | 3 - .../org/apache/ignite/ml/tree/DecisionTree.java | 252 ++++++++ .../tree/DecisionTreeClassificationTrainer.java | 93 +++ .../ml/tree/DecisionTreeConditionalNode.java | 78 +++ .../ignite/ml/tree/DecisionTreeLeafNode.java | 48 ++ .../apache/ignite/ml/tree/DecisionTreeNode.java | 26 + .../ml/tree/DecisionTreeRegressionTrainer.java | 60 ++ .../org/apache/ignite/ml/tree/TreeFilter.java | 38 ++ .../ignite/ml/tree/data/DecisionTreeData.java | 128 +++++ .../ml/tree/data/DecisionTreeDataBuilder.java | 73 +++ .../ignite/ml/tree/data/package-info.java | 22 + .../ml/tree/impurity/ImpurityMeasure.java | 55 ++ .../impurity/ImpurityMeasureCalculator.java | 38 ++ .../tree/impurity/gini/GiniImpurityMeasure.java | 115 ++++ .../gini/GiniImpurityMeasureCalculator.java | 110 ++++ .../ml/tree/impurity/gini/package-info.java | 22 + .../tree/impurity/mse/MSEImpurityMeasure.java | 133 +++++ .../mse/MSEImpurityMeasureCalculator.java | 80 +++ .../ml/tree/impurity/mse/package-info.java | 22 + .../ignite/ml/tree/impurity/package-info.java | 22 + .../util/SimpleStepFunctionCompressor.java | 149 +++++ .../ml/tree/impurity/util/StepFunction.java | 162 ++++++ .../impurity/util/StepFunctionCompressor.java | 55 ++ .../ml/tree/impurity/util/package-info.java | 22 + .../ml/tree/leaf/DecisionTreeLeafBuilder.java | 38 ++ .../tree/leaf/MeanDecisionTreeLeafBuilder.java | 73 +++ .../leaf/MostCommonDecisionTreeLeafBuilder.java | 86 +++ .../ignite/ml/tree/leaf/package-info.java | 22 + .../org/apache/ignite/ml/tree/package-info.java | 22 + .../ignite/ml/trees/CategoricalRegionInfo.java | 72 --- .../ignite/ml/trees/CategoricalSplitInfo.java | 68 --- .../ignite/ml/trees/ContinuousRegionInfo.java | 74 --- .../ml/trees/ContinuousSplitCalculator.java | 51 -- .../org/apache/ignite/ml/trees/RegionInfo.java | 62 -- .../ml/trees/models/DecisionTreeModel.java | 44 -- .../ignite/ml/trees/models/package-info.java | 22 - .../ml/trees/nodes/CategoricalSplitNode.java | 50 -- .../ml/trees/nodes/ContinuousSplitNode.java | 56 -- .../ignite/ml/trees/nodes/DecisionTreeNode.java | 33 -- .../org/apache/ignite/ml/trees/nodes/Leaf.java | 49 -- .../apache/ignite/ml/trees/nodes/SplitNode.java | 100 ---- .../ignite/ml/trees/nodes/package-info.java | 22 - .../apache/ignite/ml/trees/package-info.java | 22 - .../ml/trees/trainers/columnbased/BiIndex.java | 113 ---- ...exedCacheColumnDecisionTreeTrainerInput.java | 57 -- .../CacheColumnDecisionTreeTrainerInput.java | 141 ----- .../columnbased/ColumnDecisionTreeTrainer.java | 568 ------------------- .../ColumnDecisionTreeTrainerInput.java | 55 -- .../MatrixColumnDecisionTreeTrainerInput.java | 83 --- .../trainers/columnbased/RegionProjection.java | 109 ---- .../trainers/columnbased/TrainingContext.java | 166 ------ .../columnbased/caches/ContextCache.java | 68 --- .../columnbased/caches/FeaturesCache.java | 151 ----- .../columnbased/caches/ProjectionsCache.java | 286 ---------- .../trainers/columnbased/caches/SplitCache.java | 206 ------- .../columnbased/caches/package-info.java | 22 - .../ContinuousSplitCalculators.java | 34 -- .../contsplitcalcs/GiniSplitCalculator.java | 234 -------- .../contsplitcalcs/VarianceSplitCalculator.java | 179 ------ .../contsplitcalcs/package-info.java | 22 - .../trainers/columnbased/package-info.java | 22 - .../columnbased/regcalcs/RegionCalculators.java | 85 --- .../columnbased/regcalcs/package-info.java | 22 - .../vectors/CategoricalFeatureProcessor.java | 212 ------- .../vectors/ContinuousFeatureProcessor.java | 111 ---- .../vectors/ContinuousSplitInfo.java | 71 --- .../columnbased/vectors/FeatureProcessor.java | 82 --- .../vectors/FeatureVectorProcessorUtils.java | 57 -- .../columnbased/vectors/SampleInfo.java | 80 --- .../trainers/columnbased/vectors/SplitInfo.java | 106 ---- .../columnbased/vectors/package-info.java | 22 - .../org/apache/ignite/ml/IgniteMLTestSuite.java | 4 +- .../ml/nn/performance/MnistMLPTestUtil.java | 9 +- ...reeClassificationTrainerIntegrationTest.java | 100 ++++ .../DecisionTreeClassificationTrainerTest.java | 91 +++ ...ionTreeRegressionTrainerIntegrationTest.java | 100 ++++ .../tree/DecisionTreeRegressionTrainerTest.java | 91 +++ .../ignite/ml/tree/DecisionTreeTestSuite.java | 48 ++ .../ml/tree/data/DecisionTreeDataTest.java | 59 ++ .../gini/GiniImpurityMeasureCalculatorTest.java | 103 ++++ .../impurity/gini/GiniImpurityMeasureTest.java | 131 +++++ .../mse/MSEImpurityMeasureCalculatorTest.java | 59 ++ .../impurity/mse/MSEImpurityMeasureTest.java | 109 ++++ .../util/SimpleStepFunctionCompressorTest.java | 75 +++ .../ml/tree/impurity/util/StepFunctionTest.java | 71 +++ .../tree/impurity/util/TestImpurityMeasure.java | 88 +++ .../DecisionTreeMNISTIntegrationTest.java | 105 ++++ .../tree/performance/DecisionTreeMNISTTest.java | 74 +++ .../ignite/ml/trees/BaseDecisionTreeTest.java | 70 --- .../ml/trees/ColumnDecisionTreeTrainerTest.java | 191 ------- .../ignite/ml/trees/DecisionTreesTestSuite.java | 33 -- .../ml/trees/GiniSplitCalculatorTest.java | 141 ----- .../ignite/ml/trees/SplitDataGenerator.java | 390 ------------- .../ml/trees/VarianceSplitCalculatorTest.java | 84 --- .../ColumnDecisionTreeTrainerBenchmark.java | 456 --------------- .../IgniteColumnDecisionTreeGiniBenchmark.java | 70 --- ...niteColumnDecisionTreeVarianceBenchmark.java | 71 --- .../yardstick/ml/trees/SplitDataGenerator.java | 426 -------------- .../ignite/yardstick/ml/trees/package-info.java | 22 - 104 files changed, 3647 insertions(+), 6429 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java new file mode 100644 index 0000000..cef6368 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tree; + +import java.util.Random; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.thread.IgniteThread; + +/** + * Example of using distributed {@link DecisionTreeClassificationTrainer}. + */ +public class DecisionTreeClassificationTrainerExample { + /** + * Executes example. + * + * @param args Command line arguments, none required. + */ + public static void main(String... args) throws InterruptedException { + System.out.println(">>> Decision tree classification trainer example started."); + + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + DecisionTreeClassificationTrainerExample.class.getSimpleName(), () -> { + + // Create cache with training data. + CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>(); + trainingSetCfg.setName("TRAINING_SET"); + trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg); + + Random rnd = new Random(0); + + // Fill training data. + for (int i = 0; i < 1000; i++) + trainingSet.put(i, generatePoint(rnd)); + + // Create classification trainer. + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0); + + // Train decision tree model. + DecisionTreeNode mdl = trainer.fit( + new CacheBasedDatasetBuilder<>(ignite, trainingSet), + (k, v) -> new double[]{v.x, v.y}, + (k, v) -> v.lb + ); + + // Calculate score. + int correctPredictions = 0; + for (int i = 0; i < 1000; i++) { + LabeledPoint pnt = generatePoint(rnd); + + double prediction = mdl.apply(new double[]{pnt.x, pnt.y}); + + if (prediction == pnt.lb) + correctPredictions++; + } + + System.out.println(">>> Accuracy: " + correctPredictions / 10.0 + "%"); + + System.out.println(">>> Decision tree classification trainer example completed."); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } + + /** + * Generate point with {@code x} in (-0.5, 0.5) and {@code y} in the same interval. If {@code x * y > 0} then label + * is 1, otherwise 0. + * + * @param rnd Random. + * @return Point with label. + */ + private static LabeledPoint generatePoint(Random rnd) { + + double x = rnd.nextDouble() - 0.5; + double y = rnd.nextDouble() - 0.5; + + return new LabeledPoint(x, y, x * y > 0 ? 1 : 0); + } + + /** Point data class. */ + private static class Point { + /** X coordinate. */ + final double x; + + /** Y coordinate. */ + final double y; + + /** + * Constructs a new instance of point. + * + * @param x X coordinate. + * @param y Y coordinate. + */ + Point(double x, double y) { + this.x = x; + this.y = y; + } + } + + /** Labeled point data class. */ + private static class LabeledPoint extends Point { + /** Point label. */ + final double lb; + + /** + * Constructs a new instance of labeled point data. + * + * @param x X coordinate. + * @param y Y coordinate. + * @param lb Point label. + */ + LabeledPoint(double x, double y, double lb) { + super(x, y); + this.lb = lb; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java new file mode 100644 index 0000000..61ba5f9 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tree; + +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer; +import org.apache.ignite.thread.IgniteThread; + +/** + * Example of using distributed {@link DecisionTreeRegressionTrainer}. + */ +public class DecisionTreeRegressionTrainerExample { + /** + * Executes example. + * + * @param args Command line arguments, none required. + */ + public static void main(String... args) throws InterruptedException { + System.out.println(">>> Decision tree regression trainer example started."); + + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + DecisionTreeRegressionTrainerExample.class.getSimpleName(), () -> { + + // Create cache with training data. + CacheConfiguration<Integer, Point> trainingSetCfg = new CacheConfiguration<>(); + trainingSetCfg.setName("TRAINING_SET"); + trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, Point> trainingSet = ignite.createCache(trainingSetCfg); + + // Fill training data. + generatePoints(trainingSet); + + // Create regression trainer. + DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0); + + // Train decision tree model. + DecisionTreeNode mdl = trainer.fit( + new CacheBasedDatasetBuilder<>(ignite, trainingSet), + (k, v) -> new double[] {v.x}, + (k, v) -> v.y + ); + + System.out.println(">>> Linear regression model: " + mdl); + + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); + + // Calculate score. + for (int x = 0; x < 10; x++) { + double predicted = mdl.apply(new double[] {x}); + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x)); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println(">>> Decision tree regression trainer example completed."); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } + + /** + * Generates {@code sin(x)} on interval [0, 10) and loads into the specified cache. + */ + private static void generatePoints(IgniteCache<Integer, Point> trainingSet) { + for (int i = 0; i < 1000; i++) { + double x = i / 100.0; + double y = Math.sin(x); + + trainingSet.put(i, new Point(x, y)); + } + } + + /** Point data class. */ + private static class Point { + /** X coordinate. */ + final double x; + + /** Y coordinate. */ + final double y; + + /** + * Constructs a new instance of point. + * + * @param x X coordinate. + * @param y Y coordinate. + */ + Point(double x, double y) { + this.x = x; + this.y = y; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/examples/src/main/java/org/apache/ignite/examples/ml/tree/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/package-info.java new file mode 100644 index 0000000..d8d9de6 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Decision trees examples. + */ +package org.apache.ignite.examples.ml.tree; http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java deleted file mode 100644 index b1b2c42..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java +++ /dev/null @@ -1,354 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.trees; - -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.net.URL; -import java.nio.channels.Channels; -import java.nio.channels.ReadableByteChannel; -import java.util.Collection; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Random; -import java.util.Scanner; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import java.util.zip.GZIPInputStream; -import org.apache.commons.cli.BasicParser; -import org.apache.commons.cli.CommandLine; -import org.apache.commons.cli.CommandLineParser; -import org.apache.commons.cli.Option; -import org.apache.commons.cli.OptionBuilder; -import org.apache.commons.cli.Options; -import org.apache.commons.cli.ParseException; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.IgniteDataStreamer; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.CacheWriteSynchronizationMode; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.examples.ExampleNodeStartup; -import org.apache.ignite.examples.ml.MLExamplesCommonArgs; -import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.estimators.Estimators; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.functions.IgniteTriFunction; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.trees.models.DecisionTreeModel; -import org.apache.ignite.ml.trees.trainers.columnbased.BiIndex; -import org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator; -import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; -import org.apache.ignite.ml.util.MnistUtils; -import org.jetbrains.annotations.NotNull; - -/** - * <p> - * Example of usage of decision trees algorithm for MNIST dataset - * (it can be found here: http://yann.lecun.com/exdb/mnist/). </p> - * <p> - * Remote nodes should always be started with special configuration file which - * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p> - * <p> - * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node - * with {@code examples/config/example-ignite.xml} configuration.</p> - * <p> - * It is recommended to start at least one node prior to launching this example if you intend - * to run it with default memory settings.</p> - * <p> - * This example should be run with program arguments, for example - * -cfg examples/config/example-ignite.xml.</p> - * <p> - * -cfg specifies path to a config path.</p> - */ -public class DecisionTreesExample { - /** Name of parameter specifying path of Ignite config. */ - private static final String CONFIG = "cfg"; - - /** Default config path. */ - private static final String DEFAULT_CONFIG = "examples/config/example-ignite.xml"; - - /** - * Folder in which MNIST dataset is expected. - */ - private static String MNIST_DIR = "examples/src/main/resources/"; - - /** - * Key for MNIST training images. - */ - private static String MNIST_TRAIN_IMAGES = "train_images"; - - /** - * Key for MNIST training labels. - */ - private static String MNIST_TRAIN_LABELS = "train_labels"; - - /** - * Key for MNIST test images. - */ - private static String MNIST_TEST_IMAGES = "test_images"; - - /** - * Key for MNIST test labels. - */ - private static String MNIST_TEST_LABELS = "test_labels"; - - /** - * Launches example. - * - * @param args Program arguments. - */ - public static void main(String[] args) throws IOException { - System.out.println(">>> Decision trees example started."); - - String igniteCfgPath; - - CommandLineParser parser = new BasicParser(); - - String trainingImagesPath; - String trainingLabelsPath; - - String testImagesPath; - String testLabelsPath; - - Map<String, String> mnistPaths = new HashMap<>(); - - mnistPaths.put(MNIST_TRAIN_IMAGES, "train-images-idx3-ubyte"); - mnistPaths.put(MNIST_TRAIN_LABELS, "train-labels-idx1-ubyte"); - mnistPaths.put(MNIST_TEST_IMAGES, "t10k-images-idx3-ubyte"); - mnistPaths.put(MNIST_TEST_LABELS, "t10k-labels-idx1-ubyte"); - - try { - // Parse the command line arguments. - CommandLine line = parser.parse(buildOptions(), args); - - if (line.hasOption(MLExamplesCommonArgs.UNATTENDED)) { - System.out.println(">>> Skipped example execution because 'unattended' mode is used."); - System.out.println(">>> Decision trees example finished."); - return; - } - - igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG); - } - catch (ParseException e) { - e.printStackTrace(); - return; - } - - if (!getMNIST(mnistPaths.values())) { - System.out.println(">>> You should have MNIST dataset in " + MNIST_DIR + " to run this example."); - return; - } - - trainingImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + - mnistPaths.get(MNIST_TRAIN_IMAGES))).getPath(); - trainingLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + - mnistPaths.get(MNIST_TRAIN_LABELS))).getPath(); - testImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + - mnistPaths.get(MNIST_TEST_IMAGES))).getPath(); - testLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + - mnistPaths.get(MNIST_TEST_LABELS))).getPath(); - - try (Ignite ignite = Ignition.start(igniteCfgPath)) { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - - int ptsCnt = 60000; - int featCnt = 28 * 28; - - Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnistAsStream(trainingImagesPath, trainingLabelsPath, - new Random(123L), ptsCnt); - - Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnistAsStream(testImagesPath, testLabelsPath, - new Random(123L), 10_000); - - IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite); - - loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite); - - ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = new ColumnDecisionTreeTrainer<>(10, - ContinuousSplitCalculators.GINI.apply(ignite), - RegionCalculators.GINI, - RegionCalculators.MOST_COMMON, - ignite); - - System.out.println(">>> Training started"); - long before = System.currentTimeMillis(); - DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt)); - System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); - - IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = - Estimators.errorsPercentage(); - - Double accuracy = mse.apply(mdl, testMnistStream.map(v -> - new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); - - System.out.println(">>> Errs percentage: " + accuracy); - } - catch (IOException e) { - e.printStackTrace(); - } - - System.out.println(">>> Decision trees example finished."); - } - - /** - * Get MNIST dataset. Value of predicate 'MNIST dataset is present in expected folder' is returned. - * - * @param mnistFileNames File names of MNIST dataset. - * @return Value of predicate 'MNIST dataset is present in expected folder'. - * @throws IOException In case of file system errors. - */ - private static boolean getMNIST(Collection<String> mnistFileNames) throws IOException { - List<String> missing = mnistFileNames.stream(). - filter(f -> IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + f) == null). - collect(Collectors.toList()); - - if (!missing.isEmpty()) { - System.out.println(">>> You have not fully downloaded MNIST dataset in directory " + MNIST_DIR + - ", do you want it to be downloaded? [y]/n"); - Scanner s = new Scanner(System.in); - String str = s.nextLine(); - - if (!str.isEmpty() && !str.toLowerCase().equals("y")) - return false; - } - - for (String s : missing) { - String f = s + ".gz"; - System.out.println(">>> Downloading " + f + "..."); - URL website = new URL("http://yann.lecun.com/exdb/mnistAsStream/" + f); - ReadableByteChannel rbc = Channels.newChannel(website.openStream()); - FileOutputStream fos = new FileOutputStream(MNIST_DIR + "/" + f); - fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE); - System.out.println(">>> Done."); - - System.out.println(">>> Unzipping " + f + "..."); - unzip(MNIST_DIR + "/" + f, MNIST_DIR + "/" + s); - - System.out.println(">>> Deleting gzip " + f + ", status: " + - Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + f)).delete()); - - System.out.println(">>> Done."); - } - - return true; - } - - /** - * Unzip file located in {@code input} to {@code output}. - * - * @param input Input file path. - * @param output Output file path. - * @throws IOException In case of file system errors. - */ - private static void unzip(String input, String output) throws IOException { - byte[] buf = new byte[1024]; - - try (GZIPInputStream gis = new GZIPInputStream(new FileInputStream(input)); - FileOutputStream out = new FileOutputStream(output)) { - int sz; - while ((sz = gis.read(buf)) > 0) - out.write(buf, 0, sz); - } - } - - /** - * Build cli options. - */ - @NotNull private static Options buildOptions() { - Options options = new Options(); - - Option cfgOpt = OptionBuilder - .withArgName(CONFIG) - .withLongOpt(CONFIG) - .hasArg() - .withDescription("Path to the config.") - .isRequired(false).create(); - - Option unattended = OptionBuilder - .withArgName(MLExamplesCommonArgs.UNATTENDED) - .withLongOpt(MLExamplesCommonArgs.UNATTENDED) - .withDescription("Is example run unattended.") - .isRequired(false).create(); - - options.addOption(cfgOpt); - options.addOption(unattended); - - return options; - } - - /** - * Creates cache where data for training is stored. - * - * @param ignite Ignite instance. - * @return cache where data for training is stored. - */ - private static IgniteCache<BiIndex, Double> createBiIndexedCache(Ignite ignite) { - CacheConfiguration<BiIndex, Double> cfg = new CacheConfiguration<>(); - - // Write to primary. - cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); - - // No copying of values. - cfg.setCopyOnRead(false); - - cfg.setName("TMP_BI_INDEXED_CACHE"); - - return ignite.getOrCreateCache(cfg); - } - - /** - * Loads vectors into cache. - * - * @param cacheName Name of cache. - * @param vectorsIter Iterator over vectors to load. - * @param vectorSize Size of vector. - * @param ignite Ignite instance. - */ - private static void loadVectorsIntoBiIndexedCache(String cacheName, Iterator<? extends Vector> vectorsIter, - int vectorSize, Ignite ignite) { - try (IgniteDataStreamer<BiIndex, Double> streamer = - ignite.dataStreamer(cacheName)) { - int sampleIdx = 0; - - streamer.perNodeBufferSize(10000); - - while (vectorsIter.hasNext()) { - org.apache.ignite.ml.math.Vector next = vectorsIter.next(); - - for (int i = 0; i < vectorSize; i++) - streamer.addData(new BiIndex(sampleIdx, i), next.getX(i)); - - sampleIdx++; - - if (sampleIdx % 1000 == 0) - System.out.println(">>> Loaded " + sampleIdx + " vectors."); - } - } - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/examples/src/main/java/org/apache/ignite/examples/ml/trees/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/trees/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/trees/package-info.java deleted file mode 100644 index d944f60..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/trees/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Decision trees examples. - */ -package org.apache.ignite.examples.ml.trees; http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java index 4e0a570..f53b801 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java @@ -17,11 +17,8 @@ package org.apache.ignite.ml; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; - /** * Interface for Trainers. Trainer is just a function which produces model from the data. - * See for example {@link ColumnDecisionTreeTrainer}. * * @param <M> Type of produced model. * @param <T> Type of data needed for model producing. http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java new file mode 100644 index 0000000..c0b88fc --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +import java.io.Serializable; +import java.util.Arrays; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder; +import org.apache.ignite.ml.tree.impurity.ImpurityMeasure; +import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator; +import org.apache.ignite.ml.tree.impurity.util.StepFunction; +import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor; +import org.apache.ignite.ml.tree.leaf.DecisionTreeLeafBuilder; + +/** + * Distributed decision tree trainer that allows to fit trees using row-partitioned dataset. + * + * @param <T> Type of impurity measure. + */ +abstract class DecisionTree<T extends ImpurityMeasure<T>> implements DatasetTrainer<DecisionTreeNode, Double> { + /** Max tree deep. */ + private final int maxDeep; + + /** Min impurity decrease. */ + private final double minImpurityDecrease; + + /** Step function compressor. */ + private final StepFunctionCompressor<T> compressor; + + /** Decision tree leaf builder. */ + private final DecisionTreeLeafBuilder decisionTreeLeafBuilder; + + /** + * Constructs a new distributed decision tree trainer. + * + * @param maxDeep Max tree deep. + * @param minImpurityDecrease Min impurity decrease. + * @param compressor Impurity function compressor. + * @param decisionTreeLeafBuilder Decision tree leaf builder. + */ + DecisionTree(int maxDeep, double minImpurityDecrease, StepFunctionCompressor<T> compressor, DecisionTreeLeafBuilder decisionTreeLeafBuilder) { + this.maxDeep = maxDeep; + this.minImpurityDecrease = minImpurityDecrease; + this.compressor = compressor; + this.decisionTreeLeafBuilder = decisionTreeLeafBuilder; + } + + /** {@inheritDoc} */ + @Override public <K, V> DecisionTreeNode fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build( + new EmptyContextBuilder<>(), + new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor) + )) { + return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset)); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Returns impurity measure calculator. + * + * @param dataset Dataset. + * @return Impurity measure calculator. + */ + abstract ImpurityMeasureCalculator<T> getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> dataset); + + /** + * Splits the node specified by the given dataset and predicate and returns decision tree node. + * + * @param dataset Dataset. + * @param filter Decision tree node predicate. + * @param deep Current tree deep. + * @param impurityCalc Impurity measure calculator. + * @return Decision tree node. + */ + private DecisionTreeNode split(Dataset<EmptyContext, DecisionTreeData> dataset, TreeFilter filter, int deep, + ImpurityMeasureCalculator<T> impurityCalc) { + if (deep >= maxDeep) + return decisionTreeLeafBuilder.createLeafNode(dataset, filter); + + StepFunction<T>[] criterionFunctions = calculateImpurityForAllColumns(dataset, filter, impurityCalc); + + if (criterionFunctions == null) + return decisionTreeLeafBuilder.createLeafNode(dataset, filter); + + SplitPoint splitPnt = calculateBestSplitPoint(criterionFunctions); + + if (splitPnt == null) + return decisionTreeLeafBuilder.createLeafNode(dataset, filter); + + return new DecisionTreeConditionalNode( + splitPnt.col, + splitPnt.threshold, + split(dataset, updatePredicateForThenNode(filter, splitPnt), deep + 1, impurityCalc), + split(dataset, updatePredicateForElseNode(filter, splitPnt), deep + 1, impurityCalc) + ); + } + + /** + * Calculates impurity measure functions for all columns for the node specified by the given dataset and predicate. + * + * @param dataset Dataset. + * @param filter Decision tree node predicate. + * @param impurityCalc Impurity measure calculator. + * @return Array of impurity measure functions for all columns. + */ + private StepFunction<T>[] calculateImpurityForAllColumns(Dataset<EmptyContext, DecisionTreeData> dataset, + TreeFilter filter, ImpurityMeasureCalculator<T> impurityCalc) { + return dataset.compute( + part -> { + if (compressor != null) + return compressor.compress(impurityCalc.calculate(part.filter(filter))); + else + return impurityCalc.calculate(part.filter(filter)); + }, this::reduce + ); + } + + /** + * Calculates best split point. + * + * @param criterionFunctions Array of impurity measure functions for all columns. + * @return Best split point. + */ + private SplitPoint calculateBestSplitPoint(StepFunction<T>[] criterionFunctions) { + SplitPoint<T> res = null; + + for (int col = 0; col < criterionFunctions.length; col++) { + StepFunction<T> criterionFunctionForCol = criterionFunctions[col]; + + double[] arguments = criterionFunctionForCol.getX(); + T[] values = criterionFunctionForCol.getY(); + + for (int leftSize = 1; leftSize < values.length - 1; leftSize++) { + if ((values[0].impurity() - values[leftSize].impurity()) > minImpurityDecrease + && (res == null || values[leftSize].compareTo(res.val) < 0)) + res = new SplitPoint<>(values[leftSize], col, calculateThreshold(arguments, leftSize)); + } + } + + return res; + } + + /** + * Merges two arrays gotten from two partitions. + * + * @param a First step function. + * @param b Second step function. + * @return Merged step function. + */ + private StepFunction<T>[] reduce(StepFunction<T>[] a, StepFunction<T>[] b) { + if (a == null) + return b; + if (b == null) + return a; + else { + StepFunction<T>[] res = Arrays.copyOf(a, a.length); + + for (int i = 0; i < res.length; i++) + res[i] = res[i].add(b[i]); + + return res; + } + } + + /** + * Calculates threshold based on the given step function arguments and split point (specified left size). + * + * @param arguments Step function arguments. + * @param leftSize Split point (left size). + * @return Threshold. + */ + private double calculateThreshold(double[] arguments, int leftSize) { + return (arguments[leftSize] + arguments[leftSize + 1]) / 2.0; + } + + /** + * Constructs a new predicate for "then" node based on the parent node predicate and split point. + * + * @param filter Parent node predicate. + * @param splitPnt Split point. + * @return Predicate for "then" node. + */ + private TreeFilter updatePredicateForThenNode(TreeFilter filter, SplitPoint splitPnt) { + return filter.and(f -> f[splitPnt.col] > splitPnt.threshold); + } + + /** + * Constructs a new predicate for "else" node based on the parent node predicate and split point. + * + * @param filter Parent node predicate. + * @param splitPnt Split point. + * @return Predicate for "else" node. + */ + private TreeFilter updatePredicateForElseNode(TreeFilter filter, SplitPoint splitPnt) { + return filter.and(f -> f[splitPnt.col] <= splitPnt.threshold); + } + + /** + * Util class that represents split point. + */ + private static class SplitPoint<T extends ImpurityMeasure<T>> implements Serializable { + /** */ + private static final long serialVersionUID = -1758525953544425043L; + + /** Split point impurity measure value. */ + private final T val; + + /** Column. */ + private final int col; + + /** Threshold. */ + private final double threshold; + + /** + * Constructs a new instance of split point. + * + * @param val Split point impurity measure value. + * @param col Column. + * @param threshold Threshold. + */ + SplitPoint(T val, int col, double threshold) { + this.val = val; + this.col = col; + this.threshold = threshold; + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java new file mode 100644 index 0000000..ce75190 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator; +import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasure; +import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasureCalculator; +import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor; +import org.apache.ignite.ml.tree.leaf.MostCommonDecisionTreeLeafBuilder; + +/** + * Decision tree classifier based on distributed decision tree trainer that allows to fit trees using row-partitioned + * dataset. + */ +public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurityMeasure> { + /** + * Constructs a new decision tree classifier with default impurity function compressor. + * + * @param maxDeep Max tree deep. + * @param minImpurityDecrease Min impurity decrease. + */ + public DecisionTreeClassificationTrainer(int maxDeep, double minImpurityDecrease) { + this(maxDeep, minImpurityDecrease, null); + } + + /** + * Constructs a new instance of decision tree classifier. + * + * @param maxDeep Max tree deep. + * @param minImpurityDecrease Min impurity decrease. + */ + public DecisionTreeClassificationTrainer(int maxDeep, double minImpurityDecrease, + StepFunctionCompressor<GiniImpurityMeasure> compressor) { + super(maxDeep, minImpurityDecrease, compressor, new MostCommonDecisionTreeLeafBuilder()); + } + + /** {@inheritDoc} */ + @Override ImpurityMeasureCalculator<GiniImpurityMeasure> getImpurityMeasureCalculator( + Dataset<EmptyContext, DecisionTreeData> dataset) { + Set<Double> labels = dataset.compute(part -> { + + if (part.getLabels() != null) { + Set<Double> list = new HashSet<>(); + + for (double lb : part.getLabels()) + list.add(lb); + + return list; + } + + return null; + }, (a, b) -> { + if (a == null) + return b; + else if (b == null) + return a; + else { + a.addAll(b); + return a; + } + }); + + Map<Double, Integer> encoder = new HashMap<>(); + + int idx = 0; + for (Double lb : labels) + encoder.put(lb, idx++); + + return new GiniImpurityMeasureCalculator(encoder); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java new file mode 100644 index 0000000..9818239 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +/** + * Decision tree conditional (non-leaf) node. + */ +public class DecisionTreeConditionalNode implements DecisionTreeNode { + /** */ + private static final long serialVersionUID = 981630737007982172L; + + /** Column of the value to be tested. */ + private final int col; + + /** Threshold. */ + private final double threshold; + + /** Node that will be used in case tested value is greater then threshold. */ + private final DecisionTreeNode thenNode; + + /** Node that will be used in case tested value is not greater then threshold. */ + private final DecisionTreeNode elseNode; + + /** + * Constructs a new instance of decision tree conditional node. + * + * @param col Column of the value to be tested. + * @param threshold Threshold. + * @param thenNode Node that will be used in case tested value is greater then threshold. + * @param elseNode Node that will be used in case tested value is not greater then threshold. + */ + DecisionTreeConditionalNode(int col, double threshold, DecisionTreeNode thenNode, DecisionTreeNode elseNode) { + this.col = col; + this.threshold = threshold; + this.thenNode = thenNode; + this.elseNode = elseNode; + } + + /** {@inheritDoc} */ + @Override public Double apply(double[] features) { + return features[col] > threshold ? thenNode.apply(features) : elseNode.apply(features); + } + + /** */ + public int getCol() { + return col; + } + + /** */ + public double getThreshold() { + return threshold; + } + + /** */ + public DecisionTreeNode getThenNode() { + return thenNode; + } + + /** */ + public DecisionTreeNode getElseNode() { + return elseNode; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java new file mode 100644 index 0000000..4c6369d --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +/** + * Decision tree leaf node which contains value. + */ +public class DecisionTreeLeafNode implements DecisionTreeNode { + /** */ + private static final long serialVersionUID = -472145568088482206L; + + /** Value of the node. */ + private final double val; + + /** + * Constructs a new decision tree leaf node. + * + * @param val Value of the node. + */ + public DecisionTreeLeafNode(double val) { + this.val = val; + } + + /** {@inheritDoc} */ + @Override public Double apply(double[] doubles) { + return val; + } + + /** */ + public double getVal() { + return val; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java new file mode 100644 index 0000000..94878eb --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +import org.apache.ignite.ml.Model; + +/** + * Base interface for decision tree nodes. + */ +public interface DecisionTreeNode extends Model<double[], Double> { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java new file mode 100644 index 0000000..2bf09d3 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator; +import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasure; +import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasureCalculator; +import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor; +import org.apache.ignite.ml.tree.leaf.MeanDecisionTreeLeafBuilder; + +/** + * Decision tree regressor based on distributed decision tree trainer that allows to fit trees using row-partitioned + * dataset. + */ +public class DecisionTreeRegressionTrainer extends DecisionTree<MSEImpurityMeasure> { + /** + * Constructs a new decision tree regressor with default impurity function compressor. + * + * @param maxDeep Max tree deep. + * @param minImpurityDecrease Min impurity decrease. + */ + public DecisionTreeRegressionTrainer(int maxDeep, double minImpurityDecrease) { + this(maxDeep, minImpurityDecrease, null); + } + + /** + * Constructs a new decision tree regressor. + * + * @param maxDeep Max tree deep. + * @param minImpurityDecrease Min impurity decrease. + */ + public DecisionTreeRegressionTrainer(int maxDeep, double minImpurityDecrease, + StepFunctionCompressor<MSEImpurityMeasure> compressor) { + super(maxDeep, minImpurityDecrease, compressor, new MeanDecisionTreeLeafBuilder()); + } + + /** {@inheritDoc} */ + @Override ImpurityMeasureCalculator<MSEImpurityMeasure> getImpurityMeasureCalculator( + Dataset<EmptyContext, DecisionTreeData> dataset) { + return new MSEImpurityMeasureCalculator(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/TreeFilter.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/TreeFilter.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/TreeFilter.java new file mode 100644 index 0000000..3e4dc00 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/TreeFilter.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +import java.io.Serializable; +import java.util.Objects; +import java.util.function.Predicate; + +/** + * Predicate used to define objects that placed in decision tree node. + */ +public interface TreeFilter extends Predicate<double[]>, Serializable { + /** + * Returns a composed predicate. + * + * @param other Predicate that will be logically-ANDed with this predicate. + * @return Returns a composed predicate + */ + default TreeFilter and(TreeFilter other) { + Objects.requireNonNull(other); + return (t) -> test(t) && other.test(t); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java new file mode 100644 index 0000000..34deb46 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.data; + +import org.apache.ignite.ml.tree.TreeFilter; + +/** + * A partition {@code data} of the containing matrix of features and vector of labels stored in heap. + */ +public class DecisionTreeData implements AutoCloseable { + /** Matrix with features. */ + private final double[][] features; + + /** Vector with labels. */ + private final double[] labels; + + /** + * Constructs a new instance of decision tree data. + * + * @param features Matrix with features. + * @param labels Vector with labels. + */ + public DecisionTreeData(double[][] features, double[] labels) { + assert features.length == labels.length : "Features and labels have to be the same length"; + + this.features = features; + this.labels = labels; + } + + /** + * Filters objects and returns only data that passed filter. + * + * @param filter Filter. + * @return Data passed filter. + */ + public DecisionTreeData filter(TreeFilter filter) { + int size = 0; + + for (int i = 0; i < features.length; i++) + if (filter.test(features[i])) + size++; + + double[][] newFeatures = new double[size][]; + double[] newLabels = new double[size]; + + int ptr = 0; + + for (int i = 0; i < features.length; i++) { + if (filter.test(features[i])) { + newFeatures[ptr] = features[i]; + newLabels[ptr] = labels[i]; + + ptr++; + } + } + + return new DecisionTreeData(newFeatures, newLabels); + } + + /** + * Sorts data by specified column in ascending order. + * + * @param col Column. + */ + public void sort(int col) { + sort(col, 0, features.length - 1); + } + + /** */ + private void sort(int col, int from, int to) { + if (from < to) { + double pivot = features[(from + to) / 2][col]; + + int i = from, j = to; + + while (i <= j) { + while (features[i][col] < pivot) i++; + while (features[j][col] > pivot) j--; + + if (i <= j) { + double[] tmpFeature = features[i]; + features[i] = features[j]; + features[j] = tmpFeature; + + double tmpLb = labels[i]; + labels[i] = labels[j]; + labels[j] = tmpLb; + + i++; + j--; + } + } + + sort(col, from, j); + sort(col, i, to); + } + } + + /** */ + public double[][] getFeatures() { + return features; + } + + /** */ + public double[] getLabels() { + return labels; + } + + /** {@inheritDoc} */ + @Override public void close() { + // Do nothing, GC will clean up. + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java new file mode 100644 index 0000000..67109ae --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.data; + +import java.io.Serializable; +import java.util.Iterator; +import org.apache.ignite.ml.dataset.PartitionDataBuilder; +import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; + +/** + * A partition {@code data} builder that makes {@link DecisionTreeData}. + * + * @param <K> Type of a key in <tt>upstream</tt> data. + * @param <V> Type of a value in <tt>upstream</tt> data. + * @param <C> Type of a partition <tt>context</tt>. + */ +public class DecisionTreeDataBuilder<K, V, C extends Serializable> + implements PartitionDataBuilder<K, V, C, DecisionTreeData> { + /** */ + private static final long serialVersionUID = 3678784980215216039L; + + /** Function that extracts features from an {@code upstream} data. */ + private final IgniteBiFunction<K, V, double[]> featureExtractor; + + /** Function that extracts labels from an {@code upstream} data. */ + private final IgniteBiFunction<K, V, Double> lbExtractor; + + /** + * Constructs a new instance of decision tree data builder. + * + * @param featureExtractor Function that extracts features from an {@code upstream} data. + * @param lbExtractor Function that extracts labels from an {@code upstream} data. + */ + public DecisionTreeDataBuilder(IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + this.featureExtractor = featureExtractor; + this.lbExtractor = lbExtractor; + } + + /** {@inheritDoc} */ + @Override public DecisionTreeData build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) { + double[][] features = new double[Math.toIntExact(upstreamDataSize)][]; + double[] labels = new double[Math.toIntExact(upstreamDataSize)]; + + int ptr = 0; + while (upstreamData.hasNext()) { + UpstreamEntry<K, V> entry = upstreamData.next(); + + features[ptr] = featureExtractor.apply(entry.getKey(), entry.getValue()); + labels[ptr] = lbExtractor.apply(entry.getKey(), entry.getValue()); + + ptr++; + } + + return new DecisionTreeData(features, labels); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/package-info.java new file mode 100644 index 0000000..192b07f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains data and data builder required for decision tree trainers built on top of partition based dataset. + */ +package org.apache.ignite.ml.tree.data; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasure.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasure.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasure.java new file mode 100644 index 0000000..7ad2b80 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasure.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity; + +import java.io.Serializable; + +/** + * Base interface for impurity measures that can be used in distributed decision tree algorithm. + * + * @param <T> Type of this impurity measure. + */ +public interface ImpurityMeasure<T extends ImpurityMeasure<T>> extends Comparable<T>, Serializable { + /** + * Calculates impurity measure as a single double value. + * + * @return Impurity measure value. + */ + public double impurity(); + + /** + * Adds the given impurity to this. + * + * @param measure Another impurity. + * @return Sum of this and the given impurity. + */ + public T add(T measure); + + /** + * Subtracts the given impurity for this. + * + * @param measure Another impurity. + * @return Difference of this and the given impurity. + */ + public T subtract(T measure); + + /** {@inheritDoc} */ + default public int compareTo(T o) { + return Double.compare(impurity(), o.impurity()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java new file mode 100644 index 0000000..2b69356 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity; + +import java.io.Serializable; +import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.impurity.util.StepFunction; + +/** + * Base interface for impurity measure calculators that calculates all impurity measures required to find a best split. + * + * @param <T> Type of impurity measure. + */ +public interface ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> extends Serializable { + /** + * Calculates all impurity measures required required to find a best split and returns them as an array of + * {@link StepFunction} (for every column). + * + * @param data Features and labels. + * @return Impurity measures as an array of {@link StepFunction} (for every column). + */ + public StepFunction<T>[] calculate(DecisionTreeData data); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasure.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasure.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasure.java new file mode 100644 index 0000000..817baf5 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasure.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.gini; + +import org.apache.ignite.ml.tree.impurity.ImpurityMeasure; + +/** + * Gini impurity measure which is calculated the following way: + * {@code \-frac{1}{L}\sum_{i=1}^{s}l_i^2 - \frac{1}{R}\sum_{i=s+1}^{n}r_i^2}. + */ +public class GiniImpurityMeasure implements ImpurityMeasure<GiniImpurityMeasure> { + /** */ + private static final long serialVersionUID = 5338129703395229970L; + + /** Number of elements of each type in the left part. */ + private final long[] left; + + /** Number of elements of each type in the right part. */ + private final long[] right; + + /** + * Constructs a new instance of Gini impurity measure. + * + * @param left Number of elements of each type in the left part. + * @param right Number of elements of each type in the right part. + */ + GiniImpurityMeasure(long[] left, long[] right) { + assert left.length == right.length : "Left and right parts have to be the same length"; + + this.left = left; + this.right = right; + } + + /** {@inheritDoc} */ + @Override public double impurity() { + long leftCnt = 0; + long rightCnt = 0; + + double leftImpurity = 0; + double rightImpurity = 0; + + for (long e : left) + leftCnt += e; + + for (long e : right) + rightCnt += e; + + if (leftCnt > 0) + for (long e : left) + leftImpurity += Math.pow(e, 2) / leftCnt; + + if (rightCnt > 0) + for (long e : right) + rightImpurity += Math.pow(e, 2) / rightCnt; + + return -(leftImpurity + rightImpurity); + } + + /** {@inheritDoc} */ + @Override public GiniImpurityMeasure add(GiniImpurityMeasure b) { + assert left.length == b.left.length : "Subtracted measure has to have length " + left.length; + assert left.length == b.right.length : "Subtracted measure has to have length " + left.length; + + long[] leftRes = new long[left.length]; + long[] rightRes = new long[left.length]; + + for (int i = 0; i < left.length; i++) { + leftRes[i] = left[i] + b.left[i]; + rightRes[i] = right[i] + b.right[i]; + } + + return new GiniImpurityMeasure(leftRes, rightRes); + } + + /** {@inheritDoc} */ + @Override public GiniImpurityMeasure subtract(GiniImpurityMeasure b) { + assert left.length == b.left.length : "Subtracted measure has to have length " + left.length; + assert left.length == b.right.length : "Subtracted measure has to have length " + left.length; + + long[] leftRes = new long[left.length]; + long[] rightRes = new long[left.length]; + + for (int i = 0; i < left.length; i++) { + leftRes[i] = left[i] - b.left[i]; + rightRes[i] = right[i] - b.right[i]; + } + + return new GiniImpurityMeasure(leftRes, rightRes); + } + + /** */ + public long[] getLeft() { + return left; + } + + /** */ + public long[] getRight() { + return right; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java new file mode 100644 index 0000000..0dd0a10 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.gini; + +import java.util.Arrays; +import java.util.Map; +import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator; +import org.apache.ignite.ml.tree.impurity.util.StepFunction; + +/** + * Gini impurity measure calculator. + */ +public class GiniImpurityMeasureCalculator implements ImpurityMeasureCalculator<GiniImpurityMeasure> { + /** */ + private static final long serialVersionUID = -522995134128519679L; + + /** Label encoder which defines integer value for every label class. */ + private final Map<Double, Integer> lbEncoder; + + /** + * Constructs a new instance of Gini impurity measure calculator. + * + * @param lbEncoder Label encoder which defines integer value for every label class. + */ + public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder) { + this.lbEncoder = lbEncoder; + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data) { + double[][] features = data.getFeatures(); + double[] labels = data.getLabels(); + + if (features.length > 0) { + StepFunction<GiniImpurityMeasure>[] res = new StepFunction[features[0].length]; + + for (int col = 0; col < res.length; col++) { + data.sort(col); + + double[] x = new double[features.length + 1]; + GiniImpurityMeasure[] y = new GiniImpurityMeasure[features.length + 1]; + + int xPtr = 0, yPtr = 0; + + long[] left = new long[lbEncoder.size()]; + long[] right = new long[lbEncoder.size()]; + + for (int i = 0; i < labels.length; i++) + right[getLabelCode(labels[i])]++; + + x[xPtr++] = Double.NEGATIVE_INFINITY; + y[yPtr++] = new GiniImpurityMeasure( + Arrays.copyOf(left, left.length), + Arrays.copyOf(right, right.length) + ); + + for (int i = 0; i < features.length; i++) { + left[getLabelCode(labels[i])]++; + right[getLabelCode(labels[i])]--; + + if (i < (features.length - 1) && features[i + 1][col] == features[i][col]) + continue; + + x[xPtr++] = features[i][col]; + y[yPtr++] = new GiniImpurityMeasure( + Arrays.copyOf(left, left.length), + Arrays.copyOf(right, right.length) + ); + } + + res[col] = new StepFunction<>(Arrays.copyOf(x, xPtr), Arrays.copyOf(y, yPtr)); + } + + return res; + } + + return null; + } + + /** + * Returns label code. + * + * @param lb Label. + * @return Label code. + */ + int getLabelCode(double lb) { + Integer code = lbEncoder.get(lb); + + assert code != null : "Can't find code for label " + lb; + + return code; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/package-info.java new file mode 100644 index 0000000..d14cd92 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains Gini impurity measure and calculator. + */ +package org.apache.ignite.ml.tree.impurity.gini; \ No newline at end of file