http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/TransformationLayerArchitecture.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/TransformationLayerArchitecture.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/TransformationLayerArchitecture.java new file mode 100644 index 0000000..639bed0 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/TransformationLayerArchitecture.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.architecture; + +import org.apache.ignite.ml.math.functions.IgniteDifferentiableDoubleToDoubleFunction; + +/** + * Class encapsulation architecture of transformation layer (i.e. non-input layer). + */ +public class TransformationLayerArchitecture extends LayerArchitecture { + /** + * Flag indicating presence of bias in layer. + */ + private boolean hasBias; + + /** + * Activation function for layer. + */ + private IgniteDifferentiableDoubleToDoubleFunction activationFunction; + + /** + * Construct TransformationLayerArchitecture. + * + * @param neuronsCnt Count of neurons in this layer. + * @param hasBias Flag indicating presence of bias in layer. + * @param activationFunction Activation function for layer. + */ + public TransformationLayerArchitecture(int neuronsCnt, boolean hasBias, + IgniteDifferentiableDoubleToDoubleFunction activationFunction) { + super(neuronsCnt); + + this.hasBias = hasBias; + this.activationFunction = activationFunction; + } + + /** + * Checks if this layer has a bias. + * + * @return Value of predicate "this layer has a bias". + */ + public boolean hasBias() { + return hasBias; + } + + /** + * Get activation function for this layer. + * + * @return Activation function for this layer. + */ + public IgniteDifferentiableDoubleToDoubleFunction activationFunction() { + return activationFunction; + } +}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/package-info.java new file mode 100644 index 0000000..aff2d20 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains multilayer perceptron architecture classes. + */ +package org.apache.ignite.ml.nn.architecture; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/MLPInitializer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/MLPInitializer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/MLPInitializer.java new file mode 100644 index 0000000..680508c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/MLPInitializer.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.initializers; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; + +/** + * Interface for classes encapsulating logic for initialization of weights and biases of MLP. + */ +public interface MLPInitializer { + /** + * In-place change values of matrix representing weights. + * + * @param weights Matrix representing weights. + */ + void initWeights(Matrix weights); + + /** + * In-place change values of vector representing vectors. + * + * @param biases Vector representing vectors. + */ + void initBiases(Vector biases); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/RandomInitializer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/RandomInitializer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/RandomInitializer.java new file mode 100644 index 0000000..18cb8a6 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/RandomInitializer.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.initializers; + +import java.util.Random; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; + +/** + * Class for initialization of MLP parameters with random uniformly distributed numbers from -1 to 1. + */ +public class RandomInitializer implements MLPInitializer { + /** + * RNG. + */ + Random rnd; + + /** + * Construct RandomInitializer from given RNG. + * + * @param rnd RNG. + */ + public RandomInitializer(Random rnd) { + this.rnd = rnd; + } + + /** {@inheritDoc} */ + @Override public void initWeights(Matrix weights) { + weights.map(value -> 2 * (rnd.nextDouble() - 0.5)); + } + + /** {@inheritDoc} */ + @Override public void initBiases(Vector biases) { + biases.map(value -> 2 * (rnd.nextDouble() - 0.5)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/package-info.java new file mode 100644 index 0000000..351783b --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains multilayer perceptron parameters initializers. + */ +package org.apache.ignite.ml.nn.initializers; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/package-info.java new file mode 100644 index 0000000..1641147 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains neural networks and related classes. + */ +package org.apache.ignite.ml.nn; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java new file mode 100644 index 0000000..64a1956 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.trainers.local; + +import org.apache.ignite.IgniteLogger; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.Trainer; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.functions.IgniteSupplier; +import org.apache.ignite.ml.math.util.MatrixUtil; +import org.apache.ignite.ml.nn.LocalBatchTrainerInput; +import org.apache.ignite.ml.nn.updaters.ParameterUpdater; +import org.apache.ignite.ml.nn.updaters.UpdaterParams; + +/** + * Batch trainer. This trainer is not distributed on the cluster, but input can theoretically read data from + * Ignite cache. + */ +public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P extends UpdaterParams<? super M>> + implements Trainer<M, LocalBatchTrainerInput<M>> { + /** + * Supplier for updater function. + */ + private final IgniteSupplier<ParameterUpdater<? super M, P>> updaterSupplier; + + /** + * Error threshold. + */ + private final double errorThreshold; + + /** + * Maximal iterations count. + */ + private final int maxIterations; + + /** + * Loss function. + */ + private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; + + /** + * Logger. + */ + private IgniteLogger log; + + /** + * Construct a trainer. + * + * @param loss Loss function. + * @param updaterSupplier Supplier of updater function. + * @param errorThreshold Error threshold. + * @param maxIterations Maximal iterations count. + */ + public LocalBatchTrainer(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, + IgniteSupplier<ParameterUpdater<? super M, P>> updaterSupplier, double errorThreshold, int maxIterations) { + this.loss = loss; + this.updaterSupplier = updaterSupplier; + this.errorThreshold = errorThreshold; + this.maxIterations = maxIterations; + } + + /** {@inheritDoc} */ + @Override public M train(LocalBatchTrainerInput<M> data) { + int i = 0; + M mdl = data.mdl(); + double err; + + ParameterUpdater<? super M, P> updater = updaterSupplier.get(); + + P updaterParams = updater.init(mdl, loss); + + while (i < maxIterations) { + IgniteBiTuple<Matrix, Matrix> batch = data.getBatch(); + Matrix input = batch.get1(); + Matrix truth = batch.get2(); + + updaterParams = updater.updateParams(mdl, updaterParams, i, input, truth); + + // Update mdl with updater parameters. + mdl = updaterParams.update(mdl); + + Matrix predicted = mdl.apply(input); + + int batchSize = input.columnSize(); + + err = MatrixUtil.zipFoldByColumns(predicted, truth, (predCol, truthCol) -> + loss.apply(truthCol).apply(predCol)).sum() / batchSize; + + debug("Error: " + err); + + if (err < errorThreshold) + break; + + i++; + } + + return mdl; + } + + /** + * Construct new trainer with the same parameters as this trainer, but with new loss. + * + * @param loss New loss function. + * @return new trainer with the same parameters as this trainer, but with new loss. + */ + public LocalBatchTrainer withLoss(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { + return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations); + } + + /** + * Construct new trainer with the same parameters as this trainer, but with new updater supplier. + * + * @param updaterSupplier New updater supplier. + * @return new trainer with the same parameters as this trainer, but with new updater supplier. + */ + public LocalBatchTrainer withUpdater(IgniteSupplier<ParameterUpdater<? super M, P>> updaterSupplier) { + return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations); + } + + /** + * Construct new trainer with the same parameters as this trainer, but with new error threshold. + * + * @param errorThreshold New error threshold. + * @return new trainer with the same parameters as this trainer, but with new error threshold. + */ + public LocalBatchTrainer withErrorThreshold(double errorThreshold) { + return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations); + } + + /** + * Construct new trainer with the same parameters as this trainer, but with new maximal iterations count. + * + * @param maxIterations New maximal iterations count. + * @return new trainer with the same parameters as this trainer, but with new maximal iterations count. + */ + public LocalBatchTrainer withMaxIterations(int maxIterations) { + return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations); + } + + /** + * Set logger. + * + * @param log Logger. + * @return This object. + */ + public LocalBatchTrainer setLogger(IgniteLogger log) { + this.log = log; + + return this; + } + + /** + * Output debug message. + * + * @param msg Message. + */ + private void debug(String msg) { + if (log != null && log.isDebugEnabled()) + log.debug(msg); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java new file mode 100644 index 0000000..7065e2f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.trainers.local; + +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.functions.IgniteSupplier; +import org.apache.ignite.ml.nn.LossFunctions; +import org.apache.ignite.ml.nn.MultilayerPerceptron; +import org.apache.ignite.ml.nn.updaters.ParameterUpdater; +import org.apache.ignite.ml.nn.updaters.RPropUpdater; +import org.apache.ignite.ml.nn.updaters.RPropUpdaterParams; +import org.apache.ignite.ml.nn.updaters.UpdaterParams; + +/** + * Local batch trainer for MLP. + * + * @param <P> Parameter updater parameters. + */ +public class MLPLocalBatchTrainer<P extends UpdaterParams<? super MultilayerPerceptron>> + extends LocalBatchTrainer<MultilayerPerceptron, P> { + /** + * Default loss function. + */ + private static final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> DEFAULT_LOSS = + LossFunctions.MSE; + + /** + * Default error threshold. + */ + private static final double DEFAULT_ERROR_THRESHOLD = 1E-5; + + /** + * Default maximal iterations count. + */ + private static final int DEFAULT_MAX_ITERATIONS = 100; + + + /** + * Construct a trainer. + * + * @param loss Loss function. + * @param updaterSupplier Supplier of updater function. + * @param errorThreshold Error threshold. + * @param maxIterations Maximal iterations count. + */ + public MLPLocalBatchTrainer( + IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, + IgniteSupplier<ParameterUpdater<? super MultilayerPerceptron, P>> updaterSupplier, + double errorThreshold, int maxIterations) { + super(loss, updaterSupplier, errorThreshold, maxIterations); + } + + /** + * Get MLPLocalBatchTrainer with default parameters. + * + * @return MLPLocalBatchTrainer with default parameters. + */ + public static MLPLocalBatchTrainer<RPropUpdaterParams> getDefault() { + return new MLPLocalBatchTrainer<>(DEFAULT_LOSS, RPropUpdater::new, DEFAULT_ERROR_THRESHOLD, DEFAULT_MAX_ITERATIONS); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/package-info.java new file mode 100644 index 0000000..b78adb8 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains multilayer perceptron local trainers. + */ +package org.apache.ignite.ml.nn.trainers.local; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/package-info.java new file mode 100644 index 0000000..c90f67a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains multilayer perceptron trainers. + */ +package org.apache.ignite.ml.nn.trainers; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java new file mode 100644 index 0000000..b33c2c7 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; + +/** + * Interface for models which are smooth functions of their parameters. + */ +interface BaseSmoothParametrized<M extends BaseSmoothParametrized<M>> { + /** + * Compose function in the following way: feed output of this model as input to second argument to loss function. + * After that we have a function g of three arguments: input, ground truth, parameters. + * If we consider function + * h(w) = 1 / M sum_{i=1}^{M} g(w, input_i, groundTruth_i), + * where M is number of entries in batch, we get function of one argument: parameters vector w. + * This function is being differentiated. + * + * @param loss Loss function. + * @param inputsBatch Batch of inputs. + * @param truthBatch Batch of ground truths. + * @return Gradient of h at current point in parameters space. + */ + Vector differentiateByParameters(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, Matrix inputsBatch, Matrix truthBatch); + + /** + * Get parameters vector. + * + * @return Parameters vector. + */ + Vector parameters(); + + /** + * Set parameters. + * + * @param vector Parameters vector. + */ + M setParameters(Vector vector); + + /** + * Get count of parameters of this model. + * + * @return Count of parameters of this model. + */ + int parametersCount(); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java new file mode 100644 index 0000000..7b6a0c7 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; + +/** + * Class encapsulating Nesterov algorithm for MLP parameters update. + */ +public class NesterovUpdater implements ParameterUpdater<SmoothParametrized, NesterovUpdaterParams> { + /** + * Learning rate. + */ + private final double learningRate; + + /** + * Loss function. + */ + private IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; + + /** + * Momentum constant. + */ + protected double momentum; + + /** + * Construct NesterovUpdater. + * + * @param momentum Momentum constant. + */ + public NesterovUpdater(double learningRate, double momentum) { + this.learningRate = learningRate; + this.momentum = momentum; + } + + /** {@inheritDoc} */ + @Override public NesterovUpdaterParams init(SmoothParametrized mdl, + IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { + this.loss = loss; + + return new NesterovUpdaterParams(mdl.parametersCount()); + } + + /** {@inheritDoc} */ + @Override public NesterovUpdaterParams updateParams(SmoothParametrized mdl, NesterovUpdaterParams updaterParameters, + int iteration, Matrix inputs, Matrix groundTruth) { + + if (iteration > 0) { + Vector curParams = mdl.parameters(); + mdl.setParameters(curParams.minus(updaterParameters.prevIterationUpdates().times(momentum))); + } + + Vector gradient = mdl.differentiateByParameters(loss, inputs, groundTruth); + updaterParameters.setPreviousUpdates(updaterParameters.prevIterationUpdates().plus(gradient.times(learningRate))); + + return updaterParameters; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java new file mode 100644 index 0000000..d403ea1 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; + +/** + * Data needed for Nesterov parameters updater. + */ +public class NesterovUpdaterParams implements UpdaterParams<SmoothParametrized> { + /** + * Previous step weights updates. + */ + protected Vector prevIterationUpdates; + + /** + * Construct NesterovUpdaterParams. + * + * @param paramsCnt Count of parameters on which update happens. + */ + public NesterovUpdaterParams(int paramsCnt) { + prevIterationUpdates = new DenseLocalOnHeapVector(paramsCnt).assign(0); + } + + /** + * Set previous step parameters updates. + * + * @param updates Parameters updates. + * @return This object with updated parameters updates. + */ + public NesterovUpdaterParams setPreviousUpdates(Vector updates) { + prevIterationUpdates = updates; + return this; + } + + /** + * Get previous step parameters updates. + * + * @return Previous step parameters updates. + */ + public Vector prevIterationUpdates() { + return prevIterationUpdates; + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override public <M extends SmoothParametrized> M update(M obj) { + Vector parameters = obj.parameters(); + return (M)obj.setParameters(parameters.minus(prevIterationUpdates)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java new file mode 100644 index 0000000..e8e28fd --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; + +/** + * Interface for classes encapsulating parameters update logic. + * + * @param <M> Type of model to be updated. + * @param <P> Type of parameters needed for this updater. + */ +public interface ParameterUpdater<M, P extends UpdaterParams> { + /** + * Initializes the updater. + * + * @param mdl Model to be trained. + * @param loss Loss function. + */ + P init(M mdl, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss); + + /** + * Update updater parameters. + * + * @param mdl Model to be updated. + * @param updaterParameters Updater parameters to update. + * @param iteration Current trainer iteration. + * @param inputs Inputs. + * @param groundTruth True values. + * @return Updated parameters. + */ + P updateParams(M mdl, P updaterParameters, int iteration, Matrix inputs, Matrix groundTruth); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java new file mode 100644 index 0000000..c9d8843 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.util.MatrixUtil; + +/** + * Class encapsulating RProp algorithm. + * + * @see <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf</a>. + */ +public class RPropUpdater implements ParameterUpdater<SmoothParametrized, RPropUpdaterParams> { + /** + * Default initial update. + */ + private static double DFLT_INIT_UPDATE = 0.1; + + /** + * Default acceleration rate. + */ + private static double DFLT_ACCELERATION_RATE = 1.2; + + /** + * Default deacceleration rate. + */ + private static double DFLT_DEACCELERATION_RATE = 0.5; + + /** + * Initial update. + */ + private final double initUpdate; + + /** + * Acceleration rate. + */ + private final double accelerationRate; + + /** + * Deacceleration rate. + */ + private final double deaccelerationRate; + + /** + * Maximal value for update. + */ + private final static double UPDATE_MAX = 50.0; + + /** + * Minimal value for update. + */ + private final static double UPDATE_MIN = 1E-6; + + /** + * Loss function. + */ + protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; + + /** + * Construct RPropUpdater. + * + * @param initUpdate Initial update. + * @param accelerationRate Acceleration rate. + * @param deaccelerationRate Deacceleration rate. + */ + public RPropUpdater(double initUpdate, double accelerationRate, double deaccelerationRate) { + this.initUpdate = initUpdate; + this.accelerationRate = accelerationRate; + this.deaccelerationRate = deaccelerationRate; + } + + /** + * Construct RPropUpdater with default parameters. + */ + public RPropUpdater() { + this(DFLT_INIT_UPDATE, DFLT_ACCELERATION_RATE, DFLT_DEACCELERATION_RATE); + } + + /** {@inheritDoc} */ + @Override public RPropUpdaterParams init(SmoothParametrized mdl, + IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { + this.loss = loss; + return new RPropUpdaterParams(mdl.parametersCount(), initUpdate); + } + + /** {@inheritDoc} */ + @Override public RPropUpdaterParams updateParams(SmoothParametrized mdl, RPropUpdaterParams updaterParams, + int iteration, Matrix inputs, Matrix groundTruth) { + Vector gradient = mdl.differentiateByParameters(loss, inputs, groundTruth); + Vector prevGradient = updaterParams.prevIterationGradient(); + Vector derSigns; + + if (prevGradient != null) + derSigns = VectorUtils.zipWith(prevGradient, gradient, (x, y) -> Math.signum(x * y)); + else + derSigns = gradient.like(gradient.size()).assign(1.0); + + updaterParams.deltas().map(derSigns, (prevDelta, sign) -> { + if (sign > 0) + return Math.min(prevDelta * accelerationRate, UPDATE_MAX); + else if (sign < 0) + return Math.max(prevDelta * deaccelerationRate, UPDATE_MIN); + else + return prevDelta; + }); + + updaterParams.setPrevIterationBiasesUpdates(MatrixUtil.zipWith(gradient, updaterParams.deltas(), (der, delta, i) -> { + if (derSigns.getX(i) >= 0) + return -Math.signum(der) * delta; + + return updaterParams.prevIterationUpdates().getX(i); + })); + + Vector updatesMask = MatrixUtil.zipWith(derSigns, updaterParams.prevIterationUpdates(), (sign, upd, i) -> { + if (sign < 0) + gradient.setX(i, 0.0); + + if (sign >= 0) + return 1.0; + else + return -1.0; + }); + + updaterParams.setUpdatesMask(updatesMask); + updaterParams.setPrevIterationWeightsDerivatives(gradient.copy()); + + return updaterParams; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java new file mode 100644 index 0000000..cff5f5b --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; + +/** + * Data needed for RProp updater. + * @see <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf</a>. + */ +public class RPropUpdaterParams implements UpdaterParams<SmoothParametrized> { + /** + * Previous iteration weights updates. In original paper they are labeled with "delta w". + */ + protected Vector prevIterationUpdates; + + /** + * Previous iteration model partial derivatives by parameters. + */ + protected Vector prevIterationGradient; + /** + * Previous iteration parameters deltas. In original paper they are labeled with "delta". + */ + protected Vector deltas; + + /** + * Updates mask (values by which update is multiplied). + */ + protected Vector updatesMask; + + /** + * Construct RPropUpdaterParams. + * + * @param paramsCnt Parameters count. + * @param initUpdate Initial update (in original work labeled as "delta_0"). + */ + RPropUpdaterParams(int paramsCnt, double initUpdate) { + prevIterationUpdates = new DenseLocalOnHeapVector(paramsCnt); + prevIterationGradient = new DenseLocalOnHeapVector(paramsCnt); + deltas = new DenseLocalOnHeapVector(paramsCnt).assign(initUpdate); + updatesMask = new DenseLocalOnHeapVector(paramsCnt); + } + + /** + * Get bias deltas. + * + * @return Bias deltas. + */ + Vector deltas() { + return deltas; + } + + /** + * Get previous iteration biases updates. In original paper they are labeled with "delta w". + * + * @return Biases updates. + */ + Vector prevIterationUpdates() { + return prevIterationUpdates; + } + + /** + * Set previous iteration parameters updates. In original paper they are labeled with "delta w". + * + * @param updates New parameters updates value. + * @return This object. + */ + Vector setPrevIterationBiasesUpdates(Vector updates) { + return prevIterationUpdates = updates; + } + + /** + * Get previous iteration loss function partial derivatives by parameters. + * + * @return Previous iteration loss function partial derivatives by parameters. + */ + Vector prevIterationGradient() { + return prevIterationGradient; + } + + /** + * Set previous iteration loss function partial derivatives by parameters. + * + * @return This object. + */ + RPropUpdaterParams setPrevIterationWeightsDerivatives(Vector gradient) { + prevIterationGradient = gradient; + return this; + } + + /** + * Get updates mask (values by which update is multiplied). + * + * @return Updates mask (values by which update is multiplied). + */ + public Vector updatesMask() { + return updatesMask; + } + + /** + * Set updates mask (values by which update is multiplied). + * + * @param updatesMask New updatesMask. + */ + public RPropUpdaterParams setUpdatesMask(Vector updatesMask) { + this.updatesMask = updatesMask; + + return this; + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override public <M extends SmoothParametrized> M update(M obj) { + Vector updatesToAdd = VectorUtils.elementWiseTimes(updatesMask.copy(), prevIterationUpdates); + return (M)obj.setParameters(obj.parameters().plus(updatesToAdd)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java new file mode 100644 index 0000000..50a120a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; + +/** + * Parameters for {@link SimpleGDUpdater}. + */ +public class SimpleGDParams implements UpdaterParams<SmoothParametrized> { + /** + * Gradient. + */ + private Vector gradient; + + /** + * Learning rate. + */ + private double learningRate; + + /** + * Construct instance of this class. + * + * @param paramsCnt Count of parameters. + * @param learningRate Learning rate. + */ + public SimpleGDParams(int paramsCnt, double learningRate) { + gradient = new DenseLocalOnHeapVector(paramsCnt); + this.learningRate = learningRate; + } + + /** + * Construct instance of this class. + * + * @param gradient Gradient. + * @param learningRate Learning rate. + */ + public SimpleGDParams(Vector gradient, double learningRate) { + this.gradient = gradient; + this.learningRate = learningRate; + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override public <M extends SmoothParametrized> M update(M obj) { + Vector params = obj.parameters(); + return (M)obj.setParameters(params.minus(gradient.times(learningRate))); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java new file mode 100644 index 0000000..5bf9c3f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; + +/** + * Simple gradient descent parameters updater. + */ +public class SimpleGDUpdater implements ParameterUpdater<SmoothParametrized, SimpleGDParams> { + /** + * Learning rate. + */ + private double learningRate; + + /** + * Loss function. + */ + protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; + + /** + * Construct SimpleGDUpdater. + * + * @param learningRate Learning rate. + */ + public SimpleGDUpdater(double learningRate) { + this.learningRate = learningRate; + } + + /** {@inheritDoc} */ + @Override public SimpleGDParams init(SmoothParametrized mlp, + IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) { + this.loss = loss; + return new SimpleGDParams(mlp.parametersCount(), learningRate); + } + + /** {@inheritDoc} */ + @Override public SimpleGDParams updateParams(SmoothParametrized mlp, SimpleGDParams updaterParameters, + int iteration, Matrix inputs, Matrix groundTruth) { + return new SimpleGDParams(mlp.differentiateByParameters(loss, inputs, groundTruth), learningRate); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java new file mode 100644 index 0000000..5c4f59f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +/** + * Interface for models which are smooth functions of their parameters. + */ +public interface SmoothParametrized<M extends SmoothParametrized<M>> extends BaseSmoothParametrized<M> { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java new file mode 100644 index 0000000..cd5bc32 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn.updaters; + +/** + * A common interface for parameter updaters. + * + * @param <T> Type of object to be updated with this params. + */ +public interface UpdaterParams<T> { + /** + * Update given obj with this parameters. + * + * @param obj Object to be updated. + */ + <M extends T> M update(M obj); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/package-info.java new file mode 100644 index 0000000..13bc3c8 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains parameters updaters. + */ +package org.apache.ignite.ml.nn.updaters; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java index 76a90fc..b95cbf3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java @@ -47,7 +47,7 @@ public class OLSMultipleLinearRegressionModel implements Model<Vector, Vector>, } /** {@inheritDoc} */ - @Override public Vector predict(Vector val) { + @Override public Vector apply(Vector val) { return xMatrix.times(solver.solve(val)); } http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java index 86e9326..572e64a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java @@ -38,7 +38,7 @@ public class DecisionTreeModel implements Model<Vector, Double> { } /** {@inheritDoc} */ - @Override public Double predict(Vector val) { + @Override public Double apply(Vector val) { return root.process(val); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java index 847b1f1..4472300 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java @@ -22,6 +22,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.util.Random; import org.apache.ignite.IgniteException; /** @@ -57,4 +58,30 @@ public class Utils { return (T)obj; } + + /** + * Select k distinct integers from range [0, n) with reservoir sampling: https://en.wikipedia.org/wiki/Reservoir_sampling. + * + * @param n Number specifying left end of range of integers to pick values from. + * @param k Count specifying how many integers should be picked. + * @return Array containing k distinct integers from range [0, n); + */ + public static int[] selectKDistinct(int n, int k) { + int i; + + int res[] = new int[k]; + for (i = 0; i < k; i++) + res[i] = i; + + Random r = new Random(); + + for (; i < n; i++) { + int j = r.nextInt(i + 1); + + if (j < k) + res[j] = i; + } + + return res; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java index 05c91bd..fafd364 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java @@ -20,6 +20,7 @@ package org.apache.ignite.ml; import org.apache.ignite.ml.clustering.ClusteringTestSuite; import org.apache.ignite.ml.knn.KNNTestSuite; import org.apache.ignite.ml.math.MathImplMainTestSuite; +import org.apache.ignite.ml.nn.MLPTestSuite; import org.apache.ignite.ml.regressions.RegressionsTestSuite; import org.apache.ignite.ml.trees.DecisionTreesTestSuite; import org.junit.runner.RunWith; @@ -35,7 +36,8 @@ import org.junit.runners.Suite; ClusteringTestSuite.class, DecisionTreesTestSuite.class, KNNTestSuite.class, - LocalModelsTest.class + LocalModelsTest.class, + MLPTestSuite.class }) public class IgniteMLTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java index e010553..28af6fa 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java @@ -46,9 +46,9 @@ public class KNNClassificationTest extends BaseKNNTest { KNNModel knnMdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training); Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0}); - assertEquals(knnMdl.predict(firstVector), 1.0); + assertEquals(knnMdl.apply(firstVector), 1.0); Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0}); - assertEquals(knnMdl.predict(secondVector), 2.0); + assertEquals(knnMdl.apply(secondVector), 2.0); } /** */ @@ -69,9 +69,9 @@ public class KNNClassificationTest extends BaseKNNTest { KNNModel knnMdl = new KNNModel(1, new EuclideanDistance(), KNNStrategy.SIMPLE, training); Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0}); - assertEquals(knnMdl.predict(firstVector), 1.0); + assertEquals(knnMdl.apply(firstVector), 1.0); Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0}); - assertEquals(knnMdl.predict(secondVector), 2.0); + assertEquals(knnMdl.apply(secondVector), 2.0); } /** */ @@ -91,7 +91,7 @@ public class KNNClassificationTest extends BaseKNNTest { KNNModel knnMdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training); Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01}); - assertEquals(knnMdl.predict(vector), 2.0); + assertEquals(knnMdl.apply(vector), 2.0); } /** */ @@ -112,7 +112,7 @@ public class KNNClassificationTest extends BaseKNNTest { KNNModel knnMdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.WEIGHTED, training); Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01}); - assertEquals(knnMdl.predict(vector), 1.0); + assertEquals(knnMdl.apply(vector), 1.0); } /** */ @@ -122,7 +122,7 @@ public class KNNClassificationTest extends BaseKNNTest { KNNModel knnMdl = new KNNModel(7, new EuclideanDistance(), KNNStrategy.SIMPLE, training); Vector vector = new DenseLocalOnHeapVector(new double[] {5.15, 3.55, 1.45, 0.25}); - assertEquals(knnMdl.predict(vector), 1.0); + assertEquals(knnMdl.apply(vector), 1.0); } /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java index 9a918b6..d973686 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java @@ -56,8 +56,8 @@ public class KNNMultipleLinearRegressionTest extends BaseKNNTest { KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(1, new EuclideanDistance(), KNNStrategy.SIMPLE, training); Vector vector = new SparseBlockDistributedVector(new double[] {0, 0, 0, 5.0, 0.0}); - System.out.println(knnMdl.predict(vector)); - Assert.assertEquals(15, knnMdl.predict(vector), 1E-12); + System.out.println(knnMdl.apply(vector)); + Assert.assertEquals(15, knnMdl.apply(vector), 1E-12); } /** */ @@ -87,8 +87,8 @@ public class KNNMultipleLinearRegressionTest extends BaseKNNTest { KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training); Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); - System.out.println(knnMdl.predict(vector)); - Assert.assertEquals(67857, knnMdl.predict(vector), 2000); + System.out.println(knnMdl.apply(vector)); + Assert.assertEquals(67857, knnMdl.apply(vector), 2000); } /** */ @@ -119,8 +119,8 @@ public class KNNMultipleLinearRegressionTest extends BaseKNNTest { KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(5, new EuclideanDistance(), KNNStrategy.SIMPLE, normalizedTrainingDataset); Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); - System.out.println(knnMdl.predict(vector)); - Assert.assertEquals(67857, knnMdl.predict(vector), 2000); + System.out.println(knnMdl.apply(vector)); + Assert.assertEquals(67857, knnMdl.apply(vector), 2000); } /** */ @@ -151,7 +151,7 @@ public class KNNMultipleLinearRegressionTest extends BaseKNNTest { KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(5, new EuclideanDistance(), KNNStrategy.WEIGHTED, normalizedTrainingDataset); Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); - System.out.println(knnMdl.predict(vector)); - Assert.assertEquals(67857, knnMdl.predict(vector), 2000); + System.out.println(knnMdl.apply(vector)); + Assert.assertEquals(67857, knnMdl.apply(vector), 2000); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPConstInitializer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPConstInitializer.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPConstInitializer.java new file mode 100644 index 0000000..fa2b5e2 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPConstInitializer.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.nn.initializers.MLPInitializer; + +/** + * Initialize weights and biases with specified constant. + */ +public class MLPConstInitializer implements MLPInitializer { + /** + * Constant to be used as bias for all layers. + */ + private double bias; + + /** + * Constant to be used as weight from any neuron to any neuron in next layer. + */ + private double weight; + + /** + * Construct MLPConstInitializer. + * + * @param weight Constant to be used as weight from any neuron to any neuron in next layer. + * @param bias Constant to be used as bias for all layers. + */ + public MLPConstInitializer(double weight, double bias) { + this.bias = bias; + this.weight = weight; + } + + /** + * Construct MLPConstInitializer with biases constant equal to 0.0. + * + * @param weight Constant to be used as weight from any neuron to any neuron in next layer. + */ + public MLPConstInitializer(double weight) { + this(weight, 0.0); + } + + /** {@inheritDoc} */ + @Override public void initWeights(Matrix weights) { + weights.assign(weight); + } + + /** {@inheritDoc} */ + @Override public void initBiases(Vector biases) { + biases.assign(bias); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java new file mode 100644 index 0000000..2a6b55d --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn; + +import java.util.Random; +import org.apache.ignite.internal.util.typedef.X; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.StorageConstants; +import org.apache.ignite.ml.math.Tracer; +import org.apache.ignite.ml.math.functions.IgniteSupplier; +import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; +import org.apache.ignite.ml.nn.architecture.MLPArchitecture; +import org.apache.ignite.ml.nn.trainers.local.MLPLocalBatchTrainer; +import org.apache.ignite.ml.nn.updaters.NesterovUpdater; +import org.apache.ignite.ml.nn.updaters.ParameterUpdater; +import org.apache.ignite.ml.nn.updaters.RPropUpdater; +import org.apache.ignite.ml.nn.updaters.SimpleGDUpdater; +import org.apache.ignite.ml.nn.updaters.UpdaterParams; +import org.junit.Test; + +/** + * Tests for {@link MLPLocalBatchTrainer}. + */ +public class MLPLocalTrainerTest { + /** + * Test 'XOR' operation training with {@link SimpleGDUpdater} updater. + */ + @Test + public void testXORSimpleGD() { + xorTest(() -> new SimpleGDUpdater(0.3)); + } + + /** + * Test 'XOR' operation training with {@link RPropUpdater}. + */ + @Test + public void testXORRProp() { + xorTest(RPropUpdater::new); + } + + /** + * Test 'XOR' operation training with {@link NesterovUpdater}. + */ + @Test + public void testXORNesterov() { + xorTest(() -> new NesterovUpdater(0.1, 0.7)); + } + + /** + * Common method for testing 'XOR' with various updaters. + * @param updaterSupplier Updater supplier. + * @param <P> Updater parameters type. + */ + private <P extends UpdaterParams<? super MultilayerPerceptron>> void xorTest(IgniteSupplier<ParameterUpdater<? super MultilayerPerceptron, P>> updaterSupplier) { + Matrix xorInputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}}, + StorageConstants.ROW_STORAGE_MODE).transpose(); + + Matrix xorOutputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0}, {1.0}, {1.0}, {0.0}}, + StorageConstants.ROW_STORAGE_MODE).transpose(); + + MLPArchitecture conf = new MLPArchitecture(2). + withAddedLayer(10, true, Activators.RELU). + withAddedLayer(1, false, Activators.SIGMOID); + + SimpleMLPLocalBatchTrainerInput trainerInput = new SimpleMLPLocalBatchTrainerInput(conf, + new Random(1234L), xorInputs, xorOutputs, 4); + + MultilayerPerceptron mlp = new MLPLocalBatchTrainer<>(LossFunctions.MSE, + updaterSupplier, + 0.0001, + 16000).train(trainerInput); + + Matrix predict = mlp.apply(xorInputs); + + Tracer.showAscii(predict); + + X.println(xorOutputs.getRow(0).minus(predict.getRow(0)).kNorm(2) + ""); + + TestUtils.checkIsInEpsilonNeighbourhood(xorOutputs.getRow(0), predict.getRow(0), 1E-1); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java new file mode 100644 index 0000000..d757fcb --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn; + +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Tracer; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteTriFunction; +import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.nn.architecture.MLPArchitecture; +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests for Multilayer perceptron. + */ +public class MLPTest { + /** + * Tests that MLP with 2 layer, 1 neuron in each layer and weight equal to 1 is equivalent to sigmoid function. + */ + @Test + public void testSimpleMLPPrediction() { + MLPArchitecture conf = new MLPArchitecture(1).withAddedLayer(1, false, Activators.SIGMOID); + + MultilayerPerceptron mlp = new MultilayerPerceptron(conf, new MLPConstInitializer(1)); + + int input = 2; + + Matrix predict = mlp.apply(new DenseLocalOnHeapMatrix(new double[][] {{input}})); + + Assert.assertEquals(predict, new DenseLocalOnHeapMatrix(new double[][] {{Activators.SIGMOID.apply(input)}})); + } + + /** + * Test that MLP with parameters that should produce function close to 'XOR' is close to 'XOR' on 'XOR' domain. + */ + @Test + public void testXOR() { + MLPArchitecture conf = new MLPArchitecture(2). + withAddedLayer(2, true, Activators.SIGMOID). + withAddedLayer(1, true, Activators.SIGMOID); + + MultilayerPerceptron mlp = new MultilayerPerceptron(conf, new MLPConstInitializer(1, 2)); + + mlp.setWeights(1, new DenseLocalOnHeapMatrix(new double[][] {{20.0, 20.0}, {-20.0, -20.0}})); + mlp.setBiases(1, new DenseLocalOnHeapVector(new double[] {-10.0, 30.0})); + + mlp.setWeights(2, new DenseLocalOnHeapMatrix(new double[][] {{20.0, 20.0}})); + mlp.setBiases(2, new DenseLocalOnHeapVector(new double[] {-30.0})); + + Matrix input = new DenseLocalOnHeapMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}}).transpose(); + + Matrix predict = mlp.apply(input); + Vector truth = new DenseLocalOnHeapVector(new double[] {0.0, 1.0, 1.0, 0.0}); + + TestUtils.checkIsInEpsilonNeighbourhood(predict.getRow(0), truth, 1E-4); + } + + /** + * Test that two layer MLP is equivalent to it's subparts stacked on each other. + */ + @Test + public void testStackedMLP() { + int firstLayerNeuronsCnt = 3; + int secondLayerNeuronsCnt = 2; + MLPConstInitializer initer = new MLPConstInitializer(1, 2); + + MLPArchitecture conf = new MLPArchitecture(4). + withAddedLayer(firstLayerNeuronsCnt, false, Activators.SIGMOID). + withAddedLayer(secondLayerNeuronsCnt, false, Activators.SIGMOID); + + MultilayerPerceptron mlp = new MultilayerPerceptron(conf, initer); + + MLPArchitecture mlpLayer1Conf = new MLPArchitecture(4). + withAddedLayer(firstLayerNeuronsCnt, false, Activators.SIGMOID); + MLPArchitecture mlpLayer2Conf = new MLPArchitecture(firstLayerNeuronsCnt). + withAddedLayer(secondLayerNeuronsCnt, false, Activators.SIGMOID); + + MultilayerPerceptron mlp1 = new MultilayerPerceptron(mlpLayer1Conf, initer); + MultilayerPerceptron mlp2 = new MultilayerPerceptron(mlpLayer2Conf, initer); + + MultilayerPerceptron stackedMLP = mlp1.add(mlp2); + + Matrix predict = mlp.apply(new DenseLocalOnHeapMatrix(new double[][] {{1, 2, 3, 4}}).transpose()); + Matrix stackedPredict = stackedMLP.apply(new DenseLocalOnHeapMatrix(new double[][] {{1, 2, 3, 4}}).transpose()); + + Assert.assertEquals(predict, stackedPredict); + } + + /** + * Test parameters count works well. + */ + @Test + public void paramsCountTest() { + int inputSize = 10; + int layerWithBiasNeuronsCnt = 13; + int layerWithoutBiasNeuronsCnt = 17; + + MLPArchitecture conf = new MLPArchitecture(inputSize). + withAddedLayer(layerWithBiasNeuronsCnt, true, Activators.SIGMOID). + withAddedLayer(layerWithoutBiasNeuronsCnt, false, Activators.SIGMOID); + + Assert.assertEquals(layerWithBiasNeuronsCnt * inputSize + layerWithBiasNeuronsCnt + (layerWithoutBiasNeuronsCnt * layerWithBiasNeuronsCnt), + conf.parametersCount()); + } + + /** + * Test methods related to parameters flattening. + */ + @Test + public void setParamsFlattening() { + int inputSize = 3; + int firstLayerNeuronsCnt = 2; + int secondLayerNeurons = 1; + + DenseLocalOnHeapVector paramsVector = new DenseLocalOnHeapVector(new double[] { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, // First layer weight matrix. + 7.0, 8.0, // Second layer weight matrix. + 9.0 // Second layer biases. + }); + + DenseLocalOnHeapMatrix firstLayerWeights = new DenseLocalOnHeapMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}); + DenseLocalOnHeapMatrix secondLayerWeights = new DenseLocalOnHeapMatrix(new double[][] {{7.0, 8.0}}); + DenseLocalOnHeapVector secondLayerBiases = new DenseLocalOnHeapVector(new double[] {9.0}); + + MLPArchitecture conf = new MLPArchitecture(inputSize). + withAddedLayer(firstLayerNeuronsCnt, false, Activators.SIGMOID). + withAddedLayer(secondLayerNeurons, true, Activators.SIGMOID); + + MultilayerPerceptron mlp = new MultilayerPerceptron(conf, new MLPConstInitializer(100, 200)); + + mlp.setParameters(paramsVector); + Assert.assertEquals(paramsVector, mlp.parameters()); + + Assert.assertEquals(mlp.weights(1), firstLayerWeights); + Assert.assertEquals(mlp.weights(2), secondLayerWeights); + Assert.assertEquals(mlp.biases(2), secondLayerBiases); + } + + /** + * Test differentiation. + */ + @Test + public void testDifferentiation() { + int inputSize = 2; + int firstLayerNeuronsCnt = 1; + + double w10 = 0.1; + double w11 = 0.2; + + MLPArchitecture conf = new MLPArchitecture(inputSize). + withAddedLayer(firstLayerNeuronsCnt, false, Activators.SIGMOID); + + MultilayerPerceptron mlp = new MultilayerPerceptron(conf); + + mlp.setWeight(1, 0, 0, w10); + mlp.setWeight(1, 1, 0, w11); + double x0 = 1.0; + double x1 = 3.0; + + Matrix inputs = new DenseLocalOnHeapMatrix(new double[][] {{x0, x1}}).transpose(); + double ytt = 1.0; + Matrix truth = new DenseLocalOnHeapMatrix(new double[][] {{ytt}}).transpose(); + + Vector grad = mlp.differentiateByParameters(LossFunctions.MSE, inputs, truth); + + // Let yt be y ground truth value. + // d/dw1i [(yt - sigma(w10 * x0 + w11 * x1))^2] = + // 2 * (yt - sigma(w10 * x0 + w11 * x1)) * (-1) * (sigma(w10 * x0 + w11 * x1)) * (1 - sigma(w10 * x0 + w11 * x1)) * xi = + // let z = sigma(w10 * x0 + w11 * x1) + // - 2* (yt - z) * (z) * (1 - z) * xi. + + IgniteTriFunction<Double, Vector, Vector, Vector> partialDer = (yt, w, x) -> { + Double z = Activators.SIGMOID.apply(w.dot(x)); + + return x.copy().map(xi -> -2 * (yt - z) * z * (1 - z) * xi); + }; + + Vector weightsVec = mlp.weights(1).getRow(0); + Tracer.showAscii(weightsVec); + + Vector trueGrad = partialDer.apply(ytt, weightsVec, inputs.getCol(0)); + + Tracer.showAscii(trueGrad); + Tracer.showAscii(grad); + + Assert.assertEquals(mlp.architecture().parametersCount(), grad.size()); + Assert.assertEquals(trueGrad, grad); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java new file mode 100644 index 0000000..d006cd9 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.nn; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for multilayer perceptrons. + */ +@RunWith(Suite.class) [email protected]({ + MLPTest.class, + MLPLocalTrainerTest.class, +}) +public class MLPTestSuite { + // No-op. +}
