IGNITE-8232: ML package cleanup for 2.5 release this closes #3806
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/47cfdc27 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/47cfdc27 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/47cfdc27 Branch: refs/heads/ignite-7708 Commit: 47cfdc27e5079ee0ec91de1539bff498ffc1edc2 Parents: ee9ca06 Author: dmitrievanthony <dmitrievanth...@gmail.com> Authored: Fri Apr 13 18:08:08 2018 +0300 Committer: Yury Babak <yba...@gridgain.com> Committed: Fri Apr 13 18:08:08 2018 +0300 ---------------------------------------------------------------------- .../examples/ml/nn/MLPTrainerExample.java | 2 +- ...nWithLSQRTrainerAndNormalizationExample.java | 180 --------- ...dLinearRegressionWithLSQRTrainerExample.java | 169 -------- ...tedLinearRegressionWithQRTrainerExample.java | 137 ------- ...edLinearRegressionWithSGDTrainerExample.java | 177 -------- .../LinearRegressionLSQRTrainerExample.java | 169 ++++++++ ...sionLSQRTrainerWithNormalizationExample.java | 180 +++++++++ .../LinearRegressionSGDTrainerExample.java | 176 ++++++++ .../main/java/org/apache/ignite/ml/Trainer.java | 36 -- .../apache/ignite/ml/estimators/Estimators.java | 50 --- .../ignite/ml/estimators/package-info.java | 22 - .../ml/math/functions/IgniteBiFunction.java | 8 +- .../LinSysPartitionDataBuilderOnHeap.java | 86 ---- .../math/isolve/LinSysPartitionDataOnHeap.java | 65 --- .../ml/math/isolve/lsqr/AbstractLSQR.java | 3 +- .../ignite/ml/math/isolve/lsqr/LSQROnHeap.java | 27 +- .../org/apache/ignite/ml/nn/MLPTrainer.java | 1 - .../apache/ignite/ml/nn/UpdatesStrategy.java | 95 +++++ .../ignite/ml/optimization/GradientDescent.java | 202 ---------- .../ml/optimization/GradientFunction.java | 31 -- .../LeastSquaresGradientFunction.java | 33 -- .../util/SparseDistributedMatrixMapReducer.java | 84 ---- .../ml/optimization/util/package-info.java | 22 - .../linear/LinearRegressionLSQRTrainer.java | 10 +- .../linear/LinearRegressionQRTrainer.java | 72 ---- .../linear/LinearRegressionSGDTrainer.java | 7 +- .../org/apache/ignite/ml/trainers/Trainer.java | 33 -- .../trainers/group/BaseLocalProcessorJob.java | 146 ------- .../ignite/ml/trainers/group/ConstModel.java | 46 --- .../ignite/ml/trainers/group/GroupTrainer.java | 208 ---------- .../group/GroupTrainerBaseProcessorTask.java | 144 ------- .../ml/trainers/group/GroupTrainerCacheKey.java | 125 ------ .../group/GroupTrainerEntriesProcessorTask.java | 64 --- .../ml/trainers/group/GroupTrainerInput.java | 37 -- .../group/GroupTrainerKeysProcessorTask.java | 62 --- .../ml/trainers/group/GroupTrainingContext.java | 98 ----- .../group/LocalEntriesProcessorJob.java | 85 ---- .../trainers/group/LocalKeysProcessorJob.java | 78 ---- .../ignite/ml/trainers/group/Metaoptimizer.java | 93 ----- .../group/MetaoptimizerDistributedStep.java | 97 ----- .../group/MetaoptimizerGroupTrainer.java | 132 ------ .../ml/trainers/group/ResultAndUpdates.java | 178 -------- .../ml/trainers/group/UpdateStrategies.java | 47 --- .../ml/trainers/group/UpdatesStrategy.java | 95 ----- .../ignite/ml/trainers/group/chain/Chains.java | 56 --- .../trainers/group/chain/ComputationsChain.java | 246 ----------- .../chain/DistributedEntryProcessingStep.java | 34 -- .../chain/DistributedKeyProcessingStep.java | 33 -- .../trainers/group/chain/DistributedStep.java | 70 ---- .../trainers/group/chain/EntryAndContext.java | 70 ---- .../trainers/group/chain/HasTrainingUUID.java | 32 -- .../ml/trainers/group/chain/KeyAndContext.java | 67 --- .../ml/trainers/group/chain/package-info.java | 22 - .../ignite/ml/trainers/group/package-info.java | 22 - .../org/apache/ignite/ml/IgniteMLTestSuite.java | 4 - .../ml/math/isolve/lsqr/LSQROnHeapTest.java | 14 +- .../ignite/ml/nn/MLPTrainerIntegrationTest.java | 1 - .../org/apache/ignite/ml/nn/MLPTrainerTest.java | 1 - .../MLPTrainerMnistIntegrationTest.java | 2 +- .../ml/nn/performance/MLPTrainerMnistTest.java | 2 +- .../ml/optimization/GradientDescentTest.java | 64 --- .../ml/optimization/OptimizationTestSuite.java | 33 -- .../SparseDistributedMatrixMapReducerTest.java | 135 ------- .../ml/regressions/RegressionsTestSuite.java | 3 - .../linear/ArtificialRegressionDatasets.java | 404 ------------------- ...istributedLinearRegressionQRTrainerTest.java | 36 -- ...istributedLinearRegressionQRTrainerTest.java | 36 -- .../GenericLinearRegressionTrainerTest.java | 206 ---------- ...wareAbstractLinearRegressionTrainerTest.java | 127 ------ .../linear/LinearRegressionSGDTrainerTest.java | 2 +- .../LocalLinearRegressionQRTrainerTest.java | 36 -- .../group/DistributedWorkersChainTest.java | 189 --------- .../ml/trainers/group/GroupTrainerTest.java | 90 ----- .../trainers/group/SimpleGroupTrainerInput.java | 63 --- .../ml/trainers/group/TestGroupTrainer.java | 144 ------- .../group/TestGroupTrainerLocalContext.java | 85 ---- .../trainers/group/TestGroupTrainingCache.java | 70 ---- .../group/TestGroupTrainingSecondCache.java | 56 --- .../ml/trainers/group/TestLocalContext.java | 51 --- .../ml/trainers/group/TestTrainingLoopStep.java | 65 --- .../trainers/group/TrainersGroupTestSuite.java | 32 -- ...iteOLSMultipleLinearRegressionBenchmark.java | 69 ---- .../yardstick/ml/regression/package-info.java | 22 - 83 files changed, 666 insertions(+), 5840 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java index ce44cc6..5d1ac38 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java @@ -32,7 +32,7 @@ import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.trainers.group.UpdatesStrategy; +import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.thread.IgniteThread; /** http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java deleted file mode 100644 index 99e6577..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.regression.linear; - -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; -import org.apache.ignite.cache.query.QueryCursor; -import org.apache.ignite.cache.query.ScanQuery; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor; -import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; -import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; -import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.thread.IgniteThread; - -import javax.cache.Cache; -import java.util.Arrays; -import java.util.UUID; - -/** - * Run linear regression model over distributed matrix. - * - * @see LinearRegressionLSQRTrainer - * @see NormalizationTrainer - * @see NormalizationPreprocessor - */ -public class DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample { - /** */ - private static final double[][] data = { - {8, 78, 284, 9.100000381, 109}, - {9.300000191, 68, 433, 8.699999809, 144}, - {7.5, 70, 739, 7.199999809, 113}, - {8.899999619, 96, 1792, 8.899999619, 97}, - {10.19999981, 74, 477, 8.300000191, 206}, - {8.300000191, 111, 362, 10.89999962, 124}, - {8.800000191, 77, 671, 10, 152}, - {8.800000191, 168, 636, 9.100000381, 162}, - {10.69999981, 82, 329, 8.699999809, 150}, - {11.69999981, 89, 634, 7.599999905, 134}, - {8.5, 149, 631, 10.80000019, 292}, - {8.300000191, 60, 257, 9.5, 108}, - {8.199999809, 96, 284, 8.800000191, 111}, - {7.900000095, 83, 603, 9.5, 182}, - {10.30000019, 130, 686, 8.699999809, 129}, - {7.400000095, 145, 345, 11.19999981, 158}, - {9.600000381, 112, 1357, 9.699999809, 186}, - {9.300000191, 131, 544, 9.600000381, 177}, - {10.60000038, 80, 205, 9.100000381, 127}, - {9.699999809, 130, 1264, 9.199999809, 179}, - {11.60000038, 140, 688, 8.300000191, 80}, - {8.100000381, 154, 354, 8.399999619, 103}, - {9.800000191, 118, 1632, 9.399999619, 101}, - {7.400000095, 94, 348, 9.800000191, 117}, - {9.399999619, 119, 370, 10.39999962, 88}, - {11.19999981, 153, 648, 9.899999619, 78}, - {9.100000381, 116, 366, 9.199999809, 102}, - {10.5, 97, 540, 10.30000019, 95}, - {11.89999962, 176, 680, 8.899999619, 80}, - {8.399999619, 75, 345, 9.600000381, 92}, - {5, 134, 525, 10.30000019, 126}, - {9.800000191, 161, 870, 10.39999962, 108}, - {9.800000191, 111, 669, 9.699999809, 77}, - {10.80000019, 114, 452, 9.600000381, 60}, - {10.10000038, 142, 430, 10.69999981, 71}, - {10.89999962, 238, 822, 10.30000019, 86}, - {9.199999809, 78, 190, 10.69999981, 93}, - {8.300000191, 196, 867, 9.600000381, 106}, - {7.300000191, 125, 969, 10.5, 162}, - {9.399999619, 82, 499, 7.699999809, 95}, - {9.399999619, 125, 925, 10.19999981, 91}, - {9.800000191, 129, 353, 9.899999619, 52}, - {3.599999905, 84, 288, 8.399999619, 110}, - {8.399999619, 183, 718, 10.39999962, 69}, - {10.80000019, 119, 540, 9.199999809, 57}, - {10.10000038, 180, 668, 13, 106}, - {9, 82, 347, 8.800000191, 40}, - {10, 71, 345, 9.199999809, 50}, - {11.30000019, 118, 463, 7.800000191, 35}, - {11.30000019, 121, 728, 8.199999809, 86}, - {12.80000019, 68, 383, 7.400000095, 57}, - {10, 112, 316, 10.39999962, 57}, - {6.699999809, 109, 388, 8.899999619, 94} - }; - - /** Run example. */ - public static void main(String[] args) throws InterruptedException { - System.out.println(); - System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - SparseDistributedMatrixExample.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); - - System.out.println(">>> Create new normalization trainer object."); - NormalizationTrainer<Integer, double[]> normalizationTrainer = new NormalizationTrainer<>(); - - System.out.println(">>> Perform the training to get the normalization preprocessor."); - IgniteBiFunction<Integer, double[], double[]> preprocessor = normalizationTrainer.fit( - ignite, - dataCache, - (k, v) -> Arrays.copyOfRange(v, 1, v.length) - ); - - System.out.println(">>> Create new linear regression trainer object."); - LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); - - System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, (k, v) -> v[0]); - - System.out.println(">>> Linear regression model: " + mdl); - - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); - - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - Integer key = observation.getKey(); - double[] val = observation.getValue(); - double groundTruth = val[0]; - - double prediction = mdl.apply(new DenseLocalOnHeapVector(preprocessor.apply(key, val))); - - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } - } - - System.out.println(">>> ---------------------------------"); - }); - - igniteThread.start(); - - igniteThread.join(); - } - } - - /** - * Fills cache with data and returns it. - * - * @param ignite Ignite instance. - * @return Filled Ignite Cache. - */ - private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { - CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); - cacheConfiguration.setName("TEST_" + UUID.randomUUID()); - cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); - - IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); - - for (int i = 0; i < data.length; i++) - cache.put(i, data[i]); - - return cache; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java deleted file mode 100644 index 25aec0c..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.regression.linear; - -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; -import org.apache.ignite.cache.query.QueryCursor; -import org.apache.ignite.cache.query.ScanQuery; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; -import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.thread.IgniteThread; - -import javax.cache.Cache; -import java.util.Arrays; -import java.util.UUID; - -/** - * Run linear regression model over distributed matrix. - * - * @see LinearRegressionLSQRTrainer - */ -public class DistributedLinearRegressionWithLSQRTrainerExample { - /** */ - private static final double[][] data = { - {8, 78, 284, 9.100000381, 109}, - {9.300000191, 68, 433, 8.699999809, 144}, - {7.5, 70, 739, 7.199999809, 113}, - {8.899999619, 96, 1792, 8.899999619, 97}, - {10.19999981, 74, 477, 8.300000191, 206}, - {8.300000191, 111, 362, 10.89999962, 124}, - {8.800000191, 77, 671, 10, 152}, - {8.800000191, 168, 636, 9.100000381, 162}, - {10.69999981, 82, 329, 8.699999809, 150}, - {11.69999981, 89, 634, 7.599999905, 134}, - {8.5, 149, 631, 10.80000019, 292}, - {8.300000191, 60, 257, 9.5, 108}, - {8.199999809, 96, 284, 8.800000191, 111}, - {7.900000095, 83, 603, 9.5, 182}, - {10.30000019, 130, 686, 8.699999809, 129}, - {7.400000095, 145, 345, 11.19999981, 158}, - {9.600000381, 112, 1357, 9.699999809, 186}, - {9.300000191, 131, 544, 9.600000381, 177}, - {10.60000038, 80, 205, 9.100000381, 127}, - {9.699999809, 130, 1264, 9.199999809, 179}, - {11.60000038, 140, 688, 8.300000191, 80}, - {8.100000381, 154, 354, 8.399999619, 103}, - {9.800000191, 118, 1632, 9.399999619, 101}, - {7.400000095, 94, 348, 9.800000191, 117}, - {9.399999619, 119, 370, 10.39999962, 88}, - {11.19999981, 153, 648, 9.899999619, 78}, - {9.100000381, 116, 366, 9.199999809, 102}, - {10.5, 97, 540, 10.30000019, 95}, - {11.89999962, 176, 680, 8.899999619, 80}, - {8.399999619, 75, 345, 9.600000381, 92}, - {5, 134, 525, 10.30000019, 126}, - {9.800000191, 161, 870, 10.39999962, 108}, - {9.800000191, 111, 669, 9.699999809, 77}, - {10.80000019, 114, 452, 9.600000381, 60}, - {10.10000038, 142, 430, 10.69999981, 71}, - {10.89999962, 238, 822, 10.30000019, 86}, - {9.199999809, 78, 190, 10.69999981, 93}, - {8.300000191, 196, 867, 9.600000381, 106}, - {7.300000191, 125, 969, 10.5, 162}, - {9.399999619, 82, 499, 7.699999809, 95}, - {9.399999619, 125, 925, 10.19999981, 91}, - {9.800000191, 129, 353, 9.899999619, 52}, - {3.599999905, 84, 288, 8.399999619, 110}, - {8.399999619, 183, 718, 10.39999962, 69}, - {10.80000019, 119, 540, 9.199999809, 57}, - {10.10000038, 180, 668, 13, 106}, - {9, 82, 347, 8.800000191, 40}, - {10, 71, 345, 9.199999809, 50}, - {11.30000019, 118, 463, 7.800000191, 35}, - {11.30000019, 121, 728, 8.199999809, 86}, - {12.80000019, 68, 383, 7.400000095, 57}, - {10, 112, 316, 10.39999962, 57}, - {6.699999809, 109, 388, 8.899999619, 94} - }; - - /** Run example. */ - public static void main(String[] args) throws InterruptedException { - System.out.println(); - System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - DistributedLinearRegressionWithLSQRTrainerExample.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); - - System.out.println(">>> Create new linear regression trainer object."); - LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); - - System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> Arrays.copyOfRange(v, 1, v.length), - (k, v) -> v[0] - ); - - System.out.println(">>> Linear regression model: " + mdl); - - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); - - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; - - double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs)); - - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } - } - - System.out.println(">>> ---------------------------------"); - }); - - igniteThread.start(); - - igniteThread.join(); - } - } - - /** - * Fills cache with data and returns it. - * - * @param ignite Ignite instance. - * @return Filled Ignite Cache. - */ - private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { - CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); - cacheConfiguration.setName("TEST_" + UUID.randomUUID()); - cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); - - IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); - - for (int i = 0; i < data.length; i++) - cache.put(i, data[i]); - - return cache; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java deleted file mode 100644 index 98d5e4e..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.regression.linear; - -import org.apache.ignite.Ignite; -import org.apache.ignite.Ignition; -import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; -import org.apache.ignite.ml.Trainer; -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; -import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector; -import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.ml.regressions.linear.LinearRegressionQRTrainer; -import org.apache.ignite.thread.IgniteThread; - -import java.util.Arrays; - -/** - * Run linear regression model over distributed matrix. - * - * @see LinearRegressionQRTrainer - */ -public class DistributedLinearRegressionWithQRTrainerExample { - /** */ - private static final double[][] data = { - {8, 78, 284, 9.100000381, 109}, - {9.300000191, 68, 433, 8.699999809, 144}, - {7.5, 70, 739, 7.199999809, 113}, - {8.899999619, 96, 1792, 8.899999619, 97}, - {10.19999981, 74, 477, 8.300000191, 206}, - {8.300000191, 111, 362, 10.89999962, 124}, - {8.800000191, 77, 671, 10, 152}, - {8.800000191, 168, 636, 9.100000381, 162}, - {10.69999981, 82, 329, 8.699999809, 150}, - {11.69999981, 89, 634, 7.599999905, 134}, - {8.5, 149, 631, 10.80000019, 292}, - {8.300000191, 60, 257, 9.5, 108}, - {8.199999809, 96, 284, 8.800000191, 111}, - {7.900000095, 83, 603, 9.5, 182}, - {10.30000019, 130, 686, 8.699999809, 129}, - {7.400000095, 145, 345, 11.19999981, 158}, - {9.600000381, 112, 1357, 9.699999809, 186}, - {9.300000191, 131, 544, 9.600000381, 177}, - {10.60000038, 80, 205, 9.100000381, 127}, - {9.699999809, 130, 1264, 9.199999809, 179}, - {11.60000038, 140, 688, 8.300000191, 80}, - {8.100000381, 154, 354, 8.399999619, 103}, - {9.800000191, 118, 1632, 9.399999619, 101}, - {7.400000095, 94, 348, 9.800000191, 117}, - {9.399999619, 119, 370, 10.39999962, 88}, - {11.19999981, 153, 648, 9.899999619, 78}, - {9.100000381, 116, 366, 9.199999809, 102}, - {10.5, 97, 540, 10.30000019, 95}, - {11.89999962, 176, 680, 8.899999619, 80}, - {8.399999619, 75, 345, 9.600000381, 92}, - {5, 134, 525, 10.30000019, 126}, - {9.800000191, 161, 870, 10.39999962, 108}, - {9.800000191, 111, 669, 9.699999809, 77}, - {10.80000019, 114, 452, 9.600000381, 60}, - {10.10000038, 142, 430, 10.69999981, 71}, - {10.89999962, 238, 822, 10.30000019, 86}, - {9.199999809, 78, 190, 10.69999981, 93}, - {8.300000191, 196, 867, 9.600000381, 106}, - {7.300000191, 125, 969, 10.5, 162}, - {9.399999619, 82, 499, 7.699999809, 95}, - {9.399999619, 125, 925, 10.19999981, 91}, - {9.800000191, 129, 353, 9.899999619, 52}, - {3.599999905, 84, 288, 8.399999619, 110}, - {8.399999619, 183, 718, 10.39999962, 69}, - {10.80000019, 119, 540, 9.199999809, 57}, - {10.10000038, 180, 668, 13, 106}, - {9, 82, 347, 8.800000191, 40}, - {10, 71, 345, 9.199999809, 50}, - {11.30000019, 118, 463, 7.800000191, 35}, - {11.30000019, 121, 728, 8.199999809, 86}, - {12.80000019, 68, 383, 7.400000095, 57}, - {10, 112, 316, 10.39999962, 57}, - {6.699999809, 109, 388, 8.899999619, 94} - }; - - /** Run example. */ - public static void main(String[] args) throws InterruptedException { - System.out.println(); - System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - SparseDistributedMatrixExample.class.getSimpleName(), () -> { - - // Create SparseDistributedMatrix, new cache will be created automagically. - System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread."); - SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(data); - - System.out.println(">>> Create new linear regression trainer object."); - Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionQRTrainer(); - - System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel mdl = trainer.train(distributedMatrix); - System.out.println(">>> Linear regression model: " + mdl); - - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); - for (double[] observation : data) { - Vector inputs = new SparseDistributedVector(Arrays.copyOfRange(observation, 1, observation.length)); - double prediction = mdl.apply(inputs); - double groundTruth = observation[0]; - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } - System.out.println(">>> ---------------------------------"); - }); - - igniteThread.start(); - - igniteThread.join(); - } - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java deleted file mode 100644 index 44366e1..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.regression.linear; - -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; -import org.apache.ignite.cache.query.QueryCursor; -import org.apache.ignite.cache.query.ScanQuery; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; -import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.ml.regressions.linear.LinearRegressionQRTrainer; -import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer; -import org.apache.ignite.ml.trainers.group.UpdatesStrategy; -import org.apache.ignite.thread.IgniteThread; - -import javax.cache.Cache; -import java.util.Arrays; -import java.util.UUID; - -/** - * Run linear regression model over distributed matrix. - * - * @see LinearRegressionQRTrainer - */ -public class DistributedLinearRegressionWithSGDTrainerExample { - /** */ - private static final double[][] data = { - {8, 78, 284, 9.100000381, 109}, - {9.300000191, 68, 433, 8.699999809, 144}, - {7.5, 70, 739, 7.199999809, 113}, - {8.899999619, 96, 1792, 8.899999619, 97}, - {10.19999981, 74, 477, 8.300000191, 206}, - {8.300000191, 111, 362, 10.89999962, 124}, - {8.800000191, 77, 671, 10, 152}, - {8.800000191, 168, 636, 9.100000381, 162}, - {10.69999981, 82, 329, 8.699999809, 150}, - {11.69999981, 89, 634, 7.599999905, 134}, - {8.5, 149, 631, 10.80000019, 292}, - {8.300000191, 60, 257, 9.5, 108}, - {8.199999809, 96, 284, 8.800000191, 111}, - {7.900000095, 83, 603, 9.5, 182}, - {10.30000019, 130, 686, 8.699999809, 129}, - {7.400000095, 145, 345, 11.19999981, 158}, - {9.600000381, 112, 1357, 9.699999809, 186}, - {9.300000191, 131, 544, 9.600000381, 177}, - {10.60000038, 80, 205, 9.100000381, 127}, - {9.699999809, 130, 1264, 9.199999809, 179}, - {11.60000038, 140, 688, 8.300000191, 80}, - {8.100000381, 154, 354, 8.399999619, 103}, - {9.800000191, 118, 1632, 9.399999619, 101}, - {7.400000095, 94, 348, 9.800000191, 117}, - {9.399999619, 119, 370, 10.39999962, 88}, - {11.19999981, 153, 648, 9.899999619, 78}, - {9.100000381, 116, 366, 9.199999809, 102}, - {10.5, 97, 540, 10.30000019, 95}, - {11.89999962, 176, 680, 8.899999619, 80}, - {8.399999619, 75, 345, 9.600000381, 92}, - {5, 134, 525, 10.30000019, 126}, - {9.800000191, 161, 870, 10.39999962, 108}, - {9.800000191, 111, 669, 9.699999809, 77}, - {10.80000019, 114, 452, 9.600000381, 60}, - {10.10000038, 142, 430, 10.69999981, 71}, - {10.89999962, 238, 822, 10.30000019, 86}, - {9.199999809, 78, 190, 10.69999981, 93}, - {8.300000191, 196, 867, 9.600000381, 106}, - {7.300000191, 125, 969, 10.5, 162}, - {9.399999619, 82, 499, 7.699999809, 95}, - {9.399999619, 125, 925, 10.19999981, 91}, - {9.800000191, 129, 353, 9.899999619, 52}, - {3.599999905, 84, 288, 8.399999619, 110}, - {8.399999619, 183, 718, 10.39999962, 69}, - {10.80000019, 119, 540, 9.199999809, 57}, - {10.10000038, 180, 668, 13, 106}, - {9, 82, 347, 8.800000191, 40}, - {10, 71, 345, 9.199999809, 50}, - {11.30000019, 118, 463, 7.800000191, 35}, - {11.30000019, 121, 728, 8.199999809, 86}, - {12.80000019, 68, 383, 7.400000095, 57}, - {10, 112, 316, 10.39999962, 57}, - {6.699999809, 109, 388, 8.899999619, 94} - }; - - /** Run example. */ - public static void main(String[] args) throws InterruptedException { - System.out.println(); - System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - DistributedLinearRegressionWithSGDTrainerExample.class.getSimpleName(), () -> { - - IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); - - System.out.println(">>> Create new linear regression trainer object."); - LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>( - new RPropUpdateCalculator(), - RPropParameterUpdate::sumLocal, - RPropParameterUpdate::avg - ), 100000, 10, 100, 123L); - - System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> Arrays.copyOfRange(v, 1, v.length), - (k, v) -> v[0] - ); - - System.out.println(">>> Linear regression model: " + mdl); - - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); - - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; - - double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs)); - - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } - } - - System.out.println(">>> ---------------------------------"); - }); - - igniteThread.start(); - - igniteThread.join(); - } - } - - /** - * Fills cache with data and returns it. - * - * @param ignite Ignite instance. - * @return Filled Ignite Cache. - */ - private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { - CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); - cacheConfiguration.setName("TEST_" + UUID.randomUUID()); - cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); - - IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); - - for (int i = 0; i < data.length; i++) - cache.put(i, data[i]); - - return cache; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java new file mode 100644 index 0000000..276d43f --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.regression.linear; + +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; +import org.apache.ignite.thread.IgniteThread; + +import javax.cache.Cache; +import java.util.Arrays; +import java.util.UUID; + +/** + * Run linear regression model over distributed matrix. + * + * @see LinearRegressionLSQRTrainer + */ +public class LinearRegressionLSQRTrainerExample { + /** */ + private static final double[][] data = { + {8, 78, 284, 9.100000381, 109}, + {9.300000191, 68, 433, 8.699999809, 144}, + {7.5, 70, 739, 7.199999809, 113}, + {8.899999619, 96, 1792, 8.899999619, 97}, + {10.19999981, 74, 477, 8.300000191, 206}, + {8.300000191, 111, 362, 10.89999962, 124}, + {8.800000191, 77, 671, 10, 152}, + {8.800000191, 168, 636, 9.100000381, 162}, + {10.69999981, 82, 329, 8.699999809, 150}, + {11.69999981, 89, 634, 7.599999905, 134}, + {8.5, 149, 631, 10.80000019, 292}, + {8.300000191, 60, 257, 9.5, 108}, + {8.199999809, 96, 284, 8.800000191, 111}, + {7.900000095, 83, 603, 9.5, 182}, + {10.30000019, 130, 686, 8.699999809, 129}, + {7.400000095, 145, 345, 11.19999981, 158}, + {9.600000381, 112, 1357, 9.699999809, 186}, + {9.300000191, 131, 544, 9.600000381, 177}, + {10.60000038, 80, 205, 9.100000381, 127}, + {9.699999809, 130, 1264, 9.199999809, 179}, + {11.60000038, 140, 688, 8.300000191, 80}, + {8.100000381, 154, 354, 8.399999619, 103}, + {9.800000191, 118, 1632, 9.399999619, 101}, + {7.400000095, 94, 348, 9.800000191, 117}, + {9.399999619, 119, 370, 10.39999962, 88}, + {11.19999981, 153, 648, 9.899999619, 78}, + {9.100000381, 116, 366, 9.199999809, 102}, + {10.5, 97, 540, 10.30000019, 95}, + {11.89999962, 176, 680, 8.899999619, 80}, + {8.399999619, 75, 345, 9.600000381, 92}, + {5, 134, 525, 10.30000019, 126}, + {9.800000191, 161, 870, 10.39999962, 108}, + {9.800000191, 111, 669, 9.699999809, 77}, + {10.80000019, 114, 452, 9.600000381, 60}, + {10.10000038, 142, 430, 10.69999981, 71}, + {10.89999962, 238, 822, 10.30000019, 86}, + {9.199999809, 78, 190, 10.69999981, 93}, + {8.300000191, 196, 867, 9.600000381, 106}, + {7.300000191, 125, 969, 10.5, 162}, + {9.399999619, 82, 499, 7.699999809, 95}, + {9.399999619, 125, 925, 10.19999981, 91}, + {9.800000191, 129, 353, 9.899999619, 52}, + {3.599999905, 84, 288, 8.399999619, 110}, + {8.399999619, 183, 718, 10.39999962, 69}, + {10.80000019, 119, 540, 9.199999809, 57}, + {10.10000038, 180, 668, 13, 106}, + {9, 82, 347, 8.800000191, 40}, + {10, 71, 345, 9.199999809, 50}, + {11.30000019, 118, 463, 7.800000191, 35}, + {11.30000019, 121, 728, 8.199999809, 86}, + {12.80000019, 68, 383, 7.400000095, 57}, + {10, 112, 316, 10.39999962, 57}, + {6.699999809, 109, 388, 8.899999619, 94} + }; + + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread + // because we create ignite cache internally. + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + LinearRegressionLSQRTrainerExample.class.getSimpleName(), () -> { + IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); + + System.out.println(">>> Create new linear regression trainer object."); + LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); + + System.out.println(">>> Perform the training to get the model."); + LinearRegressionModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ); + + System.out.println(">>> Linear regression model: " + mdl); + + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); + + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs)); + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } + } + + System.out.println(">>> ---------------------------------"); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } + + /** + * Fills cache with data and returns it. + * + * @param ignite Ignite instance. + * @return Filled Ignite Cache. + */ + private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { + CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName("TEST_" + UUID.randomUUID()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < data.length; i++) + cache.put(i, data[i]); + + return cache; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java new file mode 100644 index 0000000..0358f44 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithNormalizationExample.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.regression.linear; + +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; +import org.apache.ignite.thread.IgniteThread; + +import javax.cache.Cache; +import java.util.Arrays; +import java.util.UUID; + +/** + * Run linear regression model over distributed matrix. + * + * @see LinearRegressionLSQRTrainer + * @see NormalizationTrainer + * @see NormalizationPreprocessor + */ +public class LinearRegressionLSQRTrainerWithNormalizationExample { + /** */ + private static final double[][] data = { + {8, 78, 284, 9.100000381, 109}, + {9.300000191, 68, 433, 8.699999809, 144}, + {7.5, 70, 739, 7.199999809, 113}, + {8.899999619, 96, 1792, 8.899999619, 97}, + {10.19999981, 74, 477, 8.300000191, 206}, + {8.300000191, 111, 362, 10.89999962, 124}, + {8.800000191, 77, 671, 10, 152}, + {8.800000191, 168, 636, 9.100000381, 162}, + {10.69999981, 82, 329, 8.699999809, 150}, + {11.69999981, 89, 634, 7.599999905, 134}, + {8.5, 149, 631, 10.80000019, 292}, + {8.300000191, 60, 257, 9.5, 108}, + {8.199999809, 96, 284, 8.800000191, 111}, + {7.900000095, 83, 603, 9.5, 182}, + {10.30000019, 130, 686, 8.699999809, 129}, + {7.400000095, 145, 345, 11.19999981, 158}, + {9.600000381, 112, 1357, 9.699999809, 186}, + {9.300000191, 131, 544, 9.600000381, 177}, + {10.60000038, 80, 205, 9.100000381, 127}, + {9.699999809, 130, 1264, 9.199999809, 179}, + {11.60000038, 140, 688, 8.300000191, 80}, + {8.100000381, 154, 354, 8.399999619, 103}, + {9.800000191, 118, 1632, 9.399999619, 101}, + {7.400000095, 94, 348, 9.800000191, 117}, + {9.399999619, 119, 370, 10.39999962, 88}, + {11.19999981, 153, 648, 9.899999619, 78}, + {9.100000381, 116, 366, 9.199999809, 102}, + {10.5, 97, 540, 10.30000019, 95}, + {11.89999962, 176, 680, 8.899999619, 80}, + {8.399999619, 75, 345, 9.600000381, 92}, + {5, 134, 525, 10.30000019, 126}, + {9.800000191, 161, 870, 10.39999962, 108}, + {9.800000191, 111, 669, 9.699999809, 77}, + {10.80000019, 114, 452, 9.600000381, 60}, + {10.10000038, 142, 430, 10.69999981, 71}, + {10.89999962, 238, 822, 10.30000019, 86}, + {9.199999809, 78, 190, 10.69999981, 93}, + {8.300000191, 196, 867, 9.600000381, 106}, + {7.300000191, 125, 969, 10.5, 162}, + {9.399999619, 82, 499, 7.699999809, 95}, + {9.399999619, 125, 925, 10.19999981, 91}, + {9.800000191, 129, 353, 9.899999619, 52}, + {3.599999905, 84, 288, 8.399999619, 110}, + {8.399999619, 183, 718, 10.39999962, 69}, + {10.80000019, 119, 540, 9.199999809, 57}, + {10.10000038, 180, 668, 13, 106}, + {9, 82, 347, 8.800000191, 40}, + {10, 71, 345, 9.199999809, 50}, + {11.30000019, 118, 463, 7.800000191, 35}, + {11.30000019, 121, 728, 8.199999809, 86}, + {12.80000019, 68, 383, 7.400000095, 57}, + {10, 112, 316, 10.39999962, 57}, + {6.699999809, 109, 388, 8.899999619, 94} + }; + + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread + // because we create ignite cache internally. + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + SparseDistributedMatrixExample.class.getSimpleName(), () -> { + IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); + + System.out.println(">>> Create new normalization trainer object."); + NormalizationTrainer<Integer, double[]> normalizationTrainer = new NormalizationTrainer<>(); + + System.out.println(">>> Perform the training to get the normalization preprocessor."); + IgniteBiFunction<Integer, double[], double[]> preprocessor = normalizationTrainer.fit( + ignite, + dataCache, + (k, v) -> Arrays.copyOfRange(v, 1, v.length) + ); + + System.out.println(">>> Create new linear regression trainer object."); + LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); + + System.out.println(">>> Perform the training to get the model."); + LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, (k, v) -> v[0]); + + System.out.println(">>> Linear regression model: " + mdl); + + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); + + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + Integer key = observation.getKey(); + double[] val = observation.getValue(); + double groundTruth = val[0]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(preprocessor.apply(key, val))); + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } + } + + System.out.println(">>> ---------------------------------"); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } + + /** + * Fills cache with data and returns it. + * + * @param ignite Ignite instance. + * @return Filled Ignite Cache. + */ + private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { + CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName("TEST_" + UUID.randomUUID()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < data.length; i++) + cache.put(i, data[i]); + + return cache; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java new file mode 100644 index 0000000..ce6ad3b --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.regression.linear; + +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; +import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.thread.IgniteThread; + +import javax.cache.Cache; +import java.util.Arrays; +import java.util.UUID; + +/** + * Run linear regression model over distributed matrix. + * + * @see LinearRegressionSGDTrainer + */ +public class LinearRegressionSGDTrainerExample { + /** */ + private static final double[][] data = { + {8, 78, 284, 9.100000381, 109}, + {9.300000191, 68, 433, 8.699999809, 144}, + {7.5, 70, 739, 7.199999809, 113}, + {8.899999619, 96, 1792, 8.899999619, 97}, + {10.19999981, 74, 477, 8.300000191, 206}, + {8.300000191, 111, 362, 10.89999962, 124}, + {8.800000191, 77, 671, 10, 152}, + {8.800000191, 168, 636, 9.100000381, 162}, + {10.69999981, 82, 329, 8.699999809, 150}, + {11.69999981, 89, 634, 7.599999905, 134}, + {8.5, 149, 631, 10.80000019, 292}, + {8.300000191, 60, 257, 9.5, 108}, + {8.199999809, 96, 284, 8.800000191, 111}, + {7.900000095, 83, 603, 9.5, 182}, + {10.30000019, 130, 686, 8.699999809, 129}, + {7.400000095, 145, 345, 11.19999981, 158}, + {9.600000381, 112, 1357, 9.699999809, 186}, + {9.300000191, 131, 544, 9.600000381, 177}, + {10.60000038, 80, 205, 9.100000381, 127}, + {9.699999809, 130, 1264, 9.199999809, 179}, + {11.60000038, 140, 688, 8.300000191, 80}, + {8.100000381, 154, 354, 8.399999619, 103}, + {9.800000191, 118, 1632, 9.399999619, 101}, + {7.400000095, 94, 348, 9.800000191, 117}, + {9.399999619, 119, 370, 10.39999962, 88}, + {11.19999981, 153, 648, 9.899999619, 78}, + {9.100000381, 116, 366, 9.199999809, 102}, + {10.5, 97, 540, 10.30000019, 95}, + {11.89999962, 176, 680, 8.899999619, 80}, + {8.399999619, 75, 345, 9.600000381, 92}, + {5, 134, 525, 10.30000019, 126}, + {9.800000191, 161, 870, 10.39999962, 108}, + {9.800000191, 111, 669, 9.699999809, 77}, + {10.80000019, 114, 452, 9.600000381, 60}, + {10.10000038, 142, 430, 10.69999981, 71}, + {10.89999962, 238, 822, 10.30000019, 86}, + {9.199999809, 78, 190, 10.69999981, 93}, + {8.300000191, 196, 867, 9.600000381, 106}, + {7.300000191, 125, 969, 10.5, 162}, + {9.399999619, 82, 499, 7.699999809, 95}, + {9.399999619, 125, 925, 10.19999981, 91}, + {9.800000191, 129, 353, 9.899999619, 52}, + {3.599999905, 84, 288, 8.399999619, 110}, + {8.399999619, 183, 718, 10.39999962, 69}, + {10.80000019, 119, 540, 9.199999809, 57}, + {10.10000038, 180, 668, 13, 106}, + {9, 82, 347, 8.800000191, 40}, + {10, 71, 345, 9.199999809, 50}, + {11.30000019, 118, 463, 7.800000191, 35}, + {11.30000019, 121, 728, 8.199999809, 86}, + {12.80000019, 68, 383, 7.400000095, 57}, + {10, 112, 316, 10.39999962, 57}, + {6.699999809, 109, 388, 8.899999619, 94} + }; + + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread + // because we create ignite cache internally. + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + LinearRegressionSGDTrainerExample.class.getSimpleName(), () -> { + + IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); + + System.out.println(">>> Create new linear regression trainer object."); + LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>( + new RPropUpdateCalculator(), + RPropParameterUpdate::sumLocal, + RPropParameterUpdate::avg + ), 100000, 10, 100, 123L); + + System.out.println(">>> Perform the training to get the model."); + LinearRegressionModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ); + + System.out.println(">>> Linear regression model: " + mdl); + + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); + + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs)); + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } + } + + System.out.println(">>> ---------------------------------"); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } + + /** + * Fills cache with data and returns it. + * + * @param ignite Ignite instance. + * @return Filled Ignite Cache. + */ + private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { + CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName("TEST_" + UUID.randomUUID()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < data.length; i++) + cache.put(i, data[i]); + + return cache; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java deleted file mode 100644 index f53b801..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml; - -/** - * Interface for Trainers. Trainer is just a function which produces model from the data. - * - * @param <M> Type of produced model. - * @param <T> Type of data needed for model producing. - */ -// TODO: IGNITE-7659: Reduce multiple Trainer interfaces to one -@Deprecated -public interface Trainer<M extends Model, T> { - /** - * Returns model based on data - * - * @param data data to build model - * @return model - */ - M train(T data); -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/estimators/Estimators.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/estimators/Estimators.java b/modules/ml/src/main/java/org/apache/ignite/ml/estimators/Estimators.java deleted file mode 100644 index b2731ff..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/estimators/Estimators.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.estimators; - -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Function; -import java.util.stream.Stream; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.functions.IgniteTriFunction; - -/** Estimators. */ -public class Estimators { - /** Simple implementation of mean squared error estimator. */ - public static <T, V> IgniteTriFunction<Model<T, V>, Stream<IgniteBiTuple<T, V>>, Function<V, Double>, Double> MSE() { - return (model, stream, f) -> stream.mapToDouble(dp -> { - double diff = f.apply(dp.get2()) - f.apply(model.apply(dp.get1())); - return diff * diff; - }).average().orElse(0); - } - - /** Simple implementation of errors percentage estimator. */ - public static <T, V> IgniteTriFunction<Model<T, V>, Stream<IgniteBiTuple<T, V>>, Function<V, Double>, Double> errorsPercentage() { - return (model, stream, f) -> { - AtomicLong total = new AtomicLong(0); - - long cnt = stream. - peek((ib) -> total.incrementAndGet()). - filter(dp -> !model.apply(dp.get1()).equals(dp.get2())). - count(); - - return (double)cnt / total.get(); - }; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/estimators/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/estimators/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/estimators/package-info.java deleted file mode 100644 index c03827f..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/estimators/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains estimation algorithms. - */ -package org.apache.ignite.ml.estimators; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteBiFunction.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteBiFunction.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteBiFunction.java index dc49739..45fd035 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteBiFunction.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteBiFunction.java @@ -18,6 +18,7 @@ package org.apache.ignite.ml.math.functions; import java.io.Serializable; +import java.util.Objects; import java.util.function.BiFunction; /** @@ -25,5 +26,10 @@ import java.util.function.BiFunction; * * @see java.util.function.BiFunction */ -public interface IgniteBiFunction<A, B, T> extends BiFunction<A, B, T>, Serializable { +public interface IgniteBiFunction<T, U, R> extends BiFunction<T, U, R>, Serializable { + /** {@inheritDoc} */ + default <V> IgniteBiFunction<T, U, V> andThen(IgniteFunction<? super R, ? extends V> after) { + Objects.requireNonNull(after); + return (T t, U u) -> after.apply(apply(t, u)); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java deleted file mode 100644 index e80b935..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.math.isolve; - -import java.io.Serializable; -import java.util.Iterator; -import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.dataset.UpstreamEntry; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; - -/** - * Linear system partition data builder that builds {@link LinSysPartitionDataOnHeap}. - * - * @param <K> Type of a key in <tt>upstream</tt> data. - * @param <V> Type of a value in <tt>upstream</tt> data. - * @param <C> Type of a partition <tt>context</tt>. - */ -public class LinSysPartitionDataBuilderOnHeap<K, V, C extends Serializable> - implements PartitionDataBuilder<K, V, C, LinSysPartitionDataOnHeap> { - /** */ - private static final long serialVersionUID = -7820760153954269227L; - - /** Extractor of X matrix row. */ - private final IgniteBiFunction<K, V, double[]> xExtractor; - - /** Extractor of Y vector value. */ - private final IgniteBiFunction<K, V, Double> yExtractor; - - /** - * Constructs a new instance of linear system partition data builder. - * - * @param xExtractor Extractor of X matrix row. - * @param yExtractor Extractor of Y vector value. - */ - public LinSysPartitionDataBuilderOnHeap(IgniteBiFunction<K, V, double[]> xExtractor, - IgniteBiFunction<K, V, Double> yExtractor) { - this.xExtractor = xExtractor; - this.yExtractor = yExtractor; - } - - /** {@inheritDoc} */ - @Override public LinSysPartitionDataOnHeap build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, - C ctx) { - // Prepares the matrix of features in flat column-major format. - int xCols = -1; - double[] x = null;//new double[Math.toIntExact(upstreamDataSize * cols)]; - double[] y = new double[Math.toIntExact(upstreamDataSize)]; - - int ptr = 0; - while (upstreamData.hasNext()) { - UpstreamEntry<K, V> entry = upstreamData.next(); - double[] row = xExtractor.apply(entry.getKey(), entry.getValue()); - - if (xCols < 0) { - xCols = row.length; - x = new double[Math.toIntExact(upstreamDataSize * xCols)]; - } - else - assert row.length == xCols : "X extractor must return exactly " + xCols + " columns"; - - for (int i = 0; i < xCols; i++) - x[Math.toIntExact(i * upstreamDataSize) + ptr] = row[i]; - - y[ptr] = yExtractor.apply(entry.getKey(), entry.getValue()); - - ptr++; - } - - return new LinSysPartitionDataOnHeap(x, y, Math.toIntExact(upstreamDataSize)); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java deleted file mode 100644 index 89c8e44..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.math.isolve; - -/** - * On Heap partition data that keeps part of a linear system. - */ -public class LinSysPartitionDataOnHeap implements AutoCloseable { - /** Part of X matrix. */ - private final double[] x; - - /** Part of Y vector. */ - private final double[] y; - - /** Number of rows. */ - private final int rows; - - /** - * Constructs a new instance of linear system partition data. - * - * @param x Part of X matrix. - * @param y Part of Y vector. - * @param rows Number of rows. - */ - public LinSysPartitionDataOnHeap(double[] x, double[] y, int rows) { - this.x = x; - this.rows = rows; - this.y = y; - } - - /** */ - public double[] getX() { - return x; - } - - /** */ - public int getRows() { - return rows; - } - - /** */ - public double[] getY() { - return y; - } - - /** {@inheritDoc} */ - @Override public void close() { - // Do nothing, GC will clean up. - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java index 8d190cd..d1d3219 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.math.isolve.lsqr; import com.github.fommil.netlib.BLAS; import java.util.Arrays; +import org.apache.ignite.ml.math.Precision; /** * Basic implementation of the LSQR algorithm without assumptions about dataset storage format or data processing @@ -30,7 +31,7 @@ import java.util.Arrays; // TODO: IGNITE-7660: Refactor LSQR algorithm public abstract class AbstractLSQR { /** The smallest representable positive number such that 1.0 + eps != 1.0. */ - private static final double eps = Double.longBitsToDouble(Double.doubleToLongBits(1.0) | 1) - 1.0; + private static final double eps = Precision.EPSILON; /** BLAS (Basic Linear Algebra Subprograms) instance. */ private static BLAS blas = BLAS.getInstance(); http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java index b1cc4c9..e138cf3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java @@ -22,14 +22,14 @@ import java.util.Arrays; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.math.isolve.LinSysPartitionDataOnHeap; +import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData; /** * Distributed implementation of LSQR algorithm based on {@link AbstractLSQR} and {@link Dataset}. */ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable { /** Dataset. */ - private final Dataset<LSQRPartitionContext, LinSysPartitionDataOnHeap> dataset; + private final Dataset<LSQRPartitionContext, SimpleLabeledDatasetData> dataset; /** * Constructs a new instance of OnHeap LSQR algorithm implementation. @@ -38,7 +38,7 @@ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable { * @param partDataBuilder Partition data builder. */ public LSQROnHeap(DatasetBuilder<K, V> datasetBuilder, - PartitionDataBuilder<K, V, LSQRPartitionContext, LinSysPartitionDataOnHeap> partDataBuilder) { + PartitionDataBuilder<K, V, LSQRPartitionContext, SimpleLabeledDatasetData> partDataBuilder) { this.dataset = datasetBuilder.build( (upstream, upstreamSize) -> new LSQRPartitionContext(), partDataBuilder @@ -48,20 +48,20 @@ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable { /** {@inheritDoc} */ @Override protected double bnorm() { return dataset.computeWithCtx((ctx, data) -> { - ctx.setU(Arrays.copyOf(data.getY(), data.getY().length)); + ctx.setU(Arrays.copyOf(data.getLabels(), data.getLabels().length)); - return BLAS.getInstance().dnrm2(data.getY().length, data.getY(), 1); + return BLAS.getInstance().dnrm2(data.getLabels().length, data.getLabels(), 1); }, (a, b) -> a == null ? b : b == null ? a : Math.sqrt(a * a + b * b)); } /** {@inheritDoc} */ @Override protected double beta(double[] x, double alfa, double beta) { return dataset.computeWithCtx((ctx, data) -> { - if (data.getX() == null) + if (data.getFeatures() == null) return null; - int cols = data.getX().length / data.getRows(); - BLAS.getInstance().dgemv("N", data.getRows(), cols, alfa, data.getX(), + int cols = data.getFeatures().length / data.getRows(); + BLAS.getInstance().dgemv("N", data.getRows(), cols, alfa, data.getFeatures(), Math.max(1, data.getRows()), x, 1, beta, ctx.getU(), 1); return BLAS.getInstance().dnrm2(ctx.getU().length, ctx.getU(), 1); @@ -71,13 +71,13 @@ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable { /** {@inheritDoc} */ @Override protected double[] iter(double bnorm, double[] target) { double[] res = dataset.computeWithCtx((ctx, data) -> { - if (data.getX() == null) + if (data.getFeatures() == null) return null; - int cols = data.getX().length / data.getRows(); + int cols = data.getFeatures().length / data.getRows(); BLAS.getInstance().dscal(ctx.getU().length, 1 / bnorm, ctx.getU(), 1); double[] v = new double[cols]; - BLAS.getInstance().dgemv("T", data.getRows(), cols, 1.0, data.getX(), + BLAS.getInstance().dgemv("T", data.getRows(), cols, 1.0, data.getFeatures(), Math.max(1, data.getRows()), ctx.getU(), 1, 0, v, 1); return v; @@ -101,7 +101,10 @@ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable { * @return number of columns */ @Override protected int getColumns() { - return dataset.compute(data -> data.getX() == null ? null : data.getX().length / data.getRows(), (a, b) -> a == null ? b : a); + return dataset.compute( + data -> data.getFeatures() == null ? null : data.getFeatures().length / data.getRows(), + (a, b) -> a == null ? b : a + ); } /** {@inheritDoc} */ http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java index fe955cb..d12a276 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java @@ -33,7 +33,6 @@ import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.nn.initializers.RandomInitializer; import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator; import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer; -import org.apache.ignite.ml.trainers.group.UpdatesStrategy; import org.apache.ignite.ml.util.Utils; import java.io.Serializable;