Repository: incubator-hivemall Updated Branches: refs/heads/master 5e27993b6 -> 50b4c9a75 (forced update)
[HIVEMALL-101] refactored the previous commit Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/50b4c9a7 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/50b4c9a7 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/50b4c9a7 Branch: refs/heads/master Commit: 50b4c9a752b3d3f99a1e519c609176fc46debe69 Parents: 3848ea6 Author: Makoto Yui <[email protected]> Authored: Thu Jun 15 02:50:27 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Thu Jun 15 03:17:29 2017 +0900 ---------------------------------------------------------------------- .../java/hivemall/GeneralLearnerBaseUDTF.java | 31 ++--- .../src/main/java/hivemall/LearnerBaseUDTF.java | 6 +- core/src/main/java/hivemall/UDFWithOptions.java | 6 +- .../src/main/java/hivemall/UDTFWithOptions.java | 5 +- .../classifier/BinaryOnlineClassifierUDTF.java | 6 +- .../classifier/GeneralClassifierUDTF.java | 13 +- core/src/main/java/hivemall/fm/Feature.java | 2 +- .../main/java/hivemall/model/NewDenseModel.java | 8 +- .../model/NewSpaceEfficientDenseModel.java | 19 ++- .../java/hivemall/model/NewSparseModel.java | 4 +- .../main/java/hivemall/model/SparseModel.java | 4 +- .../hivemall/model/WeightValueWithClock.java | 8 +- .../optimizer/DenseOptimizerFactory.java | 62 +++++----- .../java/hivemall/optimizer/EtaEstimator.java | 14 +-- .../java/hivemall/optimizer/LossFunctions.java | 75 ++++++----- .../main/java/hivemall/optimizer/Optimizer.java | 123 +++++++++---------- .../hivemall/optimizer/OptimizerOptions.java | 17 +-- .../java/hivemall/optimizer/Regularization.java | 55 ++++----- .../optimizer/SparseOptimizerFactory.java | 107 ++++++++-------- .../regression/GeneralRegressionUDTF.java | 5 +- .../utils/collections/maps/IntOpenHashMap.java | 4 +- .../classifier/GeneralClassifierUDTFTest.java | 10 +- .../regression/GeneralRegressionUDTFTest.java | 6 +- 23 files changed, 295 insertions(+), 295 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index e798fdf..34c7ec9 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -70,14 +70,15 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { private Optimizer optimizer; private LossFunction lossFunction; - protected PredictionModel model; - protected int count; + private PredictionModel model; + private long count; // The accumulated delta of each weight values. - protected transient Map<Object, FloatAccumulator> accumulated; - protected int sampled; + @Nullable + private transient Map<Object, FloatAccumulator> accumulated; + private int sampled; - private float cumLoss; + private double cumLoss; public GeneralLearnerBaseUDTF() { this(true); @@ -122,12 +123,12 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { try { this.optimizer = createOptimizer(optimizerOptions); } catch (Throwable e) { - throw new UDFArgumentException(e.getMessage()); + throw new UDFArgumentException(e); } - this.count = 0; + this.count = 0L; this.sampled = 0; - this.cumLoss = 0.f; + this.cumLoss = 0.d; return getReturnOI(featureOutputOI); } @@ -160,7 +161,8 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { return cl; } - protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) + @Nonnull + protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg) throws UDFArgumentException { this.featureListOI = (ListObjectInspector) arg; ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); @@ -169,7 +171,8 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { return HiveUtils.asPrimitiveObjectInspector(featureRawOI); } - protected StructObjectInspector getReturnOI(ObjectInspector featureOutputOI) { + @Nonnull + protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) { ArrayList<String> fieldNames = new ArrayList<String>(); ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); @@ -241,7 +244,7 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { final float v = f.getValueAsFloat(); float old_w = model.getWeight(k); - if (old_w != 0f) { + if (old_w != 0.f) { score += (old_w * v); } } @@ -302,7 +305,7 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { this.sampled = 0; } - protected void onlineUpdate(@Nonnull final FeatureValue[] features, float dloss) { + protected void onlineUpdate(@Nonnull final FeatureValue[] features, final float dloss) { for (FeatureValue f : features) { Object feature = f.getFeature(); float xi = f.getValueAsFloat(); @@ -368,13 +371,13 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { } @VisibleForTesting - public float getCumulativeLoss() { + public double getCumulativeLoss() { return cumLoss; } @VisibleForTesting public void resetCumulativeLoss() { - this.cumLoss = 0.f; + this.cumLoss = 0.d; } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/LearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java index bb15bb3..fdb22f8 100644 --- a/core/src/main/java/hivemall/LearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java @@ -255,9 +255,9 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { } } - protected MixClient configureMixClient(String connectURIs, String label, PredictionModel model) { - assert (connectURIs != null); - assert (model != null); + @Nonnull + protected MixClient configureMixClient(@Nonnull String connectURIs, @Nullable String label, + @Nonnull PredictionModel model) { String jobId = (mixSessionName == null) ? MixClient.DUMMY_JOB_ID : mixSessionName; if (label != null) { jobId = jobId + '-' + label; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/UDFWithOptions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/UDFWithOptions.java b/core/src/main/java/hivemall/UDFWithOptions.java index 3aaf9ef..9908cd9 100644 --- a/core/src/main/java/hivemall/UDFWithOptions.java +++ b/core/src/main/java/hivemall/UDFWithOptions.java @@ -77,9 +77,12 @@ public abstract class UDFWithOptions extends GenericUDF { } } + @Nonnull protected abstract Options getOptions(); - protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException { + @Nonnull + protected final CommandLine parseOptions(@Nonnull String optionValue) + throws UDFArgumentException { String[] args = optionValue.split("\\s+"); Options opts = getOptions(); opts.addOption("help", false, "Show function help"); @@ -109,6 +112,7 @@ public abstract class UDFWithOptions extends GenericUDF { return cl; } + @Nonnull protected abstract CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/UDTFWithOptions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/UDTFWithOptions.java b/core/src/main/java/hivemall/UDTFWithOptions.java index 1556a4f..39ab233 100644 --- a/core/src/main/java/hivemall/UDTFWithOptions.java +++ b/core/src/main/java/hivemall/UDTFWithOptions.java @@ -123,8 +123,9 @@ public abstract class UDTFWithOptions extends GenericUDTF { protected abstract CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException; - protected final List<FeatureValue> parseFeatures(final List<?> features, - final ObjectInspector featureInspector, final boolean parseFeature) { + @Nonnull + protected final List<FeatureValue> parseFeatures(@Nonnull final List<?> features, + @Nonnull final ObjectInspector featureInspector, final boolean parseFeature) { final int numFeatures = features.size(); if (numFeatures == 0) { return Collections.emptyList(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java index d25f254..2dcf521 100644 --- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java @@ -167,8 +167,10 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { return featureVector; } - protected void checkLabelValue(int label) throws UDFArgumentException { - assert (label == -1 || label == 0 || label == 1) : label; + protected void checkLabelValue(final int label) throws UDFArgumentException { + if (label != -1 && label != 0 && label != 1) { + throw new UDFArgumentException("Invalid label value for classification: + label"); + } } @VisibleForTesting http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java index 753a498..d7cb539 100644 --- a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java @@ -51,11 +51,18 @@ public final class GeneralClassifierUDTF extends GeneralLearnerBaseUDTF { } @Override - protected void checkLossFunction(LossFunction lossFunction) throws UDFArgumentException {}; + protected void checkLossFunction(LossFunction lossFunction) throws UDFArgumentException { + if(!lossFunction.forBinaryClassification()) { + throw new UDFArgumentException("The loss function `" + lossFunction.getType() + + "` is not designed for binary classification"); + } + } @Override - protected void checkTargetValue(float label) throws UDFArgumentException { - assert (label == -1.f || label == 0.f || label == 1.f) : label; + protected void checkTargetValue(final float label) throws UDFArgumentException { + if (label != -1 && label != 0 && label != 1) { + throw new UDFArgumentException("Invalid label value for classification: + label"); + } } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/fm/Feature.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/Feature.java b/core/src/main/java/hivemall/fm/Feature.java index f2d977e..2966a02 100644 --- a/core/src/main/java/hivemall/fm/Feature.java +++ b/core/src/main/java/hivemall/fm/Feature.java @@ -262,7 +262,7 @@ public abstract class Feature { if (asIntFeature) { int index = parseFeatureIndex(indexStr); probe.setFeatureIndex(index); - probe.value = parseFeatureValue(valueStr);; + probe.value = parseFeatureValue(valueStr); } else { probe.setFeature(indexStr); probe.value = parseFeatureValue(valueStr); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/model/NewDenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/NewDenseModel.java b/core/src/main/java/hivemall/model/NewDenseModel.java index aab3c2b..b5db580 100644 --- a/core/src/main/java/hivemall/model/NewDenseModel.java +++ b/core/src/main/java/hivemall/model/NewDenseModel.java @@ -53,7 +53,7 @@ public final class NewDenseModel extends AbstractPredictionModel { this.weights = new float[size]; if (withCovar) { float[] covars = new float[size]; - Arrays.fill(covars, 1f); + Arrays.fill(covars, 1.f); this.covars = covars; } else { this.covars = null; @@ -99,8 +99,10 @@ public final class NewDenseModel extends AbstractPredictionModel { int bits = MathUtils.bitsRequired(index); int newSize = (1 << bits) + 1; int oldSize = size; - logger.info("Expands internal array size from " + oldSize + " to " + newSize + " (" - + bits + " bits)"); + if (logger.isInfoEnabled()) { + logger.info("Expands internal array size from " + oldSize + " to " + newSize + " (" + + bits + " bits)"); + } this.size = newSize; this.weights = Arrays.copyOf(weights, newSize); if (covars != null) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java index 1848529..0a473b4 100644 --- a/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java +++ b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java @@ -18,7 +18,6 @@ */ package hivemall.model; -import hivemall.annotations.InternalAPI; import hivemall.model.WeightValue.WeightValueWithCovar; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; @@ -105,12 +104,8 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel { return HalfFloat.halfFloatToFloat(covars[i]); } - @InternalAPI - private void _setWeight(final int i, final float v) { - if (Math.abs(v) >= HalfFloat.MAX_FLOAT) { - throw new IllegalArgumentException("Acceptable maximum weight is " - + HalfFloat.MAX_FLOAT + ": " + v); - } + private void setWeight(final int i, final float v) { + HalfFloat.checkRange(v); weights[i] = HalfFloat.floatToHalfFloat(v); } @@ -159,7 +154,7 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel { int i = HiveUtils.parseInt(feature); ensureCapacity(i); float weight = value.get(); - _setWeight(i, weight); + setWeight(i, weight); float covar = 1.f; boolean hasCovar = value.hasCovariance(); if (hasCovar) { @@ -185,7 +180,7 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel { if (i >= size) { return; } - _setWeight(i, 0.f); + setWeight(i, 0.f); if (covars != null) { setCovar(i, 1.f); } @@ -205,7 +200,7 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel { public void setWeight(@Nonnull final Object feature, final float value) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); - _setWeight(i, value); + setWeight(i, value); } @Override @@ -221,7 +216,7 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel { protected void _set(@Nonnull final Object feature, final float weight, final short clock) { int i = ((Integer) feature).intValue(); ensureCapacity(i); - _setWeight(i, weight); + setWeight(i, weight); clocks[i] = clock; deltaUpdates[i] = 0; } @@ -231,7 +226,7 @@ public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel { final short clock) { int i = ((Integer) feature).intValue(); ensureCapacity(i); - _setWeight(i, weight); + setWeight(i, weight); setCovar(i, covar); clocks[i] = clock; deltaUpdates[i] = 0; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/model/NewSparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/NewSparseModel.java b/core/src/main/java/hivemall/model/NewSparseModel.java index e312ae4..8326d22 100644 --- a/core/src/main/java/hivemall/model/NewSparseModel.java +++ b/core/src/main/java/hivemall/model/NewSparseModel.java @@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory; public final class NewSparseModel extends AbstractPredictionModel { private static final Log logger = LogFactory.getLog(NewSparseModel.class); + @Nonnull private final OpenHashMap<Object, IWeightValue> weights; private final boolean hasCovar; private boolean clockEnabled; @@ -80,9 +81,6 @@ public final class NewSparseModel extends AbstractPredictionModel { @Override public <T extends IWeightValue> void set(@Nonnull final Object feature, @Nonnull final T value) { - assert (feature != null); - assert (value != null); - final IWeightValue wrapperValue = wrapIfRequired(value); if (clockEnabled && value.isTouched()) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/model/SparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java index ec26552..cb8ab9f 100644 --- a/core/src/main/java/hivemall/model/SparseModel.java +++ b/core/src/main/java/hivemall/model/SparseModel.java @@ -33,6 +33,7 @@ import org.apache.commons.logging.LogFactory; public final class SparseModel extends AbstractPredictionModel { private static final Log logger = LogFactory.getLog(SparseModel.class); + @Nonnull private final OpenHashMap<Object, IWeightValue> weights; private final boolean hasCovar; private boolean clockEnabled; @@ -76,9 +77,6 @@ public final class SparseModel extends AbstractPredictionModel { @Override public <T extends IWeightValue> void set(@Nonnull final Object feature, @Nonnull final T value) { - assert (feature != null); - assert (value != null); - final IWeightValue wrapperValue = wrapIfRequired(value); if (clockEnabled && value.isTouched()) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/model/WeightValueWithClock.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/WeightValueWithClock.java b/core/src/main/java/hivemall/model/WeightValueWithClock.java index 524fa94..679c519 100644 --- a/core/src/main/java/hivemall/model/WeightValueWithClock.java +++ b/core/src/main/java/hivemall/model/WeightValueWithClock.java @@ -276,6 +276,7 @@ public class WeightValueWithClock implements IWeightValue { public void setSumOfGradients(float value) { this.f2 = value; } + @Override public float getM() { return f1; @@ -324,11 +325,11 @@ public class WeightValueWithClock implements IWeightValue { @Override public float getFloatParams(@Nonnegative final int i) { - if(i == 1) { + if (i == 1) { return f1; - } else if(i == 2) { + } else if (i == 2) { return f2; - } else if(i == 3) { + } else if (i == 3) { return f3; } throw new IllegalArgumentException("getFloatParams(" + i + ") should not be called"); @@ -363,6 +364,7 @@ public class WeightValueWithClock implements IWeightValue { public void setSumOfGradients(float value) { this.f3 = value; } + @Override public float getM() { return f1; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java index 2bf030b..e273f91 100644 --- a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java +++ b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java @@ -20,7 +20,6 @@ package hivemall.optimizer; import hivemall.model.IWeightValue; import hivemall.model.WeightValue; -import hivemall.optimizer.Optimizer.OptimizerBase; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.math.MathUtils; @@ -34,37 +33,40 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; public final class DenseOptimizerFactory { - private static final Log logger = LogFactory.getLog(DenseOptimizerFactory.class); + private static final Log LOG = LogFactory.getLog(DenseOptimizerFactory.class); @Nonnull public static Optimizer create(int ndims, @Nonnull Map<String, String> options) { final String optimizerName = options.get("optimizer"); - if (optimizerName != null) { - OptimizerBase optimizerImpl; - if (optimizerName.toLowerCase().equals("sgd")) { - optimizerImpl = new Optimizer.SGD(options); - } else if (optimizerName.toLowerCase().equals("adadelta")) { - optimizerImpl = new AdaDelta(ndims, options); - } else if (optimizerName.toLowerCase().equals("adagrad")) { - optimizerImpl = new AdaGrad(ndims, options); - } else if (optimizerName.toLowerCase().equals("adam")) { - optimizerImpl = new Adam(ndims, options); - } else { - throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName); - } - - logger.info("set " + optimizerImpl.getClass().getSimpleName() + " as an optimizer: " - + options); + if (optimizerName == null) { + throw new IllegalArgumentException("`optimizer` not defined"); + } + final Optimizer optimizerImpl; + if ("sgd".equalsIgnoreCase(optimizerName)) { + optimizerImpl = new Optimizer.SGD(options); + } else if ("adadelta".equalsIgnoreCase(optimizerName)) { + optimizerImpl = new AdaDelta(ndims, options); + } else if ("adagrad".equalsIgnoreCase(optimizerName)) { // If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`. - if (options.get("regularization") != null - && options.get("regularization").toLowerCase().equals("rda")) { - optimizerImpl = new AdagradRDA(ndims, optimizerImpl, options); + if ("rda".equalsIgnoreCase(options.get("regularization"))) { + AdaGrad adagrad = new AdaGrad(ndims, options); + optimizerImpl = new AdagradRDA(ndims, adagrad, options); + } else { + optimizerImpl = new AdaGrad(ndims, options); } + } else if ("adam".equalsIgnoreCase(optimizerName)) { + optimizerImpl = new Adam(ndims, options); + } else { + throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName); + } - return optimizerImpl; + if (LOG.isInfoEnabled()) { + LOG.info("Configured " + optimizerImpl.getOptimizerName() + " as the optimizer: " + + options); } - throw new IllegalArgumentException("`optimizer` not defined"); + + return optimizerImpl; } @NotThreadSafe @@ -86,7 +88,7 @@ public final class DenseOptimizerFactory { } @Override - public float update(@Nonnull Object feature, float weight, float gradient) { + public float update(@Nonnull final Object feature, final float weight, final float gradient) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); weightValueReused.set(weight); @@ -112,8 +114,9 @@ public final class DenseOptimizerFactory { @NotThreadSafe static final class AdaGrad extends Optimizer.AdaGrad { + @Nonnull private final IWeightValue weightValueReused; - + @Nonnull private float[] sum_of_squared_gradients; public AdaGrad(int ndims, Map<String, String> options) { @@ -123,7 +126,7 @@ public final class DenseOptimizerFactory { } @Override - public float update(@Nonnull Object feature, float weight, float gradient) { + public float update(@Nonnull final Object feature, final float weight, final float gradient) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); weightValueReused.set(weight); @@ -162,7 +165,7 @@ public final class DenseOptimizerFactory { } @Override - public float update(@Nonnull Object feature, float weight, float gradient) { + public float update(@Nonnull final Object feature, final float weight, final float gradient) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); weightValueReused.set(weight); @@ -194,14 +197,15 @@ public final class DenseOptimizerFactory { @Nonnull private float[] sum_of_gradients; - public AdagradRDA(int ndims, final OptimizerBase optimizerImpl, Map<String, String> options) { + public AdagradRDA(int ndims, @Nonnull Optimizer.AdaGrad optimizerImpl, + @Nonnull Map<String, String> options) { super(optimizerImpl, options); this.weightValueReused = new WeightValue.WeightValueParamsF3(0.f, 0.f, 0.f, 0.f); this.sum_of_gradients = new float[ndims]; } @Override - public float update(@Nonnull Object feature, float weight, float gradient) { + public float update(@Nonnull final Object feature, final float weight, final float gradient) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); weightValueReused.set(weight); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/optimizer/EtaEstimator.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/EtaEstimator.java b/core/src/main/java/hivemall/optimizer/EtaEstimator.java index 17b39d1..1a4c07d 100644 --- a/core/src/main/java/hivemall/optimizer/EtaEstimator.java +++ b/core/src/main/java/hivemall/optimizer/EtaEstimator.java @@ -161,14 +161,12 @@ public abstract class EtaEstimator { @Nonnull public static EtaEstimator get(@Nonnull final Map<String, String> options) throws IllegalArgumentException { + final float eta0 = Primitives.parseFloat(options.get("eta0"), 0.1f); + final double power_t = Primitives.parseDouble(options.get("power_t"), 0.1d); + final String etaScheme = options.get("eta"); if (etaScheme == null) { - return new InvscalingEtaEstimator(0.1f, 0.1d); - } - - float eta0 = 0.1f; - if (options.containsKey("eta0")) { - eta0 = Float.parseFloat(options.get("eta0")); + return new InvscalingEtaEstimator(eta0, power_t); } if ("fixed".equalsIgnoreCase(etaScheme)) { @@ -183,10 +181,6 @@ public abstract class EtaEstimator { } return new SimpleEtaEstimator(eta0, t); } else if ("inv".equalsIgnoreCase(etaScheme) || "inverse".equalsIgnoreCase(etaScheme)) { - double power_t = 0.1; - if (options.containsKey("power_t")) { - power_t = Double.parseDouble(options.get("power_t")); - } return new InvscalingEtaEstimator(eta0, power_t); } else { throw new IllegalArgumentException("Unsupported ETA name: " + etaScheme); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/optimizer/LossFunctions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/LossFunctions.java b/core/src/main/java/hivemall/optimizer/LossFunctions.java index 0dff4aa..a1ade3d 100644 --- a/core/src/main/java/hivemall/optimizer/LossFunctions.java +++ b/core/src/main/java/hivemall/optimizer/LossFunctions.java @@ -20,6 +20,9 @@ package hivemall.optimizer; import hivemall.utils.math.MathUtils; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + /** * @link https://github.com/JohnLangford/vowpal_wabbit/wiki/Loss-functions */ @@ -30,7 +33,8 @@ public final class LossFunctions { SquaredHingeLoss, ModifiedHuberLoss } - public static LossFunction getLossFunction(String type) { + @Nonnull + public static LossFunction getLossFunction(@Nullable final String type) { if ("SquaredLoss".equalsIgnoreCase(type)) { return new SquaredLoss(); } else if ("QuantileLoss".equalsIgnoreCase(type)) { @@ -41,7 +45,7 @@ public final class LossFunctions { return new HuberLoss(); } else if ("HingeLoss".equalsIgnoreCase(type)) { return new HingeLoss(); - } else if ("LogLoss".equalsIgnoreCase(type)) { + } else if ("LogLoss".equalsIgnoreCase(type) || "LogisticLoss".equalsIgnoreCase(type)) { return new LogLoss(); } else if ("SquaredHingeLoss".equalsIgnoreCase(type)) { return new SquaredHingeLoss(); @@ -51,7 +55,8 @@ public final class LossFunctions { throw new IllegalArgumentException("Unsupported loss function name: " + type); } - public static LossFunction getLossFunction(LossType type) { + @Nonnull + public static LossFunction getLossFunction(@Nonnull final LossType type) { switch (type) { case SquaredLoss: return new SquaredLoss(); @@ -100,6 +105,7 @@ public final class LossFunctions { public boolean forRegression(); + @Nonnull public LossType getType(); } @@ -119,13 +125,13 @@ public final class LossFunctions { public static abstract class BinaryLoss implements LossFunction { - protected static void checkTarget(float y) { + protected static void checkTarget(final float y) { if (!(y == 1.f || y == -1.f)) { throw new IllegalArgumentException("target must be [+1,-1]: " + y); } } - protected static void checkTarget(double y) { + protected static void checkTarget(final double y) { if (!(y == 1.d || y == -1.d)) { throw new IllegalArgumentException("target must be [+1,-1]: " + y); } @@ -150,19 +156,19 @@ public final class LossFunctions { public static final class SquaredLoss extends RegressionLoss { @Override - public float loss(float p, float y) { + public float loss(final float p, final float y) { final float z = p - y; return z * z * 0.5f; } @Override - public double loss(double p, double y) { + public double loss(final double p, final double y) { final double z = p - y; return z * z * 0.5d; } @Override - public float dloss(float p, float y) { + public float dloss(final float p, final float y) { return p - y; // 2 (p - y) / 2 } @@ -197,7 +203,7 @@ public final class LossFunctions { } @Override - public float loss(float p, float y) { + public float loss(final float p, final float y) { float e = y - p; if (e > 0.f) { return tau * e; @@ -207,7 +213,7 @@ public final class LossFunctions { } @Override - public double loss(double p, double y) { + public double loss(final double p, final double y) { double e = y - p; if (e > 0.d) { return tau * e; @@ -217,7 +223,7 @@ public final class LossFunctions { } @Override - public float dloss(float p, float y) { + public float dloss(final float p, final float y) { float e = y - p; if (e == 0.f) { return 0.f; @@ -251,19 +257,19 @@ public final class LossFunctions { } @Override - public float loss(float p, float y) { + public float loss(final float p, final float y) { float loss = Math.abs(y - p) - epsilon; return (loss > 0.f) ? loss : 0.f; } @Override - public double loss(double p, double y) { + public double loss(final double p, final double y) { double loss = Math.abs(y - p) - epsilon; return (loss > 0.d) ? loss : 0.d; } @Override - public float dloss(float p, float y) { + public float dloss(final float p, final float y) { if ((y - p) > epsilon) {// real value > predicted value - epsilon return -1.f; } @@ -303,7 +309,7 @@ public final class LossFunctions { } @Override - public float loss(float p, float y) { + public float loss(final float p, final float y) { final float r = p - y; final float rAbs = Math.abs(r); if (rAbs <= c) { @@ -313,7 +319,7 @@ public final class LossFunctions { } @Override - public double loss(double p, double y) { + public double loss(final double p, final double y) { final double r = p - y; final double rAbs = Math.abs(r); if (rAbs <= c) { @@ -323,7 +329,7 @@ public final class LossFunctions { } @Override - public float dloss(float p, float y) { + public float dloss(final float p, final float y) { final float r = p - y; final float rAbs = Math.abs(r); if (rAbs <= c) { @@ -364,19 +370,19 @@ public final class LossFunctions { } @Override - public float loss(float p, float y) { + public float loss(final float p, final float y) { float loss = hingeLoss(p, y, threshold); return (loss > 0.f) ? loss : 0.f; } @Override - public double loss(double p, double y) { + public double loss(final double p, final double y) { double loss = hingeLoss(p, y, threshold); return (loss > 0.d) ? loss : 0.d; } @Override - public float dloss(float p, float y) { + public float dloss(final float p, final float y) { float loss = hingeLoss(p, y, threshold); return (loss > 0.f) ? -y : 0.f; } @@ -396,7 +402,7 @@ public final class LossFunctions { * <code>logloss(p,y) = log(1+exp(-p*y))</code> */ @Override - public float loss(float p, float y) { + public float loss(final float p, final float y) { checkTarget(y); final float z = y * p; @@ -410,7 +416,7 @@ public final class LossFunctions { } @Override - public double loss(double p, double y) { + public double loss(final double p, final double y) { checkTarget(y); final double z = y * p; @@ -424,7 +430,7 @@ public final class LossFunctions { } @Override - public float dloss(float p, float y) { + public float dloss(final float p, final float y) { checkTarget(y); float z = y * p; @@ -449,17 +455,17 @@ public final class LossFunctions { public static final class SquaredHingeLoss extends BinaryLoss { @Override - public float loss(float p, float y) { + public float loss(final float p, final float y) { return squaredHingeLoss(p, y); } @Override - public double loss(double p, double y) { + public double loss(final double p, final double y) { return squaredHingeLoss(p, y); } @Override - public float dloss(float p, float y) { + public float dloss(final float p, final float y) { checkTarget(y); float d = 1 - (y * p); @@ -480,7 +486,7 @@ public final class LossFunctions { public static final class ModifiedHuberLoss extends BinaryLoss { @Override - public float loss(float p, float y) { + public float loss(final float p, final float y) { final float z = p * y; if (z >= 1.f) { return 0.f; @@ -491,7 +497,7 @@ public final class LossFunctions { } @Override - public double loss(double p, double y) { + public double loss(final double p, final double y) { final double z = p * y; if (z >= 1.d) { return 0.d; @@ -502,7 +508,7 @@ public final class LossFunctions { } @Override - public float dloss(float p, float y) { + public float dloss(final float p, final float y) { final float z = p * y; if (z >= 1.f) { return 0.f; @@ -552,12 +558,12 @@ public final class LossFunctions { return Math.log(1.d + Math.exp(-z)); } - public static float squaredLoss(float p, float y) { + public static float squaredLoss(final float p, final float y) { final float z = p - y; return z * z * 0.5f; } - public static double squaredLoss(double p, double y) { + public static double squaredLoss(final double p, final double y) { final double z = p - y; return z * z * 0.5d; } @@ -576,11 +582,11 @@ public final class LossFunctions { return threshold - z; } - public static float hingeLoss(float p, float y) { + public static float hingeLoss(final float p, final float y) { return hingeLoss(p, y, 1.f); } - public static double hingeLoss(double p, double y) { + public static double hingeLoss(final double p, final double y) { return hingeLoss(p, y, 1.d); } @@ -603,7 +609,8 @@ public final class LossFunctions { /** * Math.abs(target - predicted) - epsilon */ - public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) { + public static float epsilonInsensitiveLoss(final float predicted, final float target, + final float epsilon) { return Math.abs(target - predicted) - epsilon; } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/optimizer/Optimizer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java index ad70e61..4b11bd1 100644 --- a/core/src/main/java/hivemall/optimizer/Optimizer.java +++ b/core/src/main/java/hivemall/optimizer/Optimizer.java @@ -20,6 +20,7 @@ package hivemall.optimizer; import hivemall.model.IWeightValue; import hivemall.model.WeightValue; +import hivemall.utils.lang.Primitives; import java.util.Map; @@ -39,6 +40,9 @@ public interface Optimizer { */ void proceedStep(); + @Nonnull + String getOptimizerName(); + @NotThreadSafe static abstract class OptimizerBase implements Optimizer { @@ -49,7 +53,7 @@ public interface Optimizer { @Nonnegative protected int _numStep = 1; - public OptimizerBase(final Map<String, String> options) { + public OptimizerBase(@Nonnull Map<String, String> options) { this._eta = EtaEstimator.get(options); this._reg = Regularization.get(options); } @@ -61,11 +65,14 @@ public interface Optimizer { /** * Update the given weight by the given gradient. + * + * @return new weight to be set */ - protected float update(@Nonnull final IWeightValue weight, float gradient) { - float g = _reg.regularize(weight.get(), gradient); + protected float update(@Nonnull final IWeightValue weight, final float gradient) { + float oldWeight = weight.get(); + float g = _reg.regularize(oldWeight, gradient); float delta = computeDelta(weight, g); - float newWeight = weight.get() - _eta.eta(_numStep) * delta; + float newWeight = oldWeight - _eta.eta(_numStep) * delta; weight.set(newWeight); return newWeight; } @@ -73,7 +80,7 @@ public interface Optimizer { /** * Compute a delta to update */ - protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) { + protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { return gradient; } @@ -89,12 +96,17 @@ public interface Optimizer { } @Override - public float update(@Nonnull Object feature, float weight, float gradient) { + public float update(@Nonnull final Object feature, final float weight, final float gradient) { weightValueReused.set(weight); update(weightValueReused, gradient); return weightValueReused.get(); } + @Override + public String getOptimizerName() { + return "sgd"; + } + } static abstract class AdaGrad extends OptimizerBase { @@ -104,26 +116,23 @@ public interface Optimizer { public AdaGrad(Map<String, String> options) { super(options); - float eps = 1.0f; - float scale = 100.0f; - if (options.containsKey("eps")) { - eps = Float.parseFloat(options.get("eps")); - } - if (options.containsKey("scale")) { - scale = Float.parseFloat(options.get("scale")); - } - this.eps = eps; - this.scale = scale; + this.eps = Primitives.parseFloat(options.get("eps"), 1.0f); + this.scale = Primitives.parseFloat(options.get("scale"), 100.0f); } @Override - protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) { + protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { float new_scaled_sum_sqgrad = weight.getSumOfSquaredGradients() + gradient * (gradient / scale); weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad); return gradient / ((float) Math.sqrt(new_scaled_sum_sqgrad * scale) + eps); } + @Override + public String getOptimizerName() { + return "adagrad"; + } + } static abstract class AdaDelta extends OptimizerBase { @@ -134,25 +143,13 @@ public interface Optimizer { public AdaDelta(Map<String, String> options) { super(options); - float decay = 0.95f; - float eps = 1e-6f; - float scale = 100.0f; - if (options.containsKey("decay")) { - decay = Float.parseFloat(options.get("decay")); - } - if (options.containsKey("eps")) { - eps = Float.parseFloat(options.get("eps")); - } - if (options.containsKey("scale")) { - scale = Float.parseFloat(options.get("scale")); - } - this.decay = decay; - this.eps = eps; - this.scale = scale; + this.decay = Primitives.parseFloat(options.get("decay"), 0.95f); + this.eps = Primitives.parseFloat(options.get("eps"), 1e-6f); + this.scale = Primitives.parseFloat(options.get("scale"), 100.0f); } @Override - protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) { + protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { float old_scaled_sum_sqgrad = weight.getSumOfSquaredGradients(); float old_sum_squared_delta_x = weight.getSumOfSquaredDeltaX(); float new_scaled_sum_sqgrad = (decay * old_scaled_sum_sqgrad) @@ -167,6 +164,11 @@ public interface Optimizer { return delta; } + @Override + public String getOptimizerName() { + return "adadelta"; + } + } /** @@ -183,25 +185,13 @@ public interface Optimizer { public Adam(Map<String, String> options) { super(options); - float beta = 0.9f; - float gamma = 0.999f; - float eps_hat = 1e-8f; - if (options.containsKey("beta")) { - beta = Float.parseFloat(options.get("beta")); - } - if (options.containsKey("gamma")) { - gamma = Float.parseFloat(options.get("gamma")); - } - if (options.containsKey("eps_hat")) { - eps_hat = Float.parseFloat(options.get("eps_hat")); - } - this.beta = beta; - this.gamma = gamma; - this.eps_hat = eps_hat; + this.beta = Primitives.parseFloat(options.get("beta"), 0.9f); + this.gamma = Primitives.parseFloat(options.get("gamma"), 0.999f); + this.eps_hat = Primitives.parseFloat(options.get("eps_hat"), 1e-8f); } @Override - protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) { + protected float computeDelta(@Nonnull final IWeightValue weight, final float gradient) { float val_m = beta * weight.getM() + (1.f - beta) * gradient; float val_v = gamma * weight.getV() + (float) ((1.f - gamma) * Math.pow(gradient, 2.0)); float val_m_hat = val_m / (float) (1.f - Math.pow(beta, _numStep)); @@ -212,36 +202,32 @@ public interface Optimizer { return delta; } + @Override + public String getOptimizerName() { + return "adam"; + } + } static abstract class AdagradRDA extends OptimizerBase { - private final OptimizerBase optimizerImpl; - + @Nonnull + private final AdaGrad optimizerImpl; private final float lambda; - public AdagradRDA(final OptimizerBase optimizerImpl, Map<String, String> options) { + public AdagradRDA(@Nonnull AdaGrad optimizerImpl, @Nonnull Map<String, String> options) { super(options); - // We assume `optimizerImpl` has the `AdaGrad` implementation only - if (!(optimizerImpl instanceof AdaGrad)) { - throw new IllegalArgumentException(optimizerImpl.getClass().getSimpleName() - + " currently does not support RDA regularization"); - } - float lambda = 1e-6f; - if (options.containsKey("lambda")) { - lambda = Float.parseFloat(options.get("lambda")); - } this.optimizerImpl = optimizerImpl; - this.lambda = lambda; + this.lambda = Primitives.parseFloat(options.get("lambda"), 1e-6f); } @Override - protected float update(@Nonnull final IWeightValue weight, float gradient) { - float new_sum_grad = weight.getSumOfGradients() + gradient; + protected float update(@Nonnull final IWeightValue weight, final float gradient) { + final float new_sum_grad = weight.getSumOfGradients() + gradient; // sign(u_{t,i}) - float sign = (new_sum_grad > 0.f) ? 1.f : -1.f; + final float sign = (new_sum_grad > 0.f) ? 1.f : -1.f; // |u_{t,i}|/t - \lambda - float meansOfGradients = (sign * new_sum_grad / _numStep) - lambda; + final float meansOfGradients = (sign * new_sum_grad / _numStep) - lambda; if (meansOfGradients < 0.f) { // x_{t,i} = 0 weight.set(0.f); @@ -258,6 +244,11 @@ public interface Optimizer { } } + @Override + public String getOptimizerName() { + return "adagrad_rda"; + } + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/optimizer/OptimizerOptions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/OptimizerOptions.java b/core/src/main/java/hivemall/optimizer/OptimizerOptions.java index 19fecb1..be65609 100644 --- a/core/src/main/java/hivemall/optimizer/OptimizerOptions.java +++ b/core/src/main/java/hivemall/optimizer/OptimizerOptions.java @@ -25,8 +25,8 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.commons.cli.CommandLine; -import org.apache.commons.cli.Options; import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; public final class OptimizerOptions { @@ -63,14 +63,15 @@ public final class OptimizerOptions { public static void propcessOptions(@Nullable CommandLine cl, @Nonnull Map<String, String> options) { - if (cl != null) { - for (Option opt : cl.getOptions()) { - String optName = opt.getLongOpt(); - if (optName == null) { - optName = opt.getOpt(); - } - options.put(optName, opt.getValue()); + if (cl == null) { + return; + } + for (Option opt : cl.getOptions()) { + String optName = opt.getLongOpt(); + if (optName == null) { + optName = opt.getOpt(); } + options.put(optName, opt.getValue()); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/optimizer/Regularization.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/Regularization.java b/core/src/main/java/hivemall/optimizer/Regularization.java index 4939f60..9650826 100644 --- a/core/src/main/java/hivemall/optimizer/Regularization.java +++ b/core/src/main/java/hivemall/optimizer/Regularization.java @@ -18,24 +18,23 @@ */ package hivemall.optimizer; -import javax.annotation.Nonnull; +import hivemall.utils.lang.Primitives; + import java.util.Map; +import javax.annotation.Nonnull; + public abstract class Regularization { /** the default regularization term 0.0001 */ - public static final float DEFAULT_LAMBDA = 0.0001f; + private static final float DEFAULT_LAMBDA = 0.0001f; protected final float lambda; public Regularization(@Nonnull Map<String, String> options) { - float lambda = DEFAULT_LAMBDA; - if (options.containsKey("lambda")) { - lambda = Float.parseFloat(options.get("lambda")); - } - this.lambda = lambda; + this.lambda = Primitives.parseFloat(options.get("lambda"), DEFAULT_LAMBDA); } - public float regularize(float weight, float gradient) { + public float regularize(final float weight, final float gradient) { return gradient + lambda * getRegularizer(weight); } @@ -61,8 +60,8 @@ public abstract class Regularization { } @Override - public float getRegularizer(float weight) { - return (weight > 0.f ? 1.f : -1.f); + public float getRegularizer(final float weight) { + return weight > 0.f ? 1.f : -1.f; } } @@ -81,32 +80,30 @@ public abstract class Regularization { } public static final class ElasticNet extends Regularization { - public static final float DEFAULT_L1_RATIO = 0.5f; + private static final float DEFAULT_L1_RATIO = 0.5f; - protected final L1 l1; - protected final L2 l2; + @Nonnull + private final L1 l1; + @Nonnull + private final L2 l2; - protected final float l1Ratio; + private final float l1Ratio; - public ElasticNet(Map<String, String> options) { + public ElasticNet(@Nonnull Map<String, String> options) { super(options); this.l1 = new L1(options); this.l2 = new L2(options); - float l1Ratio = DEFAULT_L1_RATIO; - if (options.containsKey("l1_ratio")) { - l1Ratio = Float.parseFloat(options.get("l1_ratio")); - if (l1Ratio < 0.f || l1Ratio > 1.f) { - throw new IllegalArgumentException("L1 ratio should be in [0.0, 1.0], but got " - + l1Ratio); - } + this.l1Ratio = Primitives.parseFloat(options.get("l1_ratio"), DEFAULT_L1_RATIO); + if (l1Ratio < 0.f || l1Ratio > 1.f) { + throw new IllegalArgumentException("L1 ratio should be in [0.0, 1.0], but got " + + l1Ratio); } - this.l1Ratio = l1Ratio; } @Override - public float getRegularizer(float weight) { + public float getRegularizer(final float weight) { return l1Ratio * l1.getRegularizer(weight) + (1.f - l1Ratio) * l2.getRegularizer(weight); } @@ -120,15 +117,15 @@ public abstract class Regularization { return new PassThrough(options); } - if (regName.toLowerCase().equals("no")) { + if ("no".equalsIgnoreCase(regName)) { return new PassThrough(options); - } else if (regName.toLowerCase().equals("l1")) { + } else if ("l1".equalsIgnoreCase(regName)) { return new L1(options); - } else if (regName.toLowerCase().equals("l2")) { + } else if ("l2".equalsIgnoreCase(regName)) { return new L2(options); - } else if (regName.toLowerCase().equals("elasticnet")) { + } else if ("elasticnet".equalsIgnoreCase(regName)) { return new ElasticNet(options); - } else if (regName.toLowerCase().equals("rda")) { + } else if ("rda".equalsIgnoreCase(regName)) { // Return `PassThrough` because we need special handling for RDA. // See an implementation of `Optimizer#RDA`. return new PassThrough(options); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java index 4a003f3..7bcac1b 100644 --- a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java +++ b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java @@ -20,7 +20,6 @@ package hivemall.optimizer; import hivemall.model.IWeightValue; import hivemall.model.WeightValue; -import hivemall.optimizer.Optimizer.OptimizerBase; import hivemall.utils.collections.maps.OpenHashMap; import java.util.Map; @@ -37,34 +36,35 @@ public final class SparseOptimizerFactory { @Nonnull public static Optimizer create(int ndims, @Nonnull Map<String, String> options) { final String optimizerName = options.get("optimizer"); - if (optimizerName != null) { - OptimizerBase optimizerImpl; - if (optimizerName.toLowerCase().equals("sgd")) { - optimizerImpl = new Optimizer.SGD(options); - } else if (optimizerName.toLowerCase().equals("adadelta")) { - optimizerImpl = new AdaDelta(ndims, options); - } else if (optimizerName.toLowerCase().equals("adagrad")) { - optimizerImpl = new AdaGrad(ndims, options); - } else if (optimizerName.toLowerCase().equals("adam")) { - optimizerImpl = new Adam(ndims, options); - } else { - throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName); - } + if (optimizerName == null) { + throw new IllegalArgumentException("`optimizer` not defined"); + } + final Optimizer optimizerImpl; + if ("sgd".equalsIgnoreCase(optimizerName)) { + optimizerImpl = new Optimizer.SGD(options); + } else if ("adadelta".equalsIgnoreCase(optimizerName)) { + optimizerImpl = new AdaDelta(ndims, options); + } else if ("adagrad".equalsIgnoreCase(optimizerName)) { // If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`. - if (options.get("regularization") != null - && options.get("regularization").toLowerCase().equals("rda")) { - optimizerImpl = new AdagradRDA(ndims, optimizerImpl, options); - } - - if (LOG.isInfoEnabled()) { - LOG.info("set " + optimizerImpl.getClass().getSimpleName() + " as an optimizer: " - + options); + if ("rda".equalsIgnoreCase(options.get("regularization"))) { + AdaGrad adagrad = new AdaGrad(ndims, options); + optimizerImpl = new AdagradRDA(ndims, adagrad, options); + } else { + optimizerImpl = new AdaGrad(ndims, options); } + } else if ("adam".equalsIgnoreCase(optimizerName)) { + optimizerImpl = new Adam(ndims, options); + } else { + throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName); + } - return optimizerImpl; + if (LOG.isInfoEnabled()) { + LOG.info("Configured " + optimizerImpl.getOptimizerName() + " as the optimizer: " + + options); } - throw new IllegalArgumentException("`optimizer` not defined"); + + return optimizerImpl; } @NotThreadSafe @@ -79,17 +79,15 @@ public final class SparseOptimizerFactory { } @Override - public float update(@Nonnull Object feature, float weight, float gradient) { - IWeightValue auxWeight; - if (auxWeights.containsKey(feature)) { - auxWeight = auxWeights.get(feature); - auxWeight.set(weight); - } else { + public float update(@Nonnull final Object feature, final float weight, final float gradient) { + IWeightValue auxWeight = auxWeights.get(feature); + if (auxWeight == null) { auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); auxWeights.put(feature, auxWeight); + } else { + auxWeight.set(weight); } - update(auxWeight, gradient); - return auxWeight.get(); + return update(auxWeight, gradient); } } @@ -106,17 +104,15 @@ public final class SparseOptimizerFactory { } @Override - public float update(@Nonnull Object feature, float weight, float gradient) { - IWeightValue auxWeight; - if (auxWeights.containsKey(feature)) { - auxWeight = auxWeights.get(feature); - auxWeight.set(weight); - } else { + public float update(@Nonnull final Object feature, final float weight, final float gradient) { + IWeightValue auxWeight = auxWeights.get(feature); + if (auxWeight == null) { auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); auxWeights.put(feature, auxWeight); + } else { + auxWeight.set(weight); } - update(auxWeight, gradient); - return auxWeight.get(); + return update(auxWeight, gradient); } } @@ -133,17 +129,15 @@ public final class SparseOptimizerFactory { } @Override - public float update(@Nonnull Object feature, float weight, float gradient) { - IWeightValue auxWeight; - if (auxWeights.containsKey(feature)) { - auxWeight = auxWeights.get(feature); - auxWeight.set(weight); - } else { + public float update(@Nonnull final Object feature, final float weight, final float gradient) { + IWeightValue auxWeight = auxWeights.get(feature); + if (auxWeight == null) { auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); auxWeights.put(feature, auxWeight); + } else { + auxWeight.set(weight); } - update(auxWeight, gradient); - return auxWeight.get(); + return update(auxWeight, gradient); } } @@ -154,23 +148,22 @@ public final class SparseOptimizerFactory { @Nonnull private final OpenHashMap<Object, IWeightValue> auxWeights; - public AdagradRDA(int size, OptimizerBase optimizerImpl, Map<String, String> options) { + public AdagradRDA(int size, @Nonnull Optimizer.AdaGrad optimizerImpl, + @Nonnull Map<String, String> options) { super(optimizerImpl, options); this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); } @Override - public float update(@Nonnull Object feature, float weight, float gradient) { - IWeightValue auxWeight; - if (auxWeights.containsKey(feature)) { - auxWeight = auxWeights.get(feature); - auxWeight.set(weight); - } else { + public float update(@Nonnull final Object feature, final float weight, final float gradient) { + IWeightValue auxWeight = auxWeights.get(feature); + if (auxWeight == null) { auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f); auxWeights.put(feature, auxWeight); + } else { + auxWeight.set(weight); } - update(auxWeight, gradient); - return auxWeight.get(); + return update(auxWeight, gradient); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java index 5137dd3..160d92d 100644 --- a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java +++ b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java @@ -50,8 +50,9 @@ public final class GeneralRegressionUDTF extends GeneralLearnerBaseUDTF { } @Override - protected void checkLossFunction(@Nonnull LossFunction lossFunction) throws UDFArgumentException { - if (lossFunction.forBinaryClassification()) { + protected void checkLossFunction(@Nonnull LossFunction lossFunction) + throws UDFArgumentException { + if (!lossFunction.forRegression()) { throw new UDFArgumentException("The loss function `" + lossFunction.getType() + "` is not designed for regression"); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java index d7ae8d6..5ce34a4 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java +++ b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java @@ -354,9 +354,9 @@ public class IntOpenHashMap<V> implements Externalizable { return key & 0x7fffffff; } - protected void recordAccess(int idx) {}; + protected void recordAccess(int idx) {} - protected void recordRemoval(int idx) {}; + protected void recordRemoval(int idx) {} public void writeExternal(ObjectOutput out) throws IOException { out.writeInt(_threshold); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java index e558e67..6ed783c 100644 --- a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java +++ b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java @@ -38,7 +38,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; - import org.junit.Assert; import org.junit.Test; @@ -107,8 +106,8 @@ public class GeneralClassifierUDTFTest { udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params}); - float cumLossPrev = Float.MAX_VALUE; - float cumLoss = 0.f; + double cumLossPrev = Double.MAX_VALUE; + double cumLoss = 0.d; int it = 0; while ((it < maxIter) && (Math.abs(cumLoss - cumLossPrev) > 1e-3f)) { cumLossPrev = cumLoss; @@ -119,7 +118,7 @@ public class GeneralClassifierUDTFTest { cumLoss = udtf.getCumulativeLoss(); println("Iter: " + ++it + ", Cumulative loss: " + cumLoss); } - Assert.assertTrue(cumLoss / samplesList.size() < 0.5f); + Assert.assertTrue(cumLoss / samplesList.size() < 0.5d); int numTests = 0; int numCorrect = 0; @@ -176,6 +175,7 @@ public class GeneralClassifierUDTFTest { } } + @SuppressWarnings("unchecked") @Test public void testNews20() throws IOException, ParseException, HiveException { int nIter = 10; @@ -205,7 +205,7 @@ public class GeneralClassifierUDTFTest { udtf.process(new Object[] {words, label}); labels.add(label); - wordsList.add((ArrayList) words.clone()); + wordsList.add((ArrayList<String>) words.clone()); words.clear(); line = news20.readLine(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/50b4c9a7/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java index 15dcc22..df5c643 100644 --- a/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java +++ b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java @@ -121,8 +121,8 @@ public class GeneralRegressionUDTFTest { udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params}); - float cumLossPrev = Float.MAX_VALUE; - float cumLoss = 0.f; + double cumLossPrev = Double.MAX_VALUE; + double cumLoss = 0.d; int it = 0; while ((it < maxIter) && (Math.abs(cumLoss - cumLossPrev) > 1e-3f)) { cumLossPrev = cumLoss; @@ -133,7 +133,7 @@ public class GeneralRegressionUDTFTest { cumLoss = udtf.getCumulativeLoss(); println("Iter: " + ++it + ", Cumulative loss: " + cumLoss); } - Assert.assertTrue(cumLoss / numTrain < 0.1f); + Assert.assertTrue(cumLoss / numTrain < 0.1d); float accum = 0.f;
