IGNITE-7438: LSQR solver for Linear Regression this closes #3494
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/2f330a1c Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/2f330a1c Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/2f330a1c Branch: refs/heads/master Commit: 2f330a1cd1430a198ac0ddd81075ef20ff1e7316 Parents: 64c9f50 Author: dmitrievanthony <dmitrievanth...@gmail.com> Authored: Fri Feb 9 14:17:27 2018 +0300 Committer: Yury Babak <yba...@gridgain.com> Committed: Fri Feb 9 14:17:27 2018 +0300 ---------------------------------------------------------------------- ...tedLinearRegressionExampleWithQRTrainer.java | 136 -------- ...edLinearRegressionExampleWithSGDTrainer.java | 137 -------- ...dLinearRegressionWithLSQRTrainerExample.java | 170 ++++++++++ ...tedLinearRegressionWithQRTrainerExample.java | 136 ++++++++ ...edLinearRegressionWithSGDTrainerExample.java | 137 ++++++++ .../org/apache/ignite/ml/DatasetTrainer.java | 42 +++ .../main/java/org/apache/ignite/ml/Trainer.java | 2 + .../ml/math/isolve/IterativeSolverResult.java | 64 ++++ .../LinSysPartitionDataBuilderOnHeap.java | 85 +++++ .../math/isolve/LinSysPartitionDataOnHeap.java | 75 +++++ .../ml/math/isolve/lsqr/AbstractLSQR.java | 333 +++++++++++++++++++ .../ignite/ml/math/isolve/lsqr/LSQROnHeap.java | 102 ++++++ .../math/isolve/lsqr/LSQRPartitionContext.java | 41 +++ .../ignite/ml/math/isolve/lsqr/LSQRResult.java | 140 ++++++++ .../ml/math/isolve/lsqr/package-info.java | 22 ++ .../ignite/ml/math/isolve/package-info.java | 22 ++ .../linear/LinearRegressionLSQRTrainer.java | 70 ++++ .../org/apache/ignite/ml/trainers/Trainer.java | 2 + .../ignite/ml/math/MathImplLocalTestSuite.java | 4 +- .../ml/math/isolve/lsqr/LSQROnHeapTest.java | 134 ++++++++ .../ml/regressions/RegressionsTestSuite.java | 4 +- .../linear/LinearRegressionLSQRTrainerTest.java | 124 +++++++ 22 files changed, 1707 insertions(+), 275 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionExampleWithQRTrainer.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionExampleWithQRTrainer.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionExampleWithQRTrainer.java deleted file mode 100644 index 98ff2a2..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionExampleWithQRTrainer.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.regression.linear; - -import java.util.Arrays; -import org.apache.ignite.Ignite; -import org.apache.ignite.Ignition; -import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; -import org.apache.ignite.ml.Trainer; -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; -import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector; -import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.ml.regressions.linear.LinearRegressionQRTrainer; -import org.apache.ignite.thread.IgniteThread; - -/** - * Run linear regression model over distributed matrix. - * - * @see LinearRegressionQRTrainer - */ -public class DistributedLinearRegressionExampleWithQRTrainer { - /** */ - private static final double[][] data = { - {8, 78, 284, 9.100000381, 109}, - {9.300000191, 68, 433, 8.699999809, 144}, - {7.5, 70, 739, 7.199999809, 113}, - {8.899999619, 96, 1792, 8.899999619, 97}, - {10.19999981, 74, 477, 8.300000191, 206}, - {8.300000191, 111, 362, 10.89999962, 124}, - {8.800000191, 77, 671, 10, 152}, - {8.800000191, 168, 636, 9.100000381, 162}, - {10.69999981, 82, 329, 8.699999809, 150}, - {11.69999981, 89, 634, 7.599999905, 134}, - {8.5, 149, 631, 10.80000019, 292}, - {8.300000191, 60, 257, 9.5, 108}, - {8.199999809, 96, 284, 8.800000191, 111}, - {7.900000095, 83, 603, 9.5, 182}, - {10.30000019, 130, 686, 8.699999809, 129}, - {7.400000095, 145, 345, 11.19999981, 158}, - {9.600000381, 112, 1357, 9.699999809, 186}, - {9.300000191, 131, 544, 9.600000381, 177}, - {10.60000038, 80, 205, 9.100000381, 127}, - {9.699999809, 130, 1264, 9.199999809, 179}, - {11.60000038, 140, 688, 8.300000191, 80}, - {8.100000381, 154, 354, 8.399999619, 103}, - {9.800000191, 118, 1632, 9.399999619, 101}, - {7.400000095, 94, 348, 9.800000191, 117}, - {9.399999619, 119, 370, 10.39999962, 88}, - {11.19999981, 153, 648, 9.899999619, 78}, - {9.100000381, 116, 366, 9.199999809, 102}, - {10.5, 97, 540, 10.30000019, 95}, - {11.89999962, 176, 680, 8.899999619, 80}, - {8.399999619, 75, 345, 9.600000381, 92}, - {5, 134, 525, 10.30000019, 126}, - {9.800000191, 161, 870, 10.39999962, 108}, - {9.800000191, 111, 669, 9.699999809, 77}, - {10.80000019, 114, 452, 9.600000381, 60}, - {10.10000038, 142, 430, 10.69999981, 71}, - {10.89999962, 238, 822, 10.30000019, 86}, - {9.199999809, 78, 190, 10.69999981, 93}, - {8.300000191, 196, 867, 9.600000381, 106}, - {7.300000191, 125, 969, 10.5, 162}, - {9.399999619, 82, 499, 7.699999809, 95}, - {9.399999619, 125, 925, 10.19999981, 91}, - {9.800000191, 129, 353, 9.899999619, 52}, - {3.599999905, 84, 288, 8.399999619, 110}, - {8.399999619, 183, 718, 10.39999962, 69}, - {10.80000019, 119, 540, 9.199999809, 57}, - {10.10000038, 180, 668, 13, 106}, - {9, 82, 347, 8.800000191, 40}, - {10, 71, 345, 9.199999809, 50}, - {11.30000019, 118, 463, 7.800000191, 35}, - {11.30000019, 121, 728, 8.199999809, 86}, - {12.80000019, 68, 383, 7.400000095, 57}, - {10, 112, 316, 10.39999962, 57}, - {6.699999809, 109, 388, 8.899999619, 94} - }; - - /** Run example. */ - public static void main(String[] args) throws InterruptedException { - System.out.println(); - System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - SparseDistributedMatrixExample.class.getSimpleName(), () -> { - - // Create SparseDistributedMatrix, new cache will be created automagically. - System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread."); - SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(data); - - System.out.println(">>> Create new linear regression trainer object."); - Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionQRTrainer(); - - System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel model = trainer.train(distributedMatrix); - System.out.println(">>> Linear regression model: " + model); - - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); - for (double[] observation : data) { - Vector inputs = new SparseDistributedVector(Arrays.copyOfRange(observation, 1, observation.length)); - double prediction = model.apply(inputs); - double groundTruth = observation[0]; - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } - System.out.println(">>> ---------------------------------"); - }); - - igniteThread.start(); - - igniteThread.join(); - } - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionExampleWithSGDTrainer.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionExampleWithSGDTrainer.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionExampleWithSGDTrainer.java deleted file mode 100644 index 3f61762..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionExampleWithSGDTrainer.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.regression.linear; - -import java.util.Arrays; -import org.apache.ignite.Ignite; -import org.apache.ignite.Ignition; -import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; -import org.apache.ignite.ml.Trainer; -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; -import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector; -import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.ml.regressions.linear.LinearRegressionQRTrainer; -import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer; -import org.apache.ignite.thread.IgniteThread; - -/** - * Run linear regression model over distributed matrix. - * - * @see LinearRegressionQRTrainer - */ -public class DistributedLinearRegressionExampleWithSGDTrainer { - /** */ - private static final double[][] data = { - {8, 78, 284, 9.100000381, 109}, - {9.300000191, 68, 433, 8.699999809, 144}, - {7.5, 70, 739, 7.199999809, 113}, - {8.899999619, 96, 1792, 8.899999619, 97}, - {10.19999981, 74, 477, 8.300000191, 206}, - {8.300000191, 111, 362, 10.89999962, 124}, - {8.800000191, 77, 671, 10, 152}, - {8.800000191, 168, 636, 9.100000381, 162}, - {10.69999981, 82, 329, 8.699999809, 150}, - {11.69999981, 89, 634, 7.599999905, 134}, - {8.5, 149, 631, 10.80000019, 292}, - {8.300000191, 60, 257, 9.5, 108}, - {8.199999809, 96, 284, 8.800000191, 111}, - {7.900000095, 83, 603, 9.5, 182}, - {10.30000019, 130, 686, 8.699999809, 129}, - {7.400000095, 145, 345, 11.19999981, 158}, - {9.600000381, 112, 1357, 9.699999809, 186}, - {9.300000191, 131, 544, 9.600000381, 177}, - {10.60000038, 80, 205, 9.100000381, 127}, - {9.699999809, 130, 1264, 9.199999809, 179}, - {11.60000038, 140, 688, 8.300000191, 80}, - {8.100000381, 154, 354, 8.399999619, 103}, - {9.800000191, 118, 1632, 9.399999619, 101}, - {7.400000095, 94, 348, 9.800000191, 117}, - {9.399999619, 119, 370, 10.39999962, 88}, - {11.19999981, 153, 648, 9.899999619, 78}, - {9.100000381, 116, 366, 9.199999809, 102}, - {10.5, 97, 540, 10.30000019, 95}, - {11.89999962, 176, 680, 8.899999619, 80}, - {8.399999619, 75, 345, 9.600000381, 92}, - {5, 134, 525, 10.30000019, 126}, - {9.800000191, 161, 870, 10.39999962, 108}, - {9.800000191, 111, 669, 9.699999809, 77}, - {10.80000019, 114, 452, 9.600000381, 60}, - {10.10000038, 142, 430, 10.69999981, 71}, - {10.89999962, 238, 822, 10.30000019, 86}, - {9.199999809, 78, 190, 10.69999981, 93}, - {8.300000191, 196, 867, 9.600000381, 106}, - {7.300000191, 125, 969, 10.5, 162}, - {9.399999619, 82, 499, 7.699999809, 95}, - {9.399999619, 125, 925, 10.19999981, 91}, - {9.800000191, 129, 353, 9.899999619, 52}, - {3.599999905, 84, 288, 8.399999619, 110}, - {8.399999619, 183, 718, 10.39999962, 69}, - {10.80000019, 119, 540, 9.199999809, 57}, - {10.10000038, 180, 668, 13, 106}, - {9, 82, 347, 8.800000191, 40}, - {10, 71, 345, 9.199999809, 50}, - {11.30000019, 118, 463, 7.800000191, 35}, - {11.30000019, 121, 728, 8.199999809, 86}, - {12.80000019, 68, 383, 7.400000095, 57}, - {10, 112, 316, 10.39999962, 57}, - {6.699999809, 109, 388, 8.899999619, 94} - }; - - /** Run example. */ - public static void main(String[] args) throws InterruptedException { - System.out.println(); - System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - SparseDistributedMatrixExample.class.getSimpleName(), () -> { - - // Create SparseDistributedMatrix, new cache will be created automagically. - System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread."); - SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(data); - - System.out.println(">>> Create new linear regression trainer object."); - Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionSGDTrainer(100_000, 1e-12); - - System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel model = trainer.train(distributedMatrix); - System.out.println(">>> Linear regression model: " + model); - - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); - for (double[] observation : data) { - Vector inputs = new SparseDistributedVector(Arrays.copyOfRange(observation, 1, observation.length)); - double prediction = model.apply(inputs); - double groundTruth = observation[0]; - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } - System.out.println(">>> ---------------------------------"); - }); - - igniteThread.start(); - - igniteThread.join(); - } - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java new file mode 100644 index 0000000..20e0653 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.regression.linear; + +import java.util.Arrays; +import java.util.UUID; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; +import org.apache.ignite.thread.IgniteThread; + +/** + * Run linear regression model over distributed matrix. + * + * @see LinearRegressionLSQRTrainer + */ +public class DistributedLinearRegressionWithLSQRTrainerExample { + /** */ + private static final double[][] data = { + {8, 78, 284, 9.100000381, 109}, + {9.300000191, 68, 433, 8.699999809, 144}, + {7.5, 70, 739, 7.199999809, 113}, + {8.899999619, 96, 1792, 8.899999619, 97}, + {10.19999981, 74, 477, 8.300000191, 206}, + {8.300000191, 111, 362, 10.89999962, 124}, + {8.800000191, 77, 671, 10, 152}, + {8.800000191, 168, 636, 9.100000381, 162}, + {10.69999981, 82, 329, 8.699999809, 150}, + {11.69999981, 89, 634, 7.599999905, 134}, + {8.5, 149, 631, 10.80000019, 292}, + {8.300000191, 60, 257, 9.5, 108}, + {8.199999809, 96, 284, 8.800000191, 111}, + {7.900000095, 83, 603, 9.5, 182}, + {10.30000019, 130, 686, 8.699999809, 129}, + {7.400000095, 145, 345, 11.19999981, 158}, + {9.600000381, 112, 1357, 9.699999809, 186}, + {9.300000191, 131, 544, 9.600000381, 177}, + {10.60000038, 80, 205, 9.100000381, 127}, + {9.699999809, 130, 1264, 9.199999809, 179}, + {11.60000038, 140, 688, 8.300000191, 80}, + {8.100000381, 154, 354, 8.399999619, 103}, + {9.800000191, 118, 1632, 9.399999619, 101}, + {7.400000095, 94, 348, 9.800000191, 117}, + {9.399999619, 119, 370, 10.39999962, 88}, + {11.19999981, 153, 648, 9.899999619, 78}, + {9.100000381, 116, 366, 9.199999809, 102}, + {10.5, 97, 540, 10.30000019, 95}, + {11.89999962, 176, 680, 8.899999619, 80}, + {8.399999619, 75, 345, 9.600000381, 92}, + {5, 134, 525, 10.30000019, 126}, + {9.800000191, 161, 870, 10.39999962, 108}, + {9.800000191, 111, 669, 9.699999809, 77}, + {10.80000019, 114, 452, 9.600000381, 60}, + {10.10000038, 142, 430, 10.69999981, 71}, + {10.89999962, 238, 822, 10.30000019, 86}, + {9.199999809, 78, 190, 10.69999981, 93}, + {8.300000191, 196, 867, 9.600000381, 106}, + {7.300000191, 125, 969, 10.5, 162}, + {9.399999619, 82, 499, 7.699999809, 95}, + {9.399999619, 125, 925, 10.19999981, 91}, + {9.800000191, 129, 353, 9.899999619, 52}, + {3.599999905, 84, 288, 8.399999619, 110}, + {8.399999619, 183, 718, 10.39999962, 69}, + {10.80000019, 119, 540, 9.199999809, 57}, + {10.10000038, 180, 668, 13, 106}, + {9, 82, 347, 8.800000191, 40}, + {10, 71, 345, 9.199999809, 50}, + {11.30000019, 118, 463, 7.800000191, 35}, + {11.30000019, 121, 728, 8.199999809, 86}, + {12.80000019, 68, 383, 7.400000095, 57}, + {10, 112, 316, 10.39999962, 57}, + {6.699999809, 109, 388, 8.899999619, 94} + }; + + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread + // because we create ignite cache internally. + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + SparseDistributedMatrixExample.class.getSimpleName(), () -> { + IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); + + System.out.println(">>> Create new linear regression trainer object."); + LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>(); + + System.out.println(">>> Perform the training to get the model."); + LinearRegressionModel mdl = trainer.fit( + new CacheBasedDatasetBuilder<>(ignite, dataCache), + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0], + 4 + ); + + System.out.println(">>> Linear regression model: " + mdl); + + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); + + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs)); + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } + } + + System.out.println(">>> ---------------------------------"); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } + + /** + * Fills cache with data and returns it. + * + * @param ignite Ignite instance. + * @return Filled Ignite Cache. + */ + private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { + CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName("TEST_" + UUID.randomUUID()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < data.length; i++) + cache.put(i, data[i]); + + return cache; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java new file mode 100644 index 0000000..2b45aa2 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.regression.linear; + +import java.util.Arrays; +import org.apache.ignite.Ignite; +import org.apache.ignite.Ignition; +import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; +import org.apache.ignite.ml.Trainer; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; +import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; +import org.apache.ignite.ml.regressions.linear.LinearRegressionQRTrainer; +import org.apache.ignite.thread.IgniteThread; + +/** + * Run linear regression model over distributed matrix. + * + * @see LinearRegressionQRTrainer + */ +public class DistributedLinearRegressionWithQRTrainerExample { + /** */ + private static final double[][] data = { + {8, 78, 284, 9.100000381, 109}, + {9.300000191, 68, 433, 8.699999809, 144}, + {7.5, 70, 739, 7.199999809, 113}, + {8.899999619, 96, 1792, 8.899999619, 97}, + {10.19999981, 74, 477, 8.300000191, 206}, + {8.300000191, 111, 362, 10.89999962, 124}, + {8.800000191, 77, 671, 10, 152}, + {8.800000191, 168, 636, 9.100000381, 162}, + {10.69999981, 82, 329, 8.699999809, 150}, + {11.69999981, 89, 634, 7.599999905, 134}, + {8.5, 149, 631, 10.80000019, 292}, + {8.300000191, 60, 257, 9.5, 108}, + {8.199999809, 96, 284, 8.800000191, 111}, + {7.900000095, 83, 603, 9.5, 182}, + {10.30000019, 130, 686, 8.699999809, 129}, + {7.400000095, 145, 345, 11.19999981, 158}, + {9.600000381, 112, 1357, 9.699999809, 186}, + {9.300000191, 131, 544, 9.600000381, 177}, + {10.60000038, 80, 205, 9.100000381, 127}, + {9.699999809, 130, 1264, 9.199999809, 179}, + {11.60000038, 140, 688, 8.300000191, 80}, + {8.100000381, 154, 354, 8.399999619, 103}, + {9.800000191, 118, 1632, 9.399999619, 101}, + {7.400000095, 94, 348, 9.800000191, 117}, + {9.399999619, 119, 370, 10.39999962, 88}, + {11.19999981, 153, 648, 9.899999619, 78}, + {9.100000381, 116, 366, 9.199999809, 102}, + {10.5, 97, 540, 10.30000019, 95}, + {11.89999962, 176, 680, 8.899999619, 80}, + {8.399999619, 75, 345, 9.600000381, 92}, + {5, 134, 525, 10.30000019, 126}, + {9.800000191, 161, 870, 10.39999962, 108}, + {9.800000191, 111, 669, 9.699999809, 77}, + {10.80000019, 114, 452, 9.600000381, 60}, + {10.10000038, 142, 430, 10.69999981, 71}, + {10.89999962, 238, 822, 10.30000019, 86}, + {9.199999809, 78, 190, 10.69999981, 93}, + {8.300000191, 196, 867, 9.600000381, 106}, + {7.300000191, 125, 969, 10.5, 162}, + {9.399999619, 82, 499, 7.699999809, 95}, + {9.399999619, 125, 925, 10.19999981, 91}, + {9.800000191, 129, 353, 9.899999619, 52}, + {3.599999905, 84, 288, 8.399999619, 110}, + {8.399999619, 183, 718, 10.39999962, 69}, + {10.80000019, 119, 540, 9.199999809, 57}, + {10.10000038, 180, 668, 13, 106}, + {9, 82, 347, 8.800000191, 40}, + {10, 71, 345, 9.199999809, 50}, + {11.30000019, 118, 463, 7.800000191, 35}, + {11.30000019, 121, 728, 8.199999809, 86}, + {12.80000019, 68, 383, 7.400000095, 57}, + {10, 112, 316, 10.39999962, 57}, + {6.699999809, 109, 388, 8.899999619, 94} + }; + + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread + // because we create ignite cache internally. + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + SparseDistributedMatrixExample.class.getSimpleName(), () -> { + + // Create SparseDistributedMatrix, new cache will be created automagically. + System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread."); + SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(data); + + System.out.println(">>> Create new linear regression trainer object."); + Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionQRTrainer(); + + System.out.println(">>> Perform the training to get the model."); + LinearRegressionModel model = trainer.train(distributedMatrix); + System.out.println(">>> Linear regression model: " + model); + + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); + for (double[] observation : data) { + Vector inputs = new SparseDistributedVector(Arrays.copyOfRange(observation, 1, observation.length)); + double prediction = model.apply(inputs); + double groundTruth = observation[0]; + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } + System.out.println(">>> ---------------------------------"); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java new file mode 100644 index 0000000..f3b2655 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.regression.linear; + +import java.util.Arrays; +import org.apache.ignite.Ignite; +import org.apache.ignite.Ignition; +import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; +import org.apache.ignite.ml.Trainer; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; +import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; +import org.apache.ignite.ml.regressions.linear.LinearRegressionQRTrainer; +import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer; +import org.apache.ignite.thread.IgniteThread; + +/** + * Run linear regression model over distributed matrix. + * + * @see LinearRegressionQRTrainer + */ +public class DistributedLinearRegressionWithSGDTrainerExample { + /** */ + private static final double[][] data = { + {8, 78, 284, 9.100000381, 109}, + {9.300000191, 68, 433, 8.699999809, 144}, + {7.5, 70, 739, 7.199999809, 113}, + {8.899999619, 96, 1792, 8.899999619, 97}, + {10.19999981, 74, 477, 8.300000191, 206}, + {8.300000191, 111, 362, 10.89999962, 124}, + {8.800000191, 77, 671, 10, 152}, + {8.800000191, 168, 636, 9.100000381, 162}, + {10.69999981, 82, 329, 8.699999809, 150}, + {11.69999981, 89, 634, 7.599999905, 134}, + {8.5, 149, 631, 10.80000019, 292}, + {8.300000191, 60, 257, 9.5, 108}, + {8.199999809, 96, 284, 8.800000191, 111}, + {7.900000095, 83, 603, 9.5, 182}, + {10.30000019, 130, 686, 8.699999809, 129}, + {7.400000095, 145, 345, 11.19999981, 158}, + {9.600000381, 112, 1357, 9.699999809, 186}, + {9.300000191, 131, 544, 9.600000381, 177}, + {10.60000038, 80, 205, 9.100000381, 127}, + {9.699999809, 130, 1264, 9.199999809, 179}, + {11.60000038, 140, 688, 8.300000191, 80}, + {8.100000381, 154, 354, 8.399999619, 103}, + {9.800000191, 118, 1632, 9.399999619, 101}, + {7.400000095, 94, 348, 9.800000191, 117}, + {9.399999619, 119, 370, 10.39999962, 88}, + {11.19999981, 153, 648, 9.899999619, 78}, + {9.100000381, 116, 366, 9.199999809, 102}, + {10.5, 97, 540, 10.30000019, 95}, + {11.89999962, 176, 680, 8.899999619, 80}, + {8.399999619, 75, 345, 9.600000381, 92}, + {5, 134, 525, 10.30000019, 126}, + {9.800000191, 161, 870, 10.39999962, 108}, + {9.800000191, 111, 669, 9.699999809, 77}, + {10.80000019, 114, 452, 9.600000381, 60}, + {10.10000038, 142, 430, 10.69999981, 71}, + {10.89999962, 238, 822, 10.30000019, 86}, + {9.199999809, 78, 190, 10.69999981, 93}, + {8.300000191, 196, 867, 9.600000381, 106}, + {7.300000191, 125, 969, 10.5, 162}, + {9.399999619, 82, 499, 7.699999809, 95}, + {9.399999619, 125, 925, 10.19999981, 91}, + {9.800000191, 129, 353, 9.899999619, 52}, + {3.599999905, 84, 288, 8.399999619, 110}, + {8.399999619, 183, 718, 10.39999962, 69}, + {10.80000019, 119, 540, 9.199999809, 57}, + {10.10000038, 180, 668, 13, 106}, + {9, 82, 347, 8.800000191, 40}, + {10, 71, 345, 9.199999809, 50}, + {11.30000019, 118, 463, 7.800000191, 35}, + {11.30000019, 121, 728, 8.199999809, 86}, + {12.80000019, 68, 383, 7.400000095, 57}, + {10, 112, 316, 10.39999962, 57}, + {6.699999809, 109, 388, 8.899999619, 94} + }; + + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread + // because we create ignite cache internally. + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + SparseDistributedMatrixExample.class.getSimpleName(), () -> { + + // Create SparseDistributedMatrix, new cache will be created automagically. + System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread."); + SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(data); + + System.out.println(">>> Create new linear regression trainer object."); + Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionSGDTrainer(100_000, 1e-12); + + System.out.println(">>> Perform the training to get the model."); + LinearRegressionModel model = trainer.train(distributedMatrix); + System.out.println(">>> Linear regression model: " + model); + + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); + for (double[] observation : data) { + Vector inputs = new SparseDistributedVector(Arrays.copyOfRange(observation, 1, observation.length)); + double prediction = model.apply(inputs); + double groundTruth = observation[0]; + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } + System.out.println(">>> ---------------------------------"); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/DatasetTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/DatasetTrainer.java new file mode 100644 index 0000000..aa04d8e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/DatasetTrainer.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml; + +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; + +/** + * Interface for trainers. Trainer is just a function which produces model from the data. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @param <M> Type of a produced model. + */ +public interface DatasetTrainer<K, V, M extends Model> { + /** + * Trains model based on the specified data. + * + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param cols Number of columns. + * @return Model. + */ + public M fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor, int cols); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java index a95a1cc..4e0a570 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java @@ -26,6 +26,8 @@ import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer * @param <M> Type of produced model. * @param <T> Type of data needed for model producing. */ +// TODO: IGNITE-7659: Reduce multiple Trainer interfaces to one +@Deprecated public interface Trainer<M extends Model, T> { /** * Returns model based on data http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/IterativeSolverResult.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/IterativeSolverResult.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/IterativeSolverResult.java new file mode 100644 index 0000000..fe39ad7 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/IterativeSolverResult.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.isolve; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * Base class for iterative solver results. + */ +public class IterativeSolverResult implements Serializable { + /** */ + private static final long serialVersionUID = 8084061028708491097L; + + /** The final solution. */ + private final double[] x; + + /** Iteration number upon termination. */ + private final int iterations; + + /** + * Constructs a new instance of iterative solver result. + * + * @param x The final solution. + * @param iterations Iteration number upon termination. + */ + public IterativeSolverResult(double[] x, int iterations) { + this.x = x; + this.iterations = iterations; + } + + /** */ + public double[] getX() { + return x; + } + + /** */ + public int getIterations() { + return iterations; + } + + /** */ + @Override public String toString() { + return "IterativeSolverResult{" + + "x=" + Arrays.toString(x) + + ", iterations=" + iterations + + '}'; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java new file mode 100644 index 0000000..1c2e2cf --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.isolve; + +import java.io.Serializable; +import java.util.Iterator; +import org.apache.ignite.ml.dataset.PartitionDataBuilder; +import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; + +/** + * Linear system partition data builder that builds {@link LinSysPartitionDataOnHeap}. + * + * @param <K> Type of a key in <tt>upstream</tt> data. + * @param <V> Type of a value in <tt>upstream</tt> data. + * @param <C> Type of a partition <tt>context</tt>. + */ +public class LinSysPartitionDataBuilderOnHeap<K, V, C extends Serializable> + implements PartitionDataBuilder<K, V, C, LinSysPartitionDataOnHeap> { + /** */ + private static final long serialVersionUID = -7820760153954269227L; + + /** Extractor of X matrix row. */ + private final IgniteBiFunction<K, V, double[]> xExtractor; + + /** Extractor of Y vector value. */ + private final IgniteBiFunction<K, V, Double> yExtractor; + + /** Number of columns. */ + private final int cols; + + /** + * Constructs a new instance of linear system partition data builder. + * + * @param xExtractor Extractor of X matrix row. + * @param yExtractor Extractor of Y vector value. + * @param cols Number of columns. + */ + public LinSysPartitionDataBuilderOnHeap(IgniteBiFunction<K, V, double[]> xExtractor, + IgniteBiFunction<K, V, Double> yExtractor, int cols) { + this.xExtractor = xExtractor; + this.yExtractor = yExtractor; + this.cols = cols; + } + + /** {@inheritDoc} */ + @Override public LinSysPartitionDataOnHeap build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, + C ctx) { + // Prepares the matrix of features in flat column-major format. + double[] x = new double[Math.toIntExact(upstreamDataSize * cols)]; + double[] y = new double[Math.toIntExact(upstreamDataSize)]; + + int ptr = 0; + while (upstreamData.hasNext()) { + UpstreamEntry<K, V> entry = upstreamData.next(); + double[] row = xExtractor.apply(entry.getKey(), entry.getValue()); + + assert row.length == cols : "X extractor must return exactly " + cols + " columns"; + + for (int i = 0; i < cols; i++) + x[Math.toIntExact(i * upstreamDataSize) + ptr] = row[i]; + + y[ptr] = yExtractor.apply(entry.getKey(), entry.getValue()); + + ptr++; + } + + return new LinSysPartitionDataOnHeap(x, Math.toIntExact(upstreamDataSize), cols, y); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java new file mode 100644 index 0000000..e0b8f46 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.isolve; + +/** + * On Heap partition data that keeps part of a linear system. + */ +public class LinSysPartitionDataOnHeap implements AutoCloseable { + /** Part of X matrix. */ + private final double[] x; + + /** Number of rows. */ + private final int rows; + + /** Number of columns. */ + private final int cols; + + /** Part of Y vector. */ + private final double[] y; + + /** + * Constructs a new instance of linear system partition data. + * + * @param x Part of X matrix. + * @param rows Number of rows. + * @param cols Number of columns. + * @param y Part of Y vector. + */ + public LinSysPartitionDataOnHeap(double[] x, int rows, int cols, double[] y) { + this.x = x; + this.rows = rows; + this.cols = cols; + this.y = y; + } + + /** */ + public double[] getX() { + return x; + } + + /** */ + public int getRows() { + return rows; + } + + /** */ + public int getCols() { + return cols; + } + + /** */ + public double[] getY() { + return y; + } + + /** {@inheritDoc} */ + @Override public void close() { + // Do nothing, GC will clean up. + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java new file mode 100644 index 0000000..8d190cd --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.isolve.lsqr; + +import com.github.fommil.netlib.BLAS; +import java.util.Arrays; + +/** + * Basic implementation of the LSQR algorithm without assumptions about dataset storage format or data processing + * device. + * + * This implementation is based on SciPy implementation. + * SciPy implementation: https://github.com/scipy/scipy/blob/master/scipy/sparse/linalg/isolve/lsqr.py#L98. + */ +// TODO: IGNITE-7660: Refactor LSQR algorithm +public abstract class AbstractLSQR { + /** The smallest representable positive number such that 1.0 + eps != 1.0. */ + private static final double eps = Double.longBitsToDouble(Double.doubleToLongBits(1.0) | 1) - 1.0; + + /** BLAS (Basic Linear Algebra Subprograms) instance. */ + private static BLAS blas = BLAS.getInstance(); + + /** + * Solves given Sparse Linear Systems. + * + * @param damp Damping coefficient. + * @param atol Stopping tolerances, if both (atol and btol) are 1.0e-9 (say), the final residual norm should be + * accurate to about 9 digits. + * @param btol Stopping tolerances, if both (atol and btol) are 1.0e-9 (say), the final residual norm should be + * accurate to about 9 digits. + * @param conlim Another stopping tolerance, LSQR terminates if an estimate of cond(A) exceeds conlim. + * @param iterLim Explicit limitation on number of iterations (for safety). + * @param calcVar Whether to estimate diagonals of (A'A + damp^2*I)^{-1}. + * @param x0 Initial value of x. + * @return Solver result. + */ + public LSQRResult solve(double damp, double atol, double btol, double conlim, double iterLim, boolean calcVar, + double[] x0) { + int n = getColumns(); + + if (iterLim < 0) + iterLim = 2 * n; + + double[] var = new double[n]; + int itn = 0; + int istop = 0; + double ctol = 0; + + if (conlim > 0) + ctol = 1 / conlim; + + double anorm = 0; + double acond = 0; + double dampsq = Math.pow(damp, 2.0); + double ddnorm = 0; + double res2 = 0; + double xnorm = 0; + double xxnorm = 0; + double z = 0; + double cs2 = -1; + double sn2 = 0; + + // Set up the first vectors u and v for the bidiagonalization. + // These satisfy beta*u = b - A*x, alfa*v = A'*u. + double bnorm = bnorm(); + double[] x; + double beta; + + if (x0 == null) { + x = new double[n]; + beta = bnorm; + } + else { + x = x0; + beta = beta(x, -1.0, 1.0); + } + + double[] v = new double[n]; + double alfa; + + if (beta > 0) { + v = iter(beta, v); + alfa = blas.dnrm2(v.length, v, 1); + } + else { + System.arraycopy(x, 0, v, 0, v.length); + alfa = 0; + } + + if (alfa > 0) + blas.dscal(v.length, 1 / alfa, v, 1); + + double[] w = Arrays.copyOf(v, v.length); + + double rhobar = alfa; + double phibar = beta; + double rnorm = beta; + double r1norm = rnorm; + double r2norm = rnorm; + double arnorm = alfa * beta; + double[] dk = new double[w.length]; + + if (arnorm == 0) + return new LSQRResult(x, itn, istop, r1norm, r2norm, anorm, acond, arnorm, xnorm, var); + + // Main iteration loop. + while (itn < iterLim) { + itn = itn + 1; + + // Perform the next step of the bidiagonalization to obtain the + // next beta, u, alfa, v. These satisfy the relations + // beta*u = A*v - alfa*u, + // alfa*v = A'*u - beta*v. + beta = beta(v, 1.0, -alfa); + if (beta > 0) { + anorm = Math.sqrt(Math.pow(anorm, 2) + Math.pow(alfa, 2) + Math.pow(beta, 2) + Math.pow(damp, 2)); + + blas.dscal(v.length, -beta, v, 1); + + iter(beta, v); + + //v = dataset.iter(beta, n); + alfa = blas.dnrm2(v.length, v, 1); + + if (alfa > 0) + blas.dscal(v.length, 1 / alfa, v, 1); + } + + // Use a plane rotation to eliminate the damping parameter. + // This alters the diagonal (rhobar) of the lower-bidiagonal matrix. + double rhobar1 = Math.sqrt(Math.pow(rhobar, 2) + Math.pow(damp, 2)); + double cs1 = rhobar / rhobar1; + double sn1 = damp / rhobar1; + double psi = sn1 * phibar; + phibar = cs1 * phibar; + + // Use a plane rotation to eliminate the subdiagonal element (beta) + // of the lower-bidiagonal matrix, giving an upper-bidiagonal matrix. + double[] symOrtho = symOrtho(rhobar1, beta); + double cs = symOrtho[0]; + double sn = symOrtho[1]; + double rho = symOrtho[2]; + + double theta = sn * alfa; + rhobar = -cs * alfa; + double phi = cs * phibar; + phibar = sn * phibar; + double tau = sn * phi; + + double t1 = phi / rho; + double t2 = -theta / rho; + + blas.dcopy(w.length, w, 1, dk, 1); + blas.dscal(dk.length, 1 / rho, dk, 1); + + // x = x + t1*w + blas.daxpy(w.length, t1, w, 1, x, 1); + // w = v + t2*w + blas.dscal(w.length, t2, w, 1); + blas.daxpy(w.length, 1, v, 1, w, 1); + + ddnorm = ddnorm + Math.pow(blas.dnrm2(dk.length, dk, 1), 2); + + if (calcVar) + blas.daxpy(var.length, 1.0, pow2(dk), 1, var, 1); + + // Use a plane rotation on the right to eliminate the + // super-diagonal element (theta) of the upper-bidiagonal matrix. + // Then use the result to estimate norm(x). + double delta = sn2 * rho; + double gambar = -cs2 * rho; + double rhs = phi - delta * z; + double zbar = rhs / gambar; + xnorm = Math.sqrt(xxnorm + Math.pow(zbar, 2)); + double gamma = Math.sqrt(Math.pow(gambar, 2) + Math.pow(theta, 2)); + cs2 = gambar / gamma; + sn2 = theta / gamma; + z = rhs / gamma; + xxnorm = xxnorm + Math.pow(z, 2); + + // Test for convergence. + // First, estimate the condition of the matrix Abar, + // and the norms of rbar and Abar'rbar. + acond = anorm * Math.sqrt(ddnorm); + double res1 = Math.pow(phibar, 2); + res2 = res2 + Math.pow(psi, 2); + rnorm = Math.sqrt(res1 + res2); + arnorm = alfa * Math.abs(tau); + + // Distinguish between + // r1norm = ||b - Ax|| and + // r2norm = rnorm in current code + // = sqrt(r1norm^2 + damp^2*||x||^2). + // Estimate r1norm from + // r1norm = sqrt(r2norm^2 - damp^2*||x||^2). + // Although there is cancellation, it might be accurate enough. + double r1sq = Math.pow(rnorm, 2) - dampsq * xxnorm; + r1norm = Math.sqrt(Math.abs(r1sq)); + + if (r1sq < 0) + r1norm = -r1norm; + + r2norm = rnorm; + + // Now use these norms to estimate certain other quantities, + // some of which will be small near a solution. + double test1 = rnorm / bnorm; + double test2 = arnorm / (anorm * rnorm + eps); + double test3 = 1 / (acond + eps); + t1 = test1 / (1 + anorm * xnorm / bnorm); + double rtol = btol + atol * anorm * xnorm / bnorm; + + // The following tests guard against extremely small values of + // atol, btol or ctol. (The user may have set any or all of + // the parameters atol, btol, conlim to 0.) + // The effect is equivalent to the normal tests using + // atol = eps, btol = eps, conlim = 1/eps. + if (itn >= iterLim) + istop = 7; + + if (1 + test3 <= 1) + istop = 6; + + if (1 + test2 <= 1) + istop = 5; + + if (1 + t1 <= 1) + istop = 4; + + // Allow for tolerances set by the user. + if (test3 <= ctol) + istop = 3; + + if (test2 <= atol) + istop = 2; + + if (test1 <= rtol) + istop = 1; + + if (istop != 0) + break; + } + + return new LSQRResult(x, itn, istop, r1norm, r2norm, anorm, acond, arnorm, xnorm, var); + } + + /** + * Calculates bnorm. + * + * @return bnorm + */ + protected abstract double bnorm(); + + /** + * Calculates beta. + * + * @param x X value. + * @param alfa Alfa value. + * @param beta Beta value. + * @return Beta. + */ + protected abstract double beta(double[] x, double alfa, double beta); + + /** + * Perform LSQR iteration. + * + * @param bnorm Bnorm value. + * @param target Target value. + * @return Iteration result. + */ + protected abstract double[] iter(double bnorm, double[] target); + + /** */ + protected abstract int getColumns(); + + /** */ + private static double[] symOrtho(double a, double b) { + if (b == 0) + return new double[] {Math.signum(a), 0, Math.abs(a)}; + else if (a == 0) + return new double[] {0, Math.signum(b), Math.abs(b)}; + else { + double c, s, r; + + if (Math.abs(b) > Math.abs(a)) { + double tau = a / b; + s = Math.signum(b) / Math.sqrt(1 + tau * tau); + c = s * tau; + r = b / s; + } + else { + double tau = b / a; + c = Math.signum(a) / Math.sqrt(1 + tau * tau); + s = c * tau; + r = a / c; + } + + return new double[] {c, s, r}; + } + } + + /** + * Raises all elements of the specified vector {@code a} to the power of the specified {@code pow}. Be aware that + * it's "in place" operation. + * + * @param a Vector or matrix of doubles. + * @return Matrix with elements raised to the specified power. + */ + private static double[] pow2(double[] a) { + double[] res = new double[a.length]; + + for (int i = 0; i < res.length; i++) + res[i] = Math.pow(a[i], 2); + + return res; + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java new file mode 100644 index 0000000..fa8e713 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.isolve.lsqr; + +import com.github.fommil.netlib.BLAS; +import java.util.Arrays; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.PartitionDataBuilder; +import org.apache.ignite.ml.math.isolve.LinSysPartitionDataOnHeap; + +/** + * Distributed implementation of LSQR algorithm based on {@link AbstractLSQR} and {@link Dataset}. + */ +public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable { + /** Dataset. */ + private final Dataset<LSQRPartitionContext, LinSysPartitionDataOnHeap> dataset; + + /** + * Constructs a new instance of OnHeap LSQR algorithm implementation. + * + * @param datasetBuilder Dataset builder. + * @param partDataBuilder Partition data builder. + */ + public LSQROnHeap(DatasetBuilder<K, V> datasetBuilder, + PartitionDataBuilder<K, V, LSQRPartitionContext, LinSysPartitionDataOnHeap> partDataBuilder) { + this.dataset = datasetBuilder.build( + (upstream, upstreamSize) -> new LSQRPartitionContext(), + partDataBuilder + ); + } + + /** {@inheritDoc} */ + @Override protected double bnorm() { + return dataset.computeWithCtx((ctx, data) -> { + ctx.setU(Arrays.copyOf(data.getY(), data.getY().length)); + + return BLAS.getInstance().dnrm2(data.getY().length, data.getY(), 1); + }, (a, b) -> a == null ? b : Math.sqrt(a * a + b * b)); + } + + /** {@inheritDoc} */ + @Override protected double beta(double[] x, double alfa, double beta) { + return dataset.computeWithCtx((ctx, data) -> { + BLAS.getInstance().dgemv("N", data.getRows(), data.getCols(), alfa, data.getX(), + Math.max(1, data.getRows()), x, 1, beta, ctx.getU(), 1); + + return BLAS.getInstance().dnrm2(ctx.getU().length, ctx.getU(), 1); + }, (a, b) -> a == null ? b : Math.sqrt(a * a + b * b)); + } + + /** {@inheritDoc} */ + @Override protected double[] iter(double bnorm, double[] target) { + double[] res = dataset.computeWithCtx((ctx, data) -> { + BLAS.getInstance().dscal(ctx.getU().length, 1 / bnorm, ctx.getU(), 1); + double[] v = new double[data.getCols()]; + BLAS.getInstance().dgemv("T", data.getRows(), data.getCols(), 1.0, data.getX(), + Math.max(1, data.getRows()), ctx.getU(), 1, 0, v, 1); + + return v; + }, (a, b) -> { + if (a == null) + return b; + else { + BLAS.getInstance().daxpy(a.length, 1.0, a, 1, b, 1); + + return b; + } + }); + BLAS.getInstance().daxpy(res.length, 1.0, res, 1, target, 1); + return target; + } + + /** + * Returns number of columns in dataset. + * + * @return number of columns + */ + @Override protected int getColumns() { + return dataset.compute(LinSysPartitionDataOnHeap::getCols, (a, b) -> a == null ? b : a); + } + + /** {@inheritDoc} */ + @Override public void close() throws Exception { + dataset.close(); + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRPartitionContext.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRPartitionContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRPartitionContext.java new file mode 100644 index 0000000..0ec9805 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRPartitionContext.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.isolve.lsqr; + +import java.io.Serializable; + +/** + * Partition context of the LSQR algorithm. + */ +public class LSQRPartitionContext implements Serializable { + /** */ + private static final long serialVersionUID = -8159608186899430315L; + + /** Part of U vector. */ + private double[] u; + + /** */ + public double[] getU() { + return u; + } + + /** */ + public void setU(double[] u) { + this.u = u; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRResult.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRResult.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRResult.java new file mode 100644 index 0000000..47beddb --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRResult.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.isolve.lsqr; + +import java.util.Arrays; +import org.apache.ignite.ml.math.isolve.IterativeSolverResult; + +/** + * LSQR iterative solver result. + */ +public class LSQRResult extends IterativeSolverResult { + /** */ + private static final long serialVersionUID = -8866269808589635947L; + + /** + * Gives the reason for termination. 1 means x is an approximate solution to Ax = b. 2 means x approximately solves + * the least-squares problem. + */ + private final int isstop; + + /** Represents norm(r), where r = b - Ax. */ + private final double r1norn; + + /**Represents sqrt( norm(r)^2 + damp^2 * norm(x)^2 ). Equal to r1norm if damp == 0. */ + private final double r2norm; + + /** Estimate of Frobenius norm of Abar = [[A]; [damp*I]]. */ + private final double anorm; + + /** Estimate of cond(Abar). */ + private final double acond; + + /** Estimate of norm(A'*r - damp^2*x). */ + private final double arnorm; + + /** Represents norm(x). */ + private final double xnorm; + + /** + * If calc_var is True, estimates all diagonals of (A'A)^{-1} (if damp == 0) or more generally + * (A'A + damp^2*I)^{-1}. This is well defined if A has full column rank or damp > 0. + */ + private final double[] var; + + /** + * Constructs a new instance of LSQR result. + * + * @param x X value. + * @param iterations Number of performed iterations. + * @param isstop Stop reason. + * @param r1norn R1 norm value. + * @param r2norm R2 norm value. + * @param anorm A norm value. + * @param acond A cond value. + * @param arnorm AR norm value. + * @param xnorm X norm value. + * @param var Var value. + */ + public LSQRResult(double[] x, int iterations, int isstop, double r1norn, double r2norm, double anorm, double acond, + double arnorm, double xnorm, double[] var) { + super(x, iterations); + this.isstop = isstop; + this.r1norn = r1norn; + this.r2norm = r2norm; + this.anorm = anorm; + this.acond = acond; + this.arnorm = arnorm; + this.xnorm = xnorm; + this.var = var; + } + + /** */ + public int getIsstop() { + return isstop; + } + + /** */ + public double getR1norn() { + return r1norn; + } + + /** */ + public double getR2norm() { + return r2norm; + } + + /** */ + public double getAnorm() { + return anorm; + } + + /** */ + public double getAcond() { + return acond; + } + + /** */ + public double getArnorm() { + return arnorm; + } + + /** */ + public double getXnorm() { + return xnorm; + } + + /** */ + public double[] getVar() { + return var; + } + + /** */ + @Override public String toString() { + return "LSQRResult{" + + "isstop=" + isstop + + ", r1norn=" + r1norn + + ", r2norm=" + r2norm + + ", anorm=" + anorm + + ", acond=" + acond + + ", arnorm=" + arnorm + + ", xnorm=" + xnorm + + ", var=" + Arrays.toString(var) + + '}'; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/package-info.java new file mode 100644 index 0000000..a667eb7 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains LSQR algorithm implementation. + */ +package org.apache.ignite.ml.math.isolve.lsqr; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/package-info.java new file mode 100644 index 0000000..5e0155f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains iterative algorithms for solving linear systems. + */ +package org.apache.ignite.ml.math.isolve; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java new file mode 100644 index 0000000..d7d587e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import java.util.Arrays; +import org.apache.ignite.ml.DatasetTrainer; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap; +import org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR; +import org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap; +import org.apache.ignite.ml.math.isolve.lsqr.LSQRResult; + +/** + * Trainer of the linear regression model based on LSQR algorithm. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * + * @see AbstractLSQR + */ +public class LinearRegressionLSQRTrainer<K, V> implements DatasetTrainer<K, V, LinearRegressionModel> { + /** {@inheritDoc} */ + @Override public LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) { + + LSQRResult res; + + try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>( + datasetBuilder, + new LinSysPartitionDataBuilderOnHeap<>( + (k, v) -> { + double[] row = Arrays.copyOf(featureExtractor.apply(k, v), cols + 1); + + row[cols] = 1.0; + + return row; + }, + lbExtractor, + cols + 1 + ) + )) { + res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null); + } + catch (Exception e) { + throw new RuntimeException(e); + } + + Vector weights = new DenseLocalOnHeapVector(Arrays.copyOfRange(res.getX(), 0, cols)); + + return new LinearRegressionModel(weights, res.getX()[cols]); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java index 5efdf57..b4f83d9 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java @@ -20,6 +20,8 @@ package org.apache.ignite.ml.trainers; import org.apache.ignite.ml.Model; /** Trainer interface. */ +@Deprecated +// TODO: IGNITE-7659: Reduce multiple Trainer interfaces to one public interface Trainer<M extends Model, T> { /** * Train the model based on provided data. http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java index bb41239..926d872 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java @@ -61,6 +61,7 @@ import org.apache.ignite.ml.math.impls.vector.VectorIterableTest; import org.apache.ignite.ml.math.impls.vector.VectorNormTest; import org.apache.ignite.ml.math.impls.vector.VectorToMatrixTest; import org.apache.ignite.ml.math.impls.vector.VectorViewTest; +import org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeapTest; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -120,7 +121,8 @@ import org.junit.runners.Suite; QRDecompositionTest.class, SingularValueDecompositionTest.class, QRDSolverTest.class, - DistanceTest.class + DistanceTest.class, + LSQROnHeapTest.class }) public class MathImplLocalTestSuite { // No-op.