IGNITE-9711: [ML] Remove IgniteThread wrapper from ml examples this closes #4849
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/609266fe Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/609266fe Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/609266fe Branch: refs/heads/master Commit: 609266fe2797c07599a893625f933740a25d049d Parents: c7227cf Author: YuriBabak <[email protected]> Authored: Fri Sep 28 11:57:58 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Fri Sep 28 11:57:58 2018 +0300 ---------------------------------------------------------------------- .../clustering/KMeansClusterizationExample.java | 75 +++-- .../ml/knn/ANNClassificationExample.java | 97 +++--- .../ml/knn/KNNClassificationExample.java | 69 ++--- .../examples/ml/knn/KNNRegressionExample.java | 79 +++-- .../examples/ml/nn/MLPTrainerExample.java | 130 ++++---- .../LinearRegressionLSQRTrainerExample.java | 56 ++-- ...ssionLSQRTrainerWithMinMaxScalerExample.java | 69 ++--- .../LinearRegressionSGDTrainerExample.java | 65 ++-- .../LogisticRegressionSGDTrainerExample.java | 84 +++--- ...gressionMultiClassClassificationExample.java | 169 +++++------ .../ml/selection/cv/CrossValidationExample.java | 58 ++-- .../split/TrainTestDatasetSplitterExample.java | 69 ++--- .../binary/SVMBinaryClassificationExample.java | 79 +++-- .../SVMMultiClassClassificationExample.java | 151 +++++----- ...ecisionTreeClassificationTrainerExample.java | 74 ++--- .../DecisionTreeRegressionTrainerExample.java | 63 ++-- .../GDBOnTreesClassificationTrainerExample.java | 58 ++-- .../GDBOnTreesRegressionTrainerExample.java | 55 ++-- .../RandomForestClassificationExample.java | 76 +++-- .../RandomForestRegressionExample.java | 91 +++--- .../ml/tutorial/Step_1_Read_and_Learn.java | 61 ++-- .../examples/ml/tutorial/Step_2_Imputing.java | 71 ++--- .../examples/ml/tutorial/Step_3_Categorial.java | 96 +++--- .../Step_3_Categorial_with_One_Hot_Encoder.java | 98 +++--- .../ml/tutorial/Step_4_Add_age_fare.java | 98 +++--- .../examples/ml/tutorial/Step_5_Scaling.java | 125 ++++---- .../tutorial/Step_5_Scaling_with_Pipeline.java | 77 +++-- .../ignite/examples/ml/tutorial/Step_6_KNN.java | 127 ++++---- .../ml/tutorial/Step_7_Split_train_test.java | 136 ++++----- .../ignite/examples/ml/tutorial/Step_8_CV.java | 218 +++++++------- .../ml/tutorial/Step_8_CV_with_Param_Grid.java | 200 ++++++------- .../ml/tutorial/Step_9_Go_to_LogReg.java | 296 +++++++++---------- 32 files changed, 1507 insertions(+), 1763 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java index 152375a..567775b 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java @@ -30,7 +30,6 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer; import org.apache.ignite.ml.math.Tracer; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.thread.IgniteThread; /** * Run KMeans clustering algorithm ({@link KMeansTrainer}) over distributed dataset. @@ -55,58 +54,52 @@ public class KMeansClusterizationExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - KMeansClusterizationExample.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); + IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); - KMeansTrainer trainer = new KMeansTrainer() - .withSeed(7867L); + KMeansTrainer trainer = new KMeansTrainer() + .withSeed(7867L); - KMeansModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); + KMeansModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); - System.out.println(">>> KMeans centroids"); - Tracer.showAscii(mdl.getCenters()[0]); - Tracer.showAscii(mdl.getCenters()[1]); - System.out.println(">>>"); + System.out.println(">>> KMeans centroids"); + Tracer.showAscii(mdl.getCenters()[0]); + Tracer.showAscii(mdl.getCenters()[1]); + System.out.println(">>>"); - System.out.println(">>> -----------------------------------"); - System.out.println(">>> | Predicted cluster\t| Real Label\t|"); - System.out.println(">>> -----------------------------------"); + System.out.println(">>> -----------------------------------"); + System.out.println(">>> | Predicted cluster\t| Real Label\t|"); + System.out.println(">>> -----------------------------------"); - int amountOfErrors = 0; - int totalAmount = 0; + int amountOfErrors = 0; + int totalAmount = 0; - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; - double prediction = mdl.apply(new DenseVector(inputs)); + double prediction = mdl.apply(new DenseVector(inputs)); - totalAmount++; - if (groundTruth != prediction) - amountOfErrors++; + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; - System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth); - } - - System.out.println(">>> ---------------------------------"); + System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth); + } - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + System.out.println(">>> ---------------------------------"); - System.out.println(">>> KMeans clustering algorithm over cached dataset usage example completed."); - } - }); + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); - igniteThread.start(); - igniteThread.join(); + System.out.println(">>> KMeans clustering algorithm over cached dataset usage example completed."); + } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java index 8a2d786..c9490fc 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java @@ -34,7 +34,6 @@ import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.distances.ManhattanDistance; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.thread.IgniteThread; /** * Run ANN multi-class classification trainer ({@link ANNClassificationTrainer}) over distributed dataset. @@ -59,73 +58,67 @@ public class ANNClassificationExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - ANNClassificationExample.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); + IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); - ANNClassificationTrainer trainer = new ANNClassificationTrainer() - .withDistance(new ManhattanDistance()) - .withK(50) - .withMaxIterations(1000) - .withSeed(1234L) - .withEpsilon(1e-2); + ANNClassificationTrainer trainer = new ANNClassificationTrainer() + .withDistance(new ManhattanDistance()) + .withK(50) + .withMaxIterations(1000) + .withSeed(1234L) + .withEpsilon(1e-2); - long startTrainingTime = System.currentTimeMillis(); + long startTrainingTime = System.currentTimeMillis(); - NNClassificationModel knnMdl = trainer.fit( - ignite, - dataCache, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ).withK(5) - .withDistanceMeasure(new EuclideanDistance()) - .withStrategy(NNStrategy.WEIGHTED); + NNClassificationModel knnMdl = trainer.fit( + ignite, + dataCache, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ).withK(5) + .withDistanceMeasure(new EuclideanDistance()) + .withStrategy(NNStrategy.WEIGHTED); - long endTrainingTime = System.currentTimeMillis(); + long endTrainingTime = System.currentTimeMillis(); - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); - - int amountOfErrors = 0; - int totalAmount = 0; + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); - long totalPredictionTime = 0L; + int amountOfErrors = 0; + int totalAmount = 0; - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; + long totalPredictionTime = 0L; - long startPredictionTime = System.currentTimeMillis(); - double prediction = knnMdl.apply(new DenseVector(inputs)); - long endPredictionTime = System.currentTimeMillis(); + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; - totalPredictionTime += (endPredictionTime - startPredictionTime); + long startPredictionTime = System.currentTimeMillis(); + double prediction = knnMdl.apply(new DenseVector(inputs)); + long endPredictionTime = System.currentTimeMillis(); - totalAmount++; - if (groundTruth != prediction) - amountOfErrors++; + totalPredictionTime += (endPredictionTime - startPredictionTime); - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; - System.out.println(">>> ---------------------------------"); + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } - System.out.println("Training costs = " + (endTrainingTime - startTrainingTime)); - System.out.println("Prediction costs = " + totalPredictionTime); + System.out.println(">>> ---------------------------------"); - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount)); - System.out.println(totalAmount); + System.out.println("Training costs = " + (endTrainingTime - startTrainingTime)); + System.out.println("Prediction costs = " + totalPredictionTime); - System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed."); - } - }); + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount)); + System.out.println(totalAmount); - igniteThread.start(); - igniteThread.join(); + System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed."); + } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java index cf285a4..5cbb2ad 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java @@ -31,7 +31,6 @@ import org.apache.ignite.ml.knn.classification.NNStrategy; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.thread.IgniteThread; /** * Run kNN multi-class classification trainer ({@link KNNClassificationTrainer}) over distributed dataset. @@ -56,54 +55,48 @@ public class KNNClassificationExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - KNNClassificationExample.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); + IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); - KNNClassificationTrainer trainer = new KNNClassificationTrainer(); + KNNClassificationTrainer trainer = new KNNClassificationTrainer(); - NNClassificationModel knnMdl = trainer.fit( - ignite, - dataCache, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ).withK(3) - .withDistanceMeasure(new EuclideanDistance()) - .withStrategy(NNStrategy.WEIGHTED); + NNClassificationModel knnMdl = trainer.fit( + ignite, + dataCache, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ).withK(3) + .withDistanceMeasure(new EuclideanDistance()) + .withStrategy(NNStrategy.WEIGHTED); - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); - - int amountOfErrors = 0; - int totalAmount = 0; + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; + int amountOfErrors = 0; + int totalAmount = 0; - double prediction = knnMdl.apply(new DenseVector(inputs)); + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; - totalAmount++; - if (groundTruth != prediction) - amountOfErrors++; + double prediction = knnMdl.apply(new DenseVector(inputs)); - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; - System.out.println(">>> ---------------------------------"); + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount)); + System.out.println(">>> ---------------------------------"); - System.out.println(">>> kNN multi-class classification algorithm over cached dataset usage example completed."); - } - }); + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount)); - igniteThread.start(); - igniteThread.join(); + System.out.println(">>> kNN multi-class classification algorithm over cached dataset usage example completed."); + } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java index 78f38c8..3969f0c 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java @@ -31,7 +31,6 @@ import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer; import org.apache.ignite.ml.math.distances.ManhattanDistance; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.thread.IgniteThread; /** * Run kNN regression trainer ({@link KNNRegressionTrainer}) over distributed dataset. @@ -57,61 +56,55 @@ public class KNNRegressionExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - KNNRegressionExample.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); + IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); - KNNRegressionTrainer trainer = new KNNRegressionTrainer(); + KNNRegressionTrainer trainer = new KNNRegressionTrainer(); - KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( - ignite, - dataCache, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ).withK(5) - .withDistanceMeasure(new ManhattanDistance()) - .withStrategy(NNStrategy.WEIGHTED); + KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit( + ignite, + dataCache, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ).withK(5) + .withDistanceMeasure(new ManhattanDistance()) + .withStrategy(NNStrategy.WEIGHTED); - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); - - int totalAmount = 0; - // Calculate mean squared error (MSE) - double mse = 0.0; - // Calculate mean absolute error (MAE) - double mae = 0.0; + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; + int totalAmount = 0; + // Calculate mean squared error (MSE) + double mse = 0.0; + // Calculate mean absolute error (MAE) + double mae = 0.0; - double prediction = knnMdl.apply(new DenseVector(inputs)); + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; - mse += Math.pow(prediction - groundTruth, 2.0); - mae += Math.abs(prediction - groundTruth); + double prediction = knnMdl.apply(new DenseVector(inputs)); - totalAmount++; + mse += Math.pow(prediction - groundTruth, 2.0); + mae += Math.abs(prediction - groundTruth); - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } + totalAmount++; - System.out.println(">>> ---------------------------------"); + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } - mse = mse / totalAmount; - System.out.println("\n>>> Mean squared error (MSE) " + mse); + System.out.println(">>> ---------------------------------"); - mae = mae / totalAmount; - System.out.println("\n>>> Mean absolute error (MAE) " + mae); + mse = mse / totalAmount; + System.out.println("\n>>> Mean squared error (MSE) " + mse); - System.out.println(">>> kNN regression over cached dataset usage example completed."); - } - }); + mae = mae / totalAmount; + System.out.println("\n>>> Mean absolute error (MAE) " + mae); - igniteThread.start(); - igniteThread.join(); + System.out.println(">>> kNN regression over cached dataset usage example completed."); + } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java index 3e5a98c..6d5745e 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java @@ -34,7 +34,6 @@ import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.thread.IgniteThread; /** * Example of using distributed {@link MultilayerPerceptron}. @@ -70,76 +69,65 @@ public class MLPTrainerExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - MLPTrainerExample.class.getSimpleName(), () -> { - - // Create cache with training data. - CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>(); - trainingSetCfg.setName("TRAINING_SET"); - trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); - - IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg); - - // Fill cache with training data. - trainingSet.put(0, new LabeledPoint(0, 0, 0)); - trainingSet.put(1, new LabeledPoint(0, 1, 1)); - trainingSet.put(2, new LabeledPoint(1, 0, 1)); - trainingSet.put(3, new LabeledPoint(1, 1, 0)); - - // Define a layered architecture. - MLPArchitecture arch = new MLPArchitecture(2). - withAddedLayer(10, true, Activators.RELU). - withAddedLayer(1, false, Activators.SIGMOID); - - // Define a neural network trainer. - MLPTrainer<SimpleGDParameterUpdate> trainer = new MLPTrainer<>( - arch, - LossFunctions.MSE, - new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.1), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ), - 3000, - 4, - 50, - 123L - ); - - // Train neural network and get multilayer perceptron model. - MultilayerPerceptron mlp = trainer.fit( - ignite, - trainingSet, - (k, v) -> VectorUtils.of(v.x, v.y), - (k, v) -> new double[] {v.lb} - ); - - int totalCnt = 4; - int failCnt = 0; - - // Calculate score. - for (int i = 0; i < 4; i++) { - LabeledPoint pnt = trainingSet.get(i); - Matrix predicted = mlp.apply(new DenseMatrix(new double[][] {{pnt.x, pnt.y}})); - - double predictedVal = predicted.get(0, 0); - double lbl = pnt.lb; - System.out.printf(">>> key: %d\t\t predicted: %.4f\t\tlabel: %.4f\n", i, predictedVal, lbl); - failCnt += Math.abs(predictedVal - lbl) < 0.5 ? 0 : 1; - } - - double failRatio = (double)failCnt / totalCnt; - - System.out.println("\n>>> Fail percentage: " + (failRatio * 100) + "%."); - - System.out.println("\n>>> Distributed multilayer perceptron example completed."); - }); - - igniteThread.start(); - - igniteThread.join(); + // Create cache with training data. + CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>(); + trainingSetCfg.setName("TRAINING_SET"); + trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg); + + // Fill cache with training data. + trainingSet.put(0, new LabeledPoint(0, 0, 0)); + trainingSet.put(1, new LabeledPoint(0, 1, 1)); + trainingSet.put(2, new LabeledPoint(1, 0, 1)); + trainingSet.put(3, new LabeledPoint(1, 1, 0)); + + // Define a layered architecture. + MLPArchitecture arch = new MLPArchitecture(2). + withAddedLayer(10, true, Activators.RELU). + withAddedLayer(1, false, Activators.SIGMOID); + + // Define a neural network trainer. + MLPTrainer<SimpleGDParameterUpdate> trainer = new MLPTrainer<>( + arch, + LossFunctions.MSE, + new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.1), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ), + 3000, + 4, + 50, + 123L + ); + + // Train neural network and get multilayer perceptron model. + MultilayerPerceptron mlp = trainer.fit( + ignite, + trainingSet, + (k, v) -> VectorUtils.of(v.x, v.y), + (k, v) -> new double[] {v.lb} + ); + + int totalCnt = 4; + int failCnt = 0; + + // Calculate score. + for (int i = 0; i < 4; i++) { + LabeledPoint pnt = trainingSet.get(i); + Matrix predicted = mlp.apply(new DenseMatrix(new double[][] {{pnt.x, pnt.y}})); + + double predictedVal = predicted.get(0, 0); + double lbl = pnt.lb; + System.out.printf(">>> key: %d\t\t predicted: %.4f\t\tlabel: %.4f\n", i, predictedVal, lbl); + failCnt += Math.abs(predictedVal - lbl) < 0.5 ? 0 : 1; + } + + double failRatio = (double)failCnt / totalCnt; + + System.out.println("\n>>> Fail percentage: " + (failRatio * 100) + "%."); + System.out.println("\n>>> Distributed multilayer perceptron example completed."); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java index 6ac445c..862a37f 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java @@ -29,7 +29,6 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.thread.IgniteThread; /** * Run linear regression model based on <a href="http://web.stanford.edu/group/SOL/software/lsqr/">LSQR algorithm</a> @@ -110,47 +109,40 @@ public class LinearRegressionLSQRTrainerExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - LinearRegressionLSQRTrainerExample.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); + IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); - System.out.println(">>> Create new linear regression trainer object."); - LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); + System.out.println(">>> Create new linear regression trainer object."); + LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); - System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); + System.out.println(">>> Perform the training to get the model."); + LinearRegressionModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); - System.out.println(">>> Linear regression model: " + mdl); + System.out.println(">>> Linear regression model: " + mdl); - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; - double prediction = mdl.apply(new DenseVector(inputs)); + double prediction = mdl.apply(new DenseVector(inputs)); - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); } + } - System.out.println(">>> ---------------------------------"); + System.out.println(">>> ---------------------------------"); - System.out.println(">>> Linear regression model over cache based dataset usage example completed."); - }); - - igniteThread.start(); - - igniteThread.join(); + System.out.println(">>> Linear regression model over cache based dataset usage example completed."); } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java index 320d464..5692cb3 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java @@ -32,7 +32,6 @@ import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerPreprocessor import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.thread.IgniteThread; /** * Run linear regression model based on <a href="http://web.stanford.edu/group/SOL/software/lsqr/">LSQR algorithm</a> @@ -116,55 +115,47 @@ public class LinearRegressionLSQRTrainerWithMinMaxScalerExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - LinearRegressionLSQRTrainerWithMinMaxScalerExample.class.getSimpleName(), () -> { - IgniteCache<Integer, Vector> dataCache = new TestCache(ignite).getVectors(data); + IgniteCache<Integer, Vector> dataCache = new TestCache(ignite).getVectors(data); - System.out.println(">>> Create new minmaxscaling trainer object."); - MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>(); + System.out.println(">>> Create new minmaxscaling trainer object."); + MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>(); - System.out.println(">>> Perform the training to get the minmaxscaling preprocessor."); - IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit( - ignite, - dataCache, - (k, v) -> { - double[] arr = v.asArray(); - return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length)); - } - ); + System.out.println(">>> Perform the training to get the minmaxscaling preprocessor."); + IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit( + ignite, + dataCache, + (k, v) -> { + double[] arr = v.asArray(); + return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length)); + } + ); - System.out.println(">>> Create new linear regression trainer object."); - LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); + System.out.println(">>> Create new linear regression trainer object."); + LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); - System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, (k, v) -> v.get(0)); + System.out.println(">>> Perform the training to get the model."); + LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, (k, v) -> v.get(0)); - System.out.println(">>> Linear regression model: " + mdl); + System.out.println(">>> Linear regression model: " + mdl); - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); - try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, Vector> observation : observations) { - Integer key = observation.getKey(); - Vector val = observation.getValue(); - double groundTruth = val.get(0); + try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Vector> observation : observations) { + Integer key = observation.getKey(); + Vector val = observation.getValue(); + double groundTruth = val.get(0); - double prediction = mdl.apply(preprocessor.apply(key, val)); + double prediction = mdl.apply(preprocessor.apply(key, val)); - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); } + } - System.out.println(">>> ---------------------------------"); - - System.out.println(">>> Linear regression model with minmaxscaling preprocessor over cache based dataset usage example completed."); - }); - - igniteThread.start(); - - igniteThread.join(); + System.out.println(">>> ---------------------------------"); + System.out.println(">>> Linear regression model with minmaxscaling preprocessor over cache based dataset usage example completed."); } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java index 9fdc0df..1e9bd5a 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java @@ -32,7 +32,6 @@ import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer; -import org.apache.ignite.thread.IgniteThread; /** * Run linear regression model based on based on @@ -114,52 +113,44 @@ public class LinearRegressionSGDTrainerExample { // Start ignite grid. try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - LinearRegressionSGDTrainerExample.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); + IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); - System.out.println(">>> Create new linear regression trainer object."); - LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>( - new RPropUpdateCalculator(), - RPropParameterUpdate::sumLocal, - RPropParameterUpdate::avg - ), 100000, 10, 100, 123L); + System.out.println(">>> Create new linear regression trainer object."); + LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>( + new RPropUpdateCalculator(), + RPropParameterUpdate::sumLocal, + RPropParameterUpdate::avg + ), 100000, 10, 100, 123L); - System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); + System.out.println(">>> Perform the training to get the model."); + LinearRegressionModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); - System.out.println(">>> Linear regression model: " + mdl); + System.out.println(">>> Linear regression model: " + mdl); - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; - double prediction = mdl.apply(new DenseVector(inputs)); + double prediction = mdl.apply(new DenseVector(inputs)); - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); } + } - System.out.println(">>> ---------------------------------"); - - System.out.println(">>> Linear regression model over cache based dataset usage example completed."); - }); - - igniteThread.start(); - - igniteThread.join(); + System.out.println(">>> ---------------------------------"); + System.out.println(">>> Linear regression model over cache based dataset usage example completed."); } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java index 0a6ff01..8d4218d 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java @@ -32,7 +32,6 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpda import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; -import org.apache.ignite.thread.IgniteThread; /** * Run logistic regression model based on <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent"> @@ -57,69 +56,62 @@ public class LogisticRegressionSGDTrainerExample { // Start ignite grid. try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - LogisticRegressionSGDTrainerExample.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); + IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); - System.out.println(">>> Create new logistic regression trainer object."); - LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ), 100000, 10, 100, 123L); + System.out.println(">>> Create new logistic regression trainer object."); + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ), 100000, 10, 100, 123L); - System.out.println(">>> Perform the training to get the model."); - LogisticRegressionModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); + System.out.println(">>> Perform the training to get the model."); + LogisticRegressionModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); - System.out.println(">>> Logistic regression model: " + mdl); + System.out.println(">>> Logistic regression model: " + mdl); - int amountOfErrors = 0; - int totalAmount = 0; + int amountOfErrors = 0; + int totalAmount = 0; - // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix - int[][] confusionMtx = {{0, 0}, {0, 0}}; + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; - double prediction = mdl.apply(new DenseVector(inputs)); + double prediction = mdl.apply(new DenseVector(inputs)); - totalAmount++; - if(groundTruth != prediction) - amountOfErrors++; + totalAmount++; + if(groundTruth != prediction) + amountOfErrors++; - int idx1 = (int)prediction; - int idx2 = (int)groundTruth; + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; - confusionMtx[idx1][idx2]++; + confusionMtx[idx1][idx2]++; - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } - - System.out.println(">>> ---------------------------------"); - - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); } - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); System.out.println(">>> ---------------------------------"); - System.out.println(">>> Logistic regression model over partitioned dataset usage example completed."); - }); + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + } - igniteThread.start(); + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); - igniteThread.join(); + System.out.println(">>> Logistic regression model over partitioned dataset usage example completed."); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java index e670f01..ff2761a 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java @@ -35,7 +35,6 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalcula import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer; -import org.apache.ignite.thread.IgniteThread; /** * Run Logistic Regression multi-class classification trainer ({@link LogRegressionMultiClassModel}) over distributed @@ -62,115 +61,109 @@ public class LogRegressionMultiClassClassificationExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - LogRegressionMultiClassClassificationExample.class.getSimpleName(), () -> { - IgniteCache<Integer, Vector> dataCache = new TestCache(ignite).getVectors(data); + IgniteCache<Integer, Vector> dataCache = new TestCache(ignite).getVectors(data); - LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>() - .withUpdatesStgy(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - )) - .withAmountOfIterations(100000) - .withAmountOfLocIterations(10) - .withBatchSize(100) - .withSeed(123L); + LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>() + .withUpdatesStgy(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + )) + .withAmountOfIterations(100000) + .withAmountOfLocIterations(10) + .withBatchSize(100) + .withSeed(123L); - LogRegressionMultiClassModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> { - double[] arr = v.asArray(); - return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length)); - }, - (k, v) -> v.get(0) - ); + LogRegressionMultiClassModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> { + double[] arr = v.asArray(); + return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length)); + }, + (k, v) -> v.get(0) + ); - System.out.println(">>> SVM Multi-class model"); - System.out.println(mdl.toString()); + System.out.println(">>> SVM Multi-class model"); + System.out.println(mdl.toString()); - MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>(); + MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>(); - IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit( - ignite, - dataCache, - (k, v) -> { - double[] arr = v.asArray(); - return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length)); - } - ); - - LogRegressionMultiClassModel mdlWithNormalization = trainer.fit( - ignite, - dataCache, - preprocessor, - (k, v) -> v.get(0) - ); + IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit( + ignite, + dataCache, + (k, v) -> { + double[] arr = v.asArray(); + return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length)); + } + ); - System.out.println(">>> Logistic Regression Multi-class model with minmaxscaling"); - System.out.println(mdlWithNormalization.toString()); + LogRegressionMultiClassModel mdlWithNormalization = trainer.fit( + ignite, + dataCache, + preprocessor, + (k, v) -> v.get(0) + ); - System.out.println(">>> ----------------------------------------------------------------"); - System.out.println(">>> | Prediction\t| Prediction with Normalization\t| Ground Truth\t|"); - System.out.println(">>> ----------------------------------------------------------------"); + System.out.println(">>> Logistic Regression Multi-class model with minmaxscaling"); + System.out.println(mdlWithNormalization.toString()); - int amountOfErrors = 0; - int amountOfErrorsWithNormalization = 0; - int totalAmount = 0; + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println(">>> | Prediction\t| Prediction with Normalization\t| Ground Truth\t|"); + System.out.println(">>> ----------------------------------------------------------------"); - // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix - int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; - int[][] confusionMtxWithNormalization = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; + int amountOfErrors = 0; + int amountOfErrorsWithNormalization = 0; + int totalAmount = 0; - try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, Vector> observation : observations) { - double[] val = observation.getValue().asArray(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; + int[][] confusionMtxWithNormalization = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; - double prediction = mdl.apply(new DenseVector(inputs)); - double predictionWithNormalization = mdlWithNormalization.apply(new DenseVector(inputs)); + try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Vector> observation : observations) { + double[] val = observation.getValue().asArray(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; - totalAmount++; + double prediction = mdl.apply(new DenseVector(inputs)); + double predictionWithNormalization = mdlWithNormalization.apply(new DenseVector(inputs)); - // Collect data for model - if(groundTruth != prediction) - amountOfErrors++; + totalAmount++; - int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2); - int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); + // Collect data for model + if(groundTruth != prediction) + amountOfErrors++; - confusionMtx[idx1][idx2]++; + int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2); + int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); - // Collect data for model with minmaxscaling - if(groundTruth != predictionWithNormalization) - amountOfErrorsWithNormalization++; + confusionMtx[idx1][idx2]++; - idx1 = (int)predictionWithNormalization == 1 ? 0 : ((int)predictionWithNormalization == 3 ? 1 : 2); - idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); + // Collect data for model with minmaxscaling + if(groundTruth != predictionWithNormalization) + amountOfErrorsWithNormalization++; - confusionMtxWithNormalization[idx1][idx2]++; + idx1 = (int)predictionWithNormalization == 1 ? 0 : ((int)predictionWithNormalization == 3 ? 1 : 2); + idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); - System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithNormalization, groundTruth); - } - System.out.println(">>> ----------------------------------------------------------------"); - System.out.println("\n>>> -----------------Logistic Regression model-------------"); - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + confusionMtxWithNormalization[idx1][idx2]++; - System.out.println("\n>>> -----------------Logistic Regression model with Normalization-------------"); - System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithNormalization); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithNormalization / (double)totalAmount)); - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithNormalization)); - - System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example completed."); + System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithNormalization, groundTruth); } - }); + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println("\n>>> -----------------Logistic Regression model-------------"); + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + + System.out.println("\n>>> -----------------Logistic Regression model with Normalization-------------"); + System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithNormalization); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithNormalization / (double)totalAmount)); + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithNormalization)); - igniteThread.start(); - igniteThread.join(); + System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example completed."); + } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java index eb4c8f3..25ce156 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java @@ -24,13 +24,11 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.examples.ml.tree.DecisionTreeClassificationTrainerExample; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.selection.cv.CrossValidation; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; -import org.apache.ignite.thread.IgniteThread; /** * Run <a href="https://en.wikipedia.org/wiki/Decision_tree">decision tree</a> classification with @@ -54,46 +52,38 @@ public class CrossValidationExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - DecisionTreeClassificationTrainerExample.class.getSimpleName(), () -> { + // Create cache with training data. + CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>(); + trainingSetCfg.setName("TRAINING_SET"); + trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); - // Create cache with training data. - CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>(); - trainingSetCfg.setName("TRAINING_SET"); - trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); + IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg); - IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg); + Random rnd = new Random(0); - Random rnd = new Random(0); + // Fill training data. + for (int i = 0; i < 1000; i++) + trainingSet.put(i, generatePoint(rnd)); - // Fill training data. - for (int i = 0; i < 1000; i++) - trainingSet.put(i, generatePoint(rnd)); + // Create classification trainer. + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0); - // Create classification trainer. - DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0); + CrossValidation<DecisionTreeNode, Double, Integer, LabeledPoint> scoreCalculator + = new CrossValidation<>(); - CrossValidation<DecisionTreeNode, Double, Integer, LabeledPoint> scoreCalculator - = new CrossValidation<>(); + double[] scores = scoreCalculator.score( + trainer, + new Accuracy<>(), + ignite, + trainingSet, + (k, v) -> VectorUtils.of(v.x, v.y), + (k, v) -> v.lb, + 4 + ); - double[] scores = scoreCalculator.score( - trainer, - new Accuracy<>(), - ignite, - trainingSet, - (k, v) -> VectorUtils.of(v.x, v.y), - (k, v) -> v.lb, - 4 - ); + System.out.println(">>> Accuracy: " + Arrays.toString(scores)); - System.out.println(">>> Accuracy: " + Arrays.toString(scores)); - - System.out.println(">>> Cross validation score calculator example completed."); - }); - - igniteThread.start(); - - igniteThread.join(); + System.out.println(">>> Cross validation score calculator example completed."); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java index fa1c2ca..8b104f5 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java @@ -31,7 +31,6 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; import org.apache.ignite.ml.selection.split.TrainTestSplit; -import org.apache.ignite.thread.IgniteThread; /** * Run linear regression model over dataset split on train and test subsets ({@link TrainTestDatasetSplitter}). @@ -113,55 +112,47 @@ public class TrainTestDatasetSplitterExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - TrainTestDatasetSplitterExample.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); + IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); - System.out.println(">>> Create new linear regression trainer object."); - LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); + System.out.println(">>> Create new linear regression trainer object."); + LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); - System.out.println(">>> Create new training dataset splitter object."); - TrainTestSplit<Integer, double[]> split = new TrainTestDatasetSplitter<Integer, double[]>() - .split(0.75); + System.out.println(">>> Create new training dataset splitter object."); + TrainTestSplit<Integer, double[]> split = new TrainTestDatasetSplitter<Integer, double[]>() + .split(0.75); - System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel mdl = trainer.fit( - ignite, - dataCache, - split.getTrainFilter(), - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); + System.out.println(">>> Perform the training to get the model."); + LinearRegressionModel mdl = trainer.fit( + ignite, + dataCache, + split.getTrainFilter(), + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); - System.out.println(">>> Linear regression model: " + mdl); + System.out.println(">>> Linear regression model: " + mdl); - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); - ScanQuery<Integer, double[]> qry = new ScanQuery<>(); - qry.setFilter(split.getTestFilter()); + ScanQuery<Integer, double[]> qry = new ScanQuery<>(); + qry.setFilter(split.getTestFilter()); - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(qry)) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(qry)) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; - double prediction = mdl.apply(new DenseVector(inputs)); + double prediction = mdl.apply(new DenseVector(inputs)); - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); } + } - System.out.println(">>> ---------------------------------"); - - System.out.println(">>> Linear regression model over cache based dataset usage example completed."); - }); - - igniteThread.start(); - - igniteThread.join(); + System.out.println(">>> ---------------------------------"); + System.out.println(">>> Linear regression model over cache based dataset usage example completed."); } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java index f71db2d..c219441 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java @@ -29,7 +29,6 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel; import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer; -import org.apache.ignite.thread.IgniteThread; /** * Run SVM binary-class classification model ({@link SVMLinearBinaryClassificationModel}) over distributed dataset. @@ -54,64 +53,58 @@ public class SVMBinaryClassificationExample { try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - SVMBinaryClassificationExample.class.getSimpleName(), () -> { - IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); + IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data); - SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer(); + SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer(); - SVMLinearBinaryClassificationModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); + SVMLinearBinaryClassificationModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); - System.out.println(">>> SVM model " + mdl); + System.out.println(">>> SVM model " + mdl); - System.out.println(">>> ---------------------------------"); - System.out.println(">>> | Prediction\t| Ground Truth\t|"); - System.out.println(">>> ---------------------------------"); - - int amountOfErrors = 0; - int totalAmount = 0; + System.out.println(">>> ---------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t|"); + System.out.println(">>> ---------------------------------"); - // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix - int[][] confusionMtx = {{0, 0}, {0, 0}}; + int amountOfErrors = 0; + int totalAmount = 0; - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; - double prediction = mdl.apply(new DenseVector(inputs)); + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; - totalAmount++; - if(groundTruth != prediction) - amountOfErrors++; + double prediction = mdl.apply(new DenseVector(inputs)); - int idx1 = prediction == 0.0 ? 0 : 1; - int idx2 = groundTruth == 0.0 ? 0 : 1; + totalAmount++; + if(groundTruth != prediction) + amountOfErrors++; - confusionMtx[idx1][idx2]++; + int idx1 = prediction == 0.0 ? 0 : 1; + int idx2 = groundTruth == 0.0 ? 0 : 1; - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } + confusionMtx[idx1][idx2]++; - System.out.println(">>> ---------------------------------"); - - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); } - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + } - System.out.println(">>> Linear regression model over cache based dataset usage example completed."); - }); + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); - igniteThread.start(); - igniteThread.join(); + System.out.println(">>> Linear regression model over cache based dataset usage example completed."); } }
