[HIVEMALL-230] Revise Optimizer Implementation ## What changes were proposed in this pull request?
Revise Optimizer implementation. 1. Revise default hyperparameters of AdaDelta and Adam. 2. Support AdamW, Amsgrad, AdamHD, Eve, and YellowFin optimizer. - [x] Nesterovâs Accelerated Gradient https://arxiv.org/abs/1212.0901 - [x] Rmsprop Geoffrey Hinton, Nitish Srivastava, Kevin Swersky. 2014. Lecture 6e: Rmsprop: Divide the gradient by a running average of its recent magnitude http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf - [x] RMSpropGraves - Generating Sequences With Recurrent Neural Networks https://arxiv.org/abs/1308.0850 - [x] Fixing Weight Decay Regularization in Adam https://openreview.net/forum?id=rk6qdGgCZ - [x] On the Convergence of Adam and Beyond https://openreview.net/forum?id=ryQu7f-RZ - [x] AdamHD (Adam with Hypergradient descent) https://arxiv.org/pdf/1703.04782.pdf - [x] Eve: A Gradient Based Optimization Method with Locally and Globally Adaptive Learning Rates https://openreview.net/forum?id=r1WUqIceg - [x] nadam: Adam with Nesterov momentum https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ http://cs229.stanford.edu/proj2015/054_report.pdf http://www.cs.toronto.edu/~fritz/absps/momentum.pdf - [ ] ~YellowFin and the Art of Momentum Tuning~ https://openreview.net/forum?id=SyrGJYlRZ ## What type of PR is it? Improvement, Feature ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-230 ## How was this patch tested? unit tests, emr ## How to use this feature? Described in [tutorial](http://hivemall.incubator.apache.org/userguide/index.html) ## Checklist - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Makoto Yui <[email protected]> Closes #175 from myui/adam_test. Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/31932fd7 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/31932fd7 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/31932fd7 Branch: refs/heads/master Commit: 31932fd7c63f9bb21eba8959944d03f280b6deb9 Parents: bc06c93 Author: Makoto Yui <[email protected]> Authored: Wed Dec 26 19:14:23 2018 +0900 Committer: Makoto Yui <[email protected]> Committed: Wed Dec 26 19:14:23 2018 +0900 ---------------------------------------------------------------------- .../java/hivemall/GeneralLearnerBaseUDTF.java | 24 +- core/src/main/java/hivemall/UDFWithOptions.java | 2 +- .../src/main/java/hivemall/UDTFWithOptions.java | 77 ++- .../main/java/hivemall/model/IWeightValue.java | 4 + .../main/java/hivemall/model/WeightValue.java | 34 +- .../hivemall/model/WeightValueWithClock.java | 20 + .../optimizer/DenseOptimizerFactory.java | 328 ++++++++++- .../java/hivemall/optimizer/EtaEstimator.java | 26 + .../main/java/hivemall/optimizer/Optimizer.java | 558 ++++++++++++++++++- .../hivemall/optimizer/OptimizerOptions.java | 26 +- .../optimizer/SparseOptimizerFactory.java | 211 ++++++- .../java/hivemall/utils/math/MathUtils.java | 12 + .../classifier/GeneralClassifierUDTFTest.java | 543 ++++++++++++++++++ .../java/hivemall/optimizer/OptimizerTest.java | 14 +- .../hivemall/classifier/adam_test_10000.tsv.gz | Bin 0 -> 285318 bytes 15 files changed, 1765 insertions(+), 114 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index 4aad70a..90ad97c 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -141,8 +141,8 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length < 2) { - throw new UDFArgumentException( - "_FUNC_ takes 2 arguments: List<Int|BigInt|Text> features, float target [, constant string options]"); + showHelp( + "_FUNC_ takes two or three arguments: List<Int|BigInt|Text> features, float target [, constant string options]"); } this.featureListOI = HiveUtils.asListOI(argOIs[0]); this.featureType = getFeatureType(featureListOI); @@ -452,12 +452,13 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { protected void update(@Nonnull final FeatureValue[] features, final float target, final float predicted) { + optimizer.proceedStep(); + float loss = lossFunction.loss(predicted, target); cvState.incrLoss(loss); // retain cumulative loss to check convergence float dloss = lossFunction.dloss(predicted, target); if (dloss == 0.f) { - optimizer.proceedStep(); return; } if (dloss < MIN_DLOSS) { @@ -467,24 +468,25 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { } if (is_mini_batch) { - accumulateUpdate(features, dloss); + accumulateUpdate(features, loss, dloss); if (sampled >= mini_batch_size) { batchUpdate(); } } else { - onlineUpdate(features, dloss); + onlineUpdate(features, loss, dloss); } - optimizer.proceedStep(); } - protected void accumulateUpdate(@Nonnull final FeatureValue[] features, final float dloss) { + protected void accumulateUpdate(@Nonnull final FeatureValue[] features, final float loss, + final float dloss) { for (FeatureValue f : features) { Object feature = f.getFeature(); float xi = f.getValueAsFloat(); float weight = model.getWeight(feature); // compute new weight, but still not set to the model - float new_weight = optimizer.update(feature, weight, dloss * xi); + float gradient = dloss * xi; + float new_weight = optimizer.update(feature, weight, loss, gradient); // (w_i - eta * delta_1) + (w_i - eta * delta_2) + ... + (w_i - eta * delta_M) FloatAccumulator acc = accumulated.get(feature); @@ -519,12 +521,14 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { this.sampled = 0; } - protected void onlineUpdate(@Nonnull final FeatureValue[] features, final float dloss) { + protected void onlineUpdate(@Nonnull final FeatureValue[] features, final float loss, + final float dloss) { for (FeatureValue f : features) { Object feature = f.getFeature(); float xi = f.getValueAsFloat(); float weight = model.getWeight(feature); - final float new_weight = optimizer.update(feature, weight, dloss * xi); + float gradient = dloss * xi; + final float new_weight = optimizer.update(feature, weight, loss, gradient); if (new_weight == 0.f) { model.delete(feature); continue; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/UDFWithOptions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/UDFWithOptions.java b/core/src/main/java/hivemall/UDFWithOptions.java index 89e7662..f8272ce 100644 --- a/core/src/main/java/hivemall/UDFWithOptions.java +++ b/core/src/main/java/hivemall/UDFWithOptions.java @@ -100,7 +100,7 @@ public abstract class UDFWithOptions extends GenericUDF { } private void showHelp(@Nonnull Options opts) throws UDFArgumentException { - showHelp(getOptions(), null); + showHelp(opts, null); } private void showHelp(@Nonnull Options opts, @Nullable String errMsg) http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/UDTFWithOptions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/UDTFWithOptions.java b/core/src/main/java/hivemall/UDTFWithOptions.java index b09cffa..43d9023 100644 --- a/core/src/main/java/hivemall/UDTFWithOptions.java +++ b/core/src/main/java/hivemall/UDTFWithOptions.java @@ -97,29 +97,46 @@ public abstract class UDTFWithOptions extends GenericUDTF { CommandLine cl = CommandLineUtils.parseOptions(args, opts); if (cl.hasOption("help")) { - Description funcDesc = getClass().getAnnotation(Description.class); - final String cmdLineSyntax; - if (funcDesc == null) { - cmdLineSyntax = getClass().getSimpleName(); - } else { - String funcName = funcDesc.name(); - cmdLineSyntax = funcName == null ? getClass().getSimpleName() - : funcDesc.value().replace("_FUNC_", funcDesc.name()); - } - StringWriter sw = new StringWriter(); - sw.write('\n'); - PrintWriter pw = new PrintWriter(sw); - HelpFormatter formatter = new HelpFormatter(); - formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts, - HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true); - pw.flush(); - String helpMsg = sw.toString(); - throw new UDFArgumentException(helpMsg); + showHelp(opts); } return cl; } + protected void showHelp(@Nullable String errMsg) throws UDFArgumentException { + showHelp(getOptions(), errMsg); + } + + private void showHelp(@Nonnull Options opts) throws UDFArgumentException { + showHelp(opts, null); + } + + private void showHelp(@Nonnull Options opts, @Nullable String errMsg) + throws UDFArgumentException { + Description funcDesc = getClass().getAnnotation(Description.class); + final String cmdLineSyntax; + if (funcDesc == null) { + cmdLineSyntax = getClass().getSimpleName(); + } else { + String funcName = funcDesc.name(); + cmdLineSyntax = funcName == null ? getClass().getSimpleName() + : funcDesc.value().replace("_FUNC_", funcDesc.name()); + } + StringWriter sw = new StringWriter(); + sw.write('\n'); + if (errMsg != null) { + sw.write(errMsg); + sw.write("\n\n"); + } + PrintWriter pw = new PrintWriter(sw); + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts, + HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true); + pw.flush(); + String helpMsg = sw.toString(); + throw new UDFArgumentException(helpMsg); + } + protected abstract CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException; @@ -149,4 +166,28 @@ public abstract class UDTFWithOptions extends GenericUDTF { return list; } + /** + * Raise {@link UDFArgumentException} if the given condition is false. + * + * @throws UDFArgumentException + */ + protected static void assumeTrue(final boolean condition, @Nonnull final String errMsg) + throws UDFArgumentException { + if (!condition) { + throw new UDFArgumentException(errMsg); + } + } + + /** + * Raise {@link UDFArgumentException} if the given condition is true. + * + * @throws UDFArgumentException + */ + protected static void assumeFalse(final boolean condition, @Nonnull final String errMsg) + throws UDFArgumentException { + if (condition) { + throw new UDFArgumentException(errMsg); + } + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/model/IWeightValue.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/IWeightValue.java b/core/src/main/java/hivemall/model/IWeightValue.java index 731c310..76ab4f4 100644 --- a/core/src/main/java/hivemall/model/IWeightValue.java +++ b/core/src/main/java/hivemall/model/IWeightValue.java @@ -50,6 +50,10 @@ public interface IWeightValue extends Copyable<IWeightValue> { void setSumOfSquaredDeltaX(float value); + float getDelta(); + + void setDelta(float value); + float getSumOfGradients(); void setSumOfGradients(float value); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/model/WeightValue.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/WeightValue.java b/core/src/main/java/hivemall/model/WeightValue.java index 3d09a56..8653451 100644 --- a/core/src/main/java/hivemall/model/WeightValue.java +++ b/core/src/main/java/hivemall/model/WeightValue.java @@ -92,6 +92,16 @@ public class WeightValue implements IWeightValue { } @Override + public float getDelta() { + return 0.f; + } + + @Override + public void setDelta(float value) { + throw new UnsupportedOperationException(); + } + + @Override public float getSumOfGradients() { return 0.f; } @@ -202,6 +212,16 @@ public class WeightValue implements IWeightValue { this.f1 = value; } + @Override + public float getDelta() { + return f1; + } + + @Override + public void setDelta(float value) { + this.f1 = value; + } + } /** @@ -314,7 +334,7 @@ public class WeightValue implements IWeightValue { } @Override - public final float getSumOfSquaredGradients() { + public float getSumOfSquaredGradients() { return f1; } @@ -324,7 +344,7 @@ public class WeightValue implements IWeightValue { } @Override - public final float getSumOfSquaredDeltaX() { + public float getSumOfSquaredDeltaX() { return f2; } @@ -334,6 +354,16 @@ public class WeightValue implements IWeightValue { } @Override + public float getDelta() { + return f2; + } + + @Override + public void setDelta(float value) { + this.f2 = value; + } + + @Override public float getSumOfGradients() { return f3; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/model/WeightValueWithClock.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/WeightValueWithClock.java b/core/src/main/java/hivemall/model/WeightValueWithClock.java index 679c519..78d2b6e 100644 --- a/core/src/main/java/hivemall/model/WeightValueWithClock.java +++ b/core/src/main/java/hivemall/model/WeightValueWithClock.java @@ -94,6 +94,16 @@ public class WeightValueWithClock implements IWeightValue { } @Override + public float getDelta() { + return 0.f; + } + + @Override + public void setDelta(float value) { + throw new UnsupportedOperationException(); + } + + @Override public float getSumOfGradients() { return 0.f; } @@ -214,6 +224,16 @@ public class WeightValueWithClock implements IWeightValue { this.f1 = value; } + @Override + public float getDelta() { + return f1; + } + + @Override + public void setDelta(float value) { + this.f1 = value; + } + } public static final class WeightValueParamsF2Clock extends WeightValueWithClock { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java index 5985868..2ead147 100644 --- a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java +++ b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java @@ -18,8 +18,10 @@ */ package hivemall.optimizer; -import hivemall.model.IWeightValue; -import hivemall.model.WeightValue; +import hivemall.model.WeightValue.WeightValueParamsF1; +import hivemall.model.WeightValue.WeightValueParamsF2; +import hivemall.model.WeightValue.WeightValueParamsF3; +import hivemall.optimizer.Optimizer.OptimizerBase; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.math.MathUtils; @@ -43,20 +45,24 @@ public final class DenseOptimizerFactory { if (optimizerName == null) { throw new IllegalArgumentException("`optimizer` not defined"); } + final String name = optimizerName.toLowerCase(); if ("rda".equalsIgnoreCase(options.get("regularization")) - && "adagrad".equalsIgnoreCase(optimizerName) == false) { + && "adagrad".equals(name) == false) { throw new IllegalArgumentException( "`-regularization rda` is only supported for AdaGrad but `-optimizer " + optimizerName + "`. Please specify `-regularization l1` and so on."); } - final Optimizer optimizerImpl; - if ("sgd".equalsIgnoreCase(optimizerName)) { + final OptimizerBase optimizerImpl; + if ("sgd".equals(name)) { optimizerImpl = new Optimizer.SGD(options); - } else if ("adadelta".equalsIgnoreCase(optimizerName)) { - optimizerImpl = new AdaDelta(ndims, options); - } else if ("adagrad".equalsIgnoreCase(optimizerName)) { + } else if ("momentum".equals(name)) { + optimizerImpl = new Momentum(ndims, options); + } else if ("nesterov".equals(name)) { + options.put("nesterov", ""); + optimizerImpl = new Momentum(ndims, options); + } else if ("adagrad".equals(name)) { // If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`. if ("rda".equalsIgnoreCase(options.get("regularization"))) { AdaGrad adagrad = new AdaGrad(ndims, options); @@ -64,8 +70,20 @@ public final class DenseOptimizerFactory { } else { optimizerImpl = new AdaGrad(ndims, options); } - } else if ("adam".equalsIgnoreCase(optimizerName)) { + } else if ("rmsprop".equals(name)) { + optimizerImpl = new RMSprop(ndims, options); + } else if ("rmspropgraves".equals(name) || "rmsprop_graves".equals(name)) { + optimizerImpl = new RMSpropGraves(ndims, options); + } else if ("adadelta".equals(name)) { + optimizerImpl = new AdaDelta(ndims, options); + } else if ("adam".equals(name)) { optimizerImpl = new Adam(ndims, options); + } else if ("nadam".equals(name)) { + optimizerImpl = new Nadam(ndims, options); + } else if ("eve".equals(name)) { + optimizerImpl = new Eve(ndims, options); + } else if ("adam_hd".equals(name) || "adamhd".equals(name)) { + optimizerImpl = new AdamHD(ndims, options); } else { throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName); } @@ -73,40 +91,107 @@ public final class DenseOptimizerFactory { if (LOG.isInfoEnabled()) { LOG.info( "Configured " + optimizerImpl.getOptimizerName() + " as the optimizer: " + options); + LOG.info("ETA estimator: " + optimizerImpl._eta); } return optimizerImpl; } @NotThreadSafe - static final class AdaDelta extends Optimizer.AdaDelta { + static final class Momentum extends Optimizer.Momentum { @Nonnull - private final IWeightValue weightValueReused; + private final WeightValueParamsF1 weightValueReused; + @Nonnull + private float[] delta; + public Momentum(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = newWeightValue(0.f); + this.delta = new float[ndims]; + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setDelta(delta[i]); + update(weightValueReused, gradient); + delta[i] = weightValueReused.getDelta(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if (index >= delta.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.delta = Arrays.copyOf(delta, newSize); + } + } + + } + + @NotThreadSafe + static final class AdaGrad extends Optimizer.AdaGrad { + + @Nonnull + private final WeightValueParamsF1 weightValueReused; @Nonnull private float[] sum_of_squared_gradients; + + public AdaGrad(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = newWeightValue(0.f); + this.sum_of_squared_gradients = new float[ndims]; + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setSumOfSquaredGradients(sum_of_squared_gradients[i]); + update(weightValueReused, gradient); + sum_of_squared_gradients[i] = weightValueReused.getSumOfSquaredGradients(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if (index >= sum_of_squared_gradients.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); + } + } + + } + + @NotThreadSafe + static final class RMSprop extends Optimizer.RMSprop { + @Nonnull - private float[] sum_of_squared_delta_x; + private final WeightValueParamsF1 weightValueReused; + @Nonnull + private float[] sum_of_squared_gradients; - public AdaDelta(int ndims, Map<String, String> options) { + public RMSprop(int ndims, Map<String, String> options) { super(options); - this.weightValueReused = new WeightValue.WeightValueParamsF2(0.f, 0.f, 0.f); + this.weightValueReused = newWeightValue(0.f); this.sum_of_squared_gradients = new float[ndims]; - this.sum_of_squared_delta_x = new float[ndims]; } @Override - public float update(@Nonnull final Object feature, final float weight, + protected float update(@Nonnull final Object feature, final float weight, final float gradient) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); weightValueReused.set(weight); weightValueReused.setSumOfSquaredGradients(sum_of_squared_gradients[i]); - weightValueReused.setSumOfSquaredDeltaX(sum_of_squared_delta_x[i]); update(weightValueReused, gradient); sum_of_squared_gradients[i] = weightValueReused.getSumOfSquaredGradients(); - sum_of_squared_delta_x[i] = weightValueReused.getSumOfSquaredDeltaX(); return weightValueReused.get(); } @@ -115,35 +200,88 @@ public final class DenseOptimizerFactory { int bits = MathUtils.bitsRequired(index); int newSize = (1 << bits) + 1; this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); - this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize); } } } @NotThreadSafe - static final class AdaGrad extends Optimizer.AdaGrad { + static final class RMSpropGraves extends Optimizer.RMSpropGraves { @Nonnull - private final IWeightValue weightValueReused; + private final WeightValueParamsF3 weightValueReused; + @Nonnull + private float[] sum_of_gradients; @Nonnull private float[] sum_of_squared_gradients; + @Nonnull + private float[] delta; - public AdaGrad(int ndims, Map<String, String> options) { + public RMSpropGraves(int ndims, Map<String, String> options) { super(options); - this.weightValueReused = new WeightValue.WeightValueParamsF1(0.f, 0.f); + this.weightValueReused = newWeightValue(0.f); + this.sum_of_gradients = new float[ndims]; this.sum_of_squared_gradients = new float[ndims]; + this.delta = new float[ndims]; } @Override - public float update(@Nonnull final Object feature, final float weight, + protected float update(@Nonnull final Object feature, final float weight, final float gradient) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); weightValueReused.set(weight); + weightValueReused.setSumOfGradients(sum_of_gradients[i]); weightValueReused.setSumOfSquaredGradients(sum_of_squared_gradients[i]); + weightValueReused.setDelta(delta[i]); update(weightValueReused, gradient); + sum_of_gradients[i] = weightValueReused.getSumOfGradients(); sum_of_squared_gradients[i] = weightValueReused.getSumOfSquaredGradients(); + delta[i] = weightValueReused.getDelta(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if (index >= sum_of_gradients.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize); + this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); + this.delta = Arrays.copyOf(delta, newSize); + } + } + + } + + @NotThreadSafe + static final class AdaDelta extends Optimizer.AdaDelta { + + @Nonnull + private final WeightValueParamsF2 weightValueReused; + + @Nonnull + private float[] sum_of_squared_gradients; + @Nonnull + private float[] sum_of_squared_delta_x; + + public AdaDelta(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = newWeightValue(0.f); + this.sum_of_squared_gradients = new float[ndims]; + this.sum_of_squared_delta_x = new float[ndims]; + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setSumOfSquaredGradients(sum_of_squared_gradients[i]); + weightValueReused.setSumOfSquaredDeltaX(sum_of_squared_delta_x[i]); + update(weightValueReused, gradient); + sum_of_squared_gradients[i] = weightValueReused.getSumOfSquaredGradients(); + sum_of_squared_delta_x[i] = weightValueReused.getSumOfSquaredDeltaX(); return weightValueReused.get(); } @@ -152,6 +290,7 @@ public final class DenseOptimizerFactory { int bits = MathUtils.bitsRequired(index); int newSize = (1 << bits) + 1; this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); + this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize); } } @@ -161,7 +300,7 @@ public final class DenseOptimizerFactory { static final class Adam extends Optimizer.Adam { @Nonnull - private final IWeightValue weightValueReused; + private final WeightValueParamsF2 weightValueReused; @Nonnull private float[] val_m; @@ -170,13 +309,13 @@ public final class DenseOptimizerFactory { public Adam(int ndims, Map<String, String> options) { super(options); - this.weightValueReused = new WeightValue.WeightValueParamsF2(0.f, 0.f, 0.f); + this.weightValueReused = newWeightValue(0.f); this.val_m = new float[ndims]; this.val_v = new float[ndims]; } @Override - public float update(@Nonnull final Object feature, final float weight, + protected float update(@Nonnull final Object feature, final float weight, final float gradient) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); @@ -201,10 +340,139 @@ public final class DenseOptimizerFactory { } @NotThreadSafe + static final class Nadam extends Optimizer.Nadam { + + @Nonnull + private final WeightValueParamsF2 weightValueReused; + + @Nonnull + private float[] val_m; + @Nonnull + private float[] val_v; + + public Nadam(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = newWeightValue(0.f); + this.val_m = new float[ndims]; + this.val_v = new float[ndims]; + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setM(val_m[i]); + weightValueReused.setV(val_v[i]); + update(weightValueReused, gradient); + val_m[i] = weightValueReused.getM(); + val_v[i] = weightValueReused.getV(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if (index >= val_m.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.val_m = Arrays.copyOf(val_m, newSize); + this.val_v = Arrays.copyOf(val_v, newSize); + } + } + + } + + @NotThreadSafe + static final class Eve extends Optimizer.Eve { + + @Nonnull + private final WeightValueParamsF2 weightValueReused; + + @Nonnull + private float[] val_m; + @Nonnull + private float[] val_v; + + public Eve(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = newWeightValue(0.f); + this.val_m = new float[ndims]; + this.val_v = new float[ndims]; + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setM(val_m[i]); + weightValueReused.setV(val_v[i]); + update(weightValueReused, gradient); + val_m[i] = weightValueReused.getM(); + val_v[i] = weightValueReused.getV(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if (index >= val_m.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.val_m = Arrays.copyOf(val_m, newSize); + this.val_v = Arrays.copyOf(val_v, newSize); + } + } + + } + + + @NotThreadSafe + static final class AdamHD extends Optimizer.AdamHD { + + @Nonnull + private final WeightValueParamsF2 weightValueReused; + + @Nonnull + private float[] val_m; + @Nonnull + private float[] val_v; + + public AdamHD(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = newWeightValue(0.f); + this.val_m = new float[ndims]; + this.val_v = new float[ndims]; + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setM(val_m[i]); + weightValueReused.setV(val_v[i]); + update(weightValueReused, gradient); + val_m[i] = weightValueReused.getM(); + val_v[i] = weightValueReused.getV(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if (index >= val_m.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.val_m = Arrays.copyOf(val_m, newSize); + this.val_v = Arrays.copyOf(val_v, newSize); + } + } + } + + @NotThreadSafe static final class AdagradRDA extends Optimizer.AdagradRDA { @Nonnull - private final IWeightValue weightValueReused; + private final WeightValueParamsF2 weightValueReused; @Nonnull private float[] sum_of_gradients; @@ -212,12 +480,12 @@ public final class DenseOptimizerFactory { public AdagradRDA(int ndims, @Nonnull Optimizer.AdaGrad optimizerImpl, @Nonnull Map<String, String> options) { super(optimizerImpl, options); - this.weightValueReused = new WeightValue.WeightValueParamsF3(0.f, 0.f, 0.f, 0.f); + this.weightValueReused = newWeightValue(0.f); this.sum_of_gradients = new float[ndims]; } @Override - public float update(@Nonnull final Object feature, final float weight, + protected float update(@Nonnull final Object feature, final float weight, final float gradient) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/optimizer/EtaEstimator.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/EtaEstimator.java b/core/src/main/java/hivemall/optimizer/EtaEstimator.java index a8847cd..32693bb 100644 --- a/core/src/main/java/hivemall/optimizer/EtaEstimator.java +++ b/core/src/main/java/hivemall/optimizer/EtaEstimator.java @@ -20,6 +20,7 @@ package hivemall.optimizer; import hivemall.utils.lang.NumberUtils; import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.StringUtils; import java.util.Map; @@ -57,6 +58,11 @@ public abstract class EtaEstimator { return eta0; } + @Override + public String toString() { + return "FixedEtaEstimator [ eta0 = " + eta0 + " ]"; + } + } public static final class SimpleEtaEstimator extends EtaEstimator { @@ -78,6 +84,12 @@ public abstract class EtaEstimator { return (float) (eta0 / (1.d + (t / total_steps))); } + @Override + public String toString() { + return "SimpleEtaEstimator [ eta0 = " + eta0 + ", totalSteps = " + total_steps + + ", finalEta = " + finalEta + " ]"; + } + } public static final class InvscalingEtaEstimator extends EtaEstimator { @@ -94,6 +106,11 @@ public abstract class EtaEstimator { return (float) (eta0 / Math.pow(t, power_t)); } + @Override + public String toString() { + return "InvscalingEtaEstimator [ eta0 = " + eta0 + ", power_t = " + power_t + " ]"; + } + } /** @@ -124,6 +141,11 @@ public abstract class EtaEstimator { this.eta = Math.min(eta0, newEta); // never be larger than eta0 } + @Override + public String toString() { + return "AdjustingEtaEstimator [ eta0 = " + eta0 + ", eta = " + eta + " ]"; + } + } @Nonnull @@ -184,6 +206,10 @@ public abstract class EtaEstimator { } else if ("inv".equalsIgnoreCase(etaScheme) || "inverse".equalsIgnoreCase(etaScheme)) { return new InvscalingEtaEstimator(eta0, power_t); } else { + if (StringUtils.isNumber(etaScheme)) { + float eta = Float.parseFloat(etaScheme); + return new FixedEtaEstimator(eta); + } throw new IllegalArgumentException("Unsupported ETA name: " + etaScheme); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/optimizer/Optimizer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java index 587adf2..f096f88 100644 --- a/core/src/main/java/hivemall/optimizer/Optimizer.java +++ b/core/src/main/java/hivemall/optimizer/Optimizer.java @@ -18,9 +18,20 @@ */ package hivemall.optimizer; +import static hivemall.utils.math.MathUtils.square; +import static java.lang.Math.abs; +import static java.lang.Math.floor; +import static java.lang.Math.min; +import static java.lang.Math.pow; +import static java.lang.Math.sqrt; + import hivemall.model.IWeightValue; import hivemall.model.WeightValue; +import hivemall.model.WeightValue.WeightValueParamsF1; +import hivemall.model.WeightValue.WeightValueParamsF2; +import hivemall.model.WeightValue.WeightValueParamsF3; import hivemall.utils.lang.Primitives; +import hivemall.utils.math.MathUtils; import java.util.Map; @@ -33,7 +44,7 @@ public interface Optimizer { /** * Update the weights of models */ - float update(@Nonnull Object feature, float weight, float gradient); + float update(@Nonnull Object feature, float weight, float loss, float gradient); /** * Count up #step to tune learning rate @@ -51,18 +62,36 @@ public interface Optimizer { @Nonnull protected final Regularization _reg; @Nonnegative - protected long _numStep = 1L; + protected long _numStep = 0L; public OptimizerBase(@Nonnull Map<String, String> options) { - this._eta = EtaEstimator.get(options); + this._eta = getEtaEstimator(options); this._reg = Regularization.get(options); } + @Nonnull + protected abstract IWeightValue newWeightValue(final float weight); + + @Nonnull + protected EtaEstimator getEtaEstimator(@Nonnull Map<String, String> options) { + return EtaEstimator.get(options); + } + @Override public void proceedStep() { _numStep++; } + @Override + public float update(@Nonnull Object feature, float weight, float loss, float gradient) { + return update(feature, weight, gradient); + } + + /** + * Update the weights of models + */ + protected abstract float update(@Nonnull Object feature, float weight, float gradient); + /** * Update the given weight by the given gradient. * @@ -71,7 +100,7 @@ public interface Optimizer { protected float update(@Nonnull final IWeightValue weight, final float gradient) { float oldWeight = weight.get(); float delta = computeDelta(weight, gradient); - float eta = _eta.eta(_numStep); + float eta = eta(_numStep); float reg = _reg.regularize(oldWeight, delta); float newWeight = oldWeight - eta * reg; weight.set(newWeight); @@ -79,6 +108,14 @@ public interface Optimizer { } /** + * @param t time step + * @return learning rate + */ + protected float eta(final long t) { + return _eta.eta(_numStep); + } + + /** * Compute a delta to update */ protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { @@ -93,11 +130,16 @@ public interface Optimizer { public SGD(@Nonnull Map<String, String> options) { super(options); - this.weightValueReused = new WeightValue(0.f); + this.weightValueReused = newWeightValue(0.f); } @Override - public float update(@Nonnull final Object feature, final float weight, + protected WeightValue newWeightValue(final float weight) { + return new WeightValue(weight); + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, final float gradient) { weightValueReused.set(weight); update(weightValueReused, gradient); @@ -111,6 +153,53 @@ public interface Optimizer { } + /** + * Momentum and Nesterov's Accelerated Gradient. + * + * https://arxiv.org/abs/1212.0901 + */ + static abstract class Momentum extends OptimizerBase { + + @Nonnull + private final WeightValueParamsF1 weightValueReused; + + private final boolean nesterov; + private final float alpha; + private final float momentum; + + public Momentum(@Nonnull Map<String, String> options) { + super(options); + this.weightValueReused = newWeightValue(0.f); + this.nesterov = options.containsKey("nesterov"); + this.alpha = Primitives.parseFloat(options.get("alpha"), 1.f); + this.momentum = Primitives.parseFloat(options.get("momentum"), 0.9f); + } + + @Override + protected WeightValueParamsF1 newWeightValue(final float weight) { + return new WeightValueParamsF1(weight, 0.f); + } + + @Override + protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { + final float oldDelta = weight.getDelta(); + final float v = momentum * oldDelta + alpha * gradient; + weight.setDelta(v); + if (nesterov) { + //return momentum * momentum * oldDelta + (1.f + momentum) * alpha * gradient; + return momentum * momentum * v + (1.f + momentum) * alpha * gradient; + } else { + return v; // normal momentum + } + } + + @Override + public String getOptimizerName() { + return nesterov ? "nesterov" : "momentum"; + } + + } + static abstract class AdaGrad extends OptimizerBase { private final float eps; @@ -123,11 +212,16 @@ public interface Optimizer { } @Override + protected WeightValueParamsF1 newWeightValue(final float weight) { + return new WeightValueParamsF1(weight, 0.f); + } + + @Override protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { float old_scaled_gg = weight.getSumOfSquaredGradients(); float new_scaled_gg = old_scaled_gg + gradient * (gradient / scale); weight.setSumOfSquaredGradients(new_scaled_gg); - return (float) (gradient / Math.sqrt(eps + ((double) old_scaled_gg) * scale)); + return (float) (gradient / sqrt(eps + ((double) old_scaled_gg) * scale)); } @Override @@ -137,6 +231,106 @@ public interface Optimizer { } + + /** + * RMSprop optimizer introducing weight decay to AdaGrad. + * + * Geoffrey Hinton, Nitish Srivastava, Kevin Swersky. 2014. "Lecture 6e: Rmsprop: Divide the + * gradient by a running average of its recent magnitude" + * + * @see http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf + */ + static abstract class RMSprop extends OptimizerBase { + + /** decay rate */ + private final float decay; + /** constant for numerical stability */ + private final float eps; + + private final float scale; // to hold g*g in float range + + public RMSprop(@Nonnull Map<String, String> options) { + super(options); + this.decay = Primitives.parseFloat(options.get("decay"), 0.95f); + this.eps = Primitives.parseFloat(options.get("eps"), 1.0f); + this.scale = Primitives.parseFloat(options.get("scale"), 100.0f); + } + + @Override + protected WeightValueParamsF1 newWeightValue(final float weight) { + return new WeightValueParamsF1(weight, 0.f); + } + + @Override + protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { + float old_scaled_gg = weight.getSumOfSquaredGradients(); + float new_scaled_gg = + decay * old_scaled_gg + (1.f - decay) * gradient * (gradient / scale); + weight.setSumOfSquaredGradients(new_scaled_gg); + return (float) (gradient / sqrt(eps + ((double) old_scaled_gg) * scale)); + } + + @Override + public String getOptimizerName() { + return "rmsprop"; + } + + } + + /** + * Alex Graves's RMSprop introducing weight decay and momentum. + * + * @see https://arxiv.org/abs/1308.0850 + */ + static abstract class RMSpropGraves extends OptimizerBase { + + /** decay rate */ + private final float decay; + private final float alpha; + private final float momentum; + /** constant for numerical stability */ + private final float eps; + + private final float scale; // to hold g*g in float range + + public RMSpropGraves(@Nonnull Map<String, String> options) { + super(options); + this.decay = Primitives.parseFloat(options.get("decay"), 0.95f); + this.alpha = Primitives.parseFloat(options.get("alpha"), 1.f); + this.momentum = Primitives.parseFloat(options.get("momentum"), 0.9f); + this.eps = Primitives.parseFloat(options.get("eps"), 1.0f); + this.scale = Primitives.parseFloat(options.get("scale"), 100.0f); + } + + @Override + protected WeightValueParamsF3 newWeightValue(final float weight) { + return new WeightValueParamsF3(weight, 0.f, 0.f, 0.f); + } + + @Override + protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { + float old_scaled_n = weight.getSumOfSquaredGradients(); + float new_scaled_n = + decay * old_scaled_n + (1.f - decay) * gradient * (gradient / scale); + weight.setSumOfSquaredGradients(new_scaled_n); + float old_scaled_g = weight.getSumOfGradients(); + float new_scaled_g = decay * old_scaled_g + (1.f - decay) * gradient / scale; + weight.setSumOfGradients(new_scaled_g); + double n = ((double) old_scaled_n) * scale; + double g = ((double) new_scaled_g) * scale; + float oldDelta = weight.getDelta(); + float delta = momentum * oldDelta + alpha * (float) (gradient / sqrt(n - g * g + eps)); + weight.setDelta(delta); + return delta; + } + + @Override + public String getOptimizerName() { + return "rmsprop_graves"; + } + + } + static abstract class AdaDelta extends OptimizerBase { private final float decay; @@ -151,12 +345,29 @@ public interface Optimizer { } @Override + protected WeightValueParamsF2 newWeightValue(final float weight) { + return new WeightValueParamsF2(weight, 0.f, 0.f); + } + + @Override + protected final EtaEstimator getEtaEstimator(@Nonnull Map<String, String> options) { + // override default learning rate scheme + if (!options.containsKey("eta")) { + options.put("eta", "fixed"); + } + if (!options.containsKey("eta0")) { + options.put("eta0", "1.0"); + } + return super.getEtaEstimator(options); + } + + @Override protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { float old_scaled_sum_sqgrad = weight.getSumOfSquaredGradients(); float old_sum_squared_delta_x = weight.getSumOfSquaredDeltaX(); float new_scaled_sum_sqgrad = (decay * old_scaled_sum_sqgrad) + ((1.f - decay) * gradient * (gradient / scale)); - float delta = (float) Math.sqrt( + float delta = (float) sqrt( (old_sum_squared_delta_x + eps) / ((double) new_scaled_sum_sqgrad * scale + eps)) * gradient; float new_sum_squared_delta_x = @@ -177,37 +388,329 @@ public interface Optimizer { * Adam, an algorithm for first-order gradient-based optimization of stochastic objective * functions, based on adaptive estimates of lower-order moments. * - * - D. P. Kingma and J. L. Ba: "ADAM: A Method for Stochastic Optimization." arXiv preprint - * arXiv:1412.6980v8, 2014. + * - D. P. Kingma and J. L. Ba: "ADAM: A Method for Stochastic Optimization." + * https://arxiv.org/abs/1412.6980v8 + * + * - "Fixing Weight Decay Regularization in Adam" https://arxiv.org/pdf/1711.05101.pdf + * + * - "On the Convergence of Adam and Beyond" https://openreview.net/forum?id=ryQu7f-RZ */ static abstract class Adam extends OptimizerBase { - private final float beta; - private final float gamma; - private final float eps_hat; + protected float alpha; + protected final float beta1, beta2; + protected final float eps; + protected final float decay; + + // amsgrad + protected final boolean amsgrad; + protected float max_vhat = Float.MIN_VALUE; public Adam(@Nonnull Map<String, String> options) { super(options); - this.beta = Primitives.parseFloat(options.get("beta"), 0.9f); - this.gamma = Primitives.parseFloat(options.get("gamma"), 0.999f); - this.eps_hat = Primitives.parseFloat(options.get("eps_hat"), 1e-8f); + //this.alpha = Primitives.parseFloat(options.get("alpha"), 0.001f); + this.alpha = Primitives.parseFloat(options.get("alpha"), 1.0f); + this.beta1 = Primitives.parseFloat(options.get("beta1"), 0.9f); + this.beta2 = Primitives.parseFloat(options.get("beta2"), 0.999f); + this.eps = Primitives.parseFloat(options.get("eps"), 1e-8f); + this.decay = Primitives.parseFloat(options.get("decay"), 0.f); + this.amsgrad = options.containsKey("amsgrad"); } @Override - protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { - float val_m = beta * weight.getM() + (1.f - beta) * gradient; - float val_v = gamma * weight.getV() + (float) ((1.f - gamma) * Math.pow(gradient, 2.0)); - float val_m_hat = val_m / (float) (1.f - Math.pow(beta, _numStep)); - float val_v_hat = val_v / (float) (1.f - Math.pow(gamma, _numStep)); - float delta = val_m_hat / (float) (Math.sqrt(val_v_hat) + eps_hat); - weight.setM(val_m); - weight.setV(val_v); + protected WeightValueParamsF2 newWeightValue(final float weight) { + return new WeightValueParamsF2(weight, 0.f, 0.f); + } + + @Override + protected float eta(final long t) { + double fix1 = 1.d - pow(beta1, t); + double fix2 = 1.d - pow(beta2, t); + float eta = _eta.eta(t); + double fix = sqrt(fix2) / fix1; + return (float) (eta * fix); + } + + protected double alpha() { + double fix1 = 1.d - pow(beta1, _numStep); + double fix2 = 1.d - pow(beta2, _numStep); + double fix = sqrt(fix2) / fix1; + return alpha * fix; + } + + @Override + protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) { + if (decay != 0.f) {// L2 regularization for weight decay + float oldWeight = weight.get(); + gradient += decay * oldWeight; + } + // update biased first moment estimate + float m = beta1 * weight.getM() + (1.f - beta1) * gradient; + // update biased second raw moment estimate + float v = beta2 * weight.getV() + (float) ((1.f - beta2) * square(gradient)); + float v_hat = v; + if (amsgrad) { + if (v_hat > max_vhat) { + this.max_vhat = v_hat; + } else {// v_hat <= max_vhat + v_hat = max_vhat; + } + } + // bias correlation using v_hat and m_hat + double deltaU = m / (sqrt(v_hat) + eps); + // compute delta update + double alpha_t = alpha(); + float delta = (float) (alpha_t * deltaU); + // weight decay + if (decay != 0.f) { + float oldWeight = weight.get(); + delta += decay * oldWeight; + } + weight.setM(m); + weight.setV(v); return delta; } @Override public String getOptimizerName() { - return "adam"; + return amsgrad ? "adam-amsgrad" : "adam"; + } + + } + + /** + * Nadam is Adam optimizer with Nesterov momentum. + * + * @see https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ + * @see http://cs229.stanford.edu/proj2015/054_report.pdf + * @see http://www.cs.toronto.edu/~fritz/absps/momentum.pdf + */ + static abstract class Nadam extends OptimizerBase { + + protected float alpha; + protected final float beta1, beta2; + protected final float eps; + protected final float decay; + protected final float scheduleDecay; + + protected double mu_t, mu_t_1; + protected double mu_product = 1.d; + protected double mu_product_next = 1.d; + + public Nadam(@Nonnull Map<String, String> options) { + super(options); + //this.alpha = Primitives.parseFloat(options.get("alpha"), 0.001f); + this.alpha = Primitives.parseFloat(options.get("alpha"), 1.0f); + this.beta1 = Primitives.parseFloat(options.get("beta1"), 0.9f); + this.beta2 = Primitives.parseFloat(options.get("beta2"), 0.999f); + this.eps = Primitives.parseFloat(options.get("eps"), 1e-8f); + this.decay = Primitives.parseFloat(options.get("decay"), 0.f); + this.scheduleDecay = Primitives.parseFloat(options.get("scheduleDecay"), 0.004f); // 1/250=0.004 + } + + @Override + protected WeightValueParamsF2 newWeightValue(final float weight) { + return new WeightValueParamsF2(weight, 0.f, 0.f); + } + + @Override + public void proceedStep() { + long t = _numStep + 1; + this._numStep = t; + double mu_product_prev = this.mu_product; + // 0.9 * (1 - 0.5 * 0.96^(floor(t/250)+1)) + double mu_t = beta1 * (1.d - 0.5d * pow(0.96d, floor(t * scheduleDecay) + 1.d)); + double mu_t_1 = + beta1 * (1.d - 0.5d * pow(0.96d, floor((t + 1.d) * scheduleDecay) + 1.d)); + this.mu_t = mu_t; + this.mu_t_1 = mu_t_1; + this.mu_product = mu_product_prev * mu_t; + this.mu_product_next = mu_product_prev * mu_t * mu_t_1; + } + + @Override + protected float eta(final long t) { + double fix1 = 1.d - pow(beta1, t); + double fix2 = 1.d - pow(beta2, t); + float eta = _eta.eta(t); + double fix = sqrt(fix2) / fix1; + return (float) (eta * fix); + } + + protected double alpha() { + double fix1 = 1.d - pow(beta1, _numStep); + double fix2 = 1.d - pow(beta2, _numStep); + double fix = sqrt(fix2) / fix1; + return alpha * fix; + } + + @Override + protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) { + if (decay != 0.f) {// L2 regularization for weight decay + float oldWeight = weight.get(); + gradient += decay * oldWeight; + } + // update biased first moment estimate + float m = beta1 * weight.getM() + (1.f - beta1) * gradient; + double m_hat = m / (1.d - mu_product_next); + // update biased second raw moment estimate + float v = beta2 * weight.getV() + (float) ((1.d - beta2) * square(gradient)); + double v_hat = v / (1.d - pow(beta2, _numStep)); + // gradient update for the current timestamp + double g_hat = gradient / (1.d - mu_product); + double m_bar = (1.d - mu_t) * g_hat + mu_t_1 * m_hat; + // bias correlation using v_hat and m_hat + double deltaU = m_bar / (sqrt(v_hat) + eps); + // compute delta update + double alpha_t = alpha(); + float delta = (float) (alpha_t * deltaU); + // weight decay + if (decay != 0.d) { + float oldWeight = weight.get(); + delta += decay * oldWeight; + } + weight.setM(m); + weight.setV(v); + return delta; + } + + @Override + public String getOptimizerName() { + return "nadam"; + } + + } + + /** + * Eve optimizer. + * + * - "Eve: A Gradient Based Optimization Method with Locally and Globally Adaptive Learning + * Rates" https://openreview.net/forum?id=r1WUqIceg + */ + static abstract class Eve extends Adam { + + protected final float beta3; + private float c = 10.f; + private float inv_c = 0.1f; + + private float currLoss; + private float prevLoss = 0.f; + private double prevDt = 1.d; + + public Eve(@Nonnull Map<String, String> options) { + super(options); + this.beta3 = Primitives.parseFloat(options.get("beta3"), 0.999f); + this.c = Primitives.parseFloat(options.get("c"), 10f); + this.inv_c = 1f / c; + } + + @Override + protected double alpha() { + double fix1 = 1.d - pow(beta1, _numStep); + double fix2 = 1.d - pow(beta2, _numStep); + double fix = sqrt(fix2) / fix1; + double alpha_t = alpha * fix; + // feedback of Eve + if (_numStep > 1 && currLoss != prevLoss) { + double d = abs(currLoss - prevLoss) / min(currLoss, prevLoss); + d = MathUtils.clip(d, inv_c, c); // [alpha/c, c*alpha] + d = (beta3 * prevDt) + (1.d - beta3) * d; + this.prevDt = d; + alpha_t = alpha_t / d; + } + return alpha_t; + } + + @Override + public float update(Object feature, float weight, float loss, float gradient) { + this.currLoss = loss; + float delta = update(feature, weight, gradient); + this.prevLoss = loss; + return delta; + } + + @Override + public String getOptimizerName() { + return "eve"; + } + + } + + /** + * Adam optimizer with Hypergradient Descent. + * + * - Online Learning Rate Adaptation with Hypergradient Descent + * https://openreview.net/forum?id=BkrsAzWAb + * + * - Convergence Analysis of an Adaptive Method of Gradient Descent + * https://damaru2.github.io/convergence_analysis_hypergradient_descent/dissertation_hypergradients.pdf + */ + static abstract class AdamHD extends Adam { + + private final float beta; + protected double deltaU = 0.d; + + public AdamHD(@Nonnull Map<String, String> options) { + super(options); + this.alpha = Primitives.parseFloat(options.get("alpha"), 0.02f); + this.beta = Primitives.parseFloat(options.get("beta"), 1e-6f); + } + + @Override + protected final EtaEstimator getEtaEstimator(@Nonnull Map<String, String> options) { + // override default learning rate scheme + if (!options.containsKey("eta")) { + options.put("eta", "fixed"); + } + if (!options.containsKey("eta0")) { + options.put("eta0", "1.0"); + } + return super.getEtaEstimator(options); + } + + private float alpha(final float gradient, final double deltaU) { + // multiplicative hypergradient descent + final double h = gradient * deltaU; + if (h > 0) {// g_{t-1}u_{t-2} > 0 + this.alpha = alpha * (1.f - beta); // decrease alpha + } else if (h < 0) {// g_{t-1}u_{t-2} < 0 + this.alpha = alpha * (1.f + beta); // increase alpha + } + return alpha; + } + + @Override + protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) { + if (decay != 0.f) {// L2 regularization for weight decay + float oldWeight = weight.get(); + gradient += decay * oldWeight; + } + // update biased first moment estimate + float m = beta1 * weight.getM() + (1.f - beta1) * gradient; + // update biased second raw moment estimate + float v = beta2 * weight.getV() + (float) ((1.f - beta2) * square(gradient)); + // compute bias-corrected first moment estimate + double m_hat = m / (1.d - pow(beta1, _numStep)); + // compute bias-corrected second raw moment estimate + double v_hat = v / (1.d - pow(beta2, _numStep)); + // compute delta update + float alpha_t = alpha(gradient, deltaU); + double deltaU = m_hat / (sqrt(v_hat) + eps); + float delta = (float) (alpha_t * deltaU); + this.deltaU = deltaU; + // weight decay + if (decay != 0.f) { + float oldWeight = weight.get(); + delta += decay * oldWeight; + } + weight.setM(m); + weight.setV(v); + return delta; + } + + @Override + public String getOptimizerName() { + return "adam_hd"; } } @@ -225,6 +728,11 @@ public interface Optimizer { } @Override + protected WeightValueParamsF2 newWeightValue(final float weight) { + return new WeightValueParamsF2(weight, 0.f, 0.f); + } + + @Override protected float update(@Nonnull final IWeightValue weight, final float gradient) { final float new_sum_grad = weight.getSumOfGradients() + gradient; // sign(u_{t,i}) http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/optimizer/OptimizerOptions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/OptimizerOptions.java b/core/src/main/java/hivemall/optimizer/OptimizerOptions.java index 2fc838d..fc4772f 100644 --- a/core/src/main/java/hivemall/optimizer/OptimizerOptions.java +++ b/core/src/main/java/hivemall/optimizer/OptimizerOptions.java @@ -41,10 +41,13 @@ public final class OptimizerOptions { } public static void setup(@Nonnull Options opts) { - opts.addOption("opt", "optimizer", true, - "Optimizer to update weights [default: adagrad, sgd, adadelta, adam]"); - opts.addOption("eps", true, "Denominator value of AdaDelta/AdaGrad [default 1e-6]"); - opts.addOption("rho", "decay", true, "Decay rate of AdaDelta [default 0.95]"); + opts.addOption("opt", "optimizer", true, "Optimizer to update weights " + + "[default: adagrad, sgd, momentum, nesterov, rmsprop, rmspropgraves, adadelta, adam, eve, adam_hd]"); + // hyperparameters + opts.addOption("eps", true, + "Denominator value of AdaDelta/AdaGrad/Adam [default: 1e-8 (AdaDelta/Adam), 1.0 (Adagrad)]"); + opts.addOption("rho", "decay", true, + " Exponential decay rate of the first and second order moments [default 0.95 (AdaDelta, rmsprop)]"); // regularization opts.addOption("reg", "regularization", true, "Regularization type [default: rda, l1, l2, elasticnet]"); @@ -57,6 +60,21 @@ public final class OptimizerOptions { opts.addOption("t", "total_steps", true, "a total of n_samples * epochs time steps"); opts.addOption("power_t", true, "The exponent for inverse scaling learning rate [default 0.1]"); + opts.addOption("alpha", true, + "Coefficient of learning rate [default: 1.0 (adam/RMSPropGraves), 0.02 (AdamHD/Nesterov)]"); + // ADAM hyperparameters + opts.addOption("beta1", "momentum", true, + "Exponential decay rate of the first order moment used in Adam [default: 0.9]"); + opts.addOption("beta2", true, + "Exponential decay rate of the second order moment used in Adam [default: 0.999]"); + opts.addOption("decay", false, "Weight decay rate [default: 0.0]"); + opts.addOption("amsgrad", false, "Whether to use AMSGrad variant of Adam"); + // ADAM-HD hyperparameters + opts.addOption("beta", true, "Hyperparameter for tuning alpha in Adam-HD [default: 1e-6f]"); + // Eve hyperparameters + opts.addOption("beta3", true, "Exponential decay rate of alpha value [default: 0.999]"); + opts.addOption("c", true, + "Clipping constant of alpha used in Eve optimizer so that clipped [default: 10]"); // other opts.addOption("scale", true, "Scaling factor for cumulative weights [100.0]"); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java index 1254740..d6d62fe 100644 --- a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java +++ b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java @@ -19,7 +19,7 @@ package hivemall.optimizer; import hivemall.model.IWeightValue; -import hivemall.model.WeightValue; +import hivemall.optimizer.Optimizer.OptimizerBase; import it.unimi.dsi.fastutil.objects.Object2ObjectMap; import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; @@ -42,20 +42,24 @@ public final class SparseOptimizerFactory { if (optimizerName == null) { throw new IllegalArgumentException("`optimizer` not defined"); } + final String name = optimizerName.toLowerCase(); if ("rda".equalsIgnoreCase(options.get("regularization")) - && "adagrad".equalsIgnoreCase(optimizerName) == false) { + && "adagrad".equals(name) == false) { throw new IllegalArgumentException( "`-regularization rda` is only supported for AdaGrad but `-optimizer " - + optimizerName); + + optimizerName + "`. Please specify `-regularization l1` and so on."); } - final Optimizer optimizerImpl; - if ("sgd".equalsIgnoreCase(optimizerName)) { + final OptimizerBase optimizerImpl; + if ("sgd".equals(name)) { optimizerImpl = new Optimizer.SGD(options); - } else if ("adadelta".equalsIgnoreCase(optimizerName)) { - optimizerImpl = new AdaDelta(ndims, options); - } else if ("adagrad".equalsIgnoreCase(optimizerName)) { + } else if ("momentum".equals(name)) { + optimizerImpl = new Momentum(ndims, options); + } else if ("nesterov".equals(name)) { + options.put("nesterov", ""); + optimizerImpl = new Momentum(ndims, options); + } else if ("adagrad".equals(name)) { // If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`. if ("rda".equalsIgnoreCase(options.get("regularization"))) { AdaGrad adagrad = new AdaGrad(ndims, options); @@ -63,8 +67,20 @@ public final class SparseOptimizerFactory { } else { optimizerImpl = new AdaGrad(ndims, options); } - } else if ("adam".equalsIgnoreCase(optimizerName)) { + } else if ("rmsprop".equals(name)) { + optimizerImpl = new RMSprop(ndims, options); + } else if ("rmspropgraves".equals(name) || "rmsprop_graves".equals(name)) { + optimizerImpl = new RMSpropGraves(ndims, options); + } else if ("adadelta".equals(name)) { + optimizerImpl = new AdaDelta(ndims, options); + } else if ("adam".equals(name)) { optimizerImpl = new Adam(ndims, options); + } else if ("nadam".equals(name)) { + optimizerImpl = new Nadam(ndims, options); + } else if ("eve".equals(name)) { + optimizerImpl = new Eve(ndims, options); + } else if ("adam_hd".equals(name) || "adamhd".equals(name)) { + optimizerImpl = new AdamHD(ndims, options); } else { throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName); } @@ -72,28 +88,29 @@ public final class SparseOptimizerFactory { if (LOG.isInfoEnabled()) { LOG.info( "Configured " + optimizerImpl.getOptimizerName() + " as the optimizer: " + options); + LOG.info("ETA estimator: " + optimizerImpl._eta); } return optimizerImpl; } @NotThreadSafe - static final class AdaDelta extends Optimizer.AdaDelta { + static final class Momentum extends Optimizer.Momentum { @Nonnull private final Object2ObjectMap<Object, IWeightValue> auxWeights; - public AdaDelta(@Nonnegative int size, @Nonnull Map<String, String> options) { + public Momentum(@Nonnegative int size, @Nonnull Map<String, String> options) { super(options); this.auxWeights = new Object2ObjectOpenHashMap<Object, IWeightValue>(size); } @Override - public float update(@Nonnull final Object feature, final float weight, + protected float update(@Nonnull final Object feature, final float weight, final float gradient) { IWeightValue auxWeight = auxWeights.get(feature); if (auxWeight == null) { - auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeight = newWeightValue(weight); auxWeights.put(feature, auxWeight); } else { auxWeight.set(weight); @@ -115,11 +132,89 @@ public final class SparseOptimizerFactory { } @Override - public float update(@Nonnull final Object feature, final float weight, + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + IWeightValue auxWeight = auxWeights.get(feature); + if (auxWeight == null) { + auxWeight = newWeightValue(weight); + auxWeights.put(feature, auxWeight); + } else { + auxWeight.set(weight); + } + return update(auxWeight, gradient); + } + + } + + @NotThreadSafe + static final class RMSprop extends Optimizer.RMSprop { + + @Nonnull + private final Object2ObjectMap<Object, IWeightValue> auxWeights; + + public RMSprop(@Nonnegative int size, @Nonnull Map<String, String> options) { + super(options); + this.auxWeights = new Object2ObjectOpenHashMap<Object, IWeightValue>(size); + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + IWeightValue auxWeight = auxWeights.get(feature); + if (auxWeight == null) { + auxWeight = newWeightValue(weight); + auxWeights.put(feature, auxWeight); + } else { + auxWeight.set(weight); + } + return update(auxWeight, gradient); + } + + } + + @NotThreadSafe + static final class RMSpropGraves extends Optimizer.RMSpropGraves { + + @Nonnull + private final Object2ObjectMap<Object, IWeightValue> auxWeights; + + public RMSpropGraves(@Nonnegative int size, @Nonnull Map<String, String> options) { + super(options); + this.auxWeights = new Object2ObjectOpenHashMap<Object, IWeightValue>(size); + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + IWeightValue auxWeight = auxWeights.get(feature); + if (auxWeight == null) { + auxWeight = newWeightValue(weight); + auxWeights.put(feature, auxWeight); + } else { + auxWeight.set(weight); + } + return update(auxWeight, gradient); + } + + } + + @NotThreadSafe + static final class AdaDelta extends Optimizer.AdaDelta { + + @Nonnull + private final Object2ObjectMap<Object, IWeightValue> auxWeights; + + public AdaDelta(@Nonnegative int size, @Nonnull Map<String, String> options) { + super(options); + this.auxWeights = new Object2ObjectOpenHashMap<Object, IWeightValue>(size); + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, final float gradient) { IWeightValue auxWeight = auxWeights.get(feature); if (auxWeight == null) { - auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeight = newWeightValue(weight); auxWeights.put(feature, auxWeight); } else { auxWeight.set(weight); @@ -141,11 +236,89 @@ public final class SparseOptimizerFactory { } @Override - public float update(@Nonnull final Object feature, final float weight, + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + IWeightValue auxWeight = auxWeights.get(feature); + if (auxWeight == null) { + auxWeight = newWeightValue(weight); + auxWeights.put(feature, auxWeight); + } else { + auxWeight.set(weight); + } + return update(auxWeight, gradient); + } + + } + + @NotThreadSafe + static final class Nadam extends Optimizer.Nadam { + + @Nonnull + private final Object2ObjectMap<Object, IWeightValue> auxWeights; + + public Nadam(@Nonnegative int size, @Nonnull Map<String, String> options) { + super(options); + this.auxWeights = new Object2ObjectOpenHashMap<Object, IWeightValue>(size); + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + IWeightValue auxWeight = auxWeights.get(feature); + if (auxWeight == null) { + auxWeight = newWeightValue(weight); + auxWeights.put(feature, auxWeight); + } else { + auxWeight.set(weight); + } + return update(auxWeight, gradient); + } + + } + + @NotThreadSafe + static final class Eve extends Optimizer.Eve { + + @Nonnull + private final Object2ObjectMap<Object, IWeightValue> auxWeights; + + public Eve(@Nonnegative int size, @Nonnull Map<String, String> options) { + super(options); + this.auxWeights = new Object2ObjectOpenHashMap<Object, IWeightValue>(size); + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, + final float gradient) { + IWeightValue auxWeight = auxWeights.get(feature); + if (auxWeight == null) { + auxWeight = newWeightValue(weight); + auxWeights.put(feature, auxWeight); + } else { + auxWeight.set(weight); + } + return update(auxWeight, gradient); + } + + } + + @NotThreadSafe + static final class AdamHD extends Optimizer.AdamHD { + + @Nonnull + private final Object2ObjectMap<Object, IWeightValue> auxWeights; + + public AdamHD(@Nonnegative int size, @Nonnull Map<String, String> options) { + super(options); + this.auxWeights = new Object2ObjectOpenHashMap<Object, IWeightValue>(size); + } + + @Override + protected float update(@Nonnull final Object feature, final float weight, final float gradient) { IWeightValue auxWeight = auxWeights.get(feature); if (auxWeight == null) { - auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeight = newWeightValue(weight); auxWeights.put(feature, auxWeight); } else { auxWeight.set(weight); @@ -168,11 +341,11 @@ public final class SparseOptimizerFactory { } @Override - public float update(@Nonnull final Object feature, final float weight, + protected float update(@Nonnull final Object feature, final float weight, final float gradient) { IWeightValue auxWeight = auxWeights.get(feature); if (auxWeight == null) { - auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); + auxWeight = newWeightValue(weight); auxWeights.put(feature, auxWeight); } else { auxWeight.set(weight); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/main/java/hivemall/utils/math/MathUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index 5b4a88a..e3f32f5 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -229,6 +229,10 @@ public final class MathUtils { return v < 0.f ? -1 : 1; } + public static double square(final double d) { + return d * d; + } + public static double log(final double n, final int base) { return Math.log(n) / Math.log(base); } @@ -429,4 +433,12 @@ public final class MathUtils { return arr; } + public static float clip(final float v, final float min, final float max) { + return Math.max(Math.min(v, max), min); + } + + public static double clip(final double v, final double min, final double max) { + return Math.max(Math.min(v, max), min); + } + }
