Add optimizer implementations
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f81948c5 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f81948c5 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f81948c5 Branch: refs/heads/JIRA-22/pr-285 Commit: f81948c5c7b83155eb29369a59f1fc65bb607f91 Parents: 5a7df55 Author: Takeshi YAMAMURO <[email protected]> Authored: Mon May 2 23:43:42 2016 +0900 Committer: Takeshi YAMAMURO <[email protected]> Committed: Wed Sep 21 00:07:28 2016 +0900 ---------------------------------------------------------------------- .../src/main/java/hivemall/LearnerBaseUDTF.java | 22 + .../hivemall/classifier/AROWClassifierUDTF.java | 2 +- .../hivemall/classifier/AdaGradRDAUDTF.java | 123 +---- .../classifier/BinaryOnlineClassifierUDTF.java | 3 + .../classifier/GeneralClassifierUDTF.java | 121 +++++ .../classifier/PassiveAggressiveUDTF.java | 2 +- .../main/java/hivemall/common/EtaEstimator.java | 160 ------- .../java/hivemall/common/LossFunctions.java | 467 ------------------- .../java/hivemall/fm/FMHyperParameters.java | 2 +- .../hivemall/fm/FactorizationMachineModel.java | 2 +- .../hivemall/fm/FactorizationMachineUDTF.java | 8 +- .../fm/FieldAwareFactorizationMachineModel.java | 1 + .../hivemall/mf/BPRMatrixFactorizationUDTF.java | 2 +- .../hivemall/mf/MatrixFactorizationSGDUDTF.java | 2 +- .../main/java/hivemall/model/DenseModel.java | 87 +--- .../main/java/hivemall/model/IWeightValue.java | 16 +- .../java/hivemall/model/PredictionModel.java | 5 +- .../model/SpaceEfficientDenseModel.java | 93 +--- .../main/java/hivemall/model/SparseModel.java | 20 +- .../model/SynchronizedModelWrapper.java | 16 +- .../main/java/hivemall/model/WeightValue.java | 162 ++++++- .../hivemall/model/WeightValueWithClock.java | 167 ++++++- .../optimizer/DenseOptimizerFactory.java | 215 +++++++++ .../java/hivemall/optimizer/EtaEstimator.java | 191 ++++++++ .../java/hivemall/optimizer/LossFunctions.java | 467 +++++++++++++++++++ .../main/java/hivemall/optimizer/Optimizer.java | 246 ++++++++++ .../java/hivemall/optimizer/Regularization.java | 99 ++++ .../optimizer/SparseOptimizerFactory.java | 171 +++++++ .../hivemall/regression/AROWRegressionUDTF.java | 2 +- .../java/hivemall/regression/AdaDeltaUDTF.java | 117 +---- .../java/hivemall/regression/AdaGradUDTF.java | 118 +---- .../regression/GeneralRegressionUDTF.java | 125 +++++ .../java/hivemall/regression/LogressUDTF.java | 63 +-- .../PassiveAggressiveRegressionUDTF.java | 2 +- .../hivemall/regression/RegressionBaseUDTF.java | 14 +- .../java/hivemall/optimizer/OptimizerTest.java | 172 +++++++ .../java/hivemall/mix/server/MixServerTest.java | 14 +- resources/ddl/define-all-as-permanent.hive | 13 +- resources/ddl/define-all.hive | 12 +- 39 files changed, 2301 insertions(+), 1223 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/LearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java index 4518cce..7fd5190 100644 --- a/core/src/main/java/hivemall/LearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java @@ -28,6 +28,9 @@ import hivemall.model.SparseModel; import hivemall.model.SynchronizedModelWrapper; import hivemall.model.WeightValue; import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.optimizer.DenseOptimizerFactory; +import hivemall.optimizer.Optimizer; +import hivemall.optimizer.SparseOptimizerFactory; import hivemall.utils.datetime.StopWatch; import hivemall.utils.hadoop.HadoopUtils; import hivemall.utils.hadoop.HiveUtils; @@ -38,6 +41,7 @@ import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.util.List; +import java.util.Map; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -195,6 +199,24 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { return model; } + // If a model implements a optimizer, it must override this + protected Map<String, String> getOptimzierOptions() { + return null; + } + + protected Optimizer createOptimizer() { + assert(!useCovariance()); + final Map<String, String> options = getOptimzierOptions(); + if(options != null) { + if (dense_model) { + return DenseOptimizerFactory.create(model_dims, options); + } else { + return SparseOptimizerFactory.create(model_dims, options); + } + } + return null; + } + protected MixClient configureMixClient(String connectURIs, String label, PredictionModel model) { assert (connectURIs != null); assert (model != null); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java index e5ef975..ac8afcb 100644 --- a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java @@ -18,7 +18,7 @@ */ package hivemall.classifier; -import hivemall.common.LossFunctions; +import hivemall.optimizer.LossFunctions; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; import hivemall.model.PredictionResult; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java index 1351bca..a6714f4 100644 --- a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java +++ b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java @@ -18,124 +18,13 @@ */ package hivemall.classifier; -import hivemall.common.LossFunctions; -import hivemall.model.FeatureValue; -import hivemall.model.IWeightValue; -import hivemall.model.WeightValue.WeightValueParamsF2; -import hivemall.utils.lang.Primitives; +@Deprecated +public final class AdaGradRDAUDTF extends GeneralClassifierUDTF { -import javax.annotation.Nonnull; - -import org.apache.commons.cli.CommandLine; -import org.apache.commons.cli.Options; -import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; - -@Description(name = "train_adagrad_rda", - value = "_FUNC_(list<string|int|bigint> features, int label [, const string options])" - + " - Returns a relation consists of <string|int|bigint feature, float weight>", - extended = "Build a prediction model by Adagrad+RDA regularization binary classifier") -public final class AdaGradRDAUDTF extends BinaryOnlineClassifierUDTF { - - private float eta; - private float lambda; - private float scaling; - - @Override - public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - final int numArgs = argOIs.length; - if (numArgs != 2 && numArgs != 3) { - throw new UDFArgumentException( - "_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, int label [, constant string options]"); - } - - StructObjectInspector oi = super.initialize(argOIs); - model.configureParams(true, false, true); - return oi; + public AdaGradRDAUDTF() { + optimizerOptions.put("optimizer", "AdaGrad"); + optimizerOptions.put("regularization", "RDA"); + optimizerOptions.put("lambda", "1e-6"); } - @Override - protected Options getOptions() { - Options opts = super.getOptions(); - opts.addOption("eta", "eta0", true, "The learning rate \\eta [default 0.1]"); - opts.addOption("lambda", true, "lambda constant of RDA [default: 1E-6f]"); - opts.addOption("scale", true, - "Internal scaling/descaling factor for cumulative weights [default: 100]"); - return opts; - } - - @Override - protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { - CommandLine cl = super.processOptions(argOIs); - if (cl == null) { - this.eta = 0.1f; - this.lambda = 1E-6f; - this.scaling = 100f; - } else { - this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.1f); - this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), 1E-6f); - this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f); - } - return cl; - } - - @Override - protected void train(@Nonnull final FeatureValue[] features, final int label) { - final float y = label > 0 ? 1.f : -1.f; - - float p = predict(features); - float loss = LossFunctions.hingeLoss(p, y); // 1.0 - y * p - if (loss <= 0.f) { // max(0, 1 - y * p) - return; - } - // subgradient => -y * W dot xi - update(features, y, count); - } - - protected void update(@Nonnull final FeatureValue[] features, final float y, final int t) { - for (FeatureValue f : features) {// w[f] += y * x[f] - if (f == null) { - continue; - } - Object x = f.getFeature(); - float xi = f.getValueAsFloat(); - - updateWeight(x, xi, y, t); - } - } - - protected void updateWeight(@Nonnull final Object x, final float xi, final float y, - final float t) { - final float gradient = -y * xi; - final float scaled_gradient = gradient * scaling; - - float scaled_sum_sqgrad = 0.f; - float scaled_sum_grad = 0.f; - IWeightValue old = model.get(x); - if (old != null) { - scaled_sum_sqgrad = old.getSumOfSquaredGradients(); - scaled_sum_grad = old.getSumOfGradients(); - } - scaled_sum_grad += scaled_gradient; - scaled_sum_sqgrad += (scaled_gradient * scaled_gradient); - - float sum_grad = scaled_sum_grad * scaling; - double sum_sqgrad = scaled_sum_sqgrad * scaling; - - // sign(u_{t,i}) - float sign = (sum_grad > 0.f) ? 1.f : -1.f; - // |u_{t,i}|/t - \lambda - float meansOfGradients = sign * sum_grad / t - lambda; - if (meansOfGradients < 0.f) { - // x_{t,i} = 0 - model.delete(x); - } else { - // x_{t,i} = -sign(u_{t,i}) * \frac{\eta t}{\sqrt{G_{t,ii}}}(|u_{t,i}|/t - \lambda) - float weight = -1.f * sign * eta * t * meansOfGradients / (float) Math.sqrt(sum_sqgrad); - IWeightValue new_w = new WeightValueParamsF2(weight, scaled_sum_sqgrad, scaled_sum_grad); - model.set(x, new_w); - } - } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java index 43a124d..0ee5d5f 100644 --- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java @@ -25,6 +25,7 @@ import hivemall.model.PredictionModel; import hivemall.model.PredictionResult; import hivemall.model.WeightValue; import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.optimizer.Optimizer; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; @@ -56,6 +57,7 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { private boolean parseFeature; protected PredictionModel model; + protected Optimizer optimizerImpl; protected int count; @Override @@ -76,6 +78,7 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { if (preloadedModelFile != null) { loadPredictionModel(model, preloadedModelFile, featureOutputOI); } + this.optimizerImpl = createOptimizer(); this.count = 0; return getReturnOI(featureOutputOI); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java new file mode 100644 index 0000000..feebadd --- /dev/null +++ b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java @@ -0,0 +1,121 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.classifier; + +import java.util.HashMap; +import java.util.Map; +import javax.annotation.Nonnull; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; + +import hivemall.optimizer.LossFunctions; +import hivemall.model.FeatureValue; + +/** + * A general classifier class with replaceable optimization functions. + */ +public class GeneralClassifierUDTF extends BinaryOnlineClassifierUDTF { + + protected final Map<String, String> optimizerOptions; + + public GeneralClassifierUDTF() { + this.optimizerOptions = new HashMap<String, String>(); + // Set default values + optimizerOptions.put("optimizer", "adagrad"); + optimizerOptions.put("eta", "fixed"); + optimizerOptions.put("eta0", "1.0"); + optimizerOptions.put("regularization", "RDA"); + optimizerOptions.put("lambda", "1e-6"); + optimizerOptions.put("scale", "100.0"); + optimizerOptions.put("lambda", "1.0"); + } + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if(argOIs.length != 2 && argOIs.length != 3) { + throw new UDFArgumentException( + this.getClass().getSimpleName() + + " takes 2 or 3 arguments: List<Text|Int|BitInt> features, int label " + + "[, constant string options]"); + } + return super.initialize(argOIs); + } + + @Override + protected Options getOptions() { + Options opts = super.getOptions(); + opts.addOption("optimizer", "opt", true, "Optimizer to update weights [default: adagrad+rda]"); + opts.addOption("eta", "eta0", true, "Initial learning rate [default 1.0]"); + opts.addOption("lambda", true, "Lambda value of RDA [default: 1e-6f]"); + opts.addOption("scale", true, "Scaling factor for cumulative weights [100.0]"); + opts.addOption("regularization", "reg", true, "Regularization type [default not-defined]"); + opts.addOption("lambda", true, "Regularization term on weights [default 1.0]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + final CommandLine cl = super.processOptions(argOIs); + assert(cl != null); + if(cl != null) { + for(final String arg : cl.getArgs()) { + optimizerOptions.put(arg, cl.getOptionValue(arg)); + } + } + return cl; + } + + @Override + protected Map<String, String> getOptimzierOptions() { + return optimizerOptions; + } + + @Override + protected void train(@Nonnull final FeatureValue[] features, final int label) { + float predicted = predict(features); + update(features, label > 0 ? 1.f : -1.f, predicted); + } + + @Override + protected void update(@Nonnull final FeatureValue[] features, final float label, + final float predicted) { + if(is_mini_batch) { + throw new UnsupportedOperationException( + this.getClass().getSimpleName() + " supports no `is_mini_batch` mode"); + } else { + float loss = LossFunctions.hingeLoss(predicted, label); + if(loss <= 0.f) { + return; + } + for(FeatureValue f : features) { + Object feature = f.getFeature(); + float xi = f.getValueAsFloat(); + float weight = model.getWeight(feature); + float new_weight = optimizerImpl.computeUpdatedValue(feature, weight, -label * xi); + model.setWeight(feature, new_weight); + } + optimizerImpl.proceedStep(); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java index 0213dec..9e404cd 100644 --- a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java +++ b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java @@ -18,7 +18,7 @@ */ package hivemall.classifier; -import hivemall.common.LossFunctions; +import hivemall.optimizer.LossFunctions; import hivemall.model.FeatureValue; import hivemall.model.PredictionResult; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/common/EtaEstimator.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/common/EtaEstimator.java b/core/src/main/java/hivemall/common/EtaEstimator.java deleted file mode 100644 index 3287641..0000000 --- a/core/src/main/java/hivemall/common/EtaEstimator.java +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Hivemall: Hive scalable Machine Learning Library - * - * Copyright (C) 2015 Makoto YUI - * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package hivemall.common; - -import hivemall.utils.lang.NumberUtils; -import hivemall.utils.lang.Primitives; - -import javax.annotation.Nonnegative; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import org.apache.commons.cli.CommandLine; -import org.apache.hadoop.hive.ql.exec.UDFArgumentException; - -public abstract class EtaEstimator { - - protected final float eta0; - - public EtaEstimator(float eta0) { - this.eta0 = eta0; - } - - public float eta0() { - return eta0; - } - - public abstract float eta(long t); - - public void update(@Nonnegative float multipler) {} - - public static final class FixedEtaEstimator extends EtaEstimator { - - public FixedEtaEstimator(float eta) { - super(eta); - } - - @Override - public float eta(long t) { - return eta0; - } - - } - - public static final class SimpleEtaEstimator extends EtaEstimator { - - private final float finalEta; - private final double total_steps; - - public SimpleEtaEstimator(float eta0, long total_steps) { - super(eta0); - this.finalEta = (float) (eta0 / 2.d); - this.total_steps = total_steps; - } - - @Override - public float eta(final long t) { - if (t > total_steps) { - return finalEta; - } - return (float) (eta0 / (1.d + (t / total_steps))); - } - - } - - public static final class InvscalingEtaEstimator extends EtaEstimator { - - private final double power_t; - - public InvscalingEtaEstimator(float eta0, double power_t) { - super(eta0); - this.power_t = power_t; - } - - @Override - public float eta(final long t) { - return (float) (eta0 / Math.pow(t, power_t)); - } - - } - - /** - * bold driver: Gemulla et al., Large-scale matrix factorization with distributed stochastic - * gradient descent, KDD 2011. - */ - public static final class AdjustingEtaEstimator extends EtaEstimator { - - private float eta; - - public AdjustingEtaEstimator(float eta) { - super(eta); - this.eta = eta; - } - - @Override - public float eta(long t) { - return eta; - } - - @Override - public void update(@Nonnegative float multipler) { - float newEta = eta * multipler; - if (!NumberUtils.isFinite(newEta)) { - // avoid NaN or INFINITY - return; - } - this.eta = Math.min(eta0, newEta); // never be larger than eta0 - } - - } - - @Nonnull - public static EtaEstimator get(@Nullable CommandLine cl) throws UDFArgumentException { - return get(cl, 0.1f); - } - - @Nonnull - public static EtaEstimator get(@Nullable CommandLine cl, float defaultEta0) - throws UDFArgumentException { - if (cl == null) { - return new InvscalingEtaEstimator(defaultEta0, 0.1d); - } - - if (cl.hasOption("boldDriver")) { - float eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.3f); - return new AdjustingEtaEstimator(eta); - } - - String etaValue = cl.getOptionValue("eta"); - if (etaValue != null) { - float eta = Float.parseFloat(etaValue); - return new FixedEtaEstimator(eta); - } - - float eta0 = Primitives.parseFloat(cl.getOptionValue("eta0"), defaultEta0); - if (cl.hasOption("t")) { - long t = Long.parseLong(cl.getOptionValue("t")); - return new SimpleEtaEstimator(eta0, t); - } - - double power_t = Primitives.parseDouble(cl.getOptionValue("power_t"), 0.1d); - return new InvscalingEtaEstimator(eta0, power_t); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/common/LossFunctions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/common/LossFunctions.java b/core/src/main/java/hivemall/common/LossFunctions.java deleted file mode 100644 index 6b403fd..0000000 --- a/core/src/main/java/hivemall/common/LossFunctions.java +++ /dev/null @@ -1,467 +0,0 @@ -/* - * Hivemall: Hive scalable Machine Learning Library - * - * Copyright (C) 2015 Makoto YUI - * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package hivemall.common; - -import hivemall.utils.math.MathUtils; - -/** - * @link https://github.com/JohnLangford/vowpal_wabbit/wiki/Loss-functions - */ -public final class LossFunctions { - - public enum LossType { - SquaredLoss, LogLoss, HingeLoss, SquaredHingeLoss, QuantileLoss, EpsilonInsensitiveLoss - } - - public static LossFunction getLossFunction(String type) { - if ("SquaredLoss".equalsIgnoreCase(type)) { - return new SquaredLoss(); - } else if ("LogLoss".equalsIgnoreCase(type)) { - return new LogLoss(); - } else if ("HingeLoss".equalsIgnoreCase(type)) { - return new HingeLoss(); - } else if ("SquaredHingeLoss".equalsIgnoreCase(type)) { - return new SquaredHingeLoss(); - } else if ("QuantileLoss".equalsIgnoreCase(type)) { - return new QuantileLoss(); - } else if ("EpsilonInsensitiveLoss".equalsIgnoreCase(type)) { - return new EpsilonInsensitiveLoss(); - } - throw new IllegalArgumentException("Unsupported type: " + type); - } - - public static LossFunction getLossFunction(LossType type) { - switch (type) { - case SquaredLoss: - return new SquaredLoss(); - case LogLoss: - return new LogLoss(); - case HingeLoss: - return new HingeLoss(); - case SquaredHingeLoss: - return new SquaredHingeLoss(); - case QuantileLoss: - return new QuantileLoss(); - case EpsilonInsensitiveLoss: - return new EpsilonInsensitiveLoss(); - default: - throw new IllegalArgumentException("Unsupported type: " + type); - } - } - - public interface LossFunction { - - /** - * Evaluate the loss function. - * - * @param p The prediction, p = w^T x - * @param y The true value (aka target) - * @return The loss evaluated at `p` and `y`. - */ - public float loss(float p, float y); - - public double loss(double p, double y); - - /** - * Evaluate the derivative of the loss function with respect to the prediction `p`. - * - * @param p The prediction, p = w^T x - * @param y The true value (aka target) - * @return The derivative of the loss function w.r.t. `p`. - */ - public float dloss(float p, float y); - - public boolean forBinaryClassification(); - - public boolean forRegression(); - - } - - public static abstract class BinaryLoss implements LossFunction { - - protected static void checkTarget(float y) { - if (!(y == 1.f || y == -1.f)) { - throw new IllegalArgumentException("target must be [+1,-1]: " + y); - } - } - - protected static void checkTarget(double y) { - if (!(y == 1.d || y == -1.d)) { - throw new IllegalArgumentException("target must be [+1,-1]: " + y); - } - } - - @Override - public boolean forBinaryClassification() { - return true; - } - - @Override - public boolean forRegression() { - return false; - } - } - - public static abstract class RegressionLoss implements LossFunction { - - @Override - public boolean forBinaryClassification() { - return false; - } - - @Override - public boolean forRegression() { - return true; - } - - } - - /** - * Squared loss for regression problems. - * - * If you're trying to minimize the mean error, use squared-loss. - */ - public static final class SquaredLoss extends RegressionLoss { - - @Override - public float loss(float p, float y) { - final float z = p - y; - return z * z * 0.5f; - } - - @Override - public double loss(double p, double y) { - final double z = p - y; - return z * z * 0.5d; - } - - @Override - public float dloss(float p, float y) { - return p - y; // 2 (p - y) / 2 - } - } - - /** - * Logistic regression loss for binary classification with y in {-1, 1}. - */ - public static final class LogLoss extends BinaryLoss { - - /** - * <code>logloss(p,y) = log(1+exp(-p*y))</code> - */ - @Override - public float loss(float p, float y) { - checkTarget(y); - - final float z = y * p; - if (z > 18.f) { - return (float) Math.exp(-z); - } - if (z < -18.f) { - return -z; - } - return (float) Math.log(1.d + Math.exp(-z)); - } - - @Override - public double loss(double p, double y) { - checkTarget(y); - - final double z = y * p; - if (z > 18.d) { - return Math.exp(-z); - } - if (z < -18.d) { - return -z; - } - return Math.log(1.d + Math.exp(-z)); - } - - @Override - public float dloss(float p, float y) { - checkTarget(y); - - float z = y * p; - if (z > 18.f) { - return (float) Math.exp(-z) * -y; - } - if (z < -18.f) { - return -y; - } - return -y / ((float) Math.exp(z) + 1.f); - } - } - - /** - * Hinge loss for binary classification tasks with y in {-1,1}. - */ - public static final class HingeLoss extends BinaryLoss { - - private float threshold; - - public HingeLoss() { - this(1.f); - } - - /** - * @param threshold Margin threshold. When threshold=1.0, one gets the loss used by SVM. - * When threshold=0.0, one gets the loss used by the Perceptron. - */ - public HingeLoss(float threshold) { - this.threshold = threshold; - } - - public void setThreshold(float threshold) { - this.threshold = threshold; - } - - @Override - public float loss(float p, float y) { - float loss = hingeLoss(p, y, threshold); - return (loss > 0.f) ? loss : 0.f; - } - - @Override - public double loss(double p, double y) { - double loss = hingeLoss(p, y, threshold); - return (loss > 0.d) ? loss : 0.d; - } - - @Override - public float dloss(float p, float y) { - float loss = hingeLoss(p, y, threshold); - return (loss > 0.f) ? -y : 0.f; - } - } - - /** - * Squared Hinge loss for binary classification tasks with y in {-1,1}. - */ - public static final class SquaredHingeLoss extends BinaryLoss { - - @Override - public float loss(float p, float y) { - return squaredHingeLoss(p, y); - } - - @Override - public double loss(double p, double y) { - return squaredHingeLoss(p, y); - } - - @Override - public float dloss(float p, float y) { - checkTarget(y); - - float d = 1 - (y * p); - return (d > 0.f) ? -2.f * d * y : 0.f; - } - - } - - /** - * Quantile loss is useful to predict rank/order and you do not mind the mean error to increase - * as long as you get the relative order correct. - * - * @link http://en.wikipedia.org/wiki/Quantile_regression - */ - public static final class QuantileLoss extends RegressionLoss { - - private float tau; - - public QuantileLoss() { - this.tau = 0.5f; - } - - public QuantileLoss(float tau) { - setTau(tau); - } - - public void setTau(float tau) { - if (tau <= 0 || tau >= 1.0) { - throw new IllegalArgumentException("tau must be in range (0, 1): " + tau); - } - this.tau = tau; - } - - @Override - public float loss(float p, float y) { - float e = y - p; - if (e > 0.f) { - return tau * e; - } else { - return -(1.f - tau) * e; - } - } - - @Override - public double loss(double p, double y) { - double e = y - p; - if (e > 0.d) { - return tau * e; - } else { - return -(1.d - tau) * e; - } - } - - @Override - public float dloss(float p, float y) { - float e = y - p; - if (e == 0.f) { - return 0.f; - } - return (e > 0.f) ? -tau : (1.f - tau); - } - - } - - /** - * Epsilon-Insensitive loss used by Support Vector Regression (SVR). - * <code>loss = max(0, |y - p| - epsilon)</code> - */ - public static final class EpsilonInsensitiveLoss extends RegressionLoss { - - private float epsilon; - - public EpsilonInsensitiveLoss() { - this(0.1f); - } - - public EpsilonInsensitiveLoss(float epsilon) { - this.epsilon = epsilon; - } - - public void setEpsilon(float epsilon) { - this.epsilon = epsilon; - } - - @Override - public float loss(float p, float y) { - float loss = Math.abs(y - p) - epsilon; - return (loss > 0.f) ? loss : 0.f; - } - - @Override - public double loss(double p, double y) { - double loss = Math.abs(y - p) - epsilon; - return (loss > 0.d) ? loss : 0.d; - } - - @Override - public float dloss(float p, float y) { - if ((y - p) > epsilon) {// real value > predicted value - epsilon - return -1.f; - } - if ((p - y) > epsilon) {// real value < predicted value - epsilon - return 1.f; - } - return 0.f; - } - - } - - public static float logisticLoss(final float target, final float predicted) { - if (predicted > -100.d) { - return target - (float) MathUtils.sigmoid(predicted); - } else { - return target; - } - } - - public static float logLoss(final float p, final float y) { - BinaryLoss.checkTarget(y); - - final float z = y * p; - if (z > 18.f) { - return (float) Math.exp(-z); - } - if (z < -18.f) { - return -z; - } - return (float) Math.log(1.d + Math.exp(-z)); - } - - public static double logLoss(final double p, final double y) { - BinaryLoss.checkTarget(y); - - final double z = y * p; - if (z > 18.d) { - return Math.exp(-z); - } - if (z < -18.d) { - return -z; - } - return Math.log(1.d + Math.exp(-z)); - } - - public static float squaredLoss(float p, float y) { - final float z = p - y; - return z * z * 0.5f; - } - - public static double squaredLoss(double p, double y) { - final double z = p - y; - return z * z * 0.5d; - } - - public static float hingeLoss(final float p, final float y, final float threshold) { - BinaryLoss.checkTarget(y); - - float z = y * p; - return threshold - z; - } - - public static double hingeLoss(final double p, final double y, final double threshold) { - BinaryLoss.checkTarget(y); - - double z = y * p; - return threshold - z; - } - - public static float hingeLoss(float p, float y) { - return hingeLoss(p, y, 1.f); - } - - public static double hingeLoss(double p, double y) { - return hingeLoss(p, y, 1.d); - } - - public static float squaredHingeLoss(final float p, final float y) { - BinaryLoss.checkTarget(y); - - float z = y * p; - float d = 1.f - z; - return (d > 0.f) ? (d * d) : 0.f; - } - - public static double squaredHingeLoss(final double p, final double y) { - BinaryLoss.checkTarget(y); - - double z = y * p; - double d = 1.d - z; - return (d > 0.d) ? d * d : 0.d; - } - - /** - * Math.abs(target - predicted) - epsilon - */ - public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) { - return Math.abs(target - predicted) - epsilon; - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/fm/FMHyperParameters.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMHyperParameters.java b/core/src/main/java/hivemall/fm/FMHyperParameters.java index db69db3..512476d 100644 --- a/core/src/main/java/hivemall/fm/FMHyperParameters.java +++ b/core/src/main/java/hivemall/fm/FMHyperParameters.java @@ -17,8 +17,8 @@ */ package hivemall.fm; -import hivemall.common.EtaEstimator; import hivemall.fm.FactorizationMachineModel.VInitScheme; +import hivemall.optimizer.EtaEstimator; import hivemall.utils.lang.Primitives; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/fm/FactorizationMachineModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java index 396328a..4b6ece6 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java @@ -18,7 +18,7 @@ */ package hivemall.fm; -import hivemall.common.EtaEstimator; +import hivemall.optimizer.EtaEstimator; import hivemall.utils.lang.NumberUtils; import hivemall.utils.math.MathUtils; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java index 2388689..7739c52 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java @@ -20,10 +20,10 @@ package hivemall.fm; import hivemall.UDTFWithOptions; import hivemall.common.ConversionState; -import hivemall.common.EtaEstimator; -import hivemall.common.LossFunctions; -import hivemall.common.LossFunctions.LossFunction; -import hivemall.common.LossFunctions.LossType; +import hivemall.optimizer.EtaEstimator; +import hivemall.optimizer.LossFunctions; +import hivemall.optimizer.LossFunctions.LossFunction; +import hivemall.optimizer.LossFunctions.LossType; import hivemall.fm.FMStringFeatureMapModel.Entry; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java index 7e3cc50..fde7701 100644 --- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java +++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java @@ -21,6 +21,7 @@ package hivemall.fm; import hivemall.fm.FMHyperParameters.FFMHyperParameters; import hivemall.utils.collections.DoubleArray3D; import hivemall.utils.collections.IntArrayList; +import hivemall.optimizer.EtaEstimator; import hivemall.utils.lang.NumberUtils; import java.util.Arrays; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java index d859f29..87d2654 100644 --- a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java +++ b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java @@ -20,7 +20,7 @@ package hivemall.mf; import hivemall.UDTFWithOptions; import hivemall.common.ConversionState; -import hivemall.common.EtaEstimator; +import hivemall.optimizer.EtaEstimator; import hivemall.mf.FactorizedModel.RankInitScheme; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.FileUtils; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java b/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java index 317da85..ab79ce2 100644 --- a/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java +++ b/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java @@ -18,7 +18,7 @@ */ package hivemall.mf; -import hivemall.common.EtaEstimator; +import hivemall.optimizer.EtaEstimator; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/DenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/DenseModel.java b/core/src/main/java/hivemall/model/DenseModel.java index ee57574..6956875 100644 --- a/core/src/main/java/hivemall/model/DenseModel.java +++ b/core/src/main/java/hivemall/model/DenseModel.java @@ -18,21 +18,18 @@ */ package hivemall.model; -import hivemall.model.WeightValue.WeightValueParamsF1; -import hivemall.model.WeightValue.WeightValueParamsF2; -import hivemall.model.WeightValue.WeightValueWithCovar; -import hivemall.utils.collections.IMapIterator; -import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.lang.Copyable; -import hivemall.utils.math.MathUtils; - import java.util.Arrays; - import javax.annotation.Nonnull; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Copyable; +import hivemall.utils.math.MathUtils; + public final class DenseModel extends AbstractPredictionModel { private static final Log logger = LogFactory.getLog(DenseModel.class); @@ -40,13 +37,6 @@ public final class DenseModel extends AbstractPredictionModel { private float[] weights; private float[] covars; - // optional values for adagrad - private float[] sum_of_squared_gradients; - // optional value for adadelta - private float[] sum_of_squared_delta_x; - // optional value for adagrad+rda - private float[] sum_of_gradients; - // optional value for MIX private short[] clocks; private byte[] deltaUpdates; @@ -67,9 +57,6 @@ public final class DenseModel extends AbstractPredictionModel { } else { this.covars = null; } - this.sum_of_squared_gradients = null; - this.sum_of_squared_delta_x = null; - this.sum_of_gradients = null; this.clocks = null; this.deltaUpdates = null; } @@ -85,20 +72,6 @@ public final class DenseModel extends AbstractPredictionModel { } @Override - public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, - boolean sum_of_gradients) { - if (sum_of_squared_gradients) { - this.sum_of_squared_gradients = new float[size]; - } - if (sum_of_squared_delta_x) { - this.sum_of_squared_delta_x = new float[size]; - } - if (sum_of_gradients) { - this.sum_of_gradients = new float[size]; - } - } - - @Override public void configureClock() { if (clocks == null) { this.clocks = new short[size]; @@ -129,16 +102,7 @@ public final class DenseModel extends AbstractPredictionModel { this.covars = Arrays.copyOf(covars, newSize); Arrays.fill(covars, oldSize, newSize, 1.f); } - if (sum_of_squared_gradients != null) { - this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); - } - if (sum_of_squared_delta_x != null) { - this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize); - } - if (sum_of_gradients != null) { - this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize); - } - if (clocks != null) { + if(clocks != null) { this.clocks = Arrays.copyOf(clocks, newSize); this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize); } @@ -152,17 +116,7 @@ public final class DenseModel extends AbstractPredictionModel { if (i >= size) { return null; } - if (sum_of_squared_gradients != null) { - if (sum_of_squared_delta_x != null) { - return (T) new WeightValueParamsF2(weights[i], sum_of_squared_gradients[i], - sum_of_squared_delta_x[i]); - } else if (sum_of_gradients != null) { - return (T) new WeightValueParamsF2(weights[i], sum_of_squared_gradients[i], - sum_of_gradients[i]); - } else { - return (T) new WeightValueParamsF1(weights[i], sum_of_squared_gradients[i]); - } - } else if (covars != null) { + if(covars != null) { return (T) new WeightValueWithCovar(weights[i], covars[i]); } else { return (T) new WeightValue(weights[i]); @@ -181,15 +135,6 @@ public final class DenseModel extends AbstractPredictionModel { covar = value.getCovariance(); covars[i] = covar; } - if (sum_of_squared_gradients != null) { - sum_of_squared_gradients[i] = value.getSumOfSquaredGradients(); - } - if (sum_of_squared_delta_x != null) { - sum_of_squared_delta_x[i] = value.getSumOfSquaredDeltaX(); - } - if (sum_of_gradients != null) { - sum_of_gradients[i] = value.getSumOfGradients(); - } short clock = 0; int delta = 0; if (clocks != null && value.isTouched()) { @@ -213,15 +158,6 @@ public final class DenseModel extends AbstractPredictionModel { if (covars != null) { covars[i] = 1.f; } - if (sum_of_squared_gradients != null) { - sum_of_squared_gradients[i] = 0.f; - } - if (sum_of_squared_delta_x != null) { - sum_of_squared_delta_x[i] = 0.f; - } - if (sum_of_gradients != null) { - sum_of_gradients[i] = 0.f; - } // avoid clock/delta } @@ -235,6 +171,13 @@ public final class DenseModel extends AbstractPredictionModel { } @Override + public void setWeight(Object feature, float value) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weights[i] = value; + } + + @Override public float getCovariance(Object feature) { int i = HiveUtils.parseInt(feature); if (i >= size) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/IWeightValue.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/IWeightValue.java b/core/src/main/java/hivemall/model/IWeightValue.java index 988e4a1..259628f 100644 --- a/core/src/main/java/hivemall/model/IWeightValue.java +++ b/core/src/main/java/hivemall/model/IWeightValue.java @@ -25,7 +25,7 @@ import javax.annotation.Nonnegative; public interface IWeightValue extends Copyable<IWeightValue> { public enum WeightValueType { - NoParams, ParamsF1, ParamsF2, ParamsCovar; + NoParams, ParamsF1, ParamsF2, ParamsF3, ParamsCovar; } WeightValueType getType(); @@ -44,10 +44,24 @@ public interface IWeightValue extends Copyable<IWeightValue> { float getSumOfSquaredGradients(); + void setSumOfSquaredGradients(float value); + float getSumOfSquaredDeltaX(); + void setSumOfSquaredDeltaX(float value); + float getSumOfGradients(); + void setSumOfGradients(float value); + + float getM(); + + void setM(float value); + + float getV(); + + void setV(float value); + /** * @return whether touched in training or not */ http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/PredictionModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/PredictionModel.java b/core/src/main/java/hivemall/model/PredictionModel.java index a8efee0..8d8dd2b 100644 --- a/core/src/main/java/hivemall/model/PredictionModel.java +++ b/core/src/main/java/hivemall/model/PredictionModel.java @@ -34,9 +34,6 @@ public interface PredictionModel extends MixedModel { boolean hasCovariance(); - void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, - boolean sum_of_gradients); - void configureClock(); boolean hasClock(); @@ -56,6 +53,8 @@ public interface PredictionModel extends MixedModel { float getWeight(@Nonnull Object feature); + void setWeight(@Nonnull Object feature, float value); + float getCovariance(@Nonnull Object feature); <K, V extends IWeightValue> IMapIterator<K, V> entries(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java index b3cd3ff..8b668e7 100644 --- a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java +++ b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java @@ -18,8 +18,6 @@ */ package hivemall.model; -import hivemall.model.WeightValue.WeightValueParamsF1; -import hivemall.model.WeightValue.WeightValueParamsF2; import hivemall.model.WeightValue.WeightValueWithCovar; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; @@ -28,7 +26,6 @@ import hivemall.utils.lang.HalfFloat; import hivemall.utils.math.MathUtils; import java.util.Arrays; - import javax.annotation.Nonnull; import org.apache.commons.logging.Log; @@ -41,13 +38,6 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { private short[] weights; private short[] covars; - // optional value for adagrad - private float[] sum_of_squared_gradients; - // optional value for adadelta - private float[] sum_of_squared_delta_x; - // optional value for adagrad+rda - private float[] sum_of_gradients; - // optional value for MIX private short[] clocks; private byte[] deltaUpdates; @@ -68,9 +58,6 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } else { this.covars = null; } - this.sum_of_squared_gradients = null; - this.sum_of_squared_delta_x = null; - this.sum_of_gradients = null; this.clocks = null; this.deltaUpdates = null; } @@ -86,20 +73,6 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override - public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, - boolean sum_of_gradients) { - if (sum_of_squared_gradients) { - this.sum_of_squared_gradients = new float[size]; - } - if (sum_of_squared_delta_x) { - this.sum_of_squared_delta_x = new float[size]; - } - if (sum_of_gradients) { - this.sum_of_gradients = new float[size]; - } - } - - @Override public void configureClock() { if (clocks == null) { this.clocks = new short[size]; @@ -126,8 +99,11 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { return HalfFloat.halfFloatToFloat(covars[i]); } - private void setWeight(final int i, final float v) { - HalfFloat.checkRange(v); + private void _setWeight(final int i, final float v) { + if(Math.abs(v) >= HalfFloat.MAX_FLOAT) { + throw new IllegalArgumentException("Acceptable maximum weight is " + + HalfFloat.MAX_FLOAT + ": " + v); + } weights[i] = HalfFloat.floatToHalfFloat(v); } @@ -149,16 +125,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { this.covars = Arrays.copyOf(covars, newSize); Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE); } - if (sum_of_squared_gradients != null) { - this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); - } - if (sum_of_squared_delta_x != null) { - this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize); - } - if (sum_of_gradients != null) { - this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize); - } - if (clocks != null) { + if(clocks != null) { this.clocks = Arrays.copyOf(clocks, newSize); this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize); } @@ -172,17 +139,8 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { if (i >= size) { return null; } - if (sum_of_squared_gradients != null) { - if (sum_of_squared_delta_x != null) { - return (T) new WeightValueParamsF2(getWeight(i), sum_of_squared_gradients[i], - sum_of_squared_delta_x[i]); - } else if (sum_of_gradients != null) { - return (T) new WeightValueParamsF2(getWeight(i), sum_of_squared_gradients[i], - sum_of_gradients[i]); - } else { - return (T) new WeightValueParamsF1(getWeight(i), sum_of_squared_gradients[i]); - } - } else if (covars != null) { + + if(covars != null) { return (T) new WeightValueWithCovar(getWeight(i), getCovar(i)); } else { return (T) new WeightValue(getWeight(i)); @@ -194,22 +152,13 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { int i = HiveUtils.parseInt(feature); ensureCapacity(i); float weight = value.get(); - setWeight(i, weight); + _setWeight(i, weight); float covar = 1.f; boolean hasCovar = value.hasCovariance(); if (hasCovar) { covar = value.getCovariance(); setCovar(i, covar); } - if (sum_of_squared_gradients != null) { - sum_of_squared_gradients[i] = value.getSumOfSquaredGradients(); - } - if (sum_of_squared_delta_x != null) { - sum_of_squared_delta_x[i] = value.getSumOfSquaredDeltaX(); - } - if (sum_of_gradients != null) { - sum_of_gradients[i] = value.getSumOfGradients(); - } short clock = 0; int delta = 0; if (clocks != null && value.isTouched()) { @@ -229,19 +178,10 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { if (i >= size) { return; } - setWeight(i, 0.f); - if (covars != null) { + _setWeight(i, 0.f); + if(covars != null) { setCovar(i, 1.f); } - if (sum_of_squared_gradients != null) { - sum_of_squared_gradients[i] = 0.f; - } - if (sum_of_squared_delta_x != null) { - sum_of_squared_delta_x[i] = 0.f; - } - if (sum_of_gradients != null) { - sum_of_gradients[i] = 0.f; - } // avoid clock/delta } @@ -255,6 +195,13 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override + public void setWeight(Object feature, float value) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + _setWeight(i, value); + } + + @Override public float getCovariance(Object feature) { int i = HiveUtils.parseInt(feature); if (i >= size) { @@ -267,7 +214,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { protected void _set(Object feature, float weight, short clock) { int i = ((Integer) feature).intValue(); ensureCapacity(i); - setWeight(i, weight); + _setWeight(i, weight); clocks[i] = clock; deltaUpdates[i] = 0; } @@ -276,7 +223,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { protected void _set(Object feature, float weight, float covar, short clock) { int i = ((Integer) feature).intValue(); ensureCapacity(i); - setWeight(i, weight); + _setWeight(i, weight); setCovar(i, covar); clocks[i] = clock; deltaUpdates[i] = 0; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/SparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java index aaab869..bab982f 100644 --- a/core/src/main/java/hivemall/model/SparseModel.java +++ b/core/src/main/java/hivemall/model/SparseModel.java @@ -36,6 +36,10 @@ public final class SparseModel extends AbstractPredictionModel { private final boolean hasCovar; private boolean clockEnabled; + public SparseModel(int size) { + this(size, false); + } + public SparseModel(int size, boolean hasCovar) { super(); this.weights = new OpenHashMap<Object, IWeightValue>(size); @@ -54,10 +58,6 @@ public final class SparseModel extends AbstractPredictionModel { } @Override - public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, - boolean sum_of_gradients) {} - - @Override public void configureClock() { this.clockEnabled = true; } @@ -131,6 +131,18 @@ public final class SparseModel extends AbstractPredictionModel { } @Override + public void setWeight(Object feature, float value) { + if(weights.containsKey(feature)) { + IWeightValue weight = weights.get(feature); + weight.set(value); + } else { + IWeightValue weight = new WeightValue(value); + weight.set(value); + weights.put(feature, weight); + } + } + + @Override public float getCovariance(final Object feature) { IWeightValue v = weights.get(feature); return v == null ? 1.f : v.getCovariance(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java index 99ee69c..87e89b6 100644 --- a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java +++ b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java @@ -63,12 +63,6 @@ public final class SynchronizedModelWrapper implements PredictionModel { } @Override - public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, - boolean sum_of_gradients) { - model.configureParams(sum_of_squared_gradients, sum_of_squared_delta_x, sum_of_gradients); - } - - @Override public void configureClock() { model.configureClock(); } @@ -157,6 +151,16 @@ public final class SynchronizedModelWrapper implements PredictionModel { } @Override + public void setWeight(Object feature, float value) { + try { + lock.lock(); + model.setWeight(feature, value); + } finally { + lock.unlock(); + } + } + + @Override public float getCovariance(Object feature) { try { lock.lock(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/WeightValue.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/WeightValue.java b/core/src/main/java/hivemall/model/WeightValue.java index e6d98c6..b329374 100644 --- a/core/src/main/java/hivemall/model/WeightValue.java +++ b/core/src/main/java/hivemall/model/WeightValue.java @@ -77,15 +77,50 @@ public class WeightValue implements IWeightValue { } @Override + public void setSumOfSquaredGradients(float value) { + throw new UnsupportedOperationException(); + } + + @Override public float getSumOfSquaredDeltaX() { return 0.f; } @Override + public void setSumOfSquaredDeltaX(float value) { + throw new UnsupportedOperationException(); + } + + @Override public float getSumOfGradients() { return 0.f; } + @Override + public void setSumOfGradients(float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float getM() { + return 0.f; + } + + @Override + public void setM(float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float getV() { + return 0.f; + } + + @Override + public void setV(float value) { + throw new UnsupportedOperationException(); + } + /** * @return whether touched in training or not */ @@ -137,7 +172,7 @@ public class WeightValue implements IWeightValue { } public static final class WeightValueParamsF1 extends WeightValue { - private final float f1; + private float f1; public WeightValueParamsF1(float weight, float f1) { super(weight); @@ -162,14 +197,19 @@ public class WeightValue implements IWeightValue { return f1; } + @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + } /** * WeightValue with Sum of Squared Gradients */ public static final class WeightValueParamsF2 extends WeightValue { - private final float f1; - private final float f2; + private float f1; + private float f2; public WeightValueParamsF2(float weight, float f1, float f2) { super(weight); @@ -198,15 +238,131 @@ public class WeightValue implements IWeightValue { } @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + + @Override public final float getSumOfSquaredDeltaX() { return f2; } @Override + public void setSumOfSquaredDeltaX(float value) { + this.f2 = value; + } + + @Override public float getSumOfGradients() { return f2; } + @Override + public void setSumOfGradients(float value) { + this.f2 = value; + } + + @Override + public float getM() { + return f1; + } + + @Override + public void setM(float value) { + this.f1 = value; + } + + @Override + public float getV() { + return f2; + } + + @Override + public void setV(float value) { + this.f2 = value; + } + + } + + public static final class WeightValueParamsF3 extends WeightValue { + private float f1; + private float f2; + private float f3; + + public WeightValueParamsF3(float weight, float f1, float f2, float f3) { + super(weight); + this.f1 = f1; + this.f2 = f2; + this.f3 = f3; + } + + @Override + public WeightValueType getType() { + return WeightValueType.ParamsF3; + } + + @Override + public float getFloatParams(@Nonnegative final int i) { + if(i == 1) { + return f1; + } else if(i == 2) { + return f2; + } else if (i == 3) { + return f3; + } + throw new IllegalArgumentException("getFloatParams(" + i + ") should not be called"); + } + + @Override + public final float getSumOfSquaredGradients() { + return f1; + } + + @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + + @Override + public final float getSumOfSquaredDeltaX() { + return f2; + } + + @Override + public void setSumOfSquaredDeltaX(float value) { + this.f2 = value; + } + + @Override + public float getSumOfGradients() { + return f3; + } + + @Override + public void setSumOfGradients(float value) { + this.f3 = value; + } + + @Override + public float getM() { + return f1; + } + + @Override + public void setM(float value) { + this.f1 = value; + } + + @Override + public float getV() { + return f2; + } + + @Override + public void setV(float value) { + this.f2 = value; + } + } public static final class WeightValueWithCovar extends WeightValue { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/WeightValueWithClock.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/WeightValueWithClock.java b/core/src/main/java/hivemall/model/WeightValueWithClock.java index 249650a..9b31361 100644 --- a/core/src/main/java/hivemall/model/WeightValueWithClock.java +++ b/core/src/main/java/hivemall/model/WeightValueWithClock.java @@ -79,15 +79,50 @@ public class WeightValueWithClock implements IWeightValue { } @Override + public void setSumOfSquaredGradients(float value) { + throw new UnsupportedOperationException(); + } + + @Override public float getSumOfSquaredDeltaX() { return 0.f; } @Override + public void setSumOfSquaredDeltaX(float value) { + throw new UnsupportedOperationException(); + } + + @Override public float getSumOfGradients() { return 0.f; } + @Override + public void setSumOfGradients(float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float getM() { + return 0.f; + } + + @Override + public void setM(float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float getV() { + return 0.f; + } + + @Override + public void setV(float value) { + throw new UnsupportedOperationException(); + } + /** * @return whether touched in training or not */ @@ -144,7 +179,7 @@ public class WeightValueWithClock implements IWeightValue { * WeightValue with Sum of Squared Gradients */ public static final class WeightValueParamsF1Clock extends WeightValueWithClock { - private final float f1; + private float f1; public WeightValueParamsF1Clock(float value, float f1) { super(value); @@ -174,11 +209,16 @@ public class WeightValueWithClock implements IWeightValue { return f1; } + @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + } public static final class WeightValueParamsF2Clock extends WeightValueWithClock { - private final float f1; - private final float f2; + private float f1; + private float f2; public WeightValueParamsF2Clock(float value, float f1, float f2) { super(value); @@ -213,15 +253,136 @@ public class WeightValueWithClock implements IWeightValue { } @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + + @Override + public float getSumOfSquaredDeltaX() { + return f2; + } + + @Override + public void setSumOfSquaredDeltaX(float value) { + this.f2 = value; + } + + @Override + public float getSumOfGradients() { + return f2; + } + + @Override + public void setSumOfGradients(float value) { + this.f2 = value; + } + @Override + public float getM() { + return f1; + } + + @Override + public void setM(float value) { + this.f1 = value; + } + + @Override + public float getV() { + return f2; + } + + @Override + public void setV(float value) { + this.f2 = value; + } + + } + + public static final class WeightValueParamsF3Clock extends WeightValueWithClock { + private float f1; + private float f2; + private float f3; + + public WeightValueParamsF3Clock(float value, float f1, float f2, float f3) { + super(value); + this.f1 = f1; + this.f2 = f2; + this.f3 = f3; + } + + public WeightValueParamsF3Clock(IWeightValue src) { + super(src); + this.f1 = src.getFloatParams(1); + this.f2 = src.getFloatParams(2); + this.f3 = src.getFloatParams(3); + } + + @Override + public WeightValueType getType() { + return WeightValueType.ParamsF3; + } + + @Override + public float getFloatParams(@Nonnegative final int i) { + if(i == 1) { + return f1; + } else if(i == 2) { + return f2; + } else if(i == 3) { + return f3; + } + throw new IllegalArgumentException("getFloatParams(" + i + ") should not be called"); + } + + @Override + public float getSumOfSquaredGradients() { + return f1; + } + + @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + + @Override public float getSumOfSquaredDeltaX() { return f2; } @Override + public void setSumOfSquaredDeltaX(float value) { + this.f2 = value; + } + + @Override public float getSumOfGradients() { + return f3; + } + + @Override + public void setSumOfGradients(float value) { + this.f3 = value; + } + @Override + public float getM() { + return f1; + } + + @Override + public void setM(float value) { + this.f1 = value; + } + + @Override + public float getV() { return f2; } + @Override + public void setV(float value) { + this.f2 = value; + } + } public static final class WeightValueWithCovarClock extends WeightValueWithClock { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java new file mode 100644 index 0000000..e2c5a10 --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java @@ -0,0 +1,215 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.optimizer; + +import javax.annotation.Nonnull; +import javax.annotation.concurrent.NotThreadSafe; +import java.util.Arrays; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import hivemall.optimizer.Optimizer.OptimizerBase; +import hivemall.model.IWeightValue; +import hivemall.model.WeightValue; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.math.MathUtils; + +public final class DenseOptimizerFactory { + private static final Log logger = LogFactory.getLog(DenseOptimizerFactory.class); + + @Nonnull + public static Optimizer create(int ndims, @Nonnull Map<String, String> options) { + final String optimizerName = options.get("optimizer"); + if(optimizerName != null) { + OptimizerBase optimizerImpl; + if(optimizerName.toLowerCase().equals("sgd")) { + optimizerImpl = new Optimizer.SGD(options); + } else if(optimizerName.toLowerCase().equals("adadelta")) { + optimizerImpl = new AdaDelta(ndims, options); + } else if(optimizerName.toLowerCase().equals("adagrad")) { + optimizerImpl = new AdaGrad(ndims, options); + } else if(optimizerName.toLowerCase().equals("adam")) { + optimizerImpl = new Adam(ndims, options); + } else { + throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName); + } + + logger.info("set " + optimizerImpl.getClass().getSimpleName() + + " as an optimizer: " + options); + + // If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`. + if(options.get("regularization") != null + && options.get("regularization").toLowerCase().equals("rda")) { + optimizerImpl = new RDA(ndims, optimizerImpl, options); + } + + return optimizerImpl; + } + throw new IllegalArgumentException("`optimizer` not defined"); + } + + @NotThreadSafe + static final class AdaDelta extends Optimizer.AdaDelta { + + private final IWeightValue weightValueReused; + + private float[] sum_of_squared_gradients; + private float[] sum_of_squared_delta_x; + + public AdaDelta(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = new WeightValue.WeightValueParamsF2(0.f, 0.f, 0.f); + this.sum_of_squared_gradients = new float[ndims]; + this.sum_of_squared_delta_x = new float[ndims]; + } + + @Override + public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setSumOfSquaredGradients(sum_of_squared_gradients[i]); + weightValueReused.setSumOfSquaredDeltaX(sum_of_squared_delta_x[i]); + computeUpdateValue(weightValueReused, gradient); + sum_of_squared_gradients[i] = weightValueReused.getSumOfSquaredGradients(); + sum_of_squared_delta_x[i] = weightValueReused.getSumOfSquaredDeltaX(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if(index >= sum_of_squared_gradients.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); + this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize); + } + } + + } + + @NotThreadSafe + static final class AdaGrad extends Optimizer.AdaGrad { + + private final IWeightValue weightValueReused; + + private float[] sum_of_squared_gradients; + + public AdaGrad(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = new WeightValue.WeightValueParamsF1(0.f, 0.f); + this.sum_of_squared_gradients = new float[ndims]; + } + + @Override + public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setSumOfSquaredGradients(sum_of_squared_gradients[i]); + computeUpdateValue(weightValueReused, gradient); + sum_of_squared_gradients[i] = weightValueReused.getSumOfSquaredGradients(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if(index >= sum_of_squared_gradients.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); + } + } + + } + + @NotThreadSafe + static final class Adam extends Optimizer.Adam { + + private final IWeightValue weightValueReused; + + private float[] val_m; + private float[] val_v; + + public Adam(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = new WeightValue.WeightValueParamsF2(0.f, 0.f, 0.f); + this.val_m = new float[ndims]; + this.val_v = new float[ndims]; + } + + @Override + public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setM(val_m[i]); + weightValueReused.setV(val_v[i]); + computeUpdateValue(weightValueReused, gradient); + val_m[i] = weightValueReused.getM(); + val_v[i] = weightValueReused.getV(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if(index >= val_m.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.val_m = Arrays.copyOf(val_m, newSize); + this.val_v = Arrays.copyOf(val_v, newSize); + } + } + + } + + @NotThreadSafe + static final class RDA extends Optimizer.RDA { + + private final IWeightValue weightValueReused; + + private float[] sum_of_gradients; + + public RDA(int ndims, final OptimizerBase optimizerImpl, Map<String, String> options) { + super(optimizerImpl, options); + this.weightValueReused = new WeightValue.WeightValueParamsF3(0.f, 0.f, 0.f, 0.f); + this.sum_of_gradients = new float[ndims]; + } + + @Override + public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setSumOfGradients(sum_of_gradients[i]); + computeUpdateValue(weightValueReused, gradient); + sum_of_gradients[i] = weightValueReused.getSumOfGradients(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if(index >= sum_of_gradients.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize); + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/EtaEstimator.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/EtaEstimator.java b/core/src/main/java/hivemall/optimizer/EtaEstimator.java new file mode 100644 index 0000000..ac1d112 --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/EtaEstimator.java @@ -0,0 +1,191 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.optimizer; + +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; + +import java.util.Map; +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; + +public abstract class EtaEstimator { + + protected final float eta0; + + public EtaEstimator(float eta0) { + this.eta0 = eta0; + } + + public float eta0() { + return eta0; + } + + public abstract float eta(long t); + + public void update(@Nonnegative float multipler) {} + + public static final class FixedEtaEstimator extends EtaEstimator { + + public FixedEtaEstimator(float eta) { + super(eta); + } + + @Override + public float eta(long t) { + return eta0; + } + + } + + public static final class SimpleEtaEstimator extends EtaEstimator { + + private final float finalEta; + private final double total_steps; + + public SimpleEtaEstimator(float eta0, long total_steps) { + super(eta0); + this.finalEta = (float) (eta0 / 2.d); + this.total_steps = total_steps; + } + + @Override + public float eta(final long t) { + if (t > total_steps) { + return finalEta; + } + return (float) (eta0 / (1.d + (t / total_steps))); + } + + } + + public static final class InvscalingEtaEstimator extends EtaEstimator { + + private final double power_t; + + public InvscalingEtaEstimator(float eta0, double power_t) { + super(eta0); + this.power_t = power_t; + } + + @Override + public float eta(final long t) { + return (float) (eta0 / Math.pow(t, power_t)); + } + + } + + /** + * bold driver: Gemulla et al., Large-scale matrix factorization with distributed stochastic + * gradient descent, KDD 2011. + */ + public static final class AdjustingEtaEstimator extends EtaEstimator { + + private float eta; + + public AdjustingEtaEstimator(float eta) { + super(eta); + this.eta = eta; + } + + @Override + public float eta(long t) { + return eta; + } + + @Override + public void update(@Nonnegative float multipler) { + float newEta = eta * multipler; + if (!NumberUtils.isFinite(newEta)) { + // avoid NaN or INFINITY + return; + } + this.eta = Math.min(eta0, newEta); // never be larger than eta0 + } + + } + + @Nonnull + public static EtaEstimator get(@Nullable CommandLine cl) throws UDFArgumentException { + return get(cl, 0.1f); + } + + @Nonnull + public static EtaEstimator get(@Nullable CommandLine cl, float defaultEta0) + throws UDFArgumentException { + if (cl == null) { + return new InvscalingEtaEstimator(defaultEta0, 0.1d); + } + + if (cl.hasOption("boldDriver")) { + float eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.3f); + return new AdjustingEtaEstimator(eta); + } + + String etaValue = cl.getOptionValue("eta"); + if (etaValue != null) { + float eta = Float.parseFloat(etaValue); + return new FixedEtaEstimator(eta); + } + + float eta0 = Primitives.parseFloat(cl.getOptionValue("eta0"), defaultEta0); + if (cl.hasOption("t")) { + long t = Long.parseLong(cl.getOptionValue("t")); + return new SimpleEtaEstimator(eta0, t); + } + + double power_t = Primitives.parseDouble(cl.getOptionValue("power_t"), 0.1d); + return new InvscalingEtaEstimator(eta0, power_t); + } + + @Nonnull + public static EtaEstimator get(@Nonnull final Map<String, String> options) + throws IllegalArgumentException { + final String etaName = options.get("eta"); + if(etaName == null) { + return new FixedEtaEstimator(1.f); + } + float eta0 = 0.1f; + if(options.containsKey("eta0")) { + eta0 = Float.parseFloat(options.get("eta0")); + } + if(etaName.toLowerCase().equals("fixed")) { + return new FixedEtaEstimator(eta0); + } else if(etaName.toLowerCase().equals("simple")) { + long t = 10000; + if(options.containsKey("t")) { + t = Long.parseLong(options.get("t")); + } + return new SimpleEtaEstimator(eta0, t); + } else if(etaName.toLowerCase().equals("inverse")) { + double power_t = 0.1; + if(options.containsKey("power_t")) { + power_t = Double.parseDouble(options.get("power_t")); + } + return new InvscalingEtaEstimator(eta0, power_t); + } else { + throw new IllegalArgumentException("Unsupported ETA name: " + etaName); + } + } + +}
