Repository: ignite Updated Branches: refs/heads/master 85525eed4 -> 6bb6b3e51
IGNITE-7397: Fix cache configuration and reduced trainings count in MLP group training test. this closes #3398 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/6bb6b3e5 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/6bb6b3e5 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/6bb6b3e5 Branch: refs/heads/master Commit: 6bb6b3e514899244d2ab813e4bf752a8c47c785d Parents: 85525ee Author: YuriBabak <[email protected]> Authored: Tue Jan 23 13:03:34 2018 +0300 Committer: YuriBabak <[email protected]> Committed: Tue Jan 23 13:03:34 2018 +0300 ---------------------------------------------------------------------- .../examples/ml/nn/MLPGroupTrainerExample.java | 17 ++-- .../examples/ml/nn/MLPLocalTrainerExample.java | 2 +- .../cassandra/serializer/package-info.java | 21 +++++ .../ignite/tests/CassandraSessionImplTest.java | 17 ++++ .../java/org/apache/ignite/IgniteLogger.java | 10 --- .../ml/nn/MLPGroupUpdateTrainerCacheInput.java | 28 +++--- .../distributed/MLPGroupUpdateTrainer.java | 74 ++++++++------- .../distributed/MLPGroupUpdateTrainingData.java | 6 +- .../MLPGroupUpdateTrainingLoopData.java | 6 +- .../trainers/distributed/MLPMetaoptimizer.java | 3 +- .../nn/trainers/local/MLPLocalBatchTrainer.java | 4 +- .../RPropUpdateCalculator.java | 6 +- .../updatecalculators/SimpleGDParameter.java | 77 ---------------- .../SimpleGDParameterUpdate.java | 89 ++++++++++++++++++ .../SimpleGDUpdateCalculator.java | 38 +++++--- .../trainers/group/BaseLocalProcessorJob.java | 3 +- .../group/GroupTrainerBaseProcessorTask.java | 3 +- .../ml/trainers/group/ResultAndUpdates.java | 3 +- .../ml/trainers/group/UpdateStrategies.java | 47 ++++++++++ .../ml/trainers/group/UpdatesStrategy.java | 94 ++++++++++++++++++++ .../ml/trainers/local/LocalBatchTrainer.java | 8 +- .../java/org/apache/ignite/ml/util/Utils.java | 2 +- .../ignite/ml/nn/MLPGroupTrainerTest.java | 38 ++++++-- .../ignite/ml/nn/MLPLocalTrainerTest.java | 8 +- .../ml/nn/performance/MnistDistributed.java | 6 +- .../ignite/ml/nn/performance/MnistLocal.java | 4 +- 26 files changed, 427 insertions(+), 187 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java index d106fad..d45e957 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java @@ -40,8 +40,7 @@ import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.thread.IgniteThread; /** - * <p> - * Example of using distributed {@link MultilayerPerceptron}.</p> + * Example of using distributed {@link MultilayerPerceptron}. * <p> * Remote nodes should always be started with special configuration file which * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p> @@ -57,7 +56,7 @@ public class MLPGroupTrainerExample { */ public static void main(String[] args) throws InterruptedException { // IMPL NOTE based on MLPGroupTrainerTest#testXOR - System.out.println(">>> Distributed multilayer perceptron example started."); + System.out.println(">>> Distributed multilayer perceptron example started."); // Start ignite grid. try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { @@ -68,7 +67,7 @@ public class MLPGroupTrainerExample { IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), MLPGroupTrainerExample.class.getSimpleName(), () -> { - int samplesCnt = 1000; + int samplesCnt = 10000; Matrix xorInputs = new DenseLocalOnHeapMatrix( new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}}, @@ -88,7 +87,7 @@ public class MLPGroupTrainerExample { try (IgniteDataStreamer<Integer, LabeledVector<Vector, Vector>> streamer = ignite.dataStreamer(cacheName)) { - streamer.perNodeBufferSize(10000); + streamer.perNodeBufferSize(100); for (int i = 0; i < samplesCnt; i++) { int col = Math.abs(rnd.nextInt()) % 4; @@ -99,14 +98,14 @@ public class MLPGroupTrainerExample { int totalCnt = 100; int failCnt = 0; MLPGroupUpdateTrainer<RPropParameterUpdate> trainer = MLPGroupUpdateTrainer.getDefault(ignite). - withSyncRate(3). + withSyncPeriod(3). withTolerance(0.001). - withMaxGlobalSteps(1000); + withMaxGlobalSteps(20); for (int i = 0; i < totalCnt; i++) { MLPGroupUpdateTrainerCacheInput trainerInput = new MLPGroupUpdateTrainerCacheInput(conf, - new RandomInitializer(rnd), 6, cache, 4); + new RandomInitializer(rnd), 6, cache, 10); MultilayerPerceptron mlp = trainer.train(trainerInput); @@ -125,7 +124,7 @@ public class MLPGroupTrainerExample { System.out.println("\n>>> Fail percentage: " + (failRatio * 100) + "%."); - System.out.println("\n>>> Distributed multilayer perceptron example completed."); + System.out.println("\n>>> Distributed multilayer perceptron example completed."); }); igniteThread.start(); http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java index b557458..02280ce 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java @@ -67,7 +67,7 @@ public class MLPLocalTrainerExample { System.out.println("\n>>> Perform training."); MultilayerPerceptron mlp = new MLPLocalBatchTrainer<>(LossFunctions.MSE, - () -> new RPropUpdateCalculator<>(), + RPropUpdateCalculator::new, 0.0001, 16000).train(trainerInput); http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/cassandra/serializers/src/main/java/org/apache/ignite/cache/store/cassandra/serializer/package-info.java ---------------------------------------------------------------------- diff --git a/modules/cassandra/serializers/src/main/java/org/apache/ignite/cache/store/cassandra/serializer/package-info.java b/modules/cassandra/serializers/src/main/java/org/apache/ignite/cache/store/cassandra/serializer/package-info.java new file mode 100644 index 0000000..cad12b7 --- /dev/null +++ b/modules/cassandra/serializers/src/main/java/org/apache/ignite/cache/store/cassandra/serializer/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains Cassandra serializers. + */ +package org.apache.ignite.cache.store.cassandra.serializer; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/cassandra/store/src/test/java/org/apache/ignite/tests/CassandraSessionImplTest.java ---------------------------------------------------------------------- diff --git a/modules/cassandra/store/src/test/java/org/apache/ignite/tests/CassandraSessionImplTest.java b/modules/cassandra/store/src/test/java/org/apache/ignite/tests/CassandraSessionImplTest.java index 9546d46..27fd741 100644 --- a/modules/cassandra/store/src/test/java/org/apache/ignite/tests/CassandraSessionImplTest.java +++ b/modules/cassandra/store/src/test/java/org/apache/ignite/tests/CassandraSessionImplTest.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.ignite.tests; import static org.junit.Assert.assertEquals; http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/core/src/main/java/org/apache/ignite/IgniteLogger.java ---------------------------------------------------------------------- diff --git a/modules/core/src/main/java/org/apache/ignite/IgniteLogger.java b/modules/core/src/main/java/org/apache/ignite/IgniteLogger.java index 0f5e4aa..412e414 100644 --- a/modules/core/src/main/java/org/apache/ignite/IgniteLogger.java +++ b/modules/core/src/main/java/org/apache/ignite/IgniteLogger.java @@ -87,8 +87,6 @@ public interface IgniteLogger { /** * Logs out trace message. - * - * @implSpec * The default implementation calls {@code this.trace(msg)}. * * @param marker Name of the marker to be associated with the message. @@ -107,8 +105,6 @@ public interface IgniteLogger { /** * Logs out debug message. - * - * @implSpec * The default implementation calls {@code this.debug(msg)}. * * @param marker Name of the marker to be associated with the message. @@ -127,8 +123,6 @@ public interface IgniteLogger { /** * Logs out information message. - * - * @implSpec * The default implementation calls {@code this.info(msg)}. * * @param marker Name of the marker to be associated with the message. @@ -157,8 +151,6 @@ public interface IgniteLogger { /** * Logs out warning message with optional exception. - * - * @implSpec * The default implementation calls {@code this.warning(msg)}. * * @param marker Name of the marker to be associated with the message. @@ -188,8 +180,6 @@ public interface IgniteLogger { /** * Logs error message with optional exception. - * - * @implSpec * The default implementation calls {@code this.error(msg)}. * * @param marker Name of the marker to be associated with the message. http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java index 783effa..ce42938 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java @@ -60,7 +60,7 @@ public class MLPGroupUpdateTrainerCacheInput extends AbstractMLPGroupUpdateTrain /** * Random number generator. */ - private Random rand; + private final Random rand; /** * Construct instance of this class with given parameters. @@ -80,6 +80,7 @@ public class MLPGroupUpdateTrainerCacheInput extends AbstractMLPGroupUpdateTrain this.batchSize = batchSize; this.cache = cache; this.mlp = new MultilayerPerceptron(arch, init); + this.rand = rand; } /** @@ -94,17 +95,17 @@ public class MLPGroupUpdateTrainerCacheInput extends AbstractMLPGroupUpdateTrain public MLPGroupUpdateTrainerCacheInput(MLPArchitecture arch, MLPInitializer init, int networksCnt, IgniteCache<Integer, LabeledVector<Vector, Vector>> cache, int batchSize) { - this(arch, init, networksCnt, cache, batchSize, new Random()); + this(arch, init, networksCnt, cache, batchSize, null); } - /** - * Construct instance of this class with given parameters and default initializer. - * - * @param arch Architecture of multilayer perceptron. - * @param networksCnt Count of networks to be trained in parallel by {@link MLPGroupUpdateTrainer}. - * @param cache Cache with labeled vectors. - * @param batchSize Size of batch to return on each training iteration. - */ + /** + * Construct instance of this class with given parameters and default initializer. + * + * @param arch Architecture of multilayer perceptron. + * @param networksCnt Count of networks to be trained in parallel by {@link MLPGroupUpdateTrainer}. + * @param cache Cache with labeled vectors. + * @param batchSize Size of batch to return on each training iteration. + */ public MLPGroupUpdateTrainerCacheInput(MLPArchitecture arch, int networksCnt, IgniteCache<Integer, LabeledVector<Vector, Vector>> cache, int batchSize) { @@ -114,8 +115,9 @@ public class MLPGroupUpdateTrainerCacheInput extends AbstractMLPGroupUpdateTrain /** {@inheritDoc} */ @Override public IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier() { String cName = cache.getName(); - int bs = batchSize; - Random r = rand; + + int bs = batchSize; // This line is for prohibiting of 'this' object be caught into serialization context of lambda. + Random r = rand; // This line is for prohibiting of 'this' object be caught into serialization context of lambda. return () -> { Ignite ignite = Ignition.localIgnite(); @@ -138,7 +140,7 @@ public class MLPGroupUpdateTrainerCacheInput extends AbstractMLPGroupUpdateTrain Matrix groundTruth = new DenseLocalOnHeapMatrix(dimEntry.label().size(), bs); for (int i = 0; i < selected.length; i++) { - LabeledVector<Vector, Vector> labeled = cache.get(selected[i]); + LabeledVector<Vector, Vector> labeled = cache.get(keys.get(selected[i])); inputs.assignColumn(i, labeled.features()); groundTruth.assignColumn(i, labeled.label()); http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java index f4647d5..8e97d87 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java @@ -35,12 +35,14 @@ import org.apache.ignite.ml.math.functions.IgniteSupplier; import org.apache.ignite.ml.math.util.MatrixUtil; import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.nn.MultilayerPerceptron; +import org.apache.ignite.ml.optimization.SmoothParametrized; import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator; import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey; import org.apache.ignite.ml.trainers.group.MetaoptimizerGroupTrainer; import org.apache.ignite.ml.trainers.group.ResultAndUpdates; +import org.apache.ignite.ml.trainers.group.UpdatesStrategy; import org.apache.ignite.ml.trainers.group.chain.EntryAndContext; import org.apache.ignite.ml.util.Utils; @@ -78,9 +80,9 @@ public class MLPGroupUpdateTrainer<U extends Serializable> extends private final int maxGlobalSteps; /** - * Synchronize updates between networks every syncRate steps. + * Synchronize updates between networks every syncPeriod steps. */ - private final int syncRate; + private final int syncPeriod; /** * Function used to reduce updates from different networks (for example, averaging of gradients of all networks). @@ -96,7 +98,7 @@ public class MLPGroupUpdateTrainer<U extends Serializable> extends /** * Updates calculator. */ - private final ParameterUpdateCalculator<MultilayerPerceptron, U> updateCalculator; + private final ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator; /** * Default maximal count of global steps. @@ -123,8 +125,8 @@ public class MLPGroupUpdateTrainer<U extends Serializable> extends /** * Default update calculator. */ - private static final ParameterUpdateCalculator<MultilayerPerceptron, RPropParameterUpdate> - DEFAULT_UPDATE_CALCULATOR = new RPropUpdateCalculator<>(); + private static final ParameterUpdateCalculator<SmoothParametrized, RPropParameterUpdate> + DEFAULT_UPDATE_CALCULATOR = new RPropUpdateCalculator(); /** * Default loss function. @@ -140,16 +142,16 @@ public class MLPGroupUpdateTrainer<U extends Serializable> extends * @param tolerance Error tolerance. */ public MLPGroupUpdateTrainer(int maxGlobalSteps, - int syncRate, + int syncPeriod, IgniteFunction<List<U>, U> allUpdatesReducer, IgniteFunction<List<U>, U> locStepUpdatesReducer, - ParameterUpdateCalculator<MultilayerPerceptron, U> updateCalculator, + ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, Ignite ignite, double tolerance) { super(new MLPMetaoptimizer<>(allUpdatesReducer), MLPCache.getOrCreate(ignite), ignite); this.maxGlobalSteps = maxGlobalSteps; - this.syncRate = syncRate; + this.syncPeriod = syncPeriod; this.allUpdatesReducer = allUpdatesReducer; this.locStepUpdatesReducer = locStepUpdatesReducer; this.updateCalculator = updateCalculator; @@ -174,7 +176,7 @@ public class MLPGroupUpdateTrainer<U extends Serializable> extends MLPGroupUpdateTrainerDataCache.getOrCreate(ignite).put(trainingUUID, new MLPGroupUpdateTrainingData<>( updateCalculator, - syncRate, + syncPeriod, locStepUpdatesReducer, data.batchSupplier(), loss, @@ -237,26 +239,32 @@ public class MLPGroupUpdateTrainer<U extends Serializable> extends return data -> { MultilayerPerceptron mlp = data.mlp(); - MultilayerPerceptron mlpCp = Utils.copy(mlp); - ParameterUpdateCalculator<MultilayerPerceptron, U> updateCalculator = data.updateCalculator(); + // Apply previous update. + MultilayerPerceptron newMlp = updateCalculator.update(mlp, data.previousUpdate()); + + MultilayerPerceptron mlpCp = Utils.copy(newMlp); + ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator = data.updateCalculator(); IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss = data.loss(); // ParameterUpdateCalculator API to have proper way to setting loss. updateCalculator.init(mlpCp, loss); - U curUpdate = data.previousUpdate(); - + // Generate new update. int steps = data.stepsCnt(); List<U> updates = new ArrayList<>(steps); - - IgniteBiTuple<Matrix, Matrix> batch = data.batchSupplier().get(); + U curUpdate = data.previousUpdate(); for (int i = 0; i < steps; i++) { + IgniteBiTuple<Matrix, Matrix> batch = data.batchSupplier().get(); Matrix input = batch.get1(); Matrix truth = batch.get2(); int batchSize = truth.columnSize(); + curUpdate = updateCalculator.calculateNewUpdate(mlpCp, curUpdate, i, input, truth); + mlpCp = updateCalculator.update(mlpCp, curUpdate); + updates.add(curUpdate); + Matrix predicted = mlpCp.apply(input); double err = MatrixUtil.zipFoldByColumns(predicted, truth, (predCol, truthCol) -> @@ -264,18 +272,11 @@ public class MLPGroupUpdateTrainer<U extends Serializable> extends if (err < data.tolerance()) break; - - mlpCp = updateCalculator.update(mlpCp, curUpdate); - updates.add(curUpdate); - - curUpdate = updateCalculator.calculateNewUpdate(mlpCp, curUpdate, i, input, truth); } - U update = data.getUpdateReducer().apply(updates); + U accumulatedUpdate = data.getUpdateReducer().apply(updates); - MultilayerPerceptron newMlp = updateCalculator.update(mlp, data.previousUpdate()); - - return new ResultAndUpdates<>(update). + return new ResultAndUpdates<>(accumulatedUpdate). updateCache(MLPCache.getOrCreate(Ignition.localIgnite()), data.key(), new MLPGroupTrainingCacheValue(newMlp)); }; @@ -333,18 +334,18 @@ public class MLPGroupUpdateTrainer<U extends Serializable> extends * @return New {@link MLPGroupUpdateTrainer} with new maxGlobalSteps value. */ public MLPGroupUpdateTrainer<U> withMaxGlobalSteps(int maxGlobalSteps) { - return new MLPGroupUpdateTrainer<>(maxGlobalSteps, syncRate, allUpdatesReducer, locStepUpdatesReducer, + return new MLPGroupUpdateTrainer<>(maxGlobalSteps, syncPeriod, allUpdatesReducer, locStepUpdatesReducer, updateCalculator, loss, ignite, tolerance); } /** - * Create new {@link MLPGroupUpdateTrainer} with new syncRate value. + * Create new {@link MLPGroupUpdateTrainer} with new syncPeriod value. * - * @param syncRate New syncRate value. - * @return New {@link MLPGroupUpdateTrainer} with new syncRate value. + * @param syncPeriod New syncPeriod value. + * @return New {@link MLPGroupUpdateTrainer} with new syncPeriod value. */ - public MLPGroupUpdateTrainer<U> withSyncRate(int syncRate) { - return new MLPGroupUpdateTrainer<>(maxGlobalSteps, syncRate + public MLPGroupUpdateTrainer<U> withSyncPeriod(int syncPeriod) { + return new MLPGroupUpdateTrainer<>(maxGlobalSteps, syncPeriod , allUpdatesReducer, locStepUpdatesReducer, updateCalculator, loss, ignite, tolerance); } @@ -355,7 +356,18 @@ public class MLPGroupUpdateTrainer<U extends Serializable> extends * @return New {@link MLPGroupUpdateTrainer} with new tolerance value. */ public MLPGroupUpdateTrainer<U> withTolerance(double tolerance) { - return new MLPGroupUpdateTrainer<>(maxGlobalSteps, syncRate, allUpdatesReducer, locStepUpdatesReducer, + return new MLPGroupUpdateTrainer<>(maxGlobalSteps, syncPeriod, allUpdatesReducer, locStepUpdatesReducer, updateCalculator, loss, ignite, tolerance); } + + /** + * Create new {@link MLPGroupUpdateTrainer} with new update strategy. + * + * @param stgy New update strategy. + * @return New {@link MLPGroupUpdateTrainer} with new tolerance value. + */ + public <U1 extends Serializable> MLPGroupUpdateTrainer<U1> withUpdateStrategy(UpdatesStrategy<? super MultilayerPerceptron, U1> stgy) { + return new MLPGroupUpdateTrainer<>(maxGlobalSteps, syncPeriod, stgy.allUpdatesReducer(), stgy.locStepUpdatesReducer(), + stgy.getUpdatesCalculator(), loss, ignite, tolerance); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java index 740fac6..3031c8f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java @@ -30,7 +30,7 @@ import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalcul /** Multilayer perceptron group update training data. */ public class MLPGroupUpdateTrainingData<U> { /** {@link ParameterUpdateCalculator}. */ - private final ParameterUpdateCalculator<MultilayerPerceptron, U> updateCalculator; + private final ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator; /** * Count of steps which should be done by each of parallel trainings before sending it's update for combining with @@ -59,7 +59,7 @@ public class MLPGroupUpdateTrainingData<U> { /** Construct multilayer perceptron group update training data with all parameters provided. */ public MLPGroupUpdateTrainingData( - ParameterUpdateCalculator<MultilayerPerceptron, U> updateCalculator, int stepsCnt, + ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator, int stepsCnt, IgniteFunction<List<U>, U> updateReducer, IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, double tolerance) { @@ -72,7 +72,7 @@ public class MLPGroupUpdateTrainingData<U> { } /** Get update calculator. */ - public ParameterUpdateCalculator<MultilayerPerceptron, U> updateCalculator() { + public ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator() { return updateCalculator; } http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java index 2050ee5..342e7d5 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java @@ -32,7 +32,7 @@ import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey; /** Multilayer perceptron group update training loop data. */ public class MLPGroupUpdateTrainingLoopData<P> implements Serializable { /** {@link ParameterUpdateCalculator}. */ - private final ParameterUpdateCalculator<MultilayerPerceptron, P> updateCalculator; + private final ParameterUpdateCalculator<? super MultilayerPerceptron, P> updateCalculator; /** * Count of steps which should be done by each of parallel trainings before sending it's update for combining with @@ -63,7 +63,7 @@ public class MLPGroupUpdateTrainingLoopData<P> implements Serializable { /** Create multilayer perceptron group update training loop data. */ public MLPGroupUpdateTrainingLoopData(MultilayerPerceptron mlp, - ParameterUpdateCalculator<MultilayerPerceptron, P> updateCalculator, int stepsCnt, + ParameterUpdateCalculator<? super MultilayerPerceptron, P> updateCalculator, int stepsCnt, IgniteFunction<List<P>, P> updateReducer, P previousUpdate, GroupTrainerCacheKey<Void> key, IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, @@ -85,7 +85,7 @@ public class MLPGroupUpdateTrainingLoopData<P> implements Serializable { } /** Get update calculator. */ - public ParameterUpdateCalculator<MultilayerPerceptron, P> updateCalculator() { + public ParameterUpdateCalculator<? super MultilayerPerceptron, P> updateCalculator() { return updateCalculator; } http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java index 6e314f1..ff95a27 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.nn.trainers.distributed; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.trainers.group.Metaoptimizer; @@ -65,7 +66,7 @@ public class MLPMetaoptimizer<P> implements Metaoptimizer<MLPGroupUpdateTrainerL @Override public P localProcessor(ArrayList<P> input, MLPGroupUpdateTrainerLocalContext locCtx) { locCtx.incrementCurrentStep(); - return allUpdatesReducer.apply(input); + return allUpdatesReducer.apply(input.stream().filter(Objects::nonNull).collect(Collectors.toList())); } /** {@inheritDoc} */ http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java index 059d15a..ebb78c0 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java @@ -61,7 +61,7 @@ public class MLPLocalBatchTrainer<P> */ public MLPLocalBatchTrainer( IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, - IgniteSupplier<ParameterUpdateCalculator<MultilayerPerceptron, P>> updaterSupplier, + IgniteSupplier<ParameterUpdateCalculator<? super MultilayerPerceptron, P>> updaterSupplier, double errorThreshold, int maxIterations) { super(loss, updaterSupplier, errorThreshold, maxIterations); } @@ -72,7 +72,7 @@ public class MLPLocalBatchTrainer<P> * @return MLPLocalBatchTrainer with default parameters. */ public static MLPLocalBatchTrainer<RPropParameterUpdate> getDefault() { - return new MLPLocalBatchTrainer<>(DEFAULT_LOSS, () -> new RPropUpdateCalculator<>(), DEFAULT_ERROR_THRESHOLD, + return new MLPLocalBatchTrainer<>(DEFAULT_LOSS, () -> new RPropUpdateCalculator(), DEFAULT_ERROR_THRESHOLD, DEFAULT_MAX_ITERATIONS); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java index 80345d9..f706a6c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java @@ -30,7 +30,7 @@ import org.apache.ignite.ml.optimization.SmoothParametrized; * <p> * See <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">RProp</a>.</p> */ -public class RPropUpdateCalculator<M extends SmoothParametrized> implements ParameterUpdateCalculator<M, RPropParameterUpdate> { +public class RPropUpdateCalculator implements ParameterUpdateCalculator<SmoothParametrized, RPropParameterUpdate> { /** * Default initial update. */ @@ -138,14 +138,14 @@ public class RPropUpdateCalculator<M extends SmoothParametrized> implements Para } /** {@inheritDoc} */ - @Override public RPropParameterUpdate init(M mdl, + @Override public RPropParameterUpdate init(SmoothParametrized mdl, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { this.loss = loss; return new RPropParameterUpdate(mdl.parametersCount(), initUpdate); } /** {@inheritDoc} */ - @Override public <M1 extends M> M1 update(M1 obj, RPropParameterUpdate update) { + @Override public <M1 extends SmoothParametrized> M1 update(M1 obj, RPropParameterUpdate update) { Vector updatesToAdd = VectorUtils.elementWiseTimes(update.updatesMask().copy(), update.prevIterationUpdates()); return (M1)obj.setParameters(obj.parameters().plus(updatesToAdd)); } http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameter.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameter.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameter.java deleted file mode 100644 index 22fc18a..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameter.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.optimization.updatecalculators; - -import java.io.Serializable; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; - -/** - * Parameters for {@link SimpleGDUpdateCalculator}. - */ -public class SimpleGDParameter implements Serializable { - /** - * Gradient. - */ - private Vector gradient; - - /** - * Learning rate. - */ - private double learningRate; - - /** - * Construct instance of this class. - * - * @param paramsCnt Count of parameters. - * @param learningRate Learning rate. - */ - public SimpleGDParameter(int paramsCnt, double learningRate) { - gradient = new DenseLocalOnHeapVector(paramsCnt); - this.learningRate = learningRate; - } - - /** - * Construct instance of this class. - * - * @param gradient Gradient. - * @param learningRate Learning rate. - */ - public SimpleGDParameter(Vector gradient, double learningRate) { - this.gradient = gradient; - this.learningRate = learningRate; - } - - /** - * Get gradient. - * - * @return Get gradient. - */ - public Vector gradient() { - return gradient; - } - - /** - * Get learning rate. - * - * @return learning rate. - */ - public double learningRate() { - return learningRate; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameterUpdate.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameterUpdate.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameterUpdate.java new file mode 100644 index 0000000..13731ea --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameterUpdate.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.optimization.updatecalculators; + +import java.io.Serializable; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; + +/** + * Parameters for {@link SimpleGDUpdateCalculator}. + */ +public class SimpleGDParameterUpdate implements Serializable { + /** Gradient. */ + private Vector gradient; + + /** + * Construct instance of this class. + * + * @param paramsCnt Count of parameters. + */ + public SimpleGDParameterUpdate(int paramsCnt) { + gradient = new DenseLocalOnHeapVector(paramsCnt); + } + + /** + * Construct instance of this class. + * + * @param gradient Gradient. + */ + public SimpleGDParameterUpdate(Vector gradient) { + this.gradient = gradient; + } + + /** + * Get gradient. + * + * @return Get gradient. + */ + public Vector gradient() { + return gradient; + } + + /** + * Method used to sum updates inside of one of parallel trainings. + * + * @param updates Updates. + * @return Sum of SimpleGDParameterUpdate. + */ + public static SimpleGDParameterUpdate sumLocal(List<SimpleGDParameterUpdate> updates) { + Vector accumulatedGrad = updates. + stream(). + filter(Objects::nonNull). + map(SimpleGDParameterUpdate::gradient). + reduce(Vector::plus). + orElse(null); + + return accumulatedGrad != null ? new SimpleGDParameterUpdate(accumulatedGrad) : null; + } + + /** + * Method used to get total update of all parallel trainings. + * + * @param updates Updates. + * @return Avg of SimpleGDParameterUpdate. + */ + public static SimpleGDParameterUpdate avg(List<SimpleGDParameterUpdate> updates) { + SimpleGDParameterUpdate sum = sumLocal(updates); + return sum != null ? new SimpleGDParameterUpdate(sum.gradient(). + divide(updates.stream().filter(Objects::nonNull).collect(Collectors.toList()).size())) : null; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java index 291e63d..f102396 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java @@ -26,17 +26,21 @@ import org.apache.ignite.ml.optimization.SmoothParametrized; /** * Simple gradient descent parameters updater. */ -public class SimpleGDUpdateCalculator<M extends SmoothParametrized> implements ParameterUpdateCalculator<M, SimpleGDParameter> { - /** - * Learning rate. - */ +public class SimpleGDUpdateCalculator implements ParameterUpdateCalculator<SmoothParametrized, SimpleGDParameterUpdate> { + /** Learning rate. */ private double learningRate; - /** - * Loss function. - */ + /** Loss function. */ protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; + /** Default learning rate. */ + private static final double DEFAULT_LEARNING_RATE = 0.1; + + /** Construct instance of this class with default parameters. */ + public SimpleGDUpdateCalculator() { + this(DEFAULT_LEARNING_RATE); + } + /** * Construct SimpleGDUpdateCalculator. * @@ -47,21 +51,31 @@ public class SimpleGDUpdateCalculator<M extends SmoothParametrized> implements P } /** {@inheritDoc} */ - @Override public SimpleGDParameter init(M mdl, + @Override public SimpleGDParameterUpdate init(SmoothParametrized mdl, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { this.loss = loss; - return new SimpleGDParameter(mdl.parametersCount(), learningRate); + return new SimpleGDParameterUpdate(mdl.parametersCount()); } /** {@inheritDoc} */ - @Override public SimpleGDParameter calculateNewUpdate(SmoothParametrized mlp, SimpleGDParameter updaterParameters, + @Override public SimpleGDParameterUpdate calculateNewUpdate(SmoothParametrized mlp, SimpleGDParameterUpdate updaterParameters, int iteration, Matrix inputs, Matrix groundTruth) { - return new SimpleGDParameter(mlp.differentiateByParameters(loss, inputs, groundTruth), learningRate); + return new SimpleGDParameterUpdate(mlp.differentiateByParameters(loss, inputs, groundTruth)); } /** {@inheritDoc} */ - @Override public <M1 extends M> M1 update(M1 obj, SimpleGDParameter update) { + @Override public <M1 extends SmoothParametrized> M1 update(M1 obj, SimpleGDParameterUpdate update) { Vector params = obj.parameters(); return (M1)obj.setParameters(params.minus(update.gradient().times(learningRate))); } + + /** + * Create new instance of this class with same parameters as this one, but with new learning rate. + * + * @param learningRate Learning rate. + * @return New instance of this class with same parameters as this one, but with new learning rate. + */ + public SimpleGDUpdateCalculator withLearningRate(double learningRate) { + return new SimpleGDUpdateCalculator(learningRate); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java index d70b3f1..e20a55a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.trainers.group; import java.io.Serializable; import java.util.List; +import java.util.Objects; import java.util.UUID; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -102,7 +103,7 @@ public abstract class BaseLocalProcessorJob<K, V, T, R extends Serializable> imp map(worker). collect(Collectors.toList()); - ResultAndUpdates<R> totalRes = ResultAndUpdates.sum(reducer, resultsAndUpdates); + ResultAndUpdates<R> totalRes = ResultAndUpdates.sum(reducer, resultsAndUpdates.stream().filter(Objects::nonNull).collect(Collectors.toList())); totalRes.applyUpdates(ignite()); http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerBaseProcessorTask.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerBaseProcessorTask.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerBaseProcessorTask.java index 755f200..b192f42 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerBaseProcessorTask.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerBaseProcessorTask.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.UUID; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -122,7 +123,7 @@ public abstract class GroupTrainerBaseProcessorTask<K, V, C, T, R extends Serial /** {@inheritDoc} */ @Nullable @Override public R reduce(List<ComputeJobResult> results) throws IgniteException { - return reducer.apply(results.stream().map(res -> (R)res.getData()).collect(Collectors.toList())); + return reducer.apply(results.stream().map(res -> (R)res.getData()).filter(Objects::nonNull).collect(Collectors.toList())); } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java index 2e3f457..9ed18af 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java @@ -21,6 +21,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import org.apache.ignite.Ignite; @@ -137,7 +138,7 @@ public class ResultAndUpdates<R> { } } - List<R> results = resultsAndUpdates.stream().map(ResultAndUpdates::result).collect(Collectors.toList()); + List<R> results = resultsAndUpdates.stream().map(ResultAndUpdates::result).filter(Objects::nonNull).collect(Collectors.toList()); return new ResultAndUpdates<>(reducer.apply(results), allUpdates); } http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdateStrategies.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdateStrategies.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdateStrategies.java new file mode 100644 index 0000000..33ec96a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdateStrategies.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trainers.group; + +import org.apache.ignite.ml.optimization.SmoothParametrized; +import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; + +/** + * Holder class for various update strategies. + */ +public class UpdateStrategies { + /** + * Simple GD update strategy. + * + * @return GD update strategy. + */ + public static UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> GD() { + return new UpdatesStrategy<>(new SimpleGDUpdateCalculator(), SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg); + } + + /** + * RProp update strategy. + * + * @return RProp update strategy. + */ + public static UpdatesStrategy<SmoothParametrized, RPropParameterUpdate> RProp() { + return new UpdatesStrategy<>(new RPropUpdateCalculator(), RPropParameterUpdate::sumLocal, RPropParameterUpdate::avg); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdatesStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdatesStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdatesStrategy.java new file mode 100644 index 0000000..9deb460 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdatesStrategy.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trainers.group; + +import java.util.List; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator; + +/** + * Class encapsulating update strategies for group trainers based on updates. + * + * @param <M> Type of model to be optimized. + * @param <U> Type of update. + */ +public class UpdatesStrategy<M, U> { + /** + * {@link ParameterUpdateCalculator}. + */ + private ParameterUpdateCalculator<M, U> updatesCalculator; + + /** + * Function used to reduce updates in one training (for example, sum all sequential gradient updates to get one + * gradient update). + */ + private IgniteFunction<List<U>, U> locStepUpdatesReducer; + + /** + * Function used to reduce updates from different trainings (for example, averaging of gradients of all parallel trainings). + */ + private IgniteFunction<List<U>, U> allUpdatesReducer; + + /** + * Construct instance of this class with given parameters. + * + * @param updatesCalculator Parameter update calculator. + * @param locStepUpdatesReducer Function used to reduce updates in one training + * (for example, sum all sequential gradient updates to get one gradient update). + * @param allUpdatesReducer Function used to reduce updates from different trainings + * (for example, averaging of gradients of all parallel trainings). + */ + public UpdatesStrategy( + ParameterUpdateCalculator<M, U> updatesCalculator, + IgniteFunction<List<U>, U> locStepUpdatesReducer, + IgniteFunction<List<U>, U> allUpdatesReducer) { + this.updatesCalculator = updatesCalculator; + this.locStepUpdatesReducer = locStepUpdatesReducer; + this.allUpdatesReducer = allUpdatesReducer; + } + + /** + * Get parameter update calculator (see {@link ParameterUpdateCalculator}). + * + * @return Parameter update calculator. + */ + public ParameterUpdateCalculator<M, U> getUpdatesCalculator() { + return updatesCalculator; + } + + /** + * Get function used to reduce updates in one training + * (for example, sum all sequential gradient updates to get one gradient update). + * + * @return Function used to reduce updates in one training + * (for example, sum all sequential gradient updates to get on gradient update). + */ + public IgniteFunction<List<U>, U> locStepUpdatesReducer() { + return locStepUpdatesReducer; + } + + /** + * Get function used to reduce updates from different trainings + * (for example, averaging of gradients of all parallel trainings). + * + * @return Function used to reduce updates from different trainings. + */ + public IgniteFunction<List<U>, U> allUpdatesReducer() { + return allUpdatesReducer; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java index ab31f9f..cb6fd89 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java @@ -38,7 +38,7 @@ public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P> /** * Supplier for updater function. */ - private final IgniteSupplier<ParameterUpdateCalculator<M, P>> updaterSupplier; + private final IgniteSupplier<ParameterUpdateCalculator<? super M, P>> updaterSupplier; /** * Error threshold. @@ -69,7 +69,7 @@ public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P> * @param maxIterations Maximal iterations count. */ public LocalBatchTrainer(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, - IgniteSupplier<ParameterUpdateCalculator<M, P>> updaterSupplier, double errorThreshold, int maxIterations) { + IgniteSupplier<ParameterUpdateCalculator<? super M, P>> updaterSupplier, double errorThreshold, int maxIterations) { this.loss = loss; this.updaterSupplier = updaterSupplier; this.errorThreshold = errorThreshold; @@ -82,7 +82,7 @@ public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P> M mdl = data.mdl(); double err; - ParameterUpdateCalculator<M, P> updater = updaterSupplier.get(); + ParameterUpdateCalculator<? super M, P> updater = updaterSupplier.get(); P updaterParams = updater.init(mdl, loss); @@ -130,7 +130,7 @@ public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P> * @param updaterSupplier New updater supplier. * @return new trainer with the same parameters as this trainer, but with new updater supplier. */ - public LocalBatchTrainer withUpdater(IgniteSupplier<ParameterUpdateCalculator<M, P>> updaterSupplier) { + public LocalBatchTrainer withUpdater(IgniteSupplier<ParameterUpdateCalculator<? super M, P>> updaterSupplier) { return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations); } http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java index 206e1e9..ed0ebd3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java @@ -53,7 +53,7 @@ public class Utils { obj = in.readObject(); } catch (IOException | ClassNotFoundException e) { - throw new IgniteException("Couldn't copy the object."); + throw new IgniteException("Couldn't copy the object.", e); } return (T)obj; http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java index 151fead..abd8ad2 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.nn; +import java.io.Serializable; import java.util.Random; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; @@ -31,8 +32,11 @@ import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.nn.initializers.RandomInitializer; import org.apache.ignite.ml.nn.trainers.distributed.MLPGroupUpdateTrainer; -import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; import org.apache.ignite.ml.structures.LabeledVector; +import org.apache.ignite.ml.trainers.group.UpdateStrategies; +import org.apache.ignite.ml.trainers.group.UpdatesStrategy; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; /** @@ -64,9 +68,26 @@ public class MLPGroupTrainerTest extends GridCommonAbstractTest { } /** + * Test training 'xor' by RProp. + */ + public void testXORRProp() { + doTestXOR(UpdateStrategies.RProp()); + } + + /** + * Test training 'xor' by SimpleGD. + */ + public void testXORGD() { + doTestXOR(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator().withLearningRate(0.5), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg)); + } + + /** * Test training of 'xor' by {@link MLPGroupUpdateTrainer}. */ - public void testXOR() { + private <U extends Serializable> void doTestXOR(UpdatesStrategy<? super MultilayerPerceptron, U> stgy) { int samplesCnt = 1000; Matrix xorInputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}}, @@ -93,18 +114,19 @@ public class MLPGroupTrainerTest extends GridCommonAbstractTest { } } - int totalCnt = 20; + int totalCnt = 30; int failCnt = 0; double maxFailRatio = 0.3; - MLPGroupUpdateTrainer<RPropParameterUpdate> trainer = MLPGroupUpdateTrainer.getDefault(ignite). - withSyncRate(3). + + MLPGroupUpdateTrainer<U> trainer = MLPGroupUpdateTrainer.getDefault(ignite). + withSyncPeriod(3). withTolerance(0.001). - withMaxGlobalSteps(1000); + withMaxGlobalSteps(100). + withUpdateStrategy(stgy); for (int i = 0; i < totalCnt; i++) { - MLPGroupUpdateTrainerCacheInput trainerInput = new MLPGroupUpdateTrainerCacheInput(conf, - new RandomInitializer(new Random(123L)), 6, cache, 4, new Random(123L)); + new RandomInitializer(new Random(123L + i)), 6, cache, 10, new Random(123L + i)); MultilayerPerceptron mlp = trainer.train(trainerInput); http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java index b4c14e1..3119170 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java @@ -43,7 +43,7 @@ public class MLPLocalTrainerTest { */ @Test public void testXORSimpleGD() { - xorTest(() -> new SimpleGDUpdateCalculator<>(0.3)); + xorTest(() -> new SimpleGDUpdateCalculator(0.3)); } /** @@ -51,7 +51,7 @@ public class MLPLocalTrainerTest { */ @Test public void testXORRProp() { - xorTest(() -> new RPropUpdateCalculator<>()); + xorTest(RPropUpdateCalculator::new); } /** @@ -67,7 +67,7 @@ public class MLPLocalTrainerTest { * @param updaterSupplier Updater supplier. * @param <P> Updater parameters type. */ - private <P> void xorTest(IgniteSupplier<ParameterUpdateCalculator<MultilayerPerceptron, P>> updaterSupplier) { + private <P> void xorTest(IgniteSupplier<ParameterUpdateCalculator<? super MultilayerPerceptron, P>> updaterSupplier) { Matrix xorInputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}}, StorageConstants.ROW_STORAGE_MODE).transpose(); @@ -79,7 +79,7 @@ public class MLPLocalTrainerTest { withAddedLayer(1, false, Activators.SIGMOID); SimpleMLPLocalBatchTrainerInput trainerInput = new SimpleMLPLocalBatchTrainerInput(conf, - new Random(1234L), xorInputs, xorOutputs, 4); + new Random(123L), xorInputs, xorOutputs, 4); MultilayerPerceptron mlp = new MLPLocalBatchTrainer<>(LossFunctions.MSE, updaterSupplier, http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java index 112aade..5656f68 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java @@ -90,7 +90,9 @@ public class MnistDistributed extends GridCommonAbstractTest { IgniteCache<Integer, LabeledVector<Vector, Vector>> labeledVectorsCache = LabeledVectorsCache.createNew(ignite); loadIntoCache(trainingMnistLst, labeledVectorsCache); - MLPGroupUpdateTrainer<RPropParameterUpdate> trainer = MLPGroupUpdateTrainer.getDefault(ignite).withMaxGlobalSteps(35).withSyncRate(2); + MLPGroupUpdateTrainer<RPropParameterUpdate> trainer = MLPGroupUpdateTrainer.getDefault(ignite). + withMaxGlobalSteps(35). + withSyncPeriod(2); MLPArchitecture arch = new MLPArchitecture(FEATURES_CNT). withAddedLayer(hiddenNeuronsCnt, true, Activators.SIGMOID). @@ -105,6 +107,8 @@ public class MnistDistributed extends GridCommonAbstractTest { Tracer.showAscii(truth); Tracer.showAscii(predicted); + + X.println("Accuracy: " + VectorUtils.zipWith(predicted, truth, (x, y) -> x.equals(y) ? 1.0 : 0.0).sum() / truth.size() * 100 + "%."); } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/6bb6b3e5/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java index cda0413a..14c02aa 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java @@ -74,7 +74,7 @@ public class MnistLocal { 2000); MultilayerPerceptron mdl = new MLPLocalBatchTrainer<>(LossFunctions.MSE, - () -> new RPropUpdateCalculator<>(0.1, 1.2, 0.5), + () -> new RPropUpdateCalculator(0.1, 1.2, 0.5), 1E-7, 200). train(input); @@ -89,5 +89,7 @@ public class MnistLocal { Tracer.showAscii(truth); Tracer.showAscii(predicted); + + X.println("Accuracy: " + VectorUtils.zipWith(predicted, truth, (x, y) -> x.equals(y) ? 1.0 : 0.0).sum() / truth.size() * 100 + "%."); } }
