IGNITE-5218: First version of decision trees. This closes #2936
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/db7697b1 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/db7697b1 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/db7697b1 Branch: refs/heads/ignite-zk Commit: db7697b17cf6eb94754edb2b5e200655a3610dc1 Parents: 6579e69 Author: Artem Malykh <[email protected]> Authored: Fri Nov 10 18:03:33 2017 +0300 Committer: Igor Sapego <[email protected]> Committed: Fri Nov 10 18:03:33 2017 +0300 ---------------------------------------------------------------------- .gitignore | 2 + examples/pom.xml | 5 + .../examples/ml/math/trees/MNISTExample.java | 261 +++++++++ .../examples/ml/math/trees/package-info.java | 22 + modules/ml/licenses/netlib-java-bsd3.txt | 51 ++ modules/ml/pom.xml | 12 +- .../main/java/org/apache/ignite/ml/Model.java | 4 +- .../main/java/org/apache/ignite/ml/Trainer.java | 30 + .../clustering/KMeansDistributedClusterer.java | 19 +- .../apache/ignite/ml/estimators/Estimators.java | 50 ++ .../ignite/ml/estimators/package-info.java | 22 + .../ignite/ml/math/distributed/CacheUtils.java | 178 +++++- .../math/distributed/keys/MatrixCacheKey.java | 6 +- .../distributed/keys/impl/BlockMatrixKey.java | 17 +- .../distributed/keys/impl/SparseMatrixKey.java | 59 +- .../ignite/ml/math/functions/Functions.java | 38 ++ .../ml/math/functions/IgniteBinaryOperator.java | 29 + .../math/functions/IgniteCurriedBiFunction.java | 29 + .../ml/math/functions/IgniteSupplier.java | 30 + .../math/functions/IgniteToDoubleFunction.java | 25 + .../matrix/SparseBlockDistributedMatrix.java | 4 +- .../impls/matrix/SparseDistributedMatrix.java | 3 +- .../storage/matrix/BlockMatrixStorage.java | 12 +- .../matrix/SparseDistributedMatrixStorage.java | 17 +- .../ignite/ml/structures/LabeledVector.java | 63 +++ .../ml/structures/LabeledVectorDouble.java | 46 ++ .../ignite/ml/structures/package-info.java | 22 + .../ignite/ml/trees/CategoricalRegionInfo.java | 72 +++ .../ignite/ml/trees/CategoricalSplitInfo.java | 68 +++ .../ignite/ml/trees/ContinuousRegionInfo.java | 74 +++ .../ml/trees/ContinuousSplitCalculator.java | 50 ++ .../org/apache/ignite/ml/trees/RegionInfo.java | 62 +++ .../ml/trees/models/DecisionTreeModel.java | 44 ++ .../ignite/ml/trees/models/package-info.java | 22 + .../ml/trees/nodes/CategoricalSplitNode.java | 50 ++ .../ml/trees/nodes/ContinuousSplitNode.java | 56 ++ .../ignite/ml/trees/nodes/DecisionTreeNode.java | 33 ++ .../org/apache/ignite/ml/trees/nodes/Leaf.java | 49 ++ .../apache/ignite/ml/trees/nodes/SplitNode.java | 100 ++++ .../ignite/ml/trees/nodes/package-info.java | 22 + .../apache/ignite/ml/trees/package-info.java | 22 + .../ml/trees/trainers/columnbased/BiIndex.java | 113 ++++ ...exedCacheColumnDecisionTreeTrainerInput.java | 57 ++ .../CacheColumnDecisionTreeTrainerInput.java | 142 +++++ .../columnbased/ColumnDecisionTreeTrainer.java | 557 +++++++++++++++++++ .../ColumnDecisionTreeTrainerInput.java | 55 ++ .../MatrixColumnDecisionTreeTrainerInput.java | 82 +++ .../trainers/columnbased/RegionProjection.java | 109 ++++ .../trainers/columnbased/TrainingContext.java | 166 ++++++ .../columnbased/caches/ContextCache.java | 68 +++ .../columnbased/caches/FeaturesCache.java | 151 +++++ .../columnbased/caches/ProjectionsCache.java | 284 ++++++++++ .../trainers/columnbased/caches/SplitCache.java | 206 +++++++ .../ContinuousSplitCalculators.java | 34 ++ .../contsplitcalcs/GiniSplitCalculator.java | 234 ++++++++ .../contsplitcalcs/VarianceSplitCalculator.java | 179 ++++++ .../contsplitcalcs/package-info.java | 22 + .../trainers/columnbased/package-info.java | 22 + .../columnbased/regcalcs/RegionCalculators.java | 85 +++ .../columnbased/regcalcs/package-info.java | 22 + .../vectors/CategoricalFeatureProcessor.java | 211 +++++++ .../vectors/ContinuousFeatureProcessor.java | 111 ++++ .../vectors/ContinuousSplitInfo.java | 54 ++ .../columnbased/vectors/FeatureProcessor.java | 81 +++ .../vectors/FeatureVectorProcessorUtils.java | 57 ++ .../columnbased/vectors/SampleInfo.java | 80 +++ .../trainers/columnbased/vectors/SplitInfo.java | 106 ++++ .../columnbased/vectors/package-info.java | 22 + .../org/apache/ignite/ml/util/MnistUtils.java | 121 ++++ .../java/org/apache/ignite/ml/util/Utils.java | 53 ++ .../org/apache/ignite/ml/IgniteMLTestSuite.java | 4 +- .../java/org/apache/ignite/ml/TestUtils.java | 15 + .../SparseDistributedBlockMatrixTest.java | 1 + .../ignite/ml/trees/BaseDecisionTreeTest.java | 70 +++ .../ml/trees/ColumnDecisionTreeTrainerTest.java | 190 +++++++ .../ignite/ml/trees/DecisionTreesTestSuite.java | 33 ++ .../ml/trees/GiniSplitCalculatorTest.java | 141 +++++ .../ignite/ml/trees/SplitDataGenerator.java | 390 +++++++++++++ .../ml/trees/VarianceSplitCalculatorTest.java | 84 +++ .../ColumnDecisionTreeTrainerBenchmark.java | 455 +++++++++++++++ .../trees/columntrees.manualrun.properties | 5 + 81 files changed, 6538 insertions(+), 114 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/.gitignore ---------------------------------------------------------------------- diff --git a/.gitignore b/.gitignore index d8dd951..18146f8 100644 --- a/.gitignore +++ b/.gitignore @@ -89,3 +89,5 @@ packages /modules/platforms/cpp/odbc-test/ignite-odbc-tests /modules/platforms/cpp/stamp-h1 +#Files related to ML manual-runnable tests +/modules/ml/src/test/resources/manualrun/trees/columntrees.manualrun.properties \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/examples/pom.xml ---------------------------------------------------------------------- diff --git a/examples/pom.xml b/examples/pom.xml index 30d23ae..2b95e65 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -248,6 +248,11 @@ <artifactId>ignite-ml</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>commons-cli</groupId> + <artifactId>commons-cli</artifactId> + <version>1.2</version> + </dependency> </dependencies> </profile> http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java new file mode 100644 index 0000000..6aaadd9 --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.math.trees; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Random; +import java.util.function.Function; +import java.util.stream.Stream; +import org.apache.commons.cli.BasicParser; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.OptionBuilder; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.IgniteDataStreamer; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.CacheAtomicityMode; +import org.apache.ignite.cache.CacheMode; +import org.apache.ignite.cache.CacheWriteSynchronizationMode; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.examples.ExampleNodeStartup; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.estimators.Estimators; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteTriFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.trees.models.DecisionTreeModel; +import org.apache.ignite.ml.trees.trainers.columnbased.BiIndex; +import org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator; +import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; +import org.apache.ignite.ml.util.MnistUtils; +import org.jetbrains.annotations.NotNull; + +/** + * <p> + * Example of usage of decision trees algorithm for MNIST dataset + * (it can be found here: http://yann.lecun.com/exdb/mnist/). </p> + * <p> + * Remote nodes should always be started with special configuration file which + * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p> + * <p> + * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node + * with {@code examples/config/example-ignite.xml} configuration.</p> + * <p> + * It is recommended to start at least one node prior to launching this example if you intend + * to run it with default memory settings.</p> + * <p> + * This example should with program arguments, for example + * -ts_i /path/to/train-images-idx3-ubyte + * -ts_l /path/to/train-labels-idx1-ubyte + * -tss_i /path/to/t10k-images-idx3-ubyte + * -tss_l /path/to/t10k-labels-idx1-ubyte + * -cfg examples/config/example-ignite.xml.</p> + * <p> + * -ts_i specifies path to training set images of MNIST; + * -ts_l specifies path to training set labels of MNIST; + * -tss_i specifies path to test set images of MNIST; + * -tss_l specifies path to test set labels of MNIST; + * -cfg specifies path to a config path.</p> + */ +public class MNISTExample { + /** Name of parameter specifying path to training set images. */ + private static final String MNIST_TRAINING_IMAGES_PATH = "ts_i"; + + /** Name of parameter specifying path to training set labels. */ + private static final String MNIST_TRAINING_LABELS_PATH = "ts_l"; + + /** Name of parameter specifying path to test set images. */ + private static final String MNIST_TEST_IMAGES_PATH = "tss_i"; + + /** Name of parameter specifying path to test set labels. */ + private static final String MNIST_TEST_LABELS_PATH = "tss_l"; + + /** Name of parameter specifying path of Ignite config. */ + private static final String CONFIG = "cfg"; + + /** Default config path. */ + private static final String DEFAULT_CONFIG = "examples/config/example-ignite.xml"; + + /** + * Launches example. + * + * @param args Program arguments. + */ + public static void main(String[] args) { + String igniteCfgPath; + + CommandLineParser parser = new BasicParser(); + + String trainingImagesPath; + String trainingLabelsPath; + + String testImagesPath; + String testLabelsPath; + + try { + // Parse the command line arguments. + CommandLine line = parser.parse(buildOptions(), args); + + trainingImagesPath = line.getOptionValue(MNIST_TRAINING_IMAGES_PATH); + trainingLabelsPath = line.getOptionValue(MNIST_TRAINING_LABELS_PATH); + testImagesPath = line.getOptionValue(MNIST_TEST_IMAGES_PATH); + testLabelsPath = line.getOptionValue(MNIST_TEST_LABELS_PATH); + igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG); + } + catch (ParseException e) { + e.printStackTrace(); + return; + } + + try (Ignite ignite = Ignition.start(igniteCfgPath)) { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + + int ptsCnt = 60000; + int featCnt = 28 * 28; + + Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath, new Random(123L), ptsCnt); + Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath, new Random(123L), 10_000); + + IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite); + + loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite); + + ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = + new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite); + + System.out.println(">>> Training started"); + long before = System.currentTimeMillis(); + DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt)); + System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); + + IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage(); + Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); + System.out.println(">>> Errs percentage: " + accuracy); + } + catch (IOException e) { + e.printStackTrace(); + } + } + + /** + * Build cli options. + */ + @NotNull private static Options buildOptions() { + Options options = new Options(); + + Option trsImagesPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_IMAGES_PATH).withLongOpt(MNIST_TRAINING_IMAGES_PATH).hasArg() + .withDescription("Path to the MNIST training set.") + .isRequired(true).create(); + + Option trsLabelsPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_LABELS_PATH).withLongOpt(MNIST_TRAINING_LABELS_PATH).hasArg() + .withDescription("Path to the MNIST training set.") + .isRequired(true).create(); + + Option tssImagesPathOpt = OptionBuilder.withArgName(MNIST_TEST_IMAGES_PATH).withLongOpt(MNIST_TEST_IMAGES_PATH).hasArg() + .withDescription("Path to the MNIST test set.") + .isRequired(true).create(); + + Option tssLabelsPathOpt = OptionBuilder.withArgName(MNIST_TEST_LABELS_PATH).withLongOpt(MNIST_TEST_LABELS_PATH).hasArg() + .withDescription("Path to the MNIST test set.") + .isRequired(true).create(); + + Option configOpt = OptionBuilder.withArgName(CONFIG).withLongOpt(CONFIG).hasArg() + .withDescription("Path to the config.") + .isRequired(false).create(); + + options.addOption(trsImagesPathOpt); + options.addOption(trsLabelsPathOpt); + options.addOption(tssImagesPathOpt); + options.addOption(tssLabelsPathOpt); + options.addOption(configOpt); + + return options; + } + + /** + * Creates cache where data for training is stored. + * + * @param ignite Ignite instance. + * @return cache where data for training is stored. + */ + private static IgniteCache<BiIndex, Double> createBiIndexedCache(Ignite ignite) { + CacheConfiguration<BiIndex, Double> cfg = new CacheConfiguration<>(); + + // Write to primary. + cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); + + // Atomic transactions only. + cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); + + // No eviction. + cfg.setEvictionPolicy(null); + + // No copying of values. + cfg.setCopyOnRead(false); + + // Cache is partitioned. + cfg.setCacheMode(CacheMode.PARTITIONED); + + cfg.setBackups(0); + + cfg.setName("TMP_BI_INDEXED_CACHE"); + + return ignite.getOrCreateCache(cfg); + } + + /** + * Loads vectors into cache. + * + * @param cacheName Name of cache. + * @param vectorsIterator Iterator over vectors to load. + * @param vectorSize Size of vector. + * @param ignite Ignite instance. + */ + private static void loadVectorsIntoBiIndexedCache(String cacheName, Iterator<? extends Vector> vectorsIterator, + int vectorSize, Ignite ignite) { + try (IgniteDataStreamer<BiIndex, Double> streamer = + ignite.dataStreamer(cacheName)) { + int sampleIdx = 0; + + streamer.perNodeBufferSize(10000); + + while (vectorsIterator.hasNext()) { + org.apache.ignite.ml.math.Vector next = vectorsIterator.next(); + + for (int i = 0; i < vectorSize; i++) + streamer.addData(new BiIndex(sampleIdx, i), next.getX(i)); + + sampleIdx++; + + if (sampleIdx % 1000 == 0) + System.out.println("Loaded " + sampleIdx + " vectors."); + } + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java new file mode 100644 index 0000000..9b6867b --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Decision trees examples. + */ +package org.apache.ignite.examples.ml.math.trees; http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/licenses/netlib-java-bsd3.txt ---------------------------------------------------------------------- diff --git a/modules/ml/licenses/netlib-java-bsd3.txt b/modules/ml/licenses/netlib-java-bsd3.txt new file mode 100644 index 0000000..d6b30c1 --- /dev/null +++ b/modules/ml/licenses/netlib-java-bsd3.txt @@ -0,0 +1,51 @@ +This product binaries redistribute netlib-java which is available under the following license: + +Copyright (c) 2013 Samuel Halliday +Copyright (c) 1992-2011 The University of Tennessee and The University + of Tennessee Research Foundation. All rights + reserved. +Copyright (c) 2000-2011 The University of California Berkeley. All + rights reserved. +Copyright (c) 2006-2011 The University of Colorado Denver. All rights + reserved. + +$COPYRIGHT$ + +Additional copyrights may follow + +$HEADER$ + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +- Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer listed + in this license in the documentation and/or other materials + provided with the distribution. + +- Neither the name of the copyright holders nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +The copyright holders provide no reassurances that the source code +provided does not infringe any patent, copyright, or any other +intellectual property rights of third parties. The copyright holders +disclaim any liability to any recipient for claims brought against +recipient by any third party for infringement of that parties +intellectual property rights. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/pom.xml ---------------------------------------------------------------------- diff --git a/modules/ml/pom.xml b/modules/ml/pom.xml index 94cfc51..c495f44 100644 --- a/modules/ml/pom.xml +++ b/modules/ml/pom.xml @@ -75,13 +75,6 @@ <dependency> <groupId>org.springframework</groupId> - <artifactId>spring-beans</artifactId> - <version>${spring.version}</version> - <scope>test</scope> - </dependency> - - <dependency> - <groupId>org.springframework</groupId> <artifactId>spring-context</artifactId> <version>${spring.version}</version> <scope>test</scope> @@ -105,6 +98,11 @@ <version>1.0</version> </dependency> + <dependency> + <groupId>com.zaxxer</groupId> + <artifactId>SparseBitSet</artifactId> + <version>1.0</version> + </dependency> </dependencies> <profiles> http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/Model.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/Model.java b/modules/ml/src/main/java/org/apache/ignite/ml/Model.java index 3c60bfa..05ce774 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/Model.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/Model.java @@ -24,7 +24,7 @@ import java.util.function.BiFunction; @FunctionalInterface public interface Model<T, V> extends Serializable { /** Predict a result for value. */ - public V predict(T val); + V predict(T val); /** * Combines this model with other model via specified combiner @@ -33,7 +33,7 @@ public interface Model<T, V> extends Serializable { * @param combiner Combiner. * @return Combination of models. */ - public default <X, W> Model<T, X> combine(Model<T, W> other, BiFunction<V, W, X> combiner) { + default <X, W> Model<T, X> combine(Model<T, W> other, BiFunction<V, W, X> combiner) { return v -> combiner.apply(predict(v), other.predict(v)); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java new file mode 100644 index 0000000..795e218 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml; + +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; + +/** + * Interface for Trainers. Trainer is just a function which produces model from the data. + * See for example {@link ColumnDecisionTreeTrainer}. + * @param <M> Type of produced model. + * @param <T> Type of data needed for model producing. + */ +public interface Trainer<M extends Model, T> { + M train(T data); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java index d6a3fc3..6c25edc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Random; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import javax.cache.Cache; @@ -94,7 +95,7 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri boolean converged = false; int iteration = 0; int dim = pointsCp.viewRow(0).size(); - IgniteUuid uid = pointsCp.getUUID(); + UUID uid = pointsCp.getUUID(); // Execute iterations of Lloyd's algorithm until converged while (iteration < maxIterations && !converged) { @@ -140,7 +141,7 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri // to their squared distance from the centers. Note that only distances between points // and new centers are computed in each iteration. int step = 0; - IgniteUuid uid = points.getUUID(); + UUID uid = points.getUUID(); while (step < initSteps) { // We assume here that costs can fit into memory of one node. @@ -180,7 +181,7 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri } /** */ - private List<Vector> getNewCenters(int k, ConcurrentHashMap<Integer, Double> costs, IgniteUuid uid, + private List<Vector> getNewCenters(int k, ConcurrentHashMap<Integer, Double> costs, UUID uid, double sumCosts, String cacheName) { return distributedFold(cacheName, (IgniteBiFunction<Cache.Entry<SparseMatrixKey, Map<Integer, Double>>, @@ -200,7 +201,7 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri list1.addAll(list2); return list1; }, - new ArrayList<>() + ArrayList::new ); } @@ -219,11 +220,11 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri (map1, map2) -> { map1.putAll(map2); return map1; - }, new ConcurrentHashMap<>()); + }, ConcurrentHashMap::new); } /** */ - private ConcurrentHashMap<Integer, Integer> weightCenters(IgniteUuid uid, List<Vector> distinctCenters, String cacheName) { + private ConcurrentHashMap<Integer, Integer> weightCenters(UUID uid, List<Vector> distinctCenters, String cacheName) { return distributedFold(cacheName, (IgniteBiFunction<Cache.Entry<SparseMatrixKey, Map<Integer, Double>>, ConcurrentHashMap<Integer, Integer>, @@ -249,7 +250,7 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri key -> key.matrixId().equals(uid), (map1, map2) -> MapUtil.mergeMaps(map1, map2, (integer, integer2) -> integer2 + integer, ConcurrentHashMap::new), - new ConcurrentHashMap<>()); + ConcurrentHashMap::new); } /** */ @@ -258,7 +259,7 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri } /** */ - private SumsAndCounts getSumsAndCounts(Vector[] centers, int dim, IgniteUuid uid, String cacheName) { + private SumsAndCounts getSumsAndCounts(Vector[] centers, int dim, UUID uid, String cacheName) { return CacheUtils.distributedFold(cacheName, (IgniteBiFunction<Cache.Entry<SparseMatrixKey, Map<Integer, Double>>, SumsAndCounts, SumsAndCounts>)(entry, counts) -> { Map<Integer, Double> vec = entry.getValue(); @@ -278,7 +279,7 @@ public class KMeansDistributedClusterer extends BaseKMeansClusterer<SparseDistri return counts; }, key -> key.matrixId().equals(uid), - SumsAndCounts::merge, new SumsAndCounts() + SumsAndCounts::merge, SumsAndCounts::new ); } http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/estimators/Estimators.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/estimators/Estimators.java b/modules/ml/src/main/java/org/apache/ignite/ml/estimators/Estimators.java new file mode 100644 index 0000000..13331d1 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/estimators/Estimators.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.estimators; + +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.stream.Stream; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.functions.IgniteTriFunction; + +/** Estimators. */ +public class Estimators { + /** Simple implementation of mean squared error estimator. */ + public static <T, V> IgniteTriFunction<Model<T, V>, Stream<IgniteBiTuple<T, V>>, Function<V, Double>, Double> MSE() { + return (model, stream, f) -> stream.mapToDouble(dp -> { + double diff = f.apply(dp.get2()) - f.apply(model.predict(dp.get1())); + return diff * diff; + }).average().orElse(0); + } + + /** Simple implementation of errors percentage estimator. */ + public static <T, V> IgniteTriFunction<Model<T, V>, Stream<IgniteBiTuple<T, V>>, Function<V, Double>, Double> errorsPercentage() { + return (model, stream, f) -> { + AtomicLong total = new AtomicLong(0); + + long cnt = stream. + peek((ib) -> total.incrementAndGet()). + filter(dp -> !model.predict(dp.get1()).equals(dp.get2())). + count(); + + return (double)cnt / total.get(); + }; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/estimators/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/estimators/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/estimators/package-info.java new file mode 100644 index 0000000..c03827f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/estimators/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains estimation algorithms. + */ +package org.apache.ignite.ml.estimators; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java index 8c8bba7..b9eb386 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java @@ -21,7 +21,11 @@ import java.util.Collection; import java.util.Collections; import java.util.Map; import java.util.Objects; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.BinaryOperator; +import java.util.stream.Stream; import javax.cache.Cache; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; @@ -32,17 +36,21 @@ import org.apache.ignite.cluster.ClusterGroup; import org.apache.ignite.cluster.ClusterNode; import org.apache.ignite.internal.processors.cache.CacheEntryImpl; import org.apache.ignite.internal.util.typedef.internal.A; +import org.apache.ignite.lang.IgniteBiTuple; import org.apache.ignite.lang.IgniteCallable; import org.apache.ignite.lang.IgnitePredicate; import org.apache.ignite.lang.IgniteRunnable; import org.apache.ignite.lang.IgniteUuid; import org.apache.ignite.ml.math.KeyMapper; -import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey; -import org.apache.ignite.ml.math.distributed.keys.impl.BlockMatrixKey; +import org.apache.ignite.ml.math.distributed.keys.BlockMatrixKey; +import org.apache.ignite.ml.math.distributed.keys.MatrixCacheKey; import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteBinaryOperator; import org.apache.ignite.ml.math.functions.IgniteConsumer; import org.apache.ignite.ml.math.functions.IgniteDoubleFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.functions.IgniteSupplier; +import org.apache.ignite.ml.math.functions.IgniteTriFunction; import org.apache.ignite.ml.math.impls.matrix.BlockEntry; /** @@ -131,7 +139,7 @@ public class CacheUtils { * @return Sum obtained using sparse logic. */ @SuppressWarnings("unchecked") - public static <K, V> double sparseSum(IgniteUuid matrixUuid, String cacheName) { + public static <K, V> double sparseSum(UUID matrixUuid, String cacheName) { A.notNull(matrixUuid, "matrixUuid"); A.notNull(cacheName, "cacheName"); @@ -198,7 +206,7 @@ public class CacheUtils { * @return Minimum value obtained using sparse logic. */ @SuppressWarnings("unchecked") - public static <K, V> double sparseMin(IgniteUuid matrixUuid, String cacheName) { + public static <K, V> double sparseMin(UUID matrixUuid, String cacheName) { A.notNull(matrixUuid, "matrixUuid"); A.notNull(cacheName, "cacheName"); @@ -235,7 +243,7 @@ public class CacheUtils { * @return Maximum value obtained using sparse logic. */ @SuppressWarnings("unchecked") - public static <K, V> double sparseMax(IgniteUuid matrixUuid, String cacheName) { + public static <K, V> double sparseMax(UUID matrixUuid, String cacheName) { A.notNull(matrixUuid, "matrixUuid"); A.notNull(cacheName, "cacheName"); @@ -316,7 +324,7 @@ public class CacheUtils { * @param mapper Mapping {@link IgniteFunction}. */ @SuppressWarnings("unchecked") - public static <K, V> void sparseMap(IgniteUuid matrixUuid, IgniteDoubleFunction<Double> mapper, String cacheName) { + public static <K, V> void sparseMap(UUID matrixUuid, IgniteDoubleFunction<Double> mapper, String cacheName) { A.notNull(matrixUuid, "matrixUuid"); A.notNull(cacheName, "cacheName"); A.notNull(mapper, "mapper"); @@ -350,12 +358,12 @@ public class CacheUtils { * * @param matrixUuid Matrix uuid. */ - private static <K> IgnitePredicate<K> sparseKeyFilter(IgniteUuid matrixUuid) { + private static <K> IgnitePredicate<K> sparseKeyFilter(UUID matrixUuid) { return key -> { - if (key instanceof BlockMatrixKey) - return ((BlockMatrixKey)key).matrixId().equals(matrixUuid); - else if (key instanceof RowColMatrixKey) - return ((RowColMatrixKey)key).matrixId().equals(matrixUuid); + if (key instanceof MatrixCacheKey) + return ((MatrixCacheKey)key).matrixId().equals(matrixUuid); + else if (key instanceof IgniteBiTuple) + return ((IgniteBiTuple<Integer, UUID>)key).get2().equals(matrixUuid); else throw new UnsupportedOperationException(); }; @@ -404,6 +412,76 @@ public class CacheUtils { } /** + * @param cacheName Cache name. + * @param fun An operation that accepts a cache entry and processes it. + * @param ignite Ignite. + * @param keysGen Keys generator. + * @param <K> Cache key object type. + * @param <V> Cache value object type. + */ + public static <K, V> void update(String cacheName, Ignite ignite, + IgniteBiFunction<Ignite, Cache.Entry<K, V>, Stream<Cache.Entry<K, V>>> fun, IgniteSupplier<Set<K>> keysGen) { + bcast(cacheName, ignite, () -> { + Ignite ig = Ignition.localIgnite(); + IgniteCache<K, V> cache = ig.getOrCreateCache(cacheName); + + Affinity<K> affinity = ig.affinity(cacheName); + ClusterNode locNode = ig.cluster().localNode(); + + Collection<K> ks = affinity.mapKeysToNodes(keysGen.get()).get(locNode); + + if (ks == null) + return; + + Map<K, V> m = new ConcurrentHashMap<>(); + + ks.parallelStream().forEach(k -> { + V v = cache.localPeek(k); + if (v != null) + (fun.apply(ignite, new CacheEntryImpl<>(k, v))).forEach(ent -> m.put(ent.getKey(), ent.getValue())); + }); + + cache.putAll(m); + }); + } + + /** + * @param cacheName Cache name. + * @param fun An operation that accepts a cache entry and processes it. + * @param ignite Ignite. + * @param keysGen Keys generator. + * @param <K> Cache key object type. + * @param <V> Cache value object type. + */ + public static <K, V> void update(String cacheName, Ignite ignite, IgniteConsumer<Cache.Entry<K, V>> fun, + IgniteSupplier<Set<K>> keysGen) { + bcast(cacheName, ignite, () -> { + Ignite ig = Ignition.localIgnite(); + IgniteCache<K, V> cache = ig.getOrCreateCache(cacheName); + + Affinity<K> affinity = ig.affinity(cacheName); + ClusterNode locNode = ig.cluster().localNode(); + + Collection<K> ks = affinity.mapKeysToNodes(keysGen.get()).get(locNode); + + if (ks == null) + return; + + Map<K, V> m = new ConcurrentHashMap<>(); + + for (K k : ks) { + V v = cache.localPeek(k); + fun.accept(new CacheEntryImpl<>(k, v)); + m.put(k, v); + } + + long before = System.currentTimeMillis(); + cache.putAll(m); + System.out.println("PutAll took: " + (System.currentTimeMillis() - before)); + }); + } + + /** * <b>Currently fold supports only commutative operations.<b/> * * @param cacheName Cache name. @@ -463,11 +541,11 @@ public class CacheUtils { * @param folder Folder. * @param keyFilter Key filter. * @param accumulator Accumulator. - * @param zeroVal Zero value. + * @param zeroValSupp Zero value supplier. */ public static <K, V, A> A distributedFold(String cacheName, IgniteBiFunction<Cache.Entry<K, V>, A, A> folder, - IgnitePredicate<K> keyFilter, BinaryOperator<A> accumulator, A zeroVal) { - return sparseFold(cacheName, folder, keyFilter, accumulator, zeroVal, null, null, 0, + IgnitePredicate<K> keyFilter, BinaryOperator<A> accumulator, IgniteSupplier<A> zeroValSupp) { + return sparseFold(cacheName, folder, keyFilter, accumulator, zeroValSupp, null, null, 0, false); } @@ -478,17 +556,17 @@ public class CacheUtils { * @param folder Folder. * @param keyFilter Key filter. * @param accumulator Accumulator. - * @param zeroVal Zero value. - * @param defVal Def value. - * @param defKey Def key. + * @param zeroValSupp Zero value supplier. + * @param defVal Default value. + * @param defKey Default key. * @param defValCnt Def value count. * @param isNilpotent Is nilpotent. */ private static <K, V, A> A sparseFold(String cacheName, IgniteBiFunction<Cache.Entry<K, V>, A, A> folder, - IgnitePredicate<K> keyFilter, BinaryOperator<A> accumulator, A zeroVal, V defVal, K defKey, long defValCnt, - boolean isNilpotent) { + IgnitePredicate<K> keyFilter, BinaryOperator<A> accumulator, IgniteSupplier<A> zeroValSupp, V defVal, K defKey, + long defValCnt, boolean isNilpotent) { - A defRes = zeroVal; + A defRes = zeroValSupp.get(); if (!isNilpotent) for (int i = 0; i < defValCnt; i++) @@ -504,7 +582,7 @@ public class CacheUtils { Affinity affinity = ignite.affinity(cacheName); ClusterNode locNode = ignite.cluster().localNode(); - A a = zeroVal; + A a = zeroValSupp.get(); // Iterate over all partitions. Some of them will be stored on that local node. for (int part = 0; part < partsCnt; part++) { @@ -519,16 +597,54 @@ public class CacheUtils { return a; }); - totalRes.add(defRes); - return totalRes.stream().reduce(zeroVal, accumulator); + return totalRes.stream().reduce(defRes, accumulator); + } + + public static <K, V, A, W> A reduce(String cacheName, Ignite ignite, + IgniteTriFunction<W, Cache.Entry<K, V>, A, A> acc, + IgniteSupplier<W> supp, + IgniteSupplier<Iterable<Cache.Entry<K, V>>> entriesGen, IgniteBinaryOperator<A> comb, + IgniteSupplier<A> zeroValSupp) { + + A defRes = zeroValSupp.get(); + + Collection<A> totalRes = bcast(cacheName, ignite, () -> { + // Use affinity in filter for ScanQuery. Otherwise we accept consumer in each node which is wrong. + A a = zeroValSupp.get(); + + W w = supp.get(); + + for (Cache.Entry<K, V> kvEntry : entriesGen.get()) + a = acc.apply(w, kvEntry, a); + + return a; + }); + + return totalRes.stream().reduce(defRes, comb); + } + + public static <K, V, A, W> A reduce(String cacheName, IgniteTriFunction<W, Cache.Entry<K, V>, A, A> acc, + IgniteSupplier<W> supp, + IgniteSupplier<Iterable<Cache.Entry<K, V>>> entriesGen, IgniteBinaryOperator<A> comb, + IgniteSupplier<A> zeroValSupp) { + return reduce(cacheName, Ignition.localIgnite(), acc, supp, entriesGen, comb, zeroValSupp); } /** * @param cacheName Cache name. * @param run {@link Runnable} to broadcast to cache nodes for given cache name. */ + public static void bcast(String cacheName, Ignite ignite, IgniteRunnable run) { + ignite.compute(ignite.cluster().forDataNodes(cacheName)).broadcast(run); + } + + /** + * Broadcast runnable to data nodes of given cache. + * @param cacheName Cache name. + * @param run Runnable. + */ public static void bcast(String cacheName, IgniteRunnable run) { - ignite().compute(ignite().cluster().forCacheNodes(cacheName)).broadcast(run); + bcast(cacheName, ignite(), run); } /** @@ -537,6 +653,18 @@ public class CacheUtils { * @param <A> Type returned by the callable. */ public static <A> Collection<A> bcast(String cacheName, IgniteCallable<A> call) { - return ignite().compute(ignite().cluster().forCacheNodes(cacheName)).broadcast(call); + return bcast(cacheName, ignite(), call); + } + + /** + * Broadcast callable to data nodes of given cache. + * @param cacheName Cache name. + * @param ignite Ignite instance. + * @param call Callable to broadcast. + * @param <A> Type of callable result. + * @return Results of callable from each node. + */ + public static <A> Collection<A> bcast(String cacheName, Ignite ignite, IgniteCallable<A> call) { + return ignite.compute(ignite.cluster().forDataNodes(cacheName)).broadcast(call); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixCacheKey.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixCacheKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixCacheKey.java index 669e9a4..0242560 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixCacheKey.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/MatrixCacheKey.java @@ -17,7 +17,7 @@ package org.apache.ignite.ml.math.distributed.keys; -import org.apache.ignite.lang.IgniteUuid; +import java.util.UUID; /** * Base matrix cache key. @@ -26,10 +26,10 @@ public interface MatrixCacheKey { /** * @return matrix id. */ - public IgniteUuid matrixId(); + public UUID matrixId(); /** * @return affinity key. */ - public IgniteUuid affinityKey(); + public Object affinityKey(); } http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/BlockMatrixKey.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/BlockMatrixKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/BlockMatrixKey.java index 2edd9cb..cc8c488 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/BlockMatrixKey.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/BlockMatrixKey.java @@ -21,6 +21,7 @@ import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; +import java.util.UUID; import org.apache.ignite.binary.BinaryObjectException; import org.apache.ignite.binary.BinaryRawReader; import org.apache.ignite.binary.BinaryRawWriter; @@ -47,7 +48,7 @@ public class BlockMatrixKey implements org.apache.ignite.ml.math.distributed.key /** Block col ID */ private long blockIdCol; /** Matrix ID */ - private IgniteUuid matrixUuid; + private UUID matrixUuid; /** Block affinity key. */ private IgniteUuid affinityKey; @@ -64,7 +65,7 @@ public class BlockMatrixKey implements org.apache.ignite.ml.math.distributed.key * @param matrixUuid Matrix uuid. * @param affinityKey Affinity key. */ - public BlockMatrixKey(long rowId, long colId, IgniteUuid matrixUuid, @Nullable IgniteUuid affinityKey) { + public BlockMatrixKey(long rowId, long colId, UUID matrixUuid, @Nullable IgniteUuid affinityKey) { assert rowId >= 0; assert colId >= 0; assert matrixUuid != null; @@ -86,7 +87,7 @@ public class BlockMatrixKey implements org.apache.ignite.ml.math.distributed.key } /** {@inheritDoc} */ - @Override public IgniteUuid matrixId() { + @Override public UUID matrixId() { return matrixUuid; } @@ -97,7 +98,7 @@ public class BlockMatrixKey implements org.apache.ignite.ml.math.distributed.key /** {@inheritDoc} */ @Override public void writeExternal(ObjectOutput out) throws IOException { - U.writeGridUuid(out, matrixUuid); + out.writeObject(matrixUuid); U.writeGridUuid(out, affinityKey); out.writeLong(blockIdRow); out.writeLong(blockIdCol); @@ -105,7 +106,7 @@ public class BlockMatrixKey implements org.apache.ignite.ml.math.distributed.key /** {@inheritDoc} */ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - matrixUuid = U.readGridUuid(in); + matrixUuid = (UUID)in.readObject(); affinityKey = U.readGridUuid(in); blockIdRow = in.readLong(); blockIdCol = in.readLong(); @@ -115,7 +116,7 @@ public class BlockMatrixKey implements org.apache.ignite.ml.math.distributed.key @Override public void writeBinary(BinaryWriter writer) throws BinaryObjectException { BinaryRawWriter out = writer.rawWriter(); - BinaryUtils.writeIgniteUuid(out, matrixUuid); + out.writeUuid(matrixUuid); BinaryUtils.writeIgniteUuid(out, affinityKey); out.writeLong(blockIdRow); out.writeLong(blockIdCol); @@ -125,7 +126,7 @@ public class BlockMatrixKey implements org.apache.ignite.ml.math.distributed.key @Override public void readBinary(BinaryReader reader) throws BinaryObjectException { BinaryRawReader in = reader.rawReader(); - matrixUuid = BinaryUtils.readIgniteUuid(in); + matrixUuid = in.readUuid(); affinityKey = BinaryUtils.readIgniteUuid(in); blockIdRow = in.readLong(); blockIdCol = in.readLong(); @@ -160,6 +161,4 @@ public class BlockMatrixKey implements org.apache.ignite.ml.math.distributed.key @Override public String toString() { return S.toString(BlockMatrixKey.class, this); } - - } http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java index 0c34c8b..aa5e0ad 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java @@ -21,30 +21,24 @@ import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; -import org.apache.ignite.binary.BinaryObjectException; -import org.apache.ignite.binary.BinaryRawReader; -import org.apache.ignite.binary.BinaryRawWriter; -import org.apache.ignite.binary.BinaryReader; -import org.apache.ignite.binary.BinaryWriter; -import org.apache.ignite.binary.Binarylizable; -import org.apache.ignite.internal.binary.BinaryUtils; +import java.util.UUID; +import org.apache.ignite.cache.affinity.AffinityKeyMapped; import org.apache.ignite.internal.util.typedef.F; import org.apache.ignite.internal.util.typedef.internal.S; -import org.apache.ignite.internal.util.typedef.internal.U; -import org.apache.ignite.lang.IgniteUuid; import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey; import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; /** * Key implementation for {@link SparseDistributedMatrix}. */ -public class SparseMatrixKey implements RowColMatrixKey, Externalizable, Binarylizable { +public class SparseMatrixKey implements RowColMatrixKey, Externalizable { /** */ private int idx; /** */ - private IgniteUuid matrixId; + private UUID matrixId; /** */ - private IgniteUuid affinityKey; + @AffinityKeyMapped + private Object affinityKey; /** * Default constructor (required by Externalizable). @@ -56,7 +50,7 @@ public class SparseMatrixKey implements RowColMatrixKey, Externalizable, Binaryl /** * Build Key. */ - public SparseMatrixKey(int idx, IgniteUuid matrixId, IgniteUuid affinityKey) { + public SparseMatrixKey(int idx, UUID matrixId, Object affinityKey) { assert idx >= 0 : "Index must be positive."; assert matrixId != null : "Matrix id can`t be null."; @@ -71,54 +65,35 @@ public class SparseMatrixKey implements RowColMatrixKey, Externalizable, Binaryl } /** {@inheritDoc} */ - @Override public IgniteUuid matrixId() { + @Override public UUID matrixId() { return matrixId; } /** {@inheritDoc} */ - @Override public IgniteUuid affinityKey() { + @Override public Object affinityKey() { return affinityKey; } /** {@inheritDoc} */ @Override public void writeExternal(ObjectOutput out) throws IOException { - U.writeGridUuid(out, matrixId); - U.writeGridUuid(out, affinityKey); +// U.writeGridUuid(out, matrixId); + out.writeObject(matrixId); + out.writeObject(affinityKey); out.writeInt(idx); } /** {@inheritDoc} */ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - matrixId = U.readGridUuid(in); - affinityKey = U.readGridUuid(in); - idx = in.readInt(); - } - - /** {@inheritDoc} */ - @Override public void writeBinary(BinaryWriter writer) throws BinaryObjectException { - BinaryRawWriter out = writer.rawWriter(); - - BinaryUtils.writeIgniteUuid(out, matrixId); - BinaryUtils.writeIgniteUuid(out, affinityKey); - out.writeInt(idx); - } - - /** {@inheritDoc} */ - @Override public void readBinary(BinaryReader reader) throws BinaryObjectException { - BinaryRawReader in = reader.rawReader(); - - matrixId = BinaryUtils.readIgniteUuid(in); - affinityKey = BinaryUtils.readIgniteUuid(in); + matrixId = (UUID)in.readObject(); + affinityKey = in.readObject(); idx = in.readInt(); } /** {@inheritDoc} */ @Override public int hashCode() { - int res = 1; - - res += res * 37 + matrixId.hashCode(); - res += res * 37 + idx; - + int res = idx; + res = 31 * res + (matrixId != null ? matrixId.hashCode() : 0); + res = 31 * res + (affinityKey != null ? affinityKey.hashCode() : 0); return res; } http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java index 022dd04..0b4ad12 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java @@ -17,7 +17,9 @@ package org.apache.ignite.ml.math.functions; +import java.util.Comparator; import java.util.List; +import java.util.function.BiFunction; import org.apache.ignite.lang.IgniteBiTuple; /** @@ -75,6 +77,30 @@ public final class Functions { /** Function that returns {@code max(abs(a), abs(b))}. */ public static final IgniteBiFunction<Double, Double, Double> MAX_ABS = (a, b) -> Math.max(Math.abs(a), Math.abs(b)); + /** + * Generic 'max' function. + * @param a First object to compare. + * @param b Second object to compare. + * @param f Comparator. + * @param <T> Type of objects to compare. + * @return Maximum between {@code a} and {@code b} in terms of comparator {@code f}. + */ + public static <T> T MAX_GENERIC(T a, T b, Comparator<T> f) { + return f.compare(a, b) > 0 ? a : b; + } + + /** + * Generic 'min' function. + * @param a First object to compare. + * @param b Second object to compare. + * @param f Comparator. + * @param <T> Type of objects to compare. + * @return Minimum between {@code a} and {@code b} in terms of comparator {@code f}. + */ + public static <T> T MIN_GENERIC(T a, T b, Comparator<T> f) { + return f.compare(a, b) > 0 ? a : b; + } + /** Function that returns {@code min(abs(a), abs(b))}. */ public static final IgniteBiFunction<Double, Double, Double> MIN_ABS = (a, b) -> Math.min(Math.abs(a), Math.abs(b)); @@ -185,4 +211,16 @@ public final class Functions { return Math.pow(a, b); }; } + + /** + * Curry bifunction. + * @param f Bifunction to curry. + * @param <A> Type of first argument of {@code f}. + * @param <B> Type of second argument of {@code f}. + * @param <C> Return type of {@code f}. + * @return Curried bifunction. + */ + public static <A, B, C> IgniteCurriedBiFunction<A, B, C> curry(BiFunction<A, B, C> f) { + return a -> b -> f.apply(a, b); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteBinaryOperator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteBinaryOperator.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteBinaryOperator.java new file mode 100644 index 0000000..1170b67 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteBinaryOperator.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.functions; + +import java.io.Serializable; +import java.util.function.BinaryOperator; + +/** + * Serializable binary operator. + * + * @see java.util.function.BinaryOperator + */ +public interface IgniteBinaryOperator<A> extends BinaryOperator<A>, Serializable { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteCurriedBiFunction.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteCurriedBiFunction.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteCurriedBiFunction.java new file mode 100644 index 0000000..3dd8490 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteCurriedBiFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.functions; + +import java.io.Serializable; +import java.util.function.BiFunction; + +/** + * Serializable binary function. + * + * @see BiFunction + */ +public interface IgniteCurriedBiFunction<A, B, T> extends IgniteFunction<A, IgniteFunction<B, T>>, Serializable { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteSupplier.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteSupplier.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteSupplier.java new file mode 100644 index 0000000..8c05b75 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteSupplier.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.functions; + +import java.io.Serializable; +import java.util.function.Supplier; + +/** + * Serializable supplier. + * + * @see java.util.function.Consumer + */ +@FunctionalInterface +public interface IgniteSupplier<T> extends Supplier<T>, Serializable { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteToDoubleFunction.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteToDoubleFunction.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteToDoubleFunction.java new file mode 100644 index 0000000..59a8bf3 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteToDoubleFunction.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.functions; + +import java.io.Serializable; +import java.util.function.ToDoubleFunction; + +@FunctionalInterface +public interface IgniteToDoubleFunction<T> extends ToDoubleFunction<T>, Serializable { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseBlockDistributedMatrix.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseBlockDistributedMatrix.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseBlockDistributedMatrix.java index 3d542bc..e829168 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseBlockDistributedMatrix.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseBlockDistributedMatrix.java @@ -20,13 +20,13 @@ package org.apache.ignite.ml.math.impls.matrix; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.UUID; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.affinity.Affinity; import org.apache.ignite.cluster.ClusterNode; import org.apache.ignite.internal.util.lang.IgnitePair; -import org.apache.ignite.lang.IgniteUuid; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.StorageConstants; import org.apache.ignite.ml.math.Vector; @@ -190,7 +190,7 @@ public class SparseBlockDistributedMatrix extends AbstractMatrix implements Stor } /** */ - private IgniteUuid getUUID() { + private UUID getUUID() { return ((BlockMatrixStorage)getStorage()).getUUID(); } http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix.java index 9a18f8b..594aebc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedMatrix.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.math.impls.matrix; import java.util.Collection; import java.util.Map; +import java.util.UUID; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; @@ -211,7 +212,7 @@ public class SparseDistributedMatrix extends AbstractMatrix implements StorageCo } /** */ - public IgniteUuid getUUID() { + public UUID getUUID() { return ((SparseDistributedMatrixStorage)getStorage()).getUUID(); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockMatrixStorage.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockMatrixStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockMatrixStorage.java index 0d5cf0a..cd76e5a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockMatrixStorage.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/BlockMatrixStorage.java @@ -24,6 +24,7 @@ import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Set; +import java.util.UUID; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.CacheAtomicityMode; @@ -32,7 +33,6 @@ import org.apache.ignite.cache.CachePeekMode; import org.apache.ignite.cache.CacheWriteSynchronizationMode; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.lang.IgnitePair; -import org.apache.ignite.internal.util.typedef.internal.U; import org.apache.ignite.lang.IgniteUuid; import org.apache.ignite.ml.math.MatrixStorage; import org.apache.ignite.ml.math.StorageConstants; @@ -59,7 +59,7 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto /** Amount of columns in the matrix. */ private int cols; /** Matrix uuid. */ - private IgniteUuid uuid; + private UUID uuid; /** Block size about 8 KB of data. */ private int maxBlockEdge = MAX_BLOCK_SIZE; @@ -92,7 +92,7 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto cache = newCache(); - uuid = IgniteUuid.randomUuid(); + uuid = UUID.randomUUID(); } /** @@ -152,7 +152,7 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto out.writeInt(cols); out.writeInt(blocksInRow); out.writeInt(blocksInCol); - U.writeGridUuid(out, uuid); + out.writeObject(uuid); out.writeUTF(cache.getName()); } @@ -162,7 +162,7 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto cols = in.readInt(); blocksInRow = in.readInt(); blocksInCol = in.readInt(); - uuid = U.readGridUuid(in); + uuid = (UUID)in.readObject(); cache = ignite().getOrCreateCache(in.readUTF()); } @@ -201,7 +201,7 @@ public class BlockMatrixStorage extends CacheUtils implements MatrixStorage, Sto * * @return storage UUID. */ - public IgniteUuid getUUID() { + public UUID getUUID() { return uuid; } http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/SparseDistributedMatrixStorage.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/SparseDistributedMatrixStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/SparseDistributedMatrixStorage.java index 95852b7..c40e73d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/SparseDistributedMatrixStorage.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/matrix/SparseDistributedMatrixStorage.java @@ -24,6 +24,7 @@ import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Map; import java.util.Set; +import java.util.UUID; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.ignite.IgniteCache; @@ -33,7 +34,6 @@ import org.apache.ignite.cache.CacheMode; import org.apache.ignite.cache.CachePeekMode; import org.apache.ignite.cache.CacheWriteSynchronizationMode; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.lang.IgniteUuid; import org.apache.ignite.ml.math.MatrixStorage; import org.apache.ignite.ml.math.StorageConstants; import org.apache.ignite.ml.math.distributed.CacheUtils; @@ -57,7 +57,7 @@ public class SparseDistributedMatrixStorage extends CacheUtils implements Matrix /** Random or sequential access mode. */ private int acsMode; /** Matrix uuid. */ - private IgniteUuid uuid; + private UUID uuid; /** Actual distributed storage. */ private IgniteCache< @@ -91,7 +91,7 @@ public class SparseDistributedMatrixStorage extends CacheUtils implements Matrix cache = newCache(); - uuid = IgniteUuid.randomUuid(); + uuid = UUID.randomUUID(); } /** @@ -115,6 +115,9 @@ public class SparseDistributedMatrixStorage extends CacheUtils implements Matrix // Cache is partitioned. cfg.setCacheMode(CacheMode.PARTITIONED); + // TODO: Possibly we should add a fix of https://issues.apache.org/jira/browse/IGNITE-6862 here commented below. + // cfg.setReadFromBackup(false); + // Random cache name. cfg.setName(CACHE_NAME); @@ -205,7 +208,7 @@ public class SparseDistributedMatrixStorage extends CacheUtils implements Matrix /** Build cache key for row/column. */ public RowColMatrixKey getCacheKey(int idx) { - return new SparseMatrixKey(idx, uuid, null); + return new SparseMatrixKey(idx, uuid, idx); } /** {@inheritDoc} */ @@ -239,7 +242,7 @@ public class SparseDistributedMatrixStorage extends CacheUtils implements Matrix cols = in.readInt(); acsMode = in.readInt(); stoMode = in.readInt(); - uuid = (IgniteUuid)in.readObject(); + uuid = (UUID)in.readObject(); cache = ignite().getOrCreateCache(in.readUTF()); } @@ -304,7 +307,7 @@ public class SparseDistributedMatrixStorage extends CacheUtils implements Matrix } /** */ - public IgniteUuid getUUID() { + public UUID getUUID() { return uuid; } @@ -312,7 +315,7 @@ public class SparseDistributedMatrixStorage extends CacheUtils implements Matrix @Override public Set<RowColMatrixKey> getAllKeys() { int range = stoMode == ROW_STORAGE_MODE ? rows : cols; - return IntStream.range(0, range).mapToObj(i -> new SparseMatrixKey(i, getUUID(), null)).collect(Collectors.toSet()); + return IntStream.range(0, range).mapToObj(i -> new SparseMatrixKey(i, getUUID(), i)).collect(Collectors.toSet()); } /** {@inheritDoc} */ http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java new file mode 100644 index 0000000..51b973a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures; + +import org.apache.ignite.ml.math.Vector; + +/** + * Class for vector with label. + * + * @param <V> Some class extending {@link Vector}. + * @param <T> Type of label. + */ +public class LabeledVector<V extends Vector, T> { + /** Vector. */ + private final V vector; + + /** Label. */ + private final T lb; + + /** + * Construct labeled vector. + * + * @param vector Vector. + * @param lb Label. + */ + public LabeledVector(V vector, T lb) { + this.vector = vector; + this.lb = lb; + } + + /** + * Get the vector. + * + * @return Vector. + */ + public V vector() { + return vector; + } + + /** + * Get the label. + * + * @return Label. + */ + public T label() { + return lb; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorDouble.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorDouble.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorDouble.java new file mode 100644 index 0000000..4ef9eae --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorDouble.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures; + +import org.apache.ignite.ml.math.Vector; + +/** + * Labeled vector specialized to double label. + * + * @param <V> Type of vector. + */ +public class LabeledVectorDouble<V extends Vector> extends LabeledVector<V, Double> { + /** + * Construct LabeledVectorDouble. + * + * @param vector Vector. + * @param lb Label. + */ + public LabeledVectorDouble(V vector, Double lb) { + super(vector, lb); + } + + /** + * Get label as double. + * + * @return label as double. + */ + public double doubleLabel() { + return label(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/structures/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/package-info.java new file mode 100644 index 0000000..ec9d79e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains some utility structures. + */ +package org.apache.ignite.ml.structures; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalRegionInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalRegionInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalRegionInfo.java new file mode 100644 index 0000000..3ae474e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalRegionInfo.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.BitSet; + +/** + * Information about categorical region. + */ +public class CategoricalRegionInfo extends RegionInfo implements Externalizable { + /** + * Bitset representing categories of this region. + */ + private BitSet cats; + + /** + * @param impurity Impurity of region. + * @param cats Bitset representing categories of this region. + */ + public CategoricalRegionInfo(double impurity, BitSet cats) { + super(impurity); + + this.cats = cats; + } + + /** + * No-op constructor for serialization/deserialization. + */ + public CategoricalRegionInfo() { + // No-op + } + + /** + * Get bitset representing categories of this region. + * + * @return Bitset representing categories of this region. + */ + public BitSet cats() { + return cats; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + super.writeExternal(out); + out.writeObject(cats); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + super.readExternal(in); + cats = (BitSet)in.readObject(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalSplitInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalSplitInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalSplitInfo.java new file mode 100644 index 0000000..94cb1e8 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalSplitInfo.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees; + +import java.util.BitSet; +import org.apache.ignite.ml.trees.nodes.CategoricalSplitNode; +import org.apache.ignite.ml.trees.nodes.SplitNode; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; + +/** + * Information about split of categorical feature. + * + * @param <D> Class representing information of left and right subregions. + */ +public class CategoricalSplitInfo<D extends RegionInfo> extends SplitInfo<D> { + /** Bitset indicating which vectors are assigned to left subregion. */ + private final BitSet bs; + + /** + * @param regionIdx Index of region which is split. + * @param leftData Data of left subregion. + * @param rightData Data of right subregion. + * @param bs Bitset indicating which vectors are assigned to left subregion. + */ + public CategoricalSplitInfo(int regionIdx, D leftData, D rightData, + BitSet bs) { + super(regionIdx, leftData, rightData); + this.bs = bs; + } + + /** {@inheritDoc} */ + @Override public SplitNode createSplitNode(int featureIdx) { + return new CategoricalSplitNode(featureIdx, bs); + } + + /** + * Get bitset indicating which vectors are assigned to left subregion. + */ + public BitSet bitSet() { + return bs; + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "CategoricalSplitInfo [" + + "infoGain=" + infoGain + + ", regionIdx=" + regionIdx + + ", leftData=" + leftData + + ", bs=" + bs + + ", rightData=" + rightData + + ']'; + } +}
