Repository: ignite
Updated Branches:
  refs/heads/master a8170f78a -> 6dbcf7e34


IGNITE-7897: Add example for LSQR with data normalization.

This closes #3614


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/6dbcf7e3
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/6dbcf7e3
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/6dbcf7e3

Branch: refs/heads/master
Commit: 6dbcf7e34ec3b1d3cd80fc83ab53a9959c3b1fe1
Parents: a8170f7
Author: Anton Dmitriev <dmitrievanth...@gmail.com>
Authored: Wed Mar 7 15:11:53 2018 +0300
Committer: Yury Babak <yba...@gridgain.com>
Committed: Wed Mar 7 15:11:53 2018 +0300

----------------------------------------------------------------------
 ...nWithLSQRTrainerAndNormalizationExample.java | 184 +++++++++++++++++++
 1 file changed, 184 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/6dbcf7e3/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java
 
b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java
new file mode 100644
index 0000000..61195c4
--- /dev/null
+++ 
b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.regression.linear;
+
+import java.util.Arrays;
+import java.util.UUID;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.configuration.CacheConfiguration;
+import 
org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import 
org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor;
+import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Run linear regression model over distributed matrix.
+ *
+ * @see LinearRegressionLSQRTrainer
+ * @see NormalizationTrainer
+ * @see NormalizationPreprocessor
+ */
+public class DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample 
{
+    /** */
+    private static final double[][] data = {
+        {8, 78, 284, 9.100000381, 109},
+        {9.300000191, 68, 433, 8.699999809, 144},
+        {7.5, 70, 739, 7.199999809, 113},
+        {8.899999619, 96, 1792, 8.899999619, 97},
+        {10.19999981, 74, 477, 8.300000191, 206},
+        {8.300000191, 111, 362, 10.89999962, 124},
+        {8.800000191, 77, 671, 10, 152},
+        {8.800000191, 168, 636, 9.100000381, 162},
+        {10.69999981, 82, 329, 8.699999809, 150},
+        {11.69999981, 89, 634, 7.599999905, 134},
+        {8.5, 149, 631, 10.80000019, 292},
+        {8.300000191, 60, 257, 9.5, 108},
+        {8.199999809, 96, 284, 8.800000191, 111},
+        {7.900000095, 83, 603, 9.5, 182},
+        {10.30000019, 130, 686, 8.699999809, 129},
+        {7.400000095, 145, 345, 11.19999981, 158},
+        {9.600000381, 112, 1357, 9.699999809, 186},
+        {9.300000191, 131, 544, 9.600000381, 177},
+        {10.60000038, 80, 205, 9.100000381, 127},
+        {9.699999809, 130, 1264, 9.199999809, 179},
+        {11.60000038, 140, 688, 8.300000191, 80},
+        {8.100000381, 154, 354, 8.399999619, 103},
+        {9.800000191, 118, 1632, 9.399999619, 101},
+        {7.400000095, 94, 348, 9.800000191, 117},
+        {9.399999619, 119, 370, 10.39999962, 88},
+        {11.19999981, 153, 648, 9.899999619, 78},
+        {9.100000381, 116, 366, 9.199999809, 102},
+        {10.5, 97, 540, 10.30000019, 95},
+        {11.89999962, 176, 680, 8.899999619, 80},
+        {8.399999619, 75, 345, 9.600000381, 92},
+        {5, 134, 525, 10.30000019, 126},
+        {9.800000191, 161, 870, 10.39999962, 108},
+        {9.800000191, 111, 669, 9.699999809, 77},
+        {10.80000019, 114, 452, 9.600000381, 60},
+        {10.10000038, 142, 430, 10.69999981, 71},
+        {10.89999962, 238, 822, 10.30000019, 86},
+        {9.199999809, 78, 190, 10.69999981, 93},
+        {8.300000191, 196, 867, 9.600000381, 106},
+        {7.300000191, 125, 969, 10.5, 162},
+        {9.399999619, 82, 499, 7.699999809, 95},
+        {9.399999619, 125, 925, 10.19999981, 91},
+        {9.800000191, 129, 353, 9.899999619, 52},
+        {3.599999905, 84, 288, 8.399999619, 110},
+        {8.399999619, 183, 718, 10.39999962, 69},
+        {10.80000019, 119, 540, 9.199999809, 57},
+        {10.10000038, 180, 668, 13, 106},
+        {9, 82, 347, 8.800000191, 40},
+        {10, 71, 345, 9.199999809, 50},
+        {11.30000019, 118, 463, 7.800000191, 35},
+        {11.30000019, 121, 728, 8.199999809, 86},
+        {12.80000019, 68, 383, 7.400000095, 57},
+        {10, 112, 316, 10.39999962, 57},
+        {6.699999809, 109, 388, 8.899999619, 94}
+    };
+
+    /** Run example. */
+    public static void main(String[] args) throws InterruptedException {
+        System.out.println();
+        System.out.println(">>> Linear regression model over sparse 
distributed matrix API usage example started.");
+        // Start ignite grid.
+        try (Ignite ignite = 
Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            // Create IgniteThread, we must work with SparseDistributedMatrix 
inside IgniteThread
+            // because we create ignite cache internally.
+            IgniteThread igniteThread = new 
IgniteThread(ignite.configuration().getIgniteInstanceName(),
+                SparseDistributedMatrixExample.class.getSimpleName(), () -> {
+                IgniteCache<Integer, double[]> dataCache = 
getTestCache(ignite);
+
+                System.out.println(">>> Create new normalization trainer 
object.");
+                NormalizationTrainer<Integer, double[]> normalizationTrainer = 
new NormalizationTrainer<>();
+
+                System.out.println(">>> Perform the training to get the 
normalization preprocessor.");
+                NormalizationPreprocessor<Integer, double[]> preprocessor = 
normalizationTrainer.fit(
+                    new CacheBasedDatasetBuilder<>(ignite, dataCache),
+                    (k, v) -> Arrays.copyOfRange(v, 1, v.length),
+                    4
+                );
+
+                System.out.println(">>> Create new linear regression trainer 
object.");
+                LinearRegressionLSQRTrainer<Integer, double[]> trainer = new 
LinearRegressionLSQRTrainer<>();
+
+                System.out.println(">>> Perform the training to get the 
model.");
+                LinearRegressionModel mdl = trainer.fit(
+                    new CacheBasedDatasetBuilder<>(ignite, dataCache),
+                    preprocessor,
+                    (k, v) -> v[0],
+                    4
+                );
+
+                System.out.println(">>> Linear regression model: " + mdl);
+
+                System.out.println(">>> ---------------------------------");
+                System.out.println(">>> | Prediction\t| Ground Truth\t|");
+                System.out.println(">>> ---------------------------------");
+
+                try (QueryCursor<Cache.Entry<Integer, double[]>> observations 
= dataCache.query(new ScanQuery<>())) {
+                    for (Cache.Entry<Integer, double[]> observation : 
observations) {
+                        Integer key = observation.getKey();
+                        double[] val = observation.getValue();
+                        double groundTruth = val[0];
+
+                        double prediction = mdl.apply(new 
DenseLocalOnHeapVector(preprocessor.apply(key, val)));
+
+                        System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", 
prediction, groundTruth);
+                    }
+                }
+
+                System.out.println(">>> ---------------------------------");
+            });
+
+            igniteThread.start();
+
+            igniteThread.join();
+        }
+    }
+
+    /**
+     * Fills cache with data and returns it.
+     *
+     * @param ignite Ignite instance.
+     * @return Filled Ignite Cache.
+     */
+    private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
+        CacheConfiguration<Integer, double[]> cacheConfiguration = new 
CacheConfiguration<>();
+        cacheConfiguration.setName("TEST_" + UUID.randomUUID());
+        cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 
10));
+
+        IgniteCache<Integer, double[]> cache = 
ignite.createCache(cacheConfiguration);
+
+        for (int i = 0; i < data.length; i++)
+            cache.put(i, data[i]);
+
+        return cache;
+    }
+}

Reply via email to